feat: allow for changing metric and algorithm from arguments

This commit is contained in:
Krzysztof Rudnicki 2023-05-29 22:33:29 +02:00
parent 8120fefc58
commit 94c005c0fa

View File

@ -6,6 +6,7 @@ import pandas as pd
import numpy as np
import argparse
import sklearn
from sklearn.neighbors import NearestNeighbors
from scipy.sparse import csr_matrix
@ -176,12 +177,12 @@ def predict(prediction_model, pivot_table, seed=42):
)
def create_model(pivot_table):
def create_model(pivot_table, metric="cosine", algorithm="brute"):
"""
Creates model based on neaarest neighbor for anime prediction
"""
pivot_table_matrix = csr_matrix(pivot_table.values)
model = NearestNeighbors(metric="cosine", algorithm="brute")
model = NearestNeighbors(metric=metric, algorithm=algorithm)
model.fit(pivot_table_matrix)
return model
@ -196,17 +197,24 @@ def handle_arguments():
type=bool, required=False, default=False)
parser.add_argument('--database', '-db', help='Specify database path',
required=False, default="database")
allowed_metric = ["cosine", "mahalanobis", "euclidean"]
parser.add_argument('--metric', '-m', help='Specify metric for NearestNeighbor learner',
required=False, default="cosine", choices=allowed_metric)
allowed_algorithms = ['auto', 'ball_tree', 'kd_tree', 'brute']
parser.add_argument('--algorithm', '-a', help='Specify alrgorithm for Nearest Neighbor learner',
required=False, default="brute", choices=allowed_algorithms)
# Parse the command-line arguments
args = parser.parse_args()
# Access the values of the arguments
return args.seed, args.debug, args.data_limit, args.database
return args.seed, args.debug, args.data_limit, args.database, args.metric, args.algorithm
if __name__ == "__main__":
seed, debug, data_limit, db = handle_arguments()
seed, debug, data_limit, db, metric, algorithm = handle_arguments()
RATING_DATA, ANIME_CONTACT_DATA = get_data(data_limit, db)
PIVOT_TABLE = preprocessing(RATING_DATA, ANIME_CONTACT_DATA, debug)
MODEL = create_model(PIVOT_TABLE)
MODEL = create_model(PIVOT_TABLE, metric, algorithm)
predict(MODEL, PIVOT_TABLE, seed)