aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--schroedinger/schroedinger.py71
1 files changed, 39 insertions, 32 deletions
diff --git a/schroedinger/schroedinger.py b/schroedinger/schroedinger.py
index 3554c6b..0c26041 100644
--- a/schroedinger/schroedinger.py
+++ b/schroedinger/schroedinger.py
@@ -1,3 +1,9 @@
+'''schroedinger.py
+
+Collection of method common to solve and tests
+'''
+
+import sys
from pathlib import Path
from typing import TextIO, Type
@@ -24,34 +30,33 @@ class Config:
if path is not Path:
path = Path(path)
- def next_parameter(fd: TextIO):
+ def next_parameter(file: TextIO):
'''Read next parameter, ignoring comments or empty lines'''
content = None
nonlocal current_line
while not content:
- str = fd.readline()
+ line = file.readline()
current_line += 1
- index = str.find('#')
- content = str[0:index].strip()
+ index = line.find('#')
+ content = line[0:index].strip()
return content
- with open(path, 'r') as fd:
+ with open(path, 'r', encoding='utf8') as file:
try:
- self.mass = float(next_parameter(fd))
- start, end, steps = next_parameter(fd).split()
+ self.mass = float(next_parameter(file))
+ start, end, steps = next_parameter(file).split()
self.interval = [float(start), float(end), int(steps)]
- self.eig_interval = [int(attr) - 1 for attr in next_parameter(fd).split()]
- self.interpolation = next_parameter(fd)
+ self.eig_interval = [int(attr) - 1 for attr in next_parameter(file).split()]
+ self.interpolation = next_parameter(file)
- npoints = int(next_parameter(fd))
+ npoints = int(next_parameter(file))
self.points = np.zeros((npoints, 2))
for i in range(npoints):
- line = next_parameter(fd)
+ line = next_parameter(file)
self.points[i] = np.array([float(comp) for comp in line.split()])
- # TODO: don't be a moron, catch only relevant exceptions
- except:
- print('Syntax error in \'{}\' line {}'.format(path.name, current_line))
- raise ValueError()
+ except ValueError:
+ print(f'Syntax error in \'{path.name}\' line {current_line}')
+ sys.exit(1)
def potential_interp(
@@ -64,19 +69,21 @@ def potential_interp(
:param points: Points to interpolate within
:return: Interpolating object
'''
+ interpolator = None
+
if interpolation == 'linear':
- def line(x):
- return np.interp(x, points[:, 0], points[:, 1])
- return line
+ def line(pos):
+ return np.interp(pos, points[:, 0], points[:, 1])
+ interpolator = line
elif interpolation == 'polynomial':
- poly = Polynomial.fit(points[:, 0], points[:, 1],
- points.shape[0] - 1)
- return poly
+ interpolator = Polynomial.fit(points[:, 0], points[:, 1],
+ points.shape[0] - 1)
elif interpolation == 'cspline':
- cs = CubicSpline(points[:, 0], points[:, 1])
- return cs
+ interpolator = CubicSpline(points[:, 0], points[:, 1])
+ else:
+ raise ValueError('Invalid interpolator kind. Use any of: linear, polynomial or cspline')
- raise ValueError()
+ return interpolator
def build_potential(
@@ -111,13 +118,13 @@ def solve_schroedinger(
:param eig_interval: Interval of quantum numbers for which the states are to be calculated
:return: Eigenvalues and normalized wave functions
'''
- n = potential.shape[0]
- a = 1 / mass / delta**2
- w, v = eigh_tridiagonal(a + potential,
- -a * np.ones(n - 1, dtype=np.float64) / 2.0,
- select='i', select_range=eig_interval)
+ qnumber = potential.shape[0]
+ coeff = 1 / mass / delta**2
+ eig, vec = eigh_tridiagonal(coeff + potential,
+ -coeff * np.ones(qnumber - 1, dtype=np.float64) / 2.0,
+ select='i', select_range=eig_interval)
# Normalize eigenfunctions
- for i in range(w.shape[0]):
- v[:, i] /= np.sqrt(delta * np.sum(np.abs(v[:, i])**2))
+ for i in range(eig.shape[0]):
+ vec[:, i] /= np.sqrt(delta * np.sum(np.abs(vec[:, i])**2))
- return w, v
+ return eig, vec