diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 8b26a1760d..eeb56f03d6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -469,13 +469,13 @@ def get_attention_backend( fp8_recipe = fp8_meta["recipe"] if fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] - if ( - use_fused_attention - and fp8_recipe.float8_current_scaling() - and device_compute_capability < (10, 0) - ): - logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100") - use_fused_attention = False + if use_fused_attention and fp8_recipe.float8_current_scaling(): + if device_compute_capability < (10, 0): + logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100") + use_fused_attention = False + elif cudnn_version < (9, 14, 0): + logger.debug("Disabling FusedAttention for FP8 current scaling with cuDNN < 9.14.0") + use_fused_attention = False # Filter: KV cache # backend | precision | KV cache | architecture | qkv_format | page_size