diff --git a/src/pytti/Image/__init__.py b/src/pytti/Image/__init__.py deleted file mode 100644 index 2fdb66e..0000000 --- a/src/pytti/Image/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch, copy -from torch import nn -import numpy as np -from PIL import Image -from pytti.tensor_tools import named_rearrange - -from pytti.Image.differentiable_image import DifferentiableImage -from pytti.Image.ema_image import EMAImage -from pytti.Image.PixelImage import PixelImage -from pytti.Image.RGBImage import RGBImage -from pytti.Image.VQGANImage import VQGANImage diff --git a/src/pytti/ImageGuide.py b/src/pytti/ImageGuide.py index 1855915..36122ed 100644 --- a/src/pytti/ImageGuide.py +++ b/src/pytti/ImageGuide.py @@ -20,8 +20,8 @@ vram_usage_mode, ) from pytti.AudioParse import SpectralAudioParser -from pytti.Image.differentiable_image import DifferentiableImage -from pytti.Image.PixelImage import PixelImage +from pytti.image_models.differentiable_image import DifferentiableImage +from pytti.image_models.pixel import PixelImage from pytti.Notebook import tqdm, make_hbox # from pytti.rotoscoper import update_rotoscopers diff --git a/src/pytti/LossAug/LossOrchestratorClass.py b/src/pytti/LossAug/LossOrchestratorClass.py index fb0fb3d..284b23d 100644 --- a/src/pytti/LossAug/LossOrchestratorClass.py +++ b/src/pytti/LossAug/LossOrchestratorClass.py @@ -2,7 +2,7 @@ from loguru import logger from PIL import Image -from pytti.Image import PixelImage +from pytti.image_models import PixelImage # from pytti.LossAug import build_loss from pytti.LossAug import TVLoss, HSVLoss, OpticalFlowLoss, TargetFlowLoss diff --git a/src/pytti/Perceptor/Embedder.py b/src/pytti/Perceptor/Embedder.py index fee6068..1acbf0c 100644 --- a/src/pytti/Perceptor/Embedder.py +++ b/src/pytti/Perceptor/Embedder.py @@ -4,13 +4,19 @@ from pytti import DEVICE, format_input, cat_with_pad, format_module, normalize # from pytti.ImageGuide import DirectImageGuide -from pytti.Image import DifferentiableImage +from pytti.image_models import DifferentiableImage import torch from torch import nn from torch.nn import functional as F -import kornia.augmentation as K + +# import .cutouts +# import .cutouts as cutouts +# import cutouts + +from .cutouts import augs as cutouts_augs +from .cutouts import samplers as cutouts_samplers PADDING_MODES = { "mirror": "reflect", @@ -43,19 +49,7 @@ def __init__( self.cut_sizes = [p.visual.input_resolution for p in perceptors] self.cutn = cutn self.noise_fac = noise_fac - self.augs = nn.Sequential( - K.RandomHorizontalFlip(p=0.3), - K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"), - K.RandomPerspective( - 0.2, - p=0.4, - ), - K.ColorJitter(hue=0.01, saturation=0.01, p=0.7), - K.RandomErasing( - scale=(0.1, 0.4), ratio=(0.3, 1 / 0.3), same_on_batch=False, p=0.7 - ), - nn.Identity(), - ) + self.augs = cutouts_augs.pytti_classic() self.input_axes = ("n", "s", "y", "x") self.output_axes = ("c", "n", "i") self.perceptors = perceptors @@ -64,69 +58,34 @@ def __init__( self.border_mode = border_mode def make_cutouts( - self, input: torch.Tensor, side_x, side_y, cut_size, device=DEVICE + self, + input: torch.Tensor, + side_x, + side_y, + cut_size, + #### + # padding, + # cutn, + # cut_pow, + # border_mode, + # augs, + # noise_fac, + #### + device=DEVICE, ) -> Tuple[list, list, list]: - min_size = min(side_x, side_y, cut_size) - max_size = min(side_x, side_y) - paddingx = min(round(side_x * self.padding), side_x) - paddingy = min(round(side_y * self.padding), side_y) - cutouts = [] - offsets = [] - sizes = [] - for _ in range(self.cutn): - # mean is 0.8 - # varience is 0.3 - size = int( - max_size - * ( - torch.zeros( - 1, - ) - .normal_(mean=0.8, std=0.3) - .clip(cut_size / max_size, 1.0) - ** self.cut_pow - ) - ) - offsetx_max = side_x - size + 1 - offsety_max = side_y - size + 1 - if self.border_mode == "clamp": - offsetx = torch.clamp( - (torch.rand([]) * (offsetx_max + 2 * paddingx) - paddingx) - .floor() - .int(), - 0, - offsetx_max, - ) - offsety = torch.clamp( - (torch.rand([]) * (offsety_max + 2 * paddingy) - paddingy) - .floor() - .int(), - 0, - offsety_max, - ) - cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] - else: - px = min(size, paddingx) - py = min(size, paddingy) - offsetx = (torch.rand([]) * (offsetx_max + 2 * px) - px).floor().int() - offsety = (torch.rand([]) * (offsety_max + 2 * py) - py).floor().int() - cutout = input[ - :, - :, - paddingy + offsety : paddingy + offsety + size, - paddingx + offsetx : paddingx + offsetx + size, - ] - cutouts.append(F.adaptive_avg_pool2d(cutout, cut_size)) - offsets.append( - torch.as_tensor([[offsetx / side_x, offsety / side_y]]).to(device) - ) - sizes.append(torch.as_tensor([[size / side_x, size / side_y]]).to(device)) - cutouts = self.augs(torch.cat(cutouts)) - offsets = torch.cat(offsets) - sizes = torch.cat(sizes) - if self.noise_fac: - facs = cutouts.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac) - cutouts.add_(facs * torch.randn_like(cutouts)) + cutouts, offsets, sizes = cutouts_samplers.pytti_classic( + input=input, + side_x=side_x, + side_y=side_y, + cut_size=cut_size, + padding=self.padding, + cutn=self.cutn, + cut_pow=self.cut_pow, + border_mode=self.border_mode, + augs=self.augs, + noise_fac=self.noise_fac, + device=DEVICE, + ) return cutouts, offsets, sizes def forward( @@ -163,12 +122,15 @@ def forward( (paddingx, paddingx, paddingy, paddingy), mode=PADDING_MODES[self.border_mode], ) + + # to do: add option to use a single perceptor per step and pick this perceptor randomly or in sequence for cut_size, perceptor in zip(self.cut_sizes, perceptors): cutouts, offsets, sizes = self.make_cutouts(input, side_x, side_y, cut_size) clip_in = normalize(cutouts) image_embeds.append(perceptor.encode_image(clip_in).float().unsqueeze(0)) all_offsets.append(offsets) all_sizes.append(sizes) + # What does pytti do with offsets and sizes? and why doesn't dango need to return them? return ( cat_with_pad(image_embeds), torch.stack(all_offsets), diff --git a/src/pytti/Perceptor/Prompt.py b/src/pytti/Perceptor/Prompt.py index 0696be9..f8ff836 100644 --- a/src/pytti/Perceptor/Prompt.py +++ b/src/pytti/Perceptor/Prompt.py @@ -26,7 +26,7 @@ parametric_eval, vram_usage_mode, ) -from pytti.Image import RGBImage +from pytti.image_models import RGBImage # from pytti.Notebook import Rotoscoper from pytti.rotoscoper import Rotoscoper diff --git a/src/pytti/Perceptor/cutouts/__init__.py b/src/pytti/Perceptor/cutouts/__init__.py new file mode 100644 index 0000000..94b5630 --- /dev/null +++ b/src/pytti/Perceptor/cutouts/__init__.py @@ -0,0 +1,3 @@ +from .dango import MakeCutouts + +test = MakeCutouts(1) diff --git a/src/pytti/Perceptor/cutouts/augs.py b/src/pytti/Perceptor/cutouts/augs.py new file mode 100644 index 0000000..101b675 --- /dev/null +++ b/src/pytti/Perceptor/cutouts/augs.py @@ -0,0 +1,42 @@ +import kornia.augmentation as K +import torch +from torch import nn +from torchvision import transforms as T + + +def pytti_classic(): + return nn.Sequential( + K.RandomHorizontalFlip(p=0.3), + K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"), + K.RandomPerspective( + 0.2, + p=0.4, + ), + K.ColorJitter(hue=0.01, saturation=0.01, p=0.7), + K.RandomErasing( + scale=(0.1, 0.4), ratio=(0.3, 1 / 0.3), same_on_batch=False, p=0.7 + ), + nn.Identity(), + ) + + +def dango(): + return T.Compose( + [ + # T.RandomHorizontalFlip(p=0.5), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomAffine( + degrees=0, + translate=(0.05, 0.05), + # scale=(0.9,0.95), + fill=-1, + interpolation=T.InterpolationMode.BILINEAR, + ), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + # T.RandomPerspective(p=1, interpolation = T.InterpolationMode.BILINEAR, fill=-1,distortion_scale=0.2), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomGrayscale(p=0.1), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05), + ] + ) diff --git a/src/pytti/Perceptor/cutouts/dango.py b/src/pytti/Perceptor/cutouts/dango.py new file mode 100644 index 0000000..ae94077 --- /dev/null +++ b/src/pytti/Perceptor/cutouts/dango.py @@ -0,0 +1,101 @@ +# via https://github.com/multimodalart/majesty-diffusion/blob/main/latent.ipynb + +# !pip install resize-right +# TO DO: add resize-right to setup instructions and notebook +from resize_right import resize + +import torch +from torch import nn +from torch.nn import functional as F +from torchvision import transforms +from torchvision import transforms as T +from torchvision.transforms import functional as TF + +from . import augs as cutouts_augs + +padargs = {"mode": "constant", "value": -1} + + +class MakeCutouts(nn.Module): + def __init__( + self, + cut_size, + Overview=4, + WholeCrop=0, + WC_Allowance=10, + WC_Grey_P=0.2, + InnerCrop=0, + IC_Size_Pow=0.5, + IC_Grey_P=0.2, + aug=True, + cutout_debug=False, + ): + super().__init__() + self.cut_size = cut_size + self.Overview = Overview + self.WholeCrop = WholeCrop + self.WC_Allowance = WC_Allowance + self.WC_Grey_P = WC_Grey_P + self.InnerCrop = InnerCrop + self.IC_Size_Pow = IC_Size_Pow + self.IC_Grey_P = IC_Grey_P + self.augs = cutouts_augs.dango + self._aug = aug + self.cutout_debug = cutout_debug + + def forward(self, input): + gray = transforms.Grayscale( + 3 + ) # this is possibly a performance improvement? 1 channel instead of 3. but also means we can't use color augs... + sideY, sideX = input.shape[2:4] + max_size = min(sideX, sideY) + min_size = min(sideX, sideY, self.cut_size) + l_size = max(sideX, sideY) + output_shape = [input.shape[0], 3, self.cut_size, self.cut_size] + output_shape_2 = [input.shape[0], 3, self.cut_size + 2, self.cut_size + 2] + pad_input = F.pad( + input, + ( + (sideY - max_size) // 2 + round(max_size * 0.055), + (sideY - max_size) // 2 + round(max_size * 0.055), + (sideX - max_size) // 2 + round(max_size * 0.055), + (sideX - max_size) // 2 + round(max_size * 0.055), + ), + **padargs + ) + cutouts_list = [] + + if self.Overview > 0: + cutouts = [] + cutout = resize(pad_input, out_shape=output_shape, antialiasing=True) + output_shape_all = list(output_shape) + output_shape_all[0] = self.Overview * input.shape[0] + pad_input = pad_input.repeat(input.shape[0], 1, 1, 1) + cutout = resize(pad_input, out_shape=output_shape_all) + if self._aug: + cutout = self.augs(cutout) + cutouts_list.append(cutout) + + if self.InnerCrop > 0: + cutouts = [] + for i in range(self.InnerCrop): + size = int( + torch.rand([]) ** self.IC_Size_Pow * (max_size - min_size) + + min_size + ) + offsetx = torch.randint(0, sideX - size + 1, ()) + offsety = torch.randint(0, sideY - size + 1, ()) + cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] + if i <= int(self.IC_Grey_P * self.InnerCrop): + cutout = gray(cutout) + cutout = resize(cutout, out_shape=output_shape) + cutouts.append(cutout) + if self.cutout_debug: + TF.to_pil_image(cutouts[-1].add(1).div(2).clamp(0, 1).squeeze(0)).save( + "content/diff/cutouts/cutout_InnerCrop.jpg", quality=99 + ) + cutouts_tensor = torch.cat(cutouts) + cutouts = [] + cutouts_list.append(cutouts_tensor) + cutouts = torch.cat(cutouts_list) + return cutouts diff --git a/src/pytti/Perceptor/cutouts/samplers.py b/src/pytti/Perceptor/cutouts/samplers.py new file mode 100644 index 0000000..086680d --- /dev/null +++ b/src/pytti/Perceptor/cutouts/samplers.py @@ -0,0 +1,117 @@ +""" +Methods for obtaining cutouts, agnostic to augmentations. + +Cutout choices have a significant impact on the performance of the perceptors and the +overall look of the image. + +The objects defined here probably are only being used in pytti.Perceptor.cutouts.Embedder.HDMultiClipEmbedder, but they +should be sufficiently general for use in notebooks without pyttitools otherwise in use. +""" + +import torch +from typing import Tuple +from torch.nn import functional as F + +PADDING_MODES = { + "mirror": "reflect", + "smear": "replicate", + "wrap": "circular", + "black": "constant", +} + +# ( +# cut_size = 64 +# cut_pow = 0.5 +# noise_fac = 0.0 +# cutn = 8 +# border_mode = "clamp" +# augs = None +# return Cutout( +# cut_size=cut_size, +# cut_pow=cut_pow, +# noise_fac=noise_fac, +# cutn=cutn, +# border_mode=border_mode, +# augs=augs, +# ) + + +def pytti_classic( + # self, + input: torch.Tensor, + side_x, + side_y, + cut_size, + padding, + cutn, + cut_pow, + border_mode, + augs, + noise_fac, + device, +) -> Tuple[list, list, list]: + """ + This is the cutout method that was already in use in the original pytti. + """ + min_size = min(side_x, side_y, cut_size) + max_size = min(side_x, side_y) + paddingx = min(round(side_x * padding), side_x) + paddingy = min(round(side_y * padding), side_y) + cutouts = [] + offsets = [] + sizes = [] + for _ in range(cutn): + # mean is 0.8 + # varience is 0.3 + size = int( + max_size + * ( + torch.zeros( + 1, + ) + .normal_(mean=0.8, std=0.3) + .clip(cut_size / max_size, 1.0) + ** cut_pow + ) + ) + offsetx_max = side_x - size + 1 + offsety_max = side_y - size + 1 + if border_mode == "clamp": + offsetx = torch.clamp( + (torch.rand([]) * (offsetx_max + 2 * paddingx) - paddingx) + .floor() + .int(), + 0, + offsetx_max, + ) + offsety = torch.clamp( + (torch.rand([]) * (offsety_max + 2 * paddingy) - paddingy) + .floor() + .int(), + 0, + offsety_max, + ) + cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] + else: + px = min(size, paddingx) + py = min(size, paddingy) + offsetx = (torch.rand([]) * (offsetx_max + 2 * px) - px).floor().int() + offsety = (torch.rand([]) * (offsety_max + 2 * py) - py).floor().int() + cutout = input[ + :, + :, + paddingy + offsety : paddingy + offsety + size, + paddingx + offsetx : paddingx + offsetx + size, + ] + cutouts.append(F.adaptive_avg_pool2d(cutout, cut_size)) + offsets.append( + torch.as_tensor([[offsetx / side_x, offsety / side_y]]).to(device) + ) + sizes.append(torch.as_tensor([[size / side_x, size / side_y]]).to(device)) + cutouts = augs(torch.cat(cutouts)) + offsets = torch.cat(offsets) + sizes = torch.cat(sizes) + if noise_fac: + facs = cutouts.new_empty([cutn, 1, 1, 1]).uniform_(0, noise_fac) + cutouts.add_(facs * torch.randn_like(cutouts)) + return cutouts, offsets, sizes diff --git a/src/pytti/Transforms.py b/src/pytti/Transforms.py index e1d336d..dfe4863 100644 --- a/src/pytti/Transforms.py +++ b/src/pytti/Transforms.py @@ -388,7 +388,7 @@ def animate_video_source( sampling_mode, ): # ugh this is GROSSSSS.... - from pytti.Image.PixelImage import PixelImage + from pytti.image_models.pixel import PixelImage # current frame index frame_n = min( @@ -415,8 +415,8 @@ def animate_video_source( for j, optical_flow in enumerate(optical_flows): # This looks like something that we shouldn't have to recompute # but rather could be attached to the flow object as an attribute - old_frame_n = frame_n - (2 ** j - 1) * frame_stride - save_n = i // save_every - (2 ** j - 1) + old_frame_n = frame_n - (2**j - 1) * frame_stride + save_n = i // save_every - (2**j - 1) if old_frame_n < 0 or save_n < 1: break diff --git a/src/pytti/config/structured_config.py b/src/pytti/config/structured_config.py index a574d5f..024c5c9 100644 --- a/src/pytti/config/structured_config.py +++ b/src/pytti/config/structured_config.py @@ -6,7 +6,7 @@ from attrs import define, field from hydra.core.config_store import ConfigStore -from pytti.Image.VQGANImage import VQGAN_MODEL_NAMES +from pytti.image_models.vqgan import VQGAN_MODEL_NAMES def check_input_against_list(attribute, value, valid_values): diff --git a/src/pytti/image_models/__init__.py b/src/pytti/image_models/__init__.py new file mode 100644 index 0000000..b47a344 --- /dev/null +++ b/src/pytti/image_models/__init__.py @@ -0,0 +1,11 @@ +import torch, copy +from torch import nn +import numpy as np +from PIL import Image +from pytti.tensor_tools import named_rearrange + +from .differentiable_image import DifferentiableImage +from .ema import EMAImage +from .pixel import PixelImage +from .rgb_image import RGBImage +from .vqgan import VQGANImage diff --git a/src/pytti/Image/differentiable_image.py b/src/pytti/image_models/differentiable_image.py similarity index 100% rename from src/pytti/Image/differentiable_image.py rename to src/pytti/image_models/differentiable_image.py diff --git a/src/pytti/Image/ema_image.py b/src/pytti/image_models/ema.py similarity index 95% rename from src/pytti/Image/ema_image.py rename to src/pytti/image_models/ema.py index f7ee04a..ea04791 100644 --- a/src/pytti/Image/ema_image.py +++ b/src/pytti/image_models/ema.py @@ -1,6 +1,6 @@ import torch from torch import nn -from pytti.Image.differentiable_image import DifferentiableImage +from pytti.image_models.differentiable_image import DifferentiableImage class EMAImage(DifferentiableImage): diff --git a/src/pytti/Image/PixelImage.py b/src/pytti/image_models/pixel.py similarity index 99% rename from src/pytti/Image/PixelImage.py rename to src/pytti/image_models/pixel.py index b5c7e27..ba58c09 100644 --- a/src/pytti/Image/PixelImage.py +++ b/src/pytti/image_models/pixel.py @@ -1,5 +1,5 @@ from pytti import DEVICE, named_rearrange, replace_grad, vram_usage_mode -from pytti.Image.differentiable_image import DifferentiableImage +from pytti.image_models.differentiable_image import DifferentiableImage from pytti.LossAug.HSVLossClass import HSVLoss # from pytti.ImageGuide import DirectImageGuide diff --git a/src/pytti/Image/RGBImage.py b/src/pytti/image_models/rgb_image.py similarity index 97% rename from src/pytti/Image/RGBImage.py rename to src/pytti/image_models/rgb_image.py index b92830d..109e22e 100644 --- a/src/pytti/Image/RGBImage.py +++ b/src/pytti/image_models/rgb_image.py @@ -2,7 +2,7 @@ import torch from torch import nn from torchvision.transforms import functional as TF -from pytti.Image import DifferentiableImage +from pytti.image_models import DifferentiableImage from PIL import Image from torch.nn import functional as F diff --git a/src/pytti/Image/VQGANImage.py b/src/pytti/image_models/vqgan.py similarity index 99% rename from src/pytti/Image/VQGANImage.py rename to src/pytti/image_models/vqgan.py index f256e88..8c33bac 100644 --- a/src/pytti/Image/VQGANImage.py +++ b/src/pytti/image_models/vqgan.py @@ -11,7 +11,7 @@ from pytti import DEVICE, replace_grad, clamp_with_grad, vram_usage_mode import torch from torch.nn import functional as F -from pytti.Image import EMAImage +from pytti.image_models import EMAImage from torchvision.transforms import functional as TF from PIL import Image from omegaconf import OmegaConf diff --git a/src/pytti/workhorse.py b/src/pytti/workhorse.py index 6836da6..d98e727 100644 --- a/src/pytti/workhorse.py +++ b/src/pytti/workhorse.py @@ -38,7 +38,7 @@ ) from pytti.rotoscoper import ROTOSCOPERS, get_frames -from pytti.Image import PixelImage, RGBImage, VQGANImage +from pytti.image_models import PixelImage, RGBImage, VQGANImage from pytti.ImageGuide import DirectImageGuide from pytti.Perceptor.Embedder import HDMultiClipEmbedder from pytti.Perceptor.Prompt import parse_prompt diff --git a/tests/test_models_load.py b/tests/test_models_load.py index 2ef078d..69f5e05 100644 --- a/tests/test_models_load.py +++ b/tests/test_models_load.py @@ -2,7 +2,7 @@ import pytest from loguru import logger -from pytti.Image.VQGANImage import ( +from pytti.image_models.vqgan import ( VQGANImage, VQGAN_MODEL_NAMES, VQGAN_CONFIG_URLS, diff --git a/tests/test_rough_e2e.py b/tests/test_rough_e2e.py index 831f1b3..63fcb56 100644 --- a/tests/test_rough_e2e.py +++ b/tests/test_rough_e2e.py @@ -108,7 +108,7 @@ def test_vqgan(self, kwargs): super().test_vqgan(**kwargs) -from pytti.Image.VQGANImage import VQGAN_MODEL_NAMES +from pytti.image_models.vqgan import VQGAN_MODEL_NAMES @pytest.mark.parametrize(