fix(distillation): reverse-KL server path NaN on variable completion length#1
fix(distillation): reverse-KL server path NaN on variable completion length#1
Conversation
There was a problem hiding this comment.
Pull request overview
Fixes a NaN-gradient issue in DistillationTrainer’s server-backed reverse-KL/JSD loss when batches contain variable completion lengths, and adds initial test coverage for the trainer’s server path.
Changes:
- Neutralize
-infteacher logprob sentinels at ignore positions (labels == -100) in_compute_server_sparse_top_1_divergence_lossto prevent+inf/NaN propagation through tail-bucket + divergence math. - Expand inline documentation around the
-infsentinel contract in_get_teacher_token_logprobs_from_server. - Add new unit + functional tests covering server logprob padding semantics and the ragged-batch reverse-KL/JSD regression.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
trl/experimental/distillation/distillation_trainer.py |
Adds padding-position masking for server-backed sparse top-1 divergence loss to prevent NaN gradients on ragged batches. |
tests/experimental/test_distillation_trainer.py |
Introduces the first DistillationTrainer tests, including regression coverage for variable-length server batches. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| trainer.train() | ||
|
|
||
| return [ | ||
| (name, param.grad.detach().clone()) | ||
| for name, param in trainer.model.named_parameters() | ||
| if param.grad is not None | ||
| ] |
There was a problem hiding this comment.
The end-to-end tests read param.grad after trainer.train() completes. With the underlying transformers.Trainer loop, gradients are typically cleared via zero_grad() after each optimizer step, so at the end of training most/all param.grad entries can be None even when backward was finite. This can make the test fail spuriously or not actually assert the intended regression. Consider capturing grads before they are cleared (e.g., via a callback hook around the optimizer step / end of accumulation, or by manually running 2 microbatches with compute_loss + backward and checking grads before the optimizer step/zeroing).
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # T1 — Unit: _get_teacher_token_logprobs_from_server padding must be finite. |
There was a problem hiding this comment.
The section header says “padding must be finite”, but this test explicitly asserts that padding positions contain the -inf sentinel. Updating the header/comment to match the asserted contract would avoid confusion for future readers.
| # T1 — Unit: _get_teacher_token_logprobs_from_server padding must be finite. | |
| # T1 — Unit: _get_teacher_token_logprobs_from_server padding uses `-inf` sentinel. |
aeb5aa8 to
52252d2
Compare
…length
When ``use_teacher_server=True`` with ``beta > 0`` and ``bs * grad_accum > 1``,
the reverse-KL server path leaked NaN into the backward pass whenever
per-sample completion lengths differed within a batch.
Root cause
----------
``_get_teacher_token_logprobs_from_server`` fills the rectangular (B, T)
output tensor with the TRL house sentinel ``float("-inf")`` at intra-batch
padding positions (the tail of shorter samples). The forward-KL server path
(``_compute_server_forward_kl_loss``) neutralises this via
``torch.where(teacher > -inf, ..., -inf)`` plus a support mask threaded
through ``_add_tail_bucket``; the reverse-KL server path
(``_compute_server_sparse_top_1_divergence_loss``) did not. Both paths
landed in the same commit (huggingface#5407) -- an oversight, not deliberate
asymmetry.
Unmasked, the -inf sentinel produces a teacher distribution [-inf, 0]
after ``_add_tail_bucket`` and +inf in ``_jsd_divergence``'s forward pass
(clamped to ``finfo.max`` by ``nan_to_num``), but NaN in the backward
pass: autograd's chain rule does not respect ``nan_to_num``, so the
pre-clamp +inf leaks through as NaN gradient.
Fix
---
Mirror the forward-KL server path's masking: after the ``isfinite`` checks
that guard required positions, replace the -inf sentinel with a finite
zero at all known padding positions (``labels == -100``) via
``torch.where``. The label mask in ``_reduce_divergence_loss`` still
excludes those positions from the final loss; the new neutralisation
prevents their -inf values from propagating through ``_add_tail_bucket``
and ``_jsd_divergence`` into the autograd graph.
Tests
-----
``tests/experimental/test_distillation_trainer.py`` is new (DistillationTrainer
had zero dedicated tests at v1.1.0):
- Sentinel contract at the server-path getter.
- The reverse-KL mask pattern produces finite forward AND backward on a
ragged batch.
- End-to-end training step under ``per_device_train_batch_size=1``,
``gradient_accumulation_steps=2``, variable completion lengths, with a
mocked ``VLLMClient``. Covers ``beta=1.0`` (reverse KL) and ``beta=0.5``
(JSD).
Reproduction pre-fix: ``grad_norm=nan`` on step 1.
Reproduction post-fix: ``grad_norm`` finite; padding positions receive
zero gradient (correctly excluded from the learning signal).
A parallel audit of GKDTrainer confirmed it is not vulnerable to the same
class of bug: its teacher runs in-process on a dense rectangular batch,
with no HTTP ragged-to-rectangular reassembly and no -inf sentinel in the
GKD loss path.
Refs: huggingface#5407.
|
Closing to replace with a template-compliant PR after rebasing onto upstream/main. |
52252d2 to
3c0d9ae
Compare
Summary
When
use_teacher_server=Truewithbeta > 0andper_device_train_batch_size * gradient_accumulation_steps > 1, the reverse-KL server path leaks NaN into the backward pass whenever per-sample completion lengths differ within a batch. Forward loss is finite (nan_to_numclamps it);grad_norm == nanat step 1.This PR adds the missing defensive mask to
_compute_server_sparse_top_1_divergence_loss, mirroring the masking already present in the forward-KL server path (_compute_server_forward_kl_loss). Also adds the first dedicated unit/integration tests forDistillationTrainer(the trainer had zero tests at v1.1.0).Reproduction
per_device_train_batch_sizegradient_accumulation_stepsloss=0.109, grad_norm=4.80OKloss=0.112, grad_norm=nanFAILloss=0.109, grad_norm=nanFAILTrigger =
use_teacher_server=True+beta > 0+loss_add_tail=True+bs*ga > 1+ per-sample completion lengths that differ within a batch. The official doc snippet for the server path (lines 72-87 ofdistillation_trainer.md) omitsper_device_train_batch_size/gradient_accumulation_steps, inheriting the HF defaultbs=8, ga=1. A user following the doc verbatim hits the bug as soon as their dataset has variable-length completions.Root cause
_get_teacher_token_logprobs_from_serverfills the rectangular(B, T)output tensor with the TRL house sentinelfloat("-inf")at intra-batch padding positions (the tail of shorter samples).Two server-path consumers of that tensor:
_compute_server_forward_kl_loss(beta == 0): masks the sentinel viatorch.where(teacher > -inf, ..., -inf)and threads asupport_maskthrough_add_tail_bucket. Safe._compute_server_sparse_top_1_divergence_loss(beta > 0): did not mask. The -inf flowed through_add_tail_bucket(producing teacher distribution[-inf, 0]) and_jsd_divergence(producing+infin forward, clamped bynan_to_num, but NaN in backward because autograd's chain rule does not respectnan_to_num— the pre-clamp+inftensor's gradient still leaks).git blameshows both code paths landed in the original DistillationTrainer PR (huggingface#5407, commitc475b979, same author, same moment). The asymmetric masking is an oversight, not a deliberate design choice — there is no comment explaining why the reverse-KL path would be safe without the mask, and the generic warning at line 1285-1287 about-inf → NaNapplies identically to both paths.Fix
After the existing
isfinitevalidation at the start of_compute_server_sparse_top_1_divergence_loss, replace the -inf sentinel at known padding positions (labels == -100) with a finite zero viatorch.where, before the shared_compute_sparse_top_1_divergence_losshelper runs. The label mask in_reduce_divergence_lossstill excludes those positions from the final loss; the new block only prevents the sentinel from reaching the autograd graph through the tail-bucket/JSD math.The
-infsentinel at the getter is intentionally preserved so that the existingisfinitevalidator at_compute_server_sparse_top_1_divergence_loss:1338-1360can continue to distinguish "missing required data" (legitimate -inf inrequiredpositions; raises) from "padding" (-inf at~required; neutralised here).Tests
New file
tests/experimental/test_distillation_trainer.py(no prior tests existed):test_server_logprobs_variable_lengths_place_neg_inf_sentinel_at_padding— pins the sentinel contract at the server getter: real completion positions preserved, padding tail carries -inf.test_reverse_kl_padding_mask_keeps_forward_and_backward_finite— direct regression: given -inf-padded teacher tensors, apply the Strategy-B mask, then verify_add_tail_bucket + _jsd_divergence(beta=1)produces finite forward AND finite backward.TestDistillationTrainerServerPathVariableCompletion— end-to-end:DistillationTrainer.train()withtrl-internal-testing/tiny-Qwen2ForCausalLM-2.5,per_device_train_batch_size=1,gradient_accumulation_steps=2, a dataset with clearly different completion lengths, and aVLLMClientmonkey-patched to return canned ragged logprobs. Two cases:beta=1.0(reverse KL) andbeta=0.5(JSD). Both assertlossandgrad_normfrom the training log-history are finite (Trainer zeroes per-parameter.gradpost-step, so the log-history path is the reliable regression signal).Empirical before/after
Same ragged-batch input; only the mask differs:
Padding positions (indices 1, 2 of sample 0) are correctly zero-grad post-fix — they remain excluded from the learning signal, just without NaN contamination.
Verified against TRL-spec env
Built a fresh env matching TRL v1.1.0's declared
pyproject.tomlranges exactly (and not a superset), installed the editable fix branch, and ranpytest:<5)Result:
No TRL-declared version was relaxed. The fix is a pure-Python change that does not touch the vLLM wire format or server contract.
Scope note: GKDTrainer is not affected
I audited the sibling
trl.experimental.gkd.GKDTrainerin parallel. It is not vulnerable to the same class of bug:gkd_trainer.py:368).(B, T)batch as the student.grep -c 'float(\"-inf\")' trl/experimental/gkd/gkd_trainer.py== 0).jsd[mask]withmask = labels != -100) rather thantorch.where(..., -inf, ...), so even if -inf ever appeared it would be filtered pre-reduction.bs=2intest_gkd_trainer.py:217— the padded-batch path is tested for GKD but was entirely uncovered for DistillationTrainer.The NaN-leak pattern is isolated to DistillationTrainer's server path.
Test plan for reviewer
pytest tests/experimental/test_distillation_trainer.py -vpasses on a clean checkout of this branch.pytest tests/experimental/test_distillation_trainer.py -vfails onmain(regression is caught).trl vllm-serveteacher with variable-length dataset atbs=1, ga=2, beta=1.0no longer producesgrad_norm=nan.Refs
DistillationTrainerfor efficient on-policy distillation huggingface/trl#5407 (same commit introduced both the masked forward-KL and the unmasked reverse-KL paths).DistillationTrainerhuggingface/trl#5500 (did not touch the loss math).