diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 727c6ef5fd6..42e549dd170 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -1197,8 +1197,16 @@ def _ssm_decode( B = rearrange(B, "b s (g n) -> b s g n", g=self.ngroups_local_tp) C = rearrange(C, "b s (g n) -> b s g n", g=self.ngroups_local_tp) x_reshaped = rearrange(x, "b s (h p) -> b s h p", p=self.headdim) - if not self.rmsnorm: - z = rearrange(z, "b s (h p) -> b s h p", p=self.headdim) + # Gate-shift optimization: when rmsnorm is on, the downstream + # ``self.norm(y, z)`` would apply ``silu(z) * y`` before the + # rmsnorm reduction. Passing z here instead lets the SSM kernel + # apply it inline in fp32 registers (HAS_Z fast path) and lets + # the gated norm skip its gate step -- one fewer HBM round-trip + # of z and a slightly cheaper post-norm kernel call. Math is + # identical up to bf16 rounding (in fact slightly more precise + # since y no longer does an intermediate bf16 round-trip + # between SSM and gated-norm). + z_reshaped = rearrange(z, "b s (h p) -> b s h p", p=self.headdim) y = selective_state_update( ssm_state, @@ -1208,7 +1216,7 @@ def _ssm_decode( B, C, D, - z=z if not self.rmsnorm else None, + z=z_reshaped, dt_bias=dt_bias, dt_softplus=True, state_batch_indices=batch_indices, @@ -1217,7 +1225,9 @@ def _ssm_decode( y = rearrange(y, "b s h p -> b s (h p)") if self.rmsnorm: - y = self.norm(y, z) + # Gate was already applied inside ``selective_state_update`` via + # HAS_Z, so pass z=None to the gated norm's no-gate fast path. + y = self.norm(y, None) return y