diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index cd864ff26c..b2809ab61e 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -25,10 +25,13 @@ flash_attn_func, flash_attn_varlen_func, flash_attn_combine, + _get_device_capability, ) DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" +# SplitKV and paged KV are not supported on SM90 +IS_SM90 = _get_device_capability() == 9 TEST_BWD_ONLY = False VERBOSE = True @@ -238,6 +241,9 @@ def test_flash_attn_output( # pack_gqa_vals = [False] num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + # SplitKV not supported on SM90 - skip this iteration + if IS_SM90 and num_splits > 1: + continue out, lse = flash_attn_func( q, k, @@ -276,6 +282,19 @@ def test_flash_attn_output( # and False and not ((causal or local) and seqlen_k < seqlen_q) ): + # TODO: SM90 backward pass has invalid MMA tile config for d=64 + non-causal + # The m_block_size=80 (non-causal) with head_dim=64 creates an invalid tile. + # Fix requires adjusting m_block_size or MMA config in flash_bwd_sm90.py + if IS_SM90 and d == 64 and not causal: + pytest.xfail("SM90 backward: d=64 + non-causal has invalid MMA tile config (m_block=80)") + # TODO: SM90 backward pass has tensor layout issue for GQA/MQA (qhead_per_kvhead > 1) + # Error: "invalid mode element for input of rank 3" in utils.select() + # Fix requires adjusting layout handling in flash_bwd_sm90.py for GQA + if IS_SM90 and mha_type != "mha": + pytest.xfail("SM90 backward: GQA/MQA has tensor layout issue (qhead_per_kvhead > 1)") + # TODO: SM90 backward pass does not support local attention yet + if IS_SM90 and local: + pytest.xfail("SM90 backward: local attention not supported yet") g = torch.randn_like(out) # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) @@ -606,6 +625,9 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # SplitKV is not supported for hdim >= 192 num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + # SplitKV not supported on SM90 - skip this iteration + if IS_SM90 and num_splits > 1: + continue out_unpad, lse = flash_attn_varlen_func( q_unpad, k_unpad, @@ -816,6 +838,8 @@ def test_flash_attn_kvcache( ): if page_size is not None and seqlen_k % page_size != 0: pytest.skip() + if page_size is not None and IS_SM90: + pytest.xfail("paged KV not supported on SM90") if seqlen_q > seqlen_k and new_kv: pytest.skip() if not new_kv and rotary_fraction > 0.0: @@ -1134,12 +1158,16 @@ def test_flash_attn_kvcache( k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() # num_splits_vals = [1, 0] + # SplitKV is not supported for hdim >= 192 num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1] # precompute_metadata_vals = [False, True] precompute_metadata_vals = [False] for num_splits, precompute_metadata in itertools.product( num_splits_vals, precompute_metadata_vals ): + # SplitKV not supported on SM90 - skip this iteration + if IS_SM90 and num_splits > 1: + continue # if precompute_metadata: # scheduler_metadata = get_scheduler_metadata( # batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, @@ -1279,6 +1307,9 @@ def test_flash_attn_kvcache( @pytest.mark.parametrize("d", [64, 128]) @pytest.mark.parametrize("seqlen_q,seqlen_k", [(128, 128), (256, 256)]) def test_flash_attn_bwd_preallocated_outputs(seqlen_q, seqlen_k, d, causal, dtype): + if IS_SM90 and d == 64 and not causal: + pytest.xfail("SM90 backward: d=64 + non-causal has invalid MMA tile config (m_block=80)") + from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd device = "cuda" diff --git a/tests/cute/test_flash_attn_race_condition.py b/tests/cute/test_flash_attn_race_condition.py index 520cf6466a..0174040687 100644 --- a/tests/cute/test_flash_attn_race_condition.py +++ b/tests/cute/test_flash_attn_race_condition.py @@ -30,6 +30,7 @@ DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" +IS_SM90 = torch.cuda.get_device_capability()[0] == 9 # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @@ -247,6 +248,10 @@ def test_flash_attn_output( and learnable_sink is None # and False ): + if IS_SM90 and mha_type != "mha": + pytest.xfail("SM90 backward: GQA/MQA has tensor layout issue (qhead_per_kvhead > 1)") + if IS_SM90 and local: + pytest.xfail("SM90 backward: local attention not supported yet") g = torch.randn_like(out) # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) diff --git a/tests/cute/test_flash_attn_varlen.py b/tests/cute/test_flash_attn_varlen.py index 53d907eed9..3f72667674 100644 --- a/tests/cute/test_flash_attn_varlen.py +++ b/tests/cute/test_flash_attn_varlen.py @@ -7,6 +7,9 @@ import torch.nn.functional as F from flash_attn.cute import flash_attn_varlen_func +IS_SM90 = torch.cuda.get_device_capability()[0] == 9 + + @pytest.mark.parametrize("B", [1, 7, 20]) @pytest.mark.parametrize("H", [1, 4, 6]) @pytest.mark.parametrize("D", [64, 128]) @@ -40,9 +43,8 @@ def test_varlen( dtype=dtype ) - # SM100 (Blackwell) backward pass doesn't support varlen yet - compute_capability = torch.cuda.get_device_capability()[0] - skip_backward = (compute_capability == 10) + # SM90/SM100 backward pass doesn't support varlen yet + skip_backward = IS_SM90 or torch.cuda.get_device_capability()[0] == 10 ok = check_varlen_vs_torch_flash( q, k, v, diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index 5ebb8f53cf..01261789f3 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -165,7 +165,7 @@ def _run_mask_test( pack_gqa = False elif kv_mode == "gqa": if COMPUTE_CAPABILITY != 10: - pytest.skip("pack_gqa requires SM100") + pytest.xfail("pack_gqa requires SM100") nheads_kv = nheads // 4 pack_gqa = True elif kv_mode == "mqa": diff --git a/tests/cute/test_score_mod.py b/tests/cute/test_score_mod.py index 26cdecde43..82c135a8ee 100644 --- a/tests/cute/test_score_mod.py +++ b/tests/cute/test_score_mod.py @@ -298,6 +298,8 @@ def test_score_mod_with_paged_kvcache( dtype, score_mod_pair, ): + if COMPUTE_CAPABILITY == 9: + pytest.xfail("Paged KV cache only supported on SM100") if page_size is not None and seqlen_kv % page_size != 0: pytest.skip() @@ -452,6 +454,8 @@ def test_score_mod_with_paged_kvcache_aux_tensors( dtype, score_mod_pair, ): + if COMPUTE_CAPABILITY == 9: + pytest.xfail("Paged KV cache only supported on SM100") if page_size is not None and seqlen_kv % page_size != 0: pytest.skip() @@ -799,7 +803,7 @@ def make_aux_tensors_for_bwd(cute_score_mod, eager_factory, seqlen_q, num_heads, (256, 128), ], ) -@pytest.mark.parametrize("dim", [64]) +@pytest.mark.parametrize("dim", [64, 128]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("score_mod_triple", BWD_TEST_PAIRS_WITH_AUX) def test_cute_vs_flex_attention_backward_with_aux( @@ -865,6 +869,7 @@ def test_cute_vs_flex_attention_backward_with_aux( def test_cute_vs_flex_attention_backward_pack_gqa( seqlen_q, seqlen_kv, dim, dtype, qhead_per_kvhead, num_kv_heads, score_mod_triple ): + pytest.skip("pack_gqa backward not yet implemented") torch.random.manual_seed(42) cute_fwd, cute_bwd, eager_ref = score_mod_triple diff --git a/tests/cute/test_score_mod_varlen.py b/tests/cute/test_score_mod_varlen.py index 3f339e548c..7cca7f2aa0 100644 --- a/tests/cute/test_score_mod_varlen.py +++ b/tests/cute/test_score_mod_varlen.py @@ -54,6 +54,8 @@ debug_global_idx_factory, ) +IS_SM90 = torch.cuda.get_device_capability()[0] == 9 + # ============================================================================= # Test pairs # ============================================================================= @@ -694,6 +696,9 @@ def test_varlen_score_mod_kvcache( score_mod_tuple, ): """Test varlen attention with score_mod and paged KV cache.""" + if IS_SM90 and page_size is not None: + pytest.xfail("paged KV not supported on SM90") + if not varlen_q and not varlen_k: pytest.skip( "At least one of varlen_q or varlen_k must be True for varlen tests" @@ -850,6 +855,9 @@ def test_varlen_score_mod_with_paged_kvcache_global( score_mod_tuple, ): """Test varlen attention with global idx score_mod and paged KV cache.""" + if IS_SM90 and page_size is not None: + pytest.xfail("paged KV not supported on SM90") + if page_size is not None and varlen_k: pytest.skip("Paged KV cache requires batched (non-varlen) K")