Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions docs/source/en/model_doc/sapiens2.md
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,40 @@ scores = results[0]["scores"]

</hfoption>

<hfoption id="Pose estimation training">

The example below shows how to compute the supervised Masked Mean Squared Error (MSE) loss with [`Sapiens2ForPoseEstimation`] by passing `labels` and optional `label_weights`. This is useful for fine-tuning using the `Trainer` API.

```python
import torch
from transformers import AutoImageProcessor, AutoModelForPoseEstimation
from transformers.image_utils import load_image

image = load_image("http://images.cocodataset.org/val2017/000000004016.jpg")

image_processor = AutoImageProcessor.from_pretrained("facebook/sapiens2-pose-0.4b")
model = AutoModelForPoseEstimation.from_pretrained("facebook/sapiens2-pose-0.4b", device_map="auto")

# Provide bounding boxes in COCO format (x, y, width, height) for each person
boxes = [[[270.8, 0.6, 294.1, 379.5]]]
inputs = image_processor(image, boxes=boxes, return_tensors="pt").to(model.device)

# Create dummy labels (heatmaps) and visibility weights to simulate ground truth
# 1.0 for visible keypoints, 0.0 for occluded/invisible keypoints
batch_size, num_keypoints = 1, 308
heatmap_height, heatmap_width = 1024, 768
labels = torch.randn(batch_size, num_keypoints, heatmap_height, heatmap_width, device=model.device)
label_weights = torch.ones(batch_size, num_keypoints, 1, 1, device=model.device)

# Forward pass with loss calculation
outputs = model(**inputs, labels=labels, label_weights=label_weights)

print("Loss:", outputs.loss.item())
```

</hfoption>


<hfoption id="Semantic segmentation">

The example below shows how to perform body-part segmentation with [`Sapiens2ForSemanticSegmentation`].
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/sapiens2/modeling_sapiens2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,8 @@ def flip_back(output_flipped, flip_pairs, target_type="gaussian-heatmap"):
""",
)
class Sapiens2ForPoseEstimation(Sapiens2PreTrainedModel):
_loss_function = staticmethod(torch.nn.functional.mse_loss)

def __init__(self, config: Sapiens2Config):
super().__init__(config)
self.num_labels = config.num_labels
Expand All @@ -1151,6 +1153,7 @@ def forward(
pixel_values: torch.FloatTensor,
flip_pairs: torch.Tensor | None = None,
labels: torch.FloatTensor | None = None,
label_weights: torch.FloatTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> Sapiens2PoseEstimatorOutput:
r"""
Expand All @@ -1161,6 +1164,8 @@ def forward(
original orientation.
labels (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`, *optional*):
Heatmap ground truth for computing the loss.
label_weights (`torch.FloatTensor` of shape `(batch_size, num_labels, 1, 1)` or `(batch_size, num_labels, height, width)`, *optional*):
Visibility weights for each keypoint. Must be broadcastable to the shape of `labels`.

Example:

Expand Down Expand Up @@ -1200,7 +1205,7 @@ def forward(

loss = None
if labels is not None:
raise NotImplementedError("Training is not yet supported")
loss = self.loss_function(heatmaps, labels, weight=label_weights)

return Sapiens2PoseEstimatorOutput(
loss=loss,
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/sapiens2/modular_sapiens2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1647,6 +1647,8 @@ def forward(
""",
)
class Sapiens2ForPoseEstimation(Sapiens2PreTrainedModel):
_loss_function = staticmethod(torch.nn.functional.mse_loss)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please lets add this to

instead

so something along "Sapiens2ForPoseEstimation": torch.nn.functional.mse_loss,


def __init__(self, config: Sapiens2Config):
super().__init__(config)
self.num_labels = config.num_labels
Expand All @@ -1661,6 +1663,7 @@ def forward(
pixel_values: torch.FloatTensor,
flip_pairs: torch.Tensor | None = None,
labels: torch.FloatTensor | None = None,
label_weights: torch.FloatTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> Sapiens2PoseEstimatorOutput:
r"""
Expand All @@ -1671,6 +1674,8 @@ def forward(
original orientation.
labels (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`, *optional*):
Heatmap ground truth for computing the loss.
label_weights (`torch.FloatTensor` of shape `(batch_size, num_labels, 1, 1)` or `(batch_size, num_labels, height, width)`, *optional*):
Visibility weights for each keypoint. Must be broadcastable to the shape of `labels`.

Example:

Expand Down Expand Up @@ -1710,7 +1715,7 @@ def forward(

loss = None
if labels is not None:
raise NotImplementedError("Training is not yet supported")
loss = self.loss_function(heatmaps, labels, weight=label_weights)

return Sapiens2PoseEstimatorOutput(
loss=loss,
Expand Down
36 changes: 34 additions & 2 deletions tests/models/sapiens2/test_modeling_sapiens2.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,19 +183,39 @@ def create_and_check_for_semantic_segmentation(self, config, pixel_values, label
(self.batch_size, config.num_labels, expected_h, expected_h),
)

def create_and_check_for_pose_estimation(self, config, pixel_values, labels):
def create_and_check_for_pose_estimation(self, config, pixel_values, labels, label_weights):
model = Sapiens2ForPoseEstimation(config)
model.to(torch_device)
model.eval()

with torch.no_grad():
result = model(pixel_values)

patch_height = self.image_size // self.patch_size
expected_h = patch_height * (2 ** len(config.head_config.upsample_out_channels))

self.parent.assertEqual(
result.heatmaps.shape,
(self.batch_size, config.num_labels, expected_h, expected_h),
)

with torch.no_grad():
result_with_loss = model(
pixel_values,
labels=labels,
)

self.parent.assertIsNotNone(result_with_loss.loss)

with torch.no_grad():
result_with_weights = model(
pixel_values,
labels=labels,
label_weights=label_weights,
)

self.parent.assertIsNotNone(result_with_weights.loss)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo we should maybe also check that the backward doesnt result in runtime error so would avoid the no grads here and call a backward

def create_and_check_for_normal_estimation(self, config, pixel_values, labels):
model = Sapiens2ForNormalEstimation(config)
model.to(torch_device)
Expand Down Expand Up @@ -278,6 +298,18 @@ def prepare_config_and_inputs_for_common(self):
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict

def prepare_config_and_inputs_for_pose_estimation(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs[0], config_and_inputs[1]

patch_height = self.image_size // self.patch_size
expected_h = patch_height * (2 ** len(config.head_config.upsample_out_channels))

labels = floats_tensor([self.batch_size, config.num_labels, expected_h, expected_h])
label_weights = floats_tensor([self.batch_size, config.num_labels, expected_h, expected_h])

return config, pixel_values, labels, label_weights


@require_torch
class Sapiens2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
Expand Down Expand Up @@ -351,7 +383,7 @@ def test_for_semantic_segmentation(self):
self.model_tester.create_and_check_for_semantic_segmentation(*config_and_inputs)

def test_for_pose_estimation(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_semantic_segmentation()
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_pose_estimation()
self.model_tester.create_and_check_for_pose_estimation(*config_and_inputs)

def test_for_pointmap_estimation(self):
Expand Down