|
61 | 61 | get_available_attention_backends, |
62 | 62 | ) |
63 | 63 |
|
64 | | -# Check if hardware supports FP8 |
| 64 | +# Check if hardware supports FP8 attention. |
65 | 65 | fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) |
| 66 | +fp8_attn_available, reason_for_no_fp8_attn = fp8_available, reason_for_no_fp8 |
| 67 | +device_compute_capability = get_device_compute_capability() |
| 68 | +if fp8_available and (device_compute_capability < (9, 0) or device_compute_capability >= (12, 0)): |
| 69 | + fp8_attn_available = False |
| 70 | + reason_for_no_fp8_attn = ( |
| 71 | + "FP8 attention is not supported for compute capability =" |
| 72 | + f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}" |
| 73 | + ) |
66 | 74 |
|
67 | 75 | # Reset RNG seed and states |
68 | 76 | seed = 1234 |
@@ -1573,8 +1581,7 @@ def _run_transformer_layer( |
1573 | 1581 | } |
1574 | 1582 |
|
1575 | 1583 |
|
1576 | | -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) |
1577 | | -@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.") |
| 1584 | +@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn) |
1578 | 1585 | @pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.") |
1579 | 1586 | @pytest.mark.parametrize("model", ["large"]) |
1580 | 1587 | @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) |
@@ -1736,8 +1743,7 @@ def get_model(dtype, config): |
1736 | 1743 |
|
1737 | 1744 |
|
1738 | 1745 | @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") |
1739 | | -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) |
1740 | | -@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") |
| 1746 | +@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn) |
1741 | 1747 | @pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) |
1742 | 1748 | @pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys()) |
1743 | 1749 | @pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16) |
@@ -1973,8 +1979,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: |
1973 | 1979 |
|
1974 | 1980 |
|
1975 | 1981 | @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") |
1976 | | -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) |
1977 | | -@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") |
| 1982 | +@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn) |
1978 | 1983 | @pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) |
1979 | 1984 | @pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys()) |
1980 | 1985 | @pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16) |
@@ -2302,8 +2307,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: |
2302 | 2307 | ), |
2303 | 2308 | reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""", |
2304 | 2309 | ) |
2305 | | -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) |
2306 | | -@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") |
| 2310 | +@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn) |
2307 | 2311 | @pytest.mark.parametrize("dtype", param_types_fp8) |
2308 | 2312 | @pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0) |
2309 | 2313 | def test_custom_mha_fp8_vs_f16(dtype, model): |
|
0 commit comments