Skip to content

feat(attention): Add attention_per_head_gate and rotary_base_per_laye…#4473

Open
shifangx wants to merge 2 commits intoNVIDIA:devfrom
shifangx:shifang/attention-for-step-3.5-flash
Open

feat(attention): Add attention_per_head_gate and rotary_base_per_laye…#4473
shifangx wants to merge 2 commits intoNVIDIA:devfrom
shifangx:shifang/attention-for-step-3.5-flash

Conversation

@shifangx
Copy link
Copy Markdown
Contributor

@shifangx shifangx commented Apr 26, 2026

What does this PR do ?

feat(attention): Add attention_per_head_gate and rotary_base_per_layer for Step-3.5-Flash
Adds two new optional, off-by-default features to TransformerConfig and
Attention to faithfully represent the Step-3.5-Flash architecture.

  • attention_per_head_gate: adds a separate ColumnParallelLinear(hidden_size
    -> num_attention_heads) whose sigmoid output gates each head independently
    (Step-3.5-Flash g_proj). Applied after core attention, before linear_proj.

  • rotary_base_per_layer: Optional[List[float]] -- per-layer RoPE theta values.
    When set, each SelfAttention creates its own RotaryEmbedding

Both features default to False/None and have no effect on existing models.

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Issue tracking

For PRs from open-source community contributors:

  • New features: a linked issue is required. Please open a feature request and reference it here before submitting the PR.
  • Small updates (bug fixes, minor improvements): a linked issue is recommended and will accelerate the PR review process.

Linked issue:

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

@shifangx shifangx requested review from a team as code owners April 26, 2026 00:06
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 26, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@shifangx shifangx force-pushed the shifang/attention-for-step-3.5-flash branch 4 times, most recently from 3d6565f to 422246b Compare April 26, 2026 15:17
…r for Step-3.5-Flash

Adds two new optional, off-by-default features to TransformerConfig and
SelfAttention to faithfully represent the Step-3.5-Flash architecture.

- attention_per_head_gate: adds a separate ColumnParallelLinear(hidden_size
  -> num_attention_heads) whose sigmoid output gates each head independently
  (Step-3.5-Flash g_proj). Applied after core attention, before linear_proj.

- rotary_base_per_layer: Optional[List[float]] -- per-layer RoPE theta values.
  When set, each SelfAttention creates its own RotaryEmbedding; the shared
  model-level rotary_pos_emb in GPTModel is not created.

Both features default to False/None and have no effect on existing models.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@shifangx shifangx force-pushed the shifang/attention-for-step-3.5-flash branch from 422246b to 6f80c19 Compare April 26, 2026 15:29
Copy link
Copy Markdown
Contributor

@Victarry Victarry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR adds two opt-in features (use_head_wise_attn_gate + rotary_base_per_layer) for Step-3.5-Flash; defaults are off, existing models unaffected.

Main suggestions (see inline comments for details):

  1. [CRITICAL] Fold use_head_wise_attn_gate into linear_qkv and merge it with the existing attention_output_gate path. Share _split_qkv / _apply_output_gate (broadcasting handles both gate shapes — no new branch needed). Benefits:

    • Eliminates the TE/local backend divergence in g_proj math (submodules.linear_qkv resolves to TELayerNormColumnParallelLinear under TE, which adds an independent learnable LayerNorm).
    • Avoids two near-parallel "output gating" implementations coexisting long-term and drifting.
    • Saves one GEMM kernel launch.
    • Makes the two gate flags naturally mutually exclusive.
  2. [IMPORTANT] The rotary_base_per_layer forward override depends on the model-level rotary_pos_emb already existing, contradicting the PR description's claim that "model-level rotary_pos_emb is not created" — this PR does not modify GPTModel, so both rotaries actually coexist.

  3. [IMPORTANT] _build_per_layer_rotary_pos_emb duplicates the rotary-construction logic from gpt_model.py; recommend extracting a shared factory function.

  4. [IMPORTANT] No test coverage at all for rotary_base_per_layer.

5/6. [SUGGESTION] assert False, "Invalid position embedding type" is uninformative; getattr(self.config, 'rotary_base_per_layer', None) is unnecessary (the field is added by this PR).

Comment on lines +390 to +401
if self.config.use_head_wise_attn_gate:
self.g_proj = submodules.linear_qkv(
self.config.hidden_size,
self.config.num_attention_heads,
config=self.config,
init_method=not_none(self.config.init_method),
gather_output=False,
bias=False,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name='gate',
tp_group=self.pg_collection.tp,
Copy link
Copy Markdown
Contributor

@Victarry Victarry Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[CRITICAL Architecture] g_proj should be folded into linear_qkv and merged with the existing attention_output_gate path, rather than introducing a separate ColumnParallelLinear and a new forward branch.

Reason:

  1. Code duplication: attention_output_gate (per-head per-channel; gate weights fused into linear_qkv, sliced out by _split_qkv, applied by _apply_output_gate) and use_head_wise_attn_gate (per-head scalar; independent g_proj; a new sigmoid-multiply block) are functionally near-parallel — they only differ in whether gate granularity is hn or 1. Yet their implementation paths are completely independent. Future changes (GQA AG handling, CUDA graphs, FP8/FP4 compatibility) on either path will need to be replicated and will drift.
  2. TE/local backend divergence: g_proj = submodules.linear_qkv(...) resolves to TELayerNormColumnParallelLinear (with a fused LN that has its own learnable parameters) under the TE backend, but plain ColumnParallelLinear under the local backend — the same config produces different mathematics on the two backends. Folding into linear_qkv makes QKV and gate share the same LN, which is consistent across backends and matches the pure-linear g_proj of the Step-3.5-Flash paper.
  3. One fewer GEMM kernel launch: the per-head scalar gate occupies only num_heads columns; folding it into linear_qkv is essentially free.
  4. _apply_output_gate does not need a branch: replace gate.view(*x.shape) with view(*x.shape[:-1], -1) (or remove it altogether), and PyTorch broadcasting naturally handles both gate shapes — [sq,b,np,hn] (existing) and [sq,b,np,1] (new) both multiply core_attn_out [sq,b,np,hn] correctly element-wise.

Suggestion:

  • Branch in the linear_qkv_out_dim calculation:
    if self.config.attention_output_gate:
        self.linear_qkv_out_dim += self.config.kv_channels * self.config.num_attention_heads  # per-head per-channel
    elif self.config.use_head_wise_attn_gate:
        self.linear_qkv_out_dim += self.config.num_attention_heads                            # per-head scalar
  • In _split_qkv, slice the trailing num_heads scalars as a new tail segment. Since this is not a multiple of hn, slice it off before the group reshape, or implement an unequal-width split inside the group. The GQA AG path needs corresponding updates: after the all-gather, slice QKV by group and slice the scalar segment by q_head separately.
  • Change _apply_output_gate to rely on broadcasting so both gates share the same function.
  • Delete the g_proj module, tp_comm_buffer_name='gate', and the new per-head gate forward block.
  • Since only one gate type can occupy the linear_qkv tail, the two flags become naturally mutually exclusive (you may add an assert in __post_init__ to make this explicit).

Comment on lines +1052 to +1055
# Per-layer theta: override the model-level RoPE with this layer's own embedding.
if self.rotary_pos_emb is not None and rotary_pos_emb is not None:
seq_len = rotary_pos_emb.shape[0]
rotary_pos_emb = self.rotary_pos_emb(seq_len)
Copy link
Copy Markdown
Contributor

@Victarry Victarry Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Correctness] The per-layer rotary override depends on the model-level rotary_pos_emb already existing, but the PR description states "shared model-level rotary_pos_emb in GPTModel is not created" — and this PR does not modify GPTModel.

Reason:

  1. Description does not match implementation: the model-level rotary is still created by GPTModel, and each layer just overrides it. Both rotaries coexist; the model-level one is computed every forward and immediately discarded — wasted compute on every iteration.
  2. Silent skip: if GPTModel ever stops passing rotary_pos_emb on some path (cross-attention, fully no-rope layers, user-disabled), the per-layer rotary will be silently skipped — there's no assert, and the behavior becomes unpredictable.
  3. seq_len = rotary_pos_emb.shape[0] implicitly couples to the external rotary tensor shape. That variable is wrapped into (rotary_pos_emb,) * 2 later in the forward; the override happens before the tuple wrap so it's currently correct, but the coupling is fragile.

Suggestion:

  • Modify GPTModel in tandem: when config.rotary_base_per_layer is not None, skip creating the model-level rotary_pos_emb and just pass seq_len to each layer so each layer generates its own rotary.
  • Or keep the override-style implementation but obtain seq_len from inference_context / hidden_states.shape[0] instead of relying on the external rotary shape.
  • Add an assert: when per-layer rotary is enabled and rotary_pos_emb is None, raise explicitly instead of silently skipping.

Comment on lines +404 to +454
def _build_per_layer_rotary_pos_emb(self, rotary_base: float) -> None:
"""Build self.rotary_pos_emb using a layer-specific rotary base."""
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding

seq_len_interpolation_factor = self.config.rotary_scaling_factor
if self.config.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=self.config.rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
rope_scaling=self.config.rope_scaling,
rope_scaling_factor=self.config.rope_scaling_factor,
use_cpu_initialization=self.config.use_cpu_initialization,
cp_group=self.pg_collection.cp,
)
elif self.config.position_embedding_type == 'yarn':
self.rotary_pos_emb = YarnRotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=self.config.rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
scaling_factor=getattr(self.config, "yarn_rotary_scaling_factor"),
original_max_position_embeddings=getattr(
self.config, "yarn_original_max_position_embeddings"
),
beta_fast=getattr(self.config, "yarn_beta_fast"),
beta_slow=getattr(self.config, "yarn_beta_slow"),
mscale=getattr(self.config, "yarn_mscale"),
mscale_all_dim=getattr(self.config, "yarn_mscale_all_dim"),
correction_range_round_to_int=getattr(
self.config, "yarn_correction_range_round_to_int"
),
use_cpu_initialization=self.config.use_cpu_initialization,
)
elif self.config.position_embedding_type == 'mrope' and not self.config.multi_latent_attention:
self.rotary_pos_emb = MultimodalRotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=self.config.rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
)
self.mrope_section = self.config.mrope_section
assert (
self.mrope_section is not None
), "mrope require mrope_section setting, but we got None from TransformerConfig"
else:
assert False, "Invalid position embedding type"
Copy link
Copy Markdown
Contributor

@Victarry Victarry Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Maintenance] _build_per_layer_rotary_pos_emb duplicates the RotaryEmbedding / YarnRotaryEmbedding / MultimodalRotaryEmbedding construction logic from gpt_model.py.

Reason: When gpt_model.py adds new rotary parameters (new rope_scaling parameters, new rotary types), this per-layer path will silently fall out of sync, causing per-layer rotary behavior to diverge from model-level rotary.

Suggestion: Extract a factory function (e.g. in models/common/embeddings/rotary_pos_embedding.py) with a signature like:

def build_rotary_pos_emb(
    config,
    rotary_base: Optional[float] = None,
    cp_group=None,
) -> RotaryEmbedding | YarnRotaryEmbedding | MultimodalRotaryEmbedding: ...

called from both GPTModel and here, so the two sites stay in lockstep.

@@ -0,0 +1,212 @@
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copy link
Copy Markdown
Contributor

@Victarry Victarry Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Tests] Tests cover only use_head_wise_attn_gate; there are no tests for rotary_base_per_layer.

Reason: rotary_base_per_layer involves two independent paths (per-layer rotary construction, forward-time override of the model-level rotary), each with three branches (rope / yarn / mrope). Zero test coverage means any future refactor (including adopting the factory-function suggestion) has no safety net.

Suggestion: Add at minimum:

  • A test that triggers the __post_init__ assert when len(rotary_base_per_layer) != num_layers.
  • A multi-layer model with distinct rotary_base per layer; verify each layer's self.rotary_pos_emb is non-None and that inv_freq matches the corresponding rotary_base.
  • A forward-pass test where each layer's rotary output numerically matches a standalone RotaryEmbedding constructed with the same rotary_base.
  • A test that with rotary_base_per_layer=None, self.rotary_pos_emb is None and forward runs without error.

self.mrope_section is not None
), "mrope require mrope_section setting, but we got None from TransformerConfig"
else:
assert False, "Invalid position embedding type"
Copy link
Copy Markdown
Contributor

@Victarry Victarry Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION Code] assert False, "Invalid position embedding type" does not include the actual position_embedding_type value, and assert False is semantically off — this is not an invariant violation, it's an unsupported configuration.

Suggestion:

else:
    raise NotImplementedError(
        f"rotary_base_per_layer does not support "
        f"position_embedding_type={self.config.position_embedding_type!r} "
        f"(only 'rope' / 'yarn' / 'mrope' are supported)."
    )


# Per-layer RotaryEmbedding (used when rotary_base_per_layer is set in config).
self.rotary_pos_emb = None
if getattr(self.config, 'rotary_base_per_layer', None):
Copy link
Copy Markdown
Contributor

@Victarry Victarry Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION Code] getattr(self.config, 'rotary_base_per_layer', None) uses getattr, which suggests the field might not exist — but the field is added by this PR to TransformerConfig, so it always exists.

Suggestion:

if self.config.rotary_base_per_layer is not None:

This avoids unnecessary cognitive overhead and prevents readers from wondering whether some other config type lacks this field.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants