diff --git a/fme/downscaling/data/datasets.py b/fme/downscaling/data/datasets.py index 8bc16826d..38e7f9191 100644 --- a/fme/downscaling/data/datasets.py +++ b/fme/downscaling/data/datasets.py @@ -441,6 +441,16 @@ def __post_init__(self): def horizontal_shape(self) -> tuple[int, int]: return self._horizontal_shape + @property + def lat_interval(self) -> ClosedInterval: + lat = self.latlon_coordinates.lat[0] # all batch members identical; use first + return ClosedInterval(lat.min().item(), lat.max().item()) + + @property + def lon_interval(self) -> ClosedInterval: + lon = self.latlon_coordinates.lon[0] # all batch members identical; use first + return ClosedInterval(lon.min().item(), lon.max().item()) + @classmethod def from_sequence( cls, @@ -464,7 +474,7 @@ def to_device(self) -> "BatchData": def __getitem__(self, k): return BatchItem( - {key: value[k].squeeze() for key, value in self.data.items()}, + {key: value[k] for key, value in self.data.items()}, self.time[k], self.latlon_coordinates[k], ) diff --git a/fme/downscaling/data/test_datasets.py b/fme/downscaling/data/test_datasets.py index 5be3e38ed..f3ac263de 100644 --- a/fme/downscaling/data/test_datasets.py +++ b/fme/downscaling/data/test_datasets.py @@ -15,6 +15,7 @@ LatLonCoordinates, PairedBatchItem, ) +from fme.downscaling.data.patching import Patch, _HorizontalSlice from fme.downscaling.data.utils import BatchedLatLonCoordinates, ClosedInterval @@ -401,3 +402,63 @@ def test_BatchData_slice_latlon(): batch_slice.data["x"], batch.data["x"][:, lat_slice, lon_slice], ) + + +def _make_batch_data_for_patching(batch_size=2): + """Create a 2×4×4 BatchData with known arange values for patch testing.""" + n_lat, n_lon = 4, 4 + lat = torch.arange(n_lat, dtype=torch.float32) + lon = torch.arange(n_lon, dtype=torch.float32) + data = { + "x": torch.arange(batch_size * n_lat * n_lon, dtype=torch.float32).reshape( + batch_size, n_lat, n_lon + ) + } + time = xr.DataArray(list(range(batch_size)), dims=["batch"]) + latlon_coordinates = BatchedLatLonCoordinates( + lat=lat.unsqueeze(0).expand(batch_size, -1).clone(), + lon=lon.unsqueeze(0).expand(batch_size, -1).clone(), + ) + return BatchData(data=data, time=time, latlon_coordinates=latlon_coordinates) + + +def test_batch_data_generate_from_patches(): + batch = _make_batch_data_for_patching() + patches = [ + Patch( + input_slice=_HorizontalSlice(y=slice(1, 3), x=slice(None)), + output_slice=_HorizontalSlice(y=slice(None), x=slice(None)), + ), + Patch( + input_slice=_HorizontalSlice(y=slice(0, 2), x=slice(2, 3)), + output_slice=_HorizontalSlice(y=slice(None), x=slice(None)), + ), + ] + generated = list(batch.generate_from_patches(patches)) + + assert len(generated) == 2 + + # Patch 0: rows 1-2, all columns + expected_lat = torch.tensor([[1.0, 2.0], [1.0, 2.0]]) + expected_lon = torch.tensor([[0.0, 1.0, 2.0, 3.0], [0.0, 1.0, 2.0, 3.0]]) + assert torch.equal(generated[0].latlon_coordinates.lat, expected_lat) + assert torch.equal(generated[0].latlon_coordinates.lon, expected_lon) + assert torch.equal(generated[0].data["x"], batch.data["x"][:, 1:3, :]) + + # Patch 1: rows 0-1, column 2 + expected_lat = torch.tensor([[0.0, 1.0], [0.0, 1.0]]) + expected_lon = torch.tensor([[2.0], [2.0]]) + assert torch.equal(generated[1].latlon_coordinates.lat, expected_lat) + assert torch.equal(generated[1].latlon_coordinates.lon, expected_lon) + assert torch.equal(generated[1].data["x"], batch.data["x"][:, 0:2, 2:3]) + + +def test_batch_data_apply_patch_already_patched_raises(): + batch = _make_batch_data_for_patching() + patch = Patch( + input_slice=_HorizontalSlice(y=slice(1, 3), x=slice(None)), + output_slice=_HorizontalSlice(y=slice(None), x=slice(None)), + ) + (patched,) = list(batch.generate_from_patches([patch])) + with pytest.raises(ValueError, match="previously patched"): + list(patched.generate_from_patches([patch])) diff --git a/fme/downscaling/data/test_utils.py b/fme/downscaling/data/test_utils.py index b92bd0ed7..662231147 100644 --- a/fme/downscaling/data/test_utils.py +++ b/fme/downscaling/data/test_utils.py @@ -49,6 +49,23 @@ def test_adjust_fine_coord_range(downscale_factor, lat_range): assert len(subsel_fine_lat) / len(subsel_coarse_lat) == downscale_factor +def test_adjust_fine_coord_range_raises_near_domain_boundary(): + downscale_factor = 4 # n_half_fine = 2 + coarse_edges = torch.linspace(0, 6, 7) + coarse_lat = _fine_midpoints(coarse_edges, 1) + fine_lat = _fine_midpoints(coarse_edges, downscale_factor) + # Drop the first fine point so only 1 fine point exists below coarse_min=0.5, + # but n_half_fine=2 are required — simulating a grid truncated at the domain edge. + fine_lat_truncated = fine_lat[1:] + with pytest.raises(ValueError): + adjust_fine_coord_range( + ClosedInterval(0, 4), + full_coarse_coord=coarse_lat, + full_fine_coord=fine_lat_truncated, + downscale_factor=downscale_factor, + ) + + @pytest.mark.parametrize( "input_slice,expected", [ diff --git a/fme/downscaling/data/utils.py b/fme/downscaling/data/utils.py index f3d49f7a3..8cecf8c09 100644 --- a/fme/downscaling/data/utils.py +++ b/fme/downscaling/data/utils.py @@ -116,6 +116,13 @@ def adjust_fine_coord_range( If downscale factor is not provided, it is assumed that the coarse and fine coordinate tensors correspond to the same region bounds. + + Raises: + ValueError: If coord_range is too close to the boundary of full_fine_coord + such that fewer than downscale_factor // 2 fine points exist beyond the + outermost selected coarse point on either side. For global latitude grids, + this is avoided by restricting coord_range to within ±88° (i.e. away from + the poles). """ if downscale_factor is None: if full_fine_coord.shape[0] % full_coarse_coord.shape[0] != 0: @@ -132,6 +139,19 @@ def adjust_fine_coord_range( n_half_fine = downscale_factor // 2 coarse_min = full_coarse_coord[full_coarse_coord >= coord_range.start][0] coarse_max = full_coarse_coord[full_coarse_coord <= coord_range.stop][-1] + + n_fine_below = int((full_fine_coord < coarse_min).sum()) + n_fine_above = int((full_fine_coord > coarse_max).sum()) + if n_fine_below < n_half_fine or n_fine_above < n_half_fine: + raise ValueError( + f"coord_range {coord_range} is too close to the boundary of " + f"full_fine_coord [{full_fine_coord.min():.2f}, " + f"{full_fine_coord.max():.2f}]. Need at least {n_half_fine} fine " + f"point(s) beyond each coarse boundary; got {n_fine_below} below " + f"and {n_fine_above} above. Restrict the coordinate range away from " + f"the domain edges." + ) + fine_min = full_fine_coord[full_fine_coord < coarse_min][-n_half_fine] fine_max = full_fine_coord[full_fine_coord > coarse_max][n_half_fine - 1] diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 5897d3acb..8ffa8c22b 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -15,8 +15,10 @@ from fme.core.typing_ import TensorDict, TensorMapping from fme.downscaling.data import ( BatchData, + ClosedInterval, PairedBatchData, StaticInputs, + adjust_fine_coord_range, load_static_inputs, ) from fme.downscaling.metrics_and_maths import filter_tensor_mapping, interpolate @@ -303,12 +305,33 @@ def __init__( self.out_packer = Packer(config.out_names) self.config = config self._channel_axis = -3 - self.static_inputs = static_inputs + self.static_inputs = ( + static_inputs.to_device() if static_inputs is not None else None + ) @property def modules(self) -> torch.nn.ModuleList: return torch.nn.ModuleList([self.module]) + def _subset_static_inputs( + self, + lat_interval: ClosedInterval, + lon_interval: ClosedInterval, + ) -> StaticInputs | None: + """Subset self.static_inputs to the given fine lat/lon interval. + + Returns None if use_fine_topography is False. + Raises ValueError if use_fine_topography is True but self.static_inputs is None. + """ + if not self.config.use_fine_topography: + return None + if self.static_inputs is None: + raise ValueError( + "Static inputs must be provided for each batch when use of fine " + "static inputs is enabled." + ) + return self.static_inputs.subset_latlon(lat_interval, lon_interval) + @property def fine_shape(self) -> tuple[int, int]: return self._get_fine_shape(self.coarse_shape) @@ -338,6 +361,12 @@ def _get_input_from_coarse( "static inputs is enabled." ) else: + expected_shape = interpolated.shape[-2:] + if static_inputs.shape != expected_shape: + raise ValueError( + f"Subsetted static input shape {static_inputs.shape} does not " + f"match expected fine spatial shape {expected_shape}." + ) n_batches = normalized.shape[0] # Join normalized static inputs to input (see dataset for details) for field in static_inputs.fields: @@ -354,12 +383,17 @@ def _get_input_from_coarse( def train_on_batch( self, batch: PairedBatchData, - static_inputs: StaticInputs | None, + static_inputs: StaticInputs | None, # TODO: remove in follow-on PR optimizer: Optimization | NullOptimization, ) -> ModelOutputs: """Performs a denoising training step on a batch of data.""" + # Ignore the passed static_inputs; subset self.static_inputs using fine batch + # coordinates. The caller-provided value is kept for signature compatibility. + _static_inputs = self._subset_static_inputs( + batch.fine.lat_interval, batch.fine.lon_interval + ) coarse, fine = batch.coarse.data, batch.fine.data - inputs_norm = self._get_input_from_coarse(coarse, static_inputs) + inputs_norm = self._get_input_from_coarse(coarse, _static_inputs) targets_norm = self.out_packer.pack( self.normalizer.fine.normalize(dict(fine)), axis=self._channel_axis ) @@ -418,6 +452,8 @@ def generate( static_inputs: StaticInputs | None, n_samples: int = 1, ) -> tuple[TensorDict, torch.Tensor, list[torch.Tensor]]: + # static_inputs receives an internally-subsetted value from the calling method; + # external callers should use generate_on_batch / generate_on_batch_no_target. inputs_ = self._get_input_from_coarse(coarse_data, static_inputs) # expand samples and fold to # [batch * n_samples, output_channels, height, width] @@ -467,22 +503,54 @@ def generate( def generate_on_batch_no_target( self, batch: BatchData, - static_inputs: StaticInputs | None, + static_inputs: StaticInputs | None, # TODO: remove in follow-on PR n_samples: int = 1, ) -> TensorDict: - generated, _, _ = self.generate(batch.data, static_inputs, n_samples) + # Ignore the passed static_inputs; derive the fine lat/lon interval from coarse + # batch coordinates via adjust_fine_coord_range, then subset self.static_inputs. + if self.config.use_fine_topography: + if self.static_inputs is None: + raise ValueError( + "Static inputs must be provided for each batch when use of fine " + "static inputs is enabled." + ) + coarse_lat = batch.latlon_coordinates.lat[0] + coarse_lon = batch.latlon_coordinates.lon[0] + fine_lat_interval = adjust_fine_coord_range( + batch.lat_interval, + full_coarse_coord=coarse_lat, + full_fine_coord=self.static_inputs.coords.lat, + downscale_factor=self.downscale_factor, + ) + fine_lon_interval = adjust_fine_coord_range( + batch.lon_interval, + full_coarse_coord=coarse_lon, + full_fine_coord=self.static_inputs.coords.lon, + downscale_factor=self.downscale_factor, + ) + _static_inputs = self.static_inputs.subset_latlon( + fine_lat_interval, fine_lon_interval + ) + else: + _static_inputs = None + generated, _, _ = self.generate(batch.data, _static_inputs, n_samples) return generated @torch.no_grad() def generate_on_batch( self, batch: PairedBatchData, - static_inputs: StaticInputs | None, + static_inputs: StaticInputs | None, # TODO: remove in follow-on PR n_samples: int = 1, ) -> ModelOutputs: + # Ignore the passed static_inputs; subset self.static_inputs using fine batch + # coordinates. The caller-provided value is kept for signature compatibility. + _static_inputs = self._subset_static_inputs( + batch.fine.lat_interval, batch.fine.lon_interval + ) coarse, fine = batch.coarse.data, batch.fine.data generated, generated_norm, latent_steps = self.generate( - coarse, static_inputs, n_samples + coarse, _static_inputs, n_samples ) targets_norm = self.out_packer.pack( diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index dd2d772ac..d6798e66f 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -10,7 +10,13 @@ from fme.core.loss import LossConfig from fme.core.normalizer import NormalizationConfig from fme.core.optimization import OptimizationConfig -from fme.downscaling.data import StaticInput, StaticInputs +from fme.downscaling.data import ( + BatchData, + BatchedLatLonCoordinates, + PairedBatchData, + StaticInput, + StaticInputs, +) from fme.downscaling.models import ( CheckpointModelConfig, DiffusionModel, @@ -72,18 +78,70 @@ def get_mock_paired_batch(coarse_shape, fine_shape): return FineResCoarseResPair(fine=fine, coarse=coarse) -def test_module_serialization(tmp_path): - coarse_shape = (8, 16) - static_inputs = StaticInputs( +def make_batch_data( + shape: tuple[int, int, int], + lat_values: list[float], + lon_values: list[float], +) -> BatchData: + """Create a BatchData with proper monotonic coordinates.""" + batch_size, lat_size, lon_size = shape + assert lat_size == len(lat_values) + assert lon_size == len(lon_values) + data = {"x": torch.ones(batch_size, lat_size, lon_size, device=get_device())} + time = xr.DataArray(range(batch_size), dims=["batch"]) + lat = torch.tensor(lat_values, dtype=torch.float32) + lon = torch.tensor(lon_values, dtype=torch.float32) + latlon = BatchedLatLonCoordinates( + lat=lat.unsqueeze(0).expand(batch_size, -1), + lon=lon.unsqueeze(0).expand(batch_size, -1), + ) + return BatchData(data=data, time=time, latlon_coordinates=latlon) + + +def _get_monotonic_coordinate(size: int, stop: float) -> torch.Tensor: + bounds = torch.linspace(0, stop, size + 1) + coord = (bounds[:-1] + bounds[1:]) / 2 + return coord + + +def make_paired_batch_data( + coarse_shape: tuple[int, int], + fine_shape: tuple[int, int], + batch_size: int = 2, +) -> PairedBatchData: + """ + Create a PairedBatchData with consistent monotonic coordinates. + """ + lat_c, lon_c = coarse_shape + lat_f, lon_f = fine_shape + fine_lat = _get_monotonic_coordinate(lat_f, stop=lat_f) + fine_lon = _get_monotonic_coordinate(lon_f, stop=lon_f) + coarse_lat = _get_monotonic_coordinate(lat_c, stop=lat_f) + coarse_lon = _get_monotonic_coordinate(lon_c, stop=lon_f) + fine = make_batch_data((batch_size, lat_f, lon_f), fine_lat, fine_lon) + coarse = make_batch_data((batch_size, lat_c, lon_c), coarse_lat, coarse_lon) + return PairedBatchData(fine=fine, coarse=coarse) + + +def make_static_inputs(fine_shape: tuple[int, int]) -> StaticInputs: + """Create StaticInputs with proper monotonic coordinates for given shape.""" + lat_size, lon_size = fine_shape + return StaticInputs( fields=[ StaticInput( - torch.rand(*coarse_shape, device=get_device()), + torch.ones(*fine_shape, device=get_device()), LatLonCoordinates( - lat=torch.ones(coarse_shape[0]), lon=torch.ones(coarse_shape[1]) + lat=_get_monotonic_coordinate(lat_size, stop=lat_size), + lon=_get_monotonic_coordinate(lon_size, stop=lon_size), ), ) ] ) + + +def test_module_serialization(tmp_path): + coarse_shape = (8, 16) + static_inputs = make_static_inputs((16, 32)) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=2, @@ -122,16 +180,7 @@ def test_from_state_backward_compat_fine_topography(): coarse_shape = (8, 16) fine_shape = (16, 32) downscale_factor = 2 - static_inputs = StaticInputs( - fields=[ - StaticInput( - torch.ones(*fine_shape, device=get_device()), - LatLonCoordinates( - lat=torch.ones(fine_shape[0]), lon=torch.ones(fine_shape[1]) - ), - ) - ] - ) + static_inputs = make_static_inputs(fine_shape) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=downscale_factor, @@ -196,19 +245,15 @@ def _get_diffusion_model( def test_diffusion_model_train_and_generate(predict_residual, use_fine_topography): coarse_shape = (8, 16) fine_shape = (16, 32) + batch_size = 2 if use_fine_topography: - static_inputs = StaticInputs( - fields=[ - StaticInput( - torch.ones(*fine_shape, device=get_device()), - LatLonCoordinates( - lat=torch.ones(fine_shape[0]), lon=torch.ones(fine_shape[1]) - ), - ) - ] - ) + static_inputs = make_static_inputs(fine_shape) + batch = make_paired_batch_data(coarse_shape, fine_shape, batch_size) else: static_inputs = None + batch = get_mock_paired_batch( + [batch_size, *coarse_shape], [batch_size, *fine_shape] + ) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=2, @@ -219,11 +264,6 @@ def test_diffusion_model_train_and_generate(predict_residual, use_fine_topograph assert model._get_fine_shape(coarse_shape) == fine_shape - batch_size = 2 - - batch = get_mock_paired_batch( - [batch_size, *coarse_shape], [batch_size, *fine_shape] - ) optimization = OptimizationConfig().build(modules=[model.module], max_epochs=2) train_outputs = model.train_on_batch(batch, static_inputs, optimization) assert torch.allclose(train_outputs.target["x"], batch.fine.data["x"]) @@ -359,16 +399,7 @@ def test_DiffusionModel_generate_on_batch_no_target(): fine_shape = (32, 32) coarse_shape = (16, 16) downscale_factor = 2 - static_inputs = StaticInputs( - fields=[ - StaticInput( - torch.rand(*fine_shape, device=get_device()), - LatLonCoordinates( - lat=torch.ones(fine_shape[0]), lon=torch.ones(fine_shape[1]) - ), - ) - ] - ) + static_inputs = make_static_inputs(fine_shape) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=downscale_factor, @@ -378,12 +409,11 @@ def test_DiffusionModel_generate_on_batch_no_target(): ) batch_size = 2 - n_generated_samples = 2 - coarse_batch = get_mock_batch( - [batch_size, *coarse_shape], topography_scale_factor=downscale_factor - ) + coarse_lat = _get_monotonic_coordinate(coarse_shape[0], stop=fine_shape[0]) + coarse_lon = _get_monotonic_coordinate(coarse_shape[1], stop=fine_shape[1]) + coarse_batch = make_batch_data((batch_size, *coarse_shape), coarse_lat, coarse_lon) samples = model.generate_on_batch_no_target( coarse_batch, @@ -399,18 +429,15 @@ def test_DiffusionModel_generate_on_batch_no_target(): def test_DiffusionModel_generate_on_batch_no_target_arbitrary_input_size(): - # We currently require an input coarse shape for accounting, but the model - # can handle arbitrary input sizes + # The model subsets its own stored static_inputs based on coarse batch + # coordinates. The stored static_inputs must cover the full fine domain + # for all tested batch sizes. coarse_shape = (16, 16) downscale_factor = 2 - static_inputs = StaticInputs( - fields=[ - StaticInput( - torch.rand(32, 32, device=get_device()), - LatLonCoordinates(torch.ones(32), torch.ones(32)), - ) - ] - ) + # Full fine domain: 64x64 covers inputs for both (8,8) and (32,32) coarse inputs + # with a downscaling factor of 2 + full_fine_size = 64 + static_inputs = make_static_inputs((full_fine_size, full_fine_size)) # need to build with static inputs to get the correct n_in_channels model = _get_diffusion_model( coarse_shape=coarse_shape, @@ -424,22 +451,14 @@ def test_DiffusionModel_generate_on_batch_no_target_arbitrary_input_size(): for alternative_input_shape in [(8, 8), (32, 32)]: fine_shape = tuple(dim * downscale_factor for dim in alternative_input_shape) - coarse_batch = get_mock_batch( - [batch_size, *alternative_input_shape], - topography_scale_factor=downscale_factor, - ) - static_inputs = StaticInputs( - fields=[ - StaticInput( - torch.rand(*fine_shape, device=get_device()), - LatLonCoordinates( - torch.ones(fine_shape[0]), torch.ones(fine_shape[1]) - ), - ) - ] + alt_y, alt_x = alternative_input_shape + coarse_lat = _get_monotonic_coordinate(alt_y, stop=alt_y * downscale_factor) + coarse_lon = _get_monotonic_coordinate(alt_x, stop=alt_x * downscale_factor) + coarse_batch = make_batch_data( + (batch_size, *alternative_input_shape), coarse_lat, coarse_lon ) samples = model.generate_on_batch_no_target( - coarse_batch, n_samples=n_ensemble, static_inputs=static_inputs + coarse_batch, n_samples=n_ensemble, static_inputs=None ) assert samples["x"].shape == ( @@ -528,16 +547,7 @@ def test_checkpoint_config_topography_raises(): def test_checkpoint_model_build_raises_when_checkpoint_has_static_inputs(tmp_path): coarse_shape = (8, 16) fine_shape = (16, 32) - static_inputs = StaticInputs( - fields=[ - StaticInput( - torch.ones(*fine_shape, device=get_device()), - LatLonCoordinates( - lat=torch.ones(fine_shape[0]), lon=torch.ones(fine_shape[1]) - ), - ) - ] - ) + static_inputs = make_static_inputs(fine_shape) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=2,