feat: added auto mode

This commit is contained in:
Krzysztof Rudnicki 2023-06-08 16:20:20 +02:00
parent 5b817e9e92
commit 3b96a6f0f4

View File

@ -2,12 +2,9 @@
Code for preprocessing data and creating model that predicts and
recomends anime based on another anime entered by user
"""
import pandas as pd
import numpy as np
import argparse
import sklearn
import pandas as pd
import numpy as np
from sklearn.neighbors import NearestNeighbors
from scipy.sparse import csr_matrix
@ -138,7 +135,8 @@ def get_data_info(rating_data, debug=False):
f"Min total rating: {smallest_rating}, Max total rating: {highest_rating}")
def preprocessing(rating_data, anime_contact_data, debug=False, user_threshold=500, anime_threshold=200):
def preprocessing(rating_data, anime_contact_data,
debug=False, user_threshold=500, anime_threshold=200, auto=False):
"""
Preprocesses data for making model more accurate and/or faster
"""
@ -151,19 +149,19 @@ def preprocessing(rating_data, anime_contact_data, debug=False, user_threshold=5
rating_data = rating_data.drop(columns="rating_x")
rating_data = rating_data.rename(columns={"rating_y": "rating"})
if debug:
if debug and not auto:
print(rating_data)
get_data_info(rating_data)
pivot_table = rating_data.pivot_table(
index="Name", columns="user_id", values="rating"
).fillna(0)
if debug:
if debug and not auto:
print(pivot_table)
return pivot_table
def predict(prediction_model, pivot_table, seed=42, anime="RANDOM", recommendation_number=6):
def predict(prediction_model, pivot_table, seed=42, anime="RANDOM", recommendation_number=6, auto=False):
"""
This will choose a random anime name and our prediction_model will predict similar anime.
"""
@ -180,11 +178,12 @@ def predict(prediction_model, pivot_table, seed=42, anime="RANDOM", recommendati
distance, suggestions = prediction_model.kneighbors(
query, n_neighbors=recommendation_number)
for i in range(0, len(distance.flatten())):
if i == 0:
if i == 0 and not auto:
print(f"Recommendations for {chosen_anime_name}:\n")
else:
elif not auto:
print(
f"{i}: {pivot_table.index[suggestions.flatten()[i]]}, with distance of {distance.flatten()[i]}:"
f"""{i}: {pivot_table.index[suggestions.flatten()[i]]},
with distance of {distance.flatten()[i]}:"""
)
@ -200,45 +199,66 @@ def create_model(pivot_table, metric="cosine", algorithm="brute", neighbors=5):
def handle_arguments():
"""
Handles all arguments that can be used to change algorithm behaviour or program display
"""
parser = argparse.ArgumentParser(description='Example script with pyargs')
parser.add_argument('--data_limit', '-dl',
help='Specify data limit, Recommended at least 500k, set to -1 for no limit', required=False, type=int, default=-1)
parser.add_argument('--seed', '-s', help='Specify seed',
help="""Specify data limit,
Recommended at least 500k, set to -1 for no limit""",
required=False, type=int, default=-1)
parser.add_argument('--seed', '-s',
help='Specify seed',
type=int, required=False, default=42)
parser.add_argument('--debug', '-d', help='Use debug (more information) prints',
parser.add_argument('--debug', '-d',
help='Use debug (more information) prints',
type=bool, required=False, default=False)
parser.add_argument('--database', '-db', help='Specify database path',
parser.add_argument('--database', '-db',
help='Specify database path',
required=False, default="database")
allowed_metric = ["cosine", "mahalanobis", "euclidean"]
parser.add_argument('--metric', '-m', help='Specify metric for NearestNeighbor learner',
parser.add_argument('--metric', '-m',
help='Specify metric for NearestNeighbor learner',
required=False, default="cosine", choices=allowed_metric)
allowed_algorithms = ['auto', 'ball_tree', 'kd_tree', 'brute']
parser.add_argument('--algorithm', '-a', help='Specify algorithm for Nearest Neighbor learner',
parser.add_argument('--algorithm', '-a',
help='Specify algorithm for Nearest Neighbor learner',
required=False, default="brute", choices=allowed_algorithms)
parser.add_argument('--anime', '-an', help='Specify anime to choose',
parser.add_argument('--anime', '-an',
help='Specify anime to choose',
required=False, default="RANDOM")
parser.add_argument('--neighbors', '-n', help='Specify number of nearest neighbors',
parser.add_argument('--neighbors', '-n',
help='Specify number of nearest neighbors',
required=False, default=5)
parser.add_argument('--user_threshold', '-ut', help='Specify minimal number of votes required for user to be included in the data, set to -1 for no threshold',
parser.add_argument('--user_threshold', '-ut',
help="""Specify minimal number of votes required for user to be
included in the data, set to -1 for no threshold""",
required=False, type=int, default=500)
parser.add_argument('--anime_threshold', '-at', help='Specify minimal number of votes required for anime to be included in the data, set to -1 for no threshold',
parser.add_argument('--anime_threshold', '-at',
help="""Specify minimal number of votes required for anime
to be included in the data, set to -1 for no threshold""",
required=False, type=int, default=200)
parser.add_argument('--recommendation_amount', '-ra', help='Specify how much anime should be recommended',
parser.add_argument('--recommendation_amount', '-ra',
help='Specify how much anime should be recommended',
required=False, type=int, default=5)
parser.add_argument('--auto', '-au',
help="""Enable auto mode, no debug, no user parameters,
automatic testing and saving results""",
type=bool, required=False, default=False)
# Parse the command-line arguments
args = parser.parse_args()
args.recommendation_amount = args.recommendation_amount + 1
# Access the values of the arguments
return args.seed, args.debug, args.data_limit, args.database, args.metric, args.algorithm, args.anime, args.neighbors, args.user_threshold, args.anime_threshold, args.recommendation_amount
return args.seed, args.debug, args.data_limit, args.database, args.metric, args.algorithm, args.anime, args.neighbors, args.user_threshold, args.anime_threshold, args.recommendation_amount, args.auto
if __name__ == "__main__":
seed, debug, data_limit, db, metric, algorithm, anime, neighbors, user_threshold, anime_threshold, recommendation_amount = handle_arguments()
SEED, DEBUG, DATA_LIMT, DB, METRIC, ALGORITHM, ANIME, NEIGHBORS, USER_THRESHOLD, ANIME_THRESHOLD, RECOMMENDATION_AMOUNT, AUTO = handle_arguments()
RATING_DATA, ANIME_CONTACT_DATA = get_data(data_limit, db)
RATING_DATA, ANIME_CONTACT_DATA = get_data(DATA_LIMT, DB)
PIVOT_TABLE = preprocessing(
RATING_DATA, ANIME_CONTACT_DATA, debug, user_threshold, anime_threshold)
MODEL = create_model(PIVOT_TABLE, metric, algorithm, neighbors)
predict(MODEL, PIVOT_TABLE, seed, anime, recommendation_amount)
RATING_DATA, ANIME_CONTACT_DATA, DEBUG, USER_THRESHOLD, ANIME_THRESHOLD)
MODEL = create_model(PIVOT_TABLE, METRIC, ALGORITHM, NEIGHBORS)
predict(MODEL, PIVOT_TABLE, SEED, ANIME, RECOMMENDATION_AMOUNT)