-
Notifications
You must be signed in to change notification settings - Fork 38
Add fine coordinates to the model for easier inference handling #971
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
frodre
wants to merge
29
commits into
main
Choose a base branch
from
feature/downscaling-model-fine-coords
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 26 commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
754f28a
Initial shot
frodre 7c5b4c5
Make fine coords required
frodre c7f58d9
Fine coords required for paired data
frodre 563b57a
Mesh with previous updates in refactor pr
frodre f4218dd
Simplify event downscaler coordinate in run()
frodre 68cab61
use batch latlon coardinates for coarse
frodre 83ca043
Make fine coord loader public
frodre 794a7d4
BatchLatLon coord access consistency
frodre b542080
linting
frodre 5add727
Add no coords checkpoint with path test
frodre 36e80db
Small tweaks
frodre b19c9d6
Add load_fine_coords_from_path test
frodre 7acf87c
Update fme/downscaling/models.py
frodre eba2749
Update fme/downscaling/models.py
frodre 2c25f1d
Update fme/downscaling/models.py
frodre 81e4eb8
use latlon coords .to method for device fix
frodre 325d324
Redo based on Anna's comments
frodre 1fac745
Remove from_state docstring
frodre 3ef1613
Move all state loading cases into static inputs code
frodre 36e8101
Remove unused function from fme.downscaling.data
frodre 3f95457
fine_coords -> full_fine_coords
frodre 6c9fcf7
Remove duplicated tests from models.py
frodre 4b58b86
Minor fixes
frodre 7e36f55
Fix imports
frodre a85fb05
Final cleanup
frodre b1ee3d4
Merge branch 'main' into feature/downscaling-model-fine-coords
frodre bedd276
Add coordinate validation and use training data as a fallback vs. sta…
frodre a9dbae3
Updates based on discussion w/ Anna
frodre 46670ff
Merge branch 'main' into feature/downscaling-model-fine-coords
frodre File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,67 +11,42 @@ | |
| @dataclasses.dataclass | ||
| class StaticInput: | ||
| data: torch.Tensor | ||
| coords: LatLonCoordinates | ||
|
|
||
| def __post_init__(self): | ||
| if len(self.data.shape) != 2: | ||
| raise ValueError(f"Topography data must be 2D. Got shape {self.data.shape}") | ||
| if self.data.shape[0] != len(self.coords.lat) or self.data.shape[1] != len( | ||
| self.coords.lon | ||
| ): | ||
| raise ValueError( | ||
| f"Static inputs data shape {self.data.shape} does not match lat/lon " | ||
| f"coordinates shape {(len(self.coords.lat), len(self.coords.lon))}" | ||
| f"StaticInput data must be 2D. Got shape {self.data.shape}" | ||
| ) | ||
| self._shape = (self.data.shape[0], self.data.shape[1]) | ||
|
|
||
| @property | ||
| def dim(self) -> int: | ||
| return len(self.shape) | ||
|
|
||
| @property | ||
| def shape(self) -> tuple[int, int]: | ||
| return self.data.shape | ||
| return self._shape | ||
|
|
||
| def subset_latlon( | ||
| def subset( | ||
| self, | ||
| lat_interval: ClosedInterval, | ||
| lon_interval: ClosedInterval, | ||
| lat_slice: slice, | ||
| lon_slice: slice, | ||
| ) -> "StaticInput": | ||
| lat_slice = lat_interval.slice_of(self.coords.lat) | ||
| lon_slice = lon_interval.slice_of(self.coords.lon) | ||
| return self._latlon_index_slice(lat_slice=lat_slice, lon_slice=lon_slice) | ||
| return StaticInput(data=self.data[lat_slice, lon_slice]) | ||
|
|
||
| def to_device(self) -> "StaticInput": | ||
| device = get_device() | ||
| return StaticInput( | ||
| data=self.data.to(device), | ||
| coords=LatLonCoordinates( | ||
| lat=self.coords.lat.to(device), | ||
| lon=self.coords.lon.to(device), | ||
| ), | ||
| ) | ||
|
|
||
| def _latlon_index_slice( | ||
| self, | ||
| lat_slice: slice, | ||
| lon_slice: slice, | ||
| ) -> "StaticInput": | ||
| sliced_data = self.data[lat_slice, lon_slice] | ||
| sliced_latlon = LatLonCoordinates( | ||
| lat=self.coords.lat[lat_slice], | ||
| lon=self.coords.lon[lon_slice], | ||
| ) | ||
| return StaticInput( | ||
| data=sliced_data, | ||
| coords=sliced_latlon, | ||
| ) | ||
| return StaticInput(data=self.data.to(device)) | ||
|
|
||
| def get_state(self) -> dict: | ||
| return { | ||
| "data": self.data.cpu(), | ||
| "coords": self.coords.get_state(), | ||
| } | ||
|
|
||
| @classmethod | ||
| def from_state(cls, state: dict) -> "StaticInput": | ||
| return cls(data=state["data"]) | ||
|
|
||
|
|
||
| def _get_normalized_static_input(path: str, field_name: str): | ||
| """ | ||
|
|
@@ -93,96 +68,179 @@ def _get_normalized_static_input(path: str, field_name: str): | |
| f"unexpected shape {static_input.shape} for static input." | ||
| "Currently, only lat/lon static input is supported." | ||
| ) | ||
| lat_name, lon_name = static_input.dims[-2:] | ||
| coords = LatLonCoordinates( | ||
| lon=torch.tensor(static_input[lon_name].values), | ||
| lat=torch.tensor(static_input[lat_name].values), | ||
| ) | ||
|
|
||
| static_input_normalized = (static_input - static_input.mean()) / static_input.std() | ||
|
|
||
| return StaticInput( | ||
| data=torch.tensor(static_input_normalized.values, dtype=torch.float32), | ||
| coords=coords, | ||
| ) | ||
|
|
||
|
|
||
| def _has_legacy_coords_in_state(state: dict) -> bool: | ||
| return "fields" in state and state["fields"] and "coords" in state["fields"][0] | ||
|
|
||
|
|
||
| def _sync_state_coordinates(state: dict) -> dict: | ||
| # if necessary adjusts legacy coordinate to expected | ||
| # format for state loading | ||
| state = state.copy() | ||
| if _has_legacy_coords_in_state(state): | ||
| state["coords"] = state["fields"][0]["coords"] | ||
| return state | ||
|
|
||
|
|
||
| def _has_coords_in_state(state: dict) -> bool: | ||
| if "coords" in state or _has_legacy_coords_in_state(state): | ||
| return True | ||
| else: | ||
| return False | ||
|
|
||
|
|
||
| def load_fine_coords_from_path(path: str) -> LatLonCoordinates: | ||
| if path.endswith(".zarr"): | ||
| ds = xr.open_zarr(path) | ||
| else: | ||
| ds = xr.open_dataset(path) | ||
| lat_name = next((n for n in ["lat", "latitude", "grid_yt"] if n in ds.coords), None) | ||
| lon_name = next( | ||
| (n for n in ["lon", "longitude", "grid_xt"] if n in ds.coords), None | ||
| ) | ||
| if lat_name is None or lon_name is None: | ||
| raise ValueError( | ||
| f"Could not find lat/lon coordinates in {path}. " | ||
| "Expected 'lat'/'latitude'/'grid_yt' and 'lon'/'longitude'/'grid_xt'." | ||
| ) | ||
| return LatLonCoordinates( | ||
| lat=torch.tensor(ds[lat_name].values, dtype=torch.float32), | ||
| lon=torch.tensor(ds[lon_name].values, dtype=torch.float32), | ||
| ) | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class StaticInputs: | ||
| fields: list[StaticInput] | ||
| coords: LatLonCoordinates | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not named |
||
|
|
||
| def __post_init__(self): | ||
| for i, field in enumerate(self.fields[1:]): | ||
| if field.coords != self.fields[0].coords: | ||
| if field.shape != self.fields[0].shape: | ||
| raise ValueError( | ||
| f"All StaticInput fields must have the same coordinates. " | ||
| f"Fields {i} and 0 do not match coordinates." | ||
| f"All StaticInput fields must have the same shape. " | ||
| f"Fields {i + 1} and 0 do not match shapes." | ||
| ) | ||
| if self.fields and self.coords.shape != self.fields[0].shape: | ||
| raise ValueError( | ||
| f"Coordinates shape {self.coords.shape} does not match fields shape " | ||
| f"{self.fields[0].shape} for StaticInputs." | ||
| ) | ||
|
|
||
| def __getitem__(self, index: int): | ||
| return self.fields[index] | ||
|
|
||
| @property | ||
| def coords(self) -> LatLonCoordinates: | ||
| if len(self.fields) == 0: | ||
| raise ValueError("No fields in StaticInputs to get coordinates from.") | ||
| return self.fields[0].coords | ||
|
|
||
| @property | ||
| def shape(self) -> tuple[int, int]: | ||
| if len(self.fields) == 0: | ||
| raise ValueError("No fields in StaticInputs to get shape from.") | ||
| return self.fields[0].shape | ||
|
|
||
| def subset_latlon( | ||
| def subset( | ||
| self, | ||
| lat_interval: ClosedInterval, | ||
| lon_interval: ClosedInterval, | ||
| ) -> "StaticInputs": | ||
| lat_slice = lat_interval.slice_from(self.coords.lat) | ||
| lon_slice = lon_interval.slice_from(self.coords.lon) | ||
| return StaticInputs( | ||
| fields=[ | ||
| field.subset_latlon(lat_interval, lon_interval) for field in self.fields | ||
| ] | ||
| fields=[field.subset(lat_slice, lon_slice) for field in self.fields], | ||
| coords=LatLonCoordinates( | ||
| lat=lat_interval.subset_of(self.coords.lat), | ||
| lon=lon_interval.subset_of(self.coords.lon), | ||
| ), | ||
| ) | ||
|
|
||
| def to_device(self) -> "StaticInputs": | ||
| return StaticInputs(fields=[field.to_device() for field in self.fields]) | ||
| return StaticInputs( | ||
| fields=[field.to_device() for field in self.fields], | ||
| coords=self.coords.to(get_device()), | ||
| ) | ||
|
|
||
| def get_state(self) -> dict: | ||
| return { | ||
| "fields": [field.get_state() for field in self.fields], | ||
| "coords": self.coords.get_state(), | ||
| } | ||
|
|
||
| @classmethod | ||
| def from_state(cls, state: dict) -> "StaticInputs": | ||
| if not _has_coords_in_state(state): | ||
| raise ValueError( | ||
| "No coordinates found in state for StaticInputs. Load with " | ||
| "from_state_backwards_compatible if loading from a checkpoint " | ||
| "saved prior to current coordinate serialization format." | ||
| ) | ||
| state = _sync_state_coordinates(state) | ||
| return cls( | ||
| fields=[ | ||
| StaticInput( | ||
| data=field_state["data"], | ||
| coords=LatLonCoordinates( | ||
| lat=field_state["coords"]["lat"], | ||
| lon=field_state["coords"]["lon"], | ||
| ), | ||
| ) | ||
| for field_state in state["fields"] | ||
| ] | ||
| StaticInput.from_state(field_state) for field_state in state["fields"] | ||
| ], | ||
| coords=LatLonCoordinates( | ||
| lat=state["coords"]["lat"], | ||
| lon=state["coords"]["lon"], | ||
| ), | ||
| ) | ||
|
|
||
| @classmethod | ||
| def from_state_backwards_compatible( | ||
| cls, | ||
| state: dict, | ||
| static_inputs_config: dict[str, str], | ||
| fine_coordinates_path: str | None, | ||
| ) -> "StaticInputs": | ||
| if state and static_inputs_config: | ||
| raise ValueError( | ||
| "Checkpoint contains static inputs but static_inputs_config is " | ||
| "also provided. Backwards compatibility loading only supports " | ||
| "a single source of StaticInputs info." | ||
| ) | ||
|
|
||
| if fine_coordinates_path and _has_coords_in_state(state): | ||
| raise ValueError( | ||
| "State contains coordinates but fine_coordinates_path is also provided." | ||
| " Only one source of coordinate info can be used for backwards " | ||
| "compatibility loading of StaticInputs." | ||
| ) | ||
| elif not _has_coords_in_state(state) and not fine_coordinates_path: | ||
| raise ValueError( | ||
| "No coordinates found in state and no fine_coordinates_path provided. " | ||
| "Cannot load StaticInputs without coordinates." | ||
| ) | ||
|
|
||
| # All compatibility cases: | ||
| # Serialized StaticInputs exist, which always had coordinates stored | ||
| # No serialized static inputs or specified inputs, load coordinates | ||
| # Specified static input fields and specified coordinates | ||
|
|
||
| if _has_coords_in_state(state): | ||
| return cls.from_state(state) | ||
| else: | ||
| assert fine_coordinates_path is not None # for type checker | ||
| coords = load_fine_coords_from_path(fine_coordinates_path) | ||
|
|
||
| if static_inputs_config: | ||
| return load_static_inputs(static_inputs_config, coords) | ||
| else: | ||
| return cls(fields=[], coords=coords) | ||
|
|
||
|
|
||
| def load_static_inputs( | ||
| static_inputs_config: dict[str, str] | None, | ||
| ) -> StaticInputs | None: | ||
| static_inputs_config: dict[str, str], coords: LatLonCoordinates | ||
| ) -> StaticInputs: | ||
| """ | ||
| Load normalized static inputs from a mapping of field names to file paths. | ||
| Returns None if the input config is empty. | ||
| Returns an empty StaticInputs (no fields) if the config is empty. | ||
| """ | ||
| # TODO: consolidate/simplify empty StaticInputs vs. None handling in | ||
| # downscaling code | ||
| if not static_inputs_config: | ||
| return None | ||
| return StaticInputs( | ||
| fields=[ | ||
| _get_normalized_static_input(path, field_name) | ||
| for field_name, path in static_inputs_config.items() | ||
| ] | ||
| ) | ||
| fields = [ | ||
| _get_normalized_static_input(path, field_name) | ||
| for field_name, path in static_inputs_config.items() | ||
| ] | ||
| return StaticInputs(fields=fields, coords=coords) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make my VSCode linter happy