diff --git a/midterm/code/main.py b/midterm/code/main.py index 7d9f67a7..dc2d75b4 100644 --- a/midterm/code/main.py +++ b/midterm/code/main.py @@ -159,18 +159,25 @@ def preprocessing(rating_data, anime_contact_data, debug=False): return pivot_table -def predict(prediction_model, pivot_table, seed=42): +def predict(prediction_model, pivot_table, seed=42, anime="RANDOM"): """ This will choose a random anime name and our prediction_model will predict similar anime. """ np.random.seed(seed) - random_anime = np.random.choice(pivot_table.shape[0]) - query = pivot_table.iloc[random_anime, :].values.reshape(1, -1) - distance, suggestions = prediction_model.kneighbors(query, n_neighbors=6) - random_anime_name = pivot_table.index[random_anime] + print(pivot_table) + if anime == "RANDOM": + chosen_anime = np.random.choice(pivot_table.shape[0]) + query = pivot_table.iloc[chosen_anime, :].values.reshape(1, -1) + chosen_anime_name = pivot_table.index[chosen_anime] + else: + query = pivot_table.loc[anime].values.reshape(1, -1) + chosen_anime_name = anime + + distance, suggestions = prediction_model.kneighbors( + query, n_neighbors=6) for i in range(0, len(distance.flatten())): if i == 0: - print(f"Recommendations for {random_anime_name}:\n") + print(f"Recommendations for {chosen_anime_name}:\n") else: print( f"{i}: {pivot_table.index[suggestions.flatten()[i]]}, with distance of {distance.flatten()[i]}:" @@ -202,19 +209,21 @@ def handle_arguments(): parser.add_argument('--metric', '-m', help='Specify metric for NearestNeighbor learner', required=False, default="cosine", choices=allowed_metric) allowed_algorithms = ['auto', 'ball_tree', 'kd_tree', 'brute'] - parser.add_argument('--algorithm', '-a', help='Specify alrgorithm for Nearest Neighbor learner', + parser.add_argument('--algorithm', '-a', help='Specify algorithm for Nearest Neighbor learner', required=False, default="brute", choices=allowed_algorithms) + parser.add_argument('--anime', '-an', help='Specify anime to choose', + required=False, default="RANDOM") # Parse the command-line arguments args = parser.parse_args() # Access the values of the arguments - return args.seed, args.debug, args.data_limit, args.database, args.metric, args.algorithm + return args.seed, args.debug, args.data_limit, args.database, args.metric, args.algorithm, args.anime if __name__ == "__main__": - seed, debug, data_limit, db, metric, algorithm = handle_arguments() + seed, debug, data_limit, db, metric, algorithm, anime = handle_arguments() RATING_DATA, ANIME_CONTACT_DATA = get_data(data_limit, db) PIVOT_TABLE = preprocessing(RATING_DATA, ANIME_CONTACT_DATA, debug) MODEL = create_model(PIVOT_TABLE, metric, algorithm) - predict(MODEL, PIVOT_TABLE, seed) + predict(MODEL, PIVOT_TABLE, seed, anime)