aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorThomas Albers Raviola <thomas@thomaslabs.org>2024-05-14 17:34:43 +0200
committerThomas Albers Raviola <thomas@thomaslabs.org>2024-05-14 17:34:43 +0200
commit35b9fdf6b4c8e34e13dac72899f05be9afccc0e1 (patch)
tree0c1c21676346566f6fb84a6039e5744f96150c15
parent45aaf49cb3fe07b006316bacac1318f07a08cc19 (diff)
linsolver: Finish gaussian elimination
* linsolver/solvers.py (gaussian_eliminate): Finish implementation. * linsolver/README.md: Add description. * linsolver/.gitignore: New file.
-rw-r--r--.gitignore2
-rw-r--r--README.md3
-rw-r--r--solvers.py43
3 files changed, 46 insertions, 2 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..3883e04
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+__pycache__
+.mypy_cache
diff --git a/README.md b/README.md
index e69de29..22bc85d 100644
--- a/README.md
+++ b/README.md
@@ -0,0 +1,3 @@
+# linsolver
+
+Toy implementation of gaussian elimination with partial pivoting in python
diff --git a/solvers.py b/solvers.py
index 7aa9568..da38e2b 100644
--- a/solvers.py
+++ b/solvers.py
@@ -4,17 +4,56 @@ import numpy as np
from numpy.typing import NDArray
-def gaussian_eliminate(aa: NDArray[np.float_], bb: NDArray[np.float_]) -> NDArray[np.float_] | None:
+def gaussian_eliminate(
+ aa: NDArray[np.float_],
+ bb: NDArray[np.float_],
+ tolerance: float = 1e-6
+) -> NDArray[np.float_] | None:
"""Solves a linear system of equations (Ax = b) by Gauss-elimination
Args:
aa: Matrix with the coefficients. Shape: (n, n).
bb: Right hand side of the equation. Shape: (n,)
+ tolerance: Greatest absolute value considered as 0
Returns:
Vector xx with the solution of the linear equation or None
if the equations are linearly dependent.
"""
nn = aa.shape[0]
- xx = np.zeros((nn,), dtype=np.float_)
+
+ ee = np.zeros((nn, nn + 1))
+ ee[:, 0:nn] = aa
+ ee[:, nn] = bb
+
+ i = 0
+ j = 0
+
+ while i < nn and j < nn:
+ pivot_row = i + np.argmax(np.abs(ee[i:, j]))
+
+ if np.abs(ee[pivot_row, j]) < tolerance:
+ j = j + 1
+ continue
+
+ if i != pivot_row:
+ # Swap rows
+ ee[[i, pivot_row]] = ee[[pivot_row, i]]
+
+ for k in range(i + 1, nn):
+ q = ee[k, j] / ee[i, j]
+ ee[k, :] = ee[k, :] - q * ee[i, :]
+
+ i = i + 1
+ j = j + 1
+
+ # Check if rank of matrix is lower than nn
+ if np.abs(ee[nn - 1, nn - 1]) < tolerance:
+ return None
+
+ # Back substitution
+ xx = np.zeros((nn, 1))
+ for i in range(nn - 1, -1, -1):
+ xx[i] = (ee[i, nn] - np.dot(ee[i, i:nn], xx[i:nn])) / ee[i, i]
+
return xx