diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0f400e14fcd..caa7cf378c1 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -2174,6 +2174,15 @@ def _generate_and_score_completions( all_process_advantages = advantages.clone() # keep the aggregated advantages for logging advantages = advantages[process_slice] + # Mask out completions from groups with zero reward std (all rewards identical). + # When std=0, advantages are 0, so the policy loss contribution is already zero. + # However, the KL penalty term (beta * per_token_kl) still produces gradients that + # pull the model toward the reference policy without any reward signal. Zeroing + # the completion mask for these groups eliminates the spurious KL gradients. + # See https://huggingface.co/papers/2505.22257 (Section 3.2). + is_std_zero_local = is_std_zero[process_slice] + completion_mask = completion_mask * (~is_std_zero_local).unsqueeze(1).int() + # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) for i, reward_func_name in enumerate(self.reward_func_names): mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()