|
54 | 54 | PLATFORM_SUPPORTS_CUDNN_ATTENTION, |
55 | 55 | tf32_on_and_off, |
56 | 56 | tf32_enabled, |
| 57 | + math_sdp_precision, |
57 | 58 | ) |
58 | 59 |
|
59 | 60 | if TEST_FAIRSEQ: |
@@ -128,6 +129,12 @@ def _check_equal( |
128 | 129 | _check_equal(gold, ref, tst, fudge_factor, tensor_name) |
129 | 130 | return |
130 | 131 |
|
| 132 | + if golden.is_cuda and golden.dtype == torch.float32: |
| 133 | + assert torch.backends.cuda.math_sdp.fp32_precision == "ieee", ( |
| 134 | + "Testing script error: FP32 golden tensor must be calculated with IEEE" |
| 135 | + " precision. Add @math_sdp_precision('ieee') to related tests to fix it." |
| 136 | + ) |
| 137 | + |
131 | 138 | # Compute error between golden |
132 | 139 | test_error = (golden - test).abs().max() |
133 | 140 | ref_error = (golden - reference).abs().max() |
@@ -3413,6 +3420,7 @@ def test_mem_eff_backwards_determinism(self, device): |
3413 | 3420 | ) |
3414 | 3421 | @parametrize("scale", [None, "l1"]) |
3415 | 3422 | @tf32_enabled() |
| 3423 | + @math_sdp_precision("ieee") |
3416 | 3424 | def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, |
3417 | 3425 | head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, |
3418 | 3426 | scale: str): |
@@ -3528,6 +3536,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, |
3528 | 3536 | ) |
3529 | 3537 | @parametrize("scale", [None, "l1"]) |
3530 | 3538 | @tf32_enabled() |
| 3539 | + @math_sdp_precision("ieee") |
3531 | 3540 | def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, |
3532 | 3541 | seq_len_k: int, head_dim: int, is_causal: bool, |
3533 | 3542 | dropout_p: float, dtype: torch.dtype, |
@@ -3641,6 +3650,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, |
3641 | 3650 | @parametrize("enable_gqa", [True, False]) |
3642 | 3651 | @parametrize("n_heads", [[16, 8], [10, 2]]) |
3643 | 3652 | @tf32_enabled() |
| 3653 | + @math_sdp_precision("ieee") |
3644 | 3654 | def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, |
3645 | 3655 | head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, |
3646 | 3656 | scale: str, enable_gqa: bool, n_heads: list[int]): |
@@ -3786,6 +3796,7 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le |
3786 | 3796 | @parametrize("scale", [None, "l1"]) |
3787 | 3797 | @parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA) |
3788 | 3798 | @tf32_enabled() |
| 3799 | + @math_sdp_precision("ieee") |
3789 | 3800 | def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int, |
3790 | 3801 | seq_len_q: int, seq_len_k: int, |
3791 | 3802 | head_dim: int, |
@@ -4100,6 +4111,7 @@ def test_fused_kernels_nested_broadcasting_query_dense(self, device): |
4100 | 4111 | @parametrize("dtype", [torch.float16]) |
4101 | 4112 | @parametrize("scale", [None, "l1"]) |
4102 | 4113 | @parametrize("is_causal", [True, False]) |
| 4114 | + @math_sdp_precision("ieee") |
4103 | 4115 | def test_flash_attention_vs_math_ref_grads_nestedtensor(self, device, batch_size: int, max_seq_len_q: int, max_seq_len_kv: int, |
4104 | 4116 | head_dim: int, dropout_p: float, dtype: torch.dtype, |
4105 | 4117 | scale: str, is_causal: bool): |
|
0 commit comments