diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/test_graphic.py | 38 |
1 files changed, 20 insertions, 18 deletions
diff --git a/test/test_graphic.py b/test/test_graphic.py index 1c23d08..0fa2640 100644 --- a/test/test_graphic.py +++ b/test/test_graphic.py @@ -1,44 +1,46 @@ +''' +Test solutions by comparing with results known to be correct +''' + import pytest import numpy as np -import matplotlib.pyplot as plt - from schroedinger import ( - Config, potential_interp, build_potential, solve_schroedinger + Config, build_potential, solve_schroedinger ) -potential = {} -e = {} -v = {} +POTENTIAL = {} +ENERGIES = {} +WAVEFUNCS = {} FORMS = ['asymmetric', 'double_linear', 'double_cubic', 'harmonic', 'finite'] @pytest.mark.parametrize('form', FORMS) def test_potential(form: str) -> None: + '''Compare potential with stored result''' potential_prec = np.loadtxt(f'test/{form}/potential.dat') - assert np.allclose(potential[form], potential_prec, rtol=1e-2, atol=1e-2) + assert np.allclose(POTENTIAL[form], potential_prec, rtol=1e-2, atol=1e-2) @pytest.mark.parametrize('form', FORMS) -def test_e(form: str) -> None: +def test_energy(form: str) -> None: + '''Compare energy with stored result''' e_prec = np.loadtxt(f'test/{form}/energies.dat') - assert np.allclose(e[form], e_prec, rtol=1e-2, atol=1e-2) + assert np.allclose(ENERGIES[form], e_prec, rtol=1e-2, atol=1e-2) @pytest.mark.parametrize('form', FORMS) -def test_v(form: str) -> None: +def test_wavefunc(form: str) -> None: + '''Compare wave functions with stored result''' v_prec = np.loadtxt(f'test/{form}/wavefuncs.dat')[:, 1:] - assert np.allclose(v[form], v_prec, rtol=1e-2, atol=1e-2) + assert np.allclose(WAVEFUNCS[form], v_prec, rtol=1e-2, atol=1e-2) def setup_module(): - global conf - global potential - global e - global v + '''Set up global variables for tests to run''' for form in FORMS: conf = Config(f'test/{form}.inp') - potential[form], delta = build_potential(conf) - e[form], v[form] = solve_schroedinger(conf.mass, potential[form][:, 1], - delta, conf.eig_interval) + POTENTIAL[form], delta = build_potential(conf) + ENERGIES[form], WAVEFUNCS[form] = solve_schroedinger( + conf.mass, POTENTIAL[form][:, 1], delta, conf.eig_interval) |