Skip to content
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
754f28a
Initial shot
frodre Mar 12, 2026
7c5b4c5
Make fine coords required
frodre Mar 13, 2026
c7f58d9
Fine coords required for paired data
frodre Mar 13, 2026
563b57a
Mesh with previous updates in refactor pr
frodre Mar 17, 2026
f4218dd
Simplify event downscaler coordinate in run()
frodre Mar 17, 2026
68cab61
use batch latlon coardinates for coarse
frodre Mar 17, 2026
83ca043
Make fine coord loader public
frodre Mar 17, 2026
794a7d4
BatchLatLon coord access consistency
frodre Mar 17, 2026
b542080
linting
frodre Mar 17, 2026
5add727
Add no coords checkpoint with path test
frodre Mar 17, 2026
36e80db
Small tweaks
frodre Mar 17, 2026
b19c9d6
Add load_fine_coords_from_path test
frodre Mar 17, 2026
7acf87c
Update fme/downscaling/models.py
frodre Mar 18, 2026
eba2749
Update fme/downscaling/models.py
frodre Mar 18, 2026
2c25f1d
Update fme/downscaling/models.py
frodre Mar 18, 2026
81e4eb8
use latlon coords .to method for device fix
frodre Mar 18, 2026
325d324
Redo based on Anna's comments
frodre Mar 19, 2026
1fac745
Remove from_state docstring
frodre Mar 19, 2026
3ef1613
Move all state loading cases into static inputs code
frodre Mar 19, 2026
36e8101
Remove unused function from fme.downscaling.data
frodre Mar 19, 2026
3f95457
fine_coords -> full_fine_coords
frodre Mar 19, 2026
6c9fcf7
Remove duplicated tests from models.py
frodre Mar 19, 2026
4b58b86
Minor fixes
frodre Mar 19, 2026
7e36f55
Fix imports
frodre Mar 19, 2026
a85fb05
Final cleanup
frodre Mar 19, 2026
b1ee3d4
Merge branch 'main' into feature/downscaling-model-fine-coords
frodre Mar 19, 2026
bedd276
Add coordinate validation and use training data as a fallback vs. sta…
frodre Mar 20, 2026
a9dbae3
Updates based on discussion w/ Anna
frodre Mar 21, 2026
46670ff
Merge branch 'main' into feature/downscaling-model-fine-coords
frodre Mar 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fme/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ def __eq__(self, other) -> bool:
def __repr__(self) -> str:
return f"LatLonCoordinates(\n lat={self.lat},\n lon={self.lon}\n"

def to(self, device: str) -> "LatLonCoordinates":
def to(self, device: str | torch.device) -> "LatLonCoordinates":
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make my VSCode linter happy

return LatLonCoordinates(
lon=self.lon.to(device),
lat=self.lat.to(device),
Expand Down
7 changes: 6 additions & 1 deletion fme/downscaling/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
PairedBatchData,
PairedGriddedData,
)
from fme.downscaling.data.utils import ClosedInterval, adjust_fine_coord_range
from fme.downscaling.data.utils import (
ClosedInterval,
adjust_fine_coord_range,
get_latlon_coords_from_properties,
)
from fme.downscaling.requirements import DataRequirements


Expand Down Expand Up @@ -529,6 +533,7 @@ def build(
dims=example.fine.latlon_coordinates.dims,
variable_metadata=variable_metadata,
all_times=all_times,
fine_coords=get_latlon_coords_from_properties(properties_fine),
)

def _get_sampler(
Expand Down
11 changes: 7 additions & 4 deletions fme/downscaling/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,13 @@ def __init__(
f"expected lon_min < {self.lon_interval.start + 360.0}"
)

self._lats_slice = self.lat_interval.slice_of(self._orig_coords.lat)
self._lons_slice = self.lon_interval.slice_of(self._orig_coords.lon)
# Used to subset the data in __getitem__
self._lats_slice = self.lat_interval.slice_from(self._orig_coords.lat)
self._lons_slice = self.lon_interval.slice_from(self._orig_coords.lon)

self._latlon_coordinates = LatLonCoordinates(
lat=self._orig_coords.lat[self._lats_slice],
lon=self._orig_coords.lon[self._lons_slice],
lat=self.lat_interval.subset_of(self._orig_coords.lat),
lon=self.lon_interval.subset_of(self._orig_coords.lon),
)
self._area_weights = self._latlon_coordinates.area_weights

Expand Down Expand Up @@ -337,6 +339,7 @@ class PairedGriddedData:
dims: list[str]
variable_metadata: Mapping[str, VariableMetadata]
all_times: xr.CFTimeIndex
fine_coords: LatLonCoordinates

@property
def loader(self) -> DataLoader[PairedBatchItem]:
Expand Down
135 changes: 56 additions & 79 deletions fme/downscaling/data/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,67 +11,42 @@
@dataclasses.dataclass
class StaticInput:
data: torch.Tensor
coords: LatLonCoordinates

def __post_init__(self):
if len(self.data.shape) != 2:
raise ValueError(f"Topography data must be 2D. Got shape {self.data.shape}")
if self.data.shape[0] != len(self.coords.lat) or self.data.shape[1] != len(
self.coords.lon
):
raise ValueError(
f"Static inputs data shape {self.data.shape} does not match lat/lon "
f"coordinates shape {(len(self.coords.lat), len(self.coords.lon))}"
f"StaticInput data must be 2D. Got shape {self.data.shape}"
)
self._shape = (self.data.shape[0], self.data.shape[1])

@property
def dim(self) -> int:
return len(self.shape)

@property
def shape(self) -> tuple[int, int]:
return self.data.shape
return self._shape

def subset_latlon(
def subset(
self,
lat_interval: ClosedInterval,
lon_interval: ClosedInterval,
lat_slice: slice,
lon_slice: slice,
) -> "StaticInput":
lat_slice = lat_interval.slice_of(self.coords.lat)
lon_slice = lon_interval.slice_of(self.coords.lon)
return self._latlon_index_slice(lat_slice=lat_slice, lon_slice=lon_slice)
return StaticInput(data=self.data[lat_slice, lon_slice])

def to_device(self) -> "StaticInput":
device = get_device()
return StaticInput(
data=self.data.to(device),
coords=LatLonCoordinates(
lat=self.coords.lat.to(device),
lon=self.coords.lon.to(device),
),
)

def _latlon_index_slice(
self,
lat_slice: slice,
lon_slice: slice,
) -> "StaticInput":
sliced_data = self.data[lat_slice, lon_slice]
sliced_latlon = LatLonCoordinates(
lat=self.coords.lat[lat_slice],
lon=self.coords.lon[lon_slice],
)
return StaticInput(
data=sliced_data,
coords=sliced_latlon,
)
return StaticInput(data=self.data.to(device))

def get_state(self) -> dict:
return {
"data": self.data.cpu(),
"coords": self.coords.get_state(),
}

@classmethod
def from_state(cls, state: dict) -> "StaticInput":
return cls(data=state["data"])


def _get_normalized_static_input(path: str, field_name: str):
"""
Expand All @@ -93,96 +68,98 @@ def _get_normalized_static_input(path: str, field_name: str):
f"unexpected shape {static_input.shape} for static input."
"Currently, only lat/lon static input is supported."
)
lat_name, lon_name = static_input.dims[-2:]
coords = LatLonCoordinates(
lon=torch.tensor(static_input[lon_name].values),
lat=torch.tensor(static_input[lat_name].values),
)

static_input_normalized = (static_input - static_input.mean()) / static_input.std()

return StaticInput(
data=torch.tensor(static_input_normalized.values, dtype=torch.float32),
coords=coords,
)


@dataclasses.dataclass
class StaticInputs:
fields: list[StaticInput]
coords: LatLonCoordinates
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not named full_coords because we do produce subsets with this class.


def __post_init__(self):
for i, field in enumerate(self.fields[1:]):
if field.coords != self.fields[0].coords:
if field.shape != self.fields[0].shape:
raise ValueError(
f"All StaticInput fields must have the same coordinates. "
f"Fields {i} and 0 do not match coordinates."
f"All StaticInput fields must have the same shape. "
f"Fields {i + 1} and 0 do not match shapes."
)
if self.fields and self.coords.shape != self.fields[0].shape:
raise ValueError(
f"Coordinates shape {self.coords.shape} does not match fields shape "
f"{self.fields[0].shape} for StaticInputs."
)

def __getitem__(self, index: int):
return self.fields[index]

@property
def coords(self) -> LatLonCoordinates:
if len(self.fields) == 0:
raise ValueError("No fields in StaticInputs to get coordinates from.")
return self.fields[0].coords

@property
def shape(self) -> tuple[int, int]:
if len(self.fields) == 0:
raise ValueError("No fields in StaticInputs to get shape from.")
return self.fields[0].shape

def subset_latlon(
def subset(
self,
lat_interval: ClosedInterval,
lon_interval: ClosedInterval,
) -> "StaticInputs":
lat_slice = lat_interval.slice_from(self.coords.lat)
lon_slice = lon_interval.slice_from(self.coords.lon)
return StaticInputs(
fields=[
field.subset_latlon(lat_interval, lon_interval) for field in self.fields
]
fields=[field.subset(lat_slice, lon_slice) for field in self.fields],
coords=LatLonCoordinates(
lat=lat_interval.subset_of(self.coords.lat),
lon=lon_interval.subset_of(self.coords.lon),
),
)

def to_device(self) -> "StaticInputs":
return StaticInputs(fields=[field.to_device() for field in self.fields])
return StaticInputs(
fields=[field.to_device() for field in self.fields],
coords=self.coords.to(get_device()),
)

def get_state(self) -> dict:
return {
"fields": [field.get_state() for field in self.fields],
"coords": self.coords.get_state(),
}

@classmethod
def from_state(cls, state: dict) -> "StaticInputs":
"""Reconstruct StaticInputs from a state dict.

Args:
state: State dict from get_state().
coords: Override coordinates. If None, reads coords from the state dict.
Pass explicitly when loading old-format checkpoints that stored coords
outside of the StaticInputs state.
"""
return cls(
fields=[
StaticInput(
data=field_state["data"],
coords=LatLonCoordinates(
lat=field_state["coords"]["lat"],
lon=field_state["coords"]["lon"],
),
)
for field_state in state["fields"]
]
StaticInput.from_state(field_state) for field_state in state["fields"]
],
coords=LatLonCoordinates(
lat=state["coords"]["lat"],
lon=state["coords"]["lon"],
),
)


def load_static_inputs(
static_inputs_config: dict[str, str] | None,
) -> StaticInputs | None:
static_inputs_config: dict[str, str], coords: LatLonCoordinates
) -> StaticInputs:
"""
Load normalized static inputs from a mapping of field names to file paths.
Returns None if the input config is empty.
Returns an empty StaticInputs (no fields) if the config is None or empty.
"""
# TODO: consolidate/simplify empty StaticInputs vs. None handling in
# downscaling code
if not static_inputs_config:
return None
return StaticInputs(
fields=[
_get_normalized_static_input(path, field_name)
for field_name, path in static_inputs_config.items()
]
)
fields = [
_get_normalized_static_input(path, field_name)
for field_name, path in static_inputs_config.items()
]
return StaticInputs(fields=fields, coords=coords)
81 changes: 46 additions & 35 deletions fme/downscaling/data/test_static.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,75 @@
import pytest
import torch

from fme.core.coordinates import LatLonCoordinates

from .static import StaticInput, StaticInputs
from .utils import ClosedInterval


@pytest.mark.parametrize(
"init_args",
[
pytest.param(
[
torch.randn((1, 2, 2)),
LatLonCoordinates(torch.arange(2), torch.arange(2)),
],
[torch.randn((1, 2, 2))],
id="3d_data",
),
pytest.param(
[torch.randn((2, 2)), LatLonCoordinates(torch.arange(2), torch.arange(5))],
id="dim_size_mismatch",
),
],
)
def test_Topography_error_cases(init_args):
with pytest.raises(ValueError):
StaticInput(*init_args)


def test_subset_latlon():
def test_subset():
full_data_shape = (10, 10)
expected_slices = [slice(2, 6), slice(3, 8)]
data = torch.randn(*full_data_shape)
coords = LatLonCoordinates(
lat=torch.linspace(0, 9, 10), lon=torch.linspace(0, 9, 10)
)
topo = StaticInput(data=data, coords=coords)
lat_interval = ClosedInterval(2, 5)
lon_interval = ClosedInterval(3, 7)
subset_topo = topo.subset_latlon(lat_interval, lon_interval)
expected_lats = torch.tensor([2, 3, 4, 5], dtype=coords.lat.dtype)
expected_lons = torch.tensor([3, 4, 5, 6, 7], dtype=coords.lon.dtype)
expected_data = data[*expected_slices]
assert torch.equal(subset_topo.coords.lat, expected_lats)
assert torch.equal(subset_topo.coords.lon, expected_lons)
assert torch.allclose(subset_topo.data, expected_data)
topo = StaticInput(data=data)
lat_slice = slice(2, 6)
lon_slice = slice(3, 8)
subset_topo = topo.subset(lat_slice, lon_slice)
assert torch.allclose(subset_topo.data, data[lat_slice, lon_slice])


def test_StaticInputs_serialize():
data = torch.arange(16).reshape(4, 4)
topography = StaticInput(
data,
LatLonCoordinates(torch.arange(4), torch.arange(4)),
)
land_frac = StaticInput(
data * -1.0,
LatLonCoordinates(torch.arange(4), torch.arange(4)),
from fme.core.coordinates import LatLonCoordinates

data = torch.arange(16, dtype=torch.float32).reshape(4, 4)
coords = LatLonCoordinates(
lat=torch.arange(4, dtype=torch.float32),
lon=torch.arange(4, dtype=torch.float32),
)
static_inputs = StaticInputs([topography, land_frac])
topography = StaticInput(data)
land_frac = StaticInput(data * -1.0)
static_inputs = StaticInputs([topography, land_frac], coords=coords)
state = static_inputs.get_state()
# Verify coords are stored at the top level, not inside each field
assert "coords" in state
assert "coords" not in state["fields"][0]
static_inputs_reconstructed = StaticInputs.from_state(state)
assert static_inputs_reconstructed[0].data.equal(static_inputs[0].data)
assert static_inputs_reconstructed[1].data.equal(static_inputs[1].data)


def test_StaticInputs_serialize_backward_compat_with_coords():
"""from_state should silently ignore 'coords' key in fields for old state dicts."""
from fme.core.coordinates import LatLonCoordinates

data = torch.arange(16, dtype=torch.float32).reshape(4, 4)
coords = LatLonCoordinates(
lat=torch.arange(4, dtype=torch.float32),
lon=torch.arange(4, dtype=torch.float32),
)
# Simulate old state dict format that included coords inside fields.
# from_state should silently ignore extra keys (like 'coords') in field dicts.
old_state = {
"fields": [
{
"data": data,
"coords": {
"lat": torch.arange(4, dtype=torch.float32),
"lon": torch.arange(4, dtype=torch.float32),
},
}
],
"coords": coords.get_state(),
}
static_inputs = StaticInputs.from_state(old_state)
assert torch.equal(static_inputs[0].data, data)
Loading
Loading