Skip to content

fix(distillation): reverse-KL server path NaN on variable completion length#1

Closed
k1064190 wants to merge 1 commit intomainfrom
fix/distillation-server-nan-on-variable-completion
Closed

fix(distillation): reverse-KL server path NaN on variable completion length#1
k1064190 wants to merge 1 commit intomainfrom
fix/distillation-server-nan-on-variable-completion

Conversation

@k1064190
Copy link
Copy Markdown
Owner

@k1064190 k1064190 commented Apr 19, 2026

Summary

When use_teacher_server=True with beta > 0 and per_device_train_batch_size * gradient_accumulation_steps > 1, the reverse-KL server path leaks NaN into the backward pass whenever per-sample completion lengths differ within a batch. Forward loss is finite (nan_to_num clamps it); grad_norm == nan at step 1.

This PR adds the missing defensive mask to _compute_server_sparse_top_1_divergence_loss, mirroring the masking already present in the forward-KL server path (_compute_server_forward_kl_loss). Also adds the first dedicated unit/integration tests for DistillationTrainer (the trainer had zero tests at v1.1.0).

Reproduction

Config per_device_train_batch_size gradient_accumulation_steps Result
A 1 1 loss=0.109, grad_norm=4.80 OK
B 1 2 loss=0.112, grad_norm=nan FAIL
C 2 1 loss=0.109, grad_norm=nan FAIL

Trigger = use_teacher_server=True + beta > 0 + loss_add_tail=True + bs*ga > 1 + per-sample completion lengths that differ within a batch. The official doc snippet for the server path (lines 72-87 of distillation_trainer.md) omits per_device_train_batch_size/gradient_accumulation_steps, inheriting the HF default bs=8, ga=1. A user following the doc verbatim hits the bug as soon as their dataset has variable-length completions.

Root cause

_get_teacher_token_logprobs_from_server fills the rectangular (B, T) output tensor with the TRL house sentinel float("-inf") at intra-batch padding positions (the tail of shorter samples).

Two server-path consumers of that tensor:

  • _compute_server_forward_kl_loss (beta == 0): masks the sentinel via torch.where(teacher > -inf, ..., -inf) and threads a support_mask through _add_tail_bucket. Safe.
  • _compute_server_sparse_top_1_divergence_loss (beta > 0): did not mask. The -inf flowed through _add_tail_bucket (producing teacher distribution [-inf, 0]) and _jsd_divergence (producing +inf in forward, clamped by nan_to_num, but NaN in backward because autograd's chain rule does not respect nan_to_num — the pre-clamp +inf tensor's gradient still leaks).

git blame shows both code paths landed in the original DistillationTrainer PR (huggingface#5407, commit c475b979, same author, same moment). The asymmetric masking is an oversight, not a deliberate design choice — there is no comment explaining why the reverse-KL path would be safe without the mask, and the generic warning at line 1285-1287 about -inf → NaN applies identically to both paths.

Fix

After the existing isfinite validation at the start of _compute_server_sparse_top_1_divergence_loss, replace the -inf sentinel at known padding positions (labels == -100) with a finite zero via torch.where, before the shared _compute_sparse_top_1_divergence_loss helper runs. The label mask in _reduce_divergence_loss still excludes those positions from the final loss; the new block only prevents the sentinel from reaching the autograd graph through the tail-bucket/JSD math.

pad_mask_2d = ~required
pad_mask_3d = pad_mask_2d.unsqueeze(-1)
zero = torch.zeros((), dtype=topk_teacher_lps.dtype, device=topk_teacher_lps.device)
topk_teacher_lps = torch.where(pad_mask_3d, zero, topk_teacher_lps)
actual_teacher_lps = torch.where(pad_mask_2d, zero, actual_teacher_lps)

The -inf sentinel at the getter is intentionally preserved so that the existing isfinite validator at _compute_server_sparse_top_1_divergence_loss:1338-1360 can continue to distinguish "missing required data" (legitimate -inf in required positions; raises) from "padding" (-inf at ~required; neutralised here).

Tests

New file tests/experimental/test_distillation_trainer.py (no prior tests existed):

  • test_server_logprobs_variable_lengths_place_neg_inf_sentinel_at_padding — pins the sentinel contract at the server getter: real completion positions preserved, padding tail carries -inf.
  • test_reverse_kl_padding_mask_keeps_forward_and_backward_finite — direct regression: given -inf-padded teacher tensors, apply the Strategy-B mask, then verify _add_tail_bucket + _jsd_divergence(beta=1) produces finite forward AND finite backward.
  • TestDistillationTrainerServerPathVariableCompletion — end-to-end: DistillationTrainer.train() with trl-internal-testing/tiny-Qwen2ForCausalLM-2.5, per_device_train_batch_size=1, gradient_accumulation_steps=2, a dataset with clearly different completion lengths, and a VLLMClient monkey-patched to return canned ragged logprobs. Two cases: beta=1.0 (reverse KL) and beta=0.5 (JSD). Both assert loss and grad_norm from the training log-history are finite (Trainer zeroes per-parameter .grad post-step, so the log-history path is the reliable regression signal).

Empirical before/after

Same ragged-batch input; only the mask differs:

[WITHOUT mask (pre-fix)]
  _add_tail_bucket finite? False
  forward finite? True            (saved by nan_to_num)
  backward finite? False          <-- BUG
  grad[sample_0] = [0.37, -0.37, nan, nan, nan, nan]

[WITH Strategy B mask (post-fix)]
  _add_tail_bucket finite? True
  forward finite? True
  backward finite? True
  grad[sample_0] = [0.63, -0.63, -3.19, 3.19, -4.00, 4.00]

Padding positions (indices 1, 2 of sample 0) are correctly zero-grad post-fix — they remain excluded from the learning signal, just without NaN contamination.

Verified against TRL-spec env

Built a fresh env matching TRL v1.1.0's declared pyproject.toml ranges exactly (and not a superset), installed the editable fix branch, and ran pytest:

Package Installed
python 3.11.15
torch 2.10.0+cu130
vllm 0.17.1 (TRL pyproject upper bound)
transformers 4.57.3 (vLLM 0.17.1 requires <5)
xgrammar 0.1.29 (vLLM 0.17.1 transitive)
flashinfer-python 0.6.4 (vLLM 0.17.1 transitive)
accelerate 1.13.0
datasets 4.8.4
pytest 9.0.3

Result:

$ pytest tests/experimental/test_distillation_trainer.py -v

test_server_logprobs_variable_lengths_place_neg_inf_sentinel_at_padding   PASSED
test_reverse_kl_padding_mask_keeps_forward_and_backward_finite            PASSED
test_reverse_kl_finite_grad_under_ga2_with_ragged_batch                   PASSED
test_jsd_finite_grad_under_ga2_with_ragged_batch                          PASSED

============================ 4 passed in 14.38s ============================

No TRL-declared version was relaxed. The fix is a pure-Python change that does not touch the vLLM wire format or server contract.

Scope note: GKDTrainer is not affected

I audited the sibling trl.experimental.gkd.GKDTrainer in parallel. It is not vulnerable to the same class of bug:

  • No HTTP teacher path (teacher runs in-process at gkd_trainer.py:368).
  • No ragged-to-rectangular reassembly. Teacher logits come from a dense forward on the same padded (B, T) batch as the student.
  • No -inf sentinel anywhere in the GKD loss path (grep -c 'float(\"-inf\")' trl/experimental/gkd/gkd_trainer.py == 0).
  • Masking uses boolean index (jsd[mask] with mask = labels != -100) rather than torch.where(..., -inf, ...), so even if -inf ever appeared it would be filtered pre-reduction.
  • CI already exercises bs=2 in test_gkd_trainer.py:217 — the padded-batch path is tested for GKD but was entirely uncovered for DistillationTrainer.

The NaN-leak pattern is isolated to DistillationTrainer's server path.

Test plan for reviewer

  • pytest tests/experimental/test_distillation_trainer.py -v passes on a clean checkout of this branch.
  • pytest tests/experimental/test_distillation_trainer.py -v fails on main (regression is caught).
  • Training smoke against a real trl vllm-serve teacher with variable-length dataset at bs=1, ga=2, beta=1.0 no longer produces grad_norm=nan.

Refs

@k1064190 k1064190 marked this pull request as ready for review April 19, 2026 10:28
Copilot AI review requested due to automatic review settings April 19, 2026 10:28
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes a NaN-gradient issue in DistillationTrainer’s server-backed reverse-KL/JSD loss when batches contain variable completion lengths, and adds initial test coverage for the trainer’s server path.

Changes:

  • Neutralize -inf teacher logprob sentinels at ignore positions (labels == -100) in _compute_server_sparse_top_1_divergence_loss to prevent +inf/NaN propagation through tail-bucket + divergence math.
  • Expand inline documentation around the -inf sentinel contract in _get_teacher_token_logprobs_from_server.
  • Add new unit + functional tests covering server logprob padding semantics and the ragged-batch reverse-KL/JSD regression.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
trl/experimental/distillation/distillation_trainer.py Adds padding-position masking for server-backed sparse top-1 divergence loss to prevent NaN gradients on ragged batches.
tests/experimental/test_distillation_trainer.py Introduces the first DistillationTrainer tests, including regression coverage for variable-length server batches.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +295 to +301
trainer.train()

return [
(name, param.grad.detach().clone())
for name, param in trainer.model.named_parameters()
if param.grad is not None
]
Copy link

Copilot AI Apr 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The end-to-end tests read param.grad after trainer.train() completes. With the underlying transformers.Trainer loop, gradients are typically cleared via zero_grad() after each optimizer step, so at the end of training most/all param.grad entries can be None even when backward was finite. This can make the test fail spuriously or not actually assert the intended regression. Consider capturing grads before they are cleared (e.g., via a callback hook around the optimizer step / end of accumulation, or by manually running 2 microbatches with compute_loss + backward and checking grads before the optimizer step/zeroing).

Copilot uses AI. Check for mistakes.


# ---------------------------------------------------------------------------
# T1 — Unit: _get_teacher_token_logprobs_from_server padding must be finite.
Copy link

Copilot AI Apr 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The section header says “padding must be finite”, but this test explicitly asserts that padding positions contain the -inf sentinel. Updating the header/comment to match the asserted contract would avoid confusion for future readers.

Suggested change
# T1 — Unit: _get_teacher_token_logprobs_from_server padding must be finite.
# T1 — Unit: _get_teacher_token_logprobs_from_server padding uses `-inf` sentinel.

Copilot uses AI. Check for mistakes.
@k1064190 k1064190 force-pushed the fix/distillation-server-nan-on-variable-completion branch from aeb5aa8 to 52252d2 Compare April 19, 2026 11:32
…length

When ``use_teacher_server=True`` with ``beta > 0`` and ``bs * grad_accum > 1``,
the reverse-KL server path leaked NaN into the backward pass whenever
per-sample completion lengths differed within a batch.

Root cause
----------
``_get_teacher_token_logprobs_from_server`` fills the rectangular (B, T)
output tensor with the TRL house sentinel ``float("-inf")`` at intra-batch
padding positions (the tail of shorter samples). The forward-KL server path
(``_compute_server_forward_kl_loss``) neutralises this via
``torch.where(teacher > -inf, ..., -inf)`` plus a support mask threaded
through ``_add_tail_bucket``; the reverse-KL server path
(``_compute_server_sparse_top_1_divergence_loss``) did not. Both paths
landed in the same commit (huggingface#5407) -- an oversight, not deliberate
asymmetry.

Unmasked, the -inf sentinel produces a teacher distribution [-inf, 0]
after ``_add_tail_bucket`` and +inf in ``_jsd_divergence``'s forward pass
(clamped to ``finfo.max`` by ``nan_to_num``), but NaN in the backward
pass: autograd's chain rule does not respect ``nan_to_num``, so the
pre-clamp +inf leaks through as NaN gradient.

Fix
---
Mirror the forward-KL server path's masking: after the ``isfinite`` checks
that guard required positions, replace the -inf sentinel with a finite
zero at all known padding positions (``labels == -100``) via
``torch.where``. The label mask in ``_reduce_divergence_loss`` still
excludes those positions from the final loss; the new neutralisation
prevents their -inf values from propagating through ``_add_tail_bucket``
and ``_jsd_divergence`` into the autograd graph.

Tests
-----
``tests/experimental/test_distillation_trainer.py`` is new (DistillationTrainer
had zero dedicated tests at v1.1.0):
- Sentinel contract at the server-path getter.
- The reverse-KL mask pattern produces finite forward AND backward on a
  ragged batch.
- End-to-end training step under ``per_device_train_batch_size=1``,
  ``gradient_accumulation_steps=2``, variable completion lengths, with a
  mocked ``VLLMClient``. Covers ``beta=1.0`` (reverse KL) and ``beta=0.5``
  (JSD).

Reproduction pre-fix: ``grad_norm=nan`` on step 1.
Reproduction post-fix: ``grad_norm`` finite; padding positions receive
zero gradient (correctly excluded from the learning signal).

A parallel audit of GKDTrainer confirmed it is not vulnerable to the same
class of bug: its teacher runs in-process on a dense rectangular batch,
with no HTTP ragged-to-rectangular reassembly and no -inf sentinel in the
GKD loss path.

Refs: huggingface#5407.
@k1064190
Copy link
Copy Markdown
Owner Author

Closing to replace with a template-compliant PR after rebasing onto upstream/main.

@k1064190 k1064190 force-pushed the fix/distillation-server-nan-on-variable-completion branch from 52252d2 to 3c0d9ae Compare April 19, 2026 11:47
@k1064190 k1064190 closed this Apr 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants