Skip to content

_truncate_sample_output does not truncate rollout_routed_experts #861

@DavidBellamy

Description

@DavidBellamy

Bug

_truncate_sample_output() in openai_endpoint_utils.py truncates tokens, rollout_log_probs, and loss_mask, but does not truncate rollout_routed_experts.

After truncation, len(sample.tokens) - 1 != rollout_routed_experts.shape[0], which causes Sample.validate() to fail with an assertion error.

When it triggers

agentic_tool_call.generate() calls truncate_samples_by_total_tokens(samples, max_seq_len, tokenizer), which calls _truncate_sample_output() when a sample's total token count exceeds max_seq_len. This is a realistic scenario in long multi-turn agentic sessions with --use-rollout-routing-replay enabled.

Comparison

Sample.strip_last_output_tokens() in types.py handles this correctly:

if self.rollout_routed_experts is not None:
    self.rollout_routed_experts = self.rollout_routed_experts[:-n]

Suggested fix

Add the corresponding truncation in _truncate_sample_output:

if sample.rollout_routed_experts is not None:
    prompt_experts = prompt_len - 1
    sample.rollout_routed_experts = sample.rollout_routed_experts[:prompt_experts + keep_tokens]

(rollout_routed_experts has shape (len(tokens) - 1, num_layers, moe_router_topk), so the slicing accounts for the offset.)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions