Fix spurious KL gradients for zero-std reward groups in GRPOTrainer#5640
Fix spurious KL gradients for zero-std reward groups in GRPOTrainer#5640robrui wants to merge 2 commits intohuggingface:mainfrom
Conversation
When all completions in a group receive the same reward, the group std is zero and advantages are zero. While the policy loss correctly becomes zero (advantage * log_prob = 0), the KL penalty term (beta * per_token_kl) still produces non-zero gradients through the model's log-probabilities. These spurious KL gradients pull the model toward the reference policy without any reward signal guiding the direction. This fix zeros out the completion mask for groups with zero reward std, eliminating both the (already-zero) policy loss and the spurious KL gradients. This is consistent with the approach in DAPO (Section 3.2, "On-Policy Clipped Objective with Zero Variance Masking"), https://arxiv.org/abs/2505.22257. Fixes huggingface#5588
…(per project convention)
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit cfc4e54. Configure here.
| # the completion mask for these groups eliminates the spurious KL gradients. | ||
| # See https://huggingface.co/papers/2505.22257 (Section 3.2). | ||
| is_std_zero_local = is_std_zero[process_slice] | ||
| completion_mask = completion_mask * (~is_std_zero_local).unsqueeze(1).int() |
There was a problem hiding this comment.
Zero-std masking blanks all eval completions when num_generations_eval=1
Medium Severity
When num_generations_eval == 1, std_rewards is unconditionally set to all zeros (line 2134–2135), causing is_std_zero to be True for every sample. The new masking at line 2184 then zeros out completion_mask for all eval completions, making the eval loss trivially zero and silencing all eval-time metrics (KL, entropy, clip ratio). The is_std_zero variable was originally designed only "for logging" and its all-True value in this edge case was harmless—but now that it drives masking, it incorrectly discards every eval sample. Since the fix targets spurious gradients, which don't exist during eval, the mask is both unnecessary and harmful in eval mode.
Additional Locations (1)
Reviewed by Cursor Bugbot for commit cfc4e54. Configure here.


What does this PR do?
Fixes #5588: In GRPOTrainer, groups with identical rewards (zero reward std) produce spurious KL gradients when
beta > 0.When std=0, advantages are 0, so policy loss contribution is zero. But the KL penalty term still produces gradients pulling the model toward the reference policy without any reward signal.
Fix: mask out completions from groups with zero reward std by multiplying
completion_maskwith(~is_std_zero_local).unsqueeze(1).int(). This eliminates the spurious KL gradients while preserving the zero policy loss.Before submitting
beta > 0in GRPO/RLOO #5588)AI writing disclosure
Note
Medium Risk
Changes GRPOTrainer loss masking so completions from zero reward-variance groups are excluded, which can alter optimization dynamics and convergence but is localized to the GRPO training loop.
Overview
Prevents spurious KL-only updates in
GRPOTrainerwhen a prompt group’s rewards have zero variance by masking those completions out ofcompletion_maskafter per-process slicing.This ensures groups with
std_rewards == 0contribute no gradients (including thebeta * per_token_klterm), aligning behavior with the referenced DAPO “zero variance masking” approach while preserving existing reward/std logging.Reviewed by Cursor Bugbot for commit cfc4e54. Bugbot is set up for automated code reviews on this repo. Configure here.