From 181ff3e1fe94a074baec1283daa89c76a4b140ff Mon Sep 17 00:00:00 2001 From: Sainava Date: Fri, 19 Jun 2026 13:50:59 +0530 Subject: [PATCH 1/8] Add native masked MSE loss for Sapiens2ForPoseEstimation (#46518) --- .../models/sapiens2/modeling_sapiens2.py | 28 ++++++++++++++++- .../models/sapiens2/modular_sapiens2.py | 28 ++++++++++++++++- .../models/sapiens2/test_modeling_sapiens2.py | 30 +++++++++++++++++++ 3 files changed, 84 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/sapiens2/modeling_sapiens2.py b/src/transformers/models/sapiens2/modeling_sapiens2.py index c7f88fba35b9..3a8165a3352f 100644 --- a/src/transformers/models/sapiens2/modeling_sapiens2.py +++ b/src/transformers/models/sapiens2/modeling_sapiens2.py @@ -1151,6 +1151,7 @@ def forward( pixel_values: torch.FloatTensor, flip_pairs: torch.Tensor | None = None, labels: torch.FloatTensor | None = None, + target_weights: torch.FloatTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> Sapiens2PoseEstimatorOutput: r""" @@ -1161,6 +1162,9 @@ def forward( original orientation. labels (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`, *optional*): Heatmap ground truth for computing the loss. + target_weights (`torch.FloatTensor` of shape `(batch_size, num_keypoints)` or `(batch_size, num_keypoints, height, width)`, *optional*): + Visibility weights for each keypoint. If a keypoint is occluded or invisible, its weight should be 0.0 to prevent + penalizing the model during the loss computation. If `None`, standard unmasked MSE loss is computed. Example: @@ -1200,7 +1204,29 @@ def forward( loss = None if labels is not None: - raise NotImplementedError("Training is not yet supported") + if labels.shape != heatmaps.shape: + raise ValueError(f"Expected labels shape {heatmaps.shape}, got {labels.shape}") + + if target_weights is None: + loss = torch.nn.functional.mse_loss(heatmaps, labels) + else: + if target_weights.ndim not in (2, 4): + raise ValueError(f"Expected target_weights to have 2 or 4 dimensions, got {target_weights.ndim}") + + if target_weights.shape != labels.shape[: target_weights.ndim]: + raise ValueError( + f"Expected target_weights shape to match {labels.shape[: target_weights.ndim]}, " + f"got {target_weights.shape}" + ) + + per_pixel_loss = torch.nn.functional.mse_loss(heatmaps, labels, reduction="none") + + ndim_pad = labels.ndim - target_weights.ndim + mask = target_weights.view(target_weights.shape + (1,) * ndim_pad) + + mask = mask.to(heatmaps.dtype) + + loss = (per_pixel_loss * mask).mean() return Sapiens2PoseEstimatorOutput( loss=loss, diff --git a/src/transformers/models/sapiens2/modular_sapiens2.py b/src/transformers/models/sapiens2/modular_sapiens2.py index e040d7f0d694..62fb3c71d892 100644 --- a/src/transformers/models/sapiens2/modular_sapiens2.py +++ b/src/transformers/models/sapiens2/modular_sapiens2.py @@ -1661,6 +1661,7 @@ def forward( pixel_values: torch.FloatTensor, flip_pairs: torch.Tensor | None = None, labels: torch.FloatTensor | None = None, + target_weights: torch.FloatTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> Sapiens2PoseEstimatorOutput: r""" @@ -1671,6 +1672,9 @@ def forward( original orientation. labels (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`, *optional*): Heatmap ground truth for computing the loss. + target_weights (`torch.FloatTensor` of shape `(batch_size, num_keypoints)` or `(batch_size, num_keypoints, height, width)`, *optional*): + Visibility weights for each keypoint. If a keypoint is occluded or invisible, its weight should be 0.0 to prevent + penalizing the model during the loss computation. If `None`, standard unmasked MSE loss is computed. Example: @@ -1710,7 +1714,29 @@ def forward( loss = None if labels is not None: - raise NotImplementedError("Training is not yet supported") + if labels.shape != heatmaps.shape: + raise ValueError(f"Expected labels shape {heatmaps.shape}, got {labels.shape}") + + if target_weights is None: + loss = torch.nn.functional.mse_loss(heatmaps, labels) + else: + if target_weights.ndim not in (2, 4): + raise ValueError(f"Expected target_weights to have 2 or 4 dimensions, got {target_weights.ndim}") + + if target_weights.shape != labels.shape[: target_weights.ndim]: + raise ValueError( + f"Expected target_weights shape to match {labels.shape[: target_weights.ndim]}, " + f"got {target_weights.shape}" + ) + + per_pixel_loss = torch.nn.functional.mse_loss(heatmaps, labels, reduction="none") + + ndim_pad = labels.ndim - target_weights.ndim + mask = target_weights.view(target_weights.shape + (1,) * ndim_pad) + + mask = mask.to(heatmaps.dtype) + + loss = (per_pixel_loss * mask).mean() return Sapiens2PoseEstimatorOutput( loss=loss, diff --git a/tests/models/sapiens2/test_modeling_sapiens2.py b/tests/models/sapiens2/test_modeling_sapiens2.py index a269c1eabdba..2927a958e518 100644 --- a/tests/models/sapiens2/test_modeling_sapiens2.py +++ b/tests/models/sapiens2/test_modeling_sapiens2.py @@ -187,15 +187,45 @@ def create_and_check_for_pose_estimation(self, config, pixel_values, labels): model = Sapiens2ForPoseEstimation(config) model.to(torch_device) model.eval() + with torch.no_grad(): result = model(pixel_values) + patch_height = self.image_size // self.patch_size expected_h = patch_height * (2 ** len(config.head_config.upsample_out_channels)) + self.parent.assertEqual( result.heatmaps.shape, (self.batch_size, config.num_labels, expected_h, expected_h), ) + + pose_labels = torch.randn_like(result.heatmaps) + + with torch.no_grad(): + result_with_loss = model( + pixel_values, + labels=pose_labels, + ) + + self.parent.assertIsNotNone(result_with_loss.loss) + + + target_weights = torch.ones( + self.batch_size, + config.num_labels, + device=pixel_values.device, + ) + + with torch.no_grad(): + result_with_weights = model( + pixel_values, + labels=pose_labels, + target_weights=target_weights, + ) + + self.parent.assertIsNotNone(result_with_weights.loss) + def create_and_check_for_normal_estimation(self, config, pixel_values, labels): model = Sapiens2ForNormalEstimation(config) model.to(torch_device) From 49b0033151efbffd155a98999f51536246c8e85c Mon Sep 17 00:00:00 2001 From: Sainava Date: Fri, 19 Jun 2026 14:24:04 +0530 Subject: [PATCH 2/8] Fix trailing whitespace in test file --- tests/models/sapiens2/test_modeling_sapiens2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/models/sapiens2/test_modeling_sapiens2.py b/tests/models/sapiens2/test_modeling_sapiens2.py index 2927a958e518..dde3c98fa69e 100644 --- a/tests/models/sapiens2/test_modeling_sapiens2.py +++ b/tests/models/sapiens2/test_modeling_sapiens2.py @@ -199,7 +199,6 @@ def create_and_check_for_pose_estimation(self, config, pixel_values, labels): (self.batch_size, config.num_labels, expected_h, expected_h), ) - pose_labels = torch.randn_like(result.heatmaps) with torch.no_grad(): @@ -210,7 +209,6 @@ def create_and_check_for_pose_estimation(self, config, pixel_values, labels): self.parent.assertIsNotNone(result_with_loss.loss) - target_weights = torch.ones( self.batch_size, config.num_labels, From 8d8a8479c2fcf154e69c29053486d2f9c3eea42d Mon Sep 17 00:00:00 2001 From: Sainava Date: Fri, 19 Jun 2026 22:28:04 +0530 Subject: [PATCH 3/8] Add supervised fine-tuning documentation for Sapiens2 pose estimation --- docs/source/en/model_doc/sapiens2.md | 34 +++++++++++++++ .../models/sapiens2/modeling_sapiens2.py | 36 +++++----------- .../models/sapiens2/modular_sapiens2.py | 36 +++++----------- .../models/sapiens2/test_modeling_sapiens2.py | 41 +++++++++++++------ 4 files changed, 84 insertions(+), 63 deletions(-) diff --git a/docs/source/en/model_doc/sapiens2.md b/docs/source/en/model_doc/sapiens2.md index b12c290c0f9c..20b8a4c36d51 100644 --- a/docs/source/en/model_doc/sapiens2.md +++ b/docs/source/en/model_doc/sapiens2.md @@ -274,6 +274,40 @@ scores = results[0]["scores"] + + +The example below shows how to compute the supervised Masked Mean Squared Error (MSE) loss with [`Sapiens2ForPoseEstimation`] by passing `labels` and optional `label_weights`. This is useful for fine-tuning using the `Trainer` API. + +```python +import torch +from transformers import AutoImageProcessor, AutoModelForPoseEstimation +from transformers.image_utils import load_image + +image = load_image("http://images.cocodataset.org/val2017/000000004016.jpg") + +image_processor = AutoImageProcessor.from_pretrained("facebook/sapiens2-pose-0.4b") +model = AutoModelForPoseEstimation.from_pretrained("facebook/sapiens2-pose-0.4b", device_map="auto") + +# Provide bounding boxes in COCO format (x, y, width, height) for each person +boxes = [[[270.8, 0.6, 294.1, 379.5]]] +inputs = image_processor(image, boxes=boxes, return_tensors="pt").to(model.device) + +# Create dummy labels (heatmaps) and visibility weights to simulate ground truth +# 1.0 for visible keypoints, 0.0 for occluded/invisible keypoints +batch_size, num_keypoints = 1, 308 +heatmap_height, heatmap_width = 256, 192 +labels = torch.randn(batch_size, num_keypoints, heatmap_height, heatmap_width, device=model.device) +label_weights = torch.ones(batch_size, num_keypoints, 1, 1, device=model.device) + +# Forward pass with loss calculation +outputs = model(**inputs, labels=labels, label_weights=label_weights) + +print("Loss:", outputs.loss.item()) +``` + + + + The example below shows how to perform body-part segmentation with [`Sapiens2ForSemanticSegmentation`]. diff --git a/src/transformers/models/sapiens2/modeling_sapiens2.py b/src/transformers/models/sapiens2/modeling_sapiens2.py index 3a8165a3352f..019efc2cb548 100644 --- a/src/transformers/models/sapiens2/modeling_sapiens2.py +++ b/src/transformers/models/sapiens2/modeling_sapiens2.py @@ -1137,6 +1137,8 @@ def flip_back(output_flipped, flip_pairs, target_type="gaussian-heatmap"): """, ) class Sapiens2ForPoseEstimation(Sapiens2PreTrainedModel): + _loss_function = torch.nn.functional.mse_loss + def __init__(self, config: Sapiens2Config): super().__init__(config) self.num_labels = config.num_labels @@ -1151,7 +1153,7 @@ def forward( pixel_values: torch.FloatTensor, flip_pairs: torch.Tensor | None = None, labels: torch.FloatTensor | None = None, - target_weights: torch.FloatTensor | None = None, + label_weights: torch.FloatTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> Sapiens2PoseEstimatorOutput: r""" @@ -1162,9 +1164,8 @@ def forward( original orientation. labels (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`, *optional*): Heatmap ground truth for computing the loss. - target_weights (`torch.FloatTensor` of shape `(batch_size, num_keypoints)` or `(batch_size, num_keypoints, height, width)`, *optional*): - Visibility weights for each keypoint. If a keypoint is occluded or invisible, its weight should be 0.0 to prevent - penalizing the model during the loss computation. If `None`, standard unmasked MSE loss is computed. + label_weights (`torch.FloatTensor` of shape `(batch_size, num_labels, 1, 1)` or `(batch_size, num_labels, height, width)`, *optional*): + Visibility weights for each keypoint. Must be broadcastable to the shape of `labels`. Example: @@ -1204,29 +1205,14 @@ def forward( loss = None if labels is not None: - if labels.shape != heatmaps.shape: - raise ValueError(f"Expected labels shape {heatmaps.shape}, got {labels.shape}") + # Calculate unreduced loss using the standard HF loss API + loss = self.loss_function(heatmaps, labels, reduction="none") - if target_weights is None: - loss = torch.nn.functional.mse_loss(heatmaps, labels) + if label_weights is not None: + # Assume label_weights is already in a broadcastable shape and dtype + loss = (loss * label_weights).mean() else: - if target_weights.ndim not in (2, 4): - raise ValueError(f"Expected target_weights to have 2 or 4 dimensions, got {target_weights.ndim}") - - if target_weights.shape != labels.shape[: target_weights.ndim]: - raise ValueError( - f"Expected target_weights shape to match {labels.shape[: target_weights.ndim]}, " - f"got {target_weights.shape}" - ) - - per_pixel_loss = torch.nn.functional.mse_loss(heatmaps, labels, reduction="none") - - ndim_pad = labels.ndim - target_weights.ndim - mask = target_weights.view(target_weights.shape + (1,) * ndim_pad) - - mask = mask.to(heatmaps.dtype) - - loss = (per_pixel_loss * mask).mean() + loss = loss.mean() return Sapiens2PoseEstimatorOutput( loss=loss, diff --git a/src/transformers/models/sapiens2/modular_sapiens2.py b/src/transformers/models/sapiens2/modular_sapiens2.py index 62fb3c71d892..65f963cb313a 100644 --- a/src/transformers/models/sapiens2/modular_sapiens2.py +++ b/src/transformers/models/sapiens2/modular_sapiens2.py @@ -1647,6 +1647,8 @@ def forward( """, ) class Sapiens2ForPoseEstimation(Sapiens2PreTrainedModel): + _loss_function = torch.nn.functional.mse_loss + def __init__(self, config: Sapiens2Config): super().__init__(config) self.num_labels = config.num_labels @@ -1661,7 +1663,7 @@ def forward( pixel_values: torch.FloatTensor, flip_pairs: torch.Tensor | None = None, labels: torch.FloatTensor | None = None, - target_weights: torch.FloatTensor | None = None, + label_weights: torch.FloatTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> Sapiens2PoseEstimatorOutput: r""" @@ -1672,9 +1674,8 @@ def forward( original orientation. labels (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`, *optional*): Heatmap ground truth for computing the loss. - target_weights (`torch.FloatTensor` of shape `(batch_size, num_keypoints)` or `(batch_size, num_keypoints, height, width)`, *optional*): - Visibility weights for each keypoint. If a keypoint is occluded or invisible, its weight should be 0.0 to prevent - penalizing the model during the loss computation. If `None`, standard unmasked MSE loss is computed. + label_weights (`torch.FloatTensor` of shape `(batch_size, num_labels, 1, 1)` or `(batch_size, num_labels, height, width)`, *optional*): + Visibility weights for each keypoint. Must be broadcastable to the shape of `labels`. Example: @@ -1714,29 +1715,14 @@ def forward( loss = None if labels is not None: - if labels.shape != heatmaps.shape: - raise ValueError(f"Expected labels shape {heatmaps.shape}, got {labels.shape}") + # Calculate unreduced loss using the standard HF loss API + loss = self.loss_function(heatmaps, labels, reduction="none") - if target_weights is None: - loss = torch.nn.functional.mse_loss(heatmaps, labels) + if label_weights is not None: + # Assume label_weights is already in a broadcastable shape and dtype + loss = (loss * label_weights).mean() else: - if target_weights.ndim not in (2, 4): - raise ValueError(f"Expected target_weights to have 2 or 4 dimensions, got {target_weights.ndim}") - - if target_weights.shape != labels.shape[: target_weights.ndim]: - raise ValueError( - f"Expected target_weights shape to match {labels.shape[: target_weights.ndim]}, " - f"got {target_weights.shape}" - ) - - per_pixel_loss = torch.nn.functional.mse_loss(heatmaps, labels, reduction="none") - - ndim_pad = labels.ndim - target_weights.ndim - mask = target_weights.view(target_weights.shape + (1,) * ndim_pad) - - mask = mask.to(heatmaps.dtype) - - loss = (per_pixel_loss * mask).mean() + loss = loss.mean() return Sapiens2PoseEstimatorOutput( loss=loss, diff --git a/tests/models/sapiens2/test_modeling_sapiens2.py b/tests/models/sapiens2/test_modeling_sapiens2.py index dde3c98fa69e..7ed8b0ca023c 100644 --- a/tests/models/sapiens2/test_modeling_sapiens2.py +++ b/tests/models/sapiens2/test_modeling_sapiens2.py @@ -183,7 +183,7 @@ def create_and_check_for_semantic_segmentation(self, config, pixel_values, label (self.batch_size, config.num_labels, expected_h, expected_h), ) - def create_and_check_for_pose_estimation(self, config, pixel_values, labels): + def create_and_check_for_pose_estimation(self, config, pixel_values, labels, label_weights): model = Sapiens2ForPoseEstimation(config) model.to(torch_device) model.eval() @@ -199,27 +199,19 @@ def create_and_check_for_pose_estimation(self, config, pixel_values, labels): (self.batch_size, config.num_labels, expected_h, expected_h), ) - pose_labels = torch.randn_like(result.heatmaps) - with torch.no_grad(): result_with_loss = model( pixel_values, - labels=pose_labels, + labels=labels, ) self.parent.assertIsNotNone(result_with_loss.loss) - target_weights = torch.ones( - self.batch_size, - config.num_labels, - device=pixel_values.device, - ) - with torch.no_grad(): result_with_weights = model( pixel_values, - labels=pose_labels, - target_weights=target_weights, + labels=labels, + label_weights=label_weights, ) self.parent.assertIsNotNone(result_with_weights.loss) @@ -306,6 +298,29 @@ def prepare_config_and_inputs_for_common(self): inputs_dict = {"pixel_values": pixel_values} return config, inputs_dict + def prepare_config_and_inputs_for_pose_estimation(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs[0], config_and_inputs[1] + + patch_height = self.image_size // self.patch_size + expected_h = patch_height * (2 ** len(config.head_config.upsample_out_channels)) + + labels = torch.randn( + self.batch_size, + config.num_labels, + expected_h, + expected_h, + device=pixel_values.device, + dtype=pixel_values.dtype, + ) + + # Creating weights in a broadcastable shape as requested by the maintainer + label_weights = torch.ones( + self.batch_size, config.num_labels, 1, 1, device=pixel_values.device, dtype=pixel_values.dtype + ) + + return config, pixel_values, labels, label_weights + @require_torch class Sapiens2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): @@ -379,7 +394,7 @@ def test_for_semantic_segmentation(self): self.model_tester.create_and_check_for_semantic_segmentation(*config_and_inputs) def test_for_pose_estimation(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs_for_semantic_segmentation() + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_pose_estimation() self.model_tester.create_and_check_for_pose_estimation(*config_and_inputs) def test_for_pointmap_estimation(self): From fcfa5745a8390a899c08d6852af624e548089955 Mon Sep 17 00:00:00 2001 From: Sainava Date: Fri, 19 Jun 2026 23:12:18 +0530 Subject: [PATCH 4/8] Fix method binding bug on _loss_function by wrapping in staticmethod --- src/transformers/models/sapiens2/modeling_sapiens2.py | 4 +--- src/transformers/models/sapiens2/modular_sapiens2.py | 4 +--- tests/models/sapiens2/test_modeling_sapiens2.py | 1 - 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/sapiens2/modeling_sapiens2.py b/src/transformers/models/sapiens2/modeling_sapiens2.py index 019efc2cb548..be051024feff 100644 --- a/src/transformers/models/sapiens2/modeling_sapiens2.py +++ b/src/transformers/models/sapiens2/modeling_sapiens2.py @@ -1137,7 +1137,7 @@ def flip_back(output_flipped, flip_pairs, target_type="gaussian-heatmap"): """, ) class Sapiens2ForPoseEstimation(Sapiens2PreTrainedModel): - _loss_function = torch.nn.functional.mse_loss + _loss_function = staticmethod(torch.nn.functional.mse_loss) def __init__(self, config: Sapiens2Config): super().__init__(config) @@ -1205,11 +1205,9 @@ def forward( loss = None if labels is not None: - # Calculate unreduced loss using the standard HF loss API loss = self.loss_function(heatmaps, labels, reduction="none") if label_weights is not None: - # Assume label_weights is already in a broadcastable shape and dtype loss = (loss * label_weights).mean() else: loss = loss.mean() diff --git a/src/transformers/models/sapiens2/modular_sapiens2.py b/src/transformers/models/sapiens2/modular_sapiens2.py index 65f963cb313a..dc2bb0e93730 100644 --- a/src/transformers/models/sapiens2/modular_sapiens2.py +++ b/src/transformers/models/sapiens2/modular_sapiens2.py @@ -1647,7 +1647,7 @@ def forward( """, ) class Sapiens2ForPoseEstimation(Sapiens2PreTrainedModel): - _loss_function = torch.nn.functional.mse_loss + _loss_function = staticmethod(torch.nn.functional.mse_loss) def __init__(self, config: Sapiens2Config): super().__init__(config) @@ -1715,11 +1715,9 @@ def forward( loss = None if labels is not None: - # Calculate unreduced loss using the standard HF loss API loss = self.loss_function(heatmaps, labels, reduction="none") if label_weights is not None: - # Assume label_weights is already in a broadcastable shape and dtype loss = (loss * label_weights).mean() else: loss = loss.mean() diff --git a/tests/models/sapiens2/test_modeling_sapiens2.py b/tests/models/sapiens2/test_modeling_sapiens2.py index 7ed8b0ca023c..0e79aaf40b20 100644 --- a/tests/models/sapiens2/test_modeling_sapiens2.py +++ b/tests/models/sapiens2/test_modeling_sapiens2.py @@ -314,7 +314,6 @@ def prepare_config_and_inputs_for_pose_estimation(self): dtype=pixel_values.dtype, ) - # Creating weights in a broadcastable shape as requested by the maintainer label_weights = torch.ones( self.batch_size, config.num_labels, 1, 1, device=pixel_values.device, dtype=pixel_values.dtype ) From cdf969554bd78bad17a290857b89ad9f4ae44dcf Mon Sep 17 00:00:00 2001 From: Sainava Date: Mon, 22 Jun 2026 17:50:35 +0530 Subject: [PATCH 5/8] Switch test to floats_tensor and fix docs heatmap dimensions --- docs/source/en/model_doc/sapiens2.md | 4 ++-- tests/models/sapiens2/test_modeling_sapiens2.py | 9 +-------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/docs/source/en/model_doc/sapiens2.md b/docs/source/en/model_doc/sapiens2.md index 20b8a4c36d51..459d6d2aa6e2 100644 --- a/docs/source/en/model_doc/sapiens2.md +++ b/docs/source/en/model_doc/sapiens2.md @@ -274,7 +274,7 @@ scores = results[0]["scores"] - + The example below shows how to compute the supervised Masked Mean Squared Error (MSE) loss with [`Sapiens2ForPoseEstimation`] by passing `labels` and optional `label_weights`. This is useful for fine-tuning using the `Trainer` API. @@ -295,7 +295,7 @@ inputs = image_processor(image, boxes=boxes, return_tensors="pt").to(model.devic # Create dummy labels (heatmaps) and visibility weights to simulate ground truth # 1.0 for visible keypoints, 0.0 for occluded/invisible keypoints batch_size, num_keypoints = 1, 308 -heatmap_height, heatmap_width = 256, 192 +heatmap_height, heatmap_width = 1024, 768 labels = torch.randn(batch_size, num_keypoints, heatmap_height, heatmap_width, device=model.device) label_weights = torch.ones(batch_size, num_keypoints, 1, 1, device=model.device) diff --git a/tests/models/sapiens2/test_modeling_sapiens2.py b/tests/models/sapiens2/test_modeling_sapiens2.py index 0e79aaf40b20..05f4c19d7d89 100644 --- a/tests/models/sapiens2/test_modeling_sapiens2.py +++ b/tests/models/sapiens2/test_modeling_sapiens2.py @@ -305,14 +305,7 @@ def prepare_config_and_inputs_for_pose_estimation(self): patch_height = self.image_size // self.patch_size expected_h = patch_height * (2 ** len(config.head_config.upsample_out_channels)) - labels = torch.randn( - self.batch_size, - config.num_labels, - expected_h, - expected_h, - device=pixel_values.device, - dtype=pixel_values.dtype, - ) + labels = floats_tensor([self.batch_size, config.num_labels, expected_h, expected_h]) label_weights = torch.ones( self.batch_size, config.num_labels, 1, 1, device=pixel_values.device, dtype=pixel_values.dtype From 5056650638361b87320aaf207beae40302fdcc1f Mon Sep 17 00:00:00 2001 From: Sainava Date: Mon, 22 Jun 2026 20:19:36 +0530 Subject: [PATCH 6/8] Simplify loss function by relying on upstream weight expansion --- src/transformers/models/sapiens2/modeling_sapiens2.py | 7 +------ src/transformers/models/sapiens2/modular_sapiens2.py | 7 +------ 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/sapiens2/modeling_sapiens2.py b/src/transformers/models/sapiens2/modeling_sapiens2.py index be051024feff..5354b6a15bca 100644 --- a/src/transformers/models/sapiens2/modeling_sapiens2.py +++ b/src/transformers/models/sapiens2/modeling_sapiens2.py @@ -1205,12 +1205,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(heatmaps, labels, reduction="none") - - if label_weights is not None: - loss = (loss * label_weights).mean() - else: - loss = loss.mean() + loss = self.loss_function(heatmaps, labels, weight=label_weights) return Sapiens2PoseEstimatorOutput( loss=loss, diff --git a/src/transformers/models/sapiens2/modular_sapiens2.py b/src/transformers/models/sapiens2/modular_sapiens2.py index dc2bb0e93730..6b3730cba5ba 100644 --- a/src/transformers/models/sapiens2/modular_sapiens2.py +++ b/src/transformers/models/sapiens2/modular_sapiens2.py @@ -1715,12 +1715,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(heatmaps, labels, reduction="none") - - if label_weights is not None: - loss = (loss * label_weights).mean() - else: - loss = loss.mean() + loss = self.loss_function(heatmaps, labels, weight=label_weights) return Sapiens2PoseEstimatorOutput( loss=loss, From e94d79f97448e44e3492bc21acc617be2fb3c648 Mon Sep 17 00:00:00 2001 From: Sainava Date: Mon, 22 Jun 2026 20:41:37 +0530 Subject: [PATCH 7/8] Update dummy label_weights shape in tests to match heatmaps --- tests/models/sapiens2/test_modeling_sapiens2.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/models/sapiens2/test_modeling_sapiens2.py b/tests/models/sapiens2/test_modeling_sapiens2.py index 05f4c19d7d89..ff35b08db677 100644 --- a/tests/models/sapiens2/test_modeling_sapiens2.py +++ b/tests/models/sapiens2/test_modeling_sapiens2.py @@ -306,10 +306,7 @@ def prepare_config_and_inputs_for_pose_estimation(self): expected_h = patch_height * (2 ** len(config.head_config.upsample_out_channels)) labels = floats_tensor([self.batch_size, config.num_labels, expected_h, expected_h]) - - label_weights = torch.ones( - self.batch_size, config.num_labels, 1, 1, device=pixel_values.device, dtype=pixel_values.dtype - ) + label_weights = floats_tensor([self.batch_size, config.num_labels, expected_h, expected_h]) return config, pixel_values, labels, label_weights From 840b2c3651f1aa3b243cc447e424f7b778006eba Mon Sep 17 00:00:00 2001 From: Sainava Date: Wed, 24 Jun 2026 02:10:21 +0530 Subject: [PATCH 8/8] Centralize loss function and test backward pass --- src/transformers/loss/loss_utils.py | 1 + .../models/sapiens2/modeling_sapiens2.py | 2 -- .../models/sapiens2/modular_sapiens2.py | 2 -- .../models/sapiens2/test_modeling_sapiens2.py | 19 ++++++------------- 4 files changed, 7 insertions(+), 17 deletions(-) diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index 92c0bd0637e5..a10b62a023f8 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -198,4 +198,5 @@ def ForSemanticSegmentationLoss( "ParakeetForTDT": ParakeetForTDTLoss, "RfDetrForObjectDetection": LwDetrForObjectDetectionLoss, "RfDetrForInstanceSegmentation": RfDetrForSegmentationLoss, + "Sapiens2ForPoseEstimation": torch.nn.functional.mse_loss, } diff --git a/src/transformers/models/sapiens2/modeling_sapiens2.py b/src/transformers/models/sapiens2/modeling_sapiens2.py index 5354b6a15bca..1a3805482ddf 100644 --- a/src/transformers/models/sapiens2/modeling_sapiens2.py +++ b/src/transformers/models/sapiens2/modeling_sapiens2.py @@ -1137,8 +1137,6 @@ def flip_back(output_flipped, flip_pairs, target_type="gaussian-heatmap"): """, ) class Sapiens2ForPoseEstimation(Sapiens2PreTrainedModel): - _loss_function = staticmethod(torch.nn.functional.mse_loss) - def __init__(self, config: Sapiens2Config): super().__init__(config) self.num_labels = config.num_labels diff --git a/src/transformers/models/sapiens2/modular_sapiens2.py b/src/transformers/models/sapiens2/modular_sapiens2.py index 6b3730cba5ba..b6770e518c61 100644 --- a/src/transformers/models/sapiens2/modular_sapiens2.py +++ b/src/transformers/models/sapiens2/modular_sapiens2.py @@ -1647,8 +1647,6 @@ def forward( """, ) class Sapiens2ForPoseEstimation(Sapiens2PreTrainedModel): - _loss_function = staticmethod(torch.nn.functional.mse_loss) - def __init__(self, config: Sapiens2Config): super().__init__(config) self.num_labels = config.num_labels diff --git a/tests/models/sapiens2/test_modeling_sapiens2.py b/tests/models/sapiens2/test_modeling_sapiens2.py index ff35b08db677..3e773f13b556 100644 --- a/tests/models/sapiens2/test_modeling_sapiens2.py +++ b/tests/models/sapiens2/test_modeling_sapiens2.py @@ -199,22 +199,15 @@ def create_and_check_for_pose_estimation(self, config, pixel_values, labels, lab (self.batch_size, config.num_labels, expected_h, expected_h), ) - with torch.no_grad(): - result_with_loss = model( - pixel_values, - labels=labels, - ) - + # Test standard loss backward pass + result_with_loss = model(pixel_values, labels=labels) self.parent.assertIsNotNone(result_with_loss.loss) + result_with_loss.loss.backward() - with torch.no_grad(): - result_with_weights = model( - pixel_values, - labels=labels, - label_weights=label_weights, - ) - + # Test weighted loss backward pass + result_with_weights = model(pixel_values, labels=labels, label_weights=label_weights) self.parent.assertIsNotNone(result_with_weights.loss) + result_with_weights.loss.backward() def create_and_check_for_normal_estimation(self, config, pixel_values, labels): model = Sapiens2ForNormalEstimation(config)