aboutsummaryrefslogtreecommitdiff
path: root/schroedinger/schrodinger_solve.py
diff options
context:
space:
mode:
authorThomas Albers Raviola <thomas@thomaslabs.org>2024-07-12 14:38:44 +0200
committerThomas Albers Raviola <thomas@thomaslabs.org>2024-07-12 14:38:44 +0200
commit2723cb9cbb66666df451c5a93bcdbf2537eea0dd (patch)
tree0828bbe99920aca0f5ad9616ba253d7d472c9f26 /schroedinger/schrodinger_solve.py
parent665bff51d17329259a80e0aeae1b2af2f4caaa26 (diff)
Add config class
Diffstat (limited to 'schroedinger/schrodinger_solve.py')
-rw-r--r--schroedinger/schrodinger_solve.py52
1 files changed, 51 insertions, 1 deletions
diff --git a/schroedinger/schrodinger_solve.py b/schroedinger/schrodinger_solve.py
index 9369e2e..a759c3a 100644
--- a/schroedinger/schrodinger_solve.py
+++ b/schroedinger/schrodinger_solve.py
@@ -1,5 +1,55 @@
+import numpy as np
+from pathlib import Path
+import argparse
+
+class Config:
+ def __init__(self, path):
+ current_line = 0
+
+ # Ensure Path object
+ if path is not Path:
+ path = Path(path)
+
+ def next_parameter(fd):
+ '''Read next parameter, ignoring comments or empty lines'''
+ content = None
+ nonlocal current_line
+ while not content:
+ str = fd.readline()
+ current_line += 1
+ index = str.find('#')
+ content = str[0:index].strip()
+ return content
+
+ with open(path, 'r') as fd:
+ try:
+ self.mass = float(next_parameter(fd))
+ self.interval = [float(attr) for attr in next_parameter(fd).split()]
+ self.eig_interval = [int(attr) for attr in next_parameter(fd).split()]
+ self.interpolation = next_parameter(fd)
+
+ npoints = int(next_parameter(fd))
+ self.points = np.zeros((npoints, 2))
+ for i in range(npoints):
+ line = next_parameter(fd)
+ self.points[i] = np.array([float(comp) for comp in line.split()])
+ except:
+ print('Syntax error in \'{}\' line {}'.format(path.name, current_line))
+
+
def main():
- pass
+ parser = argparse.ArgumentParser(
+ prog='schrodinger_solve',
+ description='a',
+ epilog='a')
+ parser.add_argument('filename')
+ args = parser.parse_args()
+ conf = Config(args.filename)
+ print(conf.mass)
+ print(conf.interval)
+ print(conf.eig_interval)
+ print(conf.interpolation)
+ print(conf.points)
if __name__ == '__main__':
main()