-
Notifications
You must be signed in to change notification settings - Fork 38
Downscaling model handles static input instead of DataLoader #954
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
5cf4452
000dcf7
76903d7
f144258
cac8ebf
7debcc0
e3192af
ba5502a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -441,6 +441,16 @@ def __post_init__(self): | |
| def horizontal_shape(self) -> tuple[int, int]: | ||
| return self._horizontal_shape | ||
|
|
||
| @property | ||
| def lat_interval(self) -> ClosedInterval: | ||
| lat = self.latlon_coordinates.lat[0] # all batch members identical; use first | ||
| return ClosedInterval(lat.min().item(), lat.max().item()) | ||
|
|
||
| @property | ||
| def lon_interval(self) -> ClosedInterval: | ||
| lon = self.latlon_coordinates.lon[0] # all batch members identical; use first | ||
| return ClosedInterval(lon.min().item(), lon.max().item()) | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a test that the coordinates are as expected when
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added tests for the data and coordinates for generate_from_patches under |
||
| @classmethod | ||
| def from_sequence( | ||
| cls, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,8 +15,10 @@ | |
| from fme.core.typing_ import TensorDict, TensorMapping | ||
| from fme.downscaling.data import ( | ||
| BatchData, | ||
| ClosedInterval, | ||
| PairedBatchData, | ||
| StaticInputs, | ||
| adjust_fine_coord_range, | ||
| load_static_inputs, | ||
| ) | ||
| from fme.downscaling.metrics_and_maths import filter_tensor_mapping, interpolate | ||
|
|
@@ -303,12 +305,33 @@ def __init__( | |
| self.out_packer = Packer(config.out_names) | ||
| self.config = config | ||
| self._channel_axis = -3 | ||
| self.static_inputs = static_inputs | ||
| self.static_inputs = ( | ||
| static_inputs.to_device() if static_inputs is not None else None | ||
| ) | ||
|
|
||
| @property | ||
| def modules(self) -> torch.nn.ModuleList: | ||
| return torch.nn.ModuleList([self.module]) | ||
|
|
||
| def _subset_static_inputs( | ||
| self, | ||
| lat_interval: ClosedInterval, | ||
| lon_interval: ClosedInterval, | ||
| ) -> StaticInputs | None: | ||
| """Subset self.static_inputs to the given fine lat/lon interval. | ||
|
|
||
| Returns None if use_fine_topography is False. | ||
| Raises ValueError if use_fine_topography is True but self.static_inputs is None. | ||
| """ | ||
| if not self.config.use_fine_topography: | ||
| return None | ||
| if self.static_inputs is None: | ||
| raise ValueError( | ||
| "Static inputs must be provided for each batch when use of fine " | ||
| "static inputs is enabled." | ||
| ) | ||
| return self.static_inputs.subset_latlon(lat_interval, lon_interval) | ||
|
|
||
| @property | ||
| def fine_shape(self) -> tuple[int, int]: | ||
| return self._get_fine_shape(self.coarse_shape) | ||
|
|
@@ -338,6 +361,12 @@ def _get_input_from_coarse( | |
| "static inputs is enabled." | ||
| ) | ||
| else: | ||
| expected_shape = interpolated.shape[-2:] | ||
| if static_inputs.shape != expected_shape: | ||
| raise ValueError( | ||
| f"Subsetted static input shape {static_inputs.shape} does not " | ||
| f"match expected fine spatial shape {expected_shape}." | ||
| ) | ||
| n_batches = normalized.shape[0] | ||
| # Join normalized static inputs to input (see dataset for details) | ||
| for field in static_inputs.fields: | ||
|
|
@@ -354,12 +383,17 @@ def _get_input_from_coarse( | |
| def train_on_batch( | ||
| self, | ||
| batch: PairedBatchData, | ||
| static_inputs: StaticInputs | None, | ||
| static_inputs: StaticInputs | None, # TODO: remove in follow-on PR | ||
| optimizer: Optimization | NullOptimization, | ||
| ) -> ModelOutputs: | ||
| """Performs a denoising training step on a batch of data.""" | ||
| # Ignore the passed static_inputs; subset self.static_inputs using fine batch | ||
| # coordinates. The caller-provided value is kept for signature compatibility. | ||
| _static_inputs = self._subset_static_inputs( | ||
| batch.fine.lat_interval, batch.fine.lon_interval | ||
| ) | ||
| coarse, fine = batch.coarse.data, batch.fine.data | ||
| inputs_norm = self._get_input_from_coarse(coarse, static_inputs) | ||
| inputs_norm = self._get_input_from_coarse(coarse, _static_inputs) | ||
| targets_norm = self.out_packer.pack( | ||
| self.normalizer.fine.normalize(dict(fine)), axis=self._channel_axis | ||
| ) | ||
|
|
@@ -415,9 +449,11 @@ def train_on_batch( | |
| def generate( | ||
| self, | ||
| coarse_data: TensorMapping, | ||
| static_inputs: StaticInputs | None, | ||
| static_inputs: StaticInputs | None, # TODO: remove in follow-on PR | ||
|
||
| n_samples: int = 1, | ||
| ) -> tuple[TensorDict, torch.Tensor, list[torch.Tensor]]: | ||
| # static_inputs receives an internally-subsetted value from the calling method; | ||
| # external callers should use generate_on_batch / generate_on_batch_no_target. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Out of scope for this PR, but the only external caller is
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mentioned in the #959 , but I ran into the problem that inference needs knowledge of the output sizes that forces some awkward handling within that module. I think this could be solved by having a richer set of output information (e.g., coordinates) passed along in the generation/prediction. That would also allow for some smarter handling in CascadedModels as well. |
||
| inputs_ = self._get_input_from_coarse(coarse_data, static_inputs) | ||
| # expand samples and fold to | ||
| # [batch * n_samples, output_channels, height, width] | ||
|
|
@@ -467,22 +503,54 @@ def generate( | |
| def generate_on_batch_no_target( | ||
| self, | ||
| batch: BatchData, | ||
| static_inputs: StaticInputs | None, | ||
| static_inputs: StaticInputs | None, # TODO: remove in follow-on PR | ||
| n_samples: int = 1, | ||
| ) -> TensorDict: | ||
| generated, _, _ = self.generate(batch.data, static_inputs, n_samples) | ||
| # Ignore the passed static_inputs; derive the fine lat/lon interval from coarse | ||
| # batch coordinates via adjust_fine_coord_range, then subset self.static_inputs. | ||
| if self.config.use_fine_topography: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this check necessary? Could this instead just check if Since the previous PR removed the old option of loading HGTsfc from the fine dataset, this config option can be deprecated and downstream checks could be removed.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I agree and was thinking this and the |
||
| if self.static_inputs is None: | ||
| raise ValueError( | ||
| "Static inputs must be provided for each batch when use of fine " | ||
| "static inputs is enabled." | ||
| ) | ||
| coarse_lat = batch.latlon_coordinates.lat[0] | ||
| coarse_lon = batch.latlon_coordinates.lon[0] | ||
| fine_lat_interval = adjust_fine_coord_range( | ||
| batch.lat_interval, | ||
| full_coarse_coord=coarse_lat, | ||
| full_fine_coord=self.static_inputs.coords.lat, | ||
| downscale_factor=self.downscale_factor, | ||
| ) | ||
| fine_lon_interval = adjust_fine_coord_range( | ||
| batch.lon_interval, | ||
| full_coarse_coord=coarse_lon, | ||
| full_fine_coord=self.static_inputs.coords.lon, | ||
| downscale_factor=self.downscale_factor, | ||
| ) | ||
| _static_inputs = self.static_inputs.subset_latlon( | ||
| fine_lat_interval, fine_lon_interval | ||
| ) | ||
| else: | ||
| _static_inputs = None | ||
| generated, _, _ = self.generate(batch.data, _static_inputs, n_samples) | ||
| return generated | ||
|
|
||
| @torch.no_grad() | ||
| def generate_on_batch( | ||
| self, | ||
| batch: PairedBatchData, | ||
| static_inputs: StaticInputs | None, | ||
| static_inputs: StaticInputs | None, # TODO: remove in follow-on PR | ||
| n_samples: int = 1, | ||
| ) -> ModelOutputs: | ||
| # Ignore the passed static_inputs; subset self.static_inputs using fine batch | ||
| # coordinates. The caller-provided value is kept for signature compatibility. | ||
| _static_inputs = self._subset_static_inputs( | ||
| batch.fine.lat_interval, batch.fine.lon_interval | ||
| ) | ||
| coarse, fine = batch.coarse.data, batch.fine.data | ||
| generated, generated_norm, latent_steps = self.generate( | ||
| coarse, static_inputs, n_samples | ||
| coarse, _static_inputs, n_samples | ||
| ) | ||
|
|
||
| targets_norm = self.out_packer.pack( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that we're no longer doing per-item random patch definitions, we could update
BatchDatato just useLatLonCoordinates