Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 92 additions & 1 deletion megatron/core/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,80 @@ def __init__(
# the quantized tensor.
set_save_original_input(self.linear_proj)

# Per-layer RotaryEmbedding (used when rotary_base_per_layer is set in config).
self.rotary_pos_emb = None
if getattr(self.config, 'rotary_base_per_layer', None):
Copy link
Copy Markdown
Contributor

@Victarry Victarry Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION Code] getattr(self.config, 'rotary_base_per_layer', None) uses getattr, which suggests the field might not exist — but the field is added by this PR to TransformerConfig, so it always exists.

Suggestion:

if self.config.rotary_base_per_layer is not None:

This avoids unnecessary cognitive overhead and prevents readers from wondering whether some other config type lacks this field.

rotary_base = self.config.rotary_base_per_layer[self.layer_number - 1]
self._build_per_layer_rotary_pos_emb(rotary_base)

# Per-head scalar output gate (e.g., Step-3.5-Flash g_proj).
# Separate ColumnParallelLinear so gate weights are independent of QKV.
if self.config.use_head_wise_attn_gate:
self.g_proj = submodules.linear_qkv(
self.config.hidden_size,
self.config.num_attention_heads,
config=self.config,
init_method=not_none(self.config.init_method),
gather_output=False,
bias=False,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name='gate',
tp_group=self.pg_collection.tp,
Comment on lines +390 to +401
Copy link
Copy Markdown
Contributor

@Victarry Victarry Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[CRITICAL Architecture] g_proj should be folded into linear_qkv and merged with the existing attention_output_gate path, rather than introducing a separate ColumnParallelLinear and a new forward branch.

Reason:

  1. Code duplication: attention_output_gate (per-head per-channel; gate weights fused into linear_qkv, sliced out by _split_qkv, applied by _apply_output_gate) and use_head_wise_attn_gate (per-head scalar; independent g_proj; a new sigmoid-multiply block) are functionally near-parallel — they only differ in whether gate granularity is hn or 1. Yet their implementation paths are completely independent. Future changes (GQA AG handling, CUDA graphs, FP8/FP4 compatibility) on either path will need to be replicated and will drift.
  2. TE/local backend divergence: g_proj = submodules.linear_qkv(...) resolves to TELayerNormColumnParallelLinear (with a fused LN that has its own learnable parameters) under the TE backend, but plain ColumnParallelLinear under the local backend — the same config produces different mathematics on the two backends. Folding into linear_qkv makes QKV and gate share the same LN, which is consistent across backends and matches the pure-linear g_proj of the Step-3.5-Flash paper.
  3. One fewer GEMM kernel launch: the per-head scalar gate occupies only num_heads columns; folding it into linear_qkv is essentially free.
  4. _apply_output_gate does not need a branch: replace gate.view(*x.shape) with view(*x.shape[:-1], -1) (or remove it altogether), and PyTorch broadcasting naturally handles both gate shapes — [sq,b,np,hn] (existing) and [sq,b,np,1] (new) both multiply core_attn_out [sq,b,np,hn] correctly element-wise.

Suggestion:

  • Branch in the linear_qkv_out_dim calculation:
    if self.config.attention_output_gate:
        self.linear_qkv_out_dim += self.config.kv_channels * self.config.num_attention_heads  # per-head per-channel
    elif self.config.use_head_wise_attn_gate:
        self.linear_qkv_out_dim += self.config.num_attention_heads                            # per-head scalar
  • In _split_qkv, slice the trailing num_heads scalars as a new tail segment. Since this is not a multiple of hn, slice it off before the group reshape, or implement an unequal-width split inside the group. The GQA AG path needs corresponding updates: after the all-gather, slice QKV by group and slice the scalar segment by q_head separately.
  • Change _apply_output_gate to rely on broadcasting so both gates share the same function.
  • Delete the g_proj module, tp_comm_buffer_name='gate', and the new per-head gate forward block.
  • Since only one gate type can occupy the linear_qkv tail, the two flags become naturally mutually exclusive (you may add an assert in __post_init__ to make this explicit).

)

def _build_per_layer_rotary_pos_emb(self, rotary_base: float) -> None:
"""Build self.rotary_pos_emb using a layer-specific rotary base."""
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding

seq_len_interpolation_factor = self.config.rotary_scaling_factor
if self.config.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=self.config.rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
rope_scaling=self.config.rope_scaling,
rope_scaling_factor=self.config.rope_scaling_factor,
use_cpu_initialization=self.config.use_cpu_initialization,
cp_group=self.pg_collection.cp,
)
elif self.config.position_embedding_type == 'yarn':
self.rotary_pos_emb = YarnRotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=self.config.rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
scaling_factor=getattr(self.config, "yarn_rotary_scaling_factor"),
original_max_position_embeddings=getattr(
self.config, "yarn_original_max_position_embeddings"
),
beta_fast=getattr(self.config, "yarn_beta_fast"),
beta_slow=getattr(self.config, "yarn_beta_slow"),
mscale=getattr(self.config, "yarn_mscale"),
mscale_all_dim=getattr(self.config, "yarn_mscale_all_dim"),
correction_range_round_to_int=getattr(
self.config, "yarn_correction_range_round_to_int"
),
use_cpu_initialization=self.config.use_cpu_initialization,
)
elif self.config.position_embedding_type == 'mrope' and not self.config.multi_latent_attention:
self.rotary_pos_emb = MultimodalRotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=self.config.rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
)
self.mrope_section = self.config.mrope_section
assert (
self.mrope_section is not None
), "mrope require mrope_section setting, but we got None from TransformerConfig"
else:
assert False, "Invalid position embedding type"
Comment on lines +404 to +454
Copy link
Copy Markdown
Contributor

@Victarry Victarry Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Maintenance] _build_per_layer_rotary_pos_emb duplicates the RotaryEmbedding / YarnRotaryEmbedding / MultimodalRotaryEmbedding construction logic from gpt_model.py.

Reason: When gpt_model.py adds new rotary parameters (new rope_scaling parameters, new rotary types), this per-layer path will silently fall out of sync, causing per-layer rotary behavior to diverge from model-level rotary.

Suggestion: Extract a factory function (e.g. in models/common/embeddings/rotary_pos_embedding.py) with a signature like:

def build_rotary_pos_emb(
    config,
    rotary_base: Optional[float] = None,
    cp_group=None,
) -> RotaryEmbedding | YarnRotaryEmbedding | MultimodalRotaryEmbedding: ...

called from both GPTModel and here, so the two sites stay in lockstep.

Copy link
Copy Markdown
Contributor

@Victarry Victarry Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION Code] assert False, "Invalid position embedding type" does not include the actual position_embedding_type value, and assert False is semantically off — this is not an invariant violation, it's an unsupported configuration.

Suggestion:

else:
    raise NotImplementedError(
        f"rotary_base_per_layer does not support "
        f"position_embedding_type={self.config.position_embedding_type!r} "
        f"(only 'rope' / 'yarn' / 'mrope' are supported)."
    )


def _checkpointed_attention_forward(
self,
query,
Expand Down Expand Up @@ -975,6 +1049,11 @@ def forward(
if no_rope:
rotary_pos_emb = None

# Per-layer theta: override the model-level RoPE with this layer's own embedding.
if self.rotary_pos_emb is not None and rotary_pos_emb is not None:
seq_len = rotary_pos_emb.shape[0]
rotary_pos_emb = self.rotary_pos_emb(seq_len)
Comment on lines +1052 to +1055
Copy link
Copy Markdown
Contributor

@Victarry Victarry Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Correctness] The per-layer rotary override depends on the model-level rotary_pos_emb already existing, but the PR description states "shared model-level rotary_pos_emb in GPTModel is not created" — and this PR does not modify GPTModel.

Reason:

  1. Description does not match implementation: the model-level rotary is still created by GPTModel, and each layer just overrides it. Both rotaries coexist; the model-level one is computed every forward and immediately discarded — wasted compute on every iteration.
  2. Silent skip: if GPTModel ever stops passing rotary_pos_emb on some path (cross-attention, fully no-rope layers, user-disabled), the per-layer rotary will be silently skipped — there's no assert, and the behavior becomes unpredictable.
  3. seq_len = rotary_pos_emb.shape[0] implicitly couples to the external rotary tensor shape. That variable is wrapped into (rotary_pos_emb,) * 2 later in the forward; the override happens before the tuple wrap so it's currently correct, but the coupling is fragile.

Suggestion:

  • Modify GPTModel in tandem: when config.rotary_base_per_layer is not None, skip creating the model-level rotary_pos_emb and just pass seq_len to each layer so each layer generates its own rotary.
  • Or keep the override-style implementation but obtain seq_len from inference_context / hidden_states.shape[0] instead of relying on the external rotary shape.
  • Add an assert: when per-layer rotary is enabled and rotary_pos_emb is None, raise explicitly instead of silently skipping.


inference_context = deprecate_inference_params(inference_context, inference_params)

if inference_context and inference_context.is_dynamic_batching():
Expand Down Expand Up @@ -1252,7 +1331,19 @@ def forward(
core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1)
nvtx_range_pop(suffix="core_attention")

# Output gate
# Per-head scalar gate (attention_per_head_gate: separate per_head_gate module)
if self.config.use_head_wise_attn_gate:
nvtx_range_push(suffix="head_wise_attn_gate")
gate_states, _ = self.g_proj(hidden_states) # [sq, b, np_per_rank]
gate_states = gate_states.view(*gate_states.shape[:2], -1, 1) # [sq, b, np, 1]
core_attn_out = core_attn_out.view(*gate_states.shape[:3], -1) # [sq, b, np, hn]
core_attn_out = (
core_attn_out * torch.sigmoid(gate_states.float()).to(core_attn_out.dtype)
)
core_attn_out = core_attn_out.view(*gate_states.shape[:2], -1) # [sq, b, np*hn]
nvtx_range_pop(suffix="head_wise_attn_gate")

# Output gate (attention_output_gate: full head_dim gate fused into QKV)
if gate is not None:
nvtx_range_push(suffix="output_gate")
core_attn_out = self._apply_output_gate(core_attn_out, gate)
Expand Down
17 changes: 17 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,17 @@ class TransformerConfig(ModelParallelConfig):
attention_output_gate: bool = False
"""Whether to apply output gate to the attention layers."""

use_head_wise_attn_gate: bool = False
"""Apply a per-head scalar output gate (e.g., Step-3.5-Flash g_proj).
Adds a separate ColumnParallelLinear(hidden_size → num_attention_heads) whose
sigmoid output gates each attention head independently. Distinct from
attention_output_gate which fuses a full head_dim gate into linear_qkv."""

rotary_base_per_layer: Optional[List[float]] = None
"""Per-layer RoPE theta values. Length must equal num_layers. When set, each
SelfAttention layer creates its own RotaryEmbedding with the corresponding base;
the shared model-level rotary_pos_emb is not created."""

test_mode: bool = False
"""Whether to run real-time tests."""

Expand Down Expand Up @@ -2522,6 +2533,12 @@ def __post_init__(self):
"2.3.0.dev0+39c0e70"
), "Must have at least TE version 2.3 or higher to use symmetric memory all reduce"

if self.rotary_base_per_layer is not None:
assert len(self.rotary_base_per_layer) == self.num_layers, (
f"rotary_base_per_layer length ({len(self.rotary_base_per_layer)}) "
f"must equal num_layers ({self.num_layers})"
)

if self.no_rope_freq:
assert not self.flash_decode, "flash_decode cannot be used with no_rope."
if isinstance(self.no_rope_freq, int):
Expand Down
212 changes: 212 additions & 0 deletions tests/unit_tests/transformer/test_head_wise_attn_gate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copy link
Copy Markdown
Contributor

@Victarry Victarry Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Tests] Tests cover only use_head_wise_attn_gate; there are no tests for rotary_base_per_layer.

Reason: rotary_base_per_layer involves two independent paths (per-layer rotary construction, forward-time override of the model-level rotary), each with three branches (rope / yarn / mrope). Zero test coverage means any future refactor (including adopting the factory-function suggestion) has no safety net.

Suggestion: Add at minimum:

  • A test that triggers the __post_init__ assert when len(rotary_base_per_layer) != num_layers.
  • A multi-layer model with distinct rotary_base per layer; verify each layer's self.rotary_pos_emb is non-None and that inv_freq matches the corresponding rotary_base.
  • A forward-pass test where each layer's rotary output numerically matches a standalone RotaryEmbedding constructed with the same rotary_base.
  • A test that with rotary_base_per_layer=None, self.rotary_pos_emb is None and forward runs without error.


"""Tests for per-head scalar attention gate (use_head_wise_attn_gate / g_proj)."""

import pytest
import torch

from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_submodules,
)
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.attention import SelfAttention
from tests.unit_tests.test_utilities import Utils


SEQ_LEN = 16
BATCH_SIZE = 2
HIDDEN_SIZE = 128
NUM_HEADS = 4


def _make_config(transformer_impl: str, use_head_wise_attn_gate: bool) -> TransformerConfig:
return TransformerConfig(
num_layers=1,
hidden_size=HIDDEN_SIZE,
num_attention_heads=NUM_HEADS,
use_cpu_initialization=True,
bf16=True,
params_dtype=torch.bfloat16,
transformer_impl=transformer_impl,
use_head_wise_attn_gate=use_head_wise_attn_gate,
)


def _make_attention(config: TransformerConfig, transformer_impl: str) -> SelfAttention:
if transformer_impl == "transformer_engine":
submodules = get_gpt_layer_with_transformer_engine_submodules().self_attention.submodules
else:
submodules = get_gpt_layer_local_spec().submodules.self_attention.submodules
return SelfAttention(config, submodules, layer_number=1)


@pytest.mark.parametrize("transformer_impl", ["transformer_engine", "native"])
class TestHeadWiseAttnGateInit:
"""Verify that g_proj is created iff use_head_wise_attn_gate=True."""

@pytest.fixture(autouse=True)
def setup_teardown(self, transformer_impl):
Utils.initialize_model_parallel(1, 1)
model_parallel_cuda_manual_seed(42)
self.transformer_impl = transformer_impl
yield
Utils.destroy_model_parallel()

def test_g_proj_exists_when_enabled(self):
config = _make_config(self.transformer_impl, use_head_wise_attn_gate=True)
attn = _make_attention(config, self.transformer_impl)
assert hasattr(attn, "g_proj"), "g_proj should be created when use_head_wise_attn_gate=True"

def test_g_proj_absent_when_disabled(self):
config = _make_config(self.transformer_impl, use_head_wise_attn_gate=False)
attn = _make_attention(config, self.transformer_impl)
assert not hasattr(
attn, "g_proj"
), "g_proj should not be created when use_head_wise_attn_gate=False"

def test_g_proj_output_size(self):
"""g_proj maps hidden_size → num_attention_heads (no bias)."""
config = _make_config(self.transformer_impl, use_head_wise_attn_gate=True)
attn = _make_attention(config, self.transformer_impl)
# ColumnParallelLinear stores weight as (output, input)
weight = attn.g_proj.weight
assert weight.shape == (
NUM_HEADS,
HIDDEN_SIZE,
), f"Unexpected g_proj weight shape: {weight.shape}"


@pytest.mark.parametrize("transformer_impl", ["transformer_engine", "native"])
class TestHeadWiseAttnGateForward:
"""Verify forward-pass behaviour of the head-wise gate."""

@pytest.fixture(autouse=True)
def setup_teardown(self, transformer_impl):
Utils.initialize_model_parallel(1, 1)
model_parallel_cuda_manual_seed(42)
self.transformer_impl = transformer_impl
yield
Utils.destroy_model_parallel()

def _run_forward(self, use_gate: bool):
config = _make_config(self.transformer_impl, use_head_wise_attn_gate=use_gate)
attn = _make_attention(config, self.transformer_impl).cuda()
hidden_states = torch.randn(
SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE, dtype=torch.bfloat16, device="cuda"
)
attention_mask = torch.ones(BATCH_SIZE, 1, 1, SEQ_LEN, dtype=bool, device="cuda")
output, bias = attn(hidden_states, attention_mask)
return output, bias

def test_output_shape_with_gate(self):
output, bias = self._run_forward(use_gate=True)
assert output.shape == (SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE)
assert bias.shape == (HIDDEN_SIZE,)

def test_output_shape_without_gate(self):
output, bias = self._run_forward(use_gate=False)
assert output.shape == (SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE)

def test_gate_changes_output(self):
"""With identical weights/inputs, gating should change the output."""
torch.manual_seed(0)
out_gated, _ = self._run_forward(use_gate=True)
torch.manual_seed(0)
out_plain, _ = self._run_forward(use_gate=False)
assert not torch.allclose(out_gated, out_plain), (
"Gated and plain outputs should differ"
)

def test_zero_gate_suppresses_attn(self):
"""When g_proj weights are zero the gate is sigmoid(0)=0.5, not zero;
confirm that zeroing the bias (if any) and weight gives a 0.5-scaled output."""
config = _make_config(self.transformer_impl, use_head_wise_attn_gate=True)
attn = _make_attention(config, self.transformer_impl).cuda()

# Drive g_proj to produce exactly 0 pre-activation → sigmoid = 0.5
torch.nn.init.zeros_(attn.g_proj.weight)

hidden_states = torch.randn(
SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE, dtype=torch.bfloat16, device="cuda"
)
attention_mask = torch.ones(BATCH_SIZE, 1, 1, SEQ_LEN, dtype=bool, device="cuda")

# Reference: disable the gate and run with same weights for linear_proj
config_no_gate = _make_config(self.transformer_impl, use_head_wise_attn_gate=False)
attn_no_gate = _make_attention(config_no_gate, self.transformer_impl).cuda()
# Copy all shared weights so the only difference is the gate scaling
attn_no_gate.load_state_dict(
{k: v for k, v in attn.state_dict().items() if k in attn_no_gate.state_dict()},
strict=False,
)

with torch.no_grad():
out_gated, _ = attn(hidden_states, attention_mask)
out_plain, _ = attn_no_gate(hidden_states, attention_mask)

# Gate = sigmoid(0) = 0.5; gated output ≈ 0.5 * plain output
# Use a loose tolerance because of bfloat16 rounding
torch.testing.assert_close(
out_gated.float(),
(out_plain * 0.5).float(),
atol=1e-2,
rtol=1e-2,
)


class TestHeadWiseAttnGateNumerics:
"""Low-level tensor-math tests (no model-parallel overhead, pure PyTorch)."""

@pytest.fixture(autouse=True)
def setup_teardown(self):
Utils.initialize_model_parallel(1, 1)
yield
Utils.destroy_model_parallel()

def test_gate_reshape_correctness(self):
"""Replicate the reshape logic and verify sigmoid is applied per-head."""
sq, b, np, hn = 4, 2, NUM_HEADS, 32
# Simulate core_attn_out: [sq, b, np*hn]
core_attn_out = torch.arange(
sq * b * np * hn, dtype=torch.float32
).reshape(sq, b, np * hn)
# Simulate gate_states: [sq, b, np] (pre-sigmoid raw scores)
gate_scores = torch.zeros(sq, b, np) # sigmoid(0) = 0.5

gate_states = gate_scores.view(sq, b, np, 1)
out = core_attn_out.view(sq, b, np, hn)
out = out * torch.sigmoid(gate_states)
out = out.view(sq, b, np * hn)

expected = core_attn_out * 0.5
torch.testing.assert_close(out, expected)

def test_gate_dtype_cast(self):
"""Gate computation upcast to float32, result cast back to input dtype."""
sq, b, np, hn = 4, 2, NUM_HEADS, 32
core_attn_out = torch.randn(sq, b, np * hn, dtype=torch.bfloat16)
gate_scores = torch.randn(sq, b, np, dtype=torch.bfloat16)

gate_states = gate_scores.view(sq, b, np, 1)
out = core_attn_out.view(sq, b, np, hn)
# Mirrors the production code: float cast for sigmoid, then cast back
out = out * torch.sigmoid(gate_states.float()).to(out.dtype)
out = out.view(sq, b, np * hn)

assert out.dtype == torch.bfloat16

@pytest.mark.parametrize("gate_value,expected_scale", [(-1e6, 0.0), (1e6, 1.0), (0.0, 0.5)])
def test_gate_saturation(self, gate_value: float, expected_scale: float):
"""Extreme gate values saturate sigmoid to 0 or 1; midpoint gives 0.5."""
sq, b, np, hn = 2, 1, 2, 8
core_attn_out = torch.ones(sq, b, np * hn)
gate_scores = torch.full((sq, b, np), gate_value)

gate_states = gate_scores.view(sq, b, np, 1)
out = core_attn_out.view(sq, b, np, hn)
out = out * torch.sigmoid(gate_states.float()).to(out.dtype)
out = out.view(sq, b, np * hn)

torch.testing.assert_close(out, torch.full_like(out, expected_scale), atol=1e-5, rtol=1e-5)