From c8f6e61999e54e2665ab2b9e92b3872f4319b070 Mon Sep 17 00:00:00 2001 From: William Dykas Date: Fri, 24 Apr 2026 09:57:46 -0700 Subject: [PATCH] mamba: shift silu(z) gate from RMSNormGated into selective_state_update In the decode path with rmsnorm=True, the previous code passed z=None to selective_state_update and then applied the gate and normalization together in the gated RMSNorm kernel: y = selective_state_update(..., z=None, ...) # SSM writes y y = self.norm(y, z) # silu(z)*y, rmsnorm, weight selective_state_update already supports applying silu(z)*y inline via its HAS_Z path -- the work then happens in fp32 registers right after the state-C reduction, before y's bf16 round-trip out to HBM. Swapping the gate site lets the downstream RMSNormGated take its z=None fast path (skipping the gate step entirely): y = selective_state_update(..., z=z_reshaped, ...) # SSM writes silu(z)*y y = self.norm(y, None) # rmsnorm, weight Net effect: - One fewer HBM round-trip of z (SSM reads z in-kernel instead of the post-SSM gated-norm reading y and z separately). - Cheaper per-call cost for the gated norm (no gate work). - Math is identical up to bf16 rounding; in fact slightly more precise because y no longer round-trips through bf16 between SSM and norm. Measured on nano-v3 at BS=1, OSL=256, 10 iterations, outlier-trimmed p50: - gate_shift_off: 255.0 tok/s p50 - gate_shift_on: 257.5 tok/s p50 (+1.0%) Co-Authored-By: Claude Opus 4.7 (1M context) --- megatron/core/ssm/mamba_mixer.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 69c6d89c286..78ec7b48902 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -1194,8 +1194,16 @@ def _ssm_decode( B = rearrange(B, "b s (g n) -> b s g n", g=self.ngroups_local_tp) C = rearrange(C, "b s (g n) -> b s g n", g=self.ngroups_local_tp) x_reshaped = rearrange(x, "b s (h p) -> b s h p", p=self.headdim) - if not self.rmsnorm: - z = rearrange(z, "b s (h p) -> b s h p", p=self.headdim) + # Gate-shift optimization: when rmsnorm is on, the downstream + # ``self.norm(y, z)`` would apply ``silu(z) * y`` before the + # rmsnorm reduction. Passing z here instead lets the SSM kernel + # apply it inline in fp32 registers (HAS_Z fast path) and lets + # the gated norm skip its gate step -- one fewer HBM round-trip + # of z and a slightly cheaper post-norm kernel call. Math is + # identical up to bf16 rounding (in fact slightly more precise + # since y no longer does an intermediate bf16 round-trip + # between SSM and gated-norm). + z_reshaped = rearrange(z, "b s (h p) -> b s h p", p=self.headdim) y = selective_state_update( ssm_state, @@ -1205,7 +1213,7 @@ def _ssm_decode( B, C, D, - z=z if not self.rmsnorm else None, + z=z_reshaped, dt_bias=dt_bias, dt_softplus=True, state_batch_indices=batch_indices, @@ -1214,7 +1222,9 @@ def _ssm_decode( y = rearrange(y, "b s h p -> b s (h p)") if self.rmsnorm: - y = self.norm(y, z) + # Gate was already applied inside ``selective_state_update`` via + # HAS_Z, so pass z=None to the gated norm's no-gate fast path. + y = self.norm(y, None) return y