diff --git a/docs/source/en/model_doc/sapiens2.md b/docs/source/en/model_doc/sapiens2.md index b12c290c0f9c..459d6d2aa6e2 100644 --- a/docs/source/en/model_doc/sapiens2.md +++ b/docs/source/en/model_doc/sapiens2.md @@ -274,6 +274,40 @@ scores = results[0]["scores"] + + +The example below shows how to compute the supervised Masked Mean Squared Error (MSE) loss with [`Sapiens2ForPoseEstimation`] by passing `labels` and optional `label_weights`. This is useful for fine-tuning using the `Trainer` API. + +```python +import torch +from transformers import AutoImageProcessor, AutoModelForPoseEstimation +from transformers.image_utils import load_image + +image = load_image("http://images.cocodataset.org/val2017/000000004016.jpg") + +image_processor = AutoImageProcessor.from_pretrained("facebook/sapiens2-pose-0.4b") +model = AutoModelForPoseEstimation.from_pretrained("facebook/sapiens2-pose-0.4b", device_map="auto") + +# Provide bounding boxes in COCO format (x, y, width, height) for each person +boxes = [[[270.8, 0.6, 294.1, 379.5]]] +inputs = image_processor(image, boxes=boxes, return_tensors="pt").to(model.device) + +# Create dummy labels (heatmaps) and visibility weights to simulate ground truth +# 1.0 for visible keypoints, 0.0 for occluded/invisible keypoints +batch_size, num_keypoints = 1, 308 +heatmap_height, heatmap_width = 1024, 768 +labels = torch.randn(batch_size, num_keypoints, heatmap_height, heatmap_width, device=model.device) +label_weights = torch.ones(batch_size, num_keypoints, 1, 1, device=model.device) + +# Forward pass with loss calculation +outputs = model(**inputs, labels=labels, label_weights=label_weights) + +print("Loss:", outputs.loss.item()) +``` + + + + The example below shows how to perform body-part segmentation with [`Sapiens2ForSemanticSegmentation`]. diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index 92c0bd0637e5..a10b62a023f8 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -198,4 +198,5 @@ def ForSemanticSegmentationLoss( "ParakeetForTDT": ParakeetForTDTLoss, "RfDetrForObjectDetection": LwDetrForObjectDetectionLoss, "RfDetrForInstanceSegmentation": RfDetrForSegmentationLoss, + "Sapiens2ForPoseEstimation": torch.nn.functional.mse_loss, } diff --git a/src/transformers/models/sapiens2/modeling_sapiens2.py b/src/transformers/models/sapiens2/modeling_sapiens2.py index c7f88fba35b9..1a3805482ddf 100644 --- a/src/transformers/models/sapiens2/modeling_sapiens2.py +++ b/src/transformers/models/sapiens2/modeling_sapiens2.py @@ -1151,6 +1151,7 @@ def forward( pixel_values: torch.FloatTensor, flip_pairs: torch.Tensor | None = None, labels: torch.FloatTensor | None = None, + label_weights: torch.FloatTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> Sapiens2PoseEstimatorOutput: r""" @@ -1161,6 +1162,8 @@ def forward( original orientation. labels (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`, *optional*): Heatmap ground truth for computing the loss. + label_weights (`torch.FloatTensor` of shape `(batch_size, num_labels, 1, 1)` or `(batch_size, num_labels, height, width)`, *optional*): + Visibility weights for each keypoint. Must be broadcastable to the shape of `labels`. Example: @@ -1200,7 +1203,7 @@ def forward( loss = None if labels is not None: - raise NotImplementedError("Training is not yet supported") + loss = self.loss_function(heatmaps, labels, weight=label_weights) return Sapiens2PoseEstimatorOutput( loss=loss, diff --git a/src/transformers/models/sapiens2/modular_sapiens2.py b/src/transformers/models/sapiens2/modular_sapiens2.py index e040d7f0d694..b6770e518c61 100644 --- a/src/transformers/models/sapiens2/modular_sapiens2.py +++ b/src/transformers/models/sapiens2/modular_sapiens2.py @@ -1661,6 +1661,7 @@ def forward( pixel_values: torch.FloatTensor, flip_pairs: torch.Tensor | None = None, labels: torch.FloatTensor | None = None, + label_weights: torch.FloatTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> Sapiens2PoseEstimatorOutput: r""" @@ -1671,6 +1672,8 @@ def forward( original orientation. labels (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`, *optional*): Heatmap ground truth for computing the loss. + label_weights (`torch.FloatTensor` of shape `(batch_size, num_labels, 1, 1)` or `(batch_size, num_labels, height, width)`, *optional*): + Visibility weights for each keypoint. Must be broadcastable to the shape of `labels`. Example: @@ -1710,7 +1713,7 @@ def forward( loss = None if labels is not None: - raise NotImplementedError("Training is not yet supported") + loss = self.loss_function(heatmaps, labels, weight=label_weights) return Sapiens2PoseEstimatorOutput( loss=loss, diff --git a/tests/models/sapiens2/test_modeling_sapiens2.py b/tests/models/sapiens2/test_modeling_sapiens2.py index a269c1eabdba..3e773f13b556 100644 --- a/tests/models/sapiens2/test_modeling_sapiens2.py +++ b/tests/models/sapiens2/test_modeling_sapiens2.py @@ -183,19 +183,32 @@ def create_and_check_for_semantic_segmentation(self, config, pixel_values, label (self.batch_size, config.num_labels, expected_h, expected_h), ) - def create_and_check_for_pose_estimation(self, config, pixel_values, labels): + def create_and_check_for_pose_estimation(self, config, pixel_values, labels, label_weights): model = Sapiens2ForPoseEstimation(config) model.to(torch_device) model.eval() + with torch.no_grad(): result = model(pixel_values) + patch_height = self.image_size // self.patch_size expected_h = patch_height * (2 ** len(config.head_config.upsample_out_channels)) + self.parent.assertEqual( result.heatmaps.shape, (self.batch_size, config.num_labels, expected_h, expected_h), ) + # Test standard loss backward pass + result_with_loss = model(pixel_values, labels=labels) + self.parent.assertIsNotNone(result_with_loss.loss) + result_with_loss.loss.backward() + + # Test weighted loss backward pass + result_with_weights = model(pixel_values, labels=labels, label_weights=label_weights) + self.parent.assertIsNotNone(result_with_weights.loss) + result_with_weights.loss.backward() + def create_and_check_for_normal_estimation(self, config, pixel_values, labels): model = Sapiens2ForNormalEstimation(config) model.to(torch_device) @@ -278,6 +291,18 @@ def prepare_config_and_inputs_for_common(self): inputs_dict = {"pixel_values": pixel_values} return config, inputs_dict + def prepare_config_and_inputs_for_pose_estimation(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs[0], config_and_inputs[1] + + patch_height = self.image_size // self.patch_size + expected_h = patch_height * (2 ** len(config.head_config.upsample_out_channels)) + + labels = floats_tensor([self.batch_size, config.num_labels, expected_h, expected_h]) + label_weights = floats_tensor([self.batch_size, config.num_labels, expected_h, expected_h]) + + return config, pixel_values, labels, label_weights + @require_torch class Sapiens2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): @@ -351,7 +376,7 @@ def test_for_semantic_segmentation(self): self.model_tester.create_and_check_for_semantic_segmentation(*config_and_inputs) def test_for_pose_estimation(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs_for_semantic_segmentation() + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_pose_estimation() self.model_tester.create_and_check_for_pose_estimation(*config_and_inputs) def test_for_pointmap_estimation(self):