Skip to content

Conversation

@yaox12
Copy link
Member

@yaox12 yaox12 commented Nov 28, 2025

Description

Very initial effort to add FA4.

Done:

  • Basic GQA/MQA (head dim = 64, 96, 128) support for SM100

Known issues:

  • SM90 is not working, so I just disabled it.
  • Some configurations only works for FWD, and I just disabled them, e.g.,
    • qk_head_dim = 192 and v_head_dim = 128
    • packed sequence

TODO:

  • Add tests
  • Correctly handle sliding window

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

if not use_flash_attn_3:
if use_flash_attn_4:
fa_4_optional_forward_kwargs = {
# "window_size": window_size,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default window_size = (-1, 0) doesn't mean no sliding window for FA4.

if use_flash_attention_2 and FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention 2 as it does not support MLA.")
use_flash_attention_2 = False
if use_flash_attention_4 and FlashAttentionUtils.v4_is_installed:
Copy link
Member Author

@yaox12 yaox12 Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently FA4 only supports MLA in forward.

use_flash_attention_2 = False
if use_flash_attention_3:

if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move FlashAttentionUtils.v3_is_installed ahead so we don't need to check it if not installed.

" not supported for compute capability = sm120"
)
use_fused_attention = False
if use_flash_attention_4 and FlashAttentionUtils.v4_is_installed:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FA4 only supports packed sequence in fwd.

if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"):
use_flash_attn_3 = True
use_flash_attn_4 = False
if flash_attention_backend is not None and str(flash_attention_backend).endswith("cute"):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The suffix cute is added in get_attention_backend because FA4 is released with the package name flash-attn-cute and version starting from 0.1.0. We need to add the ".cute" suffix to the version number to distinguish.


# `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`.
# When `FusedAttention` does not support the provided attention params, and `FlashAttention`
# does, we recommend users to install flash-attn if not installed already.
Copy link
Member Author

@yaox12 yaox12 Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might not be working correctly.

Since many checks above has

if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed:
    if xxx:
        use_flash_attention_3 = False

Many checks are skipped if FA3 is not installed. So even use_flash_attention_3 == True here doesn't mean all the requirements are met.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant