Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions ORBIT/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from pathlib import Path

import yaml
import numpy as np
from yaml import Dumper
from benedict.dicts import benedict

from ORBIT.core import loader

Expand All @@ -30,15 +32,47 @@ def load_config(filepath):
return data


def save_config(config, filepath, overwrite=False):
def prepare_config_for_save(config: dict | benedict) -> dict:
"""Prepare the configuration file for compatbility with the YAML
``SafeDump`` class used for saving configurations to file.

Parameters
----------
config : dict | benedict
ORBIT configuration dictionary.

Returns
-------
dict
ORBIT configuration dictionary where all NumPy data types are converted
to standard Python data types, e.g. ``np.float64`` -> ``float``.
"""
Save an ORBIT `config` to `filepath`.
for k, v in config.items():
match v:
case np.ndarray():
config[k] = v.tolist()
case np.floating():
config[k] = float(v)
case np.integer():
config[k] = int(v)
case dict():
config[k] = prepare_config_for_save(v)
case _:
pass
return config


def save_config(
config: dict | benedict, filepath: str | Path, overwrite: bool = False
):
"""
Save an ORBIT configuration to :py:attr:`filepath`.

Parameters
----------
config : dict
config : dict | benedict
ORBIT configuration.
filepath : str
filepath : str | Path
Location to save config.
overwrite : bool (optional)
Overwrite file if it already exists. Default: False.
Expand All @@ -53,5 +87,9 @@ def save_config(config, filepath, overwrite=False):
if filepath.exists():
raise FileExistsError(f"File already exists at '{filepath}'.")

config = prepare_config_for_save(config)
if isinstance(config, benedict):
config = config.dict()

with filepath.open("w") as f:
yaml.dump(config, f, Dumper=Dumper, default_flow_style=False)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ dev = [
"pre-commit",
"black",
"isort",
"pytest",
"pytest>=9",
"pytest-cov",
"sphinx",
"sphinx-rtd-theme",
Expand Down
17 changes: 13 additions & 4 deletions tests/test_config_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,26 @@


from ORBIT import ProjectManager, load_config, save_config
from ORBIT.config import prepare_config_for_save
from ORBIT.core.library import extract_library_specs

complete_project = extract_library_specs("config", "complete_project")

def test_save_and_load_equality(subtests, tmp_yaml_del):

def test_save_and_load_equality(tmp_yaml_del):

complete_project = extract_library_specs("config", "complete_project")
save_config(complete_project, "tmp.yaml", overwrite=True)
new = load_config("tmp.yaml")

assert new == complete_project
with subtests.test("Check direct file equality"):
assert new == complete_project

with subtests.test("Check ProjectManager equality"):
new_project = ProjectManager(new)
new = prepare_config_for_save(new_project.config)

expected_project = ProjectManager(complete_project)
complete_project = prepare_config_for_save(expected_project.config)
assert new == complete_project


def test_orbit_version_ProjectManager():
Expand Down