From 06f02a8dd1c17fbfe2b72a2fa3fadca6446748e6 Mon Sep 17 00:00:00 2001 From: Leon Date: Wed, 15 Apr 2026 08:15:34 +0200 Subject: [PATCH 01/23] v0.1 transition sdft into unified base --- trl/experimental/sdft/sdft_trainer.py | 28 + .../sdft/sdft_trainer_transition.py | 256 +++++++++ trl/experimental/sdpo/sdpo_trainer.py | 36 +- .../unified_base_self_distillation_trainer.py | 544 ++++++++++++++++++ 4 files changed, 848 insertions(+), 16 deletions(-) create mode 100644 trl/experimental/sdft/sdft_trainer_transition.py create mode 100644 trl/experimental/self_distillation/unified_base_self_distillation_trainer.py diff --git a/trl/experimental/sdft/sdft_trainer.py b/trl/experimental/sdft/sdft_trainer.py index 5bf6095c2a0..057ae851678 100644 --- a/trl/experimental/sdft/sdft_trainer.py +++ b/trl/experimental/sdft/sdft_trainer.py @@ -54,6 +54,10 @@ ) from ..self_distillation.self_distillation_mixin import SelfDistillationMixin from ..self_distillation.teacher_context import PromptTokenizer, extract_last_user_text +from ..self_distillation.unified_base_self_distillation_trainer import ( + SelfDistillationBatch, + SelfDistillationRolloutBatch, +) from ..utils import prepare_peft_model from .sdft_config import SDFTConfig @@ -462,6 +466,30 @@ def _build_buffered_batch(self, inputs: list[dict[str, Any]]) -> dict[str, torch output["old_per_token_logps"] = old_per_token_logps return output + def augment_training_batch( + self, + inputs: list[dict[str, Any]], + rollout_batch: SelfDistillationRolloutBatch, + ) -> SelfDistillationBatch: + prompts, privileged_contexts = self._split_prompt_and_privileged_context(inputs) + teacher_batch = self.teacher_context_builder.build( + prompts, + privileged_contexts, + rollout_batch.completion_ids, + rollout_batch.completion_mask, + ) + + old_per_token_logps = None if self.generate_from_teacher else rollout_batch.old_per_token_logps + return SelfDistillationBatch( + prompt_ids=teacher_batch["prompt_ids"], + prompt_mask=teacher_batch["prompt_mask"], + completion_ids=rollout_batch.completion_ids, + completion_mask=rollout_batch.completion_mask, + teacher_input_ids=teacher_batch["teacher_input_ids"], + teacher_attention_mask=teacher_batch["teacher_attention_mask"], + old_per_token_logps=old_per_token_logps, + ) + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): if return_outputs: raise ValueError("The SDFTTrainer does not support returning outputs") diff --git a/trl/experimental/sdft/sdft_trainer_transition.py b/trl/experimental/sdft/sdft_trainer_transition.py new file mode 100644 index 00000000000..e349a2918ff --- /dev/null +++ b/trl/experimental/sdft/sdft_trainer_transition.py @@ -0,0 +1,256 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import copy +import textwrap +from typing import Any + +import torch +from accelerate.logging import get_logger +from accelerate.utils import is_peft_model +from datasets import Dataset, IterableDataset +from torch import nn +from transformers import ( + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, +) +from transformers.utils import is_peft_available + +from ...models import prepare_deepspeed, prepare_fsdp +from ...trainer.callbacks import SyncRefModelCallback +from ...trainer.utils import ( + use_adapter, +) +from ..self_distillation.teacher_context import PromptTokenizer, extract_last_user_text +from ..self_distillation.unified_base_self_distillation_trainer import ( + SelfDistillationBatch, + SelfDistillationRolloutBatch, + UnifiedBaseSelfDistillationTrainer, +) +from .sdft_config import SDFTConfig + + +if is_peft_available(): + from peft import PeftConfig + from peft.peft_model import PeftModel + + from ..self_distillation.peft_adapter_ema_callback import PEFTAdapterEMACallback + + +logger = get_logger(__name__) + + +class DemonstrationTeacherContextBuilder: + """Builds student and teacher contexts from prompts plus privileged context, as in SDFT.""" + + def __init__(self, trainer): + self.trainer = trainer + self.prompt_tokenizer = PromptTokenizer(trainer) + + def _stringify_privileged_context(self, privileged_context: Any) -> str: + if privileged_context is None: + raise ValueError( + "`privileged_context` must not be None for self-distillation teacher prompt construction." + ) + if isinstance(privileged_context, str): + return privileged_context + if isinstance(privileged_context, list) and privileged_context and isinstance(privileged_context[0], dict): + chunks = [] + for message in privileged_context: + content = message.get("content", "") + if isinstance(content, list): + text = " ".join(part.get("text", "") for part in content if part.get("type") == "text") + else: + text = str(content) + if text: + chunks.append(text) + return "\n".join(chunks) + return str(privileged_context) + + def _compose_teacher_prompt(self, prompt: Any, privileged_context: Any) -> Any: + privileged_text = self._stringify_privileged_context(privileged_context) + if isinstance(prompt, list): + system_messages = prompt[:-1] + prompt_text = extract_last_user_text(prompt) + teacher_text = self.trainer.args.teacher_prompt_template.format( + prompt=prompt_text, + privileged_context=privileged_text, + ) + return system_messages + [{"role": "user", "content": teacher_text}] + return self.trainer.args.teacher_prompt_template.format(prompt=prompt, privileged_context=privileged_text) + + def select_generation_prompts(self, prompts: list[Any], privileged_contexts: list[Any]) -> list[Any]: + if not self.trainer.generate_from_teacher: + return prompts + return [ + self._compose_teacher_prompt(prompt, privileged_context) + for prompt, privileged_context in zip(prompts, privileged_contexts, strict=True) + ] + + def build( + self, + prompts: list[Any], + privileged_contexts: list[Any], + completion_ids: torch.Tensor, + completion_mask: torch.Tensor, + ) -> dict[str, torch.Tensor]: + student_batch = self.prompt_tokenizer.tokenize_prompts(prompts) + teacher_prompts = [ + self._compose_teacher_prompt(prompt, privileged_context) + for prompt, privileged_context in zip(prompts, privileged_contexts, strict=True) + ] + teacher_batch = self.prompt_tokenizer.tokenize_prompts(teacher_prompts) + teacher_input_ids = torch.cat([teacher_batch.prompt_ids, completion_ids], dim=1) + teacher_attention_mask = torch.cat([teacher_batch.prompt_mask, completion_mask], dim=1) + return { + "prompt_ids": student_batch.prompt_ids, + "prompt_mask": student_batch.prompt_mask, + "teacher_input_ids": teacher_input_ids, + "teacher_attention_mask": teacher_attention_mask, + } + + +class SDFTTrainer(UnifiedBaseSelfDistillationTrainer): + """Trainer for SDFT-style on-policy self-distillation with explicit teacher prompts.""" + + _tag_names = ["trl", "sdft"] + _name = "SDFT" + config_cls = SDFTConfig + # docstyle-ignore + _paper = { + "title": "Self-Training with On-Policy Self-Distillation for Language Model Alignment", + "id": "2601.19897", + "citation": textwrap.dedent("""\ + @article{hubotter2026selftraining, + title = {{Self-Training with On-Policy Self-Distillation for Language Model Alignment}}, + author = {Jonas H\\"ubotter and Frederike L\\"ubeck and Lejs Behric and Anton Baumann and Marco Bagatella and Daniel Marta and Ido Hakimi and Idan Shenfeld and Thomas Kleine Buening and Carlos Guestrin and Andreas Krause}, + year = 2026, + eprint = {arXiv:2601.19897} + }"""), + } + + def __init__( + self, + model: str | PreTrainedModel | nn.Module, + args: SDFTConfig | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, + processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), + peft_config: PeftConfig | None = None, + ): + if isinstance(train_dataset, IterableDataset): + raise NotImplementedError("Iterable datasets are not yet supported in SDFTTrainer.") + if isinstance(eval_dataset, IterableDataset) or ( + isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values()) + ): + raise NotImplementedError("Iterable eval datasets are not yet supported in SDFTTrainer.") + + super().init( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + callbacks=callbacks, + optimizers=optimizers, + peft_config=peft_config, + ) + + self.num_loss_tokens_to_skip = args.num_loss_tokens_to_skip + self.generate_from_teacher = args.generate_from_teacher + self.teacher_context_builder = DemonstrationTeacherContextBuilder(self) + + # In self-distillation the teacher is always derived from the student: + # - PEFT: base model with adapter disabled (or EMA teacher adapter when sync_ref_model=True) + # - Non-PEFT: same model (or deep-copied EMA model when sync_ref_model=True) + self.teacher_model = None + + if args.sync_ref_model: + if is_peft_available() and is_peft_model(self.model): + self.add_callback( + PEFTAdapterEMACallback( + model=self.model, + teacher_adapter_name="teacher", + update_rate=args.ref_model_mixup_alpha, + sync_steps=args.ref_model_sync_steps, + accelerator=self.accelerator, + ) + ) + else: + student_model = self.accelerator.unwrap_model(self.model) + self.teacher_model = copy.deepcopy(student_model) + self.teacher_model.requires_grad_(False) + self.teacher_model.eval() + if self.is_deepspeed_enabled: + self.teacher_model = prepare_deepspeed(self.teacher_model, self.accelerator) + elif self.is_fsdp_enabled: + self.teacher_model = prepare_fsdp(self.teacher_model, self.accelerator) + else: + self.teacher_model = self.accelerator.prepare_model(self.teacher_model, evaluation_mode=True) + self.add_callback(SyncRefModelCallback(ref_model=self.teacher_model, accelerator=self.accelerator)) + + self.model_accepts_loss_kwargs = False + + def augment_training_batch( + self, + inputs: list[dict[str, Any]], + rollout_batch: SelfDistillationRolloutBatch, + ) -> SelfDistillationBatch: + prompts, privileged_contexts = self._split_prompt_and_privileged_context(inputs) + teacher_batch = self.teacher_context_builder.build( + prompts, + privileged_contexts, + rollout_batch.completion_ids, + rollout_batch.completion_mask, + ) + + old_per_token_logps = None if self.generate_from_teacher else rollout_batch.old_per_token_logps + return SelfDistillationBatch( + prompt_ids=teacher_batch["prompt_ids"], + prompt_mask=teacher_batch["prompt_mask"], + completion_ids=rollout_batch.completion_ids, + completion_mask=rollout_batch.completion_mask, + teacher_input_ids=teacher_batch["teacher_input_ids"], + teacher_attention_mask=teacher_batch["teacher_attention_mask"], + old_per_token_logps=old_per_token_logps, + ) + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if return_outputs: + raise ValueError("The SDFTTrainer does not support returning outputs") + + if self.num_loss_tokens_to_skip > 0: + inputs = dict(inputs) + completion_mask = inputs["completion_mask"].clone() + token_positions = torch.arange(completion_mask.size(1), device=completion_mask.device).unsqueeze(0) + completion_mask = completion_mask * (token_positions >= self.num_loss_tokens_to_skip).long() + inputs["completion_mask"] = completion_mask + + loss = self._compute_self_distillation_loss(model, inputs) + accumulation_scale = self.current_gradient_accumulation_steps if self.model.training else 1.0 + return loss / accumulation_scale + + def _get_teacher_context_for_self_distillation(self, model): + if is_peft_available() and isinstance(self.model, PeftModel): + model = self.accelerator.unwrap_model(self.model) + if self.args.sync_ref_model and "teacher" in model.peft_config: + return use_adapter(model, adapter_name="teacher") + return use_adapter(model, adapter_name=None) + return super()._get_teacher_context_for_self_distillation(model) diff --git a/trl/experimental/sdpo/sdpo_trainer.py b/trl/experimental/sdpo/sdpo_trainer.py index ef84a17a44c..cf346ac3b60 100644 --- a/trl/experimental/sdpo/sdpo_trainer.py +++ b/trl/experimental/sdpo/sdpo_trainer.py @@ -365,23 +365,27 @@ def _warn_on_inactive_self_distillation(self, mode: str) -> None: else: self._diagnostic_counters[mode]["no_successful_rollouts"] = 0 - def _compute_loss( - self, - model, - inputs, - ) -> torch.Tensor: - accumulation_scale = self.current_gradient_accumulation_steps if self.model.training else 1.0 + def _compute_policy_loss(self, model, inputs) -> torch.Tensor: + return super()._compute_loss(model, inputs) - if self.args.sdpo_policy_loss_mode == "hybrid": - base_policy_loss = super()._compute_loss(model, inputs) - if self.args.distillation_weight <= 0.0: - return base_policy_loss + def _compute_weighted_self_distillation_loss(self, model, inputs) -> torch.Tensor | None: + if self.args.distillation_weight <= 0.0: + return None - sdpo_loss = self._compute_self_distillation_loss(model, inputs) / accumulation_scale - return base_policy_loss + self.args.distillation_weight * sdpo_loss + accumulation_scale = self.current_gradient_accumulation_steps if self.model.training else 1.0 + distillation_loss = self._compute_self_distillation_loss(model, inputs) / accumulation_scale + return self.args.distillation_weight * distillation_loss - if self.args.distillation_weight <= 0.0: - return super()._compute_loss(model, inputs) + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if return_outputs: + raise ValueError("The SDPOTrainer does not support returning outputs") - sdpo_loss = self._compute_self_distillation_loss(model, inputs) / accumulation_scale - return self.args.distillation_weight * sdpo_loss + if self.args.sdpo_policy_loss_mode == "hybrid": + policy_loss = self._compute_policy_loss(model, inputs) + weighted_distillation_loss = self._compute_weighted_self_distillation_loss(model, inputs) + return policy_loss if weighted_distillation_loss is None else policy_loss + weighted_distillation_loss + + weighted_distillation_loss = self._compute_weighted_self_distillation_loss(model, inputs) + if weighted_distillation_loss is not None: + return weighted_distillation_loss + return self._compute_policy_loss(model, inputs) diff --git a/trl/experimental/self_distillation/unified_base_self_distillation_trainer.py b/trl/experimental/self_distillation/unified_base_self_distillation_trainer.py new file mode 100644 index 00000000000..14eb68a0a5a --- /dev/null +++ b/trl/experimental/self_distillation/unified_base_self_distillation_trainer.py @@ -0,0 +1,544 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import copy +import inspect +from abc import ABC, abstractmethod +from collections import defaultdict +from dataclasses import dataclass, field +from functools import partial +from typing import Any + +import datasets +import torch +from accelerate.logging import get_logger +from datasets import Dataset, IterableDataset +from torch import nn +from torch.utils.data import DataLoader, Sampler +from transformers import ( + AutoProcessor, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, +) +from transformers.trainer_utils import seed_worker +from transformers.utils import is_datasets_available, is_peft_available + +from ...models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation +from ...trainer.base_trainer import _BaseTrainer +from ...trainer.callbacks import SyncRefModelCallback +from ...trainer.utils import ( + RepeatSampler, + create_model_from_path, + disable_dropout_in_model, + get_config_model_id, + identity, + pad, + split_tensor_dict, +) +from ..utils import prepare_peft_model +from .self_distillation_config import SelfDistillationConfig +from .self_distillation_mixin import SelfDistillationMixin +from .teacher_context import PromptTokenizer + + +if is_peft_available(): + from peft import PeftConfig + + +logger = get_logger(__name__) + + +@dataclass +class SelfDistillationRolloutBatch: + """Common student rollout batch produced before algorithm-specific augmentation.""" + + prompt_ids: torch.Tensor + prompt_mask: torch.Tensor + completion_ids: torch.Tensor + completion_mask: torch.Tensor + old_per_token_logps: torch.Tensor | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, torch.Tensor | Any]: + output: dict[str, torch.Tensor | Any] = { + "prompt_ids": self.prompt_ids, + "prompt_mask": self.prompt_mask, + "completion_ids": self.completion_ids, + "completion_mask": self.completion_mask, + } + if self.old_per_token_logps is not None: + output["old_per_token_logps"] = self.old_per_token_logps + output.update(self.metadata) + return output + + +@dataclass +class SelfDistillationBatch: + """Final self-distillation batch contract consumed by `SelfDistillationMixin`.""" + + prompt_ids: torch.Tensor + prompt_mask: torch.Tensor + completion_ids: torch.Tensor + completion_mask: torch.Tensor + teacher_input_ids: torch.Tensor + teacher_attention_mask: torch.Tensor + old_per_token_logps: torch.Tensor | None = None + self_distillation_mask: torch.Tensor | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, torch.Tensor | Any]: + output: dict[str, torch.Tensor | Any] = { + "prompt_ids": self.prompt_ids, + "prompt_mask": self.prompt_mask, + "completion_ids": self.completion_ids, + "completion_mask": self.completion_mask, + "teacher_input_ids": self.teacher_input_ids, + "teacher_attention_mask": self.teacher_attention_mask, + } + if self.old_per_token_logps is not None: + output["old_per_token_logps"] = self.old_per_token_logps + if self.self_distillation_mask is not None: + output["self_distillation_mask"] = self.self_distillation_mask + output.update(self.metadata) + return output + + +class UnifiedBaseSelfDistillationTrainer(SelfDistillationMixin, _BaseTrainer, ABC): + """Prototype base that centralizes shared self-distillation trainer lifecycle.""" + + config_cls = SelfDistillationConfig + _tag_names = ["trl", "self-distillation"] + _name = "Self-Distillation" + + def __init__( + self, + model: str | PreTrainedModel | nn.Module, + args: SelfDistillationConfig | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, + processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), + peft_config: PeftConfig | None = None, + ): + if train_dataset is None: + raise ValueError("`train_dataset` is required") + self.use_vllm = args.use_vllm + + if isinstance(model, str): + model_init_kwargs = args.model_init_kwargs or {} + if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + model_init_kwargs["device_map"] = None + model = create_model_from_path(model, **model_init_kwargs) + elif args.model_init_kwargs is not None: + logger.warning( + "You passed `model_init_kwargs` to the self-distillation config, but `model` is already " + "instantiated. The `model_init_kwargs` will be ignored." + ) + + self.model_kwarg_keys = ( + inspect.signature(model.forward).parameters.keys() + if not hasattr(model, "get_base_model") + else inspect.signature(model.get_base_model().forward).parameters.keys() + ) + + if peft_config is not None or (is_peft_available() and getattr(model, "peft_config", None) is not None): + model = prepare_peft_model(model, peft_config, args) + + if processing_class is None: + processing_class = AutoProcessor.from_pretrained( + get_config_model_id(model.config), truncation_side="left", padding_side="left" + ) + + if isinstance(processing_class, ProcessorMixin): + tokenizer = processing_class.tokenizer + elif isinstance(processing_class, PreTrainedTokenizerBase): + tokenizer = processing_class + else: + raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + self.pad_token = tokenizer.pad_token + self.pad_token_id = tokenizer.pad_token_id + self.eos_token_id = tokenizer.eos_token_id + self.max_prompt_length = args.max_prompt_length + self.max_completion_length = args.max_completion_length + self.num_generations = args.num_generations + self.num_generations_eval = args.num_generations_eval or args.num_generations + self.num_iterations = args.num_iterations + self.shuffle_dataset = args.shuffle_dataset + self.loss_type = args.loss_type + self.temperature = args.temperature + self.chat_template_kwargs = args.chat_template_kwargs or {} + self._step = 0 + self._last_loaded_step = 0 + self._buffered_inputs = None + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._diagnostic_counters = { + "train": defaultdict(int), + "eval": defaultdict(int), + } + self.prompt_tokenizer = PromptTokenizer(self) + + generation_kwargs = { + "max_new_tokens": self.max_completion_length, + "do_sample": True, + "pad_token_id": tokenizer.pad_token_id, + "bos_token_id": tokenizer.bos_token_id, + "eos_token_id": tokenizer.eos_token_id, + "temperature": args.temperature, + "top_p": args.top_p, + "top_k": args.top_k, + "min_p": args.min_p, + "repetition_penalty": args.repetition_penalty, + "cache_implementation": args.cache_implementation, + } + if args.generation_kwargs is not None: + generation_kwargs.update(args.generation_kwargs) + self.generation_config = GenerationConfig(**generation_kwargs, disable_compile=True) + + if hasattr(model, "warnings_issued"): + model.warnings_issued["estimate_tokens"] = True + + super().__init__( + model=model, + args=args, + data_collator=identity, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + callbacks=callbacks, + optimizers=optimizers, + compute_loss_func="non-None value to disable scaling", + ) + + if self.use_vllm: + from ...generation.vllm_generation import VLLMGeneration + + self.vllm_generation = VLLMGeneration( + model=self.model, + accelerator=self.accelerator, + is_fsdp_enabled=self.is_fsdp_enabled, + processing_class=self.processing_class, + mode=args.vllm_mode, + server_base_url=args.vllm_server_base_url, + server_host=args.vllm_server_host, + server_port=args.vllm_server_port, + group_port=args.vllm_group_port, + server_timeout=args.vllm_server_timeout, + tensor_parallel_size=args.vllm_tensor_parallel_size, + gpu_memory_utilization=args.vllm_gpu_memory_utilization, + max_model_length=args.vllm_max_model_length, + max_num_seqs=args.per_device_train_batch_size + * args.vllm_tensor_parallel_size + * args.steps_per_generation, + enable_sleep_mode=args.vllm_enable_sleep_mode, + model_impl=args.vllm_model_impl, + repetition_penalty=args.repetition_penalty, + temperature=self.temperature, + top_p=args.top_p, + top_k=args.top_k, + min_p=args.min_p, + max_completion_length=self.max_completion_length, + logprobs=None, + generation_kwargs=args.generation_kwargs, + ) + self._last_loaded_step = -1 + + if args.disable_dropout: + disable_dropout_in_model(self.model) + + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + self.teacher_model = None + self._setup_teacher_model() + self.model_accepts_loss_kwargs = False + + def _setup_teacher_model(self) -> None: + """Prepare a generic teacher model derived from the student. + + Subclasses can override this when they need algorithm-specific teacher handling, such as adapter switching. + """ + + if not self.args.sync_ref_model: + return + + student_model = self.accelerator.unwrap_model(self.model) + self.teacher_model = copy.deepcopy(student_model) + self.teacher_model.requires_grad_(False) + self.teacher_model.eval() + + if self.is_deepspeed_enabled: + self.teacher_model = prepare_deepspeed(self.teacher_model, self.accelerator) + elif self.is_fsdp_enabled: + self.teacher_model = prepare_fsdp(self.teacher_model, self.accelerator) + else: + self.teacher_model = self.accelerator.prepare_model(self.teacher_model, evaluation_mode=True) + + self.add_callback(SyncRefModelCallback(ref_model=self.teacher_model, accelerator=self.accelerator)) + + def get_train_dataloader(self): + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + dataloader_params = { + "batch_size": self._train_batch_size * self.args.steps_per_generation, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = partial( + seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index + ) + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + + def _get_train_sampler(self, dataset=None) -> Sampler: + if dataset is None: + dataset = self.train_dataset + return RepeatSampler( + data_source=dataset, + mini_repeat_count=self.num_generations, + batch_size=self.args.generation_batch_size // self.num_generations, + repeat_count=self.num_iterations * self.args.steps_per_generation, + shuffle=self.shuffle_dataset, + seed=self.args.seed, + ) + + def _get_eval_sampler(self, eval_dataset) -> Sampler: + return RepeatSampler( + data_source=eval_dataset, + mini_repeat_count=self.num_generations_eval, + seed=self.args.seed, + ) + + def training_step(self, model, inputs, num_items_in_batch): + output = super().training_step(model, inputs, num_items_in_batch) + self._step += 1 + return output + + def _prepare_inputs(self, generation_batch): + mode = "train" if self.model.training else "eval" + if mode == "train": + generate_every = self.args.steps_per_generation * self.num_iterations + if self._step % generate_every == 0 or self._buffered_inputs is None: + buffered_batch = self._build_buffered_batch(generation_batch) + self._buffered_inputs = split_tensor_dict(buffered_batch, self.args.steps_per_generation) + self._dispatch_self_distillation_callback( + "on_generation_batch_built", + generate_every=generate_every, + steps_per_generation=self.args.steps_per_generation, + ) + return self._buffered_inputs[self._step % self.args.steps_per_generation] + return self._build_buffered_batch(generation_batch) + + def _build_buffered_batch(self, inputs: list[dict[str, Any]]) -> dict[str, torch.Tensor | Any]: + return self.build_training_batch(inputs).to_dict() + + def build_training_batch(self, inputs: list[dict[str, Any]]) -> SelfDistillationBatch: + rollout_batch = self.build_rollout_batch(inputs) + self._validate_rollout_batch(rollout_batch) + + batch = self.augment_training_batch(inputs, rollout_batch) + self._validate_training_batch(batch) + + self._dispatch_self_distillation_callback( + "on_self_distillation_batch_prepared", + old_per_token_logps=batch.old_per_token_logps, + prompt_ids=batch.prompt_ids, + completion_ids=batch.completion_ids, + teacher_input_ids=batch.teacher_input_ids, + teacher_attention_mask=batch.teacher_attention_mask, + self_distillation_mask=batch.self_distillation_mask, + ) + return batch + + def build_rollout_batch(self, inputs: list[dict[str, Any]]) -> SelfDistillationRolloutBatch: + prompts, _ = self._split_prompt_and_privileged_context(inputs) + generation_prompts = prompts + generation_prompt_text = self.prompt_tokenizer.apply_prompt_template(generation_prompts) + self._dispatch_self_distillation_callback( + "on_generation_prompts_selected", + generation_prompts=generation_prompts, + generation_prompt_text=generation_prompt_text, + ) + + prompt_ids_list, completion_ids_list = self._generate(generation_prompts) + device = self.accelerator.device + prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left").to(device=device) + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left").to(device=device) + completion_ids = [torch.tensor(ids) for ids in completion_ids_list] + completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right").to(device=device) + completion_mask = pad(completion_mask, padding_value=0, padding_side="right").to(device=device) + old_per_token_logps = self.compute_rollout_logps( + prompt_ids=prompt_ids, + prompt_mask=prompt_mask, + completion_ids=completion_ids, + completion_mask=completion_mask, + ) + + return SelfDistillationRolloutBatch( + prompt_ids=prompt_ids, + prompt_mask=prompt_mask, + completion_ids=completion_ids, + completion_mask=completion_mask, + old_per_token_logps=old_per_token_logps, + ) + + def _generate(self, prompts: list[Any]) -> tuple[list[list[int]], list[list[int]]]: + if self.use_vllm: + return self._generate_vllm(prompts) + return self._generate_transformers(prompts) + + def _generate_vllm(self, prompts: list[Any]) -> tuple[list[list[int]], list[list[int]]]: + if self.state.global_step != self._last_loaded_step: + self.vllm_generation.sync_weights() + self._last_loaded_step = self.state.global_step + + prompts_text = self.prompt_tokenizer.apply_prompt_template(prompts) + tokenized = self.processing_class( + text=prompts_text, + return_tensors=None, + padding=False, + max_length=self.max_prompt_length, + truncation=True, + add_special_tokens=False, + ) + prompt_ids = tokenized["input_ids"] + mode = "train" if self.model.training else "eval" + num_generations = self.num_generations if mode == "train" else self.num_generations_eval + prompt_ids_out, completion_ids_list, _, _ = self.vllm_generation.generate( + prompts=prompt_ids, + images=None, + num_generations=num_generations, + ) + return prompt_ids_out, completion_ids_list + + def _generate_transformers(self, prompts: list[Any]) -> tuple[list[list[int]], list[list[int]]]: + generate_inputs = self.processing_class( + text=self.prompt_tokenizer.apply_prompt_template(prompts), + return_tensors="pt", + padding=True, + padding_side="left", + max_length=self.max_prompt_length, + truncation=True, + add_special_tokens=False, + ) + generate_inputs = _BaseTrainer._prepare_inputs(self, generate_inputs) + + with ( + unwrap_model_for_generation( + self.model_wrapped, + self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation, + ) as unwrapped_model, + torch.no_grad(), + ): + prompt_completion_ids = unwrapped_model.generate( + **generate_inputs, generation_config=self.generation_config + ) + + prompt_ids = generate_inputs["input_ids"] + prompt_mask = generate_inputs["attention_mask"] + prompt_length = prompt_ids.size(1) + completion_ids = prompt_completion_ids[:, prompt_length:] + is_eos = completion_ids == self.eos_token_id + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + seq_idx = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1) + completion_mask = (seq_idx <= eos_idx.unsqueeze(1)).long() + + prompt_ids_list = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool(), strict=False)] + completion_ids_list = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=False)] + return prompt_ids_list, completion_ids_list + + def compute_rollout_logps( + self, + prompt_ids: torch.Tensor, + prompt_mask: torch.Tensor, + completion_ids: torch.Tensor, + completion_mask: torch.Tensor, + ) -> torch.Tensor | None: + generate_every = self.args.steps_per_generation * self.num_iterations + old_per_token_logps = None + + if self.args.gradient_accumulation_steps % generate_every != 0: + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) + old_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + compute_entropy=False, + ) + + return old_per_token_logps + + def _validate_rollout_batch(self, batch: SelfDistillationRolloutBatch) -> None: + batch_size = batch.prompt_ids.size(0) + if batch.prompt_mask.size(0) != batch_size: + raise ValueError("`prompt_mask` must have the same batch size as `prompt_ids` in the rollout batch.") + if batch.completion_ids.size(0) != batch_size or batch.completion_mask.size(0) != batch_size: + raise ValueError("`completion_ids` and `completion_mask` must match the rollout batch size.") + if batch.old_per_token_logps is not None and batch.old_per_token_logps.size(0) != batch_size: + raise ValueError("`old_per_token_logps` must match the rollout batch size when provided.") + + def _validate_training_batch(self, batch: SelfDistillationBatch) -> None: + batch_size = batch.prompt_ids.size(0) + if batch.prompt_mask.size(0) != batch_size: + raise ValueError("`prompt_mask` must have the same batch size as `prompt_ids`.") + if batch.completion_ids.size(0) != batch_size or batch.completion_mask.size(0) != batch_size: + raise ValueError("`completion_ids` and `completion_mask` must match the student batch size.") + if batch.teacher_input_ids.size(0) != batch_size or batch.teacher_attention_mask.size(0) != batch_size: + raise ValueError("`teacher_input_ids` and `teacher_attention_mask` must match the student batch size.") + if batch.teacher_input_ids.size(1) != batch.teacher_attention_mask.size(1): + raise ValueError("`teacher_input_ids` and `teacher_attention_mask` must have the same sequence length.") + if batch.self_distillation_mask is not None and batch.self_distillation_mask.size(0) != batch_size: + raise ValueError("`self_distillation_mask` must match the batch size when provided.") + + @abstractmethod + def augment_training_batch( + self, + inputs: list[dict[str, Any]], + rollout_batch: SelfDistillationRolloutBatch, + ) -> SelfDistillationBatch: + """Inject teacher-side inputs and algorithm-specific fields into a common student rollout batch.""" + + @abstractmethod + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + """Subclasses own algorithm-specific loss composition on the final batch contract.""" From be1bcbca4facc1cf3bd7992b1f62849990bc7d7c Mon Sep 17 00:00:00 2001 From: Leon Date: Wed, 15 Apr 2026 09:28:49 +0200 Subject: [PATCH 02/23] sdft transition v1 complete, starting on sdpo --- .../sdft/sdft_trainer_transition.py | 59 +-- .../sdpo/sdpo_trainer_transition.py | 391 ++++++++++++++++++ .../unified_base_self_distillation_trainer.py | 74 +++- 3 files changed, 468 insertions(+), 56 deletions(-) create mode 100644 trl/experimental/sdpo/sdpo_trainer_transition.py diff --git a/trl/experimental/sdft/sdft_trainer_transition.py b/trl/experimental/sdft/sdft_trainer_transition.py index e349a2918ff..2e5e7757ff5 100644 --- a/trl/experimental/sdft/sdft_trainer_transition.py +++ b/trl/experimental/sdft/sdft_trainer_transition.py @@ -14,7 +14,6 @@ from __future__ import annotations -import copy import textwrap from typing import Any @@ -31,11 +30,6 @@ ) from transformers.utils import is_peft_available -from ...models import prepare_deepspeed, prepare_fsdp -from ...trainer.callbacks import SyncRefModelCallback -from ...trainer.utils import ( - use_adapter, -) from ..self_distillation.teacher_context import PromptTokenizer, extract_last_user_text from ..self_distillation.unified_base_self_distillation_trainer import ( SelfDistillationBatch, @@ -47,9 +41,6 @@ if is_peft_available(): from peft import PeftConfig - from peft.peft_model import PeftModel - - from ..self_distillation.peft_adapter_ema_callback import PEFTAdapterEMACallback logger = get_logger(__name__) @@ -162,7 +153,7 @@ def __init__( ): raise NotImplementedError("Iterable eval datasets are not yet supported in SDFTTrainer.") - super().init( + super().__init__( model=model, args=args, train_dataset=train_dataset, @@ -174,40 +165,8 @@ def __init__( ) self.num_loss_tokens_to_skip = args.num_loss_tokens_to_skip - self.generate_from_teacher = args.generate_from_teacher self.teacher_context_builder = DemonstrationTeacherContextBuilder(self) - # In self-distillation the teacher is always derived from the student: - # - PEFT: base model with adapter disabled (or EMA teacher adapter when sync_ref_model=True) - # - Non-PEFT: same model (or deep-copied EMA model when sync_ref_model=True) - self.teacher_model = None - - if args.sync_ref_model: - if is_peft_available() and is_peft_model(self.model): - self.add_callback( - PEFTAdapterEMACallback( - model=self.model, - teacher_adapter_name="teacher", - update_rate=args.ref_model_mixup_alpha, - sync_steps=args.ref_model_sync_steps, - accelerator=self.accelerator, - ) - ) - else: - student_model = self.accelerator.unwrap_model(self.model) - self.teacher_model = copy.deepcopy(student_model) - self.teacher_model.requires_grad_(False) - self.teacher_model.eval() - if self.is_deepspeed_enabled: - self.teacher_model = prepare_deepspeed(self.teacher_model, self.accelerator) - elif self.is_fsdp_enabled: - self.teacher_model = prepare_fsdp(self.teacher_model, self.accelerator) - else: - self.teacher_model = self.accelerator.prepare_model(self.teacher_model, evaluation_mode=True) - self.add_callback(SyncRefModelCallback(ref_model=self.teacher_model, accelerator=self.accelerator)) - - self.model_accepts_loss_kwargs = False - def augment_training_batch( self, inputs: list[dict[str, Any]], @@ -221,7 +180,6 @@ def augment_training_batch( rollout_batch.completion_mask, ) - old_per_token_logps = None if self.generate_from_teacher else rollout_batch.old_per_token_logps return SelfDistillationBatch( prompt_ids=teacher_batch["prompt_ids"], prompt_mask=teacher_batch["prompt_mask"], @@ -229,7 +187,7 @@ def augment_training_batch( completion_mask=rollout_batch.completion_mask, teacher_input_ids=teacher_batch["teacher_input_ids"], teacher_attention_mask=teacher_batch["teacher_attention_mask"], - old_per_token_logps=old_per_token_logps, + old_per_token_logps=rollout_batch.old_per_token_logps, ) def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): @@ -247,10 +205,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N accumulation_scale = self.current_gradient_accumulation_steps if self.model.training else 1.0 return loss / accumulation_scale - def _get_teacher_context_for_self_distillation(self, model): - if is_peft_available() and isinstance(self.model, PeftModel): - model = self.accelerator.unwrap_model(self.model) - if self.args.sync_ref_model and "teacher" in model.peft_config: - return use_adapter(model, adapter_name="teacher") - return use_adapter(model, adapter_name=None) - return super()._get_teacher_context_for_self_distillation(model) + def _get_peft_teacher_mode(self) -> str: + if not (is_peft_available() and is_peft_model(self.model)): + return super()._get_peft_teacher_mode() + if self.args.sync_ref_model: + return "teacher_adapter" + return "disable_adapter" diff --git a/trl/experimental/sdpo/sdpo_trainer_transition.py b/trl/experimental/sdpo/sdpo_trainer_transition.py new file mode 100644 index 00000000000..d2f75aad81b --- /dev/null +++ b/trl/experimental/sdpo/sdpo_trainer_transition.py @@ -0,0 +1,391 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import re +import textwrap +from typing import Any + +import torch +from accelerate.utils import gather_object +from datasets import Dataset, IterableDataset +from torch import nn +from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback + +from ...trainer.callbacks import SyncRefModelCallback +from ...trainer.utils import pad +from ..self_distillation.teacher_context import TokenizedPromptBatch, extract_last_user_text +from ..self_distillation.unified_base_self_distillation_trainer import UnifiedBaseSelfDistillationTrainer +from .sdpo_config import SDPOConfig + + +class EMATeacherSyncCallback(SyncRefModelCallback): + """Synchronize an EMA teacher model with the student model on each step.""" + + def __init__(self, teacher_model, update_rate: float, accelerator=None): + super().__init__(ref_model=teacher_model, accelerator=accelerator) + self.update_rate = update_rate + + def on_step_end(self, args, state, control, **kwargs): + model = kwargs["model"] + if self.accelerator is not None: + model = self.accelerator.unwrap_model(model) + self.sync_target_model(model, self.ref_model, self.update_rate) + + +class SuccessfulRolloutTeacherContextBuilder: + """Builds SDPO teacher contexts from successful rollouts, following the official online implementation.""" + + def __init__(self, trainer): + self.trainer = trainer + self.last_metrics: dict[str, float] = {} + + def _build_reprompt_text(self, prompt_text: str, solution_text: str, feedback_text: str) -> str: + return self.trainer.args.reprompt_template.format( + prompt=prompt_text, + solution=solution_text, + feedback=feedback_text, + ) + + def _tokenize_teacher_messages( + self, teacher_messages_list: list[str | list[dict[str, Any]]] + ) -> TokenizedPromptBatch: + teacher_prompt_ids_list = [] + device = self.trainer.accelerator.device + chat_template_kwargs = getattr(self.trainer, "chat_template_kwargs", {}) + for msg in teacher_messages_list: + if isinstance(msg, list) and isinstance(msg[0], dict): + tokenized = self.trainer.processing_class.apply_chat_template( + msg, + tokenize=True, + add_generation_prompt=True, + return_tensors="pt", + **chat_template_kwargs, + ) + if isinstance(tokenized, torch.Tensor): + ids = tokenized.squeeze(0) + else: + ids = tokenized["input_ids"].squeeze(0) + else: + ids = self.trainer.processing_class.encode(msg, return_tensors="pt").squeeze(0) + + if ids.shape[0] > self.trainer.args.max_reprompt_len: + ids = ids[-self.trainer.args.max_reprompt_len :] + teacher_prompt_ids_list.append(ids) + + teacher_prompt_ids = [ids.to(device) for ids in teacher_prompt_ids_list] + teacher_prompt_mask = [torch.ones(len(ids), dtype=torch.long, device=device) for ids in teacher_prompt_ids] + return TokenizedPromptBatch( + prompt_ids=pad(teacher_prompt_ids, padding_value=self.trainer.pad_token_id, padding_side="left"), + prompt_mask=pad(teacher_prompt_mask, padding_value=0, padding_side="left"), + ) + + def build( + self, + output: dict[str, torch.Tensor | Any], + prompts: list[Any], + rewards: torch.Tensor, + feedbacks: list[Any] | None = None, + ) -> dict[str, torch.Tensor]: + device = self.trainer.accelerator.device + mode = "train" if self.trainer.model.training else "eval" + num_generations = self.trainer.num_generations if mode == "train" else self.trainer.num_generations_eval + completion_ids = output["completion_ids"] + completion_mask = output["completion_mask"] + + num_local = len(prompts) + process_start = self.trainer.accelerator.process_index * num_local + process_slice = slice(process_start, process_start + num_local) + + # Rewards arrive already locally sliced (per-process) from the rollout mixin; re-gather them so + # the mining loop can find successful rollouts across all processes within each generation group. + all_rewards = self.trainer.accelerator.gather(rewards) + # Completion tensors are padded to the local max length per rank; align shapes before gathering. + # Use separate variables so the original completion_ids/completion_mask stay unpadded for the + # teacher concat (they must match the student's sequence length for logits_to_keep alignment). + padded_completion_ids = self.trainer.accelerator.pad_across_processes( + completion_ids, dim=1, pad_index=self.trainer.pad_token_id + ) + all_completion_ids = self.trainer.accelerator.gather(padded_completion_ids) + all_prompts = gather_object(prompts) + total_samples = all_rewards.shape[0] + all_feedbacks = gather_object(feedbacks) if feedbacks is not None else [None] * total_samples + + threshold = self.trainer.args.success_reward_threshold + dont_reprompt_self = self.trainer.args.dont_reprompt_on_self_success + feedback_only_without_solution = self.trainer.args.environment_feedback_only_without_solution + self_distillation_mask = torch.zeros(total_samples, device=device) + num_with_solution = 0 + num_with_feedback_available = 0 + num_with_feedback_used = 0 + success_group_count = 0 + successful_demo_indices: list[int | None] = [None] * total_samples + use_feedback_flags: list[bool] = [False] * total_samples + has_solution_flags: list[bool] = [False] * total_samples + + for i in range(total_samples): + group_start = (i // num_generations) * num_generations + group_end = group_start + num_generations + + successful = [] + if self.trainer.args.use_successful_as_teacher: + for j in range(group_start, group_end): + if dont_reprompt_self and j == i: + continue + if all_rewards[j].item() >= threshold: + successful.append(j) + + if i % num_generations == 0: + # Count groups with any successful rollout, ignoring self-exclusion which only + # affects per-sample teacher assignment, not whether the group has successes. + group_has_success = any(all_rewards[j].item() >= threshold for j in range(group_start, group_end)) + if group_has_success: + success_group_count += 1 + + raw_feedback = all_feedbacks[i] + has_feedback = isinstance(raw_feedback, str) and raw_feedback.strip() != "" + if has_feedback: + num_with_feedback_available += 1 + + has_solution = len(successful) > 0 + has_solution_flags[i] = has_solution + if has_solution: + successful_demo_indices[i] = successful[0] + use_feedback = ( + self.trainer.args.include_environment_feedback + and has_feedback + and (not feedback_only_without_solution or not has_solution) + ) + use_feedback_flags[i] = use_feedback + if use_feedback: + num_with_feedback_used += 1 + if has_solution or use_feedback: + self_distillation_mask[i] = 1.0 + if has_solution: + num_with_solution += 1 + + local_teacher_messages = [] + local_self_distillation_mask = self_distillation_mask[process_slice] + for global_idx in range(process_start, process_start + num_local): + original_prompt = all_prompts[global_idx] + raw_feedback = all_feedbacks[global_idx] + has_solution = has_solution_flags[global_idx] + use_feedback = use_feedback_flags[global_idx] + + if not has_solution and not use_feedback: + local_teacher_messages.append(original_prompt) + continue + + solution_text = "" + if has_solution: + demo_idx = successful_demo_indices[global_idx] + if demo_idx is None: + raise RuntimeError("Expected a successful demonstration index for an active SDPO teacher prompt.") + demo_ids = all_completion_ids[demo_idx] + demo_ids = demo_ids[demo_ids != self.trainer.processing_class.pad_token_id] + demo_text = self.trainer.processing_class.decode(demo_ids, skip_special_tokens=True) + + if self.trainer.args.remove_thinking_from_demonstration: + demo_text = re.sub(r".*?", "", demo_text, flags=re.DOTALL).strip() + + solution_text = self.trainer.args.solution_template.format(successful_previous_attempt=demo_text) + + feedback_text = "" + if use_feedback: + feedback_text = self.trainer.args.feedback_template.format(feedback_raw=raw_feedback) + + if isinstance(original_prompt, list): + system_messages = original_prompt[:-1] + prompt_text = extract_last_user_text(original_prompt) + reprompt_text = self._build_reprompt_text(prompt_text, solution_text, feedback_text) + local_teacher_messages.append(system_messages + [{"role": "user", "content": reprompt_text}]) + else: + local_teacher_messages.append(self._build_reprompt_text(original_prompt, solution_text, feedback_text)) + + teacher_batch = self._tokenize_teacher_messages(local_teacher_messages) + teacher_input_ids = torch.cat([teacher_batch.prompt_ids, completion_ids], dim=1) + teacher_attention_mask = torch.cat([teacher_batch.prompt_mask, completion_mask], dim=1) + + batch_size = total_samples if total_samples > 0 else 1 + num_groups = max(1, total_samples // max(1, num_generations)) + self.last_metrics = { + "self_distillation/success_group_fraction": success_group_count / num_groups, + "self_distillation/success_sample_fraction": num_with_solution / batch_size, + "self_distillation/feedback_available_fraction": num_with_feedback_available / batch_size, + "self_distillation/feedback_used_fraction": num_with_feedback_used / batch_size, + "self_distillation/reprompt_sample_fraction": self_distillation_mask.float().mean().item(), + } + + return { + "teacher_input_ids": teacher_input_ids, + "teacher_attention_mask": teacher_attention_mask, + "self_distillation_mask": local_self_distillation_mask, + } + + +class SDPOTrainer(UnifiedBaseSelfDistillationTrainer): + """ + Trainer for Self-Distillation Policy Optimization (SDPO). + + SDPO augments on-policy optimization with self-distillation from the model's own high-reward trajectories. It + converts tokenized feedback into a dense learning signal without any external teacher or explicit reward model. + SDPO treats the current model conditioned on feedback as a self-teacher and distills its feedback-informed + next-token predictions back into the policy. + """ + + config_cls = SDPOConfig + _tag_names = ["trl", "sdpo"] + _name = "SDPO" + # docstyle-ignore + _paper = { + "title": "Reinforcement Learning via Self-Distillation", + "id": "2601.20802", + "citation": textwrap.dedent("""\ + @article{hubotter2026sdpo, + title = {{Reinforcement Learning via Self-Distillation}}, + author = {Jonas H\\"ubotter and Frederike L\\"ubeck and Lejs Behric and Anton Baumann and Marco Bagatella and Daniel Marta and Ido Hakimi and Idan Shenfeld and Thomas Kleine Buening and Carlos Guestrin and Andreas Krause}, + year = 2026, + eprint = {arXiv:2601.20802} + }"""), + } + + def __init__( + self, + model: str | PreTrainedModel | nn.Module, + reward_funcs: Any | list[Any] | None = None, + args: SDPOConfig | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, + processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, + reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), + peft_config=None, + ): + if reward_funcs is None or (isinstance(reward_funcs, list) and len(reward_funcs) == 0): + raise ValueError("`reward_funcs` is required for SDPOTrainer because SDPO must score rollouts.") + super().__init__( + model=model, + reward_funcs=reward_funcs, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + reward_processing_classes=reward_processing_classes, + callbacks=callbacks, + optimizers=optimizers, + peft_config=peft_config, + ) + self.teacher_context_builder = SuccessfulRolloutTeacherContextBuilder(self) + if self.args.teacher_regularization == "ema": + # `self.model` may already be accelerator-wrapped after the shared base constructor. Build the EMA + # teacher from the unwrapped student model first, then prepare it as an auxiliary eval-only module. + student_model = self.accelerator.unwrap_model(self.model) + self.teacher_model = copy.deepcopy(student_model) + self.teacher_model.requires_grad_(False) + self.teacher_model.eval() + self.teacher_model = self._prepare_auxiliary_model_for_eval(self.teacher_model) + self.add_callback( + EMATeacherSyncCallback( + teacher_model=self.teacher_model, + update_rate=self.args.teacher_update_rate, + accelerator=self.accelerator, + ) + ) + + def _allow_topk_without_full_logit_distillation(self) -> bool: + return False + + def _generate_and_score_completions( + self, inputs: list[dict[str, torch.Tensor | Any]] + ) -> dict[str, torch.Tensor | Any]: + prompts, privileged_contexts = self._split_prompt_and_privileged_context(inputs) + + output = super()._generate_and_score_completions(inputs) + output.update( + self.teacher_context_builder.build(output, prompts, output["rewards"], feedbacks=privileged_contexts) + ) + + mode = "train" if self.model.training else "eval" + for key, value in self.teacher_context_builder.last_metrics.items(): + self._metrics[mode][key].append(value) + self._warn_on_inactive_self_distillation(mode) + + self._dispatch_self_distillation_callback( + "on_teacher_context_built", + teacher_input_ids=output["teacher_input_ids"], + teacher_attention_mask=output["teacher_attention_mask"], + completion_mask=output["completion_mask"], + self_distillation_mask=output["self_distillation_mask"], + ) + + return output + + def _warn_on_inactive_self_distillation(self, mode: str) -> None: + metrics = self.teacher_context_builder.last_metrics + tolerance = self.args.diagnostics_flat_tolerance + + reprompt_fraction = metrics.get("self_distillation/reprompt_sample_fraction", 0.0) + success_fraction = metrics.get("self_distillation/success_group_fraction", 0.0) + + if reprompt_fraction <= tolerance: + self._warn_on_degenerate_diagnostics( + mode=mode, + counter_key="inactive_self_distillation", + message=( + "SDPO self-distillation is inactive because no reprompted samples were constructed. " + "This usually means no rollout exceeded `success_reward_threshold` and no usable privileged " + "feedback was available." + ), + ) + else: + self._diagnostic_counters[mode]["inactive_self_distillation"] = 0 + + if success_fraction <= tolerance: + self._warn_on_degenerate_diagnostics( + mode=mode, + counter_key="no_successful_rollouts", + message=( + "SDPO did not find any successful rollouts in the current generation groups. " + "If this persists, reduce task difficulty, adjust reward shaping, or lower " + "`success_reward_threshold`." + ), + ) + else: + self._diagnostic_counters[mode]["no_successful_rollouts"] = 0 + + def _compute_policy_loss(self, model, inputs) -> torch.Tensor: + return super()._compute_loss(model, inputs) + + def _compute_weighted_self_distillation_loss(self, model, inputs) -> torch.Tensor | None: + if self.args.distillation_weight <= 0.0: + return None + + accumulation_scale = self.current_gradient_accumulation_steps if self.model.training else 1.0 + distillation_loss = self._compute_self_distillation_loss(model, inputs) / accumulation_scale + return self.args.distillation_weight * distillation_loss + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if return_outputs: + raise ValueError("The SDPOTrainer does not support returning outputs") + + if self.args.sdpo_policy_loss_mode == "hybrid": + policy_loss = self._compute_policy_loss(model, inputs) + weighted_distillation_loss = self._compute_weighted_self_distillation_loss(model, inputs) + return policy_loss if weighted_distillation_loss is None else policy_loss + weighted_distillation_loss + + weighted_distillation_loss = self._compute_weighted_self_distillation_loss(model, inputs) + if weighted_distillation_loss is not None: + return weighted_distillation_loss + return self._compute_policy_loss(model, inputs) diff --git a/trl/experimental/self_distillation/unified_base_self_distillation_trainer.py b/trl/experimental/self_distillation/unified_base_self_distillation_trainer.py index 14eb68a0a5a..d8e60101119 100644 --- a/trl/experimental/self_distillation/unified_base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/unified_base_self_distillation_trainer.py @@ -25,6 +25,7 @@ import datasets import torch from accelerate.logging import get_logger +from accelerate.utils import is_peft_model from datasets import Dataset, IterableDataset from torch import nn from torch.utils.data import DataLoader, Sampler @@ -50,6 +51,7 @@ identity, pad, split_tensor_dict, + use_adapter, ) from ..utils import prepare_peft_model from .self_distillation_config import SelfDistillationConfig @@ -59,6 +61,9 @@ if is_peft_available(): from peft import PeftConfig + from peft.peft_model import PeftModel + + from .peft_adapter_ema_callback import PEFTAdapterEMACallback logger = get_logger(__name__) @@ -158,6 +163,11 @@ def __init__( else inspect.signature(model.get_base_model().forward).parameters.keys() ) + if is_peft_available() and is_peft_model(model) and peft_config is not None: + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config`. Pass either a base " + "model with `peft_config`, or a pre-wrapped PEFT model." + ) if peft_config is not None or (is_peft_available() and getattr(model, "peft_config", None) is not None): model = prepare_peft_model(model, peft_config, args) @@ -269,17 +279,29 @@ def __init__( if hasattr(self.model, "add_model_tags"): self.model.add_model_tags(self._tag_names) - self.teacher_model = None self._setup_teacher_model() self.model_accepts_loss_kwargs = False def _setup_teacher_model(self) -> None: - """Prepare a generic teacher model derived from the student. + """Prepare teacher state according to the shared teacher policy.""" - Subclasses can override this when they need algorithm-specific teacher handling, such as adapter switching. - """ + teacher_regularization = self._get_teacher_regularization_mode() + peft_teacher_mode = self._get_peft_teacher_mode() + self._validate_teacher_policy(teacher_regularization, peft_teacher_mode) + + if teacher_regularization == "none": + return - if not self.args.sync_ref_model: + if is_peft_available() and is_peft_model(self.model) and peft_teacher_mode == "teacher_adapter": + self.add_callback( + PEFTAdapterEMACallback( + model=self.model, + teacher_adapter_name=self._get_teacher_adapter_name(), + update_rate=self.args.ref_model_mixup_alpha, + sync_steps=self.args.ref_model_sync_steps, + accelerator=self.accelerator, + ) + ) return student_model = self.accelerator.unwrap_model(self.model) @@ -296,6 +318,25 @@ def _setup_teacher_model(self) -> None: self.add_callback(SyncRefModelCallback(ref_model=self.teacher_model, accelerator=self.accelerator)) + def _get_teacher_regularization_mode(self) -> str: + return "ema" if self.args.sync_ref_model else "none" + + def _get_peft_teacher_mode(self) -> str: + return "inherit_adapter" + + def _get_teacher_adapter_name(self) -> str: + return "teacher" + + def _validate_teacher_policy(self, teacher_regularization: str, peft_teacher_mode: str) -> None: + if teacher_regularization not in {"none", "ema"}: + raise ValueError(f"Unsupported teacher regularization mode: {teacher_regularization}") + if peft_teacher_mode not in {"inherit_adapter", "disable_adapter", "teacher_adapter"}: + raise ValueError(f"Unsupported PEFT teacher mode: {peft_teacher_mode}") + if peft_teacher_mode == "teacher_adapter" and not (is_peft_available() and is_peft_model(self.model)): + raise ValueError("PEFT teacher mode `teacher_adapter` requires a PEFT model.") + if peft_teacher_mode == "teacher_adapter" and teacher_regularization != "ema": + raise ValueError("PEFT teacher mode `teacher_adapter` requires EMA teacher regularization.") + def get_train_dataloader(self): if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") @@ -539,6 +580,29 @@ def augment_training_batch( ) -> SelfDistillationBatch: """Inject teacher-side inputs and algorithm-specific fields into a common student rollout batch.""" + def _get_teacher_context_for_self_distillation(self, model): + peft_teacher_mode = self._get_peft_teacher_mode() + if not (is_peft_available() and isinstance(self.model, PeftModel)): + return super()._get_teacher_context_for_self_distillation(model) + + if peft_teacher_mode == "inherit_adapter": + return super()._get_teacher_context_for_self_distillation(model) + + target_model = self.teacher_model if self.teacher_model is not None else self.model + target_model = self.accelerator.unwrap_model(target_model) + + if peft_teacher_mode == "disable_adapter": + return use_adapter(target_model, adapter_name=None) + if peft_teacher_mode == "teacher_adapter": + teacher_adapter_name = self._get_teacher_adapter_name() + if teacher_adapter_name not in target_model.peft_config: + raise RuntimeError( + f"Expected PEFT teacher adapter `{teacher_adapter_name}` to exist before teacher forward." + ) + return use_adapter(target_model, adapter_name=teacher_adapter_name) + + raise ValueError(f"Unsupported PEFT teacher mode: {peft_teacher_mode}") + @abstractmethod def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """Subclasses own algorithm-specific loss composition on the final batch contract.""" From 0628701f6c42078b3b8d52b9295951f4505d2995 Mon Sep 17 00:00:00 2001 From: Leon Date: Wed, 15 Apr 2026 12:28:07 +0200 Subject: [PATCH 03/23] sdpo transitioned, needs testing --- trl/experimental/sdft/sdft.py | 6 +- trl/experimental/sdft/sdft_config.py | 8 + trl/experimental/sdft/sdft_trainer.py | 46 ++- .../sdft/sdft_trainer_transition.py | 8 - trl/experimental/sdpo/sdpo_config.py | 24 +- trl/experimental/sdpo/sdpo_trainer.py | 6 +- .../sdpo/sdpo_trainer_transition.py | 359 +++++++++++++++--- .../base_self_distillation_trainer.py | 4 - .../self_distillation_config.py | 53 ++- ...dapter_ema_callback.py => teacher_sync.py} | 16 + .../unified_base_self_distillation_trainer.py | 49 ++- 11 files changed, 460 insertions(+), 119 deletions(-) rename trl/experimental/self_distillation/{peft_adapter_ema_callback.py => teacher_sync.py} (88%) diff --git a/trl/experimental/sdft/sdft.py b/trl/experimental/sdft/sdft.py index 958a24fc683..77e1c8c3880 100644 --- a/trl/experimental/sdft/sdft.py +++ b/trl/experimental/sdft/sdft.py @@ -47,9 +47,9 @@ --max_prompt_length 1024 \ --max_completion_length 512 \ --generate_from_teacher \ - --sync_ref_model \ - --ref_model_sync_steps 1 \ - --ref_model_mixup_alpha 0.01 \ + --teacher_regularization ema \ + --teacher_sync_steps 1 \ + --teacher_update_rate 0.01 \ --eval_strategy steps \ --eval_steps 50 \ --report_to wandb diff --git a/trl/experimental/sdft/sdft_config.py b/trl/experimental/sdft/sdft_config.py index 84227e43cbf..6e27bb17216 100644 --- a/trl/experimental/sdft/sdft_config.py +++ b/trl/experimental/sdft/sdft_config.py @@ -28,6 +28,10 @@ class SDFTConfig(SelfDistillationConfig): Parameters: disable_dropout (`bool`, *optional*, defaults to `True`): Whether to disable dropout in the student and teacher models. + peft_teacher_mode (`str`, *optional*, defaults to `"auto"`): + PEFT teacher execution mode. The default `auto` reproduces the original SDFT behavior: use the + adapter-disabled base model without EMA teacher regularization, and the EMA teacher adapter when + `teacher_regularization="ema"`. generate_from_teacher (`bool`, *optional*, defaults to `False`): Whether on-policy generation should use the teacher-conditioned prompt instead of the student prompt. teacher_prompt_template (`str`, *optional*, defaults to `"{prompt}\n\n{privileged_context}"`): @@ -40,6 +44,10 @@ class SDFTConfig(SelfDistillationConfig): default=True, metadata={"help": "Whether to disable dropout in the student and teacher models."}, ) + peft_teacher_mode: str = field( + default="auto", + metadata={"help": "PEFT teacher execution mode. `auto` reproduces the original SDFT teacher behavior."}, + ) generate_from_teacher: bool = field( default=False, metadata={"help": "Whether on-policy generation should use the teacher-conditioned prompt."}, diff --git a/trl/experimental/sdft/sdft_trainer.py b/trl/experimental/sdft/sdft_trainer.py index 057ae851678..5afc4dff57f 100644 --- a/trl/experimental/sdft/sdft_trainer.py +++ b/trl/experimental/sdft/sdft_trainer.py @@ -41,7 +41,6 @@ from ...models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation from ...trainer.base_trainer import _BaseTrainer -from ...trainer.callbacks import SyncRefModelCallback from ...trainer.utils import ( RepeatSampler, create_model_from_path, @@ -54,6 +53,7 @@ ) from ..self_distillation.self_distillation_mixin import SelfDistillationMixin from ..self_distillation.teacher_context import PromptTokenizer, extract_last_user_text +from ..self_distillation.teacher_sync import PEFTAdapterEMACallback, SyncTeacherModelCallback from ..self_distillation.unified_base_self_distillation_trainer import ( SelfDistillationBatch, SelfDistillationRolloutBatch, @@ -66,8 +66,6 @@ from peft import PeftConfig from peft.peft_model import PeftModel - from ..self_distillation.peft_adapter_ema_callback import PEFTAdapterEMACallback - logger = get_logger(__name__) @@ -278,18 +276,27 @@ def __init__( self.model.add_model_tags(self._tag_names) # In self-distillation the teacher is always derived from the student: - # - PEFT: base model with adapter disabled (or EMA teacher adapter when sync_ref_model=True) - # - Non-PEFT: same model (or deep-copied EMA model when sync_ref_model=True) + # - PEFT: base model with adapter disabled (or EMA teacher adapter when teacher_regularization="ema") + # - Non-PEFT: same model (or deep-copied EMA model when teacher_regularization="ema") self.teacher_model = None - if args.sync_ref_model: - if is_peft_available() and is_peft_model(self.model): + peft_teacher_mode = args.peft_teacher_mode + if is_peft_available() and is_peft_model(self.model): + if peft_teacher_mode == "auto": + peft_teacher_mode = "teacher_adapter" if args.teacher_regularization == "ema" else "disable_adapter" + else: + if peft_teacher_mode in {"disable_adapter", "teacher_adapter"}: + raise ValueError(f"PEFT teacher mode `{peft_teacher_mode}` requires a PEFT model.") + peft_teacher_mode = "inherit_adapter" + + if args.teacher_regularization == "ema": + if peft_teacher_mode == "teacher_adapter": self.add_callback( PEFTAdapterEMACallback( model=self.model, - teacher_adapter_name="teacher", - update_rate=args.ref_model_mixup_alpha, - sync_steps=args.ref_model_sync_steps, + teacher_adapter_name=args.teacher_adapter_name, + update_rate=args.teacher_update_rate, + sync_steps=args.teacher_sync_steps, accelerator=self.accelerator, ) ) @@ -304,7 +311,9 @@ def __init__( self.teacher_model = prepare_fsdp(self.teacher_model, self.accelerator) else: self.teacher_model = self.accelerator.prepare_model(self.teacher_model, evaluation_mode=True) - self.add_callback(SyncRefModelCallback(ref_model=self.teacher_model, accelerator=self.accelerator)) + self.add_callback( + SyncTeacherModelCallback(teacher_model=self.teacher_model, accelerator=self.accelerator) + ) self.model_accepts_loss_kwargs = False @@ -508,7 +517,16 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N def _get_teacher_context_for_self_distillation(self, model): if is_peft_available() and isinstance(self.model, PeftModel): model = self.accelerator.unwrap_model(self.model) - if self.args.sync_ref_model and "teacher" in model.peft_config: - return use_adapter(model, adapter_name="teacher") - return use_adapter(model, adapter_name=None) + peft_teacher_mode = self.args.peft_teacher_mode + if peft_teacher_mode == "auto": + peft_teacher_mode = ( + "teacher_adapter" if self.args.teacher_regularization == "ema" else "disable_adapter" + ) + if peft_teacher_mode == "inherit_adapter": + return super()._get_teacher_context_for_self_distillation(model) + if peft_teacher_mode == "teacher_adapter" and self.args.teacher_adapter_name in model.peft_config: + return use_adapter(model, adapter_name=self.args.teacher_adapter_name) + if peft_teacher_mode == "disable_adapter": + return use_adapter(model, adapter_name=None) + raise ValueError(f"Unsupported PEFT teacher mode: {peft_teacher_mode}") return super()._get_teacher_context_for_self_distillation(model) diff --git a/trl/experimental/sdft/sdft_trainer_transition.py b/trl/experimental/sdft/sdft_trainer_transition.py index 2e5e7757ff5..cafce5bf221 100644 --- a/trl/experimental/sdft/sdft_trainer_transition.py +++ b/trl/experimental/sdft/sdft_trainer_transition.py @@ -19,7 +19,6 @@ import torch from accelerate.logging import get_logger -from accelerate.utils import is_peft_model from datasets import Dataset, IterableDataset from torch import nn from transformers import ( @@ -204,10 +203,3 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N loss = self._compute_self_distillation_loss(model, inputs) accumulation_scale = self.current_gradient_accumulation_steps if self.model.training else 1.0 return loss / accumulation_scale - - def _get_peft_teacher_mode(self) -> str: - if not (is_peft_available() and is_peft_model(self.model)): - return super()._get_peft_teacher_mode() - if self.args.sync_ref_model: - return "teacher_adapter" - return "disable_adapter" diff --git a/trl/experimental/sdpo/sdpo_config.py b/trl/experimental/sdpo/sdpo_config.py index 1cc8c1510b7..707d38e1648 100644 --- a/trl/experimental/sdpo/sdpo_config.py +++ b/trl/experimental/sdpo/sdpo_config.py @@ -41,10 +41,10 @@ class SDPOConfig(SelfDistillationConfig): teacher_regularization (`str`, *optional*, defaults to `"ema"`): Teacher update strategy. Supported: `ema`, `none`. - teacher_update_rate (`float` or `None`, *optional*): + teacher_update_rate (`float`, *optional*, defaults to `0.05`): EMA update rate used when `teacher_regularization="ema"`. - ema_update_rate (`float`, *optional*, defaults to `0.05`): - Deprecated alias for `teacher_update_rate`. + teacher_sync_steps (`int`, *optional*, defaults to `1`): + Number of optimizer steps between teacher EMA updates. > Parameters that control reprompting @@ -78,13 +78,13 @@ class SDPOConfig(SelfDistillationConfig): default="ema", metadata={"help": "Teacher regularization mode. Supported: `ema`, `none`."}, ) - teacher_update_rate: float | None = field( - default=None, + teacher_update_rate: float = field( + default=0.05, metadata={"help": "Teacher update rate used for EMA teacher synchronization."}, ) - ema_update_rate: float = field( - default=0.05, - metadata={"help": "Deprecated alias for `teacher_update_rate`."}, + teacher_sync_steps: int = field( + default=1, + metadata={"help": "How often to synchronize the EMA teacher model."}, ) max_reprompt_len: int = field( default=10240, @@ -125,14 +125,6 @@ class SDPOConfig(SelfDistillationConfig): def __post_init__(self): super().__post_init__() - - if self.teacher_update_rate is None: - self.teacher_update_rate = self.ema_update_rate - - if self.teacher_regularization not in {"ema", "none"}: - raise ValueError("teacher_regularization must be one of: 'ema', 'none'") - if not 0.0 <= self.teacher_update_rate <= 1.0: - raise ValueError("teacher_update_rate must be in [0, 1]") if self.sdpo_policy_loss_mode not in {"distillation_only", "hybrid"}: raise ValueError("sdpo_policy_loss_mode must be one of: 'distillation_only', 'hybrid'") if self.sdpo_policy_loss_mode == "distillation_only" and self.distillation_weight <= 0: diff --git a/trl/experimental/sdpo/sdpo_trainer.py b/trl/experimental/sdpo/sdpo_trainer.py index cf346ac3b60..5d24484e255 100644 --- a/trl/experimental/sdpo/sdpo_trainer.py +++ b/trl/experimental/sdpo/sdpo_trainer.py @@ -33,11 +33,14 @@ class EMATeacherSyncCallback(SyncRefModelCallback): """Synchronize an EMA teacher model with the student model on each step.""" - def __init__(self, teacher_model, update_rate: float, accelerator=None): + def __init__(self, teacher_model, update_rate: float, sync_steps: int, accelerator=None): super().__init__(ref_model=teacher_model, accelerator=accelerator) self.update_rate = update_rate + self.sync_steps = sync_steps def on_step_end(self, args, state, control, **kwargs): + if state.global_step % self.sync_steps != 0: + return model = kwargs["model"] if self.accelerator is not None: model = self.accelerator.unwrap_model(model) @@ -300,6 +303,7 @@ def __init__( EMATeacherSyncCallback( teacher_model=self.teacher_model, update_rate=self.args.teacher_update_rate, + sync_steps=self.args.teacher_sync_steps, accelerator=self.accelerator, ) ) diff --git a/trl/experimental/sdpo/sdpo_trainer_transition.py b/trl/experimental/sdpo/sdpo_trainer_transition.py index d2f75aad81b..8fe68e7e7aa 100644 --- a/trl/experimental/sdpo/sdpo_trainer_transition.py +++ b/trl/experimental/sdpo/sdpo_trainer_transition.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import re import textwrap from typing import Any @@ -21,27 +20,30 @@ from accelerate.utils import gather_object from datasets import Dataset, IterableDataset from torch import nn -from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback - -from ...trainer.callbacks import SyncRefModelCallback -from ...trainer.utils import pad +from transformers import ( + AutoModelForSequenceClassification, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, +) +from transformers.utils import logging + +from ...data_utils import apply_chat_template, is_conversational +from ...models import prepare_deepspeed, prepare_fsdp +from ...trainer.base_trainer import _BaseTrainer +from ...trainer.utils import get_config_model_id, pad from ..self_distillation.teacher_context import TokenizedPromptBatch, extract_last_user_text -from ..self_distillation.unified_base_self_distillation_trainer import UnifiedBaseSelfDistillationTrainer +from ..self_distillation.unified_base_self_distillation_trainer import ( + SelfDistillationBatch, + SelfDistillationRolloutBatch, + UnifiedBaseSelfDistillationTrainer, +) from .sdpo_config import SDPOConfig -class EMATeacherSyncCallback(SyncRefModelCallback): - """Synchronize an EMA teacher model with the student model on each step.""" - - def __init__(self, teacher_model, update_rate: float, accelerator=None): - super().__init__(ref_model=teacher_model, accelerator=accelerator) - self.update_rate = update_rate - - def on_step_end(self, args, state, control, **kwargs): - model = kwargs["model"] - if self.accelerator is not None: - model = self.accelerator.unwrap_model(model) - self.sync_target_model(model, self.ref_model, self.update_rate) +logger = logging.get_logger(__name__) class SuccessfulRolloutTeacherContextBuilder: @@ -277,44 +279,197 @@ def __init__( raise ValueError("`reward_funcs` is required for SDPOTrainer because SDPO must score rollouts.") super().__init__( model=model, - reward_funcs=reward_funcs, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, - reward_processing_classes=reward_processing_classes, callbacks=callbacks, optimizers=optimizers, peft_config=peft_config, ) - self.teacher_context_builder = SuccessfulRolloutTeacherContextBuilder(self) - if self.args.teacher_regularization == "ema": - # `self.model` may already be accelerator-wrapped after the shared base constructor. Build the EMA - # teacher from the unwrapped student model first, then prepare it as an auxiliary eval-only module. - student_model = self.accelerator.unwrap_model(self.model) - self.teacher_model = copy.deepcopy(student_model) - self.teacher_model.requires_grad_(False) - self.teacher_model.eval() - self.teacher_model = self._prepare_auxiliary_model_for_eval(self.teacher_model) - self.add_callback( - EMATeacherSyncCallback( - teacher_model=self.teacher_model, - update_rate=self.args.teacher_update_rate, - accelerator=self.accelerator, + self.importance_sampling_level = args.importance_sampling_level + self.scale_rewards = args.scale_rewards + self.epsilon_low = args.epsilon + self.epsilon_high = args.epsilon_high + self.beta = args.beta + + if not isinstance(reward_funcs, list): + reward_funcs = [reward_funcs] + self.reward_func_names = [] + for i, reward_func in enumerate(reward_funcs): + if isinstance(reward_func, str): + reward_model_init_kwargs = args.model_init_kwargs or {} + if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + reward_model_init_kwargs["device_map"] = None + reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( + reward_func, + num_labels=1, + **reward_model_init_kwargs, ) - ) + if isinstance(reward_funcs[i], nn.Module): + self.reward_func_names.append(get_config_model_id(reward_funcs[i].config).split("/")[-1]) + else: + self.reward_func_names.append(reward_funcs[i].__name__) + self.reward_funcs = reward_funcs + + if args.reward_weights is not None: + if len(args.reward_weights) != len(self.reward_funcs): + raise ValueError("Number of reward weights must match number of reward functions") + self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) + else: + self.reward_weights = torch.ones(len(self.reward_funcs), dtype=torch.float32) + + if reward_processing_classes is None: + reward_processing_classes = [None] * len(self.reward_funcs) + elif not isinstance(reward_processing_classes, list): + reward_processing_classes = [reward_processing_classes] + if len(reward_processing_classes) != len(self.reward_funcs): + raise ValueError("Number of reward processing classes must match number of reward functions") + + for i, (reward_processing_class, reward_func) in enumerate( + zip(reward_processing_classes, self.reward_funcs, strict=True) + ): + if isinstance(reward_func, PreTrainedModel): + if reward_processing_class is None: + reward_processing_class = AutoTokenizer.from_pretrained(get_config_model_id(reward_func.config)) + if reward_processing_class.pad_token_id is None: + reward_processing_class.pad_token = reward_processing_class.eos_token + reward_func.config.pad_token_id = reward_processing_class.pad_token_id + reward_processing_classes[i] = reward_processing_class + self.reward_processing_classes = reward_processing_classes + + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, nn.Module): + if self.is_deepspeed_enabled: + self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) + elif self.is_fsdp_enabled: + self.reward_funcs[i] = prepare_fsdp(reward_func, self.accelerator) + else: + self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) + + self.teacher_context_builder = SuccessfulRolloutTeacherContextBuilder(self) def _allow_topk_without_full_logit_distillation(self) -> bool: return False - def _generate_and_score_completions( - self, inputs: list[dict[str, torch.Tensor | Any]] - ) -> dict[str, torch.Tensor | Any]: + def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): + device = self.accelerator.device + if len(self.reward_funcs) == 0: + return torch.zeros((len(prompts), 0), device=device) + + rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) + keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]] + reward_kwargs = {key: [example[key] for example in inputs] for key in keys} + reward_kwargs["trainer_state"] = self.state + + for i, (reward_func, reward_processing_class) in enumerate( + zip(self.reward_funcs, self.reward_processing_classes, strict=True) + ): + if isinstance(reward_func, nn.Module): + if is_conversational(inputs[0]): + messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)] + texts = [ + apply_chat_template(x, reward_processing_class, **self.chat_template_kwargs)["text"] + for x in messages + ] + else: + texts = [p + c for p, c in zip(prompts, completions, strict=True)] + reward_inputs = reward_processing_class( + text=texts, + return_tensors="pt", + padding=True, + padding_side="right", + add_special_tokens=False, + ) + reward_inputs = _BaseTrainer._prepare_inputs(self, reward_inputs) + with torch.inference_mode(): + rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] + else: + output_reward_func = reward_func( + prompts=prompts, + completions=completions, + completion_ids=completion_ids_list, + **reward_kwargs, + ) + output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + + return self.accelerator.gather(rewards_per_func) + + def augment_training_batch( + self, + inputs: list[dict[str, Any]], + rollout_batch: SelfDistillationRolloutBatch, + ) -> SelfDistillationBatch: + device = self.accelerator.device + mode = "train" if self.model.training else "eval" prompts, privileged_contexts = self._split_prompt_and_privileged_context(inputs) + raw_completion_lengths = rollout_batch.metadata["raw_completion_lengths"].detach().cpu().tolist() + completion_ids_list = [ + ids[:length].tolist() + for ids, length in zip(rollout_batch.completion_ids.detach().cpu(), raw_completion_lengths, strict=True) + ] + if is_conversational({"prompt": prompts[0]}): + completions_text = self.processing_class.batch_decode( + rollout_batch.completion_ids, skip_special_tokens=True + ) + completions = [[{"role": "assistant", "content": content}] for content in completions_text] + else: + completions = self.processing_class.batch_decode(rollout_batch.completion_ids, skip_special_tokens=True) - output = super()._generate_and_score_completions(inputs) - output.update( - self.teacher_context_builder.build(output, prompts, output["rewards"], feedbacks=privileged_contexts) + rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) + if rewards_per_func.numel() == 0: + rewards = torch.zeros(self.accelerator.num_processes * len(prompts), device=device) + else: + rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + + num_generations = self.num_generations if mode == "train" else self.num_generations_eval + mean_grouped_rewards = rewards.view(-1, num_generations).mean(dim=1).repeat_interleave(num_generations, dim=0) + if self.scale_rewards == "batch": + std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards) + group_std_rewards = rewards.view(-1, num_generations).std(dim=1) + elif self.scale_rewards == "none": + std_rewards = torch.ones_like(rewards) + group_std_rewards = torch.ones(rewards.numel() // num_generations, device=device, dtype=rewards.dtype) + else: + group_std_rewards = rewards.view(-1, num_generations).std(dim=1) + std_rewards = group_std_rewards.repeat_interleave(num_generations, dim=0) + advantages = (rewards - mean_grouped_rewards) / (std_rewards + 1e-4) + self._record_reward_diagnostics(mode, rewards, rewards_per_func, group_std_rewards) + + local_batch_size = rollout_batch.completion_ids.size(0) + process_start = self.accelerator.process_index * local_batch_size + process_slice = slice(process_start, process_start + local_batch_size) + local_rewards = rewards[process_slice] + local_advantages = advantages[process_slice] + + agg_completion_lengths = self.accelerator.gather( + torch.tensor([len(ids) for ids in completion_ids_list], device=device) + ) + self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) + + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + agg_is_truncated = self.accelerator.gather(is_truncated) + self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) + term_completion_lengths = agg_completion_lengths[~agg_is_truncated] + if len(term_completion_lengths) == 0: + term_completion_lengths = torch.zeros(1, device=device) + self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) + + rollout_dict = rollout_batch.to_dict() + rollout_dict["rewards"] = local_rewards + rollout_dict["advantages"] = local_advantages + rollout_dict["num_items_in_batch"] = rollout_batch.completion_mask.sum().detach() + teacher_context = self.teacher_context_builder.build( + rollout_dict, + prompts, + rollout_dict["rewards"], + feedbacks=privileged_contexts, ) mode = "train" if self.model.training else "eval" @@ -324,13 +479,26 @@ def _generate_and_score_completions( self._dispatch_self_distillation_callback( "on_teacher_context_built", - teacher_input_ids=output["teacher_input_ids"], - teacher_attention_mask=output["teacher_attention_mask"], - completion_mask=output["completion_mask"], - self_distillation_mask=output["self_distillation_mask"], + teacher_input_ids=teacher_context["teacher_input_ids"], + teacher_attention_mask=teacher_context["teacher_attention_mask"], + completion_mask=rollout_batch.completion_mask, + self_distillation_mask=teacher_context["self_distillation_mask"], ) - return output + return SelfDistillationBatch( + prompt_ids=rollout_batch.prompt_ids, + prompt_mask=rollout_batch.prompt_mask, + completion_ids=rollout_batch.completion_ids, + completion_mask=rollout_batch.completion_mask, + teacher_input_ids=teacher_context["teacher_input_ids"], + teacher_attention_mask=teacher_context["teacher_attention_mask"], + old_per_token_logps=rollout_batch.old_per_token_logps, + self_distillation_mask=teacher_context["self_distillation_mask"], + metadata={ + "rewards": local_rewards, + "advantages": local_advantages, + }, + ) def _warn_on_inactive_self_distillation(self, mode: str) -> None: metrics = self.teacher_context_builder.last_metrics @@ -365,8 +533,105 @@ def _warn_on_inactive_self_distillation(self, mode: str) -> None: else: self._diagnostic_counters[mode]["no_successful_rollouts"] = 0 + def _record_reward_diagnostics( + self, + mode: str, + rewards: torch.Tensor, + rewards_per_func: torch.Tensor, + group_std_rewards: torch.Tensor, + ) -> None: + tolerance = self.args.diagnostics_flat_tolerance + + reward_mean = rewards.mean() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) + reward_std = rewards.std() if rewards.numel() > 1 else torch.tensor(0.0, device=self.accelerator.device) + reward_min = rewards.min() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) + reward_max = rewards.max() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) + flat_group_fraction = ( + (group_std_rewards <= tolerance).float().mean() + if group_std_rewards.numel() > 0 + else torch.tensor(1.0, device=self.accelerator.device) + ) + + self._metrics[mode]["self_distillation/reward_mean"].append(self.accelerator.gather(reward_mean).mean().item()) + self._metrics[mode]["self_distillation/reward_std"].append(self.accelerator.gather(reward_std).mean().item()) + self._metrics[mode]["self_distillation/reward_min"].append(self.accelerator.gather(reward_min).min().item()) + self._metrics[mode]["self_distillation/reward_max"].append(self.accelerator.gather(reward_max).max().item()) + self._metrics[mode]["self_distillation/group_reward_std_mean"].append( + self.accelerator.gather(group_std_rewards.mean() if group_std_rewards.numel() > 0 else reward_std) + .mean() + .item() + ) + self._metrics[mode]["self_distillation/flat_group_fraction"].append( + self.accelerator.gather(flat_group_fraction).mean().item() + ) + + if rewards_per_func.numel() > 0: + reward_func_means = rewards_per_func.nanmean(dim=0) + gathered_means = self.accelerator.gather(reward_func_means).view(-1, reward_func_means.numel()).mean(dim=0) + for reward_name, reward_func_mean in zip(self.reward_func_names, gathered_means.tolist(), strict=True): + self._metrics[mode][f"self_distillation/rewards/{reward_name}"].append(reward_func_mean) + + reward_is_flat = reward_std.item() <= tolerance + grouped_rewards_are_flat = flat_group_fraction.item() >= 1.0 - tolerance + if reward_is_flat and grouped_rewards_are_flat: + self._warn_on_degenerate_diagnostics( + mode=mode, + counter_key="flat_rewards", + message=( + "Observed flat SDPO rewards across all sampled generations. " + "Policy advantages will collapse to zero, and SDPO will not learn. " + "Check reward density, reward shaping, or `success_reward_threshold`." + ), + ) + else: + self._diagnostic_counters[mode]["flat_rewards"] = 0 + + def _warn_on_degenerate_diagnostics(self, mode: str, counter_key: str, message: str) -> None: + interval = self.args.diagnostics_warning_interval + if interval == 0: + return + + self._diagnostic_counters[mode][counter_key] += 1 + count = self._diagnostic_counters[mode][counter_key] + if count == 1 or count % interval == 0: + logger.warning("%s Consecutive degenerate steps: %s.", message, count) + def _compute_policy_loss(self, model, inputs) -> torch.Tensor: - return super()._compute_loss(model, inputs) + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) + per_token_logps, _ = self._get_per_token_logps_and_entropies( + model, + input_ids, + attention_mask, + logits_to_keep, + compute_entropy=False, + ) + old_per_token_logps = inputs.get("old_per_token_logps") + old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps + advantages = inputs["advantages"] + if advantages.dim() == 1: + advantages = advantages.unsqueeze(1) + log_ratio = per_token_logps - old_per_token_logps + if self.importance_sampling_level == "sequence": + log_ratio = (log_ratio * completion_mask).sum(-1, keepdim=True) / completion_mask.sum( + -1, keepdim=True + ).clamp(min=1.0) + coef_1 = torch.exp(log_ratio) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + per_token_loss = -torch.min(coef_1 * advantages, coef_2 * advantages) + + loss = self._aggregate_self_distillation_loss(per_token_loss, completion_mask) + + mode = "train" if self.model.training else "eval" + self._metrics[mode]["self_distillation/policy_loss"].append( + self.accelerator.gather(loss.detach()).mean().item() + ) + + accumulation_scale = self.current_gradient_accumulation_steps if mode == "train" else 1.0 + return loss / accumulation_scale def _compute_weighted_self_distillation_loss(self, model, inputs) -> torch.Tensor | None: if self.args.distillation_weight <= 0.0: diff --git a/trl/experimental/self_distillation/base_self_distillation_trainer.py b/trl/experimental/self_distillation/base_self_distillation_trainer.py index bd9abb95164..5b658b19a10 100644 --- a/trl/experimental/self_distillation/base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/base_self_distillation_trainer.py @@ -276,10 +276,6 @@ def __init__( self.model_accepts_loss_kwargs = False self.ref_model = None self.teacher_model = None - if args.sync_ref_model: - raise ValueError( - "sync_ref_model is not supported on the shared online self-distillation base without `ref_model`." - ) def get_train_dataloader(self): if self.train_dataset is None: diff --git a/trl/experimental/self_distillation/self_distillation_config.py b/trl/experimental/self_distillation/self_distillation_config.py index c28c0f6be0b..4611ddacd7a 100644 --- a/trl/experimental/self_distillation/self_distillation_config.py +++ b/trl/experimental/self_distillation/self_distillation_config.py @@ -52,6 +52,20 @@ class SelfDistillationConfig(_BaseConfig): scale_rewards (`str` or `bool`, *optional*, defaults to `"group"`): Reward normalization mode. Supported: `group`, `batch`, `none`. + > Parameters that control teacher construction + + teacher_regularization (`str`, *optional*, defaults to `"none"`): + Teacher update strategy. Supported: `none`, `ema`. + teacher_update_rate (`float`, *optional*, defaults to `0.6`): + EMA update rate used when `teacher_regularization="ema"`. + teacher_sync_steps (`int`, *optional*, defaults to `512`): + Number of optimizer steps between EMA teacher updates. + peft_teacher_mode (`str`, *optional*, defaults to `"inherit_adapter"`): + How teacher forwards should behave for PEFT models. Supported: `auto`, `inherit_adapter`, + `disable_adapter`, `teacher_adapter`. + teacher_adapter_name (`str`, *optional*, defaults to `"teacher"`): + Adapter name used when `peft_teacher_mode="teacher_adapter"`. + > Parameters that control self-distillation distillation_alpha (`float`, *optional*, defaults to `0.5`): @@ -238,17 +252,28 @@ class SelfDistillationConfig(_BaseConfig): default=False, metadata={"help": "Whether to exclude truncated completions from the loss."}, ) - sync_ref_model: bool = field( - default=False, - metadata={"help": "Whether to synchronize the reference model with the student model."}, + teacher_regularization: str = field( + default="none", + metadata={"help": "Teacher regularization mode. Supported: `none`, `ema`."}, ) - ref_model_mixup_alpha: float = field( + teacher_update_rate: float = field( default=0.6, - metadata={"help": "EMA mix coefficient used when syncing the reference model."}, + metadata={"help": "EMA update rate used when synchronizing the teacher model."}, ) - ref_model_sync_steps: int = field( + teacher_sync_steps: int = field( default=512, - metadata={"help": "How often to synchronize the reference model."}, + metadata={"help": "How often to synchronize the teacher model."}, + ) + peft_teacher_mode: str = field( + default="inherit_adapter", + metadata={ + "help": "Teacher execution mode for PEFT models. Supported: `auto`, `inherit_adapter`, " + "`disable_adapter`, `teacher_adapter`." + }, + ) + teacher_adapter_name: str = field( + default="teacher", + metadata={"help": "Adapter name used for PEFT teacher forwards when `peft_teacher_mode='teacher_adapter'`."}, ) top_entropy_quantile: float = field( default=1.0, @@ -302,6 +327,20 @@ def __post_init__(self): raise ValueError("importance_sampling_level must be either 'token' or 'sequence'") if self.loss_type not in ["grpo", "bnpo", "dr_grpo", "dapo"]: raise ValueError("loss_type must be one of: 'grpo', 'bnpo', 'dr_grpo', 'dapo'") + if self.teacher_regularization not in {"none", "ema"}: + raise ValueError("teacher_regularization must be one of: 'none', 'ema'") + if not 0.0 <= self.teacher_update_rate <= 1.0: + raise ValueError("teacher_update_rate must be in [0, 1]") + if self.teacher_sync_steps <= 0: + raise ValueError("teacher_sync_steps must be positive") + if self.peft_teacher_mode not in {"auto", "inherit_adapter", "disable_adapter", "teacher_adapter"}: + raise ValueError( + "peft_teacher_mode must be one of: 'auto', 'inherit_adapter', 'disable_adapter', 'teacher_adapter'" + ) + if self.peft_teacher_mode == "teacher_adapter" and self.teacher_regularization != "ema": + raise ValueError("peft_teacher_mode='teacher_adapter' requires teacher_regularization='ema'") + if self.teacher_adapter_name == "": + raise ValueError("teacher_adapter_name must be non-empty") if self.num_generations < 1: raise ValueError("num_generations must be at least 1") if not 0.0 <= self.distillation_alpha <= 1.0: diff --git a/trl/experimental/self_distillation/peft_adapter_ema_callback.py b/trl/experimental/self_distillation/teacher_sync.py similarity index 88% rename from trl/experimental/self_distillation/peft_adapter_ema_callback.py rename to trl/experimental/self_distillation/teacher_sync.py index e252bb512a4..267ef4d5854 100644 --- a/trl/experimental/self_distillation/peft_adapter_ema_callback.py +++ b/trl/experimental/self_distillation/teacher_sync.py @@ -22,10 +22,26 @@ TrainingArguments, ) +from ...trainer.callbacks import SyncRefModelCallback + logger = logging.getLogger(__name__) +class SyncTeacherModelCallback(SyncRefModelCallback): + """Synchronize an EMA teacher model with the student model on each configured sync step.""" + + def __init__(self, teacher_model, accelerator=None): + super().__init__(ref_model=teacher_model, accelerator=accelerator) + + def on_step_end(self, args, state, control, **kwargs): + model = kwargs["model"] + if self.ref_model is not None and state.global_step % args.teacher_sync_steps == 0: + if self.accelerator: + model = self.accelerator.unwrap_model(model) + self.sync_target_model(model, self.ref_model, args.teacher_update_rate) + + class PEFTAdapterEMACallback(TrainerCallback): """ Callback that maintains an EMA copy of PEFT adapter weights for use as a teacher model in self-distillation. diff --git a/trl/experimental/self_distillation/unified_base_self_distillation_trainer.py b/trl/experimental/self_distillation/unified_base_self_distillation_trainer.py index d8e60101119..ff738c4a548 100644 --- a/trl/experimental/self_distillation/unified_base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/unified_base_self_distillation_trainer.py @@ -42,7 +42,6 @@ from ...models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation from ...trainer.base_trainer import _BaseTrainer -from ...trainer.callbacks import SyncRefModelCallback from ...trainer.utils import ( RepeatSampler, create_model_from_path, @@ -57,14 +56,13 @@ from .self_distillation_config import SelfDistillationConfig from .self_distillation_mixin import SelfDistillationMixin from .teacher_context import PromptTokenizer +from .teacher_sync import PEFTAdapterEMACallback, SyncTeacherModelCallback if is_peft_available(): from peft import PeftConfig from peft.peft_model import PeftModel - from .peft_adapter_ema_callback import PEFTAdapterEMACallback - logger = get_logger(__name__) @@ -196,6 +194,7 @@ def __init__( self.num_iterations = args.num_iterations self.shuffle_dataset = args.shuffle_dataset self.loss_type = args.loss_type + self.mask_truncated_completions = args.mask_truncated_completions self.temperature = args.temperature self.chat_template_kwargs = args.chat_template_kwargs or {} self._step = 0 @@ -285,8 +284,8 @@ def __init__( def _setup_teacher_model(self) -> None: """Prepare teacher state according to the shared teacher policy.""" - teacher_regularization = self._get_teacher_regularization_mode() - peft_teacher_mode = self._get_peft_teacher_mode() + teacher_regularization = self.args.teacher_regularization + peft_teacher_mode = self._resolve_peft_teacher_mode() self._validate_teacher_policy(teacher_regularization, peft_teacher_mode) if teacher_regularization == "none": @@ -296,9 +295,9 @@ def _setup_teacher_model(self) -> None: self.add_callback( PEFTAdapterEMACallback( model=self.model, - teacher_adapter_name=self._get_teacher_adapter_name(), - update_rate=self.args.ref_model_mixup_alpha, - sync_steps=self.args.ref_model_sync_steps, + teacher_adapter_name=self.args.teacher_adapter_name, + update_rate=self.args.teacher_update_rate, + sync_steps=self.args.teacher_sync_steps, accelerator=self.accelerator, ) ) @@ -316,24 +315,27 @@ def _setup_teacher_model(self) -> None: else: self.teacher_model = self.accelerator.prepare_model(self.teacher_model, evaluation_mode=True) - self.add_callback(SyncRefModelCallback(ref_model=self.teacher_model, accelerator=self.accelerator)) + self.add_callback(SyncTeacherModelCallback(teacher_model=self.teacher_model, accelerator=self.accelerator)) - def _get_teacher_regularization_mode(self) -> str: - return "ema" if self.args.sync_ref_model else "none" + def _resolve_peft_teacher_mode(self) -> str: + peft_teacher_mode = self.args.peft_teacher_mode + if not (is_peft_available() and is_peft_model(self.model)): + if peft_teacher_mode in {"disable_adapter", "teacher_adapter"}: + raise ValueError(f"PEFT teacher mode `{peft_teacher_mode}` requires a PEFT model.") + return "inherit_adapter" - def _get_peft_teacher_mode(self) -> str: - return "inherit_adapter" + if peft_teacher_mode == "auto": + if self.args.teacher_regularization == "ema": + return "teacher_adapter" + return "disable_adapter" - def _get_teacher_adapter_name(self) -> str: - return "teacher" + return peft_teacher_mode def _validate_teacher_policy(self, teacher_regularization: str, peft_teacher_mode: str) -> None: if teacher_regularization not in {"none", "ema"}: raise ValueError(f"Unsupported teacher regularization mode: {teacher_regularization}") if peft_teacher_mode not in {"inherit_adapter", "disable_adapter", "teacher_adapter"}: raise ValueError(f"Unsupported PEFT teacher mode: {peft_teacher_mode}") - if peft_teacher_mode == "teacher_adapter" and not (is_peft_available() and is_peft_model(self.model)): - raise ValueError("PEFT teacher mode `teacher_adapter` requires a PEFT model.") if peft_teacher_mode == "teacher_adapter" and teacher_regularization != "ema": raise ValueError("PEFT teacher mode `teacher_adapter` requires EMA teacher regularization.") @@ -444,6 +446,10 @@ def build_rollout_batch(self, inputs: list[dict[str, Any]]) -> SelfDistillationR completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right").to(device=device) completion_mask = pad(completion_mask, padding_value=0, padding_side="right").to(device=device) + if self.mask_truncated_completions: + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() old_per_token_logps = self.compute_rollout_logps( prompt_ids=prompt_ids, prompt_mask=prompt_mask, @@ -457,6 +463,11 @@ def build_rollout_batch(self, inputs: list[dict[str, Any]]) -> SelfDistillationR completion_ids=completion_ids, completion_mask=completion_mask, old_per_token_logps=old_per_token_logps, + metadata={ + "raw_completion_lengths": torch.tensor( + [len(ids) for ids in completion_ids_list], device=device, dtype=torch.long + ) + }, ) def _generate(self, prompts: list[Any]) -> tuple[list[list[int]], list[list[int]]]: @@ -581,7 +592,7 @@ def augment_training_batch( """Inject teacher-side inputs and algorithm-specific fields into a common student rollout batch.""" def _get_teacher_context_for_self_distillation(self, model): - peft_teacher_mode = self._get_peft_teacher_mode() + peft_teacher_mode = self._resolve_peft_teacher_mode() if not (is_peft_available() and isinstance(self.model, PeftModel)): return super()._get_teacher_context_for_self_distillation(model) @@ -594,7 +605,7 @@ def _get_teacher_context_for_self_distillation(self, model): if peft_teacher_mode == "disable_adapter": return use_adapter(target_model, adapter_name=None) if peft_teacher_mode == "teacher_adapter": - teacher_adapter_name = self._get_teacher_adapter_name() + teacher_adapter_name = self.args.teacher_adapter_name if teacher_adapter_name not in target_model.peft_config: raise RuntimeError( f"Expected PEFT teacher adapter `{teacher_adapter_name}` to exist before teacher forward." From 55111ff9dc7e5bb75536ad4b0024b68e2ce05ef8 Mon Sep 17 00:00:00 2001 From: Leon Date: Wed, 15 Apr 2026 16:19:38 +0200 Subject: [PATCH 04/23] remove legacy trainers --- trl/experimental/sdft/sdft.py | 21 +- trl/experimental/sdft/sdft_trainer.py | 345 +-------- .../sdft/sdft_trainer_transition.py | 205 ------ trl/experimental/sdpo/sdpo.py | 3 +- trl/experimental/sdpo/sdpo_trainer.py | 363 ++++++++-- .../sdpo/sdpo_trainer_transition.py | 656 ------------------ .../base_self_distillation_trainer.py | 452 +++++++++--- .../self_distillation_mixin.py | 25 - .../unified_base_self_distillation_trainer.py | 619 ----------------- 9 files changed, 694 insertions(+), 1995 deletions(-) delete mode 100644 trl/experimental/sdft/sdft_trainer_transition.py delete mode 100644 trl/experimental/sdpo/sdpo_trainer_transition.py delete mode 100644 trl/experimental/self_distillation/unified_base_self_distillation_trainer.py diff --git a/trl/experimental/sdft/sdft.py b/trl/experimental/sdft/sdft.py index 77e1c8c3880..64521a2707e 100644 --- a/trl/experimental/sdft/sdft.py +++ b/trl/experimental/sdft/sdft.py @@ -46,7 +46,6 @@ --learning_rate 2e-5 \ --max_prompt_length 1024 \ --max_completion_length 512 \ - --generate_from_teacher \ --teacher_regularization ema \ --teacher_sync_steps 1 \ --teacher_update_rate 0.01 \ @@ -77,7 +76,8 @@ get_quantization_config, ) from trl.data_utils import maybe_apply_chat_template -from trl.experimental.sdft import SDFTConfig, SDFTTrainer +from trl.experimental.sdft import SDFTConfig +from trl.experimental.sdft.sdft_trainer_transition import SDFTTrainer from trl.models import unwrap_model_for_generation @@ -86,10 +86,6 @@ @dataclass class SDFTScriptArguments(ScriptArguments): - ref_model_name_or_path: str | None = field( - default=None, - metadata={"help": "Reference teacher model. Optional for PEFT runs, where the base model is used as teacher."}, - ) dataset_path: str | None = field( default=None, metadata={"help": "Optional local dataset path to load with `load_from_disk`. Overrides `dataset_name`."}, @@ -328,9 +324,10 @@ def _run_tooluse_eval( if model_args.model_name_or_path is None: raise ValueError("`model_name_or_path` is required.") - if script_args.ref_model_name_or_path is None and not model_args.use_peft: - script_args.ref_model_name_or_path = model_args.model_name_or_path - + if training_args.generate_from_teacher: + raise ValueError( + "`generate_from_teacher` is not yet supported by the transitioned SDFT trainer used in this script." + ) if model_args.dtype in ["auto", None]: if training_args.bf16: dtype = torch.bfloat16 @@ -385,16 +382,10 @@ def _run_tooluse_eval( eval_dataset = _prepare_split(raw_eval_dataset, script_args) model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) - ref_model = None - if script_args.ref_model_name_or_path is not None: - ref_model = AutoModelForCausalLM.from_pretrained(script_args.ref_model_name_or_path, **model_kwargs) model.config.use_cache = False if training_args.gradient_checkpointing else True - if ref_model is not None: - ref_model.config.use_cache = True trainer = SDFTTrainer( model=model, - ref_model=ref_model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, diff --git a/trl/experimental/sdft/sdft_trainer.py b/trl/experimental/sdft/sdft_trainer.py index 5afc4dff57f..057433762d2 100644 --- a/trl/experimental/sdft/sdft_trainer.py +++ b/trl/experimental/sdft/sdft_trainer.py @@ -14,57 +14,32 @@ from __future__ import annotations -import copy -import inspect import textwrap -from collections import defaultdict -from functools import partial from typing import Any -import datasets import torch from accelerate.logging import get_logger -from accelerate.utils import is_peft_model from datasets import Dataset, IterableDataset from torch import nn -from torch.utils.data import DataLoader, Sampler from transformers import ( - AutoProcessor, - GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, ) -from transformers.trainer_utils import seed_worker -from transformers.utils import is_datasets_available, is_peft_available +from transformers.utils import is_peft_available -from ...models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation -from ...trainer.base_trainer import _BaseTrainer -from ...trainer.utils import ( - RepeatSampler, - create_model_from_path, - disable_dropout_in_model, - get_config_model_id, - identity, - pad, - split_tensor_dict, - use_adapter, -) -from ..self_distillation.self_distillation_mixin import SelfDistillationMixin -from ..self_distillation.teacher_context import PromptTokenizer, extract_last_user_text -from ..self_distillation.teacher_sync import PEFTAdapterEMACallback, SyncTeacherModelCallback -from ..self_distillation.unified_base_self_distillation_trainer import ( +from ..self_distillation.base_self_distillation_trainer import ( + BaseSelfDistillationTrainer, SelfDistillationBatch, SelfDistillationRolloutBatch, ) -from ..utils import prepare_peft_model +from ..self_distillation.teacher_context import PromptTokenizer, extract_last_user_text from .sdft_config import SDFTConfig if is_peft_available(): from peft import PeftConfig - from peft.peft_model import PeftModel logger = get_logger(__name__) @@ -140,7 +115,7 @@ def build( } -class SDFTTrainer(SelfDistillationMixin, _BaseTrainer): +class SDFTTrainer(BaseSelfDistillationTrainer): """Trainer for SDFT-style on-policy self-distillation with explicit teacher prompts.""" _tag_names = ["trl", "sdft"] @@ -170,310 +145,26 @@ def __init__( optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), peft_config: PeftConfig | None = None, ): - if train_dataset is None: - raise ValueError("`train_dataset` is required") if isinstance(train_dataset, IterableDataset): raise NotImplementedError("Iterable datasets are not yet supported in SDFTTrainer.") if isinstance(eval_dataset, IterableDataset) or ( isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values()) ): raise NotImplementedError("Iterable eval datasets are not yet supported in SDFTTrainer.") - if args.use_vllm: - raise NotImplementedError("SDFTTrainer does not support `use_vllm=True` yet.") - if isinstance(model, str): - model_init_kwargs = args.model_init_kwargs or {} - if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: - model_init_kwargs["device_map"] = None - model = create_model_from_path(model, **model_init_kwargs) - elif args.model_init_kwargs is not None: - logger.warning( - "You passed `model_init_kwargs` to `SDFTConfig`, but `model` is already instantiated. " - "The `model_init_kwargs` will be ignored." - ) - - self.model_kwarg_keys = ( - inspect.signature(model.forward).parameters.keys() - if not hasattr(model, "get_base_model") - else inspect.signature(model.get_base_model().forward).parameters.keys() - ) - - if is_peft_available() and is_peft_model(model) and peft_config is not None: - raise ValueError( - "You passed a `PeftModel` instance together with a `peft_config` to SDFTTrainer. Pass either a base " - "model with `peft_config`, or a pre-wrapped PEFT model." - ) - if peft_config is not None or (is_peft_available() and getattr(model, "peft_config", None) is not None): - model = prepare_peft_model(model, peft_config, args) - - if processing_class is None: - processing_class = AutoProcessor.from_pretrained( - get_config_model_id(model.config), truncation_side="left", padding_side="left" - ) - - if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer - elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class - else: - raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - self.pad_token_id = tokenizer.pad_token_id - self.eos_token_id = tokenizer.eos_token_id - self.max_prompt_length = args.max_prompt_length - self.max_completion_length = args.max_completion_length - self.num_generations = args.num_generations - self.num_iterations = args.num_iterations - self.temperature = args.temperature - self.loss_type = args.loss_type - self.shuffle_dataset = args.shuffle_dataset - self.generate_from_teacher = args.generate_from_teacher - self.num_loss_tokens_to_skip = args.num_loss_tokens_to_skip - self.chat_template_kwargs = args.chat_template_kwargs or {} - self._step = 0 - self._buffered_inputs = None - self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} - self.prompt_tokenizer = PromptTokenizer(self) - self.teacher_context_builder = DemonstrationTeacherContextBuilder(self) - - generation_kwargs = { - "max_new_tokens": self.max_completion_length, - "do_sample": True, - "pad_token_id": tokenizer.pad_token_id, - "bos_token_id": tokenizer.bos_token_id, - "eos_token_id": tokenizer.eos_token_id, - "temperature": args.temperature, - "top_p": args.top_p, - "top_k": args.top_k, - "min_p": args.min_p, - "repetition_penalty": args.repetition_penalty, - "cache_implementation": args.cache_implementation, - } - if args.generation_kwargs is not None: - generation_kwargs.update(args.generation_kwargs) - self.generation_config = GenerationConfig(**generation_kwargs, disable_compile=True) - - if hasattr(model, "warnings_issued"): - model.warnings_issued["estimate_tokens"] = True super().__init__( model=model, args=args, - data_collator=identity, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, callbacks=callbacks, optimizers=optimizers, - compute_loss_func="non-None value to disable scaling", + peft_config=peft_config, ) - if args.disable_dropout: - disable_dropout_in_model(self.model) - - self.model.add_model_tags(self._tag_names) - - # In self-distillation the teacher is always derived from the student: - # - PEFT: base model with adapter disabled (or EMA teacher adapter when teacher_regularization="ema") - # - Non-PEFT: same model (or deep-copied EMA model when teacher_regularization="ema") - self.teacher_model = None - - peft_teacher_mode = args.peft_teacher_mode - if is_peft_available() and is_peft_model(self.model): - if peft_teacher_mode == "auto": - peft_teacher_mode = "teacher_adapter" if args.teacher_regularization == "ema" else "disable_adapter" - else: - if peft_teacher_mode in {"disable_adapter", "teacher_adapter"}: - raise ValueError(f"PEFT teacher mode `{peft_teacher_mode}` requires a PEFT model.") - peft_teacher_mode = "inherit_adapter" - - if args.teacher_regularization == "ema": - if peft_teacher_mode == "teacher_adapter": - self.add_callback( - PEFTAdapterEMACallback( - model=self.model, - teacher_adapter_name=args.teacher_adapter_name, - update_rate=args.teacher_update_rate, - sync_steps=args.teacher_sync_steps, - accelerator=self.accelerator, - ) - ) - else: - student_model = self.accelerator.unwrap_model(self.model) - self.teacher_model = copy.deepcopy(student_model) - self.teacher_model.requires_grad_(False) - self.teacher_model.eval() - if self.is_deepspeed_enabled: - self.teacher_model = prepare_deepspeed(self.teacher_model, self.accelerator) - elif self.is_fsdp_enabled: - self.teacher_model = prepare_fsdp(self.teacher_model, self.accelerator) - else: - self.teacher_model = self.accelerator.prepare_model(self.teacher_model, evaluation_mode=True) - self.add_callback( - SyncTeacherModelCallback(teacher_model=self.teacher_model, accelerator=self.accelerator) - ) - - self.model_accepts_loss_kwargs = False - - def get_train_dataloader(self): - if self.train_dataset is None: - raise ValueError("Trainer: training requires a train_dataset.") - - train_dataset = self.train_dataset - data_collator = self.data_collator - if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): - train_dataset = self._remove_unused_columns(train_dataset, description="training") - else: - data_collator = self._get_collator_with_removed_columns(data_collator, description="training") - - dataloader_params = { - "batch_size": self._train_batch_size * self.args.steps_per_generation, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "persistent_workers": self.args.dataloader_persistent_workers, - } - if not isinstance(train_dataset, torch.utils.data.IterableDataset): - dataloader_params["sampler"] = self._get_train_sampler() - dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["worker_init_fn"] = partial( - seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index - ) - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) - - def _get_train_sampler(self, dataset=None) -> Sampler: - if dataset is None: - dataset = self.train_dataset - return RepeatSampler( - data_source=dataset, - mini_repeat_count=self.num_generations, - batch_size=self.args.generation_batch_size // self.num_generations, - repeat_count=self.num_iterations * self.args.steps_per_generation, - shuffle=self.shuffle_dataset, - seed=self.args.seed, - ) - - def _get_eval_sampler(self, eval_dataset) -> Sampler: - return RepeatSampler( - data_source=eval_dataset, - mini_repeat_count=self.num_generations, - seed=self.args.seed, - ) - - def training_step(self, model, inputs, num_items_in_batch): - output = super().training_step(model, inputs, num_items_in_batch) - self._step += 1 - return output - - def _prepare_inputs(self, generation_batch): - mode = "train" if self.model.training else "eval" - if mode == "train": - generate_every = self.args.steps_per_generation * self.num_iterations - if self._step % generate_every == 0 or self._buffered_inputs is None: - generation_batch = self._build_buffered_batch(generation_batch) - self._buffered_inputs = split_tensor_dict(generation_batch, self.args.steps_per_generation) - self._dispatch_self_distillation_callback( - "on_generation_batch_built", - generate_every=generate_every, - steps_per_generation=self.args.steps_per_generation, - ) - return self._buffered_inputs[self._step % self.args.steps_per_generation] - return self._build_buffered_batch(generation_batch) - - def _generate_completion_ids(self, prompts: list[Any]) -> tuple[torch.Tensor, torch.Tensor]: - generate_inputs = self.processing_class( - text=self.prompt_tokenizer.apply_prompt_template(prompts), - return_tensors="pt", - padding=True, - padding_side="left", - max_length=self.max_prompt_length, - truncation=True, - add_special_tokens=False, - ) - # This generation helper builds tokenized model inputs directly, so use the base Trainer tensor preparation - # instead of re-entering the buffered outer training hook. - generate_inputs = _BaseTrainer._prepare_inputs(self, generate_inputs) - - with ( - unwrap_model_for_generation( - self.model_wrapped, - self.accelerator, - gather_deepspeed3_params=self.args.ds3_gather_for_generation, - ) as unwrapped_model, - torch.no_grad(), - ): - prompt_completion_ids = unwrapped_model.generate( - **generate_inputs, generation_config=self.generation_config - ) - - prompt_length = generate_inputs["input_ids"].size(1) - completion_ids = prompt_completion_ids[:, prompt_length:] - is_eos = completion_ids == self.eos_token_id - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device) - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - seq_idx = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1) - completion_mask = (seq_idx <= eos_idx.unsqueeze(1)).long() - - completion_ids_list = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=True)] - completion_ids = [torch.tensor(ids, device=self.accelerator.device) for ids in completion_ids_list] - completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] - return ( - pad(completion_ids, padding_value=self.pad_token_id, padding_side="right"), - pad(completion_mask, padding_value=0, padding_side="right"), - ) - - def _build_buffered_batch(self, inputs: list[dict[str, Any]]) -> dict[str, torch.Tensor | Any]: - prompts, privileged_contexts = self._split_prompt_and_privileged_context(inputs) - generation_prompts = self.teacher_context_builder.select_generation_prompts(prompts, privileged_contexts) - generation_prompt_text = self.prompt_tokenizer.apply_prompt_template(generation_prompts) - self._dispatch_self_distillation_callback( - "on_generation_prompts_selected", - generation_prompts=generation_prompts, - generation_prompt_text=generation_prompt_text, - ) - completion_ids, completion_mask = self._generate_completion_ids(generation_prompts) - - teacher_batch = self.teacher_context_builder.build( - prompts, privileged_contexts, completion_ids, completion_mask - ) - - prompt_completion_ids = torch.cat([teacher_batch["prompt_ids"], completion_ids], dim=1) - attention_mask = torch.cat([teacher_batch["prompt_mask"], completion_mask], dim=1) - logits_to_keep = completion_ids.size(1) - - with torch.no_grad(): - generate_every = self.args.steps_per_generation * self.num_iterations - if not self.generate_from_teacher and self.args.gradient_accumulation_steps % generate_every != 0: - old_per_token_logps, _ = self._get_per_token_logps_and_entropies( - self.model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - compute_entropy=False, - ) - else: - old_per_token_logps = None - - self._dispatch_self_distillation_callback( - "on_self_distillation_batch_prepared", - old_per_token_logps=old_per_token_logps, - prompt_ids=teacher_batch["prompt_ids"], - completion_ids=completion_ids, - ) - output = { - "prompt_ids": teacher_batch["prompt_ids"], - "prompt_mask": teacher_batch["prompt_mask"], - "completion_ids": completion_ids, - "completion_mask": completion_mask, - "teacher_input_ids": teacher_batch["teacher_input_ids"], - "teacher_attention_mask": teacher_batch["teacher_attention_mask"], - } - if old_per_token_logps is not None: - output["old_per_token_logps"] = old_per_token_logps - return output + self.num_loss_tokens_to_skip = args.num_loss_tokens_to_skip + self.teacher_context_builder = DemonstrationTeacherContextBuilder(self) def augment_training_batch( self, @@ -488,7 +179,6 @@ def augment_training_batch( rollout_batch.completion_mask, ) - old_per_token_logps = None if self.generate_from_teacher else rollout_batch.old_per_token_logps return SelfDistillationBatch( prompt_ids=teacher_batch["prompt_ids"], prompt_mask=teacher_batch["prompt_mask"], @@ -496,7 +186,7 @@ def augment_training_batch( completion_mask=rollout_batch.completion_mask, teacher_input_ids=teacher_batch["teacher_input_ids"], teacher_attention_mask=teacher_batch["teacher_attention_mask"], - old_per_token_logps=old_per_token_logps, + old_per_token_logps=rollout_batch.old_per_token_logps, ) def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): @@ -513,20 +203,3 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N loss = self._compute_self_distillation_loss(model, inputs) accumulation_scale = self.current_gradient_accumulation_steps if self.model.training else 1.0 return loss / accumulation_scale - - def _get_teacher_context_for_self_distillation(self, model): - if is_peft_available() and isinstance(self.model, PeftModel): - model = self.accelerator.unwrap_model(self.model) - peft_teacher_mode = self.args.peft_teacher_mode - if peft_teacher_mode == "auto": - peft_teacher_mode = ( - "teacher_adapter" if self.args.teacher_regularization == "ema" else "disable_adapter" - ) - if peft_teacher_mode == "inherit_adapter": - return super()._get_teacher_context_for_self_distillation(model) - if peft_teacher_mode == "teacher_adapter" and self.args.teacher_adapter_name in model.peft_config: - return use_adapter(model, adapter_name=self.args.teacher_adapter_name) - if peft_teacher_mode == "disable_adapter": - return use_adapter(model, adapter_name=None) - raise ValueError(f"Unsupported PEFT teacher mode: {peft_teacher_mode}") - return super()._get_teacher_context_for_self_distillation(model) diff --git a/trl/experimental/sdft/sdft_trainer_transition.py b/trl/experimental/sdft/sdft_trainer_transition.py deleted file mode 100644 index cafce5bf221..00000000000 --- a/trl/experimental/sdft/sdft_trainer_transition.py +++ /dev/null @@ -1,205 +0,0 @@ -# Copyright 2020-2026 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import textwrap -from typing import Any - -import torch -from accelerate.logging import get_logger -from datasets import Dataset, IterableDataset -from torch import nn -from transformers import ( - PreTrainedModel, - PreTrainedTokenizerBase, - ProcessorMixin, - TrainerCallback, -) -from transformers.utils import is_peft_available - -from ..self_distillation.teacher_context import PromptTokenizer, extract_last_user_text -from ..self_distillation.unified_base_self_distillation_trainer import ( - SelfDistillationBatch, - SelfDistillationRolloutBatch, - UnifiedBaseSelfDistillationTrainer, -) -from .sdft_config import SDFTConfig - - -if is_peft_available(): - from peft import PeftConfig - - -logger = get_logger(__name__) - - -class DemonstrationTeacherContextBuilder: - """Builds student and teacher contexts from prompts plus privileged context, as in SDFT.""" - - def __init__(self, trainer): - self.trainer = trainer - self.prompt_tokenizer = PromptTokenizer(trainer) - - def _stringify_privileged_context(self, privileged_context: Any) -> str: - if privileged_context is None: - raise ValueError( - "`privileged_context` must not be None for self-distillation teacher prompt construction." - ) - if isinstance(privileged_context, str): - return privileged_context - if isinstance(privileged_context, list) and privileged_context and isinstance(privileged_context[0], dict): - chunks = [] - for message in privileged_context: - content = message.get("content", "") - if isinstance(content, list): - text = " ".join(part.get("text", "") for part in content if part.get("type") == "text") - else: - text = str(content) - if text: - chunks.append(text) - return "\n".join(chunks) - return str(privileged_context) - - def _compose_teacher_prompt(self, prompt: Any, privileged_context: Any) -> Any: - privileged_text = self._stringify_privileged_context(privileged_context) - if isinstance(prompt, list): - system_messages = prompt[:-1] - prompt_text = extract_last_user_text(prompt) - teacher_text = self.trainer.args.teacher_prompt_template.format( - prompt=prompt_text, - privileged_context=privileged_text, - ) - return system_messages + [{"role": "user", "content": teacher_text}] - return self.trainer.args.teacher_prompt_template.format(prompt=prompt, privileged_context=privileged_text) - - def select_generation_prompts(self, prompts: list[Any], privileged_contexts: list[Any]) -> list[Any]: - if not self.trainer.generate_from_teacher: - return prompts - return [ - self._compose_teacher_prompt(prompt, privileged_context) - for prompt, privileged_context in zip(prompts, privileged_contexts, strict=True) - ] - - def build( - self, - prompts: list[Any], - privileged_contexts: list[Any], - completion_ids: torch.Tensor, - completion_mask: torch.Tensor, - ) -> dict[str, torch.Tensor]: - student_batch = self.prompt_tokenizer.tokenize_prompts(prompts) - teacher_prompts = [ - self._compose_teacher_prompt(prompt, privileged_context) - for prompt, privileged_context in zip(prompts, privileged_contexts, strict=True) - ] - teacher_batch = self.prompt_tokenizer.tokenize_prompts(teacher_prompts) - teacher_input_ids = torch.cat([teacher_batch.prompt_ids, completion_ids], dim=1) - teacher_attention_mask = torch.cat([teacher_batch.prompt_mask, completion_mask], dim=1) - return { - "prompt_ids": student_batch.prompt_ids, - "prompt_mask": student_batch.prompt_mask, - "teacher_input_ids": teacher_input_ids, - "teacher_attention_mask": teacher_attention_mask, - } - - -class SDFTTrainer(UnifiedBaseSelfDistillationTrainer): - """Trainer for SDFT-style on-policy self-distillation with explicit teacher prompts.""" - - _tag_names = ["trl", "sdft"] - _name = "SDFT" - config_cls = SDFTConfig - # docstyle-ignore - _paper = { - "title": "Self-Training with On-Policy Self-Distillation for Language Model Alignment", - "id": "2601.19897", - "citation": textwrap.dedent("""\ - @article{hubotter2026selftraining, - title = {{Self-Training with On-Policy Self-Distillation for Language Model Alignment}}, - author = {Jonas H\\"ubotter and Frederike L\\"ubeck and Lejs Behric and Anton Baumann and Marco Bagatella and Daniel Marta and Ido Hakimi and Idan Shenfeld and Thomas Kleine Buening and Carlos Guestrin and Andreas Krause}, - year = 2026, - eprint = {arXiv:2601.19897} - }"""), - } - - def __init__( - self, - model: str | PreTrainedModel | nn.Module, - args: SDFTConfig | None = None, - train_dataset: Dataset | IterableDataset | None = None, - eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, - processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, - callbacks: list[TrainerCallback] | None = None, - optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), - peft_config: PeftConfig | None = None, - ): - if isinstance(train_dataset, IterableDataset): - raise NotImplementedError("Iterable datasets are not yet supported in SDFTTrainer.") - if isinstance(eval_dataset, IterableDataset) or ( - isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values()) - ): - raise NotImplementedError("Iterable eval datasets are not yet supported in SDFTTrainer.") - - super().__init__( - model=model, - args=args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - processing_class=processing_class, - callbacks=callbacks, - optimizers=optimizers, - peft_config=peft_config, - ) - - self.num_loss_tokens_to_skip = args.num_loss_tokens_to_skip - self.teacher_context_builder = DemonstrationTeacherContextBuilder(self) - - def augment_training_batch( - self, - inputs: list[dict[str, Any]], - rollout_batch: SelfDistillationRolloutBatch, - ) -> SelfDistillationBatch: - prompts, privileged_contexts = self._split_prompt_and_privileged_context(inputs) - teacher_batch = self.teacher_context_builder.build( - prompts, - privileged_contexts, - rollout_batch.completion_ids, - rollout_batch.completion_mask, - ) - - return SelfDistillationBatch( - prompt_ids=teacher_batch["prompt_ids"], - prompt_mask=teacher_batch["prompt_mask"], - completion_ids=rollout_batch.completion_ids, - completion_mask=rollout_batch.completion_mask, - teacher_input_ids=teacher_batch["teacher_input_ids"], - teacher_attention_mask=teacher_batch["teacher_attention_mask"], - old_per_token_logps=rollout_batch.old_per_token_logps, - ) - - def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): - if return_outputs: - raise ValueError("The SDFTTrainer does not support returning outputs") - - if self.num_loss_tokens_to_skip > 0: - inputs = dict(inputs) - completion_mask = inputs["completion_mask"].clone() - token_positions = torch.arange(completion_mask.size(1), device=completion_mask.device).unsqueeze(0) - completion_mask = completion_mask * (token_positions >= self.num_loss_tokens_to_skip).long() - inputs["completion_mask"] = completion_mask - - loss = self._compute_self_distillation_loss(model, inputs) - accumulation_scale = self.current_gradient_accumulation_steps if self.model.training else 1.0 - return loss / accumulation_scale diff --git a/trl/experimental/sdpo/sdpo.py b/trl/experimental/sdpo/sdpo.py index 7a65a00cc5d..5ab93fbaf5b 100644 --- a/trl/experimental/sdpo/sdpo.py +++ b/trl/experimental/sdpo/sdpo.py @@ -78,7 +78,8 @@ get_quantization_config, ) from trl.data_utils import maybe_apply_chat_template -from trl.experimental.sdpo import SDPOConfig, SDPOTrainer +from trl.experimental.sdpo import SDPOConfig +from trl.experimental.sdpo.sdpo_trainer_transition import SDPOTrainer SYSTEM_PROMPT = ( diff --git a/trl/experimental/sdpo/sdpo_trainer.py b/trl/experimental/sdpo/sdpo_trainer.py index 5d24484e255..d554391f16e 100644 --- a/trl/experimental/sdpo/sdpo_trainer.py +++ b/trl/experimental/sdpo/sdpo_trainer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import re import textwrap from typing import Any @@ -21,30 +20,30 @@ from accelerate.utils import gather_object from datasets import Dataset, IterableDataset from torch import nn -from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback - -from ...trainer.callbacks import SyncRefModelCallback -from ...trainer.utils import pad -from ..self_distillation.base_self_distillation_trainer import BaseSelfDistillationTrainer +from transformers import ( + AutoModelForSequenceClassification, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, +) +from transformers.utils import logging + +from ...data_utils import apply_chat_template, is_conversational +from ...models import prepare_deepspeed, prepare_fsdp +from ...trainer.base_trainer import _BaseTrainer +from ...trainer.utils import get_config_model_id, pad +from ..self_distillation.base_self_distillation_trainer import ( + BaseSelfDistillationTrainer, + SelfDistillationBatch, + SelfDistillationRolloutBatch, +) from ..self_distillation.teacher_context import TokenizedPromptBatch, extract_last_user_text from .sdpo_config import SDPOConfig -class EMATeacherSyncCallback(SyncRefModelCallback): - """Synchronize an EMA teacher model with the student model on each step.""" - - def __init__(self, teacher_model, update_rate: float, sync_steps: int, accelerator=None): - super().__init__(ref_model=teacher_model, accelerator=accelerator) - self.update_rate = update_rate - self.sync_steps = sync_steps - - def on_step_end(self, args, state, control, **kwargs): - if state.global_step % self.sync_steps != 0: - return - model = kwargs["model"] - if self.accelerator is not None: - model = self.accelerator.unwrap_model(model) - self.sync_target_model(model, self.ref_model, self.update_rate) +logger = logging.get_logger(__name__) class SuccessfulRolloutTeacherContextBuilder: @@ -280,45 +279,197 @@ def __init__( raise ValueError("`reward_funcs` is required for SDPOTrainer because SDPO must score rollouts.") super().__init__( model=model, - reward_funcs=reward_funcs, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, - reward_processing_classes=reward_processing_classes, callbacks=callbacks, optimizers=optimizers, peft_config=peft_config, ) - self.teacher_context_builder = SuccessfulRolloutTeacherContextBuilder(self) - if self.args.teacher_regularization == "ema": - # `self.model` may already be accelerator-wrapped after the shared base constructor. Build the EMA - # teacher from the unwrapped student model first, then prepare it as an auxiliary eval-only module. - student_model = self.accelerator.unwrap_model(self.model) - self.teacher_model = copy.deepcopy(student_model) - self.teacher_model.requires_grad_(False) - self.teacher_model.eval() - self.teacher_model = self._prepare_auxiliary_model_for_eval(self.teacher_model) - self.add_callback( - EMATeacherSyncCallback( - teacher_model=self.teacher_model, - update_rate=self.args.teacher_update_rate, - sync_steps=self.args.teacher_sync_steps, - accelerator=self.accelerator, + self.importance_sampling_level = args.importance_sampling_level + self.scale_rewards = args.scale_rewards + self.epsilon_low = args.epsilon + self.epsilon_high = args.epsilon_high + self.beta = args.beta + + if not isinstance(reward_funcs, list): + reward_funcs = [reward_funcs] + self.reward_func_names = [] + for i, reward_func in enumerate(reward_funcs): + if isinstance(reward_func, str): + reward_model_init_kwargs = args.model_init_kwargs or {} + if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + reward_model_init_kwargs["device_map"] = None + reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( + reward_func, + num_labels=1, + **reward_model_init_kwargs, ) - ) + if isinstance(reward_funcs[i], nn.Module): + self.reward_func_names.append(get_config_model_id(reward_funcs[i].config).split("/")[-1]) + else: + self.reward_func_names.append(reward_funcs[i].__name__) + self.reward_funcs = reward_funcs + + if args.reward_weights is not None: + if len(args.reward_weights) != len(self.reward_funcs): + raise ValueError("Number of reward weights must match number of reward functions") + self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) + else: + self.reward_weights = torch.ones(len(self.reward_funcs), dtype=torch.float32) + + if reward_processing_classes is None: + reward_processing_classes = [None] * len(self.reward_funcs) + elif not isinstance(reward_processing_classes, list): + reward_processing_classes = [reward_processing_classes] + if len(reward_processing_classes) != len(self.reward_funcs): + raise ValueError("Number of reward processing classes must match number of reward functions") + + for i, (reward_processing_class, reward_func) in enumerate( + zip(reward_processing_classes, self.reward_funcs, strict=True) + ): + if isinstance(reward_func, PreTrainedModel): + if reward_processing_class is None: + reward_processing_class = AutoTokenizer.from_pretrained(get_config_model_id(reward_func.config)) + if reward_processing_class.pad_token_id is None: + reward_processing_class.pad_token = reward_processing_class.eos_token + reward_func.config.pad_token_id = reward_processing_class.pad_token_id + reward_processing_classes[i] = reward_processing_class + self.reward_processing_classes = reward_processing_classes + + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, nn.Module): + if self.is_deepspeed_enabled: + self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) + elif self.is_fsdp_enabled: + self.reward_funcs[i] = prepare_fsdp(reward_func, self.accelerator) + else: + self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) + + self.teacher_context_builder = SuccessfulRolloutTeacherContextBuilder(self) def _allow_topk_without_full_logit_distillation(self) -> bool: return False - def _generate_and_score_completions( - self, inputs: list[dict[str, torch.Tensor | Any]] - ) -> dict[str, torch.Tensor | Any]: + def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): + device = self.accelerator.device + if len(self.reward_funcs) == 0: + return torch.zeros((len(prompts), 0), device=device) + + rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) + keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]] + reward_kwargs = {key: [example[key] for example in inputs] for key in keys} + reward_kwargs["trainer_state"] = self.state + + for i, (reward_func, reward_processing_class) in enumerate( + zip(self.reward_funcs, self.reward_processing_classes, strict=True) + ): + if isinstance(reward_func, nn.Module): + if is_conversational(inputs[0]): + messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)] + texts = [ + apply_chat_template(x, reward_processing_class, **self.chat_template_kwargs)["text"] + for x in messages + ] + else: + texts = [p + c for p, c in zip(prompts, completions, strict=True)] + reward_inputs = reward_processing_class( + text=texts, + return_tensors="pt", + padding=True, + padding_side="right", + add_special_tokens=False, + ) + reward_inputs = _BaseTrainer._prepare_inputs(self, reward_inputs) + with torch.inference_mode(): + rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] + else: + output_reward_func = reward_func( + prompts=prompts, + completions=completions, + completion_ids=completion_ids_list, + **reward_kwargs, + ) + output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + + return self.accelerator.gather(rewards_per_func) + + def augment_training_batch( + self, + inputs: list[dict[str, Any]], + rollout_batch: SelfDistillationRolloutBatch, + ) -> SelfDistillationBatch: + device = self.accelerator.device + mode = "train" if self.model.training else "eval" prompts, privileged_contexts = self._split_prompt_and_privileged_context(inputs) + raw_completion_lengths = rollout_batch.metadata["raw_completion_lengths"].detach().cpu().tolist() + completion_ids_list = [ + ids[:length].tolist() + for ids, length in zip(rollout_batch.completion_ids.detach().cpu(), raw_completion_lengths, strict=True) + ] + if is_conversational({"prompt": prompts[0]}): + completions_text = self.processing_class.batch_decode( + rollout_batch.completion_ids, skip_special_tokens=True + ) + completions = [[{"role": "assistant", "content": content}] for content in completions_text] + else: + completions = self.processing_class.batch_decode(rollout_batch.completion_ids, skip_special_tokens=True) - output = super()._generate_and_score_completions(inputs) - output.update( - self.teacher_context_builder.build(output, prompts, output["rewards"], feedbacks=privileged_contexts) + rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) + if rewards_per_func.numel() == 0: + rewards = torch.zeros(self.accelerator.num_processes * len(prompts), device=device) + else: + rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + + num_generations = self.num_generations if mode == "train" else self.num_generations_eval + mean_grouped_rewards = rewards.view(-1, num_generations).mean(dim=1).repeat_interleave(num_generations, dim=0) + if self.scale_rewards == "batch": + std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards) + group_std_rewards = rewards.view(-1, num_generations).std(dim=1) + elif self.scale_rewards == "none": + std_rewards = torch.ones_like(rewards) + group_std_rewards = torch.ones(rewards.numel() // num_generations, device=device, dtype=rewards.dtype) + else: + group_std_rewards = rewards.view(-1, num_generations).std(dim=1) + std_rewards = group_std_rewards.repeat_interleave(num_generations, dim=0) + advantages = (rewards - mean_grouped_rewards) / (std_rewards + 1e-4) + self._record_reward_diagnostics(mode, rewards, rewards_per_func, group_std_rewards) + + local_batch_size = rollout_batch.completion_ids.size(0) + process_start = self.accelerator.process_index * local_batch_size + process_slice = slice(process_start, process_start + local_batch_size) + local_rewards = rewards[process_slice] + local_advantages = advantages[process_slice] + + agg_completion_lengths = self.accelerator.gather( + torch.tensor([len(ids) for ids in completion_ids_list], device=device) + ) + self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) + + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + agg_is_truncated = self.accelerator.gather(is_truncated) + self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) + term_completion_lengths = agg_completion_lengths[~agg_is_truncated] + if len(term_completion_lengths) == 0: + term_completion_lengths = torch.zeros(1, device=device) + self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) + + rollout_dict = rollout_batch.to_dict() + rollout_dict["rewards"] = local_rewards + rollout_dict["advantages"] = local_advantages + rollout_dict["num_items_in_batch"] = rollout_batch.completion_mask.sum().detach() + teacher_context = self.teacher_context_builder.build( + rollout_dict, + prompts, + rollout_dict["rewards"], + feedbacks=privileged_contexts, ) mode = "train" if self.model.training else "eval" @@ -328,13 +479,26 @@ def _generate_and_score_completions( self._dispatch_self_distillation_callback( "on_teacher_context_built", - teacher_input_ids=output["teacher_input_ids"], - teacher_attention_mask=output["teacher_attention_mask"], - completion_mask=output["completion_mask"], - self_distillation_mask=output["self_distillation_mask"], + teacher_input_ids=teacher_context["teacher_input_ids"], + teacher_attention_mask=teacher_context["teacher_attention_mask"], + completion_mask=rollout_batch.completion_mask, + self_distillation_mask=teacher_context["self_distillation_mask"], ) - return output + return SelfDistillationBatch( + prompt_ids=rollout_batch.prompt_ids, + prompt_mask=rollout_batch.prompt_mask, + completion_ids=rollout_batch.completion_ids, + completion_mask=rollout_batch.completion_mask, + teacher_input_ids=teacher_context["teacher_input_ids"], + teacher_attention_mask=teacher_context["teacher_attention_mask"], + old_per_token_logps=rollout_batch.old_per_token_logps, + self_distillation_mask=teacher_context["self_distillation_mask"], + metadata={ + "rewards": local_rewards, + "advantages": local_advantages, + }, + ) def _warn_on_inactive_self_distillation(self, mode: str) -> None: metrics = self.teacher_context_builder.last_metrics @@ -369,8 +533,105 @@ def _warn_on_inactive_self_distillation(self, mode: str) -> None: else: self._diagnostic_counters[mode]["no_successful_rollouts"] = 0 + def _record_reward_diagnostics( + self, + mode: str, + rewards: torch.Tensor, + rewards_per_func: torch.Tensor, + group_std_rewards: torch.Tensor, + ) -> None: + tolerance = self.args.diagnostics_flat_tolerance + + reward_mean = rewards.mean() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) + reward_std = rewards.std() if rewards.numel() > 1 else torch.tensor(0.0, device=self.accelerator.device) + reward_min = rewards.min() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) + reward_max = rewards.max() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) + flat_group_fraction = ( + (group_std_rewards <= tolerance).float().mean() + if group_std_rewards.numel() > 0 + else torch.tensor(1.0, device=self.accelerator.device) + ) + + self._metrics[mode]["self_distillation/reward_mean"].append(self.accelerator.gather(reward_mean).mean().item()) + self._metrics[mode]["self_distillation/reward_std"].append(self.accelerator.gather(reward_std).mean().item()) + self._metrics[mode]["self_distillation/reward_min"].append(self.accelerator.gather(reward_min).min().item()) + self._metrics[mode]["self_distillation/reward_max"].append(self.accelerator.gather(reward_max).max().item()) + self._metrics[mode]["self_distillation/group_reward_std_mean"].append( + self.accelerator.gather(group_std_rewards.mean() if group_std_rewards.numel() > 0 else reward_std) + .mean() + .item() + ) + self._metrics[mode]["self_distillation/flat_group_fraction"].append( + self.accelerator.gather(flat_group_fraction).mean().item() + ) + + if rewards_per_func.numel() > 0: + reward_func_means = rewards_per_func.nanmean(dim=0) + gathered_means = self.accelerator.gather(reward_func_means).view(-1, reward_func_means.numel()).mean(dim=0) + for reward_name, reward_func_mean in zip(self.reward_func_names, gathered_means.tolist(), strict=True): + self._metrics[mode][f"self_distillation/rewards/{reward_name}"].append(reward_func_mean) + + reward_is_flat = reward_std.item() <= tolerance + grouped_rewards_are_flat = flat_group_fraction.item() >= 1.0 - tolerance + if reward_is_flat and grouped_rewards_are_flat: + self._warn_on_degenerate_diagnostics( + mode=mode, + counter_key="flat_rewards", + message=( + "Observed flat SDPO rewards across all sampled generations. " + "Policy advantages will collapse to zero, and SDPO will not learn. " + "Check reward density, reward shaping, or `success_reward_threshold`." + ), + ) + else: + self._diagnostic_counters[mode]["flat_rewards"] = 0 + + def _warn_on_degenerate_diagnostics(self, mode: str, counter_key: str, message: str) -> None: + interval = self.args.diagnostics_warning_interval + if interval == 0: + return + + self._diagnostic_counters[mode][counter_key] += 1 + count = self._diagnostic_counters[mode][counter_key] + if count == 1 or count % interval == 0: + logger.warning("%s Consecutive degenerate steps: %s.", message, count) + def _compute_policy_loss(self, model, inputs) -> torch.Tensor: - return super()._compute_loss(model, inputs) + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) + per_token_logps, _ = self._get_per_token_logps_and_entropies( + model, + input_ids, + attention_mask, + logits_to_keep, + compute_entropy=False, + ) + old_per_token_logps = inputs.get("old_per_token_logps") + old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps + advantages = inputs["advantages"] + if advantages.dim() == 1: + advantages = advantages.unsqueeze(1) + log_ratio = per_token_logps - old_per_token_logps + if self.importance_sampling_level == "sequence": + log_ratio = (log_ratio * completion_mask).sum(-1, keepdim=True) / completion_mask.sum( + -1, keepdim=True + ).clamp(min=1.0) + coef_1 = torch.exp(log_ratio) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + per_token_loss = -torch.min(coef_1 * advantages, coef_2 * advantages) + + loss = self._aggregate_self_distillation_loss(per_token_loss, completion_mask) + + mode = "train" if self.model.training else "eval" + self._metrics[mode]["self_distillation/policy_loss"].append( + self.accelerator.gather(loss.detach()).mean().item() + ) + + accumulation_scale = self.current_gradient_accumulation_steps if mode == "train" else 1.0 + return loss / accumulation_scale def _compute_weighted_self_distillation_loss(self, model, inputs) -> torch.Tensor | None: if self.args.distillation_weight <= 0.0: diff --git a/trl/experimental/sdpo/sdpo_trainer_transition.py b/trl/experimental/sdpo/sdpo_trainer_transition.py deleted file mode 100644 index 8fe68e7e7aa..00000000000 --- a/trl/experimental/sdpo/sdpo_trainer_transition.py +++ /dev/null @@ -1,656 +0,0 @@ -# Copyright 2020-2026 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -import textwrap -from typing import Any - -import torch -from accelerate.utils import gather_object -from datasets import Dataset, IterableDataset -from torch import nn -from transformers import ( - AutoModelForSequenceClassification, - AutoTokenizer, - PreTrainedModel, - PreTrainedTokenizerBase, - ProcessorMixin, - TrainerCallback, -) -from transformers.utils import logging - -from ...data_utils import apply_chat_template, is_conversational -from ...models import prepare_deepspeed, prepare_fsdp -from ...trainer.base_trainer import _BaseTrainer -from ...trainer.utils import get_config_model_id, pad -from ..self_distillation.teacher_context import TokenizedPromptBatch, extract_last_user_text -from ..self_distillation.unified_base_self_distillation_trainer import ( - SelfDistillationBatch, - SelfDistillationRolloutBatch, - UnifiedBaseSelfDistillationTrainer, -) -from .sdpo_config import SDPOConfig - - -logger = logging.get_logger(__name__) - - -class SuccessfulRolloutTeacherContextBuilder: - """Builds SDPO teacher contexts from successful rollouts, following the official online implementation.""" - - def __init__(self, trainer): - self.trainer = trainer - self.last_metrics: dict[str, float] = {} - - def _build_reprompt_text(self, prompt_text: str, solution_text: str, feedback_text: str) -> str: - return self.trainer.args.reprompt_template.format( - prompt=prompt_text, - solution=solution_text, - feedback=feedback_text, - ) - - def _tokenize_teacher_messages( - self, teacher_messages_list: list[str | list[dict[str, Any]]] - ) -> TokenizedPromptBatch: - teacher_prompt_ids_list = [] - device = self.trainer.accelerator.device - chat_template_kwargs = getattr(self.trainer, "chat_template_kwargs", {}) - for msg in teacher_messages_list: - if isinstance(msg, list) and isinstance(msg[0], dict): - tokenized = self.trainer.processing_class.apply_chat_template( - msg, - tokenize=True, - add_generation_prompt=True, - return_tensors="pt", - **chat_template_kwargs, - ) - if isinstance(tokenized, torch.Tensor): - ids = tokenized.squeeze(0) - else: - ids = tokenized["input_ids"].squeeze(0) - else: - ids = self.trainer.processing_class.encode(msg, return_tensors="pt").squeeze(0) - - if ids.shape[0] > self.trainer.args.max_reprompt_len: - ids = ids[-self.trainer.args.max_reprompt_len :] - teacher_prompt_ids_list.append(ids) - - teacher_prompt_ids = [ids.to(device) for ids in teacher_prompt_ids_list] - teacher_prompt_mask = [torch.ones(len(ids), dtype=torch.long, device=device) for ids in teacher_prompt_ids] - return TokenizedPromptBatch( - prompt_ids=pad(teacher_prompt_ids, padding_value=self.trainer.pad_token_id, padding_side="left"), - prompt_mask=pad(teacher_prompt_mask, padding_value=0, padding_side="left"), - ) - - def build( - self, - output: dict[str, torch.Tensor | Any], - prompts: list[Any], - rewards: torch.Tensor, - feedbacks: list[Any] | None = None, - ) -> dict[str, torch.Tensor]: - device = self.trainer.accelerator.device - mode = "train" if self.trainer.model.training else "eval" - num_generations = self.trainer.num_generations if mode == "train" else self.trainer.num_generations_eval - completion_ids = output["completion_ids"] - completion_mask = output["completion_mask"] - - num_local = len(prompts) - process_start = self.trainer.accelerator.process_index * num_local - process_slice = slice(process_start, process_start + num_local) - - # Rewards arrive already locally sliced (per-process) from the rollout mixin; re-gather them so - # the mining loop can find successful rollouts across all processes within each generation group. - all_rewards = self.trainer.accelerator.gather(rewards) - # Completion tensors are padded to the local max length per rank; align shapes before gathering. - # Use separate variables so the original completion_ids/completion_mask stay unpadded for the - # teacher concat (they must match the student's sequence length for logits_to_keep alignment). - padded_completion_ids = self.trainer.accelerator.pad_across_processes( - completion_ids, dim=1, pad_index=self.trainer.pad_token_id - ) - all_completion_ids = self.trainer.accelerator.gather(padded_completion_ids) - all_prompts = gather_object(prompts) - total_samples = all_rewards.shape[0] - all_feedbacks = gather_object(feedbacks) if feedbacks is not None else [None] * total_samples - - threshold = self.trainer.args.success_reward_threshold - dont_reprompt_self = self.trainer.args.dont_reprompt_on_self_success - feedback_only_without_solution = self.trainer.args.environment_feedback_only_without_solution - self_distillation_mask = torch.zeros(total_samples, device=device) - num_with_solution = 0 - num_with_feedback_available = 0 - num_with_feedback_used = 0 - success_group_count = 0 - successful_demo_indices: list[int | None] = [None] * total_samples - use_feedback_flags: list[bool] = [False] * total_samples - has_solution_flags: list[bool] = [False] * total_samples - - for i in range(total_samples): - group_start = (i // num_generations) * num_generations - group_end = group_start + num_generations - - successful = [] - if self.trainer.args.use_successful_as_teacher: - for j in range(group_start, group_end): - if dont_reprompt_self and j == i: - continue - if all_rewards[j].item() >= threshold: - successful.append(j) - - if i % num_generations == 0: - # Count groups with any successful rollout, ignoring self-exclusion which only - # affects per-sample teacher assignment, not whether the group has successes. - group_has_success = any(all_rewards[j].item() >= threshold for j in range(group_start, group_end)) - if group_has_success: - success_group_count += 1 - - raw_feedback = all_feedbacks[i] - has_feedback = isinstance(raw_feedback, str) and raw_feedback.strip() != "" - if has_feedback: - num_with_feedback_available += 1 - - has_solution = len(successful) > 0 - has_solution_flags[i] = has_solution - if has_solution: - successful_demo_indices[i] = successful[0] - use_feedback = ( - self.trainer.args.include_environment_feedback - and has_feedback - and (not feedback_only_without_solution or not has_solution) - ) - use_feedback_flags[i] = use_feedback - if use_feedback: - num_with_feedback_used += 1 - if has_solution or use_feedback: - self_distillation_mask[i] = 1.0 - if has_solution: - num_with_solution += 1 - - local_teacher_messages = [] - local_self_distillation_mask = self_distillation_mask[process_slice] - for global_idx in range(process_start, process_start + num_local): - original_prompt = all_prompts[global_idx] - raw_feedback = all_feedbacks[global_idx] - has_solution = has_solution_flags[global_idx] - use_feedback = use_feedback_flags[global_idx] - - if not has_solution and not use_feedback: - local_teacher_messages.append(original_prompt) - continue - - solution_text = "" - if has_solution: - demo_idx = successful_demo_indices[global_idx] - if demo_idx is None: - raise RuntimeError("Expected a successful demonstration index for an active SDPO teacher prompt.") - demo_ids = all_completion_ids[demo_idx] - demo_ids = demo_ids[demo_ids != self.trainer.processing_class.pad_token_id] - demo_text = self.trainer.processing_class.decode(demo_ids, skip_special_tokens=True) - - if self.trainer.args.remove_thinking_from_demonstration: - demo_text = re.sub(r".*?", "", demo_text, flags=re.DOTALL).strip() - - solution_text = self.trainer.args.solution_template.format(successful_previous_attempt=demo_text) - - feedback_text = "" - if use_feedback: - feedback_text = self.trainer.args.feedback_template.format(feedback_raw=raw_feedback) - - if isinstance(original_prompt, list): - system_messages = original_prompt[:-1] - prompt_text = extract_last_user_text(original_prompt) - reprompt_text = self._build_reprompt_text(prompt_text, solution_text, feedback_text) - local_teacher_messages.append(system_messages + [{"role": "user", "content": reprompt_text}]) - else: - local_teacher_messages.append(self._build_reprompt_text(original_prompt, solution_text, feedback_text)) - - teacher_batch = self._tokenize_teacher_messages(local_teacher_messages) - teacher_input_ids = torch.cat([teacher_batch.prompt_ids, completion_ids], dim=1) - teacher_attention_mask = torch.cat([teacher_batch.prompt_mask, completion_mask], dim=1) - - batch_size = total_samples if total_samples > 0 else 1 - num_groups = max(1, total_samples // max(1, num_generations)) - self.last_metrics = { - "self_distillation/success_group_fraction": success_group_count / num_groups, - "self_distillation/success_sample_fraction": num_with_solution / batch_size, - "self_distillation/feedback_available_fraction": num_with_feedback_available / batch_size, - "self_distillation/feedback_used_fraction": num_with_feedback_used / batch_size, - "self_distillation/reprompt_sample_fraction": self_distillation_mask.float().mean().item(), - } - - return { - "teacher_input_ids": teacher_input_ids, - "teacher_attention_mask": teacher_attention_mask, - "self_distillation_mask": local_self_distillation_mask, - } - - -class SDPOTrainer(UnifiedBaseSelfDistillationTrainer): - """ - Trainer for Self-Distillation Policy Optimization (SDPO). - - SDPO augments on-policy optimization with self-distillation from the model's own high-reward trajectories. It - converts tokenized feedback into a dense learning signal without any external teacher or explicit reward model. - SDPO treats the current model conditioned on feedback as a self-teacher and distills its feedback-informed - next-token predictions back into the policy. - """ - - config_cls = SDPOConfig - _tag_names = ["trl", "sdpo"] - _name = "SDPO" - # docstyle-ignore - _paper = { - "title": "Reinforcement Learning via Self-Distillation", - "id": "2601.20802", - "citation": textwrap.dedent("""\ - @article{hubotter2026sdpo, - title = {{Reinforcement Learning via Self-Distillation}}, - author = {Jonas H\\"ubotter and Frederike L\\"ubeck and Lejs Behric and Anton Baumann and Marco Bagatella and Daniel Marta and Ido Hakimi and Idan Shenfeld and Thomas Kleine Buening and Carlos Guestrin and Andreas Krause}, - year = 2026, - eprint = {arXiv:2601.20802} - }"""), - } - - def __init__( - self, - model: str | PreTrainedModel | nn.Module, - reward_funcs: Any | list[Any] | None = None, - args: SDPOConfig | None = None, - train_dataset: Dataset | IterableDataset | None = None, - eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, - processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, - reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None, - callbacks: list[TrainerCallback] | None = None, - optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), - peft_config=None, - ): - if reward_funcs is None or (isinstance(reward_funcs, list) and len(reward_funcs) == 0): - raise ValueError("`reward_funcs` is required for SDPOTrainer because SDPO must score rollouts.") - super().__init__( - model=model, - args=args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - processing_class=processing_class, - callbacks=callbacks, - optimizers=optimizers, - peft_config=peft_config, - ) - self.importance_sampling_level = args.importance_sampling_level - self.scale_rewards = args.scale_rewards - self.epsilon_low = args.epsilon - self.epsilon_high = args.epsilon_high - self.beta = args.beta - - if not isinstance(reward_funcs, list): - reward_funcs = [reward_funcs] - self.reward_func_names = [] - for i, reward_func in enumerate(reward_funcs): - if isinstance(reward_func, str): - reward_model_init_kwargs = args.model_init_kwargs or {} - if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: - reward_model_init_kwargs["device_map"] = None - reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( - reward_func, - num_labels=1, - **reward_model_init_kwargs, - ) - if isinstance(reward_funcs[i], nn.Module): - self.reward_func_names.append(get_config_model_id(reward_funcs[i].config).split("/")[-1]) - else: - self.reward_func_names.append(reward_funcs[i].__name__) - self.reward_funcs = reward_funcs - - if args.reward_weights is not None: - if len(args.reward_weights) != len(self.reward_funcs): - raise ValueError("Number of reward weights must match number of reward functions") - self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) - else: - self.reward_weights = torch.ones(len(self.reward_funcs), dtype=torch.float32) - - if reward_processing_classes is None: - reward_processing_classes = [None] * len(self.reward_funcs) - elif not isinstance(reward_processing_classes, list): - reward_processing_classes = [reward_processing_classes] - if len(reward_processing_classes) != len(self.reward_funcs): - raise ValueError("Number of reward processing classes must match number of reward functions") - - for i, (reward_processing_class, reward_func) in enumerate( - zip(reward_processing_classes, self.reward_funcs, strict=True) - ): - if isinstance(reward_func, PreTrainedModel): - if reward_processing_class is None: - reward_processing_class = AutoTokenizer.from_pretrained(get_config_model_id(reward_func.config)) - if reward_processing_class.pad_token_id is None: - reward_processing_class.pad_token = reward_processing_class.eos_token - reward_func.config.pad_token_id = reward_processing_class.pad_token_id - reward_processing_classes[i] = reward_processing_class - self.reward_processing_classes = reward_processing_classes - - for i, reward_func in enumerate(self.reward_funcs): - if isinstance(reward_func, nn.Module): - if self.is_deepspeed_enabled: - self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) - elif self.is_fsdp_enabled: - self.reward_funcs[i] = prepare_fsdp(reward_func, self.accelerator) - else: - self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) - - self.teacher_context_builder = SuccessfulRolloutTeacherContextBuilder(self) - - def _allow_topk_without_full_logit_distillation(self) -> bool: - return False - - def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): - device = self.accelerator.device - if len(self.reward_funcs) == 0: - return torch.zeros((len(prompts), 0), device=device) - - rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) - keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]] - reward_kwargs = {key: [example[key] for example in inputs] for key in keys} - reward_kwargs["trainer_state"] = self.state - - for i, (reward_func, reward_processing_class) in enumerate( - zip(self.reward_funcs, self.reward_processing_classes, strict=True) - ): - if isinstance(reward_func, nn.Module): - if is_conversational(inputs[0]): - messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)] - texts = [ - apply_chat_template(x, reward_processing_class, **self.chat_template_kwargs)["text"] - for x in messages - ] - else: - texts = [p + c for p, c in zip(prompts, completions, strict=True)] - reward_inputs = reward_processing_class( - text=texts, - return_tensors="pt", - padding=True, - padding_side="right", - add_special_tokens=False, - ) - reward_inputs = _BaseTrainer._prepare_inputs(self, reward_inputs) - with torch.inference_mode(): - rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] - else: - output_reward_func = reward_func( - prompts=prompts, - completions=completions, - completion_ids=completion_ids_list, - **reward_kwargs, - ) - output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] - rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) - - return self.accelerator.gather(rewards_per_func) - - def augment_training_batch( - self, - inputs: list[dict[str, Any]], - rollout_batch: SelfDistillationRolloutBatch, - ) -> SelfDistillationBatch: - device = self.accelerator.device - mode = "train" if self.model.training else "eval" - prompts, privileged_contexts = self._split_prompt_and_privileged_context(inputs) - raw_completion_lengths = rollout_batch.metadata["raw_completion_lengths"].detach().cpu().tolist() - completion_ids_list = [ - ids[:length].tolist() - for ids, length in zip(rollout_batch.completion_ids.detach().cpu(), raw_completion_lengths, strict=True) - ] - if is_conversational({"prompt": prompts[0]}): - completions_text = self.processing_class.batch_decode( - rollout_batch.completion_ids, skip_special_tokens=True - ) - completions = [[{"role": "assistant", "content": content}] for content in completions_text] - else: - completions = self.processing_class.batch_decode(rollout_batch.completion_ids, skip_special_tokens=True) - - rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) - if rewards_per_func.numel() == 0: - rewards = torch.zeros(self.accelerator.num_processes * len(prompts), device=device) - else: - rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) - - num_generations = self.num_generations if mode == "train" else self.num_generations_eval - mean_grouped_rewards = rewards.view(-1, num_generations).mean(dim=1).repeat_interleave(num_generations, dim=0) - if self.scale_rewards == "batch": - std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards) - group_std_rewards = rewards.view(-1, num_generations).std(dim=1) - elif self.scale_rewards == "none": - std_rewards = torch.ones_like(rewards) - group_std_rewards = torch.ones(rewards.numel() // num_generations, device=device, dtype=rewards.dtype) - else: - group_std_rewards = rewards.view(-1, num_generations).std(dim=1) - std_rewards = group_std_rewards.repeat_interleave(num_generations, dim=0) - advantages = (rewards - mean_grouped_rewards) / (std_rewards + 1e-4) - self._record_reward_diagnostics(mode, rewards, rewards_per_func, group_std_rewards) - - local_batch_size = rollout_batch.completion_ids.size(0) - process_start = self.accelerator.process_index * local_batch_size - process_slice = slice(process_start, process_start + local_batch_size) - local_rewards = rewards[process_slice] - local_advantages = advantages[process_slice] - - agg_completion_lengths = self.accelerator.gather( - torch.tensor([len(ids) for ids in completion_ids_list], device=device) - ) - self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) - - eos_and_pad = [self.eos_token_id, self.pad_token_id] - is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) - agg_is_truncated = self.accelerator.gather(is_truncated) - self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) - term_completion_lengths = agg_completion_lengths[~agg_is_truncated] - if len(term_completion_lengths) == 0: - term_completion_lengths = torch.zeros(1, device=device) - self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - - rollout_dict = rollout_batch.to_dict() - rollout_dict["rewards"] = local_rewards - rollout_dict["advantages"] = local_advantages - rollout_dict["num_items_in_batch"] = rollout_batch.completion_mask.sum().detach() - teacher_context = self.teacher_context_builder.build( - rollout_dict, - prompts, - rollout_dict["rewards"], - feedbacks=privileged_contexts, - ) - - mode = "train" if self.model.training else "eval" - for key, value in self.teacher_context_builder.last_metrics.items(): - self._metrics[mode][key].append(value) - self._warn_on_inactive_self_distillation(mode) - - self._dispatch_self_distillation_callback( - "on_teacher_context_built", - teacher_input_ids=teacher_context["teacher_input_ids"], - teacher_attention_mask=teacher_context["teacher_attention_mask"], - completion_mask=rollout_batch.completion_mask, - self_distillation_mask=teacher_context["self_distillation_mask"], - ) - - return SelfDistillationBatch( - prompt_ids=rollout_batch.prompt_ids, - prompt_mask=rollout_batch.prompt_mask, - completion_ids=rollout_batch.completion_ids, - completion_mask=rollout_batch.completion_mask, - teacher_input_ids=teacher_context["teacher_input_ids"], - teacher_attention_mask=teacher_context["teacher_attention_mask"], - old_per_token_logps=rollout_batch.old_per_token_logps, - self_distillation_mask=teacher_context["self_distillation_mask"], - metadata={ - "rewards": local_rewards, - "advantages": local_advantages, - }, - ) - - def _warn_on_inactive_self_distillation(self, mode: str) -> None: - metrics = self.teacher_context_builder.last_metrics - tolerance = self.args.diagnostics_flat_tolerance - - reprompt_fraction = metrics.get("self_distillation/reprompt_sample_fraction", 0.0) - success_fraction = metrics.get("self_distillation/success_group_fraction", 0.0) - - if reprompt_fraction <= tolerance: - self._warn_on_degenerate_diagnostics( - mode=mode, - counter_key="inactive_self_distillation", - message=( - "SDPO self-distillation is inactive because no reprompted samples were constructed. " - "This usually means no rollout exceeded `success_reward_threshold` and no usable privileged " - "feedback was available." - ), - ) - else: - self._diagnostic_counters[mode]["inactive_self_distillation"] = 0 - - if success_fraction <= tolerance: - self._warn_on_degenerate_diagnostics( - mode=mode, - counter_key="no_successful_rollouts", - message=( - "SDPO did not find any successful rollouts in the current generation groups. " - "If this persists, reduce task difficulty, adjust reward shaping, or lower " - "`success_reward_threshold`." - ), - ) - else: - self._diagnostic_counters[mode]["no_successful_rollouts"] = 0 - - def _record_reward_diagnostics( - self, - mode: str, - rewards: torch.Tensor, - rewards_per_func: torch.Tensor, - group_std_rewards: torch.Tensor, - ) -> None: - tolerance = self.args.diagnostics_flat_tolerance - - reward_mean = rewards.mean() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) - reward_std = rewards.std() if rewards.numel() > 1 else torch.tensor(0.0, device=self.accelerator.device) - reward_min = rewards.min() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) - reward_max = rewards.max() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) - flat_group_fraction = ( - (group_std_rewards <= tolerance).float().mean() - if group_std_rewards.numel() > 0 - else torch.tensor(1.0, device=self.accelerator.device) - ) - - self._metrics[mode]["self_distillation/reward_mean"].append(self.accelerator.gather(reward_mean).mean().item()) - self._metrics[mode]["self_distillation/reward_std"].append(self.accelerator.gather(reward_std).mean().item()) - self._metrics[mode]["self_distillation/reward_min"].append(self.accelerator.gather(reward_min).min().item()) - self._metrics[mode]["self_distillation/reward_max"].append(self.accelerator.gather(reward_max).max().item()) - self._metrics[mode]["self_distillation/group_reward_std_mean"].append( - self.accelerator.gather(group_std_rewards.mean() if group_std_rewards.numel() > 0 else reward_std) - .mean() - .item() - ) - self._metrics[mode]["self_distillation/flat_group_fraction"].append( - self.accelerator.gather(flat_group_fraction).mean().item() - ) - - if rewards_per_func.numel() > 0: - reward_func_means = rewards_per_func.nanmean(dim=0) - gathered_means = self.accelerator.gather(reward_func_means).view(-1, reward_func_means.numel()).mean(dim=0) - for reward_name, reward_func_mean in zip(self.reward_func_names, gathered_means.tolist(), strict=True): - self._metrics[mode][f"self_distillation/rewards/{reward_name}"].append(reward_func_mean) - - reward_is_flat = reward_std.item() <= tolerance - grouped_rewards_are_flat = flat_group_fraction.item() >= 1.0 - tolerance - if reward_is_flat and grouped_rewards_are_flat: - self._warn_on_degenerate_diagnostics( - mode=mode, - counter_key="flat_rewards", - message=( - "Observed flat SDPO rewards across all sampled generations. " - "Policy advantages will collapse to zero, and SDPO will not learn. " - "Check reward density, reward shaping, or `success_reward_threshold`." - ), - ) - else: - self._diagnostic_counters[mode]["flat_rewards"] = 0 - - def _warn_on_degenerate_diagnostics(self, mode: str, counter_key: str, message: str) -> None: - interval = self.args.diagnostics_warning_interval - if interval == 0: - return - - self._diagnostic_counters[mode][counter_key] += 1 - count = self._diagnostic_counters[mode][counter_key] - if count == 1 or count % interval == 0: - logger.warning("%s Consecutive degenerate steps: %s.", message, count) - - def _compute_policy_loss(self, model, inputs) -> torch.Tensor: - prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] - completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] - input_ids = torch.cat([prompt_ids, completion_ids], dim=1) - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - logits_to_keep = completion_ids.size(1) - per_token_logps, _ = self._get_per_token_logps_and_entropies( - model, - input_ids, - attention_mask, - logits_to_keep, - compute_entropy=False, - ) - old_per_token_logps = inputs.get("old_per_token_logps") - old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps - advantages = inputs["advantages"] - if advantages.dim() == 1: - advantages = advantages.unsqueeze(1) - log_ratio = per_token_logps - old_per_token_logps - if self.importance_sampling_level == "sequence": - log_ratio = (log_ratio * completion_mask).sum(-1, keepdim=True) / completion_mask.sum( - -1, keepdim=True - ).clamp(min=1.0) - coef_1 = torch.exp(log_ratio) - coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) - per_token_loss = -torch.min(coef_1 * advantages, coef_2 * advantages) - - loss = self._aggregate_self_distillation_loss(per_token_loss, completion_mask) - - mode = "train" if self.model.training else "eval" - self._metrics[mode]["self_distillation/policy_loss"].append( - self.accelerator.gather(loss.detach()).mean().item() - ) - - accumulation_scale = self.current_gradient_accumulation_steps if mode == "train" else 1.0 - return loss / accumulation_scale - - def _compute_weighted_self_distillation_loss(self, model, inputs) -> torch.Tensor | None: - if self.args.distillation_weight <= 0.0: - return None - - accumulation_scale = self.current_gradient_accumulation_steps if self.model.training else 1.0 - distillation_loss = self._compute_self_distillation_loss(model, inputs) / accumulation_scale - return self.args.distillation_weight * distillation_loss - - def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): - if return_outputs: - raise ValueError("The SDPOTrainer does not support returning outputs") - - if self.args.sdpo_policy_loss_mode == "hybrid": - policy_loss = self._compute_policy_loss(model, inputs) - weighted_distillation_loss = self._compute_weighted_self_distillation_loss(model, inputs) - return policy_loss if weighted_distillation_loss is None else policy_loss + weighted_distillation_loss - - weighted_distillation_loss = self._compute_weighted_self_distillation_loss(model, inputs) - if weighted_distillation_loss is not None: - return weighted_distillation_loss - return self._compute_policy_loss(model, inputs) diff --git a/trl/experimental/self_distillation/base_self_distillation_trainer.py b/trl/experimental/self_distillation/base_self_distillation_trainer.py index 5b658b19a10..d8171cc88e1 100644 --- a/trl/experimental/self_distillation/base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/base_self_distillation_trainer.py @@ -12,29 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Shared online self-distillation trainer scaffold. - -This base combines the generic Trainer setup for self-distillation with the online rollout utilities used by SDPO-like -methods. Offline methods such as SDFT stay on `_BaseTrainer` directly and only reuse the shared distillation mixin. -""" - from __future__ import annotations +import copy import inspect +from abc import ABC, abstractmethod from collections import defaultdict +from contextlib import nullcontext +from dataclasses import dataclass, field from functools import partial from typing import Any import datasets import torch from accelerate.logging import get_logger +from accelerate.utils import is_peft_model from datasets import Dataset, IterableDataset from torch import nn from torch.utils.data import DataLoader, Sampler from transformers import ( - AutoModelForSequenceClassification, AutoProcessor, - AutoTokenizer, GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase, @@ -44,41 +41,104 @@ from transformers.trainer_utils import seed_worker from transformers.utils import is_datasets_available, is_peft_available -from ...models import prepare_deepspeed, prepare_fsdp +from ...models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation from ...trainer.base_trainer import _BaseTrainer from ...trainer.utils import ( RepeatSampler, create_model_from_path, disable_dropout_in_model, + entropy_from_logits, get_config_model_id, identity, + pad, + selective_log_softmax, split_tensor_dict, + use_adapter, ) from ..utils import prepare_peft_model -from .online_rollout_mixin import OnlineRolloutMixin from .self_distillation_config import SelfDistillationConfig from .self_distillation_mixin import SelfDistillationMixin +from .teacher_context import PromptTokenizer +from .teacher_sync import PEFTAdapterEMACallback, SyncTeacherModelCallback if is_peft_available(): from peft import PeftConfig + from peft.peft_model import PeftModel logger = get_logger(__name__) -class BaseSelfDistillationTrainer(OnlineRolloutMixin, SelfDistillationMixin, _BaseTrainer): - """Shared scaffold for experimental self-distillation trainers without GRPO inheritance.""" +@dataclass +class SelfDistillationRolloutBatch: + """Common student rollout batch produced before algorithm-specific augmentation.""" + + prompt_ids: torch.Tensor + prompt_mask: torch.Tensor + completion_ids: torch.Tensor + completion_mask: torch.Tensor + old_per_token_logps: torch.Tensor | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, torch.Tensor | Any]: + output: dict[str, torch.Tensor | Any] = { + "prompt_ids": self.prompt_ids, + "prompt_mask": self.prompt_mask, + "completion_ids": self.completion_ids, + "completion_mask": self.completion_mask, + } + if self.old_per_token_logps is not None: + output["old_per_token_logps"] = self.old_per_token_logps + output.update(self.metadata) + return output + + +@dataclass +class SelfDistillationBatch: + """Final self-distillation batch contract consumed by `SelfDistillationMixin`.""" + + prompt_ids: torch.Tensor + prompt_mask: torch.Tensor + completion_ids: torch.Tensor + completion_mask: torch.Tensor + teacher_input_ids: torch.Tensor + teacher_attention_mask: torch.Tensor + old_per_token_logps: torch.Tensor | None = None + self_distillation_mask: torch.Tensor | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, torch.Tensor | Any]: + output: dict[str, torch.Tensor | Any] = { + "prompt_ids": self.prompt_ids, + "prompt_mask": self.prompt_mask, + "completion_ids": self.completion_ids, + "completion_mask": self.completion_mask, + "teacher_input_ids": self.teacher_input_ids, + "teacher_attention_mask": self.teacher_attention_mask, + } + if self.old_per_token_logps is not None: + output["old_per_token_logps"] = self.old_per_token_logps + if self.self_distillation_mask is not None: + output["self_distillation_mask"] = self.self_distillation_mask + output.update(self.metadata) + return output + + +class BaseSelfDistillationTrainer(SelfDistillationMixin, _BaseTrainer, ABC): + """Base that centralizes shared self-distillation trainer lifecycle.""" + + config_cls = SelfDistillationConfig + _tag_names = ["trl", "self-distillation"] + _name = "Self-Distillation" def __init__( self, model: str | PreTrainedModel | nn.Module, - reward_funcs: Any | list[Any] | None = None, args: SelfDistillationConfig | None = None, train_dataset: Dataset | IterableDataset | None = None, eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, - reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None, callbacks: list[TrainerCallback] | None = None, optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), peft_config: PeftConfig | None = None, @@ -104,6 +164,11 @@ def __init__( else inspect.signature(model.get_base_model().forward).parameters.keys() ) + if is_peft_available() and is_peft_model(model) and peft_config is not None: + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config`. Pass either a base " + "model with `peft_config`, or a pre-wrapped PEFT model." + ) if peft_config is not None or (is_peft_available() and getattr(model, "peft_config", None) is not None): model = prepare_peft_model(model, peft_config, args) @@ -124,7 +189,6 @@ def __init__( self.pad_token_id = tokenizer.pad_token_id self.eos_token_id = tokenizer.eos_token_id - self.temperature = args.temperature self.max_prompt_length = args.max_prompt_length self.max_completion_length = args.max_completion_length self.num_generations = args.num_generations @@ -132,12 +196,8 @@ def __init__( self.num_iterations = args.num_iterations self.shuffle_dataset = args.shuffle_dataset self.loss_type = args.loss_type - self.importance_sampling_level = args.importance_sampling_level - self.scale_rewards = args.scale_rewards - self.epsilon_low = args.epsilon - self.epsilon_high = args.epsilon_high - self.beta = args.beta self.mask_truncated_completions = args.mask_truncated_completions + self.temperature = args.temperature self.chat_template_kwargs = args.chat_template_kwargs or {} self._step = 0 self._last_loaded_step = 0 @@ -147,6 +207,7 @@ def __init__( "train": defaultdict(int), "eval": defaultdict(int), } + self.prompt_tokenizer = PromptTokenizer(self) generation_kwargs = { "max_new_tokens": self.max_completion_length, @@ -211,71 +272,74 @@ def __init__( logprobs=None, generation_kwargs=args.generation_kwargs, ) - self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation - - if reward_funcs is None: - reward_funcs = [] - if not isinstance(reward_funcs, list): - reward_funcs = [reward_funcs] - self.reward_func_names = [] - for i, reward_func in enumerate(reward_funcs): - if isinstance(reward_func, str): - reward_model_init_kwargs = args.model_init_kwargs or {} - if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: - reward_model_init_kwargs["device_map"] = None - reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( - reward_func, - num_labels=1, - **reward_model_init_kwargs, + self._last_loaded_step = -1 + + if args.disable_dropout: + disable_dropout_in_model(self.model) + + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + self._setup_teacher_model() + self.model_accepts_loss_kwargs = False + + def _setup_teacher_model(self) -> None: + """Prepare teacher state according to the shared teacher policy.""" + + teacher_regularization = self.args.teacher_regularization + peft_teacher_mode = self._resolve_peft_teacher_mode() + self._validate_teacher_policy(teacher_regularization, peft_teacher_mode) + + if teacher_regularization == "none": + return + + if is_peft_available() and is_peft_model(self.model) and peft_teacher_mode == "teacher_adapter": + self.add_callback( + PEFTAdapterEMACallback( + model=self.model, + teacher_adapter_name=self.args.teacher_adapter_name, + update_rate=self.args.teacher_update_rate, + sync_steps=self.args.teacher_sync_steps, + accelerator=self.accelerator, ) - if isinstance(reward_funcs[i], nn.Module): - self.reward_func_names.append(get_config_model_id(reward_funcs[i].config).split("/")[-1]) - else: - self.reward_func_names.append(reward_funcs[i].__name__) - self.reward_funcs = reward_funcs - - if args.reward_weights is not None: - if len(args.reward_weights) != len(self.reward_funcs): - raise ValueError("Number of reward weights must match number of reward functions") - self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) + ) + return + + student_model = self.accelerator.unwrap_model(self.model) + self.teacher_model = copy.deepcopy(student_model) + self.teacher_model.requires_grad_(False) + self.teacher_model.eval() + + if self.is_deepspeed_enabled: + self.teacher_model = prepare_deepspeed(self.teacher_model, self.accelerator) + elif self.is_fsdp_enabled: + self.teacher_model = prepare_fsdp(self.teacher_model, self.accelerator) else: - self.reward_weights = torch.ones(len(self.reward_funcs), dtype=torch.float32) + self.teacher_model = self.accelerator.prepare_model(self.teacher_model, evaluation_mode=True) - if reward_processing_classes is None: - reward_processing_classes = [None] * len(self.reward_funcs) - elif not isinstance(reward_processing_classes, list): - reward_processing_classes = [reward_processing_classes] - if len(reward_processing_classes) != len(self.reward_funcs): - raise ValueError("Number of reward processing classes must match number of reward functions") + self.add_callback(SyncTeacherModelCallback(teacher_model=self.teacher_model, accelerator=self.accelerator)) - for i, (reward_processing_class, reward_func) in enumerate( - zip(reward_processing_classes, self.reward_funcs, strict=True) - ): - if isinstance(reward_func, PreTrainedModel): - if reward_processing_class is None: - reward_processing_class = AutoTokenizer.from_pretrained(get_config_model_id(reward_func.config)) - if reward_processing_class.pad_token_id is None: - reward_processing_class.pad_token = reward_processing_class.eos_token - reward_func.config.pad_token_id = reward_processing_class.pad_token_id - reward_processing_classes[i] = reward_processing_class - self.reward_processing_classes = reward_processing_classes + def _resolve_peft_teacher_mode(self) -> str: + peft_teacher_mode = self.args.peft_teacher_mode + if not (is_peft_available() and is_peft_model(self.model)): + if peft_teacher_mode in {"disable_adapter", "teacher_adapter"}: + raise ValueError(f"PEFT teacher mode `{peft_teacher_mode}` requires a PEFT model.") + return "inherit_adapter" - if args.disable_dropout: - disable_dropout_in_model(self.model) + if peft_teacher_mode == "auto": + if self.args.teacher_regularization == "ema": + return "teacher_adapter" + return "disable_adapter" - for i, reward_func in enumerate(self.reward_funcs): - if isinstance(reward_func, nn.Module): - if self.is_deepspeed_enabled: - self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) - elif self.is_fsdp_enabled: - self.reward_funcs[i] = prepare_fsdp(reward_func, self.accelerator) - else: - self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) + return peft_teacher_mode - self.model.add_model_tags(self._tag_names) - self.model_accepts_loss_kwargs = False - self.ref_model = None - self.teacher_model = None + def _validate_teacher_policy(self, teacher_regularization: str, peft_teacher_mode: str) -> None: + if teacher_regularization not in {"none", "ema"}: + raise ValueError(f"Unsupported teacher regularization mode: {teacher_regularization}") + if peft_teacher_mode not in {"inherit_adapter", "disable_adapter", "teacher_adapter"}: + raise ValueError(f"Unsupported PEFT teacher mode: {peft_teacher_mode}") + if peft_teacher_mode == "teacher_adapter" and teacher_regularization != "ema": + raise ValueError("PEFT teacher mode `teacher_adapter` requires EMA teacher regularization.") def get_train_dataloader(self): if self.train_dataset is None: @@ -319,7 +383,7 @@ def _get_train_sampler(self, dataset=None) -> Sampler: def _get_eval_sampler(self, eval_dataset) -> Sampler: return RepeatSampler( data_source=eval_dataset, - mini_repeat_count=getattr(self, "num_generations_eval", self.num_generations), + mini_repeat_count=self.num_generations_eval, seed=self.args.seed, ) @@ -333,8 +397,8 @@ def _prepare_inputs(self, generation_batch): if mode == "train": generate_every = self.args.steps_per_generation * self.num_iterations if self._step % generate_every == 0 or self._buffered_inputs is None: - generation_batch = self._build_buffered_batch(generation_batch) - self._buffered_inputs = split_tensor_dict(generation_batch, self.args.steps_per_generation) + buffered_batch = self._build_buffered_batch(generation_batch) + self._buffered_inputs = split_tensor_dict(buffered_batch, self.args.steps_per_generation) self._dispatch_self_distillation_callback( "on_generation_batch_built", generate_every=generate_every, @@ -343,9 +407,223 @@ def _prepare_inputs(self, generation_batch): return self._buffered_inputs[self._step % self.args.steps_per_generation] return self._build_buffered_batch(generation_batch) - def _prepare_auxiliary_model_for_eval(self, aux_model: nn.Module): - if self.is_deepspeed_enabled: - return prepare_deepspeed(aux_model, self.accelerator) - if self.is_fsdp_enabled: - return prepare_fsdp(aux_model, self.accelerator) - return self.accelerator.prepare_model(aux_model, evaluation_mode=True) + def _build_buffered_batch(self, inputs: list[dict[str, Any]]) -> dict[str, torch.Tensor | Any]: + return self.build_training_batch(inputs).to_dict() + + def build_training_batch(self, inputs: list[dict[str, Any]]) -> SelfDistillationBatch: + rollout_batch = self.build_rollout_batch(inputs) + + batch = self.augment_training_batch(inputs, rollout_batch) + self._validate_training_batch(batch) + + self._dispatch_self_distillation_callback( + "on_self_distillation_batch_prepared", + old_per_token_logps=batch.old_per_token_logps, + prompt_ids=batch.prompt_ids, + completion_ids=batch.completion_ids, + teacher_input_ids=batch.teacher_input_ids, + teacher_attention_mask=batch.teacher_attention_mask, + self_distillation_mask=batch.self_distillation_mask, + ) + return batch + + def build_rollout_batch(self, inputs: list[dict[str, Any]]) -> SelfDistillationRolloutBatch: + prompts, _ = self._split_prompt_and_privileged_context(inputs) + generation_prompts = prompts + generation_prompt_text = self.prompt_tokenizer.apply_prompt_template(generation_prompts) + self._dispatch_self_distillation_callback( + "on_generation_prompts_selected", + generation_prompts=generation_prompts, + generation_prompt_text=generation_prompt_text, + ) + + prompt_ids_list, completion_ids_list = self._generate(generation_prompts) + device = self.accelerator.device + prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left").to(device=device) + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left").to(device=device) + completion_ids = [torch.tensor(ids) for ids in completion_ids_list] + completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right").to(device=device) + completion_mask = pad(completion_mask, padding_value=0, padding_side="right").to(device=device) + if self.mask_truncated_completions: + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() + old_per_token_logps = self.compute_rollout_logps( + prompt_ids=prompt_ids, + prompt_mask=prompt_mask, + completion_ids=completion_ids, + completion_mask=completion_mask, + ) + + return SelfDistillationRolloutBatch( + prompt_ids=prompt_ids, + prompt_mask=prompt_mask, + completion_ids=completion_ids, + completion_mask=completion_mask, + old_per_token_logps=old_per_token_logps, + metadata={ + "raw_completion_lengths": torch.tensor( + [len(ids) for ids in completion_ids_list], device=device, dtype=torch.long + ) + }, + ) + + def _generate(self, prompts: list[Any]) -> tuple[list[list[int]], list[list[int]]]: + if self.use_vllm: + return self._generate_vllm(prompts) + return self._generate_transformers(prompts) + + def _generate_vllm(self, prompts: list[Any]) -> tuple[list[list[int]], list[list[int]]]: + if self.state.global_step != self._last_loaded_step: + self.vllm_generation.sync_weights() + self._last_loaded_step = self.state.global_step + + prompts_text = self.prompt_tokenizer.apply_prompt_template(prompts) + tokenized = self.processing_class( + text=prompts_text, + return_tensors=None, + padding=False, + max_length=self.max_prompt_length, + truncation=True, + add_special_tokens=False, + ) + prompt_ids = tokenized["input_ids"] + mode = "train" if self.model.training else "eval" + num_generations = self.num_generations if mode == "train" else self.num_generations_eval + prompt_ids_out, completion_ids_list, _, _ = self.vllm_generation.generate( + prompts=prompt_ids, + images=None, + num_generations=num_generations, + ) + return prompt_ids_out, completion_ids_list + + def _generate_transformers(self, prompts: list[Any]) -> tuple[list[list[int]], list[list[int]]]: + generate_inputs = self.processing_class( + text=self.prompt_tokenizer.apply_prompt_template(prompts), + return_tensors="pt", + padding=True, + padding_side="left", + max_length=self.max_prompt_length, + truncation=True, + add_special_tokens=False, + ) + generate_inputs = _BaseTrainer._prepare_inputs(self, generate_inputs) + + with ( + unwrap_model_for_generation( + self.model_wrapped, + self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation, + ) as unwrapped_model, + torch.no_grad(), + ): + prompt_completion_ids = unwrapped_model.generate( + **generate_inputs, generation_config=self.generation_config + ) + + prompt_ids = generate_inputs["input_ids"] + prompt_mask = generate_inputs["attention_mask"] + prompt_length = prompt_ids.size(1) + completion_ids = prompt_completion_ids[:, prompt_length:] + is_eos = completion_ids == self.eos_token_id + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + seq_idx = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1) + completion_mask = (seq_idx <= eos_idx.unsqueeze(1)).long() + + prompt_ids_list = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool(), strict=False)] + completion_ids_list = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=False)] + return prompt_ids_list, completion_ids_list + + def compute_rollout_logps( + self, + prompt_ids: torch.Tensor, + prompt_mask: torch.Tensor, + completion_ids: torch.Tensor, + completion_mask: torch.Tensor, + ) -> torch.Tensor | None: + generate_every = self.args.steps_per_generation * self.num_iterations + old_per_token_logps = None + + if self.args.gradient_accumulation_steps % generate_every != 0: + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) + with torch.no_grad(): + old_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + compute_entropy=False, + ) + + return old_per_token_logps + + def _get_per_token_logps_and_entropies( + self, + model, + input_ids, + attention_mask, + logits_to_keep, + compute_entropy=False, + ): + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} + if "logits_to_keep" in self.model_kwarg_keys: + model_inputs["logits_to_keep"] = logits_to_keep + 1 + logits = model(**model_inputs).logits + logits = logits[:, :-1, :] + logits = logits[:, -logits_to_keep:, :] + logits = logits / self.temperature + completion_ids = input_ids[:, -logits_to_keep:] + selected_logps = selective_log_softmax(logits, completion_ids) + entropies = entropy_from_logits(logits) if compute_entropy else None + return selected_logps, entropies + + def _validate_training_batch(self, batch: SelfDistillationBatch) -> None: + batch_size = batch.prompt_ids.size(0) + if batch.prompt_mask.size(0) != batch_size: + raise ValueError("`prompt_mask` must have the same batch size as `prompt_ids`.") + if batch.completion_ids.size(0) != batch_size or batch.completion_mask.size(0) != batch_size: + raise ValueError("`completion_ids` and `completion_mask` must match the student batch size.") + if batch.teacher_input_ids.size(0) != batch_size or batch.teacher_attention_mask.size(0) != batch_size: + raise ValueError("`teacher_input_ids` and `teacher_attention_mask` must match the student batch size.") + if batch.teacher_input_ids.size(1) != batch.teacher_attention_mask.size(1): + raise ValueError("`teacher_input_ids` and `teacher_attention_mask` must have the same sequence length.") + if batch.self_distillation_mask is not None and batch.self_distillation_mask.size(0) != batch_size: + raise ValueError("`self_distillation_mask` must match the batch size when provided.") + + @abstractmethod + def augment_training_batch( + self, + inputs: list[dict[str, Any]], + rollout_batch: SelfDistillationRolloutBatch, + ) -> SelfDistillationBatch: + """Inject teacher-side inputs and algorithm-specific fields into a common student rollout batch.""" + + def _get_teacher_context_for_self_distillation(self): + peft_teacher_mode = self._resolve_peft_teacher_mode() + if not (is_peft_available() and isinstance(self.model, PeftModel)) or peft_teacher_mode == "inherit_adapter": + return nullcontext() + + target_model = self.teacher_model if self.teacher_model is not None else self.model + target_model = self.accelerator.unwrap_model(target_model) + + if peft_teacher_mode == "disable_adapter": + return use_adapter(target_model, adapter_name=None) + if peft_teacher_mode == "teacher_adapter": + teacher_adapter_name = self.args.teacher_adapter_name + if teacher_adapter_name not in target_model.peft_config: + raise RuntimeError( + f"Expected PEFT teacher adapter `{teacher_adapter_name}` to exist before teacher forward." + ) + return use_adapter(target_model, adapter_name=teacher_adapter_name) + + raise ValueError(f"Unsupported PEFT teacher mode: {peft_teacher_mode}") + + @abstractmethod + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + """Subclasses own algorithm-specific loss composition on the final batch contract.""" diff --git a/trl/experimental/self_distillation/self_distillation_mixin.py b/trl/experimental/self_distillation/self_distillation_mixin.py index fb2a8808de1..f85dcd236f9 100644 --- a/trl/experimental/self_distillation/self_distillation_mixin.py +++ b/trl/experimental/self_distillation/self_distillation_mixin.py @@ -21,13 +21,11 @@ from __future__ import annotations -from contextlib import nullcontext from typing import Any import torch import torch.nn.functional as F -from ...trainer.utils import entropy_from_logits, selective_log_softmax from .self_distillation_config import SelfDistillationConfig @@ -62,26 +60,6 @@ def _split_prompt_and_privileged_context(inputs: list[dict[str, Any]]) -> tuple[ def _allow_topk_without_full_logit_distillation(self) -> bool: return True - def _get_per_token_logps_and_entropies( - self, - model, - input_ids, - attention_mask, - logits_to_keep, - compute_entropy=False, - ): - model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} - if "logits_to_keep" in self.model_kwarg_keys: - model_inputs["logits_to_keep"] = logits_to_keep + 1 - logits = model(**model_inputs).logits - logits = logits[:, :-1, :] - logits = logits[:, -logits_to_keep:, :] - logits = logits / self.temperature - completion_ids = input_ids[:, -logits_to_keep:] - selected_logps = selective_log_softmax(logits, completion_ids) - entropies = entropy_from_logits(logits) if compute_entropy else None - return selected_logps, entropies - def _compute_self_distillation_loss( self, model, @@ -214,9 +192,6 @@ def _get_teacher_model_for_self_distillation(self, model): return model return teacher_model - def _get_teacher_context_for_self_distillation(self, model): - return nullcontext() - def _log_self_distillation_metric(self, mode: str, metric_name: str, value: float) -> None: metric_prefix = getattr(self, "_name", "self_distillation").lower().replace(" ", "_") self._metrics[mode][f"self_distillation/{metric_name}"].append(value) diff --git a/trl/experimental/self_distillation/unified_base_self_distillation_trainer.py b/trl/experimental/self_distillation/unified_base_self_distillation_trainer.py deleted file mode 100644 index ff738c4a548..00000000000 --- a/trl/experimental/self_distillation/unified_base_self_distillation_trainer.py +++ /dev/null @@ -1,619 +0,0 @@ -# Copyright 2020-2026 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import copy -import inspect -from abc import ABC, abstractmethod -from collections import defaultdict -from dataclasses import dataclass, field -from functools import partial -from typing import Any - -import datasets -import torch -from accelerate.logging import get_logger -from accelerate.utils import is_peft_model -from datasets import Dataset, IterableDataset -from torch import nn -from torch.utils.data import DataLoader, Sampler -from transformers import ( - AutoProcessor, - GenerationConfig, - PreTrainedModel, - PreTrainedTokenizerBase, - ProcessorMixin, - TrainerCallback, -) -from transformers.trainer_utils import seed_worker -from transformers.utils import is_datasets_available, is_peft_available - -from ...models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation -from ...trainer.base_trainer import _BaseTrainer -from ...trainer.utils import ( - RepeatSampler, - create_model_from_path, - disable_dropout_in_model, - get_config_model_id, - identity, - pad, - split_tensor_dict, - use_adapter, -) -from ..utils import prepare_peft_model -from .self_distillation_config import SelfDistillationConfig -from .self_distillation_mixin import SelfDistillationMixin -from .teacher_context import PromptTokenizer -from .teacher_sync import PEFTAdapterEMACallback, SyncTeacherModelCallback - - -if is_peft_available(): - from peft import PeftConfig - from peft.peft_model import PeftModel - - -logger = get_logger(__name__) - - -@dataclass -class SelfDistillationRolloutBatch: - """Common student rollout batch produced before algorithm-specific augmentation.""" - - prompt_ids: torch.Tensor - prompt_mask: torch.Tensor - completion_ids: torch.Tensor - completion_mask: torch.Tensor - old_per_token_logps: torch.Tensor | None = None - metadata: dict[str, Any] = field(default_factory=dict) - - def to_dict(self) -> dict[str, torch.Tensor | Any]: - output: dict[str, torch.Tensor | Any] = { - "prompt_ids": self.prompt_ids, - "prompt_mask": self.prompt_mask, - "completion_ids": self.completion_ids, - "completion_mask": self.completion_mask, - } - if self.old_per_token_logps is not None: - output["old_per_token_logps"] = self.old_per_token_logps - output.update(self.metadata) - return output - - -@dataclass -class SelfDistillationBatch: - """Final self-distillation batch contract consumed by `SelfDistillationMixin`.""" - - prompt_ids: torch.Tensor - prompt_mask: torch.Tensor - completion_ids: torch.Tensor - completion_mask: torch.Tensor - teacher_input_ids: torch.Tensor - teacher_attention_mask: torch.Tensor - old_per_token_logps: torch.Tensor | None = None - self_distillation_mask: torch.Tensor | None = None - metadata: dict[str, Any] = field(default_factory=dict) - - def to_dict(self) -> dict[str, torch.Tensor | Any]: - output: dict[str, torch.Tensor | Any] = { - "prompt_ids": self.prompt_ids, - "prompt_mask": self.prompt_mask, - "completion_ids": self.completion_ids, - "completion_mask": self.completion_mask, - "teacher_input_ids": self.teacher_input_ids, - "teacher_attention_mask": self.teacher_attention_mask, - } - if self.old_per_token_logps is not None: - output["old_per_token_logps"] = self.old_per_token_logps - if self.self_distillation_mask is not None: - output["self_distillation_mask"] = self.self_distillation_mask - output.update(self.metadata) - return output - - -class UnifiedBaseSelfDistillationTrainer(SelfDistillationMixin, _BaseTrainer, ABC): - """Prototype base that centralizes shared self-distillation trainer lifecycle.""" - - config_cls = SelfDistillationConfig - _tag_names = ["trl", "self-distillation"] - _name = "Self-Distillation" - - def __init__( - self, - model: str | PreTrainedModel | nn.Module, - args: SelfDistillationConfig | None = None, - train_dataset: Dataset | IterableDataset | None = None, - eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, - processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, - callbacks: list[TrainerCallback] | None = None, - optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), - peft_config: PeftConfig | None = None, - ): - if train_dataset is None: - raise ValueError("`train_dataset` is required") - self.use_vllm = args.use_vllm - - if isinstance(model, str): - model_init_kwargs = args.model_init_kwargs or {} - if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: - model_init_kwargs["device_map"] = None - model = create_model_from_path(model, **model_init_kwargs) - elif args.model_init_kwargs is not None: - logger.warning( - "You passed `model_init_kwargs` to the self-distillation config, but `model` is already " - "instantiated. The `model_init_kwargs` will be ignored." - ) - - self.model_kwarg_keys = ( - inspect.signature(model.forward).parameters.keys() - if not hasattr(model, "get_base_model") - else inspect.signature(model.get_base_model().forward).parameters.keys() - ) - - if is_peft_available() and is_peft_model(model) and peft_config is not None: - raise ValueError( - "You passed a `PeftModel` instance together with a `peft_config`. Pass either a base " - "model with `peft_config`, or a pre-wrapped PEFT model." - ) - if peft_config is not None or (is_peft_available() and getattr(model, "peft_config", None) is not None): - model = prepare_peft_model(model, peft_config, args) - - if processing_class is None: - processing_class = AutoProcessor.from_pretrained( - get_config_model_id(model.config), truncation_side="left", padding_side="left" - ) - - if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer - elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class - else: - raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - self.pad_token = tokenizer.pad_token - self.pad_token_id = tokenizer.pad_token_id - self.eos_token_id = tokenizer.eos_token_id - self.max_prompt_length = args.max_prompt_length - self.max_completion_length = args.max_completion_length - self.num_generations = args.num_generations - self.num_generations_eval = args.num_generations_eval or args.num_generations - self.num_iterations = args.num_iterations - self.shuffle_dataset = args.shuffle_dataset - self.loss_type = args.loss_type - self.mask_truncated_completions = args.mask_truncated_completions - self.temperature = args.temperature - self.chat_template_kwargs = args.chat_template_kwargs or {} - self._step = 0 - self._last_loaded_step = 0 - self._buffered_inputs = None - self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} - self._diagnostic_counters = { - "train": defaultdict(int), - "eval": defaultdict(int), - } - self.prompt_tokenizer = PromptTokenizer(self) - - generation_kwargs = { - "max_new_tokens": self.max_completion_length, - "do_sample": True, - "pad_token_id": tokenizer.pad_token_id, - "bos_token_id": tokenizer.bos_token_id, - "eos_token_id": tokenizer.eos_token_id, - "temperature": args.temperature, - "top_p": args.top_p, - "top_k": args.top_k, - "min_p": args.min_p, - "repetition_penalty": args.repetition_penalty, - "cache_implementation": args.cache_implementation, - } - if args.generation_kwargs is not None: - generation_kwargs.update(args.generation_kwargs) - self.generation_config = GenerationConfig(**generation_kwargs, disable_compile=True) - - if hasattr(model, "warnings_issued"): - model.warnings_issued["estimate_tokens"] = True - - super().__init__( - model=model, - args=args, - data_collator=identity, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - processing_class=processing_class, - callbacks=callbacks, - optimizers=optimizers, - compute_loss_func="non-None value to disable scaling", - ) - - if self.use_vllm: - from ...generation.vllm_generation import VLLMGeneration - - self.vllm_generation = VLLMGeneration( - model=self.model, - accelerator=self.accelerator, - is_fsdp_enabled=self.is_fsdp_enabled, - processing_class=self.processing_class, - mode=args.vllm_mode, - server_base_url=args.vllm_server_base_url, - server_host=args.vllm_server_host, - server_port=args.vllm_server_port, - group_port=args.vllm_group_port, - server_timeout=args.vllm_server_timeout, - tensor_parallel_size=args.vllm_tensor_parallel_size, - gpu_memory_utilization=args.vllm_gpu_memory_utilization, - max_model_length=args.vllm_max_model_length, - max_num_seqs=args.per_device_train_batch_size - * args.vllm_tensor_parallel_size - * args.steps_per_generation, - enable_sleep_mode=args.vllm_enable_sleep_mode, - model_impl=args.vllm_model_impl, - repetition_penalty=args.repetition_penalty, - temperature=self.temperature, - top_p=args.top_p, - top_k=args.top_k, - min_p=args.min_p, - max_completion_length=self.max_completion_length, - logprobs=None, - generation_kwargs=args.generation_kwargs, - ) - self._last_loaded_step = -1 - - if args.disable_dropout: - disable_dropout_in_model(self.model) - - if hasattr(self.model, "add_model_tags"): - self.model.add_model_tags(self._tag_names) - - self._setup_teacher_model() - self.model_accepts_loss_kwargs = False - - def _setup_teacher_model(self) -> None: - """Prepare teacher state according to the shared teacher policy.""" - - teacher_regularization = self.args.teacher_regularization - peft_teacher_mode = self._resolve_peft_teacher_mode() - self._validate_teacher_policy(teacher_regularization, peft_teacher_mode) - - if teacher_regularization == "none": - return - - if is_peft_available() and is_peft_model(self.model) and peft_teacher_mode == "teacher_adapter": - self.add_callback( - PEFTAdapterEMACallback( - model=self.model, - teacher_adapter_name=self.args.teacher_adapter_name, - update_rate=self.args.teacher_update_rate, - sync_steps=self.args.teacher_sync_steps, - accelerator=self.accelerator, - ) - ) - return - - student_model = self.accelerator.unwrap_model(self.model) - self.teacher_model = copy.deepcopy(student_model) - self.teacher_model.requires_grad_(False) - self.teacher_model.eval() - - if self.is_deepspeed_enabled: - self.teacher_model = prepare_deepspeed(self.teacher_model, self.accelerator) - elif self.is_fsdp_enabled: - self.teacher_model = prepare_fsdp(self.teacher_model, self.accelerator) - else: - self.teacher_model = self.accelerator.prepare_model(self.teacher_model, evaluation_mode=True) - - self.add_callback(SyncTeacherModelCallback(teacher_model=self.teacher_model, accelerator=self.accelerator)) - - def _resolve_peft_teacher_mode(self) -> str: - peft_teacher_mode = self.args.peft_teacher_mode - if not (is_peft_available() and is_peft_model(self.model)): - if peft_teacher_mode in {"disable_adapter", "teacher_adapter"}: - raise ValueError(f"PEFT teacher mode `{peft_teacher_mode}` requires a PEFT model.") - return "inherit_adapter" - - if peft_teacher_mode == "auto": - if self.args.teacher_regularization == "ema": - return "teacher_adapter" - return "disable_adapter" - - return peft_teacher_mode - - def _validate_teacher_policy(self, teacher_regularization: str, peft_teacher_mode: str) -> None: - if teacher_regularization not in {"none", "ema"}: - raise ValueError(f"Unsupported teacher regularization mode: {teacher_regularization}") - if peft_teacher_mode not in {"inherit_adapter", "disable_adapter", "teacher_adapter"}: - raise ValueError(f"Unsupported PEFT teacher mode: {peft_teacher_mode}") - if peft_teacher_mode == "teacher_adapter" and teacher_regularization != "ema": - raise ValueError("PEFT teacher mode `teacher_adapter` requires EMA teacher regularization.") - - def get_train_dataloader(self): - if self.train_dataset is None: - raise ValueError("Trainer: training requires a train_dataset.") - - train_dataset = self.train_dataset - data_collator = self.data_collator - if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): - train_dataset = self._remove_unused_columns(train_dataset, description="training") - else: - data_collator = self._get_collator_with_removed_columns(data_collator, description="training") - - dataloader_params = { - "batch_size": self._train_batch_size * self.args.steps_per_generation, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "persistent_workers": self.args.dataloader_persistent_workers, - } - if not isinstance(train_dataset, torch.utils.data.IterableDataset): - dataloader_params["sampler"] = self._get_train_sampler() - dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["worker_init_fn"] = partial( - seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index - ) - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) - - def _get_train_sampler(self, dataset=None) -> Sampler: - if dataset is None: - dataset = self.train_dataset - return RepeatSampler( - data_source=dataset, - mini_repeat_count=self.num_generations, - batch_size=self.args.generation_batch_size // self.num_generations, - repeat_count=self.num_iterations * self.args.steps_per_generation, - shuffle=self.shuffle_dataset, - seed=self.args.seed, - ) - - def _get_eval_sampler(self, eval_dataset) -> Sampler: - return RepeatSampler( - data_source=eval_dataset, - mini_repeat_count=self.num_generations_eval, - seed=self.args.seed, - ) - - def training_step(self, model, inputs, num_items_in_batch): - output = super().training_step(model, inputs, num_items_in_batch) - self._step += 1 - return output - - def _prepare_inputs(self, generation_batch): - mode = "train" if self.model.training else "eval" - if mode == "train": - generate_every = self.args.steps_per_generation * self.num_iterations - if self._step % generate_every == 0 or self._buffered_inputs is None: - buffered_batch = self._build_buffered_batch(generation_batch) - self._buffered_inputs = split_tensor_dict(buffered_batch, self.args.steps_per_generation) - self._dispatch_self_distillation_callback( - "on_generation_batch_built", - generate_every=generate_every, - steps_per_generation=self.args.steps_per_generation, - ) - return self._buffered_inputs[self._step % self.args.steps_per_generation] - return self._build_buffered_batch(generation_batch) - - def _build_buffered_batch(self, inputs: list[dict[str, Any]]) -> dict[str, torch.Tensor | Any]: - return self.build_training_batch(inputs).to_dict() - - def build_training_batch(self, inputs: list[dict[str, Any]]) -> SelfDistillationBatch: - rollout_batch = self.build_rollout_batch(inputs) - self._validate_rollout_batch(rollout_batch) - - batch = self.augment_training_batch(inputs, rollout_batch) - self._validate_training_batch(batch) - - self._dispatch_self_distillation_callback( - "on_self_distillation_batch_prepared", - old_per_token_logps=batch.old_per_token_logps, - prompt_ids=batch.prompt_ids, - completion_ids=batch.completion_ids, - teacher_input_ids=batch.teacher_input_ids, - teacher_attention_mask=batch.teacher_attention_mask, - self_distillation_mask=batch.self_distillation_mask, - ) - return batch - - def build_rollout_batch(self, inputs: list[dict[str, Any]]) -> SelfDistillationRolloutBatch: - prompts, _ = self._split_prompt_and_privileged_context(inputs) - generation_prompts = prompts - generation_prompt_text = self.prompt_tokenizer.apply_prompt_template(generation_prompts) - self._dispatch_self_distillation_callback( - "on_generation_prompts_selected", - generation_prompts=generation_prompts, - generation_prompt_text=generation_prompt_text, - ) - - prompt_ids_list, completion_ids_list = self._generate(generation_prompts) - device = self.accelerator.device - prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list] - prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left").to(device=device) - prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left").to(device=device) - completion_ids = [torch.tensor(ids) for ids in completion_ids_list] - completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right").to(device=device) - completion_mask = pad(completion_mask, padding_value=0, padding_side="right").to(device=device) - if self.mask_truncated_completions: - eos_and_pad = [self.eos_token_id, self.pad_token_id] - is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) - completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() - old_per_token_logps = self.compute_rollout_logps( - prompt_ids=prompt_ids, - prompt_mask=prompt_mask, - completion_ids=completion_ids, - completion_mask=completion_mask, - ) - - return SelfDistillationRolloutBatch( - prompt_ids=prompt_ids, - prompt_mask=prompt_mask, - completion_ids=completion_ids, - completion_mask=completion_mask, - old_per_token_logps=old_per_token_logps, - metadata={ - "raw_completion_lengths": torch.tensor( - [len(ids) for ids in completion_ids_list], device=device, dtype=torch.long - ) - }, - ) - - def _generate(self, prompts: list[Any]) -> tuple[list[list[int]], list[list[int]]]: - if self.use_vllm: - return self._generate_vllm(prompts) - return self._generate_transformers(prompts) - - def _generate_vllm(self, prompts: list[Any]) -> tuple[list[list[int]], list[list[int]]]: - if self.state.global_step != self._last_loaded_step: - self.vllm_generation.sync_weights() - self._last_loaded_step = self.state.global_step - - prompts_text = self.prompt_tokenizer.apply_prompt_template(prompts) - tokenized = self.processing_class( - text=prompts_text, - return_tensors=None, - padding=False, - max_length=self.max_prompt_length, - truncation=True, - add_special_tokens=False, - ) - prompt_ids = tokenized["input_ids"] - mode = "train" if self.model.training else "eval" - num_generations = self.num_generations if mode == "train" else self.num_generations_eval - prompt_ids_out, completion_ids_list, _, _ = self.vllm_generation.generate( - prompts=prompt_ids, - images=None, - num_generations=num_generations, - ) - return prompt_ids_out, completion_ids_list - - def _generate_transformers(self, prompts: list[Any]) -> tuple[list[list[int]], list[list[int]]]: - generate_inputs = self.processing_class( - text=self.prompt_tokenizer.apply_prompt_template(prompts), - return_tensors="pt", - padding=True, - padding_side="left", - max_length=self.max_prompt_length, - truncation=True, - add_special_tokens=False, - ) - generate_inputs = _BaseTrainer._prepare_inputs(self, generate_inputs) - - with ( - unwrap_model_for_generation( - self.model_wrapped, - self.accelerator, - gather_deepspeed3_params=self.args.ds3_gather_for_generation, - ) as unwrapped_model, - torch.no_grad(), - ): - prompt_completion_ids = unwrapped_model.generate( - **generate_inputs, generation_config=self.generation_config - ) - - prompt_ids = generate_inputs["input_ids"] - prompt_mask = generate_inputs["attention_mask"] - prompt_length = prompt_ids.size(1) - completion_ids = prompt_completion_ids[:, prompt_length:] - is_eos = completion_ids == self.eos_token_id - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device) - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - seq_idx = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1) - completion_mask = (seq_idx <= eos_idx.unsqueeze(1)).long() - - prompt_ids_list = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool(), strict=False)] - completion_ids_list = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=False)] - return prompt_ids_list, completion_ids_list - - def compute_rollout_logps( - self, - prompt_ids: torch.Tensor, - prompt_mask: torch.Tensor, - completion_ids: torch.Tensor, - completion_mask: torch.Tensor, - ) -> torch.Tensor | None: - generate_every = self.args.steps_per_generation * self.num_iterations - old_per_token_logps = None - - if self.args.gradient_accumulation_steps % generate_every != 0: - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - logits_to_keep = completion_ids.size(1) - old_per_token_logps, _ = self._get_per_token_logps_and_entropies( - self.model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - compute_entropy=False, - ) - - return old_per_token_logps - - def _validate_rollout_batch(self, batch: SelfDistillationRolloutBatch) -> None: - batch_size = batch.prompt_ids.size(0) - if batch.prompt_mask.size(0) != batch_size: - raise ValueError("`prompt_mask` must have the same batch size as `prompt_ids` in the rollout batch.") - if batch.completion_ids.size(0) != batch_size or batch.completion_mask.size(0) != batch_size: - raise ValueError("`completion_ids` and `completion_mask` must match the rollout batch size.") - if batch.old_per_token_logps is not None and batch.old_per_token_logps.size(0) != batch_size: - raise ValueError("`old_per_token_logps` must match the rollout batch size when provided.") - - def _validate_training_batch(self, batch: SelfDistillationBatch) -> None: - batch_size = batch.prompt_ids.size(0) - if batch.prompt_mask.size(0) != batch_size: - raise ValueError("`prompt_mask` must have the same batch size as `prompt_ids`.") - if batch.completion_ids.size(0) != batch_size or batch.completion_mask.size(0) != batch_size: - raise ValueError("`completion_ids` and `completion_mask` must match the student batch size.") - if batch.teacher_input_ids.size(0) != batch_size or batch.teacher_attention_mask.size(0) != batch_size: - raise ValueError("`teacher_input_ids` and `teacher_attention_mask` must match the student batch size.") - if batch.teacher_input_ids.size(1) != batch.teacher_attention_mask.size(1): - raise ValueError("`teacher_input_ids` and `teacher_attention_mask` must have the same sequence length.") - if batch.self_distillation_mask is not None and batch.self_distillation_mask.size(0) != batch_size: - raise ValueError("`self_distillation_mask` must match the batch size when provided.") - - @abstractmethod - def augment_training_batch( - self, - inputs: list[dict[str, Any]], - rollout_batch: SelfDistillationRolloutBatch, - ) -> SelfDistillationBatch: - """Inject teacher-side inputs and algorithm-specific fields into a common student rollout batch.""" - - def _get_teacher_context_for_self_distillation(self, model): - peft_teacher_mode = self._resolve_peft_teacher_mode() - if not (is_peft_available() and isinstance(self.model, PeftModel)): - return super()._get_teacher_context_for_self_distillation(model) - - if peft_teacher_mode == "inherit_adapter": - return super()._get_teacher_context_for_self_distillation(model) - - target_model = self.teacher_model if self.teacher_model is not None else self.model - target_model = self.accelerator.unwrap_model(target_model) - - if peft_teacher_mode == "disable_adapter": - return use_adapter(target_model, adapter_name=None) - if peft_teacher_mode == "teacher_adapter": - teacher_adapter_name = self.args.teacher_adapter_name - if teacher_adapter_name not in target_model.peft_config: - raise RuntimeError( - f"Expected PEFT teacher adapter `{teacher_adapter_name}` to exist before teacher forward." - ) - return use_adapter(target_model, adapter_name=teacher_adapter_name) - - raise ValueError(f"Unsupported PEFT teacher mode: {peft_teacher_mode}") - - @abstractmethod - def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): - """Subclasses own algorithm-specific loss composition on the final batch contract.""" From 81def8a6d0b3181dfb23ccf41765ade1db86f026 Mon Sep 17 00:00:00 2001 From: Leon Date: Thu, 16 Apr 2026 08:24:41 +0200 Subject: [PATCH 05/23] sdft and sdpo transitioned and tested with new base --- trl/experimental/sdft/sdft.py | 4 +- trl/experimental/sdft/sdft_config.py | 12 ++-- trl/experimental/sdpo/sdpo.py | 2 +- trl/experimental/sdpo/sdpo_config.py | 10 ++-- .../base_self_distillation_trainer.py | 58 +++++++++---------- .../self_distillation_config.py | 40 +++---------- 6 files changed, 47 insertions(+), 79 deletions(-) diff --git a/trl/experimental/sdft/sdft.py b/trl/experimental/sdft/sdft.py index 64521a2707e..e90c4f060ac 100644 --- a/trl/experimental/sdft/sdft.py +++ b/trl/experimental/sdft/sdft.py @@ -46,7 +46,7 @@ --learning_rate 2e-5 \ --max_prompt_length 1024 \ --max_completion_length 512 \ - --teacher_regularization ema \ + --teacher_model_kind ema \ --teacher_sync_steps 1 \ --teacher_update_rate 0.01 \ --eval_strategy steps \ @@ -77,7 +77,7 @@ ) from trl.data_utils import maybe_apply_chat_template from trl.experimental.sdft import SDFTConfig -from trl.experimental.sdft.sdft_trainer_transition import SDFTTrainer +from trl.experimental.sdft.sdft_trainer import SDFTTrainer from trl.models import unwrap_model_for_generation diff --git a/trl/experimental/sdft/sdft_config.py b/trl/experimental/sdft/sdft_config.py index 6e27bb17216..041c85a7dc0 100644 --- a/trl/experimental/sdft/sdft_config.py +++ b/trl/experimental/sdft/sdft_config.py @@ -28,10 +28,8 @@ class SDFTConfig(SelfDistillationConfig): Parameters: disable_dropout (`bool`, *optional*, defaults to `True`): Whether to disable dropout in the student and teacher models. - peft_teacher_mode (`str`, *optional*, defaults to `"auto"`): - PEFT teacher execution mode. The default `auto` reproduces the original SDFT behavior: use the - adapter-disabled base model without EMA teacher regularization, and the EMA teacher adapter when - `teacher_regularization="ema"`. + teacher_model_kind (`str`, *optional*, defaults to `"base"`): + Semantic teacher choice for SDFT. Supported: `base`, `live`, `ema`. generate_from_teacher (`bool`, *optional*, defaults to `False`): Whether on-policy generation should use the teacher-conditioned prompt instead of the student prompt. teacher_prompt_template (`str`, *optional*, defaults to `"{prompt}\n\n{privileged_context}"`): @@ -44,9 +42,9 @@ class SDFTConfig(SelfDistillationConfig): default=True, metadata={"help": "Whether to disable dropout in the student and teacher models."}, ) - peft_teacher_mode: str = field( - default="auto", - metadata={"help": "PEFT teacher execution mode. `auto` reproduces the original SDFT teacher behavior."}, + teacher_model_kind: str = field( + default="base", + metadata={"help": "Semantic teacher choice for SDFT. Supported: `base`, `live`, `ema`."}, ) generate_from_teacher: bool = field( default=False, diff --git a/trl/experimental/sdpo/sdpo.py b/trl/experimental/sdpo/sdpo.py index 5ab93fbaf5b..25f75a55bab 100644 --- a/trl/experimental/sdpo/sdpo.py +++ b/trl/experimental/sdpo/sdpo.py @@ -79,7 +79,7 @@ ) from trl.data_utils import maybe_apply_chat_template from trl.experimental.sdpo import SDPOConfig -from trl.experimental.sdpo.sdpo_trainer_transition import SDPOTrainer +from trl.experimental.sdpo.sdpo_trainer import SDPOTrainer SYSTEM_PROMPT = ( diff --git a/trl/experimental/sdpo/sdpo_config.py b/trl/experimental/sdpo/sdpo_config.py index 707d38e1648..87c11ea9046 100644 --- a/trl/experimental/sdpo/sdpo_config.py +++ b/trl/experimental/sdpo/sdpo_config.py @@ -39,10 +39,10 @@ class SDPOConfig(SelfDistillationConfig): > Parameters that control the teacher - teacher_regularization (`str`, *optional*, defaults to `"ema"`): - Teacher update strategy. Supported: `ema`, `none`. + teacher_model_kind (`str`, *optional*, defaults to `"ema"`): + Semantic teacher choice. Supported: `base`, `live`, `ema`. teacher_update_rate (`float`, *optional*, defaults to `0.05`): - EMA update rate used when `teacher_regularization="ema"`. + EMA update rate used when `teacher_model_kind="ema"`. teacher_sync_steps (`int`, *optional*, defaults to `1`): Number of optimizer steps between teacher EMA updates. @@ -74,9 +74,9 @@ class SDPOConfig(SelfDistillationConfig): default="distillation_only", metadata={"help": "SDPO policy loss mode. Supported: `distillation_only`, `hybrid`."}, ) - teacher_regularization: str = field( + teacher_model_kind: str = field( default="ema", - metadata={"help": "Teacher regularization mode. Supported: `ema`, `none`."}, + metadata={"help": "Semantic teacher choice. Supported: `base`, `live`, `ema`."}, ) teacher_update_rate: float = field( default=0.05, diff --git a/trl/experimental/self_distillation/base_self_distillation_trainer.py b/trl/experimental/self_distillation/base_self_distillation_trainer.py index d8171cc88e1..2b3685339b5 100644 --- a/trl/experimental/self_distillation/base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/base_self_distillation_trainer.py @@ -284,20 +284,18 @@ def __init__( self.model_accepts_loss_kwargs = False def _setup_teacher_model(self) -> None: - """Prepare teacher state according to the shared teacher policy.""" + """Prepare teacher state according to the semantic teacher choice.""" - teacher_regularization = self.args.teacher_regularization - peft_teacher_mode = self._resolve_peft_teacher_mode() - self._validate_teacher_policy(teacher_regularization, peft_teacher_mode) + teacher_model_kind = self.args.teacher_model_kind - if teacher_regularization == "none": + if teacher_model_kind in {"base", "live"}: return - if is_peft_available() and is_peft_model(self.model) and peft_teacher_mode == "teacher_adapter": + if self._use_peft_ema_teacher_adapter(): self.add_callback( PEFTAdapterEMACallback( model=self.model, - teacher_adapter_name=self.args.teacher_adapter_name, + teacher_adapter_name="teacher", update_rate=self.args.teacher_update_rate, sync_steps=self.args.teacher_sync_steps, accelerator=self.accelerator, @@ -319,27 +317,24 @@ def _setup_teacher_model(self) -> None: self.add_callback(SyncTeacherModelCallback(teacher_model=self.teacher_model, accelerator=self.accelerator)) - def _resolve_peft_teacher_mode(self) -> str: - peft_teacher_mode = self.args.peft_teacher_mode - if not (is_peft_available() and is_peft_model(self.model)): - if peft_teacher_mode in {"disable_adapter", "teacher_adapter"}: - raise ValueError(f"PEFT teacher mode `{peft_teacher_mode}` requires a PEFT model.") - return "inherit_adapter" + def _use_peft_ema_teacher_adapter(self) -> bool: + return self.args.teacher_model_kind == "ema" and self._is_pure_lora_training() - if peft_teacher_mode == "auto": - if self.args.teacher_regularization == "ema": - return "teacher_adapter" - return "disable_adapter" + def _is_pure_lora_training(self) -> bool: + if not (is_peft_available() and is_peft_model(self.model)): + return False - return peft_teacher_mode + model = self.accelerator.unwrap_model(self.model) + adapter_name = getattr(model, "active_adapter", None) or "default" + adapter_config = model.peft_config.get(adapter_name) + peft_type = getattr(adapter_config, "peft_type", None) + if peft_type is None or str(peft_type).split(".")[-1] != "LORA": + return False - def _validate_teacher_policy(self, teacher_regularization: str, peft_teacher_mode: str) -> None: - if teacher_regularization not in {"none", "ema"}: - raise ValueError(f"Unsupported teacher regularization mode: {teacher_regularization}") - if peft_teacher_mode not in {"inherit_adapter", "disable_adapter", "teacher_adapter"}: - raise ValueError(f"Unsupported PEFT teacher mode: {peft_teacher_mode}") - if peft_teacher_mode == "teacher_adapter" and teacher_regularization != "ema": - raise ValueError("PEFT teacher mode `teacher_adapter` requires EMA teacher regularization.") + for name, param in model.named_parameters(): + if param.requires_grad and "lora_" not in name: + return False + return True def get_train_dataloader(self): if self.train_dataset is None: @@ -605,24 +600,23 @@ def augment_training_batch( """Inject teacher-side inputs and algorithm-specific fields into a common student rollout batch.""" def _get_teacher_context_for_self_distillation(self): - peft_teacher_mode = self._resolve_peft_teacher_mode() - if not (is_peft_available() and isinstance(self.model, PeftModel)) or peft_teacher_mode == "inherit_adapter": + teacher_model_kind = self.args.teacher_model_kind + if not (is_peft_available() and isinstance(self.model, PeftModel)): return nullcontext() target_model = self.teacher_model if self.teacher_model is not None else self.model target_model = self.accelerator.unwrap_model(target_model) - if peft_teacher_mode == "disable_adapter": + if teacher_model_kind == "base": return use_adapter(target_model, adapter_name=None) - if peft_teacher_mode == "teacher_adapter": - teacher_adapter_name = self.args.teacher_adapter_name + if teacher_model_kind == "ema" and self._use_peft_ema_teacher_adapter(): + teacher_adapter_name = self._get_teacher_adapter_name() if teacher_adapter_name not in target_model.peft_config: raise RuntimeError( f"Expected PEFT teacher adapter `{teacher_adapter_name}` to exist before teacher forward." ) return use_adapter(target_model, adapter_name=teacher_adapter_name) - - raise ValueError(f"Unsupported PEFT teacher mode: {peft_teacher_mode}") + return nullcontext() @abstractmethod def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): diff --git a/trl/experimental/self_distillation/self_distillation_config.py b/trl/experimental/self_distillation/self_distillation_config.py index 4611ddacd7a..a03773f9bc2 100644 --- a/trl/experimental/self_distillation/self_distillation_config.py +++ b/trl/experimental/self_distillation/self_distillation_config.py @@ -54,17 +54,12 @@ class SelfDistillationConfig(_BaseConfig): > Parameters that control teacher construction - teacher_regularization (`str`, *optional*, defaults to `"none"`): - Teacher update strategy. Supported: `none`, `ema`. + teacher_model_kind (`str`, *optional*, defaults to `"live"`): + Semantic teacher choice. Supported: `base`, `live`, `ema`. teacher_update_rate (`float`, *optional*, defaults to `0.6`): - EMA update rate used when `teacher_regularization="ema"`. + EMA update rate used when `teacher_model_kind="ema"`. teacher_sync_steps (`int`, *optional*, defaults to `512`): Number of optimizer steps between EMA teacher updates. - peft_teacher_mode (`str`, *optional*, defaults to `"inherit_adapter"`): - How teacher forwards should behave for PEFT models. Supported: `auto`, `inherit_adapter`, - `disable_adapter`, `teacher_adapter`. - teacher_adapter_name (`str`, *optional*, defaults to `"teacher"`): - Adapter name used when `peft_teacher_mode="teacher_adapter"`. > Parameters that control self-distillation @@ -252,9 +247,9 @@ class SelfDistillationConfig(_BaseConfig): default=False, metadata={"help": "Whether to exclude truncated completions from the loss."}, ) - teacher_regularization: str = field( - default="none", - metadata={"help": "Teacher regularization mode. Supported: `none`, `ema`."}, + teacher_model_kind: str = field( + default="live", + metadata={"help": "Semantic teacher choice. Supported: `base`, `live`, `ema`."}, ) teacher_update_rate: float = field( default=0.6, @@ -264,17 +259,6 @@ class SelfDistillationConfig(_BaseConfig): default=512, metadata={"help": "How often to synchronize the teacher model."}, ) - peft_teacher_mode: str = field( - default="inherit_adapter", - metadata={ - "help": "Teacher execution mode for PEFT models. Supported: `auto`, `inherit_adapter`, " - "`disable_adapter`, `teacher_adapter`." - }, - ) - teacher_adapter_name: str = field( - default="teacher", - metadata={"help": "Adapter name used for PEFT teacher forwards when `peft_teacher_mode='teacher_adapter'`."}, - ) top_entropy_quantile: float = field( default=1.0, metadata={"help": "Reserved for entropy-based token filtering."}, @@ -327,20 +311,12 @@ def __post_init__(self): raise ValueError("importance_sampling_level must be either 'token' or 'sequence'") if self.loss_type not in ["grpo", "bnpo", "dr_grpo", "dapo"]: raise ValueError("loss_type must be one of: 'grpo', 'bnpo', 'dr_grpo', 'dapo'") - if self.teacher_regularization not in {"none", "ema"}: - raise ValueError("teacher_regularization must be one of: 'none', 'ema'") + if self.teacher_model_kind not in {"base", "live", "ema"}: + raise ValueError("teacher_model_kind must be one of: 'base', 'live', 'ema'") if not 0.0 <= self.teacher_update_rate <= 1.0: raise ValueError("teacher_update_rate must be in [0, 1]") if self.teacher_sync_steps <= 0: raise ValueError("teacher_sync_steps must be positive") - if self.peft_teacher_mode not in {"auto", "inherit_adapter", "disable_adapter", "teacher_adapter"}: - raise ValueError( - "peft_teacher_mode must be one of: 'auto', 'inherit_adapter', 'disable_adapter', 'teacher_adapter'" - ) - if self.peft_teacher_mode == "teacher_adapter" and self.teacher_regularization != "ema": - raise ValueError("peft_teacher_mode='teacher_adapter' requires teacher_regularization='ema'") - if self.teacher_adapter_name == "": - raise ValueError("teacher_adapter_name must be non-empty") if self.num_generations < 1: raise ValueError("num_generations must be at least 1") if not 0.0 <= self.distillation_alpha <= 1.0: From bad6b62706b5d4fa4a62974f8ceb2b95dd9f9a66 Mon Sep 17 00:00:00 2001 From: Leon Date: Thu, 16 Apr 2026 08:38:14 +0200 Subject: [PATCH 06/23] restructure training batch builder --- trl/experimental/sdft/sdft_trainer.py | 27 ++-- trl/experimental/sdpo/sdpo_trainer.py | 33 ++--- .../base_self_distillation_trainer.py | 129 ++++++++---------- 3 files changed, 84 insertions(+), 105 deletions(-) diff --git a/trl/experimental/sdft/sdft_trainer.py b/trl/experimental/sdft/sdft_trainer.py index 057433762d2..fe4a1b9bb09 100644 --- a/trl/experimental/sdft/sdft_trainer.py +++ b/trl/experimental/sdft/sdft_trainer.py @@ -31,8 +31,8 @@ from ..self_distillation.base_self_distillation_trainer import ( BaseSelfDistillationTrainer, - SelfDistillationBatch, - SelfDistillationRolloutBatch, + RolloutBatch, + TrainingBatch, ) from ..self_distillation.teacher_context import PromptTokenizer, extract_last_user_text from .sdft_config import SDFTConfig @@ -166,11 +166,11 @@ def __init__( self.num_loss_tokens_to_skip = args.num_loss_tokens_to_skip self.teacher_context_builder = DemonstrationTeacherContextBuilder(self) - def augment_training_batch( + def finalize_batch( self, inputs: list[dict[str, Any]], - rollout_batch: SelfDistillationRolloutBatch, - ) -> SelfDistillationBatch: + rollout_batch: RolloutBatch, + ) -> TrainingBatch: prompts, privileged_contexts = self._split_prompt_and_privileged_context(inputs) teacher_batch = self.teacher_context_builder.build( prompts, @@ -179,15 +179,16 @@ def augment_training_batch( rollout_batch.completion_mask, ) - return SelfDistillationBatch( - prompt_ids=teacher_batch["prompt_ids"], - prompt_mask=teacher_batch["prompt_mask"], - completion_ids=rollout_batch.completion_ids, - completion_mask=rollout_batch.completion_mask, - teacher_input_ids=teacher_batch["teacher_input_ids"], - teacher_attention_mask=teacher_batch["teacher_attention_mask"], - old_per_token_logps=rollout_batch.old_per_token_logps, + batch = super().finalize_batch(inputs, rollout_batch) + batch.update( + { + "prompt_ids": teacher_batch["prompt_ids"], + "prompt_mask": teacher_batch["prompt_mask"], + "teacher_input_ids": teacher_batch["teacher_input_ids"], + "teacher_attention_mask": teacher_batch["teacher_attention_mask"], + } ) + return batch def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): if return_outputs: diff --git a/trl/experimental/sdpo/sdpo_trainer.py b/trl/experimental/sdpo/sdpo_trainer.py index d554391f16e..a57d946aba7 100644 --- a/trl/experimental/sdpo/sdpo_trainer.py +++ b/trl/experimental/sdpo/sdpo_trainer.py @@ -36,8 +36,8 @@ from ...trainer.utils import get_config_model_id, pad from ..self_distillation.base_self_distillation_trainer import ( BaseSelfDistillationTrainer, - SelfDistillationBatch, - SelfDistillationRolloutBatch, + RolloutBatch, + TrainingBatch, ) from ..self_distillation.teacher_context import TokenizedPromptBatch, extract_last_user_text from .sdpo_config import SDPOConfig @@ -396,15 +396,15 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): return self.accelerator.gather(rewards_per_func) - def augment_training_batch( + def finalize_batch( self, inputs: list[dict[str, Any]], - rollout_batch: SelfDistillationRolloutBatch, - ) -> SelfDistillationBatch: + rollout_batch: RolloutBatch, + ) -> TrainingBatch: device = self.accelerator.device mode = "train" if self.model.training else "eval" prompts, privileged_contexts = self._split_prompt_and_privileged_context(inputs) - raw_completion_lengths = rollout_batch.metadata["raw_completion_lengths"].detach().cpu().tolist() + raw_completion_lengths = rollout_batch.raw_completion_lengths.detach().cpu().tolist() completion_ids_list = [ ids[:length].tolist() for ids, length in zip(rollout_batch.completion_ids.detach().cpu(), raw_completion_lengths, strict=True) @@ -461,7 +461,7 @@ def augment_training_batch( self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - rollout_dict = rollout_batch.to_dict() + rollout_dict = rollout_batch.as_dict() rollout_dict["rewards"] = local_rewards rollout_dict["advantages"] = local_advantages rollout_dict["num_items_in_batch"] = rollout_batch.completion_mask.sum().detach() @@ -485,20 +485,17 @@ def augment_training_batch( self_distillation_mask=teacher_context["self_distillation_mask"], ) - return SelfDistillationBatch( - prompt_ids=rollout_batch.prompt_ids, - prompt_mask=rollout_batch.prompt_mask, - completion_ids=rollout_batch.completion_ids, - completion_mask=rollout_batch.completion_mask, - teacher_input_ids=teacher_context["teacher_input_ids"], - teacher_attention_mask=teacher_context["teacher_attention_mask"], - old_per_token_logps=rollout_batch.old_per_token_logps, - self_distillation_mask=teacher_context["self_distillation_mask"], - metadata={ + batch = super().finalize_batch(inputs, rollout_batch) + batch.update( + { + "teacher_input_ids": teacher_context["teacher_input_ids"], + "teacher_attention_mask": teacher_context["teacher_attention_mask"], + "self_distillation_mask": teacher_context["self_distillation_mask"], "rewards": local_rewards, "advantages": local_advantages, - }, + } ) + return batch def _warn_on_inactive_self_distillation(self, mode: str) -> None: metrics = self.teacher_context_builder.last_metrics diff --git a/trl/experimental/self_distillation/base_self_distillation_trainer.py b/trl/experimental/self_distillation/base_self_distillation_trainer.py index 2b3685339b5..9c9c0b627d9 100644 --- a/trl/experimental/self_distillation/base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/base_self_distillation_trainer.py @@ -19,7 +19,7 @@ from abc import ABC, abstractmethod from collections import defaultdict from contextlib import nullcontext -from dataclasses import dataclass, field +from dataclasses import dataclass from functools import partial from typing import Any @@ -71,58 +71,31 @@ @dataclass -class SelfDistillationRolloutBatch: - """Common student rollout batch produced before algorithm-specific augmentation.""" +class RolloutBatch: + """Common student rollout batch produced before algorithm-specific finalization.""" prompt_ids: torch.Tensor prompt_mask: torch.Tensor completion_ids: torch.Tensor completion_mask: torch.Tensor old_per_token_logps: torch.Tensor | None = None - metadata: dict[str, Any] = field(default_factory=dict) + raw_completion_lengths: torch.Tensor | None = None - def to_dict(self) -> dict[str, torch.Tensor | Any]: - output: dict[str, torch.Tensor | Any] = { + def as_dict(self) -> dict[str, torch.Tensor | Any]: + batch: dict[str, torch.Tensor | Any] = { "prompt_ids": self.prompt_ids, "prompt_mask": self.prompt_mask, "completion_ids": self.completion_ids, "completion_mask": self.completion_mask, } if self.old_per_token_logps is not None: - output["old_per_token_logps"] = self.old_per_token_logps - output.update(self.metadata) - return output - + batch["old_per_token_logps"] = self.old_per_token_logps + if self.raw_completion_lengths is not None: + batch["raw_completion_lengths"] = self.raw_completion_lengths + return batch -@dataclass -class SelfDistillationBatch: - """Final self-distillation batch contract consumed by `SelfDistillationMixin`.""" - prompt_ids: torch.Tensor - prompt_mask: torch.Tensor - completion_ids: torch.Tensor - completion_mask: torch.Tensor - teacher_input_ids: torch.Tensor - teacher_attention_mask: torch.Tensor - old_per_token_logps: torch.Tensor | None = None - self_distillation_mask: torch.Tensor | None = None - metadata: dict[str, Any] = field(default_factory=dict) - - def to_dict(self) -> dict[str, torch.Tensor | Any]: - output: dict[str, torch.Tensor | Any] = { - "prompt_ids": self.prompt_ids, - "prompt_mask": self.prompt_mask, - "completion_ids": self.completion_ids, - "completion_mask": self.completion_mask, - "teacher_input_ids": self.teacher_input_ids, - "teacher_attention_mask": self.teacher_attention_mask, - } - if self.old_per_token_logps is not None: - output["old_per_token_logps"] = self.old_per_token_logps - if self.self_distillation_mask is not None: - output["self_distillation_mask"] = self.self_distillation_mask - output.update(self.metadata) - return output +TrainingBatch = dict[str, torch.Tensor | Any] class BaseSelfDistillationTrainer(SelfDistillationMixin, _BaseTrainer, ABC): @@ -392,7 +365,7 @@ def _prepare_inputs(self, generation_batch): if mode == "train": generate_every = self.args.steps_per_generation * self.num_iterations if self._step % generate_every == 0 or self._buffered_inputs is None: - buffered_batch = self._build_buffered_batch(generation_batch) + buffered_batch = self._prepare_training_batch(generation_batch) self._buffered_inputs = split_tensor_dict(buffered_batch, self.args.steps_per_generation) self._dispatch_self_distillation_callback( "on_generation_batch_built", @@ -400,29 +373,26 @@ def _prepare_inputs(self, generation_batch): steps_per_generation=self.args.steps_per_generation, ) return self._buffered_inputs[self._step % self.args.steps_per_generation] - return self._build_buffered_batch(generation_batch) + return self._prepare_training_batch(generation_batch) - def _build_buffered_batch(self, inputs: list[dict[str, Any]]) -> dict[str, torch.Tensor | Any]: - return self.build_training_batch(inputs).to_dict() + def _prepare_training_batch(self, inputs: list[dict[str, Any]]) -> TrainingBatch: + rollout_batch = self.sample_rollouts(inputs) - def build_training_batch(self, inputs: list[dict[str, Any]]) -> SelfDistillationBatch: - rollout_batch = self.build_rollout_batch(inputs) - - batch = self.augment_training_batch(inputs, rollout_batch) + batch = self.finalize_batch(inputs, rollout_batch) self._validate_training_batch(batch) self._dispatch_self_distillation_callback( "on_self_distillation_batch_prepared", - old_per_token_logps=batch.old_per_token_logps, - prompt_ids=batch.prompt_ids, - completion_ids=batch.completion_ids, - teacher_input_ids=batch.teacher_input_ids, - teacher_attention_mask=batch.teacher_attention_mask, - self_distillation_mask=batch.self_distillation_mask, + old_per_token_logps=batch.get("old_per_token_logps"), + prompt_ids=batch["prompt_ids"], + completion_ids=batch["completion_ids"], + teacher_input_ids=batch["teacher_input_ids"], + teacher_attention_mask=batch["teacher_attention_mask"], + self_distillation_mask=batch.get("self_distillation_mask"), ) return batch - def build_rollout_batch(self, inputs: list[dict[str, Any]]) -> SelfDistillationRolloutBatch: + def sample_rollouts(self, inputs: list[dict[str, Any]]) -> RolloutBatch: prompts, _ = self._split_prompt_and_privileged_context(inputs) generation_prompts = prompts generation_prompt_text = self.prompt_tokenizer.apply_prompt_template(generation_prompts) @@ -446,24 +416,22 @@ def build_rollout_batch(self, inputs: list[dict[str, Any]]) -> SelfDistillationR eos_and_pad = [self.eos_token_id, self.pad_token_id] is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() - old_per_token_logps = self.compute_rollout_logps( + old_per_token_logps = self._compute_rollout_logps( prompt_ids=prompt_ids, prompt_mask=prompt_mask, completion_ids=completion_ids, completion_mask=completion_mask, ) - return SelfDistillationRolloutBatch( + return RolloutBatch( prompt_ids=prompt_ids, prompt_mask=prompt_mask, completion_ids=completion_ids, completion_mask=completion_mask, old_per_token_logps=old_per_token_logps, - metadata={ - "raw_completion_lengths": torch.tensor( - [len(ids) for ids in completion_ids_list], device=device, dtype=torch.long - ) - }, + raw_completion_lengths=torch.tensor( + [len(ids) for ids in completion_ids_list], device=device, dtype=torch.long + ), ) def _generate(self, prompts: list[Any]) -> tuple[list[list[int]], list[list[int]]]: @@ -533,7 +501,7 @@ def _generate_transformers(self, prompts: list[Any]) -> tuple[list[list[int]], l completion_ids_list = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=False)] return prompt_ids_list, completion_ids_list - def compute_rollout_logps( + def _compute_rollout_logps( self, prompt_ids: torch.Tensor, prompt_mask: torch.Tensor, @@ -578,26 +546,39 @@ def _get_per_token_logps_and_entropies( entropies = entropy_from_logits(logits) if compute_entropy else None return selected_logps, entropies - def _validate_training_batch(self, batch: SelfDistillationBatch) -> None: - batch_size = batch.prompt_ids.size(0) - if batch.prompt_mask.size(0) != batch_size: + def _validate_training_batch(self, batch: TrainingBatch) -> None: + required_keys = { + "prompt_ids", + "prompt_mask", + "completion_ids", + "completion_mask", + "teacher_input_ids", + "teacher_attention_mask", + } + missing_keys = required_keys.difference(batch) + if missing_keys: + raise ValueError(f"`finalize_batch` must return all required batch keys. Missing: {sorted(missing_keys)}") + + batch_size = batch["prompt_ids"].size(0) + if batch["prompt_mask"].size(0) != batch_size: raise ValueError("`prompt_mask` must have the same batch size as `prompt_ids`.") - if batch.completion_ids.size(0) != batch_size or batch.completion_mask.size(0) != batch_size: + if batch["completion_ids"].size(0) != batch_size or batch["completion_mask"].size(0) != batch_size: raise ValueError("`completion_ids` and `completion_mask` must match the student batch size.") - if batch.teacher_input_ids.size(0) != batch_size or batch.teacher_attention_mask.size(0) != batch_size: + if batch["teacher_input_ids"].size(0) != batch_size or batch["teacher_attention_mask"].size(0) != batch_size: raise ValueError("`teacher_input_ids` and `teacher_attention_mask` must match the student batch size.") - if batch.teacher_input_ids.size(1) != batch.teacher_attention_mask.size(1): + if batch["teacher_input_ids"].size(1) != batch["teacher_attention_mask"].size(1): raise ValueError("`teacher_input_ids` and `teacher_attention_mask` must have the same sequence length.") - if batch.self_distillation_mask is not None and batch.self_distillation_mask.size(0) != batch_size: - raise ValueError("`self_distillation_mask` must match the batch size when provided.") + if "self_distillation_mask" in batch and batch["self_distillation_mask"] is not None: + if batch["self_distillation_mask"].size(0) != batch_size: + raise ValueError("`self_distillation_mask` must match the batch size when provided.") - @abstractmethod - def augment_training_batch( + def finalize_batch( self, inputs: list[dict[str, Any]], - rollout_batch: SelfDistillationRolloutBatch, - ) -> SelfDistillationBatch: - """Inject teacher-side inputs and algorithm-specific fields into a common student rollout batch.""" + rollout_batch: RolloutBatch, + ) -> TrainingBatch: + """Build the final training batch from a shared student rollout batch.""" + return rollout_batch.as_dict() def _get_teacher_context_for_self_distillation(self): teacher_model_kind = self.args.teacher_model_kind From ef43c9574c45bbc3c8b34c881f88684f8e51ff5a Mon Sep 17 00:00:00 2001 From: Leon Date: Thu, 16 Apr 2026 10:21:38 +0200 Subject: [PATCH 07/23] nits --- trl/experimental/sdft/sdft_config.py | 8 ++++++-- trl/experimental/sdpo/sdpo_config.py | 8 ++++++-- .../base_self_distillation_trainer.py | 19 +++++++++++-------- .../self_distillation_config.py | 8 ++++++-- 4 files changed, 29 insertions(+), 14 deletions(-) diff --git a/trl/experimental/sdft/sdft_config.py b/trl/experimental/sdft/sdft_config.py index 041c85a7dc0..08d4f64272a 100644 --- a/trl/experimental/sdft/sdft_config.py +++ b/trl/experimental/sdft/sdft_config.py @@ -29,7 +29,8 @@ class SDFTConfig(SelfDistillationConfig): disable_dropout (`bool`, *optional*, defaults to `True`): Whether to disable dropout in the student and teacher models. teacher_model_kind (`str`, *optional*, defaults to `"base"`): - Semantic teacher choice for SDFT. Supported: `base`, `live`, `ema`. + Semantic teacher choice for SDFT. `base` uses the initial student, `live` uses the current student, and + `ema` uses an exponentially averaged teacher. generate_from_teacher (`bool`, *optional*, defaults to `False`): Whether on-policy generation should use the teacher-conditioned prompt instead of the student prompt. teacher_prompt_template (`str`, *optional*, defaults to `"{prompt}\n\n{privileged_context}"`): @@ -44,7 +45,10 @@ class SDFTConfig(SelfDistillationConfig): ) teacher_model_kind: str = field( default="base", - metadata={"help": "Semantic teacher choice for SDFT. Supported: `base`, `live`, `ema`."}, + metadata={ + "help": "Semantic teacher choice for SDFT. `base` uses the initial student, `live` uses the current " + "student, and `ema` uses an exponentially averaged teacher." + }, ) generate_from_teacher: bool = field( default=False, diff --git a/trl/experimental/sdpo/sdpo_config.py b/trl/experimental/sdpo/sdpo_config.py index 87c11ea9046..e40ea69a90c 100644 --- a/trl/experimental/sdpo/sdpo_config.py +++ b/trl/experimental/sdpo/sdpo_config.py @@ -40,7 +40,8 @@ class SDPOConfig(SelfDistillationConfig): > Parameters that control the teacher teacher_model_kind (`str`, *optional*, defaults to `"ema"`): - Semantic teacher choice. Supported: `base`, `live`, `ema`. + Semantic teacher choice. `base` uses the initial student, `live` uses the current student, and `ema` + uses an exponentially averaged teacher. teacher_update_rate (`float`, *optional*, defaults to `0.05`): EMA update rate used when `teacher_model_kind="ema"`. teacher_sync_steps (`int`, *optional*, defaults to `1`): @@ -76,7 +77,10 @@ class SDPOConfig(SelfDistillationConfig): ) teacher_model_kind: str = field( default="ema", - metadata={"help": "Semantic teacher choice. Supported: `base`, `live`, `ema`."}, + metadata={ + "help": "Semantic teacher choice. `base` uses the initial student, `live` uses the current student, " + "and `ema` uses an exponentially averaged teacher." + }, ) teacher_update_rate: float = field( default=0.05, diff --git a/trl/experimental/self_distillation/base_self_distillation_trainer.py b/trl/experimental/self_distillation/base_self_distillation_trainer.py index 9c9c0b627d9..3bca6976c99 100644 --- a/trl/experimental/self_distillation/base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/base_self_distillation_trainer.py @@ -64,7 +64,6 @@ if is_peft_available(): from peft import PeftConfig - from peft.peft_model import PeftModel logger = get_logger(__name__) @@ -137,12 +136,12 @@ def __init__( else inspect.signature(model.get_base_model().forward).parameters.keys() ) - if is_peft_available() and is_peft_model(model) and peft_config is not None: + if is_peft_model(model) and peft_config is not None: raise ValueError( "You passed a `PeftModel` instance together with a `peft_config`. Pass either a base " "model with `peft_config`, or a pre-wrapped PEFT model." ) - if peft_config is not None or (is_peft_available() and getattr(model, "peft_config", None) is not None): + if peft_config is not None or getattr(model, "peft_config", None) is not None: model = prepare_peft_model(model, peft_config, args) if processing_class is None: @@ -261,7 +260,10 @@ def _setup_teacher_model(self) -> None: teacher_model_kind = self.args.teacher_model_kind - if teacher_model_kind in {"base", "live"}: + if teacher_model_kind == "live": + return + + if teacher_model_kind == "base" and is_peft_model(self.model): return if self._use_peft_ema_teacher_adapter(): @@ -276,11 +278,11 @@ def _setup_teacher_model(self) -> None: ) return + # create teacher model from student copy student_model = self.accelerator.unwrap_model(self.model) self.teacher_model = copy.deepcopy(student_model) self.teacher_model.requires_grad_(False) self.teacher_model.eval() - if self.is_deepspeed_enabled: self.teacher_model = prepare_deepspeed(self.teacher_model, self.accelerator) elif self.is_fsdp_enabled: @@ -288,13 +290,14 @@ def _setup_teacher_model(self) -> None: else: self.teacher_model = self.accelerator.prepare_model(self.teacher_model, evaluation_mode=True) - self.add_callback(SyncTeacherModelCallback(teacher_model=self.teacher_model, accelerator=self.accelerator)) + if teacher_model_kind == "ema": + self.add_callback(SyncTeacherModelCallback(teacher_model=self.teacher_model, accelerator=self.accelerator)) def _use_peft_ema_teacher_adapter(self) -> bool: return self.args.teacher_model_kind == "ema" and self._is_pure_lora_training() def _is_pure_lora_training(self) -> bool: - if not (is_peft_available() and is_peft_model(self.model)): + if not is_peft_model(self.model): return False model = self.accelerator.unwrap_model(self.model) @@ -582,7 +585,7 @@ def finalize_batch( def _get_teacher_context_for_self_distillation(self): teacher_model_kind = self.args.teacher_model_kind - if not (is_peft_available() and isinstance(self.model, PeftModel)): + if not is_peft_model(self.model): return nullcontext() target_model = self.teacher_model if self.teacher_model is not None else self.model diff --git a/trl/experimental/self_distillation/self_distillation_config.py b/trl/experimental/self_distillation/self_distillation_config.py index a03773f9bc2..d4e66dbce88 100644 --- a/trl/experimental/self_distillation/self_distillation_config.py +++ b/trl/experimental/self_distillation/self_distillation_config.py @@ -55,7 +55,8 @@ class SelfDistillationConfig(_BaseConfig): > Parameters that control teacher construction teacher_model_kind (`str`, *optional*, defaults to `"live"`): - Semantic teacher choice. Supported: `base`, `live`, `ema`. + Semantic teacher choice. `live` uses the current student, `base` uses the student as it existed at the + start of training, and `ema` uses an exponentially averaged teacher. teacher_update_rate (`float`, *optional*, defaults to `0.6`): EMA update rate used when `teacher_model_kind="ema"`. teacher_sync_steps (`int`, *optional*, defaults to `512`): @@ -249,7 +250,10 @@ class SelfDistillationConfig(_BaseConfig): ) teacher_model_kind: str = field( default="live", - metadata={"help": "Semantic teacher choice. Supported: `base`, `live`, `ema`."}, + metadata={ + "help": "Semantic teacher choice. `live` uses the current student, `base` uses the initial student, " + "and `ema` uses an exponentially averaged teacher." + }, ) teacher_update_rate: float = field( default=0.6, From efe0eda4cb35cadfc60ef7a9ebeb645d6a84b8c0 Mon Sep 17 00:00:00 2001 From: Leon Date: Thu, 16 Apr 2026 11:25:32 +0200 Subject: [PATCH 08/23] wip removing mixin --- trl/experimental/sdft/sdft_trainer.py | 16 +- trl/experimental/sdpo/sdpo_trainer.py | 18 +- .../self_distillation/__init__.py | 4 +- .../base_self_distillation_trainer.py | 98 +++++- .../self_distillation/online_rollout_mixin.py | 2 +- .../self_distillation_loss.py | 318 ++++++++++++++++++ .../self_distillation_mixin.py | 270 --------------- .../self_distillation/teacher_context.py | 55 +-- 8 files changed, 423 insertions(+), 358 deletions(-) create mode 100644 trl/experimental/self_distillation/self_distillation_loss.py delete mode 100644 trl/experimental/self_distillation/self_distillation_mixin.py diff --git a/trl/experimental/sdft/sdft_trainer.py b/trl/experimental/sdft/sdft_trainer.py index fe4a1b9bb09..ee2fc712c29 100644 --- a/trl/experimental/sdft/sdft_trainer.py +++ b/trl/experimental/sdft/sdft_trainer.py @@ -34,7 +34,7 @@ RolloutBatch, TrainingBatch, ) -from ..self_distillation.teacher_context import PromptTokenizer, extract_last_user_text +from ..self_distillation.teacher_context import _split_prompt_and_privileged_context, extract_last_user_text from .sdft_config import SDFTConfig @@ -50,7 +50,6 @@ class DemonstrationTeacherContextBuilder: def __init__(self, trainer): self.trainer = trainer - self.prompt_tokenizer = PromptTokenizer(trainer) def _stringify_privileged_context(self, privileged_context: Any) -> str: if privileged_context is None: @@ -99,17 +98,14 @@ def build( completion_ids: torch.Tensor, completion_mask: torch.Tensor, ) -> dict[str, torch.Tensor]: - student_batch = self.prompt_tokenizer.tokenize_prompts(prompts) teacher_prompts = [ self._compose_teacher_prompt(prompt, privileged_context) for prompt, privileged_context in zip(prompts, privileged_contexts, strict=True) ] - teacher_batch = self.prompt_tokenizer.tokenize_prompts(teacher_prompts) - teacher_input_ids = torch.cat([teacher_batch.prompt_ids, completion_ids], dim=1) - teacher_attention_mask = torch.cat([teacher_batch.prompt_mask, completion_mask], dim=1) + teacher_batch = self.trainer._tokenize_prompts(teacher_prompts) + teacher_input_ids = torch.cat([teacher_batch["prompt_ids"], completion_ids], dim=1) + teacher_attention_mask = torch.cat([teacher_batch["prompt_mask"], completion_mask], dim=1) return { - "prompt_ids": student_batch.prompt_ids, - "prompt_mask": student_batch.prompt_mask, "teacher_input_ids": teacher_input_ids, "teacher_attention_mask": teacher_attention_mask, } @@ -171,7 +167,7 @@ def finalize_batch( inputs: list[dict[str, Any]], rollout_batch: RolloutBatch, ) -> TrainingBatch: - prompts, privileged_contexts = self._split_prompt_and_privileged_context(inputs) + prompts, privileged_contexts = _split_prompt_and_privileged_context(inputs) teacher_batch = self.teacher_context_builder.build( prompts, privileged_contexts, @@ -182,8 +178,6 @@ def finalize_batch( batch = super().finalize_batch(inputs, rollout_batch) batch.update( { - "prompt_ids": teacher_batch["prompt_ids"], - "prompt_mask": teacher_batch["prompt_mask"], "teacher_input_ids": teacher_batch["teacher_input_ids"], "teacher_attention_mask": teacher_batch["teacher_attention_mask"], } diff --git a/trl/experimental/sdpo/sdpo_trainer.py b/trl/experimental/sdpo/sdpo_trainer.py index a57d946aba7..c5ebeb02baa 100644 --- a/trl/experimental/sdpo/sdpo_trainer.py +++ b/trl/experimental/sdpo/sdpo_trainer.py @@ -39,7 +39,7 @@ RolloutBatch, TrainingBatch, ) -from ..self_distillation.teacher_context import TokenizedPromptBatch, extract_last_user_text +from ..self_distillation.teacher_context import _split_prompt_and_privileged_context, extract_last_user_text from .sdpo_config import SDPOConfig @@ -62,7 +62,7 @@ def _build_reprompt_text(self, prompt_text: str, solution_text: str, feedback_te def _tokenize_teacher_messages( self, teacher_messages_list: list[str | list[dict[str, Any]]] - ) -> TokenizedPromptBatch: + ) -> dict[str, torch.Tensor]: teacher_prompt_ids_list = [] device = self.trainer.accelerator.device chat_template_kwargs = getattr(self.trainer, "chat_template_kwargs", {}) @@ -88,10 +88,10 @@ def _tokenize_teacher_messages( teacher_prompt_ids = [ids.to(device) for ids in teacher_prompt_ids_list] teacher_prompt_mask = [torch.ones(len(ids), dtype=torch.long, device=device) for ids in teacher_prompt_ids] - return TokenizedPromptBatch( - prompt_ids=pad(teacher_prompt_ids, padding_value=self.trainer.pad_token_id, padding_side="left"), - prompt_mask=pad(teacher_prompt_mask, padding_value=0, padding_side="left"), - ) + return { + "prompt_ids": pad(teacher_prompt_ids, padding_value=self.trainer.pad_token_id, padding_side="left"), + "prompt_mask": pad(teacher_prompt_mask, padding_value=0, padding_side="left"), + } def build( self, @@ -216,8 +216,8 @@ def build( local_teacher_messages.append(self._build_reprompt_text(original_prompt, solution_text, feedback_text)) teacher_batch = self._tokenize_teacher_messages(local_teacher_messages) - teacher_input_ids = torch.cat([teacher_batch.prompt_ids, completion_ids], dim=1) - teacher_attention_mask = torch.cat([teacher_batch.prompt_mask, completion_mask], dim=1) + teacher_input_ids = torch.cat([teacher_batch["prompt_ids"], completion_ids], dim=1) + teacher_attention_mask = torch.cat([teacher_batch["prompt_mask"], completion_mask], dim=1) batch_size = total_samples if total_samples > 0 else 1 num_groups = max(1, total_samples // max(1, num_generations)) @@ -403,7 +403,7 @@ def finalize_batch( ) -> TrainingBatch: device = self.accelerator.device mode = "train" if self.model.training else "eval" - prompts, privileged_contexts = self._split_prompt_and_privileged_context(inputs) + prompts, privileged_contexts = _split_prompt_and_privileged_context(inputs) raw_completion_lengths = rollout_batch.raw_completion_lengths.detach().cpu().tolist() completion_ids_list = [ ids[:length].tolist() diff --git a/trl/experimental/self_distillation/__init__.py b/trl/experimental/self_distillation/__init__.py index 1449db2f7a3..006587d6d05 100644 --- a/trl/experimental/self_distillation/__init__.py +++ b/trl/experimental/self_distillation/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from .self_distillation_config import SelfDistillationConfig -from .self_distillation_mixin import SelfDistillationMixin +from .self_distillation_loss import SelfDistillationLossComputer -__all__ = ["SelfDistillationConfig", "SelfDistillationMixin"] +__all__ = ["SelfDistillationConfig", "SelfDistillationLossComputer"] diff --git a/trl/experimental/self_distillation/base_self_distillation_trainer.py b/trl/experimental/self_distillation/base_self_distillation_trainer.py index 3bca6976c99..47f6d5fe4f8 100644 --- a/trl/experimental/self_distillation/base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/base_self_distillation_trainer.py @@ -41,6 +41,7 @@ from transformers.trainer_utils import seed_worker from transformers.utils import is_datasets_available, is_peft_available +from ...data_utils import maybe_apply_chat_template from ...models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation from ...trainer.base_trainer import _BaseTrainer from ...trainer.utils import ( @@ -57,8 +58,7 @@ ) from ..utils import prepare_peft_model from .self_distillation_config import SelfDistillationConfig -from .self_distillation_mixin import SelfDistillationMixin -from .teacher_context import PromptTokenizer +from .self_distillation_loss import SelfDistillationLossComputer from .teacher_sync import PEFTAdapterEMACallback, SyncTeacherModelCallback @@ -97,7 +97,7 @@ def as_dict(self) -> dict[str, torch.Tensor | Any]: TrainingBatch = dict[str, torch.Tensor | Any] -class BaseSelfDistillationTrainer(SelfDistillationMixin, _BaseTrainer, ABC): +class BaseSelfDistillationTrainer(_BaseTrainer, ABC): """Base that centralizes shared self-distillation trainer lifecycle.""" config_cls = SelfDistillationConfig @@ -179,7 +179,6 @@ def __init__( "train": defaultdict(int), "eval": defaultdict(int), } - self.prompt_tokenizer = PromptTokenizer(self) generation_kwargs = { "max_new_tokens": self.max_completion_length, @@ -253,8 +252,26 @@ def __init__( self.model.add_model_tags(self._tag_names) self._setup_teacher_model() + self._self_distillation_loss = SelfDistillationLossComputer(self) self.model_accepts_loss_kwargs = False + def _set_signature_columns_if_needed(self): + if self._signature_columns is None: + self._signature_columns = ["prompt", "privileged_context"] + + def _dispatch_self_distillation_callback(self, event_name: str, **payload) -> None: + for callback in self.callback_handler.callbacks: + callback_fn = getattr(callback, event_name, None) + if callback_fn is not None: + callback_fn( + args=self.args, + state=self.state, + control=self.control, + model=self.model, + processing_class=self.processing_class, + **payload, + ) + def _setup_teacher_model(self) -> None: """Prepare teacher state according to the semantic teacher choice.""" @@ -395,17 +412,55 @@ def _prepare_training_batch(self, inputs: list[dict[str, Any]]) -> TrainingBatch ) return batch + def _apply_chat_template_to_prompts(self, prompts: list[Any]) -> list[str]: + return [ + maybe_apply_chat_template( + {"prompt": prompt}, + self.processing_class, + **self.chat_template_kwargs, + )["prompt"] + for prompt in prompts + ] + + def _tokenize_prompt_text(self, prompt_text: list[str]) -> dict[str, torch.Tensor]: + prompt_inputs = self.processing_class( + text=prompt_text, + return_tensors="pt", + padding=True, + padding_side="left", + max_length=self.max_prompt_length, + truncation=True, + add_special_tokens=False, + ) + prompt_inputs = _BaseTrainer._prepare_inputs(self, prompt_inputs) + prompt_ids = [ + input_ids[mask].to(device=self.accelerator.device) + for input_ids, mask in zip( + prompt_inputs["input_ids"], + prompt_inputs["attention_mask"].bool(), + strict=False, + ) + ] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + return { + "prompt_ids": pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left"), + "prompt_mask": pad(prompt_mask, padding_value=0, padding_side="left"), + } + + def _tokenize_prompts(self, prompts: list[Any]) -> dict[str, torch.Tensor]: + return self._tokenize_prompt_text(self._apply_chat_template_to_prompts(prompts)) + def sample_rollouts(self, inputs: list[dict[str, Any]]) -> RolloutBatch: prompts, _ = self._split_prompt_and_privileged_context(inputs) generation_prompts = prompts - generation_prompt_text = self.prompt_tokenizer.apply_prompt_template(generation_prompts) + generation_prompt_text = self._apply_chat_template_to_prompts(generation_prompts) self._dispatch_self_distillation_callback( "on_generation_prompts_selected", generation_prompts=generation_prompts, generation_prompt_text=generation_prompt_text, ) - prompt_ids_list, completion_ids_list = self._generate(generation_prompts) + prompt_ids_list, completion_ids_list = self._generate(generation_prompt_text) device = self.accelerator.device prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list] prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] @@ -437,19 +492,18 @@ def sample_rollouts(self, inputs: list[dict[str, Any]]) -> RolloutBatch: ), ) - def _generate(self, prompts: list[Any]) -> tuple[list[list[int]], list[list[int]]]: + def _generate(self, prompt_text: list[str]) -> tuple[list[list[int]], list[list[int]]]: if self.use_vllm: - return self._generate_vllm(prompts) - return self._generate_transformers(prompts) + return self._generate_vllm(prompt_text) + return self._generate_transformers(prompt_text) - def _generate_vllm(self, prompts: list[Any]) -> tuple[list[list[int]], list[list[int]]]: + def _generate_vllm(self, prompt_text: list[str]) -> tuple[list[list[int]], list[list[int]]]: if self.state.global_step != self._last_loaded_step: self.vllm_generation.sync_weights() self._last_loaded_step = self.state.global_step - prompts_text = self.prompt_tokenizer.apply_prompt_template(prompts) tokenized = self.processing_class( - text=prompts_text, + text=prompt_text, return_tensors=None, padding=False, max_length=self.max_prompt_length, @@ -466,9 +520,9 @@ def _generate_vllm(self, prompts: list[Any]) -> tuple[list[list[int]], list[list ) return prompt_ids_out, completion_ids_list - def _generate_transformers(self, prompts: list[Any]) -> tuple[list[list[int]], list[list[int]]]: + def _generate_transformers(self, prompt_text: list[str]) -> tuple[list[list[int]], list[list[int]]]: generate_inputs = self.processing_class( - text=self.prompt_tokenizer.apply_prompt_template(prompts), + text=prompt_text, return_tensors="pt", padding=True, padding_side="left", @@ -529,6 +583,16 @@ def _compute_rollout_logps( return old_per_token_logps + def _allow_topk_without_full_logit_distillation(self) -> bool: + return True + + def _compute_self_distillation_loss( + self, + model, + inputs: TrainingBatch, + ) -> torch.Tensor: + return self._self_distillation_loss.compute_loss(model, inputs) + def _get_per_token_logps_and_entropies( self, model, @@ -602,6 +666,12 @@ def _get_teacher_context_for_self_distillation(self): return use_adapter(target_model, adapter_name=teacher_adapter_name) return nullcontext() + def _get_teacher_model_for_self_distillation(self, model): + teacher_model = getattr(self, "teacher_model", None) + if teacher_model is None: + return model + return teacher_model + @abstractmethod def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """Subclasses own algorithm-specific loss composition on the final batch contract.""" diff --git a/trl/experimental/self_distillation/online_rollout_mixin.py b/trl/experimental/self_distillation/online_rollout_mixin.py index 490724582dc..4054e1233fb 100644 --- a/trl/experimental/self_distillation/online_rollout_mixin.py +++ b/trl/experimental/self_distillation/online_rollout_mixin.py @@ -16,7 +16,7 @@ This mixin owns generation, reward scoring, grouped reward normalization, and online policy-loss plumbing. It is paired with `BaseSelfDistillationTrainer` for SDPO-style methods and intentionally kept separate from the generic distillation -loss logic in `self_distillation_mixin.py`. +loss logic in `self_distillation_loss.py`. """ from __future__ import annotations diff --git a/trl/experimental/self_distillation/self_distillation_loss.py b/trl/experimental/self_distillation/self_distillation_loss.py new file mode 100644 index 00000000000..faeef1d4367 --- /dev/null +++ b/trl/experimental/self_distillation/self_distillation_loss.py @@ -0,0 +1,318 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared self-distillation loss computation. + +This module intentionally contains only the reusable distillation loss mechanics. Trainer lifecycle concerns, +callback dispatch, and batch construction live in the trainer classes. +""" + +from __future__ import annotations + +from typing import Any, Protocol + +import torch +import torch.nn.functional as F + + +class SelfDistillationRuntime(Protocol): + """Minimal trainer surface required by `SelfDistillationLossComputer`.""" + + args: Any + accelerator: Any + model_kwarg_keys: Any + temperature: float + loss_type: str + max_completion_length: int + _metrics: dict[str, Any] + _name: str + + def _allow_topk_without_full_logit_distillation(self) -> bool: ... + + def _get_teacher_model_for_self_distillation(self, model): ... + + def _get_teacher_context_for_self_distillation(self): ... + + +class SelfDistillationLossComputer: + """Computes the shared student-vs-teacher self-distillation loss.""" + + def __init__(self, runtime: SelfDistillationRuntime): + self.runtime = runtime + + def compute_loss(self, model, inputs: dict[str, Any]) -> torch.Tensor: + prompt_ids = inputs["prompt_ids"] + prompt_mask = inputs["prompt_mask"] + completion_ids = inputs["completion_ids"] + completion_mask = inputs["completion_mask"] + logits_to_keep = completion_ids.size(1) + + response_mask = self._build_response_mask(completion_mask, inputs.get("self_distillation_mask")) + if response_mask.sum() == 0: + mode = "train" if model.training else "eval" + self._log_distillation_metric(mode, 0.0) + return torch.tensor(0.0, device=completion_ids.device, requires_grad=True) + + student_logits = self._compute_student_logits( + model=model, + prompt_ids=prompt_ids, + prompt_mask=prompt_mask, + completion_ids=completion_ids, + completion_mask=completion_mask, + logits_to_keep=logits_to_keep, + ) + teacher_logits = self._compute_teacher_logits( + model=model, + teacher_input_ids=inputs["teacher_input_ids"], + teacher_attention_mask=inputs["teacher_attention_mask"], + logits_to_keep=logits_to_keep, + ) + + per_token_loss = self._compute_per_token_loss( + student_logits=student_logits, + teacher_logits=teacher_logits, + completion_ids=completion_ids, + ) + + old_log_probs = inputs.get("old_per_token_logps") + if self.runtime.args.distillation_is_clip is not None and old_log_probs is not None: + student_per_token_logps = self._select_token_log_probs(student_logits, completion_ids) + per_token_loss = self._apply_importance_sampling_clipping( + per_token_loss, + student_per_token_logps, + old_log_probs, + self.runtime.args.distillation_is_clip, + ) + + loss = self._aggregate_loss(per_token_loss, response_mask) + + mode = "train" if model.training else "eval" + mean_distill_loss = (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) + self._log_distillation_metric(mode, self.runtime.accelerator.gather(mean_distill_loss).mean().item()) + return loss + + @staticmethod + def _build_response_mask( + completion_mask: torch.Tensor, + self_distillation_mask: torch.Tensor | None, + ) -> torch.Tensor: + if self_distillation_mask is None: + return completion_mask + return completion_mask * self_distillation_mask.unsqueeze(1) + + def _compute_student_logits( + self, + model, + prompt_ids: torch.Tensor, + prompt_mask: torch.Tensor, + completion_ids: torch.Tensor, + completion_mask: torch.Tensor, + logits_to_keep: int, + ) -> torch.Tensor: + student_input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + student_attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + return self._forward_logits( + model=model, + input_ids=student_input_ids, + attention_mask=student_attention_mask, + logits_to_keep=logits_to_keep, + ) + + def _compute_teacher_logits( + self, + model, + teacher_input_ids: torch.Tensor, + teacher_attention_mask: torch.Tensor, + logits_to_keep: int, + ) -> torch.Tensor: + teacher_model = self.runtime._get_teacher_model_for_self_distillation(model) + with torch.no_grad(), self.runtime._get_teacher_context_for_self_distillation(): + return self._forward_logits( + model=teacher_model, + input_ids=teacher_input_ids, + attention_mask=teacher_attention_mask, + logits_to_keep=logits_to_keep, + ) + + def _forward_logits( + self, + model, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + logits_to_keep: int, + ) -> torch.Tensor: + model_inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "use_cache": False, + } + if "logits_to_keep" in self.runtime.model_kwarg_keys: + model_inputs["logits_to_keep"] = logits_to_keep + 1 + + logits = model(**model_inputs).logits + logits = logits[:, :-1, :] + logits = logits[:, -logits_to_keep:, :] + return logits / self.runtime.temperature + + def _compute_per_token_loss( + self, + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + completion_ids: torch.Tensor, + ) -> torch.Tensor: + args = self.runtime.args + use_topk_distillation = args.distillation_topk is not None and ( + args.full_logit_distillation or self.runtime._allow_topk_without_full_logit_distillation() + ) + + if use_topk_distillation: + return self._compute_topk_distillation_loss(student_logits, teacher_logits, args.distillation_topk) + if args.full_logit_distillation: + return self._compute_full_logit_distillation_loss(student_logits, teacher_logits) + return self._compute_sampled_token_distillation_loss(student_logits, teacher_logits, completion_ids) + + def _compute_topk_distillation_loss( + self, + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + topk: int, + ) -> torch.Tensor: + args = self.runtime.args + student_logsumexp = torch.logsumexp(student_logits, dim=-1, keepdim=True) + topk_student_logits, topk_indices = torch.topk(student_logits, k=topk, dim=-1) + topk_student_log_probs = topk_student_logits - student_logsumexp + + teacher_logsumexp = torch.logsumexp(teacher_logits, dim=-1, keepdim=True) + topk_teacher_logits = torch.gather(teacher_logits, dim=-1, index=topk_indices) + topk_teacher_log_probs = topk_teacher_logits - teacher_logsumexp + + if args.distillation_add_tail: + topk_student_log_probs = self._add_tail(topk_student_log_probs) + topk_teacher_log_probs = self._add_tail(topk_teacher_log_probs) + else: + topk_student_log_probs = self._renorm_topk_log_probs(topk_student_log_probs) + topk_teacher_log_probs = self._renorm_topk_log_probs(topk_teacher_log_probs) + + return self._compute_divergence(topk_student_log_probs, topk_teacher_log_probs, args.distillation_alpha) + + def _compute_full_logit_distillation_loss( + self, + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + ) -> torch.Tensor: + args = self.runtime.args + student_log_probs = F.log_softmax(student_logits, dim=-1) + teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) + return self._compute_divergence(student_log_probs, teacher_log_probs, args.distillation_alpha) + + def _compute_sampled_token_distillation_loss( + self, + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + completion_ids: torch.Tensor, + ) -> torch.Tensor: + if self.runtime.args.distillation_alpha != 1.0: + raise ValueError( + "Only reverse KL (alpha=1.0) is supported for token-level distillation when " + f"`full_logit_distillation=False`, got alpha={self.runtime.args.distillation_alpha}" + ) + + student_per_token_logps = self._select_token_log_probs(student_logits, completion_ids) + teacher_per_token_logps = self._select_token_log_probs(teacher_logits, completion_ids) + return self._compute_token_level_distillation_loss(student_per_token_logps, teacher_per_token_logps) + + @staticmethod + def _select_token_log_probs( + logits: torch.Tensor, + token_ids: torch.Tensor, + ) -> torch.Tensor: + logsumexp = torch.logsumexp(logits, dim=-1, keepdim=True) + indices = token_ids.unsqueeze(-1) + return (torch.gather(logits, dim=-1, index=indices) - logsumexp).squeeze(-1) + + def _aggregate_loss( + self, + per_token_loss: torch.Tensor, + response_mask: torch.Tensor, + ) -> torch.Tensor: + loss_type = self.runtime.loss_type + if loss_type == "grpo": + loss = (per_token_loss * response_mask).sum(-1) / response_mask.sum(-1).clamp(min=1.0) + return loss.mean() + if loss_type == "bnpo": + return (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) + if loss_type == "dr_grpo": + return (per_token_loss * response_mask).sum() / ( + per_token_loss.size(0) * self.runtime.max_completion_length + ) + if loss_type in ["dapo", "luspo", "cispo", "sapo"]: + return (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) + raise ValueError(f"Unsupported loss_type for self-distillation: {loss_type}") + + def _log_distillation_metric(self, mode: str, value: float) -> None: + metric_prefix = getattr(self.runtime, "_name", "self_distillation").lower().replace(" ", "_") + self.runtime._metrics[mode]["self_distillation/distillation_loss"].append(value) + self.runtime._metrics[mode][f"{metric_prefix}/distillation_loss"].append(value) + + @staticmethod + def _compute_divergence( + student_log_probs: torch.Tensor, + teacher_log_probs: torch.Tensor, + alpha: float, + ) -> torch.Tensor: + if alpha == 0.0: + kl = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) + elif alpha == 1.0: + kl = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) + else: + alpha_t = torch.tensor(alpha, dtype=student_log_probs.dtype, device=student_log_probs.device) + mixture = torch.logsumexp( + torch.stack([student_log_probs + torch.log(1 - alpha_t), teacher_log_probs + torch.log(alpha_t)]), + dim=0, + ) + kl_teacher = F.kl_div(mixture, teacher_log_probs, reduction="none", log_target=True) + kl_student = F.kl_div(mixture, student_log_probs, reduction="none", log_target=True) + kl = torch.lerp(kl_student, kl_teacher, alpha) + return kl.sum(-1) + + @staticmethod + def _add_tail(log_probs: torch.Tensor) -> torch.Tensor: + log_s = torch.logsumexp(log_probs, dim=-1, keepdim=True) + log_s = torch.clamp(log_s, max=-1e-7) + tail_log = torch.log(-torch.expm1(log_s)) + return torch.cat([log_probs, tail_log], dim=-1) + + @staticmethod + def _renorm_topk_log_probs(log_probs: torch.Tensor) -> torch.Tensor: + return log_probs - torch.logsumexp(log_probs, dim=-1, keepdim=True) + + @staticmethod + def _compute_token_level_distillation_loss( + student_log_probs: torch.Tensor, + teacher_log_probs: torch.Tensor, + ) -> torch.Tensor: + log_ratio = student_log_probs - teacher_log_probs + return log_ratio.detach() * student_log_probs + + @staticmethod + def _apply_importance_sampling_clipping( + per_token_loss: torch.Tensor, + student_log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + clip_coeff: float, + ) -> torch.Tensor: + negative_approx_kl = (student_log_probs - old_log_probs).detach() + negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) + ratio = torch.exp(negative_approx_kl).clamp(max=clip_coeff) + return per_token_loss * ratio diff --git a/trl/experimental/self_distillation/self_distillation_mixin.py b/trl/experimental/self_distillation/self_distillation_mixin.py deleted file mode 100644 index f85dcd236f9..00000000000 --- a/trl/experimental/self_distillation/self_distillation_mixin.py +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright 2020-2026 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Shared self-distillation loss utilities used by experimental trainers. - -This module intentionally holds only the reusable distillation mechanics: callback dispatch, common prompt/context -helpers, and the student-vs-teacher loss computation. Trainer lifecycle and online rollout concerns live in the trainer -classes or their online-specific base. -""" - -from __future__ import annotations - -from typing import Any - -import torch -import torch.nn.functional as F - -from .self_distillation_config import SelfDistillationConfig - - -class SelfDistillationMixin: - """Reusable self-distillation helpers shared across experimental trainers.""" - - config_cls = SelfDistillationConfig - - def _set_signature_columns_if_needed(self): - if self._signature_columns is None: - self._signature_columns = ["prompt", "privileged_context"] - - def _dispatch_self_distillation_callback(self, event_name: str, **payload) -> None: - for callback in self.callback_handler.callbacks: - callback_fn = getattr(callback, event_name, None) - if callback_fn is not None: - callback_fn( - args=self.args, - state=self.state, - control=self.control, - model=self.model, - processing_class=self.processing_class, - **payload, - ) - - @staticmethod - def _split_prompt_and_privileged_context(inputs: list[dict[str, Any]]) -> tuple[list[Any], list[Any]]: - prompts = [example["prompt"] for example in inputs] - privileged_contexts = [example.get("privileged_context") for example in inputs] - return prompts, privileged_contexts - - def _allow_topk_without_full_logit_distillation(self) -> bool: - return True - - def _compute_self_distillation_loss( - self, - model, - inputs: dict[str, Any], - ) -> torch.Tensor: - # Expected batch contract: - # - required: `prompt_ids`, `prompt_mask`, `completion_ids`, `completion_mask`, - # `teacher_input_ids`, `teacher_attention_mask` - # - optional: `self_distillation_mask` to zero-out samples without teacher supervision, - # `old_per_token_logps` to enable IS clipping when generation and optimization are misaligned - prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] - completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] - logits_to_keep = completion_ids.size(1) - - self_distillation_mask = inputs.get("self_distillation_mask") - if self_distillation_mask is not None: - response_mask = completion_mask * self_distillation_mask.unsqueeze(1) - else: - response_mask = completion_mask - - if response_mask.sum() == 0: - mode = "train" if model.training else "eval" - self._log_self_distillation_metric(mode, "distillation_loss", 0.0) - return torch.tensor(0.0, device=completion_ids.device, requires_grad=True) - - student_input_ids = torch.cat([prompt_ids, completion_ids], dim=1) - student_attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - student_model_inputs = { - "input_ids": student_input_ids, - "attention_mask": student_attention_mask, - "use_cache": False, - } - if "logits_to_keep" in self.model_kwarg_keys: - student_model_inputs["logits_to_keep"] = logits_to_keep + 1 - - student_logits = model(**student_model_inputs).logits - student_logits = student_logits[:, :-1, :] - student_logits = student_logits[:, -logits_to_keep:, :] - student_logits = student_logits / self.temperature - - teacher_input_ids = inputs["teacher_input_ids"] - teacher_attention_mask = inputs["teacher_attention_mask"] - teacher_model_inputs = { - "input_ids": teacher_input_ids, - "attention_mask": teacher_attention_mask, - "use_cache": False, - } - if "logits_to_keep" in self.model_kwarg_keys: - teacher_model_inputs["logits_to_keep"] = logits_to_keep + 1 - - teacher_model = self._get_teacher_model_for_self_distillation(model) - with torch.no_grad(), self._get_teacher_context_for_self_distillation(model): - teacher_logits = teacher_model(**teacher_model_inputs).logits - teacher_logits = teacher_logits[:, :-1, :] - teacher_logits = teacher_logits[:, -logits_to_keep:, :] - teacher_logits = teacher_logits / self.temperature - - use_topk_distillation = self.args.distillation_topk is not None and ( - self.args.full_logit_distillation or self._allow_topk_without_full_logit_distillation() - ) - if use_topk_distillation: - student_logsumexp = torch.logsumexp(student_logits, dim=-1, keepdim=True) - topk_student_logits, topk_indices = torch.topk(student_logits, k=self.args.distillation_topk, dim=-1) - topk_student_log_probs = topk_student_logits - student_logsumexp - - teacher_logsumexp = torch.logsumexp(teacher_logits, dim=-1, keepdim=True) - topk_teacher_logits = torch.gather(teacher_logits, dim=-1, index=topk_indices) - topk_teacher_log_probs = topk_teacher_logits - teacher_logsumexp - - if self.args.distillation_add_tail: - topk_student_log_probs = self._add_tail(topk_student_log_probs) - topk_teacher_log_probs = self._add_tail(topk_teacher_log_probs) - else: - topk_student_log_probs = self._renorm_topk_log_probs(topk_student_log_probs) - topk_teacher_log_probs = self._renorm_topk_log_probs(topk_teacher_log_probs) - - per_token_loss = self._compute_divergence( - topk_student_log_probs, topk_teacher_log_probs, self.args.distillation_alpha - ) - elif self.args.full_logit_distillation: - student_log_probs = F.log_softmax(student_logits, dim=-1) - teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) - per_token_loss = self._compute_divergence( - student_log_probs, teacher_log_probs, self.args.distillation_alpha - ) - else: - if self.args.distillation_alpha != 1.0: - raise ValueError( - "Only reverse KL (alpha=1.0) is supported for token-level distillation when " - "`full_logit_distillation=False`, " - f"got alpha={self.args.distillation_alpha}" - ) - student_logsumexp = torch.logsumexp(student_logits, dim=-1, keepdim=True) - teacher_logsumexp = torch.logsumexp(teacher_logits, dim=-1, keepdim=True) - idx = completion_ids.unsqueeze(-1) - student_per_token_logps = (torch.gather(student_logits, dim=-1, index=idx) - student_logsumexp).squeeze(-1) - teacher_per_token_logps = (torch.gather(teacher_logits, dim=-1, index=idx) - teacher_logsumexp).squeeze(-1) - per_token_loss = self._compute_token_level_distillation_loss( - student_per_token_logps, teacher_per_token_logps - ) - - if self.args.distillation_is_clip is not None: - old_log_probs = inputs.get("old_per_token_logps") - if old_log_probs is not None: - with torch.no_grad(): - student_lse = torch.logsumexp(student_logits, dim=-1, keepdim=True) - idx = completion_ids.unsqueeze(-1) - student_per_token_logps = (torch.gather(student_logits, dim=-1, index=idx) - student_lse).squeeze( - -1 - ) - per_token_loss = self._apply_importance_sampling_clipping( - per_token_loss, student_per_token_logps, old_log_probs, self.args.distillation_is_clip - ) - - loss = self._aggregate_self_distillation_loss(per_token_loss, response_mask) - - mode = "train" if model.training else "eval" - mean_distill_loss = (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) - self._log_self_distillation_metric( - mode, - "distillation_loss", - self.accelerator.gather(mean_distill_loss).mean().item(), - ) - - return loss - - def _get_teacher_model_for_self_distillation(self, model): - teacher_model = getattr(self, "teacher_model", None) - if teacher_model is None: - return model - return teacher_model - - def _log_self_distillation_metric(self, mode: str, metric_name: str, value: float) -> None: - metric_prefix = getattr(self, "_name", "self_distillation").lower().replace(" ", "_") - self._metrics[mode][f"self_distillation/{metric_name}"].append(value) - self._metrics[mode][f"{metric_prefix}/{metric_name}"].append(value) - - @staticmethod - def _compute_divergence( - student_log_probs: torch.Tensor, - teacher_log_probs: torch.Tensor, - alpha: float, - ) -> torch.Tensor: - if alpha == 0.0: - kl = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) - elif alpha == 1.0: - kl = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) - else: - alpha_t = torch.tensor(alpha, dtype=student_log_probs.dtype, device=student_log_probs.device) - mixture = torch.logsumexp( - torch.stack([student_log_probs + torch.log(1 - alpha_t), teacher_log_probs + torch.log(alpha_t)]), - dim=0, - ) - kl_teacher = F.kl_div(mixture, teacher_log_probs, reduction="none", log_target=True) - kl_student = F.kl_div(mixture, student_log_probs, reduction="none", log_target=True) - kl = torch.lerp(kl_student, kl_teacher, alpha) - return kl.sum(-1) - - @staticmethod - def _add_tail(log_probs: torch.Tensor) -> torch.Tensor: - log_s = torch.logsumexp(log_probs, dim=-1, keepdim=True) - log_s = torch.clamp(log_s, max=-1e-7) - tail_log = torch.log(-torch.expm1(log_s)) - return torch.cat([log_probs, tail_log], dim=-1) - - @staticmethod - def _renorm_topk_log_probs(log_probs: torch.Tensor) -> torch.Tensor: - return log_probs - torch.logsumexp(log_probs, dim=-1, keepdim=True) - - @staticmethod - def _compute_token_level_distillation_loss( - student_log_probs: torch.Tensor, - teacher_log_probs: torch.Tensor, - ) -> torch.Tensor: - # This is the token-level reverse-KL surrogate used by the official SDPO implementation for - # `full_logit_distillation=False`. It intentionally treats the teacher log-probs as fixed targets - # and keeps only the score-function term for the sampled student tokens. - log_ratio = student_log_probs - teacher_log_probs - return log_ratio.detach() * student_log_probs - - @staticmethod - def _apply_importance_sampling_clipping( - per_token_loss: torch.Tensor, - student_log_probs: torch.Tensor, - old_log_probs: torch.Tensor, - clip_coeff: float, - ) -> torch.Tensor: - negative_approx_kl = (student_log_probs - old_log_probs).detach() - negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) - ratio = torch.exp(negative_approx_kl).clamp(max=clip_coeff) - return per_token_loss * ratio - - def _aggregate_self_distillation_loss( - self, - per_token_loss: torch.Tensor, - response_mask: torch.Tensor, - ) -> torch.Tensor: - loss_type = self.loss_type - if loss_type == "grpo": - loss = (per_token_loss * response_mask).sum(-1) / response_mask.sum(-1).clamp(min=1.0) - return loss.mean() - if loss_type == "bnpo": - return (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) - if loss_type == "dr_grpo": - return (per_token_loss * response_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) - if loss_type in ["dapo", "luspo", "cispo", "sapo"]: - return (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) - raise ValueError(f"Unsupported loss_type for self-distillation: {loss_type}") diff --git a/trl/experimental/self_distillation/teacher_context.py b/trl/experimental/self_distillation/teacher_context.py index 5e1020c91a7..70e28a92adc 100644 --- a/trl/experimental/self_distillation/teacher_context.py +++ b/trl/experimental/self_distillation/teacher_context.py @@ -14,14 +14,13 @@ from __future__ import annotations -from dataclasses import dataclass from typing import Any -import torch -from ...data_utils import maybe_apply_chat_template -from ...trainer.base_trainer import _BaseTrainer -from ...trainer.utils import pad +def _split_prompt_and_privileged_context(inputs: list[dict[str, Any]]) -> tuple[list[Any], list[Any]]: + prompts = [example["prompt"] for example in inputs] + privileged_contexts = [example.get("privileged_context") for example in inputs] + return prompts, privileged_contexts def extract_last_user_text(prompt: list[dict[str, Any]]) -> str: @@ -37,49 +36,3 @@ def extract_last_user_text(prompt: list[dict[str, Any]]) -> str: if isinstance(content, list): return " ".join(part.get("text", "") for part in content if part.get("type") == "text") return content - - -@dataclass -class TokenizedPromptBatch: - prompt_ids: torch.Tensor - prompt_mask: torch.Tensor - - -class PromptTokenizer: - """Internal helper to tokenize prompt-like inputs consistently across self-distillation trainers.""" - - def __init__(self, trainer): - self.trainer = trainer - - def apply_prompt_template(self, prompts: list[Any]) -> list[str]: - return [ - maybe_apply_chat_template( - {"prompt": prompt}, - self.trainer.processing_class, - **getattr(self.trainer, "chat_template_kwargs", {}), - )["prompt"] - for prompt in prompts - ] - - def tokenize_prompts(self, prompts: list[Any]) -> TokenizedPromptBatch: - prompt_text = self.apply_prompt_template(prompts) - prompt_inputs = self.trainer.processing_class( - text=prompt_text, - return_tensors="pt", - padding=True, - padding_side="left", - max_length=self.trainer.max_prompt_length, - truncation=True, - add_special_tokens=False, - ) - prompt_inputs = super(_BaseTrainer, self.trainer)._prepare_inputs(prompt_inputs) - prompt_ids = [ - p[m].tolist() - for p, m in zip(prompt_inputs["input_ids"], prompt_inputs["attention_mask"].bool(), strict=False) - ] - prompt_ids = [torch.tensor(ids, device=self.trainer.accelerator.device) for ids in prompt_ids] - prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] - return TokenizedPromptBatch( - prompt_ids=pad(prompt_ids, padding_value=self.trainer.pad_token_id, padding_side="left"), - prompt_mask=pad(prompt_mask, padding_value=0, padding_side="left"), - ) From fa1a8f31ccab3d49dc86c932d8929de44e343270 Mon Sep 17 00:00:00 2001 From: Leon Date: Thu, 16 Apr 2026 14:17:33 +0200 Subject: [PATCH 09/23] remove mixin, refactoring and cleanup --- trl/experimental/sdft/sdft_trainer.py | 7 +- trl/experimental/sdpo/sdpo_config.py | 2 + trl/experimental/sdpo/sdpo_trainer.py | 79 ++-- .../self_distillation/__init__.py | 3 +- .../base_self_distillation_trainer.py | 215 +++++++++- .../self_distillation/loss_utils.py | 136 +++++++ .../self_distillation/online_rollout_mixin.py | 385 ------------------ .../{teacher_context.py => prompt_utils.py} | 10 +- .../self_distillation_loss.py | 318 --------------- 9 files changed, 392 insertions(+), 763 deletions(-) create mode 100644 trl/experimental/self_distillation/loss_utils.py delete mode 100644 trl/experimental/self_distillation/online_rollout_mixin.py rename trl/experimental/self_distillation/{teacher_context.py => prompt_utils.py} (77%) delete mode 100644 trl/experimental/self_distillation/self_distillation_loss.py diff --git a/trl/experimental/sdft/sdft_trainer.py b/trl/experimental/sdft/sdft_trainer.py index ee2fc712c29..be1c7e0b9e8 100644 --- a/trl/experimental/sdft/sdft_trainer.py +++ b/trl/experimental/sdft/sdft_trainer.py @@ -34,7 +34,7 @@ RolloutBatch, TrainingBatch, ) -from ..self_distillation.teacher_context import _split_prompt_and_privileged_context, extract_last_user_text +from ..self_distillation.prompt_utils import extract_last_user_text from .sdft_config import SDFTConfig @@ -167,7 +167,7 @@ def finalize_batch( inputs: list[dict[str, Any]], rollout_batch: RolloutBatch, ) -> TrainingBatch: - prompts, privileged_contexts = _split_prompt_and_privileged_context(inputs) + prompts, privileged_contexts = self._split_prompt_and_privileged_context(inputs) teacher_batch = self.teacher_context_builder.build( prompts, privileged_contexts, @@ -195,6 +195,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N completion_mask = completion_mask * (token_positions >= self.num_loss_tokens_to_skip).long() inputs["completion_mask"] = completion_mask - loss = self._compute_self_distillation_loss(model, inputs) + distillation_logits = self._compute_teacher_student_logits(model, inputs) + loss = self._compute_self_distillation_loss(model, inputs, distillation_logits) accumulation_scale = self.current_gradient_accumulation_steps if self.model.training else 1.0 return loss / accumulation_scale diff --git a/trl/experimental/sdpo/sdpo_config.py b/trl/experimental/sdpo/sdpo_config.py index e40ea69a90c..031f1e73ce8 100644 --- a/trl/experimental/sdpo/sdpo_config.py +++ b/trl/experimental/sdpo/sdpo_config.py @@ -133,6 +133,8 @@ def __post_init__(self): raise ValueError("sdpo_policy_loss_mode must be one of: 'distillation_only', 'hybrid'") if self.sdpo_policy_loss_mode == "distillation_only" and self.distillation_weight <= 0: raise ValueError("distillation_only mode requires `distillation_weight > 0`.") + if self.sdpo_policy_loss_mode == "hybrid" and self.distillation_weight <= 0: + raise ValueError("hybrid mode requires `distillation_weight > 0`.") if self.max_reprompt_len <= 0: raise ValueError("max_reprompt_len must be positive") if not self.full_logit_distillation and self.distillation_alpha != 1.0: diff --git a/trl/experimental/sdpo/sdpo_trainer.py b/trl/experimental/sdpo/sdpo_trainer.py index c5ebeb02baa..0482048fbc7 100644 --- a/trl/experimental/sdpo/sdpo_trainer.py +++ b/trl/experimental/sdpo/sdpo_trainer.py @@ -36,10 +36,12 @@ from ...trainer.utils import get_config_model_id, pad from ..self_distillation.base_self_distillation_trainer import ( BaseSelfDistillationTrainer, + DistillationLogits, RolloutBatch, TrainingBatch, ) -from ..self_distillation.teacher_context import _split_prompt_and_privileged_context, extract_last_user_text +from ..self_distillation.loss_utils import select_token_log_probs +from ..self_distillation.prompt_utils import extract_last_user_text from .sdpo_config import SDPOConfig @@ -403,7 +405,7 @@ def finalize_batch( ) -> TrainingBatch: device = self.accelerator.device mode = "train" if self.model.training else "eval" - prompts, privileged_contexts = _split_prompt_and_privileged_context(inputs) + prompts, privileged_contexts = self._split_prompt_and_privileged_context(inputs) raw_completion_lengths = rollout_batch.raw_completion_lengths.detach().cpu().tolist() completion_ids_list = [ ids[:length].tolist() @@ -593,19 +595,14 @@ def _warn_on_degenerate_diagnostics(self, mode: str, counter_key: str, message: if count == 1 or count % interval == 0: logger.warning("%s Consecutive degenerate steps: %s.", message, count) - def _compute_policy_loss(self, model, inputs) -> torch.Tensor: - prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] - completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] - input_ids = torch.cat([prompt_ids, completion_ids], dim=1) - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - logits_to_keep = completion_ids.size(1) - per_token_logps, _ = self._get_per_token_logps_and_entropies( - model, - input_ids, - attention_mask, - logits_to_keep, - compute_entropy=False, - ) + def _compute_policy_loss( + self, + inputs, + student_logits, + ) -> torch.Tensor: + completion_ids = inputs["completion_ids"] + completion_mask = inputs["completion_mask"] + per_token_logps = select_token_log_probs(student_logits, completion_ids) old_per_token_logps = inputs.get("old_per_token_logps") old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps advantages = inputs["advantages"] @@ -630,24 +627,50 @@ def _compute_policy_loss(self, model, inputs) -> torch.Tensor: accumulation_scale = self.current_gradient_accumulation_steps if mode == "train" else 1.0 return loss / accumulation_scale - def _compute_weighted_self_distillation_loss(self, model, inputs) -> torch.Tensor | None: - if self.args.distillation_weight <= 0.0: - return None - + def _compute_weighted_self_distillation_loss( + self, + model, + inputs, + distillation_logits: DistillationLogits, + ) -> torch.Tensor: accumulation_scale = self.current_gradient_accumulation_steps if self.model.training else 1.0 - distillation_loss = self._compute_self_distillation_loss(model, inputs) / accumulation_scale + distillation_loss = ( + self._compute_self_distillation_loss( + model, + inputs, + distillation_logits, + ) + / accumulation_scale + ) return self.args.distillation_weight * distillation_loss + def _compute_hybrid_loss(self, model, inputs) -> torch.Tensor: + distillation_logits = self._compute_teacher_student_logits(model, inputs) + policy_loss = self._compute_policy_loss(inputs, distillation_logits.student_logits) + weighted_distillation_loss = self._compute_weighted_self_distillation_loss( + model, + inputs, + distillation_logits, + ) + return policy_loss + weighted_distillation_loss + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): if return_outputs: raise ValueError("The SDPOTrainer does not support returning outputs") if self.args.sdpo_policy_loss_mode == "hybrid": - policy_loss = self._compute_policy_loss(model, inputs) - weighted_distillation_loss = self._compute_weighted_self_distillation_loss(model, inputs) - return policy_loss if weighted_distillation_loss is None else policy_loss + weighted_distillation_loss - - weighted_distillation_loss = self._compute_weighted_self_distillation_loss(model, inputs) - if weighted_distillation_loss is not None: - return weighted_distillation_loss - return self._compute_policy_loss(model, inputs) + return self._compute_hybrid_loss(model, inputs) + + if self.args.distillation_weight > 0.0: + distillation_logits = self._compute_teacher_student_logits(model, inputs) + return self._compute_weighted_self_distillation_loss(model, inputs, distillation_logits) + else: + student_logits = self._compute_student_distillation_logits( + model=model, + prompt_ids=inputs["prompt_ids"], + prompt_mask=inputs["prompt_mask"], + completion_ids=inputs["completion_ids"], + completion_mask=inputs["completion_mask"], + logits_to_keep=inputs["completion_ids"].size(1), + ) + return self._compute_policy_loss(inputs, student_logits) diff --git a/trl/experimental/self_distillation/__init__.py b/trl/experimental/self_distillation/__init__.py index 006587d6d05..f5d94b41edb 100644 --- a/trl/experimental/self_distillation/__init__.py +++ b/trl/experimental/self_distillation/__init__.py @@ -13,7 +13,6 @@ # limitations under the License. from .self_distillation_config import SelfDistillationConfig -from .self_distillation_loss import SelfDistillationLossComputer -__all__ = ["SelfDistillationConfig", "SelfDistillationLossComputer"] +__all__ = ["SelfDistillationConfig"] diff --git a/trl/experimental/self_distillation/base_self_distillation_trainer.py b/trl/experimental/self_distillation/base_self_distillation_trainer.py index 47f6d5fe4f8..7d643878c4d 100644 --- a/trl/experimental/self_distillation/base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/base_self_distillation_trainer.py @@ -48,17 +48,21 @@ RepeatSampler, create_model_from_path, disable_dropout_in_model, - entropy_from_logits, get_config_model_id, identity, pad, - selective_log_softmax, split_tensor_dict, use_adapter, ) from ..utils import prepare_peft_model +from .loss_utils import ( + apply_importance_sampling_clipping, + compute_full_logit_self_distillation_loss, + compute_sampled_token_self_distillation_loss, + compute_topk_self_distillation_loss, + select_token_log_probs, +) from .self_distillation_config import SelfDistillationConfig -from .self_distillation_loss import SelfDistillationLossComputer from .teacher_sync import PEFTAdapterEMACallback, SyncTeacherModelCallback @@ -97,6 +101,17 @@ def as_dict(self) -> dict[str, torch.Tensor | Any]: TrainingBatch = dict[str, torch.Tensor | Any] +@dataclass +class DistillationLogits: + """Aligned logits and masks used to compute a self-distillation objective.""" + + completion_ids: torch.Tensor + completion_mask: torch.Tensor + response_mask: torch.Tensor + student_logits: torch.Tensor + teacher_logits: torch.Tensor + + class BaseSelfDistillationTrainer(_BaseTrainer, ABC): """Base that centralizes shared self-distillation trainer lifecycle.""" @@ -252,7 +267,6 @@ def __init__( self.model.add_model_tags(self._tag_names) self._setup_teacher_model() - self._self_distillation_loss = SelfDistillationLossComputer(self) self.model_accepts_loss_kwargs = False def _set_signature_columns_if_needed(self): @@ -492,6 +506,11 @@ def sample_rollouts(self, inputs: list[dict[str, Any]]) -> RolloutBatch: ), ) + def _split_prompt_and_privileged_context(inputs: list[dict[str, Any]]) -> tuple[list[Any], list[Any]]: + prompts = [example["prompt"] for example in inputs] + privileged_contexts = [example.get("privileged_context") for example in inputs] + return prompts, privileged_contexts + def _generate(self, prompt_text: list[str]) -> tuple[list[list[int]], list[list[int]]]: if self.use_vllm: return self._generate_vllm(prompt_text) @@ -573,13 +592,13 @@ def _compute_rollout_logps( attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) logits_to_keep = completion_ids.size(1) with torch.no_grad(): - old_per_token_logps, _ = self._get_per_token_logps_and_entropies( + logits = self._forward_logits( self.model, prompt_completion_ids, attention_mask, logits_to_keep, - compute_entropy=False, ) + old_per_token_logps = select_token_log_probs(logits, completion_ids) return old_per_token_logps @@ -590,28 +609,164 @@ def _compute_self_distillation_loss( self, model, inputs: TrainingBatch, + distillation_logits: DistillationLogits, ) -> torch.Tensor: - return self._self_distillation_loss.compute_loss(model, inputs) + if distillation_logits.response_mask.sum() == 0: + mode = "train" if model.training else "eval" + self._log_self_distillation_metric(mode, 0.0) + return torch.tensor(0.0, device=distillation_logits.completion_ids.device, requires_grad=True) + + use_topk_distillation = self.args.distillation_topk is not None and ( + self.args.full_logit_distillation or self._allow_topk_without_full_logit_distillation() + ) + if use_topk_distillation: + per_token_loss = compute_topk_self_distillation_loss( + distillation_logits.student_logits, + distillation_logits.teacher_logits, + distillation_topk=self.args.distillation_topk, + distillation_alpha=self.args.distillation_alpha, + distillation_add_tail=self.args.distillation_add_tail, + ) + elif self.args.full_logit_distillation: + per_token_loss = compute_full_logit_self_distillation_loss( + distillation_logits.student_logits, + distillation_logits.teacher_logits, + distillation_alpha=self.args.distillation_alpha, + ) + else: + per_token_loss = compute_sampled_token_self_distillation_loss( + distillation_logits.student_logits, + distillation_logits.teacher_logits, + distillation_logits.completion_ids, + distillation_alpha=self.args.distillation_alpha, + ) + + old_per_token_logps = inputs.get("old_per_token_logps") + if self.args.distillation_is_clip is not None and old_per_token_logps is not None: + student_per_token_logps = select_token_log_probs( + distillation_logits.student_logits, + distillation_logits.completion_ids, + ) + per_token_loss = apply_importance_sampling_clipping( + per_token_loss, + student_per_token_logps, + old_per_token_logps, + self.args.distillation_is_clip, + ) + + loss = self._aggregate_self_distillation_loss(per_token_loss, distillation_logits.response_mask) + + mode = "train" if model.training else "eval" + mean_distill_loss = ( + per_token_loss * distillation_logits.response_mask + ).sum() / distillation_logits.response_mask.sum().clamp(min=1.0) + self._log_self_distillation_metric( + mode, + self.accelerator.gather(mean_distill_loss).mean().item(), + ) + return loss - def _get_per_token_logps_and_entropies( + def _compute_teacher_student_logits( self, model, - input_ids, - attention_mask, - logits_to_keep, - compute_entropy=False, - ): - model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} + inputs: TrainingBatch, + ) -> DistillationLogits: + prompt_ids = inputs["prompt_ids"] + prompt_mask = inputs["prompt_mask"] + completion_ids = inputs["completion_ids"] + completion_mask = inputs["completion_mask"] + logits_to_keep = completion_ids.size(1) + + response_mask = self._build_self_distillation_response_mask( + completion_mask, + inputs.get("self_distillation_mask"), + ) + student_logits = self._compute_student_distillation_logits( + model=model, + prompt_ids=prompt_ids, + prompt_mask=prompt_mask, + completion_ids=completion_ids, + completion_mask=completion_mask, + logits_to_keep=logits_to_keep, + ) + + teacher_logits = self._compute_teacher_distillation_logits( + model=model, + teacher_input_ids=inputs["teacher_input_ids"], + teacher_attention_mask=inputs["teacher_attention_mask"], + logits_to_keep=logits_to_keep, + ) + + return DistillationLogits( + completion_ids=completion_ids, + completion_mask=completion_mask, + response_mask=response_mask, + student_logits=student_logits, + teacher_logits=teacher_logits, + ) + + @staticmethod + def _build_self_distillation_response_mask( + completion_mask: torch.Tensor, + self_distillation_mask: torch.Tensor | None, + ) -> torch.Tensor: + if self_distillation_mask is None: + return completion_mask + return completion_mask * self_distillation_mask.unsqueeze(1) + + def _compute_student_distillation_logits( + self, + model, + prompt_ids: torch.Tensor, + prompt_mask: torch.Tensor, + completion_ids: torch.Tensor, + completion_mask: torch.Tensor, + logits_to_keep: int, + ) -> torch.Tensor: + student_input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + student_attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + return self._forward_logits( + model=model, + input_ids=student_input_ids, + attention_mask=student_attention_mask, + logits_to_keep=logits_to_keep, + ) + + def _compute_teacher_distillation_logits( + self, + model, + teacher_input_ids: torch.Tensor, + teacher_attention_mask: torch.Tensor, + logits_to_keep: int, + ) -> torch.Tensor: + teacher_model = self._get_teacher_model_for_self_distillation(model) + with torch.no_grad(), self._get_teacher_context_for_self_distillation(): + return self._forward_logits( + model=teacher_model, + input_ids=teacher_input_ids, + attention_mask=teacher_attention_mask, + logits_to_keep=logits_to_keep, + ) + + def _forward_logits( + self, + model, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + logits_to_keep: int, + ) -> torch.Tensor: + model_inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "use_cache": False, + } if "logits_to_keep" in self.model_kwarg_keys: model_inputs["logits_to_keep"] = logits_to_keep + 1 + logits = model(**model_inputs).logits logits = logits[:, :-1, :] logits = logits[:, -logits_to_keep:, :] - logits = logits / self.temperature - completion_ids = input_ids[:, -logits_to_keep:] - selected_logps = selective_log_softmax(logits, completion_ids) - entropies = entropy_from_logits(logits) if compute_entropy else None - return selected_logps, entropies + return logits / self.temperature def _validate_training_batch(self, batch: TrainingBatch) -> None: required_keys = { @@ -672,6 +827,28 @@ def _get_teacher_model_for_self_distillation(self, model): return model return teacher_model + def _log_self_distillation_metric(self, mode: str, value: float) -> None: + metric_prefix = getattr(self, "_name", "self_distillation").lower().replace(" ", "_") + self._metrics[mode]["self_distillation/distillation_loss"].append(value) + self._metrics[mode][f"{metric_prefix}/distillation_loss"].append(value) + + def _aggregate_self_distillation_loss( + self, + per_token_loss: torch.Tensor, + response_mask: torch.Tensor, + ) -> torch.Tensor: + loss_type = self.loss_type + if loss_type == "grpo": + loss = (per_token_loss * response_mask).sum(-1) / response_mask.sum(-1).clamp(min=1.0) + return loss.mean() + if loss_type == "bnpo": + return (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) + if loss_type == "dr_grpo": + return (per_token_loss * response_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) + if loss_type in ["dapo", "luspo", "cispo", "sapo"]: + return (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) + raise ValueError(f"Unsupported loss_type for self-distillation: {loss_type}") + @abstractmethod def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """Subclasses own algorithm-specific loss composition on the final batch contract.""" diff --git a/trl/experimental/self_distillation/loss_utils.py b/trl/experimental/self_distillation/loss_utils.py new file mode 100644 index 00000000000..ebdf9689063 --- /dev/null +++ b/trl/experimental/self_distillation/loss_utils.py @@ -0,0 +1,136 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pure helper functions for self-distillation loss computation.""" + +from __future__ import annotations + +import torch +import torch.nn.functional as F + + +def select_token_log_probs( + logits: torch.Tensor, + token_ids: torch.Tensor, +) -> torch.Tensor: + logsumexp = torch.logsumexp(logits, dim=-1, keepdim=True) + indices = token_ids.unsqueeze(-1) + return (torch.gather(logits, dim=-1, index=indices) - logsumexp).squeeze(-1) + + +def compute_divergence( + student_log_probs: torch.Tensor, + teacher_log_probs: torch.Tensor, + alpha: float, +) -> torch.Tensor: + if alpha == 0.0: + kl = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) + elif alpha == 1.0: + kl = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) + else: + alpha_t = torch.tensor(alpha, dtype=student_log_probs.dtype, device=student_log_probs.device) + mixture = torch.logsumexp( + torch.stack([student_log_probs + torch.log(1 - alpha_t), teacher_log_probs + torch.log(alpha_t)]), + dim=0, + ) + kl_teacher = F.kl_div(mixture, teacher_log_probs, reduction="none", log_target=True) + kl_student = F.kl_div(mixture, student_log_probs, reduction="none", log_target=True) + kl = torch.lerp(kl_student, kl_teacher, alpha) + return kl.sum(-1) + + +def add_tail(log_probs: torch.Tensor) -> torch.Tensor: + log_s = torch.logsumexp(log_probs, dim=-1, keepdim=True) + log_s = torch.clamp(log_s, max=-1e-7) + tail_log = torch.log(-torch.expm1(log_s)) + return torch.cat([log_probs, tail_log], dim=-1) + + +def renorm_topk_log_probs(log_probs: torch.Tensor) -> torch.Tensor: + return log_probs - torch.logsumexp(log_probs, dim=-1, keepdim=True) + + +def compute_token_level_distillation_loss( + student_log_probs: torch.Tensor, + teacher_log_probs: torch.Tensor, +) -> torch.Tensor: + log_ratio = student_log_probs - teacher_log_probs + return log_ratio.detach() * student_log_probs + + +def apply_importance_sampling_clipping( + per_token_loss: torch.Tensor, + student_log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + clip_coeff: float, +) -> torch.Tensor: + negative_approx_kl = (student_log_probs - old_log_probs).detach() + negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) + ratio = torch.exp(negative_approx_kl).clamp(max=clip_coeff) + return per_token_loss * ratio + + +def compute_topk_self_distillation_loss( + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + *, + distillation_topk: int, + distillation_alpha: float, + distillation_add_tail: bool, +) -> torch.Tensor: + student_logsumexp = torch.logsumexp(student_logits, dim=-1, keepdim=True) + topk_student_logits, topk_indices = torch.topk(student_logits, k=distillation_topk, dim=-1) + topk_student_log_probs = topk_student_logits - student_logsumexp + + teacher_logsumexp = torch.logsumexp(teacher_logits, dim=-1, keepdim=True) + topk_teacher_logits = torch.gather(teacher_logits, dim=-1, index=topk_indices) + topk_teacher_log_probs = topk_teacher_logits - teacher_logsumexp + + if distillation_add_tail: + topk_student_log_probs = add_tail(topk_student_log_probs) + topk_teacher_log_probs = add_tail(topk_teacher_log_probs) + else: + topk_student_log_probs = renorm_topk_log_probs(topk_student_log_probs) + topk_teacher_log_probs = renorm_topk_log_probs(topk_teacher_log_probs) + + return compute_divergence(topk_student_log_probs, topk_teacher_log_probs, distillation_alpha) + + +def compute_full_logit_self_distillation_loss( + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + *, + distillation_alpha: float, +) -> torch.Tensor: + student_log_probs = torch.log_softmax(student_logits, dim=-1) + teacher_log_probs = torch.log_softmax(teacher_logits, dim=-1) + return compute_divergence(student_log_probs, teacher_log_probs, distillation_alpha) + + +def compute_sampled_token_self_distillation_loss( + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + completion_ids: torch.Tensor, + *, + distillation_alpha: float, +) -> torch.Tensor: + if distillation_alpha != 1.0: + raise ValueError( + "Only reverse KL (alpha=1.0) is supported for token-level distillation when " + f"`full_logit_distillation=False`, got alpha={distillation_alpha}" + ) + + student_per_token_logps = select_token_log_probs(student_logits, completion_ids) + teacher_per_token_logps = select_token_log_probs(teacher_logits, completion_ids) + return compute_token_level_distillation_loss(student_per_token_logps, teacher_per_token_logps) diff --git a/trl/experimental/self_distillation/online_rollout_mixin.py b/trl/experimental/self_distillation/online_rollout_mixin.py deleted file mode 100644 index 4054e1233fb..00000000000 --- a/trl/experimental/self_distillation/online_rollout_mixin.py +++ /dev/null @@ -1,385 +0,0 @@ -# Copyright 2020-2026 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Online rollout helpers for experimental self-distillation trainers. - -This mixin owns generation, reward scoring, grouped reward normalization, and online policy-loss plumbing. It is paired -with `BaseSelfDistillationTrainer` for SDPO-style methods and intentionally kept separate from the generic distillation -loss logic in `self_distillation_loss.py`. -""" - -from __future__ import annotations - -import torch -from torch import nn -from transformers.utils import logging - -from ...data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template -from ...models import unwrap_model_for_generation -from ...trainer.base_trainer import _BaseTrainer -from ...trainer.utils import pad - - -logger = logging.get_logger(__name__) - - -class OnlineRolloutMixin: - """Online rollout, reward, and policy-loss utilities shared by SDPO-like trainers.""" - - def _apply_prompt_template(self, prompts): - return [ - maybe_apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"] - for prompt in prompts - ] - - def _build_buffered_batch(self, generation_batch): - return self._generate_and_score_completions(generation_batch) - - def _generate(self, prompts): - if self.use_vllm: - return self._generate_vllm(prompts) - return self._generate_transformers(prompts) - - def _generate_vllm(self, prompts): - # Sync weights if training step changed - if self.state.global_step != self._last_loaded_step: - self.vllm_generation.sync_weights() - self._last_loaded_step = self.state.global_step - - # Tokenize prompts to token IDs - prompts_text = self._apply_prompt_template(prompts) - tokenized = self.processing_class( - text=prompts_text, - return_tensors=None, - padding=False, - max_length=self.max_prompt_length, - truncation=True, - add_special_tokens=False, - ) - prompt_ids = tokenized["input_ids"] # list of list[int] - - # Generate via vLLM — it deduplicates repeated prompts from RepeatSampler internally - mode = "train" if self.model.training else "eval" - num_generations = self.num_generations if mode == "train" else self.num_generations_eval - prompt_ids_out, completion_ids_list, _, _ = self.vllm_generation.generate( - prompts=prompt_ids, - images=None, - num_generations=num_generations, - ) - return prompt_ids_out, completion_ids_list - - def _generate_transformers(self, prompts): - # Keep the generation path aligned with the reference trainers: generate from left-padded prompts, - # then recover completion token spans by trimming prompt tokens and stopping at the first EOS. - prompts_text = self._apply_prompt_template(prompts) - generate_inputs = self.processing_class( - text=prompts_text, - return_tensors="pt", - padding=True, - padding_side="left", - max_length=self.max_prompt_length, - truncation=True, - add_special_tokens=False, - ) - # This path already receives tokenized model inputs. Bypass the buffered trainer hook and use the plain - # tensor/device preparation from `_BaseTrainer`. - generate_inputs = _BaseTrainer._prepare_inputs(self, generate_inputs) - with ( - unwrap_model_for_generation( - self.model_wrapped, - self.accelerator, - gather_deepspeed3_params=self.args.ds3_gather_for_generation, - ) as unwrapped_model, - torch.no_grad(), - ): - prompt_completion_ids = unwrapped_model.generate( - **generate_inputs, generation_config=self.generation_config - ) - prompt_ids = generate_inputs["input_ids"] - prompt_mask = generate_inputs["attention_mask"] - prompt_length = prompt_ids.size(1) - completion_ids = prompt_completion_ids[:, prompt_length:] - is_eos = completion_ids == self.eos_token_id - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device) - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - seq_idx = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1) - completion_mask = (seq_idx <= eos_idx.unsqueeze(1)).int() - prompt_ids_list = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool(), strict=False)] - completion_ids_list = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=False)] - return prompt_ids_list, completion_ids_list - - def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): - device = self.accelerator.device - if len(self.reward_funcs) == 0: - return torch.zeros((len(prompts), 0), device=device) - - rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) - keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]] - reward_kwargs = {key: [example[key] for example in inputs] for key in keys} - reward_kwargs["trainer_state"] = self.state - - for i, (reward_func, reward_processing_class) in enumerate( - zip(self.reward_funcs, self.reward_processing_classes, strict=True) - ): - if isinstance(reward_func, nn.Module): - if is_conversational(inputs[0]): - messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)] - texts = [ - apply_chat_template(x, reward_processing_class, **self.chat_template_kwargs)["text"] - for x in messages - ] - else: - texts = [p + c for p, c in zip(prompts, completions, strict=True)] - reward_inputs = reward_processing_class( - text=texts, - return_tensors="pt", - padding=True, - padding_side="right", - add_special_tokens=False, - ) - # Reward functions operate on tokenized tensors too, so they need the base Trainer input preparation - # rather than the outer buffered generation hook. - reward_inputs = _BaseTrainer._prepare_inputs(self, reward_inputs) - with torch.inference_mode(): - rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] - else: - output_reward_func = reward_func( - prompts=prompts, - completions=completions, - completion_ids=completion_ids_list, - **reward_kwargs, - ) - output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] - rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) - - return self.accelerator.gather(rewards_per_func) - - def _generate_and_score_completions(self, inputs): - device = self.accelerator.device - mode = "train" if self.model.training else "eval" - prompts = [x["prompt"] for x in inputs] - prompt_ids_list, completion_ids_list = self._generate(prompts) - - prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list] - prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left").to(device=device) - prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left").to(device=device) - completion_ids = [torch.tensor(ids) for ids in completion_ids_list] - completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right").to(device=device) - completion_mask = pad(completion_mask, padding_value=0, padding_side="right").to(device=device) - - if self.mask_truncated_completions: - eos_and_pad = [self.eos_token_id, self.pad_token_id] - is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) - completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() - - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - logits_to_keep = completion_ids.size(1) - - with torch.no_grad(): - generate_every = self.args.steps_per_generation * self.num_iterations - if self.args.gradient_accumulation_steps % generate_every != 0: - old_per_token_logps, _ = self._get_per_token_logps_and_entropies( - self.model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - compute_entropy=False, - ) - else: - old_per_token_logps = None - - if is_conversational({"prompt": prompts[0]}): - completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - completions = [[{"role": "assistant", "content": content}] for content in completions_text] - else: - completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - - rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) - if rewards_per_func.numel() == 0: - rewards = torch.zeros(self.accelerator.num_processes * len(prompts), device=device) - else: - rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) - num_generations = self.num_generations if mode == "train" else self.num_generations_eval - mean_grouped_rewards = rewards.view(-1, num_generations).mean(dim=1).repeat_interleave(num_generations, dim=0) - if self.scale_rewards == "batch": - std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards) - group_std_rewards = rewards.view(-1, num_generations).std(dim=1) - elif self.scale_rewards == "none": - std_rewards = torch.ones_like(rewards) - group_std_rewards = torch.ones(rewards.numel() // num_generations, device=device, dtype=rewards.dtype) - else: - group_std_rewards = rewards.view(-1, num_generations).std(dim=1) - std_rewards = group_std_rewards.repeat_interleave(num_generations, dim=0) - advantages = (rewards - mean_grouped_rewards) / (std_rewards + 1e-4) - self._record_reward_diagnostics(mode, rewards, rewards_per_func, group_std_rewards) - - local_batch_size = completion_ids.size(0) - process_start = self.accelerator.process_index * local_batch_size - process_slice = slice(process_start, process_start + local_batch_size) - rewards = rewards[process_slice] - advantages = advantages[process_slice] - - agg_completion_lengths = self.accelerator.gather( - torch.tensor([len(ids) for ids in completion_ids_list], device=device) - ) - self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) - - eos_and_pad = [self.eos_token_id, self.pad_token_id] - is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) - agg_is_truncated = self.accelerator.gather(is_truncated) - self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) - term_completion_lengths = agg_completion_lengths[~agg_is_truncated] - if len(term_completion_lengths) == 0: - term_completion_lengths = torch.zeros(1, device=device) - self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - - output = { - "prompt_ids": prompt_ids, - "prompt_mask": prompt_mask, - "completion_ids": completion_ids, - "completion_mask": completion_mask, - "rewards": rewards, - "advantages": advantages, - "num_items_in_batch": completion_mask.sum().detach(), - } - if old_per_token_logps is not None: - output["old_per_token_logps"] = old_per_token_logps - - self._dispatch_self_distillation_callback( - "on_self_distillation_batch_prepared", - old_per_token_logps=old_per_token_logps, - prompt_ids=prompt_ids, - completion_ids=completion_ids, - ) - return output - - def _record_reward_diagnostics( - self, - mode: str, - rewards: torch.Tensor, - rewards_per_func: torch.Tensor, - group_std_rewards: torch.Tensor, - ) -> None: - tolerance = self.args.diagnostics_flat_tolerance - - reward_mean = rewards.mean() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) - reward_std = rewards.std() if rewards.numel() > 1 else torch.tensor(0.0, device=self.accelerator.device) - reward_min = rewards.min() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) - reward_max = rewards.max() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) - flat_group_fraction = ( - (group_std_rewards <= tolerance).float().mean() - if group_std_rewards.numel() > 0 - else torch.tensor(1.0, device=self.accelerator.device) - ) - - self._metrics[mode]["self_distillation/reward_mean"].append(self.accelerator.gather(reward_mean).mean().item()) - self._metrics[mode]["self_distillation/reward_std"].append(self.accelerator.gather(reward_std).mean().item()) - self._metrics[mode]["self_distillation/reward_min"].append(self.accelerator.gather(reward_min).min().item()) - self._metrics[mode]["self_distillation/reward_max"].append(self.accelerator.gather(reward_max).max().item()) - self._metrics[mode]["self_distillation/group_reward_std_mean"].append( - self.accelerator.gather(group_std_rewards.mean() if group_std_rewards.numel() > 0 else reward_std) - .mean() - .item() - ) - self._metrics[mode]["self_distillation/flat_group_fraction"].append( - self.accelerator.gather(flat_group_fraction).mean().item() - ) - - if rewards_per_func.numel() > 0: - reward_func_means = rewards_per_func.nanmean(dim=0) - gathered_means = self.accelerator.gather(reward_func_means).view(-1, reward_func_means.numel()).mean(dim=0) - for reward_name, reward_func_mean in zip(self.reward_func_names, gathered_means.tolist(), strict=True): - self._metrics[mode][f"self_distillation/rewards/{reward_name}"].append(reward_func_mean) - - reward_is_flat = reward_std.item() <= tolerance - grouped_rewards_are_flat = flat_group_fraction.item() >= 1.0 - tolerance - if reward_is_flat and grouped_rewards_are_flat: - self._warn_on_degenerate_diagnostics( - mode=mode, - counter_key="flat_rewards", - message=( - "Observed flat SDPO rewards across all sampled generations. " - "Policy advantages will collapse to zero, and SDPO will not learn. " - "Check reward density, reward shaping, or `success_reward_threshold`." - ), - ) - else: - self._diagnostic_counters[mode]["flat_rewards"] = 0 - - def _warn_on_degenerate_diagnostics(self, mode: str, counter_key: str, message: str) -> None: - interval = self.args.diagnostics_warning_interval - if interval == 0: - return - - self._diagnostic_counters[mode][counter_key] += 1 - count = self._diagnostic_counters[mode][counter_key] - if count == 1 or count % interval == 0: - logger.warning("%s Consecutive degenerate steps: %s.", message, count) - - def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): - if return_outputs: - raise ValueError(f"The {self.__class__.__name__} does not support returning outputs") - return self._compute_loss(model, inputs) - - def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): - if not isinstance(inputs, dict): - inputs = self._prepare_inputs(inputs) - with torch.no_grad(): - with self.compute_loss_context_manager(): - loss = self.compute_loss(model, inputs) - return loss.detach(), None, None - - def _compute_loss(self, model, inputs): - prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] - completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] - input_ids = torch.cat([prompt_ids, completion_ids], dim=1) - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - logits_to_keep = completion_ids.size(1) - per_token_logps, _ = self._get_per_token_logps_and_entropies( - model, - input_ids, - attention_mask, - logits_to_keep, - compute_entropy=False, - ) - old_per_token_logps = inputs.get("old_per_token_logps") - old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps - advantages = inputs["advantages"] - if advantages.dim() == 1: - advantages = advantages.unsqueeze(1) - log_ratio = per_token_logps - old_per_token_logps - if self.importance_sampling_level == "sequence": - log_ratio = (log_ratio * completion_mask).sum(-1, keepdim=True) / completion_mask.sum( - -1, keepdim=True - ).clamp(min=1.0) - coef_1 = torch.exp(log_ratio) - coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) - per_token_loss = -torch.min(coef_1 * advantages, coef_2 * advantages) - - loss = self._aggregate_self_distillation_loss(per_token_loss, completion_mask) - - mode = "train" if self.model.training else "eval" - self._metrics[mode]["self_distillation/policy_loss"].append( - self.accelerator.gather(loss.detach()).mean().item() - ) - - accumulation_scale = self.current_gradient_accumulation_steps if mode == "train" else 1.0 - return loss / accumulation_scale diff --git a/trl/experimental/self_distillation/teacher_context.py b/trl/experimental/self_distillation/prompt_utils.py similarity index 77% rename from trl/experimental/self_distillation/teacher_context.py rename to trl/experimental/self_distillation/prompt_utils.py index 70e28a92adc..f5decce6d7a 100644 --- a/trl/experimental/self_distillation/teacher_context.py +++ b/trl/experimental/self_distillation/prompt_utils.py @@ -17,15 +17,9 @@ from typing import Any -def _split_prompt_and_privileged_context(inputs: list[dict[str, Any]]) -> tuple[list[Any], list[Any]]: - prompts = [example["prompt"] for example in inputs] - privileged_contexts = [example.get("privileged_context") for example in inputs] - return prompts, privileged_contexts - - -def extract_last_user_text(prompt: list[dict[str, Any]]) -> str: +def extract_last_user_text(messages: list[dict[str, Any]]) -> str: """Extract the text content from the last user message in a conversational prompt.""" - last_message = prompt[-1] + last_message = messages[-1] if last_message.get("role") != "user": raise ValueError( f"Self-distillation teacher prompt construction expects the conversation to end with a user turn, " diff --git a/trl/experimental/self_distillation/self_distillation_loss.py b/trl/experimental/self_distillation/self_distillation_loss.py deleted file mode 100644 index faeef1d4367..00000000000 --- a/trl/experimental/self_distillation/self_distillation_loss.py +++ /dev/null @@ -1,318 +0,0 @@ -# Copyright 2020-2026 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Shared self-distillation loss computation. - -This module intentionally contains only the reusable distillation loss mechanics. Trainer lifecycle concerns, -callback dispatch, and batch construction live in the trainer classes. -""" - -from __future__ import annotations - -from typing import Any, Protocol - -import torch -import torch.nn.functional as F - - -class SelfDistillationRuntime(Protocol): - """Minimal trainer surface required by `SelfDistillationLossComputer`.""" - - args: Any - accelerator: Any - model_kwarg_keys: Any - temperature: float - loss_type: str - max_completion_length: int - _metrics: dict[str, Any] - _name: str - - def _allow_topk_without_full_logit_distillation(self) -> bool: ... - - def _get_teacher_model_for_self_distillation(self, model): ... - - def _get_teacher_context_for_self_distillation(self): ... - - -class SelfDistillationLossComputer: - """Computes the shared student-vs-teacher self-distillation loss.""" - - def __init__(self, runtime: SelfDistillationRuntime): - self.runtime = runtime - - def compute_loss(self, model, inputs: dict[str, Any]) -> torch.Tensor: - prompt_ids = inputs["prompt_ids"] - prompt_mask = inputs["prompt_mask"] - completion_ids = inputs["completion_ids"] - completion_mask = inputs["completion_mask"] - logits_to_keep = completion_ids.size(1) - - response_mask = self._build_response_mask(completion_mask, inputs.get("self_distillation_mask")) - if response_mask.sum() == 0: - mode = "train" if model.training else "eval" - self._log_distillation_metric(mode, 0.0) - return torch.tensor(0.0, device=completion_ids.device, requires_grad=True) - - student_logits = self._compute_student_logits( - model=model, - prompt_ids=prompt_ids, - prompt_mask=prompt_mask, - completion_ids=completion_ids, - completion_mask=completion_mask, - logits_to_keep=logits_to_keep, - ) - teacher_logits = self._compute_teacher_logits( - model=model, - teacher_input_ids=inputs["teacher_input_ids"], - teacher_attention_mask=inputs["teacher_attention_mask"], - logits_to_keep=logits_to_keep, - ) - - per_token_loss = self._compute_per_token_loss( - student_logits=student_logits, - teacher_logits=teacher_logits, - completion_ids=completion_ids, - ) - - old_log_probs = inputs.get("old_per_token_logps") - if self.runtime.args.distillation_is_clip is not None and old_log_probs is not None: - student_per_token_logps = self._select_token_log_probs(student_logits, completion_ids) - per_token_loss = self._apply_importance_sampling_clipping( - per_token_loss, - student_per_token_logps, - old_log_probs, - self.runtime.args.distillation_is_clip, - ) - - loss = self._aggregate_loss(per_token_loss, response_mask) - - mode = "train" if model.training else "eval" - mean_distill_loss = (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) - self._log_distillation_metric(mode, self.runtime.accelerator.gather(mean_distill_loss).mean().item()) - return loss - - @staticmethod - def _build_response_mask( - completion_mask: torch.Tensor, - self_distillation_mask: torch.Tensor | None, - ) -> torch.Tensor: - if self_distillation_mask is None: - return completion_mask - return completion_mask * self_distillation_mask.unsqueeze(1) - - def _compute_student_logits( - self, - model, - prompt_ids: torch.Tensor, - prompt_mask: torch.Tensor, - completion_ids: torch.Tensor, - completion_mask: torch.Tensor, - logits_to_keep: int, - ) -> torch.Tensor: - student_input_ids = torch.cat([prompt_ids, completion_ids], dim=1) - student_attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - return self._forward_logits( - model=model, - input_ids=student_input_ids, - attention_mask=student_attention_mask, - logits_to_keep=logits_to_keep, - ) - - def _compute_teacher_logits( - self, - model, - teacher_input_ids: torch.Tensor, - teacher_attention_mask: torch.Tensor, - logits_to_keep: int, - ) -> torch.Tensor: - teacher_model = self.runtime._get_teacher_model_for_self_distillation(model) - with torch.no_grad(), self.runtime._get_teacher_context_for_self_distillation(): - return self._forward_logits( - model=teacher_model, - input_ids=teacher_input_ids, - attention_mask=teacher_attention_mask, - logits_to_keep=logits_to_keep, - ) - - def _forward_logits( - self, - model, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - logits_to_keep: int, - ) -> torch.Tensor: - model_inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "use_cache": False, - } - if "logits_to_keep" in self.runtime.model_kwarg_keys: - model_inputs["logits_to_keep"] = logits_to_keep + 1 - - logits = model(**model_inputs).logits - logits = logits[:, :-1, :] - logits = logits[:, -logits_to_keep:, :] - return logits / self.runtime.temperature - - def _compute_per_token_loss( - self, - student_logits: torch.Tensor, - teacher_logits: torch.Tensor, - completion_ids: torch.Tensor, - ) -> torch.Tensor: - args = self.runtime.args - use_topk_distillation = args.distillation_topk is not None and ( - args.full_logit_distillation or self.runtime._allow_topk_without_full_logit_distillation() - ) - - if use_topk_distillation: - return self._compute_topk_distillation_loss(student_logits, teacher_logits, args.distillation_topk) - if args.full_logit_distillation: - return self._compute_full_logit_distillation_loss(student_logits, teacher_logits) - return self._compute_sampled_token_distillation_loss(student_logits, teacher_logits, completion_ids) - - def _compute_topk_distillation_loss( - self, - student_logits: torch.Tensor, - teacher_logits: torch.Tensor, - topk: int, - ) -> torch.Tensor: - args = self.runtime.args - student_logsumexp = torch.logsumexp(student_logits, dim=-1, keepdim=True) - topk_student_logits, topk_indices = torch.topk(student_logits, k=topk, dim=-1) - topk_student_log_probs = topk_student_logits - student_logsumexp - - teacher_logsumexp = torch.logsumexp(teacher_logits, dim=-1, keepdim=True) - topk_teacher_logits = torch.gather(teacher_logits, dim=-1, index=topk_indices) - topk_teacher_log_probs = topk_teacher_logits - teacher_logsumexp - - if args.distillation_add_tail: - topk_student_log_probs = self._add_tail(topk_student_log_probs) - topk_teacher_log_probs = self._add_tail(topk_teacher_log_probs) - else: - topk_student_log_probs = self._renorm_topk_log_probs(topk_student_log_probs) - topk_teacher_log_probs = self._renorm_topk_log_probs(topk_teacher_log_probs) - - return self._compute_divergence(topk_student_log_probs, topk_teacher_log_probs, args.distillation_alpha) - - def _compute_full_logit_distillation_loss( - self, - student_logits: torch.Tensor, - teacher_logits: torch.Tensor, - ) -> torch.Tensor: - args = self.runtime.args - student_log_probs = F.log_softmax(student_logits, dim=-1) - teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) - return self._compute_divergence(student_log_probs, teacher_log_probs, args.distillation_alpha) - - def _compute_sampled_token_distillation_loss( - self, - student_logits: torch.Tensor, - teacher_logits: torch.Tensor, - completion_ids: torch.Tensor, - ) -> torch.Tensor: - if self.runtime.args.distillation_alpha != 1.0: - raise ValueError( - "Only reverse KL (alpha=1.0) is supported for token-level distillation when " - f"`full_logit_distillation=False`, got alpha={self.runtime.args.distillation_alpha}" - ) - - student_per_token_logps = self._select_token_log_probs(student_logits, completion_ids) - teacher_per_token_logps = self._select_token_log_probs(teacher_logits, completion_ids) - return self._compute_token_level_distillation_loss(student_per_token_logps, teacher_per_token_logps) - - @staticmethod - def _select_token_log_probs( - logits: torch.Tensor, - token_ids: torch.Tensor, - ) -> torch.Tensor: - logsumexp = torch.logsumexp(logits, dim=-1, keepdim=True) - indices = token_ids.unsqueeze(-1) - return (torch.gather(logits, dim=-1, index=indices) - logsumexp).squeeze(-1) - - def _aggregate_loss( - self, - per_token_loss: torch.Tensor, - response_mask: torch.Tensor, - ) -> torch.Tensor: - loss_type = self.runtime.loss_type - if loss_type == "grpo": - loss = (per_token_loss * response_mask).sum(-1) / response_mask.sum(-1).clamp(min=1.0) - return loss.mean() - if loss_type == "bnpo": - return (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) - if loss_type == "dr_grpo": - return (per_token_loss * response_mask).sum() / ( - per_token_loss.size(0) * self.runtime.max_completion_length - ) - if loss_type in ["dapo", "luspo", "cispo", "sapo"]: - return (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) - raise ValueError(f"Unsupported loss_type for self-distillation: {loss_type}") - - def _log_distillation_metric(self, mode: str, value: float) -> None: - metric_prefix = getattr(self.runtime, "_name", "self_distillation").lower().replace(" ", "_") - self.runtime._metrics[mode]["self_distillation/distillation_loss"].append(value) - self.runtime._metrics[mode][f"{metric_prefix}/distillation_loss"].append(value) - - @staticmethod - def _compute_divergence( - student_log_probs: torch.Tensor, - teacher_log_probs: torch.Tensor, - alpha: float, - ) -> torch.Tensor: - if alpha == 0.0: - kl = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) - elif alpha == 1.0: - kl = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) - else: - alpha_t = torch.tensor(alpha, dtype=student_log_probs.dtype, device=student_log_probs.device) - mixture = torch.logsumexp( - torch.stack([student_log_probs + torch.log(1 - alpha_t), teacher_log_probs + torch.log(alpha_t)]), - dim=0, - ) - kl_teacher = F.kl_div(mixture, teacher_log_probs, reduction="none", log_target=True) - kl_student = F.kl_div(mixture, student_log_probs, reduction="none", log_target=True) - kl = torch.lerp(kl_student, kl_teacher, alpha) - return kl.sum(-1) - - @staticmethod - def _add_tail(log_probs: torch.Tensor) -> torch.Tensor: - log_s = torch.logsumexp(log_probs, dim=-1, keepdim=True) - log_s = torch.clamp(log_s, max=-1e-7) - tail_log = torch.log(-torch.expm1(log_s)) - return torch.cat([log_probs, tail_log], dim=-1) - - @staticmethod - def _renorm_topk_log_probs(log_probs: torch.Tensor) -> torch.Tensor: - return log_probs - torch.logsumexp(log_probs, dim=-1, keepdim=True) - - @staticmethod - def _compute_token_level_distillation_loss( - student_log_probs: torch.Tensor, - teacher_log_probs: torch.Tensor, - ) -> torch.Tensor: - log_ratio = student_log_probs - teacher_log_probs - return log_ratio.detach() * student_log_probs - - @staticmethod - def _apply_importance_sampling_clipping( - per_token_loss: torch.Tensor, - student_log_probs: torch.Tensor, - old_log_probs: torch.Tensor, - clip_coeff: float, - ) -> torch.Tensor: - negative_approx_kl = (student_log_probs - old_log_probs).detach() - negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) - ratio = torch.exp(negative_approx_kl).clamp(max=clip_coeff) - return per_token_loss * ratio From 6a7d5a87392d6f271fdb3f99e2800045f2a855bb Mon Sep 17 00:00:00 2001 From: Leon Date: Thu, 16 Apr 2026 15:22:35 +0200 Subject: [PATCH 10/23] always set teacher_model --- trl/experimental/sdft/sdft_trainer.py | 2 +- trl/experimental/sdpo/sdpo_trainer.py | 4 +-- .../base_self_distillation_trainer.py | 28 ++++++------------- 3 files changed, 12 insertions(+), 22 deletions(-) diff --git a/trl/experimental/sdft/sdft_trainer.py b/trl/experimental/sdft/sdft_trainer.py index be1c7e0b9e8..71e52ec5d0d 100644 --- a/trl/experimental/sdft/sdft_trainer.py +++ b/trl/experimental/sdft/sdft_trainer.py @@ -195,7 +195,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N completion_mask = completion_mask * (token_positions >= self.num_loss_tokens_to_skip).long() inputs["completion_mask"] = completion_mask - distillation_logits = self._compute_teacher_student_logits(model, inputs) + distillation_logits = self._compute_teacher_student_logits(model, self.teacher_model, inputs) loss = self._compute_self_distillation_loss(model, inputs, distillation_logits) accumulation_scale = self.current_gradient_accumulation_steps if self.model.training else 1.0 return loss / accumulation_scale diff --git a/trl/experimental/sdpo/sdpo_trainer.py b/trl/experimental/sdpo/sdpo_trainer.py index 0482048fbc7..fd63fc7ef9c 100644 --- a/trl/experimental/sdpo/sdpo_trainer.py +++ b/trl/experimental/sdpo/sdpo_trainer.py @@ -645,7 +645,7 @@ def _compute_weighted_self_distillation_loss( return self.args.distillation_weight * distillation_loss def _compute_hybrid_loss(self, model, inputs) -> torch.Tensor: - distillation_logits = self._compute_teacher_student_logits(model, inputs) + distillation_logits = self._compute_teacher_student_logits(model, self.teacher_model, inputs) policy_loss = self._compute_policy_loss(inputs, distillation_logits.student_logits) weighted_distillation_loss = self._compute_weighted_self_distillation_loss( model, @@ -662,7 +662,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N return self._compute_hybrid_loss(model, inputs) if self.args.distillation_weight > 0.0: - distillation_logits = self._compute_teacher_student_logits(model, inputs) + distillation_logits = self._compute_teacher_student_logits(model, self.teacher_model, inputs) return self._compute_weighted_self_distillation_loss(model, inputs, distillation_logits) else: student_logits = self._compute_student_distillation_logits( diff --git a/trl/experimental/self_distillation/base_self_distillation_trainer.py b/trl/experimental/self_distillation/base_self_distillation_trainer.py index 7d643878c4d..00dde4202b2 100644 --- a/trl/experimental/self_distillation/base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/base_self_distillation_trainer.py @@ -292,9 +292,11 @@ def _setup_teacher_model(self) -> None: teacher_model_kind = self.args.teacher_model_kind if teacher_model_kind == "live": + self.teacher_model = self.model return if teacher_model_kind == "base" and is_peft_model(self.model): + self.teacher_model = self.model return if self._use_peft_ema_teacher_adapter(): @@ -307,6 +309,7 @@ def _setup_teacher_model(self) -> None: accelerator=self.accelerator, ) ) + self.teacher_model = self.model return # create teacher model from student copy @@ -506,7 +509,7 @@ def sample_rollouts(self, inputs: list[dict[str, Any]]) -> RolloutBatch: ), ) - def _split_prompt_and_privileged_context(inputs: list[dict[str, Any]]) -> tuple[list[Any], list[Any]]: + def _split_prompt_and_privileged_context(self, inputs: list[dict[str, Any]]) -> tuple[list[Any], list[Any]]: prompts = [example["prompt"] for example in inputs] privileged_contexts = [example.get("privileged_context") for example in inputs] return prompts, privileged_contexts @@ -669,6 +672,7 @@ def _compute_self_distillation_loss( def _compute_teacher_student_logits( self, model, + teacher_model, inputs: TrainingBatch, ) -> DistillationLogits: prompt_ids = inputs["prompt_ids"] @@ -691,7 +695,7 @@ def _compute_teacher_student_logits( ) teacher_logits = self._compute_teacher_distillation_logits( - model=model, + teacher_model=teacher_model, teacher_input_ids=inputs["teacher_input_ids"], teacher_attention_mask=inputs["teacher_attention_mask"], logits_to_keep=logits_to_keep, @@ -734,15 +738,13 @@ def _compute_student_distillation_logits( def _compute_teacher_distillation_logits( self, - model, teacher_input_ids: torch.Tensor, teacher_attention_mask: torch.Tensor, logits_to_keep: int, ) -> torch.Tensor: - teacher_model = self._get_teacher_model_for_self_distillation(model) with torch.no_grad(), self._get_teacher_context_for_self_distillation(): return self._forward_logits( - model=teacher_model, + model=self.teacher_model, input_ids=teacher_input_ids, attention_mask=teacher_attention_mask, logits_to_keep=logits_to_keep, @@ -807,26 +809,14 @@ def _get_teacher_context_for_self_distillation(self): if not is_peft_model(self.model): return nullcontext() - target_model = self.teacher_model if self.teacher_model is not None else self.model - target_model = self.accelerator.unwrap_model(target_model) + target_model = self.accelerator.unwrap_model(self.teacher_model) if teacher_model_kind == "base": return use_adapter(target_model, adapter_name=None) if teacher_model_kind == "ema" and self._use_peft_ema_teacher_adapter(): - teacher_adapter_name = self._get_teacher_adapter_name() - if teacher_adapter_name not in target_model.peft_config: - raise RuntimeError( - f"Expected PEFT teacher adapter `{teacher_adapter_name}` to exist before teacher forward." - ) - return use_adapter(target_model, adapter_name=teacher_adapter_name) + return use_adapter(target_model, adapter_name="teacher") return nullcontext() - def _get_teacher_model_for_self_distillation(self, model): - teacher_model = getattr(self, "teacher_model", None) - if teacher_model is None: - return model - return teacher_model - def _log_self_distillation_metric(self, mode: str, value: float) -> None: metric_prefix = getattr(self, "_name", "self_distillation").lower().replace(" ", "_") self._metrics[mode]["self_distillation/distillation_loss"].append(value) From 56b2fd1eb6fe72f1758417191efb2317715652a4 Mon Sep 17 00:00:00 2001 From: Leon Date: Thu, 16 Apr 2026 16:56:44 +0200 Subject: [PATCH 11/23] align generation tokenization with grpotrainer --- trl/experimental/sdft/sdft_trainer.py | 14 ++- .../base_self_distillation_trainer.py | 118 ++++++------------ 2 files changed, 52 insertions(+), 80 deletions(-) diff --git a/trl/experimental/sdft/sdft_trainer.py b/trl/experimental/sdft/sdft_trainer.py index 71e52ec5d0d..8e13fe1004d 100644 --- a/trl/experimental/sdft/sdft_trainer.py +++ b/trl/experimental/sdft/sdft_trainer.py @@ -29,6 +29,7 @@ ) from transformers.utils import is_peft_available +from ...trainer.utils import pad from ..self_distillation.base_self_distillation_trainer import ( BaseSelfDistillationTrainer, RolloutBatch, @@ -102,9 +103,16 @@ def build( self._compose_teacher_prompt(prompt, privileged_context) for prompt, privileged_context in zip(prompts, privileged_contexts, strict=True) ] - teacher_batch = self.trainer._tokenize_prompts(teacher_prompts) - teacher_input_ids = torch.cat([teacher_batch["prompt_ids"], completion_ids], dim=1) - teacher_attention_mask = torch.cat([teacher_batch["prompt_mask"], completion_mask], dim=1) + teacher_prompt_ids_list = self.trainer._tokenize_prompts(teacher_prompts) + device = completion_ids.device + teacher_prompt_ids = [torch.tensor(ids) for ids in teacher_prompt_ids_list] + teacher_prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in teacher_prompt_ids] + teacher_prompt_ids = pad(teacher_prompt_ids, padding_value=self.trainer.pad_token_id, padding_side="left").to( + device=device + ) + teacher_prompt_mask = pad(teacher_prompt_mask, padding_value=0, padding_side="left").to(device=device) + teacher_input_ids = torch.cat([teacher_prompt_ids, completion_ids], dim=1) + teacher_attention_mask = torch.cat([teacher_prompt_mask, completion_mask], dim=1) return { "teacher_input_ids": teacher_input_ids, "teacher_attention_mask": teacher_attention_mask, diff --git a/trl/experimental/self_distillation/base_self_distillation_trainer.py b/trl/experimental/self_distillation/base_self_distillation_trainer.py index 00dde4202b2..8798b35e3b0 100644 --- a/trl/experimental/self_distillation/base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/base_self_distillation_trainer.py @@ -29,6 +29,7 @@ from accelerate.utils import is_peft_model from datasets import Dataset, IterableDataset from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.utils.data import DataLoader, Sampler from transformers import ( AutoProcessor, @@ -41,7 +42,7 @@ from transformers.trainer_utils import seed_worker from transformers.utils import is_datasets_available, is_peft_available -from ...data_utils import maybe_apply_chat_template +from ...data_utils import is_conversational from ...models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation from ...trainer.base_trainer import _BaseTrainer from ...trainer.utils import ( @@ -75,8 +76,6 @@ @dataclass class RolloutBatch: - """Common student rollout batch produced before algorithm-specific finalization.""" - prompt_ids: torch.Tensor prompt_mask: torch.Tensor completion_ids: torch.Tensor @@ -429,55 +428,32 @@ def _prepare_training_batch(self, inputs: list[dict[str, Any]]) -> TrainingBatch ) return batch - def _apply_chat_template_to_prompts(self, prompts: list[Any]) -> list[str]: - return [ - maybe_apply_chat_template( - {"prompt": prompt}, - self.processing_class, + def _tokenize_prompts(self, prompts: list[Any]) -> list[list[int]]: + if is_conversational({"prompt": prompts[0]}): + tokenized = self.processing_class.apply_chat_template( + conversation=prompts, + add_generation_prompt=True, + tokenize=True, + return_dict=True, **self.chat_template_kwargs, - )["prompt"] - for prompt in prompts - ] - - def _tokenize_prompt_text(self, prompt_text: list[str]) -> dict[str, torch.Tensor]: - prompt_inputs = self.processing_class( - text=prompt_text, - return_tensors="pt", - padding=True, - padding_side="left", - max_length=self.max_prompt_length, - truncation=True, - add_special_tokens=False, - ) - prompt_inputs = _BaseTrainer._prepare_inputs(self, prompt_inputs) - prompt_ids = [ - input_ids[mask].to(device=self.accelerator.device) - for input_ids, mask in zip( - prompt_inputs["input_ids"], - prompt_inputs["attention_mask"].bool(), - strict=False, ) - ] - prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] - return { - "prompt_ids": pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left"), - "prompt_mask": pad(prompt_mask, padding_value=0, padding_side="left"), - } - - def _tokenize_prompts(self, prompts: list[Any]) -> dict[str, torch.Tensor]: - return self._tokenize_prompt_text(self._apply_chat_template_to_prompts(prompts)) + prompt_ids = tokenized["input_ids"] + else: + prompt_ids = self.processing_class(text=prompts)["input_ids"] + if self.max_prompt_length is not None: + prompt_ids = [ids[-self.max_prompt_length :] for ids in prompt_ids] + return prompt_ids def sample_rollouts(self, inputs: list[dict[str, Any]]) -> RolloutBatch: prompts, _ = self._split_prompt_and_privileged_context(inputs) - generation_prompts = prompts - generation_prompt_text = self._apply_chat_template_to_prompts(generation_prompts) + prompt_ids = self._tokenize_prompts(prompts) self._dispatch_self_distillation_callback( "on_generation_prompts_selected", - generation_prompts=generation_prompts, - generation_prompt_text=generation_prompt_text, + generation_prompts=prompts, + generation_prompt_text=None, ) - prompt_ids_list, completion_ids_list = self._generate(generation_prompt_text) + prompt_ids_list, completion_ids_list = self._generate(prompt_ids) device = self.accelerator.device prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list] prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] @@ -514,25 +490,16 @@ def _split_prompt_and_privileged_context(self, inputs: list[dict[str, Any]]) -> privileged_contexts = [example.get("privileged_context") for example in inputs] return prompts, privileged_contexts - def _generate(self, prompt_text: list[str]) -> tuple[list[list[int]], list[list[int]]]: + def _generate(self, prompt_ids: list[list[int]]) -> tuple[list[list[int]], list[list[int]]]: if self.use_vllm: - return self._generate_vllm(prompt_text) - return self._generate_transformers(prompt_text) + return self._generate_vllm(prompt_ids) + return self._generate_transformers(prompt_ids) - def _generate_vllm(self, prompt_text: list[str]) -> tuple[list[list[int]], list[list[int]]]: + def _generate_vllm(self, prompt_ids: list[list[int]]) -> tuple[list[list[int]], list[list[int]]]: if self.state.global_step != self._last_loaded_step: self.vllm_generation.sync_weights() self._last_loaded_step = self.state.global_step - tokenized = self.processing_class( - text=prompt_text, - return_tensors=None, - padding=False, - max_length=self.max_prompt_length, - truncation=True, - add_special_tokens=False, - ) - prompt_ids = tokenized["input_ids"] mode = "train" if self.model.training else "eval" num_generations = self.num_generations if mode == "train" else self.num_generations_eval prompt_ids_out, completion_ids_list, _, _ = self.vllm_generation.generate( @@ -542,16 +509,12 @@ def _generate_vllm(self, prompt_text: list[str]) -> tuple[list[list[int]], list[ ) return prompt_ids_out, completion_ids_list - def _generate_transformers(self, prompt_text: list[str]) -> tuple[list[list[int]], list[list[int]]]: - generate_inputs = self.processing_class( - text=prompt_text, - return_tensors="pt", - padding=True, - padding_side="left", - max_length=self.max_prompt_length, - truncation=True, - add_special_tokens=False, - ) + def _generate_transformers(self, prompt_ids: list[list[int]]) -> tuple[list[list[int]], list[list[int]]]: + device = self.accelerator.device + prompt_tensors = [torch.tensor(ids) for ids in prompt_ids] + padded_ids = pad(prompt_tensors, padding_value=self.pad_token_id, padding_side="left") + attention_mask = pad([torch.ones_like(t) for t in prompt_tensors], padding_value=0, padding_side="left") + generate_inputs = {"input_ids": padded_ids, "attention_mask": attention_mask} generate_inputs = _BaseTrainer._prepare_inputs(self, generate_inputs) with ( @@ -559,26 +522,26 @@ def _generate_transformers(self, prompt_text: list[str]) -> tuple[list[list[int] self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation, + generation_kwargs=self.generation_kwargs, ) as unwrapped_model, torch.no_grad(), + FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), ): prompt_completion_ids = unwrapped_model.generate( **generate_inputs, generation_config=self.generation_config ) - prompt_ids = generate_inputs["input_ids"] - prompt_mask = generate_inputs["attention_mask"] - prompt_length = prompt_ids.size(1) + prompt_length = generate_inputs["input_ids"].size(1) completion_ids = prompt_completion_ids[:, prompt_length:] is_eos = completion_ids == self.eos_token_id - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device) + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - seq_idx = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1) - completion_mask = (seq_idx <= eos_idx.unsqueeze(1)).long() - - prompt_ids_list = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool(), strict=False)] - completion_ids_list = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=False)] - return prompt_ids_list, completion_ids_list + seq_idx = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) + completion_mask = (seq_idx <= eos_idx.unsqueeze(1)).int() + completion_ids_list = [ + c[m].tolist() for c, m in zip(completion_ids.cpu(), completion_mask.bool().cpu(), strict=True) + ] + return prompt_ids, completion_ids_list def _compute_rollout_logps( self, @@ -738,13 +701,14 @@ def _compute_student_distillation_logits( def _compute_teacher_distillation_logits( self, + teacher_model, teacher_input_ids: torch.Tensor, teacher_attention_mask: torch.Tensor, logits_to_keep: int, ) -> torch.Tensor: with torch.no_grad(), self._get_teacher_context_for_self_distillation(): return self._forward_logits( - model=self.teacher_model, + model=teacher_model, input_ids=teacher_input_ids, attention_mask=teacher_attention_mask, logits_to_keep=logits_to_keep, From 4a9d5275152f4054bf9795765209beabc4970071 Mon Sep 17 00:00:00 2001 From: Leon Date: Thu, 16 Apr 2026 17:15:34 +0200 Subject: [PATCH 12/23] fix: generation_kwargs bug --- .../self_distillation/base_self_distillation_trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/trl/experimental/self_distillation/base_self_distillation_trainer.py b/trl/experimental/self_distillation/base_self_distillation_trainer.py index 8798b35e3b0..a75ef1a0c54 100644 --- a/trl/experimental/self_distillation/base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/base_self_distillation_trainer.py @@ -194,7 +194,7 @@ def __init__( "eval": defaultdict(int), } - generation_kwargs = { + self.generation_kwargs = { "max_new_tokens": self.max_completion_length, "do_sample": True, "pad_token_id": tokenizer.pad_token_id, @@ -208,8 +208,8 @@ def __init__( "cache_implementation": args.cache_implementation, } if args.generation_kwargs is not None: - generation_kwargs.update(args.generation_kwargs) - self.generation_config = GenerationConfig(**generation_kwargs, disable_compile=True) + self.generation_kwargs.update(args.generation_kwargs) + self.generation_config = GenerationConfig(**self.generation_kwargs, disable_compile=True) if hasattr(model, "warnings_issued"): model.warnings_issued["estimate_tokens"] = True From 196feee28dc074c0cbf60369cabe1849c17a1d20 Mon Sep 17 00:00:00 2001 From: Leon Date: Thu, 16 Apr 2026 20:49:08 +0200 Subject: [PATCH 13/23] fix: incorrect import source --- trl/experimental/sdft/sdft.py | 3 +-- trl/experimental/sdpo/sdpo.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/trl/experimental/sdft/sdft.py b/trl/experimental/sdft/sdft.py index e90c4f060ac..eb6b2914996 100644 --- a/trl/experimental/sdft/sdft.py +++ b/trl/experimental/sdft/sdft.py @@ -76,8 +76,7 @@ get_quantization_config, ) from trl.data_utils import maybe_apply_chat_template -from trl.experimental.sdft import SDFTConfig -from trl.experimental.sdft.sdft_trainer import SDFTTrainer +from trl.experimental.sdft import SDFTConfig, SDFTTrainer from trl.models import unwrap_model_for_generation diff --git a/trl/experimental/sdpo/sdpo.py b/trl/experimental/sdpo/sdpo.py index 25f75a55bab..7a65a00cc5d 100644 --- a/trl/experimental/sdpo/sdpo.py +++ b/trl/experimental/sdpo/sdpo.py @@ -78,8 +78,7 @@ get_quantization_config, ) from trl.data_utils import maybe_apply_chat_template -from trl.experimental.sdpo import SDPOConfig -from trl.experimental.sdpo.sdpo_trainer import SDPOTrainer +from trl.experimental.sdpo import SDPOConfig, SDPOTrainer SYSTEM_PROMPT = ( From 3c87400860124a61ac7350798702622b1a4fb841 Mon Sep 17 00:00:00 2001 From: Leon Date: Fri, 17 Apr 2026 15:27:27 +0200 Subject: [PATCH 14/23] fixes: cleanup, standardized tokenization, distill loss=0 fix, sdpo config parameters moved to sdpoconfig, + other nits --- trl/experimental/sdft/sdft.py | 10 +- trl/experimental/sdft/sdft_trainer.py | 30 +++-- trl/experimental/sdpo/sdpo_config.py | 58 ++++++++- trl/experimental/sdpo/sdpo_trainer.py | 36 ++---- .../base_self_distillation_trainer.py | 114 ++++++++++++------ .../self_distillation_config.py | 37 ------ 6 files changed, 158 insertions(+), 127 deletions(-) diff --git a/trl/experimental/sdft/sdft.py b/trl/experimental/sdft/sdft.py index eb6b2914996..6c9aa8d6179 100644 --- a/trl/experimental/sdft/sdft.py +++ b/trl/experimental/sdft/sdft.py @@ -117,14 +117,6 @@ class SDFTScriptArguments(ScriptArguments): ) -@dataclass -class ExampleSDFTConfig(SDFTConfig): - scale_rewards: str = field( - default="group", - metadata={"help": "Reward normalization mode. Supported: `group`, `batch`, `none`."}, - ) - - def _extract_prompt_text(prompt: Any) -> str: if isinstance(prompt, str): return prompt @@ -318,7 +310,7 @@ def _run_tooluse_eval( if __name__ == "__main__": - parser = TrlParser((SDFTScriptArguments, ExampleSDFTConfig, ModelConfig)) + parser = TrlParser((SDFTScriptArguments, SDFTConfig, ModelConfig)) script_args, training_args, model_args = parser.parse_args_and_config() if model_args.model_name_or_path is None: diff --git a/trl/experimental/sdft/sdft_trainer.py b/trl/experimental/sdft/sdft_trainer.py index 8e13fe1004d..aaca27575ba 100644 --- a/trl/experimental/sdft/sdft_trainer.py +++ b/trl/experimental/sdft/sdft_trainer.py @@ -156,6 +156,9 @@ def __init__( ): raise NotImplementedError("Iterable eval datasets are not yet supported in SDFTTrainer.") + self.num_loss_tokens_to_skip = args.num_loss_tokens_to_skip + self.teacher_context_builder = DemonstrationTeacherContextBuilder(self) + super().__init__( model=model, args=args, @@ -167,9 +170,6 @@ def __init__( peft_config=peft_config, ) - self.num_loss_tokens_to_skip = args.num_loss_tokens_to_skip - self.teacher_context_builder = DemonstrationTeacherContextBuilder(self) - def finalize_batch( self, inputs: list[dict[str, Any]], @@ -196,14 +196,24 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N if return_outputs: raise ValueError("The SDFTTrainer does not support returning outputs") - if self.num_loss_tokens_to_skip > 0: - inputs = dict(inputs) - completion_mask = inputs["completion_mask"].clone() - token_positions = torch.arange(completion_mask.size(1), device=completion_mask.device).unsqueeze(0) - completion_mask = completion_mask * (token_positions >= self.num_loss_tokens_to_skip).long() - inputs["completion_mask"] = completion_mask - distillation_logits = self._compute_teacher_student_logits(model, self.teacher_model, inputs) loss = self._compute_self_distillation_loss(model, inputs, distillation_logits) accumulation_scale = self.current_gradient_accumulation_steps if self.model.training else 1.0 return loss / accumulation_scale + + def _build_self_distillation_response_mask( + self, + completion_mask: torch.Tensor, + self_distillation_mask: torch.Tensor | None, + ) -> torch.Tensor: + response_mask = BaseSelfDistillationTrainer._build_self_distillation_response_mask( + completion_mask, + self_distillation_mask, + ) + if self.num_loss_tokens_to_skip <= 0: + return response_mask + + # SDFT skips the first few completion tokens only in the distillation loss to suppress teacher-prompt artifacts. + token_positions = torch.arange(response_mask.size(1), device=response_mask.device).unsqueeze(0) + skip_mask = (token_positions >= self.num_loss_tokens_to_skip).long() + return response_mask * skip_mask diff --git a/trl/experimental/sdpo/sdpo_config.py b/trl/experimental/sdpo/sdpo_config.py index 031f1e73ce8..e4c0e6e030d 100644 --- a/trl/experimental/sdpo/sdpo_config.py +++ b/trl/experimental/sdpo/sdpo_config.py @@ -26,11 +26,26 @@ class SDPOConfig(SelfDistillationConfig): parameters used by Self-Distillation Policy Optimization (SDPO). Parameters: + > Parameters that control the online policy objective + + beta (`float`, *optional*, defaults to `0.0`): + Reference-model KL coefficient for online policy optimization. + epsilon (`float`, *optional*, defaults to `0.2`): + Lower clipping coefficient for GRPO-style policy loss. + epsilon_high (`float` or `None`, *optional*): + Upper clipping coefficient. Defaults to `epsilon` when unset. + importance_sampling_level (`str`, *optional*, defaults to `"token"`): + Importance-sampling granularity. Supported: `token`, `sequence`. + reward_weights (`list[float]` or `None`, *optional*): + Optional weights for multiple reward functions. + scale_rewards (`str` or `bool`, *optional*, defaults to `"group"`): + Reward normalization mode. Supported: `group`, `batch`, `none`. + > Parameters that control the SDPO loss sdpo_policy_loss_mode (`str`, *optional*, defaults to `"distillation_only"`): How SDPO combines the online policy loss and self-distillation loss. Supported: `distillation_only`, - `hybrid`. + `policy_only`, `hybrid`. distillation_alpha (`float`, *optional*, defaults to `1.0`): Divergence interpolation coefficient. Token-level SDPO requires the official reverse-KL setting `distillation_alpha=1.0`. @@ -61,6 +76,30 @@ class SDPOConfig(SelfDistillationConfig): default=True, metadata={"help": "Skip reprompting when model generates correct response."}, ) + beta: float = field( + default=0.0, + metadata={"help": "Reference-model KL coefficient for online policy optimization."}, + ) + epsilon: float = field( + default=0.2, + metadata={"help": "Lower clipping coefficient for GRPO-style policy loss."}, + ) + epsilon_high: float | None = field( + default=None, + metadata={"help": "Upper clipping coefficient. Defaults to `epsilon` when unset."}, + ) + importance_sampling_level: str = field( + default="token", + metadata={"help": "Importance-sampling granularity. Supported: `token`, `sequence`."}, + ) + reward_weights: list[float] | None = field( + default=None, + metadata={"help": "Optional weights for multiple reward functions."}, + ) + scale_rewards: str | bool = field( + default="group", + metadata={"help": "Reward normalization mode. Supported: `group`, `batch`, `none`."}, + ) distillation_alpha: float = field( default=1.0, metadata={ @@ -73,7 +112,7 @@ class SDPOConfig(SelfDistillationConfig): ) sdpo_policy_loss_mode: str = field( default="distillation_only", - metadata={"help": "SDPO policy loss mode. Supported: `distillation_only`, `hybrid`."}, + metadata={"help": "SDPO policy loss mode. Supported: `distillation_only`, `policy_only`, `hybrid`."}, ) teacher_model_kind: str = field( default="ema", @@ -129,8 +168,19 @@ class SDPOConfig(SelfDistillationConfig): def __post_init__(self): super().__post_init__() - if self.sdpo_policy_loss_mode not in {"distillation_only", "hybrid"}: - raise ValueError("sdpo_policy_loss_mode must be one of: 'distillation_only', 'hybrid'") + + self.scale_rewards = {True: "group", False: "none"}.get(self.scale_rewards, self.scale_rewards) + if self.scale_rewards not in ["group", "batch", "none"]: + raise ValueError("scale_rewards must be one of: 'group', 'batch', 'none'") + + if self.importance_sampling_level not in ["token", "sequence"]: + raise ValueError("importance_sampling_level must be either 'token' or 'sequence'") + + if self.epsilon_high is None: + self.epsilon_high = self.epsilon + + if self.sdpo_policy_loss_mode not in {"distillation_only", "policy_only", "hybrid"}: + raise ValueError("sdpo_policy_loss_mode must be one of: 'distillation_only', 'policy_only', 'hybrid'") if self.sdpo_policy_loss_mode == "distillation_only" and self.distillation_weight <= 0: raise ValueError("distillation_only mode requires `distillation_weight > 0`.") if self.sdpo_policy_loss_mode == "hybrid" and self.distillation_weight <= 0: diff --git a/trl/experimental/sdpo/sdpo_trainer.py b/trl/experimental/sdpo/sdpo_trainer.py index fd63fc7ef9c..98e2d023033 100644 --- a/trl/experimental/sdpo/sdpo_trainer.py +++ b/trl/experimental/sdpo/sdpo_trainer.py @@ -65,30 +65,12 @@ def _build_reprompt_text(self, prompt_text: str, solution_text: str, feedback_te def _tokenize_teacher_messages( self, teacher_messages_list: list[str | list[dict[str, Any]]] ) -> dict[str, torch.Tensor]: - teacher_prompt_ids_list = [] device = self.trainer.accelerator.device - chat_template_kwargs = getattr(self.trainer, "chat_template_kwargs", {}) - for msg in teacher_messages_list: - if isinstance(msg, list) and isinstance(msg[0], dict): - tokenized = self.trainer.processing_class.apply_chat_template( - msg, - tokenize=True, - add_generation_prompt=True, - return_tensors="pt", - **chat_template_kwargs, - ) - if isinstance(tokenized, torch.Tensor): - ids = tokenized.squeeze(0) - else: - ids = tokenized["input_ids"].squeeze(0) - else: - ids = self.trainer.processing_class.encode(msg, return_tensors="pt").squeeze(0) - - if ids.shape[0] > self.trainer.args.max_reprompt_len: - ids = ids[-self.trainer.args.max_reprompt_len :] - teacher_prompt_ids_list.append(ids) - - teacher_prompt_ids = [ids.to(device) for ids in teacher_prompt_ids_list] + teacher_prompt_ids_list = self.trainer._tokenize_prompts_untruncated(teacher_messages_list) + teacher_prompt_ids = [ + torch.as_tensor(ids[-self.trainer.args.max_reprompt_len :], device=device) + for ids in teacher_prompt_ids_list + ] teacher_prompt_mask = [torch.ones(len(ids), dtype=torch.long, device=device) for ids in teacher_prompt_ids] return { "prompt_ids": pad(teacher_prompt_ids, padding_value=self.trainer.pad_token_id, padding_side="left"), @@ -351,9 +333,6 @@ def __init__( self.teacher_context_builder = SuccessfulRolloutTeacherContextBuilder(self) - def _allow_topk_without_full_logit_distillation(self) -> bool: - return False - def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): device = self.accelerator.device if len(self.reward_funcs) == 0: @@ -661,10 +640,11 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N if self.args.sdpo_policy_loss_mode == "hybrid": return self._compute_hybrid_loss(model, inputs) - if self.args.distillation_weight > 0.0: + elif self.args.sdpo_policy_loss_mode == "distillation_only": distillation_logits = self._compute_teacher_student_logits(model, self.teacher_model, inputs) return self._compute_weighted_self_distillation_loss(model, inputs, distillation_logits) - else: + + elif self.args.sdpo_policy_loss_mode == "policy_only": student_logits = self._compute_student_distillation_logits( model=model, prompt_ids=inputs["prompt_ids"], diff --git a/trl/experimental/self_distillation/base_self_distillation_trainer.py b/trl/experimental/self_distillation/base_self_distillation_trainer.py index a75ef1a0c54..72e3f1d2ab0 100644 --- a/trl/experimental/self_distillation/base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/base_self_distillation_trainer.py @@ -76,6 +76,8 @@ @dataclass class RolloutBatch: + """Student-side rollout produced by `sample_rollouts`, consumed by `finalize_batch` to form a `TrainingBatch`.""" + prompt_ids: torch.Tensor prompt_mask: torch.Tensor completion_ids: torch.Tensor @@ -286,7 +288,25 @@ def _dispatch_self_distillation_callback(self, event_name: str, **payload) -> No ) def _setup_teacher_model(self) -> None: - """Prepare teacher state according to the semantic teacher choice.""" + """Prepare teacher state according to the semantic teacher choice. + + Resolve `teacher_model_kind` × PEFT state into the effective teacher: + + - `"live"` (any model): + Teacher is the student. No divergence, no callback. + - `"base"` + PEFT model: + Teacher reuses `self.model`; the base weights are recovered downstream by disabling the adapter + via `use_adapter` during teacher forward. + - `"base"` + non-PEFT model: + Teacher is a frozen deepcopy of the initial student (falls through to the copy branch below). + - `"ema"` + pure-LoRA training: + Teacher reuses `self.model`; a dedicated `"teacher"` LoRA adapter is attached and updated by + `PEFTAdapterEMACallback`. Teacher forward switches to that adapter downstream. + - `"ema"` (otherwise): + Teacher is a frozen deepcopy synchronized each step by `SyncTeacherModelCallback`. + + Must be called after `super().__init__` so that `self.callback_handler` is available. + """ teacher_model_kind = self.args.teacher_model_kind @@ -299,6 +319,7 @@ def _setup_teacher_model(self) -> None: return if self._use_peft_ema_teacher_adapter(): + # Must run after super().__init__ so self.callback_handler exists. self.add_callback( PEFTAdapterEMACallback( model=self.model, @@ -330,6 +351,7 @@ def _use_peft_ema_teacher_adapter(self) -> bool: return self.args.teacher_model_kind == "ema" and self._is_pure_lora_training() def _is_pure_lora_training(self) -> bool: + """Return `True` when the active adapter is LoRA and every trainable parameter is a LoRA parameter.""" if not is_peft_model(self.model): return False @@ -397,6 +419,11 @@ def training_step(self, model, inputs, num_items_in_batch): return output def _prepare_inputs(self, generation_batch): + """Return the per-step training batch, regenerating rollouts and buffering them for reuse in train mode. + + In train mode, rollouts are generated once every `steps_per_generation * num_iterations` steps and split + into per-step slices reused until the next regeneration. In eval mode, every batch is freshly prepared. + """ mode = "train" if self.model.training else "eval" if mode == "train": generate_every = self.args.steps_per_generation * self.num_iterations @@ -412,10 +439,10 @@ def _prepare_inputs(self, generation_batch): return self._prepare_training_batch(generation_batch) def _prepare_training_batch(self, inputs: list[dict[str, Any]]) -> TrainingBatch: + """Sample student rollouts and let the subclass finalize them into the final `TrainingBatch`.""" rollout_batch = self.sample_rollouts(inputs) batch = self.finalize_batch(inputs, rollout_batch) - self._validate_training_batch(batch) self._dispatch_self_distillation_callback( "on_self_distillation_batch_prepared", @@ -428,7 +455,7 @@ def _prepare_training_batch(self, inputs: list[dict[str, Any]]) -> TrainingBatch ) return batch - def _tokenize_prompts(self, prompts: list[Any]) -> list[list[int]]: + def _tokenize_prompts_untruncated(self, prompts: list[Any]) -> list[list[int]]: if is_conversational({"prompt": prompts[0]}): tokenized = self.processing_class.apply_chat_template( conversation=prompts, @@ -440,6 +467,10 @@ def _tokenize_prompts(self, prompts: list[Any]) -> list[list[int]]: prompt_ids = tokenized["input_ids"] else: prompt_ids = self.processing_class(text=prompts)["input_ids"] + return prompt_ids + + def _tokenize_prompts(self, prompts: list[Any]) -> list[list[int]]: + prompt_ids = self._tokenize_prompts_untruncated(prompts) if self.max_prompt_length is not None: prompt_ids = [ids[-self.max_prompt_length :] for ids in prompt_ids] return prompt_ids @@ -568,24 +599,31 @@ def _compute_rollout_logps( return old_per_token_logps - def _allow_topk_without_full_logit_distillation(self) -> bool: - return True - def _compute_self_distillation_loss( self, model, inputs: TrainingBatch, distillation_logits: DistillationLogits, ) -> torch.Tensor: + """Compute the per-token distillation loss and aggregate it according to `loss_type`. + + Dispatches between three objectives based on config: + + - `distillation_topk` is not `None`: top-k approximation of the divergence, optionally with a tail + bucket for the remaining probability mass (`distillation_add_tail`). + - `full_logit_distillation` is `True`: full-vocab divergence. + - otherwise: token-level (reverse-KL) distillation on sampled `completion_ids`. + + When `distillation_is_clip` is set and `old_per_token_logps` are available, the loss is corrected by a + clipped importance-sampling ratio between the current student and the student at rollout time. + """ if distillation_logits.response_mask.sum() == 0: mode = "train" if model.training else "eval" self._log_self_distillation_metric(mode, 0.0) - return torch.tensor(0.0, device=distillation_logits.completion_ids.device, requires_grad=True) + # Keep the zero loss attached to the student graph so backward produces zero gradients instead of stopping. + return distillation_logits.student_logits.sum() * 0.0 - use_topk_distillation = self.args.distillation_topk is not None and ( - self.args.full_logit_distillation or self._allow_topk_without_full_logit_distillation() - ) - if use_topk_distillation: + if self.args.distillation_topk is not None: per_token_loss = compute_topk_self_distillation_loss( distillation_logits.student_logits, distillation_logits.teacher_logits, @@ -638,6 +676,10 @@ def _compute_teacher_student_logits( teacher_model, inputs: TrainingBatch, ) -> DistillationLogits: + """Run student and teacher forwards on their respective inputs and pack aligned logits into a `DistillationLogits`. + + The teacher forward runs under the teacher context resolved by `_get_teacher_context_for_self_distillation`. + """ prompt_ids = inputs["prompt_ids"] prompt_mask = inputs["prompt_mask"] completion_ids = inputs["completion_ids"] @@ -721,6 +763,7 @@ def _forward_logits( attention_mask: torch.Tensor, logits_to_keep: int, ) -> torch.Tensor: + """Forward the model and return temperature-scaled logits aligned to the completion tokens.""" model_inputs = { "input_ids": input_ids, "attention_mask": attention_mask, @@ -734,41 +777,34 @@ def _forward_logits( logits = logits[:, -logits_to_keep:, :] return logits / self.temperature - def _validate_training_batch(self, batch: TrainingBatch) -> None: - required_keys = { - "prompt_ids", - "prompt_mask", - "completion_ids", - "completion_mask", - "teacher_input_ids", - "teacher_attention_mask", - } - missing_keys = required_keys.difference(batch) - if missing_keys: - raise ValueError(f"`finalize_batch` must return all required batch keys. Missing: {sorted(missing_keys)}") - - batch_size = batch["prompt_ids"].size(0) - if batch["prompt_mask"].size(0) != batch_size: - raise ValueError("`prompt_mask` must have the same batch size as `prompt_ids`.") - if batch["completion_ids"].size(0) != batch_size or batch["completion_mask"].size(0) != batch_size: - raise ValueError("`completion_ids` and `completion_mask` must match the student batch size.") - if batch["teacher_input_ids"].size(0) != batch_size or batch["teacher_attention_mask"].size(0) != batch_size: - raise ValueError("`teacher_input_ids` and `teacher_attention_mask` must match the student batch size.") - if batch["teacher_input_ids"].size(1) != batch["teacher_attention_mask"].size(1): - raise ValueError("`teacher_input_ids` and `teacher_attention_mask` must have the same sequence length.") - if "self_distillation_mask" in batch and batch["self_distillation_mask"] is not None: - if batch["self_distillation_mask"].size(0) != batch_size: - raise ValueError("`self_distillation_mask` must match the batch size when provided.") - def finalize_batch( self, inputs: list[dict[str, Any]], rollout_batch: RolloutBatch, ) -> TrainingBatch: - """Build the final training batch from a shared student rollout batch.""" + """Build the final training batch from a shared student rollout batch. + + Subclasses must return a `dict` with at least the following keys, all first-dim-aligned to the student + batch size `B`: + + - `prompt_ids`: student prompt ids. + - `prompt_mask`: student prompt mask. + - `completion_ids`: student completion ids. + - `completion_mask`: student completion mask. + - `teacher_input_ids`: teacher input ids. + - `teacher_attention_mask`: teacher attention mask. + """ return rollout_batch.as_dict() def _get_teacher_context_for_self_distillation(self): + """Return the context manager that routes the teacher forward to the correct weights. + + For non-PEFT models this is a no-op. For PEFT models: + + - `teacher_model_kind == "base"`: disable the student adapter so the teacher forward uses the base weights. + - `teacher_model_kind == "ema"` under pure-LoRA training: switch to the `"teacher"` LoRA adapter. + - otherwise: no-op; the teacher is a separate deepcopy. + """ teacher_model_kind = self.args.teacher_model_kind if not is_peft_model(self.model): return nullcontext() @@ -782,7 +818,7 @@ def _get_teacher_context_for_self_distillation(self): return nullcontext() def _log_self_distillation_metric(self, mode: str, value: float) -> None: - metric_prefix = getattr(self, "_name", "self_distillation").lower().replace(" ", "_") + metric_prefix = self._name.lower().replace(" ", "_") self._metrics[mode]["self_distillation/distillation_loss"].append(value) self._metrics[mode][f"{metric_prefix}/distillation_loss"].append(value) diff --git a/trl/experimental/self_distillation/self_distillation_config.py b/trl/experimental/self_distillation/self_distillation_config.py index d4e66dbce88..d10b5b18b9e 100644 --- a/trl/experimental/self_distillation/self_distillation_config.py +++ b/trl/experimental/self_distillation/self_distillation_config.py @@ -45,12 +45,8 @@ class SelfDistillationConfig(_BaseConfig): > Parameters that control the online policy objective - beta (`float`, *optional*, defaults to `0.0`): - Reference-model KL coefficient for online policy optimization. loss_type (`str`, *optional*, defaults to `"dapo"`): Policy-loss aggregation mode. Supported: `grpo`, `bnpo`, `dr_grpo`, `dapo`. - scale_rewards (`str` or `bool`, *optional*, defaults to `"group"`): - Reward normalization mode. Supported: `group`, `batch`, `none`. > Parameters that control teacher construction @@ -212,34 +208,10 @@ class SelfDistillationConfig(_BaseConfig): default=None, metadata={"help": "Model context length for vLLM. Inferred from model config if not set."}, ) - beta: float = field( - default=0.0, - metadata={"help": "Reference-model KL coefficient for online policy optimization."}, - ) num_iterations: int = field( default=1, metadata={"help": "Number of optimization iterations per generated batch."}, ) - epsilon: float = field( - default=0.2, - metadata={"help": "Lower clipping coefficient for GRPO-style policy loss."}, - ) - epsilon_high: float | None = field( - default=None, - metadata={"help": "Upper clipping coefficient. Defaults to `epsilon` when unset."}, - ) - importance_sampling_level: str = field( - default="token", - metadata={"help": "Importance-sampling granularity. Supported: `token`, `sequence`."}, - ) - reward_weights: list[float] | None = field( - default=None, - metadata={"help": "Optional weights for multiple reward functions."}, - ) - scale_rewards: str | bool = field( - default="group", - metadata={"help": "Reward normalization mode. Supported: `group`, `batch`, `none`."}, - ) loss_type: str = field( default="dapo", metadata={"help": "Policy loss aggregation. Supported: `grpo`, `bnpo`, `dr_grpo`, `dapo`."}, @@ -307,12 +279,6 @@ class SelfDistillationConfig(_BaseConfig): def __post_init__(self): super().__post_init__() - self.scale_rewards = {True: "group", False: "none"}.get(self.scale_rewards, self.scale_rewards) - if self.scale_rewards not in ["group", "batch", "none"]: - raise ValueError("scale_rewards must be one of: 'group', 'batch', 'none'") - - if self.importance_sampling_level not in ["token", "sequence"]: - raise ValueError("importance_sampling_level must be either 'token' or 'sequence'") if self.loss_type not in ["grpo", "bnpo", "dr_grpo", "dapo"]: raise ValueError("loss_type must be one of: 'grpo', 'bnpo', 'dr_grpo', 'dapo'") if self.teacher_model_kind not in {"base", "live", "ema"}: @@ -364,6 +330,3 @@ def __post_init__(self): f"The global eval batch size ({self.per_device_eval_batch_size} * {num_processes}) must be " f"divisible by the number of generations used for evaluation ({num_generations_eval})." ) - - if self.epsilon_high is None: - self.epsilon_high = self.epsilon From d2a78e24fd45e23aa42b0f44726c36c7eab83637 Mon Sep 17 00:00:00 2001 From: Leon Date: Fri, 17 Apr 2026 15:51:07 +0200 Subject: [PATCH 15/23] tests: ported old tests + new tests for base class --- .../test_base_self_distillation_trainer.py | 187 ++++++++++++++++++ tests/experimental/test_sdft_trainer.py | 39 ++-- tests/experimental/test_sdpo_trainer.py | 16 +- .../base_self_distillation_trainer.py | 8 + 4 files changed, 218 insertions(+), 32 deletions(-) create mode 100644 tests/experimental/test_base_self_distillation_trainer.py diff --git a/tests/experimental/test_base_self_distillation_trainer.py b/tests/experimental/test_base_self_distillation_trainer.py new file mode 100644 index 00000000000..55224acc649 --- /dev/null +++ b/tests/experimental/test_base_self_distillation_trainer.py @@ -0,0 +1,187 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest +import torch +from datasets import Dataset + +from trl.experimental.self_distillation.base_self_distillation_trainer import BaseSelfDistillationTrainer +from trl.experimental.self_distillation.self_distillation_config import SelfDistillationConfig + +from ..testing_utils import TrlTestCase + + +class MinimalSelfDistillationTrainer(BaseSelfDistillationTrainer): + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + del inputs, num_items_in_batch + anchor = next(model.parameters()) + return anchor.sum() * 0.0 + + +class FakeTextTokenizer: + def __call__(self, text, **kwargs): + del kwargs + token_map = { + "short prompt": [1, 2, 3], + "long prompt": [10, 11, 12, 13, 14], + } + return {"input_ids": [token_map[prompt] for prompt in text]} + + +class FakeChatProcessor: + def __init__(self): + self.calls = [] + + def apply_chat_template(self, conversation, add_generation_prompt, tokenize, return_dict, **kwargs): + self.calls.append( + { + "conversation": conversation, + "add_generation_prompt": add_generation_prompt, + "tokenize": tokenize, + "return_dict": return_dict, + "kwargs": kwargs, + } + ) + return {"input_ids": [[21, 22, 23, 24]]} + + +class TestBaseSelfDistillationTrainer(TrlTestCase): + def test_teacher_model_kind_live_uses_student_model(self): + dataset = Dataset.from_dict({"prompt": ["Solve 2+2."]}) + training_args = SelfDistillationConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=1, + max_completion_length=8, + max_steps=1, + num_generations=1, + teacher_model_kind="live", + report_to="none", + ) + + trainer = MinimalSelfDistillationTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + ) + + assert trainer.teacher_model is trainer.model + + @pytest.mark.parametrize("teacher_model_kind", ["base", "ema"]) + def test_teacher_model_kind_base_and_ema_use_frozen_teacher_copy(self, teacher_model_kind): + dataset = Dataset.from_dict({"prompt": ["Solve 2+2."]}) + training_args = SelfDistillationConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=1, + max_completion_length=8, + max_steps=1, + num_generations=1, + teacher_model_kind=teacher_model_kind, + report_to="none", + ) + + trainer = MinimalSelfDistillationTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + ) + + assert trainer.teacher_model is not trainer.model + assert trainer.teacher_model.training is False + + student_param = next(trainer.model.parameters()) + teacher_param = next(trainer.teacher_model.parameters()) + assert teacher_param.requires_grad is False + assert teacher_param.data_ptr() != student_param.data_ptr() + + def test_tokenize_prompts_truncates_text_prompts_from_left(self): + trainer = object.__new__(MinimalSelfDistillationTrainer) + trainer.processing_class = FakeTextTokenizer() + trainer.max_prompt_length = 3 + + prompt_ids = trainer._tokenize_prompts(["long prompt", "short prompt"]) + + assert prompt_ids == [[12, 13, 14], [1, 2, 3]] + + def test_tokenize_prompts_for_conversational_prompts_forwards_chat_template_kwargs(self): + trainer = object.__new__(MinimalSelfDistillationTrainer) + trainer.processing_class = FakeChatProcessor() + trainer.max_prompt_length = 2 + trainer.chat_template_kwargs = {"enable_thinking": False} + + prompt_ids = trainer._tokenize_prompts([[{"role": "user", "content": "Solve 2+2."}]]) + + assert prompt_ids == [[23, 24]] + assert trainer.processing_class.calls == [ + { + "conversation": [[{"role": "user", "content": "Solve 2+2."}]], + "add_generation_prompt": True, + "tokenize": True, + "return_dict": True, + "kwargs": {"enable_thinking": False}, + } + ] + + def test_prepare_inputs_reuses_buffered_generation_batches_within_window(self): + trainer = object.__new__(MinimalSelfDistillationTrainer) + trainer.model = SimpleNamespace(training=True) + trainer.args = SimpleNamespace(steps_per_generation=2) + trainer.num_iterations = 1 + trainer._step = 0 + trainer._buffered_inputs = None + trainer.callback_handler = SimpleNamespace(callbacks=[]) + trainer.state = SimpleNamespace() + trainer.control = SimpleNamespace() + trainer.processing_class = None + trainer._prepare_training_batch = Mock( + side_effect=[ + {"value": torch.tensor([[1.0], [2.0]])}, + {"value": torch.tensor([[3.0], [4.0]])}, + ] + ) + + first_batch = trainer._prepare_inputs([{"prompt": "first"}]) + + trainer._step = 1 + second_batch = trainer._prepare_inputs([{"prompt": "second"}]) + + trainer._step = 2 + third_batch = trainer._prepare_inputs([{"prompt": "third"}]) + + assert first_batch["value"].item() == 1.0 + assert second_batch["value"].item() == 2.0 + assert third_batch["value"].item() == 3.0 + assert trainer._prepare_training_batch.call_count == 2 + + def test_prediction_step_prepares_list_batches_before_computing_loss(self): + trainer = object.__new__(MinimalSelfDistillationTrainer) + prepared_inputs = {"prompt_ids": torch.tensor([[1, 2]])} + trainer._prepare_inputs = Mock(return_value=prepared_inputs) + trainer.compute_loss = Mock(return_value=torch.tensor(3.5, requires_grad=True)) + trainer.compute_loss_context_manager = lambda: nullcontext() + + loss, logits, labels = trainer.prediction_step( + model=None, + inputs=[{"prompt": "Solve 2+2."}], + prediction_loss_only=True, + ) + + trainer._prepare_inputs.assert_called_once() + trainer.compute_loss.assert_called_once_with(None, prepared_inputs) + assert loss.item() == 3.5 + assert logits is None + assert labels is None diff --git a/tests/experimental/test_sdft_trainer.py b/tests/experimental/test_sdft_trainer.py index 9ca9b6c579a..f307c5a96dd 100644 --- a/tests/experimental/test_sdft_trainer.py +++ b/tests/experimental/test_sdft_trainer.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import patch + import pytest import torch from datasets import Dataset from transformers import AutoModelForCausalLM, TrainerCallback, TrainerControl, TrainerState, TrainingArguments from transformers.utils import is_peft_available -from trl.data_utils import maybe_apply_chat_template from trl.experimental.sdft import SDFTConfig, SDFTTrainer from ..testing_utils import TrlTestCase, require_peft @@ -27,18 +28,18 @@ if is_peft_available(): from peft import LoraConfig, get_peft_model, get_peft_model_state_dict - from trl.experimental.self_distillation.peft_adapter_ema_callback import PEFTAdapterEMACallback + from trl.experimental.self_distillation.teacher_sync import PEFTAdapterEMACallback class SelfDistillationCaptureCallback(TrainerCallback): def __init__(self): - self.captured_generation_prompt_text = None + self.captured_generation_prompts = None self.captured_old_per_token_logps = None self.generation_batch_build_count = 0 - def on_generation_prompts_selected(self, generation_prompt_text=None, **kwargs): - if self.captured_generation_prompt_text is None and generation_prompt_text is not None: - self.captured_generation_prompt_text = generation_prompt_text[0] + def on_generation_prompts_selected(self, generation_prompts=None, **kwargs): + if self.captured_generation_prompts is None and generation_prompts is not None: + self.captured_generation_prompts = generation_prompts def on_self_distillation_batch_prepared(self, old_per_token_logps=None, **kwargs): if self.captured_old_per_token_logps is None and old_per_token_logps is not None: @@ -74,6 +75,7 @@ def test_training_rejects_none_privileged_context(self): with pytest.raises(ValueError, match="`privileged_context` must not be None"): trainer.train() + @pytest.mark.skip(reason="`generate_from_teacher` is not ported yet") def test_training_with_generate_from_teacher(self): dataset = Dataset.from_dict( { @@ -105,9 +107,8 @@ def test_training_with_generate_from_teacher(self): trainer.train() - assert capture_callback.captured_generation_prompt_text is not None - assert "Solve 2+2." in capture_callback.captured_generation_prompt_text - assert "Teacher hint" in capture_callback.captured_generation_prompt_text + assert capture_callback.captured_generation_prompts is not None + assert capture_callback.captured_generation_prompts[0] != dataset[0]["prompt"] def test_training_with_chat_template_kwargs(self): dataset = Dataset.from_dict( @@ -141,15 +142,15 @@ def test_training_with_chat_template_kwargs(self): callbacks=[capture_callback], ) - expected_prompt = maybe_apply_chat_template( - {"prompt": dataset[0]["prompt"]}, + with patch.object( trainer.processing_class, - **training_args.chat_template_kwargs, - )["prompt"] - - trainer.train() + "apply_chat_template", + wraps=trainer.processing_class.apply_chat_template, + ) as mock_apply_chat_template: + trainer.train() - assert capture_callback.captured_generation_prompt_text == expected_prompt + assert mock_apply_chat_template.call_count > 0 + assert any(call.kwargs.get("enable_thinking") is False for call in mock_apply_chat_template.call_args_list) @require_peft def test_training_with_peft_model(self): @@ -205,9 +206,9 @@ def test_training_with_peft_model_and_sync_ref_model(self): max_completion_length=8, max_steps=2, num_generations=1, - sync_ref_model=True, - ref_model_mixup_alpha=0.05, - ref_model_sync_steps=1, + teacher_model_kind="ema", + teacher_update_rate=0.05, + teacher_sync_steps=1, ) trainer = SDFTTrainer( diff --git a/tests/experimental/test_sdpo_trainer.py b/tests/experimental/test_sdpo_trainer.py index 1bb666a7dcc..d800968966f 100644 --- a/tests/experimental/test_sdpo_trainer.py +++ b/tests/experimental/test_sdpo_trainer.py @@ -100,14 +100,6 @@ def test_vllm_config_defaults_match_reference_trainers(self): assert config.vllm_model_impl == "vllm" def test_generate_vllm_syncs_on_step_change_and_uses_mode_specific_num_generations(self): - class FakeTokenizer: - def __call__(self, text, **kwargs): - token_map = { - "Solve 2+2.": [11, 12], - "Check 3+3.": [21, 22], - } - return {"input_ids": [token_map[prompt] for prompt in text]} - class FakeVLLMGeneration: def __init__(self): self.sync_weights_call_count = 0 @@ -135,11 +127,9 @@ def generate(self, prompts, images, num_generations): trainer.model = SimpleNamespace(training=True) trainer.state = SimpleNamespace(global_step=4) trainer._last_loaded_step = 3 - trainer.processing_class = FakeTokenizer() trainer.vllm_generation = FakeVLLMGeneration() - trainer._apply_prompt_template = lambda prompts: prompts - prompt_ids, completion_ids = trainer._generate(["Solve 2+2.", "Solve 2+2."]) + prompt_ids, completion_ids = trainer._generate([[11, 12], [11, 12]]) assert prompt_ids == [[11, 12], [11, 12]] assert completion_ids == [[100], [101]] @@ -154,7 +144,7 @@ def generate(self, prompts, images, num_generations): ] trainer.model.training = False - eval_prompt_ids, eval_completion_ids = trainer._generate(["Check 3+3.", "Check 3+3.", "Check 3+3."]) + eval_prompt_ids, eval_completion_ids = trainer._generate([[21, 22], [21, 22], [21, 22]]) assert eval_prompt_ids == [[21, 22], [21, 22], [21, 22]] assert eval_completion_ids == [[100], [101], [102]] @@ -167,7 +157,7 @@ def generate(self, prompts, images, num_generations): trainer.model.training = True trainer.state.global_step = 5 - trainer._generate(["Solve 2+2.", "Solve 2+2."]) + trainer._generate([[11, 12], [11, 12]]) assert trainer.vllm_generation.sync_weights_call_count == 2 assert trainer._last_loaded_step == 5 diff --git a/trl/experimental/self_distillation/base_self_distillation_trainer.py b/trl/experimental/self_distillation/base_self_distillation_trainer.py index 72e3f1d2ab0..ddbf816d920 100644 --- a/trl/experimental/self_distillation/base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/base_self_distillation_trainer.py @@ -418,6 +418,14 @@ def training_step(self, model, inputs, num_items_in_batch): self._step += 1 return output + def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): + if not isinstance(inputs, dict): + inputs = self._prepare_inputs(inputs) + with torch.no_grad(): + with self.compute_loss_context_manager(): + loss = self.compute_loss(model, inputs) + return loss.detach(), None, None + def _prepare_inputs(self, generation_batch): """Return the per-step training batch, regenerating rollouts and buffering them for reuse in train mode. From 8807088cde638175d700086d36fec673b96a82ad Mon Sep 17 00:00:00 2001 From: Leon Date: Sat, 18 Apr 2026 12:38:14 +0200 Subject: [PATCH 16/23] couple more tests and test cleanup --- .../test_base_self_distillation_trainer.py | 107 +++++++++++++++--- 1 file changed, 89 insertions(+), 18 deletions(-) diff --git a/tests/experimental/test_base_self_distillation_trainer.py b/tests/experimental/test_base_self_distillation_trainer.py index 55224acc649..7c387e54b35 100644 --- a/tests/experimental/test_base_self_distillation_trainer.py +++ b/tests/experimental/test_base_self_distillation_trainer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext +from collections import defaultdict from types import SimpleNamespace from unittest.mock import Mock @@ -20,7 +20,10 @@ import torch from datasets import Dataset -from trl.experimental.self_distillation.base_self_distillation_trainer import BaseSelfDistillationTrainer +from trl.experimental.self_distillation.base_self_distillation_trainer import ( + BaseSelfDistillationTrainer, + DistillationLogits, +) from trl.experimental.self_distillation.self_distillation_config import SelfDistillationConfig from ..testing_utils import TrlTestCase @@ -61,6 +64,28 @@ def apply_chat_template(self, conversation, add_generation_prompt, tokenize, ret class TestBaseSelfDistillationTrainer(TrlTestCase): + @staticmethod + def _make_loss_test_trainer(**args_overrides): + trainer = object.__new__(MinimalSelfDistillationTrainer) + args = { + "distillation_topk": None, + "full_logit_distillation": False, + "distillation_alpha": 1.0, + "distillation_add_tail": False, + "distillation_is_clip": None, + } + args.update(args_overrides) + trainer.args = SimpleNamespace(**args) + trainer.loss_type = "dapo" + trainer.max_completion_length = 2 + trainer.accelerator = SimpleNamespace(gather=lambda tensor: tensor) + trainer._metrics = { + "train": defaultdict(list), + "eval": defaultdict(list), + } + trainer._name = "Minimal Self Distillation" + return trainer + def test_teacher_model_kind_live_uses_student_model(self): dataset = Dataset.from_dict({"prompt": ["Solve 2+2."]}) training_args = SelfDistillationConfig( @@ -167,21 +192,67 @@ def test_prepare_inputs_reuses_buffered_generation_batches_within_window(self): assert third_batch["value"].item() == 3.0 assert trainer._prepare_training_batch.call_count == 2 - def test_prediction_step_prepares_list_batches_before_computing_loss(self): - trainer = object.__new__(MinimalSelfDistillationTrainer) - prepared_inputs = {"prompt_ids": torch.tensor([[1, 2]])} - trainer._prepare_inputs = Mock(return_value=prepared_inputs) - trainer.compute_loss = Mock(return_value=torch.tensor(3.5, requires_grad=True)) - trainer.compute_loss_context_manager = lambda: nullcontext() - - loss, logits, labels = trainer.prediction_step( - model=None, - inputs=[{"prompt": "Solve 2+2."}], - prediction_loss_only=True, + def test_compute_self_distillation_loss_ignores_masked_completion_tokens(self): + trainer = self._make_loss_test_trainer( + full_logit_distillation=True, + distillation_alpha=0.0, + ) + model = SimpleNamespace(training=True) + + # Token 0 is active and has a known non-zero divergence. + # Token 1 is intentionally very different but masked out, so it must not affect the loss. + student_probs = torch.tensor([[[0.8, 0.2], [0.01, 0.99]]], dtype=torch.float32) + teacher_probs = torch.tensor([[[0.5, 0.5], [0.99, 0.01]]], dtype=torch.float32) + distillation_logits = DistillationLogits( + completion_ids=torch.tensor([[0, 1]], dtype=torch.long), + completion_mask=torch.tensor([[1, 1]], dtype=torch.long), + response_mask=torch.tensor([[1, 0]], dtype=torch.long), + student_logits=student_probs.log(), + teacher_logits=teacher_probs.log(), + ) + + loss = trainer._compute_self_distillation_loss(model, {}, distillation_logits) + + expected_active_token_loss = teacher_probs[0, 0, 0] * ( + teacher_probs[0, 0, 0].log() - student_probs[0, 0, 0].log() + ) + teacher_probs[0, 0, 1] * (teacher_probs[0, 0, 1].log() - student_probs[0, 0, 1].log()) + torch.testing.assert_close(loss, expected_active_token_loss) + torch.testing.assert_close( + torch.tensor(trainer._metrics["train"]["self_distillation/distillation_loss"]), + expected_active_token_loss.unsqueeze(0), + ) + + def test_compute_self_distillation_loss_applies_importance_sampling_clip(self): + trainer = self._make_loss_test_trainer(distillation_is_clip=2.0) + model = SimpleNamespace(training=True) + + student_token_probs = torch.tensor([[0.2, 0.4]], dtype=torch.float32) + teacher_token_probs = torch.tensor([[0.5, 0.5]], dtype=torch.float32) + old_token_probs = torch.tensor([[0.05, 0.4]], dtype=torch.float32) + clip_coeff = trainer.args.distillation_is_clip + + distillation_logits = DistillationLogits( + completion_ids=torch.tensor([[0, 1]], dtype=torch.long), + completion_mask=torch.tensor([[1, 1]], dtype=torch.long), + response_mask=torch.tensor([[1, 1]], dtype=torch.long), + student_logits=torch.log(torch.tensor([[[0.2, 0.8], [0.6, 0.4]]], dtype=torch.float32)), + teacher_logits=torch.log(torch.tensor([[[0.5, 0.5], [0.5, 0.5]]], dtype=torch.float32)), + ) + + loss = trainer._compute_self_distillation_loss( + model, + {"old_per_token_logps": old_token_probs.log()}, + distillation_logits, + ) + + raw_per_token_loss = (student_token_probs.log() - teacher_token_probs.log()) * student_token_probs.log() + clipped_ratio = torch.minimum( + student_token_probs / old_token_probs, torch.full_like(student_token_probs, clip_coeff) ) + expected_loss = (raw_per_token_loss * clipped_ratio).mean() - trainer._prepare_inputs.assert_called_once() - trainer.compute_loss.assert_called_once_with(None, prepared_inputs) - assert loss.item() == 3.5 - assert logits is None - assert labels is None + torch.testing.assert_close(loss, expected_loss) + torch.testing.assert_close( + torch.tensor(trainer._metrics["train"]["self_distillation/distillation_loss"]), + expected_loss.unsqueeze(0), + ) From 0612699d18c421cc88be27ca25df37710ef20e34 Mon Sep 17 00:00:00 2001 From: Leon Date: Sat, 18 Apr 2026 12:40:57 +0200 Subject: [PATCH 17/23] test: nit fix --- .../test_base_self_distillation_trainer.py | 32 ------------------- 1 file changed, 32 deletions(-) diff --git a/tests/experimental/test_base_self_distillation_trainer.py b/tests/experimental/test_base_self_distillation_trainer.py index 7c387e54b35..f061cfbe516 100644 --- a/tests/experimental/test_base_self_distillation_trainer.py +++ b/tests/experimental/test_base_self_distillation_trainer.py @@ -14,7 +14,6 @@ from collections import defaultdict from types import SimpleNamespace -from unittest.mock import Mock import pytest import torch @@ -161,37 +160,6 @@ def test_tokenize_prompts_for_conversational_prompts_forwards_chat_template_kwar } ] - def test_prepare_inputs_reuses_buffered_generation_batches_within_window(self): - trainer = object.__new__(MinimalSelfDistillationTrainer) - trainer.model = SimpleNamespace(training=True) - trainer.args = SimpleNamespace(steps_per_generation=2) - trainer.num_iterations = 1 - trainer._step = 0 - trainer._buffered_inputs = None - trainer.callback_handler = SimpleNamespace(callbacks=[]) - trainer.state = SimpleNamespace() - trainer.control = SimpleNamespace() - trainer.processing_class = None - trainer._prepare_training_batch = Mock( - side_effect=[ - {"value": torch.tensor([[1.0], [2.0]])}, - {"value": torch.tensor([[3.0], [4.0]])}, - ] - ) - - first_batch = trainer._prepare_inputs([{"prompt": "first"}]) - - trainer._step = 1 - second_batch = trainer._prepare_inputs([{"prompt": "second"}]) - - trainer._step = 2 - third_batch = trainer._prepare_inputs([{"prompt": "third"}]) - - assert first_batch["value"].item() == 1.0 - assert second_batch["value"].item() == 2.0 - assert third_batch["value"].item() == 3.0 - assert trainer._prepare_training_batch.call_count == 2 - def test_compute_self_distillation_loss_ignores_masked_completion_tokens(self): trainer = self._make_loss_test_trainer( full_logit_distillation=True, From 3d0cd7298d04326b93d0d564dcf190bff0451411 Mon Sep 17 00:00:00 2001 From: Leon Date: Sat, 18 Apr 2026 14:18:57 +0200 Subject: [PATCH 18/23] move loss aggregation to loss_util + a few docstrings --- trl/experimental/sdpo/sdpo_trainer.py | 9 +++- .../base_self_distillation_trainer.py | 25 ++++------- .../self_distillation/loss_utils.py | 42 +++++++++++++++++++ 3 files changed, 56 insertions(+), 20 deletions(-) diff --git a/trl/experimental/sdpo/sdpo_trainer.py b/trl/experimental/sdpo/sdpo_trainer.py index 98e2d023033..2bfa38acc0a 100644 --- a/trl/experimental/sdpo/sdpo_trainer.py +++ b/trl/experimental/sdpo/sdpo_trainer.py @@ -40,7 +40,7 @@ RolloutBatch, TrainingBatch, ) -from ..self_distillation.loss_utils import select_token_log_probs +from ..self_distillation.loss_utils import aggregate_loss, select_token_log_probs from ..self_distillation.prompt_utils import extract_last_user_text from .sdpo_config import SDPOConfig @@ -596,7 +596,12 @@ def _compute_policy_loss( coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) per_token_loss = -torch.min(coef_1 * advantages, coef_2 * advantages) - loss = self._aggregate_self_distillation_loss(per_token_loss, completion_mask) + loss = aggregate_loss( + per_token_loss, + completion_mask, + loss_type=self.loss_type, + max_completion_length=self.max_completion_length, + ) mode = "train" if self.model.training else "eval" self._metrics[mode]["self_distillation/policy_loss"].append( diff --git a/trl/experimental/self_distillation/base_self_distillation_trainer.py b/trl/experimental/self_distillation/base_self_distillation_trainer.py index ddbf816d920..bb124d16cde 100644 --- a/trl/experimental/self_distillation/base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/base_self_distillation_trainer.py @@ -57,6 +57,7 @@ ) from ..utils import prepare_peft_model from .loss_utils import ( + aggregate_loss, apply_importance_sampling_clipping, compute_full_logit_self_distillation_loss, compute_sampled_token_self_distillation_loss, @@ -666,7 +667,12 @@ def _compute_self_distillation_loss( self.args.distillation_is_clip, ) - loss = self._aggregate_self_distillation_loss(per_token_loss, distillation_logits.response_mask) + loss = aggregate_loss( + per_token_loss, + distillation_logits.response_mask, + loss_type=self.loss_type, + max_completion_length=self.max_completion_length, + ) mode = "train" if model.training else "eval" mean_distill_loss = ( @@ -830,23 +836,6 @@ def _log_self_distillation_metric(self, mode: str, value: float) -> None: self._metrics[mode]["self_distillation/distillation_loss"].append(value) self._metrics[mode][f"{metric_prefix}/distillation_loss"].append(value) - def _aggregate_self_distillation_loss( - self, - per_token_loss: torch.Tensor, - response_mask: torch.Tensor, - ) -> torch.Tensor: - loss_type = self.loss_type - if loss_type == "grpo": - loss = (per_token_loss * response_mask).sum(-1) / response_mask.sum(-1).clamp(min=1.0) - return loss.mean() - if loss_type == "bnpo": - return (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) - if loss_type == "dr_grpo": - return (per_token_loss * response_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) - if loss_type in ["dapo", "luspo", "cispo", "sapo"]: - return (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) - raise ValueError(f"Unsupported loss_type for self-distillation: {loss_type}") - @abstractmethod def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """Subclasses own algorithm-specific loss composition on the final batch contract.""" diff --git a/trl/experimental/self_distillation/loss_utils.py b/trl/experimental/self_distillation/loss_utils.py index ebdf9689063..64b35ff42a1 100644 --- a/trl/experimental/self_distillation/loss_utils.py +++ b/trl/experimental/self_distillation/loss_utils.py @@ -81,6 +81,37 @@ def apply_importance_sampling_clipping( return per_token_loss * ratio +def aggregate_loss( + per_token_loss: torch.Tensor, + response_mask: torch.Tensor, + *, + loss_type: str, + max_completion_length: int, +) -> torch.Tensor: + """Reduce a per-token loss tensor according to the configured reduction. + + Args: + per_token_loss: + Per-token loss values of shape `(batch_size, seq_len)`. + response_mask: + Mask selecting which completion-token positions contribute to the loss. + loss_type: + Reduction mode. Uses the same loss-type conventions as the GRPO-family trainers. + max_completion_length: + Used by the `dr_grpo` reduction, which normalizes by a fixed completion budget. + """ + if loss_type == "grpo": + loss = (per_token_loss * response_mask).sum(-1) / response_mask.sum(-1).clamp(min=1.0) + return loss.mean() + if loss_type == "bnpo": + return (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) + if loss_type == "dr_grpo": + return (per_token_loss * response_mask).sum() / (per_token_loss.size(0) * max_completion_length) + if loss_type in ["dapo", "luspo", "cispo", "sapo"]: + return (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) + raise ValueError(f"Unsupported loss_type: {loss_type}") + + def compute_topk_self_distillation_loss( student_logits: torch.Tensor, teacher_logits: torch.Tensor, @@ -89,6 +120,11 @@ def compute_topk_self_distillation_loss( distillation_alpha: float, distillation_add_tail: bool, ) -> torch.Tensor: + """Compute distillation loss on the student's top-k token support. + + The student's top-k logits define the support. The teacher distribution is projected onto the same token indices. + The selected support is then either renormalized or augmented with a tail bucket before the divergence is computed. + """ student_logsumexp = torch.logsumexp(student_logits, dim=-1, keepdim=True) topk_student_logits, topk_indices = torch.topk(student_logits, k=distillation_topk, dim=-1) topk_student_log_probs = topk_student_logits - student_logsumexp @@ -113,6 +149,7 @@ def compute_full_logit_self_distillation_loss( *, distillation_alpha: float, ) -> torch.Tensor: + """Compute full-vocabulary self-distillation loss between student and teacher logits.""" student_log_probs = torch.log_softmax(student_logits, dim=-1) teacher_log_probs = torch.log_softmax(teacher_logits, dim=-1) return compute_divergence(student_log_probs, teacher_log_probs, distillation_alpha) @@ -125,6 +162,11 @@ def compute_sampled_token_self_distillation_loss( *, distillation_alpha: float, ) -> torch.Tensor: + """Compute token-level self-distillation loss only on the sampled completion tokens. + + This path compares student and teacher log-probabilities on the realized completion tokens rather than over a + larger token support. + """ if distillation_alpha != 1.0: raise ValueError( "Only reverse KL (alpha=1.0) is supported for token-level distillation when " From aa36955cbfda5daf701073cd0a1a47fdf03d9802 Mon Sep 17 00:00:00 2001 From: Leon Date: Mon, 20 Apr 2026 18:03:44 +0200 Subject: [PATCH 19/23] fix: emit accumulated _metrics via log() override BaseSelfDistillationTrainer was populating _metrics in _log_self_distillation_metric but had no log() override, so those metrics were never forwarded to the Trainer's logging system. The fix merges _metrics into the log dict, prefixes eval keys, and clears after each logging step. --- .../self_distillation/base_self_distillation_trainer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/trl/experimental/self_distillation/base_self_distillation_trainer.py b/trl/experimental/self_distillation/base_self_distillation_trainer.py index bb124d16cde..6a99c5491ec 100644 --- a/trl/experimental/self_distillation/base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/base_self_distillation_trainer.py @@ -836,6 +836,15 @@ def _log_self_distillation_metric(self, mode: str, value: float) -> None: self._metrics[mode]["self_distillation/distillation_loss"].append(value) self._metrics[mode][f"{metric_prefix}/distillation_loss"].append(value) + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: + mode = "train" if self.model.training else "eval" + metrics = {k: sum(v) / len(v) for k, v in self._metrics[mode].items() if v} + if mode == "eval": + metrics = {f"eval_{k}": v for k, v in metrics.items()} + logs = {**logs, **metrics} + super().log(logs, start_time) + self._metrics[mode].clear() + @abstractmethod def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """Subclasses own algorithm-specific loss composition on the final batch contract.""" From a432c208838f560618d854dc0793024cafb6a404 Mon Sep 17 00:00:00 2001 From: Leon Date: Mon, 20 Apr 2026 22:01:56 +0200 Subject: [PATCH 20/23] fix: minor cursor issues + config docstrings --- .../test_base_self_distillation_trainer.py | 7 ++ trl/experimental/sdft/sdft_trainer.py | 2 +- trl/experimental/sdpo/sdpo_trainer.py | 14 ++- .../base_self_distillation_trainer.py | 2 +- .../self_distillation_config.py | 107 +++++++++++++++--- 5 files changed, 110 insertions(+), 22 deletions(-) diff --git a/tests/experimental/test_base_self_distillation_trainer.py b/tests/experimental/test_base_self_distillation_trainer.py index f061cfbe516..880678a0e73 100644 --- a/tests/experimental/test_base_self_distillation_trainer.py +++ b/tests/experimental/test_base_self_distillation_trainer.py @@ -29,6 +29,13 @@ class MinimalSelfDistillationTrainer(BaseSelfDistillationTrainer): + def finalize_batch(self, inputs, rollout_batch): + del inputs + batch = rollout_batch.as_dict() + batch["teacher_input_ids"] = rollout_batch.prompt_ids + batch["teacher_attention_mask"] = rollout_batch.prompt_mask + return batch + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): del inputs, num_items_in_batch anchor = next(model.parameters()) diff --git a/trl/experimental/sdft/sdft_trainer.py b/trl/experimental/sdft/sdft_trainer.py index aaca27575ba..7aeee6d2ba6 100644 --- a/trl/experimental/sdft/sdft_trainer.py +++ b/trl/experimental/sdft/sdft_trainer.py @@ -183,7 +183,7 @@ def finalize_batch( rollout_batch.completion_mask, ) - batch = super().finalize_batch(inputs, rollout_batch) + batch = rollout_batch.as_dict() batch.update( { "teacher_input_ids": teacher_batch["teacher_input_ids"], diff --git a/trl/experimental/sdpo/sdpo_trainer.py b/trl/experimental/sdpo/sdpo_trainer.py index 2bfa38acc0a..b0d542c2a48 100644 --- a/trl/experimental/sdpo/sdpo_trainer.py +++ b/trl/experimental/sdpo/sdpo_trainer.py @@ -445,7 +445,6 @@ def finalize_batch( rollout_dict = rollout_batch.as_dict() rollout_dict["rewards"] = local_rewards rollout_dict["advantages"] = local_advantages - rollout_dict["num_items_in_batch"] = rollout_batch.completion_mask.sum().detach() teacher_context = self.teacher_context_builder.build( rollout_dict, prompts, @@ -466,7 +465,7 @@ def finalize_batch( self_distillation_mask=teacher_context["self_distillation_mask"], ) - batch = super().finalize_batch(inputs, rollout_batch) + batch = rollout_batch.as_dict() batch.update( { "teacher_input_ids": teacher_context["teacher_input_ids"], @@ -644,12 +643,10 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N if self.args.sdpo_policy_loss_mode == "hybrid": return self._compute_hybrid_loss(model, inputs) - - elif self.args.sdpo_policy_loss_mode == "distillation_only": + if self.args.sdpo_policy_loss_mode == "distillation_only": distillation_logits = self._compute_teacher_student_logits(model, self.teacher_model, inputs) return self._compute_weighted_self_distillation_loss(model, inputs, distillation_logits) - - elif self.args.sdpo_policy_loss_mode == "policy_only": + if self.args.sdpo_policy_loss_mode == "policy_only": student_logits = self._compute_student_distillation_logits( model=model, prompt_ids=inputs["prompt_ids"], @@ -659,3 +656,8 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N logits_to_keep=inputs["completion_ids"].size(1), ) return self._compute_policy_loss(inputs, student_logits) + + raise ValueError( + "Unsupported `sdpo_policy_loss_mode`: " + f"{self.args.sdpo_policy_loss_mode!r}. Expected one of: 'hybrid', 'distillation_only', 'policy_only'." + ) diff --git a/trl/experimental/self_distillation/base_self_distillation_trainer.py b/trl/experimental/self_distillation/base_self_distillation_trainer.py index 6a99c5491ec..e76157ae263 100644 --- a/trl/experimental/self_distillation/base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/base_self_distillation_trainer.py @@ -791,6 +791,7 @@ def _forward_logits( logits = logits[:, -logits_to_keep:, :] return logits / self.temperature + @abstractmethod def finalize_batch( self, inputs: list[dict[str, Any]], @@ -808,7 +809,6 @@ def finalize_batch( - `teacher_input_ids`: teacher input ids. - `teacher_attention_mask`: teacher attention mask. """ - return rollout_batch.as_dict() def _get_teacher_context_for_self_distillation(self): """Return the context manager that routes the teacher forward to the correct weights. diff --git a/trl/experimental/self_distillation/self_distillation_config.py b/trl/experimental/self_distillation/self_distillation_config.py index d10b5b18b9e..7efd525aec9 100644 --- a/trl/experimental/self_distillation/self_distillation_config.py +++ b/trl/experimental/self_distillation/self_distillation_config.py @@ -30,45 +30,117 @@ class SelfDistillationConfig(_BaseConfig): [`trl.trainer.base_config._BaseConfig`]. Parameters: - > Parameters that control generation and rollout reuse + > Parameters that control the model model_init_kwargs (`dict[str, Any]`, *optional*): - Keyword arguments used when the `model` argument is passed as a string. + Keyword arguments for model initialization when the `model` argument is passed as a string. + disable_dropout (`bool`, *optional*, defaults to `False`): + Whether to disable dropout in the student model. + remove_unused_columns (`bool`, *optional*, defaults to `False`): + Whether to drop dataset columns unused by the trainer. + + > Parameters that control generation and rollout reuse + max_prompt_length (`int` or `None`, *optional*, defaults to `512`): Maximum prompt length. Longer prompts are truncated from the left. num_generations (`int`, *optional*, defaults to `8`): Number of sampled generations per prompt. + num_generations_eval (`int` or `None`, *optional*): + Number of sampled generations per prompt during evaluation. + max_completion_length (`int` or `None`, *optional*, defaults to `256`): + Maximum generated completion length. + ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): + Whether to gather ZeRO-3 weights for generation. + shuffle_dataset (`bool`, *optional*, defaults to `True`): + Whether to shuffle the training dataset. generation_batch_size (`int` or `None`, *optional*): Global batch size used for generation. Mutually exclusive with `steps_per_generation`. steps_per_generation (`int` or `None`, *optional*): Number of optimizer steps that reuse one generated batch. Mutually exclusive with `generation_batch_size`. + > Parameters that control sampling + + temperature (`float`, *optional*, defaults to `1.0`): + Sampling temperature. + top_p (`float`, *optional*, defaults to `1.0`): + Top-p sampling parameter. + top_k (`int`, *optional*, defaults to `0`): + Top-k sampling parameter. `0` disables top-k filtering. + min_p (`float` or `None`, *optional*): + Minimum token probability for sampling. + generation_kwargs (`dict[str, Any]` or `None`, *optional*): + Extra generation kwargs passed to `GenerationConfig`. + chat_template_kwargs (`dict[str, Any]` or `None`, *optional*): + Extra kwargs forwarded to chat template application. + repetition_penalty (`float`, *optional*, defaults to `1.0`): + Repetition penalty used during generation. + use_transformers_paged (`bool`, *optional*, defaults to `False`): + Reserved for paged generation support. + cache_implementation (`str` or `None`, *optional*): + Cache implementation used by transformers generation. + + > Parameters that control vLLM generation + + use_vllm (`bool`, *optional*, defaults to `False`): + Whether to use vLLM for generation. + vllm_mode (`str`, *optional*, defaults to `"colocate"`): + vLLM mode: `"colocate"` (shared GPU) or `"server"` (separate vLLM server). + vllm_model_impl (`str`, *optional*, defaults to `"vllm"`): + Model implementation for vLLM: `"vllm"` or `"transformers"`. + vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`): + Whether to enable sleep mode for colocated vLLM engine. + vllm_server_base_url (`str` or `None`, *optional*): + Base URL for the vLLM server. If provided, `vllm_server_host` and `vllm_server_port` are ignored. + vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`): + Host of the vLLM server (server mode only). + vllm_server_port (`int`, *optional*, defaults to `8000`): + Port of the vLLM server (server mode only). + vllm_group_port (`int`, *optional*, defaults to `51216`): + Port for the weight update group (server mode only). + vllm_server_timeout (`float`, *optional*, defaults to `240.0`): + Timeout in seconds to wait for the vLLM server. + vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`): + Tensor parallel size for colocated vLLM. + vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.3`): + GPU memory utilization ratio for colocated vLLM. + vllm_max_model_length (`int` or `None`, *optional*): + Model context length for vLLM. Inferred from model config if not set. + > Parameters that control the online policy objective + num_iterations (`int`, *optional*, defaults to `1`): + Number of optimization iterations per generated batch. loss_type (`str`, *optional*, defaults to `"dapo"`): Policy-loss aggregation mode. Supported: `grpo`, `bnpo`, `dr_grpo`, `dapo`. + mask_truncated_completions (`bool`, *optional*, defaults to `False`): + Whether to exclude truncated completions from the loss. + top_entropy_quantile (`float`, *optional*, defaults to `1.0`): + Reserved for entropy-based token filtering. > Parameters that control teacher construction teacher_model_kind (`str`, *optional*, defaults to `"live"`): - Semantic teacher choice. `live` uses the current student, `base` uses the student as it existed at the - start of training, and `ema` uses an exponentially averaged teacher. + Semantic teacher choice. `live` uses the current student, `base` uses the initial student, and `ema` uses + an exponentially averaged teacher. teacher_update_rate (`float`, *optional*, defaults to `0.6`): - EMA update rate used when `teacher_model_kind="ema"`. + EMA update rate used when `teacher_model_kind="ema"`. A value of `1.0` reduces the update to a hard + overwrite, periodically resyncing the teacher to the current student weights. teacher_sync_steps (`int`, *optional*, defaults to `512`): - Number of optimizer steps between EMA teacher updates. + Number of optimizer steps between teacher updates. > Parameters that control self-distillation distillation_alpha (`float`, *optional*, defaults to `0.5`): - Divergence interpolation coefficient using the official SDPO/SDFT convention: `0.0=forward KL`, `0.5=JSD`, + KL divergence direction: `0.0=forward KL`, `0.5=JSD`, `1.0=reverse KL`. distillation_topk (`int` or `None`, *optional*, defaults to `100`): - Number of top tokens to keep for top-k distillation. If `None`, all logits are used. + Number of top tokens for top-k distillation. If `None`, uses all tokens. full_logit_distillation (`bool`, *optional*, defaults to `False`): Whether to use full-logit distillation instead of token-level distillation. distillation_is_clip (`float` or `None`, *optional*, defaults to `2.0`): - Importance-sampling clip used by the official SDPO-style correction. `None` disables clipping. + Clipping coefficient for importance sampling in self-distillation. `None` disables clipping. + distillation_add_tail (`bool`, *optional*, defaults to `False`): + Whether to add a tail bucket for non-top-k probability mass. distillation_weight (`float`, *optional*, defaults to `1.0`): Weight applied to the self-distillation loss term. @@ -124,7 +196,9 @@ class SelfDistillationConfig(_BaseConfig): ) steps_per_generation: int | None = field( default=None, - metadata={"help": "Number of optimizer steps that reuse one generated batch."}, + metadata={ + "help": "Number of optimizer steps that reuse one generated batch. Mutually exclusive with `generation_batch_size`." + }, ) temperature: float = field( default=1.0, @@ -229,11 +303,14 @@ class SelfDistillationConfig(_BaseConfig): ) teacher_update_rate: float = field( default=0.6, - metadata={"help": "EMA update rate used when synchronizing the teacher model."}, + metadata={ + "help": 'EMA update rate used when `teacher_model_kind="ema"`. A value of `1.0` reduces the update ' + "to a hard overwrite, periodically resyncing the teacher to the current student weights." + }, ) teacher_sync_steps: int = field( default=512, - metadata={"help": "How often to synchronize the teacher model."}, + metadata={"help": "Number of optimizer steps between teacher updates."}, ) top_entropy_quantile: float = field( default=1.0, @@ -245,7 +322,7 @@ class SelfDistillationConfig(_BaseConfig): ) distillation_topk: int | None = field( default=100, - metadata={"help": "Number of top tokens for top-k distillation. If None, uses all tokens."}, + metadata={"help": "Number of top tokens for top-k distillation. If `None`, uses all tokens."}, ) full_logit_distillation: bool = field( default=False, @@ -253,7 +330,9 @@ class SelfDistillationConfig(_BaseConfig): ) distillation_is_clip: float | None = field( default=2.0, - metadata={"help": "Clipping coefficient for importance sampling in self-distillation."}, + metadata={ + "help": "Clipping coefficient for importance sampling in self-distillation. `None` disables clipping." + }, ) distillation_add_tail: bool = field( default=False, From e30ca046073a2431c5ca74f7cdd8788a28aa4d96 Mon Sep 17 00:00:00 2001 From: Leon Date: Tue, 21 Apr 2026 09:24:52 +0200 Subject: [PATCH 21/23] fix: rename full logit distillation+topk into explicit flags --- docs/source/paper_index.md | 5 ++- docs/source/sdft_trainer.md | 1 + docs/source/sdpo_trainer.md | 11 ++--- .../test_base_self_distillation_trainer.py | 4 +- tests/experimental/test_sdpo_trainer.py | 2 +- trl/experimental/sdpo/sdpo.py | 2 +- trl/experimental/sdpo/sdpo_config.py | 32 +++++++++++---- .../base_self_distillation_trainer.py | 23 +++++++---- .../self_distillation/loss_utils.py | 2 +- .../self_distillation_config.py | 41 +++++++++++++------ 10 files changed, 82 insertions(+), 41 deletions(-) diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 4c78c9d3bb5..23893c47e02 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -1641,8 +1641,8 @@ from trl.experimental.sdpo import SDPOConfig, SDPOTrainer training_args = SDPOConfig( distillation_alpha=0.5, # Jensen-Shannon divergence (recommended) - distillation_topk=100, # Top-K logit distillation approximation - full_logit_distillation=True, # Required for top-K logit-level SDPO + distillation_mode="topk_logits", # Explicitly select top-K logit distillation + distillation_topk=100, # Required for top-K logit distillation distillation_is_clip=2.0, # Importance sampling clipping distillation_weight=1.0, # Weight for self-distillation loss sdpo_policy_loss_mode="distillation_only", @@ -1689,6 +1689,7 @@ dataset = Dataset.from_dict( training_args = SDFTConfig( distillation_alpha=0.5, + distillation_mode="topk_logits", distillation_topk=5, max_completion_length=64, ) diff --git a/docs/source/sdft_trainer.md b/docs/source/sdft_trainer.md index 17f6433fd78..23f29966458 100644 --- a/docs/source/sdft_trainer.md +++ b/docs/source/sdft_trainer.md @@ -32,6 +32,7 @@ dataset = Dataset.from_dict( training_args = SDFTConfig( output_dir="sdft-model", distillation_alpha=0.5, + distillation_mode="topk_logits", distillation_topk=5, max_completion_length=64, ) diff --git a/docs/source/sdpo_trainer.md b/docs/source/sdpo_trainer.md index 3910f674baf..7f82f565bcb 100644 --- a/docs/source/sdpo_trainer.md +++ b/docs/source/sdpo_trainer.md @@ -11,8 +11,9 @@ In the current TRL implementation: - the default SDPO policy loss mode is `distillation_only` - `hybrid` mode is also available to combine the base policy loss with the self-distillation loss - supported teacher regularization modes are `ema` and `none` -- `distillation_topk` is only valid when `full_logit_distillation=True` -- when `full_logit_distillation=False`, SDPO uses token-level reverse KL and requires `distillation_alpha=1.0` +- `distillation_mode` selects between `sampled_token`, `full_logits`, and `topk_logits` +- `distillation_topk` is only valid when `distillation_mode="topk_logits"` +- when `distillation_mode="sampled_token"`, SDPO uses token-level reverse KL and requires `distillation_alpha=1.0` - environment feedback can be injected into teacher reprompts when the dataset exposes a `privileged_context` column ## Expected dataset columns @@ -38,8 +39,8 @@ dataset = Dataset.from_dict( training_args = SDPOConfig( output_dir="sdpo-model", - distillation_topk=100, # Top-K logit distillation approximation - full_logit_distillation=True, # Required for top-K; enables non-reverse divergences + distillation_mode="topk_logits", # Explicitly select top-K logit distillation + distillation_topk=100, # Required when using top-K logit distillation include_environment_feedback=True, # Use dataset privileged_context for teacher reprompts ) @@ -88,7 +89,7 @@ python trl/experimental/sdpo/sdpo.py \ --num_generations 8 \ --generation_batch_size 32 \ --distillation_alpha 1.0 \ - --full_logit_distillation false \ + --distillation_mode sampled_token \ --sdpo_policy_loss_mode hybrid \ --report_to none \ --eval_strategy steps \ diff --git a/tests/experimental/test_base_self_distillation_trainer.py b/tests/experimental/test_base_self_distillation_trainer.py index 880678a0e73..3846bbe8c03 100644 --- a/tests/experimental/test_base_self_distillation_trainer.py +++ b/tests/experimental/test_base_self_distillation_trainer.py @@ -74,8 +74,8 @@ class TestBaseSelfDistillationTrainer(TrlTestCase): def _make_loss_test_trainer(**args_overrides): trainer = object.__new__(MinimalSelfDistillationTrainer) args = { + "distillation_mode": "sampled_token", "distillation_topk": None, - "full_logit_distillation": False, "distillation_alpha": 1.0, "distillation_add_tail": False, "distillation_is_clip": None, @@ -169,7 +169,7 @@ def test_tokenize_prompts_for_conversational_prompts_forwards_chat_template_kwar def test_compute_self_distillation_loss_ignores_masked_completion_tokens(self): trainer = self._make_loss_test_trainer( - full_logit_distillation=True, + distillation_mode="full_logits", distillation_alpha=0.0, ) model = SimpleNamespace(training=True) diff --git a/tests/experimental/test_sdpo_trainer.py b/tests/experimental/test_sdpo_trainer.py index d800968966f..81a29af4822 100644 --- a/tests/experimental/test_sdpo_trainer.py +++ b/tests/experimental/test_sdpo_trainer.py @@ -171,8 +171,8 @@ def test_training(self): per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=8, # reduce the completion length to reduce memory usage + distillation_mode="topk_logits", distillation_topk=5, - full_logit_distillation=True, distillation_is_clip=None, ) trainer = SDPOTrainer( diff --git a/trl/experimental/sdpo/sdpo.py b/trl/experimental/sdpo/sdpo.py index 7a65a00cc5d..cd752cc6855 100644 --- a/trl/experimental/sdpo/sdpo.py +++ b/trl/experimental/sdpo/sdpo.py @@ -43,7 +43,7 @@ --num_generations 8 \ --generation_batch_size 32 \ --distillation_alpha 1.0 \ - --full_logit_distillation false \ + --distillation_mode sampled_token \ --sdpo_policy_loss_mode hybrid \ --report_to none \ --eval_strategy steps \ diff --git a/trl/experimental/sdpo/sdpo_config.py b/trl/experimental/sdpo/sdpo_config.py index e4c0e6e030d..a45d2aab9da 100644 --- a/trl/experimental/sdpo/sdpo_config.py +++ b/trl/experimental/sdpo/sdpo_config.py @@ -13,6 +13,7 @@ # limitations under the License. from dataclasses import dataclass, field +from typing import Literal from ..self_distillation import SelfDistillationConfig @@ -47,10 +48,14 @@ class SDPOConfig(SelfDistillationConfig): How SDPO combines the online policy loss and self-distillation loss. Supported: `distillation_only`, `policy_only`, `hybrid`. distillation_alpha (`float`, *optional*, defaults to `1.0`): - Divergence interpolation coefficient. Token-level SDPO requires the official reverse-KL setting + Divergence interpolation coefficient. Sampled-token SDPO requires the official reverse-KL setting + `distillation_alpha=1.0`. + distillation_mode (`Literal["sampled_token", "full_logits", "topk_logits"]`, *optional*, defaults to `"sampled_token"`): + Distillation objective mode. `"sampled_token"` is the default SDPO mode and requires `distillation_alpha=1.0`. distillation_topk (`int` or `None`, *optional*): - Top-k approximation for logit-level SDPO. Requires `full_logit_distillation=True`. + Top-k approximation for logit-level SDPO. Must be set when `distillation_mode="topk_logits"` and left + unset otherwise. > Parameters that control the teacher @@ -103,12 +108,23 @@ class SDPOConfig(SelfDistillationConfig): distillation_alpha: float = field( default=1.0, metadata={ - "help": "KL divergence direction for SDPO. Token-level SDPO requires reverse KL (`distillation_alpha=1.0`)." + "help": "Divergence interpolation coefficient. Sampled-token SDPO requires the official reverse-KL setting " + "`distillation_alpha=1.0`." + }, + ) + distillation_mode: Literal["sampled_token", "full_logits", "topk_logits"] = field( + default="sampled_token", + metadata={ + "help": "Distillation objective mode. `sampled_token` is the default SDPO mode and requires " + "`distillation_alpha=1.0`." }, ) distillation_topk: int | None = field( default=None, - metadata={"help": "Top-K approximation for logit-level SDPO. Requires `full_logit_distillation=True`."}, + metadata={ + "help": "Top-k approximation for logit-level SDPO. Must be set when `distillation_mode=topk_logits` and left " + "unset otherwise." + }, ) sdpo_policy_loss_mode: str = field( default="distillation_only", @@ -187,10 +203,8 @@ def __post_init__(self): raise ValueError("hybrid mode requires `distillation_weight > 0`.") if self.max_reprompt_len <= 0: raise ValueError("max_reprompt_len must be positive") - if not self.full_logit_distillation and self.distillation_alpha != 1.0: + if self.distillation_mode == "sampled_token" and self.distillation_alpha != 1.0: raise ValueError( - "SDPO token-level distillation requires `distillation_alpha=1.0`. " - "Set `full_logit_distillation=True` to use other divergence settings." + "SDPO sampled-token distillation requires `distillation_alpha=1.0`. " + "Set `distillation_mode='full_logits'` or `distillation_mode='topk_logits'` to use other divergence settings." ) - if self.distillation_topk is not None and not self.full_logit_distillation: - raise ValueError("SDPO `distillation_topk` requires `full_logit_distillation=True`.") diff --git a/trl/experimental/self_distillation/base_self_distillation_trainer.py b/trl/experimental/self_distillation/base_self_distillation_trainer.py index e76157ae263..f58563bafcd 100644 --- a/trl/experimental/self_distillation/base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/base_self_distillation_trainer.py @@ -616,12 +616,12 @@ def _compute_self_distillation_loss( ) -> torch.Tensor: """Compute the per-token distillation loss and aggregate it according to `loss_type`. - Dispatches between three objectives based on config: + Dispatches between three objectives based on `distillation_mode`: - - `distillation_topk` is not `None`: top-k approximation of the divergence, optionally with a tail - bucket for the remaining probability mass (`distillation_add_tail`). - - `full_logit_distillation` is `True`: full-vocab divergence. - - otherwise: token-level (reverse-KL) distillation on sampled `completion_ids`. + - `"topk_logits"`: top-k approximation of the divergence, optionally with a tail bucket for the + remaining probability mass (`distillation_add_tail`). + - `"full_logits"`: full-vocab divergence. + - `"sampled_token"`: token-level (reverse-KL) distillation on sampled `completion_ids`. When `distillation_is_clip` is set and `old_per_token_logps` are available, the loss is corrected by a clipped importance-sampling ratio between the current student and the student at rollout time. @@ -632,7 +632,9 @@ def _compute_self_distillation_loss( # Keep the zero loss attached to the student graph so backward produces zero gradients instead of stopping. return distillation_logits.student_logits.sum() * 0.0 - if self.args.distillation_topk is not None: + if self.args.distillation_mode == "topk_logits": + if self.args.distillation_topk is None: + raise ValueError("`distillation_mode='topk_logits'` requires `distillation_topk` to be set.") per_token_loss = compute_topk_self_distillation_loss( distillation_logits.student_logits, distillation_logits.teacher_logits, @@ -640,19 +642,24 @@ def _compute_self_distillation_loss( distillation_alpha=self.args.distillation_alpha, distillation_add_tail=self.args.distillation_add_tail, ) - elif self.args.full_logit_distillation: + elif self.args.distillation_mode == "full_logits": per_token_loss = compute_full_logit_self_distillation_loss( distillation_logits.student_logits, distillation_logits.teacher_logits, distillation_alpha=self.args.distillation_alpha, ) - else: + elif self.args.distillation_mode == "sampled_token": per_token_loss = compute_sampled_token_self_distillation_loss( distillation_logits.student_logits, distillation_logits.teacher_logits, distillation_logits.completion_ids, distillation_alpha=self.args.distillation_alpha, ) + else: + raise ValueError( + "distillation_mode must be one of: 'sampled_token', 'full_logits', 'topk_logits', " + f"got {self.args.distillation_mode!r}" + ) old_per_token_logps = inputs.get("old_per_token_logps") if self.args.distillation_is_clip is not None and old_per_token_logps is not None: diff --git a/trl/experimental/self_distillation/loss_utils.py b/trl/experimental/self_distillation/loss_utils.py index 64b35ff42a1..ac74a751099 100644 --- a/trl/experimental/self_distillation/loss_utils.py +++ b/trl/experimental/self_distillation/loss_utils.py @@ -170,7 +170,7 @@ def compute_sampled_token_self_distillation_loss( if distillation_alpha != 1.0: raise ValueError( "Only reverse KL (alpha=1.0) is supported for token-level distillation when " - f"`full_logit_distillation=False`, got alpha={distillation_alpha}" + f"`distillation_mode='sampled_token'`, got alpha={distillation_alpha}" ) student_per_token_logps = select_token_log_probs(student_logits, completion_ids) diff --git a/trl/experimental/self_distillation/self_distillation_config.py b/trl/experimental/self_distillation/self_distillation_config.py index 7efd525aec9..4fe20c00b95 100644 --- a/trl/experimental/self_distillation/self_distillation_config.py +++ b/trl/experimental/self_distillation/self_distillation_config.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import Any +from typing import Any, Literal from transformers import TrainingArguments @@ -133,10 +133,13 @@ class SelfDistillationConfig(_BaseConfig): distillation_alpha (`float`, *optional*, defaults to `0.5`): KL divergence direction: `0.0=forward KL`, `0.5=JSD`, `1.0=reverse KL`. - distillation_topk (`int` or `None`, *optional*, defaults to `100`): - Number of top tokens for top-k distillation. If `None`, uses all tokens. - full_logit_distillation (`bool`, *optional*, defaults to `False`): - Whether to use full-logit distillation instead of token-level distillation. + distillation_mode (`Literal["sampled_token", "full_logits", "topk_logits"]`, *optional*, defaults to `"sampled_token"`): + Distillation objective mode. `"sampled_token"` uses token-level distillation on the sampled completion + tokens, `"full_logits"` uses full-vocabulary divergence, and `"topk_logits"` uses a top-k approximation + over the student support. + distillation_topk (`int` or `None`, *optional*): + Number of top tokens for `"topk_logits"` distillation. Must be set when + `distillation_mode="topk_logits"` and left unset otherwise. distillation_is_clip (`float` or `None`, *optional*, defaults to `2.0`): Clipping coefficient for importance sampling in self-distillation. `None` disables clipping. distillation_add_tail (`bool`, *optional*, defaults to `False`): @@ -318,15 +321,22 @@ class SelfDistillationConfig(_BaseConfig): ) distillation_alpha: float = field( default=0.5, - metadata={"help": "KL divergence direction: 0.0=forward KL, 0.5=JSD, 1.0=reverse KL."}, + metadata={"help": "KL divergence direction: `0.0=forward KL`, `0.5=JSD`, `1.0=reverse KL`."}, ) - distillation_topk: int | None = field( - default=100, - metadata={"help": "Number of top tokens for top-k distillation. If `None`, uses all tokens."}, + distillation_mode: Literal["sampled_token", "full_logits", "topk_logits"] = field( + default="sampled_token", + metadata={ + "help": "Distillation objective mode. `sampled_token` uses token-level distillation on the sampled " + "completion tokens, `full_logits` uses full-vocabulary divergence, and `topk_logits` uses a top-k " + "approximation over the student support." + }, ) - full_logit_distillation: bool = field( - default=False, - metadata={"help": "Whether to use full-logit distillation instead of token-level distillation."}, + distillation_topk: int | None = field( + default=None, + metadata={ + "help": "Number of top tokens for `topk_logits` distillation. Must be set when " + "`distillation_mode='topk_logits'` and left unset otherwise." + }, ) distillation_is_clip: float | None = field( default=2.0, @@ -370,8 +380,15 @@ def __post_init__(self): raise ValueError("num_generations must be at least 1") if not 0.0 <= self.distillation_alpha <= 1.0: raise ValueError("distillation_alpha must be in [0, 1]") + if self.distillation_mode not in {"sampled_token", "full_logits", "topk_logits"}: + raise ValueError("distillation_mode must be one of: 'sampled_token', 'full_logits', 'topk_logits'") if self.distillation_topk is not None and self.distillation_topk <= 0: raise ValueError("distillation_topk must be positive when provided") + if self.distillation_mode == "topk_logits": + if self.distillation_topk is None: + raise ValueError("`distillation_mode='topk_logits'` requires `distillation_topk` to be set.") + elif self.distillation_topk is not None: + raise ValueError("`distillation_topk` is only valid when `distillation_mode='topk_logits'`.") if self.distillation_is_clip is not None and self.distillation_is_clip <= 0: raise ValueError("distillation_is_clip must be positive when provided") if self.distillation_weight < 0: From 3a9ecb28e0ce72220a384a3cbb7730cef7e7f97b Mon Sep 17 00:00:00 2001 From: Leon Date: Tue, 21 Apr 2026 19:32:42 +0200 Subject: [PATCH 22/23] fix(self-distillation): warn on preloaded peft students --- .../test_base_self_distillation_trainer.py | 43 +++++++++++++++++++ .../base_self_distillation_trainer.py | 8 ++++ 2 files changed, 51 insertions(+) diff --git a/tests/experimental/test_base_self_distillation_trainer.py b/tests/experimental/test_base_self_distillation_trainer.py index 3846bbe8c03..bd7c7e880b1 100644 --- a/tests/experimental/test_base_self_distillation_trainer.py +++ b/tests/experimental/test_base_self_distillation_trainer.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from collections import defaultdict from types import SimpleNamespace import pytest import torch from datasets import Dataset +from transformers import AutoModelForCausalLM +from transformers.utils import is_peft_available from trl.experimental.self_distillation.base_self_distillation_trainer import ( BaseSelfDistillationTrainer, @@ -28,6 +31,10 @@ from ..testing_utils import TrlTestCase +if is_peft_available(): + from peft import LoraConfig, get_peft_model + + class MinimalSelfDistillationTrainer(BaseSelfDistillationTrainer): def finalize_batch(self, inputs, rollout_batch): del inputs @@ -112,6 +119,42 @@ def test_teacher_model_kind_live_uses_student_model(self): assert trainer.teacher_model is trainer.model + @pytest.mark.skipif(not is_peft_available(), reason="PEFT is required for this test") + def test_warns_when_initial_student_already_has_a_peft_adapter(self, caplog): + dataset = Dataset.from_dict({"prompt": ["Solve 2+2."]}) + training_args = SelfDistillationConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=1, + max_completion_length=8, + max_steps=1, + num_generations=1, + teacher_model_kind="base", + report_to="none", + ) + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + model = get_peft_model( + model, + LoraConfig( + r=4, + lora_alpha=8, + target_modules=["q_proj", "v_proj"], + bias="none", + task_type="CAUSAL_LM", + ), + ) + + with caplog.at_level( + logging.WARNING, logger="trl.experimental.self_distillation.base_self_distillation_trainer" + ): + MinimalSelfDistillationTrainer( + model=model, + args=training_args, + train_dataset=dataset, + ) + + assert "already contains a PEFT adapter" in caplog.text + assert "`teacher_model_kind='base'` may refer to the underlying base weights" in caplog.text + @pytest.mark.parametrize("teacher_model_kind", ["base", "ema"]) def test_teacher_model_kind_base_and_ema_use_frozen_teacher_copy(self, teacher_model_kind): dataset = Dataset.from_dict({"prompt": ["Solve 2+2."]}) diff --git a/trl/experimental/self_distillation/base_self_distillation_trainer.py b/trl/experimental/self_distillation/base_self_distillation_trainer.py index f58563bafcd..95fe72669c3 100644 --- a/trl/experimental/self_distillation/base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/base_self_distillation_trainer.py @@ -153,6 +153,14 @@ def __init__( else inspect.signature(model.get_base_model().forward).parameters.keys() ) + if peft_config is None and getattr(model, "peft_config", None) is not None: + logger.warning( + "The provided self-distillation student model already contains a PEFT adapter. " + "This setup is accepted but not directly supported. In particular, `teacher_model_kind='base'` " + "may refer to the underlying base weights rather than the exact initially loaded student state " + "including its adapter. For unambiguous teacher behavior, start from a merged/non-adapter model " + "or manage separate adapters explicitly." + ) if is_peft_model(model) and peft_config is not None: raise ValueError( "You passed a `PeftModel` instance together with a `peft_config`. Pass either a base " From 03718eb682c1569b3359825976b5a9cedf0abe69 Mon Sep 17 00:00:00 2001 From: Leon Date: Wed, 22 Apr 2026 07:47:39 +0200 Subject: [PATCH 23/23] docs: cleanup --- trl/experimental/sdft/sdft_config.py | 5 +---- trl/experimental/sdft/sdft_trainer.py | 2 +- trl/experimental/sdpo/sdpo_config.py | 3 --- trl/experimental/sdpo/sdpo_trainer.py | 2 +- 4 files changed, 3 insertions(+), 9 deletions(-) diff --git a/trl/experimental/sdft/sdft_config.py b/trl/experimental/sdft/sdft_config.py index 08d4f64272a..21ec6b2a814 100644 --- a/trl/experimental/sdft/sdft_config.py +++ b/trl/experimental/sdft/sdft_config.py @@ -20,10 +20,7 @@ @dataclass class SDFTConfig(SelfDistillationConfig): r""" - Configuration class for [`SDFTTrainer`]. - - This adapts the official SDFT implementation to the TRL trainer API while reusing the common self-distillation - configuration shared with SDPO. + Configuration class for [`SDFTTrainer`].. Parameters: disable_dropout (`bool`, *optional*, defaults to `True`): diff --git a/trl/experimental/sdft/sdft_trainer.py b/trl/experimental/sdft/sdft_trainer.py index 7aeee6d2ba6..24959a36bc0 100644 --- a/trl/experimental/sdft/sdft_trainer.py +++ b/trl/experimental/sdft/sdft_trainer.py @@ -47,7 +47,7 @@ class DemonstrationTeacherContextBuilder: - """Builds student and teacher contexts from prompts plus privileged context, as in SDFT.""" + """Builds student and teacher contexts from prompts plus privileged context""" def __init__(self, trainer): self.trainer = trainer diff --git a/trl/experimental/sdpo/sdpo_config.py b/trl/experimental/sdpo/sdpo_config.py index a45d2aab9da..92c4a181f39 100644 --- a/trl/experimental/sdpo/sdpo_config.py +++ b/trl/experimental/sdpo/sdpo_config.py @@ -23,9 +23,6 @@ class SDPOConfig(SelfDistillationConfig): r""" Configuration class for the [`SDPOTrainer`]. - This class extends [`experimental.self_distillation.SelfDistillationConfig`] with the online teacher-construction - parameters used by Self-Distillation Policy Optimization (SDPO). - Parameters: > Parameters that control the online policy objective diff --git a/trl/experimental/sdpo/sdpo_trainer.py b/trl/experimental/sdpo/sdpo_trainer.py index b0d542c2a48..b0d3afc4e08 100644 --- a/trl/experimental/sdpo/sdpo_trainer.py +++ b/trl/experimental/sdpo/sdpo_trainer.py @@ -49,7 +49,7 @@ class SuccessfulRolloutTeacherContextBuilder: - """Builds SDPO teacher contexts from successful rollouts, following the official online implementation.""" + """Builds teacher contexts from successful rollouts""" def __init__(self, trainer): self.trainer = trainer