diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 4c78c9d3bb5..23893c47e02 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -1641,8 +1641,8 @@ from trl.experimental.sdpo import SDPOConfig, SDPOTrainer training_args = SDPOConfig( distillation_alpha=0.5, # Jensen-Shannon divergence (recommended) - distillation_topk=100, # Top-K logit distillation approximation - full_logit_distillation=True, # Required for top-K logit-level SDPO + distillation_mode="topk_logits", # Explicitly select top-K logit distillation + distillation_topk=100, # Required for top-K logit distillation distillation_is_clip=2.0, # Importance sampling clipping distillation_weight=1.0, # Weight for self-distillation loss sdpo_policy_loss_mode="distillation_only", @@ -1689,6 +1689,7 @@ dataset = Dataset.from_dict( training_args = SDFTConfig( distillation_alpha=0.5, + distillation_mode="topk_logits", distillation_topk=5, max_completion_length=64, ) diff --git a/docs/source/sdft_trainer.md b/docs/source/sdft_trainer.md index 17f6433fd78..23f29966458 100644 --- a/docs/source/sdft_trainer.md +++ b/docs/source/sdft_trainer.md @@ -32,6 +32,7 @@ dataset = Dataset.from_dict( training_args = SDFTConfig( output_dir="sdft-model", distillation_alpha=0.5, + distillation_mode="topk_logits", distillation_topk=5, max_completion_length=64, ) diff --git a/docs/source/sdpo_trainer.md b/docs/source/sdpo_trainer.md index 3910f674baf..7f82f565bcb 100644 --- a/docs/source/sdpo_trainer.md +++ b/docs/source/sdpo_trainer.md @@ -11,8 +11,9 @@ In the current TRL implementation: - the default SDPO policy loss mode is `distillation_only` - `hybrid` mode is also available to combine the base policy loss with the self-distillation loss - supported teacher regularization modes are `ema` and `none` -- `distillation_topk` is only valid when `full_logit_distillation=True` -- when `full_logit_distillation=False`, SDPO uses token-level reverse KL and requires `distillation_alpha=1.0` +- `distillation_mode` selects between `sampled_token`, `full_logits`, and `topk_logits` +- `distillation_topk` is only valid when `distillation_mode="topk_logits"` +- when `distillation_mode="sampled_token"`, SDPO uses token-level reverse KL and requires `distillation_alpha=1.0` - environment feedback can be injected into teacher reprompts when the dataset exposes a `privileged_context` column ## Expected dataset columns @@ -38,8 +39,8 @@ dataset = Dataset.from_dict( training_args = SDPOConfig( output_dir="sdpo-model", - distillation_topk=100, # Top-K logit distillation approximation - full_logit_distillation=True, # Required for top-K; enables non-reverse divergences + distillation_mode="topk_logits", # Explicitly select top-K logit distillation + distillation_topk=100, # Required when using top-K logit distillation include_environment_feedback=True, # Use dataset privileged_context for teacher reprompts ) @@ -88,7 +89,7 @@ python trl/experimental/sdpo/sdpo.py \ --num_generations 8 \ --generation_batch_size 32 \ --distillation_alpha 1.0 \ - --full_logit_distillation false \ + --distillation_mode sampled_token \ --sdpo_policy_loss_mode hybrid \ --report_to none \ --eval_strategy steps \ diff --git a/tests/experimental/test_base_self_distillation_trainer.py b/tests/experimental/test_base_self_distillation_trainer.py new file mode 100644 index 00000000000..bd7c7e880b1 --- /dev/null +++ b/tests/experimental/test_base_self_distillation_trainer.py @@ -0,0 +1,276 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import defaultdict +from types import SimpleNamespace + +import pytest +import torch +from datasets import Dataset +from transformers import AutoModelForCausalLM +from transformers.utils import is_peft_available + +from trl.experimental.self_distillation.base_self_distillation_trainer import ( + BaseSelfDistillationTrainer, + DistillationLogits, +) +from trl.experimental.self_distillation.self_distillation_config import SelfDistillationConfig + +from ..testing_utils import TrlTestCase + + +if is_peft_available(): + from peft import LoraConfig, get_peft_model + + +class MinimalSelfDistillationTrainer(BaseSelfDistillationTrainer): + def finalize_batch(self, inputs, rollout_batch): + del inputs + batch = rollout_batch.as_dict() + batch["teacher_input_ids"] = rollout_batch.prompt_ids + batch["teacher_attention_mask"] = rollout_batch.prompt_mask + return batch + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + del inputs, num_items_in_batch + anchor = next(model.parameters()) + return anchor.sum() * 0.0 + + +class FakeTextTokenizer: + def __call__(self, text, **kwargs): + del kwargs + token_map = { + "short prompt": [1, 2, 3], + "long prompt": [10, 11, 12, 13, 14], + } + return {"input_ids": [token_map[prompt] for prompt in text]} + + +class FakeChatProcessor: + def __init__(self): + self.calls = [] + + def apply_chat_template(self, conversation, add_generation_prompt, tokenize, return_dict, **kwargs): + self.calls.append( + { + "conversation": conversation, + "add_generation_prompt": add_generation_prompt, + "tokenize": tokenize, + "return_dict": return_dict, + "kwargs": kwargs, + } + ) + return {"input_ids": [[21, 22, 23, 24]]} + + +class TestBaseSelfDistillationTrainer(TrlTestCase): + @staticmethod + def _make_loss_test_trainer(**args_overrides): + trainer = object.__new__(MinimalSelfDistillationTrainer) + args = { + "distillation_mode": "sampled_token", + "distillation_topk": None, + "distillation_alpha": 1.0, + "distillation_add_tail": False, + "distillation_is_clip": None, + } + args.update(args_overrides) + trainer.args = SimpleNamespace(**args) + trainer.loss_type = "dapo" + trainer.max_completion_length = 2 + trainer.accelerator = SimpleNamespace(gather=lambda tensor: tensor) + trainer._metrics = { + "train": defaultdict(list), + "eval": defaultdict(list), + } + trainer._name = "Minimal Self Distillation" + return trainer + + def test_teacher_model_kind_live_uses_student_model(self): + dataset = Dataset.from_dict({"prompt": ["Solve 2+2."]}) + training_args = SelfDistillationConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=1, + max_completion_length=8, + max_steps=1, + num_generations=1, + teacher_model_kind="live", + report_to="none", + ) + + trainer = MinimalSelfDistillationTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + ) + + assert trainer.teacher_model is trainer.model + + @pytest.mark.skipif(not is_peft_available(), reason="PEFT is required for this test") + def test_warns_when_initial_student_already_has_a_peft_adapter(self, caplog): + dataset = Dataset.from_dict({"prompt": ["Solve 2+2."]}) + training_args = SelfDistillationConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=1, + max_completion_length=8, + max_steps=1, + num_generations=1, + teacher_model_kind="base", + report_to="none", + ) + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + model = get_peft_model( + model, + LoraConfig( + r=4, + lora_alpha=8, + target_modules=["q_proj", "v_proj"], + bias="none", + task_type="CAUSAL_LM", + ), + ) + + with caplog.at_level( + logging.WARNING, logger="trl.experimental.self_distillation.base_self_distillation_trainer" + ): + MinimalSelfDistillationTrainer( + model=model, + args=training_args, + train_dataset=dataset, + ) + + assert "already contains a PEFT adapter" in caplog.text + assert "`teacher_model_kind='base'` may refer to the underlying base weights" in caplog.text + + @pytest.mark.parametrize("teacher_model_kind", ["base", "ema"]) + def test_teacher_model_kind_base_and_ema_use_frozen_teacher_copy(self, teacher_model_kind): + dataset = Dataset.from_dict({"prompt": ["Solve 2+2."]}) + training_args = SelfDistillationConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=1, + max_completion_length=8, + max_steps=1, + num_generations=1, + teacher_model_kind=teacher_model_kind, + report_to="none", + ) + + trainer = MinimalSelfDistillationTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + ) + + assert trainer.teacher_model is not trainer.model + assert trainer.teacher_model.training is False + + student_param = next(trainer.model.parameters()) + teacher_param = next(trainer.teacher_model.parameters()) + assert teacher_param.requires_grad is False + assert teacher_param.data_ptr() != student_param.data_ptr() + + def test_tokenize_prompts_truncates_text_prompts_from_left(self): + trainer = object.__new__(MinimalSelfDistillationTrainer) + trainer.processing_class = FakeTextTokenizer() + trainer.max_prompt_length = 3 + + prompt_ids = trainer._tokenize_prompts(["long prompt", "short prompt"]) + + assert prompt_ids == [[12, 13, 14], [1, 2, 3]] + + def test_tokenize_prompts_for_conversational_prompts_forwards_chat_template_kwargs(self): + trainer = object.__new__(MinimalSelfDistillationTrainer) + trainer.processing_class = FakeChatProcessor() + trainer.max_prompt_length = 2 + trainer.chat_template_kwargs = {"enable_thinking": False} + + prompt_ids = trainer._tokenize_prompts([[{"role": "user", "content": "Solve 2+2."}]]) + + assert prompt_ids == [[23, 24]] + assert trainer.processing_class.calls == [ + { + "conversation": [[{"role": "user", "content": "Solve 2+2."}]], + "add_generation_prompt": True, + "tokenize": True, + "return_dict": True, + "kwargs": {"enable_thinking": False}, + } + ] + + def test_compute_self_distillation_loss_ignores_masked_completion_tokens(self): + trainer = self._make_loss_test_trainer( + distillation_mode="full_logits", + distillation_alpha=0.0, + ) + model = SimpleNamespace(training=True) + + # Token 0 is active and has a known non-zero divergence. + # Token 1 is intentionally very different but masked out, so it must not affect the loss. + student_probs = torch.tensor([[[0.8, 0.2], [0.01, 0.99]]], dtype=torch.float32) + teacher_probs = torch.tensor([[[0.5, 0.5], [0.99, 0.01]]], dtype=torch.float32) + distillation_logits = DistillationLogits( + completion_ids=torch.tensor([[0, 1]], dtype=torch.long), + completion_mask=torch.tensor([[1, 1]], dtype=torch.long), + response_mask=torch.tensor([[1, 0]], dtype=torch.long), + student_logits=student_probs.log(), + teacher_logits=teacher_probs.log(), + ) + + loss = trainer._compute_self_distillation_loss(model, {}, distillation_logits) + + expected_active_token_loss = teacher_probs[0, 0, 0] * ( + teacher_probs[0, 0, 0].log() - student_probs[0, 0, 0].log() + ) + teacher_probs[0, 0, 1] * (teacher_probs[0, 0, 1].log() - student_probs[0, 0, 1].log()) + torch.testing.assert_close(loss, expected_active_token_loss) + torch.testing.assert_close( + torch.tensor(trainer._metrics["train"]["self_distillation/distillation_loss"]), + expected_active_token_loss.unsqueeze(0), + ) + + def test_compute_self_distillation_loss_applies_importance_sampling_clip(self): + trainer = self._make_loss_test_trainer(distillation_is_clip=2.0) + model = SimpleNamespace(training=True) + + student_token_probs = torch.tensor([[0.2, 0.4]], dtype=torch.float32) + teacher_token_probs = torch.tensor([[0.5, 0.5]], dtype=torch.float32) + old_token_probs = torch.tensor([[0.05, 0.4]], dtype=torch.float32) + clip_coeff = trainer.args.distillation_is_clip + + distillation_logits = DistillationLogits( + completion_ids=torch.tensor([[0, 1]], dtype=torch.long), + completion_mask=torch.tensor([[1, 1]], dtype=torch.long), + response_mask=torch.tensor([[1, 1]], dtype=torch.long), + student_logits=torch.log(torch.tensor([[[0.2, 0.8], [0.6, 0.4]]], dtype=torch.float32)), + teacher_logits=torch.log(torch.tensor([[[0.5, 0.5], [0.5, 0.5]]], dtype=torch.float32)), + ) + + loss = trainer._compute_self_distillation_loss( + model, + {"old_per_token_logps": old_token_probs.log()}, + distillation_logits, + ) + + raw_per_token_loss = (student_token_probs.log() - teacher_token_probs.log()) * student_token_probs.log() + clipped_ratio = torch.minimum( + student_token_probs / old_token_probs, torch.full_like(student_token_probs, clip_coeff) + ) + expected_loss = (raw_per_token_loss * clipped_ratio).mean() + + torch.testing.assert_close(loss, expected_loss) + torch.testing.assert_close( + torch.tensor(trainer._metrics["train"]["self_distillation/distillation_loss"]), + expected_loss.unsqueeze(0), + ) diff --git a/tests/experimental/test_sdft_trainer.py b/tests/experimental/test_sdft_trainer.py index 9ca9b6c579a..f307c5a96dd 100644 --- a/tests/experimental/test_sdft_trainer.py +++ b/tests/experimental/test_sdft_trainer.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import patch + import pytest import torch from datasets import Dataset from transformers import AutoModelForCausalLM, TrainerCallback, TrainerControl, TrainerState, TrainingArguments from transformers.utils import is_peft_available -from trl.data_utils import maybe_apply_chat_template from trl.experimental.sdft import SDFTConfig, SDFTTrainer from ..testing_utils import TrlTestCase, require_peft @@ -27,18 +28,18 @@ if is_peft_available(): from peft import LoraConfig, get_peft_model, get_peft_model_state_dict - from trl.experimental.self_distillation.peft_adapter_ema_callback import PEFTAdapterEMACallback + from trl.experimental.self_distillation.teacher_sync import PEFTAdapterEMACallback class SelfDistillationCaptureCallback(TrainerCallback): def __init__(self): - self.captured_generation_prompt_text = None + self.captured_generation_prompts = None self.captured_old_per_token_logps = None self.generation_batch_build_count = 0 - def on_generation_prompts_selected(self, generation_prompt_text=None, **kwargs): - if self.captured_generation_prompt_text is None and generation_prompt_text is not None: - self.captured_generation_prompt_text = generation_prompt_text[0] + def on_generation_prompts_selected(self, generation_prompts=None, **kwargs): + if self.captured_generation_prompts is None and generation_prompts is not None: + self.captured_generation_prompts = generation_prompts def on_self_distillation_batch_prepared(self, old_per_token_logps=None, **kwargs): if self.captured_old_per_token_logps is None and old_per_token_logps is not None: @@ -74,6 +75,7 @@ def test_training_rejects_none_privileged_context(self): with pytest.raises(ValueError, match="`privileged_context` must not be None"): trainer.train() + @pytest.mark.skip(reason="`generate_from_teacher` is not ported yet") def test_training_with_generate_from_teacher(self): dataset = Dataset.from_dict( { @@ -105,9 +107,8 @@ def test_training_with_generate_from_teacher(self): trainer.train() - assert capture_callback.captured_generation_prompt_text is not None - assert "Solve 2+2." in capture_callback.captured_generation_prompt_text - assert "Teacher hint" in capture_callback.captured_generation_prompt_text + assert capture_callback.captured_generation_prompts is not None + assert capture_callback.captured_generation_prompts[0] != dataset[0]["prompt"] def test_training_with_chat_template_kwargs(self): dataset = Dataset.from_dict( @@ -141,15 +142,15 @@ def test_training_with_chat_template_kwargs(self): callbacks=[capture_callback], ) - expected_prompt = maybe_apply_chat_template( - {"prompt": dataset[0]["prompt"]}, + with patch.object( trainer.processing_class, - **training_args.chat_template_kwargs, - )["prompt"] - - trainer.train() + "apply_chat_template", + wraps=trainer.processing_class.apply_chat_template, + ) as mock_apply_chat_template: + trainer.train() - assert capture_callback.captured_generation_prompt_text == expected_prompt + assert mock_apply_chat_template.call_count > 0 + assert any(call.kwargs.get("enable_thinking") is False for call in mock_apply_chat_template.call_args_list) @require_peft def test_training_with_peft_model(self): @@ -205,9 +206,9 @@ def test_training_with_peft_model_and_sync_ref_model(self): max_completion_length=8, max_steps=2, num_generations=1, - sync_ref_model=True, - ref_model_mixup_alpha=0.05, - ref_model_sync_steps=1, + teacher_model_kind="ema", + teacher_update_rate=0.05, + teacher_sync_steps=1, ) trainer = SDFTTrainer( diff --git a/tests/experimental/test_sdpo_trainer.py b/tests/experimental/test_sdpo_trainer.py index 1bb666a7dcc..81a29af4822 100644 --- a/tests/experimental/test_sdpo_trainer.py +++ b/tests/experimental/test_sdpo_trainer.py @@ -100,14 +100,6 @@ def test_vllm_config_defaults_match_reference_trainers(self): assert config.vllm_model_impl == "vllm" def test_generate_vllm_syncs_on_step_change_and_uses_mode_specific_num_generations(self): - class FakeTokenizer: - def __call__(self, text, **kwargs): - token_map = { - "Solve 2+2.": [11, 12], - "Check 3+3.": [21, 22], - } - return {"input_ids": [token_map[prompt] for prompt in text]} - class FakeVLLMGeneration: def __init__(self): self.sync_weights_call_count = 0 @@ -135,11 +127,9 @@ def generate(self, prompts, images, num_generations): trainer.model = SimpleNamespace(training=True) trainer.state = SimpleNamespace(global_step=4) trainer._last_loaded_step = 3 - trainer.processing_class = FakeTokenizer() trainer.vllm_generation = FakeVLLMGeneration() - trainer._apply_prompt_template = lambda prompts: prompts - prompt_ids, completion_ids = trainer._generate(["Solve 2+2.", "Solve 2+2."]) + prompt_ids, completion_ids = trainer._generate([[11, 12], [11, 12]]) assert prompt_ids == [[11, 12], [11, 12]] assert completion_ids == [[100], [101]] @@ -154,7 +144,7 @@ def generate(self, prompts, images, num_generations): ] trainer.model.training = False - eval_prompt_ids, eval_completion_ids = trainer._generate(["Check 3+3.", "Check 3+3.", "Check 3+3."]) + eval_prompt_ids, eval_completion_ids = trainer._generate([[21, 22], [21, 22], [21, 22]]) assert eval_prompt_ids == [[21, 22], [21, 22], [21, 22]] assert eval_completion_ids == [[100], [101], [102]] @@ -167,7 +157,7 @@ def generate(self, prompts, images, num_generations): trainer.model.training = True trainer.state.global_step = 5 - trainer._generate(["Solve 2+2.", "Solve 2+2."]) + trainer._generate([[11, 12], [11, 12]]) assert trainer.vllm_generation.sync_weights_call_count == 2 assert trainer._last_loaded_step == 5 @@ -181,8 +171,8 @@ def test_training(self): per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=8, # reduce the completion length to reduce memory usage + distillation_mode="topk_logits", distillation_topk=5, - full_logit_distillation=True, distillation_is_clip=None, ) trainer = SDPOTrainer( diff --git a/trl/experimental/sdft/sdft.py b/trl/experimental/sdft/sdft.py index 958a24fc683..6c9aa8d6179 100644 --- a/trl/experimental/sdft/sdft.py +++ b/trl/experimental/sdft/sdft.py @@ -46,10 +46,9 @@ --learning_rate 2e-5 \ --max_prompt_length 1024 \ --max_completion_length 512 \ - --generate_from_teacher \ - --sync_ref_model \ - --ref_model_sync_steps 1 \ - --ref_model_mixup_alpha 0.01 \ + --teacher_model_kind ema \ + --teacher_sync_steps 1 \ + --teacher_update_rate 0.01 \ --eval_strategy steps \ --eval_steps 50 \ --report_to wandb @@ -86,10 +85,6 @@ @dataclass class SDFTScriptArguments(ScriptArguments): - ref_model_name_or_path: str | None = field( - default=None, - metadata={"help": "Reference teacher model. Optional for PEFT runs, where the base model is used as teacher."}, - ) dataset_path: str | None = field( default=None, metadata={"help": "Optional local dataset path to load with `load_from_disk`. Overrides `dataset_name`."}, @@ -122,14 +117,6 @@ class SDFTScriptArguments(ScriptArguments): ) -@dataclass -class ExampleSDFTConfig(SDFTConfig): - scale_rewards: str = field( - default="group", - metadata={"help": "Reward normalization mode. Supported: `group`, `batch`, `none`."}, - ) - - def _extract_prompt_text(prompt: Any) -> str: if isinstance(prompt, str): return prompt @@ -323,14 +310,15 @@ def _run_tooluse_eval( if __name__ == "__main__": - parser = TrlParser((SDFTScriptArguments, ExampleSDFTConfig, ModelConfig)) + parser = TrlParser((SDFTScriptArguments, SDFTConfig, ModelConfig)) script_args, training_args, model_args = parser.parse_args_and_config() if model_args.model_name_or_path is None: raise ValueError("`model_name_or_path` is required.") - if script_args.ref_model_name_or_path is None and not model_args.use_peft: - script_args.ref_model_name_or_path = model_args.model_name_or_path - + if training_args.generate_from_teacher: + raise ValueError( + "`generate_from_teacher` is not yet supported by the transitioned SDFT trainer used in this script." + ) if model_args.dtype in ["auto", None]: if training_args.bf16: dtype = torch.bfloat16 @@ -385,16 +373,10 @@ def _run_tooluse_eval( eval_dataset = _prepare_split(raw_eval_dataset, script_args) model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) - ref_model = None - if script_args.ref_model_name_or_path is not None: - ref_model = AutoModelForCausalLM.from_pretrained(script_args.ref_model_name_or_path, **model_kwargs) model.config.use_cache = False if training_args.gradient_checkpointing else True - if ref_model is not None: - ref_model.config.use_cache = True trainer = SDFTTrainer( model=model, - ref_model=ref_model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, diff --git a/trl/experimental/sdft/sdft_config.py b/trl/experimental/sdft/sdft_config.py index 84227e43cbf..21ec6b2a814 100644 --- a/trl/experimental/sdft/sdft_config.py +++ b/trl/experimental/sdft/sdft_config.py @@ -20,14 +20,14 @@ @dataclass class SDFTConfig(SelfDistillationConfig): r""" - Configuration class for [`SDFTTrainer`]. - - This adapts the official SDFT implementation to the TRL trainer API while reusing the common self-distillation - configuration shared with SDPO. + Configuration class for [`SDFTTrainer`].. Parameters: disable_dropout (`bool`, *optional*, defaults to `True`): Whether to disable dropout in the student and teacher models. + teacher_model_kind (`str`, *optional*, defaults to `"base"`): + Semantic teacher choice for SDFT. `base` uses the initial student, `live` uses the current student, and + `ema` uses an exponentially averaged teacher. generate_from_teacher (`bool`, *optional*, defaults to `False`): Whether on-policy generation should use the teacher-conditioned prompt instead of the student prompt. teacher_prompt_template (`str`, *optional*, defaults to `"{prompt}\n\n{privileged_context}"`): @@ -40,6 +40,13 @@ class SDFTConfig(SelfDistillationConfig): default=True, metadata={"help": "Whether to disable dropout in the student and teacher models."}, ) + teacher_model_kind: str = field( + default="base", + metadata={ + "help": "Semantic teacher choice for SDFT. `base` uses the initial student, `live` uses the current " + "student, and `ema` uses an exponentially averaged teacher." + }, + ) generate_from_teacher: bool = field( default=False, metadata={"help": "Whether on-policy generation should use the teacher-conditioned prompt."}, diff --git a/trl/experimental/sdft/sdft_trainer.py b/trl/experimental/sdft/sdft_trainer.py index 5bf6095c2a0..24959a36bc0 100644 --- a/trl/experimental/sdft/sdft_trainer.py +++ b/trl/experimental/sdft/sdft_trainer.py @@ -14,66 +14,43 @@ from __future__ import annotations -import copy -import inspect import textwrap -from collections import defaultdict -from functools import partial from typing import Any -import datasets import torch from accelerate.logging import get_logger -from accelerate.utils import is_peft_model from datasets import Dataset, IterableDataset from torch import nn -from torch.utils.data import DataLoader, Sampler from transformers import ( - AutoProcessor, - GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, ) -from transformers.trainer_utils import seed_worker -from transformers.utils import is_datasets_available, is_peft_available +from transformers.utils import is_peft_available -from ...models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation -from ...trainer.base_trainer import _BaseTrainer -from ...trainer.callbacks import SyncRefModelCallback -from ...trainer.utils import ( - RepeatSampler, - create_model_from_path, - disable_dropout_in_model, - get_config_model_id, - identity, - pad, - split_tensor_dict, - use_adapter, +from ...trainer.utils import pad +from ..self_distillation.base_self_distillation_trainer import ( + BaseSelfDistillationTrainer, + RolloutBatch, + TrainingBatch, ) -from ..self_distillation.self_distillation_mixin import SelfDistillationMixin -from ..self_distillation.teacher_context import PromptTokenizer, extract_last_user_text -from ..utils import prepare_peft_model +from ..self_distillation.prompt_utils import extract_last_user_text from .sdft_config import SDFTConfig if is_peft_available(): from peft import PeftConfig - from peft.peft_model import PeftModel - - from ..self_distillation.peft_adapter_ema_callback import PEFTAdapterEMACallback logger = get_logger(__name__) class DemonstrationTeacherContextBuilder: - """Builds student and teacher contexts from prompts plus privileged context, as in SDFT.""" + """Builds student and teacher contexts from prompts plus privileged context""" def __init__(self, trainer): self.trainer = trainer - self.prompt_tokenizer = PromptTokenizer(trainer) def _stringify_privileged_context(self, privileged_context: Any) -> str: if privileged_context is None: @@ -122,23 +99,27 @@ def build( completion_ids: torch.Tensor, completion_mask: torch.Tensor, ) -> dict[str, torch.Tensor]: - student_batch = self.prompt_tokenizer.tokenize_prompts(prompts) teacher_prompts = [ self._compose_teacher_prompt(prompt, privileged_context) for prompt, privileged_context in zip(prompts, privileged_contexts, strict=True) ] - teacher_batch = self.prompt_tokenizer.tokenize_prompts(teacher_prompts) - teacher_input_ids = torch.cat([teacher_batch.prompt_ids, completion_ids], dim=1) - teacher_attention_mask = torch.cat([teacher_batch.prompt_mask, completion_mask], dim=1) + teacher_prompt_ids_list = self.trainer._tokenize_prompts(teacher_prompts) + device = completion_ids.device + teacher_prompt_ids = [torch.tensor(ids) for ids in teacher_prompt_ids_list] + teacher_prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in teacher_prompt_ids] + teacher_prompt_ids = pad(teacher_prompt_ids, padding_value=self.trainer.pad_token_id, padding_side="left").to( + device=device + ) + teacher_prompt_mask = pad(teacher_prompt_mask, padding_value=0, padding_side="left").to(device=device) + teacher_input_ids = torch.cat([teacher_prompt_ids, completion_ids], dim=1) + teacher_attention_mask = torch.cat([teacher_prompt_mask, completion_mask], dim=1) return { - "prompt_ids": student_batch.prompt_ids, - "prompt_mask": student_batch.prompt_mask, "teacher_input_ids": teacher_input_ids, "teacher_attention_mask": teacher_attention_mask, } -class SDFTTrainer(SelfDistillationMixin, _BaseTrainer): +class SDFTTrainer(BaseSelfDistillationTrainer): """Trainer for SDFT-style on-policy self-distillation with explicit teacher prompts.""" _tag_names = ["trl", "sdft"] @@ -168,319 +149,71 @@ def __init__( optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), peft_config: PeftConfig | None = None, ): - if train_dataset is None: - raise ValueError("`train_dataset` is required") if isinstance(train_dataset, IterableDataset): raise NotImplementedError("Iterable datasets are not yet supported in SDFTTrainer.") if isinstance(eval_dataset, IterableDataset) or ( isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values()) ): raise NotImplementedError("Iterable eval datasets are not yet supported in SDFTTrainer.") - if args.use_vllm: - raise NotImplementedError("SDFTTrainer does not support `use_vllm=True` yet.") - if isinstance(model, str): - model_init_kwargs = args.model_init_kwargs or {} - if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: - model_init_kwargs["device_map"] = None - model = create_model_from_path(model, **model_init_kwargs) - elif args.model_init_kwargs is not None: - logger.warning( - "You passed `model_init_kwargs` to `SDFTConfig`, but `model` is already instantiated. " - "The `model_init_kwargs` will be ignored." - ) - self.model_kwarg_keys = ( - inspect.signature(model.forward).parameters.keys() - if not hasattr(model, "get_base_model") - else inspect.signature(model.get_base_model().forward).parameters.keys() - ) - - if is_peft_available() and is_peft_model(model) and peft_config is not None: - raise ValueError( - "You passed a `PeftModel` instance together with a `peft_config` to SDFTTrainer. Pass either a base " - "model with `peft_config`, or a pre-wrapped PEFT model." - ) - if peft_config is not None or (is_peft_available() and getattr(model, "peft_config", None) is not None): - model = prepare_peft_model(model, peft_config, args) - - if processing_class is None: - processing_class = AutoProcessor.from_pretrained( - get_config_model_id(model.config), truncation_side="left", padding_side="left" - ) - - if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer - elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class - else: - raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - self.pad_token_id = tokenizer.pad_token_id - self.eos_token_id = tokenizer.eos_token_id - self.max_prompt_length = args.max_prompt_length - self.max_completion_length = args.max_completion_length - self.num_generations = args.num_generations - self.num_iterations = args.num_iterations - self.temperature = args.temperature - self.loss_type = args.loss_type - self.shuffle_dataset = args.shuffle_dataset - self.generate_from_teacher = args.generate_from_teacher self.num_loss_tokens_to_skip = args.num_loss_tokens_to_skip - self.chat_template_kwargs = args.chat_template_kwargs or {} - self._step = 0 - self._buffered_inputs = None - self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} - self.prompt_tokenizer = PromptTokenizer(self) self.teacher_context_builder = DemonstrationTeacherContextBuilder(self) - generation_kwargs = { - "max_new_tokens": self.max_completion_length, - "do_sample": True, - "pad_token_id": tokenizer.pad_token_id, - "bos_token_id": tokenizer.bos_token_id, - "eos_token_id": tokenizer.eos_token_id, - "temperature": args.temperature, - "top_p": args.top_p, - "top_k": args.top_k, - "min_p": args.min_p, - "repetition_penalty": args.repetition_penalty, - "cache_implementation": args.cache_implementation, - } - if args.generation_kwargs is not None: - generation_kwargs.update(args.generation_kwargs) - self.generation_config = GenerationConfig(**generation_kwargs, disable_compile=True) - - if hasattr(model, "warnings_issued"): - model.warnings_issued["estimate_tokens"] = True - super().__init__( model=model, args=args, - data_collator=identity, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, callbacks=callbacks, optimizers=optimizers, - compute_loss_func="non-None value to disable scaling", + peft_config=peft_config, ) - if args.disable_dropout: - disable_dropout_in_model(self.model) - - self.model.add_model_tags(self._tag_names) - - # In self-distillation the teacher is always derived from the student: - # - PEFT: base model with adapter disabled (or EMA teacher adapter when sync_ref_model=True) - # - Non-PEFT: same model (or deep-copied EMA model when sync_ref_model=True) - self.teacher_model = None - - if args.sync_ref_model: - if is_peft_available() and is_peft_model(self.model): - self.add_callback( - PEFTAdapterEMACallback( - model=self.model, - teacher_adapter_name="teacher", - update_rate=args.ref_model_mixup_alpha, - sync_steps=args.ref_model_sync_steps, - accelerator=self.accelerator, - ) - ) - else: - student_model = self.accelerator.unwrap_model(self.model) - self.teacher_model = copy.deepcopy(student_model) - self.teacher_model.requires_grad_(False) - self.teacher_model.eval() - if self.is_deepspeed_enabled: - self.teacher_model = prepare_deepspeed(self.teacher_model, self.accelerator) - elif self.is_fsdp_enabled: - self.teacher_model = prepare_fsdp(self.teacher_model, self.accelerator) - else: - self.teacher_model = self.accelerator.prepare_model(self.teacher_model, evaluation_mode=True) - self.add_callback(SyncRefModelCallback(ref_model=self.teacher_model, accelerator=self.accelerator)) - - self.model_accepts_loss_kwargs = False - - def get_train_dataloader(self): - if self.train_dataset is None: - raise ValueError("Trainer: training requires a train_dataset.") - - train_dataset = self.train_dataset - data_collator = self.data_collator - if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): - train_dataset = self._remove_unused_columns(train_dataset, description="training") - else: - data_collator = self._get_collator_with_removed_columns(data_collator, description="training") - - dataloader_params = { - "batch_size": self._train_batch_size * self.args.steps_per_generation, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "persistent_workers": self.args.dataloader_persistent_workers, - } - if not isinstance(train_dataset, torch.utils.data.IterableDataset): - dataloader_params["sampler"] = self._get_train_sampler() - dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["worker_init_fn"] = partial( - seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index - ) - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) - - def _get_train_sampler(self, dataset=None) -> Sampler: - if dataset is None: - dataset = self.train_dataset - return RepeatSampler( - data_source=dataset, - mini_repeat_count=self.num_generations, - batch_size=self.args.generation_batch_size // self.num_generations, - repeat_count=self.num_iterations * self.args.steps_per_generation, - shuffle=self.shuffle_dataset, - seed=self.args.seed, - ) - - def _get_eval_sampler(self, eval_dataset) -> Sampler: - return RepeatSampler( - data_source=eval_dataset, - mini_repeat_count=self.num_generations, - seed=self.args.seed, - ) - - def training_step(self, model, inputs, num_items_in_batch): - output = super().training_step(model, inputs, num_items_in_batch) - self._step += 1 - return output - - def _prepare_inputs(self, generation_batch): - mode = "train" if self.model.training else "eval" - if mode == "train": - generate_every = self.args.steps_per_generation * self.num_iterations - if self._step % generate_every == 0 or self._buffered_inputs is None: - generation_batch = self._build_buffered_batch(generation_batch) - self._buffered_inputs = split_tensor_dict(generation_batch, self.args.steps_per_generation) - self._dispatch_self_distillation_callback( - "on_generation_batch_built", - generate_every=generate_every, - steps_per_generation=self.args.steps_per_generation, - ) - return self._buffered_inputs[self._step % self.args.steps_per_generation] - return self._build_buffered_batch(generation_batch) - - def _generate_completion_ids(self, prompts: list[Any]) -> tuple[torch.Tensor, torch.Tensor]: - generate_inputs = self.processing_class( - text=self.prompt_tokenizer.apply_prompt_template(prompts), - return_tensors="pt", - padding=True, - padding_side="left", - max_length=self.max_prompt_length, - truncation=True, - add_special_tokens=False, - ) - # This generation helper builds tokenized model inputs directly, so use the base Trainer tensor preparation - # instead of re-entering the buffered outer training hook. - generate_inputs = _BaseTrainer._prepare_inputs(self, generate_inputs) - - with ( - unwrap_model_for_generation( - self.model_wrapped, - self.accelerator, - gather_deepspeed3_params=self.args.ds3_gather_for_generation, - ) as unwrapped_model, - torch.no_grad(), - ): - prompt_completion_ids = unwrapped_model.generate( - **generate_inputs, generation_config=self.generation_config - ) - - prompt_length = generate_inputs["input_ids"].size(1) - completion_ids = prompt_completion_ids[:, prompt_length:] - is_eos = completion_ids == self.eos_token_id - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device) - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - seq_idx = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1) - completion_mask = (seq_idx <= eos_idx.unsqueeze(1)).long() - - completion_ids_list = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=True)] - completion_ids = [torch.tensor(ids, device=self.accelerator.device) for ids in completion_ids_list] - completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] - return ( - pad(completion_ids, padding_value=self.pad_token_id, padding_side="right"), - pad(completion_mask, padding_value=0, padding_side="right"), - ) - - def _build_buffered_batch(self, inputs: list[dict[str, Any]]) -> dict[str, torch.Tensor | Any]: + def finalize_batch( + self, + inputs: list[dict[str, Any]], + rollout_batch: RolloutBatch, + ) -> TrainingBatch: prompts, privileged_contexts = self._split_prompt_and_privileged_context(inputs) - generation_prompts = self.teacher_context_builder.select_generation_prompts(prompts, privileged_contexts) - generation_prompt_text = self.prompt_tokenizer.apply_prompt_template(generation_prompts) - self._dispatch_self_distillation_callback( - "on_generation_prompts_selected", - generation_prompts=generation_prompts, - generation_prompt_text=generation_prompt_text, - ) - completion_ids, completion_mask = self._generate_completion_ids(generation_prompts) - teacher_batch = self.teacher_context_builder.build( - prompts, privileged_contexts, completion_ids, completion_mask + prompts, + privileged_contexts, + rollout_batch.completion_ids, + rollout_batch.completion_mask, ) - prompt_completion_ids = torch.cat([teacher_batch["prompt_ids"], completion_ids], dim=1) - attention_mask = torch.cat([teacher_batch["prompt_mask"], completion_mask], dim=1) - logits_to_keep = completion_ids.size(1) - - with torch.no_grad(): - generate_every = self.args.steps_per_generation * self.num_iterations - if not self.generate_from_teacher and self.args.gradient_accumulation_steps % generate_every != 0: - old_per_token_logps, _ = self._get_per_token_logps_and_entropies( - self.model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - compute_entropy=False, - ) - else: - old_per_token_logps = None - - self._dispatch_self_distillation_callback( - "on_self_distillation_batch_prepared", - old_per_token_logps=old_per_token_logps, - prompt_ids=teacher_batch["prompt_ids"], - completion_ids=completion_ids, + batch = rollout_batch.as_dict() + batch.update( + { + "teacher_input_ids": teacher_batch["teacher_input_ids"], + "teacher_attention_mask": teacher_batch["teacher_attention_mask"], + } ) - output = { - "prompt_ids": teacher_batch["prompt_ids"], - "prompt_mask": teacher_batch["prompt_mask"], - "completion_ids": completion_ids, - "completion_mask": completion_mask, - "teacher_input_ids": teacher_batch["teacher_input_ids"], - "teacher_attention_mask": teacher_batch["teacher_attention_mask"], - } - if old_per_token_logps is not None: - output["old_per_token_logps"] = old_per_token_logps - return output + return batch def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): if return_outputs: raise ValueError("The SDFTTrainer does not support returning outputs") - if self.num_loss_tokens_to_skip > 0: - inputs = dict(inputs) - completion_mask = inputs["completion_mask"].clone() - token_positions = torch.arange(completion_mask.size(1), device=completion_mask.device).unsqueeze(0) - completion_mask = completion_mask * (token_positions >= self.num_loss_tokens_to_skip).long() - inputs["completion_mask"] = completion_mask - - loss = self._compute_self_distillation_loss(model, inputs) + distillation_logits = self._compute_teacher_student_logits(model, self.teacher_model, inputs) + loss = self._compute_self_distillation_loss(model, inputs, distillation_logits) accumulation_scale = self.current_gradient_accumulation_steps if self.model.training else 1.0 return loss / accumulation_scale - def _get_teacher_context_for_self_distillation(self, model): - if is_peft_available() and isinstance(self.model, PeftModel): - model = self.accelerator.unwrap_model(self.model) - if self.args.sync_ref_model and "teacher" in model.peft_config: - return use_adapter(model, adapter_name="teacher") - return use_adapter(model, adapter_name=None) - return super()._get_teacher_context_for_self_distillation(model) + def _build_self_distillation_response_mask( + self, + completion_mask: torch.Tensor, + self_distillation_mask: torch.Tensor | None, + ) -> torch.Tensor: + response_mask = BaseSelfDistillationTrainer._build_self_distillation_response_mask( + completion_mask, + self_distillation_mask, + ) + if self.num_loss_tokens_to_skip <= 0: + return response_mask + + # SDFT skips the first few completion tokens only in the distillation loss to suppress teacher-prompt artifacts. + token_positions = torch.arange(response_mask.size(1), device=response_mask.device).unsqueeze(0) + skip_mask = (token_positions >= self.num_loss_tokens_to_skip).long() + return response_mask * skip_mask diff --git a/trl/experimental/sdpo/sdpo.py b/trl/experimental/sdpo/sdpo.py index 7a65a00cc5d..cd752cc6855 100644 --- a/trl/experimental/sdpo/sdpo.py +++ b/trl/experimental/sdpo/sdpo.py @@ -43,7 +43,7 @@ --num_generations 8 \ --generation_batch_size 32 \ --distillation_alpha 1.0 \ - --full_logit_distillation false \ + --distillation_mode sampled_token \ --sdpo_policy_loss_mode hybrid \ --report_to none \ --eval_strategy steps \ diff --git a/trl/experimental/sdpo/sdpo_config.py b/trl/experimental/sdpo/sdpo_config.py index 1cc8c1510b7..92c4a181f39 100644 --- a/trl/experimental/sdpo/sdpo_config.py +++ b/trl/experimental/sdpo/sdpo_config.py @@ -13,6 +13,7 @@ # limitations under the License. from dataclasses import dataclass, field +from typing import Literal from ..self_distillation import SelfDistillationConfig @@ -22,29 +23,46 @@ class SDPOConfig(SelfDistillationConfig): r""" Configuration class for the [`SDPOTrainer`]. - This class extends [`experimental.self_distillation.SelfDistillationConfig`] with the online teacher-construction - parameters used by Self-Distillation Policy Optimization (SDPO). - Parameters: + > Parameters that control the online policy objective + + beta (`float`, *optional*, defaults to `0.0`): + Reference-model KL coefficient for online policy optimization. + epsilon (`float`, *optional*, defaults to `0.2`): + Lower clipping coefficient for GRPO-style policy loss. + epsilon_high (`float` or `None`, *optional*): + Upper clipping coefficient. Defaults to `epsilon` when unset. + importance_sampling_level (`str`, *optional*, defaults to `"token"`): + Importance-sampling granularity. Supported: `token`, `sequence`. + reward_weights (`list[float]` or `None`, *optional*): + Optional weights for multiple reward functions. + scale_rewards (`str` or `bool`, *optional*, defaults to `"group"`): + Reward normalization mode. Supported: `group`, `batch`, `none`. + > Parameters that control the SDPO loss sdpo_policy_loss_mode (`str`, *optional*, defaults to `"distillation_only"`): How SDPO combines the online policy loss and self-distillation loss. Supported: `distillation_only`, - `hybrid`. + `policy_only`, `hybrid`. distillation_alpha (`float`, *optional*, defaults to `1.0`): - Divergence interpolation coefficient. Token-level SDPO requires the official reverse-KL setting + Divergence interpolation coefficient. Sampled-token SDPO requires the official reverse-KL setting + `distillation_alpha=1.0`. + distillation_mode (`Literal["sampled_token", "full_logits", "topk_logits"]`, *optional*, defaults to `"sampled_token"`): + Distillation objective mode. `"sampled_token"` is the default SDPO mode and requires `distillation_alpha=1.0`. distillation_topk (`int` or `None`, *optional*): - Top-k approximation for logit-level SDPO. Requires `full_logit_distillation=True`. + Top-k approximation for logit-level SDPO. Must be set when `distillation_mode="topk_logits"` and left + unset otherwise. > Parameters that control the teacher - teacher_regularization (`str`, *optional*, defaults to `"ema"`): - Teacher update strategy. Supported: `ema`, `none`. - teacher_update_rate (`float` or `None`, *optional*): - EMA update rate used when `teacher_regularization="ema"`. - ema_update_rate (`float`, *optional*, defaults to `0.05`): - Deprecated alias for `teacher_update_rate`. + teacher_model_kind (`str`, *optional*, defaults to `"ema"`): + Semantic teacher choice. `base` uses the initial student, `live` uses the current student, and `ema` + uses an exponentially averaged teacher. + teacher_update_rate (`float`, *optional*, defaults to `0.05`): + EMA update rate used when `teacher_model_kind="ema"`. + teacher_sync_steps (`int`, *optional*, defaults to `1`): + Number of optimizer steps between teacher EMA updates. > Parameters that control reprompting @@ -60,31 +78,69 @@ class SDPOConfig(SelfDistillationConfig): default=True, metadata={"help": "Skip reprompting when model generates correct response."}, ) + beta: float = field( + default=0.0, + metadata={"help": "Reference-model KL coefficient for online policy optimization."}, + ) + epsilon: float = field( + default=0.2, + metadata={"help": "Lower clipping coefficient for GRPO-style policy loss."}, + ) + epsilon_high: float | None = field( + default=None, + metadata={"help": "Upper clipping coefficient. Defaults to `epsilon` when unset."}, + ) + importance_sampling_level: str = field( + default="token", + metadata={"help": "Importance-sampling granularity. Supported: `token`, `sequence`."}, + ) + reward_weights: list[float] | None = field( + default=None, + metadata={"help": "Optional weights for multiple reward functions."}, + ) + scale_rewards: str | bool = field( + default="group", + metadata={"help": "Reward normalization mode. Supported: `group`, `batch`, `none`."}, + ) distillation_alpha: float = field( default=1.0, metadata={ - "help": "KL divergence direction for SDPO. Token-level SDPO requires reverse KL (`distillation_alpha=1.0`)." + "help": "Divergence interpolation coefficient. Sampled-token SDPO requires the official reverse-KL setting " + "`distillation_alpha=1.0`." + }, + ) + distillation_mode: Literal["sampled_token", "full_logits", "topk_logits"] = field( + default="sampled_token", + metadata={ + "help": "Distillation objective mode. `sampled_token` is the default SDPO mode and requires " + "`distillation_alpha=1.0`." }, ) distillation_topk: int | None = field( default=None, - metadata={"help": "Top-K approximation for logit-level SDPO. Requires `full_logit_distillation=True`."}, + metadata={ + "help": "Top-k approximation for logit-level SDPO. Must be set when `distillation_mode=topk_logits` and left " + "unset otherwise." + }, ) sdpo_policy_loss_mode: str = field( default="distillation_only", - metadata={"help": "SDPO policy loss mode. Supported: `distillation_only`, `hybrid`."}, + metadata={"help": "SDPO policy loss mode. Supported: `distillation_only`, `policy_only`, `hybrid`."}, ) - teacher_regularization: str = field( + teacher_model_kind: str = field( default="ema", - metadata={"help": "Teacher regularization mode. Supported: `ema`, `none`."}, + metadata={ + "help": "Semantic teacher choice. `base` uses the initial student, `live` uses the current student, " + "and `ema` uses an exponentially averaged teacher." + }, ) - teacher_update_rate: float | None = field( - default=None, + teacher_update_rate: float = field( + default=0.05, metadata={"help": "Teacher update rate used for EMA teacher synchronization."}, ) - ema_update_rate: float = field( - default=0.05, - metadata={"help": "Deprecated alias for `teacher_update_rate`."}, + teacher_sync_steps: int = field( + default=1, + metadata={"help": "How often to synchronize the EMA teacher model."}, ) max_reprompt_len: int = field( default=10240, @@ -126,23 +182,26 @@ class SDPOConfig(SelfDistillationConfig): def __post_init__(self): super().__post_init__() - if self.teacher_update_rate is None: - self.teacher_update_rate = self.ema_update_rate + self.scale_rewards = {True: "group", False: "none"}.get(self.scale_rewards, self.scale_rewards) + if self.scale_rewards not in ["group", "batch", "none"]: + raise ValueError("scale_rewards must be one of: 'group', 'batch', 'none'") + + if self.importance_sampling_level not in ["token", "sequence"]: + raise ValueError("importance_sampling_level must be either 'token' or 'sequence'") + + if self.epsilon_high is None: + self.epsilon_high = self.epsilon - if self.teacher_regularization not in {"ema", "none"}: - raise ValueError("teacher_regularization must be one of: 'ema', 'none'") - if not 0.0 <= self.teacher_update_rate <= 1.0: - raise ValueError("teacher_update_rate must be in [0, 1]") - if self.sdpo_policy_loss_mode not in {"distillation_only", "hybrid"}: - raise ValueError("sdpo_policy_loss_mode must be one of: 'distillation_only', 'hybrid'") + if self.sdpo_policy_loss_mode not in {"distillation_only", "policy_only", "hybrid"}: + raise ValueError("sdpo_policy_loss_mode must be one of: 'distillation_only', 'policy_only', 'hybrid'") if self.sdpo_policy_loss_mode == "distillation_only" and self.distillation_weight <= 0: raise ValueError("distillation_only mode requires `distillation_weight > 0`.") + if self.sdpo_policy_loss_mode == "hybrid" and self.distillation_weight <= 0: + raise ValueError("hybrid mode requires `distillation_weight > 0`.") if self.max_reprompt_len <= 0: raise ValueError("max_reprompt_len must be positive") - if not self.full_logit_distillation and self.distillation_alpha != 1.0: + if self.distillation_mode == "sampled_token" and self.distillation_alpha != 1.0: raise ValueError( - "SDPO token-level distillation requires `distillation_alpha=1.0`. " - "Set `full_logit_distillation=True` to use other divergence settings." + "SDPO sampled-token distillation requires `distillation_alpha=1.0`. " + "Set `distillation_mode='full_logits'` or `distillation_mode='topk_logits'` to use other divergence settings." ) - if self.distillation_topk is not None and not self.full_logit_distillation: - raise ValueError("SDPO `distillation_topk` requires `full_logit_distillation=True`.") diff --git a/trl/experimental/sdpo/sdpo_trainer.py b/trl/experimental/sdpo/sdpo_trainer.py index ef84a17a44c..b0d3afc4e08 100644 --- a/trl/experimental/sdpo/sdpo_trainer.py +++ b/trl/experimental/sdpo/sdpo_trainer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import re import textwrap from typing import Any @@ -21,31 +20,36 @@ from accelerate.utils import gather_object from datasets import Dataset, IterableDataset from torch import nn -from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback - -from ...trainer.callbacks import SyncRefModelCallback -from ...trainer.utils import pad -from ..self_distillation.base_self_distillation_trainer import BaseSelfDistillationTrainer -from ..self_distillation.teacher_context import TokenizedPromptBatch, extract_last_user_text +from transformers import ( + AutoModelForSequenceClassification, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, +) +from transformers.utils import logging + +from ...data_utils import apply_chat_template, is_conversational +from ...models import prepare_deepspeed, prepare_fsdp +from ...trainer.base_trainer import _BaseTrainer +from ...trainer.utils import get_config_model_id, pad +from ..self_distillation.base_self_distillation_trainer import ( + BaseSelfDistillationTrainer, + DistillationLogits, + RolloutBatch, + TrainingBatch, +) +from ..self_distillation.loss_utils import aggregate_loss, select_token_log_probs +from ..self_distillation.prompt_utils import extract_last_user_text from .sdpo_config import SDPOConfig -class EMATeacherSyncCallback(SyncRefModelCallback): - """Synchronize an EMA teacher model with the student model on each step.""" - - def __init__(self, teacher_model, update_rate: float, accelerator=None): - super().__init__(ref_model=teacher_model, accelerator=accelerator) - self.update_rate = update_rate - - def on_step_end(self, args, state, control, **kwargs): - model = kwargs["model"] - if self.accelerator is not None: - model = self.accelerator.unwrap_model(model) - self.sync_target_model(model, self.ref_model, self.update_rate) +logger = logging.get_logger(__name__) class SuccessfulRolloutTeacherContextBuilder: - """Builds SDPO teacher contexts from successful rollouts, following the official online implementation.""" + """Builds teacher contexts from successful rollouts""" def __init__(self, trainer): self.trainer = trainer @@ -60,36 +64,18 @@ def _build_reprompt_text(self, prompt_text: str, solution_text: str, feedback_te def _tokenize_teacher_messages( self, teacher_messages_list: list[str | list[dict[str, Any]]] - ) -> TokenizedPromptBatch: - teacher_prompt_ids_list = [] + ) -> dict[str, torch.Tensor]: device = self.trainer.accelerator.device - chat_template_kwargs = getattr(self.trainer, "chat_template_kwargs", {}) - for msg in teacher_messages_list: - if isinstance(msg, list) and isinstance(msg[0], dict): - tokenized = self.trainer.processing_class.apply_chat_template( - msg, - tokenize=True, - add_generation_prompt=True, - return_tensors="pt", - **chat_template_kwargs, - ) - if isinstance(tokenized, torch.Tensor): - ids = tokenized.squeeze(0) - else: - ids = tokenized["input_ids"].squeeze(0) - else: - ids = self.trainer.processing_class.encode(msg, return_tensors="pt").squeeze(0) - - if ids.shape[0] > self.trainer.args.max_reprompt_len: - ids = ids[-self.trainer.args.max_reprompt_len :] - teacher_prompt_ids_list.append(ids) - - teacher_prompt_ids = [ids.to(device) for ids in teacher_prompt_ids_list] + teacher_prompt_ids_list = self.trainer._tokenize_prompts_untruncated(teacher_messages_list) + teacher_prompt_ids = [ + torch.as_tensor(ids[-self.trainer.args.max_reprompt_len :], device=device) + for ids in teacher_prompt_ids_list + ] teacher_prompt_mask = [torch.ones(len(ids), dtype=torch.long, device=device) for ids in teacher_prompt_ids] - return TokenizedPromptBatch( - prompt_ids=pad(teacher_prompt_ids, padding_value=self.trainer.pad_token_id, padding_side="left"), - prompt_mask=pad(teacher_prompt_mask, padding_value=0, padding_side="left"), - ) + return { + "prompt_ids": pad(teacher_prompt_ids, padding_value=self.trainer.pad_token_id, padding_side="left"), + "prompt_mask": pad(teacher_prompt_mask, padding_value=0, padding_side="left"), + } def build( self, @@ -214,8 +200,8 @@ def build( local_teacher_messages.append(self._build_reprompt_text(original_prompt, solution_text, feedback_text)) teacher_batch = self._tokenize_teacher_messages(local_teacher_messages) - teacher_input_ids = torch.cat([teacher_batch.prompt_ids, completion_ids], dim=1) - teacher_attention_mask = torch.cat([teacher_batch.prompt_mask, completion_mask], dim=1) + teacher_input_ids = torch.cat([teacher_batch["prompt_ids"], completion_ids], dim=1) + teacher_attention_mask = torch.cat([teacher_batch["prompt_mask"], completion_mask], dim=1) batch_size = total_samples if total_samples > 0 else 1 num_groups = max(1, total_samples // max(1, num_generations)) @@ -277,44 +263,193 @@ def __init__( raise ValueError("`reward_funcs` is required for SDPOTrainer because SDPO must score rollouts.") super().__init__( model=model, - reward_funcs=reward_funcs, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, - reward_processing_classes=reward_processing_classes, callbacks=callbacks, optimizers=optimizers, peft_config=peft_config, ) + self.importance_sampling_level = args.importance_sampling_level + self.scale_rewards = args.scale_rewards + self.epsilon_low = args.epsilon + self.epsilon_high = args.epsilon_high + self.beta = args.beta + + if not isinstance(reward_funcs, list): + reward_funcs = [reward_funcs] + self.reward_func_names = [] + for i, reward_func in enumerate(reward_funcs): + if isinstance(reward_func, str): + reward_model_init_kwargs = args.model_init_kwargs or {} + if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + reward_model_init_kwargs["device_map"] = None + reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( + reward_func, + num_labels=1, + **reward_model_init_kwargs, + ) + if isinstance(reward_funcs[i], nn.Module): + self.reward_func_names.append(get_config_model_id(reward_funcs[i].config).split("/")[-1]) + else: + self.reward_func_names.append(reward_funcs[i].__name__) + self.reward_funcs = reward_funcs + + if args.reward_weights is not None: + if len(args.reward_weights) != len(self.reward_funcs): + raise ValueError("Number of reward weights must match number of reward functions") + self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) + else: + self.reward_weights = torch.ones(len(self.reward_funcs), dtype=torch.float32) + + if reward_processing_classes is None: + reward_processing_classes = [None] * len(self.reward_funcs) + elif not isinstance(reward_processing_classes, list): + reward_processing_classes = [reward_processing_classes] + if len(reward_processing_classes) != len(self.reward_funcs): + raise ValueError("Number of reward processing classes must match number of reward functions") + + for i, (reward_processing_class, reward_func) in enumerate( + zip(reward_processing_classes, self.reward_funcs, strict=True) + ): + if isinstance(reward_func, PreTrainedModel): + if reward_processing_class is None: + reward_processing_class = AutoTokenizer.from_pretrained(get_config_model_id(reward_func.config)) + if reward_processing_class.pad_token_id is None: + reward_processing_class.pad_token = reward_processing_class.eos_token + reward_func.config.pad_token_id = reward_processing_class.pad_token_id + reward_processing_classes[i] = reward_processing_class + self.reward_processing_classes = reward_processing_classes + + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, nn.Module): + if self.is_deepspeed_enabled: + self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) + elif self.is_fsdp_enabled: + self.reward_funcs[i] = prepare_fsdp(reward_func, self.accelerator) + else: + self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) + self.teacher_context_builder = SuccessfulRolloutTeacherContextBuilder(self) - if self.args.teacher_regularization == "ema": - # `self.model` may already be accelerator-wrapped after the shared base constructor. Build the EMA - # teacher from the unwrapped student model first, then prepare it as an auxiliary eval-only module. - student_model = self.accelerator.unwrap_model(self.model) - self.teacher_model = copy.deepcopy(student_model) - self.teacher_model.requires_grad_(False) - self.teacher_model.eval() - self.teacher_model = self._prepare_auxiliary_model_for_eval(self.teacher_model) - self.add_callback( - EMATeacherSyncCallback( - teacher_model=self.teacher_model, - update_rate=self.args.teacher_update_rate, - accelerator=self.accelerator, + + def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): + device = self.accelerator.device + if len(self.reward_funcs) == 0: + return torch.zeros((len(prompts), 0), device=device) + + rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) + keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]] + reward_kwargs = {key: [example[key] for example in inputs] for key in keys} + reward_kwargs["trainer_state"] = self.state + + for i, (reward_func, reward_processing_class) in enumerate( + zip(self.reward_funcs, self.reward_processing_classes, strict=True) + ): + if isinstance(reward_func, nn.Module): + if is_conversational(inputs[0]): + messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)] + texts = [ + apply_chat_template(x, reward_processing_class, **self.chat_template_kwargs)["text"] + for x in messages + ] + else: + texts = [p + c for p, c in zip(prompts, completions, strict=True)] + reward_inputs = reward_processing_class( + text=texts, + return_tensors="pt", + padding=True, + padding_side="right", + add_special_tokens=False, ) - ) + reward_inputs = _BaseTrainer._prepare_inputs(self, reward_inputs) + with torch.inference_mode(): + rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] + else: + output_reward_func = reward_func( + prompts=prompts, + completions=completions, + completion_ids=completion_ids_list, + **reward_kwargs, + ) + output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) - def _allow_topk_without_full_logit_distillation(self) -> bool: - return False + return self.accelerator.gather(rewards_per_func) - def _generate_and_score_completions( - self, inputs: list[dict[str, torch.Tensor | Any]] - ) -> dict[str, torch.Tensor | Any]: + def finalize_batch( + self, + inputs: list[dict[str, Any]], + rollout_batch: RolloutBatch, + ) -> TrainingBatch: + device = self.accelerator.device + mode = "train" if self.model.training else "eval" prompts, privileged_contexts = self._split_prompt_and_privileged_context(inputs) + raw_completion_lengths = rollout_batch.raw_completion_lengths.detach().cpu().tolist() + completion_ids_list = [ + ids[:length].tolist() + for ids, length in zip(rollout_batch.completion_ids.detach().cpu(), raw_completion_lengths, strict=True) + ] + if is_conversational({"prompt": prompts[0]}): + completions_text = self.processing_class.batch_decode( + rollout_batch.completion_ids, skip_special_tokens=True + ) + completions = [[{"role": "assistant", "content": content}] for content in completions_text] + else: + completions = self.processing_class.batch_decode(rollout_batch.completion_ids, skip_special_tokens=True) - output = super()._generate_and_score_completions(inputs) - output.update( - self.teacher_context_builder.build(output, prompts, output["rewards"], feedbacks=privileged_contexts) + rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) + if rewards_per_func.numel() == 0: + rewards = torch.zeros(self.accelerator.num_processes * len(prompts), device=device) + else: + rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + + num_generations = self.num_generations if mode == "train" else self.num_generations_eval + mean_grouped_rewards = rewards.view(-1, num_generations).mean(dim=1).repeat_interleave(num_generations, dim=0) + if self.scale_rewards == "batch": + std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards) + group_std_rewards = rewards.view(-1, num_generations).std(dim=1) + elif self.scale_rewards == "none": + std_rewards = torch.ones_like(rewards) + group_std_rewards = torch.ones(rewards.numel() // num_generations, device=device, dtype=rewards.dtype) + else: + group_std_rewards = rewards.view(-1, num_generations).std(dim=1) + std_rewards = group_std_rewards.repeat_interleave(num_generations, dim=0) + advantages = (rewards - mean_grouped_rewards) / (std_rewards + 1e-4) + self._record_reward_diagnostics(mode, rewards, rewards_per_func, group_std_rewards) + + local_batch_size = rollout_batch.completion_ids.size(0) + process_start = self.accelerator.process_index * local_batch_size + process_slice = slice(process_start, process_start + local_batch_size) + local_rewards = rewards[process_slice] + local_advantages = advantages[process_slice] + + agg_completion_lengths = self.accelerator.gather( + torch.tensor([len(ids) for ids in completion_ids_list], device=device) + ) + self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) + + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + agg_is_truncated = self.accelerator.gather(is_truncated) + self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) + term_completion_lengths = agg_completion_lengths[~agg_is_truncated] + if len(term_completion_lengths) == 0: + term_completion_lengths = torch.zeros(1, device=device) + self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) + + rollout_dict = rollout_batch.as_dict() + rollout_dict["rewards"] = local_rewards + rollout_dict["advantages"] = local_advantages + teacher_context = self.teacher_context_builder.build( + rollout_dict, + prompts, + rollout_dict["rewards"], + feedbacks=privileged_contexts, ) mode = "train" if self.model.training else "eval" @@ -324,13 +459,23 @@ def _generate_and_score_completions( self._dispatch_self_distillation_callback( "on_teacher_context_built", - teacher_input_ids=output["teacher_input_ids"], - teacher_attention_mask=output["teacher_attention_mask"], - completion_mask=output["completion_mask"], - self_distillation_mask=output["self_distillation_mask"], + teacher_input_ids=teacher_context["teacher_input_ids"], + teacher_attention_mask=teacher_context["teacher_attention_mask"], + completion_mask=rollout_batch.completion_mask, + self_distillation_mask=teacher_context["self_distillation_mask"], ) - return output + batch = rollout_batch.as_dict() + batch.update( + { + "teacher_input_ids": teacher_context["teacher_input_ids"], + "teacher_attention_mask": teacher_context["teacher_attention_mask"], + "self_distillation_mask": teacher_context["self_distillation_mask"], + "rewards": local_rewards, + "advantages": local_advantages, + } + ) + return batch def _warn_on_inactive_self_distillation(self, mode: str) -> None: metrics = self.teacher_context_builder.last_metrics @@ -365,23 +510,154 @@ def _warn_on_inactive_self_distillation(self, mode: str) -> None: else: self._diagnostic_counters[mode]["no_successful_rollouts"] = 0 - def _compute_loss( + def _record_reward_diagnostics( + self, + mode: str, + rewards: torch.Tensor, + rewards_per_func: torch.Tensor, + group_std_rewards: torch.Tensor, + ) -> None: + tolerance = self.args.diagnostics_flat_tolerance + + reward_mean = rewards.mean() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) + reward_std = rewards.std() if rewards.numel() > 1 else torch.tensor(0.0, device=self.accelerator.device) + reward_min = rewards.min() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) + reward_max = rewards.max() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) + flat_group_fraction = ( + (group_std_rewards <= tolerance).float().mean() + if group_std_rewards.numel() > 0 + else torch.tensor(1.0, device=self.accelerator.device) + ) + + self._metrics[mode]["self_distillation/reward_mean"].append(self.accelerator.gather(reward_mean).mean().item()) + self._metrics[mode]["self_distillation/reward_std"].append(self.accelerator.gather(reward_std).mean().item()) + self._metrics[mode]["self_distillation/reward_min"].append(self.accelerator.gather(reward_min).min().item()) + self._metrics[mode]["self_distillation/reward_max"].append(self.accelerator.gather(reward_max).max().item()) + self._metrics[mode]["self_distillation/group_reward_std_mean"].append( + self.accelerator.gather(group_std_rewards.mean() if group_std_rewards.numel() > 0 else reward_std) + .mean() + .item() + ) + self._metrics[mode]["self_distillation/flat_group_fraction"].append( + self.accelerator.gather(flat_group_fraction).mean().item() + ) + + if rewards_per_func.numel() > 0: + reward_func_means = rewards_per_func.nanmean(dim=0) + gathered_means = self.accelerator.gather(reward_func_means).view(-1, reward_func_means.numel()).mean(dim=0) + for reward_name, reward_func_mean in zip(self.reward_func_names, gathered_means.tolist(), strict=True): + self._metrics[mode][f"self_distillation/rewards/{reward_name}"].append(reward_func_mean) + + reward_is_flat = reward_std.item() <= tolerance + grouped_rewards_are_flat = flat_group_fraction.item() >= 1.0 - tolerance + if reward_is_flat and grouped_rewards_are_flat: + self._warn_on_degenerate_diagnostics( + mode=mode, + counter_key="flat_rewards", + message=( + "Observed flat SDPO rewards across all sampled generations. " + "Policy advantages will collapse to zero, and SDPO will not learn. " + "Check reward density, reward shaping, or `success_reward_threshold`." + ), + ) + else: + self._diagnostic_counters[mode]["flat_rewards"] = 0 + + def _warn_on_degenerate_diagnostics(self, mode: str, counter_key: str, message: str) -> None: + interval = self.args.diagnostics_warning_interval + if interval == 0: + return + + self._diagnostic_counters[mode][counter_key] += 1 + count = self._diagnostic_counters[mode][counter_key] + if count == 1 or count % interval == 0: + logger.warning("%s Consecutive degenerate steps: %s.", message, count) + + def _compute_policy_loss( + self, + inputs, + student_logits, + ) -> torch.Tensor: + completion_ids = inputs["completion_ids"] + completion_mask = inputs["completion_mask"] + per_token_logps = select_token_log_probs(student_logits, completion_ids) + old_per_token_logps = inputs.get("old_per_token_logps") + old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps + advantages = inputs["advantages"] + if advantages.dim() == 1: + advantages = advantages.unsqueeze(1) + log_ratio = per_token_logps - old_per_token_logps + if self.importance_sampling_level == "sequence": + log_ratio = (log_ratio * completion_mask).sum(-1, keepdim=True) / completion_mask.sum( + -1, keepdim=True + ).clamp(min=1.0) + coef_1 = torch.exp(log_ratio) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + per_token_loss = -torch.min(coef_1 * advantages, coef_2 * advantages) + + loss = aggregate_loss( + per_token_loss, + completion_mask, + loss_type=self.loss_type, + max_completion_length=self.max_completion_length, + ) + + mode = "train" if self.model.training else "eval" + self._metrics[mode]["self_distillation/policy_loss"].append( + self.accelerator.gather(loss.detach()).mean().item() + ) + + accumulation_scale = self.current_gradient_accumulation_steps if mode == "train" else 1.0 + return loss / accumulation_scale + + def _compute_weighted_self_distillation_loss( self, model, inputs, + distillation_logits: DistillationLogits, ) -> torch.Tensor: accumulation_scale = self.current_gradient_accumulation_steps if self.model.training else 1.0 + distillation_loss = ( + self._compute_self_distillation_loss( + model, + inputs, + distillation_logits, + ) + / accumulation_scale + ) + return self.args.distillation_weight * distillation_loss + + def _compute_hybrid_loss(self, model, inputs) -> torch.Tensor: + distillation_logits = self._compute_teacher_student_logits(model, self.teacher_model, inputs) + policy_loss = self._compute_policy_loss(inputs, distillation_logits.student_logits) + weighted_distillation_loss = self._compute_weighted_self_distillation_loss( + model, + inputs, + distillation_logits, + ) + return policy_loss + weighted_distillation_loss - if self.args.sdpo_policy_loss_mode == "hybrid": - base_policy_loss = super()._compute_loss(model, inputs) - if self.args.distillation_weight <= 0.0: - return base_policy_loss - - sdpo_loss = self._compute_self_distillation_loss(model, inputs) / accumulation_scale - return base_policy_loss + self.args.distillation_weight * sdpo_loss + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if return_outputs: + raise ValueError("The SDPOTrainer does not support returning outputs") - if self.args.distillation_weight <= 0.0: - return super()._compute_loss(model, inputs) + if self.args.sdpo_policy_loss_mode == "hybrid": + return self._compute_hybrid_loss(model, inputs) + if self.args.sdpo_policy_loss_mode == "distillation_only": + distillation_logits = self._compute_teacher_student_logits(model, self.teacher_model, inputs) + return self._compute_weighted_self_distillation_loss(model, inputs, distillation_logits) + if self.args.sdpo_policy_loss_mode == "policy_only": + student_logits = self._compute_student_distillation_logits( + model=model, + prompt_ids=inputs["prompt_ids"], + prompt_mask=inputs["prompt_mask"], + completion_ids=inputs["completion_ids"], + completion_mask=inputs["completion_mask"], + logits_to_keep=inputs["completion_ids"].size(1), + ) + return self._compute_policy_loss(inputs, student_logits) - sdpo_loss = self._compute_self_distillation_loss(model, inputs) / accumulation_scale - return self.args.distillation_weight * sdpo_loss + raise ValueError( + "Unsupported `sdpo_policy_loss_mode`: " + f"{self.args.sdpo_policy_loss_mode!r}. Expected one of: 'hybrid', 'distillation_only', 'policy_only'." + ) diff --git a/trl/experimental/self_distillation/__init__.py b/trl/experimental/self_distillation/__init__.py index 1449db2f7a3..f5d94b41edb 100644 --- a/trl/experimental/self_distillation/__init__.py +++ b/trl/experimental/self_distillation/__init__.py @@ -13,7 +13,6 @@ # limitations under the License. from .self_distillation_config import SelfDistillationConfig -from .self_distillation_mixin import SelfDistillationMixin -__all__ = ["SelfDistillationConfig", "SelfDistillationMixin"] +__all__ = ["SelfDistillationConfig"] diff --git a/trl/experimental/self_distillation/base_self_distillation_trainer.py b/trl/experimental/self_distillation/base_self_distillation_trainer.py index bd9abb95164..95fe72669c3 100644 --- a/trl/experimental/self_distillation/base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/base_self_distillation_trainer.py @@ -12,29 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Shared online self-distillation trainer scaffold. - -This base combines the generic Trainer setup for self-distillation with the online rollout utilities used by SDPO-like -methods. Offline methods such as SDFT stay on `_BaseTrainer` directly and only reuse the shared distillation mixin. -""" - from __future__ import annotations +import copy import inspect +from abc import ABC, abstractmethod from collections import defaultdict +from contextlib import nullcontext +from dataclasses import dataclass from functools import partial from typing import Any import datasets import torch from accelerate.logging import get_logger +from accelerate.utils import is_peft_model from datasets import Dataset, IterableDataset from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.utils.data import DataLoader, Sampler from transformers import ( - AutoModelForSequenceClassification, AutoProcessor, - AutoTokenizer, GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase, @@ -44,7 +42,8 @@ from transformers.trainer_utils import seed_worker from transformers.utils import is_datasets_available, is_peft_available -from ...models import prepare_deepspeed, prepare_fsdp +from ...data_utils import is_conversational +from ...models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation from ...trainer.base_trainer import _BaseTrainer from ...trainer.utils import ( RepeatSampler, @@ -52,12 +51,21 @@ disable_dropout_in_model, get_config_model_id, identity, + pad, split_tensor_dict, + use_adapter, ) from ..utils import prepare_peft_model -from .online_rollout_mixin import OnlineRolloutMixin +from .loss_utils import ( + aggregate_loss, + apply_importance_sampling_clipping, + compute_full_logit_self_distillation_loss, + compute_sampled_token_self_distillation_loss, + compute_topk_self_distillation_loss, + select_token_log_probs, +) from .self_distillation_config import SelfDistillationConfig -from .self_distillation_mixin import SelfDistillationMixin +from .teacher_sync import PEFTAdapterEMACallback, SyncTeacherModelCallback if is_peft_available(): @@ -67,18 +75,59 @@ logger = get_logger(__name__) -class BaseSelfDistillationTrainer(OnlineRolloutMixin, SelfDistillationMixin, _BaseTrainer): - """Shared scaffold for experimental self-distillation trainers without GRPO inheritance.""" +@dataclass +class RolloutBatch: + """Student-side rollout produced by `sample_rollouts`, consumed by `finalize_batch` to form a `TrainingBatch`.""" + + prompt_ids: torch.Tensor + prompt_mask: torch.Tensor + completion_ids: torch.Tensor + completion_mask: torch.Tensor + old_per_token_logps: torch.Tensor | None = None + raw_completion_lengths: torch.Tensor | None = None + + def as_dict(self) -> dict[str, torch.Tensor | Any]: + batch: dict[str, torch.Tensor | Any] = { + "prompt_ids": self.prompt_ids, + "prompt_mask": self.prompt_mask, + "completion_ids": self.completion_ids, + "completion_mask": self.completion_mask, + } + if self.old_per_token_logps is not None: + batch["old_per_token_logps"] = self.old_per_token_logps + if self.raw_completion_lengths is not None: + batch["raw_completion_lengths"] = self.raw_completion_lengths + return batch + + +TrainingBatch = dict[str, torch.Tensor | Any] + + +@dataclass +class DistillationLogits: + """Aligned logits and masks used to compute a self-distillation objective.""" + + completion_ids: torch.Tensor + completion_mask: torch.Tensor + response_mask: torch.Tensor + student_logits: torch.Tensor + teacher_logits: torch.Tensor + + +class BaseSelfDistillationTrainer(_BaseTrainer, ABC): + """Base that centralizes shared self-distillation trainer lifecycle.""" + + config_cls = SelfDistillationConfig + _tag_names = ["trl", "self-distillation"] + _name = "Self-Distillation" def __init__( self, model: str | PreTrainedModel | nn.Module, - reward_funcs: Any | list[Any] | None = None, args: SelfDistillationConfig | None = None, train_dataset: Dataset | IterableDataset | None = None, eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, - reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None, callbacks: list[TrainerCallback] | None = None, optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), peft_config: PeftConfig | None = None, @@ -104,7 +153,20 @@ def __init__( else inspect.signature(model.get_base_model().forward).parameters.keys() ) - if peft_config is not None or (is_peft_available() and getattr(model, "peft_config", None) is not None): + if peft_config is None and getattr(model, "peft_config", None) is not None: + logger.warning( + "The provided self-distillation student model already contains a PEFT adapter. " + "This setup is accepted but not directly supported. In particular, `teacher_model_kind='base'` " + "may refer to the underlying base weights rather than the exact initially loaded student state " + "including its adapter. For unambiguous teacher behavior, start from a merged/non-adapter model " + "or manage separate adapters explicitly." + ) + if is_peft_model(model) and peft_config is not None: + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config`. Pass either a base " + "model with `peft_config`, or a pre-wrapped PEFT model." + ) + if peft_config is not None or getattr(model, "peft_config", None) is not None: model = prepare_peft_model(model, peft_config, args) if processing_class is None: @@ -124,7 +186,6 @@ def __init__( self.pad_token_id = tokenizer.pad_token_id self.eos_token_id = tokenizer.eos_token_id - self.temperature = args.temperature self.max_prompt_length = args.max_prompt_length self.max_completion_length = args.max_completion_length self.num_generations = args.num_generations @@ -132,12 +193,8 @@ def __init__( self.num_iterations = args.num_iterations self.shuffle_dataset = args.shuffle_dataset self.loss_type = args.loss_type - self.importance_sampling_level = args.importance_sampling_level - self.scale_rewards = args.scale_rewards - self.epsilon_low = args.epsilon - self.epsilon_high = args.epsilon_high - self.beta = args.beta self.mask_truncated_completions = args.mask_truncated_completions + self.temperature = args.temperature self.chat_template_kwargs = args.chat_template_kwargs or {} self._step = 0 self._last_loaded_step = 0 @@ -148,7 +205,7 @@ def __init__( "eval": defaultdict(int), } - generation_kwargs = { + self.generation_kwargs = { "max_new_tokens": self.max_completion_length, "do_sample": True, "pad_token_id": tokenizer.pad_token_id, @@ -162,8 +219,8 @@ def __init__( "cache_implementation": args.cache_implementation, } if args.generation_kwargs is not None: - generation_kwargs.update(args.generation_kwargs) - self.generation_config = GenerationConfig(**generation_kwargs, disable_compile=True) + self.generation_kwargs.update(args.generation_kwargs) + self.generation_config = GenerationConfig(**self.generation_kwargs, disable_compile=True) if hasattr(model, "warnings_issued"): model.warnings_issued["estimate_tokens"] = True @@ -211,75 +268,113 @@ def __init__( logprobs=None, generation_kwargs=args.generation_kwargs, ) - self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation - - if reward_funcs is None: - reward_funcs = [] - if not isinstance(reward_funcs, list): - reward_funcs = [reward_funcs] - self.reward_func_names = [] - for i, reward_func in enumerate(reward_funcs): - if isinstance(reward_func, str): - reward_model_init_kwargs = args.model_init_kwargs or {} - if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: - reward_model_init_kwargs["device_map"] = None - reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( - reward_func, - num_labels=1, - **reward_model_init_kwargs, - ) - if isinstance(reward_funcs[i], nn.Module): - self.reward_func_names.append(get_config_model_id(reward_funcs[i].config).split("/")[-1]) - else: - self.reward_func_names.append(reward_funcs[i].__name__) - self.reward_funcs = reward_funcs - - if args.reward_weights is not None: - if len(args.reward_weights) != len(self.reward_funcs): - raise ValueError("Number of reward weights must match number of reward functions") - self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) - else: - self.reward_weights = torch.ones(len(self.reward_funcs), dtype=torch.float32) - - if reward_processing_classes is None: - reward_processing_classes = [None] * len(self.reward_funcs) - elif not isinstance(reward_processing_classes, list): - reward_processing_classes = [reward_processing_classes] - if len(reward_processing_classes) != len(self.reward_funcs): - raise ValueError("Number of reward processing classes must match number of reward functions") - - for i, (reward_processing_class, reward_func) in enumerate( - zip(reward_processing_classes, self.reward_funcs, strict=True) - ): - if isinstance(reward_func, PreTrainedModel): - if reward_processing_class is None: - reward_processing_class = AutoTokenizer.from_pretrained(get_config_model_id(reward_func.config)) - if reward_processing_class.pad_token_id is None: - reward_processing_class.pad_token = reward_processing_class.eos_token - reward_func.config.pad_token_id = reward_processing_class.pad_token_id - reward_processing_classes[i] = reward_processing_class - self.reward_processing_classes = reward_processing_classes + self._last_loaded_step = -1 if args.disable_dropout: disable_dropout_in_model(self.model) - for i, reward_func in enumerate(self.reward_funcs): - if isinstance(reward_func, nn.Module): - if self.is_deepspeed_enabled: - self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) - elif self.is_fsdp_enabled: - self.reward_funcs[i] = prepare_fsdp(reward_func, self.accelerator) - else: - self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) - self.model.add_model_tags(self._tag_names) + self._setup_teacher_model() self.model_accepts_loss_kwargs = False - self.ref_model = None - self.teacher_model = None - if args.sync_ref_model: - raise ValueError( - "sync_ref_model is not supported on the shared online self-distillation base without `ref_model`." + + def _set_signature_columns_if_needed(self): + if self._signature_columns is None: + self._signature_columns = ["prompt", "privileged_context"] + + def _dispatch_self_distillation_callback(self, event_name: str, **payload) -> None: + for callback in self.callback_handler.callbacks: + callback_fn = getattr(callback, event_name, None) + if callback_fn is not None: + callback_fn( + args=self.args, + state=self.state, + control=self.control, + model=self.model, + processing_class=self.processing_class, + **payload, + ) + + def _setup_teacher_model(self) -> None: + """Prepare teacher state according to the semantic teacher choice. + + Resolve `teacher_model_kind` × PEFT state into the effective teacher: + + - `"live"` (any model): + Teacher is the student. No divergence, no callback. + - `"base"` + PEFT model: + Teacher reuses `self.model`; the base weights are recovered downstream by disabling the adapter + via `use_adapter` during teacher forward. + - `"base"` + non-PEFT model: + Teacher is a frozen deepcopy of the initial student (falls through to the copy branch below). + - `"ema"` + pure-LoRA training: + Teacher reuses `self.model`; a dedicated `"teacher"` LoRA adapter is attached and updated by + `PEFTAdapterEMACallback`. Teacher forward switches to that adapter downstream. + - `"ema"` (otherwise): + Teacher is a frozen deepcopy synchronized each step by `SyncTeacherModelCallback`. + + Must be called after `super().__init__` so that `self.callback_handler` is available. + """ + + teacher_model_kind = self.args.teacher_model_kind + + if teacher_model_kind == "live": + self.teacher_model = self.model + return + + if teacher_model_kind == "base" and is_peft_model(self.model): + self.teacher_model = self.model + return + + if self._use_peft_ema_teacher_adapter(): + # Must run after super().__init__ so self.callback_handler exists. + self.add_callback( + PEFTAdapterEMACallback( + model=self.model, + teacher_adapter_name="teacher", + update_rate=self.args.teacher_update_rate, + sync_steps=self.args.teacher_sync_steps, + accelerator=self.accelerator, + ) ) + self.teacher_model = self.model + return + + # create teacher model from student copy + student_model = self.accelerator.unwrap_model(self.model) + self.teacher_model = copy.deepcopy(student_model) + self.teacher_model.requires_grad_(False) + self.teacher_model.eval() + if self.is_deepspeed_enabled: + self.teacher_model = prepare_deepspeed(self.teacher_model, self.accelerator) + elif self.is_fsdp_enabled: + self.teacher_model = prepare_fsdp(self.teacher_model, self.accelerator) + else: + self.teacher_model = self.accelerator.prepare_model(self.teacher_model, evaluation_mode=True) + + if teacher_model_kind == "ema": + self.add_callback(SyncTeacherModelCallback(teacher_model=self.teacher_model, accelerator=self.accelerator)) + + def _use_peft_ema_teacher_adapter(self) -> bool: + return self.args.teacher_model_kind == "ema" and self._is_pure_lora_training() + + def _is_pure_lora_training(self) -> bool: + """Return `True` when the active adapter is LoRA and every trainable parameter is a LoRA parameter.""" + if not is_peft_model(self.model): + return False + + model = self.accelerator.unwrap_model(self.model) + adapter_name = getattr(model, "active_adapter", None) or "default" + adapter_config = model.peft_config.get(adapter_name) + peft_type = getattr(adapter_config, "peft_type", None) + if peft_type is None or str(peft_type).split(".")[-1] != "LORA": + return False + + for name, param in model.named_parameters(): + if param.requires_grad and "lora_" not in name: + return False + return True def get_train_dataloader(self): if self.train_dataset is None: @@ -323,7 +418,7 @@ def _get_train_sampler(self, dataset=None) -> Sampler: def _get_eval_sampler(self, eval_dataset) -> Sampler: return RepeatSampler( data_source=eval_dataset, - mini_repeat_count=getattr(self, "num_generations_eval", self.num_generations), + mini_repeat_count=self.num_generations_eval, seed=self.args.seed, ) @@ -332,24 +427,439 @@ def training_step(self, model, inputs, num_items_in_batch): self._step += 1 return output + def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): + if not isinstance(inputs, dict): + inputs = self._prepare_inputs(inputs) + with torch.no_grad(): + with self.compute_loss_context_manager(): + loss = self.compute_loss(model, inputs) + return loss.detach(), None, None + def _prepare_inputs(self, generation_batch): + """Return the per-step training batch, regenerating rollouts and buffering them for reuse in train mode. + + In train mode, rollouts are generated once every `steps_per_generation * num_iterations` steps and split + into per-step slices reused until the next regeneration. In eval mode, every batch is freshly prepared. + """ mode = "train" if self.model.training else "eval" if mode == "train": generate_every = self.args.steps_per_generation * self.num_iterations if self._step % generate_every == 0 or self._buffered_inputs is None: - generation_batch = self._build_buffered_batch(generation_batch) - self._buffered_inputs = split_tensor_dict(generation_batch, self.args.steps_per_generation) + buffered_batch = self._prepare_training_batch(generation_batch) + self._buffered_inputs = split_tensor_dict(buffered_batch, self.args.steps_per_generation) self._dispatch_self_distillation_callback( "on_generation_batch_built", generate_every=generate_every, steps_per_generation=self.args.steps_per_generation, ) return self._buffered_inputs[self._step % self.args.steps_per_generation] - return self._build_buffered_batch(generation_batch) + return self._prepare_training_batch(generation_batch) + + def _prepare_training_batch(self, inputs: list[dict[str, Any]]) -> TrainingBatch: + """Sample student rollouts and let the subclass finalize them into the final `TrainingBatch`.""" + rollout_batch = self.sample_rollouts(inputs) + + batch = self.finalize_batch(inputs, rollout_batch) + + self._dispatch_self_distillation_callback( + "on_self_distillation_batch_prepared", + old_per_token_logps=batch.get("old_per_token_logps"), + prompt_ids=batch["prompt_ids"], + completion_ids=batch["completion_ids"], + teacher_input_ids=batch["teacher_input_ids"], + teacher_attention_mask=batch["teacher_attention_mask"], + self_distillation_mask=batch.get("self_distillation_mask"), + ) + return batch + + def _tokenize_prompts_untruncated(self, prompts: list[Any]) -> list[list[int]]: + if is_conversational({"prompt": prompts[0]}): + tokenized = self.processing_class.apply_chat_template( + conversation=prompts, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + **self.chat_template_kwargs, + ) + prompt_ids = tokenized["input_ids"] + else: + prompt_ids = self.processing_class(text=prompts)["input_ids"] + return prompt_ids + + def _tokenize_prompts(self, prompts: list[Any]) -> list[list[int]]: + prompt_ids = self._tokenize_prompts_untruncated(prompts) + if self.max_prompt_length is not None: + prompt_ids = [ids[-self.max_prompt_length :] for ids in prompt_ids] + return prompt_ids + + def sample_rollouts(self, inputs: list[dict[str, Any]]) -> RolloutBatch: + prompts, _ = self._split_prompt_and_privileged_context(inputs) + prompt_ids = self._tokenize_prompts(prompts) + self._dispatch_self_distillation_callback( + "on_generation_prompts_selected", + generation_prompts=prompts, + generation_prompt_text=None, + ) - def _prepare_auxiliary_model_for_eval(self, aux_model: nn.Module): - if self.is_deepspeed_enabled: - return prepare_deepspeed(aux_model, self.accelerator) - if self.is_fsdp_enabled: - return prepare_fsdp(aux_model, self.accelerator) - return self.accelerator.prepare_model(aux_model, evaluation_mode=True) + prompt_ids_list, completion_ids_list = self._generate(prompt_ids) + device = self.accelerator.device + prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left").to(device=device) + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left").to(device=device) + completion_ids = [torch.tensor(ids) for ids in completion_ids_list] + completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right").to(device=device) + completion_mask = pad(completion_mask, padding_value=0, padding_side="right").to(device=device) + if self.mask_truncated_completions: + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() + old_per_token_logps = self._compute_rollout_logps( + prompt_ids=prompt_ids, + prompt_mask=prompt_mask, + completion_ids=completion_ids, + completion_mask=completion_mask, + ) + + return RolloutBatch( + prompt_ids=prompt_ids, + prompt_mask=prompt_mask, + completion_ids=completion_ids, + completion_mask=completion_mask, + old_per_token_logps=old_per_token_logps, + raw_completion_lengths=torch.tensor( + [len(ids) for ids in completion_ids_list], device=device, dtype=torch.long + ), + ) + + def _split_prompt_and_privileged_context(self, inputs: list[dict[str, Any]]) -> tuple[list[Any], list[Any]]: + prompts = [example["prompt"] for example in inputs] + privileged_contexts = [example.get("privileged_context") for example in inputs] + return prompts, privileged_contexts + + def _generate(self, prompt_ids: list[list[int]]) -> tuple[list[list[int]], list[list[int]]]: + if self.use_vllm: + return self._generate_vllm(prompt_ids) + return self._generate_transformers(prompt_ids) + + def _generate_vllm(self, prompt_ids: list[list[int]]) -> tuple[list[list[int]], list[list[int]]]: + if self.state.global_step != self._last_loaded_step: + self.vllm_generation.sync_weights() + self._last_loaded_step = self.state.global_step + + mode = "train" if self.model.training else "eval" + num_generations = self.num_generations if mode == "train" else self.num_generations_eval + prompt_ids_out, completion_ids_list, _, _ = self.vllm_generation.generate( + prompts=prompt_ids, + images=None, + num_generations=num_generations, + ) + return prompt_ids_out, completion_ids_list + + def _generate_transformers(self, prompt_ids: list[list[int]]) -> tuple[list[list[int]], list[list[int]]]: + device = self.accelerator.device + prompt_tensors = [torch.tensor(ids) for ids in prompt_ids] + padded_ids = pad(prompt_tensors, padding_value=self.pad_token_id, padding_side="left") + attention_mask = pad([torch.ones_like(t) for t in prompt_tensors], padding_value=0, padding_side="left") + generate_inputs = {"input_ids": padded_ids, "attention_mask": attention_mask} + generate_inputs = _BaseTrainer._prepare_inputs(self, generate_inputs) + + with ( + unwrap_model_for_generation( + self.model_wrapped, + self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation, + generation_kwargs=self.generation_kwargs, + ) as unwrapped_model, + torch.no_grad(), + FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), + ): + prompt_completion_ids = unwrapped_model.generate( + **generate_inputs, generation_config=self.generation_config + ) + + prompt_length = generate_inputs["input_ids"].size(1) + completion_ids = prompt_completion_ids[:, prompt_length:] + is_eos = completion_ids == self.eos_token_id + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + seq_idx = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) + completion_mask = (seq_idx <= eos_idx.unsqueeze(1)).int() + completion_ids_list = [ + c[m].tolist() for c, m in zip(completion_ids.cpu(), completion_mask.bool().cpu(), strict=True) + ] + return prompt_ids, completion_ids_list + + def _compute_rollout_logps( + self, + prompt_ids: torch.Tensor, + prompt_mask: torch.Tensor, + completion_ids: torch.Tensor, + completion_mask: torch.Tensor, + ) -> torch.Tensor | None: + generate_every = self.args.steps_per_generation * self.num_iterations + old_per_token_logps = None + + if self.args.gradient_accumulation_steps % generate_every != 0: + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) + with torch.no_grad(): + logits = self._forward_logits( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + ) + old_per_token_logps = select_token_log_probs(logits, completion_ids) + + return old_per_token_logps + + def _compute_self_distillation_loss( + self, + model, + inputs: TrainingBatch, + distillation_logits: DistillationLogits, + ) -> torch.Tensor: + """Compute the per-token distillation loss and aggregate it according to `loss_type`. + + Dispatches between three objectives based on `distillation_mode`: + + - `"topk_logits"`: top-k approximation of the divergence, optionally with a tail bucket for the + remaining probability mass (`distillation_add_tail`). + - `"full_logits"`: full-vocab divergence. + - `"sampled_token"`: token-level (reverse-KL) distillation on sampled `completion_ids`. + + When `distillation_is_clip` is set and `old_per_token_logps` are available, the loss is corrected by a + clipped importance-sampling ratio between the current student and the student at rollout time. + """ + if distillation_logits.response_mask.sum() == 0: + mode = "train" if model.training else "eval" + self._log_self_distillation_metric(mode, 0.0) + # Keep the zero loss attached to the student graph so backward produces zero gradients instead of stopping. + return distillation_logits.student_logits.sum() * 0.0 + + if self.args.distillation_mode == "topk_logits": + if self.args.distillation_topk is None: + raise ValueError("`distillation_mode='topk_logits'` requires `distillation_topk` to be set.") + per_token_loss = compute_topk_self_distillation_loss( + distillation_logits.student_logits, + distillation_logits.teacher_logits, + distillation_topk=self.args.distillation_topk, + distillation_alpha=self.args.distillation_alpha, + distillation_add_tail=self.args.distillation_add_tail, + ) + elif self.args.distillation_mode == "full_logits": + per_token_loss = compute_full_logit_self_distillation_loss( + distillation_logits.student_logits, + distillation_logits.teacher_logits, + distillation_alpha=self.args.distillation_alpha, + ) + elif self.args.distillation_mode == "sampled_token": + per_token_loss = compute_sampled_token_self_distillation_loss( + distillation_logits.student_logits, + distillation_logits.teacher_logits, + distillation_logits.completion_ids, + distillation_alpha=self.args.distillation_alpha, + ) + else: + raise ValueError( + "distillation_mode must be one of: 'sampled_token', 'full_logits', 'topk_logits', " + f"got {self.args.distillation_mode!r}" + ) + + old_per_token_logps = inputs.get("old_per_token_logps") + if self.args.distillation_is_clip is not None and old_per_token_logps is not None: + student_per_token_logps = select_token_log_probs( + distillation_logits.student_logits, + distillation_logits.completion_ids, + ) + per_token_loss = apply_importance_sampling_clipping( + per_token_loss, + student_per_token_logps, + old_per_token_logps, + self.args.distillation_is_clip, + ) + + loss = aggregate_loss( + per_token_loss, + distillation_logits.response_mask, + loss_type=self.loss_type, + max_completion_length=self.max_completion_length, + ) + + mode = "train" if model.training else "eval" + mean_distill_loss = ( + per_token_loss * distillation_logits.response_mask + ).sum() / distillation_logits.response_mask.sum().clamp(min=1.0) + self._log_self_distillation_metric( + mode, + self.accelerator.gather(mean_distill_loss).mean().item(), + ) + return loss + + def _compute_teacher_student_logits( + self, + model, + teacher_model, + inputs: TrainingBatch, + ) -> DistillationLogits: + """Run student and teacher forwards on their respective inputs and pack aligned logits into a `DistillationLogits`. + + The teacher forward runs under the teacher context resolved by `_get_teacher_context_for_self_distillation`. + """ + prompt_ids = inputs["prompt_ids"] + prompt_mask = inputs["prompt_mask"] + completion_ids = inputs["completion_ids"] + completion_mask = inputs["completion_mask"] + logits_to_keep = completion_ids.size(1) + + response_mask = self._build_self_distillation_response_mask( + completion_mask, + inputs.get("self_distillation_mask"), + ) + student_logits = self._compute_student_distillation_logits( + model=model, + prompt_ids=prompt_ids, + prompt_mask=prompt_mask, + completion_ids=completion_ids, + completion_mask=completion_mask, + logits_to_keep=logits_to_keep, + ) + + teacher_logits = self._compute_teacher_distillation_logits( + teacher_model=teacher_model, + teacher_input_ids=inputs["teacher_input_ids"], + teacher_attention_mask=inputs["teacher_attention_mask"], + logits_to_keep=logits_to_keep, + ) + + return DistillationLogits( + completion_ids=completion_ids, + completion_mask=completion_mask, + response_mask=response_mask, + student_logits=student_logits, + teacher_logits=teacher_logits, + ) + + @staticmethod + def _build_self_distillation_response_mask( + completion_mask: torch.Tensor, + self_distillation_mask: torch.Tensor | None, + ) -> torch.Tensor: + if self_distillation_mask is None: + return completion_mask + return completion_mask * self_distillation_mask.unsqueeze(1) + + def _compute_student_distillation_logits( + self, + model, + prompt_ids: torch.Tensor, + prompt_mask: torch.Tensor, + completion_ids: torch.Tensor, + completion_mask: torch.Tensor, + logits_to_keep: int, + ) -> torch.Tensor: + student_input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + student_attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + return self._forward_logits( + model=model, + input_ids=student_input_ids, + attention_mask=student_attention_mask, + logits_to_keep=logits_to_keep, + ) + + def _compute_teacher_distillation_logits( + self, + teacher_model, + teacher_input_ids: torch.Tensor, + teacher_attention_mask: torch.Tensor, + logits_to_keep: int, + ) -> torch.Tensor: + with torch.no_grad(), self._get_teacher_context_for_self_distillation(): + return self._forward_logits( + model=teacher_model, + input_ids=teacher_input_ids, + attention_mask=teacher_attention_mask, + logits_to_keep=logits_to_keep, + ) + + def _forward_logits( + self, + model, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + logits_to_keep: int, + ) -> torch.Tensor: + """Forward the model and return temperature-scaled logits aligned to the completion tokens.""" + model_inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "use_cache": False, + } + if "logits_to_keep" in self.model_kwarg_keys: + model_inputs["logits_to_keep"] = logits_to_keep + 1 + + logits = model(**model_inputs).logits + logits = logits[:, :-1, :] + logits = logits[:, -logits_to_keep:, :] + return logits / self.temperature + + @abstractmethod + def finalize_batch( + self, + inputs: list[dict[str, Any]], + rollout_batch: RolloutBatch, + ) -> TrainingBatch: + """Build the final training batch from a shared student rollout batch. + + Subclasses must return a `dict` with at least the following keys, all first-dim-aligned to the student + batch size `B`: + + - `prompt_ids`: student prompt ids. + - `prompt_mask`: student prompt mask. + - `completion_ids`: student completion ids. + - `completion_mask`: student completion mask. + - `teacher_input_ids`: teacher input ids. + - `teacher_attention_mask`: teacher attention mask. + """ + + def _get_teacher_context_for_self_distillation(self): + """Return the context manager that routes the teacher forward to the correct weights. + + For non-PEFT models this is a no-op. For PEFT models: + + - `teacher_model_kind == "base"`: disable the student adapter so the teacher forward uses the base weights. + - `teacher_model_kind == "ema"` under pure-LoRA training: switch to the `"teacher"` LoRA adapter. + - otherwise: no-op; the teacher is a separate deepcopy. + """ + teacher_model_kind = self.args.teacher_model_kind + if not is_peft_model(self.model): + return nullcontext() + + target_model = self.accelerator.unwrap_model(self.teacher_model) + + if teacher_model_kind == "base": + return use_adapter(target_model, adapter_name=None) + if teacher_model_kind == "ema" and self._use_peft_ema_teacher_adapter(): + return use_adapter(target_model, adapter_name="teacher") + return nullcontext() + + def _log_self_distillation_metric(self, mode: str, value: float) -> None: + metric_prefix = self._name.lower().replace(" ", "_") + self._metrics[mode]["self_distillation/distillation_loss"].append(value) + self._metrics[mode][f"{metric_prefix}/distillation_loss"].append(value) + + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: + mode = "train" if self.model.training else "eval" + metrics = {k: sum(v) / len(v) for k, v in self._metrics[mode].items() if v} + if mode == "eval": + metrics = {f"eval_{k}": v for k, v in metrics.items()} + logs = {**logs, **metrics} + super().log(logs, start_time) + self._metrics[mode].clear() + + @abstractmethod + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + """Subclasses own algorithm-specific loss composition on the final batch contract.""" diff --git a/trl/experimental/self_distillation/loss_utils.py b/trl/experimental/self_distillation/loss_utils.py new file mode 100644 index 00000000000..ac74a751099 --- /dev/null +++ b/trl/experimental/self_distillation/loss_utils.py @@ -0,0 +1,178 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pure helper functions for self-distillation loss computation.""" + +from __future__ import annotations + +import torch +import torch.nn.functional as F + + +def select_token_log_probs( + logits: torch.Tensor, + token_ids: torch.Tensor, +) -> torch.Tensor: + logsumexp = torch.logsumexp(logits, dim=-1, keepdim=True) + indices = token_ids.unsqueeze(-1) + return (torch.gather(logits, dim=-1, index=indices) - logsumexp).squeeze(-1) + + +def compute_divergence( + student_log_probs: torch.Tensor, + teacher_log_probs: torch.Tensor, + alpha: float, +) -> torch.Tensor: + if alpha == 0.0: + kl = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) + elif alpha == 1.0: + kl = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) + else: + alpha_t = torch.tensor(alpha, dtype=student_log_probs.dtype, device=student_log_probs.device) + mixture = torch.logsumexp( + torch.stack([student_log_probs + torch.log(1 - alpha_t), teacher_log_probs + torch.log(alpha_t)]), + dim=0, + ) + kl_teacher = F.kl_div(mixture, teacher_log_probs, reduction="none", log_target=True) + kl_student = F.kl_div(mixture, student_log_probs, reduction="none", log_target=True) + kl = torch.lerp(kl_student, kl_teacher, alpha) + return kl.sum(-1) + + +def add_tail(log_probs: torch.Tensor) -> torch.Tensor: + log_s = torch.logsumexp(log_probs, dim=-1, keepdim=True) + log_s = torch.clamp(log_s, max=-1e-7) + tail_log = torch.log(-torch.expm1(log_s)) + return torch.cat([log_probs, tail_log], dim=-1) + + +def renorm_topk_log_probs(log_probs: torch.Tensor) -> torch.Tensor: + return log_probs - torch.logsumexp(log_probs, dim=-1, keepdim=True) + + +def compute_token_level_distillation_loss( + student_log_probs: torch.Tensor, + teacher_log_probs: torch.Tensor, +) -> torch.Tensor: + log_ratio = student_log_probs - teacher_log_probs + return log_ratio.detach() * student_log_probs + + +def apply_importance_sampling_clipping( + per_token_loss: torch.Tensor, + student_log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + clip_coeff: float, +) -> torch.Tensor: + negative_approx_kl = (student_log_probs - old_log_probs).detach() + negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) + ratio = torch.exp(negative_approx_kl).clamp(max=clip_coeff) + return per_token_loss * ratio + + +def aggregate_loss( + per_token_loss: torch.Tensor, + response_mask: torch.Tensor, + *, + loss_type: str, + max_completion_length: int, +) -> torch.Tensor: + """Reduce a per-token loss tensor according to the configured reduction. + + Args: + per_token_loss: + Per-token loss values of shape `(batch_size, seq_len)`. + response_mask: + Mask selecting which completion-token positions contribute to the loss. + loss_type: + Reduction mode. Uses the same loss-type conventions as the GRPO-family trainers. + max_completion_length: + Used by the `dr_grpo` reduction, which normalizes by a fixed completion budget. + """ + if loss_type == "grpo": + loss = (per_token_loss * response_mask).sum(-1) / response_mask.sum(-1).clamp(min=1.0) + return loss.mean() + if loss_type == "bnpo": + return (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) + if loss_type == "dr_grpo": + return (per_token_loss * response_mask).sum() / (per_token_loss.size(0) * max_completion_length) + if loss_type in ["dapo", "luspo", "cispo", "sapo"]: + return (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) + raise ValueError(f"Unsupported loss_type: {loss_type}") + + +def compute_topk_self_distillation_loss( + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + *, + distillation_topk: int, + distillation_alpha: float, + distillation_add_tail: bool, +) -> torch.Tensor: + """Compute distillation loss on the student's top-k token support. + + The student's top-k logits define the support. The teacher distribution is projected onto the same token indices. + The selected support is then either renormalized or augmented with a tail bucket before the divergence is computed. + """ + student_logsumexp = torch.logsumexp(student_logits, dim=-1, keepdim=True) + topk_student_logits, topk_indices = torch.topk(student_logits, k=distillation_topk, dim=-1) + topk_student_log_probs = topk_student_logits - student_logsumexp + + teacher_logsumexp = torch.logsumexp(teacher_logits, dim=-1, keepdim=True) + topk_teacher_logits = torch.gather(teacher_logits, dim=-1, index=topk_indices) + topk_teacher_log_probs = topk_teacher_logits - teacher_logsumexp + + if distillation_add_tail: + topk_student_log_probs = add_tail(topk_student_log_probs) + topk_teacher_log_probs = add_tail(topk_teacher_log_probs) + else: + topk_student_log_probs = renorm_topk_log_probs(topk_student_log_probs) + topk_teacher_log_probs = renorm_topk_log_probs(topk_teacher_log_probs) + + return compute_divergence(topk_student_log_probs, topk_teacher_log_probs, distillation_alpha) + + +def compute_full_logit_self_distillation_loss( + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + *, + distillation_alpha: float, +) -> torch.Tensor: + """Compute full-vocabulary self-distillation loss between student and teacher logits.""" + student_log_probs = torch.log_softmax(student_logits, dim=-1) + teacher_log_probs = torch.log_softmax(teacher_logits, dim=-1) + return compute_divergence(student_log_probs, teacher_log_probs, distillation_alpha) + + +def compute_sampled_token_self_distillation_loss( + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + completion_ids: torch.Tensor, + *, + distillation_alpha: float, +) -> torch.Tensor: + """Compute token-level self-distillation loss only on the sampled completion tokens. + + This path compares student and teacher log-probabilities on the realized completion tokens rather than over a + larger token support. + """ + if distillation_alpha != 1.0: + raise ValueError( + "Only reverse KL (alpha=1.0) is supported for token-level distillation when " + f"`distillation_mode='sampled_token'`, got alpha={distillation_alpha}" + ) + + student_per_token_logps = select_token_log_probs(student_logits, completion_ids) + teacher_per_token_logps = select_token_log_probs(teacher_logits, completion_ids) + return compute_token_level_distillation_loss(student_per_token_logps, teacher_per_token_logps) diff --git a/trl/experimental/self_distillation/online_rollout_mixin.py b/trl/experimental/self_distillation/online_rollout_mixin.py deleted file mode 100644 index 490724582dc..00000000000 --- a/trl/experimental/self_distillation/online_rollout_mixin.py +++ /dev/null @@ -1,385 +0,0 @@ -# Copyright 2020-2026 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Online rollout helpers for experimental self-distillation trainers. - -This mixin owns generation, reward scoring, grouped reward normalization, and online policy-loss plumbing. It is paired -with `BaseSelfDistillationTrainer` for SDPO-style methods and intentionally kept separate from the generic distillation -loss logic in `self_distillation_mixin.py`. -""" - -from __future__ import annotations - -import torch -from torch import nn -from transformers.utils import logging - -from ...data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template -from ...models import unwrap_model_for_generation -from ...trainer.base_trainer import _BaseTrainer -from ...trainer.utils import pad - - -logger = logging.get_logger(__name__) - - -class OnlineRolloutMixin: - """Online rollout, reward, and policy-loss utilities shared by SDPO-like trainers.""" - - def _apply_prompt_template(self, prompts): - return [ - maybe_apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"] - for prompt in prompts - ] - - def _build_buffered_batch(self, generation_batch): - return self._generate_and_score_completions(generation_batch) - - def _generate(self, prompts): - if self.use_vllm: - return self._generate_vllm(prompts) - return self._generate_transformers(prompts) - - def _generate_vllm(self, prompts): - # Sync weights if training step changed - if self.state.global_step != self._last_loaded_step: - self.vllm_generation.sync_weights() - self._last_loaded_step = self.state.global_step - - # Tokenize prompts to token IDs - prompts_text = self._apply_prompt_template(prompts) - tokenized = self.processing_class( - text=prompts_text, - return_tensors=None, - padding=False, - max_length=self.max_prompt_length, - truncation=True, - add_special_tokens=False, - ) - prompt_ids = tokenized["input_ids"] # list of list[int] - - # Generate via vLLM — it deduplicates repeated prompts from RepeatSampler internally - mode = "train" if self.model.training else "eval" - num_generations = self.num_generations if mode == "train" else self.num_generations_eval - prompt_ids_out, completion_ids_list, _, _ = self.vllm_generation.generate( - prompts=prompt_ids, - images=None, - num_generations=num_generations, - ) - return prompt_ids_out, completion_ids_list - - def _generate_transformers(self, prompts): - # Keep the generation path aligned with the reference trainers: generate from left-padded prompts, - # then recover completion token spans by trimming prompt tokens and stopping at the first EOS. - prompts_text = self._apply_prompt_template(prompts) - generate_inputs = self.processing_class( - text=prompts_text, - return_tensors="pt", - padding=True, - padding_side="left", - max_length=self.max_prompt_length, - truncation=True, - add_special_tokens=False, - ) - # This path already receives tokenized model inputs. Bypass the buffered trainer hook and use the plain - # tensor/device preparation from `_BaseTrainer`. - generate_inputs = _BaseTrainer._prepare_inputs(self, generate_inputs) - with ( - unwrap_model_for_generation( - self.model_wrapped, - self.accelerator, - gather_deepspeed3_params=self.args.ds3_gather_for_generation, - ) as unwrapped_model, - torch.no_grad(), - ): - prompt_completion_ids = unwrapped_model.generate( - **generate_inputs, generation_config=self.generation_config - ) - prompt_ids = generate_inputs["input_ids"] - prompt_mask = generate_inputs["attention_mask"] - prompt_length = prompt_ids.size(1) - completion_ids = prompt_completion_ids[:, prompt_length:] - is_eos = completion_ids == self.eos_token_id - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device) - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - seq_idx = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1) - completion_mask = (seq_idx <= eos_idx.unsqueeze(1)).int() - prompt_ids_list = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool(), strict=False)] - completion_ids_list = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=False)] - return prompt_ids_list, completion_ids_list - - def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): - device = self.accelerator.device - if len(self.reward_funcs) == 0: - return torch.zeros((len(prompts), 0), device=device) - - rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) - keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]] - reward_kwargs = {key: [example[key] for example in inputs] for key in keys} - reward_kwargs["trainer_state"] = self.state - - for i, (reward_func, reward_processing_class) in enumerate( - zip(self.reward_funcs, self.reward_processing_classes, strict=True) - ): - if isinstance(reward_func, nn.Module): - if is_conversational(inputs[0]): - messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)] - texts = [ - apply_chat_template(x, reward_processing_class, **self.chat_template_kwargs)["text"] - for x in messages - ] - else: - texts = [p + c for p, c in zip(prompts, completions, strict=True)] - reward_inputs = reward_processing_class( - text=texts, - return_tensors="pt", - padding=True, - padding_side="right", - add_special_tokens=False, - ) - # Reward functions operate on tokenized tensors too, so they need the base Trainer input preparation - # rather than the outer buffered generation hook. - reward_inputs = _BaseTrainer._prepare_inputs(self, reward_inputs) - with torch.inference_mode(): - rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] - else: - output_reward_func = reward_func( - prompts=prompts, - completions=completions, - completion_ids=completion_ids_list, - **reward_kwargs, - ) - output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] - rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) - - return self.accelerator.gather(rewards_per_func) - - def _generate_and_score_completions(self, inputs): - device = self.accelerator.device - mode = "train" if self.model.training else "eval" - prompts = [x["prompt"] for x in inputs] - prompt_ids_list, completion_ids_list = self._generate(prompts) - - prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list] - prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left").to(device=device) - prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left").to(device=device) - completion_ids = [torch.tensor(ids) for ids in completion_ids_list] - completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right").to(device=device) - completion_mask = pad(completion_mask, padding_value=0, padding_side="right").to(device=device) - - if self.mask_truncated_completions: - eos_and_pad = [self.eos_token_id, self.pad_token_id] - is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) - completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() - - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - logits_to_keep = completion_ids.size(1) - - with torch.no_grad(): - generate_every = self.args.steps_per_generation * self.num_iterations - if self.args.gradient_accumulation_steps % generate_every != 0: - old_per_token_logps, _ = self._get_per_token_logps_and_entropies( - self.model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - compute_entropy=False, - ) - else: - old_per_token_logps = None - - if is_conversational({"prompt": prompts[0]}): - completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - completions = [[{"role": "assistant", "content": content}] for content in completions_text] - else: - completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - - rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) - if rewards_per_func.numel() == 0: - rewards = torch.zeros(self.accelerator.num_processes * len(prompts), device=device) - else: - rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) - num_generations = self.num_generations if mode == "train" else self.num_generations_eval - mean_grouped_rewards = rewards.view(-1, num_generations).mean(dim=1).repeat_interleave(num_generations, dim=0) - if self.scale_rewards == "batch": - std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards) - group_std_rewards = rewards.view(-1, num_generations).std(dim=1) - elif self.scale_rewards == "none": - std_rewards = torch.ones_like(rewards) - group_std_rewards = torch.ones(rewards.numel() // num_generations, device=device, dtype=rewards.dtype) - else: - group_std_rewards = rewards.view(-1, num_generations).std(dim=1) - std_rewards = group_std_rewards.repeat_interleave(num_generations, dim=0) - advantages = (rewards - mean_grouped_rewards) / (std_rewards + 1e-4) - self._record_reward_diagnostics(mode, rewards, rewards_per_func, group_std_rewards) - - local_batch_size = completion_ids.size(0) - process_start = self.accelerator.process_index * local_batch_size - process_slice = slice(process_start, process_start + local_batch_size) - rewards = rewards[process_slice] - advantages = advantages[process_slice] - - agg_completion_lengths = self.accelerator.gather( - torch.tensor([len(ids) for ids in completion_ids_list], device=device) - ) - self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) - - eos_and_pad = [self.eos_token_id, self.pad_token_id] - is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) - agg_is_truncated = self.accelerator.gather(is_truncated) - self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) - term_completion_lengths = agg_completion_lengths[~agg_is_truncated] - if len(term_completion_lengths) == 0: - term_completion_lengths = torch.zeros(1, device=device) - self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - - output = { - "prompt_ids": prompt_ids, - "prompt_mask": prompt_mask, - "completion_ids": completion_ids, - "completion_mask": completion_mask, - "rewards": rewards, - "advantages": advantages, - "num_items_in_batch": completion_mask.sum().detach(), - } - if old_per_token_logps is not None: - output["old_per_token_logps"] = old_per_token_logps - - self._dispatch_self_distillation_callback( - "on_self_distillation_batch_prepared", - old_per_token_logps=old_per_token_logps, - prompt_ids=prompt_ids, - completion_ids=completion_ids, - ) - return output - - def _record_reward_diagnostics( - self, - mode: str, - rewards: torch.Tensor, - rewards_per_func: torch.Tensor, - group_std_rewards: torch.Tensor, - ) -> None: - tolerance = self.args.diagnostics_flat_tolerance - - reward_mean = rewards.mean() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) - reward_std = rewards.std() if rewards.numel() > 1 else torch.tensor(0.0, device=self.accelerator.device) - reward_min = rewards.min() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) - reward_max = rewards.max() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) - flat_group_fraction = ( - (group_std_rewards <= tolerance).float().mean() - if group_std_rewards.numel() > 0 - else torch.tensor(1.0, device=self.accelerator.device) - ) - - self._metrics[mode]["self_distillation/reward_mean"].append(self.accelerator.gather(reward_mean).mean().item()) - self._metrics[mode]["self_distillation/reward_std"].append(self.accelerator.gather(reward_std).mean().item()) - self._metrics[mode]["self_distillation/reward_min"].append(self.accelerator.gather(reward_min).min().item()) - self._metrics[mode]["self_distillation/reward_max"].append(self.accelerator.gather(reward_max).max().item()) - self._metrics[mode]["self_distillation/group_reward_std_mean"].append( - self.accelerator.gather(group_std_rewards.mean() if group_std_rewards.numel() > 0 else reward_std) - .mean() - .item() - ) - self._metrics[mode]["self_distillation/flat_group_fraction"].append( - self.accelerator.gather(flat_group_fraction).mean().item() - ) - - if rewards_per_func.numel() > 0: - reward_func_means = rewards_per_func.nanmean(dim=0) - gathered_means = self.accelerator.gather(reward_func_means).view(-1, reward_func_means.numel()).mean(dim=0) - for reward_name, reward_func_mean in zip(self.reward_func_names, gathered_means.tolist(), strict=True): - self._metrics[mode][f"self_distillation/rewards/{reward_name}"].append(reward_func_mean) - - reward_is_flat = reward_std.item() <= tolerance - grouped_rewards_are_flat = flat_group_fraction.item() >= 1.0 - tolerance - if reward_is_flat and grouped_rewards_are_flat: - self._warn_on_degenerate_diagnostics( - mode=mode, - counter_key="flat_rewards", - message=( - "Observed flat SDPO rewards across all sampled generations. " - "Policy advantages will collapse to zero, and SDPO will not learn. " - "Check reward density, reward shaping, or `success_reward_threshold`." - ), - ) - else: - self._diagnostic_counters[mode]["flat_rewards"] = 0 - - def _warn_on_degenerate_diagnostics(self, mode: str, counter_key: str, message: str) -> None: - interval = self.args.diagnostics_warning_interval - if interval == 0: - return - - self._diagnostic_counters[mode][counter_key] += 1 - count = self._diagnostic_counters[mode][counter_key] - if count == 1 or count % interval == 0: - logger.warning("%s Consecutive degenerate steps: %s.", message, count) - - def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): - if return_outputs: - raise ValueError(f"The {self.__class__.__name__} does not support returning outputs") - return self._compute_loss(model, inputs) - - def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): - if not isinstance(inputs, dict): - inputs = self._prepare_inputs(inputs) - with torch.no_grad(): - with self.compute_loss_context_manager(): - loss = self.compute_loss(model, inputs) - return loss.detach(), None, None - - def _compute_loss(self, model, inputs): - prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] - completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] - input_ids = torch.cat([prompt_ids, completion_ids], dim=1) - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - logits_to_keep = completion_ids.size(1) - per_token_logps, _ = self._get_per_token_logps_and_entropies( - model, - input_ids, - attention_mask, - logits_to_keep, - compute_entropy=False, - ) - old_per_token_logps = inputs.get("old_per_token_logps") - old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps - advantages = inputs["advantages"] - if advantages.dim() == 1: - advantages = advantages.unsqueeze(1) - log_ratio = per_token_logps - old_per_token_logps - if self.importance_sampling_level == "sequence": - log_ratio = (log_ratio * completion_mask).sum(-1, keepdim=True) / completion_mask.sum( - -1, keepdim=True - ).clamp(min=1.0) - coef_1 = torch.exp(log_ratio) - coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) - per_token_loss = -torch.min(coef_1 * advantages, coef_2 * advantages) - - loss = self._aggregate_self_distillation_loss(per_token_loss, completion_mask) - - mode = "train" if self.model.training else "eval" - self._metrics[mode]["self_distillation/policy_loss"].append( - self.accelerator.gather(loss.detach()).mean().item() - ) - - accumulation_scale = self.current_gradient_accumulation_steps if mode == "train" else 1.0 - return loss / accumulation_scale diff --git a/trl/experimental/self_distillation/prompt_utils.py b/trl/experimental/self_distillation/prompt_utils.py new file mode 100644 index 00000000000..f5decce6d7a --- /dev/null +++ b/trl/experimental/self_distillation/prompt_utils.py @@ -0,0 +1,32 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any + + +def extract_last_user_text(messages: list[dict[str, Any]]) -> str: + """Extract the text content from the last user message in a conversational prompt.""" + last_message = messages[-1] + if last_message.get("role") != "user": + raise ValueError( + f"Self-distillation teacher prompt construction expects the conversation to end with a user turn, " + f"but the last message has role '{last_message.get('role')}'. " + f"Prompts ending with assistant prefills or tool turns are not supported." + ) + content = last_message.get("content", "") + if isinstance(content, list): + return " ".join(part.get("text", "") for part in content if part.get("type") == "text") + return content diff --git a/trl/experimental/self_distillation/self_distillation_config.py b/trl/experimental/self_distillation/self_distillation_config.py index c28c0f6be0b..4fe20c00b95 100644 --- a/trl/experimental/self_distillation/self_distillation_config.py +++ b/trl/experimental/self_distillation/self_distillation_config.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import Any +from typing import Any, Literal from transformers import TrainingArguments @@ -30,39 +30,120 @@ class SelfDistillationConfig(_BaseConfig): [`trl.trainer.base_config._BaseConfig`]. Parameters: - > Parameters that control generation and rollout reuse + > Parameters that control the model model_init_kwargs (`dict[str, Any]`, *optional*): - Keyword arguments used when the `model` argument is passed as a string. + Keyword arguments for model initialization when the `model` argument is passed as a string. + disable_dropout (`bool`, *optional*, defaults to `False`): + Whether to disable dropout in the student model. + remove_unused_columns (`bool`, *optional*, defaults to `False`): + Whether to drop dataset columns unused by the trainer. + + > Parameters that control generation and rollout reuse + max_prompt_length (`int` or `None`, *optional*, defaults to `512`): Maximum prompt length. Longer prompts are truncated from the left. num_generations (`int`, *optional*, defaults to `8`): Number of sampled generations per prompt. + num_generations_eval (`int` or `None`, *optional*): + Number of sampled generations per prompt during evaluation. + max_completion_length (`int` or `None`, *optional*, defaults to `256`): + Maximum generated completion length. + ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): + Whether to gather ZeRO-3 weights for generation. + shuffle_dataset (`bool`, *optional*, defaults to `True`): + Whether to shuffle the training dataset. generation_batch_size (`int` or `None`, *optional*): Global batch size used for generation. Mutually exclusive with `steps_per_generation`. steps_per_generation (`int` or `None`, *optional*): Number of optimizer steps that reuse one generated batch. Mutually exclusive with `generation_batch_size`. + > Parameters that control sampling + + temperature (`float`, *optional*, defaults to `1.0`): + Sampling temperature. + top_p (`float`, *optional*, defaults to `1.0`): + Top-p sampling parameter. + top_k (`int`, *optional*, defaults to `0`): + Top-k sampling parameter. `0` disables top-k filtering. + min_p (`float` or `None`, *optional*): + Minimum token probability for sampling. + generation_kwargs (`dict[str, Any]` or `None`, *optional*): + Extra generation kwargs passed to `GenerationConfig`. + chat_template_kwargs (`dict[str, Any]` or `None`, *optional*): + Extra kwargs forwarded to chat template application. + repetition_penalty (`float`, *optional*, defaults to `1.0`): + Repetition penalty used during generation. + use_transformers_paged (`bool`, *optional*, defaults to `False`): + Reserved for paged generation support. + cache_implementation (`str` or `None`, *optional*): + Cache implementation used by transformers generation. + + > Parameters that control vLLM generation + + use_vllm (`bool`, *optional*, defaults to `False`): + Whether to use vLLM for generation. + vllm_mode (`str`, *optional*, defaults to `"colocate"`): + vLLM mode: `"colocate"` (shared GPU) or `"server"` (separate vLLM server). + vllm_model_impl (`str`, *optional*, defaults to `"vllm"`): + Model implementation for vLLM: `"vllm"` or `"transformers"`. + vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`): + Whether to enable sleep mode for colocated vLLM engine. + vllm_server_base_url (`str` or `None`, *optional*): + Base URL for the vLLM server. If provided, `vllm_server_host` and `vllm_server_port` are ignored. + vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`): + Host of the vLLM server (server mode only). + vllm_server_port (`int`, *optional*, defaults to `8000`): + Port of the vLLM server (server mode only). + vllm_group_port (`int`, *optional*, defaults to `51216`): + Port for the weight update group (server mode only). + vllm_server_timeout (`float`, *optional*, defaults to `240.0`): + Timeout in seconds to wait for the vLLM server. + vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`): + Tensor parallel size for colocated vLLM. + vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.3`): + GPU memory utilization ratio for colocated vLLM. + vllm_max_model_length (`int` or `None`, *optional*): + Model context length for vLLM. Inferred from model config if not set. + > Parameters that control the online policy objective - beta (`float`, *optional*, defaults to `0.0`): - Reference-model KL coefficient for online policy optimization. + num_iterations (`int`, *optional*, defaults to `1`): + Number of optimization iterations per generated batch. loss_type (`str`, *optional*, defaults to `"dapo"`): Policy-loss aggregation mode. Supported: `grpo`, `bnpo`, `dr_grpo`, `dapo`. - scale_rewards (`str` or `bool`, *optional*, defaults to `"group"`): - Reward normalization mode. Supported: `group`, `batch`, `none`. + mask_truncated_completions (`bool`, *optional*, defaults to `False`): + Whether to exclude truncated completions from the loss. + top_entropy_quantile (`float`, *optional*, defaults to `1.0`): + Reserved for entropy-based token filtering. + + > Parameters that control teacher construction + + teacher_model_kind (`str`, *optional*, defaults to `"live"`): + Semantic teacher choice. `live` uses the current student, `base` uses the initial student, and `ema` uses + an exponentially averaged teacher. + teacher_update_rate (`float`, *optional*, defaults to `0.6`): + EMA update rate used when `teacher_model_kind="ema"`. A value of `1.0` reduces the update to a hard + overwrite, periodically resyncing the teacher to the current student weights. + teacher_sync_steps (`int`, *optional*, defaults to `512`): + Number of optimizer steps between teacher updates. > Parameters that control self-distillation distillation_alpha (`float`, *optional*, defaults to `0.5`): - Divergence interpolation coefficient using the official SDPO/SDFT convention: `0.0=forward KL`, `0.5=JSD`, + KL divergence direction: `0.0=forward KL`, `0.5=JSD`, `1.0=reverse KL`. - distillation_topk (`int` or `None`, *optional*, defaults to `100`): - Number of top tokens to keep for top-k distillation. If `None`, all logits are used. - full_logit_distillation (`bool`, *optional*, defaults to `False`): - Whether to use full-logit distillation instead of token-level distillation. + distillation_mode (`Literal["sampled_token", "full_logits", "topk_logits"]`, *optional*, defaults to `"sampled_token"`): + Distillation objective mode. `"sampled_token"` uses token-level distillation on the sampled completion + tokens, `"full_logits"` uses full-vocabulary divergence, and `"topk_logits"` uses a top-k approximation + over the student support. + distillation_topk (`int` or `None`, *optional*): + Number of top tokens for `"topk_logits"` distillation. Must be set when + `distillation_mode="topk_logits"` and left unset otherwise. distillation_is_clip (`float` or `None`, *optional*, defaults to `2.0`): - Importance-sampling clip used by the official SDPO-style correction. `None` disables clipping. + Clipping coefficient for importance sampling in self-distillation. `None` disables clipping. + distillation_add_tail (`bool`, *optional*, defaults to `False`): + Whether to add a tail bucket for non-top-k probability mass. distillation_weight (`float`, *optional*, defaults to `1.0`): Weight applied to the self-distillation loss term. @@ -118,7 +199,9 @@ class SelfDistillationConfig(_BaseConfig): ) steps_per_generation: int | None = field( default=None, - metadata={"help": "Number of optimizer steps that reuse one generated batch."}, + metadata={ + "help": "Number of optimizer steps that reuse one generated batch. Mutually exclusive with `generation_batch_size`." + }, ) temperature: float = field( default=1.0, @@ -202,34 +285,10 @@ class SelfDistillationConfig(_BaseConfig): default=None, metadata={"help": "Model context length for vLLM. Inferred from model config if not set."}, ) - beta: float = field( - default=0.0, - metadata={"help": "Reference-model KL coefficient for online policy optimization."}, - ) num_iterations: int = field( default=1, metadata={"help": "Number of optimization iterations per generated batch."}, ) - epsilon: float = field( - default=0.2, - metadata={"help": "Lower clipping coefficient for GRPO-style policy loss."}, - ) - epsilon_high: float | None = field( - default=None, - metadata={"help": "Upper clipping coefficient. Defaults to `epsilon` when unset."}, - ) - importance_sampling_level: str = field( - default="token", - metadata={"help": "Importance-sampling granularity. Supported: `token`, `sequence`."}, - ) - reward_weights: list[float] | None = field( - default=None, - metadata={"help": "Optional weights for multiple reward functions."}, - ) - scale_rewards: str | bool = field( - default="group", - metadata={"help": "Reward normalization mode. Supported: `group`, `batch`, `none`."}, - ) loss_type: str = field( default="dapo", metadata={"help": "Policy loss aggregation. Supported: `grpo`, `bnpo`, `dr_grpo`, `dapo`."}, @@ -238,17 +297,23 @@ class SelfDistillationConfig(_BaseConfig): default=False, metadata={"help": "Whether to exclude truncated completions from the loss."}, ) - sync_ref_model: bool = field( - default=False, - metadata={"help": "Whether to synchronize the reference model with the student model."}, + teacher_model_kind: str = field( + default="live", + metadata={ + "help": "Semantic teacher choice. `live` uses the current student, `base` uses the initial student, " + "and `ema` uses an exponentially averaged teacher." + }, ) - ref_model_mixup_alpha: float = field( + teacher_update_rate: float = field( default=0.6, - metadata={"help": "EMA mix coefficient used when syncing the reference model."}, + metadata={ + "help": 'EMA update rate used when `teacher_model_kind="ema"`. A value of `1.0` reduces the update ' + "to a hard overwrite, periodically resyncing the teacher to the current student weights." + }, ) - ref_model_sync_steps: int = field( + teacher_sync_steps: int = field( default=512, - metadata={"help": "How often to synchronize the reference model."}, + metadata={"help": "Number of optimizer steps between teacher updates."}, ) top_entropy_quantile: float = field( default=1.0, @@ -256,19 +321,28 @@ class SelfDistillationConfig(_BaseConfig): ) distillation_alpha: float = field( default=0.5, - metadata={"help": "KL divergence direction: 0.0=forward KL, 0.5=JSD, 1.0=reverse KL."}, + metadata={"help": "KL divergence direction: `0.0=forward KL`, `0.5=JSD`, `1.0=reverse KL`."}, ) - distillation_topk: int | None = field( - default=100, - metadata={"help": "Number of top tokens for top-k distillation. If None, uses all tokens."}, + distillation_mode: Literal["sampled_token", "full_logits", "topk_logits"] = field( + default="sampled_token", + metadata={ + "help": "Distillation objective mode. `sampled_token` uses token-level distillation on the sampled " + "completion tokens, `full_logits` uses full-vocabulary divergence, and `topk_logits` uses a top-k " + "approximation over the student support." + }, ) - full_logit_distillation: bool = field( - default=False, - metadata={"help": "Whether to use full-logit distillation instead of token-level distillation."}, + distillation_topk: int | None = field( + default=None, + metadata={ + "help": "Number of top tokens for `topk_logits` distillation. Must be set when " + "`distillation_mode='topk_logits'` and left unset otherwise." + }, ) distillation_is_clip: float | None = field( default=2.0, - metadata={"help": "Clipping coefficient for importance sampling in self-distillation."}, + metadata={ + "help": "Clipping coefficient for importance sampling in self-distillation. `None` disables clipping." + }, ) distillation_add_tail: bool = field( default=False, @@ -294,20 +368,27 @@ class SelfDistillationConfig(_BaseConfig): def __post_init__(self): super().__post_init__() - self.scale_rewards = {True: "group", False: "none"}.get(self.scale_rewards, self.scale_rewards) - if self.scale_rewards not in ["group", "batch", "none"]: - raise ValueError("scale_rewards must be one of: 'group', 'batch', 'none'") - - if self.importance_sampling_level not in ["token", "sequence"]: - raise ValueError("importance_sampling_level must be either 'token' or 'sequence'") if self.loss_type not in ["grpo", "bnpo", "dr_grpo", "dapo"]: raise ValueError("loss_type must be one of: 'grpo', 'bnpo', 'dr_grpo', 'dapo'") + if self.teacher_model_kind not in {"base", "live", "ema"}: + raise ValueError("teacher_model_kind must be one of: 'base', 'live', 'ema'") + if not 0.0 <= self.teacher_update_rate <= 1.0: + raise ValueError("teacher_update_rate must be in [0, 1]") + if self.teacher_sync_steps <= 0: + raise ValueError("teacher_sync_steps must be positive") if self.num_generations < 1: raise ValueError("num_generations must be at least 1") if not 0.0 <= self.distillation_alpha <= 1.0: raise ValueError("distillation_alpha must be in [0, 1]") + if self.distillation_mode not in {"sampled_token", "full_logits", "topk_logits"}: + raise ValueError("distillation_mode must be one of: 'sampled_token', 'full_logits', 'topk_logits'") if self.distillation_topk is not None and self.distillation_topk <= 0: raise ValueError("distillation_topk must be positive when provided") + if self.distillation_mode == "topk_logits": + if self.distillation_topk is None: + raise ValueError("`distillation_mode='topk_logits'` requires `distillation_topk` to be set.") + elif self.distillation_topk is not None: + raise ValueError("`distillation_topk` is only valid when `distillation_mode='topk_logits'`.") if self.distillation_is_clip is not None and self.distillation_is_clip <= 0: raise ValueError("distillation_is_clip must be positive when provided") if self.distillation_weight < 0: @@ -345,6 +426,3 @@ def __post_init__(self): f"The global eval batch size ({self.per_device_eval_batch_size} * {num_processes}) must be " f"divisible by the number of generations used for evaluation ({num_generations_eval})." ) - - if self.epsilon_high is None: - self.epsilon_high = self.epsilon diff --git a/trl/experimental/self_distillation/self_distillation_mixin.py b/trl/experimental/self_distillation/self_distillation_mixin.py deleted file mode 100644 index fb2a8808de1..00000000000 --- a/trl/experimental/self_distillation/self_distillation_mixin.py +++ /dev/null @@ -1,295 +0,0 @@ -# Copyright 2020-2026 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Shared self-distillation loss utilities used by experimental trainers. - -This module intentionally holds only the reusable distillation mechanics: callback dispatch, common prompt/context -helpers, and the student-vs-teacher loss computation. Trainer lifecycle and online rollout concerns live in the trainer -classes or their online-specific base. -""" - -from __future__ import annotations - -from contextlib import nullcontext -from typing import Any - -import torch -import torch.nn.functional as F - -from ...trainer.utils import entropy_from_logits, selective_log_softmax -from .self_distillation_config import SelfDistillationConfig - - -class SelfDistillationMixin: - """Reusable self-distillation helpers shared across experimental trainers.""" - - config_cls = SelfDistillationConfig - - def _set_signature_columns_if_needed(self): - if self._signature_columns is None: - self._signature_columns = ["prompt", "privileged_context"] - - def _dispatch_self_distillation_callback(self, event_name: str, **payload) -> None: - for callback in self.callback_handler.callbacks: - callback_fn = getattr(callback, event_name, None) - if callback_fn is not None: - callback_fn( - args=self.args, - state=self.state, - control=self.control, - model=self.model, - processing_class=self.processing_class, - **payload, - ) - - @staticmethod - def _split_prompt_and_privileged_context(inputs: list[dict[str, Any]]) -> tuple[list[Any], list[Any]]: - prompts = [example["prompt"] for example in inputs] - privileged_contexts = [example.get("privileged_context") for example in inputs] - return prompts, privileged_contexts - - def _allow_topk_without_full_logit_distillation(self) -> bool: - return True - - def _get_per_token_logps_and_entropies( - self, - model, - input_ids, - attention_mask, - logits_to_keep, - compute_entropy=False, - ): - model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} - if "logits_to_keep" in self.model_kwarg_keys: - model_inputs["logits_to_keep"] = logits_to_keep + 1 - logits = model(**model_inputs).logits - logits = logits[:, :-1, :] - logits = logits[:, -logits_to_keep:, :] - logits = logits / self.temperature - completion_ids = input_ids[:, -logits_to_keep:] - selected_logps = selective_log_softmax(logits, completion_ids) - entropies = entropy_from_logits(logits) if compute_entropy else None - return selected_logps, entropies - - def _compute_self_distillation_loss( - self, - model, - inputs: dict[str, Any], - ) -> torch.Tensor: - # Expected batch contract: - # - required: `prompt_ids`, `prompt_mask`, `completion_ids`, `completion_mask`, - # `teacher_input_ids`, `teacher_attention_mask` - # - optional: `self_distillation_mask` to zero-out samples without teacher supervision, - # `old_per_token_logps` to enable IS clipping when generation and optimization are misaligned - prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] - completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] - logits_to_keep = completion_ids.size(1) - - self_distillation_mask = inputs.get("self_distillation_mask") - if self_distillation_mask is not None: - response_mask = completion_mask * self_distillation_mask.unsqueeze(1) - else: - response_mask = completion_mask - - if response_mask.sum() == 0: - mode = "train" if model.training else "eval" - self._log_self_distillation_metric(mode, "distillation_loss", 0.0) - return torch.tensor(0.0, device=completion_ids.device, requires_grad=True) - - student_input_ids = torch.cat([prompt_ids, completion_ids], dim=1) - student_attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - student_model_inputs = { - "input_ids": student_input_ids, - "attention_mask": student_attention_mask, - "use_cache": False, - } - if "logits_to_keep" in self.model_kwarg_keys: - student_model_inputs["logits_to_keep"] = logits_to_keep + 1 - - student_logits = model(**student_model_inputs).logits - student_logits = student_logits[:, :-1, :] - student_logits = student_logits[:, -logits_to_keep:, :] - student_logits = student_logits / self.temperature - - teacher_input_ids = inputs["teacher_input_ids"] - teacher_attention_mask = inputs["teacher_attention_mask"] - teacher_model_inputs = { - "input_ids": teacher_input_ids, - "attention_mask": teacher_attention_mask, - "use_cache": False, - } - if "logits_to_keep" in self.model_kwarg_keys: - teacher_model_inputs["logits_to_keep"] = logits_to_keep + 1 - - teacher_model = self._get_teacher_model_for_self_distillation(model) - with torch.no_grad(), self._get_teacher_context_for_self_distillation(model): - teacher_logits = teacher_model(**teacher_model_inputs).logits - teacher_logits = teacher_logits[:, :-1, :] - teacher_logits = teacher_logits[:, -logits_to_keep:, :] - teacher_logits = teacher_logits / self.temperature - - use_topk_distillation = self.args.distillation_topk is not None and ( - self.args.full_logit_distillation or self._allow_topk_without_full_logit_distillation() - ) - if use_topk_distillation: - student_logsumexp = torch.logsumexp(student_logits, dim=-1, keepdim=True) - topk_student_logits, topk_indices = torch.topk(student_logits, k=self.args.distillation_topk, dim=-1) - topk_student_log_probs = topk_student_logits - student_logsumexp - - teacher_logsumexp = torch.logsumexp(teacher_logits, dim=-1, keepdim=True) - topk_teacher_logits = torch.gather(teacher_logits, dim=-1, index=topk_indices) - topk_teacher_log_probs = topk_teacher_logits - teacher_logsumexp - - if self.args.distillation_add_tail: - topk_student_log_probs = self._add_tail(topk_student_log_probs) - topk_teacher_log_probs = self._add_tail(topk_teacher_log_probs) - else: - topk_student_log_probs = self._renorm_topk_log_probs(topk_student_log_probs) - topk_teacher_log_probs = self._renorm_topk_log_probs(topk_teacher_log_probs) - - per_token_loss = self._compute_divergence( - topk_student_log_probs, topk_teacher_log_probs, self.args.distillation_alpha - ) - elif self.args.full_logit_distillation: - student_log_probs = F.log_softmax(student_logits, dim=-1) - teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) - per_token_loss = self._compute_divergence( - student_log_probs, teacher_log_probs, self.args.distillation_alpha - ) - else: - if self.args.distillation_alpha != 1.0: - raise ValueError( - "Only reverse KL (alpha=1.0) is supported for token-level distillation when " - "`full_logit_distillation=False`, " - f"got alpha={self.args.distillation_alpha}" - ) - student_logsumexp = torch.logsumexp(student_logits, dim=-1, keepdim=True) - teacher_logsumexp = torch.logsumexp(teacher_logits, dim=-1, keepdim=True) - idx = completion_ids.unsqueeze(-1) - student_per_token_logps = (torch.gather(student_logits, dim=-1, index=idx) - student_logsumexp).squeeze(-1) - teacher_per_token_logps = (torch.gather(teacher_logits, dim=-1, index=idx) - teacher_logsumexp).squeeze(-1) - per_token_loss = self._compute_token_level_distillation_loss( - student_per_token_logps, teacher_per_token_logps - ) - - if self.args.distillation_is_clip is not None: - old_log_probs = inputs.get("old_per_token_logps") - if old_log_probs is not None: - with torch.no_grad(): - student_lse = torch.logsumexp(student_logits, dim=-1, keepdim=True) - idx = completion_ids.unsqueeze(-1) - student_per_token_logps = (torch.gather(student_logits, dim=-1, index=idx) - student_lse).squeeze( - -1 - ) - per_token_loss = self._apply_importance_sampling_clipping( - per_token_loss, student_per_token_logps, old_log_probs, self.args.distillation_is_clip - ) - - loss = self._aggregate_self_distillation_loss(per_token_loss, response_mask) - - mode = "train" if model.training else "eval" - mean_distill_loss = (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) - self._log_self_distillation_metric( - mode, - "distillation_loss", - self.accelerator.gather(mean_distill_loss).mean().item(), - ) - - return loss - - def _get_teacher_model_for_self_distillation(self, model): - teacher_model = getattr(self, "teacher_model", None) - if teacher_model is None: - return model - return teacher_model - - def _get_teacher_context_for_self_distillation(self, model): - return nullcontext() - - def _log_self_distillation_metric(self, mode: str, metric_name: str, value: float) -> None: - metric_prefix = getattr(self, "_name", "self_distillation").lower().replace(" ", "_") - self._metrics[mode][f"self_distillation/{metric_name}"].append(value) - self._metrics[mode][f"{metric_prefix}/{metric_name}"].append(value) - - @staticmethod - def _compute_divergence( - student_log_probs: torch.Tensor, - teacher_log_probs: torch.Tensor, - alpha: float, - ) -> torch.Tensor: - if alpha == 0.0: - kl = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) - elif alpha == 1.0: - kl = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) - else: - alpha_t = torch.tensor(alpha, dtype=student_log_probs.dtype, device=student_log_probs.device) - mixture = torch.logsumexp( - torch.stack([student_log_probs + torch.log(1 - alpha_t), teacher_log_probs + torch.log(alpha_t)]), - dim=0, - ) - kl_teacher = F.kl_div(mixture, teacher_log_probs, reduction="none", log_target=True) - kl_student = F.kl_div(mixture, student_log_probs, reduction="none", log_target=True) - kl = torch.lerp(kl_student, kl_teacher, alpha) - return kl.sum(-1) - - @staticmethod - def _add_tail(log_probs: torch.Tensor) -> torch.Tensor: - log_s = torch.logsumexp(log_probs, dim=-1, keepdim=True) - log_s = torch.clamp(log_s, max=-1e-7) - tail_log = torch.log(-torch.expm1(log_s)) - return torch.cat([log_probs, tail_log], dim=-1) - - @staticmethod - def _renorm_topk_log_probs(log_probs: torch.Tensor) -> torch.Tensor: - return log_probs - torch.logsumexp(log_probs, dim=-1, keepdim=True) - - @staticmethod - def _compute_token_level_distillation_loss( - student_log_probs: torch.Tensor, - teacher_log_probs: torch.Tensor, - ) -> torch.Tensor: - # This is the token-level reverse-KL surrogate used by the official SDPO implementation for - # `full_logit_distillation=False`. It intentionally treats the teacher log-probs as fixed targets - # and keeps only the score-function term for the sampled student tokens. - log_ratio = student_log_probs - teacher_log_probs - return log_ratio.detach() * student_log_probs - - @staticmethod - def _apply_importance_sampling_clipping( - per_token_loss: torch.Tensor, - student_log_probs: torch.Tensor, - old_log_probs: torch.Tensor, - clip_coeff: float, - ) -> torch.Tensor: - negative_approx_kl = (student_log_probs - old_log_probs).detach() - negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) - ratio = torch.exp(negative_approx_kl).clamp(max=clip_coeff) - return per_token_loss * ratio - - def _aggregate_self_distillation_loss( - self, - per_token_loss: torch.Tensor, - response_mask: torch.Tensor, - ) -> torch.Tensor: - loss_type = self.loss_type - if loss_type == "grpo": - loss = (per_token_loss * response_mask).sum(-1) / response_mask.sum(-1).clamp(min=1.0) - return loss.mean() - if loss_type == "bnpo": - return (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) - if loss_type == "dr_grpo": - return (per_token_loss * response_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) - if loss_type in ["dapo", "luspo", "cispo", "sapo"]: - return (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) - raise ValueError(f"Unsupported loss_type for self-distillation: {loss_type}") diff --git a/trl/experimental/self_distillation/teacher_context.py b/trl/experimental/self_distillation/teacher_context.py deleted file mode 100644 index 5e1020c91a7..00000000000 --- a/trl/experimental/self_distillation/teacher_context.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2020-2026 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any - -import torch - -from ...data_utils import maybe_apply_chat_template -from ...trainer.base_trainer import _BaseTrainer -from ...trainer.utils import pad - - -def extract_last_user_text(prompt: list[dict[str, Any]]) -> str: - """Extract the text content from the last user message in a conversational prompt.""" - last_message = prompt[-1] - if last_message.get("role") != "user": - raise ValueError( - f"Self-distillation teacher prompt construction expects the conversation to end with a user turn, " - f"but the last message has role '{last_message.get('role')}'. " - f"Prompts ending with assistant prefills or tool turns are not supported." - ) - content = last_message.get("content", "") - if isinstance(content, list): - return " ".join(part.get("text", "") for part in content if part.get("type") == "text") - return content - - -@dataclass -class TokenizedPromptBatch: - prompt_ids: torch.Tensor - prompt_mask: torch.Tensor - - -class PromptTokenizer: - """Internal helper to tokenize prompt-like inputs consistently across self-distillation trainers.""" - - def __init__(self, trainer): - self.trainer = trainer - - def apply_prompt_template(self, prompts: list[Any]) -> list[str]: - return [ - maybe_apply_chat_template( - {"prompt": prompt}, - self.trainer.processing_class, - **getattr(self.trainer, "chat_template_kwargs", {}), - )["prompt"] - for prompt in prompts - ] - - def tokenize_prompts(self, prompts: list[Any]) -> TokenizedPromptBatch: - prompt_text = self.apply_prompt_template(prompts) - prompt_inputs = self.trainer.processing_class( - text=prompt_text, - return_tensors="pt", - padding=True, - padding_side="left", - max_length=self.trainer.max_prompt_length, - truncation=True, - add_special_tokens=False, - ) - prompt_inputs = super(_BaseTrainer, self.trainer)._prepare_inputs(prompt_inputs) - prompt_ids = [ - p[m].tolist() - for p, m in zip(prompt_inputs["input_ids"], prompt_inputs["attention_mask"].bool(), strict=False) - ] - prompt_ids = [torch.tensor(ids, device=self.trainer.accelerator.device) for ids in prompt_ids] - prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] - return TokenizedPromptBatch( - prompt_ids=pad(prompt_ids, padding_value=self.trainer.pad_token_id, padding_side="left"), - prompt_mask=pad(prompt_mask, padding_value=0, padding_side="left"), - ) diff --git a/trl/experimental/self_distillation/peft_adapter_ema_callback.py b/trl/experimental/self_distillation/teacher_sync.py similarity index 88% rename from trl/experimental/self_distillation/peft_adapter_ema_callback.py rename to trl/experimental/self_distillation/teacher_sync.py index e252bb512a4..267ef4d5854 100644 --- a/trl/experimental/self_distillation/peft_adapter_ema_callback.py +++ b/trl/experimental/self_distillation/teacher_sync.py @@ -22,10 +22,26 @@ TrainingArguments, ) +from ...trainer.callbacks import SyncRefModelCallback + logger = logging.getLogger(__name__) +class SyncTeacherModelCallback(SyncRefModelCallback): + """Synchronize an EMA teacher model with the student model on each configured sync step.""" + + def __init__(self, teacher_model, accelerator=None): + super().__init__(ref_model=teacher_model, accelerator=accelerator) + + def on_step_end(self, args, state, control, **kwargs): + model = kwargs["model"] + if self.ref_model is not None and state.global_step % args.teacher_sync_steps == 0: + if self.accelerator: + model = self.accelerator.unwrap_model(model) + self.sync_target_model(model, self.ref_model, args.teacher_update_rate) + + class PEFTAdapterEMACallback(TrainerCallback): """ Callback that maintains an EMA copy of PEFT adapter weights for use as a teacher model in self-distillation.