-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Attention Dispatcher #11368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Attention Dispatcher #11368
Conversation
supported: flash, flash_varlen, flex, native, sage, sage_varlen, xformers
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
… flux attention processors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting PR! I only left some higher-level comments. My major comment is around having an attention config class instead of environment vars. Or would that be too much for this PR?
For the attention config class (if decided to proceed that route), I was thinking of the following APIs:
attn_config = AttentionConfig(
attn_implementation="...",
enable_gqa=...
)
model.set_attn_config(attn_config)
class BlockMask: | ||
def __init__(self, *args, **kwargs): | ||
raise OptionalDependencyNotAvailable( | ||
"The `torch` library version is too old. Please update it to at least 2.5.0." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could further clarify that "To use BlockMask
you need an updated torch installation."
DIFFUSERS_ATTN_PROVIDER = os.getenv("DIFFUSERS_ATTN_PROVIDER", "native") | ||
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it instead make sense to have them parsed through some kind of AttentionConfig
class?
def get_active_provider(cls): | ||
return cls._active_provider, cls._providers[cls._active_provider] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it only return cls._active_provider
?
dropout_p: float = 0.0, | ||
is_causal: bool = False, | ||
scale: Optional[float] = None, | ||
enable_gqa: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have many models using GQA? If not, would it be sensible to have it in the attention config class if we decide to go that route instead of ENV vars?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Atleast Cosmos uses GQA iirc and the change is with respect to that. The flag exists here in respect to torch's sdpa, so we should probably not move it elsewhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. We're also logging a warning about the kwargs being discarded. So okay.
The environment vars were initially only for my quick testing from CLI instead of changing the code everytime. We can get rid of it completely. The intended API in my mind, and what currently exists in the PR is with context managers: from diffusers import attention_provider
with attention_provider("sage_varlen"):
model(...) Can change once we finalize something |
WIP but ready for some initial reviews. Will update description with benchmarks and user API to modify attention provider on the fly soon
Context parallel is based on pytorch experimental, so will be added in a separate PR soon.
Benchmarks
A100
FluxThese numbers don't really reflect the true speedup from sage attention. This is because Flux is bounded by time taken in feed-forwards as opposed to attention. Also, the recommendation for Sage Attention is to use CUDA 12.4 or above. Our DGX has 12.2
batch_size=1
batch_size=4
4090
TODO
cc @DN6 @sayakpaul @yiyixuxu