Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 207 additions & 1 deletion megatron/core/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@
from .emerging_optimizers import (
_EMERGING_OPTIMIZERS,
HAVE_EMERGING_OPTIMIZERS,
FSDPMuonChainedOptimizer,
FSDPZeROTensorParallelMuon,
_create_emerging_optimizer,
_get_mfsdp_models,
)
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .layer_wise_optimizer import LayerWiseDistributedOptimizer
Expand Down Expand Up @@ -721,6 +724,193 @@ def check_config_overrides_consistency(
return True


def _build_megatron_fsdp_emerging_optimizer(
config: OptimizerConfig,
model_chunks: List[MegatronModule],
config_overrides: Dict[ParamKey, ParamGroupOverride],
pg_collection: ProcessGroupCollection,
eopt_name: str,
use_layer_wise: bool,
) -> "FSDPMuonChainedOptimizer":
"""Build an emerging optimizer (currently only Muon) for Megatron-FSDP.

The standard emerging-optimizer flow assumes `main_grad` buffers populated
by DDP. With Megatron-FSDP, gradients arrive via `finish_grad_sync()` on
DTensor parameters, which makes `Float16OptimizerWithFloat16Params`
incompatible. This helper:

1. Builds the linear-only Muon optimizer (plain `TensorParallelMuon` for
`no_shard`; `FSDPZeROTensorParallelMuon` for sharded strategies, which
gathers only split rank-boundary DTensor parameters before NS) over a
frozen-grad view of the model so `_get_param_groups` only emits
Muon-managed params.
2. Falls back to `get_megatron_optimizer` for the non-linear params,
which routes through the standard Megatron-FSDP path for Adam.
3. Wraps the chained result with `FSDPMuonChainedOptimizer`, which drives
`finish_grad_sync` and `install_optimized_model_weights`.
"""
# Lazy import to avoid pulling Muon-specific symbols when emerging_optimizers
# is unavailable; entry already validated upstream.
entry = _EMERGING_OPTIMIZERS[eopt_name]
init_state_fn = entry.init_state_fn

# Choose Muon variant based on the inner-DP sharding strategy.
# - "no_shard" (ZeRO-0): params/grads are full Replicate DTensors, so the
# plain TensorParallelMuon works without any DP communication.
# - Sharded strategies (ZeRO-1/2/3): FSDPZeROTensorParallelMuon gathers
# only split boundary parameters across the DP group before NS and
# re-shards the result.
fsdp_strategy = getattr(
model_chunks[0].ddp_config, "data_parallel_sharding_strategy", "no_shard"
)
use_fsdp_zero_muon = fsdp_strategy != "no_shard"
optimizer_cls = FSDPZeROTensorParallelMuon if use_fsdp_zero_muon else entry.optimizer_cls

log_single_rank(
logger,
logging.INFO,
f"Setting up Megatron-FSDP emerging optimizer ({eopt_name}) with config {config}",
)

# Sort parameters into linear (Muon-managed) vs nonlinear (Adam-managed).
linear_params: List[torch.nn.Parameter] = []
nonlinear_params: List[torch.nn.Parameter] = []
for model_chunk in model_chunks:
for _, param in model_chunk.named_parameters():
if not param.requires_grad:
continue
if (
not getattr(param, "is_embedding_or_output_parameter", False)
and len(param.shape) == 2
):
linear_params.append(param)
else:
nonlinear_params.append(param)

# Build the Muon optimizer for linear params.

for p in nonlinear_params:
p.requires_grad = False
linear_param_groups = _get_param_groups(model_chunks, config, config_overrides)

expert_param_groups: List[Dict[str, Any]] = []
if not use_layer_wise:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this because layer-wise <> EP isn't compatible, so we just use Adam? (Gathering expert params is too heavyweight? Also, another complexity is EP-DP sharding, so we need to gather both.)

# Build a fresh list rather than calling `linear_param_groups.remove(group)`:
# `list.remove` falls back to `dict.__eq__` when the elements are dicts,
# which compares values and trips on tensor params with
# `RuntimeError: Boolean value of Tensor with more than one value is ambiguous`.
non_expert_groups: List[Dict[str, Any]] = []
for group in linear_param_groups:
if group["is_expert_parallel"]:
expert_param_groups.append(group)
else:
non_expert_groups.append(group)
linear_param_groups = non_expert_groups

eopt_kwargs: Dict[str, Any] = {}
if entry.config_to_kwargs is not None:
eopt_kwargs = entry.config_to_kwargs(config, model_chunks, pg_collection)

optimizers: List[Any] = []
init_fns: List[Callable] = []
if linear_param_groups:
# ZeRO-1/2/3: split dense boundary params are sharded over `dp_cp`, so
# the FSDPZero variant must gather over the same group.
dense_kwargs = dict(eopt_kwargs)
if use_fsdp_zero_muon:
dense_kwargs["dp_group"] = pg_collection.dp_cp
muon_base = optimizer_cls(linear_param_groups, **dense_kwargs)
muon_opt = FP32Optimizer(muon_base, config, init_state_fn)
setattr(muon_opt, "grad_stats_parallel_group", pg_collection.mp)
setattr(muon_opt, "tp_group", pg_collection.tp)
optimizers.append(muon_opt)
init_fns.append(init_state_fn)

if expert_param_groups:
# Expert params reduce-scatter over `expt_dp` (the expert
# data-parallel group), not `dp_cp`.
expert_kwargs = dict(eopt_kwargs)
if use_fsdp_zero_muon:
expert_kwargs["dp_group"] = pg_collection.expt_dp
expert_muon_base = optimizer_cls(expert_param_groups, **expert_kwargs)
expert_muon_opt = FP32Optimizer(expert_muon_base, config, init_state_fn)
setattr(expert_muon_opt, "grad_stats_parallel_group", pg_collection.tp_ep_pp)
setattr(expert_muon_opt, "tp_group", pg_collection.tp)
optimizers.append(expert_muon_opt)
init_fns.append(init_state_fn)

# Build Adam for non-linear params via the standard FSDP path.

for p in nonlinear_params:
p.requires_grad = True
for p in linear_params:
p.requires_grad = False

saved_optimizer = config.optimizer
saved_use_lwd = config.use_layer_wise_distributed_optimizer
config.optimizer = "adam"
config.use_layer_wise_distributed_optimizer = False
try:
# `pg_collection` is always provided here, so Gloo process groups
# must be disabled on the recursive call (setup_process_groups_for_optimizer
# rejects the combination).
chained_adam = get_megatron_optimizer(
config,
model_chunks,
config_overrides=config_overrides,
pg_collection=pg_collection,
use_gloo_process_groups=False,
)
finally:
config.optimizer = saved_optimizer
config.use_layer_wise_distributed_optimizer = saved_use_lwd

for p in linear_params:
p.requires_grad = True

adam_optimizers = list(getattr(chained_adam, "chained_optimizers", [chained_adam]))
optimizers += adam_optimizers

def _adam_init_state_fn(opt: Any, config: Any = None) -> None:
for group in opt.param_groups:
for p in group["params"]:
if len(opt.state[p]) == 0:
if config is None or not config.use_precision_aware_optimizer:
opt.state[p]["exp_avg"] = torch.zeros_like(p.data)
opt.state[p]["exp_avg_sq"] = torch.zeros_like(p.data)
Comment on lines +878 to +880
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also what does this case capture? With MFSDP, I think we need DTensors so we usually do a "dummy step" to get DTensor state based on the DTensor params.

Two examples:

else:
opt.initialize_state(p)

init_fns += [_adam_init_state_fn] * len(adam_optimizers)

# Combine and wrap with the FSDP-protocol adapter.

if use_layer_wise:
# Float16OptimizerWithFloat16Params is incompatible with FSDP DTensor
# params (no .main_grad), so temporarily clear bf16 to prevent the
# LayerWiseDistributedOptimizer from re-wrapping each sub-optimizer.
Comment on lines +889 to +891
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either we should extend the logic in LayerWiseDistributedOptimizer or the comment should be simpler, like this:

LayerWiseDistributedOptimizer(config.bf16) wraps the optimizer with a Float16 Megatron optimizer if BF16 is used. Megatron-FSDP only supports DistributedOptimizer and LayerWiseDistributedOptimizer.

Note: Layer-wise appears to be a permutation of optimizer params into different layer-specific groups distributed across DP-CP (or EDP). I am not sure if this is compatible with Megatron-FSDP, since we have un-even shards. Is it possible that, say, non-empty Params are all allocated to a different DP rank, but that rank's equivalent Params are empty, so no optimizer update?

log_single_rank(
logger,
logging.INFO,
f"Using LayerWiseDistributedOptimizer for {eopt_name} + Megatron-FSDP",
)
saved_bf16 = config.bf16
config.bf16 = False
try:
inner: Any = LayerWiseDistributedOptimizer(
optimizers,
config,
pg_collection,
init_state_fn_list=init_fns,
)
finally:
config.bf16 = saved_bf16
else:
inner = ChainedOptimizer(optimizers)

return FSDPMuonChainedOptimizer(inner, _get_mfsdp_models(model_chunks))


def _get_megatron_emerging_optimizer(
config: OptimizerConfig,
model_chunks: List[MegatronModule],
Expand Down Expand Up @@ -778,7 +968,23 @@ def _get_megatron_emerging_optimizer(
if 'linear_qkv.weight' in name and len(param.shape) == 2:
param.is_qkv = True

# Apply optimizer-specific default param overrides (e.g. muon: non-linear -> adam).
# Megatron-FSDP needs a separate code path: gradients are attached via
# `finish_grad_sync()` rather than via `main_grad` buffers, so the
# standard Float16OptimizerWithFloat16Params wrapper is incompatible.
# In addition, the resulting optimizer must be wrapped to invoke FSDP's
# `finish_grad_sync()` / `install_optimized_model_weights()` hooks
# around each step.
if getattr(model_chunks[0].ddp_config, 'use_megatron_fsdp', False):
return _build_megatron_fsdp_emerging_optimizer(
config=config,
model_chunks=model_chunks,
config_overrides=config_overrides,
pg_collection=pg_collection,
eopt_name=eopt_name,
use_layer_wise=use_layer_wise,
)

# Apply optimizer-specific default param overrides (e.g., muon: non-linear -> adam).
config_overrides.update(_EMERGING_OPTIMIZERS[eopt_name].default_param_overrides)

# Build param groups and bucket by (optimizer_name, is_expert_parallel).
Expand Down
Loading