Skip to content

Commit 3f46a24

Browse files
committed
add image color adaptation demo
1 parent cc3da98 commit 3f46a24

File tree

6 files changed

+119
-0
lines changed

6 files changed

+119
-0
lines changed

data/autumn.jpg

606 KB
Loading

data/fallingwater.jpg

374 KB
Loading

data/ocean_day.jpg

74.8 KB
Loading

data/ocean_sunset.jpg

129 KB
Loading

data/woods.jpg

378 KB
Loading

examples/demo_OTDA_color_images.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
demo of Optimal transport for domain adaptation with image color adaptation as in [6]
4+
5+
[6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
6+
"""
7+
8+
import numpy as np
9+
import scipy.ndimage as spi
10+
import matplotlib.pylab as pl
11+
import ot
12+
13+
14+
#%% Loading images
15+
16+
I1=spi.imread('../data/ocean_day.jpg').astype(np.float64)/256
17+
I2=spi.imread('../data/ocean_sunset.jpg').astype(np.float64)/256
18+
19+
#%% Plot images
20+
21+
pl.figure(1)
22+
23+
pl.subplot(1,2,1)
24+
pl.imshow(I1)
25+
pl.title('Image 1')
26+
27+
pl.subplot(1,2,2)
28+
pl.imshow(I2)
29+
pl.title('Image 2')
30+
31+
pl.show()
32+
33+
#%% Image conversion and dataset generation
34+
35+
def im2mat(I):
36+
"""Converts and image to matrix (one pixel per line)"""
37+
return I.reshape((I.shape[0]*I.shape[1],I.shape[2]))
38+
39+
def mat2im(X,shape):
40+
"""Converts back a matrix to an image"""
41+
return X.reshape(shape)
42+
43+
X1=im2mat(I1)
44+
X2=im2mat(I2)
45+
46+
# training samples
47+
nb=1000
48+
idx1=np.random.randint(X1.shape[0],size=(nb,))
49+
idx2=np.random.randint(X2.shape[0],size=(nb,))
50+
51+
xs=X1[idx1,:]
52+
xt=X2[idx2,:]
53+
54+
#%% domain adaptation between images
55+
56+
# LP problem
57+
da_emd=ot.da.OTDA() # init class
58+
da_emd.fit(xs,xt) # fit distributions
59+
60+
61+
# sinkhorn regularization
62+
lambd=1e-1
63+
da_entrop=ot.da.OTDA_sinkhorn()
64+
da_entrop.fit(xs,xt,reg=lambd)
65+
66+
67+
68+
#%% prediction between images (using out of sample prediction as in [6])
69+
70+
X1t=da_emd.predict(X1)
71+
X2t=da_emd.predict(X2,-1)
72+
73+
74+
X1te=da_entrop.predict(X1)
75+
X2te=da_entrop.predict(X2,-1)
76+
77+
78+
def minmax(I):
79+
return np.minimum(np.maximum(I,0),1)
80+
81+
I1t=minmax(mat2im(X1t,I1.shape))
82+
I2t=minmax(mat2im(X2t,I2.shape))
83+
84+
I1te=minmax(mat2im(X1te,I1.shape))
85+
I2te=minmax(mat2im(X2te,I2.shape))
86+
87+
#%% plot all images
88+
89+
pl.figure(2,(10,8))
90+
91+
pl.subplot(2,3,1)
92+
93+
pl.imshow(I1)
94+
pl.title('Image 1')
95+
96+
pl.subplot(2,3,2)
97+
pl.imshow(I1t)
98+
pl.title('Image 1 Adapt')
99+
100+
101+
pl.subplot(2,3,3)
102+
pl.imshow(I1te)
103+
pl.title('Image 1 Adapt (reg)')
104+
105+
pl.subplot(2,3,4)
106+
107+
pl.imshow(I2)
108+
pl.title('Image 2')
109+
110+
pl.subplot(2,3,5)
111+
pl.imshow(I2t)
112+
pl.title('Image 2 Adapt')
113+
114+
115+
pl.subplot(2,3,6)
116+
pl.imshow(I2te)
117+
pl.title('Image 2 Adapt (reg)')
118+
119+
pl.show()

0 commit comments

Comments
 (0)