[data, trainer] fix: batch padding for multi-trajectory#5969
[data, trainer] fix: batch padding for multi-trajectory#5969ZhentaoFan wants to merge 2 commits intoverl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements batch upsampling using synthetic padding sequences to ensure global batch sizes align with PPO mini-batch requirements for actors and critics. Key additions include methods for calculating the required batch multiple, constructing minimal padding templates, and filtering these samples during metric computation. Review feedback highlights potential compatibility issues with VLM and MoE models, specifically regarding the shape of response masks, the dimensionality of position IDs, and the removal of multi-modal or expert-related fields in padding samples.
| prompts = torch.full((1,), token_id, dtype=torch.int64) | ||
| input_ids = prompts.repeat(2) | ||
| attention_mask = torch.ones_like(input_ids, dtype=torch.int64) | ||
| response_mask = torch.zeros_like(prompts) |
There was a problem hiding this comment.
The response_mask should have the same shape as input_ids (length 2) to ensure consistency with real samples. In verl, sequence-level masks like response_mask and loss_mask typically match the full sequence length (input_ids). A shape mismatch here will cause list_of_dict_to_tensordict to fail when stacking padding samples into a batch, or lead to errors in the model's forward pass.
| response_mask = torch.zeros_like(prompts) | |
| response_mask = torch.zeros_like(input_ids) |
verl/trainer/main_ppo_sync.py
Outdated
| responses=prompts.clone(), | ||
| input_ids=input_ids, | ||
| attention_mask=attention_mask, | ||
| position_ids=compute_position_id_with_mask(attention_mask.unsqueeze(0)).squeeze(0), |
There was a problem hiding this comment.
Manually computing position_ids as a 1D tensor may cause crashes with models that expect multi-dimensional position IDs (e.g., VLMs with 2D RoPE or specific multimodal architectures). If the padding samples have 1D position_ids while the real samples in the batch have 2D position_ids, batch collation or the model's forward pass will fail. It is safer to derive the structure from the source_td or use the model's specific position ID computation logic.
verl/trainer/main_ppo_sync.py
Outdated
| template_sample.pop("multi_modal_inputs", None) | ||
| template_sample.pop("routed_experts", None) |
There was a problem hiding this comment.
Popping multi_modal_inputs and routed_experts makes the padding samples incompatible with VLM and MoE models. These fields are typically required by the model during the forward pass (e.g., in compute_log_prob). If these keys are present in real samples but missing in padding samples, tq.kv_batch_get will fail to retrieve them for the padding keys, or the model will crash due to missing inputs. Instead of popping them, consider providing dummy values (e.g., zeros or empty structures) that match the padding sequence length.
|
@eric-haibin-lin @vermouth1992 @tongyx361 @PeterSH6 Ready for review. |
What does this PR do?
Background
Inside the current tq_trainer, AgentLoopWorkerTQ already supports the multi-trajectory feature. However, during actual training, sample (trajectory)-level padding is still required for each batch so that the number of samples is divisible by both
dp_sizeandmini_batch_size; otherwise, an error will be thrown. This PR fixes the bug and addresses the following considerations:Upsampling:
% dp_size == 0and% mini_batch_size == 0(and% critic_mini_batch_size == 0if training the critic).Padding:
is_paddingflag is added to the tags of padded samples to avoid impacting the accuracy metrics such as score, reward, and response length (while performance metrics still include padded samples).Verification Experiments with Multi-Trajectory Agent:
The three smallest primes — 2, 3, and 5 — are chosen to form the relevant hyperparameters:
[dp=2, batch_size=45, mini_batch_size=15, rollout.n=8].