aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--test/test_graphic.py38
1 files changed, 20 insertions, 18 deletions
diff --git a/test/test_graphic.py b/test/test_graphic.py
index 1c23d08..0fa2640 100644
--- a/test/test_graphic.py
+++ b/test/test_graphic.py
@@ -1,44 +1,46 @@
+'''
+Test solutions by comparing with results known to be correct
+'''
+
import pytest
import numpy as np
-import matplotlib.pyplot as plt
-
from schroedinger import (
- Config, potential_interp, build_potential, solve_schroedinger
+ Config, build_potential, solve_schroedinger
)
-potential = {}
-e = {}
-v = {}
+POTENTIAL = {}
+ENERGIES = {}
+WAVEFUNCS = {}
FORMS = ['asymmetric', 'double_linear', 'double_cubic', 'harmonic', 'finite']
@pytest.mark.parametrize('form', FORMS)
def test_potential(form: str) -> None:
+ '''Compare potential with stored result'''
potential_prec = np.loadtxt(f'test/{form}/potential.dat')
- assert np.allclose(potential[form], potential_prec, rtol=1e-2, atol=1e-2)
+ assert np.allclose(POTENTIAL[form], potential_prec, rtol=1e-2, atol=1e-2)
@pytest.mark.parametrize('form', FORMS)
-def test_e(form: str) -> None:
+def test_energy(form: str) -> None:
+ '''Compare energy with stored result'''
e_prec = np.loadtxt(f'test/{form}/energies.dat')
- assert np.allclose(e[form], e_prec, rtol=1e-2, atol=1e-2)
+ assert np.allclose(ENERGIES[form], e_prec, rtol=1e-2, atol=1e-2)
@pytest.mark.parametrize('form', FORMS)
-def test_v(form: str) -> None:
+def test_wavefunc(form: str) -> None:
+ '''Compare wave functions with stored result'''
v_prec = np.loadtxt(f'test/{form}/wavefuncs.dat')[:, 1:]
- assert np.allclose(v[form], v_prec, rtol=1e-2, atol=1e-2)
+ assert np.allclose(WAVEFUNCS[form], v_prec, rtol=1e-2, atol=1e-2)
def setup_module():
- global conf
- global potential
- global e
- global v
+ '''Set up global variables for tests to run'''
for form in FORMS:
conf = Config(f'test/{form}.inp')
- potential[form], delta = build_potential(conf)
- e[form], v[form] = solve_schroedinger(conf.mass, potential[form][:, 1],
- delta, conf.eig_interval)
+ POTENTIAL[form], delta = build_potential(conf)
+ ENERGIES[form], WAVEFUNCS[form] = solve_schroedinger(
+ conf.mass, POTENTIAL[form][:, 1], delta, conf.eig_interval)