diff --git a/midterm/code/main.py b/midterm/code/main.py index 8650068f..7d9f67a7 100644 --- a/midterm/code/main.py +++ b/midterm/code/main.py @@ -6,6 +6,7 @@ import pandas as pd import numpy as np import argparse +import sklearn from sklearn.neighbors import NearestNeighbors from scipy.sparse import csr_matrix @@ -176,12 +177,12 @@ def predict(prediction_model, pivot_table, seed=42): ) -def create_model(pivot_table): +def create_model(pivot_table, metric="cosine", algorithm="brute"): """ Creates model based on neaarest neighbor for anime prediction """ pivot_table_matrix = csr_matrix(pivot_table.values) - model = NearestNeighbors(metric="cosine", algorithm="brute") + model = NearestNeighbors(metric=metric, algorithm=algorithm) model.fit(pivot_table_matrix) return model @@ -196,17 +197,24 @@ def handle_arguments(): type=bool, required=False, default=False) parser.add_argument('--database', '-db', help='Specify database path', required=False, default="database") + + allowed_metric = ["cosine", "mahalanobis", "euclidean"] + parser.add_argument('--metric', '-m', help='Specify metric for NearestNeighbor learner', + required=False, default="cosine", choices=allowed_metric) + allowed_algorithms = ['auto', 'ball_tree', 'kd_tree', 'brute'] + parser.add_argument('--algorithm', '-a', help='Specify alrgorithm for Nearest Neighbor learner', + required=False, default="brute", choices=allowed_algorithms) # Parse the command-line arguments args = parser.parse_args() # Access the values of the arguments - return args.seed, args.debug, args.data_limit, args.database + return args.seed, args.debug, args.data_limit, args.database, args.metric, args.algorithm if __name__ == "__main__": - seed, debug, data_limit, db = handle_arguments() + seed, debug, data_limit, db, metric, algorithm = handle_arguments() RATING_DATA, ANIME_CONTACT_DATA = get_data(data_limit, db) PIVOT_TABLE = preprocessing(RATING_DATA, ANIME_CONTACT_DATA, debug) - MODEL = create_model(PIVOT_TABLE) + MODEL = create_model(PIVOT_TABLE, metric, algorithm) predict(MODEL, PIVOT_TABLE, seed)