Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 165 additions & 10 deletions verl/trainer/main_ppo_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
"""

import asyncio
import copy
import logging
import math
import os
import threading
import time
Expand Down Expand Up @@ -85,6 +87,7 @@
from verl.utils.fs import copy_to_local
from verl.utils.import_utils import load_class_from_fqn
from verl.utils.metric import reduce_metrics
from verl.utils.model import compute_position_id_with_mask
from verl.utils.py_functional import rename_dict
from verl.utils.ray_utils import auto_await
from verl.utils.seqlen_balancing import calculate_workload, get_seqlen_balanced_partitions, log_seqlen_unbalance
Expand Down Expand Up @@ -995,11 +998,156 @@ def _compute_reward_colocate(self, batch: KVBatchMeta, metrics: dict) -> KVBatch
# TODO: add reward model
raise NotImplementedError

def _get_required_batch_multiple(self, dp_size: int) -> int:
"""Return the global batch multiple required by downstream train steps(e.g. critics, actors)."""
required_multiple = dp_size

# If enabled with critic training, the batch should align with critic PPO mini-batches.
if self.use_critic:
critic_global_mini_batch_size = self.config.critic.ppo_mini_batch_size
critic_global_mini_batch_size *= self.config.actor_rollout_ref.rollout.n
required_multiple = math.lcm(required_multiple, critic_global_mini_batch_size)

# If there is an actor update, the batch should align with actor PPO mini-batches too.
if self.config.trainer.critic_warmup <= self.global_steps:
actor_global_mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size
actor_global_mini_batch_size *= self.config.actor_rollout_ref.rollout.n
required_multiple = math.lcm(required_multiple, actor_global_mini_batch_size)

# Notice lcm(a, b, c) == lcm(lcm(a, b), c), so it is optimal.
return required_multiple

def _construct_minimal_padding_template(self, source_td, source_tag: dict) -> tuple[dict, dict]:
"""Construct a minimal text-only padding template of one prompt token and one response token."""

# Iterate through the key and copy the sample template from a existing sample.
template_sample = {}
for key in source_td.keys():
value = source_td[key]
template_sample[key] = value.clone() if isinstance(value, torch.Tensor) else copy.deepcopy(value)

# Deep copy the template tag from a existing sample.
template_tag = copy.deepcopy(source_tag)

# Build minimal sequence
token_id = self.tokenizer.eos_token_id
prompts = torch.full((1,), token_id, dtype=torch.int64)
input_ids = prompts.repeat(2)
attention_mask = torch.ones_like(input_ids, dtype=torch.int64)
response_mask = torch.zeros_like(prompts)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The response_mask should have the same shape as input_ids (length 2) to ensure consistency with real samples. In verl, sequence-level masks like response_mask and loss_mask typically match the full sequence length (input_ids). A shape mismatch here will cause list_of_dict_to_tensordict to fail when stacking padding samples into a batch, or lead to errors in the model's forward pass.

Suggested change
response_mask = torch.zeros_like(prompts)
response_mask = torch.zeros_like(input_ids)

position_ids = self._build_padding_position_ids(template_sample.get("position_ids"), attention_mask)
routed_experts = self._build_padding_routed_experts(template_sample.get("routed_experts"), input_ids.size(0))

# Update the fields and remove redundant parts
template_sample.update(
prompts=prompts,
responses=prompts.clone(),
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
num_turns=0,
response_mask=response_mask,
loss_mask=response_mask,
rm_scores=torch.zeros_like(response_mask, dtype=torch.float32),
rollout_log_probs=torch.zeros_like(response_mask, dtype=torch.float32),
)
if "multi_modal_inputs" in template_sample:
template_sample["multi_modal_inputs"] = {}
if routed_experts is not None:
template_sample["routed_experts"] = routed_experts
else:
template_sample.pop("routed_experts", None)

# Padding flag is deployed to protect metrics calculation (e.g. response length, score, reward).
template_tag.update(is_padding=True, prompt_len=1, response_len=1, seq_len=2)
return template_sample, template_tag

@staticmethod
def _build_padding_position_ids(source_position_ids, attention_mask: torch.Tensor) -> torch.Tensor:
"""Build padding position ids with the same rank/prefix shape as the source sample."""
position_ids = compute_position_id_with_mask(attention_mask.unsqueeze(0)).squeeze(0)
if not isinstance(source_position_ids, torch.Tensor):
return position_ids

position_ids = position_ids.to(device=source_position_ids.device, dtype=source_position_ids.dtype)
if source_position_ids.dim() <= 1:
return position_ids

view_shape = (1,) * (source_position_ids.dim() - 1) + (position_ids.size(-1),)
return position_ids.reshape(view_shape).expand(*source_position_ids.shape[:-1], -1).clone()

@staticmethod
def _build_padding_routed_experts(source_routed_experts, seq_len: int) -> torch.Tensor | None:
"""Build a zero routed-experts tensor matching the source per-token expert shape."""
if not isinstance(source_routed_experts, torch.Tensor):
return None
if source_routed_experts.dim() == 0:
return torch.zeros_like(source_routed_experts)
return torch.zeros(
(seq_len, *source_routed_experts.shape[1:]),
dtype=source_routed_experts.dtype,
device=source_routed_experts.device,
)

def _upsample_batch_to_divisible_size(self, batch: KVBatchMeta, batch_multiple: int) -> KVBatchMeta:
"""Append synthetic no-op samples so the batch size becomes divisible by batch_multiple.

The synthetic samples reuse the shortest real sample as a metadata template,
but manually construct a minimal prompt_len=1 / response_len=1 sequence and
zero out reward-related fields so they do not contribute to PPO, entropy, or
KL losses. An is_padding flag is added for the future metrics calculation.
"""
remainder = len(batch) % batch_multiple
if remainder == 0:
return batch

# Take the first trajectory as the metadata template for padding data.
source_idx = 0
source_key = batch.keys[source_idx]
source_td = tq.kv_batch_get(keys=[source_key], partition_id=batch.partition_id)[0]

# Contruct the minimal padding template of one prompt token and one response token
template_sample, template_tag = self._construct_minimal_padding_template(source_td, batch.tags[source_idx])

# All padding data use the same uid (also the same trajectory_id 0 but with ascending session_ids)
# This uid is not identical to any of the actual data, so it won't affect the grpo advantage value.
pad_uid = f"pad{uuid.uuid4().hex}"
template_sample["uid"] = pad_uid

# Construct the padding samples in a for-loop
pad_keys = []
pad_tags = []
pad_fields = []
pad_size = batch_multiple - remainder
for local_idx in range(pad_size):
sample = copy.deepcopy(template_sample)
# Use incremental local_idx as different session_ids
pad_keys.append(f"{pad_uid}_{local_idx}_0")
if "session_id" in sample:
sample["session_id"] = local_idx
pad_fields.append(sample)
pad_tags.append(copy.deepcopy(template_tag))

tq.kv_batch_put(
keys=pad_keys,
partition_id=batch.partition_id,
fields=list_of_dict_to_tensordict(pad_fields),
tags=pad_tags,
)
print(
f"[DEBUG] Upsampled batch from {len(batch)} to {len(batch) + pad_size} "
f"with {pad_size} synthetic padding samples for required_multiple={batch_multiple}"
)
return KVBatchMeta(
keys=batch.keys + pad_keys,
tags=batch.tags + pad_tags,
partition_id=batch.partition_id,
fields=batch.fields,
extra_info=batch.extra_info,
)

def _balance_batch(self, batch: KVBatchMeta, metrics, logging_prefix="global_seqlen", keep_minibatch=False):
"""Reorder the data on single controller such that each dp rank gets similar total tokens."""
global_seqlen_lst = torch.tensor([tag["seq_len"] for tag in batch.tags], dtype=torch.int64)
workload_lst = calculate_workload(global_seqlen_lst)

# get actor dp size
role, worker_group = "actor", self.actor_rollout_wg
if role not in worker_group._dispatch_info:
Expand All @@ -1009,9 +1157,11 @@ def _balance_batch(self, batch: KVBatchMeta, metrics, logging_prefix="global_seq
dp_rank_mapping = worker_group._dispatch_info[role]
dp_size = max(dp_rank_mapping) + 1

# TODO: up sampling if batch is not divisible by dp_size
if len(batch) % dp_size != 0:
raise ValueError(f"Batch size {len(batch)} is not divisible by dp_size {dp_size}")
# Upsampling the batch with padding sequences
batch_multiple = self._get_required_batch_multiple(dp_size)
batch = self._upsample_batch_to_divisible_size(batch, batch_multiple)
global_seqlen_lst = torch.tensor([tag["seq_len"] for tag in batch.tags], dtype=torch.int64)
workload_lst = calculate_workload(global_seqlen_lst)

# reorder based on index. The data will be automatically equally partitioned by dispatch function
global_partition_lst = get_seqlen_balanced_partitions(workload_lst, k_partitions=dp_size, equal_size=True)
Expand All @@ -1020,6 +1170,7 @@ def _balance_batch(self, batch: KVBatchMeta, metrics, logging_prefix="global_seq
seqlen_list=global_seqlen_lst.tolist(), partitions=global_partition_lst, prefix=logging_prefix
)
metrics.update(global_balance_stats)
return batch

def _compute_old_log_prob(self, batch: KVBatchMeta, metrics: dict) -> KVBatchMeta:
"""Compute the old log prob of the batch."""
Expand Down Expand Up @@ -1244,6 +1395,7 @@ def _update_actor(self, batch: KVBatchMeta, metrics: dict) -> KVBatchMeta:

def _compute_metrics(self, batch: KVBatchMeta, metrics, timing_raw, global_steps, epoch):
# 1. collect necessary fields from TransferQueue for computing metrics
non_padding_mask = np.array([not tag.get("is_padding", False) for tag in batch.tags], dtype=bool)
fields = [
"prompts",
"responses",
Expand All @@ -1256,6 +1408,7 @@ def _compute_metrics(self, batch: KVBatchMeta, metrics, timing_raw, global_steps
"num_turns",
]
data = tq.kv_batch_get(keys=batch.keys, partition_id=batch.partition_id, select_fields=fields)
num_turns = np.array(data.pop("num_turns").tolist())
prompt_length = data["prompts"].offsets().diff()
response_length = data["responses"].offsets().diff()
global_token_num = (prompt_length + response_length).tolist()
Expand All @@ -1266,18 +1419,20 @@ def _compute_metrics(self, batch: KVBatchMeta, metrics, timing_raw, global_steps
data["prompt_length"] = prompt_length.float()
data["response_length"] = response_length.float()
batch = DataProto(batch=data, meta_info={"global_token_num": global_token_num})
metrics_batch = batch.select_idxs(non_padding_mask) if non_padding_mask.any() else batch

# 2. compute metrics
metrics.update({"training/global_step": global_steps, "training/epoch": epoch})
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_data_metrics(batch=metrics_batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
gradient_norm = metrics.get("actor/grad_norm", None)
metrics.update(compute_variance_proxy_metrics(batch=batch, gradient_norm=gradient_norm))
metrics.update(compute_variance_proxy_metrics(batch=metrics_batch, gradient_norm=gradient_norm))

# 3. other auxiliary metrics
num_turns = np.array(data.pop("num_turns").tolist())
if non_padding_mask.any():
num_turns = num_turns[non_padding_mask]
metrics.update(
{
"training/num_turns/mean": num_turns.mean(),
Expand Down Expand Up @@ -1391,7 +1546,7 @@ def step(self, batch_dict: dict, metrics: dict, timing_raw: dict) -> KVBatchMeta
batch = self._compute_reward_colocate(batch)

# 4. balance batch across data parallel groups
self._balance_batch(batch, metrics=metrics)
batch = self._balance_batch(batch, metrics=metrics)

# 5. compute old_log_prob
with marked_timer("old_log_prob", timing_raw, color="blue"):
Expand Down