aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--test/test_infinite.py55
1 files changed, 28 insertions, 27 deletions
diff --git a/test/test_infinite.py b/test/test_infinite.py
index bca21b6..066604c 100644
--- a/test/test_infinite.py
+++ b/test/test_infinite.py
@@ -1,46 +1,47 @@
-import pytest
+'''
+Test solutions to the infinite potential well by comparing with theoretical
+values
+'''
+
import numpy as np
from numpy.typing import NDArray
-import matplotlib.pyplot as plt
-
from schroedinger import (
- Config, potential_interp, build_potential, solve_schroedinger
+ Config, build_potential, solve_schroedinger
)
-
-def psi(x: NDArray[np.float64], n: int, a: float) -> NDArray[np.float64]:
- n += 1 # Index starting from 0
- x = -x + a / 2 # Reflect x and move to the left
- return np.sqrt(2 / a) * np.sin(n * np.pi * x / a)
+def psi(x: NDArray[np.float64], level: int, resolution: float) -> NDArray[np.float64]:
+ '''Generate the n=level wave function'''
+ level += 1 # Index starting from 0
+ x = -x + resolution / 2 # Reflect x and move to the left
+ return np.sqrt(2 / resolution) * np.sin(level * np.pi * x / resolution)
-def energy(n: int, a: float, mass: float) -> float:
- return (n + 1)**2 * np.pi**2 / mass / a**2 / 2.0
+def energy(level: int, resolution: float, mass: float) -> float:
+ '''Energy eigenvalue for the n=level wave function'''
+ return (level + 1)**2 * np.pi**2 / mass / resolution**2 / 2.0
def test_infinite() -> None:
+ '''Test infinite potential well'''
conf = Config('test/infinite.inp')
potential, delta = build_potential(conf)
- e, v = solve_schroedinger(conf.mass, potential[:, 1], delta, conf.eig_interval)
- # Account for -v also being an eigenvector if v is one
- v = np.abs(v)
+ energies, wavefuncs = solve_schroedinger(conf.mass, potential[:, 1], delta,
+ conf.eig_interval)
- a = np.abs(conf.points[1][0] - conf.points[0][0])
- e_theory = np.zeros(e.shape)
- v_theory = np.zeros(v.shape)
+ # Account for -v also being an eigenvector if v is one
+ wavefuncs = np.abs(wavefuncs)
+ resolution = np.abs(conf.points[1][0] - conf.points[0][0])
+ energies_theory = np.zeros(energies.shape)
+ wavefuncs_theory = np.zeros(wavefuncs.shape)
- for n in range(e.shape[0]):
- e_theory[n] = energy(n, a, conf.mass)
- v_theory[:, n] = np.abs(psi(potential[:, 0], n, a))
- # for n in range(e.shape[0]):
- # plt.plot(potential[:, 0], np.abs(v[:, n]), label='Num{}'.format(n))
- # plt.plot(potential[:, 0], np.abs(v_theory[:, n]), ls='--', label='Theory{}'.format(n))
- # plt.legend()
- # plt.savefig('test.pdf')
+ for level in range(energies.shape[0]):
+ energies_theory[level] = energy(level, resolution, conf.mass)
+ wavefuncs_theory[:, level] = np.abs(psi(potential[:, 0], level,
+ resolution))
- assert (np.allclose(e, e_theory, rtol=1e-2, atol=1e-2)
- and np.allclose(v, v_theory, rtol=1e-2, atol=1e-2))
+ assert (np.allclose(energies, energies_theory, rtol=1e-2, atol=1e-2)
+ and np.allclose(wavefuncs, wavefuncs_theory, rtol=1e-2, atol=1e-2))