Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#!/usr/bin/env bash
set -xeuo pipefail

project_name='GRPO'
exp_name='GRPO-qwen3-235b-megatron-fully-async'

RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
MODEL_PATH=Qwen3-235B-A22B
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=gsm8k/train.parquet
TEST_FILE=gsm8k/test.parquet

rollout_mode="async"
rollout_name="vllm" # sglang or vllm
if [ "$rollout_mode" = "async" ]; then
export VLLM_USE_V1=1
return_raw_chat="True"
fi
# Algorithm parameters
adv_estimator=grpo

use_kl_in_reward=False

# Response length parameters
max_prompt_length=$((1024 * 8))
max_response_length=$((1024 * 4))


# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7

# Performance Related Parameter
use_dynamic_bsz=False
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length)))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length)))
offload=True
train_ppo_micro_batch_size_per_gpu=1
infer_ppo_micro_batch_size_per_gpu=1
USE_MBRIDGE=True
USE_DIST_CKPT=False

gen_tp=8
gen_dp=8
rollout_max_num_seqs=64
max_num_batched_tokens=$((1024))
train_tp=4
train_ep=4
train_pp=8

# Fully async specific parameters
NNODES_ROLLOUT=${NNODES_ROLLOUT:-8}
NNODES_TRAIN=${NNODES_TRAIN:-8}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-16}

train_prompt_bsz=0
gen_prompt_bsz=1
n_resp_per_prompt=16
train_prompt_mini_bsz=128
total_rollout_steps=$(((512*400)))
staleness_threshold=0.5
trigger_parameter_sync_step=1
require_batches=1
partial_rollout=True

python -m verl.experimental.fully_async_policy.fully_async_main \
--config-path=config \
--config-name='fully_async_ppo_megatron_trainer.yaml'\
algorithm.adv_estimator=${adv_estimator} \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
actor_rollout_ref.rollout.free_cache_engine=True \
data.train_batch_size=${train_prompt_bsz} \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.filter_overlong_prompts=False \
data.truncation='error' \
data.return_raw_chat=${return_raw_chat} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
+actor_rollout_ref.model.override_config.model_config.max_position_embeddings=32768 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_ppo_micro_batch_size_per_gpu} \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \
actor_rollout_ref.actor.megatron.expert_model_parallel_size=${train_ep} \
actor_rollout_ref.ref.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.enable_prefix_caching=True \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.megatron.use_mbridge=$USE_MBRIDGE \
+actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=11 \
+actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=11 \
actor_rollout_ref.actor.megatron.use_dist_checkpointing=$USE_DIST_CKPT \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.data_parallel_size=${gen_dp} \
actor_rollout_ref.rollout.expert_parallel_size=64 \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=1024 \
actor_rollout_ref.actor.megatron.param_offload=${offload} \
actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \
actor_rollout_ref.actor.megatron.grad_offload=${offload} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.75 \
actor_rollout_ref.rollout.name=${rollout_name} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
actor_rollout_ref.rollout.mode=${rollout_mode} \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${infer_ppo_micro_batch_size_per_gpu} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \
actor_rollout_ref.ref.megatron.expert_model_parallel_size=${train_ep} \
actor_rollout_ref.ref.megatron.param_offload=${offload} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
trainer.critic_warmup=0 \
trainer.logger=['console'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.nnodes="${NNODES_TRAIN}" \
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
trainer.val_before_train=False \
trainer.test_freq=-1 \
trainer.save_freq=100 \
trainer.device='npu' \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \
actor_rollout_ref.rollout.enforce_eager=False \
rollout.nnodes="${NNODES_ROLLOUT}" \
rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \
rollout.total_rollout_steps="${total_rollout_steps}" \
trainer.total_epochs=10 \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.require_batches="${require_batches}" \
async_training.partial_rollout="${partial_rollout}" \
actor_rollout_ref.actor.optim.lr_decay_style='constant' \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.optim.lr_decay_steps=${total_rollout_steps} \
actor_rollout_ref.hybrid_engine=False \
trainer.device='npu' \
+actor_rollout_ref.rollout.engine_kwargs.vllm.compilation_config.cudagraph_capture_sizes="[8, 16, 32, 64, 128]" \
+actor_rollout_ref.rollout.engine_kwargs.vllm.compilation_config.cudagraph_mode="FULL_DECODE_ONLY" 2>&1 | tee "logs/verl_qwen3_235b_fully_async_$(date +%Y%m%d_%H%M).log"
9 changes: 8 additions & 1 deletion verl/workers/engine_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
from omegaconf import DictConfig, open_dict
from tensordict import NonTensorData, TensorDict
from torch.distributed.device_mesh import init_device_mesh

try:
from verl.workers.engine.mindspeed.transformer_impl import repatch
except ImportError:
repatch = None
from verl.checkpoint_engine import CheckpointEngineRegistry
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register
Expand Down Expand Up @@ -116,6 +119,10 @@ def __init__(self, config: TrainingWorkerConfig):
# TODO: this is not elegant and should refactor later
self.engine_config.use_remove_padding = self.model_config.get("use_remove_padding", False)
self.engine_config.use_fused_kernels = self.model_config.get("use_fused_kernels", False)

if repatch is not None:
# NPU MindSpeed patch, will be refactored with MindSpeedEngine.
repatch(self.engine_config.get("override_transformer_config", {}))

self.profiler_config = self.config.profiler_config
if self.profiler_config is not None:
Expand Down