Add mHC support for HybridModel on dsv4 (stacks on #4483)#4529
Closed
Connor-XY wants to merge 10 commits intoNVIDIA:dsv4from
Closed
Add mHC support for HybridModel on dsv4 (stacks on #4483)#4529Connor-XY wants to merge 10 commits intoNVIDIA:dsv4from
Connor-XY wants to merge 10 commits intoNVIDIA:dsv4from
Conversation
Adds initial mHC support for `HybridModel` / `HybridStack` via a layer-boundary wrapper that treats each hybrid layer as a single function. Stacks on the transformer mHC reference impl in NVIDIA#4483. Implementation - `HyperConnectionHybridLayer`: wraps an inner hybrid layer (Mamba / GDN / TransformerLayer / DSA / MoE / MLP), aggregates the n-stream input down to single-stream, runs the inner layer, then feeds the layer delta `f(aggregated)` back through the n-stream H_res / H_post BDA so the wrapped layer's own residual update is not double-counted. - `HybridStack`: expands at the first-process boundary (`HyperConnectionModule.input_expand`) and contracts at the final layernorm (`output_contract`). Plumbs `mhc_recompute_manager` through the per-layer forward; caches the deterministic block-end plan via `_compute_mhc_block_end_plan`. Strict-review fixes carried over from the original PR - shape-preservation assert on `layer_output == aggregated.shape` so future inner layer types that drop the residual contract fail loud. - `aggregated -> layer_output.dtype` upcast before the delta subtraction when `fp32_residual_connection=True`. - `training=self.training` (not hard-coded `False`) on the BDA call. - explicit downcast of the BDA result to `params_dtype` after `h_res_h_post_bda` when `fp32_residual_connection=True`, so fp32 n-stream hidden states do not silently propagate to subsequent layers (~2x activation memory). - non-transformer inner-layer branch comment names the rotary_pos_emb / sequence_len_offset / padding_mask args that are intentionally dropped. Tests - `test_hybrid_model.py`: dummy HybridModel + Mamba + attention + MLP + GDN + DSA + DeepSeek-style proxy patterns. - `test_dsa_gpt_mamba_equivalence.py`: mamba <-> hybrid wrapping parity check. - `test_hybrid_block.py`: forward/backward + recompute coverage. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
This was referenced Apr 29, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds initial mHC support for
HybridModel/HybridStack. Carved out of PR #4469 so the hybrid wrapper, expansion / contraction, and hybrid-specific tests can be reviewed by@NVIDIA/hybrid-modelseparately from the transformer core mHC reference impl.This PR's content (the only files this PR is asking to add or modify):
megatron/core/models/hybrid/hybrid_block.py:HyperConnectionHybridLayer: layer-boundary wrapper. Aggregates n-stream input to single-stream, runs the inner hybrid layer, feeds only the function delta back throughH_res @ residual + H_post * deltaBDA so the inner layer's own residual update is not double-counted.HybridStack:input_expandat first stage,output_contractat the final layernorm; plumbsmhc_recompute_managerper-layer; caches the deterministic block-end plan.tests/unit_tests/models/test_hybrid_model.py: dummy HybridModel + Mamba / attention / MLP / GDN / DSA / DeepSeek-style proxy patterns.tests/unit_tests/models/test_dsa_gpt_mamba_equivalence.py: mamba ↔ hybrid wrapping parity.tests/unit_tests/ssm/test_hybrid_block.py: forward/backward + recompute coverage.Stacking note
This branch is built on top of #4483 (
yxu1/mhc-transformer-core-code-dsv4) so thatHyperConnectionModule,enable_hyper_connections,mhc_recompute_layer_num, etc. resolve. Cannot merge before #4483; "Files changed" shows #4483's content as ancestry — please review only the four files listed above.Origin
Carved out of PR #4469 at commit
e3d0102ad. Includes the hybrid-specific strict-review fixes from passes 3–11 of/claude strict-review:fp32_residual_connection-aware dtype handling on thelayer_deltasubtraction and on the BDA result (closes a fp32 propagation issue introduced by an earlier dtype-alignment fix; without this, the n-stream output silently flips to fp32 and propagates to every subsequent layer).training=self.training(not hard-codedFalse) on the BDA call.rotary_pos_emb/sequence_len_offset/padding_mask)._mhc_block_end_planon the instance instead of recomputing each forward.Validation
python3 -m compileallon the touched files. The dummy HybridModel forward/backward test passed in the strict-review CI; functional CI is in a separate stacked PR.Checkpoint compatibility note
HyperConnectionHybridLayeris a wrapper (inner layer isself.inner_layer), so wrapped-layer state_dict keys are nested underinner_layer.(e.g.layers.0.inner_layer.input_layernorm.weight). HybridStack checkpoints saved withenable_hyper_connections=Falseare not directly loadable into a model with mHC enabled (and vice versa) without a key-mapping migration. Documented in the class docstring; asharded_state_dictshim is left as future work for users toggling mHC on/off mid-training.🤖 Generated by Claude Opus 4.7 (1M context).