Bug
_truncate_sample_output() in openai_endpoint_utils.py truncates tokens, rollout_log_probs, and loss_mask, but does not truncate rollout_routed_experts.
After truncation, len(sample.tokens) - 1 != rollout_routed_experts.shape[0], which causes Sample.validate() to fail with an assertion error.
When it triggers
agentic_tool_call.generate() calls truncate_samples_by_total_tokens(samples, max_seq_len, tokenizer), which calls _truncate_sample_output() when a sample's total token count exceeds max_seq_len. This is a realistic scenario in long multi-turn agentic sessions with --use-rollout-routing-replay enabled.
Comparison
Sample.strip_last_output_tokens() in types.py handles this correctly:
if self.rollout_routed_experts is not None:
self.rollout_routed_experts = self.rollout_routed_experts[:-n]
Suggested fix
Add the corresponding truncation in _truncate_sample_output:
if sample.rollout_routed_experts is not None:
prompt_experts = prompt_len - 1
sample.rollout_routed_experts = sample.rollout_routed_experts[:prompt_experts + keep_tokens]
(rollout_routed_experts has shape (len(tokens) - 1, num_layers, moe_router_topk), so the slicing accounts for the offset.)
Bug
_truncate_sample_output()inopenai_endpoint_utils.pytruncatestokens,rollout_log_probs, andloss_mask, but does not truncaterollout_routed_experts.After truncation,
len(sample.tokens) - 1 != rollout_routed_experts.shape[0], which causesSample.validate()to fail with an assertion error.When it triggers
agentic_tool_call.generate()callstruncate_samples_by_total_tokens(samples, max_seq_len, tokenizer), which calls_truncate_sample_output()when a sample's total token count exceedsmax_seq_len. This is a realistic scenario in long multi-turn agentic sessions with--use-rollout-routing-replayenabled.Comparison
Sample.strip_last_output_tokens()intypes.pyhandles this correctly:Suggested fix
Add the corresponding truncation in
_truncate_sample_output:(
rollout_routed_expertshas shape(len(tokens) - 1, num_layers, moe_router_topk), so the slicing accounts for the offset.)