-
Notifications
You must be signed in to change notification settings - Fork 570
[PyTorch] Add FA4 Support #2432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Xin Yao <[email protected]>
for more information, see https://pre-commit.ci
| if not use_flash_attn_3: | ||
| if use_flash_attn_4: | ||
| fa_4_optional_forward_kwargs = { | ||
| # "window_size": window_size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default window_size = (-1, 0) doesn't mean no sliding window for FA4.
| if use_flash_attention_2 and FlashAttentionUtils.is_installed: | ||
| logger.debug("Disabling FlashAttention 2 as it does not support MLA.") | ||
| use_flash_attention_2 = False | ||
| if use_flash_attention_4 and FlashAttentionUtils.v4_is_installed: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently FA4 only supports MLA in forward.
| use_flash_attention_2 = False | ||
| if use_flash_attention_3: | ||
|
|
||
| if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move FlashAttentionUtils.v3_is_installed ahead so we don't need to check it if not installed.
| " not supported for compute capability = sm120" | ||
| ) | ||
| use_fused_attention = False | ||
| if use_flash_attention_4 and FlashAttentionUtils.v4_is_installed: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FA4 only supports packed sequence in fwd.
| if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"): | ||
| use_flash_attn_3 = True | ||
| use_flash_attn_4 = False | ||
| if flash_attention_backend is not None and str(flash_attention_backend).endswith("cute"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The suffix cute is added in get_attention_backend because FA4 is released with the package name flash-attn-cute and version starting from 0.1.0. We need to add the ".cute" suffix to the version number to distinguish.
|
|
||
| # `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`. | ||
| # When `FusedAttention` does not support the provided attention params, and `FlashAttention` | ||
| # does, we recommend users to install flash-attn if not installed already. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might not be working correctly.
Since many checks above has
if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed:
if xxx:
use_flash_attention_3 = FalseMany checks are skipped if FA3 is not installed. So even use_flash_attention_3 == True here doesn't mean all the requirements are met.
Description
Very initial effort to add FA4.
Done:
Known issues:
TODO:
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: