-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Add preliminary Muon+M-FSDP support #4486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -64,7 +64,10 @@ | |
| from .emerging_optimizers import ( | ||
| _EMERGING_OPTIMIZERS, | ||
| HAVE_EMERGING_OPTIMIZERS, | ||
| FSDPMuonChainedOptimizer, | ||
| FSDPZeROTensorParallelMuon, | ||
| _create_emerging_optimizer, | ||
| _get_mfsdp_models, | ||
| ) | ||
| from .grad_scaler import ConstantGradScaler, DynamicGradScaler | ||
| from .layer_wise_optimizer import LayerWiseDistributedOptimizer | ||
|
|
@@ -721,6 +724,193 @@ def check_config_overrides_consistency( | |
| return True | ||
|
|
||
|
|
||
| def _build_megatron_fsdp_emerging_optimizer( | ||
| config: OptimizerConfig, | ||
| model_chunks: List[MegatronModule], | ||
| config_overrides: Dict[ParamKey, ParamGroupOverride], | ||
| pg_collection: ProcessGroupCollection, | ||
| eopt_name: str, | ||
| use_layer_wise: bool, | ||
| ) -> "FSDPMuonChainedOptimizer": | ||
| """Build an emerging optimizer (currently only Muon) for Megatron-FSDP. | ||
|
|
||
| The standard emerging-optimizer flow assumes `main_grad` buffers populated | ||
| by DDP. With Megatron-FSDP, gradients arrive via `finish_grad_sync()` on | ||
| DTensor parameters, which makes `Float16OptimizerWithFloat16Params` | ||
| incompatible. This helper: | ||
|
|
||
| 1. Builds the linear-only Muon optimizer (plain `TensorParallelMuon` for | ||
| `no_shard`; `FSDPZeROTensorParallelMuon` for sharded strategies, which | ||
| gathers only split rank-boundary DTensor parameters before NS) over a | ||
| frozen-grad view of the model so `_get_param_groups` only emits | ||
| Muon-managed params. | ||
| 2. Falls back to `get_megatron_optimizer` for the non-linear params, | ||
| which routes through the standard Megatron-FSDP path for Adam. | ||
| 3. Wraps the chained result with `FSDPMuonChainedOptimizer`, which drives | ||
| `finish_grad_sync` and `install_optimized_model_weights`. | ||
| """ | ||
| # Lazy import to avoid pulling Muon-specific symbols when emerging_optimizers | ||
| # is unavailable; entry already validated upstream. | ||
| entry = _EMERGING_OPTIMIZERS[eopt_name] | ||
| init_state_fn = entry.init_state_fn | ||
|
|
||
| # Choose Muon variant based on the inner-DP sharding strategy. | ||
| # - "no_shard" (ZeRO-0): params/grads are full Replicate DTensors, so the | ||
| # plain TensorParallelMuon works without any DP communication. | ||
| # - Sharded strategies (ZeRO-1/2/3): FSDPZeROTensorParallelMuon gathers | ||
| # only split boundary parameters across the DP group before NS and | ||
| # re-shards the result. | ||
| fsdp_strategy = getattr( | ||
| model_chunks[0].ddp_config, "data_parallel_sharding_strategy", "no_shard" | ||
| ) | ||
| use_fsdp_zero_muon = fsdp_strategy != "no_shard" | ||
| optimizer_cls = FSDPZeROTensorParallelMuon if use_fsdp_zero_muon else entry.optimizer_cls | ||
|
|
||
| log_single_rank( | ||
| logger, | ||
| logging.INFO, | ||
| f"Setting up Megatron-FSDP emerging optimizer ({eopt_name}) with config {config}", | ||
| ) | ||
|
|
||
| # Sort parameters into linear (Muon-managed) vs nonlinear (Adam-managed). | ||
| linear_params: List[torch.nn.Parameter] = [] | ||
| nonlinear_params: List[torch.nn.Parameter] = [] | ||
| for model_chunk in model_chunks: | ||
| for _, param in model_chunk.named_parameters(): | ||
| if not param.requires_grad: | ||
| continue | ||
| if ( | ||
| not getattr(param, "is_embedding_or_output_parameter", False) | ||
| and len(param.shape) == 2 | ||
| ): | ||
| linear_params.append(param) | ||
| else: | ||
| nonlinear_params.append(param) | ||
|
|
||
| # Build the Muon optimizer for linear params. | ||
|
|
||
| for p in nonlinear_params: | ||
| p.requires_grad = False | ||
| linear_param_groups = _get_param_groups(model_chunks, config, config_overrides) | ||
|
|
||
| expert_param_groups: List[Dict[str, Any]] = [] | ||
| if not use_layer_wise: | ||
| # Build a fresh list rather than calling `linear_param_groups.remove(group)`: | ||
| # `list.remove` falls back to `dict.__eq__` when the elements are dicts, | ||
| # which compares values and trips on tensor params with | ||
| # `RuntimeError: Boolean value of Tensor with more than one value is ambiguous`. | ||
| non_expert_groups: List[Dict[str, Any]] = [] | ||
| for group in linear_param_groups: | ||
| if group["is_expert_parallel"]: | ||
| expert_param_groups.append(group) | ||
| else: | ||
| non_expert_groups.append(group) | ||
| linear_param_groups = non_expert_groups | ||
|
|
||
| eopt_kwargs: Dict[str, Any] = {} | ||
| if entry.config_to_kwargs is not None: | ||
| eopt_kwargs = entry.config_to_kwargs(config, model_chunks, pg_collection) | ||
|
|
||
| optimizers: List[Any] = [] | ||
| init_fns: List[Callable] = [] | ||
| if linear_param_groups: | ||
| # ZeRO-1/2/3: split dense boundary params are sharded over `dp_cp`, so | ||
| # the FSDPZero variant must gather over the same group. | ||
| dense_kwargs = dict(eopt_kwargs) | ||
| if use_fsdp_zero_muon: | ||
| dense_kwargs["dp_group"] = pg_collection.dp_cp | ||
| muon_base = optimizer_cls(linear_param_groups, **dense_kwargs) | ||
| muon_opt = FP32Optimizer(muon_base, config, init_state_fn) | ||
| setattr(muon_opt, "grad_stats_parallel_group", pg_collection.mp) | ||
| setattr(muon_opt, "tp_group", pg_collection.tp) | ||
| optimizers.append(muon_opt) | ||
| init_fns.append(init_state_fn) | ||
|
|
||
| if expert_param_groups: | ||
| # Expert params reduce-scatter over `expt_dp` (the expert | ||
| # data-parallel group), not `dp_cp`. | ||
| expert_kwargs = dict(eopt_kwargs) | ||
| if use_fsdp_zero_muon: | ||
| expert_kwargs["dp_group"] = pg_collection.expt_dp | ||
| expert_muon_base = optimizer_cls(expert_param_groups, **expert_kwargs) | ||
| expert_muon_opt = FP32Optimizer(expert_muon_base, config, init_state_fn) | ||
| setattr(expert_muon_opt, "grad_stats_parallel_group", pg_collection.tp_ep_pp) | ||
| setattr(expert_muon_opt, "tp_group", pg_collection.tp) | ||
| optimizers.append(expert_muon_opt) | ||
| init_fns.append(init_state_fn) | ||
|
|
||
| # Build Adam for non-linear params via the standard FSDP path. | ||
|
|
||
| for p in nonlinear_params: | ||
| p.requires_grad = True | ||
| for p in linear_params: | ||
| p.requires_grad = False | ||
|
|
||
| saved_optimizer = config.optimizer | ||
| saved_use_lwd = config.use_layer_wise_distributed_optimizer | ||
| config.optimizer = "adam" | ||
| config.use_layer_wise_distributed_optimizer = False | ||
| try: | ||
| # `pg_collection` is always provided here, so Gloo process groups | ||
| # must be disabled on the recursive call (setup_process_groups_for_optimizer | ||
| # rejects the combination). | ||
| chained_adam = get_megatron_optimizer( | ||
| config, | ||
| model_chunks, | ||
| config_overrides=config_overrides, | ||
| pg_collection=pg_collection, | ||
| use_gloo_process_groups=False, | ||
| ) | ||
| finally: | ||
| config.optimizer = saved_optimizer | ||
| config.use_layer_wise_distributed_optimizer = saved_use_lwd | ||
|
|
||
| for p in linear_params: | ||
| p.requires_grad = True | ||
|
|
||
| adam_optimizers = list(getattr(chained_adam, "chained_optimizers", [chained_adam])) | ||
| optimizers += adam_optimizers | ||
|
|
||
| def _adam_init_state_fn(opt: Any, config: Any = None) -> None: | ||
| for group in opt.param_groups: | ||
| for p in group["params"]: | ||
| if len(opt.state[p]) == 0: | ||
| if config is None or not config.use_precision_aware_optimizer: | ||
| opt.state[p]["exp_avg"] = torch.zeros_like(p.data) | ||
| opt.state[p]["exp_avg_sq"] = torch.zeros_like(p.data) | ||
|
Comment on lines
+878
to
+880
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also what does this case capture? With MFSDP, I think we need DTensors so we usually do a "dummy step" to get DTensor state based on the DTensor params. Two examples:
|
||
| else: | ||
| opt.initialize_state(p) | ||
|
|
||
| init_fns += [_adam_init_state_fn] * len(adam_optimizers) | ||
|
|
||
| # Combine and wrap with the FSDP-protocol adapter. | ||
|
|
||
| if use_layer_wise: | ||
| # Float16OptimizerWithFloat16Params is incompatible with FSDP DTensor | ||
| # params (no .main_grad), so temporarily clear bf16 to prevent the | ||
| # LayerWiseDistributedOptimizer from re-wrapping each sub-optimizer. | ||
|
Comment on lines
+889
to
+891
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Either we should extend the logic in
Note: Layer-wise appears to be a permutation of optimizer params into different layer-specific groups distributed across DP-CP (or EDP). I am not sure if this is compatible with Megatron-FSDP, since we have un-even shards. Is it possible that, say, non-empty Params are all allocated to a different DP rank, but that rank's equivalent Params are empty, so no optimizer update? |
||
| log_single_rank( | ||
| logger, | ||
| logging.INFO, | ||
| f"Using LayerWiseDistributedOptimizer for {eopt_name} + Megatron-FSDP", | ||
| ) | ||
| saved_bf16 = config.bf16 | ||
| config.bf16 = False | ||
| try: | ||
| inner: Any = LayerWiseDistributedOptimizer( | ||
| optimizers, | ||
| config, | ||
| pg_collection, | ||
| init_state_fn_list=init_fns, | ||
| ) | ||
| finally: | ||
| config.bf16 = saved_bf16 | ||
| else: | ||
| inner = ChainedOptimizer(optimizers) | ||
|
|
||
| return FSDPMuonChainedOptimizer(inner, _get_mfsdp_models(model_chunks)) | ||
|
|
||
|
|
||
| def _get_megatron_emerging_optimizer( | ||
| config: OptimizerConfig, | ||
| model_chunks: List[MegatronModule], | ||
|
|
@@ -778,7 +968,23 @@ def _get_megatron_emerging_optimizer( | |
| if 'linear_qkv.weight' in name and len(param.shape) == 2: | ||
| param.is_qkv = True | ||
|
|
||
| # Apply optimizer-specific default param overrides (e.g. muon: non-linear -> adam). | ||
| # Megatron-FSDP needs a separate code path: gradients are attached via | ||
| # `finish_grad_sync()` rather than via `main_grad` buffers, so the | ||
| # standard Float16OptimizerWithFloat16Params wrapper is incompatible. | ||
| # In addition, the resulting optimizer must be wrapped to invoke FSDP's | ||
| # `finish_grad_sync()` / `install_optimized_model_weights()` hooks | ||
| # around each step. | ||
| if getattr(model_chunks[0].ddp_config, 'use_megatron_fsdp', False): | ||
| return _build_megatron_fsdp_emerging_optimizer( | ||
| config=config, | ||
| model_chunks=model_chunks, | ||
| config_overrides=config_overrides, | ||
| pg_collection=pg_collection, | ||
| eopt_name=eopt_name, | ||
| use_layer_wise=use_layer_wise, | ||
| ) | ||
|
|
||
| # Apply optimizer-specific default param overrides (e.g., muon: non-linear -> adam). | ||
| config_overrides.update(_EMERGING_OPTIMIZERS[eopt_name].default_param_overrides) | ||
|
|
||
| # Build param groups and bucket by (optimizer_name, is_expert_parallel). | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this because layer-wise <> EP isn't compatible, so we just use Adam? (Gathering expert params is too heavyweight? Also, another complexity is EP-DP sharding, so we need to gather both.)