Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 13 additions & 96 deletions fme/downscaling/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@
PairedBatchData,
PairedGriddedData,
)
from fme.downscaling.data.static import StaticInputs
from fme.downscaling.data.utils import ClosedInterval, adjust_fine_coord_range
from fme.downscaling.data.utils import (
ClosedInterval,
adjust_fine_coord_range,
get_latlon_coords_from_properties,
)
from fme.downscaling.requirements import DataRequirements


Expand Down Expand Up @@ -132,18 +135,6 @@ def _full_configs(
return all_configs


def _check_fine_res_static_input_compatibility(
static_input_shape: tuple[int, int], data_coords_shape: tuple[int, int]
) -> None:
for static, coord in zip(static_input_shape, data_coords_shape):
if static != coord:
raise ValueError(
f"Static input shape {static_input_shape} is not compatible with "
f"data coordinates shape {data_coords_shape}. Static input dimensions "
"must match fine resolution coordinate dimensions."
)


@dataclasses.dataclass
class DataLoaderConfig:
"""
Expand All @@ -164,8 +155,9 @@ class DataLoaderConfig:
(For multi-GPU runtime, it's the number of workers per GPU.)
strict_ensemble: Whether to enforce that the datasets to be concatened
have the same dimensions and coordinates.
topography: Deprecated field for specifying the topography dataset. Now
provided via build method's `static_inputs` argument.
topography: Deprecated field for specifying the topography dataset.
StaticInput data are expected to be stored and serialized within a
model through the Trainer build process.
lat_extent: The latitude extent to use for the dataset specified in
degrees, limited to (-88.0, 88.0). The extent is inclusive, so the start and
stop values are included in the extent. Defaults to [-66, 70] which
Expand Down Expand Up @@ -202,8 +194,8 @@ def __post_init__(self):
if self.topography is not None:
raise ValueError(
"The `topography` field on DataLoaderConfig is deprecated and will be "
"removed in a future release. Pass static_inputs via build's "
"`static_inputs` argument instead."
"removed in a future release. `StaticInputs` are now stored within "
" the model when it is first built and trained."
)

@property
Expand Down Expand Up @@ -236,40 +228,6 @@ def get_xarray_dataset(
strict_ensemble=self.strict_ensemble,
)

def build_static_inputs(
self,
coarse_coords: LatLonCoordinates,
requires_topography: bool,
static_inputs: StaticInputs | None = None,
) -> StaticInputs | None:
if requires_topography is False:
return None
if static_inputs is not None:
# TODO: change to use full static inputs list
full_static_inputs = static_inputs
else:
raise ValueError(
"Static inputs required for this model, but no static inputs "
"datasets were specified in the trainer configuration or provided "
"in model checkpoint."
)

# Fine grid boundaries are adjusted to exactly match the coarse grid
fine_lat_interval = adjust_fine_coord_range(
self.lat_extent,
full_coarse_coord=coarse_coords.lat,
full_fine_coord=full_static_inputs.coords.lat,
)
fine_lon_interval = adjust_fine_coord_range(
self.lon_extent,
full_coarse_coord=coarse_coords.lon,
full_fine_coord=full_static_inputs.coords.lon,
)
subset_static_inputs = full_static_inputs.subset_latlon(
lat_interval=fine_lat_interval, lon_interval=fine_lon_interval
)
return subset_static_inputs.to_device()

def build_batchitem_dataset(
self,
dataset: XarrayConcat,
Expand Down Expand Up @@ -301,22 +259,14 @@ def build(
self,
requirements: DataRequirements,
dist: Distributed | None = None,
static_inputs: StaticInputs | None = None,
) -> GriddedData:
# TODO: static_inputs_from_checkpoint is currently passed from the model
# to allow loading fine topography when no fine data is available.
# See PR https://github.com/ai2cm/ace/pull/728
# In the future we could disentangle this dependency between the data loader
# and model by enabling the built GriddedData objects to take in full static
# input fields and subset them to the same coordinate range as data.
xr_dataset, properties = self.get_xarray_dataset(
names=requirements.coarse_names, n_timesteps=1
)
if not isinstance(properties.horizontal_coordinates, LatLonCoordinates):
raise ValueError(
"Downscaling data loader only supports datasets with latlon coords."
)
latlon_coords = properties.horizontal_coordinates
dataset = self.build_batchitem_dataset(
dataset=xr_dataset,
properties=properties,
Expand All @@ -343,14 +293,8 @@ def build(
persistent_workers=True if self.num_data_workers > 0 else False,
)
example = dataset[0]
subset_static_inputs = self.build_static_inputs(
coarse_coords=latlon_coords,
requires_topography=requirements.use_fine_topography,
static_inputs=static_inputs,
)
return GriddedData(
_loader=dataloader,
static_inputs=subset_static_inputs,
shape=example.horizontal_shape,
dims=example.latlon_coordinates.dims,
variable_metadata=dataset.variable_metadata,
Expand Down Expand Up @@ -395,7 +339,6 @@ class PairedDataLoaderConfig:
time dimension. Useful to include longer sequences of small
data for testing.
topography: Deprecated field for specifying the topography dataset.
Now provided via build method's `static_inputs` argument.
sample_with_replacement: If provided, the dataset will be
sampled randomly with replacement to the given size each period,
instead of retrieving each sample once (either shuffled or not).
Expand Down Expand Up @@ -427,8 +370,8 @@ def __post_init__(self):
if self.topography is not None:
raise ValueError(
"The `topography` field on PairedDataLoaderConfig is deprecated and "
"will be removed in a future release. Pass static_inputs via the "
"build method's `static_inputs` argument instead."
"will be removed in a future release. `StaticInputs` are now stored "
"within the model when it is first built and trained."
)

def _first_data_config(
Expand Down Expand Up @@ -468,14 +411,7 @@ def build(
train: bool,
requirements: DataRequirements,
dist: Distributed | None = None,
static_inputs: StaticInputs | None = None,
) -> PairedGriddedData:
# TODO: static_inputs_from_checkpoint is currently passed from the model
# to allow loading fine topography when no fine data is available.
# See PR https://github.com/ai2cm/ace/pull/728
# In the future we could disentangle this dependency between the data loader
# and model by enabling the built GriddedData objects to take in full static
# input fields and subset them to the same coordinate range as data.
if dist is None:
dist = Distributed.get_instance()

Expand Down Expand Up @@ -537,25 +473,6 @@ def build(
full_fine_coord=properties_fine.horizontal_coordinates.lon,
)

if requirements.use_fine_topography:
if static_inputs is None:
raise ValueError(
"Model requires static inputs (use_fine_topography=True),"
" but no static inputs were provided to the data loader's"
" build method."
)

static_inputs = static_inputs.to_device()
_check_fine_res_static_input_compatibility(
static_inputs.shape,
properties_fine.horizontal_coordinates.shape,
)
static_inputs = static_inputs.subset_latlon(
lat_interval=fine_lat_extent, lon_interval=fine_lon_extent
)
else:
static_inputs = None

dataset_fine_subset = HorizontalSubsetDataset(
dataset_fine,
properties=properties_fine,
Expand Down Expand Up @@ -611,12 +528,12 @@ def build(

return PairedGriddedData(
_loader=dataloader,
static_inputs=static_inputs,
coarse_shape=example.coarse.horizontal_shape,
downscale_factor=example.downscale_factor,
dims=example.fine.latlon_coordinates.dims,
variable_metadata=variable_metadata,
all_times=all_times,
fine_coords=get_latlon_coords_from_properties(properties_fine),
)

def _get_sampler(
Expand Down
Loading