diff --git a/curryer/correction/config.py b/curryer/correction/config.py index 87ba3c97..fb43a856 100644 --- a/curryer/correction/config.py +++ b/curryer/correction/config.py @@ -734,6 +734,9 @@ def load_config_from_json(config_path: Path) -> "CorrectionConfig": if kernel_file: param_groups[param_name]["config_file"] = Path(kernel_file) _config_logger.debug(f"Mapped OFFSET_KERNEL '{param_name}' → {kernel_file}") + elif param_dict.get("config_file"): + param_groups[param_name]["config_file"] = Path(param_dict["config_file"]) + _config_logger.debug(f"Using explicit config_file for OFFSET_KERNEL '{param_name}'") else: _config_logger.warning(f"No kernel mapping found for OFFSET_KERNEL parameter: {param_name}") @@ -754,12 +757,12 @@ def load_config_from_json(config_path: Path) -> "CorrectionConfig": else: param_dict = group_data["param_dict"] param_data = { - "current_value": param_dict.get("initial_value", 0.0), + "current_value": param_dict.get("current_value", param_dict.get("initial_value", 0.0)), "bounds": param_dict.get("bounds", [-100, 100]), "sigma": param_dict.get("sigma"), "units": param_dict.get("units", "radians"), "distribution": param_dict.get("distribution_type", "normal"), - "field": param_dict.get("application_target", {}).get("field_name", None), + "field": (param_dict.get("field") or param_dict.get("application_target", {}).get("field_name", None)), } parameters.append( diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 00000000..9a9c14bf --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1,2 @@ +# Integration tests package + diff --git a/tests/test_correction/_synthetic_helpers.py b/tests/test_correction/_synthetic_helpers.py new file mode 100644 index 00000000..367c273b --- /dev/null +++ b/tests/test_correction/_synthetic_helpers.py @@ -0,0 +1,159 @@ +"""Synthetic test-data helpers shared by test_pipeline.py and clarreo e2e tests. + +These functions generate realistic-looking but entirely synthetic sensor data +(boresight vectors, spacecraft positions, transformation matrices, GCP pairs) +so that upstream pipeline tests can run without any real instrument data. + +**These are test infrastructure helpers – not pytest tests.** +""" + +from __future__ import annotations + +import logging + +import numpy as np +import xarray as xr + +logger = logging.getLogger(__name__) + + +class _PlaceholderConfig: + """Default parameters controlling synthetic data generation.""" + + base_error_m: float = 50.0 + param_error_scale: float = 10.0 + max_measurements: int = 100 + min_measurements: int = 10 + orbit_radius_mean_m: float = 6.78e6 + orbit_radius_std_m: float = 4e3 + latitude_range: tuple = (-60.0, 60.0) + longitude_range: tuple = (-180.0, 180.0) + altitude_range: tuple = (0.0, 1000.0) + max_off_nadir_rad: float = 0.1 + + +def synthetic_gcp_pairing(science_data_files): + """Return SYNTHETIC GCP pairs for upstream testing (no real GCP data needed).""" + logger.warning("USING SYNTHETIC GCP PAIRING - FAKE DATA!") + return [(str(f), f"landsat_gcp_{i:03d}.tif") for i, f in enumerate(science_data_files)] + + +def synthetic_image_matching( + geolocated_data, + gcp_reference_file, + telemetry, + calibration_dir, + params_info, + config, + los_vectors_cached=None, + optical_psfs_cached=None, +): + """Return SYNTHETIC image-matching results for upstream testing. + + Accepts (and ignores) the same signature as the real ``image_matching`` + function so the pipeline loop can call it transparently. + """ + logger.warning("USING SYNTHETIC IMAGE MATCHING - FAKE DATA!") + placeholder_cfg = ( + config.placeholder if hasattr(config, "placeholder") and config.placeholder else _PlaceholderConfig() + ) + sc_pos_name = getattr(config, "spacecraft_position_name", "sc_position") + boresight_name = getattr(config, "boresight_name", "boresight") + transform_name = getattr(config, "transformation_matrix_name", "t_inst2ref") + + valid_mask = ~np.isnan(geolocated_data["latitude"].values).any(axis=1) + n_valid = int(valid_mask.sum()) + n_meas = placeholder_cfg.min_measurements if n_valid == 0 else min(n_valid, placeholder_cfg.max_measurements) + + riss_ctrs = _generate_spherical_positions( + n_meas, placeholder_cfg.orbit_radius_mean_m, placeholder_cfg.orbit_radius_std_m + ) + boresights = _generate_synthetic_boresights(n_meas, placeholder_cfg.max_off_nadir_rad) + t_matrices = _generate_nadir_aligned_transforms(n_meas, riss_ctrs, boresights) + + param_contribution = ( + sum(abs(p) if isinstance(p, (int, float)) else np.linalg.norm(p) for _, p in params_info) + * placeholder_cfg.param_error_scale + ) + error_magnitude = placeholder_cfg.base_error_m + param_contribution + lat_errors = np.random.normal(0, error_magnitude / 111_000, n_meas) + lon_errors = np.random.normal(0, error_magnitude / 111_000, n_meas) + + if n_valid > 0: + idx = np.where(valid_mask)[0][:n_meas] + gcp_lat = geolocated_data["latitude"].values[idx, 0] + gcp_lon = geolocated_data["longitude"].values[idx, 0] + else: + gcp_lat = np.random.uniform(*placeholder_cfg.latitude_range, n_meas) + gcp_lon = np.random.uniform(*placeholder_cfg.longitude_range, n_meas) + + gcp_alt = np.random.uniform(*placeholder_cfg.altitude_range, n_meas) + + return xr.Dataset( + { + "lat_error_deg": (["measurement"], lat_errors), + "lon_error_deg": (["measurement"], lon_errors), + sc_pos_name: (["measurement", "xyz"], riss_ctrs), + boresight_name: (["measurement", "xyz"], boresights), + transform_name: (["measurement", "xyz_from", "xyz_to"], t_matrices), + "gcp_lat_deg": (["measurement"], gcp_lat), + "gcp_lon_deg": (["measurement"], gcp_lon), + "gcp_alt": (["measurement"], gcp_alt), + }, + coords={ + "measurement": range(n_meas), + "xyz": ["x", "y", "z"], + "xyz_from": ["x", "y", "z"], + "xyz_to": ["x", "y", "z"], + }, + ) + + +def _generate_synthetic_boresights(n, max_off_nadir_rad=0.07): + """Return *n* unit boresight vectors with small off-nadir angles.""" + b = np.zeros((n, 3)) + for i in range(n): + th = np.random.uniform(-max_off_nadir_rad, max_off_nadir_rad) + b[i] = [0.0, np.sin(th), np.cos(th)] + return b + + +def _generate_spherical_positions(n, radius_mean_m, radius_std_m): + """Return *n* random points on a sphere (spacecraft orbit positions).""" + pos = np.zeros((n, 3)) + for i in range(n): + r = np.random.normal(radius_mean_m, radius_std_m) + phi = np.random.uniform(0, 2 * np.pi) + ct = np.random.uniform(-1, 1) + st = np.sqrt(max(0.0, 1 - ct**2)) + pos[i] = [r * st * np.cos(phi), r * st * np.sin(phi), r * ct] + return pos + + +def _generate_nadir_aligned_transforms(n, riss_ctrs, boresights_hs): + """Return *n* rotation matrices aligning ``boresights_hs`` toward nadir.""" + T = np.zeros((n, 3, 3)) + for i in range(n): + nadir = -riss_ctrs[i] / np.linalg.norm(riss_ctrs[i]) + bhat = boresights_hs[i] / np.linalg.norm(boresights_hs[i]) + ax = np.cross(bhat, nadir) + ax_norm = np.linalg.norm(ax) + if ax_norm < 1e-6: + if np.dot(bhat, nadir) > 0: + T[i] = np.eye(3) + else: + perp = np.array([1, 0, 0]) if abs(bhat[0]) < 0.9 else np.array([0, 1, 0]) + ax = np.cross(bhat, perp) + ax /= np.linalg.norm(ax) + K = _skew(ax) + T[i] = np.eye(3) + 2 * K @ K + else: + ax /= ax_norm + angle = np.arccos(np.clip(np.dot(bhat, nadir), -1.0, 1.0)) + K = _skew(ax) + T[i] = np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * (K @ K) + return T + + +def _skew(v): + return np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) diff --git a/tests/test_correction/clarreo/__init__.py b/tests/test_correction/clarreo/__init__.py new file mode 100644 index 00000000..ffb6047f --- /dev/null +++ b/tests/test_correction/clarreo/__init__.py @@ -0,0 +1 @@ +"""CLARREO-specific correction tests.""" diff --git a/tests/test_correction/clarreo/_image_match_helpers.py b/tests/test_correction/clarreo/_image_match_helpers.py new file mode 100644 index 00000000..f240493e --- /dev/null +++ b/tests/test_correction/clarreo/_image_match_helpers.py @@ -0,0 +1,342 @@ +""" +Image-matching helpers for CLARREO integration tests. + +Provides utilities for discovering test cases, applying artificial +geolocation errors, and running image matching with those errors. +These are *test infrastructure helpers*, not pytest tests themselves. +""" + +from __future__ import annotations + +import logging +import time +from pathlib import Path + +import numpy as np +import xarray as xr +from scipy.io import loadmat + +from curryer.correction.data_structures import ImageGrid, PSFSamplingConfig, SearchConfig +from curryer.correction.image_match import ( + integrated_image_match, + load_image_grid_from_mat, + load_los_vectors_from_mat, + load_optical_psf_from_mat, +) + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Test-case metadata +# --------------------------------------------------------------------------- + +_TEST_CASE_METADATA: dict[str, dict] = { + "1": { + "name": "Dili", + "gcp_file": "GCP12055Dili_resampled.mat", + "ancil_file": "R_ISS_midframe_TestCase1.mat", + "expected_error_km": (3.0, -3.0), + "cases": [ + {"subimage": "TestCase1a_subimage.mat", "binned": False}, + {"subimage": "TestCase1b_subimage.mat", "binned": False}, + {"subimage": "TestCase1c_subimage_binned.mat", "binned": True}, + {"subimage": "TestCase1d_subimage_binned.mat", "binned": True}, + ], + }, + "2": { + "name": "Maracaibo", + "gcp_file": "GCP10121Maracaibo_resampled.mat", + "ancil_file": "R_ISS_midframe_TestCase2.mat", + "expected_error_km": (-3.0, 2.0), + "cases": [ + {"subimage": "TestCase2a_subimage.mat", "binned": False}, + {"subimage": "TestCase2b_subimage.mat", "binned": False}, + {"subimage": "TestCase2c_subimage_binned.mat", "binned": True}, + ], + }, + "3": { + "name": "Algeria3", + "gcp_file": "GCP10181Algeria3_resampled.mat", + "ancil_file": "R_ISS_midframe_TestCase3.mat", + "expected_error_km": (2.0, 3.0), + "cases": [ + {"subimage": "TestCase3a_subimage.mat", "binned": False}, + {"subimage": "TestCase3b_subimage_binned.mat", "binned": True}, + ], + }, + "4": { + "name": "Dunhuang", + "gcp_file": "GCP10142Dunhuang_resampled.mat", + "ancil_file": "R_ISS_midframe_TestCase4.mat", + "expected_error_km": (-2.0, -3.0), + "cases": [ + {"subimage": "TestCase4a_subimage.mat", "binned": False}, + {"subimage": "TestCase4b_subimage_binned.mat", "binned": True}, + ], + }, + "5": { + "name": "Algeria5", + "gcp_file": "GCP10071Algeria5_resampled.mat", + "ancil_file": "R_ISS_midframe_TestCase5.mat", + "expected_error_km": (1.0, -1.0), + "cases": [ + {"subimage": "TestCase5a_subimage.mat", "binned": False}, + ], + }, +} + + +def discover_test_image_match_cases( + test_data_dir: Path, + test_cases: list[str] | None = None, +) -> list[dict]: + """Scan *test_data_dir* for validated image-matching test cases. + + Parameters + ---------- + test_data_dir: + Root directory (e.g. ``tests/data/clarreo/image_match/``). + test_cases: + Specific case IDs to include (e.g. ``['1', '2']``). ``None`` means + all available cases. + + Returns + ------- + list[dict] + One dict per sub-case variant with keys: ``case_id``, ``case_name``, + ``subcase_name``, ``subimage_file``, ``gcp_file``, ``ancil_file``, + ``los_file``, ``psf_file``, ``expected_lat_error_km``, + ``expected_lon_error_km``, ``binned``. + """ + logger.info("Discovering image-matching test cases in: %s", test_data_dir) + + los_file = test_data_dir / "b_HS.mat" + psf_file_unbinned = test_data_dir / "optical_PSF_675nm_upsampled.mat" + psf_file_binned = test_data_dir / "optical_PSF_675nm_3_pix_binned_upsampled.mat" + + if not los_file.exists(): + raise FileNotFoundError(f"LOS vectors file not found: {los_file}") + if not psf_file_unbinned.exists(): + raise FileNotFoundError(f"PSF file not found: {psf_file_unbinned}") + + if test_cases is None: + test_cases = sorted(_TEST_CASE_METADATA.keys()) + + discovered: list[dict] = [] + for case_id in test_cases: + if case_id not in _TEST_CASE_METADATA: + logger.warning("Test case '%s' not in metadata, skipping", case_id) + continue + meta = _TEST_CASE_METADATA[case_id] + case_dir = test_data_dir / case_id + if not case_dir.is_dir(): + logger.warning("Test case directory not found: %s, skipping", case_dir) + continue + + for subcase in meta["cases"]: + subimage_file = case_dir / subcase["subimage"] + gcp_file = case_dir / meta["gcp_file"] + ancil_file = case_dir / meta["ancil_file"] + psf_file = psf_file_binned if subcase["binned"] else psf_file_unbinned + + if not subimage_file.exists(): + logger.warning("Subimage not found: %s, skipping", subimage_file) + continue + if not gcp_file.exists(): + logger.warning("GCP file not found: %s, skipping", gcp_file) + continue + if not ancil_file.exists(): + logger.warning("Ancil file not found: %s, skipping", ancil_file) + continue + + discovered.append( + { + "case_id": case_id, + "case_name": meta["name"], + "subcase_name": subcase["subimage"], + "subimage_file": subimage_file, + "gcp_file": gcp_file, + "ancil_file": ancil_file, + "los_file": los_file, + "psf_file": psf_file, + "expected_lat_error_km": meta["expected_error_km"][0], + "expected_lon_error_km": meta["expected_error_km"][1], + "binned": subcase["binned"], + } + ) + + logger.info("Discovered %d test-case variants", len(discovered)) + return discovered + + +# --------------------------------------------------------------------------- +# Error-variation helpers +# --------------------------------------------------------------------------- + + +def apply_error_variation_for_testing( + base_result: xr.Dataset, + param_idx: int, + error_variation_percent: float = 3.0, +) -> xr.Dataset: + """Apply random variation to image-matching results to simulate parameter effects. + + Uses *param_idx* as a reproducible random seed so each parameter set gets + a distinct but deterministic perturbation. + """ + output = base_result.copy(deep=True) + rng = np.random.default_rng(param_idx) + + vf = error_variation_percent / 100.0 + lat_factor = 1.0 + rng.normal(0, vf) + lon_factor = 1.0 + rng.normal(0, vf) + ccv_factor = 1.0 + rng.normal(0, vf / 10.0) + + orig_lat = base_result.attrs["lat_error_km"] + orig_lon = base_result.attrs["lon_error_km"] + orig_ccv = base_result.attrs["correlation_ccv"] + + varied_lat = orig_lat * lat_factor + varied_lon = orig_lon * lon_factor + varied_ccv = float(np.clip(orig_ccv * ccv_factor, 0.0, 1.0)) + + if "gcp_center_lat" in base_result.attrs: + gcp_center_lat = base_result.attrs["gcp_center_lat"] + elif "gcp_lat_deg" in base_result: + gcp_center_lat = float(base_result["gcp_lat_deg"].values[0]) + else: + gcp_center_lat = 45.0 + + lat_error_deg = varied_lat / 111.0 + lon_radius_km = 6378.0 * np.cos(np.deg2rad(gcp_center_lat)) + lon_error_deg = varied_lon / (lon_radius_km * np.pi / 180.0) + + output["lat_error_deg"].values[0] = lat_error_deg + output["lon_error_deg"].values[0] = lon_error_deg + output.attrs.update( + { + "lat_error_km": varied_lat, + "lon_error_km": varied_lon, + "correlation_ccv": varied_ccv, + "param_idx": param_idx, + "variation_applied": True, + } + ) + return output + + +def apply_geolocation_error_to_subimage( + subimage: ImageGrid, + gcp: ImageGrid, + lat_error_km: float, + lon_error_km: float, +) -> ImageGrid: + """Shift *subimage* coordinates by (lat_error_km, lon_error_km) for testing.""" + from curryer.compute import constants + + mid_lat = float(gcp.lat[gcp.lat.shape[0] // 2, gcp.lat.shape[1] // 2]) + earth_radius_km = constants.WGS84_SEMI_MAJOR_AXIS_KM + lat_offset_deg = lat_error_km / earth_radius_km * (180.0 / np.pi) + lon_radius_km = earth_radius_km * np.cos(np.deg2rad(mid_lat)) + lon_offset_deg = lon_error_km / lon_radius_km * (180.0 / np.pi) + + return ImageGrid( + data=subimage.data.copy(), + lat=subimage.lat + lat_offset_deg, + lon=subimage.lon + lon_offset_deg, + h=subimage.h.copy() if subimage.h is not None else None, + ) + + +def run_image_matching_with_applied_errors( + test_case: dict, + param_idx: int, + randomize_errors: bool = True, + error_variation_percent: float = 3.0, + cache_results: bool = True, + cached_result: xr.Dataset | None = None, +) -> xr.Dataset: + """Run image matching with artificial errors; return result Dataset. + + Uses *cached_result* + random variation for param_idx > 0 when + *cache_results* is True. + """ + if cached_result is not None and cache_results and param_idx > 0: + if randomize_errors: + return apply_error_variation_for_testing(cached_result, param_idx, error_variation_percent) + return cached_result.copy() + + logger.info("Running image matching: %s", test_case["case_name"]) + start = time.time() + + subimage_struct = loadmat(test_case["subimage_file"], squeeze_me=True, struct_as_record=False)["subimage"] + subimage = ImageGrid( + data=np.asarray(subimage_struct.data), + lat=np.asarray(subimage_struct.lat), + lon=np.asarray(subimage_struct.lon), + h=np.asarray(subimage_struct.h) if hasattr(subimage_struct, "h") else None, + ) + gcp = load_image_grid_from_mat(test_case["gcp_file"], key="GCP") + gcp_center_lat = float(gcp.lat[gcp.lat.shape[0] // 2, gcp.lat.shape[1] // 2]) + gcp_center_lon = float(gcp.lon[gcp.lon.shape[0] // 2, gcp.lon.shape[1] // 2]) + + subimage_with_error = apply_geolocation_error_to_subimage( + subimage, gcp, test_case["expected_lat_error_km"], test_case["expected_lon_error_km"] + ) + + los_vectors = load_los_vectors_from_mat(test_case["los_file"]) + optical_psfs = load_optical_psf_from_mat(test_case["psf_file"]) + ancil_data = loadmat(test_case["ancil_file"], squeeze_me=True) + r_iss_midframe = ancil_data["R_ISS_midframe"].ravel() + + result = integrated_image_match( + subimage=subimage_with_error, + gcp=gcp, + r_iss_midframe_m=r_iss_midframe, + los_vectors_hs=los_vectors, + optical_psfs=optical_psfs, + geolocation_config=PSFSamplingConfig(), + search_config=SearchConfig(), + ) + processing_time = time.time() - start + + lat_error_deg = result.lat_error_km / 111.0 + lon_radius_km = 6378.0 * np.cos(np.deg2rad(gcp_center_lat)) + lon_error_deg = result.lon_error_km / (lon_radius_km * np.pi / 180.0) + + t_matrix = np.array( + [ + [-0.418977524967338, 0.748005379751721, 0.514728846515064], + [-0.421890284446342, 0.341604851993858, -0.839830169131854], + [-0.804031356019172, -0.569029065124742, 0.172451447025628], + ] + ) + boresight = np.array([0.0, 0.0625969755450201, 0.99803888634292]) + + output = xr.Dataset( + { + "lat_error_deg": (["measurement"], [lat_error_deg]), + "lon_error_deg": (["measurement"], [lon_error_deg]), + "riss_ctrs": (["measurement", "xyz"], [r_iss_midframe]), + "bhat_hs": (["measurement", "xyz"], [boresight]), + "t_hs2ctrs": (["measurement", "xyz_from", "xyz_to"], t_matrix[np.newaxis, :, :]), + "gcp_lat_deg": (["measurement"], [gcp_center_lat]), + "gcp_lon_deg": (["measurement"], [gcp_center_lon]), + "gcp_alt": (["measurement"], [0.0]), + }, + coords={"measurement": [0], "xyz": ["x", "y", "z"], "xyz_from": ["x", "y", "z"], "xyz_to": ["x", "y", "z"]}, + ) + output.attrs.update( + { + "lat_error_km": result.lat_error_km, + "lon_error_km": result.lon_error_km, + "correlation_ccv": result.ccv_final, + "final_grid_step_m": result.final_grid_step_m, + "processing_time_s": processing_time, + "test_mode": True, + "param_idx": param_idx, + "gcp_center_lat": gcp_center_lat, + "gcp_center_lon": gcp_center_lon, + } + ) + return output diff --git a/tests/test_correction/clarreo/_pipeline_helpers.py b/tests/test_correction/clarreo/_pipeline_helpers.py new file mode 100644 index 00000000..98debf72 --- /dev/null +++ b/tests/test_correction/clarreo/_pipeline_helpers.py @@ -0,0 +1,281 @@ +""" +Pipeline runner helpers for CLARREO integration tests. + +Provides ``run_upstream_pipeline`` and ``run_downstream_pipeline``, +which exercise the upstream (kernel creation + geolocation) and +downstream (GCP pairing + image matching + error statistics) halves +of the Correction pipeline respectively. + +These are *test infrastructure helpers*, not pytest tests themselves. +""" + +from __future__ import annotations + +import atexit +import logging +import shutil +import tempfile +from pathlib import Path + +import numpy as np +import xarray as xr + +from curryer.correction import correction +from curryer.correction.config import DataConfig + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Lazy imports (avoid hard-coding sys.path assumptions at module import time) +# --------------------------------------------------------------------------- + + +def _load_clarreo_loaders(): + from _image_match_helpers import discover_test_image_match_cases, run_image_matching_with_applied_errors + from clarreo_config import create_clarreo_correction_config + from clarreo_data_loaders import load_clarreo_science, load_clarreo_telemetry + + return ( + create_clarreo_correction_config, + load_clarreo_telemetry, + load_clarreo_science, + discover_test_image_match_cases, + run_image_matching_with_applied_errors, + ) + + +# --------------------------------------------------------------------------- +# Upstream pipeline +# --------------------------------------------------------------------------- + + +def run_upstream_pipeline( + n_iterations: int = 5, + work_dir: Path | None = None, +) -> tuple[list, dict, Path]: + """Test the upstream segment: parameter generation → kernel creation → geolocation. + + Uses ``synthetic_image_matching`` so it does NOT require valid GCP pairs. + + Returns + ------- + (results_list, results_summary_dict, output_file_path) + """ + from _synthetic_helpers import synthetic_image_matching + + ( + create_clarreo_correction_config, + load_clarreo_telemetry, + load_clarreo_science, + _, + _, + ) = _load_clarreo_loaders() + + logger.info("=== UPSTREAM PIPELINE TEST ===") + + root_dir = Path(__file__).parents[3] + generic_dir = root_dir / "data" / "generic" + data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" + + if work_dir is None: + _tmp = tempfile.mkdtemp(prefix="curryer_upstream_") + work_dir = Path(_tmp) + atexit.register(shutil.rmtree, work_dir, True) + logger.info("Temporary work dir: %s", work_dir) + else: + work_dir.mkdir(parents=True, exist_ok=True) + + config = create_clarreo_correction_config(data_dir, generic_dir) + config.n_iterations = n_iterations + config.output_filename = "upstream_results.nc" + + tlm_df = load_clarreo_telemetry(data_dir) + sci_df = load_clarreo_science(data_dir) + + tlm_csv = work_dir / "clarreo_telemetry.csv" + sci_csv = work_dir / "clarreo_science.csv" + tlm_df.to_csv(tlm_csv) + sci_df.to_csv(sci_csv) + + config.data = DataConfig(file_format="csv", time_scale_factor=1e6) + config.image_matching_func = synthetic_image_matching + + tlm_sci_gcp_sets = [(str(tlm_csv), str(sci_csv), "synthetic_gcp.mat")] + + logger.info("Executing Correction upstream workflow (%d iterations)…", n_iterations) + results, netcdf_data = correction.loop(config, work_dir, tlm_sci_gcp_sets) + + output_file = work_dir / config.get_output_filename() + logger.info("Upstream pipeline complete. Output: %s", output_file) + + summary = { + "mode": "upstream", + "iterations": n_iterations, + "parameter_sets": len(netcdf_data["parameter_set_id"]), + "status": "complete", + } + return results, summary, output_file + + +# --------------------------------------------------------------------------- +# Downstream pipeline +# --------------------------------------------------------------------------- + + +def run_downstream_pipeline( + n_iterations: int = 5, + test_cases: list[str] | None = None, + work_dir: Path | None = None, +) -> tuple[list, dict, Path]: + """Test the downstream segment: GCP pairing → image matching → error statistics. + + Uses pre-geolocated test data (.mat files) with known errors applied. + Does NOT test kernel creation or SPICE geolocation. + + Returns + ------- + (results_list, results_summary_dict, output_file_path) + """ + from _image_match_helpers import discover_test_image_match_cases, run_image_matching_with_applied_errors + from clarreo_config import create_clarreo_correction_config + + from curryer.correction.image_match import load_image_grid_from_mat + from curryer.correction.pairing import find_l1a_gcp_pairs + + logger.info("=== DOWNSTREAM PIPELINE TEST ===") + + root_dir = Path(__file__).parents[3] + generic_dir = root_dir / "data" / "generic" + data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" + test_data_dir = root_dir / "tests" / "data" / "clarreo" / "image_match" + + if work_dir is None: + _tmp = tempfile.mkdtemp(prefix="curryer_downstream_") + work_dir = Path(_tmp) + atexit.register(shutil.rmtree, work_dir, True) + logger.info("Temporary work dir: %s", work_dir) + else: + work_dir.mkdir(parents=True, exist_ok=True) + + # --- STEP 1: discover test cases --- + discovered_cases = discover_test_image_match_cases(test_data_dir, test_cases) + if not discovered_cases: + raise RuntimeError(f"No test cases found in {test_data_dir} (test_cases={test_cases})") + + # --- STEP 2: GCP spatial pairing --- + l1a_images = [] + l1a_to_testcase: dict[str, dict] = {} + for tc in discovered_cases: + img = load_image_grid_from_mat( + tc["subimage_file"], + key="subimage", + as_named=True, + name=str(tc["subimage_file"].relative_to(test_data_dir)), + ) + l1a_images.append(img) + l1a_to_testcase[img.name] = tc + + gcp_files_seen: set = set() + gcp_images = [] + for tc in discovered_cases: + if tc["gcp_file"] not in gcp_files_seen: + gcp_img = load_image_grid_from_mat( + tc["gcp_file"], + key="GCP", + as_named=True, + name=str(tc["gcp_file"].relative_to(test_data_dir)), + ) + gcp_images.append(gcp_img) + gcp_files_seen.add(tc["gcp_file"]) + + pairing_result = find_l1a_gcp_pairs(l1a_images, gcp_images, max_distance_m=0.0) + paired_test_cases = [l1a_to_testcase[pairing_result.l1a_images[m.l1a_index].name] for m in pairing_result.matches] + n_gcp_pairs = len(paired_test_cases) + logger.info("Pairing complete: %d valid pairs", n_gcp_pairs) + + # --- STEP 3: build config --- + base_config = create_clarreo_correction_config(data_dir, generic_dir) + config = correction.CorrectionConfig( + seed=42, + n_iterations=n_iterations, + parameters=[ + correction.ParameterConfig( + ptype=correction.ParameterType.CONSTANT_KERNEL, + config_file=data_dir / "cprs_hysics_v01.attitude.ck.json", + data={ + "current_value": [0.0, 0.0, 0.0], + "sigma": 0.0, + "units": "arcseconds", + "transformation_type": "dcm_rotation", + "coordinate_frames": ["HYSICS_SLIT", "CRADLE_ELEVATION"], + }, + ) + ], + geo=base_config.geo, + performance_threshold_m=base_config.performance_threshold_m, + performance_spec_percent=base_config.performance_spec_percent, + earth_radius_m=base_config.earth_radius_m, + netcdf=base_config.netcdf, + calibration_file_names=base_config.calibration_file_names, + spacecraft_position_name=base_config.spacecraft_position_name, + boresight_name=base_config.boresight_name, + transformation_matrix_name=base_config.transformation_matrix_name, + ) + config.data = DataConfig(file_format="csv", time_scale_factor=1e6) + config.validate() + + # --- STEP 4: iterate --- + netcdf_data = correction._build_netcdf_structure(config, n_iterations, n_gcp_pairs) + threshold_metric = config.netcdf.get_threshold_metric_name() + image_match_cache: dict[str, xr.Dataset] = {} + + for param_idx in range(n_iterations): + logger.info("Iteration %d/%d", param_idx + 1, n_iterations) + pair_errors = [] + + for pair_idx, tc in enumerate(paired_test_cases): + cache_key = f"{tc['case_id']}_{tc.get('subcase_name', '')}" + cached = image_match_cache.get(cache_key) + + output = run_image_matching_with_applied_errors( + tc, + param_idx, + randomize_errors=True, + error_variation_percent=3.0, + cache_results=True, + cached_result=cached, + ) + if cache_key not in image_match_cache: + image_match_cache[cache_key] = output + + rms_m = np.sqrt((output.attrs["lat_error_km"] * 1000) ** 2 + (output.attrs["lon_error_km"] * 1000) ** 2) + pair_errors.append(rms_m) + + netcdf_data["rms_error_m"][param_idx, pair_idx] = rms_m + netcdf_data["mean_error_m"][param_idx, pair_idx] = rms_m + netcdf_data["max_error_m"][param_idx, pair_idx] = rms_m + netcdf_data["n_measurements"][param_idx, pair_idx] = 1 + netcdf_data["im_lat_error_km"][param_idx, pair_idx] = output.attrs["lat_error_km"] + netcdf_data["im_lon_error_km"][param_idx, pair_idx] = output.attrs["lon_error_km"] + netcdf_data["im_ccv"][param_idx, pair_idx] = output.attrs["correlation_ccv"] + netcdf_data["im_grid_step_m"][param_idx, pair_idx] = output.attrs["final_grid_step_m"] + + valid = np.array([e for e in pair_errors if not np.isnan(e)]) + if len(valid) > 0: + pct = (valid < config.performance_threshold_m).sum() / len(valid) * 100 + netcdf_data[threshold_metric][param_idx] = pct + netcdf_data["mean_rms_all_pairs"][param_idx] = np.mean(valid) + netcdf_data["best_pair_rms"][param_idx] = np.min(valid) + netcdf_data["worst_pair_rms"][param_idx] = np.max(valid) + + # --- STEP 5: error statistics --- + image_matching_results = list(image_match_cache.values()) + correction.call_error_stats_module(image_matching_results, correction_config=config) + + # --- STEP 6: save --- + output_file = work_dir / "downstream_results.nc" + correction._save_netcdf_results(netcdf_data, output_file, config) + + logger.info("Downstream pipeline complete. Output: %s", output_file) + summary = {"mode": "downstream", "iterations": n_iterations, "test_pairs": n_gcp_pairs, "status": "complete"} + return [], summary, output_file diff --git a/tests/test_correction/clarreo_config.py b/tests/test_correction/clarreo/clarreo_config.py similarity index 100% rename from tests/test_correction/clarreo_config.py rename to tests/test_correction/clarreo/clarreo_config.py diff --git a/tests/test_correction/clarreo_data_loaders.py b/tests/test_correction/clarreo/clarreo_data_loaders.py similarity index 100% rename from tests/test_correction/clarreo_data_loaders.py rename to tests/test_correction/clarreo/clarreo_data_loaders.py diff --git a/tests/test_correction/clarreo/conftest.py b/tests/test_correction/clarreo/conftest.py new file mode 100644 index 00000000..db32fd7f --- /dev/null +++ b/tests/test_correction/clarreo/conftest.py @@ -0,0 +1,32 @@ +"""Pytest configuration for CLARREO integration tests. + +Exposes session-scoped path fixtures for the CLARREO test data directories. +The ``clarreo/`` directory is already on ``sys.path`` via the parent +``test_correction/conftest.py``, so test files here can import +``clarreo_config`` and ``clarreo_data_loaders`` without any additional +``sys.path`` manipulation. +""" + +from pathlib import Path + +import pytest + + +@pytest.fixture(scope="session") +def clarreo_root(): + return Path(__file__).parents[3] + + +@pytest.fixture(scope="session") +def clarreo_gcs_data_dir(clarreo_root): + return clarreo_root / "tests" / "data" / "clarreo" / "gcs" + + +@pytest.fixture(scope="session") +def clarreo_image_match_data_dir(clarreo_root): + return clarreo_root / "tests" / "data" / "clarreo" / "image_match" + + +@pytest.fixture(scope="session") +def clarreo_generic_dir(clarreo_root): + return clarreo_root / "data" / "generic" diff --git a/tests/test_correction/clarreo/test_clarreo_config.py b/tests/test_correction/clarreo/test_clarreo_config.py new file mode 100644 index 00000000..5b7c597b --- /dev/null +++ b/tests/test_correction/clarreo/test_clarreo_config.py @@ -0,0 +1,94 @@ +"""Tests for CLARREO correction configuration generation and JSON serialisation.""" + +from __future__ import annotations + +import json +import logging + +import pytest +from clarreo_config import create_clarreo_correction_config + +from curryer.correction import correction +from curryer.correction.config import DataConfig + +logger = logging.getLogger(__name__) + + +def test_generate_clarreo_config_json(tmp_path, clarreo_gcs_data_dir, clarreo_generic_dir): + """Generate the CLARREO config JSON and validate structure end-to-end.""" + output_path = tmp_path / "configs" / "clarreo_correction_config.json" + + config = create_clarreo_correction_config(clarreo_gcs_data_dir, clarreo_generic_dir, config_output_path=output_path) + + assert output_path.exists() + + with open(output_path) as fh: + config_data = json.load(fh) + + assert "mission_config" in config_data + assert "correction" in config_data + assert "geolocation" in config_data + assert config_data["mission_config"]["mission_name"] == "CLARREO_Pathfinder" + + corr = config_data["correction"] + assert isinstance(corr.get("parameters"), list) + assert len(corr["parameters"]) > 0 + assert corr["earth_radius_m"] == 6378140.0 + assert corr["performance_threshold_m"] == 250.0 + assert corr["performance_spec_percent"] == 39.0 + assert config_data["geolocation"]["instrument_name"] == "CPRS_HYSICS" + + reloaded = correction.load_config_from_json(output_path) + assert reloaded.n_iterations == config.n_iterations + assert len(reloaded.parameters) == len(config.parameters) + assert reloaded.earth_radius_m == 6378140.0 + reloaded.validate() + + # Verify each reloaded parameter has the expected ptype, config_file (for + # kernel-based params), and data.field (for OFFSET_KERNEL / OFFSET_TIME), + # so this test will fail if the JSON can't be used to run correction.loop(). + valid_ptypes = { + correction.ParameterType.CONSTANT_KERNEL, + correction.ParameterType.OFFSET_KERNEL, + correction.ParameterType.OFFSET_TIME, + } + for param in reloaded.parameters: + assert param.ptype in valid_ptypes, f"Unexpected ptype: {param.ptype}" + + if param.ptype in (correction.ParameterType.CONSTANT_KERNEL, correction.ParameterType.OFFSET_KERNEL): + assert param.config_file is not None, f"{param.ptype.name} parameter must have a config_file, got None" + + if param.ptype in (correction.ParameterType.OFFSET_KERNEL, correction.ParameterType.OFFSET_TIME): + assert param.data.field is not None, f"{param.ptype.name} parameter must have data.field set, got None" + + # Count parameters by type to ensure the expected composition is preserved. + ptypes = [p.ptype for p in reloaded.parameters] + assert ptypes.count(correction.ParameterType.CONSTANT_KERNEL) >= 1 + assert ptypes.count(correction.ParameterType.OFFSET_KERNEL) >= 1 + assert ptypes.count(correction.ParameterType.OFFSET_TIME) >= 1 + + +class TestClarreoConfiguration: + """Smoke tests for the CLARREO CorrectionConfig object.""" + + @pytest.fixture(autouse=True) + def _setup(self, clarreo_gcs_data_dir, clarreo_generic_dir): + self.data_dir = clarreo_gcs_data_dir + self.generic_dir = clarreo_generic_dir + + def test_config_validates(self): + config = create_clarreo_correction_config(self.data_dir, self.generic_dir) + config.data = DataConfig(file_format="csv", time_scale_factor=1e6) + config.validate() + assert config.geo.instrument_name == "CPRS_HYSICS" + assert config.seed == 42 + + def test_parameter_count(self): + config = create_clarreo_correction_config(self.data_dir, self.generic_dir) + assert len(config.parameters) == 6 + + def test_performance_thresholds(self): + config = create_clarreo_correction_config(self.data_dir, self.generic_dir) + assert config.performance_threshold_m == 250.0 + assert config.performance_spec_percent == 39.0 + assert config.earth_radius_m == 6378140.0 diff --git a/tests/test_correction/clarreo/test_clarreo_dataio.py b/tests/test_correction/clarreo/test_clarreo_dataio.py new file mode 100644 index 00000000..e6ab8fdc --- /dev/null +++ b/tests/test_correction/clarreo/test_clarreo_dataio.py @@ -0,0 +1,38 @@ +"""CLARREO-specific data I/O integration tests (require AWS credentials).""" + +from __future__ import annotations + +import datetime as dt +import logging +import os + +import pytest + +from curryer.correction.dataio import S3Configuration, find_netcdf_objects + +logger = logging.getLogger(__name__) + +_NEEDS_AWS = pytest.mark.skipif( + not ( + (os.getenv("AWS_ACCESS_KEY_ID") and os.getenv("AWS_SECRET_ACCESS_KEY") and os.getenv("AWS_SESSION_TOKEN")) + or os.getenv("C9_USER") + ), + reason="Requires AWS credentials or Cloud9 environment.", +) + + +@_NEEDS_AWS +def test_clarreo_find_l0_objects(tmp_path): + """Find L0 telemetry objects in CSDS S3 bucket.""" + config = S3Configuration("clarreo", "L0/telemetry/hps_navigation/") + keys = find_netcdf_objects(config, start_date=dt.date(2017, 1, 15), end_date=dt.date(2017, 1, 15)) + assert keys == ["L0/telemetry/hps_navigation/20170115/CPF_TLM_L0.V00-000.hps_navigation-20170115-0.0.0.nc"] + + +@_NEEDS_AWS +def test_clarreo_find_l1a_objects(tmp_path): + """Find L1a science objects in CSDS S3 bucket.""" + config = S3Configuration("clarreo", "L1a/nadir/") + keys = find_netcdf_objects(config, start_date=dt.date(2022, 6, 3), end_date=dt.date(2022, 6, 3)) + assert len(keys) == 34 + assert "L1a/nadir/20220603/nadir-20220603T235952-step22-geolocation_creation-0.0.0.nc" in keys diff --git a/tests/test_correction/clarreo/test_clarreo_e2e.py b/tests/test_correction/clarreo/test_clarreo_e2e.py new file mode 100644 index 00000000..9d26b6f2 --- /dev/null +++ b/tests/test_correction/clarreo/test_clarreo_e2e.py @@ -0,0 +1,129 @@ +"""CLARREO end-to-end integration tests. + +Exercises the upstream (kernel creation + geolocation) and downstream +(GCP pairing + image matching + error statistics) pipelines using +CLARREO test data. + +Extra tests (require GMTED data or SPICE binaries): ``pytest --run-extra`` +""" + +from __future__ import annotations + +import logging + +import numpy as np +import pytest +import xarray as xr +from _image_match_helpers import ( + apply_error_variation_for_testing, + discover_test_image_match_cases, + run_image_matching_with_applied_errors, +) +from _pipeline_helpers import run_downstream_pipeline, run_upstream_pipeline +from _synthetic_helpers import ( + _generate_nadir_aligned_transforms, + _generate_spherical_positions, + _generate_synthetic_boresights, + synthetic_gcp_pairing, +) +from clarreo_config import create_clarreo_correction_config + +from curryer.correction.config import DataConfig + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def work_dir(tmp_path): + d = tmp_path / "work" + d.mkdir() + return d + + +def test_upstream_configuration(clarreo_gcs_data_dir, clarreo_generic_dir): + """Upstream configuration loads and validates correctly.""" + config = create_clarreo_correction_config(clarreo_gcs_data_dir, clarreo_generic_dir) + config.data = DataConfig(file_format="csv", time_scale_factor=1e6) + config.validate() + assert config.geo.instrument_name == "CPRS_HYSICS" + assert len(config.parameters) > 0 + assert config.seed == 42 + assert config.data is not None + + +def test_downstream_test_case_discovery(clarreo_image_match_data_dir): + """Downstream test cases can be discovered from the data directory.""" + test_cases = discover_test_image_match_cases(clarreo_image_match_data_dir) + assert len(test_cases) > 0 + for tc in test_cases: + assert "case_id" in tc + assert "subimage_file" in tc + assert "gcp_file" in tc + assert tc["subimage_file"].exists() + assert tc["gcp_file"].exists() + + +def test_downstream_image_matching(clarreo_image_match_data_dir): + """Downstream image matching runs successfully on test case 1.""" + test_cases = discover_test_image_match_cases(clarreo_image_match_data_dir, test_cases=["1"]) + assert len(test_cases) > 0 + result = run_image_matching_with_applied_errors( + test_cases[0], param_idx=0, randomize_errors=False, cache_results=True + ) + assert isinstance(result, xr.Dataset) + assert "lat_error_km" in result.attrs + assert "lon_error_km" in result.attrs + + +@pytest.mark.extra +def test_upstream_quick(clarreo_gcs_data_dir, clarreo_generic_dir, work_dir): + """Quick upstream pipeline (2 iterations). Requires GMTED – ``--run-extra``.""" + results_list, results_dict, output_file = run_upstream_pipeline(n_iterations=2, work_dir=work_dir) + assert results_dict["status"] == "complete" + assert results_dict["iterations"] == 2 + assert results_dict["mode"] == "upstream" + assert results_dict["parameter_sets"] > 0 + assert output_file.exists() + + +def test_downstream_quick(work_dir, clarreo_image_match_data_dir): + """Quick downstream pipeline (2 iterations, test case 1).""" + results_list, results_dict, output_file = run_downstream_pipeline( + n_iterations=2, test_cases=["1"], work_dir=work_dir + ) + assert results_dict["status"] == "complete" + assert results_dict["iterations"] == 2 + assert output_file.exists() + + +def test_downstream_helpers_basic(clarreo_image_match_data_dir): + """Downstream helper functions work correctly.""" + test_cases = discover_test_image_match_cases(clarreo_image_match_data_dir, test_cases=["1"]) + assert len(test_cases) > 0 + assert "case_id" in test_cases[0] + + base = xr.Dataset( + {"lat_error_deg": (["m"], [0.001]), "lon_error_deg": (["m"], [0.002])}, + attrs={"lat_error_km": 0.1, "lon_error_km": 0.2, "correlation_ccv": 0.95}, + ) + varied = apply_error_variation_for_testing(base, param_idx=1, error_variation_percent=3.0) + assert isinstance(varied, xr.Dataset) + assert varied.attrs["lat_error_km"] != base.attrs["lat_error_km"] + + +def test_synthetic_helpers_basic(): + """Generic synthetic helper functions produce correctly-shaped outputs.""" + pairs = synthetic_gcp_pairing(["science_1.nc", "science_2.nc"]) + assert len(pairs) == 2 + + boresights = _generate_synthetic_boresights(5, max_off_nadir_rad=0.07) + assert boresights.shape == (5, 3) + assert np.all(np.abs(boresights[:, 0]) < 0.01) + + positions = _generate_spherical_positions(5, 6.78e6, 4e3) + assert positions.shape == (5, 3) + assert np.all(np.linalg.norm(positions, axis=1) > 6.7e6) + + transforms = _generate_nadir_aligned_transforms(5, positions, boresights) + assert transforms.shape == (5, 3, 3) + assert abs(abs(np.linalg.det(transforms[0])) - 1.0) < 0.2 diff --git a/tests/test_correction/conftest.py b/tests/test_correction/conftest.py index 9405493e..76163452 100644 --- a/tests/test_correction/conftest.py +++ b/tests/test_correction/conftest.py @@ -1,111 +1,41 @@ -""" -Pytest configuration and shared fixtures for correction module tests. +"""Generic pytest configuration for test_correction. -This module provides shared fixtures used across multiple test files to ensure -consistent configuration and data loading. +- Adds ``clarreo/`` sub-directory to ``sys.path`` so test files in this package + can import ``clarreo_config``, ``clarreo_data_loaders``, etc. without + repeating ``sys.path`` manipulation in every file. +- Adds this directory itself to ``sys.path`` so ``_synthetic_helpers`` is + importable by the ``clarreo/`` sub-package helpers. """ +import sys from pathlib import Path import pytest -from clarreo_config import create_clarreo_correction_config - -from curryer.correction import correction - - -@pytest.fixture(scope="session") -def clarreo_config_from_json(): - """Load CLARREO config from JSON (session-scoped for efficiency). - - This fixture loads the canonical CLARREO configuration from JSON file. - The JSON file should be generated by test_generate_clarreo_config_json() - and committed to version control as the single source of truth. - - Session scope ensures the config is loaded once and reused across all tests. - """ - config_path = Path(__file__).parent / "configs/clarreo_correction_config.json" - - if not config_path.exists(): - pytest.skip(f"Config file not found: {config_path}. Run test_generate_clarreo_config_json() first.") - - return correction.load_config_from_json(config_path) - - -@pytest.fixture(scope="session") -def clarreo_config_programmatic(): - """Generate CLARREO config programmatically (for comparison tests). - - This fixture creates the configuration programmatically using the - create_clarreo_correction_config() function. Useful for testing - that programmatic and JSON configs produce equivalent results. - - Session scope for efficiency across multiple tests. - """ - data_dir = Path(__file__).parent.parent / "data/clarreo/gcs" - generic_dir = Path("data/generic") - return create_clarreo_correction_config(data_dir, generic_dir) +_here = str(Path(__file__).parent) +_clarreo_dir = str(Path(__file__).parent / "clarreo") -@pytest.fixture(scope="session") -def clarreo_data_dir(): - """Path to CLARREO test data directory. - - Returns: - Path: tests/data/clarreo/ - """ - return Path(__file__).parent.parent / "data/clarreo" - - -@pytest.fixture(scope="session") -def clarreo_gcs_data_dir(clarreo_data_dir): - """Path to CLARREO GCS test data directory. - - Returns: - Path: tests/data/clarreo/gcs/ - """ - return clarreo_data_dir / "gcs" - - -@pytest.fixture(scope="session") -def clarreo_image_match_data_dir(clarreo_data_dir): - """Path to CLARREO image matching test data directory. - - Returns: - Path: tests/data/clarreo/image_match/ - """ - return clarreo_data_dir / "image_match" +for _p in (_here, _clarreo_dir): + if _p not in sys.path: + sys.path.insert(0, _p) @pytest.fixture(scope="session") -def generic_kernel_dir(): - """Path to generic SPICE kernels directory. - - Returns: - Path: data/generic/ - """ - return Path("data/generic") +def root_dir(): + """Repository root directory (two levels above ``tests/test_correction/``).""" + return Path(__file__).parents[2] @pytest.fixture def temp_work_dir(tmp_path): - """Create temporary working directory for test outputs. - - This fixture provides a clean temporary directory for each test - that needs to write files (NetCDF outputs, intermediate results, etc.) - - Function scope ensures each test gets a fresh directory. - - Returns: - Path: Temporary directory path - """ - work_dir = tmp_path / "correction_work" - work_dir.mkdir(parents=True, exist_ok=True) - return work_dir + """Clean temporary working directory for each test.""" + work = tmp_path / "correction_work" + work.mkdir(parents=True, exist_ok=True) + return work -# Configuration for pytest def pytest_configure(config): - """Configure pytest with custom markers.""" - config.addinivalue_line("markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')") + """Register custom markers.""" + config.addinivalue_line("markers", "slow: marks tests as slow") config.addinivalue_line("markers", "integration: marks tests as integration tests") - config.addinivalue_line("markers", "requires_gcs: marks tests that require GCS credentials") + config.addinivalue_line("markers", "requires_gcs: marks tests requiring GCS credentials") diff --git a/tests/test_correction/test_correction.py b/tests/test_correction/test_correction.py deleted file mode 100644 index c743616d..00000000 --- a/tests/test_correction/test_correction.py +++ /dev/null @@ -1,2544 +0,0 @@ -#!/usr/bin/env python3 -""" -Unified Correction Test Suite - -This module consolidates two complementary Correction test approaches: - -1. UPSTREAM Testing (run_upstream_pipeline): - - Tests kernel creation and geolocation with parameter variations - - Uses real telemetry data - - Validates parameter modification and kernel generation - - Stops before pairing (no valid GCP pairs available) - -2. DOWNSTREAM Testing (run_downstream_pipeline): - - Tests GCP pairing, image matching, and error statistics - - Uses pre-geolocated test images with known GCP pairs - - Validates spatial pairing, image matching algorithms, and error metrics - - Skips kernel/geolocation (uses pre-computed test data) - -Both tests share the same CLARREO configuration base but configure differently -for their specific testing needs. - -Running Tests: -------------- -# Via pytest (recommended) -pytest tests/test_correction/test_correction.py -v - -# Run specific test -pytest tests/test_correction/test_correction.py::test_generate_clarreo_config_json -v - -# Standalone execution with arguments (for pipeline runs) -python tests/test_correction/test_correction.py --mode downstream --quick - -Requirements: ------------------ -These tests validate the complete Correction geolocation pipeline, -demonstrating parameter sensitivity analysis and error statistics -computation for mission requirements validation. - -""" - -import argparse -import atexit -import json -import logging -import shutil -import sys -import tempfile -import time -import unittest -from pathlib import Path - -import numpy as np -import pandas as pd -import pytest -import xarray as xr -from scipy.io import loadmat - -from curryer import meta, utils -from curryer import spicierpy as sp -from curryer.compute import constants -from curryer.correction import correction -from curryer.correction.config import DataConfig -from curryer.correction.data_structures import ( - ImageGrid, - PSFSamplingConfig, - SearchConfig, -) -from curryer.correction.image_match import ( - integrated_image_match, - load_image_grid_from_mat, - load_los_vectors_from_mat, - load_optical_psf_from_mat, -) -from curryer.correction.pairing import find_l1a_gcp_pairs -from curryer.kernels import create - -# Import CLARREO config and data loaders -sys.path.insert(0, str(Path(__file__).parent)) -from clarreo_config import create_clarreo_correction_config -from clarreo_data_loaders import load_clarreo_science, load_clarreo_telemetry - -logger = logging.getLogger(__name__) -utils.enable_logging(log_level=logging.INFO, extra_loggers=[__name__]) - - -# ============================================================================= -# TEST PLACEHOLDER FUNCTIONS (For Synthetic Data Generation) -# ============================================================================= -# These functions generate SYNTHETIC test data for testing the Correction pipeline - - -class _PlaceholderConfig: - """Configuration for placeholder test data generation.""" - - base_error_m: float = 50.0 - param_error_scale: float = 10.0 - max_measurements: int = 100 - min_measurements: int = 10 - orbit_radius_mean_m: float = 6.78e6 - orbit_radius_std_m: float = 4e3 - latitude_range: tuple[float, float] = (-60.0, 60.0) - longitude_range: tuple[float, float] = (-180.0, 180.0) - altitude_range: tuple[float, float] = (0.0, 1000.0) - max_off_nadir_rad: float = 0.1 - - -def synthetic_gcp_pairing(science_data_files): - """Generate SYNTHETIC GCP pairs for testing (TEST ONLY - not a test itself).""" - logger.warning("=" * 80) - logger.warning("!!!!️ USING SYNTHETIC GCP PAIRING - FAKE DATA! !!!!️") - logger.warning("=" * 80) - synthetic_pairs = [(f"{sci_file}", f"landsat_gcp_{i:03d}.tif") for i, sci_file in enumerate(science_data_files)] - return synthetic_pairs - - -def synthetic_image_matching( - geolocated_data, - gcp_reference_file, - telemetry, - calibration_dir, - params_info, - config, - los_vectors_cached=None, - optical_psfs_cached=None, -): - """ - Generate SYNTHETIC image matching error data (TEST ONLY - not a test itself). - - This function matches the signature of the real image_matching() function - but only uses a subset of parameters for upstream testing of the Correction. - - Used parameters: - geolocated_data: For generating realistic synthetic errors - params_info: To scale errors based on parameter variations - config: For coordinate names and placeholder configuration - - Ignored parameters (accepted for compatibility): - gcp_reference_file: Not needed for synthetic data - telemetry: Not needed for synthetic data - calibration_dir: Not needed for synthetic data - los_vectors_cached: Not needed for synthetic data - optical_psfs_cached: Not needed for synthetic data - """ - logger.warning("=" * 80) - logger.warning("!!!!️ USING SYNTHETIC IMAGE MATCHING - FAKE DATA! !!!!️") - logger.warning("=" * 80) - - placeholder_cfg = ( - config.placeholder if hasattr(config, "placeholder") and config.placeholder else _PlaceholderConfig() - ) - sc_pos_name = getattr(config, "spacecraft_position_name", "sc_position") - boresight_name = getattr(config, "boresight_name", "boresight") - transform_name = getattr(config, "transformation_matrix_name", "t_inst2ref") - - valid_mask = ~np.isnan(geolocated_data["latitude"].values).any(axis=1) - n_valid = valid_mask.sum() - n_measurements = ( - placeholder_cfg.min_measurements if n_valid == 0 else min(n_valid, placeholder_cfg.max_measurements) - ) - - # Generate realistic synthetic data - riss_ctrs = _generate_spherical_positions( - n_measurements, placeholder_cfg.orbit_radius_mean_m, placeholder_cfg.orbit_radius_std_m - ) - boresights = _generate_synthetic_boresights(n_measurements, placeholder_cfg.max_off_nadir_rad) - t_matrices = _generate_nadir_aligned_transforms(n_measurements, riss_ctrs, boresights) - - # Generate synthetic errors - base_error = placeholder_cfg.base_error_m - param_contribution = ( - sum(abs(p) if isinstance(p, int | float) else np.linalg.norm(p) for _, p in params_info) - * placeholder_cfg.param_error_scale - ) - error_magnitude = base_error + param_contribution - lat_errors = np.random.normal(0, error_magnitude / 111000, n_measurements) - lon_errors = np.random.normal(0, error_magnitude / 111000, n_measurements) - - if n_valid > 0: - valid_indices = np.where(valid_mask)[0][:n_measurements] - gcp_lat = geolocated_data["latitude"].values[valid_indices, 0] - gcp_lon = geolocated_data["longitude"].values[valid_indices, 0] - else: - gcp_lat = np.random.uniform(*placeholder_cfg.latitude_range, n_measurements) - gcp_lon = np.random.uniform(*placeholder_cfg.longitude_range, n_measurements) - - gcp_alt = np.random.uniform(*placeholder_cfg.altitude_range, n_measurements) - - return xr.Dataset( - { - "lat_error_deg": (["measurement"], lat_errors), - "lon_error_deg": (["measurement"], lon_errors), - sc_pos_name: (["measurement", "xyz"], riss_ctrs), - boresight_name: (["measurement", "xyz"], boresights), - transform_name: (["measurement", "xyz_from", "xyz_to"], t_matrices), - "gcp_lat_deg": (["measurement"], gcp_lat), - "gcp_lon_deg": (["measurement"], gcp_lon), - "gcp_alt": (["measurement"], gcp_alt), - }, - coords={ - "measurement": range(n_measurements), - "xyz": ["x", "y", "z"], - "xyz_from": ["x", "y", "z"], - "xyz_to": ["x", "y", "z"], - }, - ) - - -def _generate_synthetic_boresights(n_measurements, max_off_nadir_rad=0.07): - """Generate synthetic boresight vectors (test helper).""" - boresights = np.zeros((n_measurements, 3)) - for i in range(n_measurements): - theta = np.random.uniform(-max_off_nadir_rad, max_off_nadir_rad) - boresights[i] = [0.0, np.sin(theta), np.cos(theta)] - return boresights - - -def _generate_spherical_positions(n_measurements, radius_mean_m, radius_std_m): - """Generate synthetic spacecraft positions on sphere (test helper).""" - positions = np.zeros((n_measurements, 3)) - for i in range(n_measurements): - radius = np.random.normal(radius_mean_m, radius_std_m) - phi = np.random.uniform(0, 2 * np.pi) - cos_theta = np.random.uniform(-1, 1) - sin_theta = np.sqrt(1 - cos_theta**2) - positions[i] = [radius * sin_theta * np.cos(phi), radius * sin_theta * np.sin(phi), radius * cos_theta] - return positions - - -def _generate_nadir_aligned_transforms(n_measurements, riss_ctrs, boresights_hs): - """Generate transformation matrices aligning boresights with nadir (test helper).""" - t_matrices = np.zeros((n_measurements, 3, 3)) - for i in range(n_measurements): - nadir_ctrs = -riss_ctrs[i] / np.linalg.norm(riss_ctrs[i]) - bhat_hs_norm = boresights_hs[i] / np.linalg.norm(boresights_hs[i]) - rotation_axis = np.cross(bhat_hs_norm, nadir_ctrs) - axis_norm = np.linalg.norm(rotation_axis) - - if axis_norm < 1e-6: - if np.dot(bhat_hs_norm, nadir_ctrs) > 0: - t_matrices[i] = np.eye(3) - else: - perp = np.array([1, 0, 0]) if abs(bhat_hs_norm[0]) < 0.9 else np.array([0, 1, 0]) - rotation_axis = np.cross(bhat_hs_norm, perp) / np.linalg.norm(np.cross(bhat_hs_norm, perp)) - K = np.array( - [ - [0, -rotation_axis[2], rotation_axis[1]], - [rotation_axis[2], 0, -rotation_axis[0]], - [-rotation_axis[1], rotation_axis[0], 0], - ] - ) - t_matrices[i] = np.eye(3) + 2 * K @ K - else: - rotation_axis = rotation_axis / axis_norm - angle = np.arccos(np.clip(np.dot(bhat_hs_norm, nadir_ctrs), -1.0, 1.0)) - K = np.array( - [ - [0, -rotation_axis[2], rotation_axis[1]], - [rotation_axis[2], 0, -rotation_axis[0]], - [-rotation_axis[1], rotation_axis[0], 0], - ] - ) - t_matrices[i] = np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * (K @ K) - - return t_matrices - - -# ============================================================================= -# TEST MODE FUNCTIONS (Extracted from correction.py) -# ============================================================================= -# These functions were moved from the core correction module to keep test-specific -# code separate from mission-agnostic core functionality. - - -def discover_test_image_match_cases(test_data_dir: Path, test_cases: list[str] | None = None) -> list[dict]: - """ - Discover available image matching test cases. - - This function scans the test data directory for validated image matching - test cases and returns metadata about available test files. - - Args: - test_data_dir: Root directory for test data (tests/data/clarreo/image_match/) - test_cases: Specific test cases to use (e.g., ['1', '2']) or None for all - - Returns: - List of test case dictionaries with file paths and metadata - """ - logger.info(f"Discovering image matching test cases in: {test_data_dir}") - - # Shared calibration files (same for all test cases) - los_file = test_data_dir / "b_HS.mat" - psf_file_unbinned = test_data_dir / "optical_PSF_675nm_upsampled.mat" - psf_file_binned = test_data_dir / "optical_PSF_675nm_3_pix_binned_upsampled.mat" - - if not los_file.exists(): - raise FileNotFoundError(f"LOS vectors file not found: {los_file}") - if not psf_file_unbinned.exists(): - raise FileNotFoundError(f"Optical PSF file not found: {psf_file_unbinned}") - - # Test case metadata (from test_image_match.py) - test_case_metadata = { - "1": { - "name": "Dili", - "gcp_file": "GCP12055Dili_resampled.mat", - "ancil_file": "R_ISS_midframe_TestCase1.mat", - "expected_error_km": (3.0, -3.0), # (lat, lon) - "cases": [ - {"subimage": "TestCase1a_subimage.mat", "binned": False}, - {"subimage": "TestCase1b_subimage.mat", "binned": False}, - {"subimage": "TestCase1c_subimage_binned.mat", "binned": True}, - {"subimage": "TestCase1d_subimage_binned.mat", "binned": True}, - ], - }, - "2": { - "name": "Maracaibo", - "gcp_file": "GCP10121Maracaibo_resampled.mat", - "ancil_file": "R_ISS_midframe_TestCase2.mat", - "expected_error_km": (-3.0, 2.0), - "cases": [ - {"subimage": "TestCase2a_subimage.mat", "binned": False}, - {"subimage": "TestCase2b_subimage.mat", "binned": False}, - {"subimage": "TestCase2c_subimage_binned.mat", "binned": True}, - ], - }, - "3": { - "name": "Algeria3", - "gcp_file": "GCP10181Algeria3_resampled.mat", - "ancil_file": "R_ISS_midframe_TestCase3.mat", - "expected_error_km": (2.0, 3.0), - "cases": [ - {"subimage": "TestCase3a_subimage.mat", "binned": False}, - {"subimage": "TestCase3b_subimage_binned.mat", "binned": True}, - ], - }, - "4": { - "name": "Dunhuang", - "gcp_file": "GCP10142Dunhuang_resampled.mat", - "ancil_file": "R_ISS_midframe_TestCase4.mat", - "expected_error_km": (-2.0, -3.0), - "cases": [ - {"subimage": "TestCase4a_subimage.mat", "binned": False}, - {"subimage": "TestCase4b_subimage_binned.mat", "binned": True}, - ], - }, - "5": { - "name": "Algeria5", - "gcp_file": "GCP10071Algeria5_resampled.mat", - "ancil_file": "R_ISS_midframe_TestCase5.mat", - "expected_error_km": (1.0, -1.0), - "cases": [ - {"subimage": "TestCase5a_subimage.mat", "binned": False}, - ], - }, - } - - # Filter to requested test cases - if test_cases is None: - test_cases = sorted(test_case_metadata.keys()) - - discovered_cases = [] - - for case_id in test_cases: - if case_id not in test_case_metadata: - logger.warning(f"Test case '{case_id}' not found in metadata, skipping") - continue - - metadata = test_case_metadata[case_id] - case_dir = test_data_dir / case_id - - if not case_dir.is_dir(): - logger.warning(f"Test case directory not found: {case_dir}, skipping") - continue - - # Add each subcase variant (a, b, c, d) - for subcase in metadata["cases"]: - subimage_file = case_dir / subcase["subimage"] - gcp_file = case_dir / metadata["gcp_file"] - ancil_file = case_dir / metadata["ancil_file"] - psf_file = psf_file_binned if subcase["binned"] else psf_file_unbinned - - # Validate all files exist - if not subimage_file.exists(): - logger.warning(f"Subimage file not found: {subimage_file}, skipping") - continue - if not gcp_file.exists(): - logger.warning(f"GCP file not found: {gcp_file}, skipping") - continue - if not ancil_file.exists(): - logger.warning(f"Ancillary file not found: {ancil_file}, skipping") - continue - - discovered_cases.append( - { - "case_id": case_id, - "case_name": metadata["name"], - "subcase_name": subcase["subimage"], - "subimage_file": subimage_file, - "gcp_file": gcp_file, - "ancil_file": ancil_file, - "los_file": los_file, - "psf_file": psf_file, - "expected_lat_error_km": metadata["expected_error_km"][0], - "expected_lon_error_km": metadata["expected_error_km"][1], - "binned": subcase["binned"], - } - ) - - logger.info(f"Discovered {len(discovered_cases)} test case variants from {len(test_cases)} test case groups") - for case in discovered_cases: - logger.info( - f" - {case['case_id']}/{case['subcase_name']}: {case['case_name']}, " - f"expected error=({case['expected_lat_error_km']:.1f}, {case['expected_lon_error_km']:.1f}) km" - ) - - return discovered_cases - - -def apply_error_variation_for_testing( - base_result: xr.Dataset, param_idx: int, error_variation_percent: float = 3.0 -) -> xr.Dataset: - """ - Apply random variation to image matching results to simulate parameter effects. - - This is used in test mode to simulate how different parameter values would - affect geolocation errors, without actually re-running image matching. - - Args: - base_result: Original image matching result - param_idx: Parameter set index (used as random seed) - error_variation_percent: Percentage variation to apply (e.g., 3.0 = ±3%) - - Returns: - New Dataset with varied error values - """ - # Create copy - output = base_result.copy(deep=True) - - # Set reproducible random seed based on param_idx - np.random.seed(param_idx) - - # Generate variation factors (centered at 1.0, with specified percentage variation) - variation_fraction = error_variation_percent / 100.0 - lat_factor = 1.0 + np.random.normal(0, variation_fraction) - lon_factor = 1.0 + np.random.normal(0, variation_fraction) - ccv_factor = 1.0 + np.random.normal(0, variation_fraction / 10.0) # Smaller variation for correlation - - # Apply variations to error values - original_lat_km = base_result.attrs["lat_error_km"] - original_lon_km = base_result.attrs["lon_error_km"] - original_ccv = base_result.attrs["correlation_ccv"] - - varied_lat_km = original_lat_km * lat_factor - varied_lon_km = original_lon_km * lon_factor - varied_ccv = np.clip(original_ccv * ccv_factor, 0.0, 1.0) # Keep correlation in valid range - - # Update dataset values - # Get GCP center latitude from multiple possible sources - if "gcp_center_lat" in base_result.attrs: - gcp_center_lat = base_result.attrs["gcp_center_lat"] - elif "gcp_lat_deg" in base_result: - gcp_center_lat = float(base_result["gcp_lat_deg"].values[0]) - else: - # Fallback to a reasonable default (mid-latitude) - logger.warning("GCP center latitude not found in dataset, using default 45.0°") - gcp_center_lat = 45.0 - - lat_error_deg = varied_lat_km / 111.0 - lon_radius_km = 6378.0 * np.cos(np.deg2rad(gcp_center_lat)) - lon_error_deg = varied_lon_km / (lon_radius_km * np.pi / 180.0) - - output["lat_error_deg"].values[0] = lat_error_deg - output["lon_error_deg"].values[0] = lon_error_deg - - # Update attributes - output.attrs["lat_error_km"] = varied_lat_km - output.attrs["lon_error_km"] = varied_lon_km - output.attrs["correlation_ccv"] = varied_ccv - output.attrs["param_idx"] = param_idx - output.attrs["variation_applied"] = True - output.attrs["variation_lat_factor"] = lat_factor - output.attrs["variation_lon_factor"] = lon_factor - - logger.info( - f" Applied variation: lat {original_lat_km:.3f} → {varied_lat_km:.3f} km ({(lat_factor - 1) * 100:+.1f}%), " - f"lon {original_lon_km:.3f} → {varied_lon_km:.3f} km ({(lon_factor - 1) * 100:+.1f}%)" - ) - - return output - - -# ============================================================================= -# SHARED UTILITY FUNCTIONS -# ============================================================================= - - -def apply_geolocation_error_to_subimage( - subimage: ImageGrid, gcp: ImageGrid, lat_error_km: float, lon_error_km: float -) -> ImageGrid: - """ - Apply artificial geolocation error to a subimage for testing. - - This creates a misaligned subimage that the image matching algorithm - should detect and measure. - """ - mid_lat = float(gcp.lat[gcp.lat.shape[0] // 2, gcp.lat.shape[1] // 2]) - - # WGS84 Earth radius - use semi-major axis for latitude/longitude conversions - earth_radius_km = constants.WGS84_SEMI_MAJOR_AXIS_KM - lat_offset_deg = lat_error_km / earth_radius_km * (180.0 / np.pi) - lon_radius_km = earth_radius_km * np.cos(np.deg2rad(mid_lat)) - lon_offset_deg = lon_error_km / lon_radius_km * (180.0 / np.pi) - - return ImageGrid( - data=subimage.data.copy(), - lat=subimage.lat + lat_offset_deg, - lon=subimage.lon + lon_offset_deg, - h=subimage.h.copy() if subimage.h is not None else None, - ) - - -def run_image_matching_with_applied_errors( - test_case: dict, - param_idx: int, - randomize_errors: bool = True, - error_variation_percent: float = 3.0, - cache_results: bool = True, - cached_result: xr.Dataset | None = None, -) -> xr.Dataset: - """ - Run image matching with artificial errors applied to test data. - - This loads test data, applies known geolocation errors, then runs - image matching to verify it can detect those errors. - - Args: - test_case: Test case dictionary with file paths and expected errors - param_idx: Parameter set index (for variation seed) - randomize_errors: Whether to apply random variations - error_variation_percent: Percentage variation (default 3.0%) - cache_results: Whether to use cached results with variation - cached_result: Previously cached result to vary - """ - # Use cached result with variation if available - if cached_result is not None and cache_results and param_idx > 0: - if randomize_errors: - logger.info(f" Applying ±{error_variation_percent}% variation to cached result") - return apply_error_variation_for_testing(cached_result, param_idx, error_variation_percent) - else: - return cached_result.copy() - - logger.info(f" Running image matching with applied errors: {test_case['case_name']}") - start_time = time.time() - - # Load subimage - subimage_struct = loadmat(test_case["subimage_file"], squeeze_me=True, struct_as_record=False)["subimage"] - subimage = ImageGrid( - data=np.asarray(subimage_struct.data), - lat=np.asarray(subimage_struct.lat), - lon=np.asarray(subimage_struct.lon), - h=np.asarray(subimage_struct.h) if hasattr(subimage_struct, "h") else None, - ) - - # Load GCP - gcp = load_image_grid_from_mat(test_case["gcp_file"], key="GCP") - gcp_center_lat = float(gcp.lat[gcp.lat.shape[0] // 2, gcp.lat.shape[1] // 2]) - gcp_center_lon = float(gcp.lon[gcp.lon.shape[0] // 2, gcp.lon.shape[1] // 2]) - - # Apply expected error - expected_lat_error = test_case["expected_lat_error_km"] - expected_lon_error = test_case["expected_lon_error_km"] - - subimage_with_error = apply_geolocation_error_to_subimage(subimage, gcp, expected_lat_error, expected_lon_error) - - # Load calibration data - los_vectors = load_los_vectors_from_mat(test_case["los_file"]) - optical_psfs = load_optical_psf_from_mat(test_case["psf_file"]) - ancil_data = loadmat(test_case["ancil_file"], squeeze_me=True) - r_iss_midframe = ancil_data["R_ISS_midframe"].ravel() - - # Run image matching - result = integrated_image_match( - subimage=subimage_with_error, - gcp=gcp, - r_iss_midframe_m=r_iss_midframe, - los_vectors_hs=los_vectors, - optical_psfs=optical_psfs, - geolocation_config=PSFSamplingConfig(), - search_config=SearchConfig(), - ) - - processing_time = time.time() - start_time - - # Convert to dataset format - lat_error_deg = result.lat_error_km / 111.0 - lon_radius_km = 6378.0 * np.cos(np.deg2rad(gcp_center_lat)) - lon_error_deg = result.lon_error_km / (lon_radius_km * np.pi / 180.0) - - t_matrix = np.array( - [ - [-0.418977524967338, 0.748005379751721, 0.514728846515064], - [-0.421890284446342, 0.341604851993858, -0.839830169131854], - [-0.804031356019172, -0.569029065124742, 0.172451447025628], - ] - ) - boresight = np.array([0.0, 0.0625969755450201, 0.99803888634292]) - - output = xr.Dataset( - { - "lat_error_deg": (["measurement"], [lat_error_deg]), - "lon_error_deg": (["measurement"], [lon_error_deg]), - "riss_ctrs": (["measurement", "xyz"], [r_iss_midframe]), - "bhat_hs": (["measurement", "xyz"], [boresight]), - "t_hs2ctrs": (["measurement", "xyz_from", "xyz_to"], t_matrix[np.newaxis, :, :]), - "gcp_lat_deg": (["measurement"], [gcp_center_lat]), - "gcp_lon_deg": (["measurement"], [gcp_center_lon]), - "gcp_alt": (["measurement"], [0.0]), - }, - coords={"measurement": [0], "xyz": ["x", "y", "z"], "xyz_from": ["x", "y", "z"], "xyz_to": ["x", "y", "z"]}, - ) - - output.attrs.update( - { - "lat_error_km": result.lat_error_km, - "lon_error_km": result.lon_error_km, - "correlation_ccv": result.ccv_final, - "final_grid_step_m": result.final_grid_step_m, - "processing_time_s": processing_time, - "test_mode": True, - "param_idx": param_idx, - "gcp_center_lat": gcp_center_lat, - "gcp_center_lon": gcp_center_lon, - } - ) - - return output - - -# ============================================================================= -# CONFIGURATION GENERATION TEST -# ============================================================================= - - -def test_generate_clarreo_config_json(tmp_path): - """Generate CLARREO config JSON and validate structure. - - This test generates the canonical CLARREO configuration JSON file - that is used by all other CLARREO tests. The generated JSON is - saved to configs/ and can be committed for version control. - - This ensures: - - Single source of truth for CLARREO configuration - - Programmatic config matches JSON config - - JSON structure is valid and complete - """ - - logger.info("=" * 80) - logger.info("TEST: Generate CLARREO Configuration JSON") - logger.info("=" * 80) - - # Define paths - data_dir = Path(__file__).parent.parent / "data/clarreo/gcs" - generic_dir = Path("data/generic") - output_path = tmp_path / "configs/clarreo_correction_config.json" - - logger.info(f"Data directory: {data_dir}") - logger.info(f"Generic kernels: {generic_dir}") - logger.info(f"Output path: {output_path}") - - # Generate config programmatically - logger.info("\n1. Generating config programmatically...") - config = create_clarreo_correction_config(data_dir, generic_dir, config_output_path=output_path) - - logger.info(f"✓ Config created: {len(config.parameters)} parameter groups, {config.n_iterations} iterations") - - # Validate the generated JSON exists - logger.info("\n2. Validating generated JSON file...") - assert output_path.exists(), f"Config JSON not created: {output_path}" - logger.info(f"✓ JSON file exists: {output_path}") - - # Reload and verify structure - logger.info("\n3. Reloading and validating JSON structure...") - with open(output_path) as f: - config_data = json.load(f) - - # Validate top-level sections - assert "mission_config" in config_data, "Missing 'mission_config' section" - assert "correction" in config_data, "Missing 'correction' section" - assert "geolocation" in config_data, "Missing 'geolocation' section" - logger.info("✓ All required top-level sections present") - - # Validate mission config - mission_cfg = config_data["mission_config"] - assert mission_cfg["mission_name"] == "CLARREO_Pathfinder" - assert "kernel_mappings" in mission_cfg - logger.info(f"✓ Mission: {mission_cfg['mission_name']}") - - # Validate correction config - corr_cfg = config_data["correction"] - assert "parameters" in corr_cfg - assert isinstance(corr_cfg["parameters"], list) - assert len(corr_cfg["parameters"]) > 0 - assert "seed" in corr_cfg - assert "n_iterations" in corr_cfg - - # NEW: Validate required fields are present - assert "earth_radius_m" in corr_cfg, "Missing 'earth_radius_m' in correction config" - assert "performance_threshold_m" in corr_cfg, "Missing 'performance_threshold_m'" - assert "performance_spec_percent" in corr_cfg, "Missing 'performance_spec_percent'" - - assert corr_cfg["earth_radius_m"] == 6378140.0 - assert corr_cfg["performance_threshold_m"] == 250.0 - assert corr_cfg["performance_spec_percent"] == 39.0 - - logger.info(f"✓ Correction config: {len(corr_cfg['parameters'])} parameters, {corr_cfg['n_iterations']} iterations") - logger.info( - f"✓ Required fields: earth_radius={corr_cfg['earth_radius_m']}, " - f"threshold={corr_cfg['performance_threshold_m']}m, " - f"spec={corr_cfg['performance_spec_percent']}%" - ) - - # Validate geolocation config - geo_cfg = config_data["geolocation"] - assert "meta_kernel_file" in geo_cfg - assert "instrument_name" in geo_cfg - assert geo_cfg["instrument_name"] == "CPRS_HYSICS" - logger.info(f"✓ Geolocation config: instrument={geo_cfg['instrument_name']}") - - # Test that JSON can be loaded back into CorrectionConfig - logger.info("\n4. Testing JSON → CorrectionConfig loading...") - reloaded_config = correction.load_config_from_json(output_path) - assert reloaded_config.n_iterations == config.n_iterations - assert len(reloaded_config.parameters) == len(config.parameters) - assert reloaded_config.earth_radius_m == 6378140.0 - assert reloaded_config.performance_threshold_m == 250.0 - assert reloaded_config.performance_spec_percent == 39.0 - logger.info("✓ JSON successfully loads into CorrectionConfig") - - # Validate reloaded config - logger.info("\n5. Validating reloaded config...") - reloaded_config.validate() - logger.info("✓ Reloaded config passes validation") - - logger.info("\n" + "=" * 80) - logger.info("✓ CONFIG GENERATION TEST PASSED") - logger.info(f"✓ Canonical config saved: {output_path}") - logger.info(f"✓ File size: {output_path.stat().st_size / 1024:.1f} KB") - logger.info("=" * 80) - - # Note: Test functions should not return values per pytest best practices - # All validation is performed via assert statements above - - -# ============================================================================= -# UPSTREAM TESTING (Kernel Creation + Geolocation) -# ============================================================================= - - -def run_upstream_pipeline(n_iterations: int = 5, work_dir: Path | None = None) -> tuple[list, dict, Path]: - """ - Test UPSTREAM segment of Correction pipeline. - - This tests: - - Parameter set generation - - Kernel creation from parameters - - Geolocation with varied parameters - - This does NOT test: - - GCP pairing (no valid pairs available) - - Image matching (no valid data) - - Error statistics (no matched data) - - This is focused on testing the upstream kernel creation - and geolocation part of the pipeline. - - Args: - n_iterations: Number of Correction iterations - work_dir: Working directory for outputs. If None, uses a temporary - directory that will be cleaned up when the process exits. - - Returns: - Tuple of (results, netcdf_data, output_path) - """ - logger.info("=" * 80) - logger.info("UPSTREAM PIPELINE TEST") - logger.info("Tests: Kernel Creation + Geolocation") - logger.info("=" * 80) - - root_dir = Path(__file__).parents[2] - generic_dir = root_dir / "data" / "generic" - data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" - - # Use temporary directory if work_dir not specified - if work_dir is None: - _tmp_dir = tempfile.mkdtemp(prefix="curryer_upstream_") - work_dir = Path(_tmp_dir) - - # Register cleanup to run on process exit - def cleanup_temp_dir(): - if work_dir.exists(): - try: - shutil.rmtree(work_dir) - logger.debug(f"Cleaned up temporary directory: {work_dir}") - except Exception as e: - logger.warning(f"Failed to cleanup {work_dir}: {e}") - - atexit.register(cleanup_temp_dir) - - logger.info(f"Using temporary directory: {work_dir}") - logger.info("(will be cleaned up on process exit)") - else: - work_dir.mkdir(parents=True, exist_ok=True) - - logger.info(f"Work directory: {work_dir}") - logger.info(f"Iterations: {n_iterations}") - - # Create configuration using CLARREO config - config = create_clarreo_correction_config(data_dir, generic_dir) - config.n_iterations = n_iterations - # Set output filename for test (consistent name for version control) - config.output_filename = "upstream_results.nc" - - # Preprocess CLARREO raw CSVs into clean files the pipeline can load directly. - # This replaces the old telemetry_loader / science_loader callables. - logger.info("Preprocessing CLARREO telemetry and science data...") - tlm_df = load_clarreo_telemetry(data_dir) - sci_df_gps = load_clarreo_science(data_dir) # GPS seconds; pipeline scales to uGPS - - tlm_csv = work_dir / "clarreo_telemetry.csv" - sci_csv = work_dir / "clarreo_science.csv" - tlm_df.to_csv(tlm_csv) - sci_df_gps.to_csv(sci_csv) - logger.info(f" Telemetry → {tlm_csv}") - logger.info(f" Science → {sci_csv}") - - # Point the pipeline at the clean files via DataConfig - config.data = DataConfig( - file_format="csv", - time_scale_factor=1e6, # GPS sec → uGPS - ) - # Use synthetic image matching for this upstream-only test - config.image_matching_func = synthetic_image_matching - - logger.info(f"Configuration loaded:") - logger.info(f" Mission: CLARREO Pathfinder") - logger.info(f" Instrument: {config.geo.instrument_name}") - logger.info(f" Parameters: {len(config.parameters)}") - logger.info(f" Iterations: {n_iterations}") - logger.info(f" Data loading: config-driven (DataConfig)") - logger.info(f" Image matching: synthetic (upstream test)") - - # Each tuple is (telemetry_csv_path, science_csv_path, gcp_file_path). - # The gcp_file path is passed straight to image_matching_func; synthetic_image_matching ignores it. - tlm_sci_gcp_sets = [ - (str(tlm_csv), str(sci_csv), "synthetic_gcp.mat"), - ] - - logger.info(f"Data sets: {len(tlm_sci_gcp_sets)} (preprocessed CSVs, synthetic GCP)") - - # Execute the Correction loop - all config comes from config object! - # This will test parameter generation, kernel creation, and geolocation - logger.info("=" * 80) - logger.info("EXECUTING CORRECTION UPSTREAM WORKFLOW") - logger.info("=" * 80) - - results, netcdf_data = correction.loop(config, work_dir, tlm_sci_gcp_sets) - - logger.info("=" * 80) - logger.info("UPSTREAM PIPELINE COMPLETE") - logger.info(f"Processed {len(results)} total iterations") - logger.info(f"Generated results for {len(netcdf_data['parameter_set_id'])} parameter sets") - logger.info("=" * 80) - - # Output file is determined by config and saved by loop() - output_file = work_dir / config.get_output_filename() - logger.info(f"Results saved to: {output_file}") - - # Create results summary dict for consistency with downstream - results_summary = { - "mode": "upstream", - "iterations": n_iterations, - "parameter_sets": len(netcdf_data["parameter_set_id"]), - "status": "complete", - } - - return results, results_summary, output_file - - -# ============================================================================= -# DOWNSTREAM TESTING (Pairing + Image Matching + Error Statistics) -# ============================================================================= - - -def run_downstream_pipeline( - n_iterations: int = 5, test_cases: list[str] | None = None, work_dir: Path | None = None -) -> tuple[list, dict, Path]: - """ - Test DOWNSTREAM segment of Correction pipeline. - - IMPORTANT: This test uses a CUSTOM LOOP (not correction.loop()) because it works with - pre-geolocated test data that doesn't have the telemetry/parameters needed for - the normal upstream pipeline. - - Pipeline Comparison: - Normal correction.loop(): Parameters → Kernels → Geolocation → Matching → Stats - This test: Pre-geolocated Test Data → Pairing → Matching → Stats - - Parameter effects are simulated by varying the geolocation errors directly - (bumping lat/lon values), rather than varying parameters and re-running SPICE. - This is the correct approach for testing with pre-geolocated imagery! - - Tests (Real): - - GCP spatial pairing algorithms - - Image matching with real correlation - - Error statistics computation - - Does NOT Test (No Data Available): - - Kernel creation (no telemetry) - - SPICE geolocation (test data is pre-geolocated) - - True parameter sensitivity (simulated via error variation) - - Args: - n_iterations: Number of Correction iterations - test_cases: Specific test cases to use (e.g., ['1', '2']) - work_dir: Working directory for outputs. If None, uses a temporary - directory that will be cleaned up when the process exits. - - Returns: - Tuple of (results, netcdf_data, output_path) - """ - logger.info("=" * 80) - logger.info("DOWNSTREAM PIPELINE TEST") - logger.info("Tests: GCP Pairing + Image Matching + Error Statistics") - logger.info("=" * 80) - - root_dir = Path(__file__).parents[2] - generic_dir = root_dir / "data" / "generic" - data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" - test_data_dir = root_dir / "tests" / "data" / "clarreo" / "image_match" - - # Use temporary directory if work_dir not specified - if work_dir is None: - _tmp_dir = tempfile.mkdtemp(prefix="curryer_downstream_") - work_dir = Path(_tmp_dir) - - # Register cleanup to run on process exit - def cleanup_temp_dir(): - if work_dir.exists(): - try: - shutil.rmtree(work_dir) - logger.debug(f"Cleaned up temporary directory: {work_dir}") - except Exception as e: - logger.warning(f"Failed to cleanup {work_dir}: {e}") - - atexit.register(cleanup_temp_dir) - - logger.info(f"Using temporary directory: {work_dir}") - logger.info("(will be cleaned up on process exit)") - else: - work_dir.mkdir(parents=True, exist_ok=True) - - logger.info(f"Work directory: {work_dir}") - logger.info(f"Test data directory: {test_data_dir}") - logger.info(f"Iterations: {n_iterations}") - logger.info(f"Test cases: {test_cases or 'all'}") - - # Test configuration (simple parameters - no config object needed) - randomize_errors = True - error_variation_percent = 3.0 - cache_results = True - - # Discover test cases - discovered_cases = discover_test_image_match_cases(test_data_dir, test_cases) - logger.info(f"Discovered {len(discovered_cases)} test case variants") - - # ========================================================================== - # STEP 1: GCP PAIRING - # ========================================================================== - logger.info("\n" + "=" * 80) - logger.info("STEP 1: GCP SPATIAL PAIRING") - logger.info("=" * 80) - - # Load L1A images - l1a_images = [] - l1a_to_testcase = {} - - for test_case in discovered_cases: - l1a_img = load_image_grid_from_mat( - test_case["subimage_file"], - key="subimage", - as_named=True, - name=str(test_case["subimage_file"].relative_to(test_data_dir)), - ) - l1a_images.append(l1a_img) - l1a_to_testcase[l1a_img.name] = test_case - - # Load unique GCP references - gcp_files_seen = set() - gcp_images = [] - - for test_case in discovered_cases: - gcp_file = test_case["gcp_file"] - if gcp_file not in gcp_files_seen: - gcp_img = load_image_grid_from_mat( - gcp_file, key="GCP", as_named=True, name=str(gcp_file.relative_to(test_data_dir)) - ) - gcp_images.append(gcp_img) - gcp_files_seen.add(gcp_file) - - logger.info(f"Loaded {len(l1a_images)} L1A images") - logger.info(f"Loaded {len(gcp_images)} unique GCP references") - - # Run spatial pairing - pairing_result = find_l1a_gcp_pairs(l1a_images, gcp_images, max_distance_m=0.0) - logger.info(f"Pairing complete: Found {len(pairing_result.matches)} valid pairs") - - for match in pairing_result.matches: - l1a_name = pairing_result.l1a_images[match.l1a_index].name - gcp_name = pairing_result.gcp_images[match.gcp_index].name - logger.info(f" {l1a_name} ↔ {gcp_name} (distance: {match.distance_m:.1f}m)") - - # Convert to paired test cases - paired_test_cases = [] - for match in pairing_result.matches: - l1a_img = pairing_result.l1a_images[match.l1a_index] - test_case = l1a_to_testcase[l1a_img.name] - paired_test_cases.append(test_case) - - logger.info(f"Created {len(paired_test_cases)} paired test cases") - - # ========================================================================== - # STEP 2: CONFIGURATION - # ========================================================================== - logger.info("\n" + "=" * 80) - logger.info("STEP 2: CONFIGURATION") - logger.info("=" * 80) - - # Create base CLARREO config to get standard settings - base_config = create_clarreo_correction_config(data_dir, generic_dir) - - # Create minimal test config (CorrectionConfig = THE one config) - # For downstream testing, we use minimal parameters (sigma=0) because - # variations come from test_mode_config randomization, not parameter tweaking - config = correction.CorrectionConfig( - # Core settings - seed=42, - n_iterations=n_iterations, - # Minimal parameter (no real variation - sigma=0) - parameters=[ - correction.ParameterConfig( - ptype=correction.ParameterType.CONSTANT_KERNEL, - config_file=data_dir / "cprs_hysics_v01.attitude.ck.json", - data={ - "current_value": [0.0, 0.0, 0.0], - "sigma": 0.0, # No parameter variation (test variations applied differently) - "units": "arcseconds", - "transformation_type": "dcm_rotation", - "coordinate_frames": ["HYSICS_SLIT", "CRADLE_ELEVATION"], - }, - ) - ], - # Copy required fields from base_config - geo=base_config.geo, - performance_threshold_m=base_config.performance_threshold_m, - performance_spec_percent=base_config.performance_spec_percent, - earth_radius_m=base_config.earth_radius_m, - # Copy optional fields from base_config - netcdf=base_config.netcdf, - calibration_file_names=base_config.calibration_file_names, - spacecraft_position_name=base_config.spacecraft_position_name, - boresight_name=base_config.boresight_name, - transformation_matrix_name=base_config.transformation_matrix_name, - ) - - # Add DataConfig for file-based loading (loaders no longer needed) - config.data = DataConfig( - file_format="csv", - time_scale_factor=1e6, - ) - - # Validate complete config - config.validate() - - logger.info(f"Configuration created:") - logger.info(f" Mission: CLARREO (from clarreo_config)") - logger.info(f" Instrument: {config.geo.instrument_name}") - logger.info(f" Iterations: {config.n_iterations}") - logger.info(f" Parameters: {len(config.parameters)} (minimal for test mode)") - logger.info(f" Sigma: 0.0 (variations from randomization, not parameters)") - logger.info(f" Performance threshold: {config.performance_threshold_m}m") - - # ========================================================================== - # STEP 3: CORRECTION ITERATIONS - # ========================================================================== - logger.info("\n" + "=" * 80) - logger.info("STEP 3: CORRECTION ITERATIONS") - logger.info("=" * 80) - - n_param_sets = n_iterations - n_gcp_pairs = len(paired_test_cases) - - # Use dynamic NetCDF structure builder instead of hardcoded - netcdf_data = correction._build_netcdf_structure(config, n_param_sets, n_gcp_pairs) - logger.info(f"NetCDF structure built dynamically with {len(netcdf_data)} variables") - - # Get threshold metric name dynamically - threshold_metric = config.netcdf.get_threshold_metric_name() - logger.info(f"Using threshold metric: {threshold_metric}") - - image_match_cache = {} - - for param_idx in range(n_iterations): - logger.info(f"\n=== Iteration {param_idx + 1}/{n_iterations} ===") - - image_matching_results = [] - pair_errors = [] - - for pair_idx, test_case in enumerate(paired_test_cases): - logger.info(f"Processing pair {pair_idx + 1}/{len(paired_test_cases)}: {test_case['case_name']}") - - cache_key = f"{test_case['case_id']}_{test_case.get('subcase_name', '')}" - cached_result = image_match_cache.get(cache_key) - - # Run image matching - image_matching_output = run_image_matching_with_applied_errors( - test_case, - param_idx, - randomize_errors=randomize_errors, - error_variation_percent=error_variation_percent, - cache_results=cache_results, - cached_result=cached_result, - ) - - # Cache first result - if cache_key not in image_match_cache: - image_match_cache[cache_key] = image_matching_output - - image_matching_output.attrs["gcp_pair_index"] = pair_idx - image_matching_results.append(image_matching_output) - - # Extract and store metrics - lat_error_m = abs(image_matching_output.attrs["lat_error_km"] * 1000) - lon_error_m = abs(image_matching_output.attrs["lon_error_km"] * 1000) - rms_error_m = np.sqrt(lat_error_m**2 + lon_error_m**2) - pair_errors.append(rms_error_m) - - netcdf_data["rms_error_m"][param_idx, pair_idx] = rms_error_m - netcdf_data["mean_error_m"][param_idx, pair_idx] = rms_error_m - netcdf_data["max_error_m"][param_idx, pair_idx] = rms_error_m - netcdf_data["n_measurements"][param_idx, pair_idx] = 1 - netcdf_data["im_lat_error_km"][param_idx, pair_idx] = image_matching_output.attrs["lat_error_km"] - netcdf_data["im_lon_error_km"][param_idx, pair_idx] = image_matching_output.attrs["lon_error_km"] - netcdf_data["im_ccv"][param_idx, pair_idx] = image_matching_output.attrs["correlation_ccv"] - netcdf_data["im_grid_step_m"][param_idx, pair_idx] = image_matching_output.attrs["final_grid_step_m"] - - # Compute aggregate metrics (use dynamic threshold) - pair_errors = np.array(pair_errors) - valid_errors = pair_errors[~np.isnan(pair_errors)] - - if len(valid_errors) > 0: - threshold_value = config.performance_threshold_m - percent_under_threshold = (valid_errors < threshold_value).sum() / len(valid_errors) * 100 - netcdf_data[threshold_metric][param_idx] = percent_under_threshold - netcdf_data["mean_rms_all_pairs"][param_idx] = np.mean(valid_errors) - netcdf_data["best_pair_rms"][param_idx] = np.min(valid_errors) - netcdf_data["worst_pair_rms"][param_idx] = np.max(valid_errors) - - logger.info(f"Iteration {param_idx + 1} complete:") - logger.info(f" {percent_under_threshold:.1f}% under {threshold_value}m threshold") - logger.info(f" Mean RMS: {np.mean(valid_errors):.2f}m") - - # ========================================================================== - # STEP 4: ERROR STATISTICS - # ========================================================================== - logger.info("\n" + "=" * 80) - logger.info("STEP 4: ERROR STATISTICS") - logger.info("=" * 80) - - error_stats = correction.call_error_stats_module(image_matching_results, correction_config=config) - logger.info(f"Error statistics computed: {len(error_stats)} metrics") - - # ========================================================================== - # STEP 5: SAVE RESULTS - # ========================================================================== - logger.info("\n" + "=" * 80) - logger.info("STEP 5: SAVE RESULTS") - logger.info("=" * 80) - - output_file = work_dir / "downstream_results.nc" - correction._save_netcdf_results(netcdf_data, output_file, config) - - logger.info("=" * 80) - logger.info("DOWNSTREAM TEST COMPLETE") - logger.info("=" * 80) - logger.info(f"Output: {output_file}") - logger.info(f"Iterations: {n_iterations}") - logger.info(f"Test pairs: {n_gcp_pairs}") - - results = {"mode": "downstream", "iterations": n_iterations, "test_pairs": n_gcp_pairs, "status": "complete"} - - return [], results, output_file - - -# ============================================================================= -# UNITTEST TEST CASES -# ============================================================================= - - -class CorrectionUnifiedTests(unittest.TestCase): - """Unified test cases for both upstream and downstream pipelines. - - Note: Only test_upstream_quick requires GMTED elevation data and is marked - with @pytest.mark.extra. All other tests use either config-only validation - or pre-geolocated test data and will run in CI without GMTED files. - """ - - def setUp(self): - """Set up test environment.""" - self.root_dir = Path(__file__).parents[2] - self.test_data_dir = self.root_dir / "tests" / "data" / "clarreo" / "image_match" - - self.__tmp_dir = tempfile.TemporaryDirectory() - self.addCleanup(self.__tmp_dir.cleanup) - self.work_dir = Path(self.__tmp_dir.name) - - def test_upstream_configuration(self): - """Test that upstream configuration loads correctly.""" - logger.info("Testing upstream configuration...") - - data_dir = self.root_dir / "tests" / "data" / "clarreo" / "gcs" - generic_dir = self.root_dir / "data" / "generic" - - config = create_clarreo_correction_config(data_dir, generic_dir) - - # Attach DataConfig so the pipeline knows how to load files - config.data = DataConfig( - file_format="csv", - time_scale_factor=1e6, - ) - - # Validate config is complete (no loader args needed any more) - config.validate() - - self.assertEqual(config.geo.instrument_name, "CPRS_HYSICS") - self.assertGreater(len(config.parameters), 0) - self.assertEqual(config.seed, 42) - self.assertIsNotNone(config.data) - - logger.info(f"✓ Configuration valid: {len(config.parameters)} parameters") - - def test_downstream_test_case_discovery(self): - """Test that downstream test cases can be discovered.""" - logger.info("Testing test case discovery...") - - test_cases = discover_test_image_match_cases(self.test_data_dir) - - self.assertGreater(len(test_cases), 0, "No test cases discovered") - - for tc in test_cases: - self.assertIn("case_id", tc) - self.assertIn("subimage_file", tc) - self.assertIn("gcp_file", tc) - self.assertTrue(tc["subimage_file"].exists()) - self.assertTrue(tc["gcp_file"].exists()) - - logger.info(f"✓ Discovered {len(test_cases)} test cases") - - def test_downstream_image_matching(self): - """Test that downstream image matching works.""" - logger.info("Testing image matching...") - - test_cases = discover_test_image_match_cases(self.test_data_dir, test_cases=["1"]) - self.assertGreater(len(test_cases), 0) - - test_case = test_cases[0] - - result = run_image_matching_with_applied_errors( - test_case, param_idx=0, randomize_errors=False, cache_results=True - ) - - self.assertIsInstance(result, xr.Dataset) - self.assertIn("lat_error_km", result.attrs) - self.assertIn("lon_error_km", result.attrs) - - logger.info(f"✓ Image matching successful") - - @pytest.mark.extra - def test_upstream_quick(self): - """Run quick upstream test. - - This test requires GMTED elevation data which is not available in CI. - Run with: pytest --run-extra - """ - logger.info("Running quick upstream test...") - - results_list, results_dict, output_file = run_upstream_pipeline(n_iterations=2, work_dir=self.work_dir) - - self.assertEqual(results_dict["status"], "complete") - self.assertEqual(results_dict["iterations"], 2) - self.assertEqual(results_dict["mode"], "upstream") - self.assertGreater(results_dict["parameter_sets"], 0) - self.assertTrue(output_file.exists()) - - logger.info(f"✓ Quick upstream test complete: {output_file}") - - def test_downstream_quick(self): - """Run quick downstream test.""" - logger.info("Running quick downstream test...") - - results_list, results_dict, output_file = run_downstream_pipeline( - n_iterations=2, test_cases=["1"], work_dir=self.work_dir - ) - - self.assertEqual(results_dict["status"], "complete") - self.assertEqual(results_dict["iterations"], 2) - self.assertTrue(output_file.exists()) - - logger.info(f"✓ Quick downstream test complete: {output_file}") - - def test_synthetic_helpers_basic(self): - """Test that synthetic helper functions work correctly (for coverage).""" - logger.info("Testing synthetic helper functions...") - - # Test synthetic GCP pairing - science_files = ["science_1.nc", "science_2.nc"] - pairs = synthetic_gcp_pairing(science_files) - self.assertEqual(len(pairs), 2) - self.assertIsInstance(pairs, list) - logger.info("✓ synthetic_gcp_pairing works") - - # Test synthetic boresights generation - boresights = _generate_synthetic_boresights(5, max_off_nadir_rad=0.07) - self.assertEqual(boresights.shape, (5, 3)) - self.assertTrue(np.all(np.abs(boresights[:, 0]) < 0.01)) # Small x component - logger.info("✓ _generate_synthetic_boresights works") - - # Test synthetic positions generation - positions = _generate_spherical_positions(5, 6.78e6, 4e3) - self.assertEqual(positions.shape, (5, 3)) - radii = np.linalg.norm(positions, axis=1) - self.assertTrue(np.all(radii > 6.7e6)) # Reasonable orbit altitude - logger.info("✓ _generate_spherical_positions works") - - # Test transform generation - transforms = _generate_nadir_aligned_transforms(5, positions, boresights) - self.assertEqual(transforms.shape, (5, 3, 3)) - # Check it's a valid rotation matrix (det should be close to 1) - det = np.linalg.det(transforms[0]) - self.assertAlmostEqual(abs(det), 1.0, places=1) - logger.info("✓ _generate_nadir_aligned_transforms works") - - logger.info("✓ All synthetic helpers validated") - - def test_downstream_helpers_basic(self): - """Test downstream helper functions (for coverage).""" - logger.info("Testing downstream helper functions...") - - # Test test case discovery - test_cases = discover_test_image_match_cases(self.test_data_dir, test_cases=["1"]) - self.assertGreater(len(test_cases), 0) - self.assertIn("case_id", test_cases[0]) - self.assertIn("subimage_file", test_cases[0]) - logger.info(f"✓ discover_test_image_match_cases found {len(test_cases)} cases") - - # Test error variation (create a simple test dataset) - base_result = xr.Dataset( - { - "lat_error_deg": (["measurement"], [0.001]), - "lon_error_deg": (["measurement"], [0.002]), - }, - attrs={"lat_error_km": 0.1, "lon_error_km": 0.2, "correlation_ccv": 0.95}, - ) - - varied_result = apply_error_variation_for_testing(base_result, param_idx=1, error_variation_percent=3.0) - self.assertIsInstance(varied_result, xr.Dataset) - self.assertIn("lat_error_deg", varied_result) - # Check that variation was applied (should be different from base) - self.assertNotEqual(varied_result.attrs["lat_error_km"], base_result.attrs["lat_error_km"]) - logger.info("✓ apply_error_variation_for_testing works") - - logger.info("✓ All downstream helpers validated") - - @pytest.mark.extra - def test_loop_optimized(self): - """ - Test loop() function (optimized pair-outer implementation). - - This test validates the main loop() function which is now the optimized - default implementation. It covers the core Correction workflow. - """ - logger.info("=" * 80) - logger.info("TEST: loop() (OPTIMIZED IMPLEMENTATION)") - logger.info("=" * 80) - - # Setup configuration - root_dir = Path(__file__).parents[2] - generic_dir = root_dir / "data" / "generic" - data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" - - config = create_clarreo_correction_config(data_dir, generic_dir) - config.n_iterations = 2 # Small for fast testing - config.output_filename = "test_loop_optimized.nc" - - work_dir = self.work_dir / "test_loop_optimized" - work_dir.mkdir(exist_ok=True) - - # Preprocess raw CSVs into clean files - tlm_df = load_clarreo_telemetry(data_dir) - sci_df = load_clarreo_science(data_dir) - tlm_csv = work_dir / "tlm.csv" - sci_csv = work_dir / "sci.csv" - tlm_df.to_csv(tlm_csv) - sci_df.to_csv(sci_csv) - - # Attach DataConfig and synthetic image matching override - config.data = DataConfig( - file_format="csv", - time_scale_factor=1e6, - ) - config.image_matching_func = synthetic_image_matching - - # Each tuple: (telemetry_csv, science_csv, gcp_path) - # synthetic_image_matching ignores gcp_path - tlm_sci_gcp_sets = [ - (str(tlm_csv), str(sci_csv), "synthetic_gcp.mat"), - ] - - # Run loop() - logger.info("Running loop()...") - np.random.seed(42) - results, netcdf_data = correction.loop(config, work_dir, tlm_sci_gcp_sets, resume_from_checkpoint=False) - - # Validate results structure - self.assertIsInstance(results, list) - self.assertGreater(len(results), 0) - expected_count = config.n_iterations * len(tlm_sci_gcp_sets) - self.assertEqual(len(results), expected_count) - logger.info(f"✓ loop() returned {len(results)} results") - - # Validate NetCDF structure - self.assertIsInstance(netcdf_data, dict) - self.assertIn("rms_error_m", netcdf_data) - self.assertIn("parameter_set_id", netcdf_data) - self.assertEqual(netcdf_data["rms_error_m"].shape, (config.n_iterations, len(tlm_sci_gcp_sets))) - logger.info(f"✓ NetCDF structure valid") - - # Validate result contents - for result in results: - self.assertIn("param_index", result) - self.assertIn("pair_index", result) - self.assertIn("rms_error_m", result) - self.assertIn("aggregate_rms_error_m", result) - # Verify aggregate_rms_error_m is populated (not None) - self.assertIsNotNone( - result["aggregate_rms_error_m"], - f"aggregate_rms_error_m should be populated for result {result['iteration']}", - ) - # Verify it's a valid numeric value - self.assertIsInstance(result["aggregate_rms_error_m"], (int, float, np.number)) - logger.info(f"✓ All result entries have required fields") - - logger.info("=" * 80) - logger.info("✓ loop() TEST PASSED") - logger.info("=" * 80) - - def test_helper_extract_parameter_values(self): - """Test _extract_parameter_values helper function.""" - logger.info("Testing _extract_parameter_values...") - - # Create sample params with proper structure for CONSTANT_KERNEL type - param_config = correction.ParameterConfig( - ptype=correction.ParameterType.CONSTANT_KERNEL, config_file=Path("test_kernel.json"), data=None - ) - # Create DataFrame with angle data as expected by the function - param_data = pd.DataFrame( - { - "angle_x": [np.radians(1.0 / 3600)], # 1 arcsec in radians - "angle_y": [np.radians(2.0 / 3600)], # 2 arcsec in radians - "angle_z": [np.radians(3.0 / 3600)], # 3 arcsec in radians - } - ) - params = [(param_config, param_data)] - - result = correction._extract_parameter_values(params) - - self.assertIsInstance(result, dict) - # Should extract 3 values: roll, pitch, yaw - self.assertEqual(len(result), 3) - self.assertIn("test_kernel_roll", result) - self.assertIn("test_kernel_pitch", result) - self.assertIn("test_kernel_yaw", result) - logger.info(f"✓ _extract_parameter_values works correctly") - - def test_helper_build_netcdf_structure(self): - """Test _build_netcdf_structure helper function.""" - logger.info("Testing _build_netcdf_structure...") - - # Create minimal config - root_dir = Path(__file__).parents[2] - generic_dir = root_dir / "data" / "generic" - data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" - config = create_clarreo_correction_config(data_dir, generic_dir) - - n_params = 3 - n_pairs = 2 - - netcdf_data = correction._build_netcdf_structure(config, n_params, n_pairs) - - # Validate structure - self.assertIsInstance(netcdf_data, dict) - self.assertIn("rms_error_m", netcdf_data) - self.assertIn("parameter_set_id", netcdf_data) - self.assertEqual(netcdf_data["rms_error_m"].shape, (n_params, n_pairs)) - self.assertEqual(len(netcdf_data["parameter_set_id"]), n_params) - logger.info(f"✓ _build_netcdf_structure creates correct structure") - - def test_helper_extract_error_metrics(self): - """Test _extract_error_metrics helper function.""" - logger.info("Testing _extract_error_metrics...") - - # Create sample error stats dataset with correct attribute name - stats_dataset = xr.Dataset( - { - "measurement": (["point"], [0, 1, 2]), - "lat_error_deg": (["point"], [0.001, 0.002, 0.001]), - "lon_error_deg": (["point"], [0.001, 0.002, 0.001]), - } - ) - stats_dataset.attrs["rms_error_m"] = 150.0 - stats_dataset.attrs["mean_error_m"] = 140.0 - stats_dataset.attrs["max_error_m"] = 200.0 - stats_dataset.attrs["std_error_m"] = 10.0 - stats_dataset.attrs["total_measurements"] = 3 # Correct attribute name - - metrics = correction._extract_error_metrics(stats_dataset) - - # Validate metrics - self.assertIsInstance(metrics, dict) - self.assertIn("rms_error_m", metrics) - self.assertIn("mean_error_m", metrics) - self.assertIn("n_measurements", metrics) - self.assertEqual(metrics["n_measurements"], 3) - self.assertEqual(metrics["rms_error_m"], 150.0) - logger.info(f"✓ _extract_error_metrics extracts metrics correctly") - - def test_helper_store_parameter_values(self): - """Test _store_parameter_values helper function.""" - logger.info("Testing _store_parameter_values...") - - # Create netcdf structure with parameter arrays pre-created - # (as _build_netcdf_structure would do) - netcdf_data = { - "parameter_set_id": np.zeros(3, dtype=int), - "param_test_param": np.zeros(3), # Must match naming convention - } - - param_values = {"test_param": 1.5} - param_idx = 1 - - correction._store_parameter_values(netcdf_data, param_idx, param_values) - - # Validate storage - self.assertEqual(netcdf_data["param_test_param"][param_idx], 1.5) - logger.info(f"✓ _store_parameter_values stores correctly") - - def test_helper_store_gcp_pair_results(self): - """Test _store_gcp_pair_results helper function.""" - logger.info("Testing _store_gcp_pair_results...") - - # Create netcdf structure with all required fields - netcdf_data = { - "rms_error_m": np.zeros((2, 2)), - "mean_error_m": np.zeros((2, 2)), - "max_error_m": np.zeros((2, 2)), - "std_error_m": np.zeros((2, 2)), # Must include this - "n_measurements": np.zeros((2, 2), dtype=int), - } - - error_metrics = { - "rms_error_m": 150.0, - "mean_error_m": 140.0, - "max_error_m": 200.0, - "std_error_m": 10.0, # Must include this - "n_measurements": 10, - } - - param_idx = 0 - pair_idx = 1 - - correction._store_gcp_pair_results(netcdf_data, param_idx, pair_idx, error_metrics) - - # Validate storage - self.assertEqual(netcdf_data["rms_error_m"][param_idx, pair_idx], 150.0) - self.assertEqual(netcdf_data["std_error_m"][param_idx, pair_idx], 10.0) - self.assertEqual(netcdf_data["n_measurements"][param_idx, pair_idx], 10) - logger.info(f"✓ _store_gcp_pair_results stores correctly") - - def test_helper_compute_parameter_set_metrics(self): - """Test _compute_parameter_set_metrics helper function.""" - logger.info("Testing _compute_parameter_set_metrics...") - - # Create netcdf structure - netcdf_data = { - "percent_under_250m": np.zeros(2), - "mean_rms_all_pairs": np.zeros(2), - "best_pair_rms": np.zeros(2), - "worst_pair_rms": np.zeros(2), - } - - pair_errors = [100.0, 200.0, 300.0] - param_idx = 0 - threshold_m = 250.0 - - correction._compute_parameter_set_metrics(netcdf_data, param_idx, pair_errors, threshold_m) - - # Validate computed metrics - self.assertGreater(netcdf_data["percent_under_250m"][param_idx], 0) - self.assertGreater(netcdf_data["mean_rms_all_pairs"][param_idx], 0) - self.assertEqual(netcdf_data["best_pair_rms"][param_idx], 100.0) - self.assertEqual(netcdf_data["worst_pair_rms"][param_idx], 300.0) - logger.info(f"✓ _compute_parameter_set_metrics computes correctly") - - def test_helper_load_image_pair_data(self): - """Test _load_image_pair_data helper function.""" - logger.info("Testing _load_image_pair_data...") - - # Setup configuration - root_dir = Path(__file__).parents[2] - generic_dir = root_dir / "data" / "generic" - data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" - config = create_clarreo_correction_config(data_dir, generic_dir) - - # Preprocess raw CLARREO data → clean CSVs - tlm_csv = self.work_dir / "tlm_pair.csv" - sci_csv = self.work_dir / "sci_pair.csv" - load_clarreo_telemetry(data_dir).to_csv(tlm_csv) - load_clarreo_science(data_dir).to_csv(sci_csv) - - config.data = DataConfig( - file_format="csv", - time_scale_factor=1e6, - ) - - tlm_dataset, sci_dataset, ugps_times = correction._load_image_pair_data(str(tlm_csv), str(sci_csv), config) - - # Validate return types - self.assertIsInstance(tlm_dataset, pd.DataFrame) - self.assertIsInstance(sci_dataset, pd.DataFrame) - self.assertIsNotNone(ugps_times) - logger.info(f"✓ _load_image_pair_data loads data correctly") - - def test_helper_create_dynamic_kernels(self): - """Test _create_dynamic_kernels helper function.""" - logger.info("Testing _create_dynamic_kernels...") - - # Setup - root_dir = Path(__file__).parents[2] - generic_dir = root_dir / "data" / "generic" - data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" - config = create_clarreo_correction_config(data_dir, generic_dir) - - work_dir = self.work_dir / "test_dynamic_kernels" - work_dir.mkdir(exist_ok=True) - - # Load data (preprocess CLARREO raw CSVs) - tlm_dataset = load_clarreo_telemetry(data_dir) - creator = create.KernelCreator(overwrite=True, append=False) - - # Load SPICE kernels needed for kernel creation - # (frame kernel defines ISS_SC body which is needed by ephemeris writer) - mkrn = meta.MetaKernel.from_json( - config.geo.meta_kernel_file, - relative=True, - sds_dir=config.geo.generic_kernel_dir, - ) - with sp.ext.load_kernel([mkrn.sds_kernels, mkrn.mission_kernels]): - dynamic_kernels = correction._create_dynamic_kernels(config, work_dir, tlm_dataset, creator) - - # Validate - self.assertIsInstance(dynamic_kernels, list) - logger.info(f"✓ _create_dynamic_kernels creates {len(dynamic_kernels)} kernels") - - def test_helper_load_calibration_data(self): - """Test _load_calibration_data helper function.""" - logger.info("Testing _load_calibration_data...") - - # Setup minimal config without calibration dir - root_dir = Path(__file__).parents[2] - generic_dir = root_dir / "data" / "generic" - data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" - config = create_clarreo_correction_config(data_dir, generic_dir) - config.calibration_dir = None - - # Load LOS vectors and PSF data into calibration data. - calibration_data = correction._load_calibration_data(config) - - # Should return CalibrationData with None values when no calibration dir - self.assertIsNone(calibration_data.los_vectors) - self.assertIsNone(calibration_data.optical_psfs) - logger.info(f"✓ _load_calibration_data handles None calibration_dir") - - def test_checkpoint_save_load(self): - """Test checkpoint save and load functionality.""" - logger.info("=" * 80) - logger.info("TEST: Checkpoint Save/Load") - logger.info("=" * 80) - - # Setup configuration - root_dir = Path(__file__).parents[2] - generic_dir = root_dir / "data" / "generic" - data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" - - config = create_clarreo_correction_config(data_dir, generic_dir) - config.n_iterations = 2 - config.output_filename = "test_checkpoint.nc" - - # Use DataConfig (loaders no longer needed on config) - config.image_matching_func = synthetic_image_matching - - work_dir = self.work_dir / "test_checkpoint" - work_dir.mkdir(exist_ok=True) - - output_file = work_dir / config.output_filename - - # Build simple netcdf structure for testing - netcdf_data = correction._build_netcdf_structure(config, 2, 2) - netcdf_data["rms_error_m"][0, 0] = 100.0 - netcdf_data["rms_error_m"][1, 0] = 150.0 - - # Save checkpoint - logger.info("Saving checkpoint...") - correction._save_netcdf_checkpoint(netcdf_data, output_file, config, pair_idx_completed=0) - - checkpoint_file = output_file.parent / f"{output_file.stem}_checkpoint.nc" - self.assertTrue(checkpoint_file.exists()) - logger.info(f"✓ Checkpoint file created: {checkpoint_file}") - - # Load checkpoint - logger.info("Loading checkpoint...") - loaded_data, completed_pairs = correction._load_checkpoint(output_file, config) - - self.assertIsNotNone(loaded_data) - self.assertEqual(completed_pairs, 1) # 0-indexed, so pair 0 completed = 1 - self.assertEqual(loaded_data["rms_error_m"][0, 0], 100.0) - self.assertEqual(loaded_data["rms_error_m"][1, 0], 150.0) - logger.info(f"✓ Checkpoint loaded correctly, completed pairs = {completed_pairs}") - - # Cleanup - correction._cleanup_checkpoint(output_file) - self.assertFalse(checkpoint_file.exists()) - logger.info(f"✓ Checkpoint cleanup successful") - - logger.info("=" * 80) - logger.info("✓ CHECKPOINT SAVE/LOAD TEST PASSED") - logger.info("=" * 80) - - def test_apply_offset_function(self): - """Test apply_offset function for all parameter types.""" - logger.info("=" * 80) - logger.info("TEST: apply_offset() Function") - logger.info("=" * 80) - - # ========== Test 1: OFFSET_KERNEL with arcseconds ========== - logger.info("\nTest 1: OFFSET_KERNEL with arcseconds unit conversion") - - # Create realistic telemetry data matching CLARREO structure - telemetry_data = pd.DataFrame( - { - "frame": range(5), - "hps.az_ang_nonlin": [1.14252] * 5, - "hps.el_ang_nonlin": [-0.55009] * 5, - "hps.resolver_tms": [1168477154.0 + i for i in range(5)], - "ert": [1431903180.58 + i for i in range(5)], - } - ) - - # Create parameter config for azimuth angle offset - az_param_config = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_KERNEL, - config_file=Path("cprs_az_v01.attitude.ck.json"), - data=dict( - field="hps.az_ang_nonlin", - units="arcseconds", - ), - ) - - # Apply offset of 100 arcseconds - offset_arcsec = 100.0 - original_mean = telemetry_data["hps.az_ang_nonlin"].mean() - modified_data = correction.apply_offset(az_param_config, offset_arcsec, telemetry_data) - - # Verify offset was applied correctly - expected_offset_rad = np.deg2rad(offset_arcsec / 3600.0) - actual_delta = modified_data["hps.az_ang_nonlin"].mean() - original_mean - self.assertAlmostEqual(actual_delta, expected_offset_rad, places=9) - self.assertIsInstance(modified_data, pd.DataFrame) - logger.info(f"✓ OFFSET_KERNEL (arcseconds): {offset_arcsec} arcsec = {expected_offset_rad:.9f} rad") - - # ========== Test 2: OFFSET_KERNEL with elevation angle ========== - logger.info("\nTest 2: OFFSET_KERNEL with elevation angle (negative offset)") - - el_param_config = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_KERNEL, - config_file=Path("cprs_el_v01.attitude.ck.json"), - data=dict( - field="hps.el_ang_nonlin", - units="arcseconds", - ), - ) - - # Apply negative offset - offset_arcsec = -50.0 - original_mean = telemetry_data["hps.el_ang_nonlin"].mean() - modified_data = correction.apply_offset(el_param_config, offset_arcsec, telemetry_data) - - # Verify - expected_offset_rad = np.deg2rad(offset_arcsec / 3600.0) - actual_delta = modified_data["hps.el_ang_nonlin"].mean() - original_mean - self.assertAlmostEqual(actual_delta, expected_offset_rad, places=9) - logger.info(f"✓ OFFSET_KERNEL (negative): {offset_arcsec} arcsec = {expected_offset_rad:.9f} rad") - - # ========== Test 3: OFFSET_KERNEL with non-existent field ========== - logger.info("\nTest 3: OFFSET_KERNEL with non-existent field (should warn)") - - bad_param_config = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_KERNEL, - config_file=Path("dummy.json"), - data=dict( - field="nonexistent_field", - units="arcseconds", - ), - ) - - # Should return unmodified data when field not found - modified_data = correction.apply_offset(bad_param_config, 10.0, telemetry_data) - self.assertIsInstance(modified_data, pd.DataFrame) - # Data should be unchanged - pd.testing.assert_frame_equal(modified_data, telemetry_data) - logger.info("✓ OFFSET_KERNEL correctly handles missing field") - - # ========== Test 4: OFFSET_TIME with milliseconds ========== - logger.info("\nTest 4: OFFSET_TIME with milliseconds unit conversion") - - science_data = pd.DataFrame( - { - "corrected_timestamp": [1000000.0, 2000000.0, 3000000.0, 4000000.0, 5000000.0], - "measurement": [1.0, 2.0, 3.0, 4.0, 5.0], - } - ) - - time_param_config = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_TIME, - config_file=None, - data=dict( - field="corrected_timestamp", - units="milliseconds", - ), - ) - - # Apply time offset - value must be in seconds - # This simulates the production flow where load_param_sets converts to seconds - offset_ms = 10.0 - offset_seconds = offset_ms / 1000.0 # Convert to internal units (seconds) - original_mean = science_data["corrected_timestamp"].mean() - modified_data = correction.apply_offset(time_param_config, offset_seconds, science_data) - - # Verify offset was applied correctly (seconds -> microseconds) - expected_offset_us = offset_ms * 1000.0 # 10 ms = 10000 µs - actual_delta = modified_data["corrected_timestamp"].mean() - original_mean - self.assertAlmostEqual(actual_delta, expected_offset_us, places=6) - logger.info(f"✓ OFFSET_TIME: {offset_ms} ms = {expected_offset_us:.6f} µs") - - # ========== Test 5: OFFSET_TIME with negative offset ========== - logger.info("\nTest 5: OFFSET_TIME with negative offset") - - offset_ms = -5.5 - offset_seconds = offset_ms / 1000.0 # Convert to internal units (seconds) - original_mean = science_data["corrected_timestamp"].mean() - modified_data = correction.apply_offset(time_param_config, offset_seconds, science_data) - - # Verify - expected_offset_us = offset_ms * 1000.0 - actual_delta = modified_data["corrected_timestamp"].mean() - original_mean - self.assertAlmostEqual(actual_delta, expected_offset_us, places=6) - logger.info(f"✓ OFFSET_TIME (negative): {offset_ms} ms = {expected_offset_us:.6f} µs") - - # ========== Test 6: CONSTANT_KERNEL (pass-through) ========== - logger.info("\nTest 6: CONSTANT_KERNEL (pass-through, no modification)") - - constant_kernel_data = pd.DataFrame( - { - "ugps": [1000000, 2000000], - "angle_x": [0.001, 0.001], - "angle_y": [0.002, 0.002], - "angle_z": [0.003, 0.003], - } - ) - - constant_param_config = correction.ParameterConfig( - ptype=correction.ParameterType.CONSTANT_KERNEL, - config_file=Path("cprs_base_v01.attitude.ck.json"), - data=dict( - field="cprs_base", - ), - ) - - # For CONSTANT_KERNEL, param_data is already the kernel data - modified_data = correction.apply_offset(constant_param_config, constant_kernel_data, pd.DataFrame()) - - # Should return the constant kernel data unchanged - self.assertIsInstance(modified_data, pd.DataFrame) - pd.testing.assert_frame_equal(modified_data, constant_kernel_data) - logger.info("✓ CONSTANT_KERNEL returns data unchanged") - - # ========== Test 7: OFFSET_KERNEL without units ========== - logger.info("\nTest 7: OFFSET_KERNEL without unit conversion") - - param_no_units = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_KERNEL, - config_file=Path("test.json"), - data=dict( - field="hps.az_ang_nonlin", - # No units specified - ), - ) - - # Apply offset directly (no conversion) - offset_value = 0.001 # radians - original_mean = telemetry_data["hps.az_ang_nonlin"].mean() - modified_data = correction.apply_offset(param_no_units, offset_value, telemetry_data) - - # Verify offset applied directly without conversion - actual_delta = modified_data["hps.az_ang_nonlin"].mean() - original_mean - self.assertAlmostEqual(actual_delta, offset_value, places=9) - logger.info(f"✓ OFFSET_KERNEL (no units): {offset_value} rad applied directly") - - # ========== Test 8: Data is not modified in place ========== - logger.info("\nTest 8: Original data is not modified in place") - - original_telemetry = telemetry_data.copy() - modified_data = correction.apply_offset(az_param_config, 100.0, telemetry_data) - - # Verify original data unchanged - pd.testing.assert_frame_equal(telemetry_data, original_telemetry) - # Verify modified data is different - self.assertFalse(modified_data["hps.az_ang_nonlin"].equals(original_telemetry["hps.az_ang_nonlin"])) - logger.info("✓ Original data not modified (proper copy made)") - - # ========== Test 9: Multiple columns preserved ========== - logger.info("\nTest 9: All DataFrame columns preserved after offset") - - modified_data = correction.apply_offset(az_param_config, 50.0, telemetry_data) - - # Verify all columns still present - self.assertEqual(set(modified_data.columns), set(telemetry_data.columns)) - # Verify only target column modified - self.assertTrue(modified_data["frame"].equals(telemetry_data["frame"])) - self.assertTrue(modified_data["ert"].equals(telemetry_data["ert"])) - self.assertFalse(modified_data["hps.az_ang_nonlin"].equals(telemetry_data["hps.az_ang_nonlin"])) - logger.info("✓ All DataFrame columns preserved, only target modified") - - logger.info("\n" + "=" * 80) - logger.info("✓ apply_offset() TEST PASSED") - logger.info(" - OFFSET_KERNEL with arcseconds conversion ✓") - logger.info(" - OFFSET_KERNEL with negative values ✓") - logger.info(" - OFFSET_KERNEL missing field handling ✓") - logger.info(" - OFFSET_KERNEL without units ✓") - logger.info(" - OFFSET_TIME with milliseconds conversion ✓") - logger.info(" - OFFSET_TIME with negative values ✓") - logger.info(" - CONSTANT_KERNEL pass-through ✓") - logger.info(" - Data not modified in place ✓") - logger.info(" - All columns preserved ✓") - logger.info("=" * 80) - - def test_helper_load_param_sets(self): - """Test load_param_sets function for parameter set generation.""" - logger.info("=" * 80) - logger.info("TEST: load_param_sets() Function") - logger.info("=" * 80) - - # Create minimal config with different parameter types - data_dir = self.root_dir / "tests" / "data" / "clarreo" / "gcs" - generic_dir = self.root_dir / "data" / "generic" - - # ========== Test 1: OFFSET_KERNEL parameter with arcseconds ========== - logger.info("\nTest 1: OFFSET_KERNEL parameter generation with arcseconds") - - offset_kernel_param = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_KERNEL, - config_file=Path("cprs_az_v01.attitude.ck.json"), - data=dict( - field="hps.az_ang_nonlin", - units="arcseconds", - current_value=0.0, # Starting at zero - sigma=100.0, # ±100 arcseconds standard deviation - bounds=[-200.0, 200.0], # ±200 arcseconds limits - ), - ) - - config = correction.CorrectionConfig( - seed=42, # For reproducibility - n_iterations=3, - parameters=[offset_kernel_param], - geo=correction.GeolocationConfig( - meta_kernel_file=data_dir / "meta_kernel.tm", - generic_kernel_dir=generic_dir, - dynamic_kernels=[], - instrument_name="CPRS_HYSICS", - time_field="corrected_timestamp", - ), - performance_threshold_m=250.0, - performance_spec_percent=39.0, - earth_radius_m=6378137.0, - ) - - param_sets = correction.load_param_sets(config) - - # Validate structure - self.assertEqual(len(param_sets), 3, "Should generate 3 parameter sets") - self.assertEqual(len(param_sets[0]), 1, "Each set should have 1 parameter") - - # Check each parameter set - for i, param_set in enumerate(param_sets): - param_config, param_value = param_set[0] - self.assertEqual(param_config.ptype, correction.ParameterType.OFFSET_KERNEL) - self.assertIsInstance(param_value, float, "OFFSET_KERNEL should produce float value") - # Value should be in radians (converted from arcseconds) - self.assertLess(abs(param_value), np.deg2rad(200.0 / 3600.0), "Value should be within bounds") - logger.info(f" Set {i}: {param_value:.9f} rad ({np.rad2deg(param_value) * 3600.0:.3f} arcsec)") - - logger.info("✓ OFFSET_KERNEL parameter generation works correctly") - - # ========== Test 2: OFFSET_TIME parameter with milliseconds ========== - logger.info("\nTest 2: OFFSET_TIME parameter generation with milliseconds") - - offset_time_param = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_TIME, - config_file=None, - data=dict( - field="corrected_timestamp", - units="milliseconds", - current_value=0.0, - sigma=10.0, # ±10 ms standard deviation - bounds=[-50.0, 50.0], # ±50 ms limits - ), - ) - - config.parameters = [offset_time_param] - config.n_iterations = 3 - - param_sets = correction.load_param_sets(config) - - self.assertEqual(len(param_sets), 3) - for i, param_set in enumerate(param_sets): - param_config, param_value = param_set[0] - self.assertEqual(param_config.ptype, correction.ParameterType.OFFSET_TIME) - self.assertIsInstance(param_value, float, "OFFSET_TIME should produce float value") - # Value should be in seconds (converted from milliseconds) - self.assertLess(abs(param_value), 0.050, "Value should be within bounds (50 ms = 0.050 s)") - logger.info(f" Set {i}: {param_value:.6f} s ({param_value * 1000.0:.3f} ms)") - - logger.info("✓ OFFSET_TIME parameter generation works correctly") - - # ========== Test 3: CONSTANT_KERNEL parameter with 3D angles ========== - logger.info("\nTest 3: CONSTANT_KERNEL parameter generation with 3D angles") - - constant_kernel_param = correction.ParameterConfig( - ptype=correction.ParameterType.CONSTANT_KERNEL, - config_file=data_dir / "cprs_base_v01.attitude.ck.json", - data=dict( - field="cprs_base", - units="arcseconds", - current_value=[0.0, 0.0, 0.0], # [roll, pitch, yaw] - sigma=50.0, # ±50 arcseconds for each axis - bounds=[-100.0, 100.0], # ±100 arcseconds limits - ), - ) - - config.parameters = [constant_kernel_param] - config.n_iterations = 2 - - param_sets = correction.load_param_sets(config) - - self.assertEqual(len(param_sets), 2) - for i, param_set in enumerate(param_sets): - param_config, param_value = param_set[0] - self.assertEqual(param_config.ptype, correction.ParameterType.CONSTANT_KERNEL) - self.assertIsInstance(param_value, pd.DataFrame, "CONSTANT_KERNEL should produce DataFrame") - self.assertIn("angle_x", param_value.columns) - self.assertIn("angle_y", param_value.columns) - self.assertIn("angle_z", param_value.columns) - self.assertIn("ugps", param_value.columns) - - # Check each angle is within bounds (in radians) - max_bound_rad = np.deg2rad(100.0 / 3600.0) - for angle_col in ["angle_x", "angle_y", "angle_z"]: - angle_val = param_value[angle_col].iloc[0] - self.assertLess(abs(angle_val), max_bound_rad, f"{angle_col} should be within bounds") - - logger.info( - f" Set {i}: roll={param_value['angle_x'].iloc[0]:.9f}, " - f"pitch={param_value['angle_y'].iloc[0]:.9f}, " - f"yaw={param_value['angle_z'].iloc[0]:.9f} rad" - ) - - logger.info("✓ CONSTANT_KERNEL parameter generation works correctly") - - # ========== Test 4: Multiple parameters together ========== - logger.info("\nTest 4: Multiple parameters in single config") - - config.parameters = [offset_kernel_param, offset_time_param, constant_kernel_param] - config.n_iterations = 2 - - param_sets = correction.load_param_sets(config) - - self.assertEqual(len(param_sets), 2, "Should generate 2 parameter sets") - self.assertEqual(len(param_sets[0]), 3, "Each set should have 3 parameters") - - # Verify each parameter type is present - for i, param_set in enumerate(param_sets): - types_found = [p[0].ptype for p in param_set] - self.assertIn(correction.ParameterType.OFFSET_KERNEL, types_found) - self.assertIn(correction.ParameterType.OFFSET_TIME, types_found) - self.assertIn(correction.ParameterType.CONSTANT_KERNEL, types_found) - logger.info(f" Set {i}: Contains all 3 parameter types ✓") - - logger.info("✓ Multiple parameters handled correctly") - - # ========== Test 5: Fixed parameter (sigma=0) ========== - logger.info("\nTest 5: Fixed parameter with sigma=0") - - fixed_param = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_KERNEL, - config_file=Path("fixed.json"), - data=dict( - field="fixed_field", - units="arcseconds", - current_value=25.0, # Fixed at 25 arcseconds - sigma=0.0, # No variation - bounds=[-100.0, 100.0], - ), - ) - - config.parameters = [fixed_param] - config.n_iterations = 3 - - param_sets = correction.load_param_sets(config) - - expected_value_rad = np.deg2rad(25.0 / 3600.0) - for i, param_set in enumerate(param_sets): - param_config, param_value = param_set[0] - self.assertAlmostEqual(param_value, expected_value_rad, places=12, msg="Fixed parameter should not vary") - logger.info(f" Set {i}: {param_value:.9f} rad (constant)") - - logger.info("✓ Fixed parameter (sigma=0) works correctly") - - # ========== Test 6: Seed reproducibility ========== - logger.info("\nTest 6: Random seed reproducibility") - - config.parameters = [offset_kernel_param] - config.n_iterations = 3 - config.seed = 123 - - param_sets_1 = correction.load_param_sets(config) - - # Reset and generate again with same seed - config.seed = 123 - param_sets_2 = correction.load_param_sets(config) - - # Should produce identical values - for i in range(len(param_sets_1)): - val_1 = param_sets_1[i][0][1] - val_2 = param_sets_2[i][0][1] - self.assertAlmostEqual(val_1, val_2, places=12, msg=f"Set {i} should be identical with same seed") - logger.info(f" Set {i}: {val_1:.9f} rad (reproducible)") - - logger.info("✓ Random seed reproducibility verified") - - # ========== Test 7: Parameter without sigma (should use current_value) ========== - logger.info("\nTest 7: Parameter without sigma field") - - no_sigma_param = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_KERNEL, - config_file=Path("no_sigma.json"), - data=dict( - field="test_field", - units="arcseconds", - current_value=15.0, - # No sigma specified - bounds=[-100.0, 100.0], - ), - ) - - config.parameters = [no_sigma_param] - config.n_iterations = 3 - - param_sets = correction.load_param_sets(config) - - expected_value_rad = np.deg2rad(15.0 / 3600.0) - for i, param_set in enumerate(param_sets): - param_config, param_value = param_set[0] - self.assertAlmostEqual( - param_value, expected_value_rad, places=12, msg="Parameter without sigma should use current_value" - ) - - logger.info("✓ Parameter without sigma uses current_value correctly") - - logger.info("\n" + "=" * 80) - logger.info("✓ load_param_sets() TEST PASSED") - logger.info(" - OFFSET_KERNEL generation ✓") - logger.info(" - OFFSET_TIME generation ✓") - logger.info(" - CONSTANT_KERNEL generation ✓") - logger.info(" - Multiple parameters ✓") - logger.info(" - Fixed parameters (sigma=0) ✓") - logger.info(" - Seed reproducibility ✓") - logger.info(" - Parameters without sigma ✓") - logger.info("=" * 80) - - def test_offset_time_unit_conversion_integration(self): - """Test the full integration of load_param_sets -> apply_offset for OFFSET_TIME with all unit types. - - This test verifies that: - 1. load_param_sets correctly converts milliseconds/microseconds -> seconds - 2. apply_offset correctly converts seconds -> microseconds for the timestamp field - 3. The end-to-end pipeline produces correct results - 4. All unit conversion paths are exercised (including uncovered lines) - """ - logger.info("=" * 80) - logger.info("TEST: OFFSET_TIME Unit Conversion Integration (load_param_sets -> apply_offset)") - logger.info("=" * 80) - - data_dir = self.root_dir / "tests" / "data" / "clarreo" / "gcs" - generic_dir = self.root_dir / "data" / "generic" - - # Test data with timestamps in microseconds (typical format) - science_data = pd.DataFrame( - { - "corrected_timestamp": [1000000.0, 2000000.0, 3000000.0, 4000000.0, 5000000.0], - "measurement": [1.0, 2.0, 3.0, 4.0, 5.0], - } - ) - - # ========== Test 1: Milliseconds with sigma ========== - logger.info("\nTest 1: OFFSET_TIME with milliseconds unit (with sigma)") - - offset_time_ms_param = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_TIME, - config_file=None, - data=dict( - field="corrected_timestamp", - units="milliseconds", - current_value=10.0, # 10 milliseconds - sigma=2.0, # ±2 ms variation - bounds=[-50.0, 50.0], # ±50 ms limits - ), - ) - - config = correction.CorrectionConfig( - seed=42, - n_iterations=1, - parameters=[offset_time_ms_param], - geo=correction.GeolocationConfig( - meta_kernel_file=data_dir / "meta_kernel.tm", - generic_kernel_dir=generic_dir, - dynamic_kernels=[], - instrument_name="CPRS_HYSICS", - time_field="corrected_timestamp", - ), - performance_threshold_m=250.0, - performance_spec_percent=39.0, - earth_radius_m=6378137.0, - ) - - # Step 1: load_param_sets converts milliseconds -> seconds - param_sets = correction.load_param_sets(config) - self.assertEqual(len(param_sets), 1) - param_config, param_value_seconds = param_sets[0][0] - - # Verify conversion to seconds - self.assertLess(abs(param_value_seconds), 0.050, "Value should be in seconds, within 50ms bound") - logger.info(f" load_param_sets output: {param_value_seconds:.6f} s = {param_value_seconds * 1000.0:.3f} ms") - - # Step 2: apply_offset converts seconds -> microseconds - original_mean = science_data["corrected_timestamp"].mean() - modified_data = correction.apply_offset(param_config, param_value_seconds, science_data) - - # Verify the offset was applied correctly - actual_delta_us = modified_data["corrected_timestamp"].mean() - original_mean - expected_delta_us = param_value_seconds * 1000000.0 # seconds -> microseconds - self.assertAlmostEqual(actual_delta_us, expected_delta_us, places=3) - logger.info(f" Expected delta: {expected_delta_us:.3f} µs") - logger.info(f" Actual delta: {actual_delta_us:.3f} µs") - logger.info("✓ Milliseconds path works correctly (load_param_sets -> apply_offset)") - - # ========== Test 2: Microseconds with sigma (covers lines 1442-1446) ========== - logger.info("\nTest 2: OFFSET_TIME with microseconds unit (with sigma)") - - offset_time_us_param = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_TIME, - config_file=None, - data=dict( - field="corrected_timestamp", - units="microseconds", - current_value=5000.0, # 5000 microseconds = 5 ms - sigma=1000.0, # ±1000 µs variation - bounds=[-10000.0, 10000.0], # ±10000 µs limits - ), - ) - - config.parameters = [offset_time_us_param] - - # Step 1: load_param_sets converts microseconds -> seconds - param_sets = correction.load_param_sets(config) - param_config, param_value_seconds = param_sets[0][0] - - # Verify conversion to seconds - self.assertLess(abs(param_value_seconds), 0.010, "Value should be in seconds, within 10ms bound") - logger.info(f" load_param_sets output: {param_value_seconds:.6f} s = {param_value_seconds * 1000000.0:.1f} µs") - - # Step 2: apply_offset converts seconds -> microseconds - original_mean = science_data["corrected_timestamp"].mean() - modified_data = correction.apply_offset(param_config, param_value_seconds, science_data) - - # Verify the offset was applied correctly - actual_delta_us = modified_data["corrected_timestamp"].mean() - original_mean - expected_delta_us = param_value_seconds * 1000000.0 - self.assertAlmostEqual(actual_delta_us, expected_delta_us, places=3) - logger.info(f" Expected delta: {expected_delta_us:.3f} µs") - logger.info(f" Actual delta: {actual_delta_us:.3f} µs") - logger.info("✓ Microseconds path works correctly (load_param_sets -> apply_offset)") - - # ========== Test 3: Seconds unit (baseline) ========== - logger.info("\nTest 3: OFFSET_TIME with seconds unit (baseline)") - - offset_time_s_param = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_TIME, - config_file=None, - data=dict( - field="corrected_timestamp", - units="seconds", - current_value=0.008, # 8 milliseconds - sigma=0.002, # ±2 ms variation - bounds=[-0.050, 0.050], # ±50 ms limits - ), - ) - - config.parameters = [offset_time_s_param] - - # Step 1: load_param_sets (no conversion needed, already in seconds) - param_sets = correction.load_param_sets(config) - param_config, param_value_seconds = param_sets[0][0] - - self.assertLess(abs(param_value_seconds), 0.050, "Value should be in seconds") - logger.info(f" load_param_sets output: {param_value_seconds:.6f} s") - - # Step 2: apply_offset converts seconds -> microseconds - original_mean = science_data["corrected_timestamp"].mean() - modified_data = correction.apply_offset(param_config, param_value_seconds, science_data) - - # Verify the offset was applied correctly - actual_delta_us = modified_data["corrected_timestamp"].mean() - original_mean - expected_delta_us = param_value_seconds * 1000000.0 - self.assertAlmostEqual(actual_delta_us, expected_delta_us, places=3) - logger.info(f" Expected delta: {expected_delta_us:.3f} µs") - logger.info(f" Actual delta: {actual_delta_us:.3f} µs") - logger.info("✓ Seconds path works correctly (load_param_sets -> apply_offset)") - - # ========== Test 4: Fixed offset (sigma=0) with milliseconds ========== - logger.info("\nTest 4: OFFSET_TIME fixed offset (sigma=0) with milliseconds") - - fixed_time_param = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_TIME, - config_file=None, - data=dict( - field="corrected_timestamp", - units="milliseconds", - current_value=15.0, # Fixed 15 ms - sigma=0.0, # No variation - bounds=[-50.0, 50.0], - ), - ) - - config.parameters = [fixed_time_param] - config.n_iterations = 3 - - # Generate multiple sets - all should be identical - param_sets = correction.load_param_sets(config) - self.assertEqual(len(param_sets), 3) - - expected_seconds = 15.0 / 1000.0 # 15 ms = 0.015 s - for i, param_set in enumerate(param_sets): - param_config, param_value_seconds = param_set[0] - self.assertAlmostEqual(param_value_seconds, expected_seconds, places=9) - logger.info(f" Set {i}: {param_value_seconds:.6f} s (constant)") - - # Apply to data and verify - original_mean = science_data["corrected_timestamp"].mean() - modified_data = correction.apply_offset(param_config, expected_seconds, science_data) - - actual_delta_us = modified_data["corrected_timestamp"].mean() - original_mean - expected_delta_us = 15000.0 # 15 ms = 15000 µs - self.assertAlmostEqual(actual_delta_us, expected_delta_us, places=3) - logger.info(f" Applied offset: {actual_delta_us:.3f} µs (expected {expected_delta_us:.3f} µs)") - logger.info("✓ Fixed offset with milliseconds works correctly") - - # ========== Test 5: Fixed offset (sigma=0) with microseconds ========== - logger.info("\nTest 5: OFFSET_TIME fixed offset (sigma=0) with microseconds") - - fixed_time_us_param = correction.ParameterConfig( - ptype=correction.ParameterType.OFFSET_TIME, - config_file=None, - data=dict( - field="corrected_timestamp", - units="microseconds", - current_value=7500.0, # Fixed 7500 µs = 7.5 ms - sigma=0.0, # No variation - bounds=[-50000.0, 50000.0], - ), - ) - - config.parameters = [fixed_time_us_param] - config.n_iterations = 2 - - param_sets = correction.load_param_sets(config) - expected_seconds = 7500.0 / 1000000.0 # 7500 µs = 0.0075 s - - for i, param_set in enumerate(param_sets): - param_config, param_value_seconds = param_set[0] - self.assertAlmostEqual(param_value_seconds, expected_seconds, places=9) - logger.info(f" Set {i}: {param_value_seconds:.6f} s (constant)") - - # Apply and verify - original_mean = science_data["corrected_timestamp"].mean() - modified_data = correction.apply_offset(param_config, expected_seconds, science_data) - - actual_delta_us = modified_data["corrected_timestamp"].mean() - original_mean - expected_delta_us = 7500.0 # 7500 µs - self.assertAlmostEqual(actual_delta_us, expected_delta_us, places=3) - logger.info(f" Applied offset: {actual_delta_us:.3f} µs (expected {expected_delta_us:.3f} µs)") - logger.info("✓ Fixed offset with microseconds works correctly") - - # ========== Summary ========== - logger.info("\n" + "=" * 80) - logger.info("✓ OFFSET_TIME UNIT CONVERSION INTEGRATION TEST PASSED") - logger.info(" - Integration: load_param_sets -> apply_offset ✓") - logger.info("=" * 80) - - -# ============================================================================= -# MAIN ENTRY POINT -# ============================================================================= - - -def main(): - """Main entry point for standalone execution.""" - parser = argparse.ArgumentParser( - description="Unified Correction Test Suite", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Test Modes: ------------ -upstream - Test kernel creation + geolocation -downstream - Test pairing + matching + error statistics -unittest - Run all unit tests - -Examples: ---------- -# Run downstream test -python test_correction.py --mode downstream --quick - -# Run with specific test cases -python test_correction.py --mode downstream --test-cases 1 2 --iterations 10 - -# Run unit tests -python test_correction.py --mode unittest -pytest test_correction.py -v - -# Run upstream test (when implemented) -python test_correction.py --mode upstream --iterations 5 - """, - ) - - parser.add_argument( - "--mode", type=str, choices=["upstream", "downstream", "unittest"], required=True, help="Test mode to run" - ) - parser.add_argument("--quick", action="store_true", help="Quick test (2 iterations, test case 1 only)") - parser.add_argument("--iterations", type=int, default=5, help="Number of correction iterations (default: 5)") - parser.add_argument( - "--test-cases", nargs="+", default=None, help="Specific test cases for downstream mode (e.g., 1 2 3)" - ) - parser.add_argument("--output-dir", type=str, default=None, help="Output directory for results") - - args = parser.parse_args() - - # Setup logging - utils.enable_logging(log_level=logging.INFO, extra_loggers=[__name__]) - - if args.mode == "unittest": - # Run unit tests - unittest.main(argv=[""], exit=True) - - elif args.mode == "upstream": - # Run upstream test - if args.quick: - n_iterations = 2 - else: - n_iterations = args.iterations - - work_dir = Path(args.output_dir) if args.output_dir else None - - results_list, results_dict, output_file = run_upstream_pipeline(n_iterations=n_iterations, work_dir=work_dir) - - logger.info(f"\n✅ Upstream test complete!") - logger.info(f"Status: {results_dict['status']}") - logger.info(f"Iterations: {results_dict['iterations']}") - logger.info(f"Parameter sets: {results_dict['parameter_sets']}") - logger.info(f"Output file: {output_file}") - - elif args.mode == "downstream": - # Run downstream test - if args.quick: - n_iterations = 2 - test_cases = ["1"] - else: - n_iterations = args.iterations - test_cases = args.test_cases - - work_dir = Path(args.output_dir) if args.output_dir else None - - results_list, results_dict, output_file = run_downstream_pipeline( - n_iterations=n_iterations, test_cases=test_cases, work_dir=work_dir - ) - - logger.info(f"\n✅ Downstream test complete!") - logger.info(f"Output: {output_file}") - - # Validate results - ds = xr.open_dataset(output_file) - logger.info(f"NetCDF dimensions: {dict(ds.sizes)}") - assert not np.all(np.isnan(ds["im_lat_error_km"].values)), "No data stored" - logger.info("✅ Validation passed!") - - -if __name__ == "__main__": - main() diff --git a/tests/test_correction/test_dataio.py b/tests/test_correction/test_dataio.py index b2dc2d22..81910525 100644 --- a/tests/test_correction/test_dataio.py +++ b/tests/test_correction/test_dataio.py @@ -1,42 +1,17 @@ -""" -Tests for dataio.py module - -This module tests data I/O functionality: -- S3 object discovery and download -- NetCDF file handling -- Configuration management -- File path operations - -Running Tests: -------------- -# Via pytest (recommended) -pytest tests/test_correction/test_dataio.py -v - -# Run specific test -pytest tests/test_correction/test_dataio.py::DataIOTestCase::test_find_objects -v +"""Tests for ``curryer.correction.dataio`` (generic, no AWS credentials needed). -# Standalone execution -python tests/test_correction/test_dataio.py - -Notes: ------ -These tests use mock S3 clients to avoid requiring AWS credentials -or network access during testing. +CLARREO-specific S3 integration tests live in ``clarreo/test_clarreo_dataio.py``. """ from __future__ import annotations import datetime as dt import logging -import os -import tempfile -import unittest from pathlib import Path import pandas as pd import pytest -from curryer import utils from curryer.correction.dataio import ( S3Configuration, download_netcdf_objects, @@ -46,23 +21,24 @@ ) logger = logging.getLogger(__name__) -utils.enable_logging(log_level=logging.INFO, extra_loggers=[__name__]) + + +# ── FakeS3Client ────────────────────────────────────────────────────────────── class FakeS3Client: - def __init__(self, objects): - self.objects = objects # dict: key -> bytes - self.list_calls = [] - self.download_calls = [] + """Minimal in-memory S3 mock.""" + + def __init__(self, objects: dict[str, bytes]): + self.objects = objects + self.list_calls: list = [] + self.download_calls: list = [] def list_objects_v2(self, **kwargs): bucket = kwargs["Bucket"] prefix = kwargs.get("Prefix", "") self.list_calls.append((bucket, prefix)) - contents = [] - for key in sorted(self.objects): - if key.startswith(prefix): - contents.append({"Key": key}) + contents = [{"Key": k} for k in sorted(self.objects) if k.startswith(prefix)] return {"Contents": contents, "IsTruncated": False} def download_file(self, bucket, key, filename): @@ -70,254 +46,97 @@ def download_file(self, bucket, key, filename): Path(filename).write_bytes(self.objects[key]) -class DataIOTestCase(unittest.TestCase): - def setUp(self) -> None: - self.__tmp_dir = tempfile.TemporaryDirectory() - self.addCleanup(self.__tmp_dir.cleanup) - self.tmp_dir = Path(self.__tmp_dir.name) - - def test_find_netcdf_objects_filters_and_matches_prefix(self): - config = S3Configuration("test-bucket", "L1a/nadir") - objects = { - "L1a/nadir/20181225/file1.nc": b"data1", - "L1a/nadir/20181225/file2.txt": b"ignored", - "L1a/nadir/20181226/file3.nc": b"data3", - "L1a/nadir/20181227/file4.nc": b"out_of_range", - } - client = FakeS3Client(objects) - - logger.info(f"Testing NetCDF discovery with mock S3 - bucket: {config.bucket}, prefix: {config.base_prefix}") - keys = find_netcdf_objects( - config, - start_date=dt.date(2018, 12, 25), - end_date=dt.date(2018, 12, 26), - s3_client=client, - ) - logger.info(f"Found {len(keys)} NetCDF files matching date range and .nc extension") - - self.assertListEqual( - keys, - [ - "L1a/nadir/20181225/file1.nc", - "L1a/nadir/20181226/file3.nc", - ], - ) - self.assertListEqual( - client.list_calls, - [ - ("test-bucket", "L1a/nadir/20181225/"), - ("test-bucket", "L1a/nadir/20181226/"), - ], - ) - logger.info("Verified correct S3 prefix queries and file filtering") - - def test_download_netcdf_objects_writes_files(self): - config = S3Configuration("test-bucket", "L1a/nadir") - objects = { - "L1a/nadir/20181225/file1.nc": b"data1", - "L1a/nadir/20181225/file2.nc": b"data2", - } - client = FakeS3Client(objects) - - logger.info(f"Testing NetCDF download with mock S3 - {len(objects)} objects to destination: {self.tmp_dir}") - output_paths = download_netcdf_objects( - config, - objects.keys(), - self.tmp_dir, - s3_client=client, - ) - - logger.info(f"Successfully downloaded {len(output_paths)} files: {[p.name for p in output_paths]}") - self.assertSetEqual({p.name for p in output_paths}, {"file1.nc", "file2.nc"}) - for path in output_paths: - self.assertEqual(path.read_bytes(), objects[f"L1a/nadir/20181225/{path.name}"]) - logger.info("Verified file contents match S3 objects") - - -@unittest.skipUnless( - ( - os.getenv("AWS_ACCESS_KEY_ID", "") - and os.getenv("AWS_SECRET_ACCESS_KEY", "") - and os.getenv("AWS_SESSION_TOKEN", "") +# ── find / download ─────────────────────────────────────────────────────────── + + +def test_find_netcdf_objects_filters_and_matches_prefix(): + config = S3Configuration("test-bucket", "L1a/nadir") + objects = { + "L1a/nadir/20181225/file1.nc": b"data1", + "L1a/nadir/20181225/file2.txt": b"ignored", + "L1a/nadir/20181226/file3.nc": b"data3", + "L1a/nadir/20181227/file4.nc": b"out_of_range", + } + client = FakeS3Client(objects) + keys = find_netcdf_objects( + config, start_date=dt.date(2018, 12, 25), end_date=dt.date(2018, 12, 26), s3_client=client ) - or os.getenv("C9_USER"), - "Requires tester to set AWS access key environment variables or run in Cloud9.", -) -class ClarreoDataIOTestCase(unittest.TestCase): - def setUp(self) -> None: - self.__tmp_dir = tempfile.TemporaryDirectory() - self.addCleanup(self.__tmp_dir.cleanup) - self.tmp_dir = Path(self.__tmp_dir.name) + assert keys == ["L1a/nadir/20181225/file1.nc", "L1a/nadir/20181226/file3.nc"] + assert client.list_calls == [ + ("test-bucket", "L1a/nadir/20181225/"), + ("test-bucket", "L1a/nadir/20181226/"), + ] + + +def test_download_netcdf_objects_writes_files(tmp_path): + config = S3Configuration("test-bucket", "L1a/nadir") + objects = { + "L1a/nadir/20181225/file1.nc": b"data1", + "L1a/nadir/20181225/file2.nc": b"data2", + } + client = FakeS3Client(objects) + output_paths = download_netcdf_objects(config, objects.keys(), tmp_path, s3_client=client) + assert {p.name for p in output_paths} == {"file1.nc", "file2.nc"} + for p in output_paths: + assert p.read_bytes() == objects[f"L1a/nadir/20181225/{p.name}"] + + +# ── validation ──────────────────────────────────────────────────────────────── - def test_l0(self): - config = S3Configuration("clarreo", "L0/telemetry/hps_navigation/") - start_date = dt.date(2017, 1, 15) - end_date = dt.date(2017, 1, 15) - logger.info(f"Querying CSDS S3 bucket '{config.bucket}' for L0 telemetry data") - logger.info(f"Date range: {start_date} to {end_date}, prefix: {config.base_prefix}") +class MockConfig: + class MockGeo: + time_field = "corrected_timestamp" - keys = find_netcdf_objects( - config, - start_date=start_date, - end_date=end_date, - ) + def __init__(self): + self.geo = self.MockGeo() - logger.info(f"Successfully retrieved {len(keys)} L0 telemetry files from CSDS S3") - if keys: - logger.info(f"Example file: {keys[0]}") - self.assertListEqual( - keys, ["L0/telemetry/hps_navigation/20170115/CPF_TLM_L0.V00-000.hps_navigation-20170115-0.0.0.nc"] - ) +@pytest.fixture +def mock_config(): + return MockConfig() - def test_l1a(self): - config = S3Configuration("clarreo", "L1a/nadir/") - start_date = dt.date(2022, 6, 3) - end_date = dt.date(2022, 6, 3) - logger.info(f"Querying CSDS S3 bucket '{config.bucket}' for L1a nadir science data") - logger.info(f"Date range: {start_date} to {end_date}, prefix: {config.base_prefix}") +def test_validate_telemetry_valid(mock_config): + df = pd.DataFrame({"time": [1.0, 2.0], "position_x": [100.0, 200.0]}) + validate_telemetry_output(df, mock_config) # must not raise - keys = find_netcdf_objects( - config, - start_date=start_date, - end_date=end_date, - ) - logger.info(f"Successfully retrieved {len(keys)} L1a science files from CSDS S3") - if keys: - logger.info(f"Example file: {keys[0]}") +def test_validate_telemetry_not_dataframe(mock_config): + with pytest.raises(TypeError, match="must return pd.DataFrame"): + validate_telemetry_output({"not": "a dataframe"}, mock_config) - self.assertEqual(34, len(keys)) - self.assertIn("L1a/nadir/20220603/nadir-20220603T235952-step22-geolocation_creation-0.0.0.nc", keys) +def test_validate_telemetry_empty(mock_config): + with pytest.raises(ValueError, match="empty DataFrame"): + validate_telemetry_output(pd.DataFrame(), mock_config) -class MockConfig: - """Mock config for validation tests.""" - class MockGeo: - time_field = "corrected_timestamp" +def test_validate_science_valid(mock_config): + df = pd.DataFrame({"corrected_timestamp": [1e6, 2e6], "frame_id": [1, 2]}) + validate_science_output(df, mock_config) # must not raise - def __init__(self): - self.geo = self.MockGeo() + +def test_validate_science_not_dataframe(mock_config): + with pytest.raises(TypeError, match="must return pd.DataFrame"): + validate_science_output([1, 2, 3], mock_config) + + +def test_validate_science_empty(mock_config): + with pytest.raises(ValueError, match="empty DataFrame"): + validate_science_output(pd.DataFrame(), mock_config) + + +def test_validate_science_missing_time_field(mock_config): + df = pd.DataFrame({"frame_id": [1, 2], "other": [100, 200]}) + with pytest.raises(ValueError, match="must include time field 'corrected_timestamp'"): + validate_science_output(df, mock_config) -class TestDataIOValidation(unittest.TestCase): - """Test validation functions for data loaders.""" - - def setUp(self): - """Set up test fixtures.""" - self.config = MockConfig() - - def test_validate_telemetry_output_valid(self): - """Test that valid telemetry passes validation.""" - valid_df = pd.DataFrame( - { - "time": [1.0, 2.0, 3.0], - "position_x": [100.0, 200.0, 300.0], - } - ) - - logger.info(f"Testing telemetry validation with DataFrame shape: {valid_df.shape}") - # Should not raise - validate_telemetry_output(valid_df, self.config) - logger.info("Telemetry validation passed for valid DataFrame") - - def test_validate_telemetry_output_not_dataframe(self): - """Test that non-DataFrame input raises TypeError.""" - logger.info("Testing telemetry validation rejects non-DataFrame input") - with pytest.raises(TypeError, match="must return pd.DataFrame"): - validate_telemetry_output({"not": "a dataframe"}, self.config) - logger.info("Correctly raised TypeError for non-DataFrame telemetry") - - def test_validate_telemetry_output_empty(self): - """Test that empty DataFrame raises ValueError.""" - empty_df = pd.DataFrame() - - logger.info("Testing telemetry validation rejects empty DataFrame") - with pytest.raises(ValueError, match="empty DataFrame"): - validate_telemetry_output(empty_df, self.config) - logger.info("Correctly raised ValueError for empty telemetry DataFrame") - - def test_validate_science_output_valid(self): - """Test that valid science output passes validation.""" - valid_df = pd.DataFrame( - { - "corrected_timestamp": [1e6, 2e6, 3e6], - "frame_id": [1, 2, 3], - } - ) - - logger.info( - f"Testing science validation with DataFrame shape: {valid_df.shape}, time field: {self.config.geo.time_field}" - ) - # Should not raise - validate_science_output(valid_df, self.config) - logger.info("Science validation passed for valid DataFrame with required time field") - - def test_validate_science_output_not_dataframe(self): - """Test that non-DataFrame input raises TypeError.""" - logger.info("Testing science validation rejects non-DataFrame input") - with pytest.raises(TypeError, match="must return pd.DataFrame"): - validate_science_output([1, 2, 3], self.config) - logger.info("Correctly raised TypeError for non-DataFrame science output") - - def test_validate_science_output_empty(self): - """Test that empty DataFrame raises ValueError.""" - empty_df = pd.DataFrame() - - logger.info("Testing science validation rejects empty DataFrame") - with pytest.raises(ValueError, match="empty DataFrame"): - validate_science_output(empty_df, self.config) - logger.info("Correctly raised ValueError for empty science DataFrame") - - def test_validate_science_output_missing_time_field(self): - """Test that DataFrame missing time field raises ValueError.""" - df_no_time = pd.DataFrame( - { - "frame_id": [1, 2, 3], - "other_field": [100, 200, 300], - } - ) - - logger.info( - f"Testing science validation rejects DataFrame missing required time field '{self.config.geo.time_field}'" - ) - with pytest.raises(ValueError, match="must include time field 'corrected_timestamp'"): - validate_science_output(df_no_time, self.config) - logger.info("Correctly raised ValueError for missing time field") - - def test_validate_science_output_custom_time_field(self): - """Test that validation respects custom time field name.""" - # Change config to use different time field - self.config.geo.time_field = "custom_time" - - logger.info(f"Testing science validation with custom time field: {self.config.geo.time_field}") - df_custom_time = pd.DataFrame( - { - "custom_time": [1.0, 2.0, 3.0], - "data": [100, 200, 300], - } - ) - - # Should pass with custom time field - validate_science_output(df_custom_time, self.config) - logger.info("Validation passed for DataFrame with custom time field") - - # Should fail without custom time field - df_wrong_time = pd.DataFrame( - { - "corrected_timestamp": [1.0, 2.0, 3.0], - "data": [100, 200, 300], - } - ) - - logger.info("Testing validation rejects DataFrame with wrong time field name") - with pytest.raises(ValueError, match="must include time field 'custom_time'"): - validate_science_output(df_wrong_time, self.config) - logger.info("Correctly raised ValueError for wrong time field name") - - -if __name__ == "__main__": - unittest.main() +def test_validate_science_custom_time_field(mock_config): + mock_config.geo.time_field = "custom_time" + df_ok = pd.DataFrame({"custom_time": [1.0, 2.0], "data": [1, 2]}) + validate_science_output(df_ok, mock_config) # must not raise + df_bad = pd.DataFrame({"corrected_timestamp": [1.0, 2.0], "data": [1, 2]}) + with pytest.raises(ValueError, match="must include time field 'custom_time'"): + validate_science_output(df_bad, mock_config) diff --git a/tests/test_correction/test_image_match.py b/tests/test_correction/test_image_match.py index 2be5c15d..fbdb1ebb 100644 --- a/tests/test_correction/test_image_match.py +++ b/tests/test_correction/test_image_match.py @@ -1,40 +1,19 @@ -""" -Tests for image_match.py module - -This module tests the integrated image matching functionality, including: -- Image grid loading from MATLAB files -- Geolocation error application for testing -- Image matching algorithm validation -- Cross-correlation and grid search - -Running Tests: -------------- -# Via pytest (recommended) -pytest tests/test_correction/test_image_match.py -v - -# Run specific test -pytest tests/test_correction/test_image_match.py::TestImageMatch::test_integrated_match_case_1 -v - -# Standalone execution -python tests/test_correction/test_image_match.py - -Requirements: ------------------ -These tests validate image matching algorithms against known test cases, -demonstrating that the Python implementation correctly identifies geolocation -errors through image cross-correlation. +"""Tests for ``curryer.correction.image_match``. + +Covers ``integrated_image_match`` against 20 CLARREO test cases (5 scenes × +4 sub-cases each: unbinned a/b, 3-pixel-binned c/d). """ +from __future__ import annotations + import logging -import tempfile -import unittest from pathlib import Path import numpy as np +import pytest import xarray as xr from scipy.io import loadmat -from curryer import utils from curryer.compute import constants from curryer.correction.data_structures import ( ImageGrid, @@ -45,13 +24,15 @@ from curryer.correction.image_match import integrated_image_match logger = logging.getLogger(__name__) -utils.enable_logging(log_level=logging.INFO, extra_loggers=[__name__]) xr.set_options(display_width=120, display_max_rows=30) np.set_printoptions(linewidth=120) -def image_grid_from_struct(mat_struct): +# ── module-level helpers ────────────────────────────────────────────────────── + + +def image_grid_from_struct(mat_struct) -> ImageGrid: return ImageGrid( data=np.asarray(mat_struct.data), lat=np.asarray(mat_struct.lat), @@ -69,35 +50,42 @@ def great_circle_displacement_deg(lat_km: float, lon_km: float, reference_lat_de return lat_offset_deg, lon_offset_deg -def apply_geolocation_error(subimage: ImageGrid, gcp: ImageGrid, lat_error_km: float, lon_error_km: float): - """Return a copy of the subimage with imposed geolocation error.""" +def apply_geolocation_error(subimage: ImageGrid, gcp: ImageGrid, lat_error_km: float, lon_error_km: float) -> ImageGrid: + """Return a copy of *subimage* with an imposed lat/lon geolocation error.""" mid_lat = float(gcp.lat[gcp.lat.shape[0] // 2, gcp.lat.shape[1] // 2]) - lat_offset_deg, lon_offset_deg = great_circle_displacement_deg(lat_error_km, lon_error_km, mid_lat) + lat_off, lon_off = great_circle_displacement_deg(lat_error_km, lon_error_km, mid_lat) return ImageGrid( data=subimage.data.copy(), - lat=subimage.lat + lat_offset_deg, - lon=subimage.lon + lon_offset_deg, + lat=subimage.lat + lat_off, + lon=subimage.lon + lon_off, h=subimage.h.copy() if subimage.h is not None else None, ) -class ImageMatchTestCase(unittest.TestCase): - def setUp(self) -> None: - root_dir = Path(__file__).parent.parent.parent - self.test_dir = root_dir / "tests" / "data" / "clarreo" / "image_match" - self.assertTrue(self.test_dir.is_dir(), self.test_dir) +# ── fixtures ────────────────────────────────────────────────────────────────── + + +@pytest.fixture(scope="module") +def image_match_dir(): + root = Path(__file__).parent.parent.parent + d = root / "tests" / "data" / "clarreo" / "image_match" + assert d.is_dir(), str(d) + return d - self.__tmp_dir = tempfile.TemporaryDirectory() - self.addCleanup(self.__tmp_dir.cleanup) - self.tmp_dir = Path(self.__tmp_dir.name) - @staticmethod - def run_image_match(subimg_file, gcp_file, ancil_file, psf_file, pix_vec_file, lat_lon_err): - subimage_struct = loadmat(subimg_file, squeeze_me=True, struct_as_record=False)["subimage"] - subimage = image_grid_from_struct(subimage_struct) +# ── test class ──────────────────────────────────────────────────────────────── - gcp_struct = loadmat(gcp_file, squeeze_me=True, struct_as_record=False)["GCP"] - gcp = image_grid_from_struct(gcp_struct) + +class TestImageMatch: + """Validates ``integrated_image_match`` against 20 known test cases.""" + + @pytest.fixture(autouse=True) + def _setup(self, image_match_dir): + self.test_dir = image_match_dir + + def _run(self, subimg_file, gcp_file, ancil_file, psf_file, pix_vec_file, lat_lon_err): + subimage = image_grid_from_struct(loadmat(subimg_file, squeeze_me=True, struct_as_record=False)["subimage"]) + gcp = image_grid_from_struct(loadmat(gcp_file, squeeze_me=True, struct_as_record=False)["GCP"]) los_vectors = loadmat(pix_vec_file, squeeze_me=True)["b_HS"] r_iss = loadmat(ancil_file, squeeze_me=True)["R_ISS_midframe"].ravel() @@ -120,24 +108,26 @@ def run_image_match(subimg_file, gcp_file, ancil_file, psf_file, pix_vec_file, l r_iss_midframe_m=r_iss, los_vectors_hs=los_vectors, optical_psfs=psf_entries, - # lat_error_km=lat_lon_err[0], - # lon_error_km=lat_lon_err[1], geolocation_config=PSFSamplingConfig(), search_config=SearchConfig(), ) - logger.info(f"Image file: {subimg_file.name}") - logger.info(f"GCP file: {gcp_file.name}") - logger.info(f"Lat Error (km): {result.lat_error_km:+6.3f} (exp={lat_lon_err[0]:+6.3f})") - logger.info(f"Lon Error (km): {result.lon_error_km:+6.3f} (exp={lat_lon_err[1]:+6.3f})") - logger.info(f" CCV (final): {result.ccv_final}") - logger.info(f" Pixel (row): {result.final_index_row}") - logger.info(f" Pixel (col): {result.final_index_col}") + logger.info( + "case=%s lat=%+.3f exp=%+.3f lon=%+.3f exp=%+.3f ccv=%.4f row=%d col=%d", + subimg_file.name, + result.lat_error_km, + lat_lon_err[0], + result.lon_error_km, + lat_lon_err[1], + result.ccv_final, + result.final_index_row, + result.final_index_col, + ) return result def test_case_1a_unbinned(self): lat_lon_err = (3.0, -3.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "1" / "TestCase1a_subimage.mat", gcp_file=self.test_dir / "1" / "GCP12055Dili_resampled.mat", ancil_file=self.test_dir / "1" / "R_ISS_midframe_TestCase1.mat", @@ -145,15 +135,15 @@ def test_case_1a_unbinned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 23) - np.testing.assert_allclose(result.final_index_col, 22) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 23) + np.testing.assert_allclose(r.final_index_col, 22) def test_case_1b_unbinned(self): lat_lon_err = (3.0, -3.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "1" / "TestCase1b_subimage.mat", gcp_file=self.test_dir / "1" / "GCP12055Dili_resampled.mat", ancil_file=self.test_dir / "1" / "R_ISS_midframe_TestCase1.mat", @@ -161,15 +151,15 @@ def test_case_1b_unbinned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 21) - np.testing.assert_allclose(result.final_index_col, 22) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 21) + np.testing.assert_allclose(r.final_index_col, 22) def test_case_1c_binned(self): lat_lon_err = (3.0, -3.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "1" / "TestCase1c_subimage_binned.mat", gcp_file=self.test_dir / "1" / "GCP12055Dili_resampled.mat", ancil_file=self.test_dir / "1" / "R_ISS_midframe_TestCase1.mat", @@ -177,15 +167,15 @@ def test_case_1c_binned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 21) - np.testing.assert_allclose(result.final_index_col, 23) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 21) + np.testing.assert_allclose(r.final_index_col, 23) def test_case_1d_binned(self): lat_lon_err = (3.0, -3.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "1" / "TestCase1d_subimage_binned.mat", gcp_file=self.test_dir / "1" / "GCP12055Dili_resampled.mat", ancil_file=self.test_dir / "1" / "R_ISS_midframe_TestCase1.mat", @@ -193,15 +183,15 @@ def test_case_1d_binned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 23) - np.testing.assert_allclose(result.final_index_col, 21) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 23) + np.testing.assert_allclose(r.final_index_col, 21) def test_case_2a_unbinned(self): lat_lon_err = (-3.0, 2.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "2" / "TestCase2a_subimage.mat", gcp_file=self.test_dir / "2" / "GCP10121Maracaibo_resampled.mat", ancil_file=self.test_dir / "2" / "R_ISS_midframe_TestCase2.mat", @@ -209,15 +199,15 @@ def test_case_2a_unbinned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 22) - np.testing.assert_allclose(result.final_index_col, 21) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 22) + np.testing.assert_allclose(r.final_index_col, 21) def test_case_2b_unbinned(self): lat_lon_err = (-3.0, 2.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "2" / "TestCase2b_subimage.mat", gcp_file=self.test_dir / "2" / "GCP10121Maracaibo_resampled.mat", ancil_file=self.test_dir / "2" / "R_ISS_midframe_TestCase2.mat", @@ -225,15 +215,15 @@ def test_case_2b_unbinned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 22) - np.testing.assert_allclose(result.final_index_col, 23) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 22) + np.testing.assert_allclose(r.final_index_col, 23) def test_case_2c_binned(self): lat_lon_err = (-3.0, 2.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "2" / "TestCase2c_subimage_binned.mat", gcp_file=self.test_dir / "2" / "GCP10121Maracaibo_resampled.mat", ancil_file=self.test_dir / "2" / "R_ISS_midframe_TestCase2.mat", @@ -241,15 +231,15 @@ def test_case_2c_binned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 22) - np.testing.assert_allclose(result.final_index_col, 23) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 22) + np.testing.assert_allclose(r.final_index_col, 23) def test_case_2d_binned(self): lat_lon_err = (-3.0, 2.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "2" / "TestCase2d_subimage_binned.mat", gcp_file=self.test_dir / "2" / "GCP10121Maracaibo_resampled.mat", ancil_file=self.test_dir / "2" / "R_ISS_midframe_TestCase2.mat", @@ -257,15 +247,15 @@ def test_case_2d_binned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 23) - np.testing.assert_allclose(result.final_index_col, 23) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 23) + np.testing.assert_allclose(r.final_index_col, 23) def test_case_3a_unbinned(self): lat_lon_err = (1.0, 1.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "3" / "TestCase3a_subimage.mat", gcp_file=self.test_dir / "3" / "GCP10665SantaRosa_resampled.mat", ancil_file=self.test_dir / "3" / "R_ISS_midframe_TestCase3.mat", @@ -273,15 +263,15 @@ def test_case_3a_unbinned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 22) - np.testing.assert_allclose(result.final_index_col, 21) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 22) + np.testing.assert_allclose(r.final_index_col, 21) def test_case_3b_unbinned(self): lat_lon_err = (1.0, 1.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "3" / "TestCase3b_subimage.mat", gcp_file=self.test_dir / "3" / "GCP10665SantaRosa_resampled.mat", ancil_file=self.test_dir / "3" / "R_ISS_midframe_TestCase3.mat", @@ -289,15 +279,15 @@ def test_case_3b_unbinned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 22) - np.testing.assert_allclose(result.final_index_col, 21) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 22) + np.testing.assert_allclose(r.final_index_col, 21) def test_case_3c_binned(self): lat_lon_err = (1.0, 1.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "3" / "TestCase3c_subimage_binned.mat", gcp_file=self.test_dir / "3" / "GCP10665SantaRosa_resampled.mat", ancil_file=self.test_dir / "3" / "R_ISS_midframe_TestCase3.mat", @@ -305,15 +295,15 @@ def test_case_3c_binned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.07) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.07) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 22) - np.testing.assert_allclose(result.final_index_col, 23) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.07) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.07) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 22) + np.testing.assert_allclose(r.final_index_col, 23) def test_case_3d_binned(self): lat_lon_err = (1.0, 1.0) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "3" / "TestCase3d_subimage_binned.mat", gcp_file=self.test_dir / "3" / "GCP10665SantaRosa_resampled.mat", ancil_file=self.test_dir / "3" / "R_ISS_midframe_TestCase3.mat", @@ -321,15 +311,15 @@ def test_case_3d_binned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.07) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.07) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 23) - np.testing.assert_allclose(result.final_index_col, 20) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.07) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.07) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 23) + np.testing.assert_allclose(r.final_index_col, 20) def test_case_4a_unbinned(self): lat_lon_err = (-1.0, -2.5) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "4" / "TestCase4a_subimage.mat", gcp_file=self.test_dir / "4" / "GCP20484Morocco_resampled.mat", ancil_file=self.test_dir / "4" / "R_ISS_midframe_TestCase4.mat", @@ -337,15 +327,15 @@ def test_case_4a_unbinned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 23) - np.testing.assert_allclose(result.final_index_col, 22) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 23) + np.testing.assert_allclose(r.final_index_col, 22) def test_case_4b_unbinned(self): lat_lon_err = (-1.0, -2.5) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "4" / "TestCase4b_subimage.mat", gcp_file=self.test_dir / "4" / "GCP20484Morocco_resampled.mat", ancil_file=self.test_dir / "4" / "R_ISS_midframe_TestCase4.mat", @@ -353,15 +343,15 @@ def test_case_4b_unbinned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 23) - np.testing.assert_allclose(result.final_index_col, 23) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 23) + np.testing.assert_allclose(r.final_index_col, 23) def test_case_4c_binned(self): lat_lon_err = (-1.0, -2.5) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "4" / "TestCase4c_subimage_binned.mat", gcp_file=self.test_dir / "4" / "GCP20484Morocco_resampled.mat", ancil_file=self.test_dir / "4" / "R_ISS_midframe_TestCase4.mat", @@ -369,15 +359,15 @@ def test_case_4c_binned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 21) - np.testing.assert_allclose(result.final_index_col, 21) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 21) + np.testing.assert_allclose(r.final_index_col, 21) def test_case_4d_binned(self): lat_lon_err = (-1.0, -2.5) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "4" / "TestCase4d_subimage_binned.mat", gcp_file=self.test_dir / "4" / "GCP20484Morocco_resampled.mat", ancil_file=self.test_dir / "4" / "R_ISS_midframe_TestCase4.mat", @@ -385,15 +375,15 @@ def test_case_4d_binned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 23) - np.testing.assert_allclose(result.final_index_col, 23) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 23) + np.testing.assert_allclose(r.final_index_col, 23) def test_case_5a_unbinned(self): lat_lon_err = (2.5, 0.1) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "5" / "TestCase5a_subimage.mat", gcp_file=self.test_dir / "5" / "GCP10087Titicaca_resampled.mat", ancil_file=self.test_dir / "5" / "R_ISS_midframe_TestCase5.mat", @@ -401,15 +391,15 @@ def test_case_5a_unbinned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 23) - np.testing.assert_allclose(result.final_index_col, 21) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 23) + np.testing.assert_allclose(r.final_index_col, 21) def test_case_5b_unbinned(self): lat_lon_err = (2.5, 0.1) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "5" / "TestCase5b_subimage.mat", gcp_file=self.test_dir / "5" / "GCP10087Titicaca_resampled.mat", ancil_file=self.test_dir / "5" / "R_ISS_midframe_TestCase5.mat", @@ -417,15 +407,15 @@ def test_case_5b_unbinned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 22) - np.testing.assert_allclose(result.final_index_col, 21) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 22) + np.testing.assert_allclose(r.final_index_col, 21) def test_case_5c_binned(self): lat_lon_err = (2.5, 0.1) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "5" / "TestCase5c_subimage_binned.mat", gcp_file=self.test_dir / "5" / "GCP10087Titicaca_resampled.mat", ancil_file=self.test_dir / "5" / "R_ISS_midframe_TestCase5.mat", @@ -433,15 +423,15 @@ def test_case_5c_binned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 22) - np.testing.assert_allclose(result.final_index_col, 22) + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 22) + np.testing.assert_allclose(r.final_index_col, 22) def test_case_5d_binned(self): lat_lon_err = (2.5, 0.1) - result = self.run_image_match( + r = self._run( subimg_file=self.test_dir / "5" / "TestCase5d_subimage_binned.mat", gcp_file=self.test_dir / "5" / "GCP10087Titicaca_resampled.mat", ancil_file=self.test_dir / "5" / "R_ISS_midframe_TestCase5.mat", @@ -449,12 +439,8 @@ def test_case_5d_binned(self): pix_vec_file=self.test_dir / "b_HS.mat", lat_lon_err=lat_lon_err, ) - np.testing.assert_allclose(result.lat_error_km, lat_lon_err[0], atol=0.05) - np.testing.assert_allclose(result.lon_error_km, lat_lon_err[1], atol=0.05) - np.testing.assert_allclose(result.ccv_final, 1.0, atol=0.01) - np.testing.assert_allclose(result.final_index_row, 22) - np.testing.assert_allclose(result.final_index_col, 22) - - -if __name__ == "__main__": - unittest.main() + np.testing.assert_allclose(r.lat_error_km, lat_lon_err[0], atol=0.05) + np.testing.assert_allclose(r.lon_error_km, lat_lon_err[1], atol=0.05) + np.testing.assert_allclose(r.ccv_final, 1.0, atol=0.01) + np.testing.assert_allclose(r.final_index_row, 22) + np.testing.assert_allclose(r.final_index_col, 22) diff --git a/tests/test_correction/test_kernel_ops.py b/tests/test_correction/test_kernel_ops.py new file mode 100644 index 00000000..6cfb55dc --- /dev/null +++ b/tests/test_correction/test_kernel_ops.py @@ -0,0 +1,193 @@ +"""Tests for ``curryer.correction.kernel_ops``. + +Covers: +- ``apply_offset`` – all parameter types and unit-conversion paths +- ``_load_calibration_data`` +- ``_create_dynamic_kernels`` (``@pytest.mark.extra``, requires ``mkspk``) +""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +from clarreo_config import create_clarreo_correction_config +from clarreo_data_loaders import load_clarreo_telemetry + +from curryer import meta +from curryer import spicierpy as sp +from curryer.correction import correction +from curryer.kernels import create + +logger = logging.getLogger(__name__) + +# ── shared sample data ──────────────────────────────────────────────────────── + +_TLM = pd.DataFrame( + { + "frame": range(5), + "hps.az_ang_nonlin": [1.14252] * 5, + "hps.el_ang_nonlin": [-0.55009] * 5, + "hps.resolver_tms": [1168477154.0 + i for i in range(5)], + "ert": [1431903180.58 + i for i in range(5)], + } +) + +_SCI = pd.DataFrame( + { + "corrected_timestamp": [1_000_000.0, 2_000_000.0, 3_000_000.0, 4_000_000.0, 5_000_000.0], + "measurement": [1.0, 2.0, 3.0, 4.0, 5.0], + } +) + + +@pytest.fixture(scope="module") +def clarreo_cfg(root_dir): + return create_clarreo_correction_config( + root_dir / "tests" / "data" / "clarreo" / "gcs", + root_dir / "data" / "generic", + ) + + +# ── apply_offset tests ──────────────────────────────────────────────────────── + + +def test_apply_offset_kernel_arcseconds(): + """OFFSET_KERNEL converts arcseconds to radians and adds the offset.""" + p = correction.ParameterConfig( + ptype=correction.ParameterType.OFFSET_KERNEL, + config_file=Path("cprs_az.json"), + data=dict(field="hps.az_ang_nonlin", units="arcseconds"), + ) + original = _TLM["hps.az_ang_nonlin"].mean() + modified = correction.apply_offset(p, 100.0, _TLM) + assert modified["hps.az_ang_nonlin"].mean() - original == pytest.approx(np.deg2rad(100.0 / 3600.0), rel=1e-6) + assert isinstance(modified, pd.DataFrame) + + +def test_apply_offset_kernel_negative(): + p = correction.ParameterConfig( + ptype=correction.ParameterType.OFFSET_KERNEL, + config_file=Path("cprs_el.json"), + data=dict(field="hps.el_ang_nonlin", units="arcseconds"), + ) + original = _TLM["hps.el_ang_nonlin"].mean() + modified = correction.apply_offset(p, -50.0, _TLM) + assert modified["hps.el_ang_nonlin"].mean() - original == pytest.approx(np.deg2rad(-50.0 / 3600.0), rel=1e-6) + + +def test_apply_offset_kernel_missing_field(): + """Non-existent field: returns original DataFrame unchanged.""" + p = correction.ParameterConfig( + ptype=correction.ParameterType.OFFSET_KERNEL, + config_file=Path("dummy.json"), + data=dict(field="nonexistent_field", units="arcseconds"), + ) + modified = correction.apply_offset(p, 10.0, _TLM) + pd.testing.assert_frame_equal(modified, _TLM) + + +def test_apply_offset_time_milliseconds(): + """OFFSET_TIME: seconds input → microsecond output on timestamp column.""" + p = correction.ParameterConfig( + ptype=correction.ParameterType.OFFSET_TIME, + config_file=None, + data=dict(field="corrected_timestamp", units="milliseconds"), + ) + original = _SCI["corrected_timestamp"].mean() + modified = correction.apply_offset(p, 10.0 / 1000.0, _SCI) # 10 ms in seconds + assert modified["corrected_timestamp"].mean() - original == pytest.approx(10_000.0, rel=1e-6) + + +def test_apply_offset_time_negative(): + p = correction.ParameterConfig( + ptype=correction.ParameterType.OFFSET_TIME, + config_file=None, + data=dict(field="corrected_timestamp", units="milliseconds"), + ) + original = _SCI["corrected_timestamp"].mean() + modified = correction.apply_offset(p, -5.5 / 1000.0, _SCI) + assert modified["corrected_timestamp"].mean() - original == pytest.approx(-5500.0, rel=1e-6) + + +def test_apply_offset_constant_kernel_passthrough(): + """CONSTANT_KERNEL: data is returned unchanged.""" + kernel_data = pd.DataFrame({"ugps": [1_000_000], "angle_x": [0.001], "angle_y": [0.002], "angle_z": [0.003]}) + p = correction.ParameterConfig( + ptype=correction.ParameterType.CONSTANT_KERNEL, + config_file=Path("base.json"), + data=dict(field="base"), + ) + modified = correction.apply_offset(p, kernel_data, pd.DataFrame()) + pd.testing.assert_frame_equal(modified, kernel_data) + + +def test_apply_offset_no_units(): + """OFFSET_KERNEL without units: offset applied in raw (radian) units.""" + p = correction.ParameterConfig( + ptype=correction.ParameterType.OFFSET_KERNEL, + config_file=Path("test.json"), + data=dict(field="hps.az_ang_nonlin"), + ) + original = _TLM["hps.az_ang_nonlin"].mean() + modified = correction.apply_offset(p, 0.001, _TLM) + assert modified["hps.az_ang_nonlin"].mean() - original == pytest.approx(0.001, rel=1e-6) + + +def test_apply_offset_not_inplace(): + """Original DataFrame is not mutated.""" + p = correction.ParameterConfig( + ptype=correction.ParameterType.OFFSET_KERNEL, + config_file=Path("cprs_az.json"), + data=dict(field="hps.az_ang_nonlin", units="arcseconds"), + ) + original = _TLM.copy() + correction.apply_offset(p, 100.0, _TLM) + pd.testing.assert_frame_equal(_TLM, original) + + +def test_apply_offset_preserves_columns(): + """All columns are present in the returned DataFrame.""" + p = correction.ParameterConfig( + ptype=correction.ParameterType.OFFSET_KERNEL, + config_file=Path("cprs_az.json"), + data=dict(field="hps.az_ang_nonlin", units="arcseconds"), + ) + modified = correction.apply_offset(p, 50.0, _TLM) + assert set(modified.columns) == set(_TLM.columns) + assert modified["frame"].equals(_TLM["frame"]) + assert not modified["hps.az_ang_nonlin"].equals(_TLM["hps.az_ang_nonlin"]) + + +# ── _load_calibration_data ──────────────────────────────────────────────────── + + +def test_load_calibration_data_no_dir(clarreo_cfg): + """When calibration_dir is None, returned data contains no vectors.""" + cfg = clarreo_cfg.model_copy(deep=True) + cfg.calibration_dir = None + cal = correction._load_calibration_data(cfg) + assert cal.los_vectors is None + assert cal.optical_psfs is None + + +# ── _create_dynamic_kernels ─────────────────────────────────────────────────── + + +@pytest.mark.extra +def test_create_dynamic_kernels(root_dir, clarreo_cfg, tmp_path): + """_create_dynamic_kernels builds kernel files. Needs ``mkspk`` – ``--run-extra``.""" + data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" + work = tmp_path / "kernels" + work.mkdir() + tlm = load_clarreo_telemetry(data_dir) + creator = create.KernelCreator(overwrite=True, append=False) + mkrn = meta.MetaKernel.from_json( + clarreo_cfg.geo.meta_kernel_file, relative=True, sds_dir=clarreo_cfg.geo.generic_kernel_dir + ) + with sp.ext.load_kernel([mkrn.sds_kernels, mkrn.mission_kernels]): + dynamic_kernels = correction._create_dynamic_kernels(clarreo_cfg, work, tlm, creator) + assert isinstance(dynamic_kernels, list) diff --git a/tests/test_correction/test_pairing.py b/tests/test_correction/test_pairing.py index 9cc5520e..98ef6fb0 100644 --- a/tests/test_correction/test_pairing.py +++ b/tests/test_correction/test_pairing.py @@ -1,45 +1,20 @@ -""" -Tests for pairing.py module - -This module tests the GCP (Ground Control Point) pairing functionality: -- Spatial pairing of L1A science data with GCP reference imagery -- File discovery and matching -- Geographic overlap detection -- Pairing validation - -Running Tests: -------------- -# Via pytest (recommended) -pytest tests/test_correction/test_pairing.py -v - -# Run specific test -pytest tests/test_correction/test_pairing.py::PairingTestCase::test_find_l1a_gcp_pairs -v - -# Standalone execution -python tests/test_correction/test_pairing.py - -Requirements: ------------------ -These tests validate that GCP pairing correctly identifies which reference -images overlap with science data, ensuring accurate geolocation validation. -""" +"""Tests for ``curryer.correction.pairing`` (GCP spatial pairing).""" from __future__ import annotations import logging -import unittest from pathlib import Path import numpy as np +import pytest from scipy.io import loadmat -from curryer import utils from curryer.correction.data_structures import NamedImageGrid from curryer.correction.pairing import find_l1a_gcp_pairs logger = logging.getLogger(__name__) -utils.enable_logging(log_level=logging.DEBUG, extra_loggers=[__name__]) +# ── test-case metadata ──────────────────────────────────────────────────────── L1A_FILES = [ ("1/TestCase1a_subimage.mat", "subimage"), @@ -75,97 +50,86 @@ "5/TestCase5b_subimage.mat": "5/GCP10087Titicaca_resampled.mat", } +# ── fixtures / helpers ──────────────────────────────────────────────────────── + + +@pytest.fixture(scope="module") +def image_match_dir(): + root = Path(__file__).parent.parent.parent + d = root / "tests" / "data" / "clarreo" / "image_match" + assert d.is_dir(), str(d) + return d + + +def _load(path: Path, key: str, name: str) -> NamedImageGrid: + mat = loadmat(str(path), squeeze_me=True, struct_as_record=False)[key] + h = getattr(mat, "h", None) + return NamedImageGrid( + data=np.asarray(mat.data), + lat=np.asarray(mat.lat), + lon=np.asarray(mat.lon), + h=np.asarray(h) if h is not None else None, + name=name, + ) + + +def _rect(name, lon_min, lon_max, lat_min, lat_max) -> NamedImageGrid: + lat = np.array([[lat_max, lat_max], [lat_min, lat_min]], dtype=float) + lon = np.array([[lon_min, lon_max], [lon_min, lon_max]], dtype=float) + return NamedImageGrid(data=np.zeros_like(lat), lat=lat, lon=lon, name=name) + + +def _point(name, lon, lat) -> NamedImageGrid: + return NamedImageGrid(data=np.array([[1.0]]), lat=np.array([[lat]]), lon=np.array([[lon]]), name=name) + + +# ── tests ───────────────────────────────────────────────────────────────────── + + +def test_find_l1a_gcp_pairs(image_match_dir): + l1a = [_load(image_match_dir / rel, key, rel) for rel, key in L1A_FILES] + gcp = [_load(image_match_dir / rel, key, rel) for rel, key in GCP_FILES] + result = find_l1a_gcp_pairs(l1a, gcp, max_distance_m=0.0) + by_l1a = {} + for m in result.matches: + name = result.l1a_images[m.l1a_index].name + gname = result.gcp_images[m.gcp_index].name + if name not in by_l1a or m.distance_m < by_l1a[name][1]: + by_l1a[name] = (gname, m.distance_m) + assert set(by_l1a) == set(EXPECTED_MATCHES) + for l1a_name, expected_gcp in EXPECTED_MATCHES.items(): + gname, dist = by_l1a[l1a_name] + assert gname == expected_gcp + assert dist >= 0.0 + + +def test_synthetic_pairing_no_overlap(): + result = find_l1a_gcp_pairs( + [_rect("L1A", 1.0, 2.0, -1.0, 1.0)], [_point("GCP", 0.0, 0.0)], max_distance_m=100_000.0 + ) + assert result.matches == [] + + +def test_synthetic_pairing_partial_less_than_threshold(): + result = find_l1a_gcp_pairs( + [_rect("L1A", 0.0, 1.0, -1.0, 1.0)], [_point("GCP", 0.0, 0.0)], max_distance_m=100_000.0 + ) + assert result.matches == [] + + +def test_synthetic_pairing_complete_above_threshold(): + result = find_l1a_gcp_pairs( + [_rect("L1A", -1.0, 1.0, -1.0, 1.0)], [_point("GCP", 0.0, 0.0)], max_distance_m=100_000.0 + ) + assert len(result.matches) == 1 + m = result.matches[0] + assert m.l1a_index == 0 + assert m.gcp_index == 0 + assert m.distance_m >= 100_000.0 + -class PairingTestCase(unittest.TestCase): - def setUp(self) -> None: - root_dir = Path(__file__).parent.parent.parent - self.test_dir = root_dir / "tests" / "data" / "clarreo" / "image_match" - self.assertTrue(self.test_dir.is_dir(), self.test_dir) - - def _load_image_grid(self, relative_path: str, key: str) -> NamedImageGrid: - mat_path = self.test_dir / relative_path - mat = loadmat(mat_path, squeeze_me=True, struct_as_record=False)[key] - h = getattr(mat, "h", None) - return NamedImageGrid( - data=np.asarray(mat.data), - lat=np.asarray(mat.lat), - lon=np.asarray(mat.lon), - h=np.asarray(h) if h is not None else None, - name=relative_path, - ) - - def _prepare_inputs(self, file_list): - for rel_path, key in file_list: - yield self._load_image_grid(rel_path, key) - - def test_find_l1a_gcp_pairs(self): - l1a_inputs = list(self._prepare_inputs(L1A_FILES)) - gcp_inputs = list(self._prepare_inputs(GCP_FILES)) - - result = find_l1a_gcp_pairs(l1a_inputs, gcp_inputs, max_distance_m=0.0) - - matches_by_l1a = {} - for match in result.matches: - l1a_name = result.l1a_images[match.l1a_index].name - gcp_name = result.gcp_images[match.gcp_index].name - distance = match.distance_m - if l1a_name not in matches_by_l1a or distance < matches_by_l1a[l1a_name][1]: - matches_by_l1a[l1a_name] = (gcp_name, distance) - - assert set(matches_by_l1a.keys()) == set(expected for expected in EXPECTED_MATCHES) - - for l1a_name, expected_gcp in EXPECTED_MATCHES.items(): - assert l1a_name in matches_by_l1a, f"Missing match for {l1a_name}" - gcp_name, distance = matches_by_l1a[l1a_name] - assert gcp_name == expected_gcp, f"Unexpected match for {l1a_name}: {gcp_name}" - assert distance >= 0.0, f"Expected non-negative margin for {l1a_name}: {distance}" - - @staticmethod - def _make_rect_image(name: str, lon_min: float, lon_max: float, lat_min: float, lat_max: float) -> NamedImageGrid: - lat = np.array([[lat_max, lat_max], [lat_min, lat_min]], dtype=float) - lon = np.array([[lon_min, lon_max], [lon_min, lon_max]], dtype=float) - data = np.zeros_like(lat) - return NamedImageGrid(data=data, lat=lat, lon=lon, name=name) - - @staticmethod - def _make_point_gcp(name: str, lon: float, lat: float) -> NamedImageGrid: - data = np.array([[1.0]]) - lat_arr = np.array([[lat]]) - lon_arr = np.array([[lon]]) - return NamedImageGrid(data=data, lat=lat_arr, lon=lon_arr, name=name) - - def test_synthetic_pairing_no_overlap(self): - l1a = [self._make_rect_image("L1A", lon_min=1.0, lon_max=2.0, lat_min=-1.0, lat_max=1.0)] - gcp = [self._make_point_gcp("GCP", lon=0.0, lat=0.0)] - - result = find_l1a_gcp_pairs(l1a, gcp, max_distance_m=100_000.0) - assert result.matches == [] - - def test_synthetic_pairing_partial_less_than_threshold(self): - l1a = [self._make_rect_image("L1A", lon_min=0.0, lon_max=1.0, lat_min=-1.0, lat_max=1.0)] - gcp = [self._make_point_gcp("GCP", lon=0.0, lat=0.0)] - - result = find_l1a_gcp_pairs(l1a, gcp, max_distance_m=100_000.0) - assert result.matches == [] - - def test_synthetic_pairing_complete_above_threshold(self): - l1a = [self._make_rect_image("L1A", lon_min=-1.0, lon_max=1.0, lat_min=-1.0, lat_max=1.0)] - gcp = [self._make_point_gcp("GCP", lon=0.0, lat=0.0)] - - result = find_l1a_gcp_pairs(l1a, gcp, max_distance_m=100_000.0) - assert len(result.matches) == 1 - match = result.matches[0] - assert match.l1a_index == 0 - assert match.gcp_index == 0 - assert match.distance_m >= 100_000.0 - - def test_synthetic_pairing_partial_threshold_not_met(self): - l1a = [self._make_rect_image("L1A", lon_min=-1.0, lon_max=1.0, lat_min=-1.0, lat_max=1.0)] - gcp = [self._make_point_gcp("GCP", lon=0.0, lat=0.0)] - - result = find_l1a_gcp_pairs(l1a, gcp, max_distance_m=200_000.0) - assert result.matches == [] - - -if __name__ == "__main__": - unittest.main() +def test_synthetic_pairing_partial_threshold_not_met(): + result = find_l1a_gcp_pairs( + [_rect("L1A", -1.0, 1.0, -1.0, 1.0)], [_point("GCP", 0.0, 0.0)], max_distance_m=200_000.0 + ) + assert result.matches == [] diff --git a/tests/test_correction/test_pipeline.py b/tests/test_correction/test_pipeline.py new file mode 100644 index 00000000..ba7b73f9 --- /dev/null +++ b/tests/test_correction/test_pipeline.py @@ -0,0 +1,162 @@ +"""Tests for ``curryer.correction.pipeline``. + +Covers: +- ``_extract_parameter_values`` +- ``_extract_error_metrics`` +- ``_store_parameter_values`` +- ``_store_gcp_pair_results`` +- ``_compute_parameter_set_metrics`` +- ``_load_image_pair_data`` +- ``loop`` (optimised pair-outer, ``@pytest.mark.extra``) +""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +import xarray as xr +from _synthetic_helpers import synthetic_image_matching +from clarreo_config import create_clarreo_correction_config +from clarreo_data_loaders import load_clarreo_science, load_clarreo_telemetry + +from curryer.correction import correction +from curryer.correction.config import DataConfig + +logger = logging.getLogger(__name__) + + +# ── shared fixtures ─────────────────────────────────────────────────────────── + + +@pytest.fixture(scope="module") +def clarreo_cfg(root_dir): + data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" + generic_dir = root_dir / "data" / "generic" + return create_clarreo_correction_config(data_dir, generic_dir) + + +# ── tests ───────────────────────────────────────────────────────────────────── + + +def test_extract_parameter_values(): + """_extract_parameter_values returns roll/pitch/yaw keys.""" + param_config = correction.ParameterConfig( + ptype=correction.ParameterType.CONSTANT_KERNEL, config_file=Path("test_kernel.json"), data=None + ) + param_data = pd.DataFrame( + { + "angle_x": [np.radians(1.0 / 3600)], + "angle_y": [np.radians(2.0 / 3600)], + "angle_z": [np.radians(3.0 / 3600)], + } + ) + result = correction._extract_parameter_values([(param_config, param_data)]) + assert isinstance(result, dict) + assert len(result) == 3 + assert "test_kernel_roll" in result + assert "test_kernel_pitch" in result + assert "test_kernel_yaw" in result + + +def test_extract_error_metrics(): + """_extract_error_metrics pulls named metrics from a Dataset.""" + ds = xr.Dataset({"lat_error_deg": (["pt"], [0.001, 0.002])}) + ds.attrs.update( + { + "rms_error_m": 150.0, + "mean_error_m": 140.0, + "max_error_m": 200.0, + "std_error_m": 10.0, + "total_measurements": 2, + } + ) + m = correction._extract_error_metrics(ds) + assert m["rms_error_m"] == 150.0 + assert m["n_measurements"] == 2 + + +def test_store_parameter_values(): + """_store_parameter_values writes values at the correct index.""" + netcdf_data = {"parameter_set_id": np.zeros(3, dtype=int), "param_foo": np.zeros(3)} + correction._store_parameter_values(netcdf_data, param_idx=1, param_values={"foo": 2.5}) + assert netcdf_data["param_foo"][1] == pytest.approx(2.5) + + +def test_store_gcp_pair_results(): + """_store_gcp_pair_results populates all metric arrays correctly.""" + nc = {k: np.zeros((2, 2)) for k in ("rms_error_m", "mean_error_m", "max_error_m", "std_error_m")} + nc["n_measurements"] = np.zeros((2, 2), dtype=int) + metrics = { + "rms_error_m": 150.0, + "mean_error_m": 140.0, + "max_error_m": 200.0, + "std_error_m": 10.0, + "n_measurements": 10, + } + correction._store_gcp_pair_results(nc, param_idx=0, pair_idx=1, error_metrics=metrics) + assert nc["rms_error_m"][0, 1] == 150.0 + assert nc["std_error_m"][0, 1] == 10.0 + assert nc["n_measurements"][0, 1] == 10 + + +def test_compute_parameter_set_metrics(): + """_compute_parameter_set_metrics populates aggregate stats.""" + nc = { + "percent_under_250m": np.zeros(2), + "mean_rms_all_pairs": np.zeros(2), + "best_pair_rms": np.zeros(2), + "worst_pair_rms": np.zeros(2), + } + correction._compute_parameter_set_metrics(nc, param_idx=0, pair_errors=[100.0, 200.0, 300.0], threshold_m=250.0) + assert nc["percent_under_250m"][0] > 0 + assert nc["best_pair_rms"][0] == 100.0 + assert nc["worst_pair_rms"][0] == 300.0 + + +def test_load_image_pair_data(root_dir, clarreo_cfg, tmp_path): + """_load_image_pair_data returns DataFrames for tlm and sci.""" + data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" + tlm_csv = tmp_path / "tlm.csv" + sci_csv = tmp_path / "sci.csv" + load_clarreo_telemetry(data_dir).to_csv(tlm_csv) + load_clarreo_science(data_dir).to_csv(sci_csv) + cfg = clarreo_cfg.model_copy(deep=True) + cfg.data = DataConfig(file_format="csv", time_scale_factor=1e6) + tlm_ds, sci_ds, ugps = correction._load_image_pair_data(str(tlm_csv), str(sci_csv), cfg) + assert isinstance(tlm_ds, pd.DataFrame) + assert isinstance(sci_ds, pd.DataFrame) + assert ugps is not None + + +@pytest.mark.extra +def test_loop_optimized(root_dir, tmp_path): + """loop() produces correct result structure. Requires GMTED – ``--run-extra``.""" + data_dir = root_dir / "tests" / "data" / "clarreo" / "gcs" + generic_dir = root_dir / "data" / "generic" + config = create_clarreo_correction_config(data_dir, generic_dir) + config.n_iterations = 2 + config.output_filename = "test_loop.nc" + work = tmp_path / "loop" + work.mkdir() + tlm_csv, sci_csv = work / "tlm.csv", work / "sci.csv" + load_clarreo_telemetry(data_dir).to_csv(tlm_csv) + load_clarreo_science(data_dir).to_csv(sci_csv) + config.data = DataConfig(file_format="csv", time_scale_factor=1e6) + config.image_matching_func = synthetic_image_matching + sets = [(str(tlm_csv), str(sci_csv), "synthetic_gcp.mat")] + np.random.seed(42) + results, nc = correction.loop(config, work, sets, resume_from_checkpoint=False) + assert isinstance(results, list) + assert len(results) > 0 + assert len(results) == config.n_iterations * len(sets) + assert nc["rms_error_m"].shape == (config.n_iterations, len(sets)) + for r in results: + assert "param_index" in r + assert "pair_index" in r + assert "rms_error_m" in r + assert r["aggregate_rms_error_m"] is not None + assert isinstance(r["aggregate_rms_error_m"], (int, float, np.number)) diff --git a/tests/test_correction/test_results_io.py b/tests/test_correction/test_results_io.py new file mode 100644 index 00000000..d26fd51c --- /dev/null +++ b/tests/test_correction/test_results_io.py @@ -0,0 +1,83 @@ +"""Tests for ``curryer.correction.results_io``. + +Covers: +- ``_build_netcdf_structure`` +- ``_save_netcdf_checkpoint`` / ``_load_checkpoint`` / ``_cleanup_checkpoint`` +""" + +from __future__ import annotations + +import logging + +import pytest +from clarreo_config import create_clarreo_correction_config + +from curryer.correction import correction + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="module") +def clarreo_cfg(root_dir): + return create_clarreo_correction_config( + root_dir / "tests" / "data" / "clarreo" / "gcs", + root_dir / "data" / "generic", + ) + + +# ── _build_netcdf_structure ─────────────────────────────────────────────────── + + +def test_build_netcdf_structure_shapes(clarreo_cfg): + """_build_netcdf_structure returns arrays with the requested dimensions.""" + nc = correction._build_netcdf_structure(clarreo_cfg, n_param_sets=3, n_gcp_pairs=2) + assert isinstance(nc, dict) + assert "rms_error_m" in nc + assert "parameter_set_id" in nc + assert nc["rms_error_m"].shape == (3, 2) + assert len(nc["parameter_set_id"]) == 3 + + +def test_build_netcdf_structure_zero_initialised(clarreo_cfg): + """All numeric arrays are initialised to zero (or NaN for float arrays).""" + nc = correction._build_netcdf_structure(clarreo_cfg, n_param_sets=2, n_gcp_pairs=1) + # rms_error_m starts as zeros or NaN – just check it is numeric + assert nc["rms_error_m"].dtype.kind in ("f", "i", "u") + + +# ── checkpoint round-trip ───────────────────────────────────────────────────── + + +def test_checkpoint_save_load_cleanup(clarreo_cfg, tmp_path): + """Save → load → cleanup round-trip preserves data and removes the file.""" + cfg = clarreo_cfg.model_copy(deep=True) + cfg.n_iterations = 2 + cfg.output_filename = "ckpt_test.nc" + output_file = tmp_path / cfg.output_filename + + nc = correction._build_netcdf_structure(cfg, n_param_sets=2, n_gcp_pairs=2) + nc["rms_error_m"][0, 0] = 100.0 + nc["rms_error_m"][1, 0] = 150.0 + + correction._save_netcdf_checkpoint(nc, output_file, cfg, pair_idx_completed=0) + ckpt = output_file.parent / f"{output_file.stem}_checkpoint.nc" + assert ckpt.exists() + + loaded, completed = correction._load_checkpoint(output_file, cfg) + assert loaded is not None + assert completed == 1 # pair index 0 completed → 1 pair done + assert loaded["rms_error_m"][0, 0] == pytest.approx(100.0) + assert loaded["rms_error_m"][1, 0] == pytest.approx(150.0) + + correction._cleanup_checkpoint(output_file) + assert not ckpt.exists() + + +def test_checkpoint_load_missing_returns_none(clarreo_cfg, tmp_path): + """_load_checkpoint returns (None, 0) when no checkpoint file exists.""" + cfg = clarreo_cfg.model_copy(deep=True) + cfg.output_filename = "no_ckpt.nc" + output_file = tmp_path / cfg.output_filename + loaded, completed = correction._load_checkpoint(output_file, cfg) + assert loaded is None + assert completed == 0