Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix 4K OOM with VAE-tiling #144

Merged
merged 4 commits into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.

## 🔥🔥 News

- (🔥 New) \[2025/1/12\] DC-AE tiling makes Sana-4K inferences 4096x4096px images within 22GB GPU memory.[\[Guidance\]](asset/docs/model_zoo.md#-3-4k-models)
- (🔥 New) \[2025/1/11\] Sana code-base license changed to Apache 2.0.
- (🔥 New) \[2025/1/10\] Inference Sana with 8bit quantization.[\[Guidance\]](asset/docs/8bit_sana.md#quantization)
- (🔥 New) \[2025/1/8\] 4K resolution [Sana models](asset/docs/model_zoo.md) is supported in [Sana-ComfyUI](https://github.com/Efficient-Large-Model/ComfyUI_ExtraModels) and [work flow](asset/docs/ComfyUI/Sana_FlowEuler_4K.json) is also prepared. [\[4K guidance\]](asset/docs/ComfyUI/comfyui.md)
Expand Down
11 changes: 4 additions & 7 deletions asset/docs/model_zoo.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,9 @@ image = pipe(
image[0].save('sana.png')
```

#### 2). For 4K models
## ❗ 3. 4K models

4K models need [patch_conv](https://github.com/mit-han-lab/patch_conv) to avoid OOM issue.(80GB GPU is recommended)

run `pip install patch_conv` first, then
4K models need VAE tiling to avoid OOM issue.(24 GPU is recommended)

```python
# run `pip install git+https://github.com/huggingface/diffusers` before use Sana in diffusers
Expand All @@ -98,10 +96,9 @@ pipe.to("cuda")
pipe.vae.to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)

# for 4096x4096 image generation OOM issue
# for 4096x4096 image generation OOM issue, feel free adjust the tile size
if pipe.transformer.config.sample_size == 128:
from patch_conv import convert_model
pipe.vae = convert_model(pipe.vae, splits=32)
pipe.vae.enable_tiling(tile_sample_min_height=1024, tile_sample_min_width=1024)

prompt = 'a cyberpunk cat with a neon sign that says "Sana"'
image = pipe(
Expand Down
4 changes: 2 additions & 2 deletions configs/sana_config/2048ms/Sana_1600M_img2048_bf16.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ model:
- 8
# VAE setting
vae:
vae_type: dc-ae
vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
vae_type: AutoencoderDC
vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers
scale_factor: 0.41407
vae_latent_dim: 32
vae_downsample_rate: 32
Expand Down
4 changes: 2 additions & 2 deletions configs/sana_config/4096ms/Sana_1600M_img4096_bf16.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ model:
- 8
# VAE setting
vae:
vae_type: dc-ae
vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
vae_type: AutoencoderDC
vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers
scale_factor: 0.41407
vae_latent_dim: 32
vae_downsample_rate: 32
Expand Down
34 changes: 32 additions & 2 deletions diffusion/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# SPDX-License-Identifier: Apache-2.0

import torch
from diffusers import AutoencoderDC
from diffusers.models import AutoencoderKL
from mmcv import Registry
from termcolor import colored
Expand Down Expand Up @@ -87,6 +88,10 @@ def get_vae(name, model_path, device="cuda"):
print(colored(f"[DC-AE] Loading model from {model_path}", attrs=["bold"]))
dc_ae = DCAE_HF.from_pretrained(model_path).to(device).eval()
return dc_ae
elif "AutoencoderDC" in name:
print(colored(f"[AutoencoderDC] Loading model from {model_path}", attrs=["bold"]))
dc_ae = AutoencoderDC.from_pretrained(model_path).to(device).eval()
return dc_ae
else:
print("error load vae")
exit()
Expand All @@ -102,8 +107,14 @@ def vae_encode(name, vae, images, sample_posterior, device):
z = (z - vae.config.shift_factor) * vae.config.scaling_factor
elif "dc-ae" in name:
ae = vae
scaling_factor = ae.cfg.scaling_factor if ae.cfg.scaling_factor else 0.41407
z = ae.encode(images.to(device))
z = z * scaling_factor
elif "AutoencoderDC" in name:
ae = vae
scaling_factor = ae.config.scaling_factor if ae.config.scaling_factor else 0.41407
z = ae.encode(images.to(device))
z = z * ae.cfg.scaling_factor
z = z * scaling_factor
else:
print("error load vae")
exit()
Expand All @@ -116,7 +127,26 @@ def vae_decode(name, vae, latent):
samples = vae.decode(latent).sample
elif "dc-ae" in name:
ae = vae
samples = ae.decode(latent.detach() / ae.cfg.scaling_factor)
vae_scale_factor = (
2 ** (len(ae.config.encoder_block_out_channels) - 1)
if hasattr(ae, "config") and ae.config is not None
else 32
)
scaling_factor = ae.cfg.scaling_factor if ae.cfg.scaling_factor else 0.41407
if latent.shape[-1] * vae_scale_factor > 4000 or latent.shape[-2] * vae_scale_factor > 4000:
from patch_conv import convert_model

ae = convert_model(ae, splits=4)
samples = ae.decode(latent.detach() / scaling_factor)
elif "AutoencoderDC" in name:
ae = vae
scaling_factor = ae.config.scaling_factor if ae.config.scaling_factor else 0.41407
try:
samples = ae.decode(latent / scaling_factor, return_dict=False)[0]
except torch.cuda.OutOfMemoryError as e:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
ae.enable_tiling(tile_sample_min_height=1024, tile_sample_min_width=1024)
samples = ae.decode(latent / scaling_factor, return_dict=False)[0]
else:
print("error load vae")
exit()
Expand Down
Loading