Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
2e02152
Add separated channel module registry
mcgibbon Mar 11, 2026
fbcc20f
Add SeparatedModuleStep and fix LegacyWrapper state_dict handling
mcgibbon Mar 11, 2026
b447008
Require non-empty prognostic_names_ and test optional forcing/diagnos…
mcgibbon Mar 11, 2026
e91b808
Refactor prognostic_names from property to get_prognostic_names() method
mcgibbon Mar 11, 2026
1d88b0b
Rename prognostic_names_ to prognostic_names in SeparatedModuleStepCo…
mcgibbon Mar 11, 2026
4afeb0d
Address PR review comments on separated module code
mcgibbon Mar 11, 2026
231f3a9
Merge branch 'main' into feature/separated-module-registry
mcgibbon Mar 11, 2026
6a5bb2c
Remove LegacyWrapper/LegacyModuleAdapter, simplify conditional handling
mcgibbon Mar 11, 2026
56d838e
Merge branch 'feature/separated-module-registry' of github.com:ai2cm/…
mcgibbon Mar 11, 2026
237f03f
Clean up unused NullOptimization and move test module out of test file
mcgibbon Mar 11, 2026
9858f11
Merge branch 'main' into feature/separated-module-registry
mcgibbon Mar 12, 2026
5bd9e08
Update fme/core/step/separated_module.py
mcgibbon Mar 12, 2026
6d6df2e
reorder arguments
mcgibbon Mar 20, 2026
2f2e468
Address PR review comments for separated module registry
mcgibbon Mar 20, 2026
91c2910
Fix register_test_types import in test_step.py
mcgibbon Mar 20, 2026
c67142f
Merge branch 'main' into feature/separated-module-registry
mcgibbon Mar 20, 2026
b4e49c7
Merge branch 'main' into feature/separated-module-registry
mcgibbon Mar 21, 2026
a9fbef2
Fix corrector building to use new get_corrector API
mcgibbon Mar 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fme/core/registry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .corrector import CorrectorSelector
from .module import ModuleSelector
from .registry import Registry
from .separated_module import SeparatedModuleSelector
271 changes: 271 additions & 0 deletions fme/core/registry/separated_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
import abc
import dataclasses
from collections.abc import Callable, Mapping

# we use Type to distinguish from type attr of SeparatedModuleSelector
from typing import Any, ClassVar, Type # noqa: UP035

import dacite
import torch
from torch import nn

from fme.core.dataset_info import DatasetInfo
from fme.core.labels import BatchLabels, LabelEncoding

from .module import CONDITIONAL_BUILDERS, ModuleSelector
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Is it appropriate to re-use the CONDITIONAL_BUILDERS from .module? Or should we be defining these independently for SeparatedModule? Or should all new module types here support conditioning (we don't need backwards compatibility)? At a minimum labels can always be given as broadcasted input variables, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Answer: Good question. Right now we reuse it because the only way to get a conditional separated module is through LegacyModuleAdapter, which delegates to ModuleSelector, so the same builder types apply. When we register native separated modules that support conditioning, we should define an independent CONDITIONAL_BUILDERS list (or a different mechanism) for this registry. For now, reusing it is correct since LegacyModuleAdapter is the only registered type.

from .registry import Registry


@dataclasses.dataclass
class SeparatedModuleConfig(abc.ABC):
"""
Builds an nn.Module that takes separate forcing and prognostic tensors
as input, and returns separate prognostic and diagnostic tensors as output.

The built nn.Module must have the forward signature:
forward(forcing: Tensor, prognostic: Tensor) -> tuple[Tensor, Tensor]
where the return value is (prognostic_out, diagnostic_out).
"""

@abc.abstractmethod
def build(
self,
n_forcing_channels: int,
n_prognostic_channels: int,
n_diagnostic_channels: int,
dataset_info: DatasetInfo,
) -> nn.Module:
"""
Build a nn.Module with separated forcing/prognostic/diagnostic channels.

Args:
n_forcing_channels: number of input-only (forcing) channels
n_prognostic_channels: number of input-output (prognostic) channels
n_diagnostic_channels: number of output-only (diagnostic) channels
dataset_info: Information about the dataset, including img_shape,
horizontal coordinates, vertical coordinate, etc.

Returns:
An nn.Module whose forward method takes (forcing, prognostic) tensors
and returns (prognostic_out, diagnostic_out) tensors.
"""
...

@classmethod
def from_state(cls, state: Mapping[str, Any]) -> "SeparatedModuleConfig":
return dacite.from_dict(
data_class=cls, data=state, config=dacite.Config(strict=True)
)


class SeparatedModule:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This handles conversion of BatchLabels to Tensor, and also provides a strongly typed interface (unlike nn.Module which is treated as Any by mypy). We do the same thing for Module.

"""
Wrapper around an nn.Module with separated channel interface.

The wrapped module takes (forcing, prognostic) tensors and returns
(prognostic_out, diagnostic_out) tensors. This wrapper handles
optional label encoding for conditional models.
"""

def __init__(self, module: nn.Module, label_encoding: LabelEncoding | None):
self._module = module
self._label_encoding = label_encoding

def __call__(
self,
forcing: torch.Tensor,
prognostic: torch.Tensor,
labels: BatchLabels | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if labels is not None and self._label_encoding is None:
raise TypeError("Labels are not allowed for unconditional models")

if self._label_encoding is not None:
if labels is None:
raise TypeError("Labels are required for conditional models")
encoded_labels = labels.conform_to_encoding(self._label_encoding)
return self._module(forcing, prognostic, labels=encoded_labels.tensor)
Comment on lines +94 to +96
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to fme.core.registry.module.Module, should this raise an error if labels is not None and self._label_encoding is None or vice versa?

Also, I think this branch is untested.

else:
return self._module(forcing, prognostic)

@property
def torch_module(self) -> nn.Module:
return self._module

def get_state(self) -> dict[str, Any]:
if self._label_encoding is not None:
label_encoder_state = self._label_encoding.get_state()
else:
label_encoder_state = None
return {
**self._module.state_dict(),
"label_encoding": label_encoder_state,
}

def load_state(self, state: dict[str, Any]) -> None:
state = state.copy()
if state.get("label_encoding") is not None:
if self._label_encoding is None:
self._label_encoding = LabelEncoding.from_state(
state.pop("label_encoding")
)
else:
self._label_encoding.conform_to_state(state.pop("label_encoding"))
state.pop("label_encoding", None)
self._module.load_state_dict(state)

def wrap_module(
self, callable: Callable[[nn.Module], nn.Module]
) -> "SeparatedModule":
return SeparatedModule(callable(self._module), self._label_encoding)

def to(self, device: torch.device) -> "SeparatedModule":
return SeparatedModule(self._module.to(device), self._label_encoding)


@dataclasses.dataclass
class SeparatedModuleSelector:
"""
A dataclass for building SeparatedModuleConfig instances from config dicts.

Mirrors ModuleSelector but for the separated channel interface.

Parameters:
type: the type of the SeparatedModuleConfig
config: data for a SeparatedModuleConfig instance of the indicated type
conditional: whether to condition the predictions on batch labels.
"""

type: str
config: Mapping[str, Any]
conditional: bool = False
registry: ClassVar[Registry[SeparatedModuleConfig]] = Registry[
SeparatedModuleConfig
]()

def __post_init__(self):
if not isinstance(self.registry, Registry):
raise ValueError(
"SeparatedModuleSelector.registry should not be set manually"
)
if self.conditional and self.type not in CONDITIONAL_BUILDERS:
raise ValueError(
"Conditional predictions require a conditional builder, "
f"got {self.type} (available: {CONDITIONAL_BUILDERS})"
)
self._instance = self.registry.get(self.type, self.config)

@property
def module_config(self) -> SeparatedModuleConfig:
return self._instance

@classmethod
def register(
cls, type_name: str
) -> Callable[[Type[SeparatedModuleConfig]], Type[SeparatedModuleConfig]]: # noqa: UP006
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Why is a noqa needed here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because of type being an attribute, so the type built-in is no longer accessible.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it work to define SeparatedModuleConfigType = type[SeparatedModuleConfig] at the module level to avoid the clash?

return cls.registry.register(type_name)

def build(
self,
n_forcing_channels: int,
n_prognostic_channels: int,
n_diagnostic_channels: int,
dataset_info: DatasetInfo,
) -> SeparatedModule:
"""
Build a SeparatedModule with separated forcing/prognostic/diagnostic
channels.

Args:
n_forcing_channels: number of input-only (forcing) channels
n_prognostic_channels: number of input-output (prognostic) channels
n_diagnostic_channels: number of output-only (diagnostic) channels
dataset_info: Information about the dataset, including img_shape.

Returns:
a SeparatedModule object
"""
if self.conditional and len(dataset_info.all_labels) == 0:
raise ValueError("Conditional predictions require labels")
if self.conditional:
label_encoding = LabelEncoding(sorted(list(dataset_info.all_labels)))
else:
label_encoding = None
module = self._instance.build(
n_forcing_channels=n_forcing_channels,
n_prognostic_channels=n_prognostic_channels,
n_diagnostic_channels=n_diagnostic_channels,
dataset_info=dataset_info,
)
return SeparatedModule(module, label_encoding)

@classmethod
def get_available_types(cls):
return cls.registry._types.keys()


class LegacyWrapper(nn.Module):
"""Wraps an old-style (single tensor in/out) module for the separated
channel interface.

Input channels are ordered as [forcing, prognostic] and output channels
are ordered as [prognostic, diagnostic]. The inner module receives a
single concatenated input and produces a single concatenated output.
"""

def __init__(
self,
inner: nn.Module,
n_prognostic_channels: int,
):
super().__init__()
self.inner = inner
self._n_prognostic_channels = n_prognostic_channels

def forward(
self, forcing: torch.Tensor, prognostic: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
combined = torch.cat([forcing, prognostic], dim=-3)
output = self.inner(combined)
n = self._n_prognostic_channels
prog_out = output.narrow(-3, 0, n)
diag_out = output.narrow(-3, n, output.shape[-3] - n)
return prog_out, diag_out


@SeparatedModuleSelector.register("legacy")
@dataclasses.dataclass
class LegacyModuleAdapter(SeparatedModuleConfig):
"""
Adapter that wraps a legacy ModuleSelector module for the separated
channel interface.

Input channels are ordered as [forcing, prognostic] and output channels
are ordered as [prognostic, diagnostic].

Parameters:
legacy_builder: The legacy ModuleSelector for building the inner module.
"""

legacy_builder: ModuleSelector

def build(
self,
n_forcing_channels: int,
n_prognostic_channels: int,
n_diagnostic_channels: int,
dataset_info: DatasetInfo,
) -> nn.Module:
n_in = n_forcing_channels + n_prognostic_channels
n_out = n_prognostic_channels + n_diagnostic_channels

legacy_module = self.legacy_builder.build(
n_in_channels=n_in,
n_out_channels=n_out,
dataset_info=dataset_info,
)

return LegacyWrapper(
inner=legacy_module.torch_module,
n_prognostic_channels=n_prognostic_channels,
)
Loading
Loading