Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fme/downscaling/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
PairedBatchItem,
PairedGriddedData,
)
from .static import StaticInput, StaticInputs, load_static_inputs
from .static import StaticInput, StaticInputs, load_coords_from_path, load_static_inputs
from .utils import (
BatchedLatLonCoordinates,
ClosedInterval,
Expand Down
7 changes: 6 additions & 1 deletion fme/downscaling/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
PairedBatchData,
PairedGriddedData,
)
from fme.downscaling.data.utils import ClosedInterval, adjust_fine_coord_range
from fme.downscaling.data.utils import (
ClosedInterval,
adjust_fine_coord_range,
get_latlon_coords_from_properties,
)
from fme.downscaling.requirements import DataRequirements


Expand Down Expand Up @@ -529,6 +533,7 @@ def build(
dims=example.fine.latlon_coordinates.dims,
variable_metadata=variable_metadata,
all_times=all_times,
fine_coords=get_latlon_coords_from_properties(properties_fine),
)

def _get_sampler(
Expand Down
1 change: 1 addition & 0 deletions fme/downscaling/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ class PairedGriddedData:
dims: list[str]
variable_metadata: Mapping[str, VariableMetadata]
all_times: xr.CFTimeIndex
fine_coords: LatLonCoordinates

@property
def loader(self) -> DataLoader[PairedBatchItem]:
Expand Down
21 changes: 15 additions & 6 deletions fme/downscaling/data/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,22 @@ def _load_coords_from_ds(ds: xr.Dataset) -> LatLonCoordinates:
)


def _open_ds_from_path(path: str) -> xr.Dataset:
if path.endswith(".zarr"):
ds = xr.open_zarr(path, mask_and_scale=False)
else:
ds = xr.open_dataset(path, mask_and_scale=False)
return ds


def load_coords_from_path(path: str) -> LatLonCoordinates:
Copy link
Copy Markdown
Collaborator Author

@frodre frodre Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-added to use in CheckpointModelConfig. Added the small helper _open_ds_from_path to de-dupe code

ds = _open_ds_from_path(path)
return _load_coords_from_ds(ds)


def _get_normalized_static_input(
path: str, field_name: str
) -> tuple["StaticInput", LatLonCoordinates]:
) -> tuple[StaticInput, LatLonCoordinates]:
"""
Load a static input field from a given file path and field name and
normalize it.
Expand All @@ -74,11 +87,7 @@ def _get_normalized_static_input(

Raises ValueError if lat/lon coordinates are not found in the dataset.
"""
if path.endswith(".zarr"):
ds = xr.open_zarr(path, mask_and_scale=False)
else:
ds = xr.open_dataset(path, mask_and_scale=False)

ds = _open_ds_from_path(path)
coords = _load_coords_from_ds(ds)
da = ds[field_name]

Expand Down
7 changes: 6 additions & 1 deletion fme/downscaling/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,12 @@ def checkpointed_model_config(
# loader_config is passed in to add static inputs into model
# that correspond to the dataset coordinates
static_inputs = load_static_inputs({"HGTsfc": f"{data_paths.fine}/data.nc"})
model = model_config.build(coarse_shape, 2, static_inputs=static_inputs)
model = model_config.build(
coarse_shape,
2,
full_fine_coords=static_inputs.coords,
static_inputs=static_inputs,
)

checkpoint_path = tmp_path / "model_checkpoint.pth"
model.get_state()
Expand Down
Loading
Loading