Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions src/access/esmf_trace/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .batch_runs import run_batch_jobs
from .config import DefaultSettings, PostRunSettings, PostSummarySettings, RunSettings, load_yaml_config
from .postprocess import post_summary_from_yaml
from .utils import normalise_str_list


def run_from_config(
Expand Down Expand Up @@ -75,7 +76,7 @@ def __init__(
branches: list[str],
post_base_path: str | Path,
exact_paths: list[Path],
model_component: str,
model_component: str | list[str],
branch_pattern: re.Pattern[str] | None = None,
pets_components: list[str] | None = None,
pets_prefix: str = "0",
Expand All @@ -87,7 +88,7 @@ def __init__(
branches: Experiment branch directory names; each string must match the regex provided in branch_pattern
post_base_path: where esmf-trace writes postprocessed outputs for this config
exact_paths: list of exact paths for each branch
model_component: comma-separated esmf component selector string.
model_component: comma-separated esmf component selector string or list[str] of selectors
branch_pattern: regex pattern to parse layout values, with capture groups for each layout variable
pets_components: list[str], keys to include in pets string in order
pets_prefix: str, prefix for pets string (default "0")
Expand Down Expand Up @@ -120,8 +121,8 @@ def _validate(self) -> None:
if not self.branches:
raise ValueError("At least one branch must be provided.")

if not isinstance(self.model_component, str) or not self.model_component:
raise ValueError("model_component must be a non-empty string.")
if normalise_str_list(self.model_component) is None:
raise ValueError("model_component must be a non-empty string or list[str].")

if not isinstance(self.max_workers, int) or self.max_workers < 1:
raise ValueError("max_workers must be an int >= 1")
Expand Down Expand Up @@ -208,7 +209,7 @@ def build_config(self) -> dict:
return {
"default_settings": {
"post_base_path": str(self.post_base_path),
"model_component": self.model_component,
"model_component": normalise_str_list(self.model_component),
**self.default_settings,
},
"runs": runs,
Expand Down Expand Up @@ -271,12 +272,8 @@ def build_config(self, runs: list[dict]) -> dict:
"timeseries_suffix": self.timeseries_suffix,
}

if self.model_component is not None:
default_settings["model_component"] = (
self.model_component
if isinstance(self.model_component, list)
else [s.strip() for s in str(self.model_component).split(",") if s.strip()]
)
default_settings["model_component"] = normalise_str_list(self.model_component)

if self.pets is not None:
default_settings["pets"] = (
self.pets
Expand Down
8 changes: 8 additions & 0 deletions src/access/esmf_trace/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,11 @@ def construct_stream_paths(traceout_path: Path, pet_indices: list[int], prefix:
"""
traceout_path = Path(traceout_path).expanduser().resolve()
return [traceout_path / f"{prefix}_{p:04d}" for p in pet_indices]


def normalise_str_list(value: str | list[str] | None) -> list[str] | None:
if value is None:
return None
if isinstance(value, list):
return [str(v).strip() for v in value if str(v).strip()]
return [s.strip() for s in str(value).split(",") if s.strip()]