From b3eb60ab0d7fdf3909a45611214da0f45bbd5530 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 4 Oct 2025 01:59:08 +0000 Subject: [PATCH 1/2] Require cuDNN 9.14.0+ for fused attention with FP8 current scaling Signed-off-by: Tim Moon --- .../pytorch/attention/dot_product_attention/utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index ea7b0e8763..bc5abe5755 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -472,10 +472,13 @@ def get_attention_backend( if ( use_fused_attention and fp8_recipe.float8_current_scaling() - and device_compute_capability < (10, 0) ): - logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100") - use_fused_attention = False + if device_compute_capability < (10, 0): + logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100") + use_fused_attention = False + elif cudnn_version < (9, 14, 0): + logger.debug("Disabling FusedAttention for FP8 current scaling with cuDNN < 9.14.0") + use_fused_attention = False # Filter: KV cache # backend | precision | KV cache | architecture | qkv_format | page_size From 252be729a784cc35ccf22fa2e4e7445d539fadca Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 4 Oct 2025 02:04:22 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/attention/dot_product_attention/utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index bc5abe5755..03e67b5a4c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -469,10 +469,7 @@ def get_attention_backend( fp8_recipe = fp8_meta["recipe"] if fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] - if ( - use_fused_attention - and fp8_recipe.float8_current_scaling() - ): + if use_fused_attention and fp8_recipe.float8_current_scaling(): if device_compute_capability < (10, 0): logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100") use_fused_attention = False