aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorThomas Albers Raviola <thomas@thomaslabs.org>2024-07-19 16:11:50 +0200
committerThomas Albers Raviola <thomas@thomaslabs.org>2024-07-19 16:11:50 +0200
commit17faed5a64ca7c989b81dc3b27467e5c0aa14733 (patch)
tree4ea3e866cf5e41db3deff4676fde6fafc56757a3
parent9ff2d4dd949f74b0052d21ecd83e398ab818b02e (diff)
Add type hinting
-rw-r--r--schroedinger/schrodinger_solve.py16
-rw-r--r--schroedinger/schroedinger.py21
-rw-r--r--test/test_graphic.py10
-rw-r--r--test/test_infinite.py9
4 files changed, 37 insertions, 19 deletions
diff --git a/schroedinger/schrodinger_solve.py b/schroedinger/schrodinger_solve.py
index 6885100..2eb83e8 100644
--- a/schroedinger/schrodinger_solve.py
+++ b/schroedinger/schrodinger_solve.py
@@ -3,15 +3,19 @@ import argparse
from pathlib import Path
import numpy as np
+from numpy.typing import NDArray
from schroedinger import (
Config, potential_interp, build_potential, solve_schroedinger
-
)
DESCRIPTION='Solve time independent Schrödinger\'s equation for a given system.'
-def save_wavefuncs(filename, x, v):
+def save_wavefuncs(
+ filename: Path,
+ x: NDArray[np.float64],
+ v: NDArray[np.float64]
+) -> None:
wavefuncs = np.zeros((x.shape[0], v.shape[1] + 1))
wavefuncs[:, 0] = x
for i in range(v.shape[1]):
@@ -19,7 +23,11 @@ def save_wavefuncs(filename, x, v):
np.savetxt(filename, wavefuncs)
-def save_expvalues(filename, x, v):
+def save_expvalues(
+ filename: Path,
+ x: NDArray[np.float64],
+ v: NDArray[np.float64]
+) -> None:
n = v.shape[1]
delta = np.abs(x[1] - x[0])
expvalues = np.zeros((n, 2))
@@ -31,7 +39,7 @@ def save_expvalues(filename, x, v):
np.savetxt(filename, expvalues)
-def main():
+def main() -> None:
parser = argparse.ArgumentParser(
prog='schrodinger_solve',
description=DESCRIPTION,
diff --git a/schroedinger/schroedinger.py b/schroedinger/schroedinger.py
index 4c5d1f7..4556d92 100644
--- a/schroedinger/schroedinger.py
+++ b/schroedinger/schroedinger.py
@@ -1,11 +1,14 @@
from pathlib import Path
+from typing import TextIO
import numpy as np
from numpy.polynomial import Polynomial
from numpy.typing import NDArray
+
from scipy.interpolate import CubicSpline
from scipy.linalg import eigh_tridiagonal
+
class Config:
def __init__(self, path):
current_line = 0
@@ -14,7 +17,7 @@ class Config:
if path is not Path:
path = Path(path)
- def next_parameter(fd):
+ def next_parameter(fd: TextIO):
'''Read next parameter, ignoring comments or empty lines'''
content = None
nonlocal current_line
@@ -44,7 +47,10 @@ class Config:
raise ValueError()
-def potential_interp(interpolation, points):
+def potential_interp(
+ interpolation: str,
+ points: NDArray[np.float64]
+):
if interpolation == 'linear':
def line(x):
return np.interp(x, points[:, 0], points[:, 1])
@@ -60,7 +66,9 @@ def potential_interp(interpolation, points):
raise ValueError()
-def build_potential(config: Config):
+def build_potential(
+ config: Config
+) -> tuple[NDArray[np.float64], np.float64]:
start, end, steps = config.interval
potential = np.zeros((steps, 2))
potential[:, 0] = np.linspace(start, end, steps)
@@ -70,7 +78,12 @@ def build_potential(config: Config):
return potential, delta
-def solve_schroedinger(mass, potential, delta, eig_interval=None):
+def solve_schroedinger(
+ mass: float,
+ potential: NDArray[np.float64],
+ delta: float,
+ eig_interval: tuple[int, int]=None
+) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
'''
returns eigen values and wave functions specified by eig_interval
'''
diff --git a/test/test_graphic.py b/test/test_graphic.py
index 73d8b82..8433cae 100644
--- a/test/test_graphic.py
+++ b/test/test_graphic.py
@@ -1,7 +1,3 @@
-from pathlib import Path
-import sys
-sys.path.insert(0, str(Path.cwd().parent/'schroedinger'))
-
import pytest
import numpy as np
@@ -18,19 +14,19 @@ v = {}
FORMS = ['asymmetric', 'double_linear', 'double_cubic']
@pytest.mark.parametrize('form', FORMS)
-def test_potential(form):
+def test_potential(form: str) -> None:
potential_prec = np.loadtxt(f'test/{form}/potential.dat')
assert np.allclose(potential[form], potential_prec, rtol=1e-2, atol=1e-2)
@pytest.mark.parametrize('form', FORMS)
-def test_e(form):
+def test_e(form: str) -> None:
e_prec = np.loadtxt(f'test/{form}/energies.dat')
assert np.allclose(e[form], e_prec, rtol=1e-2, atol=1e-2)
@pytest.mark.parametrize('form', FORMS)
-def test_v(form):
+def test_v(form: str) -> None:
v_prec = np.loadtxt(f'test/{form}/wavefuncs.dat')[:, 1:]
assert np.allclose(v[form], v_prec, rtol=1e-2, atol=1e-2)
diff --git a/test/test_infinite.py b/test/test_infinite.py
index aa85ae5..bca21b6 100644
--- a/test/test_infinite.py
+++ b/test/test_infinite.py
@@ -1,5 +1,6 @@
import pytest
import numpy as np
+from numpy.typing import NDArray
import matplotlib.pyplot as plt
@@ -8,17 +9,17 @@ from schroedinger import (
)
-def psi(x, n, a):
+def psi(x: NDArray[np.float64], n: int, a: float) -> NDArray[np.float64]:
n += 1 # Index starting from 0
x = -x + a / 2 # Reflect x and move to the left
return np.sqrt(2 / a) * np.sin(n * np.pi * x / a)
-def energy(n, a, m):
- return (n + 1)**2 * np.pi**2 / m / a**2 / 2.0
+def energy(n: int, a: float, mass: float) -> float:
+ return (n + 1)**2 * np.pi**2 / mass / a**2 / 2.0
-def test_infinite():
+def test_infinite() -> None:
conf = Config('test/infinite.inp')
potential, delta = build_potential(conf)