diff --git a/fme/downscaling/data/config.py b/fme/downscaling/data/config.py index aec2ff7c0..eda213fbc 100644 --- a/fme/downscaling/data/config.py +++ b/fme/downscaling/data/config.py @@ -23,7 +23,6 @@ PairedBatchData, PairedGriddedData, ) -from fme.downscaling.data.static import StaticInputs from fme.downscaling.data.utils import ClosedInterval, adjust_fine_coord_range from fme.downscaling.requirements import DataRequirements @@ -132,18 +131,6 @@ def _full_configs( return all_configs -def _check_fine_res_static_input_compatibility( - static_input_shape: tuple[int, int], data_coords_shape: tuple[int, int] -) -> None: - for static, coord in zip(static_input_shape, data_coords_shape): - if static != coord: - raise ValueError( - f"Static input shape {static_input_shape} is not compatible with " - f"data coordinates shape {data_coords_shape}. Static input dimensions " - "must match fine resolution coordinate dimensions." - ) - - @dataclasses.dataclass class DataLoaderConfig: """ @@ -164,8 +151,9 @@ class DataLoaderConfig: (For multi-GPU runtime, it's the number of workers per GPU.) strict_ensemble: Whether to enforce that the datasets to be concatened have the same dimensions and coordinates. - topography: Deprecated field for specifying the topography dataset. Now - provided via build method's `static_inputs` argument. + topography: Deprecated field for specifying the topography dataset. + StaticInput data are expected to be stored and serialized within a + model through the Trainer build process. lat_extent: The latitude extent to use for the dataset specified in degrees, limited to (-88.0, 88.0). The extent is inclusive, so the start and stop values are included in the extent. Defaults to [-66, 70] which @@ -202,8 +190,8 @@ def __post_init__(self): if self.topography is not None: raise ValueError( "The `topography` field on DataLoaderConfig is deprecated and will be " - "removed in a future release. Pass static_inputs via build's " - "`static_inputs` argument instead." + "removed in a future release. `StaticInputs` are now stored within " + " the model when it is first built and trained." ) @property @@ -236,40 +224,6 @@ def get_xarray_dataset( strict_ensemble=self.strict_ensemble, ) - def build_static_inputs( - self, - coarse_coords: LatLonCoordinates, - requires_topography: bool, - static_inputs: StaticInputs | None = None, - ) -> StaticInputs | None: - if requires_topography is False: - return None - if static_inputs is not None: - # TODO: change to use full static inputs list - full_static_inputs = static_inputs - else: - raise ValueError( - "Static inputs required for this model, but no static inputs " - "datasets were specified in the trainer configuration or provided " - "in model checkpoint." - ) - - # Fine grid boundaries are adjusted to exactly match the coarse grid - fine_lat_interval = adjust_fine_coord_range( - self.lat_extent, - full_coarse_coord=coarse_coords.lat, - full_fine_coord=full_static_inputs.coords.lat, - ) - fine_lon_interval = adjust_fine_coord_range( - self.lon_extent, - full_coarse_coord=coarse_coords.lon, - full_fine_coord=full_static_inputs.coords.lon, - ) - subset_static_inputs = full_static_inputs.subset_latlon( - lat_interval=fine_lat_interval, lon_interval=fine_lon_interval - ) - return subset_static_inputs.to_device() - def build_batchitem_dataset( self, dataset: XarrayConcat, @@ -301,14 +255,7 @@ def build( self, requirements: DataRequirements, dist: Distributed | None = None, - static_inputs: StaticInputs | None = None, ) -> GriddedData: - # TODO: static_inputs_from_checkpoint is currently passed from the model - # to allow loading fine topography when no fine data is available. - # See PR https://github.com/ai2cm/ace/pull/728 - # In the future we could disentangle this dependency between the data loader - # and model by enabling the built GriddedData objects to take in full static - # input fields and subset them to the same coordinate range as data. xr_dataset, properties = self.get_xarray_dataset( names=requirements.coarse_names, n_timesteps=1 ) @@ -316,7 +263,6 @@ def build( raise ValueError( "Downscaling data loader only supports datasets with latlon coords." ) - latlon_coords = properties.horizontal_coordinates dataset = self.build_batchitem_dataset( dataset=xr_dataset, properties=properties, @@ -343,14 +289,8 @@ def build( persistent_workers=True if self.num_data_workers > 0 else False, ) example = dataset[0] - subset_static_inputs = self.build_static_inputs( - coarse_coords=latlon_coords, - requires_topography=requirements.use_fine_topography, - static_inputs=static_inputs, - ) return GriddedData( _loader=dataloader, - static_inputs=subset_static_inputs, shape=example.horizontal_shape, dims=example.latlon_coordinates.dims, variable_metadata=dataset.variable_metadata, @@ -395,7 +335,6 @@ class PairedDataLoaderConfig: time dimension. Useful to include longer sequences of small data for testing. topography: Deprecated field for specifying the topography dataset. - Now provided via build method's `static_inputs` argument. sample_with_replacement: If provided, the dataset will be sampled randomly with replacement to the given size each period, instead of retrieving each sample once (either shuffled or not). @@ -427,8 +366,8 @@ def __post_init__(self): if self.topography is not None: raise ValueError( "The `topography` field on PairedDataLoaderConfig is deprecated and " - "will be removed in a future release. Pass static_inputs via the " - "build method's `static_inputs` argument instead." + "will be removed in a future release. `StaticInputs` are now stored " + "within the model when it is first built and trained." ) def _first_data_config( @@ -468,14 +407,7 @@ def build( train: bool, requirements: DataRequirements, dist: Distributed | None = None, - static_inputs: StaticInputs | None = None, ) -> PairedGriddedData: - # TODO: static_inputs_from_checkpoint is currently passed from the model - # to allow loading fine topography when no fine data is available. - # See PR https://github.com/ai2cm/ace/pull/728 - # In the future we could disentangle this dependency between the data loader - # and model by enabling the built GriddedData objects to take in full static - # input fields and subset them to the same coordinate range as data. if dist is None: dist = Distributed.get_instance() @@ -537,25 +469,6 @@ def build( full_fine_coord=properties_fine.horizontal_coordinates.lon, ) - if requirements.use_fine_topography: - if static_inputs is None: - raise ValueError( - "Model requires static inputs (use_fine_topography=True)," - " but no static inputs were provided to the data loader's" - " build method." - ) - - static_inputs = static_inputs.to_device() - _check_fine_res_static_input_compatibility( - static_inputs.shape, - properties_fine.horizontal_coordinates.shape, - ) - static_inputs = static_inputs.subset_latlon( - lat_interval=fine_lat_extent, lon_interval=fine_lon_extent - ) - else: - static_inputs = None - dataset_fine_subset = HorizontalSubsetDataset( dataset_fine, properties=properties_fine, @@ -611,7 +524,6 @@ def build( return PairedGriddedData( _loader=dataloader, - static_inputs=static_inputs, coarse_shape=example.coarse.horizontal_shape, downscale_factor=example.downscale_factor, dims=example.fine.latlon_coordinates.dims, diff --git a/fme/downscaling/data/datasets.py b/fme/downscaling/data/datasets.py index 38e7f9191..bce3d59c4 100644 --- a/fme/downscaling/data/datasets.py +++ b/fme/downscaling/data/datasets.py @@ -20,14 +20,12 @@ from fme.core.generics.data import SizedMap from fme.core.typing_ import TensorMapping from fme.downscaling.data.patching import Patch, get_patches -from fme.downscaling.data.static import StaticInputs from fme.downscaling.data.utils import ( BatchedLatLonCoordinates, ClosedInterval, check_leading_dim, expand_and_fold_tensor, get_offset, - null_generator, paired_shuffle, scale_tuple, ) @@ -298,7 +296,6 @@ class GriddedData: dims: list[str] variable_metadata: Mapping[str, VariableMetadata] all_times: xr.CFTimeIndex - static_inputs: StaticInputs | None @property def loader(self) -> DataLoader[BatchItem]: @@ -307,27 +304,8 @@ def on_device(batch: BatchItem) -> BatchItem: return SizedMap(on_device, self._loader) - @property - def topography_downscale_factor(self) -> int | None: - if self.static_inputs: - if ( - self.static_inputs.shape[0] % self.shape[0] != 0 - or self.static_inputs.shape[1] % self.shape[1] != 0 - ): - raise ValueError( - "Static inputs shape must be evenly divisible by data shape. " - f"Got static inputs with shape {self.static_inputs.shape} " - f"and data with shape {self.shape}" - ) - return self.static_inputs.shape[0] // self.shape[0] - else: - return None - - def get_generator( - self, - ) -> Iterator[tuple["BatchData", StaticInputs | None]]: - for batch in self.loader: - yield (batch, self.static_inputs) + def get_generator(self) -> Iterator["BatchData"]: + yield from self.loader def get_patched_generator( self, @@ -335,20 +313,18 @@ def get_patched_generator( overlap: int = 0, drop_partial_patches: bool = True, random_offset: bool = False, - ) -> Iterator[tuple["BatchData", StaticInputs | None]]: + ) -> Iterator["BatchData"]: patched_generator = patched_batch_gen_from_loader( loader=self.loader, - static_inputs=self.static_inputs, coarse_yx_extent=self.shape, coarse_yx_patch_extent=yx_patch_extent, - downscale_factor=self.topography_downscale_factor, coarse_overlap=overlap, drop_partial_patches=drop_partial_patches, random_offset=random_offset, ) return cast( - Iterator[tuple[BatchData, StaticInputs | None]], + Iterator[BatchData], patched_generator, ) @@ -361,7 +337,6 @@ class PairedGriddedData: dims: list[str] variable_metadata: Mapping[str, VariableMetadata] all_times: xr.CFTimeIndex - static_inputs: StaticInputs | None @property def loader(self) -> DataLoader[PairedBatchItem]: @@ -370,11 +345,8 @@ def on_device(batch: PairedBatchItem) -> PairedBatchItem: return SizedMap(on_device, self._loader) - def get_generator( - self, - ) -> Iterator[tuple["PairedBatchData", StaticInputs | None]]: - for batch in self.loader: - yield (batch, self.static_inputs) + def get_generator(self) -> Iterator["PairedBatchData"]: + yield from self.loader def get_patched_generator( self, @@ -383,10 +355,9 @@ def get_patched_generator( drop_partial_patches: bool = True, random_offset: bool = False, shuffle: bool = False, - ) -> Iterator[tuple["PairedBatchData", StaticInputs | None]]: + ) -> Iterator["PairedBatchData"]: patched_generator = patched_batch_gen_from_paired_loader( self.loader, - self.static_inputs, coarse_yx_extent=self.coarse_shape, coarse_yx_patch_extent=coarse_yx_patch_extent, downscale_factor=self.downscale_factor, @@ -396,7 +367,7 @@ def get_patched_generator( shuffle=shuffle, ) return cast( - Iterator[tuple[PairedBatchData, StaticInputs | None]], + Iterator[PairedBatchData], patched_generator, ) @@ -669,6 +640,9 @@ def __iter__(self): return iter(indices[start:end]) +# downscale_factor=None means fine patches not needed here, but reusing +# _get_paired_patches in both paired and no-target cases to share the +# coincident offset logic. def _get_paired_patches( coarse_yx_extent: tuple[int, int], coarse_yx_patch_extent: tuple[int, int], @@ -713,44 +687,28 @@ def _get_paired_patches( def patched_batch_gen_from_loader( loader: DataLoader[BatchItem], - static_inputs: StaticInputs | None, coarse_yx_extent: tuple[int, int], coarse_yx_patch_extent: tuple[int, int], - downscale_factor: int | None, coarse_overlap: int = 0, drop_partial_patches: bool = True, random_offset: bool = False, shuffle: bool = False, -) -> Iterator[tuple[BatchData, StaticInputs | None]]: +) -> Iterator[BatchData]: for batch in loader: - coarse_patches, fine_patches = _get_paired_patches( + coarse_patches, _ = _get_paired_patches( coarse_yx_extent=coarse_yx_extent, coarse_yx_patch_extent=coarse_yx_patch_extent, coarse_overlap=coarse_overlap, - downscale_factor=downscale_factor, + downscale_factor=None, random_offset=random_offset, shuffle=shuffle, drop_partial_patches=drop_partial_patches, ) - batch_data_patches = batch.generate_from_patches(coarse_patches) - - if static_inputs is not None: - if fine_patches is None: - raise ValueError( - "Topography provided but downscale_factor is None, cannot " - "generate fine patches." - ) - static_inputs_patches = static_inputs.generate_from_patches(fine_patches) - else: - static_inputs_patches = null_generator(len(coarse_patches)) - - # Combine outputs from both generators - yield from zip(batch_data_patches, static_inputs_patches) + yield from batch.generate_from_patches(coarse_patches) def patched_batch_gen_from_paired_loader( loader: DataLoader[PairedBatchItem], - static_inputs: StaticInputs | None, coarse_yx_extent: tuple[int, int], coarse_yx_patch_extent: tuple[int, int], downscale_factor: int, @@ -758,7 +716,7 @@ def patched_batch_gen_from_paired_loader( drop_partial_patches: bool = True, random_offset: bool = False, shuffle: bool = False, -) -> Iterator[tuple[PairedBatchData, StaticInputs | None]]: +) -> Iterator[PairedBatchData]: for batch in loader: coarse_patches, fine_patches = _get_paired_patches( coarse_yx_extent=coarse_yx_extent, @@ -769,17 +727,4 @@ def patched_batch_gen_from_paired_loader( shuffle=shuffle, drop_partial_patches=drop_partial_patches, ) - batch_data_patches = batch.generate_from_patches(coarse_patches, fine_patches) - - if static_inputs is not None: - if fine_patches is None: - raise ValueError( - "Static inputs provided but downscale_factor is None, cannot " - "generate fine patches." - ) - static_inputs_patches = static_inputs.generate_from_patches(fine_patches) - else: - static_inputs_patches = null_generator(len(coarse_patches)) - - # Combine outputs from both generators - yield from zip(batch_data_patches, static_inputs_patches) + yield from batch.generate_from_patches(coarse_patches, fine_patches) diff --git a/fme/downscaling/data/static.py b/fme/downscaling/data/static.py index 2c7199cf9..4a860174c 100644 --- a/fme/downscaling/data/static.py +++ b/fme/downscaling/data/static.py @@ -1,12 +1,10 @@ import dataclasses -from collections.abc import Generator, Iterator import torch import xarray as xr from fme.core.coordinates import LatLonCoordinates from fme.core.device import get_device -from fme.downscaling.data.patching import Patch from fme.downscaling.data.utils import ClosedInterval @@ -53,11 +51,6 @@ def to_device(self) -> "StaticInput": ), ) - def _apply_patch(self, patch: Patch): - return self._latlon_index_slice( - lat_slice=patch.input_slice.y, lon_slice=patch.input_slice.x - ) - def _latlon_index_slice( self, lat_slice: slice, @@ -73,13 +66,6 @@ def _latlon_index_slice( coords=sliced_latlon, ) - def generate_from_patches( - self, - patches: list[Patch], - ) -> Generator["StaticInput", None, None]: - for patch in patches: - yield self._apply_patch(patch) - def get_state(self) -> dict: return { "data": self.data.cpu(), @@ -162,15 +148,6 @@ def subset_latlon( def to_device(self) -> "StaticInputs": return StaticInputs(fields=[field.to_device() for field in self.fields]) - def generate_from_patches( - self, - patches: list[Patch], - ) -> Iterator["StaticInputs"]: - for patch in patches: - yield StaticInputs( - fields=[field._apply_patch(patch) for field in self.fields] - ) - def get_state(self) -> dict: return { "fields": [field.get_state() for field in self.fields], diff --git a/fme/downscaling/data/test_config.py b/fme/downscaling/data/test_config.py index 1aaff8577..f916abadd 100644 --- a/fme/downscaling/data/test_config.py +++ b/fme/downscaling/data/test_config.py @@ -1,7 +1,6 @@ import dataclasses import pytest -import torch from fme.core.dataset.merged import MergeNoConcatDatasetConfig from fme.core.dataset.xarray import XarrayDataConfig @@ -10,25 +9,11 @@ PairedDataLoaderConfig, XarrayEnsembleDataConfig, ) -from fme.downscaling.data.static import StaticInput, StaticInputs -from fme.downscaling.data.utils import ClosedInterval, LatLonCoordinates +from fme.downscaling.data.utils import ClosedInterval from fme.downscaling.requirements import DataRequirements from fme.downscaling.test_utils import data_paths_helper -def get_static_inputs(shape=(8, 8)): - return StaticInputs( - fields=[ - StaticInput( - data=torch.ones(shape), - coords=LatLonCoordinates( - lat=torch.ones(shape[0]), lon=torch.ones(shape[1]) - ), - ) - ] - ) - - @pytest.mark.parametrize( "fine_engine, coarse_engine, num_data_workers, expected", [ @@ -78,9 +63,7 @@ def test_DataLoaderConfig_build(tmp_path, very_fast_only: bool): lat_extent=ClosedInterval(1, 4), lon_extent=ClosedInterval(0, 3), ) - data = data_config.build( - requirements=requirements, static_inputs=get_static_inputs(shape=(8, 8)) - ) + data = data_config.build(requirements=requirements) batch = next(iter(data.loader)) # lat/lon midpoints are on (0.5, 1.5, ...) assert batch.data["var0"].shape == (2, 3, 3) @@ -152,9 +135,7 @@ def test_DataLoaderConfig_includes_merge(tmp_path, very_fast_only: bool): lon_extent=ClosedInterval(0, 3), ) - data = data_config.build( - requirements=requirements, static_inputs=get_static_inputs(shape=(8, 8)) - ) + data = data_config.build(requirements=requirements) # XarrayDataConfig + MergeNoConcatDatasetConfig each # contribute 4 timesteps = 8 total assert len(data.loader) == 4 # 8 samples / batch_size 2 diff --git a/fme/downscaling/data/test_patching.py b/fme/downscaling/data/test_patching.py index 7630722c4..fb6cbc7de 100644 --- a/fme/downscaling/data/test_patching.py +++ b/fme/downscaling/data/test_patching.py @@ -2,8 +2,7 @@ import pytest import torch -from fme.core.device import get_device -from fme.downscaling.data import PairedBatchData, StaticInput, StaticInputs +from fme.downscaling.data import PairedBatchData from fme.downscaling.data.datasets import patched_batch_gen_from_paired_loader from fme.downscaling.data.patching import ( _divide_into_slices, @@ -115,19 +114,10 @@ def test_paired_patches_with_random_offset_consistent(overlap): full_coarse_coords = full_data.coarse.latlon_coordinates full_fine_coords = full_data.fine.latlon_coordinates - topography_data = torch.randn( - coarse_shape[0] * downscale_factor, - coarse_shape[1] * downscale_factor, - device=get_device(), - ) - topography = StaticInputs( - fields=[StaticInput(data=topography_data, coords=full_fine_coords[0])] - ) y_offsets = [] x_offsets = [] batch_generator = patched_batch_gen_from_paired_loader( loader=loader, - static_inputs=topography, coarse_yx_extent=coarse_shape, coarse_yx_patch_extent=(10, 10), downscale_factor=downscale_factor, @@ -136,7 +126,7 @@ def test_paired_patches_with_random_offset_consistent(overlap): random_offset=True, ) paired_batch: PairedBatchData - for paired_batch, _ in batch_generator: # type: ignore + for paired_batch in batch_generator: assert paired_batch.coarse.data["x"].shape == (batch_size, 10, 10) assert paired_batch.fine.data["x"].shape == (batch_size, 20, 20) @@ -177,19 +167,9 @@ def test_paired_patches_shuffle(shuffle): loader = _mock_data_loader( 10, *coarse_shape, downscale_factor=downscale_factor, batch_size=batch_size ) - topography_data = torch.randn( - coarse_shape[0] * downscale_factor, - coarse_shape[1] * downscale_factor, - device=get_device(), - ) - fine_coords = next(iter(loader)).fine.latlon_coordinates[0] - static_inputs = StaticInputs( - fields=[StaticInput(data=topography_data, coords=fine_coords)] - ) generator0 = patched_batch_gen_from_paired_loader( loader=loader, - static_inputs=static_inputs, coarse_yx_extent=coarse_shape, coarse_yx_patch_extent=(2, 2), downscale_factor=downscale_factor, @@ -200,7 +180,6 @@ def test_paired_patches_shuffle(shuffle): ) generator1 = patched_batch_gen_from_paired_loader( loader=loader, - static_inputs=static_inputs, coarse_yx_extent=coarse_shape, coarse_yx_patch_extent=(2, 2), downscale_factor=downscale_factor, @@ -212,28 +191,14 @@ def test_paired_patches_shuffle(shuffle): patches0: list[PairedBatchData] = [] patches1: list[PairedBatchData] = [] - topography0: list[torch.Tensor] = [] - topography1: list[torch.Tensor] = [] for i in range(4): - p0, t0 = next(generator0) - patches0.append(p0) # type: ignore - topography0.append(t0) - p1, t1 = next(generator1) - patches1.append(p1) # type: ignore - topography1.append(t1) + patches0.append(next(generator0)) # type: ignore + patches1.append(next(generator1)) # type: ignore data0 = torch.concat([patch.coarse.data["x"] for patch in patches0], dim=0) data1 = torch.concat([patch.coarse.data["x"] for patch in patches1], dim=0) - topo_concat_0 = torch.concat( - [t0.fields[0].data for t0 in topography0 if t0 is not None], dim=0 - ) - topo_concat_1 = torch.concat( - [t1.fields[0].data for t1 in topography1 if t1 is not None], dim=0 - ) if shuffle: assert not torch.equal(data0, data1) - assert not torch.equal(topo_concat_0, topo_concat_1) else: assert torch.equal(data0, data1) - assert torch.equal(topo_concat_0, topo_concat_1) diff --git a/fme/downscaling/data/test_static.py b/fme/downscaling/data/test_static.py index 8aeace714..edfc0e080 100644 --- a/fme/downscaling/data/test_static.py +++ b/fme/downscaling/data/test_static.py @@ -2,7 +2,6 @@ import torch from fme.core.coordinates import LatLonCoordinates -from fme.downscaling.data.patching import Patch, _HorizontalSlice from .static import StaticInput, StaticInputs from .utils import ClosedInterval @@ -48,75 +47,6 @@ def test_subset_latlon(): assert torch.allclose(subset_topo.data, expected_data) -def test_Topography_generate_from_patches(): - output_slice = _HorizontalSlice(y=slice(None), x=slice(None)) - patches = [ - Patch( - input_slice=_HorizontalSlice(y=slice(1, 3), x=slice(None, None)), - output_slice=output_slice, - ), - Patch( - input_slice=_HorizontalSlice(y=slice(0, 2), x=slice(2, 3)), - output_slice=output_slice, - ), - ] - topography = StaticInput( - torch.arange(16).reshape(4, 4), - LatLonCoordinates(torch.arange(4), torch.arange(4)), - ) - topo_patch_generator = topography.generate_from_patches(patches) - generated_patches = [] - for topo_patch in topo_patch_generator: - generated_patches.append(topo_patch) - assert len(generated_patches) == 2 - assert torch.equal( - generated_patches[0].data, torch.tensor([[4, 5, 6, 7], [8, 9, 10, 11]]) - ) - assert torch.equal(generated_patches[1].data, torch.tensor([[2], [6]])) - - -def test_StaticInputs_generate_from_patches(): - output_slice = _HorizontalSlice(y=slice(None), x=slice(None)) - patches = [ - Patch( - input_slice=_HorizontalSlice(y=slice(1, 3), x=slice(None, None)), - output_slice=output_slice, - ), - Patch( - input_slice=_HorizontalSlice(y=slice(0, 2), x=slice(2, 3)), - output_slice=output_slice, - ), - ] - data = torch.arange(16).reshape(4, 4) - topography = StaticInput( - data, - LatLonCoordinates(torch.arange(4), torch.arange(4)), - ) - land_frac = StaticInput( - data * -1.0, - LatLonCoordinates(torch.arange(4), torch.arange(4)), - ) - static_inputs = StaticInputs([topography, land_frac]) - static_inputs_patch_generator = static_inputs.generate_from_patches(patches) - generated_patches = [] - for static_inputs_patch in static_inputs_patch_generator: - generated_patches.append(static_inputs_patch) - - assert len(generated_patches) == 2 - - expected_topography_patch_0 = torch.tensor([[4, 5, 6, 7], [8, 9, 10, 11]]) - expected_topography_patch_1 = torch.tensor([[2], [6]]) - - # first index is the patch, second is the static input field within - # the StaticInputs container - assert torch.equal(generated_patches[0][0].data, expected_topography_patch_0) - assert torch.equal(generated_patches[1][0].data, expected_topography_patch_1) - - # land_frac field values are -1 * topography - assert torch.equal(generated_patches[0][1].data, expected_topography_patch_0 * -1.0) - assert torch.equal(generated_patches[1][1].data, expected_topography_patch_1 * -1.0) - - def test_StaticInputs_serialize(): data = torch.arange(16).reshape(4, 4) topography = StaticInput( diff --git a/fme/downscaling/evaluator.py b/fme/downscaling/evaluator.py index c2813ae37..e7a19727b 100644 --- a/fme/downscaling/evaluator.py +++ b/fme/downscaling/evaluator.py @@ -15,7 +15,6 @@ from fme.downscaling.data import ( PairedDataLoaderConfig, PairedGriddedData, - StaticInputs, enforce_lat_bounds, ) from fme.downscaling.models import CheckpointModelConfig, DiffusionModel @@ -56,12 +55,10 @@ def run(self): else: batch_generator = self.data.get_generator() - for i, (batch, static_inputs) in enumerate(batch_generator): + for i, batch in enumerate(batch_generator): with torch.no_grad(): logging.info(f"Generating predictions on batch {i + 1}") - outputs = self.model.generate_on_batch( - batch, static_inputs, n_samples=self.n_samples - ) + outputs = self.model.generate_on_batch(batch, n_samples=self.n_samples) logging.info("Recording diagnostics to aggregator") # Add sample dimension to coarse values for generation comparison coarse = {k: v.unsqueeze(1) for k, v in batch.coarse.data.items()} @@ -107,7 +104,7 @@ def __init__( def run(self): logging.info(f"Running {self.event_name} event evaluation") - batch, static_inputs = next(iter(self.data.get_generator())) + batch = next(iter(self.data.get_generator())) sample_agg = PairedSampleAggregator( target=batch[0].fine.data, coarse=batch[0].coarse.data, @@ -125,9 +122,7 @@ def run(self): f"Generating samples {start_idx} to {end_idx} " f"for event {self.event_name}" ) - outputs = self.model.generate_on_batch( - batch, static_inputs, n_samples=end_idx - start_idx - ) + outputs = self.model.generate_on_batch(batch, n_samples=end_idx - start_idx) sample_agg.record_batch(outputs.prediction) to_log = sample_agg.get_wandb() @@ -152,7 +147,6 @@ def get_paired_gridded_data( self, base_data_config: PairedDataLoaderConfig, requirements: DataRequirements, - static_inputs_from_checkpoint: StaticInputs | None = None, ) -> PairedGriddedData: enforce_lat_bounds(self.lat_extent) time_slice = self._time_selection_slice @@ -173,7 +167,6 @@ def get_paired_gridded_data( return event_data_config.build( train=False, requirements=requirements, - static_inputs=static_inputs_from_checkpoint, ) @@ -200,7 +193,6 @@ def _build_default_evaluator(self) -> Evaluator: dataset = self.data.build( train=False, requirements=self.model.data_requirements, - static_inputs=model.static_inputs, ) evaluator_model: DiffusionModel | PatchPredictor if self.patch.divide_generation and self.patch.composite_prediction: @@ -237,7 +229,6 @@ def _build_event_evaluator( dataset = event_config.get_paired_gridded_data( base_data_config=self.data, requirements=self.model.data_requirements, - static_inputs_from_checkpoint=model.static_inputs, ) if (dataset.coarse_shape[0] > model.coarse_shape[0]) or ( diff --git a/fme/downscaling/inference/inference.py b/fme/downscaling/inference/inference.py index 7b5f1a361..5ba046140 100644 --- a/fme/downscaling/inference/inference.py +++ b/fme/downscaling/inference/inference.py @@ -3,6 +3,7 @@ from dataclasses import dataclass, field import dacite +import numpy as np import torch import yaml @@ -10,7 +11,7 @@ from fme.core.generics.trainer import count_parameters from fme.core.logging_utils import LoggingConfig -from ..data import DataLoaderConfig, StaticInputs +from ..data import DataLoaderConfig from ..models import CheckpointModelConfig, DiffusionModel from ..predictors import PatchPredictionConfig, PatchPredictor from .output import DownscalingOutput, EventConfig, TimeRangeConfig @@ -50,7 +51,7 @@ def run_all(self): def _get_generation_model( self, - static_inputs: StaticInputs, + input_shape: tuple[int, int], output: DownscalingOutput, ) -> DiffusionModel | PatchPredictor: """ @@ -60,20 +61,22 @@ def _get_generation_model( the user to use patching for larger domains because that provides better generations. """ - model_patch_shape = self.model.fine_shape - actual_shape = tuple(static_inputs.shape) + model_patch_shape = self.model.coarse_shape - if model_patch_shape == actual_shape: + if model_patch_shape == input_shape: # short circuit, no patching necessary return self.model elif any( expected > actual - for expected, actual in zip(model_patch_shape, actual_shape) + for expected, actual in zip(model_patch_shape, input_shape) ): # we don't support generating regions smaller than the model patch size raise ValueError( f"Model coarse shape {model_patch_shape} is larger than " - f"actual topography shape {actual_shape} for output {output.name}." + f"actual input shape {input_shape} for output {output.name}." + "We do not support generating outputs with a smaller spatial extent" + " than the model's trained patch size. Please adjust the spatial extent" + " to be at least as large as the model's input patch size." ) elif output.patch.needs_patch_predictor: # Use a patch predictor @@ -86,14 +89,14 @@ def _get_generation_model( # User should enable patching raise ValueError( f"Model coarse shape {model_patch_shape} does not match " - f"actual input shape {actual_shape} for output {output.name}, " + f"actual input shape {input_shape} for output {output.name}, " "and patch prediction is not configured. Generation for larger domains " "requires patch prediction." ) def _on_device_generator(self, loader): - for loaded_item, topography in loader: - yield loaded_item.to_device(), topography.to_device() + for loaded_item in loader: + yield loaded_item.to_device() def run_output_generation(self, output: DownscalingOutput): """Execute the generation loop for this output.""" @@ -105,20 +108,20 @@ def run_output_generation(self, output: DownscalingOutput): total_batches = len(output.data.loader) loaded_item: LoadedSliceWorkItem - static_inputs: StaticInputs - for i, (loaded_item, static_inputs) in enumerate(output.data.get_generator()): + for i, loaded_item in enumerate(output.data.get_generator()): + input_shape = loaded_item.batch.horizontal_shape + if model is None: + model = self._get_generation_model( + input_shape=input_shape, output=output + ) + if writer is None: + fine_latlon_coords = model.get_fine_coords_for_batch(loaded_item.batch) writer = output.get_writer( - latlon_coords=static_inputs.coords, + latlon_coords=fine_latlon_coords, output_dir=self.output_dir, ) - writer.initialize_store( - static_inputs.fields[0].data.cpu().numpy().dtype - ) - if model is None: - model = self._get_generation_model( - static_inputs=static_inputs, output=output - ) + writer.initialize_store(np.float32) logging.info( f"[{output.name}] Batch {i+1}/{total_batches}, " @@ -127,7 +130,6 @@ def run_output_generation(self, output: DownscalingOutput): output_data = model.generate_on_batch_no_target( loaded_item.batch, - static_inputs=static_inputs, n_samples=loaded_item.n_ens, ) output_np = {key: value.cpu().numpy() for key, value in output_data.items()} @@ -238,7 +240,7 @@ def build(self) -> Downscaler: loader_config=self.data, requirements=self.model.data_requirements, patch=self.patch, - static_inputs_from_checkpoint=model.static_inputs, + fine_shape=model.fine_shape, ) for output_cfg in self.outputs ] diff --git a/fme/downscaling/inference/output.py b/fme/downscaling/inference/output.py index 012b555ad..2af28e27e 100644 --- a/fme/downscaling/inference/output.py +++ b/fme/downscaling/inference/output.py @@ -18,7 +18,6 @@ ClosedInterval, DataLoaderConfig, LatLonCoordinates, - StaticInputs, enforce_lat_bounds, ) from ..data.config import XarrayEnsembleDataConfig @@ -153,6 +152,7 @@ def build( loader_config: DataLoaderConfig, requirements: DataRequirements, patch: PatchPredictionConfig, + fine_shape: tuple[int, int], ) -> DownscalingOutput: """ Build an OutputTarget from this configuration. @@ -161,6 +161,8 @@ def build( loader_config: Base data loader configuration to modify requirements: Model's data requirements (variable names, etc.) patch: Default patch prediction configuration + fine_shape: Fine shape of the output used as metadata + for the shape of the output to insert into the dataset """ pass @@ -218,7 +220,7 @@ def _build_gridded_data( loader_config: DataLoaderConfig, requirements: DataRequirements, dist: Distributed | None = None, - static_inputs_from_checkpoint: StaticInputs | None = None, + fine_shape: tuple[int, int] | None = None, ) -> SliceWorkItemGriddedData: xr_dataset, properties = loader_config.get_xarray_dataset( names=requirements.coarse_names, n_timesteps=1 @@ -229,13 +231,6 @@ def _build_gridded_data( "Downscaling data loader only supports datasets with latlon coords." ) dataset = loader_config.build_batchitem_dataset(xr_dataset, properties) - topography = loader_config.build_static_inputs( - coords, - requires_topography=requirements.use_fine_topography, - static_inputs=static_inputs_from_checkpoint, - ) - if topography is None: - raise ValueError("Topography is required for downscaling generation.") work_items = get_work_items( n_times=len(dataset), @@ -243,11 +238,10 @@ def _build_gridded_data( max_samples_per_gpu=self.max_samples_per_gpu, ) - # defer topography device placement until after batch generation slice_dataset = SliceItemDataset( slice_items=work_items, dataset=dataset, - spatial_shape=topography.shape, + spatial_shape=fine_shape, ) # each SliceItemDataset work item loads its own full batch, so batch_size=1 @@ -274,7 +268,6 @@ def _build_gridded_data( all_times=xr_dataset.sample_start_times, dtype=slice_dataset.dtype, max_output_shape=slice_dataset.max_output_shape, - static_inputs=topography, ) def _build( @@ -286,7 +279,7 @@ def _build( requirements: DataRequirements, patch: PatchPredictionConfig, coarse: list[XarrayDataConfig], - static_inputs_from_checkpoint: StaticInputs | None = None, + fine_shape: tuple[int, int] | None = None, ) -> DownscalingOutput: updated_loader_config = self._replace_loader_config( time, @@ -299,7 +292,7 @@ def _build( gridded_data = self._build_gridded_data( updated_loader_config, requirements, - static_inputs_from_checkpoint=static_inputs_from_checkpoint, + fine_shape=fine_shape, ) if self.zarr_chunks is None: @@ -386,7 +379,7 @@ def build( loader_config: DataLoaderConfig, requirements: DataRequirements, patch: PatchPredictionConfig, - static_inputs_from_checkpoint: StaticInputs | None = None, + fine_shape: tuple[int, int] | None = None, ) -> DownscalingOutput: # Convert single time to TimeSlice time: Slice | TimeSlice @@ -409,7 +402,7 @@ def build( requirements=requirements, patch=patch, coarse=coarse, - static_inputs_from_checkpoint=static_inputs_from_checkpoint, + fine_shape=fine_shape, ) @@ -469,7 +462,7 @@ def build( loader_config: DataLoaderConfig, requirements: DataRequirements, patch: PatchPredictionConfig, - static_inputs_from_checkpoint: StaticInputs | None = None, + fine_shape: tuple[int, int] | None = None, ) -> DownscalingOutput: coarse = self._single_xarray_config(loader_config.coarse) return self._build( @@ -480,5 +473,5 @@ def build( requirements=requirements, patch=patch, coarse=coarse, - static_inputs_from_checkpoint=static_inputs_from_checkpoint, + fine_shape=fine_shape, ) diff --git a/fme/downscaling/inference/test_inference.py b/fme/downscaling/inference/test_inference.py index e5cf8c641..ed97d2b60 100644 --- a/fme/downscaling/inference/test_inference.py +++ b/fme/downscaling/inference/test_inference.py @@ -12,6 +12,7 @@ from fme.core.dataset.time import TimeSlice from fme.core.logging_utils import LoggingConfig from fme.downscaling.data import ( + ClosedInterval, LatLonCoordinates, StaticInput, StaticInputs, @@ -24,7 +25,6 @@ EventConfig, TimeRangeConfig, ) -from fme.downscaling.inference.work_items import SliceWorkItemGriddedData from fme.downscaling.models import ( CheckpointModelConfig, DiffusionModelConfig, @@ -91,8 +91,7 @@ def test_get_generation_model_exact_match(mock_model, mock_output_target): """ Test _get_generation_model returns model unchanged when shapes match exactly. """ - mock_model.fine_shape = (16, 16) - static_inputs = get_static_inputs(shape=(16, 16)) + mock_model.coarse_shape = (16, 16) downscaler = Downscaler( model=mock_model, @@ -100,23 +99,22 @@ def test_get_generation_model_exact_match(mock_model, mock_output_target): ) result = downscaler._get_generation_model( - static_inputs=static_inputs, + input_shape=(16, 16), output=mock_output_target, ) assert result is mock_model -@pytest.mark.parametrize("topo_shape", [(8, 16), (16, 8), (8, 8)]) +@pytest.mark.parametrize("input_shape", [(8, 16), (16, 8), (8, 8)]) def test_get_generation_model_raises_when_domain_too_small( - mock_model, mock_output_target, topo_shape + mock_model, mock_output_target, input_shape ): """ Test _get_generation_model raises ValueError when domain is smaller than model. """ - mock_model.fine_shape = (16, 16) - topo = get_static_inputs(shape=topo_shape) + mock_model.coarse_shape = (16, 16) downscaler = Downscaler( model=mock_model, @@ -125,7 +123,7 @@ def test_get_generation_model_raises_when_domain_too_small( with pytest.raises(ValueError): downscaler._get_generation_model( - static_inputs=topo, + input_shape=input_shape, output=mock_output_target, ) @@ -137,8 +135,7 @@ def test_get_generation_model_creates_patch_predictor_when_needed( Test _get_generation_model creates PatchPredictor for large domains with patching. """ - mock_model.fine_shape = (16, 16) - static_inputs = get_static_inputs(shape=(32, 32)) # Larger than model + mock_model.coarse_shape = (16, 16) patch_config = PatchPredictionConfig( divide_generation=True, @@ -152,7 +149,7 @@ def test_get_generation_model_creates_patch_predictor_when_needed( ) model = downscaler._get_generation_model( - static_inputs=static_inputs, + input_shape=(32, 32), output=mock_output_target, ) @@ -167,8 +164,7 @@ def test_get_generation_model_raises_when_large_domain_without_patching( Test _get_generation_model raises when domain is large but patching not configured. """ - mock_model.fine_shape = (16, 16) - topo = get_static_inputs(shape=(32, 32)) # Larger than model + mock_model.coarse_shape = (16, 16) mock_output_target.patch = PatchPredictionConfig(divide_generation=False) downscaler = Downscaler( @@ -178,7 +174,7 @@ def test_get_generation_model_raises_when_large_domain_without_patching( with pytest.raises(ValueError): downscaler._get_generation_model( - static_inputs=topo, + input_shape=(32, 32), output=mock_output_target, ) @@ -187,26 +183,30 @@ def test_run_target_generation_skips_padding_items( mock_model, mock_output_target, ): - """Test run_target_generation skips writing output for padding items.""" - # Create padding work item + """ + Test run_output_generation calls the model but skips writing for padding items. + """ mock_work_item = MagicMock() mock_work_item.is_padding = True - mock_work_item.n_ens = 4 - mock_work_item.batch = MagicMock() - - static_inputs = get_static_inputs(shape=(16, 16)) - - mock_gridded_data = SliceWorkItemGriddedData( - [mock_work_item], {}, [0], torch.float32, (1, 4, 16, 16), static_inputs + mock_work_item.batch.horizontal_shape = (16, 16) + # Coarse coords are interior so fine can have buffer on each side. + mock_work_item.batch.latlon_coordinates.lat = ( + torch.arange(1, 9).float().unsqueeze(0) ) - mock_output_target.data = mock_gridded_data - mock_model.fine_shape = (16, 16) - - mock_output = { - "var1": torch.randn(1, 4, 16, 16), - "var2": torch.randn(1, 4, 16, 16), + mock_work_item.batch.latlon_coordinates.lon = ( + torch.arange(1, 9).float().unsqueeze(0) + ) + mock_work_item.batch.lat_interval = ClosedInterval(1.0, 8.0) + mock_work_item.batch.lon_interval = ClosedInterval(1.0, 8.0) + mock_output_target.data.get_generator.return_value = iter([mock_work_item]) + + mock_model.downscale_factor = 2 + mock_model.static_inputs.coords.lat = torch.arange(0, 18).float() + mock_model.static_inputs.coords.lon = torch.arange(0, 18).float() + mock_model.static_inputs.subset_latlon.return_value.fields[0].data = torch.zeros(1) + mock_model.generate_on_batch_no_target.return_value = { + "var1": torch.zeros(1, 4, 16, 16), } - mock_model.generate_on_batch_no_target.return_value = mock_output mock_writer = MagicMock() mock_output_target.get_writer.return_value = mock_writer @@ -218,11 +218,8 @@ def test_run_target_generation_skips_padding_items( downscaler.run_output_generation(output=mock_output_target) - # Verify model was still called mock_model.generate_on_batch_no_target.assert_called_once() - - # Verify the mock writer was not called - mock_writer.write_batch.assert_not_called() + mock_writer.record_batch.assert_not_called() # Tests for end-to-end generation process diff --git a/fme/downscaling/inference/test_output.py b/fme/downscaling/inference/test_output.py index f9f041ae4..413fcfcb0 100644 --- a/fme/downscaling/inference/test_output.py +++ b/fme/downscaling/inference/test_output.py @@ -1,12 +1,10 @@ from unittest.mock import MagicMock import pytest -import torch from fme.core.dataset.time import TimeSlice from fme.core.dataset.xarray import XarrayDataConfig -from fme.downscaling.data import ClosedInterval, StaticInput, StaticInputs -from fme.downscaling.data.utils import LatLonCoordinates +from fme.downscaling.data import ClosedInterval from fme.downscaling.inference.output import ( DownscalingOutput, DownscalingOutputConfig, @@ -19,19 +17,6 @@ # Tests for OutputTargetConfig validation -def _get_static_inputs(shape=(8, 8)): - return StaticInputs( - fields=[ - StaticInput( - data=torch.ones(shape), - coords=LatLonCoordinates( - lat=torch.ones(shape[0]), lon=torch.ones(shape[1]) - ), - ) - ] - ) - - def test_single_xarray_config_accepts_single_config(): """Test that _single_xarray_config accepts a single XarrayDataConfig.""" xarray_config = XarrayDataConfig( @@ -116,10 +101,7 @@ def test_event_config_build_creates_output_target_with_single_time( lat_extent=ClosedInterval(0.0, 6.0), lon_extent=ClosedInterval(0.0, 6.0), ) - static_inputs = _get_static_inputs((8, 8)) - output_target = config.build( - loader_config, requirements, patch_config, static_inputs - ) + output_target = config.build(loader_config, requirements, patch_config) # Verify OutputTarget was created assert isinstance(output_target, DownscalingOutput) @@ -147,11 +129,7 @@ def test_region_config_build_creates_output_target_with_time_range( n_ens=4, save_vars=["var0", "var1"], ) - static_inputs = _get_static_inputs((8, 8)) - - output_target = config.build( - loader_config, requirements, patch_config, static_inputs - ) + output_target = config.build(loader_config, requirements, patch_config) # Verify OutputTarget was created assert isinstance(output_target, DownscalingOutput) diff --git a/fme/downscaling/inference/work_items.py b/fme/downscaling/inference/work_items.py index 0b5ef27a5..25049cf5b 100644 --- a/fme/downscaling/inference/work_items.py +++ b/fme/downscaling/inference/work_items.py @@ -10,7 +10,7 @@ from fme.core.distributed import Distributed from fme.core.generics.data import SizedMap -from ..data import BatchData, StaticInputs +from ..data import BatchData from ..data.config import BatchItemDatasetAdapter from .constants import ENSEMBLE_NAME, TIME_NAME @@ -297,7 +297,6 @@ class SliceWorkItemGriddedData: all_times: xr.CFTimeIndex dtype: torch.dtype max_output_shape: tuple[int, ...] - static_inputs: StaticInputs # TODO: currently no protocol or ABC for gridded data objects # if we want to unify, we will need one and just raise @@ -310,7 +309,5 @@ def on_device(work_item: LoadedSliceWorkItem) -> LoadedSliceWorkItem: return SizedMap(on_device, self._loader) - def get_generator(self) -> Iterator[tuple[LoadedSliceWorkItem, StaticInputs]]: - work_item: LoadedSliceWorkItem - for work_item in self.loader: - yield work_item, self.static_inputs + def get_generator(self) -> Iterator[LoadedSliceWorkItem]: + yield from self.loader diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 8ffa8c22b..2d0d1057e 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -6,6 +6,7 @@ import dacite import torch +from fme.core.coordinates import LatLonCoordinates from fme.core.device import get_device from fme.core.distributed import Distributed from fme.core.loss import LossConfig @@ -286,7 +287,8 @@ def __init__( normalizer: The normalizer object used for data normalization. loss: The loss function used for training the model. coarse_shape: The height (lat) and width (lon) of the - coarse-resolution input data. + coarse-resolution input data used to train the model + (same as patch extent, if training on patches). downscale_factor: The factor by which the data is downscaled from coarse to fine. sigma_data: The standard deviation of the data, used for diffusion @@ -332,13 +334,44 @@ def _subset_static_inputs( ) return self.static_inputs.subset_latlon(lat_interval, lon_interval) + def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: + """Return fine-resolution coordinates matching the spatial extent of batch.""" + if self.static_inputs is None: + raise ValueError( + "Model is missing static inputs, which are required to determine " + "the coordinate information for the output dataset." + ) + coarse_lat = batch.latlon_coordinates.lat[0] + coarse_lon = batch.latlon_coordinates.lon[0] + fine_lat_interval = adjust_fine_coord_range( + batch.lat_interval, + full_coarse_coord=coarse_lat, + full_fine_coord=self.static_inputs.coords.lat, + downscale_factor=self.downscale_factor, + ) + fine_lon_interval = adjust_fine_coord_range( + batch.lon_interval, + full_coarse_coord=coarse_lon, + full_fine_coord=self.static_inputs.coords.lon, + downscale_factor=self.downscale_factor, + ) + return LatLonCoordinates( + lat=self.static_inputs.coords.lat[ + fine_lat_interval.slice_of(self.static_inputs.coords.lat) + ], + lon=self.static_inputs.coords.lon[ + fine_lon_interval.slice_of(self.static_inputs.coords.lon) + ], + ) + @property def fine_shape(self) -> tuple[int, int]: return self._get_fine_shape(self.coarse_shape) def _get_fine_shape(self, coarse_shape: tuple[int, int]) -> tuple[int, int]: """ - Calculate the fine shape based on the coarse shape and downscale factor. + Calculate the fine shape based on the coarse shape of data used to train + the model and the downscaling factor. """ return ( coarse_shape[0] * self.downscale_factor, @@ -383,12 +416,9 @@ def _get_input_from_coarse( def train_on_batch( self, batch: PairedBatchData, - static_inputs: StaticInputs | None, # TODO: remove in follow-on PR optimizer: Optimization | NullOptimization, ) -> ModelOutputs: """Performs a denoising training step on a batch of data.""" - # Ignore the passed static_inputs; subset self.static_inputs using fine batch - # coordinates. The caller-provided value is kept for signature compatibility. _static_inputs = self._subset_static_inputs( batch.fine.lat_interval, batch.fine.lon_interval ) @@ -452,8 +482,8 @@ def generate( static_inputs: StaticInputs | None, n_samples: int = 1, ) -> tuple[TensorDict, torch.Tensor, list[torch.Tensor]]: - # static_inputs receives an internally-subsetted value from the calling method; - # external callers should use generate_on_batch / generate_on_batch_no_target. + # Internal method; external callers should use generate_on_batch / + # generate_on_batch_no_target. inputs_ = self._get_input_from_coarse(coarse_data, static_inputs) # expand samples and fold to # [batch * n_samples, output_channels, height, width] @@ -503,11 +533,8 @@ def generate( def generate_on_batch_no_target( self, batch: BatchData, - static_inputs: StaticInputs | None, # TODO: remove in follow-on PR n_samples: int = 1, ) -> TensorDict: - # Ignore the passed static_inputs; derive the fine lat/lon interval from coarse - # batch coordinates via adjust_fine_coord_range, then subset self.static_inputs. if self.config.use_fine_topography: if self.static_inputs is None: raise ValueError( @@ -540,11 +567,8 @@ def generate_on_batch_no_target( def generate_on_batch( self, batch: PairedBatchData, - static_inputs: StaticInputs | None, # TODO: remove in follow-on PR n_samples: int = 1, ) -> ModelOutputs: - # Ignore the passed static_inputs; subset self.static_inputs using fine batch - # coordinates. The caller-provided value is kept for signature compatibility. _static_inputs = self._subset_static_inputs( batch.fine.lat_interval, batch.fine.lon_interval ) diff --git a/fme/downscaling/predict.py b/fme/downscaling/predict.py index a78878a10..b4da331aa 100644 --- a/fme/downscaling/predict.py +++ b/fme/downscaling/predict.py @@ -21,7 +21,6 @@ ClosedInterval, DataLoaderConfig, GriddedData, - StaticInputs, enforce_lat_bounds, ) from fme.downscaling.models import CheckpointModelConfig, DiffusionModel @@ -88,7 +87,6 @@ def get_gridded_data( self, base_data_config: DataLoaderConfig, requirements: DataRequirements, - static_inputs_from_checkpoint: StaticInputs | None = None, ) -> GriddedData: enforce_lat_bounds(self.lat_extent) event_coarse = dataclasses.replace(base_data_config.full_config[0]) @@ -104,7 +102,6 @@ def get_gridded_data( ) return event_data_config.build( requirements=requirements, - static_inputs=static_inputs_from_checkpoint, ) @@ -146,7 +143,7 @@ def generation_model(self): def run(self): logging.info(f"Running {self.event_name} event downscaling...") - batch, static_inputs = next(iter(self.data.get_generator())) + batch = next(iter(self.data.get_generator())) coarse_coords = batch[0].latlon_coordinates fine_coords = LatLonCoordinates( lat=_downscale_coord(coarse_coords.lat, self.model.downscale_factor), @@ -169,7 +166,7 @@ def run(self): f"for event {self.event_name}" ) outputs = self.model.generate_on_batch_no_target( - batch, static_inputs=static_inputs, n_samples=end_idx - start_idx + batch, n_samples=end_idx - start_idx ) sample_agg.record_batch(outputs) to_log = sample_agg.get_wandb() @@ -238,30 +235,28 @@ def save_netcdf_data(self, ds: xr.Dataset): f"{self.experiment_dir}/generated_maps_and_metrics.nc", mode="w" ) - @property - def _fine_latlon_coordinates(self) -> LatLonCoordinates | None: - if self.data.static_inputs is not None: - return self.data.static_inputs.coords - else: - return None - def run(self): - aggregator = NoTargetAggregator( - downscale_factor=self.model.downscale_factor, - latlon_coordinates=self._fine_latlon_coordinates, - ) - for i, (batch, static_inputs) in enumerate(self.batch_generator): + aggregator: NoTargetAggregator | None = None + for i, batch in enumerate(self.batch_generator): + if aggregator is None: + fine_coords = self.model.get_fine_coords_for_batch(batch) + aggregator = NoTargetAggregator( + downscale_factor=self.model.downscale_factor, + latlon_coordinates=fine_coords, + ) with torch.no_grad(): logging.info(f"Generating predictions on batch {i + 1}") prediction = self.generation_model.generate_on_batch_no_target( batch=batch, - static_inputs=static_inputs, n_samples=self.n_samples, ) logging.info("Recording diagnostics to aggregator") # Add sample dimension to coarse values for generation comparison coarse = {k: v.unsqueeze(1) for k, v in batch.data.items()} aggregator.record_batch(prediction, coarse, batch.time) + + # dataset build ensures non-empty batch_generator + assert aggregator is not None logs = aggregator.get_wandb() wandb = WandB.get_instance() wandb.log(logs, step=0) @@ -297,7 +292,6 @@ def build(self) -> list[Downscaler | EventDownscaler]: model = self.model.build() dataset = self.data.build( requirements=self.model.data_requirements, - static_inputs=model.static_inputs, ) downscaler = Downscaler( data=dataset, @@ -311,7 +305,6 @@ def build(self) -> list[Downscaler | EventDownscaler]: event_dataset = event_config.get_gridded_data( base_data_config=self.data, requirements=self.model.data_requirements, - static_inputs_from_checkpoint=model.static_inputs, ) event_downscalers.append( EventDownscaler( diff --git a/fme/downscaling/predictors/composite.py b/fme/downscaling/predictors/composite.py index e1e40dc3a..05add74a0 100644 --- a/fme/downscaling/predictors/composite.py +++ b/fme/downscaling/predictors/composite.py @@ -2,10 +2,10 @@ import torch +from fme.core.coordinates import LatLonCoordinates from fme.core.typing_ import TensorDict -from fme.downscaling.data import BatchData, PairedBatchData, StaticInputs, scale_tuple +from fme.downscaling.data import BatchData, PairedBatchData, scale_tuple from fme.downscaling.data.patching import Patch, get_patches -from fme.downscaling.data.utils import null_generator from fme.downscaling.models import DiffusionModel, ModelOutputs @@ -79,6 +79,13 @@ def __init__( def coarse_shape(self): return self.coarse_yx_patch_extent + @property + def static_inputs(self): + return self.model.static_inputs + + def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: + return self.model.get_fine_coords_for_batch(batch) + def _get_patches( self, coarse_yx_extent, fine_yx_extent ) -> tuple[list[Patch], list[Patch]]: @@ -105,7 +112,6 @@ def _get_patches( def generate_on_batch( self, batch: PairedBatchData, - static_inputs: StaticInputs | None, n_samples: int = 1, ) -> ModelOutputs: predictions = [] @@ -118,17 +124,9 @@ def generate_on_batch( batch_generator = batch.generate_from_patches( coarse_patches=coarse_patches, fine_patches=fine_patches ) - if static_inputs is not None: - static_inputs_generator = static_inputs.generate_from_patches(fine_patches) - else: - static_inputs_generator = null_generator(len(fine_patches)) - - for data_patch, static_inputs_patch in zip( - batch_generator, static_inputs_generator - ): - model_output = self.model.generate_on_batch( - data_patch, static_inputs_patch, n_samples - ) + + for data_patch in batch_generator: + model_output = self.model.generate_on_batch(data_patch, n_samples) predictions.append(model_output.prediction) loss = loss + model_output.loss @@ -146,7 +144,6 @@ def generate_on_batch( def generate_on_batch_no_target( self, batch: BatchData, - static_inputs: StaticInputs | None, n_samples: int = 1, ) -> TensorDict: coarse_yx_extent = batch.horizontal_shape @@ -156,17 +153,10 @@ def generate_on_batch_no_target( ) predictions = [] batch_generator = batch.generate_from_patches(coarse_patches) - if static_inputs is not None: - static_inputs_generator = static_inputs.generate_from_patches(fine_patches) - else: - static_inputs_generator = null_generator(len(fine_patches)) - for data_patch, static_inputs_patch in zip( - batch_generator, static_inputs_generator - ): + for data_patch in batch_generator: predictions.append( self.model.generate_on_batch_no_target( batch=data_patch, - static_inputs=static_inputs_patch, n_samples=n_samples, ) ) diff --git a/fme/downscaling/predictors/test_composite.py b/fme/downscaling/predictors/test_composite.py index f6b6fa60d..9f31ff05d 100644 --- a/fme/downscaling/predictors/test_composite.py +++ b/fme/downscaling/predictors/test_composite.py @@ -6,7 +6,7 @@ from fme.core.device import get_device from fme.core.packer import Packer from fme.downscaling.aggregators.shape_helpers import upsample_tensor -from fme.downscaling.data import BatchData, PairedBatchData, StaticInput, StaticInputs +from fme.downscaling.data import BatchData, PairedBatchData from fme.downscaling.data.patching import get_patches from fme.downscaling.data.utils import BatchedLatLonCoordinates from fme.downscaling.models import ModelOutputs @@ -16,10 +16,6 @@ ) -def _get_static_inputs(shape, coords): - return StaticInputs(fields=[StaticInput(data=torch.randn(shape), coords=coords)]) - - def test_composite_predictions(): patch_yx_size = (2, 2) patches = get_patches((4, 4), patch_yx_size, overlap=0) @@ -54,9 +50,7 @@ def __init__(self, coarse_shape, downscale_factor): self.modules = [] self.out_packer = Packer(["x"]) - def generate_on_batch( - self, batch: PairedBatchData, static_inputs: StaticInputs | None, n_samples=1 - ): + def generate_on_batch(self, batch: PairedBatchData, n_samples=1): prediction_data = { k: v.unsqueeze(1).expand(-1, n_samples, -1, -1) for k, v in batch.fine.data.items() @@ -65,9 +59,7 @@ def generate_on_batch( prediction=prediction_data, target=prediction_data, loss=torch.tensor(1.0) ) - def generate_on_batch_no_target( - self, batch: BatchData, static_inputs: StaticInputs | None, n_samples=1 - ): + def generate_on_batch_no_target(self, batch: BatchData, n_samples=1): prediction_data = { k: upsample_tensor( v.unsqueeze(1).expand(-1, n_samples, -1, -1), @@ -137,13 +129,6 @@ def test_SpatialCompositePredictor_generate_on_batch(patch_size_coarse): paired_batch_data = get_paired_test_data( *coarse_extent, downscale_factor=downscale_factor, batch_size=batch_size ) - static_inputs = _get_static_inputs( - shape=( - coarse_extent[0] * downscale_factor, - coarse_extent[1] * downscale_factor, - ), - coords=paired_batch_data.fine.latlon_coordinates[0], - ) predictor = PatchPredictor( DummyModel(coarse_shape=patch_size_coarse, downscale_factor=downscale_factor), # type: ignore @@ -152,7 +137,7 @@ def test_SpatialCompositePredictor_generate_on_batch(patch_size_coarse): ) n_samples_generate = 2 outputs = predictor.generate_on_batch( - paired_batch_data, static_inputs, n_samples=n_samples_generate + paired_batch_data, n_samples=n_samples_generate ) assert outputs.prediction["x"].shape == (batch_size, n_samples_generate, 8, 8) # dummy model predicts same value as fine data for all samples @@ -174,13 +159,6 @@ def test_SpatialCompositePredictor_generate_on_batch_no_target(patch_size_coarse paired_batch_data = get_paired_test_data( *coarse_extent, downscale_factor=downscale_factor, batch_size=batch_size ) - static_inputs = _get_static_inputs( - shape=( - coarse_extent[0] * downscale_factor, - coarse_extent[1] * downscale_factor, - ), - coords=paired_batch_data.fine.latlon_coordinates[0], - ) predictor = PatchPredictor( DummyModel(coarse_shape=patch_size_coarse, downscale_factor=2), # type: ignore coarse_extent, @@ -189,6 +167,6 @@ def test_SpatialCompositePredictor_generate_on_batch_no_target(patch_size_coarse n_samples_generate = 2 coarse_batch_data = paired_batch_data.coarse prediction = predictor.generate_on_batch_no_target( - coarse_batch_data, static_inputs, n_samples=n_samples_generate + coarse_batch_data, n_samples=n_samples_generate ) assert prediction["x"].shape == (batch_size, n_samples_generate, 8, 8) diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index d6798e66f..864f511f3 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -206,7 +206,7 @@ def test_from_state_backward_compat_fine_topography(): # At runtime, omitting static inputs must raise a clear error batch = get_mock_paired_batch([2, *coarse_shape], [2, *fine_shape]) with pytest.raises(ValueError, match="Static inputs must be provided"): - model_from_old_state.generate_on_batch(batch, static_inputs=None) + model_from_old_state.generate_on_batch(batch) def _get_diffusion_model( @@ -265,13 +265,12 @@ def test_diffusion_model_train_and_generate(predict_residual, use_fine_topograph assert model._get_fine_shape(coarse_shape) == fine_shape optimization = OptimizationConfig().build(modules=[model.module], max_epochs=2) - train_outputs = model.train_on_batch(batch, static_inputs, optimization) + train_outputs = model.train_on_batch(batch, optimization) assert torch.allclose(train_outputs.target["x"], batch.fine.data["x"]) n_generated_samples = 2 generated_outputs = [ - model.generate_on_batch(batch, static_inputs) - for _ in range(n_generated_samples) + model.generate_on_batch(batch) for _ in range(n_generated_samples) ] for generated_output in generated_outputs: @@ -392,7 +391,7 @@ def test_model_error_cases(): # missing fine topography when model requires it batch.fine.topography = None with pytest.raises(ValueError): - model.generate_on_batch(batch, static_inputs=None) + model.generate_on_batch(batch) def test_DiffusionModel_generate_on_batch_no_target(): @@ -417,7 +416,6 @@ def test_DiffusionModel_generate_on_batch_no_target(): samples = model.generate_on_batch_no_target( coarse_batch, - static_inputs=static_inputs, n_samples=n_generated_samples, ) @@ -457,9 +455,7 @@ def test_DiffusionModel_generate_on_batch_no_target_arbitrary_input_size(): coarse_batch = make_batch_data( (batch_size, *alternative_input_shape), coarse_lat, coarse_lon ) - samples = model.generate_on_batch_no_target( - coarse_batch, n_samples=n_ensemble, static_inputs=None - ) + samples = model.generate_on_batch_no_target(coarse_batch, n_samples=n_ensemble) assert samples["x"].shape == ( batch_size, @@ -536,6 +532,52 @@ def test_noise_config_error(): ) +def test_get_fine_coords_for_batch(): + # Model trained on full coarse (8x16) / fine (16x32) grid + coarse_shape = (8, 16) + fine_shape = (16, 32) + downscale_factor = 2 + static_inputs = make_static_inputs(fine_shape) + model = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=downscale_factor, + use_fine_topography=True, + static_inputs=static_inputs, + ) + + # Build a batch covering a spatial patch: middle 4 coarse lats and 8 coarse lons. + full_coarse_lat = _get_monotonic_coordinate(coarse_shape[0], stop=fine_shape[0]) + full_coarse_lon = _get_monotonic_coordinate(coarse_shape[1], stop=fine_shape[1]) + patch_coarse_lat = full_coarse_lat[2:6].tolist() # [5, 7, 9, 11] + patch_coarse_lon = full_coarse_lon[4:12].tolist() # [9, 11, ..., 23] + batch = make_batch_data((2, 4, 8), patch_coarse_lat, patch_coarse_lon) + + result = model.get_fine_coords_for_batch(batch) + + expected_lat = model.static_inputs.coords.lat[4:12] + expected_lon = model.static_inputs.coords.lon[8:24] + # model.static_inputs has been moved to device; index into it directly + # to match devices + assert torch.allclose(result.lat, expected_lat) + assert torch.allclose(result.lon, expected_lon) + + +def test_get_fine_coords_for_batch_raises_without_static_inputs(): + model = _get_diffusion_model( + coarse_shape=(16, 16), + downscale_factor=2, + use_fine_topography=False, + static_inputs=None, + ) + batch = make_batch_data( + (1, 16, 16), + _get_monotonic_coordinate(16, stop=16).tolist(), + _get_monotonic_coordinate(16, stop=16).tolist(), + ) + with pytest.raises(ValueError, match="missing static inputs"): + model.get_fine_coords_for_batch(batch) + + def test_checkpoint_config_topography_raises(): with pytest.raises(ValueError): CheckpointModelConfig( diff --git a/fme/downscaling/test_predict.py b/fme/downscaling/test_predict.py index 6626d5a01..ce253ab14 100644 --- a/fme/downscaling/test_predict.py +++ b/fme/downscaling/test_predict.py @@ -147,7 +147,7 @@ def test_predictor_renaming( coarse_shape = (4, 4) downscale_factor = 2 renaming = {"var0": "var0_renamed", "var1": "var1_renamed"} - predictor_config_path, _ = create_predictor_config( + predictor_config_path, fine_data_path = create_predictor_config( tmp_path, n_samples, model_renaming=renaming, @@ -158,7 +158,11 @@ def test_predictor_renaming( model_config = get_model_config( coarse_shape, downscale_factor, use_fine_topography=False ) - model = model_config.build(coarse_shape=coarse_shape, downscale_factor=2) + model = model_config.build( + coarse_shape=coarse_shape, + downscale_factor=2, + static_inputs=load_static_inputs({"HGTsfc": fine_data_path}), + ) with open(predictor_config_path) as f: predictor_config = yaml.safe_load(f) os.makedirs( diff --git a/fme/downscaling/train.py b/fme/downscaling/train.py index 271d884be..a1dd08fde 100755 --- a/fme/downscaling/train.py +++ b/fme/downscaling/train.py @@ -174,11 +174,11 @@ def train_one_epoch(self) -> None: self.train_data, random_offset=True, shuffle=True ) outputs = None - for i, (batch, static_inputs) in enumerate(train_batch_generator): + for i, batch in enumerate(train_batch_generator): self.num_batches_seen += 1 if i % 10 == 0: logging.info(f"Training on batch {i + 1}") - outputs = self.model.train_on_batch(batch, static_inputs, self.optimization) + outputs = self.model.train_on_batch(batch, self.optimization) self.ema(self.model.modules) with torch.no_grad(): train_aggregator.record_batch( @@ -250,10 +250,8 @@ def valid_one_epoch(self) -> dict[str, float]: validation_batch_generator = self._get_batch_generator( self.validation_data, random_offset=False, shuffle=False ) - for batch, static_inputs in validation_batch_generator: - outputs = self.model.train_on_batch( - batch, static_inputs, self.null_optimization - ) + for batch in validation_batch_generator: + outputs = self.model.train_on_batch(batch, self.null_optimization) validation_aggregator.record_batch( outputs=outputs, coarse=batch.coarse.data, @@ -261,7 +259,6 @@ def valid_one_epoch(self) -> dict[str, float]: ) generated_outputs = self.model.generate_on_batch( batch, - static_inputs=static_inputs, n_samples=self.config.generate_n_samples, ) # Add sample dimension to coarse values for generation comparison @@ -429,12 +426,10 @@ def build(self) -> Trainer: train_data: PairedGriddedData = self.train_data.build( train=True, requirements=self.model.data_requirements, - static_inputs=static_inputs, ) validation_data: PairedGriddedData = self.validation_data.build( train=False, requirements=self.model.data_requirements, - static_inputs=static_inputs, ) if self.coarse_patch_extent_lat and self.coarse_patch_extent_lon: model_coarse_shape = (