Skip to content

Broken video output with Wan 2.1 I2V pipeline + quantized transformer #11006

Open
@rolux

Description

@rolux

Describe the bug

Since there is no proper documentation yet, I'm not sure if there is a difference to other video pipelines that I'm unaware of – but with the code below, the video results are reproducibly broken.

There is a warning:
Expected types for image_encoder: (<class 'transformers.models.clip.modeling_clip.CLIPVisionModel'>,), got <class 'transformers.models.clip.modeling_clip.CLIPVisionModelWithProjection'>.
which I assume I'm expected to ignore.

Init image:

Image

Result:

test.mp4

Result with different seed:

423258632.0.mp4

Result with different prompt:

423258632.0.mp4

Reproduction

# Tested on Google Colab with an A100 (40GB).
# Uses ~21 GB VRAM, takes ~150 sec per step, ~75 min in total.


!pip install git+https://github.com/huggingface/diffusers.git
!pip install -U bitsandbytes
!pip install ftfy


import os
import torch
from diffusers import (
    BitsAndBytesConfig,
    WanImageToVideoPipeline,
    WanTransformer3DModel
)
from diffusers.utils import export_to_video
from PIL import Image


model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)
transformer = WanTransformer3DModel.from_pretrained(
    model_id,
    subfolder="transformer",
    quantization_config=quantization_config
)
pipe = WanImageToVideoPipeline.from_pretrained(
    model_id,
    transformer=transformer
)
pipe.enable_model_cpu_offload()


def render(
    filename,
    image,
    prompt,
    seed=0,
    width=832,
    height=480,
    num_frames=81,
    num_inference_steps=30,
    guidance_scale=5.0,
    fps=16
):
    video = pipe(
        image=image,
        prompt=prompt,
        generator=torch.Generator(device=pipe.device).manual_seed(seed),
        width=width,
        height=height,
        num_frames=num_frames,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale
    ).frames[0]
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    export_to_video(video, filename, fps=fps)


render(
    filename="/content/test.mp4",
    image=Image.open("/content/test.png"),
    prompt="a woman in a yellow coat is dancing in the desert",
    seed=42
)

Logs

System Info

  • 🤗 Diffusers version: 0.33.0.dev0
  • Platform: Linux-6.1.85+-x86_64-with-glibc2.35
  • Running on Google Colab?: Yes
  • Python version: 3.11.11
  • PyTorch version (GPU?): 2.5.1+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.10.4 (gpu)
  • Jax version: 0.4.33
  • JaxLib version: 0.4.33
  • Huggingface_hub version: 0.28.1
  • Transformers version: 4.48.3
  • Accelerate version: 1.3.0
  • PEFT version: 0.14.0
  • Bitsandbytes version: 0.45.3
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: NVIDIA A100-SXM4-40GB, 40960 MiB

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions