Skip to content
10 changes: 10 additions & 0 deletions src/access/esmf_trace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,13 @@

with suppress(PackageNotFoundError):
__version__ = version("esmf_trace")

from access.esmf_trace.library import (
post_summary_from_config,
run_from_config,
)

__all__ = [
"run_from_config",
"post_summary_from_config",
]
4 changes: 2 additions & 2 deletions src/access/esmf_trace/batch_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from .config import ConfigError, DefaultSettings, RunSettings
from .run import run as single_run
from .utils import extract_index_list, output_name_to_index
from .utils import extract_index_list_from_str, output_name_to_index


def _find_traceout_dir(output_dir: Path, stream_prefix: str) -> Path | None:
Expand All @@ -29,7 +29,7 @@ def _gather_outputs(archive_dir: Path, output_index: str | None) -> list[Path]:
all_outputs = [p for p in archive_dir.glob("output*") if p.is_dir()]
all_outputs = [p for p in all_outputs if output_name_to_index(p) is not None]
output_dirs = sorted(all_outputs, key=output_name_to_index)
selected = extract_index_list(output_index)
selected = extract_index_list_from_str(output_index)
if selected is not None:
sel = set(selected)
present = {output_name_to_index(p) for p in output_dirs}
Expand Down
33 changes: 33 additions & 0 deletions src/access/esmf_trace/common_vars.py
Original file line number Diff line number Diff line change
@@ -1 +1,34 @@
from typing import Literal

seconds_to_nanoseconds = 1e9

# For now, two config kinds: "run" and "post-summary" are included.
# This might be extended if we want to support more config kinds.
config_kind = Literal["run", "post-summary"]

# Common keys for both run and post-summary configs
RUN_DEFAULT_FLAG_KEYS = [
"merge_adjacent",
"xaxis_datetime",
"separate_plots",
"show_html",
]

RUN_DEFAULT_KEYS = [
"stream_prefix",
"model_component",
"max_depth",
"merge_gap_ns",
"cmap",
"renderer",
"max_workers",
]

POST_SUMMARY_DEFAULT_KEYS = [
"timeseries_suffix",
"save_json_path",
"stats_start_index",
"stats_end_index",
"pets",
"model_component",
]
233 changes: 179 additions & 54 deletions src/access/esmf_trace/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, overload

from .common_vars import config_kind
from .tmp_yaml_parser import read_yaml
from .utils import extract_index_list_from_str, extract_pets


class ConfigError(Exception):
Expand Down Expand Up @@ -84,66 +89,186 @@ def to_job_kwargs(
}


def _require_key(d: dict, keys: list[str]) -> str:
@dataclass(frozen=True)
class PostSummarySettings:
post_base_path: Path
model_component: list[str] | None = None
pets: list[int] | None = None
stats_start_index: int | None = None
stats_end_index: int | None = None
timeseries_suffix: str = "_timeseries.json"
save_json_path: Path | None = None


@dataclass(frozen=True)
class PostRunSettings:
name: str
output_index: list[str] | None = None
model_component: list[str] | None = None
pets: list[int] | None = None
stats_start_index: int | None = None
stats_end_index: int | None = None
save_json_path: Path | None = None


def _as_mapping(x, what: str) -> dict:
if not isinstance(x, dict):
raise ConfigError(f"{what} must be a mapping (dict)")
return x


def _as_list(x, what: str) -> list:
if not isinstance(x, list):
raise ConfigError(f"{what} must be a list")
return x


def _require_keys(d: dict, keys: list[str], where: str) -> None:
missing = [k for k in keys if k not in d]
if missing:
raise ConfigError(f"missing required config key(s): {', '.join(missing)}")


def _parse_defaults(d: dict) -> DefaultSettings:
return DefaultSettings(
post_base_path=d.get("post_base_path"),
stream_prefix=d.get("stream_prefix", "esmf_stream"),
model_component=d.get("model_component", "[ESMF]/[ensemble] RunPhase1/[ESM0001] RunPhase1"),
max_workers=d.get("max_workers"),
xaxis_datetime=bool(d.get("xaxis_datetime", False)),
separate_plots=bool(d.get("separate_plots", False)),
cmap=d.get("cmap", "tab10"),
renderer=d.get("renderer", "browser"),
show_html=bool(d.get("show_html", False)),
max_depth=int(d.get("max_depth", 6)),
merge_adjacent=bool(d.get("merge_adjacent", False)),
merge_gap_ns=int(d.get("merge_gap_ns", 1000)),
)


def _parse_runs(lst: list[dict]) -> list[RunSettings]:
runs = []
for item in lst:
if not isinstance(item, dict):
raise ConfigError("Each run must be a mapping (dict)")

has_exact_path = item.get("exact_path")
has_other_parts = item.get("run_base") and item.get("run_name") and item.get("branch")
if not has_exact_path and not has_other_parts:
raise ConfigError(
"Each run must have either 'exact_path' or all of 'run_base', 'run_name', and 'branch' set"
)
raise ConfigError(f"missing required config key(s) in {where}: {', '.join(missing)}")


def _norm_model_component(v: str | list | tuple | set | None) -> list[str] | None:
"""
Normalise model_component to a list of strings.
Accepts a comma-separated str or a list[str].
"""
if v is None:
return None

if isinstance(v, (list, tuple, set)):
parts = [str(x).strip() for x in v if str(x).strip()]
return parts or None

s = str(v).strip()
if not s:
return None

# split on commas
parts = [p.strip() for p in s.split(",") if p.strip()]
return parts or None


runs.append(
RunSettings(
base_prefix=item.get("base_prefix"),
post_base_path=item.get("post_base_path"),
exact_path=Path(item["exact_path"]) if item.get("exact_path") else None,
run_base=Path(item["run_base"]) if item.get("run_base") else None,
run_name=item.get("run_name"),
branch=item.get("branch"),
pets=item.get("pets"),
model_component=item.get("model_component"),
output_index=item.get("output_index"),
def _norm_int_or_none(v: int | str | None) -> int | None:
if v is None or v == "":
return None
return int(v)


def _norm_path_or_none(v: str | Path | None) -> Path | None:
if v is None:
return None
return Path(v).expanduser()


# define overloads for type checking of load_yaml_config
@overload
def load_yaml_config(config_path: Path, kind: Literal["run"]) -> (DefaultSettings, list[RunSettings]): ...
@overload
def load_yaml_config(
config_path: Path, kind: Literal["post-summary"]
) -> (PostSummarySettings, list[PostRunSettings]): ...


def load_yaml_config(config_path: Path, kind: config_kind):
"""
Load and validate an esmf-trace yaml configuration file.
"""
config_path = Path(config_path)
data = read_yaml(config_path)

_require_keys(data, ["default_settings", "runs"], where=str(config_path))
default = _as_mapping(data["default_settings"], what="default_settings")
runs = _as_list(data["runs"], what="runs")

if kind == "run":
defaults = DefaultSettings(
post_base_path=default.get("post_base_path"),
stream_prefix=default.get("stream_prefix", "esmf_stream"),
model_component=default.get("model_component", "[ESMF]/[ensemble] RunPhase1/[ESM0001] RunPhase1"),
max_workers=default.get("max_workers"),
xaxis_datetime=bool(default.get("xaxis_datetime", False)),
separate_plots=bool(default.get("separate_plots", False)),
cmap=default.get("cmap", "tab10"),
renderer=default.get("renderer", "browser"),
show_html=bool(default.get("show_html", False)),
max_depth=int(default.get("max_depth", 6)),
merge_adjacent=bool(default.get("merge_adjacent", False)),
merge_gap_ns=int(default.get("merge_gap_ns", 1000)),
)

run_settings: list[RunSettings] = []
for i, item in enumerate(runs):
item = _as_mapping(item, what=f"runs[{i}]")

has_exact_path = item.get("exact_path")
has_other_parts = item.get("run_base") and item.get("run_name") and item.get("branch")
if not has_exact_path and not has_other_parts:
raise ConfigError(
"Each run must have either 'exact_path' or "
f"all of 'run_base', 'run_name', and 'branch' set (error in runs[{i}])"
)

run_settings.append(
RunSettings(
base_prefix=item.get("base_prefix"),
post_base_path=item.get("post_base_path"),
exact_path=_norm_path_or_none(item.get("exact_path") if item.get("exact_path") else None),
run_base=_norm_path_or_none(item.get("run_base") if item.get("run_base") else None),
run_name=item.get("run_name"),
branch=item.get("branch"),
archive=item.get("archive", "archive"),
pets=item.get("pets"),
model_component=item.get("model_component"),
output_index=item.get("output_index"),
)
)

return defaults, run_settings

if kind == "post-summary":
post_base = default.get("post_base_path")
if not post_base:
raise ConfigError("default_settings.post_base_path is required for post-summary config")

defaults = PostSummarySettings(
post_base_path=Path(post_base).expanduser(),
model_component=_norm_model_component(default.get("model_component")),
pets=extract_pets(default.get("pets") if default.get("pets") is not None else None),
stats_start_index=_norm_int_or_none(default.get("stats_start_index")),
stats_end_index=_norm_int_or_none(default.get("stats_end_index")),
timeseries_suffix=default.get("timeseries_suffix", "_timeseries.json"),
save_json_path=_norm_path_or_none(default.get("save_json_path")),
)
return runs

post_runs: list[PostRunSettings] = []
for i, item in enumerate(runs):
item = _as_mapping(item, what=f"runs[{i}]")
_require_keys(item, ["name"], where=f"runs[{i}]")

def load_config(input_config: dict) -> (DefaultSettings, list[RunSettings]):
_require_key(input_config, ["default_settings", "runs"])
oi = item.get("output_index")
if isinstance(oi, list):
output_index = [int(x) for x in oi]
elif isinstance(oi, str):
output_index = extract_index_list_from_str(oi)
else:
output_index = None

if not isinstance(input_config["default_settings"], dict):
raise ConfigError("'default_settings' must be a dict")
if not isinstance(input_config["runs"], list):
raise ConfigError("'runs' must be a list")
pets_input = item.get("pets", defaults.pets)
pets = pets_input if isinstance(pets_input, list) or pets_input is None else extract_pets(str(pets_input))

post_runs.append(
PostRunSettings(
name=str(item["name"]),
output_index=output_index,
model_component=_norm_model_component(item.get("model_component", defaults.model_component)),
pets=pets,
stats_start_index=_norm_int_or_none(item.get("stats_start_index", default.stats_start_index)),
stats_end_index=_norm_int_or_none(item.get("stats_end_index", default.stats_end_index)),
save_json_path=_norm_path_or_none(item.get("save_json_path", default.save_json_path)),
)
)
return defaults, post_runs

defaults = _parse_defaults(input_config["default_settings"])
runs = _parse_runs(input_config["runs"])
return defaults, runs
raise ValueError(f"Invalid config kind: {kind}")
4 changes: 2 additions & 2 deletions src/access/esmf_trace/ctf_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def open_selected_streams(traceout_path: Path, stream_paths: iter):
tmpdir = Path(tempfile.mkdtemp(prefix="ctf_stage_")).resolve()
try:
# link metadata and the selected streams into the temp bundle
Path.symlink(meta, tmpdir / "metadata", target_is_directory=False)
(tmpdir / "metadata").symlink_to(meta)
for s in streams:
Path.symlink(s, tmpdir / s.name, target_is_directory=False)
(tmpdir / s.name).symlink_to(s)

yield bt2.TraceCollectionMessageIterator(str(tmpdir))
finally:
Expand Down
55 changes: 55 additions & 0 deletions src/access/esmf_trace/library.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from dataclasses import replace
from pathlib import Path

from .batch_runs import run_batch_jobs
from .config import DefaultSettings, PostRunSettings, PostSummarySettings, RunSettings, load_yaml_config
from .postprocess import post_summary_from_yaml


def run_from_config(
config_path: str | Path | dict,
run_overrides: dict | None = None,
):
"""
Either a yaml path or a dict with the same structure.

run_overrides: optional dict of DefaultSettings field overrides
e.g. {"stream_prefix": "esmf_stream", "max_workers": 8}
"""

if isinstance(config_path, (str, Path)):
defaults, runs = load_yaml_config(Path(config_path), kind="run")
else:
defaults = DefaultSettings(**config_path["default_settings"])
runs = [RunSettings(**r) for r in config_path["runs"]]

if run_overrides:
defaults = replace(defaults, **dict(run_overrides))

run_batch_jobs(defaults, runs)


def post_summary_from_config(
config_path: str | Path | dict,
post_overrides: dict | None = None,
save_json_path: str | Path | None = None,
):
"""
Either a yaml path or a dict with the same structure.

post_overrides: optional dict of PostSummarySettings field overrides
e.g. {"timeseries_suffix": "_timeseries.json", "stats_start_index": 1}
"""

if isinstance(config_path, (str, Path)):
defaults, runs = load_yaml_config(Path(config_path), kind="post-summary")
assert isinstance(defaults, PostSummarySettings)
else:
defaults = PostSummarySettings(**config_path["default_settings"])
runs = [PostRunSettings(**r) for r in config_path["runs"]]

if post_overrides:
defaults = replace(defaults, **dict(post_overrides))

out_path = str(save_json_path) if save_json_path is not None else None
post_summary_from_yaml(defaults, runs, save_json_path=out_path)
Loading