Skip to content

Commit e06708a

Browse files
author
ronald.jaepel
committed
Add method to save Cadet sim as python file which can generate the Cadet sim again
Add test for the save_as_python method
1 parent 645ccce commit e06708a

File tree

2 files changed

+186
-17
lines changed

2 files changed

+186
-17
lines changed

cadet/cadet.py

+122-17
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
from addict import Dict
2-
31
import warnings
2+
43
with warnings.catch_warnings():
5-
warnings.filterwarnings("ignore",category=FutureWarning)
4+
warnings.filterwarnings("ignore", category=FutureWarning)
65
import h5py
76
import numpy
87
import subprocess
@@ -14,9 +13,11 @@
1413
import contextlib
1514

1615
from pathlib import Path
16+
from addict import Dict
1717

1818
from cadet.cadet_dll import CadetDLL
1919

20+
2021
class H5():
2122
pp = pprint.PrettyPrinter(indent=4)
2223

@@ -77,6 +78,32 @@ def load_json(self, filename, update=False):
7778
else:
7879
self.root = data
7980

81+
def save_as_python_script(self, filename: str, only_return_pythonic_representation=False):
82+
if not filename.endswith(".py"):
83+
raise Warning(f"The filename given to .save_as_python_script isn't a python file name.")
84+
85+
code_lines_list = [
86+
"import numpy",
87+
"from cadet import Cadet",
88+
"",
89+
"sim = Cadet()",
90+
"root = sim.root",
91+
]
92+
93+
code_lines_list = recursively_turn_dict_to_python_list(dictionary=self.root,
94+
current_lines_list=code_lines_list,
95+
prefix="root")
96+
97+
filename_for_reproduced_h5_file = filename.replace(".py", ".h5")
98+
code_lines_list.append(f"sim.filename = '{filename_for_reproduced_h5_file}'")
99+
code_lines_list.append("sim.save()")
100+
101+
if not only_return_pythonic_representation:
102+
with open(filename, "w") as handle:
103+
handle.writelines([line + "\n" for line in code_lines_list])
104+
else:
105+
return code_lines_list
106+
80107
def append(self, lock=False):
81108
"This can only be used to write new keys to the system, this is faster than having to read the data before writing it"
82109
if self.filename is not None:
@@ -117,10 +144,12 @@ def __setitem__(self, key, value):
117144
obj = obj[i]
118145
obj[parts[-1]] = value
119146

147+
120148
def is_dll(value):
121149
suffix = Path(value).suffix
122150
return suffix in {'.so', '.dll'}
123151

152+
124153
class CadetMeta(type):
125154
_cadet_runner_class = None
126155
_is_file_class = None
@@ -144,7 +173,7 @@ def __init__(cls):
144173
del cls._cadet_runner_class
145174

146175
if is_dll(value):
147-
cls._cadet_runner_class = CadetDLL(value)
176+
cls._cadet_runner_class = CadetDLL(value)
148177
cls._is_file_class = False
149178
else:
150179
cls._cadet_runner_class = CadetFile(value)
@@ -154,8 +183,9 @@ def __init__(cls):
154183
def cadet_path(cls):
155184
del cls._cadet_runner_class
156185

186+
157187
class Cadet(H5, metaclass=CadetMeta):
158-
#cadet_path must be set in order for simulations to run
188+
# cadet_path must be set in order for simulations to run
159189
def __init__(self, *data):
160190
super().__init__(*data)
161191
self._cadet_runner = None
@@ -188,7 +218,7 @@ def cadet_path(self, value):
188218
del self._cadet_runner
189219

190220
if is_dll(value):
191-
self._cadet_runner = CadetDLL(value)
221+
self._cadet_runner = CadetDLL(value)
192222
self._is_file = False
193223
else:
194224
self._cadet_runner = CadetFile(value)
@@ -209,14 +239,14 @@ def load_results(self):
209239
if runner is not None:
210240
runner.load_results(self)
211241

212-
def run(self, timeout = None, check=None):
242+
def run(self, timeout=None, check=None):
213243
data = self.cadet_runner.run(simulation=self.root.input, filename=self.filename, timeout=timeout, check=check)
214-
#self.return_information = data
244+
# self.return_information = data
215245
return data
216246

217-
def run_load(self, timeout = None, check=None, clear=True):
247+
def run_load(self, timeout=None, check=None, clear=True):
218248
data = self.cadet_runner.run(simulation=self.root.input, filename=self.filename, timeout=timeout, check=check)
219-
#self.return_information = data
249+
# self.return_information = data
220250
self.load_results()
221251
if clear:
222252
self.clear()
@@ -227,14 +257,15 @@ def clear(self):
227257
if runner is not None:
228258
runner.clear()
229259

260+
230261
class CadetFile:
231262

232263
def __init__(self, cadet_path):
233264
self.cadet_path = cadet_path
234265

235-
def run(self, filename = None, simulation=None, timeout = None, check=None):
266+
def run(self, filename=None, simulation=None, timeout=None, check=None):
236267
if filename is not None:
237-
data = subprocess.run([self.cadet_path, filename], timeout = timeout, check=check, capture_output=True)
268+
data = subprocess.run([self.cadet_path, filename], timeout=timeout, check=check, capture_output=True)
238269
return data
239270
else:
240271
print("Filename must be set before run can be used")
@@ -245,9 +276,10 @@ def clear(self):
245276
def load_results(self, sim):
246277
sim.load(paths=["/meta", "/output"], update=True)
247278

279+
248280
def convert_from_numpy(data, func):
249281
ans = Dict()
250-
for key_original,item in data.items():
282+
for key_original, item in data.items():
251283
key = func(key_original)
252284
if isinstance(item, numpy.ndarray):
253285
item = item.tolist()
@@ -264,16 +296,18 @@ def convert_from_numpy(data, func):
264296
ans[key] = item
265297
return ans
266298

267-
def recursively_load_dict( data, func):
299+
300+
def recursively_load_dict(data, func):
268301
ans = Dict()
269-
for key_original,item in data.items():
302+
for key_original, item in data.items():
270303
key = func(key_original)
271304
if isinstance(item, dict):
272305
ans[key] = recursively_load_dict(item, func)
273306
else:
274307
ans[key] = item
275308
return ans
276309

310+
277311
def set_path(obj, path, value):
278312
"paths need to be broken up so that subobjects are correctly made"
279313
path = path.split('/')
@@ -285,7 +319,8 @@ def set_path(obj, path, value):
285319

286320
temp[path[-1]] = value
287321

288-
def recursively_load( h5file, path, func, paths):
322+
323+
def recursively_load(h5file, path, func, paths):
289324
ans = Dict()
290325
if paths is not None:
291326
for path in paths:
@@ -306,8 +341,8 @@ def recursively_load( h5file, path, func, paths):
306341
ans[key] = recursively_load(h5file, local_path + '/', func, None)
307342
return ans
308343

309-
def recursively_save(h5file, path, dic, func):
310344

345+
def recursively_save(h5file, path, dic, func):
311346
if not isinstance(path, str):
312347
raise ValueError("path must be a string")
313348
if not isinstance(h5file, h5py._hl.files.File):
@@ -347,3 +382,73 @@ def recursively_save(h5file, path, dic, func):
347382
raise KeyError(f'Name conflict with upper and lower case entries for key "{path}{key}".')
348383
else:
349384
raise
385+
386+
387+
def recursively_turn_dict_to_python_list(dictionary: dict, current_lines_list: list = None, prefix: str = None):
388+
"""
389+
Recursively turn a nested dictionary or addict.Dict into a list of Python code that
390+
can generate the nested dictionary.
391+
392+
:param dictionary:
393+
:param current_lines_list:
394+
:param prefix_list:
395+
:return: list of Python code lines
396+
"""
397+
398+
def merge_to_absolute_key(prefix, key):
399+
"""
400+
Combine key and prefix to "prefix.key" except if there is no prefix, then return key
401+
"""
402+
if prefix is None:
403+
return key
404+
else:
405+
return f"{prefix}.{key}"
406+
407+
if current_lines_list is None:
408+
current_lines_list = []
409+
410+
for key in sorted(dictionary.keys()):
411+
value = dictionary[key]
412+
413+
absolute_key = merge_to_absolute_key(prefix, key)
414+
415+
if type(value) in (dict, Dict):
416+
current_lines_list = recursively_turn_dict_to_python_list(value, current_lines_list, prefix=absolute_key)
417+
else:
418+
value_representation = get_pythonic_representation_of_value(value)
419+
420+
absolute_key = clean_up_key(absolute_key)
421+
422+
current_lines_list.append(f"{absolute_key} = {value_representation}")
423+
424+
return current_lines_list
425+
426+
427+
def clean_up_key(absolute_key: str):
428+
"""
429+
Remove problematic phrases from key, such as blank "return"
430+
431+
:param absolute_key:
432+
:return:
433+
"""
434+
absolute_key = absolute_key.replace(".return", "['return']")
435+
return absolute_key
436+
437+
438+
def get_pythonic_representation_of_value(value):
439+
"""
440+
Use repr() to get a pythonic representation of the value
441+
and add "np." to "array" and "float64"
442+
443+
"""
444+
value_representation = repr(value)
445+
value_representation = value_representation.replace("array", "numpy.array")
446+
value_representation = value_representation.replace("float64", "numpy.float64")
447+
try:
448+
eval(value_representation)
449+
except NameError as e:
450+
raise ValueError(
451+
f"Encountered a value of '{value_representation}' that can't be directly reproduced in python.\n"
452+
f"Please report this to the CADET-Python developers.") from e
453+
454+
return value_representation

tests/test_save_as_python.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import tempfile
2+
3+
import numpy as np
4+
import pytest
5+
from addict import Dict
6+
7+
from cadet import Cadet
8+
9+
10+
@pytest.fixture
11+
def temp_cadet_file():
12+
"""
13+
Create a new Cadet object for use in tests.
14+
"""
15+
model = Cadet()
16+
17+
with tempfile.NamedTemporaryFile() as temp:
18+
model.filename = temp
19+
yield model
20+
21+
22+
def test_save_as_python(temp_cadet_file):
23+
"""
24+
Test that the Cadet class raises a KeyError exception when duplicate keys are set on it.
25+
"""
26+
# initialize "sim" variable to be overwritten by the exec lines later
27+
sim = Cadet()
28+
29+
# Populate temp_cadet_file with all tricky cases currently known
30+
temp_cadet_file.root.input.foo = 1
31+
temp_cadet_file.root.input.bar.baryon = np.arange(10)
32+
temp_cadet_file.root.input.bar.barometer = np.linspace(0, 10, 9)
33+
temp_cadet_file.root.input.bar.init_q = np.array([], dtype=np.float64)
34+
temp_cadet_file.root.input["return"].split_foobar = 1
35+
36+
code_lines = temp_cadet_file.save_as_python_script(filename="temp.py", only_return_pythonic_representation=True)
37+
38+
# remove code lines that save the file
39+
code_lines = code_lines[:-2]
40+
41+
# populate "sim" variable using the generated code lines
42+
for line in code_lines:
43+
exec(line)
44+
45+
# test that "sim" is equal to "temp_cadet_file"
46+
recursive_equality_check(sim.root, temp_cadet_file.root)
47+
48+
49+
def recursive_equality_check(dict_a: dict, dict_b: dict):
50+
assert dict_a.keys() == dict_b.keys()
51+
for key in dict_a.keys():
52+
value_a = dict_a[key]
53+
value_b = dict_b[key]
54+
if type(value_a) in (dict, Dict):
55+
recursive_equality_check(value_a, value_b)
56+
elif type(value_a) == np.ndarray:
57+
np.testing.assert_array_equal(value_a, value_b)
58+
else:
59+
assert value_a == value_b
60+
return True
61+
62+
63+
if __name__ == "__main__":
64+
pytest.main()

0 commit comments

Comments
 (0)