diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index b5a7442e026..c8eb147ba78 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -25,6 +25,8 @@
title: RLOO
- local: sft_trainer
title: SFT
+ - local: target_po_trainer
+ title: TargetPO
title: Trainers
- sections:
- local: clis
diff --git a/docs/source/dataset_formats.md b/docs/source/dataset_formats.md
index f2214bbf75f..4ea34a38b11 100644
--- a/docs/source/dataset_formats.md
+++ b/docs/source/dataset_formats.md
@@ -411,6 +411,7 @@ Choosing the right dataset type depends on the task you are working on and the s
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
| [`RLOOTrainer`] | [Prompt-only](#prompt-only) |
| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
+| [`TargetPOTrainer`] | [Prompt-only](#prompt-only) |
| [`experimental.bco.BCOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
| [`experimental.cpo.CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`experimental.gkd.GKDTrainer`] | [Prompt-completion](#prompt-completion) |
diff --git a/docs/source/index.md b/docs/source/index.md
index bbbdd9acde3..76497d2b15c 100644
--- a/docs/source/index.md
+++ b/docs/source/index.md
@@ -25,6 +25,7 @@ Below is the current list of TRL trainers, organized by method type (⚡️ = vL
- [`GRPOTrainer`](grpo_trainer) ⚡️
- [`RLOOTrainer`](rloo_trainer) ⚡️
+- [`TargetPOTrainer`](target_po_trainer) ⚡️
- [`OnlineDPOTrainer`](online_dpo_trainer) 🧪 ⚡️
- [`NashMDTrainer`](nash_md_trainer) 🧪 ⚡️
- [`PPOTrainer`](ppo_trainer) 🧪
diff --git a/docs/source/liger_kernel_integration.md b/docs/source/liger_kernel_integration.md
index 7a387c813fd..0c688a4dac1 100644
--- a/docs/source/liger_kernel_integration.md
+++ b/docs/source/liger_kernel_integration.md
@@ -14,6 +14,7 @@ Liger Kernel is supported in the following TRL trainers:
- **SFT** (Supervised Fine-Tuning)
- **DPO** (Direct Preference Optimization)
- **GRPO** (Group Relative Policy Optimization)
+- **TargetPO** (Target Policy Optimization)
- **KTO** (Kahneman-Tversky Optimization)
- **GKD** (Generalized Knowledge Distillation)
@@ -54,6 +55,15 @@ from trl import GRPOConfig
training_args = GRPOConfig(..., use_liger_kernel=True)
```
+
+
+
+```python
+from trl import TargetPOConfig
+
+training_args = TargetPOConfig(..., use_liger_kernel=True)
+```
+
diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md
index 3fdccefff30..b573679346b 100644
--- a/docs/source/paper_index.md
+++ b/docs/source/paper_index.md
@@ -668,6 +668,30 @@ trainer.train()
The official code [sail-sg/Stable-RL](https://github.com/sail-sg/Stable-RL)
+### Target Policy Optimization
+
+**📜 Paper**: https://huggingface.co/papers/2604.06159
+
+Target Policy Optimization (TPO) builds a target distribution over each prompt's sampled completions using rollout
+policy probabilities and normalized rewards, then trains the policy to match that target with sequence-level
+cross-entropy. To use TPO in TRL, use [`TargetPOTrainer`] or set `loss_type="tpo"` in [`GRPOConfig`]. The Python
+class is named `TargetPO` to avoid collision with the experimental Triple Preference Optimization trainer that
+shares the same acronym.
+
+```python
+from trl import GRPOConfig, GRPOTrainer
+
+training_args = GRPOConfig(
+ loss_type="tpo",
+ tpo_target_temperature=1.0,
+)
+
+trainer = GRPOTrainer(
+ ...,
+ args=training_args,
+)
+```
+
## Direct Policy Optimization
Papers relating to the [`DPOTrainer`]
diff --git a/docs/source/reducing_memory_usage.md b/docs/source/reducing_memory_usage.md
index db78aded07e..609d56b36ae 100644
--- a/docs/source/reducing_memory_usage.md
+++ b/docs/source/reducing_memory_usage.md
@@ -159,6 +159,15 @@ from trl import GRPOConfig
training_args = GRPOConfig(..., use_liger_kernel=True)
```
+
+
+
+```python
+from trl import TargetPOConfig
+
+training_args = TargetPOConfig(..., use_liger_kernel=True)
+```
+
diff --git a/docs/source/speeding_up_training.md b/docs/source/speeding_up_training.md
index c855cc06233..b42401daf1d 100644
--- a/docs/source/speeding_up_training.md
+++ b/docs/source/speeding_up_training.md
@@ -168,6 +168,15 @@ from trl import GRPOConfig
training_args = GRPOConfig(..., use_liger_kernel=True)
```
+
+
+
+```python
+from trl import TargetPOConfig
+
+training_args = TargetPOConfig(..., use_liger_kernel=True)
+```
+
diff --git a/docs/source/target_po_trainer.md b/docs/source/target_po_trainer.md
new file mode 100644
index 00000000000..9bb7c3c3d3e
--- /dev/null
+++ b/docs/source/target_po_trainer.md
@@ -0,0 +1,47 @@
+# TargetPO Trainer
+
+## Overview
+
+[`TargetPOTrainer`] implements Target Policy Optimization (TPO), an online post-training algorithm from [Target Policy Optimization](https://huggingface.co/papers/2604.06159).
+
+[`TargetPOTrainer`] keeps an aligned copy of the online rollout and reward flow used by [`GRPOTrainer`], but trains
+with a sequence-level cross-entropy target:
+
+$$
+q_i = \frac{p_i^{\text{old}} \exp(u_i / \eta)}{\sum_j p_j^{\text{old}} \exp(u_j / \eta)}
+$$
+
+Here \\(p_i^{\text{old}}\\) is a *length-normalized* proxy for the rollout policy probability of completion \\(i\\) in
+the prompt group (per-token mean log-probability by default, controlled by `tpo_length_normalize_logps`), \\(u_i\\) is
+the population-whitened group reward, and \\(\eta\\) is `tpo_target_temperature`. Length-normalization prevents the
+old-policy term from dominating the target when completions in a group have different lengths; set
+`tpo_length_normalize_logps=False` to recover the paper's literal sequence-probability formulation.
+
+## Quick Start
+
+```python
+from datasets import load_dataset
+from trl import TargetPOTrainer
+from trl.rewards import accuracy_reward
+
+dataset = load_dataset("trl-lib/DeepMath-103K", split="train")
+
+trainer = TargetPOTrainer(
+ model="Qwen/Qwen2-0.5B-Instruct",
+ reward_funcs=accuracy_reward,
+ train_dataset=dataset,
+)
+trainer.train()
+```
+
+## Configuration
+
+[`TargetPOConfig`] inherits the online rollout, reward, vLLM, tool-calling, and logging arguments from [`GRPOConfig`]. Because TargetPO uses a sequence-level softmax over every completion in a prompt group, [`TargetPOConfig`] defaults `steps_per_generation` to `1` when the user does not specify a generation schedule. Larger values are supported as long as each optimization step still contains whole prompt groups, i.e. `(generation_batch_size // steps_per_generation) % num_generations == 0`.
+
+## TargetPOConfig
+
+[[autodoc]] TargetPOConfig
+
+## TargetPOTrainer
+
+[[autodoc]] TargetPOTrainer
diff --git a/examples/scripts/target_po.py b/examples/scripts/target_po.py
new file mode 100644
index 00000000000..f0566efff3e
--- /dev/null
+++ b/examples/scripts/target_po.py
@@ -0,0 +1,123 @@
+# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# /// script
+# dependencies = [
+# "trl[peft]",
+# "math-verify",
+# "latex2sympy2_extended",
+# "trackio",
+# "kernels",
+# ]
+# ///
+
+"""
+pip install math_verify
+
+accelerate launch \
+ --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
+ examples/scripts/target_po.py \
+ --model_name_or_path Qwen/Qwen3-0.6B \
+ --output_dir target_po-Qwen3-0.6B \
+ --learning_rate 1e-5 \
+ --dtype bfloat16 \
+ --max_completion_length 1024 \
+ --use_peft \
+ --lora_target_modules "q_proj", "v_proj" \
+ --log_completions \
+ --per_device_train_batch_size 8 \
+ --num_generations 8 \
+ --beta 0.0 \
+ --tpo_target_temperature 1.0
+
+"""
+
+import torch
+from datasets import load_dataset
+
+from trl import (
+ ModelConfig,
+ ScriptArguments,
+ TargetPOConfig,
+ TargetPOTrainer,
+ TrlParser,
+ get_kbit_device_map,
+ get_peft_config,
+ get_quantization_config,
+)
+from trl.rewards import accuracy_reward, think_format_reward
+
+
+if __name__ == "__main__":
+ parser = TrlParser((ScriptArguments, TargetPOConfig, ModelConfig))
+ script_args, training_args, model_args = parser.parse_args_and_config()
+ ################
+ # Model & Processor
+ ################
+ dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
+ training_args.model_init_kwargs = dict(
+ revision=model_args.model_revision,
+ attn_implementation=model_args.attn_implementation,
+ dtype=dtype,
+ )
+ quantization_config = get_quantization_config(model_args)
+ if quantization_config is not None:
+ # Passing None would not be treated the same as omitting the argument, so we include it only when valid.
+ training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
+ training_args.model_init_kwargs["quantization_config"] = quantization_config
+
+ ################
+ # Dataset
+ ################
+ train_dataset, eval_dataset = load_dataset("AI-MO/NuminaMath-TIR", split=["train[:5%]", "test[:5%]"])
+
+ SYSTEM_PROMPT = (
+ "A conversation between user and assistant. The user asks a question, and the assistant solves it. The "
+ "assistant first thinks about the reasoning process in the mind and then provides the user with the answer. "
+ "The reasoning process and answer are enclosed within tags, i.e., \nThis is my "
+ "reasoning.\n\nThis is my answer."
+ )
+
+ def make_conversation(example):
+ return {
+ "prompt": [
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": example["problem"]},
+ ],
+ }
+
+ train_dataset = train_dataset.map(make_conversation)
+ eval_dataset = eval_dataset.map(make_conversation)
+
+ train_dataset = train_dataset.remove_columns(["messages", "problem"])
+ eval_dataset = eval_dataset.remove_columns(["messages", "problem"])
+
+ ################
+ # Training
+ ################
+ trainer = TargetPOTrainer(
+ model=model_args.model_name_or_path,
+ args=training_args,
+ reward_funcs=[think_format_reward, accuracy_reward],
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ peft_config=get_peft_config(model_args),
+ )
+
+ trainer.train()
+
+ # Save and push to hub
+ trainer.save_model(training_args.output_dir)
+ if training_args.push_to_hub:
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
diff --git a/tests/test_target_po_distributed.py b/tests/test_target_po_distributed.py
new file mode 100644
index 00000000000..f21c8389c84
--- /dev/null
+++ b/tests/test_target_po_distributed.py
@@ -0,0 +1,78 @@
+# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import tempfile
+
+import pytest
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+from trl import TargetPOTrainer
+
+
+WORLD_SIZE = 2
+NUM_GENERATIONS = 4 # one group of 4 split across 2 ranks (2 per rank)
+LOCAL_SEQ_LOGPS = [
+ torch.tensor([0.1, -0.3]), # rank 0
+ torch.tensor([0.5, -0.2]), # rank 1
+]
+LOCAL_TARGETS = [
+ torch.tensor([0.1, 0.2]), # rank 0
+ torch.tensor([0.4, 0.3]), # rank 1; global sums to 1.0
+]
+
+
+def _tpo_worker(rank: int, world_size: int, init_file: str) -> None:
+ dist.init_process_group(
+ backend="gloo",
+ init_method=f"file://{init_file}",
+ world_size=world_size,
+ rank=rank,
+ )
+ try:
+ local = LOCAL_SEQ_LOGPS[rank].clone().requires_grad_(True)
+
+ gathered = TargetPOTrainer._gather_tensor_with_grad(local)
+ logps = torch.log_softmax(gathered.view(-1, NUM_GENERATIONS), dim=1).view(-1)
+
+ process_slice = slice(rank * local.size(0), (rank + 1) * local.size(0))
+ local_logps = logps[process_slice]
+ local_targets = LOCAL_TARGETS[rank]
+
+ loss = -(local_targets * local_logps).sum() * NUM_GENERATIONS / local_targets.numel()
+ loss.backward()
+
+ global_logps = torch.cat(LOCAL_SEQ_LOGPS)
+ global_targets = torch.cat(LOCAL_TARGETS)
+ global_softmax = torch.softmax(global_logps, dim=0)
+ scale = NUM_GENERATIONS / local_targets.numel()
+ expected = scale * (global_softmax[process_slice] - global_targets[process_slice])
+
+ torch.testing.assert_close(local.grad, expected)
+ finally:
+ dist.destroy_process_group()
+
+
+@pytest.mark.skipif(not torch.distributed.is_available(), reason="torch.distributed not available")
+def test_tpo_gradient_across_ranks_with_group_spanning_ranks():
+ """
+ A TPO prompt group of size 4 split 2/2 across DP ranks. The group's log-softmax normalizer depends on
+ all four completions, so the autograd-aware all_gather must route gradient from each rank's loss back
+ to the owning rank's local tensor. Expected local gradient is scale * (softmax - target) at local positions.
+ """
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ init_file = os.path.join(tmp_dir, "rendezvous")
+ mp.spawn(_tpo_worker, args=(WORLD_SIZE, init_file), nprocs=WORLD_SIZE, join=True)
diff --git a/tests/test_target_po_trainer.py b/tests/test_target_po_trainer.py
new file mode 100644
index 00000000000..275cf0c8caa
--- /dev/null
+++ b/tests/test_target_po_trainer.py
@@ -0,0 +1,437 @@
+# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import defaultdict
+from types import SimpleNamespace
+from unittest.mock import patch
+
+import pytest
+import torch
+from accelerate.utils.memory import release_memory
+from datasets import Dataset, load_dataset
+from transformers import TrainingArguments
+
+from trl import GRPOTrainer, TargetPOConfig, TargetPOTrainer
+
+from .testing_utils import TrlTestCase, require_liger_kernel
+
+
+class TestTargetPOConfig(TrlTestCase):
+ def test_defaults_to_one_step_per_generation(self):
+ config = TargetPOConfig(output_dir=self.tmp_dir, gradient_accumulation_steps=4)
+
+ assert config.loss_type == "tpo"
+ assert config.steps_per_generation == 1
+
+ def test_allows_multi_step_generation_when_groups_are_whole(self):
+ config = TargetPOConfig(
+ output_dir=self.tmp_dir,
+ per_device_train_batch_size=4,
+ num_generations=2,
+ steps_per_generation=2,
+ )
+
+ assert config.steps_per_generation == 2
+ per_step_batch = config.generation_batch_size // config.steps_per_generation
+ assert per_step_batch % config.num_generations == 0
+
+ def test_rejects_multi_step_generation_when_groups_are_cleaved(self):
+ with pytest.raises(ValueError, match="whole prompt groups"):
+ TargetPOConfig(
+ output_dir=self.tmp_dir,
+ per_device_train_batch_size=2,
+ num_generations=4,
+ steps_per_generation=2,
+ )
+
+ def test_rejects_distributed_multi_step_generation_when_local_step_splits_groups(self):
+ with patch.object(TrainingArguments, "world_size", new=property(lambda self: 2)):
+ with pytest.raises(ValueError, match="per_device_train_batch_size"):
+ TargetPOConfig(
+ output_dir=self.tmp_dir,
+ per_device_train_batch_size=3,
+ num_generations=2,
+ steps_per_generation=2,
+ )
+
+ def test_allows_distributed_single_step_generation_with_groups_spanning_ranks(self):
+ with patch.object(TrainingArguments, "world_size", new=property(lambda self: 2)):
+ config = TargetPOConfig(
+ output_dir=self.tmp_dir,
+ per_device_train_batch_size=3,
+ num_generations=2,
+ steps_per_generation=1,
+ )
+
+ assert config.generation_batch_size == 6
+ assert config.steps_per_generation == 1
+
+ def test_trainer_metadata(self):
+ assert TargetPOTrainer._name == "TargetPO"
+ assert TargetPOTrainer._tag_names == ["trl", "tpo"]
+
+ def test_allows_liger_kernel(self):
+ config = TargetPOConfig(output_dir=self.tmp_dir, use_liger_kernel=True)
+
+ assert config.use_liger_kernel
+
+
+class TestTargetPOTrainer(TrlTestCase):
+ def test_training(self):
+ dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
+
+ training_args = TargetPOConfig(
+ output_dir=self.tmp_dir,
+ learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates
+ per_device_train_batch_size=3,
+ num_generations=3,
+ max_completion_length=8,
+ report_to="none",
+ )
+ trainer = TargetPOTrainer(
+ model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
+ reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ args=training_args,
+ train_dataset=dataset,
+ )
+
+ previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+ trainer.train()
+
+ assert trainer.state.log_history[-1]["train_loss"] is not None
+
+ for n, param in previous_trainable_params.items():
+ new_param = trainer.model.get_parameter(n)
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
+
+ @require_liger_kernel
+ def test_training_with_liger_kernel(self):
+ from trl.trainer.target_po_trainer import LigerFusedLinearTargetPOLoss
+
+ def reward_func(completions, **kwargs):
+ return [float((len(completion) % 3) - 1) for completion in completions]
+
+ dataset = Dataset.from_dict(
+ {
+ "prompt": [
+ "Say hello.",
+ "Name a color.",
+ "What is 1+1?",
+ "Write ok.",
+ "Pick a letter.",
+ "Say bye.",
+ ]
+ }
+ )
+ training_args = TargetPOConfig(
+ output_dir=self.tmp_dir,
+ per_device_train_batch_size=3,
+ num_generations=3,
+ max_completion_length=8,
+ max_steps=1,
+ use_liger_kernel=True,
+ report_to="none",
+ logging_strategy="no",
+ save_strategy="no",
+ )
+ trainer = TargetPOTrainer(
+ model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
+ reward_funcs=reward_func,
+ args=training_args,
+ train_dataset=dataset,
+ )
+
+ assert isinstance(trainer.liger_tpo_loss, LigerFusedLinearTargetPOLoss)
+
+ train_result = trainer.train()
+
+ assert torch.isfinite(torch.tensor(train_result.training_loss))
+
+ release_memory(trainer.model, trainer)
+
+
+class TestTPOLoss:
+ def test_tpo_scores_match_population_whitened_skill(self):
+ scores = torch.tensor([0.0, 1.0, 0.0, 0.0])
+
+ tpo_scores = TargetPOTrainer.get_tpo_scores(scores, num_generations=2)
+
+ expected = torch.tensor([-1.0, 1.0, 0.0, 0.0])
+ torch.testing.assert_close(tpo_scores, expected)
+
+ def test_tpo_scores_exclude_invalid_completions(self):
+ scores = torch.tensor([0.0, 100.0, 1.0])
+ valid_mask = torch.tensor([True, False, True])
+
+ tpo_scores = TargetPOTrainer.get_tpo_scores(scores, num_generations=3, valid_mask=valid_mask)
+
+ expected = torch.tensor([-1.0, 0.0, 1.0])
+ torch.testing.assert_close(tpo_scores, expected)
+
+ def test_tpo_targets_use_population_whitened_scores(self):
+ old_sequence_logps = torch.zeros(2)
+ scores = torch.tensor([0.0, 1.0])
+ tpo_scores = TargetPOTrainer.get_tpo_scores(scores, num_generations=2)
+
+ targets = TargetPOTrainer.get_tpo_targets(old_sequence_logps, tpo_scores, num_generations=2)
+
+ expected = torch.softmax(torch.tensor([-1.0, 1.0]), dim=0)
+ torch.testing.assert_close(targets, expected)
+
+ def test_tpo_targets_match_closed_form(self):
+ old_sequence_logps = torch.log(torch.tensor([0.7, 0.3, 0.2, 0.8]))
+ scores = torch.tensor([1.0, -1.0, 0.0, 0.0])
+
+ targets = TargetPOTrainer.get_tpo_targets(old_sequence_logps, scores, num_generations=2)
+
+ expected_first_group = torch.softmax(torch.log(torch.tensor([0.7, 0.3])) + torch.tensor([1.0, -1.0]), dim=0)
+ expected_second_group = torch.tensor([0.2, 0.8])
+ expected = torch.cat([expected_first_group, expected_second_group])
+ torch.testing.assert_close(targets, expected)
+
+ def test_tpo_targets_exclude_invalid_completions(self):
+ old_sequence_logps = torch.log(torch.tensor([0.7, 0.3, 0.2, 0.8]))
+ scores = torch.tensor([1.0, -1.0, 0.0, 0.0])
+ valid_mask = torch.tensor([True, False, True, True])
+
+ targets = TargetPOTrainer.get_tpo_targets(old_sequence_logps, scores, num_generations=2, valid_mask=valid_mask)
+
+ expected = torch.tensor([1.0, 0.0, 0.2, 0.8])
+ torch.testing.assert_close(targets, expected)
+
+ @pytest.mark.parametrize("trainer_cls", [GRPOTrainer, TargetPOTrainer])
+ def test_tpo_kl_uses_per_step_token_normalizer(self, trainer_cls):
+ trainer = object.__new__(trainer_cls)
+ trainer.loss_type = "tpo"
+ trainer.off_policy_mask_threshold = None
+ trainer.top_entropy_quantile = 1.0
+ trainer.beta = 0.25
+ trainer.num_generations = 2
+ trainer.num_generations_eval = 2
+ trainer.current_gradient_accumulation_steps = 2
+ trainer.tpo_length_normalize_logps = False
+ trainer.args = SimpleNamespace(use_bias_correction_kl=False)
+ trainer.model = SimpleNamespace(training=True)
+ trainer.accelerator = SimpleNamespace(
+ gather=lambda tensor: tensor,
+ num_processes=1,
+ process_index=0,
+ )
+ trainer._metrics = {"train": defaultdict(list)}
+
+ per_token_logps = torch.zeros((2, 3))
+
+ def fake_get_per_token_logps_and_entropies(*args, **kwargs):
+ return per_token_logps, torch.zeros_like(per_token_logps)
+
+ trainer._get_per_token_logps_and_entropies = fake_get_per_token_logps_and_entropies
+
+ completion_mask = torch.tensor([[1, 1, 0], [1, 0, 0]])
+ ref_log_ratio = torch.log(torch.tensor(2.0))
+ ref_per_token_logps = per_token_logps + ref_log_ratio
+ inputs = {
+ "prompt_ids": torch.zeros((2, 1), dtype=torch.long),
+ "prompt_mask": torch.ones((2, 1), dtype=torch.long),
+ "completion_ids": torch.ones((2, 3), dtype=torch.long),
+ "completion_mask": completion_mask,
+ "advantages": torch.zeros(2),
+ "old_per_token_logps": per_token_logps,
+ "ref_per_token_logps": ref_per_token_logps,
+ "tpo_targets": torch.tensor([0.5, 0.5]),
+ "num_items_in_batch": torch.tensor(6),
+ }
+
+ loss = trainer_cls._compute_loss(trainer, model=None, inputs=inputs)
+
+ normalizer = torch.tensor(float(trainer.current_gradient_accumulation_steps))
+ tpo_loss = ref_log_ratio / normalizer
+ expected_kl_loss = torch.exp(ref_log_ratio) - ref_log_ratio - 1
+ expected_loss = tpo_loss + trainer.beta * expected_kl_loss / normalizer
+ torch.testing.assert_close(loss, expected_loss)
+ torch.testing.assert_close(torch.tensor(trainer._metrics["train"]["kl"][0]), expected_kl_loss)
+
+ @pytest.mark.parametrize("length_normalize", [True, False])
+ def test_tpo_loss_gradient_matches_policy_minus_target(self, length_normalize):
+ trainer = object.__new__(TargetPOTrainer)
+ trainer.loss_type = "tpo"
+ trainer.off_policy_mask_threshold = None
+ trainer.top_entropy_quantile = 1.0
+ trainer.beta = 0.0
+ trainer.num_generations = 2
+ trainer.num_generations_eval = 2
+ trainer.current_gradient_accumulation_steps = 1
+ trainer.tpo_length_normalize_logps = length_normalize
+ trainer.model = SimpleNamespace(training=True)
+ trainer.accelerator = SimpleNamespace(
+ gather=lambda tensor: tensor,
+ num_processes=1,
+ process_index=0,
+ )
+ trainer._metrics = {"train": defaultdict(list)}
+
+ per_token_logps = torch.tensor([[-0.1, -0.2], [-0.7, -0.3]], requires_grad=True)
+
+ def fake_get_per_token_logps_and_entropies(*args, **kwargs):
+ return per_token_logps, torch.zeros_like(per_token_logps)
+
+ trainer._get_per_token_logps_and_entropies = fake_get_per_token_logps_and_entropies
+
+ tpo_targets = torch.tensor([0.25, 0.75])
+ inputs = {
+ "prompt_ids": torch.zeros((2, 1), dtype=torch.long),
+ "prompt_mask": torch.ones((2, 1), dtype=torch.long),
+ "completion_ids": torch.ones((2, 2), dtype=torch.long),
+ "completion_mask": torch.ones((2, 2), dtype=torch.long),
+ "advantages": torch.zeros(2),
+ "old_per_token_logps": per_token_logps.detach(),
+ "tpo_targets": tpo_targets,
+ "num_items_in_batch": torch.tensor(4),
+ }
+
+ loss = TargetPOTrainer._compute_loss(trainer, model=None, inputs=inputs)
+ loss.backward()
+
+ completion_mask = inputs["completion_mask"].to(per_token_logps.dtype)
+ lengths = completion_mask.sum(dim=-1).clamp(min=1)
+ summed = (per_token_logps.detach() * completion_mask).sum(dim=-1)
+ sequence_logps = summed / lengths if length_normalize else summed
+ expected_sequence_grad = torch.softmax(sequence_logps, dim=0) - tpo_targets
+ per_token_scale = (1.0 / lengths) if length_normalize else torch.ones_like(lengths)
+ expected_grad = (expected_sequence_grad * per_token_scale).unsqueeze(1) * completion_mask
+ torch.testing.assert_close(per_token_logps.grad, expected_grad)
+
+ def test_tpo_loss_excludes_invalid_completions_from_group_softmax(self):
+ trainer = object.__new__(TargetPOTrainer)
+ trainer.loss_type = "tpo"
+ trainer.off_policy_mask_threshold = None
+ trainer.top_entropy_quantile = 1.0
+ trainer.beta = 0.0
+ trainer.num_generations = 3
+ trainer.num_generations_eval = 3
+ trainer.current_gradient_accumulation_steps = 1
+ trainer.tpo_length_normalize_logps = True
+ trainer.model = SimpleNamespace(training=True)
+ trainer.accelerator = SimpleNamespace(
+ gather=lambda tensor: tensor,
+ num_processes=1,
+ process_index=0,
+ )
+ trainer._metrics = {"train": defaultdict(list)}
+
+ per_token_logps = torch.tensor([[-0.1, -0.2], [5.0, 5.0], [-0.4, -0.6]], requires_grad=True)
+
+ def fake_get_per_token_logps_and_entropies(*args, **kwargs):
+ return per_token_logps, torch.zeros_like(per_token_logps)
+
+ trainer._get_per_token_logps_and_entropies = fake_get_per_token_logps_and_entropies
+
+ tpo_targets = torch.tensor([0.6, 0.0, 0.4])
+ tpo_valid_mask = torch.tensor([True, False, True])
+ inputs = {
+ "prompt_ids": torch.zeros((3, 1), dtype=torch.long),
+ "prompt_mask": torch.ones((3, 1), dtype=torch.long),
+ "completion_ids": torch.ones((3, 2), dtype=torch.long),
+ "completion_mask": torch.tensor([[1, 1], [0, 0], [1, 1]]),
+ "advantages": torch.zeros(3),
+ "old_per_token_logps": per_token_logps.detach(),
+ "tpo_targets": tpo_targets,
+ "tpo_valid_mask": tpo_valid_mask,
+ "num_items_in_batch": torch.tensor(4),
+ }
+
+ loss = TargetPOTrainer._compute_loss(trainer, model=None, inputs=inputs)
+ loss.backward()
+
+ completion_mask = inputs["completion_mask"].to(per_token_logps.dtype)
+ lengths = completion_mask.sum(dim=-1).clamp(min=1)
+ sequence_logps = (per_token_logps.detach() * completion_mask).sum(dim=-1) / lengths
+ valid_sequence_logps = sequence_logps[tpo_valid_mask]
+ expected_valid_grad = torch.softmax(valid_sequence_logps, dim=0) - tpo_targets[tpo_valid_mask]
+ expected_grad = torch.zeros_like(per_token_logps)
+ expected_grad[0] = (expected_valid_grad[0] / lengths[0]) * completion_mask[0]
+ expected_grad[2] = (expected_valid_grad[1] / lengths[2]) * completion_mask[2]
+ torch.testing.assert_close(per_token_logps.grad, expected_grad)
+
+ @pytest.mark.parametrize("chunk_size", [1, 2])
+ @pytest.mark.parametrize("length_normalize", [True, False])
+ @pytest.mark.parametrize("use_bias", [True, False])
+ @require_liger_kernel
+ def test_liger_tpo_loss_matches_dense_loss(self, chunk_size, length_normalize, use_bias):
+ from trl.trainer.target_po_trainer import LigerFusedLinearTargetPOLoss
+
+ torch.manual_seed(0)
+ hidden_states = torch.randn(3, 2, 4, requires_grad=True)
+ weight = torch.randn(5, 4, requires_grad=True)
+ bias = torch.randn(5, requires_grad=True) if use_bias else None
+ completion_ids = torch.tensor([[0, 1], [2, 3], [4, 0]])
+ completion_mask = torch.tensor([[1, 1], [0, 0], [1, 1]], dtype=torch.float32)
+ tpo_targets = torch.tensor([0.6, 0.0, 0.4])
+ tpo_valid_mask = torch.tensor([True, False, True])
+ ref_per_token_logps = torch.tensor([[-0.2, -0.4], [-0.1, -0.3], [-0.5, -0.7]])
+ old_per_token_logps = torch.tensor([[-0.3, -0.5], [-0.2, -0.4], [-0.6, -0.8]])
+ beta = 0.25
+ normalizer = 2.0
+
+ dense_hidden_states = hidden_states.detach().clone().requires_grad_(True)
+ dense_weight = weight.detach().clone().requires_grad_(True)
+ dense_bias = bias.detach().clone().requires_grad_(True) if use_bias else None
+
+ liger_loss = LigerFusedLinearTargetPOLoss(
+ beta=beta,
+ num_generations=3,
+ tpo_length_normalize_logps=length_normalize,
+ temperature=1.0,
+ chunk_size=chunk_size,
+ )
+ loss, mean_kl, mean_entropy = liger_loss(
+ hidden_states,
+ weight,
+ completion_ids,
+ completion_mask,
+ tpo_targets,
+ tpo_valid_mask,
+ bias=bias,
+ ref_per_token_logps=ref_per_token_logps,
+ old_per_token_logps=old_per_token_logps,
+ normalizer=normalizer,
+ )
+ loss.backward()
+
+ logits = dense_hidden_states @ dense_weight.t()
+ if dense_bias is not None:
+ logits = logits + dense_bias
+ log_probs = torch.log_softmax(logits.float(), dim=-1)
+ per_token_logps = log_probs.gather(dim=-1, index=completion_ids.unsqueeze(-1)).squeeze(-1)
+ sequence_logps = (per_token_logps * completion_mask).sum(dim=-1)
+ if length_normalize:
+ sequence_logps = sequence_logps / completion_mask.sum(dim=-1).clamp(min=1.0)
+ sequence_logps = sequence_logps.masked_fill(~tpo_valid_mask, torch.finfo(sequence_logps.dtype).min)
+ logps = torch.log_softmax(sequence_logps.view(-1, 3), dim=1).view(-1)
+ dense_loss = -(tpo_targets * logps).sum()
+ per_token_kl = (
+ torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
+ )
+ dense_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
+ dense_loss = dense_loss / normalizer + beta * dense_kl / normalizer
+ dense_entropy = (-(log_probs.exp() * log_probs).sum(dim=-1) * completion_mask).sum() / completion_mask.sum()
+ dense_loss.backward()
+
+ torch.testing.assert_close(loss, dense_loss.detach())
+ torch.testing.assert_close(mean_kl, dense_kl.detach())
+ torch.testing.assert_close(mean_entropy, dense_entropy.detach())
+ torch.testing.assert_close(hidden_states.grad, dense_hidden_states.grad)
+ torch.testing.assert_close(weight.grad, dense_weight.grad)
+ if use_bias:
+ torch.testing.assert_close(bias.grad, dense_bias.grad)
diff --git a/trl/__init__.py b/trl/__init__.py
index 4947fc16c56..de60a700288 100644
--- a/trl/__init__.py
+++ b/trl/__init__.py
@@ -66,6 +66,8 @@
"SFTConfig",
"SFTTrainer",
"SyncRefModelCallback",
+ "TargetPOConfig",
+ "TargetPOTrainer",
"WeaveCallback",
"get_kbit_device_map",
"get_peft_config",
@@ -114,6 +116,8 @@
SFTConfig,
SFTTrainer,
SyncRefModelCallback,
+ TargetPOConfig,
+ TargetPOTrainer,
WeaveCallback,
get_kbit_device_map,
get_peft_config,
diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py
index 1d79d365056..ecf59e35f00 100644
--- a/trl/trainer/__init__.py
+++ b/trl/trainer/__init__.py
@@ -38,6 +38,8 @@
"rloo_trainer": ["RLOOTrainer"],
"sft_config": ["SFTConfig"],
"sft_trainer": ["SFTTrainer"],
+ "target_po_config": ["TargetPOConfig"],
+ "target_po_trainer": ["TargetPOTrainer"],
"utils": [
"disable_dropout_in_model",
"ensure_master_addr_port",
@@ -69,6 +71,8 @@
from .rloo_trainer import RLOOTrainer
from .sft_config import SFTConfig
from .sft_trainer import SFTTrainer
+ from .target_po_config import TargetPOConfig
+ from .target_po_trainer import TargetPOTrainer
from .utils import (
disable_dropout_in_model,
ensure_master_addr_port,
diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py
index e5f373c19f0..1406a6ede7d 100644
--- a/trl/trainer/grpo_config.py
+++ b/trl/trainer/grpo_config.py
@@ -256,6 +256,17 @@ class GRPOConfig(_BaseConfig):
- `"vespo"`: Variational Sequence-Level Soft Policy Optimization. Replaces hard clipping with a smooth,
asymmetric Gamma weighting function applied directly to sequence-level importance weights. Introduced in
the [VESPO paper](https://huggingface.co/papers/2602.10693).
+ - `"tpo"`: Target Policy Optimization loss. Builds a target distribution over each prompt's sampled
+ completions from the rollout policy probabilities and normalized rewards, then fits the current policy
+ to that target with cross-entropy.
+ tpo_target_temperature (`float`, *optional*, defaults to `1.0`):
+ Temperature used to build the Target Policy Optimization target distribution when `loss_type="tpo"`.
+ Lower values make the target more concentrated on high-scoring completions.
+ tpo_length_normalize_logps (`bool`, *optional*, defaults to `True`):
+ Whether to length-normalize sequence log-probabilities (per-token mean instead of sum) when building
+ the TPO target and computing the TPO loss. Recommended for real-length generations, because raw sum
+ sequence-logps vary over orders of magnitude across a group and make the old-policy term dominate the
+ target. Set to `False` to reproduce the paper's literal `p_i^old` formulation.
mask_truncated_completions (`bool`, *optional*, defaults to `False`):
When enabled, truncated completions are excluded from the loss calculation, preventing them from being
incorrectly penalized and introducing noise during training. According to the
@@ -738,7 +749,28 @@ class GRPOConfig(_BaseConfig):
"paper](https://huggingface.co/papers/2602.05261)."
"'vespo': Variational Sequence-Level Soft Policy Optimization. Replaces hard clipping with a smooth, "
"asymmetric Gamma weighting function applied directly to sequence-level importance weights. Introduced in "
- "the [VESPO paper](https://huggingface.co/papers/2602.10693)."
+ "the [VESPO paper](https://huggingface.co/papers/2602.10693). "
+ "'tpo': Target Policy Optimization loss. Builds a target distribution over each prompt's sampled "
+ "completions from rollout policy probabilities and normalized rewards, then fits the current policy to "
+ "that target with cross-entropy."
+ },
+ )
+ tpo_target_temperature: float = field(
+ default=1.0,
+ metadata={
+ "help": "Temperature used to build the Target Policy Optimization target distribution when "
+ "`loss_type='tpo'`. Lower values make the target more concentrated on high-scoring completions."
+ },
+ )
+ tpo_length_normalize_logps: bool = field(
+ default=True,
+ metadata={
+ "help": "Whether to length-normalize sequence log-probabilities (use per-token mean instead of sum) "
+ "when building the TPO target distribution and computing the TPO loss. Without this, sequences of "
+ "different lengths have log-probabilities spanning orders of magnitude, causing the old-policy term "
+ "in `q_i ∝ p_i^old * exp(u_i / eta)` to dominate and collapse the target to ~one-hot on the "
+ "highest-old-logp completion. Set to `False` to reproduce the paper's literal sequence-probability "
+ "formulation."
},
)
mask_truncated_completions: bool = field(
@@ -943,3 +975,34 @@ def __post_init__(self):
if self.delta is not None and self.use_liger_kernel:
raise ValueError("Liger kernel does not support two-sided GRPO loss yet.")
+
+ if self.tpo_target_temperature <= 0.0:
+ raise ValueError(
+ f"tpo_target_temperature must be greater than 0.0. You provided {self.tpo_target_temperature}."
+ )
+
+ if self.loss_type == "tpo":
+ if self.use_liger_kernel:
+ raise ValueError("Liger kernel does not support the TPO loss yet.")
+ # TPO's target distribution is normalized per prompt group, so each optimization step must contain a whole
+ # number of groups. That holds whenever the per-step batch (generation_batch_size // steps_per_generation)
+ # is divisible by num_generations.
+ per_step_batch = self.generation_batch_size // self.steps_per_generation
+ if per_step_batch % self.num_generations != 0:
+ raise ValueError(
+ f"TPO requires each optimization step to contain whole prompt groups. With "
+ f"steps_per_generation={self.steps_per_generation} and num_generations={self.num_generations}, "
+ f"the per-step batch of {per_step_batch} is not divisible by num_generations. Increase "
+ f"per_device_train_batch_size or reduce steps_per_generation."
+ )
+ if (
+ num_processes > 1
+ and self.steps_per_generation > 1
+ and self.per_device_train_batch_size % self.num_generations != 0
+ ):
+ raise ValueError(
+ f"TPO with distributed multi-step generation requires each rank's per-step batch to contain "
+ f"whole prompt groups. With per_device_train_batch_size={self.per_device_train_batch_size} and "
+ f"num_generations={self.num_generations}, the per-rank batch is not divisible by "
+ f"num_generations. Increase per_device_train_batch_size or reduce steps_per_generation."
+ )
diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py
index 0f400e14fcd..a0579a1824d 100644
--- a/trl/trainer/grpo_trainer.py
+++ b/trl/trainer/grpo_trainer.py
@@ -41,6 +41,7 @@
from packaging.version import Version
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.nn.functional import all_gather as _all_gather_with_grad
from torch.utils.data import Sampler
from transformers import (
AutoModelForSequenceClassification,
@@ -557,6 +558,8 @@ def __init__(
self.vllm_importance_sampling_cap = args.vllm_importance_sampling_cap
self.use_liger_kernel = args.use_liger_kernel
self.loss_type = args.loss_type
+ self.tpo_target_temperature = args.tpo_target_temperature
+ self.tpo_length_normalize_logps = args.tpo_length_normalize_logps
self.multi_objective_aggregation = args.multi_objective_aggregation
self.scale_rewards = args.scale_rewards
self.importance_sampling_level = args.importance_sampling_level
@@ -1138,7 +1141,8 @@ def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> di
# self._buffered_inputs=None can occur when resuming from a checkpoint
generation_batch = self._generate_and_score_completions(generation_batch)
generation_batch = split_pixel_values_by_grid(generation_batch)
- generation_batch = shuffle_sequence_dict(generation_batch)
+ if self.loss_type != "tpo":
+ generation_batch = shuffle_sequence_dict(generation_batch)
generation_batches = split_tensor_dict(generation_batch, self.args.steps_per_generation)
self._buffered_inputs = [unsplit_pixel_values_by_grid(batch) for batch in generation_batches]
inputs = self._buffered_inputs[self._step % self.args.steps_per_generation]
@@ -2026,8 +2030,10 @@ def _generate_and_score_completions(
# When using vLLM, we always compute old_per_token_logps for importance sampling, it was shown that the
# distribution mismatch between vLLM and the training model can be large and harm the training.
generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency
- if self.args.gradient_accumulation_steps % generate_every != 0 or (
- self.use_vllm and self.vllm_importance_sampling_correction
+ if (
+ self.loss_type == "tpo"
+ or self.args.gradient_accumulation_steps % generate_every != 0
+ or (self.use_vllm and self.vllm_importance_sampling_correction)
):
old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
self.model,
@@ -2172,6 +2178,34 @@ def _generate_and_score_completions(
(self.accelerator.process_index + 1) * len(prompts),
)
all_process_advantages = advantages.clone() # keep the aggregated advantages for logging
+ tpo_targets = None
+ tpo_valid_mask = None
+ if self.loss_type == "tpo":
+ if old_per_token_logps is None:
+ raise RuntimeError("TPO requires rollout-time log probabilities to build the target distribution.")
+ loss_mask = completion_mask if tool_mask is None else completion_mask * tool_mask
+ tpo_valid_mask = loss_mask.any(dim=-1)
+ # Length-normalize sequence logps so the target distribution isn't dominated by length variance
+ # (without this, log_softmax(old_sequence_logps) collapses to ~one-hot on the longest/most-likely
+ # completion, and rewards can't compete).
+ old_sequence_logps = (old_per_token_logps * loss_mask).sum(dim=-1)
+ if self.tpo_length_normalize_logps:
+ old_sequence_logps = old_sequence_logps / loss_mask.sum(dim=-1).clamp(min=1)
+ all_process_old_sequence_logps = gather(old_sequence_logps)
+ all_process_tpo_valid_mask = gather(tpo_valid_mask)
+ tpo_scores = self.get_tpo_scores(rewards, num_generations, valid_mask=all_process_tpo_valid_mask)
+ all_process_tpo_targets = self.get_tpo_targets(
+ all_process_old_sequence_logps,
+ tpo_scores,
+ num_generations=num_generations,
+ temperature=self.tpo_target_temperature,
+ valid_mask=all_process_tpo_valid_mask,
+ )
+ tpo_targets = all_process_tpo_targets[process_slice]
+ target_groups = all_process_tpo_targets.view(-1, num_generations)
+ target_entropy = -(target_groups * target_groups.clamp_min(torch.finfo(target_groups.dtype).tiny).log())
+ target_entropy = target_entropy.sum(dim=1).mean()
+ self._metrics[mode]["tpo/target_entropy"].append(target_entropy.item())
advantages = advantages[process_slice]
# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
@@ -2258,6 +2292,9 @@ def _generate_and_score_completions(
}
if old_per_token_logps is not None:
output["old_per_token_logps"] = old_per_token_logps
+ if tpo_targets is not None:
+ output["tpo_targets"] = tpo_targets
+ output["tpo_valid_mask"] = tpo_valid_mask
if self.use_vllm and self.vllm_importance_sampling_correction:
output["importance_sampling_ratio"] = vllm_importance_sampling_ratio
if sampling_per_token_logps is not None:
@@ -2365,6 +2402,63 @@ def get_off_policy_mask(
is_low_kl = avg_seq_kl <= off_policy_threshold
return (is_pos_adv | is_low_kl).to(dtype=mask.dtype) # (B, 1)
+ @staticmethod
+ def get_tpo_targets(
+ old_sequence_logps: torch.Tensor,
+ scores: torch.Tensor,
+ num_generations: int,
+ temperature: float = 1.0,
+ valid_mask: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ """
+ Build the Target Policy Optimization target distribution for each prompt group.
+
+ The target is q_i ∝ p_i_old * exp(score_i / temperature), where p_i_old is the rollout policy distribution
+ over the sampled completions in a prompt group. Completions with `valid_mask=False` (e.g. all tokens
+ masked out) get zero target probability and don't contribute to the group softmax.
+ """
+ if temperature <= 0.0:
+ raise ValueError(f"temperature must be greater than 0.0. You provided {temperature}.")
+
+ old_sequence_logps = old_sequence_logps.view(-1, num_generations)
+ scores = scores.view(-1, num_generations)
+ if valid_mask is not None:
+ valid_mask = valid_mask.view(-1, num_generations).bool()
+ old_sequence_logps = old_sequence_logps.masked_fill(~valid_mask, torch.finfo(old_sequence_logps.dtype).min)
+ target_logits = torch.log_softmax(old_sequence_logps, dim=1) + scores / temperature
+ if valid_mask is not None:
+ target_logits = target_logits.masked_fill(~valid_mask, torch.finfo(target_logits.dtype).min)
+ targets = torch.softmax(target_logits, dim=1)
+ if valid_mask is not None:
+ targets = torch.where(valid_mask, targets, torch.zeros_like(targets))
+ return targets.view(-1).detach()
+
+ @staticmethod
+ def get_tpo_scores(
+ scores: torch.Tensor, num_generations: int, valid_mask: torch.Tensor | None = None
+ ) -> torch.Tensor:
+ scores = scores.view(-1, num_generations)
+ if valid_mask is not None:
+ valid_mask = valid_mask.view(-1, num_generations).bool()
+ valid_count = valid_mask.sum(dim=1, keepdim=True).clamp(min=1)
+ mean_scores = scores.masked_fill(~valid_mask, 0.0).sum(dim=1, keepdim=True) / valid_count
+ centered_scores = torch.where(valid_mask, scores - mean_scores, torch.zeros_like(scores))
+ std_scores = (centered_scores.square().sum(dim=1, keepdim=True) / valid_count).sqrt()
+ else:
+ mean_scores = scores.mean(dim=1, keepdim=True)
+ std_scores = scores.std(dim=1, unbiased=False, keepdim=True)
+ centered_scores = scores - mean_scores
+ scores = torch.where(std_scores > 1e-6, centered_scores / std_scores, centered_scores)
+ return scores.view(-1)
+
+ @staticmethod
+ def _gather_tensor_with_grad(tensor: torch.Tensor) -> torch.Tensor:
+ # Autograd-aware all_gather: required when a TPO prompt group spans DP ranks, so the log-softmax
+ # normalizer's gradient routes back to the owning rank.
+ if not torch.distributed.is_available() or not torch.distributed.is_initialized():
+ return tensor
+ return torch.cat(_all_gather_with_grad(tensor), dim=0)
+
@staticmethod
@torch.no_grad()
def get_gamma_weights(
@@ -2460,6 +2554,62 @@ def _compute_loss(self, model, inputs):
old_per_token_logps = inputs.get("old_per_token_logps")
old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps
+ if self.loss_type == "tpo":
+ if "tpo_targets" not in inputs:
+ raise RuntimeError("TPO loss requires `tpo_targets` in the prepared inputs.")
+ if self.off_policy_mask_threshold is not None:
+ raise ValueError("TPO loss does not support `off_policy_mask_threshold`.")
+ if self.top_entropy_quantile < 1.0:
+ raise ValueError("TPO loss does not support `top_entropy_quantile < 1.0`.")
+
+ mode = "train" if self.model.training else "eval"
+ num_generations = self.num_generations if mode == "train" else self.num_generations_eval
+ # Length-normalized sequence logp (see _generate_and_score_completions for why).
+ sequence_logps = (per_token_logps * mask).sum(dim=-1)
+ if self.tpo_length_normalize_logps:
+ sequence_logps = sequence_logps / mask.sum(dim=-1).clamp(min=1)
+ all_sequence_logps = self._gather_tensor_with_grad(sequence_logps)
+ tpo_valid_mask = inputs.get("tpo_valid_mask")
+ if tpo_valid_mask is not None:
+ tpo_valid_mask = tpo_valid_mask.to(device=sequence_logps.device, dtype=torch.bool)
+ all_tpo_valid_mask = gather(tpo_valid_mask)
+ all_sequence_logps = all_sequence_logps.masked_fill(
+ ~all_tpo_valid_mask, torch.finfo(all_sequence_logps.dtype).min
+ )
+ all_logps = torch.log_softmax(all_sequence_logps.view(-1, num_generations), dim=1).view(-1)
+ process_slice = slice(
+ self.accelerator.process_index * sequence_logps.size(0),
+ (self.accelerator.process_index + 1) * sequence_logps.size(0),
+ )
+ logps = all_logps[process_slice]
+ tpo_targets = inputs["tpo_targets"].to(logps.dtype)
+ if tpo_valid_mask is not None:
+ tpo_targets = torch.where(tpo_valid_mask, tpo_targets, torch.zeros_like(tpo_targets))
+ loss = -(tpo_targets * logps).sum() * num_generations / tpo_targets.numel()
+
+ normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0
+ loss = loss / normalizer
+
+ if self.beta != 0.0:
+ ref_per_token_logps = inputs["ref_per_token_logps"]
+ per_token_kl = (
+ torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
+ )
+ if self.args.use_bias_correction_kl:
+ per_token_kl = per_token_kl * torch.exp(per_token_logps - old_per_token_logps)
+ kl_normalizer = (
+ self.accelerator.gather(mask.sum()).sum().clamp(min=1.0) / self.accelerator.num_processes
+ )
+ kl_loss = (per_token_kl * mask).sum() / kl_normalizer
+ loss = loss + self.beta * kl_loss / normalizer
+ self._metrics[mode]["kl"].append(self.accelerator.gather(kl_loss).nanmean().item())
+
+ completion_token_count = mask.sum().clamp(min=1.0)
+ mean_entropy = (entropies * mask).sum() / completion_token_count
+ self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item())
+
+ return loss
+
if self.off_policy_mask_threshold is not None:
# OPSM should use inference-time logprobs to detect both sources of off-policyness:
# 1. Drift from gradient updates (always present)
diff --git a/trl/trainer/target_po_config.py b/trl/trainer/target_po_config.py
new file mode 100644
index 00000000000..0879102f3b6
--- /dev/null
+++ b/trl/trainer/target_po_config.py
@@ -0,0 +1,55 @@
+# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass, field
+
+from .grpo_config import GRPOConfig
+
+
+@dataclass
+class TargetPOConfig(GRPOConfig):
+ r"""
+ Configuration class for the [`TargetPOTrainer`].
+
+ This class extends [`GRPOConfig`] with defaults for Target Policy Optimization. For a full list of training
+ arguments, please refer to [`~transformers.TrainingArguments`] and [`GRPOConfig`].
+
+ Unless the user passes `generation_batch_size` or `steps_per_generation`, `TargetPOConfig` defaults
+ `steps_per_generation` to `1` as a safe starting point. Values greater than `1` are supported as long as each
+ optimization step still contains whole prompt groups (see [`GRPOConfig`] for the divisibility check).
+
+ Parameters:
+ loss_type (`str`, *optional*, defaults to `"tpo"`):
+ Loss formulation. `TargetPOConfig` requires this to stay set to `"tpo"`.
+ tpo_target_temperature (`float`, *optional*, defaults to `1.0`):
+ Temperature used to build the TPO target distribution. Lower values make the target more concentrated on
+ high-scoring completions.
+ """
+
+ loss_type: str = field(
+ default="tpo",
+ metadata={"help": "Loss formulation. `TargetPOConfig` requires this to stay set to `tpo`."},
+ )
+
+ def __post_init__(self):
+ if self.generation_batch_size is None and self.steps_per_generation is None:
+ self.steps_per_generation = 1
+
+ use_liger_kernel = self.use_liger_kernel
+ self.use_liger_kernel = False
+ super().__post_init__()
+ self.use_liger_kernel = use_liger_kernel
+
+ if self.loss_type != "tpo":
+ raise ValueError(f"TargetPOConfig requires loss_type='tpo'. You provided {self.loss_type!r}.")
diff --git a/trl/trainer/target_po_trainer.py b/trl/trainer/target_po_trainer.py
new file mode 100644
index 00000000000..fef993b6ea8
--- /dev/null
+++ b/trl/trainer/target_po_trainer.py
@@ -0,0 +1,3233 @@
+# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import atexit
+import copy
+import importlib.resources as pkg_resources
+import inspect
+import math
+import os
+import sys
+import textwrap
+import time
+import warnings
+from collections import defaultdict, deque
+from collections.abc import Callable
+from contextlib import nullcontext
+from pathlib import Path
+from typing import Any, Protocol
+
+import numpy as np
+import pandas as pd
+import torch
+import torch.utils.data
+import transformers
+from accelerate.logging import get_logger
+from accelerate.utils import gather, gather_object, is_peft_model, set_seed
+from datasets import Dataset, IterableDataset
+from huggingface_hub import CommitScheduler, DatasetCard, DatasetCardData, create_repo
+from packaging.version import Version
+from torch import nn
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.nn.functional import all_gather as _all_gather_with_grad
+from torch.utils.data import Sampler
+from transformers import (
+ AutoModelForSequenceClassification,
+ AutoProcessor,
+ AutoTokenizer,
+ GenerationConfig,
+ PreTrainedModel,
+ PreTrainedTokenizerBase,
+ ProcessorMixin,
+ TrainerCallback,
+ is_trackio_available,
+ is_wandb_available,
+)
+from transformers.utils import is_peft_available, is_rich_available
+
+from ..chat_template_utils import (
+ add_response_schema,
+ get_training_chat_template,
+ is_chat_template_prefix_preserving,
+ parse_response,
+ supports_tool_calling,
+)
+from ..data_utils import apply_chat_template, is_conversational, prepare_multimodal_messages
+from ..extras.profiling import profiling_context, profiling_decorator
+from ..generation.vllm_generation import VLLMGeneration
+from ..import_utils import is_jmespath_available, is_liger_kernel_available
+from ..models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation
+from ..models.utils import _ForwardRedirection, disable_gradient_checkpointing
+from .base_trainer import _BaseTrainer
+from .callbacks import SyncRefModelCallback
+from .target_po_config import TargetPOConfig
+from .utils import (
+ RepeatSampler,
+ create_model_from_path,
+ disable_dropout_in_model,
+ entropy_from_logits,
+ get_config_model_id,
+ identity,
+ nanmax,
+ nanmin,
+ nanstd,
+ pad,
+ print_prompt_completions_sample,
+ selective_log_softmax,
+ shuffle_sequence_dict,
+ shutdown_event_loop_in_daemon,
+ split_pixel_values_by_grid,
+ split_tensor_dict,
+ start_event_loop_in_daemon,
+ unsplit_pixel_values_by_grid,
+ use_adapter,
+)
+
+
+if is_peft_available():
+ from peft import PeftConfig, PeftModel, get_peft_model
+
+if is_liger_kernel_available():
+ from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
+
+
+ class LigerFusedLinearTargetPOFunction(torch.autograd.Function):
+ @staticmethod
+ def _get_process_info():
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
+ return torch.distributed.get_rank(), torch.distributed.get_world_size()
+ return 0, 1
+
+ @staticmethod
+ def _gather_tensor(tensor: torch.Tensor) -> torch.Tensor:
+ if not torch.distributed.is_available() or not torch.distributed.is_initialized():
+ return tensor
+ gathered = [torch.empty_like(tensor) for _ in range(torch.distributed.get_world_size())]
+ torch.distributed.all_gather(gathered, tensor)
+ return torch.stack(gathered) if tensor.dim() == 0 else torch.cat(gathered, dim=0)
+
+ @staticmethod
+ def _chunk_logps_and_entropies(
+ input_chunk, weight, selected_token_ids_chunk, bias, temperature, compute_entropy
+ ):
+ log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(
+ input_chunk, weight, bias=bias, temperature=temperature
+ )
+ per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids_chunk.unsqueeze(-1)).squeeze(-1)
+ entropies = -(log_probs.exp() * log_probs).sum(dim=-1) if compute_entropy else None
+ return per_token_logps, entropies
+
+ @staticmethod
+ def _compute_surrogate_loss(
+ input_chunk,
+ weight,
+ selected_token_ids_chunk,
+ attention_mask_chunk,
+ tpo_per_token_grads_chunk,
+ bias,
+ ref_per_token_logps_chunk,
+ old_per_token_logps_chunk,
+ beta,
+ kl_normalizer,
+ use_bias_correction_kl,
+ temperature,
+ ):
+ per_token_logps, _ = LigerFusedLinearTargetPOFunction._chunk_logps_and_entropies(
+ input_chunk, weight, selected_token_ids_chunk, bias, temperature, compute_entropy=False
+ )
+ loss = (per_token_logps * tpo_per_token_grads_chunk).sum()
+
+ if beta != 0.0:
+ if ref_per_token_logps_chunk is None:
+ raise RuntimeError("TargetPO Liger loss requires `ref_per_token_logps` when beta is non-zero.")
+ ref_per_token_logps_chunk = ref_per_token_logps_chunk.float()
+ old_per_token_logps_chunk = (
+ per_token_logps.detach()
+ if old_per_token_logps_chunk is None
+ else old_per_token_logps_chunk.float()
+ )
+ per_token_kl = (
+ torch.exp(ref_per_token_logps_chunk - per_token_logps)
+ - (ref_per_token_logps_chunk - per_token_logps)
+ - 1
+ )
+ if use_bias_correction_kl:
+ per_token_kl = per_token_kl * torch.exp(per_token_logps - old_per_token_logps_chunk)
+ loss = loss + beta * (per_token_kl * attention_mask_chunk).sum() / kl_normalizer
+
+ return loss
+
+ @staticmethod
+ def forward(
+ ctx,
+ _input,
+ weight,
+ selected_token_ids,
+ attention_mask,
+ tpo_targets,
+ tpo_valid_mask,
+ bias=None,
+ ref_per_token_logps=None,
+ old_per_token_logps=None,
+ beta=0.0,
+ num_generations=1,
+ tpo_length_normalize_logps=True,
+ use_bias_correction_kl=False,
+ normalizer=1.0,
+ temperature=1.0,
+ chunk_size=1,
+ ):
+ chunk_size = max(1, int(chunk_size))
+ chunks = max(1, math.ceil(_input.shape[0] / chunk_size))
+ input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
+ selected_token_ids_chunks = torch.chunk(selected_token_ids, chunks=chunks, dim=0)
+ attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
+ ref_per_token_logps_chunks = (
+ torch.chunk(ref_per_token_logps, chunks=chunks, dim=0)
+ if ref_per_token_logps is not None
+ else [None] * chunks
+ )
+ old_per_token_logps_chunks = (
+ torch.chunk(old_per_token_logps, chunks=chunks, dim=0)
+ if old_per_token_logps is not None
+ else [None] * chunks
+ )
+
+ per_token_logps_chunks = []
+ entropy_sum = torch.zeros((), device=_input.device, dtype=torch.float32)
+ with torch.no_grad():
+ for input_chunk, selected_token_ids_chunk, attention_mask_chunk in zip(
+ input_chunks, selected_token_ids_chunks, attention_mask_chunks, strict=True
+ ):
+ per_token_logps_chunk, entropies_chunk = (
+ LigerFusedLinearTargetPOFunction._chunk_logps_and_entropies(
+ input_chunk, weight, selected_token_ids_chunk, bias, temperature, compute_entropy=True
+ )
+ )
+ loss_mask_chunk = attention_mask_chunk.to(per_token_logps_chunk.dtype)
+ per_token_logps_chunks.append(per_token_logps_chunk.float())
+ entropy_sum.add_((entropies_chunk * loss_mask_chunk).sum())
+
+ per_token_logps = torch.cat(per_token_logps_chunks, dim=0)
+ loss_mask = attention_mask.to(per_token_logps.dtype)
+ tpo_targets = torch.where(tpo_valid_mask, tpo_targets, torch.zeros_like(tpo_targets))
+ tpo_targets = tpo_targets.to(per_token_logps.dtype)
+
+ sequence_lengths = loss_mask.sum(dim=-1).clamp(min=1.0)
+ sequence_logps = (per_token_logps * loss_mask).sum(dim=-1)
+ if tpo_length_normalize_logps:
+ sequence_logps = sequence_logps / sequence_lengths
+
+ rank, world_size = LigerFusedLinearTargetPOFunction._get_process_info()
+ all_sequence_logps = LigerFusedLinearTargetPOFunction._gather_tensor(sequence_logps)
+ all_tpo_targets = LigerFusedLinearTargetPOFunction._gather_tensor(tpo_targets)
+ all_tpo_valid_mask = LigerFusedLinearTargetPOFunction._gather_tensor(tpo_valid_mask)
+ all_sequence_logps = all_sequence_logps.masked_fill(
+ ~all_tpo_valid_mask, torch.finfo(all_sequence_logps.dtype).min
+ )
+
+ all_logps = torch.log_softmax(all_sequence_logps.view(-1, num_generations), dim=1).view(-1)
+ process_slice = slice(rank * sequence_logps.size(0), (rank + 1) * sequence_logps.size(0))
+ logps = all_logps[process_slice]
+ loss_scale = num_generations / tpo_targets.numel()
+ tpo_loss = -(tpo_targets * logps).sum() * loss_scale
+
+ all_probs = torch.softmax(all_sequence_logps.view(-1, num_generations), dim=1)
+ target_groups = all_tpo_targets.view(-1, num_generations)
+ target_group_sums = target_groups.sum(dim=1, keepdim=True)
+ all_sequence_grads = (all_probs * target_group_sums - target_groups) * loss_scale
+ sequence_grads = all_sequence_grads.view(-1)[process_slice]
+ if tpo_length_normalize_logps:
+ tpo_per_token_grads = sequence_grads.unsqueeze(1) * loss_mask / sequence_lengths.unsqueeze(1)
+ else:
+ tpo_per_token_grads = sequence_grads.unsqueeze(1) * loss_mask
+ tpo_per_token_grads = tpo_per_token_grads / normalizer
+
+ kl_normalizer = loss_mask.sum()
+ kl_normalizer = LigerFusedLinearTargetPOFunction._gather_tensor(kl_normalizer).sum() / world_size
+ kl_normalizer = kl_normalizer.clamp(min=1.0)
+
+ kl_loss = torch.zeros((), device=_input.device, dtype=torch.float32)
+ if beta != 0.0:
+ if ref_per_token_logps is None:
+ raise RuntimeError("TargetPO Liger loss requires `ref_per_token_logps` when beta is non-zero.")
+ old_per_token_logps_for_kl = (
+ per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps.float()
+ )
+ per_token_kl = (
+ torch.exp(ref_per_token_logps.float() - per_token_logps)
+ - (ref_per_token_logps.float() - per_token_logps)
+ - 1
+ )
+ if use_bias_correction_kl:
+ per_token_kl = per_token_kl * torch.exp(per_token_logps - old_per_token_logps_for_kl)
+ kl_loss = (per_token_kl * loss_mask).sum() / kl_normalizer
+
+ loss = tpo_loss / normalizer + beta * kl_loss / normalizer
+ mean_entropy = entropy_sum / loss_mask.sum().clamp(min=1.0)
+
+ tpo_per_token_grad_chunks = torch.chunk(tpo_per_token_grads, chunks=chunks, dim=0)
+ grad_weight = torch.zeros_like(weight)
+ grad_inputs = []
+ grad_bias = torch.zeros_like(bias) if bias is not None else None
+
+ for (
+ input_chunk,
+ selected_token_ids_chunk,
+ attention_mask_chunk,
+ tpo_per_token_grads_chunk,
+ ref_per_token_logps_chunk,
+ old_per_token_logps_chunk,
+ ) in zip(
+ input_chunks,
+ selected_token_ids_chunks,
+ attention_mask_chunks,
+ tpo_per_token_grad_chunks,
+ ref_per_token_logps_chunks,
+ old_per_token_logps_chunks,
+ strict=True,
+ ):
+ loss_mask_chunk = attention_mask_chunk.to(tpo_per_token_grads_chunk.dtype)
+ if bias is not None:
+
+ def compute_surrogate(
+ input_chunk,
+ weight,
+ bias,
+ selected_token_ids_chunk,
+ loss_mask_chunk,
+ tpo_per_token_grads_chunk,
+ ref_per_token_logps_chunk,
+ old_per_token_logps_chunk,
+ ):
+ return LigerFusedLinearTargetPOFunction._compute_surrogate_loss(
+ input_chunk,
+ weight,
+ selected_token_ids_chunk,
+ loss_mask_chunk,
+ tpo_per_token_grads_chunk,
+ bias,
+ ref_per_token_logps_chunk,
+ old_per_token_logps_chunk,
+ beta / normalizer,
+ kl_normalizer,
+ use_bias_correction_kl,
+ temperature,
+ )
+
+ chunk_grad_input, chunk_grad_weight, chunk_grad_bias = torch.func.grad(
+ compute_surrogate, argnums=(0, 1, 2)
+ )(
+ input_chunk,
+ weight,
+ bias,
+ selected_token_ids_chunk,
+ loss_mask_chunk,
+ tpo_per_token_grads_chunk,
+ ref_per_token_logps_chunk,
+ old_per_token_logps_chunk,
+ )
+ grad_bias.add_(chunk_grad_bias)
+ else:
+
+ def compute_surrogate(
+ input_chunk,
+ weight,
+ selected_token_ids_chunk,
+ loss_mask_chunk,
+ tpo_per_token_grads_chunk,
+ ref_per_token_logps_chunk,
+ old_per_token_logps_chunk,
+ ):
+ return LigerFusedLinearTargetPOFunction._compute_surrogate_loss(
+ input_chunk,
+ weight,
+ selected_token_ids_chunk,
+ loss_mask_chunk,
+ tpo_per_token_grads_chunk,
+ None,
+ ref_per_token_logps_chunk,
+ old_per_token_logps_chunk,
+ beta / normalizer,
+ kl_normalizer,
+ use_bias_correction_kl,
+ temperature,
+ )
+
+ chunk_grad_input, chunk_grad_weight = torch.func.grad(compute_surrogate, argnums=(0, 1))(
+ input_chunk,
+ weight,
+ selected_token_ids_chunk,
+ loss_mask_chunk,
+ tpo_per_token_grads_chunk,
+ ref_per_token_logps_chunk,
+ old_per_token_logps_chunk,
+ )
+
+ grad_inputs.append(chunk_grad_input)
+ grad_weight.add_(chunk_grad_weight)
+
+ grad_input = torch.cat(grad_inputs, dim=0)
+ if bias is not None:
+ ctx.save_for_backward(grad_input, grad_weight, grad_bias)
+ else:
+ ctx.save_for_backward(grad_input, grad_weight)
+ ctx.has_bias = bias is not None
+ return loss, kl_loss, mean_entropy
+
+ @staticmethod
+ def backward(ctx, grad_output, *grad_metrics):
+ if ctx.has_bias:
+ grad_input, grad_weight, grad_bias = ctx.saved_tensors
+ else:
+ grad_input, grad_weight = ctx.saved_tensors
+ grad_bias = None
+
+ grad_input = grad_input * grad_output
+ grad_weight = grad_weight * grad_output
+ if grad_bias is not None:
+ grad_bias = grad_bias * grad_output
+
+ return (
+ grad_input,
+ grad_weight,
+ None, # selected_token_ids
+ None, # attention_mask
+ None, # tpo_targets
+ None, # tpo_valid_mask
+ grad_bias,
+ None, # ref_per_token_logps
+ None, # old_per_token_logps
+ None, # beta
+ None, # num_generations
+ None, # tpo_length_normalize_logps
+ None, # use_bias_correction_kl
+ None, # normalizer
+ None, # temperature
+ None, # chunk_size
+ )
+
+
+ class LigerFusedLinearTargetPOLoss(nn.Module):
+ def __init__(
+ self,
+ beta: float = 0.0,
+ num_generations: int = 1,
+ tpo_length_normalize_logps: bool = True,
+ use_bias_correction_kl: bool = False,
+ temperature: float = 1.0,
+ chunk_size: int = 1,
+ ):
+ super().__init__()
+ self.beta = beta
+ self.num_generations = num_generations
+ self.tpo_length_normalize_logps = tpo_length_normalize_logps
+ self.use_bias_correction_kl = use_bias_correction_kl
+ self.temperature = temperature
+ self.chunk_size = chunk_size
+
+ def forward(
+ self,
+ _input,
+ lin_weight,
+ selected_token_ids,
+ attention_mask,
+ tpo_targets,
+ tpo_valid_mask=None,
+ bias=None,
+ ref_per_token_logps=None,
+ old_per_token_logps=None,
+ normalizer=1.0,
+ ):
+ if tpo_valid_mask is None:
+ tpo_valid_mask = torch.ones_like(tpo_targets, dtype=torch.bool)
+ return LigerFusedLinearTargetPOFunction.apply(
+ _input,
+ lin_weight,
+ selected_token_ids,
+ attention_mask,
+ tpo_targets,
+ tpo_valid_mask,
+ bias,
+ ref_per_token_logps,
+ old_per_token_logps,
+ self.beta,
+ self.num_generations,
+ self.tpo_length_normalize_logps,
+ self.use_bias_correction_kl,
+ normalizer,
+ self.temperature,
+ self.chunk_size,
+ )
+
+
+if is_wandb_available():
+ import wandb
+
+if is_trackio_available():
+ import trackio
+
+
+logger = get_logger(__name__)
+
+# A reward function can be a string, interpreted as a model ID and loaded as a pretrained model, a pretrained model, or
+# a callable that returns a list of floats (the rewards). The callable receives prompts, completions, and additional
+# arguments from the trainer (refer to the trainer's source for details). To ensure forward compatibility, it should
+# accept **kwargs.
+RewardFunc = str | PreTrainedModel | Callable[..., list[float | None]]
+
+# What we call a rollout function is a callable that takes prompts (list) and the trainer instance as parameters and
+# returns a dict of generation results. Those results must include "prompt_ids", "completion_ids", and "logprobs"
+# fields. Any extra fields (per-completion) are forwarded to the reward functions.
+RolloutFunc = Callable[[list[str], "TargetPOTrainer"], dict[str, Any]]
+
+
+class _SupportsReset(Protocol):
+ def reset(self, **kwargs) -> str | None: ...
+
+
+EnvironmentFactory = Callable[[], _SupportsReset]
+
+
+class TargetPOTrainer(_BaseTrainer):
+ """
+ Trainer for the Target Policy Optimization (TargetPO) method, from the paper [Target Policy
+ Optimization](https://huggingface.co/papers/2604.06159).
+
+ Example:
+
+ ```python
+ from trl import TargetPOTrainer
+ from trl.rewards import accuracy_reward
+ from datasets import load_dataset
+
+ dataset = load_dataset("trl-lib/DeepMath-103K", split="train")
+
+ trainer = TargetPOTrainer(
+ model="Qwen/Qwen2.5-0.5B-Instruct",
+ reward_funcs=accuracy_reward,
+ train_dataset=dataset,
+ )
+ trainer.train()
+ ```
+
+ Args:
+ model (`str` or [`~transformers.PreTrainedModel`] or [`~peft.PeftModel`]):
+ Model to be trained. Can be either:
+
+ - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
+ path to a *directory* containing model weights saved using
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
+ using `.from_pretrained` (where `` is derived from the model
+ config) with the keyword arguments in `args.model_init_kwargs`.
+ - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
+ - A [`~peft.PeftModel`] object. Only causal language models are supported.
+ reward_funcs (`RewardFunc | list[RewardFunc]`):
+ Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
+ functions with the prompts and completions and sum the rewards. Can be either:
+
+ - A single reward function, such as:
+ - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
+ path to a *directory* containing model weights saved using
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
+ using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
+ keyword arguments in `args.model_init_kwargs`.
+ - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
+ - A custom reward function: The function is provided with the prompts and the generated completions,
+ plus any additional columns in the dataset. It should return a list of rewards. Custom reward
+ functions can be either synchronous or asynchronous and can also return `None` when the reward is
+ not applicable to those samples. This is useful for multi-task training where different reward
+ functions apply to different types of samples. When a reward function returns `None` for a sample,
+ that reward function is excluded from the reward calculation for that sample. For more details, see
+ [Using a custom reward
+ function](#using-a-custom-reward-function).
+
+ The trainer's state is also passed to the reward function. The trainer's state is an instance of
+ [`~transformers.TrainerState`] and can be accessed by accessing the `trainer_state` argument to the
+ reward function's signature.
+ - A list of reward functions, where each item can independently be any of the above types. Mixing different
+ types within the list (e.g., a string model ID and a custom reward function) is allowed.
+ args ([`TargetPOConfig`], *optional*):
+ Configuration for this trainer. If `None`, a default configuration is used.
+ train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
+ Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
+ ignored. The format of the samples can be either:
+
+ - [Standard](dataset_formats#standard): Each sample contains plain text.
+ - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
+ and content).
+ eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`):
+ Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*):
+ Processing class used to process the data. The padding side must be set to "left". If `None`, the
+ processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A
+ padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token,
+ `tokenizer.eos_token` will be used as the default.
+ reward_processing_classes ([`~transformers.PreTrainedTokenizerBase`] or `list[PreTrainedTokenizerBase]`, *optional*):
+ Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
+
+ - A single processing class: Used when `reward_funcs` contains only one reward function.
+ - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
+ If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
+ `None`, the tokenizer for the model is automatically loaded using
+ [`~transformers.AutoTokenizer.from_pretrained`]. For elements in `reward_funcs` that are custom reward
+ functions (not [`~transformers.PreTrainedModel`]), the corresponding entries in `reward_processing_classes`
+ are ignored.
+ callbacks (list of [`~transformers.TrainerCallback`], *optional*):
+ List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
+ in [here](https://huggingface.co/docs/transformers/main_classes/callback).
+
+ If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
+ method.
+ optimizers (`tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None]`, *optional*, defaults to `(None, None)`):
+ A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your
+ model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`.
+ peft_config ([`~peft.PeftConfig`], *optional*):
+ PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
+ tools (list of `Callable`, *optional*):
+ A list of callable tool functions (sync or async) that the model can invoke during generation. Each tool
+ should be a standard Python function with properly type-hinted arguments and return values, and a
+ Google-style docstring describing its purpose, arguments, and return value. For more details, see:
+ https://huggingface.co/docs/transformers/en/chat_extras#passing-tools. The model uses the function's name,
+ type hints, and docstring to determine how to call it. Ensure that the model's chat template supports tool
+ use and that it has been fine-tuned for tool calling.
+ rollout_func (`RolloutFunc`, *optional*):
+ Function to use for generating completions. It receives the list of prompts allocated to the current
+ process and the trainer instance. It must return a dict with `"prompt_ids"`, `"completion_ids"`, and
+ `"logprobs"` fields, and can optionally return `"logprob_token_ids"` (same shape as `"logprobs"`). Any
+ other fields are forwarded to the reward functions. The function receives the raw per-process prompt slice
+ with no duplication; it is responsible for returning the correct number of completions per prompt (see
+ `num_generations` / `num_generations_eval` on the trainer). This feature is experimental and may change or
+ be removed at any time without prior notice.
+ environment_factory (`EnvironmentFactory`, *optional*):
+ A callable that creates and returns an environment instance. The environment class should define methods
+ that can be invoked as tools during generation. Each method should comply with the same requirements as the
+ `tools` described above. If `environment_factory` is provided, an instance of the environment is created
+ for each generation in the batch, allowing for parallel and independent interactions. The environment must
+ also implement a callable `reset` method that can be used to reset state between generations. The `reset`
+ method should return either `None` or a string: when it returns a string, that string is appended to the
+ last user message before generation. This feature is experimental and may change or be removed at any time
+ without prior notice.
+ """
+
+ _tag_names = ["trl", "tpo"]
+ _name = "TargetPO"
+ _paper = {
+ "title": "Target Policy Optimization",
+ "id": "2604.06159",
+ # docstyle-ignore
+ "citation": textwrap.dedent("""\
+ @misc{kaddour2026targetpolicyoptimization,
+ title = {{Target Policy Optimization}},
+ author = {Jean Kaddour},
+ year = 2026,
+ eprint = {arXiv:2604.06159},
+ }"""),
+ }
+
+ def __init__(
+ self,
+ model: "str | PreTrainedModel | PeftModel",
+ reward_funcs: RewardFunc | list[RewardFunc],
+ args: TargetPOConfig | None = None,
+ train_dataset: Dataset | IterableDataset | None = None,
+ eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None,
+ processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None,
+ reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None,
+ callbacks: list[TrainerCallback] | None = None,
+ optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None),
+ peft_config: "PeftConfig | None" = None,
+ tools: list[Callable] | None = None,
+ rollout_func: RolloutFunc | None = None,
+ environment_factory: EnvironmentFactory | None = None,
+ ):
+ # Args
+ if args is None:
+ model_name = model if isinstance(model, str) else get_config_model_id(model.config)
+ model_name = model_name.split("/")[-1]
+ args = TargetPOConfig(f"{model_name}-TargetPO")
+
+ # Model
+ if isinstance(model, str):
+ model_init_kwargs = args.model_init_kwargs or {}
+ # Distributed training requires device_map=None ("auto" fails)
+ if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]:
+ model_init_kwargs["device_map"] = None
+ model = create_model_from_path(model, **model_init_kwargs)
+ else:
+ if args.model_init_kwargs is not None:
+ logger.warning(
+ "You passed `model_init_kwargs` to the `TargetPOConfig`, but your model is already instantiated. "
+ "The `model_init_kwargs` will be ignored."
+ )
+
+ # Some models (SmolVLM/Idefics3) don't support `logits_to_keep` argument and error out if we pass it
+ # Inspect the forward method before we wrap the model with PEFT
+ self.model_kwarg_keys = (
+ inspect.signature(model.forward).parameters.keys()
+ if not hasattr(model, "get_base_model")
+ else inspect.signature(model.get_base_model().forward).parameters.keys()
+ )
+
+ # Processing class
+ if processing_class is None:
+ processing_class = AutoProcessor.from_pretrained(
+ get_config_model_id(model.config), truncation_side="left", padding_side="left"
+ )
+
+ # Handle pad token for processors or tokenizers
+ if isinstance(processing_class, ProcessorMixin):
+ tokenizer = processing_class.tokenizer
+ self._is_vlm = True
+ elif isinstance(processing_class, PreTrainedTokenizerBase):
+ tokenizer = processing_class
+ self._is_vlm = False
+ else:
+ raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`")
+
+ # Resolve vision placeholder token IDs once. Used by the forward pass to rebuild mm_token_type_ids
+ # when tool responses inject images into the completion (see _generate forward_kwargs block).
+ self._image_pad_token_id = None
+ self._video_pad_token_id = None
+ if self._is_vlm:
+ for candidate in ("<|image_pad|>", "<|image|>"):
+ tid = tokenizer.convert_tokens_to_ids(candidate)
+ if tid != tokenizer.unk_token_id:
+ self._image_pad_token_id = tid
+ break
+ tid = tokenizer.convert_tokens_to_ids("<|video_pad|>")
+ if tid != tokenizer.unk_token_id:
+ self._video_pad_token_id = tid
+
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+ self.pad_token_id = tokenizer.pad_token_id
+ self.eos_token_id = tokenizer.eos_token_id
+
+ if is_peft_available() and is_peft_model(model) and peft_config is not None:
+ raise ValueError(
+ "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge "
+ "and unload the existing adapter, save the resulting base model, and then pass that base model along "
+ "with the new `peft_config` to the trainer."
+ )
+ if is_peft_available() and is_peft_model(model) and args.beta != 0.0:
+ # If the model is a PEFT model with a pretrained adapter, we need to create a "ref" adapter that is a copy
+ # of the "default" adapter, so that we can use it as the reference model during GRPO training.
+ model.add_adapter("ref", model.peft_config["default"])
+ for name, param in model.named_parameters():
+ if ".default." in name:
+ ref_name = name.replace(".default.", ".ref.")
+ ref_param = model.get_parameter(ref_name)
+ ref_param.data.copy_(param.data)
+
+ # Create PEFT model
+ if peft_config is not None:
+ model = get_peft_model(model, peft_config)
+
+ # When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally
+ # handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489
+ if is_peft_available() and is_peft_model(model) and args.gradient_checkpointing:
+ model.enable_input_require_grads()
+
+ # When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the
+ # original paper (see https://huggingface.co/papers/2305.14314, paragraph 3). Normally, this can be done by
+ # passing `autocast_adapter_dtype=False` to `get_peft_model`, but this option is not yet supported for
+ # quantized models. See: https://github.com/huggingface/peft/issues/2889
+ # Non-quantized models do not have the `is_loaded_in_{8,4}bit` attributes, whereas quantized models do
+ if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False):
+ for param in model.parameters():
+ if param.requires_grad:
+ param.data = param.data.to(torch.bfloat16)
+
+ # Reward functions
+ if not isinstance(reward_funcs, list):
+ reward_funcs = [reward_funcs]
+ self.reward_func_names = []
+ for i, reward_func in enumerate(reward_funcs):
+ if isinstance(reward_func, str):
+ model_init_kwargs = args.model_init_kwargs or {}
+ # Distributed training requires device_map=None ("auto" fails)
+ if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]:
+ model_init_kwargs["device_map"] = None
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
+ reward_func, num_labels=1, **model_init_kwargs
+ )
+ if isinstance(reward_funcs[i], nn.Module): # Use Module over PretrainedModel for compat w/ compiled models
+ self.reward_func_names.append(get_config_model_id(reward_funcs[i].config).split("/")[-1])
+ else:
+ self.reward_func_names.append(reward_funcs[i].__name__)
+ self.reward_funcs = reward_funcs
+
+ # Reward weights
+ if args.reward_weights is not None:
+ if len(args.reward_weights) != len(reward_funcs):
+ raise ValueError(
+ f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
+ f"functions ({len(reward_funcs)})"
+ )
+ self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
+ else:
+ self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
+
+ # Reward processing class
+ if reward_processing_classes is None:
+ reward_processing_classes = [None] * len(reward_funcs)
+ elif not isinstance(reward_processing_classes, list):
+ reward_processing_classes = [reward_processing_classes]
+ if len(reward_processing_classes) != len(reward_funcs):
+ raise ValueError(
+ f"The number of reward processing classes ({len(reward_processing_classes)}) must match the number of "
+ f"reward functions ({len(reward_funcs)})."
+ )
+
+ for i, (reward_processing_class, reward_func) in enumerate(
+ zip(reward_processing_classes, reward_funcs, strict=True)
+ ):
+ if isinstance(reward_func, PreTrainedModel):
+ if reward_processing_class is None:
+ reward_processing_class = AutoTokenizer.from_pretrained(get_config_model_id(reward_func.config))
+ if reward_processing_class.pad_token_id is None:
+ reward_processing_class.pad_token = reward_processing_class.eos_token
+ # The reward model computes the reward for the latest non-padded token in the input sequence.
+ # So it's important to set the pad token ID to the padding token ID of the processing class.
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
+ reward_processing_classes[i] = reward_processing_class
+
+ self.reward_processing_classes = reward_processing_classes
+
+ # Rollout function
+ if rollout_func is not None and os.environ.get("TRL_EXPERIMENTAL_SILENCE", "0") != "1":
+ warnings.warn(
+ "You are using 'rollout_func', which is an experimental feature. This API may change or be removed at "
+ "any time without prior notice. Silence this warning by setting environment variable "
+ "TRL_EXPERIMENTAL_SILENCE=1.",
+ UserWarning,
+ stacklevel=2,
+ )
+ self.rollout_func = rollout_func
+ if environment_factory is not None and os.environ.get("TRL_EXPERIMENTAL_SILENCE", "0") != "1":
+ warnings.warn(
+ "You are using 'environment_factory', which is an experimental feature. This API may change or be "
+ "removed at any time without prior notice. Silence this warning by setting environment variable "
+ "TRL_EXPERIMENTAL_SILENCE=1.",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ # Tools
+ if tools:
+ if not Version(transformers.__version__) >= Version("5.0.0"):
+ raise ImportError(
+ "Using tools with TargetPOTrainer requires transformers version 5.0.0 or higher. Please upgrade "
+ "transformers with `pip install --upgrade transformers` to use this feature."
+ )
+ if environment_factory:
+ if not Version(transformers.__version__) >= Version("5.2.0"):
+ raise ImportError(
+ "Using `environment_factory` with TargetPOTrainer requires transformers version 5.2.0 or higher. "
+ "Please install transformers from the main branch with `pip install "
+ "git+https://github.com/huggingface/transformers.git@main` to use this feature."
+ )
+ if tools or environment_factory:
+ if not is_jmespath_available():
+ raise ImportError(
+ "Using tools with TargetPOTrainer requires the jmespath library for response parsing. Please install "
+ "it with `pip install jmespath` to use this feature."
+ )
+ if not supports_tool_calling(processing_class):
+ raise ValueError(
+ "The provided chat template does not support tool calling. The template must be able to render a "
+ "full tool-calling conversation (user -> assistant with tool_calls -> tool)."
+ )
+
+ # Create the environments and extract their methods to be used as tools. We create one environment per rollout
+ generation_batch_size = args.per_device_train_batch_size * args.steps_per_generation
+ if environment_factory is not None:
+ self.environments = [environment_factory() for _ in range(generation_batch_size)]
+ environment_methods = [[] for _ in range(generation_batch_size)]
+ for i, environment in enumerate(self.environments):
+ has_reset = False
+ for name, member in inspect.getmembers(environment, predicate=inspect.ismethod):
+ if name == "reset":
+ has_reset = True
+ elif not name.startswith("_"):
+ environment_methods[i].append(member)
+ if not has_reset:
+ raise ValueError(
+ "Each environment instance returned by `environment_factory` must define a callable `reset` "
+ )
+ else:
+ self.environments = None
+
+ tools = tools or []
+ self._sync_tool_dicts = [{} for _ in range(generation_batch_size)]
+ self._async_tool_dicts = [{} for _ in range(generation_batch_size)]
+ for i in range(generation_batch_size):
+ for tool in tools + (environment_methods[i] if self.environments is not None else []):
+ if inspect.iscoroutinefunction(tool):
+ self._async_tool_dicts[i][tool.__name__] = tool
+ else:
+ self._sync_tool_dicts[i][tool.__name__] = tool
+
+ self.tools = tools + (environment_methods[0] if self.environments is not None else [])
+
+ # Check for async functions to start an event loop on a daemon thread
+ self._has_async_funcs = any(inspect.iscoroutinefunction(func) for func in self.reward_funcs + self.tools)
+
+ if self._has_async_funcs:
+ self.async_loop_thread, self.async_loop, self.async_loop_ready_event = start_event_loop_in_daemon(
+ name="TargetPOTrainer-AsyncLoop"
+ )
+ # wait until the event loop is running in the daemon thread
+ self.async_loop_ready_event.wait()
+ atexit.register(shutdown_event_loop_in_daemon, self.async_loop_thread, self.async_loop)
+
+ # At the time of initial implementation, most tokenizers do not have built-in support for response schemas.
+ # While waiting for broader adoption, we provide this utility function to manually set the response schema for
+ # known chat templates. `response_schema` lives on the (inner) tokenizer, since `parse_response` is a tokenizer
+ # method that reads `self.response_schema`.
+ tokenizer = processing_class.tokenizer if self._is_vlm else processing_class
+ if self.tools and getattr(tokenizer, "response_schema", None) is None:
+ processing_class = add_response_schema(processing_class)
+ # In multi-turn training, the chat template *must* be prefix-preserving. If the tokenizer's original template
+ # isn't, we replace it at initialization with a training-safe, prefix-preserving template.
+ if self.tools and not is_chat_template_prefix_preserving(processing_class):
+ self.chat_template = get_training_chat_template(processing_class)
+ else:
+ self.chat_template = None
+
+ # Training arguments
+ self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
+ self.num_generations = args.num_generations # = G in the GRPO paper
+ self.max_tool_calling_iterations = args.max_tool_calling_iterations or sys.maxsize
+ self.num_generations_eval = args.num_generations_eval or self.num_generations
+ self.chat_template_kwargs = args.chat_template_kwargs or {}
+ self.temperature = args.temperature
+ self.top_p = args.top_p
+ self.top_k = args.top_k
+ self.min_p = args.min_p
+ self.repetition_penalty = args.repetition_penalty
+ self.use_transformers_paged = args.use_transformers_paged
+ self.pad_to_multiple_of = args.pad_to_multiple_of
+ self.use_vllm = args.use_vllm
+ self.vllm_mode = args.vllm_mode
+ self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode
+ self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode
+ self.vllm_importance_sampling_correction = args.vllm_importance_sampling_correction
+ self.vllm_importance_sampling_mode = args.vllm_importance_sampling_mode
+ self.vllm_importance_sampling_cap = args.vllm_importance_sampling_cap
+ self.use_liger_kernel = args.use_liger_kernel
+ self.loss_type = args.loss_type
+ self.tpo_target_temperature = args.tpo_target_temperature
+ self.tpo_length_normalize_logps = args.tpo_length_normalize_logps
+ self.multi_objective_aggregation = args.multi_objective_aggregation
+ self.scale_rewards = args.scale_rewards
+ self.importance_sampling_level = args.importance_sampling_level
+ self.off_policy_mask_threshold = args.off_policy_mask_threshold
+ if self.use_liger_kernel and self.off_policy_mask_threshold is not None:
+ raise ValueError("Liger kernel does not support off-policy sequence masking yet.")
+ self.mask_truncated_completions = args.mask_truncated_completions
+ self.top_entropy_quantile = args.top_entropy_quantile
+ if self.use_liger_kernel and self.top_entropy_quantile < 1.0:
+ raise NotImplementedError(
+ "Liger Kernels don't currently support masking token positions based on entropy."
+ )
+ if self.use_liger_kernel and self.importance_sampling_level not in ("token", "sequence"):
+ raise ValueError(
+ f"Unknown importance sampling level: {self.importance_sampling_level}. "
+ "Possible values are 'token' and 'sequence'."
+ )
+
+ # Datasets
+ self.shuffle_dataset = args.shuffle_dataset
+
+ if train_dataset is None:
+ raise ValueError("`train_dataset` is required")
+ elif (
+ isinstance(train_dataset, IterableDataset)
+ or isinstance(eval_dataset, IterableDataset)
+ or (
+ isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values())
+ )
+ ):
+ # See https://github.com/huggingface/trl/issues/3213
+ raise NotImplementedError(
+ "Iterable datasets are not yet supported in TargetPOTrainer. Please use a standard dataset instead."
+ )
+
+ if args.loss_type == "luspo" and args.importance_sampling_level != "sequence":
+ logger.warning(
+ "When using `'luspo'` loss, `importance_sampling_level` should be set to `'sequence'` to mirror the "
+ "paper's setup."
+ )
+
+ if args.loss_type == "vespo" and args.importance_sampling_level != "token":
+ logger.warning(
+ "VESPO computes sequence-level importance weights internally. `importance_sampling_level` should be "
+ "set to `'token'` (the default)."
+ )
+
+ if self.loss_type == "vespo" and self.use_vllm and self.vllm_importance_sampling_correction:
+ if self.vllm_importance_sampling_mode not in ["token_truncate", "token_mask"]:
+ raise ValueError(
+ f"VESPO loss requires `vllm_importance_sampling_mode` to be either 'token_truncate' or "
+ f"'token_mask'. Got: {self.vllm_importance_sampling_mode}."
+ )
+
+ # Multi-step
+ self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
+ self.epsilon_low = args.epsilon
+ self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
+ # Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle
+ self._step = 0
+ # Buffer the batch to reuse generated outputs across multiple updates. For more details, see
+ # `_get_train_sampler` and `_prepare_inputs`.
+ self._buffered_inputs = None
+
+ # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was
+ # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream
+ # (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we
+ # default to the recommended non-reentrant behavior here, while preserving any user-provided value.
+ if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"):
+ args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
+ args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False)
+
+ super().__init__(
+ model=model,
+ args=args,
+ data_collator=identity, # No data collation is needed in GRPO
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ processing_class=processing_class,
+ callbacks=callbacks,
+ optimizers=optimizers,
+ # In Trainer, `training_step` scales the loss by `gradient_accumulation_steps` only if `compute_loss_func`
+ # is None. For DAPO, loss scaling instead depends on the total number of completions tokens across the
+ # global accumulated batch. To control scaling ourselves, we must disable Trainer’s built-in scaling. The
+ # simplest (though a bit hacky) way is to set `compute_loss_func` to any non-None value, which bypasses
+ # that behavior without rewriting `training_step`.
+ compute_loss_func="non-None value to disable scaling",
+ )
+
+ # Reference model
+ self.beta = args.beta
+ if self.beta == 0.0:
+ # If beta is 0.0, the reference model is not needed
+ self.ref_model = None
+ elif is_peft_model(model):
+ # If PEFT is used, the reference model is not needed since the adapter can be disabled
+ # to revert to the initial model.
+ self.ref_model = None
+ else:
+ # For deepspeed, fsdp or non-distributed models, create a reference model from scratch
+ model_init_kwargs = args.model_init_kwargs or {}
+ # Distributed training requires device_map=None ("auto" fails)
+ if self.args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]:
+ model_init_kwargs["device_map"] = None
+ self.ref_model = create_model_from_path(get_config_model_id(self.model.config), **model_init_kwargs)
+
+ # Disable dropout in the models
+ if args.disable_dropout:
+ disable_dropout_in_model(model)
+ if self.ref_model is not None:
+ disable_dropout_in_model(self.ref_model)
+
+ # Cast LM Head To FP32
+ if args.cast_lm_head_to_fp32:
+
+ def _cast_lm_head_to_fp32(target_model: PreTrainedModel):
+ """Cast lm_head to fp32 while preserving embedding output dtype if tied."""
+
+ def cast_inputs_to_fp32(module, inputs):
+ # Preserve other positional args and kwargs untouched
+ if not inputs:
+ return inputs
+ return (inputs[0].to(torch.float32),) + inputs[1:]
+
+ original_dtype_local = target_model.lm_head.weight.dtype
+ target_model.lm_head = target_model.lm_head.float()
+ target_model.lm_head.register_forward_pre_hook(cast_inputs_to_fp32)
+
+ if target_model.config.tie_word_embeddings:
+
+ def cast_outputs_to_original_dtype(module, args, output):
+ return output.to(original_dtype_local)
+
+ # Only cast activations; weights are now fp32 (intentional for numerical stability of logits)
+ target_model.model.embed_tokens.register_forward_hook(cast_outputs_to_original_dtype)
+
+ _cast_lm_head_to_fp32(model)
+ if self.ref_model is not None:
+ _cast_lm_head_to_fp32(self.ref_model)
+
+ # Liger loss
+ if self.use_liger_kernel:
+ if not is_liger_kernel_available():
+ raise ImportError(
+ "Liger is required to use `use_liger_kernel` as the TargetPO loss. Run `pip install liger-kernel`."
+ )
+ # redirect the model.module forward to the model forward to ensure pre-forward hooks are called
+ self._forward_redirection = _ForwardRedirection()
+
+ self.liger_tpo_loss = LigerFusedLinearTargetPOLoss(
+ beta=self.beta,
+ temperature=self.temperature,
+ num_generations=self.num_generations,
+ tpo_length_normalize_logps=self.tpo_length_normalize_logps,
+ use_bias_correction_kl=args.use_bias_correction_kl,
+ )
+
+ # Initialize the metrics
+ self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
+ self._total_train_tokens = 0
+ self._current_train_step_time = 0.0
+ self.log_completions = args.log_completions
+ self.log_unique_prompts = args.log_unique_prompts
+ self.num_completions_to_print = args.num_completions_to_print
+ # Keep logs sized to the generation batch to record only outputs from the latest model update.
+ self._logs = {
+ "images": deque(maxlen=args.generation_batch_size),
+ "prompt": deque(maxlen=args.generation_batch_size),
+ "completion": deque(maxlen=args.generation_batch_size),
+ "rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)),
+ "advantages": deque(maxlen=args.generation_batch_size),
+ "extra": defaultdict(lambda: deque(maxlen=args.generation_batch_size)),
+ }
+ # Buffers for user-logged data from reward functions, flushed after gathering
+ self._pending_extra_logs = defaultdict(list)
+ self._pending_metrics = defaultdict(list)
+
+ # Ensure each process receives a unique seed to prevent duplicate completions when generating with
+ # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
+ # it's safer to set it in all cases.
+ set_seed(args.seed, device_specific=True)
+
+ if self.use_vllm:
+ # Initialize vLLM generation backend
+ self.vllm_generation = VLLMGeneration(
+ model=self.model,
+ accelerator=self.accelerator,
+ is_fsdp_enabled=self.is_fsdp_enabled,
+ processing_class=self.processing_class,
+ # vLLM configuration
+ mode=args.vllm_mode,
+ structured_outputs_regex=args.vllm_structured_outputs_regex,
+ # Server mode configuration
+ server_base_url=args.vllm_server_base_url,
+ server_host=args.vllm_server_host,
+ server_port=args.vllm_server_port,
+ group_port=args.vllm_group_port,
+ server_timeout=args.vllm_server_timeout,
+ # Colocate mode configuration
+ tensor_parallel_size=args.vllm_tensor_parallel_size,
+ gpu_memory_utilization=args.vllm_gpu_memory_utilization,
+ max_model_length=args.vllm_max_model_length,
+ max_num_seqs=args.per_device_train_batch_size
+ * args.vllm_tensor_parallel_size
+ * args.steps_per_generation,
+ enable_sleep_mode=args.vllm_enable_sleep_mode,
+ model_impl=args.vllm_model_impl,
+ # Generation configuration
+ repetition_penalty=self.repetition_penalty,
+ temperature=self.temperature,
+ top_p=self.top_p,
+ top_k=self.top_k,
+ min_p=self.min_p,
+ max_completion_length=self.max_completion_length,
+ logprobs=0, # we only need the generated token logprobs for the importance sampling correction
+ generation_kwargs=args.generation_kwargs,
+ )
+ self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation
+ else:
+ generation_kwargs = {
+ "max_new_tokens": self.max_completion_length,
+ "do_sample": True,
+ "pad_token_id": tokenizer.pad_token_id,
+ "bos_token_id": tokenizer.bos_token_id,
+ "eos_token_id": tokenizer.eos_token_id,
+ "temperature": self.temperature,
+ "top_p": self.top_p,
+ "top_k": self.top_k,
+ "min_p": self.min_p,
+ "repetition_penalty": self.repetition_penalty,
+ "cache_implementation": args.cache_implementation,
+ }
+ if args.generation_kwargs is not None:
+ generation_kwargs.update(args.generation_kwargs)
+ self.generation_config = GenerationConfig(**generation_kwargs, disable_compile=True)
+ # Keep training-specific generation kwargs to overwrite model's original generation config
+ self.generation_kwargs = generation_kwargs
+
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
+ # self.model_accepts_loss_kwargs to False to enable scaling.
+ self.model_accepts_loss_kwargs = False
+
+ # Add tags to the model
+ self.model.add_model_tags(self._tag_names)
+
+ if self.ref_model is not None:
+ if self.is_deepspeed_enabled:
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
+ elif self.is_fsdp_enabled:
+ self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
+ else:
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
+
+ if args.sync_ref_model:
+ if self.beta == 0.0:
+ raise ValueError(
+ "You passed `sync_ref_model=True` while `beta=0.0`, which means the reference model is not used "
+ "during training. Consequently, TargetPOTrainer does not create a `ref_model` instance, and there is "
+ "nothing to synchronize. Please set `sync_ref_model=False`, or set `beta` to a non-zero value."
+ )
+ if is_peft_model(model):
+ raise NotImplementedError(
+ "You passed `sync_ref_model=True` while using a PEFT model, which is currently not supported. "
+ "With PEFT, TargetPOTrainer does not keep a separate reference model in memory; instead, it recovers "
+ "reference behavior by temporarily disabling the adapter. As a result, there is no standalone "
+ "`ref_model` instance to synchronize. Use `sync_ref_model=False`, or opt for full fine-tuning if "
+ "you need a synced reference model. If you need `sync_ref_model` to work with PEFT, please open a "
+ "feature request at https://github.com/huggingface/trl/issues."
+ )
+ self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
+
+ for i, reward_func in enumerate(self.reward_funcs):
+ if isinstance(reward_func, PreTrainedModel):
+ if self.is_deepspeed_enabled:
+ self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator)
+ else:
+ # set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp
+ self.reward_funcs[i] = self.accelerator.prepare_model(
+ reward_func, evaluation_mode=True, device_placement=True
+ )
+
+ if self.accelerator.is_main_process and self.log_completions:
+ os.makedirs(os.path.join(self.args.output_dir, "completions"), exist_ok=True)
+ if self.args.log_completions_hub_repo is not None:
+ repo_id = self.args.log_completions_hub_repo
+ create_repo(repo_id, private=self.args.hub_private_repo, repo_type="dataset", exist_ok=True)
+ template_path = pkg_resources.files("trl").joinpath("templates/completions_dataset_card.md")
+ card_data = DatasetCardData(
+ pretty_name="TRL Completion logs",
+ tags=["trl", "trl-logs", "completions"],
+ )
+ card = DatasetCard.from_template(
+ card_data=card_data,
+ template_path=str(template_path),
+ repo_id=repo_id,
+ hub_model_id=self.args.hub_model_id,
+ )
+ card.push_to_hub(repo_id)
+ self.commit_scheduler = CommitScheduler(
+ repo_id=repo_id,
+ repo_type="dataset",
+ folder_path=f"{self.args.output_dir}/completions",
+ every=2, # minutes
+ allow_patterns=["*.parquet"],
+ )
+
+ def _set_signature_columns_if_needed(self):
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
+ # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids"
+ # and "attention_mask"). In TargetPOTrainer, we preprocess data, so using the model's signature columns doesn't
+ # work. Instead, we set them to the columns expected by the `training_step` method, hence the override.
+ if self._signature_columns is None:
+ self._signature_columns = ["prompt", "image", "images"]
+
+ # This method overrides `Trainer.get_train_dataloader` to support our custom batching strategy.
+ # Instead of returning a standard per-step batch (i.e., `per_device_batch_size), our dataloader loads an
+ # *generation* batch (i.e., `per_device_batch_size × steps_per_generation`). This allows us to generate completions
+ # once every steps_per_generation step—rather than once per accumulation step—which is significantly more
+ # efficient. The only change from the original implementation is multiplying the batch size by
+ # `steps_per_generation`. Thus, `_prepare_inputs` is called with this *generation* batch, and it handles the
+ # splitting internally.
+ # Maintenance note: This method is a copy-paste of the original `Trainer.get_train_dataloader` with only one line
+ # modification.
+ def get_train_dataloader(self):
+ return self._get_dataloader(
+ dataset=self.train_dataset,
+ description="Training",
+ batch_size=self._train_batch_size * self.args.steps_per_generation, # < this is the change
+ sampler_fn=self._get_train_sampler,
+ is_training=True,
+ )
+
+ def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler:
+ # Returns a sampler that
+ # 1. ensures each prompt is repeated across multiple processes. This guarantees that identical prompts are
+ # distributed to different GPUs, allowing rewards to be computed and normalized correctly within each prompt
+ # group. Using the same seed across processes ensures consistent prompt assignment, preventing discrepancies
+ # in group formation.
+ # 2. repeats the batch multiple times to allow reusing generations across multiple updates. Refer to
+ # _prepare_inputs to see how the generations are stored and reused.
+
+ # In the following figure, the values are the prompt indices. The first row shows the first sampled batch, the
+ # second row shows the second sampled batch, and so on.
+ #
+ # | GPU 0 | GPU 1 |
+ #
+ # global_step step <-───> num_generations=2
+ # <-───────> per_device_train_batch_size=3
+ # grad_accum ▲ ▲ 0 0 0 0 1 1 2 2 <- Generate for the first `steps_per_generation` (prompts 0 to 11); store the completions; use the first slice to compute the loss
+ # =2 ▼ | 0 1 3 3 4 4 5 5 <- Take the stored generations and use the second slice to compute the loss
+ # |
+ # | 1 2 6 6 7 7 8 8 <- Take the stored generations and use the third slice to compute the loss
+ # steps_per_gen=4 ▼ 1 3 9 9 10 10 11 11 <- Take the stored generations and use the fourth slice to compute the loss
+ #
+ # 2 4 12 12 13 13 14 14 <- Generate for the second `steps_per_generation` (prompts 12 to 23); store the completions; use the first slice to compute the loss
+ # 2 5 15 15 16 16 17 17 <- Take the stored generations and use the second slice to compute the loss
+ # ...
+ if dataset is None:
+ dataset = self.train_dataset
+ return RepeatSampler(
+ data_source=dataset,
+ mini_repeat_count=self.num_generations,
+ batch_size=self.args.generation_batch_size // self.num_generations,
+ repeat_count=self.num_iterations * self.args.steps_per_generation,
+ shuffle=self.shuffle_dataset,
+ seed=self.args.seed,
+ )
+
+ def _get_eval_sampler(self, eval_dataset) -> Sampler:
+ # See _get_train_sampler for an explanation of the sampler.
+ return RepeatSampler(
+ data_source=eval_dataset,
+ mini_repeat_count=self.num_generations_eval,
+ seed=self.args.seed,
+ )
+
+ @profiling_decorator
+ def _get_last_hidden_state(
+ self,
+ unwrapped_model,
+ input_ids,
+ attention_mask,
+ logits_to_keep,
+ pixel_values=None,
+ image_grid_thw=None,
+ pixel_attention_mask=None,
+ image_sizes=None,
+ image_position_ids=None,
+ ):
+ if is_peft_model(unwrapped_model):
+ unwrapped_model = unwrapped_model.base_model.model
+
+ # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't)
+ model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
+
+ # For Qwen models:
+ if image_grid_thw is not None and pixel_values is not None:
+ model_inputs["image_grid_thw"] = image_grid_thw
+ # For Gemma, SmolVLM2, LLaVa-Next etc.:
+ if pixel_values is not None:
+ model_inputs["pixel_values"] = pixel_values
+ # For SmolVLM2
+ if pixel_attention_mask is not None:
+ model_inputs["pixel_attention_mask"] = pixel_attention_mask
+ # For LLaVa-Next
+ if image_sizes is not None:
+ model_inputs["image_sizes"] = image_sizes
+ if image_position_ids is not None:
+ model_inputs["image_position_ids"] = image_position_ids
+
+ # Only add logits_to_keep if the model supports it
+ if "logits_to_keep" in self.model_kwarg_keys:
+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
+ model_inputs["logits_to_keep"] = logits_to_keep + 1
+
+ model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings
+
+ last_hidden_state = unwrapped_model.model(**model_inputs).last_hidden_state
+ # Exclude the last value: it corresponds to the next token pred
+ last_hidden_state = last_hidden_state[:, :-1, :] # (B, L-1, H)
+ # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op.
+ last_hidden_state = last_hidden_state[:, -logits_to_keep:, :] # (B, logits_to_keep, H)
+ return last_hidden_state
+
+ def get_high_entropy_mask(self, entropies: torch.Tensor, mask: torch.Tensor, threshold: float) -> torch.Tensor:
+ """
+ Returns a binary mask identifying tokens whose entropy exceeds a given quantile threshold.
+
+ Args:
+ entropies (`torch.Tensor`):
+ Tensor of shape (batch_size, seq_len) with per-token entropy values.
+ mask (`torch.Tensor`):
+ Binary mask of the same shape as `entropies`, where `1` indicates valid tokens and `0` padding.
+ threshold (`float`):
+ Quantile threshold between `0.0` and `1.0` to select high-entropy tokens.
+
+ Returns:
+ `torch.Tensor`:
+ Boolean mask of shape (batch_size, seq_len), where `True` indicates tokens with entropy >= threshold
+ and `False` otherwise.
+ """
+ local = entropies[mask.bool()].float()
+
+ # Use a negative pad_value as a sentinel because entropy values are always >= 0.
+ # This guarantees that the sentinel cannot collide with any real entropy value.
+ pad_value = -1e9
+
+ # Pad across processes so that every rank has the same tensor length
+ padded = self.accelerator.pad_across_processes(local, dim=0, pad_index=pad_value)
+ gathered = self.accelerator.gather(padded)
+
+ # Drop sentinel values (safe because no entropy can be negative)
+ gathered = gathered[gathered != pad_value]
+
+ if gathered.numel() == 0:
+ return torch.zeros_like(entropies, dtype=torch.bool)
+
+ entropy_threshold = torch.quantile(gathered, threshold)
+ masked_entropies = entropies * mask.float()
+ entropy_mask = masked_entropies >= entropy_threshold
+ return entropy_mask & mask.bool() # ensure padding tokens are always masked out
+
+ @profiling_decorator
+ def _get_per_token_logps_and_entropies(
+ self,
+ model,
+ input_ids,
+ attention_mask,
+ logits_to_keep,
+ batch_size=None,
+ compute_entropy=False,
+ pixel_values=None,
+ image_grid_thw=None,
+ num_images=None,
+ pixel_attention_mask=None,
+ image_sizes=None,
+ token_type_ids=None,
+ mm_token_type_ids=None,
+ image_position_ids=None,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ """Compute log-probs and (optionally) entropies for each token."""
+ batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak
+ all_logps = []
+ all_entropies = []
+ for start in range(0, input_ids.size(0), batch_size):
+ input_ids_batch = input_ids[start : start + batch_size]
+ attention_mask_batch = attention_mask[start : start + batch_size]
+
+ # Build model inputs
+ model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}
+ if image_grid_thw is not None and pixel_values is not None:
+ rows_per_image = image_grid_thw.prod(dim=-1)
+ rows_per_sample = torch.split(rows_per_image, num_images)
+ rows_per_sample = torch.stack([s.sum() for s in rows_per_sample])
+ cum_rows = torch.cat([torch.tensor([0], device=rows_per_sample.device), rows_per_sample.cumsum(0)])
+ row_start, row_end = cum_rows[start].item(), cum_rows[start + batch_size].item()
+ model_inputs["pixel_values"] = pixel_values[row_start:row_end]
+ cum_imgs = torch.tensor([0] + num_images).cumsum(0)
+ img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size]
+ model_inputs["image_grid_thw"] = image_grid_thw[img_start:img_end]
+ elif image_position_ids is not None and pixel_values is not None:
+ cum_imgs = torch.tensor([0] + num_images).cumsum(0)
+ img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size]
+ model_inputs["pixel_values"] = pixel_values[img_start:img_end]
+ model_inputs["image_position_ids"] = image_position_ids[img_start:img_end]
+ elif pixel_values is not None:
+ model_inputs["pixel_values"] = pixel_values[start : start + batch_size]
+ if pixel_attention_mask is not None:
+ model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size]
+ if image_sizes is not None:
+ model_inputs["image_sizes"] = image_sizes[start : start + batch_size]
+ if token_type_ids is not None:
+ model_inputs["token_type_ids"] = token_type_ids[start : start + batch_size]
+ if mm_token_type_ids is not None:
+ model_inputs["mm_token_type_ids"] = mm_token_type_ids[start : start + batch_size]
+
+ # Only add logits_to_keep if the model supports it
+ if "logits_to_keep" in self.model_kwarg_keys:
+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
+ model_inputs["logits_to_keep"] = logits_to_keep + 1
+
+ model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings
+
+ logits = model(**model_inputs).logits
+ # Exclude the last value: it corresponds to the next token pred
+ logits = logits[:, :-1, :] # (B, L-1, H)
+ # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op.
+ logits = logits[:, -logits_to_keep:, :] # (B, logits_to_keep, H)
+ # Divide logits by sampling temperature.
+ # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
+ logits.div_(self.temperature)
+ completion_ids = input_ids_batch[:, -logits_to_keep:]
+ logps = selective_log_softmax(logits, completion_ids) # compute logprobs
+ all_logps.append(logps)
+
+ if compute_entropy:
+ with torch.no_grad():
+ entropies = entropy_from_logits(logits)
+ all_entropies.append(entropies)
+
+ logps = torch.cat(all_logps, dim=0)
+ entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None
+ return logps, entropies
+
+ def training_step(self, model, inputs, num_items_in_batch):
+ time_before = time.perf_counter()
+ output = super().training_step(model, inputs, num_items_in_batch)
+ self._step += 1
+ time_after = time.perf_counter()
+ self._current_train_step_time += time_after - time_before
+ if self._step % self.current_gradient_accumulation_steps == 0:
+ self._metrics["train"]["step_time"].append(self._current_train_step_time)
+ self._current_train_step_time = 0.0
+ return output
+
+ @profiling_decorator
+ def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]:
+ # Prepares inputs for model training/evaluation by managing completion generation and batch handling.
+ # During training:
+ # - Receives the local generation batch (Per-GPU batch size × steps per generation)
+ # from the modified training dataloader instead of the standard local batch
+ # - Generates completions once for the entire generation batch and splits it into batches of size
+ # `per_device_train_batch_size`
+ # - Buffers these completions and returns the appropriate slice for the current accumulation step
+ # - Optimizes by regenerating completions only periodically (every steps_per_generation * num_iterations)
+ # During evaluation:
+ # - The input is treated as a standard local batch (no accumulation, no multiple iterations)
+ # - Completions are generated for each batch without buffering or reuse
+ # Returns a single local batch in both cases.
+
+ mode = "train" if self.model.training else "eval"
+ if mode == "train":
+ generate_every = self.args.steps_per_generation * self.num_iterations
+ if self._step % generate_every == 0 or self._buffered_inputs is None:
+ # self._buffered_inputs=None can occur when resuming from a checkpoint
+ generation_batch = self._generate_and_score_completions(generation_batch)
+ generation_batch = split_pixel_values_by_grid(generation_batch)
+ if self.loss_type != "tpo":
+ generation_batch = shuffle_sequence_dict(generation_batch)
+ generation_batches = split_tensor_dict(generation_batch, self.args.steps_per_generation)
+ self._buffered_inputs = [unsplit_pixel_values_by_grid(batch) for batch in generation_batches]
+ inputs = self._buffered_inputs[self._step % self.args.steps_per_generation]
+ else:
+ # In evaluation, there is neither batch grouping for generation, nor multiple iterations, hence
+ # local generation batch == local eval batch
+ inputs = self._generate_and_score_completions(generation_batch)
+ return inputs
+
+ def _log_completion_extra(self, column: str, values: list):
+ """
+ Log extra columns to the completions table. Called from reward functions via the `log_extra` kwarg.
+
+ Args:
+ column (`str`):
+ Name of the column to add.
+ values (`list`):
+ Values for the column, one per sample in the batch.
+ """
+ self._pending_extra_logs[column].extend(values)
+
+ def _log_metric(self, name: str, value: float):
+ """
+ Log a scalar metric from a reward function. Called via the `log_metric` kwarg. Values are averaged over each
+ logging step and reported alongside built-in metrics like `kl` and `entropy`.
+
+ Args:
+ name (`str`):
+ Name of the metric.
+ value (`float`):
+ Scalar value for this batch.
+ """
+ self._pending_metrics[name].append(value)
+
+ @profiling_decorator
+ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list):
+ device = self.accelerator.device
+ rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
+
+ # Repeat all input columns (but "prompt", "completion", and "completion_ids") to match the num of generations
+ keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]]
+ reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
+
+ # This allows for dynamic reward shaping based on training progress.
+ reward_kwargs["trainer_state"] = self.state
+
+ # Allow reward functions to log extra columns to the completions table.
+ reward_kwargs["log_extra"] = self._log_completion_extra
+
+ # Allow reward functions to log additional scalar metrics.
+ reward_kwargs["log_metric"] = self._log_metric
+
+ async_funcs_info = [] # async custom functions for asyncio.gather
+
+ for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(
+ zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names, strict=True)
+ ):
+ if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models
+ with profiling_context(self, reward_func_name):
+ if is_conversational(inputs[0]):
+ messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)]
+ texts = [
+ apply_chat_template(x, reward_processing_class, **self.chat_template_kwargs)["text"]
+ for x in messages
+ ]
+ else:
+ texts = [p + c for p, c in zip(prompts, completions, strict=True)]
+ reward_inputs = reward_processing_class(
+ text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
+ )
+ reward_inputs = super()._prepare_inputs(reward_inputs)
+ with torch.inference_mode():
+ rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
+ elif inspect.iscoroutinefunction(reward_func): # Separate async reward funcs to run them in parallel later
+ async_funcs_info.append((i, reward_func, reward_func_name))
+ else:
+ # Run synchronous reward function
+ with profiling_context(self, reward_func_name):
+ if self.environments is not None:
+ reward_kwargs["environments"] = self.environments
+ output_reward_func = reward_func(
+ prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs
+ )
+ # Convert None values to NaN
+ output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]
+ rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
+
+ # Execute async custom functions in parallel using asyncio.gather
+ if async_funcs_info:
+
+ async def _invoke_async(index, func, func_name):
+ with profiling_context(self, func_name):
+ output = await func(
+ prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs
+ )
+ output = [r if r is not None else torch.nan for r in output]
+ return index, output
+
+ async def _run_async_funcs():
+ coros = [_invoke_async(i, func, func_name) for (i, func, func_name) in async_funcs_info]
+ return await asyncio.gather(*coros)
+
+ async_results = asyncio.run_coroutine_threadsafe(_run_async_funcs(), self.async_loop).result()
+ for idx, output_reward_func in async_results:
+ rewards_per_func[:, idx] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
+
+ # If all reward functions return None for a given row, issue a detailed warning
+ if torch.isnan(rewards_per_func).all(dim=1).any():
+ nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
+ row_reward_kwargs = {
+ key: value[nan_row_idx]
+ for key, value in reward_kwargs.items()
+ if key not in ("trainer_state", "log_extra", "log_metric")
+ }
+ row_reward_kwargs["prompt"] = prompts[nan_row_idx]
+ row_reward_kwargs["completion"] = completions[nan_row_idx]
+ logger.warning(
+ f"All reward functions returned None for the following kwargs:\n{row_reward_kwargs}\n"
+ "Please ensure that at least one reward function returns a valid reward."
+ )
+
+ # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
+ # completions may be distributed across processes
+ rewards_per_func = gather(rewards_per_func)
+ return rewards_per_func
+
+ def _tokenize_prompts(self, prompts: list):
+ """Tokenize prompts and extract images/multimodal fields for generation."""
+ if is_conversational({"prompt": prompts[0]}):
+ # Normalize string content to content blocks for VLM processors that don't handle plain strings.
+ if self._is_vlm:
+ prompts = [prepare_multimodal_messages(prompt) for prompt in prompts]
+
+ # Extract images from messages for VLM support
+ images = []
+ has_images = False
+ for prompt in prompts:
+ prompt_images = []
+ for message in prompt:
+ if isinstance(message["content"], list):
+ for part in message["content"]:
+ if part["type"] == "image":
+ prompt_images.append(part["image"])
+ has_images = True
+ images.append(prompt_images if prompt_images else None)
+ images = images if has_images else None
+
+ # Workaround for a bug in transformers 5.3.0 where some processors (e.g. Qwen2.5-VL) crash on
+ # batched unpadded input (transformers#44514).
+ # Fixed in transformers 5.4.0 (transformers#44563).
+ needs_padding_workaround = Version("5.3.0") <= Version(transformers.__version__) < Version("5.4.0")
+ tokenized = self.processing_class.apply_chat_template(
+ conversation=prompts,
+ tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[]
+ chat_template=self.chat_template,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ **({"padding": True} if needs_padding_workaround else {}),
+ **self.chat_template_kwargs,
+ )
+ if needs_padding_workaround:
+ # Unpad input_ids: remove padding tokens using attention_mask to get per-sequence lists
+ prompt_ids = [
+ [tok for tok, m in zip(ids, mask, strict=True) if m]
+ for ids, mask in zip(tokenized["input_ids"], tokenized["attention_mask"], strict=True)
+ ]
+ else:
+ prompt_ids = tokenized["input_ids"]
+ # For VLMs, the processor returns extra multimodal fields (pixel_values, image_grid_thw, etc.)
+ multimodal_fields = {k: v for k, v in tokenized.items() if k not in ("input_ids", "attention_mask")}
+ else:
+ prompt_ids = self.processing_class(text=prompts)["input_ids"]
+ images = None
+ multimodal_fields = {}
+ return prompt_ids, images, multimodal_fields
+
+ def _generate_single_turn(self, prompt_ids, images, multimodal_fields):
+ device = self.accelerator.device
+ mode = "train" if self.model.training else "eval"
+
+ # Generate completions using either vLLM or regular generation
+ if self.use_vllm:
+ # Sync weights if training step changed
+ if self.state.global_step != self._last_loaded_step:
+ with profiling_context(self, "sync_weights"):
+ self.vllm_generation.sync_weights()
+ self._last_loaded_step = self.state.global_step
+
+ # Generate using vLLM with raw token IDs
+ num_generations = self.num_generations if mode == "train" else self.num_generations_eval
+ _, completion_ids, logprobs, _ = self.vllm_generation.generate(
+ prompts=prompt_ids,
+ images=images,
+ num_generations=num_generations,
+ profiler=profiling_context(self, "vLLM.generate"),
+ )
+ # vLLM returns per-token top-k logprobs; keep only the top-1 (sampled token) logprob
+ logprobs = [[lp[0] for lp in seq] for seq in logprobs]
+
+ elif self.use_transformers_paged:
+ with (
+ profiling_context(self, "transformers.generate_batch"),
+ unwrap_model_for_generation(
+ self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
+ ) as unwrapped_model,
+ torch.no_grad(),
+ FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
+ ):
+ # Cast to the appropriate dtype based on training configuration
+ if self.args.bf16:
+ unwrapped_model.to(torch.bfloat16)
+ elif self.args.fp16:
+ unwrapped_model.to(torch.float16)
+ if self.args.cast_lm_head_to_fp32:
+ unwrapped_model.lm_head.to(torch.float32)
+ with torch.inference_mode():
+ # Continuous batching API expects 'inputs' arg only
+ all_outputs = unwrapped_model.generate_batch(
+ prompt_ids, generation_config=self.generation_config, progress_bar=False
+ )
+ unwrapped_model.train() # restore training mode, as generate_batch forces eval mode
+ completion_ids = [output.generated_tokens for output in all_outputs.values()]
+ logprobs = None # not used in this case
+
+ else:
+ # Regular generation path: left-pad token IDs into tensors
+ prompt_tensors = [torch.tensor(ids) for ids in prompt_ids]
+ padded_ids = pad(prompt_tensors, padding_value=self.pad_token_id, padding_side="left")
+ attention_mask = pad([torch.ones_like(t) for t in prompt_tensors], padding_value=0, padding_side="left")
+ generate_inputs = {"input_ids": padded_ids, "attention_mask": attention_mask}
+ # For VLMs, include multimodal fields as tensors (pixel_values, image_grid_thw, etc.)
+ for k, v in multimodal_fields.items():
+ if isinstance(v, torch.Tensor):
+ generate_inputs[k] = v
+ elif isinstance(v, list) and v and isinstance(v[0], list):
+ # Per-token field (e.g., token_type_ids): left-pad like input_ids
+ generate_inputs[k] = pad([torch.tensor(x) for x in v], padding_value=0, padding_side="left")
+ else:
+ generate_inputs[k] = torch.tensor(np.array(v))
+ generate_inputs = super()._prepare_inputs(generate_inputs)
+
+ with (
+ profiling_context(self, "transformers.generate"),
+ unwrap_model_for_generation(
+ self.model_wrapped,
+ self.accelerator,
+ gather_deepspeed3_params=self.args.ds3_gather_for_generation,
+ generation_kwargs=self.generation_kwargs, # Override model.generation_config with generation_kwargs to fix transformers#42762
+ ) as unwrapped_model,
+ torch.no_grad(),
+ FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
+ ):
+ prompt_completion_ids = unwrapped_model.generate(
+ **generate_inputs, generation_config=self.generation_config
+ )
+ # Compute prompt length and extract completion ids
+ prompt_length = generate_inputs["input_ids"].size(1)
+ completion_ids = prompt_completion_ids[:, prompt_length:]
+
+ # Mask everything after the first EOS token
+ is_eos = completion_ids == self.eos_token_id
+ eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
+ completion_ids = [
+ c[m].tolist() for c, m in zip(completion_ids.cpu(), completion_mask.bool().cpu(), strict=True)
+ ]
+ logprobs = None # not used in this case
+
+ return completion_ids, logprobs
+
+ def _get_tool_suffix_ids(self, tool_messages):
+ """Get token IDs for tool result formatting by using a minimal dummy conversation."""
+ # Use the real tool name instead of a dummy: some templates (e.g. GPT-OSS) derive the tool response
+ # header from the assistant's tool call name.
+ dummy_tool_calls = [{"type": "function", "function": {"name": tool_messages[0]["name"], "arguments": {}}}]
+ dummy_messages = [
+ {"role": "user", "content": "dummy"},
+ {
+ "role": "assistant",
+ # "content" is required here because VLM processors crash on tokenize=True without it
+ # (KeyError in processing_utils.py). See huggingface/transformers#45290.
+ "content": "",
+ "tool_calls": dummy_tool_calls,
+ },
+ ]
+ if self._is_vlm:
+ dummy_messages = prepare_multimodal_messages(dummy_messages)
+ tool_messages = prepare_multimodal_messages(tool_messages)
+
+ prefix_ids = self.processing_class.apply_chat_template(
+ dummy_messages,
+ add_generation_prompt=False,
+ tokenize=True,
+ chat_template=self.chat_template,
+ return_dict=False,
+ **self.chat_template_kwargs,
+ )
+ full_ids = self.processing_class.apply_chat_template(
+ dummy_messages + tool_messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ chat_template=self.chat_template,
+ return_dict=False,
+ **self.chat_template_kwargs,
+ )
+ # VLM processors return batched output (list of lists), unbatch for single conversation
+ if self._is_vlm:
+ prefix_ids = prefix_ids[0]
+ full_ids = full_ids[0]
+
+ # Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block.
+ # When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to
+ # EOS (not EOS + newline). Templates that don't use EOS as end-of-turn (e.g. Gemma uses
+ # ) skip this trimming.
+ eos_positions = [i for i, tok_id in enumerate(prefix_ids) if tok_id == self.eos_token_id]
+ if eos_positions:
+ prefix_ids = prefix_ids[: eos_positions[-1] + 1]
+
+ if full_ids[: len(prefix_ids)] != prefix_ids:
+ raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.")
+ return full_ids[len(prefix_ids) :]
+
+ def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields):
+ # Tool execution loop: execute tools, then regenerate completions with tool results appended to the prompt
+ tool_calls = [completion[0].get("tool_calls") for completion in completions]
+ idxs_with_tool = [idx for idx, tool_call in enumerate(tool_calls) if tool_call]
+ tool_calls = [tool_calls[idx] for idx in idxs_with_tool]
+ tool_mask = [[1] * len(ids) for ids in completion_ids] # 0 for tool result tokens, 1 elsewhere
+ # Collect images from multimodal tool responses for the forward pass
+ tool_images = [[] for _ in completion_ids]
+ tool_call_count = 0
+ tool_failure_count = 0
+ iteration_num = 0
+
+ while idxs_with_tool and iteration_num < self.max_tool_calling_iterations:
+ prompt_completion_tools = [prompts[i] for i in idxs_with_tool] # select only prompts that need tool calls
+ # Snapshot state so we can rollback tool results that would exceed max_completion_length
+ completions_len_before = [len(completions[i]) for i in idxs_with_tool]
+ tool_images_len_before = [len(tool_images[i]) for i in idxs_with_tool]
+ prompts_len_before = [len(prompts[i]) for i in idxs_with_tool]
+
+ # Call the tools, and build the new prompt for generation
+ for idx in range(len(idxs_with_tool)):
+ idx_with_tool = idxs_with_tool[idx]
+ tool_call_list = tool_calls[idx]
+ prompt_completion_tool = prompt_completion_tools[idx]
+ sync_tool_dict = self._sync_tool_dicts[idx_with_tool]
+ async_tool_dict = self._async_tool_dicts[idx_with_tool]
+ # Append the last assistant message (which triggered tool_calls) to the prompt
+ prompt_completion_tool.append(completions[idx_with_tool][-1])
+ async_coros = []
+ tool_call_results = []
+ for tool_call in tool_call_list:
+ tool_call_count += 1
+ if tool_call["type"] == "function":
+ function = tool_call["function"]
+ name = function["name"]
+ try:
+ if name in sync_tool_dict:
+ tool_call_results.append((name, sync_tool_dict[name](**function["arguments"])))
+ elif name in async_tool_dict:
+ async_coros.append((name, async_tool_dict[name](**function["arguments"])))
+ else:
+ raise ValueError(f"Tool {name} not found.")
+ except Exception as e:
+ tool_failure_count += 1
+ result = {"error": str(e)}
+ tool_call_results.append((name, result))
+ else:
+ tool_failure_count += 1
+ name = tool_call.get("name", "unknown")
+ tool_call_results.append((name, {"error": f"Unsupported tool call type: {tool_call['type']}"}))
+
+ if async_coros:
+
+ async def _run_async_tools(async_coros):
+ coros = [coro for _, coro in async_coros]
+ results = await asyncio.gather(*coros, return_exceptions=True)
+ return [(name, result) for (name, _), result in zip(async_coros, results, strict=False)]
+
+ async_results = asyncio.run_coroutine_threadsafe(
+ _run_async_tools(async_coros), self.async_loop
+ ).result()
+
+ for name, result in async_results:
+ if isinstance(result, Exception):
+ tool_failure_count += 1
+ tool_call_results.append((name, {"error": str(result)}))
+ else:
+ tool_call_results.append((name, result))
+
+ for name, result in tool_call_results:
+ # Support multimodal tool responses: if the tool returns a list of content blocks
+ # (e.g., [{"type": "image", "image": ...}, {"type": "text", "text": "..."}]),
+ # pass them through directly so _tokenize_prompts can extract images for VLMs.
+ content = result if isinstance(result, list) else str(result)
+ tool_message = {"role": "tool", "name": name, "content": content}
+ # Collect images from multimodal tool responses
+ if isinstance(content, list):
+ for part in content:
+ if isinstance(part, dict) and part.get("type") == "image":
+ tool_images[idx_with_tool].append(part["image"])
+ prompt_completion_tool.append(tool_message)
+ completions[idx_with_tool].append(tool_message)
+
+ # Build token IDs by concatenation: prompt + completion + tool_suffix.
+ prompt_completion_tool_ids = []
+ for idx in range(len(idxs_with_tool)):
+ idx_with_tool = idxs_with_tool[idx]
+ # Extract trailing tool messages from completions
+ tool_messages = []
+ for message in reversed(completions[idx_with_tool]):
+ if message["role"] == "tool":
+ tool_messages.insert(0, message)
+ else:
+ break
+ suffix_ids = self._get_tool_suffix_ids(tool_messages)
+ prompt_completion_tool_ids.append(
+ prompt_ids[idx_with_tool] + completion_ids[idx_with_tool] + suffix_ids
+ )
+
+ # Drop tool results whose addition would push the sequence past max_completion_length (the completion
+ # budget) or past the backend context ceiling (vLLM and transformers will error out on inputs longer than
+ # the model's max length). The sample exits the loop with its completion as-is, and the tool
+ # messages/images appended this iteration are rolled back so completions and tool_images stay consistent
+ # with completion_ids.
+ if self.use_vllm and self.vllm_mode == "colocate":
+ max_model_len = self.vllm_generation.llm.llm_engine.model_config.max_model_len
+ else:
+ config = self.model.config.text_config if self._is_vlm else self.model.config
+ max_model_len = config.max_position_embeddings
+ overlong = [
+ len(pct) - len(prompt_ids[i]) > self.max_completion_length or len(pct) >= max_model_len
+ for i, pct in zip(idxs_with_tool, prompt_completion_tool_ids, strict=True)
+ ]
+ for idx in range(len(idxs_with_tool)):
+ if overlong[idx]:
+ idx_with_tool = idxs_with_tool[idx]
+ del completions[idx_with_tool][completions_len_before[idx] :]
+ del tool_images[idx_with_tool][tool_images_len_before[idx] :]
+ del prompts[idx_with_tool][prompts_len_before[idx] :]
+ # Keep only non-overlong items for further processing
+ idxs_with_tool = [idx for idx, o in zip(idxs_with_tool, overlong, strict=True) if not o]
+ prompt_completion_tool_ids = [
+ pct for pct, o in zip(prompt_completion_tool_ids, overlong, strict=True) if not o
+ ]
+ if not idxs_with_tool:
+ break # all overlong, exit tool loop
+
+ # Filter images and multimodal fields to match the current subset (index into full batch).
+ # Merge tool response images so the model can see visual feedback during generation.
+ merged_images = images
+ if any(imgs for imgs in tool_images):
+ if merged_images is None:
+ merged_images = [imgs if imgs else None for imgs in tool_images]
+ else:
+ merged_images = [
+ (existing or []) + new for existing, new in zip(merged_images, tool_images, strict=True)
+ ]
+ loop_images = [merged_images[i] for i in idxs_with_tool] if merged_images else None
+ if multimodal_fields:
+ loop_multimodal_fields = {}
+ for k, v in multimodal_fields.items():
+ selected = [v[i] for i in idxs_with_tool]
+ # Per-token fields (e.g. token_type_ids) need zero-padding to match extended prompt length
+ if isinstance(selected[0], list):
+ selected = [
+ s + [0] * (len(pct) - len(s))
+ for s, pct in zip(selected, prompt_completion_tool_ids, strict=True)
+ ]
+ loop_multimodal_fields[k] = selected
+ else:
+ loop_multimodal_fields = {}
+
+ # Generate new completions after tool execution (using concatenated IDs, no re-tokenization)
+ post_tool_ids, post_tool_logprobs = self._generate_single_turn(
+ prompt_completion_tool_ids, loop_images, loop_multimodal_fields
+ )
+
+ # Truncate so that pct[len(prompt_ids[idx]) :] + post_tool does not exceed max_completion_length.
+ # The pre-regen check guarantees len(completion_tool_ids) <= max_completion_length, so any
+ # excess can only come from post_tool_ids. post_tool_ids is model-generated text and never
+ # contains image tokens, so a plain slice is safe.
+ for idx in range(len(idxs_with_tool)):
+ idx_with_tool = idxs_with_tool[idx]
+ completion_tool_length = len(prompt_completion_tool_ids[idx]) - len(prompt_ids[idx_with_tool])
+ excess_length = completion_tool_length + len(post_tool_ids[idx]) - self.max_completion_length
+ if excess_length > 0:
+ new_len = len(post_tool_ids[idx]) - excess_length
+ post_tool_ids[idx] = post_tool_ids[idx][:new_len]
+ if logprobs is not None:
+ post_tool_logprobs[idx] = post_tool_logprobs[idx][:new_len]
+
+ # Update tool_mask: the tool result should be 0 and the post-tool 1
+ for idx in range(len(idxs_with_tool)):
+ idx_with_tool = idxs_with_tool[idx]
+ prompt_completion_tool_length = len(prompt_completion_tool_ids[idx])
+ prompt_length = len(prompt_ids[idx_with_tool])
+ completion_length = len(completion_ids[idx_with_tool])
+ post_tool_length = len(post_tool_ids[idx])
+ tool_length = prompt_completion_tool_length - prompt_length - completion_length
+ tool_mask[idx_with_tool] += [0] * tool_length + [1] * post_tool_length
+ if logprobs is not None:
+ logprobs[idx_with_tool] += [0.0] * tool_length + post_tool_logprobs[idx]
+
+ # Update completion_ids with the new completions (after tool execution)
+ for idx in range(len(idxs_with_tool)):
+ idx_with_tool = idxs_with_tool[idx]
+ prompt_length = len(prompt_ids[idx_with_tool])
+ pct = prompt_completion_tool_ids[idx] # = prompt-completion-tool
+ completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx]
+
+ # Decode post-tool completions.
+ post_tool_completions = [
+ parse_response(self.processing_class, ids) if ids else {} for ids in post_tool_ids
+ ]
+
+ # Add post-tool completions to the existing completions
+ for idx in range(len(idxs_with_tool)):
+ idx_with_tool = idxs_with_tool[idx]
+ if post_tool_completions[idx]: # {} if post-tool completions completely truncated
+ completions[idx_with_tool].append(post_tool_completions[idx])
+
+ # Check for further tool calls
+ tool_calls = [completion.get("tool_calls") for completion in post_tool_completions]
+ idxs_with_tool = [idx for idx, tool_call in zip(idxs_with_tool, tool_calls, strict=True) if tool_call]
+ tool_calls = [tool_call for tool_call in tool_calls if tool_call]
+ iteration_num += 1
+
+ return tool_mask, completions, completion_ids, logprobs, tool_call_count, tool_failure_count, tool_images
+
+ def _generate(self, prompts: list):
+ device = self.accelerator.device
+ mode = "train" if self.model.training else "eval"
+
+ # Copy the prompts to avoid modifying the original list
+ prompts = copy.deepcopy(prompts)
+
+ if self.rollout_func is not None:
+ # Keep vLLM weights in sync for custom rollouts that rely on vLLM utilities.
+ if self.use_vllm and self.state.global_step != self._last_loaded_step:
+ with profiling_context(self, "sync_weights"):
+ self.vllm_generation.sync_weights()
+ self._last_loaded_step = self.state.global_step
+
+ # Pass prompts to rollout_func preserving structured messages.
+ # Chat templating must happen inside rollout_func, at the backend boundary, so that
+ # multimodal content (images, typed content blocks) is not lost before rollout logic runs.
+ output = self.rollout_func(prompts, self)
+ required_keys = {"prompt_ids", "completion_ids", "logprobs"}
+ missing_keys = required_keys - output.keys()
+ if missing_keys:
+ missing_keys_list = sorted(missing_keys)
+ raise ValueError(f"rollout_func must return keys {missing_keys_list} in its output dict.")
+ extra_fields = {k: v for k, v in output.items() if k not in required_keys}
+ prompt_ids, completion_ids, logprobs = output["prompt_ids"], output["completion_ids"], output["logprobs"]
+ images = None
+ multimodal_fields = {}
+ else:
+ prompt_ids, images, multimodal_fields = self._tokenize_prompts(prompts)
+ completion_ids, logprobs = self._generate_single_turn(prompt_ids, images, multimodal_fields)
+ extra_fields = {}
+
+ # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls.
+ if is_conversational({"prompt": prompts[0]}):
+ tokenizer = self.processing_class.tokenizer if self._is_vlm else self.processing_class
+ if (
+ Version(transformers.__version__) >= Version("5.0.0") # parse_response added in v5
+ and hasattr(tokenizer, "response_schema") # attribute not set by default for now
+ and tokenizer.response_schema is not None # only works if the tokenizer has a schema
+ ):
+ completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids]
+ else:
+ contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
+ completions = [[{"role": "assistant", "content": content}] for content in contents]
+ else:
+ completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
+
+ # Extract tool calls from the completions and (possibly) execute them
+ tool_images = []
+ if self.tools:
+ (
+ tool_mask,
+ completions,
+ completion_ids,
+ logprobs,
+ tool_call_count,
+ tool_failure_count,
+ tool_images,
+ ) = self._tool_call_loop(
+ prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields
+ )
+ # Merge tool response images into the images list for the forward pass
+ if any(imgs for imgs in tool_images):
+ if images is None:
+ images = [imgs if imgs else None for imgs in tool_images]
+ else:
+ images = [(existing or []) + new for existing, new in zip(images, tool_images, strict=True)]
+ else:
+ # Support custom env_mask from rollout_func (e.g., for environment feedback masking)
+ # Internally treated as tool_mask - marks model tokens (1) vs external tokens (0)
+ tool_mask = extra_fields.pop("env_mask", None)
+
+ # Get completion length per sequence, used for logging
+ prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device)
+ if tool_mask is not None: # count only model-generated tokens (tool_mask=1)
+ completion_lengths = torch.tensor([sum(mask) for mask in tool_mask], device=device)
+ else:
+ completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device)
+ agg_prompt_lengths = self.accelerator.gather(prompt_lengths)
+ agg_completion_lengths = self.accelerator.gather(completion_lengths)
+ total_prompt_tokens = agg_prompt_lengths.sum()
+ total_completion_tokens = agg_completion_lengths.sum() # = num_items_in_batch, required for the DAPO loss
+
+ # Log the metrics
+ if mode == "train":
+ self.state.num_input_tokens_seen += (total_prompt_tokens + total_completion_tokens).item()
+ self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]
+
+ # Log completion lengths, mean, min, max
+ self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item())
+ self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item())
+ self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item())
+
+ # Identify sequences that terminated with EOS and log their lengths
+ eos_and_pad = [self.eos_token_id, self.pad_token_id]
+ is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device)
+ agg_is_truncated = self.accelerator.gather(is_truncated)
+ self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item())
+ term_completion_lengths = agg_completion_lengths[~agg_is_truncated]
+ if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found
+ term_completion_lengths = torch.zeros(1, device=device)
+ self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item())
+ self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
+ self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())
+
+ if self.tools:
+ agg_tool_call_count = self.accelerator.gather(torch.tensor(tool_call_count, device=device)).sum()
+ tool_call_frequency = (agg_tool_call_count / len(agg_prompt_lengths)).item()
+ self._metrics[mode]["tools/call_frequency"].append(tool_call_frequency)
+ agg_tool_failure_count = self.accelerator.gather(torch.tensor(tool_failure_count, device=device)).sum()
+ failure_frequency = (
+ (agg_tool_failure_count / agg_tool_call_count).item() if agg_tool_call_count > 0 else 0.0
+ )
+ self._metrics[mode]["tools/failure_frequency"].append(failure_frequency)
+
+ return (
+ prompt_ids,
+ completion_ids,
+ tool_mask,
+ completions,
+ total_completion_tokens,
+ logprobs,
+ extra_fields,
+ images,
+ tool_images,
+ )
+
+ def _generate_and_score_completions(
+ self, inputs: list[dict[str, torch.Tensor | Any]]
+ ) -> dict[str, torch.Tensor | Any]:
+ device = self.accelerator.device
+ mode = "train" if self.model.training else "eval"
+
+ prompts = [x["prompt"] for x in inputs]
+
+ if self.environments:
+ for prompt, environment, reset_kwargs in zip(prompts, self.environments, inputs, strict=True):
+ observation = environment.reset(**reset_kwargs)
+ if observation is None:
+ continue
+ if isinstance(observation, list) and isinstance(prompt[-1]["content"], str):
+ prompt[-1]["content"] = [{"type": "text", "text": prompt[-1]["content"]}]
+ if isinstance(observation, str) and isinstance(prompt[-1]["content"], list):
+ observation = [{"type": "text", "text": observation}]
+ prompt[-1]["content"] += observation
+
+ if "images" in inputs[0]:
+ images = [example.get("images") for example in inputs]
+ elif "image" in inputs[0]:
+ images = [[example.get("image")] if example.get("image") is not None else None for example in inputs]
+ else:
+ images = None
+ # Transformers requires at least one image in the batch, otherwise it throws an error
+ if images is not None and all(img_list == [] for img_list in images):
+ images = None
+
+ # If the prompts are conversational and the inputs contain images, we need to convert the prompts from
+ # [{"role": "user", "content": "What color is the sky?"}] to
+ # [{"role": "user", "content": [{"type": "image", "image": }, {"type": "text", "text": "What color is the sky?"}]}]
+ if images is not None:
+ if not is_conversational(inputs[0]):
+ raise ValueError(
+ "Multimodal training requires conversational prompts. It looks like the dataset contains "
+ "non-conversational inputs, likely because a chat template was applied before passing the dataset "
+ "to the trainer. Please provide the raw conversational prompts and let the trainer apply the chat "
+ "template internally."
+ )
+ prompts = [
+ prepare_multimodal_messages(prompt, images=image_list)
+ for prompt, image_list in zip(prompts, images, strict=True)
+ ]
+
+ dataset_images = images # preserve dataset images before _generate may overwrite
+ (
+ prompt_ids_list,
+ completion_ids_list,
+ tool_mask_list,
+ completions,
+ num_items_in_batch,
+ sampling_per_token_logps_list,
+ extra_fields,
+ images,
+ tool_images,
+ ) = self._generate(prompts)
+ if images is None:
+ images = dataset_images # restore dataset images (rollout_func path returns None)
+
+ # Convert lists of token IDs to padded tensors
+ prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list]
+ prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
+ prompt_ids = pad(
+ prompt_ids,
+ padding_value=self.pad_token_id,
+ padding_side="left",
+ pad_to_multiple_of=self.pad_to_multiple_of,
+ ).to(device=device)
+ prompt_mask = pad(
+ prompt_mask, padding_value=0, padding_side="left", pad_to_multiple_of=self.pad_to_multiple_of
+ ).to(device=device)
+ completion_ids = [torch.tensor(ids) for ids in completion_ids_list]
+ completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids]
+ completion_ids = pad(
+ completion_ids,
+ padding_value=self.pad_token_id,
+ padding_side="right",
+ pad_to_multiple_of=self.pad_to_multiple_of,
+ ).to(device=device)
+ completion_mask = pad(
+ completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
+ ).to(device=device)
+ if sampling_per_token_logps_list is not None:
+ sampling_per_token_logps = [torch.tensor(logps) for logps in sampling_per_token_logps_list]
+ sampling_per_token_logps = pad(
+ sampling_per_token_logps,
+ padding_value=0.0,
+ padding_side="right",
+ pad_to_multiple_of=self.pad_to_multiple_of,
+ ).to(device=device)
+ else:
+ sampling_per_token_logps = None
+ if tool_mask_list is not None:
+ tool_mask = [torch.tensor(mask) for mask in tool_mask_list]
+ tool_mask = pad(
+ tool_mask, padding_value=1, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
+ ).to(device=device)
+ else:
+ tool_mask = None
+
+ # If mask_truncated_completions is enabled, zero out truncated completions for attention and loss masking
+ if self.mask_truncated_completions:
+ eos_and_pad = [self.eos_token_id, self.pad_token_id]
+ is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device)
+ # Mask completion_mask for attention masking
+ completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int()
+ # Also mask tool_mask for consistency in multi-turn training
+ if tool_mask is not None:
+ tool_mask = tool_mask * (~is_truncated).unsqueeze(1).int()
+
+ # Concatenate prompt_mask with completion_mask for logit computation
+ prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
+
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
+ batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
+
+ num_images = [len(img_list) if img_list else 0 for img_list in images] if images is not None else None
+
+ # Get forward_kwargs for models with multimodal inputs.
+ # When tool images are present (from _tool_call_loop), use image_processor directly and build
+ # mm_token_type_ids from prompt_completion_ids. Otherwise, use the full processor pipeline
+ # which returns model-specific keys (image_sizes, pixel_attention_mask, etc.).
+ if self.tools and any(imgs for imgs in tool_images) and self._is_vlm:
+ flat_images = [img for img_list in images if img_list for img in img_list]
+ image_inputs = self.processing_class.image_processor(images=flat_images, return_tensors="pt")
+ image_inputs = super()._prepare_inputs(image_inputs)
+ forward_kwargs = dict(image_inputs)
+ elif images is not None:
+ prompts_text = [
+ apply_chat_template(
+ {"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs
+ )["prompt"]
+ for prompt in prompts
+ ]
+ prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
+ prompt_inputs = super()._prepare_inputs(prompt_inputs)
+ forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
+ else:
+ forward_kwargs = {}
+
+ # If token_type_ids are used, extend them with zeros for the completion part
+ if "token_type_ids" in forward_kwargs:
+ token_type_ids = forward_kwargs["token_type_ids"]
+ if self.pad_to_multiple_of is not None:
+ # Needed only with pad_to_multiple_of: otherwise prompt_ids and token_type_ids must have equal len
+ padding_size = prompt_ids.size(1) - token_type_ids.size(1)
+ if padding_size > 0:
+ token_type_ids = torch.cat(
+ [token_type_ids.new_zeros((token_type_ids.size(0), padding_size)), token_type_ids], dim=1
+ )
+ forward_kwargs["token_type_ids"] = torch.cat(
+ [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
+ )
+ # If mm_token_type_ids are used, extend them with zeros for the completion part
+ if "mm_token_type_ids" in forward_kwargs:
+ mm_token_type_ids = forward_kwargs["mm_token_type_ids"]
+ if self.pad_to_multiple_of is not None:
+ # Needed only with pad_to_multiple_of: otherwise prompt_ids and mm_token_type_ids must have equal len
+ padding_size = prompt_ids.size(1) - mm_token_type_ids.size(1)
+ if padding_size > 0:
+ mm_token_type_ids = torch.cat(
+ [mm_token_type_ids.new_zeros((mm_token_type_ids.size(0), padding_size)), mm_token_type_ids],
+ dim=1,
+ )
+ forward_kwargs["mm_token_type_ids"] = torch.cat(
+ [mm_token_type_ids, mm_token_type_ids.new_zeros(completion_ids.shape)], dim=1
+ )
+
+ # For VLM tool images: build token type IDs from the full prompt_completion_ids.
+ # This must happen AFTER the token_type_ids/mm_token_type_ids extension blocks above,
+ # because our version already covers the full sequence (images are in the completion,
+ # not just the prompt).
+ if self.tools and any(imgs for imgs in tool_images) and self._is_vlm:
+ mm_ids = torch.zeros_like(prompt_completion_ids)
+ if self._image_pad_token_id is not None:
+ mm_ids[prompt_completion_ids == self._image_pad_token_id] = 1
+ if self._video_pad_token_id is not None:
+ mm_ids[prompt_completion_ids == self._video_pad_token_id] = 2
+
+ # Use the same key the model expects: token_type_ids for models like Gemma,
+ # mm_token_type_ids for models like Qwen.
+ image_grid_thw = forward_kwargs.get("image_grid_thw")
+ if image_grid_thw is not None:
+ forward_kwargs["mm_token_type_ids"] = mm_ids
+ else:
+ forward_kwargs["token_type_ids"] = mm_ids
+
+ # Truncation safety (Qwen-style models with image_grid_thw only): if
+ # max_completion_length truncated some image tokens, the number of image pad tokens
+ # in input_ids won't match pixel_values features. Check per-sample and drop ALL
+ # images for any sample with a mismatch (safe fallback).
+ if image_grid_thw is not None and num_images is not None:
+ merge_length = getattr(self.processing_class.image_processor, "merge_size", 2) ** 2
+ img_offset = 0
+ has_mismatch = False
+ for b in range(mm_ids.shape[0]):
+ sample_tokens = (mm_ids[b] == 1).sum().item()
+ sample_features = 0
+ for i in range(num_images[b]):
+ grid_idx = img_offset + i
+ if grid_idx < image_grid_thw.shape[0]:
+ sample_features += image_grid_thw[grid_idx].prod().item() // merge_length
+ if sample_tokens != sample_features:
+ has_mismatch = True
+ break
+ img_offset += num_images[b]
+
+ if has_mismatch:
+ # Drop all images: safer than partial trim which is error-prone
+ forward_kwargs.pop("pixel_values", None)
+ forward_kwargs.pop("image_grid_thw", None)
+ mm_ids.zero_()
+ forward_kwargs["mm_token_type_ids"] = mm_ids
+ num_images = None
+
+ # When gradient checkpointing is enabled with use_reentrant=True (non default), calling the model inside a
+ # torch.no_grad() block triggers a harmless PyTorch warning ("None of the inputs have requires_grad=True").
+ # Temporarily disable checkpointing to avoid this warning during inference.
+ with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs):
+ # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
+ # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
+ # samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps
+ # for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set
+ # old_per_token_logps to None.
+ # When using vLLM, we always compute old_per_token_logps for importance sampling, it was shown that the
+ # distribution mismatch between vLLM and the training model can be large and harm the training.
+ generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency
+ if (
+ self.loss_type == "tpo"
+ or self.args.gradient_accumulation_steps % generate_every != 0
+ or (self.use_vllm and self.vllm_importance_sampling_correction)
+ ):
+ old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
+ self.model,
+ prompt_completion_ids,
+ attention_mask,
+ logits_to_keep,
+ batch_size,
+ num_images=num_images,
+ **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, image_position_ids
+ )
+ else:
+ old_per_token_logps = None
+
+ # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch
+ if self.use_vllm and self.vllm_importance_sampling_correction:
+ mask = completion_mask if tool_mask is None else completion_mask * tool_mask
+ per_token_logps_diff = (old_per_token_logps - sampling_per_token_logps) * mask
+
+ sequence_level_is = self.vllm_importance_sampling_mode in ["sequence_mask", "sequence_truncate"]
+ if sequence_level_is:
+ per_sequence_logps_diff = per_token_logps_diff.sum(dim=-1, keepdim=True)
+ logps_diff = per_sequence_logps_diff
+ else:
+ logps_diff = per_token_logps_diff
+
+ vllm_importance_sampling_ratio = torch.exp(logps_diff)
+
+ # vllm_importance_sampling_ratio.shape:
+ # token_* modes: (B, T) (per-token ratio)
+ # sequence_* modes: (B, 1) (per-sequence ratio)
+
+ if self.vllm_importance_sampling_mode in ["sequence_truncate", "token_truncate"]:
+ vllm_importance_sampling_ratio = torch.clamp(
+ vllm_importance_sampling_ratio, max=self.vllm_importance_sampling_cap
+ )
+ elif self.vllm_importance_sampling_mode in ["sequence_mask", "token_mask"]:
+ vllm_importance_sampling_ratio = vllm_importance_sampling_ratio.masked_fill(
+ vllm_importance_sampling_ratio > self.vllm_importance_sampling_cap, value=0.0
+ )
+ else:
+ raise ValueError(
+ f"Unknown vLLM importance sampling level: {self.vllm_importance_sampling_mode}. Possible values are 'token_truncate', 'token_mask', 'sequence_truncate', and 'sequence_mask'."
+ )
+
+ # Compute the per-token log probabilities for the reference model
+ if self.beta != 0.0:
+ if self.ref_model is not None:
+ ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
+ self.ref_model,
+ prompt_completion_ids,
+ attention_mask,
+ logits_to_keep,
+ batch_size=batch_size,
+ num_images=num_images,
+ **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, image_position_ids
+ )
+ else:
+ # When training a PEFT adapter, how we obtain the reference depends on the setup:
+ # - New adapter: disabling adapters yields the base model.
+ # - Re-training an existing adapter: an initial copy is loaded under the name "ref".
+ model = self.accelerator.unwrap_model(self.model)
+ with use_adapter(model, adapter_name="ref" if "ref" in model.peft_config else None):
+ ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
+ self.model,
+ prompt_completion_ids,
+ attention_mask,
+ logits_to_keep,
+ batch_size=batch_size,
+ num_images=num_images,
+ **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, image_position_ids
+ )
+ else:
+ ref_per_token_logps = None
+
+ # Decode
+ prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True)
+ completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
+
+ # Merge extra_fields from rollout_func into inputs for reward functions
+ if extra_fields:
+ for i, inp in enumerate(inputs):
+ for key, values in extra_fields.items():
+ if isinstance(values, list) and i < len(values):
+ inp[key] = values[i]
+ elif not isinstance(values, list):
+ inp[key] = values
+
+ # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is
+ # important because rewards will be normalized per group, and completions are distributed. We will later slice
+ # rewards_per_func to extract each process's subset.
+ rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
+ num_generations = self.num_generations if mode == "train" else self.num_generations_eval
+
+ if self.multi_objective_aggregation == "sum_then_normalize":
+ # Apply weights to each reward function's output and sum
+ rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
+ mean_grouped_rewards = rewards.view(-1, num_generations).mean(dim=1)
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(num_generations, dim=0)
+ if self.scale_rewards in ["group", "none"]:
+ # If self.scale_rewards = "none", we'll only use std_rewards to check for zero std for logging
+ if num_generations > 1:
+ std_rewards = rewards.view(-1, num_generations).std(dim=1)
+ std_rewards = std_rewards.repeat_interleave(num_generations, dim=0)
+ else: # doesn't occur during training, but could occur in eval when num_generations_eval=1
+ std_rewards = torch.zeros_like(rewards)
+ elif self.scale_rewards == "batch":
+ # Compute global std
+ if rewards.numel() > 1:
+ std_rewards = rewards.std().expand_as(rewards)
+ else: # doesn't occur during training, but could occur in eval when num_generations_eval=batch_size=1
+ std_rewards = torch.zeros_like(rewards)
+ else:
+ raise ValueError(
+ f"Invalid value for scale_rewards: {self.scale_rewards}. Must be one of 'batch', 'group', or 'none'."
+ )
+
+ advantages = rewards - mean_grouped_rewards
+ if self.scale_rewards != "none":
+ advantages = advantages / (std_rewards + 1e-4)
+ is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging
+
+ elif self.multi_objective_aggregation == "normalize_then_sum":
+ grouped = rewards_per_func.view(-1, num_generations, len(self.reward_funcs))
+ mean_k = torch.nanmean(grouped, dim=1, keepdim=True)
+ std_k = nanstd(grouped, dim=1, keepdim=True) if num_generations > 1 else torch.zeros_like(mean_k)
+ reward_k = (grouped - mean_k) / (std_k + 1e-4)
+ reward_k = reward_k.view(-1, len(self.reward_funcs))
+ rewards = (reward_k * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
+ std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards)
+ advantages = (rewards - rewards.mean()) / (std_rewards + 1e-4)
+ is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging
+
+ else:
+ raise ValueError(
+ f"Invalid multi_objective_aggregation: {self.multi_objective_aggregation}. Must be "
+ "'sum_then_normalize' or 'normalize_then_sum'."
+ )
+
+ # Slice to keep only the local part of the data
+ process_slice = slice(
+ self.accelerator.process_index * len(prompts),
+ (self.accelerator.process_index + 1) * len(prompts),
+ )
+ all_process_advantages = advantages.clone() # keep the aggregated advantages for logging
+ tpo_targets = None
+ tpo_valid_mask = None
+ if self.loss_type == "tpo":
+ if old_per_token_logps is None:
+ raise RuntimeError("TPO requires rollout-time log probabilities to build the target distribution.")
+ loss_mask = completion_mask if tool_mask is None else completion_mask * tool_mask
+ tpo_valid_mask = loss_mask.any(dim=-1)
+ # Length-normalize sequence logps so the target distribution isn't dominated by length variance
+ # (without this, log_softmax(old_sequence_logps) collapses to ~one-hot on the longest/most-likely
+ # completion, and rewards can't compete).
+ old_sequence_logps = (old_per_token_logps * loss_mask).sum(dim=-1)
+ if self.tpo_length_normalize_logps:
+ old_sequence_logps = old_sequence_logps / loss_mask.sum(dim=-1).clamp(min=1)
+ all_process_old_sequence_logps = gather(old_sequence_logps)
+ all_process_tpo_valid_mask = gather(tpo_valid_mask)
+ tpo_scores = self.get_tpo_scores(rewards, num_generations, valid_mask=all_process_tpo_valid_mask)
+ all_process_tpo_targets = self.get_tpo_targets(
+ all_process_old_sequence_logps,
+ tpo_scores,
+ num_generations=num_generations,
+ temperature=self.tpo_target_temperature,
+ valid_mask=all_process_tpo_valid_mask,
+ )
+ tpo_targets = all_process_tpo_targets[process_slice]
+ target_groups = all_process_tpo_targets.view(-1, num_generations)
+ target_entropy = -(target_groups * target_groups.clamp_min(torch.finfo(target_groups.dtype).tiny).log())
+ target_entropy = target_entropy.sum(dim=1).mean()
+ self._metrics[mode]["tpo/target_entropy"].append(target_entropy.item())
+ advantages = advantages[process_slice]
+
+ # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
+ for i, reward_func_name in enumerate(self.reward_func_names):
+ mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
+ self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards)
+ std_func_rewards = nanstd(rewards_per_func[:, i]).item()
+ self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_func_rewards)
+ rewards = (rewards_per_func * self.reward_weights.to(rewards_per_func.device).unsqueeze(0)).nansum(dim=1)
+ self._metrics[mode]["reward"].append(rewards.mean().item())
+ self._metrics[mode]["reward_std"].append(rewards.std().item())
+ self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item())
+
+ # Log prompt and completion texts
+ self._logs["prompt"].extend(gather_object(prompts_text))
+ self._logs["completion"].extend(gather_object(completions_text))
+ for i, name in enumerate(self.reward_func_names):
+ self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
+ self._logs["advantages"].extend(all_process_advantages.tolist())
+
+ # Flush user-logged extra columns (from log_extra), gathering across processes.
+ # Keys must be sorted so that all ranks call gather_object in the same order, otherwise values
+ # get mis-attributed across columns (dict insertion order may differ between processes).
+ for column in sorted(self._pending_extra_logs):
+ self._logs["extra"][column].extend(gather_object(self._pending_extra_logs[column]))
+ self._pending_extra_logs.clear()
+
+ # Flush user-logged metrics (from log_metric), averaging across processes.
+ # Keys must be sorted so that all ranks call accelerator.gather in the same order, otherwise values
+ # get mis-attributed across metrics (dict insertion order may differ between processes).
+ for name in sorted(self._pending_metrics):
+ values = self._pending_metrics[name]
+ local_mean = sum(values) / len(values)
+ global_mean = self.accelerator.gather(torch.tensor(local_mean, device=device)).mean().item()
+ self._metrics[mode][name].append(global_mean)
+ self._pending_metrics.clear()
+
+ if images is not None:
+ self._logs["images"].extend(gather_object(images))
+
+ if self.use_vllm and self.vllm_importance_sampling_correction:
+ delta = torch.abs(old_per_token_logps - sampling_per_token_logps)
+ mask = completion_mask.bool() if tool_mask is None else (completion_mask * tool_mask).bool()
+ delta = delta[mask]
+ mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device)
+ max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device)
+ self._metrics[mode]["sampling/sampling_logp_difference/mean"].append(
+ self.accelerator.gather(mean_delta).mean().item()
+ )
+ self._metrics[mode]["sampling/sampling_logp_difference/max"].append(
+ self.accelerator.gather(max_delta).max().item()
+ )
+ if sequence_level_is:
+ flat_is_ratio = vllm_importance_sampling_ratio.flatten()
+ else:
+ flat_is_ratio = vllm_importance_sampling_ratio[mask]
+
+ min_importance_sampling_ratio = (
+ torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
+ )
+ mean_importance_sampling_ratio = (
+ torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
+ )
+ max_importance_sampling_ratio = (
+ torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
+ )
+ self._metrics[mode]["sampling/importance_sampling_ratio/min"].append(
+ nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item()
+ )
+ self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append(
+ self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item()
+ )
+ self._metrics[mode]["sampling/importance_sampling_ratio/max"].append(
+ nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item()
+ )
+
+ output = {
+ "prompt_ids": prompt_ids,
+ "prompt_mask": prompt_mask,
+ "completion_ids": completion_ids,
+ "completion_mask": completion_mask,
+ "advantages": advantages,
+ "num_items_in_batch": num_items_in_batch,
+ }
+ if old_per_token_logps is not None:
+ output["old_per_token_logps"] = old_per_token_logps
+ if tpo_targets is not None:
+ output["tpo_targets"] = tpo_targets
+ output["tpo_valid_mask"] = tpo_valid_mask
+ if self.use_vllm and self.vllm_importance_sampling_correction:
+ output["importance_sampling_ratio"] = vllm_importance_sampling_ratio
+ if sampling_per_token_logps is not None:
+ output["sampling_per_token_logps"] = sampling_per_token_logps
+ if ref_per_token_logps is not None:
+ output["ref_per_token_logps"] = ref_per_token_logps
+ if "pixel_values" in forward_kwargs:
+ output["pixel_values"] = forward_kwargs["pixel_values"]
+ if "image_grid_thw" in forward_kwargs:
+ output["image_grid_thw"] = forward_kwargs["image_grid_thw"]
+ if "pixel_attention_mask" in forward_kwargs:
+ output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"]
+ if "image_sizes" in forward_kwargs:
+ output["image_sizes"] = forward_kwargs["image_sizes"]
+ if "token_type_ids" in forward_kwargs:
+ output["token_type_ids"] = forward_kwargs["token_type_ids"]
+ if "mm_token_type_ids" in forward_kwargs:
+ output["mm_token_type_ids"] = forward_kwargs["mm_token_type_ids"]
+ if "image_position_ids" in forward_kwargs:
+ output["image_position_ids"] = forward_kwargs["image_position_ids"]
+ if images is not None:
+ output["num_images"] = num_images
+ if tool_mask is not None:
+ output["tool_mask"] = tool_mask
+ return output
+
+ def compute_liger_loss(self, unwrapped_model, inputs):
+ # Compute the per-token log probabilities for the model
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
+
+ # Get the last hidden state of the model
+ last_hidden_state = self._get_last_hidden_state(
+ unwrapped_model,
+ input_ids,
+ attention_mask,
+ logits_to_keep,
+ inputs.get("pixel_values"),
+ inputs.get("image_grid_thw"),
+ inputs.get("pixel_attention_mask"),
+ inputs.get("image_sizes"),
+ inputs.get("image_position_ids"),
+ )
+
+ # Apply tool_mask (from env_mask) for loss computation in multi-turn training scenarios
+ loss_mask = completion_mask if "tool_mask" not in inputs else completion_mask * inputs["tool_mask"]
+ mode = "train" if self.model.training else "eval"
+ normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval
+ self.liger_tpo_loss.num_generations = self.num_generations if mode == "train" else self.num_generations_eval
+
+ # Compute loss and metrics using liger TargetPO loss
+ loss, mean_kl, mean_entropy = self.liger_tpo_loss(
+ _input=last_hidden_state,
+ lin_weight=unwrapped_model.lm_head.weight,
+ selected_token_ids=completion_ids,
+ # The attention_mask parameter in liger loss is actually used as a loss mask (not model attention)
+ attention_mask=loss_mask,
+ tpo_targets=inputs["tpo_targets"],
+ tpo_valid_mask=inputs.get("tpo_valid_mask"),
+ bias=unwrapped_model.lm_head.bias,
+ old_per_token_logps=inputs.get("old_per_token_logps"),
+ ref_per_token_logps=inputs.get("ref_per_token_logps"),
+ normalizer=normalizer,
+ )
+
+ if self.beta != 0.0:
+ self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).mean().item())
+ self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item())
+ return loss
+
+ @profiling_decorator
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
+ if return_outputs:
+ raise ValueError("The TargetPOTrainer does not support returning outputs")
+ if self.use_liger_kernel:
+ # Compute the loss using the liger TargetPO loss
+ unwrapped_model = self.accelerator.unwrap_model(model)
+ return self._forward_redirection(model, unwrapped_model, self.compute_liger_loss, unwrapped_model, inputs)
+ else:
+ return self._compute_loss(model, inputs)
+
+ @staticmethod
+ def get_off_policy_mask(
+ advantages: torch.Tensor,
+ per_token_logps: torch.Tensor,
+ sampling_per_token_logps: torch.Tensor,
+ mask: torch.Tensor,
+ off_policy_threshold: float,
+ ) -> torch.Tensor:
+ """
+ Computes the Off-Policy Sequence Mask from DeepSeek-V3.2 paper. Returns a (B, 1) tensor where 1.0 indicates
+ "Keep" and 0.0 indicates "Drop".
+ """
+ # forward KL div: log(pi_old) - log(pi_theta)
+ kl_div = sampling_per_token_logps - per_token_logps.detach()
+ # Sequence-level Mean KL (ignoring prompt+padding)
+ seq_kl_sum = (kl_div * mask).sum(dim=1, keepdim=True)
+ avg_seq_kl = seq_kl_sum / mask.sum(dim=1, keepdim=True).clamp(min=1.0)
+ # Keep if (Advantage >= 0) OR (KL <= delta)
+ is_pos_adv = advantages >= 0
+ is_low_kl = avg_seq_kl <= off_policy_threshold
+ return (is_pos_adv | is_low_kl).to(dtype=mask.dtype) # (B, 1)
+
+ @staticmethod
+ def get_tpo_targets(
+ old_sequence_logps: torch.Tensor,
+ scores: torch.Tensor,
+ num_generations: int,
+ temperature: float = 1.0,
+ valid_mask: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ """
+ Build the Target Policy Optimization target distribution for each prompt group.
+
+ The target is q_i ∝ p_i_old * exp(score_i / temperature), where p_i_old is the rollout policy distribution
+ over the sampled completions in a prompt group.
+ """
+ if temperature <= 0.0:
+ raise ValueError(f"temperature must be greater than 0.0. You provided {temperature}.")
+
+ old_sequence_logps = old_sequence_logps.view(-1, num_generations)
+ scores = scores.view(-1, num_generations)
+ if valid_mask is not None:
+ valid_mask = valid_mask.view(-1, num_generations).bool()
+ old_sequence_logps = old_sequence_logps.masked_fill(~valid_mask, torch.finfo(old_sequence_logps.dtype).min)
+ target_logits = torch.log_softmax(old_sequence_logps, dim=1) + scores / temperature
+ if valid_mask is not None:
+ target_logits = target_logits.masked_fill(~valid_mask, torch.finfo(target_logits.dtype).min)
+ targets = torch.softmax(target_logits, dim=1)
+ if valid_mask is not None:
+ targets = torch.where(valid_mask, targets, torch.zeros_like(targets))
+ return targets.view(-1).detach()
+
+ @staticmethod
+ def get_tpo_scores(
+ scores: torch.Tensor, num_generations: int, valid_mask: torch.Tensor | None = None
+ ) -> torch.Tensor:
+ scores = scores.view(-1, num_generations)
+ if valid_mask is not None:
+ valid_mask = valid_mask.view(-1, num_generations).bool()
+ valid_count = valid_mask.sum(dim=1, keepdim=True).clamp(min=1)
+ mean_scores = scores.masked_fill(~valid_mask, 0.0).sum(dim=1, keepdim=True) / valid_count
+ centered_scores = torch.where(valid_mask, scores - mean_scores, torch.zeros_like(scores))
+ std_scores = (centered_scores.square().sum(dim=1, keepdim=True) / valid_count).sqrt()
+ else:
+ mean_scores = scores.mean(dim=1, keepdim=True)
+ std_scores = scores.std(dim=1, unbiased=False, keepdim=True)
+ centered_scores = scores - mean_scores
+ scores = torch.where(std_scores > 1e-6, centered_scores / std_scores, centered_scores)
+ return scores.view(-1)
+
+ @staticmethod
+ def _gather_tensor_with_grad(tensor: torch.Tensor) -> torch.Tensor:
+ # Autograd-aware all_gather: required when a TPO prompt group spans DP ranks, so the log-softmax
+ # normalizer's gradient routes back to the owning rank.
+ if not torch.distributed.is_available() or not torch.distributed.is_initialized():
+ return tensor
+ return torch.cat(_all_gather_with_grad(tensor), dim=0)
+
+ @staticmethod
+ @torch.no_grad()
+ def get_gamma_weights(
+ advantages: torch.Tensor,
+ log_ratio_per_token: torch.Tensor,
+ mask: torch.Tensor,
+ importance_sampling_ratio: torch.Tensor | None, # (B, T)
+ k_pos: float = 2.0,
+ lambda_pos: float = 3.0,
+ k_neg: float = 3.0,
+ lambda_neg: float = 2.0,
+ ) -> torch.Tensor:
+ """
+ Computes the Gamma weights for the VESPO loss. For reference:
+ φ(w) = e^λ × w^k × e^{-λw} is the gamma weighting (normalized so φ(1)=1)
+ with w = sequence-level importance sampling ratio
+ note: we will compute φ(w) in log space
+
+ φ(w) is detached via @torch.no_grad(), only acts as gradient scaling coefficient
+
+ VESPO loss = -φ(w) × A × log_prob, gradient naturally gives φ(w) × A × ∇log π
+ """
+ # reducing clamp range directly to log(1e-8) ~ -18.42, to avoid recomputing log_w=log(w.clamp(min=1e-8)) later
+ # This is solely for matching truthfully the original implementation, otherwise keeping -20 could be fine.
+ lower_clamp = math.log(1e-8)
+
+ # Sequence-level log ratio Σ log(π_θ/π_old) (not a mean like for `log_importance_weights`)
+ log_ratio_clamped = torch.clamp(log_ratio_per_token, -20.0, 20.0)
+ seq_log_ratio = torch.sum(log_ratio_clamped * mask, dim=-1, keepdim=True) # (B, 1)
+
+ # Apply token-level TIS or MIS correction (in log space)
+ if importance_sampling_ratio is not None:
+ log_is_ratio = torch.clamp(torch.log(importance_sampling_ratio), lower_clamp, 20.0)
+ # log(w) = log(π_θ/π_old) + log(π_old/π_sampler)
+ seq_log_ratio += torch.sum(log_is_ratio, dim=-1, keepdim=True)
+
+ log_w_seq = torch.clamp(seq_log_ratio, lower_clamp, 20.0)
+ w_seq = torch.exp(log_w_seq)
+
+ # compute k and lambda based on advantage sign
+ is_nonneg_adv = advantages >= 0
+ k_seq = torch.where(is_nonneg_adv, k_pos, k_neg)
+ lambda_seq = torch.where(is_nonneg_adv, lambda_pos, lambda_neg).clamp(min=1e-4)
+
+ # log(φ(w)) = λ + k × log(w) - λ × w
+ log_phi = lambda_seq + k_seq * log_w_seq - lambda_seq * w_seq
+ phi_seq = torch.exp(log_phi).nan_to_num(nan=0.0, posinf=0.0, neginf=0.0)
+
+ return phi_seq # (B, 1)
+
+ def _compute_loss(self, model, inputs):
+ # Compute the per-token log probabilities for the model
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
+ mask = completion_mask if "tool_mask" not in inputs else completion_mask * inputs["tool_mask"]
+
+ # Compute the per_token_logps and the entropy at each position in the completion
+ per_token_logps, entropies = self._get_per_token_logps_and_entropies(
+ model,
+ input_ids,
+ attention_mask,
+ logits_to_keep,
+ compute_entropy=True,
+ pixel_values=inputs.get("pixel_values"),
+ image_grid_thw=inputs.get("image_grid_thw"),
+ num_images=inputs.get("num_images"),
+ pixel_attention_mask=inputs.get("pixel_attention_mask"),
+ image_sizes=inputs.get("image_sizes"),
+ token_type_ids=inputs.get("token_type_ids"),
+ mm_token_type_ids=inputs.get("mm_token_type_ids"),
+ image_position_ids=inputs.get("image_position_ids"),
+ )
+
+ if self.top_entropy_quantile < 1.0:
+ entropy_mask = self.get_high_entropy_mask(entropies, mask, 1 - self.top_entropy_quantile)
+ else:
+ entropy_mask = None
+
+ # Compute the loss
+ advantages = inputs["advantages"]
+ # In the base GRPO implementation, advantages are expected to have shape (B,). To support subclasses that
+ # provide advantages with shape (B, T) (e.g., MiniLLM), we *conditionally* unsqueeze the tensor.
+ if advantages.dim() == 1:
+ advantages = advantages.unsqueeze(1)
+ # When num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps,
+ # old_per_token_logps == per_token_logps. In this case we can skip its computation
+ # (see _generate_and_score_completions) and instead use per_token_logps.detach().
+ # The exception is when using vLLM, where we always compute old_per_token_logps
+ # for importance sampling
+ old_per_token_logps = inputs.get("old_per_token_logps")
+ old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps
+
+ if self.loss_type == "tpo":
+ if "tpo_targets" not in inputs:
+ raise RuntimeError("TPO loss requires `tpo_targets` in the prepared inputs.")
+ if self.off_policy_mask_threshold is not None:
+ raise ValueError("TPO loss does not support `off_policy_mask_threshold`.")
+ if self.top_entropy_quantile < 1.0:
+ raise ValueError("TPO loss does not support `top_entropy_quantile < 1.0`.")
+
+ mode = "train" if self.model.training else "eval"
+ num_generations = self.num_generations if mode == "train" else self.num_generations_eval
+ # Length-normalized sequence logp (see _generate_and_score_completions for why).
+ sequence_logps = (per_token_logps * mask).sum(dim=-1)
+ if self.tpo_length_normalize_logps:
+ sequence_logps = sequence_logps / mask.sum(dim=-1).clamp(min=1)
+ all_sequence_logps = self._gather_tensor_with_grad(sequence_logps)
+ tpo_valid_mask = inputs.get("tpo_valid_mask")
+ if tpo_valid_mask is not None:
+ tpo_valid_mask = tpo_valid_mask.to(device=sequence_logps.device, dtype=torch.bool)
+ all_tpo_valid_mask = gather(tpo_valid_mask)
+ all_sequence_logps = all_sequence_logps.masked_fill(
+ ~all_tpo_valid_mask, torch.finfo(all_sequence_logps.dtype).min
+ )
+ all_logps = torch.log_softmax(all_sequence_logps.view(-1, num_generations), dim=1).view(-1)
+ process_slice = slice(
+ self.accelerator.process_index * sequence_logps.size(0),
+ (self.accelerator.process_index + 1) * sequence_logps.size(0),
+ )
+ logps = all_logps[process_slice]
+ tpo_targets = inputs["tpo_targets"].to(logps.dtype)
+ if tpo_valid_mask is not None:
+ tpo_targets = torch.where(tpo_valid_mask, tpo_targets, torch.zeros_like(tpo_targets))
+ loss = -(tpo_targets * logps).sum() * num_generations / tpo_targets.numel()
+
+ normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0
+ loss = loss / normalizer
+
+ if self.beta != 0.0:
+ ref_per_token_logps = inputs["ref_per_token_logps"]
+ per_token_kl = (
+ torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
+ )
+ if self.args.use_bias_correction_kl:
+ per_token_kl = per_token_kl * torch.exp(per_token_logps - old_per_token_logps)
+ kl_normalizer = (
+ self.accelerator.gather(mask.sum()).sum().clamp(min=1.0) / self.accelerator.num_processes
+ )
+ kl_loss = (per_token_kl * mask).sum() / kl_normalizer
+ loss = loss + self.beta * kl_loss / normalizer
+ self._metrics[mode]["kl"].append(self.accelerator.gather(kl_loss).nanmean().item())
+
+ completion_token_count = mask.sum().clamp(min=1.0)
+ mean_entropy = (entropies * mask).sum() / completion_token_count
+ self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item())
+
+ return loss
+
+ if self.off_policy_mask_threshold is not None:
+ # OPSM should use inference-time logprobs to detect both sources of off-policyness:
+ # 1. Drift from gradient updates (always present)
+ # 2. Drift from training-inference mismatch (when using vLLM)
+ # When using vLLM, prioritize sampling_per_token_logps, otherwise use old_per_token_logps
+ sampling_per_token_logps = inputs.get("sampling_per_token_logps", old_per_token_logps)
+
+ off_policy_mask = self.get_off_policy_mask(
+ advantages=advantages,
+ per_token_logps=per_token_logps,
+ sampling_per_token_logps=sampling_per_token_logps,
+ mask=mask,
+ off_policy_threshold=self.off_policy_mask_threshold,
+ )
+
+ log_ratio = per_token_logps - old_per_token_logps
+ if self.importance_sampling_level == "token":
+ log_importance_weights = log_ratio
+ elif self.importance_sampling_level == "sequence":
+ log_importance_weights = (log_ratio * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)
+ log_importance_weights = log_importance_weights.unsqueeze(-1)
+ else:
+ raise ValueError(
+ f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' "
+ "and 'sequence'."
+ )
+
+ coef_1 = torch.exp(log_importance_weights)
+
+ # Compute the KL divergence between the model and the reference model
+ if self.beta != 0.0:
+ ref_per_token_logps = inputs["ref_per_token_logps"]
+ per_token_kl = (
+ torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
+ )
+ # Importance sampling correction for the KL divergence
+ if self.args.use_bias_correction_kl:
+ per_token_kl = per_token_kl * coef_1
+
+ # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
+ # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
+ if self.loss_type == "cispo":
+ clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach()
+ per_token_loss = -clamped_ratios * advantages * per_token_logps
+ elif self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo", "luspo"]:
+ coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
+ # Two-sided clipping
+ if self.args.delta is not None:
+ coef_1 = torch.clamp(coef_1, max=self.args.delta)
+
+ per_token_loss1 = coef_1 * advantages
+ per_token_loss2 = coef_2 * advantages
+ per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
+ elif self.loss_type == "sapo":
+ temperatures = torch.where(advantages > 0, self.args.sapo_temperature_pos, self.args.sapo_temperature_neg)
+ soft_coef_1 = torch.sigmoid(temperatures * (coef_1 - 1)) * 4 / temperatures
+ per_token_loss = -soft_coef_1 * advantages
+ elif self.loss_type == "vespo":
+ phi_seq = self.get_gamma_weights(
+ advantages=advantages,
+ log_ratio_per_token=log_ratio,
+ mask=mask,
+ importance_sampling_ratio=inputs.get("importance_sampling_ratio"),
+ k_pos=self.args.vespo_k_pos,
+ lambda_pos=self.args.vespo_lambda_pos,
+ k_neg=self.args.vespo_k_neg,
+ lambda_neg=self.args.vespo_lambda_neg,
+ )
+ per_token_loss = -phi_seq * advantages * per_token_logps
+ else:
+ raise ValueError(f"Unknown loss type: {self.loss_type}")
+
+ if self.off_policy_mask_threshold is not None:
+ per_token_loss = per_token_loss * off_policy_mask
+
+ if entropy_mask is not None:
+ per_token_loss = per_token_loss * entropy_mask
+
+ if self.use_vllm and self.vllm_importance_sampling_correction and self.loss_type != "vespo":
+ per_token_loss = per_token_loss * inputs["importance_sampling_ratio"]
+
+ if self.beta != 0.0:
+ per_token_loss = per_token_loss + self.beta * per_token_kl
+
+ mode = "train" if self.model.training else "eval"
+ if self.loss_type in ["grpo", "sapo"]:
+ loss = ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean()
+ normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval
+ loss = loss / normalizer
+ elif self.loss_type == "bnpo":
+ loss = (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0)
+ normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval
+ loss = loss / normalizer
+ elif self.loss_type == "dr_grpo":
+ loss = (per_token_loss * mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
+ normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval
+ loss = loss / normalizer
+ elif self.loss_type in ["cispo", "dapo", "vespo"]:
+ normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes
+ loss = (per_token_loss * mask).sum() / normalizer
+ elif self.loss_type == "luspo":
+ # Unless importance_sampling_level="token" (not recommended here), per_token_loss is expected to be (B, 1)
+ loss = (per_token_loss * mask.sum(1, keepdim=True)).mean()
+ normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0
+ loss = loss / normalizer
+ else:
+ raise ValueError(f"Unknown loss type: {self.loss_type}")
+
+ # Log the metrics
+ completion_token_count = mask.sum().clamp(min=1.0)
+
+ def masked_batch_mean(x):
+ if x.shape[1] == 1: # when importance_sampling_level == "sequence"
+ return x.mean()
+ else:
+ return (x * mask).sum() / completion_token_count
+
+ if self.beta != 0.0:
+ mean_kl = masked_batch_mean(per_token_kl)
+ self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item())
+
+ mean_entropy = masked_batch_mean(entropies)
+ self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item())
+
+ if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo", "luspo"]:
+ # Compute the clipped probability ratios
+ is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages < 0)
+ is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages > 0)
+ is_region_clipped = is_low_clipped | is_high_clipped
+
+ low_clip = masked_batch_mean(is_low_clipped.float())
+ high_clip = masked_batch_mean(is_high_clipped.float())
+ clip_ratio = masked_batch_mean(is_region_clipped.float())
+
+ gathered_low_clip = self.accelerator.gather(low_clip)
+ self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item())
+ self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item())
+ gathered_high_clip = self.accelerator.gather(high_clip)
+ self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item())
+ self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
+ gathered_clip_ratio = self.accelerator.gather(clip_ratio)
+ self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
+ elif self.loss_type == "cispo":
+ is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages > 0)
+ cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float())
+ gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio)
+ self._metrics[mode]["cispo_clip_ratio"].append(gathered_cispo_clip_ratio.nanmean().item())
+ elif self.loss_type == "vespo":
+ gathered_phi_seq = self.accelerator.gather(phi_seq)
+ self._metrics[mode]["vespo/phi_seq_mean"].append(gathered_phi_seq.nanmean().item())
+
+ return loss
+
+ # During eval, Trainer calls prediction_step. If no labels are present in the inputs, it only runs forward and
+ # returns logits. We override prediction_step to force compute_loss, because this trainer doesn't involve labels.
+ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: list[str] | None = None):
+ inputs = self._prepare_inputs(inputs)
+ with torch.no_grad():
+ with self.compute_loss_context_manager():
+ loss = self.compute_loss(model, inputs)
+ loss = loss.mean().detach()
+ return loss, None, None
+
+ def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
+ mode = "train" if self.model.training else "eval"
+ # Average the metrics
+ metrics = {}
+ for key, val in self._metrics[mode].items():
+ # Filter out NaN values before averaging. A reward function that returns None for all samples
+ # in a batch produces NaN for that batch's metric. With logging_steps > 1, a naive sum()/len()
+ # would let a single NaN contaminate valid data from other batches. Only return None when no
+ # valid values remain (e.g. JSON loggers crash on float NaN).
+ valid = [v for v in val if not math.isnan(v)]
+ metrics[key] = sum(valid) / len(valid) if valid else None
+
+ # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
+ # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
+ if mode == "eval":
+ metrics = {f"eval_{key}": val for key, val in metrics.items()}
+
+ logs = {**logs, **metrics}
+ super().log(logs, start_time)
+ self._metrics[mode].clear()
+
+ if self.accelerator.is_main_process and self.log_completions:
+ if is_rich_available():
+ print_prompt_completions_sample(
+ self._logs["prompt"],
+ self._logs["completion"],
+ self._logs["rewards"],
+ self._logs["advantages"],
+ self.state.global_step,
+ self.num_completions_to_print,
+ extra=dict(self._logs["extra"]),
+ )
+
+ logging_backends = []
+ if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
+ logging_backends.append(wandb)
+ if self.args.report_to and "trackio" in self.args.report_to:
+ logging_backends.append(trackio)
+
+ table = {
+ "step": [self.state.global_step] * len(self._logs["prompt"]),
+ "prompt": self._logs["prompt"],
+ "completion": self._logs["completion"],
+ **self._logs["rewards"],
+ **self._logs["extra"],
+ "advantage": self._logs["advantages"],
+ }
+
+ df_base = pd.DataFrame(table)
+ df_base.to_parquet(
+ os.path.join(
+ self.args.output_dir,
+ "completions",
+ f"completions_{self.state.global_step:05d}.parquet",
+ )
+ )
+
+ images_raw = self._logs["images"] or []
+
+ for logging_backend in logging_backends:
+ if images_raw:
+ images = []
+ for image_list in self._logs["images"]:
+ if image_list:
+ images.append([logging_backend.Image(image) for image in image_list])
+ else:
+ images.append([])
+ df = pd.concat(
+ [df_base, pd.Series(images, name="image")],
+ axis=1,
+ copy=False,
+ )
+ else:
+ df = df_base
+
+ if self.log_unique_prompts:
+ df = df.drop_duplicates(subset=["prompt"])
+
+ logging_backend.log({"completions": logging_backend.Table(dataframe=df)})
+
+ # Ensure the model card is saved along with the checkpoint
+ def _save_checkpoint(self, model, trial):
+ if self.args.hub_model_id is None:
+ model_name = Path(self.args.output_dir).name
+ else:
+ model_name = self.args.hub_model_id.split("/")[-1]
+ self.create_model_card(model_name=model_name)
+ super()._save_checkpoint(model, trial)