diff --git a/gpt_builders.py b/gpt_builders.py index 24b5f89d311..72e3bb8c550 100644 --- a/gpt_builders.py +++ b/gpt_builders.py @@ -136,6 +136,7 @@ def _get_transformer_layer_spec(use_te, config): use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False), + enable_hyper_connections=config.enable_hyper_connections, ) elif config.transformer_impl == "inference_optimized": return get_gpt_layer_with_inference_spec( @@ -154,4 +155,5 @@ def _get_transformer_layer_spec(use_te, config): use_kitchen=config.use_kitchen, use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, + enable_hyper_connections=config.enable_hyper_connections, ) diff --git a/megatron/core/fusions/fused_bias_dropout.py b/megatron/core/fusions/fused_bias_dropout.py index 2eb4007f75c..db11fafd5cb 100644 --- a/megatron/core/fusions/fused_bias_dropout.py +++ b/megatron/core/fusions/fused_bias_dropout.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. from typing import Optional, Tuple import torch @@ -81,6 +81,16 @@ def bias_dropout_add_fused_inference( def get_bias_dropout_add(training, fused): + """ + Get the bias-dropout-add function. + + Args: + training: Whether in training mode. + fused: Whether to use fused implementation. + + Returns: + A callable that performs bias-dropout-add operation. + """ if fused: # jit scripting for a nn.module (with dropout) is not # triggering the fusion kernel. For now, we use two diff --git a/megatron/core/models/gpt/experimental_attention_variant_module_specs.py b/megatron/core/models/gpt/experimental_attention_variant_module_specs.py index 1b03b935639..4385f49ca8c 100644 --- a/megatron/core/models/gpt/experimental_attention_variant_module_specs.py +++ b/megatron/core/models/gpt/experimental_attention_variant_module_specs.py @@ -12,6 +12,7 @@ DSAttention, DSAttentionSubmodules, ) +from megatron.core.transformer.hyper_connection import HyperConnectionModule from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.multi_latent_attention import ( MLASelfAttention, @@ -24,6 +25,7 @@ ) from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import ( + HyperConnectionTransformerLayer, TransformerLayer, TransformerLayerSubmodules, get_transformer_layer_offset, @@ -227,6 +229,10 @@ def get_transformer_block_with_experimental_attention_variant_spec( # Get GPT decoder block layer specs rms_norm = config.normalization == "RMSNorm" + enable_hc = config.enable_hyper_connections + hc_module = HyperConnectionModule if enable_hc else IdentityOp + layer_module = HyperConnectionTransformerLayer if enable_hc else TransformerLayer + layer_specs = [] for layer_number in range(config.num_layers): attention = ( @@ -248,14 +254,16 @@ def get_transformer_block_with_experimental_attention_variant_spec( layer_specs.append( ModuleSpec( - module=TransformerLayer, + module=layer_module, submodules=TransformerLayerSubmodules( input_layernorm=input_layernorm, self_attention=attention, self_attn_bda=get_bias_dropout_add, + self_attention_hyper_connection=hc_module, pre_mlp_layernorm=pre_mlp_layernorm, mlp=mlp, mlp_bda=get_bias_dropout_add, + mlp_hyper_connection=hc_module, ), ) ) diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index 5e90f0b36be..9ca0e9494c7 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -1,4 +1,5 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +import copy import warnings from typing import Optional, Union @@ -12,6 +13,7 @@ from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec_for_backend from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.enums import AttnMaskType, LayerType +from megatron.core.transformer.hyper_connection import HyperConnectionModule from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP, MLPSubmodules from megatron.core.transformer.multi_latent_attention import ( @@ -34,6 +36,7 @@ ) from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import ( + HyperConnectionTransformerLayer, TransformerLayer, TransformerLayerSubmodules, get_transformer_layer_offset, @@ -183,6 +186,7 @@ def get_gpt_layer_with_transformer_engine_submodules( use_kitchen_attention: bool = False, kitchen_attention_backend: str = "sdpa", mla_down_proj_fusion: bool = False, + enable_hyper_connections: bool = False, ) -> TransformerLayerSubmodules: """Use these submodules to use lower-level Transformer Engine modules (required for fp8 training). @@ -200,6 +204,8 @@ def get_gpt_layer_with_transformer_engine_submodules( mla_down_proj_fusion (bool, optional): Enable fused q/kv down-projection and fused input layernorm when backend supports. Otherwise fall back to the unfused MLA. + enable_hyper_connections (bool): Use HyperConnectionTransformerLayer with + HyperConnectionModule instead of plain TransformerLayer. Defaults to False. Returns: TransformerLayerSubmodules: TE modules to construct a TransformerLayer @@ -233,6 +239,8 @@ def get_gpt_layer_with_transformer_engine_submodules( use_te_activation_func=use_te_activation_func, ) + hc_module = HyperConnectionModule if enable_hyper_connections else IdentityOp + if multi_latent_attention: assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA." linear_q_up_proj = ( @@ -302,9 +310,11 @@ def get_gpt_layer_with_transformer_engine_submodules( ), ), self_attn_bda=get_bias_dropout_add, + self_attention_hyper_connection=hc_module, pre_mlp_layernorm=backend.layer_norm(has_residual=True) if num_experts else IdentityOp, mlp=mlp, mlp_bda=get_bias_dropout_add, + mlp_hyper_connection=hc_module, ) else: qk_norm = backend.layer_norm(for_qk=True) @@ -325,9 +335,11 @@ def get_gpt_layer_with_transformer_engine_submodules( ), ), self_attn_bda=get_bias_dropout_add, + self_attention_hyper_connection=hc_module, pre_mlp_layernorm=backend.layer_norm(has_residual=True) if num_experts else IdentityOp, mlp=mlp, mlp_bda=get_bias_dropout_add, + mlp_hyper_connection=hc_module, sharded_state_dict_keys_map={ "mlp.0.weight": "mlp.linear_fc1.layer_norm_weight", "mlp.0.bias": "mlp.linear_fc1.layer_norm_bias", @@ -342,8 +354,10 @@ def get_gpt_layer_with_transformer_engine_submodules( @copy_signature(get_gpt_layer_with_transformer_engine_submodules) def get_gpt_layer_with_transformer_engine_spec(*args, **kwargs) -> ModuleSpec: """Use this spec to use lower-level Transformer Engine modules (required for fp8 training).""" + enable_hc = kwargs.get('enable_hyper_connections', False) + layer_module = HyperConnectionTransformerLayer if enable_hc else TransformerLayer return ModuleSpec( - module=TransformerLayer, + module=layer_module, submodules=get_gpt_layer_with_transformer_engine_submodules(*args, **kwargs), ) @@ -359,6 +373,7 @@ def get_gpt_layer_local_submodules( use_kitchen: bool = False, use_kitchen_attention: bool = False, kitchen_attention_backend: str = "sdpa", + enable_hyper_connections: bool = False, ) -> TransformerLayerSubmodules: """Use these submodules for an implementation using only modules in Megatron-Core. @@ -370,6 +385,8 @@ def get_gpt_layer_local_submodules( multi_latent_attention (bool, optional): To use MLA. Defaults to False. fp8 (str, optional): Deprecated. For temporary Nemo compatibility. qk_l2_norm (bool, optional): To use l2 norm for queries/keys. Defaults to False. + enable_hyper_connections (bool): Use HyperConnectionTransformerLayer with + HyperConnectionModule instead of plain TransformerLayer. Defaults to False. Returns: TransformerLayerSubmodules: Megatron-Core modules to construct a TransformerLayer @@ -402,6 +419,8 @@ def get_gpt_layer_local_submodules( backend=backend, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm ) + hc_module = HyperConnectionModule if enable_hyper_connections else IdentityOp + if multi_latent_attention: assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA." return TransformerLayerSubmodules( @@ -422,9 +441,11 @@ def get_gpt_layer_local_submodules( ), ), self_attn_bda=get_bias_dropout_add, + self_attention_hyper_connection=hc_module, pre_mlp_layernorm=layer_norm, mlp=mlp, mlp_bda=get_bias_dropout_add, + mlp_hyper_connection=hc_module, ) else: return TransformerLayerSubmodules( @@ -445,9 +466,11 @@ def get_gpt_layer_local_submodules( ), ), self_attn_bda=get_bias_dropout_add, + self_attention_hyper_connection=hc_module, pre_mlp_layernorm=layer_norm, mlp=mlp, mlp_bda=get_bias_dropout_add, + mlp_hyper_connection=hc_module, sharded_state_dict_keys_map={ "input_layernorm.": "self_attention.linear_qkv.layer_norm_", "pre_mlp_layernorm.": "mlp.linear_fc1.layer_norm_", @@ -458,8 +481,10 @@ def get_gpt_layer_local_submodules( @copy_signature(get_gpt_layer_local_submodules) def get_gpt_layer_local_spec(*args, **kwargs) -> ModuleSpec: """Use this spec for an implementation using only modules in Megatron-Core.""" + enable_hc = kwargs.get('enable_hyper_connections', False) + layer_module = HyperConnectionTransformerLayer if enable_hc else TransformerLayer return ModuleSpec( - module=TransformerLayer, submodules=get_gpt_layer_local_submodules(*args, **kwargs) + module=layer_module, submodules=get_gpt_layer_local_submodules(*args, **kwargs) ) @@ -568,6 +593,7 @@ def get_gpt_decoder_layer_specs( use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False), + enable_hyper_connections=config.enable_hyper_connections, ) moe_layer_spec = get_gpt_layer_with_transformer_engine_spec( num_experts=config.num_moe_experts, @@ -580,6 +606,7 @@ def get_gpt_decoder_layer_specs( use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False), + enable_hyper_connections=config.enable_hyper_connections, ) elif config.transformer_impl == "inference_optimized": layer_norm_impl = TENorm @@ -608,6 +635,7 @@ def get_gpt_decoder_layer_specs( use_kitchen=config.use_kitchen, use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, + enable_hyper_connections=config.enable_hyper_connections, ) moe_layer_spec = get_gpt_layer_local_spec( num_experts=config.num_moe_experts, @@ -619,6 +647,7 @@ def get_gpt_decoder_layer_specs( use_kitchen=config.use_kitchen, use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, + enable_hyper_connections=config.enable_hyper_connections, ) # Parse config.moe_layer_freq to determine the pattern of expert/dense layers. @@ -744,12 +773,22 @@ def get_gpt_mtp_block_spec_for_backend( if isinstance(spec, TransformerBlockSubmodules): # get the spec for the last layer of decoder block - transformer_layer_spec = spec.layer_specs[-1] - elif isinstance(spec, ModuleSpec) and spec.module == TransformerLayer: - transformer_layer_spec = spec + transformer_layer_spec = copy.copy(spec.layer_specs[-1]) + elif isinstance(spec, ModuleSpec) and issubclass(spec.module, TransformerLayer): + transformer_layer_spec = copy.copy(spec) else: raise ValueError(f"Invalid spec: {spec}") + transformer_layer_spec.submodules = copy.copy(transformer_layer_spec.submodules) + + # MTP does not support hyper connections yet; strip HC modules and + # downgrade the layer class to plain TransformerLayer. + transformer_layer_spec.submodules.self_attention_hyper_connection = IdentityOp + transformer_layer_spec.submodules.cross_attention_hyper_connection = IdentityOp + transformer_layer_spec.submodules.mlp_hyper_connection = IdentityOp + if transformer_layer_spec.module is HyperConnectionTransformerLayer: + transformer_layer_spec.module = TransformerLayer + mtp_layer_spec = get_mtp_layer_spec_for_backend( mtp_model_layer_spec=transformer_layer_spec, backend=backend ) diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py index 92d39ba92ef..e0c243b26d2 100644 --- a/megatron/core/tensor_parallel/random.py +++ b/megatron/core/tensor_parallel/random.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Parts of the code here are adapted from PyTorch # repo: https://github.com/pytorch/pytorch @@ -598,7 +598,9 @@ def forward( @staticmethod def backward(ctx, *args): """Backward pass.""" - if not torch.autograd._is_checkpoint_valid(): + from megatron.core.transformer.cuda_graphs import is_graph_capturing + + if not torch.autograd._is_checkpoint_valid() and not is_graph_capturing(): raise RuntimeError( "Checkpointing is not compatible with .grad(), " "please use .backward() if possible" @@ -642,10 +644,67 @@ def checkpoint( return CheckpointFunction.apply(function, distribute_saved_activations, *args) +def _save_args_to_ctx(ctx, args): + """Save mixed tensor/non-tensor arguments into autograd ctx. + + Since save_for_backward only supports tensors, this function separates + tensor and non-tensor arguments, saving tensors via save_for_backward + and storing non-tensor metadata (indices and values) as ctx attributes. + + Use _load_args_from_ctx to reconstruct the original args. + """ + tensor_args = [] + non_tensor_entries = [] + + for index, arg in enumerate(args): + if isinstance(arg, torch.Tensor): + tensor_args.append(arg) + continue + non_tensor_entries.append((index, arg)) + + ctx.save_for_backward(*detach_variable(tuple(tensor_args))) + ctx._non_tensor_entries = tuple(non_tensor_entries) + ctx._total_args_count = len(args) + + +def _load_args_from_ctx(ctx): + """Load and reconstruct mixed tensor/non-tensor arguments from autograd ctx. + + This is the inverse of _save_args_to_ctx. It retrieves tensors from + ctx.saved_tensors and merges them with stored non-tensor arguments + to reconstruct the original args in their original order. + + Returns: + tuple of reconstructed arguments in their original order. + """ + + def _detach_with_grad(tensor): + detached = tensor.detach() + detached.requires_grad_(tensor.requires_grad) + return detached + + tensor_iter = iter(_detach_with_grad(t) for t in ctx.saved_tensors) + total_args_count = ctx._total_args_count + non_tensor_map = dict(ctx._non_tensor_entries) + + reconstructed_args = [] + for index in range(total_args_count): + if index in non_tensor_map: + reconstructed_args.append(non_tensor_map[index]) + else: + reconstructed_args.append(next(tensor_iter)) + return tuple(reconstructed_args) + + class CheckpointWithoutOutputFunction(torch.autograd.Function): """ Checkpoint Function Helper for CheckpointWithoutOutput. Save context for recompute. + + Handles both tensor and non-tensor arguments: + - Tensor arguments are saved via save_for_backward + - Non-tensor arguments (int, float, bool, None, etc.) are stored separately + in ctx attributes and reconstructed during recomputation """ @staticmethod @@ -668,7 +727,10 @@ def forward( with torch.no_grad(), fwd_ctx: outputs = run_function(*args) - ctx.save_for_backward(*detach_variable(args)) + + # Save tensor and non-tensor arguments into ctx for recomputation + _save_args_to_ctx(ctx, args) + # the CheckpointWithoutOutput object is passed in, then it can access the saved input # tensors later for recomputation checkpoint_without_output_obj.ctx = ctx @@ -685,10 +747,63 @@ def backward(ctx, *args): torch.autograd.backward(outputs, args) ctx.outputs = None ctx.inputs = None - grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in inputs) + # Autograd expects None for non-tensor inputs; returning the original + # non-tensor value as a "gradient" is invalid once args are restored. + grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs) return (None, None) + grads +class MHCRecomputeManager: + """ + Manages cross-layer mHC CheckpointWithoutOutput recomputation blocks. + + This is particularly useful when multiple checkpointed operations have sequential + dependencies (i.e., the output of one checkpoint is the input of the next). + The manager recomputes checkpoints in forward order so each checkpoint restores + its output storage before the next checkpoint consumes it during recomputation. + + Usage: + mhc_manager = MHCRecomputeManager() + ckpt_function = CheckpointWithoutOutput(ckpt_manager=mhc_manager) + ckpt_function.checkpoint(run_function, *args) + # other checkpointed operations + mhc_manager.discard_all_outputs_and_register_unified_recompute(final_output) + """ + + def __init__(self): + self.checkpoints = [] + # Set by TransformerBlock before each layer forward. + # When True, the layer should keep block-boundary output uncheckpointed. + self.is_last_layer_in_recompute_block = False + + def add_checkpoint(self, ckpt): + """Add a checkpoint to the manager.""" + if not isinstance(ckpt, CheckpointWithoutOutput): + raise TypeError("Expected CheckpointWithoutOutput object") + if ckpt.outputs is None: + raise ValueError("CheckpointWithoutOutput must call checkpoint() before adding") + self.checkpoints.append(ckpt) + + def discard_all_outputs_and_register_unified_recompute(self, hook_tensor): + """Discard all checkpoint outputs to save memory and register unified recompute hook.""" + if not hook_tensor.requires_grad: + return + + for ckpt in self.checkpoints: + for output in ckpt.outputs: + output.untyped_storage().resize_(0) + + hook_tensor.register_hook(self._unified_recompute_hook) + + def _unified_recompute_hook(self, grad_output): + # Chained checkpoints rely on forward-order recomputation. Each + # _recompute() restores output storage in-place on the original tensor + # objects, so a later checkpoint sees its saved input storage restored + # before it recomputes. + for ckpt in self.checkpoints: + ckpt._recompute(None) + + class CheckpointWithoutOutput(object): """ Checkpoint a model or part of the model and release the output. @@ -703,8 +818,22 @@ class CheckpointWithoutOutput(object): discarded output tensors are directly saved in the following modules for backward computation. """ - def __init__(self, fp8=False): - self.fp8 = fp8 is not None + def __init__(self, fp8=False, ckpt_manager=None): + """ + Initialize CheckpointWithoutOutput. + + Args: + fp8: Whether to use FP8 mode. Defaults to False. + ckpt_manager: Optional MHCRecomputeManager instance. When provided, + checkpoint() will auto-register to the manager, and + discard_output_and_register_recompute() will only discard + output without registering individual hooks. + """ + # Intentional bug fix: the default fp8=False must not enter the TE FP8 + # recompute context. The old "fp8 is not None" behavior did that for + # every default CheckpointWithoutOutput() caller. + self.fp8 = bool(fp8) + self.ckpt_manager = ckpt_manager self.run_function = None self.fwd_cpu_rng_state = None self.fwd_cuda_rng_state = None @@ -713,7 +842,12 @@ def __init__(self, fp8=False): self.outputs = None def checkpoint(self, run_function: Callable[[Unpack[_Ts]], _R], *args: Unpack[_Ts]) -> _R: - """Checkpoint function.""" + """ + Checkpoint function. + + If ckpt_manager was provided during initialization, this checkpoint + will be automatically registered to the manager after execution. + """ # If in cuda graph warmup, disable checkpointing, as 'discard_output_and_register_recompute' # may be called in a separate graph warmup. @@ -730,6 +864,11 @@ def checkpoint(self, run_function: Callable[[Unpack[_Ts]], _R], *args: Unpack[_T self.outputs = outputs if isinstance(self.outputs, torch.Tensor): self.outputs = (self.outputs,) + + # Auto-register to manager if provided + if self.ckpt_manager is not None: + self.ckpt_manager.add_checkpoint(self) + return outputs def _recompute(self, _): @@ -738,7 +877,7 @@ def _recompute(self, _): from megatron.core.transformer.cuda_graphs import is_graph_capturing, is_graph_warmup # The recomputation has been triggered already. Just return. - # Handle cudagraphs, do nothing if currently in graph warmup + # Handle cudagraphs: do nothing if currently in graph warmup if self.ctx is None or is_graph_warmup(): return @@ -760,17 +899,8 @@ def _recompute(self, _): recompute_ctx = contextlib.nullcontext() fp8_ctx = contextlib.nullcontext() - # Store the inputs for backward pass - inputs = self.ctx.saved_tensors - - def detach(t): - if isinstance(t, torch.Tensor): - requires_grad = t.requires_grad - t = t.detach() - t.requires_grad_(requires_grad) - return t - - inputs = tuple(detach(t) for t in inputs) + # Reconstruct full args list from saved ctx + inputs = _load_args_from_ctx(self.ctx) with torch.enable_grad(), fp8_ctx, recompute_ctx: outputs = self.run_function(*inputs) @@ -803,10 +933,11 @@ def discard_output_and_register_recompute(self, hook_tensor): in the forward pass and the gradient of the hook_tensor is computed before the recomputed tensors are used. """ - + # When ckpt_manager is set, this is a no-op. + # Manager handles all discarding and hook registration uniformly. from megatron.core.transformer.cuda_graphs import is_graph_warmup - if is_graph_warmup(): + if self.ckpt_manager is not None or is_graph_warmup(): return # use resize to release the output tensor memory and still keep the metadata in the tensors. diff --git a/megatron/core/transformer/__init__.py b/megatron/core/transformer/__init__.py index 0e3cdcfa57e..75e3b485c4f 100644 --- a/megatron/core/transformer/__init__.py +++ b/megatron/core/transformer/__init__.py @@ -1,6 +1,10 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. from .module import MegatronModule from .spec_utils import ModuleSpec, build_module from .transformer_config import MLATransformerConfig, TransformerConfig -from .transformer_layer import TransformerLayer, TransformerLayerSubmodules +from .transformer_layer import ( + HyperConnectionTransformerLayer, + TransformerLayer, + TransformerLayerSubmodules, +) diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 067f6055015..af5a2e35672 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import dataclasses import gc diff --git a/megatron/core/transformer/hyper_connection.py b/megatron/core/transformer/hyper_connection.py new file mode 100644 index 00000000000..a2f2d013977 --- /dev/null +++ b/megatron/core/transformer/hyper_connection.py @@ -0,0 +1,707 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import math +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.nn as nn +from torch import Tensor + +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import nvtx_decorator + +if TYPE_CHECKING: + from megatron.core.tensor_parallel.random import MHCRecomputeManager + + +@torch.compile +def _sinkhorn_iterations(input_logits: Tensor, num_iterations: int, eps: float) -> Tensor: + input_dtype = input_logits.dtype + input_logits = input_logits.float() + row_max = input_logits.max(dim=-1, keepdim=True).values + M = torch.exp(input_logits - row_max) + for _ in range(num_iterations): + M = M / M.sum(dim=-1, keepdim=True).clamp(min=eps) + M = M / M.sum(dim=-2, keepdim=True).clamp(min=eps) + return M.to(dtype=input_dtype) + + +class SinkhornKnopp(torch.autograd.Function): + """Sinkhorn-Knopp projection to doubly stochastic matrix. + + This is an autograd.Function because the iterative forward is re-executed + during backward (under torch.enable_grad) so that PyTorch's autograd can + differentiate through it without storing all intermediate iteration states. + """ + + @staticmethod + def forward(ctx, input_logits: Tensor, num_iterations: int, eps: float = 1e-6) -> Tensor: + """Run Sinkhorn iterations and save inputs for backward recomputation.""" + M = _sinkhorn_iterations(input_logits, num_iterations, eps) + ctx.save_for_backward(input_logits) + ctx.num_iterations = num_iterations + ctx.eps = eps + return M + + @staticmethod + def backward(ctx, grad_output: Tensor): + """Recompute forward under enable_grad and back-propagate.""" + (input_logits,) = ctx.saved_tensors + with torch.enable_grad(): + logits = input_logits.detach().requires_grad_(True) + M = _sinkhorn_iterations(logits, ctx.num_iterations, ctx.eps) + M.backward(grad_output) + return logits.grad, None, None + + +def reference_sinkhorn(input_logits: Tensor, num_iterations: int, eps: float = 1e-6) -> Tensor: + """Reference Sinkhorn-Knopp (autograd.Function wrapper).""" + return SinkhornKnopp.apply(input_logits, num_iterations, eps) + + +@torch.compile +def reference_h_aggregate(x: Tensor, h_pre: Tensor) -> Tensor: + """Reference n-stream weighted aggregation: out = sum_j(h_pre_j * x_j).""" + return (x * h_pre.unsqueeze(-1)).sum(dim=2) + + +@torch.compile +def reference_h_post_bda( + h_res: Tensor, original_residual: Tensor, h_post: Tensor, x: Tensor, bias: Optional[Tensor] +) -> Tensor: + """Reference H_res @ residual + H_post * (x [+ bias]), flattened to [s, b, n*C].""" + s, b, n, C = original_residual.shape + h_res_batched = h_res.view(s * b, n, n) + residual_batched = original_residual.view(s * b, n, C) + mixed = torch.bmm(h_res_batched, residual_batched).view(s, b, n, C) + x_expanded = h_post.unsqueeze(-1) * x.unsqueeze(2) + if bias is not None: + bias_expanded = h_post.unsqueeze(-1) * bias.view(1, 1, 1, C) + return (x_expanded + bias_expanded + mixed).view(s, b, n * C) + return (x_expanded + mixed).view(s, b, n * C) + + +@torch.compile +def reference_proj_inv_rms(x: Tensor, weight: Tensor, eps: float = 1e-6) -> Tuple[Tensor, Tensor]: + """Reference fused projection + inverse-RMS normalization scale.""" + input_dtype = x.dtype + x_float = x.float() + weight_float = weight.float() + proj = torch.matmul(x_float, weight_float.t()).to(dtype=input_dtype) + norm = x_float.norm(dim=-1, keepdim=True) + K = x.shape[-1] + rms = norm / math.sqrt(K) + eps + inv_rms = (1.0 / rms).to(dtype=input_dtype) + return proj, inv_rms + + +# ============================================================================ +# HyperConnectionModule +# ============================================================================ + + +class HyperConnectionModule(MegatronModule): + """ + Unified mHC (Manifold-Constrained Hyper-Connections) module. + + Implements the complete mHC propagation: + x_{l+1} = H_res @ x_l + H_post^T @ F(H_pre @ x_l) + + This module handles: + 1. Computing learnable mappings: H_pre, H_post, H_res (with Sinkhorn-Knopp projection) + 2. Aggregation: n-stream → 1-stream (H_pre @ x) + 3. Expansion: 1-stream → n-stream (H_post^T @ output) + 4. Residual merge: H_res @ x + expanded_output + 5. Block-level expand/contract for TransformerBlock boundaries + + Args: + config: TransformerConfig with hyper-connection fields + layer_number: Current layer index for initialization + """ + + def __init__(self, config: TransformerConfig, layer_number: int): + super().__init__(config) + self.config = config + self.layer_number = layer_number + self.n = config.num_residual_streams + self.hidden_size = config.hidden_size + self.sinkhorn_iterations = config.mhc_sinkhorn_iterations + + # Projection weights for dynamic mappings. The reference implementation + # keeps this as a full, non-TP-partitioned linear projection over n*C, so + # TP ranks duplicate the projection compute and store the n*C activation + # for the weight gradient. TODO: replace with fused/partitioned variants + # in the fused mHC follow-up. + # Input: [s, b, n*C] -> Output: n^2 + 2n values per token + # - H_pre: n values + # - H_post: n values + # - H_res: n^2 values (before Sinkhorn projection) + self.mapping_proj = nn.Linear( + self.n * self.hidden_size, self.n * self.n + 2 * self.n, bias=False + ) + + init_alpha = config.mhc_init_gating_factor + # Learnable scaling factors (Eq. 5 in paper) + self.alpha_pre = nn.Parameter(torch.full((1,), init_alpha)) + self.alpha_post = nn.Parameter(torch.full((1,), init_alpha)) + self.alpha_res = nn.Parameter(torch.full((1,), init_alpha)) + + # Static bias terms + self.bias = nn.Parameter(torch.zeros(self.n * self.n + 2 * self.n)) + self.norm_eps = 1e-6 + + # Choose implementation: fused cuTile kernels vs reference modules. + # Both paths expose the same call signatures so the rest of the code + # is implementation-agnostic. + if config.use_fused_mhc: + from megatron.core.fusions.fused_mhc_kernels import ( + fused_h_aggregate, + fused_h_post_bda, + fused_proj_rms, + fused_sinkhorn, + ) + + self._sinkhorn_op = fused_sinkhorn + self._h_aggregate_op = fused_h_aggregate + self._h_post_bda_op = fused_h_post_bda + self._proj_inv_rms_op = fused_proj_rms + else: + self._sinkhorn_op = reference_sinkhorn + self._h_aggregate_op = reference_h_aggregate + self._h_post_bda_op = reference_h_post_bda + self._proj_inv_rms_op = reference_proj_inv_rms + + self._init_weights() + + def _init_weights(self) -> None: + """Initialize weights for stable training.""" + nn.init.xavier_uniform_(self.mapping_proj.weight) + + # Set sequence_parallel attribute on parameters for gradient synchronization + # across TP ranks when sequence_parallel is enabled. + # This is required because HyperConnectionModule uses non-TP-aware layers + # (nn.Linear, nn.RMSNorm) whose gradients need to be all-reduced. + if self.config.sequence_parallel: + self.mapping_proj.weight.sequence_parallel = True + self.alpha_pre.sequence_parallel = True + self.alpha_post.sequence_parallel = True + self.alpha_res.sequence_parallel = True + self.bias.sequence_parallel = True + + def _projection_and_get_norm(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """ + Projection + RMS normalization. + + Args: + x: [s, b, n*C] - n-stream hidden states + """ + s, b, nC = x.shape + x_2d = x.reshape(s * b, nC) + proj, inv_rms = self._proj_inv_rms_op(x_2d, self.mapping_proj.weight, self.norm_eps) + return proj.view(s, b, -1), inv_rms.view(s, b, 1) + + @torch.compile + def _compute_h(self, proj: Tensor, inv_rms: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """ + Compute h from projected hidden states and scaling factors. + + Args: + proj: [s, b, n^2 + 2n] - projected hidden states + inv_rms: [s, b, 1] - inverse-RMS scaling factors + + Returns: + h_pre: [s, b, n] - aggregation weights + h_post: [s, b, n] - expansion weights + h_res: [s, b, n^2] - residual mixing logits + """ + alpha_ = torch.cat( + [ + self.alpha_pre.expand(self.n), + self.alpha_post.expand(self.n), + self.alpha_res.expand(self.n * self.n), + ], + dim=-1, + ) + h = inv_rms.float() * proj.float() * alpha_.float() + self.bias.float() + # H_pre = σ(α_pre * (θ_pre @ x̃) + b_pre) + h_pre = h[..., : self.n].sigmoid().to(dtype=proj.dtype) # [s, b, n] + + # H_post = 2σ(α_post * (θ_post @ x̃) + b_post) + h_post = (h[..., self.n : 2 * self.n].sigmoid() * 2).to(dtype=proj.dtype) # [s, b, n] + h_res = h[..., 2 * self.n :].to(dtype=proj.dtype) + return h_pre, h_post, h_res + + @nvtx_decorator(message="HyperConnection::compute_mappings") + def compute_mappings(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """ + Compute mHC mappings from input hidden states. + + Reference: Eq. (5) and (8) in mHC paper + + Args: + x: [s, b, n*C] - n-stream hidden states + + Returns: + h_pre: [s, b, n] - aggregation weights (sigmoid activated) + h_post: [s, b, n] - expansion weights (2*sigmoid activated) + h_res: [s, b, n, n] - residual mixing matrix (doubly stochastic) + """ + s, b, _ = x.shape + with torch.cuda.nvtx.range("HyperConnection::projection_and_get_norm"): + proj, inv_rms = self._projection_and_get_norm(x) + with torch.cuda.nvtx.range("HyperConnection::compute_h"): + h_pre, h_post, h_res = self._compute_h(proj, inv_rms) + h_res = self._sinkhorn_op( + h_res.view(s, b, self.n, self.n), self.sinkhorn_iterations, self.norm_eps + ) # [s, b, n, n] + + return h_pre, h_post, h_res + + @torch.compile + def _apply_h_post_hidden(self, x: Tensor, h_post: Tensor) -> Tensor: + """ + Apply H_post to hidden states. + + Computes: H_post^T @ x + + Args: + x: [s, b, C] - standard hidden states + h_post: [s, b, n] - expansion weights + + Returns: + output: [s, b, n*C] - expanded tensor + """ + n = self.n + s, b, _ = h_post.shape + C = x.shape[-1] + x_expanded = x.unsqueeze(2) # [s, b, 1, C] + + # h_post^T @ x : [s, b, n, 1] * [s, b, 1, C] -> [s, b, n, C] + result = h_post.unsqueeze(-1) * x_expanded + return result.view(s, b, n * C) + + @torch.compile + def _apply_h_post_bias(self, bias: Tensor, h_post: Tensor) -> Tensor: + """ + Apply H_post to a bias vector. + + Args: + bias: [C] - bias tensor broadcast across sequence and batch + h_post: [s, b, n] - expansion weights + + Returns: + output: [s, b, n*C] - expanded bias + """ + n = self.n + s, b, _ = h_post.shape + C = bias.shape[0] + bias_expanded = bias.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand(s, b, 1, C) + + # h_post^T @ x : [s, b, n, 1] * [s, b, 1, C] -> [s, b, n, C] + result = h_post.unsqueeze(-1) * bias_expanded + return result.view(s, b, n * C) + + @nvtx_decorator(message="HyperConnection::apply_h_post") + def apply_h_post( + self, + x_with_bias: Tuple[Tensor, Optional[Tensor]], + h_post: Tensor, + manager: Optional['MHCRecomputeManager'] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Apply H_post to x and optionally bias, with optional checkpointing. + + This is the unified entry point that handles both normal execution + and checkpoint-based execution for memory efficiency. + + Args: + x_with_bias: Tuple of (x, bias) where: + - x: [s, b, C] - hidden states + - bias: [C] or None - optional bias tensor + h_post: [s, b, n] - expansion weights + manager: Optional MHCRecomputeManager for checkpoint management. + When provided, wraps _apply_h_post with CheckpointWithoutOutput. + + Returns: + Tuple of (x_out, bias_out) where: + - x_out: [s, b, n*C] - expanded hidden states + - bias_out: [s, b, n*C] or None - expanded bias if input bias was not None + """ + x, bias = x_with_bias + + if manager is not None: + from megatron.core.tensor_parallel.random import CheckpointWithoutOutput + + # Checkpoint H_post application to discard the output + x_out = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint( + self._apply_h_post_hidden, x, h_post + ) + + # Checkpoint H_post bias expansion if not None + if bias is not None: + bias_out = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint( + self._apply_h_post_bias, bias, h_post + ) + else: + bias_out = None + else: + # Normal execution without checkpoint + x_out = self._apply_h_post_hidden(x, h_post) + bias_out = self._apply_h_post_bias(bias, h_post) if bias is not None else None + + return x_out, bias_out + + def aggregate(self, x: Tensor, h_pre: Tensor) -> Tensor: + """ + Aggregate n-stream to 1-stream. + + Args: + x: [s, b, n*C] - n-stream hidden states + h_pre: [s, b, n] - aggregation weights + + Returns: + aggregated: [s, b, C] - single stream hidden states + """ + s, b, _ = x.shape + C = self.hidden_size + x_streams = x.view(s, b, self.n, C) + return self._h_aggregate_op(x_streams, h_pre) + + @torch.compile + def apply_h_res(self, h_res: Tensor, residual: Tensor) -> Tensor: + """ + Apply H_res to residual using H_res weights. + + Computes: H_res @ residual + + Args: + h_res: [s, b, n, n] - residual mixing matrix + residual: [s, b, n*C] - n-stream hidden states + """ + s, b, _ = residual.shape + n = self.n + C = self.hidden_size + + # Reshape for bmm: [s, b, n, n] -> [s*b, n, n] + h_res_batched = h_res.view(s * b, n, n) + # [s, b, n*C] -> [s, b, n, C] -> [s*b, n, C] + residual_batched = residual.view(s, b, n, C).view(s * b, n, C) + + # Batch matrix multiply: [s*b, n, n] @ [s*b, n, C] -> [s*b, n, C] + mixed = torch.bmm(h_res_batched, residual_batched) + + return mixed.view(s, b, n * C) + + def forward( + self, hidden_states: Tensor, mhc_recompute_manager: Optional['MHCRecomputeManager'] = None + ) -> Tuple[Tensor, Tensor, Tensor]: + """ + Full mHC forward pass. + + Args: + hidden_states: [s, b, n*C] - n-stream hidden states + mhc_recompute_manager: Optional MHCRecomputeManager for checkpoint management. + When provided, uses _forward_with_checkpoint for memory-efficient execution. + + Returns: + aggregated: [s, b, C] - aggregated input for layer computation + h_res: [s, b, n, n] - residual mixing matrix (for fused kernel) + h_post: [s, b, n] - expansion weights + """ + if mhc_recompute_manager is not None: + return self._forward_with_checkpoint(hidden_states, mhc_recompute_manager) + else: + return self._forward_normal(hidden_states) + + def _forward_normal(self, hidden_states: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """ + Normal forward pass without checkpointing. + + Args: + hidden_states: [s, b, n*C] - n-stream hidden states + + Returns: + aggregated: [s, b, C] - aggregated input for layer computation + h_res: [s, b, n, n] - residual mixing matrix (for fused kernel) + h_post: [s, b, n] - expansion weights + """ + # Compute mappings + h_pre, h_post, h_res = self.compute_mappings(hidden_states) + + # Aggregate for layer input + with torch.cuda.nvtx.range("HyperConnection::aggregate"): + aggregated = self.aggregate(hidden_states, h_pre) + + return aggregated, h_res, h_post + + def _forward_with_checkpoint( + self, hidden_states: Tensor, manager: 'MHCRecomputeManager' + ) -> Tuple[Tensor, Tensor, Tensor]: + """ + Forward pass with checkpointing for memory efficiency. + + compute_mappings is called directly (not checkpointed) since its outputs + (h_pre, h_post, h_res) are needed downstream. Only aggregate is wrapped with + CheckpointWithoutOutput and auto-registered to the manager. + apply_h_res is deferred to h_res_h_post_bda for kernel fusion. + + Args: + hidden_states: [s, b, n*C] - n-stream hidden states + manager: MHCRecomputeManager for unified recomputation + + Returns: + aggregated: [s, b, C] - aggregated input for layer computation + h_res: [s, b, n, n] - residual mixing matrix (for fused kernel) + h_post: [s, b, n] - expansion weights + """ + from megatron.core.tensor_parallel.random import CheckpointWithoutOutput + + h_pre, h_post, h_res = self.compute_mappings(hidden_states) + + # Checkpoint aggregate - auto-registers to manager + aggregated = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint( + self.aggregate, hidden_states, h_pre + ) + + return aggregated, h_res, h_post + + # ==================== Block-level utilities ==================== + + @staticmethod + def input_expand(x: Tensor, n: int) -> Tensor: + """ + Expand 1-stream to n-stream at TransformerBlock entry. + + Simple replication strategy: each stream initialized as a copy of input. + + Args: + x: [s, b, C] - single stream hidden states + n: Number of residual streams + + Returns: + expanded: [s, b, n*C] - n-stream hidden states + """ + s, b, C = x.shape + # expand() is a view; contiguous() intentionally materializes the n + # streams, so the reference path allocates an n*C activation here. + expanded = x.unsqueeze(2).expand(s, b, n, C).contiguous() + return expanded.view(s, b, n * C) + + @staticmethod + def output_contract(x: Tensor, n: int) -> Tensor: + """ + Contract n-stream to 1-stream at TransformerBlock exit. + + Simple averaging strategy: average all streams. + + Args: + x: [s, b, n*C] - n-stream hidden states + n: Number of residual streams + + Returns: + contracted: [s, b, C] - single stream hidden states + """ + s, b, nC = x.shape + C = nC // n + # Average all streams + x_streams = x.view(s, b, n, C) + contracted = x_streams.mean(dim=2) + return contracted + + # ==================== Combined H_res + H_post + BDA path ==================== + + @nvtx_decorator(message="HyperConnection::h_res_h_post_bda") + def h_res_h_post_bda( + self, + h_res: Tensor, + original_residual: Tensor, + h_post: Tensor, + layer_output_with_bias: Tuple[Tensor, Optional[Tensor]], + dropout_prob: float, + training: bool, + fused: bool, + manager: Optional['MHCRecomputeManager'] = None, + ) -> Tensor: + """ + Combine apply_h_res, apply_h_post and bias-dropout-add. + + This is a reference implementation that uses native PyTorch for the + dropout path. Actual fused kernels are selected through _h_post_bda_op + when dropout is disabled or training is off. + + The computation flow is: + 1. mixed = H_res @ original_residual (apply_h_res) + 2. expanded = H_post^T @ layer_output (apply_h_post) + 3. output = dropout(expanded + bias) + mixed (bias-dropout-add) + + Args: + h_res: [s, b, n, n] - residual mixing matrix + original_residual: [s, b, n*C] - n-stream hidden states (before H_res applied) + h_post: [s, b, n] - expansion weights + layer_output_with_bias: Tuple of (x, bias) where: + - x: [s, b, C] - layer output (attention or MLP output) + - bias: [C] or None - optional bias tensor + dropout_prob: Dropout probability + training: Whether in training mode + fused: Whether to use fused BDA implementation + manager: Optional MHCRecomputeManager for checkpoint management. + When provided, each operation is wrapped with CheckpointWithoutOutput. + + Returns: + output: [s, b, n*C] - final output after all operations + """ + if manager is not None: + return self._h_res_h_post_bda_with_checkpoint( + h_res, + original_residual, + h_post, + layer_output_with_bias, + dropout_prob, + training, + fused, + manager, + ) + else: + return self._h_res_h_post_bda_native( + h_res, + original_residual, + h_post, + layer_output_with_bias, + dropout_prob, + training, + fused, + ) + + def _h_res_h_post_bda_native( + self, + h_res: Tensor, + original_residual: Tensor, + h_post: Tensor, + layer_output_with_bias: Tuple[Tensor, Optional[Tensor]], + dropout_prob: float, + training: bool, + fused: bool, + ) -> Tensor: + """ + h_res, h_post and bda. + + When dropout is zero (or inference), uses a single fused/reference kernel + for H_res @ residual + H_post * (x + bias). Falls back to unfused + implementation when dropout is needed. + + Args: + h_res: [s, b, n, n] - residual mixing matrix + original_residual: [s, b, n*C] - n-stream hidden states + h_post: [s, b, n] - expansion weights + layer_output_with_bias: Tuple of (x, bias) + dropout_prob: Dropout probability + training: Whether in training mode + fused: Whether to use fused BDA implementation + + Returns: + output: [s, b, n*C] - final output + """ + x, bias = layer_output_with_bias + + if dropout_prob == 0.0 or not training: + s, b, _ = original_residual.shape + n = self.n + C = self.hidden_size + orig_reshaped = original_residual.view(s, b, n, C) + return self._h_post_bda_op(h_res, orig_reshaped, h_post, x, bias) + + from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add + + with torch.cuda.nvtx.range("HyperConnection::apply_h_res"): + mixed = self.apply_h_res(h_res, original_residual) + with torch.cuda.nvtx.range("HyperConnection::apply_h_post"): + x_expanded = self._apply_h_post_hidden(x, h_post) + bias_expanded = self._apply_h_post_bias(bias, h_post) if bias is not None else None + bda_func = get_bias_dropout_add(training, fused) + with torch.cuda.nvtx.range("HyperConnection::bda"): + output = bda_func((x_expanded, bias_expanded), mixed, dropout_prob) + return output + + @nvtx_decorator(message="HyperConnection::h_res_h_post_bda_with_checkpoint") + def _h_res_h_post_bda_with_checkpoint( + self, + h_res: Tensor, + original_residual: Tensor, + h_post: Tensor, + layer_output_with_bias: Tuple[Tensor, Optional[Tensor]], + dropout_prob: float, + training: bool, + fused: bool, + manager: 'MHCRecomputeManager', + ) -> Tensor: + """ + Checkpointed variant of _h_res_h_post_bda_native. + + Wraps compute in CheckpointWithoutOutput for activation memory savings. + Cannot reuse _native directly because checkpoint requires all args to be + positional Tensors; tuple/Optional/scalar args are unpacked or captured + via closure instead. + + Args: + h_res: [s, b, n, n] - residual mixing matrix + original_residual: [s, b, n*C] - n-stream hidden states + h_post: [s, b, n] - expansion weights + layer_output_with_bias: Tuple of (x, bias) + dropout_prob: Dropout probability + training: Whether in training mode + fused: Whether to use fused BDA implementation + manager: MHCRecomputeManager for checkpoint management + + Returns: + output: [s, b, n*C] - final output + """ + from megatron.core.tensor_parallel.random import CheckpointWithoutOutput + + x, bias = layer_output_with_bias + n = self.n + C = self.hidden_size + + # Fast path: no dropout — use fused/reference h_post_bda kernel (same as _native) + if dropout_prob == 0.0 or not training: + + def _fused_wrapper(h_res, original_residual, h_post, x, *optional_bias): + s, b, _ = original_residual.shape + orig_reshaped = original_residual.view(s, b, n, C) + b_arg = optional_bias[0] if optional_bias else None + return self._h_post_bda_op(h_res, orig_reshaped, h_post, x, b_arg) + + ckpt = CheckpointWithoutOutput(ckpt_manager=manager) + if bias is not None: + output = ckpt.checkpoint(_fused_wrapper, h_res, original_residual, h_post, x, bias) + else: + output = ckpt.checkpoint(_fused_wrapper, h_res, original_residual, h_post, x) + + # Slow path: dropout required — fused kernel does not support dropout, + # fall back to sequential apply_h_res + apply_h_post + bda + else: + from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add + + bda_func = get_bias_dropout_add(training, fused) + has_bias = bias is not None + + def _reference_wrapper(h_res, original_residual, h_post, x, *optional_bias): + with torch.cuda.nvtx.range("HyperConnection::apply_h_res"): + mixed = self.apply_h_res(h_res, original_residual) + with torch.cuda.nvtx.range("HyperConnection::apply_h_post"): + x_expanded = self._apply_h_post_hidden(x, h_post) + if has_bias: + bias_expanded = self._apply_h_post_bias(optional_bias[0], h_post) + else: + bias_expanded = None + with torch.cuda.nvtx.range("HyperConnection::bda"): + output = bda_func((x_expanded, bias_expanded), mixed, dropout_prob) + return output + + ckpt = CheckpointWithoutOutput(ckpt_manager=manager) + if has_bias: + output = ckpt.checkpoint(_reference_wrapper, h_res, original_residual, h_post, x, bias) + else: + output = ckpt.checkpoint(_reference_wrapper, h_res, original_residual, h_post, x) + + return output diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index 8bea3b8c94e..6da42004e2e 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -1,8 +1,9 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + import logging from contextlib import nullcontext from dataclasses import dataclass -from typing import List, Optional, Set, Union, cast +from typing import List, Optional, Set, Tuple, Union, cast import torch from torch import Tensor @@ -19,7 +20,9 @@ from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.pipeline_parallel.utils import is_vp_first_stage, is_vp_last_stage from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.tensor_parallel.random import MHCRecomputeManager from megatron.core.transformer.enums import CudaGraphScope, LayerType +from megatron.core.transformer.hyper_connection import HyperConnectionModule from megatron.core.transformer.module import GraphableMegatronModule, MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.torch_norm import LayerNormBuilder @@ -319,6 +322,7 @@ def __init__( self.offload_context, self.group_prefetch_offload_commit_async = nullcontext(), None self.config._cpu_offloading_context = None + self.num_residual_streams = config.num_residual_streams self._build_layers() self.num_layers_per_pipeline_rank = len(self.layers) @@ -367,6 +371,7 @@ def build_layer(layer_spec, layer_number): for i, layer_spec in enumerate(self.submodules.layer_specs) ] ) + self._mhc_recompute_block_end_plan = self._build_mhc_recompute_block_end_plan() # @TODO: add back account_for_embedding_in_pipeline_split (see issue #293) # In pipeline parallelism, we want to add this LN only to the last stage of the pipeline @@ -642,6 +647,59 @@ def __call__(self, *args, **kwargs): return super().__call__(*args, **kwargs)[0] return super().__call__(*args, **kwargs) + def _build_mhc_recompute_block_end_plan(self) -> List[bool]: + """Precompute deterministic mHC recompute block-end markers.""" + num_layers = len(self.layers) + is_recompute_block_end: List[bool] = [] + + if num_layers == 0: + return is_recompute_block_end + + mhc_recompute_layer_num = self.config.mhc_recompute_layer_num + for l_no in range(num_layers): + is_last_in_transformer_block = l_no == num_layers - 1 + is_last_in_recompute_block = is_last_in_transformer_block + if mhc_recompute_layer_num is not None: + is_last_in_recompute_block = is_last_in_transformer_block or ( + (l_no + 1) % mhc_recompute_layer_num == 0 + ) + + is_recompute_block_end.append(is_last_in_recompute_block) + + return is_recompute_block_end + + def _build_mhc_recompute_layer_plan( + self, use_mhc_recompute: bool + ) -> Tuple[List[Optional[MHCRecomputeManager]], List[bool]]: + """Build fresh per-forward MHC managers using cached block-end topology.""" + num_layers = len(self.layers) + layer_managers: List[Optional[MHCRecomputeManager]] = [None] * num_layers + is_recompute_block_end = self._mhc_recompute_block_end_plan + + if not use_mhc_recompute or num_layers == 0: + return layer_managers, is_recompute_block_end + + mhc_manager = MHCRecomputeManager() + + for l_no, is_last_in_recompute_block in enumerate(is_recompute_block_end): + is_last_in_transformer_block = l_no == num_layers - 1 + layer_managers[l_no] = mhc_manager + + if is_last_in_recompute_block and not is_last_in_transformer_block: + mhc_manager = MHCRecomputeManager() + + return layer_managers, is_recompute_block_end + + @staticmethod + def _finalize_mhc_recompute_layer( + mhc_manager: Optional[MHCRecomputeManager], + hidden_states: Tensor, + is_last_in_recompute_block: bool, + ) -> None: + """Finalize MHC recompute state for the current layer when block ends.""" + if mhc_manager is not None and is_last_in_recompute_block: + mhc_manager.discard_all_outputs_and_register_unified_recompute(hidden_states) + def forward( self, hidden_states: Union[Tensor, WrappedTensor], @@ -751,6 +809,14 @@ def forward( # is called here to be future-proof and corner-case-proof. hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + # Expand hidden states for hyper connections at the start of the block. + # Pipeline-parallel HC support is blocked in TransformerConfig until P2P + # tensor shapes carry the n-stream hidden dimension. + if self.config.enable_hyper_connections: + hidden_states = HyperConnectionModule.input_expand( + hidden_states, self.num_residual_streams + ) # [s, b, C] -> [s, b, n*C] + if self.config.sequence_parallel: rng_context = tensor_parallel.get_cuda_rng_tracker().fork() else: @@ -778,6 +844,18 @@ def forward( use_inner_quantization_context = False outer_quantization_context = nullcontext() + # Determine if MHC recompute should be used + # Only enable when: training mode AND hyper connections AND 'mhc' in recompute_modules + use_mhc_recompute = ( + self.training + and self.config.enable_hyper_connections + and self.config.recompute_granularity == 'selective' + and "mhc" in self.config.recompute_modules + ) + mhc_layer_managers, mhc_is_last_in_recompute_block = self._build_mhc_recompute_layer_plan( + use_mhc_recompute + ) + with rng_context, outer_quantization_context: # Forward pass. if self.config.recompute_granularity == 'full' and self.training: @@ -818,23 +896,40 @@ def forward( else: inner_quantization_context = nullcontext() - with self.offload_context, inner_quantization_context: - hidden_states, context = layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - rotary_pos_cos_sin=rotary_pos_cos_sin, - attention_bias=attention_bias, - inference_context=inference_context, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - padding_mask=padding_mask, + mhc_manager = mhc_layer_managers[l_no] + if mhc_manager is not None: + mhc_manager.is_last_layer_in_recompute_block = ( + mhc_is_last_in_recompute_block[l_no] ) + layer_kwargs = dict( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + rotary_pos_cos_sin=rotary_pos_cos_sin, + attention_bias=attention_bias, + inference_context=inference_context, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + padding_mask=padding_mask, + ) + if mhc_manager is not None and hasattr( + layer, "self_attention_hyper_connection" + ): + layer_kwargs["mhc_recompute_manager"] = mhc_manager + + with self.offload_context, inner_quantization_context: + hidden_states, context = layer(**layer_kwargs) + self._finalize_mhc_recompute_layer( + mhc_manager=mhc_manager, + hidden_states=hidden_states, + is_last_in_recompute_block=mhc_is_last_in_recompute_block[l_no], + ) + if ( torch.is_grad_enabled() and self.config.cpu_offloading @@ -846,6 +941,13 @@ def forward( if (l_no + layer_offset) in extract_layer_indices: intermediate_hidden_states.append(hidden_states) + # Contract n-stream hidden states at the block boundary even when the + # final layernorm is disabled; downstream output layers expect [s, b, C]. + if self.config.enable_hyper_connections: + hidden_states = HyperConnectionModule.output_contract( + hidden_states, self.num_residual_streams + ) # [s, b, n*C] -> [s, b, C] + # Final layer norm. if self.final_layernorm is not None: hidden_states = apply_module(self.final_layernorm)(cast(Tensor, hidden_states)) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 40c1a745493..ea4c4ccd15e 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import logging import math @@ -60,8 +60,8 @@ class TransformerConfig(ModelParallelConfig): mtp_loss_scaling_factor: Optional[float] = 0.1 """Weighting factor of Multi-Token Prediction (MTP) loss. - We compute the average of the MTP losses across all depths, - and multiply it the scaling factor to obtain the overall MTP loss, + We compute the average of the MTP losses across all depths, + and multiply it the scaling factor to obtain the overall MTP loss, which serves as an additional training objective. """ @@ -90,8 +90,8 @@ class TransformerConfig(ModelParallelConfig): - list: e.g., [['embedding', 'decoder'], ['decoder', 'decoder', 'decoder', 'loss']]. - PipelineParallelLayerLayout: a PipelineParallelLayerLayout object. If given either a string or a list, it will be transferred into a PipelineParallelLayerLayout - in post init. Let i = a * pp_size + b, then layout[i] gives a list of the layers - in the a-th vpp stage and the b-th pp stage, i.e., vpp(0)pp(0), vpp(0)pp(1), ..., + in post init. Let i = a * pp_size + b, then layout[i] gives a list of the layers + in the a-th vpp stage and the b-th pp stage, i.e., vpp(0)pp(0), vpp(0)pp(1), ..., vpp(i)pp(j), vpp(i)pp(j+1), ..., vpp(-1)pp(-2), vpp(-1)pp(-1). In the inner lists of layers, 'embedding' or 'E' denotes the embedding layer, 'loss' or 'L' denotes the loss function, and 'decoder' or 't' denotes the transformer decoder layer. @@ -132,8 +132,8 @@ class TransformerConfig(ModelParallelConfig): """Softmax scale for attention scaling.""" softmax_type: Literal['vanilla', 'off-by-one', 'learnable'] = 'vanilla' - """Applies modified softmax from https://www.evanmiller.org/attention-is-off-by-one.html. - Supports both TE FusedAttention and local unfused attention. Supports both a fixed offset and + """Applies modified softmax from https://www.evanmiller.org/attention-is-off-by-one.html. + Supports both TE FusedAttention and local unfused attention. Supports both a fixed offset and and learnable offset.""" num_query_groups: Optional[int] = field( @@ -193,7 +193,7 @@ class TransformerConfig(ModelParallelConfig): The stored input is casted back to the original precision before backprop compuatation.""" glu_linear_offset: float = 0.0 - """Offset term in the GLU activation function: activation_func(x[0]) * (x[1] + offset). Only + """Offset term in the GLU activation function: activation_func(x[0]) * (x[1] + offset). Only used when gated_linear_unit is True""" activation_func_clamp_value: Optional[float] = None @@ -288,7 +288,7 @@ class TransformerConfig(ModelParallelConfig): # linear attention #################### linear_attention_freq: Optional[Union[int, List[int]]] = None - """Frequency between LA (linear attention) layers + """Frequency between LA (linear attention) layers and SDPA (scaled dot-product attention) layers. Accepts either: - An integer N: Represents a (N-1):N ratio, meaning (N-1) LA layers for every 1 SDPA layer @@ -330,13 +330,13 @@ class TransformerConfig(ModelParallelConfig): embedding_init_method: Optional[Callable] = None """ - Method to initialize weights of the embedding layer. If None, will be set as described + Method to initialize weights of the embedding layer. If None, will be set as described in init_method above. """ embedding_init_method_std: Optional[float] = None """ - Standard deviation of the zero mean normal for the default initialization method for the + Standard deviation of the zero mean normal for the default initialization method for the embedding layer. If None, will be set to init_method_std. Setting this to a value around 1.0 may avoid loss spikes in training. Setting this to any value will also skip applying weight decay on embedding weights to avoid shrinkage towards zero. @@ -482,7 +482,8 @@ class TransformerConfig(ModelParallelConfig): recompute_modules: Optional[List[str]] = None """The submodules to recompute. - choices: "core_attn", "moe_act", "layernorm", "mla_up_proj", "mlp", "moe", "shared_experts". + choices: "core_attn", "moe_act", "layernorm", "mla_up_proj", "mlp", "moe", + "shared_experts", "mhc". default: ["core_attn"]. "core_attn": recompute the core attention part of the transformer layer. "moe_act": recompute the MoE MLP activation function. @@ -491,7 +492,10 @@ class TransformerConfig(ModelParallelConfig): "mlp": recompute the dense MLP submodule. "moe": recompute the MoE layer. "shared_experts": recompute the shared experts in the MoE layer. - "moe_act", "layernorm", and "mla_up_proj" use output-discarding checkpointing, + "mhc": recompute HyperConnection intermediate activations via + CheckpointWithoutOutput + MHCRecomputeManager. Requires + enable_hyper_connections=True. Cannot be used with "mlp". + "moe_act", "layernorm", "mla_up_proj", and "mhc" use output-discarding checkpointing, "core_attn", "mlp", "moe", and "shared_experts" use normal checkpointing. """ @@ -582,7 +586,7 @@ class TransformerConfig(ModelParallelConfig): fp4: Optional[Literal['e2m1']] = field( default=None, metadata={"argparse_meta": {"arg_names": ["--fp4-format"]}} ) - """If set, enables the use of FP4 precision through Transformer Engine. Currently only + """If set, enables the use of FP4 precision through Transformer Engine. Currently only supports 'nvfp4' which uses NVFP4BlockScaling recipe (requires TE >= 2.7.0.dev0).""" fp4_recipe: Optional[Literal['nvfp4', 'custom']] = "nvfp4" @@ -616,12 +620,12 @@ class TransformerConfig(ModelParallelConfig): in the hidden_states gradient.""" moe_shared_expert_gate: bool = False - """Enable gate for shared expert. Only effective when + """Enable gate for shared expert. Only effective when moe-shared-expert-intermediate-size is set.""" moe_shared_expert_overlap: bool = False """Enable overlapping between shared expert computations and dispatcher communications. - Without this, the shared experts execute before the router. + Without this, the shared experts execute before the router. Only effective when moe-shared-expert-intermediate-size is set. """ @@ -715,7 +719,7 @@ class TransformerConfig(ModelParallelConfig): The default value 1e-3 is same as that used in DeepSeekV3.""" moe_router_force_load_balancing: bool = False - """[Experimental] Force load balancing with random logits for MoE router, supports naive topk + """[Experimental] Force load balancing with random logits for MoE router, supports naive topk and group-limited topk. This is an experimental feature and only for benchmark.""" moe_router_force_biased: Optional[float] = None @@ -756,7 +760,7 @@ class TransformerConfig(ModelParallelConfig): moe_flex_dispatcher_backend: Literal['deepep', 'hybridep'] = "deepep" """[Experimental] The backend to use for flex token dispatcher. The default is "deepep". - Options are "deepep" and "hybridep". Currently only "hybridep" backend supports + Options are "deepep" and "hybridep". Currently only "hybridep" backend supports the MNNVL case.""" moe_per_layer_logging: bool = False @@ -871,6 +875,50 @@ class TransformerConfig(ModelParallelConfig): When cuda_graph_impl is set to "local", "full_iteration" can be specified as cuda_graph_scope to enable whole iteration CUDA graph. All other values enable layerwise CUDA graph.""" + #################### + # Hyper-Connection Configuration + #################### + enable_hyper_connections: bool = False + """Enable mHC residual connections.""" + + num_residual_streams: int = 4 + """Number of residual streams (n in paper). + + Within each hyper-connection transformer block, hidden states are expanded + from [s, b, C] to [s, b, n*C], so activation memory in the block scales + roughly linearly with this value. + """ + + mhc_sinkhorn_iterations: int = 20 + """Number of Sinkhorn-Knopp iterations for doubly stochastic projection.""" + + mhc_init_gating_factor: float = 0.01 + """Initial value of Gating Factor (alpha in paper).""" + + use_fused_mhc: bool = False + """Use cuTile fused kernels for mHC operations. + + When True, attempts to replace the reference mHC modules (SinkhornKnopp, + H_aggregate, H_post_bda, ProjRms) with fused cuda.tile (cuTile) autograd + functions for better performance on supported GPUs. Requires cuTile to be + installed; if cuTile is unavailable the flag is silently reset to False and + a warning is emitted. + """ + + mhc_recompute_layer_num: Optional[int] = None + """Number of layers per MHC recompute block. + + When set, every `mhc_recompute_layer_num` layers form a recompute block. The last layer + in each recompute block (i.e., layer_number % mhc_recompute_layer_num == 0 or the final + layer in the transformer block) will: + - NOT checkpoint its final MLP BDA + - Register the unified recompute hook on its MLP BDA output + - A new MHCRecomputeManager is created for subsequent layers + + If None, all layers in the transformer block share a single recompute block. + + Must be a positive integer when set.""" + #################### # miscellaneous #################### @@ -890,7 +938,7 @@ class TransformerConfig(ModelParallelConfig): batch_invariant_mode: bool = False """If true, uses batch-invariant kernels that provide deterministic forward execution regardless of batch size. This ensures bitwise identical results when the same inputs are processed - in different batch configurations. This will significantly affect speed of + in different batch configurations. This will significantly affect speed of training and inference as the kernels are not full optimized. Defaults to False.""" @@ -1383,6 +1431,7 @@ def __post_init__(self): "mlp", "moe", "shared_experts", + "mhc", } invalid_modules = set(self.recompute_modules) - allowed_modules assert not invalid_modules, ( @@ -1445,6 +1494,94 @@ def __post_init__(self): if "moe" not in self.recompute_modules: self.recompute_modules.append("moe") + # Validation for "mhc" in recompute_modules + if self.recompute_granularity == "selective" and "mhc" in self.recompute_modules: + if not self.enable_hyper_connections: + raise ValueError( + "'mhc' in recompute_modules requires enable_hyper_connections=True." + ) + if "mlp" in self.recompute_modules: + raise ValueError( + "'mhc' and 'mlp' in recompute_modules cannot be used together. " + "They use different checkpoint mechanisms that may conflict." + ) + # bool is a subclass of int, so reject it explicitly. + if self.mhc_recompute_layer_num is not None and ( + isinstance(self.mhc_recompute_layer_num, bool) + or not isinstance(self.mhc_recompute_layer_num, int) + or self.mhc_recompute_layer_num < 1 + ): + raise ValueError( + "mhc_recompute_layer_num must be a positive integer when " + "'mhc' is in recompute_modules." + ) + if self.fine_grained_activation_offloading: + raise ValueError( + "'mhc' in recompute_modules is incompatible with " + "fine_grained_activation_offloading. The mHC recompute hook fires " + "before the offloading backward chunk is initialized, causing " + "tensor_pop on a None chunk. Disable one of them." + ) + + if ( + self.enable_hyper_connections + and self.recompute_granularity is not None + and not (self.recompute_granularity == "selective" and "mhc" in self.recompute_modules) + ): + if self.recompute_granularity == "full": + raise ValueError( + "enable_hyper_connections is not yet compatible with full activation " + "recompute. Use selective recompute with 'mhc' in recompute_modules " + "or disable activation recompute." + ) + warnings.warn( + "HyperConnections are enabled but 'mhc' is not in " + "recompute_modules with selective recompute. Consider adding 'mhc' to " + "recompute_modules with selective recompute to reduce activation memory." + ) + + # Validation for use_fused_mhc + if self.use_fused_mhc: + if not self.enable_hyper_connections: + raise ValueError("use_fused_mhc requires enable_hyper_connections=True.") + try: + from megatron.core.fusions.fused_mhc_kernels import is_cutile_available + + if not is_cutile_available(): + warnings.warn( + "use_fused_mhc is enabled but cuda.tile (cuTile) is not installed. " + "Falling back to reference mHC implementations.", + UserWarning, + ) + self.use_fused_mhc = False + except ImportError: + warnings.warn( + "use_fused_mhc is enabled but fused_mhc_kernels module could not be " + "imported. Falling back to reference mHC implementations.", + UserWarning, + ) + self.use_fused_mhc = False + + # Validation for hyper_connections with MTP + if self.enable_hyper_connections and self.mtp_num_layers is not None: + raise ValueError( + "enable_hyper_connections is not compatible with Multi-Token Prediction (MTP). " + "Please disable MTP (set mtp_num_layers=None) when using hyper connections." + ) + + if self.enable_hyper_connections and self.pipeline_model_parallel_size > 1: + raise ValueError( + "enable_hyper_connections is not yet compatible with pipeline parallelism " + "(pipeline_model_parallel_size > 1). Pipeline-parallel support is planned " + "for a follow-up PR." + ) + + if self.enable_hyper_connections and self.inference_fuse_tp_communication: + raise ValueError( + "enable_hyper_connections is not compatible with " + "inference_fuse_tp_communication." + ) + if self.fine_grained_activation_offloading: assert ( not self.cpu_offloading @@ -1986,6 +2123,16 @@ def __post_init__(self): "To use full iteration cuda graph, please use " "cuda_graph_impl=local instead of cuda_graph_impl=transformer_engine." ) + if ( + self.enable_hyper_connections + and self.num_moe_experts is not None + and self.num_moe_experts > 1 + and CudaGraphScope.moe_router in self.cuda_graph_scope + ): + raise ValueError( + "enable_hyper_connections is not yet compatible with " + "MoE router CUDA graphs." + ) assert ( CudaGraphScope.moe not in self.cuda_graph_scope or CudaGraphScope.moe_router not in self.cuda_graph_scope @@ -2328,7 +2475,7 @@ class MLATransformerConfig(TransformerConfig): cache_mla_latents: bool = False """Cache the low dimensional tensors for MLA rather than full KV cache. - This is only for the dynamic inference backend and requires that + This is only for the dynamic inference backend and requires that Flash MLA is installed.""" mla_down_proj_fusion: bool = False diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index cf63199347c..c4ffa0734ac 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -1,13 +1,17 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. from __future__ import annotations import functools import logging import warnings from abc import ABC +from contextvars import ContextVar from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Dict, Optional, Union +if TYPE_CHECKING: + from megatron.core.tensor_parallel.random import MHCRecomputeManager + import torch import torch.distributed from torch import Tensor @@ -40,6 +44,7 @@ from megatron.core.inference.contexts import BaseInferenceContext logger = logging.getLogger(__name__) +_MHC_RECOMPUTE_MANAGER_CONTEXT = ContextVar("mhc_recompute_manager", default=None) def get_transformer_layer_offset( @@ -210,16 +215,21 @@ class TransformerLayerSubmodules: Args: input_layernorm: Specification for the input layer normalization. + self_attention_hyper_connection: Specification for the hyper-connection module + before/after self-attention. self_attention (Union[ModuleSpec, type]): Specification for the self-attention mechanism. self_attn_bda (Union[ModuleSpec, type]): Specification for the bias-dropout-add operation after self-attention. pre_cross_attn_layernorm: Specification for the layer normalization before cross-attention. + cross_attention_hyper_connection: Reserved for future cross-attention + hyper-connection support and must remain IdentityOp for now. cross_attention (Union[ModuleSpec, type]): Specification for the cross-attention mechanism. cross_attn_bda (Union[ModuleSpec, type]): Specification for the bias-dropout-add operation after cross-attention. pre_mlp_layernorm: Specification for the layer normalization before the MLP. + mlp_hyper_connection: Specification for the hyper-connection module before/after the MLP. mlp (Union[ModuleSpec, type]): Specification for the MLP in Dense layer. mlp_bda (Union[ModuleSpec, type]): Specification for the bias-dropout-add operation after the MLP. @@ -228,14 +238,18 @@ class TransformerLayerSubmodules: """ input_layernorm: LayerNormBuilder = IdentityOp + self_attention_hyper_connection: Union[ModuleSpec, type] = IdentityOp self_attention: Union[ModuleSpec, type] = IdentityOp self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp pre_cross_attn_layernorm: LayerNormBuilder = IdentityOp + # TODO: wire this when cross-attention hyper-connection support is added. + cross_attention_hyper_connection: Union[ModuleSpec, type] = IdentityOp cross_attention: Union[ModuleSpec, type] = IdentityOp cross_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp pre_mlp_layernorm: LayerNormBuilder = IdentityOp + mlp_hyper_connection: Union[ModuleSpec, type] = IdentityOp mlp: Union[ModuleSpec, type] = IdentityOp mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp @@ -606,8 +620,6 @@ def _forward_attention( ) if using_fused_tp_inference_kernel: - # Set the residual for fused reduce-scatter + add + layer-norm + all-gather - # operation in attention's out_proj (linear_proj) self._set_proj_residual(residual) # Self attention. @@ -700,6 +712,10 @@ def forward(self, *args, **kwargs): This method calls the core computation of a transformer layer, including self-attention, cross-attention (if applicable), and feed-forward operations. """ + # Injected by __call__ for cuda graph keying; not a real forward arg. + kwargs.pop("dynamic_inference_decode_only", None) + # mHC recompute is consumed by HyperConnectionTransformerLayer only. + kwargs.pop("mhc_recompute_manager", None) hidden_states, context = self._forward_attention(*args, **kwargs) output = self._forward_mlp( hidden_states, @@ -1280,6 +1296,19 @@ def _should_call_local_cudagraph(self, *args, **kwargs): return True return False + def __call__(self, *args, **kwargs): + if self._should_call_local_cudagraph(*args, **kwargs): + # Inference mode. + if kwargs.get('inference_context') is not None: + # dynamic_inference_decode_only is not a real argument to forward, it is only used + # to differentiate the cuda graph used for decode from the one used for non-decode + # inference. + kwargs["dynamic_inference_decode_only"] = kwargs[ + 'inference_context' + ].is_decode_only() + + return super().__call__(*args, **kwargs) + def get_layer_norm_weights(self): """ Get the weights of all layernorms (attention and MLP) in the transformer layer. @@ -1289,6 +1318,438 @@ def get_layer_norm_weights(self): return +class HyperConnectionTransformerLayer(TransformerLayer): + """A transformer layer with Manifold-Constrained Hyper-Connections (mHC). + + Extends TransformerLayer by adding hyper connection modules around self-attention + and MLP. The n-stream hidden states are aggregated before each sub-layer and + expanded back afterwards using learned mappings (H_pre, H_post, H_res). + + Cross-attention hyper connection is not supported. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + hidden_dropout: Optional[float] = None, + pg_collection: Optional[ProcessGroupCollection] = None, + vp_stage: Optional[int] = None, + is_mtp_layer: bool = False, + add_layer_offset: bool = True, + pp_layer_offset: Optional[int] = None, + ): + if submodules.cross_attention is not IdentityOp: + raise ValueError( + "HyperConnectionTransformerLayer does not support cross-attention. " + "Use IdentityOp for cross_attention." + ) + if submodules.cross_attention_hyper_connection is not IdentityOp: + raise ValueError( + "HyperConnectionTransformerLayer does not support cross-attention " + "hyper connections. Use IdentityOp for cross_attention_hyper_connection." + ) + + super().__init__( + config=config, + submodules=submodules, + layer_number=layer_number, + hidden_dropout=hidden_dropout, + pg_collection=pg_collection, + vp_stage=vp_stage, + is_mtp_layer=is_mtp_layer, + add_layer_offset=add_layer_offset, + pp_layer_offset=pp_layer_offset, + ) + + assert submodules.self_attention_hyper_connection is not IdentityOp, ( + "HyperConnectionTransformerLayer requires self_attention_hyper_connection. " + "Use TransformerLayer instead if hyper connections are not needed." + ) + assert submodules.mlp_hyper_connection is not IdentityOp, ( + "HyperConnectionTransformerLayer requires mlp_hyper_connection. " + "Use TransformerLayer instead if hyper connections are not needed." + ) + + self.self_attention_hyper_connection = build_module( + submodules.self_attention_hyper_connection, + config=self.config, + layer_number=self.layer_number, + ) + + self.mlp_hyper_connection = build_module( + submodules.mlp_hyper_connection, config=self.config, layer_number=self.layer_number + ) + + # When mHC recompute is active, skip checkpointing if the layernorm + # is IdentityOp (fused into TE linear) — there is nothing to recompute. + self.mhc_checkpoint_input_layernorm = not isinstance(self.input_layernorm, IdentityOp) + self.mhc_checkpoint_pre_mlp_layernorm = not isinstance(self.pre_mlp_layernorm, IdentityOp) + + def get_layer_static_inputs(self, seq_length, micro_batch_size): + """Override to produce n-stream hidden_states of shape [s, b, n*C]. + + CUDA graph capture creates static buffers whose shapes are determined by + this method. The base class returns [s, b, C], but mHC layers operate on + n-stream hidden states of shape [s, b, n*C]. + """ + static_inputs = super().get_layer_static_inputs(seq_length, micro_batch_size) + hs = static_inputs["hidden_states"] + n = self.config.num_residual_streams + static_inputs["hidden_states"] = torch.ones( + (hs.shape[0], hs.shape[1], n * self.config.hidden_size), + dtype=hs.dtype, + requires_grad=hs.requires_grad, + device=hs.device, + ) + return static_inputs + + def _get_submodules_under_cudagraphs(self): + """Override to include hyper connection modules. + + The base TransformerLayer._get_submodules_under_cudagraphs does not include + self_attention_hyper_connection / mlp_hyper_connection. Their learnable + parameters (mapping_proj, alpha_*, bias) need manual pre-forward hooks + during CUDA graph replay so that parameter all-gathers are triggered. + """ + submodules = super()._get_submodules_under_cudagraphs() + + if not self.config.cuda_graph_scope: + return submodules + + if CudaGraphScope.attn in self.config.cuda_graph_scope: + submodules.append(self.self_attention_hyper_connection) + if (not self.is_moe_layer and CudaGraphScope.mlp in self.config.cuda_graph_scope) or ( + self.is_moe_layer and CudaGraphScope.moe in self.config.cuda_graph_scope + ): + submodules.append(self.mlp_hyper_connection) + return submodules + + def __call__(self, *args, **kwargs): + # Extract before CUDA graph manager processes kwargs; MHCRecomputeManager + # is not a CUDA-graph-supported argument type. MCore CUDA graph capture + # executes forward synchronously inside __call__, so this context-local + # handoff is visible during capture without storing per-call state on the + # module instance. + mhc_recompute_manager = kwargs.pop("mhc_recompute_manager", None) + token = _MHC_RECOMPUTE_MANAGER_CONTEXT.set(mhc_recompute_manager) + try: + return super().__call__(*args, **kwargs) + finally: + _MHC_RECOMPUTE_MANAGER_CONTEXT.reset(token) + + def forward(self, *args, **kwargs): + """Forward pass with MHC recompute manager support.""" + kwargs.pop("dynamic_inference_decode_only", None) + + # Direct forward() calls can pass the manager normally. __call__ uses a + # context-local handoff so CUDA graph argument processing never sees the + # unsupported MHCRecomputeManager object. + mhc_recompute_manager = kwargs.pop("mhc_recompute_manager", None) + if mhc_recompute_manager is None: + mhc_recompute_manager = _MHC_RECOMPUTE_MANAGER_CONTEXT.get() + + hidden_states, context = self._forward_attention( + *args, mhc_recompute_manager=mhc_recompute_manager, **kwargs + ) + + output = self._forward_mlp( + hidden_states, + kwargs.get("inference_context", None), + padding_mask=kwargs.get("padding_mask", None), + mhc_recompute_manager=mhc_recompute_manager, + ) + return output, context + + def _forward_attention( + self, + hidden_states: Tensor, + attention_mask: Optional[Tensor] = None, + context: Optional[Tensor] = None, + context_mask: Optional[Tensor] = None, + rotary_pos_emb: Optional[Tensor] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + rotary_pos_cos_sin: Optional[Tensor] = None, + attention_bias: Optional[Tensor] = None, + inference_context: Optional[Any] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[Tensor] = None, + padding_mask: Optional[Tensor] = None, + mhc_recompute_manager: Optional['MHCRecomputeManager'] = None, + *, + inference_params: Optional[Any] = None, + ): + """Forward attention with hyper connection pre/post processing on self-attention.""" + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, + ) + + inference_context = deprecate_inference_params(inference_context, inference_params) + + residual = hidden_states + if self.config.fp32_residual_connection: + residual = residual.float() + + nvtx_range_push(suffix="self_attention_hyper_connection") + hidden_states, self_attn_h_res, self_attn_hc_h_post = self.self_attention_hyper_connection( + hidden_states, mhc_recompute_manager=mhc_recompute_manager + ) + nvtx_range_pop(suffix="self_attention_hyper_connection") + + # Optional Input Layer norm + checkpoint_input_layernorm = self.recompute_input_layernorm or ( + mhc_recompute_manager is not None and self.mhc_checkpoint_input_layernorm + ) + if checkpoint_input_layernorm: + self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput( + ckpt_manager=mhc_recompute_manager + ) + with off_interface(self.offload_attn_norm, hidden_states, "attn_norm") as hidden_states: + input_layernorm_output = self.input_layernorm_checkpoint.checkpoint( + apply_module(self.input_layernorm), hidden_states + ) + else: + with off_interface(self.offload_attn_norm, hidden_states, "attn_norm") as hidden_states: + input_layernorm_output = apply_module(self.input_layernorm)(hidden_states) + + if isinstance(input_layernorm_output, tuple): + if len(input_layernorm_output) != 2: + raise ValueError( + f"When the output of input_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(input_layernorm_output)}" + ) + input_layernorm_output, _ = input_layernorm_output + + # Self attention. + nvtx_range_push(suffix="self_attention") + attention_output_with_bias = self.self_attention( + input_layernorm_output, + attention_mask=attention_mask, + inference_context=inference_context, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + rotary_pos_cos_sin=rotary_pos_cos_sin, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ) + nvtx_range_pop(suffix="self_attention") + + if checkpoint_input_layernorm: + self.input_layernorm_checkpoint.discard_output_and_register_recompute( + attention_output_with_bias[0] + ) + + nvtx_range_push(suffix="self_attention_h_res_h_post_bda") + with self.bias_dropout_add_exec_handler(): + hidden_states = self.self_attention_hyper_connection.h_res_h_post_bda( + self_attn_h_res, + residual, + self_attn_hc_h_post, + attention_output_with_bias, + self.hidden_dropout, + self.training, + self.config.bias_dropout_fusion, + mhc_recompute_manager, + ) + nvtx_range_pop(suffix="self_attention_h_res_h_post_bda") + + if self.offload_attn_norm: + hidden_states = off_interface.group_commit(hidden_states, name="attn_norm") + + # Cross-attention (no hyper connection support). + residual = hidden_states + if self.config.fp32_residual_connection: + residual = residual.float() + pre_cross_attn_layernorm_output = apply_module(self.pre_cross_attn_layernorm)( + hidden_states + ) + + if isinstance(pre_cross_attn_layernorm_output, tuple): + if len(pre_cross_attn_layernorm_output) != 2: + raise ValueError( + f"When the output of pre_cross_attn_layernorm_output " + f"is a tuple, it is expected to have 2 elements " + f"(output, residual), but " + f"got {len(pre_cross_attn_layernorm_output)}" + ) + pre_cross_attn_layernorm_output, _ = pre_cross_attn_layernorm_output + + attention_output_with_bias = self.cross_attention( + pre_cross_attn_layernorm_output, + attention_mask=context_mask, + key_value_states=context, + inference_context=inference_context, + ) + + if isinstance(attention_output_with_bias, dict) and "context" in attention_output_with_bias: + context = attention_output_with_bias["context"] + + with self.bias_dropout_add_exec_handler(): + hidden_states = self.cross_attn_bda(self.training, self.config.bias_dropout_fusion)( + attention_output_with_bias, residual, self.hidden_dropout + ) + + return hidden_states, context + + def _forward_mlp( + self, + hidden_states, + inference_context=None, + padding_mask=None, + mhc_recompute_manager: Optional['MHCRecomputeManager'] = None, + ): + """Forward MLP with hyper connection pre/post processing.""" + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, + ) + + is_last_in_recompute_block = bool( + mhc_recompute_manager is not None + and getattr(mhc_recompute_manager, "is_last_layer_in_recompute_block", False) + ) + mhc_mlp_bda_manager = None if is_last_in_recompute_block else mhc_recompute_manager + + residual = hidden_states + if self.config.fp32_residual_connection: + residual = residual.float() + + nvtx_range_push(suffix="mlp_hyper_connection") + hidden_states, mlp_h_res, mlp_hc_h_post = self.mlp_hyper_connection( + hidden_states, mhc_recompute_manager=mhc_recompute_manager + ) + nvtx_range_pop(suffix="mlp_hyper_connection") + + # Optional Layer norm post the cross-attention. + checkpoint_pre_mlp_layernorm = self.recompute_pre_mlp_layernorm or ( + mhc_recompute_manager is not None and self.mhc_checkpoint_pre_mlp_layernorm + ) + if checkpoint_pre_mlp_layernorm: + self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput( + ckpt_manager=mhc_recompute_manager + ) + with off_interface(self.offload_mlp_norm, hidden_states, "mlp_norm") as hidden_states: + pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint( + apply_module(self.pre_mlp_layernorm), hidden_states + ) + else: + with off_interface(self.offload_mlp_norm, hidden_states, "mlp_norm") as hidden_states: + pre_mlp_layernorm_output = apply_module(self.pre_mlp_layernorm)(hidden_states) + + if isinstance(pre_mlp_layernorm_output, tuple): + if len(pre_mlp_layernorm_output) != 2: + raise ValueError( + f"When the output of pre_mlp_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(pre_mlp_layernorm_output)}" + ) + pre_mlp_layernorm_output, _ = pre_mlp_layernorm_output + + nvtx_range_push(suffix="mlp") + should_chunk_mlp_for_prefill = ( + self.config.mlp_chunks_for_prefill > 1 + and inference_context is not None + and not inference_context.is_decode_only() + and not isinstance(self.mlp, IdentityOp) + and not self.config.transformer_impl == "inference_optimized" + ) + + if self.recompute_mlp: + if self.config.fp8 or self.config.fp4: + from megatron.core.extensions.transformer_engine import te_checkpoint + + mlp_output_with_bias = te_checkpoint( + self.mlp, + False, + tensor_parallel.random.get_cuda_rng_tracker, + self.pg_collection.tp, + pre_mlp_layernorm_output, + padding_mask=padding_mask, + ) + else: + mlp_output_with_bias = tensor_parallel.checkpoint( + functools.partial(self.mlp, padding_mask=padding_mask), + False, + pre_mlp_layernorm_output, + ) + elif should_chunk_mlp_for_prefill: + num_chunks = min(self.config.mlp_chunks_for_prefill, pre_mlp_layernorm_output.shape[0]) + chunks = pre_mlp_layernorm_output.chunk(num_chunks, dim=0) + outputs = [self.mlp(chunk) for chunk in chunks] + mlp_output = torch.cat([out for out, _ in outputs], dim=0) + bias_chunks = [bias for _, bias in outputs if bias is not None] + bias_output = torch.stack(bias_chunks, dim=0).sum(dim=0) if bias_chunks else None + mlp_output_with_bias = (mlp_output, bias_output) + else: + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output, padding_mask=padding_mask) + + nvtx_range_pop(suffix="mlp") + + return self._forward_post_mlp_with_fused_hyper_connection( + mlp_output_with_bias, mlp_h_res, residual, mlp_hc_h_post, mhc_mlp_bda_manager + ) + + def _forward_post_mlp_with_fused_hyper_connection( + self, + mlp_output_with_bias, + mlp_h_res, + residual, + mlp_hc_h_post, + mhc_mlp_bda_recompute_manager: Optional['MHCRecomputeManager'] = None, + ): + """ + Perform operations after the MLP computation with fused hyper connection kernel. + + This method uses the fused kernel combining apply_h_res, apply_h_post and bias-dropout-add. + + Args: + mlp_output_with_bias (Tensor): Output tensor of the MLP layer with bias. + mlp_h_res (Tensor): [s, b, n, n] - residual mixing matrix from hyper connection. + residual (Tensor): [s, b, n*C] - original residual (n-stream hidden states). + mlp_hc_h_post (Tensor): [s, b, n] - expansion weights from hyper connection. + mhc_recompute_manager: Optional MHCRecomputeManager for checkpoint management. + + Returns: + output (Tensor): Transformed hidden states of shape [s, b, n*C]. + """ + if self.recompute_pre_mlp_layernorm or ( + mhc_mlp_bda_recompute_manager is not None and self.mhc_checkpoint_pre_mlp_layernorm + ): + self.pre_mlp_norm_checkpoint.discard_output_and_register_recompute( + mlp_output_with_bias[0] + ) + + nvtx_range_push(suffix="mlp_h_res_h_post_bda") + with self.bias_dropout_add_exec_handler(): + hidden_states = self.mlp_hyper_connection.h_res_h_post_bda( + mlp_h_res, + residual, + mlp_hc_h_post, + mlp_output_with_bias, + self.hidden_dropout, + self.training, + self.config.bias_dropout_fusion, + mhc_mlp_bda_recompute_manager, + ) + nvtx_range_pop(suffix="mlp_h_res_h_post_bda") + + if self.offload_mlp_norm: + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, + ) + + hidden_states = off_interface.group_commit(hidden_states, name="mlp_norm") + + output = make_viewless_tensor( + inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True + ) + return output + + class MoETransformerLayer(TransformerLayer): """ A Transformer layer specialized for Mixture-of-Experts (MoE) architectures. diff --git a/megatron/training/initialize.py b/megatron/training/initialize.py index ff655502019..713bce571d4 100644 --- a/megatron/training/initialize.py +++ b/megatron/training/initialize.py @@ -25,7 +25,12 @@ from megatron.core.transformer.custom_layers.batch_invariant_kernels import ( enable_batch_invariant_mode, ) -from megatron.core.utils import get_te_version, is_te_min_version, is_torch_min_version +from megatron.core.utils import ( + configure_nvtx_profiling, + get_te_version, + is_te_min_version, + is_torch_min_version, +) from megatron.training import ( get_adlr_autoresume, get_args, @@ -89,6 +94,10 @@ def state_restore_func(state_dict): print_rank_0("Enabling batch invariant mode globally") enable_batch_invariant_mode() + # Enable NVTX range profiling when profiling is active. + if args.profile: + configure_nvtx_profiling(True) + # torch.distributed initialization def finish_mpu_init(): args = get_args() diff --git a/tests/unit_tests/models/test_gpt_layer_specs.py b/tests/unit_tests/models/test_gpt_layer_specs.py new file mode 100644 index 00000000000..649f4a98c8d --- /dev/null +++ b/tests/unit_tests/models/test_gpt_layer_specs.py @@ -0,0 +1,73 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import pytest + +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, +) +from megatron.core.transformer.hyper_connection import HyperConnectionModule +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.transformer_layer import ( + HyperConnectionTransformerLayer, + TransformerLayer, +) + +_TE = get_gpt_layer_with_transformer_engine_spec +_LOCAL = get_gpt_layer_local_spec +_HC = HyperConnectionTransformerLayer +_HC_MOD = HyperConnectionModule +_TL = TransformerLayer +_ID = IdentityOp + + +class TestGptLayerSpecsHyperConnection: + """Test that enable_hyper_connections controls module types in layer specs.""" + + @pytest.mark.parametrize( + "factory,kwargs,expected_module,expected_hc", + [ + (_TE, {}, _TL, _ID), + (_TE, {"enable_hyper_connections": True}, _HC, _HC_MOD), + (_TE, {"enable_hyper_connections": False}, _TL, _ID), + (_TE, {"multi_latent_attention": True, "enable_hyper_connections": False}, _TL, _ID), + (_TE, {"multi_latent_attention": True, "enable_hyper_connections": True}, _HC, _HC_MOD), + (_LOCAL, {}, _TL, _ID), + (_LOCAL, {"enable_hyper_connections": True}, _HC, _HC_MOD), + (_LOCAL, {"enable_hyper_connections": False}, _TL, _ID), + (_LOCAL, {"multi_latent_attention": True, "enable_hyper_connections": False}, _TL, _ID), + ( + _LOCAL, + {"multi_latent_attention": True, "enable_hyper_connections": True}, + _HC, + _HC_MOD, + ), + (_LOCAL, {"normalization": "RMSNorm", "enable_hyper_connections": False}, _TL, _ID), + (_LOCAL, {"normalization": "RMSNorm", "enable_hyper_connections": True}, _HC, _HC_MOD), + ], + ids=[ + "te_default", + "te_enable", + "te_disable", + "te_mla_disable", + "te_mla_enable", + "local_default", + "local_enable", + "local_disable", + "local_mla_disable", + "local_mla_enable", + "local_rmsnorm_disable", + "local_rmsnorm_enable", + ], + ) + def test_hyper_connection_spec(self, factory, kwargs, expected_module, expected_hc): + spec = factory(**kwargs) + assert spec.module is expected_module + assert spec.submodules.self_attention_hyper_connection is expected_hc + assert spec.submodules.mlp_hyper_connection is expected_hc + + @pytest.mark.parametrize("factory", [_TE, _LOCAL], ids=["te", "local"]) + def test_singular_hyper_connection_keyword_rejected(self, factory): + """The spec keyword should match TransformerConfig.enable_hyper_connections.""" + with pytest.raises(TypeError, match="enable_hyper_connection"): + factory(enable_hyper_connection=True) diff --git a/tests/unit_tests/models/test_hybrid_moe_model.py b/tests/unit_tests/models/test_hybrid_moe_model.py index 01a46efe083..56c12076041 100644 --- a/tests/unit_tests/models/test_hybrid_moe_model.py +++ b/tests/unit_tests/models/test_hybrid_moe_model.py @@ -89,6 +89,7 @@ "embedding_init_method_std": 0.014, "enable_autocast": False, "enable_cuda_graph": False, + "enable_hyper_connections": False, "ep_overlap_early_attn_memory_release": False, "experimental_attention_variant": None, "expert_model_parallel_size": 4, @@ -151,6 +152,9 @@ "mamba_state_dim": 128, "masked_softmax_fusion": True, "memory_efficient_layer_norm": False, + "mhc_init_gating_factor": 0.01, + "mhc_recompute_layer_num": None, + "mhc_sinkhorn_iterations": 20, "microbatch_group_size_per_vp_stage": 1, "mlp_chunks_for_prefill": 1, "moe_apply_probs_on_input": False, @@ -219,6 +223,7 @@ "num_microbatches_with_partial_activation_checkpoints": None, "num_moe_experts": 128, "num_query_groups": 2, + "num_residual_streams": 4, "output_layer_init_method": {}, "overlap_moe_expert_parallel_comm": False, "overlap_p2p_comm": False, @@ -265,6 +270,7 @@ "tp_only_amax_red": False, "transformer_impl": "transformer_engine", "use_cpu_initialization": None, + "use_fused_mhc": False, "use_fused_weighted_squared_relu": False, "use_inference_optimized_layers": False, "use_kitchen": False, diff --git a/tests/unit_tests/tensor_parallel/test_random.py b/tests/unit_tests/tensor_parallel/test_random.py index 4fa79733d55..6ba6a0598e9 100644 --- a/tests/unit_tests/tensor_parallel/test_random.py +++ b/tests/unit_tests/tensor_parallel/test_random.py @@ -200,6 +200,13 @@ def test_forward(*input): Utils.destroy_model_parallel() +def test_checkpoint_without_output_fp8_flag_is_explicit(): + assert CheckpointWithoutOutput().fp8 is False + assert CheckpointWithoutOutput(fp8=False).fp8 is False + assert CheckpointWithoutOutput(fp8=None).fp8 is False + assert CheckpointWithoutOutput(fp8="e4m3").fp8 is True + + def test_checkpoint_without_output(): def normal_forward(input): x = torch.nn.functional.gelu(input) diff --git a/tests/unit_tests/test_fp8_param.py b/tests/unit_tests/test_fp8_param.py index 34b504e21de..e0a71526297 100644 --- a/tests/unit_tests/test_fp8_param.py +++ b/tests/unit_tests/test_fp8_param.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import contextlib import gc @@ -72,12 +72,12 @@ def setup_method(self, method): os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' def teardown_method(self, method): - Utils.destroy_model_parallel() - destroy_global_vars() - destroy_num_microbatches_calculator() if self.cuda_graph_helper is not None and self.cuda_graph_helper.graphs_created(): self.cuda_graph_helper.delete_cuda_graphs() self.cuda_graph_helper = None + Utils.destroy_model_parallel() + destroy_global_vars() + destroy_num_microbatches_calculator() gc.collect() def model_provider( diff --git a/tests/unit_tests/transformer/test_hyper_connection_recompute.py b/tests/unit_tests/transformer/test_hyper_connection_recompute.py new file mode 100644 index 00000000000..0b599f990a4 --- /dev/null +++ b/tests/unit_tests/transformer/test_hyper_connection_recompute.py @@ -0,0 +1,517 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +""" +Unit tests for HyperConnection block-level recomputation. + +Tests the following functionality: +1. HyperConnectionModule._forward_with_checkpoint correctness +2. HyperConnectionModule.apply_h_post with MHCRecomputeManager +3. Multiple HyperConnectionModules chained with a single MHCRecomputeManager +4. Partial checkpoint (last layer not checkpointed) +5. TransformerConfig 'mhc' in recompute_modules option +""" + +import warnings + +import pytest +import torch + +from megatron.core.tensor_parallel.random import ( + MHCRecomputeManager, + model_parallel_cuda_manual_seed, +) +from megatron.core.transformer.hyper_connection import HyperConnectionModule, reference_proj_inv_rms +from megatron.core.transformer.transformer_config import TransformerConfig +from tests.unit_tests.test_utilities import Utils + + +class TestHyperConnectionCheckpoint: + """Test HyperConnectionModule checkpoint functionality.""" + + def setup_method(self, method): + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def _create_hyper_connection_module(self, hidden_size=64, num_residual_streams=4): + """Create a HyperConnectionModule for testing.""" + config = TransformerConfig( + num_layers=2, + hidden_size=hidden_size, + num_attention_heads=4, + use_cpu_initialization=True, + enable_hyper_connections=True, + num_residual_streams=num_residual_streams, + mhc_sinkhorn_iterations=5, # Fewer iterations for faster tests + mhc_init_gating_factor=0.01, + ) + module = HyperConnectionModule(config=config, layer_number=1) + module.cuda() + return module + + def test_reference_proj_inv_rms_upcasts_norm_for_fp16(self): + width = 16384 + x = torch.full((1, width), 256.0, device='cuda', dtype=torch.float16) + weight = torch.zeros((1, width), device='cuda', dtype=torch.float16) + + _, inv_rms = reference_proj_inv_rms(x, weight) + + assert inv_rms.dtype == torch.float16 + assert torch.isfinite(inv_rms).all() + assert inv_rms.item() > 0 + + def test_forward_normal_vs_checkpoint_correctness(self): + """ + Test that _forward_with_checkpoint produces the same outputs as _forward_normal. + """ + hidden_size = 64 + num_streams = 4 + seq_len = 8 + batch_size = 2 + + module = self._create_hyper_connection_module(hidden_size, num_streams) + + # Create input tensors + hidden_states = torch.randn( + seq_len, batch_size, num_streams * hidden_size, device='cuda', requires_grad=True + ) + residual = torch.randn( + seq_len, batch_size, num_streams * hidden_size, device='cuda', requires_grad=True + ) + + # Clone inputs for comparison + hidden_states_ckpt = hidden_states.detach().clone().requires_grad_(True) + residual_ckpt = residual.detach().clone().requires_grad_(True) + + # Forward without checkpoint (reference) + torch.manual_seed(42) + torch.cuda.manual_seed(42) + aggregated_ref, h_res_ref, h_post_ref = module._forward_normal(hidden_states) + mixed_ref = module.apply_h_res(h_res_ref, residual) + loss_ref = aggregated_ref.sum() + mixed_ref.sum() + h_post_ref.sum() + loss_ref.backward() + grad_hidden_ref = hidden_states.grad.clone() + grad_residual_ref = residual.grad.clone() + + # Forward with checkpoint + torch.manual_seed(42) + torch.cuda.manual_seed(42) + manager = MHCRecomputeManager() + aggregated_ckpt, h_res_ckpt, h_post_ckpt = module._forward_with_checkpoint( + hidden_states_ckpt, manager + ) + mixed_ckpt = module.apply_h_res(h_res_ckpt, residual_ckpt) + # Calculate loss before discarding outputs + loss_ckpt = aggregated_ckpt.sum() + mixed_ckpt.sum() + h_post_ckpt.sum() + + # Register unified recompute hook + manager.discard_all_outputs_and_register_unified_recompute(loss_ckpt) + + # Backward pass + loss_ckpt.backward() + grad_hidden_ckpt = hidden_states_ckpt.grad.clone() + grad_residual_ckpt = residual_ckpt.grad.clone() + + # Verify gradients match + assert torch.allclose(grad_hidden_ckpt, grad_hidden_ref, atol=1e-5), ( + f"Hidden states gradients mismatch:\n" + f"Checkpoint: {grad_hidden_ckpt}\n" + f"Reference: {grad_hidden_ref}" + ) + assert torch.allclose(grad_residual_ckpt, grad_residual_ref, atol=1e-5), ( + f"Residual gradients mismatch:\n" + f"Checkpoint: {grad_residual_ckpt}\n" + f"Reference: {grad_residual_ref}" + ) + + def test_apply_h_post_with_checkpoint(self): + """ + Test that apply_h_post with manager produces correct gradients. + """ + hidden_size = 64 + num_streams = 4 + seq_len = 8 + batch_size = 2 + + module = self._create_hyper_connection_module(hidden_size, num_streams) + + # Create input tensors + x = torch.randn(seq_len, batch_size, hidden_size, device='cuda', requires_grad=True) + bias = torch.randn(hidden_size, device='cuda') + h_post = torch.randn(seq_len, batch_size, num_streams, device='cuda', requires_grad=True) + + # Clone inputs + x_ckpt = x.detach().clone().requires_grad_(True) + h_post_ckpt = h_post.detach().clone().requires_grad_(True) + + # Reference: without checkpoint (manager=None) + torch.manual_seed(42) + x_out_ref, bias_out_ref = module.apply_h_post((x, bias), h_post, manager=None) + loss_ref = x_out_ref.sum() + if bias_out_ref is not None: + loss_ref = loss_ref + bias_out_ref.sum() + loss_ref.backward() + grad_x_ref = x.grad.clone() + grad_h_post_ref = h_post.grad.clone() + + # With checkpoint (manager provided) + torch.manual_seed(42) + manager = MHCRecomputeManager() + x_out_ckpt, bias_out_ckpt = module.apply_h_post( + (x_ckpt, bias), h_post_ckpt, manager=manager + ) + loss_ckpt = x_out_ckpt.sum() + if bias_out_ckpt is not None: + loss_ckpt = loss_ckpt + bias_out_ckpt.sum() + + manager.discard_all_outputs_and_register_unified_recompute(loss_ckpt) + loss_ckpt.backward() + grad_x_ckpt = x_ckpt.grad.clone() + grad_h_post_ckpt = h_post_ckpt.grad.clone() + + # Verify gradients + assert torch.allclose(grad_x_ckpt, grad_x_ref, atol=1e-5) + assert torch.allclose(grad_h_post_ckpt, grad_h_post_ref, atol=1e-5) + + def test_forward_with_manager_parameter(self): + """ + Test forward() method with mhc_recompute_manager parameter. + """ + hidden_size = 64 + num_streams = 4 + seq_len = 8 + batch_size = 2 + + module = self._create_hyper_connection_module(hidden_size, num_streams) + + # Create input tensors + hidden_states = torch.randn( + seq_len, batch_size, num_streams * hidden_size, device='cuda', requires_grad=True + ) + + # Clone inputs + hidden_states_ckpt = hidden_states.detach().clone().requires_grad_(True) + + # Reference: forward without manager (uses _forward_normal) + torch.manual_seed(42) + torch.cuda.manual_seed(42) + aggregated_ref, h_res_ref, h_post_ref = module.forward( + hidden_states, mhc_recompute_manager=None + ) + loss_ref = aggregated_ref.sum() + h_res_ref.sum() + h_post_ref.sum() + loss_ref.backward() + grad_hidden_ref = hidden_states.grad.clone() + + # With manager (uses _forward_with_checkpoint) + torch.manual_seed(42) + torch.cuda.manual_seed(42) + manager = MHCRecomputeManager() + aggregated_ckpt, h_res_ckpt, h_post_ckpt = module.forward( + hidden_states_ckpt, mhc_recompute_manager=manager + ) + loss_ckpt = aggregated_ckpt.sum() + h_res_ckpt.sum() + h_post_ckpt.sum() + + manager.discard_all_outputs_and_register_unified_recompute(loss_ckpt) + loss_ckpt.backward() + grad_hidden_ckpt = hidden_states_ckpt.grad.clone() + + # Verify gradients match + assert torch.allclose(grad_hidden_ckpt, grad_hidden_ref, atol=1e-5) + + +class TestMHCBlockRecomputeIntegration: + """Test MHCRecomputeManager integration with HyperConnection.""" + + def setup_method(self, method): + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def test_multiple_hyper_connections_in_chain(self): + """ + Test that multiple HyperConnectionModules can be chained together + with a single MHCRecomputeManager. + """ + hidden_size = 64 + num_streams = 4 + seq_len = 8 + batch_size = 2 + n_channels = num_streams * hidden_size + + # Create multiple HyperConnection modules (simulating multiple layers) + config = TransformerConfig( + num_layers=4, + hidden_size=hidden_size, + num_attention_heads=4, + use_cpu_initialization=True, + enable_hyper_connections=True, + num_residual_streams=num_streams, + mhc_sinkhorn_iterations=5, + mhc_init_gating_factor=0.01, + ) + + modules = [ + HyperConnectionModule(config=config, layer_number=i + 1).cuda() for i in range(3) + ] + + # Create input tensors + hidden_states_ref = torch.randn( + seq_len, batch_size, n_channels, device='cuda', requires_grad=True + ) + residual_ref = torch.randn( + seq_len, batch_size, n_channels, device='cuda', requires_grad=True + ) + + hidden_states_ckpt = hidden_states_ref.detach().clone().requires_grad_(True) + residual_ckpt = residual_ref.detach().clone().requires_grad_(True) + + # Reference: forward without checkpoint + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + h = hidden_states_ref + r = residual_ref + for module in modules: + agg, h_res, h_post = module.forward(h, mhc_recompute_manager=None) + agg, _ = module.apply_h_post((0.1 * agg, None), h_post, manager=None) + mixed = module.apply_h_res(h_res, r) # Apply h_res to get mixed [s, b, n*C] + h = agg + mixed + r = h + + loss_ref = h.sum() + loss_ref.backward() + grad_hidden_ref = hidden_states_ref.grad.clone() + grad_residual_ref = residual_ref.grad.clone() + + # With checkpoint using single manager + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + manager = MHCRecomputeManager() + + h = hidden_states_ckpt + r = residual_ckpt + for module in modules: + agg, h_res, h_post = module.forward(h, mhc_recompute_manager=manager) + agg, _ = module.apply_h_post((0.1 * agg, None), h_post, manager=manager) + mixed = module.apply_h_res(h_res, r) # Apply h_res to get mixed [s, b, n*C] + h = agg + mixed + r = h + + loss_ckpt = h.sum() + manager.discard_all_outputs_and_register_unified_recompute(loss_ckpt) + loss_ckpt.backward() + + grad_hidden_ckpt = hidden_states_ckpt.grad.clone() + grad_residual_ckpt = residual_ckpt.grad.clone() + + # Verify gradients + assert torch.allclose( + grad_hidden_ckpt, grad_hidden_ref, atol=1e-4 + ), f"Chained HyperConnection hidden gradients mismatch" + assert torch.allclose( + grad_residual_ckpt, grad_residual_ref, atol=1e-4 + ), f"Chained HyperConnection residual gradients mismatch" + + def test_partial_checkpoint_last_layer_not_checkpointed(self): + """ + Test that when is_last_layer_in_block=True, the final output is NOT checkpointed. + This simulates the TransformerBlock behavior where the last layer's MLP BDA + serves as the hook_tensor for unified recompute. + """ + hidden_size = 64 + num_streams = 4 + seq_len = 8 + batch_size = 2 + + config = TransformerConfig( + num_layers=2, + hidden_size=hidden_size, + num_attention_heads=4, + use_cpu_initialization=True, + enable_hyper_connections=True, + num_residual_streams=num_streams, + mhc_sinkhorn_iterations=5, + mhc_init_gating_factor=0.01, + ) + + module = HyperConnectionModule(config=config, layer_number=1).cuda() + + hidden_states_ref = torch.randn( + seq_len, batch_size, num_streams * hidden_size, device='cuda', requires_grad=True + ) + residual_ref = torch.randn( + seq_len, batch_size, num_streams * hidden_size, device='cuda', requires_grad=True + ) + + hidden_states_ckpt = hidden_states_ref.detach().clone().requires_grad_(True) + residual_ckpt = residual_ref.detach().clone().requires_grad_(True) + + # Reference + torch.manual_seed(42) + torch.cuda.manual_seed(42) + aggregated_ref, h_res_ref, h_post_ref = module.forward( + hidden_states_ref, mhc_recompute_manager=None + ) + aggregated_ref, _ = module.apply_h_post( + (0.1 * aggregated_ref, None), h_post_ref, manager=None + ) + mixed_ref = module.apply_h_res( + h_res_ref, residual_ref + ) # Apply h_res to get mixed [s, b, n*C] + # Simulate BDA that is NOT checkpointed (last layer) + output_ref = aggregated_ref + 0.5 * mixed_ref + loss_ref = output_ref.sum() + loss_ref.backward() + grad_hidden_ref = hidden_states_ref.grad.clone() + + # With manager - checkpoint everything except final output + torch.manual_seed(42) + torch.cuda.manual_seed(42) + manager = MHCRecomputeManager() + aggregated_ckpt, h_res_ckpt, h_post_ckpt = module.forward( + hidden_states_ckpt, mhc_recompute_manager=manager + ) + + aggregated_ckpt, _ = module.apply_h_post( + (0.1 * aggregated_ckpt, None), h_post_ckpt, manager=manager + ) + mixed_ckpt = module.apply_h_res( + h_res_ckpt, residual_ckpt + ) # Apply h_res to get mixed [s, b, n*C] + # Simulate BDA that is NOT checkpointed (last layer) - this is the hook_tensor + output_ckpt = aggregated_ckpt + 0.5 * mixed_ckpt + + # Register unified recompute on the output (which is not checkpointed) + manager.discard_all_outputs_and_register_unified_recompute(output_ckpt) + + loss_ckpt = output_ckpt.sum() + loss_ckpt.backward() + grad_hidden_ckpt = hidden_states_ckpt.grad.clone() + + # Verify gradients match + assert torch.allclose(grad_hidden_ckpt, grad_hidden_ref, atol=1e-5) + + +class TestTransformerConfigRecomputeMhc: + """Test 'mhc' in recompute_modules configuration.""" + + def test_config_default_value(self): + """Test that 'mhc' is not in recompute_modules by default.""" + config = TransformerConfig(num_layers=2, hidden_size=64, num_attention_heads=4) + assert "mhc" not in config.recompute_modules + + def test_config_enable_mhc_recompute(self): + """Test enabling 'mhc' in recompute_modules.""" + config = TransformerConfig( + num_layers=2, + hidden_size=64, + num_attention_heads=4, + enable_hyper_connections=True, + num_residual_streams=4, + recompute_modules=["core_attn", "mhc"], + recompute_granularity='selective', + ) + assert "mhc" in config.recompute_modules + assert config.enable_hyper_connections is True + + def test_config_rejects_pipeline_parallel_hyper_connections(self): + """Pipeline-parallel tensor shapes do not support n-stream hidden states yet.""" + with pytest.raises( + ValueError, + match="enable_hyper_connections is not yet compatible with pipeline parallelism", + ): + TransformerConfig( + num_layers=2, + hidden_size=64, + num_attention_heads=4, + enable_hyper_connections=True, + num_residual_streams=4, + pipeline_model_parallel_size=2, + pipeline_dtype=torch.float32, + ) + + def test_config_rejects_fused_tp_inference_hyper_connections(self): + """mHC does not implement the fused TP inference residual path.""" + with pytest.raises( + ValueError, + match="enable_hyper_connections is not compatible with " + "inference_fuse_tp_communication", + ): + TransformerConfig( + num_layers=2, + hidden_size=64, + num_attention_heads=4, + enable_hyper_connections=True, + num_residual_streams=4, + inference_fuse_tp_communication=True, + ) + + def test_config_rejects_moe_router_cuda_graph_hyper_connections(self): + """mHC MoE layers do not implement the TE moe_router CUDA graph path yet.""" + with pytest.raises( + ValueError, + match="enable_hyper_connections is not yet compatible with " + "MoE router CUDA graphs", + ): + TransformerConfig( + num_layers=2, + hidden_size=64, + num_attention_heads=4, + enable_hyper_connections=True, + num_residual_streams=4, + num_moe_experts=4, + cuda_graph_impl="transformer_engine", + cuda_graph_scope=["moe_router"], + ) + + def test_hyper_connection_recompute_warning_requires_recompute(self): + """Do not warn about missing 'mhc' recompute when recompute is disabled.""" + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + TransformerConfig( + num_layers=2, + hidden_size=64, + num_attention_heads=4, + enable_hyper_connections=True, + num_residual_streams=4, + ) + + assert not any("HyperConnections are enabled" in str(w.message) for w in caught) + + def test_hyper_connection_recompute_warning_for_selective_without_mhc(self): + """Still warn when selective recompute is on but 'mhc' is omitted.""" + with pytest.warns(UserWarning, match="HyperConnections are enabled"): + TransformerConfig( + num_layers=2, + hidden_size=64, + num_attention_heads=4, + enable_hyper_connections=True, + num_residual_streams=4, + recompute_granularity="selective", + ) + + def test_config_rejects_full_recompute_hyper_connections(self): + """Full activation recompute is not wired for hyper-connection blocks yet.""" + with pytest.raises( + ValueError, + match="enable_hyper_connections is not yet compatible with full activation recompute", + ): + TransformerConfig( + num_layers=2, + hidden_size=64, + num_attention_heads=4, + enable_hyper_connections=True, + num_residual_streams=4, + recompute_granularity="full", + recompute_method="block", + recompute_num_layers=1, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit_tests/transformer/test_mhc_block_manager.py b/tests/unit_tests/transformer/test_mhc_block_manager.py new file mode 100644 index 00000000000..f3d09695169 --- /dev/null +++ b/tests/unit_tests/transformer/test_mhc_block_manager.py @@ -0,0 +1,413 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import pytest +import torch + +from megatron.core.tensor_parallel.random import ( + MHCRecomputeManager, + CheckpointWithoutOutput, + initialize_rng_tracker, +) +from tests.unit_tests.test_utilities import Utils + + +class TestCheckpointWithoutOutputManagerAPI: + """Test CheckpointWithoutOutput integration with MHCRecomputeManager.""" + + def setup_method(self, method): + Utils.initialize_model_parallel() + initialize_rng_tracker(force_reset=True) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def test_auto_register(self): + """CheckpointWithoutOutput auto-registers to manager when ckpt_manager is provided.""" + manager = MHCRecomputeManager() + + def func(x): + return x * 2 + 1 + + input_t = torch.randn(4, 4, device='cuda', requires_grad=True) + + ckpt = CheckpointWithoutOutput(ckpt_manager=manager) + y = ckpt.checkpoint(func, input_t) + + assert len(manager.checkpoints) == 1 + assert manager.checkpoints[0] is ckpt + + ckpt2 = CheckpointWithoutOutput(ckpt_manager=manager) + y2 = ckpt2.checkpoint(torch.nn.functional.gelu, y) + + assert len(manager.checkpoints) == 2 + assert manager.checkpoints[1] is ckpt2 + + loss = y2.sum() + manager.discard_all_outputs_and_register_unified_recompute(loss) + loss.backward() + + assert input_t.grad is not None + + def test_discard_is_noop_with_manager(self): + """discard_output_and_register_recompute is a NO-OP when ckpt_manager is set.""" + manager = MHCRecomputeManager() + + def func1(x): + return x * 2 + + def func2(x): + return torch.nn.functional.gelu(x) + + input_ref = torch.randn(4, 4, device='cuda', requires_grad=True) + y1_ref = func1(input_ref) + y2_ref = func2(y1_ref) + loss_ref = y2_ref.sum() + loss_ref.backward() + grad_ref = input_ref.grad.clone() + + input_ckpt = input_ref.detach().clone().requires_grad_(True) + + ckpt1 = CheckpointWithoutOutput(ckpt_manager=manager) + y1 = ckpt1.checkpoint(func1, input_ckpt) + ckpt1.discard_output_and_register_recompute(y1) + + ckpt2 = CheckpointWithoutOutput(ckpt_manager=manager) + y2 = ckpt2.checkpoint(func2, y1) + ckpt2.discard_output_and_register_recompute(y2) + + assert y1.untyped_storage().size() > 0, "y1 should NOT be discarded yet" + assert y2.untyped_storage().size() > 0, "y2 should NOT be discarded yet" + + loss_ckpt = y2.sum() + manager.discard_all_outputs_and_register_unified_recompute(loss_ckpt) + + assert y1.untyped_storage().size() == 0, "y1 should be discarded after manager call" + assert y2.untyped_storage().size() == 0, "y2 should be discarded after manager call" + + loss_ckpt.backward() + grad_ckpt = input_ckpt.grad.clone() + + assert torch.allclose(grad_ckpt, grad_ref, atol=1e-6) + + def test_backward_compat_without_manager(self): + """CheckpointWithoutOutput without ckpt_manager should work exactly as before.""" + + def func(x): + return torch.nn.functional.gelu(x) + + input_ref = torch.randn(4, 4, device='cuda', requires_grad=True) + y_ref = func(input_ref) + z_ref = y_ref * 2 + loss_ref = z_ref.sum() + loss_ref.backward() + grad_ref = input_ref.grad.clone() + + input_ckpt = input_ref.detach().clone().requires_grad_(True) + + ckpt = CheckpointWithoutOutput() + y = ckpt.checkpoint(func, input_ckpt) + z = y * 2 + ckpt.discard_output_and_register_recompute(z) + + assert y.untyped_storage().size() == 0 + + loss_ckpt = z.sum() + loss_ckpt.backward() + grad_ckpt = input_ckpt.grad.clone() + + assert torch.allclose(grad_ckpt, grad_ref, atol=1e-6) + + def test_error_handling(self): + """MHCRecomputeManager rejects invalid add_checkpoint calls.""" + manager = MHCRecomputeManager() + + with pytest.raises(TypeError): + manager.add_checkpoint("not a checkpoint") + + ckpt = CheckpointWithoutOutput() + with pytest.raises(ValueError): + manager.add_checkpoint(ckpt) + + def test_unified_recompute_keeps_outputs_when_hook_has_no_grad(self): + """Do not discard outputs if no hook can be registered for recompute.""" + manager = MHCRecomputeManager() + + def func(x): + return x * 2 + + input_t = torch.randn(4, 4, device='cuda', requires_grad=True) + ckpt = CheckpointWithoutOutput(ckpt_manager=manager) + y = ckpt.checkpoint(func, input_t) + + hook_tensor = torch.zeros((), device='cuda', requires_grad=False) + manager.discard_all_outputs_and_register_unified_recompute(hook_tensor) + + assert y.untyped_storage().size() > 0 + + +class TestMHCRecomputeManagerSequentialChain: + """Test MHCRecomputeManager with sequential checkpoint chains.""" + + def setup_method(self, method): + Utils.initialize_model_parallel() + initialize_rng_tracker(force_reset=True) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def test_basic_sequential_chain(self): + """Three sequential checkpoints: gradients match non-checkpointed version.""" + + def func1(x): + return x * 2 + 1 + + def func2(x): + return torch.nn.functional.gelu(x) + + def func3(x): + return x * x + x + + input_ref = torch.randn(4, 4, device='cuda', requires_grad=True) + input_ckpt = input_ref.detach().clone().requires_grad_(True) + + y1_ref = func1(input_ref) + y2_ref = func2(y1_ref) + y3_ref = func3(y2_ref) + loss_ref = y3_ref.sum() + loss_ref.backward() + grad_ref = input_ref.grad.clone() + + manager = MHCRecomputeManager() + + y1 = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint(func1, input_ckpt) + y2 = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint(func2, y1) + y3 = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint(func3, y2) + + loss_ckpt = y3.sum() + manager.discard_all_outputs_and_register_unified_recompute(loss_ckpt) + + assert y1.untyped_storage().size() == 0, "y1 storage should be released" + assert y2.untyped_storage().size() == 0, "y2 storage should be released" + assert y3.untyped_storage().size() == 0, "y3 storage should be released" + + loss_ckpt.backward() + grad_ckpt = input_ckpt.grad.clone() + + assert torch.allclose( + grad_ckpt, grad_ref, atol=1e-6 + ), f"Gradients mismatch!\nWith manager: {grad_ckpt}\nReference: {grad_ref}" + + def test_sequential_chain_with_dropout(self): + """RNG state is restored during recompute so dropout gradients match.""" + + def func_with_dropout(x): + return torch.nn.functional.dropout(x, p=0.3, training=True) + + def func2(x): + return torch.nn.functional.gelu(x) + + input_ref = torch.randn(4, 4, device='cuda', requires_grad=True) + input_ckpt = input_ref.detach().clone().requires_grad_(True) + + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + y1_ref = func_with_dropout(input_ref) + y2_ref = func2(y1_ref) + loss_ref = y2_ref.sum() + loss_ref.backward() + grad_ref = input_ref.grad.clone() + + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + manager = MHCRecomputeManager() + + y1 = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint(func_with_dropout, input_ckpt) + y2 = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint(func2, y1) + + loss_ckpt = y2.sum() + manager.discard_all_outputs_and_register_unified_recompute(loss_ckpt) + + loss_ckpt.backward() + grad_ckpt = input_ckpt.grad.clone() + + assert torch.allclose( + grad_ckpt, grad_ref, atol=1e-6 + ), f"Gradients with dropout mismatch!\nWith manager: {grad_ckpt}\nReference: {grad_ref}" + + def test_multiple_outputs(self): + """MHCRecomputeManager handles functions that return multiple outputs.""" + + def func_multi_output(x): + return x * 2, x + 1 + + def func_combine(a, b): + return a + b + + input_ref = torch.randn(4, 4, device='cuda', requires_grad=True) + input_ckpt = input_ref.detach().clone().requires_grad_(True) + + y1a_ref, y1b_ref = func_multi_output(input_ref) + y2_ref = func_combine(y1a_ref, y1b_ref) + loss_ref = y2_ref.sum() + loss_ref.backward() + grad_ref = input_ref.grad.clone() + + manager = MHCRecomputeManager() + + y1a, y1b = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint( + func_multi_output, input_ckpt + ) + y2 = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint(func_combine, y1a, y1b) + + loss_ckpt = y2.sum() + manager.discard_all_outputs_and_register_unified_recompute(loss_ckpt) + + loss_ckpt.backward() + grad_ckpt = input_ckpt.grad.clone() + + assert torch.allclose(grad_ckpt, grad_ref, atol=1e-6), ( + f"Gradients mismatch with multiple outputs!\n" + f"With manager: {grad_ckpt}\nReference: {grad_ref}" + ) + + +class TestMHCRecomputeManagerPartialCheckpoint: + """Test MHCRecomputeManager with partial checkpointing (some ops not checkpointed).""" + + def setup_method(self, method): + Utils.initialize_model_parallel() + initialize_rng_tracker(force_reset=True) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def test_partial_checkpoint(self): + """ + Only f and h are checkpointed; g is a regular operation. + + Computation chain: + a --[f]--> b --[g]--> c --[h]--> d --[sum]--> loss + """ + + def func_f(x): + return torch.nn.functional.gelu(x * 2 + 1) + + def func_g(x): + return x * 3 - 2 + + def func_h(x): + return torch.sigmoid(x) + x + + input_ref = torch.randn(4, 4, device='cuda', requires_grad=True) + + b_ref = func_f(input_ref) + c_ref = func_g(b_ref) + d_ref = func_h(c_ref) + loss_ref = d_ref.sum() + loss_ref.backward() + grad_ref = input_ref.grad.clone() + + input_ckpt = input_ref.detach().clone().requires_grad_(True) + + manager = MHCRecomputeManager() + + b = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint(func_f, input_ckpt) + c = func_g(b) + d = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint(func_h, c) + + loss_ckpt = d.sum() + manager.discard_all_outputs_and_register_unified_recompute(loss_ckpt) + + assert b.untyped_storage().size() == 0, "b storage should be released" + assert d.untyped_storage().size() == 0, "d storage should be released" + assert c.untyped_storage().size() > 0, "c storage should NOT be released (not checkpointed)" + + loss_ckpt.backward() + grad_ckpt = input_ckpt.grad.clone() + + assert torch.allclose(grad_ckpt, grad_ref, atol=1e-6), ( + f"Gradients mismatch with partial checkpoint!\n" + f"With manager: {grad_ckpt}\nReference: {grad_ref}" + ) + + def test_partial_checkpoint_with_tuple_output(self): + """ + Mimics HyperConnection's computation pattern with tuple outputs. + + - compute_mappings: checkpointed, returns tuple (h_pre, h_post, h_res) + - aggregate: NOT checkpointed + - apply_h_res: checkpointed + - apply_h_post: checkpointed + """ + + def compute_mappings(x): + h_pre = torch.sigmoid(x.mean(dim=-1, keepdim=True).expand_as(x)) + h_post = torch.tanh(x.sum(dim=-1, keepdim=True).expand_as(x)) + h_res = torch.relu(x) + return h_pre, h_post, h_res + + def aggregate(x, h_pre): + return x * h_pre + + def apply_h_res(h_res, residual): + return h_res + residual * 0.5 + + def apply_h_post(y, h_post): + return y * h_post + y + + x_ref = torch.randn(4, 4, device='cuda', requires_grad=True) + residual_ref = torch.randn(4, 4, device='cuda', requires_grad=True) + + h_pre_ref, h_post_ref, h_res_ref = compute_mappings(x_ref) + agg_ref = aggregate(x_ref, h_pre_ref) + y_ref = torch.nn.functional.gelu(agg_ref) + mixed_ref = apply_h_res(h_res_ref, residual_ref) + output_ref = apply_h_post(y_ref, h_post_ref) + final_ref = output_ref + mixed_ref + loss_ref = final_ref.sum() + loss_ref.backward() + grad_x_ref = x_ref.grad.clone() + grad_residual_ref = residual_ref.grad.clone() + + x_ckpt = x_ref.detach().clone().requires_grad_(True) + residual_ckpt = residual_ref.detach().clone().requires_grad_(True) + + manager = MHCRecomputeManager() + + h_pre, h_post, h_res = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint( + compute_mappings, x_ckpt + ) + agg = aggregate(x_ckpt, h_pre) + y = torch.nn.functional.gelu(agg) + mixed = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint( + apply_h_res, h_res, residual_ckpt + ) + output = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint(apply_h_post, y, h_post) + + final = output + mixed + loss_ckpt = final.sum() + + manager.discard_all_outputs_and_register_unified_recompute(loss_ckpt) + + assert h_pre.untyped_storage().size() == 0, "h_pre storage should be released" + assert h_post.untyped_storage().size() == 0, "h_post storage should be released" + assert h_res.untyped_storage().size() == 0, "h_res storage should be released" + assert mixed.untyped_storage().size() == 0, "mixed storage should be released" + assert output.untyped_storage().size() == 0, "output storage should be released" + + assert agg.untyped_storage().size() > 0, "agg storage should NOT be released" + assert y.untyped_storage().size() > 0, "y storage should NOT be released" + + loss_ckpt.backward() + grad_x_ckpt = x_ckpt.grad.clone() + grad_residual_ckpt = residual_ckpt.grad.clone() + + assert torch.allclose( + grad_x_ckpt, grad_x_ref, atol=1e-6 + ), f"Gradients for x mismatch!\nWith manager: {grad_x_ckpt}\nReference: {grad_x_ref}" + assert torch.allclose(grad_residual_ckpt, grad_residual_ref, atol=1e-6), ( + f"Gradients for residual mismatch!\n" + f"With manager: {grad_residual_ckpt}\nReference: {grad_residual_ref}" + ) diff --git a/tests/unit_tests/transformer/test_transformer_layer.py b/tests/unit_tests/transformer/test_transformer_layer.py index da1f9ce5860..08c830f66e4 100644 --- a/tests/unit_tests/transformer/test_transformer_layer.py +++ b/tests/unit_tests/transformer/test_transformer_layer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import pytest @@ -8,17 +8,42 @@ from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedTensor from megatron.core.inference.contexts import StaticInferenceContext from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_with_transformer_engine_spec, get_gpt_layer_with_transformer_engine_submodules, ) -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.tensor_parallel.random import MHCRecomputeManager, model_parallel_cuda_manual_seed +from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import ( + HyperConnectionTransformerLayer, TransformerLayer, get_transformer_layer_offset, ) from tests.unit_tests.test_utilities import Utils +def _make_mhc_config(hidden_size=64, num_streams=4, **extra): + """Build a TransformerConfig with common MHC defaults. + + Any default can be overridden via **extra + (e.g. ``_make_mhc_config(num_layers=8, recompute_modules=["core_attn", "mhc"])``). + """ + base = dict( + num_layers=2, + hidden_size=hidden_size, + num_attention_heads=4, + use_cpu_initialization=True, + enable_hyper_connections=True, + num_residual_streams=num_streams, + mhc_sinkhorn_iterations=5, + mhc_init_gating_factor=0.01, + hidden_dropout=0.0, + attention_dropout=0.0, + ) + base.update(extra) + return TransformerConfig(**base) + + class TestParallelTransformerLayer: def setup_method(self, method): @@ -62,6 +87,30 @@ def test_gpu_forward(self): assert hidden_states.shape[1] == micro_batch_size assert hidden_states.shape[2] == config.hidden_size + def test_gpu_forward_ignores_mhc_recompute_manager_kwarg(self): + """Non-HC layers must tolerate TransformerBlock's mHC recompute kwarg.""" + parallel_transformer_layer = self.parallel_transformer_layer + config: TransformerConfig = parallel_transformer_layer.config + sequence_length = 8 + micro_batch_size = 2 + parallel_transformer_layer.cuda() + + hidden_states = torch.ones( + (sequence_length, micro_batch_size, config.hidden_size), device='cuda' + ) + attention_mask = torch.ones( + (1, 1, sequence_length, sequence_length), dtype=bool, device='cuda' + ) + + hidden_states, context = parallel_transformer_layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + mhc_recompute_manager=None, + ) + + assert context is None + assert hidden_states.shape == (sequence_length, micro_batch_size, config.hidden_size) + def test_chunked_mlp(self): with torch.no_grad(): @@ -313,3 +362,800 @@ def get_tensor_shapes_for_tp(transformer_config, tp_size): 'self_attention.linear_qkv.weight': (hs * 3 // tp_size, hs), 'self_attention.linear_qkv.bias': (hs * 3 // tp_size,), } + + +class TestTransformerLayerWithHyperConnectionRecompute: + """Test TransformerLayer with HyperConnection and MHC block recomputation.""" + + def setup_method(self, method): + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def _create_layer_with_hyper_connection( + self, hidden_size=64, num_streams=4, layer_number=1, **extra + ): + """Create a HyperConnectionTransformerLayer with hyper connection enabled.""" + config = _make_mhc_config( + hidden_size=hidden_size, + num_streams=num_streams, + recompute_modules=["core_attn", "mhc"], + recompute_granularity='selective', + **extra, + ) + layer_spec = get_gpt_layer_with_transformer_engine_spec(enable_hyper_connections=True) + layer = HyperConnectionTransformerLayer( + config, layer_spec.submodules, layer_number=layer_number + ) + layer.cuda() + return layer, config + + def test_rejects_cross_attention_with_hyper_connection(self): + """mHC transformer layers are decoder-only until cross-attention is implemented.""" + config = _make_mhc_config() + layer_spec = get_gpt_layer_with_transformer_engine_spec(enable_hyper_connections=True) + layer_spec.submodules.cross_attention = torch.nn.Linear + + with pytest.raises(ValueError, match="does not support cross-attention"): + HyperConnectionTransformerLayer(config, layer_spec.submodules) + + layer_spec.submodules.cross_attention = IdentityOp + + def test_layernorm_tuple_outputs_are_unpacked(self): + """HC layers should mirror base-layer tuple handling for fused layernorms.""" + + class TupleLayerNorm(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states, hidden_states + + layer, config = self._create_layer_with_hyper_connection() + layer.input_layernorm = TupleLayerNorm().cuda() + layer.pre_mlp_layernorm = TupleLayerNorm().cuda() + layer.train() + + seq_len = 8 + batch_size = 2 + hidden_states = torch.randn( + seq_len, + batch_size, + config.num_residual_streams * config.hidden_size, + device='cuda', + requires_grad=True, + ) + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + output, context = layer(hidden_states=hidden_states, attention_mask=attention_mask) + + assert context is None + assert output.shape == hidden_states.shape + + def test_forward_with_hyper_connection_recompute(self): + """ + Test that TransformerLayer forward works correctly with HyperConnection + and MHC block recomputation enabled. + """ + hidden_size = 64 + num_streams = 4 + seq_len = 8 + batch_size = 2 + + layer, config = self._create_layer_with_hyper_connection(hidden_size, num_streams) + layer.train() # Enable training mode for recomputation + + # Input shape: [seq_len, batch_size, n * hidden_size] for hyper connections + n_channels = num_streams * hidden_size + hidden_states = torch.randn( + seq_len, batch_size, n_channels, device='cuda', requires_grad=True + ) + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + # Create manager for MHC block recomputation + manager = MHCRecomputeManager() + + # Forward pass with recompute manager + manager.is_last_layer_in_recompute_block = True + output, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + mhc_recompute_manager=manager, + ) + + # Verify output shape + assert output.shape == ( + seq_len, + batch_size, + n_channels, + ), f"Expected output shape {(seq_len, batch_size, n_channels)}, got {output.shape}" + + # Register unified recompute hook at block boundary. + manager.discard_all_outputs_and_register_unified_recompute(output) + + # Backward pass should work without error + loss = output.sum() + loss.backward() + + # Verify gradients exist + assert hidden_states.grad is not None, "Gradients should be computed for hidden_states" + assert hidden_states.grad.shape == hidden_states.shape + + def test_intermediate_layer_with_recompute(self): + """ + Test TransformerLayer as an intermediate layer (not last in block). + In this case, MLP BDA should also be checkpointed. + """ + hidden_size = 64 + num_streams = 4 + seq_len = 8 + batch_size = 2 + + layer, config = self._create_layer_with_hyper_connection(hidden_size, num_streams) + layer.train() + + n_channels = num_streams * hidden_size + hidden_states = torch.randn( + seq_len, batch_size, n_channels, device='cuda', requires_grad=True + ) + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + manager = MHCRecomputeManager() + + # Forward pass - NOT the last layer in block + manager.is_last_layer_in_recompute_block = False + output, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + mhc_recompute_manager=manager, + ) + + # Verify output shape + assert output.shape == (seq_len, batch_size, n_channels) + + # Backward pass should work + loss = output.sum() + # For intermediate layers, we need to pass output to next layer + # Here we just register the recompute hook on output for testing + manager.discard_all_outputs_and_register_unified_recompute(loss) + + loss.backward() + + assert hidden_states.grad is not None + assert hidden_states.grad.shape == hidden_states.shape + + def test_multiple_layers_chain_with_recompute(self): + """ + Test multiple TransformerLayers chained together with a single + MHCRecomputeManager, simulating TransformerBlock behavior. + """ + hidden_size = 64 + num_streams = 4 + seq_len = 8 + batch_size = 2 + num_layers = 3 + + layers = [ + self._create_layer_with_hyper_connection( + hidden_size, num_streams, layer_number=i + 1, num_layers=num_layers + )[0] + for i in range(num_layers) + ] + + for layer in layers: + layer.train() + + n_channels = num_streams * hidden_size + hidden_states = torch.randn( + seq_len, batch_size, n_channels, device='cuda', requires_grad=True + ) + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + # Single manager for all layers (like TransformerBlock) + manager = MHCRecomputeManager() + + # Forward through all layers + h = hidden_states + for i, layer in enumerate(layers): + is_last = i == num_layers - 1 + manager.is_last_layer_in_recompute_block = is_last + h, _ = layer( + hidden_states=h, attention_mask=attention_mask, mhc_recompute_manager=manager + ) + if is_last: + manager.discard_all_outputs_and_register_unified_recompute(h) + + # Backward pass + loss = h.sum() + loss.backward() + + # Verify gradients + assert hidden_states.grad is not None + assert hidden_states.grad.shape == hidden_states.shape + # Check that gradient is non-trivial (not all zeros) + assert hidden_states.grad.abs().sum() > 0 + + +class TestMHCRecomputeMemorySaving: + """Verify that 'mhc' in recompute_modules actually reduces peak GPU memory.""" + + def setup_method(self, method): + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @staticmethod + def _run_forward_backward( + num_layers, + hidden_size, + num_streams, + seq_len, + batch_size, + use_recompute, + recompute_block_size=2, + ): + """Run a full forward + backward pass and return (peak memory, output grad). + + When use_recompute=True, a new MHCRecomputeManager is created every + `recompute_block_size` layers, mirroring TransformerBlock's + _build_mhc_recompute_layer_plan logic. + """ + config = _make_mhc_config( + hidden_size=hidden_size, + num_streams=num_streams, + num_layers=num_layers, + recompute_modules=["core_attn", "mhc"] if use_recompute else None, + recompute_granularity='selective' if use_recompute else None, + ) + layer_spec = get_gpt_layer_with_transformer_engine_spec(enable_hyper_connections=True) + layers = [ + HyperConnectionTransformerLayer( + config, layer_spec.submodules, layer_number=i + 1 + ).cuda() + for i in range(num_layers) + ] + for layer in layers: + layer.train() + + n_channels = num_streams * hidden_size + hidden_states = torch.randn( + seq_len, batch_size, n_channels, device='cuda', requires_grad=True + ) + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + + manager = MHCRecomputeManager() if use_recompute else None + + h = hidden_states + for i, layer in enumerate(layers): + is_last_in_block = (i == num_layers - 1) or ((i + 1) % recompute_block_size == 0) + kwargs = dict(hidden_states=h, attention_mask=attention_mask) + if manager is not None: + manager.is_last_layer_in_recompute_block = is_last_in_block + kwargs['mhc_recompute_manager'] = manager + h, _ = layer(**kwargs) + if manager is not None and is_last_in_block: + manager.discard_all_outputs_and_register_unified_recompute(h) + if i < num_layers - 1: + manager = MHCRecomputeManager() + + loss = h.sum() + loss.backward() + torch.cuda.synchronize() + + peak_mem = torch.cuda.max_memory_allocated() + grad = hidden_states.grad.clone() + + del layers, hidden_states, h, loss, manager + torch.cuda.empty_cache() + + return peak_mem, grad + + def test_recompute_reduces_peak_memory(self): + """Peak memory with recompute (block_size=2) should be lower than without.""" + num_layers = 8 + hidden_size = 128 + num_streams = 4 + seq_len = 64 + batch_size = 4 + + peak_no_recompute, _ = self._run_forward_backward( + num_layers, hidden_size, num_streams, seq_len, batch_size, use_recompute=False + ) + peak_recompute, _ = self._run_forward_backward( + num_layers, + hidden_size, + num_streams, + seq_len, + batch_size, + use_recompute=True, + recompute_block_size=2, + ) + + saving_pct = (peak_no_recompute - peak_recompute) / peak_no_recompute * 100 + + assert peak_recompute < peak_no_recompute, ( + f"Recompute should reduce peak memory, but got " + f"no_recompute={peak_no_recompute / 1e6:.1f}MB vs " + f"recompute={peak_recompute / 1e6:.1f}MB " + f"(saving={saving_pct:.1f}%)" + ) + + +class TestMHCWithCudaGraph: + """Test HyperConnectionTransformerLayer compatibility with CUDA graphs. + + CUDA graph capture requires static computation graphs and fixed tensor shapes. + These tests verify that the mHC layer properly supports the CUDA graph interface + defined in GraphableMegatronModule and TransformerLayer. + """ + + def setup_method(self, method): + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123, use_cudagraphable_rng=True, force_reset_rng=True) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def _create_mhc_layer(self, hidden_size=64, num_streams=4, **extra_config): + config = _make_mhc_config(hidden_size=hidden_size, num_streams=num_streams, **extra_config) + layer_spec = get_gpt_layer_with_transformer_engine_spec(enable_hyper_connections=True) + layer = HyperConnectionTransformerLayer(config, layer_spec.submodules) + layer.cuda() + return layer, config + + def test_get_layer_static_inputs_shape_for_mhc(self): + """get_layer_static_inputs must return [s, b, n*C] for mHC layers. + + CUDA graph capture creates static buffers whose shapes are determined by + this method. If the shape is [s, b, C] instead of [s, b, n*C], the graph + capture will produce a shape mismatch at the first hyper connection module. + """ + layer, config = self._create_mhc_layer() + seq_length = 32 + micro_batch_size = 2 + + static_inputs = layer.get_layer_static_inputs(seq_length, micro_batch_size) + hidden_states = static_inputs["hidden_states"] + + expected_hidden_dim = config.num_residual_streams * config.hidden_size + assert hidden_states.shape[-1] == expected_hidden_dim, ( + f"get_layer_static_inputs returns hidden dim {hidden_states.shape[-1]} " + f"but mHC expects {expected_hidden_dim} (n={config.num_residual_streams} * " + f"C={config.hidden_size}). " + f"HyperConnectionTransformerLayer must override get_layer_static_inputs." + ) + + def test_submodules_under_cudagraphs_includes_hyper_connection(self): + """_get_submodules_under_cudagraphs must include hyper connection modules. + + CUDA graph manual hooks are set up for parameters of submodules returned + by this method. Missing hyper connection modules means their parameters + (mapping_proj, alpha_*, bias) will not get proper pre-forward hooks during + graph replay, leading to stale parameter values. + """ + layer, config = self._create_mhc_layer() + + submodules = layer._get_submodules_under_cudagraphs() + + hc_modules_found = any( + hasattr(m, 'mapping_proj') for submod in submodules for m in submod.modules() + ) + assert hc_modules_found, ( + "_get_submodules_under_cudagraphs does not include HyperConnectionModule. " + "Parameters like mapping_proj, alpha_pre/post/res will not be updated " + "during CUDA graph replay." + ) + + def test_forward_through_te_cuda_graph_capture_path(self): + """_te_cuda_graph_capture must produce correct output shapes for mHC. + + TE CUDA graph capture calls _te_cuda_graph_capture() during warmup. + For mHC layers, the input must be n-stream [s, b, n*C] and output must + also be [s, b, n*C]. + """ + layer, config = self._create_mhc_layer() + layer.eval() + + seq_len = 8 + batch_size = 2 + n_channels = config.num_residual_streams * config.hidden_size + + hidden_states = torch.randn(seq_len, batch_size, n_channels, device='cuda') + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + with torch.no_grad(): + outputs = layer._te_cuda_graph_capture( + hidden_states=hidden_states, attention_mask=attention_mask + ) + + if isinstance(outputs, tuple): + output = outputs[0] + else: + output = outputs + + assert output.shape == (seq_len, batch_size, n_channels), ( + f"_te_cuda_graph_capture output shape {output.shape} != " + f"expected {(seq_len, batch_size, n_channels)}" + ) + + def test_cuda_graph_fwd_bwd_with_hyper_connection(self): + """End-to-end CUDA graph capture and replay for forward+backward with mHC. + + Captures both the forward and backward pass of HyperConnectionTransformerLayer + into a torch.cuda.CUDAGraph and replays it with fresh input data, verifying + that the computation graph is fully static (capturable) and produces correct + output shapes and non-trivial gradients. + """ + layer, config = self._create_mhc_layer() + layer.train() + + seq_len = 8 + batch_size = 2 + n_channels = config.num_residual_streams * config.hidden_size + + static_input = torch.randn( + seq_len, batch_size, n_channels, device='cuda', requires_grad=True + ) + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + # Warmup on side stream to trigger lazy allocations + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + out, _ = layer(hidden_states=static_input, attention_mask=attention_mask) + out.sum().backward() + torch.cuda.current_stream().wait_stream(s) + + # Set .grad to None so backward allocates fresh gradient tensors in the + # graph's private memory pool during capture. + layer.zero_grad(set_to_none=True) + static_input.grad = None + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + output, _ = layer(hidden_states=static_input, attention_mask=attention_mask) + output.sum().backward() + + # Replay with new input data. + # Use no_grad because backward inside the captured graph already + # bumped the autograd version counter on static_input, making + # in-place copy_ illegal without disabling grad tracking. + with torch.no_grad(): + static_input.copy_(torch.randn_like(static_input)) + g.replay() + + assert output.shape == ( + seq_len, + batch_size, + n_channels, + ), f"Output shape {output.shape} != expected {(seq_len, batch_size, n_channels)}" + assert ( + static_input.grad is not None + ), "Gradients should be computed for static_input after graph replay" + assert static_input.grad.shape == static_input.shape + assert static_input.grad.abs().sum() > 0, "Gradients should be non-trivial" + + # Verify numerical consistency: graph replay should match eager execution + # with the same input and weights. + test_data = torch.randn(seq_len, batch_size, n_channels, device='cuda') + + with torch.no_grad(): + static_input.copy_(test_data) + g.replay() + graph_out = output.detach().clone() + graph_grad = static_input.grad.detach().clone() + + eager_input = test_data.clone().requires_grad_(True) + eager_output, _ = layer(hidden_states=eager_input, attention_mask=attention_mask) + eager_output.sum().backward() + + assert torch.allclose(graph_out, eager_output.detach(), atol=1e-5), ( + f"Graph vs eager output mismatch: " + f"max diff = {(graph_out - eager_output.detach()).abs().max().item()}" + ) + assert torch.allclose(graph_grad, eager_input.grad, atol=1e-5), ( + f"Graph vs eager gradient mismatch: " + f"max diff = {(graph_grad - eager_input.grad).abs().max().item()}" + ) + + def test_cuda_graph_fwd_bwd_with_hyper_connection_and_recompute(self): + """CUDA graph capture+replay for fwd+bwd with mHC and MHCRecomputeManager. + + When a MHCRecomputeManager is used, additional CheckpointWithoutOutput + objects are created for layernorm and hyper-connection operations. The + manager discards intermediate activations during forward (storage.resize_(0)) + and recomputes them during backward via a unified gradient hook. + This test verifies the full capture+replay still works correctly. + """ + layer, config = self._create_mhc_layer() + layer.train() + + seq_len = 8 + batch_size = 2 + n_channels = config.num_residual_streams * config.hidden_size + + static_input = torch.randn( + seq_len, batch_size, n_channels, device='cuda', requires_grad=True + ) + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + # Warmup on side stream; fresh manager per iteration to avoid stale state. + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + mgr = MHCRecomputeManager() + mgr.is_last_layer_in_recompute_block = True + out, _ = layer( + hidden_states=static_input, + attention_mask=attention_mask, + mhc_recompute_manager=mgr, + ) + mgr.discard_all_outputs_and_register_unified_recompute(out) + out.sum().backward() + torch.cuda.current_stream().wait_stream(s) + + layer.zero_grad(set_to_none=True) + static_input.grad = None + + capture_mgr = MHCRecomputeManager() + capture_mgr.is_last_layer_in_recompute_block = True + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + output, _ = layer( + hidden_states=static_input, + attention_mask=attention_mask, + mhc_recompute_manager=capture_mgr, + ) + capture_mgr.discard_all_outputs_and_register_unified_recompute(output) + output.sum().backward() + + # Replay with new input data. + with torch.no_grad(): + static_input.copy_(torch.randn_like(static_input)) + g.replay() + + assert output.shape == ( + seq_len, + batch_size, + n_channels, + ), f"Output shape {output.shape} != expected {(seq_len, batch_size, n_channels)}" + assert ( + static_input.grad is not None + ), "Gradients should be computed for static_input after graph replay" + assert static_input.grad.shape == static_input.shape + assert static_input.grad.abs().sum() > 0, "Gradients should be non-trivial" + + # Numerical consistency: graph replay vs eager with the same input. + test_data = torch.randn(seq_len, batch_size, n_channels, device='cuda') + + with torch.no_grad(): + static_input.copy_(test_data) + g.replay() + graph_out = output.detach().clone() + graph_grad = static_input.grad.detach().clone() + + eager_mgr = MHCRecomputeManager() + eager_mgr.is_last_layer_in_recompute_block = True + eager_input = test_data.clone().requires_grad_(True) + eager_output, _ = layer( + hidden_states=eager_input, + attention_mask=attention_mask, + mhc_recompute_manager=eager_mgr, + ) + eager_mgr.discard_all_outputs_and_register_unified_recompute(eager_output) + eager_output.sum().backward() + + assert torch.allclose(graph_out, eager_output.detach(), atol=1e-5), ( + f"Graph vs eager output mismatch: " + f"max diff = {(graph_out - eager_output.detach()).abs().max().item()}" + ) + assert torch.allclose(graph_grad, eager_input.grad, atol=1e-5), ( + f"Graph vs eager gradient mismatch: " + f"max diff = {(graph_grad - eager_input.grad).abs().max().item()}" + ) + + def test_mcore_cudagraph_manager_with_mhc_recompute_manager(self): + """MCore CudaGraphManager must not crash on mhc_recompute_manager kwarg. + + When cuda_graph_impl="local" is set, TransformerLayer.__call__ routes + through MegatronModule.__call__ → CudaGraphManager.__call__, which + iterates over all kwargs to check supported types. MHCRecomputeManager + (used by mhc_recompute_manager) is not a CUDA-graph-supported type. + + This test verifies that mhc_recompute_manager is properly extracted + from kwargs before the CudaGraphManager sees them, preventing the + AssertionError that would otherwise occur. + """ + layer, config = self._create_mhc_layer(cuda_graph_impl="local", cuda_graph_scope="attn") + layer.train() + + assert hasattr( + layer, 'cudagraph_manager' + ), "Layer should have cudagraph_manager with cuda_graph_impl='local'" + + seq_len = 8 + batch_size = 2 + n_channels = config.num_residual_streams * config.hidden_size + + hidden_states = torch.randn( + seq_len, batch_size, n_channels, device='cuda', requires_grad=True + ) + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + mgr = MHCRecomputeManager() + mgr.is_last_layer_in_recompute_block = True + + output, context = layer( + hidden_states=hidden_states, attention_mask=attention_mask, mhc_recompute_manager=mgr + ) + + assert output.shape == (seq_len, batch_size, n_channels) + + def test_mcore_cudagraph_manager_without_mhc_recompute_manager(self): + """MCore CudaGraphManager path works when mhc_recompute_manager is None.""" + layer, config = self._create_mhc_layer(cuda_graph_impl="local", cuda_graph_scope="attn") + layer.train() + + seq_len = 8 + batch_size = 2 + n_channels = config.num_residual_streams * config.hidden_size + + hidden_states = torch.randn( + seq_len, batch_size, n_channels, device='cuda', requires_grad=True + ) + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + output, context = layer(hidden_states=hidden_states, attention_mask=attention_mask) + + assert output.shape == (seq_len, batch_size, n_channels) + + +class TestMHCWithOffloading: + """Test HyperConnectionTransformerLayer with fine-grained activation offloading. + + Fine-grained activation offloading transfers specific activations (e.g., layernorm + inputs) to CPU during forward and reloads them during backward. These tests verify + that the mHC layer's multi-stream architecture works correctly with offloading. + """ + + def setup_method(self, method): + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def _create_mhc_layer_with_offloading( + self, hidden_size=64, num_streams=4, offload_modules=None + ): + if offload_modules is None: + offload_modules = ["attn_norm", "mlp_norm"] + + config = _make_mhc_config( + hidden_size=hidden_size, + num_streams=num_streams, + fine_grained_activation_offloading=True, + offload_modules=offload_modules, + ) + layer_spec = get_gpt_layer_with_transformer_engine_spec(enable_hyper_connections=True) + layer = HyperConnectionTransformerLayer(config, layer_spec.submodules) + layer.cuda() + return layer, config + + def test_forward_backward_with_offloading(self): + """Forward+backward should work with activation offloading enabled. + + This exercises the off_interface context manager around layernorms in + the mHC forward path, including the group_commit that commits the + offloading group for the aggregated 1-stream layernorm input. + """ + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + PipelineOffloadManager, + ) + + layer, config = self._create_mhc_layer_with_offloading() + layer.train() + + seq_len = 8 + batch_size = 2 + n_channels = config.num_residual_streams * config.hidden_size + + hidden_states = torch.randn( + seq_len, batch_size, n_channels, device='cuda', requires_grad=True + ) + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + mgr = PipelineOffloadManager.get_instance() + mgr.init_model_chunk_offload_handler(vp_size=1, vp_stage=0, min_offloaded_tensor_size=0) + + output, context = layer(hidden_states=hidden_states, attention_mask=attention_mask) + + assert output.shape == ( + seq_len, + batch_size, + n_channels, + ), f"Output shape {output.shape} != expected {(seq_len, batch_size, n_channels)}" + + loss = output.sum() + loss.backward() + + assert hidden_states.grad is not None, "Gradients should flow through offloaded path" + assert hidden_states.grad.shape == hidden_states.shape + assert hidden_states.grad.abs().sum() > 0, "Gradients should be non-trivial" + + PipelineOffloadManager.reset_instance() + + def test_offloading_numerical_equivalence(self): + """Offloaded forward+backward must produce the same result as non-offloaded. + + Compares outputs and gradients between a layer with offloading disabled + vs enabled to ensure the offloading path does not corrupt activations. + """ + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + PipelineOffloadManager, + ) + + PipelineOffloadManager.reset_instance() + + hidden_size = 64 + num_streams = 4 + seq_len = 8 + batch_size = 2 + n_channels = num_streams * hidden_size + + torch.manual_seed(42) + input_data = torch.randn(seq_len, batch_size, n_channels, device='cuda') + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + # Run without offloading + config_no_offload = _make_mhc_config(hidden_size=hidden_size, num_streams=num_streams) + layer_spec = get_gpt_layer_with_transformer_engine_spec(enable_hyper_connections=True) + layer_no_offload = HyperConnectionTransformerLayer( + config_no_offload, layer_spec.submodules + ).cuda() + layer_no_offload.train() + + h1 = input_data.clone().detach().requires_grad_(True) + out1, _ = layer_no_offload(hidden_states=h1, attention_mask=attention_mask) + out1.sum().backward() + grad_no_offload = h1.grad.clone() + out1_detached = out1.detach().clone() + + # Run with offloading using the same weights + config_offload = _make_mhc_config( + hidden_size=hidden_size, + num_streams=num_streams, + fine_grained_activation_offloading=True, + offload_modules=["attn_norm", "mlp_norm"], + ) + layer_offload = HyperConnectionTransformerLayer( + config_offload, layer_spec.submodules + ).cuda() + layer_offload.load_state_dict(layer_no_offload.state_dict()) + layer_offload.train() + + mgr = PipelineOffloadManager.get_instance() + mgr.init_model_chunk_offload_handler(vp_size=1, vp_stage=0, min_offloaded_tensor_size=0) + + h2 = input_data.clone().detach().requires_grad_(True) + out2, _ = layer_offload(hidden_states=h2, attention_mask=attention_mask) + out2.sum().backward() + grad_offload = h2.grad.clone() + + PipelineOffloadManager.reset_instance() + + assert torch.allclose(out1_detached, out2.detach(), atol=1e-5), ( + f"Forward outputs differ: max diff = " + f"{(out1_detached - out2.detach()).abs().max().item()}" + ) + assert torch.allclose(grad_no_offload, grad_offload, atol=1e-5), ( + f"Gradients differ: max diff = " + f"{(grad_no_offload - grad_offload).abs().max().item()}" + )