From fe852453c0cb80e05a0adc4a95ceb52873461851 Mon Sep 17 00:00:00 2001 From: Thomas Albers Raviola Date: Tue, 4 Jun 2024 11:11:20 +0200 Subject: Modify substitution routines to return vectors --- solvers.py | 6 +++--- test/test_solvers.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/solvers.py b/solvers.py index c7d421c..72f1b1d 100644 --- a/solvers.py +++ b/solvers.py @@ -64,7 +64,7 @@ def forward_substitution( x: solution of the system of equations ''' nn = aa.shape[0] - xx = np.zeros((nn, 1), dtype=np.float_) + xx = np.zeros(nn, dtype=np.float_) for i in range(nn): xx[i] = (bb[i] - np.dot(aa[i, 0:i], xx[0:i])) / aa[i, i] return xx @@ -83,7 +83,7 @@ def back_substitution( x: solution of the system of equations ''' nn = aa.shape[0] - xx = np.zeros((nn, 1), dtype=np.float_) + xx = np.zeros(nn, dtype=np.float_) for i in range(nn - 1, -1, -1): xx[i] = (bb[i] - np.dot(aa[i, i:], xx[i:nn])) / aa[i, i] return xx @@ -114,4 +114,4 @@ def gaussian_eliminate( y = forward_substitution(ll, pp @ bb) # U @ x = y x = back_substitution(uu, y) - return x[:, 0] + return x diff --git a/test/test_solvers.py b/test/test_solvers.py index 5b5da6f..a7c175f 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -43,7 +43,7 @@ def test_lu_factorization() -> None: def test_forward_substitution() -> None: ll = np.array([[2.0, 0.0, 0.0], [-1.0, 2.0, 0.0], [4.0, 4.0, 2.0]]) bb = np.array([3.0, 2.0, 1.0]) - xx_result = solvers.forward_substitution(ll, bb)[:, 0] + xx_result = solvers.forward_substitution(ll, bb) xx_expected = np.array([3 / 2, 7.0 / 4.0, -6]) print(xx_result) assert np.allclose(xx_expected, xx_result) @@ -51,7 +51,7 @@ def test_forward_substitution() -> None: def test_back_substitution() -> None: uu = np.array([[2.0, 4.0, 4.0], [0.0, 2.0, -1.0], [0.0, 0.0, 2.0]]) bb = np.array([1.0, 2.0, 3.0]) - xx_result = solvers.back_substitution(uu, bb)[:, 0] + xx_result = solvers.back_substitution(uu, bb) xx_expected = np.array([-6, 7.0 / 4.0, 3.0 / 2.0]) print(xx_result) assert np.allclose(xx_expected, xx_result) -- cgit v1.2.3