Skip to content
Open
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
754f28a
Initial shot
frodre Mar 12, 2026
7c5b4c5
Make fine coords required
frodre Mar 13, 2026
c7f58d9
Fine coords required for paired data
frodre Mar 13, 2026
563b57a
Mesh with previous updates in refactor pr
frodre Mar 17, 2026
f4218dd
Simplify event downscaler coordinate in run()
frodre Mar 17, 2026
68cab61
use batch latlon coardinates for coarse
frodre Mar 17, 2026
83ca043
Make fine coord loader public
frodre Mar 17, 2026
794a7d4
BatchLatLon coord access consistency
frodre Mar 17, 2026
b542080
linting
frodre Mar 17, 2026
5add727
Add no coords checkpoint with path test
frodre Mar 17, 2026
36e80db
Small tweaks
frodre Mar 17, 2026
b19c9d6
Add load_fine_coords_from_path test
frodre Mar 17, 2026
7acf87c
Update fme/downscaling/models.py
frodre Mar 18, 2026
eba2749
Update fme/downscaling/models.py
frodre Mar 18, 2026
2c25f1d
Update fme/downscaling/models.py
frodre Mar 18, 2026
81e4eb8
use latlon coords .to method for device fix
frodre Mar 18, 2026
325d324
Redo based on Anna's comments
frodre Mar 19, 2026
1fac745
Remove from_state docstring
frodre Mar 19, 2026
3ef1613
Move all state loading cases into static inputs code
frodre Mar 19, 2026
36e8101
Remove unused function from fme.downscaling.data
frodre Mar 19, 2026
3f95457
fine_coords -> full_fine_coords
frodre Mar 19, 2026
6c9fcf7
Remove duplicated tests from models.py
frodre Mar 19, 2026
4b58b86
Minor fixes
frodre Mar 19, 2026
7e36f55
Fix imports
frodre Mar 19, 2026
a85fb05
Final cleanup
frodre Mar 19, 2026
b1ee3d4
Merge branch 'main' into feature/downscaling-model-fine-coords
frodre Mar 19, 2026
bedd276
Add coordinate validation and use training data as a fallback vs. sta…
frodre Mar 20, 2026
a9dbae3
Updates based on discussion w/ Anna
frodre Mar 21, 2026
46670ff
Merge branch 'main' into feature/downscaling-model-fine-coords
frodre Mar 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fme/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def __eq__(self, other) -> bool:
def __repr__(self) -> str:
return f"LatLonCoordinates(\n lat={self.lat},\n lon={self.lon}\n"

def to(self, device: str) -> "LatLonCoordinates":
def to(self, device: str | torch.device) -> "LatLonCoordinates":
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make my VSCode linter happy

return LatLonCoordinates(
lon=self.lon.to(device),
lat=self.lat.to(device),
Expand Down
7 changes: 6 additions & 1 deletion fme/downscaling/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
PairedBatchItem,
PairedGriddedData,
)
from .static import StaticInput, StaticInputs, load_static_inputs
from .static import (
StaticInput,
StaticInputs,
load_fine_coords_from_path,
load_static_inputs,
)
from .utils import (
BatchedLatLonCoordinates,
ClosedInterval,
Expand Down
7 changes: 6 additions & 1 deletion fme/downscaling/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
PairedBatchData,
PairedGriddedData,
)
from fme.downscaling.data.utils import ClosedInterval, adjust_fine_coord_range
from fme.downscaling.data.utils import (
ClosedInterval,
adjust_fine_coord_range,
get_latlon_coords_from_properties,
)
from fme.downscaling.requirements import DataRequirements


Expand Down Expand Up @@ -529,6 +533,7 @@ def build(
dims=example.fine.latlon_coordinates.dims,
variable_metadata=variable_metadata,
all_times=all_times,
fine_coords=get_latlon_coords_from_properties(properties_fine),
)

def _get_sampler(
Expand Down
11 changes: 7 additions & 4 deletions fme/downscaling/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,13 @@ def __init__(
f"expected lon_min < {self.lon_interval.start + 360.0}"
)

self._lats_slice = self.lat_interval.slice_of(self._orig_coords.lat)
self._lons_slice = self.lon_interval.slice_of(self._orig_coords.lon)
# Used to subset the data in __getitem__
self._lats_slice = self.lat_interval.slice_from(self._orig_coords.lat)
self._lons_slice = self.lon_interval.slice_from(self._orig_coords.lon)

self._latlon_coordinates = LatLonCoordinates(
lat=self._orig_coords.lat[self._lats_slice],
lon=self._orig_coords.lon[self._lons_slice],
lat=self.lat_interval.subset_of(self._orig_coords.lat),
lon=self.lon_interval.subset_of(self._orig_coords.lon),
)
self._area_weights = self._latlon_coordinates.area_weights

Expand Down Expand Up @@ -337,6 +339,7 @@ class PairedGriddedData:
dims: list[str]
variable_metadata: Mapping[str, VariableMetadata]
all_times: xr.CFTimeIndex
fine_coords: LatLonCoordinates

@property
def loader(self) -> DataLoader[PairedBatchItem]:
Expand Down
216 changes: 137 additions & 79 deletions fme/downscaling/data/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,67 +11,42 @@
@dataclasses.dataclass
class StaticInput:
data: torch.Tensor
coords: LatLonCoordinates

def __post_init__(self):
if len(self.data.shape) != 2:
raise ValueError(f"Topography data must be 2D. Got shape {self.data.shape}")
if self.data.shape[0] != len(self.coords.lat) or self.data.shape[1] != len(
self.coords.lon
):
raise ValueError(
f"Static inputs data shape {self.data.shape} does not match lat/lon "
f"coordinates shape {(len(self.coords.lat), len(self.coords.lon))}"
f"StaticInput data must be 2D. Got shape {self.data.shape}"
)
self._shape = (self.data.shape[0], self.data.shape[1])

@property
def dim(self) -> int:
return len(self.shape)

@property
def shape(self) -> tuple[int, int]:
return self.data.shape
return self._shape

def subset_latlon(
def subset(
self,
lat_interval: ClosedInterval,
lon_interval: ClosedInterval,
lat_slice: slice,
lon_slice: slice,
) -> "StaticInput":
lat_slice = lat_interval.slice_of(self.coords.lat)
lon_slice = lon_interval.slice_of(self.coords.lon)
return self._latlon_index_slice(lat_slice=lat_slice, lon_slice=lon_slice)
return StaticInput(data=self.data[lat_slice, lon_slice])

def to_device(self) -> "StaticInput":
device = get_device()
return StaticInput(
data=self.data.to(device),
coords=LatLonCoordinates(
lat=self.coords.lat.to(device),
lon=self.coords.lon.to(device),
),
)

def _latlon_index_slice(
self,
lat_slice: slice,
lon_slice: slice,
) -> "StaticInput":
sliced_data = self.data[lat_slice, lon_slice]
sliced_latlon = LatLonCoordinates(
lat=self.coords.lat[lat_slice],
lon=self.coords.lon[lon_slice],
)
return StaticInput(
data=sliced_data,
coords=sliced_latlon,
)
return StaticInput(data=self.data.to(device))

def get_state(self) -> dict:
return {
"data": self.data.cpu(),
"coords": self.coords.get_state(),
}

@classmethod
def from_state(cls, state: dict) -> "StaticInput":
return cls(data=state["data"])


def _get_normalized_static_input(path: str, field_name: str):
"""
Expand All @@ -93,96 +68,179 @@ def _get_normalized_static_input(path: str, field_name: str):
f"unexpected shape {static_input.shape} for static input."
"Currently, only lat/lon static input is supported."
)
lat_name, lon_name = static_input.dims[-2:]
coords = LatLonCoordinates(
lon=torch.tensor(static_input[lon_name].values),
lat=torch.tensor(static_input[lat_name].values),
)

static_input_normalized = (static_input - static_input.mean()) / static_input.std()

return StaticInput(
data=torch.tensor(static_input_normalized.values, dtype=torch.float32),
coords=coords,
)


def _has_legacy_coords_in_state(state: dict) -> bool:
return "fields" in state and state["fields"] and "coords" in state["fields"][0]


def _sync_state_coordinates(state: dict) -> dict:
# if necessary adjusts legacy coordinate to expected
# format for state loading
state = state.copy()
if _has_legacy_coords_in_state(state):
state["coords"] = state["fields"][0]["coords"]
return state


def _has_coords_in_state(state: dict) -> bool:
if "coords" in state or _has_legacy_coords_in_state(state):
return True
else:
return False


def load_fine_coords_from_path(path: str) -> LatLonCoordinates:
if path.endswith(".zarr"):
ds = xr.open_zarr(path)
else:
ds = xr.open_dataset(path)
lat_name = next((n for n in ["lat", "latitude", "grid_yt"] if n in ds.coords), None)
lon_name = next(
(n for n in ["lon", "longitude", "grid_xt"] if n in ds.coords), None
)
if lat_name is None or lon_name is None:
raise ValueError(
f"Could not find lat/lon coordinates in {path}. "
"Expected 'lat'/'latitude'/'grid_yt' and 'lon'/'longitude'/'grid_xt'."
)
return LatLonCoordinates(
lat=torch.tensor(ds[lat_name].values, dtype=torch.float32),
lon=torch.tensor(ds[lon_name].values, dtype=torch.float32),
)


@dataclasses.dataclass
class StaticInputs:
fields: list[StaticInput]
coords: LatLonCoordinates
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not named full_coords because we do produce subsets with this class.


def __post_init__(self):
for i, field in enumerate(self.fields[1:]):
if field.coords != self.fields[0].coords:
if field.shape != self.fields[0].shape:
raise ValueError(
f"All StaticInput fields must have the same coordinates. "
f"Fields {i} and 0 do not match coordinates."
f"All StaticInput fields must have the same shape. "
f"Fields {i + 1} and 0 do not match shapes."
)
if self.fields and self.coords.shape != self.fields[0].shape:
raise ValueError(
f"Coordinates shape {self.coords.shape} does not match fields shape "
f"{self.fields[0].shape} for StaticInputs."
)

def __getitem__(self, index: int):
return self.fields[index]

@property
def coords(self) -> LatLonCoordinates:
if len(self.fields) == 0:
raise ValueError("No fields in StaticInputs to get coordinates from.")
return self.fields[0].coords

@property
def shape(self) -> tuple[int, int]:
if len(self.fields) == 0:
raise ValueError("No fields in StaticInputs to get shape from.")
return self.fields[0].shape

def subset_latlon(
def subset(
self,
lat_interval: ClosedInterval,
lon_interval: ClosedInterval,
) -> "StaticInputs":
lat_slice = lat_interval.slice_from(self.coords.lat)
lon_slice = lon_interval.slice_from(self.coords.lon)
return StaticInputs(
fields=[
field.subset_latlon(lat_interval, lon_interval) for field in self.fields
]
fields=[field.subset(lat_slice, lon_slice) for field in self.fields],
coords=LatLonCoordinates(
lat=lat_interval.subset_of(self.coords.lat),
lon=lon_interval.subset_of(self.coords.lon),
),
)

def to_device(self) -> "StaticInputs":
return StaticInputs(fields=[field.to_device() for field in self.fields])
return StaticInputs(
fields=[field.to_device() for field in self.fields],
coords=self.coords.to(get_device()),
)

def get_state(self) -> dict:
return {
"fields": [field.get_state() for field in self.fields],
"coords": self.coords.get_state(),
}

@classmethod
def from_state(cls, state: dict) -> "StaticInputs":
if not _has_coords_in_state(state):
raise ValueError(
"No coordinates found in state for StaticInputs. Load with "
"from_state_backwards_compatible if loading from a checkpoint "
"saved prior to current coordinate serialization format."
)
state = _sync_state_coordinates(state)
return cls(
fields=[
StaticInput(
data=field_state["data"],
coords=LatLonCoordinates(
lat=field_state["coords"]["lat"],
lon=field_state["coords"]["lon"],
),
)
for field_state in state["fields"]
]
StaticInput.from_state(field_state) for field_state in state["fields"]
],
coords=LatLonCoordinates(
lat=state["coords"]["lat"],
lon=state["coords"]["lon"],
),
)

@classmethod
def from_state_backwards_compatible(
cls,
state: dict,
static_inputs_config: dict[str, str],
fine_coordinates_path: str | None,
) -> "StaticInputs":
if state and static_inputs_config:
raise ValueError(
"Checkpoint contains static inputs but static_inputs_config is "
"also provided. Backwards compatibility loading only supports "
"a single source of StaticInputs info."
)

if fine_coordinates_path and _has_coords_in_state(state):
raise ValueError(
"State contains coordinates but fine_coordinates_path is also provided."
" Only one source of coordinate info can be used for backwards "
"compatibility loading of StaticInputs."
)
elif not _has_coords_in_state(state) and not fine_coordinates_path:
raise ValueError(
"No coordinates found in state and no fine_coordinates_path provided. "
"Cannot load StaticInputs without coordinates."
)

# All compatibility cases:
# Serialized StaticInputs exist, which always had coordinates stored
# No serialized static inputs or specified inputs, load coordinates
# Specified static input fields and specified coordinates

if _has_coords_in_state(state):
return cls.from_state(state)
else:
assert fine_coordinates_path is not None # for type checker
coords = load_fine_coords_from_path(fine_coordinates_path)

if static_inputs_config:
return load_static_inputs(static_inputs_config, coords)
else:
return cls(fields=[], coords=coords)


def load_static_inputs(
static_inputs_config: dict[str, str] | None,
) -> StaticInputs | None:
static_inputs_config: dict[str, str], coords: LatLonCoordinates
) -> StaticInputs:
"""
Load normalized static inputs from a mapping of field names to file paths.
Returns None if the input config is empty.
Returns an empty StaticInputs (no fields) if the config is empty.
"""
# TODO: consolidate/simplify empty StaticInputs vs. None handling in
# downscaling code
if not static_inputs_config:
return None
return StaticInputs(
fields=[
_get_normalized_static_input(path, field_name)
for field_name, path in static_inputs_config.items()
]
)
fields = [
_get_normalized_static_input(path, field_name)
for field_name, path in static_inputs_config.items()
]
return StaticInputs(fields=fields, coords=coords)
Loading
Loading