diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 4c78c9d3bb5..a74f85a5e75 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -1641,8 +1641,8 @@ from trl.experimental.sdpo import SDPOConfig, SDPOTrainer training_args = SDPOConfig( distillation_alpha=0.5, # Jensen-Shannon divergence (recommended) - distillation_topk=100, # Top-K logit distillation approximation - full_logit_distillation=True, # Required for top-K logit-level SDPO + distillation_mode="topk_logits", # Explicitly select top-K logit distillation + distillation_topk=100, # Required for top-K logit distillation distillation_is_clip=2.0, # Importance sampling clipping distillation_weight=1.0, # Weight for self-distillation loss sdpo_policy_loss_mode="distillation_only", @@ -1689,6 +1689,7 @@ dataset = Dataset.from_dict( training_args = SDFTConfig( distillation_alpha=0.5, + distillation_mode="topk_logits", distillation_topk=5, max_completion_length=64, ) @@ -1739,6 +1740,67 @@ Expected dataset columns: For more details, see the [SSD Trainer documentation](ssd_trainer). +### Self-Distillation Zero: Self-Revision Turns Binary Rewards into Dense Supervision + +**📜 Paper**: https://huggingface.co/papers/2604.12002 + +SD-ZERO turns binary verifier rewards into dense supervision in two phases. Phase 1 — **Self-Revision Training (SRT)** — first has a model answer a problem `x` with an initial attempt `y_init`. A binary verifier then decides whether `y_init` is correct and chooses a control prompt `P_r`: rephrase the solution if the attempt is correct, or restart if it is not. Conditioned on `(x, y_init, P_r)`, the model samples revised answers and keeps only revisions `y_revised` that verify correct. Those accepted self-revision traces are then used for supervised learning with a joint objective: predict `y_revised` given `(x, y_init, P_r)`, and predict the full assistant trace `[y_init, P_r, y_revised]` from `x`. Phase 1 is implemented as [`experimental.sdzero.SRTTrainer`], and the companion collection script [`trl/experimental/sdzero/srt_collect.py`] is the recommended way to build the offline revision dataset. + +```python +from datasets import load_from_disk + +from trl.experimental.sdzero import SRTConfig, SRTTrainer + +training_args = SRTConfig( + include_revision_loss=True, # L_revision term + include_generation_loss=True, # L_generation term + assistant_turn_template="{y_init}\n\n{control_prompt}\n\n{y_revised}", +) + +trainer = SRTTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", + args=training_args, + train_dataset=load_from_disk("/path/to/revision_dataset"), +) +trainer.train() +``` + +Expected dataset columns: + +- `problem` +- `y_init` +- `control_prompt` +- `y_revised` + +Phase 2 — **On-Policy Self-Distillation** — distills the reviser back into the generator. At each training step, the student generates a response `y` on-policy. The teacher context is the on-policy generation `y_init`, the problem `x`, and the verifier-selected `P_r`. By default, [`experimental.sdzero.SDZeroTrainer`] matches the paper's frozen SRT teacher and full-vocabulary `D_KL(student || teacher)` objective. + +```python +from datasets import Dataset + +from trl.experimental.sdzero import SDZeroConfig, SDZeroTrainer + +dataset = Dataset.from_list([ + {"prompt": [{"role": "user", "content": "...problem..."}], "answer": "...gold answer..."}, +]) + +training_args = SDZeroConfig( + max_completion_length=512, + assistant_turn_template="{y}\n\n{control_prompt}\n\n", +) + +trainer = SDZeroTrainer( + model="path/to/srt-checkpoint", + args=training_args, + train_dataset=dataset, +) +trainer.train() +``` + +Expected dataset columns: + +- `prompt` (conversational list or plain string) +- `answer` (gold answer; passed to the binary verifier) + ## Distributed Training ### ZeRO: Memory Optimizations Toward Training Trillion Parameter Models diff --git a/docs/source/sdft_trainer.md b/docs/source/sdft_trainer.md index 17f6433fd78..23f29966458 100644 --- a/docs/source/sdft_trainer.md +++ b/docs/source/sdft_trainer.md @@ -32,6 +32,7 @@ dataset = Dataset.from_dict( training_args = SDFTConfig( output_dir="sdft-model", distillation_alpha=0.5, + distillation_mode="topk_logits", distillation_topk=5, max_completion_length=64, ) diff --git a/docs/source/sdpo_trainer.md b/docs/source/sdpo_trainer.md index 3910f674baf..7f82f565bcb 100644 --- a/docs/source/sdpo_trainer.md +++ b/docs/source/sdpo_trainer.md @@ -11,8 +11,9 @@ In the current TRL implementation: - the default SDPO policy loss mode is `distillation_only` - `hybrid` mode is also available to combine the base policy loss with the self-distillation loss - supported teacher regularization modes are `ema` and `none` -- `distillation_topk` is only valid when `full_logit_distillation=True` -- when `full_logit_distillation=False`, SDPO uses token-level reverse KL and requires `distillation_alpha=1.0` +- `distillation_mode` selects between `sampled_token`, `full_logits`, and `topk_logits` +- `distillation_topk` is only valid when `distillation_mode="topk_logits"` +- when `distillation_mode="sampled_token"`, SDPO uses token-level reverse KL and requires `distillation_alpha=1.0` - environment feedback can be injected into teacher reprompts when the dataset exposes a `privileged_context` column ## Expected dataset columns @@ -38,8 +39,8 @@ dataset = Dataset.from_dict( training_args = SDPOConfig( output_dir="sdpo-model", - distillation_topk=100, # Top-K logit distillation approximation - full_logit_distillation=True, # Required for top-K; enables non-reverse divergences + distillation_mode="topk_logits", # Explicitly select top-K logit distillation + distillation_topk=100, # Required when using top-K logit distillation include_environment_feedback=True, # Use dataset privileged_context for teacher reprompts ) @@ -88,7 +89,7 @@ python trl/experimental/sdpo/sdpo.py \ --num_generations 8 \ --generation_batch_size 32 \ --distillation_alpha 1.0 \ - --full_logit_distillation false \ + --distillation_mode sampled_token \ --sdpo_policy_loss_mode hybrid \ --report_to none \ --eval_strategy steps \ diff --git a/tests/experimental/test_base_self_distillation_trainer.py b/tests/experimental/test_base_self_distillation_trainer.py new file mode 100644 index 00000000000..bd7c7e880b1 --- /dev/null +++ b/tests/experimental/test_base_self_distillation_trainer.py @@ -0,0 +1,276 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import defaultdict +from types import SimpleNamespace + +import pytest +import torch +from datasets import Dataset +from transformers import AutoModelForCausalLM +from transformers.utils import is_peft_available + +from trl.experimental.self_distillation.base_self_distillation_trainer import ( + BaseSelfDistillationTrainer, + DistillationLogits, +) +from trl.experimental.self_distillation.self_distillation_config import SelfDistillationConfig + +from ..testing_utils import TrlTestCase + + +if is_peft_available(): + from peft import LoraConfig, get_peft_model + + +class MinimalSelfDistillationTrainer(BaseSelfDistillationTrainer): + def finalize_batch(self, inputs, rollout_batch): + del inputs + batch = rollout_batch.as_dict() + batch["teacher_input_ids"] = rollout_batch.prompt_ids + batch["teacher_attention_mask"] = rollout_batch.prompt_mask + return batch + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + del inputs, num_items_in_batch + anchor = next(model.parameters()) + return anchor.sum() * 0.0 + + +class FakeTextTokenizer: + def __call__(self, text, **kwargs): + del kwargs + token_map = { + "short prompt": [1, 2, 3], + "long prompt": [10, 11, 12, 13, 14], + } + return {"input_ids": [token_map[prompt] for prompt in text]} + + +class FakeChatProcessor: + def __init__(self): + self.calls = [] + + def apply_chat_template(self, conversation, add_generation_prompt, tokenize, return_dict, **kwargs): + self.calls.append( + { + "conversation": conversation, + "add_generation_prompt": add_generation_prompt, + "tokenize": tokenize, + "return_dict": return_dict, + "kwargs": kwargs, + } + ) + return {"input_ids": [[21, 22, 23, 24]]} + + +class TestBaseSelfDistillationTrainer(TrlTestCase): + @staticmethod + def _make_loss_test_trainer(**args_overrides): + trainer = object.__new__(MinimalSelfDistillationTrainer) + args = { + "distillation_mode": "sampled_token", + "distillation_topk": None, + "distillation_alpha": 1.0, + "distillation_add_tail": False, + "distillation_is_clip": None, + } + args.update(args_overrides) + trainer.args = SimpleNamespace(**args) + trainer.loss_type = "dapo" + trainer.max_completion_length = 2 + trainer.accelerator = SimpleNamespace(gather=lambda tensor: tensor) + trainer._metrics = { + "train": defaultdict(list), + "eval": defaultdict(list), + } + trainer._name = "Minimal Self Distillation" + return trainer + + def test_teacher_model_kind_live_uses_student_model(self): + dataset = Dataset.from_dict({"prompt": ["Solve 2+2."]}) + training_args = SelfDistillationConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=1, + max_completion_length=8, + max_steps=1, + num_generations=1, + teacher_model_kind="live", + report_to="none", + ) + + trainer = MinimalSelfDistillationTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + ) + + assert trainer.teacher_model is trainer.model + + @pytest.mark.skipif(not is_peft_available(), reason="PEFT is required for this test") + def test_warns_when_initial_student_already_has_a_peft_adapter(self, caplog): + dataset = Dataset.from_dict({"prompt": ["Solve 2+2."]}) + training_args = SelfDistillationConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=1, + max_completion_length=8, + max_steps=1, + num_generations=1, + teacher_model_kind="base", + report_to="none", + ) + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + model = get_peft_model( + model, + LoraConfig( + r=4, + lora_alpha=8, + target_modules=["q_proj", "v_proj"], + bias="none", + task_type="CAUSAL_LM", + ), + ) + + with caplog.at_level( + logging.WARNING, logger="trl.experimental.self_distillation.base_self_distillation_trainer" + ): + MinimalSelfDistillationTrainer( + model=model, + args=training_args, + train_dataset=dataset, + ) + + assert "already contains a PEFT adapter" in caplog.text + assert "`teacher_model_kind='base'` may refer to the underlying base weights" in caplog.text + + @pytest.mark.parametrize("teacher_model_kind", ["base", "ema"]) + def test_teacher_model_kind_base_and_ema_use_frozen_teacher_copy(self, teacher_model_kind): + dataset = Dataset.from_dict({"prompt": ["Solve 2+2."]}) + training_args = SelfDistillationConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=1, + max_completion_length=8, + max_steps=1, + num_generations=1, + teacher_model_kind=teacher_model_kind, + report_to="none", + ) + + trainer = MinimalSelfDistillationTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + ) + + assert trainer.teacher_model is not trainer.model + assert trainer.teacher_model.training is False + + student_param = next(trainer.model.parameters()) + teacher_param = next(trainer.teacher_model.parameters()) + assert teacher_param.requires_grad is False + assert teacher_param.data_ptr() != student_param.data_ptr() + + def test_tokenize_prompts_truncates_text_prompts_from_left(self): + trainer = object.__new__(MinimalSelfDistillationTrainer) + trainer.processing_class = FakeTextTokenizer() + trainer.max_prompt_length = 3 + + prompt_ids = trainer._tokenize_prompts(["long prompt", "short prompt"]) + + assert prompt_ids == [[12, 13, 14], [1, 2, 3]] + + def test_tokenize_prompts_for_conversational_prompts_forwards_chat_template_kwargs(self): + trainer = object.__new__(MinimalSelfDistillationTrainer) + trainer.processing_class = FakeChatProcessor() + trainer.max_prompt_length = 2 + trainer.chat_template_kwargs = {"enable_thinking": False} + + prompt_ids = trainer._tokenize_prompts([[{"role": "user", "content": "Solve 2+2."}]]) + + assert prompt_ids == [[23, 24]] + assert trainer.processing_class.calls == [ + { + "conversation": [[{"role": "user", "content": "Solve 2+2."}]], + "add_generation_prompt": True, + "tokenize": True, + "return_dict": True, + "kwargs": {"enable_thinking": False}, + } + ] + + def test_compute_self_distillation_loss_ignores_masked_completion_tokens(self): + trainer = self._make_loss_test_trainer( + distillation_mode="full_logits", + distillation_alpha=0.0, + ) + model = SimpleNamespace(training=True) + + # Token 0 is active and has a known non-zero divergence. + # Token 1 is intentionally very different but masked out, so it must not affect the loss. + student_probs = torch.tensor([[[0.8, 0.2], [0.01, 0.99]]], dtype=torch.float32) + teacher_probs = torch.tensor([[[0.5, 0.5], [0.99, 0.01]]], dtype=torch.float32) + distillation_logits = DistillationLogits( + completion_ids=torch.tensor([[0, 1]], dtype=torch.long), + completion_mask=torch.tensor([[1, 1]], dtype=torch.long), + response_mask=torch.tensor([[1, 0]], dtype=torch.long), + student_logits=student_probs.log(), + teacher_logits=teacher_probs.log(), + ) + + loss = trainer._compute_self_distillation_loss(model, {}, distillation_logits) + + expected_active_token_loss = teacher_probs[0, 0, 0] * ( + teacher_probs[0, 0, 0].log() - student_probs[0, 0, 0].log() + ) + teacher_probs[0, 0, 1] * (teacher_probs[0, 0, 1].log() - student_probs[0, 0, 1].log()) + torch.testing.assert_close(loss, expected_active_token_loss) + torch.testing.assert_close( + torch.tensor(trainer._metrics["train"]["self_distillation/distillation_loss"]), + expected_active_token_loss.unsqueeze(0), + ) + + def test_compute_self_distillation_loss_applies_importance_sampling_clip(self): + trainer = self._make_loss_test_trainer(distillation_is_clip=2.0) + model = SimpleNamespace(training=True) + + student_token_probs = torch.tensor([[0.2, 0.4]], dtype=torch.float32) + teacher_token_probs = torch.tensor([[0.5, 0.5]], dtype=torch.float32) + old_token_probs = torch.tensor([[0.05, 0.4]], dtype=torch.float32) + clip_coeff = trainer.args.distillation_is_clip + + distillation_logits = DistillationLogits( + completion_ids=torch.tensor([[0, 1]], dtype=torch.long), + completion_mask=torch.tensor([[1, 1]], dtype=torch.long), + response_mask=torch.tensor([[1, 1]], dtype=torch.long), + student_logits=torch.log(torch.tensor([[[0.2, 0.8], [0.6, 0.4]]], dtype=torch.float32)), + teacher_logits=torch.log(torch.tensor([[[0.5, 0.5], [0.5, 0.5]]], dtype=torch.float32)), + ) + + loss = trainer._compute_self_distillation_loss( + model, + {"old_per_token_logps": old_token_probs.log()}, + distillation_logits, + ) + + raw_per_token_loss = (student_token_probs.log() - teacher_token_probs.log()) * student_token_probs.log() + clipped_ratio = torch.minimum( + student_token_probs / old_token_probs, torch.full_like(student_token_probs, clip_coeff) + ) + expected_loss = (raw_per_token_loss * clipped_ratio).mean() + + torch.testing.assert_close(loss, expected_loss) + torch.testing.assert_close( + torch.tensor(trainer._metrics["train"]["self_distillation/distillation_loss"]), + expected_loss.unsqueeze(0), + ) diff --git a/tests/experimental/test_sdft_trainer.py b/tests/experimental/test_sdft_trainer.py index 9ca9b6c579a..f307c5a96dd 100644 --- a/tests/experimental/test_sdft_trainer.py +++ b/tests/experimental/test_sdft_trainer.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import patch + import pytest import torch from datasets import Dataset from transformers import AutoModelForCausalLM, TrainerCallback, TrainerControl, TrainerState, TrainingArguments from transformers.utils import is_peft_available -from trl.data_utils import maybe_apply_chat_template from trl.experimental.sdft import SDFTConfig, SDFTTrainer from ..testing_utils import TrlTestCase, require_peft @@ -27,18 +28,18 @@ if is_peft_available(): from peft import LoraConfig, get_peft_model, get_peft_model_state_dict - from trl.experimental.self_distillation.peft_adapter_ema_callback import PEFTAdapterEMACallback + from trl.experimental.self_distillation.teacher_sync import PEFTAdapterEMACallback class SelfDistillationCaptureCallback(TrainerCallback): def __init__(self): - self.captured_generation_prompt_text = None + self.captured_generation_prompts = None self.captured_old_per_token_logps = None self.generation_batch_build_count = 0 - def on_generation_prompts_selected(self, generation_prompt_text=None, **kwargs): - if self.captured_generation_prompt_text is None and generation_prompt_text is not None: - self.captured_generation_prompt_text = generation_prompt_text[0] + def on_generation_prompts_selected(self, generation_prompts=None, **kwargs): + if self.captured_generation_prompts is None and generation_prompts is not None: + self.captured_generation_prompts = generation_prompts def on_self_distillation_batch_prepared(self, old_per_token_logps=None, **kwargs): if self.captured_old_per_token_logps is None and old_per_token_logps is not None: @@ -74,6 +75,7 @@ def test_training_rejects_none_privileged_context(self): with pytest.raises(ValueError, match="`privileged_context` must not be None"): trainer.train() + @pytest.mark.skip(reason="`generate_from_teacher` is not ported yet") def test_training_with_generate_from_teacher(self): dataset = Dataset.from_dict( { @@ -105,9 +107,8 @@ def test_training_with_generate_from_teacher(self): trainer.train() - assert capture_callback.captured_generation_prompt_text is not None - assert "Solve 2+2." in capture_callback.captured_generation_prompt_text - assert "Teacher hint" in capture_callback.captured_generation_prompt_text + assert capture_callback.captured_generation_prompts is not None + assert capture_callback.captured_generation_prompts[0] != dataset[0]["prompt"] def test_training_with_chat_template_kwargs(self): dataset = Dataset.from_dict( @@ -141,15 +142,15 @@ def test_training_with_chat_template_kwargs(self): callbacks=[capture_callback], ) - expected_prompt = maybe_apply_chat_template( - {"prompt": dataset[0]["prompt"]}, + with patch.object( trainer.processing_class, - **training_args.chat_template_kwargs, - )["prompt"] - - trainer.train() + "apply_chat_template", + wraps=trainer.processing_class.apply_chat_template, + ) as mock_apply_chat_template: + trainer.train() - assert capture_callback.captured_generation_prompt_text == expected_prompt + assert mock_apply_chat_template.call_count > 0 + assert any(call.kwargs.get("enable_thinking") is False for call in mock_apply_chat_template.call_args_list) @require_peft def test_training_with_peft_model(self): @@ -205,9 +206,9 @@ def test_training_with_peft_model_and_sync_ref_model(self): max_completion_length=8, max_steps=2, num_generations=1, - sync_ref_model=True, - ref_model_mixup_alpha=0.05, - ref_model_sync_steps=1, + teacher_model_kind="ema", + teacher_update_rate=0.05, + teacher_sync_steps=1, ) trainer = SDFTTrainer( diff --git a/tests/experimental/test_sdpo_trainer.py b/tests/experimental/test_sdpo_trainer.py index 1bb666a7dcc..81a29af4822 100644 --- a/tests/experimental/test_sdpo_trainer.py +++ b/tests/experimental/test_sdpo_trainer.py @@ -100,14 +100,6 @@ def test_vllm_config_defaults_match_reference_trainers(self): assert config.vllm_model_impl == "vllm" def test_generate_vllm_syncs_on_step_change_and_uses_mode_specific_num_generations(self): - class FakeTokenizer: - def __call__(self, text, **kwargs): - token_map = { - "Solve 2+2.": [11, 12], - "Check 3+3.": [21, 22], - } - return {"input_ids": [token_map[prompt] for prompt in text]} - class FakeVLLMGeneration: def __init__(self): self.sync_weights_call_count = 0 @@ -135,11 +127,9 @@ def generate(self, prompts, images, num_generations): trainer.model = SimpleNamespace(training=True) trainer.state = SimpleNamespace(global_step=4) trainer._last_loaded_step = 3 - trainer.processing_class = FakeTokenizer() trainer.vllm_generation = FakeVLLMGeneration() - trainer._apply_prompt_template = lambda prompts: prompts - prompt_ids, completion_ids = trainer._generate(["Solve 2+2.", "Solve 2+2."]) + prompt_ids, completion_ids = trainer._generate([[11, 12], [11, 12]]) assert prompt_ids == [[11, 12], [11, 12]] assert completion_ids == [[100], [101]] @@ -154,7 +144,7 @@ def generate(self, prompts, images, num_generations): ] trainer.model.training = False - eval_prompt_ids, eval_completion_ids = trainer._generate(["Check 3+3.", "Check 3+3.", "Check 3+3."]) + eval_prompt_ids, eval_completion_ids = trainer._generate([[21, 22], [21, 22], [21, 22]]) assert eval_prompt_ids == [[21, 22], [21, 22], [21, 22]] assert eval_completion_ids == [[100], [101], [102]] @@ -167,7 +157,7 @@ def generate(self, prompts, images, num_generations): trainer.model.training = True trainer.state.global_step = 5 - trainer._generate(["Solve 2+2.", "Solve 2+2."]) + trainer._generate([[11, 12], [11, 12]]) assert trainer.vllm_generation.sync_weights_call_count == 2 assert trainer._last_loaded_step == 5 @@ -181,8 +171,8 @@ def test_training(self): per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=8, # reduce the completion length to reduce memory usage + distillation_mode="topk_logits", distillation_topk=5, - full_logit_distillation=True, distillation_is_clip=None, ) trainer = SDPOTrainer( diff --git a/trl/experimental/sdft/sdft.py b/trl/experimental/sdft/sdft.py index 958a24fc683..6c9aa8d6179 100644 --- a/trl/experimental/sdft/sdft.py +++ b/trl/experimental/sdft/sdft.py @@ -46,10 +46,9 @@ --learning_rate 2e-5 \ --max_prompt_length 1024 \ --max_completion_length 512 \ - --generate_from_teacher \ - --sync_ref_model \ - --ref_model_sync_steps 1 \ - --ref_model_mixup_alpha 0.01 \ + --teacher_model_kind ema \ + --teacher_sync_steps 1 \ + --teacher_update_rate 0.01 \ --eval_strategy steps \ --eval_steps 50 \ --report_to wandb @@ -86,10 +85,6 @@ @dataclass class SDFTScriptArguments(ScriptArguments): - ref_model_name_or_path: str | None = field( - default=None, - metadata={"help": "Reference teacher model. Optional for PEFT runs, where the base model is used as teacher."}, - ) dataset_path: str | None = field( default=None, metadata={"help": "Optional local dataset path to load with `load_from_disk`. Overrides `dataset_name`."}, @@ -122,14 +117,6 @@ class SDFTScriptArguments(ScriptArguments): ) -@dataclass -class ExampleSDFTConfig(SDFTConfig): - scale_rewards: str = field( - default="group", - metadata={"help": "Reward normalization mode. Supported: `group`, `batch`, `none`."}, - ) - - def _extract_prompt_text(prompt: Any) -> str: if isinstance(prompt, str): return prompt @@ -323,14 +310,15 @@ def _run_tooluse_eval( if __name__ == "__main__": - parser = TrlParser((SDFTScriptArguments, ExampleSDFTConfig, ModelConfig)) + parser = TrlParser((SDFTScriptArguments, SDFTConfig, ModelConfig)) script_args, training_args, model_args = parser.parse_args_and_config() if model_args.model_name_or_path is None: raise ValueError("`model_name_or_path` is required.") - if script_args.ref_model_name_or_path is None and not model_args.use_peft: - script_args.ref_model_name_or_path = model_args.model_name_or_path - + if training_args.generate_from_teacher: + raise ValueError( + "`generate_from_teacher` is not yet supported by the transitioned SDFT trainer used in this script." + ) if model_args.dtype in ["auto", None]: if training_args.bf16: dtype = torch.bfloat16 @@ -385,16 +373,10 @@ def _run_tooluse_eval( eval_dataset = _prepare_split(raw_eval_dataset, script_args) model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) - ref_model = None - if script_args.ref_model_name_or_path is not None: - ref_model = AutoModelForCausalLM.from_pretrained(script_args.ref_model_name_or_path, **model_kwargs) model.config.use_cache = False if training_args.gradient_checkpointing else True - if ref_model is not None: - ref_model.config.use_cache = True trainer = SDFTTrainer( model=model, - ref_model=ref_model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, diff --git a/trl/experimental/sdft/sdft_config.py b/trl/experimental/sdft/sdft_config.py index 84227e43cbf..21ec6b2a814 100644 --- a/trl/experimental/sdft/sdft_config.py +++ b/trl/experimental/sdft/sdft_config.py @@ -20,14 +20,14 @@ @dataclass class SDFTConfig(SelfDistillationConfig): r""" - Configuration class for [`SDFTTrainer`]. - - This adapts the official SDFT implementation to the TRL trainer API while reusing the common self-distillation - configuration shared with SDPO. + Configuration class for [`SDFTTrainer`].. Parameters: disable_dropout (`bool`, *optional*, defaults to `True`): Whether to disable dropout in the student and teacher models. + teacher_model_kind (`str`, *optional*, defaults to `"base"`): + Semantic teacher choice for SDFT. `base` uses the initial student, `live` uses the current student, and + `ema` uses an exponentially averaged teacher. generate_from_teacher (`bool`, *optional*, defaults to `False`): Whether on-policy generation should use the teacher-conditioned prompt instead of the student prompt. teacher_prompt_template (`str`, *optional*, defaults to `"{prompt}\n\n{privileged_context}"`): @@ -40,6 +40,13 @@ class SDFTConfig(SelfDistillationConfig): default=True, metadata={"help": "Whether to disable dropout in the student and teacher models."}, ) + teacher_model_kind: str = field( + default="base", + metadata={ + "help": "Semantic teacher choice for SDFT. `base` uses the initial student, `live` uses the current " + "student, and `ema` uses an exponentially averaged teacher." + }, + ) generate_from_teacher: bool = field( default=False, metadata={"help": "Whether on-policy generation should use the teacher-conditioned prompt."}, diff --git a/trl/experimental/sdft/sdft_trainer.py b/trl/experimental/sdft/sdft_trainer.py index 5bf6095c2a0..24959a36bc0 100644 --- a/trl/experimental/sdft/sdft_trainer.py +++ b/trl/experimental/sdft/sdft_trainer.py @@ -14,66 +14,43 @@ from __future__ import annotations -import copy -import inspect import textwrap -from collections import defaultdict -from functools import partial from typing import Any -import datasets import torch from accelerate.logging import get_logger -from accelerate.utils import is_peft_model from datasets import Dataset, IterableDataset from torch import nn -from torch.utils.data import DataLoader, Sampler from transformers import ( - AutoProcessor, - GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, ) -from transformers.trainer_utils import seed_worker -from transformers.utils import is_datasets_available, is_peft_available +from transformers.utils import is_peft_available -from ...models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation -from ...trainer.base_trainer import _BaseTrainer -from ...trainer.callbacks import SyncRefModelCallback -from ...trainer.utils import ( - RepeatSampler, - create_model_from_path, - disable_dropout_in_model, - get_config_model_id, - identity, - pad, - split_tensor_dict, - use_adapter, +from ...trainer.utils import pad +from ..self_distillation.base_self_distillation_trainer import ( + BaseSelfDistillationTrainer, + RolloutBatch, + TrainingBatch, ) -from ..self_distillation.self_distillation_mixin import SelfDistillationMixin -from ..self_distillation.teacher_context import PromptTokenizer, extract_last_user_text -from ..utils import prepare_peft_model +from ..self_distillation.prompt_utils import extract_last_user_text from .sdft_config import SDFTConfig if is_peft_available(): from peft import PeftConfig - from peft.peft_model import PeftModel - - from ..self_distillation.peft_adapter_ema_callback import PEFTAdapterEMACallback logger = get_logger(__name__) class DemonstrationTeacherContextBuilder: - """Builds student and teacher contexts from prompts plus privileged context, as in SDFT.""" + """Builds student and teacher contexts from prompts plus privileged context""" def __init__(self, trainer): self.trainer = trainer - self.prompt_tokenizer = PromptTokenizer(trainer) def _stringify_privileged_context(self, privileged_context: Any) -> str: if privileged_context is None: @@ -122,23 +99,27 @@ def build( completion_ids: torch.Tensor, completion_mask: torch.Tensor, ) -> dict[str, torch.Tensor]: - student_batch = self.prompt_tokenizer.tokenize_prompts(prompts) teacher_prompts = [ self._compose_teacher_prompt(prompt, privileged_context) for prompt, privileged_context in zip(prompts, privileged_contexts, strict=True) ] - teacher_batch = self.prompt_tokenizer.tokenize_prompts(teacher_prompts) - teacher_input_ids = torch.cat([teacher_batch.prompt_ids, completion_ids], dim=1) - teacher_attention_mask = torch.cat([teacher_batch.prompt_mask, completion_mask], dim=1) + teacher_prompt_ids_list = self.trainer._tokenize_prompts(teacher_prompts) + device = completion_ids.device + teacher_prompt_ids = [torch.tensor(ids) for ids in teacher_prompt_ids_list] + teacher_prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in teacher_prompt_ids] + teacher_prompt_ids = pad(teacher_prompt_ids, padding_value=self.trainer.pad_token_id, padding_side="left").to( + device=device + ) + teacher_prompt_mask = pad(teacher_prompt_mask, padding_value=0, padding_side="left").to(device=device) + teacher_input_ids = torch.cat([teacher_prompt_ids, completion_ids], dim=1) + teacher_attention_mask = torch.cat([teacher_prompt_mask, completion_mask], dim=1) return { - "prompt_ids": student_batch.prompt_ids, - "prompt_mask": student_batch.prompt_mask, "teacher_input_ids": teacher_input_ids, "teacher_attention_mask": teacher_attention_mask, } -class SDFTTrainer(SelfDistillationMixin, _BaseTrainer): +class SDFTTrainer(BaseSelfDistillationTrainer): """Trainer for SDFT-style on-policy self-distillation with explicit teacher prompts.""" _tag_names = ["trl", "sdft"] @@ -168,319 +149,71 @@ def __init__( optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), peft_config: PeftConfig | None = None, ): - if train_dataset is None: - raise ValueError("`train_dataset` is required") if isinstance(train_dataset, IterableDataset): raise NotImplementedError("Iterable datasets are not yet supported in SDFTTrainer.") if isinstance(eval_dataset, IterableDataset) or ( isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values()) ): raise NotImplementedError("Iterable eval datasets are not yet supported in SDFTTrainer.") - if args.use_vllm: - raise NotImplementedError("SDFTTrainer does not support `use_vllm=True` yet.") - if isinstance(model, str): - model_init_kwargs = args.model_init_kwargs or {} - if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: - model_init_kwargs["device_map"] = None - model = create_model_from_path(model, **model_init_kwargs) - elif args.model_init_kwargs is not None: - logger.warning( - "You passed `model_init_kwargs` to `SDFTConfig`, but `model` is already instantiated. " - "The `model_init_kwargs` will be ignored." - ) - self.model_kwarg_keys = ( - inspect.signature(model.forward).parameters.keys() - if not hasattr(model, "get_base_model") - else inspect.signature(model.get_base_model().forward).parameters.keys() - ) - - if is_peft_available() and is_peft_model(model) and peft_config is not None: - raise ValueError( - "You passed a `PeftModel` instance together with a `peft_config` to SDFTTrainer. Pass either a base " - "model with `peft_config`, or a pre-wrapped PEFT model." - ) - if peft_config is not None or (is_peft_available() and getattr(model, "peft_config", None) is not None): - model = prepare_peft_model(model, peft_config, args) - - if processing_class is None: - processing_class = AutoProcessor.from_pretrained( - get_config_model_id(model.config), truncation_side="left", padding_side="left" - ) - - if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer - elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class - else: - raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - self.pad_token_id = tokenizer.pad_token_id - self.eos_token_id = tokenizer.eos_token_id - self.max_prompt_length = args.max_prompt_length - self.max_completion_length = args.max_completion_length - self.num_generations = args.num_generations - self.num_iterations = args.num_iterations - self.temperature = args.temperature - self.loss_type = args.loss_type - self.shuffle_dataset = args.shuffle_dataset - self.generate_from_teacher = args.generate_from_teacher self.num_loss_tokens_to_skip = args.num_loss_tokens_to_skip - self.chat_template_kwargs = args.chat_template_kwargs or {} - self._step = 0 - self._buffered_inputs = None - self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} - self.prompt_tokenizer = PromptTokenizer(self) self.teacher_context_builder = DemonstrationTeacherContextBuilder(self) - generation_kwargs = { - "max_new_tokens": self.max_completion_length, - "do_sample": True, - "pad_token_id": tokenizer.pad_token_id, - "bos_token_id": tokenizer.bos_token_id, - "eos_token_id": tokenizer.eos_token_id, - "temperature": args.temperature, - "top_p": args.top_p, - "top_k": args.top_k, - "min_p": args.min_p, - "repetition_penalty": args.repetition_penalty, - "cache_implementation": args.cache_implementation, - } - if args.generation_kwargs is not None: - generation_kwargs.update(args.generation_kwargs) - self.generation_config = GenerationConfig(**generation_kwargs, disable_compile=True) - - if hasattr(model, "warnings_issued"): - model.warnings_issued["estimate_tokens"] = True - super().__init__( model=model, args=args, - data_collator=identity, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, callbacks=callbacks, optimizers=optimizers, - compute_loss_func="non-None value to disable scaling", + peft_config=peft_config, ) - if args.disable_dropout: - disable_dropout_in_model(self.model) - - self.model.add_model_tags(self._tag_names) - - # In self-distillation the teacher is always derived from the student: - # - PEFT: base model with adapter disabled (or EMA teacher adapter when sync_ref_model=True) - # - Non-PEFT: same model (or deep-copied EMA model when sync_ref_model=True) - self.teacher_model = None - - if args.sync_ref_model: - if is_peft_available() and is_peft_model(self.model): - self.add_callback( - PEFTAdapterEMACallback( - model=self.model, - teacher_adapter_name="teacher", - update_rate=args.ref_model_mixup_alpha, - sync_steps=args.ref_model_sync_steps, - accelerator=self.accelerator, - ) - ) - else: - student_model = self.accelerator.unwrap_model(self.model) - self.teacher_model = copy.deepcopy(student_model) - self.teacher_model.requires_grad_(False) - self.teacher_model.eval() - if self.is_deepspeed_enabled: - self.teacher_model = prepare_deepspeed(self.teacher_model, self.accelerator) - elif self.is_fsdp_enabled: - self.teacher_model = prepare_fsdp(self.teacher_model, self.accelerator) - else: - self.teacher_model = self.accelerator.prepare_model(self.teacher_model, evaluation_mode=True) - self.add_callback(SyncRefModelCallback(ref_model=self.teacher_model, accelerator=self.accelerator)) - - self.model_accepts_loss_kwargs = False - - def get_train_dataloader(self): - if self.train_dataset is None: - raise ValueError("Trainer: training requires a train_dataset.") - - train_dataset = self.train_dataset - data_collator = self.data_collator - if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): - train_dataset = self._remove_unused_columns(train_dataset, description="training") - else: - data_collator = self._get_collator_with_removed_columns(data_collator, description="training") - - dataloader_params = { - "batch_size": self._train_batch_size * self.args.steps_per_generation, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "persistent_workers": self.args.dataloader_persistent_workers, - } - if not isinstance(train_dataset, torch.utils.data.IterableDataset): - dataloader_params["sampler"] = self._get_train_sampler() - dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["worker_init_fn"] = partial( - seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index - ) - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) - - def _get_train_sampler(self, dataset=None) -> Sampler: - if dataset is None: - dataset = self.train_dataset - return RepeatSampler( - data_source=dataset, - mini_repeat_count=self.num_generations, - batch_size=self.args.generation_batch_size // self.num_generations, - repeat_count=self.num_iterations * self.args.steps_per_generation, - shuffle=self.shuffle_dataset, - seed=self.args.seed, - ) - - def _get_eval_sampler(self, eval_dataset) -> Sampler: - return RepeatSampler( - data_source=eval_dataset, - mini_repeat_count=self.num_generations, - seed=self.args.seed, - ) - - def training_step(self, model, inputs, num_items_in_batch): - output = super().training_step(model, inputs, num_items_in_batch) - self._step += 1 - return output - - def _prepare_inputs(self, generation_batch): - mode = "train" if self.model.training else "eval" - if mode == "train": - generate_every = self.args.steps_per_generation * self.num_iterations - if self._step % generate_every == 0 or self._buffered_inputs is None: - generation_batch = self._build_buffered_batch(generation_batch) - self._buffered_inputs = split_tensor_dict(generation_batch, self.args.steps_per_generation) - self._dispatch_self_distillation_callback( - "on_generation_batch_built", - generate_every=generate_every, - steps_per_generation=self.args.steps_per_generation, - ) - return self._buffered_inputs[self._step % self.args.steps_per_generation] - return self._build_buffered_batch(generation_batch) - - def _generate_completion_ids(self, prompts: list[Any]) -> tuple[torch.Tensor, torch.Tensor]: - generate_inputs = self.processing_class( - text=self.prompt_tokenizer.apply_prompt_template(prompts), - return_tensors="pt", - padding=True, - padding_side="left", - max_length=self.max_prompt_length, - truncation=True, - add_special_tokens=False, - ) - # This generation helper builds tokenized model inputs directly, so use the base Trainer tensor preparation - # instead of re-entering the buffered outer training hook. - generate_inputs = _BaseTrainer._prepare_inputs(self, generate_inputs) - - with ( - unwrap_model_for_generation( - self.model_wrapped, - self.accelerator, - gather_deepspeed3_params=self.args.ds3_gather_for_generation, - ) as unwrapped_model, - torch.no_grad(), - ): - prompt_completion_ids = unwrapped_model.generate( - **generate_inputs, generation_config=self.generation_config - ) - - prompt_length = generate_inputs["input_ids"].size(1) - completion_ids = prompt_completion_ids[:, prompt_length:] - is_eos = completion_ids == self.eos_token_id - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device) - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - seq_idx = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1) - completion_mask = (seq_idx <= eos_idx.unsqueeze(1)).long() - - completion_ids_list = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=True)] - completion_ids = [torch.tensor(ids, device=self.accelerator.device) for ids in completion_ids_list] - completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] - return ( - pad(completion_ids, padding_value=self.pad_token_id, padding_side="right"), - pad(completion_mask, padding_value=0, padding_side="right"), - ) - - def _build_buffered_batch(self, inputs: list[dict[str, Any]]) -> dict[str, torch.Tensor | Any]: + def finalize_batch( + self, + inputs: list[dict[str, Any]], + rollout_batch: RolloutBatch, + ) -> TrainingBatch: prompts, privileged_contexts = self._split_prompt_and_privileged_context(inputs) - generation_prompts = self.teacher_context_builder.select_generation_prompts(prompts, privileged_contexts) - generation_prompt_text = self.prompt_tokenizer.apply_prompt_template(generation_prompts) - self._dispatch_self_distillation_callback( - "on_generation_prompts_selected", - generation_prompts=generation_prompts, - generation_prompt_text=generation_prompt_text, - ) - completion_ids, completion_mask = self._generate_completion_ids(generation_prompts) - teacher_batch = self.teacher_context_builder.build( - prompts, privileged_contexts, completion_ids, completion_mask + prompts, + privileged_contexts, + rollout_batch.completion_ids, + rollout_batch.completion_mask, ) - prompt_completion_ids = torch.cat([teacher_batch["prompt_ids"], completion_ids], dim=1) - attention_mask = torch.cat([teacher_batch["prompt_mask"], completion_mask], dim=1) - logits_to_keep = completion_ids.size(1) - - with torch.no_grad(): - generate_every = self.args.steps_per_generation * self.num_iterations - if not self.generate_from_teacher and self.args.gradient_accumulation_steps % generate_every != 0: - old_per_token_logps, _ = self._get_per_token_logps_and_entropies( - self.model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - compute_entropy=False, - ) - else: - old_per_token_logps = None - - self._dispatch_self_distillation_callback( - "on_self_distillation_batch_prepared", - old_per_token_logps=old_per_token_logps, - prompt_ids=teacher_batch["prompt_ids"], - completion_ids=completion_ids, + batch = rollout_batch.as_dict() + batch.update( + { + "teacher_input_ids": teacher_batch["teacher_input_ids"], + "teacher_attention_mask": teacher_batch["teacher_attention_mask"], + } ) - output = { - "prompt_ids": teacher_batch["prompt_ids"], - "prompt_mask": teacher_batch["prompt_mask"], - "completion_ids": completion_ids, - "completion_mask": completion_mask, - "teacher_input_ids": teacher_batch["teacher_input_ids"], - "teacher_attention_mask": teacher_batch["teacher_attention_mask"], - } - if old_per_token_logps is not None: - output["old_per_token_logps"] = old_per_token_logps - return output + return batch def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): if return_outputs: raise ValueError("The SDFTTrainer does not support returning outputs") - if self.num_loss_tokens_to_skip > 0: - inputs = dict(inputs) - completion_mask = inputs["completion_mask"].clone() - token_positions = torch.arange(completion_mask.size(1), device=completion_mask.device).unsqueeze(0) - completion_mask = completion_mask * (token_positions >= self.num_loss_tokens_to_skip).long() - inputs["completion_mask"] = completion_mask - - loss = self._compute_self_distillation_loss(model, inputs) + distillation_logits = self._compute_teacher_student_logits(model, self.teacher_model, inputs) + loss = self._compute_self_distillation_loss(model, inputs, distillation_logits) accumulation_scale = self.current_gradient_accumulation_steps if self.model.training else 1.0 return loss / accumulation_scale - def _get_teacher_context_for_self_distillation(self, model): - if is_peft_available() and isinstance(self.model, PeftModel): - model = self.accelerator.unwrap_model(self.model) - if self.args.sync_ref_model and "teacher" in model.peft_config: - return use_adapter(model, adapter_name="teacher") - return use_adapter(model, adapter_name=None) - return super()._get_teacher_context_for_self_distillation(model) + def _build_self_distillation_response_mask( + self, + completion_mask: torch.Tensor, + self_distillation_mask: torch.Tensor | None, + ) -> torch.Tensor: + response_mask = BaseSelfDistillationTrainer._build_self_distillation_response_mask( + completion_mask, + self_distillation_mask, + ) + if self.num_loss_tokens_to_skip <= 0: + return response_mask + + # SDFT skips the first few completion tokens only in the distillation loss to suppress teacher-prompt artifacts. + token_positions = torch.arange(response_mask.size(1), device=response_mask.device).unsqueeze(0) + skip_mask = (token_positions >= self.num_loss_tokens_to_skip).long() + return response_mask * skip_mask diff --git a/trl/experimental/sdpo/sdpo.py b/trl/experimental/sdpo/sdpo.py index 7a65a00cc5d..cd752cc6855 100644 --- a/trl/experimental/sdpo/sdpo.py +++ b/trl/experimental/sdpo/sdpo.py @@ -43,7 +43,7 @@ --num_generations 8 \ --generation_batch_size 32 \ --distillation_alpha 1.0 \ - --full_logit_distillation false \ + --distillation_mode sampled_token \ --sdpo_policy_loss_mode hybrid \ --report_to none \ --eval_strategy steps \ diff --git a/trl/experimental/sdpo/sdpo_config.py b/trl/experimental/sdpo/sdpo_config.py index 1cc8c1510b7..92c4a181f39 100644 --- a/trl/experimental/sdpo/sdpo_config.py +++ b/trl/experimental/sdpo/sdpo_config.py @@ -13,6 +13,7 @@ # limitations under the License. from dataclasses import dataclass, field +from typing import Literal from ..self_distillation import SelfDistillationConfig @@ -22,29 +23,46 @@ class SDPOConfig(SelfDistillationConfig): r""" Configuration class for the [`SDPOTrainer`]. - This class extends [`experimental.self_distillation.SelfDistillationConfig`] with the online teacher-construction - parameters used by Self-Distillation Policy Optimization (SDPO). - Parameters: + > Parameters that control the online policy objective + + beta (`float`, *optional*, defaults to `0.0`): + Reference-model KL coefficient for online policy optimization. + epsilon (`float`, *optional*, defaults to `0.2`): + Lower clipping coefficient for GRPO-style policy loss. + epsilon_high (`float` or `None`, *optional*): + Upper clipping coefficient. Defaults to `epsilon` when unset. + importance_sampling_level (`str`, *optional*, defaults to `"token"`): + Importance-sampling granularity. Supported: `token`, `sequence`. + reward_weights (`list[float]` or `None`, *optional*): + Optional weights for multiple reward functions. + scale_rewards (`str` or `bool`, *optional*, defaults to `"group"`): + Reward normalization mode. Supported: `group`, `batch`, `none`. + > Parameters that control the SDPO loss sdpo_policy_loss_mode (`str`, *optional*, defaults to `"distillation_only"`): How SDPO combines the online policy loss and self-distillation loss. Supported: `distillation_only`, - `hybrid`. + `policy_only`, `hybrid`. distillation_alpha (`float`, *optional*, defaults to `1.0`): - Divergence interpolation coefficient. Token-level SDPO requires the official reverse-KL setting + Divergence interpolation coefficient. Sampled-token SDPO requires the official reverse-KL setting + `distillation_alpha=1.0`. + distillation_mode (`Literal["sampled_token", "full_logits", "topk_logits"]`, *optional*, defaults to `"sampled_token"`): + Distillation objective mode. `"sampled_token"` is the default SDPO mode and requires `distillation_alpha=1.0`. distillation_topk (`int` or `None`, *optional*): - Top-k approximation for logit-level SDPO. Requires `full_logit_distillation=True`. + Top-k approximation for logit-level SDPO. Must be set when `distillation_mode="topk_logits"` and left + unset otherwise. > Parameters that control the teacher - teacher_regularization (`str`, *optional*, defaults to `"ema"`): - Teacher update strategy. Supported: `ema`, `none`. - teacher_update_rate (`float` or `None`, *optional*): - EMA update rate used when `teacher_regularization="ema"`. - ema_update_rate (`float`, *optional*, defaults to `0.05`): - Deprecated alias for `teacher_update_rate`. + teacher_model_kind (`str`, *optional*, defaults to `"ema"`): + Semantic teacher choice. `base` uses the initial student, `live` uses the current student, and `ema` + uses an exponentially averaged teacher. + teacher_update_rate (`float`, *optional*, defaults to `0.05`): + EMA update rate used when `teacher_model_kind="ema"`. + teacher_sync_steps (`int`, *optional*, defaults to `1`): + Number of optimizer steps between teacher EMA updates. > Parameters that control reprompting @@ -60,31 +78,69 @@ class SDPOConfig(SelfDistillationConfig): default=True, metadata={"help": "Skip reprompting when model generates correct response."}, ) + beta: float = field( + default=0.0, + metadata={"help": "Reference-model KL coefficient for online policy optimization."}, + ) + epsilon: float = field( + default=0.2, + metadata={"help": "Lower clipping coefficient for GRPO-style policy loss."}, + ) + epsilon_high: float | None = field( + default=None, + metadata={"help": "Upper clipping coefficient. Defaults to `epsilon` when unset."}, + ) + importance_sampling_level: str = field( + default="token", + metadata={"help": "Importance-sampling granularity. Supported: `token`, `sequence`."}, + ) + reward_weights: list[float] | None = field( + default=None, + metadata={"help": "Optional weights for multiple reward functions."}, + ) + scale_rewards: str | bool = field( + default="group", + metadata={"help": "Reward normalization mode. Supported: `group`, `batch`, `none`."}, + ) distillation_alpha: float = field( default=1.0, metadata={ - "help": "KL divergence direction for SDPO. Token-level SDPO requires reverse KL (`distillation_alpha=1.0`)." + "help": "Divergence interpolation coefficient. Sampled-token SDPO requires the official reverse-KL setting " + "`distillation_alpha=1.0`." + }, + ) + distillation_mode: Literal["sampled_token", "full_logits", "topk_logits"] = field( + default="sampled_token", + metadata={ + "help": "Distillation objective mode. `sampled_token` is the default SDPO mode and requires " + "`distillation_alpha=1.0`." }, ) distillation_topk: int | None = field( default=None, - metadata={"help": "Top-K approximation for logit-level SDPO. Requires `full_logit_distillation=True`."}, + metadata={ + "help": "Top-k approximation for logit-level SDPO. Must be set when `distillation_mode=topk_logits` and left " + "unset otherwise." + }, ) sdpo_policy_loss_mode: str = field( default="distillation_only", - metadata={"help": "SDPO policy loss mode. Supported: `distillation_only`, `hybrid`."}, + metadata={"help": "SDPO policy loss mode. Supported: `distillation_only`, `policy_only`, `hybrid`."}, ) - teacher_regularization: str = field( + teacher_model_kind: str = field( default="ema", - metadata={"help": "Teacher regularization mode. Supported: `ema`, `none`."}, + metadata={ + "help": "Semantic teacher choice. `base` uses the initial student, `live` uses the current student, " + "and `ema` uses an exponentially averaged teacher." + }, ) - teacher_update_rate: float | None = field( - default=None, + teacher_update_rate: float = field( + default=0.05, metadata={"help": "Teacher update rate used for EMA teacher synchronization."}, ) - ema_update_rate: float = field( - default=0.05, - metadata={"help": "Deprecated alias for `teacher_update_rate`."}, + teacher_sync_steps: int = field( + default=1, + metadata={"help": "How often to synchronize the EMA teacher model."}, ) max_reprompt_len: int = field( default=10240, @@ -126,23 +182,26 @@ class SDPOConfig(SelfDistillationConfig): def __post_init__(self): super().__post_init__() - if self.teacher_update_rate is None: - self.teacher_update_rate = self.ema_update_rate + self.scale_rewards = {True: "group", False: "none"}.get(self.scale_rewards, self.scale_rewards) + if self.scale_rewards not in ["group", "batch", "none"]: + raise ValueError("scale_rewards must be one of: 'group', 'batch', 'none'") + + if self.importance_sampling_level not in ["token", "sequence"]: + raise ValueError("importance_sampling_level must be either 'token' or 'sequence'") + + if self.epsilon_high is None: + self.epsilon_high = self.epsilon - if self.teacher_regularization not in {"ema", "none"}: - raise ValueError("teacher_regularization must be one of: 'ema', 'none'") - if not 0.0 <= self.teacher_update_rate <= 1.0: - raise ValueError("teacher_update_rate must be in [0, 1]") - if self.sdpo_policy_loss_mode not in {"distillation_only", "hybrid"}: - raise ValueError("sdpo_policy_loss_mode must be one of: 'distillation_only', 'hybrid'") + if self.sdpo_policy_loss_mode not in {"distillation_only", "policy_only", "hybrid"}: + raise ValueError("sdpo_policy_loss_mode must be one of: 'distillation_only', 'policy_only', 'hybrid'") if self.sdpo_policy_loss_mode == "distillation_only" and self.distillation_weight <= 0: raise ValueError("distillation_only mode requires `distillation_weight > 0`.") + if self.sdpo_policy_loss_mode == "hybrid" and self.distillation_weight <= 0: + raise ValueError("hybrid mode requires `distillation_weight > 0`.") if self.max_reprompt_len <= 0: raise ValueError("max_reprompt_len must be positive") - if not self.full_logit_distillation and self.distillation_alpha != 1.0: + if self.distillation_mode == "sampled_token" and self.distillation_alpha != 1.0: raise ValueError( - "SDPO token-level distillation requires `distillation_alpha=1.0`. " - "Set `full_logit_distillation=True` to use other divergence settings." + "SDPO sampled-token distillation requires `distillation_alpha=1.0`. " + "Set `distillation_mode='full_logits'` or `distillation_mode='topk_logits'` to use other divergence settings." ) - if self.distillation_topk is not None and not self.full_logit_distillation: - raise ValueError("SDPO `distillation_topk` requires `full_logit_distillation=True`.") diff --git a/trl/experimental/sdpo/sdpo_trainer.py b/trl/experimental/sdpo/sdpo_trainer.py index ef84a17a44c..b0d3afc4e08 100644 --- a/trl/experimental/sdpo/sdpo_trainer.py +++ b/trl/experimental/sdpo/sdpo_trainer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import re import textwrap from typing import Any @@ -21,31 +20,36 @@ from accelerate.utils import gather_object from datasets import Dataset, IterableDataset from torch import nn -from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback - -from ...trainer.callbacks import SyncRefModelCallback -from ...trainer.utils import pad -from ..self_distillation.base_self_distillation_trainer import BaseSelfDistillationTrainer -from ..self_distillation.teacher_context import TokenizedPromptBatch, extract_last_user_text +from transformers import ( + AutoModelForSequenceClassification, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, +) +from transformers.utils import logging + +from ...data_utils import apply_chat_template, is_conversational +from ...models import prepare_deepspeed, prepare_fsdp +from ...trainer.base_trainer import _BaseTrainer +from ...trainer.utils import get_config_model_id, pad +from ..self_distillation.base_self_distillation_trainer import ( + BaseSelfDistillationTrainer, + DistillationLogits, + RolloutBatch, + TrainingBatch, +) +from ..self_distillation.loss_utils import aggregate_loss, select_token_log_probs +from ..self_distillation.prompt_utils import extract_last_user_text from .sdpo_config import SDPOConfig -class EMATeacherSyncCallback(SyncRefModelCallback): - """Synchronize an EMA teacher model with the student model on each step.""" - - def __init__(self, teacher_model, update_rate: float, accelerator=None): - super().__init__(ref_model=teacher_model, accelerator=accelerator) - self.update_rate = update_rate - - def on_step_end(self, args, state, control, **kwargs): - model = kwargs["model"] - if self.accelerator is not None: - model = self.accelerator.unwrap_model(model) - self.sync_target_model(model, self.ref_model, self.update_rate) +logger = logging.get_logger(__name__) class SuccessfulRolloutTeacherContextBuilder: - """Builds SDPO teacher contexts from successful rollouts, following the official online implementation.""" + """Builds teacher contexts from successful rollouts""" def __init__(self, trainer): self.trainer = trainer @@ -60,36 +64,18 @@ def _build_reprompt_text(self, prompt_text: str, solution_text: str, feedback_te def _tokenize_teacher_messages( self, teacher_messages_list: list[str | list[dict[str, Any]]] - ) -> TokenizedPromptBatch: - teacher_prompt_ids_list = [] + ) -> dict[str, torch.Tensor]: device = self.trainer.accelerator.device - chat_template_kwargs = getattr(self.trainer, "chat_template_kwargs", {}) - for msg in teacher_messages_list: - if isinstance(msg, list) and isinstance(msg[0], dict): - tokenized = self.trainer.processing_class.apply_chat_template( - msg, - tokenize=True, - add_generation_prompt=True, - return_tensors="pt", - **chat_template_kwargs, - ) - if isinstance(tokenized, torch.Tensor): - ids = tokenized.squeeze(0) - else: - ids = tokenized["input_ids"].squeeze(0) - else: - ids = self.trainer.processing_class.encode(msg, return_tensors="pt").squeeze(0) - - if ids.shape[0] > self.trainer.args.max_reprompt_len: - ids = ids[-self.trainer.args.max_reprompt_len :] - teacher_prompt_ids_list.append(ids) - - teacher_prompt_ids = [ids.to(device) for ids in teacher_prompt_ids_list] + teacher_prompt_ids_list = self.trainer._tokenize_prompts_untruncated(teacher_messages_list) + teacher_prompt_ids = [ + torch.as_tensor(ids[-self.trainer.args.max_reprompt_len :], device=device) + for ids in teacher_prompt_ids_list + ] teacher_prompt_mask = [torch.ones(len(ids), dtype=torch.long, device=device) for ids in teacher_prompt_ids] - return TokenizedPromptBatch( - prompt_ids=pad(teacher_prompt_ids, padding_value=self.trainer.pad_token_id, padding_side="left"), - prompt_mask=pad(teacher_prompt_mask, padding_value=0, padding_side="left"), - ) + return { + "prompt_ids": pad(teacher_prompt_ids, padding_value=self.trainer.pad_token_id, padding_side="left"), + "prompt_mask": pad(teacher_prompt_mask, padding_value=0, padding_side="left"), + } def build( self, @@ -214,8 +200,8 @@ def build( local_teacher_messages.append(self._build_reprompt_text(original_prompt, solution_text, feedback_text)) teacher_batch = self._tokenize_teacher_messages(local_teacher_messages) - teacher_input_ids = torch.cat([teacher_batch.prompt_ids, completion_ids], dim=1) - teacher_attention_mask = torch.cat([teacher_batch.prompt_mask, completion_mask], dim=1) + teacher_input_ids = torch.cat([teacher_batch["prompt_ids"], completion_ids], dim=1) + teacher_attention_mask = torch.cat([teacher_batch["prompt_mask"], completion_mask], dim=1) batch_size = total_samples if total_samples > 0 else 1 num_groups = max(1, total_samples // max(1, num_generations)) @@ -277,44 +263,193 @@ def __init__( raise ValueError("`reward_funcs` is required for SDPOTrainer because SDPO must score rollouts.") super().__init__( model=model, - reward_funcs=reward_funcs, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, - reward_processing_classes=reward_processing_classes, callbacks=callbacks, optimizers=optimizers, peft_config=peft_config, ) + self.importance_sampling_level = args.importance_sampling_level + self.scale_rewards = args.scale_rewards + self.epsilon_low = args.epsilon + self.epsilon_high = args.epsilon_high + self.beta = args.beta + + if not isinstance(reward_funcs, list): + reward_funcs = [reward_funcs] + self.reward_func_names = [] + for i, reward_func in enumerate(reward_funcs): + if isinstance(reward_func, str): + reward_model_init_kwargs = args.model_init_kwargs or {} + if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + reward_model_init_kwargs["device_map"] = None + reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( + reward_func, + num_labels=1, + **reward_model_init_kwargs, + ) + if isinstance(reward_funcs[i], nn.Module): + self.reward_func_names.append(get_config_model_id(reward_funcs[i].config).split("/")[-1]) + else: + self.reward_func_names.append(reward_funcs[i].__name__) + self.reward_funcs = reward_funcs + + if args.reward_weights is not None: + if len(args.reward_weights) != len(self.reward_funcs): + raise ValueError("Number of reward weights must match number of reward functions") + self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) + else: + self.reward_weights = torch.ones(len(self.reward_funcs), dtype=torch.float32) + + if reward_processing_classes is None: + reward_processing_classes = [None] * len(self.reward_funcs) + elif not isinstance(reward_processing_classes, list): + reward_processing_classes = [reward_processing_classes] + if len(reward_processing_classes) != len(self.reward_funcs): + raise ValueError("Number of reward processing classes must match number of reward functions") + + for i, (reward_processing_class, reward_func) in enumerate( + zip(reward_processing_classes, self.reward_funcs, strict=True) + ): + if isinstance(reward_func, PreTrainedModel): + if reward_processing_class is None: + reward_processing_class = AutoTokenizer.from_pretrained(get_config_model_id(reward_func.config)) + if reward_processing_class.pad_token_id is None: + reward_processing_class.pad_token = reward_processing_class.eos_token + reward_func.config.pad_token_id = reward_processing_class.pad_token_id + reward_processing_classes[i] = reward_processing_class + self.reward_processing_classes = reward_processing_classes + + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, nn.Module): + if self.is_deepspeed_enabled: + self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) + elif self.is_fsdp_enabled: + self.reward_funcs[i] = prepare_fsdp(reward_func, self.accelerator) + else: + self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) + self.teacher_context_builder = SuccessfulRolloutTeacherContextBuilder(self) - if self.args.teacher_regularization == "ema": - # `self.model` may already be accelerator-wrapped after the shared base constructor. Build the EMA - # teacher from the unwrapped student model first, then prepare it as an auxiliary eval-only module. - student_model = self.accelerator.unwrap_model(self.model) - self.teacher_model = copy.deepcopy(student_model) - self.teacher_model.requires_grad_(False) - self.teacher_model.eval() - self.teacher_model = self._prepare_auxiliary_model_for_eval(self.teacher_model) - self.add_callback( - EMATeacherSyncCallback( - teacher_model=self.teacher_model, - update_rate=self.args.teacher_update_rate, - accelerator=self.accelerator, + + def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): + device = self.accelerator.device + if len(self.reward_funcs) == 0: + return torch.zeros((len(prompts), 0), device=device) + + rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) + keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]] + reward_kwargs = {key: [example[key] for example in inputs] for key in keys} + reward_kwargs["trainer_state"] = self.state + + for i, (reward_func, reward_processing_class) in enumerate( + zip(self.reward_funcs, self.reward_processing_classes, strict=True) + ): + if isinstance(reward_func, nn.Module): + if is_conversational(inputs[0]): + messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)] + texts = [ + apply_chat_template(x, reward_processing_class, **self.chat_template_kwargs)["text"] + for x in messages + ] + else: + texts = [p + c for p, c in zip(prompts, completions, strict=True)] + reward_inputs = reward_processing_class( + text=texts, + return_tensors="pt", + padding=True, + padding_side="right", + add_special_tokens=False, ) - ) + reward_inputs = _BaseTrainer._prepare_inputs(self, reward_inputs) + with torch.inference_mode(): + rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] + else: + output_reward_func = reward_func( + prompts=prompts, + completions=completions, + completion_ids=completion_ids_list, + **reward_kwargs, + ) + output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) - def _allow_topk_without_full_logit_distillation(self) -> bool: - return False + return self.accelerator.gather(rewards_per_func) - def _generate_and_score_completions( - self, inputs: list[dict[str, torch.Tensor | Any]] - ) -> dict[str, torch.Tensor | Any]: + def finalize_batch( + self, + inputs: list[dict[str, Any]], + rollout_batch: RolloutBatch, + ) -> TrainingBatch: + device = self.accelerator.device + mode = "train" if self.model.training else "eval" prompts, privileged_contexts = self._split_prompt_and_privileged_context(inputs) + raw_completion_lengths = rollout_batch.raw_completion_lengths.detach().cpu().tolist() + completion_ids_list = [ + ids[:length].tolist() + for ids, length in zip(rollout_batch.completion_ids.detach().cpu(), raw_completion_lengths, strict=True) + ] + if is_conversational({"prompt": prompts[0]}): + completions_text = self.processing_class.batch_decode( + rollout_batch.completion_ids, skip_special_tokens=True + ) + completions = [[{"role": "assistant", "content": content}] for content in completions_text] + else: + completions = self.processing_class.batch_decode(rollout_batch.completion_ids, skip_special_tokens=True) - output = super()._generate_and_score_completions(inputs) - output.update( - self.teacher_context_builder.build(output, prompts, output["rewards"], feedbacks=privileged_contexts) + rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) + if rewards_per_func.numel() == 0: + rewards = torch.zeros(self.accelerator.num_processes * len(prompts), device=device) + else: + rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + + num_generations = self.num_generations if mode == "train" else self.num_generations_eval + mean_grouped_rewards = rewards.view(-1, num_generations).mean(dim=1).repeat_interleave(num_generations, dim=0) + if self.scale_rewards == "batch": + std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards) + group_std_rewards = rewards.view(-1, num_generations).std(dim=1) + elif self.scale_rewards == "none": + std_rewards = torch.ones_like(rewards) + group_std_rewards = torch.ones(rewards.numel() // num_generations, device=device, dtype=rewards.dtype) + else: + group_std_rewards = rewards.view(-1, num_generations).std(dim=1) + std_rewards = group_std_rewards.repeat_interleave(num_generations, dim=0) + advantages = (rewards - mean_grouped_rewards) / (std_rewards + 1e-4) + self._record_reward_diagnostics(mode, rewards, rewards_per_func, group_std_rewards) + + local_batch_size = rollout_batch.completion_ids.size(0) + process_start = self.accelerator.process_index * local_batch_size + process_slice = slice(process_start, process_start + local_batch_size) + local_rewards = rewards[process_slice] + local_advantages = advantages[process_slice] + + agg_completion_lengths = self.accelerator.gather( + torch.tensor([len(ids) for ids in completion_ids_list], device=device) + ) + self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) + + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + agg_is_truncated = self.accelerator.gather(is_truncated) + self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) + term_completion_lengths = agg_completion_lengths[~agg_is_truncated] + if len(term_completion_lengths) == 0: + term_completion_lengths = torch.zeros(1, device=device) + self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) + + rollout_dict = rollout_batch.as_dict() + rollout_dict["rewards"] = local_rewards + rollout_dict["advantages"] = local_advantages + teacher_context = self.teacher_context_builder.build( + rollout_dict, + prompts, + rollout_dict["rewards"], + feedbacks=privileged_contexts, ) mode = "train" if self.model.training else "eval" @@ -324,13 +459,23 @@ def _generate_and_score_completions( self._dispatch_self_distillation_callback( "on_teacher_context_built", - teacher_input_ids=output["teacher_input_ids"], - teacher_attention_mask=output["teacher_attention_mask"], - completion_mask=output["completion_mask"], - self_distillation_mask=output["self_distillation_mask"], + teacher_input_ids=teacher_context["teacher_input_ids"], + teacher_attention_mask=teacher_context["teacher_attention_mask"], + completion_mask=rollout_batch.completion_mask, + self_distillation_mask=teacher_context["self_distillation_mask"], ) - return output + batch = rollout_batch.as_dict() + batch.update( + { + "teacher_input_ids": teacher_context["teacher_input_ids"], + "teacher_attention_mask": teacher_context["teacher_attention_mask"], + "self_distillation_mask": teacher_context["self_distillation_mask"], + "rewards": local_rewards, + "advantages": local_advantages, + } + ) + return batch def _warn_on_inactive_self_distillation(self, mode: str) -> None: metrics = self.teacher_context_builder.last_metrics @@ -365,23 +510,154 @@ def _warn_on_inactive_self_distillation(self, mode: str) -> None: else: self._diagnostic_counters[mode]["no_successful_rollouts"] = 0 - def _compute_loss( + def _record_reward_diagnostics( + self, + mode: str, + rewards: torch.Tensor, + rewards_per_func: torch.Tensor, + group_std_rewards: torch.Tensor, + ) -> None: + tolerance = self.args.diagnostics_flat_tolerance + + reward_mean = rewards.mean() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) + reward_std = rewards.std() if rewards.numel() > 1 else torch.tensor(0.0, device=self.accelerator.device) + reward_min = rewards.min() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) + reward_max = rewards.max() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) + flat_group_fraction = ( + (group_std_rewards <= tolerance).float().mean() + if group_std_rewards.numel() > 0 + else torch.tensor(1.0, device=self.accelerator.device) + ) + + self._metrics[mode]["self_distillation/reward_mean"].append(self.accelerator.gather(reward_mean).mean().item()) + self._metrics[mode]["self_distillation/reward_std"].append(self.accelerator.gather(reward_std).mean().item()) + self._metrics[mode]["self_distillation/reward_min"].append(self.accelerator.gather(reward_min).min().item()) + self._metrics[mode]["self_distillation/reward_max"].append(self.accelerator.gather(reward_max).max().item()) + self._metrics[mode]["self_distillation/group_reward_std_mean"].append( + self.accelerator.gather(group_std_rewards.mean() if group_std_rewards.numel() > 0 else reward_std) + .mean() + .item() + ) + self._metrics[mode]["self_distillation/flat_group_fraction"].append( + self.accelerator.gather(flat_group_fraction).mean().item() + ) + + if rewards_per_func.numel() > 0: + reward_func_means = rewards_per_func.nanmean(dim=0) + gathered_means = self.accelerator.gather(reward_func_means).view(-1, reward_func_means.numel()).mean(dim=0) + for reward_name, reward_func_mean in zip(self.reward_func_names, gathered_means.tolist(), strict=True): + self._metrics[mode][f"self_distillation/rewards/{reward_name}"].append(reward_func_mean) + + reward_is_flat = reward_std.item() <= tolerance + grouped_rewards_are_flat = flat_group_fraction.item() >= 1.0 - tolerance + if reward_is_flat and grouped_rewards_are_flat: + self._warn_on_degenerate_diagnostics( + mode=mode, + counter_key="flat_rewards", + message=( + "Observed flat SDPO rewards across all sampled generations. " + "Policy advantages will collapse to zero, and SDPO will not learn. " + "Check reward density, reward shaping, or `success_reward_threshold`." + ), + ) + else: + self._diagnostic_counters[mode]["flat_rewards"] = 0 + + def _warn_on_degenerate_diagnostics(self, mode: str, counter_key: str, message: str) -> None: + interval = self.args.diagnostics_warning_interval + if interval == 0: + return + + self._diagnostic_counters[mode][counter_key] += 1 + count = self._diagnostic_counters[mode][counter_key] + if count == 1 or count % interval == 0: + logger.warning("%s Consecutive degenerate steps: %s.", message, count) + + def _compute_policy_loss( + self, + inputs, + student_logits, + ) -> torch.Tensor: + completion_ids = inputs["completion_ids"] + completion_mask = inputs["completion_mask"] + per_token_logps = select_token_log_probs(student_logits, completion_ids) + old_per_token_logps = inputs.get("old_per_token_logps") + old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps + advantages = inputs["advantages"] + if advantages.dim() == 1: + advantages = advantages.unsqueeze(1) + log_ratio = per_token_logps - old_per_token_logps + if self.importance_sampling_level == "sequence": + log_ratio = (log_ratio * completion_mask).sum(-1, keepdim=True) / completion_mask.sum( + -1, keepdim=True + ).clamp(min=1.0) + coef_1 = torch.exp(log_ratio) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + per_token_loss = -torch.min(coef_1 * advantages, coef_2 * advantages) + + loss = aggregate_loss( + per_token_loss, + completion_mask, + loss_type=self.loss_type, + max_completion_length=self.max_completion_length, + ) + + mode = "train" if self.model.training else "eval" + self._metrics[mode]["self_distillation/policy_loss"].append( + self.accelerator.gather(loss.detach()).mean().item() + ) + + accumulation_scale = self.current_gradient_accumulation_steps if mode == "train" else 1.0 + return loss / accumulation_scale + + def _compute_weighted_self_distillation_loss( self, model, inputs, + distillation_logits: DistillationLogits, ) -> torch.Tensor: accumulation_scale = self.current_gradient_accumulation_steps if self.model.training else 1.0 + distillation_loss = ( + self._compute_self_distillation_loss( + model, + inputs, + distillation_logits, + ) + / accumulation_scale + ) + return self.args.distillation_weight * distillation_loss + + def _compute_hybrid_loss(self, model, inputs) -> torch.Tensor: + distillation_logits = self._compute_teacher_student_logits(model, self.teacher_model, inputs) + policy_loss = self._compute_policy_loss(inputs, distillation_logits.student_logits) + weighted_distillation_loss = self._compute_weighted_self_distillation_loss( + model, + inputs, + distillation_logits, + ) + return policy_loss + weighted_distillation_loss - if self.args.sdpo_policy_loss_mode == "hybrid": - base_policy_loss = super()._compute_loss(model, inputs) - if self.args.distillation_weight <= 0.0: - return base_policy_loss - - sdpo_loss = self._compute_self_distillation_loss(model, inputs) / accumulation_scale - return base_policy_loss + self.args.distillation_weight * sdpo_loss + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if return_outputs: + raise ValueError("The SDPOTrainer does not support returning outputs") - if self.args.distillation_weight <= 0.0: - return super()._compute_loss(model, inputs) + if self.args.sdpo_policy_loss_mode == "hybrid": + return self._compute_hybrid_loss(model, inputs) + if self.args.sdpo_policy_loss_mode == "distillation_only": + distillation_logits = self._compute_teacher_student_logits(model, self.teacher_model, inputs) + return self._compute_weighted_self_distillation_loss(model, inputs, distillation_logits) + if self.args.sdpo_policy_loss_mode == "policy_only": + student_logits = self._compute_student_distillation_logits( + model=model, + prompt_ids=inputs["prompt_ids"], + prompt_mask=inputs["prompt_mask"], + completion_ids=inputs["completion_ids"], + completion_mask=inputs["completion_mask"], + logits_to_keep=inputs["completion_ids"].size(1), + ) + return self._compute_policy_loss(inputs, student_logits) - sdpo_loss = self._compute_self_distillation_loss(model, inputs) / accumulation_scale - return self.args.distillation_weight * sdpo_loss + raise ValueError( + "Unsupported `sdpo_policy_loss_mode`: " + f"{self.args.sdpo_policy_loss_mode!r}. Expected one of: 'hybrid', 'distillation_only', 'policy_only'." + ) diff --git a/trl/experimental/sdzero/__init__.py b/trl/experimental/sdzero/__init__.py new file mode 100644 index 00000000000..b59fbc8c0dd --- /dev/null +++ b/trl/experimental/sdzero/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .sdzero_config import SDZeroConfig +from .sdzero_trainer import SDZeroTrainer +from .srt_config import SRTConfig +from .srt_trainer import SRTTrainer + + +__all__ = ["SDZeroConfig", "SDZeroTrainer", "SRTConfig", "SRTTrainer"] diff --git a/trl/experimental/sdzero/sdzero.py b/trl/experimental/sdzero/sdzero.py new file mode 100644 index 00000000000..0ace4544a60 --- /dev/null +++ b/trl/experimental/sdzero/sdzero.py @@ -0,0 +1,137 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "peft", +# "math-verify>=0.5.2", +# ] +# /// + +"""Training script for [`SDZeroTrainer`] — SD-Zero Phase 2 (On-Policy Self-Distillation). + +Trains a model with on-policy self-distillation via revision feedback. The model (typically a Phase 1 SRT +checkpoint) acts as both student and frozen teacher. The student generates responses on-policy, a binary verifier +determines correctness, and the student is trained to match the teacher's revision distribution. By default, +[`SDZeroConfig`] uses a frozen base teacher, full-logit distillation, `distillation_alpha=1.0`, no +importance-sampling clipping, and one rollout per prompt. + +The dataset must expose a `problem` column (the question) and an `answer` column (the gold final answer). + +Example: + +```bash +python trl/experimental/sdzero/sdzero.py \\ + --model_name_or_path path/to/srt-checkpoint \\ + --dataset_name open-r1/OpenR1-Math-220k \\ + --output_dir outputs/sdzero-qwen2.5-0.5b \\ + --per_device_train_batch_size 1 \\ + --max_completion_length 256 \\ + --max_steps 100 \\ + --logging_steps 1 +``` +""" + +from dataclasses import dataclass, field + +import torch +from datasets import load_dataset, load_from_disk +from transformers import AutoModelForCausalLM, AutoTokenizer + +from trl import ModelConfig, ScriptArguments, TrlParser, get_kbit_device_map, get_peft_config, get_quantization_config +from trl.experimental.sdzero import SDZeroConfig, SDZeroTrainer + + +@dataclass +class SDZeroScriptArguments(ScriptArguments): + dataset_path: str | None = field( + default=None, + metadata={"help": "Local path to a dataset saved with `datasets.save_to_disk`. Overrides `dataset_name`."}, + ) + problem_column: str = field( + default="problem", + metadata={"help": "Column name containing the problem / question text."}, + ) + answer_column: str = field( + default="answer", + metadata={"help": "Column name containing the gold final answer."}, + ) + + +def _prepare_dataset(dataset, problem_column: str, answer_column: str): + """Convert dataset rows to the `{"prompt": [...], "answer": ...}` format expected by SDZeroTrainer.""" + missing = [c for c in [problem_column, answer_column] if c not in dataset.column_names] + if missing: + raise ValueError(f"Dataset is missing required columns: {missing}. Present: {dataset.column_names}") + + def _to_prompt_answer(example): + return { + "prompt": [{"role": "user", "content": example[problem_column]}], + "answer": example[answer_column], + } + + return dataset.map(_to_prompt_answer, remove_columns=dataset.column_names) + + +if __name__ == "__main__": + parser = TrlParser((SDZeroScriptArguments, SDZeroConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + + dtype = model_args.dtype if model_args.dtype in ("auto", None) else getattr(torch, model_args.dtype) + model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + dtype=dtype, + ) + quantization_config = get_quantization_config(model_args) + if quantization_config is not None: + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + padding_side="left", + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + if script_args.dataset_path is not None: + raw_dataset = load_from_disk(script_args.dataset_path) + else: + raw_dataset = load_dataset( + script_args.dataset_name, + name=script_args.dataset_config, + split=script_args.dataset_train_split, + ) + + dataset = _prepare_dataset(raw_dataset, script_args.problem_column, script_args.answer_column) + + training_args.model_init_kwargs = model_kwargs + model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) + if training_args.gradient_checkpointing: + model.config.use_cache = False + + trainer = SDZeroTrainer( + model=model, + args=training_args, + train_dataset=dataset, + processing_class=tokenizer, + peft_config=get_peft_config(model_args), + ) + trainer.train() + trainer.save_model(training_args.output_dir) diff --git a/trl/experimental/sdzero/sdzero_config.py b/trl/experimental/sdzero/sdzero_config.py new file mode 100644 index 00000000000..daea9038cba --- /dev/null +++ b/trl/experimental/sdzero/sdzero_config.py @@ -0,0 +1,95 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Literal + +from ..self_distillation.self_distillation_config import SelfDistillationConfig + + +@dataclass +class SDZeroConfig(SelfDistillationConfig): + r""" + Configuration class for [`SDZeroTrainer`]. + + Parameters: + assistant_turn_template (`str`, *optional*, defaults to `"{y}\n\n{control_prompt}\n\n"`): + Template used to compose the teacher-side assistant prefix from the sampled response `y` and the + control prompt. + + > Parameters that control the teacher + + teacher_model_kind (`str`, *optional*, defaults to `"base"`): + Semantic teacher choice. Defaults to `"base"` so the teacher stays fixed throughout training. + teacher_update_rate (`float`, *optional*, defaults to `1.0`): + EMA update rate used when `teacher_model_kind="ema"`. Defaults to `1.0` so opting into EMA gives + periodic hard teacher resync. + teacher_sync_steps (`int`, *optional*, defaults to `512`): + Number of optimizer steps between EMA teacher updates. + + > Parameters that control the loss + + distillation_mode (`str`, *optional*, defaults to `"full_logits"`): + Distillation objective mode. Defaults to `"full_logits"`. + distillation_alpha (`float`, *optional*, defaults to `1.0`): + KL direction. Defaults to `1.0`. + distillation_is_clip (`float` or `None`, *optional*): + Importance-sampling clip. Defaults to `None`, which disables clipping. + + > Parameters that control the student generation + + num_generations (`int`, *optional*, defaults to `1`): + Number of rollouts sampled per prompt per training step. + """ + + assistant_turn_template: str = field( + default="{y}\n\n{control_prompt}\n\n", + metadata={ + "help": "Template used to compose the teacher-side assistant prefix from the sampled response `y` " + "and the control prompt." + }, + ) + teacher_model_kind: str = field( + default="base", + metadata={ + "help": "Semantic teacher choice. Defaults to 'base' so the teacher stays fixed throughout training." + }, + ) + teacher_update_rate: float = field( + default=1.0, + metadata={ + "help": "EMA update rate used when teacher_model_kind='ema'. Defaults to 1.0 so opting into EMA " + "gives periodic hard teacher resync." + }, + ) + teacher_sync_steps: int = field( + default=512, + metadata={"help": "Number of optimizer steps between EMA teacher updates."}, + ) + distillation_mode: Literal["sampled_token", "full_logits", "topk_logits"] = field( + default="full_logits", + metadata={"help": "Distillation objective mode. Defaults to 'full_logits'."}, + ) + distillation_alpha: float = field( + default=1.0, + metadata={"help": "KL direction. Defaults to 1.0."}, + ) + distillation_is_clip: float | None = field( + default=None, + metadata={"help": "Importance-sampling clip. Defaults to `None`, which disables clipping."}, + ) + num_generations: int = field( + default=1, + metadata={"help": "Number of rollouts sampled per prompt per training step."}, + ) diff --git a/trl/experimental/sdzero/sdzero_trainer.py b/trl/experimental/sdzero/sdzero_trainer.py new file mode 100644 index 00000000000..968593ede47 --- /dev/null +++ b/trl/experimental/sdzero/sdzero_trainer.py @@ -0,0 +1,244 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import textwrap +from collections.abc import Callable +from typing import Any + +import torch +from datasets import Dataset, IterableDataset +from torch import nn +from transformers import ( + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, +) +from transformers.utils import is_peft_available + +from ...data_utils import is_conversational +from ...rewards import accuracy_reward +from ...trainer.utils import get_config_model_id, pad +from ..self_distillation.base_self_distillation_trainer import ( + BaseSelfDistillationTrainer, + RolloutBatch, + TrainingBatch, +) +from .sdzero_config import SDZeroConfig + + +if is_peft_available(): + from peft import PeftConfig + + +REPHRASE_PROMPT = "Let me rephrase the above solution." +RESTART_PROMPT = "Wait, this response is not correct, let me start over." + + +class SDZeroTrainer(BaseSelfDistillationTrainer): + """ + On-policy self-distillation via revision feedback. See + [Self-Distillation Zero](https://huggingface.co/papers/2604.12002). + + At each step, the student generates a response, a binary verifier judges it, and a control prompt is + selected accordingly. The teacher provides a next-token distribution over the student's response, + conditioned on the response and the control prompt. The student is updated via KL divergence to match + the teacher's distribution. + + The dataset must contain two columns: `prompt` (the problem as a conversational list or plain string) and + `answer` (the gold answer used by the binary verifier). + + Example: + + ```python + from datasets import Dataset + from trl.experimental.sdzero import SDZeroConfig, SDZeroTrainer + + dataset = Dataset.from_list([ + {"prompt": [{"role": "user", "content": "What is 2+2?"}], "answer": "4"}, + ]) + trainer = SDZeroTrainer( + model="model-id-or-path", + args=SDZeroConfig(output_dir="sdzero-model", max_steps=100), + train_dataset=dataset, + ) + trainer.train() + ``` + + Args: + model (`str` or [`~transformers.PreTrainedModel`] or [`~peft.PeftModel`]): + Model to be trained. Can be a model id string, a local directory path, or a pre-instantiated + model object. + args ([`SDZeroConfig`], *optional*): + Configuration for this trainer. If `None`, a default configuration is used. + train_dataset ([`~datasets.Dataset`]): + Training dataset. Must contain columns `prompt` and `answer`. + eval_dataset ([`~datasets.Dataset`] or `dict[str, Dataset]`, *optional*): + Evaluation dataset. Must meet the same column requirements as `train_dataset`. + processing_class ([`~transformers.PreTrainedTokenizerBase`] or [`~transformers.ProcessorMixin`], *optional*): + Tokenizer or processor. If `None`, loaded from the model. + reward_fn (`Callable`, *optional*): + Binary reward function with signature `(completions, solution) -> list[float | None]`, where + `completions` is a list of `[{"role": "assistant", "content": ...}]` lists and `solution` is a + list of gold answer strings. Return values of `1.0` are treated as correct, anything else as + incorrect. Defaults to [`~trl.rewards.accuracy_reward`], which parses `\\boxed{}` LaTeX format. + callbacks (list of [`~transformers.TrainerCallback`], *optional*): + Callbacks to customize the training loop. + optimizers (`tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None]`, *optional*, defaults to `(None, None)`): + Optimizer and scheduler. + peft_config ([`~peft.PeftConfig`], *optional*): + PEFT configuration. If `None`, the model is not wrapped. + """ + + _tag_names = ["trl", "sdzero", "sd-zero"] + _name = "SDZero" + config_cls = SDZeroConfig + # docstyle-ignore + _paper = { + "title": "Self-Distillation Zero: Self-Revision Turns Binary Rewards into Dense Supervision", + "id": "2604.12002", + "citation": textwrap.dedent("""\ + @article{sdzero2026, + title = {{Self-Distillation Zero: Self-Revision Turns Binary Rewards into Dense Supervision}}, + year = 2026, + eprint = {arXiv:2604.12002} + }"""), + } + + def __init__( + self, + model: str | PreTrainedModel | nn.Module, + args: SDZeroConfig | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, + processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, + reward_fn: Callable | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), + peft_config: PeftConfig | None = None, + ): + if isinstance(train_dataset, IterableDataset): + raise NotImplementedError("Iterable datasets are not yet supported in SDZeroTrainer.") + if isinstance(eval_dataset, IterableDataset) or ( + isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values()) + ): + raise NotImplementedError("Iterable eval datasets are not yet supported in SDZeroTrainer.") + if args is None: + model_name = model if isinstance(model, str) else get_config_model_id(model.config) + model_name = model_name.split("/")[-1] + args = SDZeroConfig(f"{model_name}-SDZero") + + self.reward_fn = reward_fn if reward_fn is not None else accuracy_reward + + super().__init__( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + callbacks=callbacks, + optimizers=optimizers, + peft_config=peft_config, + ) + + def _set_signature_columns_if_needed(self): + if self._signature_columns is None: + self._signature_columns = ["prompt", "answer"] + + def finalize_batch( + self, + inputs: list[dict[str, Any]], + rollout_batch: RolloutBatch, + ) -> TrainingBatch: + r""" + Build the teacher context for the shared student rollout and assemble the training batch. + + For each example, the student's rollout `y_init` is scored by `reward_fn`. A control prompt + `p_r` is selected from the outcome: the rephrase nudge when `y_init` verifies + as correct, the restart nudge otherwise. The teacher input is then assembled as + + ``` + teacher_input_ids = T(x, y, p_r) ++ completion_ids + ``` + + with optional chat-template applied. + """ + tokenizer = ( + self.processing_class.tokenizer + if isinstance(self.processing_class, ProcessorMixin) + else self.processing_class + ) + + # Decode student completions `y_init` + completions = [ + tokenizer.decode(ids[mask.bool()], skip_special_tokens=True) + for ids, mask in zip(rollout_batch.completion_ids, rollout_batch.completion_mask, strict=False) + ] + + answers = [inp["answer"] for inp in inputs] + chat_completions = [[{"role": "assistant", "content": c}] for c in completions] + rewards = [r if r is not None else 0.0 for r in self.reward_fn(chat_completions, solution=answers)] + control_prompts = [REPHRASE_PROMPT if r == 1.0 else RESTART_PROMPT for r in rewards] + + mode = "train" if self.model.training else "eval" + self._metrics[mode]["sdzero/reward"].append(sum(rewards) / max(len(rewards), 1)) + + # Build the teacher prompt + prompts, _ = self._split_prompt_and_privileged_context(inputs) + teacher_prompt_ids_list = [] + for prompt, y, control_prompt in zip(prompts, completions, control_prompts, strict=False): + assistant_turn_prefix = self.args.assistant_turn_template.format( + y=y, + control_prompt=control_prompt, + ) + if is_conversational({"prompt": prompt}): + teacher_prompt_ids = tokenizer.apply_chat_template( + prompt + [{"role": "assistant", "content": assistant_turn_prefix}], + tokenize=True, + add_generation_prompt=False, + continue_final_message=True, + **self.chat_template_kwargs, + ) + else: + teacher_prompt_ids = tokenizer(prompt + assistant_turn_prefix)["input_ids"] + if self.max_prompt_length is not None: + teacher_prompt_ids = teacher_prompt_ids[-self.max_prompt_length :] + teacher_prompt_ids_list.append(teacher_prompt_ids) + + device = rollout_batch.completion_ids.device + teacher_prompt_ids = [torch.tensor(ids) for ids in teacher_prompt_ids_list] + teacher_prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in teacher_prompt_ids] + teacher_prompt_ids = pad(teacher_prompt_ids, padding_value=self.pad_token_id, padding_side="left").to( + device=device + ) + teacher_prompt_mask = pad(teacher_prompt_mask, padding_value=0, padding_side="left").to(device=device) + + teacher_input_ids = torch.cat([teacher_prompt_ids, rollout_batch.completion_ids], dim=1) + teacher_attention_mask = torch.cat([teacher_prompt_mask, rollout_batch.completion_mask], dim=1) + + batch = rollout_batch.as_dict() + batch["teacher_input_ids"] = teacher_input_ids + batch["teacher_attention_mask"] = teacher_attention_mask + return batch + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if return_outputs: + raise ValueError("SDZeroTrainer does not support returning outputs") + + distillation_logits = self._compute_teacher_student_logits(model, self.teacher_model, inputs) + loss = self._compute_self_distillation_loss(model, inputs, distillation_logits) + accumulation_scale = self.current_gradient_accumulation_steps if self.model.training else 1.0 + return loss / accumulation_scale diff --git a/trl/experimental/sdzero/srt.py b/trl/experimental/sdzero/srt.py new file mode 100644 index 00000000000..20fd89df4c6 --- /dev/null +++ b/trl/experimental/sdzero/srt.py @@ -0,0 +1,98 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "peft", +# ] +# /// + +"""Training script for [`SRTTrainer`]. + +Trains a model with the self-revision objective via [`SRTTrainer`]. The dataset must be saved locally with +`datasets.save_to_disk` and contain columns `problem`, `y_init`, `control_prompt`, and `y_revised`. + +Example: + +```bash +uv run python trl/experimental/sdzero/srt.py \ + --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ + --dataset_path /tmp/sdzero_gsm8k_srt \ + --output_dir outputs/sdzero-srt-qwen2.5-0.5b \ + --per_device_train_batch_size 2 --gradient_accumulation_steps 4 \ + --learning_rate 5e-6 --max_steps 50 --logging_steps 1 +``` +""" + +from dataclasses import dataclass, field + +import torch +from datasets import load_from_disk +from transformers import AutoModelForCausalLM, AutoTokenizer + +from trl import ModelConfig, ScriptArguments, TrlParser, get_kbit_device_map, get_peft_config, get_quantization_config +from trl.experimental.sdzero import SRTConfig, SRTTrainer + + +@dataclass +class SRTScriptArguments(ScriptArguments): + dataset_path: str | None = field( + default=None, + metadata={"help": "Local path to a self-revision dataset saved with `datasets.save_to_disk`."}, + ) + + +if __name__ == "__main__": + parser = TrlParser((SRTScriptArguments, SRTConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + + if script_args.dataset_path is None: + raise ValueError("`--dataset_path` is required (pointing to a self-revision dataset).") + + dtype = model_args.dtype if model_args.dtype in ("auto", None) else getattr(torch, model_args.dtype) + model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + dtype=dtype, + ) + quantization_config = get_quantization_config(model_args) + if quantization_config is not None: + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + dataset = load_from_disk(script_args.dataset_path) + + model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) + if training_args.gradient_checkpointing: + model.config.use_cache = False + + trainer = SRTTrainer( + model=model, + args=training_args, + train_dataset=dataset, + processing_class=tokenizer, + peft_config=get_peft_config(model_args), + ) + trainer.train() + trainer.save_model(training_args.output_dir) diff --git a/trl/experimental/sdzero/srt_collect.py b/trl/experimental/sdzero/srt_collect.py new file mode 100644 index 00000000000..65941b5b092 --- /dev/null +++ b/trl/experimental/sdzero/srt_collect.py @@ -0,0 +1,375 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "math-verify>=0.5.2", +# ] +# /// + +"""Example pipeline to build a self-revision dataset for [`SRTTrainer`]. + +For each `(problem, gold_answer)` pair in the seed dataset: + + 1. Sample one initial response `y_init` from the model. + 2. Verify `y_init` (correct vs. incorrect). + 3. Select a control prompt based on the outcome: + - correct → "Let me rephrase the above solution." + - incorrect → "Wait, this response is not correct, let me start over." + 4. Sample `num_revisions` revised responses conditioned on `[problem, y_init, control_prompt]`. + 5. Keep only revisions that verify as correct. + +The resulting dataset has one row per kept revision with columns: + +- `problem` (str): the original problem statement. +- `y_init` (str): the model's initial response. +- `r_init` (int): 1 if `y_init` was correct, 0 otherwise. +- `control_prompt` (str): the rephrase/restart nudge used to elicit the revision. +- `y_revised` (str): a verified-correct revised response. + +Saved via `datasets.save_to_disk` for direct consumption by [`SRTTrainer`]. + +Generation backend is selectable at the CLI: + +- default (transformers): loads the model with `AutoModelForCausalLM` and calls `.generate()` in batches. +- `--use_vllm`: loads the model via `vllm.LLM` (requires the `vllm` optional dep). + +Example: + + uv run python trl/experimental/sdzero/srt_collect.py \\ + --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \\ + --num_problems 256 --num_revisions 3 \\ + --output_dir srt_revision_data +""" + +import argparse +import json +import os + +import torch +from datasets import Dataset, load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from trl.import_utils import is_vllm_available +from trl.rewards import accuracy_reward + + +REPHRASE_PROMPT = "Let me rephrase the above solution." +RESTART_PROMPT = "Wait, this response is not correct, let me start over." + + +def build_control_prompt(is_correct: bool) -> str: + return REPHRASE_PROMPT if is_correct else RESTART_PROMPT + + +def render_initial_prompt(tokenizer, problem: str, chat_template_kwargs: dict | None = None) -> str: + """Render the initial-response prompt. Matches what `SRTTrainer` tokenizes at train time.""" + return tokenizer.apply_chat_template( + [{"role": "user", "content": problem}], + tokenize=False, + add_generation_prompt=True, + **(chat_template_kwargs or {}), + ) + + +def render_revision_prompt( + tokenizer, + problem: str, + y_init: str, + control_prompt: str, + assistant_turn_prefix_template: str, + chat_template_kwargs: dict | None = None, +) -> str: + assistant_turn_prefix = assistant_turn_prefix_template.format( + y_init=y_init, + control_prompt=control_prompt, + ) + return tokenizer.apply_chat_template( + [ + {"role": "user", "content": problem}, + {"role": "assistant", "content": assistant_turn_prefix}, + ], + tokenize=False, + add_generation_prompt=False, + continue_final_message=True, + **(chat_template_kwargs or {}), + ) + + +def verify_batch(completions: list[str], references: list[str]) -> list[bool]: + """Binary verifier: wraps `trl.rewards.accuracy_reward` (math_verify + LaTeX boxed parsing).""" + chat_completions = [[{"role": "assistant", "content": c}] for c in completions] + rewards = accuracy_reward(chat_completions, solution=references) + return [r == 1.0 for r in rewards] + + +def load_seed(dataset_name: str, dataset_split: str, num_problems: int) -> list[dict]: + split = f"{dataset_split}[:{num_problems}]" if num_problems > 0 else dataset_split + ds = load_dataset(dataset_name, split=split) + return [{"problem": ex["problem"], "reference": ex["answer"]} for ex in ds] + + +class Generator: + """Backend-agnostic generator for data collection.""" + + def __init__(self, model_name_or_path: str, *, use_vllm: bool, batch_size: int = 16): + self.use_vllm = use_vllm + self.batch_size = batch_size + self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left") + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + if use_vllm: + if not is_vllm_available(): + raise ImportError("vLLM is not installed; install `trl[vllm]` or drop `--use_vllm`.") + from vllm import LLM + + self.llm = LLM(model=model_name_or_path) + else: + self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to("cuda").eval() + + def generate( + self, + prompts: list[str], + *, + num_return_sequences: int, + max_new_tokens: int, + temperature: float, + top_p: float, + seed: int, + ) -> list[list[str]]: + if self.use_vllm: + return self._generate_vllm(prompts, num_return_sequences, max_new_tokens, temperature, top_p, seed) + return self._generate_hf(prompts, num_return_sequences, max_new_tokens, temperature, top_p, seed) + + def _generate_vllm(self, prompts, n, max_tokens, temperature, top_p, seed): + from vllm import SamplingParams + + params = SamplingParams(n=n, max_tokens=max_tokens, temperature=temperature, top_p=top_p, seed=seed) + outputs = self.llm.generate(prompts, sampling_params=params, use_tqdm=False) + # vLLM returns outputs in the same order as input prompts; each has `n` CompletionOutputs. + return [[o.text for o in r.outputs] for r in outputs] + + def _generate_hf(self, prompts, n, max_tokens, temperature, top_p, seed): + torch.manual_seed(seed) + device = next(self.model.parameters()).device + all_completions: list[list[str]] = [] + for start in range(0, len(prompts), self.batch_size): + batch_prompts = prompts[start : start + self.batch_size] + enc = self.tokenizer( + batch_prompts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=2048, + add_special_tokens=False, + ).to(device) + with torch.inference_mode(): + out = self.model.generate( + **enc, + do_sample=True, + temperature=temperature, + top_p=top_p, + max_new_tokens=max_tokens, + num_return_sequences=n, + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=self.tokenizer.eos_token_id, + ) + prompt_len = enc["input_ids"].shape[1] + decoded = self.tokenizer.batch_decode(out[:, prompt_len:], skip_special_tokens=True) + for j in range(len(batch_prompts)): + all_completions.append(decoded[j * n : (j + 1) * n]) + return all_completions + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Collect a self-revision dataset for SRTTrainer.") + parser.add_argument( + "--model_name_or_path", + required=True, + help="HF repo or local path of the model used to generate both initial and revised responses.", + ) + parser.add_argument( + "--dataset_name", + default="open-r1/OpenR1-Math-220k", + help="Seed dataset. Must expose a `problem` column (problem statement) and an `answer` column " + "(gold final answer, LaTeX-compatible for the default verifier).", + ) + parser.add_argument( + "--dataset_split", + default="train", + help="Split to load from `--dataset_name`.", + ) + parser.add_argument( + "--num_problems", + type=int, + default=256, + help="Number of seed problems to load from the dataset split. Use a small value for quick runs; " + "scale up for real training data. `<= 0` loads the entire split.", + ) + parser.add_argument( + "--num_revisions", + type=int, + default=3, + help="Number of revised responses to sample per problem. More revisions raise the chance that at least " + "one verifies as correct (and thus survives filtering) at the cost of more generations.", + ) + parser.add_argument( + "--max_init_tokens", + type=int, + default=512, + help="Max new tokens when sampling the initial response `y_init`. Low defaults are for smoke tests; " + "raise this for non-trivial problems.", + ) + parser.add_argument( + "--max_revised_tokens", + type=int, + default=512, + help="Max new tokens when sampling each revised response `y_revised`. Same guidance as `--max_init_tokens`.", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.8, + help="Sampling temperature used for both initial and revision generation.", + ) + parser.add_argument( + "--top_p", + type=float, + default=0.95, + help="Top-p nucleus sampling cutoff used for both initial and revision generation.", + ) + parser.add_argument( + "--seed", + type=int, + default=0, + help="Base RNG seed. Initial-response generation uses `seed`, revision generation uses `seed + 1`.", + ) + parser.add_argument( + "--assistant_turn_prefix_template", + default="{y_init}\n\n{control_prompt}\n\n", + help="Template used to compose the assistant-side revision prefix from `y_init` and `control_prompt`.", + ) + parser.add_argument( + "--chat_template_kwargs", + type=json.loads, + default=None, + help="JSON dictionary of keyword arguments forwarded to `apply_chat_template` during prompt rendering.", + ) + parser.add_argument( + "--output_dir", + required=True, + help="Destination directory for the collected dataset. Written via `datasets.save_to_disk`.", + ) + parser.add_argument( + "--batch_size", + type=int, + default=16, + help="Prompts per generation batch. Only used by the transformers backend; vLLM batches internally.", + ) + parser.add_argument( + "--use_vllm", + action="store_true", + help="Use vLLM for generation (requires the `vllm` optional dependency). Defaults to transformers.", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + os.makedirs(args.output_dir, exist_ok=True) + + generator = Generator(args.model_name_or_path, use_vllm=args.use_vllm, batch_size=args.batch_size) + tokenizer = generator.tokenizer + + seed_rows = load_seed(args.dataset_name, args.dataset_split, args.num_problems) + + # Step 1: sample one initial response per problem. + init_prompts = [render_initial_prompt(tokenizer, row["problem"], args.chat_template_kwargs) for row in seed_rows] + init_completions = generator.generate( + init_prompts, + num_return_sequences=1, + max_new_tokens=args.max_init_tokens, + temperature=args.temperature, + top_p=args.top_p, + seed=args.seed, + ) + y_inits = [c[0] for c in init_completions] + + # Step 2: verify each initial response and pick the control prompt (rephrase vs. restart). + init_correct = verify_batch(y_inits, [row["reference"] for row in seed_rows]) + rows_for_revision = [ + { + "problem": row["problem"], + "reference": row["reference"], + "y_init": y_init, + "r_init": int(is_correct), + "control_prompt": build_control_prompt(is_correct), + } + for row, y_init, is_correct in zip(seed_rows, y_inits, init_correct, strict=True) + ] + + # Step 3: sample `num_revisions` revised responses per problem. + revision_prompts = [ + render_revision_prompt( + tokenizer, + r["problem"], + r["y_init"], + r["control_prompt"], + args.assistant_turn_prefix_template, + args.chat_template_kwargs, + ) + for r in rows_for_revision + ] + revision_completions = generator.generate( + revision_prompts, + num_return_sequences=args.num_revisions, + max_new_tokens=args.max_revised_tokens, + temperature=args.temperature, + top_p=args.top_p, + seed=args.seed + 1, + ) + + # Step 4: keep only revisions that verify as correct. + flat_completions: list[str] = [] + flat_references: list[str] = [] + flat_source_row: list[dict] = [] + for row, completions in zip(rows_for_revision, revision_completions, strict=True): + for y_revised in completions: + flat_completions.append(y_revised) + flat_references.append(row["reference"]) + flat_source_row.append(row) + revised_correct = verify_batch(flat_completions, flat_references) + + collected = [ + { + "problem": row["problem"], + "y_init": row["y_init"], + "r_init": row["r_init"], + "control_prompt": row["control_prompt"], + "y_revised": y_revised, + } + for row, y_revised, is_correct in zip(flat_source_row, flat_completions, revised_correct, strict=True) + if is_correct + ] + + if not collected: + raise RuntimeError("No verified revisions were produced; try increasing `num_problems` or sampling budget.") + + Dataset.from_list(collected).save_to_disk(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/trl/experimental/sdzero/srt_config.py b/trl/experimental/sdzero/srt_config.py new file mode 100644 index 00000000000..dd49a2a3071 --- /dev/null +++ b/trl/experimental/sdzero/srt_config.py @@ -0,0 +1,75 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any + +from ...trainer.sft_config import SFTConfig + + +@dataclass +class SRTConfig(SFTConfig): + r""" + Configuration class for [`SRTTrainer`]. + + Parameters: + assistant_turn_template (`str`, *optional*, defaults to `"{y_init}\n\n{control_prompt}\n\n{y_revised}"`): + Template used to compose the assistant turn from the initial answer, control prompt, and revised answer. + Must end with `{y_revised}` so the revised-answer supervision boundary is well-defined. + chat_template_kwargs (`dict[str, Any]` or `None`, *optional*): + Extra keyword arguments forwarded to `apply_chat_template` when rendering SRT prompts and assistant turns. + include_generation_loss (`bool`, *optional*, defaults to `True`): + Whether to include the generation loss term, which supervises the model on the full assistant turn + (initial answer, control prompt, and revised answer) given only the problem. + include_revision_loss (`bool`, *optional*, defaults to `True`): + Whether to include the revision loss term, which supervises the model only on the revised answer given + the full context (problem, initial answer, control prompt). + """ + + _VALID_DICT_FIELDS = SFTConfig._VALID_DICT_FIELDS + ["chat_template_kwargs"] + + assistant_turn_template: str = field( + default="{y_init}\n\n{control_prompt}\n\n{y_revised}", + metadata={ + "help": "Template used to compose the assistant turn from the initial answer, control prompt, and revised " + "answer. Must end with `{y_revised}` so the revised-answer supervision boundary is well-defined." + }, + ) + chat_template_kwargs: dict[str, Any] | None = field( + default=None, + metadata={ + "help": "Extra keyword arguments forwarded to `apply_chat_template` when rendering SRT prompts and assistant turns." + }, + ) + include_generation_loss: bool = field( + default=True, + metadata={ + "help": "Whether to include the generation loss term, which supervises the model on the full assistant turn " + "(initial answer, control prompt, and revised answer) given only the problem" + }, + ) + include_revision_loss: bool = field( + default=True, + metadata={ + "help": "Whether to include the revision loss term, which supervises the model only on the revised " + "answer given the full context (problem, initial answer, control prompt)." + }, + ) + + def __post_init__(self): + super().__post_init__() + if not self.assistant_turn_template.endswith("{y_revised}"): + raise ValueError("`assistant_turn_template` must end with `{y_revised}`.") + if not (self.include_revision_loss or self.include_generation_loss): + raise ValueError("At least one of `include_revision_loss` or `include_generation_loss` must be True.") diff --git a/trl/experimental/sdzero/srt_trainer.py b/trl/experimental/sdzero/srt_trainer.py new file mode 100644 index 00000000000..769e3bbc611 --- /dev/null +++ b/trl/experimental/sdzero/srt_trainer.py @@ -0,0 +1,284 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import textwrap +from collections.abc import Callable +from typing import Any + +import torch +from datasets import Dataset, IterableDataset +from torch import nn +from transformers import ( + AutoProcessor, + DataCollator, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, +) +from transformers.trainer_utils import EvalPrediction +from transformers.utils import is_peft_available + +from ...trainer.sft_trainer import SFTTrainer, get_dataset_column_names +from ...trainer.utils import get_config_model_id +from .srt_config import SRTConfig + + +if is_peft_available(): + from peft import PeftConfig + + +class SRTTrainer(SFTTrainer): + """ + Trainer for Self-Revision Training (SRT) from [Self-Distillation Zero](https://huggingface.co/papers/2604.12002). + + SRT trains a model with a joint objective combining two complementary loss terms. Each dataset row is + expanded into two training records that share the same token sequence but differ in which tokens are + supervised: + + - **Revision record**: loss computed only on the revised answer, conditioned on the full context + (problem, initial answer, control prompt) as input. + - **Generation record**: loss computed on the entire assistant turn — initial answer, control prompt, + and revised answer — conditioned on only the problem as input. + + The dataset must contain four string columns: `problem`, `y_init`, `control_prompt`, and `y_revised`. + + Args: + model (`str` or [`~transformers.PreTrainedModel`] or [`~peft.PeftModel`]): + Model to be trained. Can be either: + + - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a + path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded + using `.from_pretrained` with the keyword arguments in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. + args ([`SRTConfig`], *optional*): + Configuration for this trainer. If `None`, a default configuration is used. + data_collator ([`~transformers.DataCollator`], *optional*): + Function to use to form a batch from a list of elements of the processed dataset. + train_dataset ([`~datasets.Dataset`]): + Dataset for training. Must contain columns `problem`, `y_init`, `control_prompt`, and `y_revised`. + eval_dataset ([`~datasets.Dataset`] or `dict[str, Dataset]`, *optional*): + Dataset for evaluation. Must meet the same column requirements as `train_dataset`. + processing_class ([`~transformers.PreTrainedTokenizerBase`] or [`~transformers.ProcessorMixin`], *optional*): + Processing class used to tokenize the data. If `None`, loaded from the model name. A padding token + must be set; if absent, `eos_token` is used. + compute_loss_func (`Callable`, *optional*): + Custom loss function. See [`SFTTrainer`] for details. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + Function to compute metrics at evaluation. + callbacks (list of [`~transformers.TrainerCallback`], *optional*): + List of callbacks to customize the training loop. + optimizers (`tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None]`, *optional*, defaults to `(None, None)`): + Tuple containing the optimizer and scheduler to use. + optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*): + Tuple containing the optimizer class and keyword arguments. Overrides `optim` and `optim_args` in `args`. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*): + Function to preprocess logits before caching them at each evaluation step. + peft_config ([`~peft.PeftConfig`], *optional*): + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + """ + + _tag_names = ["trl", "sdzero", "srt"] + _name = "SRT" + config_cls = SRTConfig + # docstyle-ignore + _paper = { + "title": "Self-Distillation Zero: Self-Revision Turns Binary Rewards into Dense Supervision", + "id": "2604.12002", + "citation": textwrap.dedent("""\ + @article{sdzero2026, + title = {{Self-Distillation Zero: Self-Revision Turns Binary Rewards into Dense Supervision}}, + year = 2026, + eprint = {arXiv:2604.12002} + }"""), + } + + def __init__( + self, + model: str | PreTrainedModel | nn.Module, + args: SRTConfig | None = None, + data_collator: DataCollator | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, + processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, + compute_loss_func: Callable | None = None, + compute_metrics: Callable[[EvalPrediction], dict] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), + optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None, + preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + peft_config: PeftConfig | None = None, + ): + if isinstance(train_dataset, IterableDataset) or isinstance(eval_dataset, IterableDataset): + raise NotImplementedError("Iterable datasets are not supported by `SRTTrainer`.") + if args is None: + model_name = model if isinstance(model, str) else get_config_model_id(model.config) + model_name = model_name.split("/")[-1] + args = SRTConfig(f"{model_name}-SRT") + + if processing_class is None: + model_id = model if isinstance(model, str) else get_config_model_id(model.config) + processing_class = AutoProcessor.from_pretrained(model_id) + if processing_class.pad_token is None: + processing_class.pad_token = processing_class.eos_token + + if train_dataset is None: + raise ValueError("`train_dataset` is required for `SRTTrainer`.") + + train_dataset = self._expand_srt_dataset(train_dataset, processing_class, args) + if eval_dataset is not None: + if isinstance(eval_dataset, dict): + eval_dataset = { + name: self._expand_srt_dataset(ds, processing_class, args) for name, ds in eval_dataset.items() + } + else: + eval_dataset = self._expand_srt_dataset(eval_dataset, processing_class, args) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + compute_loss_func=compute_loss_func, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + peft_config=peft_config, + ) + + @staticmethod + def _expand_srt_dataset( + dataset: Dataset, + processing_class: PreTrainedTokenizerBase | ProcessorMixin, + args: SRTConfig, + ) -> Dataset: + """Expand each dataset row into tokenized supervised training records. + + Each input row must contain: + - `problem`: the user problem `x` + - `y_init`: the model's initial attempt + - `control_prompt`: the verifier-derived revision cue `P_r` + - `y_revised`: the revised answer + + From each row, this function emits up to two tokenized samples over the same + serialized chat: + + user: problem + assistant: + + with different completion masks: + + - Revision sample: loss is applied only to the `y_revised` suffix. + - Generation sample: loss is applied to the full assistant trace + `y_init + control_prompt + y_revised`. + + To support arbitrary chat templates, token boundaries are computed from + structured chat renders rather than manual token concatenation: + + - The generation boundary comes from the canonical prompt-only render with + `add_generation_prompt=True`. + - The revision boundary comes from rendering the same chat while continuing + the assistant message immediately before the `y_revised` suffix. + """ + + tokenizer = processing_class.tokenizer if isinstance(processing_class, ProcessorMixin) else processing_class + + columns = get_dataset_column_names(dataset) + required = ["problem", "y_init", "control_prompt", "y_revised"] + missing = [c for c in required if c not in columns] + if missing: + raise ValueError(f"SRT dataset is missing required columns: {missing}. Present columns: {columns}.") + + assistant_turn_prefix_template = args.assistant_turn_template.removesuffix("{y_revised}") + chat_template_kwargs = args.chat_template_kwargs or {} + + def _tokenize_messages( + messages: list[dict[str, str]], + *, + add_generation_prompt: bool = False, + continue_final_message: bool = False, + ) -> list[int]: + return tokenizer.apply_chat_template( + conversation=messages, + tokenize=True, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + **chat_template_kwargs, + ) + + def _row_to_records(example: dict[str, Any]) -> dict[str, list]: + problem = example["problem"] + y_init = example["y_init"] + control_prompt = example["control_prompt"] + y_revised = example["y_revised"] + + prompt_messages = [{"role": "user", "content": problem}] + assistant_full = args.assistant_turn_template.format( + y_init=y_init, + control_prompt=control_prompt, + y_revised=y_revised, + ) + assistant_before_revision = assistant_turn_prefix_template.format( + y_init=y_init, + control_prompt=control_prompt, + ) + + input_ids = _tokenize_messages( + prompt_messages + [{"role": "assistant", "content": assistant_full}], + ) + generation_prefix_ids = _tokenize_messages( + prompt_messages, + add_generation_prompt=True, + ) + revision_prefix_ids = _tokenize_messages( + prompt_messages + [{"role": "assistant", "content": assistant_before_revision}], + continue_final_message=True, + ) + + if input_ids[: len(generation_prefix_ids)] != generation_prefix_ids: + raise ValueError("Unexpected tokenization: generation prefix is not a prefix of the full input") + if input_ids[: len(revision_prefix_ids)] != revision_prefix_ids: + raise ValueError("Unexpected tokenization: revision prefix is not a prefix of the full input") + + generation_mask = [0] * len(generation_prefix_ids) + [1] * (len(input_ids) - len(generation_prefix_ids)) + revision_mask = [0] * len(revision_prefix_ids) + [1] * (len(input_ids) - len(revision_prefix_ids)) + + input_ids_list, completion_masks = [], [] + if args.include_revision_loss: + input_ids_list.append(input_ids) + completion_masks.append(revision_mask) + if args.include_generation_loss: + input_ids_list.append(input_ids) + completion_masks.append(generation_mask) + + return {"input_ids": input_ids_list, "completion_mask": completion_masks} + + expanded = dataset.map( + _row_to_records, + batched=False, + remove_columns=dataset.column_names, + ) + return Dataset.from_dict( + { + "input_ids": [input_ids for row in expanded["input_ids"] for input_ids in row], + "completion_mask": [mask for row in expanded["completion_mask"] for mask in row], + } + ) diff --git a/trl/experimental/self_distillation/__init__.py b/trl/experimental/self_distillation/__init__.py index 1449db2f7a3..f5d94b41edb 100644 --- a/trl/experimental/self_distillation/__init__.py +++ b/trl/experimental/self_distillation/__init__.py @@ -13,7 +13,6 @@ # limitations under the License. from .self_distillation_config import SelfDistillationConfig -from .self_distillation_mixin import SelfDistillationMixin -__all__ = ["SelfDistillationConfig", "SelfDistillationMixin"] +__all__ = ["SelfDistillationConfig"] diff --git a/trl/experimental/self_distillation/base_self_distillation_trainer.py b/trl/experimental/self_distillation/base_self_distillation_trainer.py index bd9abb95164..95fe72669c3 100644 --- a/trl/experimental/self_distillation/base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/base_self_distillation_trainer.py @@ -12,29 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Shared online self-distillation trainer scaffold. - -This base combines the generic Trainer setup for self-distillation with the online rollout utilities used by SDPO-like -methods. Offline methods such as SDFT stay on `_BaseTrainer` directly and only reuse the shared distillation mixin. -""" - from __future__ import annotations +import copy import inspect +from abc import ABC, abstractmethod from collections import defaultdict +from contextlib import nullcontext +from dataclasses import dataclass from functools import partial from typing import Any import datasets import torch from accelerate.logging import get_logger +from accelerate.utils import is_peft_model from datasets import Dataset, IterableDataset from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.utils.data import DataLoader, Sampler from transformers import ( - AutoModelForSequenceClassification, AutoProcessor, - AutoTokenizer, GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase, @@ -44,7 +42,8 @@ from transformers.trainer_utils import seed_worker from transformers.utils import is_datasets_available, is_peft_available -from ...models import prepare_deepspeed, prepare_fsdp +from ...data_utils import is_conversational +from ...models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation from ...trainer.base_trainer import _BaseTrainer from ...trainer.utils import ( RepeatSampler, @@ -52,12 +51,21 @@ disable_dropout_in_model, get_config_model_id, identity, + pad, split_tensor_dict, + use_adapter, ) from ..utils import prepare_peft_model -from .online_rollout_mixin import OnlineRolloutMixin +from .loss_utils import ( + aggregate_loss, + apply_importance_sampling_clipping, + compute_full_logit_self_distillation_loss, + compute_sampled_token_self_distillation_loss, + compute_topk_self_distillation_loss, + select_token_log_probs, +) from .self_distillation_config import SelfDistillationConfig -from .self_distillation_mixin import SelfDistillationMixin +from .teacher_sync import PEFTAdapterEMACallback, SyncTeacherModelCallback if is_peft_available(): @@ -67,18 +75,59 @@ logger = get_logger(__name__) -class BaseSelfDistillationTrainer(OnlineRolloutMixin, SelfDistillationMixin, _BaseTrainer): - """Shared scaffold for experimental self-distillation trainers without GRPO inheritance.""" +@dataclass +class RolloutBatch: + """Student-side rollout produced by `sample_rollouts`, consumed by `finalize_batch` to form a `TrainingBatch`.""" + + prompt_ids: torch.Tensor + prompt_mask: torch.Tensor + completion_ids: torch.Tensor + completion_mask: torch.Tensor + old_per_token_logps: torch.Tensor | None = None + raw_completion_lengths: torch.Tensor | None = None + + def as_dict(self) -> dict[str, torch.Tensor | Any]: + batch: dict[str, torch.Tensor | Any] = { + "prompt_ids": self.prompt_ids, + "prompt_mask": self.prompt_mask, + "completion_ids": self.completion_ids, + "completion_mask": self.completion_mask, + } + if self.old_per_token_logps is not None: + batch["old_per_token_logps"] = self.old_per_token_logps + if self.raw_completion_lengths is not None: + batch["raw_completion_lengths"] = self.raw_completion_lengths + return batch + + +TrainingBatch = dict[str, torch.Tensor | Any] + + +@dataclass +class DistillationLogits: + """Aligned logits and masks used to compute a self-distillation objective.""" + + completion_ids: torch.Tensor + completion_mask: torch.Tensor + response_mask: torch.Tensor + student_logits: torch.Tensor + teacher_logits: torch.Tensor + + +class BaseSelfDistillationTrainer(_BaseTrainer, ABC): + """Base that centralizes shared self-distillation trainer lifecycle.""" + + config_cls = SelfDistillationConfig + _tag_names = ["trl", "self-distillation"] + _name = "Self-Distillation" def __init__( self, model: str | PreTrainedModel | nn.Module, - reward_funcs: Any | list[Any] | None = None, args: SelfDistillationConfig | None = None, train_dataset: Dataset | IterableDataset | None = None, eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, - reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None, callbacks: list[TrainerCallback] | None = None, optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), peft_config: PeftConfig | None = None, @@ -104,7 +153,20 @@ def __init__( else inspect.signature(model.get_base_model().forward).parameters.keys() ) - if peft_config is not None or (is_peft_available() and getattr(model, "peft_config", None) is not None): + if peft_config is None and getattr(model, "peft_config", None) is not None: + logger.warning( + "The provided self-distillation student model already contains a PEFT adapter. " + "This setup is accepted but not directly supported. In particular, `teacher_model_kind='base'` " + "may refer to the underlying base weights rather than the exact initially loaded student state " + "including its adapter. For unambiguous teacher behavior, start from a merged/non-adapter model " + "or manage separate adapters explicitly." + ) + if is_peft_model(model) and peft_config is not None: + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config`. Pass either a base " + "model with `peft_config`, or a pre-wrapped PEFT model." + ) + if peft_config is not None or getattr(model, "peft_config", None) is not None: model = prepare_peft_model(model, peft_config, args) if processing_class is None: @@ -124,7 +186,6 @@ def __init__( self.pad_token_id = tokenizer.pad_token_id self.eos_token_id = tokenizer.eos_token_id - self.temperature = args.temperature self.max_prompt_length = args.max_prompt_length self.max_completion_length = args.max_completion_length self.num_generations = args.num_generations @@ -132,12 +193,8 @@ def __init__( self.num_iterations = args.num_iterations self.shuffle_dataset = args.shuffle_dataset self.loss_type = args.loss_type - self.importance_sampling_level = args.importance_sampling_level - self.scale_rewards = args.scale_rewards - self.epsilon_low = args.epsilon - self.epsilon_high = args.epsilon_high - self.beta = args.beta self.mask_truncated_completions = args.mask_truncated_completions + self.temperature = args.temperature self.chat_template_kwargs = args.chat_template_kwargs or {} self._step = 0 self._last_loaded_step = 0 @@ -148,7 +205,7 @@ def __init__( "eval": defaultdict(int), } - generation_kwargs = { + self.generation_kwargs = { "max_new_tokens": self.max_completion_length, "do_sample": True, "pad_token_id": tokenizer.pad_token_id, @@ -162,8 +219,8 @@ def __init__( "cache_implementation": args.cache_implementation, } if args.generation_kwargs is not None: - generation_kwargs.update(args.generation_kwargs) - self.generation_config = GenerationConfig(**generation_kwargs, disable_compile=True) + self.generation_kwargs.update(args.generation_kwargs) + self.generation_config = GenerationConfig(**self.generation_kwargs, disable_compile=True) if hasattr(model, "warnings_issued"): model.warnings_issued["estimate_tokens"] = True @@ -211,75 +268,113 @@ def __init__( logprobs=None, generation_kwargs=args.generation_kwargs, ) - self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation - - if reward_funcs is None: - reward_funcs = [] - if not isinstance(reward_funcs, list): - reward_funcs = [reward_funcs] - self.reward_func_names = [] - for i, reward_func in enumerate(reward_funcs): - if isinstance(reward_func, str): - reward_model_init_kwargs = args.model_init_kwargs or {} - if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: - reward_model_init_kwargs["device_map"] = None - reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( - reward_func, - num_labels=1, - **reward_model_init_kwargs, - ) - if isinstance(reward_funcs[i], nn.Module): - self.reward_func_names.append(get_config_model_id(reward_funcs[i].config).split("/")[-1]) - else: - self.reward_func_names.append(reward_funcs[i].__name__) - self.reward_funcs = reward_funcs - - if args.reward_weights is not None: - if len(args.reward_weights) != len(self.reward_funcs): - raise ValueError("Number of reward weights must match number of reward functions") - self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) - else: - self.reward_weights = torch.ones(len(self.reward_funcs), dtype=torch.float32) - - if reward_processing_classes is None: - reward_processing_classes = [None] * len(self.reward_funcs) - elif not isinstance(reward_processing_classes, list): - reward_processing_classes = [reward_processing_classes] - if len(reward_processing_classes) != len(self.reward_funcs): - raise ValueError("Number of reward processing classes must match number of reward functions") - - for i, (reward_processing_class, reward_func) in enumerate( - zip(reward_processing_classes, self.reward_funcs, strict=True) - ): - if isinstance(reward_func, PreTrainedModel): - if reward_processing_class is None: - reward_processing_class = AutoTokenizer.from_pretrained(get_config_model_id(reward_func.config)) - if reward_processing_class.pad_token_id is None: - reward_processing_class.pad_token = reward_processing_class.eos_token - reward_func.config.pad_token_id = reward_processing_class.pad_token_id - reward_processing_classes[i] = reward_processing_class - self.reward_processing_classes = reward_processing_classes + self._last_loaded_step = -1 if args.disable_dropout: disable_dropout_in_model(self.model) - for i, reward_func in enumerate(self.reward_funcs): - if isinstance(reward_func, nn.Module): - if self.is_deepspeed_enabled: - self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) - elif self.is_fsdp_enabled: - self.reward_funcs[i] = prepare_fsdp(reward_func, self.accelerator) - else: - self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) - self.model.add_model_tags(self._tag_names) + self._setup_teacher_model() self.model_accepts_loss_kwargs = False - self.ref_model = None - self.teacher_model = None - if args.sync_ref_model: - raise ValueError( - "sync_ref_model is not supported on the shared online self-distillation base without `ref_model`." + + def _set_signature_columns_if_needed(self): + if self._signature_columns is None: + self._signature_columns = ["prompt", "privileged_context"] + + def _dispatch_self_distillation_callback(self, event_name: str, **payload) -> None: + for callback in self.callback_handler.callbacks: + callback_fn = getattr(callback, event_name, None) + if callback_fn is not None: + callback_fn( + args=self.args, + state=self.state, + control=self.control, + model=self.model, + processing_class=self.processing_class, + **payload, + ) + + def _setup_teacher_model(self) -> None: + """Prepare teacher state according to the semantic teacher choice. + + Resolve `teacher_model_kind` × PEFT state into the effective teacher: + + - `"live"` (any model): + Teacher is the student. No divergence, no callback. + - `"base"` + PEFT model: + Teacher reuses `self.model`; the base weights are recovered downstream by disabling the adapter + via `use_adapter` during teacher forward. + - `"base"` + non-PEFT model: + Teacher is a frozen deepcopy of the initial student (falls through to the copy branch below). + - `"ema"` + pure-LoRA training: + Teacher reuses `self.model`; a dedicated `"teacher"` LoRA adapter is attached and updated by + `PEFTAdapterEMACallback`. Teacher forward switches to that adapter downstream. + - `"ema"` (otherwise): + Teacher is a frozen deepcopy synchronized each step by `SyncTeacherModelCallback`. + + Must be called after `super().__init__` so that `self.callback_handler` is available. + """ + + teacher_model_kind = self.args.teacher_model_kind + + if teacher_model_kind == "live": + self.teacher_model = self.model + return + + if teacher_model_kind == "base" and is_peft_model(self.model): + self.teacher_model = self.model + return + + if self._use_peft_ema_teacher_adapter(): + # Must run after super().__init__ so self.callback_handler exists. + self.add_callback( + PEFTAdapterEMACallback( + model=self.model, + teacher_adapter_name="teacher", + update_rate=self.args.teacher_update_rate, + sync_steps=self.args.teacher_sync_steps, + accelerator=self.accelerator, + ) ) + self.teacher_model = self.model + return + + # create teacher model from student copy + student_model = self.accelerator.unwrap_model(self.model) + self.teacher_model = copy.deepcopy(student_model) + self.teacher_model.requires_grad_(False) + self.teacher_model.eval() + if self.is_deepspeed_enabled: + self.teacher_model = prepare_deepspeed(self.teacher_model, self.accelerator) + elif self.is_fsdp_enabled: + self.teacher_model = prepare_fsdp(self.teacher_model, self.accelerator) + else: + self.teacher_model = self.accelerator.prepare_model(self.teacher_model, evaluation_mode=True) + + if teacher_model_kind == "ema": + self.add_callback(SyncTeacherModelCallback(teacher_model=self.teacher_model, accelerator=self.accelerator)) + + def _use_peft_ema_teacher_adapter(self) -> bool: + return self.args.teacher_model_kind == "ema" and self._is_pure_lora_training() + + def _is_pure_lora_training(self) -> bool: + """Return `True` when the active adapter is LoRA and every trainable parameter is a LoRA parameter.""" + if not is_peft_model(self.model): + return False + + model = self.accelerator.unwrap_model(self.model) + adapter_name = getattr(model, "active_adapter", None) or "default" + adapter_config = model.peft_config.get(adapter_name) + peft_type = getattr(adapter_config, "peft_type", None) + if peft_type is None or str(peft_type).split(".")[-1] != "LORA": + return False + + for name, param in model.named_parameters(): + if param.requires_grad and "lora_" not in name: + return False + return True def get_train_dataloader(self): if self.train_dataset is None: @@ -323,7 +418,7 @@ def _get_train_sampler(self, dataset=None) -> Sampler: def _get_eval_sampler(self, eval_dataset) -> Sampler: return RepeatSampler( data_source=eval_dataset, - mini_repeat_count=getattr(self, "num_generations_eval", self.num_generations), + mini_repeat_count=self.num_generations_eval, seed=self.args.seed, ) @@ -332,24 +427,439 @@ def training_step(self, model, inputs, num_items_in_batch): self._step += 1 return output + def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): + if not isinstance(inputs, dict): + inputs = self._prepare_inputs(inputs) + with torch.no_grad(): + with self.compute_loss_context_manager(): + loss = self.compute_loss(model, inputs) + return loss.detach(), None, None + def _prepare_inputs(self, generation_batch): + """Return the per-step training batch, regenerating rollouts and buffering them for reuse in train mode. + + In train mode, rollouts are generated once every `steps_per_generation * num_iterations` steps and split + into per-step slices reused until the next regeneration. In eval mode, every batch is freshly prepared. + """ mode = "train" if self.model.training else "eval" if mode == "train": generate_every = self.args.steps_per_generation * self.num_iterations if self._step % generate_every == 0 or self._buffered_inputs is None: - generation_batch = self._build_buffered_batch(generation_batch) - self._buffered_inputs = split_tensor_dict(generation_batch, self.args.steps_per_generation) + buffered_batch = self._prepare_training_batch(generation_batch) + self._buffered_inputs = split_tensor_dict(buffered_batch, self.args.steps_per_generation) self._dispatch_self_distillation_callback( "on_generation_batch_built", generate_every=generate_every, steps_per_generation=self.args.steps_per_generation, ) return self._buffered_inputs[self._step % self.args.steps_per_generation] - return self._build_buffered_batch(generation_batch) + return self._prepare_training_batch(generation_batch) + + def _prepare_training_batch(self, inputs: list[dict[str, Any]]) -> TrainingBatch: + """Sample student rollouts and let the subclass finalize them into the final `TrainingBatch`.""" + rollout_batch = self.sample_rollouts(inputs) + + batch = self.finalize_batch(inputs, rollout_batch) + + self._dispatch_self_distillation_callback( + "on_self_distillation_batch_prepared", + old_per_token_logps=batch.get("old_per_token_logps"), + prompt_ids=batch["prompt_ids"], + completion_ids=batch["completion_ids"], + teacher_input_ids=batch["teacher_input_ids"], + teacher_attention_mask=batch["teacher_attention_mask"], + self_distillation_mask=batch.get("self_distillation_mask"), + ) + return batch + + def _tokenize_prompts_untruncated(self, prompts: list[Any]) -> list[list[int]]: + if is_conversational({"prompt": prompts[0]}): + tokenized = self.processing_class.apply_chat_template( + conversation=prompts, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + **self.chat_template_kwargs, + ) + prompt_ids = tokenized["input_ids"] + else: + prompt_ids = self.processing_class(text=prompts)["input_ids"] + return prompt_ids + + def _tokenize_prompts(self, prompts: list[Any]) -> list[list[int]]: + prompt_ids = self._tokenize_prompts_untruncated(prompts) + if self.max_prompt_length is not None: + prompt_ids = [ids[-self.max_prompt_length :] for ids in prompt_ids] + return prompt_ids + + def sample_rollouts(self, inputs: list[dict[str, Any]]) -> RolloutBatch: + prompts, _ = self._split_prompt_and_privileged_context(inputs) + prompt_ids = self._tokenize_prompts(prompts) + self._dispatch_self_distillation_callback( + "on_generation_prompts_selected", + generation_prompts=prompts, + generation_prompt_text=None, + ) - def _prepare_auxiliary_model_for_eval(self, aux_model: nn.Module): - if self.is_deepspeed_enabled: - return prepare_deepspeed(aux_model, self.accelerator) - if self.is_fsdp_enabled: - return prepare_fsdp(aux_model, self.accelerator) - return self.accelerator.prepare_model(aux_model, evaluation_mode=True) + prompt_ids_list, completion_ids_list = self._generate(prompt_ids) + device = self.accelerator.device + prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left").to(device=device) + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left").to(device=device) + completion_ids = [torch.tensor(ids) for ids in completion_ids_list] + completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right").to(device=device) + completion_mask = pad(completion_mask, padding_value=0, padding_side="right").to(device=device) + if self.mask_truncated_completions: + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() + old_per_token_logps = self._compute_rollout_logps( + prompt_ids=prompt_ids, + prompt_mask=prompt_mask, + completion_ids=completion_ids, + completion_mask=completion_mask, + ) + + return RolloutBatch( + prompt_ids=prompt_ids, + prompt_mask=prompt_mask, + completion_ids=completion_ids, + completion_mask=completion_mask, + old_per_token_logps=old_per_token_logps, + raw_completion_lengths=torch.tensor( + [len(ids) for ids in completion_ids_list], device=device, dtype=torch.long + ), + ) + + def _split_prompt_and_privileged_context(self, inputs: list[dict[str, Any]]) -> tuple[list[Any], list[Any]]: + prompts = [example["prompt"] for example in inputs] + privileged_contexts = [example.get("privileged_context") for example in inputs] + return prompts, privileged_contexts + + def _generate(self, prompt_ids: list[list[int]]) -> tuple[list[list[int]], list[list[int]]]: + if self.use_vllm: + return self._generate_vllm(prompt_ids) + return self._generate_transformers(prompt_ids) + + def _generate_vllm(self, prompt_ids: list[list[int]]) -> tuple[list[list[int]], list[list[int]]]: + if self.state.global_step != self._last_loaded_step: + self.vllm_generation.sync_weights() + self._last_loaded_step = self.state.global_step + + mode = "train" if self.model.training else "eval" + num_generations = self.num_generations if mode == "train" else self.num_generations_eval + prompt_ids_out, completion_ids_list, _, _ = self.vllm_generation.generate( + prompts=prompt_ids, + images=None, + num_generations=num_generations, + ) + return prompt_ids_out, completion_ids_list + + def _generate_transformers(self, prompt_ids: list[list[int]]) -> tuple[list[list[int]], list[list[int]]]: + device = self.accelerator.device + prompt_tensors = [torch.tensor(ids) for ids in prompt_ids] + padded_ids = pad(prompt_tensors, padding_value=self.pad_token_id, padding_side="left") + attention_mask = pad([torch.ones_like(t) for t in prompt_tensors], padding_value=0, padding_side="left") + generate_inputs = {"input_ids": padded_ids, "attention_mask": attention_mask} + generate_inputs = _BaseTrainer._prepare_inputs(self, generate_inputs) + + with ( + unwrap_model_for_generation( + self.model_wrapped, + self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation, + generation_kwargs=self.generation_kwargs, + ) as unwrapped_model, + torch.no_grad(), + FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), + ): + prompt_completion_ids = unwrapped_model.generate( + **generate_inputs, generation_config=self.generation_config + ) + + prompt_length = generate_inputs["input_ids"].size(1) + completion_ids = prompt_completion_ids[:, prompt_length:] + is_eos = completion_ids == self.eos_token_id + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + seq_idx = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) + completion_mask = (seq_idx <= eos_idx.unsqueeze(1)).int() + completion_ids_list = [ + c[m].tolist() for c, m in zip(completion_ids.cpu(), completion_mask.bool().cpu(), strict=True) + ] + return prompt_ids, completion_ids_list + + def _compute_rollout_logps( + self, + prompt_ids: torch.Tensor, + prompt_mask: torch.Tensor, + completion_ids: torch.Tensor, + completion_mask: torch.Tensor, + ) -> torch.Tensor | None: + generate_every = self.args.steps_per_generation * self.num_iterations + old_per_token_logps = None + + if self.args.gradient_accumulation_steps % generate_every != 0: + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) + with torch.no_grad(): + logits = self._forward_logits( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + ) + old_per_token_logps = select_token_log_probs(logits, completion_ids) + + return old_per_token_logps + + def _compute_self_distillation_loss( + self, + model, + inputs: TrainingBatch, + distillation_logits: DistillationLogits, + ) -> torch.Tensor: + """Compute the per-token distillation loss and aggregate it according to `loss_type`. + + Dispatches between three objectives based on `distillation_mode`: + + - `"topk_logits"`: top-k approximation of the divergence, optionally with a tail bucket for the + remaining probability mass (`distillation_add_tail`). + - `"full_logits"`: full-vocab divergence. + - `"sampled_token"`: token-level (reverse-KL) distillation on sampled `completion_ids`. + + When `distillation_is_clip` is set and `old_per_token_logps` are available, the loss is corrected by a + clipped importance-sampling ratio between the current student and the student at rollout time. + """ + if distillation_logits.response_mask.sum() == 0: + mode = "train" if model.training else "eval" + self._log_self_distillation_metric(mode, 0.0) + # Keep the zero loss attached to the student graph so backward produces zero gradients instead of stopping. + return distillation_logits.student_logits.sum() * 0.0 + + if self.args.distillation_mode == "topk_logits": + if self.args.distillation_topk is None: + raise ValueError("`distillation_mode='topk_logits'` requires `distillation_topk` to be set.") + per_token_loss = compute_topk_self_distillation_loss( + distillation_logits.student_logits, + distillation_logits.teacher_logits, + distillation_topk=self.args.distillation_topk, + distillation_alpha=self.args.distillation_alpha, + distillation_add_tail=self.args.distillation_add_tail, + ) + elif self.args.distillation_mode == "full_logits": + per_token_loss = compute_full_logit_self_distillation_loss( + distillation_logits.student_logits, + distillation_logits.teacher_logits, + distillation_alpha=self.args.distillation_alpha, + ) + elif self.args.distillation_mode == "sampled_token": + per_token_loss = compute_sampled_token_self_distillation_loss( + distillation_logits.student_logits, + distillation_logits.teacher_logits, + distillation_logits.completion_ids, + distillation_alpha=self.args.distillation_alpha, + ) + else: + raise ValueError( + "distillation_mode must be one of: 'sampled_token', 'full_logits', 'topk_logits', " + f"got {self.args.distillation_mode!r}" + ) + + old_per_token_logps = inputs.get("old_per_token_logps") + if self.args.distillation_is_clip is not None and old_per_token_logps is not None: + student_per_token_logps = select_token_log_probs( + distillation_logits.student_logits, + distillation_logits.completion_ids, + ) + per_token_loss = apply_importance_sampling_clipping( + per_token_loss, + student_per_token_logps, + old_per_token_logps, + self.args.distillation_is_clip, + ) + + loss = aggregate_loss( + per_token_loss, + distillation_logits.response_mask, + loss_type=self.loss_type, + max_completion_length=self.max_completion_length, + ) + + mode = "train" if model.training else "eval" + mean_distill_loss = ( + per_token_loss * distillation_logits.response_mask + ).sum() / distillation_logits.response_mask.sum().clamp(min=1.0) + self._log_self_distillation_metric( + mode, + self.accelerator.gather(mean_distill_loss).mean().item(), + ) + return loss + + def _compute_teacher_student_logits( + self, + model, + teacher_model, + inputs: TrainingBatch, + ) -> DistillationLogits: + """Run student and teacher forwards on their respective inputs and pack aligned logits into a `DistillationLogits`. + + The teacher forward runs under the teacher context resolved by `_get_teacher_context_for_self_distillation`. + """ + prompt_ids = inputs["prompt_ids"] + prompt_mask = inputs["prompt_mask"] + completion_ids = inputs["completion_ids"] + completion_mask = inputs["completion_mask"] + logits_to_keep = completion_ids.size(1) + + response_mask = self._build_self_distillation_response_mask( + completion_mask, + inputs.get("self_distillation_mask"), + ) + student_logits = self._compute_student_distillation_logits( + model=model, + prompt_ids=prompt_ids, + prompt_mask=prompt_mask, + completion_ids=completion_ids, + completion_mask=completion_mask, + logits_to_keep=logits_to_keep, + ) + + teacher_logits = self._compute_teacher_distillation_logits( + teacher_model=teacher_model, + teacher_input_ids=inputs["teacher_input_ids"], + teacher_attention_mask=inputs["teacher_attention_mask"], + logits_to_keep=logits_to_keep, + ) + + return DistillationLogits( + completion_ids=completion_ids, + completion_mask=completion_mask, + response_mask=response_mask, + student_logits=student_logits, + teacher_logits=teacher_logits, + ) + + @staticmethod + def _build_self_distillation_response_mask( + completion_mask: torch.Tensor, + self_distillation_mask: torch.Tensor | None, + ) -> torch.Tensor: + if self_distillation_mask is None: + return completion_mask + return completion_mask * self_distillation_mask.unsqueeze(1) + + def _compute_student_distillation_logits( + self, + model, + prompt_ids: torch.Tensor, + prompt_mask: torch.Tensor, + completion_ids: torch.Tensor, + completion_mask: torch.Tensor, + logits_to_keep: int, + ) -> torch.Tensor: + student_input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + student_attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + return self._forward_logits( + model=model, + input_ids=student_input_ids, + attention_mask=student_attention_mask, + logits_to_keep=logits_to_keep, + ) + + def _compute_teacher_distillation_logits( + self, + teacher_model, + teacher_input_ids: torch.Tensor, + teacher_attention_mask: torch.Tensor, + logits_to_keep: int, + ) -> torch.Tensor: + with torch.no_grad(), self._get_teacher_context_for_self_distillation(): + return self._forward_logits( + model=teacher_model, + input_ids=teacher_input_ids, + attention_mask=teacher_attention_mask, + logits_to_keep=logits_to_keep, + ) + + def _forward_logits( + self, + model, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + logits_to_keep: int, + ) -> torch.Tensor: + """Forward the model and return temperature-scaled logits aligned to the completion tokens.""" + model_inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "use_cache": False, + } + if "logits_to_keep" in self.model_kwarg_keys: + model_inputs["logits_to_keep"] = logits_to_keep + 1 + + logits = model(**model_inputs).logits + logits = logits[:, :-1, :] + logits = logits[:, -logits_to_keep:, :] + return logits / self.temperature + + @abstractmethod + def finalize_batch( + self, + inputs: list[dict[str, Any]], + rollout_batch: RolloutBatch, + ) -> TrainingBatch: + """Build the final training batch from a shared student rollout batch. + + Subclasses must return a `dict` with at least the following keys, all first-dim-aligned to the student + batch size `B`: + + - `prompt_ids`: student prompt ids. + - `prompt_mask`: student prompt mask. + - `completion_ids`: student completion ids. + - `completion_mask`: student completion mask. + - `teacher_input_ids`: teacher input ids. + - `teacher_attention_mask`: teacher attention mask. + """ + + def _get_teacher_context_for_self_distillation(self): + """Return the context manager that routes the teacher forward to the correct weights. + + For non-PEFT models this is a no-op. For PEFT models: + + - `teacher_model_kind == "base"`: disable the student adapter so the teacher forward uses the base weights. + - `teacher_model_kind == "ema"` under pure-LoRA training: switch to the `"teacher"` LoRA adapter. + - otherwise: no-op; the teacher is a separate deepcopy. + """ + teacher_model_kind = self.args.teacher_model_kind + if not is_peft_model(self.model): + return nullcontext() + + target_model = self.accelerator.unwrap_model(self.teacher_model) + + if teacher_model_kind == "base": + return use_adapter(target_model, adapter_name=None) + if teacher_model_kind == "ema" and self._use_peft_ema_teacher_adapter(): + return use_adapter(target_model, adapter_name="teacher") + return nullcontext() + + def _log_self_distillation_metric(self, mode: str, value: float) -> None: + metric_prefix = self._name.lower().replace(" ", "_") + self._metrics[mode]["self_distillation/distillation_loss"].append(value) + self._metrics[mode][f"{metric_prefix}/distillation_loss"].append(value) + + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: + mode = "train" if self.model.training else "eval" + metrics = {k: sum(v) / len(v) for k, v in self._metrics[mode].items() if v} + if mode == "eval": + metrics = {f"eval_{k}": v for k, v in metrics.items()} + logs = {**logs, **metrics} + super().log(logs, start_time) + self._metrics[mode].clear() + + @abstractmethod + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + """Subclasses own algorithm-specific loss composition on the final batch contract.""" diff --git a/trl/experimental/self_distillation/loss_utils.py b/trl/experimental/self_distillation/loss_utils.py new file mode 100644 index 00000000000..ac74a751099 --- /dev/null +++ b/trl/experimental/self_distillation/loss_utils.py @@ -0,0 +1,178 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pure helper functions for self-distillation loss computation.""" + +from __future__ import annotations + +import torch +import torch.nn.functional as F + + +def select_token_log_probs( + logits: torch.Tensor, + token_ids: torch.Tensor, +) -> torch.Tensor: + logsumexp = torch.logsumexp(logits, dim=-1, keepdim=True) + indices = token_ids.unsqueeze(-1) + return (torch.gather(logits, dim=-1, index=indices) - logsumexp).squeeze(-1) + + +def compute_divergence( + student_log_probs: torch.Tensor, + teacher_log_probs: torch.Tensor, + alpha: float, +) -> torch.Tensor: + if alpha == 0.0: + kl = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) + elif alpha == 1.0: + kl = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) + else: + alpha_t = torch.tensor(alpha, dtype=student_log_probs.dtype, device=student_log_probs.device) + mixture = torch.logsumexp( + torch.stack([student_log_probs + torch.log(1 - alpha_t), teacher_log_probs + torch.log(alpha_t)]), + dim=0, + ) + kl_teacher = F.kl_div(mixture, teacher_log_probs, reduction="none", log_target=True) + kl_student = F.kl_div(mixture, student_log_probs, reduction="none", log_target=True) + kl = torch.lerp(kl_student, kl_teacher, alpha) + return kl.sum(-1) + + +def add_tail(log_probs: torch.Tensor) -> torch.Tensor: + log_s = torch.logsumexp(log_probs, dim=-1, keepdim=True) + log_s = torch.clamp(log_s, max=-1e-7) + tail_log = torch.log(-torch.expm1(log_s)) + return torch.cat([log_probs, tail_log], dim=-1) + + +def renorm_topk_log_probs(log_probs: torch.Tensor) -> torch.Tensor: + return log_probs - torch.logsumexp(log_probs, dim=-1, keepdim=True) + + +def compute_token_level_distillation_loss( + student_log_probs: torch.Tensor, + teacher_log_probs: torch.Tensor, +) -> torch.Tensor: + log_ratio = student_log_probs - teacher_log_probs + return log_ratio.detach() * student_log_probs + + +def apply_importance_sampling_clipping( + per_token_loss: torch.Tensor, + student_log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + clip_coeff: float, +) -> torch.Tensor: + negative_approx_kl = (student_log_probs - old_log_probs).detach() + negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) + ratio = torch.exp(negative_approx_kl).clamp(max=clip_coeff) + return per_token_loss * ratio + + +def aggregate_loss( + per_token_loss: torch.Tensor, + response_mask: torch.Tensor, + *, + loss_type: str, + max_completion_length: int, +) -> torch.Tensor: + """Reduce a per-token loss tensor according to the configured reduction. + + Args: + per_token_loss: + Per-token loss values of shape `(batch_size, seq_len)`. + response_mask: + Mask selecting which completion-token positions contribute to the loss. + loss_type: + Reduction mode. Uses the same loss-type conventions as the GRPO-family trainers. + max_completion_length: + Used by the `dr_grpo` reduction, which normalizes by a fixed completion budget. + """ + if loss_type == "grpo": + loss = (per_token_loss * response_mask).sum(-1) / response_mask.sum(-1).clamp(min=1.0) + return loss.mean() + if loss_type == "bnpo": + return (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) + if loss_type == "dr_grpo": + return (per_token_loss * response_mask).sum() / (per_token_loss.size(0) * max_completion_length) + if loss_type in ["dapo", "luspo", "cispo", "sapo"]: + return (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) + raise ValueError(f"Unsupported loss_type: {loss_type}") + + +def compute_topk_self_distillation_loss( + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + *, + distillation_topk: int, + distillation_alpha: float, + distillation_add_tail: bool, +) -> torch.Tensor: + """Compute distillation loss on the student's top-k token support. + + The student's top-k logits define the support. The teacher distribution is projected onto the same token indices. + The selected support is then either renormalized or augmented with a tail bucket before the divergence is computed. + """ + student_logsumexp = torch.logsumexp(student_logits, dim=-1, keepdim=True) + topk_student_logits, topk_indices = torch.topk(student_logits, k=distillation_topk, dim=-1) + topk_student_log_probs = topk_student_logits - student_logsumexp + + teacher_logsumexp = torch.logsumexp(teacher_logits, dim=-1, keepdim=True) + topk_teacher_logits = torch.gather(teacher_logits, dim=-1, index=topk_indices) + topk_teacher_log_probs = topk_teacher_logits - teacher_logsumexp + + if distillation_add_tail: + topk_student_log_probs = add_tail(topk_student_log_probs) + topk_teacher_log_probs = add_tail(topk_teacher_log_probs) + else: + topk_student_log_probs = renorm_topk_log_probs(topk_student_log_probs) + topk_teacher_log_probs = renorm_topk_log_probs(topk_teacher_log_probs) + + return compute_divergence(topk_student_log_probs, topk_teacher_log_probs, distillation_alpha) + + +def compute_full_logit_self_distillation_loss( + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + *, + distillation_alpha: float, +) -> torch.Tensor: + """Compute full-vocabulary self-distillation loss between student and teacher logits.""" + student_log_probs = torch.log_softmax(student_logits, dim=-1) + teacher_log_probs = torch.log_softmax(teacher_logits, dim=-1) + return compute_divergence(student_log_probs, teacher_log_probs, distillation_alpha) + + +def compute_sampled_token_self_distillation_loss( + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + completion_ids: torch.Tensor, + *, + distillation_alpha: float, +) -> torch.Tensor: + """Compute token-level self-distillation loss only on the sampled completion tokens. + + This path compares student and teacher log-probabilities on the realized completion tokens rather than over a + larger token support. + """ + if distillation_alpha != 1.0: + raise ValueError( + "Only reverse KL (alpha=1.0) is supported for token-level distillation when " + f"`distillation_mode='sampled_token'`, got alpha={distillation_alpha}" + ) + + student_per_token_logps = select_token_log_probs(student_logits, completion_ids) + teacher_per_token_logps = select_token_log_probs(teacher_logits, completion_ids) + return compute_token_level_distillation_loss(student_per_token_logps, teacher_per_token_logps) diff --git a/trl/experimental/self_distillation/online_rollout_mixin.py b/trl/experimental/self_distillation/online_rollout_mixin.py deleted file mode 100644 index 490724582dc..00000000000 --- a/trl/experimental/self_distillation/online_rollout_mixin.py +++ /dev/null @@ -1,385 +0,0 @@ -# Copyright 2020-2026 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Online rollout helpers for experimental self-distillation trainers. - -This mixin owns generation, reward scoring, grouped reward normalization, and online policy-loss plumbing. It is paired -with `BaseSelfDistillationTrainer` for SDPO-style methods and intentionally kept separate from the generic distillation -loss logic in `self_distillation_mixin.py`. -""" - -from __future__ import annotations - -import torch -from torch import nn -from transformers.utils import logging - -from ...data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template -from ...models import unwrap_model_for_generation -from ...trainer.base_trainer import _BaseTrainer -from ...trainer.utils import pad - - -logger = logging.get_logger(__name__) - - -class OnlineRolloutMixin: - """Online rollout, reward, and policy-loss utilities shared by SDPO-like trainers.""" - - def _apply_prompt_template(self, prompts): - return [ - maybe_apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"] - for prompt in prompts - ] - - def _build_buffered_batch(self, generation_batch): - return self._generate_and_score_completions(generation_batch) - - def _generate(self, prompts): - if self.use_vllm: - return self._generate_vllm(prompts) - return self._generate_transformers(prompts) - - def _generate_vllm(self, prompts): - # Sync weights if training step changed - if self.state.global_step != self._last_loaded_step: - self.vllm_generation.sync_weights() - self._last_loaded_step = self.state.global_step - - # Tokenize prompts to token IDs - prompts_text = self._apply_prompt_template(prompts) - tokenized = self.processing_class( - text=prompts_text, - return_tensors=None, - padding=False, - max_length=self.max_prompt_length, - truncation=True, - add_special_tokens=False, - ) - prompt_ids = tokenized["input_ids"] # list of list[int] - - # Generate via vLLM — it deduplicates repeated prompts from RepeatSampler internally - mode = "train" if self.model.training else "eval" - num_generations = self.num_generations if mode == "train" else self.num_generations_eval - prompt_ids_out, completion_ids_list, _, _ = self.vllm_generation.generate( - prompts=prompt_ids, - images=None, - num_generations=num_generations, - ) - return prompt_ids_out, completion_ids_list - - def _generate_transformers(self, prompts): - # Keep the generation path aligned with the reference trainers: generate from left-padded prompts, - # then recover completion token spans by trimming prompt tokens and stopping at the first EOS. - prompts_text = self._apply_prompt_template(prompts) - generate_inputs = self.processing_class( - text=prompts_text, - return_tensors="pt", - padding=True, - padding_side="left", - max_length=self.max_prompt_length, - truncation=True, - add_special_tokens=False, - ) - # This path already receives tokenized model inputs. Bypass the buffered trainer hook and use the plain - # tensor/device preparation from `_BaseTrainer`. - generate_inputs = _BaseTrainer._prepare_inputs(self, generate_inputs) - with ( - unwrap_model_for_generation( - self.model_wrapped, - self.accelerator, - gather_deepspeed3_params=self.args.ds3_gather_for_generation, - ) as unwrapped_model, - torch.no_grad(), - ): - prompt_completion_ids = unwrapped_model.generate( - **generate_inputs, generation_config=self.generation_config - ) - prompt_ids = generate_inputs["input_ids"] - prompt_mask = generate_inputs["attention_mask"] - prompt_length = prompt_ids.size(1) - completion_ids = prompt_completion_ids[:, prompt_length:] - is_eos = completion_ids == self.eos_token_id - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device) - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - seq_idx = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1) - completion_mask = (seq_idx <= eos_idx.unsqueeze(1)).int() - prompt_ids_list = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool(), strict=False)] - completion_ids_list = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=False)] - return prompt_ids_list, completion_ids_list - - def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): - device = self.accelerator.device - if len(self.reward_funcs) == 0: - return torch.zeros((len(prompts), 0), device=device) - - rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) - keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]] - reward_kwargs = {key: [example[key] for example in inputs] for key in keys} - reward_kwargs["trainer_state"] = self.state - - for i, (reward_func, reward_processing_class) in enumerate( - zip(self.reward_funcs, self.reward_processing_classes, strict=True) - ): - if isinstance(reward_func, nn.Module): - if is_conversational(inputs[0]): - messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)] - texts = [ - apply_chat_template(x, reward_processing_class, **self.chat_template_kwargs)["text"] - for x in messages - ] - else: - texts = [p + c for p, c in zip(prompts, completions, strict=True)] - reward_inputs = reward_processing_class( - text=texts, - return_tensors="pt", - padding=True, - padding_side="right", - add_special_tokens=False, - ) - # Reward functions operate on tokenized tensors too, so they need the base Trainer input preparation - # rather than the outer buffered generation hook. - reward_inputs = _BaseTrainer._prepare_inputs(self, reward_inputs) - with torch.inference_mode(): - rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] - else: - output_reward_func = reward_func( - prompts=prompts, - completions=completions, - completion_ids=completion_ids_list, - **reward_kwargs, - ) - output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] - rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) - - return self.accelerator.gather(rewards_per_func) - - def _generate_and_score_completions(self, inputs): - device = self.accelerator.device - mode = "train" if self.model.training else "eval" - prompts = [x["prompt"] for x in inputs] - prompt_ids_list, completion_ids_list = self._generate(prompts) - - prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list] - prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left").to(device=device) - prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left").to(device=device) - completion_ids = [torch.tensor(ids) for ids in completion_ids_list] - completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right").to(device=device) - completion_mask = pad(completion_mask, padding_value=0, padding_side="right").to(device=device) - - if self.mask_truncated_completions: - eos_and_pad = [self.eos_token_id, self.pad_token_id] - is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) - completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() - - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - logits_to_keep = completion_ids.size(1) - - with torch.no_grad(): - generate_every = self.args.steps_per_generation * self.num_iterations - if self.args.gradient_accumulation_steps % generate_every != 0: - old_per_token_logps, _ = self._get_per_token_logps_and_entropies( - self.model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - compute_entropy=False, - ) - else: - old_per_token_logps = None - - if is_conversational({"prompt": prompts[0]}): - completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - completions = [[{"role": "assistant", "content": content}] for content in completions_text] - else: - completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - - rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) - if rewards_per_func.numel() == 0: - rewards = torch.zeros(self.accelerator.num_processes * len(prompts), device=device) - else: - rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) - num_generations = self.num_generations if mode == "train" else self.num_generations_eval - mean_grouped_rewards = rewards.view(-1, num_generations).mean(dim=1).repeat_interleave(num_generations, dim=0) - if self.scale_rewards == "batch": - std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards) - group_std_rewards = rewards.view(-1, num_generations).std(dim=1) - elif self.scale_rewards == "none": - std_rewards = torch.ones_like(rewards) - group_std_rewards = torch.ones(rewards.numel() // num_generations, device=device, dtype=rewards.dtype) - else: - group_std_rewards = rewards.view(-1, num_generations).std(dim=1) - std_rewards = group_std_rewards.repeat_interleave(num_generations, dim=0) - advantages = (rewards - mean_grouped_rewards) / (std_rewards + 1e-4) - self._record_reward_diagnostics(mode, rewards, rewards_per_func, group_std_rewards) - - local_batch_size = completion_ids.size(0) - process_start = self.accelerator.process_index * local_batch_size - process_slice = slice(process_start, process_start + local_batch_size) - rewards = rewards[process_slice] - advantages = advantages[process_slice] - - agg_completion_lengths = self.accelerator.gather( - torch.tensor([len(ids) for ids in completion_ids_list], device=device) - ) - self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) - - eos_and_pad = [self.eos_token_id, self.pad_token_id] - is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) - agg_is_truncated = self.accelerator.gather(is_truncated) - self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) - term_completion_lengths = agg_completion_lengths[~agg_is_truncated] - if len(term_completion_lengths) == 0: - term_completion_lengths = torch.zeros(1, device=device) - self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - - output = { - "prompt_ids": prompt_ids, - "prompt_mask": prompt_mask, - "completion_ids": completion_ids, - "completion_mask": completion_mask, - "rewards": rewards, - "advantages": advantages, - "num_items_in_batch": completion_mask.sum().detach(), - } - if old_per_token_logps is not None: - output["old_per_token_logps"] = old_per_token_logps - - self._dispatch_self_distillation_callback( - "on_self_distillation_batch_prepared", - old_per_token_logps=old_per_token_logps, - prompt_ids=prompt_ids, - completion_ids=completion_ids, - ) - return output - - def _record_reward_diagnostics( - self, - mode: str, - rewards: torch.Tensor, - rewards_per_func: torch.Tensor, - group_std_rewards: torch.Tensor, - ) -> None: - tolerance = self.args.diagnostics_flat_tolerance - - reward_mean = rewards.mean() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) - reward_std = rewards.std() if rewards.numel() > 1 else torch.tensor(0.0, device=self.accelerator.device) - reward_min = rewards.min() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) - reward_max = rewards.max() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) - flat_group_fraction = ( - (group_std_rewards <= tolerance).float().mean() - if group_std_rewards.numel() > 0 - else torch.tensor(1.0, device=self.accelerator.device) - ) - - self._metrics[mode]["self_distillation/reward_mean"].append(self.accelerator.gather(reward_mean).mean().item()) - self._metrics[mode]["self_distillation/reward_std"].append(self.accelerator.gather(reward_std).mean().item()) - self._metrics[mode]["self_distillation/reward_min"].append(self.accelerator.gather(reward_min).min().item()) - self._metrics[mode]["self_distillation/reward_max"].append(self.accelerator.gather(reward_max).max().item()) - self._metrics[mode]["self_distillation/group_reward_std_mean"].append( - self.accelerator.gather(group_std_rewards.mean() if group_std_rewards.numel() > 0 else reward_std) - .mean() - .item() - ) - self._metrics[mode]["self_distillation/flat_group_fraction"].append( - self.accelerator.gather(flat_group_fraction).mean().item() - ) - - if rewards_per_func.numel() > 0: - reward_func_means = rewards_per_func.nanmean(dim=0) - gathered_means = self.accelerator.gather(reward_func_means).view(-1, reward_func_means.numel()).mean(dim=0) - for reward_name, reward_func_mean in zip(self.reward_func_names, gathered_means.tolist(), strict=True): - self._metrics[mode][f"self_distillation/rewards/{reward_name}"].append(reward_func_mean) - - reward_is_flat = reward_std.item() <= tolerance - grouped_rewards_are_flat = flat_group_fraction.item() >= 1.0 - tolerance - if reward_is_flat and grouped_rewards_are_flat: - self._warn_on_degenerate_diagnostics( - mode=mode, - counter_key="flat_rewards", - message=( - "Observed flat SDPO rewards across all sampled generations. " - "Policy advantages will collapse to zero, and SDPO will not learn. " - "Check reward density, reward shaping, or `success_reward_threshold`." - ), - ) - else: - self._diagnostic_counters[mode]["flat_rewards"] = 0 - - def _warn_on_degenerate_diagnostics(self, mode: str, counter_key: str, message: str) -> None: - interval = self.args.diagnostics_warning_interval - if interval == 0: - return - - self._diagnostic_counters[mode][counter_key] += 1 - count = self._diagnostic_counters[mode][counter_key] - if count == 1 or count % interval == 0: - logger.warning("%s Consecutive degenerate steps: %s.", message, count) - - def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): - if return_outputs: - raise ValueError(f"The {self.__class__.__name__} does not support returning outputs") - return self._compute_loss(model, inputs) - - def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): - if not isinstance(inputs, dict): - inputs = self._prepare_inputs(inputs) - with torch.no_grad(): - with self.compute_loss_context_manager(): - loss = self.compute_loss(model, inputs) - return loss.detach(), None, None - - def _compute_loss(self, model, inputs): - prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] - completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] - input_ids = torch.cat([prompt_ids, completion_ids], dim=1) - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - logits_to_keep = completion_ids.size(1) - per_token_logps, _ = self._get_per_token_logps_and_entropies( - model, - input_ids, - attention_mask, - logits_to_keep, - compute_entropy=False, - ) - old_per_token_logps = inputs.get("old_per_token_logps") - old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps - advantages = inputs["advantages"] - if advantages.dim() == 1: - advantages = advantages.unsqueeze(1) - log_ratio = per_token_logps - old_per_token_logps - if self.importance_sampling_level == "sequence": - log_ratio = (log_ratio * completion_mask).sum(-1, keepdim=True) / completion_mask.sum( - -1, keepdim=True - ).clamp(min=1.0) - coef_1 = torch.exp(log_ratio) - coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) - per_token_loss = -torch.min(coef_1 * advantages, coef_2 * advantages) - - loss = self._aggregate_self_distillation_loss(per_token_loss, completion_mask) - - mode = "train" if self.model.training else "eval" - self._metrics[mode]["self_distillation/policy_loss"].append( - self.accelerator.gather(loss.detach()).mean().item() - ) - - accumulation_scale = self.current_gradient_accumulation_steps if mode == "train" else 1.0 - return loss / accumulation_scale diff --git a/trl/experimental/self_distillation/prompt_utils.py b/trl/experimental/self_distillation/prompt_utils.py new file mode 100644 index 00000000000..f5decce6d7a --- /dev/null +++ b/trl/experimental/self_distillation/prompt_utils.py @@ -0,0 +1,32 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any + + +def extract_last_user_text(messages: list[dict[str, Any]]) -> str: + """Extract the text content from the last user message in a conversational prompt.""" + last_message = messages[-1] + if last_message.get("role") != "user": + raise ValueError( + f"Self-distillation teacher prompt construction expects the conversation to end with a user turn, " + f"but the last message has role '{last_message.get('role')}'. " + f"Prompts ending with assistant prefills or tool turns are not supported." + ) + content = last_message.get("content", "") + if isinstance(content, list): + return " ".join(part.get("text", "") for part in content if part.get("type") == "text") + return content diff --git a/trl/experimental/self_distillation/self_distillation_config.py b/trl/experimental/self_distillation/self_distillation_config.py index c28c0f6be0b..4fe20c00b95 100644 --- a/trl/experimental/self_distillation/self_distillation_config.py +++ b/trl/experimental/self_distillation/self_distillation_config.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import Any +from typing import Any, Literal from transformers import TrainingArguments @@ -30,39 +30,120 @@ class SelfDistillationConfig(_BaseConfig): [`trl.trainer.base_config._BaseConfig`]. Parameters: - > Parameters that control generation and rollout reuse + > Parameters that control the model model_init_kwargs (`dict[str, Any]`, *optional*): - Keyword arguments used when the `model` argument is passed as a string. + Keyword arguments for model initialization when the `model` argument is passed as a string. + disable_dropout (`bool`, *optional*, defaults to `False`): + Whether to disable dropout in the student model. + remove_unused_columns (`bool`, *optional*, defaults to `False`): + Whether to drop dataset columns unused by the trainer. + + > Parameters that control generation and rollout reuse + max_prompt_length (`int` or `None`, *optional*, defaults to `512`): Maximum prompt length. Longer prompts are truncated from the left. num_generations (`int`, *optional*, defaults to `8`): Number of sampled generations per prompt. + num_generations_eval (`int` or `None`, *optional*): + Number of sampled generations per prompt during evaluation. + max_completion_length (`int` or `None`, *optional*, defaults to `256`): + Maximum generated completion length. + ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): + Whether to gather ZeRO-3 weights for generation. + shuffle_dataset (`bool`, *optional*, defaults to `True`): + Whether to shuffle the training dataset. generation_batch_size (`int` or `None`, *optional*): Global batch size used for generation. Mutually exclusive with `steps_per_generation`. steps_per_generation (`int` or `None`, *optional*): Number of optimizer steps that reuse one generated batch. Mutually exclusive with `generation_batch_size`. + > Parameters that control sampling + + temperature (`float`, *optional*, defaults to `1.0`): + Sampling temperature. + top_p (`float`, *optional*, defaults to `1.0`): + Top-p sampling parameter. + top_k (`int`, *optional*, defaults to `0`): + Top-k sampling parameter. `0` disables top-k filtering. + min_p (`float` or `None`, *optional*): + Minimum token probability for sampling. + generation_kwargs (`dict[str, Any]` or `None`, *optional*): + Extra generation kwargs passed to `GenerationConfig`. + chat_template_kwargs (`dict[str, Any]` or `None`, *optional*): + Extra kwargs forwarded to chat template application. + repetition_penalty (`float`, *optional*, defaults to `1.0`): + Repetition penalty used during generation. + use_transformers_paged (`bool`, *optional*, defaults to `False`): + Reserved for paged generation support. + cache_implementation (`str` or `None`, *optional*): + Cache implementation used by transformers generation. + + > Parameters that control vLLM generation + + use_vllm (`bool`, *optional*, defaults to `False`): + Whether to use vLLM for generation. + vllm_mode (`str`, *optional*, defaults to `"colocate"`): + vLLM mode: `"colocate"` (shared GPU) or `"server"` (separate vLLM server). + vllm_model_impl (`str`, *optional*, defaults to `"vllm"`): + Model implementation for vLLM: `"vllm"` or `"transformers"`. + vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`): + Whether to enable sleep mode for colocated vLLM engine. + vllm_server_base_url (`str` or `None`, *optional*): + Base URL for the vLLM server. If provided, `vllm_server_host` and `vllm_server_port` are ignored. + vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`): + Host of the vLLM server (server mode only). + vllm_server_port (`int`, *optional*, defaults to `8000`): + Port of the vLLM server (server mode only). + vllm_group_port (`int`, *optional*, defaults to `51216`): + Port for the weight update group (server mode only). + vllm_server_timeout (`float`, *optional*, defaults to `240.0`): + Timeout in seconds to wait for the vLLM server. + vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`): + Tensor parallel size for colocated vLLM. + vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.3`): + GPU memory utilization ratio for colocated vLLM. + vllm_max_model_length (`int` or `None`, *optional*): + Model context length for vLLM. Inferred from model config if not set. + > Parameters that control the online policy objective - beta (`float`, *optional*, defaults to `0.0`): - Reference-model KL coefficient for online policy optimization. + num_iterations (`int`, *optional*, defaults to `1`): + Number of optimization iterations per generated batch. loss_type (`str`, *optional*, defaults to `"dapo"`): Policy-loss aggregation mode. Supported: `grpo`, `bnpo`, `dr_grpo`, `dapo`. - scale_rewards (`str` or `bool`, *optional*, defaults to `"group"`): - Reward normalization mode. Supported: `group`, `batch`, `none`. + mask_truncated_completions (`bool`, *optional*, defaults to `False`): + Whether to exclude truncated completions from the loss. + top_entropy_quantile (`float`, *optional*, defaults to `1.0`): + Reserved for entropy-based token filtering. + + > Parameters that control teacher construction + + teacher_model_kind (`str`, *optional*, defaults to `"live"`): + Semantic teacher choice. `live` uses the current student, `base` uses the initial student, and `ema` uses + an exponentially averaged teacher. + teacher_update_rate (`float`, *optional*, defaults to `0.6`): + EMA update rate used when `teacher_model_kind="ema"`. A value of `1.0` reduces the update to a hard + overwrite, periodically resyncing the teacher to the current student weights. + teacher_sync_steps (`int`, *optional*, defaults to `512`): + Number of optimizer steps between teacher updates. > Parameters that control self-distillation distillation_alpha (`float`, *optional*, defaults to `0.5`): - Divergence interpolation coefficient using the official SDPO/SDFT convention: `0.0=forward KL`, `0.5=JSD`, + KL divergence direction: `0.0=forward KL`, `0.5=JSD`, `1.0=reverse KL`. - distillation_topk (`int` or `None`, *optional*, defaults to `100`): - Number of top tokens to keep for top-k distillation. If `None`, all logits are used. - full_logit_distillation (`bool`, *optional*, defaults to `False`): - Whether to use full-logit distillation instead of token-level distillation. + distillation_mode (`Literal["sampled_token", "full_logits", "topk_logits"]`, *optional*, defaults to `"sampled_token"`): + Distillation objective mode. `"sampled_token"` uses token-level distillation on the sampled completion + tokens, `"full_logits"` uses full-vocabulary divergence, and `"topk_logits"` uses a top-k approximation + over the student support. + distillation_topk (`int` or `None`, *optional*): + Number of top tokens for `"topk_logits"` distillation. Must be set when + `distillation_mode="topk_logits"` and left unset otherwise. distillation_is_clip (`float` or `None`, *optional*, defaults to `2.0`): - Importance-sampling clip used by the official SDPO-style correction. `None` disables clipping. + Clipping coefficient for importance sampling in self-distillation. `None` disables clipping. + distillation_add_tail (`bool`, *optional*, defaults to `False`): + Whether to add a tail bucket for non-top-k probability mass. distillation_weight (`float`, *optional*, defaults to `1.0`): Weight applied to the self-distillation loss term. @@ -118,7 +199,9 @@ class SelfDistillationConfig(_BaseConfig): ) steps_per_generation: int | None = field( default=None, - metadata={"help": "Number of optimizer steps that reuse one generated batch."}, + metadata={ + "help": "Number of optimizer steps that reuse one generated batch. Mutually exclusive with `generation_batch_size`." + }, ) temperature: float = field( default=1.0, @@ -202,34 +285,10 @@ class SelfDistillationConfig(_BaseConfig): default=None, metadata={"help": "Model context length for vLLM. Inferred from model config if not set."}, ) - beta: float = field( - default=0.0, - metadata={"help": "Reference-model KL coefficient for online policy optimization."}, - ) num_iterations: int = field( default=1, metadata={"help": "Number of optimization iterations per generated batch."}, ) - epsilon: float = field( - default=0.2, - metadata={"help": "Lower clipping coefficient for GRPO-style policy loss."}, - ) - epsilon_high: float | None = field( - default=None, - metadata={"help": "Upper clipping coefficient. Defaults to `epsilon` when unset."}, - ) - importance_sampling_level: str = field( - default="token", - metadata={"help": "Importance-sampling granularity. Supported: `token`, `sequence`."}, - ) - reward_weights: list[float] | None = field( - default=None, - metadata={"help": "Optional weights for multiple reward functions."}, - ) - scale_rewards: str | bool = field( - default="group", - metadata={"help": "Reward normalization mode. Supported: `group`, `batch`, `none`."}, - ) loss_type: str = field( default="dapo", metadata={"help": "Policy loss aggregation. Supported: `grpo`, `bnpo`, `dr_grpo`, `dapo`."}, @@ -238,17 +297,23 @@ class SelfDistillationConfig(_BaseConfig): default=False, metadata={"help": "Whether to exclude truncated completions from the loss."}, ) - sync_ref_model: bool = field( - default=False, - metadata={"help": "Whether to synchronize the reference model with the student model."}, + teacher_model_kind: str = field( + default="live", + metadata={ + "help": "Semantic teacher choice. `live` uses the current student, `base` uses the initial student, " + "and `ema` uses an exponentially averaged teacher." + }, ) - ref_model_mixup_alpha: float = field( + teacher_update_rate: float = field( default=0.6, - metadata={"help": "EMA mix coefficient used when syncing the reference model."}, + metadata={ + "help": 'EMA update rate used when `teacher_model_kind="ema"`. A value of `1.0` reduces the update ' + "to a hard overwrite, periodically resyncing the teacher to the current student weights." + }, ) - ref_model_sync_steps: int = field( + teacher_sync_steps: int = field( default=512, - metadata={"help": "How often to synchronize the reference model."}, + metadata={"help": "Number of optimizer steps between teacher updates."}, ) top_entropy_quantile: float = field( default=1.0, @@ -256,19 +321,28 @@ class SelfDistillationConfig(_BaseConfig): ) distillation_alpha: float = field( default=0.5, - metadata={"help": "KL divergence direction: 0.0=forward KL, 0.5=JSD, 1.0=reverse KL."}, + metadata={"help": "KL divergence direction: `0.0=forward KL`, `0.5=JSD`, `1.0=reverse KL`."}, ) - distillation_topk: int | None = field( - default=100, - metadata={"help": "Number of top tokens for top-k distillation. If None, uses all tokens."}, + distillation_mode: Literal["sampled_token", "full_logits", "topk_logits"] = field( + default="sampled_token", + metadata={ + "help": "Distillation objective mode. `sampled_token` uses token-level distillation on the sampled " + "completion tokens, `full_logits` uses full-vocabulary divergence, and `topk_logits` uses a top-k " + "approximation over the student support." + }, ) - full_logit_distillation: bool = field( - default=False, - metadata={"help": "Whether to use full-logit distillation instead of token-level distillation."}, + distillation_topk: int | None = field( + default=None, + metadata={ + "help": "Number of top tokens for `topk_logits` distillation. Must be set when " + "`distillation_mode='topk_logits'` and left unset otherwise." + }, ) distillation_is_clip: float | None = field( default=2.0, - metadata={"help": "Clipping coefficient for importance sampling in self-distillation."}, + metadata={ + "help": "Clipping coefficient for importance sampling in self-distillation. `None` disables clipping." + }, ) distillation_add_tail: bool = field( default=False, @@ -294,20 +368,27 @@ class SelfDistillationConfig(_BaseConfig): def __post_init__(self): super().__post_init__() - self.scale_rewards = {True: "group", False: "none"}.get(self.scale_rewards, self.scale_rewards) - if self.scale_rewards not in ["group", "batch", "none"]: - raise ValueError("scale_rewards must be one of: 'group', 'batch', 'none'") - - if self.importance_sampling_level not in ["token", "sequence"]: - raise ValueError("importance_sampling_level must be either 'token' or 'sequence'") if self.loss_type not in ["grpo", "bnpo", "dr_grpo", "dapo"]: raise ValueError("loss_type must be one of: 'grpo', 'bnpo', 'dr_grpo', 'dapo'") + if self.teacher_model_kind not in {"base", "live", "ema"}: + raise ValueError("teacher_model_kind must be one of: 'base', 'live', 'ema'") + if not 0.0 <= self.teacher_update_rate <= 1.0: + raise ValueError("teacher_update_rate must be in [0, 1]") + if self.teacher_sync_steps <= 0: + raise ValueError("teacher_sync_steps must be positive") if self.num_generations < 1: raise ValueError("num_generations must be at least 1") if not 0.0 <= self.distillation_alpha <= 1.0: raise ValueError("distillation_alpha must be in [0, 1]") + if self.distillation_mode not in {"sampled_token", "full_logits", "topk_logits"}: + raise ValueError("distillation_mode must be one of: 'sampled_token', 'full_logits', 'topk_logits'") if self.distillation_topk is not None and self.distillation_topk <= 0: raise ValueError("distillation_topk must be positive when provided") + if self.distillation_mode == "topk_logits": + if self.distillation_topk is None: + raise ValueError("`distillation_mode='topk_logits'` requires `distillation_topk` to be set.") + elif self.distillation_topk is not None: + raise ValueError("`distillation_topk` is only valid when `distillation_mode='topk_logits'`.") if self.distillation_is_clip is not None and self.distillation_is_clip <= 0: raise ValueError("distillation_is_clip must be positive when provided") if self.distillation_weight < 0: @@ -345,6 +426,3 @@ def __post_init__(self): f"The global eval batch size ({self.per_device_eval_batch_size} * {num_processes}) must be " f"divisible by the number of generations used for evaluation ({num_generations_eval})." ) - - if self.epsilon_high is None: - self.epsilon_high = self.epsilon diff --git a/trl/experimental/self_distillation/self_distillation_mixin.py b/trl/experimental/self_distillation/self_distillation_mixin.py deleted file mode 100644 index fb2a8808de1..00000000000 --- a/trl/experimental/self_distillation/self_distillation_mixin.py +++ /dev/null @@ -1,295 +0,0 @@ -# Copyright 2020-2026 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Shared self-distillation loss utilities used by experimental trainers. - -This module intentionally holds only the reusable distillation mechanics: callback dispatch, common prompt/context -helpers, and the student-vs-teacher loss computation. Trainer lifecycle and online rollout concerns live in the trainer -classes or their online-specific base. -""" - -from __future__ import annotations - -from contextlib import nullcontext -from typing import Any - -import torch -import torch.nn.functional as F - -from ...trainer.utils import entropy_from_logits, selective_log_softmax -from .self_distillation_config import SelfDistillationConfig - - -class SelfDistillationMixin: - """Reusable self-distillation helpers shared across experimental trainers.""" - - config_cls = SelfDistillationConfig - - def _set_signature_columns_if_needed(self): - if self._signature_columns is None: - self._signature_columns = ["prompt", "privileged_context"] - - def _dispatch_self_distillation_callback(self, event_name: str, **payload) -> None: - for callback in self.callback_handler.callbacks: - callback_fn = getattr(callback, event_name, None) - if callback_fn is not None: - callback_fn( - args=self.args, - state=self.state, - control=self.control, - model=self.model, - processing_class=self.processing_class, - **payload, - ) - - @staticmethod - def _split_prompt_and_privileged_context(inputs: list[dict[str, Any]]) -> tuple[list[Any], list[Any]]: - prompts = [example["prompt"] for example in inputs] - privileged_contexts = [example.get("privileged_context") for example in inputs] - return prompts, privileged_contexts - - def _allow_topk_without_full_logit_distillation(self) -> bool: - return True - - def _get_per_token_logps_and_entropies( - self, - model, - input_ids, - attention_mask, - logits_to_keep, - compute_entropy=False, - ): - model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} - if "logits_to_keep" in self.model_kwarg_keys: - model_inputs["logits_to_keep"] = logits_to_keep + 1 - logits = model(**model_inputs).logits - logits = logits[:, :-1, :] - logits = logits[:, -logits_to_keep:, :] - logits = logits / self.temperature - completion_ids = input_ids[:, -logits_to_keep:] - selected_logps = selective_log_softmax(logits, completion_ids) - entropies = entropy_from_logits(logits) if compute_entropy else None - return selected_logps, entropies - - def _compute_self_distillation_loss( - self, - model, - inputs: dict[str, Any], - ) -> torch.Tensor: - # Expected batch contract: - # - required: `prompt_ids`, `prompt_mask`, `completion_ids`, `completion_mask`, - # `teacher_input_ids`, `teacher_attention_mask` - # - optional: `self_distillation_mask` to zero-out samples without teacher supervision, - # `old_per_token_logps` to enable IS clipping when generation and optimization are misaligned - prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] - completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] - logits_to_keep = completion_ids.size(1) - - self_distillation_mask = inputs.get("self_distillation_mask") - if self_distillation_mask is not None: - response_mask = completion_mask * self_distillation_mask.unsqueeze(1) - else: - response_mask = completion_mask - - if response_mask.sum() == 0: - mode = "train" if model.training else "eval" - self._log_self_distillation_metric(mode, "distillation_loss", 0.0) - return torch.tensor(0.0, device=completion_ids.device, requires_grad=True) - - student_input_ids = torch.cat([prompt_ids, completion_ids], dim=1) - student_attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - student_model_inputs = { - "input_ids": student_input_ids, - "attention_mask": student_attention_mask, - "use_cache": False, - } - if "logits_to_keep" in self.model_kwarg_keys: - student_model_inputs["logits_to_keep"] = logits_to_keep + 1 - - student_logits = model(**student_model_inputs).logits - student_logits = student_logits[:, :-1, :] - student_logits = student_logits[:, -logits_to_keep:, :] - student_logits = student_logits / self.temperature - - teacher_input_ids = inputs["teacher_input_ids"] - teacher_attention_mask = inputs["teacher_attention_mask"] - teacher_model_inputs = { - "input_ids": teacher_input_ids, - "attention_mask": teacher_attention_mask, - "use_cache": False, - } - if "logits_to_keep" in self.model_kwarg_keys: - teacher_model_inputs["logits_to_keep"] = logits_to_keep + 1 - - teacher_model = self._get_teacher_model_for_self_distillation(model) - with torch.no_grad(), self._get_teacher_context_for_self_distillation(model): - teacher_logits = teacher_model(**teacher_model_inputs).logits - teacher_logits = teacher_logits[:, :-1, :] - teacher_logits = teacher_logits[:, -logits_to_keep:, :] - teacher_logits = teacher_logits / self.temperature - - use_topk_distillation = self.args.distillation_topk is not None and ( - self.args.full_logit_distillation or self._allow_topk_without_full_logit_distillation() - ) - if use_topk_distillation: - student_logsumexp = torch.logsumexp(student_logits, dim=-1, keepdim=True) - topk_student_logits, topk_indices = torch.topk(student_logits, k=self.args.distillation_topk, dim=-1) - topk_student_log_probs = topk_student_logits - student_logsumexp - - teacher_logsumexp = torch.logsumexp(teacher_logits, dim=-1, keepdim=True) - topk_teacher_logits = torch.gather(teacher_logits, dim=-1, index=topk_indices) - topk_teacher_log_probs = topk_teacher_logits - teacher_logsumexp - - if self.args.distillation_add_tail: - topk_student_log_probs = self._add_tail(topk_student_log_probs) - topk_teacher_log_probs = self._add_tail(topk_teacher_log_probs) - else: - topk_student_log_probs = self._renorm_topk_log_probs(topk_student_log_probs) - topk_teacher_log_probs = self._renorm_topk_log_probs(topk_teacher_log_probs) - - per_token_loss = self._compute_divergence( - topk_student_log_probs, topk_teacher_log_probs, self.args.distillation_alpha - ) - elif self.args.full_logit_distillation: - student_log_probs = F.log_softmax(student_logits, dim=-1) - teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) - per_token_loss = self._compute_divergence( - student_log_probs, teacher_log_probs, self.args.distillation_alpha - ) - else: - if self.args.distillation_alpha != 1.0: - raise ValueError( - "Only reverse KL (alpha=1.0) is supported for token-level distillation when " - "`full_logit_distillation=False`, " - f"got alpha={self.args.distillation_alpha}" - ) - student_logsumexp = torch.logsumexp(student_logits, dim=-1, keepdim=True) - teacher_logsumexp = torch.logsumexp(teacher_logits, dim=-1, keepdim=True) - idx = completion_ids.unsqueeze(-1) - student_per_token_logps = (torch.gather(student_logits, dim=-1, index=idx) - student_logsumexp).squeeze(-1) - teacher_per_token_logps = (torch.gather(teacher_logits, dim=-1, index=idx) - teacher_logsumexp).squeeze(-1) - per_token_loss = self._compute_token_level_distillation_loss( - student_per_token_logps, teacher_per_token_logps - ) - - if self.args.distillation_is_clip is not None: - old_log_probs = inputs.get("old_per_token_logps") - if old_log_probs is not None: - with torch.no_grad(): - student_lse = torch.logsumexp(student_logits, dim=-1, keepdim=True) - idx = completion_ids.unsqueeze(-1) - student_per_token_logps = (torch.gather(student_logits, dim=-1, index=idx) - student_lse).squeeze( - -1 - ) - per_token_loss = self._apply_importance_sampling_clipping( - per_token_loss, student_per_token_logps, old_log_probs, self.args.distillation_is_clip - ) - - loss = self._aggregate_self_distillation_loss(per_token_loss, response_mask) - - mode = "train" if model.training else "eval" - mean_distill_loss = (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) - self._log_self_distillation_metric( - mode, - "distillation_loss", - self.accelerator.gather(mean_distill_loss).mean().item(), - ) - - return loss - - def _get_teacher_model_for_self_distillation(self, model): - teacher_model = getattr(self, "teacher_model", None) - if teacher_model is None: - return model - return teacher_model - - def _get_teacher_context_for_self_distillation(self, model): - return nullcontext() - - def _log_self_distillation_metric(self, mode: str, metric_name: str, value: float) -> None: - metric_prefix = getattr(self, "_name", "self_distillation").lower().replace(" ", "_") - self._metrics[mode][f"self_distillation/{metric_name}"].append(value) - self._metrics[mode][f"{metric_prefix}/{metric_name}"].append(value) - - @staticmethod - def _compute_divergence( - student_log_probs: torch.Tensor, - teacher_log_probs: torch.Tensor, - alpha: float, - ) -> torch.Tensor: - if alpha == 0.0: - kl = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) - elif alpha == 1.0: - kl = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) - else: - alpha_t = torch.tensor(alpha, dtype=student_log_probs.dtype, device=student_log_probs.device) - mixture = torch.logsumexp( - torch.stack([student_log_probs + torch.log(1 - alpha_t), teacher_log_probs + torch.log(alpha_t)]), - dim=0, - ) - kl_teacher = F.kl_div(mixture, teacher_log_probs, reduction="none", log_target=True) - kl_student = F.kl_div(mixture, student_log_probs, reduction="none", log_target=True) - kl = torch.lerp(kl_student, kl_teacher, alpha) - return kl.sum(-1) - - @staticmethod - def _add_tail(log_probs: torch.Tensor) -> torch.Tensor: - log_s = torch.logsumexp(log_probs, dim=-1, keepdim=True) - log_s = torch.clamp(log_s, max=-1e-7) - tail_log = torch.log(-torch.expm1(log_s)) - return torch.cat([log_probs, tail_log], dim=-1) - - @staticmethod - def _renorm_topk_log_probs(log_probs: torch.Tensor) -> torch.Tensor: - return log_probs - torch.logsumexp(log_probs, dim=-1, keepdim=True) - - @staticmethod - def _compute_token_level_distillation_loss( - student_log_probs: torch.Tensor, - teacher_log_probs: torch.Tensor, - ) -> torch.Tensor: - # This is the token-level reverse-KL surrogate used by the official SDPO implementation for - # `full_logit_distillation=False`. It intentionally treats the teacher log-probs as fixed targets - # and keeps only the score-function term for the sampled student tokens. - log_ratio = student_log_probs - teacher_log_probs - return log_ratio.detach() * student_log_probs - - @staticmethod - def _apply_importance_sampling_clipping( - per_token_loss: torch.Tensor, - student_log_probs: torch.Tensor, - old_log_probs: torch.Tensor, - clip_coeff: float, - ) -> torch.Tensor: - negative_approx_kl = (student_log_probs - old_log_probs).detach() - negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) - ratio = torch.exp(negative_approx_kl).clamp(max=clip_coeff) - return per_token_loss * ratio - - def _aggregate_self_distillation_loss( - self, - per_token_loss: torch.Tensor, - response_mask: torch.Tensor, - ) -> torch.Tensor: - loss_type = self.loss_type - if loss_type == "grpo": - loss = (per_token_loss * response_mask).sum(-1) / response_mask.sum(-1).clamp(min=1.0) - return loss.mean() - if loss_type == "bnpo": - return (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) - if loss_type == "dr_grpo": - return (per_token_loss * response_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) - if loss_type in ["dapo", "luspo", "cispo", "sapo"]: - return (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) - raise ValueError(f"Unsupported loss_type for self-distillation: {loss_type}") diff --git a/trl/experimental/self_distillation/teacher_context.py b/trl/experimental/self_distillation/teacher_context.py deleted file mode 100644 index 5e1020c91a7..00000000000 --- a/trl/experimental/self_distillation/teacher_context.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2020-2026 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any - -import torch - -from ...data_utils import maybe_apply_chat_template -from ...trainer.base_trainer import _BaseTrainer -from ...trainer.utils import pad - - -def extract_last_user_text(prompt: list[dict[str, Any]]) -> str: - """Extract the text content from the last user message in a conversational prompt.""" - last_message = prompt[-1] - if last_message.get("role") != "user": - raise ValueError( - f"Self-distillation teacher prompt construction expects the conversation to end with a user turn, " - f"but the last message has role '{last_message.get('role')}'. " - f"Prompts ending with assistant prefills or tool turns are not supported." - ) - content = last_message.get("content", "") - if isinstance(content, list): - return " ".join(part.get("text", "") for part in content if part.get("type") == "text") - return content - - -@dataclass -class TokenizedPromptBatch: - prompt_ids: torch.Tensor - prompt_mask: torch.Tensor - - -class PromptTokenizer: - """Internal helper to tokenize prompt-like inputs consistently across self-distillation trainers.""" - - def __init__(self, trainer): - self.trainer = trainer - - def apply_prompt_template(self, prompts: list[Any]) -> list[str]: - return [ - maybe_apply_chat_template( - {"prompt": prompt}, - self.trainer.processing_class, - **getattr(self.trainer, "chat_template_kwargs", {}), - )["prompt"] - for prompt in prompts - ] - - def tokenize_prompts(self, prompts: list[Any]) -> TokenizedPromptBatch: - prompt_text = self.apply_prompt_template(prompts) - prompt_inputs = self.trainer.processing_class( - text=prompt_text, - return_tensors="pt", - padding=True, - padding_side="left", - max_length=self.trainer.max_prompt_length, - truncation=True, - add_special_tokens=False, - ) - prompt_inputs = super(_BaseTrainer, self.trainer)._prepare_inputs(prompt_inputs) - prompt_ids = [ - p[m].tolist() - for p, m in zip(prompt_inputs["input_ids"], prompt_inputs["attention_mask"].bool(), strict=False) - ] - prompt_ids = [torch.tensor(ids, device=self.trainer.accelerator.device) for ids in prompt_ids] - prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] - return TokenizedPromptBatch( - prompt_ids=pad(prompt_ids, padding_value=self.trainer.pad_token_id, padding_side="left"), - prompt_mask=pad(prompt_mask, padding_value=0, padding_side="left"), - ) diff --git a/trl/experimental/self_distillation/peft_adapter_ema_callback.py b/trl/experimental/self_distillation/teacher_sync.py similarity index 88% rename from trl/experimental/self_distillation/peft_adapter_ema_callback.py rename to trl/experimental/self_distillation/teacher_sync.py index e252bb512a4..267ef4d5854 100644 --- a/trl/experimental/self_distillation/peft_adapter_ema_callback.py +++ b/trl/experimental/self_distillation/teacher_sync.py @@ -22,10 +22,26 @@ TrainingArguments, ) +from ...trainer.callbacks import SyncRefModelCallback + logger = logging.getLogger(__name__) +class SyncTeacherModelCallback(SyncRefModelCallback): + """Synchronize an EMA teacher model with the student model on each configured sync step.""" + + def __init__(self, teacher_model, accelerator=None): + super().__init__(ref_model=teacher_model, accelerator=accelerator) + + def on_step_end(self, args, state, control, **kwargs): + model = kwargs["model"] + if self.ref_model is not None and state.global_step % args.teacher_sync_steps == 0: + if self.accelerator: + model = self.accelerator.unwrap_model(model) + self.sync_target_model(model, self.ref_model, args.teacher_update_rate) + + class PEFTAdapterEMACallback(TrainerCallback): """ Callback that maintains an EMA copy of PEFT adapter weights for use as a teacher model in self-distillation.