-
Notifications
You must be signed in to change notification settings - Fork 3.9k
feat(attention): Add rotary_base_per_layer for Step-3.5-Flash #4473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
shifangx
merged 9 commits into
NVIDIA:dev
from
shifangx:shifang/attention-for-step-3.5-flash
May 3, 2026
+172
−0
Merged
Changes from 7 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
6f80c19
feat(attention): Add attention_per_head_gate and rotary_base_per_laye…
zidanehuang001 cc4d1a1
add test
shifangx 633b9e3
Fold use_head_wise_attn_gate into linear_qkv and merge it with the ex…
shifangx 6ad0ceb
add test for rotary_base_per_layer
shifangx 7485da4
Merge branch 'dev' into shifang/attention-for-step-3.5-flash
shifangx 5671195
formate
shifangx bea28f6
revert head_wise_attn_gate
shifangx 90857df
fix issue with tests/unit_tests/models/test_hybrid_moe_model.py
shifangx 027d8d2
Merge branch 'dev' into shifang/attention-for-step-3.5-flash
shifangx File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
90 changes: 90 additions & 0 deletions
90
tests/unit_tests/transformer/test_rotary_base_per_layer.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| # Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
|
|
||
| """Tests for per-layer RoPE base (rotary_base_per_layer) wiring in SelfAttention.""" | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding | ||
| from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec | ||
| from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed | ||
| from megatron.core.transformer import TransformerConfig | ||
| from megatron.core.transformer.attention import SelfAttention | ||
| from tests.unit_tests.test_utilities import Utils | ||
|
|
||
| SEQ_LEN = 16 | ||
| BATCH_SIZE = 2 | ||
| HIDDEN_SIZE = 128 | ||
| NUM_HEADS = 4 | ||
| NUM_LAYERS = 2 | ||
| ROTARY_BASE_L1 = 10000.0 | ||
| ROTARY_BASE_L2 = 5000.0 | ||
|
|
||
|
|
||
| def _make_config(rotary_base_per_layer=None) -> TransformerConfig: | ||
| config = TransformerConfig( | ||
| num_layers=NUM_LAYERS, | ||
| hidden_size=HIDDEN_SIZE, | ||
| num_attention_heads=NUM_HEADS, | ||
| use_cpu_initialization=True, | ||
| bf16=True, | ||
| params_dtype=torch.bfloat16, | ||
| rotary_base_per_layer=rotary_base_per_layer, | ||
| ) | ||
| # _build_per_layer_rotary_pos_emb reads these attributes from config; they are | ||
| # normally injected by GPTModel but must be set manually in unit tests. | ||
| config.position_embedding_type = 'rope' | ||
| config.rotary_scaling_factor = None # seq_len_interpolation_factor | ||
| config.rotary_percent = 1.0 | ||
| config.rope_scaling = False | ||
| config.rope_scaling_factor = 8.0 | ||
| return config | ||
|
|
||
|
|
||
| def _make_attention(config: TransformerConfig, layer_number: int = 1) -> SelfAttention: | ||
| submodules = get_gpt_layer_local_spec().submodules.self_attention.submodules | ||
| return SelfAttention(config, submodules, layer_number=layer_number) | ||
|
|
||
|
|
||
| class TestRotaryBasePerLayerInit: | ||
| """Verify that SelfAttention builds the correct per-layer RotaryEmbedding.""" | ||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def setup_teardown(self): | ||
| Utils.initialize_model_parallel(1, 1) | ||
| model_parallel_cuda_manual_seed(42) | ||
| yield | ||
| Utils.destroy_model_parallel() | ||
|
|
||
| def test_rotary_pos_emb_is_rope_instance(self): | ||
| """rotary_pos_emb is a RotaryEmbedding when rotary_base_per_layer is set.""" | ||
| config = _make_config([ROTARY_BASE_L1, ROTARY_BASE_L2]) | ||
| attn = _make_attention(config, layer_number=1) | ||
| assert isinstance(attn.rotary_pos_emb, RotaryEmbedding) | ||
|
|
||
| def test_rotary_pos_emb_none_without_per_layer_config(self): | ||
| """rotary_pos_emb stays None when rotary_base_per_layer is not set.""" | ||
| config = TransformerConfig( | ||
| num_layers=NUM_LAYERS, | ||
| hidden_size=HIDDEN_SIZE, | ||
| num_attention_heads=NUM_HEADS, | ||
| use_cpu_initialization=True, | ||
| bf16=True, | ||
| params_dtype=torch.bfloat16, | ||
| ) | ||
| attn = _make_attention(config, layer_number=1) | ||
| assert attn.rotary_pos_emb is None | ||
|
|
||
| def test_different_bases_produce_different_inv_freq(self): | ||
| """Layers with distinct bases must have different inv_freq tensors.""" | ||
| config = _make_config([ROTARY_BASE_L1, ROTARY_BASE_L2]) | ||
| attn1 = _make_attention(config, layer_number=1) | ||
| attn2 = _make_attention(config, layer_number=2) | ||
| assert not torch.allclose(attn1.rotary_pos_emb.inv_freq, attn2.rotary_pos_emb.inv_freq) | ||
|
|
||
| def test_same_base_produces_identical_inv_freq(self): | ||
| """Layers sharing the same base must have identical inv_freq tensors.""" | ||
| config = _make_config([ROTARY_BASE_L1, ROTARY_BASE_L1]) | ||
| attn1 = _make_attention(config, layer_number=1) | ||
| attn2 = _make_attention(config, layer_number=2) | ||
| torch.testing.assert_close(attn1.rotary_pos_emb.inv_freq, attn2.rotary_pos_emb.inv_freq) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.