Skip to content
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
06f02a8
v0.1 transition sdft into unified base
LeonEricsson Apr 15, 2026
be1bcbc
sdft transition v1 complete, starting on sdpo
LeonEricsson Apr 15, 2026
0628701
sdpo transitioned, needs testing
LeonEricsson Apr 15, 2026
55111ff
remove legacy trainers
LeonEricsson Apr 15, 2026
81def8a
sdft and sdpo transitioned and tested with new base
LeonEricsson Apr 16, 2026
bad6b62
restructure training batch builder
LeonEricsson Apr 16, 2026
ef43c95
nits
LeonEricsson Apr 16, 2026
efe0eda
wip removing mixin
LeonEricsson Apr 16, 2026
fa1a8f3
remove mixin, refactoring and cleanup
LeonEricsson Apr 16, 2026
6a7d5a8
always set teacher_model
LeonEricsson Apr 16, 2026
56b2fd1
align generation tokenization with grpotrainer
LeonEricsson Apr 16, 2026
4a9d527
fix: generation_kwargs bug
LeonEricsson Apr 16, 2026
196feee
fix: incorrect import source
LeonEricsson Apr 16, 2026
3c87400
fixes: cleanup, standardized tokenization, distill loss=0 fix, sdpo c…
LeonEricsson Apr 17, 2026
d2a78e2
tests: ported old tests + new tests for base class
LeonEricsson Apr 17, 2026
8807088
couple more tests and test cleanup
LeonEricsson Apr 18, 2026
0612699
test: nit fix
LeonEricsson Apr 18, 2026
3d0cd72
move loss aggregation to loss_util + a few docstrings
LeonEricsson Apr 18, 2026
aa36955
fix: emit accumulated _metrics via log() override
LeonEricsson Apr 20, 2026
a432c20
fix: minor cursor issues + config docstrings
LeonEricsson Apr 20, 2026
e30ca04
fix: rename full logit distillation+topk into explicit flags
LeonEricsson Apr 21, 2026
3a9ecb2
fix(self-distillation): warn on preloaded peft students
LeonEricsson Apr 21, 2026
03718eb
docs: cleanup
LeonEricsson Apr 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 233 additions & 0 deletions tests/experimental/test_base_self_distillation_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
from types import SimpleNamespace

import pytest
import torch
from datasets import Dataset

from trl.experimental.self_distillation.base_self_distillation_trainer import (
BaseSelfDistillationTrainer,
DistillationLogits,
)
from trl.experimental.self_distillation.self_distillation_config import SelfDistillationConfig

from ..testing_utils import TrlTestCase


class MinimalSelfDistillationTrainer(BaseSelfDistillationTrainer):
def finalize_batch(self, inputs, rollout_batch):
del inputs
batch = rollout_batch.as_dict()
batch["teacher_input_ids"] = rollout_batch.prompt_ids
batch["teacher_attention_mask"] = rollout_batch.prompt_mask
return batch

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
del inputs, num_items_in_batch
anchor = next(model.parameters())
return anchor.sum() * 0.0


class FakeTextTokenizer:
def __call__(self, text, **kwargs):
del kwargs
token_map = {
"short prompt": [1, 2, 3],
"long prompt": [10, 11, 12, 13, 14],
}
return {"input_ids": [token_map[prompt] for prompt in text]}


class FakeChatProcessor:
def __init__(self):
self.calls = []

def apply_chat_template(self, conversation, add_generation_prompt, tokenize, return_dict, **kwargs):
self.calls.append(
{
"conversation": conversation,
"add_generation_prompt": add_generation_prompt,
"tokenize": tokenize,
"return_dict": return_dict,
"kwargs": kwargs,
}
)
return {"input_ids": [[21, 22, 23, 24]]}


class TestBaseSelfDistillationTrainer(TrlTestCase):
@staticmethod
def _make_loss_test_trainer(**args_overrides):
trainer = object.__new__(MinimalSelfDistillationTrainer)
args = {
"distillation_topk": None,
"full_logit_distillation": False,
"distillation_alpha": 1.0,
"distillation_add_tail": False,
"distillation_is_clip": None,
}
args.update(args_overrides)
trainer.args = SimpleNamespace(**args)
trainer.loss_type = "dapo"
trainer.max_completion_length = 2
trainer.accelerator = SimpleNamespace(gather=lambda tensor: tensor)
trainer._metrics = {
"train": defaultdict(list),
"eval": defaultdict(list),
}
trainer._name = "Minimal Self Distillation"
return trainer

def test_teacher_model_kind_live_uses_student_model(self):
dataset = Dataset.from_dict({"prompt": ["Solve 2+2."]})
training_args = SelfDistillationConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=1,
max_completion_length=8,
max_steps=1,
num_generations=1,
teacher_model_kind="live",
report_to="none",
)

trainer = MinimalSelfDistillationTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=dataset,
)

assert trainer.teacher_model is trainer.model

@pytest.mark.parametrize("teacher_model_kind", ["base", "ema"])
def test_teacher_model_kind_base_and_ema_use_frozen_teacher_copy(self, teacher_model_kind):
dataset = Dataset.from_dict({"prompt": ["Solve 2+2."]})
training_args = SelfDistillationConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=1,
max_completion_length=8,
max_steps=1,
num_generations=1,
teacher_model_kind=teacher_model_kind,
report_to="none",
)

trainer = MinimalSelfDistillationTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=dataset,
)

assert trainer.teacher_model is not trainer.model
assert trainer.teacher_model.training is False

student_param = next(trainer.model.parameters())
teacher_param = next(trainer.teacher_model.parameters())
assert teacher_param.requires_grad is False
assert teacher_param.data_ptr() != student_param.data_ptr()

def test_tokenize_prompts_truncates_text_prompts_from_left(self):
trainer = object.__new__(MinimalSelfDistillationTrainer)
trainer.processing_class = FakeTextTokenizer()
trainer.max_prompt_length = 3

prompt_ids = trainer._tokenize_prompts(["long prompt", "short prompt"])

assert prompt_ids == [[12, 13, 14], [1, 2, 3]]

def test_tokenize_prompts_for_conversational_prompts_forwards_chat_template_kwargs(self):
trainer = object.__new__(MinimalSelfDistillationTrainer)
trainer.processing_class = FakeChatProcessor()
trainer.max_prompt_length = 2
trainer.chat_template_kwargs = {"enable_thinking": False}

prompt_ids = trainer._tokenize_prompts([[{"role": "user", "content": "Solve 2+2."}]])

assert prompt_ids == [[23, 24]]
assert trainer.processing_class.calls == [
{
"conversation": [[{"role": "user", "content": "Solve 2+2."}]],
"add_generation_prompt": True,
"tokenize": True,
"return_dict": True,
"kwargs": {"enable_thinking": False},
}
]

def test_compute_self_distillation_loss_ignores_masked_completion_tokens(self):
trainer = self._make_loss_test_trainer(
full_logit_distillation=True,
distillation_alpha=0.0,
)
model = SimpleNamespace(training=True)

# Token 0 is active and has a known non-zero divergence.
# Token 1 is intentionally very different but masked out, so it must not affect the loss.
student_probs = torch.tensor([[[0.8, 0.2], [0.01, 0.99]]], dtype=torch.float32)
teacher_probs = torch.tensor([[[0.5, 0.5], [0.99, 0.01]]], dtype=torch.float32)
distillation_logits = DistillationLogits(
completion_ids=torch.tensor([[0, 1]], dtype=torch.long),
completion_mask=torch.tensor([[1, 1]], dtype=torch.long),
response_mask=torch.tensor([[1, 0]], dtype=torch.long),
student_logits=student_probs.log(),
teacher_logits=teacher_probs.log(),
)

loss = trainer._compute_self_distillation_loss(model, {}, distillation_logits)

expected_active_token_loss = teacher_probs[0, 0, 0] * (
teacher_probs[0, 0, 0].log() - student_probs[0, 0, 0].log()
) + teacher_probs[0, 0, 1] * (teacher_probs[0, 0, 1].log() - student_probs[0, 0, 1].log())
torch.testing.assert_close(loss, expected_active_token_loss)
torch.testing.assert_close(
torch.tensor(trainer._metrics["train"]["self_distillation/distillation_loss"]),
expected_active_token_loss.unsqueeze(0),
)

def test_compute_self_distillation_loss_applies_importance_sampling_clip(self):
trainer = self._make_loss_test_trainer(distillation_is_clip=2.0)
model = SimpleNamespace(training=True)

student_token_probs = torch.tensor([[0.2, 0.4]], dtype=torch.float32)
teacher_token_probs = torch.tensor([[0.5, 0.5]], dtype=torch.float32)
old_token_probs = torch.tensor([[0.05, 0.4]], dtype=torch.float32)
clip_coeff = trainer.args.distillation_is_clip

distillation_logits = DistillationLogits(
completion_ids=torch.tensor([[0, 1]], dtype=torch.long),
completion_mask=torch.tensor([[1, 1]], dtype=torch.long),
response_mask=torch.tensor([[1, 1]], dtype=torch.long),
student_logits=torch.log(torch.tensor([[[0.2, 0.8], [0.6, 0.4]]], dtype=torch.float32)),
teacher_logits=torch.log(torch.tensor([[[0.5, 0.5], [0.5, 0.5]]], dtype=torch.float32)),
)

loss = trainer._compute_self_distillation_loss(
model,
{"old_per_token_logps": old_token_probs.log()},
distillation_logits,
)

raw_per_token_loss = (student_token_probs.log() - teacher_token_probs.log()) * student_token_probs.log()
clipped_ratio = torch.minimum(
student_token_probs / old_token_probs, torch.full_like(student_token_probs, clip_coeff)
)
expected_loss = (raw_per_token_loss * clipped_ratio).mean()

torch.testing.assert_close(loss, expected_loss)
torch.testing.assert_close(
torch.tensor(trainer._metrics["train"]["self_distillation/distillation_loss"]),
expected_loss.unsqueeze(0),
)
39 changes: 20 additions & 19 deletions tests/experimental/test_sdft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import patch

import pytest
import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, TrainerCallback, TrainerControl, TrainerState, TrainingArguments
from transformers.utils import is_peft_available

from trl.data_utils import maybe_apply_chat_template
from trl.experimental.sdft import SDFTConfig, SDFTTrainer

from ..testing_utils import TrlTestCase, require_peft
Expand All @@ -27,18 +28,18 @@
if is_peft_available():
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict

from trl.experimental.self_distillation.peft_adapter_ema_callback import PEFTAdapterEMACallback
from trl.experimental.self_distillation.teacher_sync import PEFTAdapterEMACallback


class SelfDistillationCaptureCallback(TrainerCallback):
def __init__(self):
self.captured_generation_prompt_text = None
self.captured_generation_prompts = None
self.captured_old_per_token_logps = None
self.generation_batch_build_count = 0

def on_generation_prompts_selected(self, generation_prompt_text=None, **kwargs):
if self.captured_generation_prompt_text is None and generation_prompt_text is not None:
self.captured_generation_prompt_text = generation_prompt_text[0]
def on_generation_prompts_selected(self, generation_prompts=None, **kwargs):
if self.captured_generation_prompts is None and generation_prompts is not None:
self.captured_generation_prompts = generation_prompts

def on_self_distillation_batch_prepared(self, old_per_token_logps=None, **kwargs):
if self.captured_old_per_token_logps is None and old_per_token_logps is not None:
Expand Down Expand Up @@ -74,6 +75,7 @@ def test_training_rejects_none_privileged_context(self):
with pytest.raises(ValueError, match="`privileged_context` must not be None"):
trainer.train()

@pytest.mark.skip(reason="`generate_from_teacher` is not ported yet")
def test_training_with_generate_from_teacher(self):
dataset = Dataset.from_dict(
{
Expand Down Expand Up @@ -105,9 +107,8 @@ def test_training_with_generate_from_teacher(self):

trainer.train()

assert capture_callback.captured_generation_prompt_text is not None
assert "Solve 2+2." in capture_callback.captured_generation_prompt_text
assert "Teacher hint" in capture_callback.captured_generation_prompt_text
assert capture_callback.captured_generation_prompts is not None
assert capture_callback.captured_generation_prompts[0] != dataset[0]["prompt"]

def test_training_with_chat_template_kwargs(self):
dataset = Dataset.from_dict(
Expand Down Expand Up @@ -141,15 +142,15 @@ def test_training_with_chat_template_kwargs(self):
callbacks=[capture_callback],
)

expected_prompt = maybe_apply_chat_template(
{"prompt": dataset[0]["prompt"]},
with patch.object(
trainer.processing_class,
**training_args.chat_template_kwargs,
)["prompt"]

trainer.train()
"apply_chat_template",
wraps=trainer.processing_class.apply_chat_template,
) as mock_apply_chat_template:
trainer.train()

assert capture_callback.captured_generation_prompt_text == expected_prompt
assert mock_apply_chat_template.call_count > 0
assert any(call.kwargs.get("enable_thinking") is False for call in mock_apply_chat_template.call_args_list)

@require_peft
def test_training_with_peft_model(self):
Expand Down Expand Up @@ -205,9 +206,9 @@ def test_training_with_peft_model_and_sync_ref_model(self):
max_completion_length=8,
max_steps=2,
num_generations=1,
sync_ref_model=True,
ref_model_mixup_alpha=0.05,
ref_model_sync_steps=1,
teacher_model_kind="ema",
teacher_update_rate=0.05,
teacher_sync_steps=1,
)

trainer = SDFTTrainer(
Expand Down
16 changes: 3 additions & 13 deletions tests/experimental/test_sdpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,6 @@ def test_vllm_config_defaults_match_reference_trainers(self):
assert config.vllm_model_impl == "vllm"

def test_generate_vllm_syncs_on_step_change_and_uses_mode_specific_num_generations(self):
class FakeTokenizer:
def __call__(self, text, **kwargs):
token_map = {
"Solve 2+2.": [11, 12],
"Check 3+3.": [21, 22],
}
return {"input_ids": [token_map[prompt] for prompt in text]}

class FakeVLLMGeneration:
def __init__(self):
self.sync_weights_call_count = 0
Expand Down Expand Up @@ -135,11 +127,9 @@ def generate(self, prompts, images, num_generations):
trainer.model = SimpleNamespace(training=True)
trainer.state = SimpleNamespace(global_step=4)
trainer._last_loaded_step = 3
trainer.processing_class = FakeTokenizer()
trainer.vllm_generation = FakeVLLMGeneration()
trainer._apply_prompt_template = lambda prompts: prompts

prompt_ids, completion_ids = trainer._generate(["Solve 2+2.", "Solve 2+2."])
prompt_ids, completion_ids = trainer._generate([[11, 12], [11, 12]])

assert prompt_ids == [[11, 12], [11, 12]]
assert completion_ids == [[100], [101]]
Expand All @@ -154,7 +144,7 @@ def generate(self, prompts, images, num_generations):
]

trainer.model.training = False
eval_prompt_ids, eval_completion_ids = trainer._generate(["Check 3+3.", "Check 3+3.", "Check 3+3."])
eval_prompt_ids, eval_completion_ids = trainer._generate([[21, 22], [21, 22], [21, 22]])

assert eval_prompt_ids == [[21, 22], [21, 22], [21, 22]]
assert eval_completion_ids == [[100], [101], [102]]
Expand All @@ -167,7 +157,7 @@ def generate(self, prompts, images, num_generations):

trainer.model.training = True
trainer.state.global_step = 5
trainer._generate(["Solve 2+2.", "Solve 2+2."])
trainer._generate([[11, 12], [11, 12]])

assert trainer.vllm_generation.sync_weights_call_count == 2
assert trainer._last_loaded_step == 5
Expand Down
Loading
Loading