Add preliminary Muon+M-FSDP support#4486
Draft
janEbert wants to merge 5 commits intoNVIDIA:mainfrom
Draft
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
Route the emerging-optimizer factory through a Megatron-FSDP-specific
path when `ddp_config.use_megatron_fsdp` is set. Megatron-FSDP attaches
grads via `finish_grad_sync()` on DTensor params instead of via DDP's
main_grad buffers, so the standard `Float16OptimizerWithFloat16Params`
wrapper does not apply; we always wrap with `FP32Optimizer` instead and
drive the FSDP step contract from a thin `FSDPMuonChainedOptimizer`
adapter that calls `finish_grad_sync()` and
`install_optimized_model_weights()` around the inner step.
For now this supports ZeRO-0 ("no_shard") only; ZeRO-1/2/3 will work
without errors on the wiring but require a sharding-aware Muon variant
for numerical correctness, added in a follow-up.
Also patch `LayerWiseDistributedOptimizer._allgather_helper` to read
DTensor-backed params via `_local_tensor`, so the layer-wise + FSDP
combination can flatten the local shard rather than the global DTensor.
Add `FSDPZeROTensorParallelMuon`, a TensorParallelMuon subclass that: 1. Extracts the `Shard(0)` local tensor from each gradient DTensor: (`finish_grad_sync` produces a row-shard per DP rank for `optim`, `optim_grads` and `optim_grads_params`). 2. Allgathers the shards across the DP group to reconstruct the TP-local, DP-full gradient matrix. 3. Trims FSDP bucket-padding rows using the DTensor's declared global shape. 4. Delegates Newton-Schulz to the parent class (which handles the TP dimension via `newton_schulz_tp`). 5. Re-shards the orthogonalized result back to a `Shard(0)` DTensor with matching placements so the in-place update in `OrthogonalizedOptimizer.step` does not promote to `Replicate` and trip the global-shape check. The FSDP factory in `_build_megatron_fsdp_emerging_optimizer` now picks `FSDPZeROTensorParallelMuon` for any sharded inner-DP strategy and passes `pg_collection.dp_cp` for dense params and `pg_collection.expt_dp` for expert params (since expert grads reduce-scatter over a different group). "no_shard" continues to use plain `TensorParallelMuon`. DTensor is imported at module scope with a `_HAVE_DTENSOR` guard so the isinstance checks stay cheap and the module still imports on stacks without `torch.distributed.tensor`.
Three phases of tests for the Muon + Megatron-FSDP integration: - Phase 1: `FSDPMuonChainedOptimizer` adapter (single-rank, mock-based). Verifies the step contract – finish_grad_sync -> inner step -> install_optimized_model_weights – and attribute delegation. - Phase 2: `FSDPZeROTensorParallelMuon.orthogonalize` (multi-rank). Asserts the allgather -> Newton-Schulz -> reshard cycle is numerically equivalent to running NS on the full gradient and extracting the local row-shard, including FSDP padding edge cases. Includes a DTensor round-trip test that catches the `p.add_(orthogonalized_dtensor)` placement-promotion bug. - Phase 3: `_build_megatron_fsdp_emerging_optimizer` factory. Confirms the factory dispatches plain `TensorParallelMuon` for `no_shard` and `FSDPZeROTensorParallelMuon` for sharded strategies, and that expert vs. non-expert Muon instances receive `expt_dp` vs. `dp_cp` as their allgather group.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Introduce Muon support to M-FSDP. Currently 1.5×–2.7× as slow compared to an Adam baseline with a 1B–8B DeepSeek-V3 proxy model. Peak memory slightly lower than with Adam (4–7 % less).