aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--solvers.py6
-rw-r--r--test/test_solvers.py4
2 files changed, 5 insertions, 5 deletions
diff --git a/solvers.py b/solvers.py
index c7d421c..72f1b1d 100644
--- a/solvers.py
+++ b/solvers.py
@@ -64,7 +64,7 @@ def forward_substitution(
x: solution of the system of equations
'''
nn = aa.shape[0]
- xx = np.zeros((nn, 1), dtype=np.float_)
+ xx = np.zeros(nn, dtype=np.float_)
for i in range(nn):
xx[i] = (bb[i] - np.dot(aa[i, 0:i], xx[0:i])) / aa[i, i]
return xx
@@ -83,7 +83,7 @@ def back_substitution(
x: solution of the system of equations
'''
nn = aa.shape[0]
- xx = np.zeros((nn, 1), dtype=np.float_)
+ xx = np.zeros(nn, dtype=np.float_)
for i in range(nn - 1, -1, -1):
xx[i] = (bb[i] - np.dot(aa[i, i:], xx[i:nn])) / aa[i, i]
return xx
@@ -114,4 +114,4 @@ def gaussian_eliminate(
y = forward_substitution(ll, pp @ bb)
# U @ x = y
x = back_substitution(uu, y)
- return x[:, 0]
+ return x
diff --git a/test/test_solvers.py b/test/test_solvers.py
index 5b5da6f..a7c175f 100644
--- a/test/test_solvers.py
+++ b/test/test_solvers.py
@@ -43,7 +43,7 @@ def test_lu_factorization() -> None:
def test_forward_substitution() -> None:
ll = np.array([[2.0, 0.0, 0.0], [-1.0, 2.0, 0.0], [4.0, 4.0, 2.0]])
bb = np.array([3.0, 2.0, 1.0])
- xx_result = solvers.forward_substitution(ll, bb)[:, 0]
+ xx_result = solvers.forward_substitution(ll, bb)
xx_expected = np.array([3 / 2, 7.0 / 4.0, -6])
print(xx_result)
assert np.allclose(xx_expected, xx_result)
@@ -51,7 +51,7 @@ def test_forward_substitution() -> None:
def test_back_substitution() -> None:
uu = np.array([[2.0, 4.0, 4.0], [0.0, 2.0, -1.0], [0.0, 0.0, 2.0]])
bb = np.array([1.0, 2.0, 3.0])
- xx_result = solvers.back_substitution(uu, bb)[:, 0]
+ xx_result = solvers.back_substitution(uu, bb)
xx_expected = np.array([-6, 7.0 / 4.0, 3.0 / 2.0])
print(xx_result)
assert np.allclose(xx_expected, xx_result)