Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
06f02a8
v0.1 transition sdft into unified base
LeonEricsson Apr 15, 2026
be1bcbc
sdft transition v1 complete, starting on sdpo
LeonEricsson Apr 15, 2026
0628701
sdpo transitioned, needs testing
LeonEricsson Apr 15, 2026
55111ff
remove legacy trainers
LeonEricsson Apr 15, 2026
81def8a
sdft and sdpo transitioned and tested with new base
LeonEricsson Apr 16, 2026
bad6b62
restructure training batch builder
LeonEricsson Apr 16, 2026
ef43c95
nits
LeonEricsson Apr 16, 2026
efe0eda
wip removing mixin
LeonEricsson Apr 16, 2026
fa1a8f3
remove mixin, refactoring and cleanup
LeonEricsson Apr 16, 2026
6a7d5a8
always set teacher_model
LeonEricsson Apr 16, 2026
56b2fd1
align generation tokenization with grpotrainer
LeonEricsson Apr 16, 2026
4a9d527
fix: generation_kwargs bug
LeonEricsson Apr 16, 2026
196feee
fix: incorrect import source
LeonEricsson Apr 16, 2026
3c87400
fixes: cleanup, standardized tokenization, distill loss=0 fix, sdpo c…
LeonEricsson Apr 17, 2026
d2a78e2
tests: ported old tests + new tests for base class
LeonEricsson Apr 17, 2026
8807088
couple more tests and test cleanup
LeonEricsson Apr 18, 2026
0612699
test: nit fix
LeonEricsson Apr 18, 2026
3d0cd72
move loss aggregation to loss_util + a few docstrings
LeonEricsson Apr 18, 2026
aa36955
fix: emit accumulated _metrics via log() override
LeonEricsson Apr 20, 2026
a432c20
fix: minor cursor issues + config docstrings
LeonEricsson Apr 20, 2026
e30ca04
fix: rename full logit distillation+topk into explicit flags
LeonEricsson Apr 21, 2026
3a9ecb2
fix(self-distillation): warn on preloaded peft students
LeonEricsson Apr 21, 2026
03718eb
docs: cleanup
LeonEricsson Apr 22, 2026
d0e6657
feat: srt implemented and validated, sdzero wip
LeonEricsson Apr 20, 2026
517d4f4
feat: sdzero phase 2
LeonEricsson Apr 20, 2026
f35011c
docs: update paper index
LeonEricsson Apr 20, 2026
bd00d4b
fix: default sync to
LeonEricsson Apr 21, 2026
c1af84e
fix: adapt new distillation config params
LeonEricsson Apr 21, 2026
4003c2f
fix: default behavior when args=None
LeonEricsson Apr 21, 2026
289bc8c
wip: review minors
LeonEricsson Apr 21, 2026
5391c6c
docs: cleanup
LeonEricsson Apr 22, 2026
04e1a5d
fix: srt chat template and tokenization of dataset
LeonEricsson Apr 22, 2026
701e8eb
fix: wrap teacher prompt building in tokenizer apply chat template
LeonEricsson Apr 22, 2026
17fe1c2
feat: add chat_kwargs and cleanup code
LeonEricsson Apr 22, 2026
ab20ace
docs: add docstring
LeonEricsson Apr 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -1641,8 +1641,8 @@ from trl.experimental.sdpo import SDPOConfig, SDPOTrainer

training_args = SDPOConfig(
distillation_alpha=0.5, # Jensen-Shannon divergence (recommended)
distillation_topk=100, # Top-K logit distillation approximation
full_logit_distillation=True, # Required for top-K logit-level SDPO
distillation_mode="topk_logits", # Explicitly select top-K logit distillation
distillation_topk=100, # Required for top-K logit distillation
distillation_is_clip=2.0, # Importance sampling clipping
distillation_weight=1.0, # Weight for self-distillation loss
sdpo_policy_loss_mode="distillation_only",
Expand Down Expand Up @@ -1689,6 +1689,7 @@ dataset = Dataset.from_dict(

training_args = SDFTConfig(
distillation_alpha=0.5,
distillation_mode="topk_logits",
distillation_topk=5,
max_completion_length=64,
)
Expand Down Expand Up @@ -1739,6 +1740,67 @@ Expected dataset columns:

For more details, see the [SSD Trainer documentation](ssd_trainer).

### Self-Distillation Zero: Self-Revision Turns Binary Rewards into Dense Supervision

**📜 Paper**: https://huggingface.co/papers/2604.12002

SD-ZERO turns binary verifier rewards into dense supervision in two phases. Phase 1 — **Self-Revision Training (SRT)** — first has a model answer a problem `x` with an initial attempt `y_init`. A binary verifier then decides whether `y_init` is correct and chooses a control prompt `P_r`: rephrase the solution if the attempt is correct, or restart if it is not. Conditioned on `(x, y_init, P_r)`, the model samples revised answers and keeps only revisions `y_revised` that verify correct. Those accepted self-revision traces are then used for supervised learning with a joint objective: predict `y_revised` given `(x, y_init, P_r)`, and predict the full assistant trace `[y_init, P_r, y_revised]` from `x`. Phase 1 is implemented as [`experimental.sdzero.SRTTrainer`], and the companion collection script [`trl/experimental/sdzero/srt_collect.py`] is the recommended way to build the offline revision dataset.

```python
from datasets import load_from_disk

from trl.experimental.sdzero import SRTConfig, SRTTrainer

training_args = SRTConfig(
include_revision_loss=True, # L_revision term
include_generation_loss=True, # L_generation term
assistant_turn_template="{y_init}\n\n{control_prompt}\n\n{y_revised}",
)

trainer = SRTTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
args=training_args,
train_dataset=load_from_disk("/path/to/revision_dataset"),
)
trainer.train()
```

Expected dataset columns:

- `problem`
- `y_init`
- `control_prompt`
- `y_revised`

Phase 2 — **On-Policy Self-Distillation** — distills the reviser back into the generator. At each training step, the student generates a response `y` on-policy. The teacher context is the on-policy generation `y_init`, the problem `x`, and the verifier-selected `P_r`. By default, [`experimental.sdzero.SDZeroTrainer`] matches the paper's frozen SRT teacher and full-vocabulary `D_KL(student || teacher)` objective.

```python
from datasets import Dataset

from trl.experimental.sdzero import SDZeroConfig, SDZeroTrainer

dataset = Dataset.from_list([
{"prompt": [{"role": "user", "content": "...problem..."}], "answer": "...gold answer..."},
])

training_args = SDZeroConfig(
max_completion_length=512,
assistant_turn_template="{y}\n\n{control_prompt}\n\n",
)

trainer = SDZeroTrainer(
model="path/to/srt-checkpoint",
args=training_args,
train_dataset=dataset,
)
trainer.train()
```

Expected dataset columns:

- `prompt` (conversational list or plain string)
- `answer` (gold answer; passed to the binary verifier)

## Distributed Training

### ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
Expand Down
1 change: 1 addition & 0 deletions docs/source/sdft_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dataset = Dataset.from_dict(
training_args = SDFTConfig(
output_dir="sdft-model",
distillation_alpha=0.5,
distillation_mode="topk_logits",
distillation_topk=5,
max_completion_length=64,
)
Expand Down
11 changes: 6 additions & 5 deletions docs/source/sdpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ In the current TRL implementation:
- the default SDPO policy loss mode is `distillation_only`
- `hybrid` mode is also available to combine the base policy loss with the self-distillation loss
- supported teacher regularization modes are `ema` and `none`
- `distillation_topk` is only valid when `full_logit_distillation=True`
- when `full_logit_distillation=False`, SDPO uses token-level reverse KL and requires `distillation_alpha=1.0`
- `distillation_mode` selects between `sampled_token`, `full_logits`, and `topk_logits`
- `distillation_topk` is only valid when `distillation_mode="topk_logits"`
- when `distillation_mode="sampled_token"`, SDPO uses token-level reverse KL and requires `distillation_alpha=1.0`
- environment feedback can be injected into teacher reprompts when the dataset exposes a `privileged_context` column

## Expected dataset columns
Expand All @@ -38,8 +39,8 @@ dataset = Dataset.from_dict(

training_args = SDPOConfig(
output_dir="sdpo-model",
distillation_topk=100, # Top-K logit distillation approximation
full_logit_distillation=True, # Required for top-K; enables non-reverse divergences
distillation_mode="topk_logits", # Explicitly select top-K logit distillation
distillation_topk=100, # Required when using top-K logit distillation
include_environment_feedback=True, # Use dataset privileged_context for teacher reprompts
)

Expand Down Expand Up @@ -88,7 +89,7 @@ python trl/experimental/sdpo/sdpo.py \
--num_generations 8 \
--generation_batch_size 32 \
--distillation_alpha 1.0 \
--full_logit_distillation false \
--distillation_mode sampled_token \
--sdpo_policy_loss_mode hybrid \
--report_to none \
--eval_strategy steps \
Expand Down
Loading
Loading