From c4f98f15116825865adc30b3b2f7711bf73245f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=B1=E5=85=83=E8=B1=AA?= <146086744+edenfunf@users.noreply.github.com> Date: Sat, 25 Apr 2026 09:02:17 +0800 Subject: [PATCH] docs(moe): correct moe_router_topk_scaling_factor docstring (#1875) The docstring stated the scaling factor "only works when moe_router_pre_softmax enabled", but topk_routing_with_score_function applies it unconditionally on the post-top-k probabilities for every score_function and pre_softmax setting. All shipped recipes (DeepSeek-R1, Kimi-K2, Llama-4, Nemotron, etc.) rely on this by combining a non-None scaling factor with the default moe_router_pre_softmax=False. Update the docstring to describe the actual behavior, and add a parametrized unit test in test_routers.py that pins the contract across {pre_softmax, score_function}. --- .../core/transformer/transformer_config.py | 6 ++- .../transformer/moe/test_routers.py | 47 ++++++++++++++++++- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 40c1a745493..e1236ddeaca 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -692,8 +692,10 @@ class TransformerConfig(ModelParallelConfig): By default, softmax is done after top-k.""" moe_router_topk_scaling_factor: Optional[float] = None - """Scaling factor for routing score in top-k selection, only works when moe_router_pre_softmax - enabled. Defaults to None, which means no scaling.""" + """Scaling factor applied to the routing probabilities after the top-k selection. + The scaling is applied unconditionally whenever this value is not None, for every + `moe_router_score_function` (softmax/sigmoid/sqrtsoftplus) and regardless of whether + `moe_router_pre_softmax` is enabled. Defaults to None, which means no scaling.""" moe_router_score_function: Literal['softmax', 'sigmoid', 'sqrtsoftplus'] = "softmax" """Score function for MoE routing. Can be "softmax", "sigmoid" or "sqrtsoftplus".""" diff --git a/tests/unit_tests/transformer/moe/test_routers.py b/tests/unit_tests/transformer/moe/test_routers.py index 8f3dbbe96e0..edaf9ca6666 100644 --- a/tests/unit_tests/transformer/moe/test_routers.py +++ b/tests/unit_tests/transformer/moe/test_routers.py @@ -8,7 +8,11 @@ from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_submodules from megatron.core.transformer.moe.moe_layer import MoELayer -from megatron.core.transformer.moe.moe_utils import get_updated_expert_bias, router_gating_linear +from megatron.core.transformer.moe.moe_utils import ( + get_updated_expert_bias, + router_gating_linear, + topk_routing_with_score_function, +) from megatron.core.transformer.moe.router import Router from megatron.core.transformer.transformer_config import TransformerConfig from megatron.training.initialize import _set_random_seed @@ -563,3 +567,44 @@ def test_router_gating_linear_bias(router_dtype): assert torch.allclose(inp.grad, ref_inp.grad, **tols) assert torch.allclose(weight.grad, ref_weight.grad, **tols) assert torch.allclose(bias.grad, ref_bias.grad, **tols) + + +@pytest.mark.internal +@pytest.mark.parametrize("use_pre_softmax", [True, False]) +@pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) +def test_topk_scaling_factor_applies_for_all_pre_softmax_settings(use_pre_softmax, score_function): + """`moe_router_topk_scaling_factor` should multiply the post-top-k probabilities for + every combination of `score_function` and `use_pre_softmax`. + + Pins the documented behavior of `moe_router_topk_scaling_factor` (issue #1875): the + scaling is applied unconditionally inside `topk_routing_with_score_function`, not + only when `moe_router_pre_softmax=True`. + """ + torch.manual_seed(0) + logits = torch.randn(8, 4) + scaling = 2.5 + + probs_unscaled, map_unscaled = topk_routing_with_score_function( + logits, + topk=2, + use_pre_softmax=use_pre_softmax, + scaling_factor=None, + score_function=score_function, + ) + probs_scaled, map_scaled = topk_routing_with_score_function( + logits, + topk=2, + use_pre_softmax=use_pre_softmax, + scaling_factor=scaling, + score_function=score_function, + ) + + # The selected experts must be identical: scaling is monotonic and only changes + # magnitude of probabilities, not which experts win the top-k selection. + assert torch.equal(map_scaled, map_unscaled) + selected = map_unscaled + assert selected.any(), "Sanity: at least one expert should be selected" + + torch.testing.assert_close( + probs_scaled[selected], probs_unscaled[selected] * scaling, rtol=1e-6, atol=1e-6 + )