Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion fme/downscaling/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
PairedBatchItem,
PairedGriddedData,
)
from .static import StaticInput, StaticInputs, load_static_inputs
from .static import (
StaticInput,
StaticInputs,
load_fine_coords_from_path,
load_static_inputs,
)
from .utils import (
BatchedLatLonCoordinates,
ClosedInterval,
Expand Down
239 changes: 150 additions & 89 deletions fme/downscaling/data/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,11 @@
@dataclasses.dataclass
class StaticInput:
data: torch.Tensor
coords: LatLonCoordinates

def __post_init__(self):
if len(self.data.shape) != 2:
raise ValueError(f"Topography data must be 2D. Got shape {self.data.shape}")
if self.data.shape[0] != len(self.coords.lat) or self.data.shape[1] != len(
self.coords.lon
):
raise ValueError(
f"Static inputs data shape {self.data.shape} does not match lat/lon "
f"coordinates shape {(len(self.coords.lat), len(self.coords.lon))}"
f"StaticInput data must be 2D. Got shape {self.data.shape}"
)

@property
Expand All @@ -34,100 +28,127 @@ def shape(self) -> tuple[int, int]:

def subset(
self,
lat_interval: ClosedInterval,
lon_interval: ClosedInterval,
lat_slice: slice,
lon_slice: slice,
) -> "StaticInput":
lat_slice = lat_interval.slice_from(self.coords.lat)
lon_slice = lon_interval.slice_from(self.coords.lon)
return self._latlon_index_slice(lat_slice=lat_slice, lon_slice=lon_slice)
return StaticInput(data=self.data[lat_slice, lon_slice])

def to_device(self) -> "StaticInput":
device = get_device()
return StaticInput(
data=self.data.to(device),
coords=LatLonCoordinates(
lat=self.coords.lat.to(device),
lon=self.coords.lon.to(device),
),
)
return StaticInput(data=self.data.to(get_device()))

def _latlon_index_slice(
self,
lat_slice: slice,
lon_slice: slice,
) -> "StaticInput":
sliced_data = self.data[lat_slice, lon_slice]
sliced_latlon = LatLonCoordinates(
lat=self.coords.lat[lat_slice],
lon=self.coords.lon[lon_slice],
)
return StaticInput(
data=sliced_data,
coords=sliced_latlon,
def get_state(self) -> dict:
return {"data": self.data.cpu()}

@classmethod
def from_state(cls, state: dict) -> "StaticInput":
return cls(data=state["data"])


_LAT_NAMES = ("lat", "latitude", "grid_yt")
_LON_NAMES = ("lon", "longitude", "grid_xt")


def _load_coords_from_ds(ds: xr.Dataset) -> LatLonCoordinates:
lat_name = next((n for n in _LAT_NAMES if n in ds.coords), None)
lon_name = next((n for n in _LON_NAMES if n in ds.coords), None)
if lat_name is None or lon_name is None:
raise ValueError(
"Could not find lat/lon coordinates in dataset. "
f"Expected one of {_LAT_NAMES} for lat and {_LON_NAMES} for lon."
)
return LatLonCoordinates(
lat=torch.tensor(ds[lat_name].values, dtype=torch.float32),
lon=torch.tensor(ds[lon_name].values, dtype=torch.float32),
)

def get_state(self) -> dict:
return {
"data": self.data.cpu(),
"coords": self.coords.get_state(),
}

def load_fine_coords_from_path(path: str) -> LatLonCoordinates:
"""Load lat/lon coordinates from a netCDF or zarr file."""
if path.endswith(".zarr"):
ds = xr.open_zarr(path)
else:
ds = xr.open_dataset(path)
return _load_coords_from_ds(ds)


def _get_normalized_static_input(path: str, field_name: str):
def _get_normalized_static_input(
path: str, field_name: str
) -> tuple["StaticInput", LatLonCoordinates]:
"""
Load a static input field from a given file path and field name and
normalize it.
Only supports 2D lat/lon static inputs. If the input has a time dimension, it is
squeezed by taking the first time step. The lat/lon coordinates are
assumed to be the last two dimensions of the loaded dataset dimensions.
squeezed by taking the first time step.
Raises ValueError if lat/lon coordinates are not found in the dataset.
"""
if path.endswith(".zarr"):
static_input = xr.open_zarr(path, mask_and_scale=False)[field_name]
ds = xr.open_zarr(path, mask_and_scale=False)
else:
static_input = xr.open_dataset(path, mask_and_scale=False)[field_name]
if "time" in static_input.dims:
static_input = static_input.isel(time=0).squeeze()
if len(static_input.shape) != 2:
ds = xr.open_dataset(path, mask_and_scale=False)

coords = _load_coords_from_ds(ds)
da = ds[field_name]

if "time" in da.dims:
da = da.isel(time=0).squeeze()
if len(da.shape) != 2:
raise ValueError(
f"unexpected shape {static_input.shape} for static input."
f"unexpected shape {da.shape} for static input. "
"Currently, only lat/lon static input is supported."
)
lat_name, lon_name = static_input.dims[-2:]
coords = LatLonCoordinates(
lon=torch.tensor(static_input[lon_name].values),
lat=torch.tensor(static_input[lat_name].values),
)

static_input_normalized = (static_input - static_input.mean()) / static_input.std()

static_input_normalized = (da - da.mean()) / da.std()
return StaticInput(
data=torch.tensor(static_input_normalized.values, dtype=torch.float32),
coords=coords,
)
), coords


def _has_legacy_coords_in_state(state: dict) -> bool:
return bool(state.get("fields")) and "coords" in state["fields"][0]


def _has_coords_in_state(state: dict) -> bool:
return "coords" in state or _has_legacy_coords_in_state(state)


def _sync_state_coordinates(state: dict) -> dict:
"""Migrate old per-field coord format to top-level coords format."""
if _has_legacy_coords_in_state(state):
state = dict(state)
state["coords"] = state["fields"][0]["coords"]
return state


def _validate_coords(
case: str, coord1: LatLonCoordinates, coord2: LatLonCoordinates
) -> None:
if coord1 != coord2:
raise ValueError(f"Coordinates do not match between static inputs: {case}")


@dataclasses.dataclass
class StaticInputs:
fields: list[StaticInput]
coords: LatLonCoordinates

def __post_init__(self):
for i, field in enumerate(self.fields[1:]):
if field.coords != self.fields[0].coords:
if field.shape != self.fields[0].shape:
raise ValueError(
f"All StaticInput fields must have the same coordinates. "
f"Fields {i} and 0 do not match coordinates."
f"All StaticInput fields must have the same shape. "
f"Fields {i + 1} and 0 do not match shapes."
)
if self.fields and self.coords.shape != self.fields[0].shape:
raise ValueError(
f"Coordinates shape {self.coords.shape} does not match "
f"fields shape {self.fields[0].shape}."
)

def __getitem__(self, index: int):
return self.fields[index]

@property
def coords(self) -> LatLonCoordinates:
if len(self.fields) == 0:
raise ValueError("No fields in StaticInputs to get coordinates from.")
return self.fields[0].coords

@property
def shape(self) -> tuple[int, int]:
if len(self.fields) == 0:
Expand All @@ -139,48 +160,88 @@ def subset(
lat_interval: ClosedInterval,
lon_interval: ClosedInterval,
) -> "StaticInputs":
lat_slice = lat_interval.slice_from(self.coords.lat)
lon_slice = lon_interval.slice_from(self.coords.lon)
return StaticInputs(
fields=[field.subset(lat_interval, lon_interval) for field in self.fields]
fields=[field.subset(lat_slice, lon_slice) for field in self.fields],
coords=LatLonCoordinates(
lat=lat_interval.subset_of(self.coords.lat),
lon=lon_interval.subset_of(self.coords.lon),
),
)

def to_device(self) -> "StaticInputs":
return StaticInputs(fields=[field.to_device() for field in self.fields])
return StaticInputs(
fields=[field.to_device() for field in self.fields],
coords=self.coords.to(get_device()),
)

def get_state(self) -> dict:
return {
"fields": [field.get_state() for field in self.fields],
"coords": self.coords.get_state(),
}

@classmethod
def from_state(cls, state: dict) -> "StaticInputs":
if not _has_coords_in_state(state):
raise ValueError(
"No coordinates found in state for StaticInputs. Load with "
"from_state_backwards_compatible if loading from a checkpoint "
"saved prior to current coordinate serialization format."
)
state = _sync_state_coordinates(state)
return cls(
fields=[
StaticInput(
data=field_state["data"],
coords=LatLonCoordinates(
lat=field_state["coords"]["lat"],
lon=field_state["coords"]["lon"],
),
)
for field_state in state["fields"]
]
StaticInput.from_state(field_state) for field_state in state["fields"]
],
coords=LatLonCoordinates(
lat=state["coords"]["lat"],
lon=state["coords"]["lon"],
),
)

@classmethod
def from_state_backwards_compatible(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method gets used in #999, but there will be a CheckpointModelConfig gate on calling this without coords available in the state or the configuration.

cls,
state: dict,
static_inputs_config: dict[str, str],
) -> "StaticInputs | None":
if state and static_inputs_config:
raise ValueError(
"Checkpoint contains static inputs but static_inputs_config is "
"also provided. Backwards compatibility loading only supports "
"a single source of StaticInputs info."
)
if _has_coords_in_state(state):
return cls.from_state(state)
elif static_inputs_config:
return load_static_inputs(static_inputs_config)
else:
return None


def load_static_inputs(
static_inputs_config: dict[str, str] | None,
) -> StaticInputs | None:
static_inputs_config: dict[str, str],
) -> StaticInputs:
"""
Load normalized static inputs from a mapping of field names to file paths.
Returns None if the input config is empty.
Coordinates are read from each field's source dataset and validated to be
consistent between fields. Raises ValueError if any field's dataset lacks
lat/lon coordinates or if coordinates differ between fields.
"""
# TODO: consolidate/simplify empty StaticInputs vs. None handling in
# downscaling code
if not static_inputs_config:
return None
return StaticInputs(
fields=[
_get_normalized_static_input(path, field_name)
for field_name, path in static_inputs_config.items()
]
)
coords_to_use = None
fields = []
for field_name, path in static_inputs_config.items():
si, coords = _get_normalized_static_input(path, field_name)
fields.append(si)
if coords_to_use is None:
coords_to_use = coords
else:
_validate_coords(field_name, coords, coords_to_use)

if coords_to_use is None:
raise ValueError("load_static_inputs requires at least one field.")

return StaticInputs(fields=fields, coords=coords_to_use)
Loading
Loading