-
Notifications
You must be signed in to change notification settings - Fork 3.9k
feat(attention): Add attention_per_head_gate and rotary_base_per_laye… #4473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -379,6 +379,80 @@ def __init__( | |
| # the quantized tensor. | ||
| set_save_original_input(self.linear_proj) | ||
|
|
||
| # Per-layer RotaryEmbedding (used when rotary_base_per_layer is set in config). | ||
| self.rotary_pos_emb = None | ||
| if getattr(self.config, 'rotary_base_per_layer', None): | ||
| rotary_base = self.config.rotary_base_per_layer[self.layer_number - 1] | ||
| self._build_per_layer_rotary_pos_emb(rotary_base) | ||
|
|
||
| # Per-head scalar output gate (e.g., Step-3.5-Flash g_proj). | ||
| # Separate ColumnParallelLinear so gate weights are independent of QKV. | ||
| if self.config.use_head_wise_attn_gate: | ||
| self.g_proj = submodules.linear_qkv( | ||
| self.config.hidden_size, | ||
| self.config.num_attention_heads, | ||
| config=self.config, | ||
| init_method=not_none(self.config.init_method), | ||
| gather_output=False, | ||
| bias=False, | ||
| skip_bias_add=False, | ||
| is_expert=False, | ||
| tp_comm_buffer_name='gate', | ||
| tp_group=self.pg_collection.tp, | ||
|
Comment on lines
+390
to
+401
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [CRITICAL Architecture] Reason:
Suggestion:
|
||
| ) | ||
|
|
||
| def _build_per_layer_rotary_pos_emb(self, rotary_base: float) -> None: | ||
| """Build self.rotary_pos_emb using a layer-specific rotary base.""" | ||
| from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding | ||
|
|
||
| seq_len_interpolation_factor = self.config.rotary_scaling_factor | ||
| if self.config.position_embedding_type == 'rope' and not self.config.multi_latent_attention: | ||
| self.rotary_pos_emb = RotaryEmbedding( | ||
| kv_channels=self.config.kv_channels, | ||
| rotary_percent=self.config.rotary_percent, | ||
| rotary_interleaved=self.config.rotary_interleaved, | ||
| seq_len_interpolation_factor=seq_len_interpolation_factor, | ||
| rotary_base=rotary_base, | ||
| rope_scaling=self.config.rope_scaling, | ||
| rope_scaling_factor=self.config.rope_scaling_factor, | ||
| use_cpu_initialization=self.config.use_cpu_initialization, | ||
| cp_group=self.pg_collection.cp, | ||
| ) | ||
| elif self.config.position_embedding_type == 'yarn': | ||
| self.rotary_pos_emb = YarnRotaryEmbedding( | ||
| kv_channels=self.config.kv_channels, | ||
| rotary_percent=self.config.rotary_percent, | ||
| rotary_interleaved=self.config.rotary_interleaved, | ||
| seq_len_interpolation_factor=seq_len_interpolation_factor, | ||
| rotary_base=rotary_base, | ||
| scaling_factor=getattr(self.config, "yarn_rotary_scaling_factor"), | ||
| original_max_position_embeddings=getattr( | ||
| self.config, "yarn_original_max_position_embeddings" | ||
| ), | ||
| beta_fast=getattr(self.config, "yarn_beta_fast"), | ||
| beta_slow=getattr(self.config, "yarn_beta_slow"), | ||
| mscale=getattr(self.config, "yarn_mscale"), | ||
| mscale_all_dim=getattr(self.config, "yarn_mscale_all_dim"), | ||
| correction_range_round_to_int=getattr( | ||
| self.config, "yarn_correction_range_round_to_int" | ||
| ), | ||
| use_cpu_initialization=self.config.use_cpu_initialization, | ||
| ) | ||
| elif self.config.position_embedding_type == 'mrope' and not self.config.multi_latent_attention: | ||
| self.rotary_pos_emb = MultimodalRotaryEmbedding( | ||
| kv_channels=self.config.kv_channels, | ||
| rotary_percent=self.config.rotary_percent, | ||
| rotary_interleaved=self.config.rotary_interleaved, | ||
| seq_len_interpolation_factor=seq_len_interpolation_factor, | ||
| rotary_base=rotary_base, | ||
| ) | ||
| self.mrope_section = self.config.mrope_section | ||
| assert ( | ||
| self.mrope_section is not None | ||
| ), "mrope require mrope_section setting, but we got None from TransformerConfig" | ||
| else: | ||
| assert False, "Invalid position embedding type" | ||
|
Comment on lines
+404
to
+454
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [IMPORTANT Maintenance] Reason: When Suggestion: Extract a factory function (e.g. in def build_rotary_pos_emb(
config,
rotary_base: Optional[float] = None,
cp_group=None,
) -> RotaryEmbedding | YarnRotaryEmbedding | MultimodalRotaryEmbedding: ...called from both
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [SUGGESTION Code] Suggestion: else:
raise NotImplementedError(
f"rotary_base_per_layer does not support "
f"position_embedding_type={self.config.position_embedding_type!r} "
f"(only 'rope' / 'yarn' / 'mrope' are supported)."
) |
||
|
|
||
| def _checkpointed_attention_forward( | ||
| self, | ||
| query, | ||
|
|
@@ -975,6 +1049,11 @@ def forward( | |
| if no_rope: | ||
| rotary_pos_emb = None | ||
|
|
||
| # Per-layer theta: override the model-level RoPE with this layer's own embedding. | ||
| if self.rotary_pos_emb is not None and rotary_pos_emb is not None: | ||
| seq_len = rotary_pos_emb.shape[0] | ||
| rotary_pos_emb = self.rotary_pos_emb(seq_len) | ||
|
Comment on lines
+1052
to
+1055
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [IMPORTANT Correctness] The per-layer rotary override depends on the model-level Reason:
Suggestion:
|
||
|
|
||
| inference_context = deprecate_inference_params(inference_context, inference_params) | ||
|
|
||
| if inference_context and inference_context.is_dynamic_batching(): | ||
|
|
@@ -1252,7 +1331,19 @@ def forward( | |
| core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) | ||
| nvtx_range_pop(suffix="core_attention") | ||
|
|
||
| # Output gate | ||
| # Per-head scalar gate (attention_per_head_gate: separate per_head_gate module) | ||
| if self.config.use_head_wise_attn_gate: | ||
| nvtx_range_push(suffix="head_wise_attn_gate") | ||
| gate_states, _ = self.g_proj(hidden_states) # [sq, b, np_per_rank] | ||
| gate_states = gate_states.view(*gate_states.shape[:2], -1, 1) # [sq, b, np, 1] | ||
| core_attn_out = core_attn_out.view(*gate_states.shape[:3], -1) # [sq, b, np, hn] | ||
| core_attn_out = ( | ||
| core_attn_out * torch.sigmoid(gate_states.float()).to(core_attn_out.dtype) | ||
| ) | ||
| core_attn_out = core_attn_out.view(*gate_states.shape[:2], -1) # [sq, b, np*hn] | ||
| nvtx_range_pop(suffix="head_wise_attn_gate") | ||
|
|
||
| # Output gate (attention_output_gate: full head_dim gate fused into QKV) | ||
| if gate is not None: | ||
| nvtx_range_push(suffix="output_gate") | ||
| core_attn_out = self._apply_output_gate(core_attn_out, gate) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,212 @@ | ||
| # Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [IMPORTANT Tests] Tests cover only Reason: Suggestion: Add at minimum:
|
||
|
|
||
| """Tests for per-head scalar attention gate (use_head_wise_attn_gate / g_proj).""" | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from megatron.core.models.gpt.gpt_layer_specs import ( | ||
| get_gpt_layer_local_spec, | ||
| get_gpt_layer_with_transformer_engine_submodules, | ||
| ) | ||
| from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed | ||
| from megatron.core.transformer import TransformerConfig | ||
| from megatron.core.transformer.attention import SelfAttention | ||
| from tests.unit_tests.test_utilities import Utils | ||
|
|
||
|
|
||
| SEQ_LEN = 16 | ||
| BATCH_SIZE = 2 | ||
| HIDDEN_SIZE = 128 | ||
| NUM_HEADS = 4 | ||
|
|
||
|
|
||
| def _make_config(transformer_impl: str, use_head_wise_attn_gate: bool) -> TransformerConfig: | ||
| return TransformerConfig( | ||
| num_layers=1, | ||
| hidden_size=HIDDEN_SIZE, | ||
| num_attention_heads=NUM_HEADS, | ||
| use_cpu_initialization=True, | ||
| bf16=True, | ||
| params_dtype=torch.bfloat16, | ||
| transformer_impl=transformer_impl, | ||
| use_head_wise_attn_gate=use_head_wise_attn_gate, | ||
| ) | ||
|
|
||
|
|
||
| def _make_attention(config: TransformerConfig, transformer_impl: str) -> SelfAttention: | ||
| if transformer_impl == "transformer_engine": | ||
| submodules = get_gpt_layer_with_transformer_engine_submodules().self_attention.submodules | ||
| else: | ||
| submodules = get_gpt_layer_local_spec().submodules.self_attention.submodules | ||
| return SelfAttention(config, submodules, layer_number=1) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("transformer_impl", ["transformer_engine", "native"]) | ||
| class TestHeadWiseAttnGateInit: | ||
| """Verify that g_proj is created iff use_head_wise_attn_gate=True.""" | ||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def setup_teardown(self, transformer_impl): | ||
| Utils.initialize_model_parallel(1, 1) | ||
| model_parallel_cuda_manual_seed(42) | ||
| self.transformer_impl = transformer_impl | ||
| yield | ||
| Utils.destroy_model_parallel() | ||
|
|
||
| def test_g_proj_exists_when_enabled(self): | ||
| config = _make_config(self.transformer_impl, use_head_wise_attn_gate=True) | ||
| attn = _make_attention(config, self.transformer_impl) | ||
| assert hasattr(attn, "g_proj"), "g_proj should be created when use_head_wise_attn_gate=True" | ||
|
|
||
| def test_g_proj_absent_when_disabled(self): | ||
| config = _make_config(self.transformer_impl, use_head_wise_attn_gate=False) | ||
| attn = _make_attention(config, self.transformer_impl) | ||
| assert not hasattr( | ||
| attn, "g_proj" | ||
| ), "g_proj should not be created when use_head_wise_attn_gate=False" | ||
|
|
||
| def test_g_proj_output_size(self): | ||
| """g_proj maps hidden_size → num_attention_heads (no bias).""" | ||
| config = _make_config(self.transformer_impl, use_head_wise_attn_gate=True) | ||
| attn = _make_attention(config, self.transformer_impl) | ||
| # ColumnParallelLinear stores weight as (output, input) | ||
| weight = attn.g_proj.weight | ||
| assert weight.shape == ( | ||
| NUM_HEADS, | ||
| HIDDEN_SIZE, | ||
| ), f"Unexpected g_proj weight shape: {weight.shape}" | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("transformer_impl", ["transformer_engine", "native"]) | ||
| class TestHeadWiseAttnGateForward: | ||
| """Verify forward-pass behaviour of the head-wise gate.""" | ||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def setup_teardown(self, transformer_impl): | ||
| Utils.initialize_model_parallel(1, 1) | ||
| model_parallel_cuda_manual_seed(42) | ||
| self.transformer_impl = transformer_impl | ||
| yield | ||
| Utils.destroy_model_parallel() | ||
|
|
||
| def _run_forward(self, use_gate: bool): | ||
| config = _make_config(self.transformer_impl, use_head_wise_attn_gate=use_gate) | ||
| attn = _make_attention(config, self.transformer_impl).cuda() | ||
| hidden_states = torch.randn( | ||
| SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE, dtype=torch.bfloat16, device="cuda" | ||
| ) | ||
| attention_mask = torch.ones(BATCH_SIZE, 1, 1, SEQ_LEN, dtype=bool, device="cuda") | ||
| output, bias = attn(hidden_states, attention_mask) | ||
| return output, bias | ||
|
|
||
| def test_output_shape_with_gate(self): | ||
| output, bias = self._run_forward(use_gate=True) | ||
| assert output.shape == (SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE) | ||
| assert bias.shape == (HIDDEN_SIZE,) | ||
|
|
||
| def test_output_shape_without_gate(self): | ||
| output, bias = self._run_forward(use_gate=False) | ||
| assert output.shape == (SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE) | ||
|
|
||
| def test_gate_changes_output(self): | ||
| """With identical weights/inputs, gating should change the output.""" | ||
| torch.manual_seed(0) | ||
| out_gated, _ = self._run_forward(use_gate=True) | ||
| torch.manual_seed(0) | ||
| out_plain, _ = self._run_forward(use_gate=False) | ||
| assert not torch.allclose(out_gated, out_plain), ( | ||
| "Gated and plain outputs should differ" | ||
| ) | ||
|
|
||
| def test_zero_gate_suppresses_attn(self): | ||
| """When g_proj weights are zero the gate is sigmoid(0)=0.5, not zero; | ||
| confirm that zeroing the bias (if any) and weight gives a 0.5-scaled output.""" | ||
| config = _make_config(self.transformer_impl, use_head_wise_attn_gate=True) | ||
| attn = _make_attention(config, self.transformer_impl).cuda() | ||
|
|
||
| # Drive g_proj to produce exactly 0 pre-activation → sigmoid = 0.5 | ||
| torch.nn.init.zeros_(attn.g_proj.weight) | ||
|
|
||
| hidden_states = torch.randn( | ||
| SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE, dtype=torch.bfloat16, device="cuda" | ||
| ) | ||
| attention_mask = torch.ones(BATCH_SIZE, 1, 1, SEQ_LEN, dtype=bool, device="cuda") | ||
|
|
||
| # Reference: disable the gate and run with same weights for linear_proj | ||
| config_no_gate = _make_config(self.transformer_impl, use_head_wise_attn_gate=False) | ||
| attn_no_gate = _make_attention(config_no_gate, self.transformer_impl).cuda() | ||
| # Copy all shared weights so the only difference is the gate scaling | ||
| attn_no_gate.load_state_dict( | ||
| {k: v for k, v in attn.state_dict().items() if k in attn_no_gate.state_dict()}, | ||
| strict=False, | ||
| ) | ||
|
|
||
| with torch.no_grad(): | ||
| out_gated, _ = attn(hidden_states, attention_mask) | ||
| out_plain, _ = attn_no_gate(hidden_states, attention_mask) | ||
|
|
||
| # Gate = sigmoid(0) = 0.5; gated output ≈ 0.5 * plain output | ||
| # Use a loose tolerance because of bfloat16 rounding | ||
| torch.testing.assert_close( | ||
| out_gated.float(), | ||
| (out_plain * 0.5).float(), | ||
| atol=1e-2, | ||
| rtol=1e-2, | ||
| ) | ||
|
|
||
|
|
||
| class TestHeadWiseAttnGateNumerics: | ||
| """Low-level tensor-math tests (no model-parallel overhead, pure PyTorch).""" | ||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def setup_teardown(self): | ||
| Utils.initialize_model_parallel(1, 1) | ||
| yield | ||
| Utils.destroy_model_parallel() | ||
|
|
||
| def test_gate_reshape_correctness(self): | ||
| """Replicate the reshape logic and verify sigmoid is applied per-head.""" | ||
| sq, b, np, hn = 4, 2, NUM_HEADS, 32 | ||
| # Simulate core_attn_out: [sq, b, np*hn] | ||
| core_attn_out = torch.arange( | ||
| sq * b * np * hn, dtype=torch.float32 | ||
| ).reshape(sq, b, np * hn) | ||
| # Simulate gate_states: [sq, b, np] (pre-sigmoid raw scores) | ||
| gate_scores = torch.zeros(sq, b, np) # sigmoid(0) = 0.5 | ||
|
|
||
| gate_states = gate_scores.view(sq, b, np, 1) | ||
| out = core_attn_out.view(sq, b, np, hn) | ||
| out = out * torch.sigmoid(gate_states) | ||
| out = out.view(sq, b, np * hn) | ||
|
|
||
| expected = core_attn_out * 0.5 | ||
| torch.testing.assert_close(out, expected) | ||
|
|
||
| def test_gate_dtype_cast(self): | ||
| """Gate computation upcast to float32, result cast back to input dtype.""" | ||
| sq, b, np, hn = 4, 2, NUM_HEADS, 32 | ||
| core_attn_out = torch.randn(sq, b, np * hn, dtype=torch.bfloat16) | ||
| gate_scores = torch.randn(sq, b, np, dtype=torch.bfloat16) | ||
|
|
||
| gate_states = gate_scores.view(sq, b, np, 1) | ||
| out = core_attn_out.view(sq, b, np, hn) | ||
| # Mirrors the production code: float cast for sigmoid, then cast back | ||
| out = out * torch.sigmoid(gate_states.float()).to(out.dtype) | ||
| out = out.view(sq, b, np * hn) | ||
|
|
||
| assert out.dtype == torch.bfloat16 | ||
|
|
||
| @pytest.mark.parametrize("gate_value,expected_scale", [(-1e6, 0.0), (1e6, 1.0), (0.0, 0.5)]) | ||
| def test_gate_saturation(self, gate_value: float, expected_scale: float): | ||
| """Extreme gate values saturate sigmoid to 0 or 1; midpoint gives 0.5.""" | ||
| sq, b, np, hn = 2, 1, 2, 8 | ||
| core_attn_out = torch.ones(sq, b, np * hn) | ||
| gate_scores = torch.full((sq, b, np), gate_value) | ||
|
|
||
| gate_states = gate_scores.view(sq, b, np, 1) | ||
| out = core_attn_out.view(sq, b, np, hn) | ||
| out = out * torch.sigmoid(gate_states.float()).to(out.dtype) | ||
| out = out.view(sq, b, np * hn) | ||
|
|
||
| torch.testing.assert_close(out, torch.full_like(out, expected_scale), atol=1e-5, rtol=1e-5) | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[SUGGESTION Code]
getattr(self.config, 'rotary_base_per_layer', None)usesgetattr, which suggests the field might not exist — but the field is added by this PR toTransformerConfig, so it always exists.Suggestion:
This avoids unnecessary cognitive overhead and prevents readers from wondering whether some other config type lacks this field.