From c24397c9d2ca50b3989de5107498666724e8c50c Mon Sep 17 00:00:00 2001 From: Thomas Albers Raviola Date: Tue, 21 May 2024 00:23:43 +0200 Subject: Modify gaussian_eliminate to instead use LUP factorization * linsolver/solvers.py (gaussian_eliminate): Use LUP factorization. * linsolver/solvers.py (lu_factorization): New function. * linsolver/solvers.py (back_substitution): New function. * linsolver/solvers.py (forward_substitution): New function. --- solvers.py | 112 ++++++++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 85 insertions(+), 27 deletions(-) diff --git a/solvers.py b/solvers.py index 933bc2f..b544f1b 100644 --- a/solvers.py +++ b/solvers.py @@ -1,59 +1,117 @@ -"""Routines for solving a linear system of equations.""" +'''Routines for solving a linear system of equations.''' import numpy as np from numpy.typing import NDArray - -def gaussian_eliminate( - aa: NDArray[np.float_], - bb: NDArray[np.float_], +def lu_factorization( + mm: NDArray[np.float_], tolerance: float = 1e-6 -) -> NDArray[np.float_] | None: - """Solves a linear system of equations (Ax = b) by Gauss-elimination +) -> NDArray[np.float_]: + '''Computes the LUP factorization of a matrix Args: - aa: Matrix with the coefficients. Shape: (n, n). - bb: Right hand side of the equation. Shape: (n,) + mm: Matrix to factorize tolerance: Greatest absolute value considered as 0 Returns: - Vector xx with the solution of the linear equation or None - if the equations are linearly dependent. - """ - nn = aa.shape[0] + ll: Lower triangular matrix + uu: Upper triangular matrix + pp: Permutation matrix + ''' + nn = mm.shape[0] - ee = np.zeros((nn, nn + 1)) - ee[:, 0:nn] = aa - ee[:, nn] = bb + uu = mm.copy() + ll = np.zeros((nn, nn), dtype=np.float_) + pp = np.eye(nn, dtype=np.float_) i = 0 j = 0 while i < nn and j < nn: - pivot_row = i + np.argmax(np.abs(ee[i:, j])) + pivot_row = i + np.argmax(np.abs(uu[i:, j])) - if np.abs(ee[pivot_row, j]) < tolerance: + if np.abs(uu[pivot_row, j]) < tolerance: j = j + 1 continue if i != pivot_row: # Swap rows - ee[[i, pivot_row]] = ee[[pivot_row, i]] + uu[[i, pivot_row]] = uu[[pivot_row, i]] + ll[[i, pivot_row]] = ll[[pivot_row, i]] + pp[[i, pivot_row]] = pp[[pivot_row, i]] for k in range(i + 1, nn): - q = ee[k, j] / ee[i, j] - ee[k, :] = ee[k, :] - q * ee[i, :] + q = uu[k, j] / uu[i, j] + uu[k, :] = uu[k, :] - q * uu[i, :] + ll[k, j] = q i = i + 1 j = j + 1 - # Check if rank of matrix is lower than nn - if np.abs(ee[nn - 1, nn - 1]) < tolerance: - return None + return ll + np.eye(nn), uu, pp + +def forward_substitution( + aa: NDArray[np.float_], + bb: NDArray[np.float_] +) -> NDArray[np.float_]: + '''Solves (T x = b), where T is a lower triangular matrix - # Back substitution + Args: + aa: lower triangular matrix + bb: vector + + Returns: + x: solution of the system of equations + ''' + nn = aa.shape[0] xx = np.zeros((nn, 1), dtype=np.float_) - for i in range(nn - 1, -1, -1): - xx[i] = (ee[i, nn] - np.dot(ee[i, i:nn], xx[i:nn])) / ee[i, i] + for i in range(nn): + xx[i] = (bb[i] - np.dot(aa[i, 0:i], xx[0:i])) / aa[i, i] + return xx + +def back_substitution( + aa: NDArray[np.float_], + bb: NDArray[np.float_] +) -> NDArray[np.float_]: + '''Solves (T x = b), where T is a upper triangular matrix + + Args: + aa: upper triangular matrix + bb: vector + Returns: + x: solution of the system of equations + ''' + nn = aa.shape[0] + xx = np.zeros((nn, 1), dtype=np.float_) + for i in range(nn - 1, -1, -1): + xx[i] = (bb[i] - np.dot(aa[i, i:], xx[i:nn])) / aa[i, i] return xx + +def gaussian_eliminate( + aa: NDArray[np.float_], + bb: NDArray[np.float_], + tolerance: float = 1e-6 +) -> NDArray[np.float_] | None: + '''Solves a linear system of equations (A x = b) by LUP factorization. + + Args: + aa: upper triangular matrix + bb: vector + tolerance: Greatest absolute value considered as 0 + + Returns: + x: solution of the system of equations + ''' + ll, uu, pp = lu_factorization(aa) + + nn = uu.shape[0] + # Check if rank of matrix is lower than nn + if np.abs(uu[nn - 1, nn - 1]) < tolerance: + return None + + # L y = P @ b + y = forward_substitution(ll, pp @ bb) + # U @ x = y + x = back_substitution(uu, y) + return x -- cgit v1.2.3