From 4efcbc29a4faeeb7625922dba78d5d4bc3be5753 Mon Sep 17 00:00:00 2001 From: Thomas Albers Raviola Date: Fri, 12 Jul 2024 17:59:08 +0200 Subject: Split library for handling config and interpolation --- schroedinger/__init__.py | 1 + schroedinger/schrodinger_solve.py | 60 +++---------------------------------- schroedinger/schroedinger.py | 62 ++++++++++++++++++++++++++++++++------- 3 files changed, 57 insertions(+), 66 deletions(-) create mode 100644 schroedinger/__init__.py diff --git a/schroedinger/__init__.py b/schroedinger/__init__.py new file mode 100644 index 0000000..1de4510 --- /dev/null +++ b/schroedinger/__init__.py @@ -0,0 +1 @@ +from .schroedinger import Config diff --git a/schroedinger/schrodinger_solve.py b/schroedinger/schrodinger_solve.py index d301498..618d4ce 100644 --- a/schroedinger/schrodinger_solve.py +++ b/schroedinger/schrodinger_solve.py @@ -1,69 +1,17 @@ import argparse -from pathlib import Path import numpy as np -from numpy.polynomial import Polynomial -from numpy.typing import NDArray - from scipy.linalg import eigh_tridiagonal -import scipy.interpolate as interp - - -class Config: - def __init__(self, path): - current_line = 0 - - # Ensure Path object - if path is not Path: - path = Path(path) - - def next_parameter(fd): - '''Read next parameter, ignoring comments or empty lines''' - content = None - nonlocal current_line - while not content: - str = fd.readline() - current_line += 1 - index = str.find('#') - content = str[0:index].strip() - return content - - with open(path, 'r') as fd: - try: - self.mass = float(next_parameter(fd)) - start, end, steps = next_parameter(fd).split() - self.interval = [float(start), float(end), int(steps)] - self.eig_interval = [int(attr) for attr in next_parameter(fd).split()] - self.interpolation = next_parameter(fd) - - npoints = int(next_parameter(fd)) - self.points = np.zeros((npoints, 2)) - for i in range(npoints): - line = next_parameter(fd) - self.points[i] = np.array([float(comp) for comp in line.split()]) - except: - print('Syntax error in \'{}\' line {}'.format(path.name, current_line)) +from schroedinger import Config, potential_interp def build_potential(config: Config): start, end, steps = config.interval potential = np.zeros((steps, 2)) potential[:, 0] = np.linspace(start, end, steps) delta = np.abs(potential[1, 0] - potential[0, 0]) - - if config.interpolation == 'linear': - potential[:, 1] = np.interp(potential[:, 0], config.points[:, 0], - config.points[:, 1]) - elif config.interpolation == 'polynomial': - p = Polynomial.fit(config.points[:, 0], config.points[:, 1], - config.points.shape[0] - 1) - potential[:, 1] = p(potential[:, 0]) - elif config.interpolation == 'cspline': - cs = CubicSpline(config.points[:, 0], config.points[:, 1]) - potential[:, 1] = cs(potential[:, 0]) - else: - raise ValueError() - + interp = potential_interp(config.interpolation, config.points) + potential[:, 1] = interp(potential[:, 0]) return potential, delta @@ -116,7 +64,7 @@ def main(): potential, delta = build_potential(conf) np.savetxt('potential.dat', potential) - e, f = solve_schroedinger(conf.mass, potential[:, 1], delta, conf.eig_interval) + e, v = solve_schroedinger(conf.mass, potential[:, 1], delta, conf.eig_interval) np.savetxt('energies.dat', e) save_wavefuncs('wavefuncs.dat', potential[:, 0], v) diff --git a/schroedinger/schroedinger.py b/schroedinger/schroedinger.py index 63ea5f8..10d994e 100644 --- a/schroedinger/schroedinger.py +++ b/schroedinger/schroedinger.py @@ -1,17 +1,59 @@ +from pathlib import Path -def hello(): - """ - This is a description :) - :return: None - """ - print('Hello') +import numpy as np +from numpy.polynomial import Polynomial +from numpy.typing import NDArray +from scipy.interpolate import CubicSpline +class Config: + def __init__(self, path): + current_line = 0 -def main(): - hello() + # Ensure Path object + if path is not Path: + path = Path(path) + def next_parameter(fd): + '''Read next parameter, ignoring comments or empty lines''' + content = None + nonlocal current_line + while not content: + str = fd.readline() + current_line += 1 + index = str.find('#') + content = str[0:index].strip() + return content + with open(path, 'r') as fd: + try: + self.mass = float(next_parameter(fd)) + start, end, steps = next_parameter(fd).split() + self.interval = [float(start), float(end), int(steps)] + self.eig_interval = [int(attr) for attr in next_parameter(fd).split()] + self.interpolation = next_parameter(fd) -if __name__ == "__main__": + npoints = int(next_parameter(fd)) + self.points = np.zeros((npoints, 2)) + for i in range(npoints): + line = next_parameter(fd) + self.points[i] = np.array([float(comp) for comp in line.split()]) + # TODO: don't be a moron, catch only relevant exceptions + except: + print('Syntax error in \'{}\' line {}'.format(path.name, current_line)) + raise ValueError() - main() + +def potential_interp(interpolation, points): + if interpolation == 'linear': + def line(x): + return np.interp(x, points[:, 0], points[:, 1]) + return line + elif interpolation == 'polynomial': + poly = Polynomial.fit(points[:, 0], points[:, 1], + points.shape[0] - 1) + return poly + elif interpolation == 'cspline': + cs = CubicSpline(points[:, 0], points[:, 1]) + return cs + + raise ValueError() -- cgit v1.2.3