From 49774f53e5cc84a083c9d7db0ce35f3407dde0c4 Mon Sep 17 00:00:00 2001 From: Jiazhen Huang Date: Thu, 16 Apr 2026 19:27:51 +0800 Subject: [PATCH] Fix empty-target self-distillation loss to stay connected to model graph --- .../self_distillation/self_distillation_mixin.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/trl/experimental/self_distillation/self_distillation_mixin.py b/trl/experimental/self_distillation/self_distillation_mixin.py index fb2a8808de1..2036ffaf933 100644 --- a/trl/experimental/self_distillation/self_distillation_mixin.py +++ b/trl/experimental/self_distillation/self_distillation_mixin.py @@ -102,11 +102,6 @@ def _compute_self_distillation_loss( else: response_mask = completion_mask - if response_mask.sum() == 0: - mode = "train" if model.training else "eval" - self._log_self_distillation_metric(mode, "distillation_loss", 0.0) - return torch.tensor(0.0, device=completion_ids.device, requires_grad=True) - student_input_ids = torch.cat([prompt_ids, completion_ids], dim=1) student_attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) student_model_inputs = { @@ -205,6 +200,11 @@ def _compute_self_distillation_loss( "distillation_loss", self.accelerator.gather(mean_distill_loss).mean().item(), ) + self._log_self_distillation_metric( + mode, + "empty_target_batch", + float(response_mask.sum().item() == 0), + ) return loss