feat(attention): Add attention_per_head_gate and rotary_base_per_laye…#4473
feat(attention): Add attention_per_head_gate and rotary_base_per_laye…#4473shifangx wants to merge 2 commits intoNVIDIA:devfrom
Conversation
3d6565f to
422246b
Compare
…r for Step-3.5-Flash Adds two new optional, off-by-default features to TransformerConfig and SelfAttention to faithfully represent the Step-3.5-Flash architecture. - attention_per_head_gate: adds a separate ColumnParallelLinear(hidden_size -> num_attention_heads) whose sigmoid output gates each head independently (Step-3.5-Flash g_proj). Applied after core attention, before linear_proj. - rotary_base_per_layer: Optional[List[float]] -- per-layer RoPE theta values. When set, each SelfAttention creates its own RotaryEmbedding; the shared model-level rotary_pos_emb in GPTModel is not created. Both features default to False/None and have no effect on existing models. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
422246b to
6f80c19
Compare
There was a problem hiding this comment.
PR adds two opt-in features (use_head_wise_attn_gate + rotary_base_per_layer) for Step-3.5-Flash; defaults are off, existing models unaffected.
Main suggestions (see inline comments for details):
-
[CRITICAL] Fold
use_head_wise_attn_gateintolinear_qkvand merge it with the existingattention_output_gatepath. Share_split_qkv/_apply_output_gate(broadcasting handles both gate shapes — no new branch needed). Benefits:- Eliminates the TE/local backend divergence in
g_projmath (submodules.linear_qkvresolves toTELayerNormColumnParallelLinearunder TE, which adds an independent learnable LayerNorm). - Avoids two near-parallel "output gating" implementations coexisting long-term and drifting.
- Saves one GEMM kernel launch.
- Makes the two gate flags naturally mutually exclusive.
- Eliminates the TE/local backend divergence in
-
[IMPORTANT] The
rotary_base_per_layerforward override depends on the model-levelrotary_pos_embalready existing, contradicting the PR description's claim that "model-level rotary_pos_emb is not created" — this PR does not modifyGPTModel, so both rotaries actually coexist. -
[IMPORTANT]
_build_per_layer_rotary_pos_embduplicates the rotary-construction logic fromgpt_model.py; recommend extracting a shared factory function. -
[IMPORTANT] No test coverage at all for
rotary_base_per_layer.
5/6. [SUGGESTION] assert False, "Invalid position embedding type" is uninformative; getattr(self.config, 'rotary_base_per_layer', None) is unnecessary (the field is added by this PR).
| if self.config.use_head_wise_attn_gate: | ||
| self.g_proj = submodules.linear_qkv( | ||
| self.config.hidden_size, | ||
| self.config.num_attention_heads, | ||
| config=self.config, | ||
| init_method=not_none(self.config.init_method), | ||
| gather_output=False, | ||
| bias=False, | ||
| skip_bias_add=False, | ||
| is_expert=False, | ||
| tp_comm_buffer_name='gate', | ||
| tp_group=self.pg_collection.tp, |
There was a problem hiding this comment.
[CRITICAL Architecture] g_proj should be folded into linear_qkv and merged with the existing attention_output_gate path, rather than introducing a separate ColumnParallelLinear and a new forward branch.
Reason:
- Code duplication:
attention_output_gate(per-head per-channel; gate weights fused intolinear_qkv, sliced out by_split_qkv, applied by_apply_output_gate) anduse_head_wise_attn_gate(per-head scalar; independentg_proj; a new sigmoid-multiply block) are functionally near-parallel — they only differ in whether gate granularity ishnor1. Yet their implementation paths are completely independent. Future changes (GQA AG handling, CUDA graphs, FP8/FP4 compatibility) on either path will need to be replicated and will drift. - TE/local backend divergence:
g_proj = submodules.linear_qkv(...)resolves toTELayerNormColumnParallelLinear(with a fused LN that has its own learnable parameters) under the TE backend, but plainColumnParallelLinearunder the local backend — the same config produces different mathematics on the two backends. Folding intolinear_qkvmakes QKV and gate share the same LN, which is consistent across backends and matches the pure-linearg_projof the Step-3.5-Flash paper. - One fewer GEMM kernel launch: the per-head scalar gate occupies only
num_headscolumns; folding it intolinear_qkvis essentially free. _apply_output_gatedoes not need a branch: replacegate.view(*x.shape)withview(*x.shape[:-1], -1)(or remove it altogether), and PyTorch broadcasting naturally handles both gate shapes —[sq,b,np,hn](existing) and[sq,b,np,1](new) both multiplycore_attn_out [sq,b,np,hn]correctly element-wise.
Suggestion:
- Branch in the
linear_qkv_out_dimcalculation:if self.config.attention_output_gate: self.linear_qkv_out_dim += self.config.kv_channels * self.config.num_attention_heads # per-head per-channel elif self.config.use_head_wise_attn_gate: self.linear_qkv_out_dim += self.config.num_attention_heads # per-head scalar
- In
_split_qkv, slice the trailingnum_headsscalars as a new tail segment. Since this is not a multiple ofhn, slice it off before the group reshape, or implement an unequal-width split inside the group. The GQA AG path needs corresponding updates: after the all-gather, slice QKV by group and slice the scalar segment by q_head separately. - Change
_apply_output_gateto rely on broadcasting so both gates share the same function. - Delete the
g_projmodule,tp_comm_buffer_name='gate', and the new per-head gate forward block. - Since only one gate type can occupy the
linear_qkvtail, the two flags become naturally mutually exclusive (you may add an assert in__post_init__to make this explicit).
| # Per-layer theta: override the model-level RoPE with this layer's own embedding. | ||
| if self.rotary_pos_emb is not None and rotary_pos_emb is not None: | ||
| seq_len = rotary_pos_emb.shape[0] | ||
| rotary_pos_emb = self.rotary_pos_emb(seq_len) |
There was a problem hiding this comment.
[IMPORTANT Correctness] The per-layer rotary override depends on the model-level rotary_pos_emb already existing, but the PR description states "shared model-level rotary_pos_emb in GPTModel is not created" — and this PR does not modify GPTModel.
Reason:
- Description does not match implementation: the model-level rotary is still created by
GPTModel, and each layer just overrides it. Both rotaries coexist; the model-level one is computed every forward and immediately discarded — wasted compute on every iteration. - Silent skip: if
GPTModelever stops passingrotary_pos_embon some path (cross-attention, fully no-rope layers, user-disabled), the per-layer rotary will be silently skipped — there's noassert, and the behavior becomes unpredictable. seq_len = rotary_pos_emb.shape[0]implicitly couples to the external rotary tensor shape. That variable is wrapped into(rotary_pos_emb,) * 2later in the forward; the override happens before the tuple wrap so it's currently correct, but the coupling is fragile.
Suggestion:
- Modify
GPTModelin tandem: whenconfig.rotary_base_per_layer is not None, skip creating the model-levelrotary_pos_emband just passseq_lento each layer so each layer generates its own rotary. - Or keep the override-style implementation but obtain
seq_lenfrominference_context/hidden_states.shape[0]instead of relying on the external rotary shape. - Add an assert: when per-layer rotary is enabled and
rotary_pos_emb is None, raise explicitly instead of silently skipping.
| def _build_per_layer_rotary_pos_emb(self, rotary_base: float) -> None: | ||
| """Build self.rotary_pos_emb using a layer-specific rotary base.""" | ||
| from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding | ||
|
|
||
| seq_len_interpolation_factor = self.config.rotary_scaling_factor | ||
| if self.config.position_embedding_type == 'rope' and not self.config.multi_latent_attention: | ||
| self.rotary_pos_emb = RotaryEmbedding( | ||
| kv_channels=self.config.kv_channels, | ||
| rotary_percent=self.config.rotary_percent, | ||
| rotary_interleaved=self.config.rotary_interleaved, | ||
| seq_len_interpolation_factor=seq_len_interpolation_factor, | ||
| rotary_base=rotary_base, | ||
| rope_scaling=self.config.rope_scaling, | ||
| rope_scaling_factor=self.config.rope_scaling_factor, | ||
| use_cpu_initialization=self.config.use_cpu_initialization, | ||
| cp_group=self.pg_collection.cp, | ||
| ) | ||
| elif self.config.position_embedding_type == 'yarn': | ||
| self.rotary_pos_emb = YarnRotaryEmbedding( | ||
| kv_channels=self.config.kv_channels, | ||
| rotary_percent=self.config.rotary_percent, | ||
| rotary_interleaved=self.config.rotary_interleaved, | ||
| seq_len_interpolation_factor=seq_len_interpolation_factor, | ||
| rotary_base=rotary_base, | ||
| scaling_factor=getattr(self.config, "yarn_rotary_scaling_factor"), | ||
| original_max_position_embeddings=getattr( | ||
| self.config, "yarn_original_max_position_embeddings" | ||
| ), | ||
| beta_fast=getattr(self.config, "yarn_beta_fast"), | ||
| beta_slow=getattr(self.config, "yarn_beta_slow"), | ||
| mscale=getattr(self.config, "yarn_mscale"), | ||
| mscale_all_dim=getattr(self.config, "yarn_mscale_all_dim"), | ||
| correction_range_round_to_int=getattr( | ||
| self.config, "yarn_correction_range_round_to_int" | ||
| ), | ||
| use_cpu_initialization=self.config.use_cpu_initialization, | ||
| ) | ||
| elif self.config.position_embedding_type == 'mrope' and not self.config.multi_latent_attention: | ||
| self.rotary_pos_emb = MultimodalRotaryEmbedding( | ||
| kv_channels=self.config.kv_channels, | ||
| rotary_percent=self.config.rotary_percent, | ||
| rotary_interleaved=self.config.rotary_interleaved, | ||
| seq_len_interpolation_factor=seq_len_interpolation_factor, | ||
| rotary_base=rotary_base, | ||
| ) | ||
| self.mrope_section = self.config.mrope_section | ||
| assert ( | ||
| self.mrope_section is not None | ||
| ), "mrope require mrope_section setting, but we got None from TransformerConfig" | ||
| else: | ||
| assert False, "Invalid position embedding type" |
There was a problem hiding this comment.
[IMPORTANT Maintenance] _build_per_layer_rotary_pos_emb duplicates the RotaryEmbedding / YarnRotaryEmbedding / MultimodalRotaryEmbedding construction logic from gpt_model.py.
Reason: When gpt_model.py adds new rotary parameters (new rope_scaling parameters, new rotary types), this per-layer path will silently fall out of sync, causing per-layer rotary behavior to diverge from model-level rotary.
Suggestion: Extract a factory function (e.g. in models/common/embeddings/rotary_pos_embedding.py) with a signature like:
def build_rotary_pos_emb(
config,
rotary_base: Optional[float] = None,
cp_group=None,
) -> RotaryEmbedding | YarnRotaryEmbedding | MultimodalRotaryEmbedding: ...called from both GPTModel and here, so the two sites stay in lockstep.
| @@ -0,0 +1,212 @@ | |||
| # Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |||
There was a problem hiding this comment.
[IMPORTANT Tests] Tests cover only use_head_wise_attn_gate; there are no tests for rotary_base_per_layer.
Reason: rotary_base_per_layer involves two independent paths (per-layer rotary construction, forward-time override of the model-level rotary), each with three branches (rope / yarn / mrope). Zero test coverage means any future refactor (including adopting the factory-function suggestion) has no safety net.
Suggestion: Add at minimum:
- A test that triggers the
__post_init__assert whenlen(rotary_base_per_layer) != num_layers. - A multi-layer model with distinct
rotary_baseper layer; verify each layer'sself.rotary_pos_embis non-Noneand thatinv_freqmatches the correspondingrotary_base. - A forward-pass test where each layer's rotary output numerically matches a standalone
RotaryEmbeddingconstructed with the samerotary_base. - A test that with
rotary_base_per_layer=None,self.rotary_pos_emb is Noneand forward runs without error.
| self.mrope_section is not None | ||
| ), "mrope require mrope_section setting, but we got None from TransformerConfig" | ||
| else: | ||
| assert False, "Invalid position embedding type" |
There was a problem hiding this comment.
[SUGGESTION Code] assert False, "Invalid position embedding type" does not include the actual position_embedding_type value, and assert False is semantically off — this is not an invariant violation, it's an unsupported configuration.
Suggestion:
else:
raise NotImplementedError(
f"rotary_base_per_layer does not support "
f"position_embedding_type={self.config.position_embedding_type!r} "
f"(only 'rope' / 'yarn' / 'mrope' are supported)."
)|
|
||
| # Per-layer RotaryEmbedding (used when rotary_base_per_layer is set in config). | ||
| self.rotary_pos_emb = None | ||
| if getattr(self.config, 'rotary_base_per_layer', None): |
There was a problem hiding this comment.
[SUGGESTION Code] getattr(self.config, 'rotary_base_per_layer', None) uses getattr, which suggests the field might not exist — but the field is added by this PR to TransformerConfig, so it always exists.
Suggestion:
if self.config.rotary_base_per_layer is not None:This avoids unnecessary cognitive overhead and prevents readers from wondering whether some other config type lacks this field.
What does this PR do ?
feat(attention): Add attention_per_head_gate and rotary_base_per_layer for Step-3.5-Flash
Adds two new optional, off-by-default features to TransformerConfig and
Attention to faithfully represent the Step-3.5-Flash architecture.
attention_per_head_gate: adds a separate ColumnParallelLinear(hidden_size
-> num_attention_heads) whose sigmoid output gates each head independently
(Step-3.5-Flash g_proj). Applied after core attention, before linear_proj.
rotary_base_per_layer: Optional[List[float]] -- per-layer RoPE theta values.
When set, each SelfAttention creates its own RotaryEmbedding
Both features default to False/None and have no effect on existing models.
Issue tracking
For PRs from open-source community contributors:
Linked issue:
Contribution process
Pre-checks
Code review
Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!
All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.
Step 1: Mark PR as "Ready for Review"
.github/CODEOWNERS.Final Review might get declined if these requirements are not fulfilled.
Step 2: Final Review
For PRs that change
megatron/core, once all expert reviewers have approved, theFinal Reviewlabel is applied automatically and final reviewers are assigned.For PRs outside
megatron/core, this step is skipped.Step 3: Approved
Once all required reviewers have approved, the
Approvedlabel is applied automatically.Merge
Any member of mcore-engineers will be able to merge your PR.
For MRs into `dev` branch
The proposed review process for `dev` branch is under active discussion.MRs are mergable after one approval by either
eharper@nvidia.comorzijiey@nvidia.com.