diff --git a/curryer/correction/__init__.py b/curryer/correction/__init__.py index e69de29b..a273d816 100644 --- a/curryer/correction/__init__.py +++ b/curryer/correction/__init__.py @@ -0,0 +1,155 @@ +"""Correction module for iterative geolocation alignment. + +Provides tools for Monte Carlo-style correction loops, image matching +against ground control points, PSF modelling, and error statistics. + +Sub-module layout +----------------- +config + Config dataclasses (``CorrectionConfig``, ``ParameterConfig``, etc.) + and the ``ParameterType`` enum. +parameters + Random parameter-set generation (:func:`load_param_sets`). +kernel_ops + SPICE kernel creation and telemetry/time-offset application. +results_io + NetCDF result file read/write and checkpoint support. +pipeline + Main :func:`loop` orchestration and all per-iteration helpers. + Preferred-name aliases: :func:`run_correction`, :func:`run_image_matching`, + :func:`compute_error_stats`. + :func:`run_correction` also accepts :class:`CorrectionInput` objects. +correction + Thin re-export shim -- keeps all existing + ``from curryer.correction import correction`` import paths working. +correction_config + Utilities for reading and validating JSON config files. +data_structures + Shared data-container dataclasses (``ImageGrid``, ``PSFGrid``, ...). +dataio + Validation helpers and S3 data-access utilities. The S3 utilities + (``S3Configuration``, ``find_netcdf_objects``, ``download_netcdf_objects``) + are optional convenience helpers for mission-specific data pipelines; + the core correction API does not depend on them. +error_stats + Error statistics computation (``ErrorStatsProcessor``). +image_match + Image-matching algorithm (``integrated_image_match``). +io + Unified path resolution (``resolve_path``). Transparently handles + local paths and S3 URIs (``s3://…``) when ``boto3`` is installed. + Optional convenience — the public API contract is local ``Path`` + objects; S3 support is opt-in. +pairing + Ground-control-point pairing utilities. +psf + Point-spread-function modelling. +search + Image-search / correlation routines. +verification + Standalone geolocation compliance check (:func:`verify`). +""" + +# Sub-modules (ensure `curryer.correction.psf` etc. work as attributes) +from . import ( + config, + correction, + correction_config, + data_structures, + dataio, + error_stats, + image_io, + image_match, + io, + kernel_ops, + pairing, + parameters, + pipeline, + psf, + regrid, + results, + results_io, + search, + verification, +) + +# Key public names lifted to package level +from .config import ( + CorrectionConfig, + CorrectionInput, + DataConfig, + GeolocationConfig, + NetCDFConfig, + NetCDFParameterMetadata, + ParameterConfig, + ParameterType, + RequirementsConfig, + SearchStrategy, + load_config_from_json, +) +from .data_structures import ImageGrid, PSFGrid, PSFSamplingConfig, RegridConfig, SearchConfig +from .error_stats import ErrorStatsConfig, ErrorStatsProcessor, compute_percent_below +from .io import resolve_path +from .pipeline import compute_error_stats, loop, run_correction, run_image_matching +from .results import CorrectionResult, ParameterSetResult +from .verification import GCPError, VerificationResult, compare_results, verify + +__all__ = [ + # Sub-modules + "config", + "correction", + "correction_config", + "data_structures", + "dataio", + "error_stats", + "image_io", + "image_match", + "io", + "kernel_ops", + "pairing", + "parameters", + "pipeline", + "psf", + "regrid", + "results", + "results_io", + "search", + "verification", + # Config + "CorrectionConfig", + "CorrectionInput", + "DataConfig", + "GeolocationConfig", + "NetCDFConfig", + "NetCDFParameterMetadata", + "ParameterConfig", + "ParameterType", + "RequirementsConfig", + "SearchStrategy", + "load_config_from_json", + # Pipeline entry points + "loop", + "run_correction", + "compute_error_stats", + "run_image_matching", + # Data structures + "ImageGrid", + "PSFGrid", + "PSFSamplingConfig", + "RegridConfig", + "SearchConfig", + # Error stats + "ErrorStatsConfig", + "ErrorStatsProcessor", + "compute_percent_below", + # IO + "resolve_path", + # Verification + "GCPError", + "VerificationResult", + "compare_results", + "verify", + # Structured results + "CorrectionResult", + "ParameterSetResult", +] diff --git a/curryer/correction/config.py b/curryer/correction/config.py new file mode 100644 index 00000000..86107ea2 --- /dev/null +++ b/curryer/correction/config.py @@ -0,0 +1,928 @@ +"""Configuration models and enumerations for the geolocation correction pipeline. + +This module defines the data structures that represent the complete configuration +for a correction analysis run, including: + +- ``ParameterType`` – enum of the three parameter variation strategies +- ``ParameterData`` – typed container for a parameter's sampling spec +- ``ParameterConfig`` – a single parameter to vary (kernel or time offset) +- ``GeolocationConfig`` – SPICE kernel paths and instrument settings +- ``NetCDFParameterMetadata`` / ``NetCDFConfig`` – NetCDF output metadata +- ``CorrectionConfig`` – the single top-level config object passed to ``pipeline.loop()`` +- ``KernelContext``, ``CalibrationData``, ``ImageMatchingContext`` – lightweight NamedTuples + used to pass state between pipeline helper functions +- ``load_config_from_json`` – build a ``CorrectionConfig`` from a JSON file + +All mission-specific values (kernel filenames, parameter ranges, instrument names) +live in mission configuration modules (e.g. ``tests/test_correction/clarreo_config.py``) +and are injected via ``CorrectionConfig``. + +All config objects are ``pydantic.BaseModel`` subclasses which provide: +- Automatic type validation and clear ``ValidationError`` messages on construction +- Free JSON serialization via ``model_dump_json()`` / ``model_validate_json()`` +- IDE autocomplete on every field +""" + +import json +import logging +import warnings +from enum import Enum, auto +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal, NamedTuple + +import numpy as np +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator + +if TYPE_CHECKING: + from curryer import meta + +# ============================================================================ +# Standard NetCDF Variable Attributes (Mission-Agnostic) +# ============================================================================ + +STANDARD_NETCDF_ATTRIBUTES = { + # Geolocation error metrics (per GCP pair) + "rms_error_m": {"units": "meters", "long_name": "RMS geolocation error"}, + "mean_error_m": {"units": "meters", "long_name": "Mean geolocation error"}, + "max_error_m": {"units": "meters", "long_name": "Maximum geolocation error"}, + "std_error_m": {"units": "meters", "long_name": "Standard deviation of geolocation error"}, + "n_measurements": {"units": "count", "long_name": "Number of measurement points"}, + # Aggregate performance metrics (per parameter set) + "mean_rms_all_pairs": {"units": "meters", "long_name": "Mean RMS error across all GCP pairs"}, + "worst_pair_rms": {"units": "meters", "long_name": "Worst performing GCP pair RMS error"}, + "best_pair_rms": {"units": "meters", "long_name": "Best performing GCP pair RMS error"}, + # Image matching metrics (per GCP pair) + "im_lat_error_km": {"units": "kilometers", "long_name": "Image matching latitude error"}, + "im_lon_error_km": {"units": "kilometers", "long_name": "Image matching longitude error"}, + "im_ccv": {"units": "dimensionless", "long_name": "Image matching correlation coefficient"}, + "im_grid_step_m": {"units": "meters", "long_name": "Image matching final grid step size"}, +} + + +# ============================================================================ +# Standard Data Variable Names (Mission-Agnostic Keys) +# ============================================================================ + +# Standard variable names that should be present in image matching results. +# Used for extracting data from xarray.Dataset objects. +STANDARD_VAR_NAMES = { + # Error measurements (required) + "lat_error_deg": "lat_error_deg", + "lon_error_deg": "lon_error_deg", + # Spacecraft state (configurable names) + "spacecraft_position": "sc_position", # Generic default + "boresight": "boresight", # Generic default + "transformation_matrix": "t_inst2ref", # Generic default + # Control point location (optional) + "gcp_lat_deg": "gcp_lat_deg", + "gcp_lon_deg": "gcp_lon_deg", + "gcp_alt": "gcp_alt", +} + + +# ============================================================================ +# Pipeline Helper NamedTuples (not config – pass-through state only) +# ============================================================================ + + +class KernelContext(NamedTuple): + """Context for SPICE kernel loading during geolocation.""" + + mkrn: "meta.MetaKernel" + dynamic_kernels: list[Path] + param_kernels: list[Path] + + +class CalibrationData(NamedTuple): + """Pre-loaded calibration data for image matching.""" + + los_vectors: np.ndarray | None + optical_psfs: list | None + + +class ImageMatchingContext(NamedTuple): + """Context data needed for image matching operations.""" + + gcp_pairs: list[tuple] + params: list[tuple] + pair_idx: int + sci_key: str + + +# ============================================================================ +# Data Loading Configuration +# ============================================================================ + + +class DataConfig(BaseModel): + """Configuration for config-driven internal data loading. + + Replaces mission-specific loader callables with a declarative specification + of how files should be read. The pipeline reads telemetry and science data + directly from the provided file paths using pandas/xarray, applying the + ``time_scale_factor`` to convert the science time column to uGPS. + + Attributes + ---------- + file_format + File format for both telemetry and science data files. + ``"csv"`` uses :func:`pandas.read_csv`; ``"netcdf"`` converts via + :func:`xarray.open_dataset`; ``"hdf5"`` uses :func:`pandas.read_hdf`. + time_scale_factor + Multiply science timestamps by this factor to obtain uGPS + (microseconds since GPS epoch). For example, ``1e6`` converts GPS + seconds to uGPS; ``1.0`` means the file already contains uGPS. + The time column name is taken from :attr:`GeolocationConfig.time_field` + (single source of truth). + """ + + file_format: Literal["csv", "netcdf", "hdf5"] = "csv" + time_scale_factor: float = 1.0 + # Explicit column name mappings for telemetry spacecraft-position data. + # e.g. ["sc_pos_x", "sc_pos_y", "sc_pos_z"]. None means use mission defaults. + position_columns: list[str] | None = None + + +# ============================================================================ +# Parameter Configuration +# ============================================================================ + + +class ParameterType(Enum): + CONSTANT_KERNEL = auto() # Set a specific value. + OFFSET_KERNEL = auto() # Modify input kernel data by an offset. + OFFSET_TIME = auto() # Modify input timetags by an offset + + +class SearchStrategy(str, Enum): + """Strategy used to generate parameter sets during correction analysis. + + Attributes + ---------- + RANDOM + Monte Carlo random walk (current default). Each iteration draws an + independent sample from a normal distribution centred on the + parameter's ``current_value`` with the specified ``sigma``, clipped + to ``bounds``. Requires ``seed`` and ``n_iterations`` on + :class:`CorrectionConfig`. + GRID_SEARCH + Deterministic cartesian-product sweep. For every parameter, + ``grid_points_per_param`` evenly-spaced values are generated across + the full ``bounds`` offset range and the cartesian product of all + per-parameter grids is enumerated. ``n_iterations`` is ignored. + SINGLE_OFFSET + Deterministic single-parameter sweep. Each parameter is varied + independently across ``n_iterations`` evenly-spaced values (spanning + its ``bounds`` offset range) while all other parameters are held at + their nominal ``current_value``. Total parameter sets produced: + ``len(parameters) × n_iterations``. + """ + + RANDOM = "random" + GRID_SEARCH = "grid" + SINGLE_OFFSET = "single" + + +class ParameterData(BaseModel): + """Typed sampling specification for a single correction parameter. + + Supports dict-style access (``get``, ``__getitem__``, ``__contains__``) + for backward compatibility with code written against the old ``dict``-based + ``ParameterConfig.data`` API. + + Attributes + ---------- + current_value + Baseline parameter value(s). A scalar for OFFSET_KERNEL/OFFSET_TIME + and a 3-element list ``[roll, pitch, yaw]`` for CONSTANT_KERNEL. + bounds + ``[min, max]`` offset limits (same units as ``sigma``). + sigma + Standard deviation for normal-distribution sampling. ``None`` means + the parameter is held fixed at ``current_value``. + units + Physical units string, e.g. ``"arcseconds"`` or ``"milliseconds"``. + distribution + Sampling distribution name. Stored for documentation purposes; + the current implementation always uses a normal distribution. + field + Telemetry / science DataFrame column that this parameter modifies + (required for ``OFFSET_KERNEL`` and ``OFFSET_TIME``). + transformation_type + Optional hint consumed by kernel-creation routines (e.g. + ``"dcm_rotation"`` or ``"angle_bias"``). + coordinate_frames + Optional list of SPICE frame names affected by this parameter. + """ + + model_config = ConfigDict(extra="allow") + + current_value: float | list[float] = 0.0 + bounds: list[float] = Field(default_factory=lambda: [-1.0, 1.0]) + sigma: float | None = None + units: str | None = None + distribution: str = "normal" + field: str | None = None + transformation_type: str | None = None + coordinate_frames: list[str] | None = None + + # ------------------------------------------------------------------ + # Backward-compatible dict-style access + # ------------------------------------------------------------------ + + def _get_raw(self, key: str) -> Any: + """Return the raw value for *key* from declared fields or extra fields.""" + if key in type(self).model_fields: + return getattr(self, key, None) + extra = self.__pydantic_extra__ or {} + return extra.get(key) + + def get(self, key: str, default: Any = None) -> Any: + """``dict.get()`` shim for backward compatibility. + + Returns *default* when the value is ``None`` (i.e. field was not + explicitly set), mirroring ``dict.get`` on a mapping that only + contains keys with non-``None`` values. + """ + val = self._get_raw(key) + return default if val is None else val + + def __contains__(self, key: str) -> bool: + """``key in data`` shim – ``True`` when the value is not ``None``.""" + return self._get_raw(key) is not None + + def __getitem__(self, key: str) -> Any: + """``data[key]`` shim for backward compatibility.""" + if key in type(self).model_fields: + return getattr(self, key) + extra = self.__pydantic_extra__ or {} + if key in extra: + return extra[key] + raise KeyError(key) + + +class ParameterConfig(BaseModel): + """A single parameter to vary during correction analysis. + + Attributes + ---------- + ptype + How this parameter is applied (constant kernel, offset kernel, or + time offset). + config_file + Path to the SPICE kernel JSON template, or ``None`` for time + offsets that require no kernel file. + data + Sampling specification. Accepts a plain ``dict`` or ``None`` on + construction (Pydantic coerces both to :class:`ParameterData` + automatically; ``None`` becomes an empty ``ParameterData()``). + """ + + ptype: ParameterType + config_file: Path | None = None + data: ParameterData = Field(default_factory=ParameterData) + + @model_validator(mode="before") + @classmethod + def _coerce_none_data(cls, values: Any) -> Any: + """Convert ``data=None`` to an empty ``ParameterData`` (backward compat).""" + if isinstance(values, dict) and values.get("data") is None: + values = dict(values) + values["data"] = {} + return values + + +# ============================================================================ +# Geolocation Configuration +# ============================================================================ + + +class GeolocationConfig(BaseModel): + """SPICE kernel paths and instrument settings for geolocation. + + Attributes + ---------- + meta_kernel_file + Path to the mission meta-kernel JSON file. + generic_kernel_dir + Directory containing generic/shared SPICE kernels. + dynamic_kernels + Kernels regenerated from telemetry each run (SC-SPK, SC-CK, etc.) + but *not* altered by parameter variations. + instrument_name + SPICE instrument name (e.g. ``"CPRS_HYSICS"``). + time_field + Column name in the science DataFrame that holds uGPS timestamps. + minimum_correlation + Optional image-matching quality filter threshold (0.0–1.0). + """ + + meta_kernel_file: Path + generic_kernel_dir: Path + dynamic_kernels: list[Path] = Field(default_factory=list) + instrument_name: str + time_field: str + minimum_correlation: float | None = None + + +# ============================================================================ +# NetCDF Output Configuration +# ============================================================================ + + +class NetCDFParameterMetadata(BaseModel): + """NetCDF metadata for a single output parameter variable.""" + + variable_name: str + units: str + long_name: str + + +class NetCDFConfig(BaseModel): + """Configuration for NetCDF output structure and metadata. + + Attributes + ---------- + performance_threshold_m + Accuracy threshold in metres used to derive threshold-specific + variable names (e.g. ``"percent_under_250m"``). + title + Global title attribute for the output NetCDF file. + description + Global description attribute for the output NetCDF file. + parameter_metadata + Optional mapping of parameter key → :class:`NetCDFParameterMetadata`. + Auto-generated from ``CorrectionConfig.parameters`` when ``None``. + standard_attributes + Optional mission-specific attribute overrides. Falls back to the + module-level :data:`STANDARD_NETCDF_ATTRIBUTES` when ``None``. + """ + + performance_threshold_m: float + title: str = "Correction Geolocation Analysis Results" + description: str = "Parameter sensitivity analysis" + parameter_metadata: dict[str, NetCDFParameterMetadata] | None = None + standard_attributes: dict[str, dict[str, str]] | None = None + + def get_threshold_metric_name(self) -> str: + """Generate metric name dynamically from threshold.""" + threshold_m = int(self.performance_threshold_m) + return f"percent_under_{threshold_m}m" + + def get_standard_attributes(self) -> dict[str, dict[str, str]]: + """Get standard variable attributes, using mission overrides if provided.""" + if self.standard_attributes is not None: + return self.standard_attributes + return STANDARD_NETCDF_ATTRIBUTES.copy() + + def get_parameter_netcdf_metadata( + self, param_config: "ParameterConfig", angle_type: str | None = None + ) -> "NetCDFParameterMetadata": + """Get NetCDF metadata for a parameter.""" + if param_config.config_file: + param_stem = param_config.config_file.stem + lookup_key = f"{param_stem}_{angle_type}" if angle_type else param_stem + else: + lookup_key = f"param_{param_config.ptype.name.lower()}" + + if self.parameter_metadata and lookup_key in self.parameter_metadata: + return self.parameter_metadata[lookup_key] + + return self._auto_generate_metadata(param_config, angle_type, lookup_key) + + def _auto_generate_metadata( + self, param_config: "ParameterConfig", angle_type: str | None, base_key: str + ) -> "NetCDFParameterMetadata": + """Auto-generate NetCDF metadata from parameter configuration.""" + if param_config.ptype == ParameterType.CONSTANT_KERNEL: + units = "arcseconds" + elif param_config.ptype == ParameterType.OFFSET_KERNEL: + units = "arcseconds" + elif param_config.ptype == ParameterType.OFFSET_TIME: + units = "milliseconds" + else: + units = "unknown" + + # Use declared units field (replaces old isinstance(data, dict) check) + if param_config.data.units is not None: + units = param_config.data.units + + var_name = base_key.replace(".", "_").replace("-", "_") + if not var_name.startswith("param_"): + var_name = f"param_{var_name}" + + if param_config.config_file: + file_stem = param_config.config_file.stem + clean_name = file_stem.replace("_v01", "").replace("_v02", "").replace(".attitude.ck", "") + clean_name = clean_name.replace("_", " ").title() + if angle_type: + long_name = f"{clean_name} {angle_type} correction" + else: + long_name = f"{clean_name} correction" + else: + long_name = f"{param_config.ptype.name.replace('_', ' ').title()} parameter" + + return NetCDFParameterMetadata(variable_name=var_name, units=units, long_name=long_name) + + +# ============================================================================ +# Verification Requirements Configuration +# ============================================================================ + + +class RequirementsConfig(BaseModel): + """Verification requirements / thresholds. + + Can be attached as an optional ``verification`` field on + :class:`CorrectionConfig`, or passed directly to + :func:`~curryer.correction.verification.verify`. When neither is supplied, + :func:`~curryer.correction.verification.verify` falls back to + :attr:`CorrectionConfig.performance_threshold_m` and + :attr:`CorrectionConfig.performance_spec_percent`. + + Attributes + ---------- + performance_threshold_m : float + Per-measurement nadir-equivalent error limit in metres. + A measurement *passes* when its error is **below** this value. + performance_spec_percent : float + Minimum fraction of measurements (0–100) that must pass for the + overall verification to be considered successful. + """ + + performance_threshold_m: float + performance_spec_percent: float + + +# ============================================================================ +# Top-Level Correction Configuration +# ============================================================================ + + +class CorrectionConfig(BaseModel): + """The configuration object for geolocation correction analysis. + + This config contains everything needed for a Correction run: + - What parameters to vary (parameters list) + - How to vary them (seed, n_iterations) + - How to load data (telemetry_loader, science_loader) + - How to process data (gcp_pairing_func, image_matching_func) + - Geolocation settings (geo: GeolocationConfig) + - Success criteria (performance_threshold_m, performance_spec_percent) + - Output configuration (netcdf: NetCDFConfig, output_filename) + + Create one CorrectionConfig object and pass it to pipeline.loop() to run. + + Serialisation + ------------- + ``model_dump_json()`` / ``model_validate_json()`` provide lossless + JSON round-trips for all typed fields. Callable fields (loaders, + pairing/matching functions) are **excluded** from serialisation because + they cannot be represented as JSON; re-attach them after deserialising. + + Parameters + ---------- + CORE CORRECTION SETTINGS: + seed : int | None + Random seed for reproducibility, or None for non-reproducible runs. + n_iterations : int + Number of parameter set iterations. + parameters : list[ParameterConfig] + Parameters to vary (defines sensitivity analysis). + + GEOLOCATION & PERFORMANCE REQUIREMENTS: + geo : GeolocationConfig + performance_threshold_m : float + performance_spec_percent : float + + DATA LOADING CONFIGURATION: + data : DataConfig | None + Specifies file format, time field, scale factor, and optional GCP + discovery settings. When provided, telemetry and science files are + read internally by the pipeline from the paths supplied in + ``tlm_sci_gcp_sets``. + + PROCESSING FUNCTION (optional override): + image_matching_func + Defaults to the built-in ``pipeline.image_matching`` when ``None``. + Override only for missions with fundamentally different matching. + + OUTPUT CONFIGURATION: + netcdf : NetCDFConfig | None + output_filename : str | None + + CALIBRATION CONFIGURATION: + calibration_dir : Path | None + calibration_file_names : dict[str, str] | None + + MISSION-SPECIFIC NAMING: + spacecraft_position_name, boresight_name, transformation_matrix_name + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + # CORE CORRECTION SETTINGS + seed: int | None = None + n_iterations: int + parameters: list[ParameterConfig] + + # SEARCH STRATEGY + search_strategy: SearchStrategy = SearchStrategy.RANDOM + grid_points_per_param: int = Field( + default=10, + ge=2, + description="Number of evenly-spaced grid points per parameter for GRID_SEARCH strategy.", + ) + max_grid_sets: int = Field( + default=100_000, + ge=1, + description=( + "Hard upper bound on the total number of parameter sets that GRID_SEARCH may materialise. " + "Prevents accidental out-of-memory runs caused by large cartesian products " + "(e.g. 10 points × 6 params = 1,000,000 sets). " + "Raise this value deliberately, or switch to SINGLE_OFFSET for high-dimensional sweeps." + ), + ) + + # GEOLOCATION & PERFORMANCE REQUIREMENTS + geo: GeolocationConfig + performance_threshold_m: float + performance_spec_percent: float + + # DATA LOADING CONFIGURATION (config-driven; replaces mission-specific loader callables) + data: DataConfig | None = None + + # Private test-injection override for image matching. + # Not part of the public API; not serialised to JSON (PrivateAttr is always excluded). + # Usage: config._image_matching_override = your_func + # TODO(#151): Add Requirement model with evaluate_all() for multi-metric requirements. + _image_matching_override: Any = PrivateAttr(default=None) + + @property + def image_matching_func(self) -> Any: + """Deprecated — use ``_image_matching_override`` for test injection. + + .. deprecated:: + Set ``config._image_matching_override = func`` instead. + This property will be removed in a future release. + """ + warnings.warn( + "image_matching_func is deprecated. Use config._image_matching_override = func for test injection.", + DeprecationWarning, + stacklevel=2, + ) + return self._image_matching_override + + @image_matching_func.setter + def image_matching_func(self, value: Any) -> None: + warnings.warn( + "image_matching_func is deprecated. Use config._image_matching_override = func for test injection.", + DeprecationWarning, + stacklevel=2, + ) + self._image_matching_override = value + + # OUTPUT CONFIGURATION + netcdf: NetCDFConfig | None = None + output_filename: str | None = None + + # CALIBRATION CONFIGURATION + calibration_dir: Path | None = None + calibration_file_names: dict[str, str] | None = None + # Direct calibration file paths (alternative to calibration_dir + calibration_file_names) + psf_file: Path | None = None + los_vectors_file: Path | None = None + + # MISSION-SPECIFIC NAMING + spacecraft_position_name: str = "sc_position" + boresight_name: str = "boresight" + transformation_matrix_name: str = "t_inst2ref" + + @model_validator(mode="after") + def _validate_search_strategy(self) -> "CorrectionConfig": + """Ensure strategy-specific settings are consistent.""" + if self.search_strategy in (SearchStrategy.GRID_SEARCH, SearchStrategy.SINGLE_OFFSET): + if not self.parameters: + raise ValueError( + f"SearchStrategy.{self.search_strategy.name} requires at least one parameter in `parameters`." + ) + return self + + def get_calibration_file(self, file_type: str, default: str = None) -> str: + """Get calibration filename for given type with fallback to default.""" + if self.calibration_file_names and file_type in self.calibration_file_names: + return self.calibration_file_names[file_type] + if default: + return default + raise ValueError(f"No calibration file configured for type: {file_type}") + + def validate(self, check_loaders: bool = False): + """Validate that all required configuration values are present. + + Args: + check_loaders: Deprecated – accepted for backward compatibility but + has no effect. Loader callables no longer exist on + this config; data loading is driven by the ``data`` + field (:class:`DataConfig`). + + Raises: + ValueError: If any required fields are missing or invalid + """ + import logging + + logger = logging.getLogger(__name__) + errors = [] + + if self.n_iterations is None or self.n_iterations <= 0: + errors.append("n_iterations must be a positive integer") + + if self.parameters is None or len(self.parameters) == 0: + errors.append("parameters list cannot be empty") + + if self.geo is None: + errors.append("geo (GeolocationConfig) is required") + + if self.performance_threshold_m is None or self.performance_threshold_m <= 0: + errors.append("performance_threshold_m must be a positive number (e.g., 250.0 meters)") + + if self.performance_spec_percent is None or not (0 <= self.performance_spec_percent <= 100): + errors.append("performance_spec_percent must be between 0 and 100 (e.g., 39.0)") + + if errors: + error_msg = "CorrectionConfig validation failed:\n - " + "\n - ".join(errors) + error_msg += "\n\nThese values must be provided in your mission configuration." + error_msg += "\nSee tests/test_correction/clarreo_config.py for an example." + raise ValueError(error_msg) + + logger.debug("CorrectionConfig validation passed") + + def ensure_netcdf_config(self): + """Ensure NetCDFConfig exists, creating with defaults if needed.""" + if self.netcdf is None: + self.netcdf = NetCDFConfig(performance_threshold_m=self.performance_threshold_m) + + def get_output_filename(self, default: str = "correction_results.nc") -> str: + """Get output filename with optional auto-generation.""" + if self.output_filename: + return self.output_filename + return default + + @staticmethod + def generate_timestamped_filename(prefix: str = "correction", suffix: str = "") -> str: + """Generate a timestamped output filename for production use.""" + import datetime + + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + if suffix: + return f"{prefix}_{timestamp}_{suffix}.nc" + return f"{prefix}_{timestamp}.nc" + + +# ============================================================================ +# JSON Config Loading +# ============================================================================ + +# Imported here (not at module top) to avoid a circular import with the +# correction_config sibling module, which itself has no dependency on config. +from curryer.correction import correction_config as _correction_config # noqa: E402 + +_config_logger = logging.getLogger(__name__) + + +def load_config_from_json(config_path: Path) -> "CorrectionConfig": + """Load correction configuration from a JSON file. + + Args: + config_path: Path to the JSON configuration file (e.g., gcs_config.json) + + Returns: + CorrectionConfig object populated from the JSON file + + Raises: + FileNotFoundError: If config file doesn't exist + ValueError: If config file format is invalid + KeyError: If required config sections are missing + """ + config_path = Path(config_path) + + if not config_path.exists(): + raise FileNotFoundError(f"Configuration file not found: {config_path}") + + _config_logger.info(f"Loading Correction configuration from: {config_path}") + + try: + with open(config_path) as f: + config_data = json.load(f) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in config file {config_path}: {e}") + + # Extract mission configuration and kernel mappings + mission_config = _correction_config.extract_mission_config(config_data) + constant_kernel_map = _correction_config.get_kernel_mapping(config_data, "constant_kernel") + offset_kernel_map = _correction_config.get_kernel_mapping(config_data, "offset_kernel") + + _config_logger.debug(f"Mission: {mission_config.get('mission_name', 'UNKNOWN')}") + _config_logger.debug(f"Constant kernel mappings: {constant_kernel_map}") + _config_logger.debug(f"Offset kernel mappings: {offset_kernel_map}") + + # Validate required sections exist + if "correction" not in config_data: + raise KeyError("Missing required 'correction' section in config file") + if "geolocation" not in config_data: + raise KeyError("Missing required 'geolocation' section in config file") + + # Extract correction section + corr_config = config_data.get("correction", {}) + geo_config = config_data.get("geolocation", {}) + + # Validate correction section + if "parameters" not in corr_config: + raise KeyError("Missing required 'parameters' in correction section") + if not isinstance(corr_config["parameters"], list): + raise ValueError("'parameters' must be a list") + if len(corr_config["parameters"]) == 0: + raise ValueError("No parameters defined in configuration") + + # Parse parameters and group related ones together + parameters = [] + param_groups = {} + + # First pass: group parameters by their base name and type + for param_dict in corr_config.get("parameters", []): + param_name = param_dict.get("name", "") + ptype_str = param_dict.get("parameter_type", "CONSTANT_KERNEL") + ptype = ParameterType[ptype_str] + + # Group CONSTANT_KERNEL parameters by their base frame name + if ptype == ParameterType.CONSTANT_KERNEL: + # Extract base name (e.g., "hysics_to_cradle" from "hysics_to_cradle_roll") + if "_roll" in param_name: + base_name = param_name.replace("_roll", "") + angle_type = "roll" + elif "_pitch" in param_name: + base_name = param_name.replace("_pitch", "") + angle_type = "pitch" + elif "_yaw" in param_name: + base_name = param_name.replace("_yaw", "") + angle_type = "yaw" + else: + base_name = param_name + angle_type = "single" + + if base_name not in param_groups: + param_groups[base_name] = {"type": ptype, "angles": {}, "template": param_dict, "config_file": None} + + param_groups[base_name]["angles"][angle_type] = param_dict.get("initial_value", 0.0) + + # Determine config file based on kernel mapping from config + kernel_file = _correction_config.find_kernel_file(base_name, constant_kernel_map) + if kernel_file: + param_groups[base_name]["config_file"] = Path(kernel_file) + _config_logger.debug(f"Mapped CONSTANT_KERNEL '{base_name}' → {kernel_file}") + else: + _config_logger.warning(f"No kernel mapping found for CONSTANT_KERNEL parameter: {base_name}") + + else: + # OFFSET_KERNEL and OFFSET_TIME parameters are individual + param_groups[param_name] = {"type": ptype, "param_dict": param_dict, "config_file": None} + + if ptype == ParameterType.OFFSET_KERNEL: + kernel_file = _correction_config.find_kernel_file(param_name, offset_kernel_map) + if kernel_file: + param_groups[param_name]["config_file"] = Path(kernel_file) + _config_logger.debug(f"Mapped OFFSET_KERNEL '{param_name}' → {kernel_file}") + elif param_dict.get("config_file"): + param_groups[param_name]["config_file"] = Path(param_dict["config_file"]) + _config_logger.debug(f"Using explicit config_file for OFFSET_KERNEL '{param_name}'") + else: + _config_logger.warning(f"No kernel mapping found for OFFSET_KERNEL parameter: {param_name}") + + # Second pass: create ParameterConfig objects from groups + for group_name, group_data in param_groups.items(): + if group_data["type"] == ParameterType.CONSTANT_KERNEL: + template = group_data["template"] + angles = group_data["angles"] + center_values = [angles.get("roll", 0.0), angles.get("pitch", 0.0), angles.get("yaw", 0.0)] + param_data = { + "current_value": center_values, + "bounds": template.get("bounds", [-100, 100]), + "sigma": template.get("sigma"), + "units": template.get("units", "arcseconds"), + "distribution": template.get("distribution_type", "normal"), + "field": template.get("application_target", {}).get("field_name", None), + } + else: + param_dict = group_data["param_dict"] + param_data = { + "current_value": param_dict.get("current_value", param_dict.get("initial_value", 0.0)), + "bounds": param_dict.get("bounds", [-100, 100]), + "sigma": param_dict.get("sigma"), + "units": param_dict.get("units", "radians"), + "distribution": param_dict.get("distribution_type", "normal"), + "field": (param_dict.get("field") or param_dict.get("application_target", {}).get("field_name", None)), + } + + parameters.append( + ParameterConfig(ptype=group_data["type"], config_file=group_data["config_file"], data=param_data) + ) + + _config_logger.info( + f"Loaded {len(parameters)} parameter groups from {len(corr_config.get('parameters', []))} individual parameters" + ) + + # Parse geolocation configuration + default_instrument = mission_config.get("instrument_name") + instrument_name = geo_config.get("instrument_name", default_instrument) + if instrument_name is None: + raise ValueError("instrument_name must be specified in config (either in geolocation or mission section)") + + time_field = geo_config.get("time_field") + if time_field is None: + raise ValueError("time_field must be specified in geolocation config") + + geo = GeolocationConfig( + meta_kernel_file=Path(geo_config.get("meta_kernel_file", "")), + generic_kernel_dir=Path(geo_config.get("generic_kernel_dir", "")), + dynamic_kernels=[Path(k) for k in geo_config.get("dynamic_kernels", [])], + instrument_name=instrument_name, + time_field=time_field, + ) + + # Extract required mission-specific parameters from correction section + earth_radius_m = corr_config.get("earth_radius_m") + if earth_radius_m is not None: + _config_logger.warning( + "earth_radius_m in config is deprecated and ignored. " + "The WGS84 value from curryer.compute.constants is used instead." + ) + + performance_threshold_m = corr_config.get("performance_threshold_m") + if performance_threshold_m is None: + raise KeyError( + "Missing required 'performance_threshold_m' in correction config section. " + "This must be specified for your mission (e.g., 250.0 meters for CLARREO)." + ) + + performance_spec_percent = corr_config.get("performance_spec_percent") + if performance_spec_percent is None: + raise KeyError( + "Missing required 'performance_spec_percent' in correction config section. " + "This must be specified for your mission (e.g., 39.0 percent for CLARREO)." + ) + + config = CorrectionConfig( + seed=corr_config.get("seed"), + n_iterations=corr_config.get("n_iterations", 10), + parameters=parameters, + geo=geo, + performance_threshold_m=performance_threshold_m, + performance_spec_percent=performance_spec_percent, + ) + + config.validate() + + _config_logger.info( + f"Configuration loaded and validated: {config.n_iterations} iterations, " + f"{len(config.parameters)} parameter groups" + ) + return config + + +# ============================================================================ +# Typed Input Structure +# ============================================================================ + + +class CorrectionInput(BaseModel): + """A single input set for the correction loop. + + Replaces the positional tuple ``(telemetry_path, science_path, gcp_path)`` + with named fields for clarity and IDE autocomplete. + + Parameters + ---------- + telemetry_file : Path + Path to the telemetry CSV (or NetCDF/HDF5) file. + science_file : Path + Path to the science/timing CSV (or NetCDF/HDF5) file. + gcp_file : Path + Path to the GCP reference image (``.mat`` file). + + Examples + -------- + >>> from curryer.correction import CorrectionInput, run_correction + >>> inputs = [ + ... CorrectionInput( + ... telemetry_file="data/tlm_20240317.csv", + ... science_file="data/sci_20240317.csv", + ... gcp_file="gcps/landsat_chip_001.mat", + ... ) + ... ] + >>> result = run_correction(config, work_dir, inputs) + >>> results = result.results + >>> netcdf_data = result.netcdf_data + """ + + telemetry_file: Path + science_file: Path + gcp_file: Path diff --git a/curryer/correction/correction.py b/curryer/correction/correction.py index 014d5e4c..e5d5482c 100644 --- a/curryer/correction/correction.py +++ b/curryer/correction/correction.py @@ -1,2880 +1,118 @@ -""" -Mission-Agnostic Geolocation Correction Pipeline. - -This module provides generic correction infrastructure for -geolocation component sensitivity analysis. It was developed for CLARREO, -but intended to work with any Earth observation mission through configuration. - -Configuration Strategy: ----------------------- -**Do not edit this file to configure for your mission.** - -Instead, follow these steps: - -1. Create a mission-specific config module (e.g., tests/test_correction/your_mission_config.py) -2. Copy from tests/test_correction/clarreo_config.py as a template -3. Define your mission's: - - Kernel file paths - - Parameter definitions (bounds, sigma, units) - - Instrument name and settings - - Performance thresholds - - Earth radius and geodetic parameters -4. Use create_your_mission_correction_config() to build CorrectionConfig object -5. Optionally save config to JSON for reproducibility -6. Pass the configuration object to Correction pipeline functions - -Quick Start: ------------ - from curryer.correction import correction - from your_mission_config import create_mission_correction_config - from your_mission_loaders import load_mission_telemetry, load_mission_science - - # Create configuration - config = create_mission_correction_config(data_dir, generic_dir) - - # Add required loaders - config.telemetry_loader = load_mission_telemetry - config.science_loader = load_mission_science - config.gcp_pairing_func = your_pairing_function - config.image_matching_func = your_image_matching_function +"""Backward-compatibility re-export shim for the correction package. - # Run Correction analysis - results, netcdf_data = correction.loop(config, work_dir, tlm_sci_gcp_sets) +**Do not add new code here.** -For CLARREO Example: -------------------- -See tests/test_correction/clarreo_config.py for a complete reference implementation. +All implementation has been split into focused sub-modules: -For Configuration Details: -------------------------- -See CONFIGURATION_GUIDE.md in the repository root. +- :mod:`curryer.correction.config` -- config dataclasses & enums +- :mod:`curryer.correction.parameters` -- parameter set generation +- :mod:`curryer.correction.kernel_ops` -- SPICE kernel creation & offsets +- :mod:`curryer.correction.results_io` -- NetCDF read/write +- :mod:`curryer.correction.pipeline` -- main ``loop()`` orchestration -Mission-Agnostic Design: ------------------------ -This module contains NO mission-specific values, column names, or hardcoded constants. -All mission-specific parameters must be provided through CorrectionConfig. - -Core modules in curryer/correction/ are generic. Mission-specific code belongs in -your test or application directories (e.g., tests/test_correction/clarreo_*). +This module re-exports every public name so that existing code using +``from curryer.correction import correction`` continues to work without +modification. """ -import json -import logging -import time -import typing -from dataclasses import dataclass -from enum import Enum, auto -from pathlib import Path -from typing import Any, NamedTuple - -import numpy as np -import pandas as pd -import xarray as xr - -from curryer import meta -from curryer import spicierpy as sp -from curryer.compute import spatial -from curryer.correction import correction_config -from curryer.correction.data_structures import GeolocationConfig as ImageMatchGeolocationConfig -from curryer.correction.data_structures import SearchConfig - -# Import data loader protocols and validation -from curryer.correction.dataio import ( - GCPLoader, - ScienceLoader, - TelemetryLoader, - validate_science_output, - validate_telemetry_output, +# Config dataclasses, enums, and JSON loader +from curryer.correction.config import ( + STANDARD_NETCDF_ATTRIBUTES, + STANDARD_VAR_NAMES, + CalibrationData, + CorrectionConfig, + DataConfig, + GeolocationConfig, + ImageMatchingContext, + KernelContext, + NetCDFConfig, + NetCDFParameterMetadata, + ParameterConfig, + ParameterType, + load_config_from_json, ) -# Import image matching modules -from curryer.correction.image_match import ( - ImageMatchingFunc, - integrated_image_match, - load_image_grid_from_mat, - load_los_vectors_from_mat, - load_optical_psf_from_mat, - validate_image_matching_output, +# Kernel operations +from curryer.correction.kernel_ops import ( + _create_dynamic_kernels, + _create_parameter_kernels, + apply_offset, ) -# Import pairing protocols and validation -from curryer.correction.pairing import ( - GCPPairingFunc, +# Parameter generation +from curryer.correction.parameters import load_param_sets + +# Pipeline orchestration +from curryer.correction.pipeline import ( + _aggregate_image_matching_results, + _compute_parameter_set_metrics, + _extract_error_metrics, + _extract_parameter_values, + _extract_spacecraft_position_midframe, + _geolocate_and_match, + _geolocated_to_image_grid, + _load_calibration_data, + _load_file, + _load_image_pair_data, + _resolve_gcp_pairs, + _store_gcp_pair_results, + _store_parameter_values, + call_error_stats_module, + image_matching, + loop, ) -from curryer.correction.pairing import ( - validate_pairing_output as validate_gcp_pairing_output, -) -from curryer.kernels import create - -logger = logging.getLogger(__name__) - - -# ============================================================================ -# Helper Data Structures -# ============================================================================ - - -class KernelContext(NamedTuple): - """Context for SPICE kernel loading during geolocation. - - Attributes - ---------- - mkrn : meta.MetaKernel - MetaKernel instance with SDS and mission kernels - dynamic_kernels : list[Path] - List of dynamic kernel file paths (SC-SPK, SC-CK) - param_kernels : list[Path] - List of parameter-specific kernel file paths - """ - - mkrn: "meta.MetaKernel" - dynamic_kernels: list[Path] - param_kernels: list[Path] - - -class CalibrationData(NamedTuple): - """Pre-loaded calibration data for image matching. - - Attributes - ---------- - los_vectors : Optional[np.ndarray] - Line-of-sight vectors array, or None if not using calibration - optical_psfs : Optional[list] - List of optical PSF entries, or None if not using calibration - """ - - los_vectors: np.ndarray | None - optical_psfs: list | None - - -class ImageMatchingContext(NamedTuple): - """Context data needed for image matching operations. - - Attributes - ---------- - gcp_pairs : list[tuple] - List of GCP pairing tuples from pairing function - params : list[tuple] - List of (ParameterConfig, parameter_value) tuples for this iteration - pair_idx : int - Index of current GCP pair being processed - sci_key : str - Science dataset identifier for this pair - """ - - gcp_pairs: list[tuple] - params: list[tuple] - pair_idx: int - sci_key: str - - -# ============================================================================ -# Standard NetCDF Variable Attributes (Mission-Agnostic) -# ============================================================================ - -# Standard metric attributes for NetCDF output -# These are generic geolocation/error metrics that apply to most missions -# Missions can override these in their NetCDFConfig if needed -STANDARD_NETCDF_ATTRIBUTES = { - # Geolocation error metrics (per GCP pair) - "rms_error_m": {"units": "meters", "long_name": "RMS geolocation error"}, - "mean_error_m": {"units": "meters", "long_name": "Mean geolocation error"}, - "max_error_m": {"units": "meters", "long_name": "Maximum geolocation error"}, - "std_error_m": {"units": "meters", "long_name": "Standard deviation of geolocation error"}, - "n_measurements": {"units": "count", "long_name": "Number of measurement points"}, - # Aggregate performance metrics (per parameter set) - "mean_rms_all_pairs": {"units": "meters", "long_name": "Mean RMS error across all GCP pairs"}, - "worst_pair_rms": {"units": "meters", "long_name": "Worst performing GCP pair RMS error"}, - "best_pair_rms": {"units": "meters", "long_name": "Best performing GCP pair RMS error"}, - # Image matching metrics (per GCP pair) - "im_lat_error_km": {"units": "kilometers", "long_name": "Image matching latitude error"}, - "im_lon_error_km": {"units": "kilometers", "long_name": "Image matching longitude error"}, - "im_ccv": {"units": "dimensionless", "long_name": "Image matching correlation coefficient"}, - "im_grid_step_m": {"units": "meters", "long_name": "Image matching final grid step size"}, -} - - -# ============================================================================ -# Standard Data Variable Names (Mission-Agnostic Keys) -# ============================================================================ - -# Standard variable names that should be present in image matching results -# Used for extracting data from xarray.Dataset objects -STANDARD_VAR_NAMES = { - # Error measurements (required) - "lat_error_deg": "lat_error_deg", - "lon_error_deg": "lon_error_deg", - # Spacecraft state (configurable names) - "spacecraft_position": "sc_position", # Generic default - "boresight": "boresight", # Generic default - "transformation_matrix": "t_inst2ref", # Generic default - # Control point location (optional) - "gcp_lat_deg": "gcp_lat_deg", - "gcp_lon_deg": "gcp_lon_deg", - "gcp_alt": "gcp_alt", -} - - -# ============================================================================ -# Internal Adapter Functions (Correction <-> Image Matching) -# ============================================================================ - - -def _geolocated_to_image_grid(geo_dataset: xr.Dataset): - """ - Convert Correction geolocation output to ImageGrid for image matching. - - Internal adapter function: converts xarray.Dataset from geolocation step - to ImageGrid format expected by image_match module. - - Args: - geo_dataset: xarray.Dataset with latitude, longitude, altitude/height - - Returns: - ImageGrid suitable for integrated_image_match() - """ - from curryer.correction.data_structures import ImageGrid - - lat = geo_dataset["latitude"].values - lon = geo_dataset["longitude"].values - - # Try different field names for altitude/height - if "altitude" in geo_dataset: - h = geo_dataset["altitude"].values - elif "height" in geo_dataset: - h = geo_dataset["height"].values - else: - h = np.zeros_like(lat) - - # Get actual radiance/reflectance data when available - if "radiance" in geo_dataset: - data = geo_dataset["radiance"].values - elif "reflectance" in geo_dataset: - data = geo_dataset["reflectance"].values - else: - data = np.ones_like(lat) - - return ImageGrid(data=data, lat=lat, lon=lon, h=h) - - -def _extract_spacecraft_position_midframe(telemetry: pd.DataFrame) -> np.ndarray: - """ - Extract spacecraft position at mid-frame from telemetry. - - Internal adapter function: extracts position from telemetry DataFrame - with fallback logic for different column naming conventions. - - Args: - telemetry: Telemetry DataFrame with spacecraft position columns - - Returns: - np.ndarray, shape (3,) - [x, y, z] position in meters (J2000 frame) - - Raises: - ValueError: If position columns cannot be found - """ - mid_idx = len(telemetry) // 2 - - # Try common column name patterns - position_patterns = [ - ["sc_pos_x", "sc_pos_y", "sc_pos_z"], - ["position_x", "position_y", "position_z"], - ["r_x", "r_y", "r_z"], - ["pos_x", "pos_y", "pos_z"], - ] - - for cols in position_patterns: - if all(c in telemetry.columns for c in cols): - position = telemetry[cols].iloc[mid_idx].values.astype(np.float64) - logger.debug(f"Extracted spacecraft position from columns {cols}: {position}") - return position - - # If patterns don't match, try to find any column containing 'pos' or 'r_' - pos_cols = [c for c in telemetry.columns if "pos" in c.lower() or c.startswith("r_")] - if len(pos_cols) >= 3: - logger.warning(f"Using first 3 position-like columns: {pos_cols[:3]}") - return telemetry[pos_cols[:3]].iloc[mid_idx].values.astype(np.float64) - - raise ValueError(f"Cannot find position columns in telemetry. Available columns: {telemetry.columns.tolist()}") - - -# Configuration Loading Functions - - -def load_config_from_json(config_path: Path) -> "CorrectionConfig": - """Load correction configuration from a JSON file. - - Args: - config_path: Path to the JSON configuration file (e.g., gcs_config.json) - - Returns: - CorrectionConfig object populated from the JSON file - - Raises: - FileNotFoundError: If config file doesn't exist - ValueError: If config file format is invalid - KeyError: If required config sections are missing - """ - config_path = Path(config_path) - - if not config_path.exists(): - raise FileNotFoundError(f"Configuration file not found: {config_path}") - - logger.info(f"Loading Correction configuration from: {config_path}") - - try: - with open(config_path) as f: - config_data = json.load(f) - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON in config file {config_path}: {e}") - - # Extract mission configuration and kernel mappings - mission_config = correction_config.extract_mission_config(config_data) - constant_kernel_map = correction_config.get_kernel_mapping(config_data, "constant_kernel") - offset_kernel_map = correction_config.get_kernel_mapping(config_data, "offset_kernel") - - logger.debug(f"Mission: {mission_config.get('mission_name', 'UNKNOWN')}") - logger.debug(f"Constant kernel mappings: {constant_kernel_map}") - logger.debug(f"Offset kernel mappings: {offset_kernel_map}") - - # Validate required sections exist - if "correction" not in config_data: - raise KeyError("Missing required 'correction' section in config file") - if "geolocation" not in config_data: - raise KeyError("Missing required 'geolocation' section in config file") - - # Extract correction section - corr_config = config_data.get("correction", {}) - geo_config = config_data.get("geolocation", {}) - - # Validate correction section - if "parameters" not in corr_config: - raise KeyError("Missing required 'parameters' in correction section") - if not isinstance(corr_config["parameters"], list): - raise ValueError("'parameters' must be a list") - if len(corr_config["parameters"]) == 0: - raise ValueError("No parameters defined in configuration") - - # Parse parameters and group related ones together - parameters = [] - param_groups = {} - - # First pass: group parameters by their base name and type - for param_dict in corr_config.get("parameters", []): - param_name = param_dict.get("name", "") - ptype_str = param_dict.get("parameter_type", "CONSTANT_KERNEL") - ptype = ParameterType[ptype_str] - - # Group CONSTANT_KERNEL parameters by their base frame name - if ptype == ParameterType.CONSTANT_KERNEL: - # Extract base name (e.g., "hysics_to_cradle" from "hysics_to_cradle_roll") - if "_roll" in param_name: - base_name = param_name.replace("_roll", "") - angle_type = "roll" - elif "_pitch" in param_name: - base_name = param_name.replace("_pitch", "") - angle_type = "pitch" - elif "_yaw" in param_name: - base_name = param_name.replace("_yaw", "") - angle_type = "yaw" - else: - base_name = param_name - angle_type = "single" - - if base_name not in param_groups: - param_groups[base_name] = {"type": ptype, "angles": {}, "template": param_dict, "config_file": None} - - param_groups[base_name]["angles"][angle_type] = param_dict.get("initial_value", 0.0) - - # Determine config file based on kernel mapping from config - kernel_file = correction_config.find_kernel_file(base_name, constant_kernel_map) - if kernel_file: - param_groups[base_name]["config_file"] = Path(kernel_file) - logger.debug(f"Mapped CONSTANT_KERNEL '{base_name}' → {kernel_file}") - else: - logger.warning(f"No kernel mapping found for CONSTANT_KERNEL parameter: {base_name}") - - else: - # OFFSET_KERNEL and OFFSET_TIME parameters are individual - param_groups[param_name] = {"type": ptype, "param_dict": param_dict, "config_file": None} - - if ptype == ParameterType.OFFSET_KERNEL: - # Determine config file based on kernel mapping from config - kernel_file = correction_config.find_kernel_file(param_name, offset_kernel_map) - if kernel_file: - param_groups[param_name]["config_file"] = Path(kernel_file) - logger.debug(f"Mapped OFFSET_KERNEL '{param_name}' → {kernel_file}") - else: - logger.warning(f"No kernel mapping found for OFFSET_KERNEL parameter: {param_name}") - - # Second pass: create ParameterConfig objects from groups - for group_name, group_data in param_groups.items(): - if group_data["type"] == ParameterType.CONSTANT_KERNEL: - # For CONSTANT_KERNEL, combine roll/pitch/yaw into a single parameter - template = group_data["template"] - angles = group_data["angles"] - - # Create center values array [roll, pitch, yaw] with defaults of 0.0 - center_values = [angles.get("roll", 0.0), angles.get("pitch", 0.0), angles.get("yaw", 0.0)] - - param_data = { - "center": center_values, - "arange": template.get("bounds", [-100, 100]), - "sigma": template.get("sigma"), - "units": template.get("units", "arcseconds"), - "distribution": template.get("distribution_type", "normal"), - "field": template.get("application_target", {}).get("field_name", None), - } - - else: - # For OFFSET_KERNEL and OFFSET_TIME, use the parameter as-is - param_dict = group_data["param_dict"] - param_data = { - "center": param_dict.get("initial_value", 0.0), - "arange": param_dict.get("bounds", [-100, 100]), - "sigma": param_dict.get("sigma"), - "units": param_dict.get("units", "radians"), - "distribution": param_dict.get("distribution_type", "normal"), - "field": param_dict.get("application_target", {}).get("field_name", None), - } - - parameters.append( - ParameterConfig(ptype=group_data["type"], config_file=group_data["config_file"], data=param_data) - ) - - logger.info( - f"Loaded {len(parameters)} parameter groups from {len(corr_config.get('parameters', []))} individual parameters" - ) - - # Parse geolocation configuration - # Use instrument_name from geolocation config, falling back to mission_config - # If not specified, raise error - instrument name is required - default_instrument = mission_config.get("instrument_name") - instrument_name = geo_config.get("instrument_name", default_instrument) - if instrument_name is None: - raise ValueError("instrument_name must be specified in config (either in geolocation or mission section)") - - # Time field is required - no default to avoid mission-specific assumptions - time_field = geo_config.get("time_field") - if time_field is None: - raise ValueError("time_field must be specified in geolocation config") - - geo = GeolocationConfig( - meta_kernel_file=Path(geo_config.get("meta_kernel_file", "")), - generic_kernel_dir=Path(geo_config.get("generic_kernel_dir", "")), - dynamic_kernels=[Path(k) for k in geo_config.get("dynamic_kernels", [])], - instrument_name=instrument_name, - time_field=time_field, - ) - - # Extract required mission-specific parameters from correction section - # These MUST be provided in the config file - no defaults - earth_radius_m = corr_config.get("earth_radius_m") - if earth_radius_m is None: - raise KeyError( - "Missing required 'earth_radius_m' in correction config section. " - "This must be specified for your mission (e.g., 6378140.0 for WGS84)." - ) - - performance_threshold_m = corr_config.get("performance_threshold_m") - if performance_threshold_m is None: - raise KeyError( - "Missing required 'performance_threshold_m' in correction config section. " - "This must be specified for your mission (e.g., 250.0 meters for CLARREO)." - ) - - performance_spec_percent = corr_config.get("performance_spec_percent") - if performance_spec_percent is None: - raise KeyError( - "Missing required 'performance_spec_percent' in correction config section. " - "This must be specified for your mission (e.g., 39.0 percent for CLARREO)." - ) - - # Create CorrectionConfig - config = CorrectionConfig( - seed=corr_config.get("seed"), - n_iterations=corr_config.get("n_iterations", 10), - parameters=parameters, - geo=geo, - performance_threshold_m=performance_threshold_m, - performance_spec_percent=performance_spec_percent, - earth_radius_m=earth_radius_m, - ) - - # Validate the loaded configuration - config.validate() - - logger.info( - f"Configuration loaded and validated: {config.n_iterations} iterations, " - f"{len(config.parameters)} parameter groups" - ) - return config - - -# ============================================================================ -# ADAPTER FUNCTIONS -# ============================================================================ - - -def image_matching( - geolocated_data: xr.Dataset, - gcp_reference_file: Path, - telemetry: pd.DataFrame, - calibration_dir: Path, - params_info: list, - config: "CorrectionConfig", - los_vectors_cached: np.ndarray | None = None, - optical_psfs_cached: list | None = None, -) -> xr.Dataset: - """ - Image matching using integrated_image_match() module. - - This function performs actual image correlation between geolocated - pixels and Landsat GCP reference imagery. - - Args: - geolocated_data: xarray.Dataset with latitude, longitude from geolocation - gcp_reference_file: Path to GCP reference image (MATLAB .mat file) - telemetry: Telemetry DataFrame with spacecraft state - calibration_dir: Directory containing calibration files (LOS vectors, PSF) - params_info: Current parameter values for error tracking - config: CorrectionConfig with coordinate name mappings - los_vectors_cached: Pre-loaded LOS vectors (optional, for performance) - optical_psfs_cached: Pre-loaded optical PSF entries (optional, for performance) - - Returns: - xarray.Dataset with error measurements in format expected by error_stats: - - lat_error_deg, lon_error_deg: Spatial errors in degrees - - Additional metadata for error statistics processing - - Raises: - FileNotFoundError: If calibration files are missing - ValueError: If geolocation data is invalid - """ - logger.info(f"Image Matching: correlation with {gcp_reference_file.name}") - start_time = time.time() - - # Convert geolocation output to ImageGrid - logger.info(" Converting geolocation data to ImageGrid format...") - subimage = _geolocated_to_image_grid(geolocated_data) - logger.info(f" Subimage shape: {subimage.data.shape}") - - # Load GCP reference image - logger.info(f" Loading GCP reference from {gcp_reference_file}...") - gcp = load_image_grid_from_mat(gcp_reference_file, key="GCP") - # Get GCP center location (center pixel) - gcp_center_lat = float(gcp.lat[gcp.lat.shape[0] // 2, gcp.lat.shape[1] // 2]) - gcp_center_lon = float(gcp.lon[gcp.lon.shape[0] // 2, gcp.lon.shape[1] // 2]) - logger.info(f" GCP shape: {gcp.data.shape}, center: ({gcp_center_lat:.4f}, {gcp_center_lon:.4f})") - - # Use cached calibration data if available, otherwise load - logger.info(" Loading calibration data...") - - if los_vectors_cached is not None and optical_psfs_cached is not None: - # Use cached data (fast path) - los_vectors = los_vectors_cached - optical_psfs = optical_psfs_cached - logger.info(" Using cached calibration data") - else: - # Load from files - # Use configurable calibration file names - los_filename = config.get_calibration_file("los_vectors", default="b_HS.mat") - los_file = calibration_dir / los_filename - los_vectors = load_los_vectors_from_mat(los_file) - logger.info(f" LOS vectors: {los_vectors.shape}") - - psf_filename = config.get_calibration_file("optical_psf", default="optical_PSF_675nm_upsampled.mat") - psf_file = calibration_dir / psf_filename - optical_psfs = load_optical_psf_from_mat(psf_file) - logger.info(f" Optical PSF: {len(optical_psfs)} entries") - - # Extract spacecraft position from telemetry - r_iss_midframe = _extract_spacecraft_position_midframe(telemetry) - logger.info(f" Spacecraft position: {r_iss_midframe}") - - # Run real image matching - logger.info(" Running integrated_image_match()...") - geolocation_config = ImageMatchGeolocationConfig() - search_config = SearchConfig() - - result = integrated_image_match( - subimage=subimage, - gcp=gcp, - r_iss_midframe_m=r_iss_midframe, - los_vectors_hs=los_vectors, - optical_psfs=optical_psfs, - geolocation_config=geolocation_config, - search_config=search_config, - ) - - # Convert IntegratedImageMatchResult to xarray.Dataset format - logger.info(" Converting results to error_stats format...") - - # Create single measurement result (image matching produces one correlation per GCP) - - # NOTE: Boresight and transformation matrix for error_stats module - # ---------------------------------------------------------------- - # These values are NOT used by image_matching() itself - the image correlation - # is complete and accurate without them. They are needed by call_error_stats_module() - # for converting off-nadir errors to nadir-equivalent errors. - # - # Currently using simplified nadir assumptions which are acceptable for: - # - Near-nadir observations (< ~5 degrees off-nadir) - # - Testing image matching correlation accuracy (doesn't affect matching) - # - # For accurate nadir-equivalent error conversion with off-nadir pointing, these - # should be extracted from SPICE/geolocation data: - # - boresight: Extract from spicierpy.getfov(instrument) and transform via geo_dataset['attitude'] - # - t_matrix: Extract from geo_dataset['attitude'] (transformation from instrument to CTRS) - # - # See: geolocation_error_stats.py _transform_boresight_vectors() for usage - # See: BORESIGHT_TRANSFORM_ANALYSIS.md for detailed analysis and future enhancement plan - - t_matrix = np.eye(3) # Simplified: Identity matrix (no rotation) - boresight = np.array([0.0, 0.0, 1.0]) # Simplified: Nadir pointing assumption - - # Convert errors from km to degrees - lat_error_deg = result.lat_error_km / 111.0 # ~111 km per degree latitude - lon_radius_km = 6378.0 * np.cos(np.deg2rad(gcp_center_lat)) - lon_error_deg = result.lon_error_km / (lon_radius_km * np.pi / 180.0) - - processing_time = time.time() - start_time - - logger.info(f" Image matching complete in {processing_time:.2f}s:") - logger.info(f" Lat error: {result.lat_error_km:.3f} km ({lat_error_deg:.6f}°)") - logger.info(f" Lon error: {result.lon_error_km:.3f} km ({lon_error_deg:.6f}°)") - logger.info(f" Correlation: {result.ccv_final:.4f}") - logger.info(f" Grid step: {result.final_grid_step_m:.1f} m") - - # Get coordinate names from config - sc_pos_name = config.spacecraft_position_name - boresight_name = config.boresight_name - transform_name = config.transformation_matrix_name - - # Create output dataset in error_stats format (use config names) - output = xr.Dataset( - { - "lat_error_deg": (["measurement"], [lat_error_deg]), - "lon_error_deg": (["measurement"], [lon_error_deg]), - sc_pos_name: (["measurement", "xyz"], [r_iss_midframe]), - boresight_name: (["measurement", "xyz"], [boresight]), - transform_name: (["measurement", "xyz_from", "xyz_to"], t_matrix[np.newaxis, :, :]), - "gcp_lat_deg": (["measurement"], [gcp_center_lat]), - "gcp_lon_deg": (["measurement"], [gcp_center_lon]), - "gcp_alt": (["measurement"], [0.0]), # GCP at ground level - }, - coords={"measurement": [0], "xyz": ["x", "y", "z"], "xyz_from": ["x", "y", "z"], "xyz_to": ["x", "y", "z"]}, - ) - - # Add detailed metadata (Fix #3 Part B: Add km errors to attrs) - output.attrs.update( - { - "lat_error_km": result.lat_error_km, - "lon_error_km": result.lon_error_km, - "correlation_ccv": result.ccv_final, - "final_grid_step_m": result.final_grid_step_m, - "final_index_row": result.final_index_row, - "final_index_col": result.final_index_col, - "processing_time_s": processing_time, - "gcp_file": str(gcp_reference_file.name), - "gcp_center_lat": gcp_center_lat, - "gcp_center_lon": gcp_center_lon, - } - ) - - return output - - -def call_error_stats_module(image_matching_results, correction_config: "CorrectionConfig"): - """ - Call the error_stats module with image matching output. - - Args: - image_matching_results: Either a single image matching result (xarray.Dataset) - or a list of image matching results from multiple GCP pairs - correction_config: CorrectionConfig with all configuration (REQUIRED) - - Returns: - Aggregate error statistics dataset - """ - # Handle both single result and list of results - if not isinstance(image_matching_results, list): - image_matching_results = [image_matching_results] - - try: - from curryer.correction.geolocation_error_stats import ErrorStatsProcessor - from curryer.correction.geolocation_error_stats import GeolocationConfig as ErrorStatsGeolocationConfig - - logger.info(f"Error Statistics: Processing geolocation errors from {len(image_matching_results)} GCP pairs") - - # Create error stats config directly from Correction config (single source of truth) - error_config = ErrorStatsGeolocationConfig.from_correction_config(correction_config) - - processor = ErrorStatsProcessor(config=error_config) - - if len(image_matching_results) == 1: - # Single GCP pair case - error_results = processor.process_geolocation_errors(image_matching_results[0]) - else: - # Multiple GCP pairs - aggregate the data first - aggregated_data = _aggregate_image_matching_results(image_matching_results, correction_config) - error_results = processor.process_geolocation_errors(aggregated_data) - - return error_results - - except ImportError as e: - logger.warning(f"Error stats module not available: {e}") - logger.info(f"Error Statistics: Using placeholder calculations for {len(image_matching_results)} GCP pairs") - - # Fallback: compute basic statistics across all GCP pairs - all_lat_errors = [] - all_lon_errors = [] - total_measurements = 0 - - for result in image_matching_results: - lat_errors = result["lat_error_deg"].values - lon_errors = result["lon_error_deg"].values - all_lat_errors.extend(lat_errors) - all_lon_errors.extend(lon_errors) - total_measurements += len(lat_errors) - - all_lat_errors = np.array(all_lat_errors) - all_lon_errors = np.array(all_lon_errors) - - # Convert to meters (approximate) - lat_error_m = all_lat_errors * 111000 - lon_error_m = all_lon_errors * 111000 - total_error_m = np.sqrt(lat_error_m**2 + lon_error_m**2) - - mean_error = float(np.mean(total_error_m)) - rms_error = float(np.sqrt(np.mean(total_error_m**2))) - std_error = float(np.std(total_error_m)) - - return xr.Dataset( - { - "mean_error": mean_error, - "rms_error": rms_error, - "std_error": std_error, - "max_error": float(np.max(total_error_m)), - "min_error": float(np.min(total_error_m)), - } - ) - - -def _aggregate_image_matching_results(image_matching_results, config: "CorrectionConfig"): - """ - Aggregate multiple image matching results into a single dataset for error stats processing. - - Args: - image_matching_results: List of xarray.Dataset objects from image matching - config: CorrectionConfig with coordinate name mappings - - Returns: - Single aggregated xarray.Dataset with all measurements combined - """ - logger.info(f"Aggregating {len(image_matching_results)} image matching results") - - # Get coordinate names from config - sc_pos_name = config.spacecraft_position_name - boresight_name = config.boresight_name - transform_name = config.transformation_matrix_name - - # Combine all measurements into single arrays - all_lat_errors = [] - all_lon_errors = [] - all_sc_positions = [] - all_boresights = [] - all_transforms = [] - all_gcp_lats = [] - all_gcp_lons = [] - all_gcp_alts = [] - - for i, result in enumerate(image_matching_results): - # Add GCP pair identifier to track source - n_measurements = len(result["lat_error_deg"]) - - all_lat_errors.extend(result["lat_error_deg"].values) - all_lon_errors.extend(result["lon_error_deg"].values) - - # Handle coordinate transformation data (use config names) - # NOTE: Individual results have shape (1, 3) for vectors and (1, 3, 3) for matrices - if sc_pos_name in result: - # Shape: (1, 3) -> extract as (3,) for each measurement - for j in range(n_measurements): - all_sc_positions.append(result[sc_pos_name].values[j]) - if boresight_name in result: - # Shape: (1, 3) -> extract as (3,) for each measurement - for j in range(n_measurements): - all_boresights.append(result[boresight_name].values[j]) - if transform_name in result: - # Shape: (1, 3, 3) -> extract as (3, 3) for each measurement - for j in range(n_measurements): - all_transforms.append(result[transform_name].values[j, :, :]) - if "gcp_lat_deg" in result: - all_gcp_lats.extend(result["gcp_lat_deg"].values) - if "gcp_lon_deg" in result: - all_gcp_lons.extend(result["gcp_lon_deg"].values) - if "gcp_alt" in result: - all_gcp_alts.extend(result["gcp_alt"].values) - - n_total = len(all_lat_errors) - - # Create aggregated dataset with correct dimension names for error_stats - aggregated = xr.Dataset( - { - "lat_error_deg": (["measurement"], np.array(all_lat_errors)), - "lon_error_deg": (["measurement"], np.array(all_lon_errors)), - }, - coords={"measurement": np.arange(n_total)}, - ) - - # Add optional coordinate transformation data if available (use config names) - # Use dimension names that match error_stats expectations - if all_sc_positions: - # Stack into (n_measurements, 3) - aggregated[sc_pos_name] = (["measurement", "xyz"], np.array(all_sc_positions)) - aggregated = aggregated.assign_coords({"xyz": ["x", "y", "z"]}) - - if all_boresights: - # Stack into (n_measurements, 3) - aggregated[boresight_name] = (["measurement", "xyz"], np.array(all_boresights)) - - if all_transforms: - # Stack into (n_measurements, 3, 3) to match error_stats format - t_stacked = np.stack(all_transforms, axis=0) - aggregated[transform_name] = (["measurement", "xyz_from", "xyz_to"], t_stacked) - aggregated = aggregated.assign_coords({"xyz_from": ["x", "y", "z"], "xyz_to": ["x", "y", "z"]}) - - if all_gcp_lats: - aggregated["gcp_lat_deg"] = (["measurement"], np.array(all_gcp_lats)) - if all_gcp_lons: - aggregated["gcp_lon_deg"] = (["measurement"], np.array(all_gcp_lons)) - if all_gcp_alts: - aggregated["gcp_alt"] = (["measurement"], np.array(all_gcp_alts)) - - aggregated.attrs["source_gcp_pairs"] = len(image_matching_results) - aggregated.attrs["total_measurements"] = n_total - - logger.info(f" Aggregated dataset: {n_total} measurements from {len(image_matching_results)} GCP pairs") - logger.info(f" Dimensions: {dict(aggregated.sizes)}") - - return aggregated - - -# Original Functions - - -class ParameterType(Enum): - CONSTANT_KERNEL = auto() # Set a specific value. - OFFSET_KERNEL = auto() # Modify input kernel data by an offset. - OFFSET_TIME = auto() # Modify input timetags by an offset - - -@dataclass -class ParameterConfig: - ptype: ParameterType - config_file: Path | None - data: typing.Any - - -@dataclass -class GeolocationConfig: - meta_kernel_file: Path - generic_kernel_dir: Path - dynamic_kernels: [Path] # Kernels that are dynamic but *not* altered by param! - instrument_name: str - time_field: str - minimum_correlation: float | None = None # Filter threshold for image matching quality (0.0-1.0) - - -@dataclass -class NetCDFParameterMetadata: - """Metadata for a single parameter in NetCDF output.""" - - variable_name: str # NetCDF variable name (e.g., 'param_hysics_roll') - units: str # Units (e.g., 'arcseconds', 'milliseconds') - long_name: str # Human-readable description - - -@dataclass -class NetCDFConfig: - """Configuration for NetCDF output structure and metadata. - - This class defines the structure and metadata for NetCDF output files. - All mission-specific information should be provided here rather than - hardcoded in the correction module. - - The performance_threshold_m is required and should match the value in - CorrectionConfig. It's used to generate threshold-specific variable names - in the NetCDF output (e.g., "percent_under_250m"). - """ - - performance_threshold_m: float # Required: accuracy threshold in meters - title: str = "Correction Geolocation Analysis Results" - description: str = "Parameter sensitivity analysis" - - # Parameter metadata - maps parameter config to NetCDF metadata - # If None, will be auto-generated from config.parameters - parameter_metadata: dict[str, NetCDFParameterMetadata] | None = None - - # Standard variable attributes - allows mission-specific overrides - # If None, uses STANDARD_NETCDF_ATTRIBUTES module constant - standard_attributes: dict[str, dict[str, str]] | None = None - - def get_threshold_metric_name(self) -> str: - """Generate metric name dynamically from threshold.""" - threshold_m = int(self.performance_threshold_m) - return f"percent_under_{threshold_m}m" - - def get_standard_attributes(self) -> dict[str, dict[str, str]]: - """ - Get standard variable attributes, using mission overrides if provided. - - Returns: - Dictionary mapping variable names to their attributes (units, long_name) - """ - if self.standard_attributes is not None: - # Use mission-specific overrides - return self.standard_attributes - else: - # Use module-level defaults - return STANDARD_NETCDF_ATTRIBUTES.copy() - - def get_parameter_netcdf_metadata( - self, param_config: ParameterConfig, angle_type: str | None = None - ) -> NetCDFParameterMetadata: - """ - Get NetCDF metadata for a parameter. - - Args: - param_config: Parameter configuration - angle_type: For CONSTANT_KERNEL parameters: 'roll', 'pitch', or 'yaw' - - Returns: - NetCDFParameterMetadata with variable name, units, and description - """ - # Generate key for lookup - if param_config.config_file: - param_stem = param_config.config_file.stem - if angle_type: - lookup_key = f"{param_stem}_{angle_type}" - else: - lookup_key = param_stem - else: - lookup_key = f"param_{param_config.ptype.name.lower()}" - - # Try to find in provided metadata - if self.parameter_metadata and lookup_key in self.parameter_metadata: - return self.parameter_metadata[lookup_key] - - # Auto-generate if not provided - return self._auto_generate_metadata(param_config, angle_type, lookup_key) - - def _auto_generate_metadata( - self, param_config: ParameterConfig, angle_type: str | None, base_key: str - ) -> NetCDFParameterMetadata: - """Auto-generate NetCDF metadata from parameter configuration.""" - - # Determine units based on parameter type - if param_config.ptype == ParameterType.CONSTANT_KERNEL: - units = "arcseconds" - elif param_config.ptype == ParameterType.OFFSET_KERNEL: - units = "arcseconds" # Typical for angle offsets - elif param_config.ptype == ParameterType.OFFSET_TIME: - units = "milliseconds" - else: - units = "unknown" - - # Check if units specified in parameter data - if isinstance(param_config.data, dict) and "units" in param_config.data: - units = param_config.data["units"] - - # Generate variable name (ensure it starts with 'param_') - var_name = base_key.replace(".", "_").replace("-", "_") - if not var_name.startswith("param_"): - var_name = f"param_{var_name}" - - # Generate human-readable description - if param_config.config_file: - # Extract frame names from config file path - file_stem = param_config.config_file.stem - # Remove version numbers and file extensions - clean_name = file_stem.replace("_v01", "").replace("_v02", "").replace(".attitude.ck", "") - clean_name = clean_name.replace("_", " ").title() - - if angle_type: - long_name = f"{clean_name} {angle_type} correction" - else: - long_name = f"{clean_name} correction" - else: - long_name = f"{param_config.ptype.name.replace('_', ' ').title()} parameter" - - return NetCDFParameterMetadata(variable_name=var_name, units=units, long_name=long_name) - - -@dataclass -class CorrectionConfig: - """The configuration object for geolocation correction analysis. - - This config contains everything needed for a Correction run: - - What parameters to vary (parameters list) - - How to vary them (seed, n_iterations) - - How to load data (telemetry_loader, science_loader) - - How to process data (gcp_pairing_func, image_matching_func) - - Geolocation settings (geo: GeolocationConfig) - - Success criteria (performance_threshold_m, performance_spec_percent) - - Output configuration (netcdf: NetCDFConfig, output_filename) - - Create one CorrectionConfig object and pass it to correction.loop() to run. - - Parameters - ---------- - CORE CORRECTION SETTINGS (Required - define what the analysis does): - seed : Optional[int] - Random seed for reproducibility, or None for non-reproducible runs. - n_iterations : int - Number of parameter set iterations (e.g., 5, 100, 1000). - parameters : list[ParameterConfig] - List of parameters to vary (defines sensitivity analysis). - - GEOLOCATION & PERFORMANCE REQUIREMENTS (Required - mission-specific settings): - geo : GeolocationConfig - SPICE kernels and instrument configuration. - performance_threshold_m : float - Nadir-equivalent accuracy threshold in meters (e.g., 250.0 for CLARREO). - performance_spec_percent : float - Required percentage of observations meeting threshold (e.g., 39.0 for CLARREO). - earth_radius_m : float - Earth radius for geodetic calculations (e.g., 6378137.0 for WGS84). - geo : GeolocationConfig - SPICE kernels and instrument configuration. - - DATA LOADERS (Required for pipeline execution - mission-specific implementations): - telemetry_loader : Optional[TelemetryLoader], default=None - Load spacecraft telemetry. Must be set before calling correction.loop(). - science_loader : Optional[ScienceLoader], default=None - Load science frame timing. Must be set before calling correction.loop(). - gcp_loader : Optional[GCPLoader], default=None - Load GCP reference data (optional). - - PROCESSING FUNCTIONS (Optional - will use defaults/stubs if not provided): - gcp_pairing_func : Optional[GCPPairingFunc], default=None - Spatial pairing of science data to GCP. - image_matching_func : Optional[ImageMatchingFunc], default=None - Image correlation for errors. - - OUTPUT CONFIGURATION (Optional - sensible defaults provided): - netcdf : Optional[NetCDFConfig], default=None - NetCDF metadata (auto-generated if None). - output_filename : Optional[str], default=None - Output filename (auto-generates with timestamp if None). - - CALIBRATION CONFIGURATION (Optional - only needed when image_matching_func uses calibration): - calibration_dir : Optional[Path], default=None - Directory with LOS vectors, optical PSF, GCP files. - Set when using image_matching_func that requires calibration files. - calibration_file_names : Optional[dict[str, str]], default=None - Mission-specific calibration filenames. - Example: {'los_vectors': 'b_HS.mat', 'optical_psf': 'optical_PSF_675nm_upsampled.mat'} - - MISSION-SPECIFIC NAMING (Optional - override generic defaults): - spacecraft_position_name : str, default="sc_position" - Variable name for spacecraft position in output NetCDF. - boresight_name : str, default="boresight" - Variable name for boresight in output NetCDF. - transformation_matrix_name : str, default="t_inst2ref" - Variable name for transformation matrix in output NetCDF. - """ - - # CORE CORRECTION SETTINGS - seed: int | None - n_iterations: int - parameters: list[ParameterConfig] - - # GEOLOCATION & PERFORMANCE REQUIREMENTS - geo: GeolocationConfig - performance_threshold_m: float - performance_spec_percent: float - earth_radius_m: float - - # DATA LOADERS - telemetry_loader: TelemetryLoader | None = None - science_loader: ScienceLoader | None = None - gcp_loader: GCPLoader | None = None - - # PROCESSING FUNCTIONS - gcp_pairing_func: GCPPairingFunc | None = None - image_matching_func: ImageMatchingFunc | None = None - - # OUTPUT CONFIGURATION - netcdf: NetCDFConfig | None = None - output_filename: str | None = None - - # CALIBRATION CONFIGURATION - calibration_dir: Path | None = None - calibration_file_names: dict[str, str] | None = None - - # MISSION-SPECIFIC NAMING - spacecraft_position_name: str = "sc_position" - boresight_name: str = "boresight" - transformation_matrix_name: str = "t_inst2ref" - - def get_calibration_file(self, file_type: str, default: str = None) -> str: - """Get calibration filename for given type with fallback to default.""" - if self.calibration_file_names and file_type in self.calibration_file_names: - return self.calibration_file_names[file_type] - if default: - return default - raise ValueError(f"No calibration file configured for type: {file_type}") - - def validate(self, check_loaders: bool = False): - """Validate that all required configuration values are present. - - Args: - check_loaders: If True, validate that loaders are present. - Set to False when validating configs during creation, - before loaders have been added. - - Raises: - ValueError: If any required fields are missing or invalid - """ - errors = [] - - # Check required fields - if self.n_iterations is None or self.n_iterations <= 0: - errors.append("n_iterations must be a positive integer") - - if self.parameters is None or len(self.parameters) == 0: - errors.append("parameters list cannot be empty") - - if self.geo is None: - errors.append("geo (GeolocationConfig) is required") - - if self.earth_radius_m is None or self.earth_radius_m <= 0: - errors.append("earth_radius_m must be a positive number (e.g., 6378140.0 for WGS84)") - - if self.performance_threshold_m is None or self.performance_threshold_m <= 0: - errors.append("performance_threshold_m must be a positive number (e.g., 250.0 meters)") - - if self.performance_spec_percent is None or not (0 <= self.performance_spec_percent <= 100): - errors.append("performance_spec_percent must be between 0 and 100 (e.g., 39.0)") - - # Check required data loaders (Config-Centric Design) - only if requested - if check_loaders: - if self.telemetry_loader is None: - errors.append( - "telemetry_loader is required.\n" - " Add to config: config.telemetry_loader = load_your_telemetry\n" - " Example: from your_loaders import load_mission_telemetry\n" - " config.telemetry_loader = load_mission_telemetry" - ) - - if self.science_loader is None: - errors.append( - "science_loader is required.\n" - " Add to config: config.science_loader = load_your_science\n" - " Example: from your_loaders import load_mission_science\n" - " config.science_loader = load_mission_science" - ) - - if errors: - error_msg = "CorrectionConfig validation failed:\n - " + "\n - ".join(errors) - error_msg += "\n\nThese values must be provided in your mission configuration." - error_msg += "\nSee tests/test_correction/clarreo_config.py for an example." - raise ValueError(error_msg) - - # Check optional processing functions (warnings only) - only if checking loaders - if check_loaders: - if self.gcp_pairing_func is None: - logger.warning( - "gcp_pairing_func not provided - GCP pairing will return empty results.\n" - " For testing: config.gcp_pairing_func = synthetic_gcp_pairing\n" - " For production: config.gcp_pairing_func = real_spatial_pairing" - ) - - if self.image_matching_func is None: - logger.warning( - "image_matching_func not provided - will use empty stub.\n" - " For testing: config.image_matching_func = synthetic_image_matching\n" - " For production: config.image_matching_func = real_image_matching\n" - " and set config.calibration_dir if needed" - ) - - logger.debug("CorrectionConfig validation passed") - - def ensure_netcdf_config(self): - """Ensure NetCDFConfig exists, creating with defaults if needed.""" - if self.netcdf is None: - self.netcdf = NetCDFConfig(performance_threshold_m=self.performance_threshold_m) - - def get_output_filename(self, default: str = "correction_results.nc") -> str: - """ - Get output filename with optional auto-generation. - - Args: - default: Default filename if output_filename is None - - Returns: - Filename string (can include timestamp/parameters if configured) - """ - if self.output_filename: - return self.output_filename - return default - - @staticmethod - def generate_timestamped_filename(prefix: str = "correction", suffix: str = "") -> str: - """ - Generate a timestamped output filename for production use. - - This prevents overwriting previous results and provides unique identifiers. - - Args: - prefix: Filename prefix (e.g., 'correction', 'clarreo_gcs') - suffix: Optional suffix before extension (e.g., 'upstream', 'test') - - Returns: - Filename with format: {prefix}_YYYYMMDD_HHMMSS[_{suffix}].nc - - Examples: - >>> CorrectionConfig.generate_timestamped_filename() - 'correction_20251029_143022.nc' - - >>> CorrectionConfig.generate_timestamped_filename('clarreo_gcs', 'production') - 'clarreo_gcs_20251029_143022_production.nc' - """ - import datetime - - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - if suffix: - return f"{prefix}_{timestamp}_{suffix}.nc" - return f"{prefix}_{timestamp}.nc" - - -def load_param_sets(config: CorrectionConfig) -> [ParameterConfig, typing.Any]: - """ - Generate random parameter sets for Correction iterations. - Each parameter is sampled according to its distribution and bounds. - - The parameter generation works as follows: - - current_value: The baseline/current parameter value - - bounds: The limits for random offsets (in same units as current_value and sigma) - - sigma: Standard deviation for normal distribution of offsets - - Generated offsets are centered around 0, then applied to current_value - - Final value = current_value + random_offset - - Handles all parameter types: - - CONSTANT_KERNEL: 3D attitude corrections (roll, pitch, yaw) - - OFFSET_KERNEL: Single angle biases for telemetry fields - - OFFSET_TIME: Timing corrections for science frames - """ - - if config.seed is not None: - np.random.seed(config.seed) - logger.info(f"Set random seed to {config.seed} for reproducible parameter generation") - - output = [] - - logger.info(f"Generating {config.n_iterations} parameter sets for {len(config.parameters)} parameters:") - for i, param in enumerate(config.parameters): - param_name = param.config_file.name if param.config_file else f"param_{i}" - current_value = param.data.get("current_value", param.data.get("center", 0.0)) - bounds = param.data.get("bounds", param.data.get("arange", [-1.0, 1.0])) - logger.info( - f" {i + 1}. {param_name} ({param.ptype.name}): " - f"current_value={current_value}, sigma={param.data.get('sigma', 'N/A')}, " - f"bounds={bounds}, units={param.data.get('units', 'N/A')}" - ) - - for ith in range(config.n_iterations): - out_set = [] - logger.debug(f"Generating parameter set {ith + 1}/{config.n_iterations}") - - for param_idx, param in enumerate(config.parameters): - # Get parameter configuration with backward compatibility - current_value = param.data.get("current_value", param.data.get("center", 0.0)) - bounds = param.data.get("bounds", param.data.get("arange", [-1.0, 1.0])) - - # Handle different parameter structure types - if param.ptype == ParameterType.CONSTANT_KERNEL: - # CONSTANT_KERNEL parameters are 3D attitude corrections (roll, pitch, yaw) - if isinstance(current_value, list) and len(current_value) == 3: - # Multi-dimensional parameter (roll, pitch, yaw) - param_vals = [] - for i, current_val in enumerate(current_value): - # Check if parameter should be varied - if "sigma" in param.data and param.data["sigma"] is not None and param.data["sigma"] > 0: - # Apply variation: Generate offset around 0, then apply to current_value - if param.data.get("units") == "arcseconds": - # Convert arcsec to radians for sampling - sigma_rad = np.deg2rad(param.data["sigma"] / 3600.0) - current_val_rad = np.deg2rad(current_val / 3600.0) if current_val != 0 else current_val - # Convert bounds from arcsec to radians (these are offset bounds around 0) - bounds_rad = [np.deg2rad(bounds[0] / 3600.0), np.deg2rad(bounds[1] / 3600.0)] - else: - # Assume all values are already in radians - sigma_rad = param.data["sigma"] - current_val_rad = current_val - bounds_rad = bounds - - # Generate offset around 0, clamp to bounds, and add to current value - offset = np.random.normal(0, sigma_rad) - offset = np.clip(offset, bounds_rad[0], bounds_rad[1]) - param_vals.append(current_val_rad + offset) - else: - # No variation: use current_value directly - if "sigma" not in param.data or param.data["sigma"] is None: - logger.debug( - f" Parameter {param_idx} axis {i}: No sigma specified, using fixed current_value" - ) - elif param.data["sigma"] == 0: - logger.debug(f" Parameter {param_idx} axis {i}: sigma=0, using fixed current_value") - - # Convert to appropriate units if needed - if param.data.get("units") == "arcseconds": - current_val_rad = np.deg2rad(current_val / 3600.0) if current_val != 0 else current_val - else: - current_val_rad = current_val - param_vals.append(current_val_rad) - else: - # Single angle or default to zero for each axis - param_vals = [0.0, 0.0, 0.0] # [roll, pitch, yaw] - if "sigma" in param.data and param.data["sigma"] is not None and param.data["sigma"] > 0: - # Apply variation - if param.data.get("units") == "arcseconds": - sigma_rad = np.deg2rad(param.data["sigma"] / 3600.0) - bounds_rad = [np.deg2rad(bounds[0] / 3600.0), np.deg2rad(bounds[1] / 3600.0)] - current_val_rad = ( - np.deg2rad(current_value / 3600.0) if current_value != 0 else current_value - ) - else: - sigma_rad = param.data["sigma"] - bounds_rad = bounds - current_val_rad = current_value - - for i in range(3): - # Generate offset around 0, clamp to bounds, add to current value - offset = np.random.normal(0, sigma_rad) - offset = np.clip(offset, bounds_rad[0], bounds_rad[1]) - param_vals[i] = current_val_rad + offset - else: - # No variation: use current_value directly for all axes - if "sigma" not in param.data or param.data["sigma"] is None: - logger.debug(f" Parameter {param_idx}: No sigma specified, using fixed current_value") - elif param.data["sigma"] == 0: - logger.debug(f" Parameter {param_idx}: sigma=0, using fixed current_value") - - # Convert to appropriate units if needed - if param.data.get("units") == "arcseconds": - current_val_rad = ( - np.deg2rad(current_value / 3600.0) if current_value != 0 else current_value - ) - else: - current_val_rad = current_value - - # Use same value for all three axes (this handles scalar current_value) - param_vals = [current_val_rad, current_val_rad, current_val_rad] - - # Convert to DataFrame format expected by kernel creation - param_vals = pd.DataFrame( - { - "ugps": [0, 2209075218000000], # Start and end times - "angle_x": [param_vals[0], param_vals[0]], # Roll (constant over time) - "angle_y": [param_vals[1], param_vals[1]], # Pitch (constant over time) - "angle_z": [param_vals[2], param_vals[2]], # Yaw (constant over time) - } - ) - - logger.debug( - f" CONSTANT_KERNEL {param_idx}: angles=[{param_vals['angle_x'].iloc[0]:.6e}, " - f"{param_vals['angle_y'].iloc[0]:.6e}, {param_vals['angle_z'].iloc[0]:.6e}] rad" - ) - - elif param.ptype == ParameterType.OFFSET_KERNEL: - # OFFSET_KERNEL parameters are angle biases (single values) - if "sigma" in param.data and param.data["sigma"] is not None and param.data["sigma"] > 0: - # Apply variation: Generate offset around 0, then apply to current_value - if param.data.get("units") == "arcseconds": - # Convert arcsec to radians for sampling - sigma_rad = np.deg2rad(param.data["sigma"] / 3600.0) - current_val_rad = np.deg2rad(current_value / 3600.0) if current_value != 0 else current_value - # Convert bounds from arcsec to radians (these are offset bounds around 0) - bounds_rad = [np.deg2rad(bounds[0] / 3600.0), np.deg2rad(bounds[1] / 3600.0)] - else: - # Assume all values are already in radians - sigma_rad = param.data["sigma"] - current_val_rad = current_value - bounds_rad = bounds - # Generate offset around 0, clamp to bounds, and add to current value - offset = np.random.normal(0, sigma_rad) - offset = np.clip(offset, bounds_rad[0], bounds_rad[1]) - param_vals = current_val_rad + offset - else: - # No variation: use current_value directly - if "sigma" not in param.data or param.data["sigma"] is None: - logger.debug(f" Parameter {param_idx}: No sigma specified, using fixed current_value") - elif param.data["sigma"] == 0: - logger.debug(f" Parameter {param_idx}: sigma=0, using fixed current_value") - - # Convert to appropriate units if needed - if param.data.get("units") == "arcseconds": - current_val_rad = np.deg2rad(current_value / 3600.0) if current_value != 0 else current_value - else: - current_val_rad = current_value - param_vals = current_val_rad - - logger.debug(f" OFFSET_KERNEL {param_idx}: {param_vals:.6e} rad") - - elif param.ptype == ParameterType.OFFSET_TIME: - # OFFSET_TIME parameters are timing corrections (single values) - if "sigma" in param.data and param.data["sigma"] is not None and param.data["sigma"] > 0: - # Apply variation: Generate offset around 0, then apply to current_value - if param.data.get("units") == "seconds": - # Time parameters typically use seconds, no conversion needed - sigma_time = param.data["sigma"] - current_val_time = current_value - bounds_time = bounds - elif param.data.get("units") == "milliseconds": - # Convert milliseconds to seconds - sigma_time = param.data["sigma"] / 1000.0 - current_val_time = current_value / 1000.0 - bounds_time = [bounds[0] / 1000.0, bounds[1] / 1000.0] - elif param.data.get("units") == "microseconds": - # Convert microseconds to seconds - sigma_time = param.data["sigma"] / 1000000.0 - current_val_time = current_value / 1000000.0 - bounds_time = [bounds[0] / 1000000.0, bounds[1] / 1000000.0] - else: - # Default to seconds if units not specified - sigma_time = param.data["sigma"] - current_val_time = current_value - bounds_time = bounds - - # Generate offset around 0, clamp to bounds, then add to current value - offset = np.random.normal(0, sigma_time) - offset = np.clip(offset, bounds_time[0], bounds_time[1]) - param_vals = current_val_time + offset - else: - # No variation: use current_value directly - if "sigma" not in param.data or param.data["sigma"] is None: - logger.debug(f" Parameter {param_idx}: No sigma specified, using fixed current_value") - elif param.data["sigma"] == 0: - logger.debug(f" Parameter {param_idx}: sigma=0, using fixed current_value") - - # Convert to appropriate units if needed - if param.data.get("units") == "milliseconds": - current_val_time = current_value / 1000.0 - elif param.data.get("units") == "microseconds": - current_val_time = current_value / 1000000.0 - else: - current_val_time = current_value - param_vals = current_val_time - - logger.debug(f" OFFSET_TIME {param_idx}: {param_vals:.6e} seconds") - - out_set.append((param, param_vals)) - output.append(out_set) - - # Log summary of generated parameter sets & a table of parameter values for verification - if output: - logger.info(f"Generated {len(output)} parameter sets with {len(output[0])} parameters each") - logger.info("\nParameter Set Summary:") - logger.info("-" * 100) - for param_set_idx, param_set in enumerate(output): - logger.info(f" Set {param_set_idx}:") - for param_idx, (param, param_vals) in enumerate(param_set): - field_name = param.data.get("field", "unknown") - ptype_name = param.ptype.name - - if param.ptype == ParameterType.CONSTANT_KERNEL: - if isinstance(param_vals, pd.DataFrame) and "angle_x" in param_vals.columns: - angles = [ - param_vals["angle_x"].iloc[0], - param_vals["angle_y"].iloc[0], - param_vals["angle_z"].iloc[0], - ] - logger.info( - f" {ptype_name:16s} {field_name:25s}: [{angles[0]:+.6e}, {angles[1]:+.6e}, {angles[2]:+.6e}] rad" - ) - else: - logger.info(f" {ptype_name:16s} {field_name:25s}: (constant kernel data)") - elif param.ptype == ParameterType.OFFSET_KERNEL: - units = param.data.get("units", "") - if units == "arcseconds": - # Convert back to arcseconds for display - param_arcsec = np.rad2deg(param_vals) * 3600.0 - logger.info( - f" {ptype_name:16s} {field_name:25s}: {param_arcsec:+10.3f} arcsec ({param_vals:+.9f} rad)" - ) - else: - logger.info(f" {ptype_name:16s} {field_name:25s}: {param_vals:+.9f} {units}") - elif param.ptype == ParameterType.OFFSET_TIME: - units = param.data.get("units", "") - if units == "milliseconds": - # param_vals is in seconds, convert to ms for display - param_ms = param_vals * 1000.0 - logger.info( - f" {ptype_name:16s} {field_name:25s}: {param_ms:+10.3f} ms ({param_vals:+.9f} s)" - ) - else: - logger.info(f" {ptype_name:16s} {field_name:25s}: {param_vals:+.9f} {units}") - logger.info("-" * 100) - - return output - - -def load_telemetry(tlm_key: str, config: CorrectionConfig, loader_func=None) -> pd.DataFrame: - """ - Load telemetry data using provided mission-specific loader function. - - This is a generic interface. The actual telemetry loading logic should be - provided by the mission-specific loader function. - - Args: - tlm_key: Identifier for telemetry data (path, key, etc.) - config: Correction configuration - loader_func: Mission-specific loader function(tlm_key, config) -> DataFrame - - Returns: - DataFrame with telemetry data - - Raises: - ValueError: If no loader function provided - - Example: - from clarreo_data_loaders import load_clarreo_telemetry - tlm_data = load_telemetry(tlm_key, config, loader_func=load_clarreo_telemetry) - """ - if loader_func is None: - raise ValueError( - "No telemetry loader function provided. " - "Pass loader_func parameter with mission-specific loader.\n" - "Example: load_telemetry(tlm_key, config, loader_func=load_clarreo_telemetry)" - ) - - return loader_func(tlm_key, config) - - -def load_science(sci_key: str, config: CorrectionConfig, loader_func=None) -> pd.DataFrame: - """ - Load science data using provided mission-specific loader function. - - This is a generic interface. The actual science data loading logic should be - provided by the mission-specific loader function. - - Args: - sci_key: Identifier for science data (path, key, etc.) - config: Correction configuration - loader_func: Mission-specific loader function(sci_key, config) -> DataFrame - - Returns: - DataFrame with science data - - Raises: - ValueError: If no loader function provided - - Example: - from clarreo_data_loaders import load_clarreo_science - sci_data = load_science(sci_key, config, loader_func=load_clarreo_science) - """ - if loader_func is None: - raise ValueError( - "No science loader function provided. " - "Pass loader_func parameter with mission-specific loader.\n" - "Example: load_science(sci_key, config, loader_func=load_clarreo_science)" - ) - - return loader_func(sci_key, config) - - -def load_gcp(gcp_key: str, config: CorrectionConfig, loader_func=None): - """ - Load Ground Control Point (GCP) reference data using mission-specific loader. - - This is a generic interface. The actual GCP loading logic should be - provided by the mission-specific loader function. - - Args: - gcp_key: Identifier for GCP data (path, key, etc.) - config: Correction configuration - loader_func: Mission-specific loader function(gcp_key, config) -> GCP data - - Returns: - GCP reference data (format defined by mission) - - Note: - If loader_func is None, returns None (allows placeholder behavior) - - Example: - from clarreo_data_loaders import load_clarreo_gcp - gcp_data = load_gcp(gcp_key, config, loader_func=load_clarreo_gcp) - """ - if loader_func is None: - logger.info(f"No GCP loader provided for: {gcp_key} (returning None)") - return None - - return loader_func(gcp_key, config) - - -def apply_offset(config: ParameterConfig, param_data, input_data): - """ - Apply parameter offsets to input data based on parameter type. - - Args: - config: ParameterConfig specifying how to apply the offset - param_data: The parameter values to apply (offset amounts) - input_data: The input dataset to modify - - Returns: - Modified copy of input_data with parameter offsets applied - """ - logger.info(f"Applying {config.ptype.name} offset to {config.data.get('field', 'unknown field')}") - - # Make a copy to avoid modifying the original - if isinstance(input_data, pd.DataFrame): - modified_data = input_data.copy() - else: - modified_data = input_data.copy() if hasattr(input_data, "copy") else input_data - - if config.ptype == ParameterType.OFFSET_KERNEL: - # Apply offset to telemetry fields for dynamic kernels (azimuth/elevation angles) - # OFFSET_KERNEL is ONLY for angle biases, not time offsets - # Valid units: "arcseconds" (converted to radians) or None (radians assumed) - # For time offsets, use OFFSET_TIME instead - field_name = config.data.get("field") - if not field_name: - raise ValueError("OFFSET_KERNEL parameter requires 'field' to be specified in config") - - if field_name in modified_data.columns: - # Convert parameter value to appropriate units - # OFFSET_KERNEL is for angle biases only (azimuth/elevation angles) - offset_value = param_data - original_value = offset_value - if config.data.get("units") == "arcseconds": - # Convert arcseconds to radians for application - offset_value = np.deg2rad(param_data / 3600.0) - logger.info(f"✓ Applying OFFSET_KERNEL to field '{field_name}'") - logger.info(f" Offset: {original_value:.6f} arcsec = {offset_value:.9f} rad") - else: - # No units specified - assume radians (direct application) - logger.info(f"✓ Applying OFFSET_KERNEL to field '{field_name}'") - logger.info(f" Offset: {offset_value:.9f} rad (no unit conversion)") - - # Store original values for logging - original_mean = modified_data[field_name].mean() - - # Apply additive offset - modified_data[field_name] = modified_data[field_name] + offset_value - - # Log the effect - new_mean = modified_data[field_name].mean() - logger.info(f" Original mean: {original_mean:.9f}") - logger.info(f" New mean: {new_mean:.9f}") - logger.info(f" Delta: {new_mean - original_mean:.9f}") - else: - available_cols = list(modified_data.columns) if hasattr(modified_data, "columns") else [] - logger.warning(f"Field '{field_name}' not found in telemetry data for offset application") - logger.warning(f"Available columns: {available_cols}") - - elif config.ptype == ParameterType.OFFSET_TIME: - # Apply time offset to science frame timing - # NOTE: param_data is in seconds while target field (e.g., corrected_timestamp) is typically in microseconds - field_name = config.data.get("field", "corrected_timestamp") - if hasattr(modified_data, "__getitem__") and field_name in modified_data: - # param_data is already in seconds (converted by load_param_sets) - # Convert seconds to microseconds for the timestamp field - offset_value_seconds = param_data - offset_value_us = param_data * 1000000.0 # seconds to microseconds - - logger.info(f"✓ Applying OFFSET_TIME to field '{field_name}'") - units = config.data.get("units", "seconds") - if units == "milliseconds": - logger.info(f" Offset: {offset_value_seconds * 1000.0:.6f} ms (configured) = {offset_value_us:.6f} µs") - elif units == "microseconds": - logger.info( - f" Offset: {offset_value_seconds * 1000000.0:.6f} µs (configured) = {offset_value_us:.6f} µs" - ) - else: - logger.info(f" Offset: {offset_value_seconds:.6f} s = {offset_value_us:.6f} µs") - - # Store original values for logging - if hasattr(modified_data[field_name], "mean"): - original_mean = modified_data[field_name].mean() - else: - original_mean = np.mean(modified_data[field_name]) - - # Apply additive offset in microseconds - modified_data[field_name] = modified_data[field_name] + offset_value_us - - # Log the effect - if hasattr(modified_data[field_name], "mean"): - new_mean = modified_data[field_name].mean() - else: - new_mean = np.mean(modified_data[field_name]) - logger.info(f" Original mean: {original_mean:.6f}") - logger.info(f" New mean: {new_mean:.6f}") - logger.info(f" Delta: {new_mean - original_mean:.6f}") - else: - logger.warning(f"Field '{field_name}' not found in science data for time offset application") - - elif config.ptype == ParameterType.CONSTANT_KERNEL: - # For constant kernels, param_data should already be in the correct format - # (DataFrame with ugps, angle_x, angle_y, angle_z columns) - logger.info( - f"Using constant kernel data with {len(param_data) if hasattr(param_data, '__len__') else 1} entries" - ) - modified_data = param_data - - else: - raise NotImplementedError(f"Parameter type {config.ptype} not implemented") - - return modified_data - - -def _build_netcdf_structure(config: CorrectionConfig, n_param_sets: int, n_gcp_pairs: int) -> dict: - """ - Build NetCDF data structure dynamically from configuration. - - This creates the netcdf_data dictionary with proper variable names based on - the parameters defined in the configuration, avoiding hardcoded mission-specific names. - - Args: - config: CorrectionConfig with parameters and optional NetCDF config - n_param_sets: Number of parameter sets (iterations) - n_gcp_pairs: Number of GCP pairs - - Returns: - Dictionary with initialized arrays for all NetCDF variables - """ - logger.info(f"Building NetCDF data structure for {n_param_sets} parameter sets × {n_gcp_pairs} GCP pairs") - - # Ensure NetCDFConfig exists - config.ensure_netcdf_config() - - # Start with coordinate dimensions - netcdf_data = { - "parameter_set_id": np.arange(n_param_sets), - "gcp_pair_id": np.arange(n_gcp_pairs), - } - - # Add parameter variables dynamically based on config.parameters - param_count = 0 - for param in config.parameters: - if param.ptype == ParameterType.CONSTANT_KERNEL: - # CONSTANT_KERNEL parameters have roll, pitch, yaw components - for angle in ["roll", "pitch", "yaw"]: - metadata = config.netcdf.get_parameter_netcdf_metadata(param, angle) - var_name = metadata.variable_name - netcdf_data[var_name] = np.full(n_param_sets, np.nan) - logger.debug(f" Added parameter variable: {var_name} ({metadata.long_name})") - param_count += 1 - else: - # OFFSET_KERNEL and OFFSET_TIME are single values - metadata = config.netcdf.get_parameter_netcdf_metadata(param) - var_name = metadata.variable_name - netcdf_data[var_name] = np.full(n_param_sets, np.nan) - logger.debug(f" Added parameter variable: {var_name} ({metadata.long_name})") - param_count += 1 - - logger.info(f" Created {param_count} parameter variables from {len(config.parameters)} parameter configs") - - # Add standard error statistics (2D: parameter_set_id × gcp_pair_id) - error_metrics = { - "rms_error_m": "RMS geolocation error", - "mean_error_m": "Mean geolocation error", - "max_error_m": "Maximum geolocation error", - "std_error_m": "Standard deviation of geolocation error", - "n_measurements": "Number of measurement points", - } - - for var_name, description in error_metrics.items(): - if var_name == "n_measurements": - netcdf_data[var_name] = np.full((n_param_sets, n_gcp_pairs), 0, dtype=int) - else: - netcdf_data[var_name] = np.full((n_param_sets, n_gcp_pairs), np.nan) - logger.debug(f" Added error metric: {var_name}") - - # Add image matching results (2D: parameter_set_id × gcp_pair_id) - image_match_vars = { - "im_lat_error_km": "Image matching latitude error", - "im_lon_error_km": "Image matching longitude error", - "im_ccv": "Image matching correlation coefficient", - "im_grid_step_m": "Image matching final grid step size", - } - - for var_name, description in image_match_vars.items(): - netcdf_data[var_name] = np.full((n_param_sets, n_gcp_pairs), np.nan) - logger.debug(f" Added image matching variable: {var_name}") - - # Add overall performance metrics (1D: parameter_set_id) - # Use dynamic threshold metric name - threshold_metric = config.netcdf.get_threshold_metric_name() - overall_metrics = { - threshold_metric: f"Percentage of pairs with error < {config.performance_threshold_m}m", - "mean_rms_all_pairs": "Mean RMS error across all GCP pairs", - "worst_pair_rms": "Worst performing GCP pair RMS error", - "best_pair_rms": "Best performing GCP pair RMS error", - } - - for var_name, description in overall_metrics.items(): - netcdf_data[var_name] = np.full(n_param_sets, np.nan) - logger.debug(f" Added overall metric: {var_name}") - - logger.info(f"NetCDF data structure created with {len(netcdf_data)} variables") - - return netcdf_data - - -# ============================================================================= -# HELPER FUNCTIONS -# ============================================================================= -# These functions extract reusable logic from the main loop to simplify the structure - - -def _load_calibration_data(config: "CorrectionConfig") -> CalibrationData: - """Load LOS vectors and optical PSF if calibration_dir is configured. - - This function centralizes calibration data loading, which is now called once - per GCP pair in the optimized implementation (previously called once per parameter set). - - Parameters - ---------- - config : CorrectionConfig - Configuration with calibration_dir and calibration settings - - Returns - ------- - CalibrationData - NamedTuple containing (los_vectors, optical_psfs), or (None, None) if - no calibration directory configured - - Raises - ------ - FileNotFoundError - If calibration directory is configured but files don't exist - ValueError - If calibration files exist but fail to load properly - - Examples - -------- - >>> calib_data = _load_calibration_data(config) - >>> if calib_data.los_vectors is not None: - ... # Use calibration data in image matching - ... pass - """ - if not config.calibration_dir: - return CalibrationData(los_vectors=None, optical_psfs=None) - - logger.info("Loading calibration data...") - - # Use configurable calibration file names - los_filename = config.get_calibration_file("los_vectors", default="b_HS.mat") - los_file = config.calibration_dir / los_filename - - if not los_file.exists(): - raise FileNotFoundError( - f"LOS vectors calibration file not found: {los_file}\nExpected in calibration_dir: {config.calibration_dir}" - ) - - los_vectors_cached = load_los_vectors_from_mat(los_file) - - if los_vectors_cached is None: - raise ValueError( - f"Failed to load LOS vectors from {los_file}. File exists but load_los_vectors_from_mat() returned None." - ) - - psf_filename = config.get_calibration_file("optical_psf", default="optical_PSF_675nm_upsampled.mat") - psf_file = config.calibration_dir / psf_filename - - if not psf_file.exists(): - raise FileNotFoundError( - f"Optical PSF calibration file not found: {psf_file}\nExpected in calibration_dir: {config.calibration_dir}" - ) - - optical_psfs_cached = load_optical_psf_from_mat(psf_file) - - if optical_psfs_cached is None: - raise ValueError( - f"Failed to load optical PSF from {psf_file}. File exists but load_optical_psf_from_mat() returned None." - ) - - logger.info(f" Cached LOS vectors: {los_vectors_cached.shape}") - logger.info(f" Cached optical PSF: {len(optical_psfs_cached)} entries") - - return CalibrationData(los_vectors=los_vectors_cached, optical_psfs=optical_psfs_cached) - - -def _load_image_pair_data( - tlm_key: str, - sci_key: str, - config: "CorrectionConfig", - telemetry_loader: TelemetryLoader, - science_loader: ScienceLoader, -) -> tuple[pd.DataFrame, pd.DataFrame, Any]: - """Load telemetry and science data for an image pair. - - Parameters - ---------- - tlm_key : str - Identifier for telemetry data. - sci_key : str - Identifier for science data. - config : CorrectionConfig - Configuration containing geolocation settings and field names. - telemetry_loader : callable - Function with signature ``telemetry_loader(tlm_key, config) -> pandas.DataFrame`` that - loads telemetry (L1) data for the given key. - science_loader : callable - Function with signature ``science_loader(sci_key, config) -> pandas.DataFrame`` that - loads science frame timing (L1A) data for the given key. - - Returns - ------- - tlm_dataset : pandas.DataFrame - DataFrame containing spacecraft state / telemetry records (position, velocity, attitude, time). - sci_dataset : pandas.DataFrame - DataFrame containing science frame timing information. - ugps_times : array_like - Time array extracted from the science dataset (e.g., uGPS times). - - Raises - ------ - ValueError - If required loader functions are not provided or returned data are invalid. - FileNotFoundError - If underlying loader raises when expected files are missing. - - Notes - ----- - This function loads and validates both telemetry and science datasets for a - single GCP pair. In the current implementation it is called once per image pair. - - Examples - -------- - >>> tlm, sci, times = _load_image_pair_data( - ... "tlm_001", "sci_001", config, load_clarreo_telemetry, load_clarreo_science - ... ) - """ - # Load telemetry (L1) data using mission-specific loader - tlm_dataset = load_telemetry(tlm_key, config, loader_func=telemetry_loader) - validate_telemetry_output(tlm_dataset, config) - - # Load science (L1A) data using mission-specific loader - sci_dataset = load_science(sci_key, config, loader_func=science_loader) - validate_science_output(sci_dataset, config) - ugps_times = sci_dataset[config.geo.time_field] # Can be altered by later steps - - return tlm_dataset, sci_dataset, ugps_times - - -def _create_dynamic_kernels( - config: "CorrectionConfig", - work_dir: Path, - tlm_dataset: pd.DataFrame, - creator: "create.KernelCreator", -) -> list[Path]: - """Create dynamic SPICE kernels from telemetry data. - - Dynamic kernels (SC-SPK, SC-CK) are generated from spacecraft telemetry - and do not change with parameter variations. In the current implementation, - these are created once per image. - - Parameters - ---------- - config : CorrectionConfig - Configuration with geo settings and dynamic_kernels list - work_dir : Path - Working directory for kernel files - tlm_dataset : pd.DataFrame - Spacecraft state data with position, velocity, attitude, and time columns - creator : create.KernelCreator - KernelCreator instance for writing kernels - - Returns - ------- - list[Path] - List of kernel file paths created (e.g., [sc_ephemeris.bsp, sc_attitude.bc]) - - Examples - -------- - >>> from curryer.kernels import create - >>> creator = create.KernelCreator(overwrite=True, append=False) - >>> dynamic_kernels = _create_dynamic_kernels(config, work_dir, tlm_dataset, creator) - >>> # Use in SPICE context - >>> with sp.ext.load_kernel(dynamic_kernels): - ... # Perform geolocation - ... pass - """ - logger.info(" Creating dynamic kernels from telemetry...") - dynamic_kernels = [] - for kernel_config in config.geo.dynamic_kernels: - dynamic_kernels.append( - creator.write_from_json( - kernel_config, - output_kernel=work_dir, - input_data=tlm_dataset, - ) - ) - logger.info(f" Created {len(dynamic_kernels)} dynamic kernels") - return dynamic_kernels - - -def _create_parameter_kernels( - params: list[tuple["ParameterConfig", Any]], - work_dir: Path, - tlm_dataset: pd.DataFrame, - sci_dataset: pd.DataFrame, - ugps_times: Any, - config: "CorrectionConfig", - creator: "create.KernelCreator", -) -> tuple[list[Path], Any]: - """Create parameter-specific SPICE kernels and apply time offsets. - - This function applies parameter variations by creating modified kernels - (CONSTANT_KERNEL, OFFSET_KERNEL) or modifying time tags (OFFSET_TIME). - Each parameter set produces different kernels and/or time modifications. - - Parameters - ---------- - params : list[tuple[ParameterConfig, Any]] - List of (ParameterConfig, parameter_value) tuples for this iteration - work_dir : Path - Working directory for kernel files - tlm_dataset : pd.DataFrame - Spacecraft state data (may be modified for OFFSET_KERNEL) with position, velocity, attitude, and time columns - sci_dataset : pd.DataFrame - Science frame time data (may be modified for OFFSET_TIME), may include optional measurement columns - ugps_times : array-like - Original time array from science dataset - config : CorrectionConfig - Configuration with geo settings - creator : create.KernelCreator - KernelCreator instance for writing kernels - - Returns - ------- - param_kernels : list[Path] - List of parameter-specific kernel file paths - ugps_times_modified : array-like - Modified time array if OFFSET_TIME applied, otherwise original times - - Examples - -------- - >>> param_kernels, times = _create_parameter_kernels( - ... params, work_dir, tlm_dataset, sci_dataset, ugps_times, config, creator - ... ) - >>> # Use in SPICE context with dynamic kernels - >>> with sp.ext.load_kernel([dynamic_kernels, param_kernels]): - ... geo = geolocate(times) - """ - param_kernels = [] - ugps_times_modified = ugps_times.copy() if hasattr(ugps_times, "copy") else ugps_times - - # Apply each individual parameter change - logger.info(" Applying parameter changes:") - for a_param, p_data in params: # [ParameterConfig, typing.Any] - # Log parameter details - param_name = a_param.data.get("field", "unknown") - if a_param.ptype == ParameterType.CONSTANT_KERNEL: - logger.info(f" {a_param.ptype.name}: {param_name} (constant kernel data)") - elif a_param.ptype == ParameterType.OFFSET_KERNEL: - units = a_param.data.get("units", "") - logger.info( - f" {a_param.ptype.name}: {param_name} = {p_data:.6f} " - f"(internal units; configured units: {units or 'unspecified'})" - ) - elif a_param.ptype == ParameterType.OFFSET_TIME: - units = a_param.data.get("units", "") - logger.info( - f" {a_param.ptype.name}: {param_name} = {p_data:.6f} " - f"(internal units; configured units: {units or 'unspecified'})" - ) - - # Create static changing SPICE kernels - if a_param.ptype == ParameterType.CONSTANT_KERNEL: - # Aka: BASE-CK, YOKE-CK, HYSICS-CK - param_kernels.append( - creator.write_from_json( - a_param.config_file, - output_kernel=work_dir, - input_data=p_data, - ) - ) - - # Create dynamic changing SPICE kernels - elif a_param.ptype == ParameterType.OFFSET_KERNEL: - # Aka: AZ-CK, EL-CK - tlm_dataset_alt = apply_offset(a_param, p_data, tlm_dataset) - param_kernels.append( - creator.write_from_json( - a_param.config_file, - output_kernel=work_dir, - input_data=tlm_dataset_alt, - ) - ) - - # Alter non-kernel data - elif a_param.ptype == ParameterType.OFFSET_TIME: - # Aka: Frame-times... - sci_dataset_alt = apply_offset(a_param, p_data, sci_dataset) - ugps_times_modified = sci_dataset_alt[config.geo.time_field].values - - else: - raise NotImplementedError(a_param.ptype) - - logger.info(f" Created {len(param_kernels)} parameter-specific kernels") - return param_kernels, ugps_times_modified - - -def _geolocate_and_match( - config: "CorrectionConfig", - kernel_ctx: KernelContext, - ugps_times_modified: Any, - tlm_dataset: pd.DataFrame, - calibration: CalibrationData, - image_matching_func: ImageMatchingFunc, - match_ctx: ImageMatchingContext, -) -> tuple[xr.Dataset, xr.Dataset]: - """Perform geolocation and image matching for a parameter set. - - This function loads SPICE kernels, performs geolocation, and runs image - matching against GCP reference data. It's the core computation step that - combines all previous setup (kernels, data loading) into results. - - Parameters - ---------- - config : CorrectionConfig - Configuration with geo and image matching settings - kernel_ctx : KernelContext - NamedTuple containing: - - mkrn: MetaKernel instance with SDS and mission kernels - - dynamic_kernels: List of dynamic kernel file paths - - param_kernels: List of parameter-specific kernel file paths - ugps_times_modified : array-like - Time array (possibly modified by OFFSET_TIME parameter) - tlm_dataset : pd.DataFrame - Spacecraft state telemetry data - calibration : CalibrationData - NamedTuple containing: - - los_vectors: Pre-loaded LOS vectors (or None) - - optical_psfs: Pre-loaded optical PSF (or None) - image_matching_func : ImageMatchingFunc - Function to perform image matching (e.g., integrated_image_match) - match_ctx : ImageMatchingContext - NamedTuple containing: - - gcp_pairs: List of GCP pairing tuples - - params: List of (ParameterConfig, parameter_value) tuples - - pair_idx: Index of current GCP pair - - sci_key: Science dataset identifier for this pair - - Returns - ------- - geo_dataset : xr.Dataset - Geolocated points with latitude, longitude, altitude - image_matching_output : xr.Dataset - Matching results with error measurements and metadata - - Examples - -------- - >>> kernel_ctx = KernelContext(mkrn, dynamic_kernels, param_kernels) - >>> calibration = CalibrationData(los_vectors, optical_psfs) - >>> match_ctx = ImageMatchingContext(gcp_pairs, params, 0, "sci_001") - >>> geo, matching = _geolocate_and_match( - ... config, kernel_ctx, times, tlm_dataset, - ... calibration, integrated_image_match, match_ctx - ... ) - """ - logger.info(" Performing geolocation...") - with sp.ext.load_kernel( - [ - kernel_ctx.mkrn.sds_kernels, - kernel_ctx.mkrn.mission_kernels, - kernel_ctx.dynamic_kernels, - kernel_ctx.param_kernels, - ] - ): - geoloc_inst = spatial.Geolocate(config.geo.instrument_name) - geo_dataset = geoloc_inst(ugps_times_modified) - - # === IMAGE MATCHING MODULE === - logger.info(" === IMAGE MATCHING MODULE ===") - - # Use injected image matching function - gcp_file = Path(match_ctx.gcp_pairs[0][1]) if match_ctx.gcp_pairs else Path("synthetic_gcp.tif") - - # All image matching functions use the same signature - image_matching_output = image_matching_func( - geolocated_data=geo_dataset, - gcp_reference_file=gcp_file, - telemetry=tlm_dataset, - calibration_dir=config.calibration_dir, - params_info=match_ctx.params, - config=config, - los_vectors_cached=calibration.los_vectors, - optical_psfs_cached=calibration.optical_psfs, - ) - validate_image_matching_output(image_matching_output) - logger.info(" Image matching complete") - - logger.info(f" Generated error measurements for {len(image_matching_output.measurement)} points") - - # Store metadata for tracking - image_matching_output.attrs["gcp_pair_index"] = match_ctx.pair_idx - image_matching_output.attrs["gcp_pair_id"] = f"{match_ctx.sci_key}_pair_{match_ctx.pair_idx}" - - return geo_dataset, image_matching_output - - -def loop( - config: CorrectionConfig, work_dir: Path, tlm_sci_gcp_sets: [(str, str, str)], resume_from_checkpoint: bool = False -): - """ - Correction loop for parameter sensitivity analysis. - - Parameters - ---------- - config : CorrectionConfig - The single configuration containing all settings: - - Required: parameters, iterations, thresholds, geo config - - Required loaders: telemetry_loader, science_loader - - Optional processing: gcp_pairing_func, image_matching_func - - Calibration: `calibration_dir` (if image_matching_func uses calibration) - - Output: netcdf, output_filename - work_dir : Path - Working directory for temporary files. - tlm_sci_gcp_sets : list of (str, str, str) - List of (`telemetry_key`, `science_key`, `gcp_key`) tuples. - resume_from_checkpoint : bool, optional - If True, resume from an existing checkpoint. - - Returns - ------- - results : list - List of iteration results (order: `pair_idx * N + param_idx`). - netcdf_data : dict - Dictionary of NetCDF variables indexed as `[param_idx, pair_idx]`. - - Notes - ----- - This implementation uses a pair-outer, parameter-inner loop order: - - Outer loop: GCP pairs (load data once per image) - - Inner loop: Parameter sets (reuse loaded data) - This reduces file I/O and centralizes mission-specific behavior through the - `config` object. - - Examples - -------- - Correction mode (parameter optimization): - - >>> from clarreo_data_loaders import load_clarreo_telemetry, load_clarreo_science - >>> from clarreo_config import create_clarreo_correction_config - >>> - >>> # Create base config with required parameters - >>> config = create_clarreo_correction_config(data_dir, generic_dir) - >>> - >>> # Add required loaders - >>> config.telemetry_loader = load_clarreo_telemetry - >>> config.science_loader = load_clarreo_science - >>> config.gcp_pairing_func = spatial_pairing - >>> config.image_matching_func = image_matching - >>> - >>> # Run correction analysis - >>> results, netcdf_data = loop(config, work_dir, tlm_sci_gcp_sets) - - Verification mode (performance checking only): - - >>> # Create config with minimal iterations for verification - >>> config = CorrectionConfig( - ... seed=42, - ... n_iterations=1, # Verification mode - ... parameters=[], # No parameter variation - ... geo=geo_config, - ... performance_threshold_m=250.0, - ... performance_spec_percent=39.0, - ... earth_radius_m=6378137.0, - ... ) - >>> config.telemetry_loader = load_clarreo_telemetry - >>> config.science_loader = load_clarreo_science - >>> config.gcp_pairing_func = spatial_pairing - >>> config.image_matching_func = image_matching - >>> results, netcdf_data = loop(config, work_dir, tlm_sci_gcp_sets) - """ - logger.info("=== CORRECTION PIPELINE ===") - logger.info(f" GCP pairs: {len(tlm_sci_gcp_sets)} (outer loop - load data once)") - - # Extract injected functions - telemetry_loader = config.telemetry_loader - science_loader = config.science_loader - image_matching_func = config.image_matching_func - gcp_pairing_func = config.gcp_pairing_func - - # Validate required loaders - if telemetry_loader is None: - raise ValueError("config.telemetry_loader is required but was None.") - if science_loader is None: - raise ValueError("config.science_loader is required but was None.") - - # Validate required processing functions - if gcp_pairing_func is None: - raise ValueError("config.gcp_pairing_func is required but was None.") - if image_matching_func is None: - raise ValueError("config.image_matching_func is required but was None.") - - # Initialize parameter sets - params_set = load_param_sets(config) - logger.info(f" Parameter sets: {len(params_set)} (inner loop)") - - # Build NetCDF data structure - n_param_sets = len(params_set) - n_gcp_pairs = len(tlm_sci_gcp_sets) - - # Try to load checkpoint if resuming - output_file = work_dir / config.get_output_filename() - start_pair_idx = 0 - # Currently, checkpoint is bugged, since the nadir equivalent stats are not calculated until the end. - # TODO [CURRYER-100]: Fix checkpoint resume for Monte Carlo GCS - if resume_from_checkpoint: - checkpoint_data, completed_pairs = _load_checkpoint(output_file, config) - if checkpoint_data is not None: - netcdf_data = checkpoint_data - start_pair_idx = completed_pairs - logger.info(f"Resuming from checkpoint: starting at GCP pair {start_pair_idx + 1}/{n_gcp_pairs}") - else: - netcdf_data = _build_netcdf_structure(config, n_param_sets, n_gcp_pairs) - logger.info("No valid checkpoint found, starting from beginning") - else: - netcdf_data = _build_netcdf_structure(config, n_param_sets, n_gcp_pairs) - - # Initialize results dict with (param_idx, pair_idx) keys - # This avoids nested search complexity when aggregating statistics - results_dict = {} - - # Prepare SPICE environment - mkrn = meta.MetaKernel.from_json( - config.geo.meta_kernel_file, - relative=True, - sds_dir=config.geo.generic_kernel_dir, - ) - creator = create.KernelCreator(overwrite=True, append=False) - - # Load calibration data once (LOS vectors and optical PSF are static instrument calibration) - calibration_data = _load_calibration_data(config) - - # Store parameter values once (before loops) - for param_idx, params in enumerate(params_set): - param_values = _extract_parameter_values(params) - _store_parameter_values(netcdf_data, param_idx, param_values) - - # OUTER LOOP: Iterate through GCP pairs - for pair_idx, (tlm_key, sci_key, gcp_key) in enumerate(tlm_sci_gcp_sets): - # Skip already-completed pairs if resuming - if pair_idx < start_pair_idx: - logger.info(f"=== GCP Pair {pair_idx + 1}/{n_gcp_pairs}: {sci_key} === (SKIPPED - already completed)") - continue - - logger.info(f"=== GCP Pair {pair_idx + 1}/{n_gcp_pairs}: {sci_key} ===") - - # Load image pair data once - tlm_dataset, sci_dataset, ugps_times = _load_image_pair_data( - tlm_key, sci_key, config, telemetry_loader, science_loader - ) - - # Create dynamic kernels once (these don't change with parameters) - dynamic_kernels = _create_dynamic_kernels(config, work_dir, tlm_dataset, creator) - - # Get GCP pairing ONCE - gcp_pairs = gcp_pairing_func([sci_key]) - validate_gcp_pairing_output(gcp_pairs) - logger.info(f" Found {len(gcp_pairs)} GCP pairs for processing") - - # INNER LOOP: Iterate through parameter sets - for param_idx, params in enumerate(params_set): - logger.info(f" Parameter Set {param_idx + 1}/{n_param_sets}") - - # Create parameter-specific kernels (these change with parameters) - param_kernels, ugps_times_modified = _create_parameter_kernels( - params, work_dir, tlm_dataset, sci_dataset, ugps_times, config, creator - ) - - # Prepare context objects for cleaner function call - kernel_ctx = KernelContext(mkrn=mkrn, dynamic_kernels=dynamic_kernels, param_kernels=param_kernels) - match_ctx = ImageMatchingContext(gcp_pairs=gcp_pairs, params=params, pair_idx=pair_idx, sci_key=sci_key) - - # Geolocate and perform image matching - geo_dataset, image_matching_output = _geolocate_and_match( - config, - kernel_ctx, - ugps_times_modified, - tlm_dataset, - calibration_data, - image_matching_func, - match_ctx, - ) - - # Process individual pair error statistics - individual_stats = call_error_stats_module(image_matching_output, correction_config=config) - individual_metrics = _extract_error_metrics(individual_stats) - - # Store results in NetCDF (maintain [param_idx, pair_idx] ordering) - _store_gcp_pair_results(netcdf_data, param_idx, pair_idx, individual_metrics) - netcdf_data["im_lat_error_km"][param_idx, pair_idx] = image_matching_output.attrs.get( - "lat_error_km", np.nan - ) - netcdf_data["im_lon_error_km"][param_idx, pair_idx] = image_matching_output.attrs.get( - "lon_error_km", np.nan - ) - netcdf_data["im_ccv"][param_idx, pair_idx] = image_matching_output.attrs.get("correlation_ccv", np.nan) - netcdf_data["im_grid_step_m"][param_idx, pair_idx] = image_matching_output.attrs.get( - "final_grid_step_m", np.nan - ) - - # Store results in dict with (param_idx, pair_idx) key - # Note: iteration index reflects reversed order (pair_idx * n_params + param_idx) - param_values = _extract_parameter_values(params) - iteration_result = { - "iteration": pair_idx * n_param_sets + param_idx, - "pair_index": pair_idx, - "param_index": param_idx, - "parameters": param_values, - "geolocation": geo_dataset, - "gcp_pairs": gcp_pairs, - "image_matching": image_matching_output, - "error_stats": individual_stats, - "rms_error_m": individual_metrics["rms_error_m"], - "aggregate_rms_error_m": None, - } - results_dict[(param_idx, pair_idx)] = iteration_result - - logger.info( - f" RMS error: {individual_metrics['rms_error_m']:.2f}m " - f"({individual_metrics['n_measurements']} measurements)" - ) - - logger.info(f" GCP pair {pair_idx + 1} complete (processed {n_param_sets} parameter sets)") - - # Save checkpoint after each pair completes - if resume_from_checkpoint: - _save_netcdf_checkpoint(netcdf_data, output_file, config, pair_idx) - - # Compute aggregate statistics for each parameter set (after all pairs complete) - logger.info("=== Computing aggregate statistics for all parameter sets ===") - for param_idx in range(n_param_sets): - # Collect all image matching results for this parameter set - param_image_matching_results = [] - for pair_idx in range(n_gcp_pairs): - result = results_dict.get((param_idx, pair_idx)) - if result: - param_image_matching_results.append(result["image_matching"]) - - # Compute aggregate statistics - aggregate_stats = call_error_stats_module(param_image_matching_results, correction_config=config) - aggregate_error_metrics = _extract_error_metrics(aggregate_stats) - - # Extract pair errors for threshold calculation - pair_errors = [netcdf_data["rms_error_m"][param_idx, pair_idx] for pair_idx in range(n_gcp_pairs)] - _compute_parameter_set_metrics(netcdf_data, param_idx, pair_errors, threshold_m=config.performance_threshold_m) - - logger.info(f" Parameter set {param_idx + 1}: Aggregate RMS = {aggregate_error_metrics['rms_error_m']:.2f}m") - - # Add aggregate stats to all results for this parameter set - for pair_idx in range(n_gcp_pairs): - key = (param_idx, pair_idx) - if key in results_dict: - results_dict[key]["aggregate_error_stats"] = aggregate_stats - results_dict[key]["aggregate_rms_error_m"] = aggregate_error_metrics["rms_error_m"] - # Convert results_dict back to list for backward compatibility - # Sort by iteration index to maintain consistent ordering - results = [results_dict[key] for key in sorted(results_dict.keys(), key=lambda k: results_dict[k]["iteration"])] - - # Save final NetCDF results - _save_netcdf_results(netcdf_data, output_file, config) - - # Clean up checkpoint file after successful completion - if resume_from_checkpoint: - _cleanup_checkpoint(output_file) - - logger.info(f"=== Loop Complete: Processed {n_gcp_pairs} GCP pairs × {n_param_sets} parameter sets ===") - logger.info(f" Total iterations: {len(results)}") - logger.info(f" NetCDF output: {output_file}") - - return results, netcdf_data - - -def _extract_parameter_values(params): - """Extract parameter values from a parameter set into a dictionary.""" - param_values = {} - - for param_config, param_data in params: - if param_config.config_file: - param_name = param_config.config_file.stem - - if param_config.ptype == ParameterType.CONSTANT_KERNEL: - # Extract roll, pitch, yaw from DataFrame - if isinstance(param_data, pd.DataFrame) and "angle_x" in param_data.columns: - # Convert back to arcseconds for storage - param_values[f"{param_name}_roll"] = np.degrees(param_data["angle_x"].iloc[0]) * 3600 - param_values[f"{param_name}_pitch"] = np.degrees(param_data["angle_y"].iloc[0]) * 3600 - param_values[f"{param_name}_yaw"] = np.degrees(param_data["angle_z"].iloc[0]) * 3600 - - elif param_config.ptype == ParameterType.OFFSET_KERNEL: - # Single bias value (keep in original units) - param_values[param_name] = param_data - - elif param_config.ptype == ParameterType.OFFSET_TIME: - # Time correction (keep in original units) - param_values[param_name] = param_data - - return param_values - - -def _store_parameter_values(netcdf_data, param_idx, param_values): - """Store parameter values in the NetCDF data structure. - - This function maps parameter names to NetCDF variable names for storage. - It handles the naming convention used by _build_netcdf_structure. - """ - - for param_name, value in param_values.items(): - # Generate NetCDF variable name using same logic as _build_netcdf_structure - # Replace dots and dashes with underscores, ensure param_ prefix - netcdf_var = param_name.replace(".", "_").replace("-", "_") - if not netcdf_var.startswith("param_"): - netcdf_var = f"param_{netcdf_var}" - - if netcdf_var in netcdf_data: - netcdf_data[netcdf_var][param_idx] = value - logger.debug(f" Stored {netcdf_var}[{param_idx}] = {value}") - else: - # Try to find a matching variable with debug info - logger.warning( - f" Parameter variable '{netcdf_var}' not found in netcdf_data. Available keys: {[k for k in netcdf_data.keys() if k.startswith('param_')]}" - ) - - -def _extract_error_metrics(stats_dataset): - """Extract error metrics from error statistics dataset.""" - if hasattr(stats_dataset, "attrs"): - # Real error stats module - return { - "rms_error_m": stats_dataset.attrs.get("rms_error_m", np.nan), - "mean_error_m": stats_dataset.attrs.get("mean_error_m", np.nan), - "max_error_m": stats_dataset.attrs.get("max_error_m", np.nan), - "std_error_m": stats_dataset.attrs.get("std_error_m", np.nan), - "n_measurements": stats_dataset.attrs.get("total_measurements", 0), - } - else: - # Fallback for placeholder - return { - "rms_error_m": float(stats_dataset.get("rms_error", np.nan)), - "mean_error_m": float(stats_dataset.get("mean_error", np.nan)), - "max_error_m": float(stats_dataset.get("max_error", np.nan)), - "std_error_m": float(stats_dataset.get("std_error", np.nan)), - "n_measurements": int(stats_dataset.get("n_measurements", 0)), - } - - -def _store_gcp_pair_results(netcdf_data, param_idx, pair_idx, error_metrics): - """Store GCP pair results in the NetCDF data structure.""" - netcdf_data["rms_error_m"][param_idx, pair_idx] = error_metrics["rms_error_m"] - netcdf_data["mean_error_m"][param_idx, pair_idx] = error_metrics["mean_error_m"] - netcdf_data["max_error_m"][param_idx, pair_idx] = error_metrics["max_error_m"] - netcdf_data["std_error_m"][param_idx, pair_idx] = error_metrics["std_error_m"] - netcdf_data["n_measurements"][param_idx, pair_idx] = error_metrics["n_measurements"] - - -def _compute_parameter_set_metrics(netcdf_data, param_idx, pair_errors, threshold_m=250.0): - """ - Compute overall performance metrics for a parameter set. - - Args: - netcdf_data: NetCDF data dictionary - param_idx: Parameter set index - pair_errors: Array of RMS errors for each GCP pair - threshold_m: Performance threshold in meters - """ - pair_errors = np.array(pair_errors) - valid_errors = pair_errors[~np.isnan(pair_errors)] - - if len(valid_errors) > 0: - # Percentage of pairs with error < threshold - # Find the threshold metric key dynamically - threshold_metric = None - for key in netcdf_data.keys(): - if key.startswith("percent_under_") and key.endswith("m"): - threshold_metric = key - break - - if threshold_metric: - percent_under_threshold = (valid_errors < threshold_m).sum() / len(valid_errors) * 100 - netcdf_data[threshold_metric][param_idx] = percent_under_threshold - - # Mean RMS across all pairs - netcdf_data["mean_rms_all_pairs"][param_idx] = np.mean(valid_errors) - - # Best and worst pair performance - netcdf_data["best_pair_rms"][param_idx] = np.min(valid_errors) - netcdf_data["worst_pair_rms"][param_idx] = np.max(valid_errors) - - -# ============================================================================= -# Incremental NetCDF Saving (Checkpoint/Resume) -# ============================================================================= - - -def _save_netcdf_checkpoint(netcdf_data, output_file, config, pair_idx_completed): - """ - Save NetCDF checkpoint with partial results after each GCP pair completes. - - This enables resuming correction runs if they are interrupted. - Adapted for pair-outer loop order where each pair processes all parameters. - - Args: - netcdf_data: Dictionary with current NetCDF data - output_file: Path to final output file (checkpoint uses .checkpoint.nc suffix) - config: CorrectionConfig with metadata - pair_idx_completed: Index of the last completed GCP pair (for pair-outer loop) - """ - import xarray as xr - - checkpoint_file = output_file.parent / f"{output_file.stem}_checkpoint.nc" - - # Ensure NetCDFConfig exists - config.ensure_netcdf_config() - - # Create coordinate arrays - coords = { - "parameter_set_id": netcdf_data["parameter_set_id"], - "gcp_pair_id": netcdf_data["gcp_pair_id"], - } - - # Build variable list dynamically from netcdf_data keys - data_vars = {} - for var_name, var_data in netcdf_data.items(): - if var_name not in coords: - if isinstance(var_data, np.ndarray): - if var_data.ndim == 1: - data_vars[var_name] = (["parameter_set_id"], var_data) - elif var_data.ndim == 2: - data_vars[var_name] = (["parameter_set_id", "gcp_pair_id"], var_data) - - # Create dataset - ds = xr.Dataset(data_vars, coords=coords) - - # Add regular metadata - ds.attrs.update( - { - "title": config.netcdf.title, - "description": config.netcdf.description, - "created": pd.Timestamp.now().isoformat(), - "correction_iterations": config.n_iterations, - "performance_threshold_m": config.netcdf.performance_threshold_m, - "parameter_count": len(config.parameters), - "random_seed": str(config.seed) if config.seed is not None else "None", - } - ) - - # Add checkpoint-specific metadata (NetCDF-compatible types) - ds.attrs["checkpoint"] = 1 # Use integer instead of boolean for NetCDF compatibility - ds.attrs["completed_gcp_pairs"] = int(pair_idx_completed + 1) - ds.attrs["total_gcp_pairs"] = int(len(netcdf_data["gcp_pair_id"])) - ds.attrs["checkpoint_timestamp"] = pd.Timestamp.now().isoformat() - - # Add parameter variable attributes from config - for param in config.parameters: - if param.ptype == ParameterType.CONSTANT_KERNEL: - for angle in ["roll", "pitch", "yaw"]: - metadata = config.netcdf.get_parameter_netcdf_metadata(param, angle) - if metadata.variable_name in ds.data_vars: - ds[metadata.variable_name].attrs.update({"units": metadata.units, "long_name": metadata.long_name}) - else: - metadata = config.netcdf.get_parameter_netcdf_metadata(param) - if metadata.variable_name in ds.data_vars: - ds[metadata.variable_name].attrs.update({"units": metadata.units, "long_name": metadata.long_name}) - - # Add standard metric attributes - standard_attrs = config.netcdf.get_standard_attributes() - threshold_metric = config.netcdf.get_threshold_metric_name() - standard_attrs[threshold_metric] = { - "units": "percent", - "long_name": f"Percentage of pairs with error < {config.performance_threshold_m}m", - } - for var, attrs in standard_attrs.items(): - if var in ds.data_vars: - ds[var].attrs.update(attrs) - - # Save to file in one operation - checkpoint_file.parent.mkdir(parents=True, exist_ok=True) - ds.to_netcdf(checkpoint_file, mode="w") # Force overwrite mode - ds.close() - - logger.info(f" Checkpoint saved: {pair_idx_completed + 1}/{len(netcdf_data['gcp_pair_id'])} GCP pairs complete") - - -def _load_checkpoint(output_file, config): - """ - Load checkpoint if it exists and convert back to netcdf_data dict. - - Args: - output_file: Path to final output file (will check for .checkpoint.nc) - config: CorrectionConfig for structure information - - Returns: - Tuple of (netcdf_data dict, start_idx) or (None, 0) if no checkpoint - """ - import xarray as xr - - checkpoint_file = output_file.parent / f"{output_file.stem}_checkpoint.nc" - - if not checkpoint_file.exists(): - return None, 0 - - logger.info(f"Found checkpoint file: {checkpoint_file}") - - try: - ds = xr.open_dataset(checkpoint_file, decode_timedelta=False) - - # Verify this is actually a checkpoint (checkpoint attribute is 1 for true, 0 or missing for false) - checkpoint_flag = ds.attrs.get("checkpoint", 0) - if not checkpoint_flag: # Will be True if checkpoint=1, False if checkpoint=0 or missing - logger.warning("File exists but is not marked as checkpoint, ignoring") - ds.close() - return None, 0 - - completed = ds.attrs.get("completed_gcp_pairs", 0) - total = ds.attrs.get("total_gcp_pairs", 0) - timestamp = ds.attrs.get("checkpoint_timestamp", "unknown") - - logger.info(f"Checkpoint from {timestamp}: {completed}/{total} GCP pairs complete") - - # Convert xarray.Dataset back to netcdf_data dictionary - netcdf_data = {} - - # Add coordinates - netcdf_data["parameter_set_id"] = ds.coords["parameter_set_id"].values - netcdf_data["gcp_pair_id"] = ds.coords["gcp_pair_id"].values - - # Add all data variables - for var_name in ds.data_vars: - netcdf_data[var_name] = ds[var_name].values - - ds.close() - - logger.info(f"Checkpoint loaded successfully, resuming from GCP pair {completed}") - - return netcdf_data, completed - - except Exception as e: - logger.error(f"Failed to load checkpoint: {e}") - return None, 0 - - -def _cleanup_checkpoint(output_file): - """ - Remove checkpoint file after successful completion. - - Args: - output_file: Path to final output file (will remove .checkpoint.nc) - """ - checkpoint_file = output_file.parent / f"{output_file.stem}_checkpoint.nc" - - if checkpoint_file.exists(): - try: - checkpoint_file.unlink() - logger.info(f"Checkpoint file cleaned up: {checkpoint_file}") - except Exception as e: - logger.warning(f"Failed to remove checkpoint file: {e}") - - -def _save_netcdf_results(netcdf_data, output_file, config): - """ - Save results to NetCDF file using config-driven metadata. - - This function dynamically builds the NetCDF file structure from the - netcdf_data dictionary, using configuration for all metadata rather - than hardcoding mission-specific values. - - Args: - netcdf_data: Dictionary with all NetCDF variables and data - output_file: Path to output NetCDF file - config: CorrectionConfig with NetCDF metadata - """ - import xarray as xr - - logger.info(f"Saving NetCDF results to: {output_file}") - - # Ensure NetCDFConfig exists - config.ensure_netcdf_config() - - # Create coordinate arrays - coords = { - "parameter_set_id": netcdf_data["parameter_set_id"], - "gcp_pair_id": netcdf_data["gcp_pair_id"], - } - - # Build variable list dynamically from netcdf_data keys - data_vars = {} - - # Add all non-coordinate variables, determining dimensions from array shape - for var_name, var_data in netcdf_data.items(): - if var_name not in coords: - # Determine dimensions from array shape - if isinstance(var_data, np.ndarray): - if var_data.ndim == 1: - data_vars[var_name] = (["parameter_set_id"], var_data) - elif var_data.ndim == 2: - data_vars[var_name] = (["parameter_set_id", "gcp_pair_id"], var_data) - - logger.info(f" Creating dataset with {len(data_vars)} data variables") - - # Create dataset - ds = xr.Dataset(data_vars, coords=coords) - - # Add global metadata from config - ds.attrs.update( - { - "title": config.netcdf.title, - "description": config.netcdf.description, - "created": pd.Timestamp.now().isoformat(), - "correction_iterations": config.n_iterations, - "performance_threshold_m": config.netcdf.performance_threshold_m, - "parameter_count": len(config.parameters), - "random_seed": str(config.seed) if config.seed is not None else "None", - } - ) - - # Add parameter variable attributes from config - for param in config.parameters: - if param.ptype == ParameterType.CONSTANT_KERNEL: - # Add metadata for roll, pitch, yaw components - for angle in ["roll", "pitch", "yaw"]: - metadata = config.netcdf.get_parameter_netcdf_metadata(param, angle) - if metadata.variable_name in ds.data_vars: - ds[metadata.variable_name].attrs.update({"units": metadata.units, "long_name": metadata.long_name}) - else: - # Add metadata for single-value parameters - metadata = config.netcdf.get_parameter_netcdf_metadata(param) - if metadata.variable_name in ds.data_vars: - ds[metadata.variable_name].attrs.update({"units": metadata.units, "long_name": metadata.long_name}) - - # Add standard metric attributes from config (allows mission overrides) - standard_attrs = config.netcdf.get_standard_attributes() - - # Add dynamic threshold metric - threshold_metric = config.netcdf.get_threshold_metric_name() - standard_attrs[threshold_metric] = { - "units": "percent", - "long_name": f"Percentage of pairs with error < {config.performance_threshold_m}m", - } - - for var, attrs in standard_attrs.items(): - if var in ds.data_vars: - ds[var].attrs.update(attrs) - - # Save to file - output_file.parent.mkdir(parents=True, exist_ok=True) - ds.to_netcdf(output_file) +# NetCDF I/O +from curryer.correction.results_io import ( + _build_netcdf_structure, + _cleanup_checkpoint, + _load_checkpoint, + _save_netcdf_checkpoint, + _save_netcdf_results, +) - logger.info(f" NetCDF file saved successfully") - logger.info(f" Dimensions: {dict(ds.sizes)}") - logger.info(f" Data variables: {len(list(ds.data_vars.keys()))}") - logger.info(f" File: {output_file}") +__all__ = [ + # Config + "STANDARD_NETCDF_ATTRIBUTES", + "STANDARD_VAR_NAMES", + "CalibrationData", + "CorrectionConfig", + "DataConfig", + "GeolocationConfig", + "ImageMatchingContext", + "KernelContext", + "NetCDFConfig", + "NetCDFParameterMetadata", + "ParameterConfig", + "ParameterType", + # Parameters + "load_param_sets", + # Kernel ops + "_create_dynamic_kernels", + "_create_parameter_kernels", + "apply_offset", + # Results I/O + "_build_netcdf_structure", + "_cleanup_checkpoint", + "_load_checkpoint", + "_save_netcdf_checkpoint", + "_save_netcdf_results", + # Pipeline + "_aggregate_image_matching_results", + "_compute_parameter_set_metrics", + "_extract_error_metrics", + "_extract_parameter_values", + "_extract_spacecraft_position_midframe", + "_geolocate_and_match", + "_geolocated_to_image_grid", + "_load_calibration_data", + "_load_file", + "_load_image_pair_data", + "_resolve_gcp_pairs", + "_store_gcp_pair_results", + "_store_parameter_values", + "call_error_stats_module", + "image_matching", + "load_config_from_json", + "loop", +] diff --git a/curryer/correction/data_structures.py b/curryer/correction/data_structures.py index f37c8377..e3e0d4e6 100644 --- a/curryer/correction/data_structures.py +++ b/curryer/correction/data_structures.py @@ -3,6 +3,7 @@ from dataclasses import dataclass import numpy as np +from pydantic import BaseModel, field_validator, model_validator @dataclass @@ -100,8 +101,23 @@ def __post_init__(self) -> None: @dataclass -class GeolocationConfig: - """Configuration parameters for PSF geolocation modelling.""" +class PSFSamplingConfig: + """Configuration parameters for PSF sampling during image matching. + + Parameters + ---------- + gcp_step_m : float, optional + Ground control point step size in meters. Default is 30.0. + motion_convolution_step_m : float, optional + Step size for spacecraft motion convolution in meters. + Default is ``gcp_step_m / 20.0``. + psf_lat_sample_dist_deg : float, optional + PSF sample distance in the latitude direction in degrees. + Default is 2.4397105613972e-05. + psf_lon_sample_dist_deg : float, optional + PSF sample distance in the longitude direction in degrees. + Default is 2.8737038710207e-05. + """ gcp_step_m: float = 30.0 motion_convolution_step_m: float = gcp_step_m / 20.0 @@ -117,3 +133,96 @@ class SearchConfig: grid_span_km: float = 11.0 reduction_factor: float = 0.8 spacing_limit_m: float = 10.0 + + +class RegridConfig(BaseModel): + """Configuration for GCP chip regridding. + + Specifies output grid parameters for transforming irregular geodetic grids + to regular latitude/longitude grids. ECEF → geodetic conversion always + uses the WGS84 ellipsoid, which is the only ellipsoid supported by + ``curryer.compute.spatial.ecef_to_geodetic``. + + Parameters + ---------- + output_grid_size : tuple[int, int], optional + Desired output grid dimensions as (nrows, ncols). Mutually exclusive + with ``output_resolution_deg``. + output_resolution_deg : tuple[float, float], optional + Desired output resolution as (dlat, dlon) in degrees. Mutually + exclusive with ``output_grid_size``. Required when ``output_bounds`` + is set. + output_bounds : tuple[float, float, float, float], optional + Explicit output grid bounds as (minlon, maxlon, minlat, maxlat) in + degrees. Requires ``output_resolution_deg``. + conservative_bounds : bool, default=True + If True, shrink bounds to ensure all output points lie within the + input irregular grid (avoids edge extrapolation). + interpolation_method : str, default="bilinear" + Interpolation method; one of ``"bilinear"`` or ``"nearest"``. + fill_value : float, default=NaN + Value assigned to output points that fall outside the input grid. + """ + + output_grid_size: tuple[int, int] | None = None + output_resolution_deg: tuple[float, float] | None = None + output_bounds: tuple[float, float, float, float] | None = None + conservative_bounds: bool = True + interpolation_method: str = "bilinear" + fill_value: float = float("nan") + + @field_validator("interpolation_method") + @classmethod + def validate_method(cls, v: str) -> str: + """Validate interpolation method name.""" + valid = {"bilinear", "nearest"} + if v not in valid: + raise ValueError(f"interpolation_method must be one of {valid}, got '{v}'") + return v + + @field_validator("output_grid_size") + @classmethod + def validate_grid_size(cls, v: tuple[int, int] | None) -> tuple[int, int] | None: + """Validate that grid size has at least 2 rows and 2 columns.""" + if v is not None: + if v[0] < 2: + raise ValueError(f"Grid size must have at least 2 rows and 2 columns, got {v}") + if v[1] < 2: + raise ValueError(f"Grid size must have at least 2 rows and 2 columns, got {v}") + return v + + @field_validator("output_resolution_deg") + @classmethod + def validate_resolution(cls, v: tuple[float, float] | None) -> tuple[float, float] | None: + """Validate that resolution values are positive.""" + if v is not None: + if v[0] <= 0 or v[1] <= 0: + raise ValueError(f"Resolution values must be positive (dlat, dlon), got {v}") + return v + + @field_validator("output_bounds") + @classmethod + def validate_bounds(cls, v: tuple[float, float, float, float] | None) -> tuple[float, float, float, float] | None: + """Validate that bounds are properly ordered.""" + if v is not None: + minlon, maxlon, minlat, maxlat = v + if minlon >= maxlon: + raise ValueError(f"minlon must be < maxlon, got {minlon} >= {maxlon}") + if minlat >= maxlat: + raise ValueError(f"minlat must be < maxlat, got {minlat} >= {maxlat}") + return v + + @model_validator(mode="after") + def validate_grid_spec(self) -> RegridConfig: + """Validate that grid specification options are mutually consistent.""" + has_size = self.output_grid_size is not None + has_res = self.output_resolution_deg is not None + has_bounds = self.output_bounds is not None + + if has_size and has_res: + raise ValueError("Cannot specify both output_grid_size and output_resolution_deg") + if has_bounds and not has_res: + raise ValueError("output_bounds requires output_resolution_deg") + if has_bounds and has_size: + raise ValueError("Cannot specify both output_bounds and output_grid_size") + return self diff --git a/curryer/correction/dataio.py b/curryer/correction/dataio.py index f69eb867..4cddedb3 100644 --- a/curryer/correction/dataio.py +++ b/curryer/correction/dataio.py @@ -1,6 +1,6 @@ -"""Helpers for querying and downloading NetCDF data from AWS S3. +"""Validation helpers and S3 data-access utilities for the correction pipeline. -All interactions rely on the boto3 S3 client. Callers may either provide an +S3 access relies on the boto3 S3 client. Callers may either provide an explicit client instance (useful for testing) or rely on the default client, in which case boto3 must be installed and AWS credentials are read from the standard ``AWS_*`` environment variables. @@ -12,7 +12,7 @@ import os from collections.abc import Iterable from pathlib import Path -from typing import TYPE_CHECKING, Protocol +from typing import TYPE_CHECKING # TODO: Remove if boto3 is made a required dependency! try: # pragma: no cover - exercised indirectly when boto3 is available @@ -25,95 +25,10 @@ # ============================================================================ -# Data Loader Interface Protocols +# Data Validation Helpers # ============================================================================ -class TelemetryLoader(Protocol): - """ - Protocol for mission-specific telemetry loading functions. - - Telemetry loaders are responsible for reading spacecraft state data - (position, attitude, timing) from mission-specific formats and returning - it in a standard DataFrame format. - - Standard Signature: - def load_telemetry(tlm_key: str, config) -> pd.DataFrame - - Requirements: - - Accept tlm_key (path or identifier) and config object - - Return DataFrame with mission-specific telemetry fields - - Include time fields needed for SPICE kernel creation - - Include attitude data (quaternions or DCMs) - - Include position data if creating SPK kernels - - Example: - def load_clarreo_telemetry(tlm_key: str, config) -> pd.DataFrame: - # Load from multiple CSV files - # Convert formats (DCM to quaternion, etc.) - # Merge and return - return telemetry_df - """ - - def __call__(self, tlm_key: str, config) -> pd.DataFrame: - """Load telemetry data for a given key.""" - ... - - -class ScienceLoader(Protocol): - """ - Protocol for mission-specific science frame loading functions. - - Science loaders provide frame timing and metadata for the instrument - observations that will be geolocated. - - Standard Signature: - def load_science(sci_key: str, config) -> pd.DataFrame - - Requirements: - - Accept sci_key (path or identifier) and config object - - Return DataFrame with frame timing data - - Must include time field specified in config.geo.time_field - - Time values should match expected format (e.g., GPS microseconds) - - Example: - def load_clarreo_science(sci_key: str, config) -> pd.DataFrame: - # Load frame timestamps - # Convert to required units (e.g., GPS µs) - return science_df - """ - - def __call__(self, sci_key: str, config) -> pd.DataFrame: - """Load science frame timing/metadata.""" - ... - - -class GCPLoader(Protocol): - """ - Protocol for mission-specific GCP (Ground Control Point) loading functions. - - GCP loaders retrieve reference imagery or coordinates for ground truth - comparison. - - Standard Signature: - def load_gcp(gcp_key: str, config) -> Any - - Note: - This interface is currently a placeholder. The return type and structure - will be standardized when GCP loading is fully integrated into the pipeline. - - Example: - def load_clarreo_gcp(gcp_key: str, config): - # Load Landsat reference image - # Or load GCP coordinate database - return gcp_data - """ - - def __call__(self, gcp_key: str, config): - """Load GCP reference data.""" - ... - - def validate_telemetry_output(df: pd.DataFrame, config) -> None: """ Validate that telemetry loader output has expected structure. @@ -152,7 +67,7 @@ def validate_science_output(df: pd.DataFrame, config) -> None: ValueError: If DataFrame is empty or missing required time field Example: - >>> sci_df = load_science("sci_001", config) + >>> sci_df = pd.read_csv("science.csv") >>> validate_science_output(sci_df, config) """ import pandas as pd diff --git a/curryer/correction/geolocation_error_stats.py b/curryer/correction/error_stats.py similarity index 64% rename from curryer/correction/geolocation_error_stats.py rename to curryer/correction/error_stats.py index f26f4b7c..dc0d8ecb 100644 --- a/curryer/correction/geolocation_error_stats.py +++ b/curryer/correction/error_stats.py @@ -1,15 +1,22 @@ """ Geolocation statistics processor with Xarray inputs and outputs. -This module processes geolocation errors from the GCS image matching algorithm and -produces performance verification metrics, specifically the nadir-equivalent geolocation errors -which pass when less than 250m. +This module processes geolocation errors from the image matching algorithm and +produces nadir-equivalent geolocation errors together with mission-agnostic +summary statistics. The main processing pipeline: -1) Convert angular errors to N-S and E-W distances -2) Transform error components to view-plane/cross-view-plane distances -3) Scale to nadir-equivalent using geometric factors -4) Compute statistical performance metrics + +1. Convert angular errors to N-S and E-W distances. +2. Transform error components to view-plane / cross-view-plane distances. +3. Scale to nadir-equivalent using geometric factors. +4. (Optional) Compute comprehensive statistics across all measurements. + +Pass/fail evaluation is intentionally **not** included here — whether the +statistics meet mission requirements is the caller's responsibility. Use +:func:`compute_percent_below` for custom threshold queries, or compare the +fixed threshold-table entries (``percent_below_100m``, ``percent_below_250m``, +etc.) directly. """ import logging @@ -20,40 +27,85 @@ import numpy as np import xarray as xr +from curryer.compute import constants + logger = logging.getLogger(__name__) +# WGS84 Earth radius in meters – single source of truth from curryer.compute.constants. +_EARTH_RADIUS_M: float = constants.WGS84_SEMI_MAJOR_AXIS_KM * 1000.0 -@dataclass -class GeolocationConfig: - """Configuration parameters for geolocation processing. - All values should be provided from CorrectionConfig - no hardcoded defaults. +def compute_percent_below(errors: np.ndarray, threshold_m: float) -> float: + """Compute the percentage of errors below a given threshold. + + Useful for evaluating custom thresholds not in the standard table + produced by :meth:`ErrorStatsProcessor._calculate_statistics`. + + Parameters + ---------- + errors : np.ndarray + Array of nadir-equivalent geolocation errors in metres. + threshold_m : float + Threshold in metres. + + Returns + ------- + float + Percentage (0–100) of errors strictly below *threshold_m*. + Returns ``0.0`` when *errors* is empty. """ + if len(errors) == 0: + return 0.0 + return float(np.sum(errors < threshold_m) / len(errors) * 100) - earth_radius_m: float # Earth radius in meters (e.g., WGS84: 6378140.0) - performance_threshold_m: float # Accuracy threshold (e.g., 250.0) - performance_spec_percent: float # Performance requirement percentage (e.g., 39.0) - minimum_correlation: float | None = None # Filter threshold (0.0-1.0) + +@dataclass +class ErrorStatsConfig: + """Configuration for geolocation error statistics processing. + + Parameters + ---------- + minimum_correlation : float or None, optional + Minimum correlation filter threshold (0.0–1.0). Measurements whose + correlation score falls below this value are excluded before + processing. Default is ``None`` (no filtering). + variable_names : dict of str to str or None, optional + Mission-agnostic variable name mappings from semantic names to actual + dataset variable names. If ``None``, generic defaults are used. + + Notes + ----- + Pass/fail thresholds are **not** part of this config. + ``ErrorStatsProcessor`` computes statistics only; whether those numbers + meet mission requirements is the caller's responsibility. + + Earth radius is not a config field either. ``_EARTH_RADIUS_M`` (derived + from ``curryer.compute.constants.WGS84_SEMI_MAJOR_AXIS_KM``) is used + directly in all calculations. + """ + + minimum_correlation: float | None = None # Mission-agnostic variable name mappings # Maps semantic names to actual variable names in the dataset - variable_names: dict[str, str] | None = None # If None, uses CLARREO defaults + variable_names: dict[str, str] | None = None # If None, uses generic defaults @classmethod - def from_correction_config(cls, correction_config) -> "GeolocationConfig": - """ - Create GeolocationConfig from CorrectionConfig. + def from_correction_config(cls, correction_config) -> "ErrorStatsConfig": + """Create an :class:`ErrorStatsConfig` from a :class:`CorrectionConfig`. - This is the preferred way to create this config - extracts all settings - from the single source of truth (CorrectionConfig). + This is the preferred way to create this config — it extracts all + settings from the single source of truth (:class:`CorrectionConfig`). - Args: - correction_config: CorrectionConfig instance + Parameters + ---------- + correction_config : CorrectionConfig + Top-level correction configuration. - Returns: - GeolocationConfig with settings from CorrectionConfig + Returns + ------- + ErrorStatsConfig """ - # Create variable name mapping variable_names = { "spacecraft_position": correction_config.spacecraft_position_name, "boresight": correction_config.boresight_name, @@ -61,9 +113,6 @@ def from_correction_config(cls, correction_config) -> "GeolocationConfig": } return cls( - earth_radius_m=correction_config.earth_radius_m, - performance_threshold_m=correction_config.performance_threshold_m, - performance_spec_percent=correction_config.performance_spec_percent, minimum_correlation=correction_config.geo.minimum_correlation, variable_names=variable_names, ) @@ -72,19 +121,25 @@ def get_variable_name(self, semantic_name: str) -> str: """ Get actual variable name for a semantic concept. - Args: - semantic_name: Semantic name like 'spacecraft_position', 'boresight', etc. + Parameters + ---------- + semantic_name : str + Semantic name like 'spacecraft_position', 'boresight', etc. - Returns: - Actual variable name in the dataset + Returns + ------- + str + Actual variable name in the dataset. - Raises: - ValueError: If variable_names not provided and semantic_name not found + Raises + ------ + ValueError + If variable_names is None or semantic_name is not found. """ if self.variable_names is None: raise ValueError( - f"GeolocationConfig.variable_names is None. " - f"Use GeolocationConfig.from_correction_config() to create config with proper variable names." + f"ErrorStatsConfig.variable_names is None. " + f"Use ErrorStatsConfig.from_correction_config() to create config with proper variable names." ) if semantic_name not in self.variable_names: @@ -99,17 +154,20 @@ def get_variable_name(self, semantic_name: str) -> str: class ErrorStatsProcessor: """Production-ready processor for geolocation error statistics.""" - def __init__(self, config: GeolocationConfig): + def __init__(self, config: ErrorStatsConfig): """ Initialize processor with configuration. - Args: - config: GeolocationConfig (required) - use GeolocationConfig.from_correction_config() - to create from CorrectionConfig + Parameters + ---------- + config : ErrorStatsConfig + Configuration for error statistics processing. Use + ``ErrorStatsConfig.from_correction_config()`` to create from + CorrectionConfig. """ if config is None: raise ValueError( - "GeolocationConfig is required. Use GeolocationConfig.from_correction_config(correction_config) to create." + "ErrorStatsConfig is required. Use ErrorStatsConfig.from_correction_config(correction_config) to create." ) self.config = config @@ -149,49 +207,61 @@ def _filter_by_correlation(self, data: xr.Dataset) -> xr.Dataset: return filtered_data - def process_geolocation_errors(self, input_data: xr.Dataset) -> xr.Dataset: - """ - Process geolocation errors from input dataset to nadir-equivalent statistics. - - Args: - input_data: Xarray Dataset with required error measurement variables - - Returns: - Xarray Dataset with processed results and statistics + def compute_nadir_equivalent_errors(self, input_data: xr.Dataset) -> xr.Dataset: + """Compute per-measurement nadir-equivalent errors WITHOUT aggregate statistics. + + This is the method to call inside the correction loop — it requires + observation geometry (spacecraft position, boresight, transformation + matrix) that is only available during each iteration, and produces + nadir-equivalent errors for each measurement. No aggregate statistics + are computed (meaningless for a single GCP pair in isolation). + + Use this inside the loop for checkpoint/resume support. + Call :meth:`process_geolocation_errors` for the final aggregate pass + (nadir-equivalent + comprehensive statistics). + + Parameters + ---------- + input_data : xr.Dataset + Dataset with required error measurement variables and a + ``measurement`` dimension. + + Returns + ------- + xr.Dataset + Dataset with ``nadir_equiv_total_error_m`` and related intermediate + variables. No statistical attributes are set on the output. + + Raises + ------ + ValueError + If required variables are missing or all measurements are filtered + out by the correlation threshold. """ - # Validate input data self._validate_input_data(input_data) - - # NEW: Apply correlation filtering if configured filtered_data = self._filter_by_correlation(input_data) if len(filtered_data.measurement) == 0: raise ValueError("No measurements remaining after correlation filtering") - # Extract data arrays (now using filtered_data) n_measurements = len(filtered_data.measurement) - # Get actual variable names from config sc_pos_var = self.config.get_variable_name("spacecraft_position") boresight_var = self.config.get_variable_name("boresight") transform_var = self.config.get_variable_name("transformation_matrix") - # Convert angular errors to distance errors lat_error_rad = np.deg2rad(filtered_data.lat_error_deg.values) lon_error_rad = np.deg2rad(filtered_data.lon_error_deg.values) gcp_lat_rad = np.deg2rad(filtered_data.gcp_lat_deg.values) gcp_lon_rad = np.deg2rad(filtered_data.gcp_lon_deg.values) - # Calculate N-S and E-W error distances in meters - ns_error_dist_m = self.config.earth_radius_m * lat_error_rad - ew_error_dist_m = self.config.earth_radius_m * np.cos(gcp_lat_rad) * lon_error_rad + ns_error_dist_m = _EARTH_RADIUS_M * lat_error_rad + ew_error_dist_m = _EARTH_RADIUS_M * np.cos(gcp_lat_rad) * lon_error_rad - # Transform boresight vectors using configurable variable names bhat_ctrs = self._transform_boresight_vectors( filtered_data[boresight_var].values, filtered_data[transform_var].values ) - # Process each measurement to nadir-equivalent using configurable variable name results = self._process_to_nadir_equivalent( ns_error_dist_m, ew_error_dist_m, @@ -202,11 +272,29 @@ def process_geolocation_errors(self, input_data: xr.Dataset) -> xr.Dataset: n_measurements, ) - # Create output dataset - output_data = self._create_output_dataset(filtered_data, results) + return self._create_output_dataset(filtered_data, results) - # Add statistics as global attributes - stats = self._calculate_statistics(results["nadir_equiv_total_error_m"]) + def process_geolocation_errors(self, input_data: xr.Dataset) -> xr.Dataset: + """Full processing: nadir-equivalent errors + aggregate statistics. + + Use this for final aggregation after the loop, or in :func:`verify`. + For per-iteration computation (single GCP pair), prefer + :meth:`compute_nadir_equivalent_errors` to avoid computing aggregate + statistics on a small or single-measurement sample. + + Parameters + ---------- + input_data : xr.Dataset + Dataset with required error measurement variables. + + Returns + ------- + xr.Dataset + Dataset with ``nadir_equiv_total_error_m`` and related intermediate + variables, plus comprehensive statistics as global attributes. + """ + output_data = self.compute_nadir_equivalent_errors(input_data) + stats = self._calculate_statistics(output_data["nadir_equiv_total_error_m"].values) output_data.attrs.update(stats) return output_data @@ -342,8 +430,8 @@ def _calculate_view_plane_vectors(self, bhat_uen: np.ndarray) -> tuple[np.ndarra def _calculate_scaling_factors(self, riss_ctrs: np.ndarray, theta: float) -> tuple[float, float]: """Calculate scaling factors for nadir-equivalent transformation.""" r_magnitude = np.linalg.norm(riss_ctrs) - f = r_magnitude / self.config.earth_radius_m - h = r_magnitude - self.config.earth_radius_m + f = r_magnitude / _EARTH_RADIUS_M + h = r_magnitude - _EARTH_RADIUS_M # Calculate discriminant for sqrt - should be positive for physically valid geometries discriminant = 1 - f**2 * np.sin(theta) ** 2 @@ -362,10 +450,10 @@ def _calculate_scaling_factors(self, riss_ctrs: np.ndarray, theta: float) -> tup temp1 = np.maximum(temp1, 1e-10) # View-plane scaling factor - vp_factor = h / self.config.earth_radius_m / (-1 + f * np.cos(theta) / temp1) + vp_factor = h / _EARTH_RADIUS_M / (-1 + f * np.cos(theta) / temp1) # Cross-view-plane scaling factor - xvp_factor = h / self.config.earth_radius_m / np.cos(theta) / (f * np.cos(theta) - temp1) + xvp_factor = h / _EARTH_RADIUS_M / np.cos(theta) / (f * np.cos(theta) - temp1) return vp_factor, xvp_factor @@ -437,8 +525,7 @@ def _create_output_dataset(self, input_data: xr.Dataset, results: dict[str, np.n attrs={ "title": "Geolocation Error Statistics Results", "processing_timestamp": np.datetime64("now"), - "earth_radius_m": self.config.earth_radius_m, - "performance_threshold_m": self.config.performance_threshold_m, + "earth_radius_m": _EARTH_RADIUS_M, }, ) @@ -450,27 +537,53 @@ def _create_output_dataset(self, input_data: xr.Dataset, results: dict[str, np.n return output_ds def _calculate_statistics(self, nadir_equiv_errors_m: np.ndarray) -> dict[str, float | int]: - """Calculate performance statistics on nadir-equivalent errors.""" - - # Count errors below threshold - num_below_threshold = np.sum(nadir_equiv_errors_m < self.config.performance_threshold_m) - - # Calculate statistics - stats = { - "mean_error_distance_m": float(np.mean(nadir_equiv_errors_m)), - "std_error_distance_m": float(np.std(nadir_equiv_errors_m)), - "min_error_distance_m": float(np.min(nadir_equiv_errors_m)), - "max_error_distance_m": float(np.max(nadir_equiv_errors_m)), - "percent_below_250m": float(num_below_threshold / len(nadir_equiv_errors_m) * 100), - "num_below_250m": int(num_below_threshold), - "total_measurements": int(len(nadir_equiv_errors_m)), - "performance_spec_met": bool( - num_below_threshold / len(nadir_equiv_errors_m) * 100 > self.config.performance_spec_percent - ), + """Calculate comprehensive, mission-agnostic performance statistics. + + This method intentionally does NOT include any pass/fail evaluation. + Whether these statistics meet mission requirements is the caller's + responsibility. Use :func:`compute_percent_below` for custom threshold + queries not covered by the standard table. + + Parameters + ---------- + nadir_equiv_errors_m : np.ndarray + Array of nadir-equivalent geolocation errors in metres. + + Returns + ------- + dict[str, float | int] + Keys: central tendency (``mean_error_m``, ``median_error_m``, + ``rms_error_m``), spread (``std_error_m``, ``min_error_m``, + ``max_error_m``), percentiles (``p25_error_m`` … ``p99_error_m``), + count (``total_measurements``), and a threshold table at standard + intervals (``percent_below_100m`` … ``percent_below_1000m``). + """ + n = len(nadir_equiv_errors_m) + return { + # Central tendency + "mean_error_m": float(np.mean(nadir_equiv_errors_m)), + "median_error_m": float(np.median(nadir_equiv_errors_m)), + "rms_error_m": float(np.sqrt(np.mean(nadir_equiv_errors_m**2))), + # Spread + "std_error_m": float(np.std(nadir_equiv_errors_m)), + "min_error_m": float(np.min(nadir_equiv_errors_m)), + "max_error_m": float(np.max(nadir_equiv_errors_m)), + # Percentiles + "p25_error_m": float(np.percentile(nadir_equiv_errors_m, 25)), + "p75_error_m": float(np.percentile(nadir_equiv_errors_m, 75)), + "p90_error_m": float(np.percentile(nadir_equiv_errors_m, 90)), + "p95_error_m": float(np.percentile(nadir_equiv_errors_m, 95)), + "p99_error_m": float(np.percentile(nadir_equiv_errors_m, 99)), + # Count + "total_measurements": int(n), + # Threshold table (standard intervals, for quick reference) + "percent_below_100m": float(np.sum(nadir_equiv_errors_m < 100.0) / n * 100), + "percent_below_250m": float(np.sum(nadir_equiv_errors_m < 250.0) / n * 100), + "percent_below_500m": float(np.sum(nadir_equiv_errors_m < 500.0) / n * 100), + "percent_below_750m": float(np.sum(nadir_equiv_errors_m < 750.0) / n * 100), + "percent_below_1000m": float(np.sum(nadir_equiv_errors_m < 1000.0) / n * 100), } - return stats - def process_from_netcdf(self, filepath: Union[str, "Path"], minimum_correlation: float | None = None) -> xr.Dataset: """ Load previous results from NetCDF and reprocess error statistics. diff --git a/curryer/correction/image_io.py b/curryer/correction/image_io.py new file mode 100644 index 00000000..467cbf58 --- /dev/null +++ b/curryer/correction/image_io.py @@ -0,0 +1,871 @@ +"""Image I/O utilities for correction pipeline. + +This module provides format-agnostic loading and saving of image data +used in the correction pipeline, including: +- MATLAB .mat files (legacy test data) +- HDF files (raw GCP chips) +- NetCDF files (regridded GCPs, outputs) + +All functions work with ImageGrid and related data structures. +No dependencies on image matching or other correction algorithms. +""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import numpy as np + +from .data_structures import ImageGrid, NamedImageGrid, OpticalPSFEntry + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# MATLAB File I/O (migrated from image_match.py) +# ============================================================================ + + +def load_image_grid_from_mat( + mat_file: str | Path, key: str = "subimage", name: str | None = None, as_named: bool = False +) -> ImageGrid | NamedImageGrid: + """ + Load ImageGrid from MATLAB .mat file. + + Parameters + ---------- + mat_file : str or Path + Path to .mat file (local path or ``s3://`` URI). + key : str, default="subimage" + MATLAB struct key (e.g., "subimage" for L1A, "GCP" for reference). + name : str, optional + Name for NamedImageGrid. Defaults to file path. + as_named : bool, default=False + If True, return NamedImageGrid; otherwise return ImageGrid. + + Returns + ------- + ImageGrid or NamedImageGrid + Loaded image grid with data, lat, lon, h fields. + + Raises + ------ + FileNotFoundError + If mat_file is a local path and doesn't exist. + ImportError + If mat_file is an S3 URI and boto3 is not installed. + KeyError + If key not found in MATLAB file. + + Examples + -------- + >>> # Load L1A subimage + >>> l1a = load_image_grid_from_mat(Path("subimage.mat"), key="subimage") + >>> # Load GCP reference + >>> gcp = load_image_grid_from_mat(Path("gcp.mat"), key="GCP") + """ + from scipy.io import loadmat + + from curryer.correction.io import resolve_path + + mat_file = resolve_path(mat_file) + # resolve_path already validated existence / downloaded from S3. + + mat_data = loadmat(str(mat_file), squeeze_me=True, struct_as_record=False) + + if key not in mat_data: + available_keys = [k for k in mat_data.keys() if not k.startswith("__")] + raise KeyError(f"Key '{key}' not found in {mat_file.name}. Available keys: {available_keys}") + + struct = mat_data[key] + h = getattr(struct, "h", None) + + # Optimize: loadmat already returns numpy arrays, avoid redundant asarray() calls + # ImageGrid.__post_init__ will handle final type conversion + grid_kwargs = { + "data": struct.data, + "lat": struct.lat, + "lon": struct.lon, + "h": h, + } + + if as_named: + grid_kwargs["name"] = name or str(mat_file) + return NamedImageGrid(**grid_kwargs) + else: + return ImageGrid(**grid_kwargs) + + +def load_optical_psf_from_mat(mat_file: str | Path, key: str = "PSF_struct_675nm") -> list[OpticalPSFEntry]: + """ + Load optical PSF entries from MATLAB .mat file. + + Parameters + ---------- + mat_file : str or Path + Path to MATLAB file with PSF data (local path or ``s3://`` URI). + key : str, default="PSF_struct_675nm" + Primary key to try for PSF data. + + Returns + ------- + list[OpticalPSFEntry] + Optical PSF samples with data, x, and field_angle arrays. + + Raises + ------ + FileNotFoundError + If mat_file is a local path and doesn't exist. + ImportError + If mat_file is an S3 URI and boto3 is not installed. + KeyError + If no PSF data found with common key names. + ValueError + If PSF entries missing field angle attribute. + """ + from scipy.io import loadmat + + from curryer.correction.io import resolve_path + + mat_file = resolve_path(mat_file) + # resolve_path already validated existence / downloaded from S3. + + mat_data = loadmat(str(mat_file), squeeze_me=True, struct_as_record=False) + + # Try common keys in order of preference + for try_key in [key, "PSF_struct_675nm", "optical_PSF", "PSF"]: + if try_key in mat_data: + psf_struct = mat_data[try_key] + psf_entries_raw = np.atleast_1d(psf_struct) + + psf_entries = [] + for entry in psf_entries_raw: + # Handle both 'FA' and 'field_angle' attribute names + # Check if attribute exists first to avoid NumPy array boolean ambiguity + field_angle = getattr(entry, "FA", None) + if field_angle is None or ( + isinstance(field_angle, list | tuple | np.ndarray) and len(field_angle) == 0 + ): + # Fallback if FA is missing, None, or empty + field_angle = getattr(entry, "field_angle", None) + + if field_angle is None: + raise ValueError(f"PSF entry missing field angle attribute (tried 'FA' and 'field_angle')") + + # Optimize: loadmat already returns numpy arrays, OpticalPSFEntry.__post_init__ handles conversion + # Use np.atleast_1d to ensure 1D arrays efficiently + psf_entries.append( + OpticalPSFEntry( + data=entry.data, + x=np.atleast_1d(entry.x).ravel(), + field_angle=np.atleast_1d(field_angle).ravel(), + ) + ) + + logger.info(f"Loaded {len(psf_entries)} optical PSF entries from {mat_file.name}") + return psf_entries + + available_keys = [k for k in mat_data.keys() if not k.startswith("__")] + raise KeyError(f"No PSF data found in {mat_file.name}. Available keys: {available_keys}") + + +def load_los_vectors_from_mat(mat_file: str | Path, key: str = "b_HS") -> np.ndarray: + """ + Load line-of-sight vectors from MATLAB .mat file. + + Parameters + ---------- + mat_file : str or Path + Path to MATLAB file with LOS vectors (local path or ``s3://`` URI). + key : str, default="b_HS" + Primary key to try for LOS data. + + Returns + ------- + np.ndarray + LOS unit vectors in instrument frame, shape (n_pixels, 3). + + Raises + ------ + FileNotFoundError + If mat_file is a local path and doesn't exist. + ImportError + If mat_file is an S3 URI and boto3 is not installed. + KeyError + If no LOS vectors found with common key names. + """ + from scipy.io import loadmat + + from curryer.correction.io import resolve_path + + mat_file = resolve_path(mat_file) + # resolve_path already validated existence / downloaded from S3. + + mat_data = loadmat(str(mat_file)) + + # Try common keys in order of preference + for try_key in [key, "b_HS", "los_vectors", "pixel_vectors"]: + if try_key in mat_data: + los = mat_data[try_key] + + # Ensure shape is (n_pixels, 3) not (3, n_pixels) + if los.shape[0] == 3 and los.shape[1] > 3: + los = los.T + + logger.info(f"Loaded LOS vectors from {mat_file.name}: shape {los.shape}") + return los + + available_keys = [k for k in mat_data.keys() if not k.startswith("__")] + raise KeyError(f"No LOS vectors found in {mat_file.name}. Available keys: {available_keys}") + + +# ============================================================================ +# NetCDF File I/O (for regridded chips and general image grids) +# ============================================================================ + + +def save_image_grid_to_netcdf( + filepath: Path, + image_grid: ImageGrid, + metadata: dict[str, str] | None = None, + compression: bool = True, +) -> None: + """ + Save ImageGrid to NetCDF file (CF-1.8 compliant). + + This function can be used for any ImageGrid, including regridded GCP chips, + L1A subimages, or other gridded data. + + Parameters + ---------- + filepath : Path + Output NetCDF file path. + image_grid : ImageGrid + Image data with lat/lon coordinates to save. + metadata : dict[str, str], optional + Additional global attributes to include in the NetCDF file. + Common keys: 'source_file', 'mission', 'sensor', 'processing_date', 'band'. + compression : bool, default=True + Enable zlib compression for data variables (~50% size reduction). + + Raises + ------ + ImportError + If netCDF4 is not installed. + OSError + If file cannot be written. + + Notes + ----- + The NetCDF file follows CF-1.8 conventions and contains: + - Variables: 'band_data', 'lat', 'lon', 'h' (if available) + - Dimensions: 'y' (rows), 'x' (columns) + - Coordinates: lat(y, x) or (y,), lon(y, x) or (x,) + - Attributes: grid info, CRS metadata, processing information + + For regular grids (1D coordinates), variables are stored efficiently as: + - lat(y): Single column + - lon(x): Single row + + For irregular grids (2D coordinates), full arrays are stored: + - lat(y, x): Full grid + - lon(y, x): Full grid + + Examples + -------- + Save regridded GCP chip: + + >>> save_image_grid_to_netcdf( + ... Path("regridded.nc"), + ... regridded_chip, + ... metadata={ + ... 'source_file': 'LT08CHP.20140803.p002r071.c01.v001.hdf', + ... 'mission': 'CLARREO Pathfinder', + ... 'sensor': 'Landsat-8', + ... 'band': 'red', + ... 'processing_date': '2026-02-02' + ... } + ... ) + + Save L1A subimage: + + >>> save_image_grid_to_netcdf( + ... Path("l1a_subimage.nc"), + ... l1a_grid, + ... metadata={'mission': 'CLARREO', 'level': 'L1A'} + ... ) + """ + try: + import datetime + + from netCDF4 import Dataset + except ImportError as e: + raise ImportError("netCDF4 is required to save NetCDF files. Install with: pip install netCDF4") from e + + filepath = Path(filepath) + logger.info(f"Saving ImageGrid to NetCDF: {filepath}") + + # Create NetCDF file + with Dataset(filepath, "w", format="NETCDF4") as nc: + # Global attributes + nc.setncattr("title", "Regridded GCP Chip") + nc.setncattr("institution", "NASA Langley Research Center") + nc.setncattr("source", "Curryer GCP Regridding Module") + nc.setncattr( + "history", f"Created {datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ')}" + ) + nc.setncattr("Conventions", "CF-1.8") + nc.setncattr("grid_type", "regular_lat_lon") + + # Add user-provided metadata + if metadata: + for key, value in metadata.items(): + nc.setncattr(key, str(value)) + + # Create dimensions + nrows, ncols = image_grid.data.shape + nc.createDimension("y", nrows) + nc.createDimension("x", ncols) + + # Compression settings + comp_kwargs = {"zlib": True, "complevel": 4} if compression else {} + + # Determine if coordinates are 1D or 2D + lat_is_1d = image_grid.lat.ndim == 1 or ( + image_grid.lat.ndim == 2 and np.allclose(image_grid.lat, image_grid.lat[:, 0:1]) + ) + lon_is_1d = image_grid.lon.ndim == 1 or ( + image_grid.lon.ndim == 2 and np.allclose(image_grid.lon, image_grid.lon[0:1, :]) + ) + + # Create coordinate variables + # Variable names match load_gcp_chip_from_netcdf defaults: lat, lon, band_data, h + if lat_is_1d: + # 1D latitude (varies with y only) + lat_var = nc.createVariable("lat", "f8", ("y",), **comp_kwargs) + lat_var[:] = image_grid.lat[:, 0] if image_grid.lat.ndim == 2 else image_grid.lat + else: + # 2D latitude + lat_var = nc.createVariable("lat", "f8", ("y", "x"), **comp_kwargs) + lat_var[:] = image_grid.lat + + lat_var.units = "degrees_north" + lat_var.long_name = "latitude" + lat_var.standard_name = "latitude" + + if lon_is_1d: + # 1D longitude (varies with x only) + lon_var = nc.createVariable("lon", "f8", ("x",), **comp_kwargs) + lon_var[:] = image_grid.lon[0, :] if image_grid.lon.ndim == 2 else image_grid.lon + else: + # 2D longitude + lon_var = nc.createVariable("lon", "f8", ("y", "x"), **comp_kwargs) + lon_var[:] = image_grid.lon + + lon_var.units = "degrees_east" + lon_var.long_name = "longitude" + lon_var.standard_name = "longitude" + + # Create data variable + data_var = nc.createVariable("band_data", "f8", ("y", "x"), fill_value=np.nan, **comp_kwargs) + data_var[:] = image_grid.data + data_var.long_name = "regridded_radiance" + data_var.units = "DN" + data_var.coordinates = "lat lon" + data_var.grid_mapping = "crs" + + # Add grid statistics as attributes + valid_mask = ~np.isnan(image_grid.data) + valid_pixels = int(np.sum(valid_mask)) + data_var.setncattr("valid_pixels", valid_pixels) + if valid_pixels > 0: + data_var.setncattr("valid_min", float(np.nanmin(image_grid.data))) + data_var.setncattr("valid_max", float(np.nanmax(image_grid.data))) + + # Add height if available + if image_grid.h is not None: + h_var = nc.createVariable("h", "f8", ("y", "x"), fill_value=np.nan, **comp_kwargs) + h_var[:] = image_grid.h + h_var.units = "meters" + h_var.long_name = "height_above_reference_ellipsoid" + h_var.standard_name = "height_above_reference_ellipsoid" + h_var.coordinates = "lat lon" + + # Add CRS information (WGS84) + crs = nc.createVariable("crs", "i4") + crs.grid_mapping_name = "latitude_longitude" + crs.semi_major_axis = 6378137.0 + crs.inverse_flattening = 298.257223563 + crs.long_name = "WGS84" + crs.crs_wkt = 'GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563]],PRIMEM["Greenwich",0],UNIT["degree",0.0174532925199433]]' + + logger.info(f"NetCDF file saved successfully: {filepath} ({filepath.stat().st_size / 1024:.1f} KB)") + + +def load_image_grid_from_netcdf(filepath: Path) -> ImageGrid: + """ + Load ImageGrid from NetCDF file. + + Loads regridded GCP chips or other gridded data saved in NetCDF format. + Automatically handles both 1D and 2D coordinate arrays. + + Parameters + ---------- + filepath : Path + Input NetCDF file path. + + Returns + ------- + ImageGrid + Loaded image grid with data, lat, lon, and h (if available). + + Raises + ------ + ImportError + If netCDF4 is not installed. + FileNotFoundError + If filepath doesn't exist. + KeyError + If required variables not found in file. + + Examples + -------- + >>> regridded = load_image_grid_from_netcdf(Path("regridded.nc")) + >>> regridded.data.shape + (421, 433) + """ + try: + from netCDF4 import Dataset + except ImportError as e: + raise ImportError("netCDF4 is required to load NetCDF files. Install with: pip install netCDF4") from e + + filepath = Path(filepath) + + if not filepath.exists(): + raise FileNotFoundError(f"NetCDF file not found: {filepath}") + + logger.info(f"Loading ImageGrid from NetCDF: {filepath}") + + with Dataset(filepath, "r") as nc: + # Load data — try canonical name first, then legacy name + data_name = next((n for n in ("band_data", "data") if n in nc.variables), None) + if data_name is None: + raise KeyError(f"Required variable 'data' not found in {filepath.name}") + data = nc.variables[data_name][:] + + # Load coordinates — try canonical names first, then legacy names + lat_name = next((n for n in ("lat", "latitude") if n in nc.variables), None) + lon_name = next((n for n in ("lon", "longitude") if n in nc.variables), None) + if lat_name is None: + raise KeyError(f"Required variable 'latitude' not found in {filepath.name}") + if lon_name is None: + raise KeyError(f"Required variable 'longitude' not found in {filepath.name}") + + lat_var = nc.variables[lat_name] + lon_var = nc.variables[lon_name] + + # Handle 1D or 2D coordinates + if lat_var.ndim == 1: + # Expand 1D to 2D + lat_1d = lat_var[:] + lon_1d = lon_var[:] + lon, lat = np.meshgrid(lon_1d, lat_1d) + else: + lat = lat_var[:] + lon = lon_var[:] + + # Load height if available — try canonical name first, then legacy name + h_name = next((n for n in ("h", "height") if n in nc.variables), None) + h = nc.variables[h_name][:] if h_name is not None else None + + logger.info(f"Loaded ImageGrid from NetCDF: shape {data.shape}") + return ImageGrid(data=data, lat=lat, lon=lon, h=h) + + +# ============================================================================ +# HDF File I/O (new for regridding) +# ============================================================================ + + +def load_gcp_chip_from_hdf( + filepath: Path, + band_name: str = "Band_1", + coord_names: tuple[str, str, str] = ( + "ECR_x_coordinate_array", + "ECR_y_coordinate_array", + "ECR_z_coordinate_array", + ), +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ + Load raw GCP chip data from HDF file (Landsat format). + + Supports both HDF4 and HDF5 formats. Tries HDF4 first (Landsat standard), + then falls back to HDF5 if needed. + + Parameters + ---------- + filepath : Path + Path to HDF file containing GCP chip data. + band_name : str, default="Band_1" + Name of the dataset containing band/radiometric data. + coord_names : tuple[str, str, str], default=("ECR_x_coordinate_array", ...) + Names of X, Y, Z coordinate datasets (ECEF coordinates in meters). + + Returns + ------- + band_data : np.ndarray + 2D array of radiometric values, shape (nrows, ncols). + ecef_x, ecef_y, ecef_z : np.ndarray + 2D arrays of ECEF coordinates in meters, each shape (nrows, ncols). + + Raises + ------ + FileNotFoundError + If filepath doesn't exist. + KeyError + If required datasets not found in HDF file. + ValueError + If array shapes are inconsistent. + + Examples + -------- + >>> band, x, y, z = load_gcp_chip_from_hdf("LT08CHP.20140803.hdf") + >>> band.shape + (1400, 1400) + """ + if not filepath.exists(): + raise FileNotFoundError(f"HDF file not found: {filepath}") + + coord_x_name, coord_y_name, coord_z_name = coord_names + + # Try HDF4 first (Landsat standard format) + hdf4_error = None + try: + from pyhdf.SD import SD, SDC + + hdf = SD(str(filepath), SDC.READ) + datasets = hdf.datasets() + + # Check for required datasets + if band_name not in datasets: + available = list(datasets.keys()) + hdf.end() + raise KeyError(f"Band '{band_name}' not found. Available datasets: {available}") + + if coord_x_name not in datasets: + hdf.end() + raise KeyError(f"X coordinate '{coord_x_name}' not found in {filepath.name}") + if coord_y_name not in datasets: + hdf.end() + raise KeyError(f"Y coordinate '{coord_y_name}' not found in {filepath.name}") + if coord_z_name not in datasets: + hdf.end() + raise KeyError(f"Z coordinate '{coord_z_name}' not found in {filepath.name}") + + # Load datasets + band_data = np.array(hdf.select(band_name).get(), dtype=np.float64) + ecef_x = np.array(hdf.select(coord_x_name).get(), dtype=np.float64) + ecef_y = np.array(hdf.select(coord_y_name).get(), dtype=np.float64) + ecef_z = np.array(hdf.select(coord_z_name).get(), dtype=np.float64) + + hdf.end() + + logger.info(f"Loaded GCP chip from HDF4 file {filepath.name}: shape {band_data.shape}") + return band_data, ecef_x, ecef_y, ecef_z + + except ImportError: + # pyhdf not available, will try h5py below + pass + except Exception as e: + # pyhdf raised an error (possibly HDF4Error because file is HDF5), store and try h5py + hdf4_error = e + + # Try HDF5 if HDF4 failed or pyhdf not available + try: + import h5py + + with h5py.File(filepath, "r") as hdf: + # Load band data + if band_name not in hdf: + available = list(hdf.keys()) + raise KeyError(f"Band '{band_name}' not found. Available datasets: {available}") + + band_data = np.array(hdf[band_name], dtype=np.float64) + + # Load ECEF coordinates + if coord_x_name not in hdf: + raise KeyError(f"X coordinate '{coord_x_name}' not found in {filepath.name}") + if coord_y_name not in hdf: + raise KeyError(f"Y coordinate '{coord_y_name}' not found in {filepath.name}") + if coord_z_name not in hdf: + raise KeyError(f"Z coordinate '{coord_z_name}' not found in {filepath.name}") + + ecef_x = np.array(hdf[coord_x_name], dtype=np.float64) + ecef_y = np.array(hdf[coord_y_name], dtype=np.float64) + ecef_z = np.array(hdf[coord_z_name], dtype=np.float64) + + logger.info(f"Loaded GCP chip from HDF5 file {filepath.name}: shape {band_data.shape}") + + except ImportError as e: + # Neither library available + error_msg = f"Cannot read HDF file {filepath}. Neither pyhdf (for HDF4) nor h5py (for HDF5) is available." + if hdf4_error: + error_msg += f" HDF4 error: {hdf4_error}" + raise ImportError(error_msg) from e + except (KeyError, OSError, ValueError): + # Re-raise validation/IO errors from h5py + raise + + # Validate shapes + if not (band_data.shape == ecef_x.shape == ecef_y.shape == ecef_z.shape): + raise ValueError( + f"Array shape mismatch in {filepath.name}: " + f"band={band_data.shape}, x={ecef_x.shape}, y={ecef_y.shape}, z={ecef_z.shape}" + ) + + return band_data, ecef_x, ecef_y, ecef_z + + +def load_gcp_chip_from_netcdf( + filepath: Path, + band_var: str = "band_data", + lat_var: str = "lat", + lon_var: str = "lon", + height_var: str = "h", +) -> ImageGrid: + """ + Load regridded GCP chip from NetCDF file. + + Parameters + ---------- + filepath : Path + Path to NetCDF file. + band_var : str, default="band_data" + Name of the band data variable. + lat_var : str, default="lat" + Name of the latitude variable. + lon_var : str, default="lon" + Name of the longitude variable. + height_var : str, default="h" + Name of the height variable (optional). + + Returns + ------- + ImageGrid + Loaded image grid with data, lat, lon, h fields. + + Raises + ------ + FileNotFoundError + If filepath doesn't exist. + KeyError + If required variables not found. + + Examples + -------- + >>> gcp = load_gcp_chip_from_netcdf("regridded_chip.nc") + >>> gcp.data.shape + (420, 420) + """ + import xarray as xr + + if not filepath.exists(): + raise FileNotFoundError(f"NetCDF file not found: {filepath}") + + try: + ds = xr.open_dataset(filepath) + + # Check required variables + if band_var not in ds: + raise KeyError(f"Band variable '{band_var}' not found in {filepath.name}") + if lat_var not in ds: + raise KeyError(f"Latitude variable '{lat_var}' not found in {filepath.name}") + if lon_var not in ds: + raise KeyError(f"Longitude variable '{lon_var}' not found in {filepath.name}") + + # Load data + data = ds[band_var].values + lat = ds[lat_var].values + lon = ds[lon_var].values + h = ds[height_var].values if height_var in ds else None + + ds.close() + + except Exception as e: + raise OSError(f"Error reading NetCDF file {filepath}: {e}") from e + + logger.info(f"Loaded GCP chip from {filepath.name}: shape {data.shape}") + + return ImageGrid(data=data, lat=lat, lon=lon, h=h) + + +# ============================================================================ +# Generic Image Savers (new for regridding + general use) +# ============================================================================ + + +def save_image_grid( + filepath: Path, + image_grid: ImageGrid, + format: str = "netcdf", + metadata: dict | None = None, +) -> None: + """ + Save ImageGrid to file (netcdf, mat, or geotiff). + + This function works with any ImageGrid, not just regridded GCPs. + It can be used throughout the correction pipeline. + + Parameters + ---------- + filepath : Path + Output file path. + image_grid : ImageGrid + Image data with lat/lon coordinates. + format : str, default="netcdf" + Output format: 'netcdf', 'mat', or 'geotiff'. + metadata : dict, optional + Additional metadata to include in output. + + Raises + ------ + ValueError + If format is not supported. + IOError + If file cannot be written. + + Examples + -------- + >>> save_image_grid("output.nc", regridded_gcp, format="netcdf") + >>> save_image_grid("output.mat", regridded_gcp, format="mat") + """ + format = format.lower() + + if format == "netcdf": + _save_image_grid_netcdf(filepath, image_grid, metadata) + elif format == "mat": + _save_image_grid_mat(filepath, image_grid, metadata) + elif format == "geotiff": + _save_image_grid_geotiff(filepath, image_grid, metadata) + else: + raise ValueError(f"Unsupported format: {format}. Supported: 'netcdf', 'mat', 'geotiff'") + + logger.info(f"Saved ImageGrid to {filepath} (format: {format})") + + +def _save_image_grid_netcdf(filepath: Path, image_grid: ImageGrid, metadata: dict | None) -> None: + """Save ImageGrid to NetCDF file (internal helper).""" + import xarray as xr + + # Create xarray Dataset + nrows, ncols = image_grid.data.shape + + ds = xr.Dataset( + { + "band_data": (["y", "x"], image_grid.data), + "lat": (["y", "x"], image_grid.lat), + "lon": (["y", "x"], image_grid.lon), + }, + coords={ + "y": np.arange(nrows), + "x": np.arange(ncols), + }, + ) + + # Add height if present + if image_grid.h is not None: + ds["h"] = (["y", "x"], image_grid.h) + + # Add metadata + if metadata: + ds.attrs.update(metadata) + + # Add standard attributes + ds.attrs["title"] = "Regridded GCP Chip" + ds.attrs["Conventions"] = "CF-1.8" + + # Add variable attributes + ds["band_data"].attrs["long_name"] = "Band radiometric data" + ds["band_data"].attrs["units"] = "digital_number" + ds["lat"].attrs["long_name"] = "Latitude" + ds["lat"].attrs["units"] = "degrees_north" + ds["lon"].attrs["long_name"] = "Longitude" + ds["lon"].attrs["units"] = "degrees_east" + + if "h" in ds: + ds["h"].attrs["long_name"] = "Height above ellipsoid" + ds["h"].attrs["units"] = "meters" + + # Write to file + try: + ds.to_netcdf(filepath, engine="netcdf4") + except Exception as e: + raise OSError(f"Error writing NetCDF file {filepath}: {e}") from e + + +def _save_image_grid_mat(filepath: Path, image_grid: ImageGrid, metadata: dict | None) -> None: + """Save ImageGrid to MATLAB .mat file (internal helper).""" + from scipy.io import savemat + + # Prepare data structure + mat_dict = { + "data": image_grid.data, + "lat": image_grid.lat, + "lon": image_grid.lon, + } + + if image_grid.h is not None: + mat_dict["h"] = image_grid.h + + if metadata: + mat_dict["metadata"] = metadata + + # Write to file + try: + savemat(filepath, {"GCP": mat_dict}, do_compression=True) + except Exception as e: + raise OSError(f"Error writing MAT file {filepath}: {e}") from e + + +def _save_image_grid_geotiff(filepath: Path, image_grid: ImageGrid, metadata: dict | None) -> None: + """Save ImageGrid to GeoTIFF file (internal helper).""" + try: + import rasterio + from rasterio.transform import from_bounds + except ImportError as e: + raise ImportError("GeoTIFF export requires 'rasterio'. Install with: pip install rasterio") from e + + nrows, ncols = image_grid.data.shape + + # Compute geographic bounds + minlon = image_grid.lon.min() + maxlon = image_grid.lon.max() + minlat = image_grid.lat.min() + maxlat = image_grid.lat.max() + + # Create affine transform + transform = from_bounds(minlon, minlat, maxlon, maxlat, ncols, nrows) + + # Write to file + try: + with rasterio.open( + filepath, + "w", + driver="GTiff", + height=nrows, + width=ncols, + count=1, + dtype=image_grid.data.dtype, + crs="EPSG:4326", # WGS84 + transform=transform, + compress="lzw", + ) as dst: + dst.write(image_grid.data, 1) + + # Add metadata as tags + if metadata: + dst.update_tags(**metadata) + + except Exception as e: + raise OSError(f"Error writing GeoTIFF file {filepath}: {e}") from e diff --git a/curryer/correction/image_match.py b/curryer/correction/image_match.py index 507bdd0f..8916bc16 100644 --- a/curryer/correction/image_match.py +++ b/curryer/correction/image_match.py @@ -1,20 +1,21 @@ from __future__ import annotations import logging +import warnings from collections.abc import Iterable from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Protocol +from typing import TYPE_CHECKING import numpy as np from .data_structures import ( - GeolocationConfig, ImageGrid, NamedImageGrid, OpticalPSFEntry, ProjectedPSF, PSFGrid, + PSFSamplingConfig, SearchConfig, ) from .psf import ( @@ -28,60 +29,16 @@ from .search import im_search if TYPE_CHECKING: - import pandas as pd import xarray as xr logger = logging.getLogger(__name__) # ============================================================================ -# Image Matching Interface Protocol +# Output Validation # ============================================================================ -class ImageMatchingFunc(Protocol): - """ - Protocol for image matching functions in Correction pipeline. - - Image matching functions perform spatial correlation between geolocated - observations and GCP reference imagery to measure geolocation errors. - - Standard Signature: - def image_matching( - geolocated_data: xr.Dataset, - gcp_reference_file: Path, - telemetry: pd.DataFrame, - calibration_dir: Path, - params_info: list, - config, - los_vectors_cached: Optional[np.ndarray] = None, - optical_psfs_cached: Optional[list] = None, - ) -> xr.Dataset - - Required Output Fields: - - lat_error_deg: (measurement,) Latitude errors in degrees - - lon_error_deg: (measurement,) Longitude errors in degrees - - gcp_lat_deg, gcp_lon_deg, gcp_alt: GCP location - - Spacecraft state: position, boresight, transformation matrix - - See correction.image_matching() for reference implementation. - """ - - def __call__( - self, - geolocated_data: xr.Dataset, - gcp_reference_file: Path, - telemetry: pd.DataFrame, - calibration_dir: Path, - params_info: list, - config, - los_vectors_cached: np.ndarray | None = None, - optical_psfs_cached: list | None = None, - ) -> xr.Dataset: - """Perform image matching and return error measurements.""" - ... - - def validate_image_matching_output(output: xr.Dataset) -> None: """ Validate image matching output conforms to expected format. @@ -156,7 +113,7 @@ def integrated_image_match( r_iss_midframe_m: np.ndarray, los_vectors_hs: np.ndarray, optical_psfs: Iterable[OpticalPSFEntry], - geolocation_config: GeolocationConfig | None = None, + geolocation_config: PSFSamplingConfig | None = None, search_config: SearchConfig | None = None, ) -> IntegratedImageMatchResult: """ @@ -177,7 +134,7 @@ def integrated_image_match( Line-of-sight vectors in instrument frame, shape (n_pixels, 3). optical_psfs : Iterable[OpticalPSFEntry] Optical PSF samples at different field angles. - geolocation_config : GeolocationConfig, optional + geolocation_config : PSFSamplingConfig, optional PSF geolocation parameters. Defaults to standard config. search_config : SearchConfig, optional Image search parameters. Defaults to standard config. @@ -188,7 +145,7 @@ def integrated_image_match( Geolocation errors, correlation value, and intermediate products. """ - geo_config = geolocation_config or GeolocationConfig() + geo_config = geolocation_config or PSFSamplingConfig() search_cfg = search_config or SearchConfig() logger.debug("Projecting the PSF...") @@ -233,7 +190,11 @@ def integrated_image_match( # ============================================================================ -# MATLAB File Loading Utilities +# MATLAB File Loading Utilities — Deprecation Shims +# +# These functions have been moved to curryer.correction.image_io. +# They are kept here for backward compatibility and will be removed in a +# future release. # ============================================================================ @@ -243,176 +204,51 @@ def load_image_grid_from_mat( """ Load ImageGrid from MATLAB .mat file. - Parameters - ---------- - mat_file : Path - Path to .mat file. - key : str, default="subimage" - MATLAB struct key (e.g., "subimage" for L1A, "GCP" for reference). - name : str, optional - Name for NamedImageGrid. Defaults to file path. - as_named : bool, default=False - If True, return NamedImageGrid; otherwise return ImageGrid. - - Returns - ------- - ImageGrid or NamedImageGrid - Loaded image grid with data, lat, lon, h fields. - - Raises - ------ - FileNotFoundError - If mat_file doesn't exist. - KeyError - If key not found in MATLAB file. - - Examples - -------- - >>> # Load L1A subimage - >>> l1a = load_image_grid_from_mat(Path("subimage.mat"), key="subimage") - >>> # Load GCP reference - >>> gcp = load_image_grid_from_mat(Path("gcp.mat"), key="GCP") + .. deprecated:: + Use :func:`curryer.correction.image_io.load_image_grid_from_mat` instead. """ - from scipy.io import loadmat - - if not mat_file.exists(): - raise FileNotFoundError(f"MATLAB file not found: {mat_file}") - - mat_data = loadmat(str(mat_file), squeeze_me=True, struct_as_record=False) - - if key not in mat_data: - available_keys = [k for k in mat_data.keys() if not k.startswith("__")] - raise KeyError(f"Key '{key}' not found in {mat_file.name}. Available keys: {available_keys}") - - struct = mat_data[key] - h = getattr(struct, "h", None) - - # Optimize: loadmat already returns numpy arrays, avoid redundant asarray() calls - # ImageGrid.__post_init__ will handle final type conversion - grid_kwargs = { - "data": struct.data, - "lat": struct.lat, - "lon": struct.lon, - "h": h, - } + warnings.warn( + "load_image_grid_from_mat has moved to curryer.correction.image_io. " + "Importing from image_match is deprecated and will be removed in a future release.", + DeprecationWarning, + stacklevel=2, + ) + from .image_io import load_image_grid_from_mat as _impl - if as_named: - grid_kwargs["name"] = name or str(mat_file) - return NamedImageGrid(**grid_kwargs) - else: - return ImageGrid(**grid_kwargs) + return _impl(mat_file, key=key, name=name, as_named=as_named) def load_optical_psf_from_mat(mat_file: Path, key: str = "PSF_struct_675nm") -> list[OpticalPSFEntry]: """ Load optical PSF entries from MATLAB .mat file. - Parameters - ---------- - mat_file : Path - Path to MATLAB file with PSF data. - key : str, default="PSF_struct_675nm" - Primary key to try for PSF data. - - Returns - ------- - list[OpticalPSFEntry] - Optical PSF samples with data, x, and field_angle arrays. - - Raises - ------ - FileNotFoundError - If mat_file doesn't exist. - KeyError - If no PSF data found with common key names. - ValueError - If PSF entries missing field angle attribute. + .. deprecated:: + Use :func:`curryer.correction.image_io.load_optical_psf_from_mat` instead. """ - from scipy.io import loadmat - - if not mat_file.exists(): - raise FileNotFoundError(f"Optical PSF file not found: {mat_file}") - - mat_data = loadmat(str(mat_file), squeeze_me=True, struct_as_record=False) - - # Try common keys in order of preference - for try_key in [key, "PSF_struct_675nm", "optical_PSF", "PSF"]: - if try_key in mat_data: - psf_struct = mat_data[try_key] - psf_entries_raw = np.atleast_1d(psf_struct) - - psf_entries = [] - for entry in psf_entries_raw: - # Handle both 'FA' and 'field_angle' attribute names - # Check if attribute exists first to avoid NumPy array boolean ambiguity - field_angle = getattr(entry, "FA", None) - if field_angle is None or ( - isinstance(field_angle, list | tuple | np.ndarray) and len(field_angle) == 0 - ): - # Fallback if FA is missing, None, or empty - field_angle = getattr(entry, "field_angle", None) - - if field_angle is None: - raise ValueError(f"PSF entry missing field angle attribute (tried 'FA' and 'field_angle')") - - # Optimize: loadmat already returns numpy arrays, OpticalPSFEntry.__post_init__ handles conversion - # Use np.atleast_1d to ensure 1D arrays efficiently - psf_entries.append( - OpticalPSFEntry( - data=entry.data, - x=np.atleast_1d(entry.x).ravel(), - field_angle=np.atleast_1d(field_angle).ravel(), - ) - ) - - logger.info(f"Loaded {len(psf_entries)} optical PSF entries from {mat_file.name}") - return psf_entries - - available_keys = [k for k in mat_data.keys() if not k.startswith("__")] - raise KeyError(f"No PSF data found in {mat_file.name}. Available keys: {available_keys}") + warnings.warn( + "load_optical_psf_from_mat has moved to curryer.correction.image_io. " + "Importing from image_match is deprecated and will be removed in a future release.", + DeprecationWarning, + stacklevel=2, + ) + from .image_io import load_optical_psf_from_mat as _impl + + return _impl(mat_file, key=key) def load_los_vectors_from_mat(mat_file: Path, key: str = "b_HS") -> np.ndarray: """ Load line-of-sight vectors from MATLAB .mat file. - Parameters - ---------- - mat_file : Path - Path to MATLAB file with LOS vectors. - key : str, default="b_HS" - Primary key to try for LOS data. - - Returns - ------- - np.ndarray - LOS unit vectors in instrument frame, shape (n_pixels, 3). - - Raises - ------ - FileNotFoundError - If mat_file doesn't exist. - KeyError - If no LOS vectors found with common key names. + .. deprecated:: + Use :func:`curryer.correction.image_io.load_los_vectors_from_mat` instead. """ - from scipy.io import loadmat - - if not mat_file.exists(): - raise FileNotFoundError(f"LOS vector file not found: {mat_file}") - - mat_data = loadmat(str(mat_file)) - - # Try common keys in order of preference - for try_key in [key, "b_HS", "los_vectors", "pixel_vectors"]: - if try_key in mat_data: - los = mat_data[try_key] - - # Ensure shape is (n_pixels, 3) not (3, n_pixels) - if los.shape[0] == 3 and los.shape[1] > 3: - los = los.T - - logger.info(f"Loaded LOS vectors from {mat_file.name}: shape {los.shape}") - return los + warnings.warn( + "load_los_vectors_from_mat has moved to curryer.correction.image_io. " + "Importing from image_match is deprecated and will be removed in a future release.", + DeprecationWarning, + stacklevel=2, + ) + from .image_io import load_los_vectors_from_mat as _impl - available_keys = [k for k in mat_data.keys() if not k.startswith("__")] - raise KeyError(f"No LOS vectors found in {mat_file.name}. Available keys: {available_keys}") + return _impl(mat_file, key=key) diff --git a/curryer/correction/io.py b/curryer/correction/io.py new file mode 100644 index 00000000..fcdfcc34 --- /dev/null +++ b/curryer/correction/io.py @@ -0,0 +1,158 @@ +"""Unified path resolution for the correction pipeline. + +All file loading in the correction package should go through resolve_path() +to transparently handle local paths and S3 URIs. + +S3 access follows the same pattern as :mod:`curryer.correction.dataio`: +callers may provide an explicit ``s3_client`` (useful for testing) or rely +on the default boto3 client. +""" + +from __future__ import annotations + +import atexit +import logging +import tempfile +from pathlib import Path + +logger = logging.getLogger(__name__) + +# Registry of temp files created by S3 downloads. Cleaned up at process exit. +_temp_files: list[Path] = [] + + +def _cleanup_temp_files() -> None: + """Remove temporary files created by S3 downloads.""" + for path in _temp_files: + try: + if path.exists(): + path.unlink() + logger.debug("Cleaned up temp file: %s", path) + except OSError as exc: + logger.warning("Failed to clean up temp file %s: %s", path, exc) + _temp_files.clear() + + +atexit.register(_cleanup_temp_files) + + +def _require_client(client: object | None) -> object: + """Return *client* if given, otherwise create a default boto3 S3 client. + + This is intentionally identical to :func:`curryer.correction.dataio._require_client`. + + Parameters + ---------- + client : object or None + An injected S3 client, or None to use the default. + + Returns + ------- + object + A boto3 S3 client. + + Raises + ------ + ImportError + If *client* is None and boto3 is not installed. + """ + if client is not None: + return client + try: + import boto3 + except ImportError: + raise ImportError("S3 paths require boto3. Install with: pip install boto3") + return boto3.client("s3") + + +def resolve_path(path: str | Path, *, s3_client=None) -> Path: + """Resolve a file path, downloading from S3 if necessary. + + Parameters + ---------- + path : str or Path + Local file path or S3 URI (``s3://bucket/key``). + s3_client : boto3 S3 client, optional + Injected client for testing. If omitted and *path* is an S3 URI, + a default client is created via boto3. + + Returns + ------- + Path + Local file path. For S3 URIs, a temporary local file that is + cleaned up at process exit via :func:`atexit`. + + Raises + ------ + ImportError + If *path* is an S3 URI and boto3 is not installed. + FileNotFoundError + If *path* is a local path that does not exist. + ValueError + If *path* is an S3 URI with no object key (e.g. ``s3://bucket``). + """ + path_str = str(path) + if path_str.startswith("s3://"): + return _download_from_s3(path_str, s3_client=s3_client) + local_path = Path(path) + if not local_path.exists(): + raise FileNotFoundError(f"File not found: {local_path}") + return local_path + + +def _download_from_s3(s3_uri: str, *, s3_client=None) -> Path: + """Download an S3 object to a local temporary file. + + The temp file is registered for cleanup at process exit via + :func:`_cleanup_temp_files`. + + Parameters + ---------- + s3_uri : str + S3 URI in the form ``s3://bucket/key``. + s3_client : boto3 S3 client, optional + Injected client for testing. + + Returns + ------- + Path + Path to the local temporary file. + + Raises + ------ + ImportError + If boto3 is not installed and no client was injected. + ValueError + If the URI has no object key. + """ + client = _require_client(s3_client) + + # Parse s3://bucket/key + stripped = s3_uri.replace("s3://", "", 1) + parts = stripped.split("/", 1) + bucket = parts[0] + key = parts[1] if len(parts) > 1 else "" + + if not key or not key.strip(): + raise ValueError( + f"S3 URI must include an object key: {s3_uri!r}\nExpected format: s3://bucket/path/to/file.ext" + ) + + # Determine suffix from key for proper file handling + suffix = Path(key).suffix or "" + + logger.info("Downloading from S3: %s", s3_uri) + # Create and immediately close the temp file so the handle is released + # before boto3 writes to it (avoids Windows file-locking errors). + with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: + tmp_name = tmp.name + tmp_path = Path(tmp_name) + try: + client.download_file(bucket, key, tmp_name) + except Exception: + # Clean up partial download on failure + tmp_path.unlink(missing_ok=True) + raise + _temp_files.append(tmp_path) + logger.info(" Downloaded to: %s", tmp_path) + return tmp_path diff --git a/curryer/correction/kernel_ops.py b/curryer/correction/kernel_ops.py new file mode 100644 index 00000000..9648dd2b --- /dev/null +++ b/curryer/correction/kernel_ops.py @@ -0,0 +1,297 @@ +"""SPICE kernel file management for the correction pipeline. + +This module creates and applies parameter-specific SPICE kernels: + +- :func:`apply_offset` -- modifies telemetry/science data for + ``OFFSET_KERNEL`` and ``OFFSET_TIME`` parameters. +- :func:`_create_dynamic_kernels` -- writes SC-SPK/SC-CK kernels from + telemetry data (once per image pair, not per parameter set). +- :func:`_create_parameter_kernels` -- writes parameter-specific kernels + and applies time offsets for each iteration. +""" + +import logging +from pathlib import Path +from typing import Any + +import numpy as np +import pandas as pd + +from curryer.correction.config import CorrectionConfig, ParameterConfig, ParameterType +from curryer.kernels import create + +logger = logging.getLogger(__name__) + + +def apply_offset(config: ParameterConfig, param_data, input_data): + """ + Apply parameter offsets to input data based on parameter type. + + Args: + config: ParameterConfig specifying how to apply the offset + param_data: The parameter values to apply (offset amounts) + input_data: The input dataset to modify + + Returns: + Modified copy of input_data with parameter offsets applied + """ + logger.info(f"Applying {config.ptype.name} offset to {config.data.get('field', 'unknown field')}") + + # Make a copy to avoid modifying the original + if isinstance(input_data, pd.DataFrame): + modified_data = input_data.copy() + else: + modified_data = input_data.copy() if hasattr(input_data, "copy") else input_data + + if config.ptype == ParameterType.OFFSET_KERNEL: + # Apply offset to telemetry fields for dynamic kernels (azimuth/elevation angles) + # OFFSET_KERNEL is ONLY for angle biases, not time offsets + # Valid units: "arcseconds" (converted to radians) or None (radians assumed) + # For time offsets, use OFFSET_TIME instead + field_name = config.data.get("field") + if not field_name: + raise ValueError("OFFSET_KERNEL parameter requires 'field' to be specified in config") + + if field_name in modified_data.columns: + # Convert parameter value to appropriate units + # OFFSET_KERNEL is for angle biases only (azimuth/elevation angles) + offset_value = param_data + original_value = offset_value + if config.data.get("units") == "arcseconds": + # Convert arcseconds to radians for application + offset_value = np.deg2rad(param_data / 3600.0) + logger.info(f"✓ Applying OFFSET_KERNEL to field '{field_name}'") + logger.info(f" Offset: {original_value:.6f} arcsec = {offset_value:.9f} rad") + else: + # No units specified - assume radians (direct application) + logger.info(f"✓ Applying OFFSET_KERNEL to field '{field_name}'") + logger.info(f" Offset: {offset_value:.9f} rad (no unit conversion)") + + # Store original values for logging + original_mean = modified_data[field_name].mean() + + # Apply additive offset + modified_data[field_name] = modified_data[field_name] + offset_value + + # Log the effect + new_mean = modified_data[field_name].mean() + logger.info(f" Original mean: {original_mean:.9f}") + logger.info(f" New mean: {new_mean:.9f}") + logger.info(f" Delta: {new_mean - original_mean:.9f}") + else: + available_cols = list(modified_data.columns) if hasattr(modified_data, "columns") else [] + logger.warning(f"Field '{field_name}' not found in telemetry data for offset application") + logger.warning(f"Available columns: {available_cols}") + + elif config.ptype == ParameterType.OFFSET_TIME: + # Apply time offset to science frame timing + # NOTE: param_data is in seconds while target field (e.g., corrected_timestamp) is typically in microseconds + field_name = config.data.get("field", "corrected_timestamp") + if hasattr(modified_data, "__getitem__") and field_name in modified_data: + # param_data is already in seconds (converted by load_param_sets) + # Convert seconds to microseconds for the timestamp field + offset_value_seconds = param_data + offset_value_us = param_data * 1000000.0 # seconds to microseconds + + logger.info(f"✓ Applying OFFSET_TIME to field '{field_name}'") + units = config.data.get("units", "seconds") + if units == "milliseconds": + logger.info(f" Offset: {offset_value_seconds * 1000.0:.6f} ms (configured) = {offset_value_us:.6f} µs") + elif units == "microseconds": + logger.info( + f" Offset: {offset_value_seconds * 1000000.0:.6f} µs (configured) = {offset_value_us:.6f} µs" + ) + else: + logger.info(f" Offset: {offset_value_seconds:.6f} s = {offset_value_us:.6f} µs") + + # Store original values for logging + if hasattr(modified_data[field_name], "mean"): + original_mean = modified_data[field_name].mean() + else: + original_mean = np.mean(modified_data[field_name]) + + # Apply additive offset in microseconds + modified_data[field_name] = modified_data[field_name] + offset_value_us + + # Log the effect + if hasattr(modified_data[field_name], "mean"): + new_mean = modified_data[field_name].mean() + else: + new_mean = np.mean(modified_data[field_name]) + logger.info(f" Original mean: {original_mean:.6f}") + logger.info(f" New mean: {new_mean:.6f}") + logger.info(f" Delta: {new_mean - original_mean:.6f}") + else: + logger.warning(f"Field '{field_name}' not found in science data for time offset application") + + elif config.ptype == ParameterType.CONSTANT_KERNEL: + # For constant kernels, param_data should already be in the correct format + # (DataFrame with ugps, angle_x, angle_y, angle_z columns) + logger.info( + f"Using constant kernel data with {len(param_data) if hasattr(param_data, '__len__') else 1} entries" + ) + modified_data = param_data + + else: + raise NotImplementedError(f"Parameter type {config.ptype} not implemented") + + return modified_data + + +def _create_dynamic_kernels( + config: "CorrectionConfig", + work_dir: Path, + tlm_dataset: pd.DataFrame, + creator: "create.KernelCreator", +) -> list[Path]: + """Create dynamic SPICE kernels from telemetry data. + + Dynamic kernels (SC-SPK, SC-CK) are generated from spacecraft telemetry + and do not change with parameter variations. In the current implementation, + these are created once per image. + + Parameters + ---------- + config : CorrectionConfig + Configuration with geo settings and dynamic_kernels list + work_dir : Path + Working directory for kernel files + tlm_dataset : pd.DataFrame + Spacecraft state data with position, velocity, attitude, and time columns + creator : create.KernelCreator + KernelCreator instance for writing kernels + + Returns + ------- + list[Path] + List of kernel file paths created (e.g., [sc_ephemeris.bsp, sc_attitude.bc]) + + Examples + -------- + >>> from curryer.kernels import create + >>> creator = create.KernelCreator(overwrite=True, append=False) + >>> dynamic_kernels = _create_dynamic_kernels(config, work_dir, tlm_dataset, creator) + >>> # Use in SPICE context + >>> with sp.ext.load_kernel(dynamic_kernels): + ... # Perform geolocation + ... pass + """ + logger.info(" Creating dynamic kernels from telemetry...") + dynamic_kernels = [] + for kernel_config in config.geo.dynamic_kernels: + dynamic_kernels.append( + creator.write_from_json( + kernel_config, + output_kernel=work_dir, + input_data=tlm_dataset, + ) + ) + logger.info(f" Created {len(dynamic_kernels)} dynamic kernels") + return dynamic_kernels + + +def _create_parameter_kernels( + params: list[tuple["ParameterConfig", Any]], + work_dir: Path, + tlm_dataset: pd.DataFrame, + sci_dataset: pd.DataFrame, + ugps_times: Any, + config: "CorrectionConfig", + creator: "create.KernelCreator", +) -> tuple[list[Path], Any]: + """Create parameter-specific SPICE kernels and apply time offsets. + + This function applies parameter variations by creating modified kernels + (CONSTANT_KERNEL, OFFSET_KERNEL) or modifying time tags (OFFSET_TIME). + Each parameter set produces different kernels and/or time modifications. + + Parameters + ---------- + params : list[tuple[ParameterConfig, Any]] + List of (ParameterConfig, parameter_value) tuples for this iteration + work_dir : Path + Working directory for kernel files + tlm_dataset : pd.DataFrame + Spacecraft state data (may be modified for OFFSET_KERNEL) with position, velocity, attitude, and time columns + sci_dataset : pd.DataFrame + Science frame time data (may be modified for OFFSET_TIME), may include optional measurement columns + ugps_times : array-like + Original time array from science dataset + config : CorrectionConfig + Configuration with geo settings + creator : create.KernelCreator + KernelCreator instance for writing kernels + + Returns + ------- + param_kernels : list[Path] + List of parameter-specific kernel file paths + ugps_times_modified : array-like + Modified time array if OFFSET_TIME applied, otherwise original times + + Examples + -------- + >>> param_kernels, times = _create_parameter_kernels( + ... params, work_dir, tlm_dataset, sci_dataset, ugps_times, config, creator + ... ) + >>> # Use in SPICE context with dynamic kernels + >>> with sp.ext.load_kernel([dynamic_kernels, param_kernels]): + ... geo = geolocate(times) + """ + param_kernels = [] + ugps_times_modified = ugps_times.copy() if hasattr(ugps_times, "copy") else ugps_times + + # Apply each individual parameter change + logger.info(" Applying parameter changes:") + for a_param, p_data in params: # [ParameterConfig, typing.Any] + # Log parameter details + param_name = a_param.data.get("field", "unknown") + if a_param.ptype == ParameterType.CONSTANT_KERNEL: + logger.info(f" {a_param.ptype.name}: {param_name} (constant kernel data)") + elif a_param.ptype == ParameterType.OFFSET_KERNEL: + units = a_param.data.get("units", "") + logger.info( + f" {a_param.ptype.name}: {param_name} = {p_data:.6f} " + f"(internal units; configured units: {units or 'unspecified'})" + ) + elif a_param.ptype == ParameterType.OFFSET_TIME: + units = a_param.data.get("units", "") + logger.info( + f" {a_param.ptype.name}: {param_name} = {p_data:.6f} " + f"(internal units; configured units: {units or 'unspecified'})" + ) + + # Create static changing SPICE kernels + if a_param.ptype == ParameterType.CONSTANT_KERNEL: + # Aka: BASE-CK, YOKE-CK, HYSICS-CK + param_kernels.append( + creator.write_from_json( + a_param.config_file, + output_kernel=work_dir, + input_data=p_data, + ) + ) + + # Create dynamic changing SPICE kernels + elif a_param.ptype == ParameterType.OFFSET_KERNEL: + # Aka: AZ-CK, EL-CK + tlm_dataset_alt = apply_offset(a_param, p_data, tlm_dataset) + param_kernels.append( + creator.write_from_json( + a_param.config_file, + output_kernel=work_dir, + input_data=tlm_dataset_alt, + ) + ) + + # Alter non-kernel data + elif a_param.ptype == ParameterType.OFFSET_TIME: + # Aka: Frame-times... + sci_dataset_alt = apply_offset(a_param, p_data, sci_dataset) + ugps_times_modified = sci_dataset_alt[config.geo.time_field].values + + else: + raise NotImplementedError(a_param.ptype) + + logger.info(f" Created {len(param_kernels)} parameter-specific kernels") + return param_kernels, ugps_times_modified diff --git a/curryer/correction/pairing.py b/curryer/correction/pairing.py index a2ce3ea6..c3eb427a 100644 --- a/curryer/correction/pairing.py +++ b/curryer/correction/pairing.py @@ -16,7 +16,6 @@ from collections.abc import Iterable, Sequence from dataclasses import dataclass from pathlib import Path -from typing import Protocol import numpy as np @@ -26,80 +25,6 @@ logger = logging.getLogger(__name__) -# ============================================================================ -# GCP Pairing Interface Protocol -# ============================================================================ - - -class GCPPairingFunc(Protocol): - """ - Protocol for GCP pairing functions in Correction pipeline. - - Pairing functions determine which science observations (L1A images) - overlap with which ground control points (GCP reference images). - - Standard Signature: - def pair_gcps(science_keys: List[str]) -> List[Tuple[str, str]] - - Returns: - List of (science_key, gcp_reference_path) tuples, one per valid pair - - Note: - This is a simplified interface for Correction compatibility. - Real implementations (like find_l1a_gcp_pairs below) may use more - sophisticated spatial algorithms internally, but must return results - in this simple tuple format. - - Examples: - # Real spatial pairing - def spatial_gcp_pairing(science_keys): - l1a_images = load_images(science_keys) - gcp_images = discover_gcps() - pairs = find_spatial_overlaps(l1a_images, gcp_images) - return [(l1a.name, gcp.path) for l1a, gcp in pairs] - - # Test/synthetic pairing - def synthetic_gcp_pairing(science_keys): - return [(key, f"synthetic_gcp_{i}.tif") - for i, key in enumerate(science_keys)] - """ - - def __call__(self, science_keys: list[str]) -> list[tuple[str, str]]: - """Find GCP pairs for given science observations.""" - ... - - -def validate_pairing_output(pairs: list[tuple[str, str]]) -> None: - """ - Validate that GCP pairing output conforms to expected format. - - Args: - pairs: List of (science_key, gcp_path) tuples - - Raises: - TypeError: If structure is invalid - ValueError: If tuple elements have wrong types - - Example: - >>> pairs = gcp_pairing_func(["sci_001", "sci_002"]) - >>> validate_pairing_output(pairs) - """ - if not isinstance(pairs, list): - raise TypeError(f"GCP pairing must return list, got {type(pairs)}") - - for i, pair in enumerate(pairs): - if not isinstance(pair, tuple) or len(pair) != 2: - raise ValueError( - f"GCP pairing output[{i}] must be (str, str) tuple, " - f"got {type(pair)} with length {len(pair) if isinstance(pair, tuple) else 'N/A'}" - ) - sci_key, gcp_path = pair - if not isinstance(sci_key, str) or not isinstance(gcp_path, str | Path): - raise ValueError( - f"GCP pairing output[{i}] = ({type(sci_key).__name__}, {type(gcp_path).__name__}), expected (str, str)" - ) - - # ============================================================================ # Spatial Pairing Implementation # ============================================================================ @@ -529,7 +454,7 @@ def pair_files( >>> for l1a, gcp in pairs: ... print(f" {l1a.name} → {gcp.name}") """ - from .image_match import load_image_grid_from_mat + from .image_io import load_image_grid_from_mat gcp_dir = Path(gcp_directory) if not gcp_dir.is_dir(): diff --git a/curryer/correction/parameters.py b/curryer/correction/parameters.py new file mode 100644 index 00000000..cd64a899 --- /dev/null +++ b/curryer/correction/parameters.py @@ -0,0 +1,557 @@ +"""Parameter set generation for the correction pipeline. + +This module provides :func:`load_param_sets`, which generates parameter sets +for correction analysis. Three search strategies are supported: + +``RANDOM`` (default) + Monte Carlo random walk. Each parameter is sampled from a normal + distribution centred on ``current_value`` with the configured ``sigma``, + clipped to ``bounds``. Controlled by ``seed`` and ``n_iterations``. + +``GRID_SEARCH`` + Deterministic cartesian-product sweep. ``grid_points_per_param`` + evenly-spaced offset values are produced for each parameter (spanning its + full ``bounds`` range) and the cartesian product of all per-parameter grids + is enumerated. ``n_iterations`` is ignored for this strategy. + +``SINGLE_OFFSET`` + Deterministic single-parameter sweep. Each parameter is varied + independently across ``n_iterations`` evenly-spaced values while all other + parameters are held at their nominal ``current_value``. + +Supported parameter types: +- ``CONSTANT_KERNEL`` – 3-D attitude corrections (roll, pitch, yaw) stored as + a ``pandas.DataFrame`` with ``ugps``, ``angle_x``, ``angle_y``, ``angle_z``. +- ``OFFSET_KERNEL`` – single-axis angle bias (float, in radians). +- ``OFFSET_TIME`` – timing correction (float, in seconds). +""" + +import itertools +import logging +import typing + +import numpy as np +import pandas as pd + +from curryer.correction.config import CorrectionConfig, ParameterConfig, ParameterType, SearchStrategy + +logger = logging.getLogger(__name__) + +# ============================================================================ +# Unit-conversion helpers +# ============================================================================ + +_UGPS_EPOCH_END = 2_209_075_218_000_000 # sentinel end-of-mission ugps for CK DataFrames + + +def _arcsec_to_rad(value: float) -> float: + """Convert arcseconds to radians.""" + return np.deg2rad(value / 3600.0) if value != 0 else 0.0 + + +def _bounds_to_rad(bounds: list[float], units: str | None) -> list[float]: + if units == "arcseconds": + return [np.deg2rad(bounds[0] / 3600.0), np.deg2rad(bounds[1] / 3600.0)] + return list(bounds) + + +def _val_to_rad(value: float, units: str | None) -> float: + if units == "arcseconds": + return _arcsec_to_rad(value) + return value + + +def _val_to_seconds(value: float, units: str | None) -> float: + if units == "milliseconds": + return value / 1_000.0 + if units == "microseconds": + return value / 1_000_000.0 + return value + + +def _bounds_to_seconds(bounds: list[float], units: str | None) -> list[float]: + if units == "milliseconds": + return [bounds[0] / 1_000.0, bounds[1] / 1_000.0] + if units == "microseconds": + return [bounds[0] / 1_000_000.0, bounds[1] / 1_000_000.0] + return list(bounds) + + +# ============================================================================ +# DataFrame builder for CONSTANT_KERNEL +# ============================================================================ + + +def _make_ck_dataframe(angle_vals: list[float]) -> pd.DataFrame: + """Wrap ``[angle_x, angle_y, angle_z]`` (radians) into the pipeline DataFrame format.""" + return pd.DataFrame( + { + "ugps": [0, _UGPS_EPOCH_END], + "angle_x": [angle_vals[0], angle_vals[0]], + "angle_y": [angle_vals[1], angle_vals[1]], + "angle_z": [angle_vals[2], angle_vals[2]], + } + ) + + +# ============================================================================ +# Scalar current-value extraction (handles list vs scalar current_value) +# ============================================================================ + + +def _scalar_current_value(param: ParameterConfig) -> float: + """Return ``param.data.current_value`` as a scalar float. + + Raises + ------ + TypeError + If ``current_value`` is not a scalar numeric type. This helps surface + misconfigured parameters early instead of silently coercing them to 0.0. + """ + cv = param.data.current_value + if isinstance(cv, (int, float, np.number)): + return float(cv) + + param_name = getattr(param, "name", "") + raise TypeError( + f"Parameter '{param_name}' (type {getattr(param.ptype, 'name', param.ptype)}) " + f"has non-scalar current_value of type {type(cv).__name__}; expected a scalar " + "numeric value (int, float, or NumPy scalar)." + ) + + +# ============================================================================ +# Nominal value (no offset applied) – used by SINGLE_OFFSET for held params +# ============================================================================ + + +def _get_nominal_value(param: ParameterConfig) -> typing.Any: + """Return the un-perturbed, unit-converted value for *param*. + + For ``CONSTANT_KERNEL``, returns a :class:`~pandas.DataFrame` with angles + equal to the ``current_value`` in radians. For ``OFFSET_KERNEL`` / + ``OFFSET_TIME``, returns a float in radians / seconds respectively. + """ + units = param.data.get("units") + current_value = param.data.current_value + + if param.ptype == ParameterType.CONSTANT_KERNEL: + if isinstance(current_value, list) and len(current_value) == 3: + angle_vals = [_val_to_rad(v, units) for v in current_value] + else: + cv_rad = _val_to_rad(_scalar_current_value(param), units) + angle_vals = [cv_rad, cv_rad, cv_rad] + return _make_ck_dataframe(angle_vals) + + if param.ptype == ParameterType.OFFSET_KERNEL: + return _val_to_rad(_scalar_current_value(param), units) + + if param.ptype == ParameterType.OFFSET_TIME: + return _val_to_seconds(_scalar_current_value(param), units) + + param_name = getattr(param, "name", "") + raise NotImplementedError( + f"Unsupported ParameterType '{getattr(param.ptype, 'name', param.ptype)}' " + f"for nominal value of parameter '{param_name}'." + ) + + +# ============================================================================ +# Grid values (evenly-spaced across the offset range) – GRID_SEARCH / SINGLE_OFFSET +# ============================================================================ + + +def _get_grid_values(param: ParameterConfig, n_points: int) -> list[typing.Any]: + """Return *n_points* evenly-spaced sampled values for *param*. + + Offsets are linearly spaced over ``[bounds[0], bounds[1]]`` (in the + parameter's native units before conversion) and added to the converted + ``current_value``. + + For ``CONSTANT_KERNEL`` the scalar offset is applied uniformly to all + three rotation axes. + + Parameters + ---------- + param : ParameterConfig + Parameter specification. + n_points : int + Number of evenly-spaced points (>= 2). + + Returns + ------- + list + List of *n_points* values; each element matches what the pipeline + expects for that parameter type (DataFrame or float). + """ + units = param.data.get("units") + bounds = param.data.bounds + current_value = param.data.current_value + + if param.ptype == ParameterType.CONSTANT_KERNEL: + bounds_rad = _bounds_to_rad(bounds, units) + offsets = np.linspace(bounds_rad[0], bounds_rad[1], n_points) + + if isinstance(current_value, list) and len(current_value) == 3: + base_vals = [_val_to_rad(v, units) for v in current_value] + else: + cv_rad = _val_to_rad(_scalar_current_value(param), units) + base_vals = [cv_rad, cv_rad, cv_rad] + + return [_make_ck_dataframe([bv + offset for bv in base_vals]) for offset in offsets] + + if param.ptype == ParameterType.OFFSET_KERNEL: + cv_rad = _val_to_rad(_scalar_current_value(param), units) + bounds_rad = _bounds_to_rad(bounds, units) + offsets = np.linspace(bounds_rad[0], bounds_rad[1], n_points) + return list(cv_rad + offsets) + + if param.ptype == ParameterType.OFFSET_TIME: + cv_s = _val_to_seconds(_scalar_current_value(param), units) + bounds_s = _bounds_to_seconds(bounds, units) + offsets = np.linspace(bounds_s[0], bounds_s[1], n_points) + return list(cv_s + offsets) + + raise ValueError(f"Unsupported parameter type for grid generation: {param.ptype!r}") + + +# ============================================================================ +# Strategy implementations +# ============================================================================ + + +def _generate_random(config: CorrectionConfig) -> list[list[tuple[ParameterConfig, typing.Any]]]: + """Generate random parameter sets – exact current behaviour preserved.""" + if config.seed is not None: + np.random.seed(config.seed) + logger.info(f"Set random seed to {config.seed} for reproducible parameter generation") + + output = [] + + for ith in range(config.n_iterations): + out_set = [] + logger.debug(f"Generating parameter set {ith + 1}/{config.n_iterations}") + + for param_idx, param in enumerate(config.parameters): + current_value = param.data.current_value + bounds = param.data.bounds + + if param.ptype == ParameterType.CONSTANT_KERNEL: + if isinstance(current_value, list) and len(current_value) == 3: + param_vals = [] + for i, current_val in enumerate(current_value): + if "sigma" in param.data and param.data["sigma"] is not None and param.data["sigma"] > 0: + if param.data.get("units") == "arcseconds": + sigma_rad = np.deg2rad(param.data["sigma"] / 3600.0) + current_val_rad = np.deg2rad(current_val / 3600.0) if current_val != 0 else current_val + bounds_rad = [np.deg2rad(bounds[0] / 3600.0), np.deg2rad(bounds[1] / 3600.0)] + else: + sigma_rad = param.data["sigma"] + current_val_rad = current_val + bounds_rad = bounds + offset = np.random.normal(0, sigma_rad) + offset = np.clip(offset, bounds_rad[0], bounds_rad[1]) + param_vals.append(current_val_rad + offset) + else: + if "sigma" not in param.data or param.data["sigma"] is None: + logger.debug( + f" Parameter {param_idx} axis {i}: No sigma specified, using fixed current_value" + ) + elif param.data["sigma"] == 0: + logger.debug(f" Parameter {param_idx} axis {i}: sigma=0, using fixed current_value") + if param.data.get("units") == "arcseconds": + current_val_rad = np.deg2rad(current_val / 3600.0) if current_val != 0 else current_val + else: + current_val_rad = current_val + param_vals.append(current_val_rad) + else: + param_vals = [0.0, 0.0, 0.0] + if "sigma" in param.data and param.data["sigma"] is not None and param.data["sigma"] > 0: + if param.data.get("units") == "arcseconds": + sigma_rad = np.deg2rad(param.data["sigma"] / 3600.0) + bounds_rad = [np.deg2rad(bounds[0] / 3600.0), np.deg2rad(bounds[1] / 3600.0)] + current_val_rad = ( + np.deg2rad(current_value / 3600.0) if current_value != 0 else current_value + ) + else: + sigma_rad = param.data["sigma"] + bounds_rad = bounds + current_val_rad = current_value + for i in range(3): + offset = np.random.normal(0, sigma_rad) + offset = np.clip(offset, bounds_rad[0], bounds_rad[1]) + param_vals[i] = current_val_rad + offset + else: + if "sigma" not in param.data or param.data["sigma"] is None: + logger.debug(f" Parameter {param_idx}: No sigma specified, using fixed current_value") + elif param.data["sigma"] == 0: + logger.debug(f" Parameter {param_idx}: sigma=0, using fixed current_value") + if param.data.get("units") == "arcseconds": + current_val_rad = ( + np.deg2rad(current_value / 3600.0) if current_value != 0 else current_value + ) + else: + current_val_rad = current_value + param_vals = [current_val_rad, current_val_rad, current_val_rad] + + param_vals = _make_ck_dataframe(param_vals) + logger.debug( + f" CONSTANT_KERNEL {param_idx}: angles=[{param_vals['angle_x'].iloc[0]:.6e}, " + f"{param_vals['angle_y'].iloc[0]:.6e}, {param_vals['angle_z'].iloc[0]:.6e}] rad" + ) + + elif param.ptype == ParameterType.OFFSET_KERNEL: + if "sigma" in param.data and param.data["sigma"] is not None and param.data["sigma"] > 0: + if param.data.get("units") == "arcseconds": + sigma_rad = np.deg2rad(param.data["sigma"] / 3600.0) + current_val_rad = np.deg2rad(current_value / 3600.0) if current_value != 0 else current_value + bounds_rad = [np.deg2rad(bounds[0] / 3600.0), np.deg2rad(bounds[1] / 3600.0)] + else: + sigma_rad = param.data["sigma"] + current_val_rad = current_value + bounds_rad = bounds + offset = np.random.normal(0, sigma_rad) + offset = np.clip(offset, bounds_rad[0], bounds_rad[1]) + param_vals = current_val_rad + offset + else: + if "sigma" not in param.data or param.data["sigma"] is None: + logger.debug(f" Parameter {param_idx}: No sigma specified, using fixed current_value") + elif param.data["sigma"] == 0: + logger.debug(f" Parameter {param_idx}: sigma=0, using fixed current_value") + if param.data.get("units") == "arcseconds": + current_val_rad = np.deg2rad(current_value / 3600.0) if current_value != 0 else current_value + else: + current_val_rad = current_value + param_vals = current_val_rad + logger.debug(f" OFFSET_KERNEL {param_idx}: {param_vals:.6e} rad") + + elif param.ptype == ParameterType.OFFSET_TIME: + if "sigma" in param.data and param.data["sigma"] is not None and param.data["sigma"] > 0: + if param.data.get("units") == "seconds": + sigma_time = param.data["sigma"] + current_val_time = current_value + bounds_time = bounds + elif param.data.get("units") == "milliseconds": + sigma_time = param.data["sigma"] / 1000.0 + current_val_time = current_value / 1000.0 + bounds_time = [bounds[0] / 1000.0, bounds[1] / 1000.0] + elif param.data.get("units") == "microseconds": + sigma_time = param.data["sigma"] / 1000000.0 + current_val_time = current_value / 1000000.0 + bounds_time = [bounds[0] / 1000000.0, bounds[1] / 1000000.0] + else: + sigma_time = param.data["sigma"] + current_val_time = current_value + bounds_time = bounds + offset = np.random.normal(0, sigma_time) + offset = np.clip(offset, bounds_time[0], bounds_time[1]) + param_vals = current_val_time + offset + else: + if "sigma" not in param.data or param.data["sigma"] is None: + logger.debug(f" Parameter {param_idx}: No sigma specified, using fixed current_value") + elif param.data["sigma"] == 0: + logger.debug(f" Parameter {param_idx}: sigma=0, using fixed current_value") + if param.data.get("units") == "milliseconds": + current_val_time = current_value / 1000.0 + elif param.data.get("units") == "microseconds": + current_val_time = current_value / 1000000.0 + else: + current_val_time = current_value + param_vals = current_val_time + logger.debug(f" OFFSET_TIME {param_idx}: {param_vals:.6e} seconds") + + out_set.append((param, param_vals)) + output.append(out_set) + + return output + + +def _generate_grid_search(config: CorrectionConfig) -> list[list[tuple[ParameterConfig, typing.Any]]]: + """Generate parameter sets via deterministic cartesian-product grid sweep. + + Produces ``grid_points_per_param ^ len(parameters)`` parameter sets. + ``n_iterations`` is not used for this strategy. + + Raises + ------ + ValueError + If the total number of parameter sets would exceed ``config.max_grid_sets``. + Increase ``max_grid_sets`` deliberately, reduce ``grid_points_per_param`` or + the number of parameters, or use ``SearchStrategy.SINGLE_OFFSET`` instead. + """ + n = config.grid_points_per_param + n_params = len(config.parameters) + total = n**n_params + logger.info(f"GRID_SEARCH: {n} points × {n_params} parameter(s) = {total} total parameter sets") + + if total > config.max_grid_sets: + raise ValueError( + f"GRID_SEARCH would produce {total:,} parameter sets " + f"({n} points ^ {n_params} parameters), which exceeds the safety limit of " + f"{config.max_grid_sets:,}. " + f"To proceed, either:\n" + f" • reduce grid_points_per_param (currently {n}) or the number of parameters,\n" + f" • increase max_grid_sets on CorrectionConfig (set deliberately), or\n" + f" • use SearchStrategy.SINGLE_OFFSET for high-dimensional sweeps." + ) + + per_param_values = [_get_grid_values(param, n) for param in config.parameters] + + output = [] + for combo in itertools.product(*per_param_values): + out_set = list(zip(config.parameters, combo)) + output.append(out_set) + + logger.info(f"GRID_SEARCH: generated {len(output)} parameter sets") + return output + + +def _generate_single_offset(config: CorrectionConfig) -> list[list[tuple[ParameterConfig, typing.Any]]]: + """Generate parameter sets by sweeping one parameter at a time. + + For each parameter in ``config.parameters``: + - ``n_iterations`` evenly-spaced values are generated spanning the + parameter's full ``bounds`` offset range. + - All other parameters are held at their nominal ``current_value``. + + Total parameter sets produced: ``len(parameters) × n_iterations``. + """ + n = config.n_iterations + n_params = len(config.parameters) + logger.info(f"SINGLE_OFFSET: {n_params} parameter(s) × {n} values each = {n_params * n} total parameter sets") + + nominals = [_get_nominal_value(param) for param in config.parameters] + + output = [] + for sweep_idx, sweep_param in enumerate(config.parameters): + sweep_values = _get_grid_values(sweep_param, n) + param_name = sweep_param.config_file.name if sweep_param.config_file else f"param_{sweep_idx}" + logger.debug(f" SINGLE_OFFSET: sweeping parameter {sweep_idx} ({param_name}) with {len(sweep_values)} values") + for val in sweep_values: + out_set = [] + for param_idx, param in enumerate(config.parameters): + out_set.append((param, val if param_idx == sweep_idx else nominals[param_idx])) + output.append(out_set) + + logger.info(f"SINGLE_OFFSET: generated {len(output)} parameter sets") + return output + + +# ============================================================================ +# Logging helper +# ============================================================================ + + +def _log_param_set_summary(output: list[list[tuple[ParameterConfig, typing.Any]]]) -> None: + """Log a structured summary of the generated parameter sets. + + High-level counts are always emitted at INFO. The full per-set detail + (angles / offsets for every set) is emitted at DEBUG only, so large + GRID_SEARCH / SINGLE_OFFSET sweeps do not flood the INFO log. + """ + if not output: + return + + logger.info(f"Generated {len(output)} parameter sets with {len(output[0])} parameters each") + + if not logger.isEnabledFor(logging.DEBUG): + return + + logger.debug("\nParameter Set Summary:") + logger.debug("-" * 100) + for param_set_idx, param_set in enumerate(output): + logger.debug(f" Set {param_set_idx}:") + for param_idx, (param, param_vals) in enumerate(param_set): + field_name = param.data.get("field", "unknown") + ptype_name = param.ptype.name + + if param.ptype == ParameterType.CONSTANT_KERNEL: + if isinstance(param_vals, pd.DataFrame) and "angle_x" in param_vals.columns: + angles = [ + param_vals["angle_x"].iloc[0], + param_vals["angle_y"].iloc[0], + param_vals["angle_z"].iloc[0], + ] + logger.debug( + f" {ptype_name:16s} {field_name:25s}: " + f"[{angles[0]:+.6e}, {angles[1]:+.6e}, {angles[2]:+.6e}] rad" + ) + else: + logger.debug(f" {ptype_name:16s} {field_name:25s}: (constant kernel data)") + elif param.ptype == ParameterType.OFFSET_KERNEL: + units = param.data.get("units", "") + if units == "arcseconds": + param_arcsec = np.rad2deg(param_vals) * 3600.0 + logger.debug( + f" {ptype_name:16s} {field_name:25s}: {param_arcsec:+10.3f} arcsec ({param_vals:+.9f} rad)" + ) + else: + logger.debug(f" {ptype_name:16s} {field_name:25s}: {param_vals:+.9f} {units}") + elif param.ptype == ParameterType.OFFSET_TIME: + units = param.data.get("units", "") + if units == "milliseconds": + param_ms = param_vals * 1000.0 + logger.debug(f" {ptype_name:16s} {field_name:25s}: {param_ms:+10.3f} ms ({param_vals:+.9f} s)") + else: + logger.debug(f" {ptype_name:16s} {field_name:25s}: {param_vals:+.9f} {units}") + logger.debug("-" * 100) + + +# ============================================================================ +# Public API +# ============================================================================ + + +def load_param_sets(config: CorrectionConfig) -> list[list[tuple[ParameterConfig, typing.Any]]]: + """Generate parameter sets for correction iterations. + + Dispatches to the appropriate generator based on + ``config.search_strategy``: + + - :attr:`~SearchStrategy.RANDOM` – Monte Carlo random walk (default). + - :attr:`~SearchStrategy.GRID_SEARCH` – deterministic cartesian-product + sweep across ``grid_points_per_param`` evenly-spaced values per parameter. + - :attr:`~SearchStrategy.SINGLE_OFFSET` – deterministic single-parameter + sweep; each parameter is varied independently while others stay at nominal. + + Parameters + ---------- + config : CorrectionConfig + Complete correction configuration including parameters, strategy, and + sampling settings. + + Returns + ------- + list[list[tuple[ParameterConfig, Any]]] + Outer list: one element per parameter set (iteration). + Inner list: one ``(ParameterConfig, sampled_value)`` pair per parameter. + ``sampled_value`` is a :class:`~pandas.DataFrame` for + ``CONSTANT_KERNEL`` and a ``float`` for ``OFFSET_KERNEL`` / + ``OFFSET_TIME``. + """ + strategy = config.search_strategy + + logger.info( + f"Generating parameter sets for {len(config.parameters)} parameter(s) using strategy: {strategy.value!r}" + ) + for i, param in enumerate(config.parameters): + param_name = param.config_file.name if param.config_file else f"param_{i}" + current_value = param.data.current_value + bounds = param.data.bounds + logger.info( + f" {i + 1}. {param_name} ({param.ptype.name}): " + f"current_value={current_value}, sigma={param.data.get('sigma', 'N/A')}, " + f"bounds={bounds}, units={param.data.get('units', 'N/A')}" + ) + + if strategy == SearchStrategy.RANDOM: + output = _generate_random(config) + elif strategy == SearchStrategy.GRID_SEARCH: + output = _generate_grid_search(config) + elif strategy == SearchStrategy.SINGLE_OFFSET: + output = _generate_single_offset(config) + else: + raise ValueError(f"Unknown SearchStrategy: {strategy!r}. Valid values are: {[s.value for s in SearchStrategy]}") + + _log_param_set_summary(output) + return output diff --git a/curryer/correction/pipeline.py b/curryer/correction/pipeline.py new file mode 100644 index 00000000..f48b3e77 --- /dev/null +++ b/curryer/correction/pipeline.py @@ -0,0 +1,1435 @@ +"""Main correction pipeline orchestration. + +This module contains the public-facing :func:`loop` function that drives +the Monte Carlo parameter sensitivity analysis, plus all of the helper +functions it calls: + +- Adapter functions that bridge between the geolocation/image-matching + sub-modules and the correction loop. +- :func:`load_config_from_json` -- build a :class:`CorrectionConfig` from + a JSON file. +- :func:`_load_file` -- internal helper that reads CSV/NetCDF/HDF5 files + into DataFrames, replacing the old mission-specific loader callables. +- :func:`_load_image_pair_data`, :func:`_load_calibration_data`, + :func:`_geolocate_and_match` -- per-iteration computation helpers. +- :func:`loop` -- outer GCP-pair loop, inner parameter-set loop. +""" + +import logging +import time +from collections.abc import Sequence +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import numpy as np +import pandas as pd +import xarray as xr + +if TYPE_CHECKING: + from curryer.correction.results import CorrectionResult + +from curryer import meta +from curryer import spicierpy as sp +from curryer.compute import constants, spatial +from curryer.correction.config import ( + CalibrationData, + CorrectionConfig, + CorrectionInput, + ImageMatchingContext, + KernelContext, + ParameterType, +) +from curryer.correction.data_structures import ( + ImageGrid, + PSFSamplingConfig, + SearchConfig, +) +from curryer.correction.dataio import ( + validate_science_output, + validate_telemetry_output, +) +from curryer.correction.error_stats import ErrorStatsConfig, ErrorStatsProcessor +from curryer.correction.image_io import ( + load_image_grid_from_mat, + load_los_vectors_from_mat, + load_optical_psf_from_mat, +) +from curryer.correction.image_match import ( + integrated_image_match, + validate_image_matching_output, +) +from curryer.correction.kernel_ops import ( + _create_dynamic_kernels, + _create_parameter_kernels, +) +from curryer.correction.parameters import load_param_sets +from curryer.correction.results_io import ( + _build_netcdf_structure, + _cleanup_checkpoint, + _load_checkpoint, + _save_netcdf_checkpoint, + _save_netcdf_results, +) +from curryer.kernels import create + +logger = logging.getLogger(__name__) + + +def _geolocated_to_image_grid(geo_dataset: xr.Dataset): + """ + Convert Correction geolocation output to ImageGrid for image matching. + + Internal adapter function: converts xarray.Dataset from geolocation step + to ImageGrid format expected by image_match module. + + Args: + geo_dataset: xarray.Dataset with latitude, longitude, altitude/height + + Returns: + ImageGrid suitable for integrated_image_match() + """ + + lat = geo_dataset["latitude"].values + lon = geo_dataset["longitude"].values + + # Try different field names for altitude/height + if "altitude" in geo_dataset: + h = geo_dataset["altitude"].values + elif "height" in geo_dataset: + h = geo_dataset["height"].values + else: + h = np.zeros_like(lat) + + # Get actual radiance/reflectance data when available + if "radiance" in geo_dataset: + data = geo_dataset["radiance"].values + elif "reflectance" in geo_dataset: + data = geo_dataset["reflectance"].values + else: + data = np.ones_like(lat) + + return ImageGrid(data=data, lat=lat, lon=lon, h=h) + + +def _extract_spacecraft_position_midframe( + telemetry: pd.DataFrame, + config: "CorrectionConfig | None" = None, +) -> np.ndarray: + """Extract spacecraft position at mid-frame from telemetry. + + Parameters + ---------- + telemetry : pd.DataFrame + Telemetry DataFrame with spacecraft position columns. + config : CorrectionConfig or None, optional + If provided and ``config.data.position_columns`` is set, those + column names are used directly. Otherwise falls back to + pattern-guessing (with a deprecation warning). + + Returns + ------- + np.ndarray + Shape ``(3,)`` — ``[x, y, z]`` position in metres (J2000 frame). + + Raises + ------ + ValueError + If ``position_columns`` has wrong length, or specified columns are + not found, or pattern-guessing fails. + """ + mid_idx = len(telemetry) // 2 + + # Prefer explicit column names from config + if config is not None and config.data is not None and config.data.position_columns is not None: + cols = config.data.position_columns + if len(cols) != 3: + raise ValueError(f"position_columns must have exactly 3 entries, got {len(cols)}: {cols}") + missing = [c for c in cols if c not in telemetry.columns] + if missing: + raise ValueError( + f"position_columns {missing} not found in telemetry. Available: {telemetry.columns.tolist()}" + ) + position = telemetry[cols].iloc[mid_idx].values.astype(np.float64) + logger.debug( + "Extracted spacecraft position from config.data.position_columns %s: %s", + cols, + position, + ) + return position + + # Legacy fallback: pattern guessing (log deprecation warning) + logger.warning( + "position_columns not configured — falling back to column name pattern-guessing. " + "Set config.data.position_columns = ['col_x', 'col_y', 'col_z'] to silence this warning." + ) + + # Try common column name patterns + position_patterns = [ + ["sc_pos_x", "sc_pos_y", "sc_pos_z"], + ["position_x", "position_y", "position_z"], + ["r_x", "r_y", "r_z"], + ["pos_x", "pos_y", "pos_z"], + ] + + for cols in position_patterns: + if all(c in telemetry.columns for c in cols): + position = telemetry[cols].iloc[mid_idx].values.astype(np.float64) + logger.debug(f"Extracted spacecraft position from columns {cols}: {position}") + return position + + # If patterns don't match, try to find any column containing 'pos' or 'r_' + pos_cols = [c for c in telemetry.columns if "pos" in c.lower() or c.startswith("r_")] + if len(pos_cols) >= 3: + logger.warning(f"Using first 3 position-like columns: {pos_cols[:3]}") + return telemetry[pos_cols[:3]].iloc[mid_idx].values.astype(np.float64) + + raise ValueError(f"Cannot find position columns in telemetry. Available columns: {telemetry.columns.tolist()}") + + +# ============================================================================ +# ADAPTER FUNCTIONS +# ============================================================================ + + +def image_matching( + geolocated_data: xr.Dataset, + gcp_reference_file: Path, + telemetry: pd.DataFrame, + calibration_dir: Path, + params_info: list, + config: "CorrectionConfig", + los_vectors_cached: np.ndarray | None = None, + optical_psfs_cached: list | None = None, +) -> xr.Dataset: + """ + Image matching using integrated_image_match() module. + + This function performs actual image correlation between geolocated + pixels and Landsat GCP reference imagery. + + Args: + geolocated_data: xarray.Dataset with latitude, longitude from geolocation + gcp_reference_file: Path to GCP reference image (MATLAB .mat file) + telemetry: Telemetry DataFrame with spacecraft state + calibration_dir: Directory containing calibration files (LOS vectors, PSF) + params_info: Current parameter values for error tracking + config: CorrectionConfig with coordinate name mappings + los_vectors_cached: Pre-loaded LOS vectors (optional, for performance) + optical_psfs_cached: Pre-loaded optical PSF entries (optional, for performance) + + Returns: + xarray.Dataset with error measurements in format expected by error_stats: + - lat_error_deg, lon_error_deg: Spatial errors in degrees + - Additional metadata for error statistics processing + + Raises: + FileNotFoundError: If calibration files are missing + ValueError: If geolocation data is invalid + """ + logger.info(f"Image Matching: correlation with {gcp_reference_file.name}") + start_time = time.time() + + # Convert geolocation output to ImageGrid + logger.info(" Converting geolocation data to ImageGrid format...") + subimage = _geolocated_to_image_grid(geolocated_data) + logger.info(f" Subimage shape: {subimage.data.shape}") + + # Load GCP reference image + logger.info(f" Loading GCP reference from {gcp_reference_file}...") + gcp = load_image_grid_from_mat(gcp_reference_file, key="GCP") + # Get GCP center location (center pixel) + gcp_center_lat = float(gcp.lat[gcp.lat.shape[0] // 2, gcp.lat.shape[1] // 2]) + gcp_center_lon = float(gcp.lon[gcp.lon.shape[0] // 2, gcp.lon.shape[1] // 2]) + logger.info(f" GCP shape: {gcp.data.shape}, center: ({gcp_center_lat:.4f}, {gcp_center_lon:.4f})") + + # Use cached calibration data if available, otherwise load + logger.info(" Loading calibration data...") + + if los_vectors_cached is not None and optical_psfs_cached is not None: + # Use cached data (fast path) + los_vectors = los_vectors_cached + optical_psfs = optical_psfs_cached + logger.info(" Using cached calibration data") + else: + # Prefer direct file paths from config; fall back to calibration_dir parameter. + if config.los_vectors_file is not None: + los_file = Path(config.los_vectors_file) + elif calibration_dir is not None: + los_filename = config.get_calibration_file("los_vectors", default="b_HS.mat") + los_file = calibration_dir / los_filename + else: + raise ValueError("No LOS vectors source configured. Set config.los_vectors_file or config.calibration_dir.") + los_vectors = load_los_vectors_from_mat(los_file) + logger.info(f" LOS vectors: {los_vectors.shape}") + + if config.psf_file is not None: + psf_file = Path(config.psf_file) + elif calibration_dir is not None: + psf_filename = config.get_calibration_file("optical_psf", default="optical_PSF_675nm_upsampled.mat") + psf_file = calibration_dir / psf_filename + else: + raise ValueError("No PSF source configured. Set config.psf_file or config.calibration_dir.") + optical_psfs = load_optical_psf_from_mat(psf_file) + logger.info(f" Optical PSF: {len(optical_psfs)} entries") + + # Extract spacecraft position from telemetry + r_iss_midframe = _extract_spacecraft_position_midframe(telemetry, config=config) + logger.info(f" Spacecraft position: {r_iss_midframe}") + + # Run real image matching + logger.info(" Running integrated_image_match()...") + geolocation_config = PSFSamplingConfig() + search_config = SearchConfig() + + result = integrated_image_match( + subimage=subimage, + gcp=gcp, + r_iss_midframe_m=r_iss_midframe, + los_vectors_hs=los_vectors, + optical_psfs=optical_psfs, + geolocation_config=geolocation_config, + search_config=search_config, + ) + + # Convert IntegratedImageMatchResult to xarray.Dataset format + logger.info(" Converting results to error_stats format...") + + # Create single measurement result (image matching produces one correlation per GCP) + + # NOTE: Boresight and transformation matrix for error_stats module + # ---------------------------------------------------------------- + # These values are NOT used by image_matching() itself - the image correlation + # is complete and accurate without them. They are needed by call_error_stats_module() + # for converting off-nadir errors to nadir-equivalent errors. + # + # Currently using simplified nadir assumptions which are acceptable for: + # - Near-nadir observations (< ~5 degrees off-nadir) + # - Testing image matching correlation accuracy (doesn't affect matching) + # + # For accurate nadir-equivalent error conversion with off-nadir pointing, these + # should be extracted from SPICE/geolocation data: + # - boresight: Extract from spicierpy.getfov(instrument) and transform via geo_dataset['attitude'] + # - t_matrix: Extract from geo_dataset['attitude'] (transformation from instrument to CTRS) + # + # See: error_stats.py _transform_boresight_vectors() for usage + # See: BORESIGHT_TRANSFORM_ANALYSIS.md for detailed analysis and future enhancement plan + + t_matrix = np.eye(3) # Simplified: Identity matrix (no rotation) + boresight = np.array([0.0, 0.0, 1.0]) # Simplified: Nadir pointing assumption + + # Convert errors from km to degrees + lat_error_deg = result.lat_error_km / 111.0 # ~111 km per degree latitude + lon_radius_km = constants.WGS84_SEMI_MAJOR_AXIS_KM * np.cos(np.deg2rad(gcp_center_lat)) + lon_error_deg = result.lon_error_km / (lon_radius_km * np.pi / 180.0) + + processing_time = time.time() - start_time + + logger.info(f" Image matching complete in {processing_time:.2f}s:") + logger.info(f" Lat error: {result.lat_error_km:.3f} km ({lat_error_deg:.6f}°)") + logger.info(f" Lon error: {result.lon_error_km:.3f} km ({lon_error_deg:.6f}°)") + logger.info(f" Correlation: {result.ccv_final:.4f}") + logger.info(f" Grid step: {result.final_grid_step_m:.1f} m") + + # Get coordinate names from config + sc_pos_name = config.spacecraft_position_name + boresight_name = config.boresight_name + transform_name = config.transformation_matrix_name + + # Create output dataset in error_stats format (use config names) + output = xr.Dataset( + { + "lat_error_deg": (["measurement"], [lat_error_deg]), + "lon_error_deg": (["measurement"], [lon_error_deg]), + sc_pos_name: (["measurement", "xyz"], [r_iss_midframe]), + boresight_name: (["measurement", "xyz"], [boresight]), + transform_name: (["measurement", "xyz_from", "xyz_to"], t_matrix[np.newaxis, :, :]), + "gcp_lat_deg": (["measurement"], [gcp_center_lat]), + "gcp_lon_deg": (["measurement"], [gcp_center_lon]), + "gcp_alt": (["measurement"], [0.0]), # GCP at ground level + }, + coords={"measurement": [0], "xyz": ["x", "y", "z"], "xyz_from": ["x", "y", "z"], "xyz_to": ["x", "y", "z"]}, + ) + + # Add detailed metadata (Fix #3 Part B: Add km errors to attrs) + output.attrs.update( + { + "lat_error_km": result.lat_error_km, + "lon_error_km": result.lon_error_km, + "correlation_ccv": result.ccv_final, + "final_grid_step_m": result.final_grid_step_m, + "final_index_row": result.final_index_row, + "final_index_col": result.final_index_col, + "processing_time_s": processing_time, + "gcp_file": str(gcp_reference_file.name), + "gcp_center_lat": gcp_center_lat, + "gcp_center_lon": gcp_center_lon, + } + ) + + return output + + +def call_error_stats_module(image_matching_results, correction_config: "CorrectionConfig"): + """ + Call the error_stats module with image matching output. + + Args: + image_matching_results: Either a single image matching result (xarray.Dataset) + or a list of image matching results from multiple GCP pairs + correction_config: CorrectionConfig with all configuration (REQUIRED) + + Returns: + Aggregate error statistics dataset + """ + # Handle both single result and list of results + if not isinstance(image_matching_results, list): + image_matching_results = [image_matching_results] + + try: + from curryer.correction.error_stats import ErrorStatsConfig, ErrorStatsProcessor + + logger.info(f"Error Statistics: Processing geolocation errors from {len(image_matching_results)} GCP pairs") + + # Create error stats config directly from Correction config (single source of truth) + error_config = ErrorStatsConfig.from_correction_config(correction_config) + + processor = ErrorStatsProcessor(config=error_config) + + if len(image_matching_results) == 1: + # Single GCP pair case + error_results = processor.process_geolocation_errors(image_matching_results[0]) + else: + # Multiple GCP pairs - aggregate the data first + aggregated_data = _aggregate_image_matching_results(image_matching_results, correction_config) + error_results = processor.process_geolocation_errors(aggregated_data) + + return error_results + + except ImportError as e: + logger.warning(f"Error stats module not available: {e}") + logger.info(f"Error Statistics: Using placeholder calculations for {len(image_matching_results)} GCP pairs") + + # Fallback: compute basic statistics across all GCP pairs + all_lat_errors = [] + all_lon_errors = [] + total_measurements = 0 + + for result in image_matching_results: + lat_errors = result["lat_error_deg"].values + lon_errors = result["lon_error_deg"].values + all_lat_errors.extend(lat_errors) + all_lon_errors.extend(lon_errors) + total_measurements += len(lat_errors) + + all_lat_errors = np.array(all_lat_errors) + all_lon_errors = np.array(all_lon_errors) + + # Convert to meters (approximate) + lat_error_m = all_lat_errors * 111000 + lon_error_m = all_lon_errors * 111000 + total_error_m = np.sqrt(lat_error_m**2 + lon_error_m**2) + + mean_error = float(np.mean(total_error_m)) + rms_error = float(np.sqrt(np.mean(total_error_m**2))) + std_error = float(np.std(total_error_m)) + + return xr.Dataset( + { + "mean_error": mean_error, + "rms_error": rms_error, + "std_error": std_error, + "max_error": float(np.max(total_error_m)), + "min_error": float(np.min(total_error_m)), + } + ) + + +def _aggregate_image_matching_results(image_matching_results, config: "CorrectionConfig"): + """ + Aggregate multiple image matching results into a single dataset for error stats processing. + + Args: + image_matching_results: List of xarray.Dataset objects from image matching + config: CorrectionConfig with coordinate name mappings + + Returns: + Single aggregated xarray.Dataset with all measurements combined + """ + logger.info(f"Aggregating {len(image_matching_results)} image matching results") + + # Get coordinate names from config + sc_pos_name = config.spacecraft_position_name + boresight_name = config.boresight_name + transform_name = config.transformation_matrix_name + + # Combine all measurements into single arrays + all_lat_errors = [] + all_lon_errors = [] + all_sc_positions = [] + all_boresights = [] + all_transforms = [] + all_gcp_lats = [] + all_gcp_lons = [] + all_gcp_alts = [] + + for i, result in enumerate(image_matching_results): + # Add GCP pair identifier to track source + n_measurements = len(result["lat_error_deg"]) + + all_lat_errors.extend(result["lat_error_deg"].values) + all_lon_errors.extend(result["lon_error_deg"].values) + + # Handle coordinate transformation data (use config names) + # NOTE: Individual results have shape (1, 3) for vectors and (1, 3, 3) for matrices + if sc_pos_name in result: + # Shape: (1, 3) -> extract as (3,) for each measurement + for j in range(n_measurements): + all_sc_positions.append(result[sc_pos_name].values[j]) + if boresight_name in result: + # Shape: (1, 3) -> extract as (3,) for each measurement + for j in range(n_measurements): + all_boresights.append(result[boresight_name].values[j]) + if transform_name in result: + # Shape: (1, 3, 3) -> extract as (3, 3) for each measurement + for j in range(n_measurements): + all_transforms.append(result[transform_name].values[j, :, :]) + if "gcp_lat_deg" in result: + all_gcp_lats.extend(result["gcp_lat_deg"].values) + if "gcp_lon_deg" in result: + all_gcp_lons.extend(result["gcp_lon_deg"].values) + if "gcp_alt" in result: + all_gcp_alts.extend(result["gcp_alt"].values) + + n_total = len(all_lat_errors) + + # Create aggregated dataset with correct dimension names for error_stats + aggregated = xr.Dataset( + { + "lat_error_deg": (["measurement"], np.array(all_lat_errors)), + "lon_error_deg": (["measurement"], np.array(all_lon_errors)), + }, + coords={"measurement": np.arange(n_total)}, + ) + + # Add optional coordinate transformation data if available (use config names) + # Use dimension names that match error_stats expectations + if all_sc_positions: + # Stack into (n_measurements, 3) + aggregated[sc_pos_name] = (["measurement", "xyz"], np.array(all_sc_positions)) + aggregated = aggregated.assign_coords({"xyz": ["x", "y", "z"]}) + + if all_boresights: + # Stack into (n_measurements, 3) + aggregated[boresight_name] = (["measurement", "xyz"], np.array(all_boresights)) + + if all_transforms: + # Stack into (n_measurements, 3, 3) to match error_stats format + t_stacked = np.stack(all_transforms, axis=0) + aggregated[transform_name] = (["measurement", "xyz_from", "xyz_to"], t_stacked) + aggregated = aggregated.assign_coords({"xyz_from": ["x", "y", "z"], "xyz_to": ["x", "y", "z"]}) + + if all_gcp_lats: + aggregated["gcp_lat_deg"] = (["measurement"], np.array(all_gcp_lats)) + if all_gcp_lons: + aggregated["gcp_lon_deg"] = (["measurement"], np.array(all_gcp_lons)) + if all_gcp_alts: + aggregated["gcp_alt"] = (["measurement"], np.array(all_gcp_alts)) + + aggregated.attrs["source_gcp_pairs"] = len(image_matching_results) + aggregated.attrs["total_measurements"] = n_total + + logger.info(f" Aggregated dataset: {n_total} measurements from {len(image_matching_results)} GCP pairs") + logger.info(f" Dimensions: {dict(aggregated.sizes)}") + + return aggregated + + +def _resolve_gcp_pairs( + sci_key: str, + gcp_key: str, + config: "CorrectionConfig", +) -> list[tuple[str, str]]: + """Return the ``[(sci_key, gcp_key)]`` pair, validating that ``gcp_key`` is set. + + Parameters + ---------- + sci_key : str + Science file path for this outer-loop iteration. + gcp_key : str + GCP ``.mat`` file path supplied as the third element of the + ``tlm_sci_gcp_sets`` tuple. Must be non-empty. + config : CorrectionConfig + Unused directly; reserved for future extension. + + Returns + ------- + list of (sci_key, gcp_path) tuples — always length 1. + + Raises + ------ + ValueError + If ``gcp_key`` is empty or whitespace-only. + + Notes + ----- + For spatial-overlap-based pairing (many L1A images × many GCP chips) call + :func:`~curryer.correction.pairing.pair_files` *before* :func:`loop` to + build ``tlm_sci_gcp_sets`` from the ``(l1a_file, gcp_file)`` results. + """ + if not gcp_key or not gcp_key.strip(): + raise ValueError( + "gcp_key must be a non-empty file path to a GCP .mat file.\n" + "Pass the GCP file path as the third element of each tlm_sci_gcp_sets tuple:\n" + " tlm_sci_gcp_sets = [(tlm_path, sci_path, gcp_path), ...]\n" + "\n" + "To compute which GCP chips overlap a given L1A footprint use:\n" + " from curryer.correction.pairing import pair_files\n" + " pairs = pair_files(l1a_files, gcp_directory, max_distance_m=0.0)" + ) + return [(sci_key, gcp_key)] + + +def _load_file(file_path: str | Path, file_format: str = "csv") -> pd.DataFrame: + """Load a telemetry or science data file into a pandas DataFrame. + + Parameters + ---------- + file_path : str | Path + Local path or S3 URI (``s3://bucket/key``). + file_format : str + One of ``"csv"``, ``"netcdf"``, or ``"hdf5"``. + + Returns + ------- + pd.DataFrame + + Raises + ------ + FileNotFoundError + If *file_path* is local and does not exist. + ImportError + If *file_path* is an S3 URI and boto3 is not installed. + ValueError + If ``file_format`` is not recognised. + """ + from curryer.correction.io import resolve_path + + file_path = resolve_path(file_path) + # NOTE: resolve_path already validated existence / downloaded from S3. + # The old manual exists() check is removed. + + if file_format == "csv": + return pd.read_csv(file_path, index_col=0) + elif file_format == "netcdf": + return xr.load_dataset(file_path).to_dataframe().reset_index() + elif file_format == "hdf5": + return pd.read_hdf(file_path) + else: + raise ValueError(f"Unsupported file_format '{file_format}'. Must be 'csv', 'netcdf', or 'hdf5'.") + + +# ============================================================================= +# HELPER FUNCTIONS +# ============================================================================= +# These functions extract reusable logic from the main loop to simplify the structure + + +def _load_calibration_data(config: "CorrectionConfig") -> CalibrationData: + """Load LOS vectors and optical PSF if calibration_dir is configured. + + This function centralizes calibration data loading, which is now called once + per GCP pair in the optimized implementation (previously called once per parameter set). + + Parameters + ---------- + config : CorrectionConfig + Configuration with calibration_dir and calibration settings + + Returns + ------- + CalibrationData + NamedTuple containing (los_vectors, optical_psfs), or (None, None) if + no calibration directory configured + + Raises + ------ + FileNotFoundError + If calibration directory is configured but files don't exist + ValueError + If calibration files exist but fail to load properly + + Note + ---- + Supports three resolution strategies (in priority order): + + 1. ``config.los_vectors_file`` / ``config.psf_file`` — direct file paths + (set in PR 1 via ``CorrectionConfig``). + 2. ``config.calibration_dir`` + ``config.calibration_file_names`` — legacy + directory-based lookup. + 3. Neither configured — returns ``CalibrationData(None, None)`` so that + missions without calibration files still work. + + Examples + -------- + >>> calib_data = _load_calibration_data(config) + >>> if calib_data.los_vectors is not None: + ... # Use calibration data in image matching + ... pass + """ + has_direct = config.los_vectors_file is not None or config.psf_file is not None + has_dir = bool(config.calibration_dir) + + if not has_direct and not has_dir: + return CalibrationData(los_vectors=None, optical_psfs=None) + + logger.info("Loading calibration data...") + + # ---- LOS vectors ---- + if config.los_vectors_file is not None: + los_file = Path(config.los_vectors_file) + elif config.calibration_dir is not None: + los_filename = config.get_calibration_file("los_vectors", default="b_HS.mat") + los_file = config.calibration_dir / los_filename + else: + raise ValueError("No LOS vectors source configured. Set config.los_vectors_file or config.calibration_dir.") + + if not los_file.exists(): + raise FileNotFoundError( + f"LOS vectors calibration file not found: {los_file}\nSet config.los_vectors_file to the correct path." + ) + + los_vectors_cached = load_los_vectors_from_mat(los_file) + + if los_vectors_cached is None: + raise ValueError( + f"Failed to load LOS vectors from {los_file}. File exists but load_los_vectors_from_mat() returned None." + ) + + # ---- Optical PSF ---- + if config.psf_file is not None: + psf_file = Path(config.psf_file) + elif config.calibration_dir is not None: + psf_filename = config.get_calibration_file("optical_psf", default="optical_PSF_675nm_upsampled.mat") + psf_file = config.calibration_dir / psf_filename + else: + raise ValueError("No PSF source configured. Set config.psf_file or config.calibration_dir.") + + if not psf_file.exists(): + raise FileNotFoundError( + f"Optical PSF calibration file not found: {psf_file}\nSet config.psf_file to the correct path." + ) + + optical_psfs_cached = load_optical_psf_from_mat(psf_file) + + if optical_psfs_cached is None: + raise ValueError( + f"Failed to load optical PSF from {psf_file}. File exists but load_optical_psf_from_mat() returned None." + ) + + logger.info(f" Cached LOS vectors: {los_vectors_cached.shape}") + logger.info(f" Cached optical PSF: {len(optical_psfs_cached)} entries") + + return CalibrationData(los_vectors=los_vectors_cached, optical_psfs=optical_psfs_cached) + + +def _load_image_pair_data( + tlm_key: str, + sci_key: str, + config: "CorrectionConfig", +) -> tuple[pd.DataFrame, pd.DataFrame, Any]: + """Load telemetry and science data for an image pair from files. + + Parameters + ---------- + tlm_key : str + Path to the telemetry data file. + sci_key : str + Path to the science frame timing file. + config : CorrectionConfig + Configuration containing geolocation settings, file format, and + time-scaling options (via ``config.data``). + + Returns + ------- + tlm_dataset : pandas.DataFrame + DataFrame containing spacecraft state / telemetry records. + sci_dataset : pandas.DataFrame + DataFrame containing science frame timing information. + ugps_times : array_like + Time array extracted from the science dataset (uGPS values). + + Raises + ------ + FileNotFoundError + If the telemetry or science file does not exist. + ValueError + If the science file is missing the required time field. + """ + file_format = "csv" + time_scale_factor = 1.0 + + if config.data is not None: + file_format = config.data.file_format + time_scale_factor = config.data.time_scale_factor + + # Load telemetry from file + tlm_dataset = _load_file(tlm_key, file_format) + validate_telemetry_output(tlm_dataset, config) + + # Load science from file + sci_dataset = _load_file(sci_key, file_format) + + # Apply time scale factor to convert to uGPS if needed + time_field = config.geo.time_field + if time_field in sci_dataset.columns and time_scale_factor != 1.0: + sci_dataset = sci_dataset.copy() + sci_dataset[time_field] = sci_dataset[time_field] * time_scale_factor + + validate_science_output(sci_dataset, config) + ugps_times = sci_dataset[time_field] + + return tlm_dataset, sci_dataset, ugps_times + + +def _geolocate_and_match( + config: "CorrectionConfig", + kernel_ctx: KernelContext, + ugps_times_modified: Any, + tlm_dataset: pd.DataFrame, + calibration: CalibrationData, + image_matching_func: Any, + match_ctx: ImageMatchingContext, +) -> tuple[xr.Dataset, xr.Dataset]: + """Perform geolocation and image matching for a parameter set. + + This function loads SPICE kernels, performs geolocation, and runs image + matching against GCP reference data. It's the core computation step that + combines all previous setup (kernels, data loading) into results. + + Parameters + ---------- + config : CorrectionConfig + Configuration with geo and image matching settings + kernel_ctx : KernelContext + NamedTuple containing: + - mkrn: MetaKernel instance with SDS and mission kernels + - dynamic_kernels: List of dynamic kernel file paths + - param_kernels: List of parameter-specific kernel file paths + ugps_times_modified : array-like + Time array (possibly modified by OFFSET_TIME parameter) + tlm_dataset : pd.DataFrame + Spacecraft state telemetry data + calibration : CalibrationData + NamedTuple containing: + - los_vectors: Pre-loaded LOS vectors (or None) + - optical_psfs: Pre-loaded optical PSF (or None) + image_matching_func : callable + Function to perform image matching; defaults to the built-in + ``image_matching`` function when not overridden in config. + match_ctx : ImageMatchingContext + NamedTuple containing: + - gcp_pairs: List of GCP pairing tuples + - params: List of (ParameterConfig, parameter_value) tuples + - pair_idx: Index of current GCP pair + - sci_key: Science dataset identifier for this pair + + Returns + ------- + geo_dataset : xr.Dataset + Geolocated points with latitude, longitude, altitude + image_matching_output : xr.Dataset + Matching results with error measurements and metadata + + Examples + -------- + >>> kernel_ctx = KernelContext(mkrn, dynamic_kernels, param_kernels) + >>> calibration = CalibrationData(los_vectors, optical_psfs) + >>> match_ctx = ImageMatchingContext(gcp_pairs, params, 0, "sci_001") + >>> geo, matching = _geolocate_and_match( + ... config, kernel_ctx, times, tlm_dataset, + ... calibration, integrated_image_match, match_ctx + ... ) + """ + logger.info(" Performing geolocation...") + with sp.ext.load_kernel( + [ + kernel_ctx.mkrn.sds_kernels, + kernel_ctx.mkrn.mission_kernels, + kernel_ctx.dynamic_kernels, + kernel_ctx.param_kernels, + ] + ): + geoloc_inst = spatial.Geolocate(config.geo.instrument_name) + geo_dataset = geoloc_inst(ugps_times_modified) + + # === IMAGE MATCHING MODULE === + logger.info(" === IMAGE MATCHING MODULE ===") + + # Use injected image matching function + gcp_file = Path(match_ctx.gcp_pairs[0][1]) + + # All image matching functions use the same signature + image_matching_output = image_matching_func( + geolocated_data=geo_dataset, + gcp_reference_file=gcp_file, + telemetry=tlm_dataset, + calibration_dir=config.calibration_dir, + params_info=match_ctx.params, + config=config, + los_vectors_cached=calibration.los_vectors, + optical_psfs_cached=calibration.optical_psfs, + ) + validate_image_matching_output(image_matching_output) + logger.info(" Image matching complete") + + logger.info(f" Generated error measurements for {len(image_matching_output.measurement)} points") + + # Store metadata for tracking + image_matching_output.attrs["gcp_pair_index"] = match_ctx.pair_idx + image_matching_output.attrs["gcp_pair_id"] = f"{match_ctx.sci_key}_pair_{match_ctx.pair_idx}" + + return geo_dataset, image_matching_output + + +def loop( + config: CorrectionConfig, + work_dir: Path, + tlm_sci_gcp_sets: list[tuple[str, str, str]], + resume_from_checkpoint: bool = False, +): + """ + Correction loop for parameter sensitivity analysis. + + Parameters + ---------- + config : CorrectionConfig + The single configuration containing all settings: + - Required: parameters, iterations, thresholds, geo config + - Data loading: ``data`` (:class:`~curryer.correction.config.DataConfig`) + specifying file format and time scaling + - Optional: ``_image_matching_override`` override on ``config`` + (test injection only) + - Calibration: `calibration_dir` (if the image-matching override uses calibration) + - Output: netcdf, output_filename + work_dir : Path + Working directory for temporary files. + tlm_sci_gcp_sets : list of (str, str, str) + List of (`telemetry_key`, `science_key`, `gcp_key`) tuples. + File paths are expected to be local. S3 URIs (``s3://…``) are also + accepted as a convenience when ``boto3`` is installed; see + :func:`~curryer.correction.io.resolve_path`. + resume_from_checkpoint : bool, optional + If True, resume from an existing checkpoint. + + Returns + ------- + results : list + List of iteration results (order: `pair_idx * N + param_idx`). + netcdf_data : dict + Dictionary of NetCDF variables indexed as `[param_idx, pair_idx]`. + + Notes + ----- + This implementation uses a pair-outer, parameter-inner loop order: + - Outer loop: GCP pairs (load data once per image) + - Inner loop: Parameter sets (reuse loaded data) + This reduces file I/O and centralizes mission-specific behavior through the + `config` object. + + Examples + -------- + Correction mode (parameter optimization):: + + from curryer.correction.config import CorrectionConfig, DataConfig + + config = CorrectionConfig( + seed=42, + n_iterations=100, + parameters=parameters, + geo=geo_config, + performance_threshold_m=250.0, + performance_spec_percent=39.0, + data=DataConfig(file_format="csv", time_scale_factor=1e6), + ) + results, netcdf_data = loop(config, work_dir, tlm_sci_gcp_sets) + + Where each element of ``tlm_sci_gcp_sets`` is a tuple of file paths:: + + tlm_sci_gcp_sets = [ + ("telemetry.csv", "science.csv", "landsat_chip_001.mat"), + ] + """ + logger.info("=== CORRECTION PIPELINE ===") + logger.info(f" GCP pairs: {len(tlm_sci_gcp_sets)} (outer loop - load data once)") + + # Use injected image matching function override, or fall back to built-in implementation + image_matching_func = getattr(config, "_image_matching_override", None) or image_matching + + # Initialize parameter sets + params_set = load_param_sets(config) + logger.info(f" Parameter sets: {len(params_set)} (inner loop)") + + # Build NetCDF data structure + n_param_sets = len(params_set) + n_gcp_pairs = len(tlm_sci_gcp_sets) + + # Try to load checkpoint if resuming + output_file = work_dir / config.get_output_filename() + start_pair_idx = 0 + # Currently, checkpoint is bugged, since the nadir equivalent stats are not calculated until the end. + # TODO [CURRYER-100]: Fix checkpoint resume for Monte Carlo GCS + if resume_from_checkpoint: + checkpoint_data, completed_pairs = _load_checkpoint(output_file, config) + if checkpoint_data is not None: + netcdf_data = checkpoint_data + start_pair_idx = completed_pairs + logger.info(f"Resuming from checkpoint: starting at GCP pair {start_pair_idx + 1}/{n_gcp_pairs}") + else: + netcdf_data = _build_netcdf_structure(config, n_param_sets, n_gcp_pairs) + logger.info("No valid checkpoint found, starting from beginning") + else: + netcdf_data = _build_netcdf_structure(config, n_param_sets, n_gcp_pairs) + + # Initialize results dict with (param_idx, pair_idx) keys + # This avoids nested search complexity when aggregating statistics + results_dict = {} + + # Prepare SPICE environment + mkrn = meta.MetaKernel.from_json( + config.geo.meta_kernel_file, + relative=True, + sds_dir=config.geo.generic_kernel_dir, + ) + creator = create.KernelCreator(overwrite=True, append=False) + + # Load calibration data once (LOS vectors and optical PSF are static instrument calibration) + calibration_data = _load_calibration_data(config) + + # Create error stats processor once (config is constant; processor is stateless) + error_config = ErrorStatsConfig.from_correction_config(config) + error_processor = ErrorStatsProcessor(config=error_config) + + # Store parameter values once (before loops) + for param_idx, params in enumerate(params_set): + param_values = _extract_parameter_values(params) + _store_parameter_values(netcdf_data, param_idx, param_values) + + # OUTER LOOP: Iterate through GCP pairs + for pair_idx, (tlm_key, sci_key, gcp_key) in enumerate(tlm_sci_gcp_sets): + # Skip already-completed pairs if resuming + if pair_idx < start_pair_idx: + logger.info(f"=== GCP Pair {pair_idx + 1}/{n_gcp_pairs}: {sci_key} === (SKIPPED - already completed)") + continue + + logger.info(f"=== GCP Pair {pair_idx + 1}/{n_gcp_pairs}: {sci_key} ===") + + # Load image pair data once (internal file-based loading) + tlm_dataset, sci_dataset, ugps_times = _load_image_pair_data(tlm_key, sci_key, config) + + # Create dynamic kernels once (these don't change with parameters) + dynamic_kernels = _create_dynamic_kernels(config, work_dir, tlm_dataset, creator) + + # Use gcp_key directly as the GCP file path — no pairing function needed. + # Users specify exactly which GCP file pairs with each science file in + # tlm_sci_gcp_sets. An empty string disables image matching for that pair. + gcp_pairs = [(sci_key, gcp_key)] + logger.info(f" GCP file: {gcp_key or '(none)'}") + + # INNER LOOP: Iterate through parameter sets + for param_idx, params in enumerate(params_set): + logger.info(f" Parameter Set {param_idx + 1}/{n_param_sets}") + + # Create parameter-specific kernels (these change with parameters) + param_kernels, ugps_times_modified = _create_parameter_kernels( + params, work_dir, tlm_dataset, sci_dataset, ugps_times, config, creator + ) + + # Prepare context objects for cleaner function call + kernel_ctx = KernelContext(mkrn=mkrn, dynamic_kernels=dynamic_kernels, param_kernels=param_kernels) + match_ctx = ImageMatchingContext(gcp_pairs=gcp_pairs, params=params, pair_idx=pair_idx, sci_key=sci_key) + + # Geolocate and perform image matching + geo_dataset, image_matching_output = _geolocate_and_match( + config, + kernel_ctx, + ugps_times_modified, + tlm_dataset, + calibration_data, + image_matching_func, + match_ctx, + ) + + # Compute nadir-equivalent errors for this GCP pair. + # compute_nadir_equivalent_errors() skips aggregate statistics — + # computing mean/std/percentiles on a single GCP pair is + # mathematically uninformative and wastes time in a tight loop. + individual_nadir = error_processor.compute_nadir_equivalent_errors(image_matching_output) + + nadir_errors = individual_nadir["nadir_equiv_total_error_m"].values + if len(nadir_errors) == 1: + nadir_error = float(nadir_errors[0]) + individual_metrics = { + "rms_error_m": nadir_error, + "mean_error_m": nadir_error, + "max_error_m": nadir_error, + "std_error_m": 0.0, + "n_measurements": 1, + } + else: + individual_metrics = { + "rms_error_m": float(np.sqrt(np.mean(nadir_errors**2))), + "mean_error_m": float(np.mean(nadir_errors)), + "max_error_m": float(np.max(nadir_errors)), + "std_error_m": float(np.std(nadir_errors)), + "n_measurements": len(nadir_errors), + } + individual_stats = individual_nadir + + # Store results in NetCDF (maintain [param_idx, pair_idx] ordering) + _store_gcp_pair_results(netcdf_data, param_idx, pair_idx, individual_metrics) + netcdf_data["im_lat_error_km"][param_idx, pair_idx] = image_matching_output.attrs.get( + "lat_error_km", np.nan + ) + netcdf_data["im_lon_error_km"][param_idx, pair_idx] = image_matching_output.attrs.get( + "lon_error_km", np.nan + ) + netcdf_data["im_ccv"][param_idx, pair_idx] = image_matching_output.attrs.get("correlation_ccv", np.nan) + netcdf_data["im_grid_step_m"][param_idx, pair_idx] = image_matching_output.attrs.get( + "final_grid_step_m", np.nan + ) + + # Store results in dict with (param_idx, pair_idx) key + # Note: iteration index reflects reversed order (pair_idx * n_params + param_idx) + param_values = _extract_parameter_values(params) + iteration_result = { + "iteration": pair_idx * n_param_sets + param_idx, + "pair_index": pair_idx, + "param_index": param_idx, + "parameters": param_values, + "geolocation": geo_dataset, + "gcp_pairs": gcp_pairs, + "image_matching": image_matching_output, + "error_stats": individual_stats, + "rms_error_m": individual_metrics["rms_error_m"], + "aggregate_rms_error_m": None, + } + results_dict[(param_idx, pair_idx)] = iteration_result + + logger.info( + f" RMS error: {individual_metrics['rms_error_m']:.2f}m " + f"({individual_metrics['n_measurements']} measurements)" + ) + + logger.info(f" GCP pair {pair_idx + 1} complete (processed {n_param_sets} parameter sets)") + + # Save checkpoint after each pair completes + if resume_from_checkpoint: + _save_netcdf_checkpoint(netcdf_data, output_file, config, pair_idx) + + # Compute aggregate statistics for each parameter set (after all pairs complete) + logger.info("=== Computing aggregate statistics for all parameter sets ===") + for param_idx in range(n_param_sets): + # Collect all image matching results for this parameter set + param_image_matching_results = [] + for pair_idx in range(n_gcp_pairs): + result = results_dict.get((param_idx, pair_idx)) + if result: + param_image_matching_results.append(result["image_matching"]) + + # Compute aggregate statistics + aggregate_stats = call_error_stats_module(param_image_matching_results, correction_config=config) + aggregate_error_metrics = _extract_error_metrics(aggregate_stats) + + # Extract pair errors for threshold calculation + pair_errors = [netcdf_data["rms_error_m"][param_idx, pair_idx] for pair_idx in range(n_gcp_pairs)] + _compute_parameter_set_metrics(netcdf_data, param_idx, pair_errors, threshold_m=config.performance_threshold_m) + + logger.info(f" Parameter set {param_idx + 1}: Aggregate RMS = {aggregate_error_metrics['rms_error_m']:.2f}m") + + # Add aggregate stats to all results for this parameter set + for pair_idx in range(n_gcp_pairs): + key = (param_idx, pair_idx) + if key in results_dict: + results_dict[key]["aggregate_error_stats"] = aggregate_stats + results_dict[key]["aggregate_rms_error_m"] = aggregate_error_metrics["rms_error_m"] + # Convert results_dict back to list for backward compatibility + # Sort by iteration index to maintain consistent ordering + results = [results_dict[key] for key in sorted(results_dict.keys(), key=lambda k: results_dict[k]["iteration"])] + + # Save final NetCDF results + _save_netcdf_results(netcdf_data, output_file, config) + + # Clean up checkpoint file after successful completion + if resume_from_checkpoint: + _cleanup_checkpoint(output_file) + + logger.info(f"=== Loop Complete: Processed {n_gcp_pairs} GCP pairs × {n_param_sets} parameter sets ===") + logger.info(f" Total iterations: {len(results)}") + logger.info(f" NetCDF output: {output_file}") + + return results, netcdf_data + + +def _extract_parameter_values(params): + """Extract parameter values from a parameter set into a dictionary.""" + param_values = {} + + for param_config, param_data in params: + if param_config.config_file: + param_name = param_config.config_file.stem + + if param_config.ptype == ParameterType.CONSTANT_KERNEL: + # Extract roll, pitch, yaw from DataFrame + if isinstance(param_data, pd.DataFrame) and "angle_x" in param_data.columns: + # Convert back to arcseconds for storage + param_values[f"{param_name}_roll"] = np.degrees(param_data["angle_x"].iloc[0]) * 3600 + param_values[f"{param_name}_pitch"] = np.degrees(param_data["angle_y"].iloc[0]) * 3600 + param_values[f"{param_name}_yaw"] = np.degrees(param_data["angle_z"].iloc[0]) * 3600 + + elif param_config.ptype == ParameterType.OFFSET_KERNEL: + # Single bias value (keep in original units) + param_values[param_name] = param_data + + elif param_config.ptype == ParameterType.OFFSET_TIME: + # Time correction (keep in original units) + param_values[param_name] = param_data + + return param_values + + +def _store_parameter_values(netcdf_data, param_idx, param_values): + """Store parameter values in the NetCDF data structure. + + This function maps parameter names to NetCDF variable names for storage. + It handles the naming convention used by _build_netcdf_structure. + """ + + for param_name, value in param_values.items(): + # Generate NetCDF variable name using same logic as _build_netcdf_structure + # Replace dots and dashes with underscores, ensure param_ prefix + netcdf_var = param_name.replace(".", "_").replace("-", "_") + if not netcdf_var.startswith("param_"): + netcdf_var = f"param_{netcdf_var}" + + if netcdf_var in netcdf_data: + netcdf_data[netcdf_var][param_idx] = value + logger.debug(f" Stored {netcdf_var}[{param_idx}] = {value}") + else: + # Try to find a matching variable with debug info + logger.warning( + f" Parameter variable '{netcdf_var}' not found in netcdf_data. Available keys: {[k for k in netcdf_data.keys() if k.startswith('param_')]}" + ) + + +def _extract_error_metrics(stats_dataset): + """Extract error metrics from error statistics dataset.""" + if hasattr(stats_dataset, "attrs"): + # Real error stats module + return { + "rms_error_m": stats_dataset.attrs.get("rms_error_m", np.nan), + "mean_error_m": stats_dataset.attrs.get("mean_error_m", np.nan), + "max_error_m": stats_dataset.attrs.get("max_error_m", np.nan), + "std_error_m": stats_dataset.attrs.get("std_error_m", np.nan), + "n_measurements": stats_dataset.attrs.get("total_measurements", 0), + } + else: + # Fallback for placeholder + return { + "rms_error_m": float(stats_dataset.get("rms_error", np.nan)), + "mean_error_m": float(stats_dataset.get("mean_error", np.nan)), + "max_error_m": float(stats_dataset.get("max_error", np.nan)), + "std_error_m": float(stats_dataset.get("std_error", np.nan)), + "n_measurements": int(stats_dataset.get("n_measurements", 0)), + } + + +def _store_gcp_pair_results(netcdf_data, param_idx, pair_idx, error_metrics): + """Store GCP pair results in the NetCDF data structure.""" + netcdf_data["rms_error_m"][param_idx, pair_idx] = error_metrics["rms_error_m"] + netcdf_data["mean_error_m"][param_idx, pair_idx] = error_metrics["mean_error_m"] + netcdf_data["max_error_m"][param_idx, pair_idx] = error_metrics["max_error_m"] + netcdf_data["std_error_m"][param_idx, pair_idx] = error_metrics["std_error_m"] + netcdf_data["n_measurements"][param_idx, pair_idx] = error_metrics["n_measurements"] + + +def _compute_parameter_set_metrics(netcdf_data, param_idx, pair_errors, threshold_m=250.0): + """ + Compute overall performance metrics for a parameter set. + + Args: + netcdf_data: NetCDF data dictionary + param_idx: Parameter set index + pair_errors: Array of RMS errors for each GCP pair + threshold_m: Performance threshold in meters + """ + pair_errors = np.array(pair_errors) + valid_errors = pair_errors[~np.isnan(pair_errors)] + + if len(valid_errors) > 0: + # Percentage of pairs with error < threshold + # Find the threshold metric key dynamically + threshold_metric = None + for key in netcdf_data.keys(): + if key.startswith("percent_under_") and key.endswith("m"): + threshold_metric = key + break + + if threshold_metric: + percent_under_threshold = (valid_errors < threshold_m).sum() / len(valid_errors) * 100 + netcdf_data[threshold_metric][param_idx] = percent_under_threshold + + # Mean RMS across all pairs + netcdf_data["mean_rms_all_pairs"][param_idx] = np.mean(valid_errors) + + # Best and worst pair performance + netcdf_data["best_pair_rms"][param_idx] = np.min(valid_errors) + netcdf_data["worst_pair_rms"][param_idx] = np.max(valid_errors) + + +# ============================================================================= +# Incremental NetCDF Saving (Checkpoint/Resume) +# ============================================================================= + +# ============================================================================= +# Preferred-name aliases (backward-compat originals kept above) +# ============================================================================= + + +def run_correction( + config: CorrectionConfig, + work_dir: Path, + inputs: Sequence[CorrectionInput | tuple[str, str, str]], + resume_from_checkpoint: bool = False, +) -> "CorrectionResult": + """Run the correction parameter sweep. + + This is the preferred user-facing entry point (compared to :func:`loop`). + Returns a structured :class:`~curryer.correction.results.CorrectionResult` + with the best parameter set, pass/fail verdict, recommendation, and a + human-readable summary table. The raw ``results`` list and ``netcdf_data`` + dict from :func:`loop` are available as ``result.results`` and + ``result.netcdf_data`` for advanced use. + + Parameters + ---------- + config : CorrectionConfig + Full correction configuration. + work_dir : Path + Working directory for temporary files. + inputs : list of CorrectionInput or list of (str, str, str) + Each element is either a :class:`~curryer.correction.config.CorrectionInput` + (named fields) or a legacy ``(telemetry_key, science_key, gcp_key)`` tuple. + Both forms may be mixed in the same list. + File paths are expected to be local. S3 URIs (``s3://…``) are also + accepted as a convenience when ``boto3`` is installed; see + :func:`~curryer.correction.io.resolve_path`. + resume_from_checkpoint : bool, optional + If True, resume from an existing checkpoint. + + Returns + ------- + CorrectionResult + Structured result with best parameters, pass/fail verdict, + recommendation, summary table, and raw NetCDF/intermediate data + available on the returned object (for example, + ``result.netcdf_data``). + """ + from curryer.correction.results import build_correction_result + + run_start = time.time() + + normalized: list[tuple[str, str, str]] = [] + for inp in inputs: + if isinstance(inp, CorrectionInput): + normalized.append((str(inp.telemetry_file), str(inp.science_file), str(inp.gcp_file))) + else: + normalized.append(inp) + + results, netcdf_data = loop(config, work_dir, normalized, resume_from_checkpoint) + elapsed = time.time() - run_start + netcdf_path = work_dir / config.get_output_filename() + + correction_result = build_correction_result( + config=config, + results=results, + netcdf_data=netcdf_data, + netcdf_path=netcdf_path, + elapsed_time_s=elapsed, + ) + + logger.info("\n%s", correction_result.summary_table) + logger.info(correction_result.recommendation) + + return correction_result + + +def compute_error_stats(image_matching_results, correction_config: "CorrectionConfig"): + """Compute error statistics from image matching results. + + This is the preferred name for :func:`call_error_stats_module`. + See :func:`call_error_stats_module` for full documentation. + + Parameters + ---------- + image_matching_results : xr.Dataset or list of xr.Dataset + Output from image matching, either a single dataset or a list. + correction_config : CorrectionConfig + Correction configuration used to initialise the error stats processor. + + Returns + ------- + xr.Dataset + Aggregate error statistics dataset. + """ + return call_error_stats_module(image_matching_results, correction_config) + + +def run_image_matching( + geolocated_data: "xr.Dataset", + gcp_reference_file: Path, + telemetry: "pd.DataFrame", + calibration_dir: Path, + params_info: list, + config: "CorrectionConfig", + los_vectors_cached: "np.ndarray | None" = None, + optical_psfs_cached: "list | None" = None, +) -> "xr.Dataset": + """Run image matching against GCP reference. + + This is the preferred name for :func:`image_matching`. + See :func:`image_matching` for full documentation. + + Parameters + ---------- + geolocated_data : xr.Dataset + Geolocated scene data with latitude/longitude. + gcp_reference_file : Path + Path to the GCP reference image (.mat file). + telemetry : pd.DataFrame + Telemetry DataFrame with spacecraft state. + calibration_dir : Path + Directory containing calibration files. + params_info : list + Parameter information for the current iteration. + config : CorrectionConfig + Full correction configuration. + los_vectors_cached : np.ndarray or None, optional + Pre-loaded LOS vectors; loaded from disk if None. + optical_psfs_cached : list or None, optional + Pre-loaded optical PSFs; loaded from disk if None. + + Returns + ------- + xr.Dataset + Image matching results dataset. + """ + return image_matching( + geolocated_data, + gcp_reference_file, + telemetry, + calibration_dir, + params_info, + config, + los_vectors_cached, + optical_psfs_cached, + ) diff --git a/curryer/correction/psf.py b/curryer/correction/psf.py index 57e75a21..8a4ea298 100644 --- a/curryer/correction/psf.py +++ b/curryer/correction/psf.py @@ -14,11 +14,11 @@ from ..compute import constants from ..compute.spatial import ecef_to_geodetic, geodetic_to_ecef from .data_structures import ( - GeolocationConfig, ImageGrid, OpticalPSFEntry, ProjectedPSF, PSFGrid, + PSFSamplingConfig, ) logger = logging.getLogger(__name__) @@ -234,7 +234,7 @@ def convolve_gcp_with_psf(gcp: ImageGrid, psf: PSFGrid) -> ImageGrid: def convolve_psf_with_spacecraft_motion( psf: ProjectedPSF, composite_img: ImageGrid, - config: GeolocationConfig, + config: PSFSamplingConfig, ) -> PSFGrid: """ Apply spacecraft motion blur to projected PSF. @@ -245,7 +245,7 @@ def convolve_psf_with_spacecraft_motion( Projected PSF on Earth's surface. composite_img : ImageGrid Composite image defining spacecraft motion direction. - config : GeolocationConfig + config : PSFSamplingConfig Configuration with PSF sampling parameters. Returns diff --git a/curryer/correction/regrid.py b/curryer/correction/regrid.py new file mode 100644 index 00000000..56f630c9 --- /dev/null +++ b/curryer/correction/regrid.py @@ -0,0 +1,752 @@ +"""GCP chip regridding algorithms. + +This module provides functionality to transform GCP chips from irregular +geodetic grids (derived from ECEF coordinates) to regular latitude/longitude +grids. The regridding process is mission-agnostic and configurable via RegridConfig. + +The main workflow is: +1. Load raw GCP chip with ECEF coordinates (from HDF file) +2. Convert ECEF → geodetic (lon, lat, h) +3. Determine output grid bounds and spacing +4. Interpolate data onto regular grid using bilinear interpolation +5. Return ImageGrid with regular lat/lon coordinates +""" + +from __future__ import annotations + +import logging + +import numpy as np + +from ..compute.spatial import ecef_to_geodetic +from .data_structures import ImageGrid, RegridConfig + +logger = logging.getLogger(__name__) + + +def compute_regular_grid_bounds( + lon_irregular: np.ndarray, + lat_irregular: np.ndarray, + conservative: bool = True, +) -> tuple[float, float, float, float]: + """ + Compute bounds for regular output grid from irregular input. + + Parameters + ---------- + lon_irregular, lat_irregular : np.ndarray + 2D arrays of irregular grid coordinates (degrees). + conservative : bool, default=True + If True, shrink bounds to ensure all output points are within input. + Conservative bounds avoid extrapolation at edges by taking the maximum + of left/bottom edges and minimum of right/top edges. + + Returns + ------- + minlon, maxlon, minlat, maxlat : float + Bounding box for regular grid (degrees). + + Notes + ----- + Conservative bounds (default) follow MATLAB behavior: + - minlon = max(bottom_left_lon, top_left_lon) + - maxlon = min(bottom_right_lon, top_right_lon) + - minlat = max(bottom_left_lat, bottom_right_lat) + - maxlat = min(top_left_lat, top_right_lat) + + This ensures the regular grid lies entirely within the irregular grid. + """ + if conservative: + # Get corner coordinates (assuming row increases south, col increases east) + # Corners: [0,0]=top_left, [0,-1]=top_right, [-1,0]=bottom_left, [-1,-1]=bottom_right + minlon = max(lon_irregular[-1, 0], lon_irregular[0, 0]) # bottom-left, top-left + maxlon = min(lon_irregular[0, -1], lon_irregular[-1, -1]) # top-right, bottom-right + minlat = max(lat_irregular[-1, 0], lat_irregular[-1, -1]) # bottom corners + maxlat = min(lat_irregular[0, 0], lat_irregular[0, -1]) # top corners + else: + # Use full extent + minlon = float(lon_irregular.min()) + maxlon = float(lon_irregular.max()) + minlat = float(lat_irregular.min()) + maxlat = float(lat_irregular.max()) + + return minlon, maxlon, minlat, maxlat + + +def create_regular_grid( + bounds: tuple[float, float, float, float], + grid_size: tuple[int, int] | None = None, + resolution: tuple[float, float] | None = None, +) -> tuple[np.ndarray, np.ndarray]: + """ + Create regular lat/lon grid. + + Parameters + ---------- + bounds : tuple[float, float, float, float] + (minlon, maxlon, minlat, maxlat) in degrees. + grid_size : tuple[int, int], optional + (nrows, ncols). If None, derive from resolution. + resolution : tuple[float, float], optional + (dlat, dlon) in degrees. If None, use grid_size. + + Returns + ------- + lon_regular, lat_regular : np.ndarray + 2D arrays of regular grid coordinates (degrees), shape (nrows, ncols). + + Notes + ----- + Exactly one of grid_size or resolution must be provided. + Grid follows MATLAB convention: + - Row index increases going south (latitude decreases) + - Column index increases going east (longitude increases) + """ + minlon, maxlon, minlat, maxlat = bounds + + if grid_size is not None and resolution is not None: + raise ValueError("Specify only one of grid_size or resolution, not both") + if grid_size is None and resolution is None: + raise ValueError("Must specify either grid_size or resolution") + + if grid_size is not None: + nrows, ncols = grid_size + dlat = (maxlat - minlat) / (nrows - 1) + dlon = (maxlon - minlon) / (ncols - 1) + else: + dlat, dlon = resolution + nrows = round((maxlat - minlat) / dlat) + 1 + ncols = round((maxlon - minlon) / dlon) + 1 + + # Create 1D coordinate arrays using linspace so that the endpoints are + # exact regardless of floating-point accumulation in step arithmetic. + # (Using arange-based approach risks the last element overshooting the + # bound by a tiny amount, placing output points just outside the input + # grid and producing spurious NaN fill values at the edges.) + lat_1d = np.linspace(maxlat, minlat, nrows) # north → south + lon_1d = np.linspace(minlon, maxlon, ncols) # west → east + + # Create 2D meshgrid + lon_regular, lat_regular = np.meshgrid(lon_1d, lat_1d) + + logger.debug(f"Created regular grid: {nrows}×{ncols}, resolution: ({dlat:.6f}°, {dlon:.6f}°)") + + return lon_regular, lat_regular + + +def _cross2(a: np.ndarray, b: np.ndarray) -> float: + """2D cross product: a[0]*b[1] - a[1]*b[0].""" + return a[0] * b[1] - a[1] * b[0] + + +def point_in_triangle( + point: np.ndarray, + triangle: np.ndarray, +) -> tuple[bool, np.ndarray]: + """ + Check if point is inside triangle using barycentric coordinates. + + Parameters + ---------- + point : np.ndarray + Point [x, y] to test. + triangle : np.ndarray + Triangle vertices, shape (3, 2): [[x1, y1], [x2, y2], [x3, y3]]. + + Returns + ------- + inside : bool + True if point is inside triangle (barycentric coords all in (0, 1)). + barycentric_coords : np.ndarray + Barycentric coordinates [w1, w2, w3]. + + Notes + ----- + Uses the cross product method from MATLAB bandval function. + Point P is inside triangle ABC if all barycentric weights are in (0, 1). + """ + A, B, C = triangle + + # Compute denominator (area of triangle * 2) + # MATLAB: d=cross2(A,B)+cross2(B,C)+cross2(C,A) + d = _cross2(A - C, B - C) # Simplified: same as above + + if abs(d) < 1e-14: # Degenerate triangle + return False, np.array([0.0, 0.0, 0.0]) + + # Compute barycentric coordinates following MATLAB bandval logic + # MATLAB: wA=(cross2(B,C)+cross2(P,B-C))/d + wA = _cross2(point - C, B - C) / d + wB = _cross2(A - C, point - C) / d + wC = 1.0 - wA - wB + + # Check if point is inside (all weights in [0, 1]) + # Use small tolerance for boundary cases + # MATLAB uses strict inequalities (0 < w < 1), but for numerical stability + # and to handle edge cases in regridding, we use tolerant checks + tol = 1e-10 + inside = (-tol <= wA <= 1 + tol) and (-tol <= wB <= 1 + tol) and (-tol <= wC <= 1 + tol) + + return inside, np.array([wA, wB, wC]) + + +def bilinear_interpolate_quad( + point: np.ndarray, + corners_lon: np.ndarray, + corners_lat: np.ndarray, + corner_values: np.ndarray, +) -> float: + """ + Bilinear interpolation within an irregular quadrilateral. + + Parameters + ---------- + point : np.ndarray + Target point [lon, lat] (degrees). + corners_lon, corners_lat : np.ndarray + Coordinates of 4 corners, ordered clockwise from top-left: + [top-left, top-right, bottom-right, bottom-left]. + corner_values : np.ndarray + Values at the 4 corners. + + Returns + ------- + interpolated_value : float + Interpolated value at target point. + + Notes + ----- + Uses matrix inversion method from MATLAB code: + Solves [1, lon, lat, lon*lat]^T = M * [w1, w2, w3, w4]^T + where M is constructed from corner coordinates. + """ + lon_p, lat_p = point + + # Build interpolation matrix (4x4) + # M = [[1, 1, 1, 1 ] + # [lon1, lon2, lon3, lon4] + # [lat1, lat2, lat3, lat4] + # [lon1*lat1, lon2*lat2, lon3*lat3, lon4*lat4]] + M = np.ones((4, 4)) + M[1, :] = corners_lon + M[2, :] = corners_lat + M[3, :] = corners_lon * corners_lat + + # Right-hand side vector + E = np.array([1.0, lon_p, lat_p, lon_p * lat_p]) + + # Solve for weights + try: + weights = np.linalg.solve(M, E) + except np.linalg.LinAlgError: + # Singular matrix - degenerate quadrilateral + # Fall back to simple average + logger.warning("Singular matrix in bilinear interpolation, using average") + return float(np.mean(corner_values)) + + # Compute interpolated value + value = np.dot(weights, corner_values) + + return float(value) + + +def _check_point_in_cell( + point_lon: float, + point_lat: float, + lon_grid: np.ndarray, + lat_grid: np.ndarray, + i: int, + j: int, +) -> bool: + """Fast check if point is in cell [i,j]. Returns True/False.""" + tol = 1e-10 + + # Get corner coordinates + lon_tl, lat_tl = lon_grid[i, j], lat_grid[i, j] + lon_tr, lat_tr = lon_grid[i, j + 1], lat_grid[i, j + 1] + lon_br, lat_br = lon_grid[i + 1, j + 1], lat_grid[i + 1, j + 1] + lon_bl, lat_bl = lon_grid[i + 1, j], lat_grid[i + 1, j] + + # Test upper-left triangle (TL, TR, BL) + d_ul = (lon_tl - lon_bl) * (lat_tr - lat_bl) - (lat_tl - lat_bl) * (lon_tr - lon_bl) + + if abs(d_ul) > 1e-14: + wA = ((point_lon - lon_bl) * (lat_tr - lat_bl) - (point_lat - lat_bl) * (lon_tr - lon_bl)) / d_ul + wB = ((lon_tl - lon_bl) * (point_lat - lat_bl) - (lat_tl - lat_bl) * (point_lon - lon_bl)) / d_ul + wC = 1.0 - wA - wB + + if (-tol <= wA <= 1 + tol) and (-tol <= wB <= 1 + tol) and (-tol <= wC <= 1 + tol): + return True + + # Test lower-right triangle (TR, BR, BL) + d_lr = (lon_tr - lon_bl) * (lat_br - lat_bl) - (lat_tr - lat_bl) * (lon_br - lon_bl) + + if abs(d_lr) > 1e-14: + wA = ((point_lon - lon_bl) * (lat_br - lat_bl) - (point_lat - lat_bl) * (lon_br - lon_bl)) / d_lr + wB = ((lon_tr - lon_bl) * (point_lat - lat_bl) - (lat_tr - lat_bl) * (point_lon - lon_bl)) / d_lr + wC = 1.0 - wA - wB + + if (-tol <= wA <= 1 + tol) and (-tol <= wB <= 1 + tol) and (-tol <= wC <= 1 + tol): + return True + + return False + + +def find_containing_cell( + point: np.ndarray, + lon_grid: np.ndarray, + lat_grid: np.ndarray, + start_cell: tuple[int, int] | None = None, +) -> tuple[int, int] | None: + """ + Find which cell in irregular grid contains the target point. + + Parameters + ---------- + point : np.ndarray + Target point [lon, lat] (degrees). + lon_grid, lat_grid : np.ndarray + 2D arrays of irregular grid coordinates (degrees). + start_cell : tuple[int, int], optional + Starting cell (i, j) for search (optimization hint). + + Returns + ------- + cell_indices : tuple[int, int] or None + (i, j) of cell containing point, or None if not found. + + Notes + ----- + Uses barycentric coordinate test to check if point is inside + quadrilateral. For each cell, tests two triangles (upper-left and + lower-right) that together form the quadrilateral. + + Search strategy follows MATLAB optimization: + - Start from hint if provided + - Check cells near last found cell (spatial locality) + - If not found, search all cells + + Optimization: Inline triangle test to avoid array allocations. + """ + nrows, ncols = lon_grid.shape + max_i, max_j = nrows - 1, ncols - 1 + + point_lon, point_lat = point[0], point[1] + + def check_cell(i: int, j: int) -> bool: + """Check if point is in cell [i,j] by delegating to the shared helper.""" + return _check_point_in_cell( + lon_grid=lon_grid, + lat_grid=lat_grid, + point_lon=point_lon, + point_lat=point_lat, + i=i, + j=j, + ) + + # Determine search start + if start_cell is not None: + start_i = max(0, min(start_cell[0] - 1, max_i - 1)) + start_j = max(0, min(start_cell[1] - 1, max_j - 1)) + + # Search in a small window around hint first + window_size = 3 + for di in range(-window_size, window_size + 1): + for dj in range(-window_size, window_size + 1): + i = start_i + di + j = start_j + dj + if 0 <= i < max_i and 0 <= j < max_j: + if check_cell(i, j): + return (i, j) + + # Not found in window - return None (caller will use spatial index or expand search) + return None + else: + # No hint provided - do full search (only happens for very first point or edge cases) + # This is expensive but necessary for correctness when no hint is available + for i in range(max_i): + for j in range(max_j): + if check_cell(i, j): + return (i, j) + + # Not found + return None + + +def regrid_irregular_to_regular( + data_irregular: np.ndarray, + lon_irregular: np.ndarray, + lat_irregular: np.ndarray, + lon_regular: np.ndarray, + lat_regular: np.ndarray, + method: str = "bilinear", + fill_value: float = np.nan, + use_spatial_index: bool = True, +) -> np.ndarray: + """ + Regrid data from irregular geodetic grid to regular lat/lon grid. + + This is the core algorithm: for each point in the regular output grid, + find the corresponding quadrilateral cell in the irregular input grid + and interpolate the value. + + Parameters + ---------- + data_irregular : np.ndarray + 2D array of values on irregular grid. + lon_irregular, lat_irregular : np.ndarray + 2D arrays of irregular grid coordinates (degrees). + lon_regular, lat_regular : np.ndarray + 2D arrays of regular grid coordinates (degrees). + method : str, default="bilinear" + Interpolation method: "bilinear" or "nearest". + fill_value : float, default=np.nan + Value for output points that fall outside input grid. + use_spatial_index : bool, default=True + If True, build a spatial index (KD-tree) for faster cell finding. + Recommended for large grids (>100×100). Adds ~0.1s overhead. + + Returns + ------- + data_regular : np.ndarray + 2D array of interpolated values on regular grid. + + Notes + ----- + Algorithm (follows MATLAB Chip_regrid2.m): + 1. For each point P in regular grid: + a. Search for containing quadrilateral in irregular grid + b. Perform bilinear interpolation using 4 corner values + 2. Optimization: Use spatial locality (start search near last found cell) + 3. Points outside irregular grid are filled with fill_value + + Performance: O(n²) worst case, O(n²/k) typical with spatial locality. + + Optimizations applied: + - Minimize array allocations in inner loop + - Extract corner data once per cell + - Use scalar operations where possible + - Optional spatial index for O(log n) nearest neighbor queries + """ + nrows_out, ncols_out = lon_regular.shape + data_regular = np.full((nrows_out, ncols_out), fill_value) + + # Build spatial index if enabled (speeds up cell finding for large grids) + kdtree = None + cell_centers = None + # Use total grid size rather than only row count to gate KD-tree construction. + # The threshold of 2500 corresponds to a 50x50 grid. + if use_spatial_index and lon_irregular.size > 2500: + try: + from scipy.spatial import cKDTree + + # Compute cell centers for spatial index using vectorized operations + nrows_in, ncols_in = lon_irregular.shape + if nrows_in > 1 and ncols_in > 1: + # Four corners of each input cell for longitude + lon00 = lon_irregular[:-1, :-1] + lon01 = lon_irregular[:-1, 1:] + lon10 = lon_irregular[1:, :-1] + lon11 = lon_irregular[1:, 1:] + # Four corners of each input cell for latitude + lat00 = lat_irregular[:-1, :-1] + lat01 = lat_irregular[:-1, 1:] + lat10 = lat_irregular[1:, :-1] + lat11 = lat_irregular[1:, 1:] + + # Cell center (approximate) for all cells at once + center_lon = 0.25 * (lon00 + lon01 + lon10 + lon11) + center_lat = 0.25 * (lat00 + lat01 + lat10 + lat11) + + # (n_cells, 2) array of [lon, lat] centers + cell_centers = np.stack([center_lon.ravel(), center_lat.ravel()], axis=-1) + + # Corresponding (row, col) indices for each cell + row_idx, col_idx = np.indices((nrows_in - 1, ncols_in - 1)) + cell_indices_map = list(zip(row_idx.ravel(), col_idx.ravel())) + + kdtree = cKDTree(cell_centers) + logger.debug(f"Built spatial index with {len(cell_indices_map)} cells") + else: + # Grid too small to form any cells + kdtree = None + except ImportError: + logger.debug("scipy.spatial.cKDTree not available, using sequential search") + kdtree = None + + # Track last found cell for optimization (spatial locality) + last_cell = None + first_cell_of_row = None + + logger.info(f"Regridding {nrows_out}×{ncols_out} points using {method} interpolation...") + + # Iterate through regular grid points + for ii in range(nrows_out): + if ii % 50 == 0: + logger.debug(f" Processing row {ii + 1}/{nrows_out}") + + # Reset search hint at start of each row (follow MATLAB pattern) + if first_cell_of_row is not None: + last_cell = first_cell_of_row + + for jj in range(ncols_out): + point_lon = lon_regular[ii, jj] + point_lat = lat_regular[ii, jj] + point = np.array([point_lon, point_lat]) + + cell = None + + # Strategy: + # 1. If we have a hint from previous point, use windowed search + # 2. If windowed search fails or no hint, use spatial index + # 3. If no spatial index, do full search (slow, but correct) + + if last_cell is not None: + # Try windowed search around last cell + cell = find_containing_cell(point, lon_irregular, lat_irregular, last_cell) + + # If windowed search failed or no hint, use spatial index + nearest_fallback_cell = None # track for boundary-snap fallback + if cell is None and kdtree is not None: + # Query k nearest neighbors to find containing cell + distances, indices = kdtree.query(point, k=min(9, len(cell_centers))) + + for rank, idx in enumerate(indices): + candidate_i, candidate_j = cell_indices_map[idx] + # Direct check if point is in this candidate cell (fast) + if _check_point_in_cell( + point_lon, point_lat, lon_irregular, lat_irregular, candidate_i, candidate_j + ): + cell = (candidate_i, candidate_j) + break + # Remember the geometrically nearest cell in case all strict + # tests fail (boundary-snap fallback, see below). + if rank == 0: + nearest_fallback_cell = (candidate_i, candidate_j) + + # Last resort: full search (only if no spatial index available) + if cell is None and kdtree is None: + cell = find_containing_cell(point, lon_irregular, lat_irregular, None) + + if cell is None: + # Boundary-snap fallback: the point may be marginally outside the + # input grid due to floating-point differences between the output + # bounds (e.g. from MATLAB) and Python's ECEF-to-geodetic results. + # Use the nearest cell (slight extrapolation) rather than NaN so + # that sub-pixel boundary mismatches don't create artefacts. + if nearest_fallback_cell is not None: + cell = nearest_fallback_cell + logger.debug( + f"Boundary snap at [{ii},{jj}] (lon={point_lon:.7f}, lat={point_lat:.7f}) → cell {cell}" + ) + else: + # Genuinely outside the grid – leave as fill_value + continue + + i, j = cell + + # Get corner coordinates and values (extract once) + # Clockwise from top-left: TL, TR, BR, BL + lon_tl, lat_tl = lon_irregular[i, j], lat_irregular[i, j] + lon_tr, lat_tr = lon_irregular[i, j + 1], lat_irregular[i, j + 1] + lon_br, lat_br = lon_irregular[i + 1, j + 1], lat_irregular[i + 1, j + 1] + lon_bl, lat_bl = lon_irregular[i + 1, j], lat_irregular[i + 1, j] + + val_tl = data_irregular[i, j] + val_tr = data_irregular[i, j + 1] + val_br = data_irregular[i + 1, j + 1] + val_bl = data_irregular[i + 1, j] + + # Interpolate + if method not in ("bilinear", "nearest"): + raise ValueError(f"Unsupported interpolation method: {method!r}") + elif method == "bilinear": + # Inline bilinear interpolation (avoid function call overhead) + # Build interpolation matrix (4x4) and solve + M = np.array( + [ + [1.0, 1.0, 1.0, 1.0], + [lon_tl, lon_tr, lon_br, lon_bl], + [lat_tl, lat_tr, lat_br, lat_bl], + [lon_tl * lat_tl, lon_tr * lat_tr, lon_br * lat_br, lon_bl * lat_bl], + ] + ) + + E = np.array([1.0, point_lon, point_lat, point_lon * point_lat]) + + try: + weights = np.linalg.solve(M, E) + data_regular[ii, jj] = ( + weights[0] * val_tl + weights[1] * val_tr + weights[2] * val_br + weights[3] * val_bl + ) + except np.linalg.LinAlgError: + # Singular matrix - use simple average + data_regular[ii, jj] = 0.25 * (val_tl + val_tr + val_br + val_bl) + + elif method == "nearest": + # Use nearest corner (avoid extra array allocations) + dist_tl = (lon_tl - point_lon) ** 2 + (lat_tl - point_lat) ** 2 + dist_tr = (lon_tr - point_lon) ** 2 + (lat_tr - point_lat) ** 2 + dist_br = (lon_br - point_lon) ** 2 + (lat_br - point_lat) ** 2 + dist_bl = (lon_bl - point_lon) ** 2 + (lat_bl - point_lat) ** 2 + + min_dist = min(dist_tl, dist_tr, dist_br, dist_bl) + if min_dist == dist_tl: + data_regular[ii, jj] = val_tl + elif min_dist == dist_tr: + data_regular[ii, jj] = val_tr + elif min_dist == dist_br: + data_regular[ii, jj] = val_br + else: + data_regular[ii, jj] = val_bl + + # Update search hint + last_cell = cell + if jj == 0: + first_cell_of_row = cell + + logger.info("Regridding complete") + + return data_regular + + +def regrid_gcp_chip( + band_data: np.ndarray, + ecef_coords: tuple[np.ndarray, np.ndarray, np.ndarray], + config: RegridConfig, + output_file: str | None = None, + output_metadata: dict[str, str] | None = None, +) -> ImageGrid: + """ + High-level function: Regrid GCP chip from ECEF to regular lat/lon grid. + + This is the main entry point for GCP chip regridding. It handles the complete + workflow from ECEF coordinates to a regular geodetic grid. + + Parameters + ---------- + band_data : np.ndarray + 2D array of radiometric values. + ecef_coords : tuple[np.ndarray, np.ndarray, np.ndarray] + (X, Y, Z) ECEF coordinate arrays (meters), each shape (nrows, ncols). + config : RegridConfig + Regridding configuration. + output_file : str, optional + If provided, save the regridded chip to this NetCDF file path. + File will be created with CF-compliant metadata. + output_metadata : dict[str, str], optional + Additional metadata to include in the NetCDF file (only used if output_file + is specified). Common keys: 'source_file', 'mission', 'sensor', 'band'. + + Returns + ------- + regridded_chip : ImageGrid + Regridded data on regular lat/lon grid. + + Workflow + -------- + 1. Convert ECEF → geodetic (lon, lat, h) using curryer.compute.spatial + 2. Compute regular grid bounds (conservative or full extent) + 3. Create regular output grid (from resolution or size) + 4. Regrid data using bilinear interpolation + 5. Return ImageGrid with (data, lat, lon, h) + 6. Optionally save to NetCDF file + + Examples + -------- + Basic usage (return in-memory): + + >>> from curryer.correction.image_io import load_gcp_chip_from_hdf + >>> from curryer.correction.regrid import regrid_gcp_chip, RegridConfig + >>> band, x, y, z = load_gcp_chip_from_hdf("chip.hdf") + >>> config = RegridConfig(output_resolution_deg=(0.001, 0.001)) + >>> regridded = regrid_gcp_chip(band, (x, y, z), config) + + With NetCDF output: + + >>> regridded = regrid_gcp_chip( + ... band, (x, y, z), config, + ... output_file="regridded_chip.nc", + ... output_metadata={ + ... 'source_file': 'LT08CHP.20140803.p002r071.c01.v001.hdf', + ... 'mission': 'CLARREO Pathfinder', + ... 'band': 'red', + ... } + ... ) + """ + ecef_x, ecef_y, ecef_z = ecef_coords + + # Validate input shapes + if not (band_data.shape == ecef_x.shape == ecef_y.shape == ecef_z.shape): + raise ValueError( + f"Shape mismatch: band={band_data.shape}, x={ecef_x.shape}, y={ecef_y.shape}, z={ecef_z.shape}" + ) + + logger.info(f"Regridding GCP chip: input shape {band_data.shape}") + + # Step 1: Convert ECEF → geodetic + logger.debug("Converting ECEF → geodetic coordinates...") + nrows, ncols = band_data.shape + + # Flatten for vectorized conversion + ecef_flat = np.stack([ecef_x.ravel(), ecef_y.ravel(), ecef_z.ravel()], axis=1) + + # Convert using Curryer's spatial module (vectorized, uses WGS84) + lla_flat = ecef_to_geodetic(ecef_flat, meters=True, degrees=True) + + # Reshape back to 2D grids + lon_irregular = lla_flat[:, 0].reshape(nrows, ncols) + lat_irregular = lla_flat[:, 1].reshape(nrows, ncols) + + logger.debug( + f"Geodetic range: lon=[{lon_irregular.min():.4f}, {lon_irregular.max():.4f}], " + f"lat=[{lat_irregular.min():.4f}, {lat_irregular.max():.4f}]" + ) + + # Step 2: Determine output grid bounds + if config.output_bounds is not None: + bounds = config.output_bounds + logger.debug(f"Using explicit bounds: {bounds}") + else: + bounds = compute_regular_grid_bounds(lon_irregular, lat_irregular, config.conservative_bounds) + logger.debug(f"Computed bounds: {bounds}") + + # Step 3: Create regular grid + if config.output_resolution_deg is not None: + lon_regular, lat_regular = create_regular_grid(bounds, resolution=config.output_resolution_deg) + elif config.output_grid_size is not None: + lon_regular, lat_regular = create_regular_grid(bounds, grid_size=config.output_grid_size) + else: + # Auto mode: use input size + logger.debug("Auto mode: using input grid size") + lon_regular, lat_regular = create_regular_grid(bounds, grid_size=band_data.shape) + + # Step 4: Regrid data + data_regular = regrid_irregular_to_regular( + band_data, + lon_irregular, + lat_irregular, + lon_regular, + lat_regular, + method=config.interpolation_method, + fill_value=config.fill_value, + ) + + # Step 5: Create ImageGrid + # Note: h is not regridded (would need separate interpolation), set to None + regridded_chip = ImageGrid( + data=data_regular, + lat=lat_regular, + lon=lon_regular, + h=None, # Height not interpolated + ) + + logger.info(f"Regridding complete: output shape {data_regular.shape}") + + # Step 6: Optionally save to NetCDF + if output_file is not None: + from .image_io import save_image_grid_to_netcdf + + save_image_grid_to_netcdf(output_file, regridded_chip, metadata=output_metadata) + + return regridded_chip diff --git a/curryer/correction/results.py b/curryer/correction/results.py new file mode 100644 index 00000000..33e3f14e --- /dev/null +++ b/curryer/correction/results.py @@ -0,0 +1,377 @@ +"""Structured result models for the correction pipeline. + +Provides :class:`CorrectionResult` and :class:`ParameterSetResult`, +returned by :func:`~curryer.correction.pipeline.run_correction`. +""" + +from __future__ import annotations + +import math +from datetime import datetime, timezone +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np +from pydantic import BaseModel, ConfigDict, Field + +if TYPE_CHECKING: + from curryer.correction.config import CorrectionConfig + + +class ParameterSetResult(BaseModel): + """Results for a single parameter set in a correction sweep. + + Attributes + ---------- + index : int + Zero-based index of this parameter set in the sweep. + parameter_values : dict[str, float] + Parameter name → sampled value for this set. + mean_rms_m : float + Mean RMS geolocation error across all GCP pairs (metres). + best_pair_rms_m : float + RMS error of the best-performing GCP pair (metres). + worst_pair_rms_m : float + RMS error of the worst-performing GCP pair (metres). + """ + + index: int + parameter_values: dict[str, float] + mean_rms_m: float + best_pair_rms_m: float + worst_pair_rms_m: float + + +class CorrectionResult(BaseModel): + """Structured result from :func:`~curryer.correction.pipeline.run_correction`. + + Provides the instrument engineer with a clear answer: what is the best + parameter set, does it meet requirements, and what should they do next. + + Serialisation + ------------- + Most fields are JSON-serialisable via ``model_dump()`` / ``model_dump_json()``. + The ``results`` and ``netcdf_data`` fields are excluded from serialisation + by default because they contain non-JSON-serialisable types (xr.Dataset, + numpy arrays). Access them directly on the object when raw data is needed:: + + result = run_correction(config, work_dir, inputs) + result.best_parameter_set # dict[str, float] — use this + result.results # raw per-iteration dicts — advanced use only + result.netcdf_data # raw numpy arrays — advanced use only + + Attributes + ---------- + best_parameter_set : dict[str, float] + Parameter values that produced the lowest aggregate RMS. + best_rms_m : float + Best aggregate RMS achieved across all parameter sets (metres). + best_index : int + Index of the best parameter set (for cross-referencing with NetCDF output). + worst_rms_m : float + Worst aggregate RMS across all parameter sets (metres). + mean_rms_m : float + Mean of all aggregate RMS values (metres). + n_parameter_sets : int + Number of parameter sets tested in the sweep. + n_gcp_pairs : int + Number of GCP pairs used. + all_parameter_sets : list[ParameterSetResult] + All tested parameter sets sorted by mean RMS (ascending). + met_threshold : bool + Whether the best parameter set met the mission performance requirements. + recommendation : str + Human-readable next-step guidance for the instrument engineer. + summary_table : str + Human-readable ASCII table of the top results. + netcdf_path : Path or None + Path to the saved NetCDF output file. + config_snapshot : dict + Key config values used, for reproducibility records. + elapsed_time_s : float + Total wall-clock processing time in seconds. + timestamp : datetime + UTC time when the run completed. + results : list + Raw per-iteration result dicts from :func:`~curryer.correction.pipeline.loop`. + Excluded from JSON serialisation. + netcdf_data : dict + Raw NetCDF numpy arrays from :func:`~curryer.correction.pipeline.loop`. + Excluded from JSON serialisation. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + best_parameter_set: dict[str, float] + best_rms_m: float + best_index: int + worst_rms_m: float + mean_rms_m: float + n_parameter_sets: int + n_gcp_pairs: int + all_parameter_sets: list[ParameterSetResult] + met_threshold: bool + recommendation: str + summary_table: str + netcdf_path: Path | None = None + config_snapshot: dict = Field(default_factory=dict) + elapsed_time_s: float = 0.0 + timestamp: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) + + # Raw data — accessible on the result object, excluded from JSON serialisation. + results: list = Field(default_factory=list, exclude=True) + netcdf_data: dict = Field(default_factory=dict, exclude=True) + + +# ============================================================================ +# Internal table formatting helpers +# ============================================================================ + + +def _fmt_rms(value: float) -> str: + """Format an RMS value for display; returns ``'N/A'`` for non-finite values.""" + return f"{value:.1f}m" if math.isfinite(value) else "N/A" + + +def _format_correction_summary_table( + top_sets: list[ParameterSetResult], + total_sets: int, + n_gcp_pairs: int, + met_threshold: bool, +) -> str: + """Generate a human-readable correction sweep summary table. + + Uses the same box-drawing and ljust/rjust pattern as + :func:`~curryer.correction.verification._format_summary_table`. + + Parameters + ---------- + top_sets : list[ParameterSetResult] + Top-ranked parameter sets to display (typically ``all_parameter_sets[:10]``). + total_sets : int + Total number of parameter sets evaluated in the sweep. + n_gcp_pairs : int + Number of GCP pairs used. + met_threshold : bool + Whether any parameter set met performance requirements. + + Returns + ------- + str + Multi-line box-drawn ASCII table. + """ + # Column widths + w_rank = 6 + w_mean = 14 + w_best = 14 + w_worst = 14 + w_idx = 8 + w_mark = 4 + # col_inner = sum of column widths + number of separators (one between each column) + col_inner = w_rank + w_mean + w_best + w_worst + w_idx + w_mark + 5 + + title = f" Correction Sweep Summary ({total_sets} sets × {n_gcp_pairs} pairs)" + + if top_sets: + best_line = f" Best: Set #{top_sets[0].index} — {_fmt_rms(top_sets[0].mean_rms_m)} mean RMS" + else: + best_line = " No results available" + + verdict_text = "MET REQUIREMENTS ✓" if met_threshold else "DID NOT MEET REQUIREMENTS ✗" + verdict_line = f" {verdict_text}" + + inner_width = max(col_inner, len(title) + 2, len(best_line) + 2, len(verdict_line) + 2) + + def _h_sep(mid: str, fill: str = "─") -> str: + """Build a column separator row padded to inner_width.""" + core = ( + fill * w_rank + + mid + + fill * w_mean + + mid + + fill * w_best + + mid + + fill * w_worst + + mid + + fill * w_idx + + mid + + fill * w_mark + ) + return core + fill * max(0, inner_width - len(core)) + + lines: list[str] = [] + lines.append("┌" + "─" * inner_width + "┐") + lines.append("│" + title.ljust(inner_width) + "│") + lines.append("├" + _h_sep("┬") + "┤") + + header = ( + "Rank".center(w_rank) + + "│" + + "Mean RMS".center(w_mean) + + "│" + + "Best Pair".center(w_best) + + "│" + + "Worst Pair".center(w_worst) + + "│" + + "Index".center(w_idx) + + "│" + + "".center(w_mark) + ) + lines.append("│" + header.ljust(inner_width) + "│") + lines.append("├" + _h_sep("┼") + "┤") + + for rank, ps in enumerate(top_sets, 1): + marker = " ★" if rank == 1 else " " + row = ( + str(rank).center(w_rank) + + "│" + + _fmt_rms(ps.mean_rms_m).center(w_mean) + + "│" + + _fmt_rms(ps.best_pair_rms_m).center(w_best) + + "│" + + _fmt_rms(ps.worst_pair_rms_m).center(w_worst) + + "│" + + str(ps.index).center(w_idx) + + "│" + + marker.center(w_mark) + ) + lines.append("│" + row.ljust(inner_width) + "│") + + lines.append("├" + "─" * inner_width + "┤") + lines.append("│" + best_line.ljust(inner_width) + "│") + lines.append("│" + verdict_line.ljust(inner_width) + "│") + lines.append("└" + "─" * inner_width + "┘") + + return "\n".join(lines) + + +# ============================================================================ +# Public factory +# ============================================================================ + + +def build_correction_result( + config: CorrectionConfig, + results: list, + netcdf_data: dict, + netcdf_path: Path | None, + elapsed_time_s: float, +) -> CorrectionResult: + """Build a :class:`CorrectionResult` from raw :func:`~curryer.correction.pipeline.loop` outputs. + + Parameters + ---------- + config : CorrectionConfig + The correction configuration used for the run. + results : list + Per-iteration result dicts from :func:`~curryer.correction.pipeline.loop`. + netcdf_data : dict + Raw NetCDF data dict from :func:`~curryer.correction.pipeline.loop`. + netcdf_path : Path or None + Path to the saved NetCDF file, if any. + elapsed_time_s : float + Total wall-clock time of the run in seconds. + + Returns + ------- + CorrectionResult + """ + all_mean_rms: np.ndarray = netcdf_data.get("mean_rms_all_pairs", np.array([])) + rms_grid: np.ndarray = netcdf_data.get("rms_error_m", np.empty((0, 0))) + best_pair_rms_arr: np.ndarray = netcdf_data.get("best_pair_rms", np.array([])) + worst_pair_rms_arr: np.ndarray = netcdf_data.get("worst_pair_rms", np.array([])) + + n_param_sets = len(all_mean_rms) + n_gcp_pairs = int(rms_grid.shape[1]) if rms_grid.ndim == 2 and rms_grid.shape[0] > 0 else 0 + + # Best / worst parameter sets + valid_mask = ~np.isnan(all_mean_rms) if n_param_sets > 0 else np.array([], dtype=bool) + if np.any(valid_mask): + best_idx = int(np.nanargmin(all_mean_rms)) + best_rms = float(all_mean_rms[best_idx]) + worst_rms = float(all_mean_rms[int(np.nanargmax(all_mean_rms))]) + mean_rms = float(np.nanmean(all_mean_rms)) + else: + best_idx = 0 + best_rms = float("nan") + worst_rms = float("nan") + mean_rms = float("nan") + + # Extract 1-D parameter arrays (param_* keys) + param_keys = [ + k for k, v in netcdf_data.items() if k.startswith("param_") and isinstance(v, np.ndarray) and v.ndim == 1 + ] + + # Build per-set results + all_sets: list[ParameterSetResult] = [] + for idx in range(n_param_sets): + pvals = {k: float(netcdf_data[k][idx]) for k in param_keys} + all_sets.append( + ParameterSetResult( + index=idx, + parameter_values=pvals, + mean_rms_m=float(all_mean_rms[idx]), + best_pair_rms_m=float(best_pair_rms_arr[idx]) if idx < len(best_pair_rms_arr) else float("nan"), + worst_pair_rms_m=float(worst_pair_rms_arr[idx]) if idx < len(worst_pair_rms_arr) else float("nan"), + ) + ) + all_sets.sort(key=lambda s: (float("inf") if math.isnan(s.mean_rms_m) else s.mean_rms_m)) + + best_params: dict[str, float] = ( + {k: float(netcdf_data[k][best_idx]) for k in param_keys} if param_keys and n_param_sets > 0 else {} + ) + + # Evaluate requirements using legacy performance_threshold_m / performance_spec_percent. + # (The new-style Requirement.evaluate_all() path is tracked by TODO(#151).) + met_threshold = False + if n_gcp_pairs > 0 and math.isfinite(best_rms) and rms_grid.shape[0] > best_idx: + pair_errors = [float(rms_grid[best_idx, pi]) for pi in range(n_gcp_pairs)] + valid_errors = [e for e in pair_errors if math.isfinite(e)] + if valid_errors: + pct_below = sum(1 for e in valid_errors if e < config.performance_threshold_m) / len(valid_errors) * 100 + met_threshold = pct_below >= config.performance_spec_percent + + # Human-readable recommendation + if met_threshold: + recommendation = ( + f"Best parameter set (#{best_idx}) achieved {best_rms:.1f}m mean RMS " + f"and meets performance requirements. " + f"Update kernel files with these values." + ) + else: + rms_str = f"{best_rms:.1f}m" if math.isfinite(best_rms) else "N/A" + recommendation = ( + f"No parameter set met performance requirements. " + f"Best achieved: {rms_str} mean RMS (set #{best_idx}). " + f"Consider widening parameter bounds or increasing iterations." + ) + + summary_table = _format_correction_summary_table(all_sets[:10], n_param_sets, n_gcp_pairs, met_threshold) + + search_strategy = config.search_strategy + config_snapshot = { + "seed": config.seed, + "n_iterations": config.n_iterations, + "search_strategy": search_strategy.value if hasattr(search_strategy, "value") else str(search_strategy), + "performance_threshold_m": config.performance_threshold_m, + "performance_spec_percent": config.performance_spec_percent, + } + + return CorrectionResult( + best_parameter_set=best_params, + best_rms_m=best_rms, + best_index=best_idx, + worst_rms_m=worst_rms, + mean_rms_m=mean_rms, + n_parameter_sets=n_param_sets, + n_gcp_pairs=n_gcp_pairs, + all_parameter_sets=all_sets, + met_threshold=met_threshold, + recommendation=recommendation, + summary_table=summary_table, + netcdf_path=netcdf_path, + config_snapshot=config_snapshot, + elapsed_time_s=elapsed_time_s, + results=results, + netcdf_data=netcdf_data, + ) diff --git a/curryer/correction/results_io.py b/curryer/correction/results_io.py new file mode 100644 index 00000000..976465cb --- /dev/null +++ b/curryer/correction/results_io.py @@ -0,0 +1,370 @@ +"""NetCDF I/O helpers for the correction pipeline. + +This module owns all read/write operations on the NetCDF result file: + +- :func:`_build_netcdf_structure` -- initialise the in-memory data dict. +- :func:`_save_netcdf_results` -- write the final NetCDF output file. +- :func:`_save_netcdf_checkpoint` -- write a partial checkpoint after each GCP pair. +- :func:`_load_checkpoint` -- reload a checkpoint to resume an interrupted run. +- :func:`_cleanup_checkpoint` -- delete the checkpoint file after a successful run. +""" + +import logging + +import numpy as np +import pandas as pd +import xarray as xr + +from curryer.correction.config import CorrectionConfig, ParameterType + +logger = logging.getLogger(__name__) + + +def _build_netcdf_structure(config: CorrectionConfig, n_param_sets: int, n_gcp_pairs: int) -> dict: + """ + Build NetCDF data structure dynamically from configuration. + + This creates the netcdf_data dictionary with proper variable names based on + the parameters defined in the configuration, avoiding hardcoded mission-specific names. + + Args: + config: CorrectionConfig with parameters and optional NetCDF config + n_param_sets: Number of parameter sets (iterations) + n_gcp_pairs: Number of GCP pairs + + Returns: + Dictionary with initialized arrays for all NetCDF variables + """ + logger.info(f"Building NetCDF data structure for {n_param_sets} parameter sets × {n_gcp_pairs} GCP pairs") + + # Ensure NetCDFConfig exists + config.ensure_netcdf_config() + + # Start with coordinate dimensions + netcdf_data = { + "parameter_set_id": np.arange(n_param_sets), + "gcp_pair_id": np.arange(n_gcp_pairs), + } + + # Add parameter variables dynamically based on config.parameters + param_count = 0 + for param in config.parameters: + if param.ptype == ParameterType.CONSTANT_KERNEL: + # CONSTANT_KERNEL parameters have roll, pitch, yaw components + for angle in ["roll", "pitch", "yaw"]: + metadata = config.netcdf.get_parameter_netcdf_metadata(param, angle) + var_name = metadata.variable_name + netcdf_data[var_name] = np.full(n_param_sets, np.nan) + logger.debug(f" Added parameter variable: {var_name} ({metadata.long_name})") + param_count += 1 + else: + # OFFSET_KERNEL and OFFSET_TIME are single values + metadata = config.netcdf.get_parameter_netcdf_metadata(param) + var_name = metadata.variable_name + netcdf_data[var_name] = np.full(n_param_sets, np.nan) + logger.debug(f" Added parameter variable: {var_name} ({metadata.long_name})") + param_count += 1 + + logger.info(f" Created {param_count} parameter variables from {len(config.parameters)} parameter configs") + + # Add standard error statistics (2D: parameter_set_id × gcp_pair_id) + error_metrics = { + "rms_error_m": "RMS geolocation error", + "mean_error_m": "Mean geolocation error", + "max_error_m": "Maximum geolocation error", + "std_error_m": "Standard deviation of geolocation error", + "n_measurements": "Number of measurement points", + } + + for var_name, description in error_metrics.items(): + if var_name == "n_measurements": + netcdf_data[var_name] = np.full((n_param_sets, n_gcp_pairs), 0, dtype=int) + else: + netcdf_data[var_name] = np.full((n_param_sets, n_gcp_pairs), np.nan) + logger.debug(f" Added error metric: {var_name}") + + # Add image matching results (2D: parameter_set_id × gcp_pair_id) + image_match_vars = { + "im_lat_error_km": "Image matching latitude error", + "im_lon_error_km": "Image matching longitude error", + "im_ccv": "Image matching correlation coefficient", + "im_grid_step_m": "Image matching final grid step size", + } + + for var_name, description in image_match_vars.items(): + netcdf_data[var_name] = np.full((n_param_sets, n_gcp_pairs), np.nan) + logger.debug(f" Added image matching variable: {var_name}") + + # Add overall performance metrics (1D: parameter_set_id) + # Use dynamic threshold metric name + threshold_metric = config.netcdf.get_threshold_metric_name() + overall_metrics = { + threshold_metric: f"Percentage of pairs with error < {config.performance_threshold_m}m", + "mean_rms_all_pairs": "Mean RMS error across all GCP pairs", + "worst_pair_rms": "Worst performing GCP pair RMS error", + "best_pair_rms": "Best performing GCP pair RMS error", + } + + for var_name, description in overall_metrics.items(): + netcdf_data[var_name] = np.full(n_param_sets, np.nan) + logger.debug(f" Added overall metric: {var_name}") + + logger.info(f"NetCDF data structure created with {len(netcdf_data)} variables") + + return netcdf_data + + +def _save_netcdf_checkpoint(netcdf_data, output_file, config, pair_idx_completed): + """ + Save NetCDF checkpoint with partial results after each GCP pair completes. + + This enables resuming correction runs if they are interrupted. + Adapted for pair-outer loop order where each pair processes all parameters. + + Args: + netcdf_data: Dictionary with current NetCDF data + output_file: Path to final output file (checkpoint uses .checkpoint.nc suffix) + config: CorrectionConfig with metadata + pair_idx_completed: Index of the last completed GCP pair (for pair-outer loop) + """ + + checkpoint_file = output_file.parent / f"{output_file.stem}_checkpoint.nc" + + # Ensure NetCDFConfig exists + config.ensure_netcdf_config() + + # Create coordinate arrays + coords = { + "parameter_set_id": netcdf_data["parameter_set_id"], + "gcp_pair_id": netcdf_data["gcp_pair_id"], + } + + # Build variable list dynamically from netcdf_data keys + data_vars = {} + for var_name, var_data in netcdf_data.items(): + if var_name not in coords: + if isinstance(var_data, np.ndarray): + if var_data.ndim == 1: + data_vars[var_name] = (["parameter_set_id"], var_data) + elif var_data.ndim == 2: + data_vars[var_name] = (["parameter_set_id", "gcp_pair_id"], var_data) + + # Create dataset + ds = xr.Dataset(data_vars, coords=coords) + + # Add regular metadata + ds.attrs.update( + { + "title": config.netcdf.title, + "description": config.netcdf.description, + "created": pd.Timestamp.now().isoformat(), + "correction_iterations": config.n_iterations, + "performance_threshold_m": config.netcdf.performance_threshold_m, + "parameter_count": len(config.parameters), + "random_seed": str(config.seed) if config.seed is not None else "None", + } + ) + + # Add checkpoint-specific metadata (NetCDF-compatible types) + ds.attrs["checkpoint"] = 1 # Use integer instead of boolean for NetCDF compatibility + ds.attrs["completed_gcp_pairs"] = int(pair_idx_completed + 1) + ds.attrs["total_gcp_pairs"] = int(len(netcdf_data["gcp_pair_id"])) + ds.attrs["checkpoint_timestamp"] = pd.Timestamp.now().isoformat() + + # Add parameter variable attributes from config + for param in config.parameters: + if param.ptype == ParameterType.CONSTANT_KERNEL: + for angle in ["roll", "pitch", "yaw"]: + metadata = config.netcdf.get_parameter_netcdf_metadata(param, angle) + if metadata.variable_name in ds.data_vars: + ds[metadata.variable_name].attrs.update({"units": metadata.units, "long_name": metadata.long_name}) + else: + metadata = config.netcdf.get_parameter_netcdf_metadata(param) + if metadata.variable_name in ds.data_vars: + ds[metadata.variable_name].attrs.update({"units": metadata.units, "long_name": metadata.long_name}) + + # Add standard metric attributes + standard_attrs = config.netcdf.get_standard_attributes() + threshold_metric = config.netcdf.get_threshold_metric_name() + standard_attrs[threshold_metric] = { + "units": "percent", + "long_name": f"Percentage of pairs with error < {config.performance_threshold_m}m", + } + for var, attrs in standard_attrs.items(): + if var in ds.data_vars: + ds[var].attrs.update(attrs) + + # Save to file in one operation + checkpoint_file.parent.mkdir(parents=True, exist_ok=True) + ds.to_netcdf(checkpoint_file, mode="w") # Force overwrite mode + ds.close() + + logger.info(f" Checkpoint saved: {pair_idx_completed + 1}/{len(netcdf_data['gcp_pair_id'])} GCP pairs complete") + + +def _load_checkpoint(output_file, config): + """ + Load checkpoint if it exists and convert back to netcdf_data dict. + + Args: + output_file: Path to final output file (will check for .checkpoint.nc) + config: CorrectionConfig for structure information + + Returns: + Tuple of (netcdf_data dict, start_idx) or (None, 0) if no checkpoint + """ + + checkpoint_file = output_file.parent / f"{output_file.stem}_checkpoint.nc" + + if not checkpoint_file.exists(): + return None, 0 + + logger.info(f"Found checkpoint file: {checkpoint_file}") + + try: + ds = xr.open_dataset(checkpoint_file, decode_timedelta=False) + + # Verify this is actually a checkpoint (checkpoint attribute is 1 for true, 0 or missing for false) + checkpoint_flag = ds.attrs.get("checkpoint", 0) + if not checkpoint_flag: # Will be True if checkpoint=1, False if checkpoint=0 or missing + logger.warning("File exists but is not marked as checkpoint, ignoring") + ds.close() + return None, 0 + + completed = ds.attrs.get("completed_gcp_pairs", 0) + total = ds.attrs.get("total_gcp_pairs", 0) + timestamp = ds.attrs.get("checkpoint_timestamp", "unknown") + + logger.info(f"Checkpoint from {timestamp}: {completed}/{total} GCP pairs complete") + + # Convert xarray.Dataset back to netcdf_data dictionary + netcdf_data = {} + + # Add coordinates + netcdf_data["parameter_set_id"] = ds.coords["parameter_set_id"].values + netcdf_data["gcp_pair_id"] = ds.coords["gcp_pair_id"].values + + # Add all data variables + for var_name in ds.data_vars: + netcdf_data[var_name] = ds[var_name].values + + ds.close() + + logger.info(f"Checkpoint loaded successfully, resuming from GCP pair {completed}") + + return netcdf_data, completed + + except Exception as e: + logger.error(f"Failed to load checkpoint: {e}") + return None, 0 + + +def _cleanup_checkpoint(output_file): + """ + Remove checkpoint file after successful completion. + + Args: + output_file: Path to final output file (will remove .checkpoint.nc) + """ + checkpoint_file = output_file.parent / f"{output_file.stem}_checkpoint.nc" + + if checkpoint_file.exists(): + try: + checkpoint_file.unlink() + logger.info(f"Checkpoint file cleaned up: {checkpoint_file}") + except Exception as e: + logger.warning(f"Failed to remove checkpoint file: {e}") + + +def _save_netcdf_results(netcdf_data, output_file, config): + """ + Save results to NetCDF file using config-driven metadata. + + This function dynamically builds the NetCDF file structure from the + netcdf_data dictionary, using configuration for all metadata rather + than hardcoding mission-specific values. + + Args: + netcdf_data: Dictionary with all NetCDF variables and data + output_file: Path to output NetCDF file + config: CorrectionConfig with NetCDF metadata + """ + + logger.info(f"Saving NetCDF results to: {output_file}") + + # Ensure NetCDFConfig exists + config.ensure_netcdf_config() + + # Create coordinate arrays + coords = { + "parameter_set_id": netcdf_data["parameter_set_id"], + "gcp_pair_id": netcdf_data["gcp_pair_id"], + } + + # Build variable list dynamically from netcdf_data keys + data_vars = {} + + # Add all non-coordinate variables, determining dimensions from array shape + for var_name, var_data in netcdf_data.items(): + if var_name not in coords: + # Determine dimensions from array shape + if isinstance(var_data, np.ndarray): + if var_data.ndim == 1: + data_vars[var_name] = (["parameter_set_id"], var_data) + elif var_data.ndim == 2: + data_vars[var_name] = (["parameter_set_id", "gcp_pair_id"], var_data) + + logger.info(f" Creating dataset with {len(data_vars)} data variables") + + # Create dataset + ds = xr.Dataset(data_vars, coords=coords) + + # Add global metadata from config + ds.attrs.update( + { + "title": config.netcdf.title, + "description": config.netcdf.description, + "created": pd.Timestamp.now().isoformat(), + "correction_iterations": config.n_iterations, + "performance_threshold_m": config.netcdf.performance_threshold_m, + "parameter_count": len(config.parameters), + "random_seed": str(config.seed) if config.seed is not None else "None", + } + ) + + # Add parameter variable attributes from config + for param in config.parameters: + if param.ptype == ParameterType.CONSTANT_KERNEL: + # Add metadata for roll, pitch, yaw components + for angle in ["roll", "pitch", "yaw"]: + metadata = config.netcdf.get_parameter_netcdf_metadata(param, angle) + if metadata.variable_name in ds.data_vars: + ds[metadata.variable_name].attrs.update({"units": metadata.units, "long_name": metadata.long_name}) + else: + # Add metadata for single-value parameters + metadata = config.netcdf.get_parameter_netcdf_metadata(param) + if metadata.variable_name in ds.data_vars: + ds[metadata.variable_name].attrs.update({"units": metadata.units, "long_name": metadata.long_name}) + + # Add standard metric attributes from config (allows mission overrides) + standard_attrs = config.netcdf.get_standard_attributes() + + # Add dynamic threshold metric + threshold_metric = config.netcdf.get_threshold_metric_name() + standard_attrs[threshold_metric] = { + "units": "percent", + "long_name": f"Percentage of pairs with error < {config.performance_threshold_m}m", + } + + for var, attrs in standard_attrs.items(): + if var in ds.data_vars: + ds[var].attrs.update(attrs) + + # Save to file + output_file.parent.mkdir(parents=True, exist_ok=True) + ds.to_netcdf(output_file) + + logger.info(f" NetCDF file saved successfully") + logger.info(f" Dimensions: {dict(ds.sizes)}") + logger.info(f" Data variables: {len(list(ds.data_vars.keys()))}") + logger.info(f" File: {output_file}") diff --git a/curryer/correction/verification.py b/curryer/correction/verification.py new file mode 100644 index 00000000..2d6c676f --- /dev/null +++ b/curryer/correction/verification.py @@ -0,0 +1,813 @@ +"""Verification module for geolocation requirements compliance. + +Provides :func:`verify`, a standalone entry point that evaluates the current +set of SPICE kernels and alignment parameters against mission geolocation +requirements — without running the iterative correction loop. + +Typical use-cases +----------------- +Weekly automated check (CLARREO) + Pass pre-computed ``image_matching_results`` (the most common path): + + >>> result = verify(config, work_dir, image_matching_results=weekly_datasets) + >>> if not result.passed: + ... send_alert(result.summary_table) + +Post-correction validation + After a full GCS run, verify the optimised parameter set: + + >>> result = verify(config, work_dir, image_matching_results=post_correction_datasets) + +One-off compliance check + Provide already-geolocated data and let verification run image matching: + + >>> result = verify(config, work_dir, geolocated_data=raw_dataset) + +Models +------ +:class:`RequirementsConfig` + Verification thresholds (performance limit and pass-rate). +:class:`GCPError` + Per-measurement/GCP error detail. +:class:`VerificationResult` + Structured pass/fail result; serialisable via Pydantic JSON methods. + +""" + +from __future__ import annotations + +import logging +import time +from datetime import datetime, timezone +from pathlib import Path + +import numpy as np +import xarray as xr +from pydantic import BaseModel, ConfigDict, Field + +from curryer.correction.config import CorrectionConfig, RequirementsConfig +from curryer.correction.error_stats import ErrorStatsConfig, ErrorStatsProcessor + +logger = logging.getLogger(__name__) + +# ============================================================================ +# Pydantic models +# ============================================================================ + + +class GCPError(BaseModel): + """Per-measurement/GCP error detail. + + Each instance corresponds to one row in the aggregated image-matching + output — typically one measurement from a single GCP pair. + + Attributes + ---------- + gcp_index : int + Zero-based measurement index in the aggregated dataset. + science_key : str + Identifier for the science data segment (dataset label or index). + gcp_key : str + Identifier for the ground-control-point source. + lat_error_deg : float + Latitude error in degrees (positive = northward shift). + lon_error_deg : float + Longitude error in degrees (positive = eastward shift). + nadir_equiv_error_m : float or None + Nadir-equivalent total geolocation error in metres, or ``None`` when + error-stats processing was not performed. + correlation : float or None + Image-matching correlation score, or ``None`` when not available. + passed : bool + Whether this measurement satisfies the per-measurement threshold. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + gcp_index: int + science_key: str + gcp_key: str + lat_error_deg: float + lon_error_deg: float + nadir_equiv_error_m: float | None = None + correlation: float | None = None + passed: bool + + +class VerificationResult(BaseModel): + """Structured result from a :func:`verify` call. + + Most fields are JSON-serialisable via Pydantic's ``model_dump()`` / + ``model_dump_json()``. The :attr:`aggregate_stats` field (an + ``xr.Dataset``) must be excluded when serialising to JSON — persist it + separately (e.g. via ``aggregate_stats.to_netcdf(path)``):: + + json_str = result.model_dump_json(exclude={"aggregate_stats"}) + result.aggregate_stats.to_netcdf("verification_stats.nc") + + Attributes + ---------- + passed : bool + ``True`` when :attr:`percent_within_threshold` ≥ + :attr:`requirements.performance_spec_percent`. + per_gcp_errors : list[GCPError] + One entry per measurement in the aggregated dataset. + aggregate_stats : xr.Dataset + Full output from + :meth:`~curryer.correction.error_stats.ErrorStatsProcessor.process_geolocation_errors`. + requirements : RequirementsConfig + The thresholds used for this verification run. + summary_table : str + Human-readable ASCII table suitable for logging or reports. + percent_within_threshold : float + Percentage of measurements with nadir-equivalent error below + :attr:`requirements.performance_threshold_m`. + warnings : list[str] + Non-empty when :attr:`passed` is ``False``. + timestamp : datetime + UTC wall-clock time when :func:`verify` was called. + files_processed : list[str] + Science/GCP key pairs that were processed, as ``"+"`` + strings. Empty when the source mapping is unavailable. + elapsed_time_s : float or None + Wall-clock time for the verify call in seconds, or ``None`` when not + measured. + config_snapshot : dict or None + Key config fields used for this run (threshold, spec percent, + instrument name), for reproducibility records. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + passed: bool + per_gcp_errors: list[GCPError] + aggregate_stats: xr.Dataset + requirements: RequirementsConfig + summary_table: str + percent_within_threshold: float + warnings: list[str] + timestamp: datetime + + # Provenance fields (Prompt 5) — all optional so existing callers are unaffected. + files_processed: list[str] = Field(default_factory=list) + elapsed_time_s: float | None = None + config_snapshot: dict | None = None + + +# ============================================================================ +# Internal helpers +# ============================================================================ + + +def _build_requirements(config: CorrectionConfig) -> RequirementsConfig: + """Extract or construct :class:`RequirementsConfig` from *config*. + + If *config* carries a ``verification`` attribute (a + :class:`RequirementsConfig` instance), that object is returned directly. + Otherwise the top-level :attr:`~CorrectionConfig.performance_threshold_m` + and :attr:`~CorrectionConfig.performance_spec_percent` fields are used. + + Parameters + ---------- + config : CorrectionConfig + Correction configuration from which to extract thresholds. + + Returns + ------- + RequirementsConfig + """ + # TODO(#131): Add `verification: RequirementsConfig | None` as an optional + # field on CorrectionConfig so this override path works without __setattr__. + existing = getattr(config, "verification", None) + if isinstance(existing, RequirementsConfig): + logger.debug("Using RequirementsConfig from config.verification") + return existing + return RequirementsConfig( + performance_threshold_m=config.performance_threshold_m, + performance_spec_percent=config.performance_spec_percent, + ) + + +def _aggregate_results( + image_matching_results: list[xr.Dataset], + config: CorrectionConfig, +) -> xr.Dataset: + """Aggregate a list of per-GCP-pair image-matching datasets into one. + + For multiple input datasets, this delegates to the same aggregation logic + used by the correction pipeline + (:func:`~curryer.correction.pipeline._aggregate_image_matching_results`). + For a single input dataset, it is returned directly after ensuring that + the ``measurement`` coordinate is present and consists of sequential + integer indices. + + Parameters + ---------- + image_matching_results : list[xr.Dataset] + One element per GCP pair. Each dataset must have a ``measurement`` + dimension and at minimum ``lat_error_deg`` / ``lon_error_deg`` variables. + config : CorrectionConfig + Used for mission-specific variable names + (``spacecraft_position_name``, ``boresight_name``, + ``transformation_matrix_name``). + + Returns + ------- + xr.Dataset + Combined dataset with a single ``measurement`` dimension. + """ + from curryer.correction.pipeline import _aggregate_image_matching_results + + if len(image_matching_results) == 1: + ds = image_matching_results[0] + # Always normalize the measurement coordinate to sequential integers + # so that downstream gcp_index values are predictable. + n = ds.sizes.get("measurement", len(ds["lat_error_deg"])) + ds = ds.assign_coords(measurement=np.arange(n)) + return ds + + return _aggregate_image_matching_results(image_matching_results, config) + + +def _run_error_stats( + aggregated: xr.Dataset, + config: CorrectionConfig, +) -> xr.Dataset: + """Run :class:`~curryer.correction.error_stats.ErrorStatsProcessor` on *aggregated*. + + Parameters + ---------- + aggregated : xr.Dataset + Combined image-matching result with a ``measurement`` dimension. + config : CorrectionConfig + Used to build :class:`~curryer.correction.error_stats.ErrorStatsConfig`. + + Returns + ------- + xr.Dataset + Processed dataset with ``nadir_equiv_total_error_m`` and related + intermediate variables. + """ + error_config = ErrorStatsConfig.from_correction_config(config) + processor = ErrorStatsProcessor(config=error_config) + return processor.process_geolocation_errors(aggregated) + + +def _check_threshold( + aggregate_stats: xr.Dataset, + requirements: RequirementsConfig, +) -> tuple[bool, float]: + """Evaluate whether performance meets the threshold requirement. + + Uses ``nadir_equiv_total_error_m`` from the + :class:`~curryer.correction.error_stats.ErrorStatsProcessor` output. + + Parameters + ---------- + aggregate_stats : xr.Dataset + Output of ``ErrorStatsProcessor.process_geolocation_errors()``. + requirements : RequirementsConfig + Performance limits to evaluate against. + + Returns + ------- + tuple[bool, float] + ``(passed, percent_within_threshold)`` + """ + nadir_errors = aggregate_stats["nadir_equiv_total_error_m"].values + if len(nadir_errors) == 0: + return False, 0.0 + count_below = int(np.sum(nadir_errors < requirements.performance_threshold_m)) + percent_below = float(count_below / len(nadir_errors) * 100.0) + passed = percent_below >= requirements.performance_spec_percent + return passed, percent_below + + +def _generate_warnings( + passed: bool, + percent_below: float, + requirements: RequirementsConfig, +) -> list[str]: + """Generate user-facing warning messages when verification fails. + + Parameters + ---------- + passed : bool + Overall pass/fail result. + percent_below : float + Percentage of measurements within the threshold. + requirements : RequirementsConfig + Performance limits used for the check. + + Returns + ------- + list[str] + Empty list when *passed* is ``True``; otherwise one warning string. + """ + if not passed: + return [ + f"⚠️ VERIFICATION FAILED: Only {percent_below:.1f}% of observations " + f"meet the {requirements.performance_threshold_m}m nadir-equivalent error threshold " + f"(required: {requirements.performance_spec_percent}%). " + f"Recommend running the correction module to optimise calibration parameters." + ] + return [] + + +def _build_per_gcp_errors( + aggregate_stats: xr.Dataset, + source_mapping: list[tuple[str, str]], + requirements: RequirementsConfig, +) -> list[GCPError]: + """Build a :class:`GCPError` for every measurement in *aggregate_stats*. + + Parameters + ---------- + aggregate_stats : xr.Dataset + Processed dataset from + :func:`~curryer.correction.error_stats.ErrorStatsProcessor.process_geolocation_errors`. + source_mapping : list[tuple[str, str]] + ``[(science_key, gcp_key), ...]`` parallel to the measurement dimension. + If shorter than the number of measurements the remainder fall back to + ``("sci_{i}", "gcp_{i}")``. + requirements : RequirementsConfig + Used for per-measurement pass/fail evaluation. + + Returns + ------- + list[GCPError] + """ + n = aggregate_stats.sizes.get("measurement", 0) + if n == 0: + return [] + + nadir_errors = aggregate_stats["nadir_equiv_total_error_m"].values + lat_errors = aggregate_stats["lat_error_deg"].values + lon_errors = aggregate_stats["lon_error_deg"].values + + # Optional correlation variable (several possible names) + correlation_values: np.ndarray | None = None + for corr_var in ("correlation", "ccv", "im_ccv"): + if corr_var in aggregate_stats.data_vars: + correlation_values = aggregate_stats[corr_var].values + break + + errors: list[GCPError] = [] + for i in range(n): + if i < len(source_mapping): + sci_key, gcp_key = source_mapping[i] + else: + sci_key, gcp_key = f"sci_{i}", f"gcp_{i}" + + corr: float | None = None + if correlation_values is not None: + raw = float(correlation_values[i]) + if np.isfinite(raw): # type: ignore[arg-type] + corr = raw + + errors.append( + GCPError( + gcp_index=i, + science_key=sci_key, + gcp_key=gcp_key, + lat_error_deg=float(lat_errors[i]), + lon_error_deg=float(lon_errors[i]), + nadir_equiv_error_m=float(nadir_errors[i]), + correlation=corr, + passed=float(nadir_errors[i]) < requirements.performance_threshold_m, + ) + ) + return errors + + +def _build_source_mapping( + image_matching_results: list[xr.Dataset], +) -> list[tuple[str, str]]: + """Map every measurement back to its (science_key, gcp_key) pair. + + The mapping is derived from dataset attributes (``sci_key`` / ``gcp_key``) + when present, otherwise falls back to ``"result_{i}"`` labels. + + Parameters + ---------- + image_matching_results : list[xr.Dataset] + Raw per-GCP-pair datasets before aggregation. + + Returns + ------- + list[tuple[str, str]] + Parallel to the ``measurement`` dimension of the aggregated dataset. + """ + mapping: list[tuple[str, str]] = [] + for pair_idx, ds in enumerate(image_matching_results): + # Prefer explicit / richer identifiers when available, with stable fallbacks. + sci_key_attr = ( + ds.attrs.get("sci_key") + or ds.attrs.get("science_key") + or ds.attrs.get("science_file") + or f"result_{pair_idx}" + ) + gcp_key_attr = ( + ds.attrs.get("gcp_pair_id") + or ds.attrs.get("gcp_file") + or ds.attrs.get("gcp_key") + or ds.attrs.get("gcp_pair_index") + or f"gcp_{pair_idx}" + ) + sci_key = str(sci_key_attr) + gcp_key = str(gcp_key_attr) + n_meas = ds.sizes.get("measurement", len(ds["lat_error_deg"])) + for _ in range(n_meas): + mapping.append((sci_key, gcp_key)) + return mapping + + +def _format_summary_table( + per_gcp_errors: list[GCPError], + requirements: RequirementsConfig, + percent_within: float, + passed: bool, +) -> str: + """Generate a human-readable summary table. + + Example output:: + + ┌──────────────────────────────────────────────────────┐ + │ Verification Summary │ + ├──────┬────────────┬────────────┬──────────┬──────────┤ + │ GCP │ Lat err(°) │ Lon err(°) │ Nadir(m) │ Status │ + ├──────┼────────────┼────────────┼──────────┼──────────┤ + │ 0 │ 0.00123 │ -0.00045 │ 145.2 │ ✓ │ + │ 1 │ 0.00567 │ 0.00234 │ 312.8 │ ✗ │ + ├──────┴────────────┴────────────┴──────────┴──────────┤ + │ Result: PASSED — 60.0% within 250.0m (req: 39.0%) │ + └──────────────────────────────────────────────────────┘ + + Parameters + ---------- + per_gcp_errors : list[GCPError] + Per-measurement detail. + requirements : RequirementsConfig + Thresholds used for evaluation. + percent_within : float + Percentage of measurements within the threshold. + passed : bool + Overall pass/fail result. + + Returns + ------- + str + Multi-line formatted table. + """ + # Column widths + w_gcp = 6 + w_lat = 12 + w_lon = 12 + w_nadir = 10 + w_status = 8 + col_inner = w_gcp + w_lat + w_lon + w_nadir + w_status + 4 # 4 column separators + + title = " Verification Summary" + verdict = "PASSED" if passed else "FAILED" + footer_text = ( + f" Result: {verdict} — {percent_within:.1f}% within " + f"{requirements.performance_threshold_m}m " + f"(req: {requirements.performance_spec_percent}%)" + ) + + # inner_width must accommodate columns, title, AND footer + inner_width = max(col_inner, len(title) + 2, len(footer_text)) + + def _h_sep(left, mid, right, fill="─"): + """Build a column-width separator, then pad to inner_width.""" + core = ( + left + + fill * w_gcp + + mid + + fill * w_lat + + mid + + fill * w_lon + + mid + + fill * w_nadir + + mid + + fill * w_status + + right + ) + # Extend to full inner_width if footer/title made the table wider + return core + fill * (inner_width - len(core)) + + lines: list[str] = [] + + lines.append("┌" + "─" * inner_width + "┐") + lines.append("│" + title.ljust(inner_width) + "│") + lines.append("├" + _h_sep("", "┬", "", "─") + "┤") + # Header row + h_gcp = " GCP".center(w_gcp) + h_lat = "Lat err(°)".center(w_lat) + h_lon = "Lon err(°)".center(w_lon) + h_nadir = "Nadir(m)".center(w_nadir) + h_status = "Status".center(w_status) + lines.append(f"│{h_gcp}│{h_lat}│{h_lon}│{h_nadir}│{h_status}│") + lines.append("├" + _h_sep("", "┼", "", "─") + "┤") + + for err in per_gcp_errors: + c_gcp = str(err.gcp_index).rjust(w_gcp - 1).ljust(w_gcp) + c_lat = f"{err.lat_error_deg:+.5f}".center(w_lat) + c_lon = f"{err.lon_error_deg:+.5f}".center(w_lon) + if err.nadir_equiv_error_m is not None: + c_nadir = f"{err.nadir_equiv_error_m:.1f}".center(w_nadir) + else: + c_nadir = "N/A".center(w_nadir) + c_status = (" ✓ " if err.passed else " ✗ ").center(w_status) + lines.append(f"│{c_gcp}│{c_lat}│{c_lon}│{c_nadir}│{c_status}│") + + # Footer + lines.append("├" + "─" * inner_width + "┤") + lines.append("│" + footer_text.ljust(inner_width) + "│") + lines.append("└" + "─" * inner_width + "┘") + + return "\n".join(lines) + + +def _log_pairing_summary(pairs: list[tuple[Path, Path]], unpaired: list[Path] | None = None) -> None: + """Log a human-readable GCP pairing summary. + + Parameters + ---------- + pairs : list of (Path, Path) + Successfully paired (observation, gcp) paths. + unpaired : list of Path or None, optional + Observation paths for which no matching GCP was found. + """ + lines = ["GCP Pairing Summary:"] + for obs, gcp in pairs: + lines.append(f" ✓ {obs.name} → {gcp.name}") + if unpaired: + for obs in unpaired: + lines.append(f" ✗ {obs.name} → No matching GCP found") + lines.append(f"Proceeding with {len(pairs)} observation(s).") + logger.info("\n".join(lines)) + + +# ============================================================================ +# Public API +# ============================================================================ + + +def verify( + config: CorrectionConfig, + # NEW: File-path-based input modes (signature established; body raises NotImplementedError) + gcp_pairs: list[tuple[str | Path, str | Path]] | None = None, + observation_paths: list[str | Path] | None = None, + gcp_directory: str | Path | None = None, + # EXISTING: Pre-computed input modes (backward-compatible) + image_matching_results: list[xr.Dataset] | None = None, + geolocated_data: xr.Dataset | None = None, + work_dir: Path | None = None, +) -> VerificationResult: + """Evaluate current alignment against mission requirements. + + No parameter variation, no kernel creation, no iteration loop. This + function checks whether a **given** set of alignment parameters meets + geolocation requirements. + + Input priority (first match wins) + ---------------------------------- + 1. *image_matching_results* — pre-computed outputs from image matching; + the most common entry point for weekly automated checks. + 2. *geolocated_data* — raw geolocated data; requires + ``config._image_matching_override`` to be set. + 3. *gcp_pairs* — explicit (observation, gcp) file-path pairs. + **Not yet implemented** — raises ``NotImplementedError``. + 4. *observation_paths* + *gcp_directory* — auto-paired via spatial overlap. + **Not yet implemented** — raises ``NotImplementedError``. + 5. None of the above provided — raises :class:`ValueError`. + + Parameters + ---------- + config : CorrectionConfig + Configuration with all mission-specific settings: + - Performance thresholds (``performance_threshold_m``, ``performance_spec_percent``) + - Spacecraft variable names (``spacecraft_position_name``, ``boresight_name``, etc.) + - Geolocation settings (SPICE kernels, instrument configuration) + - Optional ``_image_matching_override`` (for *geolocated_data* path) + - Optional ``verification`` override (:class:`RequirementsConfig`) + gcp_pairs : list of (str | Path, str | Path) or None + Explicit (observation_path, gcp_path) pairs. + **Not yet implemented** — raises ``NotImplementedError``. + observation_paths : list of str | Path or None + Observation file paths for automatic GCP pairing. + Requires *gcp_directory*. + **Not yet implemented** — raises ``NotImplementedError``. + gcp_directory : str | Path or None + Directory of GCP reference images for automatic pairing with + *observation_paths*. + **Not yet implemented** — raises ``NotImplementedError``. + image_matching_results : list[xr.Dataset] or None + Pre-computed image-matching datasets, one per GCP pair. Each must + have a ``measurement`` dimension and ``lat_error_deg`` / + ``lon_error_deg`` variables plus the spacecraft-state variables + expected by + :class:`~curryer.correction.error_stats.ErrorStatsProcessor`. + geolocated_data : xr.Dataset or None + Already-geolocated data on which image matching will be run using + ``config._image_matching_override``. Ignored when + *image_matching_results* is provided. + work_dir : Path or None, optional + Working directory for outputs. Created if absent. + If None (default), uses ``./verification_output``. + + Returns + ------- + VerificationResult + Structured pass/fail result with per-GCP detail and a + human-readable :attr:`~VerificationResult.summary_table`. + + Raises + ------ + NotImplementedError + When *gcp_pairs* or (*observation_paths* + *gcp_directory*) is + provided — these file-path modes are not yet implemented. + ValueError + When none of the input modes is provided, or when *geolocated_data* + is supplied but ``config._image_matching_override`` is not set. + """ + # ------------------------------------------------------------------ + # File-path input modes: API established; implementation deferred + # ------------------------------------------------------------------ + if gcp_pairs is not None: + raise NotImplementedError( + "File-path-based verify() via gcp_pairs is not yet implemented. " + "Pre-compute image_matching_results and pass them directly. " + "See examples/correction/ for the recommended workflow." + ) + + if observation_paths is not None or gcp_directory is not None: + raise NotImplementedError( + "Auto-pairing verify() via observation_paths + gcp_directory is not yet implemented. " + "Pre-compute image_matching_results and pass them directly. " + "See examples/correction/ for the recommended workflow." + ) + # Handle optional work_dir with sensible default + if work_dir is None: + work_dir = Path("verification_output") + + work_dir = Path(work_dir) + work_dir.mkdir(parents=True, exist_ok=True) + + verify_start = time.time() + timestamp = datetime.now(tz=timezone.utc) + requirements = _build_requirements(config) + + logger.info( + "Starting verification: threshold=%.1fm, spec=%.1f%%", + requirements.performance_threshold_m, + requirements.performance_spec_percent, + ) + + # ------------------------------------------------------------------ + # Step 1: Obtain image-matching results + # ------------------------------------------------------------------ + source_mapping: list[tuple[str, str]] = [] + + if image_matching_results is not None: + if not image_matching_results: + raise ValueError("image_matching_results must not be empty.") + logger.info("Using %d pre-computed image-matching result(s)", len(image_matching_results)) + source_mapping = _build_source_mapping(image_matching_results) + aggregated = _aggregate_results(image_matching_results, config) + + elif geolocated_data is not None: + im_override = getattr(config, "_image_matching_override", None) + if im_override is None: + raise ValueError( + "geolocated_data was provided but config._image_matching_override is not set. " + "Either supply pre-computed image_matching_results or set " + "config._image_matching_override = your_func." + ) + logger.info("Running image matching on provided geolocated_data") + matched = im_override(geolocated_data) + if not isinstance(matched, list): + matched = [matched] + source_mapping = _build_source_mapping(matched) + aggregated = _aggregate_results(matched, config) + + else: + raise ValueError( + "Neither image_matching_results nor geolocated_data was provided. Supply at least one of them to verify()." + ) + + # ------------------------------------------------------------------ + # Step 2: Compute nadir-equivalent error statistics + # ------------------------------------------------------------------ + logger.info("Computing nadir-equivalent error statistics") + aggregate_stats = _run_error_stats(aggregated, config) + + # ------------------------------------------------------------------ + # Step 3: Threshold check + # ------------------------------------------------------------------ + passed, percent_within = _check_threshold(aggregate_stats, requirements) + + # ------------------------------------------------------------------ + # Step 4: Per-GCP detail + # ------------------------------------------------------------------ + per_gcp_errors = _build_per_gcp_errors(aggregate_stats, source_mapping, requirements) + + # ------------------------------------------------------------------ + # Step 5: Warnings + summary table + # ------------------------------------------------------------------ + warnings = _generate_warnings(passed, percent_within, requirements) + summary_table = _format_summary_table(per_gcp_errors, requirements, percent_within, passed) + + if warnings: + for w in warnings: + logger.warning(w) + + logger.info( + "Verification %s — %.1f%% within %.1fm threshold (requirement: %.1f%%)", + "PASSED" if passed else "FAILED", + percent_within, + requirements.performance_threshold_m, + requirements.performance_spec_percent, + ) + logger.info("\n%s", summary_table) + + # Build provenance fields + files_processed = [f"{sci}+{gcp}" for sci, gcp in source_mapping] + config_snapshot = { + "performance_threshold_m": requirements.performance_threshold_m, + "performance_spec_percent": requirements.performance_spec_percent, + "instrument_name": getattr(config.geo, "instrument_name", None), + } + elapsed_time_s = time.time() - verify_start + + return VerificationResult( + passed=passed, + per_gcp_errors=per_gcp_errors, + aggregate_stats=aggregate_stats, + requirements=requirements, + summary_table=summary_table, + percent_within_threshold=percent_within, + warnings=warnings, + timestamp=timestamp, + files_processed=files_processed, + elapsed_time_s=elapsed_time_s, + config_snapshot=config_snapshot, + ) + + +def compare_results(before: VerificationResult, after: VerificationResult) -> str: + """Generate a side-by-side comparison of two verification results. + + Useful for evaluating whether a correction run improved geolocation + accuracy relative to a baseline. + + Parameters + ---------- + before : VerificationResult + Baseline verification result (e.g., pre-correction). + after : VerificationResult + Updated verification result (e.g., post-correction). + + Returns + ------- + str + Human-readable side-by-side comparison table. + """ + lines = [ + "Verification Comparison", + "=" * 55, + f"{'Metric':<30} {'Before':>12} {'After':>12}", + "-" * 55, + ] + + b_stats = dict(before.aggregate_stats.attrs) if before.aggregate_stats is not None else {} + a_stats = dict(after.aggregate_stats.attrs) if after.aggregate_stats is not None else {} + + stat_keys = [ + "mean_error_m", + "median_error_m", + "rms_error_m", + "max_error_m", + "percent_below_250m", + "percent_below_500m", + ] + for key in stat_keys: + b_val = b_stats.get(key) + a_val = a_stats.get(key) + b_str = f"{b_val:.1f}" if isinstance(b_val, (int, float)) else "N/A" + a_str = f"{a_val:.1f}" if isinstance(a_val, (int, float)) else "N/A" + lines.append(f"{key:<30} {b_str:>12} {a_str:>12}") + + lines.append("-" * 55) + lines.append( + f"{'percent_within_threshold':<30} " + f"{before.percent_within_threshold:>11.1f}% " + f"{after.percent_within_threshold:>11.1f}%" + ) + lines.append("-" * 55) + b_verdict = "PASS" if before.passed else "FAIL" + a_verdict = "PASS" if after.passed else "FAIL" + lines.append(f"{'Overall':<30} {b_verdict:>12} {a_verdict:>12}") + + return "\n".join(lines) diff --git a/docs/source/contents.rst b/docs/source/contents.rst index a11ae7c9..03686b08 100644 --- a/docs/source/contents.rst +++ b/docs/source/contents.rst @@ -3,8 +3,10 @@ :caption: Contents: users.md + gcp_regridding.md developers.md acknowledgements.md repo_data.md changelog.md - autoapi/index.rst \ No newline at end of file + autoapi/index.rst + diff --git a/docs/source/gcp_regridding.md b/docs/source/gcp_regridding.md new file mode 100644 index 00000000..651f06f0 --- /dev/null +++ b/docs/source/gcp_regridding.md @@ -0,0 +1,242 @@ +# GCP Chip Regridding + +Ground Control Point (GCP) reference imagery from missions such as CLARREO +Pathfinder arrives as raw HDF files (Landsat format) in which each pixel's +position is stored as Earth-Centered, Earth-Fixed (ECEF) X/Y/Z coordinates on +an irregular geodetic grid. Before these chips can be matched against L1A +science images, they must be resampled onto a regular latitude/longitude grid +that matches the spatial resolution of the mission's detector. + +The `curryer.correction` package provides a complete pipeline for this: + +``` +HDF chip → ECEF → WGS84 geodetic → bilinear regrid → regular NetCDF grid +``` + +--- + +## Quick start — single file (Python) + +```python +from pathlib import Path +from curryer.correction.data_structures import RegridConfig +from curryer.correction.image_io import load_gcp_chip_from_hdf +from curryer.correction.regrid import regrid_gcp_chip + +# 1. Load raw chip (returns band data + ECEF X/Y/Z arrays) +band, ecef_x, ecef_y, ecef_z = load_gcp_chip_from_hdf( + Path("LT08CHP.20140803.p002r071.c01.v001.hdf") +) + +# 2. Configure regridding (~100 m resolution for CLARREO) +config = RegridConfig(output_resolution_deg=(0.0009, 0.0009)) + +# 3. Regrid and save in one call +regridded = regrid_gcp_chip( + band, + (ecef_x, ecef_y, ecef_z), + config, + output_file="regridded_chip.nc", + output_metadata={ + "source_file": "LT08CHP.20140803.p002r071.c01.v001.hdf", + "mission": "CLARREO Pathfinder", + "sensor": "Landsat-8", + "band": "red", + }, +) + +# 4. regridded is an ImageGrid — ready for image matching +print(regridded.data.shape) # e.g. (421, 433) +print(regridded.lat[0, 0]) # top-left latitude +``` + +--- + +## Batch processing — command-line script + +For a directory of 100 Landsat chips use the provided script directly: + +```bash +# Regrid all *.hdf files in /data/landsat_gcps/ and write NetCDF to /data/regridded/ +python scripts/regrid_gcp_chips.py /data/landsat_gcps/ /data/regridded/ +``` + +Output: + +``` +Processing 100 file(s) → /data/regridded/ + +[ 1/100] START LT08CHP.20140101.p002r071.c01.v001.hdf + ✓ LT08CHP.20140101.p002r071.c01.v001_regridded.nc (1823 KB, 4.2s) +[ 2/100] START LT08CHP.20140116.p002r071.c01.v001.hdf + ✓ LT08CHP.20140116.p002r071.c01.v001_regridded.nc (1791 KB, 4.0s) +... +──────────────────────────────────────────────────────────── +Finished: all 100 file(s) processed successfully. +``` + +### Common options + +| Flag | Default | Description | +| -------------------------- | -------------------- | ------------------------------------------------- | +| `--resolution DLAT DLON` | `0.0009 0.0009` | Output resolution in degrees (~100 m) | +| `--pattern GLOB` | `*.hdf` | Filename pattern when source is a directory | +| `--mission NAME` | `CLARREO Pathfinder` | Written to NetCDF `mission` attribute | +| `--band DATASET` | `Band_1` | HDF dataset name for the radiometric channel | +| `--skip-existing` | off | Skip files whose output `.nc` already exists | +| `--dry-run` | off | Print what would be done, write nothing | +| `--no-conservative-bounds` | off | Use full ECEF extent (may include edge artefacts) | +| `-v` / `--verbose` | off | Show per-row progress and DEBUG log output | + +### Resume an interrupted run + +```bash +# Already processed 60/100 — pick up from where it stopped +python scripts/regrid_gcp_chips.py /data/landsat_gcps/ /data/regridded/ --skip-existing +``` + +### Preview before committing + +```bash +python scripts/regrid_gcp_chips.py /data/landsat_gcps/ /data/regridded/ --dry-run +``` + +### Non-standard band or file pattern + +```bash +# Band_4 (near-IR), only files matching a date range +python scripts/regrid_gcp_chips.py /data/ /out/ \ + --pattern "LT08CHP.2016*.hdf" \ + --band Band_4 \ + --resolution 0.001 0.001 +``` + +--- + +## Batch processing — Python API + +Use this when you need custom logic (filtering, parallel execution, etc.): + +```python +from pathlib import Path +from curryer.correction.data_structures import RegridConfig +from curryer.correction.image_io import load_gcp_chip_from_hdf +from curryer.correction.regrid import regrid_gcp_chip + +input_dir = Path("/data/landsat_gcps") +output_dir = Path("/data/regridded") +output_dir.mkdir(parents=True, exist_ok=True) + +config = RegridConfig(output_resolution_deg=(0.0009, 0.0009)) + +hdf_files = sorted(input_dir.glob("LT08CHP.*.hdf")) +print(f"Found {len(hdf_files)} chips") + +errors = {} +for hdf_file in hdf_files: + nc_file = output_dir / f"{hdf_file.stem}_regridded.nc" + + try: + band, ecef_x, ecef_y, ecef_z = load_gcp_chip_from_hdf(hdf_file) + regrid_gcp_chip( + band, + (ecef_x, ecef_y, ecef_z), + config, + output_file=str(nc_file), + output_metadata={"source_file": hdf_file.name, "mission": "CLARREO Pathfinder"}, + ) + print(f" ✓ {hdf_file.name}") + except Exception as exc: + errors[hdf_file.name] = str(exc) + print(f" ✗ {hdf_file.name}: {exc}") + +if errors: + print(f"\n{len(errors)} file(s) failed:", *errors, sep="\n ") +``` + +### Loading the regridded output + +The output NetCDF files are `ImageGrid`-compatible and plug directly into the +correction pipeline: + +```python +from curryer.correction.image_io import load_gcp_chip_from_netcdf + +gcp = load_gcp_chip_from_netcdf(Path("regridded_chip.nc")) +# gcp.data — 2-D radiometric values +# gcp.lat — 2-D latitude (regular grid, decreasing from top to bottom) +# gcp.lon — 2-D longitude (regular grid, increasing left to right) +# gcp.h — 2-D height above WGS84 ellipsoid (metres), or None +``` + +--- + +## Configuration reference + +```python +from curryer.correction.data_structures import RegridConfig + +# Resolution-based (most common) +config = RegridConfig(output_resolution_deg=(0.0009, 0.0009)) + +# Fixed output size +config = RegridConfig(output_grid_size=(500, 500)) + +# Explicit geographic extent + resolution +config = RegridConfig( + output_bounds=(-116.5, -115.5, 38.0, 39.0), # (minlon, maxlon, minlat, maxlat) + output_resolution_deg=(0.001, 0.001), +) + +# Disable conservative clipping (use full ECEF extent) +config = RegridConfig( + output_resolution_deg=(0.0009, 0.0009), + conservative_bounds=False, +) +``` + +| Parameter | Type | Description | +| ----------------------- | ---------------------------------- | ------------------------------------------------------------------------- | +| `output_resolution_deg` | `(dlat, dlon)` | Grid spacing in degrees. Mutually exclusive with `output_grid_size`. | +| `output_grid_size` | `(nrows, ncols)` | Fixed output dimensions. Mutually exclusive with `output_resolution_deg`. | +| `output_bounds` | `(minlon, maxlon, minlat, maxlat)` | Explicit geographic extent. Requires `output_resolution_deg`. | +| `conservative_bounds` | `bool` (default `True`) | Shrink bounds to grid interior to avoid edge extrapolation. | +| `interpolation_method` | `"bilinear"` \| `"nearest"` | Interpolation algorithm (default `"bilinear"`). | +| `fill_value` | `float` (default `NaN`) | Value assigned to output points outside the input footprint. | + +--- + +## Output NetCDF structure + +``` +dimensions: y (rows), x (cols) +variables: + band_data(y, x) — radiometric values [digital_number] + lat(y, x) — latitude [degrees_north] + lon(y, x) — longitude [degrees_east] + h(y, x) — height above WGS84 [metres] (present when height is available) +global attributes: + title, Conventions (CF-1.8), source_file, mission, band, processing_software, … +``` + +> **Coordinate convention:** row 0 is northernmost (latitude decreases down), +> column 0 is westernmost (longitude increases right) — consistent with the +> MATLAB `Chip_regrid2.m` output format. + +--- + +## Notes + +- **Ellipsoid:** ECEF → geodetic conversion always uses **WGS84** + (`curryer.compute.spatial.ecef_to_geodetic`). This is correct for + Landsat-8/9 GCP chips. + +- **HDF format:** `load_gcp_chip_from_hdf` tries HDF4 (`pyhdf`) first, then + falls back to HDF5 (`h5py`). Both are handled transparently. + +- **Memory:** a 1400 × 1400 Landsat chip uses ~30 MB RAM; 100 chips processed + sequentially stay well within a 4 GB budget. If memory is tight, process + files one at a time (the CLI script already does this). + +- **Performance:** each chip typically takes 3–6 seconds on a single CPU core. + 100 chips ≈ 5–10 minutes. diff --git a/pyproject.toml b/pyproject.toml index ea930c69..aabe1a76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,8 @@ dependencies = [ "openpyxl>=3", "numpy >=1.23,<2.0", "boto3", - "typing-extensions" + "typing-extensions", + "pydantic>=2.0", ] [project.optional-dependencies] diff --git a/scripts/clarreo_preprocess.py b/scripts/clarreo_preprocess.py new file mode 100644 index 00000000..f95549dc --- /dev/null +++ b/scripts/clarreo_preprocess.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +"""CLARREO Pathfinder telemetry and science preprocessing script. + +This script converts CLARREO-specific raw telemetry CSV files into clean, +pipeline-ready CSVs that the correction pipeline can load directly via +:class:`~curryer.correction.config.DataConfig`. + +CLARREO-specific preprocessing steps +-------------------------------------- +Telemetry: + - Load four raw CSVs: ``SC_SPK``, ``SC_CK``, ``ST_CK``, ``AZEL_CK`` + - Reverse the azimuth sign (``hps.az_ang_nonlin *= -1``) + - Convert star-tracker DCM columns to quaternion columns + - Outer-join and sort all four sources on ``ert`` + - Compute combined second + sub-second timetags + +Science: + - Load science-frame timing CSV + - **No** time scaling is applied here — the pipeline handles that via + ``DataConfig.time_scale_factor``. The output ``corrected_timestamp`` + column contains GPS seconds (the raw unit from the instrument). + +Usage +----- +As a library (called from tests or other scripts):: + + from scripts.clarreo_preprocess import preprocess_clarreo_telemetry, preprocess_clarreo_science + import pandas as pd, tempfile, pathlib + + data_dir = pathlib.Path("tests/data/clarreo/gcs") + tlm_df = preprocess_clarreo_telemetry(data_dir) # → pd.DataFrame + sci_df = preprocess_clarreo_science(data_dir) # → pd.DataFrame + + # Save to CSVs for the pipeline + tlm_df.to_csv("telemetry.csv") + sci_df.to_csv("science.csv") + +As a CLI:: + + python scripts/clarreo_preprocess.py --data-dir tests/data/clarreo/gcs --output-dir /tmp/clarreo_out +""" + +from __future__ import annotations + +import argparse +import logging +from pathlib import Path + +import numpy as np +import pandas as pd + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Public preprocessing functions +# --------------------------------------------------------------------------- + + +def preprocess_clarreo_telemetry(data_dir: Path | str) -> pd.DataFrame: + """Preprocess CLARREO raw telemetry CSVs into a single clean DataFrame. + + Parameters + ---------- + data_dir : Path + Directory containing the four CLARREO telemetry CSV files. + + Returns + ------- + pd.DataFrame + Merged and preprocessed telemetry DataFrame ready for the correction + pipeline. + """ + try: + from curryer import spicierpy as sp + except ImportError as exc: # pragma: no cover + raise ImportError("curryer.spicierpy is required for DCM→quaternion conversion.") from exc + + data_dir = Path(data_dir) + logger.info(f"Loading CLARREO telemetry CSVs from: {data_dir}") + + sc_spk_df = pd.read_csv(data_dir / "openloop_tlm_5a_sc_spk_20250521T225242.csv", index_col=0) + sc_ck_df = pd.read_csv(data_dir / "openloop_tlm_5a_sc_ck_20250521T225242.csv", index_col=0) + st_ck_df = pd.read_csv(data_dir / "openloop_tlm_5a_st_ck_20250521T225242.csv", index_col=0) + azel_ck_df = pd.read_csv(data_dir / "openloop_tlm_5a_azel_ck_20250521T225242.csv", index_col=0) + + logger.info( + f"Loaded – SC_SPK: {sc_spk_df.shape}, SC_CK: {sc_ck_df.shape}, " + f"ST_CK: {st_ck_df.shape}, AZEL_CK: {azel_ck_df.shape}" + ) + + # Reverse the direction of the Azimuth element (CLARREO-specific) + azel_ck_df["hps.az_ang_nonlin"] = azel_ck_df["hps.az_ang_nonlin"] * -1 + + # Convert star-tracker DCM to quaternion (CLARREO-specific) + tlm_st_rot = np.vstack( + [ + st_ck_df["hps.dcm_base_iss_1_1"].values, + st_ck_df["hps.dcm_base_iss_1_2"].values, + st_ck_df["hps.dcm_base_iss_1_3"].values, + st_ck_df["hps.dcm_base_iss_2_1"].values, + st_ck_df["hps.dcm_base_iss_2_2"].values, + st_ck_df["hps.dcm_base_iss_2_3"].values, + st_ck_df["hps.dcm_base_iss_3_1"].values, + st_ck_df["hps.dcm_base_iss_3_2"].values, + st_ck_df["hps.dcm_base_iss_3_3"].values, + ] + ).T + tlm_st_rot = np.reshape(tlm_st_rot, (-1, 3, 3)).copy() + tlm_st_rot_q = np.vstack([sp.m2q(tlm_st_rot[i, :, :]) for i in range(tlm_st_rot.shape[0])]) + + st_ck_df["hps.dcm_base_iss_s"] = tlm_st_rot_q[:, 0] + st_ck_df["hps.dcm_base_iss_i"] = tlm_st_rot_q[:, 1] + st_ck_df["hps.dcm_base_iss_j"] = tlm_st_rot_q[:, 2] + st_ck_df["hps.dcm_base_iss_k"] = tlm_st_rot_q[:, 3] + + # Outer-join all four sources and sort by ERT + left_df = sc_spk_df + for right_df in [sc_ck_df, st_ck_df, azel_ck_df]: + left_df = pd.merge(left_df, right_df, on="ert", how="outer") + left_df = left_df.sort_values("ert") + + # Compute combined second + sub-second timetags (CLARREO-specific) + for col in list(left_df): + if col in ("hps.bad_ps_tms", "hps.corrected_tms", "hps.resolver_tms", "hps.st_quat_coi_tms"): + if col + "s" not in left_df.columns: + raise ValueError(f"Missing sub-second column for {col}") + if col == "hps.bad_ps_tms": + left_df[col + "_tmss"] = left_df[col] + left_df[col + "s"] / 256 + elif col in ("hps.corrected_tms", "hps.resolver_tms", "hps.st_quat_coi_tms"): + left_df[col + "_tmss"] = left_df[col] + left_df[col + "s"] / 2**32 + else: + raise ValueError(f"Missing conversion for expected column: {col}") + + logger.info(f"Final telemetry shape: {left_df.shape}") + return left_df + + +def preprocess_clarreo_science(data_dir: Path | str) -> pd.DataFrame: + """Preprocess CLARREO science frame timing CSV into a clean DataFrame. + + The ``corrected_timestamp`` column is returned in GPS seconds (the raw + instrument unit). Set ``DataConfig.time_scale_factor = 1e6`` in your + :class:`~curryer.correction.config.CorrectionConfig` to convert to uGPS + during pipeline loading. + + Parameters + ---------- + data_dir : Path + Directory containing the science timing CSV. + + Returns + ------- + pd.DataFrame + Science DataFrame with ``corrected_timestamp`` in GPS seconds. + """ + data_dir = Path(data_dir) + logger.info(f"Loading CLARREO science CSV from: {data_dir}") + + sci_time_df = pd.read_csv(data_dir / "openloop_tlm_5a_sci_times_20250521T225242.csv", index_col=0) + + logger.info( + f"Science data shape: {sci_time_df.shape}, " + f"corrected_timestamp range: " + f"{sci_time_df['corrected_timestamp'].min():.3f} – " + f"{sci_time_df['corrected_timestamp'].max():.3f} GPS sec" + ) + return sci_time_df + + +# --------------------------------------------------------------------------- +# CLI entry point +# --------------------------------------------------------------------------- + + +def _build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser( + description="Preprocess CLARREO Pathfinder raw telemetry/science CSVs.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + p.add_argument( + "--data-dir", + type=Path, + required=True, + help="Directory containing the raw CLARREO CSV files.", + ) + p.add_argument( + "--output-dir", + type=Path, + default=Path("."), + help="Directory where preprocessed CSVs will be written.", + ) + p.add_argument( + "--log-level", + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + ) + return p + + +def main() -> None: + args = _build_parser().parse_args() + logging.basicConfig(level=getattr(logging, args.log_level), format="%(levelname)s %(name)s: %(message)s") + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + tlm_path = output_dir / "clarreo_telemetry_preprocessed.csv" + sci_path = output_dir / "clarreo_science_preprocessed.csv" + + logger.info("Preprocessing CLARREO telemetry…") + tlm_df = preprocess_clarreo_telemetry(args.data_dir) + tlm_df.to_csv(tlm_path) + logger.info(f" → {tlm_path}") + + logger.info("Preprocessing CLARREO science frame timing…") + sci_df = preprocess_clarreo_science(args.data_dir) + sci_df.to_csv(sci_path) + logger.info(f" → {sci_path}") + + logger.info("Done.") + print(f"Telemetry: {tlm_path}") + print(f"Science: {sci_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/example_regrid_to_netcdf.py b/scripts/example_regrid_to_netcdf.py new file mode 100644 index 00000000..ec09564c --- /dev/null +++ b/scripts/example_regrid_to_netcdf.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python +""" +Example: Regrid GCP chip and save to NetCDF + +This script demonstrates the new NetCDF output functionality for GCP regridding. +It can be used as a template for batch processing GCP chips. +""" + +from pathlib import Path + +import numpy as np + +from curryer.correction.image_io import load_gcp_chip_from_hdf +from curryer.correction.regrid import RegridConfig, regrid_gcp_chip + + +def regrid_and_save_gcp_chip( + input_file: Path, + output_file: Path, + mission: str = "CLARREO Pathfinder", + resolution_deg: tuple[float, float] = (0.0009, 0.0009), +) -> None: + """ + Load, regrid, and save a GCP chip to NetCDF. + + Parameters + ---------- + input_file : Path + Input HDF file containing raw GCP chip. + output_file : Path + Output NetCDF file path. + mission : str + Mission name for metadata. + resolution_deg : tuple[float, float] + Output resolution (dlat, dlon) in degrees. + """ + print(f"Processing: {input_file.name}") + + # Load raw GCP chip + print(" Loading raw chip...") + band, x, y, z = load_gcp_chip_from_hdf(input_file) + print(f" Input shape: {band.shape}") + + # Configure regridding + config = RegridConfig( + output_resolution_deg=resolution_deg, + conservative_bounds=True, + interpolation_method="bilinear", + ) + + # Prepare metadata + metadata = { + "source_file": input_file.name, + "mission": mission, + "sensor": "Landsat-8", + "band": "red (Band 1)", + "resolution_deg": f"{resolution_deg[0]}°, {resolution_deg[1]}°", + "processing_software": "curryer", + } + + # Regrid and save + print(" Regridding...") + regridded = regrid_gcp_chip( + band, + (x, y, z), + config, + output_file=str(output_file), + output_metadata=metadata, + ) + + print(f" Output shape: {regridded.data.shape}") + print(f" Valid pixels: {(~np.isnan(regridded.data)).sum()}/{regridded.data.size}") + print(f" Saved to: {output_file}") + print(f" File size: {output_file.stat().st_size / (1024**2):.2f} MB") + + +def main(): + """Example usage: Process a single GCP chip.""" + # Example file paths + input_file = Path("tests/data/clarreo/landsat_gcp/LT08CHP.20140803.p002r071.c01.v001.hdf") + output_dir = Path("output/regridded_gcps") + output_dir.mkdir(parents=True, exist_ok=True) + + output_file = output_dir / "LT08CHP.20140803.p002r071.c01.v001_regridded.nc" + + if not input_file.exists(): + print(f"Error: Input file not found: {input_file}") + print("\nUpdate the script with your GCP chip path.") + return + + # Process with CLARREO-specific resolution (~100m) + regrid_and_save_gcp_chip( + input_file, + output_file, + mission="CLARREO Pathfinder", + resolution_deg=(0.0009, 0.0009), + ) + + print("\n✓ Processing complete!") + + +def batch_process_example(): + """Example: Batch process multiple GCP chips.""" + input_dir = Path("tests/data/clarreo/landsat_gcp") + output_dir = Path("output/regridded_gcps") + output_dir.mkdir(parents=True, exist_ok=True) + + # Find all HDF files + hdf_files = list(input_dir.glob("LT08CHP.*.hdf")) + + if not hdf_files: + print(f"No HDF files found in {input_dir}") + return + + print(f"Found {len(hdf_files)} GCP chips to process\n") + + for input_file in hdf_files: + output_file = output_dir / f"{input_file.stem}_regridded.nc" + + try: + regrid_and_save_gcp_chip(input_file, output_file) + except Exception as e: + print(f" ERROR: {e}") + continue + + print() + + print(f"✓ Batch processing complete! Processed {len(hdf_files)} chips.") + + +if __name__ == "__main__": + # Run single file example + main() + + # Uncomment to run batch processing + # batch_process_example() diff --git a/scripts/example_verification_minimal.py b/scripts/example_verification_minimal.py new file mode 100755 index 00000000..6697d05e --- /dev/null +++ b/scripts/example_verification_minimal.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python +""" +Minimal example: Run full verification on geolocated observations. + +This example demonstrates the production verification workflow: +1. Load geolocated science data (output of geolocation pipeline) +2. Run image matching against ground truth +3. Compute error statistics +4. Check against geolocation requirements + +This example uses synthetic data for demonstration. In production: +- Load real geolocated data from your pipeline +- Use actual ground truth GCP locations +- Run with real kernels and parameters + +Use this to understand the verification API before deploying to production. +""" + +import tempfile +from pathlib import Path + +import numpy as np +import xarray as xr + +from curryer.correction.config import ( + CorrectionConfig, + GeolocationConfig, + ParameterConfig, + ParameterType, +) +from curryer.correction.verification import verify + + +def create_synthetic_geolocated_data(n_measurements: int = 50, seed: int = 42) -> xr.Dataset: + """Create a synthetic geolocated science dataset. + + In production, this would be the output from your geolocation pipeline. + It should contain: + - Geolocated pixel observations (latitude, longitude, altitude) + - Spacecraft state data + - Timing information + + Parameters + ---------- + n_measurements : int + Number of pixel measurements. + seed : int + Random seed for reproducibility. + + Returns + ------- + xr.Dataset + Geolocated science data with spacecraft state. + """ + rng = np.random.RandomState(seed) + + # Simulated geolocated pixel observations + # (What your geolocation algorithm produced using current kernels/parameters) + geolocated_lat = rng.uniform(-90, 90, n_measurements) + geolocated_lon = rng.uniform(-180, 180, n_measurements) + geolocated_alt = rng.uniform(0, 5000, n_measurements) + + # Ground truth reference GCP locations + # (Fixed reference points, never change) + gcp_lat = rng.uniform(-90, 90, n_measurements) + gcp_lon = rng.uniform(-180, 180, n_measurements) + gcp_alt = rng.uniform(0, 5000, n_measurements) + + # Spacecraft state (from SPICE kernels) + spacecraft_pos = rng.normal(6.371e6, 1e5, (n_measurements, 3)) # ECEF meters + boresight = rng.normal(0, 1, (n_measurements, 3)) + boresight = boresight / np.linalg.norm(boresight, axis=1, keepdims=True) + rotation_matrices = rng.normal(0, 0.1, (n_measurements, 3, 3)) + + # Sensor/timing info + band_data = rng.uniform(0, 65535, n_measurements).astype(np.uint16) + + ds = xr.Dataset( + { + # Geolocated measurements (from geolocation pipeline) + "latitude": (["measurement"], geolocated_lat), + "longitude": (["measurement"], geolocated_lon), + "altitude": (["measurement"], geolocated_alt), + # Ground truth for image matching + "gcp_lat_deg": (["measurement"], gcp_lat), + "gcp_lon_deg": (["measurement"], gcp_lon), + "gcp_alt": (["measurement"], gcp_alt), + # Spacecraft state (from SPICE kernels) + "riss_ctrs": (["measurement", "component"], spacecraft_pos), + "bhat_hs": (["measurement", "component"], boresight), + "t_hs2ctrs": (["measurement", "matrix_row", "matrix_col"], rotation_matrices), + # Sensor data + "radiance": (["measurement"], band_data), + }, + coords={"measurement": np.arange(n_measurements)}, + ) + + ds.attrs["mission"] = "CLARREO" + ds.attrs["date"] = "2024-03-17" + + return ds + + +def create_minimal_config() -> CorrectionConfig: + """Create a minimal CorrectionConfig with a stub image matching override. + + The matching function is attached via ``config._image_matching_override`` + after construction. This is the correct pattern for test/example injection; + ``image_matching_func`` (the old public field) is deprecated. + """ + + def stub_image_matching(data: xr.Dataset) -> xr.Dataset: + """Stub image matching function. + + In production, this would run sophisticated image matching between + geolocated observations and reference imagery. + + For this example, we just compute simple lat/lon differences. + """ + n = len(data.measurement) + + # Compute errors (difference between geolocated and ground truth) + lat_err = data["latitude"].values - data["gcp_lat_deg"].values + lon_err = data["longitude"].values - data["gcp_lon_deg"].values + + result = xr.Dataset( + { + "lat_error_deg": (["measurement"], lat_err), + "lon_error_deg": (["measurement"], lon_err), + # Include spacecraft state for error_stats processor + "riss_ctrs": (["measurement", "component"], data["riss_ctrs"].values), + "bhat_hs": (["measurement", "component"], data["bhat_hs"].values), + "t_hs2ctrs": (["measurement", "matrix_row", "matrix_col"], data["t_hs2ctrs"].values), + "gcp_lat_deg": (["measurement"], data["gcp_lat_deg"].values), + "gcp_lon_deg": (["measurement"], data["gcp_lon_deg"].values), + "gcp_alt": (["measurement"], data["gcp_alt"].values), + }, + coords={"measurement": np.arange(n)}, + ) + return result + + config = CorrectionConfig( + n_iterations=1, + parameters=[ + ParameterConfig( + ptype=ParameterType.CONSTANT_KERNEL, + data={"current_value": [0.0, 0.0, 0.0], "bounds": [-300.0, 300.0]}, + ) + ], + geo=GeolocationConfig( + meta_kernel_file=Path("test.tm.json"), + generic_kernel_dir=Path("data/generic"), + instrument_name="TEST_INSTRUMENT", + time_field="corrected_timestamp", + ), + performance_threshold_m=250.0, + performance_spec_percent=39.0, + spacecraft_position_name="riss_ctrs", + boresight_name="bhat_hs", + transformation_matrix_name="t_hs2ctrs", + ) + # Attach matching override after construction (PrivateAttr – not a constructor arg). + # In production, omit this: pipeline.image_matching() runs automatically. + config._image_matching_override = stub_image_matching + return config + + +def main(): + """Run the production verification workflow.""" + print("=" * 70) + print("CLARREO Weekly Verification (Production Workflow)") + print("=" * 70) + + # Step 1: Load geolocated observations + print("\n1. Loading geolocated observations...") + geolocated = create_synthetic_geolocated_data(n_measurements=50, seed=42) + print(f" Loaded {len(geolocated.measurement)} measurements") + print(f" Variables: {list(geolocated.data_vars)}") + + # Step 2: Create config with image matching function + print("\n2. Setting up verification config...") + config = create_minimal_config() + print(f" _image_matching_override: {config._image_matching_override.__name__}") + print(f" Threshold: {config.performance_threshold_m}m") + print(f" Spec: {config.performance_spec_percent}%") + + # Step 3: Run full verification + print("\n3. Running verification...") + print(" - Running image matching on geolocated data") + print(" - Computing error statistics") + print(" - Checking against requirements") + print() + + # Use secure temporary directory (or omit work_dir to use default ./verification_output) + with tempfile.TemporaryDirectory() as tmpdir: + output_dir = Path(tmpdir) / "verification_example" + output_dir.mkdir(exist_ok=True) + + result = verify( + config=config, + geolocated_data=geolocated, + work_dir=output_dir, # Optional - defaults to ./verification_output if not provided + ) + + # Display results + print("\n4. Verification Summary") + print(result.summary_table) + + print("\n5. Result Details") + print(f" Status: {'PASSED ✓' if result.passed else 'FAILED ✗'}") + print(f" Percent within threshold: {result.percent_within_threshold:.1f}%") + print(f" Requirement: {result.requirements.performance_spec_percent}%") + print(f" Threshold: {result.requirements.performance_threshold_m}m") + print(f" Measurements analyzed: {len(result.per_gcp_errors)}") + + if result.warnings: + print("\n6. Warnings") + for warning in result.warnings: + print(f" ⚠️ {warning}") + + # Show per-measurement sample + print("\n7. Sample Per-Measurement Errors (first 5)") + for err in result.per_gcp_errors[:5]: + status = "✓ PASS" if err.passed else "✗ FAIL" + print( + f" #{err.gcp_index}: lat={err.lat_error_deg:+.5f}°, " + f"lon={err.lon_error_deg:+.5f}°, nadir={err.nadir_equiv_error_m:.1f}m {status}" + ) + + # Show output files + print("\n8. Output Files") + print(f" Work directory: {output_dir}") + if output_dir.exists(): + files = sorted(output_dir.glob("*")) + if files: + print(" Saved:") + for f in files: + if f.is_file(): + size_mb = f.stat().st_size / (1024**2) + print(f" - {f.name} ({size_mb:.2f} MB)") + else: + print(" (No files saved - use --save with CLI)") + + print("\n" + "=" * 70) + print(f"Verification {'PASSED ✓' if result.passed else 'FAILED ✗'}") + print("=" * 70) + + return 0 if result.passed else 1 + + +if __name__ == "__main__": + import sys + + sys.exit(main()) diff --git a/scripts/example_weekly_verification.py b/scripts/example_weekly_verification.py new file mode 100755 index 00000000..ba6f6b69 --- /dev/null +++ b/scripts/example_weekly_verification.py @@ -0,0 +1,328 @@ +#!/usr/bin/env python +""" +Example: Weekly verification on geolocated observations. + +This script demonstrates the production verification workflow: + +1. Geolocation pipeline runs (produces geolocated observations) +2. Load previously-geolocated science data +3. Run full verification (image matching against ground truth → error stats → threshold check) +4. Report compliance results + +Note: The geolocation itself is NOT done by the verification module. +The verification module runs image matching and error analysis on +output from your geolocation pipeline to check compliance. + +Setup +----- +Before running, ensure you have: +- A CorrectionConfig with: + - Proper kernel paths (SPICE kernels for geolocation) + - Mission parameters (spacecraft, instrument names, etc.) + - An image_matching_func (required - compares geolocated vs. ground truth) + - Performance thresholds (performance_threshold_m, performance_spec_percent) +- Geolocated science data (output of your geolocation pipeline) + +Example Usage +------------- +As a library:: + + from scripts.example_weekly_verification import run_weekly_verification + from pathlib import Path + import xarray as xr + + # Load this week's geolocated observations (from your geolocation pipeline) + geolocated = xr.open_dataset("weekly_2024-03-17.nc") + + # Run full verification (image matching + error stats + threshold check) + report = run_weekly_verification( + config=CorrectionConfig.from_json(Path("config/my_mission.json")), + geolocated_data=geolocated, + output_dir=Path("reports/2024-03-17/") + ) + + # Check result + if not report.passed: + send_alert(report.summary_table) + +As a CLI:: + + python scripts/example_weekly_verification.py \\ + --config config/my_mission.json \\ + --geolocated weekly_2024-03-17.nc \\ + --output-dir reports/2024-03-17/ + + # Or with multiple files (concatenated): + python scripts/example_weekly_verification.py \\ + --config config/my_mission.json \\ + --geolocated-dir weekly_data/ \\ + --pattern "*.nc" \\ + --output-dir reports/2024-03-17/ +""" + +from __future__ import annotations + +import argparse +import json +import logging +from pathlib import Path + +import xarray as xr + +from curryer.correction.config import CorrectionConfig +from curryer.correction.verification import VerificationResult, verify + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Typical workflow: Load and verify pre-computed image-matching results +# ============================================================================ + + +def run_weekly_verification( + config: CorrectionConfig, + geolocated_data: xr.Dataset | None = None, + geolocated_files: list[Path] | None = None, + output_dir: Path | None = None, + save_report: bool = True, +) -> VerificationResult: + """ + Run full verification on geolocated observations. + + This is the production workflow: load geolocated science data from your + geolocation pipeline, run full verification (image matching → error stats → + threshold check), and report compliance. + + The verification module does NOT perform geolocation. It verifies the + quality of geolocation output by comparing geolocated measurements against + ground truth using your image_matching_func. + + Parameters + ---------- + config : CorrectionConfig + Mission/instrument configuration with: + - Proper SPICE kernels for geolocation + - image_matching_func (required - compares geolocated vs. ground truth) + - performance_threshold_m and performance_spec_percent + geolocated_data : xr.Dataset, optional + Pre-loaded geolocated science data (output of your geolocation pipeline). + This is the recommended entry point when you have a single NetCDF file. + geolocated_files : list[Path], optional + List of NetCDF files to load and concatenate. Use when data spans + multiple files. Ignored if geolocated_data is provided. + output_dir : Path, optional + Directory for saving the verification report. Created if missing. + Defaults to ./verification_output. + save_report : bool, default=True + Whether to save JSON report and summary table to output_dir. + + Returns + ------- + VerificationResult + Structured result with pass/fail status, per-GCP errors, and + human-readable summary table. + + Raises + ------ + ValueError + If neither geolocated_data nor geolocated_files is provided, or if + config.image_matching_func is not set. + """ + if output_dir is None: + output_dir = Path.cwd() / "verification_output" + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + logger.info("=" * 70) + logger.info("Verification: Running on Geolocated Observations") + logger.info("=" * 70) + + # ---- + # Load geolocated data + # ---- + if geolocated_data is None: + if geolocated_files is None or not geolocated_files: + raise ValueError( + "Either geolocated_data or geolocated_files must be provided. " + "Supply NetCDF file(s) with geolocated science observations." + ) + logger.info(f"Loading geolocated data from {len(geolocated_files)} file(s)") + datasets = [xr.open_dataset(f) for f in geolocated_files] + # Concatenate along measurement dimension if multiple files + if len(datasets) == 1: + geolocated_data = datasets[0] + else: + geolocated_data = xr.concat(datasets, dim="measurement") + logger.info(f"Combined {len(datasets)} file(s): total shape {geolocated_data.dims}") + + # ---- + # Run verification + # ---- + result = verify( + config=config, + geolocated_data=geolocated_data, + work_dir=output_dir if output_dir else None, # Optional, defaults to ./verification_output + ) + + # ---- + # Print summary + # ---- + print("\n" + result.summary_table) + + if result.warnings: + print("\n⚠️ Warnings:") + for w in result.warnings: + print(f" {w}") + + # ---- + # Save report if requested + # ---- + if save_report: + _save_verification_report(result, output_dir) + + return result + + +def _save_verification_report(result: VerificationResult, output_dir: Path) -> None: + """Save verification result to JSON file and summary table text file.""" + # Save JSON report (excluding xarray Dataset which is not JSON-serializable) + json_file = output_dir / "verification_report.json" + dumped = result.model_dump(exclude={"aggregate_stats"}) + with open(json_file, "w") as f: + json.dump(dumped, f, indent=2, default=str) # default=str for datetime + logger.info(f"Saved JSON report: {json_file}") + + # Save summary table + table_file = output_dir / "verification_summary.txt" + with open(table_file, "w") as f: + f.write(result.summary_table) + logger.info(f"Saved summary table: {table_file}") + + # Save xarray aggregate stats if available (useful for post-processing) + if result.aggregate_stats is not None: + stats_file = output_dir / "aggregate_stats.nc" + result.aggregate_stats.to_netcdf(stats_file) + logger.info(f"Saved aggregate statistics: {stats_file}") + + +# ============================================================================ +# CLI entry point +# ============================================================================ + + +def main(): + parser = argparse.ArgumentParser( + description="Verification on geolocated observations (mission-agnostic).", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + parser.add_argument( + "--config", + type=Path, + required=True, + help="Path to CorrectionConfig JSON file.", + ) + + input_group = parser.add_mutually_exclusive_group(required=True) + input_group.add_argument( + "--geolocated", + type=Path, + help="NetCDF file with geolocated science data from this week.", + ) + input_group.add_argument( + "--geolocated-dir", + type=Path, + help="Directory containing geolocated NetCDF files to concatenate.", + ) + + parser.add_argument( + "--pattern", + default="*.nc", + help="Glob pattern for matching geolocated files (default: *.nc).", + ) + + parser.add_argument( + "--output-dir", + type=Path, + default=None, + help="Directory for verification report (default: ./verification_output).", + ) + + parser.add_argument( + "--no-save", + action="store_true", + help="Do not save JSON report and summary table.", + ) + + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Enable debug logging.", + ) + + args = parser.parse_args() + + # Setup logging + log_level = logging.DEBUG if args.verbose else logging.INFO + logging.basicConfig( + level=log_level, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + ) + + # Load config + logger.info(f"Loading config from {args.config}") + with open(args.config) as f: + config_dict = json.load(f) + config = CorrectionConfig(**config_dict) + + # Check that image_matching_func is set + if config.image_matching_func is None: + logger.error( + "config.image_matching_func is not set. " + "Verification requires an image-matching function to run on geolocated data." + ) + return 1 + + # Load geolocated data + if args.geolocated: + logger.info(f"Loading geolocated data from {args.geolocated}") + geolocated = xr.open_dataset(args.geolocated) + + # Run verification + result = run_weekly_verification( + config=config, + geolocated_data=geolocated, + output_dir=args.output_dir, + save_report=not args.no_save, + ) + + else: # args.geolocated_dir + logger.info(f"Loading geolocated files from {args.geolocated_dir}") + geolocated_dir = Path(args.geolocated_dir) + geolocated_files = sorted(geolocated_dir.glob(args.pattern)) + logger.info(f"Found {len(geolocated_files)} matching files") + + if not geolocated_files: + logger.error(f"No files matching pattern '{args.pattern}' in {args.geolocated_dir}") + return 1 + + # Run verification + result = run_weekly_verification( + config=config, + geolocated_files=geolocated_files, + output_dir=args.output_dir, + save_report=not args.no_save, + ) + + # Exit with appropriate code + return 0 if result.passed else 1 + + +if __name__ == "__main__": + import sys + + sys.exit(main()) diff --git a/scripts/regrid_gcp_chips.py b/scripts/regrid_gcp_chips.py new file mode 100644 index 00000000..81902910 --- /dev/null +++ b/scripts/regrid_gcp_chips.py @@ -0,0 +1,302 @@ +#!/usr/bin/env python +"""Batch-regrid Landsat (or other HDF) GCP chips to regular lat/lon NetCDF files. + +Usage +----- +Single file:: + + python scripts/regrid_gcp_chips.py input.hdf output_dir/ + +Directory of HDF files:: + + python scripts/regrid_gcp_chips.py /data/landsat_gcps/ /data/regridded/ + +Custom resolution and glob pattern:: + + python scripts/regrid_gcp_chips.py /data/ /out/ \\ + --pattern "LT08CHP.*.hdf" \\ + --resolution 0.001 0.001 + +Resume an interrupted run (skip already-converted files):: + + python scripts/regrid_gcp_chips.py /data/ /out/ --skip-existing + +Preview what would be processed without writing any files:: + + python scripts/regrid_gcp_chips.py /data/ /out/ --dry-run + +See all options:: + + python scripts/regrid_gcp_chips.py --help +""" + +from __future__ import annotations + +import argparse +import logging +import sys +import time +from pathlib import Path + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Core per-file processing +# --------------------------------------------------------------------------- + + +def regrid_one( + input_file: Path, + output_file: Path, + resolution_deg: tuple[float, float], + conservative_bounds: bool, + mission: str, + band_name: str, +) -> None: + """Load one HDF chip, regrid it, and write the result to a NetCDF file. + + Parameters + ---------- + input_file : Path + Source HDF4 or HDF5 chip file. + output_file : Path + Destination NetCDF file. Parent directory must already exist. + resolution_deg : tuple[float, float] + Output grid resolution as ``(dlat, dlon)`` in degrees. + conservative_bounds : bool + Shrink output bounds to the interior of the input grid when ``True`` + (avoids edge extrapolation artefacts). + mission : str + Mission label written to the ``mission`` global attribute. + band_name : str + HDF dataset name for the radiometric band (default ``"Band_1"``). + """ + from curryer.correction.data_structures import RegridConfig + from curryer.correction.image_io import load_gcp_chip_from_hdf + from curryer.correction.regrid import regrid_gcp_chip + + band, ecef_x, ecef_y, ecef_z = load_gcp_chip_from_hdf(input_file, band_name=band_name) + + config = RegridConfig( + output_resolution_deg=resolution_deg, + conservative_bounds=conservative_bounds, + interpolation_method="bilinear", + ) + + metadata = { + "source_file": input_file.name, + "mission": mission, + "band": band_name, + "resolution_deg": f"{resolution_deg[0]}, {resolution_deg[1]}", + "processing_software": "curryer", + } + + regrid_gcp_chip( + band, + (ecef_x, ecef_y, ecef_z), + config, + output_file=str(output_file), + output_metadata=metadata, + ) + + +# --------------------------------------------------------------------------- +# Batch driver +# --------------------------------------------------------------------------- + + +def collect_inputs(source: Path, pattern: str) -> list[Path]: + """Return a sorted list of HDF files to process. + + Parameters + ---------- + source : Path + Either a single HDF file or a directory to search. + pattern : str + Glob pattern applied when *source* is a directory. + """ + if source.is_file(): + return [source] + files = sorted(source.glob(pattern)) + if not files: + logger.warning("No files matched pattern %r in %s", pattern, source) + return files + + +def output_path_for(input_file: Path, output_dir: Path) -> Path: + """Derive the output NetCDF path from an input HDF filename.""" + return output_dir / f"{input_file.stem}_regridded.nc" + + +def run_batch( + input_files: list[Path], + output_dir: Path, + resolution_deg: tuple[float, float], + conservative_bounds: bool, + skip_existing: bool, + dry_run: bool, + mission: str, + band_name: str, +) -> int: + """Process *input_files* and return the number of failures.""" + total = len(input_files) + failures: list[tuple[Path, str]] = [] + + print(f"{'DRY RUN — ' if dry_run else ''}Processing {total} file(s) → {output_dir}\n") + + for idx, input_file in enumerate(input_files, start=1): + output_file = output_path_for(input_file, output_dir) + prefix = f"[{idx:>{len(str(total))}}/{total}]" + + if skip_existing and output_file.exists(): + print(f"{prefix} SKIP {input_file.name} (already exists)") + continue + + print(f"{prefix} START {input_file.name}") + + if dry_run: + print(f" → {output_file}") + continue + + t0 = time.monotonic() + try: + regrid_one(input_file, output_file, resolution_deg, conservative_bounds, mission, band_name) + elapsed = time.monotonic() - t0 + size_kb = output_file.stat().st_size / 1024 + print(f" ✓ {output_file.name} ({size_kb:.0f} KB, {elapsed:.1f}s)") + except Exception as exc: # noqa: BLE001 + elapsed = time.monotonic() - t0 + msg = str(exc) + failures.append((input_file, msg)) + print(f" ✗ FAILED ({elapsed:.1f}s): {msg}") + logger.debug("Traceback for %s:", input_file.name, exc_info=True) + + # Summary + print(f"\n{'─' * 60}") + if dry_run: + print(f"Dry run complete — {total} file(s) would be processed.") + elif failures: + print(f"Finished: {total - len(failures)}/{total} succeeded, {len(failures)} failed.") + print("\nFailed files:") + for path, reason in failures: + print(f" {path.name}: {reason}") + else: + print(f"Finished: all {total} file(s) processed successfully.") + + return len(failures) + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="regrid_gcp_chips", + description="Batch-regrid HDF GCP chips to regular lat/lon NetCDF files.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "source", + type=Path, + help="Single HDF file or directory containing HDF files.", + ) + parser.add_argument( + "output_dir", + type=Path, + help="Directory where regridded NetCDF files are written (created if needed).", + ) + parser.add_argument( + "--pattern", + default="*.hdf", + metavar="GLOB", + help="Glob pattern for HDF files when source is a directory (default: '*.hdf').", + ) + parser.add_argument( + "--resolution", + nargs=2, + type=float, + default=[0.0009, 0.0009], + metavar=("DLAT", "DLON"), + help="Output grid resolution in degrees (default: 0.0009 0.0009 ≈ 100 m).", + ) + parser.add_argument( + "--no-conservative-bounds", + dest="conservative_bounds", + action="store_false", + default=True, + help="Use full ECEF extent instead of shrinking bounds to avoid edge extrapolation.", + ) + parser.add_argument( + "--mission", + default="CLARREO Pathfinder", + help="Mission name written to NetCDF metadata (default: 'CLARREO Pathfinder').", + ) + parser.add_argument( + "--band", + default="Band_1", + metavar="DATASET", + help="HDF dataset name for the radiometric band (default: 'Band_1').", + ) + parser.add_argument( + "--skip-existing", + action="store_true", + help="Skip files whose output NetCDF already exists (useful for resuming).", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print what would be done without writing any files.", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Enable DEBUG logging (shows per-row regridding progress).", + ) + return parser + + +def main(argv: list[str] | None = None) -> None: + parser = build_parser() + args = parser.parse_args(argv) + + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.WARNING, + format="%(levelname)s %(name)s: %(message)s", + ) + + source: Path = args.source + output_dir: Path = args.output_dir + resolution_deg = (args.resolution[0], args.resolution[1]) + + if not source.exists(): + parser.error(f"source does not exist: {source}") + + if not args.dry_run: + output_dir.mkdir(parents=True, exist_ok=True) + + input_files = collect_inputs(source, args.pattern) + if not input_files: + print(f"No files found. Check --pattern (currently {args.pattern!r}).") + sys.exit(1) + + n_failures = run_batch( + input_files=input_files, + output_dir=output_dir, + resolution_deg=resolution_deg, + conservative_bounds=args.conservative_bounds, + skip_existing=args.skip_existing, + dry_run=args.dry_run, + mission=args.mission, + band_name=args.band, + ) + + sys.exit(1 if n_failures else 0) + + +if __name__ == "__main__": + main() diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 00000000..9a9c14bf --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1,2 @@ +# Integration tests package + diff --git a/tests/test_correction/_synthetic_helpers.py b/tests/test_correction/_synthetic_helpers.py new file mode 100644 index 00000000..367c273b --- /dev/null +++ b/tests/test_correction/_synthetic_helpers.py @@ -0,0 +1,159 @@ +"""Synthetic test-data helpers shared by test_pipeline.py and clarreo e2e tests. + +These functions generate realistic-looking but entirely synthetic sensor data +(boresight vectors, spacecraft positions, transformation matrices, GCP pairs) +so that upstream pipeline tests can run without any real instrument data. + +**These are test infrastructure helpers – not pytest tests.** +""" + +from __future__ import annotations + +import logging + +import numpy as np +import xarray as xr + +logger = logging.getLogger(__name__) + + +class _PlaceholderConfig: + """Default parameters controlling synthetic data generation.""" + + base_error_m: float = 50.0 + param_error_scale: float = 10.0 + max_measurements: int = 100 + min_measurements: int = 10 + orbit_radius_mean_m: float = 6.78e6 + orbit_radius_std_m: float = 4e3 + latitude_range: tuple = (-60.0, 60.0) + longitude_range: tuple = (-180.0, 180.0) + altitude_range: tuple = (0.0, 1000.0) + max_off_nadir_rad: float = 0.1 + + +def synthetic_gcp_pairing(science_data_files): + """Return SYNTHETIC GCP pairs for upstream testing (no real GCP data needed).""" + logger.warning("USING SYNTHETIC GCP PAIRING - FAKE DATA!") + return [(str(f), f"landsat_gcp_{i:03d}.tif") for i, f in enumerate(science_data_files)] + + +def synthetic_image_matching( + geolocated_data, + gcp_reference_file, + telemetry, + calibration_dir, + params_info, + config, + los_vectors_cached=None, + optical_psfs_cached=None, +): + """Return SYNTHETIC image-matching results for upstream testing. + + Accepts (and ignores) the same signature as the real ``image_matching`` + function so the pipeline loop can call it transparently. + """ + logger.warning("USING SYNTHETIC IMAGE MATCHING - FAKE DATA!") + placeholder_cfg = ( + config.placeholder if hasattr(config, "placeholder") and config.placeholder else _PlaceholderConfig() + ) + sc_pos_name = getattr(config, "spacecraft_position_name", "sc_position") + boresight_name = getattr(config, "boresight_name", "boresight") + transform_name = getattr(config, "transformation_matrix_name", "t_inst2ref") + + valid_mask = ~np.isnan(geolocated_data["latitude"].values).any(axis=1) + n_valid = int(valid_mask.sum()) + n_meas = placeholder_cfg.min_measurements if n_valid == 0 else min(n_valid, placeholder_cfg.max_measurements) + + riss_ctrs = _generate_spherical_positions( + n_meas, placeholder_cfg.orbit_radius_mean_m, placeholder_cfg.orbit_radius_std_m + ) + boresights = _generate_synthetic_boresights(n_meas, placeholder_cfg.max_off_nadir_rad) + t_matrices = _generate_nadir_aligned_transforms(n_meas, riss_ctrs, boresights) + + param_contribution = ( + sum(abs(p) if isinstance(p, (int, float)) else np.linalg.norm(p) for _, p in params_info) + * placeholder_cfg.param_error_scale + ) + error_magnitude = placeholder_cfg.base_error_m + param_contribution + lat_errors = np.random.normal(0, error_magnitude / 111_000, n_meas) + lon_errors = np.random.normal(0, error_magnitude / 111_000, n_meas) + + if n_valid > 0: + idx = np.where(valid_mask)[0][:n_meas] + gcp_lat = geolocated_data["latitude"].values[idx, 0] + gcp_lon = geolocated_data["longitude"].values[idx, 0] + else: + gcp_lat = np.random.uniform(*placeholder_cfg.latitude_range, n_meas) + gcp_lon = np.random.uniform(*placeholder_cfg.longitude_range, n_meas) + + gcp_alt = np.random.uniform(*placeholder_cfg.altitude_range, n_meas) + + return xr.Dataset( + { + "lat_error_deg": (["measurement"], lat_errors), + "lon_error_deg": (["measurement"], lon_errors), + sc_pos_name: (["measurement", "xyz"], riss_ctrs), + boresight_name: (["measurement", "xyz"], boresights), + transform_name: (["measurement", "xyz_from", "xyz_to"], t_matrices), + "gcp_lat_deg": (["measurement"], gcp_lat), + "gcp_lon_deg": (["measurement"], gcp_lon), + "gcp_alt": (["measurement"], gcp_alt), + }, + coords={ + "measurement": range(n_meas), + "xyz": ["x", "y", "z"], + "xyz_from": ["x", "y", "z"], + "xyz_to": ["x", "y", "z"], + }, + ) + + +def _generate_synthetic_boresights(n, max_off_nadir_rad=0.07): + """Return *n* unit boresight vectors with small off-nadir angles.""" + b = np.zeros((n, 3)) + for i in range(n): + th = np.random.uniform(-max_off_nadir_rad, max_off_nadir_rad) + b[i] = [0.0, np.sin(th), np.cos(th)] + return b + + +def _generate_spherical_positions(n, radius_mean_m, radius_std_m): + """Return *n* random points on a sphere (spacecraft orbit positions).""" + pos = np.zeros((n, 3)) + for i in range(n): + r = np.random.normal(radius_mean_m, radius_std_m) + phi = np.random.uniform(0, 2 * np.pi) + ct = np.random.uniform(-1, 1) + st = np.sqrt(max(0.0, 1 - ct**2)) + pos[i] = [r * st * np.cos(phi), r * st * np.sin(phi), r * ct] + return pos + + +def _generate_nadir_aligned_transforms(n, riss_ctrs, boresights_hs): + """Return *n* rotation matrices aligning ``boresights_hs`` toward nadir.""" + T = np.zeros((n, 3, 3)) + for i in range(n): + nadir = -riss_ctrs[i] / np.linalg.norm(riss_ctrs[i]) + bhat = boresights_hs[i] / np.linalg.norm(boresights_hs[i]) + ax = np.cross(bhat, nadir) + ax_norm = np.linalg.norm(ax) + if ax_norm < 1e-6: + if np.dot(bhat, nadir) > 0: + T[i] = np.eye(3) + else: + perp = np.array([1, 0, 0]) if abs(bhat[0]) < 0.9 else np.array([0, 1, 0]) + ax = np.cross(bhat, perp) + ax /= np.linalg.norm(ax) + K = _skew(ax) + T[i] = np.eye(3) + 2 * K @ K + else: + ax /= ax_norm + angle = np.arccos(np.clip(np.dot(bhat, nadir), -1.0, 1.0)) + K = _skew(ax) + T[i] = np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * (K @ K) + return T + + +def _skew(v): + return np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) diff --git a/tests/test_correction/clarreo/__init__.py b/tests/test_correction/clarreo/__init__.py new file mode 100644 index 00000000..ffb6047f --- /dev/null +++ b/tests/test_correction/clarreo/__init__.py @@ -0,0 +1 @@ +"""CLARREO-specific correction tests.""" diff --git a/tests/test_correction/clarreo/_image_match_helpers.py b/tests/test_correction/clarreo/_image_match_helpers.py new file mode 100644 index 00000000..f240493e --- /dev/null +++ b/tests/test_correction/clarreo/_image_match_helpers.py @@ -0,0 +1,342 @@ +""" +Image-matching helpers for CLARREO integration tests. + +Provides utilities for discovering test cases, applying artificial +geolocation errors, and running image matching with those errors. +These are *test infrastructure helpers*, not pytest tests themselves. +""" + +from __future__ import annotations + +import logging +import time +from pathlib import Path + +import numpy as np +import xarray as xr +from scipy.io import loadmat + +from curryer.correction.data_structures import ImageGrid, PSFSamplingConfig, SearchConfig +from curryer.correction.image_match import ( + integrated_image_match, + load_image_grid_from_mat, + load_los_vectors_from_mat, + load_optical_psf_from_mat, +) + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Test-case metadata +# --------------------------------------------------------------------------- + +_TEST_CASE_METADATA: dict[str, dict] = { + "1": { + "name": "Dili", + "gcp_file": "GCP12055Dili_resampled.mat", + "ancil_file": "R_ISS_midframe_TestCase1.mat", + "expected_error_km": (3.0, -3.0), + "cases": [ + {"subimage": "TestCase1a_subimage.mat", "binned": False}, + {"subimage": "TestCase1b_subimage.mat", "binned": False}, + {"subimage": "TestCase1c_subimage_binned.mat", "binned": True}, + {"subimage": "TestCase1d_subimage_binned.mat", "binned": True}, + ], + }, + "2": { + "name": "Maracaibo", + "gcp_file": "GCP10121Maracaibo_resampled.mat", + "ancil_file": "R_ISS_midframe_TestCase2.mat", + "expected_error_km": (-3.0, 2.0), + "cases": [ + {"subimage": "TestCase2a_subimage.mat", "binned": False}, + {"subimage": "TestCase2b_subimage.mat", "binned": False}, + {"subimage": "TestCase2c_subimage_binned.mat", "binned": True}, + ], + }, + "3": { + "name": "Algeria3", + "gcp_file": "GCP10181Algeria3_resampled.mat", + "ancil_file": "R_ISS_midframe_TestCase3.mat", + "expected_error_km": (2.0, 3.0), + "cases": [ + {"subimage": "TestCase3a_subimage.mat", "binned": False}, + {"subimage": "TestCase3b_subimage_binned.mat", "binned": True}, + ], + }, + "4": { + "name": "Dunhuang", + "gcp_file": "GCP10142Dunhuang_resampled.mat", + "ancil_file": "R_ISS_midframe_TestCase4.mat", + "expected_error_km": (-2.0, -3.0), + "cases": [ + {"subimage": "TestCase4a_subimage.mat", "binned": False}, + {"subimage": "TestCase4b_subimage_binned.mat", "binned": True}, + ], + }, + "5": { + "name": "Algeria5", + "gcp_file": "GCP10071Algeria5_resampled.mat", + "ancil_file": "R_ISS_midframe_TestCase5.mat", + "expected_error_km": (1.0, -1.0), + "cases": [ + {"subimage": "TestCase5a_subimage.mat", "binned": False}, + ], + }, +} + + +def discover_test_image_match_cases( + test_data_dir: Path, + test_cases: list[str] | None = None, +) -> list[dict]: + """Scan *test_data_dir* for validated image-matching test cases. + + Parameters + ---------- + test_data_dir: + Root directory (e.g. ``tests/data/clarreo/image_match/``). + test_cases: + Specific case IDs to include (e.g. ``['1', '2']``). ``None`` means + all available cases. + + Returns + ------- + list[dict] + One dict per sub-case variant with keys: ``case_id``, ``case_name``, + ``subcase_name``, ``subimage_file``, ``gcp_file``, ``ancil_file``, + ``los_file``, ``psf_file``, ``expected_lat_error_km``, + ``expected_lon_error_km``, ``binned``. + """ + logger.info("Discovering image-matching test cases in: %s", test_data_dir) + + los_file = test_data_dir / "b_HS.mat" + psf_file_unbinned = test_data_dir / "optical_PSF_675nm_upsampled.mat" + psf_file_binned = test_data_dir / "optical_PSF_675nm_3_pix_binned_upsampled.mat" + + if not los_file.exists(): + raise FileNotFoundError(f"LOS vectors file not found: {los_file}") + if not psf_file_unbinned.exists(): + raise FileNotFoundError(f"PSF file not found: {psf_file_unbinned}") + + if test_cases is None: + test_cases = sorted(_TEST_CASE_METADATA.keys()) + + discovered: list[dict] = [] + for case_id in test_cases: + if case_id not in _TEST_CASE_METADATA: + logger.warning("Test case '%s' not in metadata, skipping", case_id) + continue + meta = _TEST_CASE_METADATA[case_id] + case_dir = test_data_dir / case_id + if not case_dir.is_dir(): + logger.warning("Test case directory not found: %s, skipping", case_dir) + continue + + for subcase in meta["cases"]: + subimage_file = case_dir / subcase["subimage"] + gcp_file = case_dir / meta["gcp_file"] + ancil_file = case_dir / meta["ancil_file"] + psf_file = psf_file_binned if subcase["binned"] else psf_file_unbinned + + if not subimage_file.exists(): + logger.warning("Subimage not found: %s, skipping", subimage_file) + continue + if not gcp_file.exists(): + logger.warning("GCP file not found: %s, skipping", gcp_file) + continue + if not ancil_file.exists(): + logger.warning("Ancil file not found: %s, skipping", ancil_file) + continue + + discovered.append( + { + "case_id": case_id, + "case_name": meta["name"], + "subcase_name": subcase["subimage"], + "subimage_file": subimage_file, + "gcp_file": gcp_file, + "ancil_file": ancil_file, + "los_file": los_file, + "psf_file": psf_file, + "expected_lat_error_km": meta["expected_error_km"][0], + "expected_lon_error_km": meta["expected_error_km"][1], + "binned": subcase["binned"], + } + ) + + logger.info("Discovered %d test-case variants", len(discovered)) + return discovered + + +# --------------------------------------------------------------------------- +# Error-variation helpers +# --------------------------------------------------------------------------- + + +def apply_error_variation_for_testing( + base_result: xr.Dataset, + param_idx: int, + error_variation_percent: float = 3.0, +) -> xr.Dataset: + """Apply random variation to image-matching results to simulate parameter effects. + + Uses *param_idx* as a reproducible random seed so each parameter set gets + a distinct but deterministic perturbation. + """ + output = base_result.copy(deep=True) + rng = np.random.default_rng(param_idx) + + vf = error_variation_percent / 100.0 + lat_factor = 1.0 + rng.normal(0, vf) + lon_factor = 1.0 + rng.normal(0, vf) + ccv_factor = 1.0 + rng.normal(0, vf / 10.0) + + orig_lat = base_result.attrs["lat_error_km"] + orig_lon = base_result.attrs["lon_error_km"] + orig_ccv = base_result.attrs["correlation_ccv"] + + varied_lat = orig_lat * lat_factor + varied_lon = orig_lon * lon_factor + varied_ccv = float(np.clip(orig_ccv * ccv_factor, 0.0, 1.0)) + + if "gcp_center_lat" in base_result.attrs: + gcp_center_lat = base_result.attrs["gcp_center_lat"] + elif "gcp_lat_deg" in base_result: + gcp_center_lat = float(base_result["gcp_lat_deg"].values[0]) + else: + gcp_center_lat = 45.0 + + lat_error_deg = varied_lat / 111.0 + lon_radius_km = 6378.0 * np.cos(np.deg2rad(gcp_center_lat)) + lon_error_deg = varied_lon / (lon_radius_km * np.pi / 180.0) + + output["lat_error_deg"].values[0] = lat_error_deg + output["lon_error_deg"].values[0] = lon_error_deg + output.attrs.update( + { + "lat_error_km": varied_lat, + "lon_error_km": varied_lon, + "correlation_ccv": varied_ccv, + "param_idx": param_idx, + "variation_applied": True, + } + ) + return output + + +def apply_geolocation_error_to_subimage( + subimage: ImageGrid, + gcp: ImageGrid, + lat_error_km: float, + lon_error_km: float, +) -> ImageGrid: + """Shift *subimage* coordinates by (lat_error_km, lon_error_km) for testing.""" + from curryer.compute import constants + + mid_lat = float(gcp.lat[gcp.lat.shape[0] // 2, gcp.lat.shape[1] // 2]) + earth_radius_km = constants.WGS84_SEMI_MAJOR_AXIS_KM + lat_offset_deg = lat_error_km / earth_radius_km * (180.0 / np.pi) + lon_radius_km = earth_radius_km * np.cos(np.deg2rad(mid_lat)) + lon_offset_deg = lon_error_km / lon_radius_km * (180.0 / np.pi) + + return ImageGrid( + data=subimage.data.copy(), + lat=subimage.lat + lat_offset_deg, + lon=subimage.lon + lon_offset_deg, + h=subimage.h.copy() if subimage.h is not None else None, + ) + + +def run_image_matching_with_applied_errors( + test_case: dict, + param_idx: int, + randomize_errors: bool = True, + error_variation_percent: float = 3.0, + cache_results: bool = True, + cached_result: xr.Dataset | None = None, +) -> xr.Dataset: + """Run image matching with artificial errors; return result Dataset. + + Uses *cached_result* + random variation for param_idx > 0 when + *cache_results* is True. + """ + if cached_result is not None and cache_results and param_idx > 0: + if randomize_errors: + return apply_error_variation_for_testing(cached_result, param_idx, error_variation_percent) + return cached_result.copy() + + logger.info("Running image matching: %s", test_case["case_name"]) + start = time.time() + + subimage_struct = loadmat(test_case["subimage_file"], squeeze_me=True, struct_as_record=False)["subimage"] + subimage = ImageGrid( + data=np.asarray(subimage_struct.data), + lat=np.asarray(subimage_struct.lat), + lon=np.asarray(subimage_struct.lon), + h=np.asarray(subimage_struct.h) if hasattr(subimage_struct, "h") else None, + ) + gcp = load_image_grid_from_mat(test_case["gcp_file"], key="GCP") + gcp_center_lat = float(gcp.lat[gcp.lat.shape[0] // 2, gcp.lat.shape[1] // 2]) + gcp_center_lon = float(gcp.lon[gcp.lon.shape[0] // 2, gcp.lon.shape[1] // 2]) + + subimage_with_error = apply_geolocation_error_to_subimage( + subimage, gcp, test_case["expected_lat_error_km"], test_case["expected_lon_error_km"] + ) + + los_vectors = load_los_vectors_from_mat(test_case["los_file"]) + optical_psfs = load_optical_psf_from_mat(test_case["psf_file"]) + ancil_data = loadmat(test_case["ancil_file"], squeeze_me=True) + r_iss_midframe = ancil_data["R_ISS_midframe"].ravel() + + result = integrated_image_match( + subimage=subimage_with_error, + gcp=gcp, + r_iss_midframe_m=r_iss_midframe, + los_vectors_hs=los_vectors, + optical_psfs=optical_psfs, + geolocation_config=PSFSamplingConfig(), + search_config=SearchConfig(), + ) + processing_time = time.time() - start + + lat_error_deg = result.lat_error_km / 111.0 + lon_radius_km = 6378.0 * np.cos(np.deg2rad(gcp_center_lat)) + lon_error_deg = result.lon_error_km / (lon_radius_km * np.pi / 180.0) + + t_matrix = np.array( + [ + [-0.418977524967338, 0.748005379751721, 0.514728846515064], + [-0.421890284446342, 0.341604851993858, -0.839830169131854], + [-0.804031356019172, -0.569029065124742, 0.172451447025628], + ] + ) + boresight = np.array([0.0, 0.0625969755450201, 0.99803888634292]) + + output = xr.Dataset( + { + "lat_error_deg": (["measurement"], [lat_error_deg]), + "lon_error_deg": (["measurement"], [lon_error_deg]), + "riss_ctrs": (["measurement", "xyz"], [r_iss_midframe]), + "bhat_hs": (["measurement", "xyz"], [boresight]), + "t_hs2ctrs": (["measurement", "xyz_from", "xyz_to"], t_matrix[np.newaxis, :, :]), + "gcp_lat_deg": (["measurement"], [gcp_center_lat]), + "gcp_lon_deg": (["measurement"], [gcp_center_lon]), + "gcp_alt": (["measurement"], [0.0]), + }, + coords={"measurement": [0], "xyz": ["x", "y", "z"], "xyz_from": ["x", "y", "z"], "xyz_to": ["x", "y", "z"]}, + ) + output.attrs.update( + { + "lat_error_km": result.lat_error_km, + "lon_error_km": result.lon_error_km, + "correlation_ccv": result.ccv_final, + "final_grid_step_m": result.final_grid_step_m, + "processing_time_s": processing_time, + "test_mode": True, + "param_idx": param_idx, + "gcp_center_lat": gcp_center_lat, + "gcp_center_lon": gcp_center_lon, + } + ) + return output diff --git a/tests/test_correction/clarreo/_pipeline_helpers.py b/tests/test_correction/clarreo/_pipeline_helpers.py new file mode 100644 index 00000000..4bf41243 --- /dev/null +++ b/tests/test_correction/clarreo/_pipeline_helpers.py @@ -0,0 +1,280 @@ +""" +Pipeline runner helpers for CLARREO integration tests. + +Provides ``run_upstream_pipeline`` and ``run_downstream_pipeline``, +which exercise the upstream (kernel creation + geolocation) and +downstream (GCP pairing + image matching + error statistics) halves +of the Correction pipeline respectively. + +These are *test infrastructure helpers*, not pytest tests themselves. +""" + +from __future__ import annotations + +import atexit +import logging +import shutil +import tempfile +from pathlib import Path + +import numpy as np +import xarray as xr + +from curryer.correction import correction +from curryer.correction.config import DataConfig + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Lazy imports (avoid hard-coding sys.path assumptions at module import time) +# --------------------------------------------------------------------------- + + +def _load_clarreo_loaders(): + from _image_match_helpers import discover_test_image_match_cases, run_image_matching_with_applied_errors + from clarreo_config import create_clarreo_correction_config + from clarreo_data_loaders import load_clarreo_science, load_clarreo_telemetry + + return ( + create_clarreo_correction_config, + load_clarreo_telemetry, + load_clarreo_science, + discover_test_image_match_cases, + run_image_matching_with_applied_errors, + ) + + +# --------------------------------------------------------------------------- +# Upstream pipeline +# --------------------------------------------------------------------------- + + +def run_upstream_pipeline( + n_iterations: int = 5, + work_dir: Path | None = None, +) -> tuple[list, dict, Path]: + """Test the upstream segment: parameter generation → kernel creation → geolocation. + + Uses ``synthetic_image_matching`` so it does NOT require valid GCP pairs. + + Returns + ------- + (results_list, results_summary_dict, output_file_path) + """ + from _synthetic_helpers import synthetic_image_matching + + ( + create_clarreo_correction_config, + load_clarreo_telemetry, + load_clarreo_science, + _, + _, + ) = _load_clarreo_loaders() + + logger.info("=== UPSTREAM PIPELINE TEST ===") + + root_dir = Path(__file__).parents[3] + generic_dir = root_dir / "data" / "generic" + data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" + + if work_dir is None: + _tmp = tempfile.mkdtemp(prefix="curryer_upstream_") + work_dir = Path(_tmp) + atexit.register(shutil.rmtree, work_dir, True) + logger.info("Temporary work dir: %s", work_dir) + else: + work_dir.mkdir(parents=True, exist_ok=True) + + config = create_clarreo_correction_config(data_dir, generic_dir) + config.n_iterations = n_iterations + config.output_filename = "upstream_results.nc" + + tlm_df = load_clarreo_telemetry(data_dir) + sci_df = load_clarreo_science(data_dir) + + tlm_csv = work_dir / "clarreo_telemetry.csv" + sci_csv = work_dir / "clarreo_science.csv" + tlm_df.to_csv(tlm_csv) + sci_df.to_csv(sci_csv) + + config.data = DataConfig(file_format="csv", time_scale_factor=1e6) + config._image_matching_override = synthetic_image_matching + + tlm_sci_gcp_sets = [(str(tlm_csv), str(sci_csv), "synthetic_gcp.mat")] + + logger.info("Executing Correction upstream workflow (%d iterations)…", n_iterations) + results, netcdf_data = correction.loop(config, work_dir, tlm_sci_gcp_sets) + + output_file = work_dir / config.get_output_filename() + logger.info("Upstream pipeline complete. Output: %s", output_file) + + summary = { + "mode": "upstream", + "iterations": n_iterations, + "parameter_sets": len(netcdf_data["parameter_set_id"]), + "status": "complete", + } + return results, summary, output_file + + +# --------------------------------------------------------------------------- +# Downstream pipeline +# --------------------------------------------------------------------------- + + +def run_downstream_pipeline( + n_iterations: int = 5, + test_cases: list[str] | None = None, + work_dir: Path | None = None, +) -> tuple[list, dict, Path]: + """Test the downstream segment: GCP pairing → image matching → error statistics. + + Uses pre-geolocated test data (.mat files) with known errors applied. + Does NOT test kernel creation or SPICE geolocation. + + Returns + ------- + (results_list, results_summary_dict, output_file_path) + """ + from _image_match_helpers import discover_test_image_match_cases, run_image_matching_with_applied_errors + from clarreo_config import create_clarreo_correction_config + + from curryer.correction.image_match import load_image_grid_from_mat + from curryer.correction.pairing import find_l1a_gcp_pairs + + logger.info("=== DOWNSTREAM PIPELINE TEST ===") + + root_dir = Path(__file__).parents[3] + generic_dir = root_dir / "data" / "generic" + data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" + test_data_dir = root_dir / "tests" / "data" / "clarreo" / "image_match" + + if work_dir is None: + _tmp = tempfile.mkdtemp(prefix="curryer_downstream_") + work_dir = Path(_tmp) + atexit.register(shutil.rmtree, work_dir, True) + logger.info("Temporary work dir: %s", work_dir) + else: + work_dir.mkdir(parents=True, exist_ok=True) + + # --- STEP 1: discover test cases --- + discovered_cases = discover_test_image_match_cases(test_data_dir, test_cases) + if not discovered_cases: + raise RuntimeError(f"No test cases found in {test_data_dir} (test_cases={test_cases})") + + # --- STEP 2: GCP spatial pairing --- + l1a_images = [] + l1a_to_testcase: dict[str, dict] = {} + for tc in discovered_cases: + img = load_image_grid_from_mat( + tc["subimage_file"], + key="subimage", + as_named=True, + name=str(tc["subimage_file"].relative_to(test_data_dir)), + ) + l1a_images.append(img) + l1a_to_testcase[img.name] = tc + + gcp_files_seen: set = set() + gcp_images = [] + for tc in discovered_cases: + if tc["gcp_file"] not in gcp_files_seen: + gcp_img = load_image_grid_from_mat( + tc["gcp_file"], + key="GCP", + as_named=True, + name=str(tc["gcp_file"].relative_to(test_data_dir)), + ) + gcp_images.append(gcp_img) + gcp_files_seen.add(tc["gcp_file"]) + + pairing_result = find_l1a_gcp_pairs(l1a_images, gcp_images, max_distance_m=0.0) + paired_test_cases = [l1a_to_testcase[pairing_result.l1a_images[m.l1a_index].name] for m in pairing_result.matches] + n_gcp_pairs = len(paired_test_cases) + logger.info("Pairing complete: %d valid pairs", n_gcp_pairs) + + # --- STEP 3: build config --- + base_config = create_clarreo_correction_config(data_dir, generic_dir) + config = correction.CorrectionConfig( + seed=42, + n_iterations=n_iterations, + parameters=[ + correction.ParameterConfig( + ptype=correction.ParameterType.CONSTANT_KERNEL, + config_file=data_dir / "cprs_hysics_v01.attitude.ck.json", + data={ + "current_value": [0.0, 0.0, 0.0], + "sigma": 0.0, + "units": "arcseconds", + "transformation_type": "dcm_rotation", + "coordinate_frames": ["HYSICS_SLIT", "CRADLE_ELEVATION"], + }, + ) + ], + geo=base_config.geo, + performance_threshold_m=base_config.performance_threshold_m, + performance_spec_percent=base_config.performance_spec_percent, + netcdf=base_config.netcdf, + calibration_file_names=base_config.calibration_file_names, + spacecraft_position_name=base_config.spacecraft_position_name, + boresight_name=base_config.boresight_name, + transformation_matrix_name=base_config.transformation_matrix_name, + ) + config.data = DataConfig(file_format="csv", time_scale_factor=1e6) + config.validate() + + # --- STEP 4: iterate --- + netcdf_data = correction._build_netcdf_structure(config, n_iterations, n_gcp_pairs) + threshold_metric = config.netcdf.get_threshold_metric_name() + image_match_cache: dict[str, xr.Dataset] = {} + + for param_idx in range(n_iterations): + logger.info("Iteration %d/%d", param_idx + 1, n_iterations) + pair_errors = [] + + for pair_idx, tc in enumerate(paired_test_cases): + cache_key = f"{tc['case_id']}_{tc.get('subcase_name', '')}" + cached = image_match_cache.get(cache_key) + + output = run_image_matching_with_applied_errors( + tc, + param_idx, + randomize_errors=True, + error_variation_percent=3.0, + cache_results=True, + cached_result=cached, + ) + if cache_key not in image_match_cache: + image_match_cache[cache_key] = output + + rms_m = np.sqrt((output.attrs["lat_error_km"] * 1000) ** 2 + (output.attrs["lon_error_km"] * 1000) ** 2) + pair_errors.append(rms_m) + + netcdf_data["rms_error_m"][param_idx, pair_idx] = rms_m + netcdf_data["mean_error_m"][param_idx, pair_idx] = rms_m + netcdf_data["max_error_m"][param_idx, pair_idx] = rms_m + netcdf_data["n_measurements"][param_idx, pair_idx] = 1 + netcdf_data["im_lat_error_km"][param_idx, pair_idx] = output.attrs["lat_error_km"] + netcdf_data["im_lon_error_km"][param_idx, pair_idx] = output.attrs["lon_error_km"] + netcdf_data["im_ccv"][param_idx, pair_idx] = output.attrs["correlation_ccv"] + netcdf_data["im_grid_step_m"][param_idx, pair_idx] = output.attrs["final_grid_step_m"] + + valid = np.array([e for e in pair_errors if not np.isnan(e)]) + if len(valid) > 0: + pct = (valid < config.performance_threshold_m).sum() / len(valid) * 100 + netcdf_data[threshold_metric][param_idx] = pct + netcdf_data["mean_rms_all_pairs"][param_idx] = np.mean(valid) + netcdf_data["best_pair_rms"][param_idx] = np.min(valid) + netcdf_data["worst_pair_rms"][param_idx] = np.max(valid) + + # --- STEP 5: error statistics --- + image_matching_results = list(image_match_cache.values()) + correction.call_error_stats_module(image_matching_results, correction_config=config) + + # --- STEP 6: save --- + output_file = work_dir / "downstream_results.nc" + correction._save_netcdf_results(netcdf_data, output_file, config) + + logger.info("Downstream pipeline complete. Output: %s", output_file) + summary = {"mode": "downstream", "iterations": n_iterations, "test_pairs": n_gcp_pairs, "status": "complete"} + return [], summary, output_file diff --git a/tests/test_correction/clarreo_config.py b/tests/test_correction/clarreo/clarreo_config.py similarity index 90% rename from tests/test_correction/clarreo_config.py rename to tests/test_correction/clarreo/clarreo_config.py index d3d086de..7cc8d73b 100644 --- a/tests/test_correction/clarreo_config.py +++ b/tests/test_correction/clarreo/clarreo_config.py @@ -55,22 +55,15 @@ def create_clarreo_correction_config(data_dir, generic_dir, config_output_path=N CorrectionConfig: THE single config object containing all CLARREO settings. IMPORTANT - Config-Centric Design: - This function creates the base CorrectionConfig (THE one config) with all - CLARREO-specific parameters and settings. After creation, you MUST add - data loaders and processing functions before using: + This function creates the base CorrectionConfig with all CLARREO-specific + parameters and settings. - config = create_clarreo_correction_config(data_dir, generic_dir) - - # Add required loaders - config.telemetry_loader = load_clarreo_telemetry - config.science_loader = load_clarreo_science + After creation, attach a :class:`~curryer.correction.config.DataConfig` + and (optionally) an ``_image_matching_override`` for test injection:: - # Add optional processing functions (test or production) - config.gcp_pairing_func = your_pairing_function - config.image_matching_func = your_matching_function - - # Validate and use - config.validate(check_loaders=True) + config = create_clarreo_correction_config(data_dir, generic_dir) + config.data = DataConfig(file_format="csv", time_scale_factor=1e6) + config._image_matching_override = your_matching_func # optional (tests only) results = correction.loop(config, work_dir, data_sets) Args: @@ -84,13 +77,12 @@ def create_clarreo_correction_config(data_dir, generic_dir, config_output_path=N Example: config = create_clarreo_correction_config(data_dir, generic_dir) - # Add required loaders - config.telemetry_loader = load_clarreo_telemetry - config.science_loader = load_clarreo_science + # Attach data loading config + from curryer.correction.config import DataConfig + config.data = DataConfig(file_format="csv", time_scale_factor=1e6) - # Add optional processing functions - config.gcp_pairing_func = synthetic_gcp_pairing - config.image_matching_func = synthetic_image_matching + # Optionally override image matching (tests only) + config._image_matching_override = synthetic_image_matching # Now ready to use results = correction.loop(config, work_dir, data_sets) @@ -222,8 +214,6 @@ def create_clarreo_correction_config(data_dir, generic_dir, config_output_path=N # Performance metrics (CLARREO requirements) performance_threshold_m=250.0, # CLARREO accuracy requirement (meters) performance_spec_percent=39.0, # CLARREO requirement: 39% of measurements under threshold - # Geodetic parameters (WGS84) - earth_radius_m=6378140.0, # WGS84 Earth radius in meters # NetCDF output configuration netcdf=netcdf_config, # Calibration file names (CLARREO/HySICS specific) @@ -242,9 +232,9 @@ def create_clarreo_correction_config(data_dir, generic_dir, config_output_path=N param_name = param.config_file.name if param.config_file else "time_correction" logger.info(f" {i + 1}. {param_name} ({param.ptype.name})") - # Validate configuration (without checking loaders - those are added later) - config.validate(check_loaders=False) - logger.info("✓ Configuration validation passed (loaders must be added before use)") + # Validate configuration + config.validate() + logger.info("✓ Configuration validation passed") # Save configuration to file if path is provided if config_output_path is not None: @@ -273,7 +263,6 @@ def create_clarreo_correction_config(data_dir, generic_dir, config_output_path=N "n_iterations": config.n_iterations, "performance_threshold_m": config.performance_threshold_m, "performance_spec_percent": config.performance_spec_percent, - "earth_radius_m": config.earth_radius_m, "parameters": [], }, "geolocation": { diff --git a/tests/test_correction/clarreo/clarreo_data_loaders.py b/tests/test_correction/clarreo/clarreo_data_loaders.py new file mode 100644 index 00000000..0afa7fe5 --- /dev/null +++ b/tests/test_correction/clarreo/clarreo_data_loaders.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 +"""CLARREO-specific data PREPROCESSING scripts. + +These functions join multiple raw CSV files, flip azimuth signs, and convert +DCM to quaternion. They are NOT part of the correction pipeline — run them +BEFORE using the correction module to produce standard single-file inputs. + +See examples/correction/ for the recommended workflow. + +.. deprecated:: + The Protocol-based loader pattern (``TelemetryLoader``, ``ScienceLoader``, + ``GCPLoader``) has been removed. Data loading is now config-driven via + :class:`~curryer.correction.config.DataConfig`. + + This file is kept as a shim so existing ``from clarreo_data_loaders import …`` + imports in tests continue to resolve. New code should call + :mod:`scripts.clarreo_preprocess` directly and pass preprocessed CSV file + paths to the pipeline. +""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import numpy as np +import pandas as pd + +logger = logging.getLogger(__name__) + + +def load_clarreo_telemetry(data_dir: Path | str) -> pd.DataFrame: + """Preprocess CLARREO raw telemetry CSVs into a single clean DataFrame. + + Performs CLARREO-specific steps: + - Loads 4 raw CSVs (SC_SPK, SC_CK, ST_CK, AZEL_CK) + - Reverses azimuth sign + - Converts star-tracker DCM → quaternion + - Outer-joins and sorts on ``ert`` + - Computes combined second + sub-second timetags + + Parameters + ---------- + data_dir : Path + Directory containing the raw CLARREO telemetry CSV files. + """ + from curryer import spicierpy as sp + + data_dir = Path(data_dir) + logger.info(f"Loading CLARREO telemetry from: {data_dir}") + + sc_spk_df = pd.read_csv(data_dir / "openloop_tlm_5a_sc_spk_20250521T225242.csv", index_col=0) + sc_ck_df = pd.read_csv(data_dir / "openloop_tlm_5a_sc_ck_20250521T225242.csv", index_col=0) + st_ck_df = pd.read_csv(data_dir / "openloop_tlm_5a_st_ck_20250521T225242.csv", index_col=0) + azel_ck_df = pd.read_csv(data_dir / "openloop_tlm_5a_azel_ck_20250521T225242.csv", index_col=0) + + # Reverse azimuth direction + azel_ck_df["hps.az_ang_nonlin"] = azel_ck_df["hps.az_ang_nonlin"] * -1 + + # Convert star-tracker DCM → quaternion + tlm_st_rot = np.vstack([st_ck_df[f"hps.dcm_base_iss_{r}_{c}"].values for r in range(1, 4) for c in range(1, 4)]).T + tlm_st_rot = np.reshape(tlm_st_rot, (-1, 3, 3)).copy() + tlm_st_rot_q = np.vstack([sp.m2q(tlm_st_rot[i]) for i in range(tlm_st_rot.shape[0])]) + st_ck_df["hps.dcm_base_iss_s"] = tlm_st_rot_q[:, 0] + st_ck_df["hps.dcm_base_iss_i"] = tlm_st_rot_q[:, 1] + st_ck_df["hps.dcm_base_iss_j"] = tlm_st_rot_q[:, 2] + st_ck_df["hps.dcm_base_iss_k"] = tlm_st_rot_q[:, 3] + + # Outer-join all four sources and sort by ERT + left_df = sc_spk_df + for right_df in [sc_ck_df, st_ck_df, azel_ck_df]: + left_df = pd.merge(left_df, right_df, on="ert", how="outer") + left_df = left_df.sort_values("ert") + + # Compute combined second + sub-second timetags + for col in list(left_df): + if col in ("hps.bad_ps_tms", "hps.corrected_tms", "hps.resolver_tms", "hps.st_quat_coi_tms"): + sub_col = col + "s" + if sub_col not in left_df.columns: + raise ValueError(f"Missing sub-second column for {col}") + if col == "hps.bad_ps_tms": + left_df[col + "_tmss"] = left_df[col] + left_df[sub_col] / 256 + else: + left_df[col + "_tmss"] = left_df[col] + left_df[sub_col] / 2**32 + + logger.info(f"Final CLARREO telemetry shape: {left_df.shape}") + return left_df + + +def load_clarreo_science(data_dir: Path | str) -> pd.DataFrame: + """Load CLARREO science frame timing CSV. + + Returns ``corrected_timestamp`` in GPS seconds (the raw instrument unit). + Set ``DataConfig.time_scale_factor = 1e6`` so the pipeline converts to uGPS. + + Parameters + ---------- + data_dir : Path + Directory containing the science timing CSV. + """ + data_dir = Path(data_dir) + logger.info(f"Loading CLARREO science from: {data_dir}") + sci_df = pd.read_csv(data_dir / "openloop_tlm_5a_sci_times_20250521T225242.csv", index_col=0) + logger.info(f"Science shape: {sci_df.shape}") + return sci_df + + +def load_clarreo_gcp(gcp_key: str, config=None) -> None: # noqa: ANN001 + """No-op placeholder – GCPLoader protocol has been removed. + + Pass GCP file paths directly as the third element of each + ``tlm_sci_gcp_sets`` tuple instead. + """ + logger.info("load_clarreo_gcp is a no-op placeholder (GCPLoader protocol removed).") + return None + + +__all__ = ["load_clarreo_telemetry", "load_clarreo_science", "load_clarreo_gcp"] diff --git a/tests/test_correction/clarreo/conftest.py b/tests/test_correction/clarreo/conftest.py new file mode 100644 index 00000000..db32fd7f --- /dev/null +++ b/tests/test_correction/clarreo/conftest.py @@ -0,0 +1,32 @@ +"""Pytest configuration for CLARREO integration tests. + +Exposes session-scoped path fixtures for the CLARREO test data directories. +The ``clarreo/`` directory is already on ``sys.path`` via the parent +``test_correction/conftest.py``, so test files here can import +``clarreo_config`` and ``clarreo_data_loaders`` without any additional +``sys.path`` manipulation. +""" + +from pathlib import Path + +import pytest + + +@pytest.fixture(scope="session") +def clarreo_root(): + return Path(__file__).parents[3] + + +@pytest.fixture(scope="session") +def clarreo_gcs_data_dir(clarreo_root): + return clarreo_root / "tests" / "data" / "clarreo" / "gcs" + + +@pytest.fixture(scope="session") +def clarreo_image_match_data_dir(clarreo_root): + return clarreo_root / "tests" / "data" / "clarreo" / "image_match" + + +@pytest.fixture(scope="session") +def clarreo_generic_dir(clarreo_root): + return clarreo_root / "data" / "generic" diff --git a/tests/test_correction/clarreo/test_clarreo_config.py b/tests/test_correction/clarreo/test_clarreo_config.py new file mode 100644 index 00000000..c2ee101c --- /dev/null +++ b/tests/test_correction/clarreo/test_clarreo_config.py @@ -0,0 +1,91 @@ +"""Tests for CLARREO correction configuration generation and JSON serialisation.""" + +from __future__ import annotations + +import json +import logging + +import pytest +from clarreo_config import create_clarreo_correction_config + +from curryer.correction import correction +from curryer.correction.config import DataConfig + +logger = logging.getLogger(__name__) + + +def test_generate_clarreo_config_json(tmp_path, clarreo_gcs_data_dir, clarreo_generic_dir): + """Generate the CLARREO config JSON and validate structure end-to-end.""" + output_path = tmp_path / "configs" / "clarreo_correction_config.json" + + config = create_clarreo_correction_config(clarreo_gcs_data_dir, clarreo_generic_dir, config_output_path=output_path) + + assert output_path.exists() + + with open(output_path) as fh: + config_data = json.load(fh) + + assert "mission_config" in config_data + assert "correction" in config_data + assert "geolocation" in config_data + assert config_data["mission_config"]["mission_name"] == "CLARREO_Pathfinder" + + corr = config_data["correction"] + assert isinstance(corr.get("parameters"), list) + assert len(corr["parameters"]) > 0 + assert corr["performance_threshold_m"] == 250.0 + assert corr["performance_spec_percent"] == 39.0 + assert config_data["geolocation"]["instrument_name"] == "CPRS_HYSICS" + + reloaded = correction.load_config_from_json(output_path) + assert reloaded.n_iterations == config.n_iterations + assert len(reloaded.parameters) == len(config.parameters) + reloaded.validate() + + # Verify each reloaded parameter has the expected ptype, config_file (for + # kernel-based params), and data.field (for OFFSET_KERNEL / OFFSET_TIME), + # so this test will fail if the JSON can't be used to run correction.loop(). + valid_ptypes = { + correction.ParameterType.CONSTANT_KERNEL, + correction.ParameterType.OFFSET_KERNEL, + correction.ParameterType.OFFSET_TIME, + } + for param in reloaded.parameters: + assert param.ptype in valid_ptypes, f"Unexpected ptype: {param.ptype}" + + if param.ptype in (correction.ParameterType.CONSTANT_KERNEL, correction.ParameterType.OFFSET_KERNEL): + assert param.config_file is not None, f"{param.ptype.name} parameter must have a config_file, got None" + + if param.ptype in (correction.ParameterType.OFFSET_KERNEL, correction.ParameterType.OFFSET_TIME): + assert param.data.field is not None, f"{param.ptype.name} parameter must have data.field set, got None" + + # Count parameters by type to ensure the expected composition is preserved. + ptypes = [p.ptype for p in reloaded.parameters] + assert ptypes.count(correction.ParameterType.CONSTANT_KERNEL) >= 1 + assert ptypes.count(correction.ParameterType.OFFSET_KERNEL) >= 1 + assert ptypes.count(correction.ParameterType.OFFSET_TIME) >= 1 + + +class TestClarreoConfiguration: + """Smoke tests for the CLARREO CorrectionConfig object.""" + + @pytest.fixture(autouse=True) + def _setup(self, clarreo_gcs_data_dir, clarreo_generic_dir): + self.data_dir = clarreo_gcs_data_dir + self.generic_dir = clarreo_generic_dir + + def test_config_validates(self): + config = create_clarreo_correction_config(self.data_dir, self.generic_dir) + config.data = DataConfig(file_format="csv", time_scale_factor=1e6) + config.validate() + assert config.geo.instrument_name == "CPRS_HYSICS" + assert config.seed == 42 + + def test_parameter_count(self): + config = create_clarreo_correction_config(self.data_dir, self.generic_dir) + assert len(config.parameters) == 6 + + def test_performance_thresholds(self): + config = create_clarreo_correction_config(self.data_dir, self.generic_dir) + assert config.performance_threshold_m == 250.0 + assert config.performance_spec_percent == 39.0 diff --git a/tests/test_correction/clarreo/test_clarreo_dataio.py b/tests/test_correction/clarreo/test_clarreo_dataio.py new file mode 100644 index 00000000..e6ab8fdc --- /dev/null +++ b/tests/test_correction/clarreo/test_clarreo_dataio.py @@ -0,0 +1,38 @@ +"""CLARREO-specific data I/O integration tests (require AWS credentials).""" + +from __future__ import annotations + +import datetime as dt +import logging +import os + +import pytest + +from curryer.correction.dataio import S3Configuration, find_netcdf_objects + +logger = logging.getLogger(__name__) + +_NEEDS_AWS = pytest.mark.skipif( + not ( + (os.getenv("AWS_ACCESS_KEY_ID") and os.getenv("AWS_SECRET_ACCESS_KEY") and os.getenv("AWS_SESSION_TOKEN")) + or os.getenv("C9_USER") + ), + reason="Requires AWS credentials or Cloud9 environment.", +) + + +@_NEEDS_AWS +def test_clarreo_find_l0_objects(tmp_path): + """Find L0 telemetry objects in CSDS S3 bucket.""" + config = S3Configuration("clarreo", "L0/telemetry/hps_navigation/") + keys = find_netcdf_objects(config, start_date=dt.date(2017, 1, 15), end_date=dt.date(2017, 1, 15)) + assert keys == ["L0/telemetry/hps_navigation/20170115/CPF_TLM_L0.V00-000.hps_navigation-20170115-0.0.0.nc"] + + +@_NEEDS_AWS +def test_clarreo_find_l1a_objects(tmp_path): + """Find L1a science objects in CSDS S3 bucket.""" + config = S3Configuration("clarreo", "L1a/nadir/") + keys = find_netcdf_objects(config, start_date=dt.date(2022, 6, 3), end_date=dt.date(2022, 6, 3)) + assert len(keys) == 34 + assert "L1a/nadir/20220603/nadir-20220603T235952-step22-geolocation_creation-0.0.0.nc" in keys diff --git a/tests/test_correction/clarreo/test_clarreo_e2e.py b/tests/test_correction/clarreo/test_clarreo_e2e.py new file mode 100644 index 00000000..9d26b6f2 --- /dev/null +++ b/tests/test_correction/clarreo/test_clarreo_e2e.py @@ -0,0 +1,129 @@ +"""CLARREO end-to-end integration tests. + +Exercises the upstream (kernel creation + geolocation) and downstream +(GCP pairing + image matching + error statistics) pipelines using +CLARREO test data. + +Extra tests (require GMTED data or SPICE binaries): ``pytest --run-extra`` +""" + +from __future__ import annotations + +import logging + +import numpy as np +import pytest +import xarray as xr +from _image_match_helpers import ( + apply_error_variation_for_testing, + discover_test_image_match_cases, + run_image_matching_with_applied_errors, +) +from _pipeline_helpers import run_downstream_pipeline, run_upstream_pipeline +from _synthetic_helpers import ( + _generate_nadir_aligned_transforms, + _generate_spherical_positions, + _generate_synthetic_boresights, + synthetic_gcp_pairing, +) +from clarreo_config import create_clarreo_correction_config + +from curryer.correction.config import DataConfig + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def work_dir(tmp_path): + d = tmp_path / "work" + d.mkdir() + return d + + +def test_upstream_configuration(clarreo_gcs_data_dir, clarreo_generic_dir): + """Upstream configuration loads and validates correctly.""" + config = create_clarreo_correction_config(clarreo_gcs_data_dir, clarreo_generic_dir) + config.data = DataConfig(file_format="csv", time_scale_factor=1e6) + config.validate() + assert config.geo.instrument_name == "CPRS_HYSICS" + assert len(config.parameters) > 0 + assert config.seed == 42 + assert config.data is not None + + +def test_downstream_test_case_discovery(clarreo_image_match_data_dir): + """Downstream test cases can be discovered from the data directory.""" + test_cases = discover_test_image_match_cases(clarreo_image_match_data_dir) + assert len(test_cases) > 0 + for tc in test_cases: + assert "case_id" in tc + assert "subimage_file" in tc + assert "gcp_file" in tc + assert tc["subimage_file"].exists() + assert tc["gcp_file"].exists() + + +def test_downstream_image_matching(clarreo_image_match_data_dir): + """Downstream image matching runs successfully on test case 1.""" + test_cases = discover_test_image_match_cases(clarreo_image_match_data_dir, test_cases=["1"]) + assert len(test_cases) > 0 + result = run_image_matching_with_applied_errors( + test_cases[0], param_idx=0, randomize_errors=False, cache_results=True + ) + assert isinstance(result, xr.Dataset) + assert "lat_error_km" in result.attrs + assert "lon_error_km" in result.attrs + + +@pytest.mark.extra +def test_upstream_quick(clarreo_gcs_data_dir, clarreo_generic_dir, work_dir): + """Quick upstream pipeline (2 iterations). Requires GMTED – ``--run-extra``.""" + results_list, results_dict, output_file = run_upstream_pipeline(n_iterations=2, work_dir=work_dir) + assert results_dict["status"] == "complete" + assert results_dict["iterations"] == 2 + assert results_dict["mode"] == "upstream" + assert results_dict["parameter_sets"] > 0 + assert output_file.exists() + + +def test_downstream_quick(work_dir, clarreo_image_match_data_dir): + """Quick downstream pipeline (2 iterations, test case 1).""" + results_list, results_dict, output_file = run_downstream_pipeline( + n_iterations=2, test_cases=["1"], work_dir=work_dir + ) + assert results_dict["status"] == "complete" + assert results_dict["iterations"] == 2 + assert output_file.exists() + + +def test_downstream_helpers_basic(clarreo_image_match_data_dir): + """Downstream helper functions work correctly.""" + test_cases = discover_test_image_match_cases(clarreo_image_match_data_dir, test_cases=["1"]) + assert len(test_cases) > 0 + assert "case_id" in test_cases[0] + + base = xr.Dataset( + {"lat_error_deg": (["m"], [0.001]), "lon_error_deg": (["m"], [0.002])}, + attrs={"lat_error_km": 0.1, "lon_error_km": 0.2, "correlation_ccv": 0.95}, + ) + varied = apply_error_variation_for_testing(base, param_idx=1, error_variation_percent=3.0) + assert isinstance(varied, xr.Dataset) + assert varied.attrs["lat_error_km"] != base.attrs["lat_error_km"] + + +def test_synthetic_helpers_basic(): + """Generic synthetic helper functions produce correctly-shaped outputs.""" + pairs = synthetic_gcp_pairing(["science_1.nc", "science_2.nc"]) + assert len(pairs) == 2 + + boresights = _generate_synthetic_boresights(5, max_off_nadir_rad=0.07) + assert boresights.shape == (5, 3) + assert np.all(np.abs(boresights[:, 0]) < 0.01) + + positions = _generate_spherical_positions(5, 6.78e6, 4e3) + assert positions.shape == (5, 3) + assert np.all(np.linalg.norm(positions, axis=1) > 6.7e6) + + transforms = _generate_nadir_aligned_transforms(5, positions, boresights) + assert transforms.shape == (5, 3, 3) + assert abs(abs(np.linalg.det(transforms[0])) - 1.0) < 0.2 diff --git a/tests/test_correction/clarreo_data_loaders.py b/tests/test_correction/clarreo_data_loaders.py deleted file mode 100644 index 93a64a4e..00000000 --- a/tests/test_correction/clarreo_data_loaders.py +++ /dev/null @@ -1,203 +0,0 @@ -#!/usr/bin/env python3 -""" -CLARREO-specific data loading functions for Correction testing. - -This module contains all CLARREO/HySICS-specific data loading logic that was -previously in correction.py. These functions handle the specific file formats, -naming conventions, and data transformations needed for CLARREO Pathfinder data. - -Other missions should create similar modules with their own data loading logic. - -Usage: - from clarreo_data_loaders import load_clarreo_telemetry, load_clarreo_science - - tlm_data = load_clarreo_telemetry(tlm_key, config) - sci_data = load_clarreo_science(sci_key, config) - -Date: October 28, 2025 -""" - -import logging -from pathlib import Path - -import numpy as np -import pandas as pd - -from curryer import spicierpy as sp - -logger = logging.getLogger(__name__) - - -def load_clarreo_telemetry(tlm_key: str, config) -> pd.DataFrame: - """ - Load CLARREO telemetry data from multiple CSV files. - - CLARREO-specific implementation that: - - Loads 4 separate CSV files (SC_SPK, SC_CK, ST_CK, AZEL_CK) - - Reverses azimuth direction - - Converts star-tracker DCM to quaternions - - Merges all telemetry sources - - Computes combined time tags - - Args: - tlm_key: Path to telemetry file or identifier (used to construct paths) - config: Correction configuration - - Returns: - DataFrame with merged telemetry data - """ - # Extract the base path from config or construct from tlm_key - # For test cases, tlm_key is often just a string identifier like 'telemetry_5a' - # The actual data location comes from config.geo.meta_kernel_file or we use default - if hasattr(config.geo, "meta_kernel_file") and config.geo.meta_kernel_file: - # Get directory from meta kernel file path - base_path = config.geo.meta_kernel_file.parent - elif isinstance(tlm_key, Path) and tlm_key.is_dir(): - base_path = tlm_key - elif isinstance(tlm_key, Path) and tlm_key.parent.exists(): - base_path = tlm_key.parent - else: - # Fallback: construct absolute path to test data - script_dir = Path(__file__).parent.parent.parent - base_path = script_dir / "tests" / "data" / "clarreo" / "gcs" - - logger.info(f"Loading CLARREO telemetry data from: {base_path}") - - # Verify the directory exists and has data - if not base_path.exists(): - raise FileNotFoundError(f"Telemetry data directory not found: {base_path}") - - # Load the 4 telemetry CSVs (CLARREO-specific files) - sc_spk_df = pd.read_csv(base_path / "openloop_tlm_5a_sc_spk_20250521T225242.csv", index_col=0) - sc_ck_df = pd.read_csv(base_path / "openloop_tlm_5a_sc_ck_20250521T225242.csv", index_col=0) - st_ck_df = pd.read_csv(base_path / "openloop_tlm_5a_st_ck_20250521T225242.csv", index_col=0) - azel_ck_df = pd.read_csv(base_path / "openloop_tlm_5a_azel_ck_20250521T225242.csv", index_col=0) - - logger.info( - f"Loaded telemetry CSVs - SC_SPK: {sc_spk_df.shape}, SC_CK: {sc_ck_df.shape}, " - f"ST_CK: {st_ck_df.shape}, AZEL_CK: {azel_ck_df.shape}" - ) - - # CLARREO-specific: Reverse the direction of the Azimuth element - azel_ck_df["hps.az_ang_nonlin"] = azel_ck_df["hps.az_ang_nonlin"] * -1 - - # CLARREO-specific: Convert star-tracker from rotation matrix to quaternion - tlm_st_rot = np.vstack( - [ - st_ck_df["hps.dcm_base_iss_1_1"].values, - st_ck_df["hps.dcm_base_iss_1_2"].values, - st_ck_df["hps.dcm_base_iss_1_3"].values, - st_ck_df["hps.dcm_base_iss_2_1"].values, - st_ck_df["hps.dcm_base_iss_2_2"].values, - st_ck_df["hps.dcm_base_iss_2_3"].values, - st_ck_df["hps.dcm_base_iss_3_1"].values, - st_ck_df["hps.dcm_base_iss_3_2"].values, - st_ck_df["hps.dcm_base_iss_3_3"].values, - ] - ).T - tlm_st_rot = np.reshape(tlm_st_rot, (-1, 3, 3)).copy() - - tlm_st_rot_q = np.vstack([sp.m2q(tlm_st_rot[i, :, :]) for i in range(tlm_st_rot.shape[0])]) - st_ck_df["hps.dcm_base_iss_s"] = tlm_st_rot_q[:, 0] - st_ck_df["hps.dcm_base_iss_i"] = tlm_st_rot_q[:, 1] - st_ck_df["hps.dcm_base_iss_j"] = tlm_st_rot_q[:, 2] - st_ck_df["hps.dcm_base_iss_k"] = tlm_st_rot_q[:, 3] - - # CLARREO-specific: Merge all telemetry sources with outer joins - left_df = sc_spk_df - for right_df in [sc_ck_df, st_ck_df, azel_ck_df]: - left_df = pd.merge(left_df, right_df, on="ert", how="outer") - left_df = left_df.sort_values("ert") - - # CLARREO-specific: Compute combined second and subsecond timetags - for col in list(left_df): - if col in ("hps.bad_ps_tms", "hps.corrected_tms", "hps.resolver_tms", "hps.st_quat_coi_tms"): - assert col + "s" in left_df.columns, f"Missing subsecond column for {col}" - - if col == "hps.bad_ps_tms": - left_df[col + "_tmss"] = left_df[col] + left_df[col + "s"] / 256 - elif col in ("hps.corrected_tms", "hps.resolver_tms", "hps.st_quat_coi_tms"): - left_df[col + "_tmss"] = left_df[col] + left_df[col + "s"] / 2**32 - else: - raise ValueError(f"Missing conversion for expected column: {col}") - - logger.info(f"Final CLARREO telemetry shape: {left_df.shape}") - - # Validate output format - from curryer.correction.dataio import validate_telemetry_output - - validate_telemetry_output(left_df, config) - - return left_df - - -def load_clarreo_science(sci_key: str, config) -> pd.DataFrame: - """ - Load CLARREO science frame timing data. - - CLARREO-specific implementation that: - - Loads science frame timestamps from CSV - - Converts GPS seconds to uGPS (microseconds) - - Args: - sci_key: Path to science file or identifier - config: Correction configuration - - Returns: - DataFrame with science frame timestamps - """ - # Extract the base path from config or construct from sci_key - if hasattr(config.geo, "meta_kernel_file") and config.geo.meta_kernel_file: - base_path = config.geo.meta_kernel_file.parent - elif isinstance(sci_key, Path) and sci_key.is_dir(): - base_path = sci_key - elif isinstance(sci_key, Path) and sci_key.parent.exists(): - base_path = sci_key.parent - else: - # Fallback: construct absolute path to test data - script_dir = Path(__file__).parent.parent.parent - base_path = script_dir / "tests" / "data" / "clarreo" / "gcs" - - logger.info(f"Loading CLARREO science data from: {base_path}") - - # CLARREO-specific: Load science frame timing CSV - sci_time_df = pd.read_csv(base_path / "openloop_tlm_5a_sci_times_20250521T225242.csv", index_col=0) - - # CLARREO-specific: Frame times are GPS seconds, geolocation expects uGPS (microseconds) - sci_time_df["corrected_timestamp"] *= 1e6 - - logger.info(f"CLARREO science data shape: {sci_time_df.shape}") - logger.info( - f"corrected_timestamp range: {sci_time_df['corrected_timestamp'].min():.2e} to " - f"{sci_time_df['corrected_timestamp'].max():.2e} uGPS" - ) - - # Validate output format - from curryer.correction.dataio import validate_science_output - - validate_science_output(sci_time_df, config) - - return sci_time_df - - -def load_clarreo_gcp(gcp_key: str, config): - """ - Load CLARREO Ground Control Point (GCP) reference data. - - PLACEHOLDER - Real implementation will: - - Load Landsat GCP reference images/coordinates - - Extract georeferenced control points - - Return spatially/temporally matched reference data - - Args: - gcp_key: Path to GCP file or identifier - config: Correction configuration - - Returns: - GCP reference data (format TBD based on GCP pairing module requirements) - """ - logger.info(f"Loading CLARREO GCP data from: {gcp_key} (PLACEHOLDER)") - - # For testing purposes, return None - the GCP pairing module will handle this - # In real implementation, this would load and process GCP reference data - return diff --git a/tests/test_correction/conftest.py b/tests/test_correction/conftest.py index 9405493e..76163452 100644 --- a/tests/test_correction/conftest.py +++ b/tests/test_correction/conftest.py @@ -1,111 +1,41 @@ -""" -Pytest configuration and shared fixtures for correction module tests. +"""Generic pytest configuration for test_correction. -This module provides shared fixtures used across multiple test files to ensure -consistent configuration and data loading. +- Adds ``clarreo/`` sub-directory to ``sys.path`` so test files in this package + can import ``clarreo_config``, ``clarreo_data_loaders``, etc. without + repeating ``sys.path`` manipulation in every file. +- Adds this directory itself to ``sys.path`` so ``_synthetic_helpers`` is + importable by the ``clarreo/`` sub-package helpers. """ +import sys from pathlib import Path import pytest -from clarreo_config import create_clarreo_correction_config - -from curryer.correction import correction - - -@pytest.fixture(scope="session") -def clarreo_config_from_json(): - """Load CLARREO config from JSON (session-scoped for efficiency). - - This fixture loads the canonical CLARREO configuration from JSON file. - The JSON file should be generated by test_generate_clarreo_config_json() - and committed to version control as the single source of truth. - - Session scope ensures the config is loaded once and reused across all tests. - """ - config_path = Path(__file__).parent / "configs/clarreo_correction_config.json" - - if not config_path.exists(): - pytest.skip(f"Config file not found: {config_path}. Run test_generate_clarreo_config_json() first.") - - return correction.load_config_from_json(config_path) - - -@pytest.fixture(scope="session") -def clarreo_config_programmatic(): - """Generate CLARREO config programmatically (for comparison tests). - - This fixture creates the configuration programmatically using the - create_clarreo_correction_config() function. Useful for testing - that programmatic and JSON configs produce equivalent results. - - Session scope for efficiency across multiple tests. - """ - data_dir = Path(__file__).parent.parent / "data/clarreo/gcs" - generic_dir = Path("data/generic") - return create_clarreo_correction_config(data_dir, generic_dir) +_here = str(Path(__file__).parent) +_clarreo_dir = str(Path(__file__).parent / "clarreo") -@pytest.fixture(scope="session") -def clarreo_data_dir(): - """Path to CLARREO test data directory. - - Returns: - Path: tests/data/clarreo/ - """ - return Path(__file__).parent.parent / "data/clarreo" - - -@pytest.fixture(scope="session") -def clarreo_gcs_data_dir(clarreo_data_dir): - """Path to CLARREO GCS test data directory. - - Returns: - Path: tests/data/clarreo/gcs/ - """ - return clarreo_data_dir / "gcs" - - -@pytest.fixture(scope="session") -def clarreo_image_match_data_dir(clarreo_data_dir): - """Path to CLARREO image matching test data directory. - - Returns: - Path: tests/data/clarreo/image_match/ - """ - return clarreo_data_dir / "image_match" +for _p in (_here, _clarreo_dir): + if _p not in sys.path: + sys.path.insert(0, _p) @pytest.fixture(scope="session") -def generic_kernel_dir(): - """Path to generic SPICE kernels directory. - - Returns: - Path: data/generic/ - """ - return Path("data/generic") +def root_dir(): + """Repository root directory (two levels above ``tests/test_correction/``).""" + return Path(__file__).parents[2] @pytest.fixture def temp_work_dir(tmp_path): - """Create temporary working directory for test outputs. - - This fixture provides a clean temporary directory for each test - that needs to write files (NetCDF outputs, intermediate results, etc.) - - Function scope ensures each test gets a fresh directory. - - Returns: - Path: Temporary directory path - """ - work_dir = tmp_path / "correction_work" - work_dir.mkdir(parents=True, exist_ok=True) - return work_dir + """Clean temporary working directory for each test.""" + work = tmp_path / "correction_work" + work.mkdir(parents=True, exist_ok=True) + return work -# Configuration for pytest def pytest_configure(config): - """Configure pytest with custom markers.""" - config.addinivalue_line("markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')") + """Register custom markers.""" + config.addinivalue_line("markers", "slow: marks tests as slow") config.addinivalue_line("markers", "integration: marks tests as integration tests") - config.addinivalue_line("markers", "requires_gcs: marks tests that require GCS credentials") + config.addinivalue_line("markers", "requires_gcs: marks tests requiring GCS credentials") diff --git a/tests/test_correction/test_config.py b/tests/test_correction/test_config.py new file mode 100644 index 00000000..14c9bcc6 --- /dev/null +++ b/tests/test_correction/test_config.py @@ -0,0 +1,791 @@ +"""Tests for the Pydantic-based correction configuration models. + +Covers: +- Construction and field validation for all BaseModel subclasses +- ``DataConfig`` config-driven data loading configuration +- ``ParameterData`` backward-compatible dict-style access +- JSON round-trip: ``config == CorrectionConfig.model_validate_json(config.model_dump_json())`` +- ``ValidationError`` raised with field-level messages for invalid inputs +- Callable / loader fields excluded from JSON serialisation +""" + +from pathlib import Path + +import pytest +from pydantic import ValidationError + +from curryer.correction.config import ( + CorrectionConfig, + DataConfig, + GeolocationConfig, + NetCDFConfig, + NetCDFParameterMetadata, + ParameterConfig, + ParameterData, + ParameterType, + load_config_from_json, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def geo() -> GeolocationConfig: + return GeolocationConfig( + meta_kernel_file=Path("tests/data/test.kernels.tm.json"), + generic_kernel_dir=Path("data/generic"), + dynamic_kernels=[Path("tests/data/sc.spk.json"), Path("tests/data/sc.ck.json")], + instrument_name="TEST_INSTRUMENT", + time_field="corrected_timestamp", + ) + + +@pytest.fixture +def param_constant(geo) -> ParameterConfig: + return ParameterConfig( + ptype=ParameterType.CONSTANT_KERNEL, + config_file=Path("tests/data/test_base.attitude.ck.json"), + data={ + "current_value": [0.0, 0.0, 0.0], + "bounds": [-300.0, 300.0], + "sigma": 30.0, + "units": "arcseconds", + "distribution": "normal", + "transformation_type": "dcm_rotation", + "coordinate_frames": ["FRAME_A", "FRAME_B"], + }, + ) + + +@pytest.fixture +def param_offset_kernel() -> ParameterConfig: + return ParameterConfig( + ptype=ParameterType.OFFSET_KERNEL, + config_file=Path("tests/data/test_az.attitude.ck.json"), + data={ + "field": "hps.az_ang_nonlin", + "current_value": 0.0, + "bounds": [-300.0, 300.0], + "sigma": 30.0, + "units": "arcseconds", + }, + ) + + +@pytest.fixture +def param_offset_time() -> ParameterConfig: + return ParameterConfig( + ptype=ParameterType.OFFSET_TIME, + config_file=None, + data={ + "field": "corrected_timestamp", + "current_value": 0.0, + "bounds": [-50.0, 50.0], + "sigma": 7.0, + "units": "milliseconds", + }, + ) + + +@pytest.fixture +def netcdf_cfg() -> NetCDFConfig: + return NetCDFConfig( + performance_threshold_m=250.0, + title="Test Geolocation Analysis", + description="Unit-test run", + ) + + +@pytest.fixture +def minimal_config(geo, param_constant) -> CorrectionConfig: + """Minimal CorrectionConfig with no loaders (fully serialisable).""" + return CorrectionConfig( + seed=42, + n_iterations=5, + parameters=[param_constant], + geo=geo, + performance_threshold_m=250.0, + performance_spec_percent=39.0, + ) + + +@pytest.fixture +def full_config(geo, param_constant, param_offset_kernel, param_offset_time, netcdf_cfg) -> CorrectionConfig: + """Full CorrectionConfig with all optional fields populated.""" + return CorrectionConfig( + seed=0, + n_iterations=10, + parameters=[param_constant, param_offset_kernel, param_offset_time], + geo=geo, + performance_threshold_m=250.0, + performance_spec_percent=39.0, + netcdf=netcdf_cfg, + output_filename="test_results.nc", + calibration_dir=Path("tests/data/calibration"), + calibration_file_names={"los_vectors": "b_HS.mat", "optical_psf": "psf.mat"}, + spacecraft_position_name="riss_ctrs", + boresight_name="bhat_hs", + transformation_matrix_name="t_hs2ctrs", + ) + + +# =========================================================================== +# DataConfig – construction and typed fields +# =========================================================================== + + +class TestDataConfig: + def test_defaults(self): + dc = DataConfig() + assert dc.file_format == "csv" + assert dc.time_scale_factor == 1.0 + + def test_custom_values(self): + dc = DataConfig(file_format="netcdf", time_scale_factor=1e6) + assert dc.file_format == "netcdf" + assert dc.time_scale_factor == 1e6 + + def test_invalid_file_format(self): + with pytest.raises(ValidationError): + DataConfig(file_format="xml") + + def test_json_round_trip(self): + dc = DataConfig(file_format="hdf5", time_scale_factor=1.0) + restored = DataConfig.model_validate_json(dc.model_dump_json()) + assert restored.file_format == "hdf5" + assert restored.time_scale_factor == 1.0 + + def test_embedded_in_correction_config(self, geo, param_constant): + """DataConfig round-trips through CorrectionConfig serialisation.""" + cfg = CorrectionConfig( + n_iterations=1, + parameters=[param_constant], + geo=geo, + performance_threshold_m=250.0, + performance_spec_percent=39.0, + data=DataConfig(file_format="csv", time_scale_factor=1e6), + ) + json_str = cfg.model_dump_json() + restored = CorrectionConfig.model_validate_json(json_str) + assert restored.data is not None + assert restored.data.file_format == "csv" + assert restored.data.time_scale_factor == 1e6 + + def test_none_data_field_is_valid(self, geo, param_constant): + """CorrectionConfig.data may be None (backward compat).""" + cfg = CorrectionConfig( + n_iterations=1, + parameters=[param_constant], + geo=geo, + performance_threshold_m=250.0, + performance_spec_percent=39.0, + ) + assert cfg.data is None + + +# =========================================================================== +# ParameterData – construction and typed fields +# =========================================================================== + + +class TestParameterData: + def test_construction_with_all_fields(self): + pd = ParameterData( + current_value=[1.0, 2.0, 3.0], + bounds=[-100.0, 100.0], + sigma=10.0, + units="arcseconds", + distribution="normal", + field="some_field", + transformation_type="dcm_rotation", + coordinate_frames=["F1", "F2"], + ) + assert pd.current_value == [1.0, 2.0, 3.0] + assert pd.sigma == 10.0 + assert pd.units == "arcseconds" + + def test_defaults(self): + pd = ParameterData() + assert pd.current_value == 0.0 + assert pd.sigma is None + assert pd.units is None + assert pd.distribution == "normal" + + # -- dict-style backward compat ------------------------------------------- + + def test_get_returns_value(self): + pd = ParameterData(sigma=30.0, units="arcseconds") + assert pd.get("sigma") == 30.0 + assert pd.get("units") == "arcseconds" + + def test_get_returns_default_for_none_field(self): + pd = ParameterData() # sigma=None by default + assert pd.get("sigma", "N/A") == "N/A" + assert pd.get("sigma") is None + + def test_get_nonexistent_key_returns_default(self): + pd = ParameterData() + assert pd.get("no_such_key", "MISSING") == "MISSING" + + def test_contains_true_for_non_none_field(self): + pd = ParameterData(sigma=30.0) + assert "sigma" in pd + + def test_contains_false_for_none_field(self): + pd = ParameterData() # sigma=None + assert "sigma" not in pd + + def test_contains_true_for_zero_sigma(self): + """sigma=0.0 is explicitly set and must be found by 'in'.""" + pd = ParameterData(sigma=0.0) + assert "sigma" in pd + + def test_getitem_returns_value(self): + pd = ParameterData(sigma=5.0) + assert pd["sigma"] == 5.0 + + def test_getitem_raises_keyerror_for_missing_key(self): + pd = ParameterData() + with pytest.raises(KeyError): + _ = pd["totally_missing"] + + def test_extra_fields_allowed_and_accessible(self): + pd = ParameterData(my_custom_field="hello") + assert pd.get("my_custom_field") == "hello" + assert "my_custom_field" in pd + assert pd["my_custom_field"] == "hello" + + def test_validation_error_for_non_numeric_sigma(self): + with pytest.raises(ValidationError) as exc_info: + ParameterData(sigma="not-a-float") + assert "sigma" in str(exc_info.value) + + +# =========================================================================== +# ParameterConfig +# =========================================================================== + + +class TestParameterConfig: + def test_dict_coercion(self, param_constant): + """Passing data as a plain dict must produce a ParameterData instance.""" + assert isinstance(param_constant.data, ParameterData) + assert param_constant.data.sigma == 30.0 + + def test_none_data_becomes_empty_parameter_data(self): + """data=None (old API) must be accepted and become a default ParameterData.""" + pc = ParameterConfig( + ptype=ParameterType.CONSTANT_KERNEL, + config_file=Path("kernel.json"), + data=None, + ) + assert isinstance(pc.data, ParameterData) + + def test_no_config_file(self): + pc = ParameterConfig(ptype=ParameterType.OFFSET_TIME, config_file=None) + assert pc.config_file is None + + def test_invalid_ptype_raises_validation_error(self): + with pytest.raises(ValidationError) as exc_info: + ParameterConfig(ptype="INVALID_TYPE", config_file=None) + assert "ptype" in str(exc_info.value) + + def test_all_parameter_types_accepted(self): + for ptype in ParameterType: + pc = ParameterConfig(ptype=ptype) + assert pc.ptype == ptype + + +# =========================================================================== +# GeolocationConfig +# =========================================================================== + + +class TestGeolocationConfig: + def test_basic_construction(self, geo): + assert geo.instrument_name == "TEST_INSTRUMENT" + assert geo.time_field == "corrected_timestamp" + assert len(geo.dynamic_kernels) == 2 + + def test_dynamic_kernels_default_empty(self): + g = GeolocationConfig( + meta_kernel_file=Path("x.json"), + generic_kernel_dir=Path("data"), + instrument_name="INST", + time_field="ts", + ) + assert g.dynamic_kernels == [] + + def test_path_fields_coerce_strings(self): + g = GeolocationConfig( + meta_kernel_file="path/to/mk.json", + generic_kernel_dir="data/generic", + instrument_name="I", + time_field="t", + ) + assert isinstance(g.meta_kernel_file, Path) + assert isinstance(g.generic_kernel_dir, Path) + + def test_missing_required_field_raises(self): + with pytest.raises(ValidationError) as exc_info: + GeolocationConfig( + meta_kernel_file=Path("x.json"), + generic_kernel_dir=Path("data"), + # instrument_name missing + time_field="ts", + ) + assert "instrument_name" in str(exc_info.value) + + def test_minimum_correlation_optional(self, geo): + assert geo.minimum_correlation is None + geo_with_corr = GeolocationConfig( + meta_kernel_file=Path("x.json"), + generic_kernel_dir=Path("d"), + instrument_name="I", + time_field="t", + minimum_correlation=0.7, + ) + assert geo_with_corr.minimum_correlation == 0.7 + + +# =========================================================================== +# NetCDFParameterMetadata +# =========================================================================== + + +class TestNetCDFParameterMetadata: + def test_construction(self): + m = NetCDFParameterMetadata( + variable_name="param_hysics_roll", + units="arcseconds", + long_name="HySICS roll correction", + ) + assert m.variable_name == "param_hysics_roll" + + def test_missing_field_raises(self): + with pytest.raises(ValidationError) as exc_info: + NetCDFParameterMetadata(variable_name="x", units="m") # long_name missing + assert "long_name" in str(exc_info.value) + + +# =========================================================================== +# NetCDFConfig +# =========================================================================== + + +class TestNetCDFConfig: + def test_threshold_metric_name(self, netcdf_cfg): + assert netcdf_cfg.get_threshold_metric_name() == "percent_under_250m" + + def test_threshold_metric_name_round(self): + nc = NetCDFConfig(performance_threshold_m=500.0) + assert nc.get_threshold_metric_name() == "percent_under_500m" + + def test_get_standard_attributes_defaults(self, netcdf_cfg): + attrs = netcdf_cfg.get_standard_attributes() + assert "rms_error_m" in attrs + assert attrs["rms_error_m"]["units"] == "meters" + + def test_get_standard_attributes_override(self): + custom = {"my_var": {"units": "km", "long_name": "My Variable"}} + nc = NetCDFConfig(performance_threshold_m=100.0, standard_attributes=custom) + assert nc.get_standard_attributes() == custom + + def test_auto_generate_metadata_constant_kernel(self, netcdf_cfg, param_constant): + meta = netcdf_cfg.get_parameter_netcdf_metadata(param_constant, angle_type="roll") + assert "roll" in meta.long_name + assert meta.units == "arcseconds" + assert meta.variable_name.startswith("param_") + + def test_auto_generate_metadata_offset_time(self, netcdf_cfg, param_offset_time): + meta = netcdf_cfg.get_parameter_netcdf_metadata(param_offset_time) + assert meta.units == "milliseconds" + + def test_missing_threshold_raises(self): + with pytest.raises(ValidationError) as exc_info: + NetCDFConfig() # performance_threshold_m required + assert "performance_threshold_m" in str(exc_info.value) + + +# =========================================================================== +# CorrectionConfig +# =========================================================================== + + +class TestCorrectionConfig: + def test_basic_construction(self, minimal_config): + assert minimal_config.n_iterations == 5 + assert minimal_config.seed == 42 + assert len(minimal_config.parameters) == 1 + + def test_callable_fields_default_none(self, minimal_config): + """_image_matching_override defaults to None.""" + assert minimal_config._image_matching_override is None + + def test_image_matching_override_can_be_set(self, minimal_config): + """_image_matching_override accepts any callable.""" + + def my_func(*args, **kwargs): + return None + + minimal_config._image_matching_override = my_func + assert minimal_config._image_matching_override is my_func + minimal_config._image_matching_override = None + + def test_image_matching_func_deprecated_getter(self, minimal_config): + """Accessing image_matching_func property emits DeprecationWarning.""" + import warnings + + minimal_config._image_matching_override = lambda: "x" + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + val = minimal_config.image_matching_func + assert any(issubclass(w.category, DeprecationWarning) for w in caught) + assert val is minimal_config._image_matching_override + minimal_config._image_matching_override = None + + def test_image_matching_func_deprecated_setter(self, minimal_config): + """Setting image_matching_func via deprecated property emits DeprecationWarning.""" + import warnings + + def my_func(*args, **kwargs): + return None + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + minimal_config.image_matching_func = my_func + assert any(issubclass(w.category, DeprecationWarning) for w in caught) + assert minimal_config._image_matching_override is my_func + minimal_config._image_matching_override = None + + def test_mutable_fields(self, minimal_config): + minimal_config.n_iterations = 99 + assert minimal_config.n_iterations == 99 + minimal_config.n_iterations = 5 + + def test_missing_required_field_raises(self, geo, param_constant): + with pytest.raises(ValidationError) as exc_info: + CorrectionConfig( + n_iterations=5, + parameters=[param_constant], + geo=geo, + # performance_threshold_m missing + performance_spec_percent=39.0, + ) + assert "performance_threshold_m" in str(exc_info.value) + + def test_invalid_n_iterations_type_raises(self, geo, param_constant): + with pytest.raises(ValidationError) as exc_info: + CorrectionConfig( + n_iterations="not-an-int", + parameters=[param_constant], + geo=geo, + performance_threshold_m=250.0, + performance_spec_percent=39.0, + ) + assert "n_iterations" in str(exc_info.value) + + def test_validate_method_passes_for_valid_config(self, minimal_config): + minimal_config.validate() # must not raise + + def test_validate_method_accepts_legacy_check_loaders_kwarg(self, minimal_config): + """check_loaders is accepted for backward compat but has no effect.""" + minimal_config.validate(check_loaders=False) # must not raise + minimal_config.validate(check_loaders=True) # must also not raise + + def test_validate_method_raises_for_bad_n_iterations(self, minimal_config): + minimal_config.n_iterations = -1 + with pytest.raises(ValueError, match="n_iterations"): + minimal_config.validate() + minimal_config.n_iterations = 5 + + def test_correction_config_requires_no_earth_radius_m(self, geo, param_constant): + """CorrectionConfig constructs successfully without earth_radius_m (field removed).""" + config = CorrectionConfig( + n_iterations=5, + parameters=[param_constant], + geo=geo, + performance_threshold_m=250.0, + performance_spec_percent=39.0, + ) + assert not hasattr(config, "earth_radius_m") or config.model_fields.get("earth_radius_m") is None + + def test_ensure_netcdf_config_creates_default(self, minimal_config): + assert minimal_config.netcdf is None + minimal_config.ensure_netcdf_config() + assert isinstance(minimal_config.netcdf, NetCDFConfig) + assert minimal_config.netcdf.performance_threshold_m == 250.0 + + def test_ensure_netcdf_config_idempotent(self, minimal_config): + minimal_config.ensure_netcdf_config() + first = minimal_config.netcdf + minimal_config.ensure_netcdf_config() + assert minimal_config.netcdf is first # same object + + def test_get_output_filename_default(self, minimal_config): + assert minimal_config.get_output_filename() == "correction_results.nc" + + def test_get_output_filename_custom(self, minimal_config): + minimal_config.output_filename = "my_results.nc" + assert minimal_config.get_output_filename() == "my_results.nc" + minimal_config.output_filename = None + + def test_get_calibration_file(self, full_config): + assert full_config.get_calibration_file("los_vectors") == "b_HS.mat" + + def test_get_calibration_file_missing_raises(self, full_config): + with pytest.raises(ValueError, match="No calibration file configured"): + full_config.get_calibration_file("nonexistent_type") + + def test_get_calibration_file_default_fallback(self, full_config): + assert full_config.get_calibration_file("nonexistent_type", default="fallback.mat") == "fallback.mat" + + +# =========================================================================== +# JSON Round-Trip (acceptance criterion) +# =========================================================================== + + +class TestJsonRoundTrip: + """config == CorrectionConfig.model_validate_json(config.model_dump_json())""" + + def test_minimal_config_roundtrip(self, minimal_config): + json_str = minimal_config.model_dump_json() + reloaded = CorrectionConfig.model_validate_json(json_str) + assert minimal_config == reloaded + + def test_full_config_roundtrip(self, full_config): + json_str = full_config.model_dump_json() + reloaded = CorrectionConfig.model_validate_json(json_str) + assert full_config == reloaded + + def test_callable_fields_excluded_from_json(self, minimal_config): + """_image_matching_override (PrivateAttr) is always excluded from JSON serialisation.""" + minimal_config._image_matching_override = lambda: None + json_str = minimal_config.model_dump_json() + assert "image_matching_func" not in json_str + assert "_image_matching_override" not in json_str + # clean up + minimal_config._image_matching_override = None + + def test_json_contains_expected_keys(self, minimal_config): + import json + + data = json.loads(minimal_config.model_dump_json()) + assert "n_iterations" in data + assert "parameters" in data + assert "geo" in data + assert "performance_threshold_m" in data + assert "performance_spec_percent" in data + assert "earth_radius_m" not in data + + def test_path_fields_survive_roundtrip(self, minimal_config): + reloaded = CorrectionConfig.model_validate_json(minimal_config.model_dump_json()) + assert isinstance(reloaded.geo.meta_kernel_file, Path) + assert reloaded.geo.meta_kernel_file == minimal_config.geo.meta_kernel_file + assert reloaded.geo.generic_kernel_dir == minimal_config.geo.generic_kernel_dir + + def test_parameter_type_enum_survives_roundtrip(self, full_config): + reloaded = CorrectionConfig.model_validate_json(full_config.model_dump_json()) + for orig, reld in zip(full_config.parameters, reloaded.parameters): + assert orig.ptype == reld.ptype + + def test_parameter_data_fields_survive_roundtrip(self, full_config): + reloaded = CorrectionConfig.model_validate_json(full_config.model_dump_json()) + for orig, reld in zip(full_config.parameters, reloaded.parameters): + assert orig.data.sigma == reld.data.sigma + assert orig.data.units == reld.data.units + assert orig.data.bounds == reld.data.bounds + assert orig.data.field == reld.data.field + + def test_netcdf_config_survives_roundtrip(self, full_config): + reloaded = CorrectionConfig.model_validate_json(full_config.model_dump_json()) + assert reloaded.netcdf is not None + assert reloaded.netcdf.performance_threshold_m == full_config.netcdf.performance_threshold_m + assert reloaded.netcdf.title == full_config.netcdf.title + + def test_geo_model_roundtrip_standalone(self, geo): + json_str = geo.model_dump_json() + reloaded = GeolocationConfig.model_validate_json(json_str) + assert geo == reloaded + + def test_netcdf_config_roundtrip_standalone(self, netcdf_cfg): + json_str = netcdf_cfg.model_dump_json() + reloaded = NetCDFConfig.model_validate_json(json_str) + assert netcdf_cfg == reloaded + + def test_parameter_config_roundtrip_standalone(self, param_constant): + json_str = param_constant.model_dump_json() + reloaded = ParameterConfig.model_validate_json(json_str) + assert param_constant == reloaded + + def test_parameter_data_roundtrip_standalone(self): + pd = ParameterData( + current_value=[1.0, 2.0, 3.0], + bounds=[-300.0, 300.0], + sigma=30.0, + units="arcseconds", + coordinate_frames=["F1", "F2"], + ) + reloaded = ParameterData.model_validate_json(pd.model_dump_json()) + assert pd == reloaded + + def test_roundtrip_with_none_seed(self, geo, param_constant): + config = CorrectionConfig( + seed=None, + n_iterations=3, + parameters=[param_constant], + geo=geo, + performance_threshold_m=100.0, + performance_spec_percent=50.0, + ) + reloaded = CorrectionConfig.model_validate_json(config.model_dump_json()) + assert reloaded.seed is None + assert config == reloaded + + +# =========================================================================== +# load_config_from_json – earth_radius_m deprecation +# =========================================================================== + + +class TestLoadConfigFromJsonEarthRadius: + """Verify that earth_radius_m in JSON is accepted (with a warning) and ignored.""" + + def _minimal_json(self, tmp_path, *, include_earth_radius: bool) -> Path: + """Write a minimal valid correction config JSON to a temp file.""" + import json + + # Use the parameter format that load_config_from_json expects + # (name + parameter_type, not ptype). + payload = { + "mission_config": { + "mission_name": "TEST", + "kernel_mappings": { + "constant_kernel": {}, + "offset_kernel": {}, + }, + }, + "geolocation": { + "instrument_name": "TEST_INST", + "time_field": "ugps_time", + "meta_kernel_file": str(tmp_path / "test.kernels.tm.json"), + "generic_kernel_dir": str(tmp_path), + }, + "correction": { + "n_iterations": 2, + "performance_threshold_m": 250.0, + "performance_spec_percent": 39.0, + "parameters": [ + { + "name": "time_correction", + "parameter_type": "OFFSET_TIME", + "initial_value": 0.0, + "bounds": [-50.0, 50.0], + "sigma": 7.0, + "units": "milliseconds", + "field": "ugps_time", + } + ], + }, + } + if include_earth_radius: + payload["correction"]["earth_radius_m"] = 6_378_140.0 + + path = tmp_path / "config.json" + path.write_text(json.dumps(payload)) + return path + + def test_json_without_earth_radius_loads_fine(self, tmp_path): + """Config without earth_radius_m should load without error.""" + config_path = self._minimal_json(tmp_path, include_earth_radius=False) + config = load_config_from_json(config_path) + assert config.performance_threshold_m == 250.0 + + def test_json_with_earth_radius_loads_with_warning(self, tmp_path, caplog): + """Config with legacy earth_radius_m loads but emits a deprecation warning.""" + import logging + + config_path = self._minimal_json(tmp_path, include_earth_radius=True) + with caplog.at_level(logging.WARNING, logger="curryer.correction.config"): + config = load_config_from_json(config_path) + + assert config.performance_threshold_m == 250.0 + assert any("earth_radius_m" in msg and "deprecated" in msg for msg in caplog.messages) + + +# --------------------------------------------------------------------------- +# CorrectionInput +# --------------------------------------------------------------------------- + + +class TestCorrectionInput: + """Tests for the typed CorrectionInput model.""" + + def test_basic_construction(self, tmp_path): + from curryer.correction.config import CorrectionInput + + inp = CorrectionInput( + telemetry_file=tmp_path / "tlm.csv", + science_file=tmp_path / "sci.csv", + gcp_file=tmp_path / "gcp.mat", + ) + assert inp.telemetry_file == tmp_path / "tlm.csv" + assert inp.science_file == tmp_path / "sci.csv" + assert inp.gcp_file == tmp_path / "gcp.mat" + + def test_string_paths_coerced_to_path(self, tmp_path): + from curryer.correction.config import CorrectionInput + + inp = CorrectionInput( + telemetry_file="data/tlm.csv", + science_file="data/sci.csv", + gcp_file="gcps/chip.mat", + ) + assert isinstance(inp.telemetry_file, Path) + assert isinstance(inp.science_file, Path) + assert isinstance(inp.gcp_file, Path) + + def test_run_correction_accepts_correction_input(self, minimal_config, tmp_path): + """run_correction() normalises CorrectionInput to tuples before calling loop().""" + from unittest.mock import patch + + from curryer.correction.config import CorrectionInput + from curryer.correction.pipeline import run_correction + + inp = CorrectionInput( + telemetry_file=tmp_path / "tlm.csv", + science_file=tmp_path / "sci.csv", + gcp_file=tmp_path / "gcp.mat", + ) + + with patch("curryer.correction.pipeline.loop") as mock_loop: + mock_loop.return_value = ([], {}) + run_correction(minimal_config, tmp_path, [inp]) + + mock_loop.assert_called_once() + call_args = mock_loop.call_args + normalized_inputs = call_args[0][2] # third positional arg + assert len(normalized_inputs) == 1 + assert normalized_inputs[0] == ( + str(tmp_path / "tlm.csv"), + str(tmp_path / "sci.csv"), + str(tmp_path / "gcp.mat"), + ) + + def test_run_correction_accepts_legacy_tuples(self, minimal_config, tmp_path): + """run_correction() passes legacy tuples through unchanged.""" + from unittest.mock import patch + + from curryer.correction.pipeline import run_correction + + tuples = [("tlm.csv", "sci.csv", "gcp.mat")] + + with patch("curryer.correction.pipeline.loop") as mock_loop: + mock_loop.return_value = ([], {}) + run_correction(minimal_config, tmp_path, tuples) + + call_args = mock_loop.call_args + assert call_args[0][2] == tuples diff --git a/tests/test_correction/test_correction.py b/tests/test_correction/test_correction.py deleted file mode 100644 index bbb528a8..00000000 --- a/tests/test_correction/test_correction.py +++ /dev/null @@ -1,2515 +0,0 @@ -#!/usr/bin/env python3 -""" -Unified Correction Test Suite - -This module consolidates two complementary Correction test approaches: - -1. UPSTREAM Testing (run_upstream_pipeline): - - Tests kernel creation and geolocation with parameter variations - - Uses real telemetry data - - Validates parameter modification and kernel generation - - Stops before pairing (no valid GCP pairs available) - -2. DOWNSTREAM Testing (run_downstream_pipeline): - - Tests GCP pairing, image matching, and error statistics - - Uses pre-geolocated test images with known GCP pairs - - Validates spatial pairing, image matching algorithms, and error metrics - - Skips kernel/geolocation (uses pre-computed test data) - -Both tests share the same CLARREO configuration base but configure differently -for their specific testing needs. - -Running Tests: -------------- -# Via pytest (recommended) -pytest tests/test_correction/test_correction.py -v - -# Run specific test -pytest tests/test_correction/test_correction.py::test_generate_clarreo_config_json -v - -# Standalone execution with arguments (for pipeline runs) -python tests/test_correction/test_correction.py --mode downstream --quick - -Requirements: ------------------ -These tests validate the complete Correction geolocation pipeline, -demonstrating parameter sensitivity analysis and error statistics -computation for mission requirements validation. - -""" - -import argparse -import atexit -import json -import logging -import shutil -import sys -import tempfile -import time -import unittest -from pathlib import Path - -import numpy as np -import pandas as pd -import pytest -import xarray as xr -from scipy.io import loadmat - -from curryer import meta, utils -from curryer import spicierpy as sp -from curryer.compute import constants -from curryer.correction import correction -from curryer.correction.data_structures import ( - GeolocationConfig as ImageMatchGeolocationConfig, -) -from curryer.correction.data_structures import ( - ImageGrid, - SearchConfig, -) -from curryer.correction.image_match import ( - integrated_image_match, - load_image_grid_from_mat, - load_los_vectors_from_mat, - load_optical_psf_from_mat, -) -from curryer.correction.pairing import find_l1a_gcp_pairs -from curryer.kernels import create - -# Import CLARREO config and data loaders -sys.path.insert(0, str(Path(__file__).parent)) -from clarreo_config import create_clarreo_correction_config -from clarreo_data_loaders import load_clarreo_gcp, load_clarreo_science, load_clarreo_telemetry - -logger = logging.getLogger(__name__) -utils.enable_logging(log_level=logging.INFO, extra_loggers=[__name__]) - - -# ============================================================================= -# TEST PLACEHOLDER FUNCTIONS (For Synthetic Data Generation) -# ============================================================================= -# These functions generate SYNTHETIC test data for testing the Correction pipeline - - -class _PlaceholderConfig: - """Configuration for placeholder test data generation.""" - - base_error_m: float = 50.0 - param_error_scale: float = 10.0 - max_measurements: int = 100 - min_measurements: int = 10 - orbit_radius_mean_m: float = 6.78e6 - orbit_radius_std_m: float = 4e3 - latitude_range: tuple[float, float] = (-60.0, 60.0) - longitude_range: tuple[float, float] = (-180.0, 180.0) - altitude_range: tuple[float, float] = (0.0, 1000.0) - max_off_nadir_rad: float = 0.1 - - -def synthetic_gcp_pairing(science_data_files): - """Generate SYNTHETIC GCP pairs for testing (TEST ONLY - not a test itself).""" - logger.warning("=" * 80) - logger.warning("!!!!️ USING SYNTHETIC GCP PAIRING - FAKE DATA! !!!!️") - logger.warning("=" * 80) - synthetic_pairs = [(f"{sci_file}", f"landsat_gcp_{i:03d}.tif") for i, sci_file in enumerate(science_data_files)] - return synthetic_pairs - - -def synthetic_image_matching( - geolocated_data, - gcp_reference_file, - telemetry, - calibration_dir, - params_info, - config, - los_vectors_cached=None, - optical_psfs_cached=None, -): - """ - Generate SYNTHETIC image matching error data (TEST ONLY - not a test itself). - - This function matches the signature of the real image_matching() function - but only uses a subset of parameters for upstream testing of the Correction. - - Used parameters: - geolocated_data: For generating realistic synthetic errors - params_info: To scale errors based on parameter variations - config: For coordinate names and placeholder configuration - - Ignored parameters (accepted for compatibility): - gcp_reference_file: Not needed for synthetic data - telemetry: Not needed for synthetic data - calibration_dir: Not needed for synthetic data - los_vectors_cached: Not needed for synthetic data - optical_psfs_cached: Not needed for synthetic data - """ - logger.warning("=" * 80) - logger.warning("!!!!️ USING SYNTHETIC IMAGE MATCHING - FAKE DATA! !!!!️") - logger.warning("=" * 80) - - placeholder_cfg = ( - config.placeholder if hasattr(config, "placeholder") and config.placeholder else _PlaceholderConfig() - ) - sc_pos_name = getattr(config, "spacecraft_position_name", "sc_position") - boresight_name = getattr(config, "boresight_name", "boresight") - transform_name = getattr(config, "transformation_matrix_name", "t_inst2ref") - - valid_mask = ~np.isnan(geolocated_data["latitude"].values).any(axis=1) - n_valid = valid_mask.sum() - n_measurements = ( - placeholder_cfg.min_measurements if n_valid == 0 else min(n_valid, placeholder_cfg.max_measurements) - ) - - # Generate realistic synthetic data - riss_ctrs = _generate_spherical_positions( - n_measurements, placeholder_cfg.orbit_radius_mean_m, placeholder_cfg.orbit_radius_std_m - ) - boresights = _generate_synthetic_boresights(n_measurements, placeholder_cfg.max_off_nadir_rad) - t_matrices = _generate_nadir_aligned_transforms(n_measurements, riss_ctrs, boresights) - - # Generate synthetic errors - base_error = placeholder_cfg.base_error_m - param_contribution = ( - sum(abs(p) if isinstance(p, int | float) else np.linalg.norm(p) for _, p in params_info) - * placeholder_cfg.param_error_scale - ) - error_magnitude = base_error + param_contribution - lat_errors = np.random.normal(0, error_magnitude / 111000, n_measurements) - lon_errors = np.random.normal(0, error_magnitude / 111000, n_measurements) - - if n_valid > 0: - valid_indices = np.where(valid_mask)[0][:n_measurements] - gcp_lat = geolocated_data["latitude"].values[valid_indices, 0] - gcp_lon = geolocated_data["longitude"].values[valid_indices, 0] - else: - gcp_lat = np.random.uniform(*placeholder_cfg.latitude_range, n_measurements) - gcp_lon = np.random.uniform(*placeholder_cfg.longitude_range, n_measurements) - - gcp_alt = np.random.uniform(*placeholder_cfg.altitude_range, n_measurements) - - return xr.Dataset( - { - "lat_error_deg": (["measurement"], lat_errors), - "lon_error_deg": (["measurement"], lon_errors), - sc_pos_name: (["measurement", "xyz"], riss_ctrs), - boresight_name: (["measurement", "xyz"], boresights), - transform_name: (["measurement", "xyz_from", "xyz_to"], t_matrices), - "gcp_lat_deg": (["measurement"], gcp_lat), - "gcp_lon_deg": (["measurement"], gcp_lon), - "gcp_alt": (["measurement"], gcp_alt), - }, - coords={ - "measurement": range(n_measurements), - "xyz": ["x", "y", "z"], - "xyz_from": ["x", "y", "z"], - "xyz_to": ["x", "y", "z"], - }, - ) - - -def _generate_synthetic_boresights(n_measurements, max_off_nadir_rad=0.07): - """Generate synthetic boresight vectors (test helper).""" - boresights = np.zeros((n_measurements, 3)) - for i in range(n_measurements): - theta = np.random.uniform(-max_off_nadir_rad, max_off_nadir_rad) - boresights[i] = [0.0, np.sin(theta), np.cos(theta)] - return boresights - - -def _generate_spherical_positions(n_measurements, radius_mean_m, radius_std_m): - """Generate synthetic spacecraft positions on sphere (test helper).""" - positions = np.zeros((n_measurements, 3)) - for i in range(n_measurements): - radius = np.random.normal(radius_mean_m, radius_std_m) - phi = np.random.uniform(0, 2 * np.pi) - cos_theta = np.random.uniform(-1, 1) - sin_theta = np.sqrt(1 - cos_theta**2) - positions[i] = [radius * sin_theta * np.cos(phi), radius * sin_theta * np.sin(phi), radius * cos_theta] - return positions - - -def _generate_nadir_aligned_transforms(n_measurements, riss_ctrs, boresights_hs): - """Generate transformation matrices aligning boresights with nadir (test helper).""" - t_matrices = np.zeros((n_measurements, 3, 3)) - for i in range(n_measurements): - nadir_ctrs = -riss_ctrs[i] / np.linalg.norm(riss_ctrs[i]) - bhat_hs_norm = boresights_hs[i] / np.linalg.norm(boresights_hs[i]) - rotation_axis = np.cross(bhat_hs_norm, nadir_ctrs) - axis_norm = np.linalg.norm(rotation_axis) - - if axis_norm < 1e-6: - if np.dot(bhat_hs_norm, nadir_ctrs) > 0: - t_matrices[i] = np.eye(3) - else: - perp = np.array([1, 0, 0]) if abs(bhat_hs_norm[0]) < 0.9 else np.array([0, 1, 0]) - rotation_axis = np.cross(bhat_hs_norm, perp) / np.linalg.norm(np.cross(bhat_hs_norm, perp)) - K = np.array( - [ - [0, -rotation_axis[2], rotation_axis[1]], - [rotation_axis[2], 0, -rotation_axis[0]], - [-rotation_axis[1], rotation_axis[0], 0], - ] - ) - t_matrices[i] = np.eye(3) + 2 * K @ K - else: - rotation_axis = rotation_axis / axis_norm - angle = np.arccos(np.clip(np.dot(bhat_hs_norm, nadir_ctrs), -1.0, 1.0)) - K = np.array( - [ - [0, -rotation_axis[2], rotation_axis[1]], - [rotation_axis[2], 0, -rotation_axis[0]], - [-rotation_axis[1], rotation_axis[0], 0], - ] - ) - t_matrices[i] = np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * (K @ K) - - return t_matrices - - -# ============================================================================= -# TEST MODE FUNCTIONS (Extracted from correction.py) -# ============================================================================= -# These functions were moved from the core correction module to keep test-specific -# code separate from mission-agnostic core functionality. - - -def discover_test_image_match_cases(test_data_dir: Path, test_cases: list[str] | None = None) -> list[dict]: - """ - Discover available image matching test cases. - - This function scans the test data directory for validated image matching - test cases and returns metadata about available test files. - - Args: - test_data_dir: Root directory for test data (tests/data/clarreo/image_match/) - test_cases: Specific test cases to use (e.g., ['1', '2']) or None for all - - Returns: - List of test case dictionaries with file paths and metadata - """ - logger.info(f"Discovering image matching test cases in: {test_data_dir}") - - # Shared calibration files (same for all test cases) - los_file = test_data_dir / "b_HS.mat" - psf_file_unbinned = test_data_dir / "optical_PSF_675nm_upsampled.mat" - psf_file_binned = test_data_dir / "optical_PSF_675nm_3_pix_binned_upsampled.mat" - - if not los_file.exists(): - raise FileNotFoundError(f"LOS vectors file not found: {los_file}") - if not psf_file_unbinned.exists(): - raise FileNotFoundError(f"Optical PSF file not found: {psf_file_unbinned}") - - # Test case metadata (from test_image_match.py) - test_case_metadata = { - "1": { - "name": "Dili", - "gcp_file": "GCP12055Dili_resampled.mat", - "ancil_file": "R_ISS_midframe_TestCase1.mat", - "expected_error_km": (3.0, -3.0), # (lat, lon) - "cases": [ - {"subimage": "TestCase1a_subimage.mat", "binned": False}, - {"subimage": "TestCase1b_subimage.mat", "binned": False}, - {"subimage": "TestCase1c_subimage_binned.mat", "binned": True}, - {"subimage": "TestCase1d_subimage_binned.mat", "binned": True}, - ], - }, - "2": { - "name": "Maracaibo", - "gcp_file": "GCP10121Maracaibo_resampled.mat", - "ancil_file": "R_ISS_midframe_TestCase2.mat", - "expected_error_km": (-3.0, 2.0), - "cases": [ - {"subimage": "TestCase2a_subimage.mat", "binned": False}, - {"subimage": "TestCase2b_subimage.mat", "binned": False}, - {"subimage": "TestCase2c_subimage_binned.mat", "binned": True}, - ], - }, - "3": { - "name": "Algeria3", - "gcp_file": "GCP10181Algeria3_resampled.mat", - "ancil_file": "R_ISS_midframe_TestCase3.mat", - "expected_error_km": (2.0, 3.0), - "cases": [ - {"subimage": "TestCase3a_subimage.mat", "binned": False}, - {"subimage": "TestCase3b_subimage_binned.mat", "binned": True}, - ], - }, - "4": { - "name": "Dunhuang", - "gcp_file": "GCP10142Dunhuang_resampled.mat", - "ancil_file": "R_ISS_midframe_TestCase4.mat", - "expected_error_km": (-2.0, -3.0), - "cases": [ - {"subimage": "TestCase4a_subimage.mat", "binned": False}, - {"subimage": "TestCase4b_subimage_binned.mat", "binned": True}, - ], - }, - "5": { - "name": "Algeria5", - "gcp_file": "GCP10071Algeria5_resampled.mat", - "ancil_file": "R_ISS_midframe_TestCase5.mat", - "expected_error_km": (1.0, -1.0), - "cases": [ - {"subimage": "TestCase5a_subimage.mat", "binned": False}, - ], - }, - } - - # Filter to requested test cases - if test_cases is None: - test_cases = sorted(test_case_metadata.keys()) - - discovered_cases = [] - - for case_id in test_cases: - if case_id not in test_case_metadata: - logger.warning(f"Test case '{case_id}' not found in metadata, skipping") - continue - - metadata = test_case_metadata[case_id] - case_dir = test_data_dir / case_id - - if not case_dir.is_dir(): - logger.warning(f"Test case directory not found: {case_dir}, skipping") - continue - - # Add each subcase variant (a, b, c, d) - for subcase in metadata["cases"]: - subimage_file = case_dir / subcase["subimage"] - gcp_file = case_dir / metadata["gcp_file"] - ancil_file = case_dir / metadata["ancil_file"] - psf_file = psf_file_binned if subcase["binned"] else psf_file_unbinned - - # Validate all files exist - if not subimage_file.exists(): - logger.warning(f"Subimage file not found: {subimage_file}, skipping") - continue - if not gcp_file.exists(): - logger.warning(f"GCP file not found: {gcp_file}, skipping") - continue - if not ancil_file.exists(): - logger.warning(f"Ancillary file not found: {ancil_file}, skipping") - continue - - discovered_cases.append( - { - "case_id": case_id, - "case_name": metadata["name"], - "subcase_name": subcase["subimage"], - "subimage_file": subimage_file, - "gcp_file": gcp_file, - "ancil_file": ancil_file, - "los_file": los_file, - "psf_file": psf_file, - "expected_lat_error_km": metadata["expected_error_km"][0], - "expected_lon_error_km": metadata["expected_error_km"][1], - "binned": subcase["binned"], - } - ) - - logger.info(f"Discovered {len(discovered_cases)} test case variants from {len(test_cases)} test case groups") - for case in discovered_cases: - logger.info( - f" - {case['case_id']}/{case['subcase_name']}: {case['case_name']}, " - f"expected error=({case['expected_lat_error_km']:.1f}, {case['expected_lon_error_km']:.1f}) km" - ) - - return discovered_cases - - -def apply_error_variation_for_testing( - base_result: xr.Dataset, param_idx: int, error_variation_percent: float = 3.0 -) -> xr.Dataset: - """ - Apply random variation to image matching results to simulate parameter effects. - - This is used in test mode to simulate how different parameter values would - affect geolocation errors, without actually re-running image matching. - - Args: - base_result: Original image matching result - param_idx: Parameter set index (used as random seed) - error_variation_percent: Percentage variation to apply (e.g., 3.0 = ±3%) - - Returns: - New Dataset with varied error values - """ - # Create copy - output = base_result.copy(deep=True) - - # Set reproducible random seed based on param_idx - np.random.seed(param_idx) - - # Generate variation factors (centered at 1.0, with specified percentage variation) - variation_fraction = error_variation_percent / 100.0 - lat_factor = 1.0 + np.random.normal(0, variation_fraction) - lon_factor = 1.0 + np.random.normal(0, variation_fraction) - ccv_factor = 1.0 + np.random.normal(0, variation_fraction / 10.0) # Smaller variation for correlation - - # Apply variations to error values - original_lat_km = base_result.attrs["lat_error_km"] - original_lon_km = base_result.attrs["lon_error_km"] - original_ccv = base_result.attrs["correlation_ccv"] - - varied_lat_km = original_lat_km * lat_factor - varied_lon_km = original_lon_km * lon_factor - varied_ccv = np.clip(original_ccv * ccv_factor, 0.0, 1.0) # Keep correlation in valid range - - # Update dataset values - # Get GCP center latitude from multiple possible sources - if "gcp_center_lat" in base_result.attrs: - gcp_center_lat = base_result.attrs["gcp_center_lat"] - elif "gcp_lat_deg" in base_result: - gcp_center_lat = float(base_result["gcp_lat_deg"].values[0]) - else: - # Fallback to a reasonable default (mid-latitude) - logger.warning("GCP center latitude not found in dataset, using default 45.0°") - gcp_center_lat = 45.0 - - lat_error_deg = varied_lat_km / 111.0 - lon_radius_km = 6378.0 * np.cos(np.deg2rad(gcp_center_lat)) - lon_error_deg = varied_lon_km / (lon_radius_km * np.pi / 180.0) - - output["lat_error_deg"].values[0] = lat_error_deg - output["lon_error_deg"].values[0] = lon_error_deg - - # Update attributes - output.attrs["lat_error_km"] = varied_lat_km - output.attrs["lon_error_km"] = varied_lon_km - output.attrs["correlation_ccv"] = varied_ccv - output.attrs["param_idx"] = param_idx - output.attrs["variation_applied"] = True - output.attrs["variation_lat_factor"] = lat_factor - output.attrs["variation_lon_factor"] = lon_factor - - logger.info( - f" Applied variation: lat {original_lat_km:.3f} → {varied_lat_km:.3f} km ({(lat_factor - 1) * 100:+.1f}%), " - f"lon {original_lon_km:.3f} → {varied_lon_km:.3f} km ({(lon_factor - 1) * 100:+.1f}%)" - ) - - return output - - -# ============================================================================= -# SHARED UTILITY FUNCTIONS -# ============================================================================= - - -def apply_geolocation_error_to_subimage( - subimage: ImageGrid, gcp: ImageGrid, lat_error_km: float, lon_error_km: float -) -> ImageGrid: - """ - Apply artificial geolocation error to a subimage for testing. - - This creates a misaligned subimage that the image matching algorithm - should detect and measure. - """ - mid_lat = float(gcp.lat[gcp.lat.shape[0] // 2, gcp.lat.shape[1] // 2]) - - # WGS84 Earth radius - use semi-major axis for latitude/longitude conversions - earth_radius_km = constants.WGS84_SEMI_MAJOR_AXIS_KM - lat_offset_deg = lat_error_km / earth_radius_km * (180.0 / np.pi) - lon_radius_km = earth_radius_km * np.cos(np.deg2rad(mid_lat)) - lon_offset_deg = lon_error_km / lon_radius_km * (180.0 / np.pi) - - return ImageGrid( - data=subimage.data.copy(), - lat=subimage.lat + lat_offset_deg, - lon=subimage.lon + lon_offset_deg, - h=subimage.h.copy() if subimage.h is not None else None, - ) - - -def run_image_matching_with_applied_errors( - test_case: dict, - param_idx: int, - randomize_errors: bool = True, - error_variation_percent: float = 3.0, - cache_results: bool = True, - cached_result: xr.Dataset | None = None, -) -> xr.Dataset: - """ - Run image matching with artificial errors applied to test data. - - This loads test data, applies known geolocation errors, then runs - image matching to verify it can detect those errors. - - Args: - test_case: Test case dictionary with file paths and expected errors - param_idx: Parameter set index (for variation seed) - randomize_errors: Whether to apply random variations - error_variation_percent: Percentage variation (default 3.0%) - cache_results: Whether to use cached results with variation - cached_result: Previously cached result to vary - """ - # Use cached result with variation if available - if cached_result is not None and cache_results and param_idx > 0: - if randomize_errors: - logger.info(f" Applying ±{error_variation_percent}% variation to cached result") - return apply_error_variation_for_testing(cached_result, param_idx, error_variation_percent) - else: - return cached_result.copy() - - logger.info(f" Running image matching with applied errors: {test_case['case_name']}") - start_time = time.time() - - # Load subimage - subimage_struct = loadmat(test_case["subimage_file"], squeeze_me=True, struct_as_record=False)["subimage"] - subimage = ImageGrid( - data=np.asarray(subimage_struct.data), - lat=np.asarray(subimage_struct.lat), - lon=np.asarray(subimage_struct.lon), - h=np.asarray(subimage_struct.h) if hasattr(subimage_struct, "h") else None, - ) - - # Load GCP - gcp = load_image_grid_from_mat(test_case["gcp_file"], key="GCP") - gcp_center_lat = float(gcp.lat[gcp.lat.shape[0] // 2, gcp.lat.shape[1] // 2]) - gcp_center_lon = float(gcp.lon[gcp.lon.shape[0] // 2, gcp.lon.shape[1] // 2]) - - # Apply expected error - expected_lat_error = test_case["expected_lat_error_km"] - expected_lon_error = test_case["expected_lon_error_km"] - - subimage_with_error = apply_geolocation_error_to_subimage(subimage, gcp, expected_lat_error, expected_lon_error) - - # Load calibration data - los_vectors = load_los_vectors_from_mat(test_case["los_file"]) - optical_psfs = load_optical_psf_from_mat(test_case["psf_file"]) - ancil_data = loadmat(test_case["ancil_file"], squeeze_me=True) - r_iss_midframe = ancil_data["R_ISS_midframe"].ravel() - - # Run image matching - result = integrated_image_match( - subimage=subimage_with_error, - gcp=gcp, - r_iss_midframe_m=r_iss_midframe, - los_vectors_hs=los_vectors, - optical_psfs=optical_psfs, - geolocation_config=ImageMatchGeolocationConfig(), - search_config=SearchConfig(), - ) - - processing_time = time.time() - start_time - - # Convert to dataset format - lat_error_deg = result.lat_error_km / 111.0 - lon_radius_km = 6378.0 * np.cos(np.deg2rad(gcp_center_lat)) - lon_error_deg = result.lon_error_km / (lon_radius_km * np.pi / 180.0) - - t_matrix = np.array( - [ - [-0.418977524967338, 0.748005379751721, 0.514728846515064], - [-0.421890284446342, 0.341604851993858, -0.839830169131854], - [-0.804031356019172, -0.569029065124742, 0.172451447025628], - ] - ) - boresight = np.array([0.0, 0.0625969755450201, 0.99803888634292]) - - output = xr.Dataset( - { - "lat_error_deg": (["measurement"], [lat_error_deg]), - "lon_error_deg": (["measurement"], [lon_error_deg]), - "riss_ctrs": (["measurement", "xyz"], [r_iss_midframe]), - "bhat_hs": (["measurement", "xyz"], [boresight]), - "t_hs2ctrs": (["measurement", "xyz_from", "xyz_to"], t_matrix[np.newaxis, :, :]), - "gcp_lat_deg": (["measurement"], [gcp_center_lat]), - "gcp_lon_deg": (["measurement"], [gcp_center_lon]), - "gcp_alt": (["measurement"], [0.0]), - }, - coords={"measurement": [0], "xyz": ["x", "y", "z"], "xyz_from": ["x", "y", "z"], "xyz_to": ["x", "y", "z"]}, - ) - - output.attrs.update( - { - "lat_error_km": result.lat_error_km, - "lon_error_km": result.lon_error_km, - "correlation_ccv": result.ccv_final, - "final_grid_step_m": result.final_grid_step_m, - "processing_time_s": processing_time, - "test_mode": True, - "param_idx": param_idx, - "gcp_center_lat": gcp_center_lat, - "gcp_center_lon": gcp_center_lon, - } - ) - - return output - - -# ============================================================================= -# CONFIGURATION GENERATION TEST -# ============================================================================= - - -def test_generate_clarreo_config_json(tmp_path): - """Generate CLARREO config JSON and validate structure. - - This test generates the canonical CLARREO configuration JSON file - that is used by all other CLARREO tests. The generated JSON is - saved to configs/ and can be committed for version control. - - This ensures: - - Single source of truth for CLARREO configuration - - Programmatic config matches JSON config - - JSON structure is valid and complete - """ - - logger.info("=" * 80) - logger.info("TEST: Generate CLARREO Configuration JSON") - logger.info("=" * 80) - - # Define paths - data_dir = Path(__file__).parent.parent / "data/clarreo/gcs" - generic_dir = Path("data/generic") - output_path = tmp_path / "configs/clarreo_correction_config.json" - - logger.info(f"Data directory: {data_dir}") - logger.info(f"Generic kernels: {generic_dir}") - logger.info(f"Output path: {output_path}") - - # Generate config programmatically - logger.info("\n1. Generating config programmatically...") - config = create_clarreo_correction_config(data_dir, generic_dir, config_output_path=output_path) - - logger.info(f"✓ Config created: {len(config.parameters)} parameter groups, {config.n_iterations} iterations") - - # Validate the generated JSON exists - logger.info("\n2. Validating generated JSON file...") - assert output_path.exists(), f"Config JSON not created: {output_path}" - logger.info(f"✓ JSON file exists: {output_path}") - - # Reload and verify structure - logger.info("\n3. Reloading and validating JSON structure...") - with open(output_path) as f: - config_data = json.load(f) - - # Validate top-level sections - assert "mission_config" in config_data, "Missing 'mission_config' section" - assert "correction" in config_data, "Missing 'correction' section" - assert "geolocation" in config_data, "Missing 'geolocation' section" - logger.info("✓ All required top-level sections present") - - # Validate mission config - mission_cfg = config_data["mission_config"] - assert mission_cfg["mission_name"] == "CLARREO_Pathfinder" - assert "kernel_mappings" in mission_cfg - logger.info(f"✓ Mission: {mission_cfg['mission_name']}") - - # Validate correction config - corr_cfg = config_data["correction"] - assert "parameters" in corr_cfg - assert isinstance(corr_cfg["parameters"], list) - assert len(corr_cfg["parameters"]) > 0 - assert "seed" in corr_cfg - assert "n_iterations" in corr_cfg - - # NEW: Validate required fields are present - assert "earth_radius_m" in corr_cfg, "Missing 'earth_radius_m' in correction config" - assert "performance_threshold_m" in corr_cfg, "Missing 'performance_threshold_m'" - assert "performance_spec_percent" in corr_cfg, "Missing 'performance_spec_percent'" - - assert corr_cfg["earth_radius_m"] == 6378140.0 - assert corr_cfg["performance_threshold_m"] == 250.0 - assert corr_cfg["performance_spec_percent"] == 39.0 - - logger.info(f"✓ Correction config: {len(corr_cfg['parameters'])} parameters, {corr_cfg['n_iterations']} iterations") - logger.info( - f"✓ Required fields: earth_radius={corr_cfg['earth_radius_m']}, " - f"threshold={corr_cfg['performance_threshold_m']}m, " - f"spec={corr_cfg['performance_spec_percent']}%" - ) - - # Validate geolocation config - geo_cfg = config_data["geolocation"] - assert "meta_kernel_file" in geo_cfg - assert "instrument_name" in geo_cfg - assert geo_cfg["instrument_name"] == "CPRS_HYSICS" - logger.info(f"✓ Geolocation config: instrument={geo_cfg['instrument_name']}") - - # Test that JSON can be loaded back into CorrectionConfig - logger.info("\n4. Testing JSON → CorrectionConfig loading...") - reloaded_config = correction.load_config_from_json(output_path) - assert reloaded_config.n_iterations == config.n_iterations - assert len(reloaded_config.parameters) == len(config.parameters) - assert reloaded_config.earth_radius_m == 6378140.0 - assert reloaded_config.performance_threshold_m == 250.0 - assert reloaded_config.performance_spec_percent == 39.0 - logger.info("✓ JSON successfully loads into CorrectionConfig") - - # Validate reloaded config - logger.info("\n5. Validating reloaded config...") - reloaded_config.validate() - logger.info("✓ Reloaded config passes validation") - - logger.info("\n" + "=" * 80) - logger.info("✓ CONFIG GENERATION TEST PASSED") - logger.info(f"✓ Canonical config saved: {output_path}") - logger.info(f"✓ File size: {output_path.stat().st_size / 1024:.1f} KB") - logger.info("=" * 80) - - # Note: Test functions should not return values per pytest best practices - # All validation is performed via assert statements above - - -# ============================================================================= -# UPSTREAM TESTING (Kernel Creation + Geolocation) -# ============================================================================= - - -def run_upstream_pipeline(n_iterations: int = 5, work_dir: Path | None = None) -> tuple[list, dict, Path]: - """ - Test UPSTREAM segment of Correction pipeline. - - This tests: - - Parameter set generation - - Kernel creation from parameters - - Geolocation with varied parameters - - This does NOT test: - - GCP pairing (no valid pairs available) - - Image matching (no valid data) - - Error statistics (no matched data) - - This is focused on testing the upstream kernel creation - and geolocation part of the pipeline. - - Args: - n_iterations: Number of Correction iterations - work_dir: Working directory for outputs. If None, uses a temporary - directory that will be cleaned up when the process exits. - - Returns: - Tuple of (results, netcdf_data, output_path) - """ - logger.info("=" * 80) - logger.info("UPSTREAM PIPELINE TEST") - logger.info("Tests: Kernel Creation + Geolocation") - logger.info("=" * 80) - - root_dir = Path(__file__).parents[2] - generic_dir = root_dir / "data" / "generic" - data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" - - # Use temporary directory if work_dir not specified - if work_dir is None: - _tmp_dir = tempfile.mkdtemp(prefix="curryer_upstream_") - work_dir = Path(_tmp_dir) - - # Register cleanup to run on process exit - def cleanup_temp_dir(): - if work_dir.exists(): - try: - shutil.rmtree(work_dir) - logger.debug(f"Cleaned up temporary directory: {work_dir}") - except Exception as e: - logger.warning(f"Failed to cleanup {work_dir}: {e}") - - atexit.register(cleanup_temp_dir) - - logger.info(f"Using temporary directory: {work_dir}") - logger.info("(will be cleaned up on process exit)") - else: - work_dir.mkdir(parents=True, exist_ok=True) - - logger.info(f"Work directory: {work_dir}") - logger.info(f"Iterations: {n_iterations}") - - # Create configuration using CLARREO config - config = create_clarreo_correction_config(data_dir, generic_dir) - config.n_iterations = n_iterations - # Set output filename for test (consistent name for version control) - config.output_filename = "upstream_results.nc" - - # Add loaders and processing functions to config (Config-Centric Design) - config.telemetry_loader = load_clarreo_telemetry - config.science_loader = load_clarreo_science - config.gcp_loader = load_clarreo_gcp - config.gcp_pairing_func = synthetic_gcp_pairing # Test helper from this file - config.image_matching_func = synthetic_image_matching # Test helper from this file - - logger.info(f"Configuration loaded:") - logger.info(f" Mission: CLARREO Pathfinder") - logger.info(f" Instrument: {config.geo.instrument_name}") - logger.info(f" Parameters: {len(config.parameters)}") - logger.info(f" Iterations: {n_iterations}") - logger.info(f" Data loaders: telemetry, science, gcp") - logger.info(f" Processing: synthetic pairing and image matching") - - # Prepare data sets (synthetic GCP pairs since we don't have real data) - # For upstream testing, we just need telemetry and science keys - tlm_sci_gcp_sets = [ - ("telemetry_5a", "science_5a", "synthetic_gcp_1"), - ] - - logger.info(f"Data sets: {len(tlm_sci_gcp_sets)} (synthetic for upstream testing)") - - # Execute the Correction loop - all config comes from config object! - # This will test parameter generation, kernel creation, and geolocation - logger.info("=" * 80) - logger.info("EXECUTING CORRECTION UPSTREAM WORKFLOW") - logger.info("=" * 80) - - results, netcdf_data = correction.loop(config, work_dir, tlm_sci_gcp_sets) - - logger.info("=" * 80) - logger.info("UPSTREAM PIPELINE COMPLETE") - logger.info(f"Processed {len(results)} total iterations") - logger.info(f"Generated results for {len(netcdf_data['parameter_set_id'])} parameter sets") - logger.info("=" * 80) - - # Output file is determined by config and saved by loop() - output_file = work_dir / config.get_output_filename() - logger.info(f"Results saved to: {output_file}") - - # Create results summary dict for consistency with downstream - results_summary = { - "mode": "upstream", - "iterations": n_iterations, - "parameter_sets": len(netcdf_data["parameter_set_id"]), - "status": "complete", - } - - return results, results_summary, output_file - - -# ============================================================================= -# DOWNSTREAM TESTING (Pairing + Image Matching + Error Statistics) -# ============================================================================= - - -def run_downstream_pipeline( - n_iterations: int = 5, test_cases: list[str] | None = None, work_dir: Path | None = None -) -> tuple[list, dict, Path]: - """ - Test DOWNSTREAM segment of Correction pipeline. - - IMPORTANT: This test uses a CUSTOM LOOP (not correction.loop()) because it works with - pre-geolocated test data that doesn't have the telemetry/parameters needed for - the normal upstream pipeline. - - Pipeline Comparison: - Normal correction.loop(): Parameters → Kernels → Geolocation → Matching → Stats - This test: Pre-geolocated Test Data → Pairing → Matching → Stats - - Parameter effects are simulated by varying the geolocation errors directly - (bumping lat/lon values), rather than varying parameters and re-running SPICE. - This is the correct approach for testing with pre-geolocated imagery! - - Tests (Real): - - GCP spatial pairing algorithms - - Image matching with real correlation - - Error statistics computation - - Does NOT Test (No Data Available): - - Kernel creation (no telemetry) - - SPICE geolocation (test data is pre-geolocated) - - True parameter sensitivity (simulated via error variation) - - Args: - n_iterations: Number of Correction iterations - test_cases: Specific test cases to use (e.g., ['1', '2']) - work_dir: Working directory for outputs. If None, uses a temporary - directory that will be cleaned up when the process exits. - - Returns: - Tuple of (results, netcdf_data, output_path) - """ - logger.info("=" * 80) - logger.info("DOWNSTREAM PIPELINE TEST") - logger.info("Tests: GCP Pairing + Image Matching + Error Statistics") - logger.info("=" * 80) - - root_dir = Path(__file__).parents[2] - generic_dir = root_dir / "data" / "generic" - data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" - test_data_dir = root_dir / "tests" / "data" / "clarreo" / "image_match" - - # Use temporary directory if work_dir not specified - if work_dir is None: - _tmp_dir = tempfile.mkdtemp(prefix="curryer_downstream_") - work_dir = Path(_tmp_dir) - - # Register cleanup to run on process exit - def cleanup_temp_dir(): - if work_dir.exists(): - try: - shutil.rmtree(work_dir) - logger.debug(f"Cleaned up temporary directory: {work_dir}") - except Exception as e: - logger.warning(f"Failed to cleanup {work_dir}: {e}") - - atexit.register(cleanup_temp_dir) - - logger.info(f"Using temporary directory: {work_dir}") - logger.info("(will be cleaned up on process exit)") - else: - work_dir.mkdir(parents=True, exist_ok=True) - - logger.info(f"Work directory: {work_dir}") - logger.info(f"Test data directory: {test_data_dir}") - logger.info(f"Iterations: {n_iterations}") - logger.info(f"Test cases: {test_cases or 'all'}") - - # Test configuration (simple parameters - no config object needed) - randomize_errors = True - error_variation_percent = 3.0 - cache_results = True - - # Discover test cases - discovered_cases = discover_test_image_match_cases(test_data_dir, test_cases) - logger.info(f"Discovered {len(discovered_cases)} test case variants") - - # ========================================================================== - # STEP 1: GCP PAIRING - # ========================================================================== - logger.info("\n" + "=" * 80) - logger.info("STEP 1: GCP SPATIAL PAIRING") - logger.info("=" * 80) - - # Load L1A images - l1a_images = [] - l1a_to_testcase = {} - - for test_case in discovered_cases: - l1a_img = load_image_grid_from_mat( - test_case["subimage_file"], - key="subimage", - as_named=True, - name=str(test_case["subimage_file"].relative_to(test_data_dir)), - ) - l1a_images.append(l1a_img) - l1a_to_testcase[l1a_img.name] = test_case - - # Load unique GCP references - gcp_files_seen = set() - gcp_images = [] - - for test_case in discovered_cases: - gcp_file = test_case["gcp_file"] - if gcp_file not in gcp_files_seen: - gcp_img = load_image_grid_from_mat( - gcp_file, key="GCP", as_named=True, name=str(gcp_file.relative_to(test_data_dir)) - ) - gcp_images.append(gcp_img) - gcp_files_seen.add(gcp_file) - - logger.info(f"Loaded {len(l1a_images)} L1A images") - logger.info(f"Loaded {len(gcp_images)} unique GCP references") - - # Run spatial pairing - pairing_result = find_l1a_gcp_pairs(l1a_images, gcp_images, max_distance_m=0.0) - logger.info(f"Pairing complete: Found {len(pairing_result.matches)} valid pairs") - - for match in pairing_result.matches: - l1a_name = pairing_result.l1a_images[match.l1a_index].name - gcp_name = pairing_result.gcp_images[match.gcp_index].name - logger.info(f" {l1a_name} ↔ {gcp_name} (distance: {match.distance_m:.1f}m)") - - # Convert to paired test cases - paired_test_cases = [] - for match in pairing_result.matches: - l1a_img = pairing_result.l1a_images[match.l1a_index] - test_case = l1a_to_testcase[l1a_img.name] - paired_test_cases.append(test_case) - - logger.info(f"Created {len(paired_test_cases)} paired test cases") - - # ========================================================================== - # STEP 2: CONFIGURATION - # ========================================================================== - logger.info("\n" + "=" * 80) - logger.info("STEP 2: CONFIGURATION") - logger.info("=" * 80) - - # Create base CLARREO config to get standard settings - base_config = create_clarreo_correction_config(data_dir, generic_dir) - - # Create minimal test config (CorrectionConfig = THE one config) - # For downstream testing, we use minimal parameters (sigma=0) because - # variations come from test_mode_config randomization, not parameter tweaking - config = correction.CorrectionConfig( - # Core settings - seed=42, - n_iterations=n_iterations, - # Minimal parameter (no real variation - sigma=0) - parameters=[ - correction.ParameterConfig( - ptype=correction.ParameterType.CONSTANT_KERNEL, - config_file=data_dir / "cprs_hysics_v01.attitude.ck.json", - data={ - "current_value": [0.0, 0.0, 0.0], - "sigma": 0.0, # No parameter variation (test variations applied differently) - "units": "arcseconds", - "transformation_type": "dcm_rotation", - "coordinate_frames": ["HYSICS_SLIT", "CRADLE_ELEVATION"], - }, - ) - ], - # Copy required fields from base_config - geo=base_config.geo, - performance_threshold_m=base_config.performance_threshold_m, - performance_spec_percent=base_config.performance_spec_percent, - earth_radius_m=base_config.earth_radius_m, - # Copy optional fields from base_config - netcdf=base_config.netcdf, - calibration_file_names=base_config.calibration_file_names, - spacecraft_position_name=base_config.spacecraft_position_name, - boresight_name=base_config.boresight_name, - transformation_matrix_name=base_config.transformation_matrix_name, - ) - - # Add loaders to config (Config-Centric Design) - config.telemetry_loader = load_clarreo_telemetry - config.science_loader = load_clarreo_science - config.gcp_loader = load_clarreo_gcp - - # Validate complete config - config.validate(check_loaders=True) - - logger.info(f"Configuration created:") - logger.info(f" Mission: CLARREO (from clarreo_config)") - logger.info(f" Instrument: {config.geo.instrument_name}") - logger.info(f" Iterations: {config.n_iterations}") - logger.info(f" Parameters: {len(config.parameters)} (minimal for test mode)") - logger.info(f" Sigma: 0.0 (variations from randomization, not parameters)") - logger.info(f" Performance threshold: {config.performance_threshold_m}m") - - # ========================================================================== - # STEP 3: CORRECTION ITERATIONS - # ========================================================================== - logger.info("\n" + "=" * 80) - logger.info("STEP 3: CORRECTION ITERATIONS") - logger.info("=" * 80) - - n_param_sets = n_iterations - n_gcp_pairs = len(paired_test_cases) - - # Use dynamic NetCDF structure builder instead of hardcoded - netcdf_data = correction._build_netcdf_structure(config, n_param_sets, n_gcp_pairs) - logger.info(f"NetCDF structure built dynamically with {len(netcdf_data)} variables") - - # Get threshold metric name dynamically - threshold_metric = config.netcdf.get_threshold_metric_name() - logger.info(f"Using threshold metric: {threshold_metric}") - - image_match_cache = {} - - for param_idx in range(n_iterations): - logger.info(f"\n=== Iteration {param_idx + 1}/{n_iterations} ===") - - image_matching_results = [] - pair_errors = [] - - for pair_idx, test_case in enumerate(paired_test_cases): - logger.info(f"Processing pair {pair_idx + 1}/{len(paired_test_cases)}: {test_case['case_name']}") - - cache_key = f"{test_case['case_id']}_{test_case.get('subcase_name', '')}" - cached_result = image_match_cache.get(cache_key) - - # Run image matching - image_matching_output = run_image_matching_with_applied_errors( - test_case, - param_idx, - randomize_errors=randomize_errors, - error_variation_percent=error_variation_percent, - cache_results=cache_results, - cached_result=cached_result, - ) - - # Cache first result - if cache_key not in image_match_cache: - image_match_cache[cache_key] = image_matching_output - - image_matching_output.attrs["gcp_pair_index"] = pair_idx - image_matching_results.append(image_matching_output) - - # Extract and store metrics - lat_error_m = abs(image_matching_output.attrs["lat_error_km"] * 1000) - lon_error_m = abs(image_matching_output.attrs["lon_error_km"] * 1000) - rms_error_m = np.sqrt(lat_error_m**2 + lon_error_m**2) - pair_errors.append(rms_error_m) - - netcdf_data["rms_error_m"][param_idx, pair_idx] = rms_error_m - netcdf_data["mean_error_m"][param_idx, pair_idx] = rms_error_m - netcdf_data["max_error_m"][param_idx, pair_idx] = rms_error_m - netcdf_data["n_measurements"][param_idx, pair_idx] = 1 - netcdf_data["im_lat_error_km"][param_idx, pair_idx] = image_matching_output.attrs["lat_error_km"] - netcdf_data["im_lon_error_km"][param_idx, pair_idx] = image_matching_output.attrs["lon_error_km"] - netcdf_data["im_ccv"][param_idx, pair_idx] = image_matching_output.attrs["correlation_ccv"] - netcdf_data["im_grid_step_m"][param_idx, pair_idx] = image_matching_output.attrs["final_grid_step_m"] - - # Compute aggregate metrics (use dynamic threshold) - pair_errors = np.array(pair_errors) - valid_errors = pair_errors[~np.isnan(pair_errors)] - - if len(valid_errors) > 0: - threshold_value = config.performance_threshold_m - percent_under_threshold = (valid_errors < threshold_value).sum() / len(valid_errors) * 100 - netcdf_data[threshold_metric][param_idx] = percent_under_threshold - netcdf_data["mean_rms_all_pairs"][param_idx] = np.mean(valid_errors) - netcdf_data["best_pair_rms"][param_idx] = np.min(valid_errors) - netcdf_data["worst_pair_rms"][param_idx] = np.max(valid_errors) - - logger.info(f"Iteration {param_idx + 1} complete:") - logger.info(f" {percent_under_threshold:.1f}% under {threshold_value}m threshold") - logger.info(f" Mean RMS: {np.mean(valid_errors):.2f}m") - - # ========================================================================== - # STEP 4: ERROR STATISTICS - # ========================================================================== - logger.info("\n" + "=" * 80) - logger.info("STEP 4: ERROR STATISTICS") - logger.info("=" * 80) - - error_stats = correction.call_error_stats_module(image_matching_results, correction_config=config) - logger.info(f"Error statistics computed: {len(error_stats)} metrics") - - # ========================================================================== - # STEP 5: SAVE RESULTS - # ========================================================================== - logger.info("\n" + "=" * 80) - logger.info("STEP 5: SAVE RESULTS") - logger.info("=" * 80) - - output_file = work_dir / "downstream_results.nc" - correction._save_netcdf_results(netcdf_data, output_file, config) - - logger.info("=" * 80) - logger.info("DOWNSTREAM TEST COMPLETE") - logger.info("=" * 80) - logger.info(f"Output: {output_file}") - logger.info(f"Iterations: {n_iterations}") - logger.info(f"Test pairs: {n_gcp_pairs}") - - results = {"mode": "downstream", "iterations": n_iterations, "test_pairs": n_gcp_pairs, "status": "complete"} - - return [], results, output_file - - -# ============================================================================= -# UNITTEST TEST CASES -# ============================================================================= - - -class CorrectionUnifiedTests(unittest.TestCase): - """Unified test cases for both upstream and downstream pipelines. - - Note: Only test_upstream_quick requires GMTED elevation data and is marked - with @pytest.mark.extra. All other tests use either config-only validation - or pre-geolocated test data and will run in CI without GMTED files. - """ - - def setUp(self): - """Set up test environment.""" - self.root_dir = Path(__file__).parents[2] - self.test_data_dir = self.root_dir / "tests" / "data" / "clarreo" / "image_match" - - self.__tmp_dir = tempfile.TemporaryDirectory() - self.addCleanup(self.__tmp_dir.cleanup) - self.work_dir = Path(self.__tmp_dir.name) - - def test_upstream_configuration(self): - """Test that upstream configuration loads correctly.""" - logger.info("Testing upstream configuration...") - - data_dir = self.root_dir / "tests" / "data" / "clarreo" / "gcs" - generic_dir = self.root_dir / "data" / "generic" - - config = create_clarreo_correction_config(data_dir, generic_dir) - - # Add required loaders (Config-Centric Design) - config.telemetry_loader = load_clarreo_telemetry - config.science_loader = load_clarreo_science - - # Validate config is complete - config.validate(check_loaders=True) - - self.assertEqual(config.geo.instrument_name, "CPRS_HYSICS") - self.assertGreater(len(config.parameters), 0) - self.assertEqual(config.seed, 42) - self.assertIsNotNone(config.telemetry_loader) - self.assertIsNotNone(config.science_loader) - - logger.info(f"✓ Configuration valid: {len(config.parameters)} parameters") - - def test_downstream_test_case_discovery(self): - """Test that downstream test cases can be discovered.""" - logger.info("Testing test case discovery...") - - test_cases = discover_test_image_match_cases(self.test_data_dir) - - self.assertGreater(len(test_cases), 0, "No test cases discovered") - - for tc in test_cases: - self.assertIn("case_id", tc) - self.assertIn("subimage_file", tc) - self.assertIn("gcp_file", tc) - self.assertTrue(tc["subimage_file"].exists()) - self.assertTrue(tc["gcp_file"].exists()) - - logger.info(f"✓ Discovered {len(test_cases)} test cases") - - def test_downstream_image_matching(self): - """Test that downstream image matching works.""" - logger.info("Testing image matching...") - - test_cases = discover_test_image_match_cases(self.test_data_dir, test_cases=["1"]) - self.assertGreater(len(test_cases), 0) - - test_case = test_cases[0] - - result = run_image_matching_with_applied_errors( - test_case, param_idx=0, randomize_errors=False, cache_results=True - ) - - self.assertIsInstance(result, xr.Dataset) - self.assertIn("lat_error_km", result.attrs) - self.assertIn("lon_error_km", result.attrs) - - logger.info(f"✓ Image matching successful") - - @pytest.mark.extra - def test_upstream_quick(self): - """Run quick upstream test. - - This test requires GMTED elevation data which is not available in CI. - Run with: pytest --run-extra - """ - logger.info("Running quick upstream test...") - - results_list, results_dict, output_file = run_upstream_pipeline(n_iterations=2, work_dir=self.work_dir) - - self.assertEqual(results_dict["status"], "complete") - self.assertEqual(results_dict["iterations"], 2) - self.assertEqual(results_dict["mode"], "upstream") - self.assertGreater(results_dict["parameter_sets"], 0) - self.assertTrue(output_file.exists()) - - logger.info(f"✓ Quick upstream test complete: {output_file}") - - def test_downstream_quick(self): - """Run quick downstream test.""" - logger.info("Running quick downstream test...") - - results_list, results_dict, output_file = run_downstream_pipeline( - n_iterations=2, test_cases=["1"], work_dir=self.work_dir - ) - - self.assertEqual(results_dict["status"], "complete") - self.assertEqual(results_dict["iterations"], 2) - self.assertTrue(output_file.exists()) - - logger.info(f"✓ Quick downstream test complete: {output_file}") - - def test_synthetic_helpers_basic(self): - """Test that synthetic helper functions work correctly (for coverage).""" - logger.info("Testing synthetic helper functions...") - - # Test synthetic GCP pairing - science_files = ["science_1.nc", "science_2.nc"] - pairs = synthetic_gcp_pairing(science_files) - self.assertEqual(len(pairs), 2) - self.assertIsInstance(pairs, list) - logger.info("✓ synthetic_gcp_pairing works") - - # Test synthetic boresights generation - boresights = _generate_synthetic_boresights(5, max_off_nadir_rad=0.07) - self.assertEqual(boresights.shape, (5, 3)) - self.assertTrue(np.all(np.abs(boresights[:, 0]) < 0.01)) # Small x component - logger.info("✓ _generate_synthetic_boresights works") - - # Test synthetic positions generation - positions = _generate_spherical_positions(5, 6.78e6, 4e3) - self.assertEqual(positions.shape, (5, 3)) - radii = np.linalg.norm(positions, axis=1) - self.assertTrue(np.all(radii > 6.7e6)) # Reasonable orbit altitude - logger.info("✓ _generate_spherical_positions works") - - # Test transform generation - transforms = _generate_nadir_aligned_transforms(5, positions, boresights) - self.assertEqual(transforms.shape, (5, 3, 3)) - # Check it's a valid rotation matrix (det should be close to 1) - det = np.linalg.det(transforms[0]) - self.assertAlmostEqual(abs(det), 1.0, places=1) - logger.info("✓ _generate_nadir_aligned_transforms works") - - logger.info("✓ All synthetic helpers validated") - - def test_downstream_helpers_basic(self): - """Test downstream helper functions (for coverage).""" - logger.info("Testing downstream helper functions...") - - # Test test case discovery - test_cases = discover_test_image_match_cases(self.test_data_dir, test_cases=["1"]) - self.assertGreater(len(test_cases), 0) - self.assertIn("case_id", test_cases[0]) - self.assertIn("subimage_file", test_cases[0]) - logger.info(f"✓ discover_test_image_match_cases found {len(test_cases)} cases") - - # Test error variation (create a simple test dataset) - base_result = xr.Dataset( - { - "lat_error_deg": (["measurement"], [0.001]), - "lon_error_deg": (["measurement"], [0.002]), - }, - attrs={"lat_error_km": 0.1, "lon_error_km": 0.2, "correlation_ccv": 0.95}, - ) - - varied_result = apply_error_variation_for_testing(base_result, param_idx=1, error_variation_percent=3.0) - self.assertIsInstance(varied_result, xr.Dataset) - self.assertIn("lat_error_deg", varied_result) - # Check that variation was applied (should be different from base) - self.assertNotEqual(varied_result.attrs["lat_error_km"], base_result.attrs["lat_error_km"]) - logger.info("✓ apply_error_variation_for_testing works") - - logger.info("✓ All downstream helpers validated") - - @pytest.mark.extra - def test_loop_optimized(self): - """ - Test loop() function (optimized pair-outer implementation). - - This test validates the main loop() function which is now the optimized - default implementation. It covers the core Correction workflow. - """ - logger.info("=" * 80) - logger.info("TEST: loop() (OPTIMIZED IMPLEMENTATION)") - logger.info("=" * 80) - - # Setup configuration - root_dir = Path(__file__).parents[2] - generic_dir = root_dir / "data" / "generic" - data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" - - config = create_clarreo_correction_config(data_dir, generic_dir) - config.n_iterations = 2 # Small for fast testing - config.output_filename = "test_loop_optimized.nc" - - # Add loaders and processing functions - config.telemetry_loader = load_clarreo_telemetry - config.science_loader = load_clarreo_science - config.gcp_loader = load_clarreo_gcp - config.gcp_pairing_func = synthetic_gcp_pairing - config.image_matching_func = synthetic_image_matching - - # Prepare data sets - tlm_sci_gcp_sets = [ - ("telemetry_5a", "science_5a", "synthetic_gcp_1"), - ] - - work_dir = self.work_dir / "test_loop_optimized" - work_dir.mkdir(exist_ok=True) - - # Run loop() - logger.info("Running loop()...") - np.random.seed(42) - results, netcdf_data = correction.loop(config, work_dir, tlm_sci_gcp_sets, resume_from_checkpoint=False) - - # Validate results structure - self.assertIsInstance(results, list) - self.assertGreater(len(results), 0) - expected_count = config.n_iterations * len(tlm_sci_gcp_sets) - self.assertEqual(len(results), expected_count) - logger.info(f"✓ loop() returned {len(results)} results") - - # Validate NetCDF structure - self.assertIsInstance(netcdf_data, dict) - self.assertIn("rms_error_m", netcdf_data) - self.assertIn("parameter_set_id", netcdf_data) - self.assertEqual(netcdf_data["rms_error_m"].shape, (config.n_iterations, len(tlm_sci_gcp_sets))) - logger.info(f"✓ NetCDF structure valid") - - # Validate result contents - for result in results: - self.assertIn("param_index", result) - self.assertIn("pair_index", result) - self.assertIn("rms_error_m", result) - self.assertIn("aggregate_rms_error_m", result) - # Verify aggregate_rms_error_m is populated (not None) - self.assertIsNotNone( - result["aggregate_rms_error_m"], - f"aggregate_rms_error_m should be populated for result {result['iteration']}", - ) - # Verify it's a valid numeric value - self.assertIsInstance(result["aggregate_rms_error_m"], (int, float, np.number)) - logger.info(f"✓ All result entries have required fields") - - logger.info("=" * 80) - logger.info("✓ loop() TEST PASSED") - logger.info("=" * 80) - - def test_helper_extract_parameter_values(self): - """Test _extract_parameter_values helper function.""" - logger.info("Testing _extract_parameter_values...") - - # Create sample params with proper structure for CONSTANT_KERNEL type - param_config = correction.ParameterConfig( - ptype=correction.ParameterType.CONSTANT_KERNEL, config_file=Path("test_kernel.json"), data=None - ) - # Create DataFrame with angle data as expected by the function - param_data = pd.DataFrame( - { - "angle_x": [np.radians(1.0 / 3600)], # 1 arcsec in radians - "angle_y": [np.radians(2.0 / 3600)], # 2 arcsec in radians - "angle_z": [np.radians(3.0 / 3600)], # 3 arcsec in radians - } - ) - params = [(param_config, param_data)] - - result = correction._extract_parameter_values(params) - - self.assertIsInstance(result, dict) - # Should extract 3 values: roll, pitch, yaw - self.assertEqual(len(result), 3) - self.assertIn("test_kernel_roll", result) - self.assertIn("test_kernel_pitch", result) - self.assertIn("test_kernel_yaw", result) - logger.info(f"✓ _extract_parameter_values works correctly") - - def test_helper_build_netcdf_structure(self): - """Test _build_netcdf_structure helper function.""" - logger.info("Testing _build_netcdf_structure...") - - # Create minimal config - root_dir = Path(__file__).parents[2] - generic_dir = root_dir / "data" / "generic" - data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" - config = create_clarreo_correction_config(data_dir, generic_dir) - - n_params = 3 - n_pairs = 2 - - netcdf_data = correction._build_netcdf_structure(config, n_params, n_pairs) - - # Validate structure - self.assertIsInstance(netcdf_data, dict) - self.assertIn("rms_error_m", netcdf_data) - self.assertIn("parameter_set_id", netcdf_data) - self.assertEqual(netcdf_data["rms_error_m"].shape, (n_params, n_pairs)) - self.assertEqual(len(netcdf_data["parameter_set_id"]), n_params) - logger.info(f"✓ _build_netcdf_structure creates correct structure") - - def test_helper_extract_error_metrics(self): - """Test _extract_error_metrics helper function.""" - logger.info("Testing _extract_error_metrics...") - - # Create sample error stats dataset with correct attribute name - stats_dataset = xr.Dataset( - { - "measurement": (["point"], [0, 1, 2]), - "lat_error_deg": (["point"], [0.001, 0.002, 0.001]), - "lon_error_deg": (["point"], [0.001, 0.002, 0.001]), - } - ) - stats_dataset.attrs["rms_error_m"] = 150.0 - stats_dataset.attrs["mean_error_m"] = 140.0 - stats_dataset.attrs["max_error_m"] = 200.0 - stats_dataset.attrs["std_error_m"] = 10.0 - stats_dataset.attrs["total_measurements"] = 3 # Correct attribute name - - metrics = correction._extract_error_metrics(stats_dataset) - - # Validate metrics - self.assertIsInstance(metrics, dict) - self.assertIn("rms_error_m", metrics) - self.assertIn("mean_error_m", metrics) - self.assertIn("n_measurements", metrics) - self.assertEqual(metrics["n_measurements"], 3) - self.assertEqual(metrics["rms_error_m"], 150.0) - logger.info(f"✓ _extract_error_metrics extracts metrics correctly") - - def test_helper_store_parameter_values(self): - """Test _store_parameter_values helper function.""" - logger.info("Testing _store_parameter_values...") - - # Create netcdf structure with parameter arrays pre-created - # (as _build_netcdf_structure would do) - netcdf_data = { - "parameter_set_id": np.zeros(3, dtype=int), - "param_test_param": np.zeros(3), # Must match naming convention - } - - param_values = {"test_param": 1.5} - param_idx = 1 - - correction._store_parameter_values(netcdf_data, param_idx, param_values) - - # Validate storage - self.assertEqual(netcdf_data["param_test_param"][param_idx], 1.5) - logger.info(f"✓ _store_parameter_values stores correctly") - - def test_helper_store_gcp_pair_results(self): - """Test _store_gcp_pair_results helper function.""" - logger.info("Testing _store_gcp_pair_results...") - - # Create netcdf structure with all required fields - netcdf_data = { - "rms_error_m": np.zeros((2, 2)), - "mean_error_m": np.zeros((2, 2)), - "max_error_m": np.zeros((2, 2)), - "std_error_m": np.zeros((2, 2)), # Must include this - "n_measurements": np.zeros((2, 2), dtype=int), - } - - error_metrics = { - "rms_error_m": 150.0, - "mean_error_m": 140.0, - "max_error_m": 200.0, - "std_error_m": 10.0, # Must include this - "n_measurements": 10, - } - - param_idx = 0 - pair_idx = 1 - - correction._store_gcp_pair_results(netcdf_data, param_idx, pair_idx, error_metrics) - - # Validate storage - self.assertEqual(netcdf_data["rms_error_m"][param_idx, pair_idx], 150.0) - self.assertEqual(netcdf_data["std_error_m"][param_idx, pair_idx], 10.0) - self.assertEqual(netcdf_data["n_measurements"][param_idx, pair_idx], 10) - logger.info(f"✓ _store_gcp_pair_results stores correctly") - - def test_helper_compute_parameter_set_metrics(self): - """Test _compute_parameter_set_metrics helper function.""" - logger.info("Testing _compute_parameter_set_metrics...") - - # Create netcdf structure - netcdf_data = { - "percent_under_250m": np.zeros(2), - "mean_rms_all_pairs": np.zeros(2), - "best_pair_rms": np.zeros(2), - "worst_pair_rms": np.zeros(2), - } - - pair_errors = [100.0, 200.0, 300.0] - param_idx = 0 - threshold_m = 250.0 - - correction._compute_parameter_set_metrics(netcdf_data, param_idx, pair_errors, threshold_m) - - # Validate computed metrics - self.assertGreater(netcdf_data["percent_under_250m"][param_idx], 0) - self.assertGreater(netcdf_data["mean_rms_all_pairs"][param_idx], 0) - self.assertEqual(netcdf_data["best_pair_rms"][param_idx], 100.0) - self.assertEqual(netcdf_data["worst_pair_rms"][param_idx], 300.0) - logger.info(f"✓ _compute_parameter_set_metrics computes correctly") - - def test_helper_load_image_pair_data(self): - """Test _load_image_pair_data helper function.""" - logger.info("Testing _load_image_pair_data...") - - # Setup configuration - root_dir = Path(__file__).parents[2] - generic_dir = root_dir / "data" / "generic" - data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" - config = create_clarreo_correction_config(data_dir, generic_dir) - - tlm_dataset, sci_dataset, ugps_times = correction._load_image_pair_data( - "telemetry_5a", "science_5a", config, load_clarreo_telemetry, load_clarreo_science - ) - - # Validate return types - self.assertIsInstance(tlm_dataset, pd.DataFrame) - self.assertIsInstance(sci_dataset, pd.DataFrame) - self.assertIsNotNone(ugps_times) - logger.info(f"✓ _load_image_pair_data loads data correctly") - - def test_helper_create_dynamic_kernels(self): - """Test _create_dynamic_kernels helper function.""" - logger.info("Testing _create_dynamic_kernels...") - - # Setup - root_dir = Path(__file__).parents[2] - generic_dir = root_dir / "data" / "generic" - data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" - config = create_clarreo_correction_config(data_dir, generic_dir) - - work_dir = self.work_dir / "test_dynamic_kernels" - work_dir.mkdir(exist_ok=True) - - # Load data - tlm_dataset = load_clarreo_telemetry("telemetry_5a", config) - creator = create.KernelCreator(overwrite=True, append=False) - - # Load SPICE kernels needed for kernel creation - # (frame kernel defines ISS_SC body which is needed by ephemeris writer) - mkrn = meta.MetaKernel.from_json( - config.geo.meta_kernel_file, - relative=True, - sds_dir=config.geo.generic_kernel_dir, - ) - with sp.ext.load_kernel([mkrn.sds_kernels, mkrn.mission_kernels]): - dynamic_kernels = correction._create_dynamic_kernels(config, work_dir, tlm_dataset, creator) - - # Validate - self.assertIsInstance(dynamic_kernels, list) - logger.info(f"✓ _create_dynamic_kernels creates {len(dynamic_kernels)} kernels") - - def test_helper_load_calibration_data(self): - """Test _load_calibration_data helper function.""" - logger.info("Testing _load_calibration_data...") - - # Setup minimal config without calibration dir - root_dir = Path(__file__).parents[2] - generic_dir = root_dir / "data" / "generic" - data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" - config = create_clarreo_correction_config(data_dir, generic_dir) - config.calibration_dir = None - - # Load LOS vectors and PSF data into calibration data. - calibration_data = correction._load_calibration_data(config) - - # Should return CalibrationData with None values when no calibration dir - self.assertIsNone(calibration_data.los_vectors) - self.assertIsNone(calibration_data.optical_psfs) - logger.info(f"✓ _load_calibration_data handles None calibration_dir") - - def test_checkpoint_save_load(self): - """Test checkpoint save and load functionality.""" - logger.info("=" * 80) - logger.info("TEST: Checkpoint Save/Load") - logger.info("=" * 80) - - # Setup configuration - root_dir = Path(__file__).parents[2] - generic_dir = root_dir / "data" / "generic" - data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" - - config = create_clarreo_correction_config(data_dir, generic_dir) - config.n_iterations = 2 - config.output_filename = "test_checkpoint.nc" - - # Add loaders - config.telemetry_loader = load_clarreo_telemetry - config.science_loader = load_clarreo_science - config.gcp_loader = load_clarreo_gcp - config.gcp_pairing_func = synthetic_gcp_pairing - config.image_matching_func = synthetic_image_matching - - work_dir = self.work_dir / "test_checkpoint" - work_dir.mkdir(exist_ok=True) - - output_file = work_dir / config.output_filename - - # Build simple netcdf structure for testing - netcdf_data = correction._build_netcdf_structure(config, 2, 2) - netcdf_data["rms_error_m"][0, 0] = 100.0 - netcdf_data["rms_error_m"][1, 0] = 150.0 - - # Save checkpoint - logger.info("Saving checkpoint...") - correction._save_netcdf_checkpoint(netcdf_data, output_file, config, pair_idx_completed=0) - - checkpoint_file = output_file.parent / f"{output_file.stem}_checkpoint.nc" - self.assertTrue(checkpoint_file.exists()) - logger.info(f"✓ Checkpoint file created: {checkpoint_file}") - - # Load checkpoint - logger.info("Loading checkpoint...") - loaded_data, completed_pairs = correction._load_checkpoint(output_file, config) - - self.assertIsNotNone(loaded_data) - self.assertEqual(completed_pairs, 1) # 0-indexed, so pair 0 completed = 1 - self.assertEqual(loaded_data["rms_error_m"][0, 0], 100.0) - self.assertEqual(loaded_data["rms_error_m"][1, 0], 150.0) - logger.info(f"✓ Checkpoint loaded correctly, completed pairs = {completed_pairs}") - - # Cleanup - correction._cleanup_checkpoint(output_file) - self.assertFalse(checkpoint_file.exists()) - logger.info(f"✓ Checkpoint cleanup successful") - - logger.info("=" * 80) - logger.info("✓ CHECKPOINT SAVE/LOAD TEST PASSED") - logger.info("=" * 80) - - def test_apply_offset_function(self): - """Test apply_offset function for all parameter types.""" - logger.info("=" * 80) - logger.info("TEST: apply_offset() Function") - logger.info("=" * 80) - - # ========== Test 1: OFFSET_KERNEL with arcseconds ========== - logger.info("\nTest 1: OFFSET_KERNEL with arcseconds unit conversion") - - # Create realistic telemetry data matching CLARREO structure - telemetry_data = pd.DataFrame( - { - "frame": range(5), - "hps.az_ang_nonlin": [1.14252] * 5, - "hps.el_ang_nonlin": [-0.55009] * 5, - "hps.resolver_tms": [1168477154.0 + i for i in range(5)], - "ert": [1431903180.58 + i for i in range(5)], - } - ) - - # Create parameter config for azimuth angle offset - az_param_config = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_KERNEL, - config_file=Path("cprs_az_v01.attitude.ck.json"), - data=dict( - field="hps.az_ang_nonlin", - units="arcseconds", - ), - ) - - # Apply offset of 100 arcseconds - offset_arcsec = 100.0 - original_mean = telemetry_data["hps.az_ang_nonlin"].mean() - modified_data = correction.apply_offset(az_param_config, offset_arcsec, telemetry_data) - - # Verify offset was applied correctly - expected_offset_rad = np.deg2rad(offset_arcsec / 3600.0) - actual_delta = modified_data["hps.az_ang_nonlin"].mean() - original_mean - self.assertAlmostEqual(actual_delta, expected_offset_rad, places=9) - self.assertIsInstance(modified_data, pd.DataFrame) - logger.info(f"✓ OFFSET_KERNEL (arcseconds): {offset_arcsec} arcsec = {expected_offset_rad:.9f} rad") - - # ========== Test 2: OFFSET_KERNEL with elevation angle ========== - logger.info("\nTest 2: OFFSET_KERNEL with elevation angle (negative offset)") - - el_param_config = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_KERNEL, - config_file=Path("cprs_el_v01.attitude.ck.json"), - data=dict( - field="hps.el_ang_nonlin", - units="arcseconds", - ), - ) - - # Apply negative offset - offset_arcsec = -50.0 - original_mean = telemetry_data["hps.el_ang_nonlin"].mean() - modified_data = correction.apply_offset(el_param_config, offset_arcsec, telemetry_data) - - # Verify - expected_offset_rad = np.deg2rad(offset_arcsec / 3600.0) - actual_delta = modified_data["hps.el_ang_nonlin"].mean() - original_mean - self.assertAlmostEqual(actual_delta, expected_offset_rad, places=9) - logger.info(f"✓ OFFSET_KERNEL (negative): {offset_arcsec} arcsec = {expected_offset_rad:.9f} rad") - - # ========== Test 3: OFFSET_KERNEL with non-existent field ========== - logger.info("\nTest 3: OFFSET_KERNEL with non-existent field (should warn)") - - bad_param_config = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_KERNEL, - config_file=Path("dummy.json"), - data=dict( - field="nonexistent_field", - units="arcseconds", - ), - ) - - # Should return unmodified data when field not found - modified_data = correction.apply_offset(bad_param_config, 10.0, telemetry_data) - self.assertIsInstance(modified_data, pd.DataFrame) - # Data should be unchanged - pd.testing.assert_frame_equal(modified_data, telemetry_data) - logger.info("✓ OFFSET_KERNEL correctly handles missing field") - - # ========== Test 4: OFFSET_TIME with milliseconds ========== - logger.info("\nTest 4: OFFSET_TIME with milliseconds unit conversion") - - science_data = pd.DataFrame( - { - "corrected_timestamp": [1000000.0, 2000000.0, 3000000.0, 4000000.0, 5000000.0], - "measurement": [1.0, 2.0, 3.0, 4.0, 5.0], - } - ) - - time_param_config = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_TIME, - config_file=None, - data=dict( - field="corrected_timestamp", - units="milliseconds", - ), - ) - - # Apply time offset - value must be in seconds - # This simulates the production flow where load_param_sets converts to seconds - offset_ms = 10.0 - offset_seconds = offset_ms / 1000.0 # Convert to internal units (seconds) - original_mean = science_data["corrected_timestamp"].mean() - modified_data = correction.apply_offset(time_param_config, offset_seconds, science_data) - - # Verify offset was applied correctly (seconds -> microseconds) - expected_offset_us = offset_ms * 1000.0 # 10 ms = 10000 µs - actual_delta = modified_data["corrected_timestamp"].mean() - original_mean - self.assertAlmostEqual(actual_delta, expected_offset_us, places=6) - logger.info(f"✓ OFFSET_TIME: {offset_ms} ms = {expected_offset_us:.6f} µs") - - # ========== Test 5: OFFSET_TIME with negative offset ========== - logger.info("\nTest 5: OFFSET_TIME with negative offset") - - offset_ms = -5.5 - offset_seconds = offset_ms / 1000.0 # Convert to internal units (seconds) - original_mean = science_data["corrected_timestamp"].mean() - modified_data = correction.apply_offset(time_param_config, offset_seconds, science_data) - - # Verify - expected_offset_us = offset_ms * 1000.0 - actual_delta = modified_data["corrected_timestamp"].mean() - original_mean - self.assertAlmostEqual(actual_delta, expected_offset_us, places=6) - logger.info(f"✓ OFFSET_TIME (negative): {offset_ms} ms = {expected_offset_us:.6f} µs") - - # ========== Test 6: CONSTANT_KERNEL (pass-through) ========== - logger.info("\nTest 6: CONSTANT_KERNEL (pass-through, no modification)") - - constant_kernel_data = pd.DataFrame( - { - "ugps": [1000000, 2000000], - "angle_x": [0.001, 0.001], - "angle_y": [0.002, 0.002], - "angle_z": [0.003, 0.003], - } - ) - - constant_param_config = correction.ParameterConfig( - ptype=correction.ParameterType.CONSTANT_KERNEL, - config_file=Path("cprs_base_v01.attitude.ck.json"), - data=dict( - field="cprs_base", - ), - ) - - # For CONSTANT_KERNEL, param_data is already the kernel data - modified_data = correction.apply_offset(constant_param_config, constant_kernel_data, pd.DataFrame()) - - # Should return the constant kernel data unchanged - self.assertIsInstance(modified_data, pd.DataFrame) - pd.testing.assert_frame_equal(modified_data, constant_kernel_data) - logger.info("✓ CONSTANT_KERNEL returns data unchanged") - - # ========== Test 7: OFFSET_KERNEL without units ========== - logger.info("\nTest 7: OFFSET_KERNEL without unit conversion") - - param_no_units = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_KERNEL, - config_file=Path("test.json"), - data=dict( - field="hps.az_ang_nonlin", - # No units specified - ), - ) - - # Apply offset directly (no conversion) - offset_value = 0.001 # radians - original_mean = telemetry_data["hps.az_ang_nonlin"].mean() - modified_data = correction.apply_offset(param_no_units, offset_value, telemetry_data) - - # Verify offset applied directly without conversion - actual_delta = modified_data["hps.az_ang_nonlin"].mean() - original_mean - self.assertAlmostEqual(actual_delta, offset_value, places=9) - logger.info(f"✓ OFFSET_KERNEL (no units): {offset_value} rad applied directly") - - # ========== Test 8: Data is not modified in place ========== - logger.info("\nTest 8: Original data is not modified in place") - - original_telemetry = telemetry_data.copy() - modified_data = correction.apply_offset(az_param_config, 100.0, telemetry_data) - - # Verify original data unchanged - pd.testing.assert_frame_equal(telemetry_data, original_telemetry) - # Verify modified data is different - self.assertFalse(modified_data["hps.az_ang_nonlin"].equals(original_telemetry["hps.az_ang_nonlin"])) - logger.info("✓ Original data not modified (proper copy made)") - - # ========== Test 9: Multiple columns preserved ========== - logger.info("\nTest 9: All DataFrame columns preserved after offset") - - modified_data = correction.apply_offset(az_param_config, 50.0, telemetry_data) - - # Verify all columns still present - self.assertEqual(set(modified_data.columns), set(telemetry_data.columns)) - # Verify only target column modified - self.assertTrue(modified_data["frame"].equals(telemetry_data["frame"])) - self.assertTrue(modified_data["ert"].equals(telemetry_data["ert"])) - self.assertFalse(modified_data["hps.az_ang_nonlin"].equals(telemetry_data["hps.az_ang_nonlin"])) - logger.info("✓ All DataFrame columns preserved, only target modified") - - logger.info("\n" + "=" * 80) - logger.info("✓ apply_offset() TEST PASSED") - logger.info(" - OFFSET_KERNEL with arcseconds conversion ✓") - logger.info(" - OFFSET_KERNEL with negative values ✓") - logger.info(" - OFFSET_KERNEL missing field handling ✓") - logger.info(" - OFFSET_KERNEL without units ✓") - logger.info(" - OFFSET_TIME with milliseconds conversion ✓") - logger.info(" - OFFSET_TIME with negative values ✓") - logger.info(" - CONSTANT_KERNEL pass-through ✓") - logger.info(" - Data not modified in place ✓") - logger.info(" - All columns preserved ✓") - logger.info("=" * 80) - - def test_helper_load_param_sets(self): - """Test load_param_sets function for parameter set generation.""" - logger.info("=" * 80) - logger.info("TEST: load_param_sets() Function") - logger.info("=" * 80) - - # Create minimal config with different parameter types - data_dir = self.root_dir / "tests" / "data" / "clarreo" / "gcs" - generic_dir = self.root_dir / "data" / "generic" - - # ========== Test 1: OFFSET_KERNEL parameter with arcseconds ========== - logger.info("\nTest 1: OFFSET_KERNEL parameter generation with arcseconds") - - offset_kernel_param = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_KERNEL, - config_file=Path("cprs_az_v01.attitude.ck.json"), - data=dict( - field="hps.az_ang_nonlin", - units="arcseconds", - current_value=0.0, # Starting at zero - sigma=100.0, # ±100 arcseconds standard deviation - bounds=[-200.0, 200.0], # ±200 arcseconds limits - ), - ) - - config = correction.CorrectionConfig( - seed=42, # For reproducibility - n_iterations=3, - parameters=[offset_kernel_param], - geo=correction.GeolocationConfig( - meta_kernel_file=data_dir / "meta_kernel.tm", - generic_kernel_dir=generic_dir, - dynamic_kernels=[], - instrument_name="CPRS_HYSICS", - time_field="corrected_timestamp", - ), - performance_threshold_m=250.0, - performance_spec_percent=39.0, - earth_radius_m=6378137.0, - ) - - param_sets = correction.load_param_sets(config) - - # Validate structure - self.assertEqual(len(param_sets), 3, "Should generate 3 parameter sets") - self.assertEqual(len(param_sets[0]), 1, "Each set should have 1 parameter") - - # Check each parameter set - for i, param_set in enumerate(param_sets): - param_config, param_value = param_set[0] - self.assertEqual(param_config.ptype, correction.ParameterType.OFFSET_KERNEL) - self.assertIsInstance(param_value, float, "OFFSET_KERNEL should produce float value") - # Value should be in radians (converted from arcseconds) - self.assertLess(abs(param_value), np.deg2rad(200.0 / 3600.0), "Value should be within bounds") - logger.info(f" Set {i}: {param_value:.9f} rad ({np.rad2deg(param_value) * 3600.0:.3f} arcsec)") - - logger.info("✓ OFFSET_KERNEL parameter generation works correctly") - - # ========== Test 2: OFFSET_TIME parameter with milliseconds ========== - logger.info("\nTest 2: OFFSET_TIME parameter generation with milliseconds") - - offset_time_param = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_TIME, - config_file=None, - data=dict( - field="corrected_timestamp", - units="milliseconds", - current_value=0.0, - sigma=10.0, # ±10 ms standard deviation - bounds=[-50.0, 50.0], # ±50 ms limits - ), - ) - - config.parameters = [offset_time_param] - config.n_iterations = 3 - - param_sets = correction.load_param_sets(config) - - self.assertEqual(len(param_sets), 3) - for i, param_set in enumerate(param_sets): - param_config, param_value = param_set[0] - self.assertEqual(param_config.ptype, correction.ParameterType.OFFSET_TIME) - self.assertIsInstance(param_value, float, "OFFSET_TIME should produce float value") - # Value should be in seconds (converted from milliseconds) - self.assertLess(abs(param_value), 0.050, "Value should be within bounds (50 ms = 0.050 s)") - logger.info(f" Set {i}: {param_value:.6f} s ({param_value * 1000.0:.3f} ms)") - - logger.info("✓ OFFSET_TIME parameter generation works correctly") - - # ========== Test 3: CONSTANT_KERNEL parameter with 3D angles ========== - logger.info("\nTest 3: CONSTANT_KERNEL parameter generation with 3D angles") - - constant_kernel_param = correction.ParameterConfig( - ptype=correction.ParameterType.CONSTANT_KERNEL, - config_file=data_dir / "cprs_base_v01.attitude.ck.json", - data=dict( - field="cprs_base", - units="arcseconds", - current_value=[0.0, 0.0, 0.0], # [roll, pitch, yaw] - sigma=50.0, # ±50 arcseconds for each axis - bounds=[-100.0, 100.0], # ±100 arcseconds limits - ), - ) - - config.parameters = [constant_kernel_param] - config.n_iterations = 2 - - param_sets = correction.load_param_sets(config) - - self.assertEqual(len(param_sets), 2) - for i, param_set in enumerate(param_sets): - param_config, param_value = param_set[0] - self.assertEqual(param_config.ptype, correction.ParameterType.CONSTANT_KERNEL) - self.assertIsInstance(param_value, pd.DataFrame, "CONSTANT_KERNEL should produce DataFrame") - self.assertIn("angle_x", param_value.columns) - self.assertIn("angle_y", param_value.columns) - self.assertIn("angle_z", param_value.columns) - self.assertIn("ugps", param_value.columns) - - # Check each angle is within bounds (in radians) - max_bound_rad = np.deg2rad(100.0 / 3600.0) - for angle_col in ["angle_x", "angle_y", "angle_z"]: - angle_val = param_value[angle_col].iloc[0] - self.assertLess(abs(angle_val), max_bound_rad, f"{angle_col} should be within bounds") - - logger.info( - f" Set {i}: roll={param_value['angle_x'].iloc[0]:.9f}, " - f"pitch={param_value['angle_y'].iloc[0]:.9f}, " - f"yaw={param_value['angle_z'].iloc[0]:.9f} rad" - ) - - logger.info("✓ CONSTANT_KERNEL parameter generation works correctly") - - # ========== Test 4: Multiple parameters together ========== - logger.info("\nTest 4: Multiple parameters in single config") - - config.parameters = [offset_kernel_param, offset_time_param, constant_kernel_param] - config.n_iterations = 2 - - param_sets = correction.load_param_sets(config) - - self.assertEqual(len(param_sets), 2, "Should generate 2 parameter sets") - self.assertEqual(len(param_sets[0]), 3, "Each set should have 3 parameters") - - # Verify each parameter type is present - for i, param_set in enumerate(param_sets): - types_found = [p[0].ptype for p in param_set] - self.assertIn(correction.ParameterType.OFFSET_KERNEL, types_found) - self.assertIn(correction.ParameterType.OFFSET_TIME, types_found) - self.assertIn(correction.ParameterType.CONSTANT_KERNEL, types_found) - logger.info(f" Set {i}: Contains all 3 parameter types ✓") - - logger.info("✓ Multiple parameters handled correctly") - - # ========== Test 5: Fixed parameter (sigma=0) ========== - logger.info("\nTest 5: Fixed parameter with sigma=0") - - fixed_param = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_KERNEL, - config_file=Path("fixed.json"), - data=dict( - field="fixed_field", - units="arcseconds", - current_value=25.0, # Fixed at 25 arcseconds - sigma=0.0, # No variation - bounds=[-100.0, 100.0], - ), - ) - - config.parameters = [fixed_param] - config.n_iterations = 3 - - param_sets = correction.load_param_sets(config) - - expected_value_rad = np.deg2rad(25.0 / 3600.0) - for i, param_set in enumerate(param_sets): - param_config, param_value = param_set[0] - self.assertAlmostEqual(param_value, expected_value_rad, places=12, msg="Fixed parameter should not vary") - logger.info(f" Set {i}: {param_value:.9f} rad (constant)") - - logger.info("✓ Fixed parameter (sigma=0) works correctly") - - # ========== Test 6: Seed reproducibility ========== - logger.info("\nTest 6: Random seed reproducibility") - - config.parameters = [offset_kernel_param] - config.n_iterations = 3 - config.seed = 123 - - param_sets_1 = correction.load_param_sets(config) - - # Reset and generate again with same seed - config.seed = 123 - param_sets_2 = correction.load_param_sets(config) - - # Should produce identical values - for i in range(len(param_sets_1)): - val_1 = param_sets_1[i][0][1] - val_2 = param_sets_2[i][0][1] - self.assertAlmostEqual(val_1, val_2, places=12, msg=f"Set {i} should be identical with same seed") - logger.info(f" Set {i}: {val_1:.9f} rad (reproducible)") - - logger.info("✓ Random seed reproducibility verified") - - # ========== Test 7: Parameter without sigma (should use current_value) ========== - logger.info("\nTest 7: Parameter without sigma field") - - no_sigma_param = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_KERNEL, - config_file=Path("no_sigma.json"), - data=dict( - field="test_field", - units="arcseconds", - current_value=15.0, - # No sigma specified - bounds=[-100.0, 100.0], - ), - ) - - config.parameters = [no_sigma_param] - config.n_iterations = 3 - - param_sets = correction.load_param_sets(config) - - expected_value_rad = np.deg2rad(15.0 / 3600.0) - for i, param_set in enumerate(param_sets): - param_config, param_value = param_set[0] - self.assertAlmostEqual( - param_value, expected_value_rad, places=12, msg="Parameter without sigma should use current_value" - ) - - logger.info("✓ Parameter without sigma uses current_value correctly") - - logger.info("\n" + "=" * 80) - logger.info("✓ load_param_sets() TEST PASSED") - logger.info(" - OFFSET_KERNEL generation ✓") - logger.info(" - OFFSET_TIME generation ✓") - logger.info(" - CONSTANT_KERNEL generation ✓") - logger.info(" - Multiple parameters ✓") - logger.info(" - Fixed parameters (sigma=0) ✓") - logger.info(" - Seed reproducibility ✓") - logger.info(" - Parameters without sigma ✓") - logger.info("=" * 80) - - def test_offset_time_unit_conversion_integration(self): - """Test the full integration of load_param_sets -> apply_offset for OFFSET_TIME with all unit types. - - This test verifies that: - 1. load_param_sets correctly converts milliseconds/microseconds -> seconds - 2. apply_offset correctly converts seconds -> microseconds for the timestamp field - 3. The end-to-end pipeline produces correct results - 4. All unit conversion paths are exercised (including uncovered lines) - """ - logger.info("=" * 80) - logger.info("TEST: OFFSET_TIME Unit Conversion Integration (load_param_sets -> apply_offset)") - logger.info("=" * 80) - - data_dir = self.root_dir / "tests" / "data" / "clarreo" / "gcs" - generic_dir = self.root_dir / "data" / "generic" - - # Test data with timestamps in microseconds (typical format) - science_data = pd.DataFrame( - { - "corrected_timestamp": [1000000.0, 2000000.0, 3000000.0, 4000000.0, 5000000.0], - "measurement": [1.0, 2.0, 3.0, 4.0, 5.0], - } - ) - - # ========== Test 1: Milliseconds with sigma ========== - logger.info("\nTest 1: OFFSET_TIME with milliseconds unit (with sigma)") - - offset_time_ms_param = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_TIME, - config_file=None, - data=dict( - field="corrected_timestamp", - units="milliseconds", - current_value=10.0, # 10 milliseconds - sigma=2.0, # ±2 ms variation - bounds=[-50.0, 50.0], # ±50 ms limits - ), - ) - - config = correction.CorrectionConfig( - seed=42, - n_iterations=1, - parameters=[offset_time_ms_param], - geo=correction.GeolocationConfig( - meta_kernel_file=data_dir / "meta_kernel.tm", - generic_kernel_dir=generic_dir, - dynamic_kernels=[], - instrument_name="CPRS_HYSICS", - time_field="corrected_timestamp", - ), - performance_threshold_m=250.0, - performance_spec_percent=39.0, - earth_radius_m=6378137.0, - ) - - # Step 1: load_param_sets converts milliseconds -> seconds - param_sets = correction.load_param_sets(config) - self.assertEqual(len(param_sets), 1) - param_config, param_value_seconds = param_sets[0][0] - - # Verify conversion to seconds - self.assertLess(abs(param_value_seconds), 0.050, "Value should be in seconds, within 50ms bound") - logger.info(f" load_param_sets output: {param_value_seconds:.6f} s = {param_value_seconds * 1000.0:.3f} ms") - - # Step 2: apply_offset converts seconds -> microseconds - original_mean = science_data["corrected_timestamp"].mean() - modified_data = correction.apply_offset(param_config, param_value_seconds, science_data) - - # Verify the offset was applied correctly - actual_delta_us = modified_data["corrected_timestamp"].mean() - original_mean - expected_delta_us = param_value_seconds * 1000000.0 # seconds -> microseconds - self.assertAlmostEqual(actual_delta_us, expected_delta_us, places=3) - logger.info(f" Expected delta: {expected_delta_us:.3f} µs") - logger.info(f" Actual delta: {actual_delta_us:.3f} µs") - logger.info("✓ Milliseconds path works correctly (load_param_sets -> apply_offset)") - - # ========== Test 2: Microseconds with sigma (covers lines 1442-1446) ========== - logger.info("\nTest 2: OFFSET_TIME with microseconds unit (with sigma)") - - offset_time_us_param = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_TIME, - config_file=None, - data=dict( - field="corrected_timestamp", - units="microseconds", - current_value=5000.0, # 5000 microseconds = 5 ms - sigma=1000.0, # ±1000 µs variation - bounds=[-10000.0, 10000.0], # ±10000 µs limits - ), - ) - - config.parameters = [offset_time_us_param] - - # Step 1: load_param_sets converts microseconds -> seconds - param_sets = correction.load_param_sets(config) - param_config, param_value_seconds = param_sets[0][0] - - # Verify conversion to seconds - self.assertLess(abs(param_value_seconds), 0.010, "Value should be in seconds, within 10ms bound") - logger.info(f" load_param_sets output: {param_value_seconds:.6f} s = {param_value_seconds * 1000000.0:.1f} µs") - - # Step 2: apply_offset converts seconds -> microseconds - original_mean = science_data["corrected_timestamp"].mean() - modified_data = correction.apply_offset(param_config, param_value_seconds, science_data) - - # Verify the offset was applied correctly - actual_delta_us = modified_data["corrected_timestamp"].mean() - original_mean - expected_delta_us = param_value_seconds * 1000000.0 - self.assertAlmostEqual(actual_delta_us, expected_delta_us, places=3) - logger.info(f" Expected delta: {expected_delta_us:.3f} µs") - logger.info(f" Actual delta: {actual_delta_us:.3f} µs") - logger.info("✓ Microseconds path works correctly (load_param_sets -> apply_offset)") - - # ========== Test 3: Seconds unit (baseline) ========== - logger.info("\nTest 3: OFFSET_TIME with seconds unit (baseline)") - - offset_time_s_param = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_TIME, - config_file=None, - data=dict( - field="corrected_timestamp", - units="seconds", - current_value=0.008, # 8 milliseconds - sigma=0.002, # ±2 ms variation - bounds=[-0.050, 0.050], # ±50 ms limits - ), - ) - - config.parameters = [offset_time_s_param] - - # Step 1: load_param_sets (no conversion needed, already in seconds) - param_sets = correction.load_param_sets(config) - param_config, param_value_seconds = param_sets[0][0] - - self.assertLess(abs(param_value_seconds), 0.050, "Value should be in seconds") - logger.info(f" load_param_sets output: {param_value_seconds:.6f} s") - - # Step 2: apply_offset converts seconds -> microseconds - original_mean = science_data["corrected_timestamp"].mean() - modified_data = correction.apply_offset(param_config, param_value_seconds, science_data) - - # Verify the offset was applied correctly - actual_delta_us = modified_data["corrected_timestamp"].mean() - original_mean - expected_delta_us = param_value_seconds * 1000000.0 - self.assertAlmostEqual(actual_delta_us, expected_delta_us, places=3) - logger.info(f" Expected delta: {expected_delta_us:.3f} µs") - logger.info(f" Actual delta: {actual_delta_us:.3f} µs") - logger.info("✓ Seconds path works correctly (load_param_sets -> apply_offset)") - - # ========== Test 4: Fixed offset (sigma=0) with milliseconds ========== - logger.info("\nTest 4: OFFSET_TIME fixed offset (sigma=0) with milliseconds") - - fixed_time_param = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_TIME, - config_file=None, - data=dict( - field="corrected_timestamp", - units="milliseconds", - current_value=15.0, # Fixed 15 ms - sigma=0.0, # No variation - bounds=[-50.0, 50.0], - ), - ) - - config.parameters = [fixed_time_param] - config.n_iterations = 3 - - # Generate multiple sets - all should be identical - param_sets = correction.load_param_sets(config) - self.assertEqual(len(param_sets), 3) - - expected_seconds = 15.0 / 1000.0 # 15 ms = 0.015 s - for i, param_set in enumerate(param_sets): - param_config, param_value_seconds = param_set[0] - self.assertAlmostEqual(param_value_seconds, expected_seconds, places=9) - logger.info(f" Set {i}: {param_value_seconds:.6f} s (constant)") - - # Apply to data and verify - original_mean = science_data["corrected_timestamp"].mean() - modified_data = correction.apply_offset(param_config, expected_seconds, science_data) - - actual_delta_us = modified_data["corrected_timestamp"].mean() - original_mean - expected_delta_us = 15000.0 # 15 ms = 15000 µs - self.assertAlmostEqual(actual_delta_us, expected_delta_us, places=3) - logger.info(f" Applied offset: {actual_delta_us:.3f} µs (expected {expected_delta_us:.3f} µs)") - logger.info("✓ Fixed offset with milliseconds works correctly") - - # ========== Test 5: Fixed offset (sigma=0) with microseconds ========== - logger.info("\nTest 5: OFFSET_TIME fixed offset (sigma=0) with microseconds") - - fixed_time_us_param = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_TIME, - config_file=None, - data=dict( - field="corrected_timestamp", - units="microseconds", - current_value=7500.0, # Fixed 7500 µs = 7.5 ms - sigma=0.0, # No variation - bounds=[-50000.0, 50000.0], - ), - ) - - config.parameters = [fixed_time_us_param] - config.n_iterations = 2 - - param_sets = correction.load_param_sets(config) - expected_seconds = 7500.0 / 1000000.0 # 7500 µs = 0.0075 s - - for i, param_set in enumerate(param_sets): - param_config, param_value_seconds = param_set[0] - self.assertAlmostEqual(param_value_seconds, expected_seconds, places=9) - logger.info(f" Set {i}: {param_value_seconds:.6f} s (constant)") - - # Apply and verify - original_mean = science_data["corrected_timestamp"].mean() - modified_data = correction.apply_offset(param_config, expected_seconds, science_data) - - actual_delta_us = modified_data["corrected_timestamp"].mean() - original_mean - expected_delta_us = 7500.0 # 7500 µs - self.assertAlmostEqual(actual_delta_us, expected_delta_us, places=3) - logger.info(f" Applied offset: {actual_delta_us:.3f} µs (expected {expected_delta_us:.3f} µs)") - logger.info("✓ Fixed offset with microseconds works correctly") - - # ========== Summary ========== - logger.info("\n" + "=" * 80) - logger.info("✓ OFFSET_TIME UNIT CONVERSION INTEGRATION TEST PASSED") - logger.info(" - Integration: load_param_sets -> apply_offset ✓") - logger.info("=" * 80) - - -# ============================================================================= -# MAIN ENTRY POINT -# ============================================================================= - - -def main(): - """Main entry point for standalone execution.""" - parser = argparse.ArgumentParser( - description="Unified Correction Test Suite", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Test Modes: ------------ -upstream - Test kernel creation + geolocation -downstream - Test pairing + matching + error statistics -unittest - Run all unit tests - -Examples: ---------- -# Run downstream test -python test_correction.py --mode downstream --quick - -# Run with specific test cases -python test_correction.py --mode downstream --test-cases 1 2 --iterations 10 - -# Run unit tests -python test_correction.py --mode unittest -pytest test_correction.py -v - -# Run upstream test (when implemented) -python test_correction.py --mode upstream --iterations 5 - """, - ) - - parser.add_argument( - "--mode", type=str, choices=["upstream", "downstream", "unittest"], required=True, help="Test mode to run" - ) - parser.add_argument("--quick", action="store_true", help="Quick test (2 iterations, test case 1 only)") - parser.add_argument("--iterations", type=int, default=5, help="Number of correction iterations (default: 5)") - parser.add_argument( - "--test-cases", nargs="+", default=None, help="Specific test cases for downstream mode (e.g., 1 2 3)" - ) - parser.add_argument("--output-dir", type=str, default=None, help="Output directory for results") - - args = parser.parse_args() - - # Setup logging - utils.enable_logging(log_level=logging.INFO, extra_loggers=[__name__]) - - if args.mode == "unittest": - # Run unit tests - unittest.main(argv=[""], exit=True) - - elif args.mode == "upstream": - # Run upstream test - if args.quick: - n_iterations = 2 - else: - n_iterations = args.iterations - - work_dir = Path(args.output_dir) if args.output_dir else None - - results_list, results_dict, output_file = run_upstream_pipeline(n_iterations=n_iterations, work_dir=work_dir) - - logger.info(f"\n✅ Upstream test complete!") - logger.info(f"Status: {results_dict['status']}") - logger.info(f"Iterations: {results_dict['iterations']}") - logger.info(f"Parameter sets: {results_dict['parameter_sets']}") - logger.info(f"Output file: {output_file}") - - elif args.mode == "downstream": - # Run downstream test - if args.quick: - n_iterations = 2 - test_cases = ["1"] - else: - n_iterations = args.iterations - test_cases = args.test_cases - - work_dir = Path(args.output_dir) if args.output_dir else None - - results_list, results_dict, output_file = run_downstream_pipeline( - n_iterations=n_iterations, test_cases=test_cases, work_dir=work_dir - ) - - logger.info(f"\n✅ Downstream test complete!") - logger.info(f"Output: {output_file}") - - # Validate results - ds = xr.open_dataset(output_file) - logger.info(f"NetCDF dimensions: {dict(ds.sizes)}") - assert not np.all(np.isnan(ds["im_lat_error_km"].values)), "No data stored" - logger.info("✅ Validation passed!") - - -if __name__ == "__main__": - main() diff --git a/tests/test_correction/test_dataio.py b/tests/test_correction/test_dataio.py index b2dc2d22..81910525 100644 --- a/tests/test_correction/test_dataio.py +++ b/tests/test_correction/test_dataio.py @@ -1,42 +1,17 @@ -""" -Tests for dataio.py module - -This module tests data I/O functionality: -- S3 object discovery and download -- NetCDF file handling -- Configuration management -- File path operations - -Running Tests: -------------- -# Via pytest (recommended) -pytest tests/test_correction/test_dataio.py -v - -# Run specific test -pytest tests/test_correction/test_dataio.py::DataIOTestCase::test_find_objects -v +"""Tests for ``curryer.correction.dataio`` (generic, no AWS credentials needed). -# Standalone execution -python tests/test_correction/test_dataio.py - -Notes: ------ -These tests use mock S3 clients to avoid requiring AWS credentials -or network access during testing. +CLARREO-specific S3 integration tests live in ``clarreo/test_clarreo_dataio.py``. """ from __future__ import annotations import datetime as dt import logging -import os -import tempfile -import unittest from pathlib import Path import pandas as pd import pytest -from curryer import utils from curryer.correction.dataio import ( S3Configuration, download_netcdf_objects, @@ -46,23 +21,24 @@ ) logger = logging.getLogger(__name__) -utils.enable_logging(log_level=logging.INFO, extra_loggers=[__name__]) + + +# ── FakeS3Client ────────────────────────────────────────────────────────────── class FakeS3Client: - def __init__(self, objects): - self.objects = objects # dict: key -> bytes - self.list_calls = [] - self.download_calls = [] + """Minimal in-memory S3 mock.""" + + def __init__(self, objects: dict[str, bytes]): + self.objects = objects + self.list_calls: list = [] + self.download_calls: list = [] def list_objects_v2(self, **kwargs): bucket = kwargs["Bucket"] prefix = kwargs.get("Prefix", "") self.list_calls.append((bucket, prefix)) - contents = [] - for key in sorted(self.objects): - if key.startswith(prefix): - contents.append({"Key": key}) + contents = [{"Key": k} for k in sorted(self.objects) if k.startswith(prefix)] return {"Contents": contents, "IsTruncated": False} def download_file(self, bucket, key, filename): @@ -70,254 +46,97 @@ def download_file(self, bucket, key, filename): Path(filename).write_bytes(self.objects[key]) -class DataIOTestCase(unittest.TestCase): - def setUp(self) -> None: - self.__tmp_dir = tempfile.TemporaryDirectory() - self.addCleanup(self.__tmp_dir.cleanup) - self.tmp_dir = Path(self.__tmp_dir.name) - - def test_find_netcdf_objects_filters_and_matches_prefix(self): - config = S3Configuration("test-bucket", "L1a/nadir") - objects = { - "L1a/nadir/20181225/file1.nc": b"data1", - "L1a/nadir/20181225/file2.txt": b"ignored", - "L1a/nadir/20181226/file3.nc": b"data3", - "L1a/nadir/20181227/file4.nc": b"out_of_range", - } - client = FakeS3Client(objects) - - logger.info(f"Testing NetCDF discovery with mock S3 - bucket: {config.bucket}, prefix: {config.base_prefix}") - keys = find_netcdf_objects( - config, - start_date=dt.date(2018, 12, 25), - end_date=dt.date(2018, 12, 26), - s3_client=client, - ) - logger.info(f"Found {len(keys)} NetCDF files matching date range and .nc extension") - - self.assertListEqual( - keys, - [ - "L1a/nadir/20181225/file1.nc", - "L1a/nadir/20181226/file3.nc", - ], - ) - self.assertListEqual( - client.list_calls, - [ - ("test-bucket", "L1a/nadir/20181225/"), - ("test-bucket", "L1a/nadir/20181226/"), - ], - ) - logger.info("Verified correct S3 prefix queries and file filtering") - - def test_download_netcdf_objects_writes_files(self): - config = S3Configuration("test-bucket", "L1a/nadir") - objects = { - "L1a/nadir/20181225/file1.nc": b"data1", - "L1a/nadir/20181225/file2.nc": b"data2", - } - client = FakeS3Client(objects) - - logger.info(f"Testing NetCDF download with mock S3 - {len(objects)} objects to destination: {self.tmp_dir}") - output_paths = download_netcdf_objects( - config, - objects.keys(), - self.tmp_dir, - s3_client=client, - ) - - logger.info(f"Successfully downloaded {len(output_paths)} files: {[p.name for p in output_paths]}") - self.assertSetEqual({p.name for p in output_paths}, {"file1.nc", "file2.nc"}) - for path in output_paths: - self.assertEqual(path.read_bytes(), objects[f"L1a/nadir/20181225/{path.name}"]) - logger.info("Verified file contents match S3 objects") - - -@unittest.skipUnless( - ( - os.getenv("AWS_ACCESS_KEY_ID", "") - and os.getenv("AWS_SECRET_ACCESS_KEY", "") - and os.getenv("AWS_SESSION_TOKEN", "") +# ── find / download ─────────────────────────────────────────────────────────── + + +def test_find_netcdf_objects_filters_and_matches_prefix(): + config = S3Configuration("test-bucket", "L1a/nadir") + objects = { + "L1a/nadir/20181225/file1.nc": b"data1", + "L1a/nadir/20181225/file2.txt": b"ignored", + "L1a/nadir/20181226/file3.nc": b"data3", + "L1a/nadir/20181227/file4.nc": b"out_of_range", + } + client = FakeS3Client(objects) + keys = find_netcdf_objects( + config, start_date=dt.date(2018, 12, 25), end_date=dt.date(2018, 12, 26), s3_client=client ) - or os.getenv("C9_USER"), - "Requires tester to set AWS access key environment variables or run in Cloud9.", -) -class ClarreoDataIOTestCase(unittest.TestCase): - def setUp(self) -> None: - self.__tmp_dir = tempfile.TemporaryDirectory() - self.addCleanup(self.__tmp_dir.cleanup) - self.tmp_dir = Path(self.__tmp_dir.name) + assert keys == ["L1a/nadir/20181225/file1.nc", "L1a/nadir/20181226/file3.nc"] + assert client.list_calls == [ + ("test-bucket", "L1a/nadir/20181225/"), + ("test-bucket", "L1a/nadir/20181226/"), + ] + + +def test_download_netcdf_objects_writes_files(tmp_path): + config = S3Configuration("test-bucket", "L1a/nadir") + objects = { + "L1a/nadir/20181225/file1.nc": b"data1", + "L1a/nadir/20181225/file2.nc": b"data2", + } + client = FakeS3Client(objects) + output_paths = download_netcdf_objects(config, objects.keys(), tmp_path, s3_client=client) + assert {p.name for p in output_paths} == {"file1.nc", "file2.nc"} + for p in output_paths: + assert p.read_bytes() == objects[f"L1a/nadir/20181225/{p.name}"] + + +# ── validation ──────────────────────────────────────────────────────────────── - def test_l0(self): - config = S3Configuration("clarreo", "L0/telemetry/hps_navigation/") - start_date = dt.date(2017, 1, 15) - end_date = dt.date(2017, 1, 15) - logger.info(f"Querying CSDS S3 bucket '{config.bucket}' for L0 telemetry data") - logger.info(f"Date range: {start_date} to {end_date}, prefix: {config.base_prefix}") +class MockConfig: + class MockGeo: + time_field = "corrected_timestamp" - keys = find_netcdf_objects( - config, - start_date=start_date, - end_date=end_date, - ) + def __init__(self): + self.geo = self.MockGeo() - logger.info(f"Successfully retrieved {len(keys)} L0 telemetry files from CSDS S3") - if keys: - logger.info(f"Example file: {keys[0]}") - self.assertListEqual( - keys, ["L0/telemetry/hps_navigation/20170115/CPF_TLM_L0.V00-000.hps_navigation-20170115-0.0.0.nc"] - ) +@pytest.fixture +def mock_config(): + return MockConfig() - def test_l1a(self): - config = S3Configuration("clarreo", "L1a/nadir/") - start_date = dt.date(2022, 6, 3) - end_date = dt.date(2022, 6, 3) - logger.info(f"Querying CSDS S3 bucket '{config.bucket}' for L1a nadir science data") - logger.info(f"Date range: {start_date} to {end_date}, prefix: {config.base_prefix}") +def test_validate_telemetry_valid(mock_config): + df = pd.DataFrame({"time": [1.0, 2.0], "position_x": [100.0, 200.0]}) + validate_telemetry_output(df, mock_config) # must not raise - keys = find_netcdf_objects( - config, - start_date=start_date, - end_date=end_date, - ) - logger.info(f"Successfully retrieved {len(keys)} L1a science files from CSDS S3") - if keys: - logger.info(f"Example file: {keys[0]}") +def test_validate_telemetry_not_dataframe(mock_config): + with pytest.raises(TypeError, match="must return pd.DataFrame"): + validate_telemetry_output({"not": "a dataframe"}, mock_config) - self.assertEqual(34, len(keys)) - self.assertIn("L1a/nadir/20220603/nadir-20220603T235952-step22-geolocation_creation-0.0.0.nc", keys) +def test_validate_telemetry_empty(mock_config): + with pytest.raises(ValueError, match="empty DataFrame"): + validate_telemetry_output(pd.DataFrame(), mock_config) -class MockConfig: - """Mock config for validation tests.""" - class MockGeo: - time_field = "corrected_timestamp" +def test_validate_science_valid(mock_config): + df = pd.DataFrame({"corrected_timestamp": [1e6, 2e6], "frame_id": [1, 2]}) + validate_science_output(df, mock_config) # must not raise - def __init__(self): - self.geo = self.MockGeo() + +def test_validate_science_not_dataframe(mock_config): + with pytest.raises(TypeError, match="must return pd.DataFrame"): + validate_science_output([1, 2, 3], mock_config) + + +def test_validate_science_empty(mock_config): + with pytest.raises(ValueError, match="empty DataFrame"): + validate_science_output(pd.DataFrame(), mock_config) + + +def test_validate_science_missing_time_field(mock_config): + df = pd.DataFrame({"frame_id": [1, 2], "other": [100, 200]}) + with pytest.raises(ValueError, match="must include time field 'corrected_timestamp'"): + validate_science_output(df, mock_config) -class TestDataIOValidation(unittest.TestCase): - """Test validation functions for data loaders.""" - - def setUp(self): - """Set up test fixtures.""" - self.config = MockConfig() - - def test_validate_telemetry_output_valid(self): - """Test that valid telemetry passes validation.""" - valid_df = pd.DataFrame( - { - "time": [1.0, 2.0, 3.0], - "position_x": [100.0, 200.0, 300.0], - } - ) - - logger.info(f"Testing telemetry validation with DataFrame shape: {valid_df.shape}") - # Should not raise - validate_telemetry_output(valid_df, self.config) - logger.info("Telemetry validation passed for valid DataFrame") - - def test_validate_telemetry_output_not_dataframe(self): - """Test that non-DataFrame input raises TypeError.""" - logger.info("Testing telemetry validation rejects non-DataFrame input") - with pytest.raises(TypeError, match="must return pd.DataFrame"): - validate_telemetry_output({"not": "a dataframe"}, self.config) - logger.info("Correctly raised TypeError for non-DataFrame telemetry") - - def test_validate_telemetry_output_empty(self): - """Test that empty DataFrame raises ValueError.""" - empty_df = pd.DataFrame() - - logger.info("Testing telemetry validation rejects empty DataFrame") - with pytest.raises(ValueError, match="empty DataFrame"): - validate_telemetry_output(empty_df, self.config) - logger.info("Correctly raised ValueError for empty telemetry DataFrame") - - def test_validate_science_output_valid(self): - """Test that valid science output passes validation.""" - valid_df = pd.DataFrame( - { - "corrected_timestamp": [1e6, 2e6, 3e6], - "frame_id": [1, 2, 3], - } - ) - - logger.info( - f"Testing science validation with DataFrame shape: {valid_df.shape}, time field: {self.config.geo.time_field}" - ) - # Should not raise - validate_science_output(valid_df, self.config) - logger.info("Science validation passed for valid DataFrame with required time field") - - def test_validate_science_output_not_dataframe(self): - """Test that non-DataFrame input raises TypeError.""" - logger.info("Testing science validation rejects non-DataFrame input") - with pytest.raises(TypeError, match="must return pd.DataFrame"): - validate_science_output([1, 2, 3], self.config) - logger.info("Correctly raised TypeError for non-DataFrame science output") - - def test_validate_science_output_empty(self): - """Test that empty DataFrame raises ValueError.""" - empty_df = pd.DataFrame() - - logger.info("Testing science validation rejects empty DataFrame") - with pytest.raises(ValueError, match="empty DataFrame"): - validate_science_output(empty_df, self.config) - logger.info("Correctly raised ValueError for empty science DataFrame") - - def test_validate_science_output_missing_time_field(self): - """Test that DataFrame missing time field raises ValueError.""" - df_no_time = pd.DataFrame( - { - "frame_id": [1, 2, 3], - "other_field": [100, 200, 300], - } - ) - - logger.info( - f"Testing science validation rejects DataFrame missing required time field '{self.config.geo.time_field}'" - ) - with pytest.raises(ValueError, match="must include time field 'corrected_timestamp'"): - validate_science_output(df_no_time, self.config) - logger.info("Correctly raised ValueError for missing time field") - - def test_validate_science_output_custom_time_field(self): - """Test that validation respects custom time field name.""" - # Change config to use different time field - self.config.geo.time_field = "custom_time" - - logger.info(f"Testing science validation with custom time field: {self.config.geo.time_field}") - df_custom_time = pd.DataFrame( - { - "custom_time": [1.0, 2.0, 3.0], - "data": [100, 200, 300], - } - ) - - # Should pass with custom time field - validate_science_output(df_custom_time, self.config) - logger.info("Validation passed for DataFrame with custom time field") - - # Should fail without custom time field - df_wrong_time = pd.DataFrame( - { - "corrected_timestamp": [1.0, 2.0, 3.0], - "data": [100, 200, 300], - } - ) - - logger.info("Testing validation rejects DataFrame with wrong time field name") - with pytest.raises(ValueError, match="must include time field 'custom_time'"): - validate_science_output(df_wrong_time, self.config) - logger.info("Correctly raised ValueError for wrong time field name") - - -if __name__ == "__main__": - unittest.main() +def test_validate_science_custom_time_field(mock_config): + mock_config.geo.time_field = "custom_time" + df_ok = pd.DataFrame({"custom_time": [1.0, 2.0], "data": [1, 2]}) + validate_science_output(df_ok, mock_config) # must not raise + df_bad = pd.DataFrame({"corrected_timestamp": [1.0, 2.0], "data": [1, 2]}) + with pytest.raises(ValueError, match="must include time field 'custom_time'"): + validate_science_output(df_bad, mock_config) diff --git a/tests/test_correction/test_geolocation_error_stats.py b/tests/test_correction/test_error_stats.py similarity index 83% rename from tests/test_correction/test_geolocation_error_stats.py rename to tests/test_correction/test_error_stats.py index 4b1ce25c..0e8c28eb 100644 --- a/tests/test_correction/test_geolocation_error_stats.py +++ b/tests/test_correction/test_error_stats.py @@ -1,5 +1,5 @@ """ -Unit tests for geolocation_error_stats.py +Unit tests for error_stats.py This module contains comprehensive unit tests for the ErrorStatsProcessor class and related functionality, including edge cases, validation, and numerical accuracy. @@ -13,18 +13,18 @@ Running Tests: ------------- # Via pytest (recommended for CI/CD) -pytest test_geolocation_error_stats.py -v +pytest test_error_stats.py -v # Run only the 13 test cases -pytest test_geolocation_error_stats.py::TestErrorStats13Cases -v +pytest test_error_stats.py::TestErrorStats13Cases -v # Run specific test case -pytest test_geolocation_error_stats.py::TestErrorStats13Cases::test_case_01_dili_region -v +pytest test_error_stats.py::TestErrorStats13Cases::test_case_01_dili_region -v Standalone Execution: -------------------- # Generate NASA demonstration report -python test_geolocation_error_stats.py +python test_error_stats.py This runs all 13 test cases and prints a comprehensive validation report showing individual case results and overall performance metrics. @@ -41,9 +41,10 @@ import xarray as xr from curryer import utils -from curryer.correction.geolocation_error_stats import ( +from curryer.correction.error_stats import ( + ErrorStatsConfig, ErrorStatsProcessor, - GeolocationConfig, + compute_percent_below, ) logger = logging.getLogger(__name__) @@ -94,24 +95,24 @@ def _sample_from_validated_test_cases(n_measurements: int, seed: int | None = No def _create_test_config(**overrides): """ - Create GeolocationConfig for testing with CLARREO defaults. + Create ErrorStatsConfig for testing with CLARREO defaults. - These values come from the CLARREO mission configuration and are appropriate - for testing CLARREO-specific scenarios. Tests can override individual values - to test edge cases or alternative configurations. + Provides the variable name mappings used by the CLARREO mission; tests can + override individual values to exercise edge-cases. - Args: - **overrides: Override any default values + Parameters + ---------- + **overrides + Override any default values. - Returns: - GeolocationConfig with CLARREO test values + Returns + ------- + ErrorStatsConfig + ErrorStatsConfig with CLARREO variable-name defaults. """ - # CLARREO mission defaults (from clarreo_correction_config.json) - # These values should match the canonical CLARREO configuration + # CLARREO mission variable names — no pass/fail thresholds here. + # Those belong in RequirementsConfig / CorrectionConfig, not ErrorStatsConfig. defaults = { - "earth_radius_m": 6378140.0, # WGS84 Earth radius (CLARREO standard) - "performance_threshold_m": 250.0, # CLARREO accuracy requirement - "performance_spec_percent": 39.0, # CLARREO performance spec "variable_names": { "spacecraft_position": "riss_ctrs", # CLARREO/ISS variable name "boresight": "bhat_hs", # HySICS boresight variable name @@ -119,7 +120,7 @@ def _create_test_config(**overrides): }, } defaults.update(overrides) - return GeolocationConfig(**defaults) + return ErrorStatsConfig(**defaults) def create_test_dataset_13_cases() -> xr.Dataset: @@ -370,18 +371,14 @@ def process_test_data(display_results: bool = True) -> xr.Dataset: results = processor.process_geolocation_errors(test_data) if display_results: - print(f"Processing Results Summary:") - print(f"=" * 50) + print("Processing Results Summary:") + print("=" * 50) print(f"Total measurements: {results.attrs['total_measurements']}") - print(f"Mean error distance: {results.attrs['mean_error_distance_m']:.2f} m") - print(f"Std error distance: {results.attrs['std_error_distance_m']:.2f} m") - print( - f"Min/Max error: {results.attrs['min_error_distance_m']:.2f} / {results.attrs['max_error_distance_m']:.2f} m" - ) - print(f"Errors < 250m: {results.attrs['num_below_250m']} ({results.attrs['percent_below_250m']:.1f}%)") - - spec_status = "✓ PASS" if results.attrs["performance_spec_met"] else "✗ FAIL" - print(f"Performance spec (>39% < 250m): {spec_status}") + print(f"Mean error: {results.attrs['mean_error_m']:.2f} m") + print(f"RMS error: {results.attrs['rms_error_m']:.2f} m") + print(f"Std error: {results.attrs['std_error_m']:.2f} m") + print(f"Min/Max: {results.attrs['min_error_m']:.2f} / {results.attrs['max_error_m']:.2f} m") + print(f"Errors < 250 m: {results.attrs['percent_below_250m']:.1f}%") return results @@ -404,10 +401,10 @@ class TestErrorStats13Cases: Usage: # Run all 13 test cases - pytest test_geolocation_error_stats.py::TestErrorStats13Cases -v + pytest test_error_stats.py::TestErrorStats13Cases -v # Run specific test case - pytest test_geolocation_error_stats.py::TestErrorStats13Cases::test_case_01_dili_region -v + pytest test_error_stats.py::TestErrorStats13Cases::test_case_01_dili_region -v """ @pytest.fixture(scope="class") @@ -648,59 +645,44 @@ def test_case_13_south_atlantic_region(self, test_dataset, processor): logger.info(f"Case 13 (S Atlantic): Processing produced NaN") def test_performance_spec_all_cases(self, test_dataset, processor): - """Test CLARREO performance spec (>39% < 250m) on all 13 cases. + """Test CLARREO performance (>39% of measurements < 250 m) on all 13 cases. - Validates that: - 1. The expected number of measurements fall below 250m threshold - 2. The CLARREO performance requirement (>39%) is met - 3. Results are consistent with original MATLAB implementation + Validates that the percentage of measurements below 250 m matches the + Engineering baseline and that the standard threshold table is populated. + Pass/fail evaluation is deliberately left to the caller — this test + checks that the *statistics* are computed correctly. """ results = processor.process_geolocation_errors(test_dataset) - # CLARREO requirement: 39% of measurements < 250m - percent_below_threshold = results.attrs["percent_below_250m"] - spec_met = results.attrs["performance_spec_met"] - num_below = results.attrs["num_below_250m"] + # CLARREO baseline: 8 out of 13 cases are below 250 m + percent_below_250 = results.attrs["percent_below_250m"] + total = results.attrs["total_measurements"] logger.info(f"\n{'=' * 60}") - logger.info(f"CLARREO Performance Spec Validation") + logger.info("CLARREO Error Statistics Validation") logger.info(f"{'=' * 60}") - logger.info(f"Measurements < 250m: {num_below}/13") - logger.info(f"Percentage: {percent_below_threshold:.1f}%") - logger.info(f"Requirement: >39%") - logger.info(f"Result: {'✓ PASS' if spec_met else '✗ FAIL'}") + logger.info(f"Total measurements: {total}") + logger.info(f"Percentage < 250 m: {percent_below_250:.1f}%%") logger.info(f"{'=' * 60}") - # Expected results from Engineering baseline - # Based on the actual computed values from the 13 test cases: - # Cases 6, 7, 8, 9, 10, 11, 12, 13 are below 250m - expected_num_below = 8 # Cases 6, 7, 8, 9, 10, 11, 12, 13 - expected_percent = 61.5 # 8/13 * 100 ≈ 61.54% - - # Validate against expected values - assert num_below == expected_num_below, f"Expected {expected_num_below} measurements < 250m, got {num_below}" - assert abs(percent_below_threshold - expected_percent) < 0.5, ( - f"Expected {expected_percent:.1f}%, got {percent_below_threshold:.1f}%" + # 8/13 ≈ 61.54 % below 250 m (Engineering baseline) + expected_percent = 100.0 * 8 / 13 + assert abs(percent_below_250 - expected_percent) < 0.5, ( + f"Expected ~{expected_percent:.1f}%, got {percent_below_250:.1f}%" ) + assert percent_below_250 > 39.0, "Should exceed 39 % (CLARREO mission spec)" - # Validate CLARREO spec is met - assert spec_met is True, "CLARREO performance spec should be met with these test cases" - assert percent_below_threshold > 39.0, "Should exceed 39% threshold" + # Standard threshold table must be present + for key in ( + "percent_below_100m", + "percent_below_250m", + "percent_below_500m", + "percent_below_750m", + "percent_below_1000m", + ): + assert key in results.attrs, f"Missing threshold table entry: {key}" - # This assertion documents whether the spec is met with these test cases - # The test data is fixed, so this will consistently pass - assert isinstance(spec_met, bool), "Performance spec result should be boolean" - - logger.info("✓ Performance metrics match expected MATLAB results") - logger.info(f"Percentage: {percent_below_threshold:.1f}%") - logger.info(f"Requirement: >39%") - logger.info(f"Result: {'✓ PASS' if spec_met else '✗ FAIL'}") - logger.info(f"{'=' * 60}") - - # Note: This assertion documents whether the spec is met with these test cases - # The test data is fixed, so this will consistently pass or fail - # Keeping as assertion to ensure we're aware of the performance level - assert isinstance(spec_met, bool), "Performance spec result should be boolean" + logger.info("✓ Statistics match Engineering baseline") # ============================================================================ @@ -764,39 +746,33 @@ def _create_minimal_test_data(self) -> xr.Dataset: ) def test_geolocation_config_default(self): - """Test default configuration values. + """Test default configuration. - Validates standard Earth radius (WGS84) and performance specs: - - 250m threshold: nadir equivalent accuracy requirement - - 39%: project performance requirement (>39% of measurements must be <250m) + After removing pass/fail fields from ErrorStatsConfig, the only + significant defaults are ``minimum_correlation`` (None) and the + ``variable_names`` mapping. """ config = _create_test_config() - self.assertEqual(config.earth_radius_m, 6378140.0) - self.assertEqual(config.performance_threshold_m, 250.0) - self.assertEqual(config.performance_spec_percent, 39.0) + self.assertIsNone(config.minimum_correlation) + self.assertIsNotNone(config.variable_names) def test_geolocation_config_custom(self): """Test custom configuration values.""" - config = _create_test_config( - earth_radius_m=6371000.0, performance_threshold_m=200.0, performance_spec_percent=40.0 - ) - self.assertEqual(config.earth_radius_m, 6371000.0) - self.assertEqual(config.performance_threshold_m, 200.0) - self.assertEqual(config.performance_spec_percent, 40.0) + config = _create_test_config(minimum_correlation=0.7) + self.assertEqual(config.minimum_correlation, 0.7) def test_processor_initialization_default(self): """Test processor initialization with default config.""" config = _create_test_config() processor = ErrorStatsProcessor(config=config) - self.assertIsInstance(processor.config, GeolocationConfig) - self.assertEqual(processor.config.earth_radius_m, 6378140.0) + self.assertIsInstance(processor.config, ErrorStatsConfig) + self.assertIsNone(processor.config.minimum_correlation) def test_processor_initialization_custom(self): """Test processor initialization with custom config.""" - config = _create_test_config(earth_radius_m=6371000.0, performance_threshold_m=200.0) + config = _create_test_config(minimum_correlation=0.5) processor = ErrorStatsProcessor(config=config) - self.assertEqual(processor.config.earth_radius_m, 6371000.0) - self.assertEqual(processor.config.performance_threshold_m, 200.0) + self.assertEqual(processor.config.minimum_correlation, 0.5) def test_validate_input_data_success(self): """Test successful input validation.""" @@ -906,32 +882,34 @@ def test_calculate_scaling_factors_off_nadir(self): self.assertFalse(np.isnan(xvp_factor)) def test_calculate_statistics_basic(self): - """Test basic statistics calculation.""" + """Test basic statistics calculation with new mission-agnostic keys.""" errors = np.array([100.0, 200.0, 300.0, 400.0, 500.0]) stats = self.processor._calculate_statistics(errors) - self.assertEqual(stats["mean_error_distance_m"], 300.0) - self.assertEqual(stats["min_error_distance_m"], 100.0) - self.assertEqual(stats["max_error_distance_m"], 500.0) + self.assertAlmostEqual(stats["mean_error_m"], 300.0) + self.assertEqual(stats["min_error_m"], 100.0) + self.assertEqual(stats["max_error_m"], 500.0) self.assertEqual(stats["total_measurements"], 5) - self.assertEqual(stats["num_below_250m"], 2) - self.assertEqual(stats["percent_below_250m"], 40.0) - self.assertTrue(stats["performance_spec_met"]) # 40% > 39% + # 2 out of 5 values are < 250 → 40 % + self.assertAlmostEqual(stats["percent_below_250m"], 40.0) + # Keys that must NOT be present (removed pass/fail) + self.assertNotIn("performance_spec_met", stats) + self.assertNotIn("num_below_250m", stats) def test_calculate_statistics_edge_cases(self): """Test statistics with edge cases.""" # All errors below threshold errors_low = np.array([50.0, 100.0, 150.0]) stats_low = self.processor._calculate_statistics(errors_low) - self.assertEqual(stats_low["percent_below_250m"], 100.0) - self.assertTrue(stats_low["performance_spec_met"]) + self.assertAlmostEqual(stats_low["percent_below_250m"], 100.0) + self.assertNotIn("performance_spec_met", stats_low) # All errors above threshold errors_high = np.array([300.0, 400.0, 500.0]) stats_high = self.processor._calculate_statistics(errors_high) - self.assertEqual(stats_high["percent_below_250m"], 0.0) - self.assertFalse(stats_high["performance_spec_met"]) + self.assertAlmostEqual(stats_high["percent_below_250m"], 0.0) + self.assertNotIn("performance_spec_met", stats_high) def test_create_output_dataset(self): """Test output dataset creation.""" @@ -968,7 +946,7 @@ def test_create_output_dataset(self): self.assertIn("lat_error_deg", output_ds.data_vars) self.assertIn("lon_error_deg", output_ds.data_vars) - # Check attributes + # Check attributes – earth_radius_m is sourced from constants, not config self.assertIn("title", output_ds.attrs) self.assertIn("earth_radius_m", output_ds.attrs) @@ -983,37 +961,56 @@ def test_end_to_end_processing(self): self.assertIn("nadir_equiv_total_error_m", results.data_vars) self.assertEqual(len(results.measurement), 13) - # Check statistics are computed - self.assertIn("mean_error_distance_m", results.attrs) - self.assertIn("percent_below_250m", results.attrs) - self.assertIn("performance_spec_met", results.attrs) + # New statistics keys + for key in ( + "mean_error_m", + "rms_error_m", + "std_error_m", + "min_error_m", + "max_error_m", + "total_measurements", + "percent_below_250m", + ): + self.assertIn(key, results.attrs, f"Missing stats key: {key}") + + # Pass/fail keys must NOT be present + self.assertNotIn("performance_spec_met", results.attrs) + self.assertNotIn("num_below_250m", results.attrs) def test_regression_against_known_values(self): """Test against known good values from original implementation.""" results = process_test_data(display_results=False) - # These are the expected values from the original implementation - expected_mean = 1203.26 # meters - expected_percent_below_250 = 61.5 # percent - expected_num_below_250 = 8 + # Expected values from the Engineering baseline (13 test cases) + expected_mean = 1203.26 # metres + expected_percent_below_250 = 100.0 * 8 / 13 # ≈ 61.54 % - # Allow small numerical differences - self.assertLess(abs(results.attrs["mean_error_distance_m"] - expected_mean), 0.1) + self.assertLess(abs(results.attrs["mean_error_m"] - expected_mean), 0.1) self.assertLess(abs(results.attrs["percent_below_250m"] - expected_percent_below_250), 0.1) - self.assertEqual(results.attrs["num_below_250m"], expected_num_below_250) - self.assertTrue(results.attrs["performance_spec_met"]) + # Verify new keys present, old pass/fail key absent + self.assertIn("rms_error_m", results.attrs) + self.assertNotIn("performance_spec_met", results.attrs) + self.assertNotIn("num_below_250m", results.attrs) def test_custom_config_processing(self): - """Test processing with custom configuration.""" - custom_config = _create_test_config(performance_threshold_m=300.0) + """Test processing with custom minimum_correlation config.""" + custom_config = _create_test_config(minimum_correlation=0.3) processor = ErrorStatsProcessor(custom_config) test_data = create_test_dataset_13_cases() results = processor.process_geolocation_errors(test_data) - # With higher threshold, more errors should be below threshold - self.assertLessEqual(results.attrs["num_below_250m"], results.attrs["total_measurements"]) - self.assertEqual(results.attrs["performance_threshold_m"], 300.0) + # All 13 measurements should be present (no correlation variable in dataset → no filtering) + self.assertEqual(results.attrs["total_measurements"], 13) + # Standard threshold table entries must be present + for key in ( + "percent_below_100m", + "percent_below_250m", + "percent_below_500m", + "percent_below_750m", + "percent_below_1000m", + ): + self.assertIn(key, results.attrs) def test_invalid_input_types(self): """Test handling of invalid input types.""" @@ -1067,6 +1064,78 @@ def test_view_plane_vector_calculation(self): # Check orthogonality self.assertLess(abs(np.dot(v_uen, x_uen)), 1e-10) + def test_compute_nadir_equivalent_errors_no_stats(self): + """compute_nadir_equivalent_errors() must NOT populate aggregate stats. + + This is the inner-loop path — computing mean/std on a single GCP pair + is not meaningful, so the output dataset should have no statistical attrs. + """ + test_data = create_test_dataset_13_cases() + result = self.processor.compute_nadir_equivalent_errors(test_data) + + # Core output variable present + self.assertIn("nadir_equiv_total_error_m", result.data_vars) + self.assertEqual(len(result.measurement), 13) + + # Stats attributes must NOT be present + for key in ( + "mean_error_m", + "rms_error_m", + "std_error_m", + "percent_below_250m", + "total_measurements", + "performance_spec_met", + ): + self.assertNotIn(key, result.attrs, f"Stats key should not be present: {key}") + + def test_compute_nadir_equivalent_errors_values_match_full_pipeline(self): + """Per-measurement values from the two-stage path must equal the one-stage path.""" + test_data = create_test_dataset_13_cases() + + nadir_only = self.processor.compute_nadir_equivalent_errors(test_data) + full = self.processor.process_geolocation_errors(test_data) + + npt.assert_allclose( + nadir_only["nadir_equiv_total_error_m"].values, + full["nadir_equiv_total_error_m"].values, + rtol=1e-12, + ) + + +class TestComputePercentBelow(unittest.TestCase): + """Tests for the module-level compute_percent_below() helper.""" + + def test_basic(self): + errors = np.array([50.0, 100.0, 200.0, 400.0, 500.0]) + # 3 values < 250 m → 60 % + self.assertAlmostEqual(compute_percent_below(errors, 250.0), 60.0) + + def test_all_below(self): + errors = np.array([10.0, 20.0, 30.0]) + self.assertAlmostEqual(compute_percent_below(errors, 100.0), 100.0) + + def test_none_below(self): + errors = np.array([300.0, 400.0]) + self.assertAlmostEqual(compute_percent_below(errors, 250.0), 0.0) + + def test_empty_array(self): + self.assertEqual(compute_percent_below(np.array([]), 250.0), 0.0) + + def test_custom_threshold(self): + errors = np.array([100.0, 200.0, 300.0, 400.0, 500.0]) + self.assertAlmostEqual(compute_percent_below(errors, 350.0), 60.0) + + def test_consistency_with_standard_table(self): + """compute_percent_below at 250 m must match percent_below_250m from stats.""" + test_data = create_test_dataset_13_cases() + config = _create_test_config() + processor = ErrorStatsProcessor(config=config) + results = processor.process_geolocation_errors(test_data) + + errors = results["nadir_equiv_total_error_m"].values + manual = compute_percent_below(errors, 250.0) + self.assertAlmostEqual(manual, results.attrs["percent_below_250m"], places=10) + class TestCorrelationFiltering(unittest.TestCase): """Test correlation-based filtering functionality.""" @@ -1352,9 +1421,7 @@ def run_13_case_validation(): # Create processor config = _create_test_config() processor = ErrorStatsProcessor(config=config) - print( - f"✓ Created processor with earth_radius={config.earth_radius_m}m, threshold={config.performance_threshold_m}m" - ) + print("✓ Created processor") # Process all cases print("\nProcessing all test cases...") @@ -1384,31 +1451,32 @@ def run_13_case_validation(): for idx, name, lat, lon in case_info: nadir_error = results["nadir_equiv_total_error_m"].values[idx] - below_threshold = nadir_error < config.performance_threshold_m - status = "✓" if below_threshold else "✗" - print(f"\nCase {idx + 1:2d}: {name:20s} ({lat:6.1f}°, {lon:7.1f}°)") - print(f" Nadir-equiv error: {nadir_error:6.1f} m {status}") + print(f" Nadir-equiv error: {nadir_error:6.1f} m") # Display summary statistics print("\n" + "=" * 80) print("SUMMARY STATISTICS") print("=" * 80) - print(f"Total measurements: {results.attrs['total_measurements']}") - print(f"Mean error distance: {results.attrs['mean_error_distance_m']:.2f} m") - print(f"Std error distance: {results.attrs['std_error_distance_m']:.2f} m") + print(f"Total measurements: {results.attrs['total_measurements']}") + print(f"Mean error: {results.attrs['mean_error_m']:.2f} m") + print(f"RMS error: {results.attrs['rms_error_m']:.2f} m") + print(f"Std error: {results.attrs['std_error_m']:.2f} m") + print(f"Min / Max error: {results.attrs['min_error_m']:.2f} / {results.attrs['max_error_m']:.2f} m") print( - f"Min/Max error: {results.attrs['min_error_distance_m']:.2f} / " - f"{results.attrs['max_error_distance_m']:.2f} m" + f"P90 / P95 / P99: {results.attrs['p90_error_m']:.1f} / " + f"{results.attrs['p95_error_m']:.1f} / {results.attrs['p99_error_m']:.1f} m" ) - print(f"Measurements < 250m: {results.attrs['num_below_250m']}") - print(f"Percentage < 250m: {results.attrs['percent_below_250m']:.1f}%") - - # Display performance spec result - spec_met = results.attrs["performance_spec_met"] - spec_status = "✓ PASS" if spec_met else "✗ FAIL" - print(f"\nCLARREO Performance Spec: >39% of measurements < 250m") - print(f"Result: {spec_status}") + print("\nThreshold table:") + for key in ( + "percent_below_100m", + "percent_below_250m", + "percent_below_500m", + "percent_below_750m", + "percent_below_1000m", + ): + threshold = key.replace("percent_below_", "").replace("m", "") + print(f" < {threshold:>6s} m : {results.attrs[key]:.1f}%") print("=" * 80) diff --git a/tests/test_correction/test_image_io.py b/tests/test_correction/test_image_io.py new file mode 100644 index 00000000..8e2e4c63 --- /dev/null +++ b/tests/test_correction/test_image_io.py @@ -0,0 +1,442 @@ +"""Unit tests for image_io module. + +Tests for MATLAB, HDF, and NetCDF I/O functions. +""" + +import tempfile +from pathlib import Path + +import numpy as np +import pytest + +from curryer.correction.data_structures import ImageGrid, NamedImageGrid +from curryer.correction.image_io import ( + load_gcp_chip_from_hdf, + load_gcp_chip_from_netcdf, + load_image_grid_from_mat, + load_image_grid_from_netcdf, + save_image_grid, + save_image_grid_to_netcdf, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_grid(rows: int = 10, cols: int = 10, with_height: bool = False) -> ImageGrid: + """Return a small deterministic ImageGrid for use in tests.""" + rng = np.random.default_rng(0) + data = rng.random((rows, cols)) + lat = np.linspace(38.0, 39.0, rows) + lon = np.linspace(-116.0, -115.0, cols) + lat_grid, lon_grid = np.meshgrid(lat, lon, indexing="ij") + h = np.zeros((rows, cols)) if with_height else None + return ImageGrid(data=data, lat=lat_grid, lon=lon_grid, h=h) + + +# --------------------------------------------------------------------------- +# save_image_grid / load round-trips +# --------------------------------------------------------------------------- + + +class TestImageGridSaveLoad: + """Test save/load round-trips for ImageGrid.""" + + def test_netcdf_round_trip(self): + """Test save and load ImageGrid from NetCDF.""" + rng = np.random.default_rng(0) + data = rng.random((50, 50)) + lat = np.linspace(38.0, 39.0, 50) + lon = np.linspace(-116.0, -115.0, 50) + lat_grid, lon_grid = np.meshgrid(lat, lon, indexing="ij") + + original_grid = ImageGrid(data=data, lat=lat_grid, lon=lon_grid) + + with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp: + tmp_path = Path(tmp.name) + + try: + save_image_grid(tmp_path, original_grid, format="netcdf") + + loaded_grid = load_gcp_chip_from_netcdf(tmp_path) + + np.testing.assert_array_almost_equal(loaded_grid.data, original_grid.data) + np.testing.assert_array_almost_equal(loaded_grid.lat, original_grid.lat) + np.testing.assert_array_almost_equal(loaded_grid.lon, original_grid.lon) + finally: + tmp_path.unlink(missing_ok=True) + + def test_netcdf_round_trip_with_height(self): + """Height field is preserved through the xarray-based NetCDF round-trip.""" + original_grid = _make_grid(with_height=True) + + with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp: + tmp_path = Path(tmp.name) + + try: + save_image_grid(tmp_path, original_grid, format="netcdf") + + loaded_grid = load_gcp_chip_from_netcdf(tmp_path) + + assert loaded_grid.h is not None + np.testing.assert_array_almost_equal(loaded_grid.h, original_grid.h) + finally: + tmp_path.unlink(missing_ok=True) + + def test_mat_round_trip(self): + """Test save and load ImageGrid from MATLAB .mat file.""" + rng = np.random.default_rng(0) + data = rng.random((50, 50)) + lat = np.linspace(38.0, 39.0, 50) + lon = np.linspace(-116.0, -115.0, 50) + lat_grid, lon_grid = np.meshgrid(lat, lon, indexing="ij") + + original_grid = ImageGrid(data=data, lat=lat_grid, lon=lon_grid) + + with tempfile.NamedTemporaryFile(suffix=".mat", delete=False) as tmp: + tmp_path = Path(tmp.name) + + try: + save_image_grid(tmp_path, original_grid, format="mat") + + loaded_grid = load_image_grid_from_mat(tmp_path, key="GCP") + + np.testing.assert_array_almost_equal(loaded_grid.data, original_grid.data) + np.testing.assert_array_almost_equal(loaded_grid.lat, original_grid.lat) + np.testing.assert_array_almost_equal(loaded_grid.lon, original_grid.lon) + finally: + tmp_path.unlink(missing_ok=True) + + def test_mat_round_trip_with_height(self): + """Height field is preserved through the MATLAB .mat round-trip.""" + original_grid = _make_grid(with_height=True) + + with tempfile.NamedTemporaryFile(suffix=".mat", delete=False) as tmp: + tmp_path = Path(tmp.name) + + try: + save_image_grid(tmp_path, original_grid, format="mat") + + loaded_grid = load_image_grid_from_mat(tmp_path, key="GCP") + + assert loaded_grid.h is not None + np.testing.assert_array_almost_equal(loaded_grid.h, original_grid.h) + finally: + tmp_path.unlink(missing_ok=True) + + def test_save_with_metadata(self): + """Test saving ImageGrid with custom metadata.""" + rng = np.random.default_rng(0) + data = rng.random((10, 10)) + lat = np.linspace(38.0, 39.0, 10) + lon = np.linspace(-116.0, -115.0, 10) + lat_grid, lon_grid = np.meshgrid(lat, lon, indexing="ij") + + grid = ImageGrid(data=data, lat=lat_grid, lon=lon_grid) + + metadata = { + "source": "test_chip.hdf", + "mission": "test", + "creation_date": "2025-01-22", + } + + with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp: + tmp_path = Path(tmp.name) + + try: + save_image_grid(tmp_path, grid, format="netcdf", metadata=metadata) + + import xarray as xr + + ds = xr.open_dataset(tmp_path) + assert "source" in ds.attrs + assert ds.attrs["source"] == "test_chip.hdf" + ds.close() + finally: + tmp_path.unlink(missing_ok=True) + + def test_invalid_format(self): + """Test that invalid format raises ValueError.""" + rng = np.random.default_rng(0) + data = rng.random((10, 10)) + lat = np.linspace(38.0, 39.0, 10) + lon = np.linspace(-116.0, -115.0, 10) + lat_grid, lon_grid = np.meshgrid(lat, lon, indexing="ij") + + grid = ImageGrid(data=data, lat=lat_grid, lon=lon_grid) + + with tempfile.NamedTemporaryFile(suffix=".xyz") as tmp: + tmp_path = Path(tmp.name) + with pytest.raises(ValueError, match="Unsupported format"): + save_image_grid(tmp_path, grid, format="invalid_format") + + +# --------------------------------------------------------------------------- +# CF-1.8 save_image_grid_to_netcdf / load_image_grid_from_netcdf pair +# --------------------------------------------------------------------------- + + +class TestNetCDF4DirectIO: + """Test the lower-level CF-1.8 netCDF4 save/load pair.""" + + def test_round_trip_regular_grid(self): + """save_image_grid_to_netcdf + load_image_grid_from_netcdf: regular grid.""" + original_grid = _make_grid(rows=20, cols=25) + + with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp: + tmp_path = Path(tmp.name) + + try: + save_image_grid_to_netcdf(tmp_path, original_grid) + + loaded_grid = load_image_grid_from_netcdf(tmp_path) + + assert loaded_grid.data.shape == original_grid.data.shape + np.testing.assert_array_almost_equal(loaded_grid.data, original_grid.data) + # Coordinates reconstructed from 1-D arrays via meshgrid — values must match + np.testing.assert_array_almost_equal(loaded_grid.lat, original_grid.lat) + np.testing.assert_array_almost_equal(loaded_grid.lon, original_grid.lon) + finally: + tmp_path.unlink(missing_ok=True) + + def test_round_trip_with_height(self): + """Height is preserved through the netCDF4 round-trip.""" + original_grid = _make_grid(rows=10, cols=10, with_height=True) + + with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp: + tmp_path = Path(tmp.name) + + try: + save_image_grid_to_netcdf(tmp_path, original_grid) + + loaded_grid = load_image_grid_from_netcdf(tmp_path) + + assert loaded_grid.h is not None + np.testing.assert_array_almost_equal(loaded_grid.h, original_grid.h) + finally: + tmp_path.unlink(missing_ok=True) + + def test_round_trip_with_metadata(self): + """Metadata attributes are written to the file.""" + grid = _make_grid() + metadata = {"mission": "CLARREO", "band": "red"} + + with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp: + tmp_path = Path(tmp.name) + + try: + save_image_grid_to_netcdf(tmp_path, grid, metadata=metadata) + + import xarray as xr + + ds = xr.open_dataset(tmp_path) + assert ds.attrs.get("mission") == "CLARREO" + assert ds.attrs.get("band") == "red" + ds.close() + finally: + tmp_path.unlink(missing_ok=True) + + def test_round_trip_irregular_grid(self): + """2-D (irregular) coordinate arrays are stored and recovered correctly.""" + rng = np.random.default_rng(1) + nrows, ncols = 8, 10 + data = rng.random((nrows, ncols)) + # Distorted grid — coordinates are NOT separable + lat_base = np.linspace(38.0, 39.0, nrows) + lon_base = np.linspace(-116.0, -115.0, ncols) + lon_grid, lat_grid = np.meshgrid(lon_base, lat_base) + lat_grid += 0.05 * rng.standard_normal((nrows, ncols)) + lon_grid += 0.05 * rng.standard_normal((nrows, ncols)) + + original_grid = ImageGrid(data=data, lat=lat_grid, lon=lon_grid) + + with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp: + tmp_path = Path(tmp.name) + + try: + save_image_grid_to_netcdf(tmp_path, original_grid) + + loaded_grid = load_image_grid_from_netcdf(tmp_path) + + assert loaded_grid.data.shape == original_grid.data.shape + np.testing.assert_array_almost_equal(loaded_grid.data, original_grid.data) + finally: + tmp_path.unlink(missing_ok=True) + + def test_missing_file_raises(self): + """load_image_grid_from_netcdf raises FileNotFoundError for absent files.""" + with pytest.raises(FileNotFoundError): + load_image_grid_from_netcdf(Path("does_not_exist.nc")) + + +# --------------------------------------------------------------------------- +# HDF file loading +# --------------------------------------------------------------------------- + + +class TestHDFLoading: + """Test HDF file loading.""" + + def test_missing_file(self): + """Test that missing file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError): + load_gcp_chip_from_hdf(Path("nonexistent.hdf")) + + def test_hdf5_round_trip(self): + """Load a synthetic HDF5 file via the h5py fallback path.""" + h5py = pytest.importorskip("h5py") + + rng = np.random.default_rng(0) + nrows, ncols = 12, 15 + band = rng.random((nrows, ncols)) + ecef_x = rng.random((nrows, ncols)) * 1e6 + ecef_y = rng.random((nrows, ncols)) * 1e6 + ecef_z = rng.random((nrows, ncols)) * 1e6 + + with tempfile.NamedTemporaryFile(suffix=".h5", delete=False) as tmp: + tmp_path = Path(tmp.name) + + try: + with h5py.File(tmp_path, "w") as hdf: + hdf.create_dataset("Band_1", data=band) + hdf.create_dataset("ECR_x_coordinate_array", data=ecef_x) + hdf.create_dataset("ECR_y_coordinate_array", data=ecef_y) + hdf.create_dataset("ECR_z_coordinate_array", data=ecef_z) + + loaded_band, loaded_x, loaded_y, loaded_z = load_gcp_chip_from_hdf(tmp_path) + + assert loaded_band.shape == (nrows, ncols) + np.testing.assert_array_almost_equal(loaded_band, band) + np.testing.assert_array_almost_equal(loaded_x, ecef_x) + finally: + tmp_path.unlink(missing_ok=True) + + def test_shape_validation(self): + """Shape mismatch across HDF5 datasets raises ValueError.""" + h5py = pytest.importorskip("h5py") + + rng = np.random.default_rng(0) + + with tempfile.NamedTemporaryFile(suffix=".h5", delete=False) as tmp: + tmp_path = Path(tmp.name) + + try: + with h5py.File(tmp_path, "w") as hdf: + hdf.create_dataset("Band_1", data=rng.random((10, 10))) + hdf.create_dataset("ECR_x_coordinate_array", data=rng.random((10, 10))) + hdf.create_dataset("ECR_y_coordinate_array", data=rng.random((10, 10))) + # Deliberately wrong shape for Z + hdf.create_dataset("ECR_z_coordinate_array", data=rng.random((5, 10))) + + with pytest.raises(ValueError, match="shape mismatch"): + load_gcp_chip_from_hdf(tmp_path) + finally: + tmp_path.unlink(missing_ok=True) + + +# --------------------------------------------------------------------------- +# load_gcp_chip_from_netcdf (xarray-based) +# --------------------------------------------------------------------------- + + +class TestNetCDFLoading: + """Test NetCDF file loading.""" + + def test_missing_file(self): + """Test that missing file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError): + load_gcp_chip_from_netcdf(Path("nonexistent.nc")) + + def test_round_trip_with_height(self): + """Height variable is loaded when present in file.""" + original_grid = _make_grid(with_height=True) + + with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp: + tmp_path = Path(tmp.name) + + try: + save_image_grid(tmp_path, original_grid, format="netcdf") + + loaded_grid = load_gcp_chip_from_netcdf(tmp_path) + + assert loaded_grid.h is not None + np.testing.assert_array_almost_equal(loaded_grid.h, original_grid.h) + finally: + tmp_path.unlink(missing_ok=True) + + def test_missing_variable_raises(self): + """Missing band_data variable in NetCDF raises OSError.""" + import xarray as xr + + nrows, ncols = 5, 5 + lat = np.linspace(38.0, 39.0, nrows) + lon = np.linspace(-116.0, -115.0, ncols) + lat_grid, lon_grid = np.meshgrid(lat, lon, indexing="ij") + + # Build a file that is valid NetCDF but lacks "band_data" + ds = xr.Dataset( + {"lat": (["y", "x"], lat_grid), "lon": (["y", "x"], lon_grid)}, + coords={"y": np.arange(nrows), "x": np.arange(ncols)}, + ) + + with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp: + tmp_path = Path(tmp.name) + + try: + ds.to_netcdf(tmp_path) + + with pytest.raises(OSError, match="band_data"): + load_gcp_chip_from_netcdf(tmp_path) + finally: + tmp_path.unlink(missing_ok=True) + + +# --------------------------------------------------------------------------- +# load_image_grid_from_mat +# --------------------------------------------------------------------------- + + +class TestMATLoading: + """Test MATLAB file loading.""" + + def test_missing_file(self): + """Test that missing file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError): + load_image_grid_from_mat(Path("nonexistent.mat")) + + def test_missing_key_raises(self): + """KeyError is raised when the requested struct key is absent.""" + from scipy.io import savemat + + rng = np.random.default_rng(0) + data = rng.random((5, 5)) + + with tempfile.NamedTemporaryFile(suffix=".mat", delete=False) as tmp: + tmp_path = Path(tmp.name) + + try: + # Save under a different key than will be requested + savemat(str(tmp_path), {"other_key": {"data": data}}) + + with pytest.raises(KeyError, match="subimage"): + load_image_grid_from_mat(tmp_path, key="subimage") + finally: + tmp_path.unlink(missing_ok=True) + + def test_as_named_returns_named_image_grid(self): + """as_named=True returns a NamedImageGrid with the correct name.""" + original_grid = _make_grid() + + with tempfile.NamedTemporaryFile(suffix=".mat", delete=False) as tmp: + tmp_path = Path(tmp.name) + + try: + save_image_grid(tmp_path, original_grid, format="mat") + + loaded = load_image_grid_from_mat(tmp_path, key="GCP", as_named=True) + + assert isinstance(loaded, NamedImageGrid) + assert loaded.name is not None + finally: + tmp_path.unlink(missing_ok=True) diff --git a/tests/test_correction/test_image_match.py b/tests/test_correction/test_image_match.py index d204ccf9..fbdb1ebb 100644 --- a/tests/test_correction/test_image_match.py +++ b/tests/test_correction/test_image_match.py @@ -1,57 +1,38 @@ -""" -Tests for image_match.py module - -This module tests the integrated image matching functionality, including: -- Image grid loading from MATLAB files -- Geolocation error application for testing -- Image matching algorithm validation -- Cross-correlation and grid search - -Running Tests: -------------- -# Via pytest (recommended) -pytest tests/test_correction/test_image_match.py -v - -# Run specific test -pytest tests/test_correction/test_image_match.py::TestImageMatch::test_integrated_match_case_1 -v - -# Standalone execution -python tests/test_correction/test_image_match.py - -Requirements: ------------------ -These tests validate image matching algorithms against known test cases, -demonstrating that the Python implementation correctly identifies geolocation -errors through image cross-correlation. +"""Tests for ``curryer.correction.image_match``. + +Covers ``integrated_image_match`` against 20 CLARREO test cases (5 scenes × +4 sub-cases each: unbinned a/b, 3-pixel-binned c/d). """ +from __future__ import annotations + import logging -import tempfile -import unittest from pathlib import Path import numpy as np +import pytest import xarray as xr from scipy.io import loadmat -from curryer import utils from curryer.compute import constants from curryer.correction.data_structures import ( - GeolocationConfig, ImageGrid, OpticalPSFEntry, + PSFSamplingConfig, SearchConfig, ) from curryer.correction.image_match import integrated_image_match logger = logging.getLogger(__name__) -utils.enable_logging(log_level=logging.INFO, extra_loggers=[__name__]) xr.set_options(display_width=120, display_max_rows=30) np.set_printoptions(linewidth=120) -def image_grid_from_struct(mat_struct): +# ── module-level helpers ────────────────────────────────────────────────────── + + +def image_grid_from_struct(mat_struct) -> ImageGrid: return ImageGrid( data=np.asarray(mat_struct.data), lat=np.asarray(mat_struct.lat), @@ -69,35 +50,42 @@ def great_circle_displacement_deg(lat_km: float, lon_km: float, reference_lat_de return lat_offset_deg, lon_offset_deg -def apply_geolocation_error(subimage: ImageGrid, gcp: ImageGrid, lat_error_km: float, lon_error_km: float): - """Return a copy of the subimage with imposed geolocation error.""" +def apply_geolocation_error(subimage: ImageGrid, gcp: ImageGrid, lat_error_km: float, lon_error_km: float) -> ImageGrid: + """Return a copy of *subimage* with an imposed lat/lon geolocation error.""" mid_lat = float(gcp.lat[gcp.lat.shape[0] // 2, gcp.lat.shape[1] // 2]) - lat_offset_deg, lon_offset_deg = great_circle_displacement_deg(lat_error_km, lon_error_km, mid_lat) + lat_off, lon_off = great_circle_displacement_deg(lat_error_km, lon_error_km, mid_lat) return ImageGrid( data=subimage.data.copy(), - lat=subimage.lat + lat_offset_deg, - lon=subimage.lon + lon_offset_deg, + lat=subimage.lat + lat_off, + lon=subimage.lon + lon_off, h=subimage.h.copy() if subimage.h is not None else None, ) -class ImageMatchTestCase(unittest.TestCase): - def setUp(self) -> None: - root_dir = Path(__file__).parent.parent.parent - self.test_dir = root_dir / "tests" / "data" / "clarreo" / "image_match" - self.assertTrue(self.test_dir.is_dir(), self.test_dir) +# ── fixtures ────────────────────────────────────────────────────────────────── + + +@pytest.fixture(scope="module") +def image_match_dir(): + root = Path(__file__).parent.parent.parent + d = root / "tests" / "data" / "clarreo" / "image_match" + assert d.is_dir(), str(d) + return d - self.__tmp_dir = tempfile.TemporaryDirectory() - self.addCleanup(self.__tmp_dir.cleanup) - self.tmp_dir = Path(self.__tmp_dir.name) - @staticmethod - def run_image_match(subimg_file, gcp_file, ancil_file, psf_file, pix_vec_file, lat_lon_err): - subimage_struct = loadmat(subimg_file, squeeze_me=True, struct_as_record=False)["subimage"] - subimage = image_grid_from_struct(subimage_struct) +# ── test class ──────────────────────────────────────────────────────────────── - gcp_struct = loadmat(gcp_file, squeeze_me=True, struct_as_record=False)["GCP"] - gcp = image_grid_from_struct(gcp_struct) + +class TestImageMatch: + """Validates ``integrated_image_match`` against 20 known test cases.""" + + @pytest.fixture(autouse=True) + def _setup(self, image_match_dir): + self.test_dir = image_match_dir + + def _run(self, subimg_file, gcp_file, ancil_file, psf_file, pix_vec_file, lat_lon_err): + subimage = image_grid_from_struct(loadmat(subimg_file, squeeze_me=True, struct_as_record=False)["subimage"]) + gcp = image_grid_from_struct(loadmat(gcp_file, squeeze_me=True, struct_as_record=False)["GCP"]) los_vectors = loadmat(pix_vec_file, squeeze_me=True)["b_HS"] r_iss = loadmat(ancil_file, squeeze_me=True)["R_ISS_midframe"].ravel() @@ -120,24 +108,26 @@ def run_image_match(subimg_file, gcp_file, ancil_file, psf_file, pix_vec_file, l r_iss_midframe_m=r_iss, los_vectors_hs=los_vectors, optical_psfs=psf_entries, - # lat_error_km=lat_lon_err[0], - # lon_error_km=lat_lon_err[1], - geolocation_config=GeolocationConfig(), + geolocation_config=PSFSamplingConfig(), search_config=SearchConfig(), ) - logger.info(f"Image file: {subimg_file.name}") - logger.info(f"GCP file: {gcp_file.name}") - logger.info(f"Lat Error (km): {result.lat_error_km:+6.3f} (exp={lat_lon_err[0]:+6.3f})") - logger.info(f"Lon Error (km): {result.lon_error_km:+6.3f} (exp={lat_lon_err[1]:+6.3f})") - logger.info(f" CCV (final): {result.ccv_final}") - logger.info(f" Pixel (row): {result.final_index_row}") - logger.info(f" Pixel (col): {result.final_index_col}") + logger.info( + "case=%s lat=%+.3f exp=%+.3f lon=%+.3f exp=%+.3f ccv=%.4f row=%d col=%d", + subimg_file.name, + result.lat_error_km, + lat_lon_err[0], + result.lon_error_km, + lat_lon_err[1], + result.ccv_final, + result.final_index_row, + result.final_index_col, + ) return result def test_case_1a_unbinned(self): lat_lon_err = (3.0, -3.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "1" / "TestCase1a_subimage.mat", gcp_file=self.test_dir / "1" / "GCP12055Dili_resampled.mat", ancil_file=self.test_dir / "1" / "R_ISS_midframe_TestCase1.mat", @@ -145,15 +135,15 @@ def test_case_1a_unbinned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 23) - np.testing.assert_allclose(result.final_index_col, 22) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 23) + np.testing.assert_allclose(r.final_index_col, 22) def test_case_1b_unbinned(self): lat_lon_err = (3.0, -3.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "1" / "TestCase1b_subimage.mat", gcp_file=self.test_dir / "1" / "GCP12055Dili_resampled.mat", ancil_file=self.test_dir / "1" / "R_ISS_midframe_TestCase1.mat", @@ -161,15 +151,15 @@ def test_case_1b_unbinned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 21) - np.testing.assert_allclose(result.final_index_col, 22) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 21) + np.testing.assert_allclose(r.final_index_col, 22) def test_case_1c_binned(self): lat_lon_err = (3.0, -3.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "1" / "TestCase1c_subimage_binned.mat", gcp_file=self.test_dir / "1" / "GCP12055Dili_resampled.mat", ancil_file=self.test_dir / "1" / "R_ISS_midframe_TestCase1.mat", @@ -177,15 +167,15 @@ def test_case_1c_binned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 21) - np.testing.assert_allclose(result.final_index_col, 23) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 21) + np.testing.assert_allclose(r.final_index_col, 23) def test_case_1d_binned(self): lat_lon_err = (3.0, -3.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "1" / "TestCase1d_subimage_binned.mat", gcp_file=self.test_dir / "1" / "GCP12055Dili_resampled.mat", ancil_file=self.test_dir / "1" / "R_ISS_midframe_TestCase1.mat", @@ -193,15 +183,15 @@ def test_case_1d_binned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 23) - np.testing.assert_allclose(result.final_index_col, 21) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 23) + np.testing.assert_allclose(r.final_index_col, 21) def test_case_2a_unbinned(self): lat_lon_err = (-3.0, 2.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "2" / "TestCase2a_subimage.mat", gcp_file=self.test_dir / "2" / "GCP10121Maracaibo_resampled.mat", ancil_file=self.test_dir / "2" / "R_ISS_midframe_TestCase2.mat", @@ -209,15 +199,15 @@ def test_case_2a_unbinned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 22) - np.testing.assert_allclose(result.final_index_col, 21) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 22) + np.testing.assert_allclose(r.final_index_col, 21) def test_case_2b_unbinned(self): lat_lon_err = (-3.0, 2.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "2" / "TestCase2b_subimage.mat", gcp_file=self.test_dir / "2" / "GCP10121Maracaibo_resampled.mat", ancil_file=self.test_dir / "2" / "R_ISS_midframe_TestCase2.mat", @@ -225,15 +215,15 @@ def test_case_2b_unbinned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 22) - np.testing.assert_allclose(result.final_index_col, 23) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 22) + np.testing.assert_allclose(r.final_index_col, 23) def test_case_2c_binned(self): lat_lon_err = (-3.0, 2.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "2" / "TestCase2c_subimage_binned.mat", gcp_file=self.test_dir / "2" / "GCP10121Maracaibo_resampled.mat", ancil_file=self.test_dir / "2" / "R_ISS_midframe_TestCase2.mat", @@ -241,15 +231,15 @@ def test_case_2c_binned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 22) - np.testing.assert_allclose(result.final_index_col, 23) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 22) + np.testing.assert_allclose(r.final_index_col, 23) def test_case_2d_binned(self): lat_lon_err = (-3.0, 2.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "2" / "TestCase2d_subimage_binned.mat", gcp_file=self.test_dir / "2" / "GCP10121Maracaibo_resampled.mat", ancil_file=self.test_dir / "2" / "R_ISS_midframe_TestCase2.mat", @@ -257,15 +247,15 @@ def test_case_2d_binned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 23) - np.testing.assert_allclose(result.final_index_col, 23) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 23) + np.testing.assert_allclose(r.final_index_col, 23) def test_case_3a_unbinned(self): lat_lon_err = (1.0, 1.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "3" / "TestCase3a_subimage.mat", gcp_file=self.test_dir / "3" / "GCP10665SantaRosa_resampled.mat", ancil_file=self.test_dir / "3" / "R_ISS_midframe_TestCase3.mat", @@ -273,15 +263,15 @@ def test_case_3a_unbinned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 22) - np.testing.assert_allclose(result.final_index_col, 21) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 22) + np.testing.assert_allclose(r.final_index_col, 21) def test_case_3b_unbinned(self): lat_lon_err = (1.0, 1.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "3" / "TestCase3b_subimage.mat", gcp_file=self.test_dir / "3" / "GCP10665SantaRosa_resampled.mat", ancil_file=self.test_dir / "3" / "R_ISS_midframe_TestCase3.mat", @@ -289,15 +279,15 @@ def test_case_3b_unbinned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 22) - np.testing.assert_allclose(result.final_index_col, 21) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 22) + np.testing.assert_allclose(r.final_index_col, 21) def test_case_3c_binned(self): lat_lon_err = (1.0, 1.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "3" / "TestCase3c_subimage_binned.mat", gcp_file=self.test_dir / "3" / "GCP10665SantaRosa_resampled.mat", ancil_file=self.test_dir / "3" / "R_ISS_midframe_TestCase3.mat", @@ -305,15 +295,15 @@ def test_case_3c_binned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.07) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.07) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 22) - np.testing.assert_allclose(result.final_index_col, 23) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.07) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.07) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 22) + np.testing.assert_allclose(r.final_index_col, 23) def test_case_3d_binned(self): lat_lon_err = (1.0, 1.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "3" / "TestCase3d_subimage_binned.mat", gcp_file=self.test_dir / "3" / "GCP10665SantaRosa_resampled.mat", ancil_file=self.test_dir / "3" / "R_ISS_midframe_TestCase3.mat", @@ -321,15 +311,15 @@ def test_case_3d_binned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.07) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.07) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 23) - np.testing.assert_allclose(result.final_index_col, 20) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.07) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.07) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 23) + np.testing.assert_allclose(r.final_index_col, 20) def test_case_4a_unbinned(self): lat_lon_err = (-1.0, -2.5) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "4" / "TestCase4a_subimage.mat", gcp_file=self.test_dir / "4" / "GCP20484Morocco_resampled.mat", ancil_file=self.test_dir / "4" / "R_ISS_midframe_TestCase4.mat", @@ -337,15 +327,15 @@ def test_case_4a_unbinned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 23) - np.testing.assert_allclose(result.final_index_col, 22) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 23) + np.testing.assert_allclose(r.final_index_col, 22) def test_case_4b_unbinned(self): lat_lon_err = (-1.0, -2.5) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "4" / "TestCase4b_subimage.mat", gcp_file=self.test_dir / "4" / "GCP20484Morocco_resampled.mat", ancil_file=self.test_dir / "4" / "R_ISS_midframe_TestCase4.mat", @@ -353,15 +343,15 @@ def test_case_4b_unbinned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 23) - np.testing.assert_allclose(result.final_index_col, 23) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 23) + np.testing.assert_allclose(r.final_index_col, 23) def test_case_4c_binned(self): lat_lon_err = (-1.0, -2.5) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "4" / "TestCase4c_subimage_binned.mat", gcp_file=self.test_dir / "4" / "GCP20484Morocco_resampled.mat", ancil_file=self.test_dir / "4" / "R_ISS_midframe_TestCase4.mat", @@ -369,15 +359,15 @@ def test_case_4c_binned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 21) - np.testing.assert_allclose(result.final_index_col, 21) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 21) + np.testing.assert_allclose(r.final_index_col, 21) def test_case_4d_binned(self): lat_lon_err = (-1.0, -2.5) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "4" / "TestCase4d_subimage_binned.mat", gcp_file=self.test_dir / "4" / "GCP20484Morocco_resampled.mat", ancil_file=self.test_dir / "4" / "R_ISS_midframe_TestCase4.mat", @@ -385,15 +375,15 @@ def test_case_4d_binned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 23) - np.testing.assert_allclose(result.final_index_col, 23) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 23) + np.testing.assert_allclose(r.final_index_col, 23) def test_case_5a_unbinned(self): lat_lon_err = (2.5, 0.1) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "5" / "TestCase5a_subimage.mat", gcp_file=self.test_dir / "5" / "GCP10087Titicaca_resampled.mat", ancil_file=self.test_dir / "5" / "R_ISS_midframe_TestCase5.mat", @@ -401,15 +391,15 @@ def test_case_5a_unbinned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 23) - np.testing.assert_allclose(result.final_index_col, 21) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 23) + np.testing.assert_allclose(r.final_index_col, 21) def test_case_5b_unbinned(self): lat_lon_err = (2.5, 0.1) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "5" / "TestCase5b_subimage.mat", gcp_file=self.test_dir / "5" / "GCP10087Titicaca_resampled.mat", ancil_file=self.test_dir / "5" / "R_ISS_midframe_TestCase5.mat", @@ -417,15 +407,15 @@ def test_case_5b_unbinned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 22) - np.testing.assert_allclose(result.final_index_col, 21) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 22) + np.testing.assert_allclose(r.final_index_col, 21) def test_case_5c_binned(self): lat_lon_err = (2.5, 0.1) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "5" / "TestCase5c_subimage_binned.mat", gcp_file=self.test_dir / "5" / "GCP10087Titicaca_resampled.mat", ancil_file=self.test_dir / "5" / "R_ISS_midframe_TestCase5.mat", @@ -433,15 +423,15 @@ def test_case_5c_binned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 22) - np.testing.assert_allclose(result.final_index_col, 22) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 22) + np.testing.assert_allclose(r.final_index_col, 22) def test_case_5d_binned(self): lat_lon_err = (2.5, 0.1) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "5" / "TestCase5d_subimage_binned.mat", gcp_file=self.test_dir / "5" / "GCP10087Titicaca_resampled.mat", ancil_file=self.test_dir / "5" / "R_ISS_midframe_TestCase5.mat", @@ -449,12 +439,8 @@ def test_case_5d_binned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 22) - np.testing.assert_allclose(result.final_index_col, 22) - - -if __name__ == "__main__": - unittest.main() + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 22) + np.testing.assert_allclose(r.final_index_col, 22) diff --git a/tests/test_correction/test_io.py b/tests/test_correction/test_io.py new file mode 100644 index 00000000..280780db --- /dev/null +++ b/tests/test_correction/test_io.py @@ -0,0 +1,141 @@ +"""Tests for curryer.correction.io — unified path resolution.""" + +from __future__ import annotations + +import builtins +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from curryer.correction.io import _download_from_s3, _temp_files, resolve_path + + +class TestResolvePathLocal: + """Tests for local file resolution.""" + + def test_existing_file_returns_path(self, tmp_path): + f = tmp_path / "data.csv" + f.write_text("a,b\n1,2\n") + result = resolve_path(f) + assert result == f + + def test_existing_file_from_string(self, tmp_path): + f = tmp_path / "data.csv" + f.write_text("a,b\n1,2\n") + result = resolve_path(str(f)) + assert result == f + + def test_returns_path_type(self, tmp_path): + f = tmp_path / "data.csv" + f.write_text("x") + result = resolve_path(f) + assert isinstance(result, Path) + + def test_nonexistent_local_file_raises_filenotfounderror(self): + with pytest.raises(FileNotFoundError, match="File not found"): + resolve_path("/nonexistent/path/to/file.csv") + + +class TestResolvePathS3: + """Tests for S3 URI resolution with injected client.""" + + def _make_client(self, content: str = "downloaded content") -> MagicMock: + """Return a mock S3 client that writes *content* to the target path.""" + mock_client = MagicMock() + + def fake_download(bucket, key, local_path): + Path(local_path).write_text(content) + + mock_client.download_file.side_effect = fake_download + return mock_client + + def test_s3_uri_downloads_to_temp_file(self): + mock_client = self._make_client("downloaded content") + + result = resolve_path("s3://my-bucket/path/to/file.mat", s3_client=mock_client) + + try: + assert result.exists() + assert result.read_text() == "downloaded content" + assert result.suffix == ".mat" + mock_client.download_file.assert_called_once_with("my-bucket", "path/to/file.mat", str(result)) + finally: + result.unlink(missing_ok=True) + + def test_s3_uri_preserves_file_extension(self): + mock_client = self._make_client("") + + result = resolve_path("s3://bucket/data/telemetry.nc", s3_client=mock_client) + try: + assert result.suffix == ".nc" + finally: + result.unlink(missing_ok=True) + + def test_s3_uri_no_key_raises_valueerror(self): + mock_client = MagicMock() + with pytest.raises(ValueError, match="must include an object key"): + resolve_path("s3://bucket-only", s3_client=mock_client) + + def test_s3_uri_empty_key_raises_valueerror(self): + mock_client = MagicMock() + with pytest.raises(ValueError, match="must include an object key"): + resolve_path("s3://bucket/", s3_client=mock_client) + + def test_s3_uri_without_boto3_raises_importerror(self, monkeypatch): + """When no s3_client is injected AND boto3 is missing, ImportError is raised.""" + real_import = builtins.__import__ + + def mock_import(name, *args, **kwargs): + if name == "boto3": + raise ImportError("No module named 'boto3'") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", mock_import) + + with pytest.raises(ImportError, match="pip install boto3"): + resolve_path("s3://bucket/key/file.csv") + + def test_s3_download_failure_cleans_up_partial_temp_file(self): + """On download failure the temp file must not be left on disk.""" + mock_client = MagicMock() + mock_client.download_file.side_effect = RuntimeError("S3 failure") + + # Capture any temp path that would have been created + created_paths: list[Path] = [] + import tempfile as _tempfile + + real_ntf = _tempfile.NamedTemporaryFile + + def capturing_ntf(*args, **kwargs): + f = real_ntf(*args, **kwargs) + created_paths.append(Path(f.name)) + return f + + import curryer.correction.io as _io + + original = _io.tempfile.NamedTemporaryFile + _io.tempfile.NamedTemporaryFile = capturing_ntf + try: + with pytest.raises(RuntimeError, match="S3 failure"): + _download_from_s3("s3://bucket/key/file.csv", s3_client=mock_client) + finally: + _io.tempfile.NamedTemporaryFile = original + + # Temp file should have been deleted by the except block + for path in created_paths: + assert not path.exists(), f"Temp file was not cleaned up: {path}" + + def test_temp_files_registered_for_atexit_cleanup(self): + """Successful downloads register the temp path in _temp_files.""" + mock_client = self._make_client("") + + initial_count = len(_temp_files) + result = resolve_path("s3://bucket/key/file.csv", s3_client=mock_client) + try: + assert len(_temp_files) == initial_count + 1 + assert _temp_files[-1] == result + finally: + result.unlink(missing_ok=True) + if result in _temp_files: + _temp_files.remove(result) diff --git a/tests/test_correction/test_kernel_ops.py b/tests/test_correction/test_kernel_ops.py new file mode 100644 index 00000000..51e3d916 --- /dev/null +++ b/tests/test_correction/test_kernel_ops.py @@ -0,0 +1,222 @@ +"""Tests for ``curryer.correction.kernel_ops``. + +Covers: +- ``apply_offset`` – all parameter types and unit-conversion paths +- ``_load_calibration_data`` +- ``_create_dynamic_kernels`` (``@pytest.mark.extra``, requires ``mkspk``) +""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +from clarreo_config import create_clarreo_correction_config +from clarreo_data_loaders import load_clarreo_telemetry + +from curryer import meta +from curryer import spicierpy as sp +from curryer.correction import correction +from curryer.kernels import create + +logger = logging.getLogger(__name__) + +# ── shared sample data ──────────────────────────────────────────────────────── + +_TLM = pd.DataFrame( + { + "frame": range(5), + "hps.az_ang_nonlin": [1.14252] * 5, + "hps.el_ang_nonlin": [-0.55009] * 5, + "hps.resolver_tms": [1168477154.0 + i for i in range(5)], + "ert": [1431903180.58 + i for i in range(5)], + } +) + +_SCI = pd.DataFrame( + { + "corrected_timestamp": [1_000_000.0, 2_000_000.0, 3_000_000.0, 4_000_000.0, 5_000_000.0], + "measurement": [1.0, 2.0, 3.0, 4.0, 5.0], + } +) + + +@pytest.fixture(scope="module") +def clarreo_cfg(root_dir): + return create_clarreo_correction_config( + root_dir / "tests" / "data" / "clarreo" / "gcs", + root_dir / "data" / "generic", + ) + + +# ── apply_offset tests ──────────────────────────────────────────────────────── + + +def test_apply_offset_kernel_arcseconds(): + """OFFSET_KERNEL converts arcseconds to radians and adds the offset.""" + p = correction.ParameterConfig( + ptype=correction.ParameterType.OFFSET_KERNEL, + config_file=Path("cprs_az.json"), + data=dict(field="hps.az_ang_nonlin", units="arcseconds"), + ) + original = _TLM["hps.az_ang_nonlin"].mean() + modified = correction.apply_offset(p, 100.0, _TLM) + assert modified["hps.az_ang_nonlin"].mean() - original == pytest.approx(np.deg2rad(100.0 / 3600.0), rel=1e-6) + assert isinstance(modified, pd.DataFrame) + + +def test_apply_offset_kernel_negative(): + p = correction.ParameterConfig( + ptype=correction.ParameterType.OFFSET_KERNEL, + config_file=Path("cprs_el.json"), + data=dict(field="hps.el_ang_nonlin", units="arcseconds"), + ) + original = _TLM["hps.el_ang_nonlin"].mean() + modified = correction.apply_offset(p, -50.0, _TLM) + assert modified["hps.el_ang_nonlin"].mean() - original == pytest.approx(np.deg2rad(-50.0 / 3600.0), rel=1e-6) + + +def test_apply_offset_kernel_missing_field(): + """Non-existent field: returns original DataFrame unchanged.""" + p = correction.ParameterConfig( + ptype=correction.ParameterType.OFFSET_KERNEL, + config_file=Path("dummy.json"), + data=dict(field="nonexistent_field", units="arcseconds"), + ) + modified = correction.apply_offset(p, 10.0, _TLM) + pd.testing.assert_frame_equal(modified, _TLM) + + +def test_apply_offset_time_milliseconds(): + """OFFSET_TIME: seconds input → microsecond output on timestamp column.""" + p = correction.ParameterConfig( + ptype=correction.ParameterType.OFFSET_TIME, + config_file=None, + data=dict(field="corrected_timestamp", units="milliseconds"), + ) + original = _SCI["corrected_timestamp"].mean() + modified = correction.apply_offset(p, 10.0 / 1000.0, _SCI) # 10 ms in seconds + assert modified["corrected_timestamp"].mean() - original == pytest.approx(10_000.0, rel=1e-6) + + +def test_apply_offset_time_negative(): + p = correction.ParameterConfig( + ptype=correction.ParameterType.OFFSET_TIME, + config_file=None, + data=dict(field="corrected_timestamp", units="milliseconds"), + ) + original = _SCI["corrected_timestamp"].mean() + modified = correction.apply_offset(p, -5.5 / 1000.0, _SCI) + assert modified["corrected_timestamp"].mean() - original == pytest.approx(-5500.0, rel=1e-6) + + +def test_apply_offset_constant_kernel_passthrough(): + """CONSTANT_KERNEL: data is returned unchanged.""" + kernel_data = pd.DataFrame({"ugps": [1_000_000], "angle_x": [0.001], "angle_y": [0.002], "angle_z": [0.003]}) + p = correction.ParameterConfig( + ptype=correction.ParameterType.CONSTANT_KERNEL, + config_file=Path("base.json"), + data=dict(field="base"), + ) + modified = correction.apply_offset(p, kernel_data, pd.DataFrame()) + pd.testing.assert_frame_equal(modified, kernel_data) + + +def test_apply_offset_no_units(): + """OFFSET_KERNEL without units: offset applied in raw (radian) units.""" + p = correction.ParameterConfig( + ptype=correction.ParameterType.OFFSET_KERNEL, + config_file=Path("test.json"), + data=dict(field="hps.az_ang_nonlin"), + ) + original = _TLM["hps.az_ang_nonlin"].mean() + modified = correction.apply_offset(p, 0.001, _TLM) + assert modified["hps.az_ang_nonlin"].mean() - original == pytest.approx(0.001, rel=1e-6) + + +def test_apply_offset_not_inplace(): + """Original DataFrame is not mutated.""" + p = correction.ParameterConfig( + ptype=correction.ParameterType.OFFSET_KERNEL, + config_file=Path("cprs_az.json"), + data=dict(field="hps.az_ang_nonlin", units="arcseconds"), + ) + original = _TLM.copy() + correction.apply_offset(p, 100.0, _TLM) + pd.testing.assert_frame_equal(_TLM, original) + + +def test_apply_offset_preserves_columns(): + """All columns are present in the returned DataFrame.""" + p = correction.ParameterConfig( + ptype=correction.ParameterType.OFFSET_KERNEL, + config_file=Path("cprs_az.json"), + data=dict(field="hps.az_ang_nonlin", units="arcseconds"), + ) + modified = correction.apply_offset(p, 50.0, _TLM) + assert set(modified.columns) == set(_TLM.columns) + assert modified["frame"].equals(_TLM["frame"]) + assert not modified["hps.az_ang_nonlin"].equals(_TLM["hps.az_ang_nonlin"]) + + +# ── _load_calibration_data ──────────────────────────────────────────────────── + + +def test_load_calibration_data_no_dir(clarreo_cfg): + """When calibration_dir is None and no direct paths, returned data contains no vectors.""" + cfg = clarreo_cfg.model_copy(deep=True) + cfg.calibration_dir = None + cfg.los_vectors_file = None + cfg.psf_file = None + cal = correction._load_calibration_data(cfg) + assert cal.los_vectors is None + assert cal.optical_psfs is None + + +def test_load_calibration_data_direct_los_missing(clarreo_cfg, tmp_path): + """FileNotFoundError when los_vectors_file points to a non-existent file.""" + cfg = clarreo_cfg.model_copy(deep=True) + cfg.calibration_dir = None + cfg.los_vectors_file = tmp_path / "nonexistent_los.mat" + cfg.psf_file = None + with pytest.raises(FileNotFoundError, match="LOS vectors"): + correction._load_calibration_data(cfg) + + +def test_load_calibration_data_direct_psf_missing(clarreo_cfg, tmp_path): + """FileNotFoundError when psf_file points to a non-existent file.""" + from unittest.mock import patch + + cfg = clarreo_cfg.model_copy(deep=True) + cfg.calibration_dir = None + # Provide a fake LOS file so the LOS loading succeeds + fake_los = tmp_path / "los.mat" + fake_los.touch() + cfg.los_vectors_file = fake_los + cfg.psf_file = tmp_path / "nonexistent_psf.mat" + # Mock the actual loader so we don't need a real .mat file + with patch("curryer.correction.pipeline.load_los_vectors_from_mat", return_value=[[0.0, 0.0, 1.0]]): + with pytest.raises(FileNotFoundError, match="PSF"): + correction._load_calibration_data(cfg) + + +# ── _create_dynamic_kernels ─────────────────────────────────────────────────── + + +@pytest.mark.extra +def test_create_dynamic_kernels(root_dir, clarreo_cfg, tmp_path): + """_create_dynamic_kernels builds kernel files. Needs ``mkspk`` – ``--run-extra``.""" + data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" + work = tmp_path / "kernels" + work.mkdir() + tlm = load_clarreo_telemetry(data_dir) + creator = create.KernelCreator(overwrite=True, append=False) + mkrn = meta.MetaKernel.from_json( + clarreo_cfg.geo.meta_kernel_file, relative=True, sds_dir=clarreo_cfg.geo.generic_kernel_dir + ) + with sp.ext.load_kernel([mkrn.sds_kernels, mkrn.mission_kernels]): + dynamic_kernels = correction._create_dynamic_kernels(clarreo_cfg, work, tlm, creator) + assert isinstance(dynamic_kernels, list) diff --git a/tests/test_correction/test_pairing.py b/tests/test_correction/test_pairing.py index 3317d2af..98ef6fb0 100644 --- a/tests/test_correction/test_pairing.py +++ b/tests/test_correction/test_pairing.py @@ -1,45 +1,20 @@ -""" -Tests for pairing.py module - -This module tests the GCP (Ground Control Point) pairing functionality: -- Spatial pairing of L1A science data with GCP reference imagery -- File discovery and matching -- Geographic overlap detection -- Pairing validation - -Running Tests: -------------- -# Via pytest (recommended) -pytest tests/test_correction/test_pairing.py -v - -# Run specific test -pytest tests/test_correction/test_pairing.py::PairingTestCase::test_find_l1a_gcp_pairs -v - -# Standalone execution -python tests/test_correction/test_pairing.py - -Requirements: ------------------ -These tests validate that GCP pairing correctly identifies which reference -images overlap with science data, ensuring accurate geolocation validation. -""" +"""Tests for ``curryer.correction.pairing`` (GCP spatial pairing).""" from __future__ import annotations import logging -import unittest from pathlib import Path import numpy as np +import pytest from scipy.io import loadmat -from curryer import utils from curryer.correction.data_structures import NamedImageGrid from curryer.correction.pairing import find_l1a_gcp_pairs logger = logging.getLogger(__name__) -utils.enable_logging(log_level=logging.DEBUG, extra_loggers=[__name__]) +# ── test-case metadata ──────────────────────────────────────────────────────── L1A_FILES = [ ("1/TestCase1a_subimage.mat", "subimage"), @@ -75,183 +50,86 @@ "5/TestCase5b_subimage.mat": "5/GCP10087Titicaca_resampled.mat", } +# ── fixtures / helpers ──────────────────────────────────────────────────────── -class PairingTestCase(unittest.TestCase): - def setUp(self) -> None: - root_dir = Path(__file__).parent.parent.parent - self.test_dir = root_dir / "tests" / "data" / "clarreo" / "image_match" - self.assertTrue(self.test_dir.is_dir(), self.test_dir) - - def _load_image_grid(self, relative_path: str, key: str) -> NamedImageGrid: - mat_path = self.test_dir / relative_path - mat = loadmat(mat_path, squeeze_me=True, struct_as_record=False)[key] - h = getattr(mat, "h", None) - return NamedImageGrid( - data=np.asarray(mat.data), - lat=np.asarray(mat.lat), - lon=np.asarray(mat.lon), - h=np.asarray(h) if h is not None else None, - name=relative_path, - ) - - def _prepare_inputs(self, file_list): - for rel_path, key in file_list: - yield self._load_image_grid(rel_path, key) - - def test_find_l1a_gcp_pairs(self): - l1a_inputs = list(self._prepare_inputs(L1A_FILES)) - gcp_inputs = list(self._prepare_inputs(GCP_FILES)) - - result = find_l1a_gcp_pairs(l1a_inputs, gcp_inputs, max_distance_m=0.0) - - matches_by_l1a = {} - for match in result.matches: - l1a_name = result.l1a_images[match.l1a_index].name - gcp_name = result.gcp_images[match.gcp_index].name - distance = match.distance_m - if l1a_name not in matches_by_l1a or distance < matches_by_l1a[l1a_name][1]: - matches_by_l1a[l1a_name] = (gcp_name, distance) - - assert set(matches_by_l1a.keys()) == set(expected for expected in EXPECTED_MATCHES) - - for l1a_name, expected_gcp in EXPECTED_MATCHES.items(): - assert l1a_name in matches_by_l1a, f"Missing match for {l1a_name}" - gcp_name, distance = matches_by_l1a[l1a_name] - assert gcp_name == expected_gcp, f"Unexpected match for {l1a_name}: {gcp_name}" - assert distance >= 0.0, f"Expected non-negative margin for {l1a_name}: {distance}" - - @staticmethod - def _make_rect_image(name: str, lon_min: float, lon_max: float, lat_min: float, lat_max: float) -> NamedImageGrid: - lat = np.array([[lat_max, lat_max], [lat_min, lat_min]], dtype=float) - lon = np.array([[lon_min, lon_max], [lon_min, lon_max]], dtype=float) - data = np.zeros_like(lat) - return NamedImageGrid(data=data, lat=lat, lon=lon, name=name) - - @staticmethod - def _make_point_gcp(name: str, lon: float, lat: float) -> NamedImageGrid: - data = np.array([[1.0]]) - lat_arr = np.array([[lat]]) - lon_arr = np.array([[lon]]) - return NamedImageGrid(data=data, lat=lat_arr, lon=lon_arr, name=name) - - def test_synthetic_pairing_no_overlap(self): - l1a = [self._make_rect_image("L1A", lon_min=1.0, lon_max=2.0, lat_min=-1.0, lat_max=1.0)] - gcp = [self._make_point_gcp("GCP", lon=0.0, lat=0.0)] - - result = find_l1a_gcp_pairs(l1a, gcp, max_distance_m=100_000.0) - assert result.matches == [] - - def test_synthetic_pairing_partial_less_than_threshold(self): - l1a = [self._make_rect_image("L1A", lon_min=0.0, lon_max=1.0, lat_min=-1.0, lat_max=1.0)] - gcp = [self._make_point_gcp("GCP", lon=0.0, lat=0.0)] - - result = find_l1a_gcp_pairs(l1a, gcp, max_distance_m=100_000.0) - assert result.matches == [] - - def test_synthetic_pairing_complete_above_threshold(self): - l1a = [self._make_rect_image("L1A", lon_min=-1.0, lon_max=1.0, lat_min=-1.0, lat_max=1.0)] - gcp = [self._make_point_gcp("GCP", lon=0.0, lat=0.0)] - - result = find_l1a_gcp_pairs(l1a, gcp, max_distance_m=100_000.0) - assert len(result.matches) == 1 - match = result.matches[0] - assert match.l1a_index == 0 - assert match.gcp_index == 0 - assert match.distance_m >= 100_000.0 - - def test_synthetic_pairing_partial_threshold_not_met(self): - l1a = [self._make_rect_image("L1A", lon_min=-1.0, lon_max=1.0, lat_min=-1.0, lat_max=1.0)] - gcp = [self._make_point_gcp("GCP", lon=0.0, lat=0.0)] - - result = find_l1a_gcp_pairs(l1a, gcp, max_distance_m=200_000.0) - assert result.matches == [] - - -class TestPairingValidation(unittest.TestCase): - """Test validation function for pairing output.""" - - def test_validate_pairing_output_valid(self): - """Test that valid pairing output passes validation.""" - from curryer.correction.pairing import validate_pairing_output - - valid_pairs = [ - ("science_001", "landsat_gcp_001.tif"), - ("science_002", "landsat_gcp_002.tif"), - ] - - # Should not raise - validate_pairing_output(valid_pairs) - - def test_validate_pairing_output_valid_with_path(self): - """Test that pairing output with Path objects passes validation.""" - from curryer.correction.pairing import validate_pairing_output - - valid_pairs = [ - ("science_001", Path("/data/gcp_001.tif")), - ("science_002", "/data/gcp_002.tif"), - ] - - # Should not raise - validate_pairing_output(valid_pairs) - - def test_validate_pairing_output_empty_list(self): - """Test that empty list is valid (no pairs found).""" - from curryer.correction.pairing import validate_pairing_output - empty_pairs = [] +@pytest.fixture(scope="module") +def image_match_dir(): + root = Path(__file__).parent.parent.parent + d = root / "tests" / "data" / "clarreo" / "image_match" + assert d.is_dir(), str(d) + return d - # Should not raise - validate_pairing_output(empty_pairs) - def test_validate_pairing_output_not_list(self): - """Test that non-list input raises TypeError.""" - import pytest +def _load(path: Path, key: str, name: str) -> NamedImageGrid: + mat = loadmat(str(path), squeeze_me=True, struct_as_record=False)[key] + h = getattr(mat, "h", None) + return NamedImageGrid( + data=np.asarray(mat.data), + lat=np.asarray(mat.lat), + lon=np.asarray(mat.lon), + h=np.asarray(h) if h is not None else None, + name=name, + ) - from curryer.correction.pairing import validate_pairing_output - with pytest.raises(TypeError, match="must return list"): - validate_pairing_output(("science", "gcp")) # Tuple instead of list +def _rect(name, lon_min, lon_max, lat_min, lat_max) -> NamedImageGrid: + lat = np.array([[lat_max, lat_max], [lat_min, lat_min]], dtype=float) + lon = np.array([[lon_min, lon_max], [lon_min, lon_max]], dtype=float) + return NamedImageGrid(data=np.zeros_like(lat), lat=lat, lon=lon, name=name) - def test_validate_pairing_output_not_tuple(self): - """Test that list elements that aren't tuples raise ValueError.""" - import pytest - from curryer.correction.pairing import validate_pairing_output +def _point(name, lon, lat) -> NamedImageGrid: + return NamedImageGrid(data=np.array([[1.0]]), lat=np.array([[lat]]), lon=np.array([[lon]]), name=name) - invalid_pairs = [ - ("science_001", "gcp_001.tif"), - ["science_002", "gcp_002.tif"], # List instead of tuple - ] - with pytest.raises(ValueError, match=r"output\[1\] must be \(str, str\) tuple"): - validate_pairing_output(invalid_pairs) +# ── tests ───────────────────────────────────────────────────────────────────── - def test_validate_pairing_output_wrong_tuple_length(self): - """Test that tuples with wrong length raise ValueError.""" - import pytest - from curryer.correction.pairing import validate_pairing_output +def test_find_l1a_gcp_pairs(image_match_dir): + l1a = [_load(image_match_dir / rel, key, rel) for rel, key in L1A_FILES] + gcp = [_load(image_match_dir / rel, key, rel) for rel, key in GCP_FILES] + result = find_l1a_gcp_pairs(l1a, gcp, max_distance_m=0.0) + by_l1a = {} + for m in result.matches: + name = result.l1a_images[m.l1a_index].name + gname = result.gcp_images[m.gcp_index].name + if name not in by_l1a or m.distance_m < by_l1a[name][1]: + by_l1a[name] = (gname, m.distance_m) + assert set(by_l1a) == set(EXPECTED_MATCHES) + for l1a_name, expected_gcp in EXPECTED_MATCHES.items(): + gname, dist = by_l1a[l1a_name] + assert gname == expected_gcp + assert dist >= 0.0 - invalid_pairs = [ - ("science_001", "gcp_001.tif", "extra"), # 3 elements - ] - with pytest.raises(ValueError, match=r"output\[0\] must be \(str, str\) tuple"): - validate_pairing_output(invalid_pairs) +def test_synthetic_pairing_no_overlap(): + result = find_l1a_gcp_pairs( + [_rect("L1A", 1.0, 2.0, -1.0, 1.0)], [_point("GCP", 0.0, 0.0)], max_distance_m=100_000.0 + ) + assert result.matches == [] - def test_validate_pairing_output_wrong_types(self): - """Test that tuple elements with wrong types raise ValueError.""" - import pytest - from curryer.correction.pairing import validate_pairing_output +def test_synthetic_pairing_partial_less_than_threshold(): + result = find_l1a_gcp_pairs( + [_rect("L1A", 0.0, 1.0, -1.0, 1.0)], [_point("GCP", 0.0, 0.0)], max_distance_m=100_000.0 + ) + assert result.matches == [] - invalid_pairs = [ - (123, "gcp_001.tif"), # First element not string - ] - with pytest.raises(ValueError, match=r"output\[0\].*expected \(str, str\)"): - validate_pairing_output(invalid_pairs) +def test_synthetic_pairing_complete_above_threshold(): + result = find_l1a_gcp_pairs( + [_rect("L1A", -1.0, 1.0, -1.0, 1.0)], [_point("GCP", 0.0, 0.0)], max_distance_m=100_000.0 + ) + assert len(result.matches) == 1 + m = result.matches[0] + assert m.l1a_index == 0 + assert m.gcp_index == 0 + assert m.distance_m >= 100_000.0 -if __name__ == "__main__": - unittest.main() +def test_synthetic_pairing_partial_threshold_not_met(): + result = find_l1a_gcp_pairs( + [_rect("L1A", -1.0, 1.0, -1.0, 1.0)], [_point("GCP", 0.0, 0.0)], max_distance_m=200_000.0 + ) + assert result.matches == [] diff --git a/tests/test_correction/test_parameters.py b/tests/test_correction/test_parameters.py new file mode 100644 index 00000000..c98d8ed0 --- /dev/null +++ b/tests/test_correction/test_parameters.py @@ -0,0 +1,668 @@ +"""Unit tests for parameter-set generation strategies. + +Covers: +- ``SearchStrategy.RANDOM`` – default Monte Carlo random walk (exact behaviour preserved) +- ``SearchStrategy.GRID_SEARCH`` – cartesian-product sweep over evenly-spaced offsets +- ``SearchStrategy.SINGLE_OFFSET`` – one-parameter-at-a-time sweep (others held at nominal) + +For every strategy the three parameter types are exercised: + - ``CONSTANT_KERNEL`` – 3-axis rotation (returns a pandas DataFrame) + - ``OFFSET_KERNEL`` – single angle bias (float, radians) + - ``OFFSET_TIME`` – timing correction (float, seconds) + +Config validation: +- ``grid_points_per_param < 2`` rejected for GRID_SEARCH +- JSON round-trip preserves strategy fields +""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +from pydantic import ValidationError + +from curryer.correction.config import ( + CorrectionConfig, + GeolocationConfig, + ParameterConfig, + ParameterType, + SearchStrategy, +) +from curryer.correction.parameters import ( + _get_grid_values, + _get_nominal_value, + load_param_sets, +) + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def geo() -> GeolocationConfig: + return GeolocationConfig( + meta_kernel_file=Path("tests/data/test.kernels.tm.json"), + generic_kernel_dir=Path("data/generic"), + instrument_name="TEST_INSTRUMENT", + time_field="ugps", + ) + + +@pytest.fixture +def param_constant() -> ParameterConfig: + """CONSTANT_KERNEL: roll/pitch/yaw in arcseconds.""" + return ParameterConfig( + ptype=ParameterType.CONSTANT_KERNEL, + config_file=Path("tests/data/test_base.attitude.ck.json"), + data={ + "current_value": [10.0, 20.0, 30.0], + "bounds": [-60.0, 60.0], + "sigma": 6.0, + "units": "arcseconds", + }, + ) + + +@pytest.fixture +def param_constant_zero() -> ParameterConfig: + """CONSTANT_KERNEL: all axes at zero with no sigma → always returns [0, 0, 0].""" + return ParameterConfig( + ptype=ParameterType.CONSTANT_KERNEL, + config_file=Path("tests/data/test_base.attitude.ck.json"), + data={ + "current_value": [0.0, 0.0, 0.0], + "bounds": [-10.0, 10.0], + "sigma": None, + "units": "arcseconds", + }, + ) + + +@pytest.fixture +def param_offset_kernel() -> ParameterConfig: + """OFFSET_KERNEL: angle bias in arcseconds.""" + return ParameterConfig( + ptype=ParameterType.OFFSET_KERNEL, + config_file=Path("tests/data/test_az.attitude.ck.json"), + data={ + "field": "hps.az_ang_nonlin", + "current_value": 0.0, + "bounds": [-3600.0, 3600.0], + "sigma": 360.0, + "units": "arcseconds", + }, + ) + + +@pytest.fixture +def param_offset_time() -> ParameterConfig: + """OFFSET_TIME: timing bias in milliseconds.""" + return ParameterConfig( + ptype=ParameterType.OFFSET_TIME, + config_file=None, + data={ + "field": "corrected_timestamp", + "current_value": 0.0, + "bounds": [-50.0, 50.0], + "sigma": 7.0, + "units": "milliseconds", + }, + ) + + +def _make_config( + geo, + params, + *, + strategy: SearchStrategy = SearchStrategy.RANDOM, + n_iterations: int = 5, + seed: int | None = 0, + grid_points_per_param: int = 4, +) -> CorrectionConfig: + return CorrectionConfig( + seed=seed, + n_iterations=n_iterations, + parameters=params, + search_strategy=strategy, + grid_points_per_param=grid_points_per_param, + geo=geo, + performance_threshold_m=250.0, + performance_spec_percent=39.0, + ) + + +# =========================================================================== +# _get_nominal_value +# =========================================================================== + + +class TestGetNominalValue: + def test_constant_kernel_list(self, param_constant): + """Nominal CONSTANT_KERNEL returns DataFrame with converted arcsecond values.""" + df = _get_nominal_value(param_constant) + assert isinstance(df, pd.DataFrame) + assert set(df.columns) >= {"ugps", "angle_x", "angle_y", "angle_z"} + # current_value = [10, 20, 30] arcsec → radians + expected_x = np.deg2rad(10.0 / 3600.0) + expected_y = np.deg2rad(20.0 / 3600.0) + expected_z = np.deg2rad(30.0 / 3600.0) + np.testing.assert_allclose(df["angle_x"].iloc[0], expected_x, rtol=1e-10) + np.testing.assert_allclose(df["angle_y"].iloc[0], expected_y, rtol=1e-10) + np.testing.assert_allclose(df["angle_z"].iloc[0], expected_z, rtol=1e-10) + + def test_constant_kernel_zero(self, param_constant_zero): + """Nominal CONSTANT_KERNEL with zeros returns zero-angle DataFrame.""" + df = _get_nominal_value(param_constant_zero) + assert df["angle_x"].iloc[0] == 0.0 + assert df["angle_y"].iloc[0] == 0.0 + assert df["angle_z"].iloc[0] == 0.0 + + def test_offset_kernel(self, param_offset_kernel): + """Nominal OFFSET_KERNEL: current_value=0 arcsec → 0.0 rad.""" + val = _get_nominal_value(param_offset_kernel) + assert isinstance(val, float) + assert val == pytest.approx(0.0) + + def test_offset_time_ms(self, param_offset_time): + """Nominal OFFSET_TIME: current_value=0 ms → 0.0 s.""" + val = _get_nominal_value(param_offset_time) + assert isinstance(val, float) + assert val == pytest.approx(0.0) + + def test_offset_time_nonzero(self): + """Non-zero current_value is correctly converted ms → s.""" + p = ParameterConfig( + ptype=ParameterType.OFFSET_TIME, + data={"current_value": 500.0, "bounds": [-100.0, 100.0], "units": "milliseconds"}, + ) + val = _get_nominal_value(p) + assert val == pytest.approx(0.5) + + +# =========================================================================== +# _get_grid_values +# =========================================================================== + + +class TestGetGridValues: + def test_offset_time_count(self, param_offset_time): + vals = _get_grid_values(param_offset_time, 6) + assert len(vals) == 6 + + def test_offset_time_endpoints(self, param_offset_time): + """Endpoints must be current_value + bounds[0] and current_value + bounds[1] in seconds.""" + vals = _get_grid_values(param_offset_time, 5) + # current_value=0, bounds=[-50, 50] ms → [-0.05, 0.05] s + assert vals[0] == pytest.approx(-0.05) + assert vals[-1] == pytest.approx(0.05) + + def test_offset_time_evenly_spaced(self, param_offset_time): + vals = _get_grid_values(param_offset_time, 10) + diffs = np.diff(vals) + np.testing.assert_allclose(diffs, diffs[0], rtol=1e-10) + + def test_offset_kernel_arcseconds(self, param_offset_kernel): + """OFFSET_KERNEL with arcsecond units: bounds converted to radians.""" + vals = _get_grid_values(param_offset_kernel, 3) + assert len(vals) == 3 + low_rad = np.deg2rad(-3600.0 / 3600.0) # = -π/180 rad + high_rad = np.deg2rad(3600.0 / 3600.0) # = +π/180 rad + assert vals[0] == pytest.approx(low_rad) + assert vals[-1] == pytest.approx(high_rad) + + def test_constant_kernel_returns_dataframes(self, param_constant): + """CONSTANT_KERNEL grid returns a list of DataFrames.""" + vals = _get_grid_values(param_constant, 4) + assert len(vals) == 4 + for df in vals: + assert isinstance(df, pd.DataFrame) + assert "angle_x" in df.columns + + def test_constant_kernel_offset_applied_uniformly(self, param_constant_zero): + """Uniform offset applied to all 3 axes for CONSTANT_KERNEL.""" + # param_constant_zero: current=[0,0,0], bounds=[-10,10] arcsec + vals = _get_grid_values(param_constant_zero, 3) + for df in vals: + x = df["angle_x"].iloc[0] + y = df["angle_y"].iloc[0] + z = df["angle_z"].iloc[0] + assert x == pytest.approx(y), "All 3 axes should share the same offset for zero current_value" + assert y == pytest.approx(z), "All 3 axes should share the same offset for zero current_value" + + def test_constant_kernel_endpoint_magnitudes(self, param_constant_zero): + """First and last grid DataFrames have angles matching the converted bounds.""" + vals = _get_grid_values(param_constant_zero, 2) + low_rad = np.deg2rad(-10.0 / 3600.0) + high_rad = np.deg2rad(10.0 / 3600.0) + assert vals[0]["angle_x"].iloc[0] == pytest.approx(low_rad) + assert vals[-1]["angle_x"].iloc[0] == pytest.approx(high_rad) + + def test_offset_time_microseconds(self): + """Microsecond units are converted correctly.""" + p = ParameterConfig( + ptype=ParameterType.OFFSET_TIME, + data={"current_value": 0.0, "bounds": [-1_000_000.0, 1_000_000.0], "units": "microseconds"}, + ) + vals = _get_grid_values(p, 3) + assert vals[0] == pytest.approx(-1.0) + assert vals[-1] == pytest.approx(1.0) + + +# =========================================================================== +# SearchStrategy.RANDOM (default behaviour) +# =========================================================================== + + +class TestRandomStrategy: + def test_output_length(self, geo, param_offset_time): + config = _make_config(geo, [param_offset_time], strategy=SearchStrategy.RANDOM, n_iterations=7) + sets = load_param_sets(config) + assert len(sets) == 7 + + def test_inner_length(self, geo, param_offset_kernel, param_offset_time): + config = _make_config( + geo, [param_offset_kernel, param_offset_time], strategy=SearchStrategy.RANDOM, n_iterations=3 + ) + sets = load_param_sets(config) + assert all(len(s) == 2 for s in sets) + + def test_reproducible_with_seed(self, geo, param_offset_time): + config = _make_config(geo, [param_offset_time], strategy=SearchStrategy.RANDOM, n_iterations=5, seed=42) + sets_a = load_param_sets(config) + sets_b = load_param_sets(config) + assert len(sets_a) == len(sets_b) + for param_set_a, param_set_b in zip(sets_a, sets_b): + for (_, a), (_, b) in zip(param_set_a, param_set_b): + assert a == pytest.approx(b) + + def test_random_values_within_bounds(self, geo, param_offset_time): + """All sampled time offsets lie within [bounds_low, bounds_high] in seconds.""" + config = _make_config(geo, [param_offset_time], strategy=SearchStrategy.RANDOM, n_iterations=50, seed=7) + sets = load_param_sets(config) + low_s, high_s = -0.05, 0.05 # bounds=[-50, 50] ms → seconds + for param_set in sets: + _, val = param_set[0] + assert low_s <= val <= high_s + + def test_constant_kernel_returns_dataframe(self, geo, param_constant): + config = _make_config(geo, [param_constant], strategy=SearchStrategy.RANDOM, n_iterations=2) + sets = load_param_sets(config) + for param_set in sets: + _, df = param_set[0] + assert isinstance(df, pd.DataFrame) + assert "angle_x" in df.columns + assert "angle_y" in df.columns + assert "angle_z" in df.columns + + def test_offset_kernel_returns_float(self, geo, param_offset_kernel): + config = _make_config(geo, [param_offset_kernel], strategy=SearchStrategy.RANDOM, n_iterations=3) + sets = load_param_sets(config) + for param_set in sets: + _, val = param_set[0] + assert isinstance(val, (float, np.floating)) + + def test_no_sigma_returns_fixed_value(self, geo, param_constant_zero): + """Parameter with sigma=None stays fixed at nominal across all iterations.""" + config = _make_config(geo, [param_constant_zero], strategy=SearchStrategy.RANDOM, n_iterations=10, seed=0) + sets = load_param_sets(config) + first_x = sets[0][0][1]["angle_x"].iloc[0] + for param_set in sets: + _, df = param_set[0] + assert df["angle_x"].iloc[0] == pytest.approx(first_x) + + +# =========================================================================== +# SearchStrategy.GRID_SEARCH +# =========================================================================== + + +class TestGridSearchStrategy: + def test_single_param_count(self, geo, param_offset_time): + """1 parameter × 5 grid points → 5 parameter sets.""" + config = _make_config(geo, [param_offset_time], strategy=SearchStrategy.GRID_SEARCH, grid_points_per_param=5) + sets = load_param_sets(config) + assert len(sets) == 5 + + def test_two_params_cartesian_product(self, geo, param_offset_kernel, param_offset_time): + """2 parameters × 4 grid points → 4² = 16 parameter sets.""" + config = _make_config( + geo, + [param_offset_kernel, param_offset_time], + strategy=SearchStrategy.GRID_SEARCH, + grid_points_per_param=4, + ) + sets = load_param_sets(config) + assert len(sets) == 16 + + def test_three_params_cartesian_product(self, geo, param_constant, param_offset_kernel, param_offset_time): + """3 parameters × 3 grid points → 3³ = 27 parameter sets.""" + config = _make_config( + geo, + [param_constant, param_offset_kernel, param_offset_time], + strategy=SearchStrategy.GRID_SEARCH, + grid_points_per_param=3, + ) + sets = load_param_sets(config) + assert len(sets) == 27 + + def test_inner_set_length(self, geo, param_offset_kernel, param_offset_time): + config = _make_config( + geo, + [param_offset_kernel, param_offset_time], + strategy=SearchStrategy.GRID_SEARCH, + grid_points_per_param=3, + ) + sets = load_param_sets(config) + assert all(len(s) == 2 for s in sets) + + def test_values_span_full_bounds(self, geo, param_offset_time): + """First and last values in single-param grid span the full converted bounds.""" + config = _make_config(geo, [param_offset_time], strategy=SearchStrategy.GRID_SEARCH, grid_points_per_param=5) + sets = load_param_sets(config) + vals = [s[0][1] for s in sets] + assert min(vals) == pytest.approx(-0.05) + assert max(vals) == pytest.approx(0.05) + + def test_values_are_evenly_spaced(self, geo, param_offset_time): + config = _make_config(geo, [param_offset_time], strategy=SearchStrategy.GRID_SEARCH, grid_points_per_param=6) + sets = load_param_sets(config) + vals = [s[0][1] for s in sets] + diffs = np.diff(vals) + np.testing.assert_allclose(diffs, diffs[0], rtol=1e-10) + + def test_deterministic_no_seed_needed(self, geo, param_offset_time): + """GRID_SEARCH is deterministic regardless of seed.""" + config_a = _make_config( + geo, [param_offset_time], strategy=SearchStrategy.GRID_SEARCH, grid_points_per_param=4, seed=None + ) + config_b = _make_config( + geo, [param_offset_time], strategy=SearchStrategy.GRID_SEARCH, grid_points_per_param=4, seed=99 + ) + sets_a = load_param_sets(config_a) + sets_b = load_param_sets(config_b) + assert len(sets_a) == len(sets_b) + for param_set_a, param_set_b in zip(sets_a, sets_b): + for (_, a), (_, b) in zip(param_set_a, param_set_b): + assert a == pytest.approx(b) + + def test_constant_kernel_in_grid(self, geo, param_constant_zero): + """GRID_SEARCH on CONSTANT_KERNEL yields DataFrames with monotone angles.""" + config = _make_config(geo, [param_constant_zero], strategy=SearchStrategy.GRID_SEARCH, grid_points_per_param=4) + sets = load_param_sets(config) + assert len(sets) == 4 + angle_xs = [s[0][1]["angle_x"].iloc[0] for s in sets] + # Values should be monotonically increasing (linspace low→high) + assert all(angle_xs[i] <= angle_xs[i + 1] for i in range(len(angle_xs) - 1)) + + def test_n_iterations_ignored(self, geo, param_offset_time): + """n_iterations has no effect on GRID_SEARCH output count.""" + config = _make_config( + geo, + [param_offset_time], + strategy=SearchStrategy.GRID_SEARCH, + grid_points_per_param=5, + n_iterations=1000, # ignored + ) + sets = load_param_sets(config) + assert len(sets) == 5 + + def test_guardrail_raises_when_total_exceeds_limit(self, geo, param_offset_kernel, param_offset_time): + """GRID_SEARCH raises ValueError when cartesian product would exceed max_grid_sets.""" + # 10 points × 2 params = 100 sets → set limit to 99 to trigger the guard + config = _make_config( + geo, + [param_offset_kernel, param_offset_time], + strategy=SearchStrategy.GRID_SEARCH, + grid_points_per_param=10, + ) + # Override the limit below what the sweep would produce (100 sets) + config = config.model_copy(update={"max_grid_sets": 99}) + with pytest.raises(ValueError, match="exceeds the safety limit"): + load_param_sets(config) + + def test_guardrail_passes_when_limit_raised(self, geo, param_offset_time): + """Explicitly raising max_grid_sets allows larger sweeps through.""" + # 5 points × 1 param = 5 sets; set limit to 5 exactly — should succeed + config = _make_config( + geo, + [param_offset_time], + strategy=SearchStrategy.GRID_SEARCH, + grid_points_per_param=5, + ) + config = config.model_copy(update={"max_grid_sets": 5}) + sets = load_param_sets(config) + assert len(sets) == 5 + + +# =========================================================================== +# SearchStrategy.SINGLE_OFFSET +# =========================================================================== + + +class TestSingleOffsetStrategy: + def test_single_param_count(self, geo, param_offset_time): + """1 parameter × n_iterations values → n_iterations parameter sets.""" + config = _make_config(geo, [param_offset_time], strategy=SearchStrategy.SINGLE_OFFSET, n_iterations=8) + sets = load_param_sets(config) + assert len(sets) == 8 + + def test_two_params_count(self, geo, param_offset_kernel, param_offset_time): + """2 parameters × 5 values each → 10 total parameter sets.""" + config = _make_config( + geo, + [param_offset_kernel, param_offset_time], + strategy=SearchStrategy.SINGLE_OFFSET, + n_iterations=5, + ) + sets = load_param_sets(config) + assert len(sets) == 10 + + def test_time_offset_sweep_spans_bounds(self, geo, param_offset_time): + """SINGLE_OFFSET sweep of OFFSET_TIME spans the full converted bounds.""" + config = _make_config(geo, [param_offset_time], strategy=SearchStrategy.SINGLE_OFFSET, n_iterations=5) + sets = load_param_sets(config) + vals = [s[0][1] for s in sets] + assert min(vals) == pytest.approx(-0.05) + assert max(vals) == pytest.approx(0.05) + + def test_time_offset_sweep_evenly_spaced(self, geo, param_offset_time): + config = _make_config(geo, [param_offset_time], strategy=SearchStrategy.SINGLE_OFFSET, n_iterations=7) + sets = load_param_sets(config) + vals = [s[0][1] for s in sets] + diffs = np.diff(vals) + np.testing.assert_allclose(diffs, diffs[0], rtol=1e-10) + + def test_non_swept_params_held_at_nominal(self, geo, param_offset_kernel, param_offset_time): + """While sweeping param_0, param_1 must equal its nominal value in every set.""" + config = _make_config( + geo, + [param_offset_kernel, param_offset_time], + strategy=SearchStrategy.SINGLE_OFFSET, + n_iterations=4, + ) + sets = load_param_sets(config) + nominal_time = _get_nominal_value(param_offset_time) + # First 4 sets sweep param_0 (OFFSET_KERNEL); param_1 (time) should be nominal + for param_set in sets[:4]: + _, time_val = param_set[1] + assert time_val == pytest.approx(nominal_time) + + def test_swept_param_changes_others_fixed(self, geo, param_offset_kernel, param_offset_time): + """While sweeping param_1 (time), param_0 (kernel) stays at nominal in every set.""" + config = _make_config( + geo, + [param_offset_kernel, param_offset_time], + strategy=SearchStrategy.SINGLE_OFFSET, + n_iterations=4, + ) + sets = load_param_sets(config) + nominal_kernel = _get_nominal_value(param_offset_kernel) + # Last 4 sets sweep param_1 (time); param_0 (kernel) should be nominal + for param_set in sets[4:]: + _, kernel_val = param_set[0] + assert kernel_val == pytest.approx(nominal_kernel) + + def test_deterministic(self, geo, param_offset_time): + """SINGLE_OFFSET is deterministic: two calls with same config return identical results.""" + config = _make_config( + geo, [param_offset_time], strategy=SearchStrategy.SINGLE_OFFSET, n_iterations=5, seed=None + ) + sets_a = load_param_sets(config) + sets_b = load_param_sets(config) + assert len(sets_a) == len(sets_b) + for param_set_a, param_set_b in zip(sets_a, sets_b): + for (_, a), (_, b) in zip(param_set_a, param_set_b): + assert a == pytest.approx(b) + + def test_constant_kernel_sweep(self, geo, param_constant_zero): + """SINGLE_OFFSET on CONSTANT_KERNEL sweeps angle magnitudes monotonically.""" + config = _make_config(geo, [param_constant_zero], strategy=SearchStrategy.SINGLE_OFFSET, n_iterations=5) + sets = load_param_sets(config) + assert len(sets) == 5 + angle_xs = [s[0][1]["angle_x"].iloc[0] for s in sets] + assert all(angle_xs[i] <= angle_xs[i + 1] for i in range(len(angle_xs) - 1)) + + +# =========================================================================== +# Config validation +# =========================================================================== + + +class TestConfigValidation: + def test_grid_points_per_param_minimum(self, geo, param_offset_time): + """grid_points_per_param must be >= 2.""" + with pytest.raises(ValidationError) as exc_info: + _make_config( + geo, + [param_offset_time], + strategy=SearchStrategy.GRID_SEARCH, + grid_points_per_param=1, + ) + errors = exc_info.value.errors() + assert any("grid_points_per_param" in err.get("loc", ()) for err in errors) + + def test_search_strategy_default_is_random(self, geo, param_offset_time): + config = CorrectionConfig( + seed=0, + n_iterations=3, + parameters=[param_offset_time], + geo=geo, + performance_threshold_m=250.0, + performance_spec_percent=39.0, + ) + assert config.search_strategy == SearchStrategy.RANDOM + + def test_search_strategy_enum_values(self): + assert SearchStrategy("random") is SearchStrategy.RANDOM + assert SearchStrategy("grid") is SearchStrategy.GRID_SEARCH + assert SearchStrategy("single") is SearchStrategy.SINGLE_OFFSET + + def test_invalid_strategy_string_rejected(self, geo, param_offset_time): + with pytest.raises(ValidationError): + _make_config( + geo, + [param_offset_time], + strategy="not_a_strategy", # type: ignore[arg-type] + ) + + def test_json_round_trip_random(self, geo, param_offset_time): + """RANDOM config survives model_dump_json / model_validate_json round-trip.""" + config = _make_config(geo, [param_offset_time], strategy=SearchStrategy.RANDOM) + json_str = config.model_dump_json() + restored = CorrectionConfig.model_validate_json(json_str) + assert restored.search_strategy == SearchStrategy.RANDOM + assert restored.n_iterations == config.n_iterations + + def test_json_round_trip_grid(self, geo, param_offset_time): + config = _make_config(geo, [param_offset_time], strategy=SearchStrategy.GRID_SEARCH, grid_points_per_param=7) + restored = CorrectionConfig.model_validate_json(config.model_dump_json()) + assert restored.search_strategy == SearchStrategy.GRID_SEARCH + assert restored.grid_points_per_param == 7 + assert restored.max_grid_sets == config.max_grid_sets + + def test_json_round_trip_single_offset(self, geo, param_offset_time): + config = _make_config(geo, [param_offset_time], strategy=SearchStrategy.SINGLE_OFFSET, n_iterations=12) + restored = CorrectionConfig.model_validate_json(config.model_dump_json()) + assert restored.search_strategy == SearchStrategy.SINGLE_OFFSET + assert restored.n_iterations == 12 + + +# =========================================================================== +# Strategy ↔ output type consistency +# =========================================================================== + + +class TestOutputTypeConsistency: + """Ensure every strategy returns the correct element types for all param types.""" + + @pytest.mark.parametrize( + "strategy", + [SearchStrategy.RANDOM, SearchStrategy.GRID_SEARCH, SearchStrategy.SINGLE_OFFSET], + ) + def test_constant_kernel_always_dataframe(self, strategy, geo, param_constant_zero): + config = _make_config(geo, [param_constant_zero], strategy=strategy, n_iterations=3, grid_points_per_param=3) + sets = load_param_sets(config) + for param_set in sets: + _, val = param_set[0] + assert isinstance(val, pd.DataFrame), f"Expected DataFrame for {strategy}, got {type(val)}" + assert {"ugps", "angle_x", "angle_y", "angle_z"}.issubset(val.columns) + + @pytest.mark.parametrize( + "strategy", + [SearchStrategy.RANDOM, SearchStrategy.GRID_SEARCH, SearchStrategy.SINGLE_OFFSET], + ) + def test_offset_kernel_always_float(self, strategy, geo, param_offset_kernel): + config = _make_config(geo, [param_offset_kernel], strategy=strategy, n_iterations=3, grid_points_per_param=3) + sets = load_param_sets(config) + for param_set in sets: + _, val = param_set[0] + assert isinstance(val, (float, np.floating)), f"Expected float for {strategy}, got {type(val)}" + + @pytest.mark.parametrize( + "strategy", + [SearchStrategy.RANDOM, SearchStrategy.GRID_SEARCH, SearchStrategy.SINGLE_OFFSET], + ) + def test_offset_time_always_float(self, strategy, geo, param_offset_time): + config = _make_config(geo, [param_offset_time], strategy=strategy, n_iterations=3, grid_points_per_param=3) + sets = load_param_sets(config) + for param_set in sets: + _, val = param_set[0] + assert isinstance(val, (float, np.floating)), f"Expected float for {strategy}, got {type(val)}" + + +# =========================================================================== +# Logging behaviour +# =========================================================================== + + +class TestLogging: + """Verify _log_param_set_summary emits per-set detail only at DEBUG.""" + + def test_per_set_detail_suppressed_at_info(self, geo, param_offset_time, caplog): + """With log level INFO, per-set lines must NOT appear in the log output.""" + import logging + + config = _make_config(geo, [param_offset_time], strategy=SearchStrategy.GRID_SEARCH, grid_points_per_param=4) + with caplog.at_level(logging.INFO, logger="curryer.correction.parameters"): + load_param_sets(config) + + # The high-level count line should be present + assert any("Generated 4 parameter sets" in r.message for r in caplog.records) + # Individual "Set N:" detail lines must NOT appear at INFO + assert not any(r.message.startswith(" Set ") for r in caplog.records) + + def test_per_set_detail_present_at_debug(self, geo, param_offset_time, caplog): + """With log level DEBUG, per-set detail lines DO appear.""" + import logging + + config = _make_config(geo, [param_offset_time], strategy=SearchStrategy.GRID_SEARCH, grid_points_per_param=3) + with caplog.at_level(logging.DEBUG, logger="curryer.correction.parameters"): + load_param_sets(config) + + # Expect exactly 3 "Set N:" lines (one per grid point) + set_lines = [r for r in caplog.records if r.message.strip().startswith("Set ")] + assert len(set_lines) == 3 diff --git a/tests/test_correction/test_pipeline.py b/tests/test_correction/test_pipeline.py new file mode 100644 index 00000000..1afb7a5a --- /dev/null +++ b/tests/test_correction/test_pipeline.py @@ -0,0 +1,250 @@ +"""Tests for ``curryer.correction.pipeline``. + +Covers: +- ``_extract_parameter_values`` +- ``_extract_error_metrics`` +- ``_store_parameter_values`` +- ``_store_gcp_pair_results`` +- ``_compute_parameter_set_metrics`` +- ``_load_image_pair_data`` +- ``_extract_spacecraft_position_midframe`` (position_columns feature) +- ``loop`` (optimised pair-outer, ``@pytest.mark.extra``) +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from unittest.mock import MagicMock + +import numpy as np +import pandas as pd +import pytest +import xarray as xr +from _synthetic_helpers import synthetic_image_matching +from clarreo_config import create_clarreo_correction_config +from clarreo_data_loaders import load_clarreo_science, load_clarreo_telemetry + +from curryer.correction import correction +from curryer.correction.config import DataConfig +from curryer.correction.pipeline import _extract_spacecraft_position_midframe + +logger = logging.getLogger(__name__) + + +# ── shared fixtures ─────────────────────────────────────────────────────────── + + +@pytest.fixture(scope="module") +def clarreo_cfg(root_dir): + data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" + generic_dir = root_dir / "data" / "generic" + return create_clarreo_correction_config(data_dir, generic_dir) + + +# ── tests ───────────────────────────────────────────────────────────────────── + + +def test_extract_parameter_values(): + """_extract_parameter_values returns roll/pitch/yaw keys.""" + param_config = correction.ParameterConfig( + ptype=correction.ParameterType.CONSTANT_KERNEL, config_file=Path("test_kernel.json"), data=None + ) + param_data = pd.DataFrame( + { + "angle_x": [np.radians(1.0 / 3600)], + "angle_y": [np.radians(2.0 / 3600)], + "angle_z": [np.radians(3.0 / 3600)], + } + ) + result = correction._extract_parameter_values([(param_config, param_data)]) + assert isinstance(result, dict) + assert len(result) == 3 + assert "test_kernel_roll" in result + assert "test_kernel_pitch" in result + assert "test_kernel_yaw" in result + + +def test_extract_error_metrics(): + """_extract_error_metrics pulls named metrics from a Dataset.""" + ds = xr.Dataset({"lat_error_deg": (["pt"], [0.001, 0.002])}) + ds.attrs.update( + { + "rms_error_m": 150.0, + "mean_error_m": 140.0, + "max_error_m": 200.0, + "std_error_m": 10.0, + "total_measurements": 2, + } + ) + m = correction._extract_error_metrics(ds) + assert m["rms_error_m"] == 150.0 + assert m["n_measurements"] == 2 + + +def test_store_parameter_values(): + """_store_parameter_values writes values at the correct index.""" + netcdf_data = {"parameter_set_id": np.zeros(3, dtype=int), "param_foo": np.zeros(3)} + correction._store_parameter_values(netcdf_data, param_idx=1, param_values={"foo": 2.5}) + assert netcdf_data["param_foo"][1] == pytest.approx(2.5) + + +def test_store_gcp_pair_results(): + """_store_gcp_pair_results populates all metric arrays correctly.""" + nc = {k: np.zeros((2, 2)) for k in ("rms_error_m", "mean_error_m", "max_error_m", "std_error_m")} + nc["n_measurements"] = np.zeros((2, 2), dtype=int) + metrics = { + "rms_error_m": 150.0, + "mean_error_m": 140.0, + "max_error_m": 200.0, + "std_error_m": 10.0, + "n_measurements": 10, + } + correction._store_gcp_pair_results(nc, param_idx=0, pair_idx=1, error_metrics=metrics) + assert nc["rms_error_m"][0, 1] == 150.0 + assert nc["std_error_m"][0, 1] == 10.0 + assert nc["n_measurements"][0, 1] == 10 + + +def test_compute_parameter_set_metrics(): + """_compute_parameter_set_metrics populates aggregate stats.""" + nc = { + "percent_under_250m": np.zeros(2), + "mean_rms_all_pairs": np.zeros(2), + "best_pair_rms": np.zeros(2), + "worst_pair_rms": np.zeros(2), + } + correction._compute_parameter_set_metrics(nc, param_idx=0, pair_errors=[100.0, 200.0, 300.0], threshold_m=250.0) + assert nc["percent_under_250m"][0] > 0 + assert nc["best_pair_rms"][0] == 100.0 + assert nc["worst_pair_rms"][0] == 300.0 + + +def test_load_image_pair_data(root_dir, clarreo_cfg, tmp_path): + """_load_image_pair_data returns DataFrames for tlm and sci.""" + data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" + tlm_csv = tmp_path / "tlm.csv" + sci_csv = tmp_path / "sci.csv" + load_clarreo_telemetry(data_dir).to_csv(tlm_csv) + load_clarreo_science(data_dir).to_csv(sci_csv) + cfg = clarreo_cfg.model_copy(deep=True) + cfg.data = DataConfig(file_format="csv", time_scale_factor=1e6) + tlm_ds, sci_ds, ugps = correction._load_image_pair_data(str(tlm_csv), str(sci_csv), cfg) + assert isinstance(tlm_ds, pd.DataFrame) + assert isinstance(sci_ds, pd.DataFrame) + assert ugps is not None + + +@pytest.mark.extra +def test_loop_optimized(root_dir, tmp_path): + """loop() produces correct result structure. Requires GMTED – ``--run-extra``.""" + data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" + generic_dir = root_dir / "data" / "generic" + config = create_clarreo_correction_config(data_dir, generic_dir) + config.n_iterations = 2 + config.output_filename = "test_loop.nc" + work = tmp_path / "loop" + work.mkdir() + tlm_csv, sci_csv = work / "tlm.csv", work / "sci.csv" + load_clarreo_telemetry(data_dir).to_csv(tlm_csv) + load_clarreo_science(data_dir).to_csv(sci_csv) + config.data = DataConfig(file_format="csv", time_scale_factor=1e6) + config._image_matching_override = synthetic_image_matching + sets = [(str(tlm_csv), str(sci_csv), "synthetic_gcp.mat")] + np.random.seed(42) + results, nc = correction.loop(config, work, sets, resume_from_checkpoint=False) + assert isinstance(results, list) + assert len(results) > 0 + assert len(results) == config.n_iterations * len(sets) + assert nc["rms_error_m"].shape == (config.n_iterations, len(sets)) + for r in results: + assert "param_index" in r + assert "pair_index" in r + assert "rms_error_m" in r + assert r["aggregate_rms_error_m"] is not None + assert isinstance(r["aggregate_rms_error_m"], (int, float, np.number)) + + +# ── _extract_spacecraft_position_midframe ───────────────────────────────────── + + +def _make_telemetry() -> pd.DataFrame: + """Return a 3-row telemetry DataFrame with standard column names.""" + return pd.DataFrame( + { + "sc_pos_x": [1.0, 2.0, 3.0], + "sc_pos_y": [4.0, 5.0, 6.0], + "sc_pos_z": [7.0, 8.0, 9.0], + } + ) + + +class TestExtractSpacecraftPositionMidframe: + """Tests for _extract_spacecraft_position_midframe with position_columns.""" + + def test_explicit_position_columns_used(self): + """config.data.position_columns should be used directly.""" + telemetry = pd.DataFrame( + { + "my_x": [1.0, 2.0, 3.0], + "my_y": [4.0, 5.0, 6.0], + "my_z": [7.0, 8.0, 9.0], + } + ) + config = MagicMock() + config.data = DataConfig(position_columns=["my_x", "my_y", "my_z"]) + + result = _extract_spacecraft_position_midframe(telemetry, config=config) + + np.testing.assert_array_equal(result, [2.0, 5.0, 8.0]) # mid_idx = 1 + + def test_explicit_position_columns_returns_float64(self): + """Result should be a float64 ndarray of shape (3,).""" + telemetry = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) + config = MagicMock() + config.data = DataConfig(position_columns=["a", "b", "c"]) + + result = _extract_spacecraft_position_midframe(telemetry, config=config) + + assert result.shape == (3,) + assert result.dtype == np.float64 + + def test_position_columns_wrong_length_raises_valueerror(self): + """position_columns with != 3 entries should raise ValueError.""" + telemetry = pd.DataFrame({"x": [1.0], "y": [2.0]}) + config = MagicMock() + config.data = DataConfig(position_columns=["x", "y"]) + + with pytest.raises(ValueError, match="exactly 3 entries"): + _extract_spacecraft_position_midframe(telemetry, config=config) + + def test_position_columns_missing_column_raises_valueerror(self): + """position_columns referencing nonexistent columns should raise ValueError.""" + telemetry = pd.DataFrame({"x": [1.0], "y": [2.0], "z": [3.0]}) + config = MagicMock() + config.data = DataConfig(position_columns=["x", "y", "MISSING"]) + + with pytest.raises(ValueError, match="not found in telemetry"): + _extract_spacecraft_position_midframe(telemetry, config=config) + + def test_no_position_columns_falls_back_with_warning(self, caplog): + """When position_columns is None, fall back to pattern-guessing with warning.""" + telemetry = _make_telemetry() + config = MagicMock() + config.data = None # position_columns not configured + + with caplog.at_level(logging.WARNING, logger="curryer.correction.pipeline"): + result = _extract_spacecraft_position_midframe(telemetry, config=config) + + assert "position_columns not configured" in caplog.text + np.testing.assert_array_equal(result, [2.0, 5.0, 8.0]) + + def test_no_config_falls_back_to_pattern_guessing(self, caplog): + """When config=None entirely, pattern-guessing is used (backward compat).""" + telemetry = _make_telemetry() + + with caplog.at_level(logging.WARNING, logger="curryer.correction.pipeline"): + result = _extract_spacecraft_position_midframe(telemetry, config=None) + + assert "position_columns not configured" in caplog.text + np.testing.assert_array_equal(result, [2.0, 5.0, 8.0]) diff --git a/tests/test_correction/test_regrid.py b/tests/test_correction/test_regrid.py new file mode 100644 index 00000000..e9c00f31 --- /dev/null +++ b/tests/test_correction/test_regrid.py @@ -0,0 +1,349 @@ +"""Unit tests for regrid module. + +Tests for GCP chip regridding algorithms. +""" + +import numpy as np +import pytest + +from curryer.correction.data_structures import ImageGrid, RegridConfig +from curryer.correction.regrid import ( + bilinear_interpolate_quad, + compute_regular_grid_bounds, + create_regular_grid, + find_containing_cell, + point_in_triangle, + regrid_gcp_chip, + regrid_irregular_to_regular, +) + + +class TestRegridConfig: + """Test RegridConfig validation.""" + + def test_resolution_based_config(self): + """Test resolution-based configuration.""" + config = RegridConfig(output_resolution_deg=(0.001, 0.001)) + assert config.output_resolution_deg == (0.001, 0.001) + assert config.output_grid_size is None + assert config.conservative_bounds is True + + def test_size_based_config(self): + """Test size-based configuration.""" + config = RegridConfig(output_grid_size=(500, 500)) + assert config.output_grid_size == (500, 500) + assert config.output_resolution_deg is None + + def test_bounds_plus_resolution(self): + """Test bounds + resolution configuration.""" + config = RegridConfig(output_bounds=(-116.0, -115.0, 38.0, 39.0), output_resolution_deg=(0.001, 0.001)) + assert config.output_bounds is not None + assert config.output_resolution_deg is not None + + def test_invalid_size_and_resolution(self): + """Test that size + resolution raises error.""" + with pytest.raises(ValueError, match="Cannot specify both"): + RegridConfig(output_grid_size=(500, 500), output_resolution_deg=(0.001, 0.001)) + + def test_invalid_bounds_without_resolution(self): + """Test that bounds without resolution raises error.""" + with pytest.raises(ValueError, match="requires output_resolution_deg"): + RegridConfig(output_bounds=(-116.0, -115.0, 38.0, 39.0)) + + def test_invalid_interpolation_method(self): + """Test that invalid interpolation method raises error.""" + with pytest.raises(ValueError, match="interpolation_method must be"): + RegridConfig(output_resolution_deg=(0.001, 0.001), interpolation_method="invalid") + + def test_invalid_grid_size(self): + """Test that too-small grid size raises error.""" + with pytest.raises(ValueError, match="at least 2 rows"): + RegridConfig(output_grid_size=(1, 100)) + + def test_invalid_resolution(self): + """Test that negative resolution raises error.""" + with pytest.raises(ValueError, match="must be positive"): + RegridConfig(output_resolution_deg=(-0.001, 0.001)) + + def test_invalid_bounds(self): + """Test that invalid bounds raise error.""" + with pytest.raises(ValueError, match="minlon must be < maxlon"): + RegridConfig( + output_bounds=(-115.0, -116.0, 38.0, 39.0), # Swapped lon + output_resolution_deg=(0.001, 0.001), + ) + + +class TestGeometricPrimitives: + """Test geometric helper functions.""" + + def test_point_in_triangle_inside(self): + """Test point inside triangle.""" + triangle = np.array([[0.0, 0.0], [1.0, 0.0], [0.5, 1.0]]) + point = np.array([0.5, 0.3]) + + inside, weights = point_in_triangle(point, triangle) + + assert inside + assert len(weights) == 3 + assert np.abs(np.sum(weights) - 1.0) < 1e-10 # Weights sum to 1 + + def test_point_in_triangle_outside(self): + """Test point outside triangle.""" + triangle = np.array([[0.0, 0.0], [1.0, 0.0], [0.5, 1.0]]) + point = np.array([2.0, 2.0]) + + inside, weights = point_in_triangle(point, triangle) + + assert not inside + + def test_point_in_triangle_on_edge(self): + """Test point on triangle edge.""" + triangle = np.array([[0.0, 0.0], [1.0, 0.0], [0.5, 1.0]]) + point = np.array([0.5, 0.0]) # On base edge + + inside, weights = point_in_triangle(point, triangle) + + # With tolerant boundary check, point on edge should be inside + assert inside + + def test_bilinear_interpolate_square(self): + """Test bilinear interpolation on a square.""" + # Unit square with values at corners: 0, 1, 2, 3 + corners_lon = np.array([0.0, 1.0, 1.0, 0.0]) + corners_lat = np.array([1.0, 1.0, 0.0, 0.0]) + corner_values = np.array([0.0, 1.0, 2.0, 3.0]) + + # Test center point + point = np.array([0.5, 0.5]) + value = bilinear_interpolate_quad(point, corners_lon, corners_lat, corner_values) + + # At center, should be average of corners + assert np.abs(value - 1.5) < 1e-10 + + def test_bilinear_interpolate_corner(self): + """Test bilinear interpolation at corner.""" + corners_lon = np.array([0.0, 1.0, 1.0, 0.0]) + corners_lat = np.array([1.0, 1.0, 0.0, 0.0]) + corner_values = np.array([10.0, 20.0, 30.0, 40.0]) + + # Test at corner (should equal corner value) + point = np.array([0.0, 1.0]) # Top-left corner + value = bilinear_interpolate_quad(point, corners_lon, corners_lat, corner_values) + + assert np.abs(value - 10.0) < 1e-6 + + +class TestGridOperations: + """Test grid creation and bounds computation.""" + + def test_compute_bounds_conservative(self): + """Test conservative bounds computation.""" + # Create a slightly distorted grid + nrows, ncols = 10, 10 + lat_base = np.linspace(39.0, 38.0, nrows) + lon_base = np.linspace(-116.0, -115.0, ncols) + lon_grid, lat_grid = np.meshgrid(lon_base, lat_base) + + # Add small, deterministic distortion using a fixed RNG seed + rng = np.random.default_rng(0) + lon_grid += 0.01 * rng.standard_normal(size=(nrows, ncols)) + lat_grid += 0.01 * rng.standard_normal(size=(nrows, ncols)) + + minlon, maxlon, minlat, maxlat = compute_regular_grid_bounds(lon_grid, lat_grid, conservative=True) + + # Conservative bounds should not extend beyond the full extent + assert minlon >= lon_grid.min() + assert maxlon <= lon_grid.max() + assert minlat >= lat_grid.min() + assert maxlat <= lat_grid.max() + + def test_compute_bounds_full_extent(self): + """Test full extent bounds computation.""" + nrows, ncols = 10, 10 + lat_base = np.linspace(39.0, 38.0, nrows) + lon_base = np.linspace(-116.0, -115.0, ncols) + lon_grid, lat_grid = np.meshgrid(lon_base, lat_base) + + minlon, maxlon, minlat, maxlat = compute_regular_grid_bounds(lon_grid, lat_grid, conservative=False) + + # Full extent should match min/max + assert np.abs(minlon - lon_grid.min()) < 1e-10 + assert np.abs(maxlon - lon_grid.max()) < 1e-10 + assert np.abs(minlat - lat_grid.min()) < 1e-10 + assert np.abs(maxlat - lat_grid.max()) < 1e-10 + + def test_create_regular_grid_from_size(self): + """Test regular grid creation from size.""" + bounds = (-116.0, -115.0, 38.0, 39.0) + grid_size = (100, 100) + + lon_grid, lat_grid = create_regular_grid(bounds, grid_size=grid_size) + + assert lon_grid.shape == grid_size + assert lat_grid.shape == grid_size + + # Check latitude decreases (row index increases going south) + assert lat_grid[0, 0] > lat_grid[-1, 0] + + # Check longitude increases (col index increases going east) + assert lon_grid[0, 0] < lon_grid[0, -1] + + def test_create_regular_grid_from_resolution(self): + """Test regular grid creation from resolution.""" + bounds = (-116.0, -115.0, 38.0, 39.0) + resolution = (0.01, 0.01) # 0.01 degree resolution + + lon_grid, lat_grid = create_regular_grid(bounds, resolution=resolution) + + # Check expected dimensions (1 degree / 0.01 + 1) + assert lon_grid.shape[0] == 101 # (39-38)/0.01 + 1 + assert lon_grid.shape[1] == 101 # (-115-(-116))/0.01 + 1 + + def test_create_regular_grid_requires_one_param(self): + """Test that exactly one of size/resolution is required.""" + bounds = (-116.0, -115.0, 38.0, 39.0) + + # Neither provided + with pytest.raises(ValueError, match="Must specify either"): + create_regular_grid(bounds) + + # Both provided + with pytest.raises(ValueError, match="only one of"): + create_regular_grid(bounds, grid_size=(100, 100), resolution=(0.01, 0.01)) + + +class TestRegridding: + """Test core regridding functionality.""" + + def test_find_containing_cell_simple(self): + """Test finding containing cell in regular grid.""" + # Create simple regular grid + nrows, ncols = 5, 5 + lat = np.linspace(39.0, 38.0, nrows) + lon = np.linspace(-116.0, -115.0, ncols) + lon_grid, lat_grid = np.meshgrid(lon, lat) + + # Test point in center of cell [1, 1] (not on boundary) + # Cell [1,1] has corners: (-115.75, 38.75), (-115.5, 38.75), + # (-115.5, 38.5), (-115.75, 38.5) + point = np.array([-115.625, 38.625]) # Center of cell [1,1] + cell = find_containing_cell(point, lon_grid, lat_grid) + + assert cell is not None + assert cell == (1, 1) + + def test_find_containing_cell_outside(self): + """Test point outside grid returns None.""" + nrows, ncols = 5, 5 + lat = np.linspace(39.0, 38.0, nrows) + lon = np.linspace(-116.0, -115.0, ncols) + lon_grid, lat_grid = np.meshgrid(lon, lat) + + # Test point way outside + point = np.array([-120.0, 40.0]) + cell = find_containing_cell(point, lon_grid, lat_grid) + + assert cell is None + + def test_regrid_identity(self): + """Test that regridding to same grid preserves values.""" + # Create regular grid + nrows, ncols = 10, 10 + lat = np.linspace(39.0, 38.0, nrows) + lon = np.linspace(-116.0, -115.0, ncols) + lon_grid, lat_grid = np.meshgrid(lon, lat) + + # Create simple data + data = np.arange(nrows * ncols).reshape(nrows, ncols).astype(float) + + # Regrid to same grid + data_regridded = regrid_irregular_to_regular(data, lon_grid, lat_grid, lon_grid, lat_grid) + + # Should be nearly identical + np.testing.assert_array_almost_equal(data, data_regridded, decimal=6) + + def test_regrid_coarsen(self): + """Test regridding to coarser grid.""" + # Create fine regular grid + nrows_fine, ncols_fine = 20, 20 + lat_fine = np.linspace(39.0, 38.0, nrows_fine) + lon_fine = np.linspace(-116.0, -115.0, ncols_fine) + lon_grid_fine, lat_grid_fine = np.meshgrid(lon_fine, lat_fine) + + # Create data with gradient + data_fine = lon_grid_fine + lat_grid_fine + + # Create coarse grid + nrows_coarse, ncols_coarse = 5, 5 + lat_coarse = np.linspace(39.0, 38.0, nrows_coarse) + lon_coarse = np.linspace(-116.0, -115.0, ncols_coarse) + lon_grid_coarse, lat_grid_coarse = np.meshgrid(lon_coarse, lat_coarse) + + # Regrid + data_coarse = regrid_irregular_to_regular( + data_fine, lon_grid_fine, lat_grid_fine, lon_grid_coarse, lat_grid_coarse + ) + + # Check shape + assert data_coarse.shape == (nrows_coarse, ncols_coarse) + + # Check no NaNs in interior + assert not np.any(np.isnan(data_coarse)) + + # Check values are in reasonable range (allow small numerical errors) + # Bilinear interpolation can produce values slightly outside original range + tol = 1e-10 + assert data_coarse.min() >= data_fine.min() - tol + assert data_coarse.max() <= data_fine.max() + tol + + +class TestEndToEnd: + """Test complete regridding workflow.""" + + def test_regrid_gcp_chip_synthetic(self): + """Test regrid_gcp_chip with synthetic data.""" + # Create synthetic chip with ECEF coordinates + nrows, ncols = 50, 50 + + # Generate regular lat/lon grid + lat = np.linspace(38.5, 38.0, nrows) + lon = np.linspace(-116.0, -115.5, ncols) + lon_grid, lat_grid = np.meshgrid(lon, lat) + + # Convert to ECEF (using simplified approach for testing) + # For real test, would use actual geodetic_to_ecef + from curryer.compute.spatial import geodetic_to_ecef + + lla = np.stack([lon_grid.ravel(), lat_grid.ravel(), np.zeros(nrows * ncols)], axis=1) + ecef = geodetic_to_ecef(lla, meters=True, degrees=True) + + ecef_x = ecef[:, 0].reshape(nrows, ncols) + ecef_y = ecef[:, 1].reshape(nrows, ncols) + ecef_z = ecef[:, 2].reshape(nrows, ncols) + + # Create simple data pattern + band_data = np.arange(nrows * ncols).reshape(nrows, ncols).astype(float) + + # Regrid to coarser resolution + config = RegridConfig(output_resolution_deg=(0.02, 0.02)) # Coarser + + result = regrid_gcp_chip(band_data, (ecef_x, ecef_y, ecef_z), config) + + # Check result is ImageGrid + assert isinstance(result, ImageGrid) + + # Check output is smaller (coarser resolution) + assert result.data.shape[0] < nrows + assert result.data.shape[1] < ncols + + # Check no NaNs in interior + assert not np.all(np.isnan(result.data)) + + # Check lat/lon grids are regular + lat_spacing = np.diff(result.lat[:, 0]) + assert np.allclose(lat_spacing, lat_spacing[0], rtol=1e-6) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_correction/test_results.py b/tests/test_correction/test_results.py new file mode 100644 index 00000000..c441a635 --- /dev/null +++ b/tests/test_correction/test_results.py @@ -0,0 +1,580 @@ +"""Tests for ``curryer.correction.results`` and Prompt-5 additions. + +Covers +------ +- :class:`ParameterSetResult` — construction and field access +- :class:`CorrectionResult` — construction, JSON serialisation, + raw-data access, ``results``/``netcdf_data`` exclusion from JSON +- :func:`_format_correction_summary_table` — well-formed box-drawn output +- :func:`build_correction_result` — met/not-met threshold, all-NaN fallback +- :func:`compare_results` — side-by-side output format +- :class:`VerificationResult` provenance fields — defaults and population +- Backward-compat: ``loop()`` still returns a 2-tuple +""" + +from __future__ import annotations + +import json +import math +from datetime import datetime, timezone +from pathlib import Path + +import numpy as np +import pytest +import xarray as xr + +from curryer.correction.config import ( + CorrectionConfig, + GeolocationConfig, + ParameterConfig, + ParameterType, + RequirementsConfig, +) +from curryer.correction.results import ( + CorrectionResult, + ParameterSetResult, + _fmt_rms, + _format_correction_summary_table, + build_correction_result, +) +from curryer.correction.verification import ( + VerificationResult, + compare_results, +) + +# =========================================================================== +# Shared helpers +# =========================================================================== + +_THRESHOLD_M = 250.0 +_SPEC_PCT = 39.0 + + +def _make_geo() -> GeolocationConfig: + return GeolocationConfig( + meta_kernel_file=Path("tests/data/test.kernels.tm.json"), + generic_kernel_dir=Path("data/generic"), + instrument_name="TEST_INSTRUMENT", + time_field="corrected_timestamp", + ) + + +def _make_config(**overrides) -> CorrectionConfig: + defaults = dict( + n_iterations=3, + parameters=[ + ParameterConfig( + ptype=ParameterType.CONSTANT_KERNEL, + data={"current_value": [0.0, 0.0, 0.0], "bounds": [-300.0, 300.0]}, + ) + ], + geo=_make_geo(), + performance_threshold_m=_THRESHOLD_M, + performance_spec_percent=_SPEC_PCT, + ) + defaults.update(overrides) + return CorrectionConfig(**defaults) + + +def _make_netcdf_data(n_params: int = 3, n_pairs: int = 2, threshold_m: float = 250.0) -> dict: + """Synthetic netcdf_data mimicking the structure built by loop().""" + rms_grid = np.array([[100.0 + i * 50 + j * 10 for j in range(n_pairs)] for i in range(n_params)]) + mean_rms = rms_grid.mean(axis=1) + best_pair_rms = rms_grid.min(axis=1) + worst_pair_rms = rms_grid.max(axis=1) + return { + "parameter_set_id": np.arange(n_params), + "gcp_pair_id": np.arange(n_pairs), + "param_test_kernel_roll": np.linspace(-1.0, 1.0, n_params), + "param_test_kernel_pitch": np.linspace(-0.5, 0.5, n_params), + "rms_error_m": rms_grid, + "mean_error_m": rms_grid, + "max_error_m": rms_grid * 1.2, + "std_error_m": np.zeros((n_params, n_pairs)), + "n_measurements": np.ones((n_params, n_pairs), dtype=int) * 5, + "mean_rms_all_pairs": mean_rms, + "best_pair_rms": best_pair_rms, + "worst_pair_rms": worst_pair_rms, + f"percent_under_{int(threshold_m)}m": np.zeros(n_params), + "im_lat_error_km": np.zeros((n_params, n_pairs)), + "im_lon_error_km": np.zeros((n_params, n_pairs)), + "im_ccv": np.ones((n_params, n_pairs)) * 0.9, + "im_grid_step_m": np.ones((n_params, n_pairs)) * 30.0, + } + + +def _make_verification_result( + passed: bool = True, + pct: float = 60.0, + stats: dict | None = None, +) -> VerificationResult: + ds = xr.Dataset( + {"nadir_equiv_total_error_m": (["measurement"], np.array([100.0, 200.0, 300.0]))}, + coords={"measurement": np.arange(3)}, + ) + if stats: + ds.attrs.update(stats) + return VerificationResult( + passed=passed, + per_gcp_errors=[], + aggregate_stats=ds, + requirements=RequirementsConfig( + performance_threshold_m=_THRESHOLD_M, + performance_spec_percent=_SPEC_PCT, + ), + summary_table="", + percent_within_threshold=pct, + warnings=[] if passed else ["FAILED"], + timestamp=datetime.now(tz=timezone.utc), + ) + + +# =========================================================================== +# _fmt_rms +# =========================================================================== + + +def test_fmt_rms_finite(): + assert _fmt_rms(123.456) == "123.5m" + + +def test_fmt_rms_nan(): + assert _fmt_rms(float("nan")) == "N/A" + + +def test_fmt_rms_inf(): + assert _fmt_rms(float("inf")) == "N/A" + + +# =========================================================================== +# ParameterSetResult +# =========================================================================== + + +class TestParameterSetResult: + def test_construction(self): + ps = ParameterSetResult( + index=2, + parameter_values={"roll": 1.5, "pitch": -0.3}, + mean_rms_m=150.0, + best_pair_rms_m=100.0, + worst_pair_rms_m=200.0, + ) + assert ps.index == 2 + assert ps.parameter_values == {"roll": 1.5, "pitch": -0.3} + assert ps.mean_rms_m == pytest.approx(150.0) + assert ps.best_pair_rms_m == pytest.approx(100.0) + assert ps.worst_pair_rms_m == pytest.approx(200.0) + + def test_json_round_trip(self): + ps = ParameterSetResult( + index=0, + parameter_values={"roll": 0.5}, + mean_rms_m=200.0, + best_pair_rms_m=180.0, + worst_pair_rms_m=220.0, + ) + restored = ParameterSetResult.model_validate_json(ps.model_dump_json()) + assert restored.index == 0 + assert restored.mean_rms_m == pytest.approx(200.0) + + +# =========================================================================== +# CorrectionResult +# =========================================================================== + + +class TestCorrectionResult: + def _make(self, **overrides) -> CorrectionResult: + defaults = dict( + best_parameter_set={"roll": 0.1}, + best_rms_m=120.0, + best_index=0, + worst_rms_m=300.0, + mean_rms_m=200.0, + n_parameter_sets=3, + n_gcp_pairs=2, + all_parameter_sets=[ + ParameterSetResult( + index=0, + parameter_values={"roll": 0.1}, + mean_rms_m=120.0, + best_pair_rms_m=100.0, + worst_pair_rms_m=140.0, + ) + ], + met_threshold=True, + recommendation="Update kernel files.", + summary_table="(table)", + ) + defaults.update(overrides) + return CorrectionResult(**defaults) + + def test_construction_minimal(self): + result = self._make() + assert result.best_rms_m == pytest.approx(120.0) + assert result.met_threshold is True + assert result.n_parameter_sets == 3 + + def test_optional_fields_defaults(self): + result = self._make() + assert result.netcdf_path is None + assert result.config_snapshot == {} + assert result.elapsed_time_s == pytest.approx(0.0) + assert isinstance(result.timestamp, datetime) + + def test_raw_data_fields_accessible(self): + raw_results = [{"iteration": 0, "rms_error_m": 100.0}] + raw_netcdf = {"mean_rms_all_pairs": np.array([100.0, 150.0])} + result = self._make(results=raw_results, netcdf_data=raw_netcdf) + # Pydantic copies list/dict on model init; verify by value not identity + assert result.results == raw_results + assert list(result.netcdf_data.keys()) == list(raw_netcdf.keys()) + + def test_json_excludes_raw_data(self): + """results and netcdf_data must not appear in JSON output.""" + result = self._make( + results=[{"iteration": 0}], + netcdf_data={"mean_rms_all_pairs": np.array([100.0])}, + ) + d = result.model_dump() + assert "results" not in d + assert "netcdf_data" not in d + + def test_json_round_trip_scalar_fields(self): + result = self._make( + netcdf_path=Path("/tmp/out.nc"), + config_snapshot={"seed": 42, "n_iterations": 10}, + elapsed_time_s=3.14, + ) + json_str = result.model_dump_json() + data = json.loads(json_str) + assert data["best_rms_m"] == pytest.approx(120.0) + assert data["met_threshold"] is True + assert data["elapsed_time_s"] == pytest.approx(3.14) + assert data["netcdf_path"] == "/tmp/out.nc" + + def test_timestamp_is_utc_datetime(self): + result = self._make() + assert isinstance(result.timestamp, datetime) + # default_factory uses UTC + assert result.timestamp.tzinfo is not None + + +# =========================================================================== +# _format_correction_summary_table +# =========================================================================== + + +def _make_ps(idx: int, rms: float) -> ParameterSetResult: + return ParameterSetResult( + index=idx, + parameter_values={"roll": float(idx)}, + mean_rms_m=rms, + best_pair_rms_m=rms * 0.9, + worst_pair_rms_m=rms * 1.1, + ) + + +class TestFormatCorrectionSummaryTable: + def test_basic_structure(self): + sets = [_make_ps(0, 120.0), _make_ps(1, 200.0)] + table = _format_correction_summary_table(sets, total_sets=5, n_gcp_pairs=2, met_threshold=True) + lines = table.splitlines() + # First line starts with ┌ and last with └ + assert lines[0].startswith("┌") + assert lines[-1].startswith("└") + # All lines have same length + lengths = {len(line) for line in lines} + assert len(lengths) == 1, f"Inconsistent line widths: {sorted(lengths)}" + + def test_met_threshold_shows_in_footer(self): + sets = [_make_ps(0, 100.0)] + table = _format_correction_summary_table(sets, 1, 1, met_threshold=True) + assert "MET REQUIREMENTS" in table + assert "✓" in table + + def test_not_met_shows_in_footer(self): + sets = [_make_ps(0, 500.0)] + table = _format_correction_summary_table(sets, 1, 1, met_threshold=False) + assert "DID NOT MEET" in table + assert "✗" in table + + def test_empty_sets_no_crash(self): + table = _format_correction_summary_table([], total_sets=0, n_gcp_pairs=0, met_threshold=False) + assert "No results available" in table + lines = table.splitlines() + lengths = {len(line) for line in lines} + assert len(lengths) == 1 + + def test_best_set_has_star_marker(self): + sets = [_make_ps(0, 80.0), _make_ps(1, 150.0)] + table = _format_correction_summary_table(sets, 2, 1, met_threshold=True) + assert "★" in table + + def test_nan_values_display_as_na(self): + sets = [_make_ps(0, float("nan"))] + table = _format_correction_summary_table(sets, 1, 1, met_threshold=False) + assert "N/A" in table + + def test_long_title_widens_table_consistently(self): + """Wide title (many GCP pairs) must not break line-width consistency.""" + sets = [_make_ps(i, 100.0 + i * 10) for i in range(5)] + table = _format_correction_summary_table(sets, total_sets=10000, n_gcp_pairs=99, met_threshold=True) + lines = table.splitlines() + lengths = {len(line) for line in lines} + assert len(lengths) == 1 + + def test_title_appears_in_output(self): + sets = [_make_ps(0, 200.0)] + table = _format_correction_summary_table(sets, total_sets=7, n_gcp_pairs=3, met_threshold=False) + assert "7 sets" in table + assert "3 pairs" in table + + +# =========================================================================== +# build_correction_result +# =========================================================================== + + +class TestBuildCorrectionResult: + def test_basic_construction(self): + config = _make_config() + nc = _make_netcdf_data(n_params=3, n_pairs=2) + result = build_correction_result(config, [], nc, Path("/tmp/out.nc"), elapsed_time_s=1.5) + + assert isinstance(result, CorrectionResult) + assert result.n_parameter_sets == 3 + assert result.n_gcp_pairs == 2 + assert result.elapsed_time_s == pytest.approx(1.5) + assert result.netcdf_path == Path("/tmp/out.nc") + + def test_best_index_is_lowest_mean_rms(self): + config = _make_config() + nc = _make_netcdf_data(n_params=3, n_pairs=2) + result = build_correction_result(config, [], nc, None, 0.0) + # Mean RMS values are 105, 155, 205 (from the helper); best is index 0 + assert result.best_index == 0 + assert result.best_rms_m < result.worst_rms_m + + def test_all_sets_sorted_ascending(self): + config = _make_config() + nc = _make_netcdf_data(n_params=5, n_pairs=2) + result = build_correction_result(config, [], nc, None, 0.0) + rms_values = [ps.mean_rms_m for ps in result.all_parameter_sets] + assert rms_values == sorted(rms_values) + + def test_met_threshold_true_when_all_pairs_below(self): + """If all pair RMS < threshold, pct_below = 100% ≥ spec → met.""" + # Use very large threshold so all errors are "below" + config = _make_config( + performance_threshold_m=10_000.0, + performance_spec_percent=39.0, + ) + nc = _make_netcdf_data(n_params=2, n_pairs=2, threshold_m=10_000.0) + result = build_correction_result(config, [], nc, None, 0.0) + assert result.met_threshold is True + assert "meets performance requirements" in result.recommendation + + def test_met_threshold_false_when_no_pairs_below(self): + """If no pair RMS < threshold, pct_below = 0% < spec → not met.""" + config = _make_config( + performance_threshold_m=0.001, # impossibly tight + performance_spec_percent=39.0, + ) + nc = _make_netcdf_data(n_params=2, n_pairs=2, threshold_m=0) + result = build_correction_result(config, [], nc, None, 0.0) + assert result.met_threshold is False + assert "No parameter set met performance requirements" in result.recommendation + + def test_all_nan_rms_does_not_crash(self): + """Degenerate case: all RMS values are NaN — build should not raise.""" + config = _make_config() + nc = _make_netcdf_data(n_params=2, n_pairs=2) + nc["mean_rms_all_pairs"] = np.array([float("nan"), float("nan")]) + nc["rms_error_m"][:] = float("nan") + result = build_correction_result(config, [], nc, None, 0.0) + assert math.isnan(result.best_rms_m) + assert math.isnan(result.worst_rms_m) + # Table should still render without crash + assert "┌" in result.summary_table + + def test_raw_data_preserved(self): + config = _make_config() + nc = _make_netcdf_data() + raw = [{"iteration": 0, "rms_error_m": 100.0}] + result = build_correction_result(config, raw, nc, None, 0.0) + # Pydantic copies list/dict on model init; verify by value not identity + assert result.results == raw + assert set(result.netcdf_data.keys()) == set(nc.keys()) + + def test_config_snapshot_contains_key_fields(self): + config = _make_config() + nc = _make_netcdf_data() + result = build_correction_result(config, [], nc, None, 0.0) + snap = result.config_snapshot + assert "n_iterations" in snap + assert "performance_threshold_m" in snap + assert "performance_spec_percent" in snap + assert "search_strategy" in snap + + def test_summary_table_included(self): + config = _make_config() + nc = _make_netcdf_data() + result = build_correction_result(config, [], nc, None, 0.0) + assert "┌" in result.summary_table + assert "Correction Sweep Summary" in result.summary_table + + +# =========================================================================== +# compare_results +# =========================================================================== + + +class TestCompareResults: + def test_basic_output_format(self): + before = _make_verification_result(passed=False, pct=25.0) + after = _make_verification_result(passed=True, pct=65.0) + output = compare_results(before, after) + + assert "Verification Comparison" in output + assert "Before" in output + assert "After" in output + assert "Overall" in output + assert "PASS" in output + assert "FAIL" in output + + def test_percent_within_threshold_shown(self): + before = _make_verification_result(pct=30.0) + after = _make_verification_result(pct=70.0) + output = compare_results(before, after) + assert "percent_within_threshold" in output + # Both values should appear + assert "30.0%" in output + assert "70.0%" in output + + def test_aggregate_stats_attrs_used(self): + before = _make_verification_result(stats={"mean_error_m": 350.0, "rms_error_m": 400.0}) + after = _make_verification_result(stats={"mean_error_m": 120.0, "rms_error_m": 150.0}) + output = compare_results(before, after) + assert "mean_error_m" in output + assert "350.0" in output + assert "120.0" in output + + def test_missing_stat_shows_na(self): + """Stats not in aggregate_stats.attrs should display as N/A.""" + before = _make_verification_result(stats={}) + after = _make_verification_result(stats={}) + output = compare_results(before, after) + # All stat rows should show N/A since attrs is empty + assert "N/A" in output + + def test_returns_string(self): + before = _make_verification_result() + after = _make_verification_result() + result = compare_results(before, after) + assert isinstance(result, str) + assert len(result) > 0 + + +# =========================================================================== +# VerificationResult provenance fields +# =========================================================================== + + +class TestVerificationResultProvenanceFields: + def test_default_files_processed_is_empty_list(self): + vr = _make_verification_result() + assert vr.files_processed == [] + + def test_default_elapsed_time_s_is_none(self): + vr = _make_verification_result() + assert vr.elapsed_time_s is None + + def test_default_config_snapshot_is_none(self): + vr = _make_verification_result() + assert vr.config_snapshot is None + + def test_provenance_fields_set_explicitly(self): + vr = _make_verification_result() + vr_with_prov = vr.model_copy( + update={ + "files_processed": ["sci_0+gcp_0", "sci_1+gcp_1"], + "elapsed_time_s": 2.5, + "config_snapshot": {"instrument_name": "TEST"}, + } + ) + assert vr_with_prov.files_processed == ["sci_0+gcp_0", "sci_1+gcp_1"] + assert vr_with_prov.elapsed_time_s == pytest.approx(2.5) + assert vr_with_prov.config_snapshot == {"instrument_name": "TEST"} + + def test_existing_fields_unaffected_by_new_fields(self): + """Old construction sites that don't supply provenance fields still work.""" + vr = VerificationResult( + passed=True, + per_gcp_errors=[], + aggregate_stats=xr.Dataset( + {"nadir_equiv_total_error_m": (["measurement"], np.array([100.0]))}, + coords={"measurement": np.arange(1)}, + ), + requirements=RequirementsConfig( + performance_threshold_m=250.0, + performance_spec_percent=39.0, + ), + summary_table="", + percent_within_threshold=100.0, + warnings=[], + timestamp=datetime.now(tz=timezone.utc), + # provenance fields intentionally omitted + ) + assert vr.passed is True + assert vr.files_processed == [] + assert vr.elapsed_time_s is None + assert vr.config_snapshot is None + + def test_json_round_trip_with_provenance_fields(self): + vr = _make_verification_result() + vr = vr.model_copy( + update={ + "files_processed": ["sci+gcp"], + "elapsed_time_s": 1.23, + "config_snapshot": {"instrument_name": "CLARREO"}, + } + ) + json_str = vr.model_dump_json(exclude={"aggregate_stats"}) + data = json.loads(json_str) + assert data["files_processed"] == ["sci+gcp"] + assert data["elapsed_time_s"] == pytest.approx(1.23) + assert data["config_snapshot"]["instrument_name"] == "CLARREO" + + +# =========================================================================== +# Backward compatibility: loop() still returns (results, netcdf_data) +# =========================================================================== + + +def test_loop_return_annotation_is_not_correction_result(): + """loop() must NOT be annotated to return CorrectionResult. + + This is a static check — loop() is the internal workhorse and its + 2-tuple return type must remain stable. + """ + import inspect + + from curryer.correction.pipeline import loop + + sig = inspect.signature(loop) + ann = sig.return_annotation + # No annotation (empty) is fine; CorrectionResult annotation is wrong + assert ann is inspect.Parameter.empty or "CorrectionResult" not in str(ann) + + +def test_run_correction_return_annotation_is_correction_result(): + """run_correction() must be annotated to return CorrectionResult.""" + import inspect + + from curryer.correction.pipeline import run_correction + + sig = inspect.signature(run_correction) + ann = sig.return_annotation + assert "CorrectionResult" in str(ann) diff --git a/tests/test_correction/test_results_io.py b/tests/test_correction/test_results_io.py new file mode 100644 index 00000000..d26fd51c --- /dev/null +++ b/tests/test_correction/test_results_io.py @@ -0,0 +1,83 @@ +"""Tests for ``curryer.correction.results_io``. + +Covers: +- ``_build_netcdf_structure`` +- ``_save_netcdf_checkpoint`` / ``_load_checkpoint`` / ``_cleanup_checkpoint`` +""" + +from __future__ import annotations + +import logging + +import pytest +from clarreo_config import create_clarreo_correction_config + +from curryer.correction import correction + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="module") +def clarreo_cfg(root_dir): + return create_clarreo_correction_config( + root_dir / "tests" / "data" / "clarreo" / "gcs", + root_dir / "data" / "generic", + ) + + +# ── _build_netcdf_structure ─────────────────────────────────────────────────── + + +def test_build_netcdf_structure_shapes(clarreo_cfg): + """_build_netcdf_structure returns arrays with the requested dimensions.""" + nc = correction._build_netcdf_structure(clarreo_cfg, n_param_sets=3, n_gcp_pairs=2) + assert isinstance(nc, dict) + assert "rms_error_m" in nc + assert "parameter_set_id" in nc + assert nc["rms_error_m"].shape == (3, 2) + assert len(nc["parameter_set_id"]) == 3 + + +def test_build_netcdf_structure_zero_initialised(clarreo_cfg): + """All numeric arrays are initialised to zero (or NaN for float arrays).""" + nc = correction._build_netcdf_structure(clarreo_cfg, n_param_sets=2, n_gcp_pairs=1) + # rms_error_m starts as zeros or NaN – just check it is numeric + assert nc["rms_error_m"].dtype.kind in ("f", "i", "u") + + +# ── checkpoint round-trip ───────────────────────────────────────────────────── + + +def test_checkpoint_save_load_cleanup(clarreo_cfg, tmp_path): + """Save → load → cleanup round-trip preserves data and removes the file.""" + cfg = clarreo_cfg.model_copy(deep=True) + cfg.n_iterations = 2 + cfg.output_filename = "ckpt_test.nc" + output_file = tmp_path / cfg.output_filename + + nc = correction._build_netcdf_structure(cfg, n_param_sets=2, n_gcp_pairs=2) + nc["rms_error_m"][0, 0] = 100.0 + nc["rms_error_m"][1, 0] = 150.0 + + correction._save_netcdf_checkpoint(nc, output_file, cfg, pair_idx_completed=0) + ckpt = output_file.parent / f"{output_file.stem}_checkpoint.nc" + assert ckpt.exists() + + loaded, completed = correction._load_checkpoint(output_file, cfg) + assert loaded is not None + assert completed == 1 # pair index 0 completed → 1 pair done + assert loaded["rms_error_m"][0, 0] == pytest.approx(100.0) + assert loaded["rms_error_m"][1, 0] == pytest.approx(150.0) + + correction._cleanup_checkpoint(output_file) + assert not ckpt.exists() + + +def test_checkpoint_load_missing_returns_none(clarreo_cfg, tmp_path): + """_load_checkpoint returns (None, 0) when no checkpoint file exists.""" + cfg = clarreo_cfg.model_copy(deep=True) + cfg.output_filename = "no_ckpt.nc" + output_file = tmp_path / cfg.output_filename + loaded, completed = correction._load_checkpoint(output_file, cfg) + assert loaded is None + assert completed == 0 diff --git a/tests/test_correction/test_verification.py b/tests/test_correction/test_verification.py new file mode 100644 index 00000000..10a41f9f --- /dev/null +++ b/tests/test_correction/test_verification.py @@ -0,0 +1,711 @@ +"""Unit tests for curryer.correction.verification. + +Covers +------ +- :class:`RequirementsConfig` – Pydantic model construction and validation +- :class:`GCPError` – typed per-measurement detail +- :class:`VerificationResult` – JSON round-trip serialisation +- :func:`_check_threshold` – 0 %, 39 %, 100 % edge cases +- :func:`_generate_warnings` – pass / fail messaging +- :func:`_format_summary_table` – structure and content checks +- :func:`_build_per_gcp_errors` – correct passed flag and coordinate fallback +- :func:`verify` – end-to-end with pre-computed ``image_matching_results`` +- :func:`verify` – error paths (empty list, missing inputs, missing func) + +All ``CorrectionConfig`` fixtures avoid deleted fields (``telemetry_loader``, +``science_loader``, ``gcp_loader``, ``gcp_pairing_func``). +""" + +from __future__ import annotations + +import json +from datetime import datetime, timezone +from pathlib import Path + +import numpy as np +import pytest +import xarray as xr +from pydantic import ValidationError + +from curryer.correction.config import ( + CorrectionConfig, + GeolocationConfig, + ParameterConfig, + ParameterType, +) +from curryer.correction.verification import ( + GCPError, + RequirementsConfig, + VerificationResult, + _build_per_gcp_errors, + _check_threshold, + _format_summary_table, + _generate_warnings, + verify, +) + +# =========================================================================== +# Helpers / factories +# =========================================================================== + +_THRESHOLD_M = 250.0 +_SPEC_PCT = 39.0 + + +def _make_geo() -> GeolocationConfig: + """Minimal GeolocationConfig; files need not exist for verification tests.""" + return GeolocationConfig( + meta_kernel_file=Path("tests/data/test.kernels.tm.json"), + generic_kernel_dir=Path("data/generic"), + instrument_name="TEST_INSTRUMENT", + time_field="corrected_timestamp", + ) + + +def _make_config(**overrides) -> CorrectionConfig: + """Return a minimal CorrectionConfig suitable for verification tests. + + Provides CLARREO-style variable name mappings and does **not** set any + deleted fields (``telemetry_loader``, ``science_loader``, ``gcp_loader``, + ``gcp_pairing_func``). + """ + defaults = dict( + n_iterations=1, + parameters=[ + ParameterConfig( + ptype=ParameterType.CONSTANT_KERNEL, + data={"current_value": [0.0, 0.0, 0.0], "bounds": [-300.0, 300.0]}, + ) + ], + geo=_make_geo(), + performance_threshold_m=_THRESHOLD_M, + performance_spec_percent=_SPEC_PCT, + # CLARREO-style names so the 13-case dataset validates cleanly + spacecraft_position_name="riss_ctrs", + boresight_name="bhat_hs", + transformation_matrix_name="t_hs2ctrs", + ) + defaults.update(overrides) + return CorrectionConfig(**defaults) + + +def _make_aggregate_stats_dataset(nadir_errors_m: list[float]) -> xr.Dataset: + """Minimal dataset with ``nadir_equiv_total_error_m`` for threshold tests.""" + n = len(nadir_errors_m) + return xr.Dataset( + { + "nadir_equiv_total_error_m": (["measurement"], np.array(nadir_errors_m, dtype=float)), + "lat_error_deg": (["measurement"], np.zeros(n)), + "lon_error_deg": (["measurement"], np.zeros(n)), + }, + coords={"measurement": np.arange(n)}, + ) + + +def _make_full_image_matching_dataset(n: int = 5, seed: int = 0) -> xr.Dataset: + """Create a self-contained image-matching dataset usable by verify(). + + Uses the validated 13-case geometry sampled with replacement so that + :class:`~curryer.correction.error_stats.ErrorStatsProcessor` can compute + nadir-equivalent errors without triggering geometry warnings. + """ + from tests.test_correction.test_error_stats import ( + create_test_dataset_13_cases, + ) + + rng = np.random.default_rng(seed) + base = create_test_dataset_13_cases() + indices = rng.integers(0, 13, n) + sampled = base.isel(measurement=indices).assign_coords(measurement=np.arange(n)) + return sampled + + +# =========================================================================== +# RequirementsConfig +# =========================================================================== + + +class TestRequirementsConfig: + def test_construction(self): + req = RequirementsConfig(performance_threshold_m=250.0, performance_spec_percent=39.0) + assert req.performance_threshold_m == 250.0 + assert req.performance_spec_percent == 39.0 + + def test_json_round_trip(self): + req = RequirementsConfig(performance_threshold_m=500.0, performance_spec_percent=80.0) + restored = RequirementsConfig.model_validate_json(req.model_dump_json()) + assert restored.performance_threshold_m == 500.0 + assert restored.performance_spec_percent == 80.0 + + def test_missing_fields_raise(self): + with pytest.raises(ValidationError, match="performance_threshold_m"): + RequirementsConfig() + + +# =========================================================================== +# GCPError +# =========================================================================== + + +class TestGCPError: + def test_construction_full(self): + err = GCPError( + gcp_index=0, + science_key="sci_0", + gcp_key="gcp_0", + lat_error_deg=0.001, + lon_error_deg=-0.002, + nadir_equiv_error_m=120.5, + correlation=0.87, + passed=True, + ) + assert err.passed is True + assert err.nadir_equiv_error_m == pytest.approx(120.5) + + def test_optional_fields_default_to_none(self): + err = GCPError( + gcp_index=1, + science_key="s", + gcp_key="g", + lat_error_deg=0.0, + lon_error_deg=0.0, + passed=False, + ) + assert err.nadir_equiv_error_m is None + assert err.correlation is None + + def test_json_round_trip(self): + err = GCPError( + gcp_index=2, + science_key="sci_2", + gcp_key="gcp_2", + lat_error_deg=0.005, + lon_error_deg=0.003, + nadir_equiv_error_m=300.0, + correlation=0.65, + passed=False, + ) + raw = json.loads(err.model_dump_json()) + assert raw["gcp_index"] == 2 + assert raw["passed"] is False + + +# =========================================================================== +# VerificationResult +# =========================================================================== + + +class TestVerificationResult: + def _make_result(self, passed: bool = True) -> VerificationResult: + req = RequirementsConfig(performance_threshold_m=250.0, performance_spec_percent=39.0) + errors = [ + GCPError( + gcp_index=0, + science_key="s0", + gcp_key="g0", + lat_error_deg=0.001, + lon_error_deg=0.001, + nadir_equiv_error_m=100.0, + passed=True, + ) + ] + stats = _make_aggregate_stats_dataset([100.0]) + return VerificationResult( + passed=passed, + per_gcp_errors=errors, + aggregate_stats=stats, + requirements=req, + summary_table="table", + percent_within_threshold=100.0, + warnings=[], + timestamp=datetime.now(tz=timezone.utc), + ) + + def test_construction(self): + result = self._make_result() + assert result.passed is True + assert len(result.per_gcp_errors) == 1 + assert isinstance(result.aggregate_stats, xr.Dataset) + + def test_failed_result_has_warnings(self): + result = self._make_result(passed=False) + result = VerificationResult( + passed=False, + per_gcp_errors=result.per_gcp_errors, + aggregate_stats=result.aggregate_stats, + requirements=result.requirements, + summary_table="t", + percent_within_threshold=10.0, + warnings=["⚠️ VERIFICATION FAILED: ..."], + timestamp=result.timestamp, + ) + assert len(result.warnings) == 1 + assert "FAILED" in result.warnings[0] + + def test_model_dump_json_excludes_dataset(self): + """xr.Dataset is arbitrary type — model_dump_json should not crash.""" + result = self._make_result() + # Pydantic with arbitrary_types_allowed may not be JSON-serialisable for + # xr.Dataset, but other fields should dump cleanly. + dumped = result.model_dump(exclude={"aggregate_stats"}) + assert "passed" in dumped + assert "percent_within_threshold" in dumped + + +# =========================================================================== +# _check_threshold – edge cases 0 %, 39 %, 100 % +# =========================================================================== + + +class TestCheckThreshold: + """Validate _check_threshold against boundary conditions.""" + + def _req(self, spec_pct: float = _SPEC_PCT) -> RequirementsConfig: + return RequirementsConfig( + performance_threshold_m=_THRESHOLD_M, + performance_spec_percent=spec_pct, + ) + + def test_zero_percent_within_threshold_fails(self): + """All errors above threshold → 0 % pass → FAILED.""" + stats = _make_aggregate_stats_dataset([300.0, 400.0, 500.0]) + passed, pct = _check_threshold(stats, self._req()) + assert passed is False + assert pct == pytest.approx(0.0) + + def test_exactly_at_spec_percent_passes(self): + """When exactly spec_percent of measurements pass the threshold.""" + # 39 out of 100 below 250m → 39 % → should pass (>= 39 %) + errors = [100.0] * 39 + [300.0] * 61 + stats = _make_aggregate_stats_dataset(errors) + passed, pct = _check_threshold(stats, self._req(spec_pct=39.0)) + assert passed is True + assert pct == pytest.approx(39.0) + + def test_one_below_spec_percent_fails(self): + """One fewer passing measurement → should fail.""" + errors = [100.0] * 38 + [300.0] * 62 + stats = _make_aggregate_stats_dataset(errors) + passed, pct = _check_threshold(stats, self._req(spec_pct=39.0)) + assert passed is False + assert pct == pytest.approx(38.0) + + def test_hundred_percent_within_threshold_passes(self): + """All errors well below threshold → 100 % pass → PASSED.""" + stats = _make_aggregate_stats_dataset([50.0, 100.0, 150.0, 200.0]) + passed, pct = _check_threshold(stats, self._req()) + assert passed is True + assert pct == pytest.approx(100.0) + + def test_empty_dataset_fails(self): + """Empty measurement array → 0 % → FAILED.""" + stats = _make_aggregate_stats_dataset([]) + passed, pct = _check_threshold(stats, self._req()) + assert passed is False + assert pct == pytest.approx(0.0) + + def test_exactly_at_threshold_does_not_pass(self): + """Value at threshold fails per-measurement check, but overall spec=0 % passes.""" + stats = _make_aggregate_stats_dataset([_THRESHOLD_M]) + passed, pct = _check_threshold(stats, self._req(spec_pct=0.0)) + # Per-measurement: value is not < threshold → 0 % of measurements pass. + # Overall: 0 % >= spec_pct (0 %) → overall verification passes. + assert pct == pytest.approx(0.0) + assert passed is True + + +# =========================================================================== +# _generate_warnings +# =========================================================================== + + +class TestGenerateWarnings: + def _req(self) -> RequirementsConfig: + return RequirementsConfig(performance_threshold_m=250.0, performance_spec_percent=39.0) + + def test_no_warnings_when_passed(self): + warnings = _generate_warnings(passed=True, percent_below=60.0, requirements=self._req()) + assert warnings == [] + + def test_warning_emitted_when_failed(self): + warnings = _generate_warnings(passed=False, percent_below=20.0, requirements=self._req()) + assert len(warnings) == 1 + assert "VERIFICATION FAILED" in warnings[0] + assert "20.0%" in warnings[0] + assert "250.0m" in warnings[0] + assert "39.0%" in warnings[0] + + def test_warning_contains_recommendation(self): + warnings = _generate_warnings(passed=False, percent_below=5.0, requirements=self._req()) + assert "correction module" in warnings[0].lower() or "Recommend" in warnings[0] + + +# =========================================================================== +# _format_summary_table +# =========================================================================== + + +class TestFormatSummaryTable: + def _req(self) -> RequirementsConfig: + return RequirementsConfig(performance_threshold_m=250.0, performance_spec_percent=39.0) + + def _errors(self) -> list[GCPError]: + return [ + GCPError( + gcp_index=0, + science_key="sci_0", + gcp_key="gcp_0", + lat_error_deg=0.00123, + lon_error_deg=-0.00045, + nadir_equiv_error_m=145.2, + passed=True, + ), + GCPError( + gcp_index=1, + science_key="sci_1", + gcp_key="gcp_1", + lat_error_deg=0.00567, + lon_error_deg=0.00234, + nadir_equiv_error_m=312.8, + passed=False, + ), + ] + + def test_returns_string(self): + table = _format_summary_table(self._errors(), self._req(), 50.0, False) + assert isinstance(table, str) + assert len(table) > 0 + + def test_contains_header_and_footer(self): + table = _format_summary_table(self._errors(), self._req(), 50.0, False) + assert "Verification Summary" in table + assert "Result:" in table + + def test_pass_verdict_appears(self): + table = _format_summary_table(self._errors(), self._req(), 60.0, True) + assert "PASSED" in table + + def test_fail_verdict_appears(self): + table = _format_summary_table(self._errors(), self._req(), 20.0, False) + assert "FAILED" in table + + def test_threshold_and_spec_in_footer(self): + table = _format_summary_table(self._errors(), self._req(), 50.0, False) + assert "250.0m" in table + assert "39.0%" in table + + def test_checkmark_and_cross_present(self): + table = _format_summary_table(self._errors(), self._req(), 50.0, False) + assert "✓" in table + assert "✗" in table + + def test_empty_errors_list(self): + """Should not raise with zero measurements.""" + table = _format_summary_table([], self._req(), 0.0, False) + assert "Result:" in table + + def test_nadir_none_shows_na(self): + errors = [ + GCPError( + gcp_index=0, + science_key="s", + gcp_key="g", + lat_error_deg=0.0, + lon_error_deg=0.0, + nadir_equiv_error_m=None, + passed=False, + ) + ] + table = _format_summary_table(errors, self._req(), 0.0, False) + assert "N/A" in table + + +# =========================================================================== +# _build_per_gcp_errors +# =========================================================================== + + +class TestBuildPerGcpErrors: + def _req(self) -> RequirementsConfig: + return RequirementsConfig(performance_threshold_m=250.0, performance_spec_percent=39.0) + + def _stats_with_errors(self, nadir_errors: list[float]) -> xr.Dataset: + n = len(nadir_errors) + return xr.Dataset( + { + "nadir_equiv_total_error_m": (["measurement"], np.array(nadir_errors)), + "lat_error_deg": (["measurement"], np.linspace(0.001, 0.005, n)), + "lon_error_deg": (["measurement"], np.linspace(-0.001, 0.001, n)), + }, + coords={"measurement": np.arange(n)}, + ) + + def test_length_matches_measurements(self): + stats = self._stats_with_errors([100.0, 300.0, 200.0]) + errors = _build_per_gcp_errors(stats, [], self._req()) + assert len(errors) == 3 + + def test_passed_flag_set_correctly(self): + stats = self._stats_with_errors([100.0, 300.0]) + errors = _build_per_gcp_errors(stats, [], self._req()) + assert errors[0].passed is True # 100 < 250 + assert errors[1].passed is False # 300 >= 250 + + def test_source_mapping_applied(self): + stats = self._stats_with_errors([100.0]) + mapping = [("my_science", "my_gcp")] + errors = _build_per_gcp_errors(stats, mapping, self._req()) + assert errors[0].science_key == "my_science" + assert errors[0].gcp_key == "my_gcp" + + def test_fallback_keys_when_mapping_too_short(self): + stats = self._stats_with_errors([100.0, 200.0]) + errors = _build_per_gcp_errors(stats, [], self._req()) + assert errors[0].science_key == "sci_0" + assert errors[1].gcp_key == "gcp_1" + + def test_correlation_extracted_when_present(self): + n = 2 + stats = xr.Dataset( + { + "nadir_equiv_total_error_m": (["measurement"], [100.0, 200.0]), + "lat_error_deg": (["measurement"], [0.001, 0.002]), + "lon_error_deg": (["measurement"], [0.001, 0.002]), + "correlation": (["measurement"], [0.85, 0.92]), + }, + coords={"measurement": np.arange(n)}, + ) + errors = _build_per_gcp_errors(stats, [], self._req()) + assert errors[0].correlation == pytest.approx(0.85) + assert errors[1].correlation == pytest.approx(0.92) + + def test_no_correlation_variable_gives_none(self): + stats = self._stats_with_errors([100.0]) + errors = _build_per_gcp_errors(stats, [], self._req()) + assert errors[0].correlation is None + + def test_empty_dataset_returns_empty_list(self): + stats = self._stats_with_errors([]) + errors = _build_per_gcp_errors(stats, [], self._req()) + assert errors == [] + + +# =========================================================================== +# verify() – integration tests +# =========================================================================== + + +class TestVerify: + """End-to-end tests for :func:`verify` using synthetic image-matching data.""" + + @pytest.fixture + def config(self) -> CorrectionConfig: + return _make_config() + + @pytest.fixture + def image_matching_dataset(self) -> xr.Dataset: + """Single-pair dataset built from the validated 13-case geometry.""" + return _make_full_image_matching_dataset(n=13, seed=42) + + @pytest.fixture + def multi_pair_results(self) -> list[xr.Dataset]: + """Two GCP pairs with different sci/gcp labels.""" + ds1 = _make_full_image_matching_dataset(n=7, seed=0) + ds1.attrs["sci_key"] = "scene_A" + ds1.attrs["gcp_key"] = "gcp_site_1" + ds2 = _make_full_image_matching_dataset(n=6, seed=1) + ds2.attrs["sci_key"] = "scene_B" + ds2.attrs["gcp_key"] = "gcp_site_2" + return [ds1, ds2] + + # ------------------------------------------------------------------- + # Happy path – single GCP pair + # ------------------------------------------------------------------- + + def test_returns_verification_result(self, config, image_matching_dataset, tmp_path): + result = verify(config, image_matching_results=[image_matching_dataset], work_dir=tmp_path) + assert isinstance(result, VerificationResult) + + def test_result_has_all_fields(self, config, image_matching_dataset, tmp_path): + result = verify(config, image_matching_results=[image_matching_dataset], work_dir=tmp_path) + assert isinstance(result.passed, bool) + assert isinstance(result.per_gcp_errors, list) + assert isinstance(result.aggregate_stats, xr.Dataset) + assert isinstance(result.summary_table, str) + assert isinstance(result.percent_within_threshold, float) + assert isinstance(result.warnings, list) + assert isinstance(result.timestamp, datetime) + + def test_per_gcp_errors_count_matches_measurements(self, config, image_matching_dataset, tmp_path): + n = image_matching_dataset.sizes["measurement"] + result = verify(config, image_matching_results=[image_matching_dataset], work_dir=tmp_path) + assert len(result.per_gcp_errors) == n + + def test_all_per_gcp_have_nadir_error(self, config, image_matching_dataset, tmp_path): + result = verify(config, image_matching_results=[image_matching_dataset], work_dir=tmp_path) + for err in result.per_gcp_errors: + assert err.nadir_equiv_error_m is not None + assert err.nadir_equiv_error_m >= 0.0 + + def test_passed_flag_consistent_with_percent(self, config, image_matching_dataset, tmp_path): + result = verify(config, image_matching_results=[image_matching_dataset], work_dir=tmp_path) + if result.passed: + assert result.percent_within_threshold >= config.performance_spec_percent + else: + assert result.percent_within_threshold < config.performance_spec_percent + + def test_summary_table_is_non_empty_string(self, config, image_matching_dataset, tmp_path): + result = verify(config, image_matching_results=[image_matching_dataset], work_dir=tmp_path) + assert len(result.summary_table) > 0 + assert "Verification Summary" in result.summary_table + + def test_warnings_empty_when_passed(self, config, image_matching_dataset, tmp_path): + result = verify(config, image_matching_results=[image_matching_dataset], work_dir=tmp_path) + if result.passed: + assert result.warnings == [] + + def test_warnings_non_empty_when_failed(self, config, tmp_path): + """Force a FAILED result by using a very tight spec (100 %).""" + strict_config = _make_config(performance_spec_percent=100.0) + ds = _make_full_image_matching_dataset(n=13, seed=0) + result = verify(strict_config, image_matching_results=[ds], work_dir=tmp_path) + # With 100 % required, any imperfect measurement causes failure + if not result.passed: + assert len(result.warnings) >= 1 + assert "VERIFICATION FAILED" in result.warnings[0] + + # ------------------------------------------------------------------- + # Happy path – multiple GCP pairs + # ------------------------------------------------------------------- + + def test_multi_pair_aggregates_all_measurements(self, config, multi_pair_results, tmp_path): + total = sum(ds.sizes["measurement"] for ds in multi_pair_results) + result = verify(config, image_matching_results=multi_pair_results, work_dir=tmp_path) + assert len(result.per_gcp_errors) == total + + def test_multi_pair_science_keys_from_attrs(self, config, multi_pair_results, tmp_path): + result = verify(config, image_matching_results=multi_pair_results, work_dir=tmp_path) + sci_keys = {e.science_key for e in result.per_gcp_errors} + assert "scene_A" in sci_keys + assert "scene_B" in sci_keys + + def test_requirements_reflect_config(self, config, image_matching_dataset, tmp_path): + result = verify(config, image_matching_results=[image_matching_dataset], work_dir=tmp_path) + assert result.requirements.performance_threshold_m == _THRESHOLD_M + assert result.requirements.performance_spec_percent == _SPEC_PCT + + def test_work_dir_created_if_missing(self, config, image_matching_dataset, tmp_path): + new_dir = tmp_path / "nonexistent" / "subdir" + assert not new_dir.exists() + verify(config, image_matching_results=[image_matching_dataset], work_dir=new_dir) + assert new_dir.exists() + + # ------------------------------------------------------------------- + # RequirementsConfig override via config.verification + # ------------------------------------------------------------------- + + def test_custom_requirements_override_used(self, image_matching_dataset, tmp_path): + """Attach RequirementsConfig to config.verification; verify() should use it.""" + base_config = _make_config() + # Inject a very lenient requirement so it almost certainly passes + req = RequirementsConfig(performance_threshold_m=1_000_000.0, performance_spec_percent=0.0) + # Monkeypatch the config object (verification is not a declared field but + # _build_requirements uses getattr with a fallback) + object.__setattr__(base_config, "verification", req) + result = verify(base_config, image_matching_results=[image_matching_dataset], work_dir=tmp_path) + assert result.requirements.performance_threshold_m == 1_000_000.0 + + # ------------------------------------------------------------------- + # Error paths + # ------------------------------------------------------------------- + + def test_empty_image_matching_list_raises(self, config, tmp_path): + with pytest.raises(ValueError, match="must not be empty"): + verify(config, image_matching_results=[], work_dir=tmp_path) + + def test_neither_input_raises_value_error(self, config, tmp_path): + with pytest.raises(ValueError, match="Neither image_matching_results nor geolocated_data"): + verify(config, work_dir=tmp_path) + + def test_geolocated_data_without_func_raises(self, config, tmp_path): + dummy_ds = xr.Dataset({"dummy": (["x"], [1, 2, 3])}) + with pytest.raises(ValueError, match="_image_matching_override is not set"): + verify(config, geolocated_data=dummy_ds, work_dir=tmp_path) + + def test_geolocated_data_with_func_called(self, image_matching_dataset, tmp_path): + """_image_matching_override should be called when geolocated_data is supplied.""" + called = {"count": 0} + + def mock_matching_func(data): + called["count"] += 1 + return [image_matching_dataset] + + config = _make_config() + config._image_matching_override = mock_matching_func + dummy_geolocated = xr.Dataset({"placeholder": (["x"], [1, 2])}) + result = verify(config, geolocated_data=dummy_geolocated, work_dir=tmp_path) + assert called["count"] == 1 + assert isinstance(result, VerificationResult) + + def test_gcp_pairs_raises_not_implemented(self, config, tmp_path): + """gcp_pairs mode raises NotImplementedError with a helpful message.""" + with pytest.raises(NotImplementedError, match="gcp_pairs"): + verify(config, gcp_pairs=[("obs.mat", "gcp.mat")], work_dir=tmp_path) + + def test_observation_paths_raises_not_implemented(self, config, tmp_path): + """observation_paths + gcp_directory mode raises NotImplementedError.""" + with pytest.raises(NotImplementedError, match="observation_paths"): + verify(config, observation_paths=["obs.mat"], gcp_directory=tmp_path, work_dir=tmp_path) + + def test_gcp_directory_alone_raises_not_implemented(self, config, tmp_path): + """gcp_directory without observation_paths also raises NotImplementedError.""" + with pytest.raises(NotImplementedError): + verify(config, gcp_directory=tmp_path, work_dir=tmp_path) + + +# =========================================================================== +# _log_pairing_summary +# =========================================================================== + + +class TestLogPairingSummary: + """Tests for the _log_pairing_summary logging helper.""" + + def test_all_paired(self, caplog): + import logging + + from curryer.correction.verification import _log_pairing_summary + + pairs = [(Path("obs_001.mat"), Path("gcp_001.mat")), (Path("obs_002.mat"), Path("gcp_002.mat"))] + with caplog.at_level(logging.INFO, logger="curryer.correction.verification"): + _log_pairing_summary(pairs) + + log_text = "\n".join(caplog.messages) + assert "obs_001.mat" in log_text + assert "gcp_001.mat" in log_text + assert "Proceeding with 2 observation(s)" in log_text + + def test_with_unpaired(self, caplog): + import logging + + from curryer.correction.verification import _log_pairing_summary + + pairs = [(Path("obs_001.mat"), Path("gcp_001.mat"))] + unpaired = [Path("obs_002.mat")] + with caplog.at_level(logging.INFO, logger="curryer.correction.verification"): + _log_pairing_summary(pairs, unpaired=unpaired) + + log_text = "\n".join(caplog.messages) + assert "obs_002.mat" in log_text + assert "No matching GCP" in log_text + assert "Proceeding with 1 observation(s)" in log_text + + def test_empty_pairs(self, caplog): + import logging + + from curryer.correction.verification import _log_pairing_summary + + with caplog.at_level(logging.INFO, logger="curryer.correction.verification"): + _log_pairing_summary([]) + + assert "Proceeding with 0 observation(s)" in "\n".join(caplog.messages)