diff --git a/docs/source/en/model_doc/sapiens2.md b/docs/source/en/model_doc/sapiens2.md
index b12c290c0f9c..459d6d2aa6e2 100644
--- a/docs/source/en/model_doc/sapiens2.md
+++ b/docs/source/en/model_doc/sapiens2.md
@@ -274,6 +274,40 @@ scores = results[0]["scores"]
+
+
+The example below shows how to compute the supervised Masked Mean Squared Error (MSE) loss with [`Sapiens2ForPoseEstimation`] by passing `labels` and optional `label_weights`. This is useful for fine-tuning using the `Trainer` API.
+
+```python
+import torch
+from transformers import AutoImageProcessor, AutoModelForPoseEstimation
+from transformers.image_utils import load_image
+
+image = load_image("http://images.cocodataset.org/val2017/000000004016.jpg")
+
+image_processor = AutoImageProcessor.from_pretrained("facebook/sapiens2-pose-0.4b")
+model = AutoModelForPoseEstimation.from_pretrained("facebook/sapiens2-pose-0.4b", device_map="auto")
+
+# Provide bounding boxes in COCO format (x, y, width, height) for each person
+boxes = [[[270.8, 0.6, 294.1, 379.5]]]
+inputs = image_processor(image, boxes=boxes, return_tensors="pt").to(model.device)
+
+# Create dummy labels (heatmaps) and visibility weights to simulate ground truth
+# 1.0 for visible keypoints, 0.0 for occluded/invisible keypoints
+batch_size, num_keypoints = 1, 308
+heatmap_height, heatmap_width = 1024, 768
+labels = torch.randn(batch_size, num_keypoints, heatmap_height, heatmap_width, device=model.device)
+label_weights = torch.ones(batch_size, num_keypoints, 1, 1, device=model.device)
+
+# Forward pass with loss calculation
+outputs = model(**inputs, labels=labels, label_weights=label_weights)
+
+print("Loss:", outputs.loss.item())
+```
+
+
+
+
The example below shows how to perform body-part segmentation with [`Sapiens2ForSemanticSegmentation`].
diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py
index 92c0bd0637e5..a10b62a023f8 100644
--- a/src/transformers/loss/loss_utils.py
+++ b/src/transformers/loss/loss_utils.py
@@ -198,4 +198,5 @@ def ForSemanticSegmentationLoss(
"ParakeetForTDT": ParakeetForTDTLoss,
"RfDetrForObjectDetection": LwDetrForObjectDetectionLoss,
"RfDetrForInstanceSegmentation": RfDetrForSegmentationLoss,
+ "Sapiens2ForPoseEstimation": torch.nn.functional.mse_loss,
}
diff --git a/src/transformers/models/sapiens2/modeling_sapiens2.py b/src/transformers/models/sapiens2/modeling_sapiens2.py
index c7f88fba35b9..1a3805482ddf 100644
--- a/src/transformers/models/sapiens2/modeling_sapiens2.py
+++ b/src/transformers/models/sapiens2/modeling_sapiens2.py
@@ -1151,6 +1151,7 @@ def forward(
pixel_values: torch.FloatTensor,
flip_pairs: torch.Tensor | None = None,
labels: torch.FloatTensor | None = None,
+ label_weights: torch.FloatTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> Sapiens2PoseEstimatorOutput:
r"""
@@ -1161,6 +1162,8 @@ def forward(
original orientation.
labels (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`, *optional*):
Heatmap ground truth for computing the loss.
+ label_weights (`torch.FloatTensor` of shape `(batch_size, num_labels, 1, 1)` or `(batch_size, num_labels, height, width)`, *optional*):
+ Visibility weights for each keypoint. Must be broadcastable to the shape of `labels`.
Example:
@@ -1200,7 +1203,7 @@ def forward(
loss = None
if labels is not None:
- raise NotImplementedError("Training is not yet supported")
+ loss = self.loss_function(heatmaps, labels, weight=label_weights)
return Sapiens2PoseEstimatorOutput(
loss=loss,
diff --git a/src/transformers/models/sapiens2/modular_sapiens2.py b/src/transformers/models/sapiens2/modular_sapiens2.py
index e040d7f0d694..b6770e518c61 100644
--- a/src/transformers/models/sapiens2/modular_sapiens2.py
+++ b/src/transformers/models/sapiens2/modular_sapiens2.py
@@ -1661,6 +1661,7 @@ def forward(
pixel_values: torch.FloatTensor,
flip_pairs: torch.Tensor | None = None,
labels: torch.FloatTensor | None = None,
+ label_weights: torch.FloatTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> Sapiens2PoseEstimatorOutput:
r"""
@@ -1671,6 +1672,8 @@ def forward(
original orientation.
labels (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`, *optional*):
Heatmap ground truth for computing the loss.
+ label_weights (`torch.FloatTensor` of shape `(batch_size, num_labels, 1, 1)` or `(batch_size, num_labels, height, width)`, *optional*):
+ Visibility weights for each keypoint. Must be broadcastable to the shape of `labels`.
Example:
@@ -1710,7 +1713,7 @@ def forward(
loss = None
if labels is not None:
- raise NotImplementedError("Training is not yet supported")
+ loss = self.loss_function(heatmaps, labels, weight=label_weights)
return Sapiens2PoseEstimatorOutput(
loss=loss,
diff --git a/tests/models/sapiens2/test_modeling_sapiens2.py b/tests/models/sapiens2/test_modeling_sapiens2.py
index a269c1eabdba..3e773f13b556 100644
--- a/tests/models/sapiens2/test_modeling_sapiens2.py
+++ b/tests/models/sapiens2/test_modeling_sapiens2.py
@@ -183,19 +183,32 @@ def create_and_check_for_semantic_segmentation(self, config, pixel_values, label
(self.batch_size, config.num_labels, expected_h, expected_h),
)
- def create_and_check_for_pose_estimation(self, config, pixel_values, labels):
+ def create_and_check_for_pose_estimation(self, config, pixel_values, labels, label_weights):
model = Sapiens2ForPoseEstimation(config)
model.to(torch_device)
model.eval()
+
with torch.no_grad():
result = model(pixel_values)
+
patch_height = self.image_size // self.patch_size
expected_h = patch_height * (2 ** len(config.head_config.upsample_out_channels))
+
self.parent.assertEqual(
result.heatmaps.shape,
(self.batch_size, config.num_labels, expected_h, expected_h),
)
+ # Test standard loss backward pass
+ result_with_loss = model(pixel_values, labels=labels)
+ self.parent.assertIsNotNone(result_with_loss.loss)
+ result_with_loss.loss.backward()
+
+ # Test weighted loss backward pass
+ result_with_weights = model(pixel_values, labels=labels, label_weights=label_weights)
+ self.parent.assertIsNotNone(result_with_weights.loss)
+ result_with_weights.loss.backward()
+
def create_and_check_for_normal_estimation(self, config, pixel_values, labels):
model = Sapiens2ForNormalEstimation(config)
model.to(torch_device)
@@ -278,6 +291,18 @@ def prepare_config_and_inputs_for_common(self):
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
+ def prepare_config_and_inputs_for_pose_estimation(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values = config_and_inputs[0], config_and_inputs[1]
+
+ patch_height = self.image_size // self.patch_size
+ expected_h = patch_height * (2 ** len(config.head_config.upsample_out_channels))
+
+ labels = floats_tensor([self.batch_size, config.num_labels, expected_h, expected_h])
+ label_weights = floats_tensor([self.batch_size, config.num_labels, expected_h, expected_h])
+
+ return config, pixel_values, labels, label_weights
+
@require_torch
class Sapiens2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
@@ -351,7 +376,7 @@ def test_for_semantic_segmentation(self):
self.model_tester.create_and_check_for_semantic_segmentation(*config_and_inputs)
def test_for_pose_estimation(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs_for_semantic_segmentation()
+ config_and_inputs = self.model_tester.prepare_config_and_inputs_for_pose_estimation()
self.model_tester.create_and_check_for_pose_estimation(*config_and_inputs)
def test_for_pointmap_estimation(self):