Skip to content

refactor: self distillation trainers (sdpo/sdft/...) #5573

Open
LeonEricsson wants to merge 23 commits intohuggingface:mainfrom
LeonEricsson:feature/experimental-self-distillation
Open

refactor: self distillation trainers (sdpo/sdft/...) #5573
LeonEricsson wants to merge 23 commits intohuggingface:mainfrom
LeonEricsson:feature/experimental-self-distillation

Conversation

@LeonEricsson
Copy link
Copy Markdown
Collaborator

@LeonEricsson LeonEricsson commented Apr 16, 2026

Summary

Major internal restructuring of the experimental self-distillation trainers. The core change: unify shared lifecycle logic in BaseSelfDistillationTrainer, while concrete trainers (SDFTTrainer, SDPOTrainer) own their algorithmic recipes.

The previous design mixed generic self-distillation lifecycle with SDPO-specific reward/scoring machinery in the same inheritance stack, meaning the base implicitly assumed every method needed reward-based optimization. This PR corrects that with a clean split: base owns infrastructure, trainers own semantics.

No intended change in functionality.

What changed

Base class now owns the common lifecycle:
model/tokenizer/generation setup, student-side rollout generation (transformers + vLLM), and teacher synchronization.

Concrete trainers own the method recipe:
rollout interpretation, teacher input construction, algorithm-specific metadata (rewards, advantages, etc.), and loss composition.

SDPO owns its identity.
Reward scoring and policy optimization are no longer baked into the base. SDPO explicitly manages its own reward functions, normalization, and hybrid loss.

Teacher regularization
is now a single semantic config parameter (teacher_model_kind: "live" / "base" / "ema"), with the base handling PEFT/non-PEFT details internally.

Removed

I’d like to touch on a few abstractions I’ve cleared out and the reasoning behind those calls. In short, I felt that PromptTokenizer, SelfDistillationMixin, and OnlineRolloutMixin were "leaky"- masking the logic they were supposed to define and making the architecture harder to follow.

PromptTokenizer logic now lives directly in the base, next to the rollout pipeline it serves. The two mixins were splitting orchestration, rollout mechanics, and algorithm-specific tasks (like reward scoring) across too many pieces. That responsibility is now split cleanly: BaseSelfDistillationTrainer handles the "heavy lifting" of the distillation machinery and rollout mechanics. Meanwhile, concrete trainers have full ownership over their specific batch finalization and loss logic.

Caveat

generate_from_teacher isn't ported; the base currently assumes student-sourced rollouts. There's still hanging code that references it, so we need to decide whether this is a capability worth porting before this PR can land.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

AI writing disclosure

We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.

  • No AI usage: the PR was written entirely by a human.
  • AI-assisted: some parts were suggested or improved by AI, but the PR was written and reviewed by a human.
  • AI-generated: the PR was mostly or fully generated by an AI tool.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.


Note

High Risk
Large refactor of BaseSelfDistillationTrainer, SDPO/SDFT training loops, and configuration semantics (teacher selection, distillation modes, reward/policy loss), which could subtly change rollout generation, loss computation, or teacher syncing behavior despite aiming for no functional change.

Overview
Centralizes the experimental self-distillation trainer lifecycle into a redesigned BaseSelfDistillationTrainer (shared rollout sampling via Transformers/vLLM, prompt tokenization, teacher model setup/sync, and unified distillation-loss plumbing), and removes the older SelfDistillationMixin/OnlineRolloutMixin/teacher_context abstractions.

Updates SDPO to explicitly own reward model/function setup, reward normalization/advantage computation, and policy-vs-distillation loss composition (adds policy_only and reworks hybrid/distillation_only paths), while SDFT is ported onto the new base and now finalizes batches by constructing teacher inputs and applying its response-mask skipping logic.

Reworks configuration and docs to use distillation_mode (sampled_token/full_logits/topk_logits) and teacher_model_kind (base/live/ema) with teacher_sync_steps/teacher_update_rate, and adds new unit tests covering teacher selection, prompt tokenization behavior, and distillation loss correctness/IS clipping.

Reviewed by Cursor Bugbot for commit 03718eb. Bugbot is set up for automated code reviews on this repo. Configure here.

**payload,
)

def _setup_teacher_model(self) -> None:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR promotes LoRA and teacher regularization into shared capabilities owned by the base, rather than being method-specific. Teacher handling is collapsed into the semantic teacher_model_kind API (live / base / ema), letting the base infer PEFT implementation details.

return self._build_buffered_batch(generation_batch)
return self._prepare_training_batch(generation_batch)

def _prepare_training_batch(self, inputs: list[dict[str, Any]]) -> TrainingBatch:
Copy link
Copy Markdown
Collaborator Author

@LeonEricsson LeonEricsson Apr 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the hinge of the refactor. The base owns sample_rollouts() and only asks subclasses to implement finalize_batch(...), which takes a RolloutBatch (prompts + completions, pre-teacher) and produces a TrainingBatch with teacher-conditioned inputs and any trainer-specific metadata. Everything downstream keys off this split

prompt_ids = self.processing_class(text=prompts)["input_ids"]
return prompt_ids

def _tokenize_prompts(self, prompts: list[Any]) -> list[list[int]]:
Copy link
Copy Markdown
Collaborator Author

@LeonEricsson LeonEricsson Apr 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two things regarding prompt tokenization:

  1. Removed PromptTokenizer. Prompt rendering used to live in a separate PromptTokenizer object the trainer held and builders reached back through. It had no stable contract and was effectively just forwarding trainer state. Tokenization now lives on the base.
  2. Correctness fix on the rollout path. Prompts are now tokenized directly into token ID sequences, and generation runs from those for both transformers and vLLM, rather than round-tripping through text. This matches what existing trainers already do and eliminates prompt → text → token mismatches.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under the new base, SDFT's trainer becomes very simple. Build the teacher-conditioned batch and define loss composition.

Comment on lines 87 to 93
def select_generation_prompts(self, prompts: list[Any], privileged_contexts: list[Any]) -> list[Any]:
if not self.trainer.generate_from_teacher:
return prompts
return [
self._compose_teacher_prompt(prompt, privileged_context)
for prompt, privileged_context in zip(prompts, privileged_contexts, strict=True)
]
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_generate_from_teacher stuff that needs to be addressed

def _generate_and_score_completions(
self, inputs: list[dict[str, torch.Tensor | Any]]
) -> dict[str, torch.Tensor | Any]:
def finalize_batch(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SDPO's counterpart to the new base boundary. The base provides the generic rollout batch, and SDPO’s finalize_batch(...) owns everything that is actually SDPO-specific.

teacher_logits: torch.Tensor


class BaseSelfDistillationTrainer(_BaseTrainer, ABC):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new base class deliberately owns capability only: model/tokenizer setup, rollout generation, teacher sync infrastructure, shared distillation math. It does not own any algorithmic choices, no reward scoring, no advantage computation, no loss composition. Those are pushed into subclasses.


return old_per_token_logps

def _compute_self_distillation_loss(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loss computation is opinionated: it covers top-k / full-logit / sampled-token paths, importance-sampling clipping, and aggregation per loss_type, matching what both SDPO and SDFT need. It lives on the base because both current trainers use it as-is, and any future trainer doing standard self-distillation math gets it for free.

That said, it's not part of the base contract. _compute_self_distillation_loss is only ever called from a trainer's own compute_loss, and nothing in the base touches it. A trainer with substantially different distillation math can just ignore it and compute loss from scratch.

@LeonEricsson
Copy link
Copy Markdown
Collaborator Author

I've provided a bunch of comments to aid reviewers

@LeonEricsson LeonEricsson marked this pull request as ready for review April 18, 2026 12:25
Comment thread trl/experimental/self_distillation/base_self_distillation_trainer.py Outdated
Comment thread trl/experimental/sdpo/sdpo_trainer.py
Comment thread trl/experimental/sdft/sdft_trainer.py
Comment thread trl/experimental/sdpo/sdpo_trainer.py
BaseSelfDistillationTrainer was populating _metrics in
_log_self_distillation_metric but had no log() override, so those
metrics were never forwarded to the Trainer's logging system. The fix
merges _metrics into the log dict, prefixes eval keys, and clears after
each logging step.
@LeonEricsson LeonEricsson force-pushed the feature/experimental-self-distillation branch from 103b9d4 to a432c20 Compare April 20, 2026 20:12
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit 03718eb. Configure here.

"help": "Distillation objective mode. `sampled_token` uses token-level distillation on the sampled "
"completion tokens, `full_logits` uses full-vocabulary divergence, and `topk_logits` uses a top-k "
"approximation over the student support."
},
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incompatible base config defaults break SDFT at runtime

High Severity

SelfDistillationConfig defaults to distillation_mode="sampled_token" with distillation_alpha=0.5, but compute_sampled_token_self_distillation_loss raises a ValueError when distillation_alpha != 1.0. SDFTConfig inherits both defaults without overriding either, so any SDFTTrainer created with default config will crash at loss computation time. The old code avoided this because distillation_topk defaulted to 100, routing through the top-k path instead.

Additional Locations (2)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 03718eb. Configure here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants