diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index b5a7442e026..c8eb147ba78 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -25,6 +25,8 @@ title: RLOO - local: sft_trainer title: SFT + - local: target_po_trainer + title: TargetPO title: Trainers - sections: - local: clis diff --git a/docs/source/dataset_formats.md b/docs/source/dataset_formats.md index f2214bbf75f..4ea34a38b11 100644 --- a/docs/source/dataset_formats.md +++ b/docs/source/dataset_formats.md @@ -411,6 +411,7 @@ Choosing the right dataset type depends on the task you are working on and the s | [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) | | [`RLOOTrainer`] | [Prompt-only](#prompt-only) | | [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) | +| [`TargetPOTrainer`] | [Prompt-only](#prompt-only) | | [`experimental.bco.BCOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) | | [`experimental.cpo.CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | | [`experimental.gkd.GKDTrainer`] | [Prompt-completion](#prompt-completion) | diff --git a/docs/source/index.md b/docs/source/index.md index bbbdd9acde3..76497d2b15c 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -25,6 +25,7 @@ Below is the current list of TRL trainers, organized by method type (⚡️ = vL - [`GRPOTrainer`](grpo_trainer) ⚡️ - [`RLOOTrainer`](rloo_trainer) ⚡️ +- [`TargetPOTrainer`](target_po_trainer) ⚡️ - [`OnlineDPOTrainer`](online_dpo_trainer) 🧪 ⚡️ - [`NashMDTrainer`](nash_md_trainer) 🧪 ⚡️ - [`PPOTrainer`](ppo_trainer) 🧪 diff --git a/docs/source/liger_kernel_integration.md b/docs/source/liger_kernel_integration.md index 7a387c813fd..0c688a4dac1 100644 --- a/docs/source/liger_kernel_integration.md +++ b/docs/source/liger_kernel_integration.md @@ -14,6 +14,7 @@ Liger Kernel is supported in the following TRL trainers: - **SFT** (Supervised Fine-Tuning) - **DPO** (Direct Preference Optimization) - **GRPO** (Group Relative Policy Optimization) +- **TargetPO** (Target Policy Optimization) - **KTO** (Kahneman-Tversky Optimization) - **GKD** (Generalized Knowledge Distillation) @@ -54,6 +55,15 @@ from trl import GRPOConfig training_args = GRPOConfig(..., use_liger_kernel=True) ``` + + + +```python +from trl import TargetPOConfig + +training_args = TargetPOConfig(..., use_liger_kernel=True) +``` + diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 3fdccefff30..b573679346b 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -668,6 +668,30 @@ trainer.train() The official code [sail-sg/Stable-RL](https://github.com/sail-sg/Stable-RL) +### Target Policy Optimization + +**📜 Paper**: https://huggingface.co/papers/2604.06159 + +Target Policy Optimization (TPO) builds a target distribution over each prompt's sampled completions using rollout +policy probabilities and normalized rewards, then trains the policy to match that target with sequence-level +cross-entropy. To use TPO in TRL, use [`TargetPOTrainer`] or set `loss_type="tpo"` in [`GRPOConfig`]. The Python +class is named `TargetPO` to avoid collision with the experimental Triple Preference Optimization trainer that +shares the same acronym. + +```python +from trl import GRPOConfig, GRPOTrainer + +training_args = GRPOConfig( + loss_type="tpo", + tpo_target_temperature=1.0, +) + +trainer = GRPOTrainer( + ..., + args=training_args, +) +``` + ## Direct Policy Optimization Papers relating to the [`DPOTrainer`] diff --git a/docs/source/reducing_memory_usage.md b/docs/source/reducing_memory_usage.md index db78aded07e..609d56b36ae 100644 --- a/docs/source/reducing_memory_usage.md +++ b/docs/source/reducing_memory_usage.md @@ -159,6 +159,15 @@ from trl import GRPOConfig training_args = GRPOConfig(..., use_liger_kernel=True) ``` + + + +```python +from trl import TargetPOConfig + +training_args = TargetPOConfig(..., use_liger_kernel=True) +``` + diff --git a/docs/source/speeding_up_training.md b/docs/source/speeding_up_training.md index c855cc06233..b42401daf1d 100644 --- a/docs/source/speeding_up_training.md +++ b/docs/source/speeding_up_training.md @@ -168,6 +168,15 @@ from trl import GRPOConfig training_args = GRPOConfig(..., use_liger_kernel=True) ``` + + + +```python +from trl import TargetPOConfig + +training_args = TargetPOConfig(..., use_liger_kernel=True) +``` + diff --git a/docs/source/target_po_trainer.md b/docs/source/target_po_trainer.md new file mode 100644 index 00000000000..9bb7c3c3d3e --- /dev/null +++ b/docs/source/target_po_trainer.md @@ -0,0 +1,47 @@ +# TargetPO Trainer + +## Overview + +[`TargetPOTrainer`] implements Target Policy Optimization (TPO), an online post-training algorithm from [Target Policy Optimization](https://huggingface.co/papers/2604.06159). + +[`TargetPOTrainer`] keeps an aligned copy of the online rollout and reward flow used by [`GRPOTrainer`], but trains +with a sequence-level cross-entropy target: + +$$ +q_i = \frac{p_i^{\text{old}} \exp(u_i / \eta)}{\sum_j p_j^{\text{old}} \exp(u_j / \eta)} +$$ + +Here \\(p_i^{\text{old}}\\) is a *length-normalized* proxy for the rollout policy probability of completion \\(i\\) in +the prompt group (per-token mean log-probability by default, controlled by `tpo_length_normalize_logps`), \\(u_i\\) is +the population-whitened group reward, and \\(\eta\\) is `tpo_target_temperature`. Length-normalization prevents the +old-policy term from dominating the target when completions in a group have different lengths; set +`tpo_length_normalize_logps=False` to recover the paper's literal sequence-probability formulation. + +## Quick Start + +```python +from datasets import load_dataset +from trl import TargetPOTrainer +from trl.rewards import accuracy_reward + +dataset = load_dataset("trl-lib/DeepMath-103K", split="train") + +trainer = TargetPOTrainer( + model="Qwen/Qwen2-0.5B-Instruct", + reward_funcs=accuracy_reward, + train_dataset=dataset, +) +trainer.train() +``` + +## Configuration + +[`TargetPOConfig`] inherits the online rollout, reward, vLLM, tool-calling, and logging arguments from [`GRPOConfig`]. Because TargetPO uses a sequence-level softmax over every completion in a prompt group, [`TargetPOConfig`] defaults `steps_per_generation` to `1` when the user does not specify a generation schedule. Larger values are supported as long as each optimization step still contains whole prompt groups, i.e. `(generation_batch_size // steps_per_generation) % num_generations == 0`. + +## TargetPOConfig + +[[autodoc]] TargetPOConfig + +## TargetPOTrainer + +[[autodoc]] TargetPOTrainer diff --git a/examples/scripts/target_po.py b/examples/scripts/target_po.py new file mode 100644 index 00000000000..f0566efff3e --- /dev/null +++ b/examples/scripts/target_po.py @@ -0,0 +1,123 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl[peft]", +# "math-verify", +# "latex2sympy2_extended", +# "trackio", +# "kernels", +# ] +# /// + +""" +pip install math_verify + +accelerate launch \ + --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ + examples/scripts/target_po.py \ + --model_name_or_path Qwen/Qwen3-0.6B \ + --output_dir target_po-Qwen3-0.6B \ + --learning_rate 1e-5 \ + --dtype bfloat16 \ + --max_completion_length 1024 \ + --use_peft \ + --lora_target_modules "q_proj", "v_proj" \ + --log_completions \ + --per_device_train_batch_size 8 \ + --num_generations 8 \ + --beta 0.0 \ + --tpo_target_temperature 1.0 + +""" + +import torch +from datasets import load_dataset + +from trl import ( + ModelConfig, + ScriptArguments, + TargetPOConfig, + TargetPOTrainer, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) +from trl.rewards import accuracy_reward, think_format_reward + + +if __name__ == "__main__": + parser = TrlParser((ScriptArguments, TargetPOConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + ################ + # Model & Processor + ################ + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + training_args.model_init_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + dtype=dtype, + ) + quantization_config = get_quantization_config(model_args) + if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. + training_args.model_init_kwargs["device_map"] = get_kbit_device_map() + training_args.model_init_kwargs["quantization_config"] = quantization_config + + ################ + # Dataset + ################ + train_dataset, eval_dataset = load_dataset("AI-MO/NuminaMath-TIR", split=["train[:5%]", "test[:5%]"]) + + SYSTEM_PROMPT = ( + "A conversation between user and assistant. The user asks a question, and the assistant solves it. The " + "assistant first thinks about the reasoning process in the mind and then provides the user with the answer. " + "The reasoning process and answer are enclosed within tags, i.e., \nThis is my " + "reasoning.\n\nThis is my answer." + ) + + def make_conversation(example): + return { + "prompt": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": example["problem"]}, + ], + } + + train_dataset = train_dataset.map(make_conversation) + eval_dataset = eval_dataset.map(make_conversation) + + train_dataset = train_dataset.remove_columns(["messages", "problem"]) + eval_dataset = eval_dataset.remove_columns(["messages", "problem"]) + + ################ + # Training + ################ + trainer = TargetPOTrainer( + model=model_args.model_name_or_path, + args=training_args, + reward_funcs=[think_format_reward, accuracy_reward], + train_dataset=train_dataset, + eval_dataset=eval_dataset, + peft_config=get_peft_config(model_args), + ) + + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/tests/test_target_po_distributed.py b/tests/test_target_po_distributed.py new file mode 100644 index 00000000000..f21c8389c84 --- /dev/null +++ b/tests/test_target_po_distributed.py @@ -0,0 +1,78 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from trl import TargetPOTrainer + + +WORLD_SIZE = 2 +NUM_GENERATIONS = 4 # one group of 4 split across 2 ranks (2 per rank) +LOCAL_SEQ_LOGPS = [ + torch.tensor([0.1, -0.3]), # rank 0 + torch.tensor([0.5, -0.2]), # rank 1 +] +LOCAL_TARGETS = [ + torch.tensor([0.1, 0.2]), # rank 0 + torch.tensor([0.4, 0.3]), # rank 1; global sums to 1.0 +] + + +def _tpo_worker(rank: int, world_size: int, init_file: str) -> None: + dist.init_process_group( + backend="gloo", + init_method=f"file://{init_file}", + world_size=world_size, + rank=rank, + ) + try: + local = LOCAL_SEQ_LOGPS[rank].clone().requires_grad_(True) + + gathered = TargetPOTrainer._gather_tensor_with_grad(local) + logps = torch.log_softmax(gathered.view(-1, NUM_GENERATIONS), dim=1).view(-1) + + process_slice = slice(rank * local.size(0), (rank + 1) * local.size(0)) + local_logps = logps[process_slice] + local_targets = LOCAL_TARGETS[rank] + + loss = -(local_targets * local_logps).sum() * NUM_GENERATIONS / local_targets.numel() + loss.backward() + + global_logps = torch.cat(LOCAL_SEQ_LOGPS) + global_targets = torch.cat(LOCAL_TARGETS) + global_softmax = torch.softmax(global_logps, dim=0) + scale = NUM_GENERATIONS / local_targets.numel() + expected = scale * (global_softmax[process_slice] - global_targets[process_slice]) + + torch.testing.assert_close(local.grad, expected) + finally: + dist.destroy_process_group() + + +@pytest.mark.skipif(not torch.distributed.is_available(), reason="torch.distributed not available") +def test_tpo_gradient_across_ranks_with_group_spanning_ranks(): + """ + A TPO prompt group of size 4 split 2/2 across DP ranks. The group's log-softmax normalizer depends on + all four completions, so the autograd-aware all_gather must route gradient from each rank's loss back + to the owning rank's local tensor. Expected local gradient is scale * (softmax - target) at local positions. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + init_file = os.path.join(tmp_dir, "rendezvous") + mp.spawn(_tpo_worker, args=(WORLD_SIZE, init_file), nprocs=WORLD_SIZE, join=True) diff --git a/tests/test_target_po_trainer.py b/tests/test_target_po_trainer.py new file mode 100644 index 00000000000..275cf0c8caa --- /dev/null +++ b/tests/test_target_po_trainer.py @@ -0,0 +1,437 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from types import SimpleNamespace +from unittest.mock import patch + +import pytest +import torch +from accelerate.utils.memory import release_memory +from datasets import Dataset, load_dataset +from transformers import TrainingArguments + +from trl import GRPOTrainer, TargetPOConfig, TargetPOTrainer + +from .testing_utils import TrlTestCase, require_liger_kernel + + +class TestTargetPOConfig(TrlTestCase): + def test_defaults_to_one_step_per_generation(self): + config = TargetPOConfig(output_dir=self.tmp_dir, gradient_accumulation_steps=4) + + assert config.loss_type == "tpo" + assert config.steps_per_generation == 1 + + def test_allows_multi_step_generation_when_groups_are_whole(self): + config = TargetPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=4, + num_generations=2, + steps_per_generation=2, + ) + + assert config.steps_per_generation == 2 + per_step_batch = config.generation_batch_size // config.steps_per_generation + assert per_step_batch % config.num_generations == 0 + + def test_rejects_multi_step_generation_when_groups_are_cleaved(self): + with pytest.raises(ValueError, match="whole prompt groups"): + TargetPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + num_generations=4, + steps_per_generation=2, + ) + + def test_rejects_distributed_multi_step_generation_when_local_step_splits_groups(self): + with patch.object(TrainingArguments, "world_size", new=property(lambda self: 2)): + with pytest.raises(ValueError, match="per_device_train_batch_size"): + TargetPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=3, + num_generations=2, + steps_per_generation=2, + ) + + def test_allows_distributed_single_step_generation_with_groups_spanning_ranks(self): + with patch.object(TrainingArguments, "world_size", new=property(lambda self: 2)): + config = TargetPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=3, + num_generations=2, + steps_per_generation=1, + ) + + assert config.generation_batch_size == 6 + assert config.steps_per_generation == 1 + + def test_trainer_metadata(self): + assert TargetPOTrainer._name == "TargetPO" + assert TargetPOTrainer._tag_names == ["trl", "tpo"] + + def test_allows_liger_kernel(self): + config = TargetPOConfig(output_dir=self.tmp_dir, use_liger_kernel=True) + + assert config.use_liger_kernel + + +class TestTargetPOTrainer(TrlTestCase): + def test_training(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = TargetPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=8, + report_to="none", + ) + trainer = TargetPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + @require_liger_kernel + def test_training_with_liger_kernel(self): + from trl.trainer.target_po_trainer import LigerFusedLinearTargetPOLoss + + def reward_func(completions, **kwargs): + return [float((len(completion) % 3) - 1) for completion in completions] + + dataset = Dataset.from_dict( + { + "prompt": [ + "Say hello.", + "Name a color.", + "What is 1+1?", + "Write ok.", + "Pick a letter.", + "Say bye.", + ] + } + ) + training_args = TargetPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=8, + max_steps=1, + use_liger_kernel=True, + report_to="none", + logging_strategy="no", + save_strategy="no", + ) + trainer = TargetPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + assert isinstance(trainer.liger_tpo_loss, LigerFusedLinearTargetPOLoss) + + train_result = trainer.train() + + assert torch.isfinite(torch.tensor(train_result.training_loss)) + + release_memory(trainer.model, trainer) + + +class TestTPOLoss: + def test_tpo_scores_match_population_whitened_skill(self): + scores = torch.tensor([0.0, 1.0, 0.0, 0.0]) + + tpo_scores = TargetPOTrainer.get_tpo_scores(scores, num_generations=2) + + expected = torch.tensor([-1.0, 1.0, 0.0, 0.0]) + torch.testing.assert_close(tpo_scores, expected) + + def test_tpo_scores_exclude_invalid_completions(self): + scores = torch.tensor([0.0, 100.0, 1.0]) + valid_mask = torch.tensor([True, False, True]) + + tpo_scores = TargetPOTrainer.get_tpo_scores(scores, num_generations=3, valid_mask=valid_mask) + + expected = torch.tensor([-1.0, 0.0, 1.0]) + torch.testing.assert_close(tpo_scores, expected) + + def test_tpo_targets_use_population_whitened_scores(self): + old_sequence_logps = torch.zeros(2) + scores = torch.tensor([0.0, 1.0]) + tpo_scores = TargetPOTrainer.get_tpo_scores(scores, num_generations=2) + + targets = TargetPOTrainer.get_tpo_targets(old_sequence_logps, tpo_scores, num_generations=2) + + expected = torch.softmax(torch.tensor([-1.0, 1.0]), dim=0) + torch.testing.assert_close(targets, expected) + + def test_tpo_targets_match_closed_form(self): + old_sequence_logps = torch.log(torch.tensor([0.7, 0.3, 0.2, 0.8])) + scores = torch.tensor([1.0, -1.0, 0.0, 0.0]) + + targets = TargetPOTrainer.get_tpo_targets(old_sequence_logps, scores, num_generations=2) + + expected_first_group = torch.softmax(torch.log(torch.tensor([0.7, 0.3])) + torch.tensor([1.0, -1.0]), dim=0) + expected_second_group = torch.tensor([0.2, 0.8]) + expected = torch.cat([expected_first_group, expected_second_group]) + torch.testing.assert_close(targets, expected) + + def test_tpo_targets_exclude_invalid_completions(self): + old_sequence_logps = torch.log(torch.tensor([0.7, 0.3, 0.2, 0.8])) + scores = torch.tensor([1.0, -1.0, 0.0, 0.0]) + valid_mask = torch.tensor([True, False, True, True]) + + targets = TargetPOTrainer.get_tpo_targets(old_sequence_logps, scores, num_generations=2, valid_mask=valid_mask) + + expected = torch.tensor([1.0, 0.0, 0.2, 0.8]) + torch.testing.assert_close(targets, expected) + + @pytest.mark.parametrize("trainer_cls", [GRPOTrainer, TargetPOTrainer]) + def test_tpo_kl_uses_per_step_token_normalizer(self, trainer_cls): + trainer = object.__new__(trainer_cls) + trainer.loss_type = "tpo" + trainer.off_policy_mask_threshold = None + trainer.top_entropy_quantile = 1.0 + trainer.beta = 0.25 + trainer.num_generations = 2 + trainer.num_generations_eval = 2 + trainer.current_gradient_accumulation_steps = 2 + trainer.tpo_length_normalize_logps = False + trainer.args = SimpleNamespace(use_bias_correction_kl=False) + trainer.model = SimpleNamespace(training=True) + trainer.accelerator = SimpleNamespace( + gather=lambda tensor: tensor, + num_processes=1, + process_index=0, + ) + trainer._metrics = {"train": defaultdict(list)} + + per_token_logps = torch.zeros((2, 3)) + + def fake_get_per_token_logps_and_entropies(*args, **kwargs): + return per_token_logps, torch.zeros_like(per_token_logps) + + trainer._get_per_token_logps_and_entropies = fake_get_per_token_logps_and_entropies + + completion_mask = torch.tensor([[1, 1, 0], [1, 0, 0]]) + ref_log_ratio = torch.log(torch.tensor(2.0)) + ref_per_token_logps = per_token_logps + ref_log_ratio + inputs = { + "prompt_ids": torch.zeros((2, 1), dtype=torch.long), + "prompt_mask": torch.ones((2, 1), dtype=torch.long), + "completion_ids": torch.ones((2, 3), dtype=torch.long), + "completion_mask": completion_mask, + "advantages": torch.zeros(2), + "old_per_token_logps": per_token_logps, + "ref_per_token_logps": ref_per_token_logps, + "tpo_targets": torch.tensor([0.5, 0.5]), + "num_items_in_batch": torch.tensor(6), + } + + loss = trainer_cls._compute_loss(trainer, model=None, inputs=inputs) + + normalizer = torch.tensor(float(trainer.current_gradient_accumulation_steps)) + tpo_loss = ref_log_ratio / normalizer + expected_kl_loss = torch.exp(ref_log_ratio) - ref_log_ratio - 1 + expected_loss = tpo_loss + trainer.beta * expected_kl_loss / normalizer + torch.testing.assert_close(loss, expected_loss) + torch.testing.assert_close(torch.tensor(trainer._metrics["train"]["kl"][0]), expected_kl_loss) + + @pytest.mark.parametrize("length_normalize", [True, False]) + def test_tpo_loss_gradient_matches_policy_minus_target(self, length_normalize): + trainer = object.__new__(TargetPOTrainer) + trainer.loss_type = "tpo" + trainer.off_policy_mask_threshold = None + trainer.top_entropy_quantile = 1.0 + trainer.beta = 0.0 + trainer.num_generations = 2 + trainer.num_generations_eval = 2 + trainer.current_gradient_accumulation_steps = 1 + trainer.tpo_length_normalize_logps = length_normalize + trainer.model = SimpleNamespace(training=True) + trainer.accelerator = SimpleNamespace( + gather=lambda tensor: tensor, + num_processes=1, + process_index=0, + ) + trainer._metrics = {"train": defaultdict(list)} + + per_token_logps = torch.tensor([[-0.1, -0.2], [-0.7, -0.3]], requires_grad=True) + + def fake_get_per_token_logps_and_entropies(*args, **kwargs): + return per_token_logps, torch.zeros_like(per_token_logps) + + trainer._get_per_token_logps_and_entropies = fake_get_per_token_logps_and_entropies + + tpo_targets = torch.tensor([0.25, 0.75]) + inputs = { + "prompt_ids": torch.zeros((2, 1), dtype=torch.long), + "prompt_mask": torch.ones((2, 1), dtype=torch.long), + "completion_ids": torch.ones((2, 2), dtype=torch.long), + "completion_mask": torch.ones((2, 2), dtype=torch.long), + "advantages": torch.zeros(2), + "old_per_token_logps": per_token_logps.detach(), + "tpo_targets": tpo_targets, + "num_items_in_batch": torch.tensor(4), + } + + loss = TargetPOTrainer._compute_loss(trainer, model=None, inputs=inputs) + loss.backward() + + completion_mask = inputs["completion_mask"].to(per_token_logps.dtype) + lengths = completion_mask.sum(dim=-1).clamp(min=1) + summed = (per_token_logps.detach() * completion_mask).sum(dim=-1) + sequence_logps = summed / lengths if length_normalize else summed + expected_sequence_grad = torch.softmax(sequence_logps, dim=0) - tpo_targets + per_token_scale = (1.0 / lengths) if length_normalize else torch.ones_like(lengths) + expected_grad = (expected_sequence_grad * per_token_scale).unsqueeze(1) * completion_mask + torch.testing.assert_close(per_token_logps.grad, expected_grad) + + def test_tpo_loss_excludes_invalid_completions_from_group_softmax(self): + trainer = object.__new__(TargetPOTrainer) + trainer.loss_type = "tpo" + trainer.off_policy_mask_threshold = None + trainer.top_entropy_quantile = 1.0 + trainer.beta = 0.0 + trainer.num_generations = 3 + trainer.num_generations_eval = 3 + trainer.current_gradient_accumulation_steps = 1 + trainer.tpo_length_normalize_logps = True + trainer.model = SimpleNamespace(training=True) + trainer.accelerator = SimpleNamespace( + gather=lambda tensor: tensor, + num_processes=1, + process_index=0, + ) + trainer._metrics = {"train": defaultdict(list)} + + per_token_logps = torch.tensor([[-0.1, -0.2], [5.0, 5.0], [-0.4, -0.6]], requires_grad=True) + + def fake_get_per_token_logps_and_entropies(*args, **kwargs): + return per_token_logps, torch.zeros_like(per_token_logps) + + trainer._get_per_token_logps_and_entropies = fake_get_per_token_logps_and_entropies + + tpo_targets = torch.tensor([0.6, 0.0, 0.4]) + tpo_valid_mask = torch.tensor([True, False, True]) + inputs = { + "prompt_ids": torch.zeros((3, 1), dtype=torch.long), + "prompt_mask": torch.ones((3, 1), dtype=torch.long), + "completion_ids": torch.ones((3, 2), dtype=torch.long), + "completion_mask": torch.tensor([[1, 1], [0, 0], [1, 1]]), + "advantages": torch.zeros(3), + "old_per_token_logps": per_token_logps.detach(), + "tpo_targets": tpo_targets, + "tpo_valid_mask": tpo_valid_mask, + "num_items_in_batch": torch.tensor(4), + } + + loss = TargetPOTrainer._compute_loss(trainer, model=None, inputs=inputs) + loss.backward() + + completion_mask = inputs["completion_mask"].to(per_token_logps.dtype) + lengths = completion_mask.sum(dim=-1).clamp(min=1) + sequence_logps = (per_token_logps.detach() * completion_mask).sum(dim=-1) / lengths + valid_sequence_logps = sequence_logps[tpo_valid_mask] + expected_valid_grad = torch.softmax(valid_sequence_logps, dim=0) - tpo_targets[tpo_valid_mask] + expected_grad = torch.zeros_like(per_token_logps) + expected_grad[0] = (expected_valid_grad[0] / lengths[0]) * completion_mask[0] + expected_grad[2] = (expected_valid_grad[1] / lengths[2]) * completion_mask[2] + torch.testing.assert_close(per_token_logps.grad, expected_grad) + + @pytest.mark.parametrize("chunk_size", [1, 2]) + @pytest.mark.parametrize("length_normalize", [True, False]) + @pytest.mark.parametrize("use_bias", [True, False]) + @require_liger_kernel + def test_liger_tpo_loss_matches_dense_loss(self, chunk_size, length_normalize, use_bias): + from trl.trainer.target_po_trainer import LigerFusedLinearTargetPOLoss + + torch.manual_seed(0) + hidden_states = torch.randn(3, 2, 4, requires_grad=True) + weight = torch.randn(5, 4, requires_grad=True) + bias = torch.randn(5, requires_grad=True) if use_bias else None + completion_ids = torch.tensor([[0, 1], [2, 3], [4, 0]]) + completion_mask = torch.tensor([[1, 1], [0, 0], [1, 1]], dtype=torch.float32) + tpo_targets = torch.tensor([0.6, 0.0, 0.4]) + tpo_valid_mask = torch.tensor([True, False, True]) + ref_per_token_logps = torch.tensor([[-0.2, -0.4], [-0.1, -0.3], [-0.5, -0.7]]) + old_per_token_logps = torch.tensor([[-0.3, -0.5], [-0.2, -0.4], [-0.6, -0.8]]) + beta = 0.25 + normalizer = 2.0 + + dense_hidden_states = hidden_states.detach().clone().requires_grad_(True) + dense_weight = weight.detach().clone().requires_grad_(True) + dense_bias = bias.detach().clone().requires_grad_(True) if use_bias else None + + liger_loss = LigerFusedLinearTargetPOLoss( + beta=beta, + num_generations=3, + tpo_length_normalize_logps=length_normalize, + temperature=1.0, + chunk_size=chunk_size, + ) + loss, mean_kl, mean_entropy = liger_loss( + hidden_states, + weight, + completion_ids, + completion_mask, + tpo_targets, + tpo_valid_mask, + bias=bias, + ref_per_token_logps=ref_per_token_logps, + old_per_token_logps=old_per_token_logps, + normalizer=normalizer, + ) + loss.backward() + + logits = dense_hidden_states @ dense_weight.t() + if dense_bias is not None: + logits = logits + dense_bias + log_probs = torch.log_softmax(logits.float(), dim=-1) + per_token_logps = log_probs.gather(dim=-1, index=completion_ids.unsqueeze(-1)).squeeze(-1) + sequence_logps = (per_token_logps * completion_mask).sum(dim=-1) + if length_normalize: + sequence_logps = sequence_logps / completion_mask.sum(dim=-1).clamp(min=1.0) + sequence_logps = sequence_logps.masked_fill(~tpo_valid_mask, torch.finfo(sequence_logps.dtype).min) + logps = torch.log_softmax(sequence_logps.view(-1, 3), dim=1).view(-1) + dense_loss = -(tpo_targets * logps).sum() + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + ) + dense_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + dense_loss = dense_loss / normalizer + beta * dense_kl / normalizer + dense_entropy = (-(log_probs.exp() * log_probs).sum(dim=-1) * completion_mask).sum() / completion_mask.sum() + dense_loss.backward() + + torch.testing.assert_close(loss, dense_loss.detach()) + torch.testing.assert_close(mean_kl, dense_kl.detach()) + torch.testing.assert_close(mean_entropy, dense_entropy.detach()) + torch.testing.assert_close(hidden_states.grad, dense_hidden_states.grad) + torch.testing.assert_close(weight.grad, dense_weight.grad) + if use_bias: + torch.testing.assert_close(bias.grad, dense_bias.grad) diff --git a/trl/__init__.py b/trl/__init__.py index 4947fc16c56..de60a700288 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -66,6 +66,8 @@ "SFTConfig", "SFTTrainer", "SyncRefModelCallback", + "TargetPOConfig", + "TargetPOTrainer", "WeaveCallback", "get_kbit_device_map", "get_peft_config", @@ -114,6 +116,8 @@ SFTConfig, SFTTrainer, SyncRefModelCallback, + TargetPOConfig, + TargetPOTrainer, WeaveCallback, get_kbit_device_map, get_peft_config, diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index 1d79d365056..ecf59e35f00 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -38,6 +38,8 @@ "rloo_trainer": ["RLOOTrainer"], "sft_config": ["SFTConfig"], "sft_trainer": ["SFTTrainer"], + "target_po_config": ["TargetPOConfig"], + "target_po_trainer": ["TargetPOTrainer"], "utils": [ "disable_dropout_in_model", "ensure_master_addr_port", @@ -69,6 +71,8 @@ from .rloo_trainer import RLOOTrainer from .sft_config import SFTConfig from .sft_trainer import SFTTrainer + from .target_po_config import TargetPOConfig + from .target_po_trainer import TargetPOTrainer from .utils import ( disable_dropout_in_model, ensure_master_addr_port, diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index e5f373c19f0..1406a6ede7d 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -256,6 +256,17 @@ class GRPOConfig(_BaseConfig): - `"vespo"`: Variational Sequence-Level Soft Policy Optimization. Replaces hard clipping with a smooth, asymmetric Gamma weighting function applied directly to sequence-level importance weights. Introduced in the [VESPO paper](https://huggingface.co/papers/2602.10693). + - `"tpo"`: Target Policy Optimization loss. Builds a target distribution over each prompt's sampled + completions from the rollout policy probabilities and normalized rewards, then fits the current policy + to that target with cross-entropy. + tpo_target_temperature (`float`, *optional*, defaults to `1.0`): + Temperature used to build the Target Policy Optimization target distribution when `loss_type="tpo"`. + Lower values make the target more concentrated on high-scoring completions. + tpo_length_normalize_logps (`bool`, *optional*, defaults to `True`): + Whether to length-normalize sequence log-probabilities (per-token mean instead of sum) when building + the TPO target and computing the TPO loss. Recommended for real-length generations, because raw sum + sequence-logps vary over orders of magnitude across a group and make the old-policy term dominate the + target. Set to `False` to reproduce the paper's literal `p_i^old` formulation. mask_truncated_completions (`bool`, *optional*, defaults to `False`): When enabled, truncated completions are excluded from the loss calculation, preventing them from being incorrectly penalized and introducing noise during training. According to the @@ -738,7 +749,28 @@ class GRPOConfig(_BaseConfig): "paper](https://huggingface.co/papers/2602.05261)." "'vespo': Variational Sequence-Level Soft Policy Optimization. Replaces hard clipping with a smooth, " "asymmetric Gamma weighting function applied directly to sequence-level importance weights. Introduced in " - "the [VESPO paper](https://huggingface.co/papers/2602.10693)." + "the [VESPO paper](https://huggingface.co/papers/2602.10693). " + "'tpo': Target Policy Optimization loss. Builds a target distribution over each prompt's sampled " + "completions from rollout policy probabilities and normalized rewards, then fits the current policy to " + "that target with cross-entropy." + }, + ) + tpo_target_temperature: float = field( + default=1.0, + metadata={ + "help": "Temperature used to build the Target Policy Optimization target distribution when " + "`loss_type='tpo'`. Lower values make the target more concentrated on high-scoring completions." + }, + ) + tpo_length_normalize_logps: bool = field( + default=True, + metadata={ + "help": "Whether to length-normalize sequence log-probabilities (use per-token mean instead of sum) " + "when building the TPO target distribution and computing the TPO loss. Without this, sequences of " + "different lengths have log-probabilities spanning orders of magnitude, causing the old-policy term " + "in `q_i ∝ p_i^old * exp(u_i / eta)` to dominate and collapse the target to ~one-hot on the " + "highest-old-logp completion. Set to `False` to reproduce the paper's literal sequence-probability " + "formulation." }, ) mask_truncated_completions: bool = field( @@ -943,3 +975,34 @@ def __post_init__(self): if self.delta is not None and self.use_liger_kernel: raise ValueError("Liger kernel does not support two-sided GRPO loss yet.") + + if self.tpo_target_temperature <= 0.0: + raise ValueError( + f"tpo_target_temperature must be greater than 0.0. You provided {self.tpo_target_temperature}." + ) + + if self.loss_type == "tpo": + if self.use_liger_kernel: + raise ValueError("Liger kernel does not support the TPO loss yet.") + # TPO's target distribution is normalized per prompt group, so each optimization step must contain a whole + # number of groups. That holds whenever the per-step batch (generation_batch_size // steps_per_generation) + # is divisible by num_generations. + per_step_batch = self.generation_batch_size // self.steps_per_generation + if per_step_batch % self.num_generations != 0: + raise ValueError( + f"TPO requires each optimization step to contain whole prompt groups. With " + f"steps_per_generation={self.steps_per_generation} and num_generations={self.num_generations}, " + f"the per-step batch of {per_step_batch} is not divisible by num_generations. Increase " + f"per_device_train_batch_size or reduce steps_per_generation." + ) + if ( + num_processes > 1 + and self.steps_per_generation > 1 + and self.per_device_train_batch_size % self.num_generations != 0 + ): + raise ValueError( + f"TPO with distributed multi-step generation requires each rank's per-step batch to contain " + f"whole prompt groups. With per_device_train_batch_size={self.per_device_train_batch_size} and " + f"num_generations={self.num_generations}, the per-rank batch is not divisible by " + f"num_generations. Increase per_device_train_batch_size or reduce steps_per_generation." + ) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0f400e14fcd..a0579a1824d 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -41,6 +41,7 @@ from packaging.version import Version from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.nn.functional import all_gather as _all_gather_with_grad from torch.utils.data import Sampler from transformers import ( AutoModelForSequenceClassification, @@ -557,6 +558,8 @@ def __init__( self.vllm_importance_sampling_cap = args.vllm_importance_sampling_cap self.use_liger_kernel = args.use_liger_kernel self.loss_type = args.loss_type + self.tpo_target_temperature = args.tpo_target_temperature + self.tpo_length_normalize_logps = args.tpo_length_normalize_logps self.multi_objective_aggregation = args.multi_objective_aggregation self.scale_rewards = args.scale_rewards self.importance_sampling_level = args.importance_sampling_level @@ -1138,7 +1141,8 @@ def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> di # self._buffered_inputs=None can occur when resuming from a checkpoint generation_batch = self._generate_and_score_completions(generation_batch) generation_batch = split_pixel_values_by_grid(generation_batch) - generation_batch = shuffle_sequence_dict(generation_batch) + if self.loss_type != "tpo": + generation_batch = shuffle_sequence_dict(generation_batch) generation_batches = split_tensor_dict(generation_batch, self.args.steps_per_generation) self._buffered_inputs = [unsplit_pixel_values_by_grid(batch) for batch in generation_batches] inputs = self._buffered_inputs[self._step % self.args.steps_per_generation] @@ -2026,8 +2030,10 @@ def _generate_and_score_completions( # When using vLLM, we always compute old_per_token_logps for importance sampling, it was shown that the # distribution mismatch between vLLM and the training model can be large and harm the training. generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency - if self.args.gradient_accumulation_steps % generate_every != 0 or ( - self.use_vllm and self.vllm_importance_sampling_correction + if ( + self.loss_type == "tpo" + or self.args.gradient_accumulation_steps % generate_every != 0 + or (self.use_vllm and self.vllm_importance_sampling_correction) ): old_per_token_logps, _ = self._get_per_token_logps_and_entropies( self.model, @@ -2172,6 +2178,34 @@ def _generate_and_score_completions( (self.accelerator.process_index + 1) * len(prompts), ) all_process_advantages = advantages.clone() # keep the aggregated advantages for logging + tpo_targets = None + tpo_valid_mask = None + if self.loss_type == "tpo": + if old_per_token_logps is None: + raise RuntimeError("TPO requires rollout-time log probabilities to build the target distribution.") + loss_mask = completion_mask if tool_mask is None else completion_mask * tool_mask + tpo_valid_mask = loss_mask.any(dim=-1) + # Length-normalize sequence logps so the target distribution isn't dominated by length variance + # (without this, log_softmax(old_sequence_logps) collapses to ~one-hot on the longest/most-likely + # completion, and rewards can't compete). + old_sequence_logps = (old_per_token_logps * loss_mask).sum(dim=-1) + if self.tpo_length_normalize_logps: + old_sequence_logps = old_sequence_logps / loss_mask.sum(dim=-1).clamp(min=1) + all_process_old_sequence_logps = gather(old_sequence_logps) + all_process_tpo_valid_mask = gather(tpo_valid_mask) + tpo_scores = self.get_tpo_scores(rewards, num_generations, valid_mask=all_process_tpo_valid_mask) + all_process_tpo_targets = self.get_tpo_targets( + all_process_old_sequence_logps, + tpo_scores, + num_generations=num_generations, + temperature=self.tpo_target_temperature, + valid_mask=all_process_tpo_valid_mask, + ) + tpo_targets = all_process_tpo_targets[process_slice] + target_groups = all_process_tpo_targets.view(-1, num_generations) + target_entropy = -(target_groups * target_groups.clamp_min(torch.finfo(target_groups.dtype).tiny).log()) + target_entropy = target_entropy.sum(dim=1).mean() + self._metrics[mode]["tpo/target_entropy"].append(target_entropy.item()) advantages = advantages[process_slice] # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) @@ -2258,6 +2292,9 @@ def _generate_and_score_completions( } if old_per_token_logps is not None: output["old_per_token_logps"] = old_per_token_logps + if tpo_targets is not None: + output["tpo_targets"] = tpo_targets + output["tpo_valid_mask"] = tpo_valid_mask if self.use_vllm and self.vllm_importance_sampling_correction: output["importance_sampling_ratio"] = vllm_importance_sampling_ratio if sampling_per_token_logps is not None: @@ -2365,6 +2402,63 @@ def get_off_policy_mask( is_low_kl = avg_seq_kl <= off_policy_threshold return (is_pos_adv | is_low_kl).to(dtype=mask.dtype) # (B, 1) + @staticmethod + def get_tpo_targets( + old_sequence_logps: torch.Tensor, + scores: torch.Tensor, + num_generations: int, + temperature: float = 1.0, + valid_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Build the Target Policy Optimization target distribution for each prompt group. + + The target is q_i ∝ p_i_old * exp(score_i / temperature), where p_i_old is the rollout policy distribution + over the sampled completions in a prompt group. Completions with `valid_mask=False` (e.g. all tokens + masked out) get zero target probability and don't contribute to the group softmax. + """ + if temperature <= 0.0: + raise ValueError(f"temperature must be greater than 0.0. You provided {temperature}.") + + old_sequence_logps = old_sequence_logps.view(-1, num_generations) + scores = scores.view(-1, num_generations) + if valid_mask is not None: + valid_mask = valid_mask.view(-1, num_generations).bool() + old_sequence_logps = old_sequence_logps.masked_fill(~valid_mask, torch.finfo(old_sequence_logps.dtype).min) + target_logits = torch.log_softmax(old_sequence_logps, dim=1) + scores / temperature + if valid_mask is not None: + target_logits = target_logits.masked_fill(~valid_mask, torch.finfo(target_logits.dtype).min) + targets = torch.softmax(target_logits, dim=1) + if valid_mask is not None: + targets = torch.where(valid_mask, targets, torch.zeros_like(targets)) + return targets.view(-1).detach() + + @staticmethod + def get_tpo_scores( + scores: torch.Tensor, num_generations: int, valid_mask: torch.Tensor | None = None + ) -> torch.Tensor: + scores = scores.view(-1, num_generations) + if valid_mask is not None: + valid_mask = valid_mask.view(-1, num_generations).bool() + valid_count = valid_mask.sum(dim=1, keepdim=True).clamp(min=1) + mean_scores = scores.masked_fill(~valid_mask, 0.0).sum(dim=1, keepdim=True) / valid_count + centered_scores = torch.where(valid_mask, scores - mean_scores, torch.zeros_like(scores)) + std_scores = (centered_scores.square().sum(dim=1, keepdim=True) / valid_count).sqrt() + else: + mean_scores = scores.mean(dim=1, keepdim=True) + std_scores = scores.std(dim=1, unbiased=False, keepdim=True) + centered_scores = scores - mean_scores + scores = torch.where(std_scores > 1e-6, centered_scores / std_scores, centered_scores) + return scores.view(-1) + + @staticmethod + def _gather_tensor_with_grad(tensor: torch.Tensor) -> torch.Tensor: + # Autograd-aware all_gather: required when a TPO prompt group spans DP ranks, so the log-softmax + # normalizer's gradient routes back to the owning rank. + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + return tensor + return torch.cat(_all_gather_with_grad(tensor), dim=0) + @staticmethod @torch.no_grad() def get_gamma_weights( @@ -2460,6 +2554,62 @@ def _compute_loss(self, model, inputs): old_per_token_logps = inputs.get("old_per_token_logps") old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps + if self.loss_type == "tpo": + if "tpo_targets" not in inputs: + raise RuntimeError("TPO loss requires `tpo_targets` in the prepared inputs.") + if self.off_policy_mask_threshold is not None: + raise ValueError("TPO loss does not support `off_policy_mask_threshold`.") + if self.top_entropy_quantile < 1.0: + raise ValueError("TPO loss does not support `top_entropy_quantile < 1.0`.") + + mode = "train" if self.model.training else "eval" + num_generations = self.num_generations if mode == "train" else self.num_generations_eval + # Length-normalized sequence logp (see _generate_and_score_completions for why). + sequence_logps = (per_token_logps * mask).sum(dim=-1) + if self.tpo_length_normalize_logps: + sequence_logps = sequence_logps / mask.sum(dim=-1).clamp(min=1) + all_sequence_logps = self._gather_tensor_with_grad(sequence_logps) + tpo_valid_mask = inputs.get("tpo_valid_mask") + if tpo_valid_mask is not None: + tpo_valid_mask = tpo_valid_mask.to(device=sequence_logps.device, dtype=torch.bool) + all_tpo_valid_mask = gather(tpo_valid_mask) + all_sequence_logps = all_sequence_logps.masked_fill( + ~all_tpo_valid_mask, torch.finfo(all_sequence_logps.dtype).min + ) + all_logps = torch.log_softmax(all_sequence_logps.view(-1, num_generations), dim=1).view(-1) + process_slice = slice( + self.accelerator.process_index * sequence_logps.size(0), + (self.accelerator.process_index + 1) * sequence_logps.size(0), + ) + logps = all_logps[process_slice] + tpo_targets = inputs["tpo_targets"].to(logps.dtype) + if tpo_valid_mask is not None: + tpo_targets = torch.where(tpo_valid_mask, tpo_targets, torch.zeros_like(tpo_targets)) + loss = -(tpo_targets * logps).sum() * num_generations / tpo_targets.numel() + + normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 + loss = loss / normalizer + + if self.beta != 0.0: + ref_per_token_logps = inputs["ref_per_token_logps"] + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + ) + if self.args.use_bias_correction_kl: + per_token_kl = per_token_kl * torch.exp(per_token_logps - old_per_token_logps) + kl_normalizer = ( + self.accelerator.gather(mask.sum()).sum().clamp(min=1.0) / self.accelerator.num_processes + ) + kl_loss = (per_token_kl * mask).sum() / kl_normalizer + loss = loss + self.beta * kl_loss / normalizer + self._metrics[mode]["kl"].append(self.accelerator.gather(kl_loss).nanmean().item()) + + completion_token_count = mask.sum().clamp(min=1.0) + mean_entropy = (entropies * mask).sum() / completion_token_count + self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) + + return loss + if self.off_policy_mask_threshold is not None: # OPSM should use inference-time logprobs to detect both sources of off-policyness: # 1. Drift from gradient updates (always present) diff --git a/trl/trainer/target_po_config.py b/trl/trainer/target_po_config.py new file mode 100644 index 00000000000..0879102f3b6 --- /dev/null +++ b/trl/trainer/target_po_config.py @@ -0,0 +1,55 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from .grpo_config import GRPOConfig + + +@dataclass +class TargetPOConfig(GRPOConfig): + r""" + Configuration class for the [`TargetPOTrainer`]. + + This class extends [`GRPOConfig`] with defaults for Target Policy Optimization. For a full list of training + arguments, please refer to [`~transformers.TrainingArguments`] and [`GRPOConfig`]. + + Unless the user passes `generation_batch_size` or `steps_per_generation`, `TargetPOConfig` defaults + `steps_per_generation` to `1` as a safe starting point. Values greater than `1` are supported as long as each + optimization step still contains whole prompt groups (see [`GRPOConfig`] for the divisibility check). + + Parameters: + loss_type (`str`, *optional*, defaults to `"tpo"`): + Loss formulation. `TargetPOConfig` requires this to stay set to `"tpo"`. + tpo_target_temperature (`float`, *optional*, defaults to `1.0`): + Temperature used to build the TPO target distribution. Lower values make the target more concentrated on + high-scoring completions. + """ + + loss_type: str = field( + default="tpo", + metadata={"help": "Loss formulation. `TargetPOConfig` requires this to stay set to `tpo`."}, + ) + + def __post_init__(self): + if self.generation_batch_size is None and self.steps_per_generation is None: + self.steps_per_generation = 1 + + use_liger_kernel = self.use_liger_kernel + self.use_liger_kernel = False + super().__post_init__() + self.use_liger_kernel = use_liger_kernel + + if self.loss_type != "tpo": + raise ValueError(f"TargetPOConfig requires loss_type='tpo'. You provided {self.loss_type!r}.") diff --git a/trl/trainer/target_po_trainer.py b/trl/trainer/target_po_trainer.py new file mode 100644 index 00000000000..fef993b6ea8 --- /dev/null +++ b/trl/trainer/target_po_trainer.py @@ -0,0 +1,3233 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import atexit +import copy +import importlib.resources as pkg_resources +import inspect +import math +import os +import sys +import textwrap +import time +import warnings +from collections import defaultdict, deque +from collections.abc import Callable +from contextlib import nullcontext +from pathlib import Path +from typing import Any, Protocol + +import numpy as np +import pandas as pd +import torch +import torch.utils.data +import transformers +from accelerate.logging import get_logger +from accelerate.utils import gather, gather_object, is_peft_model, set_seed +from datasets import Dataset, IterableDataset +from huggingface_hub import CommitScheduler, DatasetCard, DatasetCardData, create_repo +from packaging.version import Version +from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.nn.functional import all_gather as _all_gather_with_grad +from torch.utils.data import Sampler +from transformers import ( + AutoModelForSequenceClassification, + AutoProcessor, + AutoTokenizer, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, + is_trackio_available, + is_wandb_available, +) +from transformers.utils import is_peft_available, is_rich_available + +from ..chat_template_utils import ( + add_response_schema, + get_training_chat_template, + is_chat_template_prefix_preserving, + parse_response, + supports_tool_calling, +) +from ..data_utils import apply_chat_template, is_conversational, prepare_multimodal_messages +from ..extras.profiling import profiling_context, profiling_decorator +from ..generation.vllm_generation import VLLMGeneration +from ..import_utils import is_jmespath_available, is_liger_kernel_available +from ..models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation +from ..models.utils import _ForwardRedirection, disable_gradient_checkpointing +from .base_trainer import _BaseTrainer +from .callbacks import SyncRefModelCallback +from .target_po_config import TargetPOConfig +from .utils import ( + RepeatSampler, + create_model_from_path, + disable_dropout_in_model, + entropy_from_logits, + get_config_model_id, + identity, + nanmax, + nanmin, + nanstd, + pad, + print_prompt_completions_sample, + selective_log_softmax, + shuffle_sequence_dict, + shutdown_event_loop_in_daemon, + split_pixel_values_by_grid, + split_tensor_dict, + start_event_loop_in_daemon, + unsplit_pixel_values_by_grid, + use_adapter, +) + + +if is_peft_available(): + from peft import PeftConfig, PeftModel, get_peft_model + +if is_liger_kernel_available(): + from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase + + + class LigerFusedLinearTargetPOFunction(torch.autograd.Function): + @staticmethod + def _get_process_info(): + if torch.distributed.is_available() and torch.distributed.is_initialized(): + return torch.distributed.get_rank(), torch.distributed.get_world_size() + return 0, 1 + + @staticmethod + def _gather_tensor(tensor: torch.Tensor) -> torch.Tensor: + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + return tensor + gathered = [torch.empty_like(tensor) for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(gathered, tensor) + return torch.stack(gathered) if tensor.dim() == 0 else torch.cat(gathered, dim=0) + + @staticmethod + def _chunk_logps_and_entropies( + input_chunk, weight, selected_token_ids_chunk, bias, temperature, compute_entropy + ): + log_probs, _ = LigerFusedLinearPPOBase.chunk_forward( + input_chunk, weight, bias=bias, temperature=temperature + ) + per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids_chunk.unsqueeze(-1)).squeeze(-1) + entropies = -(log_probs.exp() * log_probs).sum(dim=-1) if compute_entropy else None + return per_token_logps, entropies + + @staticmethod + def _compute_surrogate_loss( + input_chunk, + weight, + selected_token_ids_chunk, + attention_mask_chunk, + tpo_per_token_grads_chunk, + bias, + ref_per_token_logps_chunk, + old_per_token_logps_chunk, + beta, + kl_normalizer, + use_bias_correction_kl, + temperature, + ): + per_token_logps, _ = LigerFusedLinearTargetPOFunction._chunk_logps_and_entropies( + input_chunk, weight, selected_token_ids_chunk, bias, temperature, compute_entropy=False + ) + loss = (per_token_logps * tpo_per_token_grads_chunk).sum() + + if beta != 0.0: + if ref_per_token_logps_chunk is None: + raise RuntimeError("TargetPO Liger loss requires `ref_per_token_logps` when beta is non-zero.") + ref_per_token_logps_chunk = ref_per_token_logps_chunk.float() + old_per_token_logps_chunk = ( + per_token_logps.detach() + if old_per_token_logps_chunk is None + else old_per_token_logps_chunk.float() + ) + per_token_kl = ( + torch.exp(ref_per_token_logps_chunk - per_token_logps) + - (ref_per_token_logps_chunk - per_token_logps) + - 1 + ) + if use_bias_correction_kl: + per_token_kl = per_token_kl * torch.exp(per_token_logps - old_per_token_logps_chunk) + loss = loss + beta * (per_token_kl * attention_mask_chunk).sum() / kl_normalizer + + return loss + + @staticmethod + def forward( + ctx, + _input, + weight, + selected_token_ids, + attention_mask, + tpo_targets, + tpo_valid_mask, + bias=None, + ref_per_token_logps=None, + old_per_token_logps=None, + beta=0.0, + num_generations=1, + tpo_length_normalize_logps=True, + use_bias_correction_kl=False, + normalizer=1.0, + temperature=1.0, + chunk_size=1, + ): + chunk_size = max(1, int(chunk_size)) + chunks = max(1, math.ceil(_input.shape[0] / chunk_size)) + input_chunks = torch.chunk(_input, chunks=chunks, dim=0) + selected_token_ids_chunks = torch.chunk(selected_token_ids, chunks=chunks, dim=0) + attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0) + ref_per_token_logps_chunks = ( + torch.chunk(ref_per_token_logps, chunks=chunks, dim=0) + if ref_per_token_logps is not None + else [None] * chunks + ) + old_per_token_logps_chunks = ( + torch.chunk(old_per_token_logps, chunks=chunks, dim=0) + if old_per_token_logps is not None + else [None] * chunks + ) + + per_token_logps_chunks = [] + entropy_sum = torch.zeros((), device=_input.device, dtype=torch.float32) + with torch.no_grad(): + for input_chunk, selected_token_ids_chunk, attention_mask_chunk in zip( + input_chunks, selected_token_ids_chunks, attention_mask_chunks, strict=True + ): + per_token_logps_chunk, entropies_chunk = ( + LigerFusedLinearTargetPOFunction._chunk_logps_and_entropies( + input_chunk, weight, selected_token_ids_chunk, bias, temperature, compute_entropy=True + ) + ) + loss_mask_chunk = attention_mask_chunk.to(per_token_logps_chunk.dtype) + per_token_logps_chunks.append(per_token_logps_chunk.float()) + entropy_sum.add_((entropies_chunk * loss_mask_chunk).sum()) + + per_token_logps = torch.cat(per_token_logps_chunks, dim=0) + loss_mask = attention_mask.to(per_token_logps.dtype) + tpo_targets = torch.where(tpo_valid_mask, tpo_targets, torch.zeros_like(tpo_targets)) + tpo_targets = tpo_targets.to(per_token_logps.dtype) + + sequence_lengths = loss_mask.sum(dim=-1).clamp(min=1.0) + sequence_logps = (per_token_logps * loss_mask).sum(dim=-1) + if tpo_length_normalize_logps: + sequence_logps = sequence_logps / sequence_lengths + + rank, world_size = LigerFusedLinearTargetPOFunction._get_process_info() + all_sequence_logps = LigerFusedLinearTargetPOFunction._gather_tensor(sequence_logps) + all_tpo_targets = LigerFusedLinearTargetPOFunction._gather_tensor(tpo_targets) + all_tpo_valid_mask = LigerFusedLinearTargetPOFunction._gather_tensor(tpo_valid_mask) + all_sequence_logps = all_sequence_logps.masked_fill( + ~all_tpo_valid_mask, torch.finfo(all_sequence_logps.dtype).min + ) + + all_logps = torch.log_softmax(all_sequence_logps.view(-1, num_generations), dim=1).view(-1) + process_slice = slice(rank * sequence_logps.size(0), (rank + 1) * sequence_logps.size(0)) + logps = all_logps[process_slice] + loss_scale = num_generations / tpo_targets.numel() + tpo_loss = -(tpo_targets * logps).sum() * loss_scale + + all_probs = torch.softmax(all_sequence_logps.view(-1, num_generations), dim=1) + target_groups = all_tpo_targets.view(-1, num_generations) + target_group_sums = target_groups.sum(dim=1, keepdim=True) + all_sequence_grads = (all_probs * target_group_sums - target_groups) * loss_scale + sequence_grads = all_sequence_grads.view(-1)[process_slice] + if tpo_length_normalize_logps: + tpo_per_token_grads = sequence_grads.unsqueeze(1) * loss_mask / sequence_lengths.unsqueeze(1) + else: + tpo_per_token_grads = sequence_grads.unsqueeze(1) * loss_mask + tpo_per_token_grads = tpo_per_token_grads / normalizer + + kl_normalizer = loss_mask.sum() + kl_normalizer = LigerFusedLinearTargetPOFunction._gather_tensor(kl_normalizer).sum() / world_size + kl_normalizer = kl_normalizer.clamp(min=1.0) + + kl_loss = torch.zeros((), device=_input.device, dtype=torch.float32) + if beta != 0.0: + if ref_per_token_logps is None: + raise RuntimeError("TargetPO Liger loss requires `ref_per_token_logps` when beta is non-zero.") + old_per_token_logps_for_kl = ( + per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps.float() + ) + per_token_kl = ( + torch.exp(ref_per_token_logps.float() - per_token_logps) + - (ref_per_token_logps.float() - per_token_logps) + - 1 + ) + if use_bias_correction_kl: + per_token_kl = per_token_kl * torch.exp(per_token_logps - old_per_token_logps_for_kl) + kl_loss = (per_token_kl * loss_mask).sum() / kl_normalizer + + loss = tpo_loss / normalizer + beta * kl_loss / normalizer + mean_entropy = entropy_sum / loss_mask.sum().clamp(min=1.0) + + tpo_per_token_grad_chunks = torch.chunk(tpo_per_token_grads, chunks=chunks, dim=0) + grad_weight = torch.zeros_like(weight) + grad_inputs = [] + grad_bias = torch.zeros_like(bias) if bias is not None else None + + for ( + input_chunk, + selected_token_ids_chunk, + attention_mask_chunk, + tpo_per_token_grads_chunk, + ref_per_token_logps_chunk, + old_per_token_logps_chunk, + ) in zip( + input_chunks, + selected_token_ids_chunks, + attention_mask_chunks, + tpo_per_token_grad_chunks, + ref_per_token_logps_chunks, + old_per_token_logps_chunks, + strict=True, + ): + loss_mask_chunk = attention_mask_chunk.to(tpo_per_token_grads_chunk.dtype) + if bias is not None: + + def compute_surrogate( + input_chunk, + weight, + bias, + selected_token_ids_chunk, + loss_mask_chunk, + tpo_per_token_grads_chunk, + ref_per_token_logps_chunk, + old_per_token_logps_chunk, + ): + return LigerFusedLinearTargetPOFunction._compute_surrogate_loss( + input_chunk, + weight, + selected_token_ids_chunk, + loss_mask_chunk, + tpo_per_token_grads_chunk, + bias, + ref_per_token_logps_chunk, + old_per_token_logps_chunk, + beta / normalizer, + kl_normalizer, + use_bias_correction_kl, + temperature, + ) + + chunk_grad_input, chunk_grad_weight, chunk_grad_bias = torch.func.grad( + compute_surrogate, argnums=(0, 1, 2) + )( + input_chunk, + weight, + bias, + selected_token_ids_chunk, + loss_mask_chunk, + tpo_per_token_grads_chunk, + ref_per_token_logps_chunk, + old_per_token_logps_chunk, + ) + grad_bias.add_(chunk_grad_bias) + else: + + def compute_surrogate( + input_chunk, + weight, + selected_token_ids_chunk, + loss_mask_chunk, + tpo_per_token_grads_chunk, + ref_per_token_logps_chunk, + old_per_token_logps_chunk, + ): + return LigerFusedLinearTargetPOFunction._compute_surrogate_loss( + input_chunk, + weight, + selected_token_ids_chunk, + loss_mask_chunk, + tpo_per_token_grads_chunk, + None, + ref_per_token_logps_chunk, + old_per_token_logps_chunk, + beta / normalizer, + kl_normalizer, + use_bias_correction_kl, + temperature, + ) + + chunk_grad_input, chunk_grad_weight = torch.func.grad(compute_surrogate, argnums=(0, 1))( + input_chunk, + weight, + selected_token_ids_chunk, + loss_mask_chunk, + tpo_per_token_grads_chunk, + ref_per_token_logps_chunk, + old_per_token_logps_chunk, + ) + + grad_inputs.append(chunk_grad_input) + grad_weight.add_(chunk_grad_weight) + + grad_input = torch.cat(grad_inputs, dim=0) + if bias is not None: + ctx.save_for_backward(grad_input, grad_weight, grad_bias) + else: + ctx.save_for_backward(grad_input, grad_weight) + ctx.has_bias = bias is not None + return loss, kl_loss, mean_entropy + + @staticmethod + def backward(ctx, grad_output, *grad_metrics): + if ctx.has_bias: + grad_input, grad_weight, grad_bias = ctx.saved_tensors + else: + grad_input, grad_weight = ctx.saved_tensors + grad_bias = None + + grad_input = grad_input * grad_output + grad_weight = grad_weight * grad_output + if grad_bias is not None: + grad_bias = grad_bias * grad_output + + return ( + grad_input, + grad_weight, + None, # selected_token_ids + None, # attention_mask + None, # tpo_targets + None, # tpo_valid_mask + grad_bias, + None, # ref_per_token_logps + None, # old_per_token_logps + None, # beta + None, # num_generations + None, # tpo_length_normalize_logps + None, # use_bias_correction_kl + None, # normalizer + None, # temperature + None, # chunk_size + ) + + + class LigerFusedLinearTargetPOLoss(nn.Module): + def __init__( + self, + beta: float = 0.0, + num_generations: int = 1, + tpo_length_normalize_logps: bool = True, + use_bias_correction_kl: bool = False, + temperature: float = 1.0, + chunk_size: int = 1, + ): + super().__init__() + self.beta = beta + self.num_generations = num_generations + self.tpo_length_normalize_logps = tpo_length_normalize_logps + self.use_bias_correction_kl = use_bias_correction_kl + self.temperature = temperature + self.chunk_size = chunk_size + + def forward( + self, + _input, + lin_weight, + selected_token_ids, + attention_mask, + tpo_targets, + tpo_valid_mask=None, + bias=None, + ref_per_token_logps=None, + old_per_token_logps=None, + normalizer=1.0, + ): + if tpo_valid_mask is None: + tpo_valid_mask = torch.ones_like(tpo_targets, dtype=torch.bool) + return LigerFusedLinearTargetPOFunction.apply( + _input, + lin_weight, + selected_token_ids, + attention_mask, + tpo_targets, + tpo_valid_mask, + bias, + ref_per_token_logps, + old_per_token_logps, + self.beta, + self.num_generations, + self.tpo_length_normalize_logps, + self.use_bias_correction_kl, + normalizer, + self.temperature, + self.chunk_size, + ) + + +if is_wandb_available(): + import wandb + +if is_trackio_available(): + import trackio + + +logger = get_logger(__name__) + +# A reward function can be a string, interpreted as a model ID and loaded as a pretrained model, a pretrained model, or +# a callable that returns a list of floats (the rewards). The callable receives prompts, completions, and additional +# arguments from the trainer (refer to the trainer's source for details). To ensure forward compatibility, it should +# accept **kwargs. +RewardFunc = str | PreTrainedModel | Callable[..., list[float | None]] + +# What we call a rollout function is a callable that takes prompts (list) and the trainer instance as parameters and +# returns a dict of generation results. Those results must include "prompt_ids", "completion_ids", and "logprobs" +# fields. Any extra fields (per-completion) are forwarded to the reward functions. +RolloutFunc = Callable[[list[str], "TargetPOTrainer"], dict[str, Any]] + + +class _SupportsReset(Protocol): + def reset(self, **kwargs) -> str | None: ... + + +EnvironmentFactory = Callable[[], _SupportsReset] + + +class TargetPOTrainer(_BaseTrainer): + """ + Trainer for the Target Policy Optimization (TargetPO) method, from the paper [Target Policy + Optimization](https://huggingface.co/papers/2604.06159). + + Example: + + ```python + from trl import TargetPOTrainer + from trl.rewards import accuracy_reward + from datasets import load_dataset + + dataset = load_dataset("trl-lib/DeepMath-103K", split="train") + + trainer = TargetPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", + reward_funcs=accuracy_reward, + train_dataset=dataset, + ) + trainer.train() + ``` + + Args: + model (`str` or [`~transformers.PreTrainedModel`] or [`~peft.PeftModel`]): + Model to be trained. Can be either: + + - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a + path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded + using `.from_pretrained` (where `` is derived from the model + config) with the keyword arguments in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. + - A [`~peft.PeftModel`] object. Only causal language models are supported. + reward_funcs (`RewardFunc | list[RewardFunc]`): + Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward + functions with the prompts and completions and sum the rewards. Can be either: + + - A single reward function, such as: + - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a + path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded + using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the + keyword arguments in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported. + - A custom reward function: The function is provided with the prompts and the generated completions, + plus any additional columns in the dataset. It should return a list of rewards. Custom reward + functions can be either synchronous or asynchronous and can also return `None` when the reward is + not applicable to those samples. This is useful for multi-task training where different reward + functions apply to different types of samples. When a reward function returns `None` for a sample, + that reward function is excluded from the reward calculation for that sample. For more details, see + [Using a custom reward + function](#using-a-custom-reward-function). + + The trainer's state is also passed to the reward function. The trainer's state is an instance of + [`~transformers.TrainerState`] and can be accessed by accessing the `trainer_state` argument to the + reward function's signature. + - A list of reward functions, where each item can independently be any of the above types. Mixing different + types within the list (e.g., a string model ID and a custom reward function) is allowed. + args ([`TargetPOConfig`], *optional*): + Configuration for this trainer. If `None`, a default configuration is used. + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): + Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is + ignored. The format of the samples can be either: + + - [Standard](dataset_formats#standard): Each sample contains plain text. + - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role + and content). + eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`): + Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*): + Processing class used to process the data. The padding side must be set to "left". If `None`, the + processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A + padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token, + `tokenizer.eos_token` will be used as the default. + reward_processing_classes ([`~transformers.PreTrainedTokenizerBase`] or `list[PreTrainedTokenizerBase]`, *optional*): + Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either: + + - A single processing class: Used when `reward_funcs` contains only one reward function. + - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`. + If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is + `None`, the tokenizer for the model is automatically loaded using + [`~transformers.AutoTokenizer.from_pretrained`]. For elements in `reward_funcs` that are custom reward + functions (not [`~transformers.PreTrainedModel`]), the corresponding entries in `reward_processing_classes` + are ignored. + callbacks (list of [`~transformers.TrainerCallback`], *optional*): + List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed + in [here](https://huggingface.co/docs/transformers/main_classes/callback). + + If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] + method. + optimizers (`tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your + model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`. + peft_config ([`~peft.PeftConfig`], *optional*): + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + tools (list of `Callable`, *optional*): + A list of callable tool functions (sync or async) that the model can invoke during generation. Each tool + should be a standard Python function with properly type-hinted arguments and return values, and a + Google-style docstring describing its purpose, arguments, and return value. For more details, see: + https://huggingface.co/docs/transformers/en/chat_extras#passing-tools. The model uses the function's name, + type hints, and docstring to determine how to call it. Ensure that the model's chat template supports tool + use and that it has been fine-tuned for tool calling. + rollout_func (`RolloutFunc`, *optional*): + Function to use for generating completions. It receives the list of prompts allocated to the current + process and the trainer instance. It must return a dict with `"prompt_ids"`, `"completion_ids"`, and + `"logprobs"` fields, and can optionally return `"logprob_token_ids"` (same shape as `"logprobs"`). Any + other fields are forwarded to the reward functions. The function receives the raw per-process prompt slice + with no duplication; it is responsible for returning the correct number of completions per prompt (see + `num_generations` / `num_generations_eval` on the trainer). This feature is experimental and may change or + be removed at any time without prior notice. + environment_factory (`EnvironmentFactory`, *optional*): + A callable that creates and returns an environment instance. The environment class should define methods + that can be invoked as tools during generation. Each method should comply with the same requirements as the + `tools` described above. If `environment_factory` is provided, an instance of the environment is created + for each generation in the batch, allowing for parallel and independent interactions. The environment must + also implement a callable `reset` method that can be used to reset state between generations. The `reset` + method should return either `None` or a string: when it returns a string, that string is appended to the + last user message before generation. This feature is experimental and may change or be removed at any time + without prior notice. + """ + + _tag_names = ["trl", "tpo"] + _name = "TargetPO" + _paper = { + "title": "Target Policy Optimization", + "id": "2604.06159", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @misc{kaddour2026targetpolicyoptimization, + title = {{Target Policy Optimization}}, + author = {Jean Kaddour}, + year = 2026, + eprint = {arXiv:2604.06159}, + }"""), + } + + def __init__( + self, + model: "str | PreTrainedModel | PeftModel", + reward_funcs: RewardFunc | list[RewardFunc], + args: TargetPOConfig | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, + processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, + reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), + peft_config: "PeftConfig | None" = None, + tools: list[Callable] | None = None, + rollout_func: RolloutFunc | None = None, + environment_factory: EnvironmentFactory | None = None, + ): + # Args + if args is None: + model_name = model if isinstance(model, str) else get_config_model_id(model.config) + model_name = model_name.split("/")[-1] + args = TargetPOConfig(f"{model_name}-TargetPO") + + # Model + if isinstance(model, str): + model_init_kwargs = args.model_init_kwargs or {} + # Distributed training requires device_map=None ("auto" fails) + if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + model_init_kwargs["device_map"] = None + model = create_model_from_path(model, **model_init_kwargs) + else: + if args.model_init_kwargs is not None: + logger.warning( + "You passed `model_init_kwargs` to the `TargetPOConfig`, but your model is already instantiated. " + "The `model_init_kwargs` will be ignored." + ) + + # Some models (SmolVLM/Idefics3) don't support `logits_to_keep` argument and error out if we pass it + # Inspect the forward method before we wrap the model with PEFT + self.model_kwarg_keys = ( + inspect.signature(model.forward).parameters.keys() + if not hasattr(model, "get_base_model") + else inspect.signature(model.get_base_model().forward).parameters.keys() + ) + + # Processing class + if processing_class is None: + processing_class = AutoProcessor.from_pretrained( + get_config_model_id(model.config), truncation_side="left", padding_side="left" + ) + + # Handle pad token for processors or tokenizers + if isinstance(processing_class, ProcessorMixin): + tokenizer = processing_class.tokenizer + self._is_vlm = True + elif isinstance(processing_class, PreTrainedTokenizerBase): + tokenizer = processing_class + self._is_vlm = False + else: + raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") + + # Resolve vision placeholder token IDs once. Used by the forward pass to rebuild mm_token_type_ids + # when tool responses inject images into the completion (see _generate forward_kwargs block). + self._image_pad_token_id = None + self._video_pad_token_id = None + if self._is_vlm: + for candidate in ("<|image_pad|>", "<|image|>"): + tid = tokenizer.convert_tokens_to_ids(candidate) + if tid != tokenizer.unk_token_id: + self._image_pad_token_id = tid + break + tid = tokenizer.convert_tokens_to_ids("<|video_pad|>") + if tid != tokenizer.unk_token_id: + self._video_pad_token_id = tid + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + self.pad_token_id = tokenizer.pad_token_id + self.eos_token_id = tokenizer.eos_token_id + + if is_peft_available() and is_peft_model(model) and peft_config is not None: + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " + "and unload the existing adapter, save the resulting base model, and then pass that base model along " + "with the new `peft_config` to the trainer." + ) + if is_peft_available() and is_peft_model(model) and args.beta != 0.0: + # If the model is a PEFT model with a pretrained adapter, we need to create a "ref" adapter that is a copy + # of the "default" adapter, so that we can use it as the reference model during GRPO training. + model.add_adapter("ref", model.peft_config["default"]) + for name, param in model.named_parameters(): + if ".default." in name: + ref_name = name.replace(".default.", ".ref.") + ref_param = model.get_parameter(ref_name) + ref_param.data.copy_(param.data) + + # Create PEFT model + if peft_config is not None: + model = get_peft_model(model, peft_config) + + # When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally + # handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489 + if is_peft_available() and is_peft_model(model) and args.gradient_checkpointing: + model.enable_input_require_grads() + + # When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the + # original paper (see https://huggingface.co/papers/2305.14314, paragraph 3). Normally, this can be done by + # passing `autocast_adapter_dtype=False` to `get_peft_model`, but this option is not yet supported for + # quantized models. See: https://github.com/huggingface/peft/issues/2889 + # Non-quantized models do not have the `is_loaded_in_{8,4}bit` attributes, whereas quantized models do + if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False): + for param in model.parameters(): + if param.requires_grad: + param.data = param.data.to(torch.bfloat16) + + # Reward functions + if not isinstance(reward_funcs, list): + reward_funcs = [reward_funcs] + self.reward_func_names = [] + for i, reward_func in enumerate(reward_funcs): + if isinstance(reward_func, str): + model_init_kwargs = args.model_init_kwargs or {} + # Distributed training requires device_map=None ("auto" fails) + if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + model_init_kwargs["device_map"] = None + reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( + reward_func, num_labels=1, **model_init_kwargs + ) + if isinstance(reward_funcs[i], nn.Module): # Use Module over PretrainedModel for compat w/ compiled models + self.reward_func_names.append(get_config_model_id(reward_funcs[i].config).split("/")[-1]) + else: + self.reward_func_names.append(reward_funcs[i].__name__) + self.reward_funcs = reward_funcs + + # Reward weights + if args.reward_weights is not None: + if len(args.reward_weights) != len(reward_funcs): + raise ValueError( + f"Number of reward weights ({len(args.reward_weights)}) must match number of reward " + f"functions ({len(reward_funcs)})" + ) + self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) + else: + self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32) + + # Reward processing class + if reward_processing_classes is None: + reward_processing_classes = [None] * len(reward_funcs) + elif not isinstance(reward_processing_classes, list): + reward_processing_classes = [reward_processing_classes] + if len(reward_processing_classes) != len(reward_funcs): + raise ValueError( + f"The number of reward processing classes ({len(reward_processing_classes)}) must match the number of " + f"reward functions ({len(reward_funcs)})." + ) + + for i, (reward_processing_class, reward_func) in enumerate( + zip(reward_processing_classes, reward_funcs, strict=True) + ): + if isinstance(reward_func, PreTrainedModel): + if reward_processing_class is None: + reward_processing_class = AutoTokenizer.from_pretrained(get_config_model_id(reward_func.config)) + if reward_processing_class.pad_token_id is None: + reward_processing_class.pad_token = reward_processing_class.eos_token + # The reward model computes the reward for the latest non-padded token in the input sequence. + # So it's important to set the pad token ID to the padding token ID of the processing class. + reward_func.config.pad_token_id = reward_processing_class.pad_token_id + reward_processing_classes[i] = reward_processing_class + + self.reward_processing_classes = reward_processing_classes + + # Rollout function + if rollout_func is not None and os.environ.get("TRL_EXPERIMENTAL_SILENCE", "0") != "1": + warnings.warn( + "You are using 'rollout_func', which is an experimental feature. This API may change or be removed at " + "any time without prior notice. Silence this warning by setting environment variable " + "TRL_EXPERIMENTAL_SILENCE=1.", + UserWarning, + stacklevel=2, + ) + self.rollout_func = rollout_func + if environment_factory is not None and os.environ.get("TRL_EXPERIMENTAL_SILENCE", "0") != "1": + warnings.warn( + "You are using 'environment_factory', which is an experimental feature. This API may change or be " + "removed at any time without prior notice. Silence this warning by setting environment variable " + "TRL_EXPERIMENTAL_SILENCE=1.", + UserWarning, + stacklevel=2, + ) + + # Tools + if tools: + if not Version(transformers.__version__) >= Version("5.0.0"): + raise ImportError( + "Using tools with TargetPOTrainer requires transformers version 5.0.0 or higher. Please upgrade " + "transformers with `pip install --upgrade transformers` to use this feature." + ) + if environment_factory: + if not Version(transformers.__version__) >= Version("5.2.0"): + raise ImportError( + "Using `environment_factory` with TargetPOTrainer requires transformers version 5.2.0 or higher. " + "Please install transformers from the main branch with `pip install " + "git+https://github.com/huggingface/transformers.git@main` to use this feature." + ) + if tools or environment_factory: + if not is_jmespath_available(): + raise ImportError( + "Using tools with TargetPOTrainer requires the jmespath library for response parsing. Please install " + "it with `pip install jmespath` to use this feature." + ) + if not supports_tool_calling(processing_class): + raise ValueError( + "The provided chat template does not support tool calling. The template must be able to render a " + "full tool-calling conversation (user -> assistant with tool_calls -> tool)." + ) + + # Create the environments and extract their methods to be used as tools. We create one environment per rollout + generation_batch_size = args.per_device_train_batch_size * args.steps_per_generation + if environment_factory is not None: + self.environments = [environment_factory() for _ in range(generation_batch_size)] + environment_methods = [[] for _ in range(generation_batch_size)] + for i, environment in enumerate(self.environments): + has_reset = False + for name, member in inspect.getmembers(environment, predicate=inspect.ismethod): + if name == "reset": + has_reset = True + elif not name.startswith("_"): + environment_methods[i].append(member) + if not has_reset: + raise ValueError( + "Each environment instance returned by `environment_factory` must define a callable `reset` " + ) + else: + self.environments = None + + tools = tools or [] + self._sync_tool_dicts = [{} for _ in range(generation_batch_size)] + self._async_tool_dicts = [{} for _ in range(generation_batch_size)] + for i in range(generation_batch_size): + for tool in tools + (environment_methods[i] if self.environments is not None else []): + if inspect.iscoroutinefunction(tool): + self._async_tool_dicts[i][tool.__name__] = tool + else: + self._sync_tool_dicts[i][tool.__name__] = tool + + self.tools = tools + (environment_methods[0] if self.environments is not None else []) + + # Check for async functions to start an event loop on a daemon thread + self._has_async_funcs = any(inspect.iscoroutinefunction(func) for func in self.reward_funcs + self.tools) + + if self._has_async_funcs: + self.async_loop_thread, self.async_loop, self.async_loop_ready_event = start_event_loop_in_daemon( + name="TargetPOTrainer-AsyncLoop" + ) + # wait until the event loop is running in the daemon thread + self.async_loop_ready_event.wait() + atexit.register(shutdown_event_loop_in_daemon, self.async_loop_thread, self.async_loop) + + # At the time of initial implementation, most tokenizers do not have built-in support for response schemas. + # While waiting for broader adoption, we provide this utility function to manually set the response schema for + # known chat templates. `response_schema` lives on the (inner) tokenizer, since `parse_response` is a tokenizer + # method that reads `self.response_schema`. + tokenizer = processing_class.tokenizer if self._is_vlm else processing_class + if self.tools and getattr(tokenizer, "response_schema", None) is None: + processing_class = add_response_schema(processing_class) + # In multi-turn training, the chat template *must* be prefix-preserving. If the tokenizer's original template + # isn't, we replace it at initialization with a training-safe, prefix-preserving template. + if self.tools and not is_chat_template_prefix_preserving(processing_class): + self.chat_template = get_training_chat_template(processing_class) + else: + self.chat_template = None + + # Training arguments + self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper + self.num_generations = args.num_generations # = G in the GRPO paper + self.max_tool_calling_iterations = args.max_tool_calling_iterations or sys.maxsize + self.num_generations_eval = args.num_generations_eval or self.num_generations + self.chat_template_kwargs = args.chat_template_kwargs or {} + self.temperature = args.temperature + self.top_p = args.top_p + self.top_k = args.top_k + self.min_p = args.min_p + self.repetition_penalty = args.repetition_penalty + self.use_transformers_paged = args.use_transformers_paged + self.pad_to_multiple_of = args.pad_to_multiple_of + self.use_vllm = args.use_vllm + self.vllm_mode = args.vllm_mode + self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode + self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode + self.vllm_importance_sampling_correction = args.vllm_importance_sampling_correction + self.vllm_importance_sampling_mode = args.vllm_importance_sampling_mode + self.vllm_importance_sampling_cap = args.vllm_importance_sampling_cap + self.use_liger_kernel = args.use_liger_kernel + self.loss_type = args.loss_type + self.tpo_target_temperature = args.tpo_target_temperature + self.tpo_length_normalize_logps = args.tpo_length_normalize_logps + self.multi_objective_aggregation = args.multi_objective_aggregation + self.scale_rewards = args.scale_rewards + self.importance_sampling_level = args.importance_sampling_level + self.off_policy_mask_threshold = args.off_policy_mask_threshold + if self.use_liger_kernel and self.off_policy_mask_threshold is not None: + raise ValueError("Liger kernel does not support off-policy sequence masking yet.") + self.mask_truncated_completions = args.mask_truncated_completions + self.top_entropy_quantile = args.top_entropy_quantile + if self.use_liger_kernel and self.top_entropy_quantile < 1.0: + raise NotImplementedError( + "Liger Kernels don't currently support masking token positions based on entropy." + ) + if self.use_liger_kernel and self.importance_sampling_level not in ("token", "sequence"): + raise ValueError( + f"Unknown importance sampling level: {self.importance_sampling_level}. " + "Possible values are 'token' and 'sequence'." + ) + + # Datasets + self.shuffle_dataset = args.shuffle_dataset + + if train_dataset is None: + raise ValueError("`train_dataset` is required") + elif ( + isinstance(train_dataset, IterableDataset) + or isinstance(eval_dataset, IterableDataset) + or ( + isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values()) + ) + ): + # See https://github.com/huggingface/trl/issues/3213 + raise NotImplementedError( + "Iterable datasets are not yet supported in TargetPOTrainer. Please use a standard dataset instead." + ) + + if args.loss_type == "luspo" and args.importance_sampling_level != "sequence": + logger.warning( + "When using `'luspo'` loss, `importance_sampling_level` should be set to `'sequence'` to mirror the " + "paper's setup." + ) + + if args.loss_type == "vespo" and args.importance_sampling_level != "token": + logger.warning( + "VESPO computes sequence-level importance weights internally. `importance_sampling_level` should be " + "set to `'token'` (the default)." + ) + + if self.loss_type == "vespo" and self.use_vllm and self.vllm_importance_sampling_correction: + if self.vllm_importance_sampling_mode not in ["token_truncate", "token_mask"]: + raise ValueError( + f"VESPO loss requires `vllm_importance_sampling_mode` to be either 'token_truncate' or " + f"'token_mask'. Got: {self.vllm_importance_sampling_mode}." + ) + + # Multi-step + self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper + self.epsilon_low = args.epsilon + self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon + # Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle + self._step = 0 + # Buffer the batch to reuse generated outputs across multiple updates. For more details, see + # `_get_train_sampler` and `_prepare_inputs`. + self._buffered_inputs = None + + # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was + # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream + # (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we + # default to the recommended non-reentrant behavior here, while preserving any user-provided value. + if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"): + args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} + args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False) + + super().__init__( + model=model, + args=args, + data_collator=identity, # No data collation is needed in GRPO + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + callbacks=callbacks, + optimizers=optimizers, + # In Trainer, `training_step` scales the loss by `gradient_accumulation_steps` only if `compute_loss_func` + # is None. For DAPO, loss scaling instead depends on the total number of completions tokens across the + # global accumulated batch. To control scaling ourselves, we must disable Trainer’s built-in scaling. The + # simplest (though a bit hacky) way is to set `compute_loss_func` to any non-None value, which bypasses + # that behavior without rewriting `training_step`. + compute_loss_func="non-None value to disable scaling", + ) + + # Reference model + self.beta = args.beta + if self.beta == 0.0: + # If beta is 0.0, the reference model is not needed + self.ref_model = None + elif is_peft_model(model): + # If PEFT is used, the reference model is not needed since the adapter can be disabled + # to revert to the initial model. + self.ref_model = None + else: + # For deepspeed, fsdp or non-distributed models, create a reference model from scratch + model_init_kwargs = args.model_init_kwargs or {} + # Distributed training requires device_map=None ("auto" fails) + if self.args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + model_init_kwargs["device_map"] = None + self.ref_model = create_model_from_path(get_config_model_id(self.model.config), **model_init_kwargs) + + # Disable dropout in the models + if args.disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + # Cast LM Head To FP32 + if args.cast_lm_head_to_fp32: + + def _cast_lm_head_to_fp32(target_model: PreTrainedModel): + """Cast lm_head to fp32 while preserving embedding output dtype if tied.""" + + def cast_inputs_to_fp32(module, inputs): + # Preserve other positional args and kwargs untouched + if not inputs: + return inputs + return (inputs[0].to(torch.float32),) + inputs[1:] + + original_dtype_local = target_model.lm_head.weight.dtype + target_model.lm_head = target_model.lm_head.float() + target_model.lm_head.register_forward_pre_hook(cast_inputs_to_fp32) + + if target_model.config.tie_word_embeddings: + + def cast_outputs_to_original_dtype(module, args, output): + return output.to(original_dtype_local) + + # Only cast activations; weights are now fp32 (intentional for numerical stability of logits) + target_model.model.embed_tokens.register_forward_hook(cast_outputs_to_original_dtype) + + _cast_lm_head_to_fp32(model) + if self.ref_model is not None: + _cast_lm_head_to_fp32(self.ref_model) + + # Liger loss + if self.use_liger_kernel: + if not is_liger_kernel_available(): + raise ImportError( + "Liger is required to use `use_liger_kernel` as the TargetPO loss. Run `pip install liger-kernel`." + ) + # redirect the model.module forward to the model forward to ensure pre-forward hooks are called + self._forward_redirection = _ForwardRedirection() + + self.liger_tpo_loss = LigerFusedLinearTargetPOLoss( + beta=self.beta, + temperature=self.temperature, + num_generations=self.num_generations, + tpo_length_normalize_logps=self.tpo_length_normalize_logps, + use_bias_correction_kl=args.use_bias_correction_kl, + ) + + # Initialize the metrics + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._total_train_tokens = 0 + self._current_train_step_time = 0.0 + self.log_completions = args.log_completions + self.log_unique_prompts = args.log_unique_prompts + self.num_completions_to_print = args.num_completions_to_print + # Keep logs sized to the generation batch to record only outputs from the latest model update. + self._logs = { + "images": deque(maxlen=args.generation_batch_size), + "prompt": deque(maxlen=args.generation_batch_size), + "completion": deque(maxlen=args.generation_batch_size), + "rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)), + "advantages": deque(maxlen=args.generation_batch_size), + "extra": defaultdict(lambda: deque(maxlen=args.generation_batch_size)), + } + # Buffers for user-logged data from reward functions, flushed after gathering + self._pending_extra_logs = defaultdict(list) + self._pending_metrics = defaultdict(list) + + # Ensure each process receives a unique seed to prevent duplicate completions when generating with + # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but + # it's safer to set it in all cases. + set_seed(args.seed, device_specific=True) + + if self.use_vllm: + # Initialize vLLM generation backend + self.vllm_generation = VLLMGeneration( + model=self.model, + accelerator=self.accelerator, + is_fsdp_enabled=self.is_fsdp_enabled, + processing_class=self.processing_class, + # vLLM configuration + mode=args.vllm_mode, + structured_outputs_regex=args.vllm_structured_outputs_regex, + # Server mode configuration + server_base_url=args.vllm_server_base_url, + server_host=args.vllm_server_host, + server_port=args.vllm_server_port, + group_port=args.vllm_group_port, + server_timeout=args.vllm_server_timeout, + # Colocate mode configuration + tensor_parallel_size=args.vllm_tensor_parallel_size, + gpu_memory_utilization=args.vllm_gpu_memory_utilization, + max_model_length=args.vllm_max_model_length, + max_num_seqs=args.per_device_train_batch_size + * args.vllm_tensor_parallel_size + * args.steps_per_generation, + enable_sleep_mode=args.vllm_enable_sleep_mode, + model_impl=args.vllm_model_impl, + # Generation configuration + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + min_p=self.min_p, + max_completion_length=self.max_completion_length, + logprobs=0, # we only need the generated token logprobs for the importance sampling correction + generation_kwargs=args.generation_kwargs, + ) + self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation + else: + generation_kwargs = { + "max_new_tokens": self.max_completion_length, + "do_sample": True, + "pad_token_id": tokenizer.pad_token_id, + "bos_token_id": tokenizer.bos_token_id, + "eos_token_id": tokenizer.eos_token_id, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "min_p": self.min_p, + "repetition_penalty": self.repetition_penalty, + "cache_implementation": args.cache_implementation, + } + if args.generation_kwargs is not None: + generation_kwargs.update(args.generation_kwargs) + self.generation_config = GenerationConfig(**generation_kwargs, disable_compile=True) + # Keep training-specific generation kwargs to overwrite model's original generation config + self.generation_kwargs = generation_kwargs + + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + + # Add tags to the model + self.model.add_model_tags(self._tag_names) + + if self.ref_model is not None: + if self.is_deepspeed_enabled: + self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + elif self.is_fsdp_enabled: + self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + if args.sync_ref_model: + if self.beta == 0.0: + raise ValueError( + "You passed `sync_ref_model=True` while `beta=0.0`, which means the reference model is not used " + "during training. Consequently, TargetPOTrainer does not create a `ref_model` instance, and there is " + "nothing to synchronize. Please set `sync_ref_model=False`, or set `beta` to a non-zero value." + ) + if is_peft_model(model): + raise NotImplementedError( + "You passed `sync_ref_model=True` while using a PEFT model, which is currently not supported. " + "With PEFT, TargetPOTrainer does not keep a separate reference model in memory; instead, it recovers " + "reference behavior by temporarily disabling the adapter. As a result, there is no standalone " + "`ref_model` instance to synchronize. Use `sync_ref_model=False`, or opt for full fine-tuning if " + "you need a synced reference model. If you need `sync_ref_model` to work with PEFT, please open a " + "feature request at https://github.com/huggingface/trl/issues." + ) + self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) + + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, PreTrainedModel): + if self.is_deepspeed_enabled: + self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) + else: + # set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp + self.reward_funcs[i] = self.accelerator.prepare_model( + reward_func, evaluation_mode=True, device_placement=True + ) + + if self.accelerator.is_main_process and self.log_completions: + os.makedirs(os.path.join(self.args.output_dir, "completions"), exist_ok=True) + if self.args.log_completions_hub_repo is not None: + repo_id = self.args.log_completions_hub_repo + create_repo(repo_id, private=self.args.hub_private_repo, repo_type="dataset", exist_ok=True) + template_path = pkg_resources.files("trl").joinpath("templates/completions_dataset_card.md") + card_data = DatasetCardData( + pretty_name="TRL Completion logs", + tags=["trl", "trl-logs", "completions"], + ) + card = DatasetCard.from_template( + card_data=card_data, + template_path=str(template_path), + repo_id=repo_id, + hub_model_id=self.args.hub_model_id, + ) + card.push_to_hub(repo_id) + self.commit_scheduler = CommitScheduler( + repo_id=repo_id, + repo_type="dataset", + folder_path=f"{self.args.output_dir}/completions", + every=2, # minutes + allow_patterns=["*.parquet"], + ) + + def _set_signature_columns_if_needed(self): + # If `self.args.remove_unused_columns` is True, non-signature columns are removed. + # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids" + # and "attention_mask"). In TargetPOTrainer, we preprocess data, so using the model's signature columns doesn't + # work. Instead, we set them to the columns expected by the `training_step` method, hence the override. + if self._signature_columns is None: + self._signature_columns = ["prompt", "image", "images"] + + # This method overrides `Trainer.get_train_dataloader` to support our custom batching strategy. + # Instead of returning a standard per-step batch (i.e., `per_device_batch_size), our dataloader loads an + # *generation* batch (i.e., `per_device_batch_size × steps_per_generation`). This allows us to generate completions + # once every steps_per_generation step—rather than once per accumulation step—which is significantly more + # efficient. The only change from the original implementation is multiplying the batch size by + # `steps_per_generation`. Thus, `_prepare_inputs` is called with this *generation* batch, and it handles the + # splitting internally. + # Maintenance note: This method is a copy-paste of the original `Trainer.get_train_dataloader` with only one line + # modification. + def get_train_dataloader(self): + return self._get_dataloader( + dataset=self.train_dataset, + description="Training", + batch_size=self._train_batch_size * self.args.steps_per_generation, # < this is the change + sampler_fn=self._get_train_sampler, + is_training=True, + ) + + def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler: + # Returns a sampler that + # 1. ensures each prompt is repeated across multiple processes. This guarantees that identical prompts are + # distributed to different GPUs, allowing rewards to be computed and normalized correctly within each prompt + # group. Using the same seed across processes ensures consistent prompt assignment, preventing discrepancies + # in group formation. + # 2. repeats the batch multiple times to allow reusing generations across multiple updates. Refer to + # _prepare_inputs to see how the generations are stored and reused. + + # In the following figure, the values are the prompt indices. The first row shows the first sampled batch, the + # second row shows the second sampled batch, and so on. + # + # | GPU 0 | GPU 1 | + # + # global_step step <-───> num_generations=2 + # <-───────> per_device_train_batch_size=3 + # grad_accum ▲ ▲ 0 0 0 0 1 1 2 2 <- Generate for the first `steps_per_generation` (prompts 0 to 11); store the completions; use the first slice to compute the loss + # =2 ▼ | 0 1 3 3 4 4 5 5 <- Take the stored generations and use the second slice to compute the loss + # | + # | 1 2 6 6 7 7 8 8 <- Take the stored generations and use the third slice to compute the loss + # steps_per_gen=4 ▼ 1 3 9 9 10 10 11 11 <- Take the stored generations and use the fourth slice to compute the loss + # + # 2 4 12 12 13 13 14 14 <- Generate for the second `steps_per_generation` (prompts 12 to 23); store the completions; use the first slice to compute the loss + # 2 5 15 15 16 16 17 17 <- Take the stored generations and use the second slice to compute the loss + # ... + if dataset is None: + dataset = self.train_dataset + return RepeatSampler( + data_source=dataset, + mini_repeat_count=self.num_generations, + batch_size=self.args.generation_batch_size // self.num_generations, + repeat_count=self.num_iterations * self.args.steps_per_generation, + shuffle=self.shuffle_dataset, + seed=self.args.seed, + ) + + def _get_eval_sampler(self, eval_dataset) -> Sampler: + # See _get_train_sampler for an explanation of the sampler. + return RepeatSampler( + data_source=eval_dataset, + mini_repeat_count=self.num_generations_eval, + seed=self.args.seed, + ) + + @profiling_decorator + def _get_last_hidden_state( + self, + unwrapped_model, + input_ids, + attention_mask, + logits_to_keep, + pixel_values=None, + image_grid_thw=None, + pixel_attention_mask=None, + image_sizes=None, + image_position_ids=None, + ): + if is_peft_model(unwrapped_model): + unwrapped_model = unwrapped_model.base_model.model + + # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} + + # For Qwen models: + if image_grid_thw is not None and pixel_values is not None: + model_inputs["image_grid_thw"] = image_grid_thw + # For Gemma, SmolVLM2, LLaVa-Next etc.: + if pixel_values is not None: + model_inputs["pixel_values"] = pixel_values + # For SmolVLM2 + if pixel_attention_mask is not None: + model_inputs["pixel_attention_mask"] = pixel_attention_mask + # For LLaVa-Next + if image_sizes is not None: + model_inputs["image_sizes"] = image_sizes + if image_position_ids is not None: + model_inputs["image_position_ids"] = image_position_ids + + # Only add logits_to_keep if the model supports it + if "logits_to_keep" in self.model_kwarg_keys: + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + model_inputs["logits_to_keep"] = logits_to_keep + 1 + + model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings + + last_hidden_state = unwrapped_model.model(**model_inputs).last_hidden_state + # Exclude the last value: it corresponds to the next token pred + last_hidden_state = last_hidden_state[:, :-1, :] # (B, L-1, H) + # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. + last_hidden_state = last_hidden_state[:, -logits_to_keep:, :] # (B, logits_to_keep, H) + return last_hidden_state + + def get_high_entropy_mask(self, entropies: torch.Tensor, mask: torch.Tensor, threshold: float) -> torch.Tensor: + """ + Returns a binary mask identifying tokens whose entropy exceeds a given quantile threshold. + + Args: + entropies (`torch.Tensor`): + Tensor of shape (batch_size, seq_len) with per-token entropy values. + mask (`torch.Tensor`): + Binary mask of the same shape as `entropies`, where `1` indicates valid tokens and `0` padding. + threshold (`float`): + Quantile threshold between `0.0` and `1.0` to select high-entropy tokens. + + Returns: + `torch.Tensor`: + Boolean mask of shape (batch_size, seq_len), where `True` indicates tokens with entropy >= threshold + and `False` otherwise. + """ + local = entropies[mask.bool()].float() + + # Use a negative pad_value as a sentinel because entropy values are always >= 0. + # This guarantees that the sentinel cannot collide with any real entropy value. + pad_value = -1e9 + + # Pad across processes so that every rank has the same tensor length + padded = self.accelerator.pad_across_processes(local, dim=0, pad_index=pad_value) + gathered = self.accelerator.gather(padded) + + # Drop sentinel values (safe because no entropy can be negative) + gathered = gathered[gathered != pad_value] + + if gathered.numel() == 0: + return torch.zeros_like(entropies, dtype=torch.bool) + + entropy_threshold = torch.quantile(gathered, threshold) + masked_entropies = entropies * mask.float() + entropy_mask = masked_entropies >= entropy_threshold + return entropy_mask & mask.bool() # ensure padding tokens are always masked out + + @profiling_decorator + def _get_per_token_logps_and_entropies( + self, + model, + input_ids, + attention_mask, + logits_to_keep, + batch_size=None, + compute_entropy=False, + pixel_values=None, + image_grid_thw=None, + num_images=None, + pixel_attention_mask=None, + image_sizes=None, + token_type_ids=None, + mm_token_type_ids=None, + image_position_ids=None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Compute log-probs and (optionally) entropies for each token.""" + batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak + all_logps = [] + all_entropies = [] + for start in range(0, input_ids.size(0), batch_size): + input_ids_batch = input_ids[start : start + batch_size] + attention_mask_batch = attention_mask[start : start + batch_size] + + # Build model inputs + model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} + if image_grid_thw is not None and pixel_values is not None: + rows_per_image = image_grid_thw.prod(dim=-1) + rows_per_sample = torch.split(rows_per_image, num_images) + rows_per_sample = torch.stack([s.sum() for s in rows_per_sample]) + cum_rows = torch.cat([torch.tensor([0], device=rows_per_sample.device), rows_per_sample.cumsum(0)]) + row_start, row_end = cum_rows[start].item(), cum_rows[start + batch_size].item() + model_inputs["pixel_values"] = pixel_values[row_start:row_end] + cum_imgs = torch.tensor([0] + num_images).cumsum(0) + img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size] + model_inputs["image_grid_thw"] = image_grid_thw[img_start:img_end] + elif image_position_ids is not None and pixel_values is not None: + cum_imgs = torch.tensor([0] + num_images).cumsum(0) + img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size] + model_inputs["pixel_values"] = pixel_values[img_start:img_end] + model_inputs["image_position_ids"] = image_position_ids[img_start:img_end] + elif pixel_values is not None: + model_inputs["pixel_values"] = pixel_values[start : start + batch_size] + if pixel_attention_mask is not None: + model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size] + if image_sizes is not None: + model_inputs["image_sizes"] = image_sizes[start : start + batch_size] + if token_type_ids is not None: + model_inputs["token_type_ids"] = token_type_ids[start : start + batch_size] + if mm_token_type_ids is not None: + model_inputs["mm_token_type_ids"] = mm_token_type_ids[start : start + batch_size] + + # Only add logits_to_keep if the model supports it + if "logits_to_keep" in self.model_kwarg_keys: + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + model_inputs["logits_to_keep"] = logits_to_keep + 1 + + model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings + + logits = model(**model_inputs).logits + # Exclude the last value: it corresponds to the next token pred + logits = logits[:, :-1, :] # (B, L-1, H) + # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. + logits = logits[:, -logits_to_keep:, :] # (B, logits_to_keep, H) + # Divide logits by sampling temperature. + # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details + logits.div_(self.temperature) + completion_ids = input_ids_batch[:, -logits_to_keep:] + logps = selective_log_softmax(logits, completion_ids) # compute logprobs + all_logps.append(logps) + + if compute_entropy: + with torch.no_grad(): + entropies = entropy_from_logits(logits) + all_entropies.append(entropies) + + logps = torch.cat(all_logps, dim=0) + entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None + return logps, entropies + + def training_step(self, model, inputs, num_items_in_batch): + time_before = time.perf_counter() + output = super().training_step(model, inputs, num_items_in_batch) + self._step += 1 + time_after = time.perf_counter() + self._current_train_step_time += time_after - time_before + if self._step % self.current_gradient_accumulation_steps == 0: + self._metrics["train"]["step_time"].append(self._current_train_step_time) + self._current_train_step_time = 0.0 + return output + + @profiling_decorator + def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]: + # Prepares inputs for model training/evaluation by managing completion generation and batch handling. + # During training: + # - Receives the local generation batch (Per-GPU batch size × steps per generation) + # from the modified training dataloader instead of the standard local batch + # - Generates completions once for the entire generation batch and splits it into batches of size + # `per_device_train_batch_size` + # - Buffers these completions and returns the appropriate slice for the current accumulation step + # - Optimizes by regenerating completions only periodically (every steps_per_generation * num_iterations) + # During evaluation: + # - The input is treated as a standard local batch (no accumulation, no multiple iterations) + # - Completions are generated for each batch without buffering or reuse + # Returns a single local batch in both cases. + + mode = "train" if self.model.training else "eval" + if mode == "train": + generate_every = self.args.steps_per_generation * self.num_iterations + if self._step % generate_every == 0 or self._buffered_inputs is None: + # self._buffered_inputs=None can occur when resuming from a checkpoint + generation_batch = self._generate_and_score_completions(generation_batch) + generation_batch = split_pixel_values_by_grid(generation_batch) + if self.loss_type != "tpo": + generation_batch = shuffle_sequence_dict(generation_batch) + generation_batches = split_tensor_dict(generation_batch, self.args.steps_per_generation) + self._buffered_inputs = [unsplit_pixel_values_by_grid(batch) for batch in generation_batches] + inputs = self._buffered_inputs[self._step % self.args.steps_per_generation] + else: + # In evaluation, there is neither batch grouping for generation, nor multiple iterations, hence + # local generation batch == local eval batch + inputs = self._generate_and_score_completions(generation_batch) + return inputs + + def _log_completion_extra(self, column: str, values: list): + """ + Log extra columns to the completions table. Called from reward functions via the `log_extra` kwarg. + + Args: + column (`str`): + Name of the column to add. + values (`list`): + Values for the column, one per sample in the batch. + """ + self._pending_extra_logs[column].extend(values) + + def _log_metric(self, name: str, value: float): + """ + Log a scalar metric from a reward function. Called via the `log_metric` kwarg. Values are averaged over each + logging step and reported alongside built-in metrics like `kl` and `entropy`. + + Args: + name (`str`): + Name of the metric. + value (`float`): + Scalar value for this batch. + """ + self._pending_metrics[name].append(value) + + @profiling_decorator + def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): + device = self.accelerator.device + rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) + + # Repeat all input columns (but "prompt", "completion", and "completion_ids") to match the num of generations + keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]] + reward_kwargs = {key: [example[key] for example in inputs] for key in keys} + + # This allows for dynamic reward shaping based on training progress. + reward_kwargs["trainer_state"] = self.state + + # Allow reward functions to log extra columns to the completions table. + reward_kwargs["log_extra"] = self._log_completion_extra + + # Allow reward functions to log additional scalar metrics. + reward_kwargs["log_metric"] = self._log_metric + + async_funcs_info = [] # async custom functions for asyncio.gather + + for i, (reward_func, reward_processing_class, reward_func_name) in enumerate( + zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names, strict=True) + ): + if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models + with profiling_context(self, reward_func_name): + if is_conversational(inputs[0]): + messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)] + texts = [ + apply_chat_template(x, reward_processing_class, **self.chat_template_kwargs)["text"] + for x in messages + ] + else: + texts = [p + c for p, c in zip(prompts, completions, strict=True)] + reward_inputs = reward_processing_class( + text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False + ) + reward_inputs = super()._prepare_inputs(reward_inputs) + with torch.inference_mode(): + rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) + elif inspect.iscoroutinefunction(reward_func): # Separate async reward funcs to run them in parallel later + async_funcs_info.append((i, reward_func, reward_func_name)) + else: + # Run synchronous reward function + with profiling_context(self, reward_func_name): + if self.environments is not None: + reward_kwargs["environments"] = self.environments + output_reward_func = reward_func( + prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs + ) + # Convert None values to NaN + output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + + # Execute async custom functions in parallel using asyncio.gather + if async_funcs_info: + + async def _invoke_async(index, func, func_name): + with profiling_context(self, func_name): + output = await func( + prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs + ) + output = [r if r is not None else torch.nan for r in output] + return index, output + + async def _run_async_funcs(): + coros = [_invoke_async(i, func, func_name) for (i, func, func_name) in async_funcs_info] + return await asyncio.gather(*coros) + + async_results = asyncio.run_coroutine_threadsafe(_run_async_funcs(), self.async_loop).result() + for idx, output_reward_func in async_results: + rewards_per_func[:, idx] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + + # If all reward functions return None for a given row, issue a detailed warning + if torch.isnan(rewards_per_func).all(dim=1).any(): + nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0] + row_reward_kwargs = { + key: value[nan_row_idx] + for key, value in reward_kwargs.items() + if key not in ("trainer_state", "log_extra", "log_metric") + } + row_reward_kwargs["prompt"] = prompts[nan_row_idx] + row_reward_kwargs["completion"] = completions[nan_row_idx] + logger.warning( + f"All reward functions returned None for the following kwargs:\n{row_reward_kwargs}\n" + "Please ensure that at least one reward function returns a valid reward." + ) + + # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the + # completions may be distributed across processes + rewards_per_func = gather(rewards_per_func) + return rewards_per_func + + def _tokenize_prompts(self, prompts: list): + """Tokenize prompts and extract images/multimodal fields for generation.""" + if is_conversational({"prompt": prompts[0]}): + # Normalize string content to content blocks for VLM processors that don't handle plain strings. + if self._is_vlm: + prompts = [prepare_multimodal_messages(prompt) for prompt in prompts] + + # Extract images from messages for VLM support + images = [] + has_images = False + for prompt in prompts: + prompt_images = [] + for message in prompt: + if isinstance(message["content"], list): + for part in message["content"]: + if part["type"] == "image": + prompt_images.append(part["image"]) + has_images = True + images.append(prompt_images if prompt_images else None) + images = images if has_images else None + + # Workaround for a bug in transformers 5.3.0 where some processors (e.g. Qwen2.5-VL) crash on + # batched unpadded input (transformers#44514). + # Fixed in transformers 5.4.0 (transformers#44563). + needs_padding_workaround = Version("5.3.0") <= Version(transformers.__version__) < Version("5.4.0") + tokenized = self.processing_class.apply_chat_template( + conversation=prompts, + tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[] + chat_template=self.chat_template, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + **({"padding": True} if needs_padding_workaround else {}), + **self.chat_template_kwargs, + ) + if needs_padding_workaround: + # Unpad input_ids: remove padding tokens using attention_mask to get per-sequence lists + prompt_ids = [ + [tok for tok, m in zip(ids, mask, strict=True) if m] + for ids, mask in zip(tokenized["input_ids"], tokenized["attention_mask"], strict=True) + ] + else: + prompt_ids = tokenized["input_ids"] + # For VLMs, the processor returns extra multimodal fields (pixel_values, image_grid_thw, etc.) + multimodal_fields = {k: v for k, v in tokenized.items() if k not in ("input_ids", "attention_mask")} + else: + prompt_ids = self.processing_class(text=prompts)["input_ids"] + images = None + multimodal_fields = {} + return prompt_ids, images, multimodal_fields + + def _generate_single_turn(self, prompt_ids, images, multimodal_fields): + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + # Generate completions using either vLLM or regular generation + if self.use_vllm: + # Sync weights if training step changed + if self.state.global_step != self._last_loaded_step: + with profiling_context(self, "sync_weights"): + self.vllm_generation.sync_weights() + self._last_loaded_step = self.state.global_step + + # Generate using vLLM with raw token IDs + num_generations = self.num_generations if mode == "train" else self.num_generations_eval + _, completion_ids, logprobs, _ = self.vllm_generation.generate( + prompts=prompt_ids, + images=images, + num_generations=num_generations, + profiler=profiling_context(self, "vLLM.generate"), + ) + # vLLM returns per-token top-k logprobs; keep only the top-1 (sampled token) logprob + logprobs = [[lp[0] for lp in seq] for seq in logprobs] + + elif self.use_transformers_paged: + with ( + profiling_context(self, "transformers.generate_batch"), + unwrap_model_for_generation( + self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model, + torch.no_grad(), + FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), + ): + # Cast to the appropriate dtype based on training configuration + if self.args.bf16: + unwrapped_model.to(torch.bfloat16) + elif self.args.fp16: + unwrapped_model.to(torch.float16) + if self.args.cast_lm_head_to_fp32: + unwrapped_model.lm_head.to(torch.float32) + with torch.inference_mode(): + # Continuous batching API expects 'inputs' arg only + all_outputs = unwrapped_model.generate_batch( + prompt_ids, generation_config=self.generation_config, progress_bar=False + ) + unwrapped_model.train() # restore training mode, as generate_batch forces eval mode + completion_ids = [output.generated_tokens for output in all_outputs.values()] + logprobs = None # not used in this case + + else: + # Regular generation path: left-pad token IDs into tensors + prompt_tensors = [torch.tensor(ids) for ids in prompt_ids] + padded_ids = pad(prompt_tensors, padding_value=self.pad_token_id, padding_side="left") + attention_mask = pad([torch.ones_like(t) for t in prompt_tensors], padding_value=0, padding_side="left") + generate_inputs = {"input_ids": padded_ids, "attention_mask": attention_mask} + # For VLMs, include multimodal fields as tensors (pixel_values, image_grid_thw, etc.) + for k, v in multimodal_fields.items(): + if isinstance(v, torch.Tensor): + generate_inputs[k] = v + elif isinstance(v, list) and v and isinstance(v[0], list): + # Per-token field (e.g., token_type_ids): left-pad like input_ids + generate_inputs[k] = pad([torch.tensor(x) for x in v], padding_value=0, padding_side="left") + else: + generate_inputs[k] = torch.tensor(np.array(v)) + generate_inputs = super()._prepare_inputs(generate_inputs) + + with ( + profiling_context(self, "transformers.generate"), + unwrap_model_for_generation( + self.model_wrapped, + self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation, + generation_kwargs=self.generation_kwargs, # Override model.generation_config with generation_kwargs to fix transformers#42762 + ) as unwrapped_model, + torch.no_grad(), + FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), + ): + prompt_completion_ids = unwrapped_model.generate( + **generate_inputs, generation_config=self.generation_config + ) + # Compute prompt length and extract completion ids + prompt_length = generate_inputs["input_ids"].size(1) + completion_ids = prompt_completion_ids[:, prompt_length:] + + # Mask everything after the first EOS token + is_eos = completion_ids == self.eos_token_id + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) + completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + completion_ids = [ + c[m].tolist() for c, m in zip(completion_ids.cpu(), completion_mask.bool().cpu(), strict=True) + ] + logprobs = None # not used in this case + + return completion_ids, logprobs + + def _get_tool_suffix_ids(self, tool_messages): + """Get token IDs for tool result formatting by using a minimal dummy conversation.""" + # Use the real tool name instead of a dummy: some templates (e.g. GPT-OSS) derive the tool response + # header from the assistant's tool call name. + dummy_tool_calls = [{"type": "function", "function": {"name": tool_messages[0]["name"], "arguments": {}}}] + dummy_messages = [ + {"role": "user", "content": "dummy"}, + { + "role": "assistant", + # "content" is required here because VLM processors crash on tokenize=True without it + # (KeyError in processing_utils.py). See huggingface/transformers#45290. + "content": "", + "tool_calls": dummy_tool_calls, + }, + ] + if self._is_vlm: + dummy_messages = prepare_multimodal_messages(dummy_messages) + tool_messages = prepare_multimodal_messages(tool_messages) + + prefix_ids = self.processing_class.apply_chat_template( + dummy_messages, + add_generation_prompt=False, + tokenize=True, + chat_template=self.chat_template, + return_dict=False, + **self.chat_template_kwargs, + ) + full_ids = self.processing_class.apply_chat_template( + dummy_messages + tool_messages, + add_generation_prompt=True, + tokenize=True, + chat_template=self.chat_template, + return_dict=False, + **self.chat_template_kwargs, + ) + # VLM processors return batched output (list of lists), unbatch for single conversation + if self._is_vlm: + prefix_ids = prefix_ids[0] + full_ids = full_ids[0] + + # Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block. + # When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to + # EOS (not EOS + newline). Templates that don't use EOS as end-of-turn (e.g. Gemma uses + # ) skip this trimming. + eos_positions = [i for i, tok_id in enumerate(prefix_ids) if tok_id == self.eos_token_id] + if eos_positions: + prefix_ids = prefix_ids[: eos_positions[-1] + 1] + + if full_ids[: len(prefix_ids)] != prefix_ids: + raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.") + return full_ids[len(prefix_ids) :] + + def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields): + # Tool execution loop: execute tools, then regenerate completions with tool results appended to the prompt + tool_calls = [completion[0].get("tool_calls") for completion in completions] + idxs_with_tool = [idx for idx, tool_call in enumerate(tool_calls) if tool_call] + tool_calls = [tool_calls[idx] for idx in idxs_with_tool] + tool_mask = [[1] * len(ids) for ids in completion_ids] # 0 for tool result tokens, 1 elsewhere + # Collect images from multimodal tool responses for the forward pass + tool_images = [[] for _ in completion_ids] + tool_call_count = 0 + tool_failure_count = 0 + iteration_num = 0 + + while idxs_with_tool and iteration_num < self.max_tool_calling_iterations: + prompt_completion_tools = [prompts[i] for i in idxs_with_tool] # select only prompts that need tool calls + # Snapshot state so we can rollback tool results that would exceed max_completion_length + completions_len_before = [len(completions[i]) for i in idxs_with_tool] + tool_images_len_before = [len(tool_images[i]) for i in idxs_with_tool] + prompts_len_before = [len(prompts[i]) for i in idxs_with_tool] + + # Call the tools, and build the new prompt for generation + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + tool_call_list = tool_calls[idx] + prompt_completion_tool = prompt_completion_tools[idx] + sync_tool_dict = self._sync_tool_dicts[idx_with_tool] + async_tool_dict = self._async_tool_dicts[idx_with_tool] + # Append the last assistant message (which triggered tool_calls) to the prompt + prompt_completion_tool.append(completions[idx_with_tool][-1]) + async_coros = [] + tool_call_results = [] + for tool_call in tool_call_list: + tool_call_count += 1 + if tool_call["type"] == "function": + function = tool_call["function"] + name = function["name"] + try: + if name in sync_tool_dict: + tool_call_results.append((name, sync_tool_dict[name](**function["arguments"]))) + elif name in async_tool_dict: + async_coros.append((name, async_tool_dict[name](**function["arguments"]))) + else: + raise ValueError(f"Tool {name} not found.") + except Exception as e: + tool_failure_count += 1 + result = {"error": str(e)} + tool_call_results.append((name, result)) + else: + tool_failure_count += 1 + name = tool_call.get("name", "unknown") + tool_call_results.append((name, {"error": f"Unsupported tool call type: {tool_call['type']}"})) + + if async_coros: + + async def _run_async_tools(async_coros): + coros = [coro for _, coro in async_coros] + results = await asyncio.gather(*coros, return_exceptions=True) + return [(name, result) for (name, _), result in zip(async_coros, results, strict=False)] + + async_results = asyncio.run_coroutine_threadsafe( + _run_async_tools(async_coros), self.async_loop + ).result() + + for name, result in async_results: + if isinstance(result, Exception): + tool_failure_count += 1 + tool_call_results.append((name, {"error": str(result)})) + else: + tool_call_results.append((name, result)) + + for name, result in tool_call_results: + # Support multimodal tool responses: if the tool returns a list of content blocks + # (e.g., [{"type": "image", "image": ...}, {"type": "text", "text": "..."}]), + # pass them through directly so _tokenize_prompts can extract images for VLMs. + content = result if isinstance(result, list) else str(result) + tool_message = {"role": "tool", "name": name, "content": content} + # Collect images from multimodal tool responses + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "image": + tool_images[idx_with_tool].append(part["image"]) + prompt_completion_tool.append(tool_message) + completions[idx_with_tool].append(tool_message) + + # Build token IDs by concatenation: prompt + completion + tool_suffix. + prompt_completion_tool_ids = [] + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + # Extract trailing tool messages from completions + tool_messages = [] + for message in reversed(completions[idx_with_tool]): + if message["role"] == "tool": + tool_messages.insert(0, message) + else: + break + suffix_ids = self._get_tool_suffix_ids(tool_messages) + prompt_completion_tool_ids.append( + prompt_ids[idx_with_tool] + completion_ids[idx_with_tool] + suffix_ids + ) + + # Drop tool results whose addition would push the sequence past max_completion_length (the completion + # budget) or past the backend context ceiling (vLLM and transformers will error out on inputs longer than + # the model's max length). The sample exits the loop with its completion as-is, and the tool + # messages/images appended this iteration are rolled back so completions and tool_images stay consistent + # with completion_ids. + if self.use_vllm and self.vllm_mode == "colocate": + max_model_len = self.vllm_generation.llm.llm_engine.model_config.max_model_len + else: + config = self.model.config.text_config if self._is_vlm else self.model.config + max_model_len = config.max_position_embeddings + overlong = [ + len(pct) - len(prompt_ids[i]) > self.max_completion_length or len(pct) >= max_model_len + for i, pct in zip(idxs_with_tool, prompt_completion_tool_ids, strict=True) + ] + for idx in range(len(idxs_with_tool)): + if overlong[idx]: + idx_with_tool = idxs_with_tool[idx] + del completions[idx_with_tool][completions_len_before[idx] :] + del tool_images[idx_with_tool][tool_images_len_before[idx] :] + del prompts[idx_with_tool][prompts_len_before[idx] :] + # Keep only non-overlong items for further processing + idxs_with_tool = [idx for idx, o in zip(idxs_with_tool, overlong, strict=True) if not o] + prompt_completion_tool_ids = [ + pct for pct, o in zip(prompt_completion_tool_ids, overlong, strict=True) if not o + ] + if not idxs_with_tool: + break # all overlong, exit tool loop + + # Filter images and multimodal fields to match the current subset (index into full batch). + # Merge tool response images so the model can see visual feedback during generation. + merged_images = images + if any(imgs for imgs in tool_images): + if merged_images is None: + merged_images = [imgs if imgs else None for imgs in tool_images] + else: + merged_images = [ + (existing or []) + new for existing, new in zip(merged_images, tool_images, strict=True) + ] + loop_images = [merged_images[i] for i in idxs_with_tool] if merged_images else None + if multimodal_fields: + loop_multimodal_fields = {} + for k, v in multimodal_fields.items(): + selected = [v[i] for i in idxs_with_tool] + # Per-token fields (e.g. token_type_ids) need zero-padding to match extended prompt length + if isinstance(selected[0], list): + selected = [ + s + [0] * (len(pct) - len(s)) + for s, pct in zip(selected, prompt_completion_tool_ids, strict=True) + ] + loop_multimodal_fields[k] = selected + else: + loop_multimodal_fields = {} + + # Generate new completions after tool execution (using concatenated IDs, no re-tokenization) + post_tool_ids, post_tool_logprobs = self._generate_single_turn( + prompt_completion_tool_ids, loop_images, loop_multimodal_fields + ) + + # Truncate so that pct[len(prompt_ids[idx]) :] + post_tool does not exceed max_completion_length. + # The pre-regen check guarantees len(completion_tool_ids) <= max_completion_length, so any + # excess can only come from post_tool_ids. post_tool_ids is model-generated text and never + # contains image tokens, so a plain slice is safe. + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + completion_tool_length = len(prompt_completion_tool_ids[idx]) - len(prompt_ids[idx_with_tool]) + excess_length = completion_tool_length + len(post_tool_ids[idx]) - self.max_completion_length + if excess_length > 0: + new_len = len(post_tool_ids[idx]) - excess_length + post_tool_ids[idx] = post_tool_ids[idx][:new_len] + if logprobs is not None: + post_tool_logprobs[idx] = post_tool_logprobs[idx][:new_len] + + # Update tool_mask: the tool result should be 0 and the post-tool 1 + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + prompt_completion_tool_length = len(prompt_completion_tool_ids[idx]) + prompt_length = len(prompt_ids[idx_with_tool]) + completion_length = len(completion_ids[idx_with_tool]) + post_tool_length = len(post_tool_ids[idx]) + tool_length = prompt_completion_tool_length - prompt_length - completion_length + tool_mask[idx_with_tool] += [0] * tool_length + [1] * post_tool_length + if logprobs is not None: + logprobs[idx_with_tool] += [0.0] * tool_length + post_tool_logprobs[idx] + + # Update completion_ids with the new completions (after tool execution) + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + prompt_length = len(prompt_ids[idx_with_tool]) + pct = prompt_completion_tool_ids[idx] # = prompt-completion-tool + completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx] + + # Decode post-tool completions. + post_tool_completions = [ + parse_response(self.processing_class, ids) if ids else {} for ids in post_tool_ids + ] + + # Add post-tool completions to the existing completions + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + if post_tool_completions[idx]: # {} if post-tool completions completely truncated + completions[idx_with_tool].append(post_tool_completions[idx]) + + # Check for further tool calls + tool_calls = [completion.get("tool_calls") for completion in post_tool_completions] + idxs_with_tool = [idx for idx, tool_call in zip(idxs_with_tool, tool_calls, strict=True) if tool_call] + tool_calls = [tool_call for tool_call in tool_calls if tool_call] + iteration_num += 1 + + return tool_mask, completions, completion_ids, logprobs, tool_call_count, tool_failure_count, tool_images + + def _generate(self, prompts: list): + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + # Copy the prompts to avoid modifying the original list + prompts = copy.deepcopy(prompts) + + if self.rollout_func is not None: + # Keep vLLM weights in sync for custom rollouts that rely on vLLM utilities. + if self.use_vllm and self.state.global_step != self._last_loaded_step: + with profiling_context(self, "sync_weights"): + self.vllm_generation.sync_weights() + self._last_loaded_step = self.state.global_step + + # Pass prompts to rollout_func preserving structured messages. + # Chat templating must happen inside rollout_func, at the backend boundary, so that + # multimodal content (images, typed content blocks) is not lost before rollout logic runs. + output = self.rollout_func(prompts, self) + required_keys = {"prompt_ids", "completion_ids", "logprobs"} + missing_keys = required_keys - output.keys() + if missing_keys: + missing_keys_list = sorted(missing_keys) + raise ValueError(f"rollout_func must return keys {missing_keys_list} in its output dict.") + extra_fields = {k: v for k, v in output.items() if k not in required_keys} + prompt_ids, completion_ids, logprobs = output["prompt_ids"], output["completion_ids"], output["logprobs"] + images = None + multimodal_fields = {} + else: + prompt_ids, images, multimodal_fields = self._tokenize_prompts(prompts) + completion_ids, logprobs = self._generate_single_turn(prompt_ids, images, multimodal_fields) + extra_fields = {} + + # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. + if is_conversational({"prompt": prompts[0]}): + tokenizer = self.processing_class.tokenizer if self._is_vlm else self.processing_class + if ( + Version(transformers.__version__) >= Version("5.0.0") # parse_response added in v5 + and hasattr(tokenizer, "response_schema") # attribute not set by default for now + and tokenizer.response_schema is not None # only works if the tokenizer has a schema + ): + completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids] + else: + contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + completions = [[{"role": "assistant", "content": content}] for content in contents] + else: + completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + + # Extract tool calls from the completions and (possibly) execute them + tool_images = [] + if self.tools: + ( + tool_mask, + completions, + completion_ids, + logprobs, + tool_call_count, + tool_failure_count, + tool_images, + ) = self._tool_call_loop( + prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields + ) + # Merge tool response images into the images list for the forward pass + if any(imgs for imgs in tool_images): + if images is None: + images = [imgs if imgs else None for imgs in tool_images] + else: + images = [(existing or []) + new for existing, new in zip(images, tool_images, strict=True)] + else: + # Support custom env_mask from rollout_func (e.g., for environment feedback masking) + # Internally treated as tool_mask - marks model tokens (1) vs external tokens (0) + tool_mask = extra_fields.pop("env_mask", None) + + # Get completion length per sequence, used for logging + prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) + if tool_mask is not None: # count only model-generated tokens (tool_mask=1) + completion_lengths = torch.tensor([sum(mask) for mask in tool_mask], device=device) + else: + completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device) + agg_prompt_lengths = self.accelerator.gather(prompt_lengths) + agg_completion_lengths = self.accelerator.gather(completion_lengths) + total_prompt_tokens = agg_prompt_lengths.sum() + total_completion_tokens = agg_completion_lengths.sum() # = num_items_in_batch, required for the DAPO loss + + # Log the metrics + if mode == "train": + self.state.num_input_tokens_seen += (total_prompt_tokens + total_completion_tokens).item() + self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] + + # Log completion lengths, mean, min, max + self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) + + # Identify sequences that terminated with EOS and log their lengths + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device) + agg_is_truncated = self.accelerator.gather(is_truncated) + self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) + term_completion_lengths = agg_completion_lengths[~agg_is_truncated] + if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found + term_completion_lengths = torch.zeros(1, device=device) + self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) + + if self.tools: + agg_tool_call_count = self.accelerator.gather(torch.tensor(tool_call_count, device=device)).sum() + tool_call_frequency = (agg_tool_call_count / len(agg_prompt_lengths)).item() + self._metrics[mode]["tools/call_frequency"].append(tool_call_frequency) + agg_tool_failure_count = self.accelerator.gather(torch.tensor(tool_failure_count, device=device)).sum() + failure_frequency = ( + (agg_tool_failure_count / agg_tool_call_count).item() if agg_tool_call_count > 0 else 0.0 + ) + self._metrics[mode]["tools/failure_frequency"].append(failure_frequency) + + return ( + prompt_ids, + completion_ids, + tool_mask, + completions, + total_completion_tokens, + logprobs, + extra_fields, + images, + tool_images, + ) + + def _generate_and_score_completions( + self, inputs: list[dict[str, torch.Tensor | Any]] + ) -> dict[str, torch.Tensor | Any]: + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + prompts = [x["prompt"] for x in inputs] + + if self.environments: + for prompt, environment, reset_kwargs in zip(prompts, self.environments, inputs, strict=True): + observation = environment.reset(**reset_kwargs) + if observation is None: + continue + if isinstance(observation, list) and isinstance(prompt[-1]["content"], str): + prompt[-1]["content"] = [{"type": "text", "text": prompt[-1]["content"]}] + if isinstance(observation, str) and isinstance(prompt[-1]["content"], list): + observation = [{"type": "text", "text": observation}] + prompt[-1]["content"] += observation + + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None + # Transformers requires at least one image in the batch, otherwise it throws an error + if images is not None and all(img_list == [] for img_list in images): + images = None + + # If the prompts are conversational and the inputs contain images, we need to convert the prompts from + # [{"role": "user", "content": "What color is the sky?"}] to + # [{"role": "user", "content": [{"type": "image", "image": }, {"type": "text", "text": "What color is the sky?"}]}] + if images is not None: + if not is_conversational(inputs[0]): + raise ValueError( + "Multimodal training requires conversational prompts. It looks like the dataset contains " + "non-conversational inputs, likely because a chat template was applied before passing the dataset " + "to the trainer. Please provide the raw conversational prompts and let the trainer apply the chat " + "template internally." + ) + prompts = [ + prepare_multimodal_messages(prompt, images=image_list) + for prompt, image_list in zip(prompts, images, strict=True) + ] + + dataset_images = images # preserve dataset images before _generate may overwrite + ( + prompt_ids_list, + completion_ids_list, + tool_mask_list, + completions, + num_items_in_batch, + sampling_per_token_logps_list, + extra_fields, + images, + tool_images, + ) = self._generate(prompts) + if images is None: + images = dataset_images # restore dataset images (rollout_func path returns None) + + # Convert lists of token IDs to padded tensors + prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad( + prompt_ids, + padding_value=self.pad_token_id, + padding_side="left", + pad_to_multiple_of=self.pad_to_multiple_of, + ).to(device=device) + prompt_mask = pad( + prompt_mask, padding_value=0, padding_side="left", pad_to_multiple_of=self.pad_to_multiple_of + ).to(device=device) + completion_ids = [torch.tensor(ids) for ids in completion_ids_list] + completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + completion_ids = pad( + completion_ids, + padding_value=self.pad_token_id, + padding_side="right", + pad_to_multiple_of=self.pad_to_multiple_of, + ).to(device=device) + completion_mask = pad( + completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ).to(device=device) + if sampling_per_token_logps_list is not None: + sampling_per_token_logps = [torch.tensor(logps) for logps in sampling_per_token_logps_list] + sampling_per_token_logps = pad( + sampling_per_token_logps, + padding_value=0.0, + padding_side="right", + pad_to_multiple_of=self.pad_to_multiple_of, + ).to(device=device) + else: + sampling_per_token_logps = None + if tool_mask_list is not None: + tool_mask = [torch.tensor(mask) for mask in tool_mask_list] + tool_mask = pad( + tool_mask, padding_value=1, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ).to(device=device) + else: + tool_mask = None + + # If mask_truncated_completions is enabled, zero out truncated completions for attention and loss masking + if self.mask_truncated_completions: + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + # Mask completion_mask for attention masking + completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() + # Also mask tool_mask for consistency in multi-turn training + if tool_mask is not None: + tool_mask = tool_mask * (~is_truncated).unsqueeze(1).int() + + # Concatenate prompt_mask with completion_mask for logit computation + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) + + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + + num_images = [len(img_list) if img_list else 0 for img_list in images] if images is not None else None + + # Get forward_kwargs for models with multimodal inputs. + # When tool images are present (from _tool_call_loop), use image_processor directly and build + # mm_token_type_ids from prompt_completion_ids. Otherwise, use the full processor pipeline + # which returns model-specific keys (image_sizes, pixel_attention_mask, etc.). + if self.tools and any(imgs for imgs in tool_images) and self._is_vlm: + flat_images = [img for img_list in images if img_list for img in img_list] + image_inputs = self.processing_class.image_processor(images=flat_images, return_tensors="pt") + image_inputs = super()._prepare_inputs(image_inputs) + forward_kwargs = dict(image_inputs) + elif images is not None: + prompts_text = [ + apply_chat_template( + {"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs + )["prompt"] + for prompt in prompts + ] + prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt") + prompt_inputs = super()._prepare_inputs(prompt_inputs) + forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} + else: + forward_kwargs = {} + + # If token_type_ids are used, extend them with zeros for the completion part + if "token_type_ids" in forward_kwargs: + token_type_ids = forward_kwargs["token_type_ids"] + if self.pad_to_multiple_of is not None: + # Needed only with pad_to_multiple_of: otherwise prompt_ids and token_type_ids must have equal len + padding_size = prompt_ids.size(1) - token_type_ids.size(1) + if padding_size > 0: + token_type_ids = torch.cat( + [token_type_ids.new_zeros((token_type_ids.size(0), padding_size)), token_type_ids], dim=1 + ) + forward_kwargs["token_type_ids"] = torch.cat( + [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 + ) + # If mm_token_type_ids are used, extend them with zeros for the completion part + if "mm_token_type_ids" in forward_kwargs: + mm_token_type_ids = forward_kwargs["mm_token_type_ids"] + if self.pad_to_multiple_of is not None: + # Needed only with pad_to_multiple_of: otherwise prompt_ids and mm_token_type_ids must have equal len + padding_size = prompt_ids.size(1) - mm_token_type_ids.size(1) + if padding_size > 0: + mm_token_type_ids = torch.cat( + [mm_token_type_ids.new_zeros((mm_token_type_ids.size(0), padding_size)), mm_token_type_ids], + dim=1, + ) + forward_kwargs["mm_token_type_ids"] = torch.cat( + [mm_token_type_ids, mm_token_type_ids.new_zeros(completion_ids.shape)], dim=1 + ) + + # For VLM tool images: build token type IDs from the full prompt_completion_ids. + # This must happen AFTER the token_type_ids/mm_token_type_ids extension blocks above, + # because our version already covers the full sequence (images are in the completion, + # not just the prompt). + if self.tools and any(imgs for imgs in tool_images) and self._is_vlm: + mm_ids = torch.zeros_like(prompt_completion_ids) + if self._image_pad_token_id is not None: + mm_ids[prompt_completion_ids == self._image_pad_token_id] = 1 + if self._video_pad_token_id is not None: + mm_ids[prompt_completion_ids == self._video_pad_token_id] = 2 + + # Use the same key the model expects: token_type_ids for models like Gemma, + # mm_token_type_ids for models like Qwen. + image_grid_thw = forward_kwargs.get("image_grid_thw") + if image_grid_thw is not None: + forward_kwargs["mm_token_type_ids"] = mm_ids + else: + forward_kwargs["token_type_ids"] = mm_ids + + # Truncation safety (Qwen-style models with image_grid_thw only): if + # max_completion_length truncated some image tokens, the number of image pad tokens + # in input_ids won't match pixel_values features. Check per-sample and drop ALL + # images for any sample with a mismatch (safe fallback). + if image_grid_thw is not None and num_images is not None: + merge_length = getattr(self.processing_class.image_processor, "merge_size", 2) ** 2 + img_offset = 0 + has_mismatch = False + for b in range(mm_ids.shape[0]): + sample_tokens = (mm_ids[b] == 1).sum().item() + sample_features = 0 + for i in range(num_images[b]): + grid_idx = img_offset + i + if grid_idx < image_grid_thw.shape[0]: + sample_features += image_grid_thw[grid_idx].prod().item() // merge_length + if sample_tokens != sample_features: + has_mismatch = True + break + img_offset += num_images[b] + + if has_mismatch: + # Drop all images: safer than partial trim which is error-prone + forward_kwargs.pop("pixel_values", None) + forward_kwargs.pop("image_grid_thw", None) + mm_ids.zero_() + forward_kwargs["mm_token_type_ids"] = mm_ids + num_images = None + + # When gradient checkpointing is enabled with use_reentrant=True (non default), calling the model inside a + # torch.no_grad() block triggers a harmless PyTorch warning ("None of the inputs have requires_grad=True"). + # Temporarily disable checkpointing to avoid this warning during inference. + with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs): + # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of + # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the + # samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps + # for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set + # old_per_token_logps to None. + # When using vLLM, we always compute old_per_token_logps for importance sampling, it was shown that the + # distribution mismatch between vLLM and the training model can be large and harm the training. + generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency + if ( + self.loss_type == "tpo" + or self.args.gradient_accumulation_steps % generate_every != 0 + or (self.use_vllm and self.vllm_importance_sampling_correction) + ): + old_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size, + num_images=num_images, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, image_position_ids + ) + else: + old_per_token_logps = None + + # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch + if self.use_vllm and self.vllm_importance_sampling_correction: + mask = completion_mask if tool_mask is None else completion_mask * tool_mask + per_token_logps_diff = (old_per_token_logps - sampling_per_token_logps) * mask + + sequence_level_is = self.vllm_importance_sampling_mode in ["sequence_mask", "sequence_truncate"] + if sequence_level_is: + per_sequence_logps_diff = per_token_logps_diff.sum(dim=-1, keepdim=True) + logps_diff = per_sequence_logps_diff + else: + logps_diff = per_token_logps_diff + + vllm_importance_sampling_ratio = torch.exp(logps_diff) + + # vllm_importance_sampling_ratio.shape: + # token_* modes: (B, T) (per-token ratio) + # sequence_* modes: (B, 1) (per-sequence ratio) + + if self.vllm_importance_sampling_mode in ["sequence_truncate", "token_truncate"]: + vllm_importance_sampling_ratio = torch.clamp( + vllm_importance_sampling_ratio, max=self.vllm_importance_sampling_cap + ) + elif self.vllm_importance_sampling_mode in ["sequence_mask", "token_mask"]: + vllm_importance_sampling_ratio = vllm_importance_sampling_ratio.masked_fill( + vllm_importance_sampling_ratio > self.vllm_importance_sampling_cap, value=0.0 + ) + else: + raise ValueError( + f"Unknown vLLM importance sampling level: {self.vllm_importance_sampling_mode}. Possible values are 'token_truncate', 'token_mask', 'sequence_truncate', and 'sequence_mask'." + ) + + # Compute the per-token log probabilities for the reference model + if self.beta != 0.0: + if self.ref_model is not None: + ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.ref_model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size=batch_size, + num_images=num_images, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, image_position_ids + ) + else: + # When training a PEFT adapter, how we obtain the reference depends on the setup: + # - New adapter: disabling adapters yields the base model. + # - Re-training an existing adapter: an initial copy is loaded under the name "ref". + model = self.accelerator.unwrap_model(self.model) + with use_adapter(model, adapter_name="ref" if "ref" in model.peft_config else None): + ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size=batch_size, + num_images=num_images, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, image_position_ids + ) + else: + ref_per_token_logps = None + + # Decode + prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) + completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + + # Merge extra_fields from rollout_func into inputs for reward functions + if extra_fields: + for i, inp in enumerate(inputs): + for key, values in extra_fields.items(): + if isinstance(values, list) and i < len(values): + inp[key] = values[i] + elif not isinstance(values, list): + inp[key] = values + + # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is + # important because rewards will be normalized per group, and completions are distributed. We will later slice + # rewards_per_func to extract each process's subset. + rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) + num_generations = self.num_generations if mode == "train" else self.num_generations_eval + + if self.multi_objective_aggregation == "sum_then_normalize": + # Apply weights to each reward function's output and sum + rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + mean_grouped_rewards = rewards.view(-1, num_generations).mean(dim=1) + mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(num_generations, dim=0) + if self.scale_rewards in ["group", "none"]: + # If self.scale_rewards = "none", we'll only use std_rewards to check for zero std for logging + if num_generations > 1: + std_rewards = rewards.view(-1, num_generations).std(dim=1) + std_rewards = std_rewards.repeat_interleave(num_generations, dim=0) + else: # doesn't occur during training, but could occur in eval when num_generations_eval=1 + std_rewards = torch.zeros_like(rewards) + elif self.scale_rewards == "batch": + # Compute global std + if rewards.numel() > 1: + std_rewards = rewards.std().expand_as(rewards) + else: # doesn't occur during training, but could occur in eval when num_generations_eval=batch_size=1 + std_rewards = torch.zeros_like(rewards) + else: + raise ValueError( + f"Invalid value for scale_rewards: {self.scale_rewards}. Must be one of 'batch', 'group', or 'none'." + ) + + advantages = rewards - mean_grouped_rewards + if self.scale_rewards != "none": + advantages = advantages / (std_rewards + 1e-4) + is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging + + elif self.multi_objective_aggregation == "normalize_then_sum": + grouped = rewards_per_func.view(-1, num_generations, len(self.reward_funcs)) + mean_k = torch.nanmean(grouped, dim=1, keepdim=True) + std_k = nanstd(grouped, dim=1, keepdim=True) if num_generations > 1 else torch.zeros_like(mean_k) + reward_k = (grouped - mean_k) / (std_k + 1e-4) + reward_k = reward_k.view(-1, len(self.reward_funcs)) + rewards = (reward_k * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards) + advantages = (rewards - rewards.mean()) / (std_rewards + 1e-4) + is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging + + else: + raise ValueError( + f"Invalid multi_objective_aggregation: {self.multi_objective_aggregation}. Must be " + "'sum_then_normalize' or 'normalize_then_sum'." + ) + + # Slice to keep only the local part of the data + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + all_process_advantages = advantages.clone() # keep the aggregated advantages for logging + tpo_targets = None + tpo_valid_mask = None + if self.loss_type == "tpo": + if old_per_token_logps is None: + raise RuntimeError("TPO requires rollout-time log probabilities to build the target distribution.") + loss_mask = completion_mask if tool_mask is None else completion_mask * tool_mask + tpo_valid_mask = loss_mask.any(dim=-1) + # Length-normalize sequence logps so the target distribution isn't dominated by length variance + # (without this, log_softmax(old_sequence_logps) collapses to ~one-hot on the longest/most-likely + # completion, and rewards can't compete). + old_sequence_logps = (old_per_token_logps * loss_mask).sum(dim=-1) + if self.tpo_length_normalize_logps: + old_sequence_logps = old_sequence_logps / loss_mask.sum(dim=-1).clamp(min=1) + all_process_old_sequence_logps = gather(old_sequence_logps) + all_process_tpo_valid_mask = gather(tpo_valid_mask) + tpo_scores = self.get_tpo_scores(rewards, num_generations, valid_mask=all_process_tpo_valid_mask) + all_process_tpo_targets = self.get_tpo_targets( + all_process_old_sequence_logps, + tpo_scores, + num_generations=num_generations, + temperature=self.tpo_target_temperature, + valid_mask=all_process_tpo_valid_mask, + ) + tpo_targets = all_process_tpo_targets[process_slice] + target_groups = all_process_tpo_targets.view(-1, num_generations) + target_entropy = -(target_groups * target_groups.clamp_min(torch.finfo(target_groups.dtype).tiny).log()) + target_entropy = target_entropy.sum(dim=1).mean() + self._metrics[mode]["tpo/target_entropy"].append(target_entropy.item()) + advantages = advantages[process_slice] + + # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) + for i, reward_func_name in enumerate(self.reward_func_names): + mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() + self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards) + std_func_rewards = nanstd(rewards_per_func[:, i]).item() + self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_func_rewards) + rewards = (rewards_per_func * self.reward_weights.to(rewards_per_func.device).unsqueeze(0)).nansum(dim=1) + self._metrics[mode]["reward"].append(rewards.mean().item()) + self._metrics[mode]["reward_std"].append(rewards.std().item()) + self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item()) + + # Log prompt and completion texts + self._logs["prompt"].extend(gather_object(prompts_text)) + self._logs["completion"].extend(gather_object(completions_text)) + for i, name in enumerate(self.reward_func_names): + self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) + self._logs["advantages"].extend(all_process_advantages.tolist()) + + # Flush user-logged extra columns (from log_extra), gathering across processes. + # Keys must be sorted so that all ranks call gather_object in the same order, otherwise values + # get mis-attributed across columns (dict insertion order may differ between processes). + for column in sorted(self._pending_extra_logs): + self._logs["extra"][column].extend(gather_object(self._pending_extra_logs[column])) + self._pending_extra_logs.clear() + + # Flush user-logged metrics (from log_metric), averaging across processes. + # Keys must be sorted so that all ranks call accelerator.gather in the same order, otherwise values + # get mis-attributed across metrics (dict insertion order may differ between processes). + for name in sorted(self._pending_metrics): + values = self._pending_metrics[name] + local_mean = sum(values) / len(values) + global_mean = self.accelerator.gather(torch.tensor(local_mean, device=device)).mean().item() + self._metrics[mode][name].append(global_mean) + self._pending_metrics.clear() + + if images is not None: + self._logs["images"].extend(gather_object(images)) + + if self.use_vllm and self.vllm_importance_sampling_correction: + delta = torch.abs(old_per_token_logps - sampling_per_token_logps) + mask = completion_mask.bool() if tool_mask is None else (completion_mask * tool_mask).bool() + delta = delta[mask] + mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) + max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) + self._metrics[mode]["sampling/sampling_logp_difference/mean"].append( + self.accelerator.gather(mean_delta).mean().item() + ) + self._metrics[mode]["sampling/sampling_logp_difference/max"].append( + self.accelerator.gather(max_delta).max().item() + ) + if sequence_level_is: + flat_is_ratio = vllm_importance_sampling_ratio.flatten() + else: + flat_is_ratio = vllm_importance_sampling_ratio[mask] + + min_importance_sampling_ratio = ( + torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + ) + mean_importance_sampling_ratio = ( + torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + ) + max_importance_sampling_ratio = ( + torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + ) + self._metrics[mode]["sampling/importance_sampling_ratio/min"].append( + nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item() + ) + self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append( + self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item() + ) + self._metrics[mode]["sampling/importance_sampling_ratio/max"].append( + nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item() + ) + + output = { + "prompt_ids": prompt_ids, + "prompt_mask": prompt_mask, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "advantages": advantages, + "num_items_in_batch": num_items_in_batch, + } + if old_per_token_logps is not None: + output["old_per_token_logps"] = old_per_token_logps + if tpo_targets is not None: + output["tpo_targets"] = tpo_targets + output["tpo_valid_mask"] = tpo_valid_mask + if self.use_vllm and self.vllm_importance_sampling_correction: + output["importance_sampling_ratio"] = vllm_importance_sampling_ratio + if sampling_per_token_logps is not None: + output["sampling_per_token_logps"] = sampling_per_token_logps + if ref_per_token_logps is not None: + output["ref_per_token_logps"] = ref_per_token_logps + if "pixel_values" in forward_kwargs: + output["pixel_values"] = forward_kwargs["pixel_values"] + if "image_grid_thw" in forward_kwargs: + output["image_grid_thw"] = forward_kwargs["image_grid_thw"] + if "pixel_attention_mask" in forward_kwargs: + output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"] + if "image_sizes" in forward_kwargs: + output["image_sizes"] = forward_kwargs["image_sizes"] + if "token_type_ids" in forward_kwargs: + output["token_type_ids"] = forward_kwargs["token_type_ids"] + if "mm_token_type_ids" in forward_kwargs: + output["mm_token_type_ids"] = forward_kwargs["mm_token_type_ids"] + if "image_position_ids" in forward_kwargs: + output["image_position_ids"] = forward_kwargs["image_position_ids"] + if images is not None: + output["num_images"] = num_images + if tool_mask is not None: + output["tool_mask"] = tool_mask + return output + + def compute_liger_loss(self, unwrapped_model, inputs): + # Compute the per-token log probabilities for the model + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + + # Get the last hidden state of the model + last_hidden_state = self._get_last_hidden_state( + unwrapped_model, + input_ids, + attention_mask, + logits_to_keep, + inputs.get("pixel_values"), + inputs.get("image_grid_thw"), + inputs.get("pixel_attention_mask"), + inputs.get("image_sizes"), + inputs.get("image_position_ids"), + ) + + # Apply tool_mask (from env_mask) for loss computation in multi-turn training scenarios + loss_mask = completion_mask if "tool_mask" not in inputs else completion_mask * inputs["tool_mask"] + mode = "train" if self.model.training else "eval" + normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval + self.liger_tpo_loss.num_generations = self.num_generations if mode == "train" else self.num_generations_eval + + # Compute loss and metrics using liger TargetPO loss + loss, mean_kl, mean_entropy = self.liger_tpo_loss( + _input=last_hidden_state, + lin_weight=unwrapped_model.lm_head.weight, + selected_token_ids=completion_ids, + # The attention_mask parameter in liger loss is actually used as a loss mask (not model attention) + attention_mask=loss_mask, + tpo_targets=inputs["tpo_targets"], + tpo_valid_mask=inputs.get("tpo_valid_mask"), + bias=unwrapped_model.lm_head.bias, + old_per_token_logps=inputs.get("old_per_token_logps"), + ref_per_token_logps=inputs.get("ref_per_token_logps"), + normalizer=normalizer, + ) + + if self.beta != 0.0: + self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).mean().item()) + self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) + return loss + + @profiling_decorator + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if return_outputs: + raise ValueError("The TargetPOTrainer does not support returning outputs") + if self.use_liger_kernel: + # Compute the loss using the liger TargetPO loss + unwrapped_model = self.accelerator.unwrap_model(model) + return self._forward_redirection(model, unwrapped_model, self.compute_liger_loss, unwrapped_model, inputs) + else: + return self._compute_loss(model, inputs) + + @staticmethod + def get_off_policy_mask( + advantages: torch.Tensor, + per_token_logps: torch.Tensor, + sampling_per_token_logps: torch.Tensor, + mask: torch.Tensor, + off_policy_threshold: float, + ) -> torch.Tensor: + """ + Computes the Off-Policy Sequence Mask from DeepSeek-V3.2 paper. Returns a (B, 1) tensor where 1.0 indicates + "Keep" and 0.0 indicates "Drop". + """ + # forward KL div: log(pi_old) - log(pi_theta) + kl_div = sampling_per_token_logps - per_token_logps.detach() + # Sequence-level Mean KL (ignoring prompt+padding) + seq_kl_sum = (kl_div * mask).sum(dim=1, keepdim=True) + avg_seq_kl = seq_kl_sum / mask.sum(dim=1, keepdim=True).clamp(min=1.0) + # Keep if (Advantage >= 0) OR (KL <= delta) + is_pos_adv = advantages >= 0 + is_low_kl = avg_seq_kl <= off_policy_threshold + return (is_pos_adv | is_low_kl).to(dtype=mask.dtype) # (B, 1) + + @staticmethod + def get_tpo_targets( + old_sequence_logps: torch.Tensor, + scores: torch.Tensor, + num_generations: int, + temperature: float = 1.0, + valid_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Build the Target Policy Optimization target distribution for each prompt group. + + The target is q_i ∝ p_i_old * exp(score_i / temperature), where p_i_old is the rollout policy distribution + over the sampled completions in a prompt group. + """ + if temperature <= 0.0: + raise ValueError(f"temperature must be greater than 0.0. You provided {temperature}.") + + old_sequence_logps = old_sequence_logps.view(-1, num_generations) + scores = scores.view(-1, num_generations) + if valid_mask is not None: + valid_mask = valid_mask.view(-1, num_generations).bool() + old_sequence_logps = old_sequence_logps.masked_fill(~valid_mask, torch.finfo(old_sequence_logps.dtype).min) + target_logits = torch.log_softmax(old_sequence_logps, dim=1) + scores / temperature + if valid_mask is not None: + target_logits = target_logits.masked_fill(~valid_mask, torch.finfo(target_logits.dtype).min) + targets = torch.softmax(target_logits, dim=1) + if valid_mask is not None: + targets = torch.where(valid_mask, targets, torch.zeros_like(targets)) + return targets.view(-1).detach() + + @staticmethod + def get_tpo_scores( + scores: torch.Tensor, num_generations: int, valid_mask: torch.Tensor | None = None + ) -> torch.Tensor: + scores = scores.view(-1, num_generations) + if valid_mask is not None: + valid_mask = valid_mask.view(-1, num_generations).bool() + valid_count = valid_mask.sum(dim=1, keepdim=True).clamp(min=1) + mean_scores = scores.masked_fill(~valid_mask, 0.0).sum(dim=1, keepdim=True) / valid_count + centered_scores = torch.where(valid_mask, scores - mean_scores, torch.zeros_like(scores)) + std_scores = (centered_scores.square().sum(dim=1, keepdim=True) / valid_count).sqrt() + else: + mean_scores = scores.mean(dim=1, keepdim=True) + std_scores = scores.std(dim=1, unbiased=False, keepdim=True) + centered_scores = scores - mean_scores + scores = torch.where(std_scores > 1e-6, centered_scores / std_scores, centered_scores) + return scores.view(-1) + + @staticmethod + def _gather_tensor_with_grad(tensor: torch.Tensor) -> torch.Tensor: + # Autograd-aware all_gather: required when a TPO prompt group spans DP ranks, so the log-softmax + # normalizer's gradient routes back to the owning rank. + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + return tensor + return torch.cat(_all_gather_with_grad(tensor), dim=0) + + @staticmethod + @torch.no_grad() + def get_gamma_weights( + advantages: torch.Tensor, + log_ratio_per_token: torch.Tensor, + mask: torch.Tensor, + importance_sampling_ratio: torch.Tensor | None, # (B, T) + k_pos: float = 2.0, + lambda_pos: float = 3.0, + k_neg: float = 3.0, + lambda_neg: float = 2.0, + ) -> torch.Tensor: + """ + Computes the Gamma weights for the VESPO loss. For reference: + φ(w) = e^λ × w^k × e^{-λw} is the gamma weighting (normalized so φ(1)=1) + with w = sequence-level importance sampling ratio + note: we will compute φ(w) in log space + + φ(w) is detached via @torch.no_grad(), only acts as gradient scaling coefficient + + VESPO loss = -φ(w) × A × log_prob, gradient naturally gives φ(w) × A × ∇log π + """ + # reducing clamp range directly to log(1e-8) ~ -18.42, to avoid recomputing log_w=log(w.clamp(min=1e-8)) later + # This is solely for matching truthfully the original implementation, otherwise keeping -20 could be fine. + lower_clamp = math.log(1e-8) + + # Sequence-level log ratio Σ log(π_θ/π_old) (not a mean like for `log_importance_weights`) + log_ratio_clamped = torch.clamp(log_ratio_per_token, -20.0, 20.0) + seq_log_ratio = torch.sum(log_ratio_clamped * mask, dim=-1, keepdim=True) # (B, 1) + + # Apply token-level TIS or MIS correction (in log space) + if importance_sampling_ratio is not None: + log_is_ratio = torch.clamp(torch.log(importance_sampling_ratio), lower_clamp, 20.0) + # log(w) = log(π_θ/π_old) + log(π_old/π_sampler) + seq_log_ratio += torch.sum(log_is_ratio, dim=-1, keepdim=True) + + log_w_seq = torch.clamp(seq_log_ratio, lower_clamp, 20.0) + w_seq = torch.exp(log_w_seq) + + # compute k and lambda based on advantage sign + is_nonneg_adv = advantages >= 0 + k_seq = torch.where(is_nonneg_adv, k_pos, k_neg) + lambda_seq = torch.where(is_nonneg_adv, lambda_pos, lambda_neg).clamp(min=1e-4) + + # log(φ(w)) = λ + k × log(w) - λ × w + log_phi = lambda_seq + k_seq * log_w_seq - lambda_seq * w_seq + phi_seq = torch.exp(log_phi).nan_to_num(nan=0.0, posinf=0.0, neginf=0.0) + + return phi_seq # (B, 1) + + def _compute_loss(self, model, inputs): + # Compute the per-token log probabilities for the model + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + mask = completion_mask if "tool_mask" not in inputs else completion_mask * inputs["tool_mask"] + + # Compute the per_token_logps and the entropy at each position in the completion + per_token_logps, entropies = self._get_per_token_logps_and_entropies( + model, + input_ids, + attention_mask, + logits_to_keep, + compute_entropy=True, + pixel_values=inputs.get("pixel_values"), + image_grid_thw=inputs.get("image_grid_thw"), + num_images=inputs.get("num_images"), + pixel_attention_mask=inputs.get("pixel_attention_mask"), + image_sizes=inputs.get("image_sizes"), + token_type_ids=inputs.get("token_type_ids"), + mm_token_type_ids=inputs.get("mm_token_type_ids"), + image_position_ids=inputs.get("image_position_ids"), + ) + + if self.top_entropy_quantile < 1.0: + entropy_mask = self.get_high_entropy_mask(entropies, mask, 1 - self.top_entropy_quantile) + else: + entropy_mask = None + + # Compute the loss + advantages = inputs["advantages"] + # In the base GRPO implementation, advantages are expected to have shape (B,). To support subclasses that + # provide advantages with shape (B, T) (e.g., MiniLLM), we *conditionally* unsqueeze the tensor. + if advantages.dim() == 1: + advantages = advantages.unsqueeze(1) + # When num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps, + # old_per_token_logps == per_token_logps. In this case we can skip its computation + # (see _generate_and_score_completions) and instead use per_token_logps.detach(). + # The exception is when using vLLM, where we always compute old_per_token_logps + # for importance sampling + old_per_token_logps = inputs.get("old_per_token_logps") + old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps + + if self.loss_type == "tpo": + if "tpo_targets" not in inputs: + raise RuntimeError("TPO loss requires `tpo_targets` in the prepared inputs.") + if self.off_policy_mask_threshold is not None: + raise ValueError("TPO loss does not support `off_policy_mask_threshold`.") + if self.top_entropy_quantile < 1.0: + raise ValueError("TPO loss does not support `top_entropy_quantile < 1.0`.") + + mode = "train" if self.model.training else "eval" + num_generations = self.num_generations if mode == "train" else self.num_generations_eval + # Length-normalized sequence logp (see _generate_and_score_completions for why). + sequence_logps = (per_token_logps * mask).sum(dim=-1) + if self.tpo_length_normalize_logps: + sequence_logps = sequence_logps / mask.sum(dim=-1).clamp(min=1) + all_sequence_logps = self._gather_tensor_with_grad(sequence_logps) + tpo_valid_mask = inputs.get("tpo_valid_mask") + if tpo_valid_mask is not None: + tpo_valid_mask = tpo_valid_mask.to(device=sequence_logps.device, dtype=torch.bool) + all_tpo_valid_mask = gather(tpo_valid_mask) + all_sequence_logps = all_sequence_logps.masked_fill( + ~all_tpo_valid_mask, torch.finfo(all_sequence_logps.dtype).min + ) + all_logps = torch.log_softmax(all_sequence_logps.view(-1, num_generations), dim=1).view(-1) + process_slice = slice( + self.accelerator.process_index * sequence_logps.size(0), + (self.accelerator.process_index + 1) * sequence_logps.size(0), + ) + logps = all_logps[process_slice] + tpo_targets = inputs["tpo_targets"].to(logps.dtype) + if tpo_valid_mask is not None: + tpo_targets = torch.where(tpo_valid_mask, tpo_targets, torch.zeros_like(tpo_targets)) + loss = -(tpo_targets * logps).sum() * num_generations / tpo_targets.numel() + + normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 + loss = loss / normalizer + + if self.beta != 0.0: + ref_per_token_logps = inputs["ref_per_token_logps"] + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + ) + if self.args.use_bias_correction_kl: + per_token_kl = per_token_kl * torch.exp(per_token_logps - old_per_token_logps) + kl_normalizer = ( + self.accelerator.gather(mask.sum()).sum().clamp(min=1.0) / self.accelerator.num_processes + ) + kl_loss = (per_token_kl * mask).sum() / kl_normalizer + loss = loss + self.beta * kl_loss / normalizer + self._metrics[mode]["kl"].append(self.accelerator.gather(kl_loss).nanmean().item()) + + completion_token_count = mask.sum().clamp(min=1.0) + mean_entropy = (entropies * mask).sum() / completion_token_count + self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) + + return loss + + if self.off_policy_mask_threshold is not None: + # OPSM should use inference-time logprobs to detect both sources of off-policyness: + # 1. Drift from gradient updates (always present) + # 2. Drift from training-inference mismatch (when using vLLM) + # When using vLLM, prioritize sampling_per_token_logps, otherwise use old_per_token_logps + sampling_per_token_logps = inputs.get("sampling_per_token_logps", old_per_token_logps) + + off_policy_mask = self.get_off_policy_mask( + advantages=advantages, + per_token_logps=per_token_logps, + sampling_per_token_logps=sampling_per_token_logps, + mask=mask, + off_policy_threshold=self.off_policy_mask_threshold, + ) + + log_ratio = per_token_logps - old_per_token_logps + if self.importance_sampling_level == "token": + log_importance_weights = log_ratio + elif self.importance_sampling_level == "sequence": + log_importance_weights = (log_ratio * mask).sum(-1) / mask.sum(-1).clamp(min=1.0) + log_importance_weights = log_importance_weights.unsqueeze(-1) + else: + raise ValueError( + f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " + "and 'sequence'." + ) + + coef_1 = torch.exp(log_importance_weights) + + # Compute the KL divergence between the model and the reference model + if self.beta != 0.0: + ref_per_token_logps = inputs["ref_per_token_logps"] + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + ) + # Importance sampling correction for the KL divergence + if self.args.use_bias_correction_kl: + per_token_kl = per_token_kl * coef_1 + + # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on + # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1) + if self.loss_type == "cispo": + clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach() + per_token_loss = -clamped_ratios * advantages * per_token_logps + elif self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo", "luspo"]: + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + # Two-sided clipping + if self.args.delta is not None: + coef_1 = torch.clamp(coef_1, max=self.args.delta) + + per_token_loss1 = coef_1 * advantages + per_token_loss2 = coef_2 * advantages + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + elif self.loss_type == "sapo": + temperatures = torch.where(advantages > 0, self.args.sapo_temperature_pos, self.args.sapo_temperature_neg) + soft_coef_1 = torch.sigmoid(temperatures * (coef_1 - 1)) * 4 / temperatures + per_token_loss = -soft_coef_1 * advantages + elif self.loss_type == "vespo": + phi_seq = self.get_gamma_weights( + advantages=advantages, + log_ratio_per_token=log_ratio, + mask=mask, + importance_sampling_ratio=inputs.get("importance_sampling_ratio"), + k_pos=self.args.vespo_k_pos, + lambda_pos=self.args.vespo_lambda_pos, + k_neg=self.args.vespo_k_neg, + lambda_neg=self.args.vespo_lambda_neg, + ) + per_token_loss = -phi_seq * advantages * per_token_logps + else: + raise ValueError(f"Unknown loss type: {self.loss_type}") + + if self.off_policy_mask_threshold is not None: + per_token_loss = per_token_loss * off_policy_mask + + if entropy_mask is not None: + per_token_loss = per_token_loss * entropy_mask + + if self.use_vllm and self.vllm_importance_sampling_correction and self.loss_type != "vespo": + per_token_loss = per_token_loss * inputs["importance_sampling_ratio"] + + if self.beta != 0.0: + per_token_loss = per_token_loss + self.beta * per_token_kl + + mode = "train" if self.model.training else "eval" + if self.loss_type in ["grpo", "sapo"]: + loss = ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() + normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval + loss = loss / normalizer + elif self.loss_type == "bnpo": + loss = (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0) + normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval + loss = loss / normalizer + elif self.loss_type == "dr_grpo": + loss = (per_token_loss * mask).sum() / (per_token_loss.size(0) * self.max_completion_length) + normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval + loss = loss / normalizer + elif self.loss_type in ["cispo", "dapo", "vespo"]: + normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes + loss = (per_token_loss * mask).sum() / normalizer + elif self.loss_type == "luspo": + # Unless importance_sampling_level="token" (not recommended here), per_token_loss is expected to be (B, 1) + loss = (per_token_loss * mask.sum(1, keepdim=True)).mean() + normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 + loss = loss / normalizer + else: + raise ValueError(f"Unknown loss type: {self.loss_type}") + + # Log the metrics + completion_token_count = mask.sum().clamp(min=1.0) + + def masked_batch_mean(x): + if x.shape[1] == 1: # when importance_sampling_level == "sequence" + return x.mean() + else: + return (x * mask).sum() / completion_token_count + + if self.beta != 0.0: + mean_kl = masked_batch_mean(per_token_kl) + self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item()) + + mean_entropy = masked_batch_mean(entropies) + self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) + + if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo", "luspo"]: + # Compute the clipped probability ratios + is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages < 0) + is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages > 0) + is_region_clipped = is_low_clipped | is_high_clipped + + low_clip = masked_batch_mean(is_low_clipped.float()) + high_clip = masked_batch_mean(is_high_clipped.float()) + clip_ratio = masked_batch_mean(is_region_clipped.float()) + + gathered_low_clip = self.accelerator.gather(low_clip) + self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) + self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) + gathered_high_clip = self.accelerator.gather(high_clip) + self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) + self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) + gathered_clip_ratio = self.accelerator.gather(clip_ratio) + self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) + elif self.loss_type == "cispo": + is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages > 0) + cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float()) + gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio) + self._metrics[mode]["cispo_clip_ratio"].append(gathered_cispo_clip_ratio.nanmean().item()) + elif self.loss_type == "vespo": + gathered_phi_seq = self.accelerator.gather(phi_seq) + self._metrics[mode]["vespo/phi_seq_mean"].append(gathered_phi_seq.nanmean().item()) + + return loss + + # During eval, Trainer calls prediction_step. If no labels are present in the inputs, it only runs forward and + # returns logits. We override prediction_step to force compute_loss, because this trainer doesn't involve labels. + def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: list[str] | None = None): + inputs = self._prepare_inputs(inputs) + with torch.no_grad(): + with self.compute_loss_context_manager(): + loss = self.compute_loss(model, inputs) + loss = loss.mean().detach() + return loss, None, None + + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: + mode = "train" if self.model.training else "eval" + # Average the metrics + metrics = {} + for key, val in self._metrics[mode].items(): + # Filter out NaN values before averaging. A reward function that returns None for all samples + # in a batch produces NaN for that batch's metric. With logging_steps > 1, a naive sum()/len() + # would let a single NaN contaminate valid data from other batches. Only return None when no + # valid values remain (e.g. JSON loggers crash on float NaN). + valid = [v for v in val if not math.isnan(v)] + metrics[key] = sum(valid) / len(valid) if valid else None + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if mode == "eval": + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs = {**logs, **metrics} + super().log(logs, start_time) + self._metrics[mode].clear() + + if self.accelerator.is_main_process and self.log_completions: + if is_rich_available(): + print_prompt_completions_sample( + self._logs["prompt"], + self._logs["completion"], + self._logs["rewards"], + self._logs["advantages"], + self.state.global_step, + self.num_completions_to_print, + extra=dict(self._logs["extra"]), + ) + + logging_backends = [] + if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None: + logging_backends.append(wandb) + if self.args.report_to and "trackio" in self.args.report_to: + logging_backends.append(trackio) + + table = { + "step": [self.state.global_step] * len(self._logs["prompt"]), + "prompt": self._logs["prompt"], + "completion": self._logs["completion"], + **self._logs["rewards"], + **self._logs["extra"], + "advantage": self._logs["advantages"], + } + + df_base = pd.DataFrame(table) + df_base.to_parquet( + os.path.join( + self.args.output_dir, + "completions", + f"completions_{self.state.global_step:05d}.parquet", + ) + ) + + images_raw = self._logs["images"] or [] + + for logging_backend in logging_backends: + if images_raw: + images = [] + for image_list in self._logs["images"]: + if image_list: + images.append([logging_backend.Image(image) for image in image_list]) + else: + images.append([]) + df = pd.concat( + [df_base, pd.Series(images, name="image")], + axis=1, + copy=False, + ) + else: + df = df_base + + if self.log_unique_prompts: + df = df.drop_duplicates(subset=["prompt"]) + + logging_backend.log({"completions": logging_backend.Table(dataframe=df)}) + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial)