Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions fme/downscaling/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,16 @@ def __post_init__(self):
def horizontal_shape(self) -> tuple[int, int]:
return self._horizontal_shape

@property
def lat_interval(self) -> ClosedInterval:
lat = self.latlon_coordinates.lat[0] # all batch members identical; use first
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that we're no longer doing per-item random patch definitions, we could update BatchData to just use LatLonCoordinates

return ClosedInterval(lat.min().item(), lat.max().item())

@property
def lon_interval(self) -> ClosedInterval:
lon = self.latlon_coordinates.lon[0] # all batch members identical; use first
return ClosedInterval(lon.min().item(), lon.max().item())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test that the coordinates are as expected when BatchData.generate_from_patches is called? It looks like the existing tests for that usage only check data values.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added tests for the data and coordinates for generate_from_patches under test_datasets.py. I did not adjust the tests in test_static.py since the patch generation will be removed in #956.

@classmethod
def from_sequence(
cls,
Expand Down
17 changes: 17 additions & 0 deletions fme/downscaling/data/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,23 @@ def test_adjust_fine_coord_range(downscale_factor, lat_range):
assert len(subsel_fine_lat) / len(subsel_coarse_lat) == downscale_factor


def test_adjust_fine_coord_range_raises_near_domain_boundary():
downscale_factor = 4 # n_half_fine = 2
coarse_edges = torch.linspace(0, 6, 7)
coarse_lat = _fine_midpoints(coarse_edges, 1)
fine_lat = _fine_midpoints(coarse_edges, downscale_factor)
# Drop the first fine point so only 1 fine point exists below coarse_min=0.5,
# but n_half_fine=2 are required — simulating a grid truncated at the domain edge.
fine_lat_truncated = fine_lat[1:]
with pytest.raises(ValueError):
adjust_fine_coord_range(
ClosedInterval(0, 4),
full_coarse_coord=coarse_lat,
full_fine_coord=fine_lat_truncated,
downscale_factor=downscale_factor,
)


@pytest.mark.parametrize(
"input_slice,expected",
[
Expand Down
20 changes: 20 additions & 0 deletions fme/downscaling/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ def adjust_fine_coord_range(

If downscale factor is not provided, it is assumed that the coarse and fine
coordinate tensors correspond to the same region bounds.

Raises:
ValueError: If coord_range is too close to the boundary of full_fine_coord
such that fewer than downscale_factor // 2 fine points exist beyond the
outermost selected coarse point on either side. For global latitude grids,
this is avoided by restricting coord_range to within ±88° (i.e. away from
the poles).
"""
if downscale_factor is None:
if full_fine_coord.shape[0] % full_coarse_coord.shape[0] != 0:
Expand All @@ -132,6 +139,19 @@ def adjust_fine_coord_range(
n_half_fine = downscale_factor // 2
coarse_min = full_coarse_coord[full_coarse_coord >= coord_range.start][0]
coarse_max = full_coarse_coord[full_coarse_coord <= coord_range.stop][-1]

n_fine_below = int((full_fine_coord < coarse_min).sum())
n_fine_above = int((full_fine_coord > coarse_max).sum())
if n_fine_below < n_half_fine or n_fine_above < n_half_fine:
raise ValueError(
f"coord_range {coord_range} is too close to the boundary of "
f"full_fine_coord [{full_fine_coord.min():.2f}, "
f"{full_fine_coord.max():.2f}]. Need at least {n_half_fine} fine "
f"point(s) beyond each coarse boundary; got {n_fine_below} below "
f"and {n_fine_above} above. Restrict the coordinate range away from "
f"the domain edges."
)

fine_min = full_fine_coord[full_fine_coord < coarse_min][-n_half_fine]
fine_max = full_fine_coord[full_fine_coord > coarse_max][n_half_fine - 1]

Expand Down
84 changes: 76 additions & 8 deletions fme/downscaling/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
from fme.core.typing_ import TensorDict, TensorMapping
from fme.downscaling.data import (
BatchData,
ClosedInterval,
PairedBatchData,
StaticInputs,
adjust_fine_coord_range,
load_static_inputs,
)
from fme.downscaling.metrics_and_maths import filter_tensor_mapping, interpolate
Expand Down Expand Up @@ -303,12 +305,33 @@ def __init__(
self.out_packer = Packer(config.out_names)
self.config = config
self._channel_axis = -3
self.static_inputs = static_inputs
self.static_inputs = (
static_inputs.to_device() if static_inputs is not None else None
)

@property
def modules(self) -> torch.nn.ModuleList:
return torch.nn.ModuleList([self.module])

def _subset_static_inputs(
self,
lat_interval: ClosedInterval,
lon_interval: ClosedInterval,
) -> StaticInputs | None:
"""Subset self.static_inputs to the given fine lat/lon interval.

Returns None if use_fine_topography is False.
Raises ValueError if use_fine_topography is True but self.static_inputs is None.
"""
if not self.config.use_fine_topography:
return None
if self.static_inputs is None:
raise ValueError(
"Static inputs must be provided for each batch when use of fine "
"static inputs is enabled."
)
return self.static_inputs.subset_latlon(lat_interval, lon_interval)

@property
def fine_shape(self) -> tuple[int, int]:
return self._get_fine_shape(self.coarse_shape)
Expand Down Expand Up @@ -338,6 +361,12 @@ def _get_input_from_coarse(
"static inputs is enabled."
)
else:
expected_shape = interpolated.shape[-2:]
if static_inputs.shape != expected_shape:
raise ValueError(
f"Subsetted static input shape {static_inputs.shape} does not "
f"match expected fine spatial shape {expected_shape}."
)
n_batches = normalized.shape[0]
# Join normalized static inputs to input (see dataset for details)
for field in static_inputs.fields:
Expand All @@ -354,12 +383,17 @@ def _get_input_from_coarse(
def train_on_batch(
self,
batch: PairedBatchData,
static_inputs: StaticInputs | None,
static_inputs: StaticInputs | None, # TODO: remove in follow-on PR
optimizer: Optimization | NullOptimization,
) -> ModelOutputs:
"""Performs a denoising training step on a batch of data."""
# Ignore the passed static_inputs; subset self.static_inputs using fine batch
# coordinates. The caller-provided value is kept for signature compatibility.
_static_inputs = self._subset_static_inputs(
batch.fine.lat_interval, batch.fine.lon_interval
)
coarse, fine = batch.coarse.data, batch.fine.data
inputs_norm = self._get_input_from_coarse(coarse, static_inputs)
inputs_norm = self._get_input_from_coarse(coarse, _static_inputs)
targets_norm = self.out_packer.pack(
self.normalizer.fine.normalize(dict(fine)), axis=self._channel_axis
)
Expand Down Expand Up @@ -415,9 +449,11 @@ def train_on_batch(
def generate(
self,
coarse_data: TensorMapping,
static_inputs: StaticInputs | None,
static_inputs: StaticInputs | None, # TODO: remove in follow-on PR
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be kept since generate will be passed subsetted static_inputs from generate_on_batch or generate_on_batch_no_target, right?

n_samples: int = 1,
) -> tuple[TensorDict, torch.Tensor, list[torch.Tensor]]:
# static_inputs receives an internally-subsetted value from the calling method;
# external callers should use generate_on_batch / generate_on_batch_no_target.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of scope for this PR, but the only external caller is CascadePredictor.generate, which uses this instead of generate_on_batch_no_target because it was simpler to just pass the first model's output tensor instead of making a new BatchData object out of it to input to the next model. It would be worth adding a helper function to construct a new BatchData object out of generate's output; then we could make this method private.

Copy link
Collaborator Author

@frodre frodre Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mentioned in the #959 , but I ran into the problem that inference needs knowledge of the output sizes that forces some awkward handling within that module. I think this could be solved by having a richer set of output information (e.g., coordinates) passed along in the generation/prediction. That would also allow for some smarter handling in CascadedModels as well.

inputs_ = self._get_input_from_coarse(coarse_data, static_inputs)
# expand samples and fold to
# [batch * n_samples, output_channels, height, width]
Expand Down Expand Up @@ -467,22 +503,54 @@ def generate(
def generate_on_batch_no_target(
self,
batch: BatchData,
static_inputs: StaticInputs | None,
static_inputs: StaticInputs | None, # TODO: remove in follow-on PR
n_samples: int = 1,
) -> TensorDict:
generated, _, _ = self.generate(batch.data, static_inputs, n_samples)
# Ignore the passed static_inputs; derive the fine lat/lon interval from coarse
# batch coordinates via adjust_fine_coord_range, then subset self.static_inputs.
if self.config.use_fine_topography:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this check necessary? Could this instead just check if self.static_inputs is None?

Since the previous PR removed the old option of loading HGTsfc from the fine dataset, this config option can be deprecated and downstream checks could be removed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree and was thinking this and the DataRequirements.use_fine_topography would be handled in another PR to simplify things.

if self.static_inputs is None:
raise ValueError(
"Static inputs must be provided for each batch when use of fine "
"static inputs is enabled."
)
coarse_lat = batch.latlon_coordinates.lat[0]
coarse_lon = batch.latlon_coordinates.lon[0]
fine_lat_interval = adjust_fine_coord_range(
batch.lat_interval,
full_coarse_coord=coarse_lat,
full_fine_coord=self.static_inputs.coords.lat,
downscale_factor=self.downscale_factor,
)
fine_lon_interval = adjust_fine_coord_range(
batch.lon_interval,
full_coarse_coord=coarse_lon,
full_fine_coord=self.static_inputs.coords.lon,
downscale_factor=self.downscale_factor,
)
_static_inputs = self.static_inputs.subset_latlon(
fine_lat_interval, fine_lon_interval
)
else:
_static_inputs = None
generated, _, _ = self.generate(batch.data, _static_inputs, n_samples)
return generated

@torch.no_grad()
def generate_on_batch(
self,
batch: PairedBatchData,
static_inputs: StaticInputs | None,
static_inputs: StaticInputs | None, # TODO: remove in follow-on PR
n_samples: int = 1,
) -> ModelOutputs:
# Ignore the passed static_inputs; subset self.static_inputs using fine batch
# coordinates. The caller-provided value is kept for signature compatibility.
_static_inputs = self._subset_static_inputs(
batch.fine.lat_interval, batch.fine.lon_interval
)
coarse, fine = batch.coarse.data, batch.fine.data
generated, generated_norm, latent_steps = self.generate(
coarse, static_inputs, n_samples
coarse, _static_inputs, n_samples
)

targets_norm = self.out_packer.pack(
Expand Down
Loading
Loading