Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
title: RLOO
- local: sft_trainer
title: SFT
- local: target_po_trainer
title: TargetPO
title: Trainers
- sections:
- local: clis
Expand Down
1 change: 1 addition & 0 deletions docs/source/dataset_formats.md
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ Choosing the right dataset type depends on the task you are working on and the s
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
| [`RLOOTrainer`] | [Prompt-only](#prompt-only) |
| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
| [`TargetPOTrainer`] | [Prompt-only](#prompt-only) |
| [`experimental.bco.BCOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
| [`experimental.cpo.CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`experimental.gkd.GKDTrainer`] | [Prompt-completion](#prompt-completion) |
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Below is the current list of TRL trainers, organized by method type (⚡️ = vL

- [`GRPOTrainer`](grpo_trainer) ⚡️
- [`RLOOTrainer`](rloo_trainer) ⚡️
- [`TargetPOTrainer`](target_po_trainer) ⚡️
- [`OnlineDPOTrainer`](online_dpo_trainer) 🧪 ⚡️
- [`NashMDTrainer`](nash_md_trainer) 🧪 ⚡️
- [`PPOTrainer`](ppo_trainer) 🧪
Expand Down
10 changes: 10 additions & 0 deletions docs/source/liger_kernel_integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Liger Kernel is supported in the following TRL trainers:
- **SFT** (Supervised Fine-Tuning)
- **DPO** (Direct Preference Optimization)
- **GRPO** (Group Relative Policy Optimization)
- **TargetPO** (Target Policy Optimization)
- **KTO** (Kahneman-Tversky Optimization)
- **GKD** (Generalized Knowledge Distillation)

Expand Down Expand Up @@ -54,6 +55,15 @@ from trl import GRPOConfig
training_args = GRPOConfig(..., use_liger_kernel=True)
```

</hfoption>
<hfoption id="TargetPO">

```python
from trl import TargetPOConfig

training_args = TargetPOConfig(..., use_liger_kernel=True)
```

</hfoption>
<hfoption id="KTO">

Expand Down
24 changes: 24 additions & 0 deletions docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,30 @@ trainer.train()

The official code [sail-sg/Stable-RL](https://github.com/sail-sg/Stable-RL)

### Target Policy Optimization

**📜 Paper**: https://huggingface.co/papers/2604.06159

Target Policy Optimization (TPO) builds a target distribution over each prompt's sampled completions using rollout
policy probabilities and normalized rewards, then trains the policy to match that target with sequence-level
cross-entropy. To use TPO in TRL, use [`TargetPOTrainer`] or set `loss_type="tpo"` in [`GRPOConfig`]. The Python
class is named `TargetPO` to avoid collision with the experimental Triple Preference Optimization trainer that
shares the same acronym.

```python
from trl import GRPOConfig, GRPOTrainer

training_args = GRPOConfig(
loss_type="tpo",
tpo_target_temperature=1.0,
)

trainer = GRPOTrainer(
...,
args=training_args,
)
```

## Direct Policy Optimization

Papers relating to the [`DPOTrainer`]
Expand Down
9 changes: 9 additions & 0 deletions docs/source/reducing_memory_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,15 @@ from trl import GRPOConfig
training_args = GRPOConfig(..., use_liger_kernel=True)
```

</hfoption>
<hfoption id="TargetPO">

```python
from trl import TargetPOConfig

training_args = TargetPOConfig(..., use_liger_kernel=True)
```

</hfoption>
<hfoption id="KTO">

Expand Down
9 changes: 9 additions & 0 deletions docs/source/speeding_up_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,15 @@ from trl import GRPOConfig
training_args = GRPOConfig(..., use_liger_kernel=True)
```

</hfoption>
<hfoption id="TargetPO">

```python
from trl import TargetPOConfig

training_args = TargetPOConfig(..., use_liger_kernel=True)
```

</hfoption>
<hfoption id="KTO">

Expand Down
47 changes: 47 additions & 0 deletions docs/source/target_po_trainer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# TargetPO Trainer

## Overview

[`TargetPOTrainer`] implements Target Policy Optimization (TPO), an online post-training algorithm from [Target Policy Optimization](https://huggingface.co/papers/2604.06159).

[`TargetPOTrainer`] keeps an aligned copy of the online rollout and reward flow used by [`GRPOTrainer`], but trains
with a sequence-level cross-entropy target:

$$
q_i = \frac{p_i^{\text{old}} \exp(u_i / \eta)}{\sum_j p_j^{\text{old}} \exp(u_j / \eta)}
$$

Here \\(p_i^{\text{old}}\\) is a *length-normalized* proxy for the rollout policy probability of completion \\(i\\) in
the prompt group (per-token mean log-probability by default, controlled by `tpo_length_normalize_logps`), \\(u_i\\) is
the population-whitened group reward, and \\(\eta\\) is `tpo_target_temperature`. Length-normalization prevents the
old-policy term from dominating the target when completions in a group have different lengths; set
`tpo_length_normalize_logps=False` to recover the paper's literal sequence-probability formulation.

## Quick Start

```python
from datasets import load_dataset
from trl import TargetPOTrainer
from trl.rewards import accuracy_reward

dataset = load_dataset("trl-lib/DeepMath-103K", split="train")

trainer = TargetPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=accuracy_reward,
train_dataset=dataset,
)
trainer.train()
```

## Configuration

[`TargetPOConfig`] inherits the online rollout, reward, vLLM, tool-calling, and logging arguments from [`GRPOConfig`]. Because TargetPO uses a sequence-level softmax over every completion in a prompt group, [`TargetPOConfig`] defaults `steps_per_generation` to `1` when the user does not specify a generation schedule. Larger values are supported as long as each optimization step still contains whole prompt groups, i.e. `(generation_batch_size // steps_per_generation) % num_generations == 0`.

## TargetPOConfig

[[autodoc]] TargetPOConfig

## TargetPOTrainer

[[autodoc]] TargetPOTrainer
123 changes: 123 additions & 0 deletions examples/scripts/target_po.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# /// script
# dependencies = [
# "trl[peft]",
# "math-verify",
# "latex2sympy2_extended",
# "trackio",
# "kernels",
# ]
# ///

"""
pip install math_verify

accelerate launch \
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
examples/scripts/target_po.py \
--model_name_or_path Qwen/Qwen3-0.6B \
--output_dir target_po-Qwen3-0.6B \
--learning_rate 1e-5 \
--dtype bfloat16 \
--max_completion_length 1024 \
--use_peft \
--lora_target_modules "q_proj", "v_proj" \
--log_completions \
--per_device_train_batch_size 8 \
--num_generations 8 \
--beta 0.0 \
--tpo_target_temperature 1.0

"""

import torch
from datasets import load_dataset

from trl import (
ModelConfig,
ScriptArguments,
TargetPOConfig,
TargetPOTrainer,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl.rewards import accuracy_reward, think_format_reward


if __name__ == "__main__":
parser = TrlParser((ScriptArguments, TargetPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
################
# Model & Processor
################
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
training_args.model_init_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
dtype=dtype,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
training_args.model_init_kwargs["quantization_config"] = quantization_config

################
# Dataset
################
train_dataset, eval_dataset = load_dataset("AI-MO/NuminaMath-TIR", split=["train[:5%]", "test[:5%]"])

SYSTEM_PROMPT = (
"A conversation between user and assistant. The user asks a question, and the assistant solves it. The "
"assistant first thinks about the reasoning process in the mind and then provides the user with the answer. "
"The reasoning process and answer are enclosed within <think></think> tags, i.e., <think>\nThis is my "
"reasoning.\n</think>\nThis is my answer."
)

def make_conversation(example):
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["problem"]},
],
}

train_dataset = train_dataset.map(make_conversation)
eval_dataset = eval_dataset.map(make_conversation)

train_dataset = train_dataset.remove_columns(["messages", "problem"])
eval_dataset = eval_dataset.remove_columns(["messages", "problem"])

################
# Training
################
trainer = TargetPOTrainer(
model=model_args.model_name_or_path,
args=training_args,
reward_funcs=[think_format_reward, accuracy_reward],
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=get_peft_config(model_args),
)

trainer.train()

# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
78 changes: 78 additions & 0 deletions tests/test_target_po_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from trl import TargetPOTrainer


WORLD_SIZE = 2
NUM_GENERATIONS = 4 # one group of 4 split across 2 ranks (2 per rank)
LOCAL_SEQ_LOGPS = [
torch.tensor([0.1, -0.3]), # rank 0
torch.tensor([0.5, -0.2]), # rank 1
]
LOCAL_TARGETS = [
torch.tensor([0.1, 0.2]), # rank 0
torch.tensor([0.4, 0.3]), # rank 1; global sums to 1.0
]


def _tpo_worker(rank: int, world_size: int, init_file: str) -> None:
dist.init_process_group(
backend="gloo",
init_method=f"file://{init_file}",
world_size=world_size,
rank=rank,
)
try:
local = LOCAL_SEQ_LOGPS[rank].clone().requires_grad_(True)

gathered = TargetPOTrainer._gather_tensor_with_grad(local)
logps = torch.log_softmax(gathered.view(-1, NUM_GENERATIONS), dim=1).view(-1)

process_slice = slice(rank * local.size(0), (rank + 1) * local.size(0))
local_logps = logps[process_slice]
local_targets = LOCAL_TARGETS[rank]

loss = -(local_targets * local_logps).sum() * NUM_GENERATIONS / local_targets.numel()
loss.backward()

global_logps = torch.cat(LOCAL_SEQ_LOGPS)
global_targets = torch.cat(LOCAL_TARGETS)
global_softmax = torch.softmax(global_logps, dim=0)
scale = NUM_GENERATIONS / local_targets.numel()
expected = scale * (global_softmax[process_slice] - global_targets[process_slice])

torch.testing.assert_close(local.grad, expected)
finally:
dist.destroy_process_group()


@pytest.mark.skipif(not torch.distributed.is_available(), reason="torch.distributed not available")
def test_tpo_gradient_across_ranks_with_group_spanning_ranks():
"""
A TPO prompt group of size 4 split 2/2 across DP ranks. The group's log-softmax normalizer depends on
all four completions, so the autograd-aware all_gather must route gradient from each rank's loss back
to the owning rank's local tensor. Expected local gradient is scale * (softmax - target) at local positions.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
init_file = os.path.join(tmp_dir, "rendezvous")
mp.spawn(_tpo_worker, args=(WORLD_SIZE, init_file), nprocs=WORLD_SIZE, join=True)
Loading