From 812349399d0e11e849d79a59201f848ebf8c3568 Mon Sep 17 00:00:00 2001 From: pengdurice Date: Sun, 26 Apr 2026 16:22:10 +0000 Subject: [PATCH 01/15] initial commit: add muon cpu offload --- megatron/core/optimizer/__init__.py | 6 +- .../core/optimizer/layer_wise_optimizer.py | 47 +++ muon_cpu_offloading.slurm | 38 +++ tests/test_muon_cpu_offload.py | 304 ++++++++++++++++++ .../test_layer_wise_optimizer.py | 121 +++++++ 5 files changed, 514 insertions(+), 2 deletions(-) create mode 100644 muon_cpu_offloading.slurm create mode 100644 tests/test_muon_cpu_offload.py diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index c6d3e41aed5..caa60ce23ea 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -492,8 +492,8 @@ def _get_megatron_optimizer_based_on_param_groups( raise ValueError( "skip_megatron_wrapping=True is incompatible with use_precision_aware_optimizer." ) - if skip_megatron_wrapping and config.optimizer_cpu_offload: - raise ValueError("skip_megatron_wrapping=True is incompatible with optimizer_cpu_offload.") + # NOTE: skip_megatron_wrapping + optimizer_cpu_offload is allowed for Muon, + # where LayerWiseDistributedOptimizer handles CPU offloading itself. # When freezing sub-models we may have no trainable parameters on a rank and # hence an empty param_groups. However, we still need to create an optimizer @@ -822,6 +822,8 @@ def _get_megatron_emerging_optimizer( fallback_config = copy.copy(config) fallback_config.optimizer = opt_name fallback_config.use_distributed_optimizer = False + if use_layer_wise: + fallback_config.optimizer_cpu_offload = False result = _get_megatron_optimizer_based_on_param_groups( config=fallback_config, model_chunks=model_chunks, diff --git a/megatron/core/optimizer/layer_wise_optimizer.py b/megatron/core/optimizer/layer_wise_optimizer.py index d0f64010bad..0459c42c4b4 100644 --- a/megatron/core/optimizer/layer_wise_optimizer.py +++ b/megatron/core/optimizer/layer_wise_optimizer.py @@ -91,6 +91,11 @@ def __init__( super().__init__(optimizers) + self._cpu_offload = getattr(config, 'optimizer_cpu_offload', False) + if self._cpu_offload: + self.offload_optimizer_states() + logger.info('[layerwise] optimizer states CPU offloading enabled') + # TODO(kunlun, deyuf): potential future perf optimization # since allreduce is unchanged and handled by megatron DDP, they're already in # contiguous gbuf. So instead of shard param by layer randomly, we can shard by @@ -279,6 +284,9 @@ def count_zeros(self): @torch.no_grad() def step(self): # type: ignore[no-untyped-def] """step function for layer-wise optimizer.""" + if self._cpu_offload: + self.reload_optimizer_states() + update_successful, grad_norm, num_zeros_in_grad = super().step() # All gather updated params. If overlap_param_gather is True, the allgather @@ -286,8 +294,47 @@ def step(self): # type: ignore[no-untyped-def] if not self.overlap_param_gather: self.allgather_params() + if self._cpu_offload: + self.offload_optimizer_states() + return update_successful, grad_norm, num_zeros_in_grad + @torch.no_grad() + def offload_optimizer_states(self): + """Move fp32 master weights and optimizer states to CPU pinned memory.""" + torch.cuda.synchronize() + for opt in self.chained_optimizers: + if getattr(opt, 'is_stub_optimizer', False): + continue + if not isinstance(opt, Float16OptimizerWithFloat16Params): + continue + for group in opt.fp32_from_float16_groups: + for param in group: + if param.data.is_cuda: + param.data = param.data.cpu().pin_memory() + for state_vals in opt.optimizer.state.values(): + for key, val in state_vals.items(): + if isinstance(val, torch.Tensor) and val.is_cuda: + state_vals[key] = val.cpu().pin_memory() + + @torch.no_grad() + def reload_optimizer_states(self): + """Move fp32 master weights and optimizer states back to GPU.""" + for opt in self.chained_optimizers: + if getattr(opt, 'is_stub_optimizer', False): + continue + if not isinstance(opt, Float16OptimizerWithFloat16Params): + continue + for group in opt.fp32_from_float16_groups: + for param in group: + if not param.data.is_cuda: + param.data = param.data.to('cuda') + for state_vals in opt.optimizer.state.values(): + for key, val in state_vals.items(): + if isinstance(val, torch.Tensor) and not val.is_cuda: + state_vals[key] = val.to('cuda') + torch.cuda.synchronize() + # TODO(deyuf): need to improve dist checkpointing design to properly handle this # fp32_from_fp16_params is list, each sub list could be empty if group is empty # this breaks dist checkpointing assumption since extract_sharded_base drop list structure diff --git a/muon_cpu_offloading.slurm b/muon_cpu_offloading.slurm new file mode 100644 index 00000000000..5c74aa8903e --- /dev/null +++ b/muon_cpu_offloading.slurm @@ -0,0 +1,38 @@ +#!/bin/bash +#SBATCH --job-name=muon-cpu-offload +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=4 +#SBATCH --output=./logs/muon_cpu_offload-%j.out +#SBATCH --error=./logs/muon_cpu_offload-%j.err +#SBATCH --exclusive +#SBATCH --time=0:30:00 +#SBATCH --partition=h200-hi-preempt +#SBATCH --exclude=ip-10-1-115-124,ip-10-1-45-16 +#SBATCH --wait-all-nodes=1 + +set -euxo pipefail + +IMAGE=/fsx/peng/containers/nemo-rl-v0.5.2-te-tilelang.sqsh + +MEGATRON_SRC=/fsx/peng/workspace/megatron-muon-cpu-offload/Megatron-LM +MEGATRON_DST=/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM + +MOUNTS=${MEGATRON_SRC}:${MEGATRON_DST} + +srun \ + --container-image "${IMAGE}" \ + --container-mounts "${MOUNTS}" \ + --container-remap-root \ + bash -lc " + set -euxo pipefail + cd ${MEGATRON_DST} + + export NVIDIA_PYTORCH_VERSION=25.06 + export PYTHONPATH=${MEGATRON_DST}:\${PYTHONPATH:-} + pip install git+https://github.com/NVIDIA-NeMo/Emerging-Optimizers.git@v0.2.0 + + torchrun --nproc-per-node=4 tests/test_muon_cpu_offload.py + " + +echo "Test completed." diff --git a/tests/test_muon_cpu_offload.py b/tests/test_muon_cpu_offload.py new file mode 100644 index 00000000000..cebcaf3198b --- /dev/null +++ b/tests/test_muon_cpu_offload.py @@ -0,0 +1,304 @@ +"""Standalone tests for Muon CPU offloading in LayerWiseDistributedOptimizer. + +Run with: + torchrun --nproc-per-node=4 tests/test_muon_cpu_offload.py + +Avoids the pytest conftest circular-import issue by running as a plain script. +""" + +import os +import sys +import traceback +from datetime import timedelta + +import torch +import torch.distributed + +from megatron.core import parallel_state +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.optimizer import get_megatron_optimizer +from megatron.core.optimizer.layer_wise_optimizer import LayerWiseDistributedOptimizer +from megatron.core.optimizer.optimizer import Float16OptimizerWithFloat16Params +from megatron.core.optimizer.optimizer_config import OptimizerConfig +from megatron.core.tensor_parallel import model_parallel_cuda_manual_seed +from megatron.core.transformer import TransformerConfig + + +def init_distributed(): + rank = int(os.environ['LOCAL_RANK']) + world_size = int(os.environ['WORLD_SIZE']) + torch.cuda.set_device(rank) + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group( + backend='nccl', world_size=world_size, rank=rank, + timeout=timedelta(minutes=2), + ) + return rank, world_size + + +def create_model(seed, tp, pp): + torch.manual_seed(seed) + model_parallel_cuda_manual_seed(seed) + config = TransformerConfig( + num_layers=6, + hidden_size=16, + num_attention_heads=8, + use_cpu_initialization=True, + pipeline_dtype=torch.bfloat16, + bf16=True, + ) + layer_spec = get_gpt_layer_with_transformer_engine_spec() + model = GPTModel( + config=config, + transformer_layer_spec=layer_spec, + vocab_size=128, + max_sequence_length=4, + pre_process=parallel_state.is_pipeline_first_stage(), + post_process=parallel_state.is_pipeline_last_stage(), + ) + model.cuda(torch.cuda.current_device()) + return model + + +def create_optimizer_with_cpu_offload(model, cpu_offload=True): + config = OptimizerConfig( + bf16=True, + params_dtype=torch.bfloat16, + use_distributed_optimizer=False, + use_layer_wise_distributed_optimizer=True, + optimizer='muon', + lr=0.0, + optimizer_cpu_offload=cpu_offload, + ) + optimizer = get_megatron_optimizer(config, [model]) + + if isinstance(optimizer, LayerWiseDistributedOptimizer): + for opt in optimizer.chained_optimizers: + if getattr(opt, 'init_state_fn', None) is None: + continue + if not hasattr(opt, 'optimizer'): + opt.init_state_fn(opt) + else: + opt.init_state_fn(opt.optimizer) + if cpu_offload: + optimizer.offload_optimizer_states() + return optimizer + + +def _iter_fp16_opts(optimizer): + for opt in optimizer.chained_optimizers: + if getattr(opt, 'is_stub_optimizer', False): + continue + if isinstance(opt, Float16OptimizerWithFloat16Params): + yield opt + + +def test_states_on_cpu(rank): + """After init with cpu_offload=True, fp32 master weights and state are on CPU.""" + parallel_state.initialize_model_parallel(2, 2) + try: + model = create_model(seed=2, tp=2, pp=2) + optimizer = create_optimizer_with_cpu_offload(model, cpu_offload=True) + + assert isinstance(optimizer, LayerWiseDistributedOptimizer) + assert optimizer._cpu_offload + + for opt in _iter_fp16_opts(optimizer): + for group in opt.fp32_from_float16_groups: + for param in group: + assert not param.data.is_cuda, ( + f"[rank {rank}] fp32 master weight should be on CPU, " + f"got {param.data.device}" + ) + for state_vals in opt.optimizer.state.values(): + for key, val in state_vals.items(): + if isinstance(val, torch.Tensor): + assert not val.is_cuda, ( + f"[rank {rank}] optimizer state '{key}' should be on CPU, " + f"got {val.device}" + ) + + print(f" [rank {rank}] PASSED: test_states_on_cpu") + finally: + parallel_state.destroy_model_parallel() + + +def test_roundtrip_correctness(rank): + """Offload -> reload preserves fp32 master weight values exactly.""" + parallel_state.initialize_model_parallel(2, 2) + try: + model = create_model(seed=2, tp=2, pp=2) + optimizer = create_optimizer_with_cpu_offload(model, cpu_offload=True) + + assert isinstance(optimizer, LayerWiseDistributedOptimizer) + + snapshots = {} + for opt in _iter_fp16_opts(optimizer): + for gidx, group in enumerate(opt.fp32_from_float16_groups): + for pidx, param in enumerate(group): + snapshots[(id(opt), gidx, pidx)] = param.data.clone() + + optimizer.reload_optimizer_states() + + for opt in _iter_fp16_opts(optimizer): + for gidx, group in enumerate(opt.fp32_from_float16_groups): + for pidx, param in enumerate(group): + assert param.data.is_cuda, ( + f"[rank {rank}] After reload, param should be on GPU" + ) + expected = snapshots[(id(opt), gidx, pidx)].to(param.data.device) + assert torch.equal(param.data, expected), ( + f"[rank {rank}] Master weight mismatch after roundtrip" + ) + + optimizer.offload_optimizer_states() + + for opt in _iter_fp16_opts(optimizer): + for gidx, group in enumerate(opt.fp32_from_float16_groups): + for pidx, param in enumerate(group): + assert not param.data.is_cuda, ( + f"[rank {rank}] After offload, param should be on CPU" + ) + + print(f" [rank {rank}] PASSED: test_roundtrip_correctness") + finally: + parallel_state.destroy_model_parallel() + + +def test_step_runs(rank): + """A full optimizer.step() succeeds with CPU offloading.""" + parallel_state.initialize_model_parallel(2, 2) + try: + model = create_model(seed=2, tp=2, pp=2) + optimizer = create_optimizer_with_cpu_offload(model, cpu_offload=True) + + assert isinstance(optimizer, LayerWiseDistributedOptimizer) + + for param in model.parameters(): + if param.requires_grad: + g = torch.randn_like(param.data) + param.grad = g + param.main_grad = g + + update_successful, grad_norm, num_zeros = optimizer.step() + assert isinstance(update_successful, bool), ( + f"[rank {rank}] update_successful should be bool, got {type(update_successful)}" + ) + + for opt in _iter_fp16_opts(optimizer): + for group in opt.fp32_from_float16_groups: + for param in group: + assert not param.data.is_cuda, ( + f"[rank {rank}] After step, fp32 master weights should be on CPU" + ) + + print(f" [rank {rank}] PASSED: test_step_runs") + finally: + parallel_state.destroy_model_parallel() + + +def test_numerical_equivalence(rank, n_steps=5): + """Offloaded and non-offloaded optimizers produce identical fp32 master weights.""" + parallel_state.initialize_model_parallel(2, 2) + try: + model_off = create_model(seed=42, tp=2, pp=2) + model_ref = create_model(seed=42, tp=2, pp=2) + + opt_off = create_optimizer_with_cpu_offload(model_off, cpu_offload=True) + opt_ref = create_optimizer_with_cpu_offload(model_ref, cpu_offload=False) + + assert isinstance(opt_off, LayerWiseDistributedOptimizer) + assert isinstance(opt_ref, LayerWiseDistributedOptimizer) + + for step_i in range(n_steps): + torch.manual_seed(1000 + step_i + rank) + + for p_off, p_ref in zip(model_off.parameters(), model_ref.parameters()): + if not p_off.requires_grad: + continue + g = torch.randn_like(p_off.data) + p_off.grad = g.clone() + p_off.main_grad = p_off.grad + p_ref.grad = g.clone() + p_ref.main_grad = p_ref.grad + + opt_off.step() + opt_ref.step() + + opt_off.reload_optimizer_states() + + for opt_o, opt_r in zip( + _iter_fp16_opts(opt_off), _iter_fp16_opts(opt_ref) + ): + for grp_o, grp_r in zip( + opt_o.fp32_from_float16_groups, opt_r.fp32_from_float16_groups + ): + for pidx, (p_o, p_r) in enumerate(zip(grp_o, grp_r)): + p_o_gpu = p_o.data.to('cuda') if not p_o.data.is_cuda else p_o.data + assert torch.equal(p_o_gpu, p_r.data), ( + f"[rank {rank}] fp32 master weight mismatch at param {pidx} " + f"after {n_steps} steps, " + f"max diff = {(p_o_gpu - p_r.data).abs().max().item()}" + ) + + for (key_o, state_o), (key_r, state_r) in zip( + opt_o.optimizer.state.items(), opt_r.optimizer.state.items() + ): + common_keys = set(state_o.keys()) & set(state_r.keys()) + for skey in common_keys: + v_o, v_r = state_o[skey], state_r[skey] + if not isinstance(v_o, torch.Tensor): + continue + v_o_gpu = v_o.to('cuda') if not v_o.is_cuda else v_o + assert torch.equal(v_o_gpu, v_r), ( + f"[rank {rank}] optimizer state '{skey}' mismatch " + f"after {n_steps} steps, " + f"max diff = {(v_o_gpu - v_r).abs().max().item()}" + ) + + opt_off.offload_optimizer_states() + + print(f" [rank {rank}] PASSED: test_numerical_equivalence ({n_steps} steps)") + finally: + parallel_state.destroy_model_parallel() + + +def main(): + rank, world_size = init_distributed() + + tests = [ + ("test_states_on_cpu", test_states_on_cpu), + ("test_roundtrip_correctness", test_roundtrip_correctness), + ("test_step_runs", test_step_runs), + ("test_numerical_equivalence", test_numerical_equivalence), + ] + + passed, failed = 0, 0 + for name, fn in tests: + torch.distributed.barrier() + if rank == 0: + print(f"\n{'='*60}") + print(f"Running: {name}") + print(f"{'='*60}") + try: + fn(rank) + passed += 1 + except Exception: + failed += 1 + if rank == 0: + traceback.print_exc() + print(f" [rank {rank}] FAILED: {name}") + + torch.distributed.barrier() + if rank == 0: + print(f"\n{'='*60}") + print(f"Results: {passed} passed, {failed} failed out of {len(tests)}") + print(f"{'='*60}") + + torch.distributed.destroy_process_group() + sys.exit(1 if failed > 0 else 0) + + +if __name__ == '__main__': + main() diff --git a/tests/unit_tests/dist_checkpointing/test_layer_wise_optimizer.py b/tests/unit_tests/dist_checkpointing/test_layer_wise_optimizer.py index 3f60658a005..8d36fde3069 100644 --- a/tests/unit_tests/dist_checkpointing/test_layer_wise_optimizer.py +++ b/tests/unit_tests/dist_checkpointing/test_layer_wise_optimizer.py @@ -17,6 +17,7 @@ from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.optimizer import ChainedOptimizer from megatron.core.optimizer.layer_wise_optimizer import LayerWiseDistributedOptimizer +from megatron.core.optimizer.optimizer import Float16OptimizerWithFloat16Params from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel import model_parallel_cuda_manual_seed from megatron.core.transformer import MLATransformerConfig, TransformerConfig @@ -604,3 +605,123 @@ def test_optimizer_common_state_dict( check_equal(optim_param_state_A, optim_param_state_B) Utils.destroy_model_parallel() + + +class TestMuonCPUOffload: + """Tests for Muon CPU offloading in LayerWiseDistributedOptimizer.""" + + def setup_method(self, method): + pass + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def _create_cpu_offload_optimizer(self, tp=2, pp=2, seed=2): + """Helper: build a Muon LayerWise optimizer with CPU offloading.""" + from megatron.core.optimizer import get_megatron_optimizer + from megatron.core.optimizer.optimizer_config import OptimizerConfig + + Utils.initialize_model_parallel(tp, pp) + model = initialize_real_model( + seed=seed, + pre_process=parallel_state.is_pipeline_first_stage(), + post_process=parallel_state.is_pipeline_last_stage(), + ) + model.cuda(torch.cuda.current_device()) + + config = OptimizerConfig( + bf16=True, + params_dtype=torch.bfloat16, + use_distributed_optimizer=False, + use_layer_wise_distributed_optimizer=True, + optimizer='muon', + lr=0.0, + optimizer_cpu_offload=True, + ) + + optimizer = get_megatron_optimizer(config, [model]) + + if isinstance(optimizer, LayerWiseDistributedOptimizer): + for opt in optimizer.chained_optimizers: + if not hasattr(opt, 'optimizer'): + opt.init_state_fn(opt) + else: + opt.init_state_fn(opt.optimizer) + return model, optimizer + + def test_cpu_offload_states_on_cpu(self): + """After init, Muon fp32 master weights and momentum are on CPU.""" + model, optimizer = self._create_cpu_offload_optimizer() + assert isinstance(optimizer, LayerWiseDistributedOptimizer) + assert optimizer._cpu_offload + + for opt in optimizer.chained_optimizers: + if getattr(opt, 'is_stub_optimizer', False): + continue + if not isinstance(opt, Float16OptimizerWithFloat16Params): + continue + for group in opt.fp32_from_float16_groups: + for param in group: + assert not param.data.is_cuda, ( + f"fp32 master weight should be on CPU, got {param.data.device}" + ) + for state_vals in opt.optimizer.state.values(): + for key, val in state_vals.items(): + if isinstance(val, torch.Tensor): + assert not val.is_cuda, ( + f"optimizer state '{key}' should be on CPU, got {val.device}" + ) + + def test_cpu_offload_roundtrip_correctness(self): + """Offload -> reload preserves fp32 master weight values.""" + model, optimizer = self._create_cpu_offload_optimizer() + assert isinstance(optimizer, LayerWiseDistributedOptimizer) + + snapshots = {} + for opt in optimizer.chained_optimizers: + if getattr(opt, 'is_stub_optimizer', False): + continue + if not isinstance(opt, Float16OptimizerWithFloat16Params): + continue + for gidx, group in enumerate(opt.fp32_from_float16_groups): + for pidx, param in enumerate(group): + snapshots[(id(opt), gidx, pidx)] = param.data.clone() + + optimizer.reload_optimizer_states() + + for opt in optimizer.chained_optimizers: + if getattr(opt, 'is_stub_optimizer', False): + continue + if not isinstance(opt, Float16OptimizerWithFloat16Params): + continue + for gidx, group in enumerate(opt.fp32_from_float16_groups): + for pidx, param in enumerate(group): + assert param.data.is_cuda, "After reload, param should be on GPU" + key = (id(opt), gidx, pidx) + expected = snapshots[key].to(param.data.device) + assert torch.equal(param.data, expected), ( + "Master weight mismatch after offload->reload roundtrip" + ) + + optimizer.offload_optimizer_states() + + for opt in optimizer.chained_optimizers: + if getattr(opt, 'is_stub_optimizer', False): + continue + if not isinstance(opt, Float16OptimizerWithFloat16Params): + continue + for gidx, group in enumerate(opt.fp32_from_float16_groups): + for pidx, param in enumerate(group): + assert not param.data.is_cuda, "After offload, param should be on CPU" + + def test_cpu_offload_step_runs(self): + """A full optimizer step works with CPU offloading enabled.""" + model, optimizer = self._create_cpu_offload_optimizer() + assert isinstance(optimizer, LayerWiseDistributedOptimizer) + + for param in model.parameters(): + if param.requires_grad: + param.grad = torch.randn_like(param.data) + + update_successful, grad_norm, num_zeros = optimizer.step() + assert isinstance(update_successful, bool) From f1c4c187ca5687509443e2b1d9fd4de19f04890d Mon Sep 17 00:00:00 2001 From: pengdurice Date: Thu, 30 Apr 2026 20:14:39 +0000 Subject: [PATCH 02/15] move test cases Signed-off-by: pengdurice --- muon_cpu_offloading.slurm | 38 ---- .../test_muon_cpu_offload.py | 208 ++++++------------ 2 files changed, 72 insertions(+), 174 deletions(-) delete mode 100644 muon_cpu_offloading.slurm rename tests/{ => unit_tests/dist_checkpointing}/test_muon_cpu_offload.py (53%) diff --git a/muon_cpu_offloading.slurm b/muon_cpu_offloading.slurm deleted file mode 100644 index 5c74aa8903e..00000000000 --- a/muon_cpu_offloading.slurm +++ /dev/null @@ -1,38 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=muon-cpu-offload -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-node=4 -#SBATCH --output=./logs/muon_cpu_offload-%j.out -#SBATCH --error=./logs/muon_cpu_offload-%j.err -#SBATCH --exclusive -#SBATCH --time=0:30:00 -#SBATCH --partition=h200-hi-preempt -#SBATCH --exclude=ip-10-1-115-124,ip-10-1-45-16 -#SBATCH --wait-all-nodes=1 - -set -euxo pipefail - -IMAGE=/fsx/peng/containers/nemo-rl-v0.5.2-te-tilelang.sqsh - -MEGATRON_SRC=/fsx/peng/workspace/megatron-muon-cpu-offload/Megatron-LM -MEGATRON_DST=/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM - -MOUNTS=${MEGATRON_SRC}:${MEGATRON_DST} - -srun \ - --container-image "${IMAGE}" \ - --container-mounts "${MOUNTS}" \ - --container-remap-root \ - bash -lc " - set -euxo pipefail - cd ${MEGATRON_DST} - - export NVIDIA_PYTORCH_VERSION=25.06 - export PYTHONPATH=${MEGATRON_DST}:\${PYTHONPATH:-} - pip install git+https://github.com/NVIDIA-NeMo/Emerging-Optimizers.git@v0.2.0 - - torchrun --nproc-per-node=4 tests/test_muon_cpu_offload.py - " - -echo "Test completed." diff --git a/tests/test_muon_cpu_offload.py b/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py similarity index 53% rename from tests/test_muon_cpu_offload.py rename to tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py index cebcaf3198b..18cc09682e9 100644 --- a/tests/test_muon_cpu_offload.py +++ b/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py @@ -1,43 +1,22 @@ -"""Standalone tests for Muon CPU offloading in LayerWiseDistributedOptimizer. - -Run with: - torchrun --nproc-per-node=4 tests/test_muon_cpu_offload.py - -Avoids the pytest conftest circular-import issue by running as a plain script. -""" - -import os -import sys -import traceback -from datetime import timedelta +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +import pytest import torch -import torch.distributed from megatron.core import parallel_state -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.optimizer import get_megatron_optimizer from megatron.core.optimizer.layer_wise_optimizer import LayerWiseDistributedOptimizer from megatron.core.optimizer.optimizer import Float16OptimizerWithFloat16Params from megatron.core.optimizer.optimizer_config import OptimizerConfig from megatron.core.tensor_parallel import model_parallel_cuda_manual_seed +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.transformer import TransformerConfig +from tests.unit_tests.test_utilities import Utils -def init_distributed(): - rank = int(os.environ['LOCAL_RANK']) - world_size = int(os.environ['WORLD_SIZE']) - torch.cuda.set_device(rank) - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group( - backend='nccl', world_size=world_size, rank=rank, - timeout=timedelta(minutes=2), - ) - return rank, world_size - - -def create_model(seed, tp, pp): +def _create_model(seed, tp, pp): + """Create a small GPT model for testing.""" torch.manual_seed(seed) model_parallel_cuda_manual_seed(seed) config = TransformerConfig( @@ -61,7 +40,8 @@ def create_model(seed, tp, pp): return model -def create_optimizer_with_cpu_offload(model, cpu_offload=True): +def _create_optimizer(model, cpu_offload=True): + """Create a Muon LayerWise optimizer with optional CPU offloading.""" config = OptimizerConfig( bf16=True, params_dtype=torch.bfloat16, @@ -75,18 +55,20 @@ def create_optimizer_with_cpu_offload(model, cpu_offload=True): if isinstance(optimizer, LayerWiseDistributedOptimizer): for opt in optimizer.chained_optimizers: - if getattr(opt, 'init_state_fn', None) is None: + init_fn = getattr(opt, 'init_state_fn', None) + if init_fn is None: continue - if not hasattr(opt, 'optimizer'): - opt.init_state_fn(opt) + if hasattr(opt, 'optimizer'): + init_fn(opt.optimizer) else: - opt.init_state_fn(opt.optimizer) + init_fn(opt) if cpu_offload: optimizer.offload_optimizer_states() return optimizer def _iter_fp16_opts(optimizer): + """Yield Float16OptimizerWithFloat16Params sub-optimizers.""" for opt in optimizer.chained_optimizers: if getattr(opt, 'is_stub_optimizer', False): continue @@ -94,12 +76,24 @@ def _iter_fp16_opts(optimizer): yield opt -def test_states_on_cpu(rank): - """After init with cpu_offload=True, fp32 master weights and state are on CPU.""" - parallel_state.initialize_model_parallel(2, 2) - try: - model = create_model(seed=2, tp=2, pp=2) - optimizer = create_optimizer_with_cpu_offload(model, cpu_offload=True) +class TestMuonCPUOffload: + """Tests for Muon CPU offloading in LayerWiseDistributedOptimizer.""" + + def setup_method(self, method): + pass + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.parametrize('tp,pp', [(2, 2), (1, 4), (4, 1)]) + def test_states_on_cpu(self, tp, pp): + """After init with cpu_offload=True, fp32 master weights and state are on CPU.""" + if tp * pp > torch.cuda.device_count(): + pytest.skip("Not enough GPUs") + + Utils.initialize_model_parallel(tp, pp) + model = _create_model(seed=2, tp=tp, pp=pp) + optimizer = _create_optimizer(model, cpu_offload=True) assert isinstance(optimizer, LayerWiseDistributedOptimizer) assert optimizer._cpu_offload @@ -108,28 +102,24 @@ def test_states_on_cpu(rank): for group in opt.fp32_from_float16_groups: for param in group: assert not param.data.is_cuda, ( - f"[rank {rank}] fp32 master weight should be on CPU, " - f"got {param.data.device}" + f"fp32 master weight should be on CPU, got {param.data.device}" ) for state_vals in opt.optimizer.state.values(): for key, val in state_vals.items(): if isinstance(val, torch.Tensor): assert not val.is_cuda, ( - f"[rank {rank}] optimizer state '{key}' should be on CPU, " - f"got {val.device}" + f"optimizer state '{key}' should be on CPU, got {val.device}" ) - print(f" [rank {rank}] PASSED: test_states_on_cpu") - finally: - parallel_state.destroy_model_parallel() + @pytest.mark.parametrize('tp,pp', [(2, 2), (1, 4), (4, 1)]) + def test_roundtrip_correctness(self, tp, pp): + """Offload -> reload preserves fp32 master weight values exactly.""" + if tp * pp > torch.cuda.device_count(): + pytest.skip("Not enough GPUs") - -def test_roundtrip_correctness(rank): - """Offload -> reload preserves fp32 master weight values exactly.""" - parallel_state.initialize_model_parallel(2, 2) - try: - model = create_model(seed=2, tp=2, pp=2) - optimizer = create_optimizer_with_cpu_offload(model, cpu_offload=True) + Utils.initialize_model_parallel(tp, pp) + model = _create_model(seed=2, tp=tp, pp=pp) + optimizer = _create_optimizer(model, cpu_offload=True) assert isinstance(optimizer, LayerWiseDistributedOptimizer) @@ -144,12 +134,10 @@ def test_roundtrip_correctness(rank): for opt in _iter_fp16_opts(optimizer): for gidx, group in enumerate(opt.fp32_from_float16_groups): for pidx, param in enumerate(group): - assert param.data.is_cuda, ( - f"[rank {rank}] After reload, param should be on GPU" - ) + assert param.data.is_cuda, "After reload, param should be on GPU" expected = snapshots[(id(opt), gidx, pidx)].to(param.data.device) assert torch.equal(param.data, expected), ( - f"[rank {rank}] Master weight mismatch after roundtrip" + "Master weight mismatch after offload->reload roundtrip" ) optimizer.offload_optimizer_states() @@ -157,21 +145,17 @@ def test_roundtrip_correctness(rank): for opt in _iter_fp16_opts(optimizer): for gidx, group in enumerate(opt.fp32_from_float16_groups): for pidx, param in enumerate(group): - assert not param.data.is_cuda, ( - f"[rank {rank}] After offload, param should be on CPU" - ) - - print(f" [rank {rank}] PASSED: test_roundtrip_correctness") - finally: - parallel_state.destroy_model_parallel() + assert not param.data.is_cuda, "After offload, param should be on CPU" + @pytest.mark.parametrize('tp,pp', [(2, 2), (1, 4), (4, 1)]) + def test_step_runs(self, tp, pp): + """A full optimizer.step() succeeds with CPU offloading.""" + if tp * pp > torch.cuda.device_count(): + pytest.skip("Not enough GPUs") -def test_step_runs(rank): - """A full optimizer.step() succeeds with CPU offloading.""" - parallel_state.initialize_model_parallel(2, 2) - try: - model = create_model(seed=2, tp=2, pp=2) - optimizer = create_optimizer_with_cpu_offload(model, cpu_offload=True) + Utils.initialize_model_parallel(tp, pp) + model = _create_model(seed=2, tp=tp, pp=pp) + optimizer = _create_optimizer(model, cpu_offload=True) assert isinstance(optimizer, LayerWiseDistributedOptimizer) @@ -182,35 +166,35 @@ def test_step_runs(rank): param.main_grad = g update_successful, grad_norm, num_zeros = optimizer.step() - assert isinstance(update_successful, bool), ( - f"[rank {rank}] update_successful should be bool, got {type(update_successful)}" - ) + assert isinstance(update_successful, bool) for opt in _iter_fp16_opts(optimizer): for group in opt.fp32_from_float16_groups: for param in group: assert not param.data.is_cuda, ( - f"[rank {rank}] After step, fp32 master weights should be on CPU" + "After step, fp32 master weights should be back on CPU" ) - print(f" [rank {rank}] PASSED: test_step_runs") - finally: - parallel_state.destroy_model_parallel() + @pytest.mark.parametrize('tp,pp', [(2, 2), (4, 1)]) + @pytest.mark.parametrize('n_steps', [3, 5]) + def test_numerical_equivalence(self, tp, pp, n_steps): + """Offloaded and non-offloaded optimizers produce identical results.""" + if tp * pp > torch.cuda.device_count(): + pytest.skip("Not enough GPUs") + Utils.initialize_model_parallel(tp, pp) -def test_numerical_equivalence(rank, n_steps=5): - """Offloaded and non-offloaded optimizers produce identical fp32 master weights.""" - parallel_state.initialize_model_parallel(2, 2) - try: - model_off = create_model(seed=42, tp=2, pp=2) - model_ref = create_model(seed=42, tp=2, pp=2) + model_off = _create_model(seed=42, tp=tp, pp=pp) + model_ref = _create_model(seed=42, tp=tp, pp=pp) - opt_off = create_optimizer_with_cpu_offload(model_off, cpu_offload=True) - opt_ref = create_optimizer_with_cpu_offload(model_ref, cpu_offload=False) + opt_off = _create_optimizer(model_off, cpu_offload=True) + opt_ref = _create_optimizer(model_ref, cpu_offload=False) assert isinstance(opt_off, LayerWiseDistributedOptimizer) assert isinstance(opt_ref, LayerWiseDistributedOptimizer) + rank = torch.distributed.get_rank() + for step_i in range(n_steps): torch.manual_seed(1000 + step_i + rank) @@ -228,17 +212,14 @@ def test_numerical_equivalence(rank, n_steps=5): opt_off.reload_optimizer_states() - for opt_o, opt_r in zip( - _iter_fp16_opts(opt_off), _iter_fp16_opts(opt_ref) - ): + for opt_o, opt_r in zip(_iter_fp16_opts(opt_off), _iter_fp16_opts(opt_ref)): for grp_o, grp_r in zip( opt_o.fp32_from_float16_groups, opt_r.fp32_from_float16_groups ): for pidx, (p_o, p_r) in enumerate(zip(grp_o, grp_r)): p_o_gpu = p_o.data.to('cuda') if not p_o.data.is_cuda else p_o.data assert torch.equal(p_o_gpu, p_r.data), ( - f"[rank {rank}] fp32 master weight mismatch at param {pidx} " - f"after {n_steps} steps, " + f"fp32 master weight mismatch at param {pidx} after {n_steps} steps, " f"max diff = {(p_o_gpu - p_r.data).abs().max().item()}" ) @@ -252,53 +233,8 @@ def test_numerical_equivalence(rank, n_steps=5): continue v_o_gpu = v_o.to('cuda') if not v_o.is_cuda else v_o assert torch.equal(v_o_gpu, v_r), ( - f"[rank {rank}] optimizer state '{skey}' mismatch " - f"after {n_steps} steps, " + f"optimizer state '{skey}' mismatch after {n_steps} steps, " f"max diff = {(v_o_gpu - v_r).abs().max().item()}" ) opt_off.offload_optimizer_states() - - print(f" [rank {rank}] PASSED: test_numerical_equivalence ({n_steps} steps)") - finally: - parallel_state.destroy_model_parallel() - - -def main(): - rank, world_size = init_distributed() - - tests = [ - ("test_states_on_cpu", test_states_on_cpu), - ("test_roundtrip_correctness", test_roundtrip_correctness), - ("test_step_runs", test_step_runs), - ("test_numerical_equivalence", test_numerical_equivalence), - ] - - passed, failed = 0, 0 - for name, fn in tests: - torch.distributed.barrier() - if rank == 0: - print(f"\n{'='*60}") - print(f"Running: {name}") - print(f"{'='*60}") - try: - fn(rank) - passed += 1 - except Exception: - failed += 1 - if rank == 0: - traceback.print_exc() - print(f" [rank {rank}] FAILED: {name}") - - torch.distributed.barrier() - if rank == 0: - print(f"\n{'='*60}") - print(f"Results: {passed} passed, {failed} failed out of {len(tests)}") - print(f"{'='*60}") - - torch.distributed.destroy_process_group() - sys.exit(1 if failed > 0 else 0) - - -if __name__ == '__main__': - main() From 791c34e675ac76516c56f181e3049c79ea19d3b1 Mon Sep 17 00:00:00 2001 From: pengdurice Date: Thu, 30 Apr 2026 20:30:07 +0000 Subject: [PATCH 03/15] add style related changes Signed-off-by: pengdurice --- megatron/core/optimizer/__init__.py | 4 + .../core/optimizer/layer_wise_optimizer.py | 96 ++++-- muon_slurm.slurm | 38 +++ tests/test_muon_cpu_offload.py | 304 ++++++++++++++++++ .../test_layer_wise_optimizer.py | 120 ------- .../test_muon_cpu_offload.py | 74 ++++- 6 files changed, 478 insertions(+), 158 deletions(-) create mode 100644 muon_slurm.slurm create mode 100644 tests/test_muon_cpu_offload.py diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index caa60ce23ea..1da120a77bd 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -823,6 +823,10 @@ def _get_megatron_emerging_optimizer( fallback_config.optimizer = opt_name fallback_config.use_distributed_optimizer = False if use_layer_wise: + # Disable per-optimizer CPU offload (HybridDeviceOptimizer) for the + # Adam fallback when LayerWiseDistributedOptimizer is active. + # CPU offloading is handled uniformly by LayerWiseDistributedOptimizer + # for all sub-optimizers (Muon + Adam), preventing double-offloading. fallback_config.optimizer_cpu_offload = False result = _get_megatron_optimizer_based_on_param_groups( config=fallback_config, diff --git a/megatron/core/optimizer/layer_wise_optimizer.py b/megatron/core/optimizer/layer_wise_optimizer.py index 0459c42c4b4..d5ba13fb9c5 100644 --- a/megatron/core/optimizer/layer_wise_optimizer.py +++ b/megatron/core/optimizer/layer_wise_optimizer.py @@ -1,7 +1,7 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import logging -from typing import Callable, List, Optional +from typing import Callable, Dict, List, Optional, Tuple import torch from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors @@ -27,16 +27,40 @@ class LayerWiseDistributedOptimizer(ChainedOptimizer): """Layer-wise distributed optimizer for Megatron-core models. Experimental distributed optimizer wrapper that distributes weight to DP ranks by layer. - Implemented as ChainedOptimizer to support multiple optimizers (e.g. muon + adamW) + Implemented as ChainedOptimizer to support multiple optimizers (e.g. muon + adamW). When using, keep all megatron distributed-optimizer related options OFF. - How LayerWiseDistributedOptimizer work: - 1. weights are splited into lists and each rank only keep its shard in its optimizer - 2. Megatron DDP handle allreduce grad, note that each rank have full model and grad - 3. optimizer is already modified so only param belong to this DP rank is updated - 4. grad_norm and zero counting will reduce metrics globally in step function - 5. Do regular update with chained optimizers, modified optimizer only update shard - 6. allgather updated params to every rank + How LayerWiseDistributedOptimizer works: + + 1. Weights are split into lists and each rank only keeps its shard in its optimizer. + 2. Megatron DDP handles allreduce grad; each rank has full model and grad. + 3. Optimizer is modified so only params belonging to this DP rank are updated. + 4. grad_norm and zero counting reduce metrics globally in step function. + 5. Regular update with chained optimizers; modified optimizer only updates shard. + 6. All-gather (or broadcast) updated params to every rank. + + CPU Offloading: + + When ``optimizer_cpu_offload=True`` in the config, this optimizer manages a + host-device-host (H2D/D2H) cycle for the fp32 master weights and momentum + buffers owned by the wrapped ``Float16OptimizerWithFloat16Params`` sub-optimizers. + This is particularly beneficial for Muon, where most model parameters are + "muonable" and their fp32 master weights + momentum constitute the majority + of optimizer memory. + + The offload lifecycle per training step: + + 1. **reload_optimizer_states()**: Move fp32 master weights and optimizer state + tensors from CPU pinned memory back to GPU before the optimizer step. + 2. **super().step()**: Run the actual optimizer update (e.g. Newton-Schulz for + Muon, Adam for fallback params) on GPU. + 3. **broadcast_params()**: Synchronize updated bf16 model params across DP ranks. + 4. **offload_optimizer_states()**: Move fp32 master weights and optimizer state + tensors back to CPU pinned memory to free GPU memory. + + Note: The Adam fallback optimizer's ``optimizer_cpu_offload`` is set to False + when ``use_layer_wise_distributed_optimizer=True``, preventing double-offloading + via ``HybridDeviceOptimizer``. All offloading is unified through this class. """ def __init__( @@ -103,7 +127,7 @@ def __init__( # This way each rank do some duplicated work but allgather_v is no longer needed # All current distopt optimization can also be potentially applied - def shard_params(self, optimizers): + def shard_params(self, optimizers: List[MegatronOptimizer]) -> None: """Shard all params into lists by rank.""" # list of parameter are sorted by numel and assigned to ranks in ping-pong style # example of 4 ranks and 10 parameters p0-p9 after sorting, then dp_cp_params_list will be @@ -157,7 +181,7 @@ def shard_params(self, optimizers): if expt_dp_size == 1 or len(self.expt_dp_params_list[0]) == 0: self.expt_dp_params_list = None - def set_bucket_layerwise_params_list(self, model_chunks): + def set_bucket_layerwise_params_list(self, model_chunks: List) -> None: """Map sharded params to DDP buckets for async all-gather. For each bucket in each model chunk's bucket groups, build per-rank param lists @@ -245,9 +269,8 @@ def _allgather_helper(params_list, group): _allgather_helper(self.expt_dp_params_list, self.pg_collection.expt_dp) @torch.no_grad() - def broadcast_params(self): - """All rank broadcast updated local params.""" - # Broadcast linear layer weights to all other ranks. Kept as reference test. + def broadcast_params(self) -> None: + """Broadcast updated params from owning rank to all other DP ranks.""" if self.dp_cp_params_list is None: return for i, params in enumerate(self.dp_cp_params_list): @@ -262,8 +285,8 @@ def broadcast_params(self): torch.distributed.broadcast(p, src_global_rank, self.pg_collection.expt_dp) @torch.no_grad() - def get_grad_norm(self): - # similar to dist opt, always aggregate globally + def get_grad_norm(self) -> torch.Tensor: + """Compute global grad norm aggregated across all DP ranks.""" grads_for_norm = [] for optimizer in self.chained_optimizers: grads_for_norm += optimizer.get_main_grads_for_grad_norm() @@ -271,7 +294,8 @@ def get_grad_norm(self): return grad_norm @torch.no_grad() - def count_zeros(self): + def count_zeros(self) -> torch.Tensor: + """Count zero-valued gradients aggregated across all DP ranks.""" params = [] for optimizer in self.chained_optimizers: params += optimizer.get_parameters() @@ -282,8 +306,15 @@ def count_zeros(self): ) @torch.no_grad() - def step(self): # type: ignore[no-untyped-def] - """step function for layer-wise optimizer.""" + def step(self) -> Tuple[bool, Optional[float], Optional[int]]: + """Perform a single optimization step with optional CPU offloading. + + When CPU offloading is enabled, this method orchestrates the full cycle: + reload states to GPU -> optimizer step -> broadcast params -> offload states to CPU. + + Returns: + Tuple of (update_successful, grad_norm, num_zeros_in_grad). + """ if self._cpu_offload: self.reload_optimizer_states() @@ -300,8 +331,17 @@ def step(self): # type: ignore[no-untyped-def] return update_successful, grad_norm, num_zeros_in_grad @torch.no_grad() - def offload_optimizer_states(self): - """Move fp32 master weights and optimizer states to CPU pinned memory.""" + def offload_optimizer_states(self) -> None: + """Move fp32 master weights and optimizer state tensors to CPU pinned memory. + + This transfers all fp32 master weight parameters (``fp32_from_float16_groups``) + and optimizer state tensors (e.g. momentum buffers) from GPU to CPU pinned memory. + Pinned memory enables faster H2D transfers on the next reload. + + Called after each optimizer step to free GPU memory for the next forward pass. + A ``torch.cuda.synchronize()`` at entry ensures any pending GPU work (e.g. + param broadcasts) completes before tensors are moved off-device. + """ torch.cuda.synchronize() for opt in self.chained_optimizers: if getattr(opt, 'is_stub_optimizer', False): @@ -318,8 +358,16 @@ def offload_optimizer_states(self): state_vals[key] = val.cpu().pin_memory() @torch.no_grad() - def reload_optimizer_states(self): - """Move fp32 master weights and optimizer states back to GPU.""" + def reload_optimizer_states(self) -> None: + """Move fp32 master weights and optimizer state tensors back to GPU. + + This transfers all fp32 master weight parameters and optimizer state tensors + from CPU pinned memory to the current CUDA device. A ``torch.cuda.synchronize()`` + at exit ensures all H2D transfers complete before the optimizer step proceeds. + + Called at the start of each optimizer step so that the Newton-Schulz iterations + (Muon) or Adam updates can operate on GPU tensors. + """ for opt in self.chained_optimizers: if getattr(opt, 'is_stub_optimizer', False): continue @@ -339,7 +387,7 @@ def reload_optimizer_states(self): # fp32_from_fp16_params is list, each sub list could be empty if group is empty # this breaks dist checkpointing assumption since extract_sharded_base drop list structure # for now, we convert it to dict with index as key and convert back in load_state_dict - def load_state_dict(self, state_dict): + def load_state_dict(self, state_dict: Dict) -> None: if len(self.chained_optimizers) == 1: wrapped_state_dict = {1: state_dict} else: diff --git a/muon_slurm.slurm b/muon_slurm.slurm new file mode 100644 index 00000000000..5c74aa8903e --- /dev/null +++ b/muon_slurm.slurm @@ -0,0 +1,38 @@ +#!/bin/bash +#SBATCH --job-name=muon-cpu-offload +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=4 +#SBATCH --output=./logs/muon_cpu_offload-%j.out +#SBATCH --error=./logs/muon_cpu_offload-%j.err +#SBATCH --exclusive +#SBATCH --time=0:30:00 +#SBATCH --partition=h200-hi-preempt +#SBATCH --exclude=ip-10-1-115-124,ip-10-1-45-16 +#SBATCH --wait-all-nodes=1 + +set -euxo pipefail + +IMAGE=/fsx/peng/containers/nemo-rl-v0.5.2-te-tilelang.sqsh + +MEGATRON_SRC=/fsx/peng/workspace/megatron-muon-cpu-offload/Megatron-LM +MEGATRON_DST=/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM + +MOUNTS=${MEGATRON_SRC}:${MEGATRON_DST} + +srun \ + --container-image "${IMAGE}" \ + --container-mounts "${MOUNTS}" \ + --container-remap-root \ + bash -lc " + set -euxo pipefail + cd ${MEGATRON_DST} + + export NVIDIA_PYTORCH_VERSION=25.06 + export PYTHONPATH=${MEGATRON_DST}:\${PYTHONPATH:-} + pip install git+https://github.com/NVIDIA-NeMo/Emerging-Optimizers.git@v0.2.0 + + torchrun --nproc-per-node=4 tests/test_muon_cpu_offload.py + " + +echo "Test completed." diff --git a/tests/test_muon_cpu_offload.py b/tests/test_muon_cpu_offload.py new file mode 100644 index 00000000000..cebcaf3198b --- /dev/null +++ b/tests/test_muon_cpu_offload.py @@ -0,0 +1,304 @@ +"""Standalone tests for Muon CPU offloading in LayerWiseDistributedOptimizer. + +Run with: + torchrun --nproc-per-node=4 tests/test_muon_cpu_offload.py + +Avoids the pytest conftest circular-import issue by running as a plain script. +""" + +import os +import sys +import traceback +from datetime import timedelta + +import torch +import torch.distributed + +from megatron.core import parallel_state +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.optimizer import get_megatron_optimizer +from megatron.core.optimizer.layer_wise_optimizer import LayerWiseDistributedOptimizer +from megatron.core.optimizer.optimizer import Float16OptimizerWithFloat16Params +from megatron.core.optimizer.optimizer_config import OptimizerConfig +from megatron.core.tensor_parallel import model_parallel_cuda_manual_seed +from megatron.core.transformer import TransformerConfig + + +def init_distributed(): + rank = int(os.environ['LOCAL_RANK']) + world_size = int(os.environ['WORLD_SIZE']) + torch.cuda.set_device(rank) + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group( + backend='nccl', world_size=world_size, rank=rank, + timeout=timedelta(minutes=2), + ) + return rank, world_size + + +def create_model(seed, tp, pp): + torch.manual_seed(seed) + model_parallel_cuda_manual_seed(seed) + config = TransformerConfig( + num_layers=6, + hidden_size=16, + num_attention_heads=8, + use_cpu_initialization=True, + pipeline_dtype=torch.bfloat16, + bf16=True, + ) + layer_spec = get_gpt_layer_with_transformer_engine_spec() + model = GPTModel( + config=config, + transformer_layer_spec=layer_spec, + vocab_size=128, + max_sequence_length=4, + pre_process=parallel_state.is_pipeline_first_stage(), + post_process=parallel_state.is_pipeline_last_stage(), + ) + model.cuda(torch.cuda.current_device()) + return model + + +def create_optimizer_with_cpu_offload(model, cpu_offload=True): + config = OptimizerConfig( + bf16=True, + params_dtype=torch.bfloat16, + use_distributed_optimizer=False, + use_layer_wise_distributed_optimizer=True, + optimizer='muon', + lr=0.0, + optimizer_cpu_offload=cpu_offload, + ) + optimizer = get_megatron_optimizer(config, [model]) + + if isinstance(optimizer, LayerWiseDistributedOptimizer): + for opt in optimizer.chained_optimizers: + if getattr(opt, 'init_state_fn', None) is None: + continue + if not hasattr(opt, 'optimizer'): + opt.init_state_fn(opt) + else: + opt.init_state_fn(opt.optimizer) + if cpu_offload: + optimizer.offload_optimizer_states() + return optimizer + + +def _iter_fp16_opts(optimizer): + for opt in optimizer.chained_optimizers: + if getattr(opt, 'is_stub_optimizer', False): + continue + if isinstance(opt, Float16OptimizerWithFloat16Params): + yield opt + + +def test_states_on_cpu(rank): + """After init with cpu_offload=True, fp32 master weights and state are on CPU.""" + parallel_state.initialize_model_parallel(2, 2) + try: + model = create_model(seed=2, tp=2, pp=2) + optimizer = create_optimizer_with_cpu_offload(model, cpu_offload=True) + + assert isinstance(optimizer, LayerWiseDistributedOptimizer) + assert optimizer._cpu_offload + + for opt in _iter_fp16_opts(optimizer): + for group in opt.fp32_from_float16_groups: + for param in group: + assert not param.data.is_cuda, ( + f"[rank {rank}] fp32 master weight should be on CPU, " + f"got {param.data.device}" + ) + for state_vals in opt.optimizer.state.values(): + for key, val in state_vals.items(): + if isinstance(val, torch.Tensor): + assert not val.is_cuda, ( + f"[rank {rank}] optimizer state '{key}' should be on CPU, " + f"got {val.device}" + ) + + print(f" [rank {rank}] PASSED: test_states_on_cpu") + finally: + parallel_state.destroy_model_parallel() + + +def test_roundtrip_correctness(rank): + """Offload -> reload preserves fp32 master weight values exactly.""" + parallel_state.initialize_model_parallel(2, 2) + try: + model = create_model(seed=2, tp=2, pp=2) + optimizer = create_optimizer_with_cpu_offload(model, cpu_offload=True) + + assert isinstance(optimizer, LayerWiseDistributedOptimizer) + + snapshots = {} + for opt in _iter_fp16_opts(optimizer): + for gidx, group in enumerate(opt.fp32_from_float16_groups): + for pidx, param in enumerate(group): + snapshots[(id(opt), gidx, pidx)] = param.data.clone() + + optimizer.reload_optimizer_states() + + for opt in _iter_fp16_opts(optimizer): + for gidx, group in enumerate(opt.fp32_from_float16_groups): + for pidx, param in enumerate(group): + assert param.data.is_cuda, ( + f"[rank {rank}] After reload, param should be on GPU" + ) + expected = snapshots[(id(opt), gidx, pidx)].to(param.data.device) + assert torch.equal(param.data, expected), ( + f"[rank {rank}] Master weight mismatch after roundtrip" + ) + + optimizer.offload_optimizer_states() + + for opt in _iter_fp16_opts(optimizer): + for gidx, group in enumerate(opt.fp32_from_float16_groups): + for pidx, param in enumerate(group): + assert not param.data.is_cuda, ( + f"[rank {rank}] After offload, param should be on CPU" + ) + + print(f" [rank {rank}] PASSED: test_roundtrip_correctness") + finally: + parallel_state.destroy_model_parallel() + + +def test_step_runs(rank): + """A full optimizer.step() succeeds with CPU offloading.""" + parallel_state.initialize_model_parallel(2, 2) + try: + model = create_model(seed=2, tp=2, pp=2) + optimizer = create_optimizer_with_cpu_offload(model, cpu_offload=True) + + assert isinstance(optimizer, LayerWiseDistributedOptimizer) + + for param in model.parameters(): + if param.requires_grad: + g = torch.randn_like(param.data) + param.grad = g + param.main_grad = g + + update_successful, grad_norm, num_zeros = optimizer.step() + assert isinstance(update_successful, bool), ( + f"[rank {rank}] update_successful should be bool, got {type(update_successful)}" + ) + + for opt in _iter_fp16_opts(optimizer): + for group in opt.fp32_from_float16_groups: + for param in group: + assert not param.data.is_cuda, ( + f"[rank {rank}] After step, fp32 master weights should be on CPU" + ) + + print(f" [rank {rank}] PASSED: test_step_runs") + finally: + parallel_state.destroy_model_parallel() + + +def test_numerical_equivalence(rank, n_steps=5): + """Offloaded and non-offloaded optimizers produce identical fp32 master weights.""" + parallel_state.initialize_model_parallel(2, 2) + try: + model_off = create_model(seed=42, tp=2, pp=2) + model_ref = create_model(seed=42, tp=2, pp=2) + + opt_off = create_optimizer_with_cpu_offload(model_off, cpu_offload=True) + opt_ref = create_optimizer_with_cpu_offload(model_ref, cpu_offload=False) + + assert isinstance(opt_off, LayerWiseDistributedOptimizer) + assert isinstance(opt_ref, LayerWiseDistributedOptimizer) + + for step_i in range(n_steps): + torch.manual_seed(1000 + step_i + rank) + + for p_off, p_ref in zip(model_off.parameters(), model_ref.parameters()): + if not p_off.requires_grad: + continue + g = torch.randn_like(p_off.data) + p_off.grad = g.clone() + p_off.main_grad = p_off.grad + p_ref.grad = g.clone() + p_ref.main_grad = p_ref.grad + + opt_off.step() + opt_ref.step() + + opt_off.reload_optimizer_states() + + for opt_o, opt_r in zip( + _iter_fp16_opts(opt_off), _iter_fp16_opts(opt_ref) + ): + for grp_o, grp_r in zip( + opt_o.fp32_from_float16_groups, opt_r.fp32_from_float16_groups + ): + for pidx, (p_o, p_r) in enumerate(zip(grp_o, grp_r)): + p_o_gpu = p_o.data.to('cuda') if not p_o.data.is_cuda else p_o.data + assert torch.equal(p_o_gpu, p_r.data), ( + f"[rank {rank}] fp32 master weight mismatch at param {pidx} " + f"after {n_steps} steps, " + f"max diff = {(p_o_gpu - p_r.data).abs().max().item()}" + ) + + for (key_o, state_o), (key_r, state_r) in zip( + opt_o.optimizer.state.items(), opt_r.optimizer.state.items() + ): + common_keys = set(state_o.keys()) & set(state_r.keys()) + for skey in common_keys: + v_o, v_r = state_o[skey], state_r[skey] + if not isinstance(v_o, torch.Tensor): + continue + v_o_gpu = v_o.to('cuda') if not v_o.is_cuda else v_o + assert torch.equal(v_o_gpu, v_r), ( + f"[rank {rank}] optimizer state '{skey}' mismatch " + f"after {n_steps} steps, " + f"max diff = {(v_o_gpu - v_r).abs().max().item()}" + ) + + opt_off.offload_optimizer_states() + + print(f" [rank {rank}] PASSED: test_numerical_equivalence ({n_steps} steps)") + finally: + parallel_state.destroy_model_parallel() + + +def main(): + rank, world_size = init_distributed() + + tests = [ + ("test_states_on_cpu", test_states_on_cpu), + ("test_roundtrip_correctness", test_roundtrip_correctness), + ("test_step_runs", test_step_runs), + ("test_numerical_equivalence", test_numerical_equivalence), + ] + + passed, failed = 0, 0 + for name, fn in tests: + torch.distributed.barrier() + if rank == 0: + print(f"\n{'='*60}") + print(f"Running: {name}") + print(f"{'='*60}") + try: + fn(rank) + passed += 1 + except Exception: + failed += 1 + if rank == 0: + traceback.print_exc() + print(f" [rank {rank}] FAILED: {name}") + + torch.distributed.barrier() + if rank == 0: + print(f"\n{'='*60}") + print(f"Results: {passed} passed, {failed} failed out of {len(tests)}") + print(f"{'='*60}") + + torch.distributed.destroy_process_group() + sys.exit(1 if failed > 0 else 0) + + +if __name__ == '__main__': + main() diff --git a/tests/unit_tests/dist_checkpointing/test_layer_wise_optimizer.py b/tests/unit_tests/dist_checkpointing/test_layer_wise_optimizer.py index 8d36fde3069..57ed6f6263c 100644 --- a/tests/unit_tests/dist_checkpointing/test_layer_wise_optimizer.py +++ b/tests/unit_tests/dist_checkpointing/test_layer_wise_optimizer.py @@ -605,123 +605,3 @@ def test_optimizer_common_state_dict( check_equal(optim_param_state_A, optim_param_state_B) Utils.destroy_model_parallel() - - -class TestMuonCPUOffload: - """Tests for Muon CPU offloading in LayerWiseDistributedOptimizer.""" - - def setup_method(self, method): - pass - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def _create_cpu_offload_optimizer(self, tp=2, pp=2, seed=2): - """Helper: build a Muon LayerWise optimizer with CPU offloading.""" - from megatron.core.optimizer import get_megatron_optimizer - from megatron.core.optimizer.optimizer_config import OptimizerConfig - - Utils.initialize_model_parallel(tp, pp) - model = initialize_real_model( - seed=seed, - pre_process=parallel_state.is_pipeline_first_stage(), - post_process=parallel_state.is_pipeline_last_stage(), - ) - model.cuda(torch.cuda.current_device()) - - config = OptimizerConfig( - bf16=True, - params_dtype=torch.bfloat16, - use_distributed_optimizer=False, - use_layer_wise_distributed_optimizer=True, - optimizer='muon', - lr=0.0, - optimizer_cpu_offload=True, - ) - - optimizer = get_megatron_optimizer(config, [model]) - - if isinstance(optimizer, LayerWiseDistributedOptimizer): - for opt in optimizer.chained_optimizers: - if not hasattr(opt, 'optimizer'): - opt.init_state_fn(opt) - else: - opt.init_state_fn(opt.optimizer) - return model, optimizer - - def test_cpu_offload_states_on_cpu(self): - """After init, Muon fp32 master weights and momentum are on CPU.""" - model, optimizer = self._create_cpu_offload_optimizer() - assert isinstance(optimizer, LayerWiseDistributedOptimizer) - assert optimizer._cpu_offload - - for opt in optimizer.chained_optimizers: - if getattr(opt, 'is_stub_optimizer', False): - continue - if not isinstance(opt, Float16OptimizerWithFloat16Params): - continue - for group in opt.fp32_from_float16_groups: - for param in group: - assert not param.data.is_cuda, ( - f"fp32 master weight should be on CPU, got {param.data.device}" - ) - for state_vals in opt.optimizer.state.values(): - for key, val in state_vals.items(): - if isinstance(val, torch.Tensor): - assert not val.is_cuda, ( - f"optimizer state '{key}' should be on CPU, got {val.device}" - ) - - def test_cpu_offload_roundtrip_correctness(self): - """Offload -> reload preserves fp32 master weight values.""" - model, optimizer = self._create_cpu_offload_optimizer() - assert isinstance(optimizer, LayerWiseDistributedOptimizer) - - snapshots = {} - for opt in optimizer.chained_optimizers: - if getattr(opt, 'is_stub_optimizer', False): - continue - if not isinstance(opt, Float16OptimizerWithFloat16Params): - continue - for gidx, group in enumerate(opt.fp32_from_float16_groups): - for pidx, param in enumerate(group): - snapshots[(id(opt), gidx, pidx)] = param.data.clone() - - optimizer.reload_optimizer_states() - - for opt in optimizer.chained_optimizers: - if getattr(opt, 'is_stub_optimizer', False): - continue - if not isinstance(opt, Float16OptimizerWithFloat16Params): - continue - for gidx, group in enumerate(opt.fp32_from_float16_groups): - for pidx, param in enumerate(group): - assert param.data.is_cuda, "After reload, param should be on GPU" - key = (id(opt), gidx, pidx) - expected = snapshots[key].to(param.data.device) - assert torch.equal(param.data, expected), ( - "Master weight mismatch after offload->reload roundtrip" - ) - - optimizer.offload_optimizer_states() - - for opt in optimizer.chained_optimizers: - if getattr(opt, 'is_stub_optimizer', False): - continue - if not isinstance(opt, Float16OptimizerWithFloat16Params): - continue - for gidx, group in enumerate(opt.fp32_from_float16_groups): - for pidx, param in enumerate(group): - assert not param.data.is_cuda, "After offload, param should be on CPU" - - def test_cpu_offload_step_runs(self): - """A full optimizer step works with CPU offloading enabled.""" - model, optimizer = self._create_cpu_offload_optimizer() - assert isinstance(optimizer, LayerWiseDistributedOptimizer) - - for param in model.parameters(): - if param.requires_grad: - param.grad = torch.randn_like(param.data) - - update_successful, grad_norm, num_zeros = optimizer.step() - assert isinstance(update_successful, bool) diff --git a/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py b/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py index 18cc09682e9..2a4685583d0 100644 --- a/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py +++ b/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py @@ -1,5 +1,18 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +"""Tests for Muon optimizer CPU offloading in LayerWiseDistributedOptimizer. + +This module verifies that the CPU offloading mechanism correctly: +- Moves fp32 master weights and optimizer state (momentum) to CPU pinned memory. +- Preserves tensor values exactly through offload/reload round-trips. +- Produces numerically identical results to the non-offloaded code path. + +These tests require multi-GPU execution (via torchrun or pytest with distributed +launcher) since LayerWiseDistributedOptimizer shards parameters across DP ranks. +""" + +from typing import Generator + import pytest import torch @@ -15,8 +28,17 @@ from tests.unit_tests.test_utilities import Utils -def _create_model(seed, tp, pp): - """Create a small GPT model for testing.""" +def _create_model(seed: int, tp: int, pp: int) -> GPTModel: + """Create a small GPT model for testing. + + Args: + seed: Random seed for reproducibility. + tp: Tensor parallel size (already initialized via Utils). + pp: Pipeline parallel size (already initialized via Utils). + + Returns: + A GPTModel instance on the current CUDA device. + """ torch.manual_seed(seed) model_parallel_cuda_manual_seed(seed) config = TransformerConfig( @@ -40,8 +62,18 @@ def _create_model(seed, tp, pp): return model -def _create_optimizer(model, cpu_offload=True): - """Create a Muon LayerWise optimizer with optional CPU offloading.""" +def _create_optimizer( + model: GPTModel, cpu_offload: bool = True +) -> LayerWiseDistributedOptimizer: + """Create a Muon LayerWise optimizer with optional CPU offloading. + + Args: + model: The GPT model whose parameters will be optimized. + cpu_offload: Whether to enable CPU offloading of optimizer states. + + Returns: + A LayerWiseDistributedOptimizer wrapping Muon + Adam fallback. + """ config = OptimizerConfig( bf16=True, params_dtype=torch.bfloat16, @@ -67,8 +99,10 @@ def _create_optimizer(model, cpu_offload=True): return optimizer -def _iter_fp16_opts(optimizer): - """Yield Float16OptimizerWithFloat16Params sub-optimizers.""" +def _iter_fp16_opts( + optimizer: LayerWiseDistributedOptimizer, +) -> Generator[Float16OptimizerWithFloat16Params, None, None]: + """Yield Float16OptimizerWithFloat16Params sub-optimizers, skipping stubs.""" for opt in optimizer.chained_optimizers: if getattr(opt, 'is_stub_optimizer', False): continue @@ -77,16 +111,22 @@ def _iter_fp16_opts(optimizer): class TestMuonCPUOffload: - """Tests for Muon CPU offloading in LayerWiseDistributedOptimizer.""" + """Tests for Muon CPU offloading in LayerWiseDistributedOptimizer. + + Verifies the correctness of the CPU offload mechanism that moves fp32 master + weights and momentum buffers between GPU and CPU pinned memory each step. + Tests cover state placement, round-trip fidelity, step execution, and + bit-exact numerical equivalence against the non-offloaded baseline. + """ - def setup_method(self, method): + def setup_method(self, method) -> None: pass - def teardown_method(self, method): + def teardown_method(self, method) -> None: Utils.destroy_model_parallel() @pytest.mark.parametrize('tp,pp', [(2, 2), (1, 4), (4, 1)]) - def test_states_on_cpu(self, tp, pp): + def test_states_on_cpu(self, tp: int, pp: int) -> None: """After init with cpu_offload=True, fp32 master weights and state are on CPU.""" if tp * pp > torch.cuda.device_count(): pytest.skip("Not enough GPUs") @@ -112,7 +152,7 @@ def test_states_on_cpu(self, tp, pp): ) @pytest.mark.parametrize('tp,pp', [(2, 2), (1, 4), (4, 1)]) - def test_roundtrip_correctness(self, tp, pp): + def test_roundtrip_correctness(self, tp: int, pp: int) -> None: """Offload -> reload preserves fp32 master weight values exactly.""" if tp * pp > torch.cuda.device_count(): pytest.skip("Not enough GPUs") @@ -148,7 +188,7 @@ def test_roundtrip_correctness(self, tp, pp): assert not param.data.is_cuda, "After offload, param should be on CPU" @pytest.mark.parametrize('tp,pp', [(2, 2), (1, 4), (4, 1)]) - def test_step_runs(self, tp, pp): + def test_step_runs(self, tp: int, pp: int) -> None: """A full optimizer.step() succeeds with CPU offloading.""" if tp * pp > torch.cuda.device_count(): pytest.skip("Not enough GPUs") @@ -177,8 +217,14 @@ def test_step_runs(self, tp, pp): @pytest.mark.parametrize('tp,pp', [(2, 2), (4, 1)]) @pytest.mark.parametrize('n_steps', [3, 5]) - def test_numerical_equivalence(self, tp, pp, n_steps): - """Offloaded and non-offloaded optimizers produce identical results.""" + def test_numerical_equivalence(self, tp: int, pp: int, n_steps: int) -> None: + """Offloaded and non-offloaded optimizers produce bit-identical results. + + Runs both an offloaded and a non-offloaded optimizer for ``n_steps`` + with identical random gradients, then verifies that fp32 master weights + and optimizer state tensors match exactly. This ensures the offload/reload + cycle introduces zero numerical drift. + """ if tp * pp > torch.cuda.device_count(): pytest.skip("Not enough GPUs") From dabe37041a191c132eaafa9d2385df8f064cc14e Mon Sep 17 00:00:00 2001 From: pengdurice Date: Thu, 30 Apr 2026 21:17:23 +0000 Subject: [PATCH 04/15] reverrt ct some style changes Signed-off-by: pengdurice --- .../core/optimizer/layer_wise_optimizer.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/megatron/core/optimizer/layer_wise_optimizer.py b/megatron/core/optimizer/layer_wise_optimizer.py index d5ba13fb9c5..b8395fa5c76 100644 --- a/megatron/core/optimizer/layer_wise_optimizer.py +++ b/megatron/core/optimizer/layer_wise_optimizer.py @@ -1,7 +1,7 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import logging -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, List, Optional, Tuple import torch from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors @@ -127,7 +127,7 @@ def __init__( # This way each rank do some duplicated work but allgather_v is no longer needed # All current distopt optimization can also be potentially applied - def shard_params(self, optimizers: List[MegatronOptimizer]) -> None: + def shard_params(self, optimizers): """Shard all params into lists by rank.""" # list of parameter are sorted by numel and assigned to ranks in ping-pong style # example of 4 ranks and 10 parameters p0-p9 after sorting, then dp_cp_params_list will be @@ -181,7 +181,7 @@ def shard_params(self, optimizers: List[MegatronOptimizer]) -> None: if expt_dp_size == 1 or len(self.expt_dp_params_list[0]) == 0: self.expt_dp_params_list = None - def set_bucket_layerwise_params_list(self, model_chunks: List) -> None: + def set_bucket_layerwise_params_list(self, model_chunks): """Map sharded params to DDP buckets for async all-gather. For each bucket in each model chunk's bucket groups, build per-rank param lists @@ -269,8 +269,8 @@ def _allgather_helper(params_list, group): _allgather_helper(self.expt_dp_params_list, self.pg_collection.expt_dp) @torch.no_grad() - def broadcast_params(self) -> None: - """Broadcast updated params from owning rank to all other DP ranks.""" + def broadcast_params(self): + """All rank broadcast updated local params.""" if self.dp_cp_params_list is None: return for i, params in enumerate(self.dp_cp_params_list): @@ -285,8 +285,8 @@ def broadcast_params(self) -> None: torch.distributed.broadcast(p, src_global_rank, self.pg_collection.expt_dp) @torch.no_grad() - def get_grad_norm(self) -> torch.Tensor: - """Compute global grad norm aggregated across all DP ranks.""" + def get_grad_norm(self): + # similar to dist opt, always aggregate globally grads_for_norm = [] for optimizer in self.chained_optimizers: grads_for_norm += optimizer.get_main_grads_for_grad_norm() @@ -294,8 +294,7 @@ def get_grad_norm(self) -> torch.Tensor: return grad_norm @torch.no_grad() - def count_zeros(self) -> torch.Tensor: - """Count zero-valued gradients aggregated across all DP ranks.""" + def count_zeros(self): params = [] for optimizer in self.chained_optimizers: params += optimizer.get_parameters() @@ -387,7 +386,7 @@ def reload_optimizer_states(self) -> None: # fp32_from_fp16_params is list, each sub list could be empty if group is empty # this breaks dist checkpointing assumption since extract_sharded_base drop list structure # for now, we convert it to dict with index as key and convert back in load_state_dict - def load_state_dict(self, state_dict: Dict) -> None: + def load_state_dict(self, state_dict): if len(self.chained_optimizers) == 1: wrapped_state_dict = {1: state_dict} else: From de6afccdedfc9d306dc4245d042e02037bfca83c Mon Sep 17 00:00:00 2001 From: pengdurice Date: Thu, 30 Apr 2026 21:25:30 +0000 Subject: [PATCH 05/15] Remove standalone test script and slurm file from PR --- .gitignore | 3 +- muon_slurm.slurm | 38 ----- tests/test_muon_cpu_offload.py | 304 --------------------------------- 3 files changed, 2 insertions(+), 343 deletions(-) delete mode 100644 muon_slurm.slurm delete mode 100644 tests/test_muon_cpu_offload.py diff --git a/.gitignore b/.gitignore index 5556d1d5a4a..b8b740584bf 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,5 @@ docs/_build docs/apidocs # Git worktrees -.worktrees/ \ No newline at end of file +.worktrees/tests/test_muon_cpu_offload.py +muon_slurm.slurm diff --git a/muon_slurm.slurm b/muon_slurm.slurm deleted file mode 100644 index 5c74aa8903e..00000000000 --- a/muon_slurm.slurm +++ /dev/null @@ -1,38 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=muon-cpu-offload -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-node=4 -#SBATCH --output=./logs/muon_cpu_offload-%j.out -#SBATCH --error=./logs/muon_cpu_offload-%j.err -#SBATCH --exclusive -#SBATCH --time=0:30:00 -#SBATCH --partition=h200-hi-preempt -#SBATCH --exclude=ip-10-1-115-124,ip-10-1-45-16 -#SBATCH --wait-all-nodes=1 - -set -euxo pipefail - -IMAGE=/fsx/peng/containers/nemo-rl-v0.5.2-te-tilelang.sqsh - -MEGATRON_SRC=/fsx/peng/workspace/megatron-muon-cpu-offload/Megatron-LM -MEGATRON_DST=/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM - -MOUNTS=${MEGATRON_SRC}:${MEGATRON_DST} - -srun \ - --container-image "${IMAGE}" \ - --container-mounts "${MOUNTS}" \ - --container-remap-root \ - bash -lc " - set -euxo pipefail - cd ${MEGATRON_DST} - - export NVIDIA_PYTORCH_VERSION=25.06 - export PYTHONPATH=${MEGATRON_DST}:\${PYTHONPATH:-} - pip install git+https://github.com/NVIDIA-NeMo/Emerging-Optimizers.git@v0.2.0 - - torchrun --nproc-per-node=4 tests/test_muon_cpu_offload.py - " - -echo "Test completed." diff --git a/tests/test_muon_cpu_offload.py b/tests/test_muon_cpu_offload.py deleted file mode 100644 index cebcaf3198b..00000000000 --- a/tests/test_muon_cpu_offload.py +++ /dev/null @@ -1,304 +0,0 @@ -"""Standalone tests for Muon CPU offloading in LayerWiseDistributedOptimizer. - -Run with: - torchrun --nproc-per-node=4 tests/test_muon_cpu_offload.py - -Avoids the pytest conftest circular-import issue by running as a plain script. -""" - -import os -import sys -import traceback -from datetime import timedelta - -import torch -import torch.distributed - -from megatron.core import parallel_state -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.optimizer import get_megatron_optimizer -from megatron.core.optimizer.layer_wise_optimizer import LayerWiseDistributedOptimizer -from megatron.core.optimizer.optimizer import Float16OptimizerWithFloat16Params -from megatron.core.optimizer.optimizer_config import OptimizerConfig -from megatron.core.tensor_parallel import model_parallel_cuda_manual_seed -from megatron.core.transformer import TransformerConfig - - -def init_distributed(): - rank = int(os.environ['LOCAL_RANK']) - world_size = int(os.environ['WORLD_SIZE']) - torch.cuda.set_device(rank) - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group( - backend='nccl', world_size=world_size, rank=rank, - timeout=timedelta(minutes=2), - ) - return rank, world_size - - -def create_model(seed, tp, pp): - torch.manual_seed(seed) - model_parallel_cuda_manual_seed(seed) - config = TransformerConfig( - num_layers=6, - hidden_size=16, - num_attention_heads=8, - use_cpu_initialization=True, - pipeline_dtype=torch.bfloat16, - bf16=True, - ) - layer_spec = get_gpt_layer_with_transformer_engine_spec() - model = GPTModel( - config=config, - transformer_layer_spec=layer_spec, - vocab_size=128, - max_sequence_length=4, - pre_process=parallel_state.is_pipeline_first_stage(), - post_process=parallel_state.is_pipeline_last_stage(), - ) - model.cuda(torch.cuda.current_device()) - return model - - -def create_optimizer_with_cpu_offload(model, cpu_offload=True): - config = OptimizerConfig( - bf16=True, - params_dtype=torch.bfloat16, - use_distributed_optimizer=False, - use_layer_wise_distributed_optimizer=True, - optimizer='muon', - lr=0.0, - optimizer_cpu_offload=cpu_offload, - ) - optimizer = get_megatron_optimizer(config, [model]) - - if isinstance(optimizer, LayerWiseDistributedOptimizer): - for opt in optimizer.chained_optimizers: - if getattr(opt, 'init_state_fn', None) is None: - continue - if not hasattr(opt, 'optimizer'): - opt.init_state_fn(opt) - else: - opt.init_state_fn(opt.optimizer) - if cpu_offload: - optimizer.offload_optimizer_states() - return optimizer - - -def _iter_fp16_opts(optimizer): - for opt in optimizer.chained_optimizers: - if getattr(opt, 'is_stub_optimizer', False): - continue - if isinstance(opt, Float16OptimizerWithFloat16Params): - yield opt - - -def test_states_on_cpu(rank): - """After init with cpu_offload=True, fp32 master weights and state are on CPU.""" - parallel_state.initialize_model_parallel(2, 2) - try: - model = create_model(seed=2, tp=2, pp=2) - optimizer = create_optimizer_with_cpu_offload(model, cpu_offload=True) - - assert isinstance(optimizer, LayerWiseDistributedOptimizer) - assert optimizer._cpu_offload - - for opt in _iter_fp16_opts(optimizer): - for group in opt.fp32_from_float16_groups: - for param in group: - assert not param.data.is_cuda, ( - f"[rank {rank}] fp32 master weight should be on CPU, " - f"got {param.data.device}" - ) - for state_vals in opt.optimizer.state.values(): - for key, val in state_vals.items(): - if isinstance(val, torch.Tensor): - assert not val.is_cuda, ( - f"[rank {rank}] optimizer state '{key}' should be on CPU, " - f"got {val.device}" - ) - - print(f" [rank {rank}] PASSED: test_states_on_cpu") - finally: - parallel_state.destroy_model_parallel() - - -def test_roundtrip_correctness(rank): - """Offload -> reload preserves fp32 master weight values exactly.""" - parallel_state.initialize_model_parallel(2, 2) - try: - model = create_model(seed=2, tp=2, pp=2) - optimizer = create_optimizer_with_cpu_offload(model, cpu_offload=True) - - assert isinstance(optimizer, LayerWiseDistributedOptimizer) - - snapshots = {} - for opt in _iter_fp16_opts(optimizer): - for gidx, group in enumerate(opt.fp32_from_float16_groups): - for pidx, param in enumerate(group): - snapshots[(id(opt), gidx, pidx)] = param.data.clone() - - optimizer.reload_optimizer_states() - - for opt in _iter_fp16_opts(optimizer): - for gidx, group in enumerate(opt.fp32_from_float16_groups): - for pidx, param in enumerate(group): - assert param.data.is_cuda, ( - f"[rank {rank}] After reload, param should be on GPU" - ) - expected = snapshots[(id(opt), gidx, pidx)].to(param.data.device) - assert torch.equal(param.data, expected), ( - f"[rank {rank}] Master weight mismatch after roundtrip" - ) - - optimizer.offload_optimizer_states() - - for opt in _iter_fp16_opts(optimizer): - for gidx, group in enumerate(opt.fp32_from_float16_groups): - for pidx, param in enumerate(group): - assert not param.data.is_cuda, ( - f"[rank {rank}] After offload, param should be on CPU" - ) - - print(f" [rank {rank}] PASSED: test_roundtrip_correctness") - finally: - parallel_state.destroy_model_parallel() - - -def test_step_runs(rank): - """A full optimizer.step() succeeds with CPU offloading.""" - parallel_state.initialize_model_parallel(2, 2) - try: - model = create_model(seed=2, tp=2, pp=2) - optimizer = create_optimizer_with_cpu_offload(model, cpu_offload=True) - - assert isinstance(optimizer, LayerWiseDistributedOptimizer) - - for param in model.parameters(): - if param.requires_grad: - g = torch.randn_like(param.data) - param.grad = g - param.main_grad = g - - update_successful, grad_norm, num_zeros = optimizer.step() - assert isinstance(update_successful, bool), ( - f"[rank {rank}] update_successful should be bool, got {type(update_successful)}" - ) - - for opt in _iter_fp16_opts(optimizer): - for group in opt.fp32_from_float16_groups: - for param in group: - assert not param.data.is_cuda, ( - f"[rank {rank}] After step, fp32 master weights should be on CPU" - ) - - print(f" [rank {rank}] PASSED: test_step_runs") - finally: - parallel_state.destroy_model_parallel() - - -def test_numerical_equivalence(rank, n_steps=5): - """Offloaded and non-offloaded optimizers produce identical fp32 master weights.""" - parallel_state.initialize_model_parallel(2, 2) - try: - model_off = create_model(seed=42, tp=2, pp=2) - model_ref = create_model(seed=42, tp=2, pp=2) - - opt_off = create_optimizer_with_cpu_offload(model_off, cpu_offload=True) - opt_ref = create_optimizer_with_cpu_offload(model_ref, cpu_offload=False) - - assert isinstance(opt_off, LayerWiseDistributedOptimizer) - assert isinstance(opt_ref, LayerWiseDistributedOptimizer) - - for step_i in range(n_steps): - torch.manual_seed(1000 + step_i + rank) - - for p_off, p_ref in zip(model_off.parameters(), model_ref.parameters()): - if not p_off.requires_grad: - continue - g = torch.randn_like(p_off.data) - p_off.grad = g.clone() - p_off.main_grad = p_off.grad - p_ref.grad = g.clone() - p_ref.main_grad = p_ref.grad - - opt_off.step() - opt_ref.step() - - opt_off.reload_optimizer_states() - - for opt_o, opt_r in zip( - _iter_fp16_opts(opt_off), _iter_fp16_opts(opt_ref) - ): - for grp_o, grp_r in zip( - opt_o.fp32_from_float16_groups, opt_r.fp32_from_float16_groups - ): - for pidx, (p_o, p_r) in enumerate(zip(grp_o, grp_r)): - p_o_gpu = p_o.data.to('cuda') if not p_o.data.is_cuda else p_o.data - assert torch.equal(p_o_gpu, p_r.data), ( - f"[rank {rank}] fp32 master weight mismatch at param {pidx} " - f"after {n_steps} steps, " - f"max diff = {(p_o_gpu - p_r.data).abs().max().item()}" - ) - - for (key_o, state_o), (key_r, state_r) in zip( - opt_o.optimizer.state.items(), opt_r.optimizer.state.items() - ): - common_keys = set(state_o.keys()) & set(state_r.keys()) - for skey in common_keys: - v_o, v_r = state_o[skey], state_r[skey] - if not isinstance(v_o, torch.Tensor): - continue - v_o_gpu = v_o.to('cuda') if not v_o.is_cuda else v_o - assert torch.equal(v_o_gpu, v_r), ( - f"[rank {rank}] optimizer state '{skey}' mismatch " - f"after {n_steps} steps, " - f"max diff = {(v_o_gpu - v_r).abs().max().item()}" - ) - - opt_off.offload_optimizer_states() - - print(f" [rank {rank}] PASSED: test_numerical_equivalence ({n_steps} steps)") - finally: - parallel_state.destroy_model_parallel() - - -def main(): - rank, world_size = init_distributed() - - tests = [ - ("test_states_on_cpu", test_states_on_cpu), - ("test_roundtrip_correctness", test_roundtrip_correctness), - ("test_step_runs", test_step_runs), - ("test_numerical_equivalence", test_numerical_equivalence), - ] - - passed, failed = 0, 0 - for name, fn in tests: - torch.distributed.barrier() - if rank == 0: - print(f"\n{'='*60}") - print(f"Running: {name}") - print(f"{'='*60}") - try: - fn(rank) - passed += 1 - except Exception: - failed += 1 - if rank == 0: - traceback.print_exc() - print(f" [rank {rank}] FAILED: {name}") - - torch.distributed.barrier() - if rank == 0: - print(f"\n{'='*60}") - print(f"Results: {passed} passed, {failed} failed out of {len(tests)}") - print(f"{'='*60}") - - torch.distributed.destroy_process_group() - sys.exit(1 if failed > 0 else 0) - - -if __name__ == '__main__': - main() From a5714155ee0f62ac8ff014c01cbdc04572f3c9c6 Mon Sep 17 00:00:00 2001 From: pengdurice Date: Thu, 30 Apr 2026 21:25:55 +0000 Subject: [PATCH 06/15] Remove standalone test script and slurm file from PR Signed-off-by: pengdurice --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index b8b740584bf..296ba9ae23a 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,5 @@ docs/apidocs # Git worktrees .worktrees/tests/test_muon_cpu_offload.py muon_slurm.slurm +tests/test_muon_cpu_offload.py +muon_slurm.slurm From 86f0a939872b48df92d8c6d997734fcef8c8c64c Mon Sep 17 00:00:00 2001 From: pengdurice Date: Thu, 30 Apr 2026 21:28:10 +0000 Subject: [PATCH 07/15] restore gitignore Signed-off-by: pengdurice --- .gitignore | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 296ba9ae23a..5556d1d5a4a 100644 --- a/.gitignore +++ b/.gitignore @@ -22,7 +22,4 @@ docs/_build docs/apidocs # Git worktrees -.worktrees/tests/test_muon_cpu_offload.py -muon_slurm.slurm -tests/test_muon_cpu_offload.py -muon_slurm.slurm +.worktrees/ \ No newline at end of file From f26a8da731f13a9bf71ec18966e9f1e1ddbefb04 Mon Sep 17 00:00:00 2001 From: pengdurice Date: Thu, 30 Apr 2026 21:29:16 +0000 Subject: [PATCH 08/15] remove unnecessary load Signed-off-by: pengdurice --- tests/unit_tests/dist_checkpointing/test_layer_wise_optimizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit_tests/dist_checkpointing/test_layer_wise_optimizer.py b/tests/unit_tests/dist_checkpointing/test_layer_wise_optimizer.py index 57ed6f6263c..3f60658a005 100644 --- a/tests/unit_tests/dist_checkpointing/test_layer_wise_optimizer.py +++ b/tests/unit_tests/dist_checkpointing/test_layer_wise_optimizer.py @@ -17,7 +17,6 @@ from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.optimizer import ChainedOptimizer from megatron.core.optimizer.layer_wise_optimizer import LayerWiseDistributedOptimizer -from megatron.core.optimizer.optimizer import Float16OptimizerWithFloat16Params from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel import model_parallel_cuda_manual_seed from megatron.core.transformer import MLATransformerConfig, TransformerConfig From fc7771ecde8999b72d5e934d8b3856321e356920 Mon Sep 17 00:00:00 2001 From: pengdurice Date: Thu, 30 Apr 2026 21:33:42 +0000 Subject: [PATCH 09/15] raun autoformat Signed-off-by: pengdurice --- megatron/core/optimizer/__init__.py | 58 +++++++------------ .../core/optimizer/layer_wise_optimizer.py | 8 +-- .../test_muon_cpu_offload.py | 38 ++++++------ 3 files changed, 41 insertions(+), 63 deletions(-) diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index 1da120a77bd..c22b50df51a 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -10,6 +10,28 @@ from torch.optim import SGD as CPUSGD from torch.optim import AdamW as CPUAdam +from megatron.core import parallel_state +from megatron.core.optimizer.cpu_offloading.hybrid_optimizer import HybridDeviceOptimizer +from megatron.core.optimizer_param_scheduler import (ParamGroupOverride, + combine_param_group_overrides, + param_group_override_to_tuple) +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.fsdp_dtensor_checkpoint import get_global_unique_param_name + +from ..distributed.param_and_grad_buffer import _ParamAndGradBuffer +from ..transformer.module import MegatronModule +from ..utils import get_model_config, get_pg_rank, get_pg_size, is_te_min_version, log_single_rank +from .distrib_optimizer import DistributedOptimizer +from .emerging_optimizers import (_EMERGING_OPTIMIZERS, HAVE_EMERGING_OPTIMIZERS, + _create_emerging_optimizer) +from .grad_scaler import ConstantGradScaler, DynamicGradScaler +from .layer_wise_optimizer import LayerWiseDistributedOptimizer +from .optimizer import (ChainedOptimizer, Float16OptimizerWithFloat16Params, FP32Optimizer, + MegatronOptimizer, param_group_identifier_keys) +# Subclass aliases kept for backward compatibility; all are OptimizerConfig. +from .optimizer_config import (AdamOptimizerConfig, OptimizerConfig, ParamKey, ParamPredicate, + ParamWithNamePredicate, SGDOptimizerConfig) + try: from transformer_engine.pytorch.optimizers import FusedAdam as Adam from transformer_engine.pytorch.optimizers import FusedSGD as SGD @@ -47,44 +69,8 @@ if HAVE_EMERGING_OPTIMIZERS: from emerging_optimizers.scalar_optimizers import Lion -from megatron.core import parallel_state -from megatron.core.optimizer.cpu_offloading.hybrid_optimizer import HybridDeviceOptimizer -from megatron.core.optimizer_param_scheduler import ( - ParamGroupOverride, - combine_param_group_overrides, - param_group_override_to_tuple, -) -from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.transformer.fsdp_dtensor_checkpoint import get_global_unique_param_name -from ..distributed.param_and_grad_buffer import _ParamAndGradBuffer -from ..transformer.module import MegatronModule -from ..utils import get_model_config, get_pg_rank, get_pg_size, is_te_min_version, log_single_rank -from .distrib_optimizer import DistributedOptimizer -from .emerging_optimizers import ( - _EMERGING_OPTIMIZERS, - HAVE_EMERGING_OPTIMIZERS, - _create_emerging_optimizer, -) -from .grad_scaler import ConstantGradScaler, DynamicGradScaler -from .layer_wise_optimizer import LayerWiseDistributedOptimizer -from .optimizer import ( - ChainedOptimizer, - Float16OptimizerWithFloat16Params, - FP32Optimizer, - MegatronOptimizer, - param_group_identifier_keys, -) -# Subclass aliases kept for backward compatibility; all are OptimizerConfig. -from .optimizer_config import ( - AdamOptimizerConfig, - OptimizerConfig, - ParamKey, - ParamPredicate, - ParamWithNamePredicate, - SGDOptimizerConfig, -) logger = logging.getLogger(__name__) diff --git a/megatron/core/optimizer/layer_wise_optimizer.py b/megatron/core/optimizer/layer_wise_optimizer.py index b8395fa5c76..d91e51cdfe9 100644 --- a/megatron/core/optimizer/layer_wise_optimizer.py +++ b/megatron/core/optimizer/layer_wise_optimizer.py @@ -12,12 +12,8 @@ from megatron.core.utils import get_pg_rank, get_pg_size from .clip_grads import count_zeros_fp32, get_grad_norm_fp32 -from .optimizer import ( - ChainedOptimizer, - Float16OptimizerWithFloat16Params, - FP32Optimizer, - MegatronOptimizer, -) +from .optimizer import (ChainedOptimizer, Float16OptimizerWithFloat16Params, FP32Optimizer, + MegatronOptimizer) from .optimizer_config import OptimizerConfig logger = logging.getLogger(__name__) diff --git a/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py b/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py index 2a4685583d0..87245967610 100644 --- a/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py +++ b/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py @@ -15,17 +15,17 @@ import pytest import torch +from tests.unit_tests.test_utilities import Utils from megatron.core import parallel_state +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.optimizer import get_megatron_optimizer from megatron.core.optimizer.layer_wise_optimizer import LayerWiseDistributedOptimizer from megatron.core.optimizer.optimizer import Float16OptimizerWithFloat16Params from megatron.core.optimizer.optimizer_config import OptimizerConfig from megatron.core.tensor_parallel import model_parallel_cuda_manual_seed -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.transformer import TransformerConfig -from tests.unit_tests.test_utilities import Utils def _create_model(seed: int, tp: int, pp: int) -> GPTModel: @@ -62,9 +62,7 @@ def _create_model(seed: int, tp: int, pp: int) -> GPTModel: return model -def _create_optimizer( - model: GPTModel, cpu_offload: bool = True -) -> LayerWiseDistributedOptimizer: +def _create_optimizer(model: GPTModel, cpu_offload: bool = True) -> LayerWiseDistributedOptimizer: """Create a Muon LayerWise optimizer with optional CPU offloading. Args: @@ -141,15 +139,15 @@ def test_states_on_cpu(self, tp: int, pp: int) -> None: for opt in _iter_fp16_opts(optimizer): for group in opt.fp32_from_float16_groups: for param in group: - assert not param.data.is_cuda, ( - f"fp32 master weight should be on CPU, got {param.data.device}" - ) + assert ( + not param.data.is_cuda + ), f"fp32 master weight should be on CPU, got {param.data.device}" for state_vals in opt.optimizer.state.values(): for key, val in state_vals.items(): if isinstance(val, torch.Tensor): - assert not val.is_cuda, ( - f"optimizer state '{key}' should be on CPU, got {val.device}" - ) + assert ( + not val.is_cuda + ), f"optimizer state '{key}' should be on CPU, got {val.device}" @pytest.mark.parametrize('tp,pp', [(2, 2), (1, 4), (4, 1)]) def test_roundtrip_correctness(self, tp: int, pp: int) -> None: @@ -176,9 +174,9 @@ def test_roundtrip_correctness(self, tp: int, pp: int) -> None: for pidx, param in enumerate(group): assert param.data.is_cuda, "After reload, param should be on GPU" expected = snapshots[(id(opt), gidx, pidx)].to(param.data.device) - assert torch.equal(param.data, expected), ( - "Master weight mismatch after offload->reload roundtrip" - ) + assert torch.equal( + param.data, expected + ), "Master weight mismatch after offload->reload roundtrip" optimizer.offload_optimizer_states() @@ -211,9 +209,9 @@ def test_step_runs(self, tp: int, pp: int) -> None: for opt in _iter_fp16_opts(optimizer): for group in opt.fp32_from_float16_groups: for param in group: - assert not param.data.is_cuda, ( - "After step, fp32 master weights should be back on CPU" - ) + assert ( + not param.data.is_cuda + ), "After step, fp32 master weights should be back on CPU" @pytest.mark.parametrize('tp,pp', [(2, 2), (4, 1)]) @pytest.mark.parametrize('n_steps', [3, 5]) @@ -259,9 +257,7 @@ def test_numerical_equivalence(self, tp: int, pp: int, n_steps: int) -> None: opt_off.reload_optimizer_states() for opt_o, opt_r in zip(_iter_fp16_opts(opt_off), _iter_fp16_opts(opt_ref)): - for grp_o, grp_r in zip( - opt_o.fp32_from_float16_groups, opt_r.fp32_from_float16_groups - ): + for grp_o, grp_r in zip(opt_o.fp32_from_float16_groups, opt_r.fp32_from_float16_groups): for pidx, (p_o, p_r) in enumerate(zip(grp_o, grp_r)): p_o_gpu = p_o.data.to('cuda') if not p_o.data.is_cuda else p_o.data assert torch.equal(p_o_gpu, p_r.data), ( From 4335d277468bbe826ce8171c8dacf504c2507bd8 Mon Sep 17 00:00:00 2001 From: pengdurice Date: Fri, 29 May 2026 15:27:03 +0000 Subject: [PATCH 10/15] address comments Signed-off-by: pengdurice --- megatron/core/optimizer/__init__.py | 9 +- .../core/optimizer/layer_wise_optimizer.py | 18 ++- megatron/training/arguments.py | 7 +- .../test_muon_cpu_offload.py | 149 +++++++++++++++--- 4 files changed, 148 insertions(+), 35 deletions(-) diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index 279aa3b15e5..68b8cd35dd4 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -891,12 +891,9 @@ def _get_megatron_emerging_optimizer( "the legacy LayerWise ping-pong path for MoE models." ) fallback_config.use_distributed_optimizer = True - # Disable per-optimizer CPU offload (HybridDeviceOptimizer) for the - # Adam fallback when LayerWiseDistributedOptimizer is active. - # CPU offloading is handled uniformly by LayerWiseDistributedOptimizer - # for all sub-optimizers (Muon + Adam), preventing double-offloading. - if use_layer_wise: - fallback_config.optimizer_cpu_offload = False + # The separate DistributedOptimizer manages its own CPU offloading + # (via HybridDeviceOptimizer) independently of LayerWise — do NOT + # disable optimizer_cpu_offload here. result = _get_megatron_optimizer_based_on_param_groups( config=fallback_config, model_chunks=model_chunks, diff --git a/megatron/core/optimizer/layer_wise_optimizer.py b/megatron/core/optimizer/layer_wise_optimizer.py index 6736c1dda79..3343167c168 100644 --- a/megatron/core/optimizer/layer_wise_optimizer.py +++ b/megatron/core/optimizer/layer_wise_optimizer.py @@ -763,6 +763,19 @@ def start_param_sync_for_bucket_group_subset(self) -> None: ): model_chunk._start_bucket_group_param_sync(bucket_group, force_sync=False) + @torch.no_grad() + def prepare_grads(self) -> bool: + """Reload offloaded states before gradient preprocessing. + + ``ChainedOptimizer.step()`` calls ``prepare_grads()`` before + ``step_with_ready_grads()``. The child optimizers' + ``_copy_model_grads_to_main_grads`` assigns CUDA gradients to the fp32 + master params — which requires those params to be on GPU already. + """ + if self._cpu_offload: + self.reload_optimizer_states() + return super().prepare_grads() + @torch.no_grad() def step_with_ready_grads(self) -> bool: """Step then all-gather LayerWise-managed param buffers. @@ -772,8 +785,9 @@ def step_with_ready_grads(self) -> bool: which calls ``step_with_ready_grads`` directly on each child and bypasses ``step``. - When CPU offloading is enabled, orchestrates the full cycle: - reload states to GPU -> optimizer step -> param sync -> offload states to CPU. + When CPU offloading is enabled and this method is called directly + (without a preceding ``prepare_grads``), reload states first. + ``reload_optimizer_states`` is idempotent (skips already-CUDA tensors). """ if self._cpu_offload: self.reload_optimizer_states() diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index cd3ce44c3a4..d6cfc6b6d2b 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1662,9 +1662,10 @@ def validate_args(args, defaults={}): # Optimizer CPU offload check if args.optimizer_cpu_offload: - assert args.use_precision_aware_optimizer, ( - "The optimizer cpu offload must be used in conjunction with `--use-precision-aware-optimizer`, " - "as the hybrid device optimizer reuses the code path of this flag." + assert args.use_precision_aware_optimizer or args.use_layer_wise_distributed_optimizer, ( + "The optimizer cpu offload must be used in conjunction with " + "`--use-precision-aware-optimizer` (for HybridDeviceOptimizer) or " + "`--use-layer-wise-distributed-optimizer` (for Muon/emerging optimizer offloading)." ) assert not args.fp8_param_gather or args.fp8_recipe == "delayed", ( "When `--fp8-param-gather` is enabled, the optimizer cpu offload " diff --git a/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py b/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py index 87245967610..d18fc933c43 100644 --- a/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py +++ b/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py @@ -28,13 +28,15 @@ from megatron.core.transformer import TransformerConfig -def _create_model(seed: int, tp: int, pp: int) -> GPTModel: +def _create_model(seed: int, tp: int, pp: int, bf16_params: bool = True) -> GPTModel: """Create a small GPT model for testing. Args: seed: Random seed for reproducibility. tp: Tensor parallel size (already initialized via Utils). pp: Pipeline parallel size (already initialized via Utils). + bf16_params: If True, model params are bf16 (exercises float16_groups path). + If False, params are fp32 (exercises fp32_from_fp32_groups path). Returns: A GPTModel instance on the current CUDA device. @@ -45,7 +47,7 @@ def _create_model(seed: int, tp: int, pp: int) -> GPTModel: num_layers=6, hidden_size=16, num_attention_heads=8, - use_cpu_initialization=True, + use_cpu_initialization=not bf16_params, pipeline_dtype=torch.bfloat16, bf16=True, ) @@ -58,7 +60,10 @@ def _create_model(seed: int, tp: int, pp: int) -> GPTModel: pre_process=parallel_state.is_pipeline_first_stage(), post_process=parallel_state.is_pipeline_last_stage(), ) - model.cuda(torch.cuda.current_device()) + if bf16_params: + model.bfloat16().cuda(torch.cuda.current_device()) + else: + model.cuda(torch.cuda.current_device()) return model @@ -125,6 +130,7 @@ def teardown_method(self, method) -> None: @pytest.mark.parametrize('tp,pp', [(2, 2), (1, 4), (4, 1)]) def test_states_on_cpu(self, tp: int, pp: int) -> None: + print(f"test_states_on_cpu: tp={tp}, pp={pp}") """After init with cpu_offload=True, fp32 master weights and state are on CPU.""" if tp * pp > torch.cuda.device_count(): pytest.skip("Not enough GPUs") @@ -152,6 +158,7 @@ def test_states_on_cpu(self, tp: int, pp: int) -> None: @pytest.mark.parametrize('tp,pp', [(2, 2), (1, 4), (4, 1)]) def test_roundtrip_correctness(self, tp: int, pp: int) -> None: """Offload -> reload preserves fp32 master weight values exactly.""" + print(f"test_roundtrip_correctness: tp={tp}, pp={pp}") if tp * pp > torch.cuda.device_count(): pytest.skip("Not enough GPUs") @@ -186,22 +193,41 @@ def test_roundtrip_correctness(self, tp: int, pp: int) -> None: assert not param.data.is_cuda, "After offload, param should be on CPU" @pytest.mark.parametrize('tp,pp', [(2, 2), (1, 4), (4, 1)]) - def test_step_runs(self, tp: int, pp: int) -> None: - """A full optimizer.step() succeeds with CPU offloading.""" + @pytest.mark.parametrize('bf16_params', [True, False]) + def test_step_runs(self, tp: int, pp: int, bf16_params: bool) -> None: + """A full optimizer.step() succeeds with CPU offloading. + + Verifies the full prepare_grads → step_with_ready_grads cycle works + when states start on CPU. Sets main_grad on the params tracked by + Float16OptimizerWithFloat16Params.float16_groups to exercise the + _copy_model_grads_to_main_grads path that assigns CUDA grads to + the fp32 master params (which are on CPU before reload). + + When bf16_params=True, model params are bf16 and the float16_groups → + fp32_from_float16_groups path is exercised (the device mismatch path). + When bf16_params=False, params are fp32 and fp32_from_fp32_groups is used. + """ if tp * pp > torch.cuda.device_count(): pytest.skip("Not enough GPUs") Utils.initialize_model_parallel(tp, pp) - model = _create_model(seed=2, tp=tp, pp=pp) + model = _create_model(seed=2, tp=tp, pp=pp, bf16_params=bf16_params) optimizer = _create_optimizer(model, cpu_offload=True) assert isinstance(optimizer, LayerWiseDistributedOptimizer) - for param in model.parameters(): - if param.requires_grad: - g = torch.randn_like(param.data) - param.grad = g - param.main_grad = g + # Set main_grad on all params tracked by the sub-optimizers so that + # _copy_model_grads_to_main_grads is exercised during prepare_grads(). + for opt in optimizer.chained_optimizers: + if getattr(opt, 'is_stub_optimizer', False): + continue + if isinstance(opt, Float16OptimizerWithFloat16Params): + for group in opt.float16_groups: + for model_param in group: + model_param.main_grad = torch.randn_like(model_param.data) + for group in opt.fp32_from_fp32_groups: + for model_param in group: + model_param.main_grad = torch.randn_like(model_param.data) update_successful, grad_norm, num_zeros = optimizer.step() assert isinstance(update_successful, bool) @@ -215,21 +241,23 @@ def test_step_runs(self, tp: int, pp: int) -> None: @pytest.mark.parametrize('tp,pp', [(2, 2), (4, 1)]) @pytest.mark.parametrize('n_steps', [3, 5]) - def test_numerical_equivalence(self, tp: int, pp: int, n_steps: int) -> None: + @pytest.mark.parametrize('bf16_params', [True, False]) + def test_numerical_equivalence( + self, tp: int, pp: int, n_steps: int, bf16_params: bool + ) -> None: """Offloaded and non-offloaded optimizers produce bit-identical results. Runs both an offloaded and a non-offloaded optimizer for ``n_steps`` with identical random gradients, then verifies that fp32 master weights - and optimizer state tensors match exactly. This ensures the offload/reload - cycle introduces zero numerical drift. + and optimizer state tensors match exactly. """ if tp * pp > torch.cuda.device_count(): pytest.skip("Not enough GPUs") Utils.initialize_model_parallel(tp, pp) - model_off = _create_model(seed=42, tp=tp, pp=pp) - model_ref = _create_model(seed=42, tp=tp, pp=pp) + model_off = _create_model(seed=42, tp=tp, pp=pp, bf16_params=bf16_params) + model_ref = _create_model(seed=42, tp=tp, pp=pp, bf16_params=bf16_params) opt_off = _create_optimizer(model_off, cpu_offload=True) opt_ref = _create_optimizer(model_ref, cpu_offload=False) @@ -242,14 +270,24 @@ def test_numerical_equivalence(self, tp: int, pp: int, n_steps: int) -> None: for step_i in range(n_steps): torch.manual_seed(1000 + step_i + rank) - for p_off, p_ref in zip(model_off.parameters(), model_ref.parameters()): - if not p_off.requires_grad: - continue - g = torch.randn_like(p_off.data) - p_off.grad = g.clone() - p_off.main_grad = p_off.grad - p_ref.grad = g.clone() - p_ref.main_grad = p_ref.grad + # Set main_grad on all tracked params (same path as DDP grad buffers). + for fp16_opt_off, fp16_opt_ref in zip( + _iter_fp16_opts(opt_off), _iter_fp16_opts(opt_ref) + ): + for grp_off, grp_ref in zip( + fp16_opt_off.float16_groups, fp16_opt_ref.float16_groups + ): + for p_off, p_ref in zip(grp_off, grp_ref): + g = torch.randn_like(p_off.data) + p_off.main_grad = g.clone() + p_ref.main_grad = g.clone() + for grp_off, grp_ref in zip( + fp16_opt_off.fp32_from_fp32_groups, fp16_opt_ref.fp32_from_fp32_groups + ): + for p_off, p_ref in zip(grp_off, grp_ref): + g = torch.randn_like(p_off.data) + p_off.main_grad = g.clone() + p_ref.main_grad = g.clone() opt_off.step() opt_ref.step() @@ -280,3 +318,66 @@ def test_numerical_equivalence(self, tp: int, pp: int, n_steps: int) -> None: ) opt_off.offload_optimizer_states() + + @pytest.mark.parametrize('tp,pp', [(2, 2), (4, 1)]) + @pytest.mark.parametrize('bf16_params', [True, False]) + def test_prepare_grads_reloads_before_grad_copy( + self, tp: int, pp: int, bf16_params: bool + ) -> None: + """prepare_grads() must reload states to GPU before gradient assignment. + + This directly tests the fix for the reviewer-identified bug where + _copy_model_grads_to_main_grads assigns a CUDA grad tensor to the + fp32 main_param — which requires main_param.data to be on GPU (since + nn.Parameter forbids cross-device .data/.grad). + + With bf16_params=True, the float16_groups → fp32_from_float16_groups + path is exercised — this is the path that would RuntimeError without + the prepare_grads() fix. + """ + if tp * pp > torch.cuda.device_count(): + pytest.skip("Not enough GPUs") + + Utils.initialize_model_parallel(tp, pp) + model = _create_model(seed=7, tp=tp, pp=pp, bf16_params=bf16_params) + optimizer = _create_optimizer(model, cpu_offload=True) + + assert isinstance(optimizer, LayerWiseDistributedOptimizer) + + # Verify states start on CPU after construction. + for opt in _iter_fp16_opts(optimizer): + for group in opt.fp32_from_float16_groups: + for param in group: + assert not param.data.is_cuda + + # Set main_grad (CUDA tensors) on all tracked params — this is what DDP does. + for opt in _iter_fp16_opts(optimizer): + for group in opt.float16_groups: + for model_param in group: + model_param.main_grad = torch.randn_like(model_param.data) + for group in opt.fp32_from_fp32_groups: + for model_param in group: + model_param.main_grad = torch.randn_like(model_param.data) + + # Call prepare_grads() — should reload states and then copy grads. + # Without the fix, this raises RuntimeError (cross-device grad assignment). + result = optimizer.prepare_grads() + assert isinstance(result, bool) + + # After prepare_grads, fp32 master params should be on GPU with grads set. + for opt in _iter_fp16_opts(optimizer): + for group in opt.fp32_from_float16_groups: + for param in group: + assert param.data.is_cuda, ( + "After prepare_grads, fp32 master weight must be on GPU" + ) + + # Now step_with_ready_grads should work and offload back. + optimizer.step_with_ready_grads() + + for opt in _iter_fp16_opts(optimizer): + for group in opt.fp32_from_float16_groups: + for param in group: + assert not param.data.is_cuda, ( + "After step_with_ready_grads, fp32 master weight must be on CPU" + ) From 6583dd4223280d87376b6bf1e0f561213748c588 Mon Sep 17 00:00:00 2001 From: pengdurice Date: Fri, 29 May 2026 21:18:01 +0000 Subject: [PATCH 11/15] additional tests Signed-off-by: pengdurice --- .../test_muon_cpu_offload.py | 113 ++++++++++++++++++ 1 file changed, 113 insertions(+) diff --git a/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py b/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py index d18fc933c43..933d492b5a8 100644 --- a/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py +++ b/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py @@ -113,6 +113,119 @@ def _iter_fp16_opts( yield opt +class TestAdamOffloadConfig: + """Verify Adam fallback offloading behavior in the legacy LayerWise path. + + In the legacy LayerWise path (use_distributed_optimizer=False), Adam params + feed INTO LayerWise as a child — LayerWise handles offloading, so Adam's own + HybridDeviceOptimizer must be disabled (optimizer_cpu_offload=False on the + fallback config). This prevents creating a HybridDeviceOptimizer for Adam + and avoids double-offloading. + + In the separate DistributedOptimizer path (use_distributed_optimizer=True), + Adam's DistOpt is a sibling of LayerWise — it manages its own offloading. + That path requires full DDP setup and is validated via integration tests. + """ + + def setup_method(self, method) -> None: + pass + + def teardown_method(self, method) -> None: + Utils.destroy_model_parallel() + + @pytest.mark.parametrize('tp,pp', [(2, 2)]) + def test_legacy_path_adam_not_hybrid_device_optimizer(self, tp: int, pp: int) -> None: + """In legacy path, Adam child is NOT a HybridDeviceOptimizer.""" + if tp * pp > torch.cuda.device_count(): + pytest.skip("Not enough GPUs") + + from megatron.core.optimizer.cpu_offloading.hybrid_optimizer import HybridDeviceOptimizer + + Utils.initialize_model_parallel(tp, pp) + model = _create_model(seed=1, tp=tp, pp=pp, bf16_params=True) + + config = OptimizerConfig( + bf16=True, + params_dtype=torch.bfloat16, + use_distributed_optimizer=False, + use_layer_wise_distributed_optimizer=True, + optimizer='muon', + lr=0.0, + optimizer_cpu_offload=True, + ) + optimizer = get_megatron_optimizer(config, [model]) + assert isinstance(optimizer, LayerWiseDistributedOptimizer) + + # In the legacy path, Adam params go inside LayerWise. The fallback config + # has optimizer_cpu_offload=False, so no HybridDeviceOptimizer is created. + # LayerWise handles offloading for all children uniformly. + for opt in optimizer.chained_optimizers: + if isinstance(opt, Float16OptimizerWithFloat16Params): + inner_opt = opt.optimizer + else: + inner_opt = opt + assert not isinstance(inner_opt, HybridDeviceOptimizer), ( + f"In legacy LayerWise path, Adam should NOT use " + f"HybridDeviceOptimizer (LayerWise manages offloading). " + f"Got {type(inner_opt).__name__}" + ) + + @pytest.mark.parametrize('tp,pp', [(2, 2)]) + def test_legacy_path_layerwise_manages_adam_offload(self, tp: int, pp: int) -> None: + """LayerWise offload/reload cycle covers Adam params in legacy path.""" + if tp * pp > torch.cuda.device_count(): + pytest.skip("Not enough GPUs") + + Utils.initialize_model_parallel(tp, pp) + model = _create_model(seed=1, tp=tp, pp=pp, bf16_params=True) + + config = OptimizerConfig( + bf16=True, + params_dtype=torch.bfloat16, + use_distributed_optimizer=False, + use_layer_wise_distributed_optimizer=True, + optimizer='muon', + lr=0.0, + optimizer_cpu_offload=True, + ) + optimizer = get_megatron_optimizer(config, [model]) + assert isinstance(optimizer, LayerWiseDistributedOptimizer) + assert optimizer._cpu_offload + + # After init, LayerWise offloads ALL children's fp32 master weights — + # including Adam-managed params (biases, layernorms, embeddings). + # Init state so optimizer.state has tensors. + for opt in optimizer.chained_optimizers: + init_fn = getattr(opt, 'init_state_fn', None) + if init_fn is None: + continue + if hasattr(opt, 'optimizer'): + init_fn(opt.optimizer) + else: + init_fn(opt) + optimizer.offload_optimizer_states() + + # Verify ALL fp32_from_float16_groups params across all children are on CPU. + found_any = False + for opt in _iter_fp16_opts(optimizer): + for group in opt.fp32_from_float16_groups: + for param in group: + found_any = True + assert not param.data.is_cuda, ( + "LayerWise should offload ALL children's master weights" + ) + assert found_any, "Expected at least some fp32 master weights to verify" + + # Reload and verify they're back on GPU. + optimizer.reload_optimizer_states() + for opt in _iter_fp16_opts(optimizer): + for group in opt.fp32_from_float16_groups: + for param in group: + assert param.data.is_cuda, "After reload, all should be on GPU" + + optimizer.offload_optimizer_states() + + class TestMuonCPUOffload: """Tests for Muon CPU offloading in LayerWiseDistributedOptimizer. From de640b858f3d994e0e63bd7167cd6130d6289149 Mon Sep 17 00:00:00 2001 From: pengdurice Date: Mon, 1 Jun 2026 16:40:22 +0000 Subject: [PATCH 12/15] address separate adam path test Signed-off-by: pengdurice --- .../test_muon_cpu_offload.py | 169 +++++++++++++++++- 1 file changed, 165 insertions(+), 4 deletions(-) diff --git a/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py b/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py index 933d492b5a8..8c6c7dc8858 100644 --- a/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py +++ b/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py @@ -11,7 +11,9 @@ launcher) since LayerWiseDistributedOptimizer shards parameters across DP ranks. """ +from functools import partial from typing import Generator +from unittest import mock import pytest import torch @@ -22,7 +24,7 @@ from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.optimizer import get_megatron_optimizer from megatron.core.optimizer.layer_wise_optimizer import LayerWiseDistributedOptimizer -from megatron.core.optimizer.optimizer import Float16OptimizerWithFloat16Params +from megatron.core.optimizer.optimizer import ChainedOptimizer, Float16OptimizerWithFloat16Params from megatron.core.optimizer.optimizer_config import OptimizerConfig from megatron.core.tensor_parallel import model_parallel_cuda_manual_seed from megatron.core.transformer import TransformerConfig @@ -122,9 +124,8 @@ class TestAdamOffloadConfig: fallback config). This prevents creating a HybridDeviceOptimizer for Adam and avoids double-offloading. - In the separate DistributedOptimizer path (use_distributed_optimizer=True), - Adam's DistOpt is a sibling of LayerWise — it manages its own offloading. - That path requires full DDP setup and is validated via integration tests. + The separate DistributedOptimizer path (use_distributed_optimizer=True) is + tested in TestSeparateDistOptPath, which requires full DDP setup. """ def setup_method(self, method) -> None: @@ -494,3 +495,163 @@ def test_prepare_grads_reloads_before_grad_copy( assert not param.data.is_cuda, ( "After step_with_ready_grads, fp32 master weight must be on CPU" ) + + +class TestSeparateDistOptPath: + """Verify Adam offloading in the use_separate_distributed_optimizer=True path. + + In this path, LayerWise only owns Muon params, and Adam's DistributedOptimizer + is a sibling that manages its own CPU offloading via HybridDeviceOptimizer. + This test class requires DDP-wrapped models (via get_model) to trigger the + use_separate_distributed_optimizer codepath. + """ + + def setup_method(self, method) -> None: + pass + + def teardown_method(self, method) -> None: + Utils.destroy_model_parallel() + + def _setup_hybrid_optimizer(self, tp: int, pp: int): + """Create optimizer via the use_separate_distributed_optimizer=True path. + + Uses get_model (DDP wrapping) so model_chunks[0].ddp_config exists with + use_distributed_optimizer=True, which triggers the new path. + """ + from tests.unit_tests.dist_checkpointing.utils import ( + init_basic_mock_args, + initialize_gpt_model, + ) + from megatron.training.arguments import parse_args + from megatron.training.training import get_model + + Utils.initialize_model_parallel(tp, pp) + + mock_args = parse_args(ignore_unknown_args=True) + with mock.patch('megatron.training.training.get_args', new=lambda: mock_args): + init_basic_mock_args(mock_args, tp, pp, bf16=True) + mock_args.use_distributed_optimizer = True + mock_args.use_layer_wise_distributed_optimizer = True + mock_args.optimizer = 'muon' + model = get_model( + partial( + initialize_gpt_model, + seed=2, + tensor_model_parallel_size=tp, + pipeline_model_parallel_size=pp, + pipeline_dtype=torch.bfloat16, + bf16=True, + ) + ) + + config = OptimizerConfig( + bf16=True, + params_dtype=torch.bfloat16, + use_distributed_optimizer=True, + use_layer_wise_distributed_optimizer=True, + optimizer='muon', + lr=0.0, + optimizer_cpu_offload=True, + ) + optimizer = get_megatron_optimizer(config, model) + return model, optimizer + + @pytest.mark.parametrize('tp,pp', [(2, 2)]) + def test_new_path_adam_uses_hybrid_device_optimizer(self, tp: int, pp: int) -> None: + """In new path, Adam's DistOpt wraps a HybridDeviceOptimizer.""" + if tp * pp > torch.cuda.device_count(): + pytest.skip("Not enough GPUs") + + from megatron.core.optimizer.cpu_offloading.hybrid_optimizer import HybridDeviceOptimizer + from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer + + _, optimizer = self._setup_hybrid_optimizer(tp, pp) + + # New path returns ChainedOptimizer([LayerWise, DistributedOptimizer]) + assert isinstance(optimizer, ChainedOptimizer) + + found_layer_wise = False + found_dist_opt_with_hybrid = False + for child in optimizer.chained_optimizers: + if isinstance(child, LayerWiseDistributedOptimizer): + found_layer_wise = True + # LayerWise should NOT contain HybridDeviceOptimizer children + for sub_opt in child.chained_optimizers: + if isinstance(sub_opt, Float16OptimizerWithFloat16Params): + assert not isinstance(sub_opt.optimizer, HybridDeviceOptimizer), ( + "In new path, LayerWise children should NOT use " + "HybridDeviceOptimizer (LayerWise only owns Muon params)" + ) + elif isinstance(child, DistributedOptimizer): + # Adam's DistOpt should use HybridDeviceOptimizer for CPU offload + if isinstance(child.optimizer, HybridDeviceOptimizer): + found_dist_opt_with_hybrid = True + + assert found_layer_wise, "Expected a LayerWiseDistributedOptimizer in the chain" + assert found_dist_opt_with_hybrid, ( + "Expected Adam's DistributedOptimizer to use HybridDeviceOptimizer " + "for CPU offloading in the use_separate_distributed_optimizer=True path" + ) + + @pytest.mark.parametrize('tp,pp', [(2, 2)]) + def test_new_path_layerwise_only_offloads_muon_params(self, tp: int, pp: int) -> None: + """In new path, LayerWise offload/reload only touches Muon params.""" + if tp * pp > torch.cuda.device_count(): + pytest.skip("Not enough GPUs") + + from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer + + _, optimizer = self._setup_hybrid_optimizer(tp, pp) + assert isinstance(optimizer, ChainedOptimizer) + + # Find LayerWise and DistOpt children + layer_wise = None + dist_opt = None + for child in optimizer.chained_optimizers: + if isinstance(child, LayerWiseDistributedOptimizer): + layer_wise = child + elif isinstance(child, DistributedOptimizer): + dist_opt = child + + assert layer_wise is not None + assert dist_opt is not None + + # Init states for LayerWise children + for opt in layer_wise.chained_optimizers: + init_fn = getattr(opt, 'init_state_fn', None) + if init_fn is None: + continue + if hasattr(opt, 'optimizer'): + init_fn(opt.optimizer) + else: + init_fn(opt) + + # Offload LayerWise states + layer_wise.offload_optimizer_states() + + # LayerWise's children (Muon) should now be on CPU + found_muon_on_cpu = False + for opt in layer_wise.chained_optimizers: + if getattr(opt, 'is_stub_optimizer', False): + continue + if isinstance(opt, Float16OptimizerWithFloat16Params): + for group in opt.fp32_from_float16_groups: + for param in group: + found_muon_on_cpu = True + assert not param.data.is_cuda, ( + "After LayerWise offload, Muon master weights should be on CPU" + ) + + assert found_muon_on_cpu, "Expected Muon master weights in LayerWise" + + # Reload and verify back on GPU + layer_wise.reload_optimizer_states() + for opt in layer_wise.chained_optimizers: + if getattr(opt, 'is_stub_optimizer', False): + continue + if isinstance(opt, Float16OptimizerWithFloat16Params): + for group in opt.fp32_from_float16_groups: + for param in group: + assert param.data.is_cuda, ( + "After LayerWise reload, Muon master weights should be on GPU" + ) From 4532e68313cb47e1ec24e53426c8aea239402fba Mon Sep 17 00:00:00 2001 From: pengdurice Date: Mon, 1 Jun 2026 22:54:12 +0000 Subject: [PATCH 13/15] address comments Signed-off-by: pengdurice --- .../test_muon_cpu_offload.py | 63 ++++++++++++++++++- 1 file changed, 61 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py b/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py index 8c6c7dc8858..33adcb2cf02 100644 --- a/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py +++ b/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py @@ -502,8 +502,10 @@ class TestSeparateDistOptPath: In this path, LayerWise only owns Muon params, and Adam's DistributedOptimizer is a sibling that manages its own CPU offloading via HybridDeviceOptimizer. - This test class requires DDP-wrapped models (via get_model) to trigger the - use_separate_distributed_optimizer codepath. + + These tests use get_model (DDP wrapping) to trigger the new codepath. + optimizer_offload_fraction=1.0 ensures all Adam params are physically + offloaded to CPU, allowing behavioral (not just structural) assertions. """ def setup_method(self, method) -> None: @@ -517,6 +519,8 @@ def _setup_hybrid_optimizer(self, tp: int, pp: int): Uses get_model (DDP wrapping) so model_chunks[0].ddp_config exists with use_distributed_optimizer=True, which triggers the new path. + optimizer_offload_fraction=1.0 so HybridDeviceOptimizer offloads all + Adam params to CPU (not just structurally present). """ from tests.unit_tests.dist_checkpointing.utils import ( init_basic_mock_args, @@ -552,6 +556,7 @@ def _setup_hybrid_optimizer(self, tp: int, pp: int): optimizer='muon', lr=0.0, optimizer_cpu_offload=True, + optimizer_offload_fraction=1.0, ) optimizer = get_megatron_optimizer(config, model) return model, optimizer @@ -593,6 +598,60 @@ def test_new_path_adam_uses_hybrid_device_optimizer(self, tp: int, pp: int) -> N "for CPU offloading in the use_separate_distributed_optimizer=True path" ) + @pytest.mark.parametrize('tp,pp', [(2, 2)]) + def test_new_path_adam_params_physically_on_cpu(self, tp: int, pp: int) -> None: + """Adam params are physically offloaded to CPU by HybridDeviceOptimizer. + + With optimizer_offload_fraction=1.0, all Adam params managed by + HybridDeviceOptimizer should have CPU copies. This is a behavioral + assertion (not just structural) that the offloading actually happens. + """ + if tp * pp > torch.cuda.device_count(): + pytest.skip("Not enough GPUs") + + from megatron.core.optimizer.cpu_offloading.hybrid_optimizer import HybridDeviceOptimizer + from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer + + _, optimizer = self._setup_hybrid_optimizer(tp, pp) + assert isinstance(optimizer, ChainedOptimizer) + + # Find Adam's DistOpt with HybridDeviceOptimizer + hybrid_opt = None + for child in optimizer.chained_optimizers: + if isinstance(child, DistributedOptimizer): + if isinstance(child.optimizer, HybridDeviceOptimizer): + hybrid_opt = child.optimizer + break + + assert hybrid_opt is not None, "HybridDeviceOptimizer not found" + + # With offload_fraction=1.0, all params should be routed to CPU optimizers. + # gpu_params_map_cpu_copy maps GPU param → CPU copy for offloaded params. + assert len(hybrid_opt.gpu_params_map_cpu_copy) > 0, ( + "With offload_fraction=1.0, HybridDeviceOptimizer should have " + "offloaded params (gpu_params_map_cpu_copy is empty)" + ) + + # Verify the CPU copies are actually on CPU + for gpu_param, cpu_copy in hybrid_opt.gpu_params_map_cpu_copy.items(): + assert not cpu_copy.is_cuda, ( + f"CPU copy of offloaded Adam param should be on CPU, got {cpu_copy.device}" + ) + assert gpu_param.is_cuda, ( + f"Original Adam param should remain on GPU, got {gpu_param.device}" + ) + + # Verify CPU optimizers exist and have params + assert len(hybrid_opt.cpu_optimizers) > 0, ( + "With offload_fraction=1.0, there should be CPU optimizer(s)" + ) + cpu_param_count = sum( + sum(1 for _ in group['params']) + for opt in hybrid_opt.cpu_optimizers + for group in opt.param_groups + ) + assert cpu_param_count > 0, "CPU optimizers should have params assigned" + @pytest.mark.parametrize('tp,pp', [(2, 2)]) def test_new_path_layerwise_only_offloads_muon_params(self, tp: int, pp: int) -> None: """In new path, LayerWise offload/reload only touches Muon params.""" From 742db5eac072da03d6beb566eb3f90d61ec26f3e Mon Sep 17 00:00:00 2001 From: pengdurice Date: Wed, 3 Jun 2026 17:44:12 +0000 Subject: [PATCH 14/15] fix claude comments Signed-off-by: pengdurice --- .../core/optimizer/layer_wise_optimizer.py | 10 +- .../test_layer_wise_optimizer.py | 147 ++++++++++++++++++ 2 files changed, 156 insertions(+), 1 deletion(-) diff --git a/megatron/core/optimizer/layer_wise_optimizer.py b/megatron/core/optimizer/layer_wise_optimizer.py index 3343167c168..1a5469bf1cd 100644 --- a/megatron/core/optimizer/layer_wise_optimizer.py +++ b/megatron/core/optimizer/layer_wise_optimizer.py @@ -849,9 +849,14 @@ def reload_optimizer_states(self) -> None: from CPU pinned memory to the current CUDA device. A ``torch.cuda.synchronize()`` at exit ensures all H2D transfers complete before the optimizer step proceeds. + This method is idempotent: if all tensors are already on GPU (e.g. when called + a second time from ``step_with_ready_grads`` after ``prepare_grads`` already + reloaded), no synchronization barrier is issued. + Called at the start of each optimizer step so that the Newton-Schulz iterations (Muon) or Adam updates can operate on GPU tensors. """ + any_moved = False for opt in self.chained_optimizers: if getattr(opt, 'is_stub_optimizer', False): continue @@ -861,11 +866,14 @@ def reload_optimizer_states(self) -> None: for param in group: if not param.data.is_cuda: param.data = param.data.to('cuda') + any_moved = True for state_vals in opt.optimizer.state.values(): for key, val in state_vals.items(): if isinstance(val, torch.Tensor) and not val.is_cuda: state_vals[key] = val.to('cuda') - torch.cuda.synchronize() + any_moved = True + if any_moved: + torch.cuda.synchronize() # TODO(deyuf): need to improve dist checkpointing design to properly handle this # fp32_from_fp16_params is list, each sub list could be empty if group is empty diff --git a/tests/unit_tests/dist_checkpointing/test_layer_wise_optimizer.py b/tests/unit_tests/dist_checkpointing/test_layer_wise_optimizer.py index 42ef0a401ee..93a420ac370 100644 --- a/tests/unit_tests/dist_checkpointing/test_layer_wise_optimizer.py +++ b/tests/unit_tests/dist_checkpointing/test_layer_wise_optimizer.py @@ -675,3 +675,150 @@ def test_optimizer_common_state_dict_hybrid(self, tmp_path_dist_ckpt, tp, pp): check_equal(optim_param_state_A, optim_param_state_B) Utils.destroy_model_parallel() + + @pytest.mark.parametrize('tp', [1, 2]) + @pytest.mark.parametrize('pp', [1, 2]) + @pytest.mark.parametrize('bf16', [True]) + def test_layer_wise_optimizer_save_load_cpu_offload( + self, tmp_path_dist_ckpt, tp, pp, bf16 + ): + """Test save/load of LayerWiseDistributedOptimizer with cpu_offload=True. + + Runs an optimizer step to produce non-trivial post-update momentum + states, then saves a checkpoint with states offloaded to CPU. Loads + into a fresh optimizer, verifies states are correctly placed on CPU, + confirms a subsequent step succeeds after load, and checks checkpoint + A == B equivalence. + """ + if tp * pp > 8: + pytest.skip(f"TP*PP > 8 is larger than world size") + + Utils.initialize_model_parallel(tp, pp) + + with TempNamedDir( + tmp_path_dist_ckpt / 'test_layer_wise_cpu_offload_A', sync=True + ) as ckpt_dir_A: + with TempNamedDir( + tmp_path_dist_ckpt / 'test_layer_wise_cpu_offload_B', sync=True + ) as ckpt_dir_B: + from megatron.core.optimizer.optimizer import Float16OptimizerWithFloat16Params + + # Create model and optimizer A with cpu_offload + model_A, optimizer_A = setup_model_and_optimizer( + seed=2, + tp=tp, + pp=pp, + bf16=bf16, + dist_opt=False, + initialize_fn=initialize_gpt_model, + optimizer='dist_muon', + ) + assert isinstance(optimizer_A, LayerWiseDistributedOptimizer) + optimizer_A._cpu_offload = True + + # Run a step to produce non-trivial optimizer state (momentum) + for opt in optimizer_A.chained_optimizers: + if getattr(opt, 'is_stub_optimizer', False): + continue + if isinstance(opt, Float16OptimizerWithFloat16Params): + for group in opt.float16_groups: + for p in group: + p.main_grad = torch.randn_like(p.data) + for group in opt.fp32_from_fp32_groups: + for p in group: + p.main_grad = torch.randn_like(p.data) + optimizer_A.step() + + # After step, states should be on CPU (offloaded by step) + for opt in optimizer_A.chained_optimizers: + if getattr(opt, 'is_stub_optimizer', False): + continue + if isinstance(opt, Float16OptimizerWithFloat16Params): + for group in opt.fp32_from_float16_groups: + for param in group: + assert not param.data.is_cuda + + # Reload to GPU for saving + optimizer_A.reload_optimizer_states() + + # Save checkpoint A + model_sharded_sd_A = model_A[0].sharded_state_dict() + optim_sd_A = optimizer_A.sharded_state_dict(model_sharded_sd_A) + save(optim_sd_A, ckpt_dir_A) + + # Offload again after save + optimizer_A.offload_optimizer_states() + + # Create model and optimizer B with different seed + model_B, optimizer_B = setup_model_and_optimizer( + seed=3, + tp=tp, + pp=pp, + bf16=bf16, + dist_opt=False, + initialize_fn=initialize_gpt_model, + optimizer='dist_muon', + ) + assert isinstance(optimizer_B, LayerWiseDistributedOptimizer) + optimizer_B._cpu_offload = True + + # Load checkpoint A into optimizer B + model_sharded_sd_B = model_B[0].sharded_state_dict() + load_sharded_sd = optimizer_B.sharded_state_dict( + model_sharded_sd_B, is_loading=True + ) + state_dict = load(load_sharded_sd, ckpt_dir_A) + optimizer_B.load_state_dict(state_dict) + + # Offload after load — key test: load_state_dict + offload works + optimizer_B.offload_optimizer_states() + for opt in optimizer_B.chained_optimizers: + if getattr(opt, 'is_stub_optimizer', False): + continue + if isinstance(opt, Float16OptimizerWithFloat16Params): + for group in opt.fp32_from_float16_groups: + for param in group: + assert not param.data.is_cuda, ( + "After load + offload, master weights must be on CPU" + ) + + # Reload and save as checkpoint B (before any extra step) + optimizer_B.reload_optimizer_states() + optim_sd_B = optimizer_B.sharded_state_dict(model_sharded_sd_B) + save(optim_sd_B, ckpt_dir_B) + + # Run a step after load to verify training can continue + for opt in optimizer_B.chained_optimizers: + if getattr(opt, 'is_stub_optimizer', False): + continue + if isinstance(opt, Float16OptimizerWithFloat16Params): + for group in opt.float16_groups: + for p in group: + p.main_grad = torch.randn_like(p.data) + for group in opt.fp32_from_fp32_groups: + for p in group: + p.main_grad = torch.randn_like(p.data) + update_successful, _, _ = optimizer_B.step() + assert update_successful, "Step after checkpoint load should succeed" + + # After post-load step, states should be offloaded to CPU + for opt in optimizer_B.chained_optimizers: + if getattr(opt, 'is_stub_optimizer', False): + continue + if isinstance(opt, Float16OptimizerWithFloat16Params): + for group in opt.fp32_from_float16_groups: + for param in group: + assert not param.data.is_cuda, ( + "After post-load step, states must be on CPU" + ) + + Utils.destroy_model_parallel() + + # Compare checkpoints A==B (same state, different paths) + Utils.initialize_model_parallel(1, 1) + from megatron.core.dist_checkpointing import load_plain_tensors + + plain_sd_A = load_plain_tensors(ckpt_dir_A) + plain_sd_B = load_plain_tensors(ckpt_dir_B) + + check_equal(plain_sd_A, plain_sd_B) From 08bedf65199945ee1235ae11afd79ebabebbdeaf Mon Sep 17 00:00:00 2001 From: janEbert Date: Fri, 5 Jun 2026 11:30:47 +0200 Subject: [PATCH 15/15] Fix linting errors Ran `tools/autoformat.sh`. --- .../test_layer_wise_optimizer.py | 16 +++-- .../test_muon_cpu_offload.py | 58 +++++++++---------- 2 files changed, 35 insertions(+), 39 deletions(-) diff --git a/tests/unit_tests/dist_checkpointing/test_layer_wise_optimizer.py b/tests/unit_tests/dist_checkpointing/test_layer_wise_optimizer.py index 93a420ac370..070f73b0b55 100644 --- a/tests/unit_tests/dist_checkpointing/test_layer_wise_optimizer.py +++ b/tests/unit_tests/dist_checkpointing/test_layer_wise_optimizer.py @@ -679,9 +679,7 @@ def test_optimizer_common_state_dict_hybrid(self, tmp_path_dist_ckpt, tp, pp): @pytest.mark.parametrize('tp', [1, 2]) @pytest.mark.parametrize('pp', [1, 2]) @pytest.mark.parametrize('bf16', [True]) - def test_layer_wise_optimizer_save_load_cpu_offload( - self, tmp_path_dist_ckpt, tp, pp, bf16 - ): + def test_layer_wise_optimizer_save_load_cpu_offload(self, tmp_path_dist_ckpt, tp, pp, bf16): """Test save/load of LayerWiseDistributedOptimizer with cpu_offload=True. Runs an optimizer step to produce non-trivial post-update momentum @@ -778,9 +776,9 @@ def test_layer_wise_optimizer_save_load_cpu_offload( if isinstance(opt, Float16OptimizerWithFloat16Params): for group in opt.fp32_from_float16_groups: for param in group: - assert not param.data.is_cuda, ( - "After load + offload, master weights must be on CPU" - ) + assert ( + not param.data.is_cuda + ), "After load + offload, master weights must be on CPU" # Reload and save as checkpoint B (before any extra step) optimizer_B.reload_optimizer_states() @@ -808,9 +806,9 @@ def test_layer_wise_optimizer_save_load_cpu_offload( if isinstance(opt, Float16OptimizerWithFloat16Params): for group in opt.fp32_from_float16_groups: for param in group: - assert not param.data.is_cuda, ( - "After post-load step, states must be on CPU" - ) + assert ( + not param.data.is_cuda + ), "After post-load step, states must be on CPU" Utils.destroy_model_parallel() diff --git a/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py b/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py index 33adcb2cf02..976a6af9433 100644 --- a/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py +++ b/tests/unit_tests/dist_checkpointing/test_muon_cpu_offload.py @@ -17,7 +17,6 @@ import pytest import torch -from tests.unit_tests.test_utilities import Utils from megatron.core import parallel_state from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec @@ -28,6 +27,7 @@ from megatron.core.optimizer.optimizer_config import OptimizerConfig from megatron.core.tensor_parallel import model_parallel_cuda_manual_seed from megatron.core.transformer import TransformerConfig +from tests.unit_tests.test_utilities import Utils def _create_model(seed: int, tp: int, pp: int, bf16_params: bool = True) -> GPTModel: @@ -212,9 +212,9 @@ def test_legacy_path_layerwise_manages_adam_offload(self, tp: int, pp: int) -> N for group in opt.fp32_from_float16_groups: for param in group: found_any = True - assert not param.data.is_cuda, ( - "LayerWise should offload ALL children's master weights" - ) + assert ( + not param.data.is_cuda + ), "LayerWise should offload ALL children's master weights" assert found_any, "Expected at least some fp32 master weights to verify" # Reload and verify they're back on GPU. @@ -356,9 +356,7 @@ def test_step_runs(self, tp: int, pp: int, bf16_params: bool) -> None: @pytest.mark.parametrize('tp,pp', [(2, 2), (4, 1)]) @pytest.mark.parametrize('n_steps', [3, 5]) @pytest.mark.parametrize('bf16_params', [True, False]) - def test_numerical_equivalence( - self, tp: int, pp: int, n_steps: int, bf16_params: bool - ) -> None: + def test_numerical_equivalence(self, tp: int, pp: int, n_steps: int, bf16_params: bool) -> None: """Offloaded and non-offloaded optimizers produce bit-identical results. Runs both an offloaded and a non-offloaded optimizer for ``n_steps`` @@ -482,9 +480,9 @@ def test_prepare_grads_reloads_before_grad_copy( for opt in _iter_fp16_opts(optimizer): for group in opt.fp32_from_float16_groups: for param in group: - assert param.data.is_cuda, ( - "After prepare_grads, fp32 master weight must be on GPU" - ) + assert ( + param.data.is_cuda + ), "After prepare_grads, fp32 master weight must be on GPU" # Now step_with_ready_grads should work and offload back. optimizer.step_with_ready_grads() @@ -492,9 +490,9 @@ def test_prepare_grads_reloads_before_grad_copy( for opt in _iter_fp16_opts(optimizer): for group in opt.fp32_from_float16_groups: for param in group: - assert not param.data.is_cuda, ( - "After step_with_ready_grads, fp32 master weight must be on CPU" - ) + assert ( + not param.data.is_cuda + ), "After step_with_ready_grads, fp32 master weight must be on CPU" class TestSeparateDistOptPath: @@ -522,12 +520,12 @@ def _setup_hybrid_optimizer(self, tp: int, pp: int): optimizer_offload_fraction=1.0 so HybridDeviceOptimizer offloads all Adam params to CPU (not just structurally present). """ + from megatron.training.arguments import parse_args + from megatron.training.training import get_model from tests.unit_tests.dist_checkpointing.utils import ( init_basic_mock_args, initialize_gpt_model, ) - from megatron.training.arguments import parse_args - from megatron.training.training import get_model Utils.initialize_model_parallel(tp, pp) @@ -634,17 +632,17 @@ def test_new_path_adam_params_physically_on_cpu(self, tp: int, pp: int) -> None: # Verify the CPU copies are actually on CPU for gpu_param, cpu_copy in hybrid_opt.gpu_params_map_cpu_copy.items(): - assert not cpu_copy.is_cuda, ( - f"CPU copy of offloaded Adam param should be on CPU, got {cpu_copy.device}" - ) - assert gpu_param.is_cuda, ( - f"Original Adam param should remain on GPU, got {gpu_param.device}" - ) + assert ( + not cpu_copy.is_cuda + ), f"CPU copy of offloaded Adam param should be on CPU, got {cpu_copy.device}" + assert ( + gpu_param.is_cuda + ), f"Original Adam param should remain on GPU, got {gpu_param.device}" # Verify CPU optimizers exist and have params - assert len(hybrid_opt.cpu_optimizers) > 0, ( - "With offload_fraction=1.0, there should be CPU optimizer(s)" - ) + assert ( + len(hybrid_opt.cpu_optimizers) > 0 + ), "With offload_fraction=1.0, there should be CPU optimizer(s)" cpu_param_count = sum( sum(1 for _ in group['params']) for opt in hybrid_opt.cpu_optimizers @@ -697,9 +695,9 @@ def test_new_path_layerwise_only_offloads_muon_params(self, tp: int, pp: int) -> for group in opt.fp32_from_float16_groups: for param in group: found_muon_on_cpu = True - assert not param.data.is_cuda, ( - "After LayerWise offload, Muon master weights should be on CPU" - ) + assert ( + not param.data.is_cuda + ), "After LayerWise offload, Muon master weights should be on CPU" assert found_muon_on_cpu, "Expected Muon master weights in LayerWise" @@ -711,6 +709,6 @@ def test_new_path_layerwise_only_offloads_muon_params(self, tp: int, pp: int) -> if isinstance(opt, Float16OptimizerWithFloat16Params): for group in opt.fp32_from_float16_groups: for param in group: - assert param.data.is_cuda, ( - "After LayerWise reload, Muon master weights should be on GPU" - ) + assert ( + param.data.is_cuda + ), "After LayerWise reload, Muon master weights should be on GPU"