-
Notifications
You must be signed in to change notification settings - Fork 33.6k
Add native masked MSE loss for Sapiens2ForPoseEstimation #46764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
181ff3e
49b0033
d83872e
8d8a847
fcfa574
f996ad9
4a99e4c
cdf9695
35ce0bc
5056650
e94d79f
840b2c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -274,6 +274,40 @@ scores = results[0]["scores"] | |
|
|
||
| </hfoption> | ||
|
|
||
| <hfoption id="Supervised fine-tuning (Pose estimation)"> | ||
|
|
||
| The example below shows how to compute the supervised Masked Mean Squared Error (MSE) loss with [`Sapiens2ForPoseEstimation`] by passing `labels` and optional `label_weights`. This is useful for fine-tuning using the `Trainer` API. | ||
|
|
||
| ```python | ||
| import torch | ||
| from transformers import AutoImageProcessor, AutoModelForPoseEstimation | ||
| from transformers.image_utils import load_image | ||
|
|
||
| image = load_image("http://images.cocodataset.org/val2017/000000004016.jpg") | ||
|
|
||
| image_processor = AutoImageProcessor.from_pretrained("facebook/sapiens2-pose-0.4b") | ||
| model = AutoModelForPoseEstimation.from_pretrained("facebook/sapiens2-pose-0.4b", device_map="auto") | ||
|
|
||
| # Provide bounding boxes in COCO format (x, y, width, height) for each person | ||
| boxes = [[[270.8, 0.6, 294.1, 379.5]]] | ||
| inputs = image_processor(image, boxes=boxes, return_tensors="pt").to(model.device) | ||
|
|
||
| # Create dummy labels (heatmaps) and visibility weights to simulate ground truth | ||
| # 1.0 for visible keypoints, 0.0 for occluded/invisible keypoints | ||
| batch_size, num_keypoints = 1, 308 | ||
| heatmap_height, heatmap_width = 256, 192 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can assume here that the heatmaps have the same height and width as the preprocessed image. We'll have to add the pose preprocessing to the For the loss calculation you might have to interpolate the model outputs to match the label size again. |
||
| labels = torch.randn(batch_size, num_keypoints, heatmap_height, heatmap_width, device=model.device) | ||
| label_weights = torch.ones(batch_size, num_keypoints, 1, 1, device=model.device) | ||
|
|
||
| # Forward pass with loss calculation | ||
| outputs = model(**inputs, labels=labels, label_weights=label_weights) | ||
|
|
||
| print("Loss:", outputs.loss.item()) | ||
| ``` | ||
|
|
||
| </hfoption> | ||
|
|
||
|
|
||
| <hfoption id="Semantic segmentation"> | ||
|
|
||
| The example below shows how to perform body-part segmentation with [`Sapiens2ForSemanticSegmentation`]. | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1137,6 +1137,8 @@ def flip_back(output_flipped, flip_pairs, target_type="gaussian-heatmap"): | |||||||||||||||
| """, | ||||||||||||||||
| ) | ||||||||||||||||
| class Sapiens2ForPoseEstimation(Sapiens2PreTrainedModel): | ||||||||||||||||
| _loss_function = staticmethod(torch.nn.functional.mse_loss) | ||||||||||||||||
|
|
||||||||||||||||
| def __init__(self, config: Sapiens2Config): | ||||||||||||||||
| super().__init__(config) | ||||||||||||||||
| self.num_labels = config.num_labels | ||||||||||||||||
|
|
@@ -1151,6 +1153,7 @@ def forward( | |||||||||||||||
| pixel_values: torch.FloatTensor, | ||||||||||||||||
| flip_pairs: torch.Tensor | None = None, | ||||||||||||||||
| labels: torch.FloatTensor | None = None, | ||||||||||||||||
| label_weights: torch.FloatTensor | None = None, | ||||||||||||||||
| **kwargs: Unpack[TransformersKwargs], | ||||||||||||||||
| ) -> Sapiens2PoseEstimatorOutput: | ||||||||||||||||
| r""" | ||||||||||||||||
|
|
@@ -1161,6 +1164,8 @@ def forward( | |||||||||||||||
| original orientation. | ||||||||||||||||
| labels (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`, *optional*): | ||||||||||||||||
| Heatmap ground truth for computing the loss. | ||||||||||||||||
| label_weights (`torch.FloatTensor` of shape `(batch_size, num_labels, 1, 1)` or `(batch_size, num_labels, height, width)`, *optional*): | ||||||||||||||||
| Visibility weights for each keypoint. Must be broadcastable to the shape of `labels`. | ||||||||||||||||
|
|
||||||||||||||||
| Example: | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -1200,7 +1205,12 @@ def forward( | |||||||||||||||
|
|
||||||||||||||||
| loss = None | ||||||||||||||||
| if labels is not None: | ||||||||||||||||
| raise NotImplementedError("Training is not yet supported") | ||||||||||||||||
| loss = self.loss_function(heatmaps, labels, reduction="none") | ||||||||||||||||
|
|
||||||||||||||||
| if label_weights is not None: | ||||||||||||||||
| loss = (loss * label_weights).mean() | ||||||||||||||||
| else: | ||||||||||||||||
| loss = loss.mean() | ||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we pass
Suggested change
This should work correctly if |
||||||||||||||||
|
|
||||||||||||||||
| return Sapiens2PoseEstimatorOutput( | ||||||||||||||||
| loss=loss, | ||||||||||||||||
|
|
||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -1647,6 +1647,8 @@ def forward( | |||
| """, | ||||
| ) | ||||
| class Sapiens2ForPoseEstimation(Sapiens2PreTrainedModel): | ||||
| _loss_function = staticmethod(torch.nn.functional.mse_loss) | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please lets add this to
so something along |
||||
|
|
||||
| def __init__(self, config: Sapiens2Config): | ||||
| super().__init__(config) | ||||
| self.num_labels = config.num_labels | ||||
|
|
@@ -1661,6 +1663,7 @@ def forward( | |||
| pixel_values: torch.FloatTensor, | ||||
| flip_pairs: torch.Tensor | None = None, | ||||
| labels: torch.FloatTensor | None = None, | ||||
| label_weights: torch.FloatTensor | None = None, | ||||
| **kwargs: Unpack[TransformersKwargs], | ||||
| ) -> Sapiens2PoseEstimatorOutput: | ||||
| r""" | ||||
|
|
@@ -1671,6 +1674,8 @@ def forward( | |||
| original orientation. | ||||
| labels (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`, *optional*): | ||||
| Heatmap ground truth for computing the loss. | ||||
| label_weights (`torch.FloatTensor` of shape `(batch_size, num_labels, 1, 1)` or `(batch_size, num_labels, height, width)`, *optional*): | ||||
| Visibility weights for each keypoint. Must be broadcastable to the shape of `labels`. | ||||
|
|
||||
| Example: | ||||
|
|
||||
|
|
@@ -1710,7 +1715,12 @@ def forward( | |||
|
|
||||
| loss = None | ||||
| if labels is not None: | ||||
| raise NotImplementedError("Training is not yet supported") | ||||
| loss = self.loss_function(heatmaps, labels, reduction="none") | ||||
|
|
||||
| if label_weights is not None: | ||||
| loss = (loss * label_weights).mean() | ||||
| else: | ||||
| loss = loss.mean() | ||||
|
|
||||
| return Sapiens2PoseEstimatorOutput( | ||||
| loss=loss, | ||||
|
|
||||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -183,19 +183,39 @@ def create_and_check_for_semantic_segmentation(self, config, pixel_values, label | |||||||||
| (self.batch_size, config.num_labels, expected_h, expected_h), | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| def create_and_check_for_pose_estimation(self, config, pixel_values, labels): | ||||||||||
| def create_and_check_for_pose_estimation(self, config, pixel_values, labels, label_weights): | ||||||||||
| model = Sapiens2ForPoseEstimation(config) | ||||||||||
| model.to(torch_device) | ||||||||||
| model.eval() | ||||||||||
|
|
||||||||||
| with torch.no_grad(): | ||||||||||
| result = model(pixel_values) | ||||||||||
|
|
||||||||||
| patch_height = self.image_size // self.patch_size | ||||||||||
| expected_h = patch_height * (2 ** len(config.head_config.upsample_out_channels)) | ||||||||||
|
|
||||||||||
| self.parent.assertEqual( | ||||||||||
| result.heatmaps.shape, | ||||||||||
| (self.batch_size, config.num_labels, expected_h, expected_h), | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| with torch.no_grad(): | ||||||||||
| result_with_loss = model( | ||||||||||
| pixel_values, | ||||||||||
| labels=labels, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| self.parent.assertIsNotNone(result_with_loss.loss) | ||||||||||
|
|
||||||||||
| with torch.no_grad(): | ||||||||||
| result_with_weights = model( | ||||||||||
| pixel_values, | ||||||||||
| labels=labels, | ||||||||||
| label_weights=label_weights, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| self.parent.assertIsNotNone(result_with_weights.loss) | ||||||||||
|
|
||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Imo we should maybe also check that the backward doesnt result in runtime error so would avoid the no grads here and call a backward |
||||||||||
| def create_and_check_for_normal_estimation(self, config, pixel_values, labels): | ||||||||||
| model = Sapiens2ForNormalEstimation(config) | ||||||||||
| model.to(torch_device) | ||||||||||
|
|
@@ -278,6 +298,28 @@ def prepare_config_and_inputs_for_common(self): | |||||||||
| inputs_dict = {"pixel_values": pixel_values} | ||||||||||
| return config, inputs_dict | ||||||||||
|
|
||||||||||
| def prepare_config_and_inputs_for_pose_estimation(self): | ||||||||||
| config_and_inputs = self.prepare_config_and_inputs() | ||||||||||
| config, pixel_values = config_and_inputs[0], config_and_inputs[1] | ||||||||||
|
|
||||||||||
| patch_height = self.image_size // self.patch_size | ||||||||||
| expected_h = patch_height * (2 ** len(config.head_config.upsample_out_channels)) | ||||||||||
|
|
||||||||||
| labels = torch.randn( | ||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
| self.batch_size, | ||||||||||
| config.num_labels, | ||||||||||
| expected_h, | ||||||||||
| expected_h, | ||||||||||
| device=pixel_values.device, | ||||||||||
| dtype=pixel_values.dtype, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| label_weights = torch.ones( | ||||||||||
| self.batch_size, config.num_labels, 1, 1, device=pixel_values.device, dtype=pixel_values.dtype | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| return config, pixel_values, labels, label_weights | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @require_torch | ||||||||||
| class Sapiens2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): | ||||||||||
|
|
@@ -351,7 +393,7 @@ def test_for_semantic_segmentation(self): | |||||||||
| self.model_tester.create_and_check_for_semantic_segmentation(*config_and_inputs) | ||||||||||
|
|
||||||||||
| def test_for_pose_estimation(self): | ||||||||||
| config_and_inputs = self.model_tester.prepare_config_and_inputs_for_semantic_segmentation() | ||||||||||
| config_and_inputs = self.model_tester.prepare_config_and_inputs_for_pose_estimation() | ||||||||||
| self.model_tester.create_and_check_for_pose_estimation(*config_and_inputs) | ||||||||||
|
|
||||||||||
| def test_for_pointmap_estimation(self): | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe rename to this to be closer in form to the
Pose estimationandPose estimation with flip augmentationssections