mirror of
https://github.com/kuhyx/WUT_Computer_Science.git
synced 2026-07-04 19:03:01 +02:00
feat: added auto mode
This commit is contained in:
parent
5b817e9e92
commit
3b96a6f0f4
@ -2,12 +2,9 @@
|
||||
Code for preprocessing data and creating model that predicts and
|
||||
recomends anime based on another anime entered by user
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
import numpy as np
|
||||
import argparse
|
||||
|
||||
import sklearn
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
from scipy.sparse import csr_matrix
|
||||
|
||||
@ -138,7 +135,8 @@ def get_data_info(rating_data, debug=False):
|
||||
f"Min total rating: {smallest_rating}, Max total rating: {highest_rating}")
|
||||
|
||||
|
||||
def preprocessing(rating_data, anime_contact_data, debug=False, user_threshold=500, anime_threshold=200):
|
||||
def preprocessing(rating_data, anime_contact_data,
|
||||
debug=False, user_threshold=500, anime_threshold=200, auto=False):
|
||||
"""
|
||||
Preprocesses data for making model more accurate and/or faster
|
||||
"""
|
||||
@ -151,19 +149,19 @@ def preprocessing(rating_data, anime_contact_data, debug=False, user_threshold=5
|
||||
|
||||
rating_data = rating_data.drop(columns="rating_x")
|
||||
rating_data = rating_data.rename(columns={"rating_y": "rating"})
|
||||
if debug:
|
||||
if debug and not auto:
|
||||
print(rating_data)
|
||||
get_data_info(rating_data)
|
||||
|
||||
pivot_table = rating_data.pivot_table(
|
||||
index="Name", columns="user_id", values="rating"
|
||||
).fillna(0)
|
||||
if debug:
|
||||
if debug and not auto:
|
||||
print(pivot_table)
|
||||
return pivot_table
|
||||
|
||||
|
||||
def predict(prediction_model, pivot_table, seed=42, anime="RANDOM", recommendation_number=6):
|
||||
def predict(prediction_model, pivot_table, seed=42, anime="RANDOM", recommendation_number=6, auto=False):
|
||||
"""
|
||||
This will choose a random anime name and our prediction_model will predict similar anime.
|
||||
"""
|
||||
@ -180,11 +178,12 @@ def predict(prediction_model, pivot_table, seed=42, anime="RANDOM", recommendati
|
||||
distance, suggestions = prediction_model.kneighbors(
|
||||
query, n_neighbors=recommendation_number)
|
||||
for i in range(0, len(distance.flatten())):
|
||||
if i == 0:
|
||||
if i == 0 and not auto:
|
||||
print(f"Recommendations for {chosen_anime_name}:\n")
|
||||
else:
|
||||
elif not auto:
|
||||
print(
|
||||
f"{i}: {pivot_table.index[suggestions.flatten()[i]]}, with distance of {distance.flatten()[i]}:"
|
||||
f"""{i}: {pivot_table.index[suggestions.flatten()[i]]},
|
||||
with distance of {distance.flatten()[i]}:"""
|
||||
)
|
||||
|
||||
|
||||
@ -200,45 +199,66 @@ def create_model(pivot_table, metric="cosine", algorithm="brute", neighbors=5):
|
||||
|
||||
|
||||
def handle_arguments():
|
||||
"""
|
||||
Handles all arguments that can be used to change algorithm behaviour or program display
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='Example script with pyargs')
|
||||
parser.add_argument('--data_limit', '-dl',
|
||||
help='Specify data limit, Recommended at least 500k, set to -1 for no limit', required=False, type=int, default=-1)
|
||||
parser.add_argument('--seed', '-s', help='Specify seed',
|
||||
help="""Specify data limit,
|
||||
Recommended at least 500k, set to -1 for no limit""",
|
||||
required=False, type=int, default=-1)
|
||||
parser.add_argument('--seed', '-s',
|
||||
help='Specify seed',
|
||||
type=int, required=False, default=42)
|
||||
parser.add_argument('--debug', '-d', help='Use debug (more information) prints',
|
||||
parser.add_argument('--debug', '-d',
|
||||
help='Use debug (more information) prints',
|
||||
type=bool, required=False, default=False)
|
||||
parser.add_argument('--database', '-db', help='Specify database path',
|
||||
parser.add_argument('--database', '-db',
|
||||
help='Specify database path',
|
||||
required=False, default="database")
|
||||
|
||||
allowed_metric = ["cosine", "mahalanobis", "euclidean"]
|
||||
parser.add_argument('--metric', '-m', help='Specify metric for NearestNeighbor learner',
|
||||
parser.add_argument('--metric', '-m',
|
||||
help='Specify metric for NearestNeighbor learner',
|
||||
required=False, default="cosine", choices=allowed_metric)
|
||||
allowed_algorithms = ['auto', 'ball_tree', 'kd_tree', 'brute']
|
||||
parser.add_argument('--algorithm', '-a', help='Specify algorithm for Nearest Neighbor learner',
|
||||
parser.add_argument('--algorithm', '-a',
|
||||
help='Specify algorithm for Nearest Neighbor learner',
|
||||
required=False, default="brute", choices=allowed_algorithms)
|
||||
parser.add_argument('--anime', '-an', help='Specify anime to choose',
|
||||
parser.add_argument('--anime', '-an',
|
||||
help='Specify anime to choose',
|
||||
required=False, default="RANDOM")
|
||||
parser.add_argument('--neighbors', '-n', help='Specify number of nearest neighbors',
|
||||
parser.add_argument('--neighbors', '-n',
|
||||
help='Specify number of nearest neighbors',
|
||||
required=False, default=5)
|
||||
parser.add_argument('--user_threshold', '-ut', help='Specify minimal number of votes required for user to be included in the data, set to -1 for no threshold',
|
||||
parser.add_argument('--user_threshold', '-ut',
|
||||
help="""Specify minimal number of votes required for user to be
|
||||
included in the data, set to -1 for no threshold""",
|
||||
required=False, type=int, default=500)
|
||||
parser.add_argument('--anime_threshold', '-at', help='Specify minimal number of votes required for anime to be included in the data, set to -1 for no threshold',
|
||||
parser.add_argument('--anime_threshold', '-at',
|
||||
help="""Specify minimal number of votes required for anime
|
||||
to be included in the data, set to -1 for no threshold""",
|
||||
required=False, type=int, default=200)
|
||||
parser.add_argument('--recommendation_amount', '-ra', help='Specify how much anime should be recommended',
|
||||
parser.add_argument('--recommendation_amount', '-ra',
|
||||
help='Specify how much anime should be recommended',
|
||||
required=False, type=int, default=5)
|
||||
parser.add_argument('--auto', '-au',
|
||||
help="""Enable auto mode, no debug, no user parameters,
|
||||
automatic testing and saving results""",
|
||||
type=bool, required=False, default=False)
|
||||
|
||||
# Parse the command-line arguments
|
||||
args = parser.parse_args()
|
||||
args.recommendation_amount = args.recommendation_amount + 1
|
||||
# Access the values of the arguments
|
||||
return args.seed, args.debug, args.data_limit, args.database, args.metric, args.algorithm, args.anime, args.neighbors, args.user_threshold, args.anime_threshold, args.recommendation_amount
|
||||
return args.seed, args.debug, args.data_limit, args.database, args.metric, args.algorithm, args.anime, args.neighbors, args.user_threshold, args.anime_threshold, args.recommendation_amount, args.auto
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
seed, debug, data_limit, db, metric, algorithm, anime, neighbors, user_threshold, anime_threshold, recommendation_amount = handle_arguments()
|
||||
SEED, DEBUG, DATA_LIMT, DB, METRIC, ALGORITHM, ANIME, NEIGHBORS, USER_THRESHOLD, ANIME_THRESHOLD, RECOMMENDATION_AMOUNT, AUTO = handle_arguments()
|
||||
|
||||
RATING_DATA, ANIME_CONTACT_DATA = get_data(data_limit, db)
|
||||
RATING_DATA, ANIME_CONTACT_DATA = get_data(DATA_LIMT, DB)
|
||||
PIVOT_TABLE = preprocessing(
|
||||
RATING_DATA, ANIME_CONTACT_DATA, debug, user_threshold, anime_threshold)
|
||||
MODEL = create_model(PIVOT_TABLE, metric, algorithm, neighbors)
|
||||
predict(MODEL, PIVOT_TABLE, seed, anime, recommendation_amount)
|
||||
RATING_DATA, ANIME_CONTACT_DATA, DEBUG, USER_THRESHOLD, ANIME_THRESHOLD)
|
||||
MODEL = create_model(PIVOT_TABLE, METRIC, ALGORITHM, NEIGHBORS)
|
||||
predict(MODEL, PIVOT_TABLE, SEED, ANIME, RECOMMENDATION_AMOUNT)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user