From cc803f1d661d04cae4b3fc80e0f3ce56ac7bb57f Mon Sep 17 00:00:00 2001 From: Krzysztof Rudnicki Date: Sun, 11 Jun 2023 20:08:50 +0200 Subject: [PATCH] feat: tests now test different algorithms --- final/code/main.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/final/code/main.py b/final/code/main.py index bea2dc6f..c8a21e30 100644 --- a/final/code/main.py +++ b/final/code/main.py @@ -222,8 +222,11 @@ def create_model(pivot_table, rows_number, metric="cosine", algorithm="brute", n """ neighbors_number = calculate_neighbors(pivot_table.shape[0], neighbors) pivot_table_matrix = csr_matrix(pivot_table.values) - model = NearestNeighbors(n_neighbors=neighbors_number, - metric=metric, algorithm=algorithm) + if algorithm == "brute": + model = NearestNeighbors(n_neighbors=neighbors_number, + metric=metric, algorithm=algorithm) + else: + model = NearestNeighbors(n_neighbors=neighbors_number, algorithm=algorithm) try: model.fit(pivot_table_matrix) except: @@ -290,7 +293,7 @@ def handle_arguments(): def auto_mode(data_limit = -1, seed = 42, anime="RANDOM"): print("Started auto mode") - algorithm_spread = ['ball_tree', 'kd_tree', 'brute'] + algorithm_spread = ['auto', 'ball_tree', 'kd_tree', 'brute'] neighbor_spread = [5, "sqrt", "half", "log", "n-1"] # No reason to access and waste computational power every time we run the simulation starting_rating_data, starting_anime_contact_data, starting_rows_number = get_data(limit_data=data_limit) @@ -299,8 +302,12 @@ def auto_mode(data_limit = -1, seed = 42, anime="RANDOM"): if os.path.exists('test_results'): shutil.rmtree('test_results') for algorithm in algorithm_spread: - print("testing for algorithm: ", algorithm) - possibleMetrics = sorted(VALID_METRICS_SPARSE[algorithm]) + possibleMetrics = [] + if algorithm != 'auto': + possibleMetrics = sorted(VALID_METRICS_SPARSE[algorithm]) + print("testing for algorithm: ", algorithm, possibleMetrics) + if possibleMetrics == []: + possibleMetrics = [""] for metric in possibleMetrics: print("testing for algorithm, metric: ", algorithm, metric) for neighbor_amount in neighbor_spread: