From 37d44f6ee359c5196218790e4c77f0c76f0a32bf Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Mon, 27 Apr 2026 09:15:32 -0700 Subject: [PATCH 01/10] Add mHC transformer reference implementation --- gpt_builders.py | 2 + megatron/core/fusions/fused_bias_dropout.py | 93 ++- ...rimental_attention_variant_module_specs.py | 10 +- megatron/core/models/gpt/gpt_layer_specs.py | 51 +- megatron/core/tensor_parallel/random.py | 163 +++- megatron/core/transformer/__init__.py | 8 +- megatron/core/transformer/cuda_graphs.py | 2 +- megatron/core/transformer/hyper_connection.py | 716 ++++++++++++++++++ .../core/transformer/transformer_block.py | 85 ++- .../core/transformer/transformer_config.py | 116 ++- .../core/transformer/transformer_layer.py | 409 +++++++++- megatron/training/initialize.py | 13 +- 12 files changed, 1625 insertions(+), 43 deletions(-) create mode 100644 megatron/core/transformer/hyper_connection.py diff --git a/gpt_builders.py b/gpt_builders.py index 24b5f89d311..59a8942e472 100644 --- a/gpt_builders.py +++ b/gpt_builders.py @@ -136,6 +136,7 @@ def _get_transformer_layer_spec(use_te, config): use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False), + enable_hyper_connection=config.enable_hyper_connections, ) elif config.transformer_impl == "inference_optimized": return get_gpt_layer_with_inference_spec( @@ -154,4 +155,5 @@ def _get_transformer_layer_spec(use_te, config): use_kitchen=config.use_kitchen, use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, + enable_hyper_connection=config.enable_hyper_connections, ) diff --git a/megatron/core/fusions/fused_bias_dropout.py b/megatron/core/fusions/fused_bias_dropout.py index 2eb4007f75c..1f2448d86be 100644 --- a/megatron/core/fusions/fused_bias_dropout.py +++ b/megatron/core/fusions/fused_bias_dropout.py @@ -1,10 +1,13 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -from typing import Optional, Tuple +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +from typing import TYPE_CHECKING, Optional, Tuple import torch from megatron.core.jit import jit_fuser +if TYPE_CHECKING: + from megatron.core.tensor_parallel.random import CheckpointManager + # pylint: disable=missing-function-docstring @@ -80,7 +83,26 @@ def bias_dropout_add_fused_inference( return _bias_dropout_add_func(x_with_bias, residual, prob, False) -def get_bias_dropout_add(training, fused): +def get_bias_dropout_add( + training, fused, mhc_recompute_manager: Optional['CheckpointManager'] = None +): + """ + Get the bias-dropout-add function. + + Args: + training: Whether in training mode. + fused: Whether to use fused implementation. + mhc_recompute_manager: Optional CheckpointManager for checkpoint management. + When provided, the returned function will wrap the BDA operation with + CheckpointWithoutOutput for memory-efficient recomputation. + + Returns: + A callable that performs bias-dropout-add operation. + """ + if mhc_recompute_manager is not None: + # Return a checkpointed version that handles tuple unpacking internally + return _get_checkpointed_bda(training, fused, mhc_recompute_manager) + if fused: # jit scripting for a nn.module (with dropout) is not # triggering the fusion kernel. For now, we use two @@ -92,3 +114,68 @@ def get_bias_dropout_add(training, fused): return bias_dropout_add_fused_inference else: return bias_dropout_add_unfused(training) + + +def _get_checkpointed_bda(training, fused, mhc_recompute_manager: 'CheckpointManager'): + """ + Create a checkpointed bias-dropout-add function. + + This function handles: + 1. Tuple unpacking for x_with_bias (required because save_for_backward can't save tuples) + 2. Non-tensor arguments like dropout probability (handled by CheckpointWithoutOutput) + 3. Auto-registration to the CheckpointManager + + Args: + training: Whether in training mode. + fused: Whether to use fused implementation. + mhc_recompute_manager: CheckpointManager for checkpoint management. + + Returns: + A callable that performs checkpointed bias-dropout-add operation. + """ + from megatron.core.tensor_parallel.random import CheckpointWithoutOutput + + # Get the underlying BDA function + if fused: + if training: + bda_func = bias_dropout_add_fused_train + else: + bda_func = bias_dropout_add_fused_inference + else: + bda_func = bias_dropout_add_unfused(training) + + def _checkpointed_bda(x_with_bias, residual, prob): + """ + Checkpointed BDA that handles tuple unpacking internally. + + Args: + x_with_bias: Either a tuple (x, bias) or a single tensor x. + residual: Residual tensor. + prob: Dropout probability. + + Returns: + Output tensor after bias-dropout-add. + """ + # Create checkpoint with manager + ckpt = CheckpointWithoutOutput(ckpt_manager=mhc_recompute_manager) + + # Handle case where x_with_bias might be a single tensor (e.g., from IdentityOp) + if isinstance(x_with_bias, tuple): + x, bias = x_with_bias + else: + x = x_with_bias + bias = None + + # Wrapper function that re-packs the tuple for the actual BDA function + def _bda_wrapper(output, bias, res, dropout): + return bda_func((output, bias), res, dropout) + + # Call checkpoint with unpacked arguments + result = ckpt.checkpoint(_bda_wrapper, x, bias, residual, prob) + + # No-op when manager is set - manager handles all discarding uniformly + ckpt.discard_output_and_register_recompute(result) + + return result + + return _checkpointed_bda diff --git a/megatron/core/models/gpt/experimental_attention_variant_module_specs.py b/megatron/core/models/gpt/experimental_attention_variant_module_specs.py index 1b03b935639..4385f49ca8c 100644 --- a/megatron/core/models/gpt/experimental_attention_variant_module_specs.py +++ b/megatron/core/models/gpt/experimental_attention_variant_module_specs.py @@ -12,6 +12,7 @@ DSAttention, DSAttentionSubmodules, ) +from megatron.core.transformer.hyper_connection import HyperConnectionModule from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.multi_latent_attention import ( MLASelfAttention, @@ -24,6 +25,7 @@ ) from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import ( + HyperConnectionTransformerLayer, TransformerLayer, TransformerLayerSubmodules, get_transformer_layer_offset, @@ -227,6 +229,10 @@ def get_transformer_block_with_experimental_attention_variant_spec( # Get GPT decoder block layer specs rms_norm = config.normalization == "RMSNorm" + enable_hc = config.enable_hyper_connections + hc_module = HyperConnectionModule if enable_hc else IdentityOp + layer_module = HyperConnectionTransformerLayer if enable_hc else TransformerLayer + layer_specs = [] for layer_number in range(config.num_layers): attention = ( @@ -248,14 +254,16 @@ def get_transformer_block_with_experimental_attention_variant_spec( layer_specs.append( ModuleSpec( - module=TransformerLayer, + module=layer_module, submodules=TransformerLayerSubmodules( input_layernorm=input_layernorm, self_attention=attention, self_attn_bda=get_bias_dropout_add, + self_attention_hyper_connection=hc_module, pre_mlp_layernorm=pre_mlp_layernorm, mlp=mlp, mlp_bda=get_bias_dropout_add, + mlp_hyper_connection=hc_module, ), ) ) diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index 5e90f0b36be..a097e966f68 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -1,4 +1,5 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +import copy import warnings from typing import Optional, Union @@ -12,6 +13,7 @@ from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec_for_backend from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.enums import AttnMaskType, LayerType +from megatron.core.transformer.hyper_connection import HyperConnectionModule from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP, MLPSubmodules from megatron.core.transformer.multi_latent_attention import ( @@ -34,6 +36,7 @@ ) from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import ( + HyperConnectionTransformerLayer, TransformerLayer, TransformerLayerSubmodules, get_transformer_layer_offset, @@ -183,6 +186,7 @@ def get_gpt_layer_with_transformer_engine_submodules( use_kitchen_attention: bool = False, kitchen_attention_backend: str = "sdpa", mla_down_proj_fusion: bool = False, + enable_hyper_connection: bool = False, ) -> TransformerLayerSubmodules: """Use these submodules to use lower-level Transformer Engine modules (required for fp8 training). @@ -200,6 +204,8 @@ def get_gpt_layer_with_transformer_engine_submodules( mla_down_proj_fusion (bool, optional): Enable fused q/kv down-projection and fused input layernorm when backend supports. Otherwise fall back to the unfused MLA. + enable_hyper_connection (bool): Use HyperConnectionTransformerLayer with + HyperConnectionModule instead of plain TransformerLayer. Defaults to False. Returns: TransformerLayerSubmodules: TE modules to construct a TransformerLayer @@ -233,6 +239,8 @@ def get_gpt_layer_with_transformer_engine_submodules( use_te_activation_func=use_te_activation_func, ) + hc_module = HyperConnectionModule if enable_hyper_connection else IdentityOp + if multi_latent_attention: assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA." linear_q_up_proj = ( @@ -302,9 +310,11 @@ def get_gpt_layer_with_transformer_engine_submodules( ), ), self_attn_bda=get_bias_dropout_add, + self_attention_hyper_connection=hc_module, pre_mlp_layernorm=backend.layer_norm(has_residual=True) if num_experts else IdentityOp, mlp=mlp, mlp_bda=get_bias_dropout_add, + mlp_hyper_connection=hc_module, ) else: qk_norm = backend.layer_norm(for_qk=True) @@ -325,9 +335,11 @@ def get_gpt_layer_with_transformer_engine_submodules( ), ), self_attn_bda=get_bias_dropout_add, + self_attention_hyper_connection=hc_module, pre_mlp_layernorm=backend.layer_norm(has_residual=True) if num_experts else IdentityOp, mlp=mlp, mlp_bda=get_bias_dropout_add, + mlp_hyper_connection=hc_module, sharded_state_dict_keys_map={ "mlp.0.weight": "mlp.linear_fc1.layer_norm_weight", "mlp.0.bias": "mlp.linear_fc1.layer_norm_bias", @@ -342,8 +354,10 @@ def get_gpt_layer_with_transformer_engine_submodules( @copy_signature(get_gpt_layer_with_transformer_engine_submodules) def get_gpt_layer_with_transformer_engine_spec(*args, **kwargs) -> ModuleSpec: """Use this spec to use lower-level Transformer Engine modules (required for fp8 training).""" + enable_hc = kwargs.get('enable_hyper_connection', False) + layer_module = HyperConnectionTransformerLayer if enable_hc else TransformerLayer return ModuleSpec( - module=TransformerLayer, + module=layer_module, submodules=get_gpt_layer_with_transformer_engine_submodules(*args, **kwargs), ) @@ -359,6 +373,7 @@ def get_gpt_layer_local_submodules( use_kitchen: bool = False, use_kitchen_attention: bool = False, kitchen_attention_backend: str = "sdpa", + enable_hyper_connection: bool = False, ) -> TransformerLayerSubmodules: """Use these submodules for an implementation using only modules in Megatron-Core. @@ -370,6 +385,8 @@ def get_gpt_layer_local_submodules( multi_latent_attention (bool, optional): To use MLA. Defaults to False. fp8 (str, optional): Deprecated. For temporary Nemo compatibility. qk_l2_norm (bool, optional): To use l2 norm for queries/keys. Defaults to False. + enable_hyper_connection (bool): Use HyperConnectionTransformerLayer with + HyperConnectionModule instead of plain TransformerLayer. Defaults to False. Returns: TransformerLayerSubmodules: Megatron-Core modules to construct a TransformerLayer @@ -402,6 +419,8 @@ def get_gpt_layer_local_submodules( backend=backend, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm ) + hc_module = HyperConnectionModule if enable_hyper_connection else IdentityOp + if multi_latent_attention: assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA." return TransformerLayerSubmodules( @@ -422,9 +441,11 @@ def get_gpt_layer_local_submodules( ), ), self_attn_bda=get_bias_dropout_add, + self_attention_hyper_connection=hc_module, pre_mlp_layernorm=layer_norm, mlp=mlp, mlp_bda=get_bias_dropout_add, + mlp_hyper_connection=hc_module, ) else: return TransformerLayerSubmodules( @@ -445,9 +466,11 @@ def get_gpt_layer_local_submodules( ), ), self_attn_bda=get_bias_dropout_add, + self_attention_hyper_connection=hc_module, pre_mlp_layernorm=layer_norm, mlp=mlp, mlp_bda=get_bias_dropout_add, + mlp_hyper_connection=hc_module, sharded_state_dict_keys_map={ "input_layernorm.": "self_attention.linear_qkv.layer_norm_", "pre_mlp_layernorm.": "mlp.linear_fc1.layer_norm_", @@ -458,8 +481,10 @@ def get_gpt_layer_local_submodules( @copy_signature(get_gpt_layer_local_submodules) def get_gpt_layer_local_spec(*args, **kwargs) -> ModuleSpec: """Use this spec for an implementation using only modules in Megatron-Core.""" + enable_hc = kwargs.get('enable_hyper_connection', False) + layer_module = HyperConnectionTransformerLayer if enable_hc else TransformerLayer return ModuleSpec( - module=TransformerLayer, submodules=get_gpt_layer_local_submodules(*args, **kwargs) + module=layer_module, submodules=get_gpt_layer_local_submodules(*args, **kwargs) ) @@ -568,6 +593,7 @@ def get_gpt_decoder_layer_specs( use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False), + enable_hyper_connection=config.enable_hyper_connections, ) moe_layer_spec = get_gpt_layer_with_transformer_engine_spec( num_experts=config.num_moe_experts, @@ -580,6 +606,7 @@ def get_gpt_decoder_layer_specs( use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False), + enable_hyper_connection=config.enable_hyper_connections, ) elif config.transformer_impl == "inference_optimized": layer_norm_impl = TENorm @@ -608,6 +635,7 @@ def get_gpt_decoder_layer_specs( use_kitchen=config.use_kitchen, use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, + enable_hyper_connection=config.enable_hyper_connections, ) moe_layer_spec = get_gpt_layer_local_spec( num_experts=config.num_moe_experts, @@ -619,6 +647,7 @@ def get_gpt_decoder_layer_specs( use_kitchen=config.use_kitchen, use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, + enable_hyper_connection=config.enable_hyper_connections, ) # Parse config.moe_layer_freq to determine the pattern of expert/dense layers. @@ -744,12 +773,22 @@ def get_gpt_mtp_block_spec_for_backend( if isinstance(spec, TransformerBlockSubmodules): # get the spec for the last layer of decoder block - transformer_layer_spec = spec.layer_specs[-1] - elif isinstance(spec, ModuleSpec) and spec.module == TransformerLayer: - transformer_layer_spec = spec + transformer_layer_spec = copy.copy(spec.layer_specs[-1]) + elif isinstance(spec, ModuleSpec) and issubclass(spec.module, TransformerLayer): + transformer_layer_spec = copy.copy(spec) else: raise ValueError(f"Invalid spec: {spec}") + transformer_layer_spec.submodules = copy.copy(transformer_layer_spec.submodules) + + # MTP does not support hyper connections yet; strip HC modules and + # downgrade the layer class to plain TransformerLayer. + transformer_layer_spec.submodules.self_attention_hyper_connection = IdentityOp + transformer_layer_spec.submodules.cross_attention_hyper_connection = IdentityOp + transformer_layer_spec.submodules.mlp_hyper_connection = IdentityOp + if transformer_layer_spec.module is HyperConnectionTransformerLayer: + transformer_layer_spec.module = TransformerLayer + mtp_layer_spec = get_mtp_layer_spec_for_backend( mtp_model_layer_spec=transformer_layer_spec, backend=backend ) diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py index 92d39ba92ef..4516fe10d88 100644 --- a/megatron/core/tensor_parallel/random.py +++ b/megatron/core/tensor_parallel/random.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Parts of the code here are adapted from PyTorch # repo: https://github.com/pytorch/pytorch @@ -598,7 +598,9 @@ def forward( @staticmethod def backward(ctx, *args): """Backward pass.""" - if not torch.autograd._is_checkpoint_valid(): + from megatron.core.transformer.cuda_graphs import is_graph_capturing + + if not torch.autograd._is_checkpoint_valid() and not is_graph_capturing(): raise RuntimeError( "Checkpointing is not compatible with .grad(), " "please use .backward() if possible" @@ -642,10 +644,67 @@ def checkpoint( return CheckpointFunction.apply(function, distribute_saved_activations, *args) +def _save_args_to_ctx(ctx, args): + """Save mixed tensor/non-tensor arguments into autograd ctx. + + Since save_for_backward only supports tensors, this function separates + tensor and non-tensor arguments, saving tensors via save_for_backward + and storing non-tensor metadata (indices and values) as ctx attributes. + + Use _load_args_from_ctx to reconstruct the original args. + """ + tensor_args = [] + non_tensor_entries = [] + + for index, arg in enumerate(args): + if isinstance(arg, torch.Tensor): + tensor_args.append(arg) + continue + non_tensor_entries.append((index, arg)) + + ctx.save_for_backward(*detach_variable(tuple(tensor_args))) + ctx._non_tensor_entries = tuple(non_tensor_entries) + ctx._total_args_count = len(args) + + +def _load_args_from_ctx(ctx): + """Load and reconstruct mixed tensor/non-tensor arguments from autograd ctx. + + This is the inverse of _save_args_to_ctx. It retrieves tensors from + ctx.saved_tensors and merges them with stored non-tensor arguments + to reconstruct the original args in their original order. + + Returns: + tuple of reconstructed arguments in their original order. + """ + + def _detach_with_grad(tensor): + detached = tensor.detach() + detached.requires_grad_(tensor.requires_grad) + return detached + + tensor_iter = iter(_detach_with_grad(t) for t in ctx.saved_tensors) + total_args_count = ctx._total_args_count + non_tensor_map = dict(ctx._non_tensor_entries) + + reconstructed_args = [] + for index in range(total_args_count): + if index in non_tensor_map: + reconstructed_args.append(non_tensor_map[index]) + else: + reconstructed_args.append(next(tensor_iter)) + return tuple(reconstructed_args) + + class CheckpointWithoutOutputFunction(torch.autograd.Function): """ Checkpoint Function Helper for CheckpointWithoutOutput. Save context for recompute. + + Handles both tensor and non-tensor arguments: + - Tensor arguments are saved via save_for_backward + - Non-tensor arguments (int, float, bool, None, etc.) are stored separately + in ctx attributes and reconstructed during recomputation """ @staticmethod @@ -668,7 +727,10 @@ def forward( with torch.no_grad(), fwd_ctx: outputs = run_function(*args) - ctx.save_for_backward(*detach_variable(args)) + + # Save tensor and non-tensor arguments into ctx for recomputation + _save_args_to_ctx(ctx, args) + # the CheckpointWithoutOutput object is passed in, then it can access the saved input # tensors later for recomputation checkpoint_without_output_obj.ctx = ctx @@ -685,10 +747,56 @@ def backward(ctx, *args): torch.autograd.backward(outputs, args) ctx.outputs = None ctx.inputs = None - grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in inputs) + grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs) return (None, None) + grads +class CheckpointManager: + """ + Manages multiple CheckpointWithoutOutput objects within a TransformerBlock + cross layer recomputations, enabling unified recomputation during backward pass. + This is particularly useful for scenarios where multiple checkpoint operations have + sequential dependencies (i.e., the output of one checkpoint is the input of the next). + + Usage: + ckptManager = CheckpointManager() + ckpt_function = CheckpointWithoutOutput(ckpt_manager=ckptManager) + ckpt_function.checkpoint(run_function, *args) + # other checkpointed operations + ckpt_manager.discard_all_outputs_and_register_unified_recompute(final_output) + """ + + def __init__(self): + self.checkpoints = [] + # Set by TransformerBlock before each layer forward. + # When True, the layer should keep block-boundary output uncheckpointed. + self.is_last_layer_in_recompute_block = False + + def add_checkpoint(self, ckpt): + """Add a checkpoint to the manager.""" + if not isinstance(ckpt, CheckpointWithoutOutput): + raise TypeError("Expected CheckpointWithoutOutput object") + if ckpt.outputs is None: + raise ValueError("CheckpointWithoutOutput must call checkpoint() before adding") + self.checkpoints.append(ckpt) + + def discard_all_outputs_and_register_unified_recompute(self, hook_tensor): + """Discard all checkpoint outputs to save memory and register unified recompute hook.""" + for ckpt in self.checkpoints: + for output in ckpt.outputs: + output.untyped_storage().resize_(0) + + # Register unified recompute hook + if hook_tensor.requires_grad: + hook_tensor.register_hook(self._unified_recompute_hook) + + def _unified_recompute_hook(self, grad_output): + for ckpt in self.checkpoints: + # Call _recompute for each checkpoint in forward order + # The _recompute method will restore the output tensor storage + ckpt._recompute(None) + + class CheckpointWithoutOutput(object): """ Checkpoint a model or part of the model and release the output. @@ -703,8 +811,19 @@ class CheckpointWithoutOutput(object): discarded output tensors are directly saved in the following modules for backward computation. """ - def __init__(self, fp8=False): - self.fp8 = fp8 is not None + def __init__(self, fp8=False, ckpt_manager=None): + """ + Initialize CheckpointWithoutOutput. + + Args: + fp8: Whether to use FP8 mode. Defaults to False. + ckpt_manager: Optional CheckpointManager instance. When provided, + checkpoint() will auto-register to the manager, and + discard_output_and_register_recompute() will only discard + output without registering individual hooks. + """ + self.fp8 = bool(fp8) + self.ckpt_manager = ckpt_manager self.run_function = None self.fwd_cpu_rng_state = None self.fwd_cuda_rng_state = None @@ -713,7 +832,12 @@ def __init__(self, fp8=False): self.outputs = None def checkpoint(self, run_function: Callable[[Unpack[_Ts]], _R], *args: Unpack[_Ts]) -> _R: - """Checkpoint function.""" + """ + Checkpoint function. + + If ckpt_manager was provided during initialization, this checkpoint + will be automatically registered to the manager after execution. + """ # If in cuda graph warmup, disable checkpointing, as 'discard_output_and_register_recompute' # may be called in a separate graph warmup. @@ -730,6 +854,11 @@ def checkpoint(self, run_function: Callable[[Unpack[_Ts]], _R], *args: Unpack[_T self.outputs = outputs if isinstance(self.outputs, torch.Tensor): self.outputs = (self.outputs,) + + # Auto-register to manager if provided + if self.ckpt_manager is not None: + self.ckpt_manager.add_checkpoint(self) + return outputs def _recompute(self, _): @@ -738,7 +867,7 @@ def _recompute(self, _): from megatron.core.transformer.cuda_graphs import is_graph_capturing, is_graph_warmup # The recomputation has been triggered already. Just return. - # Handle cudagraphs, do nothing if currently in graph warmup + # Handle cudagraphs: do nothing if currently in graph warmup if self.ctx is None or is_graph_warmup(): return @@ -760,17 +889,8 @@ def _recompute(self, _): recompute_ctx = contextlib.nullcontext() fp8_ctx = contextlib.nullcontext() - # Store the inputs for backward pass - inputs = self.ctx.saved_tensors - - def detach(t): - if isinstance(t, torch.Tensor): - requires_grad = t.requires_grad - t = t.detach() - t.requires_grad_(requires_grad) - return t - - inputs = tuple(detach(t) for t in inputs) + # Reconstruct full args list from saved ctx + inputs = _load_args_from_ctx(self.ctx) with torch.enable_grad(), fp8_ctx, recompute_ctx: outputs = self.run_function(*inputs) @@ -803,10 +923,11 @@ def discard_output_and_register_recompute(self, hook_tensor): in the forward pass and the gradient of the hook_tensor is computed before the recomputed tensors are used. """ - + # When ckpt_manager is set, this is a no-op. + # Manager handles all discarding and hook registration uniformly. from megatron.core.transformer.cuda_graphs import is_graph_warmup - if is_graph_warmup(): + if self.ckpt_manager is not None or is_graph_warmup(): return # use resize to release the output tensor memory and still keep the metadata in the tensors. diff --git a/megatron/core/transformer/__init__.py b/megatron/core/transformer/__init__.py index 0e3cdcfa57e..75e3b485c4f 100644 --- a/megatron/core/transformer/__init__.py +++ b/megatron/core/transformer/__init__.py @@ -1,6 +1,10 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. from .module import MegatronModule from .spec_utils import ModuleSpec, build_module from .transformer_config import MLATransformerConfig, TransformerConfig -from .transformer_layer import TransformerLayer, TransformerLayerSubmodules +from .transformer_layer import ( + HyperConnectionTransformerLayer, + TransformerLayer, + TransformerLayerSubmodules, +) diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 067f6055015..af5a2e35672 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import dataclasses import gc diff --git a/megatron/core/transformer/hyper_connection.py b/megatron/core/transformer/hyper_connection.py new file mode 100644 index 00000000000..64ec3107213 --- /dev/null +++ b/megatron/core/transformer/hyper_connection.py @@ -0,0 +1,716 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import math +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.nn as nn +from torch import Tensor + +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import nvtx_decorator + +if TYPE_CHECKING: + from megatron.core.tensor_parallel.random import CheckpointManager + + +@torch.compile +def _sinkhorn_iterations(input_logits: Tensor, num_iterations: int, eps: float) -> Tensor: + row_max = input_logits.max(dim=-1, keepdim=True).values + M = torch.exp(input_logits - row_max) + for _ in range(num_iterations): + M = M / M.sum(dim=-1, keepdim=True).clamp(min=eps) + M = M / M.sum(dim=-2, keepdim=True).clamp(min=eps) + return M + + +class SinkhornKnopp(torch.autograd.Function): + """Sinkhorn-Knopp projection to doubly stochastic matrix. + + This is an autograd.Function because the iterative forward is re-executed + during backward (under torch.enable_grad) so that PyTorch's autograd can + differentiate through it without storing all intermediate iteration states. + """ + + @staticmethod + def forward(ctx, input_logits: Tensor, num_iterations: int, eps: float = 1e-6) -> Tensor: + """Run Sinkhorn iterations and save inputs for backward recomputation.""" + M = _sinkhorn_iterations(input_logits, num_iterations, eps) + ctx.save_for_backward(input_logits) + ctx.num_iterations = num_iterations + ctx.eps = eps + return M + + @staticmethod + def backward(ctx, grad_output: Tensor): + """Recompute forward under enable_grad and back-propagate.""" + (input_logits,) = ctx.saved_tensors + with torch.enable_grad(): + logits = input_logits.detach().requires_grad_(True) + M = _sinkhorn_iterations(logits, ctx.num_iterations, ctx.eps) + M.backward(grad_output) + return logits.grad, None, None + + +def native_sinkhorn(input_logits: Tensor, num_iterations: int, eps: float = 1e-6) -> Tensor: + """Native Sinkhorn-Knopp (autograd.Function wrapper).""" + return SinkhornKnopp.apply(input_logits, num_iterations, eps) + + +@torch.compile +def native_h_aggregate(x: Tensor, h_pre: Tensor) -> Tensor: + """Native n-stream weighted aggregation: out = sum_j(h_pre_j * x_j).""" + return (x * h_pre.unsqueeze(-1)).sum(dim=2) + + +@torch.compile +def native_h_post_bda( + h_res: Tensor, original_residual: Tensor, h_post: Tensor, x: Tensor, bias: Optional[Tensor] +) -> Tensor: + """Native H_res @ residual + H_post * (x [+ bias]).""" + s, b, n, C = original_residual.shape + h_res_batched = h_res.view(s * b, n, n) + residual_batched = original_residual.view(s * b, n, C) + mixed = torch.bmm(h_res_batched, residual_batched).view(s, b, n, C) + x_expanded = h_post.unsqueeze(-1) * x.unsqueeze(2) + if bias is not None: + bias_expanded = h_post.unsqueeze(-1) * bias.view(1, 1, 1, C) + return x_expanded + bias_expanded + mixed + return x_expanded + mixed + + +@torch.compile +def native_proj_rms(x: Tensor, weight: Tensor, eps: float = 1e-6) -> Tuple[Tensor, Tensor]: + """Native fused projection + RMS normalization.""" + proj = torch.matmul(x, weight.t()) + norm = x.norm(dim=-1, keepdim=True) + K = x.shape[-1] + v = norm / math.sqrt(K) + eps + r = 1.0 / v + return proj, r + + +# ============================================================================ +# HyperConnectionModule +# ============================================================================ + + +# TODO: keep hyper connection in fp32 computation +class HyperConnectionModule(MegatronModule): + """ + Unified mHC (Manifold-Constrained Hyper-Connections) module. + + Implements the complete mHC propagation: + x_{l+1} = H_res @ x_l + H_post^T @ F(H_pre @ x_l) + + This module handles: + 1. Computing learnable mappings: H_pre, H_post, H_res (with Sinkhorn-Knopp projection) + 2. Aggregation: n-stream → 1-stream (H_pre @ x) + 3. Expansion: 1-stream → n-stream (H_post^T @ output) + 4. Residual merge: H_res @ x + expanded_output + 5. Block-level expand/contract for TransformerBlock boundaries + + Args: + config: TransformerConfig with hyper-connection fields + layer_number: Current layer index for initialization + """ + + def __init__(self, config: TransformerConfig, layer_number: int): + super().__init__(config) + self.config = config + self.layer_number = layer_number + self.n = config.num_residual_streams + self.hidden_size = config.hidden_size + self.sinkhorn_iterations = config.mhc_sinkhorn_iterations + + # Projection weights for dynamic mappings + # Input: [s, b, n*C] -> Output: n^2 + 2n values per token + # - H_pre: n values + # - H_post: n values + # - H_res: n^2 values (before Sinkhorn projection) + self.mapping_proj = nn.Linear( + self.n * self.hidden_size, self.n * self.n + 2 * self.n, bias=False + ) + + init_alpha = config.mhc_init_gating_factor + # Learnable scaling factors (Eq. 5 in paper) + self.alpha_pre = nn.Parameter(torch.full((1,), init_alpha)) + self.alpha_post = nn.Parameter(torch.full((1,), init_alpha)) + self.alpha_res = nn.Parameter(torch.full((1,), init_alpha)) + + # Static bias terms + self.bias = nn.Parameter(torch.zeros(self.n * self.n + 2 * self.n)) + self.norm_eps = 1e-6 + + # Choose implementation: fused cuTile kernels vs reference modules. + # Both paths expose the same call signatures so the rest of the code + # is implementation-agnostic. + if config.use_fused_mhc: + from megatron.core.fusions.fused_mhc_kernels import ( + fused_h_aggregate, + fused_h_post_bda, + fused_proj_rms, + fused_sinkhorn, + ) + + self._sinkhorn_op = fused_sinkhorn + self._h_aggregate_op = fused_h_aggregate + self._h_post_bda_op = fused_h_post_bda + self._proj_rms_op = fused_proj_rms + else: + self._sinkhorn_op = native_sinkhorn + self._h_aggregate_op = native_h_aggregate + self._h_post_bda_op = native_h_post_bda + self._proj_rms_op = native_proj_rms + + self._init_weights() + + def _init_weights(self) -> None: + """Initialize weights for stable training.""" + nn.init.xavier_uniform_(self.mapping_proj.weight) + + # Set sequence_parallel attribute on parameters for gradient synchronization + # across TP ranks when sequence_parallel is enabled. + # This is required because HyperConnectionModule uses non-TP-aware layers + # (nn.Linear, nn.RMSNorm) whose gradients need to be all-reduced. + if self.config.sequence_parallel: + setattr(self.mapping_proj.weight, 'sequence_parallel', True) + setattr(self.alpha_pre, 'sequence_parallel', True) + setattr(self.alpha_post, 'sequence_parallel', True) + setattr(self.alpha_res, 'sequence_parallel', True) + setattr(self.bias, 'sequence_parallel', True) + + def _projection_and_get_norm(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """ + Projection + RMS normalization. + + Args: + x: [s, b, n*C] - n-stream hidden states + """ + s, b, nC = x.shape + x_2d = x.reshape(s * b, nC) + proj, r = self._proj_rms_op(x_2d, self.mapping_proj.weight, self.norm_eps) + return proj.view(s, b, -1), r.view(s, b, 1) + + @torch.compile + def _compute_h(self, proj: Tensor, r: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """ + Compute h from projected hidden states and scaling factors. + + Args: + proj: [s, b, n^2 + 2n] - projected hidden states + r: [s, b, 1] - scaling factors + + Returns: + h_pre: [s, b, n] - aggregation weights + h_post: [s, b, n] - expansion weights + h_res: [s, b, n^2] - residual mixing logits + """ + alpha_ = torch.cat( + [ + self.alpha_pre.expand(self.n), + self.alpha_post.expand(self.n), + self.alpha_res.expand(self.n * self.n), + ], + dim=-1, + ) + h = r * proj * alpha_ + self.bias + # H_pre = σ(α_pre * (θ_pre @ x̃) + b_pre) + h_pre = h[..., : self.n].sigmoid() # [s, b, n] + + # H_post = 2σ(α_post * (θ_post @ x̃) + b_post) + h_post = h[..., self.n : 2 * self.n].sigmoid() * 2 # [s, b, n] + h_res = h[..., 2 * self.n :] + return h_pre, h_post, h_res + + @nvtx_decorator(message="HyperConnection::compute_mappings") + def compute_mappings(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """ + Compute mHC mappings from input hidden states. + + Reference: Eq. (5) and (8) in mHC paper + + Args: + x: [s, b, n*C] - n-stream hidden states + + Returns: + h_pre: [s, b, n] - aggregation weights (sigmoid activated) + h_post: [s, b, n] - expansion weights (2*sigmoid activated) + h_res: [s, b, n, n] - residual mixing matrix (doubly stochastic) + """ + s, b, _ = x.shape + with torch.cuda.nvtx.range("HyperConnection::projection_and_get_norm"): + proj, r = self._projection_and_get_norm(x) + with torch.cuda.nvtx.range("HyperConnection::compute_h"): + h_pre, h_post, h_res = self._compute_h(proj, r) + h_res = self._sinkhorn_op( + h_res.view(s, b, self.n, self.n), self.sinkhorn_iterations, self.norm_eps + ) # [s, b, n, n] + + return h_pre, h_post, h_res + + @torch.compile + def _apply_h_post(self, x: Tensor, h_post: Tensor) -> Tensor: + """ + Core implementation of H_post application to a single tensor. + + Computes: H_post^T @ x + + Args: + x: Input tensor, can be either: + - [s, b, C] - standard hidden states + - [C] - bias tensor (will be broadcast) + h_post: [s, b, n] - expansion weights + + Returns: + output: [s, b, n*C] - expanded tensor + """ + n = self.n + s, b, _ = h_post.shape + + if x.dim() == 1: + # x is bias with shape [C], need to broadcast to [s, b, 1, C] + C = x.shape[0] + x_expanded = x.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand(s, b, 1, C) + else: + # x is [s, b, C] + C = x.shape[-1] + x_expanded = x.unsqueeze(2) # [s, b, 1, C] + + # h_post^T @ x : [s, b, n, 1] * [s, b, 1, C] -> [s, b, n, C] + # Using broadcast multiply instead of einsum + result = h_post.unsqueeze(-1) * x_expanded + return result.view(s, b, n * C) + + @nvtx_decorator(message="HyperConnection::apply_h_post") + def apply_h_post( + self, + x_with_bias: Tuple[Tensor, Optional[Tensor]], + h_post: Tensor, + manager: Optional['CheckpointManager'] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Apply H_post to x and optionally bias, with optional checkpointing. + + This is the unified entry point that handles both normal execution + and checkpoint-based execution for memory efficiency. + + Args: + x_with_bias: Tuple of (x, bias) where: + - x: [s, b, C] - hidden states + - bias: [C] or None - optional bias tensor + h_post: [s, b, n] - expansion weights + manager: Optional CheckpointManager for checkpoint management. + When provided, wraps _apply_h_post with CheckpointWithoutOutput. + + Returns: + Tuple of (x_out, bias_out) where: + - x_out: [s, b, n*C] - expanded hidden states + - bias_out: [s, b, n*C] or None - expanded bias if input bias was not None + """ + x, bias = x_with_bias + + if manager is not None: + from megatron.core.tensor_parallel.random import CheckpointWithoutOutput + + # Checkpoint _apply_h_post to discard the output + x_out = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint( + self._apply_h_post, x, h_post + ) + + # Checkpoint _apply_h_post for bias if not None + if bias is not None: + bias_out = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint( + self._apply_h_post, bias, h_post + ) + else: + bias_out = None + else: + # Normal execution without checkpoint + x_out = self._apply_h_post(x, h_post) + bias_out = self._apply_h_post(bias, h_post) if bias is not None else None + + return x_out, bias_out + + def aggregate(self, x: Tensor, h_pre: Tensor) -> Tensor: + """ + Aggregate n-stream to 1-stream. + + Args: + x: [s, b, n*C] - n-stream hidden states + h_pre: [s, b, n] - aggregation weights + + Returns: + aggregated: [s, b, C] - single stream hidden states + """ + s, b, _ = x.shape + C = self.hidden_size + x_streams = x.view(s, b, self.n, C) + return self._h_aggregate_op(x_streams, h_pre) + + @torch.compile + def apply_h_res(self, h_res: Tensor, residual: Tensor) -> Tensor: + """ + Apply H_res to residual using H_res weights. + + Computes: H_res @ residual + + Args: + h_res: [s, b, n, n] - residual mixing matrix + residual: [s, b, n*C] - n-stream hidden states + """ + s, b, _ = residual.shape + n = self.n + C = self.hidden_size + + # Reshape for bmm: [s, b, n, n] -> [s*b, n, n] + h_res_batched = h_res.view(s * b, n, n) + # [s, b, n*C] -> [s, b, n, C] -> [s*b, n, C] + residual_batched = residual.view(s, b, n, C).view(s * b, n, C) + + # Batch matrix multiply: [s*b, n, n] @ [s*b, n, C] -> [s*b, n, C] + mixed = torch.bmm(h_res_batched, residual_batched) + + return mixed.view(s, b, n * C) + + def forward( + self, hidden_states: Tensor, mhc_recompute_manager: Optional['CheckpointManager'] = None + ) -> Tuple[Tensor, Tensor, Tensor]: + """ + Full mHC forward pass. + + Args: + hidden_states: [s, b, n*C] - n-stream hidden states + mhc_recompute_manager: Optional CheckpointManager for checkpoint management. + When provided, uses _forward_with_checkpoint for memory-efficient execution. + + Returns: + aggregated: [s, b, C] - aggregated input for layer computation + h_res: [s, b, n, n] - residual mixing matrix (for fused kernel) + h_post: [s, b, n] - expansion weights + """ + if mhc_recompute_manager is not None: + return self._forward_with_checkpoint(hidden_states, mhc_recompute_manager) + else: + return self._forward_normal(hidden_states) + + def _forward_normal(self, hidden_states: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """ + Normal forward pass without checkpointing. + + Args: + hidden_states: [s, b, n*C] - n-stream hidden states + + Returns: + aggregated: [s, b, C] - aggregated input for layer computation + h_res: [s, b, n, n] - residual mixing matrix (for fused kernel) + h_post: [s, b, n] - expansion weights + """ + # Compute mappings + h_pre, h_post, h_res = self.compute_mappings(hidden_states) + + # Aggregate for layer input + with torch.cuda.nvtx.range("HyperConnection::aggregate"): + aggregated = self.aggregate(hidden_states, h_pre) + + return aggregated, h_res, h_post + + def _forward_with_checkpoint( + self, hidden_states: Tensor, manager: 'CheckpointManager' + ) -> Tuple[Tensor, Tensor, Tensor]: + """ + Forward pass with checkpointing for memory efficiency. + + compute_mappings is called directly (not checkpointed) since its outputs + (h_pre, h_post, h_res) are needed downstream. Only aggregate is wrapped with + CheckpointWithoutOutput and auto-registered to the manager. + apply_h_res is deferred to fused_h_res_h_post_bda for kernel fusion. + + Args: + hidden_states: [s, b, n*C] - n-stream hidden states + manager: CheckpointManager for unified recomputation + + Returns: + aggregated: [s, b, C] - aggregated input for layer computation + h_res: [s, b, n, n] - residual mixing matrix (for fused kernel) + h_post: [s, b, n] - expansion weights + """ + from megatron.core.tensor_parallel.random import CheckpointWithoutOutput + + h_pre, h_post, h_res = self.compute_mappings(hidden_states) + + # Checkpoint aggregate - auto-registers to manager + aggregated = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint( + self.aggregate, hidden_states, h_pre + ) + + return aggregated, h_res, h_post + + # ==================== Block-level utilities ==================== + + @staticmethod + def input_expand(x: Tensor, n: int) -> Tensor: + """ + Expand 1-stream to n-stream at TransformerBlock entry. + + Simple replication strategy: each stream initialized as a copy of input. + + Args: + x: [s, b, C] - single stream hidden states + n: Number of residual streams + + Returns: + expanded: [s, b, n*C] - n-stream hidden states + """ + s, b, C = x.shape + # Replicate input to n streams + expanded = x.unsqueeze(2).expand(s, b, n, C).contiguous() + return expanded.view(s, b, n * C) + + @staticmethod + def output_contract(x: Tensor, n: int) -> Tensor: + """ + Contract n-stream to 1-stream at TransformerBlock exit. + + Simple averaging strategy: average all streams. + + Args: + x: [s, b, n*C] - n-stream hidden states + n: Number of residual streams + + Returns: + contracted: [s, b, C] - single stream hidden states + """ + s, b, nC = x.shape + C = nC // n + # Average all streams + x_streams = x.view(s, b, n, C) + contracted = x_streams.mean(dim=2) + return contracted + + # ==================== Fused kernel placeholder ==================== + + @nvtx_decorator(message="HyperConnection::fused_h_res_h_post_bda") + def fused_h_res_h_post_bda( + self, + h_res: Tensor, + original_residual: Tensor, + h_post: Tensor, + layer_output_with_bias: Tuple[Tensor, Optional[Tensor]], + dropout_prob: float, + training: bool, + fused: bool, + manager: Optional['CheckpointManager'] = None, + ) -> Tensor: + """ + Fused kernel combining apply_h_res, apply_h_post and bias-dropout-add. + + This is a placeholder for future kernel fusion optimization. + Currently implements the operations sequentially using native PyTorch. + + The computation flow is: + 1. mixed = H_res @ original_residual (apply_h_res) + 2. expanded = H_post^T @ layer_output (apply_h_post) + 3. output = dropout(expanded + bias) + mixed (bias-dropout-add) + + Args: + h_res: [s, b, n, n] - residual mixing matrix + original_residual: [s, b, n*C] - n-stream hidden states (before H_res applied) + h_post: [s, b, n] - expansion weights + layer_output_with_bias: Tuple of (x, bias) where: + - x: [s, b, C] - layer output (attention or MLP output) + - bias: [C] or None - optional bias tensor + dropout_prob: Dropout probability + training: Whether in training mode + fused: Whether to use fused BDA implementation + manager: Optional CheckpointManager for checkpoint management. + When provided, each operation is wrapped with CheckpointWithoutOutput. + + Returns: + output: [s, b, n*C] - final output after all operations + """ + if manager is not None: + return self._fused_h_res_h_post_bda_with_checkpoint( + h_res, + original_residual, + h_post, + layer_output_with_bias, + dropout_prob, + training, + fused, + manager, + ) + else: + return self._fused_h_res_h_post_bda_native( + h_res, + original_residual, + h_post, + layer_output_with_bias, + dropout_prob, + training, + fused, + ) + + def _fused_h_res_h_post_bda_native( + self, + h_res: Tensor, + original_residual: Tensor, + h_post: Tensor, + layer_output_with_bias: Tuple[Tensor, Optional[Tensor]], + dropout_prob: float, + training: bool, + fused: bool, + ) -> Tensor: + """ + h_res, h_post and bda. + + When dropout is zero (or inference), uses a single fused/reference kernel + for H_res @ residual + H_post * (x + bias). Falls back to unfused + implementation when dropout is needed. + + Args: + h_res: [s, b, n, n] - residual mixing matrix + original_residual: [s, b, n*C] - n-stream hidden states + h_post: [s, b, n] - expansion weights + layer_output_with_bias: Tuple of (x, bias) + dropout_prob: Dropout probability + training: Whether in training mode + fused: Whether to use fused BDA implementation + + Returns: + output: [s, b, n*C] - final output + """ + x, bias = layer_output_with_bias + + if dropout_prob == 0.0 or not training: + s, b, _ = original_residual.shape + n = self.n + C = self.hidden_size + orig_reshaped = original_residual.view(s, b, n, C) + output = self._h_post_bda_op(h_res, orig_reshaped, h_post, x, bias) + return output.view(s, b, n * C) + + from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add + + with torch.cuda.nvtx.range("HyperConnection::apply_h_res"): + mixed = self.apply_h_res(h_res, original_residual) + with torch.cuda.nvtx.range("HyperConnection::apply_h_post"): + x_expanded = self._apply_h_post(x, h_post) + bias_expanded = self._apply_h_post(bias, h_post) if bias is not None else None + bda_func = get_bias_dropout_add(training, fused) + with torch.cuda.nvtx.range("HyperConnection::bda"): + output = bda_func((x_expanded, bias_expanded), mixed, dropout_prob) + return output + + @nvtx_decorator(message="HyperConnection::fused_h_res_h_post_bda_with_checkpoint") + def _fused_h_res_h_post_bda_with_checkpoint( + self, + h_res: Tensor, + original_residual: Tensor, + h_post: Tensor, + layer_output_with_bias: Tuple[Tensor, Optional[Tensor]], + dropout_prob: float, + training: bool, + fused: bool, + manager: 'CheckpointManager', + ) -> Tensor: + """ + Checkpointed variant of _fused_h_res_h_post_bda_native. + + Wraps compute in CheckpointWithoutOutput for activation memory savings. + Cannot reuse _native directly because checkpoint requires all args to be + positional Tensors; tuple/Optional/scalar args are unpacked or captured + via closure instead. + + Args: + h_res: [s, b, n, n] - residual mixing matrix + original_residual: [s, b, n*C] - n-stream hidden states + h_post: [s, b, n] - expansion weights + layer_output_with_bias: Tuple of (x, bias) + dropout_prob: Dropout probability + training: Whether in training mode + fused: Whether to use fused BDA implementation + manager: CheckpointManager for checkpoint management + + Returns: + output: [s, b, n*C] - final output + """ + from megatron.core.tensor_parallel.random import CheckpointWithoutOutput + + x, bias = layer_output_with_bias + n = self.n + C = self.hidden_size + + # Fast path: no dropout — use fused/reference h_post_bda kernel (same as _native) + if dropout_prob == 0.0 or not training: + + def _fused_wrapper(h_res, original_residual, h_post, x, *optional_bias): + s, b, _ = original_residual.shape + orig_reshaped = original_residual.view(s, b, n, C) + b_arg = optional_bias[0] if optional_bias else None + return self._h_post_bda_op(h_res, orig_reshaped, h_post, x, b_arg).view(s, b, n * C) + + ckpt = CheckpointWithoutOutput(ckpt_manager=manager) + if bias is not None: + output = ckpt.checkpoint(_fused_wrapper, h_res, original_residual, h_post, x, bias) + else: + output = ckpt.checkpoint(_fused_wrapper, h_res, original_residual, h_post, x) + + # Slow path: dropout required — fused kernel does not support dropout, + # fall back to sequential apply_h_res + apply_h_post + bda + else: + from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add + + bda_func = get_bias_dropout_add(training, fused) + has_bias = bias is not None + + def _native_wrapper(h_res, original_residual, h_post, x, *optional_bias): + with torch.cuda.nvtx.range("HyperConnection::apply_h_res"): + mixed = self.apply_h_res(h_res, original_residual) + with torch.cuda.nvtx.range("HyperConnection::apply_h_post"): + x_expanded = self._apply_h_post(x, h_post) + if has_bias: + bias_expanded = self._apply_h_post(optional_bias[0], h_post) + else: + bias_expanded = None + with torch.cuda.nvtx.range("HyperConnection::bda"): + output = bda_func((x_expanded, bias_expanded), mixed, dropout_prob) + return output + + ckpt = CheckpointWithoutOutput(ckpt_manager=manager) + if has_bias: + output = ckpt.checkpoint(_native_wrapper, h_res, original_residual, h_post, x, bias) + else: + output = ckpt.checkpoint(_native_wrapper, h_res, original_residual, h_post, x) + + return output + + +# ==================== Checkpoint utilities for mHC ==================== + + +class HyperConnectionCheckpoint: + """ + Checkpoint utility for mHC intermediate activations. + + Implements the paper's "recomputing strategy" to reduce memory footprint + by discarding intermediate n-stream activations and recomputing on-the-fly. + """ + + @staticmethod + def compute_optimal_block_size(num_layers: int, num_streams: int) -> int: + """ + Compute optimal recomputation block size. + + From paper Eq. (20): L_r^* ≈ sqrt(nL/(n+2)) + + Args: + num_layers: Total number of transformer layers + num_streams: Number of residual streams (n) + + Returns: + block_size: Optimal block size for checkpointing + """ + block_size = int(math.sqrt(num_streams * num_layers / (num_streams + 2))) + return max(1, block_size) diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index 8bea3b8c94e..0048d18c3db 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -1,8 +1,9 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + import logging from contextlib import nullcontext from dataclasses import dataclass -from typing import List, Optional, Set, Union, cast +from typing import List, Optional, Set, Tuple, Union, cast import torch from torch import Tensor @@ -19,7 +20,9 @@ from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.pipeline_parallel.utils import is_vp_first_stage, is_vp_last_stage from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.tensor_parallel.random import CheckpointManager from megatron.core.transformer.enums import CudaGraphScope, LayerType +from megatron.core.transformer.hyper_connection import HyperConnectionModule from megatron.core.transformer.module import GraphableMegatronModule, MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.torch_norm import LayerNormBuilder @@ -319,6 +322,7 @@ def __init__( self.offload_context, self.group_prefetch_offload_commit_async = nullcontext(), None self.config._cpu_offloading_context = None + self.num_residual_streams = config.num_residual_streams self._build_layers() self.num_layers_per_pipeline_rank = len(self.layers) @@ -642,6 +646,46 @@ def __call__(self, *args, **kwargs): return super().__call__(*args, **kwargs)[0] return super().__call__(*args, **kwargs) + def _build_mhc_recompute_layer_plan( + self, use_mhc_recompute: bool + ) -> Tuple[List[Optional[CheckpointManager]], List[bool]]: + """Pre-build per-layer MHC recompute managers and block-end markers.""" + num_layers = len(self.layers) + layer_managers: List[Optional[CheckpointManager]] = [None] * num_layers + is_recompute_block_end: List[bool] = [False] * num_layers + + if not use_mhc_recompute or num_layers == 0: + return layer_managers, is_recompute_block_end + + mhc_recompute_layer_num = self.config.mhc_recompute_layer_num + mhc_manager = CheckpointManager() + + for l_no in range(num_layers): + is_last_in_transformer_block = l_no == num_layers - 1 + is_last_in_recompute_block = is_last_in_transformer_block + if mhc_recompute_layer_num is not None: + is_last_in_recompute_block = is_last_in_transformer_block or ( + (l_no + 1) % mhc_recompute_layer_num == 0 + ) + + layer_managers[l_no] = mhc_manager + is_recompute_block_end[l_no] = is_last_in_recompute_block + + if is_last_in_recompute_block and not is_last_in_transformer_block: + mhc_manager = CheckpointManager() + + return layer_managers, is_recompute_block_end + + @staticmethod + def _finalize_mhc_recompute_layer( + mhc_manager: Optional[CheckpointManager], + hidden_states: Tensor, + is_last_in_recompute_block: bool, + ) -> None: + """Finalize MHC recompute state for the current layer when block ends.""" + if mhc_manager is not None and is_last_in_recompute_block: + mhc_manager.discard_all_outputs_and_register_unified_recompute(hidden_states) + def forward( self, hidden_states: Union[Tensor, WrappedTensor], @@ -751,6 +795,13 @@ def forward( # is called here to be future-proof and corner-case-proof. hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + # Expand hidden states for hyper connections at the start of the block + # Only expand at the first PP stage; subsequent stages receive n-stream from previous stage + if self.config.enable_hyper_connections and self.pre_process: + hidden_states = HyperConnectionModule.input_expand( + hidden_states, self.num_residual_streams + ) # [s, b, C] -> [s, b, n*C] + if self.config.sequence_parallel: rng_context = tensor_parallel.get_cuda_rng_tracker().fork() else: @@ -778,6 +829,18 @@ def forward( use_inner_quantization_context = False outer_quantization_context = nullcontext() + # Determine if MHC recompute should be used + # Only enable when: training mode AND hyper connections AND 'mhc' in recompute_modules + use_mhc_recompute = ( + self.training + and self.config.enable_hyper_connections + and self.config.recompute_granularity == 'selective' + and "mhc" in self.config.recompute_modules + ) + mhc_layer_managers, mhc_is_last_in_recompute_block = self._build_mhc_recompute_layer_plan( + use_mhc_recompute + ) + with rng_context, outer_quantization_context: # Forward pass. if self.config.recompute_granularity == 'full' and self.training: @@ -818,6 +881,12 @@ def forward( else: inner_quantization_context = nullcontext() + mhc_manager = mhc_layer_managers[l_no] + if mhc_manager is not None: + mhc_manager.is_last_layer_in_recompute_block = ( + mhc_is_last_in_recompute_block[l_no] + ) + with self.offload_context, inner_quantization_context: hidden_states, context = layer( hidden_states=hidden_states, @@ -833,7 +902,13 @@ def forward( packed_seq_params=packed_seq_params, sequence_len_offset=sequence_len_offset, padding_mask=padding_mask, + mhc_recompute_manager=mhc_manager, ) + self._finalize_mhc_recompute_layer( + mhc_manager=mhc_manager, + hidden_states=hidden_states, + is_last_in_recompute_block=mhc_is_last_in_recompute_block[l_no], + ) if ( torch.is_grad_enabled() @@ -846,6 +921,12 @@ def forward( if (l_no + layer_offset) in extract_layer_indices: intermediate_hidden_states.append(hidden_states) + # Only contract if the final layer norm is in this stage + if self.config.enable_hyper_connections and self.has_final_layernorm_in_this_stage(): + hidden_states = HyperConnectionModule.output_contract( + hidden_states, self.num_residual_streams + ) # [s, b, n*C] -> [s, b, C] + # Final layer norm. if self.final_layernorm is not None: hidden_states = apply_module(self.final_layernorm)(cast(Tensor, hidden_states)) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 40c1a745493..7740b09012b 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import logging import math @@ -482,7 +482,8 @@ class TransformerConfig(ModelParallelConfig): recompute_modules: Optional[List[str]] = None """The submodules to recompute. - choices: "core_attn", "moe_act", "layernorm", "mla_up_proj", "mlp", "moe", "shared_experts". + choices: "core_attn", "moe_act", "layernorm", "mla_up_proj", "mlp", "moe", + "shared_experts", "mhc". default: ["core_attn"]. "core_attn": recompute the core attention part of the transformer layer. "moe_act": recompute the MoE MLP activation function. @@ -491,7 +492,10 @@ class TransformerConfig(ModelParallelConfig): "mlp": recompute the dense MLP submodule. "moe": recompute the MoE layer. "shared_experts": recompute the shared experts in the MoE layer. - "moe_act", "layernorm", and "mla_up_proj" use output-discarding checkpointing, + "mhc": recompute HyperConnection intermediate activations via + CheckpointWithoutOutput + CheckpointManager. Requires + enable_hyper_connections=True. Cannot be used with "mlp". + "moe_act", "layernorm", "mla_up_proj", and "mhc" use output-discarding checkpointing, "core_attn", "mlp", "moe", and "shared_experts" use normal checkpointing. """ @@ -871,6 +875,45 @@ class TransformerConfig(ModelParallelConfig): When cuda_graph_impl is set to "local", "full_iteration" can be specified as cuda_graph_scope to enable whole iteration CUDA graph. All other values enable layerwise CUDA graph.""" + #################### + # Hyper-Connection Configuration + #################### + enable_hyper_connections: bool = False + """Enable mHC residual connections.""" + + num_residual_streams: int = 4 + """Number of residual streams (n in paper).""" + + mhc_sinkhorn_iterations: int = 20 + """Number of Sinkhorn-Knopp iterations for doubly stochastic projection.""" + + mhc_init_gating_factor: float = 0.01 + """Initial value of Gating Factor (alpha in paper).""" + + use_fused_mhc: bool = False + """Use cuTile fused kernels for mHC operations. + + When True, attempts to replace the reference mHC modules (SinkhornKnopp, + H_aggregate, H_post_bda, ProjRms) with fused cuda.tile (cuTile) autograd + functions for better performance on supported GPUs. Requires cuTile to be + installed; if cuTile is unavailable the flag is silently reset to False and + a warning is emitted. + """ + + mhc_recompute_layer_num: Optional[int] = None + """Number of layers per MHC recompute block. + + When set, every `mhc_recompute_layer_num` layers form a recompute block. The last layer + in each recompute block (i.e., layer_number % mhc_recompute_layer_num == 0 or the final + layer in the transformer block) will: + - NOT checkpoint its final MLP BDA + - Register the unified recompute hook on its MLP BDA output + - A new CheckpointManager is created for subsequent layers + + If None, all layers in the transformer block share a single recompute block. + + Must be a positive integer when set.""" + #################### # miscellaneous #################### @@ -1383,6 +1426,7 @@ def __post_init__(self): "mlp", "moe", "shared_experts", + "mhc", } invalid_modules = set(self.recompute_modules) - allowed_modules assert not invalid_modules, ( @@ -1445,6 +1489,72 @@ def __post_init__(self): if "moe" not in self.recompute_modules: self.recompute_modules.append("moe") + # Validation for "mhc" in recompute_modules + if self.recompute_granularity == "selective" and "mhc" in self.recompute_modules: + if not self.enable_hyper_connections: + raise ValueError( + "'mhc' in recompute_modules requires enable_hyper_connections=True." + ) + if "mlp" in self.recompute_modules: + raise ValueError( + "'mhc' and 'mlp' in recompute_modules cannot be used together. " + "They use different checkpoint mechanisms that may conflict." + ) + if self.mhc_recompute_layer_num is not None and ( + isinstance(self.mhc_recompute_layer_num, bool) + or not isinstance(self.mhc_recompute_layer_num, int) + or self.mhc_recompute_layer_num < 1 + ): + raise ValueError( + "mhc_recompute_layer_num must be a positive integer when " + "'mhc' is in recompute_modules." + ) + if self.fine_grained_activation_offloading: + raise ValueError( + "'mhc' in recompute_modules is incompatible with " + "fine_grained_activation_offloading. The mHC recompute hook fires " + "before the offloading backward chunk is initialized, causing " + "tensor_pop on a None chunk. Disable one of them." + ) + + if self.enable_hyper_connections and not ( + self.recompute_granularity == "selective" and "mhc" in self.recompute_modules + ): + warnings.warn( + "HyperConnections are enabled but 'mhc' is not in " + "recompute_modules with selective recompute. Consider adding 'mhc' to " + "recompute_modules with selective recompute to reduce activation memory." + ) + + # Validation for use_fused_mhc + if self.use_fused_mhc: + if not self.enable_hyper_connections: + raise ValueError("use_fused_mhc requires enable_hyper_connections=True.") + try: + from megatron.core.fusions.fused_mhc_kernels import is_cutile_available + + if not is_cutile_available(): + warnings.warn( + "use_fused_mhc is enabled but cuda.tile (cuTile) is not installed. " + "Falling back to reference mHC implementations.", + UserWarning, + ) + self.use_fused_mhc = False + except ImportError: + warnings.warn( + "use_fused_mhc is enabled but fused_mhc_kernels module could not be " + "imported. Falling back to reference mHC implementations.", + UserWarning, + ) + self.use_fused_mhc = False + + # Validation for hyper_connections with MTP + if self.enable_hyper_connections and self.mtp_num_layers is not None: + raise ValueError( + "enable_hyper_connections is not compatible with Multi-Token Prediction (MTP). " + "Please disable MTP (set mtp_num_layers=None) when using hyper connections." + ) + if self.fine_grained_activation_offloading: assert ( not self.cpu_offloading diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index cf63199347c..437993021d5 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. from __future__ import annotations import functools @@ -8,6 +8,9 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Dict, Optional, Union +if TYPE_CHECKING: + from megatron.core.tensor_parallel.random import CheckpointManager + import torch import torch.distributed from torch import Tensor @@ -228,14 +231,17 @@ class TransformerLayerSubmodules: """ input_layernorm: LayerNormBuilder = IdentityOp + self_attention_hyper_connection: Union[ModuleSpec, type] = IdentityOp self_attention: Union[ModuleSpec, type] = IdentityOp self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp pre_cross_attn_layernorm: LayerNormBuilder = IdentityOp + cross_attention_hyper_connection: Union[ModuleSpec, type] = IdentityOp cross_attention: Union[ModuleSpec, type] = IdentityOp cross_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp pre_mlp_layernorm: LayerNormBuilder = IdentityOp + mlp_hyper_connection: Union[ModuleSpec, type] = IdentityOp mlp: Union[ModuleSpec, type] = IdentityOp mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp @@ -606,8 +612,6 @@ def _forward_attention( ) if using_fused_tp_inference_kernel: - # Set the residual for fused reduce-scatter + add + layer-norm + all-gather - # operation in attention's out_proj (linear_proj) self._set_proj_residual(residual) # Self attention. @@ -700,6 +704,11 @@ def forward(self, *args, **kwargs): This method calls the core computation of a transformer layer, including self-attention, cross-attention (if applicable), and feed-forward operations. """ + # Injected by __call__ for cuda graph keying; not a real forward arg. + kwargs.pop("dynamic_inference_decode_only", None) + assert ( + not self.config.enable_hyper_connections + ), "Please use HyperConnectionTransformerLayer instead" hidden_states, context = self._forward_attention(*args, **kwargs) output = self._forward_mlp( hidden_states, @@ -1280,6 +1289,33 @@ def _should_call_local_cudagraph(self, *args, **kwargs): return True return False + def backward_dw_cudagraph(self, microbatch_idx): + """ + CUDA Graph backward weight gradient computation for this layer. + """ + cg_index = microbatch_idx % len(self.cuda_graphs) + if not hasattr(self.cuda_graphs[cg_index], 'backward_dw'): + return + self.cuda_graphs[cg_index].backward_dw() + + def __call__(self, *args, **kwargs): + # Extract mhc_recompute_manager before CUDA graph manager processes kwargs, + # since CheckpointManager is not a CUDA-graph-supported type. + self._mhc_recompute_manager = kwargs.pop("mhc_recompute_manager", None) + kwargs.pop("is_last_layer_in_recompute_block", None) + + if self._should_call_local_cudagraph(*args, **kwargs): + # Inference mode. + if kwargs.get('inference_context') is not None: + # dynamic_inference_decode_only is not a real argument to forward, it is only used + # to differentiate the cuda graph used for decode from the one used for non-decode + # inference. + kwargs["dynamic_inference_decode_only"] = kwargs[ + 'inference_context' + ].is_decode_only() + + return super().__call__(*args, **kwargs) + def get_layer_norm_weights(self): """ Get the weights of all layernorms (attention and MLP) in the transformer layer. @@ -1289,6 +1325,373 @@ def get_layer_norm_weights(self): return +class HyperConnectionTransformerLayer(TransformerLayer): + """A transformer layer with Manifold-Constrained Hyper-Connections (mHC). + + Extends TransformerLayer by adding hyper connection modules around self-attention + and MLP. The n-stream hidden states are aggregated before each sub-layer and + expanded back afterwards using learned mappings (H_pre, H_post, H_res). + + Cross-attention hyper connection is not supported. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + hidden_dropout: Optional[float] = None, + pg_collection: Optional[ProcessGroupCollection] = None, + vp_stage: Optional[int] = None, + ): + super().__init__( + config=config, + submodules=submodules, + layer_number=layer_number, + hidden_dropout=hidden_dropout, + pg_collection=pg_collection, + vp_stage=vp_stage, + ) + + if submodules.cross_attention_hyper_connection is not IdentityOp: + raise ValueError( + "HyperConnectionTransformerLayer does not support cross-attention " + "hyper connections. Use IdentityOp for cross_attention_hyper_connection." + ) + + assert submodules.self_attention_hyper_connection is not IdentityOp, ( + "HyperConnectionTransformerLayer requires self_attention_hyper_connection. " + "Use TransformerLayer instead if hyper connections are not needed." + ) + assert submodules.mlp_hyper_connection is not IdentityOp, ( + "HyperConnectionTransformerLayer requires mlp_hyper_connection. " + "Use TransformerLayer instead if hyper connections are not needed." + ) + + self.self_attention_hyper_connection = build_module( + submodules.self_attention_hyper_connection, + config=self.config, + layer_number=self.layer_number, + ) + + self.mlp_hyper_connection = build_module( + submodules.mlp_hyper_connection, config=self.config, layer_number=self.layer_number + ) + + # When mHC recompute is active, skip checkpointing if the layernorm + # is IdentityOp (fused into TE linear) — there is nothing to recompute. + self.mhc_checkpoint_input_layernorm = not isinstance(self.input_layernorm, IdentityOp) + self.mhc_checkpoint_pre_mlp_layernorm = not isinstance(self.pre_mlp_layernorm, IdentityOp) + + def get_layer_static_inputs(self, seq_length, micro_batch_size): + """Override to produce n-stream hidden_states of shape [s, b, n*C]. + + CUDA graph capture creates static buffers whose shapes are determined by + this method. The base class returns [s, b, C], but mHC layers operate on + n-stream hidden states of shape [s, b, n*C]. + """ + static_inputs = super().get_layer_static_inputs(seq_length, micro_batch_size) + hs = static_inputs["hidden_states"] + n = self.config.num_residual_streams + static_inputs["hidden_states"] = torch.ones( + (hs.shape[0], hs.shape[1], n * self.config.hidden_size), + dtype=hs.dtype, + requires_grad=hs.requires_grad, + device=hs.device, + ) + return static_inputs + + def _get_submodules_under_cudagraphs(self): + """Override to include hyper connection modules. + + The base TransformerLayer._get_submodules_under_cudagraphs does not include + self_attention_hyper_connection / mlp_hyper_connection. Their learnable + parameters (mapping_proj, alpha_*, bias) need manual pre-forward hooks + during CUDA graph replay so that parameter all-gathers are triggered. + """ + submodules = super()._get_submodules_under_cudagraphs() + + if not self.config.cuda_graph_scope: + return submodules + + if CudaGraphScope.attn in self.config.cuda_graph_scope: + submodules.append(self.self_attention_hyper_connection) + if (not self.is_moe_layer and CudaGraphScope.mlp in self.config.cuda_graph_scope) or ( + self.is_moe_layer and CudaGraphScope.moe in self.config.cuda_graph_scope + ): + submodules.append(self.mlp_hyper_connection) + return submodules + + def forward(self, *args, **kwargs): + """Forward pass with MHC recompute manager support.""" + kwargs.pop("dynamic_inference_decode_only", None) + + mhc_recompute_manager = getattr(self, '_mhc_recompute_manager', None) + + hidden_states, context = self._forward_attention( + *args, mhc_recompute_manager=mhc_recompute_manager, **kwargs + ) + + output = self._forward_mlp( + hidden_states, + kwargs.get("inference_context", None), + padding_mask=kwargs.get("padding_mask", None), + mhc_recompute_manager=mhc_recompute_manager, + ) + return output, context + + def _forward_attention( + self, + hidden_states: Tensor, + attention_mask: Optional[Tensor] = None, + context: Optional[Tensor] = None, + context_mask: Optional[Tensor] = None, + rotary_pos_emb: Optional[Tensor] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + rotary_pos_cos_sin: Optional[Tensor] = None, + attention_bias: Optional[Tensor] = None, + inference_context: Optional[Any] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[Tensor] = None, + padding_mask: Optional[Tensor] = None, + mhc_recompute_manager: Optional['CheckpointManager'] = None, + *, + inference_params: Optional[Any] = None, + ): + """Forward attention with hyper connection pre/post processing on self-attention.""" + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, + ) + + inference_context = deprecate_inference_params(inference_context, inference_params) + + residual = hidden_states + + nvtx_range_push(suffix="self_attention_hyper_connection") + hidden_states, self_attn_h_res, self_attn_hc_h_post = self.self_attention_hyper_connection( + hidden_states, mhc_recompute_manager=mhc_recompute_manager + ) + nvtx_range_pop(suffix="self_attention_hyper_connection") + + # Optional Input Layer norm + checkpoint_input_layernorm = self.recompute_input_layernorm or ( + mhc_recompute_manager is not None and self.mhc_checkpoint_input_layernorm + ) + if checkpoint_input_layernorm: + self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput( + ckpt_manager=mhc_recompute_manager + ) + with off_interface(self.offload_attn_norm, hidden_states, "attn_norm") as hidden_states: + input_layernorm_output = self.input_layernorm_checkpoint.checkpoint( + self.input_layernorm, hidden_states + ) + else: + with off_interface(self.offload_attn_norm, hidden_states, "attn_norm") as hidden_states: + input_layernorm_output = self.input_layernorm(hidden_states) + + # Self attention. + nvtx_range_push(suffix="self_attention") + attention_output_with_bias = self.self_attention( + input_layernorm_output, + attention_mask=attention_mask, + inference_context=inference_context, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + rotary_pos_cos_sin=rotary_pos_cos_sin, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ) + nvtx_range_pop(suffix="self_attention") + + if checkpoint_input_layernorm: + self.input_layernorm_checkpoint.discard_output_and_register_recompute( + attention_output_with_bias[0] + ) + + nvtx_range_push(suffix="self_attention_fused_h_res_h_post_bda") + with self.bias_dropout_add_exec_handler(): + hidden_states = self.self_attention_hyper_connection.fused_h_res_h_post_bda( + self_attn_h_res, + residual, + self_attn_hc_h_post, + attention_output_with_bias, + self.hidden_dropout, + self.training, + self.config.bias_dropout_fusion, + mhc_recompute_manager, + ) + nvtx_range_pop(suffix="self_attention_fused_h_res_h_post_bda") + + if self.offload_attn_norm: + hidden_states = off_interface.group_commit(hidden_states, name="attn_norm") + + # Cross-attention (no hyper connection support). + residual = hidden_states + pre_cross_attn_layernorm_output = self.pre_cross_attn_layernorm(hidden_states) + + attention_output_with_bias = self.cross_attention( + pre_cross_attn_layernorm_output, + attention_mask=context_mask, + key_value_states=context, + inference_context=inference_context, + ) + + if isinstance(attention_output_with_bias, dict) and "context" in attention_output_with_bias: + context = attention_output_with_bias["context"] + + with self.bias_dropout_add_exec_handler(): + hidden_states = self.cross_attn_bda(self.training, self.config.bias_dropout_fusion)( + attention_output_with_bias, residual, self.hidden_dropout + ) + + return hidden_states, context + + def _forward_mlp( + self, + hidden_states, + inference_context=None, + padding_mask=None, + mhc_recompute_manager: Optional['CheckpointManager'] = None, + ): + """Forward MLP with hyper connection pre/post processing.""" + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, + ) + + is_last_in_recompute_block = bool( + mhc_recompute_manager is not None + and getattr(mhc_recompute_manager, "is_last_layer_in_recompute_block", False) + ) + mhc_mlp_bda_manager = None if is_last_in_recompute_block else mhc_recompute_manager + + residual = hidden_states + + nvtx_range_push(suffix="mlp_hyper_connection") + hidden_states, mlp_h_res, mlp_hc_h_post = self.mlp_hyper_connection( + hidden_states, mhc_recompute_manager=mhc_recompute_manager + ) + nvtx_range_pop(suffix="mlp_hyper_connection") + + # Optional Layer norm post the cross-attention. + checkpoint_pre_mlp_layernorm = self.recompute_pre_mlp_layernorm or ( + mhc_recompute_manager is not None and self.mhc_checkpoint_pre_mlp_layernorm + ) + if checkpoint_pre_mlp_layernorm: + self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput( + ckpt_manager=mhc_recompute_manager + ) + with off_interface(self.offload_mlp_norm, hidden_states, "mlp_norm") as hidden_states: + pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint( + self.pre_mlp_layernorm, hidden_states + ) + else: + with off_interface(self.offload_mlp_norm, hidden_states, "mlp_norm") as hidden_states: + pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) + + nvtx_range_push(suffix="mlp") + should_chunk_mlp_for_prefill = ( + self.config.mlp_chunks_for_prefill > 1 + and inference_context is not None + and not inference_context.is_decode_only() + and not isinstance(self.mlp, IdentityOp) + and not self.config.transformer_impl == "inference_optimized" + ) + + if self.recompute_mlp: + if self.config.fp8 or self.config.fp4: + from megatron.core.extensions.transformer_engine import te_checkpoint + + mlp_output_with_bias = te_checkpoint( + self.mlp, + False, + tensor_parallel.random.get_cuda_rng_tracker, + self.pg_collection.tp, + pre_mlp_layernorm_output, + padding_mask=padding_mask, + ) + else: + mlp_output_with_bias = tensor_parallel.checkpoint( + functools.partial(self.mlp, padding_mask=padding_mask), + False, + pre_mlp_layernorm_output, + ) + elif should_chunk_mlp_for_prefill: + num_chunks = min(self.config.mlp_chunks_for_prefill, pre_mlp_layernorm_output.shape[0]) + chunks = pre_mlp_layernorm_output.chunk(num_chunks, dim=0) + outputs = [self.mlp(chunk) for chunk in chunks] + mlp_output = torch.cat([out for out, _ in outputs], dim=0) + bias_chunks = [bias for _, bias in outputs if bias is not None] + bias_output = torch.stack(bias_chunks, dim=0).sum(dim=0) if bias_chunks else None + mlp_output_with_bias = (mlp_output, bias_output) + else: + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output, padding_mask=padding_mask) + + nvtx_range_pop(suffix="mlp") + + return self._forward_post_mlp_with_fused_hyper_connection( + mlp_output_with_bias, mlp_h_res, residual, mlp_hc_h_post, mhc_mlp_bda_manager + ) + + def _forward_post_mlp_with_fused_hyper_connection( + self, + mlp_output_with_bias, + mlp_h_res, + residual, + mlp_hc_h_post, + mhc_mlp_bda_recompute_manager: Optional['CheckpointManager'] = None, + ): + """ + Perform operations after the MLP computation with fused hyper connection kernel. + + This method uses the fused kernel combining apply_h_res, apply_h_post and bias-dropout-add. + + Args: + mlp_output_with_bias (Tensor): Output tensor of the MLP layer with bias. + mlp_h_res (Tensor): [s, b, n, n] - residual mixing matrix from hyper connection. + residual (Tensor): [s, b, n*C] - original residual (n-stream hidden states). + mlp_hc_h_post (Tensor): [s, b, n] - expansion weights from hyper connection. + mhc_recompute_manager: Optional CheckpointManager for checkpoint management. + + Returns: + output (Tensor): Transformed hidden states of shape [s, b, n*C]. + """ + if self.recompute_pre_mlp_layernorm or ( + mhc_mlp_bda_recompute_manager is not None and self.mhc_checkpoint_pre_mlp_layernorm + ): + self.pre_mlp_norm_checkpoint.discard_output_and_register_recompute( + mlp_output_with_bias[0] + ) + + nvtx_range_push(suffix="mlp_fused_h_res_h_post_bda") + with self.bias_dropout_add_exec_handler(): + hidden_states = self.mlp_hyper_connection.fused_h_res_h_post_bda( + mlp_h_res, + residual, + mlp_hc_h_post, + mlp_output_with_bias, + self.hidden_dropout, + self.training, + self.config.bias_dropout_fusion, + mhc_mlp_bda_recompute_manager, + ) + nvtx_range_pop(suffix="mlp_fused_h_res_h_post_bda") + + if self.offload_mlp_norm: + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, + ) + + hidden_states = off_interface.group_commit(hidden_states, name="mlp_norm") + + output = make_viewless_tensor( + inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True + ) + return output + + class MoETransformerLayer(TransformerLayer): """ A Transformer layer specialized for Mixture-of-Experts (MoE) architectures. diff --git a/megatron/training/initialize.py b/megatron/training/initialize.py index ff655502019..61a795b4754 100644 --- a/megatron/training/initialize.py +++ b/megatron/training/initialize.py @@ -25,7 +25,12 @@ from megatron.core.transformer.custom_layers.batch_invariant_kernels import ( enable_batch_invariant_mode, ) -from megatron.core.utils import get_te_version, is_te_min_version, is_torch_min_version +from megatron.core.utils import ( + configure_nvtx_profiling, + get_te_version, + is_te_min_version, + is_torch_min_version, +) from megatron.training import ( get_adlr_autoresume, get_args, @@ -89,6 +94,12 @@ def state_restore_func(state_dict): print_rank_0("Enabling batch invariant mode globally") enable_batch_invariant_mode() + # Enable NVTX range profiling when profiling is active. + # Must be done before model modules with @nvtx_decorator are imported, + # since the decorator captures _nvtx_enabled at decoration (import) time. + if args.profile: + configure_nvtx_profiling(True) + # torch.distributed initialization def finish_mpu_init(): args = get_args() From 6c58c9504e260750bce77136703804fc8ec92e13 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Mon, 27 Apr 2026 09:15:32 -0700 Subject: [PATCH 02/10] Add mHC transformer unit tests --- .../unit_tests/models/test_gpt_layer_specs.py | 67 ++ .../models/test_hybrid_moe_model.py | 6 + tests/unit_tests/test_fp8_param.py | 8 +- .../test_hyper_connection_recompute.py | 408 +++++++++ .../transformer/test_mhc_block_manager.py | 397 +++++++++ .../transformer/test_transformer_layer.py | 786 +++++++++++++++++- 6 files changed, 1666 insertions(+), 6 deletions(-) create mode 100644 tests/unit_tests/models/test_gpt_layer_specs.py create mode 100644 tests/unit_tests/transformer/test_hyper_connection_recompute.py create mode 100644 tests/unit_tests/transformer/test_mhc_block_manager.py diff --git a/tests/unit_tests/models/test_gpt_layer_specs.py b/tests/unit_tests/models/test_gpt_layer_specs.py new file mode 100644 index 00000000000..bfa86fd0241 --- /dev/null +++ b/tests/unit_tests/models/test_gpt_layer_specs.py @@ -0,0 +1,67 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import pytest + +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, +) +from megatron.core.transformer.hyper_connection import HyperConnectionModule +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.transformer_layer import ( + HyperConnectionTransformerLayer, + TransformerLayer, +) + +_TE = get_gpt_layer_with_transformer_engine_spec +_LOCAL = get_gpt_layer_local_spec +_HC = HyperConnectionTransformerLayer +_HC_MOD = HyperConnectionModule +_TL = TransformerLayer +_ID = IdentityOp + + +class TestGptLayerSpecsHyperConnection: + """Test that enable_hyper_connection controls module types in layer specs.""" + + @pytest.mark.parametrize( + "factory,kwargs,expected_module,expected_hc", + [ + (_TE, {}, _TL, _ID), + (_TE, {"enable_hyper_connection": True}, _HC, _HC_MOD), + (_TE, {"enable_hyper_connection": False}, _TL, _ID), + (_TE, {"multi_latent_attention": True, "enable_hyper_connection": False}, _TL, _ID), + (_TE, {"multi_latent_attention": True, "enable_hyper_connection": True}, _HC, _HC_MOD), + (_LOCAL, {}, _TL, _ID), + (_LOCAL, {"enable_hyper_connection": True}, _HC, _HC_MOD), + (_LOCAL, {"enable_hyper_connection": False}, _TL, _ID), + (_LOCAL, {"multi_latent_attention": True, "enable_hyper_connection": False}, _TL, _ID), + ( + _LOCAL, + {"multi_latent_attention": True, "enable_hyper_connection": True}, + _HC, + _HC_MOD, + ), + (_LOCAL, {"normalization": "RMSNorm", "enable_hyper_connection": False}, _TL, _ID), + (_LOCAL, {"normalization": "RMSNorm", "enable_hyper_connection": True}, _HC, _HC_MOD), + ], + ids=[ + "te_default", + "te_enable", + "te_disable", + "te_mla_disable", + "te_mla_enable", + "local_default", + "local_enable", + "local_disable", + "local_mla_disable", + "local_mla_enable", + "local_rmsnorm_disable", + "local_rmsnorm_enable", + ], + ) + def test_hyper_connection_spec(self, factory, kwargs, expected_module, expected_hc): + spec = factory(**kwargs) + assert spec.module is expected_module + assert spec.submodules.self_attention_hyper_connection is expected_hc + assert spec.submodules.mlp_hyper_connection is expected_hc diff --git a/tests/unit_tests/models/test_hybrid_moe_model.py b/tests/unit_tests/models/test_hybrid_moe_model.py index 01a46efe083..56c12076041 100644 --- a/tests/unit_tests/models/test_hybrid_moe_model.py +++ b/tests/unit_tests/models/test_hybrid_moe_model.py @@ -89,6 +89,7 @@ "embedding_init_method_std": 0.014, "enable_autocast": False, "enable_cuda_graph": False, + "enable_hyper_connections": False, "ep_overlap_early_attn_memory_release": False, "experimental_attention_variant": None, "expert_model_parallel_size": 4, @@ -151,6 +152,9 @@ "mamba_state_dim": 128, "masked_softmax_fusion": True, "memory_efficient_layer_norm": False, + "mhc_init_gating_factor": 0.01, + "mhc_recompute_layer_num": None, + "mhc_sinkhorn_iterations": 20, "microbatch_group_size_per_vp_stage": 1, "mlp_chunks_for_prefill": 1, "moe_apply_probs_on_input": False, @@ -219,6 +223,7 @@ "num_microbatches_with_partial_activation_checkpoints": None, "num_moe_experts": 128, "num_query_groups": 2, + "num_residual_streams": 4, "output_layer_init_method": {}, "overlap_moe_expert_parallel_comm": False, "overlap_p2p_comm": False, @@ -265,6 +270,7 @@ "tp_only_amax_red": False, "transformer_impl": "transformer_engine", "use_cpu_initialization": None, + "use_fused_mhc": False, "use_fused_weighted_squared_relu": False, "use_inference_optimized_layers": False, "use_kitchen": False, diff --git a/tests/unit_tests/test_fp8_param.py b/tests/unit_tests/test_fp8_param.py index 34b504e21de..e0a71526297 100644 --- a/tests/unit_tests/test_fp8_param.py +++ b/tests/unit_tests/test_fp8_param.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import contextlib import gc @@ -72,12 +72,12 @@ def setup_method(self, method): os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' def teardown_method(self, method): - Utils.destroy_model_parallel() - destroy_global_vars() - destroy_num_microbatches_calculator() if self.cuda_graph_helper is not None and self.cuda_graph_helper.graphs_created(): self.cuda_graph_helper.delete_cuda_graphs() self.cuda_graph_helper = None + Utils.destroy_model_parallel() + destroy_global_vars() + destroy_num_microbatches_calculator() gc.collect() def model_provider( diff --git a/tests/unit_tests/transformer/test_hyper_connection_recompute.py b/tests/unit_tests/transformer/test_hyper_connection_recompute.py new file mode 100644 index 00000000000..cf44f2d7cd0 --- /dev/null +++ b/tests/unit_tests/transformer/test_hyper_connection_recompute.py @@ -0,0 +1,408 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +""" +Unit tests for HyperConnection block-level recomputation. + +Tests the following functionality: +1. HyperConnectionModule._forward_with_checkpoint correctness +2. HyperConnectionModule.apply_h_post with CheckpointManager +3. Multiple HyperConnectionModules chained with a single CheckpointManager +4. Partial checkpoint (last layer not checkpointed) +5. TransformerConfig 'mhc' in recompute_modules option +""" + +import pytest +import torch + +from megatron.core.tensor_parallel.random import CheckpointManager, model_parallel_cuda_manual_seed +from megatron.core.transformer.hyper_connection import HyperConnectionModule +from megatron.core.transformer.transformer_config import TransformerConfig +from tests.unit_tests.test_utilities import Utils + + +class TestHyperConnectionCheckpoint: + """Test HyperConnectionModule checkpoint functionality.""" + + def setup_method(self, method): + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def _create_hyper_connection_module(self, hidden_size=64, num_residual_streams=4): + """Create a HyperConnectionModule for testing.""" + config = TransformerConfig( + num_layers=2, + hidden_size=hidden_size, + num_attention_heads=4, + use_cpu_initialization=True, + enable_hyper_connections=True, + num_residual_streams=num_residual_streams, + mhc_sinkhorn_iterations=5, # Fewer iterations for faster tests + mhc_init_gating_factor=0.01, + ) + module = HyperConnectionModule(config=config, layer_number=1) + module.cuda() + return module + + def test_forward_normal_vs_checkpoint_correctness(self): + """ + Test that _forward_with_checkpoint produces the same outputs as _forward_normal. + """ + hidden_size = 64 + num_streams = 4 + seq_len = 8 + batch_size = 2 + + module = self._create_hyper_connection_module(hidden_size, num_streams) + + # Create input tensors + hidden_states = torch.randn( + seq_len, batch_size, num_streams * hidden_size, device='cuda', requires_grad=True + ) + residual = torch.randn( + seq_len, batch_size, num_streams * hidden_size, device='cuda', requires_grad=True + ) + + # Clone inputs for comparison + hidden_states_ckpt = hidden_states.detach().clone().requires_grad_(True) + residual_ckpt = residual.detach().clone().requires_grad_(True) + + # Forward without checkpoint (reference) + torch.manual_seed(42) + torch.cuda.manual_seed(42) + aggregated_ref, h_res_ref, h_post_ref = module._forward_normal(hidden_states) + mixed_ref = module.apply_h_res(h_res_ref, residual) + loss_ref = aggregated_ref.sum() + mixed_ref.sum() + h_post_ref.sum() + loss_ref.backward() + grad_hidden_ref = hidden_states.grad.clone() + grad_residual_ref = residual.grad.clone() + + # Forward with checkpoint + torch.manual_seed(42) + torch.cuda.manual_seed(42) + manager = CheckpointManager() + aggregated_ckpt, h_res_ckpt, h_post_ckpt = module._forward_with_checkpoint( + hidden_states_ckpt, manager + ) + mixed_ckpt = module.apply_h_res(h_res_ckpt, residual_ckpt) + # Calculate loss before discarding outputs + loss_ckpt = aggregated_ckpt.sum() + mixed_ckpt.sum() + h_post_ckpt.sum() + + # Register unified recompute hook + manager.discard_all_outputs_and_register_unified_recompute(loss_ckpt) + + # Backward pass + loss_ckpt.backward() + grad_hidden_ckpt = hidden_states_ckpt.grad.clone() + grad_residual_ckpt = residual_ckpt.grad.clone() + + # Verify gradients match + assert torch.allclose(grad_hidden_ckpt, grad_hidden_ref, atol=1e-5), ( + f"Hidden states gradients mismatch:\n" + f"Checkpoint: {grad_hidden_ckpt}\n" + f"Reference: {grad_hidden_ref}" + ) + assert torch.allclose(grad_residual_ckpt, grad_residual_ref, atol=1e-5), ( + f"Residual gradients mismatch:\n" + f"Checkpoint: {grad_residual_ckpt}\n" + f"Reference: {grad_residual_ref}" + ) + + def test_apply_h_post_with_checkpoint(self): + """ + Test that apply_h_post with manager produces correct gradients. + """ + hidden_size = 64 + num_streams = 4 + seq_len = 8 + batch_size = 2 + + module = self._create_hyper_connection_module(hidden_size, num_streams) + + # Create input tensors + x = torch.randn(seq_len, batch_size, hidden_size, device='cuda', requires_grad=True) + bias = torch.randn(hidden_size, device='cuda') + h_post = torch.randn(seq_len, batch_size, num_streams, device='cuda', requires_grad=True) + + # Clone inputs + x_ckpt = x.detach().clone().requires_grad_(True) + h_post_ckpt = h_post.detach().clone().requires_grad_(True) + + # Reference: without checkpoint (manager=None) + torch.manual_seed(42) + x_out_ref, bias_out_ref = module.apply_h_post((x, bias), h_post, manager=None) + loss_ref = x_out_ref.sum() + if bias_out_ref is not None: + loss_ref = loss_ref + bias_out_ref.sum() + loss_ref.backward() + grad_x_ref = x.grad.clone() + grad_h_post_ref = h_post.grad.clone() + + # With checkpoint (manager provided) + torch.manual_seed(42) + manager = CheckpointManager() + x_out_ckpt, bias_out_ckpt = module.apply_h_post( + (x_ckpt, bias), h_post_ckpt, manager=manager + ) + loss_ckpt = x_out_ckpt.sum() + if bias_out_ckpt is not None: + loss_ckpt = loss_ckpt + bias_out_ckpt.sum() + + manager.discard_all_outputs_and_register_unified_recompute(loss_ckpt) + loss_ckpt.backward() + grad_x_ckpt = x_ckpt.grad.clone() + grad_h_post_ckpt = h_post_ckpt.grad.clone() + + # Verify gradients + assert torch.allclose(grad_x_ckpt, grad_x_ref, atol=1e-5) + assert torch.allclose(grad_h_post_ckpt, grad_h_post_ref, atol=1e-5) + + def test_forward_with_manager_parameter(self): + """ + Test forward() method with mhc_recompute_manager parameter. + """ + hidden_size = 64 + num_streams = 4 + seq_len = 8 + batch_size = 2 + + module = self._create_hyper_connection_module(hidden_size, num_streams) + + # Create input tensors + hidden_states = torch.randn( + seq_len, batch_size, num_streams * hidden_size, device='cuda', requires_grad=True + ) + + # Clone inputs + hidden_states_ckpt = hidden_states.detach().clone().requires_grad_(True) + + # Reference: forward without manager (uses _forward_normal) + torch.manual_seed(42) + torch.cuda.manual_seed(42) + aggregated_ref, h_res_ref, h_post_ref = module.forward( + hidden_states, mhc_recompute_manager=None + ) + loss_ref = aggregated_ref.sum() + h_res_ref.sum() + h_post_ref.sum() + loss_ref.backward() + grad_hidden_ref = hidden_states.grad.clone() + + # With manager (uses _forward_with_checkpoint) + torch.manual_seed(42) + torch.cuda.manual_seed(42) + manager = CheckpointManager() + aggregated_ckpt, h_res_ckpt, h_post_ckpt = module.forward( + hidden_states_ckpt, mhc_recompute_manager=manager + ) + loss_ckpt = aggregated_ckpt.sum() + h_res_ckpt.sum() + h_post_ckpt.sum() + + manager.discard_all_outputs_and_register_unified_recompute(loss_ckpt) + loss_ckpt.backward() + grad_hidden_ckpt = hidden_states_ckpt.grad.clone() + + # Verify gradients match + assert torch.allclose(grad_hidden_ckpt, grad_hidden_ref, atol=1e-5) + + +class TestMHCBlockRecomputeIntegration: + """Test CheckpointManager integration with HyperConnection.""" + + def setup_method(self, method): + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def test_multiple_hyper_connections_in_chain(self): + """ + Test that multiple HyperConnectionModules can be chained together + with a single CheckpointManager. + """ + hidden_size = 64 + num_streams = 4 + seq_len = 8 + batch_size = 2 + n_channels = num_streams * hidden_size + + # Create multiple HyperConnection modules (simulating multiple layers) + config = TransformerConfig( + num_layers=4, + hidden_size=hidden_size, + num_attention_heads=4, + use_cpu_initialization=True, + enable_hyper_connections=True, + num_residual_streams=num_streams, + mhc_sinkhorn_iterations=5, + mhc_init_gating_factor=0.01, + ) + + modules = [ + HyperConnectionModule(config=config, layer_number=i + 1).cuda() for i in range(3) + ] + + # Create input tensors + hidden_states_ref = torch.randn( + seq_len, batch_size, n_channels, device='cuda', requires_grad=True + ) + residual_ref = torch.randn( + seq_len, batch_size, n_channels, device='cuda', requires_grad=True + ) + + hidden_states_ckpt = hidden_states_ref.detach().clone().requires_grad_(True) + residual_ckpt = residual_ref.detach().clone().requires_grad_(True) + + # Reference: forward without checkpoint + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + h = hidden_states_ref + r = residual_ref + for module in modules: + agg, h_res, h_post = module.forward(h, mhc_recompute_manager=None) + agg, _ = module.apply_h_post((0.1 * agg, None), h_post, manager=None) + mixed = module.apply_h_res(h_res, r) # Apply h_res to get mixed [s, b, n*C] + h = agg + mixed + r = h + + loss_ref = h.sum() + loss_ref.backward() + grad_hidden_ref = hidden_states_ref.grad.clone() + grad_residual_ref = residual_ref.grad.clone() + + # With checkpoint using single manager + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + manager = CheckpointManager() + + h = hidden_states_ckpt + r = residual_ckpt + for module in modules: + agg, h_res, h_post = module.forward(h, mhc_recompute_manager=manager) + agg, _ = module.apply_h_post((0.1 * agg, None), h_post, manager=manager) + mixed = module.apply_h_res(h_res, r) # Apply h_res to get mixed [s, b, n*C] + h = agg + mixed + r = h + + loss_ckpt = h.sum() + manager.discard_all_outputs_and_register_unified_recompute(loss_ckpt) + loss_ckpt.backward() + + grad_hidden_ckpt = hidden_states_ckpt.grad.clone() + grad_residual_ckpt = residual_ckpt.grad.clone() + + # Verify gradients + assert torch.allclose( + grad_hidden_ckpt, grad_hidden_ref, atol=1e-4 + ), f"Chained HyperConnection hidden gradients mismatch" + assert torch.allclose( + grad_residual_ckpt, grad_residual_ref, atol=1e-4 + ), f"Chained HyperConnection residual gradients mismatch" + + def test_partial_checkpoint_last_layer_not_checkpointed(self): + """ + Test that when is_last_layer_in_block=True, the final output is NOT checkpointed. + This simulates the TransformerBlock behavior where the last layer's MLP BDA + serves as the hook_tensor for unified recompute. + """ + hidden_size = 64 + num_streams = 4 + seq_len = 8 + batch_size = 2 + + config = TransformerConfig( + num_layers=2, + hidden_size=hidden_size, + num_attention_heads=4, + use_cpu_initialization=True, + enable_hyper_connections=True, + num_residual_streams=num_streams, + mhc_sinkhorn_iterations=5, + mhc_init_gating_factor=0.01, + ) + + module = HyperConnectionModule(config=config, layer_number=1).cuda() + + hidden_states_ref = torch.randn( + seq_len, batch_size, num_streams * hidden_size, device='cuda', requires_grad=True + ) + residual_ref = torch.randn( + seq_len, batch_size, num_streams * hidden_size, device='cuda', requires_grad=True + ) + + hidden_states_ckpt = hidden_states_ref.detach().clone().requires_grad_(True) + residual_ckpt = residual_ref.detach().clone().requires_grad_(True) + + # Reference + torch.manual_seed(42) + torch.cuda.manual_seed(42) + aggregated_ref, h_res_ref, h_post_ref = module.forward( + hidden_states_ref, mhc_recompute_manager=None + ) + aggregated_ref, _ = module.apply_h_post( + (0.1 * aggregated_ref, None), h_post_ref, manager=None + ) + mixed_ref = module.apply_h_res( + h_res_ref, residual_ref + ) # Apply h_res to get mixed [s, b, n*C] + # Simulate BDA that is NOT checkpointed (last layer) + output_ref = aggregated_ref + 0.5 * mixed_ref + loss_ref = output_ref.sum() + loss_ref.backward() + grad_hidden_ref = hidden_states_ref.grad.clone() + + # With manager - checkpoint everything except final output + torch.manual_seed(42) + torch.cuda.manual_seed(42) + manager = CheckpointManager() + aggregated_ckpt, h_res_ckpt, h_post_ckpt = module.forward( + hidden_states_ckpt, mhc_recompute_manager=manager + ) + + aggregated_ckpt, _ = module.apply_h_post( + (0.1 * aggregated_ckpt, None), h_post_ckpt, manager=manager + ) + mixed_ckpt = module.apply_h_res( + h_res_ckpt, residual_ckpt + ) # Apply h_res to get mixed [s, b, n*C] + # Simulate BDA that is NOT checkpointed (last layer) - this is the hook_tensor + output_ckpt = aggregated_ckpt + 0.5 * mixed_ckpt + + # Register unified recompute on the output (which is not checkpointed) + manager.discard_all_outputs_and_register_unified_recompute(output_ckpt) + + loss_ckpt = output_ckpt.sum() + loss_ckpt.backward() + grad_hidden_ckpt = hidden_states_ckpt.grad.clone() + + # Verify gradients match + assert torch.allclose(grad_hidden_ckpt, grad_hidden_ref, atol=1e-5) + + +class TestTransformerConfigRecomputeMhc: + """Test 'mhc' in recompute_modules configuration.""" + + def test_config_default_value(self): + """Test that 'mhc' is not in recompute_modules by default.""" + config = TransformerConfig(num_layers=2, hidden_size=64, num_attention_heads=4) + assert "mhc" not in config.recompute_modules + + def test_config_enable_mhc_recompute(self): + """Test enabling 'mhc' in recompute_modules.""" + config = TransformerConfig( + num_layers=2, + hidden_size=64, + num_attention_heads=4, + enable_hyper_connections=True, + num_residual_streams=4, + recompute_modules=["core_attn", "mhc"], + recompute_granularity='selective', + ) + assert "mhc" in config.recompute_modules + assert config.enable_hyper_connections is True + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit_tests/transformer/test_mhc_block_manager.py b/tests/unit_tests/transformer/test_mhc_block_manager.py new file mode 100644 index 00000000000..aab004d6516 --- /dev/null +++ b/tests/unit_tests/transformer/test_mhc_block_manager.py @@ -0,0 +1,397 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import pytest +import torch + +from megatron.core.tensor_parallel.random import ( + CheckpointManager, + CheckpointWithoutOutput, + initialize_rng_tracker, +) +from tests.unit_tests.test_utilities import Utils + + +class TestCheckpointWithoutOutputManagerAPI: + """Test CheckpointWithoutOutput integration with CheckpointManager.""" + + def setup_method(self, method): + Utils.initialize_model_parallel() + initialize_rng_tracker(force_reset=True) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def test_auto_register(self): + """CheckpointWithoutOutput auto-registers to manager when ckpt_manager is provided.""" + manager = CheckpointManager() + + def func(x): + return x * 2 + 1 + + input_t = torch.randn(4, 4, device='cuda', requires_grad=True) + + ckpt = CheckpointWithoutOutput(ckpt_manager=manager) + y = ckpt.checkpoint(func, input_t) + + assert len(manager.checkpoints) == 1 + assert manager.checkpoints[0] is ckpt + + ckpt2 = CheckpointWithoutOutput(ckpt_manager=manager) + y2 = ckpt2.checkpoint(torch.nn.functional.gelu, y) + + assert len(manager.checkpoints) == 2 + assert manager.checkpoints[1] is ckpt2 + + loss = y2.sum() + manager.discard_all_outputs_and_register_unified_recompute(loss) + loss.backward() + + assert input_t.grad is not None + + def test_discard_is_noop_with_manager(self): + """discard_output_and_register_recompute is a NO-OP when ckpt_manager is set.""" + manager = CheckpointManager() + + def func1(x): + return x * 2 + + def func2(x): + return torch.nn.functional.gelu(x) + + input_ref = torch.randn(4, 4, device='cuda', requires_grad=True) + y1_ref = func1(input_ref) + y2_ref = func2(y1_ref) + loss_ref = y2_ref.sum() + loss_ref.backward() + grad_ref = input_ref.grad.clone() + + input_ckpt = input_ref.detach().clone().requires_grad_(True) + + ckpt1 = CheckpointWithoutOutput(ckpt_manager=manager) + y1 = ckpt1.checkpoint(func1, input_ckpt) + ckpt1.discard_output_and_register_recompute(y1) + + ckpt2 = CheckpointWithoutOutput(ckpt_manager=manager) + y2 = ckpt2.checkpoint(func2, y1) + ckpt2.discard_output_and_register_recompute(y2) + + assert y1.untyped_storage().size() > 0, "y1 should NOT be discarded yet" + assert y2.untyped_storage().size() > 0, "y2 should NOT be discarded yet" + + loss_ckpt = y2.sum() + manager.discard_all_outputs_and_register_unified_recompute(loss_ckpt) + + assert y1.untyped_storage().size() == 0, "y1 should be discarded after manager call" + assert y2.untyped_storage().size() == 0, "y2 should be discarded after manager call" + + loss_ckpt.backward() + grad_ckpt = input_ckpt.grad.clone() + + assert torch.allclose(grad_ckpt, grad_ref, atol=1e-6) + + def test_backward_compat_without_manager(self): + """CheckpointWithoutOutput without ckpt_manager should work exactly as before.""" + + def func(x): + return torch.nn.functional.gelu(x) + + input_ref = torch.randn(4, 4, device='cuda', requires_grad=True) + y_ref = func(input_ref) + z_ref = y_ref * 2 + loss_ref = z_ref.sum() + loss_ref.backward() + grad_ref = input_ref.grad.clone() + + input_ckpt = input_ref.detach().clone().requires_grad_(True) + + ckpt = CheckpointWithoutOutput() + y = ckpt.checkpoint(func, input_ckpt) + z = y * 2 + ckpt.discard_output_and_register_recompute(z) + + assert y.untyped_storage().size() == 0 + + loss_ckpt = z.sum() + loss_ckpt.backward() + grad_ckpt = input_ckpt.grad.clone() + + assert torch.allclose(grad_ckpt, grad_ref, atol=1e-6) + + def test_error_handling(self): + """CheckpointManager rejects invalid add_checkpoint calls.""" + manager = CheckpointManager() + + with pytest.raises(TypeError): + manager.add_checkpoint("not a checkpoint") + + ckpt = CheckpointWithoutOutput() + with pytest.raises(ValueError): + manager.add_checkpoint(ckpt) + + +class TestCheckpointManagerSequentialChain: + """Test CheckpointManager with sequential checkpoint chains.""" + + def setup_method(self, method): + Utils.initialize_model_parallel() + initialize_rng_tracker(force_reset=True) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def test_basic_sequential_chain(self): + """Three sequential checkpoints: gradients match non-checkpointed version.""" + + def func1(x): + return x * 2 + 1 + + def func2(x): + return torch.nn.functional.gelu(x) + + def func3(x): + return x * x + x + + input_ref = torch.randn(4, 4, device='cuda', requires_grad=True) + input_ckpt = input_ref.detach().clone().requires_grad_(True) + + y1_ref = func1(input_ref) + y2_ref = func2(y1_ref) + y3_ref = func3(y2_ref) + loss_ref = y3_ref.sum() + loss_ref.backward() + grad_ref = input_ref.grad.clone() + + manager = CheckpointManager() + + y1 = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint(func1, input_ckpt) + y2 = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint(func2, y1) + y3 = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint(func3, y2) + + loss_ckpt = y3.sum() + manager.discard_all_outputs_and_register_unified_recompute(loss_ckpt) + + assert y1.untyped_storage().size() == 0, "y1 storage should be released" + assert y2.untyped_storage().size() == 0, "y2 storage should be released" + assert y3.untyped_storage().size() == 0, "y3 storage should be released" + + loss_ckpt.backward() + grad_ckpt = input_ckpt.grad.clone() + + assert torch.allclose( + grad_ckpt, grad_ref, atol=1e-6 + ), f"Gradients mismatch!\nWith manager: {grad_ckpt}\nReference: {grad_ref}" + + def test_sequential_chain_with_dropout(self): + """RNG state is restored during recompute so dropout gradients match.""" + + def func_with_dropout(x): + return torch.nn.functional.dropout(x, p=0.3, training=True) + + def func2(x): + return torch.nn.functional.gelu(x) + + input_ref = torch.randn(4, 4, device='cuda', requires_grad=True) + input_ckpt = input_ref.detach().clone().requires_grad_(True) + + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + y1_ref = func_with_dropout(input_ref) + y2_ref = func2(y1_ref) + loss_ref = y2_ref.sum() + loss_ref.backward() + grad_ref = input_ref.grad.clone() + + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + manager = CheckpointManager() + + y1 = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint(func_with_dropout, input_ckpt) + y2 = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint(func2, y1) + + loss_ckpt = y2.sum() + manager.discard_all_outputs_and_register_unified_recompute(loss_ckpt) + + loss_ckpt.backward() + grad_ckpt = input_ckpt.grad.clone() + + assert torch.allclose( + grad_ckpt, grad_ref, atol=1e-6 + ), f"Gradients with dropout mismatch!\nWith manager: {grad_ckpt}\nReference: {grad_ref}" + + def test_multiple_outputs(self): + """CheckpointManager handles functions that return multiple outputs.""" + + def func_multi_output(x): + return x * 2, x + 1 + + def func_combine(a, b): + return a + b + + input_ref = torch.randn(4, 4, device='cuda', requires_grad=True) + input_ckpt = input_ref.detach().clone().requires_grad_(True) + + y1a_ref, y1b_ref = func_multi_output(input_ref) + y2_ref = func_combine(y1a_ref, y1b_ref) + loss_ref = y2_ref.sum() + loss_ref.backward() + grad_ref = input_ref.grad.clone() + + manager = CheckpointManager() + + y1a, y1b = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint( + func_multi_output, input_ckpt + ) + y2 = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint(func_combine, y1a, y1b) + + loss_ckpt = y2.sum() + manager.discard_all_outputs_and_register_unified_recompute(loss_ckpt) + + loss_ckpt.backward() + grad_ckpt = input_ckpt.grad.clone() + + assert torch.allclose(grad_ckpt, grad_ref, atol=1e-6), ( + f"Gradients mismatch with multiple outputs!\n" + f"With manager: {grad_ckpt}\nReference: {grad_ref}" + ) + + +class TestCheckpointManagerPartialCheckpoint: + """Test CheckpointManager with partial checkpointing (some ops not checkpointed).""" + + def setup_method(self, method): + Utils.initialize_model_parallel() + initialize_rng_tracker(force_reset=True) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def test_partial_checkpoint(self): + """ + Only f and h are checkpointed; g is a regular operation. + + Computation chain: + a --[f]--> b --[g]--> c --[h]--> d --[sum]--> loss + """ + + def func_f(x): + return torch.nn.functional.gelu(x * 2 + 1) + + def func_g(x): + return x * 3 - 2 + + def func_h(x): + return torch.sigmoid(x) + x + + input_ref = torch.randn(4, 4, device='cuda', requires_grad=True) + + b_ref = func_f(input_ref) + c_ref = func_g(b_ref) + d_ref = func_h(c_ref) + loss_ref = d_ref.sum() + loss_ref.backward() + grad_ref = input_ref.grad.clone() + + input_ckpt = input_ref.detach().clone().requires_grad_(True) + + manager = CheckpointManager() + + b = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint(func_f, input_ckpt) + c = func_g(b) + d = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint(func_h, c) + + loss_ckpt = d.sum() + manager.discard_all_outputs_and_register_unified_recompute(loss_ckpt) + + assert b.untyped_storage().size() == 0, "b storage should be released" + assert d.untyped_storage().size() == 0, "d storage should be released" + assert c.untyped_storage().size() > 0, "c storage should NOT be released (not checkpointed)" + + loss_ckpt.backward() + grad_ckpt = input_ckpt.grad.clone() + + assert torch.allclose(grad_ckpt, grad_ref, atol=1e-6), ( + f"Gradients mismatch with partial checkpoint!\n" + f"With manager: {grad_ckpt}\nReference: {grad_ref}" + ) + + def test_partial_checkpoint_with_tuple_output(self): + """ + Mimics HyperConnection's computation pattern with tuple outputs. + + - compute_mappings: checkpointed, returns tuple (h_pre, h_post, h_res) + - aggregate: NOT checkpointed + - apply_h_res: checkpointed + - apply_h_post: checkpointed + """ + + def compute_mappings(x): + h_pre = torch.sigmoid(x.mean(dim=-1, keepdim=True).expand_as(x)) + h_post = torch.tanh(x.sum(dim=-1, keepdim=True).expand_as(x)) + h_res = torch.relu(x) + return h_pre, h_post, h_res + + def aggregate(x, h_pre): + return x * h_pre + + def apply_h_res(h_res, residual): + return h_res + residual * 0.5 + + def apply_h_post(y, h_post): + return y * h_post + y + + x_ref = torch.randn(4, 4, device='cuda', requires_grad=True) + residual_ref = torch.randn(4, 4, device='cuda', requires_grad=True) + + h_pre_ref, h_post_ref, h_res_ref = compute_mappings(x_ref) + agg_ref = aggregate(x_ref, h_pre_ref) + y_ref = torch.nn.functional.gelu(agg_ref) + mixed_ref = apply_h_res(h_res_ref, residual_ref) + output_ref = apply_h_post(y_ref, h_post_ref) + final_ref = output_ref + mixed_ref + loss_ref = final_ref.sum() + loss_ref.backward() + grad_x_ref = x_ref.grad.clone() + grad_residual_ref = residual_ref.grad.clone() + + x_ckpt = x_ref.detach().clone().requires_grad_(True) + residual_ckpt = residual_ref.detach().clone().requires_grad_(True) + + manager = CheckpointManager() + + h_pre, h_post, h_res = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint( + compute_mappings, x_ckpt + ) + agg = aggregate(x_ckpt, h_pre) + y = torch.nn.functional.gelu(agg) + mixed = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint( + apply_h_res, h_res, residual_ckpt + ) + output = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint(apply_h_post, y, h_post) + + final = output + mixed + loss_ckpt = final.sum() + + manager.discard_all_outputs_and_register_unified_recompute(loss_ckpt) + + assert h_pre.untyped_storage().size() == 0, "h_pre storage should be released" + assert h_post.untyped_storage().size() == 0, "h_post storage should be released" + assert h_res.untyped_storage().size() == 0, "h_res storage should be released" + assert mixed.untyped_storage().size() == 0, "mixed storage should be released" + assert output.untyped_storage().size() == 0, "output storage should be released" + + assert agg.untyped_storage().size() > 0, "agg storage should NOT be released" + assert y.untyped_storage().size() > 0, "y storage should NOT be released" + + loss_ckpt.backward() + grad_x_ckpt = x_ckpt.grad.clone() + grad_residual_ckpt = residual_ckpt.grad.clone() + + assert torch.allclose( + grad_x_ckpt, grad_x_ref, atol=1e-6 + ), f"Gradients for x mismatch!\nWith manager: {grad_x_ckpt}\nReference: {grad_x_ref}" + assert torch.allclose(grad_residual_ckpt, grad_residual_ref, atol=1e-6), ( + f"Gradients for residual mismatch!\n" + f"With manager: {grad_residual_ckpt}\nReference: {grad_residual_ref}" + ) diff --git a/tests/unit_tests/transformer/test_transformer_layer.py b/tests/unit_tests/transformer/test_transformer_layer.py index da1f9ce5860..995e99d6a24 100644 --- a/tests/unit_tests/transformer/test_transformer_layer.py +++ b/tests/unit_tests/transformer/test_transformer_layer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import pytest @@ -8,17 +8,41 @@ from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedTensor from megatron.core.inference.contexts import StaticInferenceContext from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_with_transformer_engine_spec, get_gpt_layer_with_transformer_engine_submodules, ) -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.tensor_parallel.random import CheckpointManager, model_parallel_cuda_manual_seed from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import ( + HyperConnectionTransformerLayer, TransformerLayer, get_transformer_layer_offset, ) from tests.unit_tests.test_utilities import Utils +def _make_mhc_config(hidden_size=64, num_streams=4, **extra): + """Build a TransformerConfig with common MHC defaults. + + Any default can be overridden via **extra + (e.g. ``_make_mhc_config(num_layers=8, recompute_modules=["core_attn", "mhc"])``). + """ + base = dict( + num_layers=2, + hidden_size=hidden_size, + num_attention_heads=4, + use_cpu_initialization=True, + enable_hyper_connections=True, + num_residual_streams=num_streams, + mhc_sinkhorn_iterations=5, + mhc_init_gating_factor=0.01, + hidden_dropout=0.0, + attention_dropout=0.0, + ) + base.update(extra) + return TransformerConfig(**base) + + class TestParallelTransformerLayer: def setup_method(self, method): @@ -313,3 +337,761 @@ def get_tensor_shapes_for_tp(transformer_config, tp_size): 'self_attention.linear_qkv.weight': (hs * 3 // tp_size, hs), 'self_attention.linear_qkv.bias': (hs * 3 // tp_size,), } + + +class TestTransformerLayerWithHyperConnectionRecompute: + """Test TransformerLayer with HyperConnection and MHC block recomputation.""" + + def setup_method(self, method): + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def _create_layer_with_hyper_connection( + self, hidden_size=64, num_streams=4, layer_number=1, **extra + ): + """Create a HyperConnectionTransformerLayer with hyper connection enabled.""" + config = _make_mhc_config( + hidden_size=hidden_size, + num_streams=num_streams, + recompute_modules=["core_attn", "mhc"], + recompute_granularity='selective', + **extra, + ) + layer_spec = get_gpt_layer_with_transformer_engine_spec(enable_hyper_connection=True) + layer = HyperConnectionTransformerLayer( + config, layer_spec.submodules, layer_number=layer_number + ) + layer.cuda() + return layer, config + + def test_forward_with_hyper_connection_recompute(self): + """ + Test that TransformerLayer forward works correctly with HyperConnection + and MHC block recomputation enabled. + """ + hidden_size = 64 + num_streams = 4 + seq_len = 8 + batch_size = 2 + + layer, config = self._create_layer_with_hyper_connection(hidden_size, num_streams) + layer.train() # Enable training mode for recomputation + + # Input shape: [seq_len, batch_size, n * hidden_size] for hyper connections + n_channels = num_streams * hidden_size + hidden_states = torch.randn( + seq_len, batch_size, n_channels, device='cuda', requires_grad=True + ) + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + # Create manager for MHC block recomputation + manager = CheckpointManager() + + # Forward pass with recompute manager + manager.is_last_layer_in_recompute_block = True + output, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + mhc_recompute_manager=manager, + ) + + # Verify output shape + assert output.shape == ( + seq_len, + batch_size, + n_channels, + ), f"Expected output shape {(seq_len, batch_size, n_channels)}, got {output.shape}" + + # Register unified recompute hook at block boundary. + manager.discard_all_outputs_and_register_unified_recompute(output) + + # Backward pass should work without error + loss = output.sum() + loss.backward() + + # Verify gradients exist + assert hidden_states.grad is not None, "Gradients should be computed for hidden_states" + assert hidden_states.grad.shape == hidden_states.shape + + def test_intermediate_layer_with_recompute(self): + """ + Test TransformerLayer as an intermediate layer (not last in block). + In this case, MLP BDA should also be checkpointed. + """ + hidden_size = 64 + num_streams = 4 + seq_len = 8 + batch_size = 2 + + layer, config = self._create_layer_with_hyper_connection(hidden_size, num_streams) + layer.train() + + n_channels = num_streams * hidden_size + hidden_states = torch.randn( + seq_len, batch_size, n_channels, device='cuda', requires_grad=True + ) + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + manager = CheckpointManager() + + # Forward pass - NOT the last layer in block + manager.is_last_layer_in_recompute_block = False + output, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + mhc_recompute_manager=manager, + ) + + # Verify output shape + assert output.shape == (seq_len, batch_size, n_channels) + + # Backward pass should work + loss = output.sum() + # For intermediate layers, we need to pass output to next layer + # Here we just register the recompute hook on output for testing + manager.discard_all_outputs_and_register_unified_recompute(loss) + + loss.backward() + + assert hidden_states.grad is not None + assert hidden_states.grad.shape == hidden_states.shape + + def test_multiple_layers_chain_with_recompute(self): + """ + Test multiple TransformerLayers chained together with a single + CheckpointManager, simulating TransformerBlock behavior. + """ + hidden_size = 64 + num_streams = 4 + seq_len = 8 + batch_size = 2 + num_layers = 3 + + layers = [ + self._create_layer_with_hyper_connection( + hidden_size, num_streams, layer_number=i + 1, num_layers=num_layers + )[0] + for i in range(num_layers) + ] + + for layer in layers: + layer.train() + + n_channels = num_streams * hidden_size + hidden_states = torch.randn( + seq_len, batch_size, n_channels, device='cuda', requires_grad=True + ) + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + # Single manager for all layers (like TransformerBlock) + manager = CheckpointManager() + + # Forward through all layers + h = hidden_states + for i, layer in enumerate(layers): + is_last = i == num_layers - 1 + manager.is_last_layer_in_recompute_block = is_last + h, _ = layer( + hidden_states=h, attention_mask=attention_mask, mhc_recompute_manager=manager + ) + if is_last: + manager.discard_all_outputs_and_register_unified_recompute(h) + + # Backward pass + loss = h.sum() + loss.backward() + + # Verify gradients + assert hidden_states.grad is not None + assert hidden_states.grad.shape == hidden_states.shape + # Check that gradient is non-trivial (not all zeros) + assert hidden_states.grad.abs().sum() > 0 + + +class TestMHCRecomputeMemorySaving: + """Verify that 'mhc' in recompute_modules actually reduces peak GPU memory.""" + + def setup_method(self, method): + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @staticmethod + def _run_forward_backward( + num_layers, + hidden_size, + num_streams, + seq_len, + batch_size, + use_recompute, + recompute_block_size=2, + ): + """Run a full forward + backward pass and return (peak memory, output grad). + + When use_recompute=True, a new CheckpointManager is created every + `recompute_block_size` layers, mirroring TransformerBlock's + _build_mhc_recompute_layer_plan logic. + """ + config = _make_mhc_config( + hidden_size=hidden_size, + num_streams=num_streams, + num_layers=num_layers, + recompute_modules=["core_attn", "mhc"] if use_recompute else None, + recompute_granularity='selective' if use_recompute else None, + ) + layer_spec = get_gpt_layer_with_transformer_engine_spec(enable_hyper_connection=True) + layers = [ + HyperConnectionTransformerLayer( + config, layer_spec.submodules, layer_number=i + 1 + ).cuda() + for i in range(num_layers) + ] + for layer in layers: + layer.train() + + n_channels = num_streams * hidden_size + hidden_states = torch.randn( + seq_len, batch_size, n_channels, device='cuda', requires_grad=True + ) + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + + manager = CheckpointManager() if use_recompute else None + + h = hidden_states + for i, layer in enumerate(layers): + is_last_in_block = (i == num_layers - 1) or ((i + 1) % recompute_block_size == 0) + kwargs = dict(hidden_states=h, attention_mask=attention_mask) + if manager is not None: + manager.is_last_layer_in_recompute_block = is_last_in_block + kwargs['mhc_recompute_manager'] = manager + h, _ = layer(**kwargs) + if manager is not None and is_last_in_block: + manager.discard_all_outputs_and_register_unified_recompute(h) + if i < num_layers - 1: + manager = CheckpointManager() + + loss = h.sum() + loss.backward() + torch.cuda.synchronize() + + peak_mem = torch.cuda.max_memory_allocated() + grad = hidden_states.grad.clone() + + del layers, hidden_states, h, loss, manager + torch.cuda.empty_cache() + + return peak_mem, grad + + def test_recompute_reduces_peak_memory(self): + """Peak memory with recompute (block_size=2) should be lower than without.""" + num_layers = 8 + hidden_size = 128 + num_streams = 4 + seq_len = 64 + batch_size = 4 + + peak_no_recompute, _ = self._run_forward_backward( + num_layers, hidden_size, num_streams, seq_len, batch_size, use_recompute=False + ) + peak_recompute, _ = self._run_forward_backward( + num_layers, + hidden_size, + num_streams, + seq_len, + batch_size, + use_recompute=True, + recompute_block_size=2, + ) + + saving_pct = (peak_no_recompute - peak_recompute) / peak_no_recompute * 100 + + assert peak_recompute < peak_no_recompute, ( + f"Recompute should reduce peak memory, but got " + f"no_recompute={peak_no_recompute / 1e6:.1f}MB vs " + f"recompute={peak_recompute / 1e6:.1f}MB " + f"(saving={saving_pct:.1f}%)" + ) + + +class TestMHCWithCudaGraph: + """Test HyperConnectionTransformerLayer compatibility with CUDA graphs. + + CUDA graph capture requires static computation graphs and fixed tensor shapes. + These tests verify that the mHC layer properly supports the CUDA graph interface + defined in GraphableMegatronModule and TransformerLayer. + """ + + def setup_method(self, method): + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123, use_cudagraphable_rng=True, force_reset_rng=True) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def _create_mhc_layer(self, hidden_size=64, num_streams=4, **extra_config): + config = _make_mhc_config(hidden_size=hidden_size, num_streams=num_streams, **extra_config) + layer_spec = get_gpt_layer_with_transformer_engine_spec(enable_hyper_connection=True) + layer = HyperConnectionTransformerLayer(config, layer_spec.submodules) + layer.cuda() + return layer, config + + def test_get_layer_static_inputs_shape_for_mhc(self): + """get_layer_static_inputs must return [s, b, n*C] for mHC layers. + + CUDA graph capture creates static buffers whose shapes are determined by + this method. If the shape is [s, b, C] instead of [s, b, n*C], the graph + capture will produce a shape mismatch at the first hyper connection module. + """ + layer, config = self._create_mhc_layer() + seq_length = 32 + micro_batch_size = 2 + + static_inputs = layer.get_layer_static_inputs(seq_length, micro_batch_size) + hidden_states = static_inputs["hidden_states"] + + expected_hidden_dim = config.num_residual_streams * config.hidden_size + assert hidden_states.shape[-1] == expected_hidden_dim, ( + f"get_layer_static_inputs returns hidden dim {hidden_states.shape[-1]} " + f"but mHC expects {expected_hidden_dim} (n={config.num_residual_streams} * " + f"C={config.hidden_size}). " + f"HyperConnectionTransformerLayer must override get_layer_static_inputs." + ) + + def test_submodules_under_cudagraphs_includes_hyper_connection(self): + """_get_submodules_under_cudagraphs must include hyper connection modules. + + CUDA graph manual hooks are set up for parameters of submodules returned + by this method. Missing hyper connection modules means their parameters + (mapping_proj, alpha_*, bias) will not get proper pre-forward hooks during + graph replay, leading to stale parameter values. + """ + layer, config = self._create_mhc_layer() + + submodules = layer._get_submodules_under_cudagraphs() + + hc_modules_found = any( + hasattr(m, 'mapping_proj') for submod in submodules for m in submod.modules() + ) + assert hc_modules_found, ( + "_get_submodules_under_cudagraphs does not include HyperConnectionModule. " + "Parameters like mapping_proj, alpha_pre/post/res will not be updated " + "during CUDA graph replay." + ) + + def test_forward_through_te_cuda_graph_capture_path(self): + """_te_cuda_graph_capture must produce correct output shapes for mHC. + + TE CUDA graph capture calls _te_cuda_graph_capture() during warmup. + For mHC layers, the input must be n-stream [s, b, n*C] and output must + also be [s, b, n*C]. + """ + layer, config = self._create_mhc_layer() + layer.eval() + + seq_len = 8 + batch_size = 2 + n_channels = config.num_residual_streams * config.hidden_size + + hidden_states = torch.randn(seq_len, batch_size, n_channels, device='cuda') + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + with torch.no_grad(): + outputs = layer._te_cuda_graph_capture( + hidden_states=hidden_states, attention_mask=attention_mask + ) + + if isinstance(outputs, tuple): + output = outputs[0] + else: + output = outputs + + assert output.shape == (seq_len, batch_size, n_channels), ( + f"_te_cuda_graph_capture output shape {output.shape} != " + f"expected {(seq_len, batch_size, n_channels)}" + ) + + def test_cuda_graph_fwd_bwd_with_hyper_connection(self): + """End-to-end CUDA graph capture and replay for forward+backward with mHC. + + Captures both the forward and backward pass of HyperConnectionTransformerLayer + into a torch.cuda.CUDAGraph and replays it with fresh input data, verifying + that the computation graph is fully static (capturable) and produces correct + output shapes and non-trivial gradients. + """ + layer, config = self._create_mhc_layer() + layer.train() + + seq_len = 8 + batch_size = 2 + n_channels = config.num_residual_streams * config.hidden_size + + static_input = torch.randn( + seq_len, batch_size, n_channels, device='cuda', requires_grad=True + ) + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + # Warmup on side stream to trigger lazy allocations + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + out, _ = layer(hidden_states=static_input, attention_mask=attention_mask) + out.sum().backward() + torch.cuda.current_stream().wait_stream(s) + + # Set .grad to None so backward allocates fresh gradient tensors in the + # graph's private memory pool during capture. + layer.zero_grad(set_to_none=True) + static_input.grad = None + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + output, _ = layer(hidden_states=static_input, attention_mask=attention_mask) + output.sum().backward() + + # Replay with new input data. + # Use no_grad because backward inside the captured graph already + # bumped the autograd version counter on static_input, making + # in-place copy_ illegal without disabling grad tracking. + with torch.no_grad(): + static_input.copy_(torch.randn_like(static_input)) + g.replay() + + assert output.shape == ( + seq_len, + batch_size, + n_channels, + ), f"Output shape {output.shape} != expected {(seq_len, batch_size, n_channels)}" + assert ( + static_input.grad is not None + ), "Gradients should be computed for static_input after graph replay" + assert static_input.grad.shape == static_input.shape + assert static_input.grad.abs().sum() > 0, "Gradients should be non-trivial" + + # Verify numerical consistency: graph replay should match eager execution + # with the same input and weights. + test_data = torch.randn(seq_len, batch_size, n_channels, device='cuda') + + with torch.no_grad(): + static_input.copy_(test_data) + g.replay() + graph_out = output.detach().clone() + graph_grad = static_input.grad.detach().clone() + + eager_input = test_data.clone().requires_grad_(True) + eager_output, _ = layer(hidden_states=eager_input, attention_mask=attention_mask) + eager_output.sum().backward() + + assert torch.allclose(graph_out, eager_output.detach(), atol=1e-5), ( + f"Graph vs eager output mismatch: " + f"max diff = {(graph_out - eager_output.detach()).abs().max().item()}" + ) + assert torch.allclose(graph_grad, eager_input.grad, atol=1e-5), ( + f"Graph vs eager gradient mismatch: " + f"max diff = {(graph_grad - eager_input.grad).abs().max().item()}" + ) + + def test_cuda_graph_fwd_bwd_with_hyper_connection_and_recompute(self): + """CUDA graph capture+replay for fwd+bwd with mHC and CheckpointManager. + + When a CheckpointManager is used, additional CheckpointWithoutOutput + objects are created for layernorm and hyper-connection operations. The + manager discards intermediate activations during forward (storage.resize_(0)) + and recomputes them during backward via a unified gradient hook. + This test verifies the full capture+replay still works correctly. + """ + layer, config = self._create_mhc_layer() + layer.train() + + seq_len = 8 + batch_size = 2 + n_channels = config.num_residual_streams * config.hidden_size + + static_input = torch.randn( + seq_len, batch_size, n_channels, device='cuda', requires_grad=True + ) + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + # Warmup on side stream; fresh manager per iteration to avoid stale state. + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + mgr = CheckpointManager() + mgr.is_last_layer_in_recompute_block = True + out, _ = layer( + hidden_states=static_input, + attention_mask=attention_mask, + mhc_recompute_manager=mgr, + ) + mgr.discard_all_outputs_and_register_unified_recompute(out) + out.sum().backward() + torch.cuda.current_stream().wait_stream(s) + + layer.zero_grad(set_to_none=True) + static_input.grad = None + + capture_mgr = CheckpointManager() + capture_mgr.is_last_layer_in_recompute_block = True + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + output, _ = layer( + hidden_states=static_input, + attention_mask=attention_mask, + mhc_recompute_manager=capture_mgr, + ) + capture_mgr.discard_all_outputs_and_register_unified_recompute(output) + output.sum().backward() + + # Replay with new input data. + with torch.no_grad(): + static_input.copy_(torch.randn_like(static_input)) + g.replay() + + assert output.shape == ( + seq_len, + batch_size, + n_channels, + ), f"Output shape {output.shape} != expected {(seq_len, batch_size, n_channels)}" + assert ( + static_input.grad is not None + ), "Gradients should be computed for static_input after graph replay" + assert static_input.grad.shape == static_input.shape + assert static_input.grad.abs().sum() > 0, "Gradients should be non-trivial" + + # Numerical consistency: graph replay vs eager with the same input. + test_data = torch.randn(seq_len, batch_size, n_channels, device='cuda') + + with torch.no_grad(): + static_input.copy_(test_data) + g.replay() + graph_out = output.detach().clone() + graph_grad = static_input.grad.detach().clone() + + eager_mgr = CheckpointManager() + eager_mgr.is_last_layer_in_recompute_block = True + eager_input = test_data.clone().requires_grad_(True) + eager_output, _ = layer( + hidden_states=eager_input, + attention_mask=attention_mask, + mhc_recompute_manager=eager_mgr, + ) + eager_mgr.discard_all_outputs_and_register_unified_recompute(eager_output) + eager_output.sum().backward() + + assert torch.allclose(graph_out, eager_output.detach(), atol=1e-5), ( + f"Graph vs eager output mismatch: " + f"max diff = {(graph_out - eager_output.detach()).abs().max().item()}" + ) + assert torch.allclose(graph_grad, eager_input.grad, atol=1e-5), ( + f"Graph vs eager gradient mismatch: " + f"max diff = {(graph_grad - eager_input.grad).abs().max().item()}" + ) + + def test_mcore_cudagraph_manager_with_mhc_recompute_manager(self): + """MCore CudaGraphManager must not crash on mhc_recompute_manager kwarg. + + When cuda_graph_impl="local" is set, TransformerLayer.__call__ routes + through MegatronModule.__call__ → CudaGraphManager.__call__, which + iterates over all kwargs to check supported types. CheckpointManager + (used by mhc_recompute_manager) is not a CUDA-graph-supported type. + + This test verifies that mhc_recompute_manager is properly extracted + from kwargs before the CudaGraphManager sees them, preventing the + AssertionError that would otherwise occur. + """ + layer, config = self._create_mhc_layer(cuda_graph_impl="local", cuda_graph_scope="attn") + layer.train() + + assert hasattr( + layer, 'cudagraph_manager' + ), "Layer should have cudagraph_manager with cuda_graph_impl='local'" + + seq_len = 8 + batch_size = 2 + n_channels = config.num_residual_streams * config.hidden_size + + hidden_states = torch.randn( + seq_len, batch_size, n_channels, device='cuda', requires_grad=True + ) + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + mgr = CheckpointManager() + mgr.is_last_layer_in_recompute_block = True + + output, context = layer( + hidden_states=hidden_states, attention_mask=attention_mask, mhc_recompute_manager=mgr + ) + + assert output.shape == (seq_len, batch_size, n_channels) + + def test_mcore_cudagraph_manager_without_mhc_recompute_manager(self): + """MCore CudaGraphManager path works when mhc_recompute_manager is None.""" + layer, config = self._create_mhc_layer(cuda_graph_impl="local", cuda_graph_scope="attn") + layer.train() + + seq_len = 8 + batch_size = 2 + n_channels = config.num_residual_streams * config.hidden_size + + hidden_states = torch.randn( + seq_len, batch_size, n_channels, device='cuda', requires_grad=True + ) + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + output, context = layer(hidden_states=hidden_states, attention_mask=attention_mask) + + assert output.shape == (seq_len, batch_size, n_channels) + + +class TestMHCWithOffloading: + """Test HyperConnectionTransformerLayer with fine-grained activation offloading. + + Fine-grained activation offloading transfers specific activations (e.g., layernorm + inputs) to CPU during forward and reloads them during backward. These tests verify + that the mHC layer's multi-stream architecture works correctly with offloading. + """ + + def setup_method(self, method): + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def _create_mhc_layer_with_offloading( + self, hidden_size=64, num_streams=4, offload_modules=None + ): + if offload_modules is None: + offload_modules = ["attn_norm", "mlp_norm"] + + config = _make_mhc_config( + hidden_size=hidden_size, + num_streams=num_streams, + fine_grained_activation_offloading=True, + offload_modules=offload_modules, + ) + layer_spec = get_gpt_layer_with_transformer_engine_spec(enable_hyper_connection=True) + layer = HyperConnectionTransformerLayer(config, layer_spec.submodules) + layer.cuda() + return layer, config + + def test_forward_backward_with_offloading(self): + """Forward+backward should work with activation offloading enabled. + + This exercises the off_interface context manager around layernorms in + the mHC forward path, including the group_commit that commits the + offloading group for the aggregated 1-stream layernorm input. + """ + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + PipelineOffloadManager, + ) + + layer, config = self._create_mhc_layer_with_offloading() + layer.train() + + seq_len = 8 + batch_size = 2 + n_channels = config.num_residual_streams * config.hidden_size + + hidden_states = torch.randn( + seq_len, batch_size, n_channels, device='cuda', requires_grad=True + ) + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + mgr = PipelineOffloadManager.get_instance() + mgr.init_model_chunk_offload_handler(vp_size=1, vp_stage=0, min_offloaded_tensor_size=0) + + output, context = layer(hidden_states=hidden_states, attention_mask=attention_mask) + + assert output.shape == ( + seq_len, + batch_size, + n_channels, + ), f"Output shape {output.shape} != expected {(seq_len, batch_size, n_channels)}" + + loss = output.sum() + loss.backward() + + assert hidden_states.grad is not None, "Gradients should flow through offloaded path" + assert hidden_states.grad.shape == hidden_states.shape + assert hidden_states.grad.abs().sum() > 0, "Gradients should be non-trivial" + + PipelineOffloadManager.reset_instance() + + def test_offloading_numerical_equivalence(self): + """Offloaded forward+backward must produce the same result as non-offloaded. + + Compares outputs and gradients between a layer with offloading disabled + vs enabled to ensure the offloading path does not corrupt activations. + """ + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + PipelineOffloadManager, + ) + + PipelineOffloadManager.reset_instance() + + hidden_size = 64 + num_streams = 4 + seq_len = 8 + batch_size = 2 + n_channels = num_streams * hidden_size + + torch.manual_seed(42) + input_data = torch.randn(seq_len, batch_size, n_channels, device='cuda') + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + # Run without offloading + config_no_offload = _make_mhc_config(hidden_size=hidden_size, num_streams=num_streams) + layer_spec = get_gpt_layer_with_transformer_engine_spec(enable_hyper_connection=True) + layer_no_offload = HyperConnectionTransformerLayer( + config_no_offload, layer_spec.submodules + ).cuda() + layer_no_offload.train() + + h1 = input_data.clone().detach().requires_grad_(True) + out1, _ = layer_no_offload(hidden_states=h1, attention_mask=attention_mask) + out1.sum().backward() + grad_no_offload = h1.grad.clone() + out1_detached = out1.detach().clone() + + # Run with offloading using the same weights + config_offload = _make_mhc_config( + hidden_size=hidden_size, + num_streams=num_streams, + fine_grained_activation_offloading=True, + offload_modules=["attn_norm", "mlp_norm"], + ) + layer_offload = HyperConnectionTransformerLayer( + config_offload, layer_spec.submodules + ).cuda() + layer_offload.load_state_dict(layer_no_offload.state_dict()) + layer_offload.train() + + mgr = PipelineOffloadManager.get_instance() + mgr.init_model_chunk_offload_handler(vp_size=1, vp_stage=0, min_offloaded_tensor_size=0) + + h2 = input_data.clone().detach().requires_grad_(True) + out2, _ = layer_offload(hidden_states=h2, attention_mask=attention_mask) + out2.sum().backward() + grad_offload = h2.grad.clone() + + PipelineOffloadManager.reset_instance() + + assert torch.allclose(out1_detached, out2.detach(), atol=1e-5), ( + f"Forward outputs differ: max diff = " + f"{(out1_detached - out2.detach()).abs().max().item()}" + ) + assert torch.allclose(grad_no_offload, grad_offload, atol=1e-5), ( + f"Gradients differ: max diff = " + f"{(grad_no_offload - grad_offload).abs().max().item()}" + ) From 00c333927b864236f36e6e9284d5d5e7c6eb6c35 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Mon, 27 Apr 2026 09:25:09 -0700 Subject: [PATCH 03/10] Remove mHC transformer config trailing whitespace --- .../core/transformer/transformer_config.py | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 7740b09012b..a127792f684 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -60,8 +60,8 @@ class TransformerConfig(ModelParallelConfig): mtp_loss_scaling_factor: Optional[float] = 0.1 """Weighting factor of Multi-Token Prediction (MTP) loss. - We compute the average of the MTP losses across all depths, - and multiply it the scaling factor to obtain the overall MTP loss, + We compute the average of the MTP losses across all depths, + and multiply it the scaling factor to obtain the overall MTP loss, which serves as an additional training objective. """ @@ -90,8 +90,8 @@ class TransformerConfig(ModelParallelConfig): - list: e.g., [['embedding', 'decoder'], ['decoder', 'decoder', 'decoder', 'loss']]. - PipelineParallelLayerLayout: a PipelineParallelLayerLayout object. If given either a string or a list, it will be transferred into a PipelineParallelLayerLayout - in post init. Let i = a * pp_size + b, then layout[i] gives a list of the layers - in the a-th vpp stage and the b-th pp stage, i.e., vpp(0)pp(0), vpp(0)pp(1), ..., + in post init. Let i = a * pp_size + b, then layout[i] gives a list of the layers + in the a-th vpp stage and the b-th pp stage, i.e., vpp(0)pp(0), vpp(0)pp(1), ..., vpp(i)pp(j), vpp(i)pp(j+1), ..., vpp(-1)pp(-2), vpp(-1)pp(-1). In the inner lists of layers, 'embedding' or 'E' denotes the embedding layer, 'loss' or 'L' denotes the loss function, and 'decoder' or 't' denotes the transformer decoder layer. @@ -132,8 +132,8 @@ class TransformerConfig(ModelParallelConfig): """Softmax scale for attention scaling.""" softmax_type: Literal['vanilla', 'off-by-one', 'learnable'] = 'vanilla' - """Applies modified softmax from https://www.evanmiller.org/attention-is-off-by-one.html. - Supports both TE FusedAttention and local unfused attention. Supports both a fixed offset and + """Applies modified softmax from https://www.evanmiller.org/attention-is-off-by-one.html. + Supports both TE FusedAttention and local unfused attention. Supports both a fixed offset and and learnable offset.""" num_query_groups: Optional[int] = field( @@ -193,7 +193,7 @@ class TransformerConfig(ModelParallelConfig): The stored input is casted back to the original precision before backprop compuatation.""" glu_linear_offset: float = 0.0 - """Offset term in the GLU activation function: activation_func(x[0]) * (x[1] + offset). Only + """Offset term in the GLU activation function: activation_func(x[0]) * (x[1] + offset). Only used when gated_linear_unit is True""" activation_func_clamp_value: Optional[float] = None @@ -288,7 +288,7 @@ class TransformerConfig(ModelParallelConfig): # linear attention #################### linear_attention_freq: Optional[Union[int, List[int]]] = None - """Frequency between LA (linear attention) layers + """Frequency between LA (linear attention) layers and SDPA (scaled dot-product attention) layers. Accepts either: - An integer N: Represents a (N-1):N ratio, meaning (N-1) LA layers for every 1 SDPA layer @@ -330,13 +330,13 @@ class TransformerConfig(ModelParallelConfig): embedding_init_method: Optional[Callable] = None """ - Method to initialize weights of the embedding layer. If None, will be set as described + Method to initialize weights of the embedding layer. If None, will be set as described in init_method above. """ embedding_init_method_std: Optional[float] = None """ - Standard deviation of the zero mean normal for the default initialization method for the + Standard deviation of the zero mean normal for the default initialization method for the embedding layer. If None, will be set to init_method_std. Setting this to a value around 1.0 may avoid loss spikes in training. Setting this to any value will also skip applying weight decay on embedding weights to avoid shrinkage towards zero. @@ -586,7 +586,7 @@ class TransformerConfig(ModelParallelConfig): fp4: Optional[Literal['e2m1']] = field( default=None, metadata={"argparse_meta": {"arg_names": ["--fp4-format"]}} ) - """If set, enables the use of FP4 precision through Transformer Engine. Currently only + """If set, enables the use of FP4 precision through Transformer Engine. Currently only supports 'nvfp4' which uses NVFP4BlockScaling recipe (requires TE >= 2.7.0.dev0).""" fp4_recipe: Optional[Literal['nvfp4', 'custom']] = "nvfp4" @@ -620,12 +620,12 @@ class TransformerConfig(ModelParallelConfig): in the hidden_states gradient.""" moe_shared_expert_gate: bool = False - """Enable gate for shared expert. Only effective when + """Enable gate for shared expert. Only effective when moe-shared-expert-intermediate-size is set.""" moe_shared_expert_overlap: bool = False """Enable overlapping between shared expert computations and dispatcher communications. - Without this, the shared experts execute before the router. + Without this, the shared experts execute before the router. Only effective when moe-shared-expert-intermediate-size is set. """ @@ -719,7 +719,7 @@ class TransformerConfig(ModelParallelConfig): The default value 1e-3 is same as that used in DeepSeekV3.""" moe_router_force_load_balancing: bool = False - """[Experimental] Force load balancing with random logits for MoE router, supports naive topk + """[Experimental] Force load balancing with random logits for MoE router, supports naive topk and group-limited topk. This is an experimental feature and only for benchmark.""" moe_router_force_biased: Optional[float] = None @@ -760,7 +760,7 @@ class TransformerConfig(ModelParallelConfig): moe_flex_dispatcher_backend: Literal['deepep', 'hybridep'] = "deepep" """[Experimental] The backend to use for flex token dispatcher. The default is "deepep". - Options are "deepep" and "hybridep". Currently only "hybridep" backend supports + Options are "deepep" and "hybridep". Currently only "hybridep" backend supports the MNNVL case.""" moe_per_layer_logging: bool = False @@ -902,14 +902,14 @@ class TransformerConfig(ModelParallelConfig): mhc_recompute_layer_num: Optional[int] = None """Number of layers per MHC recompute block. - + When set, every `mhc_recompute_layer_num` layers form a recompute block. The last layer in each recompute block (i.e., layer_number % mhc_recompute_layer_num == 0 or the final layer in the transformer block) will: - NOT checkpoint its final MLP BDA - Register the unified recompute hook on its MLP BDA output - A new CheckpointManager is created for subsequent layers - + If None, all layers in the transformer block share a single recompute block. Must be a positive integer when set.""" @@ -933,7 +933,7 @@ class TransformerConfig(ModelParallelConfig): batch_invariant_mode: bool = False """If true, uses batch-invariant kernels that provide deterministic forward execution regardless of batch size. This ensures bitwise identical results when the same inputs are processed - in different batch configurations. This will significantly affect speed of + in different batch configurations. This will significantly affect speed of training and inference as the kernels are not full optimized. Defaults to False.""" @@ -2438,7 +2438,7 @@ class MLATransformerConfig(TransformerConfig): cache_mla_latents: bool = False """Cache the low dimensional tensors for MLA rather than full KV cache. - This is only for the dynamic inference backend and requires that + This is only for the dynamic inference backend and requires that Flash MLA is installed.""" mla_down_proj_fusion: bool = False From 3793a59ab99f66d65e760bb8e815509c33b99a2a Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Mon, 27 Apr 2026 13:49:59 -0700 Subject: [PATCH 04/10] Address mHC review feedback --- gpt_builders.py | 4 +- megatron/core/models/gpt/gpt_layer_specs.py | 24 +++---- megatron/core/tensor_parallel/random.py | 10 ++- megatron/core/transformer/hyper_connection.py | 68 ++++++------------- .../core/transformer/transformer_block.py | 12 ++-- .../core/transformer/transformer_config.py | 13 +++- .../core/transformer/transformer_layer.py | 17 ++--- .../unit_tests/models/test_gpt_layer_specs.py | 28 +++++--- .../test_hyper_connection_recompute.py | 44 ++++++++++++ .../transformer/test_transformer_layer.py | 10 +-- 10 files changed, 136 insertions(+), 94 deletions(-) diff --git a/gpt_builders.py b/gpt_builders.py index 59a8942e472..72e3bb8c550 100644 --- a/gpt_builders.py +++ b/gpt_builders.py @@ -136,7 +136,7 @@ def _get_transformer_layer_spec(use_te, config): use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False), - enable_hyper_connection=config.enable_hyper_connections, + enable_hyper_connections=config.enable_hyper_connections, ) elif config.transformer_impl == "inference_optimized": return get_gpt_layer_with_inference_spec( @@ -155,5 +155,5 @@ def _get_transformer_layer_spec(use_te, config): use_kitchen=config.use_kitchen, use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, - enable_hyper_connection=config.enable_hyper_connections, + enable_hyper_connections=config.enable_hyper_connections, ) diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index a097e966f68..9ca0e9494c7 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -186,7 +186,7 @@ def get_gpt_layer_with_transformer_engine_submodules( use_kitchen_attention: bool = False, kitchen_attention_backend: str = "sdpa", mla_down_proj_fusion: bool = False, - enable_hyper_connection: bool = False, + enable_hyper_connections: bool = False, ) -> TransformerLayerSubmodules: """Use these submodules to use lower-level Transformer Engine modules (required for fp8 training). @@ -204,7 +204,7 @@ def get_gpt_layer_with_transformer_engine_submodules( mla_down_proj_fusion (bool, optional): Enable fused q/kv down-projection and fused input layernorm when backend supports. Otherwise fall back to the unfused MLA. - enable_hyper_connection (bool): Use HyperConnectionTransformerLayer with + enable_hyper_connections (bool): Use HyperConnectionTransformerLayer with HyperConnectionModule instead of plain TransformerLayer. Defaults to False. Returns: @@ -239,7 +239,7 @@ def get_gpt_layer_with_transformer_engine_submodules( use_te_activation_func=use_te_activation_func, ) - hc_module = HyperConnectionModule if enable_hyper_connection else IdentityOp + hc_module = HyperConnectionModule if enable_hyper_connections else IdentityOp if multi_latent_attention: assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA." @@ -354,7 +354,7 @@ def get_gpt_layer_with_transformer_engine_submodules( @copy_signature(get_gpt_layer_with_transformer_engine_submodules) def get_gpt_layer_with_transformer_engine_spec(*args, **kwargs) -> ModuleSpec: """Use this spec to use lower-level Transformer Engine modules (required for fp8 training).""" - enable_hc = kwargs.get('enable_hyper_connection', False) + enable_hc = kwargs.get('enable_hyper_connections', False) layer_module = HyperConnectionTransformerLayer if enable_hc else TransformerLayer return ModuleSpec( module=layer_module, @@ -373,7 +373,7 @@ def get_gpt_layer_local_submodules( use_kitchen: bool = False, use_kitchen_attention: bool = False, kitchen_attention_backend: str = "sdpa", - enable_hyper_connection: bool = False, + enable_hyper_connections: bool = False, ) -> TransformerLayerSubmodules: """Use these submodules for an implementation using only modules in Megatron-Core. @@ -385,7 +385,7 @@ def get_gpt_layer_local_submodules( multi_latent_attention (bool, optional): To use MLA. Defaults to False. fp8 (str, optional): Deprecated. For temporary Nemo compatibility. qk_l2_norm (bool, optional): To use l2 norm for queries/keys. Defaults to False. - enable_hyper_connection (bool): Use HyperConnectionTransformerLayer with + enable_hyper_connections (bool): Use HyperConnectionTransformerLayer with HyperConnectionModule instead of plain TransformerLayer. Defaults to False. Returns: @@ -419,7 +419,7 @@ def get_gpt_layer_local_submodules( backend=backend, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm ) - hc_module = HyperConnectionModule if enable_hyper_connection else IdentityOp + hc_module = HyperConnectionModule if enable_hyper_connections else IdentityOp if multi_latent_attention: assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA." @@ -481,7 +481,7 @@ def get_gpt_layer_local_submodules( @copy_signature(get_gpt_layer_local_submodules) def get_gpt_layer_local_spec(*args, **kwargs) -> ModuleSpec: """Use this spec for an implementation using only modules in Megatron-Core.""" - enable_hc = kwargs.get('enable_hyper_connection', False) + enable_hc = kwargs.get('enable_hyper_connections', False) layer_module = HyperConnectionTransformerLayer if enable_hc else TransformerLayer return ModuleSpec( module=layer_module, submodules=get_gpt_layer_local_submodules(*args, **kwargs) @@ -593,7 +593,7 @@ def get_gpt_decoder_layer_specs( use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False), - enable_hyper_connection=config.enable_hyper_connections, + enable_hyper_connections=config.enable_hyper_connections, ) moe_layer_spec = get_gpt_layer_with_transformer_engine_spec( num_experts=config.num_moe_experts, @@ -606,7 +606,7 @@ def get_gpt_decoder_layer_specs( use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False), - enable_hyper_connection=config.enable_hyper_connections, + enable_hyper_connections=config.enable_hyper_connections, ) elif config.transformer_impl == "inference_optimized": layer_norm_impl = TENorm @@ -635,7 +635,7 @@ def get_gpt_decoder_layer_specs( use_kitchen=config.use_kitchen, use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, - enable_hyper_connection=config.enable_hyper_connections, + enable_hyper_connections=config.enable_hyper_connections, ) moe_layer_spec = get_gpt_layer_local_spec( num_experts=config.num_moe_experts, @@ -647,7 +647,7 @@ def get_gpt_decoder_layer_specs( use_kitchen=config.use_kitchen, use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, - enable_hyper_connection=config.enable_hyper_connections, + enable_hyper_connections=config.enable_hyper_connections, ) # Parse config.moe_layer_freq to determine the pattern of expert/dense layers. diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py index 4516fe10d88..a43b1636799 100644 --- a/megatron/core/tensor_parallel/random.py +++ b/megatron/core/tensor_parallel/random.py @@ -747,6 +747,8 @@ def backward(ctx, *args): torch.autograd.backward(outputs, args) ctx.outputs = None ctx.inputs = None + # Autograd expects None for non-tensor inputs; returning the original + # non-tensor value as a "gradient" is invalid once args are restored. grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs) return (None, None) + grads @@ -791,9 +793,11 @@ def discard_all_outputs_and_register_unified_recompute(self, hook_tensor): hook_tensor.register_hook(self._unified_recompute_hook) def _unified_recompute_hook(self, grad_output): + # Chained checkpoints rely on forward-order recomputation. Each + # _recompute() restores output storage in-place on the original tensor + # objects, so a later checkpoint sees its saved input storage restored + # before it recomputes. for ckpt in self.checkpoints: - # Call _recompute for each checkpoint in forward order - # The _recompute method will restore the output tensor storage ckpt._recompute(None) @@ -822,6 +826,8 @@ def __init__(self, fp8=False, ckpt_manager=None): discard_output_and_register_recompute() will only discard output without registering individual hooks. """ + # Treat the default fp8=False as disabled. The old "fp8 is not None" + # behavior entered the TE FP8 recompute path for non-FP8 callers. self.fp8 = bool(fp8) self.ckpt_manager = ckpt_manager self.run_function = None diff --git a/megatron/core/transformer/hyper_connection.py b/megatron/core/transformer/hyper_connection.py index 64ec3107213..e15ea881ffd 100644 --- a/megatron/core/transformer/hyper_connection.py +++ b/megatron/core/transformer/hyper_connection.py @@ -17,12 +17,14 @@ @torch.compile def _sinkhorn_iterations(input_logits: Tensor, num_iterations: int, eps: float) -> Tensor: + input_dtype = input_logits.dtype + input_logits = input_logits.float() row_max = input_logits.max(dim=-1, keepdim=True).values M = torch.exp(input_logits - row_max) for _ in range(num_iterations): M = M / M.sum(dim=-1, keepdim=True).clamp(min=eps) M = M / M.sum(dim=-2, keepdim=True).clamp(min=eps) - return M + return M.to(dtype=input_dtype) class SinkhornKnopp(torch.autograd.Function): @@ -53,22 +55,22 @@ def backward(ctx, grad_output: Tensor): return logits.grad, None, None -def native_sinkhorn(input_logits: Tensor, num_iterations: int, eps: float = 1e-6) -> Tensor: - """Native Sinkhorn-Knopp (autograd.Function wrapper).""" +def reference_sinkhorn(input_logits: Tensor, num_iterations: int, eps: float = 1e-6) -> Tensor: + """Reference Sinkhorn-Knopp (autograd.Function wrapper).""" return SinkhornKnopp.apply(input_logits, num_iterations, eps) @torch.compile -def native_h_aggregate(x: Tensor, h_pre: Tensor) -> Tensor: - """Native n-stream weighted aggregation: out = sum_j(h_pre_j * x_j).""" +def reference_h_aggregate(x: Tensor, h_pre: Tensor) -> Tensor: + """Reference n-stream weighted aggregation: out = sum_j(h_pre_j * x_j).""" return (x * h_pre.unsqueeze(-1)).sum(dim=2) @torch.compile -def native_h_post_bda( +def reference_h_post_bda( h_res: Tensor, original_residual: Tensor, h_post: Tensor, x: Tensor, bias: Optional[Tensor] ) -> Tensor: - """Native H_res @ residual + H_post * (x [+ bias]).""" + """Reference H_res @ residual + H_post * (x [+ bias]).""" s, b, n, C = original_residual.shape h_res_batched = h_res.view(s * b, n, n) residual_batched = original_residual.view(s * b, n, C) @@ -81,8 +83,8 @@ def native_h_post_bda( @torch.compile -def native_proj_rms(x: Tensor, weight: Tensor, eps: float = 1e-6) -> Tuple[Tensor, Tensor]: - """Native fused projection + RMS normalization.""" +def reference_proj_rms(x: Tensor, weight: Tensor, eps: float = 1e-6) -> Tuple[Tensor, Tensor]: + """Reference fused projection + RMS normalization.""" proj = torch.matmul(x, weight.t()) norm = x.norm(dim=-1, keepdim=True) K = x.shape[-1] @@ -96,7 +98,6 @@ def native_proj_rms(x: Tensor, weight: Tensor, eps: float = 1e-6) -> Tuple[Tenso # ============================================================================ -# TODO: keep hyper connection in fp32 computation class HyperConnectionModule(MegatronModule): """ Unified mHC (Manifold-Constrained Hyper-Connections) module. @@ -124,7 +125,9 @@ def __init__(self, config: TransformerConfig, layer_number: int): self.hidden_size = config.hidden_size self.sinkhorn_iterations = config.mhc_sinkhorn_iterations - # Projection weights for dynamic mappings + # Projection weights for dynamic mappings. The reference implementation + # keeps this as a full, non-TP-partitioned linear projection over n*C; + # fused/partitioned variants are expected to replace it in follow-up work. # Input: [s, b, n*C] -> Output: n^2 + 2n values per token # - H_pre: n values # - H_post: n values @@ -159,10 +162,10 @@ def __init__(self, config: TransformerConfig, layer_number: int): self._h_post_bda_op = fused_h_post_bda self._proj_rms_op = fused_proj_rms else: - self._sinkhorn_op = native_sinkhorn - self._h_aggregate_op = native_h_aggregate - self._h_post_bda_op = native_h_post_bda - self._proj_rms_op = native_proj_rms + self._sinkhorn_op = reference_sinkhorn + self._h_aggregate_op = reference_h_aggregate + self._h_post_bda_op = reference_h_post_bda + self._proj_rms_op = reference_proj_rms self._init_weights() @@ -665,7 +668,7 @@ def _fused_wrapper(h_res, original_residual, h_post, x, *optional_bias): bda_func = get_bias_dropout_add(training, fused) has_bias = bias is not None - def _native_wrapper(h_res, original_residual, h_post, x, *optional_bias): + def _reference_wrapper(h_res, original_residual, h_post, x, *optional_bias): with torch.cuda.nvtx.range("HyperConnection::apply_h_res"): mixed = self.apply_h_res(h_res, original_residual) with torch.cuda.nvtx.range("HyperConnection::apply_h_post"): @@ -680,37 +683,8 @@ def _native_wrapper(h_res, original_residual, h_post, x, *optional_bias): ckpt = CheckpointWithoutOutput(ckpt_manager=manager) if has_bias: - output = ckpt.checkpoint(_native_wrapper, h_res, original_residual, h_post, x, bias) + output = ckpt.checkpoint(_reference_wrapper, h_res, original_residual, h_post, x, bias) else: - output = ckpt.checkpoint(_native_wrapper, h_res, original_residual, h_post, x) + output = ckpt.checkpoint(_reference_wrapper, h_res, original_residual, h_post, x) return output - - -# ==================== Checkpoint utilities for mHC ==================== - - -class HyperConnectionCheckpoint: - """ - Checkpoint utility for mHC intermediate activations. - - Implements the paper's "recomputing strategy" to reduce memory footprint - by discarding intermediate n-stream activations and recomputing on-the-fly. - """ - - @staticmethod - def compute_optimal_block_size(num_layers: int, num_streams: int) -> int: - """ - Compute optimal recomputation block size. - - From paper Eq. (20): L_r^* ≈ sqrt(nL/(n+2)) - - Args: - num_layers: Total number of transformer layers - num_streams: Number of residual streams (n) - - Returns: - block_size: Optimal block size for checkpointing - """ - block_size = int(math.sqrt(num_streams * num_layers / (num_streams + 2))) - return max(1, block_size) diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index 0048d18c3db..67d3da0e076 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -795,9 +795,10 @@ def forward( # is called here to be future-proof and corner-case-proof. hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) - # Expand hidden states for hyper connections at the start of the block - # Only expand at the first PP stage; subsequent stages receive n-stream from previous stage - if self.config.enable_hyper_connections and self.pre_process: + # Expand hidden states for hyper connections at the start of the block. + # Pipeline-parallel HC support is blocked in TransformerConfig until P2P + # tensor shapes carry the n-stream hidden dimension. + if self.config.enable_hyper_connections: hidden_states = HyperConnectionModule.input_expand( hidden_states, self.num_residual_streams ) # [s, b, C] -> [s, b, n*C] @@ -921,8 +922,9 @@ def forward( if (l_no + layer_offset) in extract_layer_indices: intermediate_hidden_states.append(hidden_states) - # Only contract if the final layer norm is in this stage - if self.config.enable_hyper_connections and self.has_final_layernorm_in_this_stage(): + # Contract n-stream hidden states at the block boundary even when the + # final layernorm is disabled; downstream output layers expect [s, b, C]. + if self.config.enable_hyper_connections: hidden_states = HyperConnectionModule.output_contract( hidden_states, self.num_residual_streams ) # [s, b, n*C] -> [s, b, C] diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index a127792f684..eea6e6aa04a 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1517,8 +1517,10 @@ def __post_init__(self): "tensor_pop on a None chunk. Disable one of them." ) - if self.enable_hyper_connections and not ( - self.recompute_granularity == "selective" and "mhc" in self.recompute_modules + if ( + self.enable_hyper_connections + and self.recompute_granularity is not None + and not (self.recompute_granularity == "selective" and "mhc" in self.recompute_modules) ): warnings.warn( "HyperConnections are enabled but 'mhc' is not in " @@ -1555,6 +1557,13 @@ def __post_init__(self): "Please disable MTP (set mtp_num_layers=None) when using hyper connections." ) + if self.enable_hyper_connections and self.pipeline_model_parallel_size > 1: + raise ValueError( + "enable_hyper_connections is not yet compatible with pipeline parallelism " + "(pipeline_model_parallel_size > 1). Pipeline-parallel support is planned " + "for a follow-up PR." + ) + if self.fine_grained_activation_offloading: assert ( not self.cpu_offloading diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 437993021d5..214aec813f5 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -706,9 +706,6 @@ def forward(self, *args, **kwargs): """ # Injected by __call__ for cuda graph keying; not a real forward arg. kwargs.pop("dynamic_inference_decode_only", None) - assert ( - not self.config.enable_hyper_connections - ), "Please use HyperConnectionTransformerLayer instead" hidden_states, context = self._forward_attention(*args, **kwargs) output = self._forward_mlp( hidden_states, @@ -1299,11 +1296,6 @@ def backward_dw_cudagraph(self, microbatch_idx): self.cuda_graphs[cg_index].backward_dw() def __call__(self, *args, **kwargs): - # Extract mhc_recompute_manager before CUDA graph manager processes kwargs, - # since CheckpointManager is not a CUDA-graph-supported type. - self._mhc_recompute_manager = kwargs.pop("mhc_recompute_manager", None) - kwargs.pop("is_last_layer_in_recompute_block", None) - if self._should_call_local_cudagraph(*args, **kwargs): # Inference mode. if kwargs.get('inference_context') is not None: @@ -1422,6 +1414,15 @@ def _get_submodules_under_cudagraphs(self): submodules.append(self.mlp_hyper_connection) return submodules + def __call__(self, *args, **kwargs): + # Extract before CUDA graph manager processes kwargs; CheckpointManager + # is not a CUDA-graph-supported argument type. + self._mhc_recompute_manager = kwargs.pop("mhc_recompute_manager", None) + try: + return super().__call__(*args, **kwargs) + finally: + self._mhc_recompute_manager = None + def forward(self, *args, **kwargs): """Forward pass with MHC recompute manager support.""" kwargs.pop("dynamic_inference_decode_only", None) diff --git a/tests/unit_tests/models/test_gpt_layer_specs.py b/tests/unit_tests/models/test_gpt_layer_specs.py index bfa86fd0241..649f4a98c8d 100644 --- a/tests/unit_tests/models/test_gpt_layer_specs.py +++ b/tests/unit_tests/models/test_gpt_layer_specs.py @@ -22,28 +22,28 @@ class TestGptLayerSpecsHyperConnection: - """Test that enable_hyper_connection controls module types in layer specs.""" + """Test that enable_hyper_connections controls module types in layer specs.""" @pytest.mark.parametrize( "factory,kwargs,expected_module,expected_hc", [ (_TE, {}, _TL, _ID), - (_TE, {"enable_hyper_connection": True}, _HC, _HC_MOD), - (_TE, {"enable_hyper_connection": False}, _TL, _ID), - (_TE, {"multi_latent_attention": True, "enable_hyper_connection": False}, _TL, _ID), - (_TE, {"multi_latent_attention": True, "enable_hyper_connection": True}, _HC, _HC_MOD), + (_TE, {"enable_hyper_connections": True}, _HC, _HC_MOD), + (_TE, {"enable_hyper_connections": False}, _TL, _ID), + (_TE, {"multi_latent_attention": True, "enable_hyper_connections": False}, _TL, _ID), + (_TE, {"multi_latent_attention": True, "enable_hyper_connections": True}, _HC, _HC_MOD), (_LOCAL, {}, _TL, _ID), - (_LOCAL, {"enable_hyper_connection": True}, _HC, _HC_MOD), - (_LOCAL, {"enable_hyper_connection": False}, _TL, _ID), - (_LOCAL, {"multi_latent_attention": True, "enable_hyper_connection": False}, _TL, _ID), + (_LOCAL, {"enable_hyper_connections": True}, _HC, _HC_MOD), + (_LOCAL, {"enable_hyper_connections": False}, _TL, _ID), + (_LOCAL, {"multi_latent_attention": True, "enable_hyper_connections": False}, _TL, _ID), ( _LOCAL, - {"multi_latent_attention": True, "enable_hyper_connection": True}, + {"multi_latent_attention": True, "enable_hyper_connections": True}, _HC, _HC_MOD, ), - (_LOCAL, {"normalization": "RMSNorm", "enable_hyper_connection": False}, _TL, _ID), - (_LOCAL, {"normalization": "RMSNorm", "enable_hyper_connection": True}, _HC, _HC_MOD), + (_LOCAL, {"normalization": "RMSNorm", "enable_hyper_connections": False}, _TL, _ID), + (_LOCAL, {"normalization": "RMSNorm", "enable_hyper_connections": True}, _HC, _HC_MOD), ], ids=[ "te_default", @@ -65,3 +65,9 @@ def test_hyper_connection_spec(self, factory, kwargs, expected_module, expected_ assert spec.module is expected_module assert spec.submodules.self_attention_hyper_connection is expected_hc assert spec.submodules.mlp_hyper_connection is expected_hc + + @pytest.mark.parametrize("factory", [_TE, _LOCAL], ids=["te", "local"]) + def test_singular_hyper_connection_keyword_rejected(self, factory): + """The spec keyword should match TransformerConfig.enable_hyper_connections.""" + with pytest.raises(TypeError, match="enable_hyper_connection"): + factory(enable_hyper_connection=True) diff --git a/tests/unit_tests/transformer/test_hyper_connection_recompute.py b/tests/unit_tests/transformer/test_hyper_connection_recompute.py index cf44f2d7cd0..4e8c51cd938 100644 --- a/tests/unit_tests/transformer/test_hyper_connection_recompute.py +++ b/tests/unit_tests/transformer/test_hyper_connection_recompute.py @@ -11,6 +11,8 @@ 5. TransformerConfig 'mhc' in recompute_modules option """ +import warnings + import pytest import torch @@ -403,6 +405,48 @@ def test_config_enable_mhc_recompute(self): assert "mhc" in config.recompute_modules assert config.enable_hyper_connections is True + def test_config_rejects_pipeline_parallel_hyper_connections(self): + """Pipeline-parallel tensor shapes do not support n-stream hidden states yet.""" + with pytest.raises( + ValueError, + match="enable_hyper_connections is not yet compatible with pipeline parallelism", + ): + TransformerConfig( + num_layers=2, + hidden_size=64, + num_attention_heads=4, + enable_hyper_connections=True, + num_residual_streams=4, + pipeline_model_parallel_size=2, + pipeline_dtype=torch.float32, + ) + + def test_hyper_connection_recompute_warning_requires_recompute(self): + """Do not warn about missing 'mhc' recompute when recompute is disabled.""" + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + TransformerConfig( + num_layers=2, + hidden_size=64, + num_attention_heads=4, + enable_hyper_connections=True, + num_residual_streams=4, + ) + + assert not any("HyperConnections are enabled" in str(w.message) for w in caught) + + def test_hyper_connection_recompute_warning_for_selective_without_mhc(self): + """Still warn when selective recompute is on but 'mhc' is omitted.""" + with pytest.warns(UserWarning, match="HyperConnections are enabled"): + TransformerConfig( + num_layers=2, + hidden_size=64, + num_attention_heads=4, + enable_hyper_connections=True, + num_residual_streams=4, + recompute_granularity="selective", + ) + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/unit_tests/transformer/test_transformer_layer.py b/tests/unit_tests/transformer/test_transformer_layer.py index 995e99d6a24..741f465ea17 100644 --- a/tests/unit_tests/transformer/test_transformer_layer.py +++ b/tests/unit_tests/transformer/test_transformer_layer.py @@ -360,7 +360,7 @@ def _create_layer_with_hyper_connection( recompute_granularity='selective', **extra, ) - layer_spec = get_gpt_layer_with_transformer_engine_spec(enable_hyper_connection=True) + layer_spec = get_gpt_layer_with_transformer_engine_spec(enable_hyper_connections=True) layer = HyperConnectionTransformerLayer( config, layer_spec.submodules, layer_number=layer_number ) @@ -544,7 +544,7 @@ def _run_forward_backward( recompute_modules=["core_attn", "mhc"] if use_recompute else None, recompute_granularity='selective' if use_recompute else None, ) - layer_spec = get_gpt_layer_with_transformer_engine_spec(enable_hyper_connection=True) + layer_spec = get_gpt_layer_with_transformer_engine_spec(enable_hyper_connections=True) layers = [ HyperConnectionTransformerLayer( config, layer_spec.submodules, layer_number=i + 1 @@ -638,7 +638,7 @@ def teardown_method(self, method): def _create_mhc_layer(self, hidden_size=64, num_streams=4, **extra_config): config = _make_mhc_config(hidden_size=hidden_size, num_streams=num_streams, **extra_config) - layer_spec = get_gpt_layer_with_transformer_engine_spec(enable_hyper_connection=True) + layer_spec = get_gpt_layer_with_transformer_engine_spec(enable_hyper_connections=True) layer = HyperConnectionTransformerLayer(config, layer_spec.submodules) layer.cuda() return layer, config @@ -980,7 +980,7 @@ def _create_mhc_layer_with_offloading( fine_grained_activation_offloading=True, offload_modules=offload_modules, ) - layer_spec = get_gpt_layer_with_transformer_engine_spec(enable_hyper_connection=True) + layer_spec = get_gpt_layer_with_transformer_engine_spec(enable_hyper_connections=True) layer = HyperConnectionTransformerLayer(config, layer_spec.submodules) layer.cuda() return layer, config @@ -1052,7 +1052,7 @@ def test_offloading_numerical_equivalence(self): # Run without offloading config_no_offload = _make_mhc_config(hidden_size=hidden_size, num_streams=num_streams) - layer_spec = get_gpt_layer_with_transformer_engine_spec(enable_hyper_connection=True) + layer_spec = get_gpt_layer_with_transformer_engine_spec(enable_hyper_connections=True) layer_no_offload = HyperConnectionTransformerLayer( config_no_offload, layer_spec.submodules ).cuda() From e551a52c7f7a00fa1773d07e2a8e77b72a20fa7c Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Mon, 27 Apr 2026 14:17:14 -0700 Subject: [PATCH 05/10] Address follow-up Claude feedback --- megatron/core/tensor_parallel/random.py | 12 ++++---- megatron/core/transformer/hyper_connection.py | 28 +++++++++---------- .../core/transformer/transformer_config.py | 6 ++++ .../core/transformer/transformer_layer.py | 27 +++++++++++++----- megatron/training/initialize.py | 2 -- .../test_hyper_connection_recompute.py | 16 +++++++++++ .../transformer/test_mhc_block_manager.py | 16 +++++++++++ .../transformer/test_transformer_layer.py | 12 ++++++++ 8 files changed, 91 insertions(+), 28 deletions(-) diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py index a43b1636799..6aff0128b6c 100644 --- a/megatron/core/tensor_parallel/random.py +++ b/megatron/core/tensor_parallel/random.py @@ -784,13 +784,14 @@ def add_checkpoint(self, ckpt): def discard_all_outputs_and_register_unified_recompute(self, hook_tensor): """Discard all checkpoint outputs to save memory and register unified recompute hook.""" + if not hook_tensor.requires_grad: + return + for ckpt in self.checkpoints: for output in ckpt.outputs: output.untyped_storage().resize_(0) - # Register unified recompute hook - if hook_tensor.requires_grad: - hook_tensor.register_hook(self._unified_recompute_hook) + hook_tensor.register_hook(self._unified_recompute_hook) def _unified_recompute_hook(self, grad_output): # Chained checkpoints rely on forward-order recomputation. Each @@ -826,8 +827,9 @@ def __init__(self, fp8=False, ckpt_manager=None): discard_output_and_register_recompute() will only discard output without registering individual hooks. """ - # Treat the default fp8=False as disabled. The old "fp8 is not None" - # behavior entered the TE FP8 recompute path for non-FP8 callers. + # Intentional bug fix: the default fp8=False must not enter the TE FP8 + # recompute context. The old "fp8 is not None" behavior did that for + # every default CheckpointWithoutOutput() caller. self.fp8 = bool(fp8) self.ckpt_manager = ckpt_manager self.run_function = None diff --git a/megatron/core/transformer/hyper_connection.py b/megatron/core/transformer/hyper_connection.py index e15ea881ffd..8ac9ce1ca81 100644 --- a/megatron/core/transformer/hyper_connection.py +++ b/megatron/core/transformer/hyper_connection.py @@ -83,14 +83,14 @@ def reference_h_post_bda( @torch.compile -def reference_proj_rms(x: Tensor, weight: Tensor, eps: float = 1e-6) -> Tuple[Tensor, Tensor]: - """Reference fused projection + RMS normalization.""" +def reference_proj_inv_rms(x: Tensor, weight: Tensor, eps: float = 1e-6) -> Tuple[Tensor, Tensor]: + """Reference fused projection + inverse-RMS normalization scale.""" proj = torch.matmul(x, weight.t()) norm = x.norm(dim=-1, keepdim=True) K = x.shape[-1] - v = norm / math.sqrt(K) + eps - r = 1.0 / v - return proj, r + rms = norm / math.sqrt(K) + eps + inv_rms = 1.0 / rms + return proj, inv_rms # ============================================================================ @@ -160,12 +160,12 @@ def __init__(self, config: TransformerConfig, layer_number: int): self._sinkhorn_op = fused_sinkhorn self._h_aggregate_op = fused_h_aggregate self._h_post_bda_op = fused_h_post_bda - self._proj_rms_op = fused_proj_rms + self._proj_inv_rms_op = fused_proj_rms else: self._sinkhorn_op = reference_sinkhorn self._h_aggregate_op = reference_h_aggregate self._h_post_bda_op = reference_h_post_bda - self._proj_rms_op = reference_proj_rms + self._proj_inv_rms_op = reference_proj_inv_rms self._init_weights() @@ -193,17 +193,17 @@ def _projection_and_get_norm(self, x: Tensor) -> Tuple[Tensor, Tensor]: """ s, b, nC = x.shape x_2d = x.reshape(s * b, nC) - proj, r = self._proj_rms_op(x_2d, self.mapping_proj.weight, self.norm_eps) - return proj.view(s, b, -1), r.view(s, b, 1) + proj, inv_rms = self._proj_inv_rms_op(x_2d, self.mapping_proj.weight, self.norm_eps) + return proj.view(s, b, -1), inv_rms.view(s, b, 1) @torch.compile - def _compute_h(self, proj: Tensor, r: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + def _compute_h(self, proj: Tensor, inv_rms: Tensor) -> Tuple[Tensor, Tensor, Tensor]: """ Compute h from projected hidden states and scaling factors. Args: proj: [s, b, n^2 + 2n] - projected hidden states - r: [s, b, 1] - scaling factors + inv_rms: [s, b, 1] - inverse-RMS scaling factors Returns: h_pre: [s, b, n] - aggregation weights @@ -218,7 +218,7 @@ def _compute_h(self, proj: Tensor, r: Tensor) -> Tuple[Tensor, Tensor, Tensor]: ], dim=-1, ) - h = r * proj * alpha_ + self.bias + h = inv_rms * proj * alpha_ + self.bias # H_pre = σ(α_pre * (θ_pre @ x̃) + b_pre) h_pre = h[..., : self.n].sigmoid() # [s, b, n] @@ -244,9 +244,9 @@ def compute_mappings(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: """ s, b, _ = x.shape with torch.cuda.nvtx.range("HyperConnection::projection_and_get_norm"): - proj, r = self._projection_and_get_norm(x) + proj, inv_rms = self._projection_and_get_norm(x) with torch.cuda.nvtx.range("HyperConnection::compute_h"): - h_pre, h_post, h_res = self._compute_h(proj, r) + h_pre, h_post, h_res = self._compute_h(proj, inv_rms) h_res = self._sinkhorn_op( h_res.view(s, b, self.n, self.n), self.sinkhorn_iterations, self.norm_eps ) # [s, b, n, n] diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index eea6e6aa04a..10fd36d72b7 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1564,6 +1564,12 @@ def __post_init__(self): "for a follow-up PR." ) + if self.enable_hyper_connections and self.inference_fuse_tp_communication: + raise ValueError( + "enable_hyper_connections is not compatible with " + "inference_fuse_tp_communication." + ) + if self.fine_grained_activation_offloading: assert ( not self.cpu_offloading diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 214aec813f5..d12d5fe84bf 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -1336,6 +1336,17 @@ def __init__( pg_collection: Optional[ProcessGroupCollection] = None, vp_stage: Optional[int] = None, ): + if submodules.cross_attention is not IdentityOp: + raise ValueError( + "HyperConnectionTransformerLayer does not support cross-attention. " + "Use IdentityOp for cross_attention." + ) + if submodules.cross_attention_hyper_connection is not IdentityOp: + raise ValueError( + "HyperConnectionTransformerLayer does not support cross-attention " + "hyper connections. Use IdentityOp for cross_attention_hyper_connection." + ) + super().__init__( config=config, submodules=submodules, @@ -1345,12 +1356,6 @@ def __init__( vp_stage=vp_stage, ) - if submodules.cross_attention_hyper_connection is not IdentityOp: - raise ValueError( - "HyperConnectionTransformerLayer does not support cross-attention " - "hyper connections. Use IdentityOp for cross_attention_hyper_connection." - ) - assert submodules.self_attention_hyper_connection is not IdentityOp, ( "HyperConnectionTransformerLayer requires self_attention_hyper_connection. " "Use TransformerLayer instead if hyper connections are not needed." @@ -1427,7 +1432,9 @@ def forward(self, *args, **kwargs): """Forward pass with MHC recompute manager support.""" kwargs.pop("dynamic_inference_decode_only", None) - mhc_recompute_manager = getattr(self, '_mhc_recompute_manager', None) + mhc_recompute_manager = kwargs.pop("mhc_recompute_manager", None) + if mhc_recompute_manager is None: + mhc_recompute_manager = getattr(self, '_mhc_recompute_manager', None) hidden_states, context = self._forward_attention( *args, mhc_recompute_manager=mhc_recompute_manager, **kwargs @@ -1468,6 +1475,8 @@ def _forward_attention( inference_context = deprecate_inference_params(inference_context, inference_params) residual = hidden_states + if self.config.fp32_residual_connection: + residual = residual.float() nvtx_range_push(suffix="self_attention_hyper_connection") hidden_states, self_attn_h_res, self_attn_hc_h_post = self.self_attention_hyper_connection( @@ -1531,6 +1540,8 @@ def _forward_attention( # Cross-attention (no hyper connection support). residual = hidden_states + if self.config.fp32_residual_connection: + residual = residual.float() pre_cross_attn_layernorm_output = self.pre_cross_attn_layernorm(hidden_states) attention_output_with_bias = self.cross_attention( @@ -1569,6 +1580,8 @@ def _forward_mlp( mhc_mlp_bda_manager = None if is_last_in_recompute_block else mhc_recompute_manager residual = hidden_states + if self.config.fp32_residual_connection: + residual = residual.float() nvtx_range_push(suffix="mlp_hyper_connection") hidden_states, mlp_h_res, mlp_hc_h_post = self.mlp_hyper_connection( diff --git a/megatron/training/initialize.py b/megatron/training/initialize.py index 61a795b4754..713bce571d4 100644 --- a/megatron/training/initialize.py +++ b/megatron/training/initialize.py @@ -95,8 +95,6 @@ def state_restore_func(state_dict): enable_batch_invariant_mode() # Enable NVTX range profiling when profiling is active. - # Must be done before model modules with @nvtx_decorator are imported, - # since the decorator captures _nvtx_enabled at decoration (import) time. if args.profile: configure_nvtx_profiling(True) diff --git a/tests/unit_tests/transformer/test_hyper_connection_recompute.py b/tests/unit_tests/transformer/test_hyper_connection_recompute.py index 4e8c51cd938..25640ead10e 100644 --- a/tests/unit_tests/transformer/test_hyper_connection_recompute.py +++ b/tests/unit_tests/transformer/test_hyper_connection_recompute.py @@ -421,6 +421,22 @@ def test_config_rejects_pipeline_parallel_hyper_connections(self): pipeline_dtype=torch.float32, ) + def test_config_rejects_fused_tp_inference_hyper_connections(self): + """mHC does not implement the fused TP inference residual path.""" + with pytest.raises( + ValueError, + match="enable_hyper_connections is not compatible with " + "inference_fuse_tp_communication", + ): + TransformerConfig( + num_layers=2, + hidden_size=64, + num_attention_heads=4, + enable_hyper_connections=True, + num_residual_streams=4, + inference_fuse_tp_communication=True, + ) + def test_hyper_connection_recompute_warning_requires_recompute(self): """Do not warn about missing 'mhc' recompute when recompute is disabled.""" with warnings.catch_warnings(record=True) as caught: diff --git a/tests/unit_tests/transformer/test_mhc_block_manager.py b/tests/unit_tests/transformer/test_mhc_block_manager.py index aab004d6516..f02ea1c555e 100644 --- a/tests/unit_tests/transformer/test_mhc_block_manager.py +++ b/tests/unit_tests/transformer/test_mhc_block_manager.py @@ -128,6 +128,22 @@ def test_error_handling(self): with pytest.raises(ValueError): manager.add_checkpoint(ckpt) + def test_unified_recompute_keeps_outputs_when_hook_has_no_grad(self): + """Do not discard outputs if no hook can be registered for recompute.""" + manager = CheckpointManager() + + def func(x): + return x * 2 + + input_t = torch.randn(4, 4, device='cuda', requires_grad=True) + ckpt = CheckpointWithoutOutput(ckpt_manager=manager) + y = ckpt.checkpoint(func, input_t) + + hook_tensor = torch.zeros((), device='cuda', requires_grad=False) + manager.discard_all_outputs_and_register_unified_recompute(hook_tensor) + + assert y.untyped_storage().size() > 0 + class TestCheckpointManagerSequentialChain: """Test CheckpointManager with sequential checkpoint chains.""" diff --git a/tests/unit_tests/transformer/test_transformer_layer.py b/tests/unit_tests/transformer/test_transformer_layer.py index 741f465ea17..23a18c09050 100644 --- a/tests/unit_tests/transformer/test_transformer_layer.py +++ b/tests/unit_tests/transformer/test_transformer_layer.py @@ -12,6 +12,7 @@ get_gpt_layer_with_transformer_engine_submodules, ) from megatron.core.tensor_parallel.random import CheckpointManager, model_parallel_cuda_manual_seed +from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import ( HyperConnectionTransformerLayer, @@ -367,6 +368,17 @@ def _create_layer_with_hyper_connection( layer.cuda() return layer, config + def test_rejects_cross_attention_with_hyper_connection(self): + """mHC transformer layers are decoder-only until cross-attention is implemented.""" + config = _make_mhc_config() + layer_spec = get_gpt_layer_with_transformer_engine_spec(enable_hyper_connections=True) + layer_spec.submodules.cross_attention = torch.nn.Linear + + with pytest.raises(ValueError, match="does not support cross-attention"): + HyperConnectionTransformerLayer(config, layer_spec.submodules) + + layer_spec.submodules.cross_attention = IdentityOp + def test_forward_with_hyper_connection_recompute(self): """ Test that TransformerLayer forward works correctly with HyperConnection From 2604c25026214a02cb4db5793836b169e9b28b75 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Mon, 27 Apr 2026 14:55:05 -0700 Subject: [PATCH 06/10] Address latest Claude review feedback --- megatron/core/fusions/fused_bias_dropout.py | 81 +------------------ megatron/core/transformer/hyper_connection.py | 13 ++- .../core/transformer/transformer_block.py | 36 ++++++--- .../core/transformer/transformer_config.py | 10 +++ .../core/transformer/transformer_layer.py | 4 + .../test_hyper_connection_recompute.py | 18 +++++ 6 files changed, 65 insertions(+), 97 deletions(-) diff --git a/megatron/core/fusions/fused_bias_dropout.py b/megatron/core/fusions/fused_bias_dropout.py index 1f2448d86be..db11fafd5cb 100644 --- a/megatron/core/fusions/fused_bias_dropout.py +++ b/megatron/core/fusions/fused_bias_dropout.py @@ -1,13 +1,10 @@ # Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -from typing import TYPE_CHECKING, Optional, Tuple +from typing import Optional, Tuple import torch from megatron.core.jit import jit_fuser -if TYPE_CHECKING: - from megatron.core.tensor_parallel.random import CheckpointManager - # pylint: disable=missing-function-docstring @@ -83,26 +80,17 @@ def bias_dropout_add_fused_inference( return _bias_dropout_add_func(x_with_bias, residual, prob, False) -def get_bias_dropout_add( - training, fused, mhc_recompute_manager: Optional['CheckpointManager'] = None -): +def get_bias_dropout_add(training, fused): """ Get the bias-dropout-add function. Args: training: Whether in training mode. fused: Whether to use fused implementation. - mhc_recompute_manager: Optional CheckpointManager for checkpoint management. - When provided, the returned function will wrap the BDA operation with - CheckpointWithoutOutput for memory-efficient recomputation. Returns: A callable that performs bias-dropout-add operation. """ - if mhc_recompute_manager is not None: - # Return a checkpointed version that handles tuple unpacking internally - return _get_checkpointed_bda(training, fused, mhc_recompute_manager) - if fused: # jit scripting for a nn.module (with dropout) is not # triggering the fusion kernel. For now, we use two @@ -114,68 +102,3 @@ def get_bias_dropout_add( return bias_dropout_add_fused_inference else: return bias_dropout_add_unfused(training) - - -def _get_checkpointed_bda(training, fused, mhc_recompute_manager: 'CheckpointManager'): - """ - Create a checkpointed bias-dropout-add function. - - This function handles: - 1. Tuple unpacking for x_with_bias (required because save_for_backward can't save tuples) - 2. Non-tensor arguments like dropout probability (handled by CheckpointWithoutOutput) - 3. Auto-registration to the CheckpointManager - - Args: - training: Whether in training mode. - fused: Whether to use fused implementation. - mhc_recompute_manager: CheckpointManager for checkpoint management. - - Returns: - A callable that performs checkpointed bias-dropout-add operation. - """ - from megatron.core.tensor_parallel.random import CheckpointWithoutOutput - - # Get the underlying BDA function - if fused: - if training: - bda_func = bias_dropout_add_fused_train - else: - bda_func = bias_dropout_add_fused_inference - else: - bda_func = bias_dropout_add_unfused(training) - - def _checkpointed_bda(x_with_bias, residual, prob): - """ - Checkpointed BDA that handles tuple unpacking internally. - - Args: - x_with_bias: Either a tuple (x, bias) or a single tensor x. - residual: Residual tensor. - prob: Dropout probability. - - Returns: - Output tensor after bias-dropout-add. - """ - # Create checkpoint with manager - ckpt = CheckpointWithoutOutput(ckpt_manager=mhc_recompute_manager) - - # Handle case where x_with_bias might be a single tensor (e.g., from IdentityOp) - if isinstance(x_with_bias, tuple): - x, bias = x_with_bias - else: - x = x_with_bias - bias = None - - # Wrapper function that re-packs the tuple for the actual BDA function - def _bda_wrapper(output, bias, res, dropout): - return bda_func((output, bias), res, dropout) - - # Call checkpoint with unpacked arguments - result = ckpt.checkpoint(_bda_wrapper, x, bias, residual, prob) - - # No-op when manager is set - manager handles all discarding uniformly - ckpt.discard_output_and_register_recompute(result) - - return result - - return _checkpointed_bda diff --git a/megatron/core/transformer/hyper_connection.py b/megatron/core/transformer/hyper_connection.py index 8ac9ce1ca81..78beec10a83 100644 --- a/megatron/core/transformer/hyper_connection.py +++ b/megatron/core/transformer/hyper_connection.py @@ -70,7 +70,7 @@ def reference_h_aggregate(x: Tensor, h_pre: Tensor) -> Tensor: def reference_h_post_bda( h_res: Tensor, original_residual: Tensor, h_post: Tensor, x: Tensor, bias: Optional[Tensor] ) -> Tensor: - """Reference H_res @ residual + H_post * (x [+ bias]).""" + """Reference H_res @ residual + H_post * (x [+ bias]), flattened to [s, b, n*C].""" s, b, n, C = original_residual.shape h_res_batched = h_res.view(s * b, n, n) residual_batched = original_residual.view(s * b, n, C) @@ -78,8 +78,8 @@ def reference_h_post_bda( x_expanded = h_post.unsqueeze(-1) * x.unsqueeze(2) if bias is not None: bias_expanded = h_post.unsqueeze(-1) * bias.view(1, 1, 1, C) - return x_expanded + bias_expanded + mixed - return x_expanded + mixed + return (x_expanded + bias_expanded + mixed).view(s, b, n * C) + return (x_expanded + mixed).view(s, b, n * C) @torch.compile @@ -127,7 +127,7 @@ def __init__(self, config: TransformerConfig, layer_number: int): # Projection weights for dynamic mappings. The reference implementation # keeps this as a full, non-TP-partitioned linear projection over n*C; - # fused/partitioned variants are expected to replace it in follow-up work. + # TODO: replace with fused/partitioned variants in the fused mHC follow-up. # Input: [s, b, n*C] -> Output: n^2 + 2n values per token # - H_pre: n values # - H_post: n values @@ -591,8 +591,7 @@ def _fused_h_res_h_post_bda_native( n = self.n C = self.hidden_size orig_reshaped = original_residual.view(s, b, n, C) - output = self._h_post_bda_op(h_res, orig_reshaped, h_post, x, bias) - return output.view(s, b, n * C) + return self._h_post_bda_op(h_res, orig_reshaped, h_post, x, bias) from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add @@ -652,7 +651,7 @@ def _fused_wrapper(h_res, original_residual, h_post, x, *optional_bias): s, b, _ = original_residual.shape orig_reshaped = original_residual.view(s, b, n, C) b_arg = optional_bias[0] if optional_bias else None - return self._h_post_bda_op(h_res, orig_reshaped, h_post, x, b_arg).view(s, b, n * C) + return self._h_post_bda_op(h_res, orig_reshaped, h_post, x, b_arg) ckpt = CheckpointWithoutOutput(ckpt_manager=manager) if bias is not None: diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index 67d3da0e076..09ff2aa5633 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -371,6 +371,7 @@ def build_layer(layer_spec, layer_number): for i, layer_spec in enumerate(self.submodules.layer_specs) ] ) + self._mhc_recompute_block_end_plan = self._build_mhc_recompute_block_end_plan() # @TODO: add back account_for_embedding_in_pipeline_split (see issue #293) # In pipeline parallelism, we want to add this LN only to the last stage of the pipeline @@ -646,20 +647,15 @@ def __call__(self, *args, **kwargs): return super().__call__(*args, **kwargs)[0] return super().__call__(*args, **kwargs) - def _build_mhc_recompute_layer_plan( - self, use_mhc_recompute: bool - ) -> Tuple[List[Optional[CheckpointManager]], List[bool]]: - """Pre-build per-layer MHC recompute managers and block-end markers.""" + def _build_mhc_recompute_block_end_plan(self) -> List[bool]: + """Precompute deterministic mHC recompute block-end markers.""" num_layers = len(self.layers) - layer_managers: List[Optional[CheckpointManager]] = [None] * num_layers - is_recompute_block_end: List[bool] = [False] * num_layers + is_recompute_block_end: List[bool] = [] - if not use_mhc_recompute or num_layers == 0: - return layer_managers, is_recompute_block_end + if num_layers == 0: + return is_recompute_block_end mhc_recompute_layer_num = self.config.mhc_recompute_layer_num - mhc_manager = CheckpointManager() - for l_no in range(num_layers): is_last_in_transformer_block = l_no == num_layers - 1 is_last_in_recompute_block = is_last_in_transformer_block @@ -668,8 +664,26 @@ def _build_mhc_recompute_layer_plan( (l_no + 1) % mhc_recompute_layer_num == 0 ) + is_recompute_block_end.append(is_last_in_recompute_block) + + return is_recompute_block_end + + def _build_mhc_recompute_layer_plan( + self, use_mhc_recompute: bool + ) -> Tuple[List[Optional[CheckpointManager]], List[bool]]: + """Build fresh per-forward MHC managers using cached block-end topology.""" + num_layers = len(self.layers) + layer_managers: List[Optional[CheckpointManager]] = [None] * num_layers + is_recompute_block_end = self._mhc_recompute_block_end_plan + + if not use_mhc_recompute or num_layers == 0: + return layer_managers, is_recompute_block_end + + mhc_manager = CheckpointManager() + + for l_no, is_last_in_recompute_block in enumerate(is_recompute_block_end): + is_last_in_transformer_block = l_no == num_layers - 1 layer_managers[l_no] = mhc_manager - is_recompute_block_end[l_no] = is_last_in_recompute_block if is_last_in_recompute_block and not is_last_in_transformer_block: mhc_manager = CheckpointManager() diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 10fd36d72b7..4342fd12775 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -2111,6 +2111,16 @@ def __post_init__(self): "To use full iteration cuda graph, please use " "cuda_graph_impl=local instead of cuda_graph_impl=transformer_engine." ) + if ( + self.enable_hyper_connections + and self.num_moe_experts is not None + and self.num_moe_experts > 1 + and CudaGraphScope.moe_router in self.cuda_graph_scope + ): + raise ValueError( + "enable_hyper_connections is not yet compatible with " + "MoE router CUDA graphs." + ) assert ( CudaGraphScope.moe not in self.cuda_graph_scope or CudaGraphScope.moe_router not in self.cuda_graph_scope diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index d12d5fe84bf..8bacccbdf8e 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -236,6 +236,7 @@ class TransformerLayerSubmodules: self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp pre_cross_attn_layernorm: LayerNormBuilder = IdentityOp + # Reserved for future cross-attention hyper-connection support. cross_attention_hyper_connection: Union[ModuleSpec, type] = IdentityOp cross_attention: Union[ModuleSpec, type] = IdentityOp cross_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp @@ -1432,6 +1433,9 @@ def forward(self, *args, **kwargs): """Forward pass with MHC recompute manager support.""" kwargs.pop("dynamic_inference_decode_only", None) + # Direct forward() calls can pass the manager normally. __call__ stores + # it on self first so CUDA graph argument processing never sees the + # unsupported CheckpointManager object. mhc_recompute_manager = kwargs.pop("mhc_recompute_manager", None) if mhc_recompute_manager is None: mhc_recompute_manager = getattr(self, '_mhc_recompute_manager', None) diff --git a/tests/unit_tests/transformer/test_hyper_connection_recompute.py b/tests/unit_tests/transformer/test_hyper_connection_recompute.py index 25640ead10e..ebc149e8912 100644 --- a/tests/unit_tests/transformer/test_hyper_connection_recompute.py +++ b/tests/unit_tests/transformer/test_hyper_connection_recompute.py @@ -437,6 +437,24 @@ def test_config_rejects_fused_tp_inference_hyper_connections(self): inference_fuse_tp_communication=True, ) + def test_config_rejects_moe_router_cuda_graph_hyper_connections(self): + """mHC MoE layers do not implement the TE moe_router CUDA graph path yet.""" + with pytest.raises( + ValueError, + match="enable_hyper_connections is not yet compatible with " + "MoE router CUDA graphs", + ): + TransformerConfig( + num_layers=2, + hidden_size=64, + num_attention_heads=4, + enable_hyper_connections=True, + num_residual_streams=4, + num_moe_experts=4, + cuda_graph_impl="transformer_engine", + cuda_graph_scope=["moe_router"], + ) + def test_hyper_connection_recompute_warning_requires_recompute(self): """Do not warn about missing 'mhc' recompute when recompute is disabled.""" with warnings.catch_warnings(record=True) as caught: From e6ef250abfdaa0489bfc26229f7cc02692c1ede1 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Mon, 27 Apr 2026 15:14:44 -0700 Subject: [PATCH 07/10] Follow up on Claude transformer layer review --- .../core/transformer/transformer_block.py | 37 +++++++------ .../core/transformer/transformer_layer.py | 48 +++++++++++++++-- .../transformer/test_transformer_layer.py | 52 +++++++++++++++++++ 3 files changed, 116 insertions(+), 21 deletions(-) diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index 09ff2aa5633..85bbe0744ae 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -902,23 +902,28 @@ def forward( mhc_is_last_in_recompute_block[l_no] ) + layer_kwargs = dict( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + rotary_pos_cos_sin=rotary_pos_cos_sin, + attention_bias=attention_bias, + inference_context=inference_context, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + padding_mask=padding_mask, + ) + if mhc_manager is not None and hasattr( + layer, "self_attention_hyper_connection" + ): + layer_kwargs["mhc_recompute_manager"] = mhc_manager + with self.offload_context, inner_quantization_context: - hidden_states, context = layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - rotary_pos_cos_sin=rotary_pos_cos_sin, - attention_bias=attention_bias, - inference_context=inference_context, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - padding_mask=padding_mask, - mhc_recompute_manager=mhc_manager, - ) + hidden_states, context = layer(**layer_kwargs) self._finalize_mhc_recompute_layer( mhc_manager=mhc_manager, hidden_states=hidden_states, diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 8bacccbdf8e..ed1e58f8a1b 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -707,6 +707,8 @@ def forward(self, *args, **kwargs): """ # Injected by __call__ for cuda graph keying; not a real forward arg. kwargs.pop("dynamic_inference_decode_only", None) + # mHC recompute is consumed by HyperConnectionTransformerLayer only. + kwargs.pop("mhc_recompute_manager", None) hidden_states, context = self._forward_attention(*args, **kwargs) output = self._forward_mlp( hidden_states, @@ -1336,6 +1338,9 @@ def __init__( hidden_dropout: Optional[float] = None, pg_collection: Optional[ProcessGroupCollection] = None, vp_stage: Optional[int] = None, + is_mtp_layer: bool = False, + add_layer_offset: bool = True, + pp_layer_offset: Optional[int] = None, ): if submodules.cross_attention is not IdentityOp: raise ValueError( @@ -1355,6 +1360,9 @@ def __init__( hidden_dropout=hidden_dropout, pg_collection=pg_collection, vp_stage=vp_stage, + is_mtp_layer=is_mtp_layer, + add_layer_offset=add_layer_offset, + pp_layer_offset=pp_layer_offset, ) assert submodules.self_attention_hyper_connection is not IdentityOp, ( @@ -1498,11 +1506,20 @@ def _forward_attention( ) with off_interface(self.offload_attn_norm, hidden_states, "attn_norm") as hidden_states: input_layernorm_output = self.input_layernorm_checkpoint.checkpoint( - self.input_layernorm, hidden_states + apply_module(self.input_layernorm), hidden_states ) else: with off_interface(self.offload_attn_norm, hidden_states, "attn_norm") as hidden_states: - input_layernorm_output = self.input_layernorm(hidden_states) + input_layernorm_output = apply_module(self.input_layernorm)(hidden_states) + + if isinstance(input_layernorm_output, tuple): + if len(input_layernorm_output) != 2: + raise ValueError( + f"When the output of input_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(input_layernorm_output)}" + ) + input_layernorm_output, _ = input_layernorm_output # Self attention. nvtx_range_push(suffix="self_attention") @@ -1546,7 +1563,19 @@ def _forward_attention( residual = hidden_states if self.config.fp32_residual_connection: residual = residual.float() - pre_cross_attn_layernorm_output = self.pre_cross_attn_layernorm(hidden_states) + pre_cross_attn_layernorm_output = apply_module(self.pre_cross_attn_layernorm)( + hidden_states + ) + + if isinstance(pre_cross_attn_layernorm_output, tuple): + if len(pre_cross_attn_layernorm_output) != 2: + raise ValueError( + f"When the output of pre_cross_attn_layernorm_output " + f"is a tuple, it is expected to have 2 elements " + f"(output, residual), but " + f"got {len(pre_cross_attn_layernorm_output)}" + ) + pre_cross_attn_layernorm_output, _ = pre_cross_attn_layernorm_output attention_output_with_bias = self.cross_attention( pre_cross_attn_layernorm_output, @@ -1603,11 +1632,20 @@ def _forward_mlp( ) with off_interface(self.offload_mlp_norm, hidden_states, "mlp_norm") as hidden_states: pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint( - self.pre_mlp_layernorm, hidden_states + apply_module(self.pre_mlp_layernorm), hidden_states ) else: with off_interface(self.offload_mlp_norm, hidden_states, "mlp_norm") as hidden_states: - pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) + pre_mlp_layernorm_output = apply_module(self.pre_mlp_layernorm)(hidden_states) + + if isinstance(pre_mlp_layernorm_output, tuple): + if len(pre_mlp_layernorm_output) != 2: + raise ValueError( + f"When the output of pre_mlp_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(pre_mlp_layernorm_output)}" + ) + pre_mlp_layernorm_output, _ = pre_mlp_layernorm_output nvtx_range_push(suffix="mlp") should_chunk_mlp_for_prefill = ( diff --git a/tests/unit_tests/transformer/test_transformer_layer.py b/tests/unit_tests/transformer/test_transformer_layer.py index 23a18c09050..4fcb6fa9997 100644 --- a/tests/unit_tests/transformer/test_transformer_layer.py +++ b/tests/unit_tests/transformer/test_transformer_layer.py @@ -87,6 +87,30 @@ def test_gpu_forward(self): assert hidden_states.shape[1] == micro_batch_size assert hidden_states.shape[2] == config.hidden_size + def test_gpu_forward_ignores_mhc_recompute_manager_kwarg(self): + """Non-HC layers must tolerate TransformerBlock's mHC recompute kwarg.""" + parallel_transformer_layer = self.parallel_transformer_layer + config: TransformerConfig = parallel_transformer_layer.config + sequence_length = 8 + micro_batch_size = 2 + parallel_transformer_layer.cuda() + + hidden_states = torch.ones( + (sequence_length, micro_batch_size, config.hidden_size), device='cuda' + ) + attention_mask = torch.ones( + (1, 1, sequence_length, sequence_length), dtype=bool, device='cuda' + ) + + hidden_states, context = parallel_transformer_layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + mhc_recompute_manager=None, + ) + + assert context is None + assert hidden_states.shape == (sequence_length, micro_batch_size, config.hidden_size) + def test_chunked_mlp(self): with torch.no_grad(): @@ -379,6 +403,34 @@ def test_rejects_cross_attention_with_hyper_connection(self): layer_spec.submodules.cross_attention = IdentityOp + def test_layernorm_tuple_outputs_are_unpacked(self): + """HC layers should mirror base-layer tuple handling for fused layernorms.""" + + class TupleLayerNorm(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states, hidden_states + + layer, config = self._create_layer_with_hyper_connection() + layer.input_layernorm = TupleLayerNorm().cuda() + layer.pre_mlp_layernorm = TupleLayerNorm().cuda() + layer.train() + + seq_len = 8 + batch_size = 2 + hidden_states = torch.randn( + seq_len, + batch_size, + config.num_residual_streams * config.hidden_size, + device='cuda', + requires_grad=True, + ) + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + output, context = layer(hidden_states=hidden_states, attention_mask=attention_mask) + + assert context is None + assert output.shape == hidden_states.shape + def test_forward_with_hyper_connection_recompute(self): """ Test that TransformerLayer forward works correctly with HyperConnection From 6562e52e9fc479979daa2acf4232f7a77328fb57 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Mon, 27 Apr 2026 16:32:45 -0700 Subject: [PATCH 08/10] Address Claude mHC follow-up review --- megatron/core/tensor_parallel/random.py | 24 ++++++---- megatron/core/transformer/hyper_connection.py | 48 +++++++++++-------- .../core/transformer/transformer_block.py | 12 ++--- .../core/transformer/transformer_config.py | 5 +- .../core/transformer/transformer_layer.py | 32 ++++++++----- .../unit_tests/tensor_parallel/test_random.py | 7 +++ .../test_hyper_connection_recompute.py | 36 +++++++++----- .../transformer/test_mhc_block_manager.py | 34 ++++++------- .../transformer/test_transformer_layer.py | 30 ++++++------ 9 files changed, 134 insertions(+), 94 deletions(-) diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py index 6aff0128b6c..d4c05bf3cae 100644 --- a/megatron/core/tensor_parallel/random.py +++ b/megatron/core/tensor_parallel/random.py @@ -753,19 +753,21 @@ def backward(ctx, *args): return (None, None) + grads -class CheckpointManager: +class MHCRecomputeManager: """ - Manages multiple CheckpointWithoutOutput objects within a TransformerBlock - cross layer recomputations, enabling unified recomputation during backward pass. - This is particularly useful for scenarios where multiple checkpoint operations have - sequential dependencies (i.e., the output of one checkpoint is the input of the next). + Manages cross-layer mHC CheckpointWithoutOutput recomputation blocks. + + This is particularly useful when multiple checkpointed operations have sequential + dependencies (i.e., the output of one checkpoint is the input of the next). + The manager recomputes checkpoints in forward order so each checkpoint restores + its output storage before the next checkpoint consumes it during recomputation. Usage: - ckptManager = CheckpointManager() - ckpt_function = CheckpointWithoutOutput(ckpt_manager=ckptManager) + mhc_manager = MHCRecomputeManager() + ckpt_function = CheckpointWithoutOutput(ckpt_manager=mhc_manager) ckpt_function.checkpoint(run_function, *args) # other checkpointed operations - ckpt_manager.discard_all_outputs_and_register_unified_recompute(final_output) + mhc_manager.discard_all_outputs_and_register_unified_recompute(final_output) """ def __init__(self): @@ -802,6 +804,10 @@ def _unified_recompute_hook(self, grad_output): ckpt._recompute(None) +# Backward-compatible alias for earlier internal users of the generic name. +CheckpointManager = MHCRecomputeManager + + class CheckpointWithoutOutput(object): """ Checkpoint a model or part of the model and release the output. @@ -822,7 +828,7 @@ def __init__(self, fp8=False, ckpt_manager=None): Args: fp8: Whether to use FP8 mode. Defaults to False. - ckpt_manager: Optional CheckpointManager instance. When provided, + ckpt_manager: Optional MHCRecomputeManager instance. When provided, checkpoint() will auto-register to the manager, and discard_output_and_register_recompute() will only discard output without registering individual hooks. diff --git a/megatron/core/transformer/hyper_connection.py b/megatron/core/transformer/hyper_connection.py index 78beec10a83..e5b611371dd 100644 --- a/megatron/core/transformer/hyper_connection.py +++ b/megatron/core/transformer/hyper_connection.py @@ -12,7 +12,7 @@ from megatron.core.utils import nvtx_decorator if TYPE_CHECKING: - from megatron.core.tensor_parallel.random import CheckpointManager + from megatron.core.tensor_parallel.random import MHCRecomputeManager @torch.compile @@ -85,11 +85,14 @@ def reference_h_post_bda( @torch.compile def reference_proj_inv_rms(x: Tensor, weight: Tensor, eps: float = 1e-6) -> Tuple[Tensor, Tensor]: """Reference fused projection + inverse-RMS normalization scale.""" - proj = torch.matmul(x, weight.t()) - norm = x.norm(dim=-1, keepdim=True) + input_dtype = x.dtype + x_float = x.float() + weight_float = weight.float() + proj = torch.matmul(x_float, weight_float.t()).to(dtype=input_dtype) + norm = x_float.norm(dim=-1, keepdim=True) K = x.shape[-1] rms = norm / math.sqrt(K) + eps - inv_rms = 1.0 / rms + inv_rms = (1.0 / rms).to(dtype=input_dtype) return proj, inv_rms @@ -126,8 +129,10 @@ def __init__(self, config: TransformerConfig, layer_number: int): self.sinkhorn_iterations = config.mhc_sinkhorn_iterations # Projection weights for dynamic mappings. The reference implementation - # keeps this as a full, non-TP-partitioned linear projection over n*C; - # TODO: replace with fused/partitioned variants in the fused mHC follow-up. + # keeps this as a full, non-TP-partitioned linear projection over n*C, so + # TP ranks duplicate the projection compute and store the n*C activation + # for the weight gradient. TODO: replace with fused/partitioned variants + # in the fused mHC follow-up. # Input: [s, b, n*C] -> Output: n^2 + 2n values per token # - H_pre: n values # - H_post: n values @@ -218,13 +223,13 @@ def _compute_h(self, proj: Tensor, inv_rms: Tensor) -> Tuple[Tensor, Tensor, Ten ], dim=-1, ) - h = inv_rms * proj * alpha_ + self.bias + h = inv_rms.float() * proj.float() * alpha_.float() + self.bias.float() # H_pre = σ(α_pre * (θ_pre @ x̃) + b_pre) - h_pre = h[..., : self.n].sigmoid() # [s, b, n] + h_pre = h[..., : self.n].sigmoid().to(dtype=proj.dtype) # [s, b, n] # H_post = 2σ(α_post * (θ_post @ x̃) + b_post) - h_post = h[..., self.n : 2 * self.n].sigmoid() * 2 # [s, b, n] - h_res = h[..., 2 * self.n :] + h_post = (h[..., self.n : 2 * self.n].sigmoid() * 2).to(dtype=proj.dtype) # [s, b, n] + h_res = h[..., 2 * self.n :].to(dtype=proj.dtype) return h_pre, h_post, h_res @nvtx_decorator(message="HyperConnection::compute_mappings") @@ -291,7 +296,7 @@ def apply_h_post( self, x_with_bias: Tuple[Tensor, Optional[Tensor]], h_post: Tensor, - manager: Optional['CheckpointManager'] = None, + manager: Optional['MHCRecomputeManager'] = None, ) -> Tuple[Tensor, Optional[Tensor]]: """ Apply H_post to x and optionally bias, with optional checkpointing. @@ -304,7 +309,7 @@ def apply_h_post( - x: [s, b, C] - hidden states - bias: [C] or None - optional bias tensor h_post: [s, b, n] - expansion weights - manager: Optional CheckpointManager for checkpoint management. + manager: Optional MHCRecomputeManager for checkpoint management. When provided, wraps _apply_h_post with CheckpointWithoutOutput. Returns: @@ -378,14 +383,14 @@ def apply_h_res(self, h_res: Tensor, residual: Tensor) -> Tensor: return mixed.view(s, b, n * C) def forward( - self, hidden_states: Tensor, mhc_recompute_manager: Optional['CheckpointManager'] = None + self, hidden_states: Tensor, mhc_recompute_manager: Optional['MHCRecomputeManager'] = None ) -> Tuple[Tensor, Tensor, Tensor]: """ Full mHC forward pass. Args: hidden_states: [s, b, n*C] - n-stream hidden states - mhc_recompute_manager: Optional CheckpointManager for checkpoint management. + mhc_recompute_manager: Optional MHCRecomputeManager for checkpoint management. When provided, uses _forward_with_checkpoint for memory-efficient execution. Returns: @@ -420,7 +425,7 @@ def _forward_normal(self, hidden_states: Tensor) -> Tuple[Tensor, Tensor, Tensor return aggregated, h_res, h_post def _forward_with_checkpoint( - self, hidden_states: Tensor, manager: 'CheckpointManager' + self, hidden_states: Tensor, manager: 'MHCRecomputeManager' ) -> Tuple[Tensor, Tensor, Tensor]: """ Forward pass with checkpointing for memory efficiency. @@ -432,7 +437,7 @@ def _forward_with_checkpoint( Args: hidden_states: [s, b, n*C] - n-stream hidden states - manager: CheckpointManager for unified recomputation + manager: MHCRecomputeManager for unified recomputation Returns: aggregated: [s, b, C] - aggregated input for layer computation @@ -467,7 +472,8 @@ def input_expand(x: Tensor, n: int) -> Tensor: expanded: [s, b, n*C] - n-stream hidden states """ s, b, C = x.shape - # Replicate input to n streams + # expand() is a view; contiguous() intentionally materializes the n + # streams, so the reference path allocates an n*C activation here. expanded = x.unsqueeze(2).expand(s, b, n, C).contiguous() return expanded.view(s, b, n * C) @@ -504,7 +510,7 @@ def fused_h_res_h_post_bda( dropout_prob: float, training: bool, fused: bool, - manager: Optional['CheckpointManager'] = None, + manager: Optional['MHCRecomputeManager'] = None, ) -> Tensor: """ Fused kernel combining apply_h_res, apply_h_post and bias-dropout-add. @@ -527,7 +533,7 @@ def fused_h_res_h_post_bda( dropout_prob: Dropout probability training: Whether in training mode fused: Whether to use fused BDA implementation - manager: Optional CheckpointManager for checkpoint management. + manager: Optional MHCRecomputeManager for checkpoint management. When provided, each operation is wrapped with CheckpointWithoutOutput. Returns: @@ -615,7 +621,7 @@ def _fused_h_res_h_post_bda_with_checkpoint( dropout_prob: float, training: bool, fused: bool, - manager: 'CheckpointManager', + manager: 'MHCRecomputeManager', ) -> Tensor: """ Checkpointed variant of _fused_h_res_h_post_bda_native. @@ -633,7 +639,7 @@ def _fused_h_res_h_post_bda_with_checkpoint( dropout_prob: Dropout probability training: Whether in training mode fused: Whether to use fused BDA implementation - manager: CheckpointManager for checkpoint management + manager: MHCRecomputeManager for checkpoint management Returns: output: [s, b, n*C] - final output diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index 85bbe0744ae..6da42004e2e 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -20,7 +20,7 @@ from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.pipeline_parallel.utils import is_vp_first_stage, is_vp_last_stage from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.tensor_parallel.random import CheckpointManager +from megatron.core.tensor_parallel.random import MHCRecomputeManager from megatron.core.transformer.enums import CudaGraphScope, LayerType from megatron.core.transformer.hyper_connection import HyperConnectionModule from megatron.core.transformer.module import GraphableMegatronModule, MegatronModule @@ -670,29 +670,29 @@ def _build_mhc_recompute_block_end_plan(self) -> List[bool]: def _build_mhc_recompute_layer_plan( self, use_mhc_recompute: bool - ) -> Tuple[List[Optional[CheckpointManager]], List[bool]]: + ) -> Tuple[List[Optional[MHCRecomputeManager]], List[bool]]: """Build fresh per-forward MHC managers using cached block-end topology.""" num_layers = len(self.layers) - layer_managers: List[Optional[CheckpointManager]] = [None] * num_layers + layer_managers: List[Optional[MHCRecomputeManager]] = [None] * num_layers is_recompute_block_end = self._mhc_recompute_block_end_plan if not use_mhc_recompute or num_layers == 0: return layer_managers, is_recompute_block_end - mhc_manager = CheckpointManager() + mhc_manager = MHCRecomputeManager() for l_no, is_last_in_recompute_block in enumerate(is_recompute_block_end): is_last_in_transformer_block = l_no == num_layers - 1 layer_managers[l_no] = mhc_manager if is_last_in_recompute_block and not is_last_in_transformer_block: - mhc_manager = CheckpointManager() + mhc_manager = MHCRecomputeManager() return layer_managers, is_recompute_block_end @staticmethod def _finalize_mhc_recompute_layer( - mhc_manager: Optional[CheckpointManager], + mhc_manager: Optional[MHCRecomputeManager], hidden_states: Tensor, is_last_in_recompute_block: bool, ) -> None: diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 4342fd12775..11e6af6d235 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -493,7 +493,7 @@ class TransformerConfig(ModelParallelConfig): "moe": recompute the MoE layer. "shared_experts": recompute the shared experts in the MoE layer. "mhc": recompute HyperConnection intermediate activations via - CheckpointWithoutOutput + CheckpointManager. Requires + CheckpointWithoutOutput + MHCRecomputeManager. Requires enable_hyper_connections=True. Cannot be used with "mlp". "moe_act", "layernorm", "mla_up_proj", and "mhc" use output-discarding checkpointing, "core_attn", "mlp", "moe", and "shared_experts" use normal checkpointing. @@ -908,7 +908,7 @@ class TransformerConfig(ModelParallelConfig): layer in the transformer block) will: - NOT checkpoint its final MLP BDA - Register the unified recompute hook on its MLP BDA output - - A new CheckpointManager is created for subsequent layers + - A new MHCRecomputeManager is created for subsequent layers If None, all layers in the transformer block share a single recompute block. @@ -1500,6 +1500,7 @@ def __post_init__(self): "'mhc' and 'mlp' in recompute_modules cannot be used together. " "They use different checkpoint mechanisms that may conflict." ) + # bool is a subclass of int, so reject it explicitly. if self.mhc_recompute_layer_num is not None and ( isinstance(self.mhc_recompute_layer_num, bool) or not isinstance(self.mhc_recompute_layer_num, int) diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index ed1e58f8a1b..06bbea74779 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -5,11 +5,12 @@ import logging import warnings from abc import ABC +from contextvars import ContextVar from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Dict, Optional, Union if TYPE_CHECKING: - from megatron.core.tensor_parallel.random import CheckpointManager + from megatron.core.tensor_parallel.random import MHCRecomputeManager import torch import torch.distributed @@ -43,6 +44,7 @@ from megatron.core.inference.contexts import BaseInferenceContext logger = logging.getLogger(__name__) +_MHC_RECOMPUTE_MANAGER_CONTEXT = ContextVar("mhc_recompute_manager", default=None) def get_transformer_layer_offset( @@ -1429,24 +1431,28 @@ def _get_submodules_under_cudagraphs(self): return submodules def __call__(self, *args, **kwargs): - # Extract before CUDA graph manager processes kwargs; CheckpointManager - # is not a CUDA-graph-supported argument type. - self._mhc_recompute_manager = kwargs.pop("mhc_recompute_manager", None) + # Extract before CUDA graph manager processes kwargs; MHCRecomputeManager + # is not a CUDA-graph-supported argument type. MCore CUDA graph capture + # executes forward synchronously inside __call__, so this context-local + # handoff is visible during capture without storing per-call state on the + # module instance. + mhc_recompute_manager = kwargs.pop("mhc_recompute_manager", None) + token = _MHC_RECOMPUTE_MANAGER_CONTEXT.set(mhc_recompute_manager) try: return super().__call__(*args, **kwargs) finally: - self._mhc_recompute_manager = None + _MHC_RECOMPUTE_MANAGER_CONTEXT.reset(token) def forward(self, *args, **kwargs): """Forward pass with MHC recompute manager support.""" kwargs.pop("dynamic_inference_decode_only", None) - # Direct forward() calls can pass the manager normally. __call__ stores - # it on self first so CUDA graph argument processing never sees the - # unsupported CheckpointManager object. + # Direct forward() calls can pass the manager normally. __call__ uses a + # context-local handoff so CUDA graph argument processing never sees the + # unsupported MHCRecomputeManager object. mhc_recompute_manager = kwargs.pop("mhc_recompute_manager", None) if mhc_recompute_manager is None: - mhc_recompute_manager = getattr(self, '_mhc_recompute_manager', None) + mhc_recompute_manager = _MHC_RECOMPUTE_MANAGER_CONTEXT.get() hidden_states, context = self._forward_attention( *args, mhc_recompute_manager=mhc_recompute_manager, **kwargs @@ -1475,7 +1481,7 @@ def _forward_attention( packed_seq_params: Optional[PackedSeqParams] = None, sequence_len_offset: Optional[Tensor] = None, padding_mask: Optional[Tensor] = None, - mhc_recompute_manager: Optional['CheckpointManager'] = None, + mhc_recompute_manager: Optional['MHCRecomputeManager'] = None, *, inference_params: Optional[Any] = None, ): @@ -1599,7 +1605,7 @@ def _forward_mlp( hidden_states, inference_context=None, padding_mask=None, - mhc_recompute_manager: Optional['CheckpointManager'] = None, + mhc_recompute_manager: Optional['MHCRecomputeManager'] = None, ): """Forward MLP with hyper connection pre/post processing.""" from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( @@ -1697,7 +1703,7 @@ def _forward_post_mlp_with_fused_hyper_connection( mlp_h_res, residual, mlp_hc_h_post, - mhc_mlp_bda_recompute_manager: Optional['CheckpointManager'] = None, + mhc_mlp_bda_recompute_manager: Optional['MHCRecomputeManager'] = None, ): """ Perform operations after the MLP computation with fused hyper connection kernel. @@ -1709,7 +1715,7 @@ def _forward_post_mlp_with_fused_hyper_connection( mlp_h_res (Tensor): [s, b, n, n] - residual mixing matrix from hyper connection. residual (Tensor): [s, b, n*C] - original residual (n-stream hidden states). mlp_hc_h_post (Tensor): [s, b, n] - expansion weights from hyper connection. - mhc_recompute_manager: Optional CheckpointManager for checkpoint management. + mhc_recompute_manager: Optional MHCRecomputeManager for checkpoint management. Returns: output (Tensor): Transformed hidden states of shape [s, b, n*C]. diff --git a/tests/unit_tests/tensor_parallel/test_random.py b/tests/unit_tests/tensor_parallel/test_random.py index 4fa79733d55..6ba6a0598e9 100644 --- a/tests/unit_tests/tensor_parallel/test_random.py +++ b/tests/unit_tests/tensor_parallel/test_random.py @@ -200,6 +200,13 @@ def test_forward(*input): Utils.destroy_model_parallel() +def test_checkpoint_without_output_fp8_flag_is_explicit(): + assert CheckpointWithoutOutput().fp8 is False + assert CheckpointWithoutOutput(fp8=False).fp8 is False + assert CheckpointWithoutOutput(fp8=None).fp8 is False + assert CheckpointWithoutOutput(fp8="e4m3").fp8 is True + + def test_checkpoint_without_output(): def normal_forward(input): x = torch.nn.functional.gelu(input) diff --git a/tests/unit_tests/transformer/test_hyper_connection_recompute.py b/tests/unit_tests/transformer/test_hyper_connection_recompute.py index ebc149e8912..a82bc92075c 100644 --- a/tests/unit_tests/transformer/test_hyper_connection_recompute.py +++ b/tests/unit_tests/transformer/test_hyper_connection_recompute.py @@ -5,8 +5,8 @@ Tests the following functionality: 1. HyperConnectionModule._forward_with_checkpoint correctness -2. HyperConnectionModule.apply_h_post with CheckpointManager -3. Multiple HyperConnectionModules chained with a single CheckpointManager +2. HyperConnectionModule.apply_h_post with MHCRecomputeManager +3. Multiple HyperConnectionModules chained with a single MHCRecomputeManager 4. Partial checkpoint (last layer not checkpointed) 5. TransformerConfig 'mhc' in recompute_modules option """ @@ -16,8 +16,11 @@ import pytest import torch -from megatron.core.tensor_parallel.random import CheckpointManager, model_parallel_cuda_manual_seed -from megatron.core.transformer.hyper_connection import HyperConnectionModule +from megatron.core.tensor_parallel.random import ( + MHCRecomputeManager, + model_parallel_cuda_manual_seed, +) +from megatron.core.transformer.hyper_connection import HyperConnectionModule, reference_proj_inv_rms from megatron.core.transformer.transformer_config import TransformerConfig from tests.unit_tests.test_utilities import Utils @@ -48,6 +51,17 @@ def _create_hyper_connection_module(self, hidden_size=64, num_residual_streams=4 module.cuda() return module + def test_reference_proj_inv_rms_upcasts_norm_for_fp16(self): + width = 16384 + x = torch.full((1, width), 256.0, device='cuda', dtype=torch.float16) + weight = torch.zeros((1, width), device='cuda', dtype=torch.float16) + + _, inv_rms = reference_proj_inv_rms(x, weight) + + assert inv_rms.dtype == torch.float16 + assert torch.isfinite(inv_rms).all() + assert inv_rms.item() > 0 + def test_forward_normal_vs_checkpoint_correctness(self): """ Test that _forward_with_checkpoint produces the same outputs as _forward_normal. @@ -84,7 +98,7 @@ def test_forward_normal_vs_checkpoint_correctness(self): # Forward with checkpoint torch.manual_seed(42) torch.cuda.manual_seed(42) - manager = CheckpointManager() + manager = MHCRecomputeManager() aggregated_ckpt, h_res_ckpt, h_post_ckpt = module._forward_with_checkpoint( hidden_states_ckpt, manager ) @@ -144,7 +158,7 @@ def test_apply_h_post_with_checkpoint(self): # With checkpoint (manager provided) torch.manual_seed(42) - manager = CheckpointManager() + manager = MHCRecomputeManager() x_out_ckpt, bias_out_ckpt = module.apply_h_post( (x_ckpt, bias), h_post_ckpt, manager=manager ) @@ -193,7 +207,7 @@ def test_forward_with_manager_parameter(self): # With manager (uses _forward_with_checkpoint) torch.manual_seed(42) torch.cuda.manual_seed(42) - manager = CheckpointManager() + manager = MHCRecomputeManager() aggregated_ckpt, h_res_ckpt, h_post_ckpt = module.forward( hidden_states_ckpt, mhc_recompute_manager=manager ) @@ -208,7 +222,7 @@ def test_forward_with_manager_parameter(self): class TestMHCBlockRecomputeIntegration: - """Test CheckpointManager integration with HyperConnection.""" + """Test MHCRecomputeManager integration with HyperConnection.""" def setup_method(self, method): Utils.initialize_model_parallel(1, 1) @@ -220,7 +234,7 @@ def teardown_method(self, method): def test_multiple_hyper_connections_in_chain(self): """ Test that multiple HyperConnectionModules can be chained together - with a single CheckpointManager. + with a single MHCRecomputeManager. """ hidden_size = 64 num_streams = 4 @@ -277,7 +291,7 @@ def test_multiple_hyper_connections_in_chain(self): torch.manual_seed(42) torch.cuda.manual_seed(42) - manager = CheckpointManager() + manager = MHCRecomputeManager() h = hidden_states_ckpt r = residual_ckpt @@ -358,7 +372,7 @@ def test_partial_checkpoint_last_layer_not_checkpointed(self): # With manager - checkpoint everything except final output torch.manual_seed(42) torch.cuda.manual_seed(42) - manager = CheckpointManager() + manager = MHCRecomputeManager() aggregated_ckpt, h_res_ckpt, h_post_ckpt = module.forward( hidden_states_ckpt, mhc_recompute_manager=manager ) diff --git a/tests/unit_tests/transformer/test_mhc_block_manager.py b/tests/unit_tests/transformer/test_mhc_block_manager.py index f02ea1c555e..f3d09695169 100644 --- a/tests/unit_tests/transformer/test_mhc_block_manager.py +++ b/tests/unit_tests/transformer/test_mhc_block_manager.py @@ -4,7 +4,7 @@ import torch from megatron.core.tensor_parallel.random import ( - CheckpointManager, + MHCRecomputeManager, CheckpointWithoutOutput, initialize_rng_tracker, ) @@ -12,7 +12,7 @@ class TestCheckpointWithoutOutputManagerAPI: - """Test CheckpointWithoutOutput integration with CheckpointManager.""" + """Test CheckpointWithoutOutput integration with MHCRecomputeManager.""" def setup_method(self, method): Utils.initialize_model_parallel() @@ -23,7 +23,7 @@ def teardown_method(self, method): def test_auto_register(self): """CheckpointWithoutOutput auto-registers to manager when ckpt_manager is provided.""" - manager = CheckpointManager() + manager = MHCRecomputeManager() def func(x): return x * 2 + 1 @@ -50,7 +50,7 @@ def func(x): def test_discard_is_noop_with_manager(self): """discard_output_and_register_recompute is a NO-OP when ckpt_manager is set.""" - manager = CheckpointManager() + manager = MHCRecomputeManager() def func1(x): return x * 2 @@ -118,8 +118,8 @@ def func(x): assert torch.allclose(grad_ckpt, grad_ref, atol=1e-6) def test_error_handling(self): - """CheckpointManager rejects invalid add_checkpoint calls.""" - manager = CheckpointManager() + """MHCRecomputeManager rejects invalid add_checkpoint calls.""" + manager = MHCRecomputeManager() with pytest.raises(TypeError): manager.add_checkpoint("not a checkpoint") @@ -130,7 +130,7 @@ def test_error_handling(self): def test_unified_recompute_keeps_outputs_when_hook_has_no_grad(self): """Do not discard outputs if no hook can be registered for recompute.""" - manager = CheckpointManager() + manager = MHCRecomputeManager() def func(x): return x * 2 @@ -145,8 +145,8 @@ def func(x): assert y.untyped_storage().size() > 0 -class TestCheckpointManagerSequentialChain: - """Test CheckpointManager with sequential checkpoint chains.""" +class TestMHCRecomputeManagerSequentialChain: + """Test MHCRecomputeManager with sequential checkpoint chains.""" def setup_method(self, method): Utils.initialize_model_parallel() @@ -177,7 +177,7 @@ def func3(x): loss_ref.backward() grad_ref = input_ref.grad.clone() - manager = CheckpointManager() + manager = MHCRecomputeManager() y1 = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint(func1, input_ckpt) y2 = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint(func2, y1) @@ -221,7 +221,7 @@ def func2(x): torch.manual_seed(42) torch.cuda.manual_seed(42) - manager = CheckpointManager() + manager = MHCRecomputeManager() y1 = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint(func_with_dropout, input_ckpt) y2 = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint(func2, y1) @@ -237,7 +237,7 @@ def func2(x): ), f"Gradients with dropout mismatch!\nWith manager: {grad_ckpt}\nReference: {grad_ref}" def test_multiple_outputs(self): - """CheckpointManager handles functions that return multiple outputs.""" + """MHCRecomputeManager handles functions that return multiple outputs.""" def func_multi_output(x): return x * 2, x + 1 @@ -254,7 +254,7 @@ def func_combine(a, b): loss_ref.backward() grad_ref = input_ref.grad.clone() - manager = CheckpointManager() + manager = MHCRecomputeManager() y1a, y1b = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint( func_multi_output, input_ckpt @@ -273,8 +273,8 @@ def func_combine(a, b): ) -class TestCheckpointManagerPartialCheckpoint: - """Test CheckpointManager with partial checkpointing (some ops not checkpointed).""" +class TestMHCRecomputeManagerPartialCheckpoint: + """Test MHCRecomputeManager with partial checkpointing (some ops not checkpointed).""" def setup_method(self, method): Utils.initialize_model_parallel() @@ -311,7 +311,7 @@ def func_h(x): input_ckpt = input_ref.detach().clone().requires_grad_(True) - manager = CheckpointManager() + manager = MHCRecomputeManager() b = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint(func_f, input_ckpt) c = func_g(b) @@ -374,7 +374,7 @@ def apply_h_post(y, h_post): x_ckpt = x_ref.detach().clone().requires_grad_(True) residual_ckpt = residual_ref.detach().clone().requires_grad_(True) - manager = CheckpointManager() + manager = MHCRecomputeManager() h_pre, h_post, h_res = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint( compute_mappings, x_ckpt diff --git a/tests/unit_tests/transformer/test_transformer_layer.py b/tests/unit_tests/transformer/test_transformer_layer.py index 4fcb6fa9997..08c830f66e4 100644 --- a/tests/unit_tests/transformer/test_transformer_layer.py +++ b/tests/unit_tests/transformer/test_transformer_layer.py @@ -11,7 +11,7 @@ get_gpt_layer_with_transformer_engine_spec, get_gpt_layer_with_transformer_engine_submodules, ) -from megatron.core.tensor_parallel.random import CheckpointManager, model_parallel_cuda_manual_seed +from megatron.core.tensor_parallel.random import MHCRecomputeManager, model_parallel_cuda_manual_seed from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import ( @@ -452,7 +452,7 @@ def test_forward_with_hyper_connection_recompute(self): attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') # Create manager for MHC block recomputation - manager = CheckpointManager() + manager = MHCRecomputeManager() # Forward pass with recompute manager manager.is_last_layer_in_recompute_block = True @@ -499,7 +499,7 @@ def test_intermediate_layer_with_recompute(self): ) attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') - manager = CheckpointManager() + manager = MHCRecomputeManager() # Forward pass - NOT the last layer in block manager.is_last_layer_in_recompute_block = False @@ -526,7 +526,7 @@ def test_intermediate_layer_with_recompute(self): def test_multiple_layers_chain_with_recompute(self): """ Test multiple TransformerLayers chained together with a single - CheckpointManager, simulating TransformerBlock behavior. + MHCRecomputeManager, simulating TransformerBlock behavior. """ hidden_size = 64 num_streams = 4 @@ -551,7 +551,7 @@ def test_multiple_layers_chain_with_recompute(self): attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') # Single manager for all layers (like TransformerBlock) - manager = CheckpointManager() + manager = MHCRecomputeManager() # Forward through all layers h = hidden_states @@ -597,7 +597,7 @@ def _run_forward_backward( ): """Run a full forward + backward pass and return (peak memory, output grad). - When use_recompute=True, a new CheckpointManager is created every + When use_recompute=True, a new MHCRecomputeManager is created every `recompute_block_size` layers, mirroring TransformerBlock's _build_mhc_recompute_layer_plan logic. """ @@ -627,7 +627,7 @@ def _run_forward_backward( torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() - manager = CheckpointManager() if use_recompute else None + manager = MHCRecomputeManager() if use_recompute else None h = hidden_states for i, layer in enumerate(layers): @@ -640,7 +640,7 @@ def _run_forward_backward( if manager is not None and is_last_in_block: manager.discard_all_outputs_and_register_unified_recompute(h) if i < num_layers - 1: - manager = CheckpointManager() + manager = MHCRecomputeManager() loss = h.sum() loss.backward() @@ -864,9 +864,9 @@ def test_cuda_graph_fwd_bwd_with_hyper_connection(self): ) def test_cuda_graph_fwd_bwd_with_hyper_connection_and_recompute(self): - """CUDA graph capture+replay for fwd+bwd with mHC and CheckpointManager. + """CUDA graph capture+replay for fwd+bwd with mHC and MHCRecomputeManager. - When a CheckpointManager is used, additional CheckpointWithoutOutput + When a MHCRecomputeManager is used, additional CheckpointWithoutOutput objects are created for layernorm and hyper-connection operations. The manager discards intermediate activations during forward (storage.resize_(0)) and recomputes them during backward via a unified gradient hook. @@ -889,7 +889,7 @@ def test_cuda_graph_fwd_bwd_with_hyper_connection_and_recompute(self): s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): - mgr = CheckpointManager() + mgr = MHCRecomputeManager() mgr.is_last_layer_in_recompute_block = True out, _ = layer( hidden_states=static_input, @@ -903,7 +903,7 @@ def test_cuda_graph_fwd_bwd_with_hyper_connection_and_recompute(self): layer.zero_grad(set_to_none=True) static_input.grad = None - capture_mgr = CheckpointManager() + capture_mgr = MHCRecomputeManager() capture_mgr.is_last_layer_in_recompute_block = True g = torch.cuda.CUDAGraph() @@ -941,7 +941,7 @@ def test_cuda_graph_fwd_bwd_with_hyper_connection_and_recompute(self): graph_out = output.detach().clone() graph_grad = static_input.grad.detach().clone() - eager_mgr = CheckpointManager() + eager_mgr = MHCRecomputeManager() eager_mgr.is_last_layer_in_recompute_block = True eager_input = test_data.clone().requires_grad_(True) eager_output, _ = layer( @@ -966,7 +966,7 @@ def test_mcore_cudagraph_manager_with_mhc_recompute_manager(self): When cuda_graph_impl="local" is set, TransformerLayer.__call__ routes through MegatronModule.__call__ → CudaGraphManager.__call__, which - iterates over all kwargs to check supported types. CheckpointManager + iterates over all kwargs to check supported types. MHCRecomputeManager (used by mhc_recompute_manager) is not a CUDA-graph-supported type. This test verifies that mhc_recompute_manager is properly extracted @@ -989,7 +989,7 @@ def test_mcore_cudagraph_manager_with_mhc_recompute_manager(self): ) attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') - mgr = CheckpointManager() + mgr = MHCRecomputeManager() mgr.is_last_layer_in_recompute_block = True output, context = layer( From b03220133e47538133fb5f6195f7b7bba0ccfa5b Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Mon, 27 Apr 2026 17:29:15 -0700 Subject: [PATCH 09/10] Address Claude mHC low-risk review items --- megatron/core/tensor_parallel/random.py | 4 - megatron/core/transformer/hyper_connection.py | 98 +++++++++++-------- .../core/transformer/transformer_config.py | 13 ++- .../core/transformer/transformer_layer.py | 28 +++--- .../test_hyper_connection_recompute.py | 17 ++++ 5 files changed, 96 insertions(+), 64 deletions(-) diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py index d4c05bf3cae..e0c243b26d2 100644 --- a/megatron/core/tensor_parallel/random.py +++ b/megatron/core/tensor_parallel/random.py @@ -804,10 +804,6 @@ def _unified_recompute_hook(self, grad_output): ckpt._recompute(None) -# Backward-compatible alias for earlier internal users of the generic name. -CheckpointManager = MHCRecomputeManager - - class CheckpointWithoutOutput(object): """ Checkpoint a model or part of the model and release the output. diff --git a/megatron/core/transformer/hyper_connection.py b/megatron/core/transformer/hyper_connection.py index e5b611371dd..a2f2d013977 100644 --- a/megatron/core/transformer/hyper_connection.py +++ b/megatron/core/transformer/hyper_connection.py @@ -183,11 +183,11 @@ def _init_weights(self) -> None: # This is required because HyperConnectionModule uses non-TP-aware layers # (nn.Linear, nn.RMSNorm) whose gradients need to be all-reduced. if self.config.sequence_parallel: - setattr(self.mapping_proj.weight, 'sequence_parallel', True) - setattr(self.alpha_pre, 'sequence_parallel', True) - setattr(self.alpha_post, 'sequence_parallel', True) - setattr(self.alpha_res, 'sequence_parallel', True) - setattr(self.bias, 'sequence_parallel', True) + self.mapping_proj.weight.sequence_parallel = True + self.alpha_pre.sequence_parallel = True + self.alpha_post.sequence_parallel = True + self.alpha_res.sequence_parallel = True + self.bias.sequence_parallel = True def _projection_and_get_norm(self, x: Tensor) -> Tuple[Tensor, Tensor]: """ @@ -259,16 +259,14 @@ def compute_mappings(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: return h_pre, h_post, h_res @torch.compile - def _apply_h_post(self, x: Tensor, h_post: Tensor) -> Tensor: + def _apply_h_post_hidden(self, x: Tensor, h_post: Tensor) -> Tensor: """ - Core implementation of H_post application to a single tensor. + Apply H_post to hidden states. Computes: H_post^T @ x Args: - x: Input tensor, can be either: - - [s, b, C] - standard hidden states - - [C] - bias tensor (will be broadcast) + x: [s, b, C] - standard hidden states h_post: [s, b, n] - expansion weights Returns: @@ -276,21 +274,34 @@ def _apply_h_post(self, x: Tensor, h_post: Tensor) -> Tensor: """ n = self.n s, b, _ = h_post.shape - - if x.dim() == 1: - # x is bias with shape [C], need to broadcast to [s, b, 1, C] - C = x.shape[0] - x_expanded = x.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand(s, b, 1, C) - else: - # x is [s, b, C] - C = x.shape[-1] - x_expanded = x.unsqueeze(2) # [s, b, 1, C] + C = x.shape[-1] + x_expanded = x.unsqueeze(2) # [s, b, 1, C] # h_post^T @ x : [s, b, n, 1] * [s, b, 1, C] -> [s, b, n, C] - # Using broadcast multiply instead of einsum result = h_post.unsqueeze(-1) * x_expanded return result.view(s, b, n * C) + @torch.compile + def _apply_h_post_bias(self, bias: Tensor, h_post: Tensor) -> Tensor: + """ + Apply H_post to a bias vector. + + Args: + bias: [C] - bias tensor broadcast across sequence and batch + h_post: [s, b, n] - expansion weights + + Returns: + output: [s, b, n*C] - expanded bias + """ + n = self.n + s, b, _ = h_post.shape + C = bias.shape[0] + bias_expanded = bias.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand(s, b, 1, C) + + # h_post^T @ x : [s, b, n, 1] * [s, b, 1, C] -> [s, b, n, C] + result = h_post.unsqueeze(-1) * bias_expanded + return result.view(s, b, n * C) + @nvtx_decorator(message="HyperConnection::apply_h_post") def apply_h_post( self, @@ -322,22 +333,22 @@ def apply_h_post( if manager is not None: from megatron.core.tensor_parallel.random import CheckpointWithoutOutput - # Checkpoint _apply_h_post to discard the output + # Checkpoint H_post application to discard the output x_out = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint( - self._apply_h_post, x, h_post + self._apply_h_post_hidden, x, h_post ) - # Checkpoint _apply_h_post for bias if not None + # Checkpoint H_post bias expansion if not None if bias is not None: bias_out = CheckpointWithoutOutput(ckpt_manager=manager).checkpoint( - self._apply_h_post, bias, h_post + self._apply_h_post_bias, bias, h_post ) else: bias_out = None else: # Normal execution without checkpoint - x_out = self._apply_h_post(x, h_post) - bias_out = self._apply_h_post(bias, h_post) if bias is not None else None + x_out = self._apply_h_post_hidden(x, h_post) + bias_out = self._apply_h_post_bias(bias, h_post) if bias is not None else None return x_out, bias_out @@ -433,7 +444,7 @@ def _forward_with_checkpoint( compute_mappings is called directly (not checkpointed) since its outputs (h_pre, h_post, h_res) are needed downstream. Only aggregate is wrapped with CheckpointWithoutOutput and auto-registered to the manager. - apply_h_res is deferred to fused_h_res_h_post_bda for kernel fusion. + apply_h_res is deferred to h_res_h_post_bda for kernel fusion. Args: hidden_states: [s, b, n*C] - n-stream hidden states @@ -498,10 +509,10 @@ def output_contract(x: Tensor, n: int) -> Tensor: contracted = x_streams.mean(dim=2) return contracted - # ==================== Fused kernel placeholder ==================== + # ==================== Combined H_res + H_post + BDA path ==================== - @nvtx_decorator(message="HyperConnection::fused_h_res_h_post_bda") - def fused_h_res_h_post_bda( + @nvtx_decorator(message="HyperConnection::h_res_h_post_bda") + def h_res_h_post_bda( self, h_res: Tensor, original_residual: Tensor, @@ -513,10 +524,11 @@ def fused_h_res_h_post_bda( manager: Optional['MHCRecomputeManager'] = None, ) -> Tensor: """ - Fused kernel combining apply_h_res, apply_h_post and bias-dropout-add. + Combine apply_h_res, apply_h_post and bias-dropout-add. - This is a placeholder for future kernel fusion optimization. - Currently implements the operations sequentially using native PyTorch. + This is a reference implementation that uses native PyTorch for the + dropout path. Actual fused kernels are selected through _h_post_bda_op + when dropout is disabled or training is off. The computation flow is: 1. mixed = H_res @ original_residual (apply_h_res) @@ -540,7 +552,7 @@ def fused_h_res_h_post_bda( output: [s, b, n*C] - final output after all operations """ if manager is not None: - return self._fused_h_res_h_post_bda_with_checkpoint( + return self._h_res_h_post_bda_with_checkpoint( h_res, original_residual, h_post, @@ -551,7 +563,7 @@ def fused_h_res_h_post_bda( manager, ) else: - return self._fused_h_res_h_post_bda_native( + return self._h_res_h_post_bda_native( h_res, original_residual, h_post, @@ -561,7 +573,7 @@ def fused_h_res_h_post_bda( fused, ) - def _fused_h_res_h_post_bda_native( + def _h_res_h_post_bda_native( self, h_res: Tensor, original_residual: Tensor, @@ -604,15 +616,15 @@ def _fused_h_res_h_post_bda_native( with torch.cuda.nvtx.range("HyperConnection::apply_h_res"): mixed = self.apply_h_res(h_res, original_residual) with torch.cuda.nvtx.range("HyperConnection::apply_h_post"): - x_expanded = self._apply_h_post(x, h_post) - bias_expanded = self._apply_h_post(bias, h_post) if bias is not None else None + x_expanded = self._apply_h_post_hidden(x, h_post) + bias_expanded = self._apply_h_post_bias(bias, h_post) if bias is not None else None bda_func = get_bias_dropout_add(training, fused) with torch.cuda.nvtx.range("HyperConnection::bda"): output = bda_func((x_expanded, bias_expanded), mixed, dropout_prob) return output - @nvtx_decorator(message="HyperConnection::fused_h_res_h_post_bda_with_checkpoint") - def _fused_h_res_h_post_bda_with_checkpoint( + @nvtx_decorator(message="HyperConnection::h_res_h_post_bda_with_checkpoint") + def _h_res_h_post_bda_with_checkpoint( self, h_res: Tensor, original_residual: Tensor, @@ -624,7 +636,7 @@ def _fused_h_res_h_post_bda_with_checkpoint( manager: 'MHCRecomputeManager', ) -> Tensor: """ - Checkpointed variant of _fused_h_res_h_post_bda_native. + Checkpointed variant of _h_res_h_post_bda_native. Wraps compute in CheckpointWithoutOutput for activation memory savings. Cannot reuse _native directly because checkpoint requires all args to be @@ -677,9 +689,9 @@ def _reference_wrapper(h_res, original_residual, h_post, x, *optional_bias): with torch.cuda.nvtx.range("HyperConnection::apply_h_res"): mixed = self.apply_h_res(h_res, original_residual) with torch.cuda.nvtx.range("HyperConnection::apply_h_post"): - x_expanded = self._apply_h_post(x, h_post) + x_expanded = self._apply_h_post_hidden(x, h_post) if has_bias: - bias_expanded = self._apply_h_post(optional_bias[0], h_post) + bias_expanded = self._apply_h_post_bias(optional_bias[0], h_post) else: bias_expanded = None with torch.cuda.nvtx.range("HyperConnection::bda"): diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 11e6af6d235..ea4c4ccd15e 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -882,7 +882,12 @@ class TransformerConfig(ModelParallelConfig): """Enable mHC residual connections.""" num_residual_streams: int = 4 - """Number of residual streams (n in paper).""" + """Number of residual streams (n in paper). + + Within each hyper-connection transformer block, hidden states are expanded + from [s, b, C] to [s, b, n*C], so activation memory in the block scales + roughly linearly with this value. + """ mhc_sinkhorn_iterations: int = 20 """Number of Sinkhorn-Knopp iterations for doubly stochastic projection.""" @@ -1523,6 +1528,12 @@ def __post_init__(self): and self.recompute_granularity is not None and not (self.recompute_granularity == "selective" and "mhc" in self.recompute_modules) ): + if self.recompute_granularity == "full": + raise ValueError( + "enable_hyper_connections is not yet compatible with full activation " + "recompute. Use selective recompute with 'mhc' in recompute_modules " + "or disable activation recompute." + ) warnings.warn( "HyperConnections are enabled but 'mhc' is not in " "recompute_modules with selective recompute. Consider adding 'mhc' to " diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 06bbea74779..c4ffa0734ac 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -215,16 +215,21 @@ class TransformerLayerSubmodules: Args: input_layernorm: Specification for the input layer normalization. + self_attention_hyper_connection: Specification for the hyper-connection module + before/after self-attention. self_attention (Union[ModuleSpec, type]): Specification for the self-attention mechanism. self_attn_bda (Union[ModuleSpec, type]): Specification for the bias-dropout-add operation after self-attention. pre_cross_attn_layernorm: Specification for the layer normalization before cross-attention. + cross_attention_hyper_connection: Reserved for future cross-attention + hyper-connection support and must remain IdentityOp for now. cross_attention (Union[ModuleSpec, type]): Specification for the cross-attention mechanism. cross_attn_bda (Union[ModuleSpec, type]): Specification for the bias-dropout-add operation after cross-attention. pre_mlp_layernorm: Specification for the layer normalization before the MLP. + mlp_hyper_connection: Specification for the hyper-connection module before/after the MLP. mlp (Union[ModuleSpec, type]): Specification for the MLP in Dense layer. mlp_bda (Union[ModuleSpec, type]): Specification for the bias-dropout-add operation after the MLP. @@ -238,7 +243,7 @@ class TransformerLayerSubmodules: self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp pre_cross_attn_layernorm: LayerNormBuilder = IdentityOp - # Reserved for future cross-attention hyper-connection support. + # TODO: wire this when cross-attention hyper-connection support is added. cross_attention_hyper_connection: Union[ModuleSpec, type] = IdentityOp cross_attention: Union[ModuleSpec, type] = IdentityOp cross_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp @@ -1291,15 +1296,6 @@ def _should_call_local_cudagraph(self, *args, **kwargs): return True return False - def backward_dw_cudagraph(self, microbatch_idx): - """ - CUDA Graph backward weight gradient computation for this layer. - """ - cg_index = microbatch_idx % len(self.cuda_graphs) - if not hasattr(self.cuda_graphs[cg_index], 'backward_dw'): - return - self.cuda_graphs[cg_index].backward_dw() - def __call__(self, *args, **kwargs): if self._should_call_local_cudagraph(*args, **kwargs): # Inference mode. @@ -1548,9 +1544,9 @@ def _forward_attention( attention_output_with_bias[0] ) - nvtx_range_push(suffix="self_attention_fused_h_res_h_post_bda") + nvtx_range_push(suffix="self_attention_h_res_h_post_bda") with self.bias_dropout_add_exec_handler(): - hidden_states = self.self_attention_hyper_connection.fused_h_res_h_post_bda( + hidden_states = self.self_attention_hyper_connection.h_res_h_post_bda( self_attn_h_res, residual, self_attn_hc_h_post, @@ -1560,7 +1556,7 @@ def _forward_attention( self.config.bias_dropout_fusion, mhc_recompute_manager, ) - nvtx_range_pop(suffix="self_attention_fused_h_res_h_post_bda") + nvtx_range_pop(suffix="self_attention_h_res_h_post_bda") if self.offload_attn_norm: hidden_states = off_interface.group_commit(hidden_states, name="attn_norm") @@ -1727,9 +1723,9 @@ def _forward_post_mlp_with_fused_hyper_connection( mlp_output_with_bias[0] ) - nvtx_range_push(suffix="mlp_fused_h_res_h_post_bda") + nvtx_range_push(suffix="mlp_h_res_h_post_bda") with self.bias_dropout_add_exec_handler(): - hidden_states = self.mlp_hyper_connection.fused_h_res_h_post_bda( + hidden_states = self.mlp_hyper_connection.h_res_h_post_bda( mlp_h_res, residual, mlp_hc_h_post, @@ -1739,7 +1735,7 @@ def _forward_post_mlp_with_fused_hyper_connection( self.config.bias_dropout_fusion, mhc_mlp_bda_recompute_manager, ) - nvtx_range_pop(suffix="mlp_fused_h_res_h_post_bda") + nvtx_range_pop(suffix="mlp_h_res_h_post_bda") if self.offload_mlp_norm: from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( diff --git a/tests/unit_tests/transformer/test_hyper_connection_recompute.py b/tests/unit_tests/transformer/test_hyper_connection_recompute.py index a82bc92075c..0b599f990a4 100644 --- a/tests/unit_tests/transformer/test_hyper_connection_recompute.py +++ b/tests/unit_tests/transformer/test_hyper_connection_recompute.py @@ -495,6 +495,23 @@ def test_hyper_connection_recompute_warning_for_selective_without_mhc(self): recompute_granularity="selective", ) + def test_config_rejects_full_recompute_hyper_connections(self): + """Full activation recompute is not wired for hyper-connection blocks yet.""" + with pytest.raises( + ValueError, + match="enable_hyper_connections is not yet compatible with full activation recompute", + ): + TransformerConfig( + num_layers=2, + hidden_size=64, + num_attention_heads=4, + enable_hyper_connections=True, + num_residual_streams=4, + recompute_granularity="full", + recompute_method="block", + recompute_num_layers=1, + ) + if __name__ == "__main__": pytest.main([__file__, "-v"]) From 05c434389448398f8036b0eb126fb257eef2d64b Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Wed, 29 Apr 2026 11:24:49 -0700 Subject: [PATCH 10/10] Add mHC support for HybridModel on dsv4 Adds initial mHC support for `HybridModel` / `HybridStack` via a layer-boundary wrapper that treats each hybrid layer as a single function. Stacks on the transformer mHC reference impl in #4483. Implementation - `HyperConnectionHybridLayer`: wraps an inner hybrid layer (Mamba / GDN / TransformerLayer / DSA / MoE / MLP), aggregates the n-stream input down to single-stream, runs the inner layer, then feeds the layer delta `f(aggregated)` back through the n-stream H_res / H_post BDA so the wrapped layer's own residual update is not double-counted. - `HybridStack`: expands at the first-process boundary (`HyperConnectionModule.input_expand`) and contracts at the final layernorm (`output_contract`). Plumbs `mhc_recompute_manager` through the per-layer forward; caches the deterministic block-end plan via `_compute_mhc_block_end_plan`. Strict-review fixes carried over from the original PR - shape-preservation assert on `layer_output == aggregated.shape` so future inner layer types that drop the residual contract fail loud. - `aggregated -> layer_output.dtype` upcast before the delta subtraction when `fp32_residual_connection=True`. - `training=self.training` (not hard-coded `False`) on the BDA call. - explicit downcast of the BDA result to `params_dtype` after `h_res_h_post_bda` when `fp32_residual_connection=True`, so fp32 n-stream hidden states do not silently propagate to subsequent layers (~2x activation memory). - non-transformer inner-layer branch comment names the rotary_pos_emb / sequence_len_offset / padding_mask args that are intentionally dropped. Tests - `test_hybrid_model.py`: dummy HybridModel + Mamba + attention + MLP + GDN + DSA + DeepSeek-style proxy patterns. - `test_dsa_gpt_mamba_equivalence.py`: mamba <-> hybrid wrapping parity check. - `test_hybrid_block.py`: forward/backward + recompute coverage. Co-Authored-By: Claude Opus 4.7 (1M context) --- megatron/core/models/hybrid/hybrid_block.py | 258 +++++++++++++++++- .../models/test_dsa_gpt_mamba_equivalence.py | 50 ++++ tests/unit_tests/models/test_hybrid_model.py | 149 +++++++++- tests/unit_tests/ssm/test_hybrid_block.py | 173 +++++++++++- 4 files changed, 622 insertions(+), 8 deletions(-) diff --git a/megatron/core/models/hybrid/hybrid_block.py b/megatron/core/models/hybrid/hybrid_block.py index 5494d531e52..6d991dad786 100644 --- a/megatron/core/models/hybrid/hybrid_block.py +++ b/megatron/core/models/hybrid/hybrid_block.py @@ -7,7 +7,7 @@ from contextlib import nullcontext from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torch import Tensor, nn @@ -22,8 +22,10 @@ from megatron.core.models.hybrid.hybrid_layer_allocation import Symbols as LayerSymbols from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.tensor_parallel.random import CheckpointManager from megatron.core.transformer import TransformerConfig from megatron.core.transformer.enums import CudaGraphScope +from megatron.core.transformer.hyper_connection import HyperConnectionModule from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.module import GraphableMegatronModule, MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module @@ -47,6 +49,158 @@ class HybridStackSubmodules: mtp_block_spec: Optional[ModuleSpec] = None +class HyperConnectionHybridLayer(MegatronModule): + """Layer-boundary mHC wrapper for HybridStack layers. + + Hybrid layers already own their local residual paths. For this initial + integration we treat each hybrid layer as a single function by aggregating + n streams to the layer input, running the existing layer, and feeding only + the layer delta back through mHC expansion. The expansion path intentionally + uses zero additional dropout because the wrapped hybrid layer has already + applied its local dropout/residual update before the delta is computed. + + Checkpoint compatibility: this is a *wrapper* (the inner layer is held as + `self.inner_layer`), so wrapped-layer state_dict keys are nested under + `inner_layer.` (e.g. `layers.0.inner_layer.input_layernorm.weight` instead + of `layers.0.input_layernorm.weight`). HybridStack checkpoints saved with + `enable_hyper_connections=False` cannot be loaded into a model with + `enable_hyper_connections=True` (and vice versa) without a key-mapping + migration. Note: this differs from `HyperConnectionTransformerLayer`, + which subclasses `TransformerLayer` and only adds new sibling fields, + keeping all base keys stable. + """ + + def __init__(self, config: TransformerConfig, layer: MegatronModule) -> None: + super().__init__(config=config) + self.inner_layer = layer + self.layer_number = layer.layer_number + self.hyper_connection = HyperConnectionModule(config=config, layer_number=self.layer_number) + if config.params_dtype is not None: + self.hyper_connection.to(dtype=config.params_dtype) + if hasattr(layer, 'tp_group'): + self.tp_group = layer.tp_group + + def mamba_state_shapes_per_request(self) -> Optional[Tuple[Tuple[int], Tuple[int]]]: + """Delegate Mamba inference state shape requests to the wrapped layer.""" + if not hasattr(self.inner_layer, 'mamba_state_shapes_per_request'): + return None + return self.inner_layer.mamba_state_shapes_per_request() + + def _call_inner_layer( + self, + hidden_states: Tensor, + attention_mask: Tensor, + inference_context: Optional[BaseInferenceContext], + rotary_pos_emb: Optional[Tensor], + sequence_len_offset: Optional[Tensor], + packed_seq_params: Optional[PackedSeqParams], + padding_mask: Optional[Tensor], + ) -> Tuple[Tensor, Optional[Tensor]]: + if isinstance(self.inner_layer, TransformerLayer): + output = self.inner_layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_context=inference_context, + rotary_pos_emb=rotary_pos_emb, + sequence_len_offset=sequence_len_offset, + packed_seq_params=packed_seq_params, + padding_mask=padding_mask, + _called_from_hybrid_mhc_wrapper=True, + ) + else: + # Non-transformer layers (e.g. MambaLayer; GatedDeltaNet which does + # accept `sequence_len_offset` is currently always wrapped inside a + # TransformerLayer spec, so it takes the branch above) do not accept + # rotary_pos_emb / sequence_len_offset / padding_mask — pass only + # the common arguments. New layer types that consume any of these + # must add explicit handling here. + output = self.inner_layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_context=inference_context, + packed_seq_params=packed_seq_params, + ) + + if isinstance(output, tuple): + context = output[1] if len(output) > 1 else None + return output[0], context + return output, None + + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + inference_context: Optional[BaseInferenceContext] = None, + rotary_pos_emb: Optional[Tensor] = None, + sequence_len_offset: Optional[Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + padding_mask: Optional[Tensor] = None, + mhc_recompute_manager=None, + ) -> Tuple[Tensor, Optional[Tensor]]: + """Run the wrapped hybrid layer through one layer-boundary mHC update.""" + residual = hidden_states + aggregated, h_res, h_post = self.hyper_connection( + hidden_states, mhc_recompute_manager=mhc_recompute_manager + ) + layer_output, context = self._call_inner_layer( + aggregated, + attention_mask, + inference_context, + rotary_pos_emb, + sequence_len_offset, + packed_seq_params, + padding_mask, + ) + # The inner hybrid layer already applied its own local residual/dropout, so + # it returns `aggregated + f(aggregated)`. We feed only the function + # delta `f(aggregated)` into the n-stream BDA so it does not double-count + # the residual that mHC owns. The temporary [s, b, C] tensor here is the + # simplest correct form; a future optimization could fuse the subtraction + # into `h_res_h_post_bda` to avoid the allocation. + # Sanity check: this contract requires the inner layer to preserve shape; + # any mismatch indicates a future layer type is breaking the residual + # assumption and would silently corrupt the n-stream state. + assert layer_output.shape == aggregated.shape, ( + "HyperConnectionHybridLayer requires inner layers to preserve " + f"hidden-state shape. Got {tuple(layer_output.shape)} from inner layer " + f"vs {tuple(aggregated.shape)} input — layer must add its own residual." + ) + # `fp32_residual_connection=True` may cause some inner layers (e.g., + # MambaLayer) to return `layer_output` in fp32 while `aggregated` is in + # compute dtype; explicitly upcast `aggregated` so the subtraction stays + # in fp32 instead of relying on PyTorch's implicit promotion. + if self.config.fp32_residual_connection and aggregated.dtype != layer_output.dtype: + aggregated = aggregated.to(layer_output.dtype) + layer_delta = layer_output - aggregated + # `dropout_prob=0.0` already disables dropout regardless of training mode; + # `training=self.training` is more semantically accurate than hard-coding + # False during a training-mode forward. + hidden_states = self.hyper_connection.h_res_h_post_bda( + h_res, + residual, + h_post, + (layer_delta, None), + dropout_prob=0.0, + training=self.training, + fused=False, + manager=mhc_recompute_manager, + ) + # In `HyperConnectionTransformerLayer` the n-stream output stays in compute + # dtype because the post-attention `x` is in compute dtype. In the hybrid + # wrapper, `layer_delta` may be fp32 (when `fp32_residual_connection=True` + # or an inner layer upcasts), so `h_post_bda`'s `output.to(x.dtype)` would + # leave the result in fp32 and silently propagate fp32 n-stream hidden + # states to every subsequent layer (~2x activation memory). Restore the + # compute-dtype contract here. + if ( + self.config.fp32_residual_connection + and self.config.params_dtype is not None + and hidden_states.dtype != self.config.params_dtype + ): + hidden_states = hidden_states.to(self.config.params_dtype) + return hidden_states, context + + class HybridStack(GraphableMegatronModule, MegatronModule): """ Constructor for the HybridStack class. @@ -101,6 +255,10 @@ def __init__( self.input_tensor = None self.pg_collection = pg_collection + # Lazily populated mHC recompute layout cache (deterministic from config + # and num_layers); see `_build_mhc_recompute_layer_plan`. + self._mhc_block_end_plan: Optional[List[bool]] = None + assert layer_type_list is not None, ( "layer_type_list must be provided. It should be pre-computed from " "--hybrid-layer-pattern by HybridModel." @@ -173,6 +331,8 @@ def __init__( ) else: raise ValueError("unexpected layer_type") + if self.config.enable_hyper_connections: + layer = HyperConnectionHybridLayer(config=self.config, layer=layer) self.layers.append(layer) # Required for activation recomputation @@ -239,6 +399,59 @@ def __call__(self, *args, **kwargs): return super().__call__(*args, **kwargs)[0] return super().__call__(*args, **kwargs) + def _compute_mhc_block_end_plan(self) -> List[bool]: + """Compute per-layer block-end markers (deterministic from config).""" + num_layers = len(self.layers) + is_recompute_block_end: List[bool] = [False] * num_layers + if num_layers == 0: + return is_recompute_block_end + mhc_recompute_layer_num = self.config.mhc_recompute_layer_num + for l_no in range(num_layers): + is_last_in_stack = l_no == num_layers - 1 + is_last_in_recompute_block = is_last_in_stack + if mhc_recompute_layer_num is not None: + is_last_in_recompute_block = is_last_in_stack or ( + (l_no + 1) % mhc_recompute_layer_num == 0 + ) + is_recompute_block_end[l_no] = is_last_in_recompute_block + return is_recompute_block_end + + def _build_mhc_recompute_layer_plan( + self, use_mhc_recompute: bool + ) -> Tuple[List[Optional[CheckpointManager]], List[bool]]: + """Pre-build per-layer MHC recompute managers and block-end markers. + + The block-end plan is deterministic from config and cached on the + instance; only the per-block ``CheckpointManager`` instances are + allocated fresh per forward pass (managers are single-use). Mirrors + the caching scheme used by ``TransformerBlock``. + """ + num_layers = len(self.layers) + if not use_mhc_recompute or num_layers == 0: + return [None] * num_layers, [False] * num_layers + + if self._mhc_block_end_plan is None: + self._mhc_block_end_plan = self._compute_mhc_block_end_plan() + is_recompute_block_end = self._mhc_block_end_plan + + layer_managers: List[Optional[CheckpointManager]] = [None] * num_layers + mhc_manager = CheckpointManager() + for l_no in range(num_layers): + layer_managers[l_no] = mhc_manager + if is_recompute_block_end[l_no] and l_no != num_layers - 1: + mhc_manager = CheckpointManager() + return layer_managers, is_recompute_block_end + + @staticmethod + def _finalize_mhc_recompute_layer( + mhc_manager: Optional[CheckpointManager], + hidden_states: Tensor, + is_last_in_recompute_block: bool, + ) -> None: + """Finalize MHC recompute state for the current layer when a block ends.""" + if mhc_manager is not None and is_last_in_recompute_block: + mhc_manager.discard_all_outputs_and_register_unified_recompute(hidden_states) + def forward( self, hidden_states: Union[Tensor, WrappedTensor], @@ -278,6 +491,11 @@ def forward( if isinstance(hidden_states, WrappedTensor): hidden_states = hidden_states.unwrap() + if self.config.enable_hyper_connections and self.pre_process: + hidden_states = HyperConnectionModule.input_expand( + hidden_states, self.config.num_residual_streams + ) + if inference_context and inference_context.is_static_batching(): # NOTE(bnorick): match BaseInferenceContext attributes for # mamba_ssm.utils.generation.BaseInferenceContext, @@ -331,13 +549,29 @@ def get_inner_quant_context(config, layer_number): def get_inner_quant_context(config, layer_number): return nullcontext() + use_mhc_recompute = ( + self.training + and self.config.enable_hyper_connections + and self.config.recompute_granularity == 'selective' + and "mhc" in self.config.recompute_modules + ) + mhc_layer_managers, mhc_is_last_in_recompute_block = self._build_mhc_recompute_layer_plan( + use_mhc_recompute + ) + with outer_fp8_context: - for layer in self.layers: + for l_no, layer in enumerate(self.layers): # Layers have 1-indexed layer numbers attribute. inner_quant_context = get_inner_quant_context(self.config, layer.layer_number - 1) + mhc_manager = mhc_layer_managers[l_no] + if mhc_manager is not None: + mhc_manager.is_last_layer_in_recompute_block = ( + mhc_is_last_in_recompute_block[l_no] + ) + with inner_quant_context: - if isinstance(layer, TransformerLayer): - hidden_states, _ = layer( + if isinstance(layer, (TransformerLayer, HyperConnectionHybridLayer)): + layer_kwargs = dict( hidden_states=hidden_states, attention_mask=attention_mask, inference_context=inference_context, @@ -346,6 +580,11 @@ def get_inner_quant_context(config, layer_number): packed_seq_params=packed_seq_params, padding_mask=padding_mask, ) + if mhc_manager is not None and isinstance( + layer, HyperConnectionHybridLayer + ): + layer_kwargs["mhc_recompute_manager"] = mhc_manager + hidden_states, _ = layer(**layer_kwargs) else: # MambaLayer, Expert, or MLP hidden_states = layer( hidden_states=hidden_states, @@ -360,6 +599,17 @@ def get_inner_quant_context(config, layer_number): if isinstance(hidden_states, tuple): hidden_states = hidden_states[0] + self._finalize_mhc_recompute_layer( + mhc_manager=mhc_manager, + hidden_states=hidden_states, + is_last_in_recompute_block=mhc_is_last_in_recompute_block[l_no], + ) + + if self.config.enable_hyper_connections and self.post_process: + hidden_states = HyperConnectionModule.output_contract( + hidden_states, self.config.num_residual_streams + ) + # Final layer norm. if self.post_process and self.post_layer_norm: hidden_states = self.final_norm(hidden_states) diff --git a/tests/unit_tests/models/test_dsa_gpt_mamba_equivalence.py b/tests/unit_tests/models/test_dsa_gpt_mamba_equivalence.py index 229af268a79..9255e4794d5 100644 --- a/tests/unit_tests/models/test_dsa_gpt_mamba_equivalence.py +++ b/tests/unit_tests/models/test_dsa_gpt_mamba_equivalence.py @@ -35,6 +35,7 @@ get_transformer_block_with_experimental_attention_variant_spec, ) from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.models.hybrid.hybrid_block import HyperConnectionHybridLayer from megatron.core.models.hybrid.hybrid_layer_allocation import validate_segment_layers from megatron.core.models.hybrid.hybrid_layer_specs import hybrid_stack_spec from megatron.core.models.hybrid.hybrid_model import HybridModel @@ -642,3 +643,52 @@ def test_moe_record_and_compare_golden_values(self, tp: int, pp: int) -> None: # Verify HybridModel matches golden values _compare_against_golden_values(mamba_logprobs, gpt_logprobs, abs_tol=1e-3) + + +# --------------------------------------------------------------------------- +# mHC HybridModel smoke tests for DeepSeek proxy patterns +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +class TestDSAHybridMHCProxy: + """Smoke-test mHC on DeepSeek-style HybridModel patterns. + + These do not assert GPT/Hybrid numerical equivalence because the current + HybridModel implementation wraps each split hybrid layer at the boundary, + whereas GPT mHC has separate attention and MLP hyper-connections inside a + TransformerLayer. + """ + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def _enable_mhc(self, config: MLATransformerConfig) -> MLATransformerConfig: + config.enable_hyper_connections = True + config.num_residual_streams = 4 + config.mhc_sinkhorn_iterations = 5 + config.mhc_init_gating_factor = 0.01 + config.hidden_dropout = 0.0 + return config + + def _assert_mhc_model_forward(self, config: MLATransformerConfig, pattern: str) -> None: + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(42) + model = _build_mamba_model(self._enable_mhc(config), pattern) + assert all(isinstance(layer, HyperConnectionHybridLayer) for layer in model.decoder.layers) + + torch.manual_seed(99) + tokens = torch.randint(0, _VOCAB_SIZE, (_BATCH_SIZE, _SEQ_LEN), device='cuda') + logprobs = _forward_logprobs_pp1(model, tokens) + assert logprobs.shape == (_BATCH_SIZE, _SEQ_LEN - 1) + assert torch.isfinite(logprobs).all() + + def test_dsa_dense_hybrid_mhc_forward(self) -> None: + """DeepSeek-V3.2-style DSA + MLP split pattern runs with mHC.""" + config = _make_dsa_config(num_layers=_NUM_GPT_LAYERS, tp=1, pp=1) + self._assert_mhc_model_forward(config, _MAMBA_PATTERN) + + def test_dsa_moe_hybrid_mhc_forward(self) -> None: + """DeepSeek-V3-style DSA + MoE split pattern runs with mHC.""" + config = _make_dsa_moe_config(num_layers=_NUM_GPT_LAYERS, tp=1, pp=1) + self._assert_mhc_model_forward(config, _MOE_MAMBA_PATTERN) diff --git a/tests/unit_tests/models/test_hybrid_model.py b/tests/unit_tests/models/test_hybrid_model.py index 98a53da0314..0d214605f47 100644 --- a/tests/unit_tests/models/test_hybrid_model.py +++ b/tests/unit_tests/models/test_hybrid_model.py @@ -16,17 +16,61 @@ from megatron.core.inference.inference_request import DynamicInferenceRequest from megatron.core.inference.sampling_params import SamplingParams from megatron.core.models.common.embeddings.yarn_rotary_pos_embedding import YarnRotaryEmbedding +from megatron.core.models.hybrid.hybrid_block import ( + HybridStack, + HybridStackSubmodules, + HyperConnectionHybridLayer, +) from megatron.core.models.hybrid.hybrid_layer_specs import hybrid_stack_spec from megatron.core.models.hybrid.hybrid_model import HybridModel from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer import TransformerConfig from megatron.core.transformer.enums import AttnBackend -from megatron.core.transformer.module import Float16Module +from megatron.core.transformer.module import Float16Module, MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.utils import divide, is_fa_min_version, is_torch_min_version from tests.unit_tests.test_utilities import Utils +class _DummyHybridLayer(MegatronModule): + """Minimal same-shape layer used to test HybridModel/mHC plumbing.""" + + def __init__(self, config: TransformerConfig, layer_number: int, **_kwargs): + super().__init__(config=config) + self.layer_number = layer_number + self.proj = torch.nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.seen_hidden_shapes = [] + + def forward( + self, + hidden_states, + attention_mask=None, + inference_context=None, + packed_seq_params=None, + **_kwargs, + ): + self.seen_hidden_shapes.append(tuple(hidden_states.shape)) + return hidden_states + 0.125 * self.proj(hidden_states) + + +def _get_dummy_hybrid_stack_spec() -> ModuleSpec: + """Build a HybridStack spec whose layer symbols all resolve to dummy layers.""" + dummy_layer_spec = ModuleSpec(module=_DummyHybridLayer) + return ModuleSpec( + module=HybridStack, + params={"post_layer_norm": False}, + submodules=HybridStackSubmodules( + mamba_layer=dummy_layer_spec, + gdn_layer=dummy_layer_spec, + attention_layer=dummy_layer_spec, + dsa_layer=dummy_layer_spec, + mlp_layer=dummy_layer_spec, + moe_layer=dummy_layer_spec, + ), + ) + + class TestHybridModel: def setup_method(self, method): @@ -57,6 +101,109 @@ def test_constructor(self): num_weights = sum([p.numel() for p in self.model.parameters()]) assert num_weights == 1774872 + def test_constructor_with_hyper_connections(self): + model_config = TransformerConfig( + num_layers=3, + hidden_size=256, + num_attention_heads=4, + use_cpu_initialization=True, + enable_hyper_connections=True, + hidden_dropout=0.0, + ) + model = HybridModel( + config=model_config, + hybrid_stack_spec=hybrid_stack_spec, + vocab_size=100, + max_sequence_length=4, + hybrid_layer_pattern="M*-", + ) + + assert all(isinstance(layer, HyperConnectionHybridLayer) for layer in model.decoder.layers) + num_weights = sum([p.numel() for p in model.parameters()]) + assert num_weights > sum([p.numel() for p in self.model.parameters()]) + + def test_forward_with_hyper_connections(self): + model_config = TransformerConfig( + num_layers=3, + hidden_size=256, + num_attention_heads=4, + use_cpu_initialization=True, + enable_hyper_connections=True, + hidden_dropout=0.0, + ) + model = HybridModel( + config=model_config, + hybrid_stack_spec=hybrid_stack_spec, + vocab_size=100, + max_sequence_length=4, + hybrid_layer_pattern="M*-", + ) + model.cuda() + + sequence_length = model.max_sequence_length + micro_batch_size = 2 + data = list(range(sequence_length)) + input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() + position_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() + attention_mask = torch.ones( + (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool + ).cuda() + + logits = model.forward( + input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask + ) + + assert logits.shape[0] == micro_batch_size + assert logits.shape[1] == sequence_length + assert logits.shape[2] == model.vocab_size + + def test_dummy_hybrid_model_with_hyper_connections_forward_backward(self): + model_config = TransformerConfig( + num_layers=3, + hidden_size=32, + num_attention_heads=4, + use_cpu_initialization=True, + enable_hyper_connections=True, + hidden_dropout=0.0, + mhc_sinkhorn_iterations=3, + ) + model = HybridModel( + config=model_config, + hybrid_stack_spec=_get_dummy_hybrid_stack_spec(), + vocab_size=64, + max_sequence_length=8, + hybrid_layer_pattern="M*-", + parallel_output=False, + ) + + assert all(isinstance(layer, HyperConnectionHybridLayer) for layer in model.decoder.layers) + assert all( + isinstance(layer.inner_layer, _DummyHybridLayer) for layer in model.decoder.layers + ) + + model.cuda() + sequence_length = model.max_sequence_length + micro_batch_size = 2 + data = torch.arange(sequence_length, dtype=torch.int64, device='cuda') + input_ids = data.repeat((micro_batch_size, 1)) + position_ids = data.repeat((micro_batch_size, 1)) + + logits = model.forward(input_ids=input_ids, position_ids=position_ids, attention_mask=None) + + assert logits.shape == (micro_batch_size, sequence_length, model.vocab_size) + assert torch.isfinite(logits).all() + + logits.float().mean().backward() + + for layer in model.decoder.layers: + assert layer.inner_layer.seen_hidden_shapes == [ + (sequence_length, micro_batch_size, model_config.hidden_size) + ] + assert layer.inner_layer.proj.weight.grad is not None + assert layer.hyper_connection.mapping_proj.weight.grad is not None + assert torch.isfinite(layer.inner_layer.proj.weight.grad).all() + assert torch.isfinite(layer.hyper_connection.mapping_proj.weight.grad).all() + def test_set_input_tensor(self): config: TransformerConfig = self.model.config sequence_length = self.model.max_sequence_length diff --git a/tests/unit_tests/ssm/test_hybrid_block.py b/tests/unit_tests/ssm/test_hybrid_block.py index 08bf7f2bc28..14caa55aa0a 100644 --- a/tests/unit_tests/ssm/test_hybrid_block.py +++ b/tests/unit_tests/ssm/test_hybrid_block.py @@ -3,7 +3,7 @@ import pytest import torch -from megatron.core.models.hybrid.hybrid_block import HybridStack +from megatron.core.models.hybrid.hybrid_block import HybridStack, HyperConnectionHybridLayer from megatron.core.models.hybrid.hybrid_layer_allocation import Symbols, validate_segment_layers from megatron.core.models.hybrid.hybrid_layer_specs import hybrid_stack_spec from megatron.core.process_groups_config import ProcessGroupCollection @@ -30,8 +30,13 @@ def setup_method(self, method): def get_pg_collection(self): return ProcessGroupCollection.use_mpu_process_groups(required_pgs=['tp', 'pp', 'cp']) - def get_mamba_block(self, layer_pattern): + def get_mamba_block(self, layer_pattern, enable_hyper_connections=False): layer_type_list = validate_segment_layers(layer_pattern) + mhc_kwargs = ( + {"enable_hyper_connections": True, "hidden_dropout": 0.0, "mhc_sinkhorn_iterations": 5} + if enable_hyper_connections + else {} + ) transformer_config = TransformerConfig( hidden_size=256, # The Mamba layer places several constraints on this # Need to specify num_attention_heads and num_layers or TransformerConfig @@ -39,6 +44,7 @@ def get_mamba_block(self, layer_pattern): num_layers=len(layer_type_list), num_attention_heads=4, use_cpu_initialization=True, + **mhc_kwargs, ) modules = hybrid_stack_spec.submodules return HybridStack( @@ -49,8 +55,13 @@ def get_mamba_block(self, layer_pattern): pg_collection=self.get_pg_collection(), ) - def get_dsa_mamba_block(self, layer_pattern): + def get_dsa_mamba_block(self, layer_pattern, enable_hyper_connections=False): layer_type_list = validate_segment_layers(layer_pattern) + mhc_kwargs = ( + {"enable_hyper_connections": True, "hidden_dropout": 0.0, "mhc_sinkhorn_iterations": 5} + if enable_hyper_connections + else {} + ) transformer_config = MLATransformerConfig( hidden_size=256, # The Mamba layer places several constraints on this # Need to specify num_attention_heads and num_layers or TransformerConfig @@ -71,6 +82,7 @@ def get_dsa_mamba_block(self, layer_pattern): dsa_indexer_n_heads=8, dsa_indexer_head_dim=64, dsa_indexer_topk=32, + **mhc_kwargs, ) modules = hybrid_stack_spec.submodules return HybridStack( @@ -118,6 +130,161 @@ def test_layer_types(self): assert isinstance(layers[2], TransformerLayer) assert isinstance(layers[2].mlp, MLP) + def test_hyper_connection_layer_wrappers(self): + """mHC wraps each hybrid layer while preserving the layer type underneath.""" + layer_pattern = Symbols.MAMBA + Symbols.ATTENTION + Symbols.MLP + block = self.get_mamba_block(layer_pattern, enable_hyper_connections=True) + layers = block.layers + assert all(isinstance(layer, HyperConnectionHybridLayer) for layer in layers) + assert isinstance(layers[0].inner_layer, MambaLayer) + assert isinstance(layers[1].inner_layer, TransformerLayer) + assert isinstance(layers[1].inner_layer.self_attention, SelfAttention) + assert isinstance(layers[2].inner_layer, TransformerLayer) + assert isinstance(layers[2].inner_layer.mlp, MLP) + + def test_hyper_connection_recompute_plan_for_hybrid_layers(self): + """HybridStack creates per-layer mHC recompute managers when requested.""" + layer_pattern = Symbols.MAMBA + Symbols.ATTENTION + Symbols.MLP + layer_type_list = validate_segment_layers(layer_pattern) + transformer_config = TransformerConfig( + hidden_size=256, + num_layers=len(layer_type_list), + num_attention_heads=4, + use_cpu_initialization=True, + enable_hyper_connections=True, + hidden_dropout=0.0, + mhc_sinkhorn_iterations=5, + recompute_granularity="selective", + recompute_modules=["core_attn", "mhc"], + ) + block = HybridStack( + transformer_config, + hybrid_stack_spec.submodules, + layer_type_list=layer_type_list, + pp_layer_offset=0, + pg_collection=self.get_pg_collection(), + ) + + managers, block_ends = block._build_mhc_recompute_layer_plan(use_mhc_recompute=True) + assert len(managers) == len(block.layers) + assert all(manager is not None for manager in managers) + assert block_ends[-1] is True + + def test_hyper_connection_gpu_forward(self): + """mHC-enabled HybridStack expands internally and contracts back at the output.""" + layer_pattern = Symbols.MAMBA + Symbols.ATTENTION + Symbols.MLP + block = self.get_mamba_block(layer_pattern, enable_hyper_connections=True) + block.cuda() + micro_batch_size = 2 + sequence_length = 32 + hidden_states = torch.ones((sequence_length, micro_batch_size, block.config.hidden_size)) + hidden_states = hidden_states.cuda() + attention_mask = torch.ones( + (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool + ) + attention_mask = attention_mask.cuda() + output = block(hidden_states, attention_mask=attention_mask) + assert output.shape[0] == sequence_length + assert output.shape[1] == micro_batch_size + assert output.shape[2] == block.config.hidden_size + assert output.dtype == torch.float32 + + def test_hyper_connection_gdn_gpu_forward(self): + """mHC runs through GDN, attention, and Mamba hybrid layers.""" + layer_pattern = Symbols.GDN + Symbols.ATTENTION + Symbols.MAMBA + layer_type_list = validate_segment_layers(layer_pattern) + transformer_config = TransformerConfig( + hidden_size=256, + num_layers=len(layer_type_list), + num_attention_heads=4, + use_cpu_initialization=True, + activation_func=torch.nn.functional.silu, + enable_hyper_connections=True, + hidden_dropout=0.0, + mhc_sinkhorn_iterations=5, + ) + block = HybridStack( + transformer_config, + hybrid_stack_spec.submodules, + layer_type_list=layer_type_list, + pp_layer_offset=0, + pg_collection=self.get_pg_collection(), + ) + block.cuda() + micro_batch_size = 2 + sequence_length = 32 + hidden_states = torch.ones((sequence_length, micro_batch_size, block.config.hidden_size)) + hidden_states = hidden_states.cuda() + attention_mask = torch.ones( + (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool + ).cuda() + output = block(hidden_states, attention_mask=attention_mask) + assert output.shape == (sequence_length, micro_batch_size, block.config.hidden_size) + + def test_hyper_connection_dsa_layer_wrappers(self): + """mHC wraps DeepSeek-style DSA and MLP split layers.""" + layer_pattern = Symbols.MAMBA + Symbols.DS_ATTENTION + Symbols.MLP + block = self.get_dsa_mamba_block(layer_pattern, enable_hyper_connections=True) + layers = block.layers + assert all(isinstance(layer, HyperConnectionHybridLayer) for layer in layers) + assert isinstance(layers[0].inner_layer, MambaLayer) + assert isinstance(layers[1].inner_layer, TransformerLayer) + assert isinstance(layers[1].inner_layer.self_attention, MLASelfAttention) + assert isinstance(layers[1].inner_layer.self_attention.core_attention, DSAttention) + assert isinstance(layers[2].inner_layer, TransformerLayer) + assert isinstance(layers[2].inner_layer.mlp, MLP) + + def test_hyper_connection_pipeline_boundary_shapes(self): + """HybridStack keeps n-stream tensors between PP stages and contracts at the end.""" + layer_type_list = validate_segment_layers(Symbols.MAMBA) + transformer_config = TransformerConfig( + hidden_size=256, + num_layers=len(layer_type_list), + num_attention_heads=4, + use_cpu_initialization=True, + enable_hyper_connections=True, + hidden_dropout=0.0, + mhc_sinkhorn_iterations=5, + ) + modules = hybrid_stack_spec.submodules + first_stage = HybridStack( + transformer_config, + modules, + layer_type_list=layer_type_list, + pp_layer_offset=0, + post_process=False, + pg_collection=self.get_pg_collection(), + ).cuda() + last_stage = HybridStack( + transformer_config, + modules, + pre_process=False, + layer_type_list=layer_type_list, + pp_layer_offset=1, + post_process=True, + pg_collection=self.get_pg_collection(), + ).cuda() + + micro_batch_size = 2 + sequence_length = 32 + hidden_states = torch.ones( + (sequence_length, micro_batch_size, transformer_config.hidden_size), device='cuda' + ) + attention_mask = torch.ones( + (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool, device='cuda' + ) + + pp_hidden = first_stage(hidden_states, attention_mask=attention_mask) + assert pp_hidden.shape == ( + sequence_length, + micro_batch_size, + transformer_config.hidden_size * transformer_config.num_residual_streams, + ) + + last_stage.set_input_tensor(pp_hidden.detach()) + output = last_stage(hidden_states, attention_mask=attention_mask) + assert output.shape == (sequence_length, micro_batch_size, transformer_config.hidden_size) + def test_invalid_layer_types_cause_failure(self): invalid_symbol = '+' assert invalid_symbol not in Symbols.VALID_LAYERS # sanity check.