mirror of
https://github.com/kuhyx/WUT_Computer_Science.git
synced 2026-07-04 18:03:14 +02:00
feat: allow for changing metric and algorithm from arguments
This commit is contained in:
parent
8120fefc58
commit
94c005c0fa
@ -6,6 +6,7 @@ import pandas as pd
|
||||
import numpy as np
|
||||
import argparse
|
||||
|
||||
import sklearn
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
from scipy.sparse import csr_matrix
|
||||
|
||||
@ -176,12 +177,12 @@ def predict(prediction_model, pivot_table, seed=42):
|
||||
)
|
||||
|
||||
|
||||
def create_model(pivot_table):
|
||||
def create_model(pivot_table, metric="cosine", algorithm="brute"):
|
||||
"""
|
||||
Creates model based on neaarest neighbor for anime prediction
|
||||
"""
|
||||
pivot_table_matrix = csr_matrix(pivot_table.values)
|
||||
model = NearestNeighbors(metric="cosine", algorithm="brute")
|
||||
model = NearestNeighbors(metric=metric, algorithm=algorithm)
|
||||
model.fit(pivot_table_matrix)
|
||||
return model
|
||||
|
||||
@ -196,17 +197,24 @@ def handle_arguments():
|
||||
type=bool, required=False, default=False)
|
||||
parser.add_argument('--database', '-db', help='Specify database path',
|
||||
required=False, default="database")
|
||||
|
||||
allowed_metric = ["cosine", "mahalanobis", "euclidean"]
|
||||
parser.add_argument('--metric', '-m', help='Specify metric for NearestNeighbor learner',
|
||||
required=False, default="cosine", choices=allowed_metric)
|
||||
allowed_algorithms = ['auto', 'ball_tree', 'kd_tree', 'brute']
|
||||
parser.add_argument('--algorithm', '-a', help='Specify alrgorithm for Nearest Neighbor learner',
|
||||
required=False, default="brute", choices=allowed_algorithms)
|
||||
# Parse the command-line arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
# Access the values of the arguments
|
||||
return args.seed, args.debug, args.data_limit, args.database
|
||||
return args.seed, args.debug, args.data_limit, args.database, args.metric, args.algorithm
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
seed, debug, data_limit, db = handle_arguments()
|
||||
seed, debug, data_limit, db, metric, algorithm = handle_arguments()
|
||||
|
||||
RATING_DATA, ANIME_CONTACT_DATA = get_data(data_limit, db)
|
||||
PIVOT_TABLE = preprocessing(RATING_DATA, ANIME_CONTACT_DATA, debug)
|
||||
MODEL = create_model(PIVOT_TABLE)
|
||||
MODEL = create_model(PIVOT_TABLE, metric, algorithm)
|
||||
predict(MODEL, PIVOT_TABLE, seed)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user