Skip to content

Commit 3b83a3f

Browse files
committed
Revise V^beta loss
1 parent 9b9400a commit 3b83a3f

File tree

1 file changed

+12
-26
lines changed

1 file changed

+12
-26
lines changed

style_transfer/style_transfer.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,6 @@ def v_beta_loss(x, reduction="mean", channel_reduction=None, beta=2.0, eps=1e-8)
188188
Deep Image Representations by Inverting Them", Mahendran et al (2014)
189189
(https://arxiv.org/abs/1412.0035).
190190
191-
The nine-point stencil is from "Fundamental Solutions of 9-point Discrete
192-
Laplacians; Derivation and Tables", Lynch (1992)
193-
(https://docs.lib.purdue.edu/cgi/viewcontent.cgi?article=1928&context=cstech).
194-
195191
Vectorial total variation was proposed in "Color TV: total variation methods for
196192
restoration of vector-valued images", Blomgren et al (1998)
197193
(https://ieeexplore.ieee.org/document/661180).
@@ -213,37 +209,27 @@ def v_beta_loss(x, reduction="mean", channel_reduction=None, beta=2.0, eps=1e-8)
213209
l, m, h = slice(None, -2), slice(1, -1), slice(2, None)
214210
x = torch.nn.functional.pad(x, (1, 1, 1, 1), "replicate")
215211
target = x[..., m, m]
216-
ml = (x[..., m, l] - target) ** 2 / 3
217-
lm = (x[..., l, m] - target) ** 2 / 3
218-
mh = (x[..., m, h] - target) ** 2 / 3
219-
hm = (x[..., h, m] - target) ** 2 / 3
220-
ll = (x[..., l, l] - target) ** 2 / 12
221-
hh = (x[..., h, h] - target) ** 2 / 12
222-
hl = (x[..., h, l] - target) ** 2 / 12
223-
lh = (x[..., l, h] - target) ** 2 / 12
224-
diffs = ml + lm + mh + hm + ll + hh + hl + lh
212+
ml = (x[..., m, l] - target) ** 2 / 4 # horizontal 1
213+
mh = (x[..., m, h] - target) ** 2 / 4 # horizontal 2
214+
lm = (x[..., l, m] - target) ** 2 / 4 # vertical 1
215+
hm = (x[..., h, m] - target) ** 2 / 4 # vertical 2
216+
ll = (x[..., l, l] - target) ** 2 / 8 # diagonal upper left to lower right 1
217+
hh = (x[..., h, h] - target) ** 2 / 8 # diagonal upper left to lower right 2
218+
lh = (x[..., l, h] - target) ** 2 / 8 # diagonal lower left to upper right 1
219+
hl = (x[..., h, l] - target) ** 2 / 8 # diagonal lower left to upper right 2
220+
diffs = ml + mh + lm + hm + ll + hh + lh + hl
225221
losses = torch.pow(reductions[channel_reduction](diffs, dim=-3) + eps, beta / 2)
226222
return reductions[reduction](losses)
227223

228224

229225
class VBetaLoss(nn.Module):
230-
__doc__ = v_beta_loss.__doc__
231-
232-
def __init__(self, reduction="mean", channel_reduction=None, beta=2.0, eps=1e-8):
226+
def __init__(self, beta=2.0, eps=1e-8):
233227
super().__init__()
234-
self.reduction = reduction
235-
self.channel_reduction = channel_reduction
236228
self.beta = beta
237229
self.eps = eps
238230

239231
def forward(self, x):
240-
return v_beta_loss(
241-
x,
242-
reduction=self.reduction,
243-
channel_reduction=self.channel_reduction,
244-
beta=self.beta,
245-
eps=self.eps,
246-
)
232+
return v_beta_loss(x * 4, beta=self.beta, eps=self.eps)
247233

248234

249235
class SumLoss(nn.ModuleList):
@@ -400,7 +386,7 @@ def get_image(self, image_type='pil'):
400386
def stylize(self, content_image, style_images, *,
401387
style_weights=None,
402388
content_weight: float = 0.015,
403-
tv_weight: float = 2.,
389+
tv_weight: float = 0.125,
404390
tv_beta: float = 2.,
405391
optimizer: str = 'adam',
406392
min_scale: int = 128,

0 commit comments

Comments
 (0)