Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/cute/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@
flash_attn_func,
flash_attn_varlen_func,
flash_attn_combine,
_get_device_capability,
)


DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE"
# SplitKV and paged KV are not supported on SM90
IS_SM90 = _get_device_capability() == 9
TEST_BWD_ONLY = False
VERBOSE = True

Expand Down Expand Up @@ -238,6 +241,9 @@ def test_flash_attn_output(
# pack_gqa_vals = [False]
num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1]
for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
# SplitKV not supported on SM90 - skip this iteration
if IS_SM90 and num_splits > 1:
continue
out, lse = flash_attn_func(
q,
k,
Expand Down Expand Up @@ -276,6 +282,19 @@ def test_flash_attn_output(
# and False
and not ((causal or local) and seqlen_k < seqlen_q)
):
# TODO: SM90 backward pass has invalid MMA tile config for d=64 + non-causal
# The m_block_size=80 (non-causal) with head_dim=64 creates an invalid tile.
# Fix requires adjusting m_block_size or MMA config in flash_bwd_sm90.py
if IS_SM90 and d == 64 and not causal:
pytest.xfail("SM90 backward: d=64 + non-causal has invalid MMA tile config (m_block=80)")
# TODO: SM90 backward pass has tensor layout issue for GQA/MQA (qhead_per_kvhead > 1)
# Error: "invalid mode element for input of rank 3" in utils.select()
# Fix requires adjusting layout handling in flash_bwd_sm90.py for GQA
if IS_SM90 and mha_type != "mha":
pytest.xfail("SM90 backward: GQA/MQA has tensor layout issue (qhead_per_kvhead > 1)")
# TODO: SM90 backward pass does not support local attention yet
if IS_SM90 and local:
pytest.xfail("SM90 backward: local attention not supported yet")
g = torch.randn_like(out)
# do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2)
dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
Expand Down Expand Up @@ -606,6 +625,9 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
# SplitKV is not supported for hdim >= 192
num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1]
for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
# SplitKV not supported on SM90 - skip this iteration
if IS_SM90 and num_splits > 1:
continue
out_unpad, lse = flash_attn_varlen_func(
q_unpad,
k_unpad,
Expand Down Expand Up @@ -816,6 +838,8 @@ def test_flash_attn_kvcache(
):
if page_size is not None and seqlen_k % page_size != 0:
pytest.skip()
if page_size is not None and IS_SM90:
pytest.xfail("paged KV not supported on SM90")
if seqlen_q > seqlen_k and new_kv:
pytest.skip()
if not new_kv and rotary_fraction > 0.0:
Expand Down Expand Up @@ -1134,12 +1158,16 @@ def test_flash_attn_kvcache(
k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone()
v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone()
# num_splits_vals = [1, 0]
# SplitKV is not supported for hdim >= 192
num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1]
# precompute_metadata_vals = [False, True]
precompute_metadata_vals = [False]
for num_splits, precompute_metadata in itertools.product(
num_splits_vals, precompute_metadata_vals
):
# SplitKV not supported on SM90 - skip this iteration
if IS_SM90 and num_splits > 1:
continue
# if precompute_metadata:
# scheduler_metadata = get_scheduler_metadata(
# batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d,
Expand Down Expand Up @@ -1279,6 +1307,9 @@ def test_flash_attn_kvcache(
@pytest.mark.parametrize("d", [64, 128])
@pytest.mark.parametrize("seqlen_q,seqlen_k", [(128, 128), (256, 256)])
def test_flash_attn_bwd_preallocated_outputs(seqlen_q, seqlen_k, d, causal, dtype):
if IS_SM90 and d == 64 and not causal:
pytest.xfail("SM90 backward: d=64 + non-causal has invalid MMA tile config (m_block=80)")

from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd

device = "cuda"
Expand Down
5 changes: 5 additions & 0 deletions tests/cute/test_flash_attn_race_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@


DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE"
IS_SM90 = torch.cuda.get_device_capability()[0] == 9


# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
Expand Down Expand Up @@ -247,6 +248,10 @@ def test_flash_attn_output(
and learnable_sink is None
# and False
):
if IS_SM90 and mha_type != "mha":
pytest.xfail("SM90 backward: GQA/MQA has tensor layout issue (qhead_per_kvhead > 1)")
if IS_SM90 and local:
pytest.xfail("SM90 backward: local attention not supported yet")
g = torch.randn_like(out)
# do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2)
dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
Expand Down
8 changes: 5 additions & 3 deletions tests/cute/test_flash_attn_varlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import torch.nn.functional as F
from flash_attn.cute import flash_attn_varlen_func

IS_SM90 = torch.cuda.get_device_capability()[0] == 9


@pytest.mark.parametrize("B", [1, 7, 20])
@pytest.mark.parametrize("H", [1, 4, 6])
@pytest.mark.parametrize("D", [64, 128])
Expand Down Expand Up @@ -40,9 +43,8 @@ def test_varlen(
dtype=dtype
)

# SM100 (Blackwell) backward pass doesn't support varlen yet
compute_capability = torch.cuda.get_device_capability()[0]
skip_backward = (compute_capability == 10)
# SM90/SM100 backward pass doesn't support varlen yet
skip_backward = IS_SM90 or torch.cuda.get_device_capability()[0] == 10

ok = check_varlen_vs_torch_flash(
q, k, v,
Expand Down
2 changes: 1 addition & 1 deletion tests/cute/test_mask_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _run_mask_test(
pack_gqa = False
elif kv_mode == "gqa":
if COMPUTE_CAPABILITY != 10:
pytest.skip("pack_gqa requires SM100")
pytest.xfail("pack_gqa requires SM100")
nheads_kv = nheads // 4
pack_gqa = True
elif kv_mode == "mqa":
Expand Down
7 changes: 6 additions & 1 deletion tests/cute/test_score_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ def test_score_mod_with_paged_kvcache(
dtype,
score_mod_pair,
):
if COMPUTE_CAPABILITY == 9:
pytest.xfail("Paged KV cache only supported on SM100")
if page_size is not None and seqlen_kv % page_size != 0:
pytest.skip()

Expand Down Expand Up @@ -452,6 +454,8 @@ def test_score_mod_with_paged_kvcache_aux_tensors(
dtype,
score_mod_pair,
):
if COMPUTE_CAPABILITY == 9:
pytest.xfail("Paged KV cache only supported on SM100")
if page_size is not None and seqlen_kv % page_size != 0:
pytest.skip()

Expand Down Expand Up @@ -799,7 +803,7 @@ def make_aux_tensors_for_bwd(cute_score_mod, eager_factory, seqlen_q, num_heads,
(256, 128),
],
)
@pytest.mark.parametrize("dim", [64])
@pytest.mark.parametrize("dim", [64, 128])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("score_mod_triple", BWD_TEST_PAIRS_WITH_AUX)
def test_cute_vs_flex_attention_backward_with_aux(
Expand Down Expand Up @@ -865,6 +869,7 @@ def test_cute_vs_flex_attention_backward_with_aux(
def test_cute_vs_flex_attention_backward_pack_gqa(
seqlen_q, seqlen_kv, dim, dtype, qhead_per_kvhead, num_kv_heads, score_mod_triple
):
pytest.skip("pack_gqa backward not yet implemented")
torch.random.manual_seed(42)
cute_fwd, cute_bwd, eager_ref = score_mod_triple

Expand Down
8 changes: 8 additions & 0 deletions tests/cute/test_score_mod_varlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
debug_global_idx_factory,
)

IS_SM90 = torch.cuda.get_device_capability()[0] == 9

# =============================================================================
# Test pairs
# =============================================================================
Expand Down Expand Up @@ -694,6 +696,9 @@ def test_varlen_score_mod_kvcache(
score_mod_tuple,
):
"""Test varlen attention with score_mod and paged KV cache."""
if IS_SM90 and page_size is not None:
pytest.xfail("paged KV not supported on SM90")

if not varlen_q and not varlen_k:
pytest.skip(
"At least one of varlen_q or varlen_k must be True for varlen tests"
Expand Down Expand Up @@ -850,6 +855,9 @@ def test_varlen_score_mod_with_paged_kvcache_global(
score_mod_tuple,
):
"""Test varlen attention with global idx score_mod and paged KV cache."""
if IS_SM90 and page_size is not None:
pytest.xfail("paged KV not supported on SM90")

if page_size is not None and varlen_k:
pytest.skip("Paged KV cache requires batched (non-varlen) K")

Expand Down