diff --git a/fme/core/coordinates.py b/fme/core/coordinates.py index 3e061fc6a..57d31fdf0 100644 --- a/fme/core/coordinates.py +++ b/fme/core/coordinates.py @@ -619,7 +619,7 @@ def __eq__(self, other) -> bool: def __repr__(self) -> str: return f"LatLonCoordinates(\n lat={self.lat},\n lon={self.lon}\n" - def to(self, device: str) -> "LatLonCoordinates": + def to(self, device: str | torch.device) -> "LatLonCoordinates": return LatLonCoordinates( lon=self.lon.to(device), lat=self.lat.to(device), diff --git a/fme/downscaling/data/__init__.py b/fme/downscaling/data/__init__.py index 5790cdc11..334e1a949 100644 --- a/fme/downscaling/data/__init__.py +++ b/fme/downscaling/data/__init__.py @@ -12,7 +12,12 @@ PairedBatchItem, PairedGriddedData, ) -from .static import StaticInput, StaticInputs, load_static_inputs +from .static import ( + StaticInput, + StaticInputs, + load_fine_coords_from_path, + load_static_inputs, +) from .utils import ( BatchedLatLonCoordinates, ClosedInterval, diff --git a/fme/downscaling/data/config.py b/fme/downscaling/data/config.py index eda213fbc..14396a8e7 100644 --- a/fme/downscaling/data/config.py +++ b/fme/downscaling/data/config.py @@ -23,7 +23,11 @@ PairedBatchData, PairedGriddedData, ) -from fme.downscaling.data.utils import ClosedInterval, adjust_fine_coord_range +from fme.downscaling.data.utils import ( + ClosedInterval, + adjust_fine_coord_range, + get_latlon_coords_from_properties, +) from fme.downscaling.requirements import DataRequirements @@ -529,6 +533,7 @@ def build( dims=example.fine.latlon_coordinates.dims, variable_metadata=variable_metadata, all_times=all_times, + fine_coords=get_latlon_coords_from_properties(properties_fine), ) def _get_sampler( diff --git a/fme/downscaling/data/datasets.py b/fme/downscaling/data/datasets.py index bce3d59c4..61f94f2ed 100644 --- a/fme/downscaling/data/datasets.py +++ b/fme/downscaling/data/datasets.py @@ -138,11 +138,13 @@ def __init__( f"expected lon_min < {self.lon_interval.start + 360.0}" ) - self._lats_slice = self.lat_interval.slice_of(self._orig_coords.lat) - self._lons_slice = self.lon_interval.slice_of(self._orig_coords.lon) + # Used to subset the data in __getitem__ + self._lats_slice = self.lat_interval.slice_from(self._orig_coords.lat) + self._lons_slice = self.lon_interval.slice_from(self._orig_coords.lon) + self._latlon_coordinates = LatLonCoordinates( - lat=self._orig_coords.lat[self._lats_slice], - lon=self._orig_coords.lon[self._lons_slice], + lat=self.lat_interval.subset_of(self._orig_coords.lat), + lon=self.lon_interval.subset_of(self._orig_coords.lon), ) self._area_weights = self._latlon_coordinates.area_weights @@ -337,6 +339,7 @@ class PairedGriddedData: dims: list[str] variable_metadata: Mapping[str, VariableMetadata] all_times: xr.CFTimeIndex + fine_coords: LatLonCoordinates @property def loader(self) -> DataLoader[PairedBatchItem]: diff --git a/fme/downscaling/data/static.py b/fme/downscaling/data/static.py index 4a860174c..8fee89e24 100644 --- a/fme/downscaling/data/static.py +++ b/fme/downscaling/data/static.py @@ -11,18 +11,13 @@ @dataclasses.dataclass class StaticInput: data: torch.Tensor - coords: LatLonCoordinates def __post_init__(self): if len(self.data.shape) != 2: - raise ValueError(f"Topography data must be 2D. Got shape {self.data.shape}") - if self.data.shape[0] != len(self.coords.lat) or self.data.shape[1] != len( - self.coords.lon - ): raise ValueError( - f"Static inputs data shape {self.data.shape} does not match lat/lon " - f"coordinates shape {(len(self.coords.lat), len(self.coords.lon))}" + f"StaticInput data must be 2D. Got shape {self.data.shape}" ) + self._shape = (self.data.shape[0], self.data.shape[1]) @property def dim(self) -> int: @@ -30,50 +25,57 @@ def dim(self) -> int: @property def shape(self) -> tuple[int, int]: - return self.data.shape + return self._shape - def subset_latlon( + def subset( self, - lat_interval: ClosedInterval, - lon_interval: ClosedInterval, + lat_slice: slice, + lon_slice: slice, ) -> "StaticInput": - lat_slice = lat_interval.slice_of(self.coords.lat) - lon_slice = lon_interval.slice_of(self.coords.lon) - return self._latlon_index_slice(lat_slice=lat_slice, lon_slice=lon_slice) + return StaticInput(data=self.data[lat_slice, lon_slice]) def to_device(self) -> "StaticInput": device = get_device() - return StaticInput( - data=self.data.to(device), - coords=LatLonCoordinates( - lat=self.coords.lat.to(device), - lon=self.coords.lon.to(device), - ), - ) - - def _latlon_index_slice( - self, - lat_slice: slice, - lon_slice: slice, - ) -> "StaticInput": - sliced_data = self.data[lat_slice, lon_slice] - sliced_latlon = LatLonCoordinates( - lat=self.coords.lat[lat_slice], - lon=self.coords.lon[lon_slice], - ) - return StaticInput( - data=sliced_data, - coords=sliced_latlon, - ) + return StaticInput(data=self.data.to(device)) def get_state(self) -> dict: return { "data": self.data.cpu(), - "coords": self.coords.get_state(), } + @classmethod + def from_state(cls, state: dict) -> "StaticInput": + return cls(data=state["data"]) + -def _get_normalized_static_input(path: str, field_name: str): +def _load_coords_from_ds(ds: xr.Dataset) -> LatLonCoordinates: + lat_name = next((n for n in ["lat", "latitude", "grid_yt"] if n in ds.coords), None) + lon_name = next( + (n for n in ["lon", "longitude", "grid_xt"] if n in ds.coords), None + ) + if lat_name is None or lon_name is None: + raise ValueError( + f"Could not find lat/lon coordinates in dataset. " + "Expected 'lat'/'latitude'/'grid_yt' and 'lon'/'longitude'/'grid_xt'." + ) + return LatLonCoordinates( + lat=torch.tensor(ds[lat_name].values, dtype=torch.float32), + lon=torch.tensor(ds[lon_name].values, dtype=torch.float32), + ) + + +def load_fine_coords_from_path(path: str) -> LatLonCoordinates: + if path.endswith(".zarr"): + ds = xr.open_zarr(path) + else: + ds = xr.open_dataset(path) + + return _load_coords_from_ds(ds) + + +def _get_normalized_static_input( + path: str, field_name: str +) -> tuple[StaticInput, LatLonCoordinates | None]: """ Load a static input field from a given file path and field name and normalize it. @@ -81,108 +83,175 @@ def _get_normalized_static_input(path: str, field_name: str): Only supports 2D lat/lon static inputs. If the input has a time dimension, it is squeezed by taking the first time step. The lat/lon coordinates are assumed to be the last two dimensions of the loaded dataset dimensions. + + Raises ValueError if lat/lon coordinates are not found in the dataset. """ if path.endswith(".zarr"): - static_input = xr.open_zarr(path, mask_and_scale=False)[field_name] + ds = xr.open_zarr(path, mask_and_scale=False) else: - static_input = xr.open_dataset(path, mask_and_scale=False)[field_name] - if "time" in static_input.dims: - static_input = static_input.isel(time=0).squeeze() - if len(static_input.shape) != 2: + ds = xr.open_dataset(path, mask_and_scale=False) + + da = ds[field_name] + coords = _load_coords_from_ds(ds) + + if "time" in da.dims: + da = da.isel(time=0).squeeze() + if len(da.shape) != 2: raise ValueError( - f"unexpected shape {static_input.shape} for static input." + f"unexpected shape {da.shape} for static input." "Currently, only lat/lon static input is supported." ) - lat_name, lon_name = static_input.dims[-2:] - coords = LatLonCoordinates( - lon=torch.tensor(static_input[lon_name].values), - lat=torch.tensor(static_input[lat_name].values), - ) - static_input_normalized = (static_input - static_input.mean()) / static_input.std() + static_input_normalized = (da - da.mean()) / da.std() return StaticInput( data=torch.tensor(static_input_normalized.values, dtype=torch.float32), - coords=coords, - ) + ), coords + + +def _has_legacy_coords_in_state(state: dict) -> bool: + return "fields" in state and state["fields"] and "coords" in state["fields"][0] + + +def _sync_state_coordinates(state: dict) -> dict: + # if necessary adjusts legacy coordinate to expected + # format for state loading + state = state.copy() + if _has_legacy_coords_in_state(state): + state["coords"] = state["fields"][0]["coords"] + return state + + +def _has_coords_in_state(state: dict) -> bool: + if "coords" in state or _has_legacy_coords_in_state(state): + return True + else: + return False @dataclasses.dataclass class StaticInputs: fields: list[StaticInput] + coords: LatLonCoordinates def __post_init__(self): for i, field in enumerate(self.fields[1:]): - if field.coords != self.fields[0].coords: + if field.shape != self.fields[0].shape: raise ValueError( - f"All StaticInput fields must have the same coordinates. " - f"Fields {i} and 0 do not match coordinates." + f"All StaticInput fields must have the same shape. " + f"Fields {i + 1} and 0 do not match shapes." ) + if self.fields and self.coords.shape != self.fields[0].shape: + raise ValueError( + f"Coordinates shape {self.coords.shape} does not match fields shape " + f"{self.fields[0].shape} for StaticInputs." + ) def __getitem__(self, index: int): return self.fields[index] - @property - def coords(self) -> LatLonCoordinates: - if len(self.fields) == 0: - raise ValueError("No fields in StaticInputs to get coordinates from.") - return self.fields[0].coords - @property def shape(self) -> tuple[int, int]: if len(self.fields) == 0: raise ValueError("No fields in StaticInputs to get shape from.") return self.fields[0].shape - def subset_latlon( + def subset( self, lat_interval: ClosedInterval, lon_interval: ClosedInterval, ) -> "StaticInputs": + lat_slice = lat_interval.slice_from(self.coords.lat) + lon_slice = lon_interval.slice_from(self.coords.lon) return StaticInputs( - fields=[ - field.subset_latlon(lat_interval, lon_interval) for field in self.fields - ] + fields=[field.subset(lat_slice, lon_slice) for field in self.fields], + coords=LatLonCoordinates( + lat=lat_interval.subset_of(self.coords.lat), + lon=lon_interval.subset_of(self.coords.lon), + ), ) def to_device(self) -> "StaticInputs": - return StaticInputs(fields=[field.to_device() for field in self.fields]) + return StaticInputs( + fields=[field.to_device() for field in self.fields], + coords=self.coords.to(get_device()), + ) def get_state(self) -> dict: return { "fields": [field.get_state() for field in self.fields], + "coords": self.coords.get_state(), } @classmethod def from_state(cls, state: dict) -> "StaticInputs": + if not _has_coords_in_state(state): + raise ValueError( + "No coordinates found in state for StaticInputs. Load with " + "from_state_backwards_compatible if loading from a checkpoint " + "saved prior to current coordinate serialization format." + ) + state = _sync_state_coordinates(state) return cls( fields=[ - StaticInput( - data=field_state["data"], - coords=LatLonCoordinates( - lat=field_state["coords"]["lat"], - lon=field_state["coords"]["lon"], - ), - ) - for field_state in state["fields"] - ] + StaticInput.from_state(field_state) for field_state in state["fields"] + ], + coords=LatLonCoordinates( + lat=state["coords"]["lat"], + lon=state["coords"]["lon"], + ), ) + @classmethod + def from_state_backwards_compatible( + cls, + state: dict, + static_inputs_config: dict[str, str], + ) -> "StaticInputs | None": + if state and static_inputs_config: + raise ValueError( + "Checkpoint contains static inputs but static_inputs_config is " + "also provided. Backwards compatibility loading only supports " + "a single source of StaticInputs info." + ) + + if _has_coords_in_state(state): + return cls.from_state(state) + elif static_inputs_config: + return load_static_inputs(static_inputs_config) + else: + return None + + +def _validate_coords( + case: str, coord1: LatLonCoordinates, coord2: LatLonCoordinates +) -> None: + if not coord1 == coord2: + raise ValueError(f"Coordinates do not match between static inputs: {case}") + def load_static_inputs( - static_inputs_config: dict[str, str] | None, -) -> StaticInputs | None: + static_inputs_config: dict[str, str], +) -> StaticInputs: """ Load normalized static inputs from a mapping of field names to file paths. - Returns None if the input config is empty. + + Coordinates are read from each field's source dataset and validated to be + consistent between fields. Raises ValueError if any field's dataset lacks + lat/lon coordinates or if coordinates differ between fields. """ - # TODO: consolidate/simplify empty StaticInputs vs. None handling in - # downscaling code - if not static_inputs_config: - return None - return StaticInputs( - fields=[ - _get_normalized_static_input(path, field_name) - for field_name, path in static_inputs_config.items() - ] - ) + coords_to_use = None + fields = [] + for field_name, path in static_inputs_config.items(): + si, coords = _get_normalized_static_input(path, field_name) + fields.append(si) + + if coords_to_use is None: + coords_to_use = coords + elif coords is not None and coords_to_use is not None: + _validate_coords(field_name, coords, coords_to_use) + + if coords_to_use is None: + raise ValueError("load_static_inputs requires at least one field.") + + return StaticInputs(fields=fields, coords=coords_to_use) diff --git a/fme/downscaling/data/test_static.py b/fme/downscaling/data/test_static.py index edfc0e080..955da241d 100644 --- a/fme/downscaling/data/test_static.py +++ b/fme/downscaling/data/test_static.py @@ -1,64 +1,166 @@ +import numpy as np import pytest import torch +import xarray as xr -from fme.core.coordinates import LatLonCoordinates +from .static import StaticInput, StaticInputs, load_fine_coords_from_path -from .static import StaticInput, StaticInputs -from .utils import ClosedInterval +def _make_coords(n=4): + from fme.core.coordinates import LatLonCoordinates -@pytest.mark.parametrize( - "init_args", - [ - pytest.param( - [ - torch.randn((1, 2, 2)), - LatLonCoordinates(torch.arange(2), torch.arange(2)), - ], - id="3d_data", - ), - pytest.param( - [torch.randn((2, 2)), LatLonCoordinates(torch.arange(2), torch.arange(5))], - id="dim_size_mismatch", - ), - ], -) -def test_Topography_error_cases(init_args): + return LatLonCoordinates( + lat=torch.arange(n, dtype=torch.float32), + lon=torch.arange(n, dtype=torch.float32), + ) + + +def _write_coords_netcdf(path, n=4): + xr.Dataset( + coords={ + "lat": np.arange(n, dtype=np.float32), + "lon": np.arange(n, dtype=np.float32), + } + ).to_netcdf(path) + + +def test_StaticInput_error_cases(): + data = torch.randn(1, 2, 2) with pytest.raises(ValueError): - StaticInput(*init_args) + StaticInput(data=data) -def test_subset_latlon(): +def test_subset(): full_data_shape = (10, 10) - expected_slices = [slice(2, 6), slice(3, 8)] data = torch.randn(*full_data_shape) - coords = LatLonCoordinates( - lat=torch.linspace(0, 9, 10), lon=torch.linspace(0, 9, 10) - ) - topo = StaticInput(data=data, coords=coords) - lat_interval = ClosedInterval(2, 5) - lon_interval = ClosedInterval(3, 7) - subset_topo = topo.subset_latlon(lat_interval, lon_interval) - expected_lats = torch.tensor([2, 3, 4, 5], dtype=coords.lat.dtype) - expected_lons = torch.tensor([3, 4, 5, 6, 7], dtype=coords.lon.dtype) - expected_data = data[*expected_slices] - assert torch.equal(subset_topo.coords.lat, expected_lats) - assert torch.equal(subset_topo.coords.lon, expected_lons) - assert torch.allclose(subset_topo.data, expected_data) + topo = StaticInput(data=data) + lat_slice = slice(2, 6) + lon_slice = slice(3, 8) + subset_topo = topo.subset(lat_slice, lon_slice) + assert torch.allclose(subset_topo.data, data[lat_slice, lon_slice]) def test_StaticInputs_serialize(): - data = torch.arange(16).reshape(4, 4) - topography = StaticInput( - data, - LatLonCoordinates(torch.arange(4), torch.arange(4)), + dim_len = 4 + data = torch.arange(dim_len * dim_len, dtype=torch.float32).reshape( + dim_len, dim_len ) - land_frac = StaticInput( - data * -1.0, - LatLonCoordinates(torch.arange(4), torch.arange(4)), + coords = _make_coords(n=dim_len) + static_inputs = StaticInputs( + [StaticInput(data), StaticInput(data * -1.0)], coords=coords ) - static_inputs = StaticInputs([topography, land_frac]) state = static_inputs.get_state() - static_inputs_reconstructed = StaticInputs.from_state(state) - assert static_inputs_reconstructed[0].data.equal(static_inputs[0].data) - assert static_inputs_reconstructed[1].data.equal(static_inputs[1].data) + assert "coords" in state + reconstructed = StaticInputs.from_state(state) + assert reconstructed[0].data.equal(static_inputs[0].data) + assert reconstructed[1].data.equal(static_inputs[1].data) + assert torch.equal(reconstructed.coords.lat, static_inputs.coords.lat) + assert torch.equal(reconstructed.coords.lon, static_inputs.coords.lon) + + +def test_StaticInputs_from_state_raises_on_missing_coords(): + data = torch.arange(16, dtype=torch.float32).reshape(4, 4) + with pytest.raises(ValueError, match="No coordinates"): + StaticInputs.from_state({"fields": [{"data": data}]}) + + +def test_StaticInputs_from_state_legacy_coords_in_fields(): + """from_state handles old format where coords were stored only inside each field.""" + dim_len = 4 + data = torch.arange(dim_len * dim_len, dtype=torch.float32).reshape( + dim_len, dim_len + ) + coords = _make_coords(n=dim_len) + old_state = { + "fields": [{"data": data, "coords": {"lat": coords.lat, "lon": coords.lon}}], + } + result = StaticInputs.from_state(old_state) + assert torch.equal(result[0].data, data) + assert torch.equal(result.coords.lat, coords.lat) + assert torch.equal(result.coords.lon, coords.lon) + + +def test_from_state_backwards_compatible_has_coords(): + """When state has coords, delegates to from_state.""" + data = torch.arange(16, dtype=torch.float32).reshape(4, 4) + coords = _make_coords() + state = StaticInputs([StaticInput(data)], coords=coords).get_state() + result = StaticInputs.from_state_backwards_compatible( + state=state, static_inputs_config={} + ) + assert result is not None + assert torch.equal(result[0].data, data) + + +def test_from_state_backwards_compatible_no_state_no_config(): + """No static inputs in checkpoint and no config: returns None.""" + result = StaticInputs.from_state_backwards_compatible( + state={}, static_inputs_config={} + ) + assert result is None + + +def test_from_state_backwards_compatible_with_config(tmp_path): + """Checkpoint with static_inputs_config: loads fields and coords from files.""" + dim_len = 4 + lat = np.linspace(0, 1, dim_len, dtype=np.float32) + lon = np.linspace(0, 1, dim_len, dtype=np.float32) + field_data = np.random.rand(dim_len, dim_len).astype(np.float32) + field_path = str(tmp_path / "field.nc") + xr.Dataset( + {"HGTsfc": (["lat", "lon"], field_data)}, + coords={"lat": lat, "lon": lon}, + ).to_netcdf(field_path) + result = StaticInputs.from_state_backwards_compatible( + state={}, + static_inputs_config={"HGTsfc": field_path}, + ) + assert result is not None + assert len(result.fields) == 1 + + +def test_from_state_backwards_compatible_raises_state_and_config(): + """ + Errors if checkpoint state has fields and static_inputs_config is also provided. + """ + dim_len = 4 + data = torch.arange(dim_len * dim_len, dtype=torch.float32).reshape( + dim_len, dim_len + ) + coords = _make_coords(n=dim_len) + state = StaticInputs([StaticInput(data)], coords=coords).get_state() + with pytest.raises(ValueError, match="static_inputs_config"): + StaticInputs.from_state_backwards_compatible( + state=state, + static_inputs_config={"HGTsfc": "some/path"}, + ) + + +@pytest.mark.parametrize( + "lat_name,lon_name", + [ + pytest.param("lat", "lon", id="standard"), + pytest.param("latitude", "longitude", id="long_names"), + pytest.param("grid_yt", "grid_xt", id="fv3_names"), + ], +) +def test_load_fine_coords_from_path(tmp_path, lat_name, lon_name): + lat = [0.0, 1.0, 2.0] + lon = [10.0, 20.0, 30.0, 40.0] + ds = xr.Dataset(coords={lat_name: lat, lon_name: lon}) + path = str(tmp_path / "coords.nc") + ds.to_netcdf(path) + + coords = load_fine_coords_from_path(path) + + assert torch.allclose(coords.lat, torch.tensor(lat, dtype=torch.float32)) + assert torch.allclose(coords.lon, torch.tensor(lon, dtype=torch.float32)) + + +def test_load_fine_coords_from_path_raises_on_missing_coords(tmp_path): + ds = xr.Dataset(coords={"x": [0.0, 1.0], "y": [10.0, 20.0]}) + path = str(tmp_path / "no_latlon.nc") + ds.to_netcdf(path) + + with pytest.raises(ValueError, match="Could not find lat/lon coordinates"): + load_fine_coords_from_path(path) diff --git a/fme/downscaling/data/test_utils.py b/fme/downscaling/data/test_utils.py index 662231147..8f6e7806b 100644 --- a/fme/downscaling/data/test_utils.py +++ b/fme/downscaling/data/test_utils.py @@ -92,13 +92,34 @@ def test_scale_slice(input_slice, expected): ), ], ) -def test_ClosedInterval_slice_of(interval, expected_slice): +def test_ClosedInterval_slice_from(interval, expected_slice): coords = torch.arange(5) - result_slice = interval.slice_of(coords) + result_slice = interval.slice_from(coords) assert result_slice == expected_slice +@pytest.mark.parametrize( + "interval,expected_values", + [ + pytest.param(ClosedInterval(2, 4), torch.tensor([2, 3, 4]), id="middle"), + pytest.param( + ClosedInterval(float("-inf"), 2), torch.tensor([0, 1, 2]), id="start_inf" + ), + pytest.param(ClosedInterval(4, float("inf")), torch.tensor([4]), id="end_inf"), + pytest.param( + ClosedInterval(float("-inf"), float("inf")), + torch.arange(5), + id="all_inf", + ), + ], +) +def test_ClosedInterval_subset_of(interval, expected_values): + coords = torch.arange(5) + result = interval.subset_of(coords) + assert torch.equal(result, expected_values) + + def test_ClosedInterval_fail_on_empty_slice(): coords = torch.arange(5) with pytest.raises(ValueError): - ClosedInterval(5.5, 7).slice_of(coords) + ClosedInterval(5.5, 7).slice_from(coords) diff --git a/fme/downscaling/data/utils.py b/fme/downscaling/data/utils.py index 8cecf8c09..bdd34e188 100644 --- a/fme/downscaling/data/utils.py +++ b/fme/downscaling/data/utils.py @@ -37,9 +37,9 @@ def __post_init__(self): def __contains__(self, value: float): return self.start <= value <= self.stop - def slice_of(self, coords: torch.Tensor) -> slice: + def slice_from(self, coords: torch.Tensor) -> slice: """ - Return a slice that selects all elements of `coords` within this + Return a slice that selects all elements from `coords` within this specified interval. This assumes `coords` is monotonically increasing. Args: @@ -63,6 +63,13 @@ def slice_of(self, coords: torch.Tensor) -> slice: indices = mask.nonzero(as_tuple=True)[0] return slice(indices[0].item(), indices[-1].item() + 1) + def subset_of(self, coords: torch.Tensor) -> torch.Tensor: + """ + Return a subset of `coords` that falls within this specified interval. + This assumes `coords` is monotonically increasing. + """ + return coords[self.slice_from(coords)] + def scale_slice(slice_: slice, scale: int) -> slice: if slice_ == slice(None): diff --git a/fme/downscaling/inference/test_inference.py b/fme/downscaling/inference/test_inference.py index ed97d2b60..aadd9fd68 100644 --- a/fme/downscaling/inference/test_inference.py +++ b/fme/downscaling/inference/test_inference.py @@ -16,6 +16,7 @@ LatLonCoordinates, StaticInput, StaticInputs, + load_fine_coords_from_path, load_static_inputs, ) from fme.downscaling.inference.constants import ENSEMBLE_NAME, TIME_NAME @@ -64,8 +65,11 @@ def mock_output_target(): def get_static_inputs(shape=(16, 16)): data = torch.randn(shape) - coords = LatLonCoordinates(lat=torch.arange(shape[0]), lon=torch.arange(shape[1])) - return StaticInputs([StaticInput(data=data, coords=coords)]) + coords = LatLonCoordinates( + lat=torch.arange(shape[0], dtype=torch.float32), + lon=torch.arange(shape[1], dtype=torch.float32), + ) + return StaticInputs([StaticInput(data=data)], coords=coords) # Tests for Downscaler initialization @@ -201,9 +205,11 @@ def test_run_target_generation_skips_padding_items( mock_output_target.data.get_generator.return_value = iter([mock_work_item]) mock_model.downscale_factor = 2 - mock_model.static_inputs.coords.lat = torch.arange(0, 18).float() - mock_model.static_inputs.coords.lon = torch.arange(0, 18).float() - mock_model.static_inputs.subset_latlon.return_value.fields[0].data = torch.zeros(1) + mock_model.fine_coords = LatLonCoordinates( + lat=torch.arange(0, 18).float(), + lon=torch.arange(0, 18).float(), + ) + mock_model.static_inputs = None mock_model.generate_on_batch_no_target.return_value = { "var1": torch.zeros(1, 4, 16, 16), } @@ -273,8 +279,12 @@ def checkpointed_model_config( # loader_config is passed in to add static inputs into model # that correspond to the dataset coordinates - static_inputs = load_static_inputs({"HGTsfc": f"{data_paths.fine}/data.nc"}) - model = model_config.build(coarse_shape, 2, static_inputs=static_inputs) + fine_data_path = f"{data_paths.fine}/data.nc" + fine_coords = load_fine_coords_from_path(fine_data_path) + static_inputs = load_static_inputs({"HGTsfc": fine_data_path}) + model = model_config.build( + coarse_shape, 2, full_fine_coords=fine_coords, static_inputs=static_inputs + ) checkpoint_path = tmp_path / "model_checkpoint.pth" model.get_state() diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 2d0d1057e..84287f502 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -20,7 +20,7 @@ PairedBatchData, StaticInputs, adjust_fine_coord_range, - load_static_inputs, + load_fine_coords_from_path, ) from fme.downscaling.metrics_and_maths import filter_tensor_mapping, interpolate from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector @@ -181,8 +181,9 @@ def build( self, coarse_shape: tuple[int, int], downscale_factor: int, - rename: dict[str, str] | None = None, + full_fine_coords: LatLonCoordinates, static_inputs: StaticInputs | None = None, + rename: dict[str, str] | None = None, ) -> "DiffusionModel": invert_rename = {v: k for k, v in (rename or {}).items()} orig_in_names = [invert_rename.get(name, name) for name in self.in_names] @@ -194,14 +195,16 @@ def build( # https://en.wikipedia.org/wiki/Standard_score sigma_data = 1.0 - n_in_channels = len(self.in_names) - if static_inputs is not None: - n_in_channels += len(static_inputs.fields) - elif self.use_fine_topography: - # Old checkpoints may not have static inputs serialized, but if - # use_fine_topography is True, we still need to account for the topography - # channel, which was the only static input at the time - n_in_channels += 1 + num_static_in_channels = len(static_inputs.fields) if static_inputs else 0 + n_in_channels = len(self.in_names) + num_static_in_channels + if self.use_fine_topography and ( + not static_inputs or len(static_inputs.fields) == 0 + ): + raise ValueError( + "use_fine_topography is enabled but no static input fields were found. " + "At least one static input field must be provided when using fine " + "topography." + ) module = self.module.build( n_in_channels=n_in_channels, @@ -212,6 +215,12 @@ def build( use_amp_bf16=self.use_amp_bf16, ) + if static_inputs and static_inputs.coords != full_fine_coords: + raise ValueError( + "static_inputs coordinates do not match full_fine_coords. " + "Static inputs must be defined on the same grid as the model output." + ) + return DiffusionModel( config=self, module=module, @@ -220,6 +229,7 @@ def build( coarse_shape=coarse_shape, downscale_factor=downscale_factor, sigma_data=sigma_data, + full_fine_coords=full_fine_coords, static_inputs=static_inputs, ) @@ -276,6 +286,7 @@ def __init__( coarse_shape: tuple[int, int], downscale_factor: int, sigma_data: float, + full_fine_coords: LatLonCoordinates, static_inputs: StaticInputs | None = None, ) -> None: """ @@ -293,8 +304,11 @@ def __init__( coarse to fine. sigma_data: The standard deviation of the data, used for diffusion model preconditioning. - static_inputs: Static inputs to the model, loaded from the trainer - config or checkpoint. Must be set when use_fine_topography is True. + full_fine_coords: The full fine-resolution domain coordinates. + Serves as the canonical source of truth for the model output grid. + static_inputs: Static inputs to the model. May be None when + no static data is needed. If present, coordinates + must match full_fine_coords. """ self.coarse_shape = coarse_shape self.downscale_factor = downscale_factor @@ -307,62 +321,49 @@ def __init__( self.out_packer = Packer(config.out_names) self.config = config self._channel_axis = -3 - self.static_inputs = ( - static_inputs.to_device() if static_inputs is not None else None - ) + self.full_fine_coords = full_fine_coords.to(get_device()) + self.static_inputs = static_inputs.to_device() if static_inputs else None @property def modules(self) -> torch.nn.ModuleList: return torch.nn.ModuleList([self.module]) - def _subset_static_inputs( - self, - lat_interval: ClosedInterval, - lon_interval: ClosedInterval, - ) -> StaticInputs | None: - """Subset self.static_inputs to the given fine lat/lon interval. - - Returns None if use_fine_topography is False. - Raises ValueError if use_fine_topography is True but self.static_inputs is None. - """ - if not self.config.use_fine_topography: - return None - if self.static_inputs is None: - raise ValueError( - "Static inputs must be provided for each batch when use of fine " - "static inputs is enabled." - ) - return self.static_inputs.subset_latlon(lat_interval, lon_interval) - - def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: - """Return fine-resolution coordinates matching the spatial extent of batch.""" - if self.static_inputs is None: - raise ValueError( - "Model is missing static inputs, which are required to determine " - "the coordinate information for the output dataset." - ) + def _get_fine_interval_from_batch( + self, batch: BatchData + ) -> tuple[ClosedInterval, ClosedInterval]: coarse_lat = batch.latlon_coordinates.lat[0] coarse_lon = batch.latlon_coordinates.lon[0] + fine_lat_interval = adjust_fine_coord_range( batch.lat_interval, full_coarse_coord=coarse_lat, - full_fine_coord=self.static_inputs.coords.lat, + full_fine_coord=self.full_fine_coords.lat, downscale_factor=self.downscale_factor, ) fine_lon_interval = adjust_fine_coord_range( batch.lon_interval, full_coarse_coord=coarse_lon, - full_fine_coord=self.static_inputs.coords.lon, + full_fine_coord=self.full_fine_coords.lon, downscale_factor=self.downscale_factor, ) + return fine_lat_interval, fine_lon_interval + + def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: + lat_interval, lon_interval = self._get_fine_interval_from_batch(batch) return LatLonCoordinates( - lat=self.static_inputs.coords.lat[ - fine_lat_interval.slice_of(self.static_inputs.coords.lat) - ], - lon=self.static_inputs.coords.lon[ - fine_lon_interval.slice_of(self.static_inputs.coords.lon) - ], + lat=lat_interval.subset_of(self.full_fine_coords.lat), + lon=lon_interval.subset_of(self.full_fine_coords.lon), + ) + + def _subset_static_if_available(self, batch: BatchData) -> StaticInputs | None: + if self.static_inputs is None: + return None + fine_lat_interval, fine_lon_interval = self._get_fine_interval_from_batch(batch) + subset_static_inputs = self.static_inputs.subset( + lat_interval=fine_lat_interval, + lon_interval=fine_lon_interval, ) + return subset_static_inputs @property def fine_shape(self) -> tuple[int, int]: @@ -387,11 +388,13 @@ def _get_input_from_coarse( ) interpolated = interpolate(normalized, self.downscale_factor) + # TODO: update when removing use_fine_topography flag if self.config.use_fine_topography: - if static_inputs is None: + if static_inputs is None or not static_inputs.fields: raise ValueError( - "Static inputs must be provided for each batch when use of fine " - "static inputs is enabled." + "Static inputs must be provided for each batch when flag " + "use_fine_topography is enabled, but no static input fields " + "were found." ) else: expected_shape = interpolated.shape[-2:] @@ -402,12 +405,13 @@ def _get_input_from_coarse( ) n_batches = normalized.shape[0] # Join normalized static inputs to input (see dataset for details) + fields: list[torch.Tensor] = [interpolated] for field in static_inputs.fields: - topo = field.data.unsqueeze(0).repeat(n_batches, 1, 1) - topo = topo.unsqueeze(self._channel_axis) - interpolated = torch.concat( - [interpolated, topo], axis=self._channel_axis - ) + static_field = field.data.unsqueeze(0).repeat(n_batches, 1, 1) + static_field = static_field.unsqueeze(self._channel_axis) + fields.append(static_field) + + interpolated = torch.concat(fields, dim=self._channel_axis) if self.config._interpolate_input: return interpolated @@ -419,9 +423,7 @@ def train_on_batch( optimizer: Optimization | NullOptimization, ) -> ModelOutputs: """Performs a denoising training step on a batch of data.""" - _static_inputs = self._subset_static_inputs( - batch.fine.lat_interval, batch.fine.lon_interval - ) + _static_inputs = self._subset_static_if_available(batch.coarse) coarse, fine = batch.coarse.data, batch.fine.data inputs_norm = self._get_input_from_coarse(coarse, _static_inputs) targets_norm = self.out_packer.pack( @@ -535,31 +537,7 @@ def generate_on_batch_no_target( batch: BatchData, n_samples: int = 1, ) -> TensorDict: - if self.config.use_fine_topography: - if self.static_inputs is None: - raise ValueError( - "Static inputs must be provided for each batch when use of fine " - "static inputs is enabled." - ) - coarse_lat = batch.latlon_coordinates.lat[0] - coarse_lon = batch.latlon_coordinates.lon[0] - fine_lat_interval = adjust_fine_coord_range( - batch.lat_interval, - full_coarse_coord=coarse_lat, - full_fine_coord=self.static_inputs.coords.lat, - downscale_factor=self.downscale_factor, - ) - fine_lon_interval = adjust_fine_coord_range( - batch.lon_interval, - full_coarse_coord=coarse_lon, - full_fine_coord=self.static_inputs.coords.lon, - downscale_factor=self.downscale_factor, - ) - _static_inputs = self.static_inputs.subset_latlon( - fine_lat_interval, fine_lon_interval - ) - else: - _static_inputs = None + _static_inputs = self._subset_static_if_available(batch) generated, _, _ = self.generate(batch.data, _static_inputs, n_samples) return generated @@ -569,9 +547,7 @@ def generate_on_batch( batch: PairedBatchData, n_samples: int = 1, ) -> ModelOutputs: - _static_inputs = self._subset_static_inputs( - batch.fine.lat_interval, batch.fine.lon_interval - ) + _static_inputs = self._subset_static_if_available(batch.coarse) coarse, fine = batch.coarse.data, batch.fine.data generated, generated_norm, latent_steps = self.generate( coarse, _static_inputs, n_samples @@ -591,17 +567,15 @@ def generate_on_batch( ) def get_state(self) -> Mapping[str, Any]: - if self.static_inputs is not None: - static_inputs_state = self.static_inputs.get_state() - else: - static_inputs_state = None - return { "config": self.config.get_state(), "module": self.module.state_dict(), "coarse_shape": self.coarse_shape, "downscale_factor": self.downscale_factor, - "static_inputs": static_inputs_state, + "full_fine_coords": self.full_fine_coords.get_state(), + "static_inputs": self.static_inputs.get_state() + if self.static_inputs is not None + else None, } @classmethod @@ -609,15 +583,36 @@ def from_state( cls, state: Mapping[str, Any], ) -> "DiffusionModel": - config = DiffusionModelConfig.from_state(state["config"]) - # backwards compatibility for models before static inputs serialization - if state.get("static_inputs") is not None: - static_inputs = StaticInputs.from_state(state["static_inputs"]).to_device() + """ + Model-level from state is used to reconstruct the full model during training via + loading a checkpoint. This requires recent checkpoint state fields, as + opposed to the CheckpointModelConfig where we build with the option to provide + static_inputs and fine_coordinates for backwards compatibility. + """ + static_inputs_state = state.get("static_inputs") + static_inputs = ( + StaticInputs.from_state(static_inputs_state) + if static_inputs_state + else None + ) + full_fine_coords_state = state.get("full_fine_coords") + if full_fine_coords_state is not None: + full_fine_coords = LatLonCoordinates( + lat=full_fine_coords_state["lat"], + lon=full_fine_coords_state["lon"], + ) else: - static_inputs = None + raise ValueError( + "No full_fine_coords found in loaded state for DiffusionModel. " + "Must use CheckpointModelConfig with fine_coordinates_path provided " + "for backwards compatibility loading of old checkpoints without " + "full_fine_coords in state." + ) + config = DiffusionModelConfig.from_state(state["config"]) model = config.build( state["coarse_shape"], state["downscale_factor"], + full_fine_coords=full_fine_coords, static_inputs=static_inputs, ) model.module.load_state_dict(state["module"], strict=True) @@ -648,6 +643,9 @@ class CheckpointModelConfig: but the model requires static input data. Raises an error if the checkpoint already has static inputs from training. fine_topography_path: Deprecated. Use static_inputs instead. + fine_coordinates_path: Optional path to a netCDF/zarr file containing lat/lon + coordinates for the full fine domain. Used for old checkpoints that have + no static_inputs and no stored fine_coords. model_updates: Optional mapping of {key: new_value} model config updates to apply when loading the model. This is useful for running evaluation with updated parameters than at training time. Use with caution; not all @@ -658,6 +656,7 @@ class CheckpointModelConfig: rename: dict[str, str] | None = None static_inputs: dict[str, str] | None = None fine_topography_path: str | None = None + fine_coordinates_path: str | None = None model_updates: dict[str, Any] | None = None def __post_init__(self) -> None: @@ -686,8 +685,6 @@ def _checkpoint(self) -> Mapping[str, Any]: self._rename.get(name, name) for name in checkpoint_data["model"]["config"]["out_names"] ] - # backwards compatibility for models before static inputs serialization - checkpoint_data["model"].setdefault("static_inputs", None) self._checkpoint_data = checkpoint_data self._checkpoint_is_loaded = True @@ -696,29 +693,52 @@ def _checkpoint(self) -> Mapping[str, Any]: checkpoint_data["model"]["config"][k] = v return self._checkpoint_data + @staticmethod + def _get_coords_backwards_compatible( + coords_from_state: dict | None, + fine_coordinates_path: str | None, + ) -> LatLonCoordinates: + if coords_from_state and fine_coordinates_path: + raise ValueError( + "Checkpoint contains fine coordinates but fine_coordinates_path is also" + " provided. Backwards compatibility loading only supports a single " + "source of fine coordinates info." + ) + + if coords_from_state is not None: + return LatLonCoordinates( + lat=coords_from_state["lat"], + lon=coords_from_state["lon"], + ) + elif fine_coordinates_path is not None: + return load_fine_coords_from_path(fine_coordinates_path) + else: + raise ValueError( + "No fine coordinates found in checkpoint state and no " + " fine_coordinates_path provided. One of these must be provided to " + "load the model." + ) + def build( self, ) -> DiffusionModel: - static_inputs: StaticInputs | None - if self._checkpoint["model"]["static_inputs"] is not None: - if self.static_inputs is not None: - raise ValueError( - "The model checkpoint already has static inputs from training. " - "static_inputs should not be provided in checkpoint model config." - "static inputs from training." - ) - static_inputs = StaticInputs.from_state( - self._checkpoint["model"]["static_inputs"] - ) - elif self.static_inputs is not None: - static_inputs = load_static_inputs(self.static_inputs) - else: - static_inputs = None + checkpoint_model: dict = self._checkpoint["model"] + checkpoint_static_inputs = checkpoint_model.get("static_inputs", {}) + full_fine_coords = self._get_coords_backwards_compatible( + checkpoint_model.get("full_fine_coords"), + self.fine_coordinates_path, + ) + + static_inputs = StaticInputs.from_state_backwards_compatible( + state=checkpoint_static_inputs, + static_inputs_config=self.static_inputs or {}, + ) model = _CheckpointModelConfigSelector.from_state( self._checkpoint["model"]["config"] ).build( coarse_shape=self._checkpoint["model"]["coarse_shape"], downscale_factor=self._checkpoint["model"]["downscale_factor"], + full_fine_coords=full_fine_coords, rename=self._rename, static_inputs=static_inputs, ) diff --git a/fme/downscaling/predict.py b/fme/downscaling/predict.py index b4da331aa..c6b0a3677 100644 --- a/fme/downscaling/predict.py +++ b/fme/downscaling/predict.py @@ -9,7 +9,6 @@ import yaml from fme.core.cli import prepare_directory -from fme.core.coordinates import LatLonCoordinates from fme.core.dataset.time import TimeSlice from fme.core.dicts import to_flat_dict from fme.core.distributed import Distributed @@ -29,31 +28,6 @@ from fme.downscaling.typing_ import FineResCoarseResPair -def _downscale_coord(coord: torch.tensor, downscale_factor: int): - """ - This is a bandaid fix for the issue where BatchData does not - contain coords for the topography, which is fine-res in the no-target - generation case. The SampleAggregator requires the fine-res coords - for the predictions. - - TODO: remove after topography refactors to have its own data container. - """ - if len(coord.shape) != 1: - raise ValueError("coord tensor to downscale must be 1d") - spacing = coord[1] - coord[0] - # Compute edges from midpoints - first_edge = coord[0] - spacing / 2 - last_edge = coord[-1] + spacing / 2 - - # Subdivide edges - step = spacing / downscale_factor - new_edges = torch.arange(first_edge, last_edge + step / 2, step) - - # Compute new midpoints - coord_new = (new_edges[:-1] + new_edges[1:]) / 2 - return coord_new.to(device=coord.device, dtype=coord.dtype) - - @dataclasses.dataclass class EventConfig: name: str @@ -144,11 +118,8 @@ def generation_model(self): def run(self): logging.info(f"Running {self.event_name} event downscaling...") batch = next(iter(self.data.get_generator())) - coarse_coords = batch[0].latlon_coordinates - fine_coords = LatLonCoordinates( - lat=_downscale_coord(coarse_coords.lat, self.model.downscale_factor), - lon=_downscale_coord(coarse_coords.lon, self.model.downscale_factor), - ) + coarse_coords = batch.latlon_coordinates[0] + fine_coords = self.model.get_fine_coords_for_batch(batch) sample_agg = SampleAggregator( coarse=batch[0].data, latlon_coordinates=FineResCoarseResPair( diff --git a/fme/downscaling/predictors/composite.py b/fme/downscaling/predictors/composite.py index 05add74a0..45256264f 100644 --- a/fme/downscaling/predictors/composite.py +++ b/fme/downscaling/predictors/composite.py @@ -86,6 +86,10 @@ def static_inputs(self): def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: return self.model.get_fine_coords_for_batch(batch) + @property + def full_fine_coords(self): + return self.model.full_fine_coords + def _get_patches( self, coarse_yx_extent, fine_yx_extent ) -> tuple[list[Patch], list[Patch]]: diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index 864f511f3..b8fe01cc8 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -124,25 +124,29 @@ def make_paired_batch_data( def make_static_inputs(fine_shape: tuple[int, int]) -> StaticInputs: - """Create StaticInputs with proper monotonic coordinates for given shape.""" - lat_size, lon_size = fine_shape + """Create StaticInputs with one field and matching coords for the given shape.""" return StaticInputs( - fields=[ - StaticInput( - torch.ones(*fine_shape, device=get_device()), - LatLonCoordinates( - lat=_get_monotonic_coordinate(lat_size, stop=lat_size), - lon=_get_monotonic_coordinate(lon_size, stop=lon_size), - ), - ) - ] + fields=[StaticInput(torch.ones(*fine_shape, device=get_device()))], + coords=make_fine_coords(fine_shape), + ) + + +def make_fine_coords(fine_shape: tuple[int, int]) -> LatLonCoordinates: + """Create LatLonCoordinates for given fine shape.""" + lat_size, lon_size = fine_shape + return LatLonCoordinates( + lat=_get_monotonic_coordinate(lat_size, stop=lat_size), + lon=_get_monotonic_coordinate(lon_size, stop=lon_size), ) def test_module_serialization(tmp_path): coarse_shape = (8, 16) - static_inputs = make_static_inputs((16, 32)) + fine_shape = (16, 32) + static_inputs = make_static_inputs(fine_shape) + fine_coords = static_inputs.coords model = _get_diffusion_model( + full_fine_coords=fine_coords, coarse_shape=coarse_shape, downscale_factor=2, predict_residual=True, @@ -158,6 +162,13 @@ def test_module_serialization(tmp_path): model.module.parameters(), model_from_state.module.parameters() ) ) + assert model_from_state.full_fine_coords is not None + assert torch.equal( + model_from_state.full_fine_coords.lat.cpu(), fine_coords.lat.cpu() + ) + assert torch.equal( + model_from_state.full_fine_coords.lon.cpu(), fine_coords.lon.cpu() + ) torch.save(model.get_state(), tmp_path / "test.ckpt") model_from_disk = DiffusionModel.from_state( @@ -174,53 +185,74 @@ def test_module_serialization(tmp_path): assert torch.equal( loaded_static_inputs.fields[0].data, static_inputs.fields[0].data ) + assert model_from_disk.full_fine_coords is not None + assert torch.equal( + model_from_disk.full_fine_coords.lat.cpu(), fine_coords.lat.cpu() + ) + assert torch.equal( + model_from_disk.full_fine_coords.lon.cpu(), fine_coords.lon.cpu() + ) + + +def test_model_raises_when_no_static_fields_but_topography_required(): + coarse_shape = (8, 16) + fine_shape = (16, 32) + fine_coords = make_fine_coords(fine_shape) + with pytest.raises(ValueError): + _ = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=2, + full_fine_coords=fine_coords, + use_fine_topography=True, + static_inputs=StaticInputs(fields=[], coords=fine_coords), + ) -def test_from_state_backward_compat_fine_topography(): +def test_from_state_backward_compat_migrates_fine_coords_from_old_static_inputs(): + """Old checkpoints that stored coords in static_inputs fields should have + fine_coords auto-migrated on from_state.""" coarse_shape = (8, 16) fine_shape = (16, 32) downscale_factor = 2 static_inputs = make_static_inputs(fine_shape) + fine_coords = static_inputs.coords model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=downscale_factor, + full_fine_coords=fine_coords, predict_residual=True, use_fine_topography=True, static_inputs=static_inputs, ) - - # Simulate old checkpoint format: static_inputs not serialized state = model.get_state() - state["static_inputs"] = None + # Simulate old format: coords embedded in fields, not in static_inputs state + del state["static_inputs"]["coords"] + state["static_inputs"]["fields"][0]["coords"] = fine_coords.to( + get_device() + ).get_state() - # Should load correctly via the elif use_fine_topography branch (+1 channel) model_from_old_state = DiffusionModel.from_state(state) - assert model_from_old_state.static_inputs is None - assert all( - torch.equal(p1, p2) - for p1, p2 in zip( - model.module.parameters(), model_from_old_state.module.parameters() - ) + assert model_from_old_state.full_fine_coords is not None + assert torch.equal( + model_from_old_state.full_fine_coords.lat.cpu(), fine_coords.lat.cpu() + ) + assert torch.equal( + model_from_old_state.full_fine_coords.lon.cpu(), fine_coords.lon.cpu() ) - - # At runtime, omitting static inputs must raise a clear error - batch = get_mock_paired_batch([2, *coarse_shape], [2, *fine_shape]) - with pytest.raises(ValueError, match="Static inputs must be provided"): - model_from_old_state.generate_on_batch(batch) def _get_diffusion_model( coarse_shape, downscale_factor, + full_fine_coords: LatLonCoordinates, predict_residual=True, use_fine_topography=True, - static_inputs=None, + static_inputs: StaticInputs | None = None, ): normalizer = PairedNormalizationConfig( NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), ) - return DiffusionModelConfig( module=DiffusionModuleRegistrySelector( "unet_diffusion_song", {"model_channels": 4} @@ -237,7 +269,12 @@ def _get_diffusion_model( num_diffusion_generation_steps=3, predict_residual=predict_residual, use_fine_topography=use_fine_topography, - ).build(coarse_shape, downscale_factor, static_inputs=static_inputs) + ).build( + coarse_shape, + downscale_factor, + full_fine_coords=full_fine_coords, + static_inputs=static_inputs, + ) @pytest.mark.parametrize("predict_residual", [True, False]) @@ -246,15 +283,15 @@ def test_diffusion_model_train_and_generate(predict_residual, use_fine_topograph coarse_shape = (8, 16) fine_shape = (16, 32) batch_size = 2 + fine_coords = make_fine_coords(fine_shape) if use_fine_topography: static_inputs = make_static_inputs(fine_shape) batch = make_paired_batch_data(coarse_shape, fine_shape, batch_size) else: - static_inputs = None - batch = get_mock_paired_batch( - [batch_size, *coarse_shape], [batch_size, *fine_shape] - ) + static_inputs = StaticInputs(fields=[], coords=fine_coords) + batch = make_paired_batch_data(coarse_shape, fine_shape, batch_size) model = _get_diffusion_model( + full_fine_coords=fine_coords, coarse_shape=coarse_shape, downscale_factor=2, predict_residual=predict_residual, @@ -303,9 +340,11 @@ def test_normalizer_serialization(tmp_path): stds = xr.Dataset({"x": 1.0}) means.to_netcdf(tmp_path / "means.nc") stds.to_netcdf(tmp_path / "stds.nc") + fine_shape = (coarse_shape[0] * 2, coarse_shape[1] * 2) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=2, + full_fine_coords=make_fine_coords(fine_shape), predict_residual=False, use_fine_topography=False, ) @@ -330,7 +369,6 @@ def test_model_error_cases(): fine_shape = (8, 16) coarse_shape = (4, 8) upscaling_factor = 2 - batch_size = 3 normalization_config = PairedNormalizationConfig( NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), @@ -366,32 +404,27 @@ def test_model_error_cases(): **extra_kwargs, # type: ignore ) - # Compatible init, but no topography provided during prediction + # use_fine_topography=True requires non-empty static input fields at build time module_selector = selector( "prebuilt", {"module": DummyModule()}, expects_interpolated_input=True, ) - model = model_class( # type: ignore - module_selector, # type: ignore - LossConfig(type="MSE"), - ["x"], - ["x"], - normalization_config, - use_fine_topography=True, - **extra_kwargs, # type: ignore - ).build( - coarse_shape, - upscaling_factor, - ) - batch = get_mock_paired_batch( - [batch_size, *coarse_shape], [batch_size, *fine_shape] - ) - - # missing fine topography when model requires it - batch.fine.topography = None with pytest.raises(ValueError): - model.generate_on_batch(batch) + model_class( # type: ignore + module_selector, # type: ignore + LossConfig(type="MSE"), + ["x"], + ["x"], + normalization_config, + use_fine_topography=True, + **extra_kwargs, # type: ignore + ).build( + coarse_shape, + upscaling_factor, + full_fine_coords=make_fine_coords(fine_shape), + static_inputs=StaticInputs(fields=[], coords=make_fine_coords(fine_shape)), + ) def test_DiffusionModel_generate_on_batch_no_target(): @@ -402,6 +435,7 @@ def test_DiffusionModel_generate_on_batch_no_target(): model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=downscale_factor, + full_fine_coords=static_inputs.coords, predict_residual=True, use_fine_topography=True, static_inputs=static_inputs, @@ -435,11 +469,13 @@ def test_DiffusionModel_generate_on_batch_no_target_arbitrary_input_size(): # Full fine domain: 64x64 covers inputs for both (8,8) and (32,32) coarse inputs # with a downscaling factor of 2 full_fine_size = 64 - static_inputs = make_static_inputs((full_fine_size, full_fine_size)) + full_fine_shape = (full_fine_size, full_fine_size) + static_inputs = make_static_inputs(full_fine_shape) # need to build with static inputs to get the correct n_in_channels model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=downscale_factor, + full_fine_coords=static_inputs.coords, predict_residual=True, use_fine_topography=True, static_inputs=static_inputs, @@ -494,6 +530,8 @@ def test_lognorm_noise_backwards_compatibility(): model = model_config.build( (32, 32), 2, + full_fine_coords=make_fine_coords((64, 64)), + static_inputs=StaticInputs(fields=[], coords=make_fine_coords((64, 64))), ) state = model.get_state() @@ -541,6 +579,7 @@ def test_get_fine_coords_for_batch(): model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=downscale_factor, + full_fine_coords=static_inputs.coords, use_fine_topography=True, static_inputs=static_inputs, ) @@ -554,55 +593,45 @@ def test_get_fine_coords_for_batch(): result = model.get_fine_coords_for_batch(batch) - expected_lat = model.static_inputs.coords.lat[4:12] - expected_lon = model.static_inputs.coords.lon[8:24] - # model.static_inputs has been moved to device; index into it directly - # to match devices + expected_lat = model.full_fine_coords.lat[4:12] + expected_lon = model.full_fine_coords.lon[8:24] assert torch.allclose(result.lat, expected_lat) assert torch.allclose(result.lon, expected_lon) -def test_get_fine_coords_for_batch_raises_without_static_inputs(): - model = _get_diffusion_model( - coarse_shape=(16, 16), - downscale_factor=2, - use_fine_topography=False, - static_inputs=None, - ) - batch = make_batch_data( - (1, 16, 16), - _get_monotonic_coordinate(16, stop=16).tolist(), - _get_monotonic_coordinate(16, stop=16).tolist(), - ) - with pytest.raises(ValueError, match="missing static inputs"): - model.get_fine_coords_for_batch(batch) - - -def test_checkpoint_config_topography_raises(): - with pytest.raises(ValueError): - CheckpointModelConfig( - checkpoint_path="/any/path.ckpt", - fine_topography_path="/topo/path.nc", - ) - - -def test_checkpoint_model_build_raises_when_checkpoint_has_static_inputs(tmp_path): +def test_checkpoint_model_build_with_fine_coordinates_path(tmp_path): + """Old-format checkpoint (no fine_coords key, no coords in static_inputs) + should load correctly when fine_coordinates_path is provided.""" coarse_shape = (8, 16) - fine_shape = (16, 32) - static_inputs = make_static_inputs(fine_shape) + fine_coords = make_fine_coords((coarse_shape[0] * 2, coarse_shape[1] * 2)) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=2, - predict_residual=True, - use_fine_topography=True, - static_inputs=static_inputs, + full_fine_coords=fine_coords, + use_fine_topography=False, + static_inputs=StaticInputs(fields=[], coords=fine_coords), ) + # Simulate old checkpoint: no fine_coords key, no coords in static_inputs + state = model.get_state() + del state["full_fine_coords"] + del state["static_inputs"]["coords"] checkpoint_path = tmp_path / "test.ckpt" - torch.save({"model": model.get_state()}, checkpoint_path) + torch.save({"model": state}, checkpoint_path) - config = CheckpointModelConfig( - checkpoint_path=str(checkpoint_path), - static_inputs={"HGTsfc": "/any/path.nc"}, + # Write fine coords to a netCDF file for the loader to read + coords_path = tmp_path / "fine_coords.nc" + ds = xr.Dataset( + coords={ + "lat": fine_coords.lat.cpu().numpy(), + "lon": fine_coords.lon.cpu().numpy(), + } ) - with pytest.raises(ValueError): - config.build() + ds.to_netcdf(coords_path) + + loaded_model = CheckpointModelConfig( + checkpoint_path=str(checkpoint_path), + fine_coordinates_path=str(coords_path), + ).build() + + assert torch.equal(loaded_model.full_fine_coords.lat.cpu(), fine_coords.lat.cpu()) + assert torch.equal(loaded_model.full_fine_coords.lon.cpu(), fine_coords.lon.cpu()) diff --git a/fme/downscaling/test_predict.py b/fme/downscaling/test_predict.py index ce253ab14..505a69b5d 100644 --- a/fme/downscaling/test_predict.py +++ b/fme/downscaling/test_predict.py @@ -8,7 +8,11 @@ from fme.core.normalizer import NormalizationConfig from fme.core.testing.wandb import mock_wandb from fme.downscaling import predict -from fme.downscaling.data import load_static_inputs +from fme.downscaling.data import ( + StaticInputs, + load_fine_coords_from_path, + load_static_inputs, +) from fme.downscaling.models import DiffusionModelConfig, PairedNormalizationConfig from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector from fme.downscaling.test_models import LinearDownscaling @@ -97,7 +101,7 @@ def create_predictor_config( out_path = tmp_path / "predictor-config.yaml" with open(out_path, "w") as file: yaml.dump(config, file) - return out_path, f"{paths.fine}/data.nc" + return out_path, paths def test_predictor_runs(tmp_path, very_fast_only: bool): @@ -106,15 +110,19 @@ def test_predictor_runs(tmp_path, very_fast_only: bool): n_samples = 2 coarse_shape = (4, 4) downscale_factor = 2 - predictor_config_path, fine_data_path = create_predictor_config( + predictor_config_path, paths = create_predictor_config( tmp_path, n_samples, ) + fine_data_path = f"{paths.fine}/data.nc" + fine_coords = load_fine_coords_from_path(fine_data_path) model_config = get_model_config(coarse_shape, downscale_factor=downscale_factor) + static_inputs = load_static_inputs({"HGTsfc": fine_data_path}) model = model_config.build( coarse_shape=coarse_shape, downscale_factor=downscale_factor, - static_inputs=load_static_inputs({"HGTsfc": fine_data_path}), + full_fine_coords=fine_coords, + static_inputs=static_inputs, ) with open(predictor_config_path) as f: predictor_config = yaml.safe_load(f) @@ -147,7 +155,7 @@ def test_predictor_renaming( coarse_shape = (4, 4) downscale_factor = 2 renaming = {"var0": "var0_renamed", "var1": "var1_renamed"} - predictor_config_path, fine_data_path = create_predictor_config( + predictor_config_path, paths = create_predictor_config( tmp_path, n_samples, model_renaming=renaming, @@ -155,13 +163,15 @@ def test_predictor_renaming( "rename": {"var0": "var0_renamed", "var1": "var1_renamed"} }, ) + fine_coords = load_fine_coords_from_path(f"{paths.fine}/data.nc") model_config = get_model_config( coarse_shape, downscale_factor, use_fine_topography=False ) model = model_config.build( coarse_shape=coarse_shape, downscale_factor=2, - static_inputs=load_static_inputs({"HGTsfc": fine_data_path}), + full_fine_coords=fine_coords, + static_inputs=StaticInputs(fields=[], coords=fine_coords), ) with open(predictor_config_path) as f: predictor_config = yaml.safe_load(f) diff --git a/fme/downscaling/test_utils.py b/fme/downscaling/test_utils.py index 4f34a4451..b1546e044 100644 --- a/fme/downscaling/test_utils.py +++ b/fme/downscaling/test_utils.py @@ -78,4 +78,5 @@ def data_paths_helper( create_test_data_on_disk( coarse_path / "data.nc", dim_sizes.coarse, variable_names, coords ) + # TODO: should this return the full filename instead of just the directory? return FineResCoarseResPair[str](fine=fine_path, coarse=coarse_path) diff --git a/fme/downscaling/train.py b/fme/downscaling/train.py index a1dd08fde..0b5a4ff66 100755 --- a/fme/downscaling/train.py +++ b/fme/downscaling/train.py @@ -393,7 +393,7 @@ class TrainerConfig: experiment_dir: str save_checkpoints: bool logging: LoggingConfig - static_inputs: dict[str, str] | None = None + static_inputs: dict[str, str] = dataclasses.field(default_factory=dict) ema: EMAConfig = dataclasses.field(default_factory=EMAConfig) validate_using_ema: bool = False generate_n_samples: int = 1 @@ -421,8 +421,6 @@ def checkpoint_dir(self) -> str: return os.path.join(self.experiment_dir, "checkpoints") def build(self) -> Trainer: - static_inputs = load_static_inputs(self.static_inputs) - train_data: PairedGriddedData = self.train_data.build( train=True, requirements=self.model.data_requirements, @@ -431,6 +429,11 @@ def build(self) -> Trainer: train=False, requirements=self.model.data_requirements, ) + if self.static_inputs: + static_inputs = load_static_inputs(self.static_inputs) + else: + static_inputs = None + if self.coarse_patch_extent_lat and self.coarse_patch_extent_lon: model_coarse_shape = ( self.coarse_patch_extent_lat, @@ -442,6 +445,7 @@ def build(self) -> Trainer: downscaling_model = self.model.build( model_coarse_shape, train_data.downscale_factor, + full_fine_coords=train_data.fine_coords, static_inputs=static_inputs, )