From 25754d826be1f0443d12498ad7eee651ca730322 Mon Sep 17 00:00:00 2001 From: LoveSy Date: Sun, 24 Nov 2024 08:29:51 -0500 Subject: [PATCH 1/7] Refine graio --- app/app_sana.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/app/app_sana.py b/app/app_sana.py index a392924..4efd71d 100755 --- a/app/app_sana.py +++ b/app/app_sana.py @@ -21,6 +21,7 @@ import random import time import uuid +import socket from datetime import datetime import gradio as gr @@ -308,7 +309,7 @@ def generate(

Sana-{model_size}B{args.image_size}px

Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer

[Paper] [Github(coming soon)] [Project] -

Powered by DC-AE with 32x latent space

, running on A6000 node. +

Powered by DC-AE with 32x latent space,

running on node {socket.gethostname()}.

Unsafe word will give you a 'Red Heart' in the image instead.

""" if model_size == "0.6": @@ -336,7 +337,7 @@ def generate( """ with gr.Blocks(css=css) as demo: gr.Markdown(title) - gr.Markdown(DESCRIPTION) + gr.HTML(DESCRIPTION) gr.DuplicateButton( value="Duplicate Space for private use", elem_id="duplicate-button", @@ -485,4 +486,4 @@ def generate( ) if __name__ == "__main__": - demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=True, share=True) + demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=False) From 56fd9f9d9292c9ee9501ca04d02e25a53fa24f86 Mon Sep 17 00:00:00 2001 From: LoveSy Date: Sun, 24 Nov 2024 09:59:26 -0500 Subject: [PATCH 2/7] Refine counter --- app/app_sana.py | 39 +++++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/app/app_sana.py b/app/app_sana.py index 4efd71d..4839c25 100755 --- a/app/app_sana.py +++ b/app/app_sana.py @@ -19,9 +19,10 @@ import argparse import os import random +import socket +import sqlite3 import time import uuid -import socket from datetime import datetime import gradio as gr @@ -42,6 +43,7 @@ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1" DEMO_PORT = int(os.getenv("DEMO_PORT", "15432")) os.environ["GRADIO_EXAMPLES_CACHE"] = "./.gradio/cache" +COUNTER_DB = os.getenv('COUNTER_DB', '.count.db') device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -112,30 +114,30 @@ NUM_IMAGES_PER_PROMPT = 1 TEST_TIMES = 0 INFER_SPEED = 0 -FILENAME = f"output/port{DEMO_PORT}_inference_count.txt" +def open_db(): + db = sqlite3.connect(COUNTER_DB) + db.execute('CREATE TABLE IF NOT EXISTS counter(app CHARS PRIMARY KEY UNIQUE, value INTEGER)') + db.execute('INSERT OR IGNORE INTO counter(app, value) VALUES("Sana", 0)') + return db -def read_inference_count(): - global TEST_TIMES - try: - with open(FILENAME) as f: - count = int(f.read().strip()) - except FileNotFoundError: - count = 0 - TEST_TIMES = count - return count +def read_inference_count(): + with open_db() as db: + cur = db.execute('SELECT value FROM counter WHERE app="Sana"') + db.commit() + return cur.fetchone()[0] def write_inference_count(count): - with open(FILENAME, "w") as f: - f.write(str(count)) + count = max(0, int(count)) + with open_db() as db: + db.execute(f'UPDATE counter SET value=value+{count} WHERE app="Sana"') + db.commit() def run_inference(num_imgs=1): - TEST_TIMES = read_inference_count() - TEST_TIMES += int(num_imgs) - write_inference_count(TEST_TIMES) + write_inference_count(num_imgs) return ( f"Total inference runs: logo @@ -335,7 +338,7 @@ def generate( .gradio-container{max-width: 640px !important} h1{text-align:center} """ -with gr.Blocks(css=css) as demo: +with gr.Blocks(css=css, title='Sana') as demo: gr.Markdown(title) gr.HTML(DESCRIPTION) gr.DuplicateButton( From f732c8c5fa1cd362f01b03f3fabe659cd0286630 Mon Sep 17 00:00:00 2001 From: LoveSy Date: Sun, 24 Nov 2024 10:01:43 -0500 Subject: [PATCH 3/7] Fix lint --- app/app_sana.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/app/app_sana.py b/app/app_sana.py index 4839c25..1b905be 100755 --- a/app/app_sana.py +++ b/app/app_sana.py @@ -43,7 +43,7 @@ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1" DEMO_PORT = int(os.getenv("DEMO_PORT", "15432")) os.environ["GRADIO_EXAMPLES_CACHE"] = "./.gradio/cache" -COUNTER_DB = os.getenv('COUNTER_DB', '.count.db') +COUNTER_DB = os.getenv("COUNTER_DB", ".count.db") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -115,9 +115,10 @@ TEST_TIMES = 0 INFER_SPEED = 0 + def open_db(): db = sqlite3.connect(COUNTER_DB) - db.execute('CREATE TABLE IF NOT EXISTS counter(app CHARS PRIMARY KEY UNIQUE, value INTEGER)') + db.execute("CREATE TABLE IF NOT EXISTS counter(app CHARS PRIMARY KEY UNIQUE, value INTEGER)") db.execute('INSERT OR IGNORE INTO counter(app, value) VALUES("Sana", 0)') return db @@ -338,7 +339,7 @@ def generate( .gradio-container{max-width: 640px !important} h1{text-align:center} """ -with gr.Blocks(css=css, title='Sana') as demo: +with gr.Blocks(css=css, title="Sana") as demo: gr.Markdown(title) gr.HTML(DESCRIPTION) gr.DuplicateButton( From b0c9c2d9538ef49cad4806b58e852788dedc85d8 Mon Sep 17 00:00:00 2001 From: LoveSy Date: Sun, 24 Nov 2024 18:40:10 -0500 Subject: [PATCH 4/7] Remote TEST_TIMES --- app/app_sana.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/app/app_sana.py b/app/app_sana.py index 1b905be..4f0c43a 100755 --- a/app/app_sana.py +++ b/app/app_sana.py @@ -112,7 +112,6 @@ SCHEDULE_NAME = ["Flow_DPM_Solver"] DEFAULT_SCHEDULE_NAME = "Flow_DPM_Solver" NUM_IMAGES_PER_PROMPT = 1 -TEST_TIMES = 0 INFER_SPEED = 0 @@ -139,10 +138,11 @@ def write_inference_count(count): def run_inference(num_imgs=1): write_inference_count(num_imgs) + count = read_inference_count() return ( f"Total inference runs: {TEST_TIMES}" + f"16px; color:red; font-weight: bold;'>{count}" ) @@ -242,13 +242,12 @@ def generate( flow_dpms_inference_steps: int = 20, randomize_seed: bool = False, ): - global TEST_TIMES global INFER_SPEED # seed = 823753551 run_inference(num_imgs) seed = int(randomize_seed_fn(seed, randomize_seed)) generator = torch.Generator(device=device).manual_seed(seed) - print(f"PORT: {DEMO_PORT}, model_path: {model_path}, time_times: {TEST_TIMES}") + print(f"PORT: {DEMO_PORT}, model_path: {model_path}") if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2): prompt = "A red heart." @@ -302,7 +301,6 @@ def generate( ) -TEST_TIMES = read_inference_count() model_size = "1.6" if "1600M" in args.model_path else "0.6" title = f"""
From 6b7dddbf77240f34a33e59a62dd373232a887165 Mon Sep 17 00:00:00 2001 From: LoveSy Date: Mon, 25 Nov 2024 10:10:02 -0500 Subject: [PATCH 5/7] Update on trigger instead of click --- app/app_sana.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/app/app_sana.py b/app/app_sana.py index 4f0c43a..67dee52 100755 --- a/app/app_sana.py +++ b/app/app_sana.py @@ -244,7 +244,7 @@ def generate( ): global INFER_SPEED # seed = 823753551 - run_inference(num_imgs) + box = run_inference(num_imgs) seed = int(randomize_seed_fn(seed, randomize_seed)) generator = torch.Generator(device=device).manual_seed(seed) print(f"PORT: {DEMO_PORT}, model_path: {model_path}") @@ -298,6 +298,7 @@ def generate( img, seed, f"Inference Speed: {INFER_SPEED:.3f} s/Img", + box, ) @@ -445,8 +446,6 @@ def generate( value=1, ) - run_button.click(fn=run_inference, inputs=num_imgs, outputs=info_box) - gr.Examples( examples=examples, inputs=prompt, @@ -483,7 +482,7 @@ def generate( flow_dpms_inference_steps, randomize_seed, ], - outputs=[result, seed, speed_box], + outputs=[result, seed, speed_box, info_box], api_name="run", ) From b215540bbf013f3ee470a4c4e4e6a003984b8623 Mon Sep 17 00:00:00 2001 From: LoveSy Date: Mon, 25 Nov 2024 10:43:05 -0500 Subject: [PATCH 6/7] Fix batch --- app/sana_pipeline.py | 157 ++++++++++++++++++++++--------------------- 1 file changed, 79 insertions(+), 78 deletions(-) diff --git a/app/sana_pipeline.py b/app/sana_pipeline.py index cfafe2b..e89b063 100644 --- a/app/sana_pipeline.py +++ b/app/sana_pipeline.py @@ -231,86 +231,87 @@ def forward( ), torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1), ) + for _ in range(num_images_per_prompt): - with torch.no_grad(): - prompts.append( - prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip() - ) + prompts.append( + prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip() + ) - # prepare text feature - if not self.config.text_encoder.chi_prompt: - max_length_all = self.config.text_encoder.model_max_length - prompts_all = prompts - else: - chi_prompt = "\n".join(self.config.text_encoder.chi_prompt) - prompts_all = [chi_prompt + prompt for prompt in prompts] - num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) - max_length_all = ( - num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2 - ) # magic number 2: [bos], [_] - - caption_token = self.tokenizer( - prompts_all, - max_length=max_length_all, - padding="max_length", - truncation=True, - return_tensors="pt", - ).to(device=self.device) - select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0)) - caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][ - :, :, select_index - ].to(self.weight_dtype) - emb_masks = caption_token.attention_mask[:, select_index] - null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype) - - n = len(prompts) - if latents is None: - z = torch.randn( - n, - self.config.vae.vae_latent_dim, - self.latent_size_h, - self.latent_size_w, - generator=generator, - device=self.device, - dtype=self.weight_dtype, - ) - else: - z = latents.to(self.weight_dtype).to(self.device) - model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks) - if self.vis_sampler == "flow_euler": - flow_solver = FlowEuler( - self.model, - condition=caption_embs, - uncondition=null_y, - cfg_scale=guidance_scale, - model_kwargs=model_kwargs, - ) - sample = flow_solver.sample( - z, - steps=num_inference_steps, - ) - elif self.vis_sampler == "flow_dpm-solver": - scheduler = DPMS( - self.model, - condition=caption_embs, - uncondition=null_y, - guidance_type=self.guidance_type, - cfg_scale=guidance_scale, - pag_scale=pag_guidance_scale, - pag_applied_layers=self.config.model.pag_applied_layers, - model_type="flow", - model_kwargs=model_kwargs, - schedule="FLOW", - ) - scheduler.register_progress_bar(self.progress_fn) - sample = scheduler.sample( - z, - steps=num_inference_steps, - order=2, - skip_type="time_uniform_flow", - method="multistep", - flow_shift=self.flow_shift, - ) + with torch.no_grad(): + # prepare text feature + if not self.config.text_encoder.chi_prompt: + max_length_all = self.config.text_encoder.model_max_length + prompts_all = prompts + else: + chi_prompt = "\n".join(self.config.text_encoder.chi_prompt) + prompts_all = [chi_prompt + prompt for prompt in prompts] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = ( + num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2 + ) # magic number 2: [bos], [_] + + caption_token = self.tokenizer( + prompts_all, + max_length=max_length_all, + padding="max_length", + truncation=True, + return_tensors="pt", + ).to(device=self.device) + select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0)) + caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][ + :, :, select_index + ].to(self.weight_dtype) + emb_masks = caption_token.attention_mask[:, select_index] + null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype) + + n = len(prompts) + if latents is None: + z = torch.randn( + n, + self.config.vae.vae_latent_dim, + self.latent_size_h, + self.latent_size_w, + generator=generator, + device=self.device, + dtype=self.weight_dtype, + ) + else: + z = latents.to(self.weight_dtype).to(self.device) + model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks) + if self.vis_sampler == "flow_euler": + flow_solver = FlowEuler( + self.model, + condition=caption_embs, + uncondition=null_y, + cfg_scale=guidance_scale, + model_kwargs=model_kwargs, + ) + sample = flow_solver.sample( + z, + steps=num_inference_steps, + ) + elif self.vis_sampler == "flow_dpm-solver": + scheduler = DPMS( + self.model, + condition=caption_embs, + uncondition=null_y, + guidance_type=self.guidance_type, + cfg_scale=guidance_scale, + pag_scale=pag_guidance_scale, + pag_applied_layers=self.config.model.pag_applied_layers, + model_type="flow", + model_kwargs=model_kwargs, + schedule="FLOW", + ) + scheduler.register_progress_bar(self.progress_fn) + sample = scheduler.sample( + z, + steps=num_inference_steps, + order=2, + skip_type="time_uniform_flow", + method="multistep", + flow_shift=self.flow_shift, + ) sample = sample.to(self.weight_dtype) with torch.no_grad(): From 2c0f0c341d6fb6d9361aab0a4b5e77eec215e371 Mon Sep 17 00:00:00 2001 From: lawrence-cj Date: Mon, 25 Nov 2024 23:45:00 +0800 Subject: [PATCH 7/7] pre-commit; Signed-off-by: lawrence-cj --- Dockerfile | 1 - app/sana_pipeline.py | 4 +--- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/Dockerfile b/Dockerfile index dc317e7..530eb8b 100755 --- a/Dockerfile +++ b/Dockerfile @@ -25,4 +25,3 @@ RUN ./environment_setup.sh sana # COPY server.py server.py CMD ["conda", "run", "-n", "sana", "--no-capture-output", "python", "-u", "-W", "ignore", "app/app_sana.py", "--config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml", "--model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth"] - diff --git a/app/sana_pipeline.py b/app/sana_pipeline.py index e89b063..a3251c1 100644 --- a/app/sana_pipeline.py +++ b/app/sana_pipeline.py @@ -233,9 +233,7 @@ def forward( ) for _ in range(num_images_per_prompt): - prompts.append( - prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip() - ) + prompts.append(prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip()) with torch.no_grad(): # prepare text feature