aboutsummaryrefslogtreecommitdiff
path: root/schroedinger/schroedinger.py
blob: 3ed2c3fa67a995e08bf0eae2fa4c2330d1274ef2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from pathlib import Path

import numpy as np
from numpy.polynomial import Polynomial
from numpy.typing import NDArray
from scipy.interpolate import CubicSpline
from scipy.linalg import eigh_tridiagonal

class Config:
    def __init__(self, path):
        current_line = 0

        # Ensure Path object
        if path is not Path:
            path = Path(path)

        def next_parameter(fd):
            '''Read next parameter, ignoring comments or empty lines'''
            content = None
            nonlocal current_line
            while not content:
                str = fd.readline()
                current_line += 1
                index = str.find('#')
                content = str[0:index].strip()
            return content

        with open(path, 'r') as fd:
            try:
                self.mass = float(next_parameter(fd))
                start, end, steps = next_parameter(fd).split()
                self.interval = [float(start), float(end), int(steps)]
                self.eig_interval = [int(attr) - 1 for attr in next_parameter(fd).split()]
                self.interpolation = next_parameter(fd)

                npoints = int(next_parameter(fd))
                self.points = np.zeros((npoints, 2))
                for i in range(npoints):
                    line = next_parameter(fd)
                    self.points[i] = np.array([float(comp) for comp in line.split()])
            # TODO: don't be a moron, catch only relevant exceptions
            except:
                print('Syntax error in \'{}\' line {}'.format(path.name, current_line))
                raise ValueError()


def potential_interp(interpolation, points):
    if interpolation == 'linear':
        def line(x):
            return np.interp(x, points[:, 0], points[:, 1])
        return line
    elif interpolation == 'polynomial':
        poly = Polynomial.fit(points[:, 0], points[:, 1],
                              points.shape[0] - 1)
        return poly
    elif interpolation == 'cspline':
        cs = CubicSpline(points[:, 0], points[:, 1])
        return cs

    raise ValueError()


def build_potential(config: Config):
    start, end, steps = config.interval
    potential = np.zeros((steps, 2))
    potential[:, 0] = np.linspace(start, end, steps)
    delta = np.abs(potential[1, 0] - potential[0, 0])
    interp = potential_interp(config.interpolation, config.points)
    potential[:, 1] = interp(potential[:, 0])
    return potential, delta


def solve_schroedinger(mass, potential, delta, eig_interval=None):
    '''
    returns eigen values and wave functions specified by eig_interval
    '''
    n = potential.shape[0]
    a = 1 / mass / delta**2
    w, v = eigh_tridiagonal(a + potential,
                            -a * np.ones(n - 1, dtype=np.float_) / 2.0,
                            select='i', select_range=eig_interval)
    # Normalize eigenfunctions
    for i in range(w.shape[0]):
        v[:, i] /= np.sqrt(delta * np.sum(np.abs(v[:, i])**2))

    return w, v