From 985102b7d8c38779add5f2a580a7ea2e9b61c9c0 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Fri, 20 Mar 2026 20:49:47 -0700 Subject: [PATCH 1/4] Refactor StaticInputs Coordinates MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fme/downscaling/data/static.py — major refactor: - StaticInput: removed coords field; now just holds data; subset() takes slices (not intervals) - Added _load_coords_from_ds(), load_fine_coords_from_path() for loading coords from files - Added _has_legacy_coords_in_state(), _has_coords_in_state(), _sync_state_coordinates() for backwards-compat state migration - StaticInputs: added coords: LatLonCoordinates as first-class field; removed per-field coord validation; subset() now returns updated coords; to_device() moves coords too; get_state() includes top-level "coords"; from_state() raises if coords absent; added from_state_backwards_compatible() - _get_normalized_static_input(): returns (StaticInput, LatLonCoordinates) tuple - load_static_inputs(): no longer accepts None; validates coord consistency across fields fme/downscaling/data/test_static.py — full replacement with new test suite covering all new APIs fme/downscaling/data/__init__.py — added load_fine_coords_from_path export fme/downscaling/test_models.py — updated make_static_inputs() to new StaticInputs interface fme/downscaling/inference/test_inference.py — updated get_static_inputs() helper fme/downscaling/train.py — guarded load_static_inputs call against None config# Please enter the commit message for your changes. Lines starting --- fme/downscaling/data/__init__.py | 7 +- fme/downscaling/data/static.py | 237 ++++++++++++-------- fme/downscaling/data/test_static.py | 194 ++++++++++++---- fme/downscaling/inference/test_inference.py | 7 +- fme/downscaling/test_models.py | 14 +- fme/downscaling/train.py | 4 +- 6 files changed, 315 insertions(+), 148 deletions(-) diff --git a/fme/downscaling/data/__init__.py b/fme/downscaling/data/__init__.py index 5790cdc11..334e1a949 100644 --- a/fme/downscaling/data/__init__.py +++ b/fme/downscaling/data/__init__.py @@ -12,7 +12,12 @@ PairedBatchItem, PairedGriddedData, ) -from .static import StaticInput, StaticInputs, load_static_inputs +from .static import ( + StaticInput, + StaticInputs, + load_fine_coords_from_path, + load_static_inputs, +) from .utils import ( BatchedLatLonCoordinates, ClosedInterval, diff --git a/fme/downscaling/data/static.py b/fme/downscaling/data/static.py index c3e5e6e90..e5f870c97 100644 --- a/fme/downscaling/data/static.py +++ b/fme/downscaling/data/static.py @@ -11,17 +11,11 @@ @dataclasses.dataclass class StaticInput: data: torch.Tensor - coords: LatLonCoordinates def __post_init__(self): if len(self.data.shape) != 2: - raise ValueError(f"Topography data must be 2D. Got shape {self.data.shape}") - if self.data.shape[0] != len(self.coords.lat) or self.data.shape[1] != len( - self.coords.lon - ): raise ValueError( - f"Static inputs data shape {self.data.shape} does not match lat/lon " - f"coordinates shape {(len(self.coords.lat), len(self.coords.lon))}" + f"StaticInput data must be 2D. Got shape {self.data.shape}" ) @property @@ -34,100 +28,127 @@ def shape(self) -> tuple[int, int]: def subset( self, - lat_interval: ClosedInterval, - lon_interval: ClosedInterval, + lat_slice: slice, + lon_slice: slice, ) -> "StaticInput": - lat_slice = lat_interval.slice_from(self.coords.lat) - lon_slice = lon_interval.slice_from(self.coords.lon) - return self._latlon_index_slice(lat_slice=lat_slice, lon_slice=lon_slice) + return StaticInput(data=self.data[lat_slice, lon_slice]) def to_device(self) -> "StaticInput": - device = get_device() - return StaticInput( - data=self.data.to(device), - coords=LatLonCoordinates( - lat=self.coords.lat.to(device), - lon=self.coords.lon.to(device), - ), - ) + return StaticInput(data=self.data.to(get_device())) - def _latlon_index_slice( - self, - lat_slice: slice, - lon_slice: slice, - ) -> "StaticInput": - sliced_data = self.data[lat_slice, lon_slice] - sliced_latlon = LatLonCoordinates( - lat=self.coords.lat[lat_slice], - lon=self.coords.lon[lon_slice], - ) - return StaticInput( - data=sliced_data, - coords=sliced_latlon, + def get_state(self) -> dict: + return {"data": self.data.cpu()} + + @classmethod + def from_state(cls, state: dict) -> "StaticInput": + return cls(data=state["data"]) + + +_LAT_NAMES = ("lat", "latitude", "grid_yt") +_LON_NAMES = ("lon", "longitude", "grid_xt") + + +def _load_coords_from_ds(ds: xr.Dataset) -> LatLonCoordinates: + lat_name = next((n for n in _LAT_NAMES if n in ds.coords), None) + lon_name = next((n for n in _LON_NAMES if n in ds.coords), None) + if lat_name is None or lon_name is None: + raise ValueError( + "Could not find lat/lon coordinates in dataset. " + f"Expected one of {_LAT_NAMES} for lat and {_LON_NAMES} for lon." ) + return LatLonCoordinates( + lat=torch.tensor(ds[lat_name].values, dtype=torch.float32), + lon=torch.tensor(ds[lon_name].values, dtype=torch.float32), + ) - def get_state(self) -> dict: - return { - "data": self.data.cpu(), - "coords": self.coords.get_state(), - } + +def load_fine_coords_from_path(path: str) -> LatLonCoordinates: + """Load lat/lon coordinates from a netCDF or zarr file.""" + if path.endswith(".zarr"): + ds = xr.open_zarr(path) + else: + ds = xr.open_dataset(path) + return _load_coords_from_ds(ds) -def _get_normalized_static_input(path: str, field_name: str): +def _get_normalized_static_input( + path: str, field_name: str +) -> tuple["StaticInput", LatLonCoordinates]: """ Load a static input field from a given file path and field name and normalize it. Only supports 2D lat/lon static inputs. If the input has a time dimension, it is - squeezed by taking the first time step. The lat/lon coordinates are - assumed to be the last two dimensions of the loaded dataset dimensions. + squeezed by taking the first time step. + + Raises ValueError if lat/lon coordinates are not found in the dataset. """ if path.endswith(".zarr"): - static_input = xr.open_zarr(path, mask_and_scale=False)[field_name] + ds = xr.open_zarr(path, mask_and_scale=False) else: - static_input = xr.open_dataset(path, mask_and_scale=False)[field_name] - if "time" in static_input.dims: - static_input = static_input.isel(time=0).squeeze() - if len(static_input.shape) != 2: + ds = xr.open_dataset(path, mask_and_scale=False) + + coords = _load_coords_from_ds(ds) + da = ds[field_name] + + if "time" in da.dims: + da = da.isel(time=0).squeeze() + if len(da.shape) != 2: raise ValueError( - f"unexpected shape {static_input.shape} for static input." + f"unexpected shape {da.shape} for static input. " "Currently, only lat/lon static input is supported." ) - lat_name, lon_name = static_input.dims[-2:] - coords = LatLonCoordinates( - lon=torch.tensor(static_input[lon_name].values), - lat=torch.tensor(static_input[lat_name].values), - ) - - static_input_normalized = (static_input - static_input.mean()) / static_input.std() + static_input_normalized = (da - da.mean()) / da.std() return StaticInput( data=torch.tensor(static_input_normalized.values, dtype=torch.float32), - coords=coords, - ) + ), coords + + +def _has_legacy_coords_in_state(state: dict) -> bool: + return bool(state.get("fields")) and "coords" in state["fields"][0] + + +def _has_coords_in_state(state: dict) -> bool: + return "coords" in state or _has_legacy_coords_in_state(state) + + +def _sync_state_coordinates(state: dict) -> dict: + """Migrate old per-field coord format to top-level coords format.""" + if _has_legacy_coords_in_state(state): + state = dict(state) + state["coords"] = state["fields"][0]["coords"] + return state + + +def _validate_coords( + case: str, coord1: LatLonCoordinates, coord2: LatLonCoordinates +) -> None: + if coord1 != coord2: + raise ValueError(f"Coordinates do not match between static inputs: {case}") @dataclasses.dataclass class StaticInputs: fields: list[StaticInput] + coords: LatLonCoordinates def __post_init__(self): for i, field in enumerate(self.fields[1:]): - if field.coords != self.fields[0].coords: + if field.shape != self.fields[0].shape: raise ValueError( - f"All StaticInput fields must have the same coordinates. " - f"Fields {i} and 0 do not match coordinates." + f"All StaticInput fields must have the same shape. " + f"Fields {i + 1} and 0 do not match shapes." ) + if self.fields and self.coords.shape != self.fields[0].shape: + raise ValueError( + f"Coordinates shape {self.coords.shape} does not match " + f"fields shape {self.fields[0].shape}." + ) def __getitem__(self, index: int): return self.fields[index] - @property - def coords(self) -> LatLonCoordinates: - if len(self.fields) == 0: - raise ValueError("No fields in StaticInputs to get coordinates from.") - return self.fields[0].coords - @property def shape(self) -> tuple[int, int]: if len(self.fields) == 0: @@ -140,47 +161,85 @@ def subset( lon_interval: ClosedInterval, ) -> "StaticInputs": return StaticInputs( - fields=[field.subset(lat_interval, lon_interval) for field in self.fields] + fields=[field.subset(lat_slice, lon_slice) for field in self.fields], + coords=LatLonCoordinates( + lat=lat_interval.subset_of(self.coords.lat), + lon=lon_interval.subset_of(self.coords.lon), + ), ) def to_device(self) -> "StaticInputs": - return StaticInputs(fields=[field.to_device() for field in self.fields]) + return StaticInputs( + fields=[field.to_device() for field in self.fields], + coords=self.coords.to(get_device()), + ) def get_state(self) -> dict: return { "fields": [field.get_state() for field in self.fields], + "coords": self.coords.get_state(), } @classmethod def from_state(cls, state: dict) -> "StaticInputs": + if not _has_coords_in_state(state): + raise ValueError( + "No coordinates found in state for StaticInputs. Load with " + "from_state_backwards_compatible if loading from a checkpoint " + "saved prior to current coordinate serialization format." + ) + state = _sync_state_coordinates(state) return cls( fields=[ - StaticInput( - data=field_state["data"], - coords=LatLonCoordinates( - lat=field_state["coords"]["lat"], - lon=field_state["coords"]["lon"], - ), - ) - for field_state in state["fields"] - ] + StaticInput.from_state(field_state) for field_state in state["fields"] + ], + coords=LatLonCoordinates( + lat=state["coords"]["lat"], + lon=state["coords"]["lon"], + ), ) + @classmethod + def from_state_backwards_compatible( + cls, + state: dict, + static_inputs_config: dict[str, str], + ) -> "StaticInputs | None": + if state and static_inputs_config: + raise ValueError( + "Checkpoint contains static inputs but static_inputs_config is " + "also provided. Backwards compatibility loading only supports " + "a single source of StaticInputs info." + ) + if _has_coords_in_state(state): + return cls.from_state(state) + elif static_inputs_config: + return load_static_inputs(static_inputs_config) + else: + return None + def load_static_inputs( - static_inputs_config: dict[str, str] | None, -) -> StaticInputs | None: + static_inputs_config: dict[str, str], +) -> StaticInputs: """ Load normalized static inputs from a mapping of field names to file paths. - Returns None if the input config is empty. + + Coordinates are read from each field's source dataset and validated to be + consistent between fields. Raises ValueError if any field's dataset lacks + lat/lon coordinates or if coordinates differ between fields. """ - # TODO: consolidate/simplify empty StaticInputs vs. None handling in - # downscaling code - if not static_inputs_config: - return None - return StaticInputs( - fields=[ - _get_normalized_static_input(path, field_name) - for field_name, path in static_inputs_config.items() - ] - ) + coords_to_use = None + fields = [] + for field_name, path in static_inputs_config.items(): + si, coords = _get_normalized_static_input(path, field_name) + fields.append(si) + if coords_to_use is None: + coords_to_use = coords + else: + _validate_coords(field_name, coords, coords_to_use) + + if coords_to_use is None: + raise ValueError("load_static_inputs requires at least one field.") + + return StaticInputs(fields=fields, coords=coords_to_use) diff --git a/fme/downscaling/data/test_static.py b/fme/downscaling/data/test_static.py index 827cdfb49..955da241d 100644 --- a/fme/downscaling/data/test_static.py +++ b/fme/downscaling/data/test_static.py @@ -1,64 +1,166 @@ +import numpy as np import pytest import torch +import xarray as xr -from fme.core.coordinates import LatLonCoordinates +from .static import StaticInput, StaticInputs, load_fine_coords_from_path -from .static import StaticInput, StaticInputs -from .utils import ClosedInterval +def _make_coords(n=4): + from fme.core.coordinates import LatLonCoordinates -@pytest.mark.parametrize( - "init_args", - [ - pytest.param( - [ - torch.randn((1, 2, 2)), - LatLonCoordinates(torch.arange(2), torch.arange(2)), - ], - id="3d_data", - ), - pytest.param( - [torch.randn((2, 2)), LatLonCoordinates(torch.arange(2), torch.arange(5))], - id="dim_size_mismatch", - ), - ], -) -def test_Topography_error_cases(init_args): + return LatLonCoordinates( + lat=torch.arange(n, dtype=torch.float32), + lon=torch.arange(n, dtype=torch.float32), + ) + + +def _write_coords_netcdf(path, n=4): + xr.Dataset( + coords={ + "lat": np.arange(n, dtype=np.float32), + "lon": np.arange(n, dtype=np.float32), + } + ).to_netcdf(path) + + +def test_StaticInput_error_cases(): + data = torch.randn(1, 2, 2) with pytest.raises(ValueError): - StaticInput(*init_args) + StaticInput(data=data) def test_subset(): full_data_shape = (10, 10) - expected_slices = [slice(2, 6), slice(3, 8)] data = torch.randn(*full_data_shape) - coords = LatLonCoordinates( - lat=torch.linspace(0, 9, 10), lon=torch.linspace(0, 9, 10) - ) - topo = StaticInput(data=data, coords=coords) - lat_interval = ClosedInterval(2, 5) - lon_interval = ClosedInterval(3, 7) - subset_topo = topo.subset(lat_interval, lon_interval) - expected_lats = torch.tensor([2, 3, 4, 5], dtype=coords.lat.dtype) - expected_lons = torch.tensor([3, 4, 5, 6, 7], dtype=coords.lon.dtype) - expected_data = data[*expected_slices] - assert torch.equal(subset_topo.coords.lat, expected_lats) - assert torch.equal(subset_topo.coords.lon, expected_lons) - assert torch.allclose(subset_topo.data, expected_data) + topo = StaticInput(data=data) + lat_slice = slice(2, 6) + lon_slice = slice(3, 8) + subset_topo = topo.subset(lat_slice, lon_slice) + assert torch.allclose(subset_topo.data, data[lat_slice, lon_slice]) def test_StaticInputs_serialize(): - data = torch.arange(16).reshape(4, 4) - topography = StaticInput( - data, - LatLonCoordinates(torch.arange(4), torch.arange(4)), + dim_len = 4 + data = torch.arange(dim_len * dim_len, dtype=torch.float32).reshape( + dim_len, dim_len ) - land_frac = StaticInput( - data * -1.0, - LatLonCoordinates(torch.arange(4), torch.arange(4)), + coords = _make_coords(n=dim_len) + static_inputs = StaticInputs( + [StaticInput(data), StaticInput(data * -1.0)], coords=coords ) - static_inputs = StaticInputs([topography, land_frac]) state = static_inputs.get_state() - static_inputs_reconstructed = StaticInputs.from_state(state) - assert static_inputs_reconstructed[0].data.equal(static_inputs[0].data) - assert static_inputs_reconstructed[1].data.equal(static_inputs[1].data) + assert "coords" in state + reconstructed = StaticInputs.from_state(state) + assert reconstructed[0].data.equal(static_inputs[0].data) + assert reconstructed[1].data.equal(static_inputs[1].data) + assert torch.equal(reconstructed.coords.lat, static_inputs.coords.lat) + assert torch.equal(reconstructed.coords.lon, static_inputs.coords.lon) + + +def test_StaticInputs_from_state_raises_on_missing_coords(): + data = torch.arange(16, dtype=torch.float32).reshape(4, 4) + with pytest.raises(ValueError, match="No coordinates"): + StaticInputs.from_state({"fields": [{"data": data}]}) + + +def test_StaticInputs_from_state_legacy_coords_in_fields(): + """from_state handles old format where coords were stored only inside each field.""" + dim_len = 4 + data = torch.arange(dim_len * dim_len, dtype=torch.float32).reshape( + dim_len, dim_len + ) + coords = _make_coords(n=dim_len) + old_state = { + "fields": [{"data": data, "coords": {"lat": coords.lat, "lon": coords.lon}}], + } + result = StaticInputs.from_state(old_state) + assert torch.equal(result[0].data, data) + assert torch.equal(result.coords.lat, coords.lat) + assert torch.equal(result.coords.lon, coords.lon) + + +def test_from_state_backwards_compatible_has_coords(): + """When state has coords, delegates to from_state.""" + data = torch.arange(16, dtype=torch.float32).reshape(4, 4) + coords = _make_coords() + state = StaticInputs([StaticInput(data)], coords=coords).get_state() + result = StaticInputs.from_state_backwards_compatible( + state=state, static_inputs_config={} + ) + assert result is not None + assert torch.equal(result[0].data, data) + + +def test_from_state_backwards_compatible_no_state_no_config(): + """No static inputs in checkpoint and no config: returns None.""" + result = StaticInputs.from_state_backwards_compatible( + state={}, static_inputs_config={} + ) + assert result is None + + +def test_from_state_backwards_compatible_with_config(tmp_path): + """Checkpoint with static_inputs_config: loads fields and coords from files.""" + dim_len = 4 + lat = np.linspace(0, 1, dim_len, dtype=np.float32) + lon = np.linspace(0, 1, dim_len, dtype=np.float32) + field_data = np.random.rand(dim_len, dim_len).astype(np.float32) + field_path = str(tmp_path / "field.nc") + xr.Dataset( + {"HGTsfc": (["lat", "lon"], field_data)}, + coords={"lat": lat, "lon": lon}, + ).to_netcdf(field_path) + result = StaticInputs.from_state_backwards_compatible( + state={}, + static_inputs_config={"HGTsfc": field_path}, + ) + assert result is not None + assert len(result.fields) == 1 + + +def test_from_state_backwards_compatible_raises_state_and_config(): + """ + Errors if checkpoint state has fields and static_inputs_config is also provided. + """ + dim_len = 4 + data = torch.arange(dim_len * dim_len, dtype=torch.float32).reshape( + dim_len, dim_len + ) + coords = _make_coords(n=dim_len) + state = StaticInputs([StaticInput(data)], coords=coords).get_state() + with pytest.raises(ValueError, match="static_inputs_config"): + StaticInputs.from_state_backwards_compatible( + state=state, + static_inputs_config={"HGTsfc": "some/path"}, + ) + + +@pytest.mark.parametrize( + "lat_name,lon_name", + [ + pytest.param("lat", "lon", id="standard"), + pytest.param("latitude", "longitude", id="long_names"), + pytest.param("grid_yt", "grid_xt", id="fv3_names"), + ], +) +def test_load_fine_coords_from_path(tmp_path, lat_name, lon_name): + lat = [0.0, 1.0, 2.0] + lon = [10.0, 20.0, 30.0, 40.0] + ds = xr.Dataset(coords={lat_name: lat, lon_name: lon}) + path = str(tmp_path / "coords.nc") + ds.to_netcdf(path) + + coords = load_fine_coords_from_path(path) + + assert torch.allclose(coords.lat, torch.tensor(lat, dtype=torch.float32)) + assert torch.allclose(coords.lon, torch.tensor(lon, dtype=torch.float32)) + + +def test_load_fine_coords_from_path_raises_on_missing_coords(tmp_path): + ds = xr.Dataset(coords={"x": [0.0, 1.0], "y": [10.0, 20.0]}) + path = str(tmp_path / "no_latlon.nc") + ds.to_netcdf(path) + + with pytest.raises(ValueError, match="Could not find lat/lon coordinates"): + load_fine_coords_from_path(path) diff --git a/fme/downscaling/inference/test_inference.py b/fme/downscaling/inference/test_inference.py index aad4071f6..c96733b31 100644 --- a/fme/downscaling/inference/test_inference.py +++ b/fme/downscaling/inference/test_inference.py @@ -64,8 +64,11 @@ def mock_output_target(): def get_static_inputs(shape=(16, 16)): data = torch.randn(shape) - coords = LatLonCoordinates(lat=torch.arange(shape[0]), lon=torch.arange(shape[1])) - return StaticInputs([StaticInput(data=data, coords=coords)]) + coords = LatLonCoordinates( + lat=torch.arange(shape[0], dtype=torch.float32), + lon=torch.arange(shape[1], dtype=torch.float32), + ) + return StaticInputs([StaticInput(data=data)], coords=coords) # Tests for Downscaler initialization diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index 864f511f3..703af10d5 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -127,15 +127,11 @@ def make_static_inputs(fine_shape: tuple[int, int]) -> StaticInputs: """Create StaticInputs with proper monotonic coordinates for given shape.""" lat_size, lon_size = fine_shape return StaticInputs( - fields=[ - StaticInput( - torch.ones(*fine_shape, device=get_device()), - LatLonCoordinates( - lat=_get_monotonic_coordinate(lat_size, stop=lat_size), - lon=_get_monotonic_coordinate(lon_size, stop=lon_size), - ), - ) - ] + fields=[StaticInput(torch.ones(*fine_shape, device=get_device()))], + coords=LatLonCoordinates( + lat=_get_monotonic_coordinate(lat_size, stop=lat_size), + lon=_get_monotonic_coordinate(lon_size, stop=lon_size), + ), ) diff --git a/fme/downscaling/train.py b/fme/downscaling/train.py index a1dd08fde..1a2c9e46f 100755 --- a/fme/downscaling/train.py +++ b/fme/downscaling/train.py @@ -421,7 +421,9 @@ def checkpoint_dir(self) -> str: return os.path.join(self.experiment_dir, "checkpoints") def build(self) -> Trainer: - static_inputs = load_static_inputs(self.static_inputs) + static_inputs = ( + load_static_inputs(self.static_inputs) if self.static_inputs else None + ) train_data: PairedGriddedData = self.train_data.build( train=True, From 665dfb998f806d0151b2d5356188c1aebefe9ef1 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Mon, 23 Mar 2026 14:43:14 -0700 Subject: [PATCH 2/4] Add slice definition back into subset operation --- fme/downscaling/data/static.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fme/downscaling/data/static.py b/fme/downscaling/data/static.py index e5f870c97..9c3660933 100644 --- a/fme/downscaling/data/static.py +++ b/fme/downscaling/data/static.py @@ -160,6 +160,8 @@ def subset( lat_interval: ClosedInterval, lon_interval: ClosedInterval, ) -> "StaticInputs": + lat_slice = lat_interval.slice_from(self.coords.lat) + lon_slice = lon_interval.slice_from(self.coords.lon) return StaticInputs( fields=[field.subset(lat_slice, lon_slice) for field in self.fields], coords=LatLonCoordinates( From 687b3d98c1af0867e30c25fbcebf480836dbb077 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Mon, 23 Mar 2026 15:51:51 -0700 Subject: [PATCH 3/4] Remove load_fine_coords_from_path helper --- fme/downscaling/data/__init__.py | 7 +---- fme/downscaling/data/static.py | 9 ------- fme/downscaling/data/test_static.py | 40 ++++++----------------------- 3 files changed, 9 insertions(+), 47 deletions(-) diff --git a/fme/downscaling/data/__init__.py b/fme/downscaling/data/__init__.py index 334e1a949..5790cdc11 100644 --- a/fme/downscaling/data/__init__.py +++ b/fme/downscaling/data/__init__.py @@ -12,12 +12,7 @@ PairedBatchItem, PairedGriddedData, ) -from .static import ( - StaticInput, - StaticInputs, - load_fine_coords_from_path, - load_static_inputs, -) +from .static import StaticInput, StaticInputs, load_static_inputs from .utils import ( BatchedLatLonCoordinates, ClosedInterval, diff --git a/fme/downscaling/data/static.py b/fme/downscaling/data/static.py index 9c3660933..bc5ef20cb 100644 --- a/fme/downscaling/data/static.py +++ b/fme/downscaling/data/static.py @@ -62,15 +62,6 @@ def _load_coords_from_ds(ds: xr.Dataset) -> LatLonCoordinates: ) -def load_fine_coords_from_path(path: str) -> LatLonCoordinates: - """Load lat/lon coordinates from a netCDF or zarr file.""" - if path.endswith(".zarr"): - ds = xr.open_zarr(path) - else: - ds = xr.open_dataset(path) - return _load_coords_from_ds(ds) - - def _get_normalized_static_input( path: str, field_name: str ) -> tuple["StaticInput", LatLonCoordinates]: diff --git a/fme/downscaling/data/test_static.py b/fme/downscaling/data/test_static.py index 955da241d..4601985de 100644 --- a/fme/downscaling/data/test_static.py +++ b/fme/downscaling/data/test_static.py @@ -3,7 +3,7 @@ import torch import xarray as xr -from .static import StaticInput, StaticInputs, load_fine_coords_from_path +from .static import StaticInput, StaticInputs, _load_coords_from_ds def _make_coords(n=4): @@ -15,15 +15,6 @@ def _make_coords(n=4): ) -def _write_coords_netcdf(path, n=4): - xr.Dataset( - coords={ - "lat": np.arange(n, dtype=np.float32), - "lon": np.arange(n, dtype=np.float32), - } - ).to_netcdf(path) - - def test_StaticInput_error_cases(): data = torch.randn(1, 2, 2) with pytest.raises(ValueError): @@ -136,31 +127,16 @@ def test_from_state_backwards_compatible_raises_state_and_config(): ) -@pytest.mark.parametrize( - "lat_name,lon_name", - [ - pytest.param("lat", "lon", id="standard"), - pytest.param("latitude", "longitude", id="long_names"), - pytest.param("grid_yt", "grid_xt", id="fv3_names"), - ], -) -def test_load_fine_coords_from_path(tmp_path, lat_name, lon_name): +def test_load_fine_coords_from_path(): lat = [0.0, 1.0, 2.0] lon = [10.0, 20.0, 30.0, 40.0] - ds = xr.Dataset(coords={lat_name: lat, lon_name: lon}) - path = str(tmp_path / "coords.nc") - ds.to_netcdf(path) - - coords = load_fine_coords_from_path(path) + ds = xr.Dataset(coords={"lat": lat, "lon": lon}) + coords = _load_coords_from_ds(ds) assert torch.allclose(coords.lat, torch.tensor(lat, dtype=torch.float32)) assert torch.allclose(coords.lon, torch.tensor(lon, dtype=torch.float32)) - -def test_load_fine_coords_from_path_raises_on_missing_coords(tmp_path): - ds = xr.Dataset(coords={"x": [0.0, 1.0], "y": [10.0, 20.0]}) - path = str(tmp_path / "no_latlon.nc") - ds.to_netcdf(path) - - with pytest.raises(ValueError, match="Could not find lat/lon coordinates"): - load_fine_coords_from_path(path) + # expected coord names missing + ds = xr.Dataset(coords={"x": lon, "y": lat}) + with pytest.raises(ValueError): + _load_coords_from_ds(ds) From 2a8205dda9d3762bab2a69da426de6d1a87a9309 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Mon, 23 Mar 2026 15:53:28 -0700 Subject: [PATCH 4/4] Rename test for _load_coords_from_ds --- fme/downscaling/data/test_static.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fme/downscaling/data/test_static.py b/fme/downscaling/data/test_static.py index 4601985de..aba525c48 100644 --- a/fme/downscaling/data/test_static.py +++ b/fme/downscaling/data/test_static.py @@ -127,7 +127,7 @@ def test_from_state_backwards_compatible_raises_state_and_config(): ) -def test_load_fine_coords_from_path(): +def test__load_coords_from_ds(): lat = [0.0, 1.0, 2.0] lon = [10.0, 20.0, 30.0, 40.0] ds = xr.Dataset(coords={"lat": lat, "lon": lon})