diff --git a/README.md b/README.md index b779785..030c286 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ # ⚑️Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer +###
ICLR 2025
+
  @@ -36,18 +38,24 @@ As a result, Sana-0.6B is very competitive with modern giant diffusion models (e ## πŸ”₯πŸ”₯ News -- (πŸ”₯ New) \[2025/1/12\] DC-AE tiling makes Sana-4K inferences 4096x4096px images within 22GB GPU memory.[\[Guidance\]](asset/docs/model_zoo.md#-3-4k-models) +- (πŸ”₯ New) \[2025/1/24\] 4bit-Sana is released, powered by [SVDQuant and Nunchaku](https://github.com/mit-han-lab/nunchaku) inference engine. Now run your Sana within **8GB** GPU VRAM [\[Guidance\]](asset/docs/4bit_sana.md) [\[Demo\]](https://svdquant.mit.edu/) [\[Model\]](asset/docs/model_zoo.md) +- (πŸ”₯ New) \[2025/1/24\] DCAE-1.1 is released, better reconstruction quality. [\[Model\]](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.1) [\[diffusers\]](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers) +- (πŸ”₯ New) \[2025/1/23\] Sana is accepted by ICLR-2025. πŸŽ‰πŸŽ‰πŸŽ‰ + +______________________________________________________________________ + +- (πŸ”₯ New) \[2025/1/12\] DC-AE tiling makes Sana-4K inferences 4096x4096px images within 22GB GPU memory. With model offload and 8bit/4bit quantize. The 4K Sana run within **8GB** GPU VRAM. [\[Guidance\]](asset/docs/model_zoo.md#-3-4k-models) - (πŸ”₯ New) \[2025/1/11\] Sana code-base license changed to Apache 2.0. - (πŸ”₯ New) \[2025/1/10\] Inference Sana with 8bit quantization.[\[Guidance\]](asset/docs/8bit_sana.md#quantization) - (πŸ”₯ New) \[2025/1/8\] 4K resolution [Sana models](asset/docs/model_zoo.md) is supported in [Sana-ComfyUI](https://github.com/Efficient-Large-Model/ComfyUI_ExtraModels) and [work flow](asset/docs/ComfyUI/Sana_FlowEuler_4K.json) is also prepared. [\[4K guidance\]](asset/docs/ComfyUI/comfyui.md) - (πŸ”₯ New) \[2025/1/8\] 1.6B 4K resolution [Sana models](asset/docs/model_zoo.md) are released: [\[BF16 pth\]](https://huggingface.co/Efficient-Large-Model/Sana_1600M_4Kpx_BF16) or [\[BF16 diffusers\]](https://huggingface.co/Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers). πŸš€ Get your 4096x4096 resolution images within 20 seconds! Find more samples in [Sana page](https://nvlabs.github.io/Sana/). Thanks [SUPIR](https://github.com/Fanghua-Yu/SUPIR) for their wonderful work and support. - (πŸ”₯ New) \[2025/1/2\] Bug in the `diffusers` pipeline is solved. [Solved PR](https://github.com/huggingface/diffusers/pull/10431) - (πŸ”₯ New) \[2025/1/2\] 2K resolution [Sana models](asset/docs/model_zoo.md) is supported in [Sana-ComfyUI](https://github.com/Efficient-Large-Model/ComfyUI_ExtraModels) and [work flow](asset/docs/ComfyUI/Sana_FlowEuler_2K.json) is also prepared. -- (πŸ”₯ New) \[2024/12/20\] 1.6B 2K resolution [Sana models](asset/docs/model_zoo.md) are released: [\[BF16 pth\]](https://huggingface.co/Efficient-Large-Model/Sana_1600M_2Kpx_BF16) or [\[BF16 diffusers\]](https://huggingface.co/Efficient-Large-Model/Sana_1600M_2Kpx_BF16_diffusers). πŸš€ Get your 2K resolution images within 4 seconds! Find more samples in [Sana page](https://nvlabs.github.io/Sana/). Thanks [SUPIR](https://github.com/Fanghua-Yu/SUPIR) for their wonderful work and support. -- (πŸ”₯ New) \[2024/12/18\] `diffusers` supports Sana-LoRA fine-tuning! Sana-LoRA's training and convergence speed is super fast. [\[Guidance\]](asset/docs/sana_lora_dreambooth.md) or [\[diffusers docs\]](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sana.md). -- (πŸ”₯ New) \[2024/12/13\] `diffusers` has Sana! [All Sana models in diffusers safetensors](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) are released and diffusers pipeline `SanaPipeline`, `SanaPAGPipeline`, `DPMSolverMultistepScheduler(with FlowMatching)` are all supported now. We prepare a [Model Card](asset/docs/model_zoo.md) for you to choose. -- (πŸ”₯ New) \[2024/12/10\] 1.6B BF16 [Sana model](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16) is released for stable fine-tuning. -- (πŸ”₯ New) \[2024/12/9\] We release the [ComfyUI node](https://github.com/Efficient-Large-Model/ComfyUI_ExtraModels) for Sana. [\[Guidance\]](asset/docs/ComfyUI/comfyui.md) +- βœ… \[2024/12\] 1.6B 2K resolution [Sana models](asset/docs/model_zoo.md) are released: [\[BF16 pth\]](https://huggingface.co/Efficient-Large-Model/Sana_1600M_2Kpx_BF16) or [\[BF16 diffusers\]](https://huggingface.co/Efficient-Large-Model/Sana_1600M_2Kpx_BF16_diffusers). πŸš€ Get your 2K resolution images within 4 seconds! Find more samples in [Sana page](https://nvlabs.github.io/Sana/). Thanks [SUPIR](https://github.com/Fanghua-Yu/SUPIR) for their wonderful work and support. +- βœ… \[2024/12\] `diffusers` supports Sana-LoRA fine-tuning! Sana-LoRA's training and convergence speed is super fast. [\[Guidance\]](asset/docs/sana_lora_dreambooth.md) or [\[diffusers docs\]](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sana.md). +- βœ… \[2024/12\] `diffusers` has Sana! [All Sana models in diffusers safetensors](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) are released and diffusers pipeline `SanaPipeline`, `SanaPAGPipeline`, `DPMSolverMultistepScheduler(with FlowMatching)` are all supported now. We prepare a [Model Card](asset/docs/model_zoo.md) for you to choose. +- βœ… \[2024/12\] 1.6B BF16 [Sana model](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16) is released for stable fine-tuning. +- βœ… \[2024/12\] We release the [ComfyUI node](https://github.com/Efficient-Large-Model/ComfyUI_ExtraModels) for Sana. [\[Guidance\]](asset/docs/ComfyUI/comfyui.md) - βœ… \[2024/11\] All multi-linguistic (Emoji & Chinese & English) SFT models are released: [1.6B-512px](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing), [1.6B-1024px](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing), [600M-512px](https://huggingface.co/Efficient-Large-Model/Sana_600M_512px), [600M-1024px](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px). The metric performance is shown [here](#performance) - βœ… \[2024/11\] Sana Replicate API is launching at [Sana-API](https://replicate.com/chenxwh/sana). - βœ… \[2024/11\] 1.6B [Sana models](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) are released. diff --git a/app/app_sana_4bit.py b/app/app_sana_4bit.py new file mode 100644 index 0000000..44bb3d9 --- /dev/null +++ b/app/app_sana_4bit.py @@ -0,0 +1,409 @@ +#!/usr/bin/env python +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +#!/usr/bin/env python +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import argparse +import os +import random +import time +import uuid +from datetime import datetime + +import gradio as gr +import numpy as np +import spaces +import torch +from diffusers import SanaPipeline +from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel +from torchvision.utils import save_image + +MAX_SEED = np.iinfo(np.int32).max +CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1" +MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096")) +USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1" +ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1" +DEMO_PORT = int(os.getenv("DEMO_PORT", "15432")) +os.environ["GRADIO_EXAMPLES_CACHE"] = "./.gradio/cache" +COUNTER_DB = os.getenv("COUNTER_DB", ".count.db") +INFER_SPEED = 0 + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +style_list = [ + { + "name": "(No style)", + "prompt": "{prompt}", + "negative_prompt": "", + }, + { + "name": "Cinematic", + "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, " + "cinemascope, moody, epic, gorgeous, film grain, grainy", + "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured", + }, + { + "name": "Photographic", + "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed", + "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly", + }, + { + "name": "Anime", + "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed", + "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast", + }, + { + "name": "Manga", + "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style", + "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style", + }, + { + "name": "Digital Art", + "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed", + "negative_prompt": "photo, photorealistic, realism, ugly", + }, + { + "name": "Pixel art", + "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics", + "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic", + }, + { + "name": "Fantasy art", + "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, " + "majestic, magical, fantasy art, cover art, dreamy", + "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, " + "glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, " + "disfigured, sloppy, duplicate, mutated, black and white", + }, + { + "name": "Neonpunk", + "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, " + "detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, " + "ultra detailed, intricate, professional", + "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured", + }, + { + "name": "3D Model", + "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting", + "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting", + }, +] + +styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list} +STYLE_NAMES = list(styles.keys()) +DEFAULT_STYLE_NAME = "(No style)" +SCHEDULE_NAME = ["Flow_DPM_Solver"] +DEFAULT_SCHEDULE_NAME = "Flow_DPM_Solver" +NUM_IMAGES_PER_PROMPT = 1 + + +def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]: + p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME]) + if not negative: + negative = "" + return p.replace("{prompt}", positive), n + negative + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_path", + nargs="?", + default="Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", + type=str, + help="Path to the model file (positional)", + ) + parser.add_argument("--share", action="store_true") + + return parser.parse_known_args()[0] + + +args = get_args() + +if torch.cuda.is_available(): + + transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m") + pipe = SanaPipeline.from_pretrained( + "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", + transformer=transformer, + variant="bf16", + torch_dtype=torch.bfloat16, + ).to(device) + + pipe.text_encoder.to(torch.bfloat16) + pipe.vae.to(torch.bfloat16) + + +def save_image_sana(img, seed="", save_img=False): + unique_name = f"{str(uuid.uuid4())}_{seed}.png" + save_path = os.path.join(f"output/online_demo_img/{datetime.now().date()}") + os.umask(0o000) # file permission: 666; dir permission: 777 + os.makedirs(save_path, exist_ok=True) + unique_name = os.path.join(save_path, unique_name) + if save_img: + save_image(img, unique_name, nrow=1, normalize=True, value_range=(-1, 1)) + + return unique_name + + +def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: + if randomize_seed: + seed = random.randint(0, MAX_SEED) + return seed + + +@torch.no_grad() +@torch.inference_mode() +@spaces.GPU(enable_queue=True) +def generate( + prompt: str = None, + negative_prompt: str = "", + style: str = DEFAULT_STYLE_NAME, + use_negative_prompt: bool = False, + num_imgs: int = 1, + seed: int = 0, + height: int = 1024, + width: int = 1024, + flow_dpms_guidance_scale: float = 5.0, + flow_dpms_inference_steps: int = 20, + randomize_seed: bool = False, +): + global INFER_SPEED + # seed = 823753551 + seed = int(randomize_seed_fn(seed, randomize_seed)) + generator = torch.Generator(device=device).manual_seed(seed) + print(f"PORT: {DEMO_PORT}, model_path: {args.model_path}") + + print(prompt) + + num_inference_steps = flow_dpms_inference_steps + guidance_scale = flow_dpms_guidance_scale + + if not use_negative_prompt: + negative_prompt = None # type: ignore + prompt, negative_prompt = apply_style(style, prompt, negative_prompt) + + time_start = time.time() + images = pipe( + prompt=prompt, + height=height, + width=width, + negative_prompt=negative_prompt, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + num_images_per_prompt=num_imgs, + generator=generator, + ).images + INFER_SPEED = (time.time() - time_start) / num_imgs + + save_img = False + if save_img: + img = [save_image_sana(img, seed, save_img=save_image) for img in images] + print(img) + else: + img = images + + torch.cuda.empty_cache() + + return ( + img, + seed, + f"Inference Speed: {INFER_SPEED:.3f} s/Img", + ) + + +model_size = "1.6" if "1600M" in args.model_path else "0.6" +title = f""" +
+ logo +
+""" +DESCRIPTION = f""" +

Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer (4bit version)

+ """ +if model_size == "0.6": + DESCRIPTION += "\n

0.6B model's text rendering ability is limited.

" +if not torch.cuda.is_available(): + DESCRIPTION += "\n

Running on CPU πŸ₯Ά This demo does not work on CPU.

" + +examples = [ + 'a cyberpunk cat with a neon sign that says "Sana"', + "A very detailed and realistic full body photo set of a tall, slim, and athletic Shiba Inu in a white oversized straight t-shirt, white shorts, and short white shoes.", + "Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.", + "portrait photo of a girl, photograph, highly detailed face, depth of field", + 'make me a logo that says "So Fast" with a really cool flying dragon shape with lightning sparks all over the sides and all of it contains Indonesian language', + "🐢 Wearing πŸ•Ά flying on the 🌈", + "πŸ‘§ with 🌹 in the ❄️", + "an old rusted robot wearing pants and a jacket riding skis in a supermarket.", + "professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.", + "Astronaut in a jungle, cold color palette, muted colors, detailed", + "a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests", +] + +css = """ +.gradio-container {max-width: 850px !important; height: auto !important;} +h1 {text-align: center;} +""" +theme = gr.themes.Base() +with gr.Blocks(css=css, theme=theme, title="Sana") as demo: + gr.Markdown(title) + gr.HTML(DESCRIPTION) + gr.DuplicateButton( + value="Duplicate Space for private use", + elem_id="duplicate-button", + visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", + ) + # with gr.Row(equal_height=False): + with gr.Group(): + with gr.Row(): + prompt = gr.Text( + label="Prompt", + show_label=False, + max_lines=1, + placeholder="Enter your prompt", + container=False, + ) + run_button = gr.Button("Run", scale=0) + result = gr.Gallery( + label="Result", + show_label=False, + height=750, + columns=NUM_IMAGES_PER_PROMPT, + format="jpeg", + ) + + speed_box = gr.Markdown( + value=f"Inference speed: {INFER_SPEED} s/Img" + ) + with gr.Accordion("Advanced options", open=False): + with gr.Group(): + with gr.Row(visible=True): + height = gr.Slider( + label="Height", + minimum=256, + maximum=MAX_IMAGE_SIZE, + step=32, + value=1024, + ) + width = gr.Slider( + label="Width", + minimum=256, + maximum=MAX_IMAGE_SIZE, + step=32, + value=1024, + ) + with gr.Row(): + flow_dpms_inference_steps = gr.Slider( + label="Sampling steps", + minimum=5, + maximum=40, + step=1, + value=20, + ) + flow_dpms_guidance_scale = gr.Slider( + label="CFG Guidance scale", + minimum=1, + maximum=10, + step=0.1, + value=4.5, + ) + with gr.Row(): + use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True) + negative_prompt = gr.Text( + label="Negative prompt", + max_lines=1, + placeholder="Enter a negative prompt", + visible=True, + ) + style_selection = gr.Radio( + show_label=True, + container=True, + interactive=True, + choices=STYLE_NAMES, + value=DEFAULT_STYLE_NAME, + label="Image Style", + ) + seed = gr.Slider( + label="Seed", + minimum=0, + maximum=MAX_SEED, + step=1, + value=0, + ) + randomize_seed = gr.Checkbox(label="Randomize seed", value=True) + with gr.Row(visible=True): + schedule = gr.Radio( + show_label=True, + container=True, + interactive=True, + choices=SCHEDULE_NAME, + value=DEFAULT_SCHEDULE_NAME, + label="Sampler Schedule", + visible=True, + ) + num_imgs = gr.Slider( + label="Num Images", + minimum=1, + maximum=6, + step=1, + value=1, + ) + + gr.Examples( + examples=examples, + inputs=prompt, + outputs=[result, seed], + fn=generate, + cache_examples=CACHE_EXAMPLES, + ) + + use_negative_prompt.change( + fn=lambda x: gr.update(visible=x), + inputs=use_negative_prompt, + outputs=negative_prompt, + api_name=False, + ) + + gr.on( + triggers=[ + prompt.submit, + negative_prompt.submit, + run_button.click, + ], + fn=generate, + inputs=[ + prompt, + negative_prompt, + style_selection, + use_negative_prompt, + num_imgs, + seed, + height, + width, + flow_dpms_guidance_scale, + flow_dpms_inference_steps, + randomize_seed, + ], + outputs=[result, seed, speed_box], + api_name="run", + ) + +if __name__ == "__main__": + demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=args.share) diff --git a/app/app_sana_4bit_compare_bf16.py b/app/app_sana_4bit_compare_bf16.py new file mode 100644 index 0000000..e378a22 --- /dev/null +++ b/app/app_sana_4bit_compare_bf16.py @@ -0,0 +1,313 @@ +# Changed from https://huggingface.co/spaces/playgroundai/playground-v2.5/blob/main/app.py +import argparse +import os +import random +import time +from datetime import datetime + +import GPUtil + +# import gradio last to avoid conflicts with other imports +import gradio as gr +import safety_check +import spaces +import torch +from diffusers import SanaPipeline +from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel +from transformers import AutoModelForCausalLM, AutoTokenizer + +MAX_IMAGE_SIZE = 2048 +MAX_SEED = 1000000000 + +DEFAULT_HEIGHT = 1024 +DEFAULT_WIDTH = 1024 + +# num_inference_steps, guidance_scale, seed +EXAMPLES = [ + [ + "🐢 Wearing πŸ•Ά flying on the 🌈", + 1024, + 1024, + 20, + 5, + 2, + ], + [ + "ε€§ζΌ ε­€ηƒŸη›΄, 长河落ζ—₯εœ†", + 1024, + 1024, + 20, + 5, + 23, + ], + [ + "Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, " + "volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, " + "art nouveau style, illustration art artwork by SenseiJaye, intricate detail.", + 1024, + 1024, + 20, + 5, + 233, + ], + [ + "A photo of a Eurasian lynx in a sunlit forest, with tufted ears and a spotted coat. The lynx should be " + "sharply focused, gazing into the distance, while the background is softly blurred for depth. Use cinematic " + "lighting with soft rays filtering through the trees, and capture the scene with a shallow depth of field " + "for a natural, peaceful atmosphere. 8K resolution, highly detailed, photorealistic, " + "cinematic lighting, ultra-HD.", + 1024, + 1024, + 20, + 5, + 2333, + ], + [ + "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. " + "She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. " + "She wears sunglasses and red lipstick. She walks confidently and casually. " + "The street is damp and reflective, creating a mirror effect of the colorful lights. " + "Many pedestrians walk about.", + 1024, + 1024, + 20, + 5, + 23333, + ], + [ + "Cozy bedroom with vintage wooden furniture and a large circular window covered in lush green vines, " + "opening to a misty forest. Soft, ambient lighting highlights the bed with crumpled blankets, a bookshelf, " + "and a desk. The atmosphere is serene and natural. 8K resolution, highly detailed, photorealistic, " + "cinematic lighting, ultra-HD.", + 1024, + 1024, + 20, + 5, + 233333, + ], +] + + +def hash_str_to_int(s: str) -> int: + """Hash a string to an integer.""" + modulus = 10**9 + 7 # Large prime modulus + hash_int = 0 + for char in s: + hash_int = (hash_int * 31 + ord(char)) % modulus + return hash_int + + +def get_pipeline( + precision: str, use_qencoder: bool = False, device: str | torch.device = "cuda", pipeline_init_kwargs: dict = {} +) -> SanaPipeline: + if precision == "int4": + assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices" + transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m") + + pipeline_init_kwargs["transformer"] = transformer + if use_qencoder: + raise NotImplementedError("Quantized encoder not supported for Sana for now") + else: + assert precision == "bf16" + pipeline = SanaPipeline.from_pretrained( + "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", + variant="bf16", + torch_dtype=torch.bfloat16, + **pipeline_init_kwargs, + ) + + pipeline = pipeline.to(device) + return pipeline + + +def get_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "-p", + "--precisions", + type=str, + default=["int4"], + nargs="*", + choices=["int4", "bf16"], + help="Which precisions to use", + ) + parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder") + parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker") + parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses") + return parser.parse_args() + + +args = get_args() + + +pipelines = [] +pipeline_init_kwargs = {} +for i, precision in enumerate(args.precisions): + + pipeline = get_pipeline( + precision=precision, + use_qencoder=args.use_qencoder, + device="cuda", + pipeline_init_kwargs={**pipeline_init_kwargs}, + ) + pipelines.append(pipeline) + if i == 0: + pipeline_init_kwargs["vae"] = pipeline.vae + pipeline_init_kwargs["text_encoder"] = pipeline.text_encoder + +# safety checker +safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path) +safety_checker_model = AutoModelForCausalLM.from_pretrained( + args.shield_model_path, + device_map="auto", + torch_dtype=torch.bfloat16, +).to(pipeline.device) + + +@spaces.GPU(enable_queue=True) +def generate( + prompt: str = None, + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 4, + guidance_scale: float = 0, + seed: int = 0, +): + print(f"Prompt: {prompt}") + is_unsafe_prompt = False + if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2): + prompt = "A peaceful world." + images, latency_strs = [], [] + for i, pipeline in enumerate(pipelines): + progress = gr.Progress(track_tqdm=True) + start_time = time.time() + image = pipeline( + prompt=prompt, + height=height, + width=width, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator=torch.Generator().manual_seed(seed), + ).images[0] + end_time = time.time() + latency = end_time - start_time + if latency < 1: + latency = latency * 1000 + latency_str = f"{latency:.2f}ms" + else: + latency_str = f"{latency:.2f}s" + images.append(image) + latency_strs.append(latency_str) + if is_unsafe_prompt: + for i in range(len(latency_strs)): + latency_strs[i] += " (Unsafe prompt detected)" + torch.cuda.empty_cache() + + if args.count_use: + if os.path.exists("use_count.txt"): + with open("use_count.txt") as f: + count = int(f.read()) + else: + count = 0 + count += 1 + current_time = datetime.now() + print(f"{current_time}: {count}") + with open("use_count.txt", "w") as f: + f.write(str(count)) + with open("use_record.txt", "a") as f: + f.write(f"{current_time}: {count}\n") + + return *images, *latency_strs + + +with open("./assets/description.html") as f: + DESCRIPTION = f.read() +gpus = GPUtil.getGPUs() +if len(gpus) > 0: + gpu = gpus[0] + memory = gpu.memoryTotal / 1024 + device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory." +else: + device_info = "Running on CPU πŸ₯Ά This demo does not work on CPU." +notice = f'Notice: We will replace unsafe prompts with a default prompt: "A peaceful world."' + +with gr.Blocks( + css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"], + title=f"SVDQuant SANA-1600M Demo", +) as demo: + + def get_header_str(): + + if args.count_use: + if os.path.exists("use_count.txt"): + with open("use_count.txt") as f: + count = int(f.read()) + else: + count = 0 + count_info = ( + f"
" + f"Total inference runs: " + f" {count}
" + ) + else: + count_info = "" + header_str = DESCRIPTION.format(device_info=device_info, notice=notice, count_info=count_info) + return header_str + + header = gr.HTML(get_header_str()) + demo.load(fn=get_header_str, outputs=header) + + with gr.Row(): + image_results, latency_results = [], [] + for i, precision in enumerate(args.precisions): + with gr.Column(): + gr.Markdown(f"# {precision.upper()}", elem_id="image_header") + with gr.Group(): + image_result = gr.Image( + format="png", + image_mode="RGB", + label="Result", + show_label=False, + show_download_button=True, + interactive=False, + ) + latency_result = gr.Text(label="Inference Latency", show_label=True) + image_results.append(image_result) + latency_results.append(latency_result) + with gr.Row(): + prompt = gr.Text( + label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, scale=4 + ) + run_button = gr.Button("Run", scale=1) + + with gr.Row(): + seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4) + randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed") + with gr.Accordion("Advanced options", open=False): + with gr.Group(): + height = gr.Slider(label="Height", minimum=256, maximum=4096, step=32, value=1024) + width = gr.Slider(label="Width", minimum=256, maximum=4096, step=32, value=1024) + with gr.Group(): + num_inference_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, step=1, value=20) + guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=5) + + input_args = [prompt, height, width, num_inference_steps, guidance_scale, seed] + + gr.Examples(examples=EXAMPLES, inputs=input_args, outputs=[*image_results, *latency_results], fn=generate) + + gr.on( + triggers=[prompt.submit, run_button.click], + fn=generate, + inputs=input_args, + outputs=[*image_results, *latency_results], + api_name="run", + ) + randomize_seed.click( + lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, api_name=False, queue=False + ).then(fn=generate, inputs=input_args, outputs=[*image_results, *latency_results], api_name=False, queue=False) + + gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility") + + +if __name__ == "__main__": + demo.queue(max_size=20).launch(server_name="0.0.0.0", debug=True, share=True) diff --git a/asset/docs/4bit_sana.md b/asset/docs/4bit_sana.md new file mode 100644 index 0000000..e79079a --- /dev/null +++ b/asset/docs/4bit_sana.md @@ -0,0 +1,68 @@ + + +# 4bit SanaPipeline + +### 1. Environment setup + +Follow the official [SVDQuant-Nunchaku](https://github.com/mit-han-lab/nunchaku) repository to set up the environment. The guidance can be found [here](https://github.com/mit-han-lab/nunchaku?tab=readme-ov-file#installation). + +### 2. Code snap for inference + +Here we show the code snippet for SanaPipeline. For SanaPAGPipeline, please refer to the [SanaPAGPipeline](https://github.com/mit-han-lab/nunchaku/blob/main/examples/sana_1600m_pag.py) section. + +```python +import torch +from diffusers import SanaPipeline + +from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel + +transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m") +pipe = SanaPipeline.from_pretrained( + "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", + transformer=transformer, + variant="bf16", + torch_dtype=torch.bfloat16, +).to("cuda") + +pipe.text_encoder.to(torch.bfloat16) +pipe.vae.to(torch.bfloat16) + +image = pipe( + prompt="A cute 🐼 eating πŸŽ‹, ink drawing style", + height=1024, + width=1024, + guidance_scale=4.5, + num_inference_steps=20, + generator=torch.Generator().manual_seed(42), +).images[0] +image.save("sana_1600m.png") +``` + +### 3. Online demo + +1). Launch the 4bit Sana. + +```bash +python app/app_sana_4bit.py +``` + +2). Compare with BF16 version + +Refer to the original [Nunchaku-Sana.](https://github.com/mit-han-lab/nunchaku/tree/main/app/sana/t2i) guidance for SanaPAGPipeline + +```bash +python app/app_sana_4bit_compare_bf16.py +``` diff --git a/asset/docs/model_zoo.md b/asset/docs/model_zoo.md index 01ea915..5c7c2ee 100644 --- a/asset/docs/model_zoo.md +++ b/asset/docs/model_zoo.md @@ -9,6 +9,7 @@ | Sana-1.6B | 1024px | [Sana_1600M_1024px](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px) | [Efficient-Large-Model/Sana_1600M_1024px_diffusers](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_diffusers) | fp16/fp32 | - | | Sana-1.6B | 1024px | [Sana_1600M_1024px_MultiLing](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing) | [Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers) | fp16/fp32 | Multi-Language | | Sana-1.6B | 1024px | [Sana_1600M_1024px_BF16](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16) | [Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | **bf16**/fp32 | Multi-Language | +| Sana-1.6B | 1024px | - | [mit-han-lab/svdq-int4-sana-1600m](https://huggingface.co/mit-han-lab/svdq-int4-sana-1600m) | **int4** | Multi-Language | | Sana-1.6B | 2Kpx | [Sana_1600M_2Kpx_BF16](https://huggingface.co/Efficient-Large-Model/Sana_1600M_2Kpx_BF16) | [Efficient-Large-Model/Sana_1600M_2Kpx_BF16_diffusers](https://huggingface.co/Efficient-Large-Model/Sana_1600M_2Kpx_BF16_diffusers) | **bf16**/fp32 | Multi-Language | | Sana-1.6B | 4Kpx | [Sana_1600M_4Kpx_BF16](https://huggingface.co/Efficient-Large-Model/Sana_1600M_4Kpx_BF16) | [Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers](https://huggingface.co/Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers) | **bf16**/fp32 | Multi-Language | @@ -79,7 +80,7 @@ image[0].save('sana.png') ## ❗ 3. 4K models -4K models need VAE tiling to avoid OOM issue.(24 GPU is recommended) +4K models need VAE tiling to avoid OOM issue.(16 GPU is recommended) ```python # run `pip install git+https://github.com/huggingface/diffusers` before use Sana in diffusers @@ -98,8 +99,12 @@ pipe.text_encoder.to(torch.bfloat16) # for 4096x4096 image generation OOM issue, feel free adjust the tile size if pipe.transformer.config.sample_size == 128: - pipe.vae.enable_tiling(tile_sample_min_height=1024, tile_sample_min_width=1024) - + pipe.vae.enable_tiling( + tile_sample_min_height=1024, + tile_sample_min_width=1024, + tile_sample_stride_height=896, + tile_sample_stride_width=896, + ) prompt = 'a cyberpunk cat with a neon sign that says "Sana"' image = pipe( prompt=prompt, @@ -112,3 +117,37 @@ image = pipe( image[0].save("sana_4K.png") ``` + +## ❗ 4. int4 inference + +This int4 model is quantized with [SVDQuant-Nunchaku](https://github.com/mit-han-lab/nunchaku). You need first follow the [guidance of installation](https://github.com/mit-han-lab/nunchaku?tab=readme-ov-file#installation) of nunchaku engine, then you can use the following code snippet to perform inference with int4 Sana model. + +Here we show the code snippet for SanaPipeline. For SanaPAGPipeline, please refer to the [SanaPAGPipeline](https://github.com/mit-han-lab/nunchaku/blob/main/examples/sana_1600m_pag.py) section. + +```python +import torch +from diffusers import SanaPipeline + +from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel + +transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m") +pipe = SanaPipeline.from_pretrained( + "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", + transformer=transformer, + variant="bf16", + torch_dtype=torch.bfloat16, +).to("cuda") + +pipe.text_encoder.to(torch.bfloat16) +pipe.vae.to(torch.bfloat16) + +image = pipe( + prompt="A cute 🐼 eating πŸŽ‹, ink drawing style", + height=1024, + width=1024, + guidance_scale=4.5, + num_inference_steps=20, + generator=torch.Generator().manual_seed(42), +).images[0] +image.save("sana_1600m.png") +```