Skip to content
Merged
102 changes: 7 additions & 95 deletions fme/downscaling/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
PairedBatchData,
PairedGriddedData,
)
from fme.downscaling.data.static import StaticInputs
from fme.downscaling.data.utils import ClosedInterval, adjust_fine_coord_range
from fme.downscaling.requirements import DataRequirements

Expand Down Expand Up @@ -132,18 +131,6 @@ def _full_configs(
return all_configs


def _check_fine_res_static_input_compatibility(
static_input_shape: tuple[int, int], data_coords_shape: tuple[int, int]
) -> None:
for static, coord in zip(static_input_shape, data_coords_shape):
if static != coord:
raise ValueError(
f"Static input shape {static_input_shape} is not compatible with "
f"data coordinates shape {data_coords_shape}. Static input dimensions "
"must match fine resolution coordinate dimensions."
)


@dataclasses.dataclass
class DataLoaderConfig:
"""
Expand All @@ -164,8 +151,9 @@ class DataLoaderConfig:
(For multi-GPU runtime, it's the number of workers per GPU.)
strict_ensemble: Whether to enforce that the datasets to be concatened
have the same dimensions and coordinates.
topography: Deprecated field for specifying the topography dataset. Now
provided via build method's `static_inputs` argument.
topography: Deprecated field for specifying the topography dataset.
StaticInput data are expected to be stored and serialized within a
model through the Trainer build process.
lat_extent: The latitude extent to use for the dataset specified in
degrees, limited to (-88.0, 88.0). The extent is inclusive, so the start and
stop values are included in the extent. Defaults to [-66, 70] which
Expand Down Expand Up @@ -202,8 +190,8 @@ def __post_init__(self):
if self.topography is not None:
raise ValueError(
"The `topography` field on DataLoaderConfig is deprecated and will be "
"removed in a future release. Pass static_inputs via build's "
"`static_inputs` argument instead."
"removed in a future release. `StaticInputs` are now stored within "
" the model when it is first built and trained."
)

@property
Expand Down Expand Up @@ -236,40 +224,6 @@ def get_xarray_dataset(
strict_ensemble=self.strict_ensemble,
)

def build_static_inputs(
self,
coarse_coords: LatLonCoordinates,
requires_topography: bool,
static_inputs: StaticInputs | None = None,
) -> StaticInputs | None:
if requires_topography is False:
return None
if static_inputs is not None:
# TODO: change to use full static inputs list
full_static_inputs = static_inputs
else:
raise ValueError(
"Static inputs required for this model, but no static inputs "
"datasets were specified in the trainer configuration or provided "
"in model checkpoint."
)

# Fine grid boundaries are adjusted to exactly match the coarse grid
fine_lat_interval = adjust_fine_coord_range(
self.lat_extent,
full_coarse_coord=coarse_coords.lat,
full_fine_coord=full_static_inputs.coords.lat,
)
fine_lon_interval = adjust_fine_coord_range(
self.lon_extent,
full_coarse_coord=coarse_coords.lon,
full_fine_coord=full_static_inputs.coords.lon,
)
subset_static_inputs = full_static_inputs.subset_latlon(
lat_interval=fine_lat_interval, lon_interval=fine_lon_interval
)
return subset_static_inputs.to_device()

def build_batchitem_dataset(
self,
dataset: XarrayConcat,
Expand Down Expand Up @@ -301,22 +255,14 @@ def build(
self,
requirements: DataRequirements,
dist: Distributed | None = None,
static_inputs: StaticInputs | None = None,
) -> GriddedData:
# TODO: static_inputs_from_checkpoint is currently passed from the model
# to allow loading fine topography when no fine data is available.
# See PR https://github.com/ai2cm/ace/pull/728
# In the future we could disentangle this dependency between the data loader
# and model by enabling the built GriddedData objects to take in full static
# input fields and subset them to the same coordinate range as data.
xr_dataset, properties = self.get_xarray_dataset(
names=requirements.coarse_names, n_timesteps=1
)
if not isinstance(properties.horizontal_coordinates, LatLonCoordinates):
raise ValueError(
"Downscaling data loader only supports datasets with latlon coords."
)
latlon_coords = properties.horizontal_coordinates
dataset = self.build_batchitem_dataset(
dataset=xr_dataset,
properties=properties,
Expand All @@ -343,14 +289,8 @@ def build(
persistent_workers=True if self.num_data_workers > 0 else False,
)
example = dataset[0]
subset_static_inputs = self.build_static_inputs(
coarse_coords=latlon_coords,
requires_topography=requirements.use_fine_topography,
static_inputs=static_inputs,
)
return GriddedData(
_loader=dataloader,
static_inputs=subset_static_inputs,
shape=example.horizontal_shape,
dims=example.latlon_coordinates.dims,
variable_metadata=dataset.variable_metadata,
Expand Down Expand Up @@ -395,7 +335,6 @@ class PairedDataLoaderConfig:
time dimension. Useful to include longer sequences of small
data for testing.
topography: Deprecated field for specifying the topography dataset.
Now provided via build method's `static_inputs` argument.
sample_with_replacement: If provided, the dataset will be
sampled randomly with replacement to the given size each period,
instead of retrieving each sample once (either shuffled or not).
Expand Down Expand Up @@ -427,8 +366,8 @@ def __post_init__(self):
if self.topography is not None:
raise ValueError(
"The `topography` field on PairedDataLoaderConfig is deprecated and "
"will be removed in a future release. Pass static_inputs via the "
"build method's `static_inputs` argument instead."
"will be removed in a future release. `StaticInputs` are now stored "
"within the model when it is first built and trained."
)

def _first_data_config(
Expand Down Expand Up @@ -468,14 +407,7 @@ def build(
train: bool,
requirements: DataRequirements,
dist: Distributed | None = None,
static_inputs: StaticInputs | None = None,
) -> PairedGriddedData:
# TODO: static_inputs_from_checkpoint is currently passed from the model
# to allow loading fine topography when no fine data is available.
# See PR https://github.com/ai2cm/ace/pull/728
# In the future we could disentangle this dependency between the data loader
# and model by enabling the built GriddedData objects to take in full static
# input fields and subset them to the same coordinate range as data.
if dist is None:
dist = Distributed.get_instance()

Expand Down Expand Up @@ -537,25 +469,6 @@ def build(
full_fine_coord=properties_fine.horizontal_coordinates.lon,
)

if requirements.use_fine_topography:
if static_inputs is None:
raise ValueError(
"Model requires static inputs (use_fine_topography=True),"
" but no static inputs were provided to the data loader's"
" build method."
)

static_inputs = static_inputs.to_device()
_check_fine_res_static_input_compatibility(
static_inputs.shape,
properties_fine.horizontal_coordinates.shape,
)
static_inputs = static_inputs.subset_latlon(
lat_interval=fine_lat_extent, lon_interval=fine_lon_extent
)
else:
static_inputs = None

dataset_fine_subset = HorizontalSubsetDataset(
dataset_fine,
properties=properties_fine,
Expand Down Expand Up @@ -611,7 +524,6 @@ def build(

return PairedGriddedData(
_loader=dataloader,
static_inputs=static_inputs,
coarse_shape=example.coarse.horizontal_shape,
downscale_factor=example.downscale_factor,
dims=example.fine.latlon_coordinates.dims,
Expand Down
Loading
Loading