From 754f28ab225636b7e592b57cd8504e30c4651d82 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Thu, 12 Mar 2026 16:44:16 -0700 Subject: [PATCH 01/27] Initial shot --- fme/downscaling/data/config.py | 7 +- fme/downscaling/data/datasets.py | 1 + fme/downscaling/data/static.py | 80 +++----------- fme/downscaling/data/test_static.py | 64 ++++++----- fme/downscaling/inference/test_inference.py | 23 ++-- fme/downscaling/models.py | 111 ++++++++++++++++++-- fme/downscaling/predict.py | 53 +++++----- fme/downscaling/predictors/composite.py | 4 + fme/downscaling/test_models.py | 78 ++++++++++++-- fme/downscaling/test_predict.py | 28 +++-- fme/downscaling/train.py | 1 + 11 files changed, 296 insertions(+), 154 deletions(-) diff --git a/fme/downscaling/data/config.py b/fme/downscaling/data/config.py index eda213fbc..14396a8e7 100644 --- a/fme/downscaling/data/config.py +++ b/fme/downscaling/data/config.py @@ -23,7 +23,11 @@ PairedBatchData, PairedGriddedData, ) -from fme.downscaling.data.utils import ClosedInterval, adjust_fine_coord_range +from fme.downscaling.data.utils import ( + ClosedInterval, + adjust_fine_coord_range, + get_latlon_coords_from_properties, +) from fme.downscaling.requirements import DataRequirements @@ -529,6 +533,7 @@ def build( dims=example.fine.latlon_coordinates.dims, variable_metadata=variable_metadata, all_times=all_times, + fine_coords=get_latlon_coords_from_properties(properties_fine), ) def _get_sampler( diff --git a/fme/downscaling/data/datasets.py b/fme/downscaling/data/datasets.py index bce3d59c4..2b6798f04 100644 --- a/fme/downscaling/data/datasets.py +++ b/fme/downscaling/data/datasets.py @@ -337,6 +337,7 @@ class PairedGriddedData: dims: list[str] variable_metadata: Mapping[str, VariableMetadata] all_times: xr.CFTimeIndex + fine_coords: LatLonCoordinates | None = None @property def loader(self) -> DataLoader[PairedBatchItem]: diff --git a/fme/downscaling/data/static.py b/fme/downscaling/data/static.py index 4a860174c..ae4721f8d 100644 --- a/fme/downscaling/data/static.py +++ b/fme/downscaling/data/static.py @@ -3,26 +3,17 @@ import torch import xarray as xr -from fme.core.coordinates import LatLonCoordinates from fme.core.device import get_device -from fme.downscaling.data.utils import ClosedInterval +from fme.downscaling.data.patching import Patch @dataclasses.dataclass class StaticInput: data: torch.Tensor - coords: LatLonCoordinates def __post_init__(self): if len(self.data.shape) != 2: raise ValueError(f"Topography data must be 2D. Got shape {self.data.shape}") - if self.data.shape[0] != len(self.coords.lat) or self.data.shape[1] != len( - self.coords.lon - ): - raise ValueError( - f"Static inputs data shape {self.data.shape} does not match lat/lon " - f"coordinates shape {(len(self.coords.lat), len(self.coords.lon))}" - ) @property def dim(self) -> int: @@ -32,44 +23,23 @@ def dim(self) -> int: def shape(self) -> tuple[int, int]: return self.data.shape - def subset_latlon( + def subset( self, - lat_interval: ClosedInterval, - lon_interval: ClosedInterval, + lat_slice: slice, + lon_slice: slice, ) -> "StaticInput": - lat_slice = lat_interval.slice_of(self.coords.lat) - lon_slice = lon_interval.slice_of(self.coords.lon) - return self._latlon_index_slice(lat_slice=lat_slice, lon_slice=lon_slice) + return StaticInput(data=self.data[lat_slice, lon_slice]) def to_device(self) -> "StaticInput": device = get_device() - return StaticInput( - data=self.data.to(device), - coords=LatLonCoordinates( - lat=self.coords.lat.to(device), - lon=self.coords.lon.to(device), - ), - ) + return StaticInput(data=self.data.to(device)) - def _latlon_index_slice( - self, - lat_slice: slice, - lon_slice: slice, - ) -> "StaticInput": - sliced_data = self.data[lat_slice, lon_slice] - sliced_latlon = LatLonCoordinates( - lat=self.coords.lat[lat_slice], - lon=self.coords.lon[lon_slice], - ) - return StaticInput( - data=sliced_data, - coords=sliced_latlon, - ) + def _apply_patch(self, patch: Patch): + return self.subset(lat_slice=patch.input_slice.y, lon_slice=patch.input_slice.x) def get_state(self) -> dict: return { "data": self.data.cpu(), - "coords": self.coords.get_state(), } @@ -93,17 +63,11 @@ def _get_normalized_static_input(path: str, field_name: str): f"unexpected shape {static_input.shape} for static input." "Currently, only lat/lon static input is supported." ) - lat_name, lon_name = static_input.dims[-2:] - coords = LatLonCoordinates( - lon=torch.tensor(static_input[lon_name].values), - lat=torch.tensor(static_input[lat_name].values), - ) static_input_normalized = (static_input - static_input.mean()) / static_input.std() return StaticInput( data=torch.tensor(static_input_normalized.values, dtype=torch.float32), - coords=coords, ) @@ -113,36 +77,28 @@ class StaticInputs: def __post_init__(self): for i, field in enumerate(self.fields[1:]): - if field.coords != self.fields[0].coords: + if field.shape != self.fields[0].shape: raise ValueError( - f"All StaticInput fields must have the same coordinates. " - f"Fields {i} and 0 do not match coordinates." + f"All StaticInput fields must have the same shape. " + f"Fields {i + 1} and 0 do not match shapes." ) def __getitem__(self, index: int): return self.fields[index] - @property - def coords(self) -> LatLonCoordinates: - if len(self.fields) == 0: - raise ValueError("No fields in StaticInputs to get coordinates from.") - return self.fields[0].coords - @property def shape(self) -> tuple[int, int]: if len(self.fields) == 0: raise ValueError("No fields in StaticInputs to get shape from.") return self.fields[0].shape - def subset_latlon( + def subset( self, - lat_interval: ClosedInterval, - lon_interval: ClosedInterval, + lat_slice: slice, + lon_slice: slice, ) -> "StaticInputs": return StaticInputs( - fields=[ - field.subset_latlon(lat_interval, lon_interval) for field in self.fields - ] + fields=[field.subset(lat_slice, lon_slice) for field in self.fields] ) def to_device(self) -> "StaticInputs": @@ -159,10 +115,6 @@ def from_state(cls, state: dict) -> "StaticInputs": fields=[ StaticInput( data=field_state["data"], - coords=LatLonCoordinates( - lat=field_state["coords"]["lat"], - lon=field_state["coords"]["lon"], - ), ) for field_state in state["fields"] ] @@ -171,7 +123,7 @@ def from_state(cls, state: dict) -> "StaticInputs": def load_static_inputs( static_inputs_config: dict[str, str] | None, -) -> StaticInputs | None: +) -> "StaticInputs | None": """ Load normalized static inputs from a mapping of field names to file paths. Returns None if the input config is empty. diff --git a/fme/downscaling/data/test_static.py b/fme/downscaling/data/test_static.py index edfc0e080..104ffdc9f 100644 --- a/fme/downscaling/data/test_static.py +++ b/fme/downscaling/data/test_static.py @@ -1,26 +1,16 @@ import pytest import torch -from fme.core.coordinates import LatLonCoordinates - from .static import StaticInput, StaticInputs -from .utils import ClosedInterval @pytest.mark.parametrize( "init_args", [ pytest.param( - [ - torch.randn((1, 2, 2)), - LatLonCoordinates(torch.arange(2), torch.arange(2)), - ], + [torch.randn((1, 2, 2))], id="3d_data", ), - pytest.param( - [torch.randn((2, 2)), LatLonCoordinates(torch.arange(2), torch.arange(5))], - id="dim_size_mismatch", - ), ], ) def test_Topography_error_cases(init_args): @@ -28,37 +18,43 @@ def test_Topography_error_cases(init_args): StaticInput(*init_args) -def test_subset_latlon(): +def test_subset(): full_data_shape = (10, 10) - expected_slices = [slice(2, 6), slice(3, 8)] data = torch.randn(*full_data_shape) - coords = LatLonCoordinates( - lat=torch.linspace(0, 9, 10), lon=torch.linspace(0, 9, 10) - ) - topo = StaticInput(data=data, coords=coords) - lat_interval = ClosedInterval(2, 5) - lon_interval = ClosedInterval(3, 7) - subset_topo = topo.subset_latlon(lat_interval, lon_interval) - expected_lats = torch.tensor([2, 3, 4, 5], dtype=coords.lat.dtype) - expected_lons = torch.tensor([3, 4, 5, 6, 7], dtype=coords.lon.dtype) - expected_data = data[*expected_slices] - assert torch.equal(subset_topo.coords.lat, expected_lats) - assert torch.equal(subset_topo.coords.lon, expected_lons) - assert torch.allclose(subset_topo.data, expected_data) + topo = StaticInput(data=data) + lat_slice = slice(2, 6) + lon_slice = slice(3, 8) + subset_topo = topo.subset(lat_slice, lon_slice) + assert torch.allclose(subset_topo.data, data[lat_slice, lon_slice]) def test_StaticInputs_serialize(): data = torch.arange(16).reshape(4, 4) - topography = StaticInput( - data, - LatLonCoordinates(torch.arange(4), torch.arange(4)), - ) - land_frac = StaticInput( - data * -1.0, - LatLonCoordinates(torch.arange(4), torch.arange(4)), - ) + topography = StaticInput(data) + land_frac = StaticInput(data * -1.0) static_inputs = StaticInputs([topography, land_frac]) state = static_inputs.get_state() + # Verify coords are NOT stored in state + assert "coords" not in state["fields"][0] static_inputs_reconstructed = StaticInputs.from_state(state) assert static_inputs_reconstructed[0].data.equal(static_inputs[0].data) assert static_inputs_reconstructed[1].data.equal(static_inputs[1].data) + + +def test_StaticInputs_serialize_backward_compat_with_coords(): + """from_state should silently ignore 'coords' key for old state dicts.""" + data = torch.arange(16, dtype=torch.float32).reshape(4, 4) + # Simulate old state dict format that included coords + old_state = { + "fields": [ + { + "data": data, + "coords": { + "lat": torch.arange(4, dtype=torch.float32), + "lon": torch.arange(4, dtype=torch.float32), + }, + } + ] + } + static_inputs = StaticInputs.from_state(old_state) + assert torch.equal(static_inputs[0].data, data) diff --git a/fme/downscaling/inference/test_inference.py b/fme/downscaling/inference/test_inference.py index ed97d2b60..c06b7b0cc 100644 --- a/fme/downscaling/inference/test_inference.py +++ b/fme/downscaling/inference/test_inference.py @@ -64,8 +64,7 @@ def mock_output_target(): def get_static_inputs(shape=(16, 16)): data = torch.randn(shape) - coords = LatLonCoordinates(lat=torch.arange(shape[0]), lon=torch.arange(shape[1])) - return StaticInputs([StaticInput(data=data, coords=coords)]) + return StaticInputs([StaticInput(data=data)]) # Tests for Downscaler initialization @@ -201,9 +200,11 @@ def test_run_target_generation_skips_padding_items( mock_output_target.data.get_generator.return_value = iter([mock_work_item]) mock_model.downscale_factor = 2 - mock_model.static_inputs.coords.lat = torch.arange(0, 18).float() - mock_model.static_inputs.coords.lon = torch.arange(0, 18).float() - mock_model.static_inputs.subset_latlon.return_value.fields[0].data = torch.zeros(1) + mock_model.fine_coords = LatLonCoordinates( + lat=torch.arange(0, 18).float(), + lon=torch.arange(0, 18).float(), + ) + mock_model.static_inputs = None mock_model.generate_on_batch_no_target.return_value = { "var1": torch.zeros(1, 4, 16, 16), } @@ -273,8 +274,16 @@ def checkpointed_model_config( # loader_config is passed in to add static inputs into model # that correspond to the dataset coordinates - static_inputs = load_static_inputs({"HGTsfc": f"{data_paths.fine}/data.nc"}) - model = model_config.build(coarse_shape, 2, static_inputs=static_inputs) + fine_data_path = f"{data_paths.fine}/data.nc" + static_inputs = load_static_inputs({"HGTsfc": fine_data_path}) + ds = xr.open_dataset(fine_data_path) + fine_coords = LatLonCoordinates( + lat=torch.tensor(ds["lat"].values, dtype=torch.float32), + lon=torch.tensor(ds["lon"].values, dtype=torch.float32), + ) + model = model_config.build( + coarse_shape, 2, static_inputs=static_inputs, fine_coords=fine_coords + ) checkpoint_path = tmp_path / "model_checkpoint.pth" model.get_state() diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 2d0d1057e..1b5108b74 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -5,6 +5,7 @@ import dacite import torch +import xarray as xr from fme.core.coordinates import LatLonCoordinates from fme.core.device import get_device @@ -183,6 +184,7 @@ def build( downscale_factor: int, rename: dict[str, str] | None = None, static_inputs: StaticInputs | None = None, + fine_coords: LatLonCoordinates | None = None, ) -> "DiffusionModel": invert_rename = {v: k for k, v in (rename or {}).items()} orig_in_names = [invert_rename.get(name, name) for name in self.in_names] @@ -221,6 +223,7 @@ def build( downscale_factor=downscale_factor, sigma_data=sigma_data, static_inputs=static_inputs, + fine_coords=fine_coords, ) def get_state(self) -> Mapping[str, Any]: @@ -277,6 +280,7 @@ def __init__( downscale_factor: int, sigma_data: float, static_inputs: StaticInputs | None = None, + fine_coords: LatLonCoordinates | None = None, ) -> None: """ Args: @@ -295,6 +299,8 @@ def __init__( model preconditioning. static_inputs: Static inputs to the model, loaded from the trainer config or checkpoint. Must be set when use_fine_topography is True. + fine_coords: Full-domain fine-resolution coordinates. Used as the + single coordinate authority for output spatial metadata. """ self.coarse_shape = coarse_shape self.downscale_factor = downscale_factor @@ -310,6 +316,14 @@ def __init__( self.static_inputs = ( static_inputs.to_device() if static_inputs is not None else None ) + self.fine_coords = fine_coords + if fine_coords is not None and static_inputs is not None: + expected = (len(fine_coords.lat), len(fine_coords.lon)) + if static_inputs.shape != expected: + raise ValueError( + f"static_inputs shape {static_inputs.shape} does not match " + f"fine_coords grid {expected}" + ) @property def modules(self) -> torch.nn.ModuleList: @@ -332,7 +346,13 @@ def _subset_static_inputs( "Static inputs must be provided for each batch when use of fine " "static inputs is enabled." ) - return self.static_inputs.subset_latlon(lat_interval, lon_interval) + if self.fine_coords is None: + raise ValueError( + "fine_coords must be set on the model to subset static inputs." + ) + lat_slice = lat_interval.slice_of(self.fine_coords.lat) + lon_slice = lon_interval.slice_of(self.fine_coords.lon) + return self.static_inputs.subset(lat_slice, lon_slice) def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: """Return fine-resolution coordinates matching the spatial extent of batch.""" @@ -541,23 +561,28 @@ def generate_on_batch_no_target( "Static inputs must be provided for each batch when use of fine " "static inputs is enabled." ) + if self.fine_coords is None: + raise ValueError( + "fine_coords must be set on the model when use_fine_topography " + "is enabled." + ) coarse_lat = batch.latlon_coordinates.lat[0] coarse_lon = batch.latlon_coordinates.lon[0] fine_lat_interval = adjust_fine_coord_range( batch.lat_interval, full_coarse_coord=coarse_lat, - full_fine_coord=self.static_inputs.coords.lat, + full_fine_coord=self.fine_coords.lat, downscale_factor=self.downscale_factor, ) fine_lon_interval = adjust_fine_coord_range( batch.lon_interval, full_coarse_coord=coarse_lon, - full_fine_coord=self.static_inputs.coords.lon, + full_fine_coord=self.fine_coords.lon, downscale_factor=self.downscale_factor, ) - _static_inputs = self.static_inputs.subset_latlon( - fine_lat_interval, fine_lon_interval - ) + lat_slice = fine_lat_interval.slice_of(self.fine_coords.lat) + lon_slice = fine_lon_interval.slice_of(self.fine_coords.lon) + _static_inputs = self.static_inputs.subset(lat_slice, lon_slice) else: _static_inputs = None generated, _, _ = self.generate(batch.data, _static_inputs, n_samples) @@ -602,6 +627,9 @@ def get_state(self) -> Mapping[str, Any]: "coarse_shape": self.coarse_shape, "downscale_factor": self.downscale_factor, "static_inputs": static_inputs_state, + "fine_coords": ( + self.fine_coords.get_state() if self.fine_coords is not None else None + ), } @classmethod @@ -615,10 +643,33 @@ def from_state( static_inputs = StaticInputs.from_state(state["static_inputs"]).to_device() else: static_inputs = None + + # Load fine_coords: new checkpoints store it directly; old checkpoints + # that had static_inputs with coords can auto-migrate from raw state. + if state.get("fine_coords") is not None: + fine_coords = LatLonCoordinates( + lat=state["fine_coords"]["lat"], + lon=state["fine_coords"]["lon"], + ) + elif ( + state.get("static_inputs") is not None + and len(state["static_inputs"].get("fields", [])) > 0 + and "coords" in state["static_inputs"]["fields"][0] + ): + # Backward compat: old checkpoints stored coords inside static_inputs fields + coords_state = state["static_inputs"]["fields"][0]["coords"] + fine_coords = LatLonCoordinates( + lat=coords_state["lat"], + lon=coords_state["lon"], + ) + else: + fine_coords = None + model = config.build( state["coarse_shape"], state["downscale_factor"], static_inputs=static_inputs, + fine_coords=fine_coords, ) model.module.load_state_dict(state["module"], strict=True) return model @@ -648,6 +699,9 @@ class CheckpointModelConfig: but the model requires static input data. Raises an error if the checkpoint already has static inputs from training. fine_topography_path: Deprecated. Use static_inputs instead. + fine_coordinates_path: Optional path to a netCDF/zarr file containing lat/lon + coordinates for the full fine domain. Used for old checkpoints that have + no static_inputs and no stored fine_coords. model_updates: Optional mapping of {key: new_value} model config updates to apply when loading the model. This is useful for running evaluation with updated parameters than at training time. Use with caution; not all @@ -658,6 +712,7 @@ class CheckpointModelConfig: rename: dict[str, str] | None = None static_inputs: dict[str, str] | None = None fine_topography_path: str | None = None + fine_coordinates_path: str | None = None model_updates: dict[str, Any] | None = None def __post_init__(self) -> None: @@ -688,6 +743,8 @@ def _checkpoint(self) -> Mapping[str, Any]: ] # backwards compatibility for models before static inputs serialization checkpoint_data["model"].setdefault("static_inputs", None) + # backwards compatibility for models before fine_coords serialization + checkpoint_data["model"].setdefault("fine_coords", None) self._checkpoint_data = checkpoint_data self._checkpoint_is_loaded = True @@ -696,6 +753,23 @@ def _checkpoint(self) -> Mapping[str, Any]: checkpoint_data["model"]["config"][k] = v return self._checkpoint_data + def _load_fine_coords_from_path(self, path: str) -> LatLonCoordinates: + if path.endswith(".zarr"): + ds = xr.open_zarr(path) + else: + ds = xr.open_dataset(path) + lat_name = next((n for n in ["lat", "latitude"] if n in ds.coords), None) + lon_name = next((n for n in ["lon", "longitude"] if n in ds.coords), None) + if lat_name is None or lon_name is None: + raise ValueError( + f"Could not find lat/lon coordinates in {path}. " + "Expected 'lat'/'latitude' and 'lon'/'longitude'." + ) + return LatLonCoordinates( + lat=torch.tensor(ds[lat_name].values, dtype=torch.float32), + lon=torch.tensor(ds[lon_name].values, dtype=torch.float32), + ) + def build( self, ) -> DiffusionModel: @@ -723,6 +797,31 @@ def build( static_inputs=static_inputs, ) model.module.load_state_dict(self._checkpoint["model"]["module"]) + + # Restore fine_coords: new checkpoints have it stored directly; old + # checkpoints may have coords embedded in static_inputs fields. + model_state = self._checkpoint["model"] + if model_state.get("fine_coords") is not None: + fine_coords_state = model_state["fine_coords"] + model.fine_coords = LatLonCoordinates( + lat=fine_coords_state["lat"], + lon=fine_coords_state["lon"], + ) + elif ( + model_state.get("static_inputs") is not None + and len(model_state["static_inputs"].get("fields", [])) > 0 + and "coords" in model_state["static_inputs"]["fields"][0] + ): + coords_state = model_state["static_inputs"]["fields"][0]["coords"] + model.fine_coords = LatLonCoordinates( + lat=coords_state["lat"], + lon=coords_state["lon"], + ) + elif self.fine_coordinates_path is not None: + model.fine_coords = self._load_fine_coords_from_path( + self.fine_coordinates_path + ) + return model @property diff --git a/fme/downscaling/predict.py b/fme/downscaling/predict.py index b4da331aa..a93361aa0 100644 --- a/fme/downscaling/predict.py +++ b/fme/downscaling/predict.py @@ -21,6 +21,7 @@ ClosedInterval, DataLoaderConfig, GriddedData, + adjust_fine_coord_range, enforce_lat_bounds, ) from fme.downscaling.models import CheckpointModelConfig, DiffusionModel @@ -29,31 +30,6 @@ from fme.downscaling.typing_ import FineResCoarseResPair -def _downscale_coord(coord: torch.tensor, downscale_factor: int): - """ - This is a bandaid fix for the issue where BatchData does not - contain coords for the topography, which is fine-res in the no-target - generation case. The SampleAggregator requires the fine-res coords - for the predictions. - - TODO: remove after topography refactors to have its own data container. - """ - if len(coord.shape) != 1: - raise ValueError("coord tensor to downscale must be 1d") - spacing = coord[1] - coord[0] - # Compute edges from midpoints - first_edge = coord[0] - spacing / 2 - last_edge = coord[-1] + spacing / 2 - - # Subdivide edges - step = spacing / downscale_factor - new_edges = torch.arange(first_edge, last_edge + step / 2, step) - - # Compute new midpoints - coord_new = (new_edges[:-1] + new_edges[1:]) / 2 - return coord_new.to(device=coord.device, dtype=coord.dtype) - - @dataclasses.dataclass class EventConfig: name: str @@ -145,9 +121,32 @@ def run(self): logging.info(f"Running {self.event_name} event downscaling...") batch = next(iter(self.data.get_generator())) coarse_coords = batch[0].latlon_coordinates + if self.model.fine_coords is None: + raise ValueError( + "Model fine_coords must be set for event downscaling output " + "coordinates." + ) + coarse_lat = coarse_coords.lat + coarse_lon = coarse_coords.lon + lat_interval = ClosedInterval(coarse_lat.min().item(), coarse_lat.max().item()) + lon_interval = ClosedInterval(coarse_lon.min().item(), coarse_lon.max().item()) + fine_lat_interval = adjust_fine_coord_range( + lat_interval, + full_coarse_coord=coarse_lat, + full_fine_coord=self.model.fine_coords.lat, + downscale_factor=self.model.downscale_factor, + ) + fine_lon_interval = adjust_fine_coord_range( + lon_interval, + full_coarse_coord=coarse_lon, + full_fine_coord=self.model.fine_coords.lon, + downscale_factor=self.model.downscale_factor, + ) + lat_slice = fine_lat_interval.slice_of(self.model.fine_coords.lat) + lon_slice = fine_lon_interval.slice_of(self.model.fine_coords.lon) fine_coords = LatLonCoordinates( - lat=_downscale_coord(coarse_coords.lat, self.model.downscale_factor), - lon=_downscale_coord(coarse_coords.lon, self.model.downscale_factor), + lat=self.model.fine_coords.lat[lat_slice], + lon=self.model.fine_coords.lon[lon_slice], ) sample_agg = SampleAggregator( coarse=batch[0].data, diff --git a/fme/downscaling/predictors/composite.py b/fme/downscaling/predictors/composite.py index 05add74a0..405ac1eed 100644 --- a/fme/downscaling/predictors/composite.py +++ b/fme/downscaling/predictors/composite.py @@ -85,6 +85,10 @@ def static_inputs(self): def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: return self.model.get_fine_coords_for_batch(batch) + + @property + def fine_coords(self): + return self.model.fine_coords def _get_patches( self, coarse_yx_extent, fine_yx_extent diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index 864f511f3..dc916014c 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -124,30 +124,37 @@ def make_paired_batch_data( def make_static_inputs(fine_shape: tuple[int, int]) -> StaticInputs: - """Create StaticInputs with proper monotonic coordinates for given shape.""" - lat_size, lon_size = fine_shape + """Create StaticInputs for given shape.""" return StaticInputs( fields=[ StaticInput( torch.ones(*fine_shape, device=get_device()), - LatLonCoordinates( - lat=_get_monotonic_coordinate(lat_size, stop=lat_size), - lon=_get_monotonic_coordinate(lon_size, stop=lon_size), - ), ) ] ) +def make_fine_coords(fine_shape: tuple[int, int]) -> LatLonCoordinates: + """Create LatLonCoordinates for given fine shape.""" + lat_size, lon_size = fine_shape + return LatLonCoordinates( + lat=_get_monotonic_coordinate(lat_size, stop=lat_size), + lon=_get_monotonic_coordinate(lon_size, stop=lon_size), + ) + + def test_module_serialization(tmp_path): coarse_shape = (8, 16) - static_inputs = make_static_inputs((16, 32)) + fine_shape = (16, 32) + static_inputs = make_static_inputs(fine_shape) + fine_coords = make_fine_coords(fine_shape) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=2, predict_residual=True, use_fine_topography=False, static_inputs=static_inputs, + fine_coords=fine_coords, ) model_from_state = DiffusionModel.from_state( model.get_state(), @@ -158,6 +165,9 @@ def test_module_serialization(tmp_path): model.module.parameters(), model_from_state.module.parameters() ) ) + assert model_from_state.fine_coords is not None + assert torch.equal(model_from_state.fine_coords.lat, fine_coords.lat) + assert torch.equal(model_from_state.fine_coords.lon, fine_coords.lon) torch.save(model.get_state(), tmp_path / "test.ckpt") model_from_disk = DiffusionModel.from_state( @@ -174,6 +184,9 @@ def test_module_serialization(tmp_path): assert torch.equal( loaded_static_inputs.fields[0].data, static_inputs.fields[0].data ) + assert model_from_disk.fine_coords is not None + assert torch.equal(model_from_disk.fine_coords.lat, fine_coords.lat) + assert torch.equal(model_from_disk.fine_coords.lon, fine_coords.lon) def test_from_state_backward_compat_fine_topography(): @@ -181,21 +194,25 @@ def test_from_state_backward_compat_fine_topography(): fine_shape = (16, 32) downscale_factor = 2 static_inputs = make_static_inputs(fine_shape) + fine_coords = make_fine_coords(fine_shape) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=downscale_factor, predict_residual=True, use_fine_topography=True, static_inputs=static_inputs, + fine_coords=fine_coords, ) # Simulate old checkpoint format: static_inputs not serialized state = model.get_state() state["static_inputs"] = None + state["fine_coords"] = None # Should load correctly via the elif use_fine_topography branch (+1 channel) model_from_old_state = DiffusionModel.from_state(state) assert model_from_old_state.static_inputs is None + assert model_from_old_state.fine_coords is None assert all( torch.equal(p1, p2) for p1, p2 in zip( @@ -209,12 +226,40 @@ def test_from_state_backward_compat_fine_topography(): model_from_old_state.generate_on_batch(batch) +def test_from_state_backward_compat_migrates_fine_coords_from_old_static_inputs(): + """Old checkpoints that stored coords in static_inputs fields should have + fine_coords auto-migrated on from_state.""" + coarse_shape = (8, 16) + fine_shape = (16, 32) + downscale_factor = 2 + static_inputs = make_static_inputs(fine_shape) + fine_coords = make_fine_coords(fine_shape) + model = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=downscale_factor, + predict_residual=True, + use_fine_topography=True, + static_inputs=static_inputs, + fine_coords=fine_coords, + ) + state = model.get_state() + # Simulate old format: fine_coords absent, but static_inputs fields have coords + del state["fine_coords"] + state["static_inputs"]["fields"][0]["coords"] = fine_coords.get_state() + + model_from_old_state = DiffusionModel.from_state(state) + assert model_from_old_state.fine_coords is not None + assert torch.equal(model_from_old_state.fine_coords.lat, fine_coords.lat) + assert torch.equal(model_from_old_state.fine_coords.lon, fine_coords.lon) + + def _get_diffusion_model( coarse_shape, downscale_factor, predict_residual=True, use_fine_topography=True, static_inputs=None, + fine_coords=None, ): normalizer = PairedNormalizationConfig( NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), @@ -237,7 +282,12 @@ def _get_diffusion_model( num_diffusion_generation_steps=3, predict_residual=predict_residual, use_fine_topography=use_fine_topography, - ).build(coarse_shape, downscale_factor, static_inputs=static_inputs) + ).build( + coarse_shape, + downscale_factor, + static_inputs=static_inputs, + fine_coords=fine_coords, + ) @pytest.mark.parametrize("predict_residual", [True, False]) @@ -248,9 +298,11 @@ def test_diffusion_model_train_and_generate(predict_residual, use_fine_topograph batch_size = 2 if use_fine_topography: static_inputs = make_static_inputs(fine_shape) + fine_coords = make_fine_coords(fine_shape) batch = make_paired_batch_data(coarse_shape, fine_shape, batch_size) else: static_inputs = None + fine_coords = None batch = get_mock_paired_batch( [batch_size, *coarse_shape], [batch_size, *fine_shape] ) @@ -260,6 +312,7 @@ def test_diffusion_model_train_and_generate(predict_residual, use_fine_topograph predict_residual=predict_residual, use_fine_topography=use_fine_topography, static_inputs=static_inputs, + fine_coords=fine_coords, ) assert model._get_fine_shape(coarse_shape) == fine_shape @@ -399,12 +452,14 @@ def test_DiffusionModel_generate_on_batch_no_target(): coarse_shape = (16, 16) downscale_factor = 2 static_inputs = make_static_inputs(fine_shape) + fine_coords = make_fine_coords(fine_shape) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=downscale_factor, predict_residual=True, use_fine_topography=True, static_inputs=static_inputs, + fine_coords=fine_coords, ) batch_size = 2 @@ -435,7 +490,9 @@ def test_DiffusionModel_generate_on_batch_no_target_arbitrary_input_size(): # Full fine domain: 64x64 covers inputs for both (8,8) and (32,32) coarse inputs # with a downscaling factor of 2 full_fine_size = 64 - static_inputs = make_static_inputs((full_fine_size, full_fine_size)) + full_fine_shape = (full_fine_size, full_fine_size) + static_inputs = make_static_inputs(full_fine_shape) + fine_coords = make_fine_coords(full_fine_shape) # need to build with static inputs to get the correct n_in_channels model = _get_diffusion_model( coarse_shape=coarse_shape, @@ -443,6 +500,7 @@ def test_DiffusionModel_generate_on_batch_no_target_arbitrary_input_size(): predict_residual=True, use_fine_topography=True, static_inputs=static_inputs, + fine_coords=fine_coords, ) n_ensemble = 2 batch_size = 2 @@ -590,12 +648,14 @@ def test_checkpoint_model_build_raises_when_checkpoint_has_static_inputs(tmp_pat coarse_shape = (8, 16) fine_shape = (16, 32) static_inputs = make_static_inputs(fine_shape) + fine_coords = make_fine_coords(fine_shape) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=2, predict_residual=True, use_fine_topography=True, static_inputs=static_inputs, + fine_coords=fine_coords, ) checkpoint_path = tmp_path / "test.ckpt" torch.save({"model": model.get_state()}, checkpoint_path) diff --git a/fme/downscaling/test_predict.py b/fme/downscaling/test_predict.py index ce253ab14..5054fb538 100644 --- a/fme/downscaling/test_predict.py +++ b/fme/downscaling/test_predict.py @@ -2,8 +2,10 @@ import pytest import torch +import xarray as xr import yaml +from fme.core.coordinates import LatLonCoordinates from fme.core.loss import LossConfig from fme.core.normalizer import NormalizationConfig from fme.core.testing.wandb import mock_wandb @@ -64,6 +66,18 @@ def get_model_config( ) +def load_fine_coords_from_path(path: str) -> LatLonCoordinates: + """Load lat/lon coordinates from a netCDF or zarr data file.""" + if path.endswith(".zarr"): + ds = xr.open_zarr(path) + else: + ds = xr.open_dataset(path) + return LatLonCoordinates( + lat=torch.tensor(ds["lat"].values, dtype=torch.float32), + lon=torch.tensor(ds["lon"].values, dtype=torch.float32), + ) + + def create_predictor_config( tmp_path, n_samples: int, @@ -97,7 +111,7 @@ def create_predictor_config( out_path = tmp_path / "predictor-config.yaml" with open(out_path, "w") as file: yaml.dump(config, file) - return out_path, f"{paths.fine}/data.nc" + return out_path, paths def test_predictor_runs(tmp_path, very_fast_only: bool): @@ -106,15 +120,18 @@ def test_predictor_runs(tmp_path, very_fast_only: bool): n_samples = 2 coarse_shape = (4, 4) downscale_factor = 2 - predictor_config_path, fine_data_path = create_predictor_config( + predictor_config_path, paths = create_predictor_config( tmp_path, n_samples, ) + fine_data_path = f"{paths.fine}/data.nc" + fine_coords = load_fine_coords_from_path(fine_data_path) model_config = get_model_config(coarse_shape, downscale_factor=downscale_factor) model = model_config.build( coarse_shape=coarse_shape, downscale_factor=downscale_factor, static_inputs=load_static_inputs({"HGTsfc": fine_data_path}), + fine_coords=fine_coords, ) with open(predictor_config_path) as f: predictor_config = yaml.safe_load(f) @@ -147,7 +164,7 @@ def test_predictor_renaming( coarse_shape = (4, 4) downscale_factor = 2 renaming = {"var0": "var0_renamed", "var1": "var1_renamed"} - predictor_config_path, fine_data_path = create_predictor_config( + predictor_config_path, paths = create_predictor_config( tmp_path, n_samples, model_renaming=renaming, @@ -155,13 +172,12 @@ def test_predictor_renaming( "rename": {"var0": "var0_renamed", "var1": "var1_renamed"} }, ) + fine_coords = load_fine_coords_from_path(f"{paths.fine}/data.nc") model_config = get_model_config( coarse_shape, downscale_factor, use_fine_topography=False ) model = model_config.build( - coarse_shape=coarse_shape, - downscale_factor=2, - static_inputs=load_static_inputs({"HGTsfc": fine_data_path}), + coarse_shape=coarse_shape, downscale_factor=2, fine_coords=fine_coords ) with open(predictor_config_path) as f: predictor_config = yaml.safe_load(f) diff --git a/fme/downscaling/train.py b/fme/downscaling/train.py index a1dd08fde..5a0c25666 100755 --- a/fme/downscaling/train.py +++ b/fme/downscaling/train.py @@ -443,6 +443,7 @@ def build(self) -> Trainer: model_coarse_shape, train_data.downscale_factor, static_inputs=static_inputs, + fine_coords=train_data.fine_coords, ) optimization = self.optimization.build( From 7c5b4c5ae5ae37abce7e98f76e1f679c7eb657e1 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Fri, 13 Mar 2026 14:58:38 -0700 Subject: [PATCH 02/27] Make fine coords required --- fme/downscaling/models.py | 101 ++++++++++++++++++++++++--------- fme/downscaling/predict.py | 5 -- fme/downscaling/test_models.py | 19 +++++-- 3 files changed, 87 insertions(+), 38 deletions(-) diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 1b5108b74..94f0c5a81 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -182,9 +182,9 @@ def build( self, coarse_shape: tuple[int, int], downscale_factor: int, + fine_coords: LatLonCoordinates, rename: dict[str, str] | None = None, static_inputs: StaticInputs | None = None, - fine_coords: LatLonCoordinates | None = None, ) -> "DiffusionModel": invert_rename = {v: k for k, v in (rename or {}).items()} orig_in_names = [invert_rename.get(name, name) for name in self.in_names] @@ -279,8 +279,8 @@ def __init__( coarse_shape: tuple[int, int], downscale_factor: int, sigma_data: float, + fine_coords: LatLonCoordinates, static_inputs: StaticInputs | None = None, - fine_coords: LatLonCoordinates | None = None, ) -> None: """ Args: @@ -297,10 +297,12 @@ def __init__( coarse to fine. sigma_data: The standard deviation of the data, used for diffusion model preconditioning. + fine_coords: the full-domain fine-resolution coordinates to use + for spatial metadata in the model output. static_inputs: Static inputs to the model, loaded from the trainer config or checkpoint. Must be set when use_fine_topography is True. - fine_coords: Full-domain fine-resolution coordinates. Used as the - single coordinate authority for output spatial metadata. + Expected to be on the same coordinate grid as fine_coords for now, + but this may be relaxed in the future. """ self.coarse_shape = coarse_shape self.downscale_factor = downscale_factor @@ -317,12 +319,13 @@ def __init__( static_inputs.to_device() if static_inputs is not None else None ) self.fine_coords = fine_coords - if fine_coords is not None and static_inputs is not None: + if static_inputs is not None: expected = (len(fine_coords.lat), len(fine_coords.lon)) if static_inputs.shape != expected: raise ValueError( - f"static_inputs shape {static_inputs.shape} does not match " - f"fine_coords grid {expected}" + f"static_inputs are expected to be on the same coordinate grid as " + f"fine_coords. StaticInputs shape {static_inputs.shape} does not " + f"match fine_coords grid {expected}." ) @property @@ -346,10 +349,6 @@ def _subset_static_inputs( "Static inputs must be provided for each batch when use of fine " "static inputs is enabled." ) - if self.fine_coords is None: - raise ValueError( - "fine_coords must be set on the model to subset static inputs." - ) lat_slice = lat_interval.slice_of(self.fine_coords.lat) lon_slice = lon_interval.slice_of(self.fine_coords.lon) return self.static_inputs.subset(lat_slice, lon_slice) @@ -561,11 +560,6 @@ def generate_on_batch_no_target( "Static inputs must be provided for each batch when use of fine " "static inputs is enabled." ) - if self.fine_coords is None: - raise ValueError( - "fine_coords must be set on the model when use_fine_topography " - "is enabled." - ) coarse_lat = batch.latlon_coordinates.lat[0] coarse_lon = batch.latlon_coordinates.lon[0] fine_lat_interval = adjust_fine_coord_range( @@ -627,9 +621,7 @@ def get_state(self) -> Mapping[str, Any]: "coarse_shape": self.coarse_shape, "downscale_factor": self.downscale_factor, "static_inputs": static_inputs_state, - "fine_coords": ( - self.fine_coords.get_state() if self.fine_coords is not None else None - ), + "fine_coords": self.fine_coords.get_state(), } @classmethod @@ -646,24 +638,31 @@ def from_state( # Load fine_coords: new checkpoints store it directly; old checkpoints # that had static_inputs with coords can auto-migrate from raw state. - if state.get("fine_coords") is not None: + fine_coords = state.get("fine_coords") + if fine_coords is not None: + # TODO: Why doesn't LatLonCoordinates have a from_state method? fine_coords = LatLonCoordinates( lat=state["fine_coords"]["lat"], lon=state["fine_coords"]["lon"], ) elif ( - state.get("static_inputs") is not None - and len(state["static_inputs"].get("fields", [])) > 0 + static_inputs is not None + and static_inputs.fields and "coords" in state["static_inputs"]["fields"][0] ): - # Backward compat: old checkpoints stored coords inside static_inputs fields + # Backward compat: old checkpoints with static inputs stored coords inside + # static_inputs fields[0]["coords"] coords_state = state["static_inputs"]["fields"][0]["coords"] fine_coords = LatLonCoordinates( lat=coords_state["lat"], lon=coords_state["lon"], ) else: - fine_coords = None + raise ValueError( + "No fine coordinates found in checkpoint state and no static inputs " + " were available to infer them. fine_coords must be serialized with the" + " checkpoint to resume from training." + ) model = config.build( state["coarse_shape"], @@ -758,12 +757,16 @@ def _load_fine_coords_from_path(self, path: str) -> LatLonCoordinates: ds = xr.open_zarr(path) else: ds = xr.open_dataset(path) - lat_name = next((n for n in ["lat", "latitude"] if n in ds.coords), None) - lon_name = next((n for n in ["lon", "longitude"] if n in ds.coords), None) + lat_name = next( + (n for n in ["lat", "latitude", "grid_yt"] if n in ds.coords), None + ) + lon_name = next( + (n for n in ["lon", "longitude", "grid_xt"] if n in ds.coords), None + ) if lat_name is None or lon_name is None: raise ValueError( f"Could not find lat/lon coordinates in {path}. " - "Expected 'lat'/'latitude' and 'lon'/'longitude'." + "Expected 'lat'/'latitude'/'grid_yt' and 'lon'/'longitude'/'grid_xt'." ) return LatLonCoordinates( lat=torch.tensor(ds[lat_name].values, dtype=torch.float32), @@ -788,11 +791,55 @@ def build( static_inputs = load_static_inputs(self.static_inputs) else: static_inputs = None + + fine_coords: LatLonCoordinates + has_fine_coords = self._checkpoint["model"]["fine_coords"] is not None + has_static_input_coords = ( + self._checkpoint["model"]["static_inputs"] is not None + and self._checkpoint["model"]["static_inputs"]["fields"][0].get("coords") + is not None + ) + # TODO: simplify with static input refactor that deisables empty StaticInputs + if ( + has_fine_coords + or has_static_input_coords + and self.fine_coordinates_path is not None + ): + raise ValueError( + "The model checkpoint already has fine coordinates from training. " + "fine_coordinates_path should not be provided in checkpoint model " + "config." + ) + elif has_fine_coords: + fine_coords_state = self._checkpoint["model"]["fine_coords"] + fine_coords = LatLonCoordinates( + lat=fine_coords_state["lat"], + lon=fine_coords_state["lon"], + ) + elif has_static_input_coords: + coords_state = self._checkpoint["model"]["static_inputs"]["fields"][0][ + "coords" + ] + fine_coords = LatLonCoordinates( + lat=coords_state["lat"], + lon=coords_state["lon"], + ) + elif self.fine_coordinates_path is not None: + fine_coords = self._load_fine_coords_from_path(self.fine_coordinates_path) + else: + raise ValueError( + "No fine coordinates found in checkpoint state and no static inputs " + " were available to infer them. fine_coordinates_path must be provided " + "in the checkpoint model configuration to load fine coordinates from " + "the provided path." + ) + model = _CheckpointModelConfigSelector.from_state( self._checkpoint["model"]["config"] ).build( coarse_shape=self._checkpoint["model"]["coarse_shape"], downscale_factor=self._checkpoint["model"]["downscale_factor"], + fine_coords=fine_coords, rename=self._rename, static_inputs=static_inputs, ) diff --git a/fme/downscaling/predict.py b/fme/downscaling/predict.py index a93361aa0..449e84368 100644 --- a/fme/downscaling/predict.py +++ b/fme/downscaling/predict.py @@ -121,11 +121,6 @@ def run(self): logging.info(f"Running {self.event_name} event downscaling...") batch = next(iter(self.data.get_generator())) coarse_coords = batch[0].latlon_coordinates - if self.model.fine_coords is None: - raise ValueError( - "Model fine_coords must be set for event downscaling output " - "coordinates." - ) coarse_lat = coarse_coords.lat coarse_lon = coarse_coords.lon lat_interval = ClosedInterval(coarse_lat.min().item(), coarse_lat.max().item()) diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index dc916014c..d801eec25 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -204,15 +204,15 @@ def test_from_state_backward_compat_fine_topography(): fine_coords=fine_coords, ) - # Simulate old checkpoint format: static_inputs not serialized + # Simulate old checkpoint format: static_inputs not serialized, fine_coords still + # present state = model.get_state() state["static_inputs"] = None - state["fine_coords"] = None # Should load correctly via the elif use_fine_topography branch (+1 channel) model_from_old_state = DiffusionModel.from_state(state) assert model_from_old_state.static_inputs is None - assert model_from_old_state.fine_coords is None + assert torch.equal(model_from_old_state.fine_coords.lat, fine_coords.lat) assert all( torch.equal(p1, p2) for p1, p2 in zip( @@ -259,12 +259,18 @@ def _get_diffusion_model( predict_residual=True, use_fine_topography=True, static_inputs=None, - fine_coords=None, + fine_coords: LatLonCoordinates | None = None, ): normalizer = PairedNormalizationConfig( NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), ) + if fine_coords is None: + fine_shape = ( + coarse_shape[0] * downscale_factor, + coarse_shape[1] * downscale_factor, + ) + fine_coords = make_fine_coords(fine_shape) return DiffusionModelConfig( module=DiffusionModuleRegistrySelector( @@ -296,13 +302,12 @@ def test_diffusion_model_train_and_generate(predict_residual, use_fine_topograph coarse_shape = (8, 16) fine_shape = (16, 32) batch_size = 2 + fine_coords = make_fine_coords(fine_shape) if use_fine_topography: static_inputs = make_static_inputs(fine_shape) - fine_coords = make_fine_coords(fine_shape) batch = make_paired_batch_data(coarse_shape, fine_shape, batch_size) else: static_inputs = None - fine_coords = None batch = get_mock_paired_batch( [batch_size, *coarse_shape], [batch_size, *fine_shape] ) @@ -436,6 +441,7 @@ def test_model_error_cases(): ).build( coarse_shape, upscaling_factor, + fine_coords=make_fine_coords(fine_shape), ) batch = get_mock_paired_batch( [batch_size, *coarse_shape], [batch_size, *fine_shape] @@ -552,6 +558,7 @@ def test_lognorm_noise_backwards_compatibility(): model = model_config.build( (32, 32), 2, + fine_coords=make_fine_coords((64, 64)), ) state = model.get_state() From c7f58d9430eae8f95f1a9d6ece66e4858d944d06 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Fri, 13 Mar 2026 15:12:27 -0700 Subject: [PATCH 03/27] Fine coords required for paired data --- fme/downscaling/data/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fme/downscaling/data/datasets.py b/fme/downscaling/data/datasets.py index 2b6798f04..b11d0e49d 100644 --- a/fme/downscaling/data/datasets.py +++ b/fme/downscaling/data/datasets.py @@ -337,7 +337,7 @@ class PairedGriddedData: dims: list[str] variable_metadata: Mapping[str, VariableMetadata] all_times: xr.CFTimeIndex - fine_coords: LatLonCoordinates | None = None + fine_coords: LatLonCoordinates @property def loader(self) -> DataLoader[PairedBatchItem]: From 563b57ac00c6dcbdc4d7c82ad6aafed54f6bcf88 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Mon, 16 Mar 2026 20:39:08 -0700 Subject: [PATCH 04/27] Mesh with previous updates in refactor pr --- fme/downscaling/data/static.py | 2 +- fme/downscaling/models.py | 60 ++++++++-------------------------- fme/downscaling/test_models.py | 42 +++++++++--------------- 3 files changed, 30 insertions(+), 74 deletions(-) diff --git a/fme/downscaling/data/static.py b/fme/downscaling/data/static.py index ae4721f8d..48ab54ed7 100644 --- a/fme/downscaling/data/static.py +++ b/fme/downscaling/data/static.py @@ -123,7 +123,7 @@ def from_state(cls, state: dict) -> "StaticInputs": def load_static_inputs( static_inputs_config: dict[str, str] | None, -) -> "StaticInputs | None": +) -> StaticInputs | None: """ Load normalized static inputs from a mapping of field names to file paths. Returns None if the input config is empty. diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 94f0c5a81..11ec1b1db 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -318,7 +318,11 @@ def __init__( self.static_inputs = ( static_inputs.to_device() if static_inputs is not None else None ) - self.fine_coords = fine_coords + device = get_device() + self.fine_coords = LatLonCoordinates( + lat=fine_coords.lat.to(device), + lon=fine_coords.lon.to(device), + ) if static_inputs is not None: expected = (len(fine_coords.lat), len(fine_coords.lon)) if static_inputs.shape != expected: @@ -355,32 +359,23 @@ def _subset_static_inputs( def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: """Return fine-resolution coordinates matching the spatial extent of batch.""" - if self.static_inputs is None: - raise ValueError( - "Model is missing static inputs, which are required to determine " - "the coordinate information for the output dataset." - ) coarse_lat = batch.latlon_coordinates.lat[0] coarse_lon = batch.latlon_coordinates.lon[0] fine_lat_interval = adjust_fine_coord_range( batch.lat_interval, full_coarse_coord=coarse_lat, - full_fine_coord=self.static_inputs.coords.lat, + full_fine_coord=self.fine_coords.lat, downscale_factor=self.downscale_factor, ) fine_lon_interval = adjust_fine_coord_range( batch.lon_interval, full_coarse_coord=coarse_lon, - full_fine_coord=self.static_inputs.coords.lon, + full_fine_coord=self.fine_coords.lon, downscale_factor=self.downscale_factor, ) return LatLonCoordinates( - lat=self.static_inputs.coords.lat[ - fine_lat_interval.slice_of(self.static_inputs.coords.lat) - ], - lon=self.static_inputs.coords.lon[ - fine_lon_interval.slice_of(self.static_inputs.coords.lon) - ], + lat=self.fine_coords.lat[fine_lat_interval.slice_of(self.fine_coords.lat)], + lon=self.fine_coords.lon[fine_lon_interval.slice_of(self.fine_coords.lon)], ) @property @@ -574,9 +569,9 @@ def generate_on_batch_no_target( full_fine_coord=self.fine_coords.lon, downscale_factor=self.downscale_factor, ) - lat_slice = fine_lat_interval.slice_of(self.fine_coords.lat) - lon_slice = fine_lon_interval.slice_of(self.fine_coords.lon) - _static_inputs = self.static_inputs.subset(lat_slice, lon_slice) + _static_inputs = self._subset_static_inputs( + fine_lat_interval, fine_lon_interval + ) else: _static_inputs = None generated, _, _ = self.generate(batch.data, _static_inputs, n_samples) @@ -801,10 +796,8 @@ def build( ) # TODO: simplify with static input refactor that deisables empty StaticInputs if ( - has_fine_coords - or has_static_input_coords - and self.fine_coordinates_path is not None - ): + has_fine_coords or has_static_input_coords + ) and self.fine_coordinates_path is not None: raise ValueError( "The model checkpoint already has fine coordinates from training. " "fine_coordinates_path should not be provided in checkpoint model " @@ -844,31 +837,6 @@ def build( static_inputs=static_inputs, ) model.module.load_state_dict(self._checkpoint["model"]["module"]) - - # Restore fine_coords: new checkpoints have it stored directly; old - # checkpoints may have coords embedded in static_inputs fields. - model_state = self._checkpoint["model"] - if model_state.get("fine_coords") is not None: - fine_coords_state = model_state["fine_coords"] - model.fine_coords = LatLonCoordinates( - lat=fine_coords_state["lat"], - lon=fine_coords_state["lon"], - ) - elif ( - model_state.get("static_inputs") is not None - and len(model_state["static_inputs"].get("fields", [])) > 0 - and "coords" in model_state["static_inputs"]["fields"][0] - ): - coords_state = model_state["static_inputs"]["fields"][0]["coords"] - model.fine_coords = LatLonCoordinates( - lat=coords_state["lat"], - lon=coords_state["lon"], - ) - elif self.fine_coordinates_path is not None: - model.fine_coords = self._load_fine_coords_from_path( - self.fine_coordinates_path - ) - return model @property diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index d801eec25..803baefda 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -166,8 +166,8 @@ def test_module_serialization(tmp_path): ) ) assert model_from_state.fine_coords is not None - assert torch.equal(model_from_state.fine_coords.lat, fine_coords.lat) - assert torch.equal(model_from_state.fine_coords.lon, fine_coords.lon) + assert torch.equal(model_from_state.fine_coords.lat.cpu(), fine_coords.lat.cpu()) + assert torch.equal(model_from_state.fine_coords.lon.cpu(), fine_coords.lon.cpu()) torch.save(model.get_state(), tmp_path / "test.ckpt") model_from_disk = DiffusionModel.from_state( @@ -185,8 +185,8 @@ def test_module_serialization(tmp_path): loaded_static_inputs.fields[0].data, static_inputs.fields[0].data ) assert model_from_disk.fine_coords is not None - assert torch.equal(model_from_disk.fine_coords.lat, fine_coords.lat) - assert torch.equal(model_from_disk.fine_coords.lon, fine_coords.lon) + assert torch.equal(model_from_disk.fine_coords.lat.cpu(), fine_coords.lat.cpu()) + assert torch.equal(model_from_disk.fine_coords.lon.cpu(), fine_coords.lon.cpu()) def test_from_state_backward_compat_fine_topography(): @@ -212,7 +212,9 @@ def test_from_state_backward_compat_fine_topography(): # Should load correctly via the elif use_fine_topography branch (+1 channel) model_from_old_state = DiffusionModel.from_state(state) assert model_from_old_state.static_inputs is None - assert torch.equal(model_from_old_state.fine_coords.lat, fine_coords.lat) + assert torch.equal( + model_from_old_state.fine_coords.lat.cpu(), fine_coords.lat.cpu() + ) assert all( torch.equal(p1, p2) for p1, p2 in zip( @@ -249,8 +251,12 @@ def test_from_state_backward_compat_migrates_fine_coords_from_old_static_inputs( model_from_old_state = DiffusionModel.from_state(state) assert model_from_old_state.fine_coords is not None - assert torch.equal(model_from_old_state.fine_coords.lat, fine_coords.lat) - assert torch.equal(model_from_old_state.fine_coords.lon, fine_coords.lon) + assert torch.equal( + model_from_old_state.fine_coords.lat.cpu(), fine_coords.lat.cpu() + ) + assert torch.equal( + model_from_old_state.fine_coords.lon.cpu(), fine_coords.lon.cpu() + ) def _get_diffusion_model( @@ -619,30 +625,12 @@ def test_get_fine_coords_for_batch(): result = model.get_fine_coords_for_batch(batch) - expected_lat = model.static_inputs.coords.lat[4:12] - expected_lon = model.static_inputs.coords.lon[8:24] - # model.static_inputs has been moved to device; index into it directly - # to match devices + expected_lat = model.fine_coords.lat[4:12] + expected_lon = model.fine_coords.lon[8:24] assert torch.allclose(result.lat, expected_lat) assert torch.allclose(result.lon, expected_lon) -def test_get_fine_coords_for_batch_raises_without_static_inputs(): - model = _get_diffusion_model( - coarse_shape=(16, 16), - downscale_factor=2, - use_fine_topography=False, - static_inputs=None, - ) - batch = make_batch_data( - (1, 16, 16), - _get_monotonic_coordinate(16, stop=16).tolist(), - _get_monotonic_coordinate(16, stop=16).tolist(), - ) - with pytest.raises(ValueError, match="missing static inputs"): - model.get_fine_coords_for_batch(batch) - - def test_checkpoint_config_topography_raises(): with pytest.raises(ValueError): CheckpointModelConfig( From f4218dd66b8afcb3d3350fedb84d821ab6256437 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Mon, 16 Mar 2026 21:19:02 -0700 Subject: [PATCH 05/27] Simplify event downscaler coordinate in run() --- fme/downscaling/predict.py | 27 ++++----------------------- 1 file changed, 4 insertions(+), 23 deletions(-) diff --git a/fme/downscaling/predict.py b/fme/downscaling/predict.py index 449e84368..d56a16e56 100644 --- a/fme/downscaling/predict.py +++ b/fme/downscaling/predict.py @@ -21,7 +21,6 @@ ClosedInterval, DataLoaderConfig, GriddedData, - adjust_fine_coord_range, enforce_lat_bounds, ) from fme.downscaling.models import CheckpointModelConfig, DiffusionModel @@ -120,29 +119,11 @@ def generation_model(self): def run(self): logging.info(f"Running {self.event_name} event downscaling...") batch = next(iter(self.data.get_generator())) - coarse_coords = batch[0].latlon_coordinates - coarse_lat = coarse_coords.lat - coarse_lon = coarse_coords.lon - lat_interval = ClosedInterval(coarse_lat.min().item(), coarse_lat.max().item()) - lon_interval = ClosedInterval(coarse_lon.min().item(), coarse_lon.max().item()) - fine_lat_interval = adjust_fine_coord_range( - lat_interval, - full_coarse_coord=coarse_lat, - full_fine_coord=self.model.fine_coords.lat, - downscale_factor=self.model.downscale_factor, - ) - fine_lon_interval = adjust_fine_coord_range( - lon_interval, - full_coarse_coord=coarse_lon, - full_fine_coord=self.model.fine_coords.lon, - downscale_factor=self.model.downscale_factor, - ) - lat_slice = fine_lat_interval.slice_of(self.model.fine_coords.lat) - lon_slice = fine_lon_interval.slice_of(self.model.fine_coords.lon) - fine_coords = LatLonCoordinates( - lat=self.model.fine_coords.lat[lat_slice], - lon=self.model.fine_coords.lon[lon_slice], + coarse_coords = LatLonCoordinates( + lat=batch[0].latlon_coordinates.lat, + lon=batch[0].latlon_coordinates.lon, ) + fine_coords = self.model.get_fine_coords_for_batch(batch) sample_agg = SampleAggregator( coarse=batch[0].data, latlon_coordinates=FineResCoarseResPair( From 68cab616b230dde8263cc3504a2bc6fc1ba55ecf Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Mon, 16 Mar 2026 21:20:00 -0700 Subject: [PATCH 06/27] use batch latlon coardinates for coarse --- fme/downscaling/predict.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/fme/downscaling/predict.py b/fme/downscaling/predict.py index d56a16e56..d35d0bc1a 100644 --- a/fme/downscaling/predict.py +++ b/fme/downscaling/predict.py @@ -9,7 +9,6 @@ import yaml from fme.core.cli import prepare_directory -from fme.core.coordinates import LatLonCoordinates from fme.core.dataset.time import TimeSlice from fme.core.dicts import to_flat_dict from fme.core.distributed import Distributed @@ -119,10 +118,7 @@ def generation_model(self): def run(self): logging.info(f"Running {self.event_name} event downscaling...") batch = next(iter(self.data.get_generator())) - coarse_coords = LatLonCoordinates( - lat=batch[0].latlon_coordinates.lat, - lon=batch[0].latlon_coordinates.lon, - ) + coarse_coords = batch[0].latlon_coordinates fine_coords = self.model.get_fine_coords_for_batch(batch) sample_agg = SampleAggregator( coarse=batch[0].data, From 83ca0438cfbc2bf8ddafd5b3f6dd413cfb5de643 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Mon, 16 Mar 2026 21:26:29 -0700 Subject: [PATCH 07/27] Make fine coord loader public --- fme/downscaling/models.py | 43 ++++++++++++++++----------------- fme/downscaling/test_predict.py | 20 ++++----------- fme/downscaling/test_utils.py | 1 + 3 files changed, 27 insertions(+), 37 deletions(-) diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 11ec1b1db..f50071cad 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -680,6 +680,26 @@ def from_state(cls, state: Mapping[str, Any]) -> DiffusionModelConfig: ).wrapper +def load_fine_coords_from_path(path: str) -> LatLonCoordinates: + if path.endswith(".zarr"): + ds = xr.open_zarr(path) + else: + ds = xr.open_dataset(path) + lat_name = next((n for n in ["lat", "latitude", "grid_yt"] if n in ds.coords), None) + lon_name = next( + (n for n in ["lon", "longitude", "grid_xt"] if n in ds.coords), None + ) + if lat_name is None or lon_name is None: + raise ValueError( + f"Could not find lat/lon coordinates in {path}. " + "Expected 'lat'/'latitude'/'grid_yt' and 'lon'/'longitude'/'grid_xt'." + ) + return LatLonCoordinates( + lat=torch.tensor(ds[lat_name].values, dtype=torch.float32), + lon=torch.tensor(ds[lon_name].values, dtype=torch.float32), + ) + + @dataclasses.dataclass class CheckpointModelConfig: """ @@ -747,27 +767,6 @@ def _checkpoint(self) -> Mapping[str, Any]: checkpoint_data["model"]["config"][k] = v return self._checkpoint_data - def _load_fine_coords_from_path(self, path: str) -> LatLonCoordinates: - if path.endswith(".zarr"): - ds = xr.open_zarr(path) - else: - ds = xr.open_dataset(path) - lat_name = next( - (n for n in ["lat", "latitude", "grid_yt"] if n in ds.coords), None - ) - lon_name = next( - (n for n in ["lon", "longitude", "grid_xt"] if n in ds.coords), None - ) - if lat_name is None or lon_name is None: - raise ValueError( - f"Could not find lat/lon coordinates in {path}. " - "Expected 'lat'/'latitude'/'grid_yt' and 'lon'/'longitude'/'grid_xt'." - ) - return LatLonCoordinates( - lat=torch.tensor(ds[lat_name].values, dtype=torch.float32), - lon=torch.tensor(ds[lon_name].values, dtype=torch.float32), - ) - def build( self, ) -> DiffusionModel: @@ -818,7 +817,7 @@ def build( lon=coords_state["lon"], ) elif self.fine_coordinates_path is not None: - fine_coords = self._load_fine_coords_from_path(self.fine_coordinates_path) + fine_coords = load_fine_coords_from_path(self.fine_coordinates_path) else: raise ValueError( "No fine coordinates found in checkpoint state and no static inputs " diff --git a/fme/downscaling/test_predict.py b/fme/downscaling/test_predict.py index 5054fb538..8101b293e 100644 --- a/fme/downscaling/test_predict.py +++ b/fme/downscaling/test_predict.py @@ -2,16 +2,18 @@ import pytest import torch -import xarray as xr import yaml -from fme.core.coordinates import LatLonCoordinates from fme.core.loss import LossConfig from fme.core.normalizer import NormalizationConfig from fme.core.testing.wandb import mock_wandb from fme.downscaling import predict from fme.downscaling.data import load_static_inputs -from fme.downscaling.models import DiffusionModelConfig, PairedNormalizationConfig +from fme.downscaling.models import ( + DiffusionModelConfig, + PairedNormalizationConfig, + load_fine_coords_from_path, +) from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector from fme.downscaling.test_models import LinearDownscaling from fme.downscaling.test_utils import data_paths_helper @@ -66,18 +68,6 @@ def get_model_config( ) -def load_fine_coords_from_path(path: str) -> LatLonCoordinates: - """Load lat/lon coordinates from a netCDF or zarr data file.""" - if path.endswith(".zarr"): - ds = xr.open_zarr(path) - else: - ds = xr.open_dataset(path) - return LatLonCoordinates( - lat=torch.tensor(ds["lat"].values, dtype=torch.float32), - lon=torch.tensor(ds["lon"].values, dtype=torch.float32), - ) - - def create_predictor_config( tmp_path, n_samples: int, diff --git a/fme/downscaling/test_utils.py b/fme/downscaling/test_utils.py index 4f34a4451..b1546e044 100644 --- a/fme/downscaling/test_utils.py +++ b/fme/downscaling/test_utils.py @@ -78,4 +78,5 @@ def data_paths_helper( create_test_data_on_disk( coarse_path / "data.nc", dim_sizes.coarse, variable_names, coords ) + # TODO: should this return the full filename instead of just the directory? return FineResCoarseResPair[str](fine=fine_path, coarse=coarse_path) From 794a7d40a00145f6ba99fea30e150b3b5314a7a4 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Mon, 16 Mar 2026 21:40:17 -0700 Subject: [PATCH 08/27] BatchLatLon coord access consistency --- fme/downscaling/predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fme/downscaling/predict.py b/fme/downscaling/predict.py index d35d0bc1a..c6b0a3677 100644 --- a/fme/downscaling/predict.py +++ b/fme/downscaling/predict.py @@ -118,7 +118,7 @@ def generation_model(self): def run(self): logging.info(f"Running {self.event_name} event downscaling...") batch = next(iter(self.data.get_generator())) - coarse_coords = batch[0].latlon_coordinates + coarse_coords = batch.latlon_coordinates[0] fine_coords = self.model.get_fine_coords_for_batch(batch) sample_agg = SampleAggregator( coarse=batch[0].data, From b542080861455241f2b8315a124e0015c14d5e75 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Mon, 16 Mar 2026 21:41:35 -0700 Subject: [PATCH 09/27] linting --- fme/downscaling/predictors/composite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fme/downscaling/predictors/composite.py b/fme/downscaling/predictors/composite.py index 405ac1eed..5d0a4be8f 100644 --- a/fme/downscaling/predictors/composite.py +++ b/fme/downscaling/predictors/composite.py @@ -85,7 +85,7 @@ def static_inputs(self): def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: return self.model.get_fine_coords_for_batch(batch) - + @property def fine_coords(self): return self.model.fine_coords From 5add7272aec664240428468d6cfdd5035f047dd4 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Mon, 16 Mar 2026 21:45:45 -0700 Subject: [PATCH 10/27] Add no coords checkpoint with path test --- fme/downscaling/test_models.py | 37 ++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index 803baefda..98bb76b33 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -631,6 +631,43 @@ def test_get_fine_coords_for_batch(): assert torch.allclose(result.lon, expected_lon) +def test_checkpoint_model_build_with_fine_coordinates_path(tmp_path): + """Old-format checkpoint (no fine_coords key, no coords in static_inputs) + should load correctly when fine_coordinates_path is provided.""" + coarse_shape = (8, 16) + fine_shape = (16, 32) + fine_coords = make_fine_coords(fine_shape) + model = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=2, + use_fine_topography=False, + fine_coords=fine_coords, + ) + # Simulate old checkpoint: no fine_coords stored + state = model.get_state() + del state["fine_coords"] + checkpoint_path = tmp_path / "test.ckpt" + torch.save({"model": state}, checkpoint_path) + + # Write fine coords to a netCDF file for the loader to read + coords_path = tmp_path / "fine_coords.nc" + ds = xr.Dataset( + coords={ + "lat": fine_coords.lat.cpu().numpy(), + "lon": fine_coords.lon.cpu().numpy(), + } + ) + ds.to_netcdf(coords_path) + + loaded_model = CheckpointModelConfig( + checkpoint_path=str(checkpoint_path), + fine_coordinates_path=str(coords_path), + ).build() + + assert torch.equal(loaded_model.fine_coords.lat.cpu(), fine_coords.lat.cpu()) + assert torch.equal(loaded_model.fine_coords.lon.cpu(), fine_coords.lon.cpu()) + + def test_checkpoint_config_topography_raises(): with pytest.raises(ValueError): CheckpointModelConfig( From 36e80db4c185b5fd6257b185a43d664daf96b771 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Tue, 17 Mar 2026 15:37:30 -0700 Subject: [PATCH 11/27] Small tweaks --- fme/downscaling/data/datasets.py | 6 ++++-- fme/downscaling/data/test_utils.py | 21 +++++++++++++++++++++ fme/downscaling/data/utils.py | 8 ++++++++ fme/downscaling/inference/test_inference.py | 7 ++----- fme/downscaling/models.py | 6 ++++-- 5 files changed, 39 insertions(+), 9 deletions(-) diff --git a/fme/downscaling/data/datasets.py b/fme/downscaling/data/datasets.py index b11d0e49d..186133d2a 100644 --- a/fme/downscaling/data/datasets.py +++ b/fme/downscaling/data/datasets.py @@ -138,11 +138,13 @@ def __init__( f"expected lon_min < {self.lon_interval.start + 360.0}" ) + # Used to subset the data in __getitem__ self._lats_slice = self.lat_interval.slice_of(self._orig_coords.lat) self._lons_slice = self.lon_interval.slice_of(self._orig_coords.lon) + self._latlon_coordinates = LatLonCoordinates( - lat=self._orig_coords.lat[self._lats_slice], - lon=self._orig_coords.lon[self._lons_slice], + lat=self.lat_interval.subset_of(self._orig_coords.lat), + lon=self.lon_interval.subset_of(self._orig_coords.lon), ) self._area_weights = self._latlon_coordinates.area_weights diff --git a/fme/downscaling/data/test_utils.py b/fme/downscaling/data/test_utils.py index 662231147..c0d66521b 100644 --- a/fme/downscaling/data/test_utils.py +++ b/fme/downscaling/data/test_utils.py @@ -98,6 +98,27 @@ def test_ClosedInterval_slice_of(interval, expected_slice): assert result_slice == expected_slice +@pytest.mark.parametrize( + "interval,expected_values", + [ + pytest.param(ClosedInterval(2, 4), torch.tensor([2, 3, 4]), id="middle"), + pytest.param( + ClosedInterval(float("-inf"), 2), torch.tensor([0, 1, 2]), id="start_inf" + ), + pytest.param(ClosedInterval(4, float("inf")), torch.tensor([4]), id="end_inf"), + pytest.param( + ClosedInterval(float("-inf"), float("inf")), + torch.arange(5), + id="all_inf", + ), + ], +) +def test_ClosedInterval_subset_of(interval, expected_values): + coords = torch.arange(5) + result = interval.subset_of(coords) + assert torch.equal(result, expected_values) + + def test_ClosedInterval_fail_on_empty_slice(): coords = torch.arange(5) with pytest.raises(ValueError): diff --git a/fme/downscaling/data/utils.py b/fme/downscaling/data/utils.py index 8cecf8c09..146375ef8 100644 --- a/fme/downscaling/data/utils.py +++ b/fme/downscaling/data/utils.py @@ -63,6 +63,14 @@ def slice_of(self, coords: torch.Tensor) -> slice: indices = mask.nonzero(as_tuple=True)[0] return slice(indices[0].item(), indices[-1].item() + 1) + def subset_of(self, coords: torch.Tensor) -> torch.Tensor: + """ + Return a subset of `coords` that falls within this specified interval. + This assumes `coords` is monotonically increasing. + """ + slice = self.slice_of(coords) + return coords[slice] + def scale_slice(slice_: slice, scale: int) -> slice: if slice_ == slice(None): diff --git a/fme/downscaling/inference/test_inference.py b/fme/downscaling/inference/test_inference.py index c06b7b0cc..481bc2973 100644 --- a/fme/downscaling/inference/test_inference.py +++ b/fme/downscaling/inference/test_inference.py @@ -32,6 +32,7 @@ LossConfig, NormalizationConfig, PairedNormalizationConfig, + load_fine_coords_from_path, ) from fme.downscaling.predictors import PatchPredictionConfig, PatchPredictor from fme.downscaling.test_evaluator import LinearDownscalingDiffusion @@ -276,11 +277,7 @@ def checkpointed_model_config( # that correspond to the dataset coordinates fine_data_path = f"{data_paths.fine}/data.nc" static_inputs = load_static_inputs({"HGTsfc": fine_data_path}) - ds = xr.open_dataset(fine_data_path) - fine_coords = LatLonCoordinates( - lat=torch.tensor(ds["lat"].values, dtype=torch.float32), - lon=torch.tensor(ds["lon"].values, dtype=torch.float32), - ) + fine_coords = load_fine_coords_from_path(fine_data_path) model = model_config.build( coarse_shape, 2, static_inputs=static_inputs, fine_coords=fine_coords ) diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index f50071cad..b0e586cf9 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -353,6 +353,8 @@ def _subset_static_inputs( "Static inputs must be provided for each batch when use of fine " "static inputs is enabled." ) + # Fine coords and static inputs are expected to be on the same grid + # as fine_coords so use the model's fine coords to subset StaticInputs lat_slice = lat_interval.slice_of(self.fine_coords.lat) lon_slice = lon_interval.slice_of(self.fine_coords.lon) return self.static_inputs.subset(lat_slice, lon_slice) @@ -374,8 +376,8 @@ def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: downscale_factor=self.downscale_factor, ) return LatLonCoordinates( - lat=self.fine_coords.lat[fine_lat_interval.slice_of(self.fine_coords.lat)], - lon=self.fine_coords.lon[fine_lon_interval.slice_of(self.fine_coords.lon)], + lat=fine_lat_interval.subset_of(self.fine_coords.lat), + lon=fine_lon_interval.subset_of(self.fine_coords.lon), ) @property From b19c9d6e791b3f2efdca8c5e49605d89fcef28a9 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Tue, 17 Mar 2026 15:50:21 -0700 Subject: [PATCH 12/27] Add load_fine_coords_from_path test --- fme/downscaling/test_models.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index 98bb76b33..1c4254bcf 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -24,6 +24,7 @@ PairedNormalizationConfig, _repeat_batch_by_samples, _separate_interleaved_samples, + load_fine_coords_from_path, ) from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector from fme.downscaling.noise import LogNormalNoiseDistribution @@ -698,3 +699,33 @@ def test_checkpoint_model_build_raises_when_checkpoint_has_static_inputs(tmp_pat ) with pytest.raises(ValueError): config.build() + + +@pytest.mark.parametrize( + "lat_name,lon_name", + [ + pytest.param("lat", "lon", id="standard"), + pytest.param("latitude", "longitude", id="long_names"), + pytest.param("grid_yt", "grid_xt", id="fv3_names"), + ], +) +def test_load_fine_coords_from_path(tmp_path, lat_name, lon_name): + lat = [0.0, 1.0, 2.0] + lon = [10.0, 20.0, 30.0, 40.0] + ds = xr.Dataset(coords={lat_name: lat, lon_name: lon}) + path = str(tmp_path / "coords.nc") + ds.to_netcdf(path) + + coords = load_fine_coords_from_path(path) + + assert torch.allclose(coords.lat, torch.tensor(lat, dtype=torch.float32)) + assert torch.allclose(coords.lon, torch.tensor(lon, dtype=torch.float32)) + + +def test_load_fine_coords_from_path_raises_on_missing_coords(tmp_path): + ds = xr.Dataset(coords={"x": [0.0, 1.0], "y": [10.0, 20.0]}) + path = str(tmp_path / "no_latlon.nc") + ds.to_netcdf(path) + + with pytest.raises(ValueError, match="Could not find lat/lon coordinates"): + load_fine_coords_from_path(path) From 7acf87c148dbbe48c2769057feef0a5390b5218a Mon Sep 17 00:00:00 2001 From: "W. Andre Perkins" Date: Wed, 18 Mar 2026 13:20:11 -0700 Subject: [PATCH 13/27] Update fme/downscaling/models.py Co-authored-by: Anna Kwa --- fme/downscaling/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index b0e586cf9..3932c8905 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -319,7 +319,7 @@ def __init__( static_inputs.to_device() if static_inputs is not None else None ) device = get_device() - self.fine_coords = LatLonCoordinates( + self.fine_coords = fine_coords.to(device) lat=fine_coords.lat.to(device), lon=fine_coords.lon.to(device), ) From eba2749addebe6d4dd6488980a09017c3e69a07a Mon Sep 17 00:00:00 2001 From: "W. Andre Perkins" Date: Wed, 18 Mar 2026 13:20:24 -0700 Subject: [PATCH 14/27] Update fme/downscaling/models.py Co-authored-by: Anna Kwa --- fme/downscaling/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 3932c8905..fb1766770 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -324,7 +324,7 @@ def __init__( lon=fine_coords.lon.to(device), ) if static_inputs is not None: - expected = (len(fine_coords.lat), len(fine_coords.lon)) + expected = fine_coords.shape if static_inputs.shape != expected: raise ValueError( f"static_inputs are expected to be on the same coordinate grid as " From 2c25f1d0857bfdec53e082b25fd068f61974fad9 Mon Sep 17 00:00:00 2001 From: "W. Andre Perkins" Date: Wed, 18 Mar 2026 13:20:52 -0700 Subject: [PATCH 15/27] Update fme/downscaling/models.py Co-authored-by: Anna Kwa --- fme/downscaling/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index fb1766770..0648249c1 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -353,7 +353,7 @@ def _subset_static_inputs( "Static inputs must be provided for each batch when use of fine " "static inputs is enabled." ) - # Fine coords and static inputs are expected to be on the same grid + # Static inputs are expected to be on the same grid # as fine_coords so use the model's fine coords to subset StaticInputs lat_slice = lat_interval.slice_of(self.fine_coords.lat) lon_slice = lon_interval.slice_of(self.fine_coords.lon) From 81e4eb84ebf1a07e03be6ba530366ce0b89b2acf Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Wed, 18 Mar 2026 13:25:13 -0700 Subject: [PATCH 16/27] use latlon coords .to method for device fix --- fme/core/coordinates.py | 2 +- fme/downscaling/models.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/fme/core/coordinates.py b/fme/core/coordinates.py index 68c2fbb29..bf27cd996 100644 --- a/fme/core/coordinates.py +++ b/fme/core/coordinates.py @@ -736,7 +736,7 @@ def __eq__(self, other) -> bool: def __repr__(self) -> str: return f"LatLonCoordinates(\n lat={self.lat},\n lon={self.lon}\n" - def to(self, device: str) -> "LatLonCoordinates": + def to(self, device: str | torch.device) -> "LatLonCoordinates": return LatLonCoordinates( lon=self.lon.to(device), lat=self.lat.to(device), diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 0648249c1..e89ddab89 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -318,11 +318,7 @@ def __init__( self.static_inputs = ( static_inputs.to_device() if static_inputs is not None else None ) - device = get_device() - self.fine_coords = fine_coords.to(device) - lat=fine_coords.lat.to(device), - lon=fine_coords.lon.to(device), - ) + self.fine_coords = fine_coords.to(get_device()) if static_inputs is not None: expected = fine_coords.shape if static_inputs.shape != expected: From 325d32473673e2f38b29be16c177ebeb1b886fac Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Wed, 18 Mar 2026 22:09:30 -0700 Subject: [PATCH 17/27] Redo based on Anna's comments --- fme/downscaling/data/datasets.py | 4 +- fme/downscaling/data/static.py | 81 ++++-- fme/downscaling/data/test_static.py | 27 +- fme/downscaling/data/test_utils.py | 6 +- fme/downscaling/data/utils.py | 6 +- fme/downscaling/inference/test_inference.py | 14 +- fme/downscaling/models.py | 304 ++++++++------------ fme/downscaling/test_models.py | 110 +++---- fme/downscaling/test_predict.py | 11 +- fme/downscaling/train.py | 9 +- 10 files changed, 257 insertions(+), 315 deletions(-) diff --git a/fme/downscaling/data/datasets.py b/fme/downscaling/data/datasets.py index 186133d2a..61f94f2ed 100644 --- a/fme/downscaling/data/datasets.py +++ b/fme/downscaling/data/datasets.py @@ -139,8 +139,8 @@ def __init__( ) # Used to subset the data in __getitem__ - self._lats_slice = self.lat_interval.slice_of(self._orig_coords.lat) - self._lons_slice = self.lon_interval.slice_of(self._orig_coords.lon) + self._lats_slice = self.lat_interval.slice_from(self._orig_coords.lat) + self._lons_slice = self.lon_interval.slice_from(self._orig_coords.lon) self._latlon_coordinates = LatLonCoordinates( lat=self.lat_interval.subset_of(self._orig_coords.lat), diff --git a/fme/downscaling/data/static.py b/fme/downscaling/data/static.py index 48ab54ed7..34d73baa2 100644 --- a/fme/downscaling/data/static.py +++ b/fme/downscaling/data/static.py @@ -3,8 +3,9 @@ import torch import xarray as xr +from fme.core.coordinates import LatLonCoordinates from fme.core.device import get_device -from fme.downscaling.data.patching import Patch +from fme.downscaling.data.utils import ClosedInterval @dataclasses.dataclass @@ -13,7 +14,10 @@ class StaticInput: def __post_init__(self): if len(self.data.shape) != 2: - raise ValueError(f"Topography data must be 2D. Got shape {self.data.shape}") + raise ValueError( + f"StaticInput data must be 2D. Got shape {self.data.shape}" + ) + self._shape = (self.data.shape[0], self.data.shape[1]) @property def dim(self) -> int: @@ -21,7 +25,7 @@ def dim(self) -> int: @property def shape(self) -> tuple[int, int]: - return self.data.shape + return self._shape def subset( self, @@ -34,14 +38,15 @@ def to_device(self) -> "StaticInput": device = get_device() return StaticInput(data=self.data.to(device)) - def _apply_patch(self, patch: Patch): - return self.subset(lat_slice=patch.input_slice.y, lon_slice=patch.input_slice.x) - def get_state(self) -> dict: return { "data": self.data.cpu(), } + @classmethod + def from_state(cls, state: dict) -> "StaticInput": + return cls(data=state["data"]) + def _get_normalized_static_input(path: str, field_name: str): """ @@ -74,6 +79,7 @@ def _get_normalized_static_input(path: str, field_name: str): @dataclasses.dataclass class StaticInputs: fields: list[StaticInput] + coords: LatLonCoordinates def __post_init__(self): for i, field in enumerate(self.fields[1:]): @@ -82,6 +88,11 @@ def __post_init__(self): f"All StaticInput fields must have the same shape. " f"Fields {i + 1} and 0 do not match shapes." ) + if self.fields and self.coords.shape != self.fields[0].shape: + raise ValueError( + f"Coordinates shape {self.coords.shape} does not match fields shape " + f"{self.fields[0].shape} for StaticInputs." + ) def __getitem__(self, index: int): return self.fields[index] @@ -94,47 +105,61 @@ def shape(self) -> tuple[int, int]: def subset( self, - lat_slice: slice, - lon_slice: slice, + lat_interval: ClosedInterval, + lon_interval: ClosedInterval, ) -> "StaticInputs": + lat_slice = lat_interval.slice_from(self.coords.lat) + lon_slice = lon_interval.slice_from(self.coords.lon) return StaticInputs( - fields=[field.subset(lat_slice, lon_slice) for field in self.fields] + fields=[field.subset(lat_slice, lon_slice) for field in self.fields], + coords=LatLonCoordinates( + lat=lat_interval.subset_of(self.coords.lat), + lon=lon_interval.subset_of(self.coords.lon), + ), ) def to_device(self) -> "StaticInputs": - return StaticInputs(fields=[field.to_device() for field in self.fields]) + return StaticInputs( + fields=[field.to_device() for field in self.fields], + coords=self.coords.to(get_device()), + ) def get_state(self) -> dict: return { "fields": [field.get_state() for field in self.fields], + "coords": self.coords.get_state(), } @classmethod def from_state(cls, state: dict) -> "StaticInputs": + """Reconstruct StaticInputs from a state dict. + + Args: + state: State dict from get_state(). + coords: Override coordinates. If None, reads coords from the state dict. + Pass explicitly when loading old-format checkpoints that stored coords + outside of the StaticInputs state. + """ return cls( fields=[ - StaticInput( - data=field_state["data"], - ) - for field_state in state["fields"] - ] + StaticInput.from_state(field_state) for field_state in state["fields"] + ], + coords=LatLonCoordinates( + lat=state["coords"]["lat"], + lon=state["coords"]["lon"], + ), ) def load_static_inputs( - static_inputs_config: dict[str, str] | None, -) -> StaticInputs | None: + static_inputs_config: dict[str, str], coords: LatLonCoordinates +) -> StaticInputs: """ Load normalized static inputs from a mapping of field names to file paths. - Returns None if the input config is empty. + Returns an empty StaticInputs (no fields) if the config is None or empty. """ - # TODO: consolidate/simplify empty StaticInputs vs. None handling in - # downscaling code - if not static_inputs_config: - return None - return StaticInputs( - fields=[ - _get_normalized_static_input(path, field_name) - for field_name, path in static_inputs_config.items() - ] - ) + fields = [ + _get_normalized_static_input(path, field_name) + for field_name, path in static_inputs_config.items() + ] + return StaticInputs(fields=fields, coords=coords) diff --git a/fme/downscaling/data/test_static.py b/fme/downscaling/data/test_static.py index 104ffdc9f..fd6465e17 100644 --- a/fme/downscaling/data/test_static.py +++ b/fme/downscaling/data/test_static.py @@ -29,12 +29,19 @@ def test_subset(): def test_StaticInputs_serialize(): - data = torch.arange(16).reshape(4, 4) + from fme.core.coordinates import LatLonCoordinates + + data = torch.arange(16, dtype=torch.float32).reshape(4, 4) + coords = LatLonCoordinates( + lat=torch.arange(4, dtype=torch.float32), + lon=torch.arange(4, dtype=torch.float32), + ) topography = StaticInput(data) land_frac = StaticInput(data * -1.0) - static_inputs = StaticInputs([topography, land_frac]) + static_inputs = StaticInputs([topography, land_frac], coords=coords) state = static_inputs.get_state() - # Verify coords are NOT stored in state + # Verify coords are stored at the top level, not inside each field + assert "coords" in state assert "coords" not in state["fields"][0] static_inputs_reconstructed = StaticInputs.from_state(state) assert static_inputs_reconstructed[0].data.equal(static_inputs[0].data) @@ -42,9 +49,16 @@ def test_StaticInputs_serialize(): def test_StaticInputs_serialize_backward_compat_with_coords(): - """from_state should silently ignore 'coords' key for old state dicts.""" + """from_state should silently ignore 'coords' key in fields for old state dicts.""" + from fme.core.coordinates import LatLonCoordinates + data = torch.arange(16, dtype=torch.float32).reshape(4, 4) - # Simulate old state dict format that included coords + coords = LatLonCoordinates( + lat=torch.arange(4, dtype=torch.float32), + lon=torch.arange(4, dtype=torch.float32), + ) + # Simulate old state dict format that included coords inside fields. + # from_state should silently ignore extra keys (like 'coords') in field dicts. old_state = { "fields": [ { @@ -54,7 +68,8 @@ def test_StaticInputs_serialize_backward_compat_with_coords(): "lon": torch.arange(4, dtype=torch.float32), }, } - ] + ], + "coords": coords.get_state(), } static_inputs = StaticInputs.from_state(old_state) assert torch.equal(static_inputs[0].data, data) diff --git a/fme/downscaling/data/test_utils.py b/fme/downscaling/data/test_utils.py index c0d66521b..8f6e7806b 100644 --- a/fme/downscaling/data/test_utils.py +++ b/fme/downscaling/data/test_utils.py @@ -92,9 +92,9 @@ def test_scale_slice(input_slice, expected): ), ], ) -def test_ClosedInterval_slice_of(interval, expected_slice): +def test_ClosedInterval_slice_from(interval, expected_slice): coords = torch.arange(5) - result_slice = interval.slice_of(coords) + result_slice = interval.slice_from(coords) assert result_slice == expected_slice @@ -122,4 +122,4 @@ def test_ClosedInterval_subset_of(interval, expected_values): def test_ClosedInterval_fail_on_empty_slice(): coords = torch.arange(5) with pytest.raises(ValueError): - ClosedInterval(5.5, 7).slice_of(coords) + ClosedInterval(5.5, 7).slice_from(coords) diff --git a/fme/downscaling/data/utils.py b/fme/downscaling/data/utils.py index 146375ef8..ca3a4359a 100644 --- a/fme/downscaling/data/utils.py +++ b/fme/downscaling/data/utils.py @@ -37,9 +37,9 @@ def __post_init__(self): def __contains__(self, value: float): return self.start <= value <= self.stop - def slice_of(self, coords: torch.Tensor) -> slice: + def slice_from(self, coords: torch.Tensor) -> slice: """ - Return a slice that selects all elements of `coords` within this + Return a slice that selects all elements from `coords` within this specified interval. This assumes `coords` is monotonically increasing. Args: @@ -68,7 +68,7 @@ def subset_of(self, coords: torch.Tensor) -> torch.Tensor: Return a subset of `coords` that falls within this specified interval. This assumes `coords` is monotonically increasing. """ - slice = self.slice_of(coords) + slice = self.slice_from(coords) return coords[slice] diff --git a/fme/downscaling/inference/test_inference.py b/fme/downscaling/inference/test_inference.py index 481bc2973..9ba54e580 100644 --- a/fme/downscaling/inference/test_inference.py +++ b/fme/downscaling/inference/test_inference.py @@ -64,8 +64,14 @@ def mock_output_target(): def get_static_inputs(shape=(16, 16)): + from fme.core.coordinates import LatLonCoordinates + data = torch.randn(shape) - return StaticInputs([StaticInput(data=data)]) + coords = LatLonCoordinates( + lat=torch.arange(shape[0], dtype=torch.float32), + lon=torch.arange(shape[1], dtype=torch.float32), + ) + return StaticInputs([StaticInput(data=data)], coords=coords) # Tests for Downscaler initialization @@ -276,11 +282,9 @@ def checkpointed_model_config( # loader_config is passed in to add static inputs into model # that correspond to the dataset coordinates fine_data_path = f"{data_paths.fine}/data.nc" - static_inputs = load_static_inputs({"HGTsfc": fine_data_path}) fine_coords = load_fine_coords_from_path(fine_data_path) - model = model_config.build( - coarse_shape, 2, static_inputs=static_inputs, fine_coords=fine_coords - ) + static_inputs = load_static_inputs({"HGTsfc": fine_data_path}, coords=fine_coords) + model = model_config.build(coarse_shape, 2, static_inputs=static_inputs) checkpoint_path = tmp_path / "model_checkpoint.pth" model.get_state() diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index e89ddab89..421cd84f4 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -17,7 +17,6 @@ from fme.core.typing_ import TensorDict, TensorMapping from fme.downscaling.data import ( BatchData, - ClosedInterval, PairedBatchData, StaticInputs, adjust_fine_coord_range, @@ -182,9 +181,8 @@ def build( self, coarse_shape: tuple[int, int], downscale_factor: int, - fine_coords: LatLonCoordinates, + static_inputs: StaticInputs, rename: dict[str, str] | None = None, - static_inputs: StaticInputs | None = None, ) -> "DiffusionModel": invert_rename = {v: k for k, v in (rename or {}).items()} orig_in_names = [invert_rename.get(name, name) for name in self.in_names] @@ -196,10 +194,10 @@ def build( # https://en.wikipedia.org/wiki/Standard_score sigma_data = 1.0 - n_in_channels = len(self.in_names) - if static_inputs is not None: - n_in_channels += len(static_inputs.fields) - elif self.use_fine_topography: + n_in_channels = len(self.in_names) + len(static_inputs.fields) + if self.use_fine_topography and len(static_inputs.fields) == 0: + # TODO: remove this when forcing static inputs to be provided + # . via removing use_fine_topography # Old checkpoints may not have static inputs serialized, but if # use_fine_topography is True, we still need to account for the topography # channel, which was the only static input at the time @@ -223,7 +221,6 @@ def build( downscale_factor=downscale_factor, sigma_data=sigma_data, static_inputs=static_inputs, - fine_coords=fine_coords, ) def get_state(self) -> Mapping[str, Any]: @@ -279,8 +276,7 @@ def __init__( coarse_shape: tuple[int, int], downscale_factor: int, sigma_data: float, - fine_coords: LatLonCoordinates, - static_inputs: StaticInputs | None = None, + static_inputs: StaticInputs, ) -> None: """ Args: @@ -297,12 +293,9 @@ def __init__( coarse to fine. sigma_data: The standard deviation of the data, used for diffusion model preconditioning. - fine_coords: the full-domain fine-resolution coordinates to use - for spatial metadata in the model output. - static_inputs: Static inputs to the model, loaded from the trainer - config or checkpoint. Must be set when use_fine_topography is True. - Expected to be on the same coordinate grid as fine_coords for now, - but this may be relaxed in the future. + static_inputs: Static inputs to the model. Always provided, carrying + the full-domain fine-resolution coordinates. Fields may be empty + when no static data is needed. """ self.coarse_shape = coarse_shape self.downscale_factor = downscale_factor @@ -315,48 +308,20 @@ def __init__( self.out_packer = Packer(config.out_names) self.config = config self._channel_axis = -3 - self.static_inputs = ( - static_inputs.to_device() if static_inputs is not None else None - ) - self.fine_coords = fine_coords.to(get_device()) - if static_inputs is not None: - expected = fine_coords.shape - if static_inputs.shape != expected: - raise ValueError( - f"static_inputs are expected to be on the same coordinate grid as " - f"fine_coords. StaticInputs shape {static_inputs.shape} does not " - f"match fine_coords grid {expected}." - ) + self.static_inputs = static_inputs.to_device() @property def modules(self) -> torch.nn.ModuleList: return torch.nn.ModuleList([self.module]) - def _subset_static_inputs( - self, - lat_interval: ClosedInterval, - lon_interval: ClosedInterval, - ) -> StaticInputs | None: - """Subset self.static_inputs to the given fine lat/lon interval. - - Returns None if use_fine_topography is False. - Raises ValueError if use_fine_topography is True but self.static_inputs is None. - """ - if not self.config.use_fine_topography: - return None - if self.static_inputs is None: - raise ValueError( - "Static inputs must be provided for each batch when use of fine " - "static inputs is enabled." - ) - # Static inputs are expected to be on the same grid - # as fine_coords so use the model's fine coords to subset StaticInputs - lat_slice = lat_interval.slice_of(self.fine_coords.lat) - lon_slice = lon_interval.slice_of(self.fine_coords.lon) - return self.static_inputs.subset(lat_slice, lon_slice) + @property + def fine_coords(self) -> LatLonCoordinates: + return self.static_inputs.coords def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: - """Return fine-resolution coordinates matching the spatial extent of batch.""" + return self._get_subset_for_coarse_batch(batch).coords + + def _get_subset_for_coarse_batch(self, batch: BatchData) -> StaticInputs: coarse_lat = batch.latlon_coordinates.lat[0] coarse_lon = batch.latlon_coordinates.lon[0] fine_lat_interval = adjust_fine_coord_range( @@ -371,10 +336,11 @@ def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: full_fine_coord=self.fine_coords.lon, downscale_factor=self.downscale_factor, ) - return LatLonCoordinates( - lat=fine_lat_interval.subset_of(self.fine_coords.lat), - lon=fine_lon_interval.subset_of(self.fine_coords.lon), + subset_static_inputs = self.static_inputs.subset( + lat_interval=fine_lat_interval, + lon_interval=fine_lon_interval, ) + return subset_static_inputs @property def fine_shape(self) -> tuple[int, int]: @@ -391,7 +357,7 @@ def _get_fine_shape(self, coarse_shape: tuple[int, int]) -> tuple[int, int]: ) def _get_input_from_coarse( - self, coarse: TensorMapping, static_inputs: StaticInputs | None + self, coarse: TensorMapping, static_inputs: StaticInputs ) -> torch.Tensor: inputs = filter_tensor_mapping(coarse, self.in_packer.names) normalized = self.in_packer.pack( @@ -399,11 +365,13 @@ def _get_input_from_coarse( ) interpolated = interpolate(normalized, self.downscale_factor) + # TODO: update when removing use_fine_topography flag if self.config.use_fine_topography: - if static_inputs is None: + if not static_inputs.fields: raise ValueError( - "Static inputs must be provided for each batch when use of fine " - "static inputs is enabled." + "Static inputs must be provided for each batch when flag " + "use_fine_topography is enabled, but no static input fields " + "were found." ) else: expected_shape = interpolated.shape[-2:] @@ -414,12 +382,13 @@ def _get_input_from_coarse( ) n_batches = normalized.shape[0] # Join normalized static inputs to input (see dataset for details) + fields: list[torch.Tensor] = [interpolated] for field in static_inputs.fields: - topo = field.data.unsqueeze(0).repeat(n_batches, 1, 1) - topo = topo.unsqueeze(self._channel_axis) - interpolated = torch.concat( - [interpolated, topo], axis=self._channel_axis - ) + static_field = field.data.unsqueeze(0).repeat(n_batches, 1, 1) + static_field = static_field.unsqueeze(self._channel_axis) + fields.append(static_field) + + interpolated = torch.concat(fields, dim=self._channel_axis) if self.config._interpolate_input: return interpolated @@ -431,9 +400,7 @@ def train_on_batch( optimizer: Optimization | NullOptimization, ) -> ModelOutputs: """Performs a denoising training step on a batch of data.""" - _static_inputs = self._subset_static_inputs( - batch.fine.lat_interval, batch.fine.lon_interval - ) + _static_inputs = self._get_subset_for_coarse_batch(batch.coarse) coarse, fine = batch.coarse.data, batch.fine.data inputs_norm = self._get_input_from_coarse(coarse, _static_inputs) targets_norm = self.out_packer.pack( @@ -491,7 +458,7 @@ def train_on_batch( def generate( self, coarse_data: TensorMapping, - static_inputs: StaticInputs | None, + static_inputs: StaticInputs, n_samples: int = 1, ) -> tuple[TensorDict, torch.Tensor, list[torch.Tensor]]: # Internal method; external callers should use generate_on_batch / @@ -547,31 +514,7 @@ def generate_on_batch_no_target( batch: BatchData, n_samples: int = 1, ) -> TensorDict: - if self.config.use_fine_topography: - if self.static_inputs is None: - raise ValueError( - "Static inputs must be provided for each batch when use of fine " - "static inputs is enabled." - ) - coarse_lat = batch.latlon_coordinates.lat[0] - coarse_lon = batch.latlon_coordinates.lon[0] - fine_lat_interval = adjust_fine_coord_range( - batch.lat_interval, - full_coarse_coord=coarse_lat, - full_fine_coord=self.fine_coords.lat, - downscale_factor=self.downscale_factor, - ) - fine_lon_interval = adjust_fine_coord_range( - batch.lon_interval, - full_coarse_coord=coarse_lon, - full_fine_coord=self.fine_coords.lon, - downscale_factor=self.downscale_factor, - ) - _static_inputs = self._subset_static_inputs( - fine_lat_interval, fine_lon_interval - ) - else: - _static_inputs = None + _static_inputs = self._get_subset_for_coarse_batch(batch) generated, _, _ = self.generate(batch.data, _static_inputs, n_samples) return generated @@ -581,9 +524,7 @@ def generate_on_batch( batch: PairedBatchData, n_samples: int = 1, ) -> ModelOutputs: - _static_inputs = self._subset_static_inputs( - batch.fine.lat_interval, batch.fine.lon_interval - ) + _static_inputs = self._get_subset_for_coarse_batch(batch.coarse) coarse, fine = batch.coarse.data, batch.fine.data generated, generated_norm, latent_steps = self.generate( coarse, _static_inputs, n_samples @@ -603,65 +544,69 @@ def generate_on_batch( ) def get_state(self) -> Mapping[str, Any]: - if self.static_inputs is not None: - static_inputs_state = self.static_inputs.get_state() - else: - static_inputs_state = None - return { "config": self.config.get_state(), "module": self.module.state_dict(), "coarse_shape": self.coarse_shape, "downscale_factor": self.downscale_factor, - "static_inputs": static_inputs_state, - "fine_coords": self.fine_coords.get_state(), + "static_inputs": self.static_inputs.get_state(), } + @staticmethod + def _legacy_coord_in_state( + static_inputs_state: Mapping[str, Any], + ) -> bool: + return ( + "fields" in static_inputs_state + and static_inputs_state["fields"] + and "coords" in static_inputs_state["fields"][0] + ) + @classmethod - def from_state( + def coords_in_checkpoint(cls, static_inputs_state: Mapping[str, Any]) -> bool: + return "coords" in static_inputs_state or cls._legacy_coord_in_state( + static_inputs_state + ) + + @classmethod + def _update_state_with_legacy_coords( cls, - state: Mapping[str, Any], - ) -> "DiffusionModel": - config = DiffusionModelConfig.from_state(state["config"]) - # backwards compatibility for models before static inputs serialization - if state.get("static_inputs") is not None: - static_inputs = StaticInputs.from_state(state["static_inputs"]).to_device() - else: - static_inputs = None - - # Load fine_coords: new checkpoints store it directly; old checkpoints - # that had static_inputs with coords can auto-migrate from raw state. - fine_coords = state.get("fine_coords") - if fine_coords is not None: - # TODO: Why doesn't LatLonCoordinates have a from_state method? - fine_coords = LatLonCoordinates( - lat=state["fine_coords"]["lat"], - lon=state["fine_coords"]["lon"], - ) - elif ( - static_inputs is not None - and static_inputs.fields - and "coords" in state["static_inputs"]["fields"][0] - ): - # Backward compat: old checkpoints with static inputs stored coords inside - # static_inputs fields[0]["coords"] - coords_state = state["static_inputs"]["fields"][0]["coords"] - fine_coords = LatLonCoordinates( - lat=coords_state["lat"], - lon=coords_state["lon"], - ) + static_inputs_state: dict[str, Any], + ) -> None: + if cls._legacy_coord_in_state(static_inputs_state): + coords = static_inputs_state["fields"][0]["coords"] + static_inputs_state["coords"] = coords else: raise ValueError( - "No fine coordinates found in checkpoint state and no static inputs " - " were available to infer them. fine_coords must be serialized with the" - " checkpoint to resume from training." + "No coordinates found in static_inputs state. " + "Must have a 'coords' field, or at least one field in static_inputs" + " must have a 'coords' field to infer coordinates from for model " + " reconstruction. Coordinates must be serialized with the " + "checkpoint to reload from state during training." ) + @classmethod + def from_state( + cls, + state: Mapping[str, Any], + ) -> "DiffusionModel": + """ + Model-level from state is used to reconstruct the full model during training via + loading a checkpoint. This requires recent checkpoint state fields, as + opposed to the CheckpointModelConfig where we build with the option to provide + static_inputs and fine_coordinates for backwards compatibility. + """ + static_inputs_state: dict = state.get("static_inputs") or {} + + if "coords" not in static_inputs_state: + cls._update_state_with_legacy_coords(static_inputs_state) + + static_inputs = StaticInputs.from_state(static_inputs_state) + config = DiffusionModelConfig.from_state(state["config"]) model = config.build( state["coarse_shape"], state["downscale_factor"], static_inputs=static_inputs, - fine_coords=fine_coords, ) model.module.load_state_dict(state["module"], strict=True) return model @@ -754,9 +699,7 @@ def _checkpoint(self) -> Mapping[str, Any]: for name in checkpoint_data["model"]["config"]["out_names"] ] # backwards compatibility for models before static inputs serialization - checkpoint_data["model"].setdefault("static_inputs", None) - # backwards compatibility for models before fine_coords serialization - checkpoint_data["model"].setdefault("fine_coords", None) + checkpoint_data["model"].setdefault("static_inputs", {}) self._checkpoint_data = checkpoint_data self._checkpoint_is_loaded = True @@ -768,68 +711,53 @@ def _checkpoint(self) -> Mapping[str, Any]: def build( self, ) -> DiffusionModel: - static_inputs: StaticInputs | None - if self._checkpoint["model"]["static_inputs"] is not None: - if self.static_inputs is not None: - raise ValueError( - "The model checkpoint already has static inputs from training. " - "static_inputs should not be provided in checkpoint model config." - "static inputs from training." - ) - static_inputs = StaticInputs.from_state( - self._checkpoint["model"]["static_inputs"] - ) - elif self.static_inputs is not None: - static_inputs = load_static_inputs(self.static_inputs) - else: - static_inputs = None - - fine_coords: LatLonCoordinates - has_fine_coords = self._checkpoint["model"]["fine_coords"] is not None - has_static_input_coords = ( - self._checkpoint["model"]["static_inputs"] is not None - and self._checkpoint["model"]["static_inputs"]["fields"][0].get("coords") - is not None + checkpoint_model = self._checkpoint["model"] + checkpoint_static_inputs = checkpoint_model["static_inputs"] + + has_coords_in_checkpoint = DiffusionModel.coords_in_checkpoint( + checkpoint_static_inputs ) - # TODO: simplify with static input refactor that deisables empty StaticInputs - if ( - has_fine_coords or has_static_input_coords - ) and self.fine_coordinates_path is not None: + if has_coords_in_checkpoint and self.fine_coordinates_path is not None: raise ValueError( "The model checkpoint already has fine coordinates from training. " "fine_coordinates_path should not be provided in checkpoint model " "config." ) - elif has_fine_coords: - fine_coords_state = self._checkpoint["model"]["fine_coords"] - fine_coords = LatLonCoordinates( - lat=fine_coords_state["lat"], - lon=fine_coords_state["lon"], - ) - elif has_static_input_coords: - coords_state = self._checkpoint["model"]["static_inputs"]["fields"][0][ - "coords" - ] - fine_coords = LatLonCoordinates( - lat=coords_state["lat"], - lon=coords_state["lon"], - ) - elif self.fine_coordinates_path is not None: - fine_coords = load_fine_coords_from_path(self.fine_coordinates_path) - else: + elif checkpoint_static_inputs and self.static_inputs is not None: raise ValueError( - "No fine coordinates found in checkpoint state and no static inputs " - " were available to infer them. fine_coordinates_path must be provided " - "in the checkpoint model configuration to load fine coordinates from " - "the provided path." + "The model checkpoint already has static inputs from training. " + "static_inputs should not be provided in checkpoint model config." ) + if has_coords_in_checkpoint: + if DiffusionModel._legacy_coord_in_state(checkpoint_static_inputs): + DiffusionModel._update_state_with_legacy_coords( + checkpoint_static_inputs + ) + static_inputs = StaticInputs.from_state(checkpoint_static_inputs) + else: + if self.fine_coordinates_path is None: + raise ValueError( + "No fine coordinates found in checkpoint state. " + "fine_coordinates_path must be provided in the checkpoint model " + "configuration to load this model." + ) + fine_coords = load_fine_coords_from_path(self.fine_coordinates_path) + if checkpoint_static_inputs: + checkpoint_static_inputs["coords"] = fine_coords.get_state() + static_inputs = StaticInputs.from_state(checkpoint_static_inputs) + elif self.static_inputs is not None: + static_inputs = load_static_inputs( + self.static_inputs, coords=fine_coords + ) + else: + static_inputs = StaticInputs(fields=[], coords=fine_coords) + model = _CheckpointModelConfigSelector.from_state( self._checkpoint["model"]["config"] ).build( coarse_shape=self._checkpoint["model"]["coarse_shape"], downscale_factor=self._checkpoint["model"]["downscale_factor"], - fine_coords=fine_coords, rename=self._rename, static_inputs=static_inputs, ) diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index 1c4254bcf..575cc5b91 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -125,13 +125,10 @@ def make_paired_batch_data( def make_static_inputs(fine_shape: tuple[int, int]) -> StaticInputs: - """Create StaticInputs for given shape.""" + """Create StaticInputs with one field and matching coords for the given shape.""" return StaticInputs( - fields=[ - StaticInput( - torch.ones(*fine_shape, device=get_device()), - ) - ] + fields=[StaticInput(torch.ones(*fine_shape, device=get_device()))], + coords=make_fine_coords(fine_shape), ) @@ -148,14 +145,13 @@ def test_module_serialization(tmp_path): coarse_shape = (8, 16) fine_shape = (16, 32) static_inputs = make_static_inputs(fine_shape) - fine_coords = make_fine_coords(fine_shape) + fine_coords = static_inputs.coords model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=2, predict_residual=True, use_fine_topography=False, static_inputs=static_inputs, - fine_coords=fine_coords, ) model_from_state = DiffusionModel.from_state( model.get_state(), @@ -190,43 +186,31 @@ def test_module_serialization(tmp_path): assert torch.equal(model_from_disk.fine_coords.lon.cpu(), fine_coords.lon.cpu()) -def test_from_state_backward_compat_fine_topography(): - coarse_shape = (8, 16) - fine_shape = (16, 32) - downscale_factor = 2 - static_inputs = make_static_inputs(fine_shape) - fine_coords = make_fine_coords(fine_shape) +def test_from_state_raises_when_no_coords_available(): + """from_state must raise when the checkpoint has neither top-level coords + nor legacy per-field coords to migrate from.""" model = _get_diffusion_model( - coarse_shape=coarse_shape, - downscale_factor=downscale_factor, - predict_residual=True, - use_fine_topography=True, - static_inputs=static_inputs, - fine_coords=fine_coords, + coarse_shape=(8, 16), downscale_factor=2, use_fine_topography=False ) - - # Simulate old checkpoint format: static_inputs not serialized, fine_coords still - # present state = model.get_state() + # Wipe coords entirely — no way to recover them state["static_inputs"] = None + with pytest.raises(ValueError): + DiffusionModel.from_state(state) - # Should load correctly via the elif use_fine_topography branch (+1 channel) - model_from_old_state = DiffusionModel.from_state(state) - assert model_from_old_state.static_inputs is None - assert torch.equal( - model_from_old_state.fine_coords.lat.cpu(), fine_coords.lat.cpu() - ) - assert all( - torch.equal(p1, p2) - for p1, p2 in zip( - model.module.parameters(), model_from_old_state.module.parameters() - ) - ) - # At runtime, omitting static inputs must raise a clear error - batch = get_mock_paired_batch([2, *coarse_shape], [2, *fine_shape]) +def test_generate_raises_when_no_static_fields_but_topography_required(): + coarse_shape = (8, 16) + fine_shape = (16, 32) + model = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=2, + use_fine_topography=True, + static_inputs=StaticInputs(fields=[], coords=make_fine_coords(fine_shape)), + ) + batch = make_paired_batch_data(coarse_shape, fine_shape) with pytest.raises(ValueError, match="Static inputs must be provided"): - model_from_old_state.generate_on_batch(batch) + model.generate_on_batch(batch) def test_from_state_backward_compat_migrates_fine_coords_from_old_static_inputs(): @@ -236,18 +220,17 @@ def test_from_state_backward_compat_migrates_fine_coords_from_old_static_inputs( fine_shape = (16, 32) downscale_factor = 2 static_inputs = make_static_inputs(fine_shape) - fine_coords = make_fine_coords(fine_shape) + fine_coords = static_inputs.coords model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=downscale_factor, predict_residual=True, use_fine_topography=True, static_inputs=static_inputs, - fine_coords=fine_coords, ) state = model.get_state() - # Simulate old format: fine_coords absent, but static_inputs fields have coords - del state["fine_coords"] + # Simulate old format: coords embedded in fields, not in static_inputs state + del state["static_inputs"]["coords"] state["static_inputs"]["fields"][0]["coords"] = fine_coords.get_state() model_from_old_state = DiffusionModel.from_state(state) @@ -265,19 +248,18 @@ def _get_diffusion_model( downscale_factor, predict_residual=True, use_fine_topography=True, - static_inputs=None, - fine_coords: LatLonCoordinates | None = None, + static_inputs: StaticInputs | None = None, ): normalizer = PairedNormalizationConfig( NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), ) - if fine_coords is None: + if static_inputs is None: fine_shape = ( coarse_shape[0] * downscale_factor, coarse_shape[1] * downscale_factor, ) - fine_coords = make_fine_coords(fine_shape) + static_inputs = StaticInputs(fields=[], coords=make_fine_coords(fine_shape)) return DiffusionModelConfig( module=DiffusionModuleRegistrySelector( @@ -299,7 +281,6 @@ def _get_diffusion_model( coarse_shape, downscale_factor, static_inputs=static_inputs, - fine_coords=fine_coords, ) @@ -314,17 +295,14 @@ def test_diffusion_model_train_and_generate(predict_residual, use_fine_topograph static_inputs = make_static_inputs(fine_shape) batch = make_paired_batch_data(coarse_shape, fine_shape, batch_size) else: - static_inputs = None - batch = get_mock_paired_batch( - [batch_size, *coarse_shape], [batch_size, *fine_shape] - ) + static_inputs = StaticInputs(fields=[], coords=fine_coords) + batch = make_paired_batch_data(coarse_shape, fine_shape, batch_size) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=2, predict_residual=predict_residual, use_fine_topography=use_fine_topography, static_inputs=static_inputs, - fine_coords=fine_coords, ) assert model._get_fine_shape(coarse_shape) == fine_shape @@ -448,14 +426,11 @@ def test_model_error_cases(): ).build( coarse_shape, upscaling_factor, - fine_coords=make_fine_coords(fine_shape), - ) - batch = get_mock_paired_batch( - [batch_size, *coarse_shape], [batch_size, *fine_shape] + static_inputs=StaticInputs(fields=[], coords=make_fine_coords(fine_shape)), ) + batch = make_paired_batch_data(coarse_shape, fine_shape, batch_size) - # missing fine topography when model requires it - batch.fine.topography = None + # missing fine topography when model requires it — empty fields raises ValueError with pytest.raises(ValueError): model.generate_on_batch(batch) @@ -465,14 +440,12 @@ def test_DiffusionModel_generate_on_batch_no_target(): coarse_shape = (16, 16) downscale_factor = 2 static_inputs = make_static_inputs(fine_shape) - fine_coords = make_fine_coords(fine_shape) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=downscale_factor, predict_residual=True, use_fine_topography=True, static_inputs=static_inputs, - fine_coords=fine_coords, ) batch_size = 2 @@ -505,7 +478,6 @@ def test_DiffusionModel_generate_on_batch_no_target_arbitrary_input_size(): full_fine_size = 64 full_fine_shape = (full_fine_size, full_fine_size) static_inputs = make_static_inputs(full_fine_shape) - fine_coords = make_fine_coords(full_fine_shape) # need to build with static inputs to get the correct n_in_channels model = _get_diffusion_model( coarse_shape=coarse_shape, @@ -513,7 +485,6 @@ def test_DiffusionModel_generate_on_batch_no_target_arbitrary_input_size(): predict_residual=True, use_fine_topography=True, static_inputs=static_inputs, - fine_coords=fine_coords, ) n_ensemble = 2 batch_size = 2 @@ -565,7 +536,7 @@ def test_lognorm_noise_backwards_compatibility(): model = model_config.build( (32, 32), 2, - fine_coords=make_fine_coords((64, 64)), + static_inputs=StaticInputs(fields=[], coords=make_fine_coords((64, 64))), ) state = model.get_state() @@ -609,12 +580,11 @@ def test_get_fine_coords_for_batch(): coarse_shape = (8, 16) fine_shape = (16, 32) downscale_factor = 2 - static_inputs = make_static_inputs(fine_shape) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=downscale_factor, use_fine_topography=True, - static_inputs=static_inputs, + static_inputs=make_static_inputs(fine_shape), ) # Build a batch covering a spatial patch: middle 4 coarse lats and 8 coarse lons. @@ -636,17 +606,15 @@ def test_checkpoint_model_build_with_fine_coordinates_path(tmp_path): """Old-format checkpoint (no fine_coords key, no coords in static_inputs) should load correctly when fine_coordinates_path is provided.""" coarse_shape = (8, 16) - fine_shape = (16, 32) - fine_coords = make_fine_coords(fine_shape) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=2, use_fine_topography=False, - fine_coords=fine_coords, ) - # Simulate old checkpoint: no fine_coords stored + fine_coords = model.fine_coords + # Simulate old checkpoint: no coords in static_inputs state state = model.get_state() - del state["fine_coords"] + del state["static_inputs"]["coords"] checkpoint_path = tmp_path / "test.ckpt" torch.save({"model": state}, checkpoint_path) @@ -681,14 +649,12 @@ def test_checkpoint_model_build_raises_when_checkpoint_has_static_inputs(tmp_pat coarse_shape = (8, 16) fine_shape = (16, 32) static_inputs = make_static_inputs(fine_shape) - fine_coords = make_fine_coords(fine_shape) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=2, predict_residual=True, use_fine_topography=True, static_inputs=static_inputs, - fine_coords=fine_coords, ) checkpoint_path = tmp_path / "test.ckpt" torch.save({"model": model.get_state()}, checkpoint_path) diff --git a/fme/downscaling/test_predict.py b/fme/downscaling/test_predict.py index 8101b293e..50411a0fd 100644 --- a/fme/downscaling/test_predict.py +++ b/fme/downscaling/test_predict.py @@ -8,7 +8,7 @@ from fme.core.normalizer import NormalizationConfig from fme.core.testing.wandb import mock_wandb from fme.downscaling import predict -from fme.downscaling.data import load_static_inputs +from fme.downscaling.data import StaticInputs, load_static_inputs from fme.downscaling.models import ( DiffusionModelConfig, PairedNormalizationConfig, @@ -120,8 +120,9 @@ def test_predictor_runs(tmp_path, very_fast_only: bool): model = model_config.build( coarse_shape=coarse_shape, downscale_factor=downscale_factor, - static_inputs=load_static_inputs({"HGTsfc": fine_data_path}), - fine_coords=fine_coords, + static_inputs=load_static_inputs( + {"HGTsfc": fine_data_path}, coords=fine_coords + ), ) with open(predictor_config_path) as f: predictor_config = yaml.safe_load(f) @@ -167,7 +168,9 @@ def test_predictor_renaming( coarse_shape, downscale_factor, use_fine_topography=False ) model = model_config.build( - coarse_shape=coarse_shape, downscale_factor=2, fine_coords=fine_coords + coarse_shape=coarse_shape, + downscale_factor=2, + static_inputs=StaticInputs(fields=[], coords=fine_coords), ) with open(predictor_config_path) as f: predictor_config = yaml.safe_load(f) diff --git a/fme/downscaling/train.py b/fme/downscaling/train.py index 5a0c25666..d91b8b4f4 100755 --- a/fme/downscaling/train.py +++ b/fme/downscaling/train.py @@ -393,7 +393,7 @@ class TrainerConfig: experiment_dir: str save_checkpoints: bool logging: LoggingConfig - static_inputs: dict[str, str] | None = None + static_inputs: dict[str, str] = dataclasses.field(default_factory=dict) ema: EMAConfig = dataclasses.field(default_factory=EMAConfig) validate_using_ema: bool = False generate_n_samples: int = 1 @@ -421,8 +421,6 @@ def checkpoint_dir(self) -> str: return os.path.join(self.experiment_dir, "checkpoints") def build(self) -> Trainer: - static_inputs = load_static_inputs(self.static_inputs) - train_data: PairedGriddedData = self.train_data.build( train=True, requirements=self.model.data_requirements, @@ -431,6 +429,10 @@ def build(self) -> Trainer: train=False, requirements=self.model.data_requirements, ) + static_inputs = load_static_inputs( + self.static_inputs, coords=train_data.fine_coords + ) + if self.coarse_patch_extent_lat and self.coarse_patch_extent_lon: model_coarse_shape = ( self.coarse_patch_extent_lat, @@ -443,7 +445,6 @@ def build(self) -> Trainer: model_coarse_shape, train_data.downscale_factor, static_inputs=static_inputs, - fine_coords=train_data.fine_coords, ) optimization = self.optimization.build( From 1fac745c4952266de060906b627186fa442ad9db Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Thu, 19 Mar 2026 11:28:17 -0700 Subject: [PATCH 18/27] Remove from_state docstring --- fme/downscaling/data/static.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/fme/downscaling/data/static.py b/fme/downscaling/data/static.py index 34d73baa2..b9c1372c3 100644 --- a/fme/downscaling/data/static.py +++ b/fme/downscaling/data/static.py @@ -132,14 +132,6 @@ def get_state(self) -> dict: @classmethod def from_state(cls, state: dict) -> "StaticInputs": - """Reconstruct StaticInputs from a state dict. - - Args: - state: State dict from get_state(). - coords: Override coordinates. If None, reads coords from the state dict. - Pass explicitly when loading old-format checkpoints that stored coords - outside of the StaticInputs state. - """ return cls( fields=[ StaticInput.from_state(field_state) for field_state in state["fields"] From 3ef1613f4eb98dc365827e6cb0025d3c1ddf049d Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Thu, 19 Mar 2026 15:00:35 -0700 Subject: [PATCH 19/27] Move all state loading cases into static inputs code --- fme/downscaling/data/__init__.py | 7 +- fme/downscaling/data/static.py | 89 ++++++++++++++ fme/downscaling/data/test_static.py | 178 +++++++++++++++++++++------- fme/downscaling/models.py | 107 +---------------- 4 files changed, 233 insertions(+), 148 deletions(-) diff --git a/fme/downscaling/data/__init__.py b/fme/downscaling/data/__init__.py index 5790cdc11..334e1a949 100644 --- a/fme/downscaling/data/__init__.py +++ b/fme/downscaling/data/__init__.py @@ -12,7 +12,12 @@ PairedBatchItem, PairedGriddedData, ) -from .static import StaticInput, StaticInputs, load_static_inputs +from .static import ( + StaticInput, + StaticInputs, + load_fine_coords_from_path, + load_static_inputs, +) from .utils import ( BatchedLatLonCoordinates, ClosedInterval, diff --git a/fme/downscaling/data/static.py b/fme/downscaling/data/static.py index b9c1372c3..6f34b4a26 100644 --- a/fme/downscaling/data/static.py +++ b/fme/downscaling/data/static.py @@ -76,6 +76,46 @@ def _get_normalized_static_input(path: str, field_name: str): ) +def _has_legacy_coords_in_state(state: dict) -> bool: + return "fields" in state and state["fields"] and "coords" in state["fields"][0] + + +def _sync_state_coordinates(state: dict) -> dict: + # if necessary adjusts legacy coordinate to expected + # format for state loading + state = state.copy() + if _has_legacy_coords_in_state(state): + state["coords"] = state["fields"][0]["coords"] + return state + + +def _has_coords_in_state(state: dict) -> bool: + if "coords" in state or _has_legacy_coords_in_state(state): + return True + else: + return False + + +def load_fine_coords_from_path(path: str) -> LatLonCoordinates: + if path.endswith(".zarr"): + ds = xr.open_zarr(path) + else: + ds = xr.open_dataset(path) + lat_name = next((n for n in ["lat", "latitude", "grid_yt"] if n in ds.coords), None) + lon_name = next( + (n for n in ["lon", "longitude", "grid_xt"] if n in ds.coords), None + ) + if lat_name is None or lon_name is None: + raise ValueError( + f"Could not find lat/lon coordinates in {path}. " + "Expected 'lat'/'latitude'/'grid_yt' and 'lon'/'longitude'/'grid_xt'." + ) + return LatLonCoordinates( + lat=torch.tensor(ds[lat_name].values, dtype=torch.float32), + lon=torch.tensor(ds[lon_name].values, dtype=torch.float32), + ) + + @dataclasses.dataclass class StaticInputs: fields: list[StaticInput] @@ -132,6 +172,13 @@ def get_state(self) -> dict: @classmethod def from_state(cls, state: dict) -> "StaticInputs": + if not _has_coords_in_state(state): + raise ValueError( + "No coordinates found in state for StaticInputs. Load with" + "from_state_backwards_compatible if loading from a checkpoint " + "saved prior to current coordinate serialization format." + ) + state = _sync_state_coordinates(state) return cls( fields=[ StaticInput.from_state(field_state) for field_state in state["fields"] @@ -142,6 +189,48 @@ def from_state(cls, state: dict) -> "StaticInputs": ), ) + @classmethod + def from_state_backwards_compatible( + cls, + state: dict, + static_inputs_config: dict[str, str], + fine_coordinates_path: str | None, + ) -> "StaticInputs": + if state and static_inputs_config: + raise ValueError( + "Checkpoint contains static inputs but static_inputs_config is " + "also provided. Backwards compatibility loading only supports " + "a single source of StaticInputs info." + ) + + if fine_coordinates_path and _has_coords_in_state(state): + raise ValueError( + "State contains coordinates but fine_coordinates_path is also provided." + " Only one source of coordinate info can be used for backwards " + "compatibility loading of StaticInputs." + ) + elif not _has_coords_in_state(state) and not fine_coordinates_path: + raise ValueError( + "No coordinates found in state and no fine_coordinates_path provided. " + "Cannot load StaticInputs without coordinates." + ) + + # All compatibility cases: + # Serialized StaticInputs exist, which always had coordinates stored + # No serialized static inputs or specified inputs, load coordinates + # Specified static input fields and specified coordinates + + if _has_coords_in_state(state): + return cls.from_state(state) + else: + assert fine_coordinates_path is not None # for type checker + coords = load_fine_coords_from_path(fine_coordinates_path) + + if static_inputs_config: + return load_static_inputs(static_inputs_config, coords) + else: + return cls(fields=[], coords=coords) + def load_static_inputs( static_inputs_config: dict[str, str], coords: LatLonCoordinates diff --git a/fme/downscaling/data/test_static.py b/fme/downscaling/data/test_static.py index fd6465e17..f20dfe667 100644 --- a/fme/downscaling/data/test_static.py +++ b/fme/downscaling/data/test_static.py @@ -1,21 +1,33 @@ +import numpy as np import pytest import torch +import xarray as xr from .static import StaticInput, StaticInputs -@pytest.mark.parametrize( - "init_args", - [ - pytest.param( - [torch.randn((1, 2, 2))], - id="3d_data", - ), - ], -) -def test_Topography_error_cases(init_args): +def _make_coords(n=4): + from fme.core.coordinates import LatLonCoordinates + + return LatLonCoordinates( + lat=torch.arange(n, dtype=torch.float32), + lon=torch.arange(n, dtype=torch.float32), + ) + + +def _write_coords_netcdf(path, n=4): + xr.Dataset( + coords={ + "lat": np.arange(n, dtype=np.float32), + "lon": np.arange(n, dtype=np.float32), + } + ).to_netcdf(path) + + +def test_StaticInput_error_cases(): + data = torch.randn(1, 2, 2) with pytest.raises(ValueError): - StaticInput(*init_args) + StaticInput(data=data) def test_subset(): @@ -29,47 +41,121 @@ def test_subset(): def test_StaticInputs_serialize(): - from fme.core.coordinates import LatLonCoordinates - - data = torch.arange(16, dtype=torch.float32).reshape(4, 4) - coords = LatLonCoordinates( - lat=torch.arange(4, dtype=torch.float32), - lon=torch.arange(4, dtype=torch.float32), + dim_len = 4 + data = torch.arange(dim_len * dim_len, dtype=torch.float32).reshape( + dim_len, dim_len + ) + coords = _make_coords(n=dim_len) + static_inputs = StaticInputs( + [StaticInput(data), StaticInput(data * -1.0)], coords=coords ) - topography = StaticInput(data) - land_frac = StaticInput(data * -1.0) - static_inputs = StaticInputs([topography, land_frac], coords=coords) state = static_inputs.get_state() - # Verify coords are stored at the top level, not inside each field assert "coords" in state - assert "coords" not in state["fields"][0] - static_inputs_reconstructed = StaticInputs.from_state(state) - assert static_inputs_reconstructed[0].data.equal(static_inputs[0].data) - assert static_inputs_reconstructed[1].data.equal(static_inputs[1].data) + reconstructed = StaticInputs.from_state(state) + assert reconstructed[0].data.equal(static_inputs[0].data) + assert reconstructed[1].data.equal(static_inputs[1].data) + assert torch.equal(reconstructed.coords.lat, static_inputs.coords.lat) + assert torch.equal(reconstructed.coords.lon, static_inputs.coords.lon) -def test_StaticInputs_serialize_backward_compat_with_coords(): - """from_state should silently ignore 'coords' key in fields for old state dicts.""" - from fme.core.coordinates import LatLonCoordinates - +def test_StaticInputs_from_state_raises_on_missing_coords(): data = torch.arange(16, dtype=torch.float32).reshape(4, 4) - coords = LatLonCoordinates( - lat=torch.arange(4, dtype=torch.float32), - lon=torch.arange(4, dtype=torch.float32), + with pytest.raises(ValueError, match="No coordinates"): + StaticInputs.from_state({"fields": [{"data": data}]}) + + +def test_StaticInputs_from_state_legacy_coords_in_fields(): + """from_state handles old format where coords were stored only inside each field.""" + dim_len = 4 + data = torch.arange(dim_len * dim_len, dtype=torch.float32).reshape( + dim_len, dim_len ) - # Simulate old state dict format that included coords inside fields. - # from_state should silently ignore extra keys (like 'coords') in field dicts. + coords = _make_coords(n=dim_len) old_state = { - "fields": [ - { - "data": data, - "coords": { - "lat": torch.arange(4, dtype=torch.float32), - "lon": torch.arange(4, dtype=torch.float32), - }, - } - ], - "coords": coords.get_state(), + "fields": [{"data": data, "coords": {"lat": coords.lat, "lon": coords.lon}}], } - static_inputs = StaticInputs.from_state(old_state) - assert torch.equal(static_inputs[0].data, data) + result = StaticInputs.from_state(old_state) + assert torch.equal(result[0].data, data) + assert torch.equal(result.coords.lat, coords.lat) + assert torch.equal(result.coords.lon, coords.lon) + + +def test_from_state_backwards_compatible_has_coords(): + """When state has coords, delegates to from_state.""" + data = torch.arange(16, dtype=torch.float32).reshape(4, 4) + coords = _make_coords() + state = StaticInputs([StaticInput(data)], coords=coords).get_state() + result = StaticInputs.from_state_backwards_compatible( + state=state, static_inputs_config={}, fine_coordinates_path=None + ) + assert torch.equal(result[0].data, data) + + +def test_from_state_backwards_compatible_no_state_no_config(tmp_path): + """Old checkpoint with no static inputs: empty StaticInputs with loaded coords.""" + coords_path = str(tmp_path / "coords.nc") + dim_len = 4 + _write_coords_netcdf(coords_path, n=dim_len) + result = StaticInputs.from_state_backwards_compatible( + state={}, static_inputs_config={}, fine_coordinates_path=coords_path + ) + assert result.fields == [] + assert result.coords.lat.shape == (dim_len,) + + +def test_from_state_backwards_compatible_with_config(tmp_path): + """Old checkpoint with static_inputs_config: loads fields from paths.""" + coords_path = str(tmp_path / "coords.nc") + dim_len = 4 + _write_coords_netcdf(coords_path, n=dim_len) + field_data = np.random.rand(dim_len, dim_len).astype(np.float32) + field_path = str(tmp_path / "field.nc") + xr.Dataset({"HGTsfc": (["lat", "lon"], field_data)}).to_netcdf(field_path) + result = StaticInputs.from_state_backwards_compatible( + state={}, + static_inputs_config={"HGTsfc": field_path}, + fine_coordinates_path=coords_path, + ) + assert len(result.fields) == 1 + + +def test_from_state_backwards_compatible_raises_state_and_config(): + """ + Errors if checkpoint state has fields and static_inputs_config is also provided. + """ + dim_len = 4 + data = torch.arange(dim_len * dim_len, dtype=torch.float32).reshape( + dim_len, dim_len + ) + coords = _make_coords(n=dim_len) + state = StaticInputs([StaticInput(data)], coords=coords).get_state() + with pytest.raises(ValueError, match="static_inputs_config"): + StaticInputs.from_state_backwards_compatible( + state=state, + static_inputs_config={"HGTsfc": "some/path"}, + fine_coordinates_path=None, + ) + + +def test_from_state_backwards_compatible_raises_coords_in_state_and_path(): + """Errors if state has coords and fine_coordinates_path is also provided.""" + dim_len = 4 + data = torch.arange(dim_len * dim_len, dtype=torch.float32).reshape( + dim_len, dim_len + ) + coords = _make_coords(n=dim_len) + state = StaticInputs([StaticInput(data)], coords=coords).get_state() + with pytest.raises(ValueError, match="fine_coordinates_path"): + StaticInputs.from_state_backwards_compatible( + state=state, + static_inputs_config={}, + fine_coordinates_path="some/path", + ) + + +def test_from_state_backwards_compatible_raises_no_coords(): + """Errors if no coords in state and no fine_coordinates_path.""" + with pytest.raises(ValueError, match="No coordinates"): + StaticInputs.from_state_backwards_compatible( + state={}, static_inputs_config={}, fine_coordinates_path=None + ) diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 421cd84f4..078befbbd 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -5,7 +5,6 @@ import dacite import torch -import xarray as xr from fme.core.coordinates import LatLonCoordinates from fme.core.device import get_device @@ -20,7 +19,6 @@ PairedBatchData, StaticInputs, adjust_fine_coord_range, - load_static_inputs, ) from fme.downscaling.metrics_and_maths import filter_tensor_mapping, interpolate from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector @@ -552,39 +550,6 @@ def get_state(self) -> Mapping[str, Any]: "static_inputs": self.static_inputs.get_state(), } - @staticmethod - def _legacy_coord_in_state( - static_inputs_state: Mapping[str, Any], - ) -> bool: - return ( - "fields" in static_inputs_state - and static_inputs_state["fields"] - and "coords" in static_inputs_state["fields"][0] - ) - - @classmethod - def coords_in_checkpoint(cls, static_inputs_state: Mapping[str, Any]) -> bool: - return "coords" in static_inputs_state or cls._legacy_coord_in_state( - static_inputs_state - ) - - @classmethod - def _update_state_with_legacy_coords( - cls, - static_inputs_state: dict[str, Any], - ) -> None: - if cls._legacy_coord_in_state(static_inputs_state): - coords = static_inputs_state["fields"][0]["coords"] - static_inputs_state["coords"] = coords - else: - raise ValueError( - "No coordinates found in static_inputs state. " - "Must have a 'coords' field, or at least one field in static_inputs" - " must have a 'coords' field to infer coordinates from for model " - " reconstruction. Coordinates must be serialized with the " - "checkpoint to reload from state during training." - ) - @classmethod def from_state( cls, @@ -597,10 +562,6 @@ def from_state( static_inputs and fine_coordinates for backwards compatibility. """ static_inputs_state: dict = state.get("static_inputs") or {} - - if "coords" not in static_inputs_state: - cls._update_state_with_legacy_coords(static_inputs_state) - static_inputs = StaticInputs.from_state(static_inputs_state) config = DiffusionModelConfig.from_state(state["config"]) model = config.build( @@ -623,26 +584,6 @@ def from_state(cls, state: Mapping[str, Any]) -> DiffusionModelConfig: ).wrapper -def load_fine_coords_from_path(path: str) -> LatLonCoordinates: - if path.endswith(".zarr"): - ds = xr.open_zarr(path) - else: - ds = xr.open_dataset(path) - lat_name = next((n for n in ["lat", "latitude", "grid_yt"] if n in ds.coords), None) - lon_name = next( - (n for n in ["lon", "longitude", "grid_xt"] if n in ds.coords), None - ) - if lat_name is None or lon_name is None: - raise ValueError( - f"Could not find lat/lon coordinates in {path}. " - "Expected 'lat'/'latitude'/'grid_yt' and 'lon'/'longitude'/'grid_xt'." - ) - return LatLonCoordinates( - lat=torch.tensor(ds[lat_name].values, dtype=torch.float32), - lon=torch.tensor(ds[lon_name].values, dtype=torch.float32), - ) - - @dataclasses.dataclass class CheckpointModelConfig: """ @@ -698,8 +639,6 @@ def _checkpoint(self) -> Mapping[str, Any]: self._rename.get(name, name) for name in checkpoint_data["model"]["config"]["out_names"] ] - # backwards compatibility for models before static inputs serialization - checkpoint_data["model"].setdefault("static_inputs", {}) self._checkpoint_data = checkpoint_data self._checkpoint_is_loaded = True @@ -711,48 +650,14 @@ def _checkpoint(self) -> Mapping[str, Any]: def build( self, ) -> DiffusionModel: - checkpoint_model = self._checkpoint["model"] - checkpoint_static_inputs = checkpoint_model["static_inputs"] + checkpoint_model: dict = self._checkpoint["model"] + checkpoint_static_inputs = checkpoint_model.get("static_inputs", {}) - has_coords_in_checkpoint = DiffusionModel.coords_in_checkpoint( - checkpoint_static_inputs + static_inputs = StaticInputs.from_state_backwards_compatible( + state=checkpoint_static_inputs, + static_inputs_config=self.static_inputs or {}, + fine_coordinates_path=self.fine_coordinates_path, ) - if has_coords_in_checkpoint and self.fine_coordinates_path is not None: - raise ValueError( - "The model checkpoint already has fine coordinates from training. " - "fine_coordinates_path should not be provided in checkpoint model " - "config." - ) - elif checkpoint_static_inputs and self.static_inputs is not None: - raise ValueError( - "The model checkpoint already has static inputs from training. " - "static_inputs should not be provided in checkpoint model config." - ) - - if has_coords_in_checkpoint: - if DiffusionModel._legacy_coord_in_state(checkpoint_static_inputs): - DiffusionModel._update_state_with_legacy_coords( - checkpoint_static_inputs - ) - static_inputs = StaticInputs.from_state(checkpoint_static_inputs) - else: - if self.fine_coordinates_path is None: - raise ValueError( - "No fine coordinates found in checkpoint state. " - "fine_coordinates_path must be provided in the checkpoint model " - "configuration to load this model." - ) - fine_coords = load_fine_coords_from_path(self.fine_coordinates_path) - if checkpoint_static_inputs: - checkpoint_static_inputs["coords"] = fine_coords.get_state() - static_inputs = StaticInputs.from_state(checkpoint_static_inputs) - elif self.static_inputs is not None: - static_inputs = load_static_inputs( - self.static_inputs, coords=fine_coords - ) - else: - static_inputs = StaticInputs(fields=[], coords=fine_coords) - model = _CheckpointModelConfigSelector.from_state( self._checkpoint["model"]["config"] ).build( From 36e810102331e11b7a2d0454a0f56381d89b22e1 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Thu, 19 Mar 2026 15:06:16 -0700 Subject: [PATCH 20/27] Remove unused function from fme.downscaling.data --- fme/downscaling/data/__init__.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/fme/downscaling/data/__init__.py b/fme/downscaling/data/__init__.py index 334e1a949..5790cdc11 100644 --- a/fme/downscaling/data/__init__.py +++ b/fme/downscaling/data/__init__.py @@ -12,12 +12,7 @@ PairedBatchItem, PairedGriddedData, ) -from .static import ( - StaticInput, - StaticInputs, - load_fine_coords_from_path, - load_static_inputs, -) +from .static import StaticInput, StaticInputs, load_static_inputs from .utils import ( BatchedLatLonCoordinates, ClosedInterval, From 3f95457ad297c719d5d7359bab8b40941673f73e Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Thu, 19 Mar 2026 15:10:22 -0700 Subject: [PATCH 21/27] fine_coords -> full_fine_coords --- fme/downscaling/data/test_static.py | 32 +++++++++++- fme/downscaling/models.py | 6 +-- fme/downscaling/predictors/composite.py | 2 +- fme/downscaling/test_models.py | 67 ++++++++----------------- 4 files changed, 57 insertions(+), 50 deletions(-) diff --git a/fme/downscaling/data/test_static.py b/fme/downscaling/data/test_static.py index f20dfe667..36a68c1ff 100644 --- a/fme/downscaling/data/test_static.py +++ b/fme/downscaling/data/test_static.py @@ -3,7 +3,7 @@ import torch import xarray as xr -from .static import StaticInput, StaticInputs +from .static import StaticInput, StaticInputs, load_fine_coords_from_path def _make_coords(n=4): @@ -159,3 +159,33 @@ def test_from_state_backwards_compatible_raises_no_coords(): StaticInputs.from_state_backwards_compatible( state={}, static_inputs_config={}, fine_coordinates_path=None ) + + +@pytest.mark.parametrize( + "lat_name,lon_name", + [ + pytest.param("lat", "lon", id="standard"), + pytest.param("latitude", "longitude", id="long_names"), + pytest.param("grid_yt", "grid_xt", id="fv3_names"), + ], +) +def test_load_fine_coords_from_path(tmp_path, lat_name, lon_name): + lat = [0.0, 1.0, 2.0] + lon = [10.0, 20.0, 30.0, 40.0] + ds = xr.Dataset(coords={lat_name: lat, lon_name: lon}) + path = str(tmp_path / "coords.nc") + ds.to_netcdf(path) + + coords = load_fine_coords_from_path(path) + + assert torch.allclose(coords.lat, torch.tensor(lat, dtype=torch.float32)) + assert torch.allclose(coords.lon, torch.tensor(lon, dtype=torch.float32)) + + +def test_load_fine_coords_from_path_raises_on_missing_coords(tmp_path): + ds = xr.Dataset(coords={"x": [0.0, 1.0], "y": [10.0, 20.0]}) + path = str(tmp_path / "no_latlon.nc") + ds.to_netcdf(path) + + with pytest.raises(ValueError, match="Could not find lat/lon coordinates"): + load_fine_coords_from_path(path) diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 078befbbd..187133895 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -313,7 +313,7 @@ def modules(self) -> torch.nn.ModuleList: return torch.nn.ModuleList([self.module]) @property - def fine_coords(self) -> LatLonCoordinates: + def full_fine_coords(self) -> LatLonCoordinates: return self.static_inputs.coords def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: @@ -325,13 +325,13 @@ def _get_subset_for_coarse_batch(self, batch: BatchData) -> StaticInputs: fine_lat_interval = adjust_fine_coord_range( batch.lat_interval, full_coarse_coord=coarse_lat, - full_fine_coord=self.fine_coords.lat, + full_fine_coord=self.full_fine_coords.lat, downscale_factor=self.downscale_factor, ) fine_lon_interval = adjust_fine_coord_range( batch.lon_interval, full_coarse_coord=coarse_lon, - full_fine_coord=self.fine_coords.lon, + full_fine_coord=self.full_fine_coords.lon, downscale_factor=self.downscale_factor, ) subset_static_inputs = self.static_inputs.subset( diff --git a/fme/downscaling/predictors/composite.py b/fme/downscaling/predictors/composite.py index 5d0a4be8f..3ce1c3e02 100644 --- a/fme/downscaling/predictors/composite.py +++ b/fme/downscaling/predictors/composite.py @@ -88,7 +88,7 @@ def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: @property def fine_coords(self): - return self.model.fine_coords + return self.model.full_fine_coords def _get_patches( self, coarse_yx_extent, fine_yx_extent diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index 575cc5b91..14a0fe141 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -24,7 +24,6 @@ PairedNormalizationConfig, _repeat_batch_by_samples, _separate_interleaved_samples, - load_fine_coords_from_path, ) from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector from fme.downscaling.noise import LogNormalNoiseDistribution @@ -162,9 +161,13 @@ def test_module_serialization(tmp_path): model.module.parameters(), model_from_state.module.parameters() ) ) - assert model_from_state.fine_coords is not None - assert torch.equal(model_from_state.fine_coords.lat.cpu(), fine_coords.lat.cpu()) - assert torch.equal(model_from_state.fine_coords.lon.cpu(), fine_coords.lon.cpu()) + assert model_from_state.full_fine_coords is not None + assert torch.equal( + model_from_state.full_fine_coords.lat.cpu(), fine_coords.lat.cpu() + ) + assert torch.equal( + model_from_state.full_fine_coords.lon.cpu(), fine_coords.lon.cpu() + ) torch.save(model.get_state(), tmp_path / "test.ckpt") model_from_disk = DiffusionModel.from_state( @@ -181,9 +184,13 @@ def test_module_serialization(tmp_path): assert torch.equal( loaded_static_inputs.fields[0].data, static_inputs.fields[0].data ) - assert model_from_disk.fine_coords is not None - assert torch.equal(model_from_disk.fine_coords.lat.cpu(), fine_coords.lat.cpu()) - assert torch.equal(model_from_disk.fine_coords.lon.cpu(), fine_coords.lon.cpu()) + assert model_from_disk.full_fine_coords is not None + assert torch.equal( + model_from_disk.full_fine_coords.lat.cpu(), fine_coords.lat.cpu() + ) + assert torch.equal( + model_from_disk.full_fine_coords.lon.cpu(), fine_coords.lon.cpu() + ) def test_from_state_raises_when_no_coords_available(): @@ -234,12 +241,12 @@ def test_from_state_backward_compat_migrates_fine_coords_from_old_static_inputs( state["static_inputs"]["fields"][0]["coords"] = fine_coords.get_state() model_from_old_state = DiffusionModel.from_state(state) - assert model_from_old_state.fine_coords is not None + assert model_from_old_state.full_fine_coords is not None assert torch.equal( - model_from_old_state.fine_coords.lat.cpu(), fine_coords.lat.cpu() + model_from_old_state.full_fine_coords.lat.cpu(), fine_coords.lat.cpu() ) assert torch.equal( - model_from_old_state.fine_coords.lon.cpu(), fine_coords.lon.cpu() + model_from_old_state.full_fine_coords.lon.cpu(), fine_coords.lon.cpu() ) @@ -596,8 +603,8 @@ def test_get_fine_coords_for_batch(): result = model.get_fine_coords_for_batch(batch) - expected_lat = model.fine_coords.lat[4:12] - expected_lon = model.fine_coords.lon[8:24] + expected_lat = model.full_fine_coords.lat[4:12] + expected_lon = model.full_fine_coords.lon[8:24] assert torch.allclose(result.lat, expected_lat) assert torch.allclose(result.lon, expected_lon) @@ -611,7 +618,7 @@ def test_checkpoint_model_build_with_fine_coordinates_path(tmp_path): downscale_factor=2, use_fine_topography=False, ) - fine_coords = model.fine_coords + fine_coords = model.full_fine_coords # Simulate old checkpoint: no coords in static_inputs state state = model.get_state() del state["static_inputs"]["coords"] @@ -633,8 +640,8 @@ def test_checkpoint_model_build_with_fine_coordinates_path(tmp_path): fine_coordinates_path=str(coords_path), ).build() - assert torch.equal(loaded_model.fine_coords.lat.cpu(), fine_coords.lat.cpu()) - assert torch.equal(loaded_model.fine_coords.lon.cpu(), fine_coords.lon.cpu()) + assert torch.equal(loaded_model.full_fine_coords.lat.cpu(), fine_coords.lat.cpu()) + assert torch.equal(loaded_model.full_fine_coords.lon.cpu(), fine_coords.lon.cpu()) def test_checkpoint_config_topography_raises(): @@ -665,33 +672,3 @@ def test_checkpoint_model_build_raises_when_checkpoint_has_static_inputs(tmp_pat ) with pytest.raises(ValueError): config.build() - - -@pytest.mark.parametrize( - "lat_name,lon_name", - [ - pytest.param("lat", "lon", id="standard"), - pytest.param("latitude", "longitude", id="long_names"), - pytest.param("grid_yt", "grid_xt", id="fv3_names"), - ], -) -def test_load_fine_coords_from_path(tmp_path, lat_name, lon_name): - lat = [0.0, 1.0, 2.0] - lon = [10.0, 20.0, 30.0, 40.0] - ds = xr.Dataset(coords={lat_name: lat, lon_name: lon}) - path = str(tmp_path / "coords.nc") - ds.to_netcdf(path) - - coords = load_fine_coords_from_path(path) - - assert torch.allclose(coords.lat, torch.tensor(lat, dtype=torch.float32)) - assert torch.allclose(coords.lon, torch.tensor(lon, dtype=torch.float32)) - - -def test_load_fine_coords_from_path_raises_on_missing_coords(tmp_path): - ds = xr.Dataset(coords={"x": [0.0, 1.0], "y": [10.0, 20.0]}) - path = str(tmp_path / "no_latlon.nc") - ds.to_netcdf(path) - - with pytest.raises(ValueError, match="Could not find lat/lon coordinates"): - load_fine_coords_from_path(path) From 6c9fcf78606d5ff226dd64dba6f53af814e3db6a Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Thu, 19 Mar 2026 15:10:49 -0700 Subject: [PATCH 22/27] Remove duplicated tests from models.py --- fme/downscaling/test_models.py | 30 ------------------------------ 1 file changed, 30 deletions(-) diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index 14a0fe141..1ebf6f278 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -642,33 +642,3 @@ def test_checkpoint_model_build_with_fine_coordinates_path(tmp_path): assert torch.equal(loaded_model.full_fine_coords.lat.cpu(), fine_coords.lat.cpu()) assert torch.equal(loaded_model.full_fine_coords.lon.cpu(), fine_coords.lon.cpu()) - - -def test_checkpoint_config_topography_raises(): - with pytest.raises(ValueError): - CheckpointModelConfig( - checkpoint_path="/any/path.ckpt", - fine_topography_path="/topo/path.nc", - ) - - -def test_checkpoint_model_build_raises_when_checkpoint_has_static_inputs(tmp_path): - coarse_shape = (8, 16) - fine_shape = (16, 32) - static_inputs = make_static_inputs(fine_shape) - model = _get_diffusion_model( - coarse_shape=coarse_shape, - downscale_factor=2, - predict_residual=True, - use_fine_topography=True, - static_inputs=static_inputs, - ) - checkpoint_path = tmp_path / "test.ckpt" - torch.save({"model": model.get_state()}, checkpoint_path) - - config = CheckpointModelConfig( - checkpoint_path=str(checkpoint_path), - static_inputs={"HGTsfc": "/any/path.nc"}, - ) - with pytest.raises(ValueError): - config.build() From 4b58b86e2c283b528d5344343e168520f837fbeb Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Thu, 19 Mar 2026 15:19:22 -0700 Subject: [PATCH 23/27] Minor fixes --- fme/downscaling/data/static.py | 2 +- fme/downscaling/data/utils.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/fme/downscaling/data/static.py b/fme/downscaling/data/static.py index 6f34b4a26..0eb10e9af 100644 --- a/fme/downscaling/data/static.py +++ b/fme/downscaling/data/static.py @@ -174,7 +174,7 @@ def get_state(self) -> dict: def from_state(cls, state: dict) -> "StaticInputs": if not _has_coords_in_state(state): raise ValueError( - "No coordinates found in state for StaticInputs. Load with" + "No coordinates found in state for StaticInputs. Load with " "from_state_backwards_compatible if loading from a checkpoint " "saved prior to current coordinate serialization format." ) diff --git a/fme/downscaling/data/utils.py b/fme/downscaling/data/utils.py index ca3a4359a..bdd34e188 100644 --- a/fme/downscaling/data/utils.py +++ b/fme/downscaling/data/utils.py @@ -68,8 +68,7 @@ def subset_of(self, coords: torch.Tensor) -> torch.Tensor: Return a subset of `coords` that falls within this specified interval. This assumes `coords` is monotonically increasing. """ - slice = self.slice_from(coords) - return coords[slice] + return coords[self.slice_from(coords)] def scale_slice(slice_: slice, scale: int) -> slice: From 7e36f554ab13da5e00bfd11791716e24f2e9c7fd Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Thu, 19 Mar 2026 15:26:40 -0700 Subject: [PATCH 24/27] Fix imports --- fme/downscaling/data/__init__.py | 7 ++++++- fme/downscaling/inference/test_inference.py | 2 +- fme/downscaling/test_predict.py | 8 ++++---- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/fme/downscaling/data/__init__.py b/fme/downscaling/data/__init__.py index 5790cdc11..334e1a949 100644 --- a/fme/downscaling/data/__init__.py +++ b/fme/downscaling/data/__init__.py @@ -12,7 +12,12 @@ PairedBatchItem, PairedGriddedData, ) -from .static import StaticInput, StaticInputs, load_static_inputs +from .static import ( + StaticInput, + StaticInputs, + load_fine_coords_from_path, + load_static_inputs, +) from .utils import ( BatchedLatLonCoordinates, ClosedInterval, diff --git a/fme/downscaling/inference/test_inference.py b/fme/downscaling/inference/test_inference.py index 9ba54e580..2f89ba7de 100644 --- a/fme/downscaling/inference/test_inference.py +++ b/fme/downscaling/inference/test_inference.py @@ -16,6 +16,7 @@ LatLonCoordinates, StaticInput, StaticInputs, + load_fine_coords_from_path, load_static_inputs, ) from fme.downscaling.inference.constants import ENSEMBLE_NAME, TIME_NAME @@ -32,7 +33,6 @@ LossConfig, NormalizationConfig, PairedNormalizationConfig, - load_fine_coords_from_path, ) from fme.downscaling.predictors import PatchPredictionConfig, PatchPredictor from fme.downscaling.test_evaluator import LinearDownscalingDiffusion diff --git a/fme/downscaling/test_predict.py b/fme/downscaling/test_predict.py index 50411a0fd..01739afe5 100644 --- a/fme/downscaling/test_predict.py +++ b/fme/downscaling/test_predict.py @@ -8,12 +8,12 @@ from fme.core.normalizer import NormalizationConfig from fme.core.testing.wandb import mock_wandb from fme.downscaling import predict -from fme.downscaling.data import StaticInputs, load_static_inputs -from fme.downscaling.models import ( - DiffusionModelConfig, - PairedNormalizationConfig, +from fme.downscaling.data import ( + StaticInputs, load_fine_coords_from_path, + load_static_inputs, ) +from fme.downscaling.models import DiffusionModelConfig, PairedNormalizationConfig from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector from fme.downscaling.test_models import LinearDownscaling from fme.downscaling.test_utils import data_paths_helper From a85fb05650a7df898ec92e87bb4f1bacf6598f1c Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Thu, 19 Mar 2026 16:13:37 -0700 Subject: [PATCH 25/27] Final cleanup --- fme/downscaling/data/static.py | 2 +- fme/downscaling/inference/test_inference.py | 2 -- fme/downscaling/predictors/composite.py | 2 +- fme/downscaling/test_models.py | 13 ------------- 4 files changed, 2 insertions(+), 17 deletions(-) diff --git a/fme/downscaling/data/static.py b/fme/downscaling/data/static.py index 0eb10e9af..296067c1e 100644 --- a/fme/downscaling/data/static.py +++ b/fme/downscaling/data/static.py @@ -237,7 +237,7 @@ def load_static_inputs( ) -> StaticInputs: """ Load normalized static inputs from a mapping of field names to file paths. - Returns an empty StaticInputs (no fields) if the config is None or empty. + Returns an empty StaticInputs (no fields) if the config is empty. """ fields = [ _get_normalized_static_input(path, field_name) diff --git a/fme/downscaling/inference/test_inference.py b/fme/downscaling/inference/test_inference.py index 2f89ba7de..e15c3ec09 100644 --- a/fme/downscaling/inference/test_inference.py +++ b/fme/downscaling/inference/test_inference.py @@ -64,8 +64,6 @@ def mock_output_target(): def get_static_inputs(shape=(16, 16)): - from fme.core.coordinates import LatLonCoordinates - data = torch.randn(shape) coords = LatLonCoordinates( lat=torch.arange(shape[0], dtype=torch.float32), diff --git a/fme/downscaling/predictors/composite.py b/fme/downscaling/predictors/composite.py index 3ce1c3e02..45256264f 100644 --- a/fme/downscaling/predictors/composite.py +++ b/fme/downscaling/predictors/composite.py @@ -87,7 +87,7 @@ def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: return self.model.get_fine_coords_for_batch(batch) @property - def fine_coords(self): + def full_fine_coords(self): return self.model.full_fine_coords def _get_patches( diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index 1ebf6f278..005b3f8bd 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -193,19 +193,6 @@ def test_module_serialization(tmp_path): ) -def test_from_state_raises_when_no_coords_available(): - """from_state must raise when the checkpoint has neither top-level coords - nor legacy per-field coords to migrate from.""" - model = _get_diffusion_model( - coarse_shape=(8, 16), downscale_factor=2, use_fine_topography=False - ) - state = model.get_state() - # Wipe coords entirely — no way to recover them - state["static_inputs"] = None - with pytest.raises(ValueError): - DiffusionModel.from_state(state) - - def test_generate_raises_when_no_static_fields_but_topography_required(): coarse_shape = (8, 16) fine_shape = (16, 32) From bedd276f831df9900290e459a1fc1bb708704c3a Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Fri, 20 Mar 2026 12:22:28 -0700 Subject: [PATCH 26/27] Add coordinate validation and use training data as a fallback vs. static inputs --- fme/downscaling/data/static.py | 112 ++++++++++++++------ fme/downscaling/inference/test_inference.py | 4 +- fme/downscaling/test_predict.py | 4 +- fme/downscaling/train.py | 2 +- 4 files changed, 84 insertions(+), 38 deletions(-) diff --git a/fme/downscaling/data/static.py b/fme/downscaling/data/static.py index 296067c1e..c5a741125 100644 --- a/fme/downscaling/data/static.py +++ b/fme/downscaling/data/static.py @@ -48,7 +48,34 @@ def from_state(cls, state: dict) -> "StaticInput": return cls(data=state["data"]) -def _get_normalized_static_input(path: str, field_name: str): +def _load_coords_from_ds(ds: xr.Dataset) -> LatLonCoordinates: + lat_name = next((n for n in ["lat", "latitude", "grid_yt"] if n in ds.coords), None) + lon_name = next( + (n for n in ["lon", "longitude", "grid_xt"] if n in ds.coords), None + ) + if lat_name is None or lon_name is None: + raise ValueError( + f"Could not find lat/lon coordinates in dataset. " + "Expected 'lat'/'latitude'/'grid_yt' and 'lon'/'longitude'/'grid_xt'." + ) + return LatLonCoordinates( + lat=torch.tensor(ds[lat_name].values, dtype=torch.float32), + lon=torch.tensor(ds[lon_name].values, dtype=torch.float32), + ) + + +def load_fine_coords_from_path(path: str) -> LatLonCoordinates: + if path.endswith(".zarr"): + ds = xr.open_zarr(path) + else: + ds = xr.open_dataset(path) + + return _load_coords_from_ds(ds) + + +def _get_normalized_static_input( + path: str, field_name: str +) -> tuple[StaticInput, LatLonCoordinates | None]: """ Load a static input field from a given file path and field name and normalize it. @@ -58,22 +85,30 @@ def _get_normalized_static_input(path: str, field_name: str): assumed to be the last two dimensions of the loaded dataset dimensions. """ if path.endswith(".zarr"): - static_input = xr.open_zarr(path, mask_and_scale=False)[field_name] + ds = xr.open_zarr(path, mask_and_scale=False) else: - static_input = xr.open_dataset(path, mask_and_scale=False)[field_name] - if "time" in static_input.dims: - static_input = static_input.isel(time=0).squeeze() - if len(static_input.shape) != 2: + ds = xr.open_dataset(path, mask_and_scale=False) + + da = ds[field_name] + try: + coords = _load_coords_from_ds(ds) + except ValueError: + # no coords available + coords = None + + if "time" in da.dims: + da = da.isel(time=0).squeeze() + if len(da.shape) != 2: raise ValueError( - f"unexpected shape {static_input.shape} for static input." + f"unexpected shape {da.shape} for static input." "Currently, only lat/lon static input is supported." ) - static_input_normalized = (static_input - static_input.mean()) / static_input.std() + static_input_normalized = (da - da.mean()) / da.std() return StaticInput( data=torch.tensor(static_input_normalized.values, dtype=torch.float32), - ) + ), coords def _has_legacy_coords_in_state(state: dict) -> bool: @@ -96,26 +131,6 @@ def _has_coords_in_state(state: dict) -> bool: return False -def load_fine_coords_from_path(path: str) -> LatLonCoordinates: - if path.endswith(".zarr"): - ds = xr.open_zarr(path) - else: - ds = xr.open_dataset(path) - lat_name = next((n for n in ["lat", "latitude", "grid_yt"] if n in ds.coords), None) - lon_name = next( - (n for n in ["lon", "longitude", "grid_xt"] if n in ds.coords), None - ) - if lat_name is None or lon_name is None: - raise ValueError( - f"Could not find lat/lon coordinates in {path}. " - "Expected 'lat'/'latitude'/'grid_yt' and 'lon'/'longitude'/'grid_xt'." - ) - return LatLonCoordinates( - lat=torch.tensor(ds[lat_name].values, dtype=torch.float32), - lon=torch.tensor(ds[lon_name].values, dtype=torch.float32), - ) - - @dataclasses.dataclass class StaticInputs: fields: list[StaticInput] @@ -232,15 +247,42 @@ def from_state_backwards_compatible( return cls(fields=[], coords=coords) +def _validate_coords( + case: str, coord1: LatLonCoordinates, coord2: LatLonCoordinates +) -> None: + if not coord1 == coord2: + raise ValueError(f"Coordinates do not match between static inputs: {case}") + + def load_static_inputs( - static_inputs_config: dict[str, str], coords: LatLonCoordinates + static_inputs_config: dict[str, str], + fallback_coords: LatLonCoordinates, + validate_coords: bool = True, ) -> StaticInputs: """ Load normalized static inputs from a mapping of field names to file paths. Returns an empty StaticInputs (no fields) if the config is empty. + + Coordinates are inferred from the static input field datasets and verified + to match between each field. If no static inputs are provided + coordinates are used from fallback_coords. """ - fields = [ - _get_normalized_static_input(path, field_name) - for field_name, path in static_inputs_config.items() - ] - return StaticInputs(fields=fields, coords=coords) + coords_to_use = None + fields = [] + for field_name, path in static_inputs_config.items(): + si, coords = _get_normalized_static_input(path, field_name) + fields.append(si) + + if coords is not None and coords_to_use is None: + coords_to_use = coords + elif coords is not None and validate_coords: + assert coords_to_use is not None # for type checker + _validate_coords(field_name, coords, coords_to_use) + + if coords_to_use is None: + # no coords found with static inputs, use provided fallback + coords_to_use = fallback_coords + elif validate_coords: + _validate_coords("fallback", coords_to_use, fallback_coords) + + return StaticInputs(fields=fields, coords=coords_to_use) diff --git a/fme/downscaling/inference/test_inference.py b/fme/downscaling/inference/test_inference.py index e15c3ec09..8df3a2263 100644 --- a/fme/downscaling/inference/test_inference.py +++ b/fme/downscaling/inference/test_inference.py @@ -281,7 +281,9 @@ def checkpointed_model_config( # that correspond to the dataset coordinates fine_data_path = f"{data_paths.fine}/data.nc" fine_coords = load_fine_coords_from_path(fine_data_path) - static_inputs = load_static_inputs({"HGTsfc": fine_data_path}, coords=fine_coords) + static_inputs = load_static_inputs( + {"HGTsfc": fine_data_path}, fallback_coords=fine_coords + ) model = model_config.build(coarse_shape, 2, static_inputs=static_inputs) checkpoint_path = tmp_path / "model_checkpoint.pth" diff --git a/fme/downscaling/test_predict.py b/fme/downscaling/test_predict.py index 01739afe5..626001b4d 100644 --- a/fme/downscaling/test_predict.py +++ b/fme/downscaling/test_predict.py @@ -121,7 +121,9 @@ def test_predictor_runs(tmp_path, very_fast_only: bool): coarse_shape=coarse_shape, downscale_factor=downscale_factor, static_inputs=load_static_inputs( - {"HGTsfc": fine_data_path}, coords=fine_coords + {"HGTsfc": fine_data_path}, + fallback_coords=fine_coords, + validate_coords=False, ), ) with open(predictor_config_path) as f: diff --git a/fme/downscaling/train.py b/fme/downscaling/train.py index d91b8b4f4..f51b61f94 100755 --- a/fme/downscaling/train.py +++ b/fme/downscaling/train.py @@ -430,7 +430,7 @@ def build(self) -> Trainer: requirements=self.model.data_requirements, ) static_inputs = load_static_inputs( - self.static_inputs, coords=train_data.fine_coords + self.static_inputs, fallback_coords=train_data.fine_coords ) if self.coarse_patch_extent_lat and self.coarse_patch_extent_lon: From a9dbae3c9bad3a845d4a0123d6f709fc06f4aa54 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Fri, 20 Mar 2026 17:43:08 -0700 Subject: [PATCH 27/27] Updates based on discussion w/ Anna --- fme/downscaling/data/static.py | 57 ++------ fme/downscaling/data/test_static.py | 53 ++------ fme/downscaling/inference/test_inference.py | 6 +- fme/downscaling/models.py | 140 +++++++++++++++----- fme/downscaling/test_models.py | 86 ++++++------ fme/downscaling/test_predict.py | 9 +- fme/downscaling/train.py | 8 +- 7 files changed, 193 insertions(+), 166 deletions(-) diff --git a/fme/downscaling/data/static.py b/fme/downscaling/data/static.py index c5a741125..8fee89e24 100644 --- a/fme/downscaling/data/static.py +++ b/fme/downscaling/data/static.py @@ -83,6 +83,8 @@ def _get_normalized_static_input( Only supports 2D lat/lon static inputs. If the input has a time dimension, it is squeezed by taking the first time step. The lat/lon coordinates are assumed to be the last two dimensions of the loaded dataset dimensions. + + Raises ValueError if lat/lon coordinates are not found in the dataset. """ if path.endswith(".zarr"): ds = xr.open_zarr(path, mask_and_scale=False) @@ -90,11 +92,7 @@ def _get_normalized_static_input( ds = xr.open_dataset(path, mask_and_scale=False) da = ds[field_name] - try: - coords = _load_coords_from_ds(ds) - except ValueError: - # no coords available - coords = None + coords = _load_coords_from_ds(ds) if "time" in da.dims: da = da.isel(time=0).squeeze() @@ -209,8 +207,7 @@ def from_state_backwards_compatible( cls, state: dict, static_inputs_config: dict[str, str], - fine_coordinates_path: str | None, - ) -> "StaticInputs": + ) -> "StaticInputs | None": if state and static_inputs_config: raise ValueError( "Checkpoint contains static inputs but static_inputs_config is " @@ -218,33 +215,12 @@ def from_state_backwards_compatible( "a single source of StaticInputs info." ) - if fine_coordinates_path and _has_coords_in_state(state): - raise ValueError( - "State contains coordinates but fine_coordinates_path is also provided." - " Only one source of coordinate info can be used for backwards " - "compatibility loading of StaticInputs." - ) - elif not _has_coords_in_state(state) and not fine_coordinates_path: - raise ValueError( - "No coordinates found in state and no fine_coordinates_path provided. " - "Cannot load StaticInputs without coordinates." - ) - - # All compatibility cases: - # Serialized StaticInputs exist, which always had coordinates stored - # No serialized static inputs or specified inputs, load coordinates - # Specified static input fields and specified coordinates - if _has_coords_in_state(state): return cls.from_state(state) + elif static_inputs_config: + return load_static_inputs(static_inputs_config) else: - assert fine_coordinates_path is not None # for type checker - coords = load_fine_coords_from_path(fine_coordinates_path) - - if static_inputs_config: - return load_static_inputs(static_inputs_config, coords) - else: - return cls(fields=[], coords=coords) + return None def _validate_coords( @@ -256,16 +232,13 @@ def _validate_coords( def load_static_inputs( static_inputs_config: dict[str, str], - fallback_coords: LatLonCoordinates, - validate_coords: bool = True, ) -> StaticInputs: """ Load normalized static inputs from a mapping of field names to file paths. - Returns an empty StaticInputs (no fields) if the config is empty. - Coordinates are inferred from the static input field datasets and verified - to match between each field. If no static inputs are provided - coordinates are used from fallback_coords. + Coordinates are read from each field's source dataset and validated to be + consistent between fields. Raises ValueError if any field's dataset lacks + lat/lon coordinates or if coordinates differ between fields. """ coords_to_use = None fields = [] @@ -273,16 +246,12 @@ def load_static_inputs( si, coords = _get_normalized_static_input(path, field_name) fields.append(si) - if coords is not None and coords_to_use is None: + if coords_to_use is None: coords_to_use = coords - elif coords is not None and validate_coords: - assert coords_to_use is not None # for type checker + elif coords is not None and coords_to_use is not None: _validate_coords(field_name, coords, coords_to_use) if coords_to_use is None: - # no coords found with static inputs, use provided fallback - coords_to_use = fallback_coords - elif validate_coords: - _validate_coords("fallback", coords_to_use, fallback_coords) + raise ValueError("load_static_inputs requires at least one field.") return StaticInputs(fields=fields, coords=coords_to_use) diff --git a/fme/downscaling/data/test_static.py b/fme/downscaling/data/test_static.py index 36a68c1ff..955da241d 100644 --- a/fme/downscaling/data/test_static.py +++ b/fme/downscaling/data/test_static.py @@ -86,36 +86,36 @@ def test_from_state_backwards_compatible_has_coords(): coords = _make_coords() state = StaticInputs([StaticInput(data)], coords=coords).get_state() result = StaticInputs.from_state_backwards_compatible( - state=state, static_inputs_config={}, fine_coordinates_path=None + state=state, static_inputs_config={} ) + assert result is not None assert torch.equal(result[0].data, data) -def test_from_state_backwards_compatible_no_state_no_config(tmp_path): - """Old checkpoint with no static inputs: empty StaticInputs with loaded coords.""" - coords_path = str(tmp_path / "coords.nc") - dim_len = 4 - _write_coords_netcdf(coords_path, n=dim_len) +def test_from_state_backwards_compatible_no_state_no_config(): + """No static inputs in checkpoint and no config: returns None.""" result = StaticInputs.from_state_backwards_compatible( - state={}, static_inputs_config={}, fine_coordinates_path=coords_path + state={}, static_inputs_config={} ) - assert result.fields == [] - assert result.coords.lat.shape == (dim_len,) + assert result is None def test_from_state_backwards_compatible_with_config(tmp_path): - """Old checkpoint with static_inputs_config: loads fields from paths.""" - coords_path = str(tmp_path / "coords.nc") + """Checkpoint with static_inputs_config: loads fields and coords from files.""" dim_len = 4 - _write_coords_netcdf(coords_path, n=dim_len) + lat = np.linspace(0, 1, dim_len, dtype=np.float32) + lon = np.linspace(0, 1, dim_len, dtype=np.float32) field_data = np.random.rand(dim_len, dim_len).astype(np.float32) field_path = str(tmp_path / "field.nc") - xr.Dataset({"HGTsfc": (["lat", "lon"], field_data)}).to_netcdf(field_path) + xr.Dataset( + {"HGTsfc": (["lat", "lon"], field_data)}, + coords={"lat": lat, "lon": lon}, + ).to_netcdf(field_path) result = StaticInputs.from_state_backwards_compatible( state={}, static_inputs_config={"HGTsfc": field_path}, - fine_coordinates_path=coords_path, ) + assert result is not None assert len(result.fields) == 1 @@ -133,31 +133,6 @@ def test_from_state_backwards_compatible_raises_state_and_config(): StaticInputs.from_state_backwards_compatible( state=state, static_inputs_config={"HGTsfc": "some/path"}, - fine_coordinates_path=None, - ) - - -def test_from_state_backwards_compatible_raises_coords_in_state_and_path(): - """Errors if state has coords and fine_coordinates_path is also provided.""" - dim_len = 4 - data = torch.arange(dim_len * dim_len, dtype=torch.float32).reshape( - dim_len, dim_len - ) - coords = _make_coords(n=dim_len) - state = StaticInputs([StaticInput(data)], coords=coords).get_state() - with pytest.raises(ValueError, match="fine_coordinates_path"): - StaticInputs.from_state_backwards_compatible( - state=state, - static_inputs_config={}, - fine_coordinates_path="some/path", - ) - - -def test_from_state_backwards_compatible_raises_no_coords(): - """Errors if no coords in state and no fine_coordinates_path.""" - with pytest.raises(ValueError, match="No coordinates"): - StaticInputs.from_state_backwards_compatible( - state={}, static_inputs_config={}, fine_coordinates_path=None ) diff --git a/fme/downscaling/inference/test_inference.py b/fme/downscaling/inference/test_inference.py index 8df3a2263..aadd9fd68 100644 --- a/fme/downscaling/inference/test_inference.py +++ b/fme/downscaling/inference/test_inference.py @@ -281,10 +281,10 @@ def checkpointed_model_config( # that correspond to the dataset coordinates fine_data_path = f"{data_paths.fine}/data.nc" fine_coords = load_fine_coords_from_path(fine_data_path) - static_inputs = load_static_inputs( - {"HGTsfc": fine_data_path}, fallback_coords=fine_coords + static_inputs = load_static_inputs({"HGTsfc": fine_data_path}) + model = model_config.build( + coarse_shape, 2, full_fine_coords=fine_coords, static_inputs=static_inputs ) - model = model_config.build(coarse_shape, 2, static_inputs=static_inputs) checkpoint_path = tmp_path / "model_checkpoint.pth" model.get_state() diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 187133895..84287f502 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -16,9 +16,11 @@ from fme.core.typing_ import TensorDict, TensorMapping from fme.downscaling.data import ( BatchData, + ClosedInterval, PairedBatchData, StaticInputs, adjust_fine_coord_range, + load_fine_coords_from_path, ) from fme.downscaling.metrics_and_maths import filter_tensor_mapping, interpolate from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector @@ -179,7 +181,8 @@ def build( self, coarse_shape: tuple[int, int], downscale_factor: int, - static_inputs: StaticInputs, + full_fine_coords: LatLonCoordinates, + static_inputs: StaticInputs | None = None, rename: dict[str, str] | None = None, ) -> "DiffusionModel": invert_rename = {v: k for k, v in (rename or {}).items()} @@ -192,14 +195,16 @@ def build( # https://en.wikipedia.org/wiki/Standard_score sigma_data = 1.0 - n_in_channels = len(self.in_names) + len(static_inputs.fields) - if self.use_fine_topography and len(static_inputs.fields) == 0: - # TODO: remove this when forcing static inputs to be provided - # . via removing use_fine_topography - # Old checkpoints may not have static inputs serialized, but if - # use_fine_topography is True, we still need to account for the topography - # channel, which was the only static input at the time - n_in_channels += 1 + num_static_in_channels = len(static_inputs.fields) if static_inputs else 0 + n_in_channels = len(self.in_names) + num_static_in_channels + if self.use_fine_topography and ( + not static_inputs or len(static_inputs.fields) == 0 + ): + raise ValueError( + "use_fine_topography is enabled but no static input fields were found. " + "At least one static input field must be provided when using fine " + "topography." + ) module = self.module.build( n_in_channels=n_in_channels, @@ -210,6 +215,12 @@ def build( use_amp_bf16=self.use_amp_bf16, ) + if static_inputs and static_inputs.coords != full_fine_coords: + raise ValueError( + "static_inputs coordinates do not match full_fine_coords. " + "Static inputs must be defined on the same grid as the model output." + ) + return DiffusionModel( config=self, module=module, @@ -218,6 +229,7 @@ def build( coarse_shape=coarse_shape, downscale_factor=downscale_factor, sigma_data=sigma_data, + full_fine_coords=full_fine_coords, static_inputs=static_inputs, ) @@ -274,7 +286,8 @@ def __init__( coarse_shape: tuple[int, int], downscale_factor: int, sigma_data: float, - static_inputs: StaticInputs, + full_fine_coords: LatLonCoordinates, + static_inputs: StaticInputs | None = None, ) -> None: """ Args: @@ -291,9 +304,11 @@ def __init__( coarse to fine. sigma_data: The standard deviation of the data, used for diffusion model preconditioning. - static_inputs: Static inputs to the model. Always provided, carrying - the full-domain fine-resolution coordinates. Fields may be empty - when no static data is needed. + full_fine_coords: The full fine-resolution domain coordinates. + Serves as the canonical source of truth for the model output grid. + static_inputs: Static inputs to the model. May be None when + no static data is needed. If present, coordinates + must match full_fine_coords. """ self.coarse_shape = coarse_shape self.downscale_factor = downscale_factor @@ -306,22 +321,19 @@ def __init__( self.out_packer = Packer(config.out_names) self.config = config self._channel_axis = -3 - self.static_inputs = static_inputs.to_device() + self.full_fine_coords = full_fine_coords.to(get_device()) + self.static_inputs = static_inputs.to_device() if static_inputs else None @property def modules(self) -> torch.nn.ModuleList: return torch.nn.ModuleList([self.module]) - @property - def full_fine_coords(self) -> LatLonCoordinates: - return self.static_inputs.coords - - def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: - return self._get_subset_for_coarse_batch(batch).coords - - def _get_subset_for_coarse_batch(self, batch: BatchData) -> StaticInputs: + def _get_fine_interval_from_batch( + self, batch: BatchData + ) -> tuple[ClosedInterval, ClosedInterval]: coarse_lat = batch.latlon_coordinates.lat[0] coarse_lon = batch.latlon_coordinates.lon[0] + fine_lat_interval = adjust_fine_coord_range( batch.lat_interval, full_coarse_coord=coarse_lat, @@ -334,6 +346,19 @@ def _get_subset_for_coarse_batch(self, batch: BatchData) -> StaticInputs: full_fine_coord=self.full_fine_coords.lon, downscale_factor=self.downscale_factor, ) + return fine_lat_interval, fine_lon_interval + + def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: + lat_interval, lon_interval = self._get_fine_interval_from_batch(batch) + return LatLonCoordinates( + lat=lat_interval.subset_of(self.full_fine_coords.lat), + lon=lon_interval.subset_of(self.full_fine_coords.lon), + ) + + def _subset_static_if_available(self, batch: BatchData) -> StaticInputs | None: + if self.static_inputs is None: + return None + fine_lat_interval, fine_lon_interval = self._get_fine_interval_from_batch(batch) subset_static_inputs = self.static_inputs.subset( lat_interval=fine_lat_interval, lon_interval=fine_lon_interval, @@ -355,7 +380,7 @@ def _get_fine_shape(self, coarse_shape: tuple[int, int]) -> tuple[int, int]: ) def _get_input_from_coarse( - self, coarse: TensorMapping, static_inputs: StaticInputs + self, coarse: TensorMapping, static_inputs: StaticInputs | None ) -> torch.Tensor: inputs = filter_tensor_mapping(coarse, self.in_packer.names) normalized = self.in_packer.pack( @@ -365,7 +390,7 @@ def _get_input_from_coarse( # TODO: update when removing use_fine_topography flag if self.config.use_fine_topography: - if not static_inputs.fields: + if static_inputs is None or not static_inputs.fields: raise ValueError( "Static inputs must be provided for each batch when flag " "use_fine_topography is enabled, but no static input fields " @@ -398,7 +423,7 @@ def train_on_batch( optimizer: Optimization | NullOptimization, ) -> ModelOutputs: """Performs a denoising training step on a batch of data.""" - _static_inputs = self._get_subset_for_coarse_batch(batch.coarse) + _static_inputs = self._subset_static_if_available(batch.coarse) coarse, fine = batch.coarse.data, batch.fine.data inputs_norm = self._get_input_from_coarse(coarse, _static_inputs) targets_norm = self.out_packer.pack( @@ -456,7 +481,7 @@ def train_on_batch( def generate( self, coarse_data: TensorMapping, - static_inputs: StaticInputs, + static_inputs: StaticInputs | None, n_samples: int = 1, ) -> tuple[TensorDict, torch.Tensor, list[torch.Tensor]]: # Internal method; external callers should use generate_on_batch / @@ -512,7 +537,7 @@ def generate_on_batch_no_target( batch: BatchData, n_samples: int = 1, ) -> TensorDict: - _static_inputs = self._get_subset_for_coarse_batch(batch) + _static_inputs = self._subset_static_if_available(batch) generated, _, _ = self.generate(batch.data, _static_inputs, n_samples) return generated @@ -522,7 +547,7 @@ def generate_on_batch( batch: PairedBatchData, n_samples: int = 1, ) -> ModelOutputs: - _static_inputs = self._get_subset_for_coarse_batch(batch.coarse) + _static_inputs = self._subset_static_if_available(batch.coarse) coarse, fine = batch.coarse.data, batch.fine.data generated, generated_norm, latent_steps = self.generate( coarse, _static_inputs, n_samples @@ -547,7 +572,10 @@ def get_state(self) -> Mapping[str, Any]: "module": self.module.state_dict(), "coarse_shape": self.coarse_shape, "downscale_factor": self.downscale_factor, - "static_inputs": self.static_inputs.get_state(), + "full_fine_coords": self.full_fine_coords.get_state(), + "static_inputs": self.static_inputs.get_state() + if self.static_inputs is not None + else None, } @classmethod @@ -561,12 +589,30 @@ def from_state( opposed to the CheckpointModelConfig where we build with the option to provide static_inputs and fine_coordinates for backwards compatibility. """ - static_inputs_state: dict = state.get("static_inputs") or {} - static_inputs = StaticInputs.from_state(static_inputs_state) + static_inputs_state = state.get("static_inputs") + static_inputs = ( + StaticInputs.from_state(static_inputs_state) + if static_inputs_state + else None + ) + full_fine_coords_state = state.get("full_fine_coords") + if full_fine_coords_state is not None: + full_fine_coords = LatLonCoordinates( + lat=full_fine_coords_state["lat"], + lon=full_fine_coords_state["lon"], + ) + else: + raise ValueError( + "No full_fine_coords found in loaded state for DiffusionModel. " + "Must use CheckpointModelConfig with fine_coordinates_path provided " + "for backwards compatibility loading of old checkpoints without " + "full_fine_coords in state." + ) config = DiffusionModelConfig.from_state(state["config"]) model = config.build( state["coarse_shape"], state["downscale_factor"], + full_fine_coords=full_fine_coords, static_inputs=static_inputs, ) model.module.load_state_dict(state["module"], strict=True) @@ -647,22 +693,52 @@ def _checkpoint(self) -> Mapping[str, Any]: checkpoint_data["model"]["config"][k] = v return self._checkpoint_data + @staticmethod + def _get_coords_backwards_compatible( + coords_from_state: dict | None, + fine_coordinates_path: str | None, + ) -> LatLonCoordinates: + if coords_from_state and fine_coordinates_path: + raise ValueError( + "Checkpoint contains fine coordinates but fine_coordinates_path is also" + " provided. Backwards compatibility loading only supports a single " + "source of fine coordinates info." + ) + + if coords_from_state is not None: + return LatLonCoordinates( + lat=coords_from_state["lat"], + lon=coords_from_state["lon"], + ) + elif fine_coordinates_path is not None: + return load_fine_coords_from_path(fine_coordinates_path) + else: + raise ValueError( + "No fine coordinates found in checkpoint state and no " + " fine_coordinates_path provided. One of these must be provided to " + "load the model." + ) + def build( self, ) -> DiffusionModel: checkpoint_model: dict = self._checkpoint["model"] checkpoint_static_inputs = checkpoint_model.get("static_inputs", {}) + full_fine_coords = self._get_coords_backwards_compatible( + checkpoint_model.get("full_fine_coords"), + self.fine_coordinates_path, + ) static_inputs = StaticInputs.from_state_backwards_compatible( state=checkpoint_static_inputs, static_inputs_config=self.static_inputs or {}, - fine_coordinates_path=self.fine_coordinates_path, ) model = _CheckpointModelConfigSelector.from_state( self._checkpoint["model"]["config"] ).build( coarse_shape=self._checkpoint["model"]["coarse_shape"], downscale_factor=self._checkpoint["model"]["downscale_factor"], + full_fine_coords=full_fine_coords, rename=self._rename, static_inputs=static_inputs, ) diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index 005b3f8bd..b8fe01cc8 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -146,6 +146,7 @@ def test_module_serialization(tmp_path): static_inputs = make_static_inputs(fine_shape) fine_coords = static_inputs.coords model = _get_diffusion_model( + full_fine_coords=fine_coords, coarse_shape=coarse_shape, downscale_factor=2, predict_residual=True, @@ -193,18 +194,18 @@ def test_module_serialization(tmp_path): ) -def test_generate_raises_when_no_static_fields_but_topography_required(): +def test_model_raises_when_no_static_fields_but_topography_required(): coarse_shape = (8, 16) fine_shape = (16, 32) - model = _get_diffusion_model( - coarse_shape=coarse_shape, - downscale_factor=2, - use_fine_topography=True, - static_inputs=StaticInputs(fields=[], coords=make_fine_coords(fine_shape)), - ) - batch = make_paired_batch_data(coarse_shape, fine_shape) - with pytest.raises(ValueError, match="Static inputs must be provided"): - model.generate_on_batch(batch) + fine_coords = make_fine_coords(fine_shape) + with pytest.raises(ValueError): + _ = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=2, + full_fine_coords=fine_coords, + use_fine_topography=True, + static_inputs=StaticInputs(fields=[], coords=fine_coords), + ) def test_from_state_backward_compat_migrates_fine_coords_from_old_static_inputs(): @@ -218,6 +219,7 @@ def test_from_state_backward_compat_migrates_fine_coords_from_old_static_inputs( model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=downscale_factor, + full_fine_coords=fine_coords, predict_residual=True, use_fine_topography=True, static_inputs=static_inputs, @@ -225,7 +227,9 @@ def test_from_state_backward_compat_migrates_fine_coords_from_old_static_inputs( state = model.get_state() # Simulate old format: coords embedded in fields, not in static_inputs state del state["static_inputs"]["coords"] - state["static_inputs"]["fields"][0]["coords"] = fine_coords.get_state() + state["static_inputs"]["fields"][0]["coords"] = fine_coords.to( + get_device() + ).get_state() model_from_old_state = DiffusionModel.from_state(state) assert model_from_old_state.full_fine_coords is not None @@ -240,6 +244,7 @@ def test_from_state_backward_compat_migrates_fine_coords_from_old_static_inputs( def _get_diffusion_model( coarse_shape, downscale_factor, + full_fine_coords: LatLonCoordinates, predict_residual=True, use_fine_topography=True, static_inputs: StaticInputs | None = None, @@ -248,13 +253,6 @@ def _get_diffusion_model( NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), ) - if static_inputs is None: - fine_shape = ( - coarse_shape[0] * downscale_factor, - coarse_shape[1] * downscale_factor, - ) - static_inputs = StaticInputs(fields=[], coords=make_fine_coords(fine_shape)) - return DiffusionModelConfig( module=DiffusionModuleRegistrySelector( "unet_diffusion_song", {"model_channels": 4} @@ -274,6 +272,7 @@ def _get_diffusion_model( ).build( coarse_shape, downscale_factor, + full_fine_coords=full_fine_coords, static_inputs=static_inputs, ) @@ -292,6 +291,7 @@ def test_diffusion_model_train_and_generate(predict_residual, use_fine_topograph static_inputs = StaticInputs(fields=[], coords=fine_coords) batch = make_paired_batch_data(coarse_shape, fine_shape, batch_size) model = _get_diffusion_model( + full_fine_coords=fine_coords, coarse_shape=coarse_shape, downscale_factor=2, predict_residual=predict_residual, @@ -340,9 +340,11 @@ def test_normalizer_serialization(tmp_path): stds = xr.Dataset({"x": 1.0}) means.to_netcdf(tmp_path / "means.nc") stds.to_netcdf(tmp_path / "stds.nc") + fine_shape = (coarse_shape[0] * 2, coarse_shape[1] * 2) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=2, + full_fine_coords=make_fine_coords(fine_shape), predict_residual=False, use_fine_topography=False, ) @@ -367,7 +369,6 @@ def test_model_error_cases(): fine_shape = (8, 16) coarse_shape = (4, 8) upscaling_factor = 2 - batch_size = 3 normalization_config = PairedNormalizationConfig( NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), @@ -403,30 +404,27 @@ def test_model_error_cases(): **extra_kwargs, # type: ignore ) - # Compatible init, but no topography provided during prediction + # use_fine_topography=True requires non-empty static input fields at build time module_selector = selector( "prebuilt", {"module": DummyModule()}, expects_interpolated_input=True, ) - model = model_class( # type: ignore - module_selector, # type: ignore - LossConfig(type="MSE"), - ["x"], - ["x"], - normalization_config, - use_fine_topography=True, - **extra_kwargs, # type: ignore - ).build( - coarse_shape, - upscaling_factor, - static_inputs=StaticInputs(fields=[], coords=make_fine_coords(fine_shape)), - ) - batch = make_paired_batch_data(coarse_shape, fine_shape, batch_size) - - # missing fine topography when model requires it — empty fields raises ValueError with pytest.raises(ValueError): - model.generate_on_batch(batch) + model_class( # type: ignore + module_selector, # type: ignore + LossConfig(type="MSE"), + ["x"], + ["x"], + normalization_config, + use_fine_topography=True, + **extra_kwargs, # type: ignore + ).build( + coarse_shape, + upscaling_factor, + full_fine_coords=make_fine_coords(fine_shape), + static_inputs=StaticInputs(fields=[], coords=make_fine_coords(fine_shape)), + ) def test_DiffusionModel_generate_on_batch_no_target(): @@ -437,6 +435,7 @@ def test_DiffusionModel_generate_on_batch_no_target(): model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=downscale_factor, + full_fine_coords=static_inputs.coords, predict_residual=True, use_fine_topography=True, static_inputs=static_inputs, @@ -476,6 +475,7 @@ def test_DiffusionModel_generate_on_batch_no_target_arbitrary_input_size(): model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=downscale_factor, + full_fine_coords=static_inputs.coords, predict_residual=True, use_fine_topography=True, static_inputs=static_inputs, @@ -530,6 +530,7 @@ def test_lognorm_noise_backwards_compatibility(): model = model_config.build( (32, 32), 2, + full_fine_coords=make_fine_coords((64, 64)), static_inputs=StaticInputs(fields=[], coords=make_fine_coords((64, 64))), ) state = model.get_state() @@ -574,11 +575,13 @@ def test_get_fine_coords_for_batch(): coarse_shape = (8, 16) fine_shape = (16, 32) downscale_factor = 2 + static_inputs = make_static_inputs(fine_shape) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=downscale_factor, + full_fine_coords=static_inputs.coords, use_fine_topography=True, - static_inputs=make_static_inputs(fine_shape), + static_inputs=static_inputs, ) # Build a batch covering a spatial patch: middle 4 coarse lats and 8 coarse lons. @@ -600,14 +603,17 @@ def test_checkpoint_model_build_with_fine_coordinates_path(tmp_path): """Old-format checkpoint (no fine_coords key, no coords in static_inputs) should load correctly when fine_coordinates_path is provided.""" coarse_shape = (8, 16) + fine_coords = make_fine_coords((coarse_shape[0] * 2, coarse_shape[1] * 2)) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=2, + full_fine_coords=fine_coords, use_fine_topography=False, + static_inputs=StaticInputs(fields=[], coords=fine_coords), ) - fine_coords = model.full_fine_coords - # Simulate old checkpoint: no coords in static_inputs state + # Simulate old checkpoint: no fine_coords key, no coords in static_inputs state = model.get_state() + del state["full_fine_coords"] del state["static_inputs"]["coords"] checkpoint_path = tmp_path / "test.ckpt" torch.save({"model": state}, checkpoint_path) diff --git a/fme/downscaling/test_predict.py b/fme/downscaling/test_predict.py index 626001b4d..505a69b5d 100644 --- a/fme/downscaling/test_predict.py +++ b/fme/downscaling/test_predict.py @@ -117,14 +117,12 @@ def test_predictor_runs(tmp_path, very_fast_only: bool): fine_data_path = f"{paths.fine}/data.nc" fine_coords = load_fine_coords_from_path(fine_data_path) model_config = get_model_config(coarse_shape, downscale_factor=downscale_factor) + static_inputs = load_static_inputs({"HGTsfc": fine_data_path}) model = model_config.build( coarse_shape=coarse_shape, downscale_factor=downscale_factor, - static_inputs=load_static_inputs( - {"HGTsfc": fine_data_path}, - fallback_coords=fine_coords, - validate_coords=False, - ), + full_fine_coords=fine_coords, + static_inputs=static_inputs, ) with open(predictor_config_path) as f: predictor_config = yaml.safe_load(f) @@ -172,6 +170,7 @@ def test_predictor_renaming( model = model_config.build( coarse_shape=coarse_shape, downscale_factor=2, + full_fine_coords=fine_coords, static_inputs=StaticInputs(fields=[], coords=fine_coords), ) with open(predictor_config_path) as f: diff --git a/fme/downscaling/train.py b/fme/downscaling/train.py index f51b61f94..0b5a4ff66 100755 --- a/fme/downscaling/train.py +++ b/fme/downscaling/train.py @@ -429,9 +429,10 @@ def build(self) -> Trainer: train=False, requirements=self.model.data_requirements, ) - static_inputs = load_static_inputs( - self.static_inputs, fallback_coords=train_data.fine_coords - ) + if self.static_inputs: + static_inputs = load_static_inputs(self.static_inputs) + else: + static_inputs = None if self.coarse_patch_extent_lat and self.coarse_patch_extent_lon: model_coarse_shape = ( @@ -444,6 +445,7 @@ def build(self) -> Trainer: downscaling_model = self.model.build( model_coarse_shape, train_data.downscale_factor, + full_fine_coords=train_data.fine_coords, static_inputs=static_inputs, )