aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorThomas Albers Raviola <thomas@thomaslabs.org>2024-06-04 11:08:51 +0200
committerThomas Albers Raviola <thomas@thomaslabs.org>2024-06-04 11:08:51 +0200
commit68870236395435606a767c4c8c1a6b45f5c1ff64 (patch)
tree535d787e504f336ea934137a5bb137cbd5b0b49a
parent255c2e7b7bf721296d4d4f22cb106ab3cf197e0a (diff)
Add tests for LU routine
-rw-r--r--solvers.py4
-rw-r--r--test/test_solvers.py25
2 files changed, 27 insertions, 2 deletions
diff --git a/solvers.py b/solvers.py
index 040c6f4..c7d421c 100644
--- a/solvers.py
+++ b/solvers.py
@@ -6,7 +6,7 @@ from numpy.typing import NDArray
def lu_factorization(
mm: NDArray[np.float_],
tolerance: float = 1e-6
-) -> NDArray[np.float_]:
+) -> tuple[NDArray[np.float_], NDArray[np.float_], NDArray[np.float_]]:
'''Computes the LUP factorization of a matrix
Args:
@@ -28,7 +28,7 @@ def lu_factorization(
j = 0
while i < nn and j < nn:
- pivot_row = i + np.argmax(np.abs(uu[i:, j]))
+ pivot_row = i + int(np.argmax(np.abs(uu[i:, j])))
if np.abs(uu[pivot_row, j]) < tolerance:
j = j + 1
diff --git a/test/test_solvers.py b/test/test_solvers.py
index e5b73d8..5b5da6f 100644
--- a/test/test_solvers.py
+++ b/test/test_solvers.py
@@ -3,8 +3,10 @@
import numpy as np
from numpy.typing import NDArray
+
import solvers
+
def test_elimination_3() -> None:
"""Tests elimination with 3 variables."""
aa = np.array([[2.0, 4.0, 4.0], [5.0, 4.0, 2.0], [1.0, 2.0, -1.0]], dtype=np.float_)
@@ -30,3 +32,26 @@ def test_lindep_3() -> None:
xx_expected = None
xx_gauss = solvers.gaussian_eliminate(aa, bb)
assert xx_expected == xx_gauss
+
+
+def test_lu_factorization() -> None:
+ aa = np.array([[2.0, 4.0, 4.0], [1.0, 2.0, -1.0], [5.0, 4.0, 2.0]])
+ ll, uu, pp = solvers.lu_factorization(aa)
+ assert np.allclose(aa, pp.T @ ll @ uu)
+
+
+def test_forward_substitution() -> None:
+ ll = np.array([[2.0, 0.0, 0.0], [-1.0, 2.0, 0.0], [4.0, 4.0, 2.0]])
+ bb = np.array([3.0, 2.0, 1.0])
+ xx_result = solvers.forward_substitution(ll, bb)[:, 0]
+ xx_expected = np.array([3 / 2, 7.0 / 4.0, -6])
+ print(xx_result)
+ assert np.allclose(xx_expected, xx_result)
+
+def test_back_substitution() -> None:
+ uu = np.array([[2.0, 4.0, 4.0], [0.0, 2.0, -1.0], [0.0, 0.0, 2.0]])
+ bb = np.array([1.0, 2.0, 3.0])
+ xx_result = solvers.back_substitution(uu, bb)[:, 0]
+ xx_expected = np.array([-6, 7.0 / 4.0, 3.0 / 2.0])
+ print(xx_result)
+ assert np.allclose(xx_expected, xx_result)