diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4691ef0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +*egg-info/ +__pycache__/ +src/ +checkpoints/ +results/ \ No newline at end of file diff --git a/environment.yaml b/environment.yaml index 9663e45..12737f4 100644 --- a/environment.yaml +++ b/environment.yaml @@ -11,7 +11,9 @@ dependencies: - numpy=1.19.2 - pip: - albumentations==0.4.3 - - diffusers + - diffusers==0.12.1 + - psutil + - appdirs - bezier - opencv-python==4.1.2.30 - pudb==2019.2 diff --git a/scripts/inference_parallel.py b/scripts/inference_parallel.py new file mode 100644 index 0000000..cd85460 --- /dev/null +++ b/scripts/inference_parallel.py @@ -0,0 +1,968 @@ +import argparse, os, sys, glob +import cv2 +import torch +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from imwatermark import WatermarkEncoder +from itertools import islice +from einops import rearrange +from torchvision.utils import make_grid +import time +from pytorch_lightning import seed_everything +from torch import autocast +from contextlib import contextmanager, nullcontext +import torchvision +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler + +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from transformers import AutoFeatureExtractor +import clip +from torchvision.transforms import Resize +wm = "Paint-by-Example" +wm_encoder = WatermarkEncoder() +wm_encoder.set_watermark('bytes', wm.encode('utf-8')) +safety_model_id = "CompVis/stable-diffusion-safety-checker" +safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) +safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) + + +# utils functions +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + +def get_tensor_clip(normalize=True, toTensor=True): + transform_list = [] + if toTensor: + transform_list += [torchvision.transforms.ToTensor()] + + if normalize: + transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711))] + return torchvision.transforms.Compose(transform_list) + +def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + +def put_watermark(img, wm_encoder=None): + if wm_encoder is not None: + img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + img = wm_encoder.encode(img, 'dwtDct') + img = Image.fromarray(img[:, :, ::-1]) + return img + + +def load_replacement(x): + try: + hwc = x.shape + y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0])) + y = (np.array(y)/255.0).astype(x.dtype) + assert y.shape == x.shape + return y + except Exception: + return x + + +def check_safety(x_image): + safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") + x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) + assert x_checked_image.shape[0] == len(has_nsfw_concept) + for i in range(len(has_nsfw_concept)): + if has_nsfw_concept[i]: + x_checked_image[i] = load_replacement(x_checked_image[i]) + return x_checked_image, has_nsfw_concept + +def get_tensor(normalize=True, toTensor=True): + transform_list = [] + if toTensor: + transform_list += [torchvision.transforms.ToTensor()] + + if normalize: + transform_list += [torchvision.transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + return torchvision.transforms.Compose(transform_list) + +def get_tensor_clip(normalize=True, toTensor=True): + transform_list = [] + if toTensor: + transform_list += [torchvision.transforms.ToTensor()] + + if normalize: + transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711))] + return torchvision.transforms.Compose(transform_list) + + +# dist utils +def get_world_size(): + return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + + +def get_rank(): + return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + + +def expand_first_dim(x, num_repeats=5, unsqueeze=False): + if unsqueeze: + x = x.unsqueeze(0) + x = x.repeat(num_repeats, *([1] * (x.dim() - 1))) + return x + + +def dist_init(): + if 'MASTER_ADDR' not in os.environ: + os.environ['MASTER_ADDR'] = 'localhost' + if 'MASTER_PORT' not in os.environ: + os.environ['MASTER_PORT'] = '29500' + if 'RANK' not in os.environ: + os.environ['RANK'] = '0' + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = '0' + if 'WORLD_SIZE' not in os.environ: + os.environ['WORLD_SIZE'] = '1' + + backend = 'gloo' if os.name == 'nt' else 'nccl' + torch.distributed.init_process_group(backend=backend, init_method='env://') + torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) + + +# val dataset +import torchvision.transforms.functional as TF +import torch.nn.functional as F +import torchvision.transforms as T +import pycocotools.mask as maskUtils +import random +from torchvision.transforms import InterpolationMode, Compose, Resize, ToTensor, Normalize, CenterCrop +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +def _tokenzie_text(text: str): + """Tokenize text.""" + return clip.tokenize(text, context_length=77, truncate=True) # 1x77 + + +def _find_tokens_to_replace(class_name, text): + """Find the tokens to be replaced.""" + if f"{class_name.lower()}s" in text.lower(): + class_name = class_name + "s" + elif f"{class_name.lower()}es" in text.lower(): + class_name = class_name + "es" + tokenized_class_name = _tokenzie_text(class_name).squeeze(0) + tokenized_text = _tokenzie_text(text).squeeze(0) + tokens_to_replace = [] + valid_tokens = tokenized_class_name[ + (tokenized_class_name != 49406) & \ + (tokenized_class_name != 49407) & \ + (tokenized_class_name != 0)] + len_valid_tokens = len(valid_tokens) + for i in range(len(tokenized_text) - len_valid_tokens + 1): + if torch.equal(tokenized_text[i:i+len_valid_tokens], valid_tokens): + tokens_to_replace.extend( + list(range(i, i+len_valid_tokens))) + + return tokenized_class_name, tokenized_text, tokens_to_replace + + +# def _transform(n_px): +# return Compose([ +# Resize(n_px, interpolation=BICUBIC), +# CenterCrop(n_px), +# Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), +# ]) + +# if dropping ration > 2, some objects like human, guitar, etc. will be dropped +# that is not what we want ... +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def denormalize(img): + mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).to(img.device) + std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).to(img.device) + # print("Before ", img.min(), img.max()) + img = img * std[None, :, None, None] + mean[None, :, None, None] + # print("After ", img.min(), img.max()) + return img + + +class RetrieverSA(object): + def __init__( + self, + path_filtered_ids="/mnt/external/datasets/sam/merged_meta_data_new2.pt", + path_meta_info="/mnt/external/tmp/2023/05/12/clip_embeddings_large_sa-1b.pt") -> None: + # filtered_ids = torch.load(path_filtered_ids, map_location="cpu") + filtered_ids = torch.load(path_filtered_ids, map_location="cpu") + meta_info = torch.load(path_meta_info, map_location="cpu") + # img_ids = list(meta_info.keys()) + # img_ids = img_ids[:100] + + # new dict + new_filtered_ids = {} + for item in filtered_ids: + new_filtered_ids[item["image"]["path"]] = item + + self.new_filtered_ids = new_filtered_ids + # for key in meta_info.keys(): + # print(meta_info[key].shape) + # import pdb; pdb.set_trace() + num_instances = 0 + for key in meta_info: + num_instances += len(meta_info[key]) + + print("num_instances", num_instances) + + # build total tensor + total_tensor = torch.zeros(num_instances, 1024) + total_tensor = total_tensor.half() + + index_map = {} + index = 0 + # for img_id in img_ids: + # for instance_id in range(len(meta_info[img_id])): + # total_tensor[index] = meta_info[img_id][instance_id] + # index_map[(img_id, instance_id)] = index + # index += 1 + for key in meta_info: + for instance_id in range(len(meta_info[key])): + total_tensor[index] = meta_info[key][instance_id] + index_map[(key, instance_id)] = index + index += 1 + + # index to img_id + instance_id + index_to_img_id = {v: k for k, v in index_map.items()} + self.index_to_img_id = index_to_img_id + self.total_tensor = total_tensor + self.index_map = index_map + + def get_instance_patch(self, img_id, instance_id, bg="white"): + item = self.new_filtered_ids[img_id] + image_path = item["image"]["path"] + image_path = os.path.join("/mnt/external/datasets/sam/", image_path) + img = Image.open(image_path).convert("RGB") + # img = img.resize((224, 224)) + img = TF.to_tensor(img) + + box = item["annotations"][instance_id]["bbox"] + mask = item["annotations"][instance_id]["segmentation"] + # normalize bbox + box = [ + box[0] / item["image"]["width"], + box[1] / item["image"]["height"], + (box[0] + box[2]) / item["image"]["width"], + (box[1] + box[3]) / item["image"]["height"], + ] + # decode mask + mask = torch.from_numpy(maskUtils.decode(mask)).bool()# [:, :, instance_id] + instance_patch = self._process_img(img, box, mask, bg=bg)[0] + instance_patch = TF.to_pil_image(instance_patch, mode="RGB") + instance_patch = _transform(224)(instance_patch) + return instance_patch + + def get_class_name(self, img_id, instance_id): + item = self.new_filtered_ids[img_id] + + # item["instance_caption"] = [['Question: what is the name of the object? Answer: ', 2, 'a roll of tape'], ['Question: what is the name of the object? Answer: ', 3, '3m tape']] + # class_name = None + # for caption in item["instance_caption"]: + # if caption[1] == instance_id: + # class_name = caption[2] + # break + + # if class_name is None: + # print(item["instance_caption"], instance_id) + # exit() + try: + class_name = item["instance_caption"][instance_id][-1] + except: + print(item["instance_caption"], instance_id, img_id) + exit() + + return class_name + + def _process_img(self, img, box, mask, bg="none"): + # resize the image to the size of mask + # print(img) + mask = mask.unsqueeze(0).float() + mask = TF.resize(mask, (img.shape[1], img.shape[2])) + mask = (mask > 0.5).float() + # raw mask + raw_mask = mask.clone() + # generate a random number ranging in [0, 1] + p1 = random.random() + if p1 < 0.5: + mask = F.conv2d(mask, torch.ones(1, 1, 3, 3), padding=1) + else: + mask = F.conv2d(mask, torch.ones(1, 1, 5, 5), padding=2) + mask = (mask > 0.5).float() + + p = random.random() + if bg == "white": + # multiply the mask with the image + img = img * mask + extended_mask = mask.repeat(3, 1, 1) + img[extended_mask <= 0.0] = 1.0 + elif bg == "black": + img = img * mask + else: + pass + # crop the image using the 4-d box range in [0, 1] + bbox = [ + int((box[1] ) * img.shape[1]), # top + int((box[0] ) * img.shape[2]), # left + int((box[3] - box[1]) * img.shape[1]), # height + int((box[2] - box[0]) * img.shape[2]), # width + ] + # random jitter the box in form of [top, left, height, width] with 2 pixels + offset = torch.randint(0, 3, (4,)) + bbox = [bbox[0] - offset[0], bbox[1] - offset[1], bbox[2] + offset[2], bbox[3] + offset[3]] + # keep the range + bbox[0] = max(0, bbox[0]) + bbox[1] = max(0, bbox[1]) + bbox[2] = min(img.shape[1] - bbox[0], bbox[2]) + bbox[3] = min(img.shape[2] - bbox[1], bbox[3]) + + max_size = max(bbox[2], bbox[3]) + # pad the box to a square and crop from the image + raw_center = [bbox[0] + bbox[2] / 2, bbox[1] + bbox[3] / 2] + new_left_top = [raw_center[0] - max_size / 2, raw_center[1] - max_size / 2] + new_right_bottom = [raw_center[0] + max_size / 2, raw_center[1] + max_size / 2] + # move the new bbox to keep the box in the image + if new_left_top[0] < 0: + new_right_bottom[0] -= new_left_top[0] + new_left_top[0] = 0 + if new_left_top[1] < 0: + new_right_bottom[1] -= new_left_top[1] + new_left_top[1] = 0 + if new_right_bottom[0] > img.shape[1]: + new_left_top[0] -= new_right_bottom[0] - img.shape[1] + new_right_bottom[0] = img.shape[1] + if new_right_bottom[1] > img.shape[2]: + new_left_top[1] -= new_right_bottom[1] - img.shape[2] + new_right_bottom[1] = img.shape[2] + # convert the new bbox to the form of [top, left, height, width] + bbox = [ + int(new_left_top[0]), + int(new_left_top[1]), + int(new_right_bottom[0] - new_left_top[0]), + int(new_right_bottom[1] - new_left_top[1]), + ] + + img = TF.crop( + img.clone(), + *bbox, + ) + + return img, raw_mask + + def query(self, query_tensor, total_tensor, topk=10): + # query_tensor = query_tensor.half() + query_tensor = query_tensor.to("cpu") + total_tensor = total_tensor.to("cpu") + + r""" + q: (1024,) + database: (N, 1024) + """ + # normalize and compute cosine similarity + query_tensor = F.normalize(query_tensor.float(), dim=0) + total_tensor = F.normalize(total_tensor.float(), dim=1) + sim = torch.matmul(total_tensor, query_tensor)# .squeeze(1) + + # sort + sim, indices = torch.sort(sim, descending=True) + indices = indices[:topk] + + selcted_ids = [] + for index in indices: + selcted_ids.append(self.index_to_img_id[index.item()]) + + return selcted_ids + + + +class SAValDataset2(torch.utils.data.Dataset): + + def __init__( + self, + path, + split="train", + splits=(0.9, 0.05, 0.05), + crop_res=256, + min_resize_res=256, + max_resize_res=256, + flip_prob=0.5, + max_size=-1, + bg="none", + **kwargs): + # data_path = DATA_PATH_TABLE["openimages"] + self.data_path = path + self.split = split + # self.task_indicator = task_indicator + self.crop_res = crop_res + self.flip_prob = flip_prob + + self.min_resize_res = min_resize_res + self.max_resize_res = max_resize_res + + # fix the path now + # self.filtered_ids = torch.load("../filtered_ids.pt", map_location="cpu") + self.meta_info = torch.load(os.path.join(self.data_path, "val_999.pt"), map_location="cpu") + if max_size is not None and max_size > 0: + self.meta_info = self.meta_info[:max_size] + + self.bg = bg + + def _process_img(self, img, box, mask): + # resize the image to the size of mask + mask = mask.unsqueeze(0).float() + mask = TF.resize(mask, (img.shape[1], img.shape[2])) + mask = (mask > 0.5).float() + # raw mask + raw_mask = mask.clone() + # generate a random number ranging in [0, 1] + p1 = random.random() + if p1 < 0.5: + mask = F.conv2d(mask, torch.ones(1, 1, 3, 3), padding=1) + else: + mask = F.conv2d(mask, torch.ones(1, 1, 5, 5), padding=2) + mask = (mask > 0.5).float() + + p = random.random() + # if 0.0 <= p < 0.25: + assert self.bg in ["white", "black", "none"] + if self.bg == "white": + # multiply the mask with the image + img = img * mask + extended_mask = mask.repeat(3, 1, 1) + img[extended_mask <= 0.0] = 1.0 + # elif 0.25 <= p < 0.5: + elif self.bg == "black": + img = img * mask + else: + pass + # crop the image using the 4-d box range in [0, 1] + bbox = [ + int((box[1] ) * img.shape[1]), # top + int((box[0] ) * img.shape[2]), # left + int((box[3] - box[1]) * img.shape[1]), # height + int((box[2] - box[0]) * img.shape[2]), # width + ] + # random jitter the box in form of [top, left, height, width] with 2 pixels + offset = torch.randint(0, 3, (4,)) + bbox = [bbox[0] - offset[0], bbox[1] - offset[1], bbox[2] + offset[2], bbox[3] + offset[3]] + # keep the range + bbox[0] = max(0, bbox[0]) + bbox[1] = max(0, bbox[1]) + bbox[2] = min(img.shape[1] - bbox[0], bbox[2]) + bbox[3] = min(img.shape[2] - bbox[1], bbox[3]) + + max_size = max(bbox[2], bbox[3]) + # pad the box to a square and crop from the image + raw_center = [bbox[0] + bbox[2] / 2, bbox[1] + bbox[3] / 2] + new_left_top = [raw_center[0] - max_size / 2, raw_center[1] - max_size / 2] + new_right_bottom = [raw_center[0] + max_size / 2, raw_center[1] + max_size / 2] + # move the new bbox to keep the box in the image + if new_left_top[0] < 0: + new_right_bottom[0] -= new_left_top[0] + new_left_top[0] = 0 + if new_left_top[1] < 0: + new_right_bottom[1] -= new_left_top[1] + new_left_top[1] = 0 + if new_right_bottom[0] > img.shape[1]: + new_left_top[0] -= new_right_bottom[0] - img.shape[1] + new_right_bottom[0] = img.shape[1] + if new_right_bottom[1] > img.shape[2]: + new_left_top[1] -= new_right_bottom[1] - img.shape[2] + new_right_bottom[1] = img.shape[2] + # convert the new bbox to the form of [top, left, height, width] + bbox = [ + int(new_left_top[0]), + int(new_left_top[1]), + int(new_right_bottom[0] - new_left_top[0]), + int(new_right_bottom[1] - new_left_top[1]), + ] + + img = TF.crop( + img.clone(), + *bbox, + ) + + return img, raw_mask + + def __len__(self): + return len(self.meta_info) + + def __getitem__(self, index): + item = self.meta_info[index] + img_id = item["image"]["path"] + + # just for debug + while not os.path.exists(os.path.join(self.data_path, img_id)): + index = random.randint(0, len(self.meta_info) - 1) + item = self.meta_info[index] + img_id = item["image"]["path"] + + caption = item["image_caption"][0][-1] + class_names = item["instance_caption"] + + random_instance_id = random.choice(list(range(len(class_names)))) + class_name = class_names[random_instance_id][-1] + instance_id = class_names[random_instance_id][1] + + # print(class_names, class_name) + # class_name = remove_a(class_name) + + box = item["annotations"][instance_id]["bbox"] + # convert the box [x, y, w, h] to [x1, y1, x2, y2] and normalize to [0, 1] + box = [ + box[0] / item["image"]["width"], + box[1] / item["image"]["height"], + (box[0] + box[2]) / item["image"]["width"], + (box[1] + box[3]) / item["image"]["height"], + ] + + # box to tensor + box = torch.tensor(box) + + seg = item["annotations"][instance_id]["segmentation"] + + img_path = os.path.join(self.data_path, img_id) + image = Image.open(img_path).convert("RGB") + reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item() + image = image.resize((reize_res, reize_res), Image.Resampling.LANCZOS) + # convert the PIL.Image img to torch.Tensor + image = TF.to_tensor(image) + + mask = torch.from_numpy(maskUtils.decode(seg)).bool() + mask = mask.squeeze(0) + processed_img, mask = self._process_img(image, box, mask) + tokenized_class_name, tokenized_text, tokens_to_replace = \ + _find_tokens_to_replace(class_name, caption) + + crop = T.RandomCrop(self.crop_res) + # flip = T.RandomHorizontalFlip(float(self.flip_prob)) + + image = crop(image) + mask = crop(mask) + # flip image and mask + if random.random() < self.flip_prob: + image = TF.hflip(image) + mask = TF.hflip(mask) + box = torch.tensor([1 - box[2], box[1], 1 - box[0], box[3]]) + image = 2 * image - 1 + + # TODO: normalize the image & process the cropped image for CLIP embedding + # to PIL image + processed_img = TF.to_pil_image(processed_img, mode="RGB") + processed_img = _transform(224)(processed_img) + + res = dict( + img_id=img_id, + instance_id=instance_id, + edited=image, + class_name=class_name, + box=box, + mask=mask, + text=caption, + tokens_to_replace=tokens_to_replace, + processed_img=processed_img, # cropped image + tokenized_class_name=tokenized_class_name, + tokenized_text=tokenized_text, + ) + # # print the type of values in res + # for k, v in res.items(): + # print(k, type(v)) + return res + + +@torch.no_grad() +def _get_instance_image_embedding(clip_encoder, images): + # return self.clip_encoder.encode_image(images) + x = clip_encoder.visual.conv1(images.type(clip_encoder.visual.conv1.weight.dtype)) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([clip_encoder.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + clip_encoder.visual.positional_embedding.to(x.dtype) + x = clip_encoder.visual.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = clip_encoder.visual.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + return x + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/txt2img-samples" + ) + parser.add_argument( + "--skip_grid", + action='store_true', + help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", + ) + parser.add_argument( + "--skip_save", + action='store_true', + help="do not save individual samples. For speed measurements.", + ) + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + parser.add_argument( + "--plms", + action='store_true', + help="use plms sampling", + ) + parser.add_argument( + "--fixed_code", + action='store_true', + help="if enabled, uses the same starting code across samples ", + ) + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=2, + help="sample this often", + ) + parser.add_argument( + "--H", + type=int, + default=512, + help="image height, in pixel space", + ) + parser.add_argument( + "--W", + type=int, + default=512, + help="image width, in pixel space", + ) + parser.add_argument( + "--n_imgs", + type=int, + default=100, + help="image width, in pixel space", + ) + parser.add_argument( + "--C", + type=int, + default=4, + help="latent channels", + ) + parser.add_argument( + "--f", + type=int, + default=8, + help="downsampling factor", + ) + parser.add_argument( + "--n_samples", + type=int, + default=1, + help="how many samples to produce for each given reference image. A.k.a. batch size", + ) + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + parser.add_argument( + "--scale", + type=float, + default=1, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + parser.add_argument( + "--config", + type=str, + default="", + help="path to config which constructs model", + ) + parser.add_argument( + "--ckpt", + type=str, + default="", + help="path to checkpoint of model", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="the seed (for reproducible sampling)", + ) + parser.add_argument( + "--precision", + type=str, + help="evaluate at this precision", + choices=["full", "autocast"], + default="autocast" + ) + parser.add_argument( + "--image_path", + type=str, + help="evaluate at this precision", + default="" + ) + parser.add_argument( + "--mask_path", + type=str, + help="evaluate at this precision", + default="" + ) + parser.add_argument( + "--reference_path", + type=str, + help="evaluate at this precision", + default="" + ) + opt = parser.parse_args() + + dist_init() + seed_everything(opt.seed + get_rank()) + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + if opt.plms: + sampler = PLMSSampler(model) + else: + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + + sample_path = os.path.join(outpath, "source") + result_path = os.path.join(outpath, "results") + grid_path=os.path.join(outpath, "grid") + os.makedirs(sample_path, exist_ok=True) + os.makedirs(result_path, exist_ok=True) + os.makedirs(grid_path, exist_ok=True) + + + start_code = None + if opt.fixed_code: + start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) + + precision_scope = autocast if opt.precision=="autocast" else nullcontext + + dataset_obj = SAValDataset2( + path="/mnt/external/datasets/sam/", + split="val", + flip_prob=0.0, + ) + + retriever = RetrieverSA( + path_filtered_ids="/mnt/external/datasets/sam/merged_meta_data_new2.pt", + path_meta_info=f"/mnt/external/tmp/2023/05/12/clip_embeddings_large_sa-1b_white.pt" + ) + + clip_encoder = clip.load("ViT-L/14")[0] + clip_encoder = clip_encoder.cuda() + clip_encoder.eval() + + # split dataset into different ranks + inds = list(range(len(dataset_obj))) + inds = inds[get_rank()::get_world_size()] + + for idx in tqdm(inds): + # image_path = "" + # mask_path = "" + # reference_path = "" + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + # filename=os.path.basename(image_path) + # img_p = Image.open(image_path).convert("RGB") + # image_tensor = get_tensor()(img_p) + # image_tensor = image_tensor.unsqueeze(0) + # ref_p = Image.open(reference_path).convert("RGB").resize((224,224)) + # ref_tensor=get_tensor_clip()(ref_p) + # ref_tensor = ref_tensor.unsqueeze(0) + # mask=Image.open(mask_path).convert("L") + # mask = np.array(mask)[None,None] + # mask = 1 - mask.astype(np.float32)/255.0 + # mask[mask < 0.5] = 0 + # mask[mask >= 0.5] = 1 + # mask_tensor = torch.from_numpy(mask) + image_tensor = dataset_obj[idx]['edited'][None, ...] + ref_tensor = dataset_obj[idx]['processed_img'][None, ...] + + image_clip_embedding = _get_instance_image_embedding(clip_encoder, ref_tensor.cuda())[0, 0, ...] + top_k_instances = retriever.query(image_clip_embedding, retriever.total_tensor, 50) + + _count = 0 + for _img_id, _instance_id in top_k_instances: + # res = dataset_obj._get_item_by_img_instance_id(_img_id, _instance_id) + # import pdb; pdb.set_trace() + if _img_id not in retriever.new_filtered_ids: + print(f"img_id {_img_id} not in new_filtered_ids") + continue + instance_patches = retriever.get_instance_patch(_img_id, _instance_id, bg="none").unsqueeze(0) + ref_tensor = instance_patches + # import pdb; pdb.set_trace() + + mask_tensor = 1 - dataset_obj[idx]['mask'][None, ...] + filename = dataset_obj[idx]['img_id'] + filename = os.path.basename(filename) + # import pdb;pdb.set_trace() + inpaint_image = image_tensor*mask_tensor + test_model_kwargs={} + test_model_kwargs['inpaint_mask'] = mask_tensor.to(device) + test_model_kwargs['inpaint_image'] = inpaint_image.to(device) + ref_tensor=ref_tensor.to(device) + uc = None + if opt.scale != 1.0: + uc = model.learnable_vector + c = model.get_learned_conditioning(ref_tensor.to(torch.float16)) + c = model.proj_out(c) + inpaint_mask=test_model_kwargs['inpaint_mask'] + z_inpaint = model.encode_first_stage(test_model_kwargs['inpaint_image']) + z_inpaint = model.get_first_stage_encoding(z_inpaint).detach() + test_model_kwargs['inpaint_image']=z_inpaint + test_model_kwargs['inpaint_mask']=Resize([z_inpaint.shape[-2],z_inpaint.shape[-1]])(test_model_kwargs['inpaint_mask']) + + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] + samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code, + test_model_kwargs=test_model_kwargs) + + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() + + x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim) + x_checked_image=x_samples_ddim + x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) + + def un_norm(x): + return (x+1.0)/2.0 + def un_norm_clip(x): + x[0,:,:] = x[0,:,:] * 0.26862954 + 0.48145466 + x[1,:,:] = x[1,:,:] * 0.26130258 + 0.4578275 + x[2,:,:] = x[2,:,:] * 0.27577711 + 0.40821073 + return x + + if not opt.skip_save: + for i,x_sample in enumerate(x_checked_image_torch): + + + all_img=[] + all_img.append(un_norm(image_tensor[i]).cpu()) + all_img.append(un_norm(inpaint_image[i]).cpu()) + ref_img=ref_tensor + ref_img=Resize([opt.H, opt.W])(ref_img) + all_img.append(un_norm_clip(ref_img[i]).cpu()) + all_img.append(x_sample) + grid = torch.stack(all_img, 0) + grid = make_grid(grid) + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + img = Image.fromarray(grid.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(grid_path, 'grid-'+filename[:-4]+'_'+str(opt.seed)+f'_{_count}.png')) + + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + img = Image.fromarray(x_sample.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(result_path, filename[:-4]+'_'+str(opt.seed)+f'_{_count}.png')) + + mask_save=255.*rearrange(un_norm(inpaint_mask[i]).cpu(), 'c h w -> h w c').numpy() + mask_save= cv2.cvtColor(mask_save,cv2.COLOR_GRAY2RGB) + mask_save = Image.fromarray(mask_save.astype(np.uint8)) + mask_save.save(os.path.join(sample_path, filename[:-4]+'_'+str(opt.seed)+f"_mask_{_count}.png")) + GT_img=255.*rearrange(all_img[0], 'c h w -> h w c').numpy() + GT_img = Image.fromarray(GT_img.astype(np.uint8)) + GT_img.save(os.path.join(sample_path, filename[:-4]+'_'+str(opt.seed)+f"_GT_{_count}.png")) + inpaint_img=255.*rearrange(all_img[1], 'c h w -> h w c').numpy() + inpaint_img = Image.fromarray(inpaint_img.astype(np.uint8)) + inpaint_img.save(os.path.join(sample_path, filename[:-4]+'_'+str(opt.seed)+f"_inpaint_{_count}.png")) + ref_img=255.*rearrange(all_img[2], 'c h w -> h w c').numpy() + ref_img = Image.fromarray(ref_img.astype(np.uint8)) + ref_img.save(os.path.join(sample_path, filename[:-4]+'_'+str(opt.seed)+f"_ref_{_count}.png")) + + _count += 1 + if _count >= 50: + break + + print(f"Your samples are ready and waiting for you here: \n{outpath} \n" + f" \nEnjoy.") + + +if __name__ == "__main__": + main() diff --git a/scripts/inferencev1.py b/scripts/inferencev1.py new file mode 100644 index 0000000..1b84ded --- /dev/null +++ b/scripts/inferencev1.py @@ -0,0 +1,450 @@ +import argparse, os, sys, glob +import cv2 +import torch +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from imwatermark import WatermarkEncoder +from itertools import islice +from einops import rearrange +from torchvision.utils import make_grid +import time +from pytorch_lightning import seed_everything +from torch import autocast +from contextlib import contextmanager, nullcontext +import torchvision +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler + +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from transformers import AutoFeatureExtractor +import clip +from torchvision.transforms import Resize +wm = "Paint-by-Example" +wm_encoder = WatermarkEncoder() +wm_encoder.set_watermark('bytes', wm.encode('utf-8')) +safety_model_id = "CompVis/stable-diffusion-safety-checker" +safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) +safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + +def get_tensor_clip(normalize=True, toTensor=True): + transform_list = [] + if toTensor: + transform_list += [torchvision.transforms.ToTensor()] + + if normalize: + transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711))] + return torchvision.transforms.Compose(transform_list) + +def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + +def put_watermark(img, wm_encoder=None): + if wm_encoder is not None: + img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + img = wm_encoder.encode(img, 'dwtDct') + img = Image.fromarray(img[:, :, ::-1]) + return img + + +def load_replacement(x): + try: + hwc = x.shape + y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0])) + y = (np.array(y)/255.0).astype(x.dtype) + assert y.shape == x.shape + return y + except Exception: + return x + + +def check_safety(x_image): + safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") + x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) + assert x_checked_image.shape[0] == len(has_nsfw_concept) + for i in range(len(has_nsfw_concept)): + if has_nsfw_concept[i]: + x_checked_image[i] = load_replacement(x_checked_image[i]) + return x_checked_image, has_nsfw_concept + +def get_tensor(normalize=True, toTensor=True): + transform_list = [] + if toTensor: + transform_list += [torchvision.transforms.ToTensor()] + + if normalize: + transform_list += [torchvision.transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + return torchvision.transforms.Compose(transform_list) + +def get_tensor_clip(normalize=True, toTensor=True): + transform_list = [] + if toTensor: + transform_list += [torchvision.transforms.ToTensor()] + + if normalize: + transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711))] + return torchvision.transforms.Compose(transform_list) + + + +# dist utils +def get_world_size(): + return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + + +def get_rank(): + return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + + +def expand_first_dim(x, num_repeats=5, unsqueeze=False): + if unsqueeze: + x = x.unsqueeze(0) + x = x.repeat(num_repeats, *([1] * (x.dim() - 1))) + return x + + +def dist_init(): + if 'MASTER_ADDR' not in os.environ: + os.environ['MASTER_ADDR'] = 'localhost' + if 'MASTER_PORT' not in os.environ: + os.environ['MASTER_PORT'] = '29500' + if 'RANK' not in os.environ: + os.environ['RANK'] = '0' + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = '0' + if 'WORLD_SIZE' not in os.environ: + os.environ['WORLD_SIZE'] = '1' + + backend = 'gloo' if os.name == 'nt' else 'nccl' + torch.distributed.init_process_group(backend=backend, init_method='env://') + torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) + + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/txt2img-samples" + ) + parser.add_argument( + "--skip_grid", + action='store_true', + help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", + ) + parser.add_argument( + "--skip_save", + action='store_true', + help="do not save individual samples. For speed measurements.", + ) + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + parser.add_argument( + "--plms", + action='store_true', + help="use plms sampling", + ) + parser.add_argument( + "--fixed_code", + action='store_true', + help="if enabled, uses the same starting code across samples ", + ) + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=2, + help="sample this often", + ) + parser.add_argument( + "--H", + type=int, + default=512, + help="image height, in pixel space", + ) + parser.add_argument( + "--W", + type=int, + default=512, + help="image width, in pixel space", + ) + parser.add_argument( + "--n_imgs", + type=int, + default=100, + help="image width, in pixel space", + ) + parser.add_argument( + "--C", + type=int, + default=4, + help="latent channels", + ) + parser.add_argument( + "--f", + type=int, + default=8, + help="downsampling factor", + ) + parser.add_argument( + "--n_samples", + type=int, + default=1, + help="how many samples to produce for each given reference image. A.k.a. batch size", + ) + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + parser.add_argument( + "--scale", + type=float, + default=1, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + parser.add_argument( + "--config", + type=str, + default="", + help="path to config which constructs model", + ) + parser.add_argument( + "--ckpt", + type=str, + default="", + help="path to checkpoint of model", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="the seed (for reproducible sampling)", + ) + parser.add_argument( + "--precision", + type=str, + help="evaluate at this precision", + choices=["full", "autocast"], + default="autocast" + ) + parser.add_argument( + "--image_path", + type=str, + help="evaluate at this precision", + default="" + ) + parser.add_argument( + "--mask_path", + type=str, + help="evaluate at this precision", + default="" + ) + parser.add_argument( + "--reference_path", + type=str, + help="evaluate at this precision", + default="" + ) + opt = parser.parse_args() + + dist_init() + seed_everything(opt.seed) + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + if opt.plms: + sampler = PLMSSampler(model) + else: + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + + sample_path = os.path.join(outpath, "source") + result_path = os.path.join(outpath, "results") + grid_path=os.path.join(outpath, "grid") + os.makedirs(sample_path, exist_ok=True) + os.makedirs(result_path, exist_ok=True) + os.makedirs(grid_path, exist_ok=True) + + + start_code = None + if opt.fixed_code: + start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) + + file_names = glob.glob(f"test_cases/*_input.jpg") + file_names = sorted(file_names) + file_names = file_names[get_rank()::get_world_size()] + + precision_scope = autocast if opt.precision=="autocast" else nullcontext + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + for filename in file_names: + # filename = os.path.basename(opt.image_path) + img_p = Image.open(filename).convert("RGB") + img_p = img_p.resize((opt.H, opt.W)) + image_tensor = get_tensor()(img_p) + image_tensor = image_tensor.unsqueeze(0) + reference_path = filename.replace("_input.jpg", "_ref.jpg") + ref_p = Image.open(reference_path).convert("RGB").resize((224,224)) + ref_tensor=get_tensor_clip()(ref_p) + ref_tensor = ref_tensor.unsqueeze(0) + mask_path = filename.replace("_input.jpg", "_mask.jpg") + mask=Image.open(mask_path).convert("L") + mask = mask.resize((opt.H, opt.W)) + mask = np.array(mask)[None,None] + mask = 1 - mask.astype(np.float32)/255.0 + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask_tensor = torch.from_numpy(mask) + inpaint_image = image_tensor*mask_tensor + test_model_kwargs={} + test_model_kwargs['inpaint_mask']=mask_tensor.to(device) + test_model_kwargs['inpaint_image']=inpaint_image.to(device) + ref_tensor=ref_tensor.to(device) + uc = None + if opt.scale != 1.0: + uc = model.learnable_vector + c = model.get_learned_conditioning(ref_tensor.to(torch.float16)) + c = model.proj_out(c) + inpaint_mask=test_model_kwargs['inpaint_mask'] + z_inpaint = model.encode_first_stage(test_model_kwargs['inpaint_image']) + z_inpaint = model.get_first_stage_encoding(z_inpaint).detach() + test_model_kwargs['inpaint_image']=z_inpaint + test_model_kwargs['inpaint_mask']=Resize([z_inpaint.shape[-2],z_inpaint.shape[-1]])(test_model_kwargs['inpaint_mask']) + + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] + samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code, + test_model_kwargs=test_model_kwargs) + + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() + + x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim) + x_checked_image=x_samples_ddim + x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) + + def un_norm(x): + return (x+1.0)/2.0 + def un_norm_clip(x): + x[0,:,:] = x[0,:,:] * 0.26862954 + 0.48145466 + x[1,:,:] = x[1,:,:] * 0.26130258 + 0.4578275 + x[2,:,:] = x[2,:,:] * 0.27577711 + 0.40821073 + return x + + if not opt.skip_save: + for i,x_sample in enumerate(x_checked_image_torch): + + + all_img=[] + all_img.append(un_norm(image_tensor[i]).cpu()) + all_img.append(un_norm(inpaint_image[i]).cpu()) + ref_img=ref_tensor + ref_img=Resize([opt.H, opt.W])(ref_img) + all_img.append(un_norm_clip(ref_img[i]).cpu()) + all_img.append(x_sample) + grid = torch.stack(all_img, 0) + grid = make_grid(grid) + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + img = Image.fromarray(grid.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(grid_path, 'grid-'+ os.path.basename(filename)[:-4] +'_'+str(opt.seed)+'.png')) + + + + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + img = Image.fromarray(x_sample.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(result_path, os.path.basename(filename)[:-4] +'_'+str(opt.seed)+".png")) + + mask_save=255.*rearrange(un_norm(inpaint_mask[i]).cpu(), 'c h w -> h w c').numpy() + mask_save= cv2.cvtColor(mask_save,cv2.COLOR_GRAY2RGB) + mask_save = Image.fromarray(mask_save.astype(np.uint8)) + mask_save.save(os.path.join(sample_path, os.path.basename(filename)[:-4]+'_'+str(opt.seed)+"_mask.png")) + GT_img=255.*rearrange(all_img[0], 'c h w -> h w c').numpy() + GT_img = Image.fromarray(GT_img.astype(np.uint8)) + GT_img.save(os.path.join(sample_path, os.path.basename(filename)[:-4]+'_'+str(opt.seed)+"_GT.png")) + inpaint_img=255.*rearrange(all_img[1], 'c h w -> h w c').numpy() + inpaint_img = Image.fromarray(inpaint_img.astype(np.uint8)) + inpaint_img.save(os.path.join(sample_path, os.path.basename(filename)[:-4]+'_'+str(opt.seed)+"_inpaint.png")) + ref_img=255.*rearrange(all_img[2], 'c h w -> h w c').numpy() + ref_img = Image.fromarray(ref_img.astype(np.uint8)) + ref_img.save(os.path.join(sample_path, os.path.basename(filename)[:-4] +'_'+str(opt.seed)+"_ref.png")) + + print(f"Your samples are ready and waiting for you here: \n{outpath} \n" + f" \nEnjoy.") + + +if __name__ == "__main__": + main() diff --git a/test.sh b/test.sh index c2d7b70..ae700ad 100644 --- a/test.sh +++ b/test.sh @@ -1,29 +1,157 @@ -python scripts/inference.py \ ---plms --outdir results \ ---config configs/v1.yaml \ ---ckpt checkpoints/model.ckpt \ ---image_path examples/image/example_1.png \ ---mask_path examples/mask/example_1.png \ ---reference_path examples/reference/example_1.jpg \ ---seed 321 \ ---scale 5 - -python scripts/inference.py \ ---plms --outdir results \ ---config configs/v1.yaml \ ---ckpt checkpoints/model.ckpt \ ---image_path examples/image/example_2.png \ ---mask_path examples/mask/example_2.png \ ---reference_path examples/reference/example_2.jpg \ ---seed 5876 \ ---scale 5 - -python scripts/inference.py \ ---plms --outdir results \ ---config configs/v1.yaml \ ---ckpt checkpoints/model.ckpt \ ---image_path examples/image/example_3.png \ ---mask_path examples/mask/example_3.png \ ---reference_path examples/reference/example_3.jpg \ ---seed 5065 \ ---scale 5 +pip install pycocotools + +FILENAME=checkpoints/model.ckpt +# check if file exists +if [ ! -f "$FILENAME" ]; then + ID=15QzaTWsvZonJcXsNv-ilMRCYaQLhzR_i + mkdir -p checkpoints + + wget --load-cookies /tmp/cookies.txt \ + "https://docs.google.com/uc?export=download&confirm=$(wget \ + --quiet --save-cookies /tmp/cookies.txt \ + --keep-session-cookies \ + --no-check-certificate \ + "https://docs.google.com/uc?export=download&id=$ID" \ + -O- | sed -rn "s/.*confirm=([0-9A-Za-z_]+).*/\1\n/p")&id=$ID" \ + -O $FILENAME && rm -rf /tmp/cookies.txt +else + echo "Model checkpoint already exists." +fi + +torchrun --nproc_per_node=4 scripts/inferencev1.py \ + --plms --outdir results-512 \ + --config configs/v1.yaml \ + --ckpt checkpoints/model.ckpt \ + --image_path test_cases/${num}_input.jpg \ + --mask_path test_cases/${num}_mask.jpg \ + --reference_path test_cases/${num}_ref.jpg \ + --seed 321 \ + --scale 5 \ + --n_samples 1 + +# python scripts/inferencev1.py \ +# --plms --outdir results-512 \ +# --config configs/v1.yaml \ +# --ckpt checkpoints/model.ckpt \ +# --image_path test_cases/${num}_input.jpg \ +# --mask_path test_cases/${num}_mask.jpg \ +# --reference_path test_cases/${num}_ref.jpg \ +# --seed 321 \ +# --scale 5 \ +# --n_samples 1 \ +# --H 256 --W 256 + + +# iterate over all examples, the names are like test_cases/972_input.jpg, test_cases/972_mask.png, test_cases/972_reference.jpg + +# for num in {1..10} + +# do +# echo "Processing $num" +# # test whether the file exists +# if [ ! -f "test_cases/${num}_input.jpg" ]; then +# echo "test_cases/${num}_input.jpg does not exist." +# continue +# fi +# mkdir -p results-512 +# python scripts/inference.py \ +# --plms --outdir results-512 \ +# --config configs/v1.yaml \ +# --ckpt checkpoints/model.ckpt \ +# --image_path test_cases/${num}_input.jpg \ +# --mask_path test_cases/${num}_mask.jpg \ +# --reference_path test_cases/${num}_ref.jpg \ +# --seed 321 \ +# --scale 5 \ +# --n_samples 1 + +# mkdir -p results-256 +# python scripts/inference.py \ +# --plms --outdir results-256 \ +# --config configs/v1.yaml \ +# --ckpt checkpoints/model.ckpt \ +# --image_path test_cases/${num}_input.jpg \ +# --mask_path test_cases/${num}_mask.jpg \ +# --reference_path test_cases/${num}_ref.jpg \ +# --seed 321 \ +# --scale 5 \ +# --n_samples 1 \ +# --H 256 --W 256 +# done + +# python scripts/inference.py \ +# --plms --outdir results \ +# --config configs/v1.yaml \ +# --ckpt checkpoints/model.ckpt \ +# --image_path examples/image/example_1.png \ +# --mask_path examples/mask/example_1.png \ +# --reference_path examples/reference/example_1.jpg \ +# --seed 321 \ +# --scale 5 + +# torchrun --nproc_per_node=8 scripts/inference_parallel.py \ +# --plms --outdir /mnt/external/tmp/2023/05/22/paint-by-example-results \ +# --config configs/v1.yaml \ +# --ckpt checkpoints/model.ckpt \ +# --image_path examples/image/example_1.png \ +# --mask_path examples/mask/example_1.png \ +# --reference_path examples/reference/example_1.jpg \ +# --seed 321 \ +# --scale 5 \ +# --n_samples 1 --H 256 --W 256 + +# torchrun --nproc_per_node=8 scripts/inference_parallel.py \ +# --plms --outdir /mnt/external/tmp/2023/05/22/paint-by-example-results \ +# --config configs/v1.yaml \ +# --ckpt checkpoints/model.ckpt \ +# --image_path examples/image/example_2.png \ +# --mask_path examples/mask/example_2.png \ +# --reference_path examples/reference/example_2.jpg \ +# --seed 5876 \ +# --scale 5 \ +# --n_samples 1 --H 256 --W 256 + +# torchrun --nproc_per_node=8 scripts/inference_parallel.py \ +# --plms --outdir /mnt/external/tmp/2023/05/22/paint-by-example-results \ +# --config configs/v1.yaml \ +# --ckpt checkpoints/model.ckpt \ +# --image_path examples/image/example_3.png \ +# --mask_path examples/mask/example_3.png \ +# --reference_path examples/reference/example_3.jpg \ +# --seed 5065 \ +# --scale 5 \ +# --n_samples 1 --H 256 --W 256 + + +# python scripts/inference.py \ +# --plms --outdir results \ +# --config configs/v1.yaml \ +# --ckpt checkpoints/model.ckpt \ +# --image_path examples/image/example_1.png \ +# --mask_path examples/mask/example_1.png \ +# --reference_path examples/reference/example_1.jpg \ +# --seed 321 \ +# --scale 5 \ +# --n_samples 1 + +# python scripts/inference.py \ +# --plms --outdir results \ +# --config configs/v1.yaml \ +# --ckpt checkpoints/model.ckpt \ +# --image_path examples/image/example_2.png \ +# --mask_path examples/mask/example_2.png \ +# --reference_path examples/reference/example_2.jpg \ +# --seed 5876 \ +# --scale 5 \ +# --n_samples 1 + +# python scripts/inference.py \ +# --plms --outdir results \ +# --config configs/v1.yaml \ +# --ckpt checkpoints/model.ckpt \ +# --image_path examples/image/example_3.png \ +# --mask_path examples/mask/example_3.png \ +# --reference_path examples/reference/example_3.jpg \ +# --seed 5065 \ +# --scale 5 \ +# --n_samples 1