diff options
author | Thomas Albers Raviola <thomas@thomaslabs.org> | 2024-07-19 15:40:58 +0200 |
---|---|---|
committer | Thomas Albers Raviola <thomas@thomaslabs.org> | 2024-07-19 15:40:58 +0200 |
commit | 9ff2d4dd949f74b0052d21ecd83e398ab818b02e (patch) | |
tree | 5bc810243fe3da424d70b2aaca77cd89d830072d | |
parent | 27c2a2af862ff21ecbf6a7523d6a52bda42825fd (diff) |
Add output-dir parameter to solver
-rw-r--r-- | schroedinger/schrodinger_solve.py | 32 |
1 files changed, 24 insertions, 8 deletions
diff --git a/schroedinger/schrodinger_solve.py b/schroedinger/schrodinger_solve.py index 8f9f515..6885100 100644 --- a/schroedinger/schrodinger_solve.py +++ b/schroedinger/schrodinger_solve.py @@ -1,4 +1,6 @@ +import sys import argparse +from pathlib import Path import numpy as np @@ -7,6 +9,8 @@ from schroedinger import ( ) +DESCRIPTION='Solve time independent Schrödinger\'s equation for a given system.' + def save_wavefuncs(filename, x, v): wavefuncs = np.zeros((x.shape[0], v.shape[1] + 1)) wavefuncs[:, 0] = x @@ -30,21 +34,33 @@ def save_expvalues(filename, x, v): def main(): parser = argparse.ArgumentParser( prog='schrodinger_solve', - description='a', - epilog='a') + description=DESCRIPTION, + epilog='') + + parser.add_argument('filename', + help='File describing the system to solve') + parser.add_argument('-o', '--output-dir', + help='Output directory for the results') - parser.add_argument('filename') args = parser.parse_args() + conf = Config(args.filename) + output_path = Path(args.output_dir) if args.output_dir else Path.cwd() potential, delta = build_potential(conf) - np.savetxt('potential.dat', potential) - e, v = solve_schroedinger(conf.mass, potential[:, 1], delta, conf.eig_interval) + e, v = solve_schroedinger(conf.mass, potential[:, 1], delta, + conf.eig_interval) - np.savetxt('energies.dat', e) - save_wavefuncs('wavefuncs.dat', potential[:, 0], v) - save_expvalues('expvalues.dat', potential[:, 0], v) + try: + np.savetxt(output_path / 'potential.dat', potential) + np.savetxt(output_path / 'energies.dat', e) + save_wavefuncs(output_path / 'wavefuncs.dat', potential[:, 0], v) + save_expvalues(output_path / 'expvalues.dat', potential[:, 0], v) + except FileNotFoundError: + print('Output files could not be saved.' + ' Are you sure the output directory exists?') + sys.exit(1) if __name__ == '__main__': |