21
21
_attention_backends ,
22
22
)
23
23
from transformer_engine .pytorch .dot_product_attention .utils import (
24
- FAUtils ,
24
+ FlashAttentionUtils ,
25
25
get_attention_backend ,
26
26
check_set_window_size ,
27
27
AttentionParams ,
@@ -280,12 +280,12 @@ def test_dot_product_attention(
280
280
# mannually pads and unpads the input and output of FlashAttention for testing purposes
281
281
if (
282
282
pad_between_seqs
283
- and FAUtils .is_installed
283
+ and FlashAttentionUtils .is_installed
284
284
and not (
285
285
config .max_seqlen_q != config .max_seqlen_kv
286
286
and config .attn_mask_type in ["causal" , "padding_causal" ]
287
287
)
288
- and (config .window_size [0 ] == - 1 or FAUtils .v2_3_plus )
288
+ and (config .window_size [0 ] == - 1 or FlashAttentionUtils .v2_3_plus )
289
289
):
290
290
flash_attn_supported = True
291
291
@@ -592,7 +592,7 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
592
592
}
593
593
594
594
595
- @pytest .mark .skipif (not FAUtils .v2_3_plus , reason = "Flash-attn 2.3+ is required." )
595
+ @pytest .mark .skipif (not FlashAttentionUtils .v2_3_plus , reason = "Flash-attn 2.3+ is required." )
596
596
@pytest .mark .parametrize ("dtype" , param_types_lean )
597
597
@pytest .mark .parametrize ("model_configs" , [model_configs_swa ])
598
598
@pytest .mark .parametrize ("model" , model_configs_swa .keys ())
@@ -614,7 +614,7 @@ def test_dpa_sliding_window(dtype, model_configs, model):
614
614
}
615
615
616
616
617
- @pytest .mark .skipif (not FAUtils .v2_3_plus , reason = "Flash-attn 2.3+ is required." )
617
+ @pytest .mark .skipif (not FlashAttentionUtils .v2_3_plus , reason = "Flash-attn 2.3+ is required." )
618
618
@pytest .mark .parametrize ("dtype" , param_types_lean )
619
619
@pytest .mark .parametrize ("model_configs" , [model_configs_alibi_slopes ])
620
620
@pytest .mark .parametrize ("model" , model_configs_alibi_slopes .keys ())
@@ -1456,7 +1456,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
1456
1456
):
1457
1457
pytest .skip ("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7" )
1458
1458
1459
- if FAUtils .v3_is_installed and not is_training and "padding" not in config .attn_mask_type :
1459
+ if FlashAttentionUtils .v3_is_installed and not is_training and "padding" not in config .attn_mask_type :
1460
1460
os .environ ["NVTE_FLASH_ATTN" ] = "1"
1461
1461
os .environ ["NVTE_FUSED_ATTN" ] = "0"
1462
1462
_attention_backends ["backend_selection_requires_update" ] = True
@@ -1482,7 +1482,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
1482
1482
rtol = 5e-1
1483
1483
rmse_tol = 0.15
1484
1484
logging .debug ("========== {:^25s} ==========" .format ("forward output" ))
1485
- if FAUtils .v3_is_installed and not is_training and "padding" not in config .attn_mask_type :
1485
+ if FlashAttentionUtils .v3_is_installed and not is_training and "padding" not in config .attn_mask_type :
1486
1486
_error (
1487
1487
flash_attn_fwd_fp8 ,
1488
1488
fused_attn_fwd_f16 ,
@@ -1667,7 +1667,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
1667
1667
os .environ ["NVTE_FP8_DPA_BWD" ] = "1" if fp8_dpa_bwd else "0"
1668
1668
os .environ ["NVTE_ALLOW_NONDETERMINISTIC_ALGO" ] = "1"
1669
1669
1670
- if FAUtils .v3_is_installed and not is_training and "padding" not in config .attn_mask_type :
1670
+ if FlashAttentionUtils .v3_is_installed and not is_training and "padding" not in config .attn_mask_type :
1671
1671
os .environ ["NVTE_FLASH_ATTN" ] = "1"
1672
1672
os .environ ["NVTE_FUSED_ATTN" ] = "0"
1673
1673
_attention_backends ["backend_selection_requires_update" ] = True
@@ -1696,7 +1696,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
1696
1696
rmse_tol = 0.11
1697
1697
bwd_names = ["dq" , "dk" , "dv" ]
1698
1698
logging .debug ("========== {:^25s} ==========" .format ("forward output" ))
1699
- if FAUtils .v3_is_installed and not is_training and "padding" not in config .attn_mask_type :
1699
+ if FlashAttentionUtils .v3_is_installed and not is_training and "padding" not in config .attn_mask_type :
1700
1700
_error (
1701
1701
flash_attn_fwd_fp8 ,
1702
1702
fused_attn_fwd_f16 ,
0 commit comments