aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorThomas Albers Raviola <thomas@thomaslabs.org>2024-07-12 17:59:08 +0200
committerThomas Albers Raviola <thomas@thomaslabs.org>2024-07-12 17:59:08 +0200
commit4efcbc29a4faeeb7625922dba78d5d4bc3be5753 (patch)
treeabb2a3b09801c588a36c07d024c63a00b49b348a
parent599daa156166ae93c315358769ae56145c46ac12 (diff)
Split library for handling config and interpolation
-rw-r--r--schroedinger/__init__.py1
-rw-r--r--schroedinger/schrodinger_solve.py60
-rw-r--r--schroedinger/schroedinger.py62
3 files changed, 57 insertions, 66 deletions
diff --git a/schroedinger/__init__.py b/schroedinger/__init__.py
new file mode 100644
index 0000000..1de4510
--- /dev/null
+++ b/schroedinger/__init__.py
@@ -0,0 +1 @@
+from .schroedinger import Config
diff --git a/schroedinger/schrodinger_solve.py b/schroedinger/schrodinger_solve.py
index d301498..618d4ce 100644
--- a/schroedinger/schrodinger_solve.py
+++ b/schroedinger/schrodinger_solve.py
@@ -1,69 +1,17 @@
import argparse
-from pathlib import Path
import numpy as np
-from numpy.polynomial import Polynomial
-from numpy.typing import NDArray
-
from scipy.linalg import eigh_tridiagonal
-import scipy.interpolate as interp
-
-
-class Config:
- def __init__(self, path):
- current_line = 0
-
- # Ensure Path object
- if path is not Path:
- path = Path(path)
-
- def next_parameter(fd):
- '''Read next parameter, ignoring comments or empty lines'''
- content = None
- nonlocal current_line
- while not content:
- str = fd.readline()
- current_line += 1
- index = str.find('#')
- content = str[0:index].strip()
- return content
-
- with open(path, 'r') as fd:
- try:
- self.mass = float(next_parameter(fd))
- start, end, steps = next_parameter(fd).split()
- self.interval = [float(start), float(end), int(steps)]
- self.eig_interval = [int(attr) for attr in next_parameter(fd).split()]
- self.interpolation = next_parameter(fd)
-
- npoints = int(next_parameter(fd))
- self.points = np.zeros((npoints, 2))
- for i in range(npoints):
- line = next_parameter(fd)
- self.points[i] = np.array([float(comp) for comp in line.split()])
- except:
- print('Syntax error in \'{}\' line {}'.format(path.name, current_line))
+from schroedinger import Config, potential_interp
def build_potential(config: Config):
start, end, steps = config.interval
potential = np.zeros((steps, 2))
potential[:, 0] = np.linspace(start, end, steps)
delta = np.abs(potential[1, 0] - potential[0, 0])
-
- if config.interpolation == 'linear':
- potential[:, 1] = np.interp(potential[:, 0], config.points[:, 0],
- config.points[:, 1])
- elif config.interpolation == 'polynomial':
- p = Polynomial.fit(config.points[:, 0], config.points[:, 1],
- config.points.shape[0] - 1)
- potential[:, 1] = p(potential[:, 0])
- elif config.interpolation == 'cspline':
- cs = CubicSpline(config.points[:, 0], config.points[:, 1])
- potential[:, 1] = cs(potential[:, 0])
- else:
- raise ValueError()
-
+ interp = potential_interp(config.interpolation, config.points)
+ potential[:, 1] = interp(potential[:, 0])
return potential, delta
@@ -116,7 +64,7 @@ def main():
potential, delta = build_potential(conf)
np.savetxt('potential.dat', potential)
- e, f = solve_schroedinger(conf.mass, potential[:, 1], delta, conf.eig_interval)
+ e, v = solve_schroedinger(conf.mass, potential[:, 1], delta, conf.eig_interval)
np.savetxt('energies.dat', e)
save_wavefuncs('wavefuncs.dat', potential[:, 0], v)
diff --git a/schroedinger/schroedinger.py b/schroedinger/schroedinger.py
index 63ea5f8..10d994e 100644
--- a/schroedinger/schroedinger.py
+++ b/schroedinger/schroedinger.py
@@ -1,17 +1,59 @@
+from pathlib import Path
-def hello():
- """
- This is a description :)
- :return: None
- """
- print('Hello')
+import numpy as np
+from numpy.polynomial import Polynomial
+from numpy.typing import NDArray
+from scipy.interpolate import CubicSpline
+class Config:
+ def __init__(self, path):
+ current_line = 0
-def main():
- hello()
+ # Ensure Path object
+ if path is not Path:
+ path = Path(path)
+ def next_parameter(fd):
+ '''Read next parameter, ignoring comments or empty lines'''
+ content = None
+ nonlocal current_line
+ while not content:
+ str = fd.readline()
+ current_line += 1
+ index = str.find('#')
+ content = str[0:index].strip()
+ return content
+ with open(path, 'r') as fd:
+ try:
+ self.mass = float(next_parameter(fd))
+ start, end, steps = next_parameter(fd).split()
+ self.interval = [float(start), float(end), int(steps)]
+ self.eig_interval = [int(attr) for attr in next_parameter(fd).split()]
+ self.interpolation = next_parameter(fd)
-if __name__ == "__main__":
+ npoints = int(next_parameter(fd))
+ self.points = np.zeros((npoints, 2))
+ for i in range(npoints):
+ line = next_parameter(fd)
+ self.points[i] = np.array([float(comp) for comp in line.split()])
+ # TODO: don't be a moron, catch only relevant exceptions
+ except:
+ print('Syntax error in \'{}\' line {}'.format(path.name, current_line))
+ raise ValueError()
- main()
+
+def potential_interp(interpolation, points):
+ if interpolation == 'linear':
+ def line(x):
+ return np.interp(x, points[:, 0], points[:, 1])
+ return line
+ elif interpolation == 'polynomial':
+ poly = Polynomial.fit(points[:, 0], points[:, 1],
+ points.shape[0] - 1)
+ return poly
+ elif interpolation == 'cspline':
+ cs = CubicSpline(points[:, 0], points[:, 1])
+ return cs
+
+ raise ValueError()