Add separated channel module registry and step implementation#957
Add separated channel module registry and step implementation#957
Conversation
Add SeparatedModuleConfig, SeparatedModule, and SeparatedModuleSelector for modules that take separate forcing/prognostic input tensors and return separate prognostic/diagnostic output tensors. Include a LegacyModuleAdapter that wraps old-style single-tensor modules via LegacyWrapper for use in the new interface. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Add SeparatedModuleStepConfig/SeparatedModuleStep as a new StepABC implementation that works with the separated channel module interface. Uses explicit forcing_names, prognostic_names_, and diagnostic_names instead of deriving them from in/out name set operations. Remove state_dict/load_state_dict overrides from LegacyWrapper to fix an asymmetry where PyTorch's parent recursion calls child state_dict() overrides but not child load_state_dict() overrides, causing key mismatches when wrapped by DummyWrapper. Includes parametrized test coverage in test_step.py and dedicated tests with numerical equivalence regression tests vs SingleModuleStep. Co-Authored-By: Claude Opus 4.6 <[email protected]>
| @classmethod | ||
| def register( | ||
| cls, type_name: str | ||
| ) -> Callable[[Type[SeparatedModuleConfig]], Type[SeparatedModuleConfig]]: # noqa: UP006 |
There was a problem hiding this comment.
Question: Why is a noqa needed here?
There was a problem hiding this comment.
It's because of type being an attribute, so the type built-in is no longer accessible.
There was a problem hiding this comment.
Would it work to define SeparatedModuleConfigType = type[SeparatedModuleConfig] at the module level to avoid the clash?
| from fme.core.dataset_info import DatasetInfo | ||
| from fme.core.labels import BatchLabels, LabelEncoding | ||
|
|
||
| from .module import CONDITIONAL_BUILDERS, ModuleSelector |
There was a problem hiding this comment.
Question: Is it appropriate to re-use the CONDITIONAL_BUILDERS from .module? Or should we be defining these independently for SeparatedModule? Or should all new module types here support conditioning (we don't need backwards compatibility)? At a minimum labels can always be given as broadcasted input variables, right?
There was a problem hiding this comment.
Answer: Good question. Right now we reuse it because the only way to get a conditional separated module is through LegacyModuleAdapter, which delegates to ModuleSelector, so the same builder types apply. When we register native separated modules that support conditioning, we should define an independent CONDITIONAL_BUILDERS list (or a different mechanism) for this registry. For now, reusing it is correct since LegacyModuleAdapter is the only registered type.
fme/core/step/separated_module.py
Outdated
| if "secondary_decoder" in state: | ||
| self.secondary_decoder.load_module_state(state["secondary_decoder"]) |
There was a problem hiding this comment.
Question: Do we need this check? There aren't any existing checkpoints to maintain backwards compatibility for.
Suggestion: Remove this check if you can.
| from fme.core.step.step import StepSelector | ||
|
|
||
| IMG_SHAPE = (16, 32) | ||
| TIMESTEP = __import__("datetime").timedelta(hours=6) |
There was a problem hiding this comment.
Suggestion: Properly import timedelta instead of this weird line.
|
|
||
| return single_step, separated_step, in_names | ||
|
|
||
| def test_equivalence(self): |
There was a problem hiding this comment.
Suggestion: test_equivalence_with_(something), "single_step"?
| single_sfno = single_step.module.torch_module.module | ||
| separated_sfno = separated_step.module.torch_module.module.inner |
There was a problem hiding this comment.
Issue: I don't like digging into these internal properties, is it possible to do this with public API like the .modules attribute of StepABC?
There was a problem hiding this comment.
It wasn't really avoidable.
…tics Validate that prognostic_names_ is non-empty in SeparatedModuleStepConfig post-init, since prognostic variables are required. Add test cases for configurations with no forcings, no diagnostics, and both empty together. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Change StepConfigABC.prognostic_names and StepABC.prognostic_names from @Final @Property to @Final getter methods, enabling subclasses to use prognostic_names as a dataclass field name in a follow-up commit. Updates all callers across StepperConfig, Stepper, CoupledStepper, and tests. Co-Authored-By: Claude Opus 4.6 <[email protected]>
…nfig Now that the base class property was refactored to get_prognostic_names(), the field can use the cleaner name without conflict. Co-Authored-By: Claude Opus 4.6 <[email protected]>
- Fix timedelta import in test (use datetime.timedelta properly) - Remove unnecessary backwards-compat check in load_state - Rename test_equivalence to test_equivalence_with_single_step Co-Authored-By: Claude Opus 4.6 <[email protected]>
|
I ended up going the route of enabling prognostic/forcing/output tensors with a new StepABC implementation. Thinking about it, it's just too much of a fundamental change in the sense that it makes no sense to run an existing checkpoint with the new code - the input and output packers are fundamentally changed, and there's no guarantee the variable name ordering for existing checkpoints is valid under the new system. |
|
There are a bunch of weird questions for the Legacy wrapper, like conditional being set in both the inside and outside module selectors. I don't think we actually need this support - I'll just register a new separated builder when we have the local model that uses this feature. So for now this PR will just have a class/config used for testing. |
- Remove LegacyWrapper and LegacyModuleAdapter from separated module registry - Remove conditional field from SeparatedModuleSelector; all separated modules are expected to accept an optional labels argument - Replace legacy adapter usage in tests with SimpleSeparatedBuilder - Remove equivalence test class (was specific to legacy wrapper) Co-Authored-By: Claude Opus 4.6 <[email protected]>
…ace into feature/separated-module-registry
- Remove unused _no_optimization / NullOptimization from SingleModuleStep, SeparateRadiationStep, FCN3Step, and SeparatedModuleStep - Move SimpleSeparatedModule and its builder into separated_module.py with deferred registration via register_test_types(), eliminating test-to-test import dependency - Fix SeparatedModuleConfig.build docstring to document labels parameter Co-Authored-By: Claude Opus 4.6 <[email protected]>
| self.module = dist.wrap_module(module) | ||
| self._img_shape = dataset_info.img_shape | ||
| self._config = config | ||
| self._no_optimization = NullOptimization() |
| initial_condition = get_initial_condition( | ||
| config.initial_condition.get_dataset(), | ||
| stepper_config.prognostic_names, | ||
| stepper_config.get_prognostic_names(), |
There was a problem hiding this comment.
This property method needed to be refactored to a getter so that prognostic_names could be a dataclass attribute instead, on the new Step.
There was a problem hiding this comment.
What do you think about making this refactor a separate PR? Fairly minor change, but does touch a lot of files.
There was a problem hiding this comment.
Feel free to ignore, I don't think it's really necessary to separate off.
There was a problem hiding this comment.
I also went back and forth on this. Will keep it.
| forward( | ||
| forcing: Tensor, | ||
| prognostic: Tensor, | ||
| labels: Tensor | None = None, |
There was a problem hiding this comment.
I went with requiring all modules of this new type take in a labels argument, since there's no need for backwards compatibility with ones that don't. For network types that don't have conditioning, we can broadcast these to the domain size and just treat them as additional input variables.
| ) | ||
|
|
||
|
|
||
| class SeparatedModule: |
There was a problem hiding this comment.
This handles conversion of BatchLabels to Tensor, and also provides a strongly typed interface (unlike nn.Module which is treated as Any by mypy). We do the same thing for Module.
jpdunc23
left a comment
There was a problem hiding this comment.
This is looking good. My main concern is the handling of combinations of missing labels and encoding.
When I heard you talking about this PR I had a somewhat different idea of what the changes would look like. I thought that you would implement a multi-module rather than single-module approach. The multi-module approach might be impractical for certain architectures, e.g. if you want a shared embedding layer for all inputs. But it would also be more readily able to take advantage of existing nn.Modules. The StepABC would just be responsible for piping the appropriate args to the appropriate nn.Module without any single nn.Module having to know about "business logic" (i.e., distinctions between forcing, prognostic, and diagnostic variables).
Anyways, the approach here is also reasonable.
| @classmethod | ||
| def register( | ||
| cls, type_name: str | ||
| ) -> Callable[[Type[SeparatedModuleConfig]], Type[SeparatedModuleConfig]]: # noqa: UP006 |
There was a problem hiding this comment.
Would it work to define SeparatedModuleConfigType = type[SeparatedModuleConfig] at the module level to avoid the clash?
| initial_condition = get_initial_condition( | ||
| config.initial_condition.get_dataset(), | ||
| stepper_config.prognostic_names, | ||
| stepper_config.get_prognostic_names(), |
There was a problem hiding this comment.
What do you think about making this refactor a separate PR? Fairly minor change, but does touch a lot of files.
| if labels is not None and self._label_encoding is not None: | ||
| encoded_labels = labels.conform_to_encoding(self._label_encoding) | ||
| return self._module(forcing, prognostic, labels=encoded_labels.tensor) |
There was a problem hiding this comment.
Similar to fme.core.registry.module.Module, should this raise an error if labels is not None and self._label_encoding is None or vice versa?
Also, I think this branch is untested.
| ): | ||
| return _SimpleSeparatedModule( | ||
| n_forcing_channels, n_prognostic_channels, n_diagnostic_channels | ||
| ) |
There was a problem hiding this comment.
Could this can be moved to fme/core/registry/test_separated_module.py? I think so since we have existing tests where the registry is updated in the test file (e.g., see MockStepConfig in fme/core/step/test_step_registry.py). I think it would be cleaner to keep the testing modules isolated to the test files.
There was a problem hiding this comment.
Claude wanted to avoid one test file importing from another test file, which I agreed with. The issue being this is used in more than one test file.
There was a problem hiding this comment.
Perhaps a separate fme/core/registry/testing.py module?
There was a problem hiding this comment.
I think that would make sense.
| ocean: OceanConfig | None = None | ||
| corrector: AtmosphereCorrectorConfig | CorrectorSelector = dataclasses.field( | ||
| default_factory=lambda: AtmosphereCorrectorConfig() | ||
| ) |
There was a problem hiding this comment.
Since we don't need to support backwards compatibility I wish we could avoid this bit of atmosphere specificity:
| ocean: OceanConfig | None = None | |
| corrector: AtmosphereCorrectorConfig | CorrectorSelector = dataclasses.field( | |
| default_factory=lambda: AtmosphereCorrectorConfig() | |
| ) | |
| corrector: CorrectorSelector | None = None |
I'm working on some refactors now to make this change possible, though I don't necessarily think this PR should be held up by it.
There was a problem hiding this comment.
If you’re able to get those changes in soon, it wouldn’t be unreasonable to break backwards compatibility on this in the very near future before we have many checkpoints.
There was a problem hiding this comment.
At the least I think we can require CorrectorSelector now? I'm less clear on how we can run without an OceanConfig.
The issue where this breaks down is that we want to use the architecture differently for these types of variables. For example, the way Nvidia does residual updates for prognostic variables in the Makani code, or the way forcing variables are used as context instead of normal inputs. These features are impossible to implement without the module knowing about this distinction. Well, maybe not impossible, the Step could do most of it, but these really do feel like architectural choices and not like physical choices. I do think it’s good practice to continue using modules that take single tensors as much as possible when they don’t need to make this distinction, and concatenating to call them. |
Co-authored-by: James Duncan <[email protected]>
|
Mh as I'm working in tandem on the Local network, I'm finding an interesting question that is leading me to see the appeal of doing this stuff in Step... I'll need to think on it some more. |
- Replace typing.Type with type alias to remove noqa comments - Move test utilities to fme/core/registry/testing.py - Add labels/encoding mismatch validation in SeparatedModule - Use public .modules API in tests instead of internal properties - Properly import timedelta in test file Co-Authored-By: Claude Opus 4.6 <[email protected]>
Co-Authored-By: Claude Opus 4.6 <[email protected]>
The build_corrector method was removed from VerticalCoordinate on main. Use self.corrector.get_corrector(dataset_info) instead, matching SingleModuleStepConfig. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Adds a new module interface where modules explicitly take separate forcing and prognostic tensors as input and return separate prognostic and diagnostic tensors as output, instead of a single concatenated tensor. This makes channel semantics explicit at the module level.
Changes:
fme.core.registry.SeparatedModuleConfig,SeparatedModule,SeparatedModuleSelector: new module registry mirroringModuleSelectorbut with separated channel interfacefme.core.registry.LegacyModuleAdapter,LegacyWrapper: adapter wrapping old-style single-tensor modules for backwards compatibilityfme.core.step.SeparatedModuleStepConfig,SeparatedModuleStep: newStepABCimplementation using the separated channel interface, with explicitforcing_names,prognostic_names, anddiagnostic_namesfieldsStepConfigABC.prognostic_namesandStepABC.prognostic_names: refactored from@final @propertytoget_prognostic_names()getter method, enabling subclasses to useprognostic_namesas a dataclass field nameTests added