Add mHC transformer reference implementation#4483
Add mHC transformer reference implementation#4483Connor-XY wants to merge 10 commits intoNVIDIA:dsv4from
Conversation
|
Update after the latest push: this first split PR is now intended to be self-contained, not code-only. It includes both the core/reference Transformer mHC implementation and focused unit-test coverage for that implementation:
Validation on hsg at head
I also tried a broader file-level pytest invocation first; it stopped on an existing non-mHC Follow-up split PRs should now start with fused mHC kernel support, pipeline-parallel compatibility, functional CI/golden coverage, and HybridModel support. The earlier PR description text saying "code path only" / "unit tests are separated" is stale. |
|
/ok to test 00c3339 |
|
/claude strict-review |
| if self.enable_hyper_connections and self.mtp_num_layers is not None: | ||
| raise ValueError( | ||
| "enable_hyper_connections is not compatible with Multi-Token Prediction (MTP). " | ||
| "Please disable MTP (set mtp_num_layers=None) when using hyper connections." | ||
| ) |
There was a problem hiding this comment.
[CRITICAL Correctness] Missing validation: enable_hyper_connections with pipeline_model_parallel_size > 1 will silently produce wrong results or crash.
The code in transformer_block.py has PP-aware branching (self.pre_process for expand, has_final_layernorm_in_this_stage() for contract), implying n-stream hidden states flow between PP stages. However, get_tensor_shapes() in schedules.py computes the P2P tensor shape as (seq_len, micro_batch_size, config.hidden_size) — it does not account for num_residual_streams. This means P2P communication will allocate receive buffers of the wrong size, causing either a crash or silent data corruption.
The PR description notes "pipeline-parallel mHC compatibility" as a follow-up, but users can still configure enable_hyper_connections=True with PP > 1 today. Please add a guard here:
if self.enable_hyper_connections and self.pipeline_model_parallel_size > 1:
raise ValueError(
"enable_hyper_connections is not yet compatible with pipeline parallelism "
"(pipeline_model_parallel_size > 1). Pipeline-parallel support is planned "
"for a follow-up PR."
)| Initialize CheckpointWithoutOutput. | ||
|
|
||
| Args: | ||
| fp8: Whether to use FP8 mode. Defaults to False. |
There was a problem hiding this comment.
[IMPORTANT Compatibility] Behavior change in fp8 flag: fp8 is not None → bool(fp8).
The old code self.fp8 = fp8 is not None evaluated to True for the default fp8=False, which was a latent bug — it would attempt to set up FP8 recompute context even when FP8 wasn't in use. The new bool(fp8) correctly evaluates to False for the default case.
This fix is correct, but it changes runtime behavior for all existing callers that use the default CheckpointWithoutOutput() (no fp8 arg). Previously they would enter the FP8 context path during recompute; now they won't. If TE's get_fp8_context_for_recompute() happened to be a no-op when FP8 wasn't globally enabled, this change is safe. But if any caller relied on the old behavior, this could subtly break.
Please verify that no existing callers are affected and consider adding a brief comment or changelog note that this is an intentional bug fix.
| ctx.outputs = None | ||
| ctx.inputs = None | ||
| grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in inputs) | ||
| grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs) |
There was a problem hiding this comment.
[IMPORTANT Compatibility] Changed non-tensor gradient return from inp to None.
The old code returned the non-tensor value itself as a "gradient" (else inp), which is technically incorrect — autograd expects None for non-differentiable inputs. The new code (else None) is correct.
However, this is a behavioral change to CheckpointWithoutOutputFunction.backward that affects ALL callers of CheckpointWithoutOutput, not just the new mHC code path. The old code happened to work because non-tensor args were never passed to save_for_backward (which only accepts tensors), so this path was likely never exercised. The new _save_args_to_ctx now properly handles non-tensor args, making this fix necessary and correct.
Consider splitting this fix into a separate commit to isolate the backward-compatibility change from the mHC feature addition.
| # Expand hidden states for hyper connections at the start of the block | ||
| # Only expand at the first PP stage; subsequent stages receive n-stream from previous stage | ||
| if self.config.enable_hyper_connections and self.pre_process: | ||
| hidden_states = HyperConnectionModule.input_expand( | ||
| hidden_states, self.num_residual_streams | ||
| ) # [s, b, C] -> [s, b, n*C] |
There was a problem hiding this comment.
[IMPORTANT Correctness] The self.pre_process guard assumes PP-aware n-stream passing, but PP is not yet supported.
This check implies that non-first PP stages receive n-stream tensors from the previous stage and should not expand. But since PP support is not implemented (no shape adaptation in get_tensor_shapes() / p2p_communication), this guard is premature and misleading.
Until PP support is added, this should either:
- Use a simpler unconditional expand (since PP>1 should be blocked by config validation), or
- At minimum, add an assertion to catch the case where
not self.pre_processwith HC enabled, since it would be an invalid state.
If you add the PP>1 validation guard I suggested in transformer_config.py, then self.pre_process will always be True when HC is enabled, making this condition redundant but harmless.
| # Only contract if the final layer norm is in this stage | ||
| if self.config.enable_hyper_connections and self.has_final_layernorm_in_this_stage(): | ||
| hidden_states = HyperConnectionModule.output_contract( | ||
| hidden_states, self.num_residual_streams | ||
| ) # [s, b, n*C] -> [s, b, C] |
There was a problem hiding this comment.
[IMPORTANT Correctness] has_final_layernorm_in_this_stage() may not be the right guard for output contraction.
has_final_layernorm_in_this_stage() checks whether this stage has the final layernorm (based on post_process, post_layer_norm, and MTP config). But output_contract converts [s, b, n*C] → [s, b, C], which must happen exactly once at the model boundary — specifically right before the final layernorm, which expects [s, b, C].
For configurations without a final layernorm in this stage (e.g., post_layer_norm=False, or MTP configurations where the final layernorm is in a different stage), this contraction would be skipped, and the n-stream tensor would be passed to subsequent processing (output layer, loss computation) with the wrong hidden dimension.
This is mitigated if PP>1 is blocked by validation (as I suggested), but even with PP=1, post_layer_norm=False could skip this contraction. Please verify that the guard covers all valid configurations.
| def discard_all_outputs_and_register_unified_recompute(self, hook_tensor): | ||
| """Discard all checkpoint outputs to save memory and register unified recompute hook.""" | ||
| for ckpt in self.checkpoints: | ||
| for output in ckpt.outputs: | ||
| output.untyped_storage().resize_(0) | ||
|
|
||
| # Register unified recompute hook | ||
| if hook_tensor.requires_grad: | ||
| hook_tensor.register_hook(self._unified_recompute_hook) | ||
|
|
||
| def _unified_recompute_hook(self, grad_output): | ||
| for ckpt in self.checkpoints: | ||
| # Call _recompute for each checkpoint in forward order | ||
| # The _recompute method will restore the output tensor storage | ||
| ckpt._recompute(None) |
There was a problem hiding this comment.
[IMPORTANT Correctness] discard_all_outputs_and_register_unified_recompute relies on forward-order recomputation — the ordering assumption should be documented.
When chained checkpoints are used (checkpoint 2's input = checkpoint 1's output), discarding ALL outputs and then recomputing in forward order is correct because:
- Checkpoint 1's
_recompute()restores its output tensor storage in-place viauntyped_storage().resize_()+copy_() - Checkpoint 2's saved input (which shares storage with checkpoint 1's output) is thereby restored before checkpoint 2's
_recompute()needs it
This is subtle and depends on:
_recompute()restoring storage in-place on the original tensor objectssave_for_backward()preserving storage identity (not copying)- Iterating
self.checkpointsin forward (append) order
Consider adding a comment explaining this invariant, since breaking any of these assumptions (e.g., recomputing in reverse order, or PyTorch changing save_for_backward internals) would cause silent data corruption.
| residual_batched = original_residual.view(s * b, n, C) | ||
| mixed = torch.bmm(h_res_batched, residual_batched).view(s, b, n, C) | ||
| x_expanded = h_post.unsqueeze(-1) * x.unsqueeze(2) |
There was a problem hiding this comment.
[SUGGESTION Naming] native_h_aggregate name is ambiguous — "native" could mean "PyTorch native" or "CPU-native".
Throughout this file, native_ prefix is used to distinguish reference PyTorch implementations from fused cuTile kernels. Consider reference_ or unfused_ instead, which is clearer about the distinction:
| residual_batched = original_residual.view(s * b, n, C) | |
| mixed = torch.bmm(h_res_batched, residual_batched).view(s, b, n, C) | |
| x_expanded = h_post.unsqueeze(-1) * x.unsqueeze(2) | |
| def reference_h_aggregate(x: Tensor, h_pre: Tensor) -> Tensor: |
Or alternatively, adopt a naming convention consistent with the rest of the codebase (e.g., the existing bias_dropout_add_unfused pattern).
| self.hidden_size = config.hidden_size | ||
| self.sinkhorn_iterations = config.mhc_sinkhorn_iterations | ||
|
|
There was a problem hiding this comment.
[SUGGESTION Simplification] TODO: keep hyper connection in fp32 computation — this TODO flags a real concern.
The mapping projection, Sinkhorn iterations, and gating computations all run in the model's default dtype (typically bf16/fp16). Since these involve:
- Small scalar parameters (
alpha_pre/post/res,bias) that can lose precision - Iterative normalization (Sinkhorn) with
.clamp(min=eps)that is sensitive to underflow - RMS normalization with
rsqrt-equivalent computation
Running these in reduced precision could cause training instability, especially at scale. Consider at minimum upcasting the Sinkhorn iterations to fp32 (the .clamp(min=eps) with eps=1e-6 is dangerously close to fp16's epsilon of ~6e-8).
If this is intentionally deferred, please convert this to a more explicit TODO with a tracking reference.
| use_kitchen_attention: bool = False, | ||
| kitchen_attention_backend: str = "sdpa", | ||
| mla_down_proj_fusion: bool = False, | ||
| enable_hyper_connection: bool = False, |
There was a problem hiding this comment.
[SUGGESTION Naming] Inconsistent naming: enable_hyper_connection (singular) here vs enable_hyper_connections (plural) in TransformerConfig.
The config field is config.enable_hyper_connections (plural), but the layer spec parameter is enable_hyper_connection (singular). This is confusing and error-prone — a typo in either direction would silently use the default False.
Consider aligning to one name. Since the config uses plural (enable_hyper_connections), the spec functions should match:
| enable_hyper_connection: bool = False, | |
| enable_hyper_connection: bool = False, |
Or rename to enable_hyper_connections throughout the spec functions to match the config.
|
|
||
| Returns: | ||
| block_size: Optimal block size for checkpointing | ||
| """ | ||
| block_size = int(math.sqrt(num_streams * num_layers / (num_streams + 2))) | ||
| return max(1, block_size) |
There was a problem hiding this comment.
[SUGGESTION Simplification] HyperConnectionCheckpoint.compute_optimal_block_size is defined but never called in this PR.
This static method implements Eq. (20) from the paper but has no callers. If it's intended for follow-up PRs, consider deferring it to the PR that actually uses it. If it's meant to be called from _build_mhc_recompute_layer_plan as a default for mhc_recompute_layer_num, please wire it up or add a TODO referencing the follow-up.
| def __call__(self, *args, **kwargs): | ||
| # Extract mhc_recompute_manager before CUDA graph manager processes kwargs, | ||
| # since CheckpointManager is not a CUDA-graph-supported type. | ||
| self._mhc_recompute_manager = kwargs.pop("mhc_recompute_manager", None) | ||
| kwargs.pop("is_last_layer_in_recompute_block", None) |
There was a problem hiding this comment.
[IMPORTANT Compatibility] __call__ override on the base TransformerLayer changes the call interface for ALL layer types, not just mHC.
This __call__ override is on TransformerLayer (the base class), meaning every layer type (including MoETransformerLayer and any downstream custom layers) now has mhc_recompute_manager and is_last_layer_in_recompute_block silently popped from kwargs. While the pop with default None is safe, storing self._mhc_recompute_manager on every layer instance is unnecessary overhead for non-HC layers.
Consider either:
- Moving this
__call__override toHyperConnectionTransformerLayeronly, or - At minimum, guarding the attribute assignment:
if self.config.enable_hyper_connections:
Additionally, kwargs.pop("is_last_layer_in_recompute_block", None) pops a kwarg that is never passed from transformer_block.py (the block sets it as an attribute on the manager instead). This is dead code.
| # Injected by __call__ for cuda graph keying; not a real forward arg. | ||
| kwargs.pop("dynamic_inference_decode_only", None) | ||
| assert ( | ||
| not self.config.enable_hyper_connections | ||
| ), "Please use HyperConnectionTransformerLayer instead" |
There was a problem hiding this comment.
[IMPORTANT Compatibility] Hard assertion in base TransformerLayer.forward blocks HC-enabled configs from using the base class.
This assertion fires for any code path that calls TransformerLayer.forward() when enable_hyper_connections=True — even if hyper connections are not actually needed (e.g., MTP layers that strip HC, or custom model code).
The MTP path in gpt_layer_specs.py correctly downgrades the layer module to TransformerLayer, but the config still has enable_hyper_connections=True. This assertion would fire in that case.
Consider either:
- Checking
isinstance(self, HyperConnectionTransformerLayer)instead of the config flag, or - Removing the assertion and letting the spec/module selection handle correctness.
| raise ValueError( | ||
| "'mhc' and 'mlp' in recompute_modules cannot be used together. " | ||
| "They use different checkpoint mechanisms that may conflict." | ||
| ) | ||
| if self.mhc_recompute_layer_num is not None and ( | ||
| isinstance(self.mhc_recompute_layer_num, bool) |
There was a problem hiding this comment.
[SUGGESTION Simplification] The isinstance(self.mhc_recompute_layer_num, bool) check before the int check is needed because bool is a subclass of int in Python. This is correct but non-obvious — a reader might wonder why bool is checked separately.
Consider simplifying with a type annotation or a clearer guard:
if self.mhc_recompute_layer_num is not None:
if not isinstance(self.mhc_recompute_layer_num, int) or isinstance(
self.mhc_recompute_layer_num, bool
) or self.mhc_recompute_layer_num < 1:Or adding a one-line comment: # bool is a subclass of int, reject it explicitly.
There was a problem hiding this comment.
Review Summary
PR: Add mHC transformer reference implementation
Reviewed the full diff (3489 additions, 79 deletions across 18 files). This PR ports the core Manifold Hyper-Connections (mHC) implementation onto dsv4, adding:
HyperConnectionModulewith Sinkhorn-projected dynamic mappingsHyperConnectionTransformerLayerextending TransformerLayer with n-stream pre/post processingCheckpointManagerfor cross-layer mHC activation recompute- Mixed tensor/non-tensor
_save_args_to_ctx/_load_args_from_ctxhelpers - Comprehensive config validation and layer spec plumbing
- Strong test coverage (unit tests for checkpointing, block manager, layer integration, memory savings, CUDA graphs, and offloading)
Findings
| Severity | Count |
|---|---|
| CRITICAL | 0 |
| IMPORTANT | 3 |
| SUGGESTION | 4 |
Most Impactful Findings
-
[IMPORTANT Implementation]
reference_proj_inv_rmscomputes L2 norm in input dtype without fp32 upcast. For fp16 withn*C = 16384, intermediatesum(x²)overflows. This is the same class of bug that motivated fp32 upcast inRMSNorm/LayerNormthroughout Megatron. (comment) -
[IMPORTANT Compatibility]
CheckpointWithoutOutput.__init__changesself.fp8 = fp8 is not None→self.fp8 = bool(fp8). This is a correct bug fix (the old code entered FP8 recompute context for every defaultCheckpointWithoutOutput()caller), but it silently changes behavior for all existing layernorm/MLA checkpoint sites. Should be called out separately. (comment) -
[IMPORTANT Correctness]
HyperConnectionTransformerLayer.__call__stores themhc_recompute_manageronselfto bypass CUDA graph argument filtering. Works today due to synchronous forward execution, but the pattern is fragile. (comment)
Overall Assessment
Risk level: Medium. The mHC implementation is well-structured with thorough validation guards (PP blocked, MTP blocked, fused TP blocked, MoE CUDA graphs blocked) and good test coverage. The core mHC math (Sinkhorn, aggregate, expand/contract) looks correct, and the CheckpointManager design for cross-layer recompute is sound.
The primary risk is the fp16 overflow in reference_proj_inv_rms (#1), which could cause silent NaN/Inf during mixed-precision training. The fp8 flag fix (#2) is correct but should be validated independently with FP8 training to confirm no regressions. The remaining suggestions are quality improvements, not blockers.
The test suite covers the important cases: checkpoint correctness, memory savings, CUDA graph capture+replay, activation offloading, and config validation. Nice work on the defensive programming — the config __post_init__ guards are comprehensive and will prevent users from hitting unsupported combinations.
|
Follow-up after Claude's latest strict review: pushed Addressed the latest important/suggestion items:
Validation:
|
|
/claude strict-review |
| # Intentional bug fix: the default fp8=False must not enter the TE FP8 | ||
| # recompute context. The old "fp8 is not None" behavior did that for | ||
| # every default CheckpointWithoutOutput() caller. | ||
| self.fp8 = bool(fp8) |
There was a problem hiding this comment.
[IMPORTANT Compatibility] This fixes a real bug (False is not None → True), but it silently changes behavior for all existing default callers of CheckpointWithoutOutput() — including layernorm recompute (transformer_layer.py:590, 728), MoE activation recompute (experts.py:380), and fine-grained callables (fine_grained_callables.py:483).
Previously these callers entered TE's activation_recompute_forward context during forward and fp8_autocast during recompute, even when FP8 was not enabled. After this fix they correctly use nullcontext().
The fix is correct, and the new test (test_checkpoint_without_output_fp8_flag_is_explicit) covers it. However, because activation_recompute_forward may set TE-internal state even when FP8 is globally disabled, I'd recommend:
- Verifying that existing selective-recompute tests (layernorm, mla_up_proj, moe_act) still pass unchanged — confirming the old TE context was truly a no-op.
- Mentioning this fix in the PR description since it affects non-mHC code paths.
| def backward_dw_cudagraph(self, microbatch_idx): | ||
| """ | ||
| CUDA Graph backward weight gradient computation for this layer. | ||
| """ | ||
| cg_index = microbatch_idx % len(self.cuda_graphs) | ||
| if not hasattr(self.cuda_graphs[cg_index], 'backward_dw'): | ||
| return | ||
| self.cuda_graphs[cg_index].backward_dw() |
There was a problem hiding this comment.
[IMPORTANT Compatibility] backward_dw_cudagraph is defined but never called anywhere in the codebase — grep -rn "backward_dw_cudagraph" megatron/core/ returns only this definition. This is dead code that should either be connected to a caller (e.g., the pipeline schedule for deferred weight gradients) or removed.
If it's intended for a follow-up PR, please add a # TODO: comment referencing that work.
|
|
||
|
|
||
| # Backward-compatible alias for earlier internal users of the generic name. | ||
| CheckpointManager = MHCRecomputeManager |
There was a problem hiding this comment.
[IMPORTANT Compatibility] CheckpointManager alias is defined but never imported or used anywhere in the codebase. The comment says "Backward-compatible alias for earlier internal users" but no such users exist. This is dead code.
| CheckpointManager = MHCRecomputeManager |
| if self.config.enable_hyper_connections: | ||
| hidden_states = HyperConnectionModule.input_expand( | ||
| hidden_states, self.num_residual_streams | ||
| ) # [s, b, C] -> [s, b, n*C] |
There was a problem hiding this comment.
[IMPORTANT Correctness] When recompute_granularity == 'full', _checkpointed_forward does not pass mhc_recompute_manager to layers — this is correct (full recompute already wraps the forward). However, there is no config validation blocking the combination of enable_hyper_connections=True with recompute_granularity='full'.
_checkpointed_forward also omits rotary_pos_cos, rotary_pos_sin, rotary_pos_cos_sin, and sequence_len_offset kwargs (pre-existing issue), which could silently produce wrong results for positional encoding features combined with full recompute.
Consider adding a validation or warning in TransformerConfig.__post_init__ for this combination, or a TODO noting the limitation.
| torch.autograd.backward(outputs, args) | ||
| ctx.outputs = None | ||
| ctx.inputs = None | ||
| grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in inputs) | ||
| # Autograd expects None for non-tensor inputs; returning the original | ||
| # non-tensor value as a "gradient" is invalid once args are restored. | ||
| grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs) | ||
| return (None, None) + grads |
There was a problem hiding this comment.
[IMPORTANT Correctness] Good fix — the old code inp.grad if isinstance(inp, torch.Tensor) else inp returned the original non-tensor value as a "gradient", which is invalid for autograd (it expects None for non-differentiable inputs). The new code correctly returns None.
However, this is a behavioral change for existing callers of CheckpointWithoutOutputFunction. If any caller was (incorrectly) relying on getting the non-tensor value back through the gradient tuple, this would break. Given that autograd should never consume non-tensor "gradients", this change is safe, but worth noting in the PR description alongside the fp8 fix since both modify the shared CheckpointWithoutOutput infrastructure.
| 3. output = dropout(expanded + bias) + mixed (bias-dropout-add) | ||
|
|
||
| Args: | ||
| h_res: [s, b, n, n] - residual mixing matrix | ||
| original_residual: [s, b, n*C] - n-stream hidden states (before H_res applied) | ||
| h_post: [s, b, n] - expansion weights | ||
| layer_output_with_bias: Tuple of (x, bias) where: | ||
| - x: [s, b, C] - layer output (attention or MLP output) | ||
| - bias: [C] or None - optional bias tensor | ||
| dropout_prob: Dropout probability | ||
| training: Whether in training mode | ||
| fused: Whether to use fused BDA implementation |
There was a problem hiding this comment.
[SUGGESTION Naming] fused_h_res_h_post_bda is misleading — this is a sequential implementation, not a fused kernel. The docstring itself says "This is a placeholder for future kernel fusion optimization."
Consider renaming to h_res_h_post_bda (dropping the "fused" prefix) and reserving "fused" for the actual kernel variant when it lands. The current name may confuse readers inspecting profiles or code paths into thinking this is already an optimized fused operation.
| from megatron.core.fusions.fused_mhc_kernels import ( | ||
| fused_h_aggregate, | ||
| fused_h_post_bda, | ||
| fused_proj_rms, | ||
| fused_sinkhorn, | ||
| ) |
There was a problem hiding this comment.
[SUGGESTION Simplification] The sequence_parallel attribute is set via setattr(self.mapping_proj.weight, 'sequence_parallel', True) — this can be simplified to direct attribute assignment:
| from megatron.core.fusions.fused_mhc_kernels import ( | |
| fused_h_aggregate, | |
| fused_h_post_bda, | |
| fused_proj_rms, | |
| fused_sinkhorn, | |
| ) | |
| self.mapping_proj.weight.sequence_parallel = True | |
| self.alpha_pre.sequence_parallel = True | |
| self.alpha_post.sequence_parallel = True | |
| self.alpha_res.sequence_parallel = True | |
| self.bias.sequence_parallel = True |
| pre_cross_attn_layernorm: LayerNormBuilder = IdentityOp | ||
| # Reserved for future cross-attention hyper-connection support. | ||
| cross_attention_hyper_connection: Union[ModuleSpec, type] = IdentityOp | ||
| cross_attention: Union[ModuleSpec, type] = IdentityOp | ||
| cross_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp |
There was a problem hiding this comment.
[SUGGESTION Simplification] cross_attention_hyper_connection is declared here but has no runtime use path — it's only validated to be IdentityOp in HyperConnectionTransformerLayer.__init__ and stripped to IdentityOp in get_gpt_mtp_block_spec_for_backend. It's never built or called.
If this is reserved for a future cross-attention HC implementation, consider adding a brief TODO reference to the planned follow-up. Otherwise it could confuse readers into thinking it's wired up.
| #################### | ||
| # Hyper-Connection Configuration | ||
| #################### | ||
| enable_hyper_connections: bool = False | ||
| """Enable mHC residual connections.""" | ||
|
|
||
| num_residual_streams: int = 4 | ||
| """Number of residual streams (n in paper).""" | ||
|
|
||
| mhc_sinkhorn_iterations: int = 20 | ||
| """Number of Sinkhorn-Knopp iterations for doubly stochastic projection.""" | ||
|
|
||
| mhc_init_gating_factor: float = 0.01 | ||
| """Initial value of Gating Factor (alpha in paper).""" | ||
|
|
||
| use_fused_mhc: bool = False | ||
| """Use cuTile fused kernels for mHC operations. | ||
|
|
||
| When True, attempts to replace the reference mHC modules (SinkhornKnopp, | ||
| H_aggregate, H_post_bda, ProjRms) with fused cuda.tile (cuTile) autograd | ||
| functions for better performance on supported GPUs. Requires cuTile to be | ||
| installed; if cuTile is unavailable the flag is silently reset to False and | ||
| a warning is emitted. | ||
| """ | ||
|
|
||
| mhc_recompute_layer_num: Optional[int] = None | ||
| """Number of layers per MHC recompute block. | ||
|
|
||
| When set, every `mhc_recompute_layer_num` layers form a recompute block. The last layer | ||
| in each recompute block (i.e., layer_number % mhc_recompute_layer_num == 0 or the final | ||
| layer in the transformer block) will: | ||
| - NOT checkpoint its final MLP BDA | ||
| - Register the unified recompute hook on its MLP BDA output | ||
| - A new MHCRecomputeManager is created for subsequent layers | ||
|
|
||
| If None, all layers in the transformer block share a single recompute block. | ||
|
|
||
| Must be a positive integer when set.""" | ||
|
|
||
| #################### | ||
| # miscellaneous | ||
| #################### |
There was a problem hiding this comment.
[SUGGESTION Naming] Consider documenting the relationship between num_residual_streams and hidden dimension impact. The n-stream architecture multiplies the effective hidden dimension by n within the transformer block ([s, b, n*C]), which has significant memory implications. A brief note like "memory scales linearly with this value within the transformer block" would help users set this parameter appropriately.
| - [C] - bias tensor (will be broadcast) | ||
| h_post: [s, b, n] - expansion weights | ||
|
|
||
| Returns: | ||
| output: [s, b, n*C] - expanded tensor | ||
| """ | ||
| n = self.n | ||
| s, b, _ = h_post.shape | ||
|
|
||
| if x.dim() == 1: | ||
| # x is bias with shape [C], need to broadcast to [s, b, 1, C] | ||
| C = x.shape[0] | ||
| x_expanded = x.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand(s, b, 1, C) | ||
| else: | ||
| # x is [s, b, C] |
There was a problem hiding this comment.
[SUGGESTION Simplification] _apply_h_post is @torch.compile-decorated but branches on x.dim(). It's called with both 1D (bias) and 3D (hidden_states) inputs within the same forward pass, which will trigger two separate graph specializations and recompilation on the first occurrence of each shape rank.
This is not a correctness issue, but for cleaner compilation consider splitting into two methods (_apply_h_post_hidden for 3D and _apply_h_post_bias for 1D) to avoid the dynamic branching under @torch.compile.
Code Review SummaryPR: Add mHC transformer reference implementation Reviewed the full diff (3529 additions, 79 deletions across 19 files) covering the core Manifold Hyper-Connections (mHC) transformer implementation on Findings
Key FindingsIMPORTANT:
SUGGESTION:
What Looks Good
Overall AssessmentLow risk. The core mHC implementation is mathematically sound and well-integrated with Megatron's checkpoint, CUDA graph, and parallelism infrastructure. The important findings are about dead code and a behavior change in shared infrastructure ( |
|
Addressed Claude's latest low-risk review items in
Validation completed:
I also attempted the focused hsg Slurm/container pytest subset, including the selective-recompute FP8 cases Claude called out, but the 4-GPU allocations stayed pending or returned |
|
/claude strict-review |
| # Intentional bug fix: the default fp8=False must not enter the TE FP8 | ||
| # recompute context. The old "fp8 is not None" behavior did that for | ||
| # every default CheckpointWithoutOutput() caller. | ||
| self.fp8 = bool(fp8) |
There was a problem hiding this comment.
[IMPORTANT Compatibility] This changes the default behavior of CheckpointWithoutOutput() for all callers, not just mHC callers.
Previously, the default fp8=False resulted in self.fp8 = (False is not None) → True, meaning every default CheckpointWithoutOutput() instance (including layernorm and MLA recompute) incorrectly entered the TE FP8 recompute context during _recompute(). The fix to bool(fp8) is correct, but it silently changes behavior for the existing selective-recompute code paths.
In practice this is likely safe — the TE FP8 context is probably a no-op when TE FP8 is not actually enabled — but it's worth calling out in the PR description as a cross-cutting behavioral change bundled into an mHC feature PR. If any existing test relied on the old (buggy) default, it would break here.
| FineGrainedActivationOffloadingInterface as off_interface, | ||
| ) | ||
|
|
||
| hidden_states = off_interface.group_commit(hidden_states, name="mlp_norm") |
There was a problem hiding this comment.
[SUGGESTION Performance] Missing forced_released_tensors=[residual] — the base class _forward_post_mlp (line 901-903) passes forced_released_tensors=[residual] to group_commit, which explicitly releases the residual tensor at the offloading sync point. The HC version omits this.
For mHC layers the residual is [s, b, n*C] — n times larger than the standard residual — so the delayed release has a proportionally larger memory cost. The tensor is eventually collected when _forward_post_mlp_with_fused_hyper_connection returns, so the impact window is small, but adding the explicit release matches the base class behavior.
The same issue exists for the attention offloading path at line 1562 (missing forced_released_tensors=[residual] vs. base line 666-668).
| hidden_states = off_interface.group_commit(hidden_states, name="mlp_norm") | |
| hidden_states = off_interface.group_commit( | |
| hidden_states, name="mlp_norm", forced_released_tensors=[residual] | |
| ) |
| nvtx_range_pop(suffix="self_attention_h_res_h_post_bda") | ||
|
|
||
| if self.offload_attn_norm: | ||
| hidden_states = off_interface.group_commit(hidden_states, name="attn_norm") |
There was a problem hiding this comment.
[SUGGESTION Performance] Same as the MLP offloading path: base class (line 666-668) passes forced_released_tensors=[residual] here. The n-stream residual [s, b, n*C] is worth releasing explicitly.
| hidden_states = off_interface.group_commit(hidden_states, name="attn_norm") | |
| hidden_states = off_interface.group_commit( | |
| hidden_states, name="attn_norm", forced_released_tensors=[residual] | |
| ) |
| mlp_h_res (Tensor): [s, b, n, n] - residual mixing matrix from hyper connection. | ||
| residual (Tensor): [s, b, n*C] - original residual (n-stream hidden states). | ||
| mlp_hc_h_post (Tensor): [s, b, n] - expansion weights from hyper connection. | ||
| mhc_recompute_manager: Optional MHCRecomputeManager for checkpoint management. |
There was a problem hiding this comment.
[SUGGESTION Naming] Minor docstring mismatch: this says mhc_recompute_manager but the parameter is mhc_mlp_bda_recompute_manager.
| mhc_recompute_manager: Optional MHCRecomputeManager for checkpoint management. | |
| mhc_mlp_bda_recompute_manager: Optional MHCRecomputeManager for checkpoint management. |
There was a problem hiding this comment.
Review Summary
CRITICAL: 0 | IMPORTANT: 1 | SUGGESTION: 3
Overall Assessment
This PR ports the core mHC (Manifold-Constrained Hyper-Connections) transformer support onto dsv4. The implementation is well-structured and carefully guarded with comprehensive config validation — pipeline parallelism, full activation recompute, fused TP inference, and fine-grained offloading + mhc recompute are all properly blocked or validated. The new MHCRecomputeManager with forward-order unified recomputation is correctly designed and thoroughly tested (including CUDA graph capture+replay, memory savings, and numerical equivalence tests).
Key Findings
-
[IMPORTANT Compatibility]
CheckpointWithoutOutput.fp8behavioral change (random.py:835): The fix fromfp8 is not Nonetobool(fp8)corrects a bug where the defaultfp8=Falseincorrectly entered the TE FP8 recompute context. This is a cross-cutting fix that affects all existingCheckpointWithoutOutput()callers (layernorm recompute, MLA recompute, etc.), not just mHC. Worth noting in the PR description. -
[SUGGESTION Performance] Missing
forced_released_tensors=[residual]in the HC activation offloading paths (transformer_layer.py:1562, 1745). The base class explicitly releases the residual tensor at the offloading sync point; the HC version does not. Since the HC residual isn×larger, the omission has proportionally more impact. -
[SUGGESTION Naming] Minor docstring/parameter name mismatch in
_forward_post_mlp_with_fused_hyper_connection.
What looks good
- Config validation is thorough — all unsupported combinations (PP, full recompute, fused TP inference, mhc+offloading, mhc+mlp recompute) are caught with clear error messages
MHCRecomputeManagercorrectly enforces forward-order recomputation for chained checkpoints, with properctx=Noneguard against double-recompute_save_args_to_ctx/_load_args_from_ctxproperly generalizesCheckpointWithoutOutputto handle mixed tensor/non-tensor arguments- CUDA graph compatibility is well-handled via the
ContextVarhandoff pattern formhc_recompute_manager, and theget_layer_static_inputs/_get_submodules_under_cudagraphsoverrides - MTP spec stripping correctly uses
issubclass(spec.module, TransformerLayer)(changed from==) andcopy.copyto avoid mutating the original spec - Test coverage is excellent — checkpoint correctness, memory savings, CUDA graph capture+replay, offloading numerical equivalence, config validation edge cases
Risk Level: Low-Medium
The only behavioral change to shared code is the fp8 flag fix, which corrects a latent bug. All mHC code paths are gated behind enable_hyper_connections=True (default False), so existing training configurations are unaffected.
What does this PR do?
This is PR 1 in a fine-grained split of #4469. It ports the core/reference Transformer support for Manifold Hyper-Connections (mHC) onto
dsv4.The original mHC PR targeting
mainis #3430. This split targetsdsv4per the DeepSeek/DSv4 workflow discussion, so reviewers can inspect the change in smaller pieces.This PR includes the code path only:
HyperConnectionModulereference implementationHyperConnectionTransformerLayerintegrationCheckpointManagerFollow-up split PRs are stacked after this one:
Testing
git diff --checkpassed for the split stack.python -m compileallpassed for the touched Python files in the full split stack.