From ec876b450755de828d6431c03582a2d12d00d880 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Wed, 11 Mar 2026 14:46:44 -0700 Subject: [PATCH] create BatchData wrapper in CascadePredictor._generate, make _generate methods private --- fme/downscaling/models.py | 6 ++-- fme/downscaling/predictors/cascade.py | 40 ++++++++++++++++++---- fme/downscaling/predictors/test_cascade.py | 2 +- 3 files changed, 37 insertions(+), 11 deletions(-) diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 5897d3acb..f7057d3ac 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -412,7 +412,7 @@ def train_on_batch( ) @torch.no_grad() - def generate( + def _generate( self, coarse_data: TensorMapping, static_inputs: StaticInputs | None, @@ -470,7 +470,7 @@ def generate_on_batch_no_target( static_inputs: StaticInputs | None, n_samples: int = 1, ) -> TensorDict: - generated, _, _ = self.generate(batch.data, static_inputs, n_samples) + generated, _, _ = self._generate(batch.data, static_inputs, n_samples) return generated @torch.no_grad() @@ -481,7 +481,7 @@ def generate_on_batch( n_samples: int = 1, ) -> ModelOutputs: coarse, fine = batch.coarse.data, batch.fine.data - generated, generated_norm, latent_steps = self.generate( + generated, generated_norm, latent_steps = self._generate( coarse, static_inputs, n_samples ) diff --git a/fme/downscaling/predictors/cascade.py b/fme/downscaling/predictors/cascade.py index 0f64c73b5..3399c3db6 100644 --- a/fme/downscaling/predictors/cascade.py +++ b/fme/downscaling/predictors/cascade.py @@ -2,6 +2,7 @@ import math import torch +import xarray as xr from fme.core.coordinates import LatLonCoordinates from fme.core.device import get_device @@ -15,6 +16,7 @@ adjust_fine_coord_range, scale_tuple, ) +from fme.downscaling.data.utils import BatchedLatLonCoordinates from fme.downscaling.metrics_and_maths import filter_tensor_mapping from fme.downscaling.models import CheckpointModelConfig, DiffusionModel, ModelOutputs from fme.downscaling.requirements import DataRequirements @@ -86,6 +88,26 @@ def _restore_batch_and_sample_dims(data: TensorMapping, n_samples: int): return unfold_ensemble_dim(squeezed, n_samples) +def _batch_data_with_unused_coords(data: TensorMapping) -> BatchData: + # wrapper function so that we can call each level's + # public generate_on_batch_no_target function using tensormapping + # from the previous step. + data_shape = next(iter(data.values())).shape + time = xr.DataArray( + [0 for _ in range(data_shape[0])], + dims=["time"], + ) + latlon_coordinates = BatchedLatLonCoordinates( + lat=torch.zeros((data_shape[0], data_shape[1]), device=get_device()), + lon=torch.zeros((data_shape[0], data_shape[2]), device=get_device()), + ) + return BatchData( + data=data, + time=time, + latlon_coordinates=latlon_coordinates, + ) + + class CascadePredictor: def __init__( self, models: list[DiffusionModel], static_inputs: list[StaticInputs | None] @@ -116,22 +138,26 @@ def modules(self) -> torch.nn.ModuleList: return torch.nn.ModuleList([model.modules for model in self.models]) @torch.no_grad() - def generate( + def _generate( self, coarse: TensorMapping, n_samples: int, static_inputs: list[StaticInputs | None], ): current_coarse = coarse - for i, (model, fine_topography) in enumerate(zip(self.models, static_inputs)): + for i, (model, step_static_inputs) in enumerate( + zip(self.models, static_inputs) + ): sample_data = next(iter(current_coarse.values())) batch_size = sample_data.shape[0] # n_samples are generated for the first step, and subsequent models # generate 1 sample n_samples_cascade_step = n_samples if i == 0 else 1 - generated, generated_norm, latent_steps = model.generate( - current_coarse, fine_topography, n_samples_cascade_step + generated = model.generate_on_batch_no_target( + _batch_data_with_unused_coords(current_coarse), + step_static_inputs, + n_samples_cascade_step, ) generated = { k: v.reshape(batch_size * n_samples_cascade_step, *v.shape[-2:]) @@ -139,7 +165,7 @@ def generate( } current_coarse = generated generated = _restore_batch_and_sample_dims(generated, n_samples) - return generated, generated_norm, latent_steps + return generated @torch.no_grad() def generate_on_batch_no_target( @@ -151,7 +177,7 @@ def generate_on_batch_no_target( subset_static_inputs = self._get_subset_static_inputs( coarse_coords=batch.latlon_coordinates[0] ) - generated, _, _ = self.generate(batch.data, n_samples, subset_static_inputs) + generated = self._generate(batch.data, n_samples, subset_static_inputs) return generated @torch.no_grad() @@ -164,7 +190,7 @@ def generate_on_batch( static_inputs = self._get_subset_static_inputs( coarse_coords=batch.coarse.latlon_coordinates[0] ) - generated, _, latent_steps = self.generate( + generated, _, latent_steps = self._generate( batch.coarse.data, n_samples, static_inputs ) targets = filter_tensor_mapping(batch.fine.data, set(self.out_packer.names)) diff --git a/fme/downscaling/predictors/test_cascade.py b/fme/downscaling/predictors/test_cascade.py index 99511d576..a2f3f5cdb 100644 --- a/fme/downscaling/predictors/test_cascade.py +++ b/fme/downscaling/predictors/test_cascade.py @@ -98,7 +98,7 @@ def test_CascadePredictor_generate(downscale_factors): dtype=torch.float32, ) } - generated, _, _ = cascade_predictor.generate( + generated = cascade_predictor._generate( coarse=coarse_input, n_samples=n_samples_generate, static_inputs=static_inputs_list,