Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
754f28a
Initial shot
frodre Mar 12, 2026
7c5b4c5
Make fine coords required
frodre Mar 13, 2026
c7f58d9
Fine coords required for paired data
frodre Mar 13, 2026
563b57a
Mesh with previous updates in refactor pr
frodre Mar 17, 2026
f4218dd
Simplify event downscaler coordinate in run()
frodre Mar 17, 2026
68cab61
use batch latlon coardinates for coarse
frodre Mar 17, 2026
83ca043
Make fine coord loader public
frodre Mar 17, 2026
794a7d4
BatchLatLon coord access consistency
frodre Mar 17, 2026
b542080
linting
frodre Mar 17, 2026
5add727
Add no coords checkpoint with path test
frodre Mar 17, 2026
36e80db
Small tweaks
frodre Mar 17, 2026
b19c9d6
Add load_fine_coords_from_path test
frodre Mar 17, 2026
7acf87c
Update fme/downscaling/models.py
frodre Mar 18, 2026
eba2749
Update fme/downscaling/models.py
frodre Mar 18, 2026
2c25f1d
Update fme/downscaling/models.py
frodre Mar 18, 2026
81e4eb8
use latlon coords .to method for device fix
frodre Mar 18, 2026
325d324
Redo based on Anna's comments
frodre Mar 19, 2026
1fac745
Remove from_state docstring
frodre Mar 19, 2026
3ef1613
Move all state loading cases into static inputs code
frodre Mar 19, 2026
36e8101
Remove unused function from fme.downscaling.data
frodre Mar 19, 2026
3f95457
fine_coords -> full_fine_coords
frodre Mar 19, 2026
6c9fcf7
Remove duplicated tests from models.py
frodre Mar 19, 2026
4b58b86
Minor fixes
frodre Mar 19, 2026
7e36f55
Fix imports
frodre Mar 19, 2026
a85fb05
Final cleanup
frodre Mar 19, 2026
b1ee3d4
Merge branch 'main' into feature/downscaling-model-fine-coords
frodre Mar 19, 2026
bedd276
Add coordinate validation and use training data as a fallback vs. sta…
frodre Mar 20, 2026
a9dbae3
Updates based on discussion w/ Anna
frodre Mar 21, 2026
46670ff
Merge branch 'main' into feature/downscaling-model-fine-coords
frodre Mar 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion fme/downscaling/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
PairedBatchData,
PairedGriddedData,
)
from fme.downscaling.data.utils import ClosedInterval, adjust_fine_coord_range
from fme.downscaling.data.utils import (
ClosedInterval,
adjust_fine_coord_range,
get_latlon_coords_from_properties,
)
from fme.downscaling.requirements import DataRequirements


Expand Down Expand Up @@ -529,6 +533,7 @@ def build(
dims=example.fine.latlon_coordinates.dims,
variable_metadata=variable_metadata,
all_times=all_times,
fine_coords=get_latlon_coords_from_properties(properties_fine),
)

def _get_sampler(
Expand Down
7 changes: 5 additions & 2 deletions fme/downscaling/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,13 @@ def __init__(
f"expected lon_min < {self.lon_interval.start + 360.0}"
)

# Used to subset the data in __getitem__
self._lats_slice = self.lat_interval.slice_of(self._orig_coords.lat)
self._lons_slice = self.lon_interval.slice_of(self._orig_coords.lon)

self._latlon_coordinates = LatLonCoordinates(
lat=self._orig_coords.lat[self._lats_slice],
lon=self._orig_coords.lon[self._lons_slice],
lat=self.lat_interval.subset_of(self._orig_coords.lat),
lon=self.lon_interval.subset_of(self._orig_coords.lon),
)
self._area_weights = self._latlon_coordinates.area_weights

Expand Down Expand Up @@ -337,6 +339,7 @@ class PairedGriddedData:
dims: list[str]
variable_metadata: Mapping[str, VariableMetadata]
all_times: xr.CFTimeIndex
fine_coords: LatLonCoordinates

@property
def loader(self) -> DataLoader[PairedBatchItem]:
Expand Down
78 changes: 15 additions & 63 deletions fme/downscaling/data/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,17 @@
import torch
import xarray as xr

from fme.core.coordinates import LatLonCoordinates
from fme.core.device import get_device
from fme.downscaling.data.utils import ClosedInterval
from fme.downscaling.data.patching import Patch


@dataclasses.dataclass
class StaticInput:
data: torch.Tensor
coords: LatLonCoordinates

def __post_init__(self):
if len(self.data.shape) != 2:
raise ValueError(f"Topography data must be 2D. Got shape {self.data.shape}")
if self.data.shape[0] != len(self.coords.lat) or self.data.shape[1] != len(
self.coords.lon
):
raise ValueError(
f"Static inputs data shape {self.data.shape} does not match lat/lon "
f"coordinates shape {(len(self.coords.lat), len(self.coords.lon))}"
)

@property
def dim(self) -> int:
Expand All @@ -32,44 +23,23 @@ def dim(self) -> int:
def shape(self) -> tuple[int, int]:
return self.data.shape

def subset_latlon(
def subset(
self,
lat_interval: ClosedInterval,
lon_interval: ClosedInterval,
lat_slice: slice,
lon_slice: slice,
) -> "StaticInput":
lat_slice = lat_interval.slice_of(self.coords.lat)
lon_slice = lon_interval.slice_of(self.coords.lon)
return self._latlon_index_slice(lat_slice=lat_slice, lon_slice=lon_slice)
return StaticInput(data=self.data[lat_slice, lon_slice])

def to_device(self) -> "StaticInput":
device = get_device()
return StaticInput(
data=self.data.to(device),
coords=LatLonCoordinates(
lat=self.coords.lat.to(device),
lon=self.coords.lon.to(device),
),
)
return StaticInput(data=self.data.to(device))

def _latlon_index_slice(
self,
lat_slice: slice,
lon_slice: slice,
) -> "StaticInput":
sliced_data = self.data[lat_slice, lon_slice]
sliced_latlon = LatLonCoordinates(
lat=self.coords.lat[lat_slice],
lon=self.coords.lon[lon_slice],
)
return StaticInput(
data=sliced_data,
coords=sliced_latlon,
)
def _apply_patch(self, patch: Patch):
return self.subset(lat_slice=patch.input_slice.y, lon_slice=patch.input_slice.x)

def get_state(self) -> dict:
return {
"data": self.data.cpu(),
"coords": self.coords.get_state(),
}


Expand All @@ -93,17 +63,11 @@ def _get_normalized_static_input(path: str, field_name: str):
f"unexpected shape {static_input.shape} for static input."
"Currently, only lat/lon static input is supported."
)
lat_name, lon_name = static_input.dims[-2:]
coords = LatLonCoordinates(
lon=torch.tensor(static_input[lon_name].values),
lat=torch.tensor(static_input[lat_name].values),
)

static_input_normalized = (static_input - static_input.mean()) / static_input.std()

return StaticInput(
data=torch.tensor(static_input_normalized.values, dtype=torch.float32),
coords=coords,
)


Expand All @@ -113,36 +77,28 @@ class StaticInputs:

def __post_init__(self):
for i, field in enumerate(self.fields[1:]):
if field.coords != self.fields[0].coords:
if field.shape != self.fields[0].shape:
raise ValueError(
f"All StaticInput fields must have the same coordinates. "
f"Fields {i} and 0 do not match coordinates."
f"All StaticInput fields must have the same shape. "
f"Fields {i + 1} and 0 do not match shapes."
)

def __getitem__(self, index: int):
return self.fields[index]

@property
def coords(self) -> LatLonCoordinates:
if len(self.fields) == 0:
raise ValueError("No fields in StaticInputs to get coordinates from.")
return self.fields[0].coords

@property
def shape(self) -> tuple[int, int]:
if len(self.fields) == 0:
raise ValueError("No fields in StaticInputs to get shape from.")
return self.fields[0].shape

def subset_latlon(
def subset(
self,
lat_interval: ClosedInterval,
lon_interval: ClosedInterval,
lat_slice: slice,
lon_slice: slice,
) -> "StaticInputs":
return StaticInputs(
fields=[
field.subset_latlon(lat_interval, lon_interval) for field in self.fields
]
fields=[field.subset(lat_slice, lon_slice) for field in self.fields]
)

def to_device(self) -> "StaticInputs":
Expand All @@ -159,10 +115,6 @@ def from_state(cls, state: dict) -> "StaticInputs":
fields=[
StaticInput(
data=field_state["data"],
coords=LatLonCoordinates(
lat=field_state["coords"]["lat"],
lon=field_state["coords"]["lon"],
),
)
for field_state in state["fields"]
]
Expand Down
64 changes: 30 additions & 34 deletions fme/downscaling/data/test_static.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,60 @@
import pytest
import torch

from fme.core.coordinates import LatLonCoordinates

from .static import StaticInput, StaticInputs
from .utils import ClosedInterval


@pytest.mark.parametrize(
"init_args",
[
pytest.param(
[
torch.randn((1, 2, 2)),
LatLonCoordinates(torch.arange(2), torch.arange(2)),
],
[torch.randn((1, 2, 2))],
id="3d_data",
),
pytest.param(
[torch.randn((2, 2)), LatLonCoordinates(torch.arange(2), torch.arange(5))],
id="dim_size_mismatch",
),
],
)
def test_Topography_error_cases(init_args):
with pytest.raises(ValueError):
StaticInput(*init_args)


def test_subset_latlon():
def test_subset():
full_data_shape = (10, 10)
expected_slices = [slice(2, 6), slice(3, 8)]
data = torch.randn(*full_data_shape)
coords = LatLonCoordinates(
lat=torch.linspace(0, 9, 10), lon=torch.linspace(0, 9, 10)
)
topo = StaticInput(data=data, coords=coords)
lat_interval = ClosedInterval(2, 5)
lon_interval = ClosedInterval(3, 7)
subset_topo = topo.subset_latlon(lat_interval, lon_interval)
expected_lats = torch.tensor([2, 3, 4, 5], dtype=coords.lat.dtype)
expected_lons = torch.tensor([3, 4, 5, 6, 7], dtype=coords.lon.dtype)
expected_data = data[*expected_slices]
assert torch.equal(subset_topo.coords.lat, expected_lats)
assert torch.equal(subset_topo.coords.lon, expected_lons)
assert torch.allclose(subset_topo.data, expected_data)
topo = StaticInput(data=data)
lat_slice = slice(2, 6)
lon_slice = slice(3, 8)
subset_topo = topo.subset(lat_slice, lon_slice)
assert torch.allclose(subset_topo.data, data[lat_slice, lon_slice])


def test_StaticInputs_serialize():
data = torch.arange(16).reshape(4, 4)
topography = StaticInput(
data,
LatLonCoordinates(torch.arange(4), torch.arange(4)),
)
land_frac = StaticInput(
data * -1.0,
LatLonCoordinates(torch.arange(4), torch.arange(4)),
)
topography = StaticInput(data)
land_frac = StaticInput(data * -1.0)
static_inputs = StaticInputs([topography, land_frac])
state = static_inputs.get_state()
# Verify coords are NOT stored in state
assert "coords" not in state["fields"][0]
static_inputs_reconstructed = StaticInputs.from_state(state)
assert static_inputs_reconstructed[0].data.equal(static_inputs[0].data)
assert static_inputs_reconstructed[1].data.equal(static_inputs[1].data)


def test_StaticInputs_serialize_backward_compat_with_coords():
"""from_state should silently ignore 'coords' key for old state dicts."""
data = torch.arange(16, dtype=torch.float32).reshape(4, 4)
# Simulate old state dict format that included coords
old_state = {
"fields": [
{
"data": data,
"coords": {
"lat": torch.arange(4, dtype=torch.float32),
"lon": torch.arange(4, dtype=torch.float32),
},
}
]
}
static_inputs = StaticInputs.from_state(old_state)
assert torch.equal(static_inputs[0].data, data)
21 changes: 21 additions & 0 deletions fme/downscaling/data/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,27 @@ def test_ClosedInterval_slice_of(interval, expected_slice):
assert result_slice == expected_slice


@pytest.mark.parametrize(
"interval,expected_values",
[
pytest.param(ClosedInterval(2, 4), torch.tensor([2, 3, 4]), id="middle"),
pytest.param(
ClosedInterval(float("-inf"), 2), torch.tensor([0, 1, 2]), id="start_inf"
),
pytest.param(ClosedInterval(4, float("inf")), torch.tensor([4]), id="end_inf"),
pytest.param(
ClosedInterval(float("-inf"), float("inf")),
torch.arange(5),
id="all_inf",
),
],
)
def test_ClosedInterval_subset_of(interval, expected_values):
coords = torch.arange(5)
result = interval.subset_of(coords)
assert torch.equal(result, expected_values)


def test_ClosedInterval_fail_on_empty_slice():
coords = torch.arange(5)
with pytest.raises(ValueError):
Expand Down
8 changes: 8 additions & 0 deletions fme/downscaling/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ def slice_of(self, coords: torch.Tensor) -> slice:
indices = mask.nonzero(as_tuple=True)[0]
return slice(indices[0].item(), indices[-1].item() + 1)

def subset_of(self, coords: torch.Tensor) -> torch.Tensor:
"""
Return a subset of `coords` that falls within this specified interval.
This assumes `coords` is monotonically increasing.
"""
slice = self.slice_of(coords)
return coords[slice]


def scale_slice(slice_: slice, scale: int) -> slice:
if slice_ == slice(None):
Expand Down
20 changes: 13 additions & 7 deletions fme/downscaling/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
LossConfig,
NormalizationConfig,
PairedNormalizationConfig,
load_fine_coords_from_path,
)
from fme.downscaling.predictors import PatchPredictionConfig, PatchPredictor
from fme.downscaling.test_evaluator import LinearDownscalingDiffusion
Expand Down Expand Up @@ -64,8 +65,7 @@ def mock_output_target():

def get_static_inputs(shape=(16, 16)):
data = torch.randn(shape)
coords = LatLonCoordinates(lat=torch.arange(shape[0]), lon=torch.arange(shape[1]))
return StaticInputs([StaticInput(data=data, coords=coords)])
return StaticInputs([StaticInput(data=data)])


# Tests for Downscaler initialization
Expand Down Expand Up @@ -201,9 +201,11 @@ def test_run_target_generation_skips_padding_items(
mock_output_target.data.get_generator.return_value = iter([mock_work_item])

mock_model.downscale_factor = 2
mock_model.static_inputs.coords.lat = torch.arange(0, 18).float()
mock_model.static_inputs.coords.lon = torch.arange(0, 18).float()
mock_model.static_inputs.subset_latlon.return_value.fields[0].data = torch.zeros(1)
mock_model.fine_coords = LatLonCoordinates(
lat=torch.arange(0, 18).float(),
lon=torch.arange(0, 18).float(),
)
mock_model.static_inputs = None
mock_model.generate_on_batch_no_target.return_value = {
"var1": torch.zeros(1, 4, 16, 16),
}
Expand Down Expand Up @@ -273,8 +275,12 @@ def checkpointed_model_config(

# loader_config is passed in to add static inputs into model
# that correspond to the dataset coordinates
static_inputs = load_static_inputs({"HGTsfc": f"{data_paths.fine}/data.nc"})
model = model_config.build(coarse_shape, 2, static_inputs=static_inputs)
fine_data_path = f"{data_paths.fine}/data.nc"
static_inputs = load_static_inputs({"HGTsfc": fine_data_path})
fine_coords = load_fine_coords_from_path(fine_data_path)
model = model_config.build(
coarse_shape, 2, static_inputs=static_inputs, fine_coords=fine_coords
)

checkpoint_path = tmp_path / "model_checkpoint.pth"
model.get_state()
Expand Down
Loading
Loading