Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions megatron/core/ssm/mamba_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,8 +1197,16 @@ def _ssm_decode(
B = rearrange(B, "b s (g n) -> b s g n", g=self.ngroups_local_tp)
C = rearrange(C, "b s (g n) -> b s g n", g=self.ngroups_local_tp)
x_reshaped = rearrange(x, "b s (h p) -> b s h p", p=self.headdim)
if not self.rmsnorm:
z = rearrange(z, "b s (h p) -> b s h p", p=self.headdim)
# Gate-shift optimization: when rmsnorm is on, the downstream
# ``self.norm(y, z)`` would apply ``silu(z) * y`` before the
# rmsnorm reduction. Passing z here instead lets the SSM kernel
# apply it inline in fp32 registers (HAS_Z fast path) and lets
# the gated norm skip its gate step -- one fewer HBM round-trip
# of z and a slightly cheaper post-norm kernel call. Math is
# identical up to bf16 rounding (in fact slightly more precise
# since y no longer does an intermediate bf16 round-trip
# between SSM and gated-norm).
z_reshaped = rearrange(z, "b s (h p) -> b s h p", p=self.headdim)

y = selective_state_update(
ssm_state,
Expand All @@ -1208,7 +1216,7 @@ def _ssm_decode(
B,
C,
D,
z=z if not self.rmsnorm else None,
z=z_reshaped,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=batch_indices,
Expand All @@ -1217,7 +1225,9 @@ def _ssm_decode(
y = rearrange(y, "b s h p -> b s (h p)")

if self.rmsnorm:
y = self.norm(y, z)
# Gate was already applied inside ``selective_state_update`` via
# HAS_Z, so pass z=None to the gated norm's no-gate fast path.
y = self.norm(y, None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be too crazy to subsume the norm in ssm update as well?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can try that this week as well

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ill keep trying but its slower for some reason every time I try. But that could just be a skill issue


return y

Expand Down
Loading