Skip to content
70 changes: 70 additions & 0 deletions megatron/core/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,71 @@ def __init__(
# the quantized tensor.
set_save_original_input(self.linear_proj)

# Per-layer RotaryEmbedding (used when rotary_base_per_layer is set in config).
self.rotary_pos_emb = None
if getattr(self.config, 'rotary_base_per_layer', None):
Comment thread
shifangx marked this conversation as resolved.
rotary_base = self.config.rotary_base_per_layer[self.layer_number - 1]
self._build_per_layer_rotary_pos_emb(rotary_base)

def _build_per_layer_rotary_pos_emb(self, rotary_base: float) -> None:
"""Build self.rotary_pos_emb using a layer-specific rotary base."""
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding

seq_len_interpolation_factor = self.config.rotary_scaling_factor
if self.config.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=self.config.rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
rope_scaling=self.config.rope_scaling,
rope_scaling_factor=self.config.rope_scaling_factor,
use_cpu_initialization=self.config.use_cpu_initialization,
cp_group=self.pg_collection.cp,
)
elif self.config.position_embedding_type == 'yarn':
self.rotary_pos_emb = YarnRotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=self.config.rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
scaling_factor=getattr(self.config, "yarn_rotary_scaling_factor"),
original_max_position_embeddings=getattr(
self.config, "yarn_original_max_position_embeddings"
),
beta_fast=getattr(self.config, "yarn_beta_fast"),
beta_slow=getattr(self.config, "yarn_beta_slow"),
mscale=getattr(self.config, "yarn_mscale"),
mscale_all_dim=getattr(self.config, "yarn_mscale_all_dim"),
correction_range_round_to_int=getattr(
self.config, "yarn_correction_range_round_to_int"
),
use_cpu_initialization=self.config.use_cpu_initialization,
)
elif (
self.config.position_embedding_type == 'mrope'
and not self.config.multi_latent_attention
):
self.rotary_pos_emb = MultimodalRotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=self.config.rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
)
self.mrope_section = self.config.mrope_section
assert (
self.mrope_section is not None
), "mrope require mrope_section setting, but we got None from TransformerConfig"
else:
raise NotImplementedError(
f"rotary_base_per_layer does not support "
f"position_embedding_type={self.config.position_embedding_type!r} "
f"(only 'rope' / 'yarn' / 'mrope' are supported)."
)

def _checkpointed_attention_forward(
self,
query,
Expand Down Expand Up @@ -1024,6 +1089,11 @@ def forward(
if no_rope:
rotary_pos_emb = None

# Per-layer theta: override the model-level RoPE with this layer's own embedding.
if self.rotary_pos_emb is not None and rotary_pos_emb is not None:
seq_len = rotary_pos_emb.shape[0]
rotary_pos_emb = self.rotary_pos_emb(seq_len)
Comment thread
shifangx marked this conversation as resolved.

inference_context = deprecate_inference_params(inference_context, inference_params)

if inference_context and inference_context.is_dynamic_batching():
Expand Down
11 changes: 11 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,11 @@ class TransformerConfig(ModelParallelConfig):
attention_output_gate: bool = False
"""Whether to apply output gate to the attention layers."""

rotary_base_per_layer: Optional[List[float]] = None
"""Per-layer RoPE theta values. Length must equal num_layers. When set, each
SelfAttention layer creates its own RotaryEmbedding with the corresponding base;
the shared model-level rotary_pos_emb is not created."""

test_mode: bool = False
"""Whether to run real-time tests."""

Expand Down Expand Up @@ -2523,6 +2528,12 @@ def __post_init__(self):
"2.3.0.dev0+39c0e70"
), "Must have at least TE version 2.3 or higher to use symmetric memory all reduce"

if self.rotary_base_per_layer is not None:
assert len(self.rotary_base_per_layer) == self.num_layers, (
f"rotary_base_per_layer length ({len(self.rotary_base_per_layer)}) "
f"must equal num_layers ({self.num_layers})"
)

if self.no_rope_freq:
assert not self.flash_decode, "flash_decode cannot be used with no_rope."
if isinstance(self.no_rope_freq, int):
Expand Down
90 changes: 90 additions & 0 deletions tests/unit_tests/transformer/test_rotary_base_per_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

"""Tests for per-layer RoPE base (rotary_base_per_layer) wiring in SelfAttention."""

import pytest
import torch

from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.attention import SelfAttention
from tests.unit_tests.test_utilities import Utils

SEQ_LEN = 16
BATCH_SIZE = 2
HIDDEN_SIZE = 128
NUM_HEADS = 4
NUM_LAYERS = 2
ROTARY_BASE_L1 = 10000.0
ROTARY_BASE_L2 = 5000.0


def _make_config(rotary_base_per_layer=None) -> TransformerConfig:
config = TransformerConfig(
num_layers=NUM_LAYERS,
hidden_size=HIDDEN_SIZE,
num_attention_heads=NUM_HEADS,
use_cpu_initialization=True,
bf16=True,
params_dtype=torch.bfloat16,
rotary_base_per_layer=rotary_base_per_layer,
)
# _build_per_layer_rotary_pos_emb reads these attributes from config; they are
# normally injected by GPTModel but must be set manually in unit tests.
config.position_embedding_type = 'rope'
config.rotary_scaling_factor = None # seq_len_interpolation_factor
config.rotary_percent = 1.0
config.rope_scaling = False
config.rope_scaling_factor = 8.0
return config


def _make_attention(config: TransformerConfig, layer_number: int = 1) -> SelfAttention:
submodules = get_gpt_layer_local_spec().submodules.self_attention.submodules
return SelfAttention(config, submodules, layer_number=layer_number)


class TestRotaryBasePerLayerInit:
"""Verify that SelfAttention builds the correct per-layer RotaryEmbedding."""

@pytest.fixture(autouse=True)
def setup_teardown(self):
Utils.initialize_model_parallel(1, 1)
model_parallel_cuda_manual_seed(42)
yield
Utils.destroy_model_parallel()

def test_rotary_pos_emb_is_rope_instance(self):
"""rotary_pos_emb is a RotaryEmbedding when rotary_base_per_layer is set."""
config = _make_config([ROTARY_BASE_L1, ROTARY_BASE_L2])
attn = _make_attention(config, layer_number=1)
assert isinstance(attn.rotary_pos_emb, RotaryEmbedding)

def test_rotary_pos_emb_none_without_per_layer_config(self):
"""rotary_pos_emb stays None when rotary_base_per_layer is not set."""
config = TransformerConfig(
num_layers=NUM_LAYERS,
hidden_size=HIDDEN_SIZE,
num_attention_heads=NUM_HEADS,
use_cpu_initialization=True,
bf16=True,
params_dtype=torch.bfloat16,
)
attn = _make_attention(config, layer_number=1)
assert attn.rotary_pos_emb is None

def test_different_bases_produce_different_inv_freq(self):
"""Layers with distinct bases must have different inv_freq tensors."""
config = _make_config([ROTARY_BASE_L1, ROTARY_BASE_L2])
attn1 = _make_attention(config, layer_number=1)
attn2 = _make_attention(config, layer_number=2)
assert not torch.allclose(attn1.rotary_pos_emb.inv_freq, attn2.rotary_pos_emb.inv_freq)

def test_same_base_produces_identical_inv_freq(self):
"""Layers sharing the same base must have identical inv_freq tensors."""
config = _make_config([ROTARY_BASE_L1, ROTARY_BASE_L1])
attn1 = _make_attention(config, layer_number=1)
attn2 = _make_attention(config, layer_number=2)
torch.testing.assert_close(attn1.rotary_pos_emb.inv_freq, attn2.rotary_pos_emb.inv_freq)
Loading