-
Notifications
You must be signed in to change notification settings - Fork 528
[MRG] Linear Circular OT #736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #736 +/- ##
========================================
Coverage 97.10% 97.11%
========================================
Files 100 100
Lines 20453 20697 +244
========================================
+ Hits 19861 20099 +238
- Misses 592 598 +6 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a very strong PR, thanks again @clbonet !
Just a few comments and I think we can merge. Also you new function get_projections_sphere is very nice, thansk for the refactoring
tc[mask_end > 0] = ( | ||
(Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp) | ||
)[mask_end > 0] | ||
with warnings.catch_warnings(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a bit worrying. why are you catching tjose worning? when do they happen. It is probaly OK but we need at last a comment there to explian
return wasserstein1_circle( | ||
u_values, v_values, u_weights, v_weights, require_sort | ||
) | ||
# if p == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the code if not usefull anymore
|
||
Returns | ||
------- | ||
loss: float |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this true or is this of shape (...) instead?
Types of changes
This PR aims to add the Linear Circular OT distance (with uniform measure on the circle as reference measure) introduced in LCOT: Linear Circular Optimal Transport
and its sliced counterpart introduced in Linear Spherical Sliced Optimal Transport: A Fast Metric for Comparing Spherical Data
.
Changes:
ot.solver_1d.linear_circular_embedding
andot.solver_1d.linear_circular_ot
functions to compute the Linear Circular OT distance.ot.sliced.linear_sliced_wasserstein_sphere
function to compute its sliced counterpart.ot.sliced.get_projections_sphere
andot.sliced.projection_sphere_to_circle
.test/test_1d_solver
andtest/test_sliced
.plot_compute_wasserstein_circle.py
ot.wasserstein_circle
to always useot.solver_1d.binary_search_circle
for any p, which seems to give better results even for p=1 (as discussed in Issue Wasserstein Circle distance doesn't seem correct? #738).plot_ssw_unif_torch.py
.Motivation and context / Related issue
The Linear Circular OT distance with uniform measure on the circle is very fast to compute, as the OT maps can be computed in closed-forms. It is thus faster to compute than the true wasserstein distance on the circle.
How has this been tested (if it applies)
I added tests of these functions in
test/test_1d_solver
andtest/test_sliced
.PR checklist