diff --git a/CHANGELOG.md b/CHANGELOG.md index d4979bc7..9681b2f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Add files to support mybinder.org +- (plaid-check) add a simple app to check the integrity of a plaid database - (sample/features) add_field: check field size consistency with geometrical support. - (sample) add `set_trees` to `Sample` delegated methods: `sample.set_trees(...)` now works as a direct proxy to `SampleFeatures.set_trees`, consistent with other delegated tree methods. diff --git a/pyproject.toml b/pyproject.toml index 70bdd20e..8b01932b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,9 @@ files=["LICENSE.txt"] file="README.md" content-type = "text/markdown" +[project.scripts] +plaid-check = "plaid.cli.plaidcheck:main" + [tool.setuptools] platforms = [ "Linux", diff --git a/src/plaid/cli/plaidcheck.py b/src/plaid/cli/plaidcheck.py new file mode 100644 index 00000000..1d8e7095 --- /dev/null +++ b/src/plaid/cli/plaidcheck.py @@ -0,0 +1,560 @@ +"""CLI tool to validate integrity of a PLAID dataset stored on disk.""" + +import argparse +import json +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Optional + +import CGNS.PAT.cgnsutils as CGU +import numpy as np + +from plaid.constants import CGNS_FIELD_LOCATIONS +from plaid.storage import init_from_disk +from plaid.storage.common.reader import ( + load_infos_from_disk, + load_metadata_from_disk, + load_problem_definitions_from_disk, +) + + +@dataclass +class CheckMessage: + """One integrity check message. + + Args: + severity: Message severity (`error`, `warning`, or `info`). + code: Stable message code identifier. + location: Path-like location string related to the issue. + message: Human-readable message. + """ + + severity: str + code: str + location: str + message: str + + +@dataclass +class CheckReport: + """Container for check results and summary helpers. + + Args: + messages: Integrity check messages collected during validation. + """ + + messages: list[CheckMessage] + + def add(self, severity: str, code: str, location: str, message: str) -> None: + """Append a new message to the report. + + Args: + severity: Message severity (`error`, `warning`, or `info`). + code: Stable message code identifier. + location: Path-like location string related to the issue. + message: Human-readable message. + """ + self.messages.append( + CheckMessage( + severity=severity, + code=code, + location=location, + message=message, + ) + ) + + def counts(self) -> dict[str, int]: + """Return counts by severity. + + Returns: + Mapping from severity names to message counts. + """ + return { + "error": sum(msg.severity == "error" for msg in self.messages), + "warning": sum(msg.severity == "warning" for msg in self.messages), + "info": sum(msg.severity == "info" for msg in self.messages), + } + + def has_errors(self) -> bool: + """Return whether at least one error was reported. + + Returns: + True when the report contains one or more error messages. + """ + return any(msg.severity == "error" for msg in self.messages) + + def has_warnings(self) -> bool: + """Return whether at least one warning was reported. + + Returns: + True when the report contains one or more warning messages. + """ + return any(msg.severity == "warning" for msg in self.messages) + + def to_json(self) -> str: + """Serialize report to JSON string. + + Returns: + JSON string containing severity counts and message details. + """ + payload = { + "counts": self.counts(), + "messages": [asdict(msg) for msg in self.messages], + } + return json.dumps(payload, indent=2) + + +def _check_required_layout(path: Path, report: CheckReport) -> None: + """Validate that the dataset directory has the required PLAID layout. + + Args: + path: Dataset directory to inspect. + report: Report updated with missing path errors. + + Returns: + None. + """ + required_paths = [ + "infos.yaml", + "variable_schema.yaml", + "cgns_types.yaml", + "constants", + "data", + ] + for rel in required_paths: + p = path / rel + if not p.exists(): + report.add("error", "MISSING_PATH", rel, f"Missing file/path path: {rel}") + + +def _check_numeric_content(value: Any) -> Optional[str]: + """Inspect a feature value for invalid numeric or object content. + + Args: + value: Feature value to validate. + + Returns: + Description of the detected issue, or None when the value is valid. + """ + if value is None: + return "value is None" + arr = np.asarray(value) + if arr.size == 0: + return "value is empty" + if np.issubdtype(arr.dtype, np.floating): + if np.isnan(arr).any(): + return "contains NaN" + if np.isinf(arr).any(): + return "contains Inf" + if arr.dtype == object: + if any(v is None for v in arr.flat): + return "contains None in object array" + return None + + +def _is_branch_without_data(sample: Any, path: str) -> bool: + """Return True when `path` points to a branch node with no direct value. + + Args: + sample: Sample-like object exposing `get_tree()`. + path: CGNS path to inspect. + + Returns: + bool: True if node exists, node value is None, and node has children. + """ + tree = sample.get_tree() + if tree is None: + return False + node = CGU.getNodeByPath(tree, path) + if node is None: + return False + if len(node) < 3: + return False + return node[1] is None and bool(node[2]) + + +def _is_branch_without_data_in_mapping( + feature_name: str, + value: Any, + feat_map: dict[str, Any], +) -> bool: + """Return True when a dict-based feature entry represents a branch node. + + Args: + feature_name: Feature path/key currently inspected. + value: Feature value currently inspected. + feat_map: Mapping of feature names to values for a given sample/time. + + Returns: + bool: True when current entry is a branch with no direct data and + child entries exist in `feat_map`. + """ + if value is not None: + return False + prefix = f"{feature_name}/" + return any(name.startswith(prefix) for name in feat_map) + + +def compute_checksum(sample): + """Compute a SHA-256 checksum for a converted sample representation. + + Args: + sample: Sample object or dictionary representation to checksum. + + Returns: + Hexadecimal SHA-256 digest of the pickled sample. + """ + import hashlib + import pickle + + sha256 = hashlib.sha256() + sha256.update(pickle.dumps(sample)) + return sha256.hexdigest() + + +def check_dataset( + path: Path, + splits: Optional[list[str]] = None, +) -> CheckReport: + """Run integrity checks on a local PLAID dataset. + + Algorithm overview: + 1. Validate the required on-disk PLAID layout. + 2. Load infos, metadata, and split-specific dataset/converter objects. + 3. Validate top-level declarations from ``infos.yaml`` (backend, sample counts). + 4. Resolve requested splits and report unknown ones. + 5. For each checked split: + - verify split-level schema/value consistency, + - validate sample IDs, + - convert each sample and validate values, + - compute checksums for duplicate-data detection, + - build scalar signatures to detect duplicated DOE-like inputs. + 6. Validate optional problem definitions against available features/splits/indices. + 7. Emit an ``OK`` info message when no issue is detected. + + Args: + path: Dataset directory. + splits: Optional selected split names. + + Returns: + A populated :class:`CheckReport`. + """ + report = CheckReport(messages=[]) + + # First verify the dataset has the required on-disk files and folders. + # Later checks rely on these paths being present and readable. + _check_required_layout(path, report) + if report.has_errors(): + return report + + # Load dataset descriptors and metadata before touching sample payloads. + # Each loading step is isolated so the report points to the failing layer. + try: + infos = load_infos_from_disk(path) + except Exception as exc: + report.add("error", "INFOS_READ_ERROR", "infos.yaml", str(exc)) + return report + + try: + flat_cst, variable_schema, constant_schema, _ = load_metadata_from_disk(path) + except Exception as exc: + report.add("error", "METADATA_READ_ERROR", str(path), str(exc)) + return report + + try: + datasetdict, converterdict = init_from_disk(path) + except Exception as exc: + report.add("error", "DATASET_INIT_ERROR", str(path), str(exc)) + return report + + # Validate top-level dataset declarations from infos.yaml. + declared_backend = infos.get("storage_backend") + if not isinstance(declared_backend, str): + report.add( + "error", + "BACKEND_MISSING", + "infos.yaml", + "Missing or invalid 'storage_backend' in infos.yaml", + ) + + num_samples = infos.get("num_samples", {}) + if not isinstance(num_samples, dict): + report.add( + "error", "NUM_SAMPLES_INVALID", "infos.yaml", "'num_samples' must be a dict" + ) + num_samples = {} + + # Resolve the user-requested splits against the splits actually available. + dataset_splits = set(datasetdict.keys()) + target_splits = set(splits) if splits else dataset_splits + unknown_splits = target_splits - dataset_splits + for split in sorted(unknown_splits): + available = ' and '.join(f'"{x}"' for x in dataset_splits) + report.add("error", "UNKNOWN_SPLIT", split, f"Split not found in dataset, available are {available}") + target_splits = target_splits & dataset_splits + + checksum_report = {} + for split in sorted(target_splits): + dataset = datasetdict[split] + converter = converterdict[split] + + # Check split-level consistency between metadata, schemas, and storage. + expected_n = num_samples.get(split) + actual_n = len(dataset) + if isinstance(expected_n, int) and expected_n != actual_n: + report.add( + "error", + "SPLIT_COUNT_MISMATCH", + split, + f"Expected {expected_n} samples from infos.yaml, found {actual_n}", + ) + + if split not in constant_schema: + report.add( + "error", + "MISSING_CONSTANT_SCHEMA", + split, + "No constant schema for split", + ) + + if split not in flat_cst: + report.add( + "error", + "MISSING_CONSTANT_VALUES", + split, + "No constant values for split", + ) + + + # Deep-check to validate content and detect non valide data in fields (nan inf) + for idx in range(actual_n): + try: + sample = converter.to_plaid(dataset, idx) + except Exception as exc: + report.add( + "error", + "SAMPLE_CONVERSION_ERROR", + f"{split}[{idx}]", + str(exc), + ) + continue + + # Track whole-sample checksums to detect duplicated data across + # all checked splits after the per-split loop completes. + sample_checksum = compute_checksum(sample) + checksum_report[(idx, split)] = sample_checksum + + for global_name in sample.get_global_names(): + global_path = "Global/" + global_name + value = sample.get_feature_by_path(global_path) + + if _is_branch_without_data(sample, global_path): + continue + + issue = _check_numeric_content(value) + if issue is not None: + report.add( + "warning", + "INVALID_DATA_VALUE A", + f"{split}[{idx}] global/{global_name}", + issue, + ) + + for time in sample.get_all_time_values(): + local_bases = sample.get_base_names(time=time) + for base in local_bases: + zone_names = sample.features.get_zone_names( + base=base, time=time + ) + for zone in zone_names: + for location in CGNS_FIELD_LOCATIONS: + field_names = sample.get_field_names( + location=location, + zone=zone, + base=base, + time=time, + ) + + for f_name in field_names: + field_value = sample.get_field(f_name, + location= location, + zone=zone, + base=base, + time=time) + issue = _check_numeric_content(field_value) + if issue is not None: + report.add( + "warning", + "INVALID_DATA_VALUE A", + f"{split}[{idx}][{time}] {base}/{zone}/{location}/{f_name}", + issue, + ) + + # Compare checksums from every checked sample to flag identical sample data. + checksum_values = list(checksum_report.values()) + if len(checksum_report) != len(np.unique(checksum_values)): + k = list(checksum_report.keys()) + v = list(checksum_report.values()) + uni, cou = np.unique(v, return_counts=True) + for u, c in zip(uni, cou): + if c == 1: + continue + duplicated = k[v == u] + + report.add( + "warning", + "DUPLICATED_DATA", + str(duplicated), + "duplicated sample", + ) + # If problem definitions are present, verify that their feature references, + # split names, and sample indices are compatible with the dataset. + pb_def_dir = path / "problem_definitions" + if pb_def_dir.exists(): + try: + pb_defs = load_problem_definitions_from_disk(path) + except Exception as exc: + report.add( + "error", + "PB_DEF_READ_ERROR", + "problem_definitions", + str(exc), + ) + return report + + all_features = set(variable_schema.keys()) + for split_cst in flat_cst.values(): + all_features.update(split_cst.keys()) + + for pb_name, pb_def in pb_defs.items(): + for feat in pb_def.input_features: + if feat not in all_features: + report.add( + "error", + "PB_DEF_UNKNOWN_INPUT", + f"problem_definitions/{pb_name}", + f"Unknown input feature: {feat}", + ) + + for feat in pb_def.output_features: + if feat not in all_features: + report.add( + "error", + "PB_DEF_UNKNOWN_OUTPUT", + f"problem_definitions/{pb_name}", + f"Unknown output feature: {feat}", + ) + + for split_dict_name in ["train_split", "test_split"]: + split_dict = getattr(pb_def, split_dict_name) + if split_dict is None: + continue + #split_dict must have only one elements + if len(split_dict) > 1 : + report.add( + "error", + "PB_DEF_SPLIT", + f"problem_definitions/{pb_name}", + f"{split_dict_name} has more than 1 split: {list(split_dict.keys())}", + ) + continue + split_name = next(iter(split_dict.keys())) + split_ids = next(iter(split_dict.values())) + if split_name not in dataset_splits: + report.add( + "error", + "PB_DEF_UNKNOWN_SPLIT", + f"problem_definitions/{pb_name}", + f"Unknown split in {split_dict_name}: {split_name}", + ) + continue + if split_ids == "all": + continue + ids_list = list(split_ids) + if len(ids_list) != len(set(ids_list)): + report.add( + "error", + "PB_DEF_DUPLICATE_INDICES", + f"problem_definitions/{pb_name}", + f"Duplicated indices in {split_dict_name}", + ) + split_len = len(datasetdict[split_name]) + bad = [i for i in ids_list if i < 0 or i >= split_len] + if bad: + report.add( + "error", + "PB_DEF_OUT_OF_RANGE_INDICES", + f"problem_definitions/{pb_name}", + f"Out-of-range indices in {split_dict_name} (first 10): {bad[:10]}", + ) + + # Emit an explicit success message when no errors or warnings were found. + if not report.messages: + report.add("info", "OK", str(path), "No issue detected") + return report + + +def _build_parser() -> argparse.ArgumentParser: + """Build the command-line parser for the dataset checker. + + Returns: + Configured argument parser for the `plaid-check` command. + """ + parser = argparse.ArgumentParser(description="Check integrity of a PLAID dataset.") + parser.add_argument("path", type=Path, help="Path to local PLAID dataset") + parser.add_argument( + "--split", + action="append", + default=None, + help="Split to check (can be provided multiple times)", + ) + parser.add_argument( + "--json", + action="store_true", + help="Print report in JSON format", + ) + parser.add_argument( + "--strict", + action="store_true", + help="Treat warnings as failure", + ) + return parser + + +def main(argv: Optional[list[str]] = None) -> int: + """CLI entry point for `plaid-check`. + + Args: + argv: Optional command-line args. + + Returns: + Process exit code. + """ + parser = _build_parser() + args = parser.parse_args(argv) + + report = check_dataset( + path=args.path, + splits=args.split + ) + + if args.json: + print(report.to_json()) + else: + for msg in report.messages: + print(f"[{msg.severity.upper()}] {msg.code} {msg.location}: {msg.message}") + counts = report.counts() + print( + f"Summary: {counts['error']} error(s), " + f"{counts['warning']} warning(s), {counts['info']} info message(s)" + ) + + if report.has_errors(): + return 1 + if args.strict and report.has_warnings(): + return 2 + return 0 + + +if __name__ == "__main__": # pragma: no cover + raise SystemExit(main()) diff --git a/tests/cli/test_plaidcheck.py b/tests/cli/test_plaidcheck.py new file mode 100644 index 00000000..7493e92a --- /dev/null +++ b/tests/cli/test_plaidcheck.py @@ -0,0 +1,463 @@ +"""Tests for the plaidcheck CLI and checker helpers.""" + +import json +import shutil +from pathlib import Path +from typing import Any + +import numpy as np +import yaml + +from plaid.cli import plaidcheck +from plaid.cli.plaidcheck import ( + CheckReport, + _check_numeric_content, + _is_branch_without_data, + _is_branch_without_data_in_mapping, + check_dataset, + main, +) + + +def _copy_reference_dataset(tmp_path: Path) -> Path: + """Copy the small reference dataset used by container tests. + + Args: + tmp_path: Temporary pytest directory. + + Returns: + Path to the copied dataset root. + """ + src = Path(__file__).resolve().parent.parent / "containers" / "dataset" + dst = tmp_path / "dataset" + shutil.copytree(src, dst) + return dst + + +def test_check_dataset_valid_reference(tmp_path: Path) -> None: + """Reference dataset should pass with no errors.""" + dataset_path = _copy_reference_dataset(tmp_path) + + report = check_dataset(dataset_path) + + assert not report.has_errors() + + +def test_check_dataset_missing_infos(tmp_path: Path) -> None: + """Missing infos.yaml should be reported as an error.""" + dataset_path = _copy_reference_dataset(tmp_path) + (dataset_path / "infos.yaml").unlink() + + report = check_dataset(dataset_path) + + assert report.has_errors() + assert any(msg.code == "MISSING_PATH" for msg in report.messages) + + +def test_check_dataset_num_samples_mismatch(tmp_path: Path) -> None: + """Tampering with num_samples should raise split mismatch errors.""" + dataset_path = _copy_reference_dataset(tmp_path) + infos_path = dataset_path / "infos.yaml" + infos = yaml.safe_load(infos_path.read_text(encoding="utf-8")) + infos["num_samples"]["train"] = 1 + infos_path.write_text(yaml.dump(infos, sort_keys=False), encoding="utf-8") + + report = check_dataset(dataset_path) + + assert any(msg.code == "SPLIT_COUNT_MISMATCH" for msg in report.messages) + + +def test_main_json_output_and_exit_code(tmp_path: Path, capsys) -> None: + """CLI should output JSON and return expected status code.""" + dataset_path = _copy_reference_dataset(tmp_path) + + code = main([str(dataset_path), "--json", ]) + out = capsys.readouterr().out + payload = json.loads(out) + + assert code == 0 + assert "counts" in payload + assert "messages" in payload + + +def test_main_strict_fails_on_warning(tmp_path: Path) -> None: + """In strict mode, warnings should make the command fail.""" + dataset_path = _copy_reference_dataset(tmp_path) + infos_path = dataset_path / "infos.yaml" + infos = yaml.safe_load(infos_path.read_text(encoding="utf-8")) + infos["num_samples"]["train"] = 11 + infos_path.write_text(yaml.dump(infos, sort_keys=False), encoding="utf-8") + + code = main([str(dataset_path), "--strict"]) + + assert code in {1, 2} + + +class _FakeSample: + """Minimal sample-like object exposing `get_tree` for helper tests.""" + + def get_tree(self) -> dict[str, str]: + """Return a sentinel tree object used by monkeypatched CGU access.""" + return {"tree": "sentinel"} + + +def test_is_branch_without_data_true_for_none_with_children(monkeypatch) -> None: + """Branch nodes with children and no data should be ignored by numeric checks.""" + + def _fake_get_node_by_path(tree: Any, path: str) -> list[Any]: + assert tree == {"tree": "sentinel"} + assert path == "Global/Branch" + return ["Branch", None, [["Child", 1.0, [], "DataArray_t"]], "UserDefinedData_t"] + + monkeypatch.setattr(plaidcheck.CGU, "getNodeByPath", _fake_get_node_by_path) + + assert _is_branch_without_data(_FakeSample(), "Global/Branch") + + +def test_is_branch_without_data_false_for_none_leaf(monkeypatch) -> None: + """Leaf nodes with None data must still be reported as invalid values.""" + + def _fake_get_node_by_path(tree: Any, path: str) -> list[Any]: + assert tree == {"tree": "sentinel"} + assert path == "Global/Leaf" + return ["Leaf", None, [], "DataArray_t"] + + monkeypatch.setattr(plaidcheck.CGU, "getNodeByPath", _fake_get_node_by_path) + + assert not _is_branch_without_data(_FakeSample(), "Global/Leaf") + assert _check_numeric_content(None) == "value is None" + + +def test_is_branch_without_data_in_mapping_true_for_branch_entry() -> None: + """Dict-based branch entry with None data should be skipped in B-path checks.""" + feat_map = { + "Global": None, + "Global/ParamA": 1.0, + "Global/ParamB": 2.0, + } + + assert _is_branch_without_data_in_mapping("Global", None, feat_map) + + +def test_is_branch_without_data_in_mapping_false_for_leaf_none() -> None: + """Dict-based leaf None entry should still be checked and reported.""" + feat_map = { + "Global/Leaf": None, + "Global/Other": 1.0, + } + + assert not _is_branch_without_data_in_mapping("Global/Leaf", None, feat_map) + + +def _make_minimal_layout(root: Path) -> Path: + """Create the minimal expected dataset layout for checker entry checks. + + Args: + root: Temporary path where the dataset directory is created. + + Returns: + Path to the dataset root. + """ + dataset = root / "dataset_min" + dataset.mkdir() + (dataset / "infos.yaml").write_text("storage_backend: zarr\n", encoding="utf-8") + (dataset / "variable_schema.yaml").write_text("{}\n", encoding="utf-8") + (dataset / "cgns_types.yaml").write_text("{}\n", encoding="utf-8") + (dataset / "constants").mkdir() + (dataset / "data").mkdir() + return dataset + + +class _FakeFeatures: + """Minimal features wrapper used by fake sample objects.""" + + def get_zone_names(self, base: str, time: float) -> list[str]: + """Return a single deterministic zone name. + + Args: + base: Ignored. + time: Ignored. + + Returns: + A single zone name list. + """ + return ["ZoneA"] + + +class _FakeSampleForCheck: + """Sample-like object implementing methods used by `check_dataset`.""" + + def __init__( + self, + global_value: Any = 1.0, + field_value: Any = 1.0, + global_names: list[str] | None = None, + tree: Any = None, + checksum: str = "same", + ) -> None: + self._global_value = global_value + self._field_value = field_value + self._global_names = ["G"] if global_names is None else global_names + self._tree = tree + self._checksum = checksum + self.features = _FakeFeatures() + + def get_global_names(self) -> list[str]: + """Return configured global names.""" + return self._global_names + + def get_feature_by_path(self, path: str) -> Any: + """Return configured global value. + + Args: + path: Ignored. + + Returns: + Global value payload. + """ + return self._global_value + + def get_tree(self): + """Return no CGNS tree to disable branch skipping.""" + return self._tree + + def get_all_time_values(self) -> list[float]: + """Return one time value.""" + return [0.0] + + def get_base_names(self, time: float) -> list[str]: + """Return one base name. + + Args: + time: Ignored. + + Returns: + A single base name list. + """ + return ["BaseA"] + + def get_field_names( + self, + location: str, + zone: str, + base: str, + time: float, + ) -> list[str]: + """Return one field name. + + Args: + location: Ignored. + zone: Ignored. + base: Ignored. + time: Ignored. + + Returns: + A single field name list. + """ + return ["F"] + + def get_field(self, *args, **kwargs) -> Any: + """Return configured field value. + + Returns: + Field value payload. + """ + return self._field_value + + +class _FakeDataset: + """Dataset-like object exposing only `__len__`.""" + + def __init__(self, n: int) -> None: + self._n = n + + def __len__(self) -> int: + """Return dataset size.""" + return self._n + + +class _FakeConverter: + """Converter-like object exposing `to_plaid`.""" + + def __init__(self, samples: list[Any], fail_indices: set[int] | None = None) -> None: + self._samples = samples + self._fail_indices = set() if fail_indices is None else fail_indices + + def to_plaid(self, dataset: _FakeDataset, idx: int) -> Any: + """Return fake sample or raise conversion error. + + Args: + dataset: Ignored. + idx: Sample index. + + Returns: + Fake sample instance. + """ + if idx in self._fail_indices: + raise RuntimeError("boom") + return self._samples[idx] + + +def test_check_numeric_content_all_remaining_branches() -> None: + """Numeric checker should report all remaining invalid content cases.""" + assert _check_numeric_content([]) == "value is empty" + assert _check_numeric_content(np.array([1.0, np.nan])) == "contains NaN" + assert _check_numeric_content(np.array([1.0, np.inf])) == "contains Inf" + assert ( + _check_numeric_content(np.array([None, "x"], dtype=object)) + == "contains None in object array" + ) + + +def test_is_branch_without_data_false_variants(monkeypatch) -> None: + """Branch helper should return False for missing tree/node/children layout.""" + + class _SampleNoTree: + def get_tree(self): + return None + + class _SampleWithTree: + def get_tree(self): + return {"tree": 1} + + assert not _is_branch_without_data(_SampleNoTree(), "any") + + monkeypatch.setattr(plaidcheck.CGU, "getNodeByPath", lambda tree, path: None) + assert not _is_branch_without_data(_SampleWithTree(), "any") + + monkeypatch.setattr(plaidcheck.CGU, "getNodeByPath", lambda tree, path: ["X", None]) + assert not _is_branch_without_data(_SampleWithTree(), "any") + + +def test_is_branch_without_data_in_mapping_false_when_value_present() -> None: + """Mapping helper should immediately reject non-None entries.""" + assert not _is_branch_without_data_in_mapping("Global", 1.0, {"Global/Child": 2.0}) + + +def test_check_dataset_loader_failures_and_header_validations( + tmp_path: Path, + monkeypatch, +) -> None: + """Checker should report infos/metadata/init failures and header errors.""" + dataset = _make_minimal_layout(tmp_path) + + monkeypatch.setattr(plaidcheck, "load_infos_from_disk", lambda path: (_ for _ in ()).throw(RuntimeError("infos"))) + report_infos = check_dataset(dataset) + assert any(msg.code == "INFOS_READ_ERROR" for msg in report_infos.messages) + + monkeypatch.setattr(plaidcheck, "load_infos_from_disk", lambda path: {"storage_backend": "zarr", "num_samples": {"train": 1}}) + monkeypatch.setattr(plaidcheck, "load_metadata_from_disk", lambda path: (_ for _ in ()).throw(RuntimeError("meta"))) + report_meta = check_dataset(dataset) + assert any(msg.code == "METADATA_READ_ERROR" for msg in report_meta.messages) + + monkeypatch.setattr(plaidcheck, "load_metadata_from_disk", lambda path: ({"train": {}}, {}, {"train": {}}, None)) + monkeypatch.setattr(plaidcheck, "init_from_disk", lambda path: (_ for _ in ()).throw(RuntimeError("init"))) + report_init = check_dataset(dataset) + assert any(msg.code == "DATASET_INIT_ERROR" for msg in report_init.messages) + + monkeypatch.setattr(plaidcheck, "init_from_disk", lambda path: ({"train": _FakeDataset(0)}, {"train": _FakeConverter([])})) + monkeypatch.setattr(plaidcheck, "load_infos_from_disk", lambda path: {"storage_backend": 12, "num_samples": "bad"}) + report_header = check_dataset(dataset) + assert any(msg.code == "BACKEND_MISSING" for msg in report_header.messages) + assert any(msg.code == "NUM_SAMPLES_INVALID" for msg in report_header.messages) + + +def test_check_dataset_split_and_data_warnings_and_duplicates(tmp_path: Path, monkeypatch) -> None: + """Checker should report split errors, warnings and duplicated samples.""" + dataset = _make_minimal_layout(tmp_path) + + monkeypatch.setattr(plaidcheck, "load_infos_from_disk", lambda path: {"storage_backend": "zarr", "num_samples": {"train": 3}}) + monkeypatch.setattr(plaidcheck, "load_metadata_from_disk", lambda path: ({}, {"Var": {}}, {}, None)) + + monkeypatch.setattr( + plaidcheck.CGU, + "getNodeByPath", + lambda tree, path: ["Branch", None, [["child", 1.0, [], "DataArray_t"]], "UserDefinedData_t"], + ) + + samples = [ + _FakeSampleForCheck(global_value=1.0, field_value=1.0, checksum="dup"), + _FakeSampleForCheck(global_value=np.array([np.nan]), field_value=np.array([np.nan]), tree={"branch": 1}, checksum="unique"), + _FakeSampleForCheck(global_value=np.array([np.nan]), field_value=np.array([np.nan]), checksum="dup"), + ] + converter = _FakeConverter(samples=samples) + datasetdict = {"train": _FakeDataset(3)} + monkeypatch.setattr(plaidcheck, "init_from_disk", lambda path: (datasetdict, {"train": converter})) + monkeypatch.setattr(plaidcheck, "compute_checksum", lambda sample: sample._checksum) + + report = check_dataset(dataset, splits=["train", "ghost"]) + + assert any(msg.code == "UNKNOWN_SPLIT" for msg in report.messages) + assert any(msg.code == "MISSING_CONSTANT_SCHEMA" for msg in report.messages) + assert any(msg.code == "MISSING_CONSTANT_VALUES" for msg in report.messages) + assert any(msg.code == "INVALID_DATA_VALUE A" for msg in report.messages) + assert any(msg.code == "DUPLICATED_DATA" for msg in report.messages) + + +def test_check_dataset_sample_conversion_error(tmp_path: Path, monkeypatch) -> None: + """Checker should emit conversion errors when converter fails on an index.""" + dataset = _make_minimal_layout(tmp_path) + + monkeypatch.setattr(plaidcheck, "load_infos_from_disk", lambda path: {"storage_backend": "zarr", "num_samples": {"train": 1}}) + monkeypatch.setattr(plaidcheck, "load_metadata_from_disk", lambda path: ({"train": {}}, {"Var": {}}, {"train": {}}, None)) + monkeypatch.setattr( + plaidcheck, + "init_from_disk", + lambda path: ({"train": _FakeDataset(1)}, {"train": _FakeConverter([_FakeSampleForCheck()], fail_indices={0})}), + ) + + report = check_dataset(dataset, splits=["train"]) + + assert any(msg.code == "SAMPLE_CONVERSION_ERROR" for msg in report.messages) + + +def test_check_dataset_problem_definition_validation_paths(tmp_path: Path, monkeypatch) -> None: + """Checker should cover problem-definition read/validation branches.""" + dataset = _make_minimal_layout(tmp_path) + (dataset / "problem_definitions").mkdir() + + monkeypatch.setattr(plaidcheck, "load_infos_from_disk", lambda path: {"storage_backend": "zarr", "num_samples": {"train": 2}}) + monkeypatch.setattr(plaidcheck, "load_metadata_from_disk", lambda path: ({"train": {"Known": 1.0}}, {"KnownVar": {}}, {"train": {}}, None)) + monkeypatch.setattr( + plaidcheck, + "init_from_disk", + lambda path: ({"train": _FakeDataset(2)}, {"train": _FakeConverter([_FakeSampleForCheck(), _FakeSampleForCheck()])}), + ) + + monkeypatch.setattr(plaidcheck, "load_problem_definitions_from_disk", lambda path: (_ for _ in ()).throw(RuntimeError("pb"))) + report_read = check_dataset(dataset) + assert any(msg.code == "PB_DEF_READ_ERROR" for msg in report_read.messages) + + class _PBDef: + def __init__(self, train_split, test_split): + self.input_features = ["UnknownInput"] + self.output_features = ["UnknownOutput"] + self.train_split = train_split + self.test_split = test_split + + pb_defs = { + "pb_many": _PBDef(train_split={"train": [0], "other": [1]}, test_split=None), + "pb_unknown_split": _PBDef(train_split={"ghost": [0]}, test_split=None), + "pb_indices": _PBDef(train_split={"train": [0, 0, -1, 9]}, test_split={"train": "all"}), + } + monkeypatch.setattr(plaidcheck, "load_problem_definitions_from_disk", lambda path: pb_defs) + report_pb = check_dataset(dataset) + + assert any(msg.code == "PB_DEF_UNKNOWN_INPUT" for msg in report_pb.messages) + assert any(msg.code == "PB_DEF_UNKNOWN_OUTPUT" for msg in report_pb.messages) + assert any(msg.code == "PB_DEF_SPLIT" for msg in report_pb.messages) + assert any(msg.code == "PB_DEF_UNKNOWN_SPLIT" for msg in report_pb.messages) + assert any(msg.code == "PB_DEF_DUPLICATE_INDICES" for msg in report_pb.messages) + assert any(msg.code == "PB_DEF_OUT_OF_RANGE_INDICES" for msg in report_pb.messages) + + +def test_main_strict_returns_warning_exit_code(monkeypatch, tmp_path: Path, capsys) -> None: + """Main should return exit code 2 when strict mode sees warnings only.""" + report = CheckReport(messages=[]) + report.add("warning", "W", "loc", "msg") + + monkeypatch.setattr(plaidcheck, "check_dataset", lambda path, splits=None: report) + code = main([str(tmp_path), "--strict"]) + _ = capsys.readouterr().out + + assert code == 2