Skip to content

Add separated channel module registry and step implementation#957

Open
mcgibbon wants to merge 18 commits intomainfrom
feature/separated-module-registry
Open

Add separated channel module registry and step implementation#957
mcgibbon wants to merge 18 commits intomainfrom
feature/separated-module-registry

Conversation

@mcgibbon
Copy link
Contributor

@mcgibbon mcgibbon commented Mar 11, 2026

Adds a new module interface where modules explicitly take separate forcing and prognostic tensors as input and return separate prognostic and diagnostic tensors as output, instead of a single concatenated tensor. This makes channel semantics explicit at the module level.

Changes:

  • fme.core.registry.SeparatedModuleConfig, SeparatedModule, SeparatedModuleSelector: new module registry mirroring ModuleSelector but with separated channel interface

  • fme.core.registry.LegacyModuleAdapter, LegacyWrapper: adapter wrapping old-style single-tensor modules for backwards compatibility

  • fme.core.step.SeparatedModuleStepConfig, SeparatedModuleStep: new StepABC implementation using the separated channel interface, with explicit forcing_names, prognostic_names, and diagnostic_names fields

  • StepConfigABC.prognostic_names and StepABC.prognostic_names: refactored from @final @property to get_prognostic_names() getter method, enabling subclasses to use prognostic_names as a dataclass field name

  • Tests added

mcgibbon and others added 2 commits March 11, 2026 18:49
Add SeparatedModuleConfig, SeparatedModule, and SeparatedModuleSelector
for modules that take separate forcing/prognostic input tensors and
return separate prognostic/diagnostic output tensors. Include a
LegacyModuleAdapter that wraps old-style single-tensor modules via
LegacyWrapper for use in the new interface.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Add SeparatedModuleStepConfig/SeparatedModuleStep as a new StepABC
implementation that works with the separated channel module interface.
Uses explicit forcing_names, prognostic_names_, and diagnostic_names
instead of deriving them from in/out name set operations.

Remove state_dict/load_state_dict overrides from LegacyWrapper to fix
an asymmetry where PyTorch's parent recursion calls child state_dict()
overrides but not child load_state_dict() overrides, causing key
mismatches when wrapped by DummyWrapper.

Includes parametrized test coverage in test_step.py and dedicated tests
with numerical equivalence regression tests vs SingleModuleStep.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
@classmethod
def register(
cls, type_name: str
) -> Callable[[Type[SeparatedModuleConfig]], Type[SeparatedModuleConfig]]: # noqa: UP006
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Why is a noqa needed here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because of type being an attribute, so the type built-in is no longer accessible.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it work to define SeparatedModuleConfigType = type[SeparatedModuleConfig] at the module level to avoid the clash?

from fme.core.dataset_info import DatasetInfo
from fme.core.labels import BatchLabels, LabelEncoding

from .module import CONDITIONAL_BUILDERS, ModuleSelector
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Is it appropriate to re-use the CONDITIONAL_BUILDERS from .module? Or should we be defining these independently for SeparatedModule? Or should all new module types here support conditioning (we don't need backwards compatibility)? At a minimum labels can always be given as broadcasted input variables, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Answer: Good question. Right now we reuse it because the only way to get a conditional separated module is through LegacyModuleAdapter, which delegates to ModuleSelector, so the same builder types apply. When we register native separated modules that support conditioning, we should define an independent CONDITIONAL_BUILDERS list (or a different mechanism) for this registry. For now, reusing it is correct since LegacyModuleAdapter is the only registered type.

Comment on lines +394 to +395
if "secondary_decoder" in state:
self.secondary_decoder.load_module_state(state["secondary_decoder"])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Do we need this check? There aren't any existing checkpoints to maintain backwards compatibility for.

Suggestion: Remove this check if you can.

from fme.core.step.step import StepSelector

IMG_SHAPE = (16, 32)
TIMESTEP = __import__("datetime").timedelta(hours=6)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Properly import timedelta instead of this weird line.


return single_step, separated_step, in_names

def test_equivalence(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: test_equivalence_with_(something), "single_step"?

Comment on lines +335 to +336
single_sfno = single_step.module.torch_module.module
separated_sfno = separated_step.module.torch_module.module.inner
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: I don't like digging into these internal properties, is it possible to do this with public API like the .modules attribute of StepABC?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wasn't really avoidable.

mcgibbon and others added 4 commits March 11, 2026 19:53
…tics

Validate that prognostic_names_ is non-empty in SeparatedModuleStepConfig
post-init, since prognostic variables are required. Add test cases for
configurations with no forcings, no diagnostics, and both empty together.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Change StepConfigABC.prognostic_names and StepABC.prognostic_names from
@Final @Property to @Final getter methods, enabling subclasses to use
prognostic_names as a dataclass field name in a follow-up commit. Updates
all callers across StepperConfig, Stepper, CoupledStepper, and tests.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
…nfig

Now that the base class property was refactored to get_prognostic_names(),
the field can use the cleaner name without conflict.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
- Fix timedelta import in test (use datetime.timedelta properly)
- Remove unnecessary backwards-compat check in load_state
- Rename test_equivalence to test_equivalence_with_single_step

Co-Authored-By: Claude Opus 4.6 <[email protected]>
@mcgibbon mcgibbon changed the title Feature/separated module registry Add separated channel module registry and step implementation Mar 11, 2026
@mcgibbon
Copy link
Contributor Author

I ended up going the route of enabling prognostic/forcing/output tensors with a new StepABC implementation. Thinking about it, it's just too much of a fundamental change in the sense that it makes no sense to run an existing checkpoint with the new code - the input and output packers are fundamentally changed, and there's no guarantee the variable name ordering for existing checkpoints is valid under the new system.

@mcgibbon
Copy link
Contributor Author

mcgibbon commented Mar 11, 2026

There are a bunch of weird questions for the Legacy wrapper, like conditional being set in both the inside and outside module selectors. I don't think we actually need this support - I'll just register a new separated builder when we have the local model that uses this feature. So for now this PR will just have a class/config used for testing.

mcgibbon and others added 3 commits March 11, 2026 20:44
- Remove LegacyWrapper and LegacyModuleAdapter from separated module registry
- Remove conditional field from SeparatedModuleSelector; all separated
  modules are expected to accept an optional labels argument
- Replace legacy adapter usage in tests with SimpleSeparatedBuilder
- Remove equivalence test class (was specific to legacy wrapper)

Co-Authored-By: Claude Opus 4.6 <[email protected]>
- Remove unused _no_optimization / NullOptimization from SingleModuleStep,
  SeparateRadiationStep, FCN3Step, and SeparatedModuleStep
- Move SimpleSeparatedModule and its builder into separated_module.py with
  deferred registration via register_test_types(), eliminating test-to-test
  import dependency
- Fix SeparatedModuleConfig.build docstring to document labels parameter

Co-Authored-By: Claude Opus 4.6 <[email protected]>
@mcgibbon mcgibbon marked this pull request as ready for review March 12, 2026 14:34
self.module = dist.wrap_module(module)
self._img_shape = dataset_info.img_shape
self._config = config
self._no_optimization = NullOptimization()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was unused.

initial_condition = get_initial_condition(
config.initial_condition.get_dataset(),
stepper_config.prognostic_names,
stepper_config.get_prognostic_names(),
Copy link
Contributor Author

@mcgibbon mcgibbon Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This property method needed to be refactored to a getter so that prognostic_names could be a dataclass attribute instead, on the new Step.

Copy link
Member

@jpdunc23 jpdunc23 Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about making this refactor a separate PR? Fairly minor change, but does touch a lot of files.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to ignore, I don't think it's really necessary to separate off.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also went back and forth on this. Will keep it.

forward(
forcing: Tensor,
prognostic: Tensor,
labels: Tensor | None = None,
Copy link
Contributor Author

@mcgibbon mcgibbon Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with requiring all modules of this new type take in a labels argument, since there's no need for backwards compatibility with ones that don't. For network types that don't have conditioning, we can broadcast these to the domain size and just treat them as additional input variables.

)


class SeparatedModule:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This handles conversion of BatchLabels to Tensor, and also provides a strongly typed interface (unlike nn.Module which is treated as Any by mypy). We do the same thing for Module.

Copy link
Member

@jpdunc23 jpdunc23 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking good. My main concern is the handling of combinations of missing labels and encoding.

When I heard you talking about this PR I had a somewhat different idea of what the changes would look like. I thought that you would implement a multi-module rather than single-module approach. The multi-module approach might be impractical for certain architectures, e.g. if you want a shared embedding layer for all inputs. But it would also be more readily able to take advantage of existing nn.Modules. The StepABC would just be responsible for piping the appropriate args to the appropriate nn.Module without any single nn.Module having to know about "business logic" (i.e., distinctions between forcing, prognostic, and diagnostic variables).

Anyways, the approach here is also reasonable.

@classmethod
def register(
cls, type_name: str
) -> Callable[[Type[SeparatedModuleConfig]], Type[SeparatedModuleConfig]]: # noqa: UP006
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it work to define SeparatedModuleConfigType = type[SeparatedModuleConfig] at the module level to avoid the clash?

initial_condition = get_initial_condition(
config.initial_condition.get_dataset(),
stepper_config.prognostic_names,
stepper_config.get_prognostic_names(),
Copy link
Member

@jpdunc23 jpdunc23 Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about making this refactor a separate PR? Fairly minor change, but does touch a lot of files.

Comment on lines +86 to +88
if labels is not None and self._label_encoding is not None:
encoded_labels = labels.conform_to_encoding(self._label_encoding)
return self._module(forcing, prognostic, labels=encoded_labels.tensor)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to fme.core.registry.module.Module, should this raise an error if labels is not None and self._label_encoding is None or vice versa?

Also, I think this branch is untested.

):
return _SimpleSeparatedModule(
n_forcing_channels, n_prognostic_channels, n_diagnostic_channels
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this can be moved to fme/core/registry/test_separated_module.py? I think so since we have existing tests where the registry is updated in the test file (e.g., see MockStepConfig in fme/core/step/test_step_registry.py). I think it would be cleaner to keep the testing modules isolated to the test files.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude wanted to avoid one test file importing from another test file, which I agreed with. The issue being this is used in more than one test file.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps a separate fme/core/registry/testing.py module?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that would make sense.

Comment on lines +62 to +65
ocean: OceanConfig | None = None
corrector: AtmosphereCorrectorConfig | CorrectorSelector = dataclasses.field(
default_factory=lambda: AtmosphereCorrectorConfig()
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we don't need to support backwards compatibility I wish we could avoid this bit of atmosphere specificity:

Suggested change
ocean: OceanConfig | None = None
corrector: AtmosphereCorrectorConfig | CorrectorSelector = dataclasses.field(
default_factory=lambda: AtmosphereCorrectorConfig()
)
corrector: CorrectorSelector | None = None

I'm working on some refactors now to make this change possible, though I don't necessarily think this PR should be held up by it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you’re able to get those changes in soon, it wouldn’t be unreasonable to break backwards compatibility on this in the very near future before we have many checkpoints.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the least I think we can require CorrectorSelector now? I'm less clear on how we can run without an OceanConfig.

@mcgibbon
Copy link
Contributor Author

mcgibbon commented Mar 12, 2026

without any single nn.Module having to know about "business logic" (i.e., distinctions between forcing, prognostic, and diagnostic variables).

The issue where this breaks down is that we want to use the architecture differently for these types of variables. For example, the way Nvidia does residual updates for prognostic variables in the Makani code, or the way forcing variables are used as context instead of normal inputs. These features are impossible to implement without the module knowing about this distinction. Well, maybe not impossible, the Step could do most of it, but these really do feel like architectural choices and not like physical choices.

I do think it’s good practice to continue using modules that take single tensors as much as possible when they don’t need to make this distinction, and concatenating to call them.

@mcgibbon
Copy link
Contributor Author

Mh as I'm working in tandem on the Local network, I'm finding an interesting question that is leading me to see the appeal of doing this stuff in Step... I'll need to think on it some more.

mcgibbon and others added 5 commits March 20, 2026 20:40
- Replace typing.Type with type alias to remove noqa comments
- Move test utilities to fme/core/registry/testing.py
- Add labels/encoding mismatch validation in SeparatedModule
- Use public .modules API in tests instead of internal properties
- Properly import timedelta in test file

Co-Authored-By: Claude Opus 4.6 <[email protected]>
The build_corrector method was removed from VerticalCoordinate on main.
Use self.corrector.get_corrector(dataset_info) instead, matching
SingleModuleStepConfig.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants