Skip to content

Add mHC support for HybridModel on dsv4 (stacks on #4483)#4529

Closed
Connor-XY wants to merge 10 commits intoNVIDIA:dsv4from
Connor-XY:yxu1/mhc-hybridmodel-stacked-dsv4
Closed

Add mHC support for HybridModel on dsv4 (stacks on #4483)#4529
Connor-XY wants to merge 10 commits intoNVIDIA:dsv4from
Connor-XY:yxu1/mhc-hybridmodel-stacked-dsv4

Conversation

@Connor-XY
Copy link
Copy Markdown

Summary

Adds initial mHC support for HybridModel / HybridStack. Carved out of PR #4469 so the hybrid wrapper, expansion / contraction, and hybrid-specific tests can be reviewed by @NVIDIA/hybrid-model separately from the transformer core mHC reference impl.

This PR's content (the only files this PR is asking to add or modify):

  • megatron/core/models/hybrid/hybrid_block.py:
    • HyperConnectionHybridLayer: layer-boundary wrapper. Aggregates n-stream input to single-stream, runs the inner hybrid layer, feeds only the function delta back through H_res @ residual + H_post * delta BDA so the inner layer's own residual update is not double-counted.
    • HybridStack: input_expand at first stage, output_contract at the final layernorm; plumbs mhc_recompute_manager per-layer; caches the deterministic block-end plan.
  • tests/unit_tests/models/test_hybrid_model.py: dummy HybridModel + Mamba / attention / MLP / GDN / DSA / DeepSeek-style proxy patterns.
  • tests/unit_tests/models/test_dsa_gpt_mamba_equivalence.py: mamba ↔ hybrid wrapping parity.
  • tests/unit_tests/ssm/test_hybrid_block.py: forward/backward + recompute coverage.

Stacking note

This branch is built on top of #4483 (yxu1/mhc-transformer-core-code-dsv4) so that HyperConnectionModule, enable_hyper_connections, mhc_recompute_layer_num, etc. resolve. Cannot merge before #4483; "Files changed" shows #4483's content as ancestry — please review only the four files listed above.

Origin

Carved out of PR #4469 at commit e3d0102ad. Includes the hybrid-specific strict-review fixes from passes 3–11 of /claude strict-review:

  • shape-preservation assert on inner-layer output (catches future layer types that drop the residual contract).
  • explicit fp32_residual_connection-aware dtype handling on the layer_delta subtraction and on the BDA result (closes a fp32 propagation issue introduced by an earlier dtype-alignment fix; without this, the n-stream output silently flips to fp32 and propagates to every subsequent layer).
  • training=self.training (not hard-coded False) on the BDA call.
  • non-transformer branch comment naming the args that are intentionally dropped (Mamba doesn't accept rotary_pos_emb / sequence_len_offset / padding_mask).
  • caches _mhc_block_end_plan on the instance instead of recomputing each forward.

Validation

python3 -m compileall on the touched files. The dummy HybridModel forward/backward test passed in the strict-review CI; functional CI is in a separate stacked PR.

Checkpoint compatibility note

HyperConnectionHybridLayer is a wrapper (inner layer is self.inner_layer), so wrapped-layer state_dict keys are nested under inner_layer. (e.g. layers.0.inner_layer.input_layernorm.weight). HybridStack checkpoints saved with enable_hyper_connections=False are not directly loadable into a model with mHC enabled (and vice versa) without a key-mapping migration. Documented in the class docstring; a sharded_state_dict shim is left as future work for users toggling mHC on/off mid-training.

🤖 Generated by Claude Opus 4.7 (1M context).

Connor-XY and others added 10 commits April 27, 2026 09:15
Adds initial mHC support for `HybridModel` / `HybridStack` via a
layer-boundary wrapper that treats each hybrid layer as a single
function. Stacks on the transformer mHC reference impl in NVIDIA#4483.

Implementation
- `HyperConnectionHybridLayer`: wraps an inner hybrid layer
  (Mamba / GDN / TransformerLayer / DSA / MoE / MLP), aggregates the
  n-stream input down to single-stream, runs the inner layer, then
  feeds the layer delta `f(aggregated)` back through the n-stream
  H_res / H_post BDA so the wrapped layer's own residual update is
  not double-counted.
- `HybridStack`: expands at the first-process boundary
  (`HyperConnectionModule.input_expand`) and contracts at the
  final layernorm (`output_contract`). Plumbs `mhc_recompute_manager`
  through the per-layer forward; caches the deterministic block-end
  plan via `_compute_mhc_block_end_plan`.

Strict-review fixes carried over from the original PR
- shape-preservation assert on `layer_output == aggregated.shape` so
  future inner layer types that drop the residual contract fail loud.
- `aggregated -> layer_output.dtype` upcast before the delta
  subtraction when `fp32_residual_connection=True`.
- `training=self.training` (not hard-coded `False`) on the BDA call.
- explicit downcast of the BDA result to `params_dtype` after
  `h_res_h_post_bda` when `fp32_residual_connection=True`, so fp32
  n-stream hidden states do not silently propagate to subsequent
  layers (~2x activation memory).
- non-transformer inner-layer branch comment names the
  rotary_pos_emb / sequence_len_offset / padding_mask args that are
  intentionally dropped.

Tests
- `test_hybrid_model.py`: dummy HybridModel + Mamba + attention + MLP
  + GDN + DSA + DeepSeek-style proxy patterns.
- `test_dsa_gpt_mamba_equivalence.py`: mamba <-> hybrid wrapping
  parity check.
- `test_hybrid_block.py`: forward/backward + recompute coverage.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 29, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant