feat: add option to specifiy anime to choose from

This commit is contained in:
Krzysztof Rudnicki 2023-05-29 22:54:18 +02:00
parent 94c005c0fa
commit 0d9182a28f

View File

@ -159,18 +159,25 @@ def preprocessing(rating_data, anime_contact_data, debug=False):
return pivot_table
def predict(prediction_model, pivot_table, seed=42):
def predict(prediction_model, pivot_table, seed=42, anime="RANDOM"):
"""
This will choose a random anime name and our prediction_model will predict similar anime.
"""
np.random.seed(seed)
random_anime = np.random.choice(pivot_table.shape[0])
query = pivot_table.iloc[random_anime, :].values.reshape(1, -1)
distance, suggestions = prediction_model.kneighbors(query, n_neighbors=6)
random_anime_name = pivot_table.index[random_anime]
print(pivot_table)
if anime == "RANDOM":
chosen_anime = np.random.choice(pivot_table.shape[0])
query = pivot_table.iloc[chosen_anime, :].values.reshape(1, -1)
chosen_anime_name = pivot_table.index[chosen_anime]
else:
query = pivot_table.loc[anime].values.reshape(1, -1)
chosen_anime_name = anime
distance, suggestions = prediction_model.kneighbors(
query, n_neighbors=6)
for i in range(0, len(distance.flatten())):
if i == 0:
print(f"Recommendations for {random_anime_name}:\n")
print(f"Recommendations for {chosen_anime_name}:\n")
else:
print(
f"{i}: {pivot_table.index[suggestions.flatten()[i]]}, with distance of {distance.flatten()[i]}:"
@ -202,19 +209,21 @@ def handle_arguments():
parser.add_argument('--metric', '-m', help='Specify metric for NearestNeighbor learner',
required=False, default="cosine", choices=allowed_metric)
allowed_algorithms = ['auto', 'ball_tree', 'kd_tree', 'brute']
parser.add_argument('--algorithm', '-a', help='Specify alrgorithm for Nearest Neighbor learner',
parser.add_argument('--algorithm', '-a', help='Specify algorithm for Nearest Neighbor learner',
required=False, default="brute", choices=allowed_algorithms)
parser.add_argument('--anime', '-an', help='Specify anime to choose',
required=False, default="RANDOM")
# Parse the command-line arguments
args = parser.parse_args()
# Access the values of the arguments
return args.seed, args.debug, args.data_limit, args.database, args.metric, args.algorithm
return args.seed, args.debug, args.data_limit, args.database, args.metric, args.algorithm, args.anime
if __name__ == "__main__":
seed, debug, data_limit, db, metric, algorithm = handle_arguments()
seed, debug, data_limit, db, metric, algorithm, anime = handle_arguments()
RATING_DATA, ANIME_CONTACT_DATA = get_data(data_limit, db)
PIVOT_TABLE = preprocessing(RATING_DATA, ANIME_CONTACT_DATA, debug)
MODEL = create_model(PIVOT_TABLE, metric, algorithm)
predict(MODEL, PIVOT_TABLE, seed)
predict(MODEL, PIVOT_TABLE, seed, anime)