Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EasyAnimateV5.1 text-to-video, image-to-video, control-to-video generation model #10626

Merged
merged 34 commits into from
Mar 3, 2025

Conversation

bubbliiiing
Copy link
Contributor

@bubbliiiing bubbliiiing commented Jan 22, 2025

What does this PR do?

This PR converts the EasyAnimateV5.1 model into a diffuser-supported inference model, including three complete pipelines and corresponding modules.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@a-r-r-o-w

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR @bubbliiiing! This is in great shape and already mostly in the implementation style used in diffusers 🤗

I've left some comments from a quick look through the PR. Happy to help make any of the required changes to help bring the PR to completion

encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
attn2: Attention = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems similar to Flux/SD3/HunyuanVideo's Joint-attention processors that concatenate the visual and text tokens. Let's do it the same way as done here:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean using add_q, add_k, add_v instead of using attn2?
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two things we could do here:

  • Either convert the state dict of the original-format models (that you currently have on the HuggingFace Hub) and update them to diffusers-format (which would make attn2.to_q -> add_q_proj, attn2.to_k -> add_k_proj, attn2.to_v -> add_v_proj
  • Create a custom attention class similar to Attention and MochiAttention in which you are free to use layer naming of your choice (so basically keeping the same to_q, to_k and to_v.

The first approach is more closely aligned with diffusers code style but would require you to update multiple checkpoints -- but we are transitioning to a single file modeling format, so if you choose to go with second approach for convenience, that works for us as well. Essentially, irrespective of the design you choose, we need to make sure:

  • When the forward of a layer is called, it only takes tensors as input and produces tensors as output.
  • Taking intermediate layers as input to forward, or making calls to other layers out-of-order randomly, is not supported by our design style of different current/upcoming features

cc @DN6 here in case you have thoughts about the single file format and model-specific Attention classes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved attn2 to the processor's init; does this meet the requirements?

for name, module in self.named_children():
_set_3dgroupnorm_for_submodule(name, module)

def single_forward(self, x: torch.Tensor) -> torch.Tensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very different from diffusers-style implementation of encoder/decoder. Could we follow the style as done in:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, does this mean that I cannot use functions like set_padding_one_frame?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I need to use this conv_cache in autoencoder_kl_mochi?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the model implementations, we usually only try to keep (atleast in the latest model integrations):

  • Submodel initializations
  • Forward method

So, unless a helper function like set_padding_one_frame is used in multiple locations, I would suggest directly substituting its code in the forward implementation. If a helper function is required, let's make it a private function by prefixing the function name with an underscore

The conv_cache saves a few computations when running the VAE encode/decode process from repeated frames that are used as padding. As such, it is not required to implement it if it is not needed for framewise encoding and decoding.

@yiyixuxu yiyixuxu added the roadmap Add to current release roadmap label Jan 22, 2025
@bubbliiiing
Copy link
Contributor Author

Sorry for not standardizing some parts; I will make the necessary modifications. Also, I would like to ask if I need to add test files in tests/pipelines and add documentation in docs/source/en/api/pipelines?

@a-r-r-o-w
Copy link
Member

Yes, we will need a test for all three pipelines as well as model tests in tests/models, as well as the documentation pages. Ofcourse, I will try to help you with everything :)

Also, congratulations on the release! I tried out the original repository example and the model is very good! 🎉

@a-r-r-o-w
Copy link
Member

Thank you so much for addressing the reviews @bubbliiiing! The PR is almost ready to merge IMO. There are just a few more small changes to make to align to our model implementation design. Is it okay if I quickly push some last changes to this branch directly?

@bubbliiiing
Copy link
Contributor Author

Thank you so much for addressing the reviews @bubbliiiing! The PR is almost ready to merge IMO. There are just a few more small changes to make to align to our model implementation design. Is it okay if I quickly push some last changes to this branch directly?

Of course, thank you very much for your help.

@nitinmukesh
Copy link

For some reason this isn't working anymore, it was working earlier. :(

import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, EasyAnimateTransformer3DModel, EasyAnimatePipeline
from diffusers.utils import export_to_video
from transformers import AutoModel
dtype = torch.bfloat16
quant_config = DiffusersBitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=dtype)
text_encoder_4bit = AutoModel.from_pretrained(
    "alibaba-pai/EasyAnimateV5.1-12b-zh",
    subfolder="text_encoder",
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16,
)
transformer_4bit = EasyAnimateTransformer3DModel.from_pretrained(
    "alibaba-pai/EasyAnimateV5.1-12b-zh",
    subfolder="transformer",
    quantization_config=quant_config,
    torch_dtype=dtype,
)
pipeline = EasyAnimatePipeline.from_pretrained(
    "alibaba-pai/EasyAnimateV5.1-12b-zh",
    text_encoder=text_encoder_4bit,
    transformer=transformer_4bit,
    torch_dtype=dtype,
)
pipeline.enable_model_cpu_offload()
pipeline.vae.enable_slicing()
pipeline.vae.enable_tiling()
prompt = "A cat walks on the grass, realistic style."
negative_prompt = "bad detailed"
video = pipeline(
    prompt=prompt, 
    negative_prompt=negative_prompt, 
    num_frames=81, 
    num_inference_steps=40,
    width=512,
    height=320
).frames[0]
export_to_video(video, "cat2.mp4", fps=8)

It just crashes

(venv) C:\aiOWN\lumina2>python EasyAnimate5.1.py
`low_cpu_mem_usage` was None, now default to True since model is quantized.
Downloading shards: 100%|███████████████████████████████████████████████████| 5/5 [00:00<?, ?it/s]
`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46
Loading checkpoint shards: 100%|████████████████████████████████████| 5/5 [00:19<00:00,  3.90s/it]
The config attributes {'add_ref_latent_in_control_model': True, 'clip_channels': None, 'enable_clip_in_inpaint': False, 'ref_channels': None, 'swa_layers': None} were passed to EasyAnimateTransformer3DModel, but are not expected and will be ignored. Please verify your config.json configuration file.

@bubbliiiing
Copy link
Contributor Author

由于某种原因,它不再起作用了,之前它是起作用的。:(

import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, EasyAnimateTransformer3DModel, EasyAnimatePipeline
from diffusers.utils import export_to_video
from transformers import AutoModel
dtype = torch.bfloat16
quant_config = DiffusersBitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=dtype)
text_encoder_4bit = AutoModel.from_pretrained(
    "alibaba-pai/EasyAnimateV5.1-12b-zh",
    subfolder="text_encoder",
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16,
)
transformer_4bit = EasyAnimateTransformer3DModel.from_pretrained(
    "alibaba-pai/EasyAnimateV5.1-12b-zh",
    subfolder="transformer",
    quantization_config=quant_config,
    torch_dtype=dtype,
)
pipeline = EasyAnimatePipeline.from_pretrained(
    "alibaba-pai/EasyAnimateV5.1-12b-zh",
    text_encoder=text_encoder_4bit,
    transformer=transformer_4bit,
    torch_dtype=dtype,
)
pipeline.enable_model_cpu_offload()
pipeline.vae.enable_slicing()
pipeline.vae.enable_tiling()
prompt = "A cat walks on the grass, realistic style."
negative_prompt = "bad detailed"
video = pipeline(
    prompt=prompt, 
    negative_prompt=negative_prompt, 
    num_frames=81, 
    num_inference_steps=40,
    width=512,
    height=320
).frames[0]
export_to_video(video, "cat2.mp4", fps=8)

它只是崩溃了

(venv) C:\aiOWN\lumina2>python EasyAnimate5.1.py
`low_cpu_mem_usage` was None, now default to True since model is quantized.
Downloading shards: 100%|███████████████████████████████████████████████████| 5/5 [00:00<?, ?it/s]
`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46
Loading checkpoint shards: 100%|████████████████████████████████████| 5/5 [00:19<00:00,  3.90s/it]
The config attributes {'add_ref_latent_in_control_model': True, 'clip_channels': None, 'enable_clip_in_inpaint': False, 'ref_channels': None, 'swa_layers': None} were passed to EasyAnimateTransformer3DModel, but are not expected and will be ignored. Please verify your config.json configuration file.

Thank you for your feedback, I'll give it a try.

@bubbliiiing
Copy link
Contributor Author

For some reason this isn't working anymore, it was working earlier. :(

import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, EasyAnimateTransformer3DModel, EasyAnimatePipeline
from diffusers.utils import export_to_video
from transformers import AutoModel
dtype = torch.bfloat16
quant_config = DiffusersBitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=dtype)
text_encoder_4bit = AutoModel.from_pretrained(
    "alibaba-pai/EasyAnimateV5.1-12b-zh",
    subfolder="text_encoder",
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16,
)
transformer_4bit = EasyAnimateTransformer3DModel.from_pretrained(
    "alibaba-pai/EasyAnimateV5.1-12b-zh",
    subfolder="transformer",
    quantization_config=quant_config,
    torch_dtype=dtype,
)
pipeline = EasyAnimatePipeline.from_pretrained(
    "alibaba-pai/EasyAnimateV5.1-12b-zh",
    text_encoder=text_encoder_4bit,
    transformer=transformer_4bit,
    torch_dtype=dtype,
)
pipeline.enable_model_cpu_offload()
pipeline.vae.enable_slicing()
pipeline.vae.enable_tiling()
prompt = "A cat walks on the grass, realistic style."
negative_prompt = "bad detailed"
video = pipeline(
    prompt=prompt, 
    negative_prompt=negative_prompt, 
    num_frames=81, 
    num_inference_steps=40,
    width=512,
    height=320
).frames[0]
export_to_video(video, "cat2.mp4", fps=8)

It just crashes

(venv) C:\aiOWN\lumina2>python EasyAnimate5.1.py
`low_cpu_mem_usage` was None, now default to True since model is quantized.
Downloading shards: 100%|███████████████████████████████████████████████████| 5/5 [00:00<?, ?it/s]
`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46
Loading checkpoint shards: 100%|████████████████████████████████████| 5/5 [00:19<00:00,  3.90s/it]
The config attributes {'add_ref_latent_in_control_model': True, 'clip_channels': None, 'enable_clip_in_inpaint': False, 'ref_channels': None, 'swa_layers': None} were passed to EasyAnimateTransformer3DModel, but are not expected and will be ignored. Please verify your config.json configuration file.

It seems to be working fine in my environment.
image

@nitinmukesh
Copy link

@bubbliiiing Sorry for the trouble.
It's working now after reinstalling the PR.

@nitinmukesh
Copy link

20250212_154949_370380_easyanimate51_bnb.mp4

A little white rabbit with glasses was sitting on a chair in a cafe reading a newspaper. There was a cup of hot coffee on the table.

Nice output. Please add support of pipeline.enable_sequential_cpu_offload().

@a-r-r-o-w
Copy link
Member

@nitinmukesh enable_sequential_cpu_offload() should work with any diffusers model out of the box. Are you facing any particular error? Apologies about iterating slow on this PR as there are many integrations happening in parallel, but will try to wrap it up soon to fix any sequential offloading problem you might be facing

@bubbliiiing
Copy link
Contributor Author

enable_sequential_cpu_offload

It seems that qwen2vl does not support enable_sequential_cpu_offload and will prompt the following error message:
image

@nitinmukesh
Copy link

Exactly this is the error I'm getting.
Currently with int4 and enable_model_cpu_offload it is working fine for 10 GB VRAM.
With sequential it should work for as low as 6 GB VRAM. However if it is not supported for qwen2vl we can simply ignore adding enable_sequential_cpu_offload.

@a-r-r-o-w
Copy link
Member

cc @SunMarc here for sequential cpu offload related issues in Qwen2VL

@nitinmukesh
Copy link

nitinmukesh commented Feb 12, 2025

@nitinmukesh enable_sequential_cpu_offload() should work with any diffusers model out of the box. Are you facing any particular error? Apologies about iterating slow on this PR as there are many integrations happening in parallel, but will try to wrap it up soon to fix any sequential offloading problem you might be facing

No need to aplogize, you guys are doing very good. I'm not in hurry ()maybe excited to see all the development happening) just trying all the features that are being added and posting my feedback. I want to see this library grow as it's very easy to use.

@nitinmukesh
Copy link

@a-r-r-o-w

Here is the code to reproduce the error

[pipeline.enable_model_cpu_offload() works fine]

import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, EasyAnimateTransformer3DModel, EasyAnimatePipeline
from diffusers.utils import export_to_video
from transformers import AutoModel
dtype = torch.bfloat16
quant_config = DiffusersBitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=dtype)
text_encoder_4bit = AutoModel.from_pretrained(
    "alibaba-pai/EasyAnimateV5.1-12b-zh",
    subfolder="text_encoder",
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16,
)
transformer_4bit = EasyAnimateTransformer3DModel.from_pretrained(
    "alibaba-pai/EasyAnimateV5.1-12b-zh",
    subfolder="transformer",
    quantization_config=quant_config,
    torch_dtype=dtype,
)
pipeline = EasyAnimatePipeline.from_pretrained(
    "alibaba-pai/EasyAnimateV5.1-12b-zh",
    text_encoder=text_encoder_4bit,
    transformer=transformer_4bit,
    torch_dtype=dtype,
)
pipeline.enable_sequential_cpu_offload()
pipeline.vae.enable_slicing()
pipeline.vae.enable_tiling()
prompt = "A cat walks on the grass, realistic style."
negative_prompt = "bad detailed"
video = pipeline(
    prompt=prompt, 
    negative_prompt=negative_prompt, 
    num_frames=81, 
    num_inference_steps=40,
    width=512,
    height=320
).frames[0]
export_to_video(video, "cat2.mp4", fps=8)

and log

Log 1: #10626 (comment)

Log 2:

(venv) C:\aiOWN\diffuser_webui>python easyanimate_bnb.py
`low_cpu_mem_usage` was None, now default to True since model is quantized.
Downloading shards: 100%|███████████████████████████████████████████████████| 5/5 [00:00<?, ?it/s]
`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46
Loading checkpoint shards: 100%|████████████████████████████████████| 5/5 [00:34<00:00,  6.83s/it]
The config attributes {'add_ref_latent_in_control_model': True, 'clip_channels': None, 'enable_clip_in_inpaint': False, 'ref_channels': None, 'swa_layers': None} were passed to EasyAnimateTransformer3DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Expected types for text_encoder: ['Qwen2VLForConditionalGeneration', 'BertModel'], got Qwen2VLModel.
Loading pipeline components...:   0%|                                       | 0/5 [00:00<?, ?it/s]The config attributes {'force_upcast': True, 'mid_block_use_attention': False, 'sample_size': 256, 'slice_mag_vae': False, 'slice_compression_vae': False, 'cache_compression_vae': False, 'cache_mag_vae': True, 'use_tiling': False, 'norm_type': None, 'use_tiling_encoder': False, 'use_tiling_decoder': False, 'mid_block_attention_type': 'spatial'} were passed to AutoencoderKLMagvit, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading pipeline components...: 100%|███████████████████████████████| 5/5 [00:01<00:00,  4.86it/s]
Traceback (most recent call last):
  File "C:\aiOWN\diffuser_webui\easyanimate_bnb.py", line 30, in <module>
    video = pipeline(
  File "C:\aiOWN\diffuser_webui\venv\lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "C:\aiOWN\diffuser_webui\venv\lib\site-packages\diffusers\pipelines\easyanimate\pipeline_easyanimate.py", line 803, in __call__
    ) = self.encode_prompt(
  File "C:\aiOWN\diffuser_webui\venv\lib\site-packages\diffusers\pipelines\easyanimate\pipeline_easyanimate.py", line 383, in encode_prompt
    prompt_embeds = text_encoder(
  File "C:\aiOWN\diffuser_webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\aiOWN\diffuser_webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\aiOWN\diffuser_webui\venv\lib\site-packages\accelerate\hooks.py", line 176, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "C:\aiOWN\diffuser_webui\venv\lib\site-packages\transformers\models\qwen2_vl\modeling_qwen2_vl.py", line 1082, in forward
    inputs_embeds = self.embed_tokens(input_ids)
  File "C:\aiOWN\diffuser_webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\aiOWN\diffuser_webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\aiOWN\diffuser_webui\venv\lib\site-packages\accelerate\hooks.py", line 171, in new_forward
    args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
  File "C:\aiOWN\diffuser_webui\venv\lib\site-packages\accelerate\hooks.py", line 370, in pre_forward
    return send_to_device(args, self.execution_device), send_to_device(
  File "C:\aiOWN\diffuser_webui\venv\lib\site-packages\accelerate\utils\operations.py", line 174, in send_to_device
    return honor_type(
  File "C:\aiOWN\diffuser_webui\venv\lib\site-packages\accelerate\utils\operations.py", line 81, in honor_type
    return type(obj)(generator)
  File "C:\aiOWN\diffuser_webui\venv\lib\site-packages\accelerate\utils\operations.py", line 175, in <genexpr>
    tensor, (send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys) for t in tensor)
  File "C:\aiOWN\diffuser_webui\venv\lib\site-packages\accelerate\utils\operations.py", line 155, in send_to_device
    return tensor.to(device, non_blocking=non_blocking)
NotImplementedError: Cannot copy out of meta tensor; no data!

@a-r-r-o-w
Copy link
Member

@bubbliiiing I've verified that the -diffusers suffixed checkpoints works with the T2V pipeline. I've refactored the VAE to mostly diffusers-expected format. Only the pipelines remain now, specifically the following changes:

  • Removing einops dependency
  • Creating a specific RoPE layer similar to how we do in HunyuanVideo and other latest model integrations
  • Removing calls to Image.open. The users will always have to pass a PIL.Image.Image, np.array or torch.Tensor themselves
  • Some other minimal cleanup

I can take these up when I find some time this week, but if you or someone from your team has the time to look into this, it would be very helpful too!

output.mp4

@bubbliiiing
Copy link
Contributor Author

Sure, I'll do my best to make the necessary fixes. If there are any parts I haven't covered, please let me know and I'll be happy to help with the modifications. Thank you!

@bubbliiiing
Copy link
Contributor Author

I have fixed some issues. Please take another look when you have time.

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @bubbliiiing, the latest changes are great! @yiyixuxu Could you give this a look?

@a-r-r-o-w a-r-r-o-w requested a review from yiyixuxu February 27, 2025 01:20
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome! thank you

@yiyixuxu
Copy link
Collaborator

we have some tests are still failing though,

@a-r-r-o-w a-r-r-o-w merged commit 5e3b7d2 into huggingface:main Mar 3, 2025
23 of 27 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
close-to-merge roadmap Add to current release roadmap
Projects
Development

Successfully merging this pull request may close these issues.

7 participants