Skip to content

Fix spurious KL gradients for zero-std reward groups in GRPOTrainer#5640

Open
robrui wants to merge 2 commits intohuggingface:mainfrom
robrui:fix/grpo-zero-std-kl-masking
Open

Fix spurious KL gradients for zero-std reward groups in GRPOTrainer#5640
robrui wants to merge 2 commits intohuggingface:mainfrom
robrui:fix/grpo-zero-std-kl-masking

Conversation

@robrui
Copy link
Copy Markdown

@robrui robrui commented Apr 24, 2026

What does this PR do?

Fixes #5588: In GRPOTrainer, groups with identical rewards (zero reward std) produce spurious KL gradients when beta > 0.

When std=0, advantages are 0, so policy loss contribution is zero. But the KL penalty term still produces gradients pulling the model toward the reference policy without any reward signal.

Fix: mask out completions from groups with zero reward std by multiplying completion_mask with (~is_std_zero_local).unsqueeze(1).int(). This eliminates the spurious KL gradients while preserving the zero policy loss.

Before submitting

AI writing disclosure

  • AI-assisted (AI tools assisted with code generation; all changes reviewed and verified by a human before submission)

Note

Medium Risk
Changes GRPOTrainer loss masking so completions from zero reward-variance groups are excluded, which can alter optimization dynamics and convergence but is localized to the GRPO training loop.

Overview
Prevents spurious KL-only updates in GRPOTrainer when a prompt group’s rewards have zero variance by masking those completions out of completion_mask after per-process slicing.

This ensures groups with std_rewards == 0 contribute no gradients (including the beta * per_token_kl term), aligning behavior with the referenced DAPO “zero variance masking” approach while preserving existing reward/std logging.

Reviewed by Cursor Bugbot for commit cfc4e54. Bugbot is set up for automated code reviews on this repo. Configure here.

When all completions in a group receive the same reward, the group std is
zero and advantages are zero. While the policy loss correctly becomes zero
(advantage * log_prob = 0), the KL penalty term (beta * per_token_kl)
still produces non-zero gradients through the model's log-probabilities.
These spurious KL gradients pull the model toward the reference policy
without any reward signal guiding the direction.

This fix zeros out the completion mask for groups with zero reward std,
eliminating both the (already-zero) policy loss and the spurious KL
gradients. This is consistent with the approach in DAPO (Section 3.2,
"On-Policy Clipped Objective with Zero Variance Masking"),
https://arxiv.org/abs/2505.22257.

Fixes huggingface#5588
Comment thread trl/trainer/grpo_trainer.py Outdated
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit cfc4e54. Configure here.

# the completion mask for these groups eliminates the spurious KL gradients.
# See https://huggingface.co/papers/2505.22257 (Section 3.2).
is_std_zero_local = is_std_zero[process_slice]
completion_mask = completion_mask * (~is_std_zero_local).unsqueeze(1).int()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Zero-std masking blanks all eval completions when num_generations_eval=1

Medium Severity

When num_generations_eval == 1, std_rewards is unconditionally set to all zeros (line 2134–2135), causing is_std_zero to be True for every sample. The new masking at line 2184 then zeros out completion_mask for all eval completions, making the eval loss trivially zero and silencing all eval-time metrics (KL, entropy, clip ratio). The is_std_zero variable was originally designed only "for logging" and its all-True value in this edge case was harmless—but now that it drives masking, it incorrectly discards every eval sample. Since the fix targets spurious gradients, which don't exist during eval, the mask is both unnecessary and harmful in eval mode.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit cfc4e54. Configure here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Zero-std reward groups produce spurious KL gradients when beta > 0 in GRPO/RLOO

1 participant