feat: add TargetPO trainer#5591
Draft
JeanKaddour wants to merge 14 commits intohuggingface:mainfrom
Draft
Conversation
572a27a to
a0f11d7
Compare
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit a0f11d7. Configure here.
a0f11d7 to
5db4cb3
Compare
Collaborator
|
unfortunate name for the method, as there is already a TPO method #5506 perhaps this method can be called |
Author
|
Oh, I wasn't aware! Thanks for letting me know, will rename soon |
The acronym "TPO" clashes with the experimental Triple Preference Optimization trainer (huggingface#5506). Rename the Python classes and module paths to TargetPO while keeping loss_type="tpo" inside GRPOConfig untouched - scoped to GRPO, the loss identifier is unambiguous. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
TPO's target is normalized per prompt group, so the loss is well-defined whenever each optimization step contains whole groups. Replace the hard `steps_per_generation == 1` requirement with a divisibility check on `generation_batch_size // steps_per_generation`, and update the TargetPOConfig docs and tests accordingly. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Move torch.distributed.nn.functional.all_gather to module-level imports (aliased _all_gather_with_grad) instead of importing inside the hot path, and add a 2-rank gloo+mp.spawn test that verifies the autograd-aware gather routes gradient correctly when a TPO prompt group spans DP ranks. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Mirror the GSPO example layout for examples/scripts/target_po.py so users have a canonical launch template, and add a tiny-model smoke test that trains TargetPOTrainer end-to-end and asserts parameters move. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Without normalization, summed sequence log-probabilities span orders of magnitude across a prompt group, so log_softmax(old_sequence_logps) collapses to a near-one-hot on the longest/most-likely completion and the reward term in q_i ∝ p_i^old · exp(u_i / eta) cannot compete. Length-normalizing (per-token mean instead of sum) keeps the old-policy term bounded so rewards actually shape the target. Add tpo_length_normalize_logps (default True) to control this; set False to recover the paper-literal sequence-probability formulation. Parametrize the gradient test over the flag so both paths stay correct. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
When tool masking or truncation zeros out every token in a completion, that completion has no log-probability gradient to give and shouldn't shape the target distribution. Plumb a tpo_valid_mask (loss_mask.any(dim=-1), gathered across processes) through get_tpo_scores and get_tpo_targets, and apply it in the loss-time group softmax so masked completions get zero target probability and are excluded from the log-softmax normalizer. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
When num_processes > 1 and steps_per_generation > 1, the per-step batch is sliced both across ranks (per_device_train_batch_size) and across sub-steps. For TPO's group softmax to stay coherent, each rank must hold whole groups within a sub-step — i.e. per_device_train_batch_size must be divisible by num_generations. Otherwise prompt groups straddle rank boundaries within a loss step and the view(-1, num_generations) reshape is wrong. Add the check to GRPOConfig.__post_init__ and cover both the rejection path and the legal one-step-per-generation case where groups may still span ranks (handled by the autograd-aware all_gather). Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Per the repo convention that each trainer is independently readable, replace the thin GRPOTrainer subclass with a full self-contained TargetPOTrainer that carries its own copy of the online rollout, reward, generation, vLLM, and loss machinery. Behavior is unchanged — the TPO target construction, length normalization, valid-mask handling, and distributed gradient routing all match the GRPOTrainer(loss_type='tpo') path. Update the test files to call TargetPOTrainer.* helpers directly instead of reaching into GRPOTrainer. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.

What does this PR do?
Adds Target Policy Optimization (TPO) as a GRPO-family loss with a thin first-class trainer wrapper.
This PR:
loss_type="tpo"toGRPOTrainer, using the targetq_i ∝ p_i_old * exp(u_i / eta)and a sequence-level target-matching losstpo_skillpopulation-standardization behaviorTPOConfigandTPOTraineras convenience wrappers with TPO defaults and metadataTPOTrainerusage, prompt-only dataset format, and paper-index entryp - qgradient, and public importsValidation
uv run --extra test pytest tests/test_tpo_trainer.pyuvx ruff check trl/trainer/grpo_trainer.py tests/test_tpo_trainer.pyTPOTrainersmoke training withuv run --extra test pythonBefore submitting
AI writing disclosure
We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Maintainers familiar with GRPO internals and online trainer APIs would be especially helpful reviewers.
Note
Medium Risk
Adds a new online loss path (
loss_type="tpo") and distributed, autograd-aware gathering in the GRPO-family training loop, which can affect training correctness and performance across multi-GPU setups. Risk is mitigated by extensive new unit/distributed tests and input validation, but this still touches core optimization logic.Overview
Adds Target Policy Optimization (TPO) support to the GRPO family by introducing
loss_type="tpo"with new config knobs (tpo_target_temperature,tpo_length_normalize_logps), target construction from rollout-time sequence logprobs plus whitened group rewards, and a sequence-level cross-entropy training objective (optionally with KL regularization).Introduces first-class
TargetPOConfig/TargetPOTrainer, exports them publicly, and provides an example training script. Updates docs to list TargetPO across the trainer index, dataset-format expectations, paper index, and performance/memory guides.Adds comprehensive tests covering config validation, target/score math, gradient correctness (including prompt groups spanning distributed ranks via autograd-aware all-gather), and Liger-kernel loss parity where available.
Reviewed by Cursor Bugbot for commit 46d1fd1. Bugbot is set up for automated code reviews on this repo. Configure here.