diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 2ff3eacc071..9cc39f7a197 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -387,6 +387,71 @@ def __init__( # the quantized tensor. set_save_original_input(self.linear_proj) + # Per-layer RotaryEmbedding (used when rotary_base_per_layer is set in config). + self.rotary_pos_emb = None + if getattr(self.config, 'rotary_base_per_layer', None): + rotary_base = self.config.rotary_base_per_layer[self.layer_number - 1] + self._build_per_layer_rotary_pos_emb(rotary_base) + + def _build_per_layer_rotary_pos_emb(self, rotary_base: float) -> None: + """Build self.rotary_pos_emb using a layer-specific rotary base.""" + from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding + + seq_len_interpolation_factor = self.config.rotary_scaling_factor + if self.config.position_embedding_type == 'rope' and not self.config.multi_latent_attention: + self.rotary_pos_emb = RotaryEmbedding( + kv_channels=self.config.kv_channels, + rotary_percent=self.config.rotary_percent, + rotary_interleaved=self.config.rotary_interleaved, + seq_len_interpolation_factor=seq_len_interpolation_factor, + rotary_base=rotary_base, + rope_scaling=self.config.rope_scaling, + rope_scaling_factor=self.config.rope_scaling_factor, + use_cpu_initialization=self.config.use_cpu_initialization, + cp_group=self.pg_collection.cp, + ) + elif self.config.position_embedding_type == 'yarn': + self.rotary_pos_emb = YarnRotaryEmbedding( + kv_channels=self.config.kv_channels, + rotary_percent=self.config.rotary_percent, + rotary_interleaved=self.config.rotary_interleaved, + seq_len_interpolation_factor=seq_len_interpolation_factor, + rotary_base=rotary_base, + scaling_factor=getattr(self.config, "yarn_rotary_scaling_factor"), + original_max_position_embeddings=getattr( + self.config, "yarn_original_max_position_embeddings" + ), + beta_fast=getattr(self.config, "yarn_beta_fast"), + beta_slow=getattr(self.config, "yarn_beta_slow"), + mscale=getattr(self.config, "yarn_mscale"), + mscale_all_dim=getattr(self.config, "yarn_mscale_all_dim"), + correction_range_round_to_int=getattr( + self.config, "yarn_correction_range_round_to_int" + ), + use_cpu_initialization=self.config.use_cpu_initialization, + ) + elif ( + self.config.position_embedding_type == 'mrope' + and not self.config.multi_latent_attention + ): + self.rotary_pos_emb = MultimodalRotaryEmbedding( + kv_channels=self.config.kv_channels, + rotary_percent=self.config.rotary_percent, + rotary_interleaved=self.config.rotary_interleaved, + seq_len_interpolation_factor=seq_len_interpolation_factor, + rotary_base=rotary_base, + ) + self.mrope_section = self.config.mrope_section + assert ( + self.mrope_section is not None + ), "mrope require mrope_section setting, but we got None from TransformerConfig" + else: + raise NotImplementedError( + f"rotary_base_per_layer does not support " + f"position_embedding_type={self.config.position_embedding_type!r} " + f"(only 'rope' / 'yarn' / 'mrope' are supported)." + ) + def _checkpointed_attention_forward( self, query, @@ -1024,6 +1089,11 @@ def forward( if no_rope: rotary_pos_emb = None + # Per-layer theta: override the model-level RoPE with this layer's own embedding. + if self.rotary_pos_emb is not None and rotary_pos_emb is not None: + seq_len = rotary_pos_emb.shape[0] + rotary_pos_emb = self.rotary_pos_emb(seq_len) + inference_context = deprecate_inference_params(inference_context, inference_params) if inference_context and inference_context.is_dynamic_batching(): diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index b41203eb0a1..90943a49cc8 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -244,6 +244,11 @@ class TransformerConfig(ModelParallelConfig): attention_output_gate: bool = False """Whether to apply output gate to the attention layers.""" + rotary_base_per_layer: Optional[List[float]] = None + """Per-layer RoPE theta values. Length must equal num_layers. When set, each + SelfAttention layer creates its own RotaryEmbedding with the corresponding base; + the shared model-level rotary_pos_emb is not created.""" + test_mode: bool = False """Whether to run real-time tests.""" @@ -2523,6 +2528,12 @@ def __post_init__(self): "2.3.0.dev0+39c0e70" ), "Must have at least TE version 2.3 or higher to use symmetric memory all reduce" + if self.rotary_base_per_layer is not None: + assert len(self.rotary_base_per_layer) == self.num_layers, ( + f"rotary_base_per_layer length ({len(self.rotary_base_per_layer)}) " + f"must equal num_layers ({self.num_layers})" + ) + if self.no_rope_freq: assert not self.flash_decode, "flash_decode cannot be used with no_rope." if isinstance(self.no_rope_freq, int): diff --git a/tests/unit_tests/transformer/test_rotary_base_per_layer.py b/tests/unit_tests/transformer/test_rotary_base_per_layer.py new file mode 100644 index 00000000000..0a655094cb3 --- /dev/null +++ b/tests/unit_tests/transformer/test_rotary_base_per_layer.py @@ -0,0 +1,90 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Tests for per-layer RoPE base (rotary_base_per_layer) wiring in SelfAttention.""" + +import pytest +import torch + +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer import TransformerConfig +from megatron.core.transformer.attention import SelfAttention +from tests.unit_tests.test_utilities import Utils + +SEQ_LEN = 16 +BATCH_SIZE = 2 +HIDDEN_SIZE = 128 +NUM_HEADS = 4 +NUM_LAYERS = 2 +ROTARY_BASE_L1 = 10000.0 +ROTARY_BASE_L2 = 5000.0 + + +def _make_config(rotary_base_per_layer=None) -> TransformerConfig: + config = TransformerConfig( + num_layers=NUM_LAYERS, + hidden_size=HIDDEN_SIZE, + num_attention_heads=NUM_HEADS, + use_cpu_initialization=True, + bf16=True, + params_dtype=torch.bfloat16, + rotary_base_per_layer=rotary_base_per_layer, + ) + # _build_per_layer_rotary_pos_emb reads these attributes from config; they are + # normally injected by GPTModel but must be set manually in unit tests. + config.position_embedding_type = 'rope' + config.rotary_scaling_factor = None # seq_len_interpolation_factor + config.rotary_percent = 1.0 + config.rope_scaling = False + config.rope_scaling_factor = 8.0 + return config + + +def _make_attention(config: TransformerConfig, layer_number: int = 1) -> SelfAttention: + submodules = get_gpt_layer_local_spec().submodules.self_attention.submodules + return SelfAttention(config, submodules, layer_number=layer_number) + + +class TestRotaryBasePerLayerInit: + """Verify that SelfAttention builds the correct per-layer RotaryEmbedding.""" + + @pytest.fixture(autouse=True) + def setup_teardown(self): + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(42) + yield + Utils.destroy_model_parallel() + + def test_rotary_pos_emb_is_rope_instance(self): + """rotary_pos_emb is a RotaryEmbedding when rotary_base_per_layer is set.""" + config = _make_config([ROTARY_BASE_L1, ROTARY_BASE_L2]) + attn = _make_attention(config, layer_number=1) + assert isinstance(attn.rotary_pos_emb, RotaryEmbedding) + + def test_rotary_pos_emb_none_without_per_layer_config(self): + """rotary_pos_emb stays None when rotary_base_per_layer is not set.""" + config = TransformerConfig( + num_layers=NUM_LAYERS, + hidden_size=HIDDEN_SIZE, + num_attention_heads=NUM_HEADS, + use_cpu_initialization=True, + bf16=True, + params_dtype=torch.bfloat16, + ) + attn = _make_attention(config, layer_number=1) + assert attn.rotary_pos_emb is None + + def test_different_bases_produce_different_inv_freq(self): + """Layers with distinct bases must have different inv_freq tensors.""" + config = _make_config([ROTARY_BASE_L1, ROTARY_BASE_L2]) + attn1 = _make_attention(config, layer_number=1) + attn2 = _make_attention(config, layer_number=2) + assert not torch.allclose(attn1.rotary_pos_emb.inv_freq, attn2.rotary_pos_emb.inv_freq) + + def test_same_base_produces_identical_inv_freq(self): + """Layers sharing the same base must have identical inv_freq tensors.""" + config = _make_config([ROTARY_BASE_L1, ROTARY_BASE_L1]) + attn1 = _make_attention(config, layer_number=1) + attn2 = _make_attention(config, layer_number=2) + torch.testing.assert_close(attn1.rotary_pos_emb.inv_freq, attn2.rotary_pos_emb.inv_freq)