diff --git a/ORBIT/config.py b/ORBIT/config.py index d97926d9..579cacd1 100644 --- a/ORBIT/config.py +++ b/ORBIT/config.py @@ -9,7 +9,9 @@ from pathlib import Path import yaml +import numpy as np from yaml import Dumper +from benedict.dicts import benedict from ORBIT.core import loader @@ -30,15 +32,47 @@ def load_config(filepath): return data -def save_config(config, filepath, overwrite=False): +def prepare_config_for_save(config: dict | benedict) -> dict: + """Prepare the configuration file for compatbility with the YAML + ``SafeDump`` class used for saving configurations to file. + + Parameters + ---------- + config : dict | benedict + ORBIT configuration dictionary. + + Returns + ------- + dict + ORBIT configuration dictionary where all NumPy data types are converted + to standard Python data types, e.g. ``np.float64`` -> ``float``. """ - Save an ORBIT `config` to `filepath`. + for k, v in config.items(): + match v: + case np.ndarray(): + config[k] = v.tolist() + case np.floating(): + config[k] = float(v) + case np.integer(): + config[k] = int(v) + case dict(): + config[k] = prepare_config_for_save(v) + case _: + pass + return config + + +def save_config( + config: dict | benedict, filepath: str | Path, overwrite: bool = False +): + """ + Save an ORBIT configuration to :py:attr:`filepath`. Parameters ---------- - config : dict + config : dict | benedict ORBIT configuration. - filepath : str + filepath : str | Path Location to save config. overwrite : bool (optional) Overwrite file if it already exists. Default: False. @@ -53,5 +87,9 @@ def save_config(config, filepath, overwrite=False): if filepath.exists(): raise FileExistsError(f"File already exists at '{filepath}'.") + config = prepare_config_for_save(config) + if isinstance(config, benedict): + config = config.dict() + with filepath.open("w") as f: yaml.dump(config, f, Dumper=Dumper, default_flow_style=False) diff --git a/pyproject.toml b/pyproject.toml index 3a601972..c8fb3108 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ dev = [ "pre-commit", "black", "isort", - "pytest", + "pytest>=9", "pytest-cov", "sphinx", "sphinx-rtd-theme", diff --git a/tests/test_config_management.py b/tests/test_config_management.py index 13a59cef..474ea81c 100644 --- a/tests/test_config_management.py +++ b/tests/test_config_management.py @@ -5,17 +5,26 @@ from ORBIT import ProjectManager, load_config, save_config +from ORBIT.config import prepare_config_for_save from ORBIT.core.library import extract_library_specs -complete_project = extract_library_specs("config", "complete_project") +def test_save_and_load_equality(subtests, tmp_yaml_del): -def test_save_and_load_equality(tmp_yaml_del): - + complete_project = extract_library_specs("config", "complete_project") save_config(complete_project, "tmp.yaml", overwrite=True) new = load_config("tmp.yaml") - assert new == complete_project + with subtests.test("Check direct file equality"): + assert new == complete_project + + with subtests.test("Check ProjectManager equality"): + new_project = ProjectManager(new) + new = prepare_config_for_save(new_project.config) + + expected_project = ProjectManager(complete_project) + complete_project = prepare_config_for_save(expected_project.config) + assert new == complete_project def test_orbit_version_ProjectManager():