@@ -26,16 +26,20 @@ def __init__(
26
26
self .pil_image = None
27
27
self .has_latent = False
28
28
w , h = image_shape
29
- try :
30
- comp_adjusted = TF .resize (comp .clone (), (h , w ))
31
- except :
32
- # comp_adjusted = comp.clone()
33
- # Need to convert the latent to its image form
34
- comp_adjusted = img_model .decode_tensor (comp .clone ())
29
+ comp_adjusted = TF .resize (comp .clone (), (h , w ))
30
+ # try:
31
+ # comp_adjusted = TF.resize(comp.clone(), (h, w))
32
+ # except:
33
+ # # comp_adjusted = comp.clone()
34
+ # # Need to convert the latent to its image form
35
+ # comp_adjusted = img_model.decode_tensor(comp.clone())
35
36
self .direct_loss = MSELoss (comp_adjusted , weight , stop , name , image_shape )
36
37
37
38
@torch .no_grad ()
38
39
def set_comp (self , pil_image , device = DEVICE ):
40
+ """
41
+ sets the DIRECT loss anchor "comp" to the tensorized image.
42
+ """
39
43
logger .debug (type (pil_image ))
40
44
self .pil_image = pil_image
41
45
self .has_latent = False
@@ -47,6 +51,10 @@ def set_comp(self, pil_image, device=DEVICE):
47
51
48
52
@classmethod
49
53
def convert_input (cls , input , img ):
54
+ """
55
+ Converts the input image tensor to the image representation of the image model.
56
+ E.g. if img is VQGAN, then the input tensor is converted to the latent representation.
57
+ """
50
58
logger .debug (type (input )) # pretty sure this is gonna be tensor
51
59
# return input # this is the default MSE loss version
52
60
return img .make_latent (input )
@@ -107,25 +115,62 @@ def get_loss(self, input, img):
107
115
logger .debug (
108
116
self .comp .shape
109
117
) # [1 1 1 1] -> from target image constructor when no input image provided
118
+
119
+ # why is the latent comp only set here? why not in the __init__ and set_comp?
110
120
if not self .has_latent :
111
121
# make_latent() encodes the image through a dummy class instance, returns the resulting fitted image representation
112
122
# if get_image_tensor() is not implemented, then the returned 'latent' tensor is just the tensorized pil image
113
123
latent = img .make_latent (self .pil_image )
114
124
logger .debug (type (latent )) # EMAParametersDict
115
125
logger .debug (type (self .comp )) # torch.Tensor
116
126
with torch .no_grad ():
117
- self .comp .set_ (latent .clone ())
127
+ if type (latent ) == type (self .comp ):
128
+ self .comp .set_ (latent .clone ())
129
+ # else:
130
+
118
131
self .has_latent = True
132
+
119
133
l1 = super ().get_loss (img .get_latent_tensor (), img ) / 2
120
134
l2 = self .direct_loss .get_loss (input , img ) / 10
121
135
return l1 + l2
122
136
123
137
124
138
######################################################################
125
139
140
+ # fuck it, let's just make a dip latent loss from scratch.
141
+
142
+
143
+ # The issue we're resolving here is that by inheriting from the MSELoss,
144
+ # I can't easily set the comp to the parameters of the image model.
145
+
146
+ from pytti .LossAug .BaseLossClass import Loss
147
+ from pytti .image_models .ema import EMAImage , EMAParametersDict
148
+ from pytti .rotoscoper import Rotoscoper
149
+
150
+ import deep_image_prior
151
+ import deep_image_prior .models
152
+ from deep_image_prior .models import (
153
+ get_hq_skip_net ,
154
+ get_non_offset_params ,
155
+ get_offset_params ,
156
+ )
126
157
127
- class LatentLossGeneric (LatentLoss ):
128
- # class LatentLoss(MSELoss):
158
+
159
+ def load_dip (input_depth , num_scales , offset_type , offset_groups , device ):
160
+ dip_net = get_hq_skip_net (
161
+ input_depth ,
162
+ skip_n33d = 192 ,
163
+ skip_n33u = 192 ,
164
+ skip_n11 = 4 ,
165
+ num_scales = num_scales ,
166
+ offset_type = offset_type ,
167
+ offset_groups = offset_groups ,
168
+ ).to (device )
169
+
170
+ return dip_net
171
+
172
+
173
+ class LatentLossDIP (Loss ):
129
174
@torch .no_grad ()
130
175
def __init__ (
131
176
self ,
@@ -134,29 +179,109 @@ def __init__(
134
179
stop = - math .inf ,
135
180
name = "direct target loss" ,
136
181
image_shape = None ,
182
+ device = None ,
137
183
):
138
- super ().__init__ (comp , weight , stop , name , image_shape )
184
+ ##################################################################
185
+ super ().__init__ (weight , stop , name , device )
186
+ if image_shape is None :
187
+ raise
188
+ # height, width = comp.shape[-2:]
189
+ # image_shape = (width, height)
190
+ self .image_shape = image_shape
191
+ self .register_buffer ("mask" , torch .ones (1 , 1 , 1 , 1 , device = self .device ))
192
+ self .use_mask = False
193
+ ##################################################################
139
194
self .pil_image = None
140
195
self .has_latent = False
141
- w , h = image_shape
142
- self .direct_loss = MSELoss (
143
- TF .resize (comp .clone (), (h , w )), weight , stop , name , image_shape
196
+ logger .debug (type (comp )) # inits to image tensor
197
+ if comp is None :
198
+ comp = self .default_comp ()
199
+ if isinstance (comp , EMAParametersDict ):
200
+ logger .debug ("initializing loss from latent" )
201
+ self .register_module ("comp" , comp )
202
+ self .has_latent = True
203
+ else :
204
+ w , h = image_shape
205
+ comp_adjusted = TF .resize (comp .clone (), (h , w ))
206
+ # try:
207
+ # comp_adjusted = TF.resize(comp.clone(), (h, w))
208
+ # except:
209
+ # # comp_adjusted = comp.clone()
210
+ # # Need to convert the latent to its image form
211
+ # comp_adjusted = img_model.decode_tensor(comp.clone())
212
+ self .direct_loss = MSELoss (comp_adjusted , weight , stop , name , image_shape )
213
+
214
+ ##################################################################
215
+
216
+ logger .debug (type (comp ))
217
+
218
+ @classmethod
219
+ def default_comp (* args , ** kargs ):
220
+ logger .debug ("default_comp" )
221
+ device = kargs .get ("device" , "cuda" ) if torch .cuda .is_available () else "cpu"
222
+ net = load_dip (
223
+ input_depth = 32 ,
224
+ num_scales = 7 ,
225
+ offset_type = "none" ,
226
+ offset_groups = 4 ,
227
+ device = device ,
144
228
)
229
+ return EMAParametersDict (z = net , decay = 0.99 , device = device )
230
+
231
+ ###################################################################################
145
232
146
233
@torch .no_grad ()
147
234
def set_comp (self , pil_image , device = DEVICE ):
235
+ """
236
+ sets the DIRECT loss anchor "comp" to the tensorized image.
237
+ """
238
+ logger .debug (type (pil_image ))
148
239
self .pil_image = pil_image
149
240
self .has_latent = False
150
- self . direct_loss . set_comp (
151
- pil_image . resize ( self .image_shape , Image .LANCZOS )
241
+ im_resized = pil_image . resize (
242
+ self .image_shape , Image .LANCZOS
152
243
) # to do: ResizeRight
244
+ # self.direct_loss.set_comp(im_resized)
245
+
246
+ im_tensor = (
247
+ TF .to_tensor (pil_image )
248
+ .unsqueeze (0 )
249
+ .to (device , memory_format = torch .channels_last )
250
+ )
251
+
252
+ if hasattr (self , "direct_loss" ):
253
+ self .direct_loss .set_comp (im_tensor )
254
+ else :
255
+ self .direct_loss = MSELoss (
256
+ im_tensor , self .weight , self .stop , self .name , self .image_shape
257
+ )
258
+ # self.direct_loss.set_comp(im_resized)
259
+
260
+ @classmethod
261
+ def convert_input (cls , input , img ):
262
+ """
263
+ Converts the input image tensor to the image representation of the image model.
264
+ E.g. if img is VQGAN, then the input tensor is converted to the latent representation.
265
+ """
266
+ logger .debug (type (input )) # pretty sure this is gonna be tensor
267
+ # return input # this is the default MSE loss version
268
+ return img .make_latent (input )
153
269
154
270
@classmethod
155
271
@vram_usage_mode ("Latent Image Loss" )
156
272
@torch .no_grad ()
157
273
def TargetImage (
158
- cls , prompt_string , image_shape , pil_image = None , is_path = False , device = DEVICE
274
+ cls ,
275
+ prompt_string ,
276
+ image_shape ,
277
+ pil_image = None ,
278
+ is_path = False ,
279
+ device = DEVICE ,
280
+ img_model = None ,
159
281
):
282
+ logger .debug (
283
+ type (pil_image )
284
+ ) # None. emitted prior to do_run:559 but after parse_scenes:122. Why even use this constructor if no pil_image?
160
285
text , weight , stop = parse (
161
286
prompt_string , r"(?<!^http)(?<!s):|:(?!/)" , ["" , "1" , "-inf" ]
162
287
)
@@ -168,24 +293,69 @@ def TargetImage(
168
293
comp = (
169
294
MSELoss .make_comp (pil_image )
170
295
if pil_image is not None
171
- else torch .zeros (1 , 1 , 1 , 1 , device = device )
296
+ # else torch.zeros(1, 1, 1, 1, device=device)
297
+ else cls .default_comp (img_model = img_model )
172
298
)
173
299
out = cls (comp , weight , stop , text + " (latent)" , image_shape )
174
300
if pil_image is not None :
175
301
out .set_comp (pil_image )
176
- out .set_mask (mask )
302
+ if (
303
+ mask
304
+ ): # this will break if there's no pil_image since the direct_loss won't be initialized
305
+ out .set_mask (mask )
177
306
return out
178
307
179
308
def set_mask (self , mask , inverted = False ):
180
309
self .direct_loss .set_mask (mask , inverted )
181
- super ().set_mask (mask , inverted )
310
+ # super().set_mask(mask, inverted)
311
+ # if device is None:
312
+ device = self .device
313
+ if isinstance (mask , str ) and mask != "" :
314
+ if mask [0 ] == "-" :
315
+ mask = mask [1 :]
316
+ inverted = True
317
+ if mask .strip ()[- 4 :] == ".mp4" :
318
+ r = Rotoscoper (mask , self )
319
+ r .update (0 )
320
+ return
321
+ mask = Image .open (fetch (mask )).convert ("L" )
322
+ if isinstance (mask , Image .Image ):
323
+ with vram_usage_mode ("Masks" ):
324
+ mask = (
325
+ TF .to_tensor (mask )
326
+ .unsqueeze (0 )
327
+ .to (device , memory_format = torch .channels_last )
328
+ )
329
+ if mask not in ["" , None ]:
330
+ self .mask .set_ (mask if not inverted else (1 - mask ))
331
+ self .use_mask = mask not in ["" , None ]
182
332
183
333
def get_loss (self , input , img ):
334
+ logger .debug (type (input )) # Tensor
335
+ logger .debug (input .shape ) # this is an image tensor
336
+ logger .debug (type (img )) # DIPImage
337
+ logger .debug (type (self .comp )) # EMAParametersDict
338
+ # logger.debug(
339
+ # self.comp.shape
340
+ # ) # [1 1 1 1] -> from target image constructor when no input image provided
341
+
342
+ # why is the latent comp only set here? why not in the __init__ and set_comp?
184
343
if not self .has_latent :
344
+ raise
345
+ # make_latent() encodes the image through a dummy class instance, returns the resulting fitted image representation
346
+ # if get_image_tensor() is not implemented, then the returned 'latent' tensor is just the tensorized pil image
185
347
latent = img .make_latent (self .pil_image )
348
+ logger .debug (type (latent )) # EMAParametersDict
349
+ logger .debug (type (self .comp )) # torch.Tensor
186
350
with torch .no_grad ():
187
- self .comp .set_ (latent .clone ())
351
+ if type (latent ) == type (self .comp ):
352
+ self .comp .set_ (latent .clone ())
353
+ # else:
354
+
188
355
self .has_latent = True
356
+
357
+ estimated_image = self .comp .get_image_tensor ()
358
+
189
359
l1 = super ().get_loss (img .get_latent_tensor (), img ) / 2
190
360
l2 = self .direct_loss .get_loss (input , img ) / 10
191
361
return l1 + l2
0 commit comments