Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
#!/usr/bin/env bash
set -xeuo pipefail

project_name='GRPO'
exp_name='GRPO-qwen3-235b-megatron-fully-async'

RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
MODEL_PATH=Qwen3-235B-A22B
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=gsm8k/train.parquet
TEST_FILE=gsm8k/test.parquet

rollout_mode="async"
rollout_name="vllm" # sglang or vllm
if [ "$rollout_mode" = "async" ]; then
export VLLM_USE_V1=1
return_raw_chat="True"
fi
# Algorithm parameters
adv_estimator=grpo

use_kl_in_reward=False

# Response length parameters
max_prompt_length=$((1024 * 8))
max_response_length=$((1024 * 4))


# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7

# Performance Related Parameter
use_dynamic_bsz=False
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length)))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length)))
offload=True
train_ppo_micro_batch_size_per_gpu=1
infer_ppo_micro_batch_size_per_gpu=1
USE_MBRIDGE=True
USE_DIST_CKPT=False

gen_tp=8
gen_dp=8
rollout_max_num_seqs=64
max_num_batched_tokens=$((1024))
train_tp=4
train_ep=4
train_pp=8

# Fully async specific parameters
NNODES_ROLLOUT=${NNODES_ROLLOUT:-8}
NNODES_TRAIN=${NNODES_TRAIN:-8}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-16}

train_prompt_bsz=0
gen_prompt_bsz=1
n_resp_per_prompt=16
train_prompt_mini_bsz=128
total_rollout_steps=$(((512*400)))
staleness_threshold=0.5
trigger_parameter_sync_step=1
require_batches=1
partial_rollout=True

python -m verl.experimental.fully_async_policy.fully_async_main \
--config-path=config \
--config-name='fully_async_ppo_megatron_trainer.yaml'\
algorithm.adv_estimator=${adv_estimator} \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
actor_rollout_ref.rollout.free_cache_engine=True \
data.train_batch_size=${train_prompt_bsz} \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.filter_overlong_prompts=False \
data.truncation='error' \
data.return_raw_chat=${return_raw_chat} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
+actor_rollout_ref.model.override_config.model_config.max_position_embeddings=32768 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_ppo_micro_batch_size_per_gpu} \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \
actor_rollout_ref.actor.megatron.expert_model_parallel_size=${train_ep} \
actor_rollout_ref.ref.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.enable_prefix_caching=True \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.megatron.use_mbridge=$USE_MBRIDGE \
+actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=11 \
+actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=11 \
actor_rollout_ref.actor.megatron.use_dist_checkpointing=$USE_DIST_CKPT \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.data_parallel_size=${gen_dp} \
actor_rollout_ref.rollout.expert_parallel_size=64 \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=1024 \
actor_rollout_ref.actor.megatron.param_offload=${offload} \
actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \
actor_rollout_ref.actor.megatron.grad_offload=${offload} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.75 \
actor_rollout_ref.rollout.name=${rollout_name} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
actor_rollout_ref.rollout.mode=${rollout_mode} \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${infer_ppo_micro_batch_size_per_gpu} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \
actor_rollout_ref.ref.megatron.expert_model_parallel_size=${train_ep} \
actor_rollout_ref.ref.megatron.param_offload=${offload} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
trainer.critic_warmup=0 \
trainer.logger=['console'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.nnodes="${NNODES_TRAIN}" \
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
trainer.val_before_train=False \
trainer.test_freq=-1 \
trainer.save_freq=100 \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \
actor_rollout_ref.rollout.enforce_eager=False \
rollout.nnodes="${NNODES_ROLLOUT}" \
rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \
rollout.total_rollout_steps="${total_rollout_steps}" \
trainer.total_epochs=10 \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.require_batches="${require_batches}" \
async_training.partial_rollout="${partial_rollout}" \
actor_rollout_ref.actor.optim.lr_decay_style='constant' \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.optim.lr_decay_steps=${total_rollout_steps} \
actor_rollout_ref.hybrid_engine=False \
trainer.device='npu' \
+actor_rollout_ref.rollout.engine_kwargs.vllm.compilation_config.cudagraph_capture_sizes="[8, 16, 32, 64, 128]" \
+actor_rollout_ref.rollout.engine_kwargs.vllm.compilation_config.cudagraph_mode="FULL_DECODE_ONLY" 2>&1 | tee "logs/verl_qwen3_235b_fully_async_$(date +%Y%m%d_%H%M).log"
Loading