Skip to content

Add native masked MSE loss for Sapiens2ForPoseEstimation#46764

Open
Sainava wants to merge 12 commits into
huggingface:mainfrom
Sainava:feat/46518-sapiens2-pose-loss
Open

Add native masked MSE loss for Sapiens2ForPoseEstimation#46764
Sainava wants to merge 12 commits into
huggingface:mainfrom
Sainava:feat/46518-sapiens2-pose-loss

Conversation

@Sainava

@Sainava Sainava commented Jun 19, 2026

Copy link
Copy Markdown

What does this PR do?

Fixes #46518

As discussed with the maintainers in the linked issue, this PR implements native supervised pose-estimation loss directly in Sapiens2ForPoseEstimation to unlock fine-tuning capabilities using the Trainer API.

Specific Changes:

  • Added an optional target_weights parameter to the forward signature to handle keypoint visibility masking, following the masking behavior of OpenMMLab's KeypointMSELoss without adding external dependencies.
  • Implemented the masked MSE loss calculation using pure PyTorch, explicitly casting masks to the heatmap dtype to ensure safe fp16/bf16 mixed-precision training.
  • Added strict shape and dimensionality validation for supervision targets to prevent silent broadcasting errors.
  • Updated the forward docstring to explicitly document the new parameter.
  • Updated test_modeling_sapiens2.py to verify the loss computation both with and without target_weights.

All local make fix-repo and make check-repo checks have passed.

Scope of this PR

This implementation follows the masking behavior of OpenMMLab's KeypointMSELoss through optional target_weights support. It intentionally does not implement skip_empty_channel or the configurable loss_weight parameter from OpenMMLab, as neither is currently exposed through the Sapiens2 Transformers configuration. The goal of this PR is to provide native supervised pose-estimation training support while keeping the initial implementation focused and aligned with existing Transformers conventions.

  • I confirm that this is not a pure code agent PR.

Before submitting

Who can review?

Hi @guarin, following up on the discussion in #46518, I've put together a draft implementation for supervised pose-estimation loss support in Sapiens2ForPoseEstimation. I'd appreciate any feedback when you have a chance.

@guarin guarin left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR! Enabling pose estimation fine-tuning will be great addition!

Left some comments :) Ideally you could also add a small example in sapiens2.md on how to pass the labels and get the loss.

Comment thread src/transformers/models/sapiens2/modeling_sapiens2.py Outdated
Comment thread src/transformers/models/sapiens2/modeling_sapiens2.py Outdated
Comment thread tests/models/sapiens2/test_modeling_sapiens2.py Outdated
@Sainava

Sainava commented Jun 19, 2026

Copy link
Copy Markdown
Author

Hi @guarin, thanks for the feedback!

I've updated the PR to follow the suggested loss-function architecture, moved the pose-estimation test inputs into a dedicated factory method, and added a supervised fine-tuning example to the documentation.

One implementation detail: _loss_function is wrapped with staticmethod(...) so it works correctly with the self.loss_function property and avoids Python binding the model instance as the first argument.

I'd really appreciate any further feedback when you have a chance :)

@guarin guarin left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update and adding the docs! This is already looking pretty good, left minor comments on how we could simplify further :)

Let me know if you would be interested to add the preprocessing as well (might be a bit tricky).

Comment on lines +1208 to +1213
loss = self.loss_function(heatmaps, labels, reduction="none")

if label_weights is not None:
loss = (loss * label_weights).mean()
else:
loss = loss.mean()

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we pass weights in all cases?

Suggested change
loss = self.loss_function(heatmaps, labels, reduction="none")
if label_weights is not None:
loss = (loss * label_weights).mean()
else:
loss = loss.mean()
loss = self.loss_function(heatmaps, labels, weight=weights)

This should work correctly if weights is a Tensor or None and will make it easier for users to customise the loss function. Otherwise they have to overwrite the full forward method to handle weights correctly.

Comment thread docs/source/en/model_doc/sapiens2.md Outdated
Comment on lines +277 to +278
<hfoption id="Supervised fine-tuning (Pose estimation)">

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename to this to be closer in form to the Pose estimation and Pose estimation with flip augmentations sections

Suggested change
<hfoption id="Supervised fine-tuning (Pose estimation)">
<hfoption id="Pose estimation training">

Comment thread docs/source/en/model_doc/sapiens2.md Outdated
Comment on lines +297 to +298
batch_size, num_keypoints = 1, 308
heatmap_height, heatmap_width = 256, 192

@guarin guarin Jun 22, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can assume here that the heatmaps have the same height and width as the preprocessed image. We'll have to add the pose preprocessing to the Sapiens2ImageProcessor which will convert it to the correct format and size. The original Sapiens2 code for this is here: https://github.com/facebookresearch/sapiens2/blob/main/sapiens/pose/src/datasets/transforms/pose_transforms.py

For the loss calculation you might have to interpolate the model outputs to match the label size again.

Comment on lines +307 to +308

labels = torch.randn(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
labels = torch.randn(
labels = floats_tensor(

@Sainava

Sainava commented Jun 22, 2026

Copy link
Copy Markdown
Author

Hi @guarin, thanks for the review! I've pushed the suggested updates (tests and documentation)

Regarding the loss function: I tried passing weight=label_weights directly to self.loss_function, but since the default implementation resolves to torch.nn.functional.mse_loss, this raises:

ValueError: Weights and input must have the same size

when using OpenMMLab-style visibility weights of shape [batch_size, num_keypoints, 1, 1] against heatmaps of shape [batch_size, num_keypoints, height, width].

To preserve the original masking behavior, I'm currently computing the unreduced loss and applying the broadcasted weights explicitly.

Would you prefer that I instead move the weighting logic into a small custom _loss_function implementation so the forward pass can remain:

loss = self.loss_function(heatmaps, labels, weight=label_weights)

while preserving the same behavior?

And yes, I'd definitely be interested in working on the preprocessing side as well. My preference would be to get the training-loss support merged first and then tackle the preprocessing changes in a follow-up PR if that's okay .

@guarin

guarin commented Jun 22, 2026

Copy link
Copy Markdown
Member

Hi! I think we can safely assume that label_weights has the same shape as labels in the forward function. The pre-processing can make sure that this is the case and expand any (batch_size, num_keypoints, 1, 1) tensors to match labels.shape so we don't have to worry about that in forward.

@github-actions

Copy link
Copy Markdown
Contributor

CI Dashboard: View test results in Grafana

@Sainava

Sainava commented Jun 22, 2026

Copy link
Copy Markdown
Author

Hi @guarin ! I've simplified the forward pass to use weight=label_weights directly in the loss function as suggested.
Thanks again for the guidance :)

@vasqu vasqu left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall pretty much ready just some smaller comments from my side 🤗

""",
)
class Sapiens2ForPoseEstimation(Sapiens2PreTrainedModel):
_loss_function = staticmethod(torch.nn.functional.mse_loss)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please lets add this to

instead

so something along "Sapiens2ForPoseEstimation": torch.nn.functional.mse_loss,

Comment on lines +202 to +218
with torch.no_grad():
result_with_loss = model(
pixel_values,
labels=labels,
)

self.parent.assertIsNotNone(result_with_loss.loss)

with torch.no_grad():
result_with_weights = model(
pixel_values,
labels=labels,
label_weights=label_weights,
)

self.parent.assertIsNotNone(result_with_weights.loss)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo we should maybe also check that the backward doesnt result in runtime error so would avoid the no grads here and call a backward

@github-actions

Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: sapiens2

@Sainava

Sainava commented Jun 23, 2026

Copy link
Copy Markdown
Author

Hi @vasqu, thanks for the review :) .I've moved the loss to LOSS_MAPPING and added the .backward() tests as well .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Investigate training support for Sapiens2ForPoseEstimation when labels are provided

3 participants