Upload testing suite for DistillationTrainer#5615
Upload testing suite for DistillationTrainer#5615cmpatino wants to merge 5 commits intohuggingface:mainfrom
DistillationTrainer#5615Conversation
qgallouedec
left a comment
There was a problem hiding this comment.
Thanks for adding tests to DistillationTrainer, it had none, so this is directionally welcome. A few things worth addressing before merge:
-
Heavy use of
DistillationTrainer.__new__(...)+ manual attribute assignment, plus fiveDummy*/custom mock classes. Every test assembles a fake trainer by bypassing__init__and setting ~10 private attributes inline. This couples every test to internal attribute names. Any rename or new attribute used incompute_losssilently breaks the suite. It also cuts against two principles: consistency (the rest oftests/experimental/, e.g.test_gkd_trainer.py, loads a tiny real model and exercises the real__init__) and simplicity (the mock scaffolding and duplicated attribute setup intest_server_teacher_path_handles_variable_prompt_lengths/..._padded_completionsis exactly that). A tiny-model fixture would remove most of it. -
No end-to-end test. One short
trainer.train()on a tiny student+teacher would catch far more than the current mocked suite combined, and would make most of the attribute-juggling tests redundant. Plus it would be more align with the principle of testing behavior over implementation.
| torch.testing.assert_close(local_loss, server_loss) | ||
|
|
||
|
|
||
| def test_sampled_mode_keeps_teacher_argmax_for_forward_support(): |
There was a problem hiding this comment.
this is tautological. The expected value is computed by calling the same private helper _jsd_divergence with the same support/mask construction that compute_loss uses internally, so it tests wiring, not semantics. Any bug inside _jsd_divergence is invisible. Derive the expected value from first principles (plain JSD formula) or drop it.
| report_to="none", | ||
| ) | ||
|
|
||
| assert caught == [] |
There was a problem hiding this comment.
fragile: any unrelated DeprecationWarning from a dependency will break it. The pytest.raises is already the real assertion
There was a problem hiding this comment.
Thank you for the feedback! I addressed your comments with the following changes:
- Followed GKD's tests and used a small model and dataset to test the trainer.
- Improved the Liger tests by following the testing approach from the SFTTrainer.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit 1033b7c. Configure here.

What does this PR do?
Uploads tests for the
DistillationTrainerBefore submitting
AI writing disclosure
We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
Note
Low Risk
Test-only change; no production code paths are modified, with risk limited to potential CI flakiness due to model/dataset dependencies and optional accelerator coverage.
Overview
Adds a new experimental test module covering
DistillationTrainerand related utilities.The suite validates
DistillationConfigargument constraints (teacher server vs liger, server URL required, reverse-KL mode restrictions), correctness ofbuild_teacher_request_inputstrimming/prompt-length logic, andgeneralized_jsd_lossbehavior acrossbeta, temperature, and reductions. It also adds integration-style tests that run a short train with a local teacher (including checkpoint output), optionally exercises the liger path, verifies parity between local-teacher and mocked vLLM-teacher-server losses in sampled top-1 mode, and checks_RepeatBatchDataLoaderforwardsset_epoch.Reviewed by Cursor Bugbot for commit 43950bb. Bugbot is set up for automated code reviews on this repo. Configure here.