From 3b96a6f0f440b5ededd21901d48e71896779d968 Mon Sep 17 00:00:00 2001 From: Krzysztof Rudnicki Date: Thu, 8 Jun 2023 16:20:20 +0200 Subject: [PATCH] feat: added auto mode --- final/code/main.py | 78 +++++++++++++++++++++++++++++----------------- 1 file changed, 49 insertions(+), 29 deletions(-) diff --git a/final/code/main.py b/final/code/main.py index cc9b662a..bdfe7180 100644 --- a/final/code/main.py +++ b/final/code/main.py @@ -2,12 +2,9 @@ Code for preprocessing data and creating model that predicts and recomends anime based on another anime entered by user """ -import pandas as pd - -import numpy as np import argparse - -import sklearn +import pandas as pd +import numpy as np from sklearn.neighbors import NearestNeighbors from scipy.sparse import csr_matrix @@ -138,7 +135,8 @@ def get_data_info(rating_data, debug=False): f"Min total rating: {smallest_rating}, Max total rating: {highest_rating}") -def preprocessing(rating_data, anime_contact_data, debug=False, user_threshold=500, anime_threshold=200): +def preprocessing(rating_data, anime_contact_data, + debug=False, user_threshold=500, anime_threshold=200, auto=False): """ Preprocesses data for making model more accurate and/or faster """ @@ -151,19 +149,19 @@ def preprocessing(rating_data, anime_contact_data, debug=False, user_threshold=5 rating_data = rating_data.drop(columns="rating_x") rating_data = rating_data.rename(columns={"rating_y": "rating"}) - if debug: + if debug and not auto: print(rating_data) get_data_info(rating_data) pivot_table = rating_data.pivot_table( index="Name", columns="user_id", values="rating" ).fillna(0) - if debug: + if debug and not auto: print(pivot_table) return pivot_table -def predict(prediction_model, pivot_table, seed=42, anime="RANDOM", recommendation_number=6): +def predict(prediction_model, pivot_table, seed=42, anime="RANDOM", recommendation_number=6, auto=False): """ This will choose a random anime name and our prediction_model will predict similar anime. """ @@ -180,11 +178,12 @@ def predict(prediction_model, pivot_table, seed=42, anime="RANDOM", recommendati distance, suggestions = prediction_model.kneighbors( query, n_neighbors=recommendation_number) for i in range(0, len(distance.flatten())): - if i == 0: + if i == 0 and not auto: print(f"Recommendations for {chosen_anime_name}:\n") - else: + elif not auto: print( - f"{i}: {pivot_table.index[suggestions.flatten()[i]]}, with distance of {distance.flatten()[i]}:" + f"""{i}: {pivot_table.index[suggestions.flatten()[i]]}, + with distance of {distance.flatten()[i]}:""" ) @@ -200,45 +199,66 @@ def create_model(pivot_table, metric="cosine", algorithm="brute", neighbors=5): def handle_arguments(): + """ + Handles all arguments that can be used to change algorithm behaviour or program display + """ parser = argparse.ArgumentParser(description='Example script with pyargs') parser.add_argument('--data_limit', '-dl', - help='Specify data limit, Recommended at least 500k, set to -1 for no limit', required=False, type=int, default=-1) - parser.add_argument('--seed', '-s', help='Specify seed', + help="""Specify data limit, + Recommended at least 500k, set to -1 for no limit""", + required=False, type=int, default=-1) + parser.add_argument('--seed', '-s', + help='Specify seed', type=int, required=False, default=42) - parser.add_argument('--debug', '-d', help='Use debug (more information) prints', + parser.add_argument('--debug', '-d', + help='Use debug (more information) prints', type=bool, required=False, default=False) - parser.add_argument('--database', '-db', help='Specify database path', + parser.add_argument('--database', '-db', + help='Specify database path', required=False, default="database") allowed_metric = ["cosine", "mahalanobis", "euclidean"] - parser.add_argument('--metric', '-m', help='Specify metric for NearestNeighbor learner', + parser.add_argument('--metric', '-m', + help='Specify metric for NearestNeighbor learner', required=False, default="cosine", choices=allowed_metric) allowed_algorithms = ['auto', 'ball_tree', 'kd_tree', 'brute'] - parser.add_argument('--algorithm', '-a', help='Specify algorithm for Nearest Neighbor learner', + parser.add_argument('--algorithm', '-a', + help='Specify algorithm for Nearest Neighbor learner', required=False, default="brute", choices=allowed_algorithms) - parser.add_argument('--anime', '-an', help='Specify anime to choose', + parser.add_argument('--anime', '-an', + help='Specify anime to choose', required=False, default="RANDOM") - parser.add_argument('--neighbors', '-n', help='Specify number of nearest neighbors', + parser.add_argument('--neighbors', '-n', + help='Specify number of nearest neighbors', required=False, default=5) - parser.add_argument('--user_threshold', '-ut', help='Specify minimal number of votes required for user to be included in the data, set to -1 for no threshold', + parser.add_argument('--user_threshold', '-ut', + help="""Specify minimal number of votes required for user to be + included in the data, set to -1 for no threshold""", required=False, type=int, default=500) - parser.add_argument('--anime_threshold', '-at', help='Specify minimal number of votes required for anime to be included in the data, set to -1 for no threshold', + parser.add_argument('--anime_threshold', '-at', + help="""Specify minimal number of votes required for anime + to be included in the data, set to -1 for no threshold""", required=False, type=int, default=200) - parser.add_argument('--recommendation_amount', '-ra', help='Specify how much anime should be recommended', + parser.add_argument('--recommendation_amount', '-ra', + help='Specify how much anime should be recommended', required=False, type=int, default=5) + parser.add_argument('--auto', '-au', + help="""Enable auto mode, no debug, no user parameters, + automatic testing and saving results""", + type=bool, required=False, default=False) # Parse the command-line arguments args = parser.parse_args() args.recommendation_amount = args.recommendation_amount + 1 # Access the values of the arguments - return args.seed, args.debug, args.data_limit, args.database, args.metric, args.algorithm, args.anime, args.neighbors, args.user_threshold, args.anime_threshold, args.recommendation_amount + return args.seed, args.debug, args.data_limit, args.database, args.metric, args.algorithm, args.anime, args.neighbors, args.user_threshold, args.anime_threshold, args.recommendation_amount, args.auto if __name__ == "__main__": - seed, debug, data_limit, db, metric, algorithm, anime, neighbors, user_threshold, anime_threshold, recommendation_amount = handle_arguments() + SEED, DEBUG, DATA_LIMT, DB, METRIC, ALGORITHM, ANIME, NEIGHBORS, USER_THRESHOLD, ANIME_THRESHOLD, RECOMMENDATION_AMOUNT, AUTO = handle_arguments() - RATING_DATA, ANIME_CONTACT_DATA = get_data(data_limit, db) + RATING_DATA, ANIME_CONTACT_DATA = get_data(DATA_LIMT, DB) PIVOT_TABLE = preprocessing( - RATING_DATA, ANIME_CONTACT_DATA, debug, user_threshold, anime_threshold) - MODEL = create_model(PIVOT_TABLE, metric, algorithm, neighbors) - predict(MODEL, PIVOT_TABLE, seed, anime, recommendation_amount) + RATING_DATA, ANIME_CONTACT_DATA, DEBUG, USER_THRESHOLD, ANIME_THRESHOLD) + MODEL = create_model(PIVOT_TABLE, METRIC, ALGORITHM, NEIGHBORS) + predict(MODEL, PIVOT_TABLE, SEED, ANIME, RECOMMENDATION_AMOUNT)