Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions megatron/core/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,8 @@ def _get_megatron_optimizer_based_on_param_groups(
raise ValueError(
"skip_megatron_wrapping=True is incompatible with use_precision_aware_optimizer."
)
if skip_megatron_wrapping and config.optimizer_cpu_offload:
raise ValueError("skip_megatron_wrapping=True is incompatible with optimizer_cpu_offload.")
# NOTE: skip_megatron_wrapping + optimizer_cpu_offload is allowed for Muon,
# where LayerWiseDistributedOptimizer handles CPU offloading itself.

# When freezing sub-models we may have no trainable parameters on a rank and
# hence an empty param_groups. However, we still need to create an optimizer
Expand Down Expand Up @@ -822,6 +822,8 @@ def _get_megatron_emerging_optimizer(
fallback_config = copy.copy(config)
fallback_config.optimizer = opt_name
fallback_config.use_distributed_optimizer = False
if use_layer_wise:
fallback_config.optimizer_cpu_offload = False
result = _get_megatron_optimizer_based_on_param_groups(
config=fallback_config,
model_chunks=model_chunks,
Expand Down
56 changes: 53 additions & 3 deletions megatron/core/optimizer/layer_wise_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def __init__(

super().__init__(optimizers)

self._cpu_offload = getattr(config, 'optimizer_cpu_offload', False)
if self._cpu_offload:
self.offload_optimizer_states()
logger.info('[layerwise] optimizer states CPU offloading enabled')

# TODO(kunlun, deyuf): potential future perf optimization
# since allreduce is unchanged and handled by megatron DDP, they're already in
# contiguous gbuf. So instead of shard param by layer randomly, we can shard by
Expand Down Expand Up @@ -277,15 +282,60 @@ def count_zeros(self):
@torch.no_grad()
def step(self): # type: ignore[no-untyped-def]
"""step function for layer-wise optimizer."""
if self._cpu_offload:
self.reload_optimizer_states()

update_successful, grad_norm, num_zeros_in_grad = super().step()

# All gather updated params. If async_allgather is True, the allgather
# is deferred to the forward pre-hooks via DDP bucket infrastructure.
# Synchronize updated params across DP ranks. If async_allgather is
# True, the sync is deferred to forward pre-hooks via DDP bucket
# infrastructure. Otherwise use broadcast_params() which does
# per-parameter in-place broadcasts with zero extra GPU memory,
# avoiding OOM from allgather_params()'s temporary flat buffers.
if not self.async_allgather:
self.allgather_params()
self.broadcast_params()

if self._cpu_offload:
self.offload_optimizer_states()

return update_successful, grad_norm, num_zeros_in_grad

@torch.no_grad()
def offload_optimizer_states(self):
"""Move fp32 master weights and optimizer states to CPU pinned memory."""
torch.cuda.synchronize()
for opt in self.chained_optimizers:
if getattr(opt, 'is_stub_optimizer', False):
continue
if not isinstance(opt, Float16OptimizerWithFloat16Params):
continue
for group in opt.fp32_from_float16_groups:
for param in group:
if param.data.is_cuda:
param.data = param.data.cpu().pin_memory()
for state_vals in opt.optimizer.state.values():
for key, val in state_vals.items():
if isinstance(val, torch.Tensor) and val.is_cuda:
state_vals[key] = val.cpu().pin_memory()

@torch.no_grad()
def reload_optimizer_states(self):
"""Move fp32 master weights and optimizer states back to GPU."""
for opt in self.chained_optimizers:
if getattr(opt, 'is_stub_optimizer', False):
continue
if not isinstance(opt, Float16OptimizerWithFloat16Params):
continue
for group in opt.fp32_from_float16_groups:
for param in group:
if not param.data.is_cuda:
param.data = param.data.to('cuda')
for state_vals in opt.optimizer.state.values():
for key, val in state_vals.items():
if isinstance(val, torch.Tensor) and not val.is_cuda:
state_vals[key] = val.to('cuda')
torch.cuda.synchronize()

# TODO(deyuf): need to improve dist checkpointing design to properly handle this
# fp32_from_fp16_params is list, each sub list could be empty if group is empty
# this breaks dist checkpointing assumption since extract_sharded_base drop list structure
Expand Down
38 changes: 38 additions & 0 deletions muon_cpu_offloading.slurm
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/bin/bash
#SBATCH --job-name=muon-cpu-offload
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --gpus-per-node=4
#SBATCH --output=./logs/muon_cpu_offload-%j.out
#SBATCH --error=./logs/muon_cpu_offload-%j.err
#SBATCH --exclusive
#SBATCH --time=0:30:00
#SBATCH --partition=h200-hi-preempt
#SBATCH --exclude=ip-10-1-115-124,ip-10-1-45-16
#SBATCH --wait-all-nodes=1

set -euxo pipefail

IMAGE=/fsx/peng/containers/nemo-rl-v0.5.2-te-tilelang.sqsh

MEGATRON_SRC=/fsx/peng/workspace/megatron-muon-cpu-offload/Megatron-LM
MEGATRON_DST=/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM

MOUNTS=${MEGATRON_SRC}:${MEGATRON_DST}

srun \
--container-image "${IMAGE}" \
--container-mounts "${MOUNTS}" \
--container-remap-root \
bash -lc "
set -euxo pipefail
cd ${MEGATRON_DST}

export NVIDIA_PYTORCH_VERSION=25.06
export PYTHONPATH=${MEGATRON_DST}:\${PYTHONPATH:-}
pip install git+https://github.com/NVIDIA-NeMo/Emerging-Optimizers.git@v0.2.0

torchrun --nproc-per-node=4 tests/test_muon_cpu_offload.py
"

echo "Test completed."
Loading