Add native masked MSE loss for Sapiens2ForPoseEstimation#46764
Add native masked MSE loss for Sapiens2ForPoseEstimation#46764Sainava wants to merge 12 commits into
Conversation
guarin
left a comment
There was a problem hiding this comment.
Thank you for the PR! Enabling pose estimation fine-tuning will be great addition!
Left some comments :) Ideally you could also add a small example in sapiens2.md on how to pass the labels and get the loss.
|
Hi @guarin, thanks for the feedback! I've updated the PR to follow the suggested loss-function architecture, moved the pose-estimation test inputs into a dedicated factory method, and added a supervised fine-tuning example to the documentation. One implementation detail: I'd really appreciate any further feedback when you have a chance :) |
guarin
left a comment
There was a problem hiding this comment.
Thanks for the update and adding the docs! This is already looking pretty good, left minor comments on how we could simplify further :)
Let me know if you would be interested to add the preprocessing as well (might be a bit tricky).
| loss = self.loss_function(heatmaps, labels, reduction="none") | ||
|
|
||
| if label_weights is not None: | ||
| loss = (loss * label_weights).mean() | ||
| else: | ||
| loss = loss.mean() |
There was a problem hiding this comment.
Could we pass weights in all cases?
| loss = self.loss_function(heatmaps, labels, reduction="none") | |
| if label_weights is not None: | |
| loss = (loss * label_weights).mean() | |
| else: | |
| loss = loss.mean() | |
| loss = self.loss_function(heatmaps, labels, weight=weights) |
This should work correctly if weights is a Tensor or None and will make it easier for users to customise the loss function. Otherwise they have to overwrite the full forward method to handle weights correctly.
| <hfoption id="Supervised fine-tuning (Pose estimation)"> | ||
|
|
There was a problem hiding this comment.
Maybe rename to this to be closer in form to the Pose estimation and Pose estimation with flip augmentations sections
| <hfoption id="Supervised fine-tuning (Pose estimation)"> | |
| <hfoption id="Pose estimation training"> | |
| batch_size, num_keypoints = 1, 308 | ||
| heatmap_height, heatmap_width = 256, 192 |
There was a problem hiding this comment.
You can assume here that the heatmaps have the same height and width as the preprocessed image. We'll have to add the pose preprocessing to the Sapiens2ImageProcessor which will convert it to the correct format and size. The original Sapiens2 code for this is here: https://github.com/facebookresearch/sapiens2/blob/main/sapiens/pose/src/datasets/transforms/pose_transforms.py
For the loss calculation you might have to interpolate the model outputs to match the label size again.
|
|
||
| labels = torch.randn( |
There was a problem hiding this comment.
| labels = torch.randn( | |
| labels = floats_tensor( |
|
Hi @guarin, thanks for the review! I've pushed the suggested updates (tests and documentation) Regarding the loss function: I tried passing when using OpenMMLab-style visibility weights of shape To preserve the original masking behavior, I'm currently computing the unreduced loss and applying the broadcasted weights explicitly. Would you prefer that I instead move the weighting logic into a small custom loss = self.loss_function(heatmaps, labels, weight=label_weights)while preserving the same behavior? And yes, I'd definitely be interested in working on the preprocessing side as well. My preference would be to get the training-loss support merged first and then tackle the preprocessing changes in a follow-up PR if that's okay . |
|
Hi! I think we can safely assume that |
|
CI Dashboard: View test results in Grafana |
|
Hi @guarin ! I've simplified the forward pass to use weight=label_weights directly in the loss function as suggested. |
vasqu
left a comment
There was a problem hiding this comment.
Overall pretty much ready just some smaller comments from my side 🤗
| """, | ||
| ) | ||
| class Sapiens2ForPoseEstimation(Sapiens2PreTrainedModel): | ||
| _loss_function = staticmethod(torch.nn.functional.mse_loss) |
There was a problem hiding this comment.
please lets add this to
insteadso something along "Sapiens2ForPoseEstimation": torch.nn.functional.mse_loss,
| with torch.no_grad(): | ||
| result_with_loss = model( | ||
| pixel_values, | ||
| labels=labels, | ||
| ) | ||
|
|
||
| self.parent.assertIsNotNone(result_with_loss.loss) | ||
|
|
||
| with torch.no_grad(): | ||
| result_with_weights = model( | ||
| pixel_values, | ||
| labels=labels, | ||
| label_weights=label_weights, | ||
| ) | ||
|
|
||
| self.parent.assertIsNotNone(result_with_weights.loss) | ||
|
|
There was a problem hiding this comment.
Imo we should maybe also check that the backward doesnt result in runtime error so would avoid the no grads here and call a backward
|
[For maintainers] Suggested jobs to run (before merge) run-slow: sapiens2 |
|
Hi @vasqu, thanks for the review :) .I've moved the loss to LOSS_MAPPING and added the .backward() tests as well . |
What does this PR do?
Fixes #46518
As discussed with the maintainers in the linked issue, this PR implements native supervised pose-estimation loss directly in
Sapiens2ForPoseEstimationto unlock fine-tuning capabilities using theTrainerAPI.Specific Changes:
target_weightsparameter to theforwardsignature to handle keypoint visibility masking, following the masking behavior of OpenMMLab'sKeypointMSELosswithout adding external dependencies.fp16/bf16mixed-precision training.forwarddocstring to explicitly document the new parameter.test_modeling_sapiens2.pyto verify the loss computation both with and withouttarget_weights.All local
make fix-repoandmake check-repochecks have passed.Scope of this PR
This implementation follows the masking behavior of OpenMMLab's
KeypointMSELossthrough optionaltarget_weightssupport. It intentionally does not implementskip_empty_channelor the configurableloss_weightparameter from OpenMMLab, as neither is currently exposed through the Sapiens2 Transformers configuration. The goal of this PR is to provide native supervised pose-estimation training support while keeping the initial implementation focused and aligned with existing Transformers conventions.Before submitting
Pull Request checks?
to it if that's the case. Investigate training support for Sapiens2ForPoseEstimation when labels are provided #46518 (comment)
Who can review?
Hi @guarin, following up on the discussion in #46518, I've put together a draft implementation for supervised pose-estimation loss support in Sapiens2ForPoseEstimation. I'd appreciate any feedback when you have a chance.