Add MambaLayer FSDP support and fix TP annotation for SSM parameters#4329
Closed
Phlip79 wants to merge 2 commits intoNVIDIA:mainfrom
Closed
Add MambaLayer FSDP support and fix TP annotation for SSM parameters#4329Phlip79 wants to merge 2 commits intoNVIDIA:mainfrom
Phlip79 wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
MambaLayer (GraphableMegatronModule) was not recognized as an FSDP sharding unit, causing its parameters to remain in the root group and defeating ZeRO-3 param sharding for Mamba and hybrid models. Additionally, MambaMixer sets tensor_model_parallel and partition_dim directly on parameters (conv1d, A_log, dt_bias, D, norm.weight) rather than on the owning module. The TP annotation logic only checked module-level attributes, so these parameters were either unclassified or misclassified by the norm-name fallback (e.g. ExtendedRMSNorm treated as replicated when actually TP-sharded). Changes: - Register MambaLayer in default fsdp_unit_modules (mcore_fsdp_adapter) and sub_modules_to_wrap (torch_fully_sharded_data_parallel) - Add param-level TP attribute fallback in _detect_parallelism_type, placed before the norm-name fallback so TP-sharded norm weights are correctly classified - Pass param through from _annotate_tensor_parallelism - Add tests for param-level TP detection, norm override, and a MambaMixer-like end-to-end annotation test Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
Member
Author
|
/ok to test 0ef2848 |
ilml
approved these changes
Apr 21, 2026
shjwudp
approved these changes
Apr 22, 2026
Contributor
shjwudp
left a comment
There was a problem hiding this comment.
LGTM, thanks @Phlip79 !
- This PR sets MambaLayer as the default sharding layer.
- IIt adds MFSDP-TP compatibility for MambaLayer — the layer now sets
partition_dimon parameters directly, rather than on modules.
It would be great to make Mamba + MoE the default model for E2E tests if you're up for it — this would help us catch and track Mamba-related regressions more reliably. Reference: tests/unit_tests/distributed/megatron_fsdp/utils.py#L55
Member
Author
|
Closing this PR in favor of #4467 which solves the same issue. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
MambaLayer (GraphableMegatronModule) was not recognized as an FSDP sharding unit, causing its parameters to remain in the root group and defeating ZeRO-3 param sharding for Mamba and hybrid models.
Additionally, MambaMixer sets tensor_model_parallel and partition_dim directly on parameters (conv1d, A_log, dt_bias, D, norm.weight) rather than on the owning module. The TP annotation logic only checked module-level attributes, so these parameters were either unclassified or misclassified by the norm-name fallback (e.g. ExtendedRMSNorm treated as replicated when actually TP-sharded).
Changes:
Testing
Tested using Qwen3-Next
HybridModelimplementation: wandb. Set the following flags: