Skip to content

Add MambaLayer FSDP support and fix TP annotation for SSM parameters#4329

Closed
Phlip79 wants to merge 2 commits intoNVIDIA:mainfrom
Phlip79:worktree-agent-a7696c2d
Closed

Add MambaLayer FSDP support and fix TP annotation for SSM parameters#4329
Phlip79 wants to merge 2 commits intoNVIDIA:mainfrom
Phlip79:worktree-agent-a7696c2d

Conversation

@Phlip79
Copy link
Copy Markdown
Member

@Phlip79 Phlip79 commented Apr 15, 2026

MambaLayer (GraphableMegatronModule) was not recognized as an FSDP sharding unit, causing its parameters to remain in the root group and defeating ZeRO-3 param sharding for Mamba and hybrid models.

Additionally, MambaMixer sets tensor_model_parallel and partition_dim directly on parameters (conv1d, A_log, dt_bias, D, norm.weight) rather than on the owning module. The TP annotation logic only checked module-level attributes, so these parameters were either unclassified or misclassified by the norm-name fallback (e.g. ExtendedRMSNorm treated as replicated when actually TP-sharded).

Changes:

  • Register MambaLayer in default fsdp_unit_modules (mcore_fsdp_adapter) and sub_modules_to_wrap (torch_fully_sharded_data_parallel)
  • Add param-level TP attribute fallback in _detect_parallelism_type, placed before the norm-name fallback so TP-sharded norm weights are correctly classified
  • Pass param through from _annotate_tensor_parallelism
  • Add tests for param-level TP detection, norm override, and a MambaMixer-like end-to-end annotation test

Testing

Tested using Qwen3-Next HybridModel implementation: wandb. Set the following flags:

  • cfg.ddp.use_megatron_fsdp = True
  • cfg.ddp.fsdp_double_buffer = True
  • cfg.ddp.nccl_ub = False
  • cfg.ddp.fsdp_db_use_persist_buf_on_alloc_fail = True
  • cfg.ddp.num_distributed_optimizer_instances = 1
  • cfg.checkpoint.ckpt_format = "fsdp_dtensor"
  • cfg.model.pipeline_model_parallel_size = 1

MambaLayer (GraphableMegatronModule) was not recognized as an FSDP
sharding unit, causing its parameters to remain in the root group and
defeating ZeRO-3 param sharding for Mamba and hybrid models.

Additionally, MambaMixer sets tensor_model_parallel and partition_dim
directly on parameters (conv1d, A_log, dt_bias, D, norm.weight) rather
than on the owning module. The TP annotation logic only checked
module-level attributes, so these parameters were either unclassified
or misclassified by the norm-name fallback (e.g. ExtendedRMSNorm
treated as replicated when actually TP-sharded).

Changes:
- Register MambaLayer in default fsdp_unit_modules (mcore_fsdp_adapter)
  and sub_modules_to_wrap (torch_fully_sharded_data_parallel)
- Add param-level TP attribute fallback in _detect_parallelism_type,
  placed before the norm-name fallback so TP-sharded norm weights are
  correctly classified
- Pass param through from _annotate_tensor_parallelism
- Add tests for param-level TP detection, norm override, and a
  MambaMixer-like end-to-end annotation test

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Apr 15, 2026

/ok to test 0ef2848

@Phlip79 Phlip79 marked this pull request as ready for review April 22, 2026 01:32
@Phlip79 Phlip79 requested review from a team as code owners April 22, 2026 01:32
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team April 22, 2026 01:32
Copy link
Copy Markdown
Contributor

@shjwudp shjwudp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @Phlip79 !

  1. This PR sets MambaLayer as the default sharding layer.
  2. IIt adds MFSDP-TP compatibility for MambaLayer — the layer now sets partition_dim on parameters directly, rather than on modules.

It would be great to make Mamba + MoE the default model for E2E tests if you're up for it — this would help us catch and track Mamba-related regressions more reliably. Reference: tests/unit_tests/distributed/megatron_fsdp/utils.py#L55

@svcnvidia-nemo-ci svcnvidia-nemo-ci added the Final Review PR is in the "final review" stage label Apr 22, 2026
@Phlip79 Phlip79 marked this pull request as draft April 23, 2026 04:59
@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Apr 27, 2026

Closing this PR in favor of #4467 which solves the same issue.

@Phlip79 Phlip79 closed this Apr 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

complexity: low Final Review PR is in the "final review" stage

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants