-
Notifications
You must be signed in to change notification settings - Fork 516
/
Copy pathplot.py
110 lines (87 loc) · 3.26 KB
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Functions for plotting OT matrices
.. warning::
Note that by default the module is not import in :mod:`ot`. In order to
use it you need to explicitly import :mod:`ot.plot`
"""
# Author: Remi Flamary <[email protected]>
#
# License: MIT License
import numpy as np
import matplotlib.pylab as pl
from matplotlib import gridspec
def plot1D_mat(a, b, M, title=''):
r""" Plot matrix :math:`\mathbf{M}` with the source and target 1D distribution
Creates a subplot with the source distribution :math:`\mathbf{a}` on the left and
target distribution :math:`\mathbf{b}` on the top. The matrix :math:`\mathbf{M}` is shown in between.
Parameters
----------
a : ndarray, shape (na,)
Source distribution
b : ndarray, shape (nb,)
Target distribution
M : ndarray, shape (na, nb)
Matrix to plot
"""
na, nb = M.shape
gs = gridspec.GridSpec(3, 3)
xa = np.arange(na)
xb = np.arange(nb)
ax1 = pl.subplot(gs[0, 1:])
pl.plot(xb, b, 'r', label='Target distribution')
pl.yticks(())
pl.title(title)
ax2 = pl.subplot(gs[1:, 0])
pl.plot(a, xa, 'b', label='Source distribution')
pl.gca().invert_xaxis()
pl.gca().invert_yaxis()
pl.xticks(())
pl.subplot(gs[1:, 1:], sharex=ax1, sharey=ax2)
pl.imshow(M, interpolation='nearest')
pl.axis('off')
pl.xlim((0, nb))
pl.tight_layout()
pl.subplots_adjust(wspace=0., hspace=0.2)
def plot2D_samples_mat(xs, xt, G, draw_arrows: bool = True, thr=1e-8, **kwargs):
r""" Plot matrix :math:`\mathbf{G}` in 2D with lines using alpha values
Plot lines between source and target 2D samples with a color
proportional to the value of the matrix :math:`\mathbf{G}` between samples.
Parameters
----------
xs : ndarray, shape (ns,2)
Source samples positions
b : ndarray, shape (nt,2)
Target samples positions
G : ndarray, shape (na,nb)
OT matrix
draw_arrows : bool, optional
If True, draw directional arrows in the middle of the lines
thr : float, optional
threshold above which the line is drawn
**kwargs : dict
parameters given to the plot functions (default color is black if
nothing given)
"""
if ('color' not in kwargs) and ('c' not in kwargs):
kwargs['color'] = 'k'
mx = G.max()
if 'alpha' in kwargs:
scale = kwargs['alpha']
del kwargs['alpha']
else:
scale = 1
for i in range(xs.shape[0]):
for j in range(xt.shape[0]):
if G[i, j] / mx > thr:
plt.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]],
alpha=G[i, j] / mx * scale, **kwargs)
if draw_arrows:
# Calculate the midpoint
mid_x = (xs[i, 0] + xt[j, 0]) / 2
mid_y = (xs[i, 1] + xt[j, 1]) / 2
# Annotate with an arrowhead at the midpoint
plt.annotate('',
xy=(mid_x, mid_y),
xytext=(mid_x - 0.5 * (xt[j, 0] - xs[i, 0]), mid_y - 0.5 * (xt[j, 1] - xs[i, 1])),
arrowprops=dict(arrowstyle='-|>', color=kwargs['color'], alpha=G[i, j] / mx * scale)
)