diff --git a/fme/downscaling/data/static.py b/fme/downscaling/data/static.py index c3e5e6e90..bc5ef20cb 100644 --- a/fme/downscaling/data/static.py +++ b/fme/downscaling/data/static.py @@ -11,17 +11,11 @@ @dataclasses.dataclass class StaticInput: data: torch.Tensor - coords: LatLonCoordinates def __post_init__(self): if len(self.data.shape) != 2: - raise ValueError(f"Topography data must be 2D. Got shape {self.data.shape}") - if self.data.shape[0] != len(self.coords.lat) or self.data.shape[1] != len( - self.coords.lon - ): raise ValueError( - f"Static inputs data shape {self.data.shape} does not match lat/lon " - f"coordinates shape {(len(self.coords.lat), len(self.coords.lon))}" + f"StaticInput data must be 2D. Got shape {self.data.shape}" ) @property @@ -33,101 +27,119 @@ def shape(self) -> tuple[int, int]: return self.data.shape def subset( - self, - lat_interval: ClosedInterval, - lon_interval: ClosedInterval, - ) -> "StaticInput": - lat_slice = lat_interval.slice_from(self.coords.lat) - lon_slice = lon_interval.slice_from(self.coords.lon) - return self._latlon_index_slice(lat_slice=lat_slice, lon_slice=lon_slice) - - def to_device(self) -> "StaticInput": - device = get_device() - return StaticInput( - data=self.data.to(device), - coords=LatLonCoordinates( - lat=self.coords.lat.to(device), - lon=self.coords.lon.to(device), - ), - ) - - def _latlon_index_slice( self, lat_slice: slice, lon_slice: slice, ) -> "StaticInput": - sliced_data = self.data[lat_slice, lon_slice] - sliced_latlon = LatLonCoordinates( - lat=self.coords.lat[lat_slice], - lon=self.coords.lon[lon_slice], - ) - return StaticInput( - data=sliced_data, - coords=sliced_latlon, - ) + return StaticInput(data=self.data[lat_slice, lon_slice]) + + def to_device(self) -> "StaticInput": + return StaticInput(data=self.data.to(get_device())) def get_state(self) -> dict: - return { - "data": self.data.cpu(), - "coords": self.coords.get_state(), - } + return {"data": self.data.cpu()} + + @classmethod + def from_state(cls, state: dict) -> "StaticInput": + return cls(data=state["data"]) -def _get_normalized_static_input(path: str, field_name: str): +_LAT_NAMES = ("lat", "latitude", "grid_yt") +_LON_NAMES = ("lon", "longitude", "grid_xt") + + +def _load_coords_from_ds(ds: xr.Dataset) -> LatLonCoordinates: + lat_name = next((n for n in _LAT_NAMES if n in ds.coords), None) + lon_name = next((n for n in _LON_NAMES if n in ds.coords), None) + if lat_name is None or lon_name is None: + raise ValueError( + "Could not find lat/lon coordinates in dataset. " + f"Expected one of {_LAT_NAMES} for lat and {_LON_NAMES} for lon." + ) + return LatLonCoordinates( + lat=torch.tensor(ds[lat_name].values, dtype=torch.float32), + lon=torch.tensor(ds[lon_name].values, dtype=torch.float32), + ) + + +def _get_normalized_static_input( + path: str, field_name: str +) -> tuple["StaticInput", LatLonCoordinates]: """ Load a static input field from a given file path and field name and normalize it. Only supports 2D lat/lon static inputs. If the input has a time dimension, it is - squeezed by taking the first time step. The lat/lon coordinates are - assumed to be the last two dimensions of the loaded dataset dimensions. + squeezed by taking the first time step. + + Raises ValueError if lat/lon coordinates are not found in the dataset. """ if path.endswith(".zarr"): - static_input = xr.open_zarr(path, mask_and_scale=False)[field_name] + ds = xr.open_zarr(path, mask_and_scale=False) else: - static_input = xr.open_dataset(path, mask_and_scale=False)[field_name] - if "time" in static_input.dims: - static_input = static_input.isel(time=0).squeeze() - if len(static_input.shape) != 2: + ds = xr.open_dataset(path, mask_and_scale=False) + + coords = _load_coords_from_ds(ds) + da = ds[field_name] + + if "time" in da.dims: + da = da.isel(time=0).squeeze() + if len(da.shape) != 2: raise ValueError( - f"unexpected shape {static_input.shape} for static input." + f"unexpected shape {da.shape} for static input. " "Currently, only lat/lon static input is supported." ) - lat_name, lon_name = static_input.dims[-2:] - coords = LatLonCoordinates( - lon=torch.tensor(static_input[lon_name].values), - lat=torch.tensor(static_input[lat_name].values), - ) - - static_input_normalized = (static_input - static_input.mean()) / static_input.std() + static_input_normalized = (da - da.mean()) / da.std() return StaticInput( data=torch.tensor(static_input_normalized.values, dtype=torch.float32), - coords=coords, - ) + ), coords + + +def _has_legacy_coords_in_state(state: dict) -> bool: + return bool(state.get("fields")) and "coords" in state["fields"][0] + + +def _has_coords_in_state(state: dict) -> bool: + return "coords" in state or _has_legacy_coords_in_state(state) + + +def _sync_state_coordinates(state: dict) -> dict: + """Migrate old per-field coord format to top-level coords format.""" + if _has_legacy_coords_in_state(state): + state = dict(state) + state["coords"] = state["fields"][0]["coords"] + return state + + +def _validate_coords( + case: str, coord1: LatLonCoordinates, coord2: LatLonCoordinates +) -> None: + if coord1 != coord2: + raise ValueError(f"Coordinates do not match between static inputs: {case}") @dataclasses.dataclass class StaticInputs: fields: list[StaticInput] + coords: LatLonCoordinates def __post_init__(self): for i, field in enumerate(self.fields[1:]): - if field.coords != self.fields[0].coords: + if field.shape != self.fields[0].shape: raise ValueError( - f"All StaticInput fields must have the same coordinates. " - f"Fields {i} and 0 do not match coordinates." + f"All StaticInput fields must have the same shape. " + f"Fields {i + 1} and 0 do not match shapes." ) + if self.fields and self.coords.shape != self.fields[0].shape: + raise ValueError( + f"Coordinates shape {self.coords.shape} does not match " + f"fields shape {self.fields[0].shape}." + ) def __getitem__(self, index: int): return self.fields[index] - @property - def coords(self) -> LatLonCoordinates: - if len(self.fields) == 0: - raise ValueError("No fields in StaticInputs to get coordinates from.") - return self.fields[0].coords - @property def shape(self) -> tuple[int, int]: if len(self.fields) == 0: @@ -139,48 +151,88 @@ def subset( lat_interval: ClosedInterval, lon_interval: ClosedInterval, ) -> "StaticInputs": + lat_slice = lat_interval.slice_from(self.coords.lat) + lon_slice = lon_interval.slice_from(self.coords.lon) return StaticInputs( - fields=[field.subset(lat_interval, lon_interval) for field in self.fields] + fields=[field.subset(lat_slice, lon_slice) for field in self.fields], + coords=LatLonCoordinates( + lat=lat_interval.subset_of(self.coords.lat), + lon=lon_interval.subset_of(self.coords.lon), + ), ) def to_device(self) -> "StaticInputs": - return StaticInputs(fields=[field.to_device() for field in self.fields]) + return StaticInputs( + fields=[field.to_device() for field in self.fields], + coords=self.coords.to(get_device()), + ) def get_state(self) -> dict: return { "fields": [field.get_state() for field in self.fields], + "coords": self.coords.get_state(), } @classmethod def from_state(cls, state: dict) -> "StaticInputs": + if not _has_coords_in_state(state): + raise ValueError( + "No coordinates found in state for StaticInputs. Load with " + "from_state_backwards_compatible if loading from a checkpoint " + "saved prior to current coordinate serialization format." + ) + state = _sync_state_coordinates(state) return cls( fields=[ - StaticInput( - data=field_state["data"], - coords=LatLonCoordinates( - lat=field_state["coords"]["lat"], - lon=field_state["coords"]["lon"], - ), - ) - for field_state in state["fields"] - ] + StaticInput.from_state(field_state) for field_state in state["fields"] + ], + coords=LatLonCoordinates( + lat=state["coords"]["lat"], + lon=state["coords"]["lon"], + ), ) + @classmethod + def from_state_backwards_compatible( + cls, + state: dict, + static_inputs_config: dict[str, str], + ) -> "StaticInputs | None": + if state and static_inputs_config: + raise ValueError( + "Checkpoint contains static inputs but static_inputs_config is " + "also provided. Backwards compatibility loading only supports " + "a single source of StaticInputs info." + ) + if _has_coords_in_state(state): + return cls.from_state(state) + elif static_inputs_config: + return load_static_inputs(static_inputs_config) + else: + return None + def load_static_inputs( - static_inputs_config: dict[str, str] | None, -) -> StaticInputs | None: + static_inputs_config: dict[str, str], +) -> StaticInputs: """ Load normalized static inputs from a mapping of field names to file paths. - Returns None if the input config is empty. + + Coordinates are read from each field's source dataset and validated to be + consistent between fields. Raises ValueError if any field's dataset lacks + lat/lon coordinates or if coordinates differ between fields. """ - # TODO: consolidate/simplify empty StaticInputs vs. None handling in - # downscaling code - if not static_inputs_config: - return None - return StaticInputs( - fields=[ - _get_normalized_static_input(path, field_name) - for field_name, path in static_inputs_config.items() - ] - ) + coords_to_use = None + fields = [] + for field_name, path in static_inputs_config.items(): + si, coords = _get_normalized_static_input(path, field_name) + fields.append(si) + if coords_to_use is None: + coords_to_use = coords + else: + _validate_coords(field_name, coords, coords_to_use) + + if coords_to_use is None: + raise ValueError("load_static_inputs requires at least one field.") + + return StaticInputs(fields=fields, coords=coords_to_use) diff --git a/fme/downscaling/data/test_static.py b/fme/downscaling/data/test_static.py index 827cdfb49..aba525c48 100644 --- a/fme/downscaling/data/test_static.py +++ b/fme/downscaling/data/test_static.py @@ -1,64 +1,142 @@ +import numpy as np import pytest import torch +import xarray as xr -from fme.core.coordinates import LatLonCoordinates - -from .static import StaticInput, StaticInputs -from .utils import ClosedInterval - - -@pytest.mark.parametrize( - "init_args", - [ - pytest.param( - [ - torch.randn((1, 2, 2)), - LatLonCoordinates(torch.arange(2), torch.arange(2)), - ], - id="3d_data", - ), - pytest.param( - [torch.randn((2, 2)), LatLonCoordinates(torch.arange(2), torch.arange(5))], - id="dim_size_mismatch", - ), - ], -) -def test_Topography_error_cases(init_args): +from .static import StaticInput, StaticInputs, _load_coords_from_ds + + +def _make_coords(n=4): + from fme.core.coordinates import LatLonCoordinates + + return LatLonCoordinates( + lat=torch.arange(n, dtype=torch.float32), + lon=torch.arange(n, dtype=torch.float32), + ) + + +def test_StaticInput_error_cases(): + data = torch.randn(1, 2, 2) with pytest.raises(ValueError): - StaticInput(*init_args) + StaticInput(data=data) def test_subset(): full_data_shape = (10, 10) - expected_slices = [slice(2, 6), slice(3, 8)] data = torch.randn(*full_data_shape) - coords = LatLonCoordinates( - lat=torch.linspace(0, 9, 10), lon=torch.linspace(0, 9, 10) - ) - topo = StaticInput(data=data, coords=coords) - lat_interval = ClosedInterval(2, 5) - lon_interval = ClosedInterval(3, 7) - subset_topo = topo.subset(lat_interval, lon_interval) - expected_lats = torch.tensor([2, 3, 4, 5], dtype=coords.lat.dtype) - expected_lons = torch.tensor([3, 4, 5, 6, 7], dtype=coords.lon.dtype) - expected_data = data[*expected_slices] - assert torch.equal(subset_topo.coords.lat, expected_lats) - assert torch.equal(subset_topo.coords.lon, expected_lons) - assert torch.allclose(subset_topo.data, expected_data) + topo = StaticInput(data=data) + lat_slice = slice(2, 6) + lon_slice = slice(3, 8) + subset_topo = topo.subset(lat_slice, lon_slice) + assert torch.allclose(subset_topo.data, data[lat_slice, lon_slice]) def test_StaticInputs_serialize(): - data = torch.arange(16).reshape(4, 4) - topography = StaticInput( - data, - LatLonCoordinates(torch.arange(4), torch.arange(4)), + dim_len = 4 + data = torch.arange(dim_len * dim_len, dtype=torch.float32).reshape( + dim_len, dim_len ) - land_frac = StaticInput( - data * -1.0, - LatLonCoordinates(torch.arange(4), torch.arange(4)), + coords = _make_coords(n=dim_len) + static_inputs = StaticInputs( + [StaticInput(data), StaticInput(data * -1.0)], coords=coords ) - static_inputs = StaticInputs([topography, land_frac]) state = static_inputs.get_state() - static_inputs_reconstructed = StaticInputs.from_state(state) - assert static_inputs_reconstructed[0].data.equal(static_inputs[0].data) - assert static_inputs_reconstructed[1].data.equal(static_inputs[1].data) + assert "coords" in state + reconstructed = StaticInputs.from_state(state) + assert reconstructed[0].data.equal(static_inputs[0].data) + assert reconstructed[1].data.equal(static_inputs[1].data) + assert torch.equal(reconstructed.coords.lat, static_inputs.coords.lat) + assert torch.equal(reconstructed.coords.lon, static_inputs.coords.lon) + + +def test_StaticInputs_from_state_raises_on_missing_coords(): + data = torch.arange(16, dtype=torch.float32).reshape(4, 4) + with pytest.raises(ValueError, match="No coordinates"): + StaticInputs.from_state({"fields": [{"data": data}]}) + + +def test_StaticInputs_from_state_legacy_coords_in_fields(): + """from_state handles old format where coords were stored only inside each field.""" + dim_len = 4 + data = torch.arange(dim_len * dim_len, dtype=torch.float32).reshape( + dim_len, dim_len + ) + coords = _make_coords(n=dim_len) + old_state = { + "fields": [{"data": data, "coords": {"lat": coords.lat, "lon": coords.lon}}], + } + result = StaticInputs.from_state(old_state) + assert torch.equal(result[0].data, data) + assert torch.equal(result.coords.lat, coords.lat) + assert torch.equal(result.coords.lon, coords.lon) + + +def test_from_state_backwards_compatible_has_coords(): + """When state has coords, delegates to from_state.""" + data = torch.arange(16, dtype=torch.float32).reshape(4, 4) + coords = _make_coords() + state = StaticInputs([StaticInput(data)], coords=coords).get_state() + result = StaticInputs.from_state_backwards_compatible( + state=state, static_inputs_config={} + ) + assert result is not None + assert torch.equal(result[0].data, data) + + +def test_from_state_backwards_compatible_no_state_no_config(): + """No static inputs in checkpoint and no config: returns None.""" + result = StaticInputs.from_state_backwards_compatible( + state={}, static_inputs_config={} + ) + assert result is None + + +def test_from_state_backwards_compatible_with_config(tmp_path): + """Checkpoint with static_inputs_config: loads fields and coords from files.""" + dim_len = 4 + lat = np.linspace(0, 1, dim_len, dtype=np.float32) + lon = np.linspace(0, 1, dim_len, dtype=np.float32) + field_data = np.random.rand(dim_len, dim_len).astype(np.float32) + field_path = str(tmp_path / "field.nc") + xr.Dataset( + {"HGTsfc": (["lat", "lon"], field_data)}, + coords={"lat": lat, "lon": lon}, + ).to_netcdf(field_path) + result = StaticInputs.from_state_backwards_compatible( + state={}, + static_inputs_config={"HGTsfc": field_path}, + ) + assert result is not None + assert len(result.fields) == 1 + + +def test_from_state_backwards_compatible_raises_state_and_config(): + """ + Errors if checkpoint state has fields and static_inputs_config is also provided. + """ + dim_len = 4 + data = torch.arange(dim_len * dim_len, dtype=torch.float32).reshape( + dim_len, dim_len + ) + coords = _make_coords(n=dim_len) + state = StaticInputs([StaticInput(data)], coords=coords).get_state() + with pytest.raises(ValueError, match="static_inputs_config"): + StaticInputs.from_state_backwards_compatible( + state=state, + static_inputs_config={"HGTsfc": "some/path"}, + ) + + +def test__load_coords_from_ds(): + lat = [0.0, 1.0, 2.0] + lon = [10.0, 20.0, 30.0, 40.0] + ds = xr.Dataset(coords={"lat": lat, "lon": lon}) + + coords = _load_coords_from_ds(ds) + assert torch.allclose(coords.lat, torch.tensor(lat, dtype=torch.float32)) + assert torch.allclose(coords.lon, torch.tensor(lon, dtype=torch.float32)) + + # expected coord names missing + ds = xr.Dataset(coords={"x": lon, "y": lat}) + with pytest.raises(ValueError): + _load_coords_from_ds(ds) diff --git a/fme/downscaling/inference/test_inference.py b/fme/downscaling/inference/test_inference.py index aad4071f6..c96733b31 100644 --- a/fme/downscaling/inference/test_inference.py +++ b/fme/downscaling/inference/test_inference.py @@ -64,8 +64,11 @@ def mock_output_target(): def get_static_inputs(shape=(16, 16)): data = torch.randn(shape) - coords = LatLonCoordinates(lat=torch.arange(shape[0]), lon=torch.arange(shape[1])) - return StaticInputs([StaticInput(data=data, coords=coords)]) + coords = LatLonCoordinates( + lat=torch.arange(shape[0], dtype=torch.float32), + lon=torch.arange(shape[1], dtype=torch.float32), + ) + return StaticInputs([StaticInput(data=data)], coords=coords) # Tests for Downscaler initialization diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index 864f511f3..703af10d5 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -127,15 +127,11 @@ def make_static_inputs(fine_shape: tuple[int, int]) -> StaticInputs: """Create StaticInputs with proper monotonic coordinates for given shape.""" lat_size, lon_size = fine_shape return StaticInputs( - fields=[ - StaticInput( - torch.ones(*fine_shape, device=get_device()), - LatLonCoordinates( - lat=_get_monotonic_coordinate(lat_size, stop=lat_size), - lon=_get_monotonic_coordinate(lon_size, stop=lon_size), - ), - ) - ] + fields=[StaticInput(torch.ones(*fine_shape, device=get_device()))], + coords=LatLonCoordinates( + lat=_get_monotonic_coordinate(lat_size, stop=lat_size), + lon=_get_monotonic_coordinate(lon_size, stop=lon_size), + ), ) diff --git a/fme/downscaling/train.py b/fme/downscaling/train.py index a1dd08fde..1a2c9e46f 100755 --- a/fme/downscaling/train.py +++ b/fme/downscaling/train.py @@ -421,7 +421,9 @@ def checkpoint_dir(self) -> str: return os.path.join(self.experiment_dir, "checkpoints") def build(self) -> Trainer: - static_inputs = load_static_inputs(self.static_inputs) + static_inputs = ( + load_static_inputs(self.static_inputs) if self.static_inputs else None + ) train_data: PairedGriddedData = self.train_data.build( train=True,