Skip to content

Add fine coordinates to the model for easier inference handling#971

Open
frodre wants to merge 29 commits intomainfrom
feature/downscaling-model-fine-coords
Open

Add fine coordinates to the model for easier inference handling#971
frodre wants to merge 29 commits intomainfrom
feature/downscaling-model-fine-coords

Conversation

@frodre
Copy link
Collaborator

@frodre frodre commented Mar 13, 2026

DiffusionModel previously relied on StaticInput.coords to store the fine-resolution lat/lon grid, coupling spatial metadata to individual topography fields. This made coordinate handling awkward since models without any static inputs had no coordinate information and would either fail or required a _downscale_coord bandaid to approximate fine-resolution coordinates from coarse ones.

Changes:

  • StaticInputs now carries a required coords: LatLonCoordinates field representing the full fine-resolution domain. coords is removed from individual StaticInput fields.

  • DiffusionModel always receives a non-optional StaticInputs (fields may be empty when no static data is needed). full_fine_coords is a property delegating to static_inputs.coords.

  • DiffusionModelConfig.build and DiffusionModel.__init__ no longer accept None for static_inputs.

  • StaticInputs.subset(lat_interval, lon_interval) replaces subset_latlon, computing slices internally and subsetting both fields and coords together.

  • ClosedInterval.slice_of renamed to slice_from; new subset_of convenience method returns the coordinate values within the interval.

  • All checkpoint backwards-compatibility loading logic is consolidated in StaticInputs.from_state_backwards_compatible in static.py. CheckpointModelConfig gains an optional fine_coordinates_path for old checkpoints with no stored coordinates.

  • load_fine_coords_from_path added to fme.downscaling.data and tested.

  • PairedGriddedData carries fine_coords, passed to the model at build time to use as a fallback if no static input's with coordinates in the dataset are specified.

  • Removed _downscale_coord from predict.py; model.get_fine_coords_for_batch is used instead.

  • Tests added

@frodre frodre force-pushed the feature/downscaling-model-fine-coords branch 3 times, most recently from c7ff073 to a97e954 Compare March 16, 2026 22:59
@frodre frodre changed the base branch from refactor/remove-static-input-from-data-and-call-sigs to main March 16, 2026 23:07
@frodre frodre changed the base branch from main to refactor/remove-static-input-from-data-and-call-sigs March 16, 2026 23:08
@frodre frodre force-pushed the feature/downscaling-model-fine-coords branch 2 times, most recently from 35c2234 to 229e858 Compare March 16, 2026 23:36
@frodre frodre changed the base branch from refactor/remove-static-input-from-data-and-call-sigs to main March 16, 2026 23:44
@frodre frodre changed the base branch from main to refactor/remove-static-input-from-data-and-call-sigs March 16, 2026 23:44
Base automatically changed from refactor/remove-static-input-from-data-and-call-sigs to main March 17, 2026 21:01
frodre added a commit that referenced this pull request Mar 17, 2026
This PR finalizes the removal of the `StaticInput` handling by the data
pipeline. The passing of static_input objects are removed from the data
configuration, batch iteration, and model call signatures in favor of
the direct model handling introduced in the previous downscaling PR
(#954).

Changes:
- add `get_fine_coords_for_batch` to facilitate translation of an input
batch domain to output coordinates via the models stored information.
For now, this relies on the model's `static_inputs`, but will be
switched to model's stored coordinates in (#971)
- inference `Downscaler` now takes the batch `input_shape` instead of
`static_inputs` to check the domain size and model type (regular
`DiffusionModel` or `PatchPredictor`
- downscaling `torch.datasets` generators for `BatchData` no longer
include `StaticInputs`
- removed `_apply_patch` and `_generate_from_patches` from
`StaticInputs`
- `config.py` no longer references static inputs as an argument

- [x] Tests added
@frodre frodre force-pushed the feature/downscaling-model-fine-coords branch from ae4b09e to 5add727 Compare March 17, 2026 21:24
# Load fine_coords: new checkpoints store it directly; old checkpoints
# that had static_inputs with coords can auto-migrate from raw state.
fine_coords = state.get("fine_coords")
if fine_coords is not None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block is the pathway executed during training resumption and will fail if we try and resume training for a model without any static inputs or fine coordinates. Not totally sure if this sort of backwards compat is really necessary, since we'll likely just be training new models w/ fine_coords.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm ok with breaking backwards compatibility here. I think the oldest checkpoint we'd possibly want to continue to use would be the released precip-only model checkpoint, which has static inputs saved.

@frodre frodre marked this pull request as ready for review March 17, 2026 22:58
Copy link
Contributor

@AnnaKwa AnnaKwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with adding a fine coords attribute to the model and saving it as well. I do prefer to keep the the coordinates on the StaticInputs class though as the main use of the fine coords in the model is to subset the static inputs, and having them in that class this keeps the subsetting of that tensor cleaner while making it 100% clear that the coordinates are associated with that set of static inputs. Otherwise my worry in separating the coordinates from the static inputs is that down the line it may be easier to introduce a bug where the coordinates don't match.

Could the model attribute instead default to point to the static inputs fine coordinates and be set by an optional checkpoint config path if required for a model w/o static inputs?

coarse to fine.
sigma_data: The standard deviation of the data, used for diffusion
model preconditioning.
fine_coords: the full-domain fine-resolution coordinates to use
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: call this full_fine_coords or something to that effect so it's obvious that this is the full domain and doesn't need to be updated if training is resumed on some different domain.

self,
coarse_shape: tuple[int, int],
downscale_factor: int,
fine_coords: LatLonCoordinates,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this and other places where this arg is added, see comment about naming to describe it's the full domain.

coarse_shape: tuple[int, int],
downscale_factor: int,
sigma_data: float,
fine_coords: LatLonCoordinates,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: make this arg optional, only to be used if building from a checkpoint that has an optional fine_coords_path arg to allow for models w/o static inputs to have correct coords in the saved predict/evaluate output. Otherwise for the standard case where the model uses static inputs, set the attribute self.fine_coords using the static inputs coords.

frodre and others added 2 commits March 18, 2026 13:20
Co-authored-by: Anna Kwa <annak@allenai.org>
Co-authored-by: Anna Kwa <annak@allenai.org>
return f"LatLonCoordinates(\n lat={self.lat},\n lon={self.lon}\n"

def to(self, device: str) -> "LatLonCoordinates":
def to(self, device: str | torch.device) -> "LatLonCoordinates":
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make my VSCode linter happy

@dataclasses.dataclass
class StaticInputs:
fields: list[StaticInput]
coords: LatLonCoordinates
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not named full_coords because we do produce subsets with this class.

@frodre frodre requested a review from AnnaKwa March 19, 2026 23:22
@AnnaKwa
Copy link
Contributor

AnnaKwa commented Mar 20, 2026

Ah, I'm very sorry for the confusion- when I said "I do prefer to keep the the coordinates on the StaticInputs" I meant I preferred them kept where there were currently were (within the static inputs but on individual StaticInput objects not the higher level StaticInputs). The potential mismatch in coords I was concerned about in the previous iteration was between the static inputs coords (post init ensures all StaticInput objects have same coords) and the DiffusionModel's coords attribute getting out of sync through mixing of saved checkpoints and updated configs. I think the static inputs coords are best kept in the lowest level object StaticInput so it's clear they are associated with the data information there.

PairedGriddedData carries fine_coords, passed to the model at build time.

Is this so that models without static inputs have the fine coords information? Could we instead set a _fine_coords_from_gridded_data attribute and fall back to using this if there are no static inputs (rather than setting the static inputs coords from the fine dataset coords)? See comment on the full_fine_coords property.
It should amount the same thing since the fine coords are usually the same from both sources, but this way it's 100% guaranteed that the static inputs coords are direct from their underlying datasets.

return self.static_inputs.subset_latlon(lat_interval, lon_interval)
@property
def full_fine_coords(self) -> LatLonCoordinates:
return self.static_inputs.coords
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we do something like have an attribute self._full_fine_coords_from_gridded_data: None | LatLonCoordinates that is set in the build method, and then

if len(self.static_inputs.fields) > 0: # or maybe add a len method to the class
    return self.static_inputs.coords
else:
    return self._full_fine_coords_from_gridded_data

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the fallback be put here, rather than within the StaticInputs?

return self.static_inputs.subset_latlon(lat_interval, lon_interval)
@property
def full_fine_coords(self) -> LatLonCoordinates:
return self.static_inputs.coords
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the fallback be put here, rather than within the StaticInputs?

# no coords found with static inputs, use provided fallback
coords_to_use = fallback_coords
elif validate_coords:
_validate_coords("fallback", coords_to_use, fallback_coords)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this validation also get moved up to be done at the model level when it gets built?

try:
coords = _load_coords_from_ds(ds)
except ValueError:
# no coords available
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up to you, but in the data loading (get_horizontal_coordinates) it's assumed the last two dims are lat, lon so I think it's ok to assume this here as well. If for some reason there are no usable coords I think we should just raise the error here (i.e. if there are static inputs, we expect them to have valid coords).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants