Skip to content

Commit 490a5f4

Browse files
authored
[PyTorch] Fix attention backend and tests for sm120 (#2320)
* Fix attention backend and tests for sm120 Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Disable MLA only for backward Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
1 parent 5e8a9a9 commit 490a5f4

File tree

2 files changed

+48
-9
lines changed

2 files changed

+48
-9
lines changed

tests/pytorch/attention/test_attention.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,16 @@
6161
get_available_attention_backends,
6262
)
6363

64-
# Check if hardware supports FP8
64+
# Check if hardware supports FP8 attention.
6565
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
66+
fp8_attn_available, reason_for_no_fp8_attn = fp8_available, reason_for_no_fp8
67+
device_compute_capability = get_device_compute_capability()
68+
if fp8_available and (device_compute_capability < (9, 0) or device_compute_capability >= (12, 0)):
69+
fp8_attn_available = False
70+
reason_for_no_fp8_attn = (
71+
"FP8 attention is not supported for compute capability ="
72+
f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}"
73+
)
6674

6775
# Reset RNG seed and states
6876
seed = 1234
@@ -1573,8 +1581,7 @@ def _run_transformer_layer(
15731581
}
15741582

15751583

1576-
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
1577-
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.")
1584+
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
15781585
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
15791586
@pytest.mark.parametrize("model", ["large"])
15801587
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@@ -1736,8 +1743,7 @@ def get_model(dtype, config):
17361743

17371744

17381745
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
1739-
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
1740-
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
1746+
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
17411747
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
17421748
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
17431749
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
@@ -1973,8 +1979,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
19731979

19741980

19751981
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
1976-
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
1977-
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
1982+
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
19781983
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
19791984
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
19801985
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
@@ -2302,8 +2307,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
23022307
),
23032308
reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""",
23042309
)
2305-
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
2306-
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
2310+
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
23072311
@pytest.mark.parametrize("dtype", param_types_fp8)
23082312
@pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0)
23092313
def test_custom_mha_fp8_vs_f16(dtype, model):

transformer_engine/pytorch/attention/dot_product_attention/utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,20 @@ def get_attention_backend(
481481
logger.debug("Disabling FusedAttention for FP8 current scaling with cuDNN < 9.14.0")
482482
use_fused_attention = False
483483

484+
if device_compute_capability == (12, 0):
485+
if use_flash_attention:
486+
logger.debug(
487+
"Disabling FlashAttention as FP8 is not supported"
488+
" for compute capability = sm120"
489+
)
490+
if use_fused_attention:
491+
logger.debug(
492+
"Disabling FusedAttention as FP8 is not supported"
493+
" for compute capability = sm120"
494+
)
495+
use_flash_attention = False
496+
use_fused_attention = False
497+
484498
# Filter: Return max_logit
485499
if return_max_logit:
486500
if use_flash_attention:
@@ -560,6 +574,20 @@ def get_attention_backend(
560574
qkv_layout,
561575
)
562576
use_fused_attention = False
577+
if (
578+
device_compute_capability == (12, 0)
579+
and (head_dim_qk > 128 or head_dim_qk % 8 != 0)
580+
and is_training
581+
):
582+
if use_fused_attention:
583+
logger.debug(
584+
"Disabling FusedAttention as MLA for backward pass is not supported for compute"
585+
" capability = sm120 for a head_dim_qk > 128 or head_dim_qk %%8 != 0. Found:"
586+
" head_dim_qk = %s",
587+
head_dim_qk,
588+
)
589+
use_fused_attention = False
590+
563591
if use_flash_attention_2 and (
564592
head_dim_qk > 256
565593
or head_dim_qk % 8 != 0
@@ -629,6 +657,13 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
629657
"padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
630658
)
631659
use_flash_attention = False
660+
if device_compute_capability == (12, 0):
661+
if use_fused_attention:
662+
logger.debug(
663+
"Disabling FusedAttention as qkv_format = thd is"
664+
" not supported for compute capability = sm120"
665+
)
666+
use_fused_attention = False
632667

633668
# Filter: Dropout
634669
if attention_dropout != 0.0 and use_flash_attention_3:

0 commit comments

Comments
 (0)