Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broken video output with Wan 2.1 I2V pipeline + quantized transformer #11006

Open
rolux opened this issue Mar 7, 2025 · 6 comments
Open

Broken video output with Wan 2.1 I2V pipeline + quantized transformer #11006

rolux opened this issue Mar 7, 2025 · 6 comments
Labels
bug Something isn't working

Comments

@rolux
Copy link

rolux commented Mar 7, 2025

Describe the bug

Since there is no proper documentation yet, I'm not sure if there is a difference to other video pipelines that I'm unaware of – but with the code below, the video results are reproducibly broken.

There is a warning:
Expected types for image_encoder: (<class 'transformers.models.clip.modeling_clip.CLIPVisionModel'>,), got <class 'transformers.models.clip.modeling_clip.CLIPVisionModelWithProjection'>.
which I assume I'm expected to ignore.

Init image:

Image

Result:

test.mp4

Result with different seed:

423258632.0.mp4

Result with different prompt:

423258632.0.mp4

Reproduction

# Tested on Google Colab with an A100 (40GB).
# Uses ~21 GB VRAM, takes ~150 sec per step, ~75 min in total.


!pip install git+https://github.com/huggingface/diffusers.git
!pip install -U bitsandbytes
!pip install ftfy


import os
import torch
from diffusers import (
    BitsAndBytesConfig,
    WanImageToVideoPipeline,
    WanTransformer3DModel
)
from diffusers.utils import export_to_video
from PIL import Image


model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)
transformer = WanTransformer3DModel.from_pretrained(
    model_id,
    subfolder="transformer",
    quantization_config=quantization_config
)
pipe = WanImageToVideoPipeline.from_pretrained(
    model_id,
    transformer=transformer
)
pipe.enable_model_cpu_offload()


def render(
    filename,
    image,
    prompt,
    seed=0,
    width=832,
    height=480,
    num_frames=81,
    num_inference_steps=30,
    guidance_scale=5.0,
    fps=16
):
    video = pipe(
        image=image,
        prompt=prompt,
        generator=torch.Generator(device=pipe.device).manual_seed(seed),
        width=width,
        height=height,
        num_frames=num_frames,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale
    ).frames[0]
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    export_to_video(video, filename, fps=fps)


render(
    filename="/content/test.mp4",
    image=Image.open("/content/test.png"),
    prompt="a woman in a yellow coat is dancing in the desert",
    seed=42
)

Logs

System Info

  • 🤗 Diffusers version: 0.33.0.dev0
  • Platform: Linux-6.1.85+-x86_64-with-glibc2.35
  • Running on Google Colab?: Yes
  • Python version: 3.11.11
  • PyTorch version (GPU?): 2.5.1+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.10.4 (gpu)
  • Jax version: 0.4.33
  • JaxLib version: 0.4.33
  • Huggingface_hub version: 0.28.1
  • Transformers version: 4.48.3
  • Accelerate version: 1.3.0
  • PEFT version: 0.14.0
  • Bitsandbytes version: 0.45.3
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: NVIDIA A100-SXM4-40GB, 40960 MiB

Who can help?

No response

@rolux rolux added the bug Something isn't working label Mar 7, 2025
@rolux
Copy link
Author

rolux commented Mar 7, 2025

Replaced

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)

with

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

This uses ~34 GB of VRAM, takes ~84 sec per step (~42 min in total), but the output quality does not seem to improve.

8-bit.mp4

@a-r-r-o-w
Copy link
Member

Thanks for the detailed post and reproducible code @rolux. Quantization on video models does not work at times, so it could very well be that. I'll run the same code with torch.bfloat16 first to verify that it's not a problem with our implementation. If not, quantization could probably be resulting in the poorer quality

@a-r-r-o-w
Copy link
Member

Btw, if you have a decent amount of RAM and want to run the model in low VRAM without quantization, I would recommend you to try this: https://huggingface.co/docs/diffusers/main/en/optimization/memory#group-offloading. Combining this with a few other memory optimizations can help you run in 7-10 GB VRAM

@rolux
Copy link
Author

rolux commented Mar 7, 2025

@a-r-r-o-w: Thanks – yes, I had seen this earlier today.

Added a comment in #10999. TLDR: Didn't work for me, as of now.

@tin2tin
Copy link

tin2tin commented Mar 23, 2025

@rolux Did you somehow figure out how to run wan 2.1 i2v locally with Diffusers?

@bghira
Copy link
Contributor

bghira commented Mar 23, 2025

using fp32 also doesn't help the issue unfortunately

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants