Skip to content
Closed
2 changes: 2 additions & 0 deletions gpt_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def _get_transformer_layer_spec(use_te, config):
use_kitchen_attention=config.use_kitchen_attention,
kitchen_attention_backend=config.kitchen_attention_backend,
mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False),
enable_hyper_connections=config.enable_hyper_connections,
)
elif config.transformer_impl == "inference_optimized":
return get_gpt_layer_with_inference_spec(
Expand All @@ -154,4 +155,5 @@ def _get_transformer_layer_spec(use_te, config):
use_kitchen=config.use_kitchen,
use_kitchen_attention=config.use_kitchen_attention,
kitchen_attention_backend=config.kitchen_attention_backend,
enable_hyper_connections=config.enable_hyper_connections,
)
12 changes: 11 additions & 1 deletion megatron/core/fusions/fused_bias_dropout.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
from typing import Optional, Tuple

import torch
Expand Down Expand Up @@ -81,6 +81,16 @@ def bias_dropout_add_fused_inference(


def get_bias_dropout_add(training, fused):
"""
Get the bias-dropout-add function.

Args:
training: Whether in training mode.
fused: Whether to use fused implementation.

Returns:
A callable that performs bias-dropout-add operation.
"""
if fused:
# jit scripting for a nn.module (with dropout) is not
# triggering the fusion kernel. For now, we use two
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
DSAttention,
DSAttentionSubmodules,
)
from megatron.core.transformer.hyper_connection import HyperConnectionModule
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.multi_latent_attention import (
MLASelfAttention,
Expand All @@ -24,6 +25,7 @@
)
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import (
HyperConnectionTransformerLayer,
TransformerLayer,
TransformerLayerSubmodules,
get_transformer_layer_offset,
Expand Down Expand Up @@ -227,6 +229,10 @@ def get_transformer_block_with_experimental_attention_variant_spec(

# Get GPT decoder block layer specs
rms_norm = config.normalization == "RMSNorm"
enable_hc = config.enable_hyper_connections
hc_module = HyperConnectionModule if enable_hc else IdentityOp
layer_module = HyperConnectionTransformerLayer if enable_hc else TransformerLayer

layer_specs = []
for layer_number in range(config.num_layers):
attention = (
Expand All @@ -248,14 +254,16 @@ def get_transformer_block_with_experimental_attention_variant_spec(

layer_specs.append(
ModuleSpec(
module=TransformerLayer,
module=layer_module,
submodules=TransformerLayerSubmodules(
input_layernorm=input_layernorm,
self_attention=attention,
self_attn_bda=get_bias_dropout_add,
self_attention_hyper_connection=hc_module,
pre_mlp_layernorm=pre_mlp_layernorm,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
mlp_hyper_connection=hc_module,
),
)
)
Expand Down
51 changes: 45 additions & 6 deletions megatron/core/models/gpt/gpt_layer_specs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
import copy
import warnings
from typing import Optional, Union

Expand All @@ -12,6 +13,7 @@
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec_for_backend
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.enums import AttnMaskType, LayerType
from megatron.core.transformer.hyper_connection import HyperConnectionModule
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.multi_latent_attention import (
Expand All @@ -34,6 +36,7 @@
)
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import (
HyperConnectionTransformerLayer,
TransformerLayer,
TransformerLayerSubmodules,
get_transformer_layer_offset,
Expand Down Expand Up @@ -183,6 +186,7 @@ def get_gpt_layer_with_transformer_engine_submodules(
use_kitchen_attention: bool = False,
kitchen_attention_backend: str = "sdpa",
mla_down_proj_fusion: bool = False,
enable_hyper_connections: bool = False,
) -> TransformerLayerSubmodules:
"""Use these submodules to use lower-level Transformer Engine modules (required for fp8
training).
Expand All @@ -200,6 +204,8 @@ def get_gpt_layer_with_transformer_engine_submodules(
mla_down_proj_fusion (bool, optional): Enable fused q/kv down-projection and fused input
layernorm when backend supports. Otherwise fall back
to the unfused MLA.
enable_hyper_connections (bool): Use HyperConnectionTransformerLayer with
HyperConnectionModule instead of plain TransformerLayer. Defaults to False.

Returns:
TransformerLayerSubmodules: TE modules to construct a TransformerLayer
Expand Down Expand Up @@ -233,6 +239,8 @@ def get_gpt_layer_with_transformer_engine_submodules(
use_te_activation_func=use_te_activation_func,
)

hc_module = HyperConnectionModule if enable_hyper_connections else IdentityOp

if multi_latent_attention:
assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA."
linear_q_up_proj = (
Expand Down Expand Up @@ -302,9 +310,11 @@ def get_gpt_layer_with_transformer_engine_submodules(
),
),
self_attn_bda=get_bias_dropout_add,
self_attention_hyper_connection=hc_module,
pre_mlp_layernorm=backend.layer_norm(has_residual=True) if num_experts else IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
mlp_hyper_connection=hc_module,
)
else:
qk_norm = backend.layer_norm(for_qk=True)
Expand All @@ -325,9 +335,11 @@ def get_gpt_layer_with_transformer_engine_submodules(
),
),
self_attn_bda=get_bias_dropout_add,
self_attention_hyper_connection=hc_module,
pre_mlp_layernorm=backend.layer_norm(has_residual=True) if num_experts else IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
mlp_hyper_connection=hc_module,
sharded_state_dict_keys_map={
"mlp.0.weight": "mlp.linear_fc1.layer_norm_weight",
"mlp.0.bias": "mlp.linear_fc1.layer_norm_bias",
Expand All @@ -342,8 +354,10 @@ def get_gpt_layer_with_transformer_engine_submodules(
@copy_signature(get_gpt_layer_with_transformer_engine_submodules)
def get_gpt_layer_with_transformer_engine_spec(*args, **kwargs) -> ModuleSpec:
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training)."""
enable_hc = kwargs.get('enable_hyper_connections', False)
layer_module = HyperConnectionTransformerLayer if enable_hc else TransformerLayer
return ModuleSpec(
module=TransformerLayer,
module=layer_module,
submodules=get_gpt_layer_with_transformer_engine_submodules(*args, **kwargs),
)

Expand All @@ -359,6 +373,7 @@ def get_gpt_layer_local_submodules(
use_kitchen: bool = False,
use_kitchen_attention: bool = False,
kitchen_attention_backend: str = "sdpa",
enable_hyper_connections: bool = False,
) -> TransformerLayerSubmodules:
"""Use these submodules for an implementation using only modules in Megatron-Core.

Expand All @@ -370,6 +385,8 @@ def get_gpt_layer_local_submodules(
multi_latent_attention (bool, optional): To use MLA. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
qk_l2_norm (bool, optional): To use l2 norm for queries/keys. Defaults to False.
enable_hyper_connections (bool): Use HyperConnectionTransformerLayer with
HyperConnectionModule instead of plain TransformerLayer. Defaults to False.

Returns:
TransformerLayerSubmodules: Megatron-Core modules to construct a TransformerLayer
Expand Down Expand Up @@ -402,6 +419,8 @@ def get_gpt_layer_local_submodules(
backend=backend, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm
)

hc_module = HyperConnectionModule if enable_hyper_connections else IdentityOp

if multi_latent_attention:
assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA."
return TransformerLayerSubmodules(
Expand All @@ -422,9 +441,11 @@ def get_gpt_layer_local_submodules(
),
),
self_attn_bda=get_bias_dropout_add,
self_attention_hyper_connection=hc_module,
pre_mlp_layernorm=layer_norm,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
mlp_hyper_connection=hc_module,
)
else:
return TransformerLayerSubmodules(
Expand All @@ -445,9 +466,11 @@ def get_gpt_layer_local_submodules(
),
),
self_attn_bda=get_bias_dropout_add,
self_attention_hyper_connection=hc_module,
pre_mlp_layernorm=layer_norm,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
mlp_hyper_connection=hc_module,
sharded_state_dict_keys_map={
"input_layernorm.": "self_attention.linear_qkv.layer_norm_",
"pre_mlp_layernorm.": "mlp.linear_fc1.layer_norm_",
Expand All @@ -458,8 +481,10 @@ def get_gpt_layer_local_submodules(
@copy_signature(get_gpt_layer_local_submodules)
def get_gpt_layer_local_spec(*args, **kwargs) -> ModuleSpec:
"""Use this spec for an implementation using only modules in Megatron-Core."""
enable_hc = kwargs.get('enable_hyper_connections', False)
layer_module = HyperConnectionTransformerLayer if enable_hc else TransformerLayer
return ModuleSpec(
module=TransformerLayer, submodules=get_gpt_layer_local_submodules(*args, **kwargs)
module=layer_module, submodules=get_gpt_layer_local_submodules(*args, **kwargs)
)


Expand Down Expand Up @@ -568,6 +593,7 @@ def get_gpt_decoder_layer_specs(
use_kitchen_attention=config.use_kitchen_attention,
kitchen_attention_backend=config.kitchen_attention_backend,
mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False),
enable_hyper_connections=config.enable_hyper_connections,
)
moe_layer_spec = get_gpt_layer_with_transformer_engine_spec(
num_experts=config.num_moe_experts,
Expand All @@ -580,6 +606,7 @@ def get_gpt_decoder_layer_specs(
use_kitchen_attention=config.use_kitchen_attention,
kitchen_attention_backend=config.kitchen_attention_backend,
mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False),
enable_hyper_connections=config.enable_hyper_connections,
)
elif config.transformer_impl == "inference_optimized":
layer_norm_impl = TENorm
Expand Down Expand Up @@ -608,6 +635,7 @@ def get_gpt_decoder_layer_specs(
use_kitchen=config.use_kitchen,
use_kitchen_attention=config.use_kitchen_attention,
kitchen_attention_backend=config.kitchen_attention_backend,
enable_hyper_connections=config.enable_hyper_connections,
)
moe_layer_spec = get_gpt_layer_local_spec(
num_experts=config.num_moe_experts,
Expand All @@ -619,6 +647,7 @@ def get_gpt_decoder_layer_specs(
use_kitchen=config.use_kitchen,
use_kitchen_attention=config.use_kitchen_attention,
kitchen_attention_backend=config.kitchen_attention_backend,
enable_hyper_connections=config.enable_hyper_connections,
)

# Parse config.moe_layer_freq to determine the pattern of expert/dense layers.
Expand Down Expand Up @@ -744,12 +773,22 @@ def get_gpt_mtp_block_spec_for_backend(

if isinstance(spec, TransformerBlockSubmodules):
# get the spec for the last layer of decoder block
transformer_layer_spec = spec.layer_specs[-1]
elif isinstance(spec, ModuleSpec) and spec.module == TransformerLayer:
transformer_layer_spec = spec
transformer_layer_spec = copy.copy(spec.layer_specs[-1])
elif isinstance(spec, ModuleSpec) and issubclass(spec.module, TransformerLayer):
transformer_layer_spec = copy.copy(spec)
else:
raise ValueError(f"Invalid spec: {spec}")

transformer_layer_spec.submodules = copy.copy(transformer_layer_spec.submodules)

# MTP does not support hyper connections yet; strip HC modules and
# downgrade the layer class to plain TransformerLayer.
transformer_layer_spec.submodules.self_attention_hyper_connection = IdentityOp
transformer_layer_spec.submodules.cross_attention_hyper_connection = IdentityOp
transformer_layer_spec.submodules.mlp_hyper_connection = IdentityOp
if transformer_layer_spec.module is HyperConnectionTransformerLayer:
transformer_layer_spec.module = TransformerLayer

mtp_layer_spec = get_mtp_layer_spec_for_backend(
mtp_model_layer_spec=transformer_layer_spec, backend=backend
)
Expand Down
Loading