Skip to content

Commit efefceb

Browse files
committed
Add .save_as_python_script method and test
1 parent 645ccce commit efefceb

File tree

2 files changed

+158
-0
lines changed

2 files changed

+158
-0
lines changed

cadet/cadet.py

+94
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,32 @@ def load_json(self, filename, update=False):
7777
else:
7878
self.root = data
7979

80+
def save_as_python_script(self, filename: str, only_return_pythonic_representation=False):
81+
if not filename.endswith(".py"):
82+
raise Warning(f"The filename given to .save_as_python_script isn't a python file name.")
83+
84+
code_lines_list = [
85+
"import numpy",
86+
"from cadet import Cadet",
87+
"",
88+
"sim = Cadet()",
89+
"root = sim.root",
90+
]
91+
92+
code_lines_list = recursively_turn_dict_to_python_list(dictionary=self.root,
93+
current_lines_list=code_lines_list,
94+
prefix="root")
95+
96+
filename_for_reproduced_h5_file = filename.replace(".py", ".h5")
97+
code_lines_list.append(f"sim.filename = '{filename_for_reproduced_h5_file}'")
98+
code_lines_list.append("sim.save()")
99+
100+
if not only_return_pythonic_representation:
101+
with open(filename, "w") as handle:
102+
handle.writelines([line + "\n" for line in code_lines_list])
103+
else:
104+
return code_lines_list
105+
80106
def append(self, lock=False):
81107
"This can only be used to write new keys to the system, this is faster than having to read the data before writing it"
82108
if self.filename is not None:
@@ -347,3 +373,71 @@ def recursively_save(h5file, path, dic, func):
347373
raise KeyError(f'Name conflict with upper and lower case entries for key "{path}{key}".')
348374
else:
349375
raise
376+
377+
378+
def recursively_turn_dict_to_python_list(dictionary: dict, current_lines_list: list = None, prefix: str = None):
379+
"""
380+
Recursively turn a nested dictionary or addict.Dict into a list of Python code that
381+
can generate the nested dictionary.
382+
383+
:param dictionary:
384+
:param current_lines_list:
385+
:param prefix_list:
386+
:return: list of Python code lines
387+
"""
388+
389+
def merge_to_absolute_key(prefix, key):
390+
"""
391+
Combine key and prefix to "prefix.key" except if there is no prefix, then return key
392+
"""
393+
if prefix is None:
394+
return key
395+
else:
396+
return f"{prefix}.{key}"
397+
398+
def clean_up_key(absolute_key: str):
399+
"""
400+
Remove problematic phrases from key, such as blank "return"
401+
402+
:param absolute_key:
403+
:return:
404+
"""
405+
absolute_key = absolute_key.replace(".return", "['return']")
406+
return absolute_key
407+
408+
def get_pythonic_representation_of_value(value):
409+
"""
410+
Use repr() to get a pythonic representation of the value
411+
and add "np." to "array" and "float64"
412+
413+
"""
414+
value_representation = repr(value)
415+
value_representation = value_representation.replace("array", "numpy.array")
416+
value_representation = value_representation.replace("float64", "numpy.float64")
417+
try:
418+
eval(value_representation)
419+
except NameError as e:
420+
raise ValueError(
421+
f"Encountered a value of '{value_representation}' that can't be directly reproduced in python.\n"
422+
f"Please report this to the CADET-Python developers.") from e
423+
424+
return value_representation
425+
426+
if current_lines_list is None:
427+
current_lines_list = []
428+
429+
for key in sorted(dictionary.keys()):
430+
value = dictionary[key]
431+
432+
absolute_key = merge_to_absolute_key(prefix, key)
433+
434+
if type(value) in (dict, Dict):
435+
current_lines_list = recursively_turn_dict_to_python_list(value, current_lines_list, prefix=absolute_key)
436+
else:
437+
value_representation = get_pythonic_representation_of_value(value)
438+
439+
absolute_key = clean_up_key(absolute_key)
440+
441+
current_lines_list.append(f"{absolute_key} = {value_representation}")
442+
443+
return current_lines_list

tests/test_save_as_python.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import tempfile
2+
3+
import numpy as np
4+
import pytest
5+
from addict import Dict
6+
7+
from cadet import Cadet
8+
9+
10+
@pytest.fixture
11+
def temp_cadet_file():
12+
"""
13+
Create a new Cadet object for use in tests.
14+
"""
15+
model = Cadet()
16+
17+
with tempfile.NamedTemporaryFile() as temp:
18+
model.filename = temp
19+
yield model
20+
21+
22+
def test_save_as_python(temp_cadet_file):
23+
"""
24+
Test that the Cadet class raises a KeyError exception when duplicate keys are set on it.
25+
"""
26+
# initialize "sim" variable to be overwritten by the exec lines later
27+
sim = Cadet()
28+
29+
# Populate temp_cadet_file with all tricky cases currently known
30+
temp_cadet_file.root.input.foo = 1
31+
temp_cadet_file.root.input.bar.baryon = np.arange(10)
32+
temp_cadet_file.root.input.bar.barometer = np.linspace(0, 10, 9)
33+
temp_cadet_file.root.input.bar.init_q = np.array([], dtype=np.float64)
34+
temp_cadet_file.root.input["return"].split_foobar = 1
35+
36+
code_lines = temp_cadet_file.save_as_python_script(filename="temp.py", only_return_pythonic_representation=True)
37+
38+
# remove code lines that save the file
39+
code_lines = code_lines[:-2]
40+
41+
# populate "sim" variable using the generated code lines
42+
for line in code_lines:
43+
exec(line)
44+
45+
# test that "sim" is equal to "temp_cadet_file"
46+
recursive_equality_check(sim.root, temp_cadet_file.root)
47+
48+
49+
def recursive_equality_check(dict_a: dict, dict_b: dict):
50+
assert dict_a.keys() == dict_b.keys()
51+
for key in dict_a.keys():
52+
value_a = dict_a[key]
53+
value_b = dict_b[key]
54+
if type(value_a) in (dict, Dict):
55+
recursive_equality_check(value_a, value_b)
56+
elif type(value_a) == np.ndarray:
57+
np.testing.assert_array_equal(value_a, value_b)
58+
else:
59+
assert value_a == value_b
60+
return True
61+
62+
63+
if __name__ == "__main__":
64+
pytest.main()

0 commit comments

Comments
 (0)