Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
06f02a8
v0.1 transition sdft into unified base
LeonEricsson Apr 15, 2026
be1bcbc
sdft transition v1 complete, starting on sdpo
LeonEricsson Apr 15, 2026
0628701
sdpo transitioned, needs testing
LeonEricsson Apr 15, 2026
55111ff
remove legacy trainers
LeonEricsson Apr 15, 2026
81def8a
sdft and sdpo transitioned and tested with new base
LeonEricsson Apr 16, 2026
bad6b62
restructure training batch builder
LeonEricsson Apr 16, 2026
ef43c95
nits
LeonEricsson Apr 16, 2026
efe0eda
wip removing mixin
LeonEricsson Apr 16, 2026
fa1a8f3
remove mixin, refactoring and cleanup
LeonEricsson Apr 16, 2026
6a7d5a8
always set teacher_model
LeonEricsson Apr 16, 2026
56b2fd1
align generation tokenization with grpotrainer
LeonEricsson Apr 16, 2026
4a9d527
fix: generation_kwargs bug
LeonEricsson Apr 16, 2026
196feee
fix: incorrect import source
LeonEricsson Apr 16, 2026
3c87400
fixes: cleanup, standardized tokenization, distill loss=0 fix, sdpo c…
LeonEricsson Apr 17, 2026
d2a78e2
tests: ported old tests + new tests for base class
LeonEricsson Apr 17, 2026
8807088
couple more tests and test cleanup
LeonEricsson Apr 18, 2026
0612699
test: nit fix
LeonEricsson Apr 18, 2026
3d0cd72
move loss aggregation to loss_util + a few docstrings
LeonEricsson Apr 18, 2026
aa36955
fix: emit accumulated _metrics via log() override
LeonEricsson Apr 20, 2026
a432c20
fix: minor cursor issues + config docstrings
LeonEricsson Apr 20, 2026
e30ca04
fix: rename full logit distillation+topk into explicit flags
LeonEricsson Apr 21, 2026
3a9ecb2
fix(self-distillation): warn on preloaded peft students
LeonEricsson Apr 21, 2026
03718eb
docs: cleanup
LeonEricsson Apr 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -1641,8 +1641,8 @@ from trl.experimental.sdpo import SDPOConfig, SDPOTrainer

training_args = SDPOConfig(
distillation_alpha=0.5, # Jensen-Shannon divergence (recommended)
distillation_topk=100, # Top-K logit distillation approximation
full_logit_distillation=True, # Required for top-K logit-level SDPO
distillation_mode="topk_logits", # Explicitly select top-K logit distillation
distillation_topk=100, # Required for top-K logit distillation
distillation_is_clip=2.0, # Importance sampling clipping
distillation_weight=1.0, # Weight for self-distillation loss
sdpo_policy_loss_mode="distillation_only",
Expand Down Expand Up @@ -1689,6 +1689,7 @@ dataset = Dataset.from_dict(

training_args = SDFTConfig(
distillation_alpha=0.5,
distillation_mode="topk_logits",
distillation_topk=5,
max_completion_length=64,
)
Expand Down
1 change: 1 addition & 0 deletions docs/source/sdft_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dataset = Dataset.from_dict(
training_args = SDFTConfig(
output_dir="sdft-model",
distillation_alpha=0.5,
distillation_mode="topk_logits",
distillation_topk=5,
max_completion_length=64,
)
Expand Down
11 changes: 6 additions & 5 deletions docs/source/sdpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ In the current TRL implementation:
- the default SDPO policy loss mode is `distillation_only`
- `hybrid` mode is also available to combine the base policy loss with the self-distillation loss
- supported teacher regularization modes are `ema` and `none`
- `distillation_topk` is only valid when `full_logit_distillation=True`
- when `full_logit_distillation=False`, SDPO uses token-level reverse KL and requires `distillation_alpha=1.0`
- `distillation_mode` selects between `sampled_token`, `full_logits`, and `topk_logits`
- `distillation_topk` is only valid when `distillation_mode="topk_logits"`
- when `distillation_mode="sampled_token"`, SDPO uses token-level reverse KL and requires `distillation_alpha=1.0`
- environment feedback can be injected into teacher reprompts when the dataset exposes a `privileged_context` column

## Expected dataset columns
Expand All @@ -38,8 +39,8 @@ dataset = Dataset.from_dict(

training_args = SDPOConfig(
output_dir="sdpo-model",
distillation_topk=100, # Top-K logit distillation approximation
full_logit_distillation=True, # Required for top-K; enables non-reverse divergences
distillation_mode="topk_logits", # Explicitly select top-K logit distillation
distillation_topk=100, # Required when using top-K logit distillation
include_environment_feedback=True, # Use dataset privileged_context for teacher reprompts
)

Expand Down Expand Up @@ -88,7 +89,7 @@ python trl/experimental/sdpo/sdpo.py \
--num_generations 8 \
--generation_batch_size 32 \
--distillation_alpha 1.0 \
--full_logit_distillation false \
--distillation_mode sampled_token \
--sdpo_policy_loss_mode hybrid \
--report_to none \
--eval_strategy steps \
Expand Down
276 changes: 276 additions & 0 deletions tests/experimental/test_base_self_distillation_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from collections import defaultdict
from types import SimpleNamespace

import pytest
import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM
from transformers.utils import is_peft_available

from trl.experimental.self_distillation.base_self_distillation_trainer import (
BaseSelfDistillationTrainer,
DistillationLogits,
)
from trl.experimental.self_distillation.self_distillation_config import SelfDistillationConfig

from ..testing_utils import TrlTestCase


if is_peft_available():
from peft import LoraConfig, get_peft_model


class MinimalSelfDistillationTrainer(BaseSelfDistillationTrainer):
def finalize_batch(self, inputs, rollout_batch):
del inputs
batch = rollout_batch.as_dict()
batch["teacher_input_ids"] = rollout_batch.prompt_ids
batch["teacher_attention_mask"] = rollout_batch.prompt_mask
return batch

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
del inputs, num_items_in_batch
anchor = next(model.parameters())
return anchor.sum() * 0.0


class FakeTextTokenizer:
def __call__(self, text, **kwargs):
del kwargs
token_map = {
"short prompt": [1, 2, 3],
"long prompt": [10, 11, 12, 13, 14],
}
return {"input_ids": [token_map[prompt] for prompt in text]}


class FakeChatProcessor:
def __init__(self):
self.calls = []

def apply_chat_template(self, conversation, add_generation_prompt, tokenize, return_dict, **kwargs):
self.calls.append(
{
"conversation": conversation,
"add_generation_prompt": add_generation_prompt,
"tokenize": tokenize,
"return_dict": return_dict,
"kwargs": kwargs,
}
)
return {"input_ids": [[21, 22, 23, 24]]}


class TestBaseSelfDistillationTrainer(TrlTestCase):
@staticmethod
def _make_loss_test_trainer(**args_overrides):
trainer = object.__new__(MinimalSelfDistillationTrainer)
args = {
"distillation_mode": "sampled_token",
"distillation_topk": None,
"distillation_alpha": 1.0,
"distillation_add_tail": False,
"distillation_is_clip": None,
}
args.update(args_overrides)
trainer.args = SimpleNamespace(**args)
trainer.loss_type = "dapo"
trainer.max_completion_length = 2
trainer.accelerator = SimpleNamespace(gather=lambda tensor: tensor)
trainer._metrics = {
"train": defaultdict(list),
"eval": defaultdict(list),
}
trainer._name = "Minimal Self Distillation"
return trainer

def test_teacher_model_kind_live_uses_student_model(self):
dataset = Dataset.from_dict({"prompt": ["Solve 2+2."]})
training_args = SelfDistillationConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=1,
max_completion_length=8,
max_steps=1,
num_generations=1,
teacher_model_kind="live",
report_to="none",
)

trainer = MinimalSelfDistillationTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=dataset,
)

assert trainer.teacher_model is trainer.model

@pytest.mark.skipif(not is_peft_available(), reason="PEFT is required for this test")
def test_warns_when_initial_student_already_has_a_peft_adapter(self, caplog):
dataset = Dataset.from_dict({"prompt": ["Solve 2+2."]})
training_args = SelfDistillationConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=1,
max_completion_length=8,
max_steps=1,
num_generations=1,
teacher_model_kind="base",
report_to="none",
)
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
model = get_peft_model(
model,
LoraConfig(
r=4,
lora_alpha=8,
target_modules=["q_proj", "v_proj"],
bias="none",
task_type="CAUSAL_LM",
),
)

with caplog.at_level(
logging.WARNING, logger="trl.experimental.self_distillation.base_self_distillation_trainer"
):
MinimalSelfDistillationTrainer(
model=model,
args=training_args,
train_dataset=dataset,
)

assert "already contains a PEFT adapter" in caplog.text
assert "`teacher_model_kind='base'` may refer to the underlying base weights" in caplog.text

@pytest.mark.parametrize("teacher_model_kind", ["base", "ema"])
def test_teacher_model_kind_base_and_ema_use_frozen_teacher_copy(self, teacher_model_kind):
dataset = Dataset.from_dict({"prompt": ["Solve 2+2."]})
training_args = SelfDistillationConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=1,
max_completion_length=8,
max_steps=1,
num_generations=1,
teacher_model_kind=teacher_model_kind,
report_to="none",
)

trainer = MinimalSelfDistillationTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=dataset,
)

assert trainer.teacher_model is not trainer.model
assert trainer.teacher_model.training is False

student_param = next(trainer.model.parameters())
teacher_param = next(trainer.teacher_model.parameters())
assert teacher_param.requires_grad is False
assert teacher_param.data_ptr() != student_param.data_ptr()

def test_tokenize_prompts_truncates_text_prompts_from_left(self):
trainer = object.__new__(MinimalSelfDistillationTrainer)
trainer.processing_class = FakeTextTokenizer()
trainer.max_prompt_length = 3

prompt_ids = trainer._tokenize_prompts(["long prompt", "short prompt"])

assert prompt_ids == [[12, 13, 14], [1, 2, 3]]

def test_tokenize_prompts_for_conversational_prompts_forwards_chat_template_kwargs(self):
trainer = object.__new__(MinimalSelfDistillationTrainer)
trainer.processing_class = FakeChatProcessor()
trainer.max_prompt_length = 2
trainer.chat_template_kwargs = {"enable_thinking": False}

prompt_ids = trainer._tokenize_prompts([[{"role": "user", "content": "Solve 2+2."}]])

assert prompt_ids == [[23, 24]]
assert trainer.processing_class.calls == [
{
"conversation": [[{"role": "user", "content": "Solve 2+2."}]],
"add_generation_prompt": True,
"tokenize": True,
"return_dict": True,
"kwargs": {"enable_thinking": False},
}
]

def test_compute_self_distillation_loss_ignores_masked_completion_tokens(self):
trainer = self._make_loss_test_trainer(
distillation_mode="full_logits",
distillation_alpha=0.0,
)
model = SimpleNamespace(training=True)

# Token 0 is active and has a known non-zero divergence.
# Token 1 is intentionally very different but masked out, so it must not affect the loss.
student_probs = torch.tensor([[[0.8, 0.2], [0.01, 0.99]]], dtype=torch.float32)
teacher_probs = torch.tensor([[[0.5, 0.5], [0.99, 0.01]]], dtype=torch.float32)
distillation_logits = DistillationLogits(
completion_ids=torch.tensor([[0, 1]], dtype=torch.long),
completion_mask=torch.tensor([[1, 1]], dtype=torch.long),
response_mask=torch.tensor([[1, 0]], dtype=torch.long),
student_logits=student_probs.log(),
teacher_logits=teacher_probs.log(),
)

loss = trainer._compute_self_distillation_loss(model, {}, distillation_logits)

expected_active_token_loss = teacher_probs[0, 0, 0] * (
teacher_probs[0, 0, 0].log() - student_probs[0, 0, 0].log()
) + teacher_probs[0, 0, 1] * (teacher_probs[0, 0, 1].log() - student_probs[0, 0, 1].log())
torch.testing.assert_close(loss, expected_active_token_loss)
torch.testing.assert_close(
torch.tensor(trainer._metrics["train"]["self_distillation/distillation_loss"]),
expected_active_token_loss.unsqueeze(0),
)

def test_compute_self_distillation_loss_applies_importance_sampling_clip(self):
trainer = self._make_loss_test_trainer(distillation_is_clip=2.0)
model = SimpleNamespace(training=True)

student_token_probs = torch.tensor([[0.2, 0.4]], dtype=torch.float32)
teacher_token_probs = torch.tensor([[0.5, 0.5]], dtype=torch.float32)
old_token_probs = torch.tensor([[0.05, 0.4]], dtype=torch.float32)
clip_coeff = trainer.args.distillation_is_clip

distillation_logits = DistillationLogits(
completion_ids=torch.tensor([[0, 1]], dtype=torch.long),
completion_mask=torch.tensor([[1, 1]], dtype=torch.long),
response_mask=torch.tensor([[1, 1]], dtype=torch.long),
student_logits=torch.log(torch.tensor([[[0.2, 0.8], [0.6, 0.4]]], dtype=torch.float32)),
teacher_logits=torch.log(torch.tensor([[[0.5, 0.5], [0.5, 0.5]]], dtype=torch.float32)),
)

loss = trainer._compute_self_distillation_loss(
model,
{"old_per_token_logps": old_token_probs.log()},
distillation_logits,
)

raw_per_token_loss = (student_token_probs.log() - teacher_token_probs.log()) * student_token_probs.log()
clipped_ratio = torch.minimum(
student_token_probs / old_token_probs, torch.full_like(student_token_probs, clip_coeff)
)
expected_loss = (raw_per_token_loss * clipped_ratio).mean()

torch.testing.assert_close(loss, expected_loss)
torch.testing.assert_close(
torch.tensor(trainer._metrics["train"]["self_distillation/distillation_loss"]),
expected_loss.unsqueeze(0),
)
Loading
Loading