aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--README.md9
-rw-r--r--solvers.py2
-rw-r--r--test/__init__.py0
-rw-r--r--test/test_solvers.py (renamed from test_solvers.py)28
5 files changed, 14 insertions, 26 deletions
diff --git a/.gitignore b/.gitignore
index 3883e04..a2eeca3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,3 @@
__pycache__
.mypy_cache
+.pytest_cache
diff --git a/README.md b/README.md
index 22bc85d..28ed767 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,12 @@
# linsolver
Toy implementation of gaussian elimination with partial pivoting in python
+
+# Testing
+
+In order to test the code, run the following command in the project main's
+folder
+
+```
+$ python -m pytest
+```
diff --git a/solvers.py b/solvers.py
index b544f1b..040c6f4 100644
--- a/solvers.py
+++ b/solvers.py
@@ -114,4 +114,4 @@ def gaussian_eliminate(
y = forward_substitution(ll, pp @ bb)
# U @ x = y
x = back_substitution(uu, y)
- return x
+ return x[:, 0]
diff --git a/test/__init__.py b/test/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/__init__.py
diff --git a/test_solvers.py b/test/test_solvers.py
index ad37c8c..e5b73d8 100644
--- a/test_solvers.py
+++ b/test/test_solvers.py
@@ -5,25 +5,13 @@ import numpy as np
from numpy.typing import NDArray
import solvers
-
-def main() -> None:
- """Main testing function."""
-
- print("\nTest elimination")
- test_elimination_3()
- print("\nTest pivot")
- test_pivot_3()
- print("\nTest linear dependency")
- test_lindep_3()
-
-
def test_elimination_3() -> None:
"""Tests elimination with 3 variables."""
aa = np.array([[2.0, 4.0, 4.0], [5.0, 4.0, 2.0], [1.0, 2.0, -1.0]], dtype=np.float_)
bb = np.array([1.0, 4.0, 2.0], dtype=np.float_)
xx_expected = np.array([0.666666666666667, 0.416666666666667, -0.5], dtype=np.float_)
xx_gauss = solvers.gaussian_eliminate(aa, bb)
- _check_result(xx_expected, xx_gauss)
+ assert np.allclose(xx_expected, xx_gauss)
def test_pivot_3() -> None:
@@ -32,7 +20,7 @@ def test_pivot_3() -> None:
bb = np.array([1.0, 2.0, 4.0])
xx_expected = np.array([0.666666666666667, 0.416666666666667, -0.5])
xx_gauss = solvers.gaussian_eliminate(aa, bb)
- _check_result(xx_expected, xx_gauss)
+ assert np.allclose(xx_expected, xx_gauss)
def test_lindep_3() -> None:
@@ -41,14 +29,4 @@ def test_lindep_3() -> None:
bb = np.array([1.0, 2.0, 3.0])
xx_expected = None
xx_gauss = solvers.gaussian_eliminate(aa, bb)
- _check_result(xx_expected, xx_gauss)
-
-
-def _check_result(expected: NDArray[np.float_] | None, obtained: NDArray[np.float_] | None) -> None:
- """Checks results by printing expected and obtained one."""
- print("Expected:", expected)
- print("Obtained:", obtained)
-
-
-if __name__ == "__main__":
- main()
+ assert xx_expected == xx_gauss