Skip to content

[trainer, perf] feat: wire use_mask_nesting end-to-end, nest prompts/responses#6009

Draft
tongyx361 wants to merge 13 commits intoverl-project:mainfrom
tongyx361:tyx/feat/nested-tensor
Draft

[trainer, perf] feat: wire use_mask_nesting end-to-end, nest prompts/responses#6009
tongyx361 wants to merge 13 commits intoverl-project:mainfrom
tongyx361:tyx/feat/nested-tensor

Conversation

@tongyx361
Copy link
Copy Markdown
Collaborator

@tongyx361 tongyx361 commented Apr 15, 2026

What does this PR do?

Make the use_mask_nesting=True path usable end-to-end and push its
dispatch payload compression close to the theoretical ceiling.

The base commit b6e6562e feat: nested tensor utilities landed the
library-level utilities (nest_batch_by_mask / MaskNestingSpec /
extract_response) but running with trainer.use_mask_nesting=True
crashed at the first _compress_batch call. This PR:

  1. Fixes the bug chain blocking end-to-end execution of the nesting
    path (loss_mask / response_mask plumbing, response-axis shape
    mismatches).
  2. Introduces a canonical max_prompt_length / max_response_length
    stashed on the batch so extract_response and
    TensorDict.to_padded_tensor agree on the response-axis width.
  3. Extends KNOWN_FIELD_TO_MASK_AND_PAD with prompts / responses
    (which were silently staying dense and accounting for ~78% of
    post-nest wire bytes) by deriving a prompt_mask from
    attention_mask[:, :max_prompt] at compress time.
  4. Adds a coverage check inside nest_batch_by_mask that reports (or,
    with trainer.strict_mask_nesting=True, raises on) any tensor
    field still dense after nesting.

No related issue; Slack / GH discussion on the feature commit
b6e6562e is the context. Duplicate search:
gh pr list --repo verl-project/verl --state open --search "use_mask_nesting"
/ "nest_batch_by_mask" returned none relevant.

Checklist Before Starting

  • Search for similar PRs — none open that touch this path:
    gh pr list --repo verl-project/verl --state open --search "use_mask_nesting"
  • PR title follows [{modules}] {type}: {description}:
    [trainer, perf] feat: make use_mask_nesting usable end-to-end, nest prompts/responses
    (no [BREAKING] — defaults unchanged; new trainer.strict_mask_nesting
    defaults to False).

Test

Not CI-testable in its full form (requires Ray + 16×A100 + vLLM
rollout). Validated on a live RL run: Qwen2.5-Math-7B, 2 nodes × 8×A100,
DAPO-Math-17k, max_prompt_length=2048, max_response_length=14336,
train_prompt_bsz=512, rollout.n=16 (flat batch 8192), adv=rloo,
ULysses SP=4, FSDP=8, vLLM TP=1.

Dispatch timings, steady-state averages over steps 2-4:

Metric baseline
(use_mask_nesting=False)
first end-to-end-working nest
(pre-registry-fix)
this PR
wg_dispatch/actor_rollout_compute_log_prob 8.52 s 4.67 s 3.33 s
wg_dispatch/actor_rollout_update_actor 11.26 s 5.40 s 4.19 s
combined dispatch 19.78 s 10.07 s 7.52 s
dispatch speedup vs baseline 1.96× 2.63×
step total 557.6 s 557.1 s 526.8 s (−5.5%)

Modelling dispatch as T = α + β · bytes solves to α ≈ 6 s (Ray RPC /
serialize fixed cost) + β ≈ 14 s (payload-proportional). The observed
2.63× corresponds to a back-solved payload compression ≈ 11.2×,
against a ~11.9× theoretical ceiling from the batch's 8.4% valid-token
fill ratio. Coverage check reports no remaining dense tensor fields on
this task.

Training loss / grad-norm / MFU remain stable across observed steps
(step 4 throughput 1216 tok/s, actor MFU 39.5%, actor_infer MFU 60.6%).

Validation runs:

  • Baseline (use_mask_nesting=False): job fd2a521f2e3e80d0, 100+
    steps completed — dispatch reference numbers.
  • Pre-PR (nesting on, no fix): job b09b7e4e0efee766 run_times=0 —
    crashes at forward_backward_batch with KeyError('loss_mask').
  • First end-to-end run (fixes only, no registry expansion):
    b09b7e4e0efee766 run_times=13 — training proceeds, dispatch 1.96×.
  • With full PR: forked job 23498c720a9d1e9c — 4 training steps
    complete, dispatch 2.63×, no warning, loss stable.

Numerical equivalence vs baseline (first 4 steps, same config fork,
same model init):

step baseline critic/score/mean ours critic/score/mean Δ
1 −0.9841 −0.9849 8 × 10⁻⁴
2 −0.9880 −0.9856 2.4 × 10⁻³
3 −0.9883 −0.9863 2.0 × 10⁻³
4 −0.9849 −0.9849 0

score/{max,min} = ±1 match exactly on all four steps; entropy,
grad_norm, pg_loss, response_length/mean all land in the same
distribution. The millidecimal-level drift on score/mean corresponds
to a few dozen of the 8192 samples flipping binary correctness — well
within vLLM rollout non-determinism (PagedAttention scheduling order
perturbs logits at ULP scale, which under temperature=1.0 sampling
diverges trajectories even with identical prompts + model weights).
No evidence of systematic bias; consistent with nesting being purely
a transport-layer change.

Local checks:

pre-commit install
pre-commit run --all-files --show-diff-on-failure --color=always   # passes

CI unit tests still cover the utilities via the existing
tests/utils/test_nested_tensor_on_cpu.py /
tests/utils/test_padding_on_cpu.py from the base feature commit.

API and Usage Example

No required user-facing changes. Defaults preserve existing behaviour.

Two new trainer.* knobs:

trainer:
  use_mask_nesting: True          # existing; this PR makes it functional
  strict_mask_nesting: False      # NEW — set True in CI / dev to fail fast
                                  # when a new batch tensor field bypasses
                                  # nesting (i.e. absent from the registry).

Registry extension (library-side, no user action required):

# verl/workers/utils/padding.py
KNOWN_FIELD_TO_MASK_AND_PAD = {
    ...
    "prompts":   ("prompt_mask",   PAD_TOKEN_ID),  # NEW — prompt_mask is
                                                   # derived from
                                                   # attention_mask[:, :max_prompt]
                                                   # in _compress_batch
    "responses": ("response_mask", PAD_TOKEN_ID),  # NEW
    # log-prob / entropy fields moved from ("attention_mask", ...) to
    # ("response_mask", ...) because _decompress_model_outputs stores
    # them as (bsz, max_response_len).
    "old_log_probs":     ("response_mask", 0.0),   # CHANGED
    "ref_log_prob":      ("response_mask", 0.0),   # CHANGED
    "rollout_log_probs": ("response_mask", 0.0),   # CHANGED
    "entropys":          ("response_mask", 0.0),   # CHANGED
    ...
}

New nest_batch_by_mask kwargs (both have backward-compatible defaults):

nest_batch_by_mask(
    data,
    pad_token_id=tok.pad_token_id,
    # new, default False — raise instead of warn on un-nested tensor fields
    strict=False,
    # new, default None — names of tensor fields that are intentionally
    # expected to stay dense and should be excluded from the coverage check
    ignore_dense_fields={"my_scalar_field"},
)

New library helper used by ppo_loss / value_loss:

from verl.workers.utils.padding import select_and_pad_to_response

# replaces  data.select(*fields).to_padded_tensor()
# and right-pads the response dim to the batch-stashed
# max_response_length so log_prob (from extract_response) and
# old_log_probs / advantages / response_mask all share the same width.
data = select_and_pad_to_response(data, "response_mask", "old_log_probs", "advantages")

Design & Code Changes

Bug fixes required to execute use_mask_nesting=True

Each surfaced as the previous one was unblocked; cumulative in the
final diff.

  • _compress_batch: alias loss_mask = response_mask before
    nest_batch_by_mask (and list loss_mask in
    field_to_mask_and_pad) so it becomes a nested all-ones jagged
    field. Workers' forward_backward_batch was hitting
    KeyError('loss_mask') because nest_in_td pops response_mask,
    making the old post-nest alias branch dead.
  • KNOWN_FIELD_TO_MASK_AND_PAD: re-pair old_log_probs /
    ref_log_prob / rollout_log_probs / entropys with
    response_mask. _decompress_model_outputs stores them as
    (bsz, max_response_len), so the shape-prefix assertion against
    attention_mask (max_total) was firing in _update_actor.
  • _compress_batch: restore response_mask as a nested alias of
    loss_mask after nesting so worker-side ppo_loss / value_loss
    data.select("response_mask", ...) keeps working;
    _decompress_batch drops the alias before unnest_batch_by_mask.

Canonical max-length carried on the batch

Two independent code paths were deriving the response-axis output
width from tensor shapes and drifting after worker chunking:

  • extract_response / slice_response padded to the nesting spec's
    stashed mask_shape[-1] (original full-batch max).
  • TensorDict.to_padded_tensor() in loss fns padded to the
    per-micro-batch local max.

The two collided in compute_policy_loss_vanilla
(log_prob - old_log_prob: (bs, 14336) vs (bs, 301)). Fix:

  • _compress_batch writes
    config.data.{max_prompt_length, max_response_length} onto the batch
    via tu.assign_non_tensor.
  • prepare_response_slice reads the stashed value first, falls back
    to shape inference.
  • New select_and_pad_to_response in verl/workers/utils/padding.py
    replaces data.select(...).to_padded_tensor() in ppo_loss /
    value_loss and right-pads each selected field's response dim up to
    the canonical width, so every response-axis arithmetic sees matching
    shapes regardless of chunking.

Coverage / strict mode

  • prompts paired with a new prompt_mask (derived from
    attention_mask[:, :max_prompt_length] in _compress_batch, popped
    by _decompress_batch after unnest) and responses paired with
    response_mask.
  • nest_batch_by_mask walks the TD once after nesting and reports any
    tensor field that stayed dense, excluding
    _DEFAULT_IGNORE_DENSE = {"dummy_tensor"} and spec-owned RLE
    offsets/lengths. Offenders are printed with shape / dtype /
    bytes.
  • New trainer.strict_mask_nesting flag (default False): when
    True, the coverage check raises RuntimeError instead of warning.

Files touched

 verl/trainer/config/ppo_trainer.yaml               |  +9
 verl/trainer/config/_generated_ppo_*_trainer.yaml  |  +8 (autogen)
 verl/trainer/ppo/ray_trainer.py                    |  +40 −3
 verl/workers/utils/padding.py                      | +142 −12
 verl/workers/utils/losses.py                       |  +4 −3

Checklist Before Submitting

  • Read the Contribute Guide.
  • pre-commit run --all-files passes.
  • Add / Update documentation — not yet; the feature commit
    b6e6562e did not add user docs for use_mask_nesting; a
    follow-up doc PR will describe the switch and the new
    strict_mask_nesting flag once this lands.
  • Add CI test — existing CPU tests in
    tests/utils/test_nested_tensor_on_cpu.py and
    tests/utils/test_padding_on_cpu.py continue to cover the
    library utilities. The end-to-end behaviour (training-step
    convergence with use_mask_nesting=True) requires a GPU-backed
    e2e test; I can extend
    .github/workflows/e2e_ppo_trainer.yml with a
    use_mask_nesting=True variant in a follow-up once reviewers
    are okay with the test-hour budget.
  • ci-request Slack ping — will do after addressing review.
  • No recipe submodule change.

AI-assisted: written with Claude (Claude Opus 4.6 1M context, via
Claude Code). Every changed line was reviewed by the submitter and the
rationale was validated end-to-end on the running RL workload before
being committed. Commits carry
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
trailers.

tongyx361 and others added 12 commits April 14, 2026 00:04
- MaskNestingSpec-based nest/unnest for TensorDict (mask-driven RLE)
- Field registry (KNOWN_FIELD_TO_MASK_AND_PAD) with DynamicPadValue for token IDs
- Dtype compression (KNOWN_FIELD_DTYPE_COMPRESSIONS) and shape permutation (KNOWN_FIELD_PERMUTATIONS)
- Two-layer response extraction: UnnestContext + ResponseSliceContext (library vs PPO)
- extract_response() dispatching over legacy/new path for A/B experiments
- data_io timing metrics: compress, decompress, wg_dispatch/execute/collect per method
- Tests: nest/unnest roundtrip, dispatch compat, response extraction, legacy equivalence

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
`nest_batch_by_mask` pops response_mask, so the post-nest
`if "response_mask" in batch_td` check never fired and loss_mask was
never aliased in the nesting path. Workers then hit
KeyError('loss_mask') in forward_backward_batch.

Alias before nesting and register loss_mask in the nesting spec so it
gets nested alongside response_mask-paired fields as an all-ones
jagged tensor with per-row length equal to the response token count.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
`old_log_probs`, `ref_log_prob`, `rollout_log_probs`, `entropys` are
stored back into the batch as (bsz, max_response_len) after
`_decompress_model_outputs` -> `extract_response`. Pairing them with
`attention_mask` (full prompt+response axis) in
KNOWN_FIELD_TO_MASK_AND_PAD tripped the shape-prefix assertion in
`make_mask_nesting_specs` the next time `_compress_batch` was called
(from `_update_actor`):

  AssertionError: field 'old_log_probs' has shape (8192, 14336);
  expected leading (*batch_dims, *sample_mask_dims) to match mask
  'attention_mask' shape (8192, 16384).

Move them to `response_mask` alongside the other response-only
quantities (advantages / returns / ...). The `teacher_*` distillation
pair and `routed_experts` remain on `attention_mask` since they are
full-length.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Worker-side loss fns (ppo_loss, value_loss) still select "response_mask"
via `data.select(...)`. In the nested path, nest_in_td pops
response_mask, so train_batch hit:

  KeyError: 'key "response_mask" not found in TensorDict ...'
  (in losses.py ppo_loss: data = data.select(*fields).to_padded_tensor())

Alias `response_mask = loss_mask` post-nest — both refer to the same
all-ones nested tensor in this mode. Drop the alias before
unnest_batch_by_mask in _decompress_batch so the real 2D mask
rehydrates correctly.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
`prepare_response_slice` pinned ``max_response_len`` to the spec-stashed
full-batch ``orig_response_len`` (e.g. 14336), but worker-side code
paths such as ``ppo_loss`` unpack nested response-axis fields via
``TensorDict.to_padded_tensor()`` which pads to the *per-micro-batch*
local max (e.g. 301). The two shapes then refused to broadcast:

  RuntimeError: The size of tensor a (14336) must match the size of
  tensor b (301) at non-singleton dimension 1
  (in compute_policy_loss_vanilla: log_prob - old_log_prob)

Use ``response_lens.max()`` so extract_response's output dense tensor
has the same local-max width as the other response-axis fields after
``.to_padded_tensor()``. The mask still carries exact per-row lengths,
so nothing downstream is lost by the tighter padding.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…cal width

Response-axis shapes were derived from tensor shapes (MaskNestingSpec
``mask_shape``, or per-micro-batch jagged max) in two independent places
and drifted apart after worker chunking:

  * ``extract_response``/``slice_response`` pad to the nesting spec's
    stashed ``mask_shape[-1]`` — the *original full-batch* max (e.g.
    14336).
  * ``TensorDict.to_padded_tensor()`` in worker loss code pads to the
    *per-micro-batch local* max (e.g. 301).

When the two widths collide downstream (``log_prob - old_log_prob``
inside ``compute_policy_loss_vanilla``) PyTorch refuses to broadcast:

  RuntimeError: The size of tensor a (14336) must match the size of
  tensor b (301) at non-singleton dimension 1

Record the task-level ``max_prompt_length`` / ``max_response_length``
straight from ``config.data`` onto the batch (via ``tu.assign_non_tensor``)
in ``_compress_batch``. Read them back as the single source of truth:

  * ``prepare_response_slice`` prefers the stashed config width over
    shape inference.
  * New ``select_and_pad_to_response`` helper replaces
    ``.select(...).to_padded_tensor()`` in ``ppo_loss`` / ``value_loss``
    and right-pads the response dim of every selected field up to the
    same canonical width.

Both sides of every response-axis arithmetic now produce
``(bs, max_response_length)``, independent of whichever chunk landed on
a worker. Reverts 568e93e's narrower local-max workaround.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Snapshots the real pre-compress TensorDict, runs the legacy
(left_right_2_no_padding) and nesting (nest_batch_by_mask) paths
side-by-side on identical clones, and prints per-field
dtype/shape/nested flag/bytes plus total raw bytes and cloudpickle
size. Fires exactly once per trainer instance (guarded by
``_bench_compress_done``), so later steps are untouched.

Temporary — revert after collecting the numbers.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Payload bench on a real training batch (8192 samples, fill ratio 8.4%)
showed that ``nest_batch_by_mask`` was achieving only ~3.4× compression
over the legacy dense path, far below the ~12× ceiling implied by the
actual token occupancy. The culprit was that ``prompts`` (134 MB) and
``responses`` (940 MB) were missing from ``KNOWN_FIELD_TO_MASK_AND_PAD``
and silently stayed dense — together accounting for ~78% of the
post-nest bytes on the wire.

Fix it in three parts:

1. Registry entries for both:

   * ``"responses": ("response_mask", PAD_TOKEN_ID)`` — right-padded
     response token IDs, already response-axis aligned.
   * ``"prompts": ("prompt_mask", PAD_TOKEN_ID)`` — left-padded prompt
     token IDs, paired with a new ``prompt_mask`` field derived from
     ``attention_mask[:, :max_prompt_length]`` in ``_compress_batch``.
     ``_decompress_batch`` drops ``prompt_mask`` post-unnest since the
     trainer never had it to begin with.

2. Coverage check inside ``nest_batch_by_mask``: after all specs run,
   scan the batch for any tensor field that is still dense (ignoring
   RLE offsets/lengths and a small default-ignore set). Report each
   offender with ``shape`` / ``dtype`` / ``bytes`` so the reader can
   either add a registry entry or opt out via ``ignore_dense_fields``.

3. New ``strict`` kwarg on ``nest_batch_by_mask`` (wired to
   ``trainer.strict_mask_nesting``, default False). On ``True`` the
   coverage check raises ``RuntimeError`` instead of warning — useful
   in CI / dev setups to catch regressions when new fields are added
   to the batch without a matching registry entry.

Expected post-change compression on the same batch: ~11.9× vs legacy
(close to the 12× theoretical upper bound set by the fill ratio).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Snapshots the real pre-compress TensorDict, runs the legacy
(left_right_2_no_padding) and nesting (nest_batch_by_mask) paths
side-by-side on identical clones, and prints per-field
dtype/shape/nested flag/bytes plus total raw bytes and cloudpickle
size. Fires exactly once per trainer instance (guarded by
``_bench_compress_done``), so later steps are untouched.

Temporary — revert after collecting the numbers.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a mask-driven Run-Length Encoding (RLE) nesting mechanism for variable-length sequences, providing an alternative to the legacy flash_attn-dependent padding logic. Key additions include RLE utilities in verl/utils/nested_tensor.py, integration into the RayPPOTrainer for batch compression/decompression, and updated loss functions. Review feedback points out a potential IndexError in the response slicing logic when handling empty rows and advises using public PyTorch APIs instead of internal NestedTensor attributes to maintain compatibility with future versions.

Comment thread verl/utils/nested_tensor.py Outdated
Comment thread verl/workers/utils/padding.py Outdated
Addresses gemini-code-assist review on verl-project#6009:

* ``_rle_scatter_indices`` in ``verl/utils/nested_tensor.py`` was
  reaching into the private ``NestedTensor._offsets`` attribute.
  Switch to the public ``offsets()`` method, which returns the same
  jagged-layout offsets and is stable across PyTorch releases.

* ``prepare_response_slice`` in ``verl/workers/utils/padding.py`` did
  the same, and additionally would pick up the next non-empty row's
  first segment when a row's mask had zero True positions (pre-LLM
  edge cases, defence-in-depth for reuse outside RL). Switch to the
  public API and gate the scatter on ``has_segments`` so empty rows
  keep ``abs_start = 0`` and don't corrupt the slice bounds for rows
  that follow them.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant