diff options
-rw-r--r-- | schroedinger/schroedinger.py | 71 |
1 files changed, 39 insertions, 32 deletions
diff --git a/schroedinger/schroedinger.py b/schroedinger/schroedinger.py index 3554c6b..0c26041 100644 --- a/schroedinger/schroedinger.py +++ b/schroedinger/schroedinger.py @@ -1,3 +1,9 @@ +'''schroedinger.py + +Collection of method common to solve and tests +''' + +import sys from pathlib import Path from typing import TextIO, Type @@ -24,34 +30,33 @@ class Config: if path is not Path: path = Path(path) - def next_parameter(fd: TextIO): + def next_parameter(file: TextIO): '''Read next parameter, ignoring comments or empty lines''' content = None nonlocal current_line while not content: - str = fd.readline() + line = file.readline() current_line += 1 - index = str.find('#') - content = str[0:index].strip() + index = line.find('#') + content = line[0:index].strip() return content - with open(path, 'r') as fd: + with open(path, 'r', encoding='utf8') as file: try: - self.mass = float(next_parameter(fd)) - start, end, steps = next_parameter(fd).split() + self.mass = float(next_parameter(file)) + start, end, steps = next_parameter(file).split() self.interval = [float(start), float(end), int(steps)] - self.eig_interval = [int(attr) - 1 for attr in next_parameter(fd).split()] - self.interpolation = next_parameter(fd) + self.eig_interval = [int(attr) - 1 for attr in next_parameter(file).split()] + self.interpolation = next_parameter(file) - npoints = int(next_parameter(fd)) + npoints = int(next_parameter(file)) self.points = np.zeros((npoints, 2)) for i in range(npoints): - line = next_parameter(fd) + line = next_parameter(file) self.points[i] = np.array([float(comp) for comp in line.split()]) - # TODO: don't be a moron, catch only relevant exceptions - except: - print('Syntax error in \'{}\' line {}'.format(path.name, current_line)) - raise ValueError() + except ValueError: + print(f'Syntax error in \'{path.name}\' line {current_line}') + sys.exit(1) def potential_interp( @@ -64,19 +69,21 @@ def potential_interp( :param points: Points to interpolate within :return: Interpolating object ''' + interpolator = None + if interpolation == 'linear': - def line(x): - return np.interp(x, points[:, 0], points[:, 1]) - return line + def line(pos): + return np.interp(pos, points[:, 0], points[:, 1]) + interpolator = line elif interpolation == 'polynomial': - poly = Polynomial.fit(points[:, 0], points[:, 1], - points.shape[0] - 1) - return poly + interpolator = Polynomial.fit(points[:, 0], points[:, 1], + points.shape[0] - 1) elif interpolation == 'cspline': - cs = CubicSpline(points[:, 0], points[:, 1]) - return cs + interpolator = CubicSpline(points[:, 0], points[:, 1]) + else: + raise ValueError('Invalid interpolator kind. Use any of: linear, polynomial or cspline') - raise ValueError() + return interpolator def build_potential( @@ -111,13 +118,13 @@ def solve_schroedinger( :param eig_interval: Interval of quantum numbers for which the states are to be calculated :return: Eigenvalues and normalized wave functions ''' - n = potential.shape[0] - a = 1 / mass / delta**2 - w, v = eigh_tridiagonal(a + potential, - -a * np.ones(n - 1, dtype=np.float64) / 2.0, - select='i', select_range=eig_interval) + qnumber = potential.shape[0] + coeff = 1 / mass / delta**2 + eig, vec = eigh_tridiagonal(coeff + potential, + -coeff * np.ones(qnumber - 1, dtype=np.float64) / 2.0, + select='i', select_range=eig_interval) # Normalize eigenfunctions - for i in range(w.shape[0]): - v[:, i] /= np.sqrt(delta * np.sum(np.abs(v[:, i])**2)) + for i in range(eig.shape[0]): + vec[:, i] /= np.sqrt(delta * np.sum(np.abs(vec[:, i])**2)) - return w, v + return eig, vec |