Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions docs/source/en/model_doc/sapiens2.md
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,40 @@ scores = results[0]["scores"]

</hfoption>

<hfoption id="Supervised fine-tuning (Pose estimation)">

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename to this to be closer in form to the Pose estimation and Pose estimation with flip augmentations sections

Suggested change
<hfoption id="Supervised fine-tuning (Pose estimation)">
<hfoption id="Pose estimation training">

The example below shows how to compute the supervised Masked Mean Squared Error (MSE) loss with [`Sapiens2ForPoseEstimation`] by passing `labels` and optional `label_weights`. This is useful for fine-tuning using the `Trainer` API.

```python
import torch
from transformers import AutoImageProcessor, AutoModelForPoseEstimation
from transformers.image_utils import load_image

image = load_image("http://images.cocodataset.org/val2017/000000004016.jpg")

image_processor = AutoImageProcessor.from_pretrained("facebook/sapiens2-pose-0.4b")
model = AutoModelForPoseEstimation.from_pretrained("facebook/sapiens2-pose-0.4b", device_map="auto")

# Provide bounding boxes in COCO format (x, y, width, height) for each person
boxes = [[[270.8, 0.6, 294.1, 379.5]]]
inputs = image_processor(image, boxes=boxes, return_tensors="pt").to(model.device)

# Create dummy labels (heatmaps) and visibility weights to simulate ground truth
# 1.0 for visible keypoints, 0.0 for occluded/invisible keypoints
batch_size, num_keypoints = 1, 308
heatmap_height, heatmap_width = 256, 192

@guarin guarin Jun 22, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can assume here that the heatmaps have the same height and width as the preprocessed image. We'll have to add the pose preprocessing to the Sapiens2ImageProcessor which will convert it to the correct format and size. The original Sapiens2 code for this is here: https://github.com/facebookresearch/sapiens2/blob/main/sapiens/pose/src/datasets/transforms/pose_transforms.py

For the loss calculation you might have to interpolate the model outputs to match the label size again.

labels = torch.randn(batch_size, num_keypoints, heatmap_height, heatmap_width, device=model.device)
label_weights = torch.ones(batch_size, num_keypoints, 1, 1, device=model.device)

# Forward pass with loss calculation
outputs = model(**inputs, labels=labels, label_weights=label_weights)

print("Loss:", outputs.loss.item())
```

</hfoption>


<hfoption id="Semantic segmentation">

The example below shows how to perform body-part segmentation with [`Sapiens2ForSemanticSegmentation`].
Expand Down
12 changes: 11 additions & 1 deletion src/transformers/models/sapiens2/modeling_sapiens2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,8 @@ def flip_back(output_flipped, flip_pairs, target_type="gaussian-heatmap"):
""",
)
class Sapiens2ForPoseEstimation(Sapiens2PreTrainedModel):
_loss_function = staticmethod(torch.nn.functional.mse_loss)

def __init__(self, config: Sapiens2Config):
super().__init__(config)
self.num_labels = config.num_labels
Expand All @@ -1151,6 +1153,7 @@ def forward(
pixel_values: torch.FloatTensor,
flip_pairs: torch.Tensor | None = None,
labels: torch.FloatTensor | None = None,
label_weights: torch.FloatTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> Sapiens2PoseEstimatorOutput:
r"""
Expand All @@ -1161,6 +1164,8 @@ def forward(
original orientation.
labels (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`, *optional*):
Heatmap ground truth for computing the loss.
label_weights (`torch.FloatTensor` of shape `(batch_size, num_labels, 1, 1)` or `(batch_size, num_labels, height, width)`, *optional*):
Visibility weights for each keypoint. Must be broadcastable to the shape of `labels`.

Example:

Expand Down Expand Up @@ -1200,7 +1205,12 @@ def forward(

loss = None
if labels is not None:
raise NotImplementedError("Training is not yet supported")
loss = self.loss_function(heatmaps, labels, reduction="none")

if label_weights is not None:
loss = (loss * label_weights).mean()
else:
loss = loss.mean()

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we pass weights in all cases?

Suggested change
loss = self.loss_function(heatmaps, labels, reduction="none")
if label_weights is not None:
loss = (loss * label_weights).mean()
else:
loss = loss.mean()
loss = self.loss_function(heatmaps, labels, weight=weights)

This should work correctly if weights is a Tensor or None and will make it easier for users to customise the loss function. Otherwise they have to overwrite the full forward method to handle weights correctly.


return Sapiens2PoseEstimatorOutput(
loss=loss,
Expand Down
12 changes: 11 additions & 1 deletion src/transformers/models/sapiens2/modular_sapiens2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1647,6 +1647,8 @@ def forward(
""",
)
class Sapiens2ForPoseEstimation(Sapiens2PreTrainedModel):
_loss_function = staticmethod(torch.nn.functional.mse_loss)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please lets add this to

instead

so something along "Sapiens2ForPoseEstimation": torch.nn.functional.mse_loss,


def __init__(self, config: Sapiens2Config):
super().__init__(config)
self.num_labels = config.num_labels
Expand All @@ -1661,6 +1663,7 @@ def forward(
pixel_values: torch.FloatTensor,
flip_pairs: torch.Tensor | None = None,
labels: torch.FloatTensor | None = None,
label_weights: torch.FloatTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> Sapiens2PoseEstimatorOutput:
r"""
Expand All @@ -1671,6 +1674,8 @@ def forward(
original orientation.
labels (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`, *optional*):
Heatmap ground truth for computing the loss.
label_weights (`torch.FloatTensor` of shape `(batch_size, num_labels, 1, 1)` or `(batch_size, num_labels, height, width)`, *optional*):
Visibility weights for each keypoint. Must be broadcastable to the shape of `labels`.

Example:

Expand Down Expand Up @@ -1710,7 +1715,12 @@ def forward(

loss = None
if labels is not None:
raise NotImplementedError("Training is not yet supported")
loss = self.loss_function(heatmaps, labels, reduction="none")

if label_weights is not None:
loss = (loss * label_weights).mean()
else:
loss = loss.mean()

return Sapiens2PoseEstimatorOutput(
loss=loss,
Expand Down
46 changes: 44 additions & 2 deletions tests/models/sapiens2/test_modeling_sapiens2.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,19 +183,39 @@ def create_and_check_for_semantic_segmentation(self, config, pixel_values, label
(self.batch_size, config.num_labels, expected_h, expected_h),
)

def create_and_check_for_pose_estimation(self, config, pixel_values, labels):
def create_and_check_for_pose_estimation(self, config, pixel_values, labels, label_weights):
model = Sapiens2ForPoseEstimation(config)
model.to(torch_device)
model.eval()

with torch.no_grad():
result = model(pixel_values)

patch_height = self.image_size // self.patch_size
expected_h = patch_height * (2 ** len(config.head_config.upsample_out_channels))

self.parent.assertEqual(
result.heatmaps.shape,
(self.batch_size, config.num_labels, expected_h, expected_h),
)

with torch.no_grad():
result_with_loss = model(
pixel_values,
labels=labels,
)

self.parent.assertIsNotNone(result_with_loss.loss)

with torch.no_grad():
result_with_weights = model(
pixel_values,
labels=labels,
label_weights=label_weights,
)

self.parent.assertIsNotNone(result_with_weights.loss)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo we should maybe also check that the backward doesnt result in runtime error so would avoid the no grads here and call a backward

def create_and_check_for_normal_estimation(self, config, pixel_values, labels):
model = Sapiens2ForNormalEstimation(config)
model.to(torch_device)
Expand Down Expand Up @@ -278,6 +298,28 @@ def prepare_config_and_inputs_for_common(self):
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict

def prepare_config_and_inputs_for_pose_estimation(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs[0], config_and_inputs[1]

patch_height = self.image_size // self.patch_size
expected_h = patch_height * (2 ** len(config.head_config.upsample_out_channels))

labels = torch.randn(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
labels = torch.randn(
labels = floats_tensor(

self.batch_size,
config.num_labels,
expected_h,
expected_h,
device=pixel_values.device,
dtype=pixel_values.dtype,
)

label_weights = torch.ones(
self.batch_size, config.num_labels, 1, 1, device=pixel_values.device, dtype=pixel_values.dtype
)

return config, pixel_values, labels, label_weights


@require_torch
class Sapiens2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
Expand Down Expand Up @@ -351,7 +393,7 @@ def test_for_semantic_segmentation(self):
self.model_tester.create_and_check_for_semantic_segmentation(*config_and_inputs)

def test_for_pose_estimation(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_semantic_segmentation()
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_pose_estimation()
self.model_tester.create_and_check_for_pose_estimation(*config_and_inputs)

def test_for_pointmap_estimation(self):
Expand Down