mirror of
https://github.com/kuhyx/WUT_Computer_Science.git
synced 2026-07-04 20:23:04 +02:00
feat: add option to specifiy anime to choose from
This commit is contained in:
parent
94c005c0fa
commit
0d9182a28f
@ -159,18 +159,25 @@ def preprocessing(rating_data, anime_contact_data, debug=False):
|
||||
return pivot_table
|
||||
|
||||
|
||||
def predict(prediction_model, pivot_table, seed=42):
|
||||
def predict(prediction_model, pivot_table, seed=42, anime="RANDOM"):
|
||||
"""
|
||||
This will choose a random anime name and our prediction_model will predict similar anime.
|
||||
"""
|
||||
np.random.seed(seed)
|
||||
random_anime = np.random.choice(pivot_table.shape[0])
|
||||
query = pivot_table.iloc[random_anime, :].values.reshape(1, -1)
|
||||
distance, suggestions = prediction_model.kneighbors(query, n_neighbors=6)
|
||||
random_anime_name = pivot_table.index[random_anime]
|
||||
print(pivot_table)
|
||||
if anime == "RANDOM":
|
||||
chosen_anime = np.random.choice(pivot_table.shape[0])
|
||||
query = pivot_table.iloc[chosen_anime, :].values.reshape(1, -1)
|
||||
chosen_anime_name = pivot_table.index[chosen_anime]
|
||||
else:
|
||||
query = pivot_table.loc[anime].values.reshape(1, -1)
|
||||
chosen_anime_name = anime
|
||||
|
||||
distance, suggestions = prediction_model.kneighbors(
|
||||
query, n_neighbors=6)
|
||||
for i in range(0, len(distance.flatten())):
|
||||
if i == 0:
|
||||
print(f"Recommendations for {random_anime_name}:\n")
|
||||
print(f"Recommendations for {chosen_anime_name}:\n")
|
||||
else:
|
||||
print(
|
||||
f"{i}: {pivot_table.index[suggestions.flatten()[i]]}, with distance of {distance.flatten()[i]}:"
|
||||
@ -202,19 +209,21 @@ def handle_arguments():
|
||||
parser.add_argument('--metric', '-m', help='Specify metric for NearestNeighbor learner',
|
||||
required=False, default="cosine", choices=allowed_metric)
|
||||
allowed_algorithms = ['auto', 'ball_tree', 'kd_tree', 'brute']
|
||||
parser.add_argument('--algorithm', '-a', help='Specify alrgorithm for Nearest Neighbor learner',
|
||||
parser.add_argument('--algorithm', '-a', help='Specify algorithm for Nearest Neighbor learner',
|
||||
required=False, default="brute", choices=allowed_algorithms)
|
||||
parser.add_argument('--anime', '-an', help='Specify anime to choose',
|
||||
required=False, default="RANDOM")
|
||||
# Parse the command-line arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
# Access the values of the arguments
|
||||
return args.seed, args.debug, args.data_limit, args.database, args.metric, args.algorithm
|
||||
return args.seed, args.debug, args.data_limit, args.database, args.metric, args.algorithm, args.anime
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
seed, debug, data_limit, db, metric, algorithm = handle_arguments()
|
||||
seed, debug, data_limit, db, metric, algorithm, anime = handle_arguments()
|
||||
|
||||
RATING_DATA, ANIME_CONTACT_DATA = get_data(data_limit, db)
|
||||
PIVOT_TABLE = preprocessing(RATING_DATA, ANIME_CONTACT_DATA, debug)
|
||||
MODEL = create_model(PIVOT_TABLE, metric, algorithm)
|
||||
predict(MODEL, PIVOT_TABLE, seed)
|
||||
predict(MODEL, PIVOT_TABLE, seed, anime)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user