From 68870236395435606a767c4c8c1a6b45f5c1ff64 Mon Sep 17 00:00:00 2001 From: Thomas Albers Raviola Date: Tue, 4 Jun 2024 11:08:51 +0200 Subject: Add tests for LU routine --- solvers.py | 4 ++-- test/test_solvers.py | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/solvers.py b/solvers.py index 040c6f4..c7d421c 100644 --- a/solvers.py +++ b/solvers.py @@ -6,7 +6,7 @@ from numpy.typing import NDArray def lu_factorization( mm: NDArray[np.float_], tolerance: float = 1e-6 -) -> NDArray[np.float_]: +) -> tuple[NDArray[np.float_], NDArray[np.float_], NDArray[np.float_]]: '''Computes the LUP factorization of a matrix Args: @@ -28,7 +28,7 @@ def lu_factorization( j = 0 while i < nn and j < nn: - pivot_row = i + np.argmax(np.abs(uu[i:, j])) + pivot_row = i + int(np.argmax(np.abs(uu[i:, j]))) if np.abs(uu[pivot_row, j]) < tolerance: j = j + 1 diff --git a/test/test_solvers.py b/test/test_solvers.py index e5b73d8..5b5da6f 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -3,8 +3,10 @@ import numpy as np from numpy.typing import NDArray + import solvers + def test_elimination_3() -> None: """Tests elimination with 3 variables.""" aa = np.array([[2.0, 4.0, 4.0], [5.0, 4.0, 2.0], [1.0, 2.0, -1.0]], dtype=np.float_) @@ -30,3 +32,26 @@ def test_lindep_3() -> None: xx_expected = None xx_gauss = solvers.gaussian_eliminate(aa, bb) assert xx_expected == xx_gauss + + +def test_lu_factorization() -> None: + aa = np.array([[2.0, 4.0, 4.0], [1.0, 2.0, -1.0], [5.0, 4.0, 2.0]]) + ll, uu, pp = solvers.lu_factorization(aa) + assert np.allclose(aa, pp.T @ ll @ uu) + + +def test_forward_substitution() -> None: + ll = np.array([[2.0, 0.0, 0.0], [-1.0, 2.0, 0.0], [4.0, 4.0, 2.0]]) + bb = np.array([3.0, 2.0, 1.0]) + xx_result = solvers.forward_substitution(ll, bb)[:, 0] + xx_expected = np.array([3 / 2, 7.0 / 4.0, -6]) + print(xx_result) + assert np.allclose(xx_expected, xx_result) + +def test_back_substitution() -> None: + uu = np.array([[2.0, 4.0, 4.0], [0.0, 2.0, -1.0], [0.0, 0.0, 2.0]]) + bb = np.array([1.0, 2.0, 3.0]) + xx_result = solvers.back_substitution(uu, bb)[:, 0] + xx_expected = np.array([-6, 7.0 / 4.0, 3.0 / 2.0]) + print(xx_result) + assert np.allclose(xx_expected, xx_result) -- cgit v1.2.3