@@ -188,10 +188,6 @@ def v_beta_loss(x, reduction="mean", channel_reduction=None, beta=2.0, eps=1e-8)
188
188
Deep Image Representations by Inverting Them", Mahendran et al (2014)
189
189
(https://arxiv.org/abs/1412.0035).
190
190
191
- The nine-point stencil is from "Fundamental Solutions of 9-point Discrete
192
- Laplacians; Derivation and Tables", Lynch (1992)
193
- (https://docs.lib.purdue.edu/cgi/viewcontent.cgi?article=1928&context=cstech).
194
-
195
191
Vectorial total variation was proposed in "Color TV: total variation methods for
196
192
restoration of vector-valued images", Blomgren et al (1998)
197
193
(https://ieeexplore.ieee.org/document/661180).
@@ -213,37 +209,27 @@ def v_beta_loss(x, reduction="mean", channel_reduction=None, beta=2.0, eps=1e-8)
213
209
l , m , h = slice (None , - 2 ), slice (1 , - 1 ), slice (2 , None )
214
210
x = torch .nn .functional .pad (x , (1 , 1 , 1 , 1 ), "replicate" )
215
211
target = x [..., m , m ]
216
- ml = (x [..., m , l ] - target ) ** 2 / 3
217
- lm = (x [..., l , m ] - target ) ** 2 / 3
218
- mh = (x [..., m , h ] - target ) ** 2 / 3
219
- hm = (x [..., h , m ] - target ) ** 2 / 3
220
- ll = (x [..., l , l ] - target ) ** 2 / 12
221
- hh = (x [..., h , h ] - target ) ** 2 / 12
222
- hl = (x [..., h , l ] - target ) ** 2 / 12
223
- lh = (x [..., l , h ] - target ) ** 2 / 12
224
- diffs = ml + lm + mh + hm + ll + hh + hl + lh
212
+ ml = (x [..., m , l ] - target ) ** 2 / 4 # horizontal 1
213
+ mh = (x [..., m , h ] - target ) ** 2 / 4 # horizontal 2
214
+ lm = (x [..., l , m ] - target ) ** 2 / 4 # vertical 1
215
+ hm = (x [..., h , m ] - target ) ** 2 / 4 # vertical 2
216
+ ll = (x [..., l , l ] - target ) ** 2 / 8 # diagonal upper left to lower right 1
217
+ hh = (x [..., h , h ] - target ) ** 2 / 8 # diagonal upper left to lower right 2
218
+ lh = (x [..., l , h ] - target ) ** 2 / 8 # diagonal lower left to upper right 1
219
+ hl = (x [..., h , l ] - target ) ** 2 / 8 # diagonal lower left to upper right 2
220
+ diffs = ml + mh + lm + hm + ll + hh + lh + hl
225
221
losses = torch .pow (reductions [channel_reduction ](diffs , dim = - 3 ) + eps , beta / 2 )
226
222
return reductions [reduction ](losses )
227
223
228
224
229
225
class VBetaLoss (nn .Module ):
230
- __doc__ = v_beta_loss .__doc__
231
-
232
- def __init__ (self , reduction = "mean" , channel_reduction = None , beta = 2.0 , eps = 1e-8 ):
226
+ def __init__ (self , beta = 2.0 , eps = 1e-8 ):
233
227
super ().__init__ ()
234
- self .reduction = reduction
235
- self .channel_reduction = channel_reduction
236
228
self .beta = beta
237
229
self .eps = eps
238
230
239
231
def forward (self , x ):
240
- return v_beta_loss (
241
- x ,
242
- reduction = self .reduction ,
243
- channel_reduction = self .channel_reduction ,
244
- beta = self .beta ,
245
- eps = self .eps ,
246
- )
232
+ return v_beta_loss (x * 4 , beta = self .beta , eps = self .eps )
247
233
248
234
249
235
class SumLoss (nn .ModuleList ):
@@ -400,7 +386,7 @@ def get_image(self, image_type='pil'):
400
386
def stylize (self , content_image , style_images , * ,
401
387
style_weights = None ,
402
388
content_weight : float = 0.015 ,
403
- tv_weight : float = 2. ,
389
+ tv_weight : float = 0.125 ,
404
390
tv_beta : float = 2. ,
405
391
optimizer : str = 'adam' ,
406
392
min_scale : int = 128 ,
0 commit comments