fix(distillation): reverse-KL server path NaN on variable completion length#5594
fix(distillation): reverse-KL server path NaN on variable completion length#5594k1064190 wants to merge 6 commits intohuggingface:mainfrom
Conversation
…length
When ``use_teacher_server=True`` with ``beta > 0`` and ``bs * grad_accum > 1``,
the reverse-KL server path leaked NaN into the backward pass whenever
per-sample completion lengths differed within a batch.
Root cause
----------
``_get_teacher_token_logprobs_from_server`` fills the rectangular (B, T)
output tensor with the TRL house sentinel ``float("-inf")`` at intra-batch
padding positions (the tail of shorter samples). The forward-KL server path
(``_compute_server_forward_kl_loss``) neutralises this via
``torch.where(teacher > -inf, ..., -inf)`` plus a support mask threaded
through ``_add_tail_bucket``; the reverse-KL server path
(``_compute_server_sparse_top_1_divergence_loss``) did not. Both paths
landed in the same commit (huggingface#5407) -- an oversight, not deliberate
asymmetry.
Unmasked, the -inf sentinel produces a teacher distribution [-inf, 0]
after ``_add_tail_bucket`` and +inf in ``_jsd_divergence``'s forward pass
(clamped to ``finfo.max`` by ``nan_to_num``), but NaN in the backward
pass: autograd's chain rule does not respect ``nan_to_num``, so the
pre-clamp +inf leaks through as NaN gradient.
Fix
---
Mirror the forward-KL server path's masking: after the ``isfinite`` checks
that guard required positions, replace the -inf sentinel with a finite
zero at all known padding positions (``labels == -100``) via
``torch.where``. The label mask in ``_reduce_divergence_loss`` still
excludes those positions from the final loss; the new neutralisation
prevents their -inf values from propagating through ``_add_tail_bucket``
and ``_jsd_divergence`` into the autograd graph.
Tests
-----
``tests/experimental/test_distillation_trainer.py`` is new (DistillationTrainer
had zero dedicated tests at v1.1.0):
- Sentinel contract at the server-path getter.
- The reverse-KL mask pattern produces finite forward AND backward on a
ragged batch.
- End-to-end training step under ``per_device_train_batch_size=1``,
``gradient_accumulation_steps=2``, variable completion lengths, with a
mocked ``VLLMClient``. Covers ``beta=1.0`` (reverse KL) and ``beta=0.5``
(JSD).
Reproduction pre-fix: ``grad_norm=nan`` on step 1.
Reproduction post-fix: ``grad_norm`` finite; padding positions receive
zero gradient (correctly excluded from the learning signal).
A parallel audit of GKDTrainer confirmed it is not vulnerable to the same
class of bug: its teacher runs in-process on a dense rectangular batch,
with no HTTP ragged-to-rectangular reassembly and no -inf sentinel in the
GKD loss path.
Refs: huggingface#5407.
Collapse the module summary, triple-line test docstrings, and the one-shot helper factories in `tests/experimental/test_distillation_trainer.py` into the repo's terse style. Functional coverage (sentinel pin, mid-level mask finite forward/backward, end-to-end train() under bs*ga>1 with ragged batches for beta=1.0 and beta=0.5) is unchanged; all 4 tests still pass.
There was a problem hiding this comment.
Pull request overview
Fixes a NaN-gradient issue in DistillationTrainer’s server-backed generalized JSD / reverse-KL loss when training with variable per-sample completion lengths (teacher server pads with -inf sentinels).
Changes:
- Neutralize
-infteacher logprob sentinels at padding positions (labels == -100) in the server reverse-KL / sparse top-1 divergence path before divergence math runs. - Clarify the sentinel/masking contract in
_get_teacher_token_logprobs_from_servercomments. - Add dedicated tests covering the sentinel contract, masking stability (forward + backward finite), and an end-to-end
DistillationTrainer.train()run with gradient accumulation and ragged batches using a mockedVLLMClient.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
trl/experimental/distillation/distillation_trainer.py |
Masks out -inf padding sentinels for the server reverse-KL/sparse-top1 divergence path to prevent NaN gradients. |
tests/experimental/test_distillation_trainer.py |
Adds unit + integration tests to reproduce and guard against the NaN-gradient regression with ragged completion lengths. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <[email protected]>
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit d3f6a18. Configure here.
Experiments showed the end-to-end regression tests were miscalibrated: - `bs=1, ga=2` and `bs=2, ga=1` both reproduce `grad_norm=nan` when the fix is removed (because `_get_teacher_token_logprobs_from_server` emits -inf padding not only for cross-sample ragged batches but also via per-sample `completion_offsets`). Parametrize the reverse-KL test over both configs for fuller coverage. - `beta=0.5` (JSD mixture) does not actually produce NaN without the fix in either config: `_jsd_divergence`'s mixture branch routes student gradients through `log((1-beta)*student_probs + beta*teacher_probs)`, which stays finite when teacher_probs=0 at padding. Drop the JSD end-to-end test — it was a vacuous guard. Unit + mid-level tests (sentinel contract, mask-keeps-forward-and- backward-finite) are unchanged.
cmpatino
left a comment
There was a problem hiding this comment.
Thank you for the contribution! This is a good check to prevent NaNs while training.
I left some comments before approving the merge of the PR.
- Trim padding-mask comment to two lines focused on what it prevents; the backward-autograd exposition lived in the PR description. - Drop the explicit `zero` scalar tensor — `torch.where` broadcasts the `0.0` literal to the tensor's dtype/device (verified bit-exact equivalent in fp32/bf16/fp16). - Mark the end-to-end `trainer.train()` test `@pytest.mark.slow` to match repo convention for heavy tests (saves ~8s per warm CI run).

What does this PR do?
Fixes a NaN-gradient bug in
DistillationTrainer's server-backed reverse-KL loss when batches produce per-sample padding within the rectangular teacher logprob tensor.Trigger:
use_teacher_server=True+beta == 1.0(reverse KL) + any config where_get_teacher_token_logprobs_from_serveremits-infpadding. This happens whenever per-samplecompletion_offsetsdiffer (verified empirically withbs=1, ga=2andbs=2, ga=1). Forward loss is finite (clamped bynan_to_num);grad_norm=nanon the first optim step.Root cause:
_get_teacher_token_logprobs_from_serverpads the rectangular(B, T)teacher tensor with-inf— both for shorter samples in cross-sample ragged batches and at per-sample head offsets (completion_start - aligned_prompt_length). The forward-KL server path (_compute_server_forward_kl_loss) tolerates this: itsp·(log_t - log_s)form hasteacher_probs=0at padding, so0·(-inf)=NaNin forward isnan_to_num'd to 0, and backward partials∂/∂safe_teacher = teacher_probs = 0zero the gradient. The JSD mixture path is safe for the same reason —log((1-β)·student_probs + β·teacher_probs)stays finite whenteacher_probs=0.Reverse-KL (
student_probs·(log_s - log_t)) has no such protection: the multiplierstudent_probsis nonzero at padding, sofinite·(+inf)=+infsurvives into backward, andgrad_normbecomes NaN viasqrt(sum(inf²))/clipping.Fix: In
_compute_server_sparse_top_1_divergence_loss, after the existingisfinitevalidation, neutralise the-infsentinel at known padding positions (labels == -100) with a finite zero viatorch.where, before the shared divergence helper runs. The label mask in_reduce_divergence_losscontinues to exclude these positions from the final loss — the new block only prevents the sentinel from reaching the autograd graph.This is the minimal surgical fix: the shared helper
_compute_sparse_top_1_divergence_lossis untouched, the-infsentinel contract (used by the existingisfinitevalidator to flag missing-required-data) is preserved, and forward-KL / JSD paths are not affected.Tests (
tests/experimental/test_distillation_trainer.py, new file — trainer had no dedicated tests):_add_tail_bucket+_jsd_divergence(beta=1)post-mask produces finite forward AND backward.DistillationTrainer.train()with mockedVLLMClient, parametrized overbs=1, ga=2andbs=2, ga=1. Both fail (grad_norm=nan) without the fix; both pass with it.pytest tests/experimental/test_distillation_trainer.py -v: 4/4 pass.Env (
trl env):Before submitting
AI writing disclosure
Note
Medium Risk
Touches loss math in the teacher-server distillation path; a small masking change, but it directly affects training stability and gradients for reverse-KL on variable-length completions.
Overview
Fixes a NaN/inf gradient issue in
DistillationTrainerwhen using a teacher server with reverse-KL (beta=1) and variable completion lengths by zeroing out teacher-inflogprob padding atlabels == -100before divergence computation.Adds a new experimental test suite that (1) asserts the server getter preserves
-infsentinels for ragged batches, (2) verifies the padding-neutralisation strategy keeps forward/backward finite, and (3) runs an end-to-end mocked teacher-server training step to ensureloss/grad_normremain finite under different batch/accumulation shapes.Reviewed by Cursor Bugbot for commit efa22bc. Bugbot is set up for automated code reviews on this repo. Configure here.