diff --git a/src/access/esmf_trace/library.py b/src/access/esmf_trace/library.py index 0b8a35c..d59178c 100644 --- a/src/access/esmf_trace/library.py +++ b/src/access/esmf_trace/library.py @@ -5,6 +5,7 @@ from .batch_runs import run_batch_jobs from .config import DefaultSettings, PostRunSettings, PostSummarySettings, RunSettings, load_yaml_config from .postprocess import post_summary_from_yaml +from .utils import normalise_str_list def run_from_config( @@ -75,7 +76,7 @@ def __init__( branches: list[str], post_base_path: str | Path, exact_paths: list[Path], - model_component: str, + model_component: str | list[str], branch_pattern: re.Pattern[str] | None = None, pets_components: list[str] | None = None, pets_prefix: str = "0", @@ -87,7 +88,7 @@ def __init__( branches: Experiment branch directory names; each string must match the regex provided in branch_pattern post_base_path: where esmf-trace writes postprocessed outputs for this config exact_paths: list of exact paths for each branch - model_component: comma-separated esmf component selector string. + model_component: comma-separated esmf component selector string or list[str] of selectors branch_pattern: regex pattern to parse layout values, with capture groups for each layout variable pets_components: list[str], keys to include in pets string in order pets_prefix: str, prefix for pets string (default "0") @@ -120,8 +121,8 @@ def _validate(self) -> None: if not self.branches: raise ValueError("At least one branch must be provided.") - if not isinstance(self.model_component, str) or not self.model_component: - raise ValueError("model_component must be a non-empty string.") + if normalise_str_list(self.model_component) is None: + raise ValueError("model_component must be a non-empty string or list[str].") if not isinstance(self.max_workers, int) or self.max_workers < 1: raise ValueError("max_workers must be an int >= 1") @@ -208,7 +209,7 @@ def build_config(self) -> dict: return { "default_settings": { "post_base_path": str(self.post_base_path), - "model_component": self.model_component, + "model_component": normalise_str_list(self.model_component), **self.default_settings, }, "runs": runs, @@ -271,12 +272,8 @@ def build_config(self, runs: list[dict]) -> dict: "timeseries_suffix": self.timeseries_suffix, } - if self.model_component is not None: - default_settings["model_component"] = ( - self.model_component - if isinstance(self.model_component, list) - else [s.strip() for s in str(self.model_component).split(",") if s.strip()] - ) + default_settings["model_component"] = normalise_str_list(self.model_component) + if self.pets is not None: default_settings["pets"] = ( self.pets diff --git a/src/access/esmf_trace/utils.py b/src/access/esmf_trace/utils.py index 03df9c7..a36ae81 100644 --- a/src/access/esmf_trace/utils.py +++ b/src/access/esmf_trace/utils.py @@ -89,3 +89,11 @@ def construct_stream_paths(traceout_path: Path, pet_indices: list[int], prefix: """ traceout_path = Path(traceout_path).expanduser().resolve() return [traceout_path / f"{prefix}_{p:04d}" for p in pet_indices] + + +def normalise_str_list(value: str | list[str] | None) -> list[str] | None: + if value is None: + return None + if isinstance(value, list): + return [str(v).strip() for v in value if str(v).strip()] + return [s.strip() for s in str(value).split(",") if s.strip()]