diff --git a/Dockerfile b/Dockerfile index dc317e7..530eb8b 100755 --- a/Dockerfile +++ b/Dockerfile @@ -25,4 +25,3 @@ RUN ./environment_setup.sh sana # COPY server.py server.py CMD ["conda", "run", "-n", "sana", "--no-capture-output", "python", "-u", "-W", "ignore", "app/app_sana.py", "--config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml", "--model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth"] - diff --git a/app/app_sana.py b/app/app_sana.py index a392924..67dee52 100755 --- a/app/app_sana.py +++ b/app/app_sana.py @@ -19,6 +19,8 @@ import argparse import os import random +import socket +import sqlite3 import time import uuid from datetime import datetime @@ -41,6 +43,7 @@ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1" DEMO_PORT = int(os.getenv("DEMO_PORT", "15432")) os.environ["GRADIO_EXAMPLES_CACHE"] = "./.gradio/cache" +COUNTER_DB = os.getenv("COUNTER_DB", ".count.db") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -109,36 +112,37 @@ SCHEDULE_NAME = ["Flow_DPM_Solver"] DEFAULT_SCHEDULE_NAME = "Flow_DPM_Solver" NUM_IMAGES_PER_PROMPT = 1 -TEST_TIMES = 0 INFER_SPEED = 0 -FILENAME = f"output/port{DEMO_PORT}_inference_count.txt" -def read_inference_count(): - global TEST_TIMES - try: - with open(FILENAME) as f: - count = int(f.read().strip()) - except FileNotFoundError: - count = 0 - TEST_TIMES = count +def open_db(): + db = sqlite3.connect(COUNTER_DB) + db.execute("CREATE TABLE IF NOT EXISTS counter(app CHARS PRIMARY KEY UNIQUE, value INTEGER)") + db.execute('INSERT OR IGNORE INTO counter(app, value) VALUES("Sana", 0)') + return db + - return count +def read_inference_count(): + with open_db() as db: + cur = db.execute('SELECT value FROM counter WHERE app="Sana"') + db.commit() + return cur.fetchone()[0] def write_inference_count(count): - with open(FILENAME, "w") as f: - f.write(str(count)) + count = max(0, int(count)) + with open_db() as db: + db.execute(f'UPDATE counter SET value=value+{count} WHERE app="Sana"') + db.commit() def run_inference(num_imgs=1): - TEST_TIMES = read_inference_count() - TEST_TIMES += int(num_imgs) - write_inference_count(TEST_TIMES) + write_inference_count(num_imgs) + count = read_inference_count() return ( f"Total inference runs: {TEST_TIMES}" + f"16px; color:red; font-weight: bold;'>{count}" ) @@ -238,12 +242,12 @@ def generate( flow_dpms_inference_steps: int = 20, randomize_seed: bool = False, ): - global TEST_TIMES global INFER_SPEED # seed = 823753551 + box = run_inference(num_imgs) seed = int(randomize_seed_fn(seed, randomize_seed)) generator = torch.Generator(device=device).manual_seed(seed) - print(f"PORT: {DEMO_PORT}, model_path: {model_path}, time_times: {TEST_TIMES}") + print(f"PORT: {DEMO_PORT}, model_path: {model_path}") if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2): prompt = "A red heart." @@ -294,11 +298,11 @@ def generate( img, seed, f"Inference Speed: {INFER_SPEED:.3f} s/Img", + box, ) -TEST_TIMES = read_inference_count() -model_size = "1.6" if "D20" in args.model_path else "0.6" +model_size = "1.6" if "1600M" in args.model_path else "0.6" title = f"""
Sana-{model_size}B{args.image_size}px
Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer
[Paper] [Github(coming soon)] [Project] -
Powered by DC-AE with 32x latent space
, running on A6000 node. +Powered by DC-AE with 32x latent space,
running on node {socket.gethostname()}.Unsafe word will give you a 'Red Heart' in the image instead.
""" if model_size == "0.6": @@ -334,9 +338,9 @@ def generate( .gradio-container{max-width: 640px !important} h1{text-align:center} """ -with gr.Blocks(css=css) as demo: +with gr.Blocks(css=css, title="Sana") as demo: gr.Markdown(title) - gr.Markdown(DESCRIPTION) + gr.HTML(DESCRIPTION) gr.DuplicateButton( value="Duplicate Space for private use", elem_id="duplicate-button", @@ -442,8 +446,6 @@ def generate( value=1, ) - run_button.click(fn=run_inference, inputs=num_imgs, outputs=info_box) - gr.Examples( examples=examples, inputs=prompt, @@ -480,9 +482,9 @@ def generate( flow_dpms_inference_steps, randomize_seed, ], - outputs=[result, seed, speed_box], + outputs=[result, seed, speed_box, info_box], api_name="run", ) if __name__ == "__main__": - demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=True, share=True) + demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=False) diff --git a/app/sana_pipeline.py b/app/sana_pipeline.py index cfafe2b..a3251c1 100644 --- a/app/sana_pipeline.py +++ b/app/sana_pipeline.py @@ -231,86 +231,85 @@ def forward( ), torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1), ) + for _ in range(num_images_per_prompt): - with torch.no_grad(): - prompts.append( - prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip() - ) + prompts.append(prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip()) - # prepare text feature - if not self.config.text_encoder.chi_prompt: - max_length_all = self.config.text_encoder.model_max_length - prompts_all = prompts - else: - chi_prompt = "\n".join(self.config.text_encoder.chi_prompt) - prompts_all = [chi_prompt + prompt for prompt in prompts] - num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) - max_length_all = ( - num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2 - ) # magic number 2: [bos], [_] - - caption_token = self.tokenizer( - prompts_all, - max_length=max_length_all, - padding="max_length", - truncation=True, - return_tensors="pt", - ).to(device=self.device) - select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0)) - caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][ - :, :, select_index - ].to(self.weight_dtype) - emb_masks = caption_token.attention_mask[:, select_index] - null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype) - - n = len(prompts) - if latents is None: - z = torch.randn( - n, - self.config.vae.vae_latent_dim, - self.latent_size_h, - self.latent_size_w, - generator=generator, - device=self.device, - dtype=self.weight_dtype, - ) - else: - z = latents.to(self.weight_dtype).to(self.device) - model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks) - if self.vis_sampler == "flow_euler": - flow_solver = FlowEuler( - self.model, - condition=caption_embs, - uncondition=null_y, - cfg_scale=guidance_scale, - model_kwargs=model_kwargs, - ) - sample = flow_solver.sample( - z, - steps=num_inference_steps, - ) - elif self.vis_sampler == "flow_dpm-solver": - scheduler = DPMS( - self.model, - condition=caption_embs, - uncondition=null_y, - guidance_type=self.guidance_type, - cfg_scale=guidance_scale, - pag_scale=pag_guidance_scale, - pag_applied_layers=self.config.model.pag_applied_layers, - model_type="flow", - model_kwargs=model_kwargs, - schedule="FLOW", - ) - scheduler.register_progress_bar(self.progress_fn) - sample = scheduler.sample( - z, - steps=num_inference_steps, - order=2, - skip_type="time_uniform_flow", - method="multistep", - flow_shift=self.flow_shift, - ) + with torch.no_grad(): + # prepare text feature + if not self.config.text_encoder.chi_prompt: + max_length_all = self.config.text_encoder.model_max_length + prompts_all = prompts + else: + chi_prompt = "\n".join(self.config.text_encoder.chi_prompt) + prompts_all = [chi_prompt + prompt for prompt in prompts] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = ( + num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2 + ) # magic number 2: [bos], [_] + + caption_token = self.tokenizer( + prompts_all, + max_length=max_length_all, + padding="max_length", + truncation=True, + return_tensors="pt", + ).to(device=self.device) + select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0)) + caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][ + :, :, select_index + ].to(self.weight_dtype) + emb_masks = caption_token.attention_mask[:, select_index] + null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype) + + n = len(prompts) + if latents is None: + z = torch.randn( + n, + self.config.vae.vae_latent_dim, + self.latent_size_h, + self.latent_size_w, + generator=generator, + device=self.device, + dtype=self.weight_dtype, + ) + else: + z = latents.to(self.weight_dtype).to(self.device) + model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks) + if self.vis_sampler == "flow_euler": + flow_solver = FlowEuler( + self.model, + condition=caption_embs, + uncondition=null_y, + cfg_scale=guidance_scale, + model_kwargs=model_kwargs, + ) + sample = flow_solver.sample( + z, + steps=num_inference_steps, + ) + elif self.vis_sampler == "flow_dpm-solver": + scheduler = DPMS( + self.model, + condition=caption_embs, + uncondition=null_y, + guidance_type=self.guidance_type, + cfg_scale=guidance_scale, + pag_scale=pag_guidance_scale, + pag_applied_layers=self.config.model.pag_applied_layers, + model_type="flow", + model_kwargs=model_kwargs, + schedule="FLOW", + ) + scheduler.register_progress_bar(self.progress_fn) + sample = scheduler.sample( + z, + steps=num_inference_steps, + order=2, + skip_type="time_uniform_flow", + method="multistep", + flow_shift=self.flow_shift, + ) sample = sample.to(self.weight_dtype) with torch.no_grad():