diff --git a/fme/downscaling/data/config.py b/fme/downscaling/data/config.py index aec2ff7c0..14396a8e7 100644 --- a/fme/downscaling/data/config.py +++ b/fme/downscaling/data/config.py @@ -23,8 +23,11 @@ PairedBatchData, PairedGriddedData, ) -from fme.downscaling.data.static import StaticInputs -from fme.downscaling.data.utils import ClosedInterval, adjust_fine_coord_range +from fme.downscaling.data.utils import ( + ClosedInterval, + adjust_fine_coord_range, + get_latlon_coords_from_properties, +) from fme.downscaling.requirements import DataRequirements @@ -132,18 +135,6 @@ def _full_configs( return all_configs -def _check_fine_res_static_input_compatibility( - static_input_shape: tuple[int, int], data_coords_shape: tuple[int, int] -) -> None: - for static, coord in zip(static_input_shape, data_coords_shape): - if static != coord: - raise ValueError( - f"Static input shape {static_input_shape} is not compatible with " - f"data coordinates shape {data_coords_shape}. Static input dimensions " - "must match fine resolution coordinate dimensions." - ) - - @dataclasses.dataclass class DataLoaderConfig: """ @@ -164,8 +155,9 @@ class DataLoaderConfig: (For multi-GPU runtime, it's the number of workers per GPU.) strict_ensemble: Whether to enforce that the datasets to be concatened have the same dimensions and coordinates. - topography: Deprecated field for specifying the topography dataset. Now - provided via build method's `static_inputs` argument. + topography: Deprecated field for specifying the topography dataset. + StaticInput data are expected to be stored and serialized within a + model through the Trainer build process. lat_extent: The latitude extent to use for the dataset specified in degrees, limited to (-88.0, 88.0). The extent is inclusive, so the start and stop values are included in the extent. Defaults to [-66, 70] which @@ -202,8 +194,8 @@ def __post_init__(self): if self.topography is not None: raise ValueError( "The `topography` field on DataLoaderConfig is deprecated and will be " - "removed in a future release. Pass static_inputs via build's " - "`static_inputs` argument instead." + "removed in a future release. `StaticInputs` are now stored within " + " the model when it is first built and trained." ) @property @@ -236,40 +228,6 @@ def get_xarray_dataset( strict_ensemble=self.strict_ensemble, ) - def build_static_inputs( - self, - coarse_coords: LatLonCoordinates, - requires_topography: bool, - static_inputs: StaticInputs | None = None, - ) -> StaticInputs | None: - if requires_topography is False: - return None - if static_inputs is not None: - # TODO: change to use full static inputs list - full_static_inputs = static_inputs - else: - raise ValueError( - "Static inputs required for this model, but no static inputs " - "datasets were specified in the trainer configuration or provided " - "in model checkpoint." - ) - - # Fine grid boundaries are adjusted to exactly match the coarse grid - fine_lat_interval = adjust_fine_coord_range( - self.lat_extent, - full_coarse_coord=coarse_coords.lat, - full_fine_coord=full_static_inputs.coords.lat, - ) - fine_lon_interval = adjust_fine_coord_range( - self.lon_extent, - full_coarse_coord=coarse_coords.lon, - full_fine_coord=full_static_inputs.coords.lon, - ) - subset_static_inputs = full_static_inputs.subset_latlon( - lat_interval=fine_lat_interval, lon_interval=fine_lon_interval - ) - return subset_static_inputs.to_device() - def build_batchitem_dataset( self, dataset: XarrayConcat, @@ -301,14 +259,7 @@ def build( self, requirements: DataRequirements, dist: Distributed | None = None, - static_inputs: StaticInputs | None = None, ) -> GriddedData: - # TODO: static_inputs_from_checkpoint is currently passed from the model - # to allow loading fine topography when no fine data is available. - # See PR https://github.com/ai2cm/ace/pull/728 - # In the future we could disentangle this dependency between the data loader - # and model by enabling the built GriddedData objects to take in full static - # input fields and subset them to the same coordinate range as data. xr_dataset, properties = self.get_xarray_dataset( names=requirements.coarse_names, n_timesteps=1 ) @@ -316,7 +267,6 @@ def build( raise ValueError( "Downscaling data loader only supports datasets with latlon coords." ) - latlon_coords = properties.horizontal_coordinates dataset = self.build_batchitem_dataset( dataset=xr_dataset, properties=properties, @@ -343,14 +293,8 @@ def build( persistent_workers=True if self.num_data_workers > 0 else False, ) example = dataset[0] - subset_static_inputs = self.build_static_inputs( - coarse_coords=latlon_coords, - requires_topography=requirements.use_fine_topography, - static_inputs=static_inputs, - ) return GriddedData( _loader=dataloader, - static_inputs=subset_static_inputs, shape=example.horizontal_shape, dims=example.latlon_coordinates.dims, variable_metadata=dataset.variable_metadata, @@ -395,7 +339,6 @@ class PairedDataLoaderConfig: time dimension. Useful to include longer sequences of small data for testing. topography: Deprecated field for specifying the topography dataset. - Now provided via build method's `static_inputs` argument. sample_with_replacement: If provided, the dataset will be sampled randomly with replacement to the given size each period, instead of retrieving each sample once (either shuffled or not). @@ -427,8 +370,8 @@ def __post_init__(self): if self.topography is not None: raise ValueError( "The `topography` field on PairedDataLoaderConfig is deprecated and " - "will be removed in a future release. Pass static_inputs via the " - "build method's `static_inputs` argument instead." + "will be removed in a future release. `StaticInputs` are now stored " + "within the model when it is first built and trained." ) def _first_data_config( @@ -468,14 +411,7 @@ def build( train: bool, requirements: DataRequirements, dist: Distributed | None = None, - static_inputs: StaticInputs | None = None, ) -> PairedGriddedData: - # TODO: static_inputs_from_checkpoint is currently passed from the model - # to allow loading fine topography when no fine data is available. - # See PR https://github.com/ai2cm/ace/pull/728 - # In the future we could disentangle this dependency between the data loader - # and model by enabling the built GriddedData objects to take in full static - # input fields and subset them to the same coordinate range as data. if dist is None: dist = Distributed.get_instance() @@ -537,25 +473,6 @@ def build( full_fine_coord=properties_fine.horizontal_coordinates.lon, ) - if requirements.use_fine_topography: - if static_inputs is None: - raise ValueError( - "Model requires static inputs (use_fine_topography=True)," - " but no static inputs were provided to the data loader's" - " build method." - ) - - static_inputs = static_inputs.to_device() - _check_fine_res_static_input_compatibility( - static_inputs.shape, - properties_fine.horizontal_coordinates.shape, - ) - static_inputs = static_inputs.subset_latlon( - lat_interval=fine_lat_extent, lon_interval=fine_lon_extent - ) - else: - static_inputs = None - dataset_fine_subset = HorizontalSubsetDataset( dataset_fine, properties=properties_fine, @@ -611,12 +528,12 @@ def build( return PairedGriddedData( _loader=dataloader, - static_inputs=static_inputs, coarse_shape=example.coarse.horizontal_shape, downscale_factor=example.downscale_factor, dims=example.fine.latlon_coordinates.dims, variable_metadata=variable_metadata, all_times=all_times, + fine_coords=get_latlon_coords_from_properties(properties_fine), ) def _get_sampler( diff --git a/fme/downscaling/data/datasets.py b/fme/downscaling/data/datasets.py index 38e7f9191..b11d0e49d 100644 --- a/fme/downscaling/data/datasets.py +++ b/fme/downscaling/data/datasets.py @@ -20,14 +20,12 @@ from fme.core.generics.data import SizedMap from fme.core.typing_ import TensorMapping from fme.downscaling.data.patching import Patch, get_patches -from fme.downscaling.data.static import StaticInputs from fme.downscaling.data.utils import ( BatchedLatLonCoordinates, ClosedInterval, check_leading_dim, expand_and_fold_tensor, get_offset, - null_generator, paired_shuffle, scale_tuple, ) @@ -298,7 +296,6 @@ class GriddedData: dims: list[str] variable_metadata: Mapping[str, VariableMetadata] all_times: xr.CFTimeIndex - static_inputs: StaticInputs | None @property def loader(self) -> DataLoader[BatchItem]: @@ -307,27 +304,8 @@ def on_device(batch: BatchItem) -> BatchItem: return SizedMap(on_device, self._loader) - @property - def topography_downscale_factor(self) -> int | None: - if self.static_inputs: - if ( - self.static_inputs.shape[0] % self.shape[0] != 0 - or self.static_inputs.shape[1] % self.shape[1] != 0 - ): - raise ValueError( - "Static inputs shape must be evenly divisible by data shape. " - f"Got static inputs with shape {self.static_inputs.shape} " - f"and data with shape {self.shape}" - ) - return self.static_inputs.shape[0] // self.shape[0] - else: - return None - - def get_generator( - self, - ) -> Iterator[tuple["BatchData", StaticInputs | None]]: - for batch in self.loader: - yield (batch, self.static_inputs) + def get_generator(self) -> Iterator["BatchData"]: + yield from self.loader def get_patched_generator( self, @@ -335,20 +313,18 @@ def get_patched_generator( overlap: int = 0, drop_partial_patches: bool = True, random_offset: bool = False, - ) -> Iterator[tuple["BatchData", StaticInputs | None]]: + ) -> Iterator["BatchData"]: patched_generator = patched_batch_gen_from_loader( loader=self.loader, - static_inputs=self.static_inputs, coarse_yx_extent=self.shape, coarse_yx_patch_extent=yx_patch_extent, - downscale_factor=self.topography_downscale_factor, coarse_overlap=overlap, drop_partial_patches=drop_partial_patches, random_offset=random_offset, ) return cast( - Iterator[tuple[BatchData, StaticInputs | None]], + Iterator[BatchData], patched_generator, ) @@ -361,7 +337,7 @@ class PairedGriddedData: dims: list[str] variable_metadata: Mapping[str, VariableMetadata] all_times: xr.CFTimeIndex - static_inputs: StaticInputs | None + fine_coords: LatLonCoordinates @property def loader(self) -> DataLoader[PairedBatchItem]: @@ -370,11 +346,8 @@ def on_device(batch: PairedBatchItem) -> PairedBatchItem: return SizedMap(on_device, self._loader) - def get_generator( - self, - ) -> Iterator[tuple["PairedBatchData", StaticInputs | None]]: - for batch in self.loader: - yield (batch, self.static_inputs) + def get_generator(self) -> Iterator["PairedBatchData"]: + yield from self.loader def get_patched_generator( self, @@ -383,10 +356,9 @@ def get_patched_generator( drop_partial_patches: bool = True, random_offset: bool = False, shuffle: bool = False, - ) -> Iterator[tuple["PairedBatchData", StaticInputs | None]]: + ) -> Iterator["PairedBatchData"]: patched_generator = patched_batch_gen_from_paired_loader( self.loader, - self.static_inputs, coarse_yx_extent=self.coarse_shape, coarse_yx_patch_extent=coarse_yx_patch_extent, downscale_factor=self.downscale_factor, @@ -396,7 +368,7 @@ def get_patched_generator( shuffle=shuffle, ) return cast( - Iterator[tuple[PairedBatchData, StaticInputs | None]], + Iterator[PairedBatchData], patched_generator, ) @@ -669,6 +641,9 @@ def __iter__(self): return iter(indices[start:end]) +# downscale_factor=None means fine patches not needed here, but reusing +# _get_paired_patches in both paired and no-target cases to share the +# coincident offset logic. def _get_paired_patches( coarse_yx_extent: tuple[int, int], coarse_yx_patch_extent: tuple[int, int], @@ -713,44 +688,28 @@ def _get_paired_patches( def patched_batch_gen_from_loader( loader: DataLoader[BatchItem], - static_inputs: StaticInputs | None, coarse_yx_extent: tuple[int, int], coarse_yx_patch_extent: tuple[int, int], - downscale_factor: int | None, coarse_overlap: int = 0, drop_partial_patches: bool = True, random_offset: bool = False, shuffle: bool = False, -) -> Iterator[tuple[BatchData, StaticInputs | None]]: +) -> Iterator[BatchData]: for batch in loader: - coarse_patches, fine_patches = _get_paired_patches( + coarse_patches, _ = _get_paired_patches( coarse_yx_extent=coarse_yx_extent, coarse_yx_patch_extent=coarse_yx_patch_extent, coarse_overlap=coarse_overlap, - downscale_factor=downscale_factor, + downscale_factor=None, random_offset=random_offset, shuffle=shuffle, drop_partial_patches=drop_partial_patches, ) - batch_data_patches = batch.generate_from_patches(coarse_patches) - - if static_inputs is not None: - if fine_patches is None: - raise ValueError( - "Topography provided but downscale_factor is None, cannot " - "generate fine patches." - ) - static_inputs_patches = static_inputs.generate_from_patches(fine_patches) - else: - static_inputs_patches = null_generator(len(coarse_patches)) - - # Combine outputs from both generators - yield from zip(batch_data_patches, static_inputs_patches) + yield from batch.generate_from_patches(coarse_patches) def patched_batch_gen_from_paired_loader( loader: DataLoader[PairedBatchItem], - static_inputs: StaticInputs | None, coarse_yx_extent: tuple[int, int], coarse_yx_patch_extent: tuple[int, int], downscale_factor: int, @@ -758,7 +717,7 @@ def patched_batch_gen_from_paired_loader( drop_partial_patches: bool = True, random_offset: bool = False, shuffle: bool = False, -) -> Iterator[tuple[PairedBatchData, StaticInputs | None]]: +) -> Iterator[PairedBatchData]: for batch in loader: coarse_patches, fine_patches = _get_paired_patches( coarse_yx_extent=coarse_yx_extent, @@ -769,17 +728,4 @@ def patched_batch_gen_from_paired_loader( shuffle=shuffle, drop_partial_patches=drop_partial_patches, ) - batch_data_patches = batch.generate_from_patches(coarse_patches, fine_patches) - - if static_inputs is not None: - if fine_patches is None: - raise ValueError( - "Static inputs provided but downscale_factor is None, cannot " - "generate fine patches." - ) - static_inputs_patches = static_inputs.generate_from_patches(fine_patches) - else: - static_inputs_patches = null_generator(len(coarse_patches)) - - # Combine outputs from both generators - yield from zip(batch_data_patches, static_inputs_patches) + yield from batch.generate_from_patches(coarse_patches, fine_patches) diff --git a/fme/downscaling/data/static.py b/fme/downscaling/data/static.py index 2c7199cf9..48ab54ed7 100644 --- a/fme/downscaling/data/static.py +++ b/fme/downscaling/data/static.py @@ -1,30 +1,19 @@ import dataclasses -from collections.abc import Generator, Iterator import torch import xarray as xr -from fme.core.coordinates import LatLonCoordinates from fme.core.device import get_device from fme.downscaling.data.patching import Patch -from fme.downscaling.data.utils import ClosedInterval @dataclasses.dataclass class StaticInput: data: torch.Tensor - coords: LatLonCoordinates def __post_init__(self): if len(self.data.shape) != 2: raise ValueError(f"Topography data must be 2D. Got shape {self.data.shape}") - if self.data.shape[0] != len(self.coords.lat) or self.data.shape[1] != len( - self.coords.lon - ): - raise ValueError( - f"Static inputs data shape {self.data.shape} does not match lat/lon " - f"coordinates shape {(len(self.coords.lat), len(self.coords.lon))}" - ) @property def dim(self) -> int: @@ -34,56 +23,23 @@ def dim(self) -> int: def shape(self) -> tuple[int, int]: return self.data.shape - def subset_latlon( + def subset( self, - lat_interval: ClosedInterval, - lon_interval: ClosedInterval, + lat_slice: slice, + lon_slice: slice, ) -> "StaticInput": - lat_slice = lat_interval.slice_of(self.coords.lat) - lon_slice = lon_interval.slice_of(self.coords.lon) - return self._latlon_index_slice(lat_slice=lat_slice, lon_slice=lon_slice) + return StaticInput(data=self.data[lat_slice, lon_slice]) def to_device(self) -> "StaticInput": device = get_device() - return StaticInput( - data=self.data.to(device), - coords=LatLonCoordinates( - lat=self.coords.lat.to(device), - lon=self.coords.lon.to(device), - ), - ) + return StaticInput(data=self.data.to(device)) def _apply_patch(self, patch: Patch): - return self._latlon_index_slice( - lat_slice=patch.input_slice.y, lon_slice=patch.input_slice.x - ) - - def _latlon_index_slice( - self, - lat_slice: slice, - lon_slice: slice, - ) -> "StaticInput": - sliced_data = self.data[lat_slice, lon_slice] - sliced_latlon = LatLonCoordinates( - lat=self.coords.lat[lat_slice], - lon=self.coords.lon[lon_slice], - ) - return StaticInput( - data=sliced_data, - coords=sliced_latlon, - ) - - def generate_from_patches( - self, - patches: list[Patch], - ) -> Generator["StaticInput", None, None]: - for patch in patches: - yield self._apply_patch(patch) + return self.subset(lat_slice=patch.input_slice.y, lon_slice=patch.input_slice.x) def get_state(self) -> dict: return { "data": self.data.cpu(), - "coords": self.coords.get_state(), } @@ -107,17 +63,11 @@ def _get_normalized_static_input(path: str, field_name: str): f"unexpected shape {static_input.shape} for static input." "Currently, only lat/lon static input is supported." ) - lat_name, lon_name = static_input.dims[-2:] - coords = LatLonCoordinates( - lon=torch.tensor(static_input[lon_name].values), - lat=torch.tensor(static_input[lat_name].values), - ) static_input_normalized = (static_input - static_input.mean()) / static_input.std() return StaticInput( data=torch.tensor(static_input_normalized.values, dtype=torch.float32), - coords=coords, ) @@ -127,50 +77,33 @@ class StaticInputs: def __post_init__(self): for i, field in enumerate(self.fields[1:]): - if field.coords != self.fields[0].coords: + if field.shape != self.fields[0].shape: raise ValueError( - f"All StaticInput fields must have the same coordinates. " - f"Fields {i} and 0 do not match coordinates." + f"All StaticInput fields must have the same shape. " + f"Fields {i + 1} and 0 do not match shapes." ) def __getitem__(self, index: int): return self.fields[index] - @property - def coords(self) -> LatLonCoordinates: - if len(self.fields) == 0: - raise ValueError("No fields in StaticInputs to get coordinates from.") - return self.fields[0].coords - @property def shape(self) -> tuple[int, int]: if len(self.fields) == 0: raise ValueError("No fields in StaticInputs to get shape from.") return self.fields[0].shape - def subset_latlon( + def subset( self, - lat_interval: ClosedInterval, - lon_interval: ClosedInterval, + lat_slice: slice, + lon_slice: slice, ) -> "StaticInputs": return StaticInputs( - fields=[ - field.subset_latlon(lat_interval, lon_interval) for field in self.fields - ] + fields=[field.subset(lat_slice, lon_slice) for field in self.fields] ) def to_device(self) -> "StaticInputs": return StaticInputs(fields=[field.to_device() for field in self.fields]) - def generate_from_patches( - self, - patches: list[Patch], - ) -> Iterator["StaticInputs"]: - for patch in patches: - yield StaticInputs( - fields=[field._apply_patch(patch) for field in self.fields] - ) - def get_state(self) -> dict: return { "fields": [field.get_state() for field in self.fields], @@ -182,10 +115,6 @@ def from_state(cls, state: dict) -> "StaticInputs": fields=[ StaticInput( data=field_state["data"], - coords=LatLonCoordinates( - lat=field_state["coords"]["lat"], - lon=field_state["coords"]["lon"], - ), ) for field_state in state["fields"] ] diff --git a/fme/downscaling/data/test_config.py b/fme/downscaling/data/test_config.py index 1aaff8577..f916abadd 100644 --- a/fme/downscaling/data/test_config.py +++ b/fme/downscaling/data/test_config.py @@ -1,7 +1,6 @@ import dataclasses import pytest -import torch from fme.core.dataset.merged import MergeNoConcatDatasetConfig from fme.core.dataset.xarray import XarrayDataConfig @@ -10,25 +9,11 @@ PairedDataLoaderConfig, XarrayEnsembleDataConfig, ) -from fme.downscaling.data.static import StaticInput, StaticInputs -from fme.downscaling.data.utils import ClosedInterval, LatLonCoordinates +from fme.downscaling.data.utils import ClosedInterval from fme.downscaling.requirements import DataRequirements from fme.downscaling.test_utils import data_paths_helper -def get_static_inputs(shape=(8, 8)): - return StaticInputs( - fields=[ - StaticInput( - data=torch.ones(shape), - coords=LatLonCoordinates( - lat=torch.ones(shape[0]), lon=torch.ones(shape[1]) - ), - ) - ] - ) - - @pytest.mark.parametrize( "fine_engine, coarse_engine, num_data_workers, expected", [ @@ -78,9 +63,7 @@ def test_DataLoaderConfig_build(tmp_path, very_fast_only: bool): lat_extent=ClosedInterval(1, 4), lon_extent=ClosedInterval(0, 3), ) - data = data_config.build( - requirements=requirements, static_inputs=get_static_inputs(shape=(8, 8)) - ) + data = data_config.build(requirements=requirements) batch = next(iter(data.loader)) # lat/lon midpoints are on (0.5, 1.5, ...) assert batch.data["var0"].shape == (2, 3, 3) @@ -152,9 +135,7 @@ def test_DataLoaderConfig_includes_merge(tmp_path, very_fast_only: bool): lon_extent=ClosedInterval(0, 3), ) - data = data_config.build( - requirements=requirements, static_inputs=get_static_inputs(shape=(8, 8)) - ) + data = data_config.build(requirements=requirements) # XarrayDataConfig + MergeNoConcatDatasetConfig each # contribute 4 timesteps = 8 total assert len(data.loader) == 4 # 8 samples / batch_size 2 diff --git a/fme/downscaling/data/test_patching.py b/fme/downscaling/data/test_patching.py index 7630722c4..fb6cbc7de 100644 --- a/fme/downscaling/data/test_patching.py +++ b/fme/downscaling/data/test_patching.py @@ -2,8 +2,7 @@ import pytest import torch -from fme.core.device import get_device -from fme.downscaling.data import PairedBatchData, StaticInput, StaticInputs +from fme.downscaling.data import PairedBatchData from fme.downscaling.data.datasets import patched_batch_gen_from_paired_loader from fme.downscaling.data.patching import ( _divide_into_slices, @@ -115,19 +114,10 @@ def test_paired_patches_with_random_offset_consistent(overlap): full_coarse_coords = full_data.coarse.latlon_coordinates full_fine_coords = full_data.fine.latlon_coordinates - topography_data = torch.randn( - coarse_shape[0] * downscale_factor, - coarse_shape[1] * downscale_factor, - device=get_device(), - ) - topography = StaticInputs( - fields=[StaticInput(data=topography_data, coords=full_fine_coords[0])] - ) y_offsets = [] x_offsets = [] batch_generator = patched_batch_gen_from_paired_loader( loader=loader, - static_inputs=topography, coarse_yx_extent=coarse_shape, coarse_yx_patch_extent=(10, 10), downscale_factor=downscale_factor, @@ -136,7 +126,7 @@ def test_paired_patches_with_random_offset_consistent(overlap): random_offset=True, ) paired_batch: PairedBatchData - for paired_batch, _ in batch_generator: # type: ignore + for paired_batch in batch_generator: assert paired_batch.coarse.data["x"].shape == (batch_size, 10, 10) assert paired_batch.fine.data["x"].shape == (batch_size, 20, 20) @@ -177,19 +167,9 @@ def test_paired_patches_shuffle(shuffle): loader = _mock_data_loader( 10, *coarse_shape, downscale_factor=downscale_factor, batch_size=batch_size ) - topography_data = torch.randn( - coarse_shape[0] * downscale_factor, - coarse_shape[1] * downscale_factor, - device=get_device(), - ) - fine_coords = next(iter(loader)).fine.latlon_coordinates[0] - static_inputs = StaticInputs( - fields=[StaticInput(data=topography_data, coords=fine_coords)] - ) generator0 = patched_batch_gen_from_paired_loader( loader=loader, - static_inputs=static_inputs, coarse_yx_extent=coarse_shape, coarse_yx_patch_extent=(2, 2), downscale_factor=downscale_factor, @@ -200,7 +180,6 @@ def test_paired_patches_shuffle(shuffle): ) generator1 = patched_batch_gen_from_paired_loader( loader=loader, - static_inputs=static_inputs, coarse_yx_extent=coarse_shape, coarse_yx_patch_extent=(2, 2), downscale_factor=downscale_factor, @@ -212,28 +191,14 @@ def test_paired_patches_shuffle(shuffle): patches0: list[PairedBatchData] = [] patches1: list[PairedBatchData] = [] - topography0: list[torch.Tensor] = [] - topography1: list[torch.Tensor] = [] for i in range(4): - p0, t0 = next(generator0) - patches0.append(p0) # type: ignore - topography0.append(t0) - p1, t1 = next(generator1) - patches1.append(p1) # type: ignore - topography1.append(t1) + patches0.append(next(generator0)) # type: ignore + patches1.append(next(generator1)) # type: ignore data0 = torch.concat([patch.coarse.data["x"] for patch in patches0], dim=0) data1 = torch.concat([patch.coarse.data["x"] for patch in patches1], dim=0) - topo_concat_0 = torch.concat( - [t0.fields[0].data for t0 in topography0 if t0 is not None], dim=0 - ) - topo_concat_1 = torch.concat( - [t1.fields[0].data for t1 in topography1 if t1 is not None], dim=0 - ) if shuffle: assert not torch.equal(data0, data1) - assert not torch.equal(topo_concat_0, topo_concat_1) else: assert torch.equal(data0, data1) - assert torch.equal(topo_concat_0, topo_concat_1) diff --git a/fme/downscaling/data/test_static.py b/fme/downscaling/data/test_static.py index 8aeace714..104ffdc9f 100644 --- a/fme/downscaling/data/test_static.py +++ b/fme/downscaling/data/test_static.py @@ -1,27 +1,16 @@ import pytest import torch -from fme.core.coordinates import LatLonCoordinates -from fme.downscaling.data.patching import Patch, _HorizontalSlice - from .static import StaticInput, StaticInputs -from .utils import ClosedInterval @pytest.mark.parametrize( "init_args", [ pytest.param( - [ - torch.randn((1, 2, 2)), - LatLonCoordinates(torch.arange(2), torch.arange(2)), - ], + [torch.randn((1, 2, 2))], id="3d_data", ), - pytest.param( - [torch.randn((2, 2)), LatLonCoordinates(torch.arange(2), torch.arange(5))], - id="dim_size_mismatch", - ), ], ) def test_Topography_error_cases(init_args): @@ -29,106 +18,43 @@ def test_Topography_error_cases(init_args): StaticInput(*init_args) -def test_subset_latlon(): +def test_subset(): full_data_shape = (10, 10) - expected_slices = [slice(2, 6), slice(3, 8)] data = torch.randn(*full_data_shape) - coords = LatLonCoordinates( - lat=torch.linspace(0, 9, 10), lon=torch.linspace(0, 9, 10) - ) - topo = StaticInput(data=data, coords=coords) - lat_interval = ClosedInterval(2, 5) - lon_interval = ClosedInterval(3, 7) - subset_topo = topo.subset_latlon(lat_interval, lon_interval) - expected_lats = torch.tensor([2, 3, 4, 5], dtype=coords.lat.dtype) - expected_lons = torch.tensor([3, 4, 5, 6, 7], dtype=coords.lon.dtype) - expected_data = data[*expected_slices] - assert torch.equal(subset_topo.coords.lat, expected_lats) - assert torch.equal(subset_topo.coords.lon, expected_lons) - assert torch.allclose(subset_topo.data, expected_data) - - -def test_Topography_generate_from_patches(): - output_slice = _HorizontalSlice(y=slice(None), x=slice(None)) - patches = [ - Patch( - input_slice=_HorizontalSlice(y=slice(1, 3), x=slice(None, None)), - output_slice=output_slice, - ), - Patch( - input_slice=_HorizontalSlice(y=slice(0, 2), x=slice(2, 3)), - output_slice=output_slice, - ), - ] - topography = StaticInput( - torch.arange(16).reshape(4, 4), - LatLonCoordinates(torch.arange(4), torch.arange(4)), - ) - topo_patch_generator = topography.generate_from_patches(patches) - generated_patches = [] - for topo_patch in topo_patch_generator: - generated_patches.append(topo_patch) - assert len(generated_patches) == 2 - assert torch.equal( - generated_patches[0].data, torch.tensor([[4, 5, 6, 7], [8, 9, 10, 11]]) - ) - assert torch.equal(generated_patches[1].data, torch.tensor([[2], [6]])) - - -def test_StaticInputs_generate_from_patches(): - output_slice = _HorizontalSlice(y=slice(None), x=slice(None)) - patches = [ - Patch( - input_slice=_HorizontalSlice(y=slice(1, 3), x=slice(None, None)), - output_slice=output_slice, - ), - Patch( - input_slice=_HorizontalSlice(y=slice(0, 2), x=slice(2, 3)), - output_slice=output_slice, - ), - ] - data = torch.arange(16).reshape(4, 4) - topography = StaticInput( - data, - LatLonCoordinates(torch.arange(4), torch.arange(4)), - ) - land_frac = StaticInput( - data * -1.0, - LatLonCoordinates(torch.arange(4), torch.arange(4)), - ) - static_inputs = StaticInputs([topography, land_frac]) - static_inputs_patch_generator = static_inputs.generate_from_patches(patches) - generated_patches = [] - for static_inputs_patch in static_inputs_patch_generator: - generated_patches.append(static_inputs_patch) - - assert len(generated_patches) == 2 - - expected_topography_patch_0 = torch.tensor([[4, 5, 6, 7], [8, 9, 10, 11]]) - expected_topography_patch_1 = torch.tensor([[2], [6]]) - - # first index is the patch, second is the static input field within - # the StaticInputs container - assert torch.equal(generated_patches[0][0].data, expected_topography_patch_0) - assert torch.equal(generated_patches[1][0].data, expected_topography_patch_1) - - # land_frac field values are -1 * topography - assert torch.equal(generated_patches[0][1].data, expected_topography_patch_0 * -1.0) - assert torch.equal(generated_patches[1][1].data, expected_topography_patch_1 * -1.0) + topo = StaticInput(data=data) + lat_slice = slice(2, 6) + lon_slice = slice(3, 8) + subset_topo = topo.subset(lat_slice, lon_slice) + assert torch.allclose(subset_topo.data, data[lat_slice, lon_slice]) def test_StaticInputs_serialize(): data = torch.arange(16).reshape(4, 4) - topography = StaticInput( - data, - LatLonCoordinates(torch.arange(4), torch.arange(4)), - ) - land_frac = StaticInput( - data * -1.0, - LatLonCoordinates(torch.arange(4), torch.arange(4)), - ) + topography = StaticInput(data) + land_frac = StaticInput(data * -1.0) static_inputs = StaticInputs([topography, land_frac]) state = static_inputs.get_state() + # Verify coords are NOT stored in state + assert "coords" not in state["fields"][0] static_inputs_reconstructed = StaticInputs.from_state(state) assert static_inputs_reconstructed[0].data.equal(static_inputs[0].data) assert static_inputs_reconstructed[1].data.equal(static_inputs[1].data) + + +def test_StaticInputs_serialize_backward_compat_with_coords(): + """from_state should silently ignore 'coords' key for old state dicts.""" + data = torch.arange(16, dtype=torch.float32).reshape(4, 4) + # Simulate old state dict format that included coords + old_state = { + "fields": [ + { + "data": data, + "coords": { + "lat": torch.arange(4, dtype=torch.float32), + "lon": torch.arange(4, dtype=torch.float32), + }, + } + ] + } + static_inputs = StaticInputs.from_state(old_state) + assert torch.equal(static_inputs[0].data, data) diff --git a/fme/downscaling/evaluator.py b/fme/downscaling/evaluator.py index c2813ae37..e7a19727b 100644 --- a/fme/downscaling/evaluator.py +++ b/fme/downscaling/evaluator.py @@ -15,7 +15,6 @@ from fme.downscaling.data import ( PairedDataLoaderConfig, PairedGriddedData, - StaticInputs, enforce_lat_bounds, ) from fme.downscaling.models import CheckpointModelConfig, DiffusionModel @@ -56,12 +55,10 @@ def run(self): else: batch_generator = self.data.get_generator() - for i, (batch, static_inputs) in enumerate(batch_generator): + for i, batch in enumerate(batch_generator): with torch.no_grad(): logging.info(f"Generating predictions on batch {i + 1}") - outputs = self.model.generate_on_batch( - batch, static_inputs, n_samples=self.n_samples - ) + outputs = self.model.generate_on_batch(batch, n_samples=self.n_samples) logging.info("Recording diagnostics to aggregator") # Add sample dimension to coarse values for generation comparison coarse = {k: v.unsqueeze(1) for k, v in batch.coarse.data.items()} @@ -107,7 +104,7 @@ def __init__( def run(self): logging.info(f"Running {self.event_name} event evaluation") - batch, static_inputs = next(iter(self.data.get_generator())) + batch = next(iter(self.data.get_generator())) sample_agg = PairedSampleAggregator( target=batch[0].fine.data, coarse=batch[0].coarse.data, @@ -125,9 +122,7 @@ def run(self): f"Generating samples {start_idx} to {end_idx} " f"for event {self.event_name}" ) - outputs = self.model.generate_on_batch( - batch, static_inputs, n_samples=end_idx - start_idx - ) + outputs = self.model.generate_on_batch(batch, n_samples=end_idx - start_idx) sample_agg.record_batch(outputs.prediction) to_log = sample_agg.get_wandb() @@ -152,7 +147,6 @@ def get_paired_gridded_data( self, base_data_config: PairedDataLoaderConfig, requirements: DataRequirements, - static_inputs_from_checkpoint: StaticInputs | None = None, ) -> PairedGriddedData: enforce_lat_bounds(self.lat_extent) time_slice = self._time_selection_slice @@ -173,7 +167,6 @@ def get_paired_gridded_data( return event_data_config.build( train=False, requirements=requirements, - static_inputs=static_inputs_from_checkpoint, ) @@ -200,7 +193,6 @@ def _build_default_evaluator(self) -> Evaluator: dataset = self.data.build( train=False, requirements=self.model.data_requirements, - static_inputs=model.static_inputs, ) evaluator_model: DiffusionModel | PatchPredictor if self.patch.divide_generation and self.patch.composite_prediction: @@ -237,7 +229,6 @@ def _build_event_evaluator( dataset = event_config.get_paired_gridded_data( base_data_config=self.data, requirements=self.model.data_requirements, - static_inputs_from_checkpoint=model.static_inputs, ) if (dataset.coarse_shape[0] > model.coarse_shape[0]) or ( diff --git a/fme/downscaling/inference/inference.py b/fme/downscaling/inference/inference.py index 7b5f1a361..c78885e90 100644 --- a/fme/downscaling/inference/inference.py +++ b/fme/downscaling/inference/inference.py @@ -3,6 +3,7 @@ from dataclasses import dataclass, field import dacite +import numpy as np import torch import yaml @@ -10,7 +11,7 @@ from fme.core.generics.trainer import count_parameters from fme.core.logging_utils import LoggingConfig -from ..data import DataLoaderConfig, StaticInputs +from ..data import DataLoaderConfig from ..models import CheckpointModelConfig, DiffusionModel from ..predictors import PatchPredictionConfig, PatchPredictor from .output import DownscalingOutput, EventConfig, TimeRangeConfig @@ -50,7 +51,7 @@ def run_all(self): def _get_generation_model( self, - static_inputs: StaticInputs, + input_shape: tuple[int, int], output: DownscalingOutput, ) -> DiffusionModel | PatchPredictor: """ @@ -60,20 +61,22 @@ def _get_generation_model( the user to use patching for larger domains because that provides better generations. """ - model_patch_shape = self.model.fine_shape - actual_shape = tuple(static_inputs.shape) + model_patch_shape = self.model.coarse_shape - if model_patch_shape == actual_shape: + if model_patch_shape == input_shape: # short circuit, no patching necessary return self.model elif any( expected > actual - for expected, actual in zip(model_patch_shape, actual_shape) + for expected, actual in zip(model_patch_shape, input_shape) ): # we don't support generating regions smaller than the model patch size raise ValueError( f"Model coarse shape {model_patch_shape} is larger than " - f"actual topography shape {actual_shape} for output {output.name}." + f"actual input shape {input_shape} for output {output.name}." + "We do not support generating outputs with a smaller spatial extent" + " than the model's trained patch size. Please adjust the spatial extent" + " to be at least as large as the model's input patch size." ) elif output.patch.needs_patch_predictor: # Use a patch predictor @@ -86,14 +89,14 @@ def _get_generation_model( # User should enable patching raise ValueError( f"Model coarse shape {model_patch_shape} does not match " - f"actual input shape {actual_shape} for output {output.name}, " + f"actual input shape {input_shape} for output {output.name}, " "and patch prediction is not configured. Generation for larger domains " "requires patch prediction." ) def _on_device_generator(self, loader): - for loaded_item, topography in loader: - yield loaded_item.to_device(), topography.to_device() + for loaded_item in loader: + yield loaded_item.to_device() def run_output_generation(self, output: DownscalingOutput): """Execute the generation loop for this output.""" @@ -105,20 +108,20 @@ def run_output_generation(self, output: DownscalingOutput): total_batches = len(output.data.loader) loaded_item: LoadedSliceWorkItem - static_inputs: StaticInputs - for i, (loaded_item, static_inputs) in enumerate(output.data.get_generator()): + for i, loaded_item in enumerate(output.data.get_generator()): + input_shape = loaded_item.batch.horizontal_shape + if model is None: + model = self._get_generation_model( + input_shape=input_shape, output=output + ) + if writer is None: + fine_latlon_coords = model.get_fine_coords_for_batch(loaded_item.batch) writer = output.get_writer( - latlon_coords=static_inputs.coords, + latlon_coords=fine_latlon_coords, output_dir=self.output_dir, ) - writer.initialize_store( - static_inputs.fields[0].data.cpu().numpy().dtype - ) - if model is None: - model = self._get_generation_model( - static_inputs=static_inputs, output=output - ) + writer.initialize_store(np.float32) logging.info( f"[{output.name}] Batch {i+1}/{total_batches}, " @@ -127,7 +130,6 @@ def run_output_generation(self, output: DownscalingOutput): output_data = model.generate_on_batch_no_target( loaded_item.batch, - static_inputs=static_inputs, n_samples=loaded_item.n_ens, ) output_np = {key: value.cpu().numpy() for key, value in output_data.items()} @@ -238,7 +240,6 @@ def build(self) -> Downscaler: loader_config=self.data, requirements=self.model.data_requirements, patch=self.patch, - static_inputs_from_checkpoint=model.static_inputs, ) for output_cfg in self.outputs ] diff --git a/fme/downscaling/inference/output.py b/fme/downscaling/inference/output.py index 012b555ad..e29000679 100644 --- a/fme/downscaling/inference/output.py +++ b/fme/downscaling/inference/output.py @@ -18,7 +18,6 @@ ClosedInterval, DataLoaderConfig, LatLonCoordinates, - StaticInputs, enforce_lat_bounds, ) from ..data.config import XarrayEnsembleDataConfig @@ -29,6 +28,13 @@ from .zarr_utils import determine_zarr_chunks +@dataclass +class WriterParams: + chunks: dict[str, int] + shards: dict[str, int] + coords: dict[str, np.ndarray] + + def _identity_collate(batch): """ Collate function that returns the single batch item. @@ -67,8 +73,8 @@ def __init__( max_samples_per_gpu: int, data: SliceWorkItemGriddedData, patch: PatchPredictionConfig, - chunks: dict[str, int], - shards: dict[str, int], + zarr_chunks_override: dict[str, int] | None, + zarr_shards_override: dict[str, int] | None, dims: tuple[str, ...] = DIMS, ) -> None: self.name = name @@ -77,22 +83,20 @@ def __init__( self.max_samples_per_gpu = max_samples_per_gpu self.data = data self.patch = patch - self.chunks = chunks - self.shards = shards + self.zarr_chunks_override = zarr_chunks_override + self.zarr_shards_override = zarr_shards_override self.dims = dims - def get_writer( - self, - latlon_coords: LatLonCoordinates, - output_dir: str, - ) -> ZarrWriter: - """ - Create a ZarrWriter for this target. - - Args: - latlon_coords: High-resolution spatial coordinates for outputs - output_dir: Directory to store output zarr file - """ + def _build_writer_params(self, latlon_coords: LatLonCoordinates) -> WriterParams: + lat_size = len(latlon_coords.lat) + lon_size = len(latlon_coords.lon) + n_times, n_ens = self.data.max_output_shape + full_shape = (n_times, n_ens, lat_size, lon_size) + element_size = torch.tensor([], dtype=self.data.dtype).element_size() + chunks = self.zarr_chunks_override or determine_zarr_chunks( + DIMS, full_shape, element_size + ) + shards = self.zarr_shards_override or dict(zip(DIMS, full_shape)) ensemble = list(range(self.n_ens)) coords = dict( zip( @@ -105,15 +109,29 @@ def get_writer( ], ) ) - dims = tuple(coords.keys()) + return WriterParams(chunks=chunks, shards=shards, coords=coords) + def get_writer( + self, + latlon_coords: LatLonCoordinates, + output_dir: str, + ) -> ZarrWriter: + """ + Create a ZarrWriter for this target. + + Args: + latlon_coords: High-resolution spatial coordinates for outputs + output_dir: Directory to store output zarr file + """ + params = self._build_writer_params(latlon_coords) + dims = tuple(params.coords.keys()) return ZarrWriter( path=f"{output_dir}/{self.name}.zarr", dims=dims, - coords=coords, + coords=params.coords, data_vars=self.save_vars, - chunks=self.chunks, - shards=self.shards, + chunks=params.chunks, + shards=params.shards, ) @@ -218,7 +236,6 @@ def _build_gridded_data( loader_config: DataLoaderConfig, requirements: DataRequirements, dist: Distributed | None = None, - static_inputs_from_checkpoint: StaticInputs | None = None, ) -> SliceWorkItemGriddedData: xr_dataset, properties = loader_config.get_xarray_dataset( names=requirements.coarse_names, n_timesteps=1 @@ -229,13 +246,6 @@ def _build_gridded_data( "Downscaling data loader only supports datasets with latlon coords." ) dataset = loader_config.build_batchitem_dataset(xr_dataset, properties) - topography = loader_config.build_static_inputs( - coords, - requires_topography=requirements.use_fine_topography, - static_inputs=static_inputs_from_checkpoint, - ) - if topography is None: - raise ValueError("Topography is required for downscaling generation.") work_items = get_work_items( n_times=len(dataset), @@ -243,11 +253,9 @@ def _build_gridded_data( max_samples_per_gpu=self.max_samples_per_gpu, ) - # defer topography device placement until after batch generation slice_dataset = SliceItemDataset( slice_items=work_items, dataset=dataset, - spatial_shape=topography.shape, ) # each SliceItemDataset work item loads its own full batch, so batch_size=1 @@ -274,7 +282,6 @@ def _build_gridded_data( all_times=xr_dataset.sample_start_times, dtype=slice_dataset.dtype, max_output_shape=slice_dataset.max_output_shape, - static_inputs=topography, ) def _build( @@ -286,7 +293,6 @@ def _build( requirements: DataRequirements, patch: PatchPredictionConfig, coarse: list[XarrayDataConfig], - static_inputs_from_checkpoint: StaticInputs | None = None, ) -> DownscalingOutput: updated_loader_config = self._replace_loader_config( time, @@ -296,27 +302,7 @@ def _build( loader_config, ) - gridded_data = self._build_gridded_data( - updated_loader_config, - requirements, - static_inputs_from_checkpoint=static_inputs_from_checkpoint, - ) - - if self.zarr_chunks is None: - # Get element size from dtype by creating a dummy tensor - element_size = torch.tensor([], dtype=gridded_data.dtype).element_size() - chunks = determine_zarr_chunks( - dims=DIMS, - data_shape=gridded_data.max_output_shape, - bytes_per_element=element_size, - ) - else: - chunks = self.zarr_chunks - - if self.zarr_shards is None: - shards = dict(zip(DIMS, gridded_data.max_output_shape)) - else: - shards = self.zarr_shards + gridded_data = self._build_gridded_data(updated_loader_config, requirements) return DownscalingOutput( name=self.name, @@ -325,8 +311,8 @@ def _build( max_samples_per_gpu=self.max_samples_per_gpu, data=gridded_data, patch=patch, - chunks=chunks, - shards=shards, + zarr_chunks_override=self.zarr_chunks, + zarr_shards_override=self.zarr_shards, dims=DIMS, ) @@ -386,7 +372,6 @@ def build( loader_config: DataLoaderConfig, requirements: DataRequirements, patch: PatchPredictionConfig, - static_inputs_from_checkpoint: StaticInputs | None = None, ) -> DownscalingOutput: # Convert single time to TimeSlice time: Slice | TimeSlice @@ -409,7 +394,6 @@ def build( requirements=requirements, patch=patch, coarse=coarse, - static_inputs_from_checkpoint=static_inputs_from_checkpoint, ) @@ -469,7 +453,6 @@ def build( loader_config: DataLoaderConfig, requirements: DataRequirements, patch: PatchPredictionConfig, - static_inputs_from_checkpoint: StaticInputs | None = None, ) -> DownscalingOutput: coarse = self._single_xarray_config(loader_config.coarse) return self._build( @@ -480,5 +463,4 @@ def build( requirements=requirements, patch=patch, coarse=coarse, - static_inputs_from_checkpoint=static_inputs_from_checkpoint, ) diff --git a/fme/downscaling/inference/test_inference.py b/fme/downscaling/inference/test_inference.py index e5cf8c641..91c10daf8 100644 --- a/fme/downscaling/inference/test_inference.py +++ b/fme/downscaling/inference/test_inference.py @@ -12,6 +12,7 @@ from fme.core.dataset.time import TimeSlice from fme.core.logging_utils import LoggingConfig from fme.downscaling.data import ( + ClosedInterval, LatLonCoordinates, StaticInput, StaticInputs, @@ -24,7 +25,6 @@ EventConfig, TimeRangeConfig, ) -from fme.downscaling.inference.work_items import SliceWorkItemGriddedData from fme.downscaling.models import ( CheckpointModelConfig, DiffusionModelConfig, @@ -44,7 +44,6 @@ def mock_model(): """Create a mock model with coarse_shape attribute.""" model = MagicMock() model.coarse_shape = (16, 16) - model.fine_shape = (32, 32) return model @@ -64,8 +63,7 @@ def mock_output_target(): def get_static_inputs(shape=(16, 16)): data = torch.randn(shape) - coords = LatLonCoordinates(lat=torch.arange(shape[0]), lon=torch.arange(shape[1])) - return StaticInputs([StaticInput(data=data, coords=coords)]) + return StaticInputs([StaticInput(data=data)]) # Tests for Downscaler initialization @@ -91,8 +89,7 @@ def test_get_generation_model_exact_match(mock_model, mock_output_target): """ Test _get_generation_model returns model unchanged when shapes match exactly. """ - mock_model.fine_shape = (16, 16) - static_inputs = get_static_inputs(shape=(16, 16)) + mock_model.coarse_shape = (16, 16) downscaler = Downscaler( model=mock_model, @@ -100,23 +97,22 @@ def test_get_generation_model_exact_match(mock_model, mock_output_target): ) result = downscaler._get_generation_model( - static_inputs=static_inputs, + input_shape=(16, 16), output=mock_output_target, ) assert result is mock_model -@pytest.mark.parametrize("topo_shape", [(8, 16), (16, 8), (8, 8)]) +@pytest.mark.parametrize("input_shape", [(8, 16), (16, 8), (8, 8)]) def test_get_generation_model_raises_when_domain_too_small( - mock_model, mock_output_target, topo_shape + mock_model, mock_output_target, input_shape ): """ Test _get_generation_model raises ValueError when domain is smaller than model. """ - mock_model.fine_shape = (16, 16) - topo = get_static_inputs(shape=topo_shape) + mock_model.coarse_shape = (16, 16) downscaler = Downscaler( model=mock_model, @@ -125,7 +121,7 @@ def test_get_generation_model_raises_when_domain_too_small( with pytest.raises(ValueError): downscaler._get_generation_model( - static_inputs=topo, + input_shape=input_shape, output=mock_output_target, ) @@ -137,8 +133,7 @@ def test_get_generation_model_creates_patch_predictor_when_needed( Test _get_generation_model creates PatchPredictor for large domains with patching. """ - mock_model.fine_shape = (16, 16) - static_inputs = get_static_inputs(shape=(32, 32)) # Larger than model + mock_model.coarse_shape = (16, 16) patch_config = PatchPredictionConfig( divide_generation=True, @@ -152,7 +147,7 @@ def test_get_generation_model_creates_patch_predictor_when_needed( ) model = downscaler._get_generation_model( - static_inputs=static_inputs, + input_shape=(32, 32), output=mock_output_target, ) @@ -167,8 +162,7 @@ def test_get_generation_model_raises_when_large_domain_without_patching( Test _get_generation_model raises when domain is large but patching not configured. """ - mock_model.fine_shape = (16, 16) - topo = get_static_inputs(shape=(32, 32)) # Larger than model + mock_model.coarse_shape = (16, 16) mock_output_target.patch = PatchPredictionConfig(divide_generation=False) downscaler = Downscaler( @@ -178,7 +172,7 @@ def test_get_generation_model_raises_when_large_domain_without_patching( with pytest.raises(ValueError): downscaler._get_generation_model( - static_inputs=topo, + input_shape=(32, 32), output=mock_output_target, ) @@ -187,26 +181,32 @@ def test_run_target_generation_skips_padding_items( mock_model, mock_output_target, ): - """Test run_target_generation skips writing output for padding items.""" - # Create padding work item + """ + Test run_output_generation calls the model but skips writing for padding items. + """ mock_work_item = MagicMock() mock_work_item.is_padding = True - mock_work_item.n_ens = 4 - mock_work_item.batch = MagicMock() - - static_inputs = get_static_inputs(shape=(16, 16)) - - mock_gridded_data = SliceWorkItemGriddedData( - [mock_work_item], {}, [0], torch.float32, (1, 4, 16, 16), static_inputs + mock_work_item.batch.horizontal_shape = (16, 16) + # Coarse coords are interior so fine can have buffer on each side. + mock_work_item.batch.latlon_coordinates.lat = ( + torch.arange(1, 9).float().unsqueeze(0) ) - mock_output_target.data = mock_gridded_data - mock_model.fine_shape = (16, 16) - - mock_output = { - "var1": torch.randn(1, 4, 16, 16), - "var2": torch.randn(1, 4, 16, 16), + mock_work_item.batch.latlon_coordinates.lon = ( + torch.arange(1, 9).float().unsqueeze(0) + ) + mock_work_item.batch.lat_interval = ClosedInterval(1.0, 8.0) + mock_work_item.batch.lon_interval = ClosedInterval(1.0, 8.0) + mock_output_target.data.get_generator.return_value = iter([mock_work_item]) + + mock_model.downscale_factor = 2 + mock_model.fine_coords = LatLonCoordinates( + lat=torch.arange(0, 18).float(), + lon=torch.arange(0, 18).float(), + ) + mock_model.static_inputs = None + mock_model.generate_on_batch_no_target.return_value = { + "var1": torch.zeros(1, 4, 16, 16), } - mock_model.generate_on_batch_no_target.return_value = mock_output mock_writer = MagicMock() mock_output_target.get_writer.return_value = mock_writer @@ -218,11 +218,8 @@ def test_run_target_generation_skips_padding_items( downscaler.run_output_generation(output=mock_output_target) - # Verify model was still called mock_model.generate_on_batch_no_target.assert_called_once() - - # Verify the mock writer was not called - mock_writer.write_batch.assert_not_called() + mock_writer.record_batch.assert_not_called() # Tests for end-to-end generation process @@ -276,8 +273,16 @@ def checkpointed_model_config( # loader_config is passed in to add static inputs into model # that correspond to the dataset coordinates - static_inputs = load_static_inputs({"HGTsfc": f"{data_paths.fine}/data.nc"}) - model = model_config.build(coarse_shape, 2, static_inputs=static_inputs) + fine_data_path = f"{data_paths.fine}/data.nc" + static_inputs = load_static_inputs({"HGTsfc": fine_data_path}) + ds = xr.open_dataset(fine_data_path) + fine_coords = LatLonCoordinates( + lat=torch.tensor(ds["lat"].values, dtype=torch.float32), + lon=torch.tensor(ds["lon"].values, dtype=torch.float32), + ) + model = model_config.build( + coarse_shape, 2, static_inputs=static_inputs, fine_coords=fine_coords + ) checkpoint_path = tmp_path / "model_checkpoint.pth" model.get_state() diff --git a/fme/downscaling/inference/test_output.py b/fme/downscaling/inference/test_output.py index f9f041ae4..23bee3d44 100644 --- a/fme/downscaling/inference/test_output.py +++ b/fme/downscaling/inference/test_output.py @@ -1,35 +1,71 @@ from unittest.mock import MagicMock +import numpy as np import pytest import torch from fme.core.dataset.time import TimeSlice from fme.core.dataset.xarray import XarrayDataConfig -from fme.downscaling.data import ClosedInterval, StaticInput, StaticInputs -from fme.downscaling.data.utils import LatLonCoordinates +from fme.downscaling.data import ClosedInterval from fme.downscaling.inference.output import ( DownscalingOutput, DownscalingOutputConfig, EventConfig, TimeRangeConfig, + WriterParams, ) from fme.downscaling.predictors import PatchPredictionConfig from fme.downscaling.requirements import DataRequirements -# Tests for OutputTargetConfig validation +def _make_downscaling_output(zarr_chunks_override=None, zarr_shards_override=None): + mock_data = MagicMock() + mock_data.max_output_shape = (2, 4) + mock_data.dtype = torch.float32 + mock_data.all_times.to_numpy.return_value = np.zeros(2) + return DownscalingOutput( + name="test", + save_vars=None, + n_ens=4, + max_samples_per_gpu=4, + data=mock_data, + patch=MagicMock(), + zarr_chunks_override=zarr_chunks_override, + zarr_shards_override=zarr_shards_override, + ) + + +def _make_latlon(lat_size=10, lon_size=20): + latlon = MagicMock() + latlon.lat = torch.zeros(lat_size) + latlon.lon = torch.zeros(lon_size) + return latlon + + +def test_build_writer_params_default_chunks_and_shards(): + output = _make_downscaling_output() + latlon = _make_latlon(lat_size=10, lon_size=20) + params = output._build_writer_params(latlon) + assert isinstance(params, WriterParams) + assert params.shards == {"time": 2, "ensemble": 4, "latitude": 10, "longitude": 20} + assert params.chunks["time"] == 1 + assert params.chunks["ensemble"] == 1 -def _get_static_inputs(shape=(8, 8)): - return StaticInputs( - fields=[ - StaticInput( - data=torch.ones(shape), - coords=LatLonCoordinates( - lat=torch.ones(shape[0]), lon=torch.ones(shape[1]) - ), - ) - ] + +def test_build_writer_params_override_chunks_and_shards(): + zarr_chunks_override = {"time": 5, "ensemble": 5, "latitude": 5, "longitude": 5} + zarr_shards_override = {"time": 10, "ensemble": 10, "latitude": 10, "longitude": 10} + output = _make_downscaling_output( + zarr_chunks_override=zarr_chunks_override, + zarr_shards_override=zarr_shards_override, ) + latlon = _make_latlon(lat_size=10, lon_size=20) + params = output._build_writer_params(latlon) + assert params.chunks == zarr_chunks_override + assert params.shards == zarr_shards_override + + +# Tests for OutputTargetConfig validation def test_single_xarray_config_accepts_single_config(): @@ -116,10 +152,7 @@ def test_event_config_build_creates_output_target_with_single_time( lat_extent=ClosedInterval(0.0, 6.0), lon_extent=ClosedInterval(0.0, 6.0), ) - static_inputs = _get_static_inputs((8, 8)) - output_target = config.build( - loader_config, requirements, patch_config, static_inputs - ) + output_target = config.build(loader_config, requirements, patch_config) # Verify OutputTarget was created assert isinstance(output_target, DownscalingOutput) @@ -130,10 +163,6 @@ def test_event_config_build_creates_output_target_with_single_time( # Verify time dimension - should have exactly 1 timestep assert len(output_target.data.all_times) == 1 assert output_target.data is not None - assert output_target.chunks is not None - assert tuple(output_target.chunks.values())[:2] == (1, 1) - assert output_target.shards is not None - assert tuple(output_target.shards.values()) == output_target.data.max_output_shape @pytest.mark.parametrize("loader_config", [True], indirect=True) @@ -147,11 +176,7 @@ def test_region_config_build_creates_output_target_with_time_range( n_ens=4, save_vars=["var0", "var1"], ) - static_inputs = _get_static_inputs((8, 8)) - - output_target = config.build( - loader_config, requirements, patch_config, static_inputs - ) + output_target = config.build(loader_config, requirements, patch_config) # Verify OutputTarget was created assert isinstance(output_target, DownscalingOutput) @@ -159,12 +184,7 @@ def test_region_config_build_creates_output_target_with_time_range( assert output_target.n_ens == 4 assert len(output_target.data.all_times) == 2 - # Verify chunks dict structure assert output_target.data is not None - assert output_target.chunks is not None - assert tuple(output_target.chunks.values())[:2] == (1, 1) - assert output_target.shards is not None - assert tuple(output_target.shards.values()) == output_target.data.max_output_shape def test_time_range_config_raise_error_invalid_lat_extent(): diff --git a/fme/downscaling/inference/test_work_items.py b/fme/downscaling/inference/test_work_items.py index d6bf9c37a..3bdd24cc8 100644 --- a/fme/downscaling/inference/test_work_items.py +++ b/fme/downscaling/inference/test_work_items.py @@ -409,8 +409,8 @@ def test_slice_item_dataset_max_output_shape( shape = dataset.max_output_shape # First item: time_slice=slice(0,2), ens_slice=slice(0,4) - # n_times = 2, n_ens = 4, spatial = (64, 64) - assert shape == (2, 4, 64, 64) + # n_times = 2, n_ens = 4 + assert shape == (2, 4) def test_slice_item_dataset_dtype_property( diff --git a/fme/downscaling/inference/work_items.py b/fme/downscaling/inference/work_items.py index 0b5ef27a5..643fe184f 100644 --- a/fme/downscaling/inference/work_items.py +++ b/fme/downscaling/inference/work_items.py @@ -10,7 +10,7 @@ from fme.core.distributed import Distributed from fme.core.generics.data import SizedMap -from ..data import BatchData, StaticInputs +from ..data import BatchData from ..data.config import BatchItemDatasetAdapter from .constants import ENSEMBLE_NAME, TIME_NAME @@ -110,18 +110,11 @@ def __init__( self, slice_items: list[SliceWorkItem], dataset: BatchItemDatasetAdapter, - spatial_shape: tuple[int, int] | None = None, ) -> None: self.slice_items = slice_items self.dataset = dataset self._dtype = None - if spatial_shape is None: - sample_batch_item = self.dataset[0] - self.spatial_shape = sample_batch_item.horizontal_shape - else: - self.spatial_shape = spatial_shape - def __len__(self) -> int: return len(self.slice_items) @@ -133,11 +126,11 @@ def __getitem__(self, idx: int) -> LoadedSliceWorkItem: return loaded_item @property - def max_output_shape(self): + def max_output_shape(self) -> tuple[int, int]: first_item = self.slice_items[0] n_times = first_item.time_slice.stop - first_item.time_slice.start n_ensembles = first_item.ens_slice.stop - first_item.ens_slice.start - return (n_times, n_ensembles, *self.spatial_shape) + return (n_times, n_ensembles) @property def dtype(self) -> torch.dtype: @@ -296,8 +289,7 @@ class SliceWorkItemGriddedData: variable_metadata: Mapping[str, VariableMetadata] all_times: xr.CFTimeIndex dtype: torch.dtype - max_output_shape: tuple[int, ...] - static_inputs: StaticInputs + max_output_shape: tuple[int, int] # TODO: currently no protocol or ABC for gridded data objects # if we want to unify, we will need one and just raise @@ -310,7 +302,5 @@ def on_device(work_item: LoadedSliceWorkItem) -> LoadedSliceWorkItem: return SizedMap(on_device, self._loader) - def get_generator(self) -> Iterator[tuple[LoadedSliceWorkItem, StaticInputs]]: - work_item: LoadedSliceWorkItem - for work_item in self.loader: - yield work_item, self.static_inputs + def get_generator(self) -> Iterator[LoadedSliceWorkItem]: + yield from self.loader diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 8ffa8c22b..f50071cad 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -5,7 +5,9 @@ import dacite import torch +import xarray as xr +from fme.core.coordinates import LatLonCoordinates from fme.core.device import get_device from fme.core.distributed import Distributed from fme.core.loss import LossConfig @@ -180,6 +182,7 @@ def build( self, coarse_shape: tuple[int, int], downscale_factor: int, + fine_coords: LatLonCoordinates, rename: dict[str, str] | None = None, static_inputs: StaticInputs | None = None, ) -> "DiffusionModel": @@ -220,6 +223,7 @@ def build( downscale_factor=downscale_factor, sigma_data=sigma_data, static_inputs=static_inputs, + fine_coords=fine_coords, ) def get_state(self) -> Mapping[str, Any]: @@ -275,6 +279,7 @@ def __init__( coarse_shape: tuple[int, int], downscale_factor: int, sigma_data: float, + fine_coords: LatLonCoordinates, static_inputs: StaticInputs | None = None, ) -> None: """ @@ -286,13 +291,18 @@ def __init__( normalizer: The normalizer object used for data normalization. loss: The loss function used for training the model. coarse_shape: The height (lat) and width (lon) of the - coarse-resolution input data. + coarse-resolution input data used to train the model + (same as patch extent, if training on patches). downscale_factor: The factor by which the data is downscaled from coarse to fine. sigma_data: The standard deviation of the data, used for diffusion model preconditioning. + fine_coords: the full-domain fine-resolution coordinates to use + for spatial metadata in the model output. static_inputs: Static inputs to the model, loaded from the trainer config or checkpoint. Must be set when use_fine_topography is True. + Expected to be on the same coordinate grid as fine_coords for now, + but this may be relaxed in the future. """ self.coarse_shape = coarse_shape self.downscale_factor = downscale_factor @@ -308,6 +318,19 @@ def __init__( self.static_inputs = ( static_inputs.to_device() if static_inputs is not None else None ) + device = get_device() + self.fine_coords = LatLonCoordinates( + lat=fine_coords.lat.to(device), + lon=fine_coords.lon.to(device), + ) + if static_inputs is not None: + expected = (len(fine_coords.lat), len(fine_coords.lon)) + if static_inputs.shape != expected: + raise ValueError( + f"static_inputs are expected to be on the same coordinate grid as " + f"fine_coords. StaticInputs shape {static_inputs.shape} does not " + f"match fine_coords grid {expected}." + ) @property def modules(self) -> torch.nn.ModuleList: @@ -330,7 +353,30 @@ def _subset_static_inputs( "Static inputs must be provided for each batch when use of fine " "static inputs is enabled." ) - return self.static_inputs.subset_latlon(lat_interval, lon_interval) + lat_slice = lat_interval.slice_of(self.fine_coords.lat) + lon_slice = lon_interval.slice_of(self.fine_coords.lon) + return self.static_inputs.subset(lat_slice, lon_slice) + + def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: + """Return fine-resolution coordinates matching the spatial extent of batch.""" + coarse_lat = batch.latlon_coordinates.lat[0] + coarse_lon = batch.latlon_coordinates.lon[0] + fine_lat_interval = adjust_fine_coord_range( + batch.lat_interval, + full_coarse_coord=coarse_lat, + full_fine_coord=self.fine_coords.lat, + downscale_factor=self.downscale_factor, + ) + fine_lon_interval = adjust_fine_coord_range( + batch.lon_interval, + full_coarse_coord=coarse_lon, + full_fine_coord=self.fine_coords.lon, + downscale_factor=self.downscale_factor, + ) + return LatLonCoordinates( + lat=self.fine_coords.lat[fine_lat_interval.slice_of(self.fine_coords.lat)], + lon=self.fine_coords.lon[fine_lon_interval.slice_of(self.fine_coords.lon)], + ) @property def fine_shape(self) -> tuple[int, int]: @@ -338,7 +384,8 @@ def fine_shape(self) -> tuple[int, int]: def _get_fine_shape(self, coarse_shape: tuple[int, int]) -> tuple[int, int]: """ - Calculate the fine shape based on the coarse shape and downscale factor. + Calculate the fine shape based on the coarse shape of data used to train + the model and the downscaling factor. """ return ( coarse_shape[0] * self.downscale_factor, @@ -383,12 +430,9 @@ def _get_input_from_coarse( def train_on_batch( self, batch: PairedBatchData, - static_inputs: StaticInputs | None, # TODO: remove in follow-on PR optimizer: Optimization | NullOptimization, ) -> ModelOutputs: """Performs a denoising training step on a batch of data.""" - # Ignore the passed static_inputs; subset self.static_inputs using fine batch - # coordinates. The caller-provided value is kept for signature compatibility. _static_inputs = self._subset_static_inputs( batch.fine.lat_interval, batch.fine.lon_interval ) @@ -452,8 +496,8 @@ def generate( static_inputs: StaticInputs | None, n_samples: int = 1, ) -> tuple[TensorDict, torch.Tensor, list[torch.Tensor]]: - # static_inputs receives an internally-subsetted value from the calling method; - # external callers should use generate_on_batch / generate_on_batch_no_target. + # Internal method; external callers should use generate_on_batch / + # generate_on_batch_no_target. inputs_ = self._get_input_from_coarse(coarse_data, static_inputs) # expand samples and fold to # [batch * n_samples, output_channels, height, width] @@ -503,11 +547,8 @@ def generate( def generate_on_batch_no_target( self, batch: BatchData, - static_inputs: StaticInputs | None, # TODO: remove in follow-on PR n_samples: int = 1, ) -> TensorDict: - # Ignore the passed static_inputs; derive the fine lat/lon interval from coarse - # batch coordinates via adjust_fine_coord_range, then subset self.static_inputs. if self.config.use_fine_topography: if self.static_inputs is None: raise ValueError( @@ -519,16 +560,16 @@ def generate_on_batch_no_target( fine_lat_interval = adjust_fine_coord_range( batch.lat_interval, full_coarse_coord=coarse_lat, - full_fine_coord=self.static_inputs.coords.lat, + full_fine_coord=self.fine_coords.lat, downscale_factor=self.downscale_factor, ) fine_lon_interval = adjust_fine_coord_range( batch.lon_interval, full_coarse_coord=coarse_lon, - full_fine_coord=self.static_inputs.coords.lon, + full_fine_coord=self.fine_coords.lon, downscale_factor=self.downscale_factor, ) - _static_inputs = self.static_inputs.subset_latlon( + _static_inputs = self._subset_static_inputs( fine_lat_interval, fine_lon_interval ) else: @@ -540,11 +581,8 @@ def generate_on_batch_no_target( def generate_on_batch( self, batch: PairedBatchData, - static_inputs: StaticInputs | None, # TODO: remove in follow-on PR n_samples: int = 1, ) -> ModelOutputs: - # Ignore the passed static_inputs; subset self.static_inputs using fine batch - # coordinates. The caller-provided value is kept for signature compatibility. _static_inputs = self._subset_static_inputs( batch.fine.lat_interval, batch.fine.lon_interval ) @@ -578,6 +616,7 @@ def get_state(self) -> Mapping[str, Any]: "coarse_shape": self.coarse_shape, "downscale_factor": self.downscale_factor, "static_inputs": static_inputs_state, + "fine_coords": self.fine_coords.get_state(), } @classmethod @@ -591,10 +630,40 @@ def from_state( static_inputs = StaticInputs.from_state(state["static_inputs"]).to_device() else: static_inputs = None + + # Load fine_coords: new checkpoints store it directly; old checkpoints + # that had static_inputs with coords can auto-migrate from raw state. + fine_coords = state.get("fine_coords") + if fine_coords is not None: + # TODO: Why doesn't LatLonCoordinates have a from_state method? + fine_coords = LatLonCoordinates( + lat=state["fine_coords"]["lat"], + lon=state["fine_coords"]["lon"], + ) + elif ( + static_inputs is not None + and static_inputs.fields + and "coords" in state["static_inputs"]["fields"][0] + ): + # Backward compat: old checkpoints with static inputs stored coords inside + # static_inputs fields[0]["coords"] + coords_state = state["static_inputs"]["fields"][0]["coords"] + fine_coords = LatLonCoordinates( + lat=coords_state["lat"], + lon=coords_state["lon"], + ) + else: + raise ValueError( + "No fine coordinates found in checkpoint state and no static inputs " + " were available to infer them. fine_coords must be serialized with the" + " checkpoint to resume from training." + ) + model = config.build( state["coarse_shape"], state["downscale_factor"], static_inputs=static_inputs, + fine_coords=fine_coords, ) model.module.load_state_dict(state["module"], strict=True) return model @@ -611,6 +680,26 @@ def from_state(cls, state: Mapping[str, Any]) -> DiffusionModelConfig: ).wrapper +def load_fine_coords_from_path(path: str) -> LatLonCoordinates: + if path.endswith(".zarr"): + ds = xr.open_zarr(path) + else: + ds = xr.open_dataset(path) + lat_name = next((n for n in ["lat", "latitude", "grid_yt"] if n in ds.coords), None) + lon_name = next( + (n for n in ["lon", "longitude", "grid_xt"] if n in ds.coords), None + ) + if lat_name is None or lon_name is None: + raise ValueError( + f"Could not find lat/lon coordinates in {path}. " + "Expected 'lat'/'latitude'/'grid_yt' and 'lon'/'longitude'/'grid_xt'." + ) + return LatLonCoordinates( + lat=torch.tensor(ds[lat_name].values, dtype=torch.float32), + lon=torch.tensor(ds[lon_name].values, dtype=torch.float32), + ) + + @dataclasses.dataclass class CheckpointModelConfig: """ @@ -624,6 +713,9 @@ class CheckpointModelConfig: but the model requires static input data. Raises an error if the checkpoint already has static inputs from training. fine_topography_path: Deprecated. Use static_inputs instead. + fine_coordinates_path: Optional path to a netCDF/zarr file containing lat/lon + coordinates for the full fine domain. Used for old checkpoints that have + no static_inputs and no stored fine_coords. model_updates: Optional mapping of {key: new_value} model config updates to apply when loading the model. This is useful for running evaluation with updated parameters than at training time. Use with caution; not all @@ -634,6 +726,7 @@ class CheckpointModelConfig: rename: dict[str, str] | None = None static_inputs: dict[str, str] | None = None fine_topography_path: str | None = None + fine_coordinates_path: str | None = None model_updates: dict[str, Any] | None = None def __post_init__(self) -> None: @@ -664,6 +757,8 @@ def _checkpoint(self) -> Mapping[str, Any]: ] # backwards compatibility for models before static inputs serialization checkpoint_data["model"].setdefault("static_inputs", None) + # backwards compatibility for models before fine_coords serialization + checkpoint_data["model"].setdefault("fine_coords", None) self._checkpoint_data = checkpoint_data self._checkpoint_is_loaded = True @@ -690,11 +785,53 @@ def build( static_inputs = load_static_inputs(self.static_inputs) else: static_inputs = None + + fine_coords: LatLonCoordinates + has_fine_coords = self._checkpoint["model"]["fine_coords"] is not None + has_static_input_coords = ( + self._checkpoint["model"]["static_inputs"] is not None + and self._checkpoint["model"]["static_inputs"]["fields"][0].get("coords") + is not None + ) + # TODO: simplify with static input refactor that deisables empty StaticInputs + if ( + has_fine_coords or has_static_input_coords + ) and self.fine_coordinates_path is not None: + raise ValueError( + "The model checkpoint already has fine coordinates from training. " + "fine_coordinates_path should not be provided in checkpoint model " + "config." + ) + elif has_fine_coords: + fine_coords_state = self._checkpoint["model"]["fine_coords"] + fine_coords = LatLonCoordinates( + lat=fine_coords_state["lat"], + lon=fine_coords_state["lon"], + ) + elif has_static_input_coords: + coords_state = self._checkpoint["model"]["static_inputs"]["fields"][0][ + "coords" + ] + fine_coords = LatLonCoordinates( + lat=coords_state["lat"], + lon=coords_state["lon"], + ) + elif self.fine_coordinates_path is not None: + fine_coords = load_fine_coords_from_path(self.fine_coordinates_path) + else: + raise ValueError( + "No fine coordinates found in checkpoint state and no static inputs " + " were available to infer them. fine_coordinates_path must be provided " + "in the checkpoint model configuration to load fine coordinates from " + "the provided path." + ) + model = _CheckpointModelConfigSelector.from_state( self._checkpoint["model"]["config"] ).build( coarse_shape=self._checkpoint["model"]["coarse_shape"], downscale_factor=self._checkpoint["model"]["downscale_factor"], + fine_coords=fine_coords, rename=self._rename, static_inputs=static_inputs, ) diff --git a/fme/downscaling/predict.py b/fme/downscaling/predict.py index a78878a10..c6b0a3677 100644 --- a/fme/downscaling/predict.py +++ b/fme/downscaling/predict.py @@ -9,7 +9,6 @@ import yaml from fme.core.cli import prepare_directory -from fme.core.coordinates import LatLonCoordinates from fme.core.dataset.time import TimeSlice from fme.core.dicts import to_flat_dict from fme.core.distributed import Distributed @@ -21,7 +20,6 @@ ClosedInterval, DataLoaderConfig, GriddedData, - StaticInputs, enforce_lat_bounds, ) from fme.downscaling.models import CheckpointModelConfig, DiffusionModel @@ -30,31 +28,6 @@ from fme.downscaling.typing_ import FineResCoarseResPair -def _downscale_coord(coord: torch.tensor, downscale_factor: int): - """ - This is a bandaid fix for the issue where BatchData does not - contain coords for the topography, which is fine-res in the no-target - generation case. The SampleAggregator requires the fine-res coords - for the predictions. - - TODO: remove after topography refactors to have its own data container. - """ - if len(coord.shape) != 1: - raise ValueError("coord tensor to downscale must be 1d") - spacing = coord[1] - coord[0] - # Compute edges from midpoints - first_edge = coord[0] - spacing / 2 - last_edge = coord[-1] + spacing / 2 - - # Subdivide edges - step = spacing / downscale_factor - new_edges = torch.arange(first_edge, last_edge + step / 2, step) - - # Compute new midpoints - coord_new = (new_edges[:-1] + new_edges[1:]) / 2 - return coord_new.to(device=coord.device, dtype=coord.dtype) - - @dataclasses.dataclass class EventConfig: name: str @@ -88,7 +61,6 @@ def get_gridded_data( self, base_data_config: DataLoaderConfig, requirements: DataRequirements, - static_inputs_from_checkpoint: StaticInputs | None = None, ) -> GriddedData: enforce_lat_bounds(self.lat_extent) event_coarse = dataclasses.replace(base_data_config.full_config[0]) @@ -104,7 +76,6 @@ def get_gridded_data( ) return event_data_config.build( requirements=requirements, - static_inputs=static_inputs_from_checkpoint, ) @@ -146,12 +117,9 @@ def generation_model(self): def run(self): logging.info(f"Running {self.event_name} event downscaling...") - batch, static_inputs = next(iter(self.data.get_generator())) - coarse_coords = batch[0].latlon_coordinates - fine_coords = LatLonCoordinates( - lat=_downscale_coord(coarse_coords.lat, self.model.downscale_factor), - lon=_downscale_coord(coarse_coords.lon, self.model.downscale_factor), - ) + batch = next(iter(self.data.get_generator())) + coarse_coords = batch.latlon_coordinates[0] + fine_coords = self.model.get_fine_coords_for_batch(batch) sample_agg = SampleAggregator( coarse=batch[0].data, latlon_coordinates=FineResCoarseResPair( @@ -169,7 +137,7 @@ def run(self): f"for event {self.event_name}" ) outputs = self.model.generate_on_batch_no_target( - batch, static_inputs=static_inputs, n_samples=end_idx - start_idx + batch, n_samples=end_idx - start_idx ) sample_agg.record_batch(outputs) to_log = sample_agg.get_wandb() @@ -238,30 +206,28 @@ def save_netcdf_data(self, ds: xr.Dataset): f"{self.experiment_dir}/generated_maps_and_metrics.nc", mode="w" ) - @property - def _fine_latlon_coordinates(self) -> LatLonCoordinates | None: - if self.data.static_inputs is not None: - return self.data.static_inputs.coords - else: - return None - def run(self): - aggregator = NoTargetAggregator( - downscale_factor=self.model.downscale_factor, - latlon_coordinates=self._fine_latlon_coordinates, - ) - for i, (batch, static_inputs) in enumerate(self.batch_generator): + aggregator: NoTargetAggregator | None = None + for i, batch in enumerate(self.batch_generator): + if aggregator is None: + fine_coords = self.model.get_fine_coords_for_batch(batch) + aggregator = NoTargetAggregator( + downscale_factor=self.model.downscale_factor, + latlon_coordinates=fine_coords, + ) with torch.no_grad(): logging.info(f"Generating predictions on batch {i + 1}") prediction = self.generation_model.generate_on_batch_no_target( batch=batch, - static_inputs=static_inputs, n_samples=self.n_samples, ) logging.info("Recording diagnostics to aggregator") # Add sample dimension to coarse values for generation comparison coarse = {k: v.unsqueeze(1) for k, v in batch.data.items()} aggregator.record_batch(prediction, coarse, batch.time) + + # dataset build ensures non-empty batch_generator + assert aggregator is not None logs = aggregator.get_wandb() wandb = WandB.get_instance() wandb.log(logs, step=0) @@ -297,7 +263,6 @@ def build(self) -> list[Downscaler | EventDownscaler]: model = self.model.build() dataset = self.data.build( requirements=self.model.data_requirements, - static_inputs=model.static_inputs, ) downscaler = Downscaler( data=dataset, @@ -311,7 +276,6 @@ def build(self) -> list[Downscaler | EventDownscaler]: event_dataset = event_config.get_gridded_data( base_data_config=self.data, requirements=self.model.data_requirements, - static_inputs_from_checkpoint=model.static_inputs, ) event_downscalers.append( EventDownscaler( diff --git a/fme/downscaling/predictors/composite.py b/fme/downscaling/predictors/composite.py index e1e40dc3a..5d0a4be8f 100644 --- a/fme/downscaling/predictors/composite.py +++ b/fme/downscaling/predictors/composite.py @@ -2,10 +2,10 @@ import torch +from fme.core.coordinates import LatLonCoordinates from fme.core.typing_ import TensorDict -from fme.downscaling.data import BatchData, PairedBatchData, StaticInputs, scale_tuple +from fme.downscaling.data import BatchData, PairedBatchData, scale_tuple from fme.downscaling.data.patching import Patch, get_patches -from fme.downscaling.data.utils import null_generator from fme.downscaling.models import DiffusionModel, ModelOutputs @@ -79,6 +79,17 @@ def __init__( def coarse_shape(self): return self.coarse_yx_patch_extent + @property + def static_inputs(self): + return self.model.static_inputs + + def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: + return self.model.get_fine_coords_for_batch(batch) + + @property + def fine_coords(self): + return self.model.fine_coords + def _get_patches( self, coarse_yx_extent, fine_yx_extent ) -> tuple[list[Patch], list[Patch]]: @@ -105,7 +116,6 @@ def _get_patches( def generate_on_batch( self, batch: PairedBatchData, - static_inputs: StaticInputs | None, n_samples: int = 1, ) -> ModelOutputs: predictions = [] @@ -118,17 +128,9 @@ def generate_on_batch( batch_generator = batch.generate_from_patches( coarse_patches=coarse_patches, fine_patches=fine_patches ) - if static_inputs is not None: - static_inputs_generator = static_inputs.generate_from_patches(fine_patches) - else: - static_inputs_generator = null_generator(len(fine_patches)) - - for data_patch, static_inputs_patch in zip( - batch_generator, static_inputs_generator - ): - model_output = self.model.generate_on_batch( - data_patch, static_inputs_patch, n_samples - ) + + for data_patch in batch_generator: + model_output = self.model.generate_on_batch(data_patch, n_samples) predictions.append(model_output.prediction) loss = loss + model_output.loss @@ -146,7 +148,6 @@ def generate_on_batch( def generate_on_batch_no_target( self, batch: BatchData, - static_inputs: StaticInputs | None, n_samples: int = 1, ) -> TensorDict: coarse_yx_extent = batch.horizontal_shape @@ -156,17 +157,10 @@ def generate_on_batch_no_target( ) predictions = [] batch_generator = batch.generate_from_patches(coarse_patches) - if static_inputs is not None: - static_inputs_generator = static_inputs.generate_from_patches(fine_patches) - else: - static_inputs_generator = null_generator(len(fine_patches)) - for data_patch, static_inputs_patch in zip( - batch_generator, static_inputs_generator - ): + for data_patch in batch_generator: predictions.append( self.model.generate_on_batch_no_target( batch=data_patch, - static_inputs=static_inputs_patch, n_samples=n_samples, ) ) diff --git a/fme/downscaling/predictors/test_composite.py b/fme/downscaling/predictors/test_composite.py index f6b6fa60d..9f31ff05d 100644 --- a/fme/downscaling/predictors/test_composite.py +++ b/fme/downscaling/predictors/test_composite.py @@ -6,7 +6,7 @@ from fme.core.device import get_device from fme.core.packer import Packer from fme.downscaling.aggregators.shape_helpers import upsample_tensor -from fme.downscaling.data import BatchData, PairedBatchData, StaticInput, StaticInputs +from fme.downscaling.data import BatchData, PairedBatchData from fme.downscaling.data.patching import get_patches from fme.downscaling.data.utils import BatchedLatLonCoordinates from fme.downscaling.models import ModelOutputs @@ -16,10 +16,6 @@ ) -def _get_static_inputs(shape, coords): - return StaticInputs(fields=[StaticInput(data=torch.randn(shape), coords=coords)]) - - def test_composite_predictions(): patch_yx_size = (2, 2) patches = get_patches((4, 4), patch_yx_size, overlap=0) @@ -54,9 +50,7 @@ def __init__(self, coarse_shape, downscale_factor): self.modules = [] self.out_packer = Packer(["x"]) - def generate_on_batch( - self, batch: PairedBatchData, static_inputs: StaticInputs | None, n_samples=1 - ): + def generate_on_batch(self, batch: PairedBatchData, n_samples=1): prediction_data = { k: v.unsqueeze(1).expand(-1, n_samples, -1, -1) for k, v in batch.fine.data.items() @@ -65,9 +59,7 @@ def generate_on_batch( prediction=prediction_data, target=prediction_data, loss=torch.tensor(1.0) ) - def generate_on_batch_no_target( - self, batch: BatchData, static_inputs: StaticInputs | None, n_samples=1 - ): + def generate_on_batch_no_target(self, batch: BatchData, n_samples=1): prediction_data = { k: upsample_tensor( v.unsqueeze(1).expand(-1, n_samples, -1, -1), @@ -137,13 +129,6 @@ def test_SpatialCompositePredictor_generate_on_batch(patch_size_coarse): paired_batch_data = get_paired_test_data( *coarse_extent, downscale_factor=downscale_factor, batch_size=batch_size ) - static_inputs = _get_static_inputs( - shape=( - coarse_extent[0] * downscale_factor, - coarse_extent[1] * downscale_factor, - ), - coords=paired_batch_data.fine.latlon_coordinates[0], - ) predictor = PatchPredictor( DummyModel(coarse_shape=patch_size_coarse, downscale_factor=downscale_factor), # type: ignore @@ -152,7 +137,7 @@ def test_SpatialCompositePredictor_generate_on_batch(patch_size_coarse): ) n_samples_generate = 2 outputs = predictor.generate_on_batch( - paired_batch_data, static_inputs, n_samples=n_samples_generate + paired_batch_data, n_samples=n_samples_generate ) assert outputs.prediction["x"].shape == (batch_size, n_samples_generate, 8, 8) # dummy model predicts same value as fine data for all samples @@ -174,13 +159,6 @@ def test_SpatialCompositePredictor_generate_on_batch_no_target(patch_size_coarse paired_batch_data = get_paired_test_data( *coarse_extent, downscale_factor=downscale_factor, batch_size=batch_size ) - static_inputs = _get_static_inputs( - shape=( - coarse_extent[0] * downscale_factor, - coarse_extent[1] * downscale_factor, - ), - coords=paired_batch_data.fine.latlon_coordinates[0], - ) predictor = PatchPredictor( DummyModel(coarse_shape=patch_size_coarse, downscale_factor=2), # type: ignore coarse_extent, @@ -189,6 +167,6 @@ def test_SpatialCompositePredictor_generate_on_batch_no_target(patch_size_coarse n_samples_generate = 2 coarse_batch_data = paired_batch_data.coarse prediction = predictor.generate_on_batch_no_target( - coarse_batch_data, static_inputs, n_samples=n_samples_generate + coarse_batch_data, n_samples=n_samples_generate ) assert prediction["x"].shape == (batch_size, n_samples_generate, 8, 8) diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index d6798e66f..98bb76b33 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -124,30 +124,37 @@ def make_paired_batch_data( def make_static_inputs(fine_shape: tuple[int, int]) -> StaticInputs: - """Create StaticInputs with proper monotonic coordinates for given shape.""" - lat_size, lon_size = fine_shape + """Create StaticInputs for given shape.""" return StaticInputs( fields=[ StaticInput( torch.ones(*fine_shape, device=get_device()), - LatLonCoordinates( - lat=_get_monotonic_coordinate(lat_size, stop=lat_size), - lon=_get_monotonic_coordinate(lon_size, stop=lon_size), - ), ) ] ) +def make_fine_coords(fine_shape: tuple[int, int]) -> LatLonCoordinates: + """Create LatLonCoordinates for given fine shape.""" + lat_size, lon_size = fine_shape + return LatLonCoordinates( + lat=_get_monotonic_coordinate(lat_size, stop=lat_size), + lon=_get_monotonic_coordinate(lon_size, stop=lon_size), + ) + + def test_module_serialization(tmp_path): coarse_shape = (8, 16) - static_inputs = make_static_inputs((16, 32)) + fine_shape = (16, 32) + static_inputs = make_static_inputs(fine_shape) + fine_coords = make_fine_coords(fine_shape) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=2, predict_residual=True, use_fine_topography=False, static_inputs=static_inputs, + fine_coords=fine_coords, ) model_from_state = DiffusionModel.from_state( model.get_state(), @@ -158,6 +165,9 @@ def test_module_serialization(tmp_path): model.module.parameters(), model_from_state.module.parameters() ) ) + assert model_from_state.fine_coords is not None + assert torch.equal(model_from_state.fine_coords.lat.cpu(), fine_coords.lat.cpu()) + assert torch.equal(model_from_state.fine_coords.lon.cpu(), fine_coords.lon.cpu()) torch.save(model.get_state(), tmp_path / "test.ckpt") model_from_disk = DiffusionModel.from_state( @@ -174,6 +184,9 @@ def test_module_serialization(tmp_path): assert torch.equal( loaded_static_inputs.fields[0].data, static_inputs.fields[0].data ) + assert model_from_disk.fine_coords is not None + assert torch.equal(model_from_disk.fine_coords.lat.cpu(), fine_coords.lat.cpu()) + assert torch.equal(model_from_disk.fine_coords.lon.cpu(), fine_coords.lon.cpu()) def test_from_state_backward_compat_fine_topography(): @@ -181,21 +194,27 @@ def test_from_state_backward_compat_fine_topography(): fine_shape = (16, 32) downscale_factor = 2 static_inputs = make_static_inputs(fine_shape) + fine_coords = make_fine_coords(fine_shape) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=downscale_factor, predict_residual=True, use_fine_topography=True, static_inputs=static_inputs, + fine_coords=fine_coords, ) - # Simulate old checkpoint format: static_inputs not serialized + # Simulate old checkpoint format: static_inputs not serialized, fine_coords still + # present state = model.get_state() state["static_inputs"] = None # Should load correctly via the elif use_fine_topography branch (+1 channel) model_from_old_state = DiffusionModel.from_state(state) assert model_from_old_state.static_inputs is None + assert torch.equal( + model_from_old_state.fine_coords.lat.cpu(), fine_coords.lat.cpu() + ) assert all( torch.equal(p1, p2) for p1, p2 in zip( @@ -206,7 +225,38 @@ def test_from_state_backward_compat_fine_topography(): # At runtime, omitting static inputs must raise a clear error batch = get_mock_paired_batch([2, *coarse_shape], [2, *fine_shape]) with pytest.raises(ValueError, match="Static inputs must be provided"): - model_from_old_state.generate_on_batch(batch, static_inputs=None) + model_from_old_state.generate_on_batch(batch) + + +def test_from_state_backward_compat_migrates_fine_coords_from_old_static_inputs(): + """Old checkpoints that stored coords in static_inputs fields should have + fine_coords auto-migrated on from_state.""" + coarse_shape = (8, 16) + fine_shape = (16, 32) + downscale_factor = 2 + static_inputs = make_static_inputs(fine_shape) + fine_coords = make_fine_coords(fine_shape) + model = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=downscale_factor, + predict_residual=True, + use_fine_topography=True, + static_inputs=static_inputs, + fine_coords=fine_coords, + ) + state = model.get_state() + # Simulate old format: fine_coords absent, but static_inputs fields have coords + del state["fine_coords"] + state["static_inputs"]["fields"][0]["coords"] = fine_coords.get_state() + + model_from_old_state = DiffusionModel.from_state(state) + assert model_from_old_state.fine_coords is not None + assert torch.equal( + model_from_old_state.fine_coords.lat.cpu(), fine_coords.lat.cpu() + ) + assert torch.equal( + model_from_old_state.fine_coords.lon.cpu(), fine_coords.lon.cpu() + ) def _get_diffusion_model( @@ -215,11 +265,18 @@ def _get_diffusion_model( predict_residual=True, use_fine_topography=True, static_inputs=None, + fine_coords: LatLonCoordinates | None = None, ): normalizer = PairedNormalizationConfig( NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), ) + if fine_coords is None: + fine_shape = ( + coarse_shape[0] * downscale_factor, + coarse_shape[1] * downscale_factor, + ) + fine_coords = make_fine_coords(fine_shape) return DiffusionModelConfig( module=DiffusionModuleRegistrySelector( @@ -237,7 +294,12 @@ def _get_diffusion_model( num_diffusion_generation_steps=3, predict_residual=predict_residual, use_fine_topography=use_fine_topography, - ).build(coarse_shape, downscale_factor, static_inputs=static_inputs) + ).build( + coarse_shape, + downscale_factor, + static_inputs=static_inputs, + fine_coords=fine_coords, + ) @pytest.mark.parametrize("predict_residual", [True, False]) @@ -246,6 +308,7 @@ def test_diffusion_model_train_and_generate(predict_residual, use_fine_topograph coarse_shape = (8, 16) fine_shape = (16, 32) batch_size = 2 + fine_coords = make_fine_coords(fine_shape) if use_fine_topography: static_inputs = make_static_inputs(fine_shape) batch = make_paired_batch_data(coarse_shape, fine_shape, batch_size) @@ -260,18 +323,18 @@ def test_diffusion_model_train_and_generate(predict_residual, use_fine_topograph predict_residual=predict_residual, use_fine_topography=use_fine_topography, static_inputs=static_inputs, + fine_coords=fine_coords, ) assert model._get_fine_shape(coarse_shape) == fine_shape optimization = OptimizationConfig().build(modules=[model.module], max_epochs=2) - train_outputs = model.train_on_batch(batch, static_inputs, optimization) + train_outputs = model.train_on_batch(batch, optimization) assert torch.allclose(train_outputs.target["x"], batch.fine.data["x"]) n_generated_samples = 2 generated_outputs = [ - model.generate_on_batch(batch, static_inputs) - for _ in range(n_generated_samples) + model.generate_on_batch(batch) for _ in range(n_generated_samples) ] for generated_output in generated_outputs: @@ -384,6 +447,7 @@ def test_model_error_cases(): ).build( coarse_shape, upscaling_factor, + fine_coords=make_fine_coords(fine_shape), ) batch = get_mock_paired_batch( [batch_size, *coarse_shape], [batch_size, *fine_shape] @@ -392,7 +456,7 @@ def test_model_error_cases(): # missing fine topography when model requires it batch.fine.topography = None with pytest.raises(ValueError): - model.generate_on_batch(batch, static_inputs=None) + model.generate_on_batch(batch) def test_DiffusionModel_generate_on_batch_no_target(): @@ -400,12 +464,14 @@ def test_DiffusionModel_generate_on_batch_no_target(): coarse_shape = (16, 16) downscale_factor = 2 static_inputs = make_static_inputs(fine_shape) + fine_coords = make_fine_coords(fine_shape) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=downscale_factor, predict_residual=True, use_fine_topography=True, static_inputs=static_inputs, + fine_coords=fine_coords, ) batch_size = 2 @@ -417,7 +483,6 @@ def test_DiffusionModel_generate_on_batch_no_target(): samples = model.generate_on_batch_no_target( coarse_batch, - static_inputs=static_inputs, n_samples=n_generated_samples, ) @@ -437,7 +502,9 @@ def test_DiffusionModel_generate_on_batch_no_target_arbitrary_input_size(): # Full fine domain: 64x64 covers inputs for both (8,8) and (32,32) coarse inputs # with a downscaling factor of 2 full_fine_size = 64 - static_inputs = make_static_inputs((full_fine_size, full_fine_size)) + full_fine_shape = (full_fine_size, full_fine_size) + static_inputs = make_static_inputs(full_fine_shape) + fine_coords = make_fine_coords(full_fine_shape) # need to build with static inputs to get the correct n_in_channels model = _get_diffusion_model( coarse_shape=coarse_shape, @@ -445,6 +512,7 @@ def test_DiffusionModel_generate_on_batch_no_target_arbitrary_input_size(): predict_residual=True, use_fine_topography=True, static_inputs=static_inputs, + fine_coords=fine_coords, ) n_ensemble = 2 batch_size = 2 @@ -457,9 +525,7 @@ def test_DiffusionModel_generate_on_batch_no_target_arbitrary_input_size(): coarse_batch = make_batch_data( (batch_size, *alternative_input_shape), coarse_lat, coarse_lon ) - samples = model.generate_on_batch_no_target( - coarse_batch, n_samples=n_ensemble, static_inputs=None - ) + samples = model.generate_on_batch_no_target(coarse_batch, n_samples=n_ensemble) assert samples["x"].shape == ( batch_size, @@ -498,6 +564,7 @@ def test_lognorm_noise_backwards_compatibility(): model = model_config.build( (32, 32), 2, + fine_coords=make_fine_coords((64, 64)), ) state = model.get_state() @@ -536,6 +603,71 @@ def test_noise_config_error(): ) +def test_get_fine_coords_for_batch(): + # Model trained on full coarse (8x16) / fine (16x32) grid + coarse_shape = (8, 16) + fine_shape = (16, 32) + downscale_factor = 2 + static_inputs = make_static_inputs(fine_shape) + model = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=downscale_factor, + use_fine_topography=True, + static_inputs=static_inputs, + ) + + # Build a batch covering a spatial patch: middle 4 coarse lats and 8 coarse lons. + full_coarse_lat = _get_monotonic_coordinate(coarse_shape[0], stop=fine_shape[0]) + full_coarse_lon = _get_monotonic_coordinate(coarse_shape[1], stop=fine_shape[1]) + patch_coarse_lat = full_coarse_lat[2:6].tolist() # [5, 7, 9, 11] + patch_coarse_lon = full_coarse_lon[4:12].tolist() # [9, 11, ..., 23] + batch = make_batch_data((2, 4, 8), patch_coarse_lat, patch_coarse_lon) + + result = model.get_fine_coords_for_batch(batch) + + expected_lat = model.fine_coords.lat[4:12] + expected_lon = model.fine_coords.lon[8:24] + assert torch.allclose(result.lat, expected_lat) + assert torch.allclose(result.lon, expected_lon) + + +def test_checkpoint_model_build_with_fine_coordinates_path(tmp_path): + """Old-format checkpoint (no fine_coords key, no coords in static_inputs) + should load correctly when fine_coordinates_path is provided.""" + coarse_shape = (8, 16) + fine_shape = (16, 32) + fine_coords = make_fine_coords(fine_shape) + model = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=2, + use_fine_topography=False, + fine_coords=fine_coords, + ) + # Simulate old checkpoint: no fine_coords stored + state = model.get_state() + del state["fine_coords"] + checkpoint_path = tmp_path / "test.ckpt" + torch.save({"model": state}, checkpoint_path) + + # Write fine coords to a netCDF file for the loader to read + coords_path = tmp_path / "fine_coords.nc" + ds = xr.Dataset( + coords={ + "lat": fine_coords.lat.cpu().numpy(), + "lon": fine_coords.lon.cpu().numpy(), + } + ) + ds.to_netcdf(coords_path) + + loaded_model = CheckpointModelConfig( + checkpoint_path=str(checkpoint_path), + fine_coordinates_path=str(coords_path), + ).build() + + assert torch.equal(loaded_model.fine_coords.lat.cpu(), fine_coords.lat.cpu()) + assert torch.equal(loaded_model.fine_coords.lon.cpu(), fine_coords.lon.cpu()) + + def test_checkpoint_config_topography_raises(): with pytest.raises(ValueError): CheckpointModelConfig( @@ -548,12 +680,14 @@ def test_checkpoint_model_build_raises_when_checkpoint_has_static_inputs(tmp_pat coarse_shape = (8, 16) fine_shape = (16, 32) static_inputs = make_static_inputs(fine_shape) + fine_coords = make_fine_coords(fine_shape) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=2, predict_residual=True, use_fine_topography=True, static_inputs=static_inputs, + fine_coords=fine_coords, ) checkpoint_path = tmp_path / "test.ckpt" torch.save({"model": model.get_state()}, checkpoint_path) diff --git a/fme/downscaling/test_predict.py b/fme/downscaling/test_predict.py index 6626d5a01..8101b293e 100644 --- a/fme/downscaling/test_predict.py +++ b/fme/downscaling/test_predict.py @@ -9,7 +9,11 @@ from fme.core.testing.wandb import mock_wandb from fme.downscaling import predict from fme.downscaling.data import load_static_inputs -from fme.downscaling.models import DiffusionModelConfig, PairedNormalizationConfig +from fme.downscaling.models import ( + DiffusionModelConfig, + PairedNormalizationConfig, + load_fine_coords_from_path, +) from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector from fme.downscaling.test_models import LinearDownscaling from fme.downscaling.test_utils import data_paths_helper @@ -97,7 +101,7 @@ def create_predictor_config( out_path = tmp_path / "predictor-config.yaml" with open(out_path, "w") as file: yaml.dump(config, file) - return out_path, f"{paths.fine}/data.nc" + return out_path, paths def test_predictor_runs(tmp_path, very_fast_only: bool): @@ -106,15 +110,18 @@ def test_predictor_runs(tmp_path, very_fast_only: bool): n_samples = 2 coarse_shape = (4, 4) downscale_factor = 2 - predictor_config_path, fine_data_path = create_predictor_config( + predictor_config_path, paths = create_predictor_config( tmp_path, n_samples, ) + fine_data_path = f"{paths.fine}/data.nc" + fine_coords = load_fine_coords_from_path(fine_data_path) model_config = get_model_config(coarse_shape, downscale_factor=downscale_factor) model = model_config.build( coarse_shape=coarse_shape, downscale_factor=downscale_factor, static_inputs=load_static_inputs({"HGTsfc": fine_data_path}), + fine_coords=fine_coords, ) with open(predictor_config_path) as f: predictor_config = yaml.safe_load(f) @@ -147,7 +154,7 @@ def test_predictor_renaming( coarse_shape = (4, 4) downscale_factor = 2 renaming = {"var0": "var0_renamed", "var1": "var1_renamed"} - predictor_config_path, _ = create_predictor_config( + predictor_config_path, paths = create_predictor_config( tmp_path, n_samples, model_renaming=renaming, @@ -155,10 +162,13 @@ def test_predictor_renaming( "rename": {"var0": "var0_renamed", "var1": "var1_renamed"} }, ) + fine_coords = load_fine_coords_from_path(f"{paths.fine}/data.nc") model_config = get_model_config( coarse_shape, downscale_factor, use_fine_topography=False ) - model = model_config.build(coarse_shape=coarse_shape, downscale_factor=2) + model = model_config.build( + coarse_shape=coarse_shape, downscale_factor=2, fine_coords=fine_coords + ) with open(predictor_config_path) as f: predictor_config = yaml.safe_load(f) os.makedirs( diff --git a/fme/downscaling/test_utils.py b/fme/downscaling/test_utils.py index 4f34a4451..b1546e044 100644 --- a/fme/downscaling/test_utils.py +++ b/fme/downscaling/test_utils.py @@ -78,4 +78,5 @@ def data_paths_helper( create_test_data_on_disk( coarse_path / "data.nc", dim_sizes.coarse, variable_names, coords ) + # TODO: should this return the full filename instead of just the directory? return FineResCoarseResPair[str](fine=fine_path, coarse=coarse_path) diff --git a/fme/downscaling/train.py b/fme/downscaling/train.py index 271d884be..5a0c25666 100755 --- a/fme/downscaling/train.py +++ b/fme/downscaling/train.py @@ -174,11 +174,11 @@ def train_one_epoch(self) -> None: self.train_data, random_offset=True, shuffle=True ) outputs = None - for i, (batch, static_inputs) in enumerate(train_batch_generator): + for i, batch in enumerate(train_batch_generator): self.num_batches_seen += 1 if i % 10 == 0: logging.info(f"Training on batch {i + 1}") - outputs = self.model.train_on_batch(batch, static_inputs, self.optimization) + outputs = self.model.train_on_batch(batch, self.optimization) self.ema(self.model.modules) with torch.no_grad(): train_aggregator.record_batch( @@ -250,10 +250,8 @@ def valid_one_epoch(self) -> dict[str, float]: validation_batch_generator = self._get_batch_generator( self.validation_data, random_offset=False, shuffle=False ) - for batch, static_inputs in validation_batch_generator: - outputs = self.model.train_on_batch( - batch, static_inputs, self.null_optimization - ) + for batch in validation_batch_generator: + outputs = self.model.train_on_batch(batch, self.null_optimization) validation_aggregator.record_batch( outputs=outputs, coarse=batch.coarse.data, @@ -261,7 +259,6 @@ def valid_one_epoch(self) -> dict[str, float]: ) generated_outputs = self.model.generate_on_batch( batch, - static_inputs=static_inputs, n_samples=self.config.generate_n_samples, ) # Add sample dimension to coarse values for generation comparison @@ -429,12 +426,10 @@ def build(self) -> Trainer: train_data: PairedGriddedData = self.train_data.build( train=True, requirements=self.model.data_requirements, - static_inputs=static_inputs, ) validation_data: PairedGriddedData = self.validation_data.build( train=False, requirements=self.model.data_requirements, - static_inputs=static_inputs, ) if self.coarse_patch_extent_lat and self.coarse_patch_extent_lon: model_coarse_shape = ( @@ -448,6 +443,7 @@ def build(self) -> Trainer: model_coarse_shape, train_data.downscale_factor, static_inputs=static_inputs, + fine_coords=train_data.fine_coords, ) optimization = self.optimization.build(