diff options
Diffstat (limited to 'schroedinger/schrodinger_solve.py')
-rw-r--r-- | schroedinger/schrodinger_solve.py | 60 |
1 files changed, 4 insertions, 56 deletions
diff --git a/schroedinger/schrodinger_solve.py b/schroedinger/schrodinger_solve.py index d301498..618d4ce 100644 --- a/schroedinger/schrodinger_solve.py +++ b/schroedinger/schrodinger_solve.py @@ -1,69 +1,17 @@ import argparse -from pathlib import Path import numpy as np -from numpy.polynomial import Polynomial -from numpy.typing import NDArray - from scipy.linalg import eigh_tridiagonal -import scipy.interpolate as interp - - -class Config: - def __init__(self, path): - current_line = 0 - - # Ensure Path object - if path is not Path: - path = Path(path) - - def next_parameter(fd): - '''Read next parameter, ignoring comments or empty lines''' - content = None - nonlocal current_line - while not content: - str = fd.readline() - current_line += 1 - index = str.find('#') - content = str[0:index].strip() - return content - - with open(path, 'r') as fd: - try: - self.mass = float(next_parameter(fd)) - start, end, steps = next_parameter(fd).split() - self.interval = [float(start), float(end), int(steps)] - self.eig_interval = [int(attr) for attr in next_parameter(fd).split()] - self.interpolation = next_parameter(fd) - - npoints = int(next_parameter(fd)) - self.points = np.zeros((npoints, 2)) - for i in range(npoints): - line = next_parameter(fd) - self.points[i] = np.array([float(comp) for comp in line.split()]) - except: - print('Syntax error in \'{}\' line {}'.format(path.name, current_line)) +from schroedinger import Config, potential_interp def build_potential(config: Config): start, end, steps = config.interval potential = np.zeros((steps, 2)) potential[:, 0] = np.linspace(start, end, steps) delta = np.abs(potential[1, 0] - potential[0, 0]) - - if config.interpolation == 'linear': - potential[:, 1] = np.interp(potential[:, 0], config.points[:, 0], - config.points[:, 1]) - elif config.interpolation == 'polynomial': - p = Polynomial.fit(config.points[:, 0], config.points[:, 1], - config.points.shape[0] - 1) - potential[:, 1] = p(potential[:, 0]) - elif config.interpolation == 'cspline': - cs = CubicSpline(config.points[:, 0], config.points[:, 1]) - potential[:, 1] = cs(potential[:, 0]) - else: - raise ValueError() - + interp = potential_interp(config.interpolation, config.points) + potential[:, 1] = interp(potential[:, 0]) return potential, delta @@ -116,7 +64,7 @@ def main(): potential, delta = build_potential(conf) np.savetxt('potential.dat', potential) - e, f = solve_schroedinger(conf.mass, potential[:, 1], delta, conf.eig_interval) + e, v = solve_schroedinger(conf.mass, potential[:, 1], delta, conf.eig_interval) np.savetxt('energies.dat', e) save_wavefuncs('wavefuncs.dat', potential[:, 0], v) |