Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
2e02152
Add separated channel module registry
mcgibbon Mar 11, 2026
fbcc20f
Add SeparatedModuleStep and fix LegacyWrapper state_dict handling
mcgibbon Mar 11, 2026
b447008
Require non-empty prognostic_names_ and test optional forcing/diagnos…
mcgibbon Mar 11, 2026
e91b808
Refactor prognostic_names from property to get_prognostic_names() method
mcgibbon Mar 11, 2026
1d88b0b
Rename prognostic_names_ to prognostic_names in SeparatedModuleStepCo…
mcgibbon Mar 11, 2026
4afeb0d
Address PR review comments on separated module code
mcgibbon Mar 11, 2026
231f3a9
Merge branch 'main' into feature/separated-module-registry
mcgibbon Mar 11, 2026
6a5bb2c
Remove LegacyWrapper/LegacyModuleAdapter, simplify conditional handling
mcgibbon Mar 11, 2026
56d838e
Merge branch 'feature/separated-module-registry' of github.com:ai2cm/…
mcgibbon Mar 11, 2026
237f03f
Clean up unused NullOptimization and move test module out of test file
mcgibbon Mar 11, 2026
9858f11
Merge branch 'main' into feature/separated-module-registry
mcgibbon Mar 12, 2026
5bd9e08
Update fme/core/step/separated_module.py
mcgibbon Mar 12, 2026
6d6df2e
reorder arguments
mcgibbon Mar 20, 2026
2f2e468
Address PR review comments for separated module registry
mcgibbon Mar 20, 2026
91c2910
Fix register_test_types import in test_step.py
mcgibbon Mar 20, 2026
c67142f
Merge branch 'main' into feature/separated-module-registry
mcgibbon Mar 20, 2026
b4e49c7
Merge branch 'main' into feature/separated-module-registry
mcgibbon Mar 21, 2026
a9fbef2
Fix corrector building to use new get_corrector API
mcgibbon Mar 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fme/ace/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def run_inference_from_config(config: InferenceConfig):
logging.info("Loading initial condition data")
initial_condition = get_initial_condition(
config.initial_condition.get_dataset(),
stepper_config.prognostic_names,
stepper_config.get_prognostic_names(),
Copy link
Contributor Author

@mcgibbon mcgibbon Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This property method needed to be refactored to a getter so that prognostic_names could be a dataclass attribute instead, on the new Step.

Copy link
Member

@jpdunc23 jpdunc23 Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about making this refactor a separate PR? Fairly minor change, but does touch a lot of files.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to ignore, I don't think it's really necessary to separate off.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also went back and forth on this. Will keep it.

labels=config.labels,
n_ensemble=config.n_ensemble_per_ic,
)
Expand Down
7 changes: 3 additions & 4 deletions fme/ace/step/fcn3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from fme.core.distributed import Distributed
from fme.core.normalizer import NetworkAndLossNormalizationConfig, StandardNormalizer
from fme.core.ocean import Ocean, OceanConfig
from fme.core.optimization import NullOptimization
from fme.core.packer import Packer
from fme.core.registry import CorrectorSelector
from fme.core.step.args import StepArgs
Expand Down Expand Up @@ -229,7 +228,8 @@ def get_loss_normalizer(
extra_residual_scaled_names = []
return self.normalization.get_loss_normalizer(
names=self._normalize_names + extra_names,
residual_scaled_names=self.prognostic_names + extra_residual_scaled_names,
residual_scaled_names=self.get_prognostic_names()
+ extra_residual_scaled_names,
)

@classmethod
Expand Down Expand Up @@ -386,7 +386,6 @@ def __init__(
self.module = dist.wrap_module(module)
self._img_shape = dataset_info.img_shape
self._config = config
self._no_optimization = NullOptimization()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was unused.


self._timestep = timestep

Expand Down Expand Up @@ -483,7 +482,7 @@ def network_call(input_norm: TensorDict) -> TensorDict:
corrector=self._corrector,
ocean=self.ocean,
residual_prediction=self._config.residual_prediction,
prognostic_names=self.prognostic_names,
prognostic_names=self.get_prognostic_names(),
prescribed_prognostic_names=self._config.prescribed_prognostic_names,
)

Expand Down
18 changes: 9 additions & 9 deletions fme/ace/stepper/single_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def get_evaluation_window_data_requirements(

def get_prognostic_state_data_requirements(self) -> PrognosticStateDataRequirements:
return PrognosticStateDataRequirements(
names=self.prognostic_names,
names=self.get_prognostic_names(),
n_timesteps=self.n_ic_timesteps,
)

Expand Down Expand Up @@ -672,10 +672,9 @@ def next_step_forcing_names(self) -> list[str]:
"""
return self.step.get_next_step_forcing_names()

@property
def prognostic_names(self) -> list[str]:
def get_prognostic_names(self) -> list[str]:
"""Names of variables which both inputs and outputs."""
return self.step.prognostic_names
return self.step.get_prognostic_names()

@property
def output_names(self) -> list[str]:
Expand Down Expand Up @@ -971,9 +970,8 @@ def get_base_weights(self) -> Weights | None:
"""
return self._parameter_initializer.base_weights

@property
def prognostic_names(self) -> list[str]:
return self._step_obj.prognostic_names
def get_prognostic_names(self) -> list[str]:
return self._step_obj.get_prognostic_names()

@property
def out_names(self) -> list[str]:
Expand Down Expand Up @@ -1174,7 +1172,9 @@ def predict(
)
.remove_initial_condition(self.n_ic_timesteps)
)
prognostic_state = data.get_end(self.prognostic_names, self.n_ic_timesteps)
prognostic_state = data.get_end(
self.get_prognostic_names(), self.n_ic_timesteps
)
data = BatchData.new_on_device(
data=data.data,
time=data.time,
Expand Down Expand Up @@ -1494,7 +1494,7 @@ def __init__(

self._epoch: int | None = None # to keep track of cached values

self._prognostic_names = self._stepper.prognostic_names
self._prognostic_names = self._stepper.get_prognostic_names()
self._derive_func = self._stepper.derive_func
self._loss_obj = self._stepper.build_loss(config.loss)

Expand Down
10 changes: 5 additions & 5 deletions fme/core/generics/test_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def test_looper():
data={n: spherical_data.data[n][:, :1] for n in spherical_data.data},
time=time[:, 0:1],
).get_start(
prognostic_names=stepper.prognostic_names,
prognostic_names=stepper.get_prognostic_names(),
n_ic_timesteps=1,
)
loader = MockLoader(shape, forcing_names, 3, time=time)
Expand All @@ -197,7 +197,7 @@ def test_looper_paired():
data={n: spherical_data.data[n][:, :1] for n in spherical_data.data},
time=time[:, 0:1],
).get_start(
prognostic_names=stepper.prognostic_names,
prognostic_names=stepper.get_prognostic_names(),
n_ic_timesteps=1,
)
loader = MockLoader(shape, forcing_names, 3, time=time)
Expand Down Expand Up @@ -235,7 +235,7 @@ def test_looper_paired_with_derived_variables():
data={n: spherical_data.data[n][:, :1] for n in spherical_data.data},
time=time[:, 0:1],
).get_start(
prognostic_names=stepper.prognostic_names,
prognostic_names=stepper.get_prognostic_names(),
n_ic_timesteps=1,
)
loader = MockLoader(shape, forcing_names, 2, time=time)
Expand All @@ -258,7 +258,7 @@ def test_looper_paired_with_target_data():
data={n: spherical_data.data[n][:, :1] for n in spherical_data.data},
time=time[:, 0:1],
).get_start(
prognostic_names=stepper.prognostic_names,
prognostic_names=stepper.get_prognostic_names(),
n_ic_timesteps=1,
)
loader = MockLoader(shape, all_names, 2, time=time)
Expand All @@ -284,7 +284,7 @@ def test_looper_paired_with_target_data_and_derived_variables():
data={n: spherical_data.data[n][:, :1] for n in spherical_data.data},
time=time[:, 0:1],
).get_start(
prognostic_names=stepper.prognostic_names,
prognostic_names=stepper.get_prognostic_names(),
n_ic_timesteps=1,
)
loader = MockLoader(shape, all_names, 2, time=time)
Expand Down
1 change: 1 addition & 0 deletions fme/core/registry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .corrector import CorrectorSelector
from .module import ModuleSelector
from .registry import Registry
from .separated_module import SeparatedModuleSelector
204 changes: 204 additions & 0 deletions fme/core/registry/separated_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import abc
import dataclasses
from collections.abc import Callable, Mapping
from typing import Any, ClassVar

import dacite
import torch
from torch import nn

from fme.core.dataset_info import DatasetInfo
from fme.core.labels import BatchLabels, LabelEncoding

from .registry import Registry

SeparatedModuleConfigType = type["SeparatedModuleConfig"]


@dataclasses.dataclass
class SeparatedModuleConfig(abc.ABC):
"""
Builds an nn.Module that takes separate forcing and prognostic tensors
as input, and returns separate prognostic and diagnostic tensors as output.

The built nn.Module must have the forward signature::

forward(
forcing: Tensor,
prognostic: Tensor,
labels: Tensor | None = None,
Copy link
Contributor Author

@mcgibbon mcgibbon Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with requiring all modules of this new type take in a labels argument, since there's no need for backwards compatibility with ones that don't. For network types that don't have conditioning, we can broadcast these to the domain size and just treat them as additional input variables.

) -> tuple[Tensor, Tensor]

where the return value is (prognostic_out, diagnostic_out).
"""

@abc.abstractmethod
def build(
self,
n_forcing_channels: int,
n_prognostic_channels: int,
n_diagnostic_channels: int,
dataset_info: DatasetInfo,
) -> nn.Module:
"""
Build a nn.Module with separated forcing/prognostic/diagnostic channels.

Args:
n_forcing_channels: number of input-only (forcing) channels
n_prognostic_channels: number of input-output (prognostic) channels
n_diagnostic_channels: number of output-only (diagnostic) channels
dataset_info: Information about the dataset, including img_shape,
horizontal coordinates, vertical coordinate, etc.

Returns:
An nn.Module whose forward method takes
(forcing, prognostic, labels=None) and returns
(prognostic_out, diagnostic_out) tensors.
"""
...

@classmethod
def from_state(cls, state: Mapping[str, Any]) -> "SeparatedModuleConfig":
return dacite.from_dict(
data_class=cls, data=state, config=dacite.Config(strict=True)
)


class SeparatedModule:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This handles conversion of BatchLabels to Tensor, and also provides a strongly typed interface (unlike nn.Module which is treated as Any by mypy). We do the same thing for Module.

"""
Wrapper around an nn.Module with separated channel interface.

The wrapped module takes (forcing, prognostic) tensors and returns
(prognostic_out, diagnostic_out) tensors. This wrapper handles
optional label encoding for conditional models.
"""

def __init__(self, module: nn.Module, label_encoding: LabelEncoding | None):
self._module = module
self._label_encoding = label_encoding

def __call__(
self,
forcing: torch.Tensor,
prognostic: torch.Tensor,
labels: BatchLabels | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if labels is not None and self._label_encoding is None:
raise ValueError(
"labels were provided but the module has no label encoding"
)
if labels is None and self._label_encoding is not None:
raise ValueError(
"labels were not provided but the module has a label encoding"
)
if labels is not None and self._label_encoding is not None:
encoded_labels = labels.conform_to_encoding(self._label_encoding)
return self._module(forcing, prognostic, labels=encoded_labels.tensor)
Comment on lines +94 to +96
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to fme.core.registry.module.Module, should this raise an error if labels is not None and self._label_encoding is None or vice versa?

Also, I think this branch is untested.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noting that the if labels is not None and self._label_encoding is not None: path is still untested. Maybe this could be covered in fme/core/step/test_separated_module.py?

else:
return self._module(forcing, prognostic, labels=None)

@property
def torch_module(self) -> nn.Module:
return self._module

def get_state(self) -> dict[str, Any]:
if self._label_encoding is not None:
label_encoder_state = self._label_encoding.get_state()
else:
label_encoder_state = None
return {
**self._module.state_dict(),
"label_encoding": label_encoder_state,
}

def load_state(self, state: dict[str, Any]) -> None:
state = state.copy()
if state.get("label_encoding") is not None:
if self._label_encoding is None:
self._label_encoding = LabelEncoding.from_state(
state.pop("label_encoding")
)
else:
self._label_encoding.conform_to_state(state.pop("label_encoding"))
state.pop("label_encoding", None)
self._module.load_state_dict(state)

def wrap_module(
self, callable: Callable[[nn.Module], nn.Module]
) -> "SeparatedModule":
return SeparatedModule(callable(self._module), self._label_encoding)

def to(self, device: torch.device) -> "SeparatedModule":
return SeparatedModule(self._module.to(device), self._label_encoding)


@dataclasses.dataclass
class SeparatedModuleSelector:
"""
A dataclass for building SeparatedModuleConfig instances from config dicts.

Mirrors ModuleSelector but for the separated channel interface.

Parameters:
type: the type of the SeparatedModuleConfig
config: data for a SeparatedModuleConfig instance of the indicated type
"""

type: str
config: Mapping[str, Any]
registry: ClassVar[Registry[SeparatedModuleConfig]] = Registry[
SeparatedModuleConfig
]()

def __post_init__(self):
if not isinstance(self.registry, Registry):
raise ValueError(
"SeparatedModuleSelector.registry should not be set manually"
)
self._instance = self.registry.get(self.type, self.config)

@property
def module_config(self) -> SeparatedModuleConfig:
return self._instance

@classmethod
def register(
cls, type_name: str
) -> Callable[[SeparatedModuleConfigType], SeparatedModuleConfigType]:
return cls.registry.register(type_name)

def build(
self,
n_forcing_channels: int,
n_prognostic_channels: int,
n_diagnostic_channels: int,
dataset_info: DatasetInfo,
) -> SeparatedModule:
"""
Build a SeparatedModule with separated forcing/prognostic/diagnostic
channels.

Args:
n_forcing_channels: number of input-only (forcing) channels
n_prognostic_channels: number of input-output (prognostic) channels
n_diagnostic_channels: number of output-only (diagnostic) channels
dataset_info: Information about the dataset, including img_shape.

Returns:
a SeparatedModule object
"""
if len(dataset_info.all_labels) > 0:
label_encoding = LabelEncoding(sorted(list(dataset_info.all_labels)))
else:
label_encoding = None
module = self._instance.build(
n_forcing_channels=n_forcing_channels,
n_prognostic_channels=n_prognostic_channels,
n_diagnostic_channels=n_diagnostic_channels,
dataset_info=dataset_info,
)
return SeparatedModule(module, label_encoding)

@classmethod
def get_available_types(cls):
return cls.registry._types.keys()
Loading
Loading