diff options
Diffstat (limited to 'schroedinger/schroedinger.py')
-rw-r--r-- | schroedinger/schroedinger.py | 89 |
1 files changed, 79 insertions, 10 deletions
diff --git a/schroedinger/schroedinger.py b/schroedinger/schroedinger.py index 63ea5f8..3ed2c3f 100644 --- a/schroedinger/schroedinger.py +++ b/schroedinger/schroedinger.py @@ -1,17 +1,86 @@ +from pathlib import Path -def hello(): - """ - This is a description :) - :return: None - """ - print('Hello') +import numpy as np +from numpy.polynomial import Polynomial +from numpy.typing import NDArray +from scipy.interpolate import CubicSpline +from scipy.linalg import eigh_tridiagonal +class Config: + def __init__(self, path): + current_line = 0 -def main(): - hello() + # Ensure Path object + if path is not Path: + path = Path(path) + def next_parameter(fd): + '''Read next parameter, ignoring comments or empty lines''' + content = None + nonlocal current_line + while not content: + str = fd.readline() + current_line += 1 + index = str.find('#') + content = str[0:index].strip() + return content + with open(path, 'r') as fd: + try: + self.mass = float(next_parameter(fd)) + start, end, steps = next_parameter(fd).split() + self.interval = [float(start), float(end), int(steps)] + self.eig_interval = [int(attr) - 1 for attr in next_parameter(fd).split()] + self.interpolation = next_parameter(fd) -if __name__ == "__main__": + npoints = int(next_parameter(fd)) + self.points = np.zeros((npoints, 2)) + for i in range(npoints): + line = next_parameter(fd) + self.points[i] = np.array([float(comp) for comp in line.split()]) + # TODO: don't be a moron, catch only relevant exceptions + except: + print('Syntax error in \'{}\' line {}'.format(path.name, current_line)) + raise ValueError() - main() + +def potential_interp(interpolation, points): + if interpolation == 'linear': + def line(x): + return np.interp(x, points[:, 0], points[:, 1]) + return line + elif interpolation == 'polynomial': + poly = Polynomial.fit(points[:, 0], points[:, 1], + points.shape[0] - 1) + return poly + elif interpolation == 'cspline': + cs = CubicSpline(points[:, 0], points[:, 1]) + return cs + + raise ValueError() + + +def build_potential(config: Config): + start, end, steps = config.interval + potential = np.zeros((steps, 2)) + potential[:, 0] = np.linspace(start, end, steps) + delta = np.abs(potential[1, 0] - potential[0, 0]) + interp = potential_interp(config.interpolation, config.points) + potential[:, 1] = interp(potential[:, 0]) + return potential, delta + + +def solve_schroedinger(mass, potential, delta, eig_interval=None): + ''' + returns eigen values and wave functions specified by eig_interval + ''' + n = potential.shape[0] + a = 1 / mass / delta**2 + w, v = eigh_tridiagonal(a + potential, + -a * np.ones(n - 1, dtype=np.float_) / 2.0, + select='i', select_range=eig_interval) + # Normalize eigenfunctions + for i in range(w.shape[0]): + v[:, i] /= np.sqrt(delta * np.sum(np.abs(v[:, i])**2)) + + return w, v |