Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fme/downscaling/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
PairedBatchItem,
PairedGriddedData,
)
from .static import StaticInput, StaticInputs, load_static_inputs
from .static import StaticInput, StaticInputs, load_coords_from_path, load_static_inputs
from .utils import (
BatchedLatLonCoordinates,
ClosedInterval,
Expand Down
7 changes: 6 additions & 1 deletion fme/downscaling/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
PairedBatchData,
PairedGriddedData,
)
from fme.downscaling.data.utils import ClosedInterval, adjust_fine_coord_range
from fme.downscaling.data.utils import (
ClosedInterval,
adjust_fine_coord_range,
get_latlon_coords_from_properties,
)
from fme.downscaling.requirements import DataRequirements


Expand Down Expand Up @@ -529,6 +533,7 @@ def build(
dims=example.fine.latlon_coordinates.dims,
variable_metadata=variable_metadata,
all_times=all_times,
fine_coords=get_latlon_coords_from_properties(properties_fine),
)

def _get_sampler(
Expand Down
1 change: 1 addition & 0 deletions fme/downscaling/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ class PairedGriddedData:
dims: list[str]
variable_metadata: Mapping[str, VariableMetadata]
all_times: xr.CFTimeIndex
fine_coords: LatLonCoordinates

@property
def loader(self) -> DataLoader[PairedBatchItem]:
Expand Down
21 changes: 15 additions & 6 deletions fme/downscaling/data/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,22 @@ def _load_coords_from_ds(ds: xr.Dataset) -> LatLonCoordinates:
)


def _open_ds_from_path(path: str) -> xr.Dataset:
if path.endswith(".zarr"):
ds = xr.open_zarr(path, mask_and_scale=False)
else:
ds = xr.open_dataset(path, mask_and_scale=False)
return ds


def load_coords_from_path(path: str) -> LatLonCoordinates:
Copy link
Collaborator Author

@frodre frodre Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-added to use in CheckpointModelConfig. Added the small helper _open_ds_from_path to de-dupe code

ds = _open_ds_from_path(path)
return _load_coords_from_ds(ds)


def _get_normalized_static_input(
path: str, field_name: str
) -> tuple["StaticInput", LatLonCoordinates]:
) -> tuple[StaticInput, LatLonCoordinates]:
"""
Load a static input field from a given file path and field name and
normalize it.
Expand All @@ -74,11 +87,7 @@ def _get_normalized_static_input(

Raises ValueError if lat/lon coordinates are not found in the dataset.
"""
if path.endswith(".zarr"):
ds = xr.open_zarr(path, mask_and_scale=False)
else:
ds = xr.open_dataset(path, mask_and_scale=False)

ds = _open_ds_from_path(path)
coords = _load_coords_from_ds(ds)
da = ds[field_name]

Expand Down
7 changes: 6 additions & 1 deletion fme/downscaling/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,12 @@ def checkpointed_model_config(
# loader_config is passed in to add static inputs into model
# that correspond to the dataset coordinates
static_inputs = load_static_inputs({"HGTsfc": f"{data_paths.fine}/data.nc"})
model = model_config.build(coarse_shape, 2, static_inputs=static_inputs)
model = model_config.build(
coarse_shape,
2,
full_fine_coords=static_inputs.coords,
static_inputs=static_inputs,
)

checkpoint_path = tmp_path / "model_checkpoint.pth"
model.get_state()
Expand Down
Loading
Loading