refactor: self distillation trainers (sdpo/sdft/...) #5573
refactor: self distillation trainers (sdpo/sdft/...) #5573LeonEricsson wants to merge 23 commits intohuggingface:mainfrom
Conversation
| **payload, | ||
| ) | ||
|
|
||
| def _setup_teacher_model(self) -> None: |
There was a problem hiding this comment.
This PR promotes LoRA and teacher regularization into shared capabilities owned by the base, rather than being method-specific. Teacher handling is collapsed into the semantic teacher_model_kind API (live / base / ema), letting the base infer PEFT implementation details.
| return self._build_buffered_batch(generation_batch) | ||
| return self._prepare_training_batch(generation_batch) | ||
|
|
||
| def _prepare_training_batch(self, inputs: list[dict[str, Any]]) -> TrainingBatch: |
There was a problem hiding this comment.
This is the hinge of the refactor. The base owns sample_rollouts() and only asks subclasses to implement finalize_batch(...), which takes a RolloutBatch (prompts + completions, pre-teacher) and produces a TrainingBatch with teacher-conditioned inputs and any trainer-specific metadata. Everything downstream keys off this split
| prompt_ids = self.processing_class(text=prompts)["input_ids"] | ||
| return prompt_ids | ||
|
|
||
| def _tokenize_prompts(self, prompts: list[Any]) -> list[list[int]]: |
There was a problem hiding this comment.
Two things regarding prompt tokenization:
- Removed
PromptTokenizer. Prompt rendering used to live in a separatePromptTokenizerobject the trainer held and builders reached back through. It had no stable contract and was effectively just forwarding trainer state. Tokenization now lives on the base. - Correctness fix on the rollout path. Prompts are now tokenized directly into token ID sequences, and generation runs from those for both transformers and vLLM, rather than round-tripping through text. This matches what existing trainers already do and eliminates prompt → text → token mismatches.
There was a problem hiding this comment.
Under the new base, SDFT's trainer becomes very simple. Build the teacher-conditioned batch and define loss composition.
| def select_generation_prompts(self, prompts: list[Any], privileged_contexts: list[Any]) -> list[Any]: | ||
| if not self.trainer.generate_from_teacher: | ||
| return prompts | ||
| return [ | ||
| self._compose_teacher_prompt(prompt, privileged_context) | ||
| for prompt, privileged_context in zip(prompts, privileged_contexts, strict=True) | ||
| ] |
There was a problem hiding this comment.
_generate_from_teacher stuff that needs to be addressed
| def _generate_and_score_completions( | ||
| self, inputs: list[dict[str, torch.Tensor | Any]] | ||
| ) -> dict[str, torch.Tensor | Any]: | ||
| def finalize_batch( |
There was a problem hiding this comment.
SDPO's counterpart to the new base boundary. The base provides the generic rollout batch, and SDPO’s finalize_batch(...) owns everything that is actually SDPO-specific.
| teacher_logits: torch.Tensor | ||
|
|
||
|
|
||
| class BaseSelfDistillationTrainer(_BaseTrainer, ABC): |
There was a problem hiding this comment.
The new base class deliberately owns capability only: model/tokenizer setup, rollout generation, teacher sync infrastructure, shared distillation math. It does not own any algorithmic choices, no reward scoring, no advantage computation, no loss composition. Those are pushed into subclasses.
|
|
||
| return old_per_token_logps | ||
|
|
||
| def _compute_self_distillation_loss( |
There was a problem hiding this comment.
This loss computation is opinionated: it covers top-k / full-logit / sampled-token paths, importance-sampling clipping, and aggregation per loss_type, matching what both SDPO and SDFT need. It lives on the base because both current trainers use it as-is, and any future trainer doing standard self-distillation math gets it for free.
That said, it's not part of the base contract. _compute_self_distillation_loss is only ever called from a trainer's own compute_loss, and nothing in the base touches it. A trainer with substantially different distillation math can just ignore it and compute loss from scratch.
|
I've provided a bunch of comments to aid reviewers |
…onfig parameters moved to sdpoconfig, + other nits
BaseSelfDistillationTrainer was populating _metrics in _log_self_distillation_metric but had no log() override, so those metrics were never forwarded to the Trainer's logging system. The fix merges _metrics into the log dict, prefixes eval keys, and clears after each logging step.
103b9d4 to
a432c20
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit 03718eb. Configure here.
| "help": "Distillation objective mode. `sampled_token` uses token-level distillation on the sampled " | ||
| "completion tokens, `full_logits` uses full-vocabulary divergence, and `topk_logits` uses a top-k " | ||
| "approximation over the student support." | ||
| }, |
There was a problem hiding this comment.
Incompatible base config defaults break SDFT at runtime
High Severity
SelfDistillationConfig defaults to distillation_mode="sampled_token" with distillation_alpha=0.5, but compute_sampled_token_self_distillation_loss raises a ValueError when distillation_alpha != 1.0. SDFTConfig inherits both defaults without overriding either, so any SDFTTrainer created with default config will crash at loss computation time. The old code avoided this because distillation_topk defaulted to 100, routing through the top-k path instead.
Additional Locations (2)
Reviewed by Cursor Bugbot for commit 03718eb. Configure here.


Summary
Major internal restructuring of the experimental self-distillation trainers. The core change: unify shared lifecycle logic in
BaseSelfDistillationTrainer, while concrete trainers (SDFTTrainer,SDPOTrainer) own their algorithmic recipes.The previous design mixed generic self-distillation lifecycle with SDPO-specific reward/scoring machinery in the same inheritance stack, meaning the base implicitly assumed every method needed reward-based optimization. This PR corrects that with a clean split: base owns infrastructure, trainers own semantics.
No intended change in functionality.
What changed
Base class now owns the common lifecycle:
model/tokenizer/generation setup, student-side rollout generation (transformers + vLLM), and teacher synchronization.
Concrete trainers own the method recipe:
rollout interpretation, teacher input construction, algorithm-specific metadata (rewards, advantages, etc.), and loss composition.
SDPO owns its identity.
Reward scoring and policy optimization are no longer baked into the base. SDPO explicitly manages its own reward functions, normalization, and hybrid loss.
Teacher regularization
is now a single semantic config parameter (
teacher_model_kind:"live"/"base"/"ema"), with the base handling PEFT/non-PEFT details internally.Removed
I’d like to touch on a few abstractions I’ve cleared out and the reasoning behind those calls. In short, I felt that
PromptTokenizer,SelfDistillationMixin, andOnlineRolloutMixinwere "leaky"- masking the logic they were supposed to define and making the architecture harder to follow.PromptTokenizerlogic now lives directly in the base, next to the rollout pipeline it serves. The two mixins were splitting orchestration, rollout mechanics, and algorithm-specific tasks (like reward scoring) across too many pieces. That responsibility is now split cleanly:BaseSelfDistillationTrainerhandles the "heavy lifting" of the distillation machinery and rollout mechanics. Meanwhile, concrete trainers have full ownership over their specific batch finalization and loss logic.Caveat
generate_from_teacherisn't ported; the base currently assumes student-sourced rollouts. There's still hanging code that references it, so we need to decide whether this is a capability worth porting before this PR can land.Before submitting
AI writing disclosure
We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
Note
High Risk
Large refactor of
BaseSelfDistillationTrainer, SDPO/SDFT training loops, and configuration semantics (teacher selection, distillation modes, reward/policy loss), which could subtly change rollout generation, loss computation, or teacher syncing behavior despite aiming for no functional change.Overview
Centralizes the experimental self-distillation trainer lifecycle into a redesigned
BaseSelfDistillationTrainer(shared rollout sampling via Transformers/vLLM, prompt tokenization, teacher model setup/sync, and unified distillation-loss plumbing), and removes the olderSelfDistillationMixin/OnlineRolloutMixin/teacher_contextabstractions.Updates SDPO to explicitly own reward model/function setup, reward normalization/advantage computation, and policy-vs-distillation loss composition (adds
policy_onlyand reworkshybrid/distillation_onlypaths), while SDFT is ported onto the new base and now finalizes batches by constructing teacher inputs and applying its response-mask skipping logic.Reworks configuration and docs to use
distillation_mode(sampled_token/full_logits/topk_logits) andteacher_model_kind(base/live/ema) withteacher_sync_steps/teacher_update_rate, and adds new unit tests covering teacher selection, prompt tokenization behavior, and distillation loss correctness/IS clipping.Reviewed by Cursor Bugbot for commit 03718eb. Bugbot is set up for automated code reviews on this repo. Configure here.