Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Wasserstein convolutional barycenter
  • Loading branch information
ncourty committed Sep 7, 2018
1 parent 5180023 commit d99abf0
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 1 deletion.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,4 +227,6 @@ You can also post bug reports and feature requests in Github issues. Make sure t

[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018)

[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning
[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning

[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66.
Binary file added data/duck.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/heart.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/redcross.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/tooth.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
92 changes: 92 additions & 0 deletions examples/plot_convolutional_barycenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@

#%%
# -*- coding: utf-8 -*-
"""
============================================
Convolutional Wasserstein Barycenter example
============================================
This example is designed to illustrate how the Convolutional Wasserstein Barycenter
function of POT works.
"""

# Author: Nicolas Courty <[email protected]>
#
# License: MIT License


import numpy as np
import pylab as pl
import ot

##############################################################################
# Data preparation
# ----------------
#
# The four distributions are constructed from 4 simple images


f1 = 1 - pl.imread('../data/redcross.png')[:, :, 2]
f2 = 1 - pl.imread('../data/duck.png')[:, :, 2]
f3 = 1 - pl.imread('../data/heart.png')[:, :, 2]
f4 = 1 - pl.imread('../data/tooth.png')[:, :, 2]

A = []
f1=f1/np.sum(f1)
f2=f2/np.sum(f2)
f3=f3/np.sum(f3)
f4=f4/np.sum(f4)
A.append(f1)
A.append(f2)
A.append(f3)
A.append(f4)
A=np.array(A)

nb_images = 5

# those are the four corners coordinates that will be interpolated by bilinear
# interpolation
v1=np.array((1,0,0,0))
v2=np.array((0,1,0,0))
v3=np.array((0,0,1,0))
v4=np.array((0,0,0,1))


##############################################################################
# Barycenter computation and visualization
# ----------------------------------------
#

pl.figure(figsize=(10,10))
pl.title('Convolutional Wasserstein Barycenters in POT')
cm='Blues'
# regularization parameter
reg=0.004
for i in range(nb_images):
for j in range(nb_images):
pl.subplot(nb_images,nb_images,i*nb_images+j+1)
tx=float(i)/(nb_images-1)
ty=float(j)/(nb_images-1)

# weights are constructed by bilinear interpolation
tmp1=(1-tx)*v1+tx*v2
tmp2=(1-tx)*v3+tx*v4
weights=(1-ty)*tmp1+ty*tmp2

if i==0 and j==0:
pl.imshow(f1,cmap=cm)
pl.axis('off')
elif i==0 and j==(nb_images-1):
pl.imshow(f3,cmap=cm)
pl.axis('off')
elif i==(nb_images-1) and j==0:
pl.imshow(f2,cmap=cm)
pl.axis('off')
elif i==(nb_images-1) and j==(nb_images-1):
pl.imshow(f4,cmap=cm)
pl.axis('off')
else:
# call to barycenter computation
pl.imshow(ot.convolutional_barycenter2d(A,reg,weights),cmap=cm)
pl.axis('off')
pl.show()
106 changes: 106 additions & 0 deletions ot/bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,112 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
else:
return geometricBar(weights, UKv)

def convolutional_barycenter2d(A,reg,weights=None,numItermax = 10000, stopThr=1e-9, verbose=False, log=False):
"""Compute the entropic regularized wasserstein barycenter of distributions A
where A is a collection of 2D images.
The function solves the following optimization problem:
.. math::
\mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
where :
- :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
- :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}`
- reg is the regularization strength scalar value
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [21]_
Parameters
----------
A : np.ndarray (n,w,h)
n distributions (2D images) of size w x h
reg : float
Regularization term >0
weights : np.ndarray (n,)
Weights of each image on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshol on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
Returns
-------
a : (w,h) ndarray
2D Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
References
----------
.. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015).
Convolutional wasserstein distances: Efficient optimal transportation on geometric domains
ACM Transactions on Graphics (TOG), 34(4), 66
"""

if weights is None:
weights = np.ones(A.shape[0]) / A.shape[0]
else:
assert(len(weights) == A.shape[0])

if log:
log = {'err': []}

b=np.zeros_like(A[0,:,:])
U=np.ones_like(A)
KV=np.ones_like(A)
threshold = 1e-30 # in order to avoids numerical precision issues

cpt = 0
err=1

# build the convolution operator
t = np.linspace(0,1,A.shape[1])
[Y,X] = np.meshgrid(t,t)
xi1 = np.exp(-(X-Y)**2/reg)
K = lambda x: np.dot(np.dot(xi1,x),xi1)

while (err>stopThr and cpt<numItermax):

bold=b
cpt = cpt +1

b=np.zeros_like(A[0,:,:])
for r in range(A.shape[0]):
KV[r,:,:]=K(A[r,:,:]/np.maximum(threshold,K(U[r,:,:])))
b += weights[r] * np.log(np.maximum(threshold, U[r,:,:]*KV[r,:,:]))
b = np.exp(b)
for r in range(A.shape[0]):
U[r,:,:]=b/np.maximum(threshold,KV[r,:,:])

if cpt%10==1:
err=np.sum(np.abs(bold-b))
# log and verbose print
if log:
log['err'].append(err)

if verbose:
if cpt%200 ==0:
print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
print('{:5d}|{:8e}|'.format(cpt,err))

if log:
log['niter']=cpt
log['U']=U
return b,log
else:
return b


def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
stopThr=1e-3, verbose=False, log=False):
Expand Down

0 comments on commit d99abf0

Please sign in to comment.