Skip to content

feat: add TargetPO trainer#5591

Draft
JeanKaddour wants to merge 14 commits intohuggingface:mainfrom
JeanKaddour:codex/add-tpo-trainer-template
Draft

feat: add TargetPO trainer#5591
JeanKaddour wants to merge 14 commits intohuggingface:mainfrom
JeanKaddour:codex/add-tpo-trainer-template

Conversation

@JeanKaddour
Copy link
Copy Markdown

@JeanKaddour JeanKaddour commented Apr 18, 2026

What does this PR do?

Adds Target Policy Optimization (TPO) as a GRPO-family loss with a thin first-class trainer wrapper.

This PR:

  • adds loss_type="tpo" to GRPOTrainer, using the target q_i ∝ p_i_old * exp(u_i / eta) and a sequence-level target-matching loss
  • builds TPO targets from old-policy sequence logprobs and TPO-whitened group rewards, matching the reference tpo_skill population-standardization behavior
  • adds TPOConfig and TPOTrainer as convenience wrappers with TPO defaults and metadata
  • documents the minimal TPOTrainer usage, prompt-only dataset format, and paper-index entry
  • adds focused tests for TPO config defaults, target construction, TPO score whitening, the p - q gradient, and public imports

Validation

  • uv run --extra test pytest tests/test_tpo_trainer.py
  • uvx ruff check trl/trainer/grpo_trainer.py tests/test_tpo_trainer.py
  • one-step TPOTrainer smoke training with uv run --extra test python

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

AI writing disclosure

We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.

  • No AI usage: the PR was written entirely by a human.
  • AI-assisted: some parts were suggested or improved by AI, but the PR was written and reviewed by a human.
  • AI-generated: the PR was mostly or fully generated by an AI tool.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Maintainers familiar with GRPO internals and online trainer APIs would be especially helpful reviewers.


Note

Medium Risk
Adds a new online loss path (loss_type="tpo") and distributed, autograd-aware gathering in the GRPO-family training loop, which can affect training correctness and performance across multi-GPU setups. Risk is mitigated by extensive new unit/distributed tests and input validation, but this still touches core optimization logic.

Overview
Adds Target Policy Optimization (TPO) support to the GRPO family by introducing loss_type="tpo" with new config knobs (tpo_target_temperature, tpo_length_normalize_logps), target construction from rollout-time sequence logprobs plus whitened group rewards, and a sequence-level cross-entropy training objective (optionally with KL regularization).

Introduces first-class TargetPOConfig/TargetPOTrainer, exports them publicly, and provides an example training script. Updates docs to list TargetPO across the trainer index, dataset-format expectations, paper index, and performance/memory guides.

Adds comprehensive tests covering config validation, target/score math, gradient correctness (including prompt groups spanning distributed ranks via autograd-aware all-gather), and Liger-kernel loss parity where available.

Reviewed by Cursor Bugbot for commit 46d1fd1. Bugbot is set up for automated code reviews on this repo. Configure here.

Comment thread trl/scripts/tpo.py Outdated
@JeanKaddour JeanKaddour force-pushed the codex/add-tpo-trainer-template branch from 572a27a to a0f11d7 Compare April 18, 2026 12:54
@JeanKaddour JeanKaddour changed the title [codex] Add TPO trainer feat: add TPO trainer Apr 18, 2026
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit a0f11d7. Configure here.

Comment thread docs/source/tpo_trainer.md Outdated
@JeanKaddour JeanKaddour force-pushed the codex/add-tpo-trainer-template branch from a0f11d7 to 5db4cb3 Compare April 18, 2026 13:04
@kashif
Copy link
Copy Markdown
Collaborator

kashif commented Apr 18, 2026

unfortunate name for the method, as there is already a TPO method #5506 perhaps this method can be called TargetPO ? or another variation?

@JeanKaddour
Copy link
Copy Markdown
Author

Oh, I wasn't aware! Thanks for letting me know, will rename soon

The acronym "TPO" clashes with the experimental Triple Preference
Optimization trainer (huggingface#5506). Rename the Python classes and module
paths to TargetPO while keeping loss_type="tpo" inside GRPOConfig
untouched - scoped to GRPO, the loss identifier is unambiguous.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
@JeanKaddour JeanKaddour changed the title feat: add TPO trainer feat: add TargetPO trainer Apr 20, 2026
kashif and others added 11 commits April 21, 2026 15:45
TPO's target is normalized per prompt group, so the loss is well-defined
whenever each optimization step contains whole groups. Replace the hard
`steps_per_generation == 1` requirement with a divisibility check on
`generation_batch_size // steps_per_generation`, and update the
TargetPOConfig docs and tests accordingly.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Move torch.distributed.nn.functional.all_gather to module-level imports
(aliased _all_gather_with_grad) instead of importing inside the hot path,
and add a 2-rank gloo+mp.spawn test that verifies the autograd-aware gather
routes gradient correctly when a TPO prompt group spans DP ranks.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Mirror the GSPO example layout for examples/scripts/target_po.py so users
have a canonical launch template, and add a tiny-model smoke test that
trains TargetPOTrainer end-to-end and asserts parameters move.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Without normalization, summed sequence log-probabilities span orders of
magnitude across a prompt group, so log_softmax(old_sequence_logps) collapses
to a near-one-hot on the longest/most-likely completion and the reward term
in q_i ∝ p_i^old · exp(u_i / eta) cannot compete. Length-normalizing (per-token
mean instead of sum) keeps the old-policy term bounded so rewards actually
shape the target.

Add tpo_length_normalize_logps (default True) to control this; set False to
recover the paper-literal sequence-probability formulation. Parametrize the
gradient test over the flag so both paths stay correct.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
When tool masking or truncation zeros out every token in a completion, that
completion has no log-probability gradient to give and shouldn't shape the
target distribution. Plumb a tpo_valid_mask (loss_mask.any(dim=-1), gathered
across processes) through get_tpo_scores and get_tpo_targets, and apply it in
the loss-time group softmax so masked completions get zero target probability
and are excluded from the log-softmax normalizer.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
When num_processes > 1 and steps_per_generation > 1, the per-step batch is
sliced both across ranks (per_device_train_batch_size) and across sub-steps.
For TPO's group softmax to stay coherent, each rank must hold whole groups
within a sub-step — i.e. per_device_train_batch_size must be divisible by
num_generations. Otherwise prompt groups straddle rank boundaries within a
loss step and the view(-1, num_generations) reshape is wrong.

Add the check to GRPOConfig.__post_init__ and cover both the rejection path
and the legal one-step-per-generation case where groups may still span ranks
(handled by the autograd-aware all_gather).

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Per the repo convention that each trainer is independently readable, replace
the thin GRPOTrainer subclass with a full self-contained TargetPOTrainer that
carries its own copy of the online rollout, reward, generation, vLLM, and
loss machinery. Behavior is unchanged — the TPO target construction, length
normalization, valid-mask handling, and distributed gradient routing all
match the GRPOTrainer(loss_type='tpo') path.

Update the test files to call TargetPOTrainer.* helpers directly instead of
reaching into GRPOTrainer.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants