Skip to content

CompositeLossMetrics now performs a weighted sum of losses. #1251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 17, 2025

Conversation

ds-hwang
Copy link
Contributor

Currently, CompositeLossMetrics sums the losses without considering their weights (i.e., the number of live targets). To make this a weighted sum, downstream code has been implementing CompositeLossWeights to inject the number of live targets into loss_weights. This is essentially patching a surprising logic (initail loss sum) with complex logic (CompositeLossWeights) into a straightforward one (weighted sum).

Therefore, we’re changing the default loss aggregation logic to be straightforward from the beginning.

From now on, our standarized loss aggregation logic is

loss = sum(each_loss_weight * each_loss * num_each_samples) / sum(each_loss_weight * num_each_samples)

Historically, the complex logic was introduced because the weights of losses returned by child metrics were unknown. But now that child metrics return losses as WeightedScalar, we can adopt a simpler, cleaner aggregation logic.

Note: alternative formulation could be

loss = sum(each_loss_weight * each_loss * num_each_samples) / sum(num_each_samples)

However, when num_each_samples is large and each_loss_weight is small, the denominator can become disproportionately large. So we discard this option.

@ds-hwang ds-hwang requested review from ruomingp, markblee and a team as code owners June 10, 2025 16:22
@ds-hwang
Copy link
Contributor Author

@markblee Could you take a look? From 1399

Copy link
Contributor

@markblee markblee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Will approve after the internal review completes.)

@ds-hwang ds-hwang force-pushed the metric_weighted branch 3 times, most recently from 00d1611 to 29f13f7 Compare July 3, 2025 20:40
Currently, `CompositeLossMetrics` sums the losses without considering their
weights (i.e., the number of live targets). To make this a weighted sum,
downstream code has been implementing `CompositeLossWeights` to inject the
number of live targets into `loss_weights`. This is essentially patching a
surprising logic (initail loss sum) with complex logic (CompositeLossWeights)
into a straightforward one (weighted sum).

Therefore, we’re changing the default loss aggregation logic to be
straightforward from the beginning.

From now on, our standarized loss aggregation logic is
```
loss = sum(each_loss_weight * each_loss * num_each_samples) / sum(each_loss_weight * num_each_samples)
```

Historically, the complex logic was introduced because the weights of
losses returned by child metrics were unknown. But now that child metrics
return losses as `WeightedScalar`, we can adopt a simpler, cleaner aggregation
logic.

Note: alternative formulation could be
```
loss = sum(each_loss_weight * each_loss * num_each_samples) / sum(num_each_samples)
```
However, when num_each_samples is large and each_loss_weight is small, the
denominator can become disproportionately large. So we discard this option.
@ds-hwang
Copy link
Contributor Author

ds-hwang commented Jul 15, 2025

@markblee could you take a look again?

(Will approve after the internal review completes.)

All reviewers approved internally at 23540

@ds-hwang ds-hwang requested a review from markblee July 15, 2025 16:56
@ds-hwang ds-hwang added this pull request to the merge queue Jul 16, 2025
Merged via the queue into main with commit 343102a Jul 17, 2025
11 checks passed
@ds-hwang ds-hwang deleted the metric_weighted branch July 17, 2025 00:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants