Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2174,6 +2174,15 @@ def _generate_and_score_completions(
all_process_advantages = advantages.clone() # keep the aggregated advantages for logging
advantages = advantages[process_slice]

# Mask out completions from groups with zero reward std (all rewards identical).
# When std=0, advantages are 0, so the policy loss contribution is already zero.
# However, the KL penalty term (beta * per_token_kl) still produces gradients that
# pull the model toward the reference policy without any reward signal. Zeroing
# the completion mask for these groups eliminates the spurious KL gradients.
# See https://huggingface.co/papers/2505.22257 (Section 3.2).
is_std_zero_local = is_std_zero[process_slice]
completion_mask = completion_mask * (~is_std_zero_local).unsqueeze(1).int()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Zero-std masking blanks all eval completions when num_generations_eval=1

Medium Severity

When num_generations_eval == 1, std_rewards is unconditionally set to all zeros (line 2134–2135), causing is_std_zero to be True for every sample. The new masking at line 2184 then zeros out completion_mask for all eval completions, making the eval loss trivially zero and silencing all eval-time metrics (KL, entropy, clip ratio). The is_std_zero variable was originally designed only "for logging" and its all-True value in this edge case was harmless—but now that it drives masking, it incorrectly discards every eval sample. Since the fix targets spurious gradients, which don't exist during eval, the mask is both unnecessary and harmful in eval mode.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit cfc4e54. Configure here.


# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
for i, reward_func_name in enumerate(self.reward_func_names):
mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
Expand Down