diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index 4f69a9efd55..8a22803c33f 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -64,7 +64,10 @@ from .emerging_optimizers import ( _EMERGING_OPTIMIZERS, HAVE_EMERGING_OPTIMIZERS, + FSDPMuonChainedOptimizer, + FSDPZeROTensorParallelMuon, _create_emerging_optimizer, + _get_mfsdp_models, ) from .grad_scaler import ConstantGradScaler, DynamicGradScaler from .layer_wise_optimizer import LayerWiseDistributedOptimizer @@ -721,6 +724,193 @@ def check_config_overrides_consistency( return True +def _build_megatron_fsdp_emerging_optimizer( + config: OptimizerConfig, + model_chunks: List[MegatronModule], + config_overrides: Dict[ParamKey, ParamGroupOverride], + pg_collection: ProcessGroupCollection, + eopt_name: str, + use_layer_wise: bool, +) -> "FSDPMuonChainedOptimizer": + """Build an emerging optimizer (currently only Muon) for Megatron-FSDP. + + The standard emerging-optimizer flow assumes `main_grad` buffers populated + by DDP. With Megatron-FSDP, gradients arrive via `finish_grad_sync()` on + DTensor parameters, which makes `Float16OptimizerWithFloat16Params` + incompatible. This helper: + + 1. Builds the linear-only Muon optimizer (plain `TensorParallelMuon` for + `no_shard`; `FSDPZeROTensorParallelMuon` for sharded strategies, which + gathers only split rank-boundary DTensor parameters before NS) over a + frozen-grad view of the model so `_get_param_groups` only emits + Muon-managed params. + 2. Falls back to `get_megatron_optimizer` for the non-linear params, + which routes through the standard Megatron-FSDP path for Adam. + 3. Wraps the chained result with `FSDPMuonChainedOptimizer`, which drives + `finish_grad_sync` and `install_optimized_model_weights`. + """ + # Lazy import to avoid pulling Muon-specific symbols when emerging_optimizers + # is unavailable; entry already validated upstream. + entry = _EMERGING_OPTIMIZERS[eopt_name] + init_state_fn = entry.init_state_fn + + # Choose Muon variant based on the inner-DP sharding strategy. + # - "no_shard" (ZeRO-0): params/grads are full Replicate DTensors, so the + # plain TensorParallelMuon works without any DP communication. + # - Sharded strategies (ZeRO-1/2/3): FSDPZeROTensorParallelMuon gathers + # only split boundary parameters across the DP group before NS and + # re-shards the result. + fsdp_strategy = getattr( + model_chunks[0].ddp_config, "data_parallel_sharding_strategy", "no_shard" + ) + use_fsdp_zero_muon = fsdp_strategy != "no_shard" + optimizer_cls = FSDPZeROTensorParallelMuon if use_fsdp_zero_muon else entry.optimizer_cls + + log_single_rank( + logger, + logging.INFO, + f"Setting up Megatron-FSDP emerging optimizer ({eopt_name}) with config {config}", + ) + + # Sort parameters into linear (Muon-managed) vs nonlinear (Adam-managed). + linear_params: List[torch.nn.Parameter] = [] + nonlinear_params: List[torch.nn.Parameter] = [] + for model_chunk in model_chunks: + for _, param in model_chunk.named_parameters(): + if not param.requires_grad: + continue + if ( + not getattr(param, "is_embedding_or_output_parameter", False) + and len(param.shape) == 2 + ): + linear_params.append(param) + else: + nonlinear_params.append(param) + + # Build the Muon optimizer for linear params. + + for p in nonlinear_params: + p.requires_grad = False + linear_param_groups = _get_param_groups(model_chunks, config, config_overrides) + + expert_param_groups: List[Dict[str, Any]] = [] + if not use_layer_wise: + # Build a fresh list rather than calling `linear_param_groups.remove(group)`: + # `list.remove` falls back to `dict.__eq__` when the elements are dicts, + # which compares values and trips on tensor params with + # `RuntimeError: Boolean value of Tensor with more than one value is ambiguous`. + non_expert_groups: List[Dict[str, Any]] = [] + for group in linear_param_groups: + if group["is_expert_parallel"]: + expert_param_groups.append(group) + else: + non_expert_groups.append(group) + linear_param_groups = non_expert_groups + + eopt_kwargs: Dict[str, Any] = {} + if entry.config_to_kwargs is not None: + eopt_kwargs = entry.config_to_kwargs(config, model_chunks, pg_collection) + + optimizers: List[Any] = [] + init_fns: List[Callable] = [] + if linear_param_groups: + # ZeRO-1/2/3: split dense boundary params are sharded over `dp_cp`, so + # the FSDPZero variant must gather over the same group. + dense_kwargs = dict(eopt_kwargs) + if use_fsdp_zero_muon: + dense_kwargs["dp_group"] = pg_collection.dp_cp + muon_base = optimizer_cls(linear_param_groups, **dense_kwargs) + muon_opt = FP32Optimizer(muon_base, config, init_state_fn) + setattr(muon_opt, "grad_stats_parallel_group", pg_collection.mp) + setattr(muon_opt, "tp_group", pg_collection.tp) + optimizers.append(muon_opt) + init_fns.append(init_state_fn) + + if expert_param_groups: + # Expert params reduce-scatter over `expt_dp` (the expert + # data-parallel group), not `dp_cp`. + expert_kwargs = dict(eopt_kwargs) + if use_fsdp_zero_muon: + expert_kwargs["dp_group"] = pg_collection.expt_dp + expert_muon_base = optimizer_cls(expert_param_groups, **expert_kwargs) + expert_muon_opt = FP32Optimizer(expert_muon_base, config, init_state_fn) + setattr(expert_muon_opt, "grad_stats_parallel_group", pg_collection.tp_ep_pp) + setattr(expert_muon_opt, "tp_group", pg_collection.tp) + optimizers.append(expert_muon_opt) + init_fns.append(init_state_fn) + + # Build Adam for non-linear params via the standard FSDP path. + + for p in nonlinear_params: + p.requires_grad = True + for p in linear_params: + p.requires_grad = False + + saved_optimizer = config.optimizer + saved_use_lwd = config.use_layer_wise_distributed_optimizer + config.optimizer = "adam" + config.use_layer_wise_distributed_optimizer = False + try: + # `pg_collection` is always provided here, so Gloo process groups + # must be disabled on the recursive call (setup_process_groups_for_optimizer + # rejects the combination). + chained_adam = get_megatron_optimizer( + config, + model_chunks, + config_overrides=config_overrides, + pg_collection=pg_collection, + use_gloo_process_groups=False, + ) + finally: + config.optimizer = saved_optimizer + config.use_layer_wise_distributed_optimizer = saved_use_lwd + + for p in linear_params: + p.requires_grad = True + + adam_optimizers = list(getattr(chained_adam, "chained_optimizers", [chained_adam])) + optimizers += adam_optimizers + + def _adam_init_state_fn(opt: Any, config: Any = None) -> None: + for group in opt.param_groups: + for p in group["params"]: + if len(opt.state[p]) == 0: + if config is None or not config.use_precision_aware_optimizer: + opt.state[p]["exp_avg"] = torch.zeros_like(p.data) + opt.state[p]["exp_avg_sq"] = torch.zeros_like(p.data) + else: + opt.initialize_state(p) + + init_fns += [_adam_init_state_fn] * len(adam_optimizers) + + # Combine and wrap with the FSDP-protocol adapter. + + if use_layer_wise: + # Float16OptimizerWithFloat16Params is incompatible with FSDP DTensor + # params (no .main_grad), so temporarily clear bf16 to prevent the + # LayerWiseDistributedOptimizer from re-wrapping each sub-optimizer. + log_single_rank( + logger, + logging.INFO, + f"Using LayerWiseDistributedOptimizer for {eopt_name} + Megatron-FSDP", + ) + saved_bf16 = config.bf16 + config.bf16 = False + try: + inner: Any = LayerWiseDistributedOptimizer( + optimizers, + config, + pg_collection, + init_state_fn_list=init_fns, + ) + finally: + config.bf16 = saved_bf16 + else: + inner = ChainedOptimizer(optimizers) + + return FSDPMuonChainedOptimizer(inner, _get_mfsdp_models(model_chunks)) + + def _get_megatron_emerging_optimizer( config: OptimizerConfig, model_chunks: List[MegatronModule], @@ -778,7 +968,23 @@ def _get_megatron_emerging_optimizer( if 'linear_qkv.weight' in name and len(param.shape) == 2: param.is_qkv = True - # Apply optimizer-specific default param overrides (e.g. muon: non-linear -> adam). + # Megatron-FSDP needs a separate code path: gradients are attached via + # `finish_grad_sync()` rather than via `main_grad` buffers, so the + # standard Float16OptimizerWithFloat16Params wrapper is incompatible. + # In addition, the resulting optimizer must be wrapped to invoke FSDP's + # `finish_grad_sync()` / `install_optimized_model_weights()` hooks + # around each step. + if getattr(model_chunks[0].ddp_config, 'use_megatron_fsdp', False): + return _build_megatron_fsdp_emerging_optimizer( + config=config, + model_chunks=model_chunks, + config_overrides=config_overrides, + pg_collection=pg_collection, + eopt_name=eopt_name, + use_layer_wise=use_layer_wise, + ) + + # Apply optimizer-specific default param overrides (e.g., muon: non-linear -> adam). config_overrides.update(_EMERGING_OPTIMIZERS[eopt_name].default_param_overrides) # Build param groups and bucket by (optimizer_name, is_expert_parallel). diff --git a/megatron/core/optimizer/emerging_optimizers.py b/megatron/core/optimizer/emerging_optimizers.py index cc218d6ba40..93802e108dd 100644 --- a/megatron/core/optimizer/emerging_optimizers.py +++ b/megatron/core/optimizer/emerging_optimizers.py @@ -21,6 +21,21 @@ from .optimizer_config import ParamKey, ParamPredicate +try: + from torch.distributed.tensor import DTensor as _DTensor + + from megatron.core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor import ( + gather_uneven_dtensor_to_full_tensor, + update_uneven_dtensor_chunk_metadata, + ) + + _HAVE_DTENSOR = True +except ImportError: + _DTensor = None # type: ignore[assignment,misc] + gather_uneven_dtensor_to_full_tensor = None # type: ignore[assignment] + update_uneven_dtensor_chunk_metadata = None # type: ignore[assignment] + _HAVE_DTENSOR = False + try: from emerging_optimizers import registry from emerging_optimizers.orthogonalized_optimizers import ( @@ -146,6 +161,67 @@ def _get_qkv_split_shapes(model_cfg) -> List[int]: _EMERGING_OPTIMIZERS: Dict[str, EmergingOptimizerEntry] = {} +# =========================================================================== +# Megatron-FSDP integration +# =========================================================================== + + +def _get_mfsdp_models(model_chunks): + """Extract list of MegatronFSDP instances from FSDP-wrapped model chunks.""" + mfsdp_models = [] + for chunk in model_chunks: + # FullyShardedDataParallel delegates finish_grad_sync / start_param_sync + # from its .module (MegatronFSDP). install_optimized_model_weights lives + # directly on MegatronFSDP, so we need the inner module reference. + if hasattr(chunk, "finish_grad_sync") and hasattr(chunk, "module"): + mfsdp_models.append(chunk.module) + if not mfsdp_models: + raise RuntimeError( + "Could not find any MegatronFSDP instances in model_chunks. " + "Ensure the model is wrapped with FullyShardedDataParallel." + ) + return mfsdp_models + + +class FSDPMuonChainedOptimizer: + """Thin FSDP-protocol adapter wrapping a Muon-based MegatronOptimizer. + + Injects the MegatronFSDP step contract around the inner optimizer: + 1. `finish_grad_sync()` – waits for async grad sync, attaches grads + (allreduces for `no_shard`; reduce-scatters for ZeRO-1/2/3). + 2. `inner_optimizer.step()` – Muon NS + weight update + Adam. + 3. `install_optimized_model_weights()` – copies fp32 main weights + back into the model's bf16 (sharded for ZeRO-3) buffer. + + All other attribute accesses are delegated to the inner optimizer via + `__getattr__`, making this class transparent to the training loop. + """ + + def __init__(self, inner: Any, mfsdp_models: list) -> None: + # Use object.__setattr__ to avoid triggering our own __getattr__ during init. + object.__setattr__(self, "inner", inner) + object.__setattr__(self, "_mfsdp_models", mfsdp_models) + + @torch.no_grad() # type: ignore[misc] + def step(self) -> Any: + """FSDP-aware optimizer step: sync grads -> inner step -> install weights.""" + for mfsdp in self._mfsdp_models: + if not mfsdp.model_auto_sync: + mfsdp.finish_grad_sync() + result = self.inner.step() + for mfsdp in self._mfsdp_models: + mfsdp.install_optimized_model_weights() + return result + + def zero_grad(self, set_to_none: bool = True) -> None: + """Zero optimizer gradients. FSDP grad buffer is zeroed by the training loop.""" + self.inner.zero_grad(set_to_none) + + def __getattr__(self, name: str) -> Any: + """Delegate all other attribute accesses to the inner optimizer.""" + return getattr(object.__getattribute__(self, "inner"), name) + + # =========================================================================== # Muon # =========================================================================== @@ -276,6 +352,208 @@ def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> t return grad +class FSDPZeROTensorParallelMuon(TensorParallelMuon): + """`TensorParallelMuon` extended for Megatron-FSDP ZeRO-1/2/3. + + Megatron-FSDP shards a flat parameter buffer across DP ranks. Most 2D + parameters are either fully local on one rank or empty on the others; only + rank-boundary parameters are split across two ranks. Muon needs the full + matrix for Newton-Schulz, so this class gathers only those boundary + parameters and runs local NS for fully-local parameters. + """ + + def __init__( + self, + params: ParamsT, + dp_group: Optional[torch.distributed.ProcessGroup] = None, + **kwargs: Any, + ) -> None: + self.dp_group = dp_group + super().__init__(params, **kwargs) + + @torch.no_grad() # type: ignore[misc] + def step(self, closure: Optional[Callable] = None) -> Optional[float]: + """Muon step for Megatron-FSDP ZeRO-1/2/3. + + Ranks must enter collectives in the same order, so local boundary + decisions are unioned across the DP group before the parameter loop. + """ + if closure is None: + loss = None + else: + loss = closure() + + for group in self.param_groups: + self._init_group(group, skip_non_grad_params=False) + gather_param_indices = self._get_boundary_gather_param_indices(group) + + for param_idx, p in enumerate(group["params"]): + self._muon_step_one(p, group, gather_full=param_idx in gather_param_indices) + + return loss + + def _as_dtensor(self, value: Any): + """Return `value` or `value.data` when either is a DTensor.""" + if not _HAVE_DTENSOR: + return None + if isinstance(value, _DTensor): + return value + data = getattr(value, "data", None) + if isinstance(data, _DTensor): + return data + return None + + def _is_nonempty_dtensor_param(self, param: torch.Tensor) -> bool: + dtensor = self._as_dtensor(param) + return dtensor is not None and dtensor.to_local().numel() > 0 + + def _needs_boundary_gather(self, param: torch.Tensor) -> bool: + dtensor = self._as_dtensor(param) + if dtensor is None: + return False + local_tensor = dtensor.to_local() + return local_tensor.numel() > 0 and tuple(dtensor.shape) != tuple(local_tensor.shape) + + def _get_boundary_gather_param_indices(self, group: Dict[str, Any]) -> set[int]: + """Return globally-agreed parameter indices that need a boundary all-gather.""" + params = group["params"] + local_boundary_indices = [ + idx for idx, param in enumerate(params) if self._needs_boundary_gather(param) + ] + + if self.dp_group is None or get_pg_size(self.dp_group) == 1: + return set(local_boundary_indices) + + gathered_indices: list[list[int] | None] = [None] * get_pg_size(self.dp_group) + torch.distributed.all_gather_object( + gathered_indices, local_boundary_indices, group=self.dp_group + ) + return { + idx + for rank_indices in gathered_indices + if rank_indices is not None + for idx in rank_indices + } + + def _copy_dtensor_chunk_metadata(self, dst, src) -> None: + if hasattr(src._local_tensor, "__create_chunk_list__"): + dst._local_tensor.__create_chunk_list__ = src._local_tensor.__create_chunk_list__ + if hasattr(src._local_tensor, "__create_write_items__"): + dst._local_tensor.__create_write_items__ = src._local_tensor.__create_write_items__ + + def _dtensor_from_local_like(self, value, local_tensor: torch.Tensor): + dtensor = _DTensor.from_local( + local_tensor=local_tensor, + device_mesh=value.device_mesh, + placements=value.placements, + shape=value.shape, + stride=value.stride(), + ) + self._copy_dtensor_chunk_metadata(dtensor, value) + return dtensor + + def _reshard_full_update_like(self, value, full_update: torch.Tensor): + if not hasattr(value._local_tensor, "__create_chunk_list__"): + if update_uneven_dtensor_chunk_metadata is None: + raise RuntimeError("DTensor support is required for Megatron-FSDP Muon.") + update_uneven_dtensor_chunk_metadata(value) + value_metadata = value._local_tensor.__create_chunk_list__()[0] + slices = tuple( + slice(offset, offset + size) + for offset, size in zip(value_metadata.offsets, value_metadata.sizes) + ) + local_update = full_update[slices].contiguous().to(dtype=value.to_local().dtype) + return self._dtensor_from_local_like(value, local_update) + + @torch.no_grad() # type: ignore[misc] + def _muon_step_one( + self, p: torch.Tensor, group: Dict[str, Any], gather_full: bool = False + ) -> None: + """Per-param Muon update that participates in DP collectives on every rank.""" + # Fall back to the plain parent step for the single-rank / unsharded case. + if self.dp_group is None or get_pg_size(self.dp_group) == 1: + if p.grad is None: + return + return self._local_muon_update(p, p.grad, group) + + value = self._as_dtensor(p) + if value is None: + if p.grad is None: + return + return self._local_muon_update(p, p.grad, group) + + p_local = value.to_local() + if p_local.numel() == 0 and not gather_full: + return + + state = self.state[p] + mom_buffer = state["momentum_buffer"] + mom_dtensor = self._as_dtensor(mom_buffer) + mom_local = mom_dtensor.to_local() if mom_dtensor is not None else mom_buffer + + grad = p.grad + if grad is None: + local_grad = torch.zeros_like(mom_local) + else: + grad_dtensor = self._as_dtensor(grad) + local_grad = grad_dtensor.to_local() if grad_dtensor is not None else grad + + lr = group["lr"] + self._apply_weight_decay_inplace(p_local, local_grad, lr, group["weight_decay"]) + + mom_local.lerp_(local_grad, 1 - group["momentum"]) + if self.nesterov: + update_local = local_grad.lerp(mom_local, group["momentum"]) + else: + update_local = mom_local + + from emerging_optimizers import utils + + with utils.fp32_matmul_precision(self.fp32_matmul_prec): + group_kwargs = {k: v for k, v in group.items() if k != "params"} + if gather_full: + if gather_uneven_dtensor_to_full_tensor is None: + raise RuntimeError("DTensor support is required for Megatron-FSDP Muon.") + update_dtensor = self._dtensor_from_local_like(value, update_local.contiguous()) + full_update = gather_uneven_dtensor_to_full_tensor(update_dtensor).to_local() + orth_full_update = super(FSDPZeROTensorParallelMuon, self).orthogonalize( + p, full_update, **group_kwargs + ) + sharded_update = self._reshard_full_update_like(value, orth_full_update) + self.pre_weight_update_fn_inplace(p, sharded_update) + update_target = p if isinstance(p, _DTensor) else value + update_target.add_(sharded_update, alpha=-lr) + self.post_weight_update_fn_inplace(p) + else: + orth_local_update = super(FSDPZeROTensorParallelMuon, self).orthogonalize( + p, update_local, **group_kwargs + ).to(dtype=p_local.dtype) + self.pre_weight_update_fn_inplace(p_local, orth_local_update) + p_local.add_(orth_local_update, alpha=-lr) + self.post_weight_update_fn_inplace(p_local) + + @torch.no_grad() # type: ignore[misc] + def _local_muon_update( + self, p: torch.Tensor, grad: torch.Tensor, group: Dict[str, Any] + ) -> None: + """Local (non-DP) Muon update – identical to OrthogonalizedOptimizer.step body.""" + from emerging_optimizers import utils + + state = self.state[p] + self._apply_weight_decay_inplace(p, grad, group["lr"], group["weight_decay"]) + state["momentum_buffer"].lerp_(grad, 1 - group["momentum"]) + if self.nesterov: + grad = grad.lerp(state["momentum_buffer"], group["momentum"]) + else: + grad = state["momentum_buffer"] + with utils.fp32_matmul_precision(self.fp32_matmul_prec): + group_kwargs = {k: v for k, v in group.items() if k != "params"} + orth_grad = self.orthogonalize(p, grad, **group_kwargs) + self.pre_weight_update_fn_inplace(p, orth_grad) + p.add_(orth_grad, alpha=-group["lr"]) + self.post_weight_update_fn_inplace(p) + + class TensorParallelAdaptiveMuon(TensorParallelMuon, AdaptiveMuon): """Tensor Parallel Adaptive Muon optimizer. diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 5f1968dcc27..1886b0d6e63 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1020,8 +1020,11 @@ def validate_args(args, defaults={}): args.use_distributed_optimizer = True # Optimizer step MXFP8 buffer operation that is not relevant or supported for Megatron-FSDP. args.reuse_grad_buf_for_mxfp8_param_ag = False - # Optimizer compatibility check. - assert args.optimizer in ('sgd', 'adam'), \ + # Optimizer compatibility check. Muon is supported via the + # _build_megatron_fsdp_emerging_optimizer path; other emerging + # optimizers (Lion, SOAP, adaptive_muon) don't yet have the + # FSDP-aware step contract wired up. + assert args.optimizer in ('sgd', 'adam', 'muon'), \ f"Megatron-FSDP does not support the {args.optimizer} optimizer yet." if ( @@ -1468,12 +1471,34 @@ def validate_args(args, defaults={}): args.use_layer_wise_distributed_optimizer = True if args.use_distributed_optimizer: - args.use_layer_wise_distributed_optimizer = True - args.use_distributed_optimizer = False + # Megatron-FSDP already shards optimizer state itself (ZeRO-1/2/3 + # via `--data-parallel-sharding-strategy`). Layering + # `LayerWiseDistributedOptimizer` on top would double-shard and, + # in practice, trips `TypeErrors` in its constructor call from + # `_build_megatron_fsdp_emerging_optimizer`. The M-FSDP factory + # already handles the "distributed" part via `FSDPMuonChainedOptimizer`. + if not args.use_megatron_fsdp: + args.use_layer_wise_distributed_optimizer = True + args.use_distributed_optimizer = False assert not args.use_torch_fsdp2, "Emerging optimizer does not support Torch-FSDP2 for now." - assert not args.use_megatron_fsdp, "Emerging optimizer does not support Megatron-FSDP for now." - assert args.ckpt_format in ["torch", "torch_dist"], "Emerging optimizer supports torch and torch_dist checkpoint format." + if args.use_megatron_fsdp: + assert args.optimizer == "muon", ( + "Emerging optimizer with Megatron-FSDP is currently only supported for Muon." + ) + assert args.outer_dp_sharding_strategy == "no_shard", ( + "Emerging optimizer with Megatron-FSDP does not support HSDP " + "(--outer-dp-sharding-strategy != no_shard) yet." + ) + # Megatron-FSDP itself requires `fsdp_dtensor` (asserted above), so + # the emerging-optimizer path must accept it here to avoid a + # contradictory assertion pair. + assert args.ckpt_format == "fsdp_dtensor", ( + "Emerging optimizer with Megatron-FSDP requires " + "--ckpt-format fsdp_dtensor." + ) + else: + assert args.ckpt_format in ["torch", "torch_dist"], "Emerging optimizer supports torch and torch_dist checkpoint format." # Make sure all functionality that requires Gloo process groups is disabled. diff --git a/tests/unit_tests/test_muon_fsdp_optimizer.py b/tests/unit_tests/test_muon_fsdp_optimizer.py new file mode 100644 index 00000000000..88797b59c8f --- /dev/null +++ b/tests/unit_tests/test_muon_fsdp_optimizer.py @@ -0,0 +1,385 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Tests for Muon + Megatron-FSDP integration.""" + +import os +from unittest.mock import MagicMock, patch + +import pytest +import torch +import torch.nn as nn +from packaging.version import Version + +from megatron.core.optimizer.emerging_optimizers import ( + FSDPMuonChainedOptimizer, + FSDPZeROTensorParallelMuon, + HAVE_EMERGING_OPTIMIZERS, + TensorParallelMuon, + _get_mfsdp_models, +) +from tests.unit_tests.test_utilities import Utils + + +pytestmark = [ + pytest.mark.skipif( + Version(os.getenv("NVIDIA_PYTORCH_VERSION", "99.99")) <= Version("25.05"), + reason="Skip Muon FSDP optimizer tests on LTS containers", + ), + pytest.mark.skipif( + not HAVE_EMERGING_OPTIMIZERS, + reason="Muon tests require the emerging-optimizers package", + ), +] + +WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + + +def _skip_if_single_rank(): + return pytest.mark.skipif(WORLD_SIZE <= 1, reason="Multi-rank test requires WORLD_SIZE > 1") + + +def _skip_if_no_dtensor(): + return pytest.mark.skipif( + Version(torch.__version__.split("+")[0]) < Version("2.4.0"), + reason="DTensor tests require PyTorch >= 2.4.0", + ) + + +def _make_fsdp_muon(params, dp_group=None, **kwargs): + defaults = dict( + lr=0.05, + momentum=0.0, + nesterov=False, + weight_decay=0.0, + num_ns_steps=2, + pg_collection=None, + tp_mode="duplicated", + split_qkv=False, + ) + defaults.update(kwargs) + return FSDPZeROTensorParallelMuon(params=params, dp_group=dp_group, **defaults) + + +def _reference_update(full_grad, **kwargs): + from emerging_optimizers import utils + + defaults = dict( + lr=0.05, + momentum=0.0, + nesterov=False, + weight_decay=0.0, + num_ns_steps=2, + pg_collection=None, + tp_mode="duplicated", + split_qkv=False, + ) + defaults.update(kwargs) + param = nn.Parameter(torch.zeros_like(full_grad)) + optimizer = TensorParallelMuon(params=[param], **defaults) + with utils.fp32_matmul_precision(optimizer.fp32_matmul_prec): + return optimizer.orthogonalize(param, full_grad) + + +def _local_rows(plan, rank): + return plan.get(rank, 0) + + +def _local_slice(full_tensor, plan, rank): + start = sum(_local_rows(plan, r) for r in range(rank)) + rows = _local_rows(plan, rank) + return full_tensor[start : start + rows].clone() + + +def _make_dtensor(local_tensor, global_shape, device_mesh): + from torch.distributed.tensor import DTensor, Shard + + from megatron.core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor import ( + update_uneven_dtensor_chunk_metadata, + ) + + global_stride = torch.empty( + global_shape, dtype=local_tensor.dtype, device=local_tensor.device + ).stride() + dtensor = DTensor.from_local( + local_tensor=local_tensor, + device_mesh=device_mesh, + placements=[Shard(0)], + shape=torch.Size(global_shape), + stride=global_stride, + run_check=False, + ) + update_uneven_dtensor_chunk_metadata(dtensor) + return dtensor + + +class TestFSDPMuonChainedOptimizer: + @staticmethod + def _make_mfsdp_mock(model_auto_sync=False): + mock = MagicMock() + mock.model_auto_sync = model_auto_sync + return mock + + @staticmethod + def _make_inner_mock(): + mock = MagicMock() + mock.step.return_value = (True, 1.0, 0) + return mock + + @pytest.mark.parametrize("model_auto_sync", [True, False]) + def test_step_protocol(self, model_auto_sync): + inner = self._make_inner_mock() + mfsdp = self._make_mfsdp_mock(model_auto_sync=model_auto_sync) + wrapper = FSDPMuonChainedOptimizer(inner, [mfsdp]) + + wrapper.step() + + if model_auto_sync: + mfsdp.finish_grad_sync.assert_not_called() + else: + mfsdp.finish_grad_sync.assert_called_once() + inner.step.assert_called_once() + mfsdp.install_optimized_model_weights.assert_called_once() + + def test_getattr_delegation(self): + inner = self._make_inner_mock() + inner.param_groups = [{"lr": 0.01}] + inner.state_dict.return_value = {"state": {}, "param_groups": []} + wrapper = FSDPMuonChainedOptimizer(inner, [self._make_mfsdp_mock()]) + + assert wrapper.param_groups == [{"lr": 0.01}] + assert wrapper.state_dict() == {"state": {}, "param_groups": []} + + wrapper.load_state_dict({"state": {}, "param_groups": []}) + inner.load_state_dict.assert_called_once() + + @pytest.mark.parametrize("set_to_none", [True, False]) + def test_zero_grad_delegates(self, set_to_none): + inner = self._make_inner_mock() + wrapper = FSDPMuonChainedOptimizer(inner, [self._make_mfsdp_mock()]) + + wrapper.zero_grad(set_to_none=set_to_none) + + inner.zero_grad.assert_called_once_with(set_to_none) + + def test_get_mfsdp_models_error(self): + with pytest.raises(RuntimeError, match="Could not find any MegatronFSDP"): + _get_mfsdp_models([nn.Module()]) + + def test_get_mfsdp_models_success(self): + inner_module = MagicMock() + chunk = MagicMock() + chunk.module = inner_module + chunk.finish_grad_sync = MagicMock() + + assert _get_mfsdp_models([chunk]) == [inner_module] + + +@_skip_if_single_rank() +@_skip_if_no_dtensor() +class TestFSDPZeROTensorParallelMuon: + @pytest.fixture(autouse=True) + def setup_and_teardown(self): + Utils.initialize_model_parallel() + yield + Utils.destroy_model_parallel() + + def test_step_gathers_only_split_boundary_params(self): + from torch.distributed.device_mesh import init_device_mesh + + from megatron.core.optimizer import emerging_optimizers as eopt_mod + + dp_size = torch.distributed.get_world_size() + dp_rank = torch.distributed.get_rank() + device_mesh = init_device_mesh("cuda", (dp_size,), mesh_dim_names=("dp",)) + dp_group = device_mesh.get_group("dp") + cols = 4 + + # Param 0 is fully local on rank 0. Param 1 crosses the rank 0/1 + # boundary. Param 2 is fully local on rank 1. Other ranks hold empty + # local shards and still participate in the boundary gather for param 1. + plans = [ + {0: 4}, + {0: 2, 1: 4}, + {1: 5}, + ] + full_params = [ + torch.arange(rows * cols, device="cuda", dtype=torch.float32).view(rows, cols) / 100 + for rows in (4, 6, 5) + ] + full_grads = [ + torch.arange(rows * cols, device="cuda", dtype=torch.float32).view(rows, cols) / 50 + + 0.1 + for rows in (4, 6, 5) + ] + + params = [] + initial_locals = [] + for full_param, plan in zip(full_params, plans): + local_param = _local_slice(full_param, plan, dp_rank).contiguous() + param = nn.Parameter(_make_dtensor(local_param, full_param.shape, device_mesh)) + params.append(param) + initial_locals.append(local_param.clone()) + + for param, full_grad, plan in zip(params, full_grads, plans): + local_grad = _local_slice(full_grad, plan, dp_rank).contiguous() + param.grad = _make_dtensor(local_grad, full_grad.shape, device_mesh) + + optimizer = _make_fsdp_muon(params, dp_group=dp_group) + assert optimizer._get_boundary_gather_param_indices(optimizer.param_groups[0]) == {1} + + real_gather = eopt_mod.gather_uneven_dtensor_to_full_tensor + gather_shapes = [] + + def counting_gather(value): + gather_shapes.append(tuple(value.shape)) + return real_gather(value) + + with patch.object(eopt_mod, "gather_uneven_dtensor_to_full_tensor", counting_gather): + optimizer.step() + + assert gather_shapes == [(6, cols)] + + lr = optimizer.param_groups[0]["lr"] + for idx, (param, full_param, full_grad, plan) in enumerate( + zip(params, full_params, full_grads, plans) + ): + expected_update = _reference_update(full_grad) + expected_local = _local_slice(full_param - lr * expected_update, plan, dp_rank) + local_value = param.data.to_local() + + if local_value.numel() == 0: + assert local_value.shape[0] == 0 + continue + + torch.testing.assert_close( + local_value, + expected_local, + atol=1e-5, + rtol=1e-4, + msg=f"param {idx} local update mismatch on rank {dp_rank}", + ) + assert not torch.equal(local_value, initial_locals[idx]) + + def test_boundary_detector_includes_middle_split_params(self): + from torch.distributed.device_mesh import init_device_mesh + + dp_size = torch.distributed.get_world_size() + dp_rank = torch.distributed.get_rank() + device_mesh = init_device_mesh("cuda", (dp_size,), mesh_dim_names=("dp",)) + dp_group = device_mesh.get_group("dp") + cols = 4 + + # Optimizer groups can exclude Adam-managed tensors that are present in + # the underlying M-FSDP flat buffer. After that filtering, a split Muon + # tensor is not guaranteed to be first or last in the Muon param group. + plans = [ + {0: 2, 1: 2}, + {0: 1, 1: 3}, + {0: 3, 1: 1}, + ] + + params = [] + for plan in plans: + full_param = torch.zeros(sum(plan.values()), cols, device="cuda") + local_param = _local_slice(full_param, plan, dp_rank).contiguous() + params.append(nn.Parameter(_make_dtensor(local_param, full_param.shape, device_mesh))) + + optimizer = _make_fsdp_muon(params, dp_group=dp_group) + + assert optimizer._get_boundary_gather_param_indices(optimizer.param_groups[0]) == { + 0, + 1, + 2, + } + + +@_skip_if_single_rank() +class TestFSDPFactoryIntegration: + @pytest.fixture(autouse=True) + def setup_and_teardown(self): + Utils.initialize_model_parallel() + yield + Utils.destroy_model_parallel() + + @pytest.mark.parametrize("strategy", ["no_shard", "optim", "optim_grads", "optim_grads_params"]) + def test_factory_dispatches_correct_muon_cls(self, strategy): + from megatron.core.optimizer import OptimizerConfig, _build_megatron_fsdp_emerging_optimizer + from megatron.core.process_groups_config import ProcessGroupCollection + from megatron.core.optimizer.optimizer import ChainedOptimizer, FP32Optimizer + + model_chunk = MagicMock() + model_chunk.config.num_attention_heads = 8 + model_chunk.config.num_query_groups = 8 + model_chunk.config.kv_channels = 16 + + linear_weight = nn.Parameter(torch.randn(32, 16, device="cuda")) + linear_weight.is_embedding_or_output_parameter = False + bias_param = nn.Parameter(torch.randn(32, device="cuda")) + bias_param.is_embedding_or_output_parameter = False + + model_chunk.named_parameters.return_value = [ + ("layer.linear.weight", linear_weight), + ("layer.linear.bias", bias_param), + ] + model_chunk.parameters.return_value = iter([linear_weight, bias_param]) + model_chunk.ddp_config.use_megatron_fsdp = True + model_chunk.ddp_config.data_parallel_sharding_strategy = strategy + + config = OptimizerConfig( + optimizer="muon", + lr=0.01, + weight_decay=0.01, + bf16=False, + fp16=False, + use_distributed_optimizer=False, + muon_momentum=0.95, + muon_nesterov=True, + muon_fp32_matmul_prec="medium", + muon_num_ns_steps=5, + muon_scale_mode="spectral", + muon_tp_mode="duplicated", + muon_split_qkv=False, + muon_extra_scale_factor=1.0, + ) + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + + def fake_get_param_groups(model_chunks, config, overrides): + return [ + { + "params": [linear_weight], + "is_expert_parallel": False, + "wd_mult": 1.0, + "lr_mult": 1.0, + } + ] + + with patch("megatron.core.optimizer._get_param_groups", side_effect=fake_get_param_groups): + with patch("megatron.core.optimizer.get_megatron_optimizer") as mock_get_opt: + mock_adam = MagicMock() + mock_adam_sub = MagicMock() + mock_adam_sub.config = config + mock_adam.chained_optimizers = [mock_adam_sub] + mock_get_opt.return_value = mock_adam + with patch("megatron.core.optimizer._get_mfsdp_models", return_value=[MagicMock()]): + result = _build_megatron_fsdp_emerging_optimizer( + config=config, + model_chunks=[model_chunk], + config_overrides={}, + pg_collection=pg_collection, + eopt_name="muon", + use_layer_wise=False, + ) + + assert isinstance(result, FSDPMuonChainedOptimizer) + inner = object.__getattribute__(result, "inner") + assert isinstance(inner, ChainedOptimizer) + + muon_wrapper = inner.chained_optimizers[0] + assert isinstance(muon_wrapper, FP32Optimizer) + base_opt = muon_wrapper.optimizer + if strategy == "no_shard": + assert isinstance(base_opt, TensorParallelMuon) + assert not isinstance(base_opt, FSDPZeROTensorParallelMuon) + else: + assert isinstance(base_opt, FSDPZeROTensorParallelMuon) + assert base_opt.dp_group is not None