Skip to content

fix(distillation): reverse-KL server path NaN on variable completion length#5594

Open
k1064190 wants to merge 6 commits intohuggingface:mainfrom
k1064190:fix/distillation-server-nan-on-variable-completion
Open

fix(distillation): reverse-KL server path NaN on variable completion length#5594
k1064190 wants to merge 6 commits intohuggingface:mainfrom
k1064190:fix/distillation-server-nan-on-variable-completion

Conversation

@k1064190
Copy link
Copy Markdown

@k1064190 k1064190 commented Apr 19, 2026

What does this PR do?

Fixes a NaN-gradient bug in DistillationTrainer's server-backed reverse-KL loss when batches produce per-sample padding within the rectangular teacher logprob tensor.

Trigger: use_teacher_server=True + beta == 1.0 (reverse KL) + any config where _get_teacher_token_logprobs_from_server emits -inf padding. This happens whenever per-sample completion_offsets differ (verified empirically with bs=1, ga=2 and bs=2, ga=1). Forward loss is finite (clamped by nan_to_num); grad_norm=nan on the first optim step.

Root cause: _get_teacher_token_logprobs_from_server pads the rectangular (B, T) teacher tensor with -inf — both for shorter samples in cross-sample ragged batches and at per-sample head offsets (completion_start - aligned_prompt_length). The forward-KL server path (_compute_server_forward_kl_loss) tolerates this: its p·(log_t - log_s) form has teacher_probs=0 at padding, so 0·(-inf)=NaN in forward is nan_to_num'd to 0, and backward partials ∂/∂safe_teacher = teacher_probs = 0 zero the gradient. The JSD mixture path is safe for the same reason — log((1-β)·student_probs + β·teacher_probs) stays finite when teacher_probs=0.

Reverse-KL (student_probs·(log_s - log_t)) has no such protection: the multiplier student_probs is nonzero at padding, so finite·(+inf)=+inf survives into backward, and grad_norm becomes NaN via sqrt(sum(inf²))/clipping.

Fix: In _compute_server_sparse_top_1_divergence_loss, after the existing isfinite validation, neutralise the -inf sentinel at known padding positions (labels == -100) with a finite zero via torch.where, before the shared divergence helper runs. The label mask in _reduce_divergence_loss continues to exclude these positions from the final loss — the new block only prevents the sentinel from reaching the autograd graph.

This is the minimal surgical fix: the shared helper _compute_sparse_top_1_divergence_loss is untouched, the -inf sentinel contract (used by the existing isfinite validator to flag missing-required-data) is preserved, and forward-KL / JSD paths are not affected.

Tests (tests/experimental/test_distillation_trainer.py, new file — trainer had no dedicated tests):

  • Sentinel contract at the server getter.
  • Mask pattern in isolation: _add_tail_bucket + _jsd_divergence(beta=1) post-mask produces finite forward AND backward.
  • End-to-end DistillationTrainer.train() with mocked VLLMClient, parametrized over bs=1, ga=2 and bs=2, ga=1. Both fail (grad_norm=nan) without the fix; both pass with it.

pytest tests/experimental/test_distillation_trainer.py -v: 4/4 pass.

Env (trl env):

- Platform: Linux-5.14.0-427.22.1.el9_4.x86_64-x86_64-with-glibc2.35
- Python version: 3.11.15
- TRL version: 1.3.0.dev0+3c0d9ae
- PyTorch version: 2.10.0+cu130
- accelerator(s): NVIDIA RTX PRO 6000 Blackwell Server Edition x3
- Transformers version: 4.57.3
- Accelerate version: 1.13.0
- Datasets version: 4.8.4
- HF Hub version: 0.36.2
- bitsandbytes version: 0.49.2
- DeepSpeed version: 0.18.9
- Liger-Kernel version: 0.7.0
- PEFT version: 0.19.1
- vLLM version: 0.17.1

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

AI writing disclosure

  • No AI usage: the PR was written entirely by a human.
  • AI-assisted: some parts were suggested or improved by AI, but the PR was written and reviewed by a human.
  • AI-generated: the PR was mostly or fully generated by an AI tool.

Note

Medium Risk
Touches loss math in the teacher-server distillation path; a small masking change, but it directly affects training stability and gradients for reverse-KL on variable-length completions.

Overview
Fixes a NaN/inf gradient issue in DistillationTrainer when using a teacher server with reverse-KL (beta=1) and variable completion lengths by zeroing out teacher -inf logprob padding at labels == -100 before divergence computation.

Adds a new experimental test suite that (1) asserts the server getter preserves -inf sentinels for ragged batches, (2) verifies the padding-neutralisation strategy keeps forward/backward finite, and (3) runs an end-to-end mocked teacher-server training step to ensure loss/grad_norm remain finite under different batch/accumulation shapes.

Reviewed by Cursor Bugbot for commit efa22bc. Bugbot is set up for automated code reviews on this repo. Configure here.

…length

When ``use_teacher_server=True`` with ``beta > 0`` and ``bs * grad_accum > 1``,
the reverse-KL server path leaked NaN into the backward pass whenever
per-sample completion lengths differed within a batch.

Root cause
----------
``_get_teacher_token_logprobs_from_server`` fills the rectangular (B, T)
output tensor with the TRL house sentinel ``float("-inf")`` at intra-batch
padding positions (the tail of shorter samples). The forward-KL server path
(``_compute_server_forward_kl_loss``) neutralises this via
``torch.where(teacher > -inf, ..., -inf)`` plus a support mask threaded
through ``_add_tail_bucket``; the reverse-KL server path
(``_compute_server_sparse_top_1_divergence_loss``) did not. Both paths
landed in the same commit (huggingface#5407) -- an oversight, not deliberate
asymmetry.

Unmasked, the -inf sentinel produces a teacher distribution [-inf, 0]
after ``_add_tail_bucket`` and +inf in ``_jsd_divergence``'s forward pass
(clamped to ``finfo.max`` by ``nan_to_num``), but NaN in the backward
pass: autograd's chain rule does not respect ``nan_to_num``, so the
pre-clamp +inf leaks through as NaN gradient.

Fix
---
Mirror the forward-KL server path's masking: after the ``isfinite`` checks
that guard required positions, replace the -inf sentinel with a finite
zero at all known padding positions (``labels == -100``) via
``torch.where``. The label mask in ``_reduce_divergence_loss`` still
excludes those positions from the final loss; the new neutralisation
prevents their -inf values from propagating through ``_add_tail_bucket``
and ``_jsd_divergence`` into the autograd graph.

Tests
-----
``tests/experimental/test_distillation_trainer.py`` is new (DistillationTrainer
had zero dedicated tests at v1.1.0):
- Sentinel contract at the server-path getter.
- The reverse-KL mask pattern produces finite forward AND backward on a
  ragged batch.
- End-to-end training step under ``per_device_train_batch_size=1``,
  ``gradient_accumulation_steps=2``, variable completion lengths, with a
  mocked ``VLLMClient``. Covers ``beta=1.0`` (reverse KL) and ``beta=0.5``
  (JSD).

Reproduction pre-fix: ``grad_norm=nan`` on step 1.
Reproduction post-fix: ``grad_norm`` finite; padding positions receive
zero gradient (correctly excluded from the learning signal).

A parallel audit of GKDTrainer confirmed it is not vulnerable to the same
class of bug: its teacher runs in-process on a dense rectangular batch,
with no HTTP ragged-to-rectangular reassembly and no -inf sentinel in the
GKD loss path.

Refs: huggingface#5407.
Collapse the module summary, triple-line test docstrings, and the one-shot
helper factories in `tests/experimental/test_distillation_trainer.py` into
the repo's terse style. Functional coverage (sentinel pin, mid-level mask
finite forward/backward, end-to-end train() under bs*ga>1 with ragged
batches for beta=1.0 and beta=0.5) is unchanged; all 4 tests still pass.
Copilot AI review requested due to automatic review settings April 19, 2026 12:45
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes a NaN-gradient issue in DistillationTrainer’s server-backed generalized JSD / reverse-KL loss when training with variable per-sample completion lengths (teacher server pads with -inf sentinels).

Changes:

  • Neutralize -inf teacher logprob sentinels at padding positions (labels == -100) in the server reverse-KL / sparse top-1 divergence path before divergence math runs.
  • Clarify the sentinel/masking contract in _get_teacher_token_logprobs_from_server comments.
  • Add dedicated tests covering the sentinel contract, masking stability (forward + backward finite), and an end-to-end DistillationTrainer.train() run with gradient accumulation and ragged batches using a mocked VLLMClient.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
trl/experimental/distillation/distillation_trainer.py Masks out -inf padding sentinels for the server reverse-KL/sparse-top1 divergence path to prevent NaN gradients.
tests/experimental/test_distillation_trainer.py Adds unit + integration tests to reproduce and guard against the NaN-gradient regression with ragged completion lengths.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tests/experimental/test_distillation_trainer.py Outdated
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit d3f6a18. Configure here.

Comment thread tests/experimental/test_distillation_trainer.py Outdated
Experiments showed the end-to-end regression tests were miscalibrated:

- `bs=1, ga=2` and `bs=2, ga=1` both reproduce `grad_norm=nan` when the
  fix is removed (because `_get_teacher_token_logprobs_from_server`
  emits -inf padding not only for cross-sample ragged batches but also
  via per-sample `completion_offsets`). Parametrize the reverse-KL test
  over both configs for fuller coverage.
- `beta=0.5` (JSD mixture) does not actually produce NaN without the
  fix in either config: `_jsd_divergence`'s mixture branch routes
  student gradients through `log((1-beta)*student_probs + beta*teacher_probs)`,
  which stays finite when teacher_probs=0 at padding. Drop the JSD
  end-to-end test — it was a vacuous guard.

Unit + mid-level tests (sentinel contract, mask-keeps-forward-and-
backward-finite) are unchanged.
Copy link
Copy Markdown
Collaborator

@cmpatino cmpatino left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution! This is a good check to prevent NaNs while training.

I left some comments before approving the merge of the PR.

Comment thread trl/experimental/distillation/distillation_trainer.py Outdated
Comment thread trl/experimental/distillation/distillation_trainer.py Outdated
Comment thread tests/experimental/test_distillation_trainer.py
- Trim padding-mask comment to two lines focused on what it prevents;
  the backward-autograd exposition lived in the PR description.
- Drop the explicit `zero` scalar tensor — `torch.where` broadcasts
  the `0.0` literal to the tensor's dtype/device (verified bit-exact
  equivalent in fp32/bf16/fp16).
- Mark the end-to-end `trainer.train()` test `@pytest.mark.slow` to
  match repo convention for heavy tests (saves ~8s per warm CI run).
@k1064190 k1064190 requested a review from cmpatino April 22, 2026 08:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants