Skip to content

Add mHC transformer reference implementation#4483

Closed
Connor-XY wants to merge 10 commits intoNVIDIA:dsv4from
Connor-XY:yxu1/mhc-transformer-core-code-dsv4
Closed

Add mHC transformer reference implementation#4483
Connor-XY wants to merge 10 commits intoNVIDIA:dsv4from
Connor-XY:yxu1/mhc-transformer-core-code-dsv4

Conversation

@Connor-XY
Copy link
Copy Markdown

What does this PR do?

This is PR 1 in a fine-grained split of #4469. It ports the core/reference Transformer support for Manifold Hyper-Connections (mHC) onto dsv4.

The original mHC PR targeting main is #3430. This split targets dsv4 per the DeepSeek/DSv4 workflow discussion, so reviewers can inspect the change in smaller pieces.

This PR includes the code path only:

  • Transformer config/argument plumbing for mHC options
  • HyperConnectionModule reference implementation
  • HyperConnectionTransformerLayer integration
  • TransformerBlock input expansion/output contraction
  • mHC selective recompute plumbing through CheckpointManager
  • GPT layer spec plumbing for enabling hyper connections

Follow-up split PRs are stacked after this one:

  1. mHC transformer unit tests
  2. fused mHC kernel support and tests
  3. pipeline-parallel mHC compatibility
  4. functional CI config and golden values
  5. HybridModel support on top of the transformer stack

Testing

  • git diff --check passed for the split stack.
  • python -m compileall passed for the touched Python files in the full split stack.
  • Unit/functional tests are separated into follow-up PRs to keep this PR focused on the core implementation surface.

@Connor-XY Connor-XY requested review from a team as code owners April 27, 2026 16:20
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 27, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Copy link
Copy Markdown
Author

Connor-XY commented Apr 27, 2026

Update after the latest push: this first split PR is now intended to be self-contained, not code-only.

It includes both the core/reference Transformer mHC implementation and focused unit-test coverage for that implementation:

  • mHC config / argument plumbing
  • HyperConnectionModule reference implementation
  • HyperConnectionTransformerLayer integration
  • TransformerBlock expand/contract handling
  • mHC selective recompute / CheckpointManager plumbing
  • GPT layer spec integration for enabling hyper connections
  • unit tests covering GPT layer specs, Hybrid MoE config compatibility, FP8 compatibility plumbing, transformer mHC recompute, mHC block-manager behavior, CUDA graph shape/capture behavior, and offloading behavior

Validation on hsg at head 00c333927:

  • git diff --check upstream/dsv4...yxu1/mhc-transformer-core-code-dsv4 passed
  • python -m compileall passed over the touched transformer and unit-test areas
  • Focused Slurm/container mHC unit subset passed: 51 passed, 8 skipped in 228.07s

I also tried a broader file-level pytest invocation first; it stopped on an existing non-mHC test_sharded_state_dict[tp_pp0-tp-pp-dp] case because the interactive hsg allocation has 4 GPUs and that parametrization requires world_size=8 (tp=4, pp=2). The rerun above targets the mHC-specific tests and compatibility files.

Follow-up split PRs should now start with fused mHC kernel support, pipeline-parallel compatibility, functional CI/golden coverage, and HybridModel support. The earlier PR description text saying "code path only" / "unit tests are separated" is stale.

@Connor-XY
Copy link
Copy Markdown
Author

/ok to test 00c3339

@Connor-XY
Copy link
Copy Markdown
Author

/claude strict-review

Comment on lines +1552 to +1556
if self.enable_hyper_connections and self.mtp_num_layers is not None:
raise ValueError(
"enable_hyper_connections is not compatible with Multi-Token Prediction (MTP). "
"Please disable MTP (set mtp_num_layers=None) when using hyper connections."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[CRITICAL Correctness] Missing validation: enable_hyper_connections with pipeline_model_parallel_size > 1 will silently produce wrong results or crash.

The code in transformer_block.py has PP-aware branching (self.pre_process for expand, has_final_layernorm_in_this_stage() for contract), implying n-stream hidden states flow between PP stages. However, get_tensor_shapes() in schedules.py computes the P2P tensor shape as (seq_len, micro_batch_size, config.hidden_size) — it does not account for num_residual_streams. This means P2P communication will allocate receive buffers of the wrong size, causing either a crash or silent data corruption.

The PR description notes "pipeline-parallel mHC compatibility" as a follow-up, but users can still configure enable_hyper_connections=True with PP > 1 today. Please add a guard here:

if self.enable_hyper_connections and self.pipeline_model_parallel_size > 1:
    raise ValueError(
        "enable_hyper_connections is not yet compatible with pipeline parallelism "
        "(pipeline_model_parallel_size > 1). Pipeline-parallel support is planned "
        "for a follow-up PR."
    )

Initialize CheckpointWithoutOutput.

Args:
fp8: Whether to use FP8 mode. Defaults to False.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Compatibility] Behavior change in fp8 flag: fp8 is not Nonebool(fp8).

The old code self.fp8 = fp8 is not None evaluated to True for the default fp8=False, which was a latent bug — it would attempt to set up FP8 recompute context even when FP8 wasn't in use. The new bool(fp8) correctly evaluates to False for the default case.

This fix is correct, but it changes runtime behavior for all existing callers that use the default CheckpointWithoutOutput() (no fp8 arg). Previously they would enter the FP8 context path during recompute; now they won't. If TE's get_fp8_context_for_recompute() happened to be a no-op when FP8 wasn't globally enabled, this change is safe. But if any caller relied on the old behavior, this could subtly break.

Please verify that no existing callers are affected and consider adding a brief comment or changelog note that this is an intentional bug fix.

ctx.outputs = None
ctx.inputs = None
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in inputs)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Compatibility] Changed non-tensor gradient return from inp to None.

The old code returned the non-tensor value itself as a "gradient" (else inp), which is technically incorrect — autograd expects None for non-differentiable inputs. The new code (else None) is correct.

However, this is a behavioral change to CheckpointWithoutOutputFunction.backward that affects ALL callers of CheckpointWithoutOutput, not just the new mHC code path. The old code happened to work because non-tensor args were never passed to save_for_backward (which only accepts tensors), so this path was likely never exercised. The new _save_args_to_ctx now properly handles non-tensor args, making this fix necessary and correct.

Consider splitting this fix into a separate commit to isolate the backward-compatibility change from the mHC feature addition.

Comment on lines +798 to +803
# Expand hidden states for hyper connections at the start of the block
# Only expand at the first PP stage; subsequent stages receive n-stream from previous stage
if self.config.enable_hyper_connections and self.pre_process:
hidden_states = HyperConnectionModule.input_expand(
hidden_states, self.num_residual_streams
) # [s, b, C] -> [s, b, n*C]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Correctness] The self.pre_process guard assumes PP-aware n-stream passing, but PP is not yet supported.

This check implies that non-first PP stages receive n-stream tensors from the previous stage and should not expand. But since PP support is not implemented (no shape adaptation in get_tensor_shapes() / p2p_communication), this guard is premature and misleading.

Until PP support is added, this should either:

  1. Use a simpler unconditional expand (since PP>1 should be blocked by config validation), or
  2. At minimum, add an assertion to catch the case where not self.pre_process with HC enabled, since it would be an invalid state.

If you add the PP>1 validation guard I suggested in transformer_config.py, then self.pre_process will always be True when HC is enabled, making this condition redundant but harmless.

Comment on lines +924 to +928
# Only contract if the final layer norm is in this stage
if self.config.enable_hyper_connections and self.has_final_layernorm_in_this_stage():
hidden_states = HyperConnectionModule.output_contract(
hidden_states, self.num_residual_streams
) # [s, b, n*C] -> [s, b, C]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Correctness] has_final_layernorm_in_this_stage() may not be the right guard for output contraction.

has_final_layernorm_in_this_stage() checks whether this stage has the final layernorm (based on post_process, post_layer_norm, and MTP config). But output_contract converts [s, b, n*C] → [s, b, C], which must happen exactly once at the model boundary — specifically right before the final layernorm, which expects [s, b, C].

For configurations without a final layernorm in this stage (e.g., post_layer_norm=False, or MTP configurations where the final layernorm is in a different stage), this contraction would be skipped, and the n-stream tensor would be passed to subsequent processing (output layer, loss computation) with the wrong hidden dimension.

This is mitigated if PP>1 is blocked by validation (as I suggested), but even with PP=1, post_layer_norm=False could skip this contraction. Please verify that the guard covers all valid configurations.

Comment on lines +783 to +797
def discard_all_outputs_and_register_unified_recompute(self, hook_tensor):
"""Discard all checkpoint outputs to save memory and register unified recompute hook."""
for ckpt in self.checkpoints:
for output in ckpt.outputs:
output.untyped_storage().resize_(0)

# Register unified recompute hook
if hook_tensor.requires_grad:
hook_tensor.register_hook(self._unified_recompute_hook)

def _unified_recompute_hook(self, grad_output):
for ckpt in self.checkpoints:
# Call _recompute for each checkpoint in forward order
# The _recompute method will restore the output tensor storage
ckpt._recompute(None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Correctness] discard_all_outputs_and_register_unified_recompute relies on forward-order recomputation — the ordering assumption should be documented.

When chained checkpoints are used (checkpoint 2's input = checkpoint 1's output), discarding ALL outputs and then recomputing in forward order is correct because:

  1. Checkpoint 1's _recompute() restores its output tensor storage in-place via untyped_storage().resize_() + copy_()
  2. Checkpoint 2's saved input (which shares storage with checkpoint 1's output) is thereby restored before checkpoint 2's _recompute() needs it

This is subtle and depends on:

  • _recompute() restoring storage in-place on the original tensor objects
  • save_for_backward() preserving storage identity (not copying)
  • Iterating self.checkpoints in forward (append) order

Consider adding a comment explaining this invariant, since breaking any of these assumptions (e.g., recomputing in reverse order, or PyTorch changing save_for_backward internals) would cause silent data corruption.

Comment on lines +74 to +76
residual_batched = original_residual.view(s * b, n, C)
mixed = torch.bmm(h_res_batched, residual_batched).view(s, b, n, C)
x_expanded = h_post.unsqueeze(-1) * x.unsqueeze(2)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION Naming] native_h_aggregate name is ambiguous — "native" could mean "PyTorch native" or "CPU-native".

Throughout this file, native_ prefix is used to distinguish reference PyTorch implementations from fused cuTile kernels. Consider reference_ or unfused_ instead, which is clearer about the distinction:

Suggested change
residual_batched = original_residual.view(s * b, n, C)
mixed = torch.bmm(h_res_batched, residual_batched).view(s, b, n, C)
x_expanded = h_post.unsqueeze(-1) * x.unsqueeze(2)
def reference_h_aggregate(x: Tensor, h_pre: Tensor) -> Tensor:

Or alternatively, adopt a naming convention consistent with the rest of the codebase (e.g., the existing bias_dropout_add_unfused pattern).

Comment on lines +124 to +126
self.hidden_size = config.hidden_size
self.sinkhorn_iterations = config.mhc_sinkhorn_iterations

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION Simplification] TODO: keep hyper connection in fp32 computation — this TODO flags a real concern.

The mapping projection, Sinkhorn iterations, and gating computations all run in the model's default dtype (typically bf16/fp16). Since these involve:

  • Small scalar parameters (alpha_pre/post/res, bias) that can lose precision
  • Iterative normalization (Sinkhorn) with .clamp(min=eps) that is sensitive to underflow
  • RMS normalization with rsqrt-equivalent computation

Running these in reduced precision could cause training instability, especially at scale. Consider at minimum upcasting the Sinkhorn iterations to fp32 (the .clamp(min=eps) with eps=1e-6 is dangerously close to fp16's epsilon of ~6e-8).

If this is intentionally deferred, please convert this to a more explicit TODO with a tracking reference.

use_kitchen_attention: bool = False,
kitchen_attention_backend: str = "sdpa",
mla_down_proj_fusion: bool = False,
enable_hyper_connection: bool = False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION Naming] Inconsistent naming: enable_hyper_connection (singular) here vs enable_hyper_connections (plural) in TransformerConfig.

The config field is config.enable_hyper_connections (plural), but the layer spec parameter is enable_hyper_connection (singular). This is confusing and error-prone — a typo in either direction would silently use the default False.

Consider aligning to one name. Since the config uses plural (enable_hyper_connections), the spec functions should match:

Suggested change
enable_hyper_connection: bool = False,
enable_hyper_connection: bool = False,

Or rename to enable_hyper_connections throughout the spec functions to match the config.

Comment on lines +711 to +716

Returns:
block_size: Optimal block size for checkpointing
"""
block_size = int(math.sqrt(num_streams * num_layers / (num_streams + 2)))
return max(1, block_size)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION Simplification] HyperConnectionCheckpoint.compute_optimal_block_size is defined but never called in this PR.

This static method implements Eq. (20) from the paper but has no callers. If it's intended for follow-up PRs, consider deferring it to the PR that actually uses it. If it's meant to be called from _build_mhc_recompute_layer_plan as a default for mhc_recompute_layer_num, please wire it up or add a TODO referencing the follow-up.

Comment on lines +1301 to +1305
def __call__(self, *args, **kwargs):
# Extract mhc_recompute_manager before CUDA graph manager processes kwargs,
# since CheckpointManager is not a CUDA-graph-supported type.
self._mhc_recompute_manager = kwargs.pop("mhc_recompute_manager", None)
kwargs.pop("is_last_layer_in_recompute_block", None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Compatibility] __call__ override on the base TransformerLayer changes the call interface for ALL layer types, not just mHC.

This __call__ override is on TransformerLayer (the base class), meaning every layer type (including MoETransformerLayer and any downstream custom layers) now has mhc_recompute_manager and is_last_layer_in_recompute_block silently popped from kwargs. While the pop with default None is safe, storing self._mhc_recompute_manager on every layer instance is unnecessary overhead for non-HC layers.

Consider either:

  1. Moving this __call__ override to HyperConnectionTransformerLayer only, or
  2. At minimum, guarding the attribute assignment: if self.config.enable_hyper_connections:

Additionally, kwargs.pop("is_last_layer_in_recompute_block", None) pops a kwarg that is never passed from transformer_block.py (the block sets it as an attribute on the manager instead). This is dead code.

Comment on lines +707 to +711
# Injected by __call__ for cuda graph keying; not a real forward arg.
kwargs.pop("dynamic_inference_decode_only", None)
assert (
not self.config.enable_hyper_connections
), "Please use HyperConnectionTransformerLayer instead"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Compatibility] Hard assertion in base TransformerLayer.forward blocks HC-enabled configs from using the base class.

This assertion fires for any code path that calls TransformerLayer.forward() when enable_hyper_connections=True — even if hyper connections are not actually needed (e.g., MTP layers that strip HC, or custom model code).

The MTP path in gpt_layer_specs.py correctly downgrades the layer module to TransformerLayer, but the config still has enable_hyper_connections=True. This assertion would fire in that case.

Consider either:

  1. Checking isinstance(self, HyperConnectionTransformerLayer) instead of the config flag, or
  2. Removing the assertion and letting the spec/module selection handle correctness.

Comment on lines +1499 to +1504
raise ValueError(
"'mhc' and 'mlp' in recompute_modules cannot be used together. "
"They use different checkpoint mechanisms that may conflict."
)
if self.mhc_recompute_layer_num is not None and (
isinstance(self.mhc_recompute_layer_num, bool)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION Simplification] The isinstance(self.mhc_recompute_layer_num, bool) check before the int check is needed because bool is a subclass of int in Python. This is correct but non-obvious — a reader might wonder why bool is checked separately.

Consider simplifying with a type annotation or a clearer guard:

if self.mhc_recompute_layer_num is not None:
    if not isinstance(self.mhc_recompute_layer_num, int) or isinstance(
        self.mhc_recompute_layer_num, bool
    ) or self.mhc_recompute_layer_num < 1:

Or adding a one-line comment: # bool is a subclass of int, reject it explicitly.

Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Summary

PR: Add mHC transformer reference implementation

Reviewed the full diff (3489 additions, 79 deletions across 18 files). This PR ports the core Manifold Hyper-Connections (mHC) implementation onto dsv4, adding:

  • HyperConnectionModule with Sinkhorn-projected dynamic mappings
  • HyperConnectionTransformerLayer extending TransformerLayer with n-stream pre/post processing
  • CheckpointManager for cross-layer mHC activation recompute
  • Mixed tensor/non-tensor _save_args_to_ctx / _load_args_from_ctx helpers
  • Comprehensive config validation and layer spec plumbing
  • Strong test coverage (unit tests for checkpointing, block manager, layer integration, memory savings, CUDA graphs, and offloading)

Findings

Severity Count
CRITICAL 0
IMPORTANT 3
SUGGESTION 4

Most Impactful Findings

  1. [IMPORTANT Implementation] reference_proj_inv_rms computes L2 norm in input dtype without fp32 upcast. For fp16 with n*C = 16384, intermediate sum(x²) overflows. This is the same class of bug that motivated fp32 upcast in RMSNorm / LayerNorm throughout Megatron. (comment)

  2. [IMPORTANT Compatibility] CheckpointWithoutOutput.__init__ changes self.fp8 = fp8 is not Noneself.fp8 = bool(fp8). This is a correct bug fix (the old code entered FP8 recompute context for every default CheckpointWithoutOutput() caller), but it silently changes behavior for all existing layernorm/MLA checkpoint sites. Should be called out separately. (comment)

  3. [IMPORTANT Correctness] HyperConnectionTransformerLayer.__call__ stores the mhc_recompute_manager on self to bypass CUDA graph argument filtering. Works today due to synchronous forward execution, but the pattern is fragile. (comment)

Overall Assessment

Risk level: Medium. The mHC implementation is well-structured with thorough validation guards (PP blocked, MTP blocked, fused TP blocked, MoE CUDA graphs blocked) and good test coverage. The core mHC math (Sinkhorn, aggregate, expand/contract) looks correct, and the CheckpointManager design for cross-layer recompute is sound.

The primary risk is the fp16 overflow in reference_proj_inv_rms (#1), which could cause silent NaN/Inf during mixed-precision training. The fp8 flag fix (#2) is correct but should be validated independently with FP8 training to confirm no regressions. The remaining suggestions are quality improvements, not blockers.

The test suite covers the important cases: checkpoint correctness, memory savings, CUDA graph capture+replay, activation offloading, and config validation. Nice work on the defensive programming — the config __post_init__ guards are comprehensive and will prevent users from hitting unsupported combinations.

Copy link
Copy Markdown
Author

Follow-up after Claude's latest strict review: pushed 6562e52.

Addressed the latest important/suggestion items:

  • Upcasted the reference mHC projection/RMS path to fp32 for the norm/projection math, and added a fp16 overflow regression test for reference_proj_inv_rms.
  • Kept _compute_h's multiply/add/sigmoid path in fp32 before casting mapping outputs back to the projection dtype.
  • Replaced the CUDA-graph mhc_recompute_manager module-instance handoff with a context-local handoff, with the synchronous capture invariant documented.
  • Renamed the block-level recompute manager to MHCRecomputeManager while leaving a compatibility alias for the previous generic name.
  • Added/updated comments for the reference implementation's non-TP mapping_proj, input_expand materialization cost, explicit bool rejection in config validation, and the intentional CheckpointWithoutOutput(fp8=False) behavior.

Validation:

  • local git diff --check
  • local python3 -m compileall -q over the touched implementation/tests
  • hsg Slurm/container focused pytest: 5 passed, 23 warnings in 29.35s
    • test_checkpoint_without_output_fp8_flag_is_explicit
    • test_reference_proj_inv_rms_upcasts_norm_for_fp16
    • test_mcore_cudagraph_manager_with_mhc_recompute_manager
    • test_cuda_graph_fwd_bwd_with_hyper_connection_and_recompute
    • test_forward_with_hyper_connection_recompute

Copy link
Copy Markdown
Author

/claude strict-review

Comment on lines +836 to +839
# Intentional bug fix: the default fp8=False must not enter the TE FP8
# recompute context. The old "fp8 is not None" behavior did that for
# every default CheckpointWithoutOutput() caller.
self.fp8 = bool(fp8)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Compatibility] This fixes a real bug (False is not NoneTrue), but it silently changes behavior for all existing default callers of CheckpointWithoutOutput() — including layernorm recompute (transformer_layer.py:590, 728), MoE activation recompute (experts.py:380), and fine-grained callables (fine_grained_callables.py:483).

Previously these callers entered TE's activation_recompute_forward context during forward and fp8_autocast during recompute, even when FP8 was not enabled. After this fix they correctly use nullcontext().

The fix is correct, and the new test (test_checkpoint_without_output_fp8_flag_is_explicit) covers it. However, because activation_recompute_forward may set TE-internal state even when FP8 is globally disabled, I'd recommend:

  1. Verifying that existing selective-recompute tests (layernorm, mla_up_proj, moe_act) still pass unchanged — confirming the old TE context was truly a no-op.
  2. Mentioning this fix in the PR description since it affects non-mHC code paths.

Comment on lines +1294 to +1301
def backward_dw_cudagraph(self, microbatch_idx):
"""
CUDA Graph backward weight gradient computation for this layer.
"""
cg_index = microbatch_idx % len(self.cuda_graphs)
if not hasattr(self.cuda_graphs[cg_index], 'backward_dw'):
return
self.cuda_graphs[cg_index].backward_dw()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Compatibility] backward_dw_cudagraph is defined but never called anywhere in the codebase — grep -rn "backward_dw_cudagraph" megatron/core/ returns only this definition. This is dead code that should either be connected to a caller (e.g., the pipeline schedule for deferred weight gradients) or removed.

If it's intended for a follow-up PR, please add a # TODO: comment referencing that work.

Comment thread megatron/core/tensor_parallel/random.py Outdated


# Backward-compatible alias for earlier internal users of the generic name.
CheckpointManager = MHCRecomputeManager
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Compatibility] CheckpointManager alias is defined but never imported or used anywhere in the codebase. The comment says "Backward-compatible alias for earlier internal users" but no such users exist. This is dead code.

Suggested change
CheckpointManager = MHCRecomputeManager

Comment on lines +815 to +818
if self.config.enable_hyper_connections:
hidden_states = HyperConnectionModule.input_expand(
hidden_states, self.num_residual_streams
) # [s, b, C] -> [s, b, n*C]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Correctness] When recompute_granularity == 'full', _checkpointed_forward does not pass mhc_recompute_manager to layers — this is correct (full recompute already wraps the forward). However, there is no config validation blocking the combination of enable_hyper_connections=True with recompute_granularity='full'.

_checkpointed_forward also omits rotary_pos_cos, rotary_pos_sin, rotary_pos_cos_sin, and sequence_len_offset kwargs (pre-existing issue), which could silently produce wrong results for positional encoding features combined with full recompute.

Consider adding a validation or warning in TransformerConfig.__post_init__ for this combination, or a TODO noting the limitation.

Comment on lines 747 to 753
torch.autograd.backward(outputs, args)
ctx.outputs = None
ctx.inputs = None
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in inputs)
# Autograd expects None for non-tensor inputs; returning the original
# non-tensor value as a "gradient" is invalid once args are restored.
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs)
return (None, None) + grads
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Correctness] Good fix — the old code inp.grad if isinstance(inp, torch.Tensor) else inp returned the original non-tensor value as a "gradient", which is invalid for autograd (it expects None for non-differentiable inputs). The new code correctly returns None.

However, this is a behavioral change for existing callers of CheckpointWithoutOutputFunction. If any caller was (incorrectly) relying on getting the non-tensor value back through the gradient tuple, this would break. Given that autograd should never consume non-tensor "gradients", this change is safe, but worth noting in the PR description alongside the fp8 fix since both modify the shared CheckpointWithoutOutput infrastructure.

Comment on lines +524 to +535
3. output = dropout(expanded + bias) + mixed (bias-dropout-add)

Args:
h_res: [s, b, n, n] - residual mixing matrix
original_residual: [s, b, n*C] - n-stream hidden states (before H_res applied)
h_post: [s, b, n] - expansion weights
layer_output_with_bias: Tuple of (x, bias) where:
- x: [s, b, C] - layer output (attention or MLP output)
- bias: [C] or None - optional bias tensor
dropout_prob: Dropout probability
training: Whether in training mode
fused: Whether to use fused BDA implementation
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION Naming] fused_h_res_h_post_bda is misleading — this is a sequential implementation, not a fused kernel. The docstring itself says "This is a placeholder for future kernel fusion optimization."

Consider renaming to h_res_h_post_bda (dropping the "fused" prefix) and reserving "fused" for the actual kernel variant when it lands. The current name may confuse readers inspecting profiles or code paths into thinking this is already an optimized fused operation.

Comment on lines +158 to +163
from megatron.core.fusions.fused_mhc_kernels import (
fused_h_aggregate,
fused_h_post_bda,
fused_proj_rms,
fused_sinkhorn,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION Simplification] The sequence_parallel attribute is set via setattr(self.mapping_proj.weight, 'sequence_parallel', True) — this can be simplified to direct attribute assignment:

Suggested change
from megatron.core.fusions.fused_mhc_kernels import (
fused_h_aggregate,
fused_h_post_bda,
fused_proj_rms,
fused_sinkhorn,
)
self.mapping_proj.weight.sequence_parallel = True
self.alpha_pre.sequence_parallel = True
self.alpha_post.sequence_parallel = True
self.alpha_res.sequence_parallel = True
self.bias.sequence_parallel = True

Comment on lines 240 to 244
pre_cross_attn_layernorm: LayerNormBuilder = IdentityOp
# Reserved for future cross-attention hyper-connection support.
cross_attention_hyper_connection: Union[ModuleSpec, type] = IdentityOp
cross_attention: Union[ModuleSpec, type] = IdentityOp
cross_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION Simplification] cross_attention_hyper_connection is declared here but has no runtime use path — it's only validated to be IdentityOp in HyperConnectionTransformerLayer.__init__ and stripped to IdentityOp in get_gpt_mtp_block_spec_for_backend. It's never built or called.

If this is reserved for a future cross-attention HC implementation, consider adding a brief TODO reference to the planned follow-up. Otherwise it could confuse readers into thinking it's wired up.

Comment on lines +878 to 919
####################
# Hyper-Connection Configuration
####################
enable_hyper_connections: bool = False
"""Enable mHC residual connections."""

num_residual_streams: int = 4
"""Number of residual streams (n in paper)."""

mhc_sinkhorn_iterations: int = 20
"""Number of Sinkhorn-Knopp iterations for doubly stochastic projection."""

mhc_init_gating_factor: float = 0.01
"""Initial value of Gating Factor (alpha in paper)."""

use_fused_mhc: bool = False
"""Use cuTile fused kernels for mHC operations.

When True, attempts to replace the reference mHC modules (SinkhornKnopp,
H_aggregate, H_post_bda, ProjRms) with fused cuda.tile (cuTile) autograd
functions for better performance on supported GPUs. Requires cuTile to be
installed; if cuTile is unavailable the flag is silently reset to False and
a warning is emitted.
"""

mhc_recompute_layer_num: Optional[int] = None
"""Number of layers per MHC recompute block.

When set, every `mhc_recompute_layer_num` layers form a recompute block. The last layer
in each recompute block (i.e., layer_number % mhc_recompute_layer_num == 0 or the final
layer in the transformer block) will:
- NOT checkpoint its final MLP BDA
- Register the unified recompute hook on its MLP BDA output
- A new MHCRecomputeManager is created for subsequent layers

If None, all layers in the transformer block share a single recompute block.

Must be a positive integer when set."""

####################
# miscellaneous
####################
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION Naming] Consider documenting the relationship between num_residual_streams and hidden dimension impact. The n-stream architecture multiplies the effective hidden dimension by n within the transformer block ([s, b, n*C]), which has significant memory implications. A brief note like "memory scales linearly with this value within the transformer block" would help users set this parameter appropriately.

Comment on lines +271 to +285
- [C] - bias tensor (will be broadcast)
h_post: [s, b, n] - expansion weights

Returns:
output: [s, b, n*C] - expanded tensor
"""
n = self.n
s, b, _ = h_post.shape

if x.dim() == 1:
# x is bias with shape [C], need to broadcast to [s, b, 1, C]
C = x.shape[0]
x_expanded = x.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand(s, b, 1, C)
else:
# x is [s, b, C]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION Simplification] _apply_h_post is @torch.compile-decorated but branches on x.dim(). It's called with both 1D (bias) and 3D (hidden_states) inputs within the same forward pass, which will trigger two separate graph specializations and recompilation on the first occurrence of each shape rank.

This is not a correctness issue, but for cleaner compilation consider splitting into two methods (_apply_h_post_hidden for 3D and _apply_h_post_bias for 1D) to avoid the dynamic branching under @torch.compile.

@claude
Copy link
Copy Markdown
Contributor

claude Bot commented Apr 27, 2026

Code Review Summary

PR: Add mHC transformer reference implementation

Reviewed the full diff (3529 additions, 79 deletions across 19 files) covering the core Manifold Hyper-Connections (mHC) transformer implementation on dsv4.

Findings

Severity Count
CRITICAL 0
IMPORTANT 4
SUGGESTION 4

Key Findings

IMPORTANT:

  1. fp8 flag bug fix in CheckpointWithoutOutput — Changes behavior for all existing default callers (layernorm recompute, MoE activation recompute, etc.) by no longer entering TE's FP8 context managers when FP8 is disabled. The fix is correct, but should be validated against existing selective-recompute tests and documented in the PR description since it affects non-mHC code paths.

  2. backward_dw_cudagraph is dead code — Method defined on TransformerLayer but never called anywhere in the codebase. Should be connected to a caller or removed.

  3. CheckpointManager alias is unused — Alias defined but never imported or referenced. Dead code that should be removed.

  4. Full recompute + hyper connections lacks validation_checkpointed_forward doesn't pass mhc_recompute_manager (correct) but also omits rotary_pos_cos/sin (pre-existing). No config validation warns about this combination.

SUGGESTION:

  1. fused_h_res_h_post_bda name is misleading — it's a sequential implementation, not a fused kernel.
  2. cross_attention_hyper_connection field is placeholder-only — needs a TODO reference.
  3. _apply_h_post with @torch.compile may trigger recompilation due to dynamic shape branching.
  4. num_residual_streams docstring should note memory scaling implications.

What Looks Good

  • Correctness: The mHC math (Sinkhorn projection, H_pre aggregation, H_post expansion, H_res residual mixing) is correctly implemented and matches the paper formulation. The reference_h_post_bda fast path correctly fuses H_res @ residual + H_post * (x + bias) without dropout.
  • RNG handling: CheckpointWithoutOutput properly saves/restores RNG state via _get_all_rng_states / _set_all_rng_states, so dropout masks are consistent during recomputation.
  • CUDA graph compatibility: The ContextVar handoff for mhc_recompute_manager in HyperConnectionTransformerLayer.__call__ correctly avoids exposing unsupported types to the CudaGraphManager.
  • Config validation: Thorough validation blocks incompatible combinations (pipeline parallelism, MTP, fused TP inference, fine-grained offloading + mHC recompute).
  • MTP spec stripping: get_gpt_mtp_block_spec_for_backend correctly uses copy.copy and issubclass to handle HC layer specs.
  • Non-tensor arg handling: _save_args_to_ctx / _load_args_from_ctx properly separates tensor and non-tensor args for CheckpointWithoutOutput recomputation.
  • Test coverage: Comprehensive tests for CUDA graph capture/replay, activation offloading numerical equivalence, recompute blocks, and layer spec generation.

Overall Assessment

Low risk. The core mHC implementation is mathematically sound and well-integrated with Megatron's checkpoint, CUDA graph, and parallelism infrastructure. The important findings are about dead code and a behavior change in shared infrastructure (CheckpointWithoutOutput.fp8) rather than correctness issues in the mHC code itself. The config validations correctly block unsupported combinations. The primary concern is ensuring the fp8 flag fix doesn't cause regressions in non-mHC selective recompute paths.

Copy link
Copy Markdown
Author

Addressed Claude's latest low-risk review items in b03220133e47538133fb5f6195f7b7bba0ccfa5b:

  • Removed the unused CheckpointManager = MHCRecomputeManager alias.
  • Removed the unused TransformerLayer.backward_dw_cudagraph method.
  • Added config validation rejecting enable_hyper_connections=True with full activation recompute, plus a focused unit test.
  • Renamed fused_h_res_h_post_bda to h_res_h_post_bda so the reference path is not described as a fused kernel.
  • Split _apply_h_post into hidden-state and bias variants to avoid rank branching under @torch.compile.
  • Applied the direct sequence_parallel attribute assignment suggestion.
  • Added cross-attention HC TODO/doc wording and expanded the num_residual_streams memory-scaling docstring.

Validation completed:

  • git diff --check passed locally.
  • python3 -m compileall -q passed for the touched core/test files locally and in a temporary hsg worktree.
  • Stale-symbol grep passed for the removed alias/method/old fused name/old _apply_h_post entry point locally and in the hsg temp worktree.

I also attempted the focused hsg Slurm/container pytest subset, including the selective-recompute FP8 cases Claude called out, but the 4-GPU allocations stayed pending or returned Requested nodes are busy; I did not count those as passed.

Copy link
Copy Markdown
Author

/claude strict-review

# Intentional bug fix: the default fp8=False must not enter the TE FP8
# recompute context. The old "fp8 is not None" behavior did that for
# every default CheckpointWithoutOutput() caller.
self.fp8 = bool(fp8)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Compatibility] This changes the default behavior of CheckpointWithoutOutput() for all callers, not just mHC callers.

Previously, the default fp8=False resulted in self.fp8 = (False is not None)True, meaning every default CheckpointWithoutOutput() instance (including layernorm and MLA recompute) incorrectly entered the TE FP8 recompute context during _recompute(). The fix to bool(fp8) is correct, but it silently changes behavior for the existing selective-recompute code paths.

In practice this is likely safe — the TE FP8 context is probably a no-op when TE FP8 is not actually enabled — but it's worth calling out in the PR description as a cross-cutting behavioral change bundled into an mHC feature PR. If any existing test relied on the old (buggy) default, it would break here.

FineGrainedActivationOffloadingInterface as off_interface,
)

hidden_states = off_interface.group_commit(hidden_states, name="mlp_norm")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION Performance] Missing forced_released_tensors=[residual] — the base class _forward_post_mlp (line 901-903) passes forced_released_tensors=[residual] to group_commit, which explicitly releases the residual tensor at the offloading sync point. The HC version omits this.

For mHC layers the residual is [s, b, n*C]n times larger than the standard residual — so the delayed release has a proportionally larger memory cost. The tensor is eventually collected when _forward_post_mlp_with_fused_hyper_connection returns, so the impact window is small, but adding the explicit release matches the base class behavior.

The same issue exists for the attention offloading path at line 1562 (missing forced_released_tensors=[residual] vs. base line 666-668).

Suggested change
hidden_states = off_interface.group_commit(hidden_states, name="mlp_norm")
hidden_states = off_interface.group_commit(
hidden_states, name="mlp_norm", forced_released_tensors=[residual]
)

nvtx_range_pop(suffix="self_attention_h_res_h_post_bda")

if self.offload_attn_norm:
hidden_states = off_interface.group_commit(hidden_states, name="attn_norm")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION Performance] Same as the MLP offloading path: base class (line 666-668) passes forced_released_tensors=[residual] here. The n-stream residual [s, b, n*C] is worth releasing explicitly.

Suggested change
hidden_states = off_interface.group_commit(hidden_states, name="attn_norm")
hidden_states = off_interface.group_commit(
hidden_states, name="attn_norm", forced_released_tensors=[residual]
)

mlp_h_res (Tensor): [s, b, n, n] - residual mixing matrix from hyper connection.
residual (Tensor): [s, b, n*C] - original residual (n-stream hidden states).
mlp_hc_h_post (Tensor): [s, b, n] - expansion weights from hyper connection.
mhc_recompute_manager: Optional MHCRecomputeManager for checkpoint management.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION Naming] Minor docstring mismatch: this says mhc_recompute_manager but the parameter is mhc_mlp_bda_recompute_manager.

Suggested change
mhc_recompute_manager: Optional MHCRecomputeManager for checkpoint management.
mhc_mlp_bda_recompute_manager: Optional MHCRecomputeManager for checkpoint management.

Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Summary

CRITICAL: 0 | IMPORTANT: 1 | SUGGESTION: 3

Overall Assessment

This PR ports the core mHC (Manifold-Constrained Hyper-Connections) transformer support onto dsv4. The implementation is well-structured and carefully guarded with comprehensive config validation — pipeline parallelism, full activation recompute, fused TP inference, and fine-grained offloading + mhc recompute are all properly blocked or validated. The new MHCRecomputeManager with forward-order unified recomputation is correctly designed and thoroughly tested (including CUDA graph capture+replay, memory savings, and numerical equivalence tests).

Key Findings

  1. [IMPORTANT Compatibility] CheckpointWithoutOutput.fp8 behavioral change (random.py:835): The fix from fp8 is not None to bool(fp8) corrects a bug where the default fp8=False incorrectly entered the TE FP8 recompute context. This is a cross-cutting fix that affects all existing CheckpointWithoutOutput() callers (layernorm recompute, MLA recompute, etc.), not just mHC. Worth noting in the PR description.

  2. [SUGGESTION Performance] Missing forced_released_tensors=[residual] in the HC activation offloading paths (transformer_layer.py:1562, 1745). The base class explicitly releases the residual tensor at the offloading sync point; the HC version does not. Since the HC residual is larger, the omission has proportionally more impact.

  3. [SUGGESTION Naming] Minor docstring/parameter name mismatch in _forward_post_mlp_with_fused_hyper_connection.

What looks good

  • Config validation is thorough — all unsupported combinations (PP, full recompute, fused TP inference, mhc+offloading, mhc+mlp recompute) are caught with clear error messages
  • MHCRecomputeManager correctly enforces forward-order recomputation for chained checkpoints, with proper ctx=None guard against double-recompute
  • _save_args_to_ctx / _load_args_from_ctx properly generalizes CheckpointWithoutOutput to handle mixed tensor/non-tensor arguments
  • CUDA graph compatibility is well-handled via the ContextVar handoff pattern for mhc_recompute_manager, and the get_layer_static_inputs / _get_submodules_under_cudagraphs overrides
  • MTP spec stripping correctly uses issubclass(spec.module, TransformerLayer) (changed from ==) and copy.copy to avoid mutating the original spec
  • Test coverage is excellent — checkpoint correctness, memory savings, CUDA graph capture+replay, offloading numerical equivalence, config validation edge cases

Risk Level: Low-Medium

The only behavioral change to shared code is the fp8 flag fix, which corrects a latent bug. All mHC code paths are gated behind enable_hyper_connections=True (default False), so existing training configurations are unaffected.

@Connor-XY Connor-XY closed this May 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants