Skip to content

qwen3.5 27b save checkpoint oom #94

@mamazi0131

Description

@mamazi0131

When I set tp=4, the model trains normally, but an OOM (Out of Memory) error occurs during checkpointing. When I set tp=8, the following error is reported:

File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/trainer/ppo/ray_trainer.py", line 1437, in fit
ref_log_prob = self._compute_ref_log_prob(batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/trainer/ppo/ray_trainer.py", line 1125, in _compute_ref_log_prob
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/single_controller/ray/base.py", line 55, in call
output = ray.get(output)
^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^
ray.exceptions.RayTaskError: ray::WorkerDict.ref_compute_ref_log_prob() (pid=193257, ip=29.160.49.68, actor_id=a3a4794062e7929dfc45cd5401000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7f6bbdbb1f70>)
File "/usr/lib/python3.12/concurrent/futures/_base.py", line 456, in result
return self.__get_result()
^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
raise self._exception
^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/single_controller/ray/base.py", line 932, in func
return getattr(self.worker_dict[key], name)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/single_controller/base/decorator.py", line 427, in inner
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/utils/profiler/performance.py", line 105, in f
return self.log(decorated_function, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/utils/profiler/performance.py", line 118, in log
output = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/utils/profiler/profile.py", line 173, in wrapper
return func(self_instance, *args, **kwargs_inner)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/workers/megatron_workers.py", line 858, in compute_ref_log_prob
output, _, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/utils/profiler/performance.py", line 105, in f
return self.log(decorated_function, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/utils/profiler/performance.py", line 118, in log
output = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/workers/actor/megatron_actor.py", line 259, in compute_log_prob
output = self.forward_backward_batch(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/workers/actor/megatron_actor.py", line 733, in forward_backward_batch
losses_reduced = forward_backward_func(
^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/megatron/core/pipeline_parallel/schedules.py", line 636, in forward_backward_no_pipelining
output_tensor, num_tokens = forward_step(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/megatron/core/pipeline_parallel/schedules.py", line 423, in forward_step
output_tensor, loss_func = forward_step_func(data_iterator, model)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/workers/actor/megatron_actor.py", line 683, in forward_step
output = forward_fn(
^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/models/mcore/model_forward.py", line 141, in model_forward
output_orig = model(
^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1787, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/module.py", line 489, in forward
outputs = self.module(*inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1787, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/mbridge/models/qwen3_5/model.py", line 367, in forward
output = self.language_model(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1787, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/megatron/core/models/gpt/gpt_model.py", line 525, in forward
hidden_states = self.decoder(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_block.py", line 619, in call
return super().call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/module.py", line 352, in call
return super().call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1787, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_block.py", line 765, in forward
hidden_states, context = layer(
^^^^^^
File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_layer.py", line 1217, in call
return super().call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/module.py", line 352, in call
return super().call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1787, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_layer.py", line 513, in forward
hidden_states, context = self._forward_attention(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_layer.py", line 597, in _forward_attention
attention_output_with_bias = self.self_attention(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1787, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/mbridge/models/qwen3_5/attention.py", line 360, in forward
core_attn_out = self._apply_output_gate(core_attn_out, gate)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 953, in compile_wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 2202, in call
result = self._torchdynamo_orig_backend(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1945, in call
result = self._inner_convert(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 707, in call
result = _compile(
^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1752, in _compile
guarded_code, tracer_output = compile_inner(code, one_graph, hooks)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_utils_internal.py", line 97, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1433, in compile_inner
return _compile_inner(code, one_graph, hooks)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1467, in _compile_inner
dynamo_output = compile_frame(
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1341, in compile_frame
bytecode, tracer_output = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1600, in transform_code_object
tracer_output = transformations(instructions, code_options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1313, in transform
tracer_output = trace_frame(
^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 328, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 838, in trace_frame
run_tracer()
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 819, in run_tracer
tracer.run()
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1654, in run
while self.step():
^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1334, in step
self.dispatch_table[inst.opcode](self, inst)
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 866, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 2582, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1240, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/misc.py", line 1148, in call_function
return self.obj.call_method(tx, self.name, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/tensor.py", line 745, in call_method
return wrap_fx_proxy(
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/builder.py", line 2795, in wrap_fx_proxy
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/builder.py", line 2861, in wrap_fx_proxy_cls
out = _wrap_fx_proxy(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/builder.py", line 2972, in _wrap_fx_proxy
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3626, in get_fake_value
raise TorchRuntimeError(msg).with_traceback(e.traceback) from None
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3524, in get_fake_value
ret_val = wrap_fake_exception(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 2966, in wrap_fake_exception
return fn()
^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3525, in
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3735, in run_node
raise RuntimeError(make_error_message(e)).with_traceback(
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3705, in run_node
return getattr(args[0], node.target)(*args[1:], **kwargs) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_stats.py", line 29, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1397, in torch_dispatch
return self.dispatch(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 2155, in dispatch
return self._cached_dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1544, in _cached_dispatch_impl
output = self._dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 2793, in _dispatch_impl
op_impl_out = op_impl(self, func, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_impls.py", line 180, in dispatch_to_op_implementations_dict
return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_impls.py", line 629, in _view_meta
return torch._refs._reshape_view_helper(a, shape, allow_copy=False)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_refs/init.py", line 3950, in _reshape_view_helper
shape = utils.infer_size(shape, a.numel())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_prims_common/init.py", line 1064, in infer_size
torch._check(
File "/usr/local/lib/python3.12/dist-packages/torch/init.py", line 1732, in _check
_check_with(RuntimeError, cond, message) # pyrefly: ignore [bad-argument-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/init.py", line 1714, in _check_with
raise error_type(message_evaluated)
torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_method view(
(FakeTensor(..., device='cuda:0', size=(432, 1, 6, 256), dtype=torch.bfloat16), 432, 1, 768), **{}): got RuntimeError("shape '[432, 1, 768]' is invalid for input of size 663552")

from user code:
File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/attention.py", line 1221, in _apply_output_gate
gate = gate.view(*x.shape)

My experimental setup is as follows:

export PYTHONPATH="/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl:$PYTHONPATH"
export CUDA_DEVICE_MAX_CONNECTIONS=1
export VLLM_USE_V1=1
export VLLM_ALLREDUCE_USE_SYMM_MEM=0
export SWANLAB_API_KEY=xxx
export SWANLAB_MODE=cloud

echo $NODE_IP_LIST > env.txt
sed "s/:/ slots=/g" env.txt | sed "s/,/\n/g" > "hostfile"
sed "s/:.//g" env.txt | sed "s/,/\n/g" > "pssh.hosts"
pssh -i -t 0 -h pssh.hosts "echo 'export TORCH_CUDA_ARCH_LIST=\"9.0\"' >> ~/.bashrc"

source ~/.bashrc

bash stop_ray.sh
bash start_ray.sh

########################### Quick Config ###########################
# tp4pp2 ckpt no oom
# tp4 ckpt oom
# tp2pp2 ckpt oom

TP=${TP:-4}
PP=${PP:-2}
CP=${CP:-1}
EP=${EP:-1}
ETP=${ETP:-1}
GEN_TP=${GEN_TP:-4}

ALL_OFFLOAD=${ALL_OFFLOAD:-True}

rollout_name="vllm"
project_name='verl_grpo_qwen3_5_35b_geo3k'
exp_name='qwen3_5_27b_megatron_text_grpo'
OUTPUT_DIR=xxx
CKPTS_DIR=${OUTPUT_DIR}/"${project_name}"/"${exp_name}/ckpts"
ROLLOUT_DATA_DIR=${OUTPUT_DIR}/"${project_name}"/"${exp_name}/rollout_data"
VALIDATION_DATA_DIR=${OUTPUT_DIR}/"${project_name}"/"${exp_name}/validation_data"
LOG_DIR="${OUTPUT_DIR}/${project_name}/${exp_name}/logs/$(date +%Y%m%d_%H%M%S)"
SWANLAB_LOG_DIR=${OUTPUT_DIR}/"${project_name}"/"${exp_name}/swanlog"

mkdir -p ${CKPTS_DIR}
mkdir -p ${ROLLOUT_DATA_DIR}
mkdir -p ${VALIDATION_DATA_DIR}
mkdir -p ${LOG_DIR}
mkdir -p ${SWANLAB_LOG_DIR} 

adv_estimator=grpo
NNODES=2
max_prompt_length=1024
max_response_length=4096
train_prompt_bsz=64
train_prompt_mini_bsz=${train_prompt_bsz}

# train rollout params
train_temperature=1.2
train_rollot_n=8
train_top_p=1.0
train_top_k=-1 

# val rollout params
val_temperature=0
val_rollot_n=1
val_top_p=1.0
val_top_k=-1 
val_do_sample=False

# prepare model and data
HF_MODEL_PATH=${HF_MODEL_PATH:-xxx}
DATASET_DIR=${DATASET_DIR:-xxx}
train_path=${train_path:-${DATASET_DIR}/rl/gsm8k/train.parquet}
test_path=${test_path:-${DATASET_DIR}/rl/gsm8k/test.parquet}

########################### Parameter Arrays ###########################

DATA=(
    data.train_files=${train_path}
    data.val_files=${test_path}
    data.train_batch_size=${train_prompt_bsz}
    data.max_prompt_length=${max_prompt_length}
    data.max_response_length=${max_response_length}
    data.truncation='error'
    data.filter_overlong_prompts=True
)

MODEL=(
    actor_rollout_ref.model.path=${HF_MODEL_PATH}
    actor_rollout_ref.model.trust_remote_code=True
    actor_rollout_ref.model.use_remove_padding=False
)

ACTOR=(
    actor_rollout_ref.actor.optim.lr=1e-6
    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz}
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=30720
    actor_rollout_ref.actor.use_dynamic_bsz=False
    actor_rollout_ref.actor.use_kl_loss=False
    actor_rollout_ref.actor.kl_loss_coef=0.01
    actor_rollout_ref.actor.kl_loss_type=low_var_kl
    actor_rollout_ref.actor.entropy_coeff=0
    actor_rollout_ref.actor.megatron.use_mbridge=True
    actor_rollout_ref.actor.megatron.vanilla_mbridge=True
    actor_rollout_ref.actor.megatron.use_remove_padding=False
    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${TP}
    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${PP}
    actor_rollout_ref.actor.megatron.context_parallel_size=${CP}
    actor_rollout_ref.actor.megatron.expert_model_parallel_size=${EP}
    actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${ETP}
    actor_rollout_ref.actor.megatron.param_offload=${ALL_OFFLOAD}
    actor_rollout_ref.actor.megatron.optimizer_offload=${ALL_OFFLOAD}
    actor_rollout_ref.actor.megatron.grad_offload=${ALL_OFFLOAD}
    actor_rollout_ref.actor.megatron.dtype=bfloat16
    ++actor_rollout_ref.actor.megatron.override_transformer_config.attention_backend=flash
    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform
    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full
    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1
    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_aux_loss_coeff=0.01
    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_z_loss_coeff=0.001
    +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=1
    +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True
    +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True
    +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True
)

ROLLOUT=(
    actor_rollout_ref.rollout.name=${rollout_name}
    actor_rollout_ref.rollout.tensor_model_parallel_size=${GEN_TP}
    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 # 0.6 save ckpt 会oom
    actor_rollout_ref.rollout.n=8
    actor_rollout_ref.rollout.mode=async
    actor_rollout_ref.rollout.enforce_eager=False
    actor_rollout_ref.rollout.temperature=${train_temperature}
    actor_rollout_ref.rollout.top_p=${train_top_p}
    actor_rollout_ref.rollout.top_k=${train_top_k}
    actor_rollout_ref.rollout.dtype=bfloat16
    actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature}
    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p}
    actor_rollout_ref.rollout.val_kwargs.top_k=${val_top_k}
    actor_rollout_ref.rollout.val_kwargs.do_sample=${val_do_sample}
    actor_rollout_ref.rollout.val_kwargs.n=${val_rollot_n}
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1
    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=False
    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=30720
    actor_rollout_ref.rollout.checkpoint_engine.update_weights_bucket_megabytes=4096
)

REF=(
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1
    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=False
    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=30720
    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${TP}
    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${PP}
    actor_rollout_ref.ref.megatron.context_parallel_size=${CP}
    actor_rollout_ref.ref.megatron.expert_model_parallel_size=${EP}
    actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=${ETP}
    actor_rollout_ref.ref.megatron.param_offload=${ALL_OFFLOAD}
)

ALGORITHM=(
    algorithm.adv_estimator=${adv_estimator}
    algorithm.use_kl_in_reward=False
)

TRAINER=(
    trainer.critic_warmup=0
    trainer.logger='["console","swanlab"]'
    trainer.project_name=${project_name}
    trainer.experiment_name=${exp_name}
    trainer.default_local_dir=${CKPTS_DIR}
    trainer.n_gpus_per_node=8
    trainer.nnodes=${NNODES}
    trainer.save_freq=1
    trainer.val_before_train=False
    trainer.test_freq=5
    trainer.total_epochs=15
)

RAY_KWARGS=(
    +ray_kwargs.ray_init.runtime_env.env_vars.SWANLAB_LOG_DIR="${SWANLAB_LOG_DIR}"
    +ray_kwargs.ray_init.runtime_env.env_vars.SWANLAB_MODE="${SWANLAB_MODE}"
    +ray_kwargs.ray_init.runtime_env.env_vars.SWANLAB_API_KEY="${SWANLAB_API_KEY}"
)

########################### Launch ###########################

python3 -m verl.trainer.main_ppo \
    --config-path=config \
    --config-name='ppo_megatron_trainer.yaml' \
    "${DATA[@]}" \
    "${ALGORITHM[@]}" \
    "${MODEL[@]}" \
    "${ROLLOUT[@]}" \
    "${ACTOR[@]}" \
    "${REF[@]}" \
    "${TRAINER[@]}" \
    "${RAY_KWARGS[@]}" \
    "$@" 2>&1 | tee "$LOG_DIR/train.log"

bash stop_ray.sh

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions