From 17faed5a64ca7c989b81dc3b27467e5c0aa14733 Mon Sep 17 00:00:00 2001 From: Thomas Albers Raviola Date: Fri, 19 Jul 2024 16:11:50 +0200 Subject: Add type hinting --- schroedinger/schrodinger_solve.py | 16 ++++++++++++---- schroedinger/schroedinger.py | 21 +++++++++++++++++---- test/test_graphic.py | 10 +++------- test/test_infinite.py | 9 +++++---- 4 files changed, 37 insertions(+), 19 deletions(-) diff --git a/schroedinger/schrodinger_solve.py b/schroedinger/schrodinger_solve.py index 6885100..2eb83e8 100644 --- a/schroedinger/schrodinger_solve.py +++ b/schroedinger/schrodinger_solve.py @@ -3,15 +3,19 @@ import argparse from pathlib import Path import numpy as np +from numpy.typing import NDArray from schroedinger import ( Config, potential_interp, build_potential, solve_schroedinger - ) DESCRIPTION='Solve time independent Schrödinger\'s equation for a given system.' -def save_wavefuncs(filename, x, v): +def save_wavefuncs( + filename: Path, + x: NDArray[np.float64], + v: NDArray[np.float64] +) -> None: wavefuncs = np.zeros((x.shape[0], v.shape[1] + 1)) wavefuncs[:, 0] = x for i in range(v.shape[1]): @@ -19,7 +23,11 @@ def save_wavefuncs(filename, x, v): np.savetxt(filename, wavefuncs) -def save_expvalues(filename, x, v): +def save_expvalues( + filename: Path, + x: NDArray[np.float64], + v: NDArray[np.float64] +) -> None: n = v.shape[1] delta = np.abs(x[1] - x[0]) expvalues = np.zeros((n, 2)) @@ -31,7 +39,7 @@ def save_expvalues(filename, x, v): np.savetxt(filename, expvalues) -def main(): +def main() -> None: parser = argparse.ArgumentParser( prog='schrodinger_solve', description=DESCRIPTION, diff --git a/schroedinger/schroedinger.py b/schroedinger/schroedinger.py index 4c5d1f7..4556d92 100644 --- a/schroedinger/schroedinger.py +++ b/schroedinger/schroedinger.py @@ -1,11 +1,14 @@ from pathlib import Path +from typing import TextIO import numpy as np from numpy.polynomial import Polynomial from numpy.typing import NDArray + from scipy.interpolate import CubicSpline from scipy.linalg import eigh_tridiagonal + class Config: def __init__(self, path): current_line = 0 @@ -14,7 +17,7 @@ class Config: if path is not Path: path = Path(path) - def next_parameter(fd): + def next_parameter(fd: TextIO): '''Read next parameter, ignoring comments or empty lines''' content = None nonlocal current_line @@ -44,7 +47,10 @@ class Config: raise ValueError() -def potential_interp(interpolation, points): +def potential_interp( + interpolation: str, + points: NDArray[np.float64] +): if interpolation == 'linear': def line(x): return np.interp(x, points[:, 0], points[:, 1]) @@ -60,7 +66,9 @@ def potential_interp(interpolation, points): raise ValueError() -def build_potential(config: Config): +def build_potential( + config: Config +) -> tuple[NDArray[np.float64], np.float64]: start, end, steps = config.interval potential = np.zeros((steps, 2)) potential[:, 0] = np.linspace(start, end, steps) @@ -70,7 +78,12 @@ def build_potential(config: Config): return potential, delta -def solve_schroedinger(mass, potential, delta, eig_interval=None): +def solve_schroedinger( + mass: float, + potential: NDArray[np.float64], + delta: float, + eig_interval: tuple[int, int]=None +) -> tuple[NDArray[np.float64], NDArray[np.float64]]: ''' returns eigen values and wave functions specified by eig_interval ''' diff --git a/test/test_graphic.py b/test/test_graphic.py index 73d8b82..8433cae 100644 --- a/test/test_graphic.py +++ b/test/test_graphic.py @@ -1,7 +1,3 @@ -from pathlib import Path -import sys -sys.path.insert(0, str(Path.cwd().parent/'schroedinger')) - import pytest import numpy as np @@ -18,19 +14,19 @@ v = {} FORMS = ['asymmetric', 'double_linear', 'double_cubic'] @pytest.mark.parametrize('form', FORMS) -def test_potential(form): +def test_potential(form: str) -> None: potential_prec = np.loadtxt(f'test/{form}/potential.dat') assert np.allclose(potential[form], potential_prec, rtol=1e-2, atol=1e-2) @pytest.mark.parametrize('form', FORMS) -def test_e(form): +def test_e(form: str) -> None: e_prec = np.loadtxt(f'test/{form}/energies.dat') assert np.allclose(e[form], e_prec, rtol=1e-2, atol=1e-2) @pytest.mark.parametrize('form', FORMS) -def test_v(form): +def test_v(form: str) -> None: v_prec = np.loadtxt(f'test/{form}/wavefuncs.dat')[:, 1:] assert np.allclose(v[form], v_prec, rtol=1e-2, atol=1e-2) diff --git a/test/test_infinite.py b/test/test_infinite.py index aa85ae5..bca21b6 100644 --- a/test/test_infinite.py +++ b/test/test_infinite.py @@ -1,5 +1,6 @@ import pytest import numpy as np +from numpy.typing import NDArray import matplotlib.pyplot as plt @@ -8,17 +9,17 @@ from schroedinger import ( ) -def psi(x, n, a): +def psi(x: NDArray[np.float64], n: int, a: float) -> NDArray[np.float64]: n += 1 # Index starting from 0 x = -x + a / 2 # Reflect x and move to the left return np.sqrt(2 / a) * np.sin(n * np.pi * x / a) -def energy(n, a, m): - return (n + 1)**2 * np.pi**2 / m / a**2 / 2.0 +def energy(n: int, a: float, mass: float) -> float: + return (n + 1)**2 * np.pi**2 / mass / a**2 / 2.0 -def test_infinite(): +def test_infinite() -> None: conf = Config('test/infinite.inp') potential, delta = build_potential(conf) -- cgit v1.2.3 From 3907db7c10c6745dda60050d21036a602f2c3570 Mon Sep 17 00:00:00 2001 From: Thomas Albers Raviola Date: Fri, 19 Jul 2024 16:53:53 +0200 Subject: Add documentation to common library --- schroedinger/schroedinger.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/schroedinger/schroedinger.py b/schroedinger/schroedinger.py index 4556d92..3554c6b 100644 --- a/schroedinger/schroedinger.py +++ b/schroedinger/schroedinger.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import TextIO +from typing import TextIO, Type import numpy as np from numpy.polynomial import Polynomial @@ -9,7 +9,14 @@ from scipy.interpolate import CubicSpline from scipy.linalg import eigh_tridiagonal +Interpolator = Type[callable] | Type[Polynomial] | Type[CubicSpline] +# type Interpolator = callable | Polynomial | CubicSpline + + class Config: + '''Wrapper for reading, parsing and storing the contents of a configuration + file + ''' def __init__(self, path): current_line = 0 @@ -50,7 +57,13 @@ class Config: def potential_interp( interpolation: str, points: NDArray[np.float64] -): +) -> Interpolator: + '''Create an interpolator for a set of predefined potential points + + :param interpolation: Kind of interpolation + :param points: Points to interpolate within + :return: Interpolating object + ''' if interpolation == 'linear': def line(x): return np.interp(x, points[:, 0], points[:, 1]) @@ -69,6 +82,11 @@ def potential_interp( def build_potential( config: Config ) -> tuple[NDArray[np.float64], np.float64]: + '''Build a potential based on the options inside a configuration file + + :param config: System parameters + :return: Potential and distance between samples + ''' start, end, steps = config.interval potential = np.zeros((steps, 2)) potential[:, 0] = np.linspace(start, end, steps) @@ -84,8 +102,14 @@ def solve_schroedinger( delta: float, eig_interval: tuple[int, int]=None ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: - ''' - returns eigen values and wave functions specified by eig_interval + '''Solve the one dimensional, time independent Schrödinger's equation for a + particle of given mass inside a discretized potential + + :param mass: Mass of the particle + :param potential: Discretized potential + :param delta: Distance between points in the potential + :param eig_interval: Interval of quantum numbers for which the states are to be calculated + :return: Eigenvalues and normalized wave functions ''' n = potential.shape[0] a = 1 / mass / delta**2 -- cgit v1.2.3