1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ demo of Optimal transport for domain adaptation with image color adaptation as in [6]
4
+
5
+ [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
6
+ """
7
+
8
+ import numpy as np
9
+ import scipy .ndimage as spi
10
+ import matplotlib .pylab as pl
11
+ import ot
12
+
13
+
14
+ #%% Loading images
15
+
16
+ I1 = spi .imread ('../data/ocean_day.jpg' ).astype (np .float64 )/ 256
17
+ I2 = spi .imread ('../data/ocean_sunset.jpg' ).astype (np .float64 )/ 256
18
+
19
+ #%% Plot images
20
+
21
+ pl .figure (1 )
22
+
23
+ pl .subplot (1 ,2 ,1 )
24
+ pl .imshow (I1 )
25
+ pl .title ('Image 1' )
26
+
27
+ pl .subplot (1 ,2 ,2 )
28
+ pl .imshow (I2 )
29
+ pl .title ('Image 2' )
30
+
31
+ pl .show ()
32
+
33
+ #%% Image conversion and dataset generation
34
+
35
+ def im2mat (I ):
36
+ """Converts and image to matrix (one pixel per line)"""
37
+ return I .reshape ((I .shape [0 ]* I .shape [1 ],I .shape [2 ]))
38
+
39
+ def mat2im (X ,shape ):
40
+ """Converts back a matrix to an image"""
41
+ return X .reshape (shape )
42
+
43
+ X1 = im2mat (I1 )
44
+ X2 = im2mat (I2 )
45
+
46
+ # training samples
47
+ nb = 1000
48
+ idx1 = np .random .randint (X1 .shape [0 ],size = (nb ,))
49
+ idx2 = np .random .randint (X2 .shape [0 ],size = (nb ,))
50
+
51
+ xs = X1 [idx1 ,:]
52
+ xt = X2 [idx2 ,:]
53
+
54
+ #%% domain adaptation between images
55
+
56
+ # LP problem
57
+ da_emd = ot .da .OTDA () # init class
58
+ da_emd .fit (xs ,xt ) # fit distributions
59
+
60
+
61
+ # sinkhorn regularization
62
+ lambd = 1e-1
63
+ da_entrop = ot .da .OTDA_sinkhorn ()
64
+ da_entrop .fit (xs ,xt ,reg = lambd )
65
+
66
+
67
+
68
+ #%% prediction between images (using out of sample prediction as in [6])
69
+
70
+ X1t = da_emd .predict (X1 )
71
+ X2t = da_emd .predict (X2 ,- 1 )
72
+
73
+
74
+ X1te = da_entrop .predict (X1 )
75
+ X2te = da_entrop .predict (X2 ,- 1 )
76
+
77
+
78
+ def minmax (I ):
79
+ return np .minimum (np .maximum (I ,0 ),1 )
80
+
81
+ I1t = minmax (mat2im (X1t ,I1 .shape ))
82
+ I2t = minmax (mat2im (X2t ,I2 .shape ))
83
+
84
+ I1te = minmax (mat2im (X1te ,I1 .shape ))
85
+ I2te = minmax (mat2im (X2te ,I2 .shape ))
86
+
87
+ #%% plot all images
88
+
89
+ pl .figure (2 ,(10 ,8 ))
90
+
91
+ pl .subplot (2 ,3 ,1 )
92
+
93
+ pl .imshow (I1 )
94
+ pl .title ('Image 1' )
95
+
96
+ pl .subplot (2 ,3 ,2 )
97
+ pl .imshow (I1t )
98
+ pl .title ('Image 1 Adapt' )
99
+
100
+
101
+ pl .subplot (2 ,3 ,3 )
102
+ pl .imshow (I1te )
103
+ pl .title ('Image 1 Adapt (reg)' )
104
+
105
+ pl .subplot (2 ,3 ,4 )
106
+
107
+ pl .imshow (I2 )
108
+ pl .title ('Image 2' )
109
+
110
+ pl .subplot (2 ,3 ,5 )
111
+ pl .imshow (I2t )
112
+ pl .title ('Image 2 Adapt' )
113
+
114
+
115
+ pl .subplot (2 ,3 ,6 )
116
+ pl .imshow (I2te )
117
+ pl .title ('Image 2 Adapt (reg)' )
118
+
119
+ pl .show ()
0 commit comments