From 94eb4a24fe09c8a193933d8c7f48a657da4e6c52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 11 May 2018 16:07:03 +0200 Subject: [PATCH 01/13] update documentation in bregman --- ot/bregman.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ot/bregman.py b/ot/bregman.py index 07b8660e3..9c84aed6e 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -844,6 +844,8 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, loss matrix for OT reg : float Regularization term >0 + weights : np.ndarray (n,) + Weights of each histogram i_i on the simplex numItermax : int, optional Max number of iterations stopThr : float, optional From 8f6c8cd04db65a5d28c467e00b294b07e8183eb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 11 May 2018 16:11:45 +0200 Subject: [PATCH 02/13] update readme --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 6b7cff03c..466c09ca8 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,8 @@ This open source Python library provide several solvers for optimization problem It provides the following solvers: * OT Network Flow solver for the linear program/ Earth Movers Distance [1]. -* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (required cudamat). +* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (requires cudamat). +* Non regularized Wasserstein barycenters [16] with LP solver. * Bregman projections for Wasserstein barycenter [3] and unmixing [4]. * Optimal transport for domain adaptation with group lasso regularization [5] * Conditional gradient [6] and Generalized conditional gradient for regularized OT [7]. @@ -210,3 +211,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t [14] Knott, M. and Smith, C. S. (1984).[On the optimal mapping of distributions](https://link.springer.com/article/10.1007/BF00934745), Journal of Optimization Theory and Applications Vol 43. [15] Peyré, G., & Cuturi, M. (2018). [Computational Optimal Transport](https://arxiv.org/pdf/1803.00567.pdf) . + +[16] Agueh, M., & Carlier, G. (2011). [Barycenters in the Wasserstein space](https://hal.archives-ouvertes.fr/hal-00637399/document). SIAM Journal on Mathematical Analysis, 43(2), 904-924. From be8817730c7996052e84d21ba08cf60f59020935 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 11 May 2018 16:55:35 +0200 Subject: [PATCH 03/13] add requirement scipy for linprog interior point solver --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9bfca4327..df841ba20 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ numpy -scipy +scipy>=1.0 cython matplotlib sphinx-gallery From 060d9046b291c76244deab2d78ee8356a294e91f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 11 May 2018 16:56:47 +0200 Subject: [PATCH 04/13] add cvx barycenter solver --- ot/lp/cvx.py | 138 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 ot/lp/cvx.py diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py new file mode 100644 index 000000000..4d08916fd --- /dev/null +++ b/ot/lp/cvx.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +""" +LP solvers for optimal transport using cvxopt +""" + +# Author: Remi Flamary +# +# License: MIT License + +import numpy as np +import scipy as sp +import scipy.sparse as sps + +try: + import cvxopt + from cvxopt import solvers, matrix, sparse, spmatrix +except ImportError: + cvxopt=False + +def scipy_sparse_to_spmatrix(A): + """Efficient conversion from scipy sparse matrix to cvxopt sparse matrix""" + coo = A.tocoo() + SP = spmatrix(coo.data.tolist(), coo.row.tolist(), coo.col.tolist(), size=A.shape) + return SP + +def barycenter(A, M, weights=None, verbose=False, log=False,solver='interior-point'): + """Compute the entropic regularized wasserstein barycenter of distributions A + + The function solves the following optimization problem [16]: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{1}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` + + The linear program is solved using the default cvxopt solver if installed. + If cvxopt is not installed it uses the lp solver from scipy.optimize. + + Parameters + ---------- + A : np.ndarray (d,n) + n training distributions of size d + M : np.ndarray (d,d) + loss matrix for OT + reg : float + Regularization term >0 + weights : np.ndarray (n,) + Weights of each histogram i_i on the simplex + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + solver : string, optional + the solver used, default 'interior-point' use the lp solver from + scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt. + + Returns + ------- + a : (d,) ndarray + Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924. + + + + """ + + if weights is None: + weights = np.ones(A.shape[1]) / A.shape[1] + else: + assert(len(weights) == A.shape[1]) + + n_distributions=A.shape[1] + n=A.shape[0] + + n2=n*n + c=np.zeros((0)) + b_eq1=np.zeros((0)) + for i in range(n_distributions): + c=np.concatenate((c,M.ravel()*weights[i])) + b_eq1=np.concatenate((b_eq1,A[:,i])) + c=np.concatenate((c,np.zeros(n))) + + lst_idiag1=[sps.kron(sps.eye(n),np.ones((1,n))) for i in range(n_distributions)] + # row constraints + A_eq1=sps.hstack((sps.block_diag(lst_idiag1),sps.coo_matrix((n_distributions*n,n)))) + + # columns constraints + lst_idiag2=[] + lst_eye=[] + for i in range(n_distributions): + if i==0: + lst_idiag2.append(sps.kron(np.ones((1,n)),sps.eye(n))) + lst_eye.append(-sps.eye(n)) + else: + lst_idiag2.append(sps.kron(np.ones((1,n)),sps.eye(n-1,n))) + lst_eye.append(-sps.eye(n-1,n)) + + A_eq2=sps.hstack((sps.block_diag(lst_idiag2),sps.vstack(lst_eye))) + b_eq2=np.zeros((A_eq2.shape[0])) + + # full problem + A_eq=sps.vstack((A_eq1,A_eq2)) + b_eq=np.concatenate((b_eq1,b_eq2)) + + if not cvxopt or solver in ['interior-point']: # cvxopt not installed or simplex/interior point + + if solver is None: + solver='interior-point' + + options={'sparse':True,'disp': verbose} + sol=sp.optimize.linprog(c,A_eq=A_eq,b_eq=b_eq,method=solver,options=options) + x=sol.x + b=x[-n:] + + else: + + h=np.zeros((n_distributions*n2+n)) + G=-sps.eye(n_distributions*n2+n) + + sol=solvers.lp(matrix(c),scipy_sparse_to_spmatrix(G),matrix(h),A=scipy_sparse_to_spmatrix(A_eq),b=matrix(b_eq),solver=solver) + + x=np.array(sol['x']) + b=x[-n:].ravel() + + if log: + return b, sol + else: + return b \ No newline at end of file From 3aee908ad42d65897f1916de6eab84921ac94a10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 11 May 2018 16:58:06 +0200 Subject: [PATCH 05/13] pep8 --- ot/bregman.py | 2 +- ot/lp/cvx.py | 104 +++++++++++++++++++++++++------------------------- 2 files changed, 54 insertions(+), 52 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index 9c84aed6e..e788ef585 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -845,7 +845,7 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, reg : float Regularization term >0 weights : np.ndarray (n,) - Weights of each histogram i_i on the simplex + Weights of each histogram i_i on the simplex numItermax : int, optional Max number of iterations stopThr : float, optional diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index 4d08916fd..93097d16d 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -15,7 +15,8 @@ import cvxopt from cvxopt import solvers, matrix, sparse, spmatrix except ImportError: - cvxopt=False + cvxopt = False + def scipy_sparse_to_spmatrix(A): """Efficient conversion from scipy sparse matrix to cvxopt sparse matrix""" @@ -23,7 +24,8 @@ def scipy_sparse_to_spmatrix(A): SP = spmatrix(coo.data.tolist(), coo.row.tolist(), coo.col.tolist(), size=A.shape) return SP -def barycenter(A, M, weights=None, verbose=False, log=False,solver='interior-point'): + +def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'): """Compute the entropic regularized wasserstein barycenter of distributions A The function solves the following optimization problem [16]: @@ -36,7 +38,7 @@ def barycenter(A, M, weights=None, verbose=False, log=False,solver='interior-poi - :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn) - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - The linear program is solved using the default cvxopt solver if installed. + The linear program is solved using the default cvxopt solver if installed. If cvxopt is not installed it uses the lp solver from scipy.optimize. Parameters @@ -48,13 +50,13 @@ def barycenter(A, M, weights=None, verbose=False, log=False,solver='interior-poi reg : float Regularization term >0 weights : np.ndarray (n,) - Weights of each histogram i_i on the simplex + Weights of each histogram i_i on the simplex verbose : bool, optional Print information along iterations log : bool, optional record log if True solver : string, optional - the solver used, default 'interior-point' use the lp solver from + the solver used, default 'interior-point' use the lp solver from scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt. Returns @@ -78,61 +80,61 @@ def barycenter(A, M, weights=None, verbose=False, log=False,solver='interior-poi weights = np.ones(A.shape[1]) / A.shape[1] else: assert(len(weights) == A.shape[1]) - - n_distributions=A.shape[1] - n=A.shape[0] - - n2=n*n - c=np.zeros((0)) - b_eq1=np.zeros((0)) + + n_distributions = A.shape[1] + n = A.shape[0] + + n2 = n * n + c = np.zeros((0)) + b_eq1 = np.zeros((0)) for i in range(n_distributions): - c=np.concatenate((c,M.ravel()*weights[i])) - b_eq1=np.concatenate((b_eq1,A[:,i])) - c=np.concatenate((c,np.zeros(n))) - - lst_idiag1=[sps.kron(sps.eye(n),np.ones((1,n))) for i in range(n_distributions)] + c = np.concatenate((c, M.ravel() * weights[i])) + b_eq1 = np.concatenate((b_eq1, A[:, i])) + c = np.concatenate((c, np.zeros(n))) + + lst_idiag1 = [sps.kron(sps.eye(n), np.ones((1, n))) for i in range(n_distributions)] # row constraints - A_eq1=sps.hstack((sps.block_diag(lst_idiag1),sps.coo_matrix((n_distributions*n,n)))) - + A_eq1 = sps.hstack((sps.block_diag(lst_idiag1), sps.coo_matrix((n_distributions * n, n)))) + # columns constraints - lst_idiag2=[] - lst_eye=[] + lst_idiag2 = [] + lst_eye = [] for i in range(n_distributions): - if i==0: - lst_idiag2.append(sps.kron(np.ones((1,n)),sps.eye(n))) + if i == 0: + lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n))) lst_eye.append(-sps.eye(n)) else: - lst_idiag2.append(sps.kron(np.ones((1,n)),sps.eye(n-1,n))) - lst_eye.append(-sps.eye(n-1,n)) - - A_eq2=sps.hstack((sps.block_diag(lst_idiag2),sps.vstack(lst_eye))) - b_eq2=np.zeros((A_eq2.shape[0])) - + lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n - 1, n))) + lst_eye.append(-sps.eye(n - 1, n)) + + A_eq2 = sps.hstack((sps.block_diag(lst_idiag2), sps.vstack(lst_eye))) + b_eq2 = np.zeros((A_eq2.shape[0])) + # full problem - A_eq=sps.vstack((A_eq1,A_eq2)) - b_eq=np.concatenate((b_eq1,b_eq2)) - - if not cvxopt or solver in ['interior-point']: # cvxopt not installed or simplex/interior point - + A_eq = sps.vstack((A_eq1, A_eq2)) + b_eq = np.concatenate((b_eq1, b_eq2)) + + if not cvxopt or solver in ['interior-point']: # cvxopt not installed or simplex/interior point + if solver is None: - solver='interior-point' - - options={'sparse':True,'disp': verbose} - sol=sp.optimize.linprog(c,A_eq=A_eq,b_eq=b_eq,method=solver,options=options) - x=sol.x - b=x[-n:] - + solver = 'interior-point' + + options = {'sparse': True, 'disp': verbose} + sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver, options=options) + x = sol.x + b = x[-n:] + else: - - h=np.zeros((n_distributions*n2+n)) - G=-sps.eye(n_distributions*n2+n) - - sol=solvers.lp(matrix(c),scipy_sparse_to_spmatrix(G),matrix(h),A=scipy_sparse_to_spmatrix(A_eq),b=matrix(b_eq),solver=solver) - - x=np.array(sol['x']) - b=x[-n:].ravel() - + + h = np.zeros((n_distributions * n2 + n)) + G = -sps.eye(n_distributions * n2 + n) + + sol = solvers.lp(matrix(c), scipy_sparse_to_spmatrix(G), matrix(h), A=scipy_sparse_to_spmatrix(A_eq), b=matrix(b_eq), solver=solver) + + x = np.array(sol['x']) + b = x[-n:].ravel() + if log: return b, sol else: - return b \ No newline at end of file + return b From 4285cf64f8a2ec481586a190dd545d2a8946e134 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 11 May 2018 17:01:23 +0200 Subject: [PATCH 06/13] remove unused sparse --- ot/lp/cvx.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index 93097d16d..193c0f530 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -13,7 +13,7 @@ try: import cvxopt - from cvxopt import solvers, matrix, sparse, spmatrix + from cvxopt import solvers, matrix, spmatrix except ImportError: cvxopt = False @@ -114,13 +114,15 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po A_eq = sps.vstack((A_eq1, A_eq2)) b_eq = np.concatenate((b_eq1, b_eq2)) - if not cvxopt or solver in ['interior-point']: # cvxopt not installed or simplex/interior point + if not cvxopt or solver in ['interior-point']: + # cvxopt not installed or interior point if solver is None: solver = 'interior-point' options = {'sparse': True, 'disp': verbose} - sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver, options=options) + sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver, + options=options) x = sol.x b = x[-n:] @@ -129,7 +131,9 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po h = np.zeros((n_distributions * n2 + n)) G = -sps.eye(n_distributions * n2 + n) - sol = solvers.lp(matrix(c), scipy_sparse_to_spmatrix(G), matrix(h), A=scipy_sparse_to_spmatrix(A_eq), b=matrix(b_eq), solver=solver) + sol = solvers.lp(matrix(c), scipy_sparse_to_spmatrix(G), matrix(h), + A=scipy_sparse_to_spmatrix(A_eq), b=matrix(b_eq), + solver=solver) x = np.array(sol['x']) b = x[-n:].ravel() From 36f4f7ed2116841d7fe9514ee250bbf16e77b72d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 11 May 2018 17:07:56 +0200 Subject: [PATCH 07/13] better documentation --- ot/lp/__init__.py | 3 +++ ot/lp/cvx.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 6371feba1..5dda82ac4 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -11,9 +11,12 @@ import numpy as np +from .import cvx + # import compiled emd from .emd_wrap import emd_c, check_result from ..utils import parmap +from .cvx import barycenter def emd(a, b, M, numItermax=100000, log=False): diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index 193c0f530..c62da6aa4 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -38,8 +38,8 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po - :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn) - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - The linear program is solved using the default cvxopt solver if installed. - If cvxopt is not installed it uses the lp solver from scipy.optimize. + The linear program is solved using the interior point solver from scipy.optimize. + If cvxopt solver if installed it can use cvxopt. Parameters ---------- From fdb2f3af19d04872bafa0d9ec5563732e1d6209b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 11 May 2018 17:24:09 +0200 Subject: [PATCH 08/13] add test for barycenter --- ot/lp/cvx.py | 12 +++++++----- test/test_gpu.py | 2 +- test/test_ot.py | 35 ++++++++++++++++++++++++++++++++++- 3 files changed, 42 insertions(+), 7 deletions(-) diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index c62da6aa4..fe9ac7617 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -39,7 +39,9 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` The linear program is solved using the interior point solver from scipy.optimize. - If cvxopt solver if installed it can use cvxopt. + If cvxopt solver if installed it can use cvxopt + + Note that this problem do not scale well (both in memory and computational time). Parameters ---------- @@ -114,14 +116,14 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po A_eq = sps.vstack((A_eq1, A_eq2)) b_eq = np.concatenate((b_eq1, b_eq2)) - if not cvxopt or solver in ['interior-point']: + if not cvxopt or solver in ['interior-point']: # cvxopt not installed or interior point if solver is None: solver = 'interior-point' options = {'sparse': True, 'disp': verbose} - sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver, + sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver, options=options) x = sol.x b = x[-n:] @@ -131,8 +133,8 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po h = np.zeros((n_distributions * n2 + n)) G = -sps.eye(n_distributions * n2 + n) - sol = solvers.lp(matrix(c), scipy_sparse_to_spmatrix(G), matrix(h), - A=scipy_sparse_to_spmatrix(A_eq), b=matrix(b_eq), + sol = solvers.lp(matrix(c), scipy_sparse_to_spmatrix(G), matrix(h), + A=scipy_sparse_to_spmatrix(A_eq), b=matrix(b_eq), solver=solver) x = np.array(sol['x']) diff --git a/test/test_gpu.py b/test/test_gpu.py index 615c2a7c3..1e97c4565 100644 --- a/test/test_gpu.py +++ b/test/test_gpu.py @@ -76,4 +76,4 @@ def describe_res(r): time3 - time2)) describe_res(G2) - np.testing.assert_allclose(G1, G2, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(G1, G2, rtol=1e-3, atol=1e-3) diff --git a/test/test_ot.py b/test/test_ot.py index ea6d9dcd9..bf23e8c7e 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -10,7 +10,7 @@ import ot from ot.datasets import get_1D_gauss as gauss - +import pytest def test_doctest(): import doctest @@ -117,6 +117,39 @@ def test_emd2_multi(): np.testing.assert_allclose(emd1, emdn) +def test_lp_barycenter(): + + a1 = np.array([1.0, 0, 0])[:, None] + a2 = np.array([0, 0, 1.0])[:, None] + + A = np.hstack((a1, a2)) + M = np.array([[0, 1.0, 4.0], [1.0, 0, 1.0], [4.0, 1.0, 0]]) + + # obvious barycenter between two diracs + bary0 = np.array([0, 1.0, 0]) + + bary = ot.lp.barycenter(A, M, [.5, .5]) + + np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7) + np.testing.assert_allclose(bary.sum(), 1) + +@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") +def test_lp_barycenter_cvxopt(): + + a1 = np.array([1.0, 0, 0])[:, None] + a2 = np.array([0, 0, 1.0])[:, None] + + A = np.hstack((a1, a2)) + M = np.array([[0, 1.0, 4.0], [1.0, 0, 1.0], [4.0, 1.0, 0]]) + + # obvious barycenter between two diracs + bary0 = np.array([0, 1.0, 0]) + + bary = ot.lp.barycenter(A, M, [.5, .5],solver=None) + + np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7) + np.testing.assert_allclose(bary.sum(), 1) + def test_warnings(): n = 100 # nb bins m = 100 # nb bins From bd1af44ea0a819d5df0ccffbea4d05ed7547960b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 11 May 2018 17:25:32 +0200 Subject: [PATCH 09/13] add test barycenter cvxopt --- test/test_ot.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_ot.py b/test/test_ot.py index bf23e8c7e..cc25bf464 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -12,6 +12,7 @@ from ot.datasets import get_1D_gauss as gauss import pytest + def test_doctest(): import doctest @@ -133,6 +134,7 @@ def test_lp_barycenter(): np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7) np.testing.assert_allclose(bary.sum(), 1) + @pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): @@ -145,11 +147,12 @@ def test_lp_barycenter_cvxopt(): # obvious barycenter between two diracs bary0 = np.array([0, 1.0, 0]) - bary = ot.lp.barycenter(A, M, [.5, .5],solver=None) + bary = ot.lp.barycenter(A, M, [.5, .5], solver=None) np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7) np.testing.assert_allclose(bary.sum(), 1) + def test_warnings(): n = 100 # nb bins m = 100 # nb bins From da8f6119484642eca6e8efb3e5aaecce7a777622 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 14 May 2018 11:34:59 +0200 Subject: [PATCH 10/13] add example --- examples/plot_barycenter_lp_vs_entropic.py | 284 +++++++++++++++++++++ 1 file changed, 284 insertions(+) create mode 100644 examples/plot_barycenter_lp_vs_entropic.py diff --git a/examples/plot_barycenter_lp_vs_entropic.py b/examples/plot_barycenter_lp_vs_entropic.py new file mode 100644 index 000000000..2eded2f63 --- /dev/null +++ b/examples/plot_barycenter_lp_vs_entropic.py @@ -0,0 +1,284 @@ +# -*- coding: utf-8 -*- +""" +================================================================================= +1D Wasserstein barycenter comparison between exact LP and entropic regularization +================================================================================= + +This example illustrates the computation of regularized Wassersyein Barycenter +as proposed in [3]. + + +[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). +Iterative Bregman projections for regularized transportation problems +SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + +""" + +# Author: Remi Flamary +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot +# necessary for 3d plot even if not used +from mpl_toolkits.mplot3d import Axes3D # noqa +from matplotlib.collections import PolyCollection # noqa + +#import ot.lp.cvx as cvx + +# +# Generate data +# ------------- + +#%% parameters + +problems = [] + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +# Gaussian distributions +a1 = ot.datasets.get_1D_gauss(n, m=20, s=5) # m= mean, s= std +a2 = ot.datasets.get_1D_gauss(n, m=60, s=8) + +# creating matrix A containing all distributions +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] + +# loss matrix + normalization +M = ot.utils.dist0(n) +M /= M.max() + +# +# Plot data +# --------- + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') +pl.tight_layout() + +# +# Barycenter computation +# ---------------------- + +#%% barycenter computation + +alpha = 0.5 # 0<=alpha<=1 +weights = np.array([1 - alpha, alpha]) + +# l2bary +bary_l2 = A.dot(weights) + +# wasserstein +reg = 1e-3 +ot.tic() +bary_wass = ot.bregman.barycenter(A, M, reg, weights) +ot.toc() + + +ot.tic() +bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True) +ot.toc() + +pl.figure(2) +pl.clf() +pl.subplot(2, 1, 1) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') + +pl.subplot(2, 1, 2) +pl.plot(x, bary_l2, 'r', label='l2') +pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') +pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') +pl.legend() +pl.title('Barycenters') +pl.tight_layout() + +problems.append([A, [bary_l2, bary_wass, bary_wass2]]) + +#%% parameters + +a1 = 1.0 * (x > 10) * (x < 50) +a2 = 1.0 * (x > 60) * (x < 80) + +a1 /= a1.sum() +a2 /= a2.sum() + +# creating matrix A containing all distributions +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] + +# loss matrix + normalization +M = ot.utils.dist0(n) +M /= M.max() + + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') +pl.tight_layout() + +# +# Barycenter computation +# ---------------------- + +#%% barycenter computation + +alpha = 0.5 # 0<=alpha<=1 +weights = np.array([1 - alpha, alpha]) + +# l2bary +bary_l2 = A.dot(weights) + +# wasserstein +reg = 1e-3 +ot.tic() +bary_wass = ot.bregman.barycenter(A, M, reg, weights) +ot.toc() + + +ot.tic() +bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True) +ot.toc() + + +problems.append([A, [bary_l2, bary_wass, bary_wass2]]) + +pl.figure(2) +pl.clf() +pl.subplot(2, 1, 1) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') + +pl.subplot(2, 1, 2) +pl.plot(x, bary_l2, 'r', label='l2') +pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') +pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') +pl.legend() +pl.title('Barycenters') +pl.tight_layout() + +#%% parameters + +a1 = np.zeros(n) +a2 = np.zeros(n) + +a1[10] = .25 +a1[20] = .5 +a1[30] = .25 +a2[80] = 1 + + +a1 /= a1.sum() +a2 /= a2.sum() + +# creating matrix A containing all distributions +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] + +# loss matrix + normalization +M = ot.utils.dist0(n) +M /= M.max() + + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') +pl.tight_layout() + +# +# Barycenter computation +# ---------------------- + +#%% barycenter computation + +alpha = 0.5 # 0<=alpha<=1 +weights = np.array([1 - alpha, alpha]) + +# l2bary +bary_l2 = A.dot(weights) + +# wasserstein +reg = 1e-3 +ot.tic() +bary_wass = ot.bregman.barycenter(A, M, reg, weights) +ot.toc() + + +ot.tic() +bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True) +ot.toc() + + +problems.append([A, [bary_l2, bary_wass, bary_wass2]]) + +pl.figure(2) +pl.clf() +pl.subplot(2, 1, 1) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') + +pl.subplot(2, 1, 2) +pl.plot(x, bary_l2, 'r', label='l2') +pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') +pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') +pl.legend() +pl.title('Barycenters') +pl.tight_layout() + + +# +# Final figure +# ------------ +# + +#%% plot + +nbm = len(problems) +nbm2 = (nbm // 2) + + +pl.figure(2, (20, 6)) +pl.clf() + +for i in range(nbm): + + A = problems[i][0] + bary_l2 = problems[i][1][0] + bary_wass = problems[i][1][1] + bary_wass2 = problems[i][1][2] + + pl.subplot(2, nbm, 1 + i) + for j in range(n_distributions): + pl.plot(x, A[:, j]) + if i == nbm2: + pl.title('Distributions') + pl.xticks(()) + pl.yticks(()) + + pl.subplot(2, nbm, 1 + i) + + pl.plot(x, bary_l2, 'r', label='L2 (Euclidean)') + pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') + pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') + if i == nbm - 1: + pl.legend() + if i == nbm2: + pl.title('Barycenters') From 8ba983dea4a44fdf9946e4031db621815852394c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 14 May 2018 11:41:21 +0200 Subject: [PATCH 11/13] update doc + speedup autopep8 --- Makefile | 4 ++-- examples/plot_barycenter_lp_vs_entropic.py | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index 95714b8ad..5fa80e94d 100644 --- a/Makefile +++ b/Makefile @@ -58,9 +58,9 @@ notebook : ipython notebook --matplotlib=inline --notebook-dir=notebooks/ autopep8 : - autopep8 -ir test ot examples + autopep8 -ir test ot examples --jobs -1 aautopep8 : - autopep8 -air test ot examples + autopep8 -air test ot examples --jobs -1 FORCE : diff --git a/examples/plot_barycenter_lp_vs_entropic.py b/examples/plot_barycenter_lp_vs_entropic.py index 2eded2f63..3a65449f6 100644 --- a/examples/plot_barycenter_lp_vs_entropic.py +++ b/examples/plot_barycenter_lp_vs_entropic.py @@ -4,14 +4,19 @@ 1D Wasserstein barycenter comparison between exact LP and entropic regularization ================================================================================= -This example illustrates the computation of regularized Wassersyein Barycenter -as proposed in [3]. +This example illustrates the computation of regularized Wasserstein Barycenter +as proposed in [3] and exact LP barycenters using standard LP solver. +It reproduces approximately Figure 3.1 and 3.2 from the following paper: +Cuturi, M., & Peyré, G. (2016). A smoothed dual approach for variational +Wasserstein problems. SIAM Journal on Imaging Sciences, 9(1), 320-343. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + + """ # Author: Remi Flamary From 3f1482238925932bb6c9c606651427491a65365c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 14 May 2018 17:03:36 +0200 Subject: [PATCH 12/13] last change example --- examples/plot_barycenter_lp_vs_entropic.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/plot_barycenter_lp_vs_entropic.py b/examples/plot_barycenter_lp_vs_entropic.py index 3a65449f6..6936bbb00 100644 --- a/examples/plot_barycenter_lp_vs_entropic.py +++ b/examples/plot_barycenter_lp_vs_entropic.py @@ -278,7 +278,7 @@ pl.xticks(()) pl.yticks(()) - pl.subplot(2, nbm, 1 + i) + pl.subplot(2, nbm, 1 + i + nbm) pl.plot(x, bary_l2, 'r', label='L2 (Euclidean)') pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') @@ -287,3 +287,6 @@ pl.legend() if i == nbm2: pl.title('Barycenters') + + pl.xticks(()) + pl.yticks(()) From 54f0b47e55c966d5492e4ce19ec4e704ef3278d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 29 May 2018 16:08:33 +0200 Subject: [PATCH 13/13] update documentation for barycenter function --- ot/bregman.py | 4 ++-- ot/lp/cvx.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index e788ef585..b017c1a85 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -839,13 +839,13 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, Parameters ---------- A : np.ndarray (d,n) - n training distributions of size d + n training distributions a_i of size d M : np.ndarray (d,d) loss matrix for OT reg : float Regularization term >0 weights : np.ndarray (n,) - Weights of each histogram i_i on the simplex + Weights of each histogram a_i on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index fe9ac7617..c8c75bc82 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -46,13 +46,13 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po Parameters ---------- A : np.ndarray (d,n) - n training distributions of size d + n training distributions a_i of size d M : np.ndarray (d,d) loss matrix for OT reg : float Regularization term >0 weights : np.ndarray (n,) - Weights of each histogram i_i on the simplex + Weights of each histogram a_i on the simplex (barycentric coodinates) verbose : bool, optional Print information along iterations log : bool, optional