fix: minor errors

This commit is contained in:
Gromiusz 2025-01-13 18:21:02 +01:00
parent e38f28a0b5
commit 9f8671352c

View File

@ -1,5 +1,6 @@
import gc
import time
import numpy as np
from numba import njit, prange
from time_measurement import time_measurement_longest, longest_threads_time_accumulator, tests_time
import linear_algebra_utils as linAlg
@ -8,23 +9,28 @@ import linear_algebra_utils as linAlg
@njit(parallel=True)
def numba_matrix_vector_multiply(A, input_x, Ax):
for i in prange(len(A)):
Ax[i] = sum(A[i][j] * input_x[j] for j in range(len(input_x)))
acc = 0.0
for j in range(len(input_x)):
acc += A[i][j] * input_x[j]
Ax[i] = acc
@njit(parallel=True)
def numba_vector_vector_subtraction(b, Ax, residual):
for i in prange(len(b)):
residual[i] = b[i] - Ax[i]
@njit(parallel=True)
@njit(nopython=True)
def numba_scalar_vector_multiply(omega, vector, result):
for i in prange(len(vector)):
result[i] = omega * vector[i]
omega_real = omega.real
for i in range(len(vector)):
result[i] = omega_real * vector[i]
@njit(parallel=True)
def numba_vector_vector_addition(input_x, vector, output_x):
for i in prange(len(input_x)):
output_x[i] = input_x[i] + vector[i]
# Funkcje z dekoratorem
@time_measurement_longest(longest_threads_time_accumulator)
def matrix_vector_multiply(A, input_x, Ax):