Skip to content

Attention Dispatcher #11368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft

Attention Dispatcher #11368

wants to merge 6 commits into from

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Apr 19, 2025

WIP but ready for some initial reviews. Will update description with benchmarks and user API to modify attention provider on the fly soon

# test.py
import torch
from diffusers import Lumina2Pipeline

pipe = Lumina2Pipeline.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16)
pipe.to("cuda")

prompt = "A cat holding a sign that says 'Hello, World!' in a colorful park with flowers and trees"
image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
image.save("output.png")
# fails because flex attention requires head dim to be a power of 2
DIFFUSERS_ATTN_PROVIDER="flex" CUDA_VISIBLE_DEVICES=3 python3 test.py
# dispatches to cudnn internally in pytorch, so it's the same as using "_native_cudnn" (see below)
DIFFUSERS_ATTN_PROVIDER="native" CUDA_VISIBLE_DEVICES=3 python3 test.py
DIFFUSERS_ATTN_PROVIDER="flash_varlen" CUDA_VISIBLE_DEVICES=3 python3 test.py
DIFFUSERS_ATTN_PROVIDER="sage_varlen" CUDA_VISIBLE_DEVICES=3 python3 test.py
DIFFUSERS_ATTN_PROVIDER="_native_cudnn" CUDA_VISIBLE_DEVICES=3 python3 test.py
DIFFUSERS_ATTN_PROVIDER="_native_efficient" CUDA_VISIBLE_DEVICES=3 python3 test.py
DIFFUSERS_ATTN_PROVIDER="xformers" CUDA_VISIBLE_DEVICES=3 python3 test.py

Context parallel is based on pytorch experimental, so will be added in a separate PR soon.

Benchmarks

A100

Flux

These numbers don't really reflect the true speedup from sage attention. This is because Flux is bounded by time taken in feed-forwards as opposed to attention. Also, the recommendation for Sage Attention is to use CUDA 12.4 or above. Our DGX has 12.2

batch_size=1

model_id attn_provider time model_memory inference_memory
flux flash 16.893 22.387 22.994
flux flash_varlen 17.551 22.387 22.994
flux flex 49.968 22.387 25.16
flux native 17.008 22.387 22.994
flux _native_cudnn 17.065 22.387 22.994
flux _native_efficient 18.279 22.387 22.994
flux _native_flash 16.995 22.387 22.994
flux sage 17.25 22.387 22.941
flux sage_varlen 18.828 22.387 22.941
flux xformers 16.993 22.387 22.994

batch_size=4

model_id attn_provider time model_memory inference_memory
flux flash 63.625 22.387 24.863
flux flash_varlen 65.1 22.387 24.863
flux flex 187.939 22.387 33.529
flux native 64.342 22.387 24.863
flux _native_cudnn 64.769 22.387 24.863
flux _native_efficient 69.545 22.387 24.863
flux _native_flash 64.443 22.387 24.863
flux sage 65.258 22.387 24.863
flux sage_varlen 67.588 22.387 24.863
flux xformers 64.523 22.387 24.863

4090

TODO

cc @DN6 @sayakpaul @yiyixuxu

supported: flash, flash_varlen, flex, native, sage, sage_varlen, xformers
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting PR! I only left some higher-level comments. My major comment is around having an attention config class instead of environment vars. Or would that be too much for this PR?


For the attention config class (if decided to proceed that route), I was thinking of the following APIs:

attn_config = AttentionConfig(
    attn_implementation="...",
    enable_gqa=...
)
model.set_attn_config(attn_config)

class BlockMask:
def __init__(self, *args, **kwargs):
raise OptionalDependencyNotAvailable(
"The `torch` library version is too old. Please update it to at least 2.5.0."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could further clarify that "To use BlockMask you need an updated torch installation."

Comment on lines +44 to +45
DIFFUSERS_ATTN_PROVIDER = os.getenv("DIFFUSERS_ATTN_PROVIDER", "native")
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
Copy link
Member

@sayakpaul sayakpaul Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it instead make sense to have them parsed through some kind of AttentionConfig class?

Comment on lines +153 to +154
def get_active_provider(cls):
return cls._active_provider, cls._providers[cls._active_provider]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it only return cls._active_provider?

dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have many models using GQA? If not, would it be sensible to have it in the attention config class if we decide to go that route instead of ENV vars?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Atleast Cosmos uses GQA iirc and the change is with respect to that. The flag exists here in respect to torch's sdpa, so we should probably not move it elsewhere

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. We're also logging a warning about the kwargs being discarded. So okay.

@a-r-r-o-w
Copy link
Member Author

The environment vars were initially only for my quick testing from CLI instead of changing the code everytime. We can get rid of it completely.

The intended API in my mind, and what currently exists in the PR is with context managers:

from diffusers import attention_provider

with attention_provider("sage_varlen"):
    model(...)

Can change once we finalize something

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants