Add fine coordinates to the model for easier inference handling#971
Add fine coordinates to the model for easier inference handling#971
Conversation
c7ff073 to
a97e954
Compare
35c2234 to
229e858
Compare
This PR finalizes the removal of the `StaticInput` handling by the data pipeline. The passing of static_input objects are removed from the data configuration, batch iteration, and model call signatures in favor of the direct model handling introduced in the previous downscaling PR (#954). Changes: - add `get_fine_coords_for_batch` to facilitate translation of an input batch domain to output coordinates via the models stored information. For now, this relies on the model's `static_inputs`, but will be switched to model's stored coordinates in (#971) - inference `Downscaler` now takes the batch `input_shape` instead of `static_inputs` to check the domain size and model type (regular `DiffusionModel` or `PatchPredictor` - downscaling `torch.datasets` generators for `BatchData` no longer include `StaticInputs` - removed `_apply_patch` and `_generate_from_patches` from `StaticInputs` - `config.py` no longer references static inputs as an argument - [x] Tests added
ae4b09e to
5add727
Compare
fme/downscaling/models.py
Outdated
| # Load fine_coords: new checkpoints store it directly; old checkpoints | ||
| # that had static_inputs with coords can auto-migrate from raw state. | ||
| fine_coords = state.get("fine_coords") | ||
| if fine_coords is not None: |
There was a problem hiding this comment.
This block is the pathway executed during training resumption and will fail if we try and resume training for a model without any static inputs or fine coordinates. Not totally sure if this sort of backwards compat is really necessary, since we'll likely just be training new models w/ fine_coords.
There was a problem hiding this comment.
I'm ok with breaking backwards compatibility here. I think the oldest checkpoint we'd possibly want to continue to use would be the released precip-only model checkpoint, which has static inputs saved.
AnnaKwa
left a comment
There was a problem hiding this comment.
I agree with adding a fine coords attribute to the model and saving it as well. I do prefer to keep the the coordinates on the StaticInputs class though as the main use of the fine coords in the model is to subset the static inputs, and having them in that class this keeps the subsetting of that tensor cleaner while making it 100% clear that the coordinates are associated with that set of static inputs. Otherwise my worry in separating the coordinates from the static inputs is that down the line it may be easier to introduce a bug where the coordinates don't match.
Could the model attribute instead default to point to the static inputs fine coordinates and be set by an optional checkpoint config path if required for a model w/o static inputs?
fme/downscaling/models.py
Outdated
| coarse to fine. | ||
| sigma_data: The standard deviation of the data, used for diffusion | ||
| model preconditioning. | ||
| fine_coords: the full-domain fine-resolution coordinates to use |
There was a problem hiding this comment.
Suggestion: call this full_fine_coords or something to that effect so it's obvious that this is the full domain and doesn't need to be updated if training is resumed on some different domain.
fme/downscaling/models.py
Outdated
| self, | ||
| coarse_shape: tuple[int, int], | ||
| downscale_factor: int, | ||
| fine_coords: LatLonCoordinates, |
There was a problem hiding this comment.
For this and other places where this arg is added, see comment about naming to describe it's the full domain.
fme/downscaling/models.py
Outdated
| coarse_shape: tuple[int, int], | ||
| downscale_factor: int, | ||
| sigma_data: float, | ||
| fine_coords: LatLonCoordinates, |
There was a problem hiding this comment.
Suggestion: make this arg optional, only to be used if building from a checkpoint that has an optional fine_coords_path arg to allow for models w/o static inputs to have correct coords in the saved predict/evaluate output. Otherwise for the standard case where the model uses static inputs, set the attribute self.fine_coords using the static inputs coords.
Co-authored-by: Anna Kwa <annak@allenai.org>
Co-authored-by: Anna Kwa <annak@allenai.org>
Co-authored-by: Anna Kwa <annak@allenai.org>
| return f"LatLonCoordinates(\n lat={self.lat},\n lon={self.lon}\n" | ||
|
|
||
| def to(self, device: str) -> "LatLonCoordinates": | ||
| def to(self, device: str | torch.device) -> "LatLonCoordinates": |
There was a problem hiding this comment.
To make my VSCode linter happy
| @dataclasses.dataclass | ||
| class StaticInputs: | ||
| fields: list[StaticInput] | ||
| coords: LatLonCoordinates |
There was a problem hiding this comment.
Not named full_coords because we do produce subsets with this class.
|
Ah, I'm very sorry for the confusion- when I said "I do prefer to keep the the coordinates on the StaticInputs" I meant I preferred them kept where there were currently were (within the static inputs but on individual
Is this so that models without static inputs have the fine coords information? Could we instead set a |
fme/downscaling/models.py
Outdated
| return self.static_inputs.subset_latlon(lat_interval, lon_interval) | ||
| @property | ||
| def full_fine_coords(self) -> LatLonCoordinates: | ||
| return self.static_inputs.coords |
There was a problem hiding this comment.
Could we do something like have an attribute self._full_fine_coords_from_gridded_data: None | LatLonCoordinates that is set in the build method, and then
if len(self.static_inputs.fields) > 0: # or maybe add a len method to the class
return self.static_inputs.coords
else:
return self._full_fine_coords_from_gridded_data
There was a problem hiding this comment.
Can the fallback be put here, rather than within the StaticInputs?
fme/downscaling/models.py
Outdated
| return self.static_inputs.subset_latlon(lat_interval, lon_interval) | ||
| @property | ||
| def full_fine_coords(self) -> LatLonCoordinates: | ||
| return self.static_inputs.coords |
There was a problem hiding this comment.
Can the fallback be put here, rather than within the StaticInputs?
fme/downscaling/data/static.py
Outdated
| # no coords found with static inputs, use provided fallback | ||
| coords_to_use = fallback_coords | ||
| elif validate_coords: | ||
| _validate_coords("fallback", coords_to_use, fallback_coords) |
There was a problem hiding this comment.
Could this validation also get moved up to be done at the model level when it gets built?
fme/downscaling/data/static.py
Outdated
| try: | ||
| coords = _load_coords_from_ds(ds) | ||
| except ValueError: | ||
| # no coords available |
There was a problem hiding this comment.
Up to you, but in the data loading (get_horizontal_coordinates) it's assumed the last two dims are lat, lon so I think it's ok to assume this here as well. If for some reason there are no usable coords I think we should just raise the error here (i.e. if there are static inputs, we expect them to have valid coords).
DiffusionModelpreviously relied onStaticInput.coordsto store the fine-resolution lat/lon grid, coupling spatial metadata to individual topography fields. This made coordinate handling awkward since models without any static inputs had no coordinate information and would either fail or required a_downscale_coordbandaid to approximate fine-resolution coordinates from coarse ones.Changes:
StaticInputsnow carries a requiredcoords: LatLonCoordinatesfield representing the full fine-resolution domain.coordsis removed from individualStaticInputfields.DiffusionModelalways receives a non-optionalStaticInputs(fields may be empty when no static data is needed).full_fine_coordsis a property delegating tostatic_inputs.coords.DiffusionModelConfig.buildandDiffusionModel.__init__no longer acceptNoneforstatic_inputs.StaticInputs.subset(lat_interval, lon_interval)replacessubset_latlon, computing slices internally and subsetting both fields and coords together.ClosedInterval.slice_ofrenamed toslice_from; newsubset_ofconvenience method returns the coordinate values within the interval.All checkpoint backwards-compatibility loading logic is consolidated in
StaticInputs.from_state_backwards_compatibleinstatic.py.CheckpointModelConfiggains an optionalfine_coordinates_pathfor old checkpoints with no stored coordinates.load_fine_coords_from_pathadded tofme.downscaling.dataand tested.PairedGriddedDatacarriesfine_coords, passed to the model at build time to use as a fallback if no static input's with coordinates in the dataset are specified.Removed
_downscale_coordfrompredict.py;model.get_fine_coords_for_batchis used instead.Tests added