Skip to content

Commit 1552014

Browse files
committed
h8 this. saving progress, but I think I need to just backtrack and simplify how losses and image_models work first, then come back to this afterwards.
1 parent 02a9831 commit 1552014

File tree

6 files changed

+227
-28
lines changed

6 files changed

+227
-28
lines changed

src/pytti/LossAug/LatentLossClass.py

+190-20
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,20 @@ def __init__(
2626
self.pil_image = None
2727
self.has_latent = False
2828
w, h = image_shape
29-
try:
30-
comp_adjusted = TF.resize(comp.clone(), (h, w))
31-
except:
32-
# comp_adjusted = comp.clone()
33-
# Need to convert the latent to its image form
34-
comp_adjusted = img_model.decode_tensor(comp.clone())
29+
comp_adjusted = TF.resize(comp.clone(), (h, w))
30+
# try:
31+
# comp_adjusted = TF.resize(comp.clone(), (h, w))
32+
# except:
33+
# # comp_adjusted = comp.clone()
34+
# # Need to convert the latent to its image form
35+
# comp_adjusted = img_model.decode_tensor(comp.clone())
3536
self.direct_loss = MSELoss(comp_adjusted, weight, stop, name, image_shape)
3637

3738
@torch.no_grad()
3839
def set_comp(self, pil_image, device=DEVICE):
40+
"""
41+
sets the DIRECT loss anchor "comp" to the tensorized image.
42+
"""
3943
logger.debug(type(pil_image))
4044
self.pil_image = pil_image
4145
self.has_latent = False
@@ -47,6 +51,10 @@ def set_comp(self, pil_image, device=DEVICE):
4751

4852
@classmethod
4953
def convert_input(cls, input, img):
54+
"""
55+
Converts the input image tensor to the image representation of the image model.
56+
E.g. if img is VQGAN, then the input tensor is converted to the latent representation.
57+
"""
5058
logger.debug(type(input)) # pretty sure this is gonna be tensor
5159
# return input # this is the default MSE loss version
5260
return img.make_latent(input)
@@ -107,25 +115,62 @@ def get_loss(self, input, img):
107115
logger.debug(
108116
self.comp.shape
109117
) # [1 1 1 1] -> from target image constructor when no input image provided
118+
119+
# why is the latent comp only set here? why not in the __init__ and set_comp?
110120
if not self.has_latent:
111121
# make_latent() encodes the image through a dummy class instance, returns the resulting fitted image representation
112122
# if get_image_tensor() is not implemented, then the returned 'latent' tensor is just the tensorized pil image
113123
latent = img.make_latent(self.pil_image)
114124
logger.debug(type(latent)) # EMAParametersDict
115125
logger.debug(type(self.comp)) # torch.Tensor
116126
with torch.no_grad():
117-
self.comp.set_(latent.clone())
127+
if type(latent) == type(self.comp):
128+
self.comp.set_(latent.clone())
129+
# else:
130+
118131
self.has_latent = True
132+
119133
l1 = super().get_loss(img.get_latent_tensor(), img) / 2
120134
l2 = self.direct_loss.get_loss(input, img) / 10
121135
return l1 + l2
122136

123137

124138
######################################################################
125139

140+
# fuck it, let's just make a dip latent loss from scratch.
141+
142+
143+
# The issue we're resolving here is that by inheriting from the MSELoss,
144+
# I can't easily set the comp to the parameters of the image model.
145+
146+
from pytti.LossAug.BaseLossClass import Loss
147+
from pytti.image_models.ema import EMAImage, EMAParametersDict
148+
from pytti.rotoscoper import Rotoscoper
149+
150+
import deep_image_prior
151+
import deep_image_prior.models
152+
from deep_image_prior.models import (
153+
get_hq_skip_net,
154+
get_non_offset_params,
155+
get_offset_params,
156+
)
126157

127-
class LatentLossGeneric(LatentLoss):
128-
# class LatentLoss(MSELoss):
158+
159+
def load_dip(input_depth, num_scales, offset_type, offset_groups, device):
160+
dip_net = get_hq_skip_net(
161+
input_depth,
162+
skip_n33d=192,
163+
skip_n33u=192,
164+
skip_n11=4,
165+
num_scales=num_scales,
166+
offset_type=offset_type,
167+
offset_groups=offset_groups,
168+
).to(device)
169+
170+
return dip_net
171+
172+
173+
class LatentLossDIP(Loss):
129174
@torch.no_grad()
130175
def __init__(
131176
self,
@@ -134,29 +179,109 @@ def __init__(
134179
stop=-math.inf,
135180
name="direct target loss",
136181
image_shape=None,
182+
device=None,
137183
):
138-
super().__init__(comp, weight, stop, name, image_shape)
184+
##################################################################
185+
super().__init__(weight, stop, name, device)
186+
if image_shape is None:
187+
raise
188+
# height, width = comp.shape[-2:]
189+
# image_shape = (width, height)
190+
self.image_shape = image_shape
191+
self.register_buffer("mask", torch.ones(1, 1, 1, 1, device=self.device))
192+
self.use_mask = False
193+
##################################################################
139194
self.pil_image = None
140195
self.has_latent = False
141-
w, h = image_shape
142-
self.direct_loss = MSELoss(
143-
TF.resize(comp.clone(), (h, w)), weight, stop, name, image_shape
196+
logger.debug(type(comp)) # inits to image tensor
197+
if comp is None:
198+
comp = self.default_comp()
199+
if isinstance(comp, EMAParametersDict):
200+
logger.debug("initializing loss from latent")
201+
self.register_module("comp", comp)
202+
self.has_latent = True
203+
else:
204+
w, h = image_shape
205+
comp_adjusted = TF.resize(comp.clone(), (h, w))
206+
# try:
207+
# comp_adjusted = TF.resize(comp.clone(), (h, w))
208+
# except:
209+
# # comp_adjusted = comp.clone()
210+
# # Need to convert the latent to its image form
211+
# comp_adjusted = img_model.decode_tensor(comp.clone())
212+
self.direct_loss = MSELoss(comp_adjusted, weight, stop, name, image_shape)
213+
214+
##################################################################
215+
216+
logger.debug(type(comp))
217+
218+
@classmethod
219+
def default_comp(*args, **kargs):
220+
logger.debug("default_comp")
221+
device = kargs.get("device", "cuda") if torch.cuda.is_available() else "cpu"
222+
net = load_dip(
223+
input_depth=32,
224+
num_scales=7,
225+
offset_type="none",
226+
offset_groups=4,
227+
device=device,
144228
)
229+
return EMAParametersDict(z=net, decay=0.99, device=device)
230+
231+
###################################################################################
145232

146233
@torch.no_grad()
147234
def set_comp(self, pil_image, device=DEVICE):
235+
"""
236+
sets the DIRECT loss anchor "comp" to the tensorized image.
237+
"""
238+
logger.debug(type(pil_image))
148239
self.pil_image = pil_image
149240
self.has_latent = False
150-
self.direct_loss.set_comp(
151-
pil_image.resize(self.image_shape, Image.LANCZOS)
241+
im_resized = pil_image.resize(
242+
self.image_shape, Image.LANCZOS
152243
) # to do: ResizeRight
244+
# self.direct_loss.set_comp(im_resized)
245+
246+
im_tensor = (
247+
TF.to_tensor(pil_image)
248+
.unsqueeze(0)
249+
.to(device, memory_format=torch.channels_last)
250+
)
251+
252+
if hasattr(self, "direct_loss"):
253+
self.direct_loss.set_comp(im_tensor)
254+
else:
255+
self.direct_loss = MSELoss(
256+
im_tensor, self.weight, self.stop, self.name, self.image_shape
257+
)
258+
# self.direct_loss.set_comp(im_resized)
259+
260+
@classmethod
261+
def convert_input(cls, input, img):
262+
"""
263+
Converts the input image tensor to the image representation of the image model.
264+
E.g. if img is VQGAN, then the input tensor is converted to the latent representation.
265+
"""
266+
logger.debug(type(input)) # pretty sure this is gonna be tensor
267+
# return input # this is the default MSE loss version
268+
return img.make_latent(input)
153269

154270
@classmethod
155271
@vram_usage_mode("Latent Image Loss")
156272
@torch.no_grad()
157273
def TargetImage(
158-
cls, prompt_string, image_shape, pil_image=None, is_path=False, device=DEVICE
274+
cls,
275+
prompt_string,
276+
image_shape,
277+
pil_image=None,
278+
is_path=False,
279+
device=DEVICE,
280+
img_model=None,
159281
):
282+
logger.debug(
283+
type(pil_image)
284+
) # None. emitted prior to do_run:559 but after parse_scenes:122. Why even use this constructor if no pil_image?
160285
text, weight, stop = parse(
161286
prompt_string, r"(?<!^http)(?<!s):|:(?!/)", ["", "1", "-inf"]
162287
)
@@ -168,24 +293,69 @@ def TargetImage(
168293
comp = (
169294
MSELoss.make_comp(pil_image)
170295
if pil_image is not None
171-
else torch.zeros(1, 1, 1, 1, device=device)
296+
# else torch.zeros(1, 1, 1, 1, device=device)
297+
else cls.default_comp(img_model=img_model)
172298
)
173299
out = cls(comp, weight, stop, text + " (latent)", image_shape)
174300
if pil_image is not None:
175301
out.set_comp(pil_image)
176-
out.set_mask(mask)
302+
if (
303+
mask
304+
): # this will break if there's no pil_image since the direct_loss won't be initialized
305+
out.set_mask(mask)
177306
return out
178307

179308
def set_mask(self, mask, inverted=False):
180309
self.direct_loss.set_mask(mask, inverted)
181-
super().set_mask(mask, inverted)
310+
# super().set_mask(mask, inverted)
311+
# if device is None:
312+
device = self.device
313+
if isinstance(mask, str) and mask != "":
314+
if mask[0] == "-":
315+
mask = mask[1:]
316+
inverted = True
317+
if mask.strip()[-4:] == ".mp4":
318+
r = Rotoscoper(mask, self)
319+
r.update(0)
320+
return
321+
mask = Image.open(fetch(mask)).convert("L")
322+
if isinstance(mask, Image.Image):
323+
with vram_usage_mode("Masks"):
324+
mask = (
325+
TF.to_tensor(mask)
326+
.unsqueeze(0)
327+
.to(device, memory_format=torch.channels_last)
328+
)
329+
if mask not in ["", None]:
330+
self.mask.set_(mask if not inverted else (1 - mask))
331+
self.use_mask = mask not in ["", None]
182332

183333
def get_loss(self, input, img):
334+
logger.debug(type(input)) # Tensor
335+
logger.debug(input.shape) # this is an image tensor
336+
logger.debug(type(img)) # DIPImage
337+
logger.debug(type(self.comp)) # EMAParametersDict
338+
# logger.debug(
339+
# self.comp.shape
340+
# ) # [1 1 1 1] -> from target image constructor when no input image provided
341+
342+
# why is the latent comp only set here? why not in the __init__ and set_comp?
184343
if not self.has_latent:
344+
raise
345+
# make_latent() encodes the image through a dummy class instance, returns the resulting fitted image representation
346+
# if get_image_tensor() is not implemented, then the returned 'latent' tensor is just the tensorized pil image
185347
latent = img.make_latent(self.pil_image)
348+
logger.debug(type(latent)) # EMAParametersDict
349+
logger.debug(type(self.comp)) # torch.Tensor
186350
with torch.no_grad():
187-
self.comp.set_(latent.clone())
351+
if type(latent) == type(self.comp):
352+
self.comp.set_(latent.clone())
353+
# else:
354+
188355
self.has_latent = True
356+
357+
estimated_image = self.comp.get_image_tensor()
358+
189359
l1 = super().get_loss(img.get_latent_tensor(), img) / 2
190360
l2 = self.direct_loss.get_loss(input, img) / 10
191361
return l1 + l2

src/pytti/LossAug/LossOrchestratorClass.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def build_loss(weight_name, weight, name, img, pil_target):
3333
f"{weight_name} {name}:{weight}",
3434
img.image_shape,
3535
pil_target,
36-
img_model=img, # type(img)
36+
# img_model=img, # type(img)
3737
)
3838
out.set_enabled(pil_target is not None)
3939
return out

src/pytti/LossAug/MSELossClass.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(
4343

4444
@classmethod
4545
def default_comp(cls, img_model=None, *args, **kargs):
46-
# logger.debug("default_comp")
46+
logger.debug("default_comp")
4747
# logger.debug(type(img_model))
4848
# device = kargs.get("device", "cuda") if torch.cuda.is_available() else "cpu"
4949
# if img_model is None:

src/pytti/image_models/deep_image_prior.py

+26-4
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,10 @@ def __init__(
131131
# )
132132

133133
# def get_image_tensor(self):
134-
def decode_tensor(self):
134+
def decode_tensor(self, input_latent=None):
135+
"""
136+
Generates the image tensor from the attached DIP representation
137+
"""
135138
with torch.cuda.amp.autocast():
136139
# out = net(net_input_noised * input_scale).float()
137140
# logger.debug(self.net)
@@ -199,9 +202,12 @@ def encode_random(self):
199202

200203
@classmethod
201204
def get_preferred_loss(cls):
202-
from pytti.LossAug.LatentLossClass import LatentLoss
205+
from pytti.LossAug.LatentLossClass import LatentLoss, LatentLossDIP
206+
207+
return LatentLossDIP # LatentLoss
203208

204-
return LatentLoss
209+
# it'll be stupid complicated, but I could put a closure in here...
210+
# yeah no fuck that. I'm not adding complexity to enable deep image. I need to simplify how loss stuff works FIRST.
205211

206212
def make_latent(self, pil_image):
207213
"""
@@ -233,7 +239,7 @@ def default_comp(*args, **kargs):
233239

234240
def encode_image(self, pil_image, device="cuda"):
235241
"""
236-
Encodes the image into a tensor.
242+
Fits the attached DIP model representation to the input pil_image.
237243
238244
:param pil_image: The image to encode
239245
:param smart_encode: If True, the pallet will be optimized to match the image, defaults to True
@@ -262,3 +268,19 @@ def encode_image(self, pil_image, device="cuda"):
262268
)
263269
# why is there a magic number here?
264270
guide.run_steps(self.image_encode_steps, [], [], [mse])
271+
272+
273+
##############################################################################################################################
274+
275+
# round three
276+
277+
# gonna implement this the way that makes sense to me, and then see if I can't square-peg-round-hole it
278+
class DipSimpleLatentLoss(nn.Module):
279+
def __init__(
280+
self,
281+
net,
282+
image_shape,
283+
pil_image=None,
284+
):
285+
super().__init__()
286+
self.net = net

src/pytti/image_models/ema.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,14 @@ def average(self):
143143
def set_(self, d):
144144
if isinstance(d, torch.Tensor):
145145
logger.debug(self._container)
146+
logger.debug(d.shape)
146147

147148
d_ = d
148149
if isinstance(d, EMAParametersDict):
149150
d_ = d._container
150151
logger.debug(type(d_))
151-
logger.debug(d_.shape) # fuck it
152+
# logger.debug(d_.shape) # fuck it
153+
logger.debug(type(self._container))
152154
for k, v in d_.items():
153155
self._container[k].set_(v)
154156
# self._container[k].tensor.set_(v)

src/pytti/workhorse.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,12 @@ def do_run():
394394
loss_augs.extend(
395395
type(img)
396396
.get_preferred_loss()
397-
.TargetImage(p.strip(), img.image_shape, is_path=True, img_model=type(img))
397+
.TargetImage(
398+
p.strip(),
399+
img.image_shape,
400+
is_path=True,
401+
# img_model=type(img)
402+
)
398403
for p in params.direct_image_prompts.split("|")
399404
if p.strip()
400405
)

0 commit comments

Comments
 (0)