-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ecd0b07
commit ef17546
Showing
1 changed file
with
209 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
from __future__ import annotations | ||
|
||
import math | ||
import os.path | ||
import random | ||
import shutil | ||
import sys | ||
from argparse import ArgumentParser | ||
from tqdm import tqdm | ||
|
||
import einops | ||
import k_diffusion as K | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from einops import rearrange | ||
from omegaconf import OmegaConf | ||
from PIL import Image, ImageOps | ||
from torch import autocast | ||
import json | ||
import os | ||
sys.path.append("./stable_diffusion") | ||
|
||
from stable_diffusion.ldm.util import instantiate_from_config | ||
|
||
|
||
class CFGDenoiser(nn.Module): | ||
def __init__(self, model): | ||
super().__init__() | ||
self.inner_model = model | ||
|
||
def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale): | ||
cfg_z = einops.repeat(z, "1 ... -> n ...", n=3) | ||
cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3) | ||
cfg_cond = { | ||
"c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])], | ||
"c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])], | ||
} | ||
out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3) | ||
return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond) | ||
|
||
|
||
def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False): | ||
print(f"Loading model from {ckpt}") | ||
print(ckpt) | ||
pl_sd = torch.load(ckpt, map_location="cpu") | ||
if "global_step" in pl_sd: | ||
print(f"Global Step: {pl_sd['global_step']}") | ||
sd = pl_sd["state_dict"] | ||
if vae_ckpt is not None: | ||
print(f"Loading VAE from {vae_ckpt}") | ||
vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"] | ||
sd = { | ||
k: vae_sd[k[len("first_s tage_model.") :]] if k.startswith("first_stage_model.") else v | ||
for k, v in sd.items() | ||
} | ||
model = instantiate_from_config(config.model) | ||
m, u = model.load_state_dict(sd, strict=False) | ||
if len(m) > 0 and verbose: | ||
print("missing keys:") | ||
print(m) | ||
if len(u) > 0 and verbose: | ||
print("unexpected keys:") | ||
print(u) | ||
return model | ||
|
||
|
||
def main(): | ||
parser = ArgumentParser() | ||
parser.add_argument("--resolution", default=512, type=int) | ||
parser.add_argument("--steps", default=100, type=int) | ||
parser.add_argument("--config", default="configs/generate.yaml", type=str) # default 2.1 | ||
parser.add_argument("--ckpt", default="checkpoints/MagicBrush-epoch-52-step-4999.ckpt", type=str) # default 2.1 | ||
parser.add_argument("--vae-ckpt", default=None, type=str) | ||
parser.add_argument("--input_path", required=True, type=str, default='./dev/images') | ||
parser.add_argument("--output_path", required=True, type=str, default='./DevMagicBrushOutput') | ||
parser.add_argument("--cfg-text", default=7.5, type=float) | ||
parser.add_argument("--cfg-image", default=1.5, type=float) | ||
parser.add_argument("--debug", action="store_true") | ||
parser.add_argument("--skip_iter", action="store_true") | ||
|
||
parser.add_argument("--seed", type=int) | ||
args = parser.parse_args() | ||
|
||
# with open('your_input_path_dir/series_instructions.json') as fp: | ||
# with open(os.path.join(args.input_path + '-meta', 'series_instructions.json')) as fp: | ||
with open(os.path.join(args.input_path, '..', 'edit_sessions.json')) as fp: | ||
data_json = json.load(fp) | ||
if not os.path.exists(args.output_path): | ||
os.makedirs(args.output_path) | ||
print(1) | ||
config = OmegaConf.load(args.config) | ||
model = load_model_from_config(config, args.ckpt, args.vae_ckpt) | ||
model.eval().cuda() | ||
model_wrap = K.external.CompVisDenoiser(model) | ||
model_wrap_cfg = CFGDenoiser(model_wrap) | ||
null_token = model.get_learned_conditioning([""]) | ||
|
||
seed = random.randint(0, 100000) if args.seed is None else args.seed | ||
|
||
|
||
|
||
def edit_image(input_path, output_path, edict_instruction): | ||
input_image = Image.open(input_path).convert("RGB") | ||
width, height = input_image.size | ||
factor = args.resolution / max(width, height) | ||
factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height) | ||
width = int((width * factor) // 64) * 64 | ||
height = int((height * factor) // 64) * 64 | ||
input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS) | ||
|
||
with torch.no_grad(), autocast("cuda"), model.ema_scope(): | ||
cond = {} | ||
cond["c_crossattn"] = [model.get_learned_conditioning([edict_instruction])] | ||
input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1 | ||
input_image = rearrange(input_image, "h w c -> 1 c h w").to(model.device) | ||
cond["c_concat"] = [model.encode_first_stage(input_image).mode()] | ||
|
||
uncond = {} | ||
uncond["c_crossattn"] = [null_token] | ||
uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])] | ||
|
||
sigmas = model_wrap.get_sigmas(args.steps) | ||
|
||
extra_args = { | ||
"cond": cond, | ||
"uncond": uncond, | ||
"text_cfg_scale": args.cfg_text, | ||
"image_cfg_scale": args.cfg_image, | ||
} | ||
torch.manual_seed(seed) | ||
z = torch.randn_like(cond["c_concat"][0]) * sigmas[0] | ||
z = K.sampling.sample_euler_ancestral(model_wrap_cfg, z, sigmas, extra_args=extra_args) | ||
x = model.decode_first_stage(z) | ||
x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0) | ||
x = 255.0 * rearrange(x, "1 c h w -> h w c") | ||
edited_image = Image.fromarray(x.type(torch.uint8).cpu().numpy()) | ||
edited_image.save(output_path) | ||
|
||
if args.debug: | ||
data_json = data_json[:3] | ||
|
||
# iterative edit | ||
if args.skip_iter: | ||
print("Skip Iterative (Mutli-Turn) Editing......") | ||
else: | ||
print("Iterative (Mutli-Turn) Editing......") | ||
# for idx, datas in enumerate(data_json): | ||
for image_id, datas in tqdm(data_json.items()): | ||
for turn_id, data in enumerate(datas): | ||
print("data", data) | ||
image_name = data['input'] | ||
image_dir = image_name.split('-')[0] | ||
# 139306-input.png | ||
# 139306-output1.png | ||
# image_dir is 139306 | ||
if turn_id == 0: # first enter | ||
image_path = os.path.join(args.input_path, image_dir, image_name) | ||
else: | ||
image_path = save_output_img_path | ||
edit_instruction = data['instruction'] | ||
save_output_dir_path = os.path.join(args.output_path, image_dir) | ||
if not os.path.exists(save_output_dir_path): | ||
os.makedirs(save_output_dir_path) | ||
if turn_id == 0: | ||
save_output_img_path = os.path.join(save_output_dir_path, image_dir+'_1.png') | ||
else: | ||
save_output_img_path = os.path.join(save_output_dir_path, image_dir+'_iter_' +str(turn_id + 1)+'.png') | ||
|
||
if os.path.exists(save_output_img_path): | ||
print('already generated, skip') | ||
continue | ||
print('image_name', image_name) | ||
print('image_path', image_path) | ||
print('save_output_img_path', save_output_img_path) | ||
print('edit_instruction', edit_instruction) | ||
edit_image(image_path, save_output_img_path, edit_instruction) | ||
|
||
print("Independent (Single-Turn) Editing......") | ||
# for idx, datas in enumerate(data_json): | ||
for image_id, datas in tqdm(data_json.items()): | ||
for turn_id, data in enumerate(datas): | ||
image_name = data['input'] | ||
image_dir = image_name.split('-')[0] | ||
image_path = os.path.join(args.input_path, image_dir, image_name) | ||
edit_instruction = data['instruction'] | ||
save_outut_dir_path = os.path.join(args.output_path, image_dir) | ||
if not os.path.exists(save_outut_dir_path): | ||
os.makedirs(save_outut_dir_path) | ||
if turn_id == 0: | ||
save_output_img_path = os.path.join(save_outut_dir_path, image_dir+'_1.png') | ||
# if os.path.exists(save_output_img_path): # already generated in iterative editing | ||
# print('already generated in iterative editing') | ||
# continue | ||
else: | ||
save_output_img_path = os.path.join(save_outut_dir_path, image_dir+'_inde_' +str(turn_id + 1)+'.png') | ||
if os.path.exists(save_output_img_path): # already generated in iterative (multi-turn) editing. | ||
print('already generated, skip') | ||
continue | ||
print('image_name', image_name) | ||
print('image_path', image_path) | ||
print('save_output_img_path', save_output_img_path) | ||
print('edit_instruction', edit_instruction) | ||
edit_image(image_path, save_output_img_path, edit_instruction) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
|