mamba: shift silu(z) gate from RMSNormGated into selective_state_update#4461
Draft
wdykas wants to merge 1 commit intoNVIDIA:mainfrom
Draft
mamba: shift silu(z) gate from RMSNormGated into selective_state_update#4461wdykas wants to merge 1 commit intoNVIDIA:mainfrom
wdykas wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
In the decode path with rmsnorm=True, the previous code passed z=None to
selective_state_update and then applied the gate and normalization together
in the gated RMSNorm kernel:
y = selective_state_update(..., z=None, ...) # SSM writes y
y = self.norm(y, z) # silu(z)*y, rmsnorm, weight
selective_state_update already supports applying silu(z)*y inline via its
HAS_Z path -- the work then happens in fp32 registers right after the
state-C reduction, before y's bf16 round-trip out to HBM. Swapping the
gate site lets the downstream RMSNormGated take its z=None fast path
(skipping the gate step entirely):
y = selective_state_update(..., z=z_reshaped, ...) # SSM writes silu(z)*y
y = self.norm(y, None) # rmsnorm, weight
Net effect:
- One fewer HBM round-trip of z (SSM reads z in-kernel instead of the
post-SSM gated-norm reading y and z separately).
- Cheaper per-call cost for the gated norm (no gate work).
- Math is identical up to bf16 rounding; in fact slightly more precise
because y no longer round-trips through bf16 between SSM and norm.
Measured on nano-v3 at BS=1, OSL=256, 10 iterations, outlier-trimmed p50:
- gate_shift_off: 255.0 tok/s p50
- gate_shift_on: 257.5 tok/s p50 (+1.0%)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Contributor
|
This PR has been automatically converted to draft because all PRs must start as drafts. When you are ready for review, click Ready for Review to begin the review process. This will:
See the contribution guide for more details. |
Contributor
Author
|
This will need to be reviewed and tested much more. |
| y = self.norm(y, z) | ||
| # Gate was already applied inside ``selective_state_update`` via | ||
| # HAS_Z, so pass z=None to the gated norm's no-gate fast path. | ||
| y = self.norm(y, None) |
Contributor
There was a problem hiding this comment.
would it be too crazy to subsume the norm in ssm update as well?
Contributor
Author
There was a problem hiding this comment.
I can try that this week as well
Contributor
Author
There was a problem hiding this comment.
Ill keep trying but its slower for some reason every time I try. But that could just be a skill issue
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
In the decode path with rmsnorm=True, the previous code passed z=None to selective_state_update and then applied the gate and normalization together in the gated RMSNorm kernel:
selective_state_update already supports applying silu(z)*y inline via its HAS_Z path -- the work then happens in fp32 registers right after the state-C reduction, before y's bf16 round-trip out to HBM. Swapping the gate site lets the downstream RMSNormGated take its z=None fast path (skipping the gate step entirely):
Net effect:
Measured on nano-v3 at BS=1, OSL=256, 10 iterations, outlier-trimmed p50:
What does this PR do ?
Issue tracking
For PRs from open-source community contributors:
Linked issue:
Contribution process
Pre-checks
Code review
Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!
All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.
Step 1: Mark PR as "Ready for Review"
.github/CODEOWNERS.Final Review might get declined if these requirements are not fulfilled.
Step 2: Final Review
For PRs that change
megatron/core, once all expert reviewers have approved, theFinal Reviewlabel is applied automatically and final reviewers are assigned.For PRs outside
megatron/core, this step is skipped.Step 3: Approved
Once all required reviewers have approved, the
Approvedlabel is applied automatically.Merge
Any member of mcore-engineers will be able to merge your PR.
For MRs into `dev` branch
The proposed review process for `dev` branch is under active discussion.MRs are mergable after one approval by either
eharper@nvidia.comorzijiey@nvidia.com.