-
Notifications
You must be signed in to change notification settings - Fork 38
Add separated channel module registry and step implementation #957
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2e02152
fbcc20f
b447008
e91b808
1d88b0b
4afeb0d
231f3a9
6a5bb2c
56d838e
237f03f
9858f11
5bd9e08
6d6df2e
2f2e468
91c2910
c67142f
b4e49c7
a9fbef2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,6 @@ | |
| from fme.core.distributed import Distributed | ||
| from fme.core.normalizer import NetworkAndLossNormalizationConfig, StandardNormalizer | ||
| from fme.core.ocean import Ocean, OceanConfig | ||
| from fme.core.optimization import NullOptimization | ||
| from fme.core.packer import Packer | ||
| from fme.core.registry import CorrectorSelector | ||
| from fme.core.step.args import StepArgs | ||
|
|
@@ -229,7 +228,8 @@ def get_loss_normalizer( | |
| extra_residual_scaled_names = [] | ||
| return self.normalization.get_loss_normalizer( | ||
| names=self._normalize_names + extra_names, | ||
| residual_scaled_names=self.prognostic_names + extra_residual_scaled_names, | ||
| residual_scaled_names=self.get_prognostic_names() | ||
| + extra_residual_scaled_names, | ||
| ) | ||
|
|
||
| @classmethod | ||
|
|
@@ -386,7 +386,6 @@ def __init__( | |
| self.module = dist.wrap_module(module) | ||
| self._img_shape = dataset_info.img_shape | ||
| self._config = config | ||
| self._no_optimization = NullOptimization() | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was unused. |
||
|
|
||
| self._timestep = timestep | ||
|
|
||
|
|
@@ -483,7 +482,7 @@ def network_call(input_norm: TensorDict) -> TensorDict: | |
| corrector=self._corrector, | ||
| ocean=self.ocean, | ||
| residual_prediction=self._config.residual_prediction, | ||
| prognostic_names=self.prognostic_names, | ||
| prognostic_names=self.get_prognostic_names(), | ||
| prescribed_prognostic_names=self._config.prescribed_prognostic_names, | ||
| ) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| from .corrector import CorrectorSelector | ||
| from .module import ModuleSelector | ||
| from .registry import Registry | ||
| from .separated_module import SeparatedModuleSelector |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,204 @@ | ||
| import abc | ||
| import dataclasses | ||
| from collections.abc import Callable, Mapping | ||
| from typing import Any, ClassVar | ||
|
|
||
| import dacite | ||
| import torch | ||
| from torch import nn | ||
|
|
||
| from fme.core.dataset_info import DatasetInfo | ||
| from fme.core.labels import BatchLabels, LabelEncoding | ||
|
|
||
| from .registry import Registry | ||
|
|
||
| SeparatedModuleConfigType = type["SeparatedModuleConfig"] | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class SeparatedModuleConfig(abc.ABC): | ||
| """ | ||
| Builds an nn.Module that takes separate forcing and prognostic tensors | ||
| as input, and returns separate prognostic and diagnostic tensors as output. | ||
|
|
||
| The built nn.Module must have the forward signature:: | ||
|
|
||
| forward( | ||
| forcing: Tensor, | ||
| prognostic: Tensor, | ||
| labels: Tensor | None = None, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I went with requiring all modules of this new type take in a labels argument, since there's no need for backwards compatibility with ones that don't. For network types that don't have conditioning, we can broadcast these to the domain size and just treat them as additional input variables. |
||
| ) -> tuple[Tensor, Tensor] | ||
|
|
||
| where the return value is (prognostic_out, diagnostic_out). | ||
| """ | ||
|
|
||
| @abc.abstractmethod | ||
| def build( | ||
| self, | ||
| n_forcing_channels: int, | ||
| n_prognostic_channels: int, | ||
| n_diagnostic_channels: int, | ||
| dataset_info: DatasetInfo, | ||
| ) -> nn.Module: | ||
| """ | ||
| Build a nn.Module with separated forcing/prognostic/diagnostic channels. | ||
|
|
||
| Args: | ||
| n_forcing_channels: number of input-only (forcing) channels | ||
| n_prognostic_channels: number of input-output (prognostic) channels | ||
| n_diagnostic_channels: number of output-only (diagnostic) channels | ||
| dataset_info: Information about the dataset, including img_shape, | ||
| horizontal coordinates, vertical coordinate, etc. | ||
|
|
||
| Returns: | ||
| An nn.Module whose forward method takes | ||
| (forcing, prognostic, labels=None) and returns | ||
| (prognostic_out, diagnostic_out) tensors. | ||
| """ | ||
| ... | ||
|
|
||
| @classmethod | ||
| def from_state(cls, state: Mapping[str, Any]) -> "SeparatedModuleConfig": | ||
| return dacite.from_dict( | ||
| data_class=cls, data=state, config=dacite.Config(strict=True) | ||
| ) | ||
|
|
||
|
|
||
| class SeparatedModule: | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This handles conversion of BatchLabels to Tensor, and also provides a strongly typed interface (unlike nn.Module which is treated as Any by mypy). We do the same thing for Module. |
||
| """ | ||
| Wrapper around an nn.Module with separated channel interface. | ||
|
|
||
| The wrapped module takes (forcing, prognostic) tensors and returns | ||
| (prognostic_out, diagnostic_out) tensors. This wrapper handles | ||
| optional label encoding for conditional models. | ||
| """ | ||
|
|
||
| def __init__(self, module: nn.Module, label_encoding: LabelEncoding | None): | ||
| self._module = module | ||
| self._label_encoding = label_encoding | ||
|
|
||
| def __call__( | ||
| self, | ||
| forcing: torch.Tensor, | ||
| prognostic: torch.Tensor, | ||
| labels: BatchLabels | None = None, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| if labels is not None and self._label_encoding is None: | ||
| raise ValueError( | ||
| "labels were provided but the module has no label encoding" | ||
| ) | ||
| if labels is None and self._label_encoding is not None: | ||
| raise ValueError( | ||
| "labels were not provided but the module has a label encoding" | ||
| ) | ||
| if labels is not None and self._label_encoding is not None: | ||
| encoded_labels = labels.conform_to_encoding(self._label_encoding) | ||
| return self._module(forcing, prognostic, labels=encoded_labels.tensor) | ||
|
Comment on lines
+94
to
+96
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to Also, I think this branch is untested.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Noting that the |
||
| else: | ||
| return self._module(forcing, prognostic, labels=None) | ||
|
|
||
| @property | ||
| def torch_module(self) -> nn.Module: | ||
| return self._module | ||
|
|
||
| def get_state(self) -> dict[str, Any]: | ||
| if self._label_encoding is not None: | ||
| label_encoder_state = self._label_encoding.get_state() | ||
| else: | ||
| label_encoder_state = None | ||
| return { | ||
| **self._module.state_dict(), | ||
| "label_encoding": label_encoder_state, | ||
| } | ||
|
|
||
| def load_state(self, state: dict[str, Any]) -> None: | ||
| state = state.copy() | ||
| if state.get("label_encoding") is not None: | ||
| if self._label_encoding is None: | ||
| self._label_encoding = LabelEncoding.from_state( | ||
| state.pop("label_encoding") | ||
| ) | ||
| else: | ||
| self._label_encoding.conform_to_state(state.pop("label_encoding")) | ||
| state.pop("label_encoding", None) | ||
| self._module.load_state_dict(state) | ||
|
|
||
| def wrap_module( | ||
| self, callable: Callable[[nn.Module], nn.Module] | ||
| ) -> "SeparatedModule": | ||
| return SeparatedModule(callable(self._module), self._label_encoding) | ||
|
|
||
| def to(self, device: torch.device) -> "SeparatedModule": | ||
| return SeparatedModule(self._module.to(device), self._label_encoding) | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class SeparatedModuleSelector: | ||
| """ | ||
| A dataclass for building SeparatedModuleConfig instances from config dicts. | ||
|
|
||
| Mirrors ModuleSelector but for the separated channel interface. | ||
|
|
||
| Parameters: | ||
| type: the type of the SeparatedModuleConfig | ||
| config: data for a SeparatedModuleConfig instance of the indicated type | ||
| """ | ||
|
|
||
| type: str | ||
| config: Mapping[str, Any] | ||
| registry: ClassVar[Registry[SeparatedModuleConfig]] = Registry[ | ||
| SeparatedModuleConfig | ||
| ]() | ||
|
|
||
| def __post_init__(self): | ||
| if not isinstance(self.registry, Registry): | ||
| raise ValueError( | ||
| "SeparatedModuleSelector.registry should not be set manually" | ||
| ) | ||
| self._instance = self.registry.get(self.type, self.config) | ||
|
|
||
| @property | ||
| def module_config(self) -> SeparatedModuleConfig: | ||
| return self._instance | ||
|
|
||
| @classmethod | ||
| def register( | ||
| cls, type_name: str | ||
| ) -> Callable[[SeparatedModuleConfigType], SeparatedModuleConfigType]: | ||
| return cls.registry.register(type_name) | ||
|
|
||
| def build( | ||
| self, | ||
| n_forcing_channels: int, | ||
| n_prognostic_channels: int, | ||
| n_diagnostic_channels: int, | ||
| dataset_info: DatasetInfo, | ||
| ) -> SeparatedModule: | ||
| """ | ||
| Build a SeparatedModule with separated forcing/prognostic/diagnostic | ||
| channels. | ||
|
|
||
| Args: | ||
| n_forcing_channels: number of input-only (forcing) channels | ||
| n_prognostic_channels: number of input-output (prognostic) channels | ||
| n_diagnostic_channels: number of output-only (diagnostic) channels | ||
| dataset_info: Information about the dataset, including img_shape. | ||
|
|
||
| Returns: | ||
| a SeparatedModule object | ||
| """ | ||
| if len(dataset_info.all_labels) > 0: | ||
| label_encoding = LabelEncoding(sorted(list(dataset_info.all_labels))) | ||
| else: | ||
| label_encoding = None | ||
| module = self._instance.build( | ||
| n_forcing_channels=n_forcing_channels, | ||
| n_prognostic_channels=n_prognostic_channels, | ||
| n_diagnostic_channels=n_diagnostic_channels, | ||
| dataset_info=dataset_info, | ||
| ) | ||
| return SeparatedModule(module, label_encoding) | ||
|
|
||
| @classmethod | ||
| def get_available_types(cls): | ||
| return cls.registry._types.keys() | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This property method needed to be refactored to a getter so that
prognostic_namescould be a dataclass attribute instead, on the new Step.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about making this refactor a separate PR? Fairly minor change, but does touch a lot of files.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to ignore, I don't think it's really necessary to separate off.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also went back and forth on this. Will keep it.