aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorThomas Albers Raviola <thomas@thomaslabs.org>2024-07-19 15:40:58 +0200
committerThomas Albers Raviola <thomas@thomaslabs.org>2024-07-19 15:40:58 +0200
commit9ff2d4dd949f74b0052d21ecd83e398ab818b02e (patch)
tree5bc810243fe3da424d70b2aaca77cd89d830072d
parent27c2a2af862ff21ecbf6a7523d6a52bda42825fd (diff)
Add output-dir parameter to solver
-rw-r--r--schroedinger/schrodinger_solve.py32
1 files changed, 24 insertions, 8 deletions
diff --git a/schroedinger/schrodinger_solve.py b/schroedinger/schrodinger_solve.py
index 8f9f515..6885100 100644
--- a/schroedinger/schrodinger_solve.py
+++ b/schroedinger/schrodinger_solve.py
@@ -1,4 +1,6 @@
+import sys
import argparse
+from pathlib import Path
import numpy as np
@@ -7,6 +9,8 @@ from schroedinger import (
)
+DESCRIPTION='Solve time independent Schrödinger\'s equation for a given system.'
+
def save_wavefuncs(filename, x, v):
wavefuncs = np.zeros((x.shape[0], v.shape[1] + 1))
wavefuncs[:, 0] = x
@@ -30,21 +34,33 @@ def save_expvalues(filename, x, v):
def main():
parser = argparse.ArgumentParser(
prog='schrodinger_solve',
- description='a',
- epilog='a')
+ description=DESCRIPTION,
+ epilog='')
+
+ parser.add_argument('filename',
+ help='File describing the system to solve')
+ parser.add_argument('-o', '--output-dir',
+ help='Output directory for the results')
- parser.add_argument('filename')
args = parser.parse_args()
+
conf = Config(args.filename)
+ output_path = Path(args.output_dir) if args.output_dir else Path.cwd()
potential, delta = build_potential(conf)
- np.savetxt('potential.dat', potential)
- e, v = solve_schroedinger(conf.mass, potential[:, 1], delta, conf.eig_interval)
+ e, v = solve_schroedinger(conf.mass, potential[:, 1], delta,
+ conf.eig_interval)
- np.savetxt('energies.dat', e)
- save_wavefuncs('wavefuncs.dat', potential[:, 0], v)
- save_expvalues('expvalues.dat', potential[:, 0], v)
+ try:
+ np.savetxt(output_path / 'potential.dat', potential)
+ np.savetxt(output_path / 'energies.dat', e)
+ save_wavefuncs(output_path / 'wavefuncs.dat', potential[:, 0], v)
+ save_expvalues(output_path / 'expvalues.dat', potential[:, 0], v)
+ except FileNotFoundError:
+ print('Output files could not be saved.'
+ ' Are you sure the output directory exists?')
+ sys.exit(1)
if __name__ == '__main__':