diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index c6d3e41aed5..c22b50df51a 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -10,6 +10,28 @@ from torch.optim import SGD as CPUSGD from torch.optim import AdamW as CPUAdam +from megatron.core import parallel_state +from megatron.core.optimizer.cpu_offloading.hybrid_optimizer import HybridDeviceOptimizer +from megatron.core.optimizer_param_scheduler import (ParamGroupOverride, + combine_param_group_overrides, + param_group_override_to_tuple) +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.fsdp_dtensor_checkpoint import get_global_unique_param_name + +from ..distributed.param_and_grad_buffer import _ParamAndGradBuffer +from ..transformer.module import MegatronModule +from ..utils import get_model_config, get_pg_rank, get_pg_size, is_te_min_version, log_single_rank +from .distrib_optimizer import DistributedOptimizer +from .emerging_optimizers import (_EMERGING_OPTIMIZERS, HAVE_EMERGING_OPTIMIZERS, + _create_emerging_optimizer) +from .grad_scaler import ConstantGradScaler, DynamicGradScaler +from .layer_wise_optimizer import LayerWiseDistributedOptimizer +from .optimizer import (ChainedOptimizer, Float16OptimizerWithFloat16Params, FP32Optimizer, + MegatronOptimizer, param_group_identifier_keys) +# Subclass aliases kept for backward compatibility; all are OptimizerConfig. +from .optimizer_config import (AdamOptimizerConfig, OptimizerConfig, ParamKey, ParamPredicate, + ParamWithNamePredicate, SGDOptimizerConfig) + try: from transformer_engine.pytorch.optimizers import FusedAdam as Adam from transformer_engine.pytorch.optimizers import FusedSGD as SGD @@ -47,44 +69,8 @@ if HAVE_EMERGING_OPTIMIZERS: from emerging_optimizers.scalar_optimizers import Lion -from megatron.core import parallel_state -from megatron.core.optimizer.cpu_offloading.hybrid_optimizer import HybridDeviceOptimizer -from megatron.core.optimizer_param_scheduler import ( - ParamGroupOverride, - combine_param_group_overrides, - param_group_override_to_tuple, -) -from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.transformer.fsdp_dtensor_checkpoint import get_global_unique_param_name -from ..distributed.param_and_grad_buffer import _ParamAndGradBuffer -from ..transformer.module import MegatronModule -from ..utils import get_model_config, get_pg_rank, get_pg_size, is_te_min_version, log_single_rank -from .distrib_optimizer import DistributedOptimizer -from .emerging_optimizers import ( - _EMERGING_OPTIMIZERS, - HAVE_EMERGING_OPTIMIZERS, - _create_emerging_optimizer, -) -from .grad_scaler import ConstantGradScaler, DynamicGradScaler -from .layer_wise_optimizer import LayerWiseDistributedOptimizer -from .optimizer import ( - ChainedOptimizer, - Float16OptimizerWithFloat16Params, - FP32Optimizer, - MegatronOptimizer, - param_group_identifier_keys, -) -# Subclass aliases kept for backward compatibility; all are OptimizerConfig. -from .optimizer_config import ( - AdamOptimizerConfig, - OptimizerConfig, - ParamKey, - ParamPredicate, - ParamWithNamePredicate, - SGDOptimizerConfig, -) logger = logging.getLogger(__name__) @@ -492,8 +478,8 @@ def _get_megatron_optimizer_based_on_param_groups( raise ValueError( "skip_megatron_wrapping=True is incompatible with use_precision_aware_optimizer." ) - if skip_megatron_wrapping and config.optimizer_cpu_offload: - raise ValueError("skip_megatron_wrapping=True is incompatible with optimizer_cpu_offload.") + # NOTE: skip_megatron_wrapping + optimizer_cpu_offload is allowed for Muon, + # where LayerWiseDistributedOptimizer handles CPU offloading itself. # When freezing sub-models we may have no trainable parameters on a rank and # hence an empty param_groups. However, we still need to create an optimizer @@ -822,6 +808,12 @@ def _get_megatron_emerging_optimizer( fallback_config = copy.copy(config) fallback_config.optimizer = opt_name fallback_config.use_distributed_optimizer = False + if use_layer_wise: + # Disable per-optimizer CPU offload (HybridDeviceOptimizer) for the + # Adam fallback when LayerWiseDistributedOptimizer is active. + # CPU offloading is handled uniformly by LayerWiseDistributedOptimizer + # for all sub-optimizers (Muon + Adam), preventing double-offloading. + fallback_config.optimizer_cpu_offload = False result = _get_megatron_optimizer_based_on_param_groups( config=fallback_config, model_chunks=model_chunks, diff --git a/megatron/core/optimizer/layer_wise_optimizer.py b/megatron/core/optimizer/layer_wise_optimizer.py index d0f64010bad..d91e51cdfe9 100644 --- a/megatron/core/optimizer/layer_wise_optimizer.py +++ b/megatron/core/optimizer/layer_wise_optimizer.py @@ -1,7 +1,7 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import logging -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Tuple import torch from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors @@ -12,12 +12,8 @@ from megatron.core.utils import get_pg_rank, get_pg_size from .clip_grads import count_zeros_fp32, get_grad_norm_fp32 -from .optimizer import ( - ChainedOptimizer, - Float16OptimizerWithFloat16Params, - FP32Optimizer, - MegatronOptimizer, -) +from .optimizer import (ChainedOptimizer, Float16OptimizerWithFloat16Params, FP32Optimizer, + MegatronOptimizer) from .optimizer_config import OptimizerConfig logger = logging.getLogger(__name__) @@ -27,16 +23,40 @@ class LayerWiseDistributedOptimizer(ChainedOptimizer): """Layer-wise distributed optimizer for Megatron-core models. Experimental distributed optimizer wrapper that distributes weight to DP ranks by layer. - Implemented as ChainedOptimizer to support multiple optimizers (e.g. muon + adamW) + Implemented as ChainedOptimizer to support multiple optimizers (e.g. muon + adamW). When using, keep all megatron distributed-optimizer related options OFF. - How LayerWiseDistributedOptimizer work: - 1. weights are splited into lists and each rank only keep its shard in its optimizer - 2. Megatron DDP handle allreduce grad, note that each rank have full model and grad - 3. optimizer is already modified so only param belong to this DP rank is updated - 4. grad_norm and zero counting will reduce metrics globally in step function - 5. Do regular update with chained optimizers, modified optimizer only update shard - 6. allgather updated params to every rank + How LayerWiseDistributedOptimizer works: + + 1. Weights are split into lists and each rank only keeps its shard in its optimizer. + 2. Megatron DDP handles allreduce grad; each rank has full model and grad. + 3. Optimizer is modified so only params belonging to this DP rank are updated. + 4. grad_norm and zero counting reduce metrics globally in step function. + 5. Regular update with chained optimizers; modified optimizer only updates shard. + 6. All-gather (or broadcast) updated params to every rank. + + CPU Offloading: + + When ``optimizer_cpu_offload=True`` in the config, this optimizer manages a + host-device-host (H2D/D2H) cycle for the fp32 master weights and momentum + buffers owned by the wrapped ``Float16OptimizerWithFloat16Params`` sub-optimizers. + This is particularly beneficial for Muon, where most model parameters are + "muonable" and their fp32 master weights + momentum constitute the majority + of optimizer memory. + + The offload lifecycle per training step: + + 1. **reload_optimizer_states()**: Move fp32 master weights and optimizer state + tensors from CPU pinned memory back to GPU before the optimizer step. + 2. **super().step()**: Run the actual optimizer update (e.g. Newton-Schulz for + Muon, Adam for fallback params) on GPU. + 3. **broadcast_params()**: Synchronize updated bf16 model params across DP ranks. + 4. **offload_optimizer_states()**: Move fp32 master weights and optimizer state + tensors back to CPU pinned memory to free GPU memory. + + Note: The Adam fallback optimizer's ``optimizer_cpu_offload`` is set to False + when ``use_layer_wise_distributed_optimizer=True``, preventing double-offloading + via ``HybridDeviceOptimizer``. All offloading is unified through this class. """ def __init__( @@ -91,6 +111,11 @@ def __init__( super().__init__(optimizers) + self._cpu_offload = getattr(config, 'optimizer_cpu_offload', False) + if self._cpu_offload: + self.offload_optimizer_states() + logger.info('[layerwise] optimizer states CPU offloading enabled') + # TODO(kunlun, deyuf): potential future perf optimization # since allreduce is unchanged and handled by megatron DDP, they're already in # contiguous gbuf. So instead of shard param by layer randomly, we can shard by @@ -242,7 +267,6 @@ def _allgather_helper(params_list, group): @torch.no_grad() def broadcast_params(self): """All rank broadcast updated local params.""" - # Broadcast linear layer weights to all other ranks. Kept as reference test. if self.dp_cp_params_list is None: return for i, params in enumerate(self.dp_cp_params_list): @@ -277,8 +301,18 @@ def count_zeros(self): ) @torch.no_grad() - def step(self): # type: ignore[no-untyped-def] - """step function for layer-wise optimizer.""" + def step(self) -> Tuple[bool, Optional[float], Optional[int]]: + """Perform a single optimization step with optional CPU offloading. + + When CPU offloading is enabled, this method orchestrates the full cycle: + reload states to GPU -> optimizer step -> broadcast params -> offload states to CPU. + + Returns: + Tuple of (update_successful, grad_norm, num_zeros_in_grad). + """ + if self._cpu_offload: + self.reload_optimizer_states() + update_successful, grad_norm, num_zeros_in_grad = super().step() # All gather updated params. If overlap_param_gather is True, the allgather @@ -286,8 +320,64 @@ def step(self): # type: ignore[no-untyped-def] if not self.overlap_param_gather: self.allgather_params() + if self._cpu_offload: + self.offload_optimizer_states() + return update_successful, grad_norm, num_zeros_in_grad + @torch.no_grad() + def offload_optimizer_states(self) -> None: + """Move fp32 master weights and optimizer state tensors to CPU pinned memory. + + This transfers all fp32 master weight parameters (``fp32_from_float16_groups``) + and optimizer state tensors (e.g. momentum buffers) from GPU to CPU pinned memory. + Pinned memory enables faster H2D transfers on the next reload. + + Called after each optimizer step to free GPU memory for the next forward pass. + A ``torch.cuda.synchronize()`` at entry ensures any pending GPU work (e.g. + param broadcasts) completes before tensors are moved off-device. + """ + torch.cuda.synchronize() + for opt in self.chained_optimizers: + if getattr(opt, 'is_stub_optimizer', False): + continue + if not isinstance(opt, Float16OptimizerWithFloat16Params): + continue + for group in opt.fp32_from_float16_groups: + for param in group: + if param.data.is_cuda: + param.data = param.data.cpu().pin_memory() + for state_vals in opt.optimizer.state.values(): + for key, val in state_vals.items(): + if isinstance(val, torch.Tensor) and val.is_cuda: + state_vals[key] = val.cpu().pin_memory() + + @torch.no_grad() + def reload_optimizer_states(self) -> None: + """Move fp32 master weights and optimizer state tensors back to GPU. + + This transfers all fp32 master weight parameters and optimizer state tensors + from CPU pinned memory to the current CUDA device. A ``torch.cuda.synchronize()`` + at exit ensures all H2D transfers complete before the optimizer step proceeds. + + Called at the start of each optimizer step so that the Newton-Schulz iterations + (Muon) or Adam updates can operate on GPU tensors. + """ + for opt in self.chained_optimizers: + if getattr(opt, 'is_stub_optimizer', False): + continue + if not isinstance(opt, Float16OptimizerWithFloat16Params): + continue + for group in opt.fp32_from_float16_groups: + for param in group: + if not param.data.is_cuda: + param.data = param.data.to('cuda') + for state_vals in opt.optimizer.state.values(): + for key, val in state_vals.items(): + if isinstance(val, torch.Tensor) and not val.is_cuda: + state_vals[key] = val.to('cuda') + torch.cuda.synchronize() + # TODO(deyuf): need to improve dist checkpointing design to properly handle this # fp32_from_fp16_params is list, each sub list could be empty if group is empty # this breaks dist checkpointing assumption since extract_sharded_base drop list structure diff --git a/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py b/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py new file mode 100644 index 00000000000..87245967610 --- /dev/null +++ b/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py @@ -0,0 +1,282 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Tests for Muon optimizer CPU offloading in LayerWiseDistributedOptimizer. + +This module verifies that the CPU offloading mechanism correctly: +- Moves fp32 master weights and optimizer state (momentum) to CPU pinned memory. +- Preserves tensor values exactly through offload/reload round-trips. +- Produces numerically identical results to the non-offloaded code path. + +These tests require multi-GPU execution (via torchrun or pytest with distributed +launcher) since LayerWiseDistributedOptimizer shards parameters across DP ranks. +""" + +from typing import Generator + +import pytest +import torch +from tests.unit_tests.test_utilities import Utils + +from megatron.core import parallel_state +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.optimizer import get_megatron_optimizer +from megatron.core.optimizer.layer_wise_optimizer import LayerWiseDistributedOptimizer +from megatron.core.optimizer.optimizer import Float16OptimizerWithFloat16Params +from megatron.core.optimizer.optimizer_config import OptimizerConfig +from megatron.core.tensor_parallel import model_parallel_cuda_manual_seed +from megatron.core.transformer import TransformerConfig + + +def _create_model(seed: int, tp: int, pp: int) -> GPTModel: + """Create a small GPT model for testing. + + Args: + seed: Random seed for reproducibility. + tp: Tensor parallel size (already initialized via Utils). + pp: Pipeline parallel size (already initialized via Utils). + + Returns: + A GPTModel instance on the current CUDA device. + """ + torch.manual_seed(seed) + model_parallel_cuda_manual_seed(seed) + config = TransformerConfig( + num_layers=6, + hidden_size=16, + num_attention_heads=8, + use_cpu_initialization=True, + pipeline_dtype=torch.bfloat16, + bf16=True, + ) + layer_spec = get_gpt_layer_with_transformer_engine_spec() + model = GPTModel( + config=config, + transformer_layer_spec=layer_spec, + vocab_size=128, + max_sequence_length=4, + pre_process=parallel_state.is_pipeline_first_stage(), + post_process=parallel_state.is_pipeline_last_stage(), + ) + model.cuda(torch.cuda.current_device()) + return model + + +def _create_optimizer(model: GPTModel, cpu_offload: bool = True) -> LayerWiseDistributedOptimizer: + """Create a Muon LayerWise optimizer with optional CPU offloading. + + Args: + model: The GPT model whose parameters will be optimized. + cpu_offload: Whether to enable CPU offloading of optimizer states. + + Returns: + A LayerWiseDistributedOptimizer wrapping Muon + Adam fallback. + """ + config = OptimizerConfig( + bf16=True, + params_dtype=torch.bfloat16, + use_distributed_optimizer=False, + use_layer_wise_distributed_optimizer=True, + optimizer='muon', + lr=0.0, + optimizer_cpu_offload=cpu_offload, + ) + optimizer = get_megatron_optimizer(config, [model]) + + if isinstance(optimizer, LayerWiseDistributedOptimizer): + for opt in optimizer.chained_optimizers: + init_fn = getattr(opt, 'init_state_fn', None) + if init_fn is None: + continue + if hasattr(opt, 'optimizer'): + init_fn(opt.optimizer) + else: + init_fn(opt) + if cpu_offload: + optimizer.offload_optimizer_states() + return optimizer + + +def _iter_fp16_opts( + optimizer: LayerWiseDistributedOptimizer, +) -> Generator[Float16OptimizerWithFloat16Params, None, None]: + """Yield Float16OptimizerWithFloat16Params sub-optimizers, skipping stubs.""" + for opt in optimizer.chained_optimizers: + if getattr(opt, 'is_stub_optimizer', False): + continue + if isinstance(opt, Float16OptimizerWithFloat16Params): + yield opt + + +class TestMuonCPUOffload: + """Tests for Muon CPU offloading in LayerWiseDistributedOptimizer. + + Verifies the correctness of the CPU offload mechanism that moves fp32 master + weights and momentum buffers between GPU and CPU pinned memory each step. + Tests cover state placement, round-trip fidelity, step execution, and + bit-exact numerical equivalence against the non-offloaded baseline. + """ + + def setup_method(self, method) -> None: + pass + + def teardown_method(self, method) -> None: + Utils.destroy_model_parallel() + + @pytest.mark.parametrize('tp,pp', [(2, 2), (1, 4), (4, 1)]) + def test_states_on_cpu(self, tp: int, pp: int) -> None: + """After init with cpu_offload=True, fp32 master weights and state are on CPU.""" + if tp * pp > torch.cuda.device_count(): + pytest.skip("Not enough GPUs") + + Utils.initialize_model_parallel(tp, pp) + model = _create_model(seed=2, tp=tp, pp=pp) + optimizer = _create_optimizer(model, cpu_offload=True) + + assert isinstance(optimizer, LayerWiseDistributedOptimizer) + assert optimizer._cpu_offload + + for opt in _iter_fp16_opts(optimizer): + for group in opt.fp32_from_float16_groups: + for param in group: + assert ( + not param.data.is_cuda + ), f"fp32 master weight should be on CPU, got {param.data.device}" + for state_vals in opt.optimizer.state.values(): + for key, val in state_vals.items(): + if isinstance(val, torch.Tensor): + assert ( + not val.is_cuda + ), f"optimizer state '{key}' should be on CPU, got {val.device}" + + @pytest.mark.parametrize('tp,pp', [(2, 2), (1, 4), (4, 1)]) + def test_roundtrip_correctness(self, tp: int, pp: int) -> None: + """Offload -> reload preserves fp32 master weight values exactly.""" + if tp * pp > torch.cuda.device_count(): + pytest.skip("Not enough GPUs") + + Utils.initialize_model_parallel(tp, pp) + model = _create_model(seed=2, tp=tp, pp=pp) + optimizer = _create_optimizer(model, cpu_offload=True) + + assert isinstance(optimizer, LayerWiseDistributedOptimizer) + + snapshots = {} + for opt in _iter_fp16_opts(optimizer): + for gidx, group in enumerate(opt.fp32_from_float16_groups): + for pidx, param in enumerate(group): + snapshots[(id(opt), gidx, pidx)] = param.data.clone() + + optimizer.reload_optimizer_states() + + for opt in _iter_fp16_opts(optimizer): + for gidx, group in enumerate(opt.fp32_from_float16_groups): + for pidx, param in enumerate(group): + assert param.data.is_cuda, "After reload, param should be on GPU" + expected = snapshots[(id(opt), gidx, pidx)].to(param.data.device) + assert torch.equal( + param.data, expected + ), "Master weight mismatch after offload->reload roundtrip" + + optimizer.offload_optimizer_states() + + for opt in _iter_fp16_opts(optimizer): + for gidx, group in enumerate(opt.fp32_from_float16_groups): + for pidx, param in enumerate(group): + assert not param.data.is_cuda, "After offload, param should be on CPU" + + @pytest.mark.parametrize('tp,pp', [(2, 2), (1, 4), (4, 1)]) + def test_step_runs(self, tp: int, pp: int) -> None: + """A full optimizer.step() succeeds with CPU offloading.""" + if tp * pp > torch.cuda.device_count(): + pytest.skip("Not enough GPUs") + + Utils.initialize_model_parallel(tp, pp) + model = _create_model(seed=2, tp=tp, pp=pp) + optimizer = _create_optimizer(model, cpu_offload=True) + + assert isinstance(optimizer, LayerWiseDistributedOptimizer) + + for param in model.parameters(): + if param.requires_grad: + g = torch.randn_like(param.data) + param.grad = g + param.main_grad = g + + update_successful, grad_norm, num_zeros = optimizer.step() + assert isinstance(update_successful, bool) + + for opt in _iter_fp16_opts(optimizer): + for group in opt.fp32_from_float16_groups: + for param in group: + assert ( + not param.data.is_cuda + ), "After step, fp32 master weights should be back on CPU" + + @pytest.mark.parametrize('tp,pp', [(2, 2), (4, 1)]) + @pytest.mark.parametrize('n_steps', [3, 5]) + def test_numerical_equivalence(self, tp: int, pp: int, n_steps: int) -> None: + """Offloaded and non-offloaded optimizers produce bit-identical results. + + Runs both an offloaded and a non-offloaded optimizer for ``n_steps`` + with identical random gradients, then verifies that fp32 master weights + and optimizer state tensors match exactly. This ensures the offload/reload + cycle introduces zero numerical drift. + """ + if tp * pp > torch.cuda.device_count(): + pytest.skip("Not enough GPUs") + + Utils.initialize_model_parallel(tp, pp) + + model_off = _create_model(seed=42, tp=tp, pp=pp) + model_ref = _create_model(seed=42, tp=tp, pp=pp) + + opt_off = _create_optimizer(model_off, cpu_offload=True) + opt_ref = _create_optimizer(model_ref, cpu_offload=False) + + assert isinstance(opt_off, LayerWiseDistributedOptimizer) + assert isinstance(opt_ref, LayerWiseDistributedOptimizer) + + rank = torch.distributed.get_rank() + + for step_i in range(n_steps): + torch.manual_seed(1000 + step_i + rank) + + for p_off, p_ref in zip(model_off.parameters(), model_ref.parameters()): + if not p_off.requires_grad: + continue + g = torch.randn_like(p_off.data) + p_off.grad = g.clone() + p_off.main_grad = p_off.grad + p_ref.grad = g.clone() + p_ref.main_grad = p_ref.grad + + opt_off.step() + opt_ref.step() + + opt_off.reload_optimizer_states() + + for opt_o, opt_r in zip(_iter_fp16_opts(opt_off), _iter_fp16_opts(opt_ref)): + for grp_o, grp_r in zip(opt_o.fp32_from_float16_groups, opt_r.fp32_from_float16_groups): + for pidx, (p_o, p_r) in enumerate(zip(grp_o, grp_r)): + p_o_gpu = p_o.data.to('cuda') if not p_o.data.is_cuda else p_o.data + assert torch.equal(p_o_gpu, p_r.data), ( + f"fp32 master weight mismatch at param {pidx} after {n_steps} steps, " + f"max diff = {(p_o_gpu - p_r.data).abs().max().item()}" + ) + + for (key_o, state_o), (key_r, state_r) in zip( + opt_o.optimizer.state.items(), opt_r.optimizer.state.items() + ): + common_keys = set(state_o.keys()) & set(state_r.keys()) + for skey in common_keys: + v_o, v_r = state_o[skey], state_r[skey] + if not isinstance(v_o, torch.Tensor): + continue + v_o_gpu = v_o.to('cuda') if not v_o.is_cuda else v_o + assert torch.equal(v_o_gpu, v_r), ( + f"optimizer state '{skey}' mismatch after {n_steps} steps, " + f"max diff = {(v_o_gpu - v_r).abs().max().item()}" + ) + + opt_off.offload_optimizer_states()