Skip to content

Cutouts #184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
11 changes: 0 additions & 11 deletions src/pytti/Image/__init__.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/pytti/ImageGuide.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
vram_usage_mode,
)
from pytti.AudioParse import SpectralAudioParser
from pytti.Image.differentiable_image import DifferentiableImage
from pytti.Image.PixelImage import PixelImage
from pytti.image_models.differentiable_image import DifferentiableImage
from pytti.image_models.pixel import PixelImage
from pytti.Notebook import tqdm, make_hbox

# from pytti.rotoscoper import update_rotoscopers
Expand Down
2 changes: 1 addition & 1 deletion src/pytti/LossAug/LossOrchestratorClass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from loguru import logger
from PIL import Image

from pytti.Image import PixelImage
from pytti.image_models import PixelImage

# from pytti.LossAug import build_loss
from pytti.LossAug import TVLoss, HSVLoss, OpticalFlowLoss, TargetFlowLoss
Expand Down
116 changes: 39 additions & 77 deletions src/pytti/Perceptor/Embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,19 @@
from pytti import DEVICE, format_input, cat_with_pad, format_module, normalize

# from pytti.ImageGuide import DirectImageGuide
from pytti.Image import DifferentiableImage
from pytti.image_models import DifferentiableImage

import torch
from torch import nn
from torch.nn import functional as F

import kornia.augmentation as K

# import .cutouts
# import .cutouts as cutouts
# import cutouts

from .cutouts import augs as cutouts_augs
from .cutouts import samplers as cutouts_samplers

PADDING_MODES = {
"mirror": "reflect",
Expand Down Expand Up @@ -43,19 +49,7 @@ def __init__(
self.cut_sizes = [p.visual.input_resolution for p in perceptors]
self.cutn = cutn
self.noise_fac = noise_fac
self.augs = nn.Sequential(
K.RandomHorizontalFlip(p=0.3),
K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"),
K.RandomPerspective(
0.2,
p=0.4,
),
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
K.RandomErasing(
scale=(0.1, 0.4), ratio=(0.3, 1 / 0.3), same_on_batch=False, p=0.7
),
nn.Identity(),
)
self.augs = cutouts_augs.pytti_classic()
self.input_axes = ("n", "s", "y", "x")
self.output_axes = ("c", "n", "i")
self.perceptors = perceptors
Expand All @@ -64,69 +58,34 @@ def __init__(
self.border_mode = border_mode

def make_cutouts(
self, input: torch.Tensor, side_x, side_y, cut_size, device=DEVICE
self,
input: torch.Tensor,
side_x,
side_y,
cut_size,
####
# padding,
# cutn,
# cut_pow,
# border_mode,
# augs,
# noise_fac,
####
device=DEVICE,
) -> Tuple[list, list, list]:
min_size = min(side_x, side_y, cut_size)
max_size = min(side_x, side_y)
paddingx = min(round(side_x * self.padding), side_x)
paddingy = min(round(side_y * self.padding), side_y)
cutouts = []
offsets = []
sizes = []
for _ in range(self.cutn):
# mean is 0.8
# varience is 0.3
size = int(
max_size
* (
torch.zeros(
1,
)
.normal_(mean=0.8, std=0.3)
.clip(cut_size / max_size, 1.0)
** self.cut_pow
)
)
offsetx_max = side_x - size + 1
offsety_max = side_y - size + 1
if self.border_mode == "clamp":
offsetx = torch.clamp(
(torch.rand([]) * (offsetx_max + 2 * paddingx) - paddingx)
.floor()
.int(),
0,
offsetx_max,
)
offsety = torch.clamp(
(torch.rand([]) * (offsety_max + 2 * paddingy) - paddingy)
.floor()
.int(),
0,
offsety_max,
)
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
else:
px = min(size, paddingx)
py = min(size, paddingy)
offsetx = (torch.rand([]) * (offsetx_max + 2 * px) - px).floor().int()
offsety = (torch.rand([]) * (offsety_max + 2 * py) - py).floor().int()
cutout = input[
:,
:,
paddingy + offsety : paddingy + offsety + size,
paddingx + offsetx : paddingx + offsetx + size,
]
cutouts.append(F.adaptive_avg_pool2d(cutout, cut_size))
offsets.append(
torch.as_tensor([[offsetx / side_x, offsety / side_y]]).to(device)
)
sizes.append(torch.as_tensor([[size / side_x, size / side_y]]).to(device))
cutouts = self.augs(torch.cat(cutouts))
offsets = torch.cat(offsets)
sizes = torch.cat(sizes)
if self.noise_fac:
facs = cutouts.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
cutouts.add_(facs * torch.randn_like(cutouts))
cutouts, offsets, sizes = cutouts_samplers.pytti_classic(
input=input,
side_x=side_x,
side_y=side_y,
cut_size=cut_size,
padding=self.padding,
cutn=self.cutn,
cut_pow=self.cut_pow,
border_mode=self.border_mode,
augs=self.augs,
noise_fac=self.noise_fac,
device=DEVICE,
)
return cutouts, offsets, sizes

def forward(
Expand Down Expand Up @@ -163,12 +122,15 @@ def forward(
(paddingx, paddingx, paddingy, paddingy),
mode=PADDING_MODES[self.border_mode],
)

# to do: add option to use a single perceptor per step and pick this perceptor randomly or in sequence
for cut_size, perceptor in zip(self.cut_sizes, perceptors):
cutouts, offsets, sizes = self.make_cutouts(input, side_x, side_y, cut_size)
clip_in = normalize(cutouts)
image_embeds.append(perceptor.encode_image(clip_in).float().unsqueeze(0))
all_offsets.append(offsets)
all_sizes.append(sizes)
# What does pytti do with offsets and sizes? and why doesn't dango need to return them?
return (
cat_with_pad(image_embeds),
torch.stack(all_offsets),
Expand Down
2 changes: 1 addition & 1 deletion src/pytti/Perceptor/Prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
parametric_eval,
vram_usage_mode,
)
from pytti.Image import RGBImage
from pytti.image_models import RGBImage

# from pytti.Notebook import Rotoscoper
from pytti.rotoscoper import Rotoscoper
Expand Down
3 changes: 3 additions & 0 deletions src/pytti/Perceptor/cutouts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .dango import MakeCutouts

test = MakeCutouts(1)
42 changes: 42 additions & 0 deletions src/pytti/Perceptor/cutouts/augs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import kornia.augmentation as K
import torch
from torch import nn
from torchvision import transforms as T


def pytti_classic():
return nn.Sequential(
K.RandomHorizontalFlip(p=0.3),
K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"),
K.RandomPerspective(
0.2,
p=0.4,
),
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
K.RandomErasing(
scale=(0.1, 0.4), ratio=(0.3, 1 / 0.3), same_on_batch=False, p=0.7
),
nn.Identity(),
)


def dango():
return T.Compose(
[
# T.RandomHorizontalFlip(p=0.5),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.RandomAffine(
degrees=0,
translate=(0.05, 0.05),
# scale=(0.9,0.95),
fill=-1,
interpolation=T.InterpolationMode.BILINEAR,
),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
# T.RandomPerspective(p=1, interpolation = T.InterpolationMode.BILINEAR, fill=-1,distortion_scale=0.2),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.RandomGrayscale(p=0.1),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05),
]
)
101 changes: 101 additions & 0 deletions src/pytti/Perceptor/cutouts/dango.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# via https://github.com/multimodalart/majesty-diffusion/blob/main/latent.ipynb

# !pip install resize-right
# TO DO: add resize-right to setup instructions and notebook
from resize_right import resize

import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision import transforms as T
from torchvision.transforms import functional as TF

from . import augs as cutouts_augs

padargs = {"mode": "constant", "value": -1}


class MakeCutouts(nn.Module):
def __init__(
self,
cut_size,
Overview=4,
WholeCrop=0,
WC_Allowance=10,
WC_Grey_P=0.2,
InnerCrop=0,
IC_Size_Pow=0.5,
IC_Grey_P=0.2,
aug=True,
cutout_debug=False,
):
super().__init__()
self.cut_size = cut_size
self.Overview = Overview
self.WholeCrop = WholeCrop
self.WC_Allowance = WC_Allowance
self.WC_Grey_P = WC_Grey_P
self.InnerCrop = InnerCrop
self.IC_Size_Pow = IC_Size_Pow
self.IC_Grey_P = IC_Grey_P
self.augs = cutouts_augs.dango
self._aug = aug
self.cutout_debug = cutout_debug

def forward(self, input):
gray = transforms.Grayscale(
3
) # this is possibly a performance improvement? 1 channel instead of 3. but also means we can't use color augs...
sideY, sideX = input.shape[2:4]
max_size = min(sideX, sideY)
min_size = min(sideX, sideY, self.cut_size)
l_size = max(sideX, sideY)
output_shape = [input.shape[0], 3, self.cut_size, self.cut_size]
output_shape_2 = [input.shape[0], 3, self.cut_size + 2, self.cut_size + 2]
pad_input = F.pad(
input,
(
(sideY - max_size) // 2 + round(max_size * 0.055),
(sideY - max_size) // 2 + round(max_size * 0.055),
(sideX - max_size) // 2 + round(max_size * 0.055),
(sideX - max_size) // 2 + round(max_size * 0.055),
),
**padargs
)
cutouts_list = []

if self.Overview > 0:
cutouts = []
cutout = resize(pad_input, out_shape=output_shape, antialiasing=True)
output_shape_all = list(output_shape)
output_shape_all[0] = self.Overview * input.shape[0]
pad_input = pad_input.repeat(input.shape[0], 1, 1, 1)
cutout = resize(pad_input, out_shape=output_shape_all)
if self._aug:
cutout = self.augs(cutout)
cutouts_list.append(cutout)

if self.InnerCrop > 0:
cutouts = []
for i in range(self.InnerCrop):
size = int(
torch.rand([]) ** self.IC_Size_Pow * (max_size - min_size)
+ min_size
)
offsetx = torch.randint(0, sideX - size + 1, ())
offsety = torch.randint(0, sideY - size + 1, ())
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
if i <= int(self.IC_Grey_P * self.InnerCrop):
cutout = gray(cutout)
cutout = resize(cutout, out_shape=output_shape)
cutouts.append(cutout)
if self.cutout_debug:
TF.to_pil_image(cutouts[-1].add(1).div(2).clamp(0, 1).squeeze(0)).save(
"content/diff/cutouts/cutout_InnerCrop.jpg", quality=99
)
cutouts_tensor = torch.cat(cutouts)
cutouts = []
cutouts_list.append(cutouts_tensor)
cutouts = torch.cat(cutouts_list)
return cutouts
Loading