diff --git a/drag_bench_evaluation/run_drag_diffusion.py b/drag_bench_evaluation/run_drag_diffusion.py index 6d9f26b..a7a11d6 100644 --- a/drag_bench_evaluation/run_drag_diffusion.py +++ b/drag_bench_evaluation/run_drag_diffusion.py @@ -155,7 +155,7 @@ def run_drag(source_image, # feature shape: [1280,16,16], [1280,32,32], [640,64,64], [320,64,64] # update according to the given supervision - updated_init_code = drag_diffusion_update(model, init_code, + updated_init_code, opt_seq = drag_diffusion_update(model, init_code, None, t, handle_points, target_points, mask, args) # hijack the attention module diff --git a/drag_ui.py b/drag_ui.py index 3edd9fa..9b1a43e 100755 --- a/drag_ui.py +++ b/drag_ui.py @@ -61,6 +61,7 @@ prompt = gr.Textbox(label="Prompt") lora_path = gr.Textbox(value="./lora_tmp", label="LoRA path") lora_status_bar = gr.Textbox(label="display LoRA training status") + results_path = gr.Textbox(value="./results", label="Folder for results") # algorithm specific parameters with gr.Tab("Drag Config"): @@ -79,6 +80,11 @@ latent_lr = gr.Number(value=0.01, label="latent lr") start_step = gr.Number(value=0, label="start_step", precision=0, visible=False) start_layer = gr.Number(value=10, label="start_layer", precision=0, visible=False) + save_optimization_seq_rgb = gr.Checkbox( + value=False, + label="Save Opt Seq", + info="Each step of optimization latent is saved in png file. Will take more time." + ) with gr.Tab("Base Model Config"): with gr.Row(): @@ -136,6 +142,7 @@ with gr.Row(): pos_prompt_gen = gr.Textbox(label="Positive Prompt") neg_prompt_gen = gr.Textbox(label="Negative Prompt") + results_path_gen = gr.Textbox(value="./results", label="Folder for results") with gr.Tab("Generation Config"): with gr.Row(): @@ -218,6 +225,11 @@ latent_lr_gen = gr.Number(value=0.01, label="latent lr") start_step_gen = gr.Number(value=0, label="start_step", precision=0, visible=False) start_layer_gen = gr.Number(value=10, label="start_layer", precision=0, visible=False) + save_optimization_seq_rgb_gen = gr.Checkbox( + value=False, + label="Save Opt Seq", + info="Each step of optimization latent is saved in png file. Will take more time." + ) # event definition # event for dragging user-input real image @@ -265,6 +277,8 @@ lora_path, start_step, start_layer, + results_path, + save_optimization_seq_rgb, ], [output_image] ) @@ -343,6 +357,8 @@ b2_gen, s1_gen, s2_gen, + results_path_gen, + save_optimization_seq_rgb_gen, ], [output_image_gen] ) diff --git a/utils/drag_utils.py b/utils/drag_utils.py index ff4f9e4..e8f836f 100755 --- a/utils/drag_utils.py +++ b/utils/drag_utils.py @@ -101,6 +101,7 @@ def drag_diffusion_update(model, # prepare optimizable init_code and optimizer init_code.requires_grad_(True) optimizer = torch.optim.Adam([init_code], lr=args.lr) + opt_seq = [init_code.detach().clone()] # prepare for point tracking and background regularization handle_points_init = copy.deepcopy(handle_points) @@ -157,8 +158,9 @@ def drag_diffusion_update(model, scaler.step(optimizer) scaler.update() optimizer.zero_grad() + opt_seq.append(init_code.detach().clone()) - return init_code + return init_code, opt_seq def drag_diffusion_update_gen(model, init_code, @@ -218,6 +220,7 @@ def drag_diffusion_update_gen(model, # prepare amp scaler for mixed-precision training scaler = torch.cuda.amp.GradScaler() + opt_seq = [init_code.detach().clone()] for step_idx in range(args.n_pix_step): with torch.autocast(device_type='cuda', dtype=torch.float16): if args.guidance_scale > 1.: @@ -281,6 +284,7 @@ def drag_diffusion_update_gen(model, scaler.step(optimizer) scaler.update() optimizer.zero_grad() + opt_seq.append(init_code.detach().clone()) - return init_code + return init_code, opt_seq diff --git a/utils/ui_utils.py b/utils/ui_utils.py index 6f1f09a..bdbb187 100755 --- a/utils/ui_utils.py +++ b/utils/ui_utils.py @@ -185,7 +185,8 @@ def run_drag(source_image, lora_path, start_step, start_layer, - save_dir="./results" + save_dir="./results", + save_seq=True, ): # initialize model device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") @@ -288,7 +289,7 @@ def run_drag(source_image, text_embeddings = text_embeddings.float() model.unet = model.unet.float() - updated_init_code = drag_diffusion_update( + updated_init_code, opt_seq = drag_diffusion_update( model, init_code, text_embeddings, @@ -345,6 +346,27 @@ def run_drag(source_image, save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") save_image(save_result, os.path.join(save_dir, save_prefix + '.png')) + if save_seq: + os.mkdir(os.path.join(save_dir, save_prefix)) + # save list of latents in pt file + torch.save(opt_seq, os.path.join(save_dir, save_prefix, 'opt_seq.pt')) + for i in range(0, len(opt_seq), 2): + # denoise latents and save + latents = torch.cat([opt_seq[i].half(), opt_seq[i+1].half() if i+1 < len(opt_seq) else opt_seq[i].half()], dim=0) + gen_image_seq = model( + args.prompt, + encoder_hidden_states=torch.cat([text_embeddings]*2, dim=0), + batch_size=2, + latents=latents, + guidance_scale=args.guidance_scale, + num_inference_steps=args.n_inference_step, + num_actual_inference_steps=args.n_actual_inference_step + ) + gen_image_seq = F.interpolate(gen_image_seq, (full_h, full_w), mode='bilinear') + save_image(gen_image_seq[0].unsqueeze(dim=0), os.path.join(save_dir, save_prefix, f'iter_{i}.png')) + if i+1 < len(opt_seq): + save_image(gen_image_seq[1].unsqueeze(dim=0), os.path.join(save_dir, save_prefix, f'iter_{i+1}.png')) + out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0] out_image = (out_image * 255).astype(np.uint8) return out_image @@ -464,7 +486,9 @@ def run_drag_gen( b2, s1, s2, - save_dir="./results"): + save_dir="./results", + save_seq=True, + ): # initialize model device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = DragPipeline.from_pretrained(model_path, torch_dtype=torch.float16) @@ -572,7 +596,7 @@ def run_drag_gen( init_code = init_code.to(torch.float32) text_embeddings = text_embeddings.to(torch.float32) model.unet = model.unet.to(torch.float32) - updated_init_code = drag_diffusion_update_gen(model, init_code, + updated_init_code, opt_seq = drag_diffusion_update_gen(model, init_code, text_embeddings, t, handle_points, target_points, mask, args) updated_init_code = updated_init_code.to(torch.float16) text_embeddings = text_embeddings.to(torch.float16) @@ -619,6 +643,27 @@ def run_drag_gen( save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") save_image(save_result, os.path.join(save_dir, save_prefix + '.png')) + if save_seq: + os.mkdir(os.path.join(save_dir, save_prefix)) + # save list of latents in pt file + torch.save(opt_seq, os.path.join(save_dir, save_prefix, 'opt_seq.pt')) + for i in range(0, len(opt_seq), 2): + # denoise latents and save + latents = torch.cat([opt_seq[i].half(), opt_seq[i+1].half() if i+1 < len(opt_seq) else opt_seq[i].half()], dim=0) + gen_image_seq = model( + args.prompt, + encoder_hidden_states=torch.cat([text_embeddings]*2, dim=0), + batch_size=2, + latents=latents, + guidance_scale=args.guidance_scale, + num_inference_steps=args.n_inference_step, + num_actual_inference_steps=args.n_actual_inference_step + ) + gen_image_seq = F.interpolate(gen_image_seq, (full_h, full_w), mode='bilinear') + save_image(gen_image_seq[0].unsqueeze(dim=0), os.path.join(save_dir, save_prefix, f'iter_{i}.png')) + if i+1 < len(opt_seq): + save_image(gen_image_seq[1].unsqueeze(dim=0), os.path.join(save_dir, save_prefix, f'iter_{i+1}.png')) + out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0] out_image = (out_image * 255).astype(np.uint8) return out_image