From ea98dc9712ac7045cb0d98f6695701af54dd255a Mon Sep 17 00:00:00 2001 From: Krzysztof Rudnicki Date: Sun, 20 Oct 2024 18:27:51 +0200 Subject: [PATCH] feat: changed to proper pytests --- code/main.py | 8 +++---- code/richardson_method.py | 8 ++++--- code/tests.py | 48 +++++++++++++++++++++------------------ 3 files changed, 34 insertions(+), 30 deletions(-) diff --git a/code/main.py b/code/main.py index 16977fa9..b468d54c 100644 --- a/code/main.py +++ b/code/main.py @@ -1,7 +1,5 @@ -from tests import run_tests - -def main(): - run_tests() +import pytest if __name__ == "__main__": - main() + # Run pytest and exit with the appropriate status code + pytest.main(["-v", "tests.py"]) diff --git a/code/richardson_method.py b/code/richardson_method.py index 59aa2964..256c2a9c 100644 --- a/code/richardson_method.py +++ b/code/richardson_method.py @@ -15,20 +15,22 @@ class RichardsonMethod: raise ValueError("Matrix A is not positive semi-definite.") self.lambda_max = EigenvalueMethods.power_method(self.A) self.omega = 2 / (self.lambda_min + self.lambda_max) + def will_converge(self) -> bool: wA = LinearAlgebraUtils.matrix_scalar_multiply(self.A, self.omega) IMinuswA = LinearAlgebraUtils.matrix_matrix_subtraction(self.I, wA) - return LinearAlgebraUtils.matrix_norm(IMinuswA) < 1 + norm = LinearAlgebraUtils.matrix_norm(IMinuswA) + return norm < 1 def solve(self): x = self.x0[:] if not self.will_converge(): - print("Richardson method for those values will NOT converge") + return "Richardson method for those values will NOT converge" + for iteration in range(self.max_iterations): Ax = LinearAlgebraUtils.matrix_vector_multiply(self.A, x) residual = LinearAlgebraUtils.vector_vector_subtraction(self.b, Ax) x = LinearAlgebraUtils.vector_vector_addition(x, LinearAlgebraUtils.scalar_matrix_multiply(self.omega, residual)) - print('Maximum number of iterations reached without convergence.') return x diff --git a/code/tests.py b/code/tests.py index 80852a32..3b3a89b9 100644 --- a/code/tests.py +++ b/code/tests.py @@ -1,32 +1,36 @@ +import pytest import numpy as np from scipy.sparse.linalg import cg from matrix_generator import MatrixGenerator from richardson_method import RichardsonMethod -def run_tests(): - test_sizes = [2, 3, 4, 5, 10, 20, 50, 100] +@pytest.mark.parametrize("n", [2, 3, 4, 5, 10, 20, 50, 100]) +def test_richardson_vs_cg(n): tolerance = 1e-5 + A, b = MatrixGenerator.generate_random_matrix_and_vector(n) - for n in test_sizes: - print(f"\nRunning test for n = {n}") - - A, b = MatrixGenerator.generate_random_matrix_and_vector(n) - print("A: ", A) - print("b: ", b) - richardson_solver = RichardsonMethod(A, b, size=n, max_iterations=1000, tol=1e-7) - solution_richardson = richardson_solver.solve() - print("Richardson Method Solution:", solution_richardson) - - solution_cg, info = cg(A, b) - if info == 0: - print("SciPy Conjugate Gradient solution:", solution_cg) - else: - print("SciPy Conjugate Gradient did not converge.") + richardson_solver = RichardsonMethod(A, b, size=n, max_iterations=1000, tol=1e-7) + solution_richardson = richardson_solver.solve() + + solution_cg, info = cg(A, b) + + if info == 0: # SciPy CG converged + assert_scipy_converged(solution_richardson, solution_cg, tolerance) + else: # SciPy CG did not converge + assert_scipy_not_converged(solution_richardson) +def assert_scipy_converged(solution_richardson, solution_cg, tolerance): + if solution_richardson == "Richardson method for those values will NOT converge": + print("Richardson did not converge, while SciPy did") + assert False, "Richardson did not converge, while SciPy did" + else: difference = np.linalg.norm(solution_richardson - solution_cg) print(f"Difference between Richardson and CG solutions: {difference:.8f}") - - if difference < tolerance: - print("The solutions are effectively the same.") - else: - print("The solutions are different!") + assert difference < tolerance, f"The solutions are different! Difference: {difference:.8f}" + +def assert_scipy_not_converged(solution_richardson): + if solution_richardson == "Richardson method for those values will NOT converge": + print("Richardson and SciPy did not converge") + else: + print("Richardson converged while SciPy did not:", solution_richardson) + assert False, "Richardson converged while SciPy did not"