From 66c66272e6ede638f3b921ffe998f1ef0c877c6d Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 16 Sep 2025 01:26:33 -0700 Subject: [PATCH 01/15] add max_score for fused/unfused F16 non-CP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .gitmodules | 3 +- 3rdparty/cudnn-frontend | 2 +- tests/pytorch/attention/test_attention.py | 54 ++- tests/pytorch/utils.py | 3 + .../common/fused_attn/fused_attn.cpp | 32 +- .../fused_attn_f16_arbitrary_seqlen.cu | 324 +++++++++++------- .../fused_attn_f16_arbitrary_seqlen.h | 6 +- .../common/fused_attn/fused_attn_fp8.cu | 8 +- transformer_engine/common/fused_attn/utils.h | 7 +- .../include/transformer_engine/fused_attn.h | 42 ++- .../jax/csrc/extensions/attention.cpp | 18 +- .../dot_product_attention/backends.py | 38 +- .../dot_product_attention.py | 8 + .../attention/dot_product_attention/utils.py | 22 ++ .../pytorch/cpp_extensions/fused_attn.py | 14 +- transformer_engine/pytorch/csrc/extensions.h | 4 +- .../pytorch/csrc/extensions/attention.cpp | 69 ++-- 17 files changed, 422 insertions(+), 232 deletions(-) diff --git a/.gitmodules b/.gitmodules index 21492db5ef..f41b9f5c50 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,4 +3,5 @@ url = https://github.com/google/googletest.git [submodule "3rdparty/cudnn-frontend"] path = 3rdparty/cudnn-frontend - url = https://github.com/NVIDIA/cudnn-frontend.git + url = https://gitlab-master.nvidia.com/cudnn/cudnn_frontend.git + branch = feature/muon diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index deda80e537..d210115c3b 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit deda80e5372d50e925d7bf4f76c5db779be3fbd5 +Subproject commit d210115c3ba176d45247801f0f2f8464bca935e7 diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 56bfa14234..c0c53d5bc9 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -166,7 +166,7 @@ def test_dot_product_attention( # UnfusedDotProductAttention backend if unfused_attn_supported: - unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( + unfused_attn_fwd, max_score, unfused_attn_bwd = _run_dot_product_attention( dtype, config, "UnfusedDotProductAttention", @@ -180,7 +180,7 @@ def test_dot_product_attention( # FusedAttention backend if fused_attn_supported: if len(fused_attn_backends) == 1: - fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( + fused_attn_fwd, max_score, fused_attn_bwd = _run_dot_product_attention( dtype, config, "FusedAttention", @@ -192,7 +192,7 @@ def test_dot_product_attention( ) if len(fused_attn_backends) == 2: os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" - fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( + fused_attn_fwd, max_score, fused_attn_bwd = _run_dot_product_attention( dtype, config, "FusedAttention", @@ -203,7 +203,7 @@ def test_dot_product_attention( is_training, ) os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" - fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention( + fused_attn_fwd_1, max_score_1, fused_attn_bwd_1 = _run_dot_product_attention( dtype, config, "FusedAttention", @@ -216,7 +216,7 @@ def test_dot_product_attention( # FlashAttention backend if flash_attn_supported: - flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( + flash_attn_fwd, max_score, flash_attn_bwd = _run_dot_product_attention( dtype, config, "FlashAttention", @@ -259,6 +259,32 @@ def test_dpa_checkpoint(dtype, model_configs, model): test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False) +model_configs_max_score = { + # test: b, h, hg, d + "max_score_1_0": ModelConfig(8, 128, 16, 64), + "max_score_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256), + "max_score_2_0": ModelConfig(2, 2048, 24, 128), + "max_score_2_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096), + "max_score_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048), + "max_score_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048), + "max_score_4_0": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048), + "max_score_4_1": ModelConfig(8, 128, 16, 192, max_seqlen_kv=2048), + "max_score_5_0": ModelConfig(8, 1, 16, 512, max_seqlen_kv=2048), + "max_score_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048), + "max_score_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048), + "max_score_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048), +} +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("model_configs", [model_configs_max_score]) +@pytest.mark.parametrize("model", model_configs_max_score.keys()) +def test_dpa_max_score(dtype, model_configs, model): + """Test DotProductAttention module with checkpointing""" + config = model_configs[model] + config.return_max_score=True + test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False) + + model_configs_mla = { # test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0 @@ -989,6 +1015,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: tp_group=None, layer_number=1, attention_type=config.attn_type, + return_max_score=config.return_max_score, ).to(dtype=dtype, device="cuda") if not is_training: block = block.eval() @@ -1004,7 +1031,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: k = inp[1] v = inp[2] d_out = out_grad - out = block( + out, max_score = block( q, k, v, @@ -1025,13 +1052,18 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: fast_zero_fill=True, ) if is_training: - out.backward(d_out) + out.backward((d_out, torch.zeros(1,device="cuda"))) + + if config.return_max_score: + out = (out, max_score) + else: + out = (out, None) if backend in ["FlashAttention", "UnfusedDotProductAttention"]: if is_training: - return out, (q.grad, k.grad, v.grad) + return *out, (q.grad, k.grad, v.grad) else: - return out, (None, None, None) + return *out, (None, None, None) if backend == "FusedAttention": if qkv_format == "thd" and pad_between_seqs: out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) @@ -1065,9 +1097,9 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: return out_orig, (None, None, None) else: if is_training: - return out, (q.grad, k.grad, v.grad) + return *out, (q.grad, k.grad, v.grad) else: - return out, (None, None, None) + return *out, (None, None, None) model_configs_te_layer = { diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 38f400f659..72f4919549 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -153,6 +153,7 @@ def __init__( alibi_type: str = "none", bias_shape: str = "1hss", window_size: Tuple[int, int] = (-1, -1), + return_max_score = False, total_requests: int = None, max_ctx_len: int = None, num_layers: int = 1, @@ -178,6 +179,7 @@ def __init__( self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross" self.bias_shape = bias_shape self.window_size = window_size + self.return_max_score = return_max_score self.total_requests = total_requests self.max_ctx_len = max_ctx_len self.num_layers = num_layers @@ -263,6 +265,7 @@ def test(): fp8_meta=fp8_meta, is_training=is_training, inference_params=inference_params, + return_max_score=config.return_max_score, ) ( use_flash_attention, diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 60b10862e6..c001b00d1d 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -137,7 +137,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, bool return_max_sum_exp) { using namespace transformer_engine; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; const int device_id = cuda::current_device(); @@ -185,7 +185,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && !requires_64bit_ragged_offset && // 9.10.0: known bugs with SDPA FP8 - (cudnn_runtime_version != 91000)) { + (cudnn_runtime_version != 91000) && + !return_max_sum_exp) { if (cudnn_runtime_version >= 8900) { backend = NVTE_Fused_Attn_Backend::NVTE_FP8; } else { @@ -213,7 +214,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) || (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) && ((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) && - !requires_64bit_ragged_offset) { + !requires_64bit_ragged_offset && + !return_max_sum_exp) { flag_m512 = true; } if ( @@ -409,7 +411,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + bool return_max_sum_exp, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { @@ -448,7 +450,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, - max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_sum_exp); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -462,7 +464,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd_qkvpacked( - b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type, + b, h, max_seqlen, d, t, is_training, return_max_sum_exp, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_QKV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); @@ -530,7 +532,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, - max_seqlen, d, d, window_size_left, window_size_right); + max_seqlen, d, d, window_size_left, window_size_right, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -584,7 +586,7 @@ void nvte_fused_attn_fwd_kvpacked( NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_sum_exp, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { @@ -661,7 +663,7 @@ void nvte_fused_attn_fwd_kvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); + max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, return_max_sum_exp); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -676,7 +678,7 @@ void nvte_fused_attn_fwd_kvpacked( #if (CUDNN_VERSION >= 8903) fused_attn_arbitrary_seqlen_fwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v, - page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, return_max_sum_exp, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, @@ -756,7 +758,7 @@ void nvte_fused_attn_bwd_kvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); + max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -815,7 +817,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + bool return_max_sum_exp, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { @@ -887,7 +889,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, return_max_sum_exp); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -902,7 +904,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, - page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, return_max_sum_exp, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, @@ -979,7 +981,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 4e6c3c858b..f53547b314 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -52,11 +52,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, - int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, + int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, bool generate_max_sum_exp, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, - void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, + void *devPtrS1, void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { @@ -99,6 +99,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( s_kv = is_ragged_kv ? max_t_kv : s_kv; } const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; + bool generate_stats = !generate_max_sum_exp; try { FADescriptor_v1 descriptor{b, @@ -126,7 +127,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( window_size_right, true, tensorType, - tensorType}; + tensorType, + generate_stats, + generate_max_sum_exp, + }; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -136,7 +140,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr, // V std::shared_ptr, // attn_scale std::shared_ptr, // O - std::shared_ptr, // Stats + std::shared_ptr, // S1 + std::shared_ptr, // S2 std::shared_ptr, // bias std::shared_ptr, // seq_q std::shared_ptr, // seq_kv @@ -238,6 +243,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl( sdpa_options = fe::graph::SDPA_attributes() .set_name("flash_attention") .set_is_inference(false) + .set_generate_stats(generate_stats) + .set_generate_max(generate_max_sum_exp) + .set_generate_sum_exp(generate_max_sum_exp) .set_causal_mask(is_causal) .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); @@ -302,7 +310,15 @@ void fused_attn_arbitrary_seqlen_fwd_impl( sdpa_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } - auto [O, Stats] = mha_graph->sdpa(Q, K, V, sdpa_options); + auto sdpa_outputs = mha_graph->sdpa_internal(Q, K, V, std::move(sdpa_options)); + std::shared_ptr O, Stats, Max, Sum_Exp; + O = sdpa_outputs.O; + if (generate_max_sum_exp) { + Max = sdpa_outputs.Max; + Sum_Exp = sdpa_outputs.Sum_exp; + } else { + Stats = sdpa_outputs.Stats; + } std::vector o_stride(4); generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, @@ -317,17 +333,22 @@ void fused_attn_arbitrary_seqlen_fwd_impl( O->set_ragged_offset(offset_o); } - Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); - if (is_ragged_q && cudnn_runtime_version >= 90600) { - offset_stats = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_stats") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); + if (generate_max_sum_exp) { + Max->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}).set_stride({h * s_q, s_q, 1, 1}); + Sum_Exp->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}).set_stride({h * s_q, s_q, 1, 1}); } else { - Stats->set_stride({h * s_q, s_q, 1, 1}); + Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); + if (is_ragged_q && cudnn_runtime_version >= 90600) { + offset_stats = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_stats") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); + } else { + Stats->set_stride({h * s_q, s_q, 1, 1}); + } } std::tuple, // Q @@ -336,7 +357,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr, // attn_scale std::shared_ptr> // O key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O); - auto Stats_tuple = std::make_tuple(Stats); + auto Stats_tuple = generate_stats ? std::make_tuple(Stats, nullptr) : std::make_tuple(Max, Sum_Exp); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); @@ -366,7 +387,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( return return_tuple; }; - auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, page_table_k, page_table_v, + auto [mha_graph, Q, K, V, attn_scale, O, S1, S2, bias, seq_q, seq_kv, page_table_k, page_table_v, offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] = get_graph(sdpa_f16_fprop_cache, descriptor); @@ -401,7 +422,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::unordered_map, void *> variant_pack = { {Q, devPtrQ}, {K, devPtrK}, {V, devPtrV}, {attn_scale, &scaling_factor}, - {O, devPtrO}, {Stats, devPtrSoftmaxStats}}; + {O, devPtrO}, {S1, devPtrS1}}; + + if (generate_max_sum_exp) { + variant_pack[S2] = devPtrS2; + } if (is_bias) { variant_pack[bias] = devPtrBias; @@ -562,7 +587,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( window_size_right, deterministic, tensorType, - tensorType}; + tensorType, + true, + true, + }; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -948,7 +976,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( using namespace transformer_engine::fused_attn; void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + bool is_training, bool return_max_sum_exp, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, @@ -979,7 +1007,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( } void *devPtrO = output_O->data.dptr; - void *devPtrS = nullptr; + void *devPtrS1 = nullptr; + void *devPtrS2 = nullptr; void *devPtrCuSeqlens = cu_seqlens->data.dptr; void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr; @@ -990,29 +1019,28 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( max_tokens = get_max_tokens(num_tokens); } + size_t i = 0; if (Aux_CTX_Tensors->size == 0) { const auto cudnn_runtime_version = cudnnGetVersion(); - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Aux_CTX_Tensors->size = 3; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; + if (return_max_sum_exp) { + Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_Max->data.dptr = nullptr; if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens, num_attn_heads, 1}; + output_Max->data.shape = {max_tokens, num_attn_heads, 1}; } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + output_Max->data.shape = {batch, num_attn_heads, max_seqlen, 1}; } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen}; - output_bias->data.dtype = QKV_type; + output_Max->data.dtype = DType::kFloat32; + Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_Sum_Exp->data.dptr = nullptr; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_Sum_Exp->data.shape = {max_tokens, num_attn_heads, 1}; + } else { + output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + } + output_Sum_Exp->data.dtype = DType::kFloat32; } else { - Aux_CTX_Tensors->size = 2; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_S->data.dptr = nullptr; if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { output_S->data.shape = {max_tokens, num_attn_heads, 1}; @@ -1020,23 +1048,37 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; } output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; } - } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; + + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_bias->data.dptr = nullptr; + output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen}; + output_bias->data.dtype = QKV_type; + } + + Aux_CTX_Tensors->size = i; + } else if (Aux_CTX_Tensors->size >= 2) { + if (return_max_sum_exp) { + Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS1 = output_Max->data.dptr; + Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS2 = output_Sum_Exp->data.dptr; + } else { + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS1 = output_S->data.dptr; + } + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = devPtrBias; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_bias->data.dptr = devPtrBias; + } } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); } @@ -1049,9 +1091,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, - max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training, + max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training, return_max_sum_exp, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, + devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr, nullptr, devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); @@ -1159,7 +1201,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, - size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, + size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_sum_exp, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, @@ -1194,7 +1236,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( } void *devPtrO = output_O->data.dptr; - void *devPtrS = nullptr; + void *devPtrS1 = nullptr; + void *devPtrS2 = nullptr; void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; @@ -1216,29 +1259,28 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( max_tokens_kv = get_max_tokens(num_tokens_kv); } + size_t i = 0; if (Aux_CTX_Tensors->size == 0) { const auto cudnn_runtime_version = cudnnGetVersion(); - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Aux_CTX_Tensors->size = 3; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; + if (return_max_sum_exp) { + Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_Max->data.dptr = nullptr; if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + output_Max->data.shape = {max_tokens_q, num_attn_heads, 1}; } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; - output_bias->data.dtype = QKV_type; + output_Max->data.dtype = DType::kFloat32; + Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_Sum_Exp->data.dptr = nullptr; + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_Sum_Exp->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } + output_Sum_Exp->data.dtype = DType::kFloat32; } else { - Aux_CTX_Tensors->size = 2; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_S->data.dptr = nullptr; if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; @@ -1246,23 +1288,37 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; } output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; } - } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; + + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_bias->data.dptr = nullptr; + output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; + output_bias->data.dtype = QKV_type; + } + + Aux_CTX_Tensors->size = i; + } else if (Aux_CTX_Tensors->size >= 2) { + if (return_max_sum_exp) { + Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS1 = output_Max->data.dptr; + Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS2 = output_Sum_Exp->data.dptr; + } else { + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS1 = output_S->data.dptr; + } + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = devPtrBias; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_bias->data.dptr = devPtrBias; + } } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); } @@ -1276,9 +1332,9 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, return_max_sum_exp, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, + devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); @@ -1400,7 +1456,7 @@ void fused_attn_arbitrary_seqlen_fwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + bool return_max_sum_exp, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, @@ -1416,7 +1472,8 @@ void fused_attn_arbitrary_seqlen_fwd( void *devPtrK = input_K->data.dptr; void *devPtrV = input_V->data.dptr; void *devPtrO = output_O->data.dptr; - void *devPtrS = nullptr; + void *devPtrS1 = nullptr; + void *devPtrS2 = nullptr; void *devPtrBias = nullptr; size_t bias_b = 0; size_t bias_h = 0; @@ -1446,29 +1503,28 @@ void fused_attn_arbitrary_seqlen_fwd( max_tokens_kv = get_max_tokens(num_tokens_kv); } + size_t i = 0; if (Aux_CTX_Tensors->size == 0) { const auto cudnn_runtime_version = cudnnGetVersion(); - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Aux_CTX_Tensors->size = 3; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; + if (return_max_sum_exp) { + Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_Max->data.dptr = nullptr; if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + output_Max->data.shape = {max_tokens_q, num_attn_heads, 1}; } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; - output_bias->data.dtype = QKV_type; + output_Max->data.dtype = DType::kFloat32; + Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_Sum_Exp->data.dptr = nullptr; + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_Sum_Exp->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } + output_Sum_Exp->data.dtype = DType::kFloat32; } else { - Aux_CTX_Tensors->size = 2; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_S->data.dptr = nullptr; if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; @@ -1476,23 +1532,37 @@ void fused_attn_arbitrary_seqlen_fwd( output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; } output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; } - } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; + + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_bias->data.dptr = nullptr; + output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; + output_bias->data.dtype = QKV_type; + } + + Aux_CTX_Tensors->size = i; + } else if (Aux_CTX_Tensors->size >= 2) { + if (return_max_sum_exp) { + Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS1 = output_Max->data.dptr; + Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS2 = output_Sum_Exp->data.dptr; + } else { + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS1 = output_S->data.dptr; + } + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = devPtrBias; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_bias->data.dptr = devPtrBias; + } } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); } @@ -1506,9 +1576,9 @@ void fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, return_max_sum_exp, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, + devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index e1a20274f4..5b549dfccb 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -20,7 +20,7 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + bool is_training, bool return_max_sum_exp, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, @@ -39,7 +39,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, - size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, + size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_sum_exp, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, @@ -64,7 +64,7 @@ void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, - size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, + size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_sum_exp, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index d7f0983763..d23599165e 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1699,7 +1699,9 @@ void fused_attn_fp8_fwd_impl_v1( 0, true, fwd_tensor_type, - fwd_tensor_type}; + fwd_tensor_type, + false, + true}; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -2004,7 +2006,9 @@ void fused_attn_fp8_bwd_impl_v1( 0, false, fwd_tensor_type, - bwd_tensor_type}; + bwd_tensor_type, + false, + true}; namespace fe = cudnn_frontend; using graph_and_tensors = diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 678b636910..646bc638f0 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -112,18 +112,21 @@ struct FADescriptor_v1 { bool deterministic; cudnn_frontend::DataType_t fwd_tensor_type; cudnn_frontend::DataType_t bwd_tensor_type; + bool generate_stats; + bool generate_max_sum_exp; bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left, - window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < + window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type, + generate_stats, generate_max_sum_exp) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type, - rhs.bwd_tensor_type); + rhs.bwd_tensor_type, rhs.generate_stats, rhs.generate_max_sum_exp); } }; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 44f5791490..3573c8237f 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -172,27 +172,28 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); /*! \brief Get fused attention backend based on input parameters. * - * \param[in] is_training Whether the model is in training mode. - * \param[in] q_dtype The data type of Tensor Q. - * \param[in] kv_dtype The data type of Tensors K, V. - * \param[in] qkv_layout The layout of Tensors Q, K, V. - * \param[in] bias_type The attention bias type. - * \param[in] attn_mask_type The attention mask type. - * \param[in] dropout The dropout probability. - * \param[in] num_attn_heads The number of heads in Q. - * \param[in] num_gqa_groups The number of heads in K, V. - * \param[in] max_seqlen_q The sequence length of Q. - * \param[in] max_seqlen_kv The sequence length of K, V. - * \param[in] head_dim_qk The head dimension of Q, K. - * \param[in] head_dim_v The head dimension of V. - * \param[in] window_size_left Sliding window size (the left half). - * \param[in] window_size_right Sliding window size (the right half). + * \param[in] is_training Whether the model is in training mode. + * \param[in] q_dtype The data type of Tensor Q. + * \param[in] kv_dtype The data type of Tensors K, V. + * \param[in] qkv_layout The layout of Tensors Q, K, V. + * \param[in] bias_type The attention bias type. + * \param[in] attn_mask_type The attention mask type. + * \param[in] dropout The dropout probability. + * \param[in] num_attn_heads The number of heads in Q. + * \param[in] num_gqa_groups The number of heads in K, V. + * \param[in] max_seqlen_q The sequence length of Q. + * \param[in] max_seqlen_kv The sequence length of K, V. + * \param[in] head_dim_qk The head dimension of Q, K. + * \param[in] head_dim_v The head dimension of V. + * \param[in] window_size_left Sliding window size (the left half). + * \param[in] window_size_right Sliding window size (the right half). + * \param[in] return_max_sum_exp Whether to produce Max and Sum_Exp, or Stats. */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, bool return_max_sum_exp); /*! \brief Compute dot product attention with packed QKV input. * @@ -234,6 +235,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] max_seqlen Max sequence length used for computing, * it may be >= max(seqlen_i) for i=0,...batch_size-1. * \param[in] is_training Whether this is in training mode or inference. + * \param[in] return_max_sum_exp Whether to produce Max and Sum_Exp, or Stats. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensor's layout. @@ -248,7 +250,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + bool return_max_sum_exp, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); @@ -356,6 +358,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] max_seqlen_kv Max sequence length used for computing for KV. * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * \param[in] is_training Whether this is in training mode or inference. + * \param[in] return_max_sum_exp Whether to produce Max and Sum_Exp, or Stats. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensor's layout. @@ -372,7 +375,7 @@ void nvte_fused_attn_fwd_kvpacked( NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_sum_exp, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); @@ -489,6 +492,7 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] max_seqlen_kv Max sequence length used for computing for K and V. * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * \param[in] is_training Whether this is in training mode or inference. + * \param[in] return_max_sum_exp Whether to produce Max and Sum_Exp, or Stats. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensors' layout. @@ -506,7 +510,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + bool return_max_sum_exp, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 40089dc2d6..914fc96d1a 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -21,7 +21,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy auto backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, - kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); + kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, false); return backend; } @@ -174,7 +174,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( nvte_fused_attn_fwd_qkvpacked( qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), - dummy_rng_state_tensor.data(), q_max_seqlen, is_training, scaling_factor, + dummy_rng_state_tensor.data(), q_max_seqlen, is_training, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { @@ -183,7 +183,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, - kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, + kv_max_seqlen, is_training, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { nvte_fused_attn_fwd( @@ -191,7 +191,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), - dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, + dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } else { @@ -265,7 +265,7 @@ static void FusedAttnForwardImpl( auto backend = nvte_get_fused_attn_backend( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, - kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); + kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, false); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -283,7 +283,7 @@ static void FusedAttnForwardImpl( nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), rng_state_tensor.data(), - q_max_seqlen, is_training, scaling_factor, dropout_probability, + q_max_seqlen, is_training, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { @@ -297,7 +297,7 @@ static void FusedAttnForwardImpl( &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, - is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + is_training, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; @@ -311,7 +311,7 @@ static void FusedAttnForwardImpl( o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), - q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, + q_max_seqlen, kv_max_seqlen, is_training, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); } else { NVTE_ERROR("Unsupported qkv_layout."); @@ -522,7 +522,7 @@ static void FusedAttnBackwardImpl( auto backend = nvte_get_fused_attn_backend( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, - kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); + kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, false); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias); diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index afa1bae633..5e83beee6f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -142,6 +142,7 @@ def __init__( attention_dropout: float = 0.0, attention_dropout_ctx: Optional[Callable] = nullcontext, layer_number: Optional[int] = None, + return_max_score: Optional[bool] = False, ) -> None: super().__init__() @@ -149,6 +150,7 @@ def __init__( self.attention_type = attention_type self.attention_dropout_ctx = attention_dropout_ctx self.layer_number = layer_number + self.return_max_score = return_max_score def mask_func(x, y): return ( @@ -337,6 +339,12 @@ def forward( if "padding" in attn_mask_type: attention_probs = attention_probs.masked_fill(attention_mask, 0) + # max attention score + max_score = None + if self.return_max_score: + max_indices = torch.argmax(attention_probs, dim=-1, keepdim=True) + max_score = torch.max(matmul_result.gather(dim=-1, index=max_indices)) + # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with self.attention_dropout_ctx(): @@ -391,6 +399,8 @@ def forward( # [tq, np, hn] --> [tq, hp] context_layer = context_layer.view(total_tokens, -1) + if self.return_max_score: + return context_layer, max_score return context_layer @@ -442,6 +452,7 @@ def __init__( attention_type: str = "self", layer_number: Optional[int] = None, deterministic: bool = False, + return_max_score: Optional[bool] = False, ) -> None: super().__init__() @@ -465,6 +476,7 @@ def __init__( self.logger.setLevel(attn_log._log_level) if not self.logger.hasHandlers(): self.logger.addHandler(attn_log._stream_handler) + self.return_max_score = return_max_score def forward( self, @@ -660,6 +672,7 @@ def forward( batch_size * context_len, ) + max_score = None use_flash_attn_3 = False if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"): use_flash_attn_3 = True @@ -762,7 +775,11 @@ def forward( softmax_scale=self.softmax_scale, causal="causal" in attn_mask_type, **fa_optional_forward_kwargs, + return_attn_probs=self.return_max_score, ) + #if self.return_max_score and (self.attention_dropout == 0.0 or not self.training): + # output, _, S_dmask = output + # max_score = torch.max(S_dmask) else: fa_3_optional_forward_kwargs = {} fa_3_optional_forward_kwargs["window_size"] = window_size @@ -889,6 +906,8 @@ def convert_to_torch_float8(tensor, dtype): # thd -> t(hd) output = output.reshape(output.shape[0], -1) + if self.return_max_score: + return output.contiguous(), max_score return output.contiguous() @@ -925,6 +944,7 @@ def forward( fp8_meta, quantizers, deterministic, + return_max_score, ): # pylint: disable=missing-function-docstring # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype @@ -939,6 +959,7 @@ def forward( QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) ) + max_score = None if fp8: fused_attention_backend = FusedAttnBackend["FP8"] assert isinstance(k, q.__class__) and isinstance( @@ -1035,7 +1056,7 @@ def forward( fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) else: # q, k, v, out_ret: torch.float16 or torch.bfloat16 - out_ret, aux_ctx_tensors = fused_attn_fwd( + out_ret, aux_ctx_tensors, max_score = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, @@ -1061,6 +1082,7 @@ def forward( attn_mask_type, window_size, rng_gen, + return_max_score, ) out_save = out_ret fp8_tensors = (None, None, None, None) @@ -1121,11 +1143,12 @@ def forward( ctx.use_FAv2_bwd = use_FAv2_bwd ctx.deterministic = deterministic - return out_ret + return out_ret, max_score @staticmethod - def backward(ctx, d_out): + def backward(ctx, *args): # pylint: disable=missing-function-docstring + d_out = args[0] if ctx.is_output_fp8: assert isinstance( d_out, Float8Tensor @@ -1321,6 +1344,7 @@ def backward(ctx, d_out): None, None, None, + None, ) # else, return (dqkv, dbias) return ( @@ -1390,6 +1414,7 @@ def __init__( attention_type: str = "self", layer_number: Optional[int] = None, deterministic: bool = False, + return_max_score: Optional[bool] = False, ) -> None: super().__init__() @@ -1402,6 +1427,7 @@ def __init__( ) == "1" and get_device_compute_capability() == (9, 0) self.layer_number = 1 if layer_number is None else layer_number self.deterministic = deterministic + self.return_max_score = return_max_score def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument """ @@ -1634,7 +1660,11 @@ def forward( fp8_meta, quantizers, self.deterministic, + self.return_max_score, ) + if self.return_max_score and not context_parallel: + # ...hd -> ...(hd) + return output[0].view(*output[0].shape[:-2], -1), output[1] # ...hd -> ...(hd) - return output.view(*output.shape[:-2], -1) + return output[0].view(*output[0].shape[:-2], -1) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index b35b87a83f..a8cbbbfd9b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -168,6 +168,11 @@ class DotProductAttention(TransformerEngineBaseModule): softmax_scale: Optional[float], default = `None` softmax scale for the attention scores. If `None`, defaults to `1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`. + return_max_score: Optional[bool], default = `False` + If true, returns the maximum attention score, max_score = max(S), where + S = Q*K^T and in shape [b, h, s_q, s_kv]. max_score can be used to rescale + the Q and K projection weights in a MuonClip optimizer (see + `Muon is Scalable for LLM Training `_). Parallelism parameters ---------------------- @@ -223,6 +228,7 @@ def __init__( cp_stream: torch.cuda.Stream = None, cp_comm_type: str = "p2p", softmax_scale: Optional[float] = None, + return_max_score: Optional[bool] = False, ) -> None: super().__init__() @@ -306,10 +312,12 @@ def __init__( self.attention_type = attention_type self.attention_dropout = attention_dropout + self.return_max_score = return_max_score attn_kwargs = { "attention_dropout": attention_dropout, "attention_dropout_ctx": attention_dropout_ctx, + "return_max_score": return_max_score, } self.flash_attention = FlashAttention( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 7097f4ba0f..526e3f5982 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -216,6 +216,8 @@ class AttentionParams: The FP8 metadata tensor of `DotProductAttention`. inference_params: Optional[InferenceParams], default = `None` Inference-related parameters. See InferenceParams for details. + return_max_score: bool, default = `False` + Whether to output max_score. """ qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor @@ -242,6 +244,7 @@ class AttentionParams: fp8: bool = False fp8_meta: Union[Dict[str, Any], None] = None inference_params: Optional[InferenceParams] = None + return_max_score: bool = False def __eq__(self, other): """ @@ -313,6 +316,7 @@ def get_attention_backend( fp8 = attention_params.fp8 fp8_meta = attention_params.fp8_meta inference_params = attention_params.inference_params + return_max_score = attention_params.return_max_score # Run config logger = logging.getLogger("DotProductAttention") @@ -425,6 +429,23 @@ def get_attention_backend( logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") use_unfused_attention = False + # Filter: Return max_score + if return_max_score: + if context_parallel: + use_flash_attention = False + use_fused_attention = False + use_unfused_attention = False + logger.debug("Disabling all backends for max_score with context parallelism") + if use_flash_attention: + use_flash_attention = False + logger.debug("Disabling FlashAttention for max_score") + #if use_flash_attention and not (attention_dropout == 0.0 or not is_training): + # use_flash_attention = False + # logger.debug("Disabling FlashAttention for max_score with dropout") + #if use_flash_attention and fp8 and fp8_meta["recipe"].fp8_dpa: + # use_flash_attention = False + # logger.debug("Disabling FlashAttention for max_score with FP8 attention") + # Filter: KV cache # backend | precision | KV cache | architecture | qkv_format | page_size # --------------------------------------------------------------------------------------- @@ -783,6 +804,7 @@ def get_attention_backend( head_dim_v, window_size[0], window_size[1], + return_max_score, ) if fused_attention_backend == FusedAttnBackend["No_Backend"]: logger.debug("Disabling FusedAttention as no backend supports the provided input") diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index b9810bf861..f141d0b8c9 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -133,6 +133,7 @@ def fused_attn_fwd( attn_mask_type: str = "padding", window_size: Tuple[int, int] = (-1, -1), rng_gen: torch.Generator = None, + return_max_score: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention FWD for separate QKV input. @@ -205,6 +206,8 @@ def fused_attn_fwd( rng_gen: torch.Generator, default = None random number generator; if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen + return_max_score: bool, default = False + whether to return the maximum attention score Returns ---------- @@ -235,6 +238,7 @@ def fused_attn_fwd( rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen state of the random number generator; [seed, offset], dtype uint64 + max_score: float if return_max_score = True, otherwise None """ if attn_scale is None: @@ -302,10 +306,18 @@ def fused_attn_fwd( attn_bias, rng_gen, rng_elts_per_thread, + return_max_score, ) + if return_max_score: + # output_tensors: out, Max, Sum_Exp + stats = output_tensors[1] + torch.log(output_tensors[2]) + max_score = torch.max(output_tensors[1]) + # still return stats for bwd + return output_tensors[0], stats, max_score + # out, aux_ctx_tensors - return output_tensors[0], output_tensors[1:] + return output_tensors[0], output_tensors[1:], None def fused_attn_bwd( diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 4cb05725bc..47e6fe8fd4 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -75,7 +75,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, bool return_max_sum_exp); std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, @@ -87,7 +87,7 @@ std::vector fused_attn_fwd( const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, - const std::optional rng_gen, size_t rng_elts_per_thread); + const std::optional rng_gen, size_t rng_elts_per_thread, bool return_max_sum_exp); std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 6d835a5c94..a7e71aca5e 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -60,11 +60,11 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, bool return_max_sum_exp) { NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, - max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right); + max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, return_max_sum_exp); return fused_attention_backend; } @@ -79,7 +79,7 @@ std::vector fused_attn_fwd( const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, - const std::optional rng_gen, size_t rng_elts_per_thread) { + const std::optional rng_gen, size_t rng_elts_per_thread, bool return_max_sum_exp) { TensorWrapper te_Q, te_K, te_V, te_O, te_S; auto none = py::none(); @@ -203,7 +203,7 @@ std::vector fused_attn_fwd( &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], + return_max_sum_exp, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -215,41 +215,40 @@ std::vector fused_attn_fwd( // output_tensors = [O, nvte_aux_tensor_pack.tensors] std::vector output_tensors; output_tensors.push_back(o_python); - for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - // allocate memory for nvte_aux_tensor_pack.tensors - at::Tensor output_tensor; - if (nvte_aux_tensor_pack.size >= 2) { - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { - if (i < nvte_aux_tensor_pack.size - 2) { - NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]); - output_tensor = allocateSpace( - nvte_shape_to_vector(temp_shape), - static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); - } else if (i == nvte_aux_tensor_pack.size - 2) { - output_tensor = rng_state; - } else if (i == nvte_aux_tensor_pack.size - 1) { - output_tensor = Bias.value(); - } - } else { - NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]); - output_tensor = - (i < nvte_aux_tensor_pack.size - 1) - ? allocateSpace( - nvte_shape_to_vector(temp_shape), - static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false) - : rng_state; - } - } else { - NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]); - output_tensor = allocateSpace( - nvte_shape_to_vector(temp_shape), - static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); - } + auto set_tensor_param = [&](size_t i, const at::Tensor &output_tensor) { output_tensors.push_back(py::cast(output_tensor)); NVTEBasicTensor temp_data = {output_tensor.data_ptr(), nvte_tensor_type(nvte_aux_tensor_pack.tensors[i]), nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])}; nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); + }; + // allocate memory for nvte_aux_tensor_pack.tensors + // f16_max512 : S [b, h, sq, skv] + // f16_arbitrary: + // return_max_sum_exp=false: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv] + // return_max_sum_exp=true: Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv] + // fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2] + size_t i = 0; + at::Tensor output_tensor; + // intermediate softmax tensor, S or M + output_tensor = + allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), + static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); + set_tensor_param(i++, output_tensor); + // fp8 has an additional softmax stats tensor, ZInv + if (return_max_sum_exp || qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + output_tensor = + allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), + static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); + set_tensor_param(i++, output_tensor); + } + // rng_state + if (i < nvte_aux_tensor_pack.size) { + set_tensor_param(i++, rng_state); + } + // bias (optional) + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { + set_tensor_param(i++, Bias.value()); } // execute the kernel @@ -259,7 +258,7 @@ std::vector fused_attn_fwd( &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], + return_max_sum_exp, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); }); From 8f9155fa4956d2e04559b6bd499a6408aac62a0a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 30 Sep 2025 11:22:15 -0700 Subject: [PATCH 02/15] calculate max per head instead of max over all heads Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../pytorch/attention/dot_product_attention/backends.py | 6 ++++-- transformer_engine/pytorch/cpp_extensions/fused_attn.py | 8 +++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 5e83beee6f..2aa8985a28 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -342,8 +342,10 @@ def forward( # max attention score max_score = None if self.return_max_score: - max_indices = torch.argmax(attention_probs, dim=-1, keepdim=True) - max_score = torch.max(matmul_result.gather(dim=-1, index=max_indices)) + # matmul_result [b, np, sq, dk], max_score [np] + max_score = matmul_result.view(*matmul_result.shape[:2],-1) + max_score = torch.max(max_score, dim=-1)[0] + max_score = torch.max(max_score, dim=0)[0] # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index f141d0b8c9..a0c261745d 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -310,10 +310,12 @@ def fused_attn_fwd( ) if return_max_score: - # output_tensors: out, Max, Sum_Exp + # output_tensors: out [b, sq, h, d] or [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] stats = output_tensors[1] + torch.log(output_tensors[2]) - max_score = torch.max(output_tensors[1]) - # still return stats for bwd + max_score = output_tensors[1].unsqueeze(-1) + max_score = torch.max(max_score, dim=-1)[0] + max_score = torch.max(max_score, dim=0)[0] + # return out [b, sq, h, d] or [sq, b, h, d], stats [b, h, sq, 1], max_score [h] return output_tensors[0], stats, max_score # out, aux_ctx_tensors From efaf827f91e1700d249a8c0daddcc9fcb2d85f02 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 7 Oct 2025 22:51:53 -0700 Subject: [PATCH 03/15] fix fused attn max_score shape Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 16 +++++++++++----- .../pytorch/cpp_extensions/fused_attn.py | 2 +- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index c0c53d5bc9..c8a097c8a7 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -166,7 +166,7 @@ def test_dot_product_attention( # UnfusedDotProductAttention backend if unfused_attn_supported: - unfused_attn_fwd, max_score, unfused_attn_bwd = _run_dot_product_attention( + unfused_attn_fwd, unfused_max_score, unfused_attn_bwd = _run_dot_product_attention( dtype, config, "UnfusedDotProductAttention", @@ -180,7 +180,7 @@ def test_dot_product_attention( # FusedAttention backend if fused_attn_supported: if len(fused_attn_backends) == 1: - fused_attn_fwd, max_score, fused_attn_bwd = _run_dot_product_attention( + fused_attn_fwd, fused_max_score, fused_attn_bwd = _run_dot_product_attention( dtype, config, "FusedAttention", @@ -192,7 +192,7 @@ def test_dot_product_attention( ) if len(fused_attn_backends) == 2: os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" - fused_attn_fwd, max_score, fused_attn_bwd = _run_dot_product_attention( + fused_attn_fwd, fused_max_score, fused_attn_bwd = _run_dot_product_attention( dtype, config, "FusedAttention", @@ -203,7 +203,7 @@ def test_dot_product_attention( is_training, ) os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" - fused_attn_fwd_1, max_score_1, fused_attn_bwd_1 = _run_dot_product_attention( + fused_attn_fwd_1, fused_max_score_1, fused_attn_bwd_1 = _run_dot_product_attention( dtype, config, "FusedAttention", @@ -216,7 +216,7 @@ def test_dot_product_attention( # FlashAttention backend if flash_attn_supported: - flash_attn_fwd, max_score, flash_attn_bwd = _run_dot_product_attention( + flash_attn_fwd, flash_max_score, flash_attn_bwd = _run_dot_product_attention( dtype, config, "FlashAttention", @@ -231,16 +231,22 @@ def test_dot_product_attention( if unfused_attn_supported and flash_attn_supported: logging.info("[test_dot_product_attention]: unfused attn vs flash attn") torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols) + if config.return_max_score: + torch.testing.assert_close(flash_max_score, unfused_max_score, **tols) for i, _ in enumerate(flash_attn_bwd): torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols) if unfused_attn_supported and fused_attn_supported: logging.info("[test_dot_product_attention]: unfused attn vs fused attn") torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) + if config.return_max_score: + torch.testing.assert_close(fused_max_score, unfused_max_score, **tols) for i, _ in enumerate(unfused_attn_bwd): torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols) if fused_attn_supported and flash_attn_supported: logging.info("[test_dot_product_attention]: fused attn vs flash attn") torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) + if config.return_max_score: + torch.testing.assert_close(fused_max_score, flash_max_score, **tols) for i, _ in enumerate(flash_attn_bwd): torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols) if fused_attn_supported and len(fused_attn_backends) == 2: diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index a0c261745d..091968d799 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -312,7 +312,7 @@ def fused_attn_fwd( if return_max_score: # output_tensors: out [b, sq, h, d] or [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] stats = output_tensors[1] + torch.log(output_tensors[2]) - max_score = output_tensors[1].unsqueeze(-1) + max_score = output_tensors[1].squeeze(-1).to(dtype=output_tensors[0].dtype) max_score = torch.max(max_score, dim=-1)[0] max_score = torch.max(max_score, dim=0)[0] # return out [b, sq, h, d] or [sq, b, h, d], stats [b, h, sq, 1], max_score [h] From 290dfb93eff99eb1cb95b9a9d1f7ffcb65f1ab6d Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 10 Oct 2025 04:29:36 -0700 Subject: [PATCH 04/15] revert FE to github Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .gitmodules | 3 +-- 3rdparty/cudnn-frontend | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.gitmodules b/.gitmodules index f41b9f5c50..21492db5ef 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,5 +3,4 @@ url = https://github.com/google/googletest.git [submodule "3rdparty/cudnn-frontend"] path = 3rdparty/cudnn-frontend - url = https://gitlab-master.nvidia.com/cudnn/cudnn_frontend.git - branch = feature/muon + url = https://github.com/NVIDIA/cudnn-frontend.git diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index d210115c3b..deda80e537 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit d210115c3ba176d45247801f0f2f8464bca935e7 +Subproject commit deda80e5372d50e925d7bf4f76c5db779be3fbd5 From c93518bc5585859a4a51c69105d7b741822c03bf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Oct 2025 12:18:48 +0000 Subject: [PATCH 05/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention.py | 8 +- tests/pytorch/utils.py | 2 +- .../common/fused_attn/fused_attn.cpp | 78 +++--- .../fused_attn_f16_arbitrary_seqlen.cu | 223 +++++++++--------- .../fused_attn_f16_arbitrary_seqlen.h | 44 ++-- transformer_engine/common/fused_attn/utils.h | 3 +- .../include/transformer_engine/fused_attn.h | 47 ++-- .../jax/csrc/extensions/attention.cpp | 28 ++- .../dot_product_attention/backends.py | 6 +- .../attention/dot_product_attention/utils.py | 4 +- .../pytorch/cpp_extensions/fused_attn.py | 4 +- .../pytorch/csrc/extensions/attention.cpp | 13 +- 12 files changed, 241 insertions(+), 219 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 8df2f37299..04513350f1 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -281,6 +281,8 @@ def test_dpa_checkpoint(dtype, model_configs, model): "max_score_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048), "max_score_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048), } + + @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("model_configs", [model_configs_max_score]) @@ -288,10 +290,10 @@ def test_dpa_checkpoint(dtype, model_configs, model): def test_dpa_max_score(dtype, model_configs, model): """Test DotProductAttention module with checkpointing""" config = model_configs[model] - config.return_max_score=True + config.return_max_score = True test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False) - + model_configs_softmax = { # test: ModelConfig(b, sq, hq, dqk) "softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8), @@ -1136,7 +1138,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: fast_zero_fill=True, ) if is_training: - out.backward((d_out, torch.zeros(1,device="cuda"))) + out.backward((d_out, torch.zeros(1, device="cuda"))) if config.return_max_score: out = (out, max_score) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index dd440f0c3a..86ffebd2b7 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -205,7 +205,7 @@ def __init__( window_size: Tuple[int, int] = (-1, -1), context_parallel: bool = False, cp_comm_type: str = "p2p", - return_max_score = False, + return_max_score=False, total_requests: int = None, max_ctx_len: int = None, num_layers: int = 1, diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 9e46849165..6c2da5248c 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -187,8 +187,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && // 9.10.0: known bugs with SDPA FP8 - (cudnn_runtime_version != 91000) && - !return_max_sum_exp) { + (cudnn_runtime_version != 91000) && !return_max_sum_exp) { if (cudnn_runtime_version >= 8900) { backend = NVTE_Fused_Attn_Backend::NVTE_FP8; } else { @@ -217,8 +216,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) && ((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) && !requires_64bit_ragged_offset && - (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && - !return_max_sum_exp) { + (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && !return_max_sum_exp) { flag_m512 = true; } if ( @@ -420,8 +418,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, - size_t max_seqlen, bool is_training, bool return_max_sum_exp, float attn_scale, - float dropout, NVTE_QKV_Layout qkv_layout, + size_t max_seqlen, bool is_training, bool return_max_sum_exp, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, @@ -476,10 +474,10 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd_qkvpacked( - b, h, max_seqlen, d, t, is_training, return_max_sum_exp, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, - input_rng_state, wkspace, stream, handle); + b, h, max_seqlen, d, t, is_training, return_max_sum_exp, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV, + input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, + input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -604,10 +602,10 @@ void nvte_fused_attn_fwd_kvpacked( const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, - size_t max_seqlen_kv, bool is_training, bool return_max_sum_exp, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream) { + size_t max_seqlen_kv, bool is_training, bool return_max_sum_exp, float attn_scale, + float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -682,7 +680,8 @@ void nvte_fused_attn_fwd_kvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, return_max_sum_exp); + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, + return_max_sum_exp); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -697,12 +696,12 @@ void nvte_fused_attn_fwd_kvpacked( #if (CUDNN_VERSION >= 8903) fused_attn_arbitrary_seqlen_fwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v, - page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, return_max_sum_exp, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, - window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset, output_O, - Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, - wkspace, stream, handle); + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, + return_max_sum_exp, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size_left, window_size_right, input_Q, input_KV, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, + input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -834,18 +833,16 @@ void nvte_fused_attn_bwd_kvpacked( } } // NVTE fused attention FWD with separate Q, K and V -void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - bool return_max_sum_exp, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_sum_exp, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -915,7 +912,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, return_max_sum_exp); + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, + return_max_sum_exp); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -930,12 +928,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, - page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, return_max_sum_exp, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, - window_size_right, input_Q, input_K, input_V, input_Bias, input_SoftmaxOffset, output_O, - Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, - wkspace, stream, handle); + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, + return_max_sum_exp, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, + input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 9e1728762f..2bd62b7767 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -52,15 +52,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, - int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, bool generate_max_sum_exp, - float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, - void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, - void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, - void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, - void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, - void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, + bool generate_max_sum_exp, float scaling_factor, float dropout_probability, + NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, + void *devPtrS1, void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, + void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, + void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, + void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, + size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -104,37 +105,38 @@ void fused_attn_arbitrary_seqlen_fwd_impl( const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; bool generate_stats = !generate_max_sum_exp; try { - FADescriptor_v1 descriptor{b, - h, - hg, - s_q, - s_kv, - d_qk, - d_v, - num_pages_k, - num_pages_v, - page_size_k, - page_size_v, - max_pages_per_seq_k, - max_pages_per_seq_v, - bias_b, - bias_h, - scaling_factor, - is_training, - dropout_probability, - layout, - bias_type, - mask_type, - softmax_type, - window_size_left, - window_size_right, - true, - tensorType, - cudnn_frontend::DataType_t::NOT_SET, - cudnn_frontend::DataType_t::NOT_SET, - cudnn_frontend::DataType_t::NOT_SET, - generate_stats, - generate_max_sum_exp, + FADescriptor_v1 descriptor{ + b, + h, + hg, + s_q, + s_kv, + d_qk, + d_v, + num_pages_k, + num_pages_v, + page_size_k, + page_size_v, + max_pages_per_seq_k, + max_pages_per_seq_v, + bias_b, + bias_h, + scaling_factor, + is_training, + dropout_probability, + layout, + bias_type, + mask_type, + softmax_type, + window_size_left, + window_size_right, + true, + tensorType, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET, + generate_stats, + generate_max_sum_exp, }; namespace fe = cudnn_frontend; @@ -329,7 +331,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr O, Stats, Max, Sum_Exp; O = sdpa_outputs.O; if (generate_max_sum_exp) { - Max = sdpa_outputs.Max; + Max = sdpa_outputs.Max; Sum_Exp = sdpa_outputs.Sum_exp; } else { Stats = sdpa_outputs.Stats; @@ -349,8 +351,14 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } if (generate_max_sum_exp) { - Max->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}).set_stride({h * s_q, s_q, 1, 1}); - Sum_Exp->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}).set_stride({h * s_q, s_q, 1, 1}); + Max->set_output(true) + .set_data_type(fe::DataType_t::FLOAT) + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}); + Sum_Exp->set_output(true) + .set_data_type(fe::DataType_t::FLOAT) + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}); } else { Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); if (is_ragged_q && cudnn_runtime_version >= 90600) { @@ -372,7 +380,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr, // attn_scale std::shared_ptr> // O key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O); - auto Stats_tuple = generate_stats ? std::make_tuple(Stats, nullptr) : std::make_tuple(Max, Sum_Exp); + auto Stats_tuple = + generate_stats ? std::make_tuple(Stats, nullptr) : std::make_tuple(Max, Sum_Exp); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); auto softmax_offset_tuple = is_softmax_offset ? std::make_tuple(softmax_offset) : std::make_tuple(nullptr); @@ -438,8 +447,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( // Build variant pack std::unordered_map, void *> variant_pack = { - {Q, devPtrQ}, {K, devPtrK}, - {V, devPtrV}, {attn_scale, &scaling_factor}, + {Q, devPtrQ}, {K, devPtrK}, {V, devPtrV}, {attn_scale, &scaling_factor}, {O, devPtrO}, {S1, devPtrS1}}; if (generate_max_sum_exp) { @@ -586,37 +594,38 @@ void fused_attn_arbitrary_seqlen_bwd_impl( const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; try { - FADescriptor_v1 descriptor{b, - h, - hg, - s_q, - s_kv, - d_qk, - d_v, - 0, - 0, - 0, - 0, - 0, - 0, - bias_b, - bias_h, - scaling_factor, - true, - dropout_probability, - layout, - bias_type, - mask_type, - softmax_type, - window_size_left, - window_size_right, - deterministic, - tensorType, - cudnn_frontend::DataType_t::NOT_SET, - cudnn_frontend::DataType_t::NOT_SET, - cudnn_frontend::DataType_t::NOT_SET, - true, - true, + FADescriptor_v1 descriptor{ + b, + h, + hg, + s_q, + s_kv, + d_qk, + d_v, + 0, + 0, + 0, + 0, + 0, + 0, + bias_b, + bias_h, + scaling_factor, + true, + dropout_probability, + layout, + bias_type, + mask_type, + softmax_type, + window_size_left, + window_size_right, + deterministic, + tensorType, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET, + true, + true, }; namespace fe = cudnn_frontend; @@ -1029,12 +1038,13 @@ void fused_attn_arbitrary_seqlen_bwd_impl( using namespace transformer_engine::fused_attn; void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - bool is_training, bool return_max_sum_exp, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + bool is_training, bool return_max_sum_exp, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, + const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_QKV->data.dtype; @@ -1162,9 +1172,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, - devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr, - nullptr, devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), + window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, + devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, + nullptr, nullptr, devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { @@ -1277,12 +1287,13 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, - size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_sum_exp, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, + bool return_max_sum_exp, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, + const Tensor *input_KV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -1425,9 +1436,9 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, - devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, + devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, + devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { @@ -1555,14 +1566,14 @@ void fused_attn_arbitrary_seqlen_fwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, - bool return_max_sum_exp, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, - const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + bool return_max_sum_exp, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; @@ -1693,9 +1704,9 @@ void fused_attn_arbitrary_seqlen_fwd( max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, - devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, + devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, + devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 5cd285740b..8cef0cd3f7 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -20,12 +20,13 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - bool is_training, bool return_max_sum_exp, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + bool is_training, bool return_max_sum_exp, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, + const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, @@ -41,12 +42,13 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, - size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_sum_exp, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, + bool return_max_sum_exp, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, + const Tensor *input_KV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); @@ -67,15 +69,15 @@ void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, - size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_sum_exp, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, - const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, + bool return_max_sum_exp, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 9f650c157c..b3cd815ddd 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -123,7 +123,8 @@ struct FADescriptor_v1 { page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type, - o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_stats, generate_max_sum_exp) < + o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_stats, + generate_max_sum_exp) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 4fb1fe4d7f..8a8d870edd 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -268,13 +268,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd_qkvpacked( - const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, - bool is_training, bool return_max_sum_exp, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, + NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, + size_t max_seqlen, bool is_training, bool return_max_sum_exp, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, + cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed QKV input. * @@ -402,10 +405,10 @@ void nvte_fused_attn_fwd_kvpacked( const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, - size_t max_seqlen_kv, bool is_training, bool return_max_sum_exp, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream); + size_t max_seqlen_kv, bool is_training, bool return_max_sum_exp, float attn_scale, + float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed KV input. * @@ -535,18 +538,16 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - bool return_max_sum_exp, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_sum_exp, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index fc1857d7cd..ffc0706fe7 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -22,7 +22,8 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy auto backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, - q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, false); + q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, + false); return backend; } @@ -179,17 +180,18 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training, - false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); + false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, query_workspace_tensor.data(), + nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { nvte_fused_attn_fwd_kvpacked( q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), - dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, query_workspace_tensor.data(), nullptr); + dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { nvte_fused_attn_fwd( q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), @@ -197,8 +199,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, - kv_max_seqlen, is_training, false, scaling_factor, dropout_probability, qkv_layout, bias_type, - mask_type, softmax_type, window_size_left, window_size_right, + kv_max_seqlen, is_training, false, scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, softmax_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } else { NVTE_ERROR("Unsupported QKVLayout."); @@ -276,7 +278,8 @@ static void FusedAttnForwardImpl( auto backend = nvte_get_fused_attn_backend( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, - q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, false); + q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, + false); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -308,8 +311,8 @@ static void FusedAttnForwardImpl( s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), - q_max_seqlen, kv_max_seqlen, is_training, false, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, softmax_type, window_size_left, window_size_right, + q_max_seqlen, kv_max_seqlen, is_training, false, scaling_factor, dropout_probability, + qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; @@ -542,7 +545,8 @@ static void FusedAttnBackwardImpl( auto backend = nvte_get_fused_attn_backend( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, - q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, false); + q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, + false); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias); diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 1fccd5d130..b3e28a55a5 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -432,7 +432,7 @@ def forward( max_score = None if self.return_max_score: # matmul_result [b, np, sq, dk], max_score [np] - max_score = matmul_result.view(*matmul_result.shape[:2],-1) + max_score = matmul_result.view(*matmul_result.shape[:2], -1) max_score = torch.max(max_score, dim=-1)[0] max_score = torch.max(max_score, dim=0)[0] @@ -538,7 +538,7 @@ def forward( # quantize O if fp8_output: context_layer = O_quantizer(context_layer) - + if self.return_max_score: return context_layer, max_score @@ -920,7 +920,7 @@ def forward( **fa_optional_forward_kwargs, return_attn_probs=self.return_max_score, ) - #if self.return_max_score and (self.attention_dropout == 0.0 or not self.training): + # if self.return_max_score and (self.attention_dropout == 0.0 or not self.training): # output, _, S_dmask = output # max_score = torch.max(S_dmask) else: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 58245f68dc..bdd412f7ae 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -491,10 +491,10 @@ def get_attention_backend( if use_flash_attention: use_flash_attention = False logger.debug("Disabling FlashAttention for max_score") - #if use_flash_attention and not (attention_dropout == 0.0 or not is_training): + # if use_flash_attention and not (attention_dropout == 0.0 or not is_training): # use_flash_attention = False # logger.debug("Disabling FlashAttention for max_score with dropout") - #if use_flash_attention and fp8 and fp8_meta["recipe"].fp8_dpa: + # if use_flash_attention and fp8 and fp8_meta["recipe"].fp8_dpa: # use_flash_attention = False # logger.debug("Disabling FlashAttention for max_score with FP8 attention") diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 9d8843d72a..9a4693b5f2 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -326,8 +326,8 @@ def fused_attn_fwd( # output_tensors: out [b, sq, h, d] or [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] stats = output_tensors[1] + torch.log(output_tensors[2]) max_score = output_tensors[1].squeeze(-1).to(dtype=output_tensors[0].dtype) - max_score = torch.max(max_score, dim=-1)[0] - max_score = torch.max(max_score, dim=0)[0] + max_score = torch.max(max_score, dim=-1)[0] + max_score = torch.max(max_score, dim=0)[0] # return out [b, sq, h, d] or [sq, b, h, d], stats [b, h, sq, 1], max_score [h] return output_tensors[0], stats, max_score diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 556de15063..fb8f79a68e 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -49,7 +49,8 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, - max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, return_max_sum_exp); + max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, + return_max_sum_exp); return fused_attention_backend; } @@ -228,8 +229,9 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_sum_exp, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); + return_max_sum_exp, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size[0], window_size[1], workspace.data(), + at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace and auxiliary output tensors @@ -287,8 +289,9 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_sum_exp, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); + return_max_sum_exp, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size[0], window_size[1], workspace.data(), + at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers, but not allocated memory From 63a7f79f8bbcea581370c732cd8e9265b84195fd Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 10 Oct 2025 05:21:50 -0700 Subject: [PATCH 06/15] update FE to 1.15.0-rc Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 80a8e4af4d..bb37575d10 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 80a8e4af4d89d33a2c59d51fcf9fda1c9d368cd4 +Subproject commit bb37575d103b9974bc619a193dc1a96d835dc117 From 78d542642a83efa14d70b4e211adab83880d7d01 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 10 Oct 2025 05:51:35 -0700 Subject: [PATCH 07/15] fix merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../fused_attn_f16_arbitrary_seqlen.cu | 63 +++++++++---------- .../dot_product_attention/backends.py | 3 +- 2 files changed, 32 insertions(+), 34 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 2bd62b7767..4488ae4a62 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -252,8 +252,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_name("flash_attention") .set_is_inference(false) .set_generate_stats(generate_stats) - .set_generate_max(generate_max_sum_exp) - .set_generate_sum_exp(generate_max_sum_exp) .set_causal_mask(is_causal) .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); @@ -327,16 +325,24 @@ void fused_attn_arbitrary_seqlen_fwd_impl( sdpa_options.set_sink_token(softmax_offset); } - auto sdpa_outputs = mha_graph->sdpa_internal(Q, K, V, std::move(sdpa_options)); - std::shared_ptr O, Stats, Max, Sum_Exp; - O = sdpa_outputs.O; + std::shared_ptr Max, Sum_Exp; if (generate_max_sum_exp) { - Max = sdpa_outputs.Max; - Sum_Exp = sdpa_outputs.Sum_exp; - } else { - Stats = sdpa_outputs.Stats; + Max = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Max") + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + sdpa_options.set_logit_max(Max); + Sum_Exp = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Sum_Exp") + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + sdpa_options.set_score_sum_exp(Sum_Exp); } + auto [O, Stats] = mha_graph->sdpa(Q, K, V, std::move(sdpa_options)); + std::vector o_stride(4); generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); @@ -350,28 +356,19 @@ void fused_attn_arbitrary_seqlen_fwd_impl( O->set_ragged_offset(offset_o); } - if (generate_max_sum_exp) { - Max->set_output(true) - .set_data_type(fe::DataType_t::FLOAT) - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}); - Sum_Exp->set_output(true) - .set_data_type(fe::DataType_t::FLOAT) - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}); + if (!generate_max_sum_exp) { + Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); + if (is_ragged_q && cudnn_runtime_version >= 90600) { + offset_stats = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_stats") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); } else { - Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); - if (is_ragged_q && cudnn_runtime_version >= 90600) { - offset_stats = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_stats") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); - } else { - Stats->set_stride({h * s_q, s_q, 1, 1}); - } + Stats->set_stride({h * s_q, s_q, 1, 1}); + } } std::tuple, // Q @@ -1171,7 +1168,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + return_max_sum_exp, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr, nullptr, devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), @@ -1435,7 +1432,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + return_max_sum_exp, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, @@ -1703,7 +1700,7 @@ void fused_attn_arbitrary_seqlen_fwd( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + return_max_sum_exp, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index b3e28a55a5..0b5bdbed5b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1332,7 +1332,7 @@ def forward( return out_ret, max_score @staticmethod - def backward(ctx, *args): + def backward(ctx, d_out, *args): # pylint: disable=missing-function-docstring # d_out is expected to be in FP8 if is_output_fp8=True, @@ -1576,6 +1576,7 @@ def backward(ctx, *args): d_softmax_offset, None, None, + None, ) From 437219b8dd5c063f77d01fb89f6fcd90433c3ce0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Oct 2025 12:52:27 +0000 Subject: [PATCH 08/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../fused_attn_f16_arbitrary_seqlen.cu | 76 ++++++++++--------- 1 file changed, 39 insertions(+), 37 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 4488ae4a62..e3c7e467ab 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -327,17 +327,17 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr Max, Sum_Exp; if (generate_max_sum_exp) { - Max = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Max") - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); + Max = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Max") + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); sdpa_options.set_logit_max(Max); - Sum_Exp = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Sum_Exp") - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); + Sum_Exp = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Sum_Exp") + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); sdpa_options.set_score_sum_exp(Sum_Exp); } @@ -357,18 +357,18 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } if (!generate_max_sum_exp) { - Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); - if (is_ragged_q && cudnn_runtime_version >= 90600) { - offset_stats = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_stats") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); - } else { - Stats->set_stride({h * s_q, s_q, 1, 1}); - } + Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); + if (is_ragged_q && cudnn_runtime_version >= 90600) { + offset_stats = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_stats") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); + } else { + Stats->set_stride({h * s_q, s_q, 1, 1}); + } } std::tuple, // Q @@ -1168,11 +1168,11 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training, - return_max_sum_exp, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, - devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, - nullptr, nullptr, devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); + return_max_sum_exp, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, + devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlens, devPtrCuSeqlens, nullptr, nullptr, devPtrSeqOffsets, devPtrSeqOffsets, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1432,11 +1432,12 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - return_max_sum_exp, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, - devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, - devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + return_max_sum_exp, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, + devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, + devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, + stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1700,11 +1701,12 @@ void fused_attn_arbitrary_seqlen_fwd( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - return_max_sum_exp, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, - devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, - devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + return_max_sum_exp, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, + devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, + devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, + stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { From 2adb1f2b0f761e04ce7e2ae7be22502d678231be Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sun, 12 Oct 2025 02:14:35 -0700 Subject: [PATCH 09/15] reduce ew kernels; fix causal masks; add more tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 19 ++++++++----------- .../dot_product_attention/backends.py | 10 +++++++--- .../pytorch/cpp_extensions/fused_attn.py | 10 +++++----- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 04513350f1..df1d0f7f51 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -267,15 +267,15 @@ def test_dpa_checkpoint(dtype, model_configs, model): model_configs_max_score = { - # test: b, h, hg, d + # test: ModelConfig(b, sq, hq, dqk) "max_score_1_0": ModelConfig(8, 128, 16, 64), "max_score_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256), - "max_score_2_0": ModelConfig(2, 2048, 24, 128), + "max_score_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"), "max_score_2_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096), - "max_score_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048), + "max_score_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"), "max_score_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048), "max_score_4_0": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048), - "max_score_4_1": ModelConfig(8, 128, 16, 192, max_seqlen_kv=2048), + "max_score_4_1": ModelConfig(8, 128, 16, 192, max_seqlen_kv=2048, attn_bias_type="post_scale_bias"), "max_score_5_0": ModelConfig(8, 1, 16, 512, max_seqlen_kv=2048), "max_score_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048), "max_score_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048), @@ -1117,7 +1117,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: k = inp[1] v = inp[2] d_out = out_grad - out, max_score = block( + out = block( q, k, v, @@ -1137,13 +1137,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: alibi_slopes=alibi_slopes, fast_zero_fill=True, ) - if is_training: - out.backward((d_out, torch.zeros(1, device="cuda"))) - - if config.return_max_score: - out = (out, max_score) - else: + if not config.return_max_score: out = (out, None) + if is_training: + out[0].backward(d_out) d_softmax_offset = None if is_training and config.softmax_type != "vanilla": diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 0b5bdbed5b..901e7a26b1 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -218,6 +218,7 @@ def mask_func(x, y): if is_in_onnx_export_mode() else attention_mask_func(x, y) ) + self.mask_func = mask_func self.scale_mask_softmax = FusedScaleMaskSoftmax(mask_func) @@ -432,9 +433,12 @@ def forward( max_score = None if self.return_max_score: # matmul_result [b, np, sq, dk], max_score [np] - max_score = matmul_result.view(*matmul_result.shape[:2], -1) - max_score = torch.max(max_score, dim=-1)[0] - max_score = torch.max(max_score, dim=0)[0] + max_score = matmul_result + if attn_mask_type != "no_mask": + max_score = self.mask_func(matmul_result, attention_mask) + with self.attention_dropout_ctx(): + max_score = self.attention_dropout(max_score) + max_score = torch.amax(max_score, dim=(0,2,3)) # add attention sink to the last column: [b, np, sq, sk+1] if self.softmax_type != "vanilla": diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 9a4693b5f2..d09aa25876 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -325,11 +325,11 @@ def fused_attn_fwd( if return_max_score: # output_tensors: out [b, sq, h, d] or [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] stats = output_tensors[1] + torch.log(output_tensors[2]) - max_score = output_tensors[1].squeeze(-1).to(dtype=output_tensors[0].dtype) - max_score = torch.max(max_score, dim=-1)[0] - max_score = torch.max(max_score, dim=0)[0] - # return out [b, sq, h, d] or [sq, b, h, d], stats [b, h, sq, 1], max_score [h] - return output_tensors[0], stats, max_score + # Max [b, h, sq, 1] -> max_score [h] + max_score = torch.amax(output_tensors[1], dim=(0,2,3)).to(dtype=output_tensors[0].dtype) + aux_ctx_tensors = [stats] + aux_ctx_tensors.extend(output_tensors[3:]) + return output_tensors[0], aux_ctx_tensors, max_score # out, aux_ctx_tensors return output_tensors[0], output_tensors[1:], None From bc7d6b09184829992e213a3fef85d6331665ab3e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 12 Oct 2025 09:16:00 +0000 Subject: [PATCH 10/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention.py | 8 ++++++-- .../pytorch/attention/dot_product_attention/backends.py | 3 ++- transformer_engine/pytorch/cpp_extensions/fused_attn.py | 2 +- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index df1d0f7f51..858e5a2c27 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -272,10 +272,14 @@ def test_dpa_checkpoint(dtype, model_configs, model): "max_score_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256), "max_score_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"), "max_score_2_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096), - "max_score_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"), + "max_score_3_0": ModelConfig( + 8, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal" + ), "max_score_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048), "max_score_4_0": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048), - "max_score_4_1": ModelConfig(8, 128, 16, 192, max_seqlen_kv=2048, attn_bias_type="post_scale_bias"), + "max_score_4_1": ModelConfig( + 8, 128, 16, 192, max_seqlen_kv=2048, attn_bias_type="post_scale_bias" + ), "max_score_5_0": ModelConfig(8, 1, 16, 512, max_seqlen_kv=2048), "max_score_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048), "max_score_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048), diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 901e7a26b1..b729834602 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -218,6 +218,7 @@ def mask_func(x, y): if is_in_onnx_export_mode() else attention_mask_func(x, y) ) + self.mask_func = mask_func self.scale_mask_softmax = FusedScaleMaskSoftmax(mask_func) @@ -438,7 +439,7 @@ def forward( max_score = self.mask_func(matmul_result, attention_mask) with self.attention_dropout_ctx(): max_score = self.attention_dropout(max_score) - max_score = torch.amax(max_score, dim=(0,2,3)) + max_score = torch.amax(max_score, dim=(0, 2, 3)) # add attention sink to the last column: [b, np, sq, sk+1] if self.softmax_type != "vanilla": diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index d09aa25876..1900b56789 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -326,7 +326,7 @@ def fused_attn_fwd( # output_tensors: out [b, sq, h, d] or [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] stats = output_tensors[1] + torch.log(output_tensors[2]) # Max [b, h, sq, 1] -> max_score [h] - max_score = torch.amax(output_tensors[1], dim=(0,2,3)).to(dtype=output_tensors[0].dtype) + max_score = torch.amax(output_tensors[1], dim=(0, 2, 3)).to(dtype=output_tensors[0].dtype) aux_ctx_tensors = [stats] aux_ctx_tensors.extend(output_tensors[3:]) return output_tensors[0], aux_ctx_tensors, max_score From 7946127f4f268b3748dd31bd1e94721062bc630e Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sun, 12 Oct 2025 02:28:24 -0700 Subject: [PATCH 11/15] minor fix to tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 21 +++++------ .../common/fused_attn/fused_attn.cpp | 24 ++++++------- .../fused_attn_f16_arbitrary_seqlen.cu | 36 +++++++++---------- .../fused_attn_f16_arbitrary_seqlen.h | 6 ++-- .../include/transformer_engine/fused_attn.h | 16 ++++----- transformer_engine/pytorch/csrc/extensions.h | 4 +-- .../pytorch/csrc/extensions/attention.cpp | 16 ++++----- 7 files changed, 62 insertions(+), 61 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 858e5a2c27..35e4257c09 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -281,7 +281,7 @@ def test_dpa_checkpoint(dtype, model_configs, model): 8, 128, 16, 192, max_seqlen_kv=2048, attn_bias_type="post_scale_bias" ), "max_score_5_0": ModelConfig(8, 1, 16, 512, max_seqlen_kv=2048), - "max_score_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048), + "max_score_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048, attn_mask_type="causal", window_size=(20,0)), "max_score_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048), "max_score_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048), } @@ -1141,19 +1141,20 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: alibi_slopes=alibi_slopes, fast_zero_fill=True, ) - if not config.return_max_score: - out = (out, None) + max_score = None + if config.return_max_score: + out, max_score = out if is_training: - out[0].backward(d_out) + out.backward(d_out) d_softmax_offset = None if is_training and config.softmax_type != "vanilla": d_softmax_offset = block.softmax_offset.grad if backend in ["FlashAttention", "UnfusedDotProductAttention"]: if is_training: - return *out, (q.grad, k.grad, v.grad, d_softmax_offset) + return out, max_score, (q.grad, k.grad, v.grad, d_softmax_offset) else: - return *out, (None, None, None, d_softmax_offset) + return out, max_score, (None, None, None, d_softmax_offset) if backend == "FusedAttention": if qkv_format == "thd" and pad_between_seqs: out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) @@ -1182,14 +1183,14 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 ) if is_training: - return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset) + return out_orig, max_score, (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset) else: - return out_orig, (None, None, None, d_softmax_offset) + return out_orig, max_score, (None, None, None, d_softmax_offset) else: if is_training: - return *out, (q.grad, k.grad, v.grad, d_softmax_offset) + return out, max_score, (q.grad, k.grad, v.grad, d_softmax_offset) else: - return *out, (None, None, None, d_softmax_offset) + return out, max_score, (None, None, None, d_softmax_offset) model_configs_te_layer = { diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 6c2da5248c..4c4d44db53 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -138,7 +138,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_sum_exp) { + int64_t window_size_right, bool return_max_score) { using namespace transformer_engine; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; const int device_id = cuda::current_device(); @@ -187,7 +187,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && // 9.10.0: known bugs with SDPA FP8 - (cudnn_runtime_version != 91000) && !return_max_sum_exp) { + (cudnn_runtime_version != 91000) && !return_max_score) { if (cudnn_runtime_version >= 8900) { backend = NVTE_Fused_Attn_Backend::NVTE_FP8; } else { @@ -216,7 +216,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) && ((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) && !requires_64bit_ragged_offset && - (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && !return_max_sum_exp) { + (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && !return_max_score) { flag_m512 = true; } if ( @@ -418,7 +418,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, - size_t max_seqlen, bool is_training, bool return_max_sum_exp, + size_t max_seqlen, bool is_training, bool return_max_score, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, @@ -460,7 +460,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_sum_exp); + h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_score); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -474,7 +474,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd_qkvpacked( - b, h, max_seqlen, d, t, is_training, return_max_sum_exp, attn_scale, dropout, qkv_layout, + b, h, max_seqlen, d, t, is_training, return_max_score, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV, input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); @@ -602,7 +602,7 @@ void nvte_fused_attn_fwd_kvpacked( const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, - size_t max_seqlen_kv, bool is_training, bool return_max_sum_exp, float attn_scale, + size_t max_seqlen_kv, bool is_training, bool return_max_score, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { @@ -681,7 +681,7 @@ void nvte_fused_attn_fwd_kvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, - return_max_sum_exp); + return_max_score); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -697,7 +697,7 @@ void nvte_fused_attn_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, - return_max_sum_exp, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + return_max_score, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, @@ -839,7 +839,7 @@ void nvte_fused_attn_fwd( const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_sum_exp, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_score, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { @@ -913,7 +913,7 @@ void nvte_fused_attn_fwd( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, - return_max_sum_exp); + return_max_score); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -929,7 +929,7 @@ void nvte_fused_attn_fwd( fused_attn_arbitrary_seqlen_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, - return_max_sum_exp, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + return_max_score, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index e3c7e467ab..5541185de7 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -53,7 +53,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, - bool generate_max_sum_exp, float scaling_factor, float dropout_probability, + bool return_max_score, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, @@ -103,7 +103,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; - bool generate_stats = !generate_max_sum_exp; + bool generate_stats = !return_max_score; try { FADescriptor_v1 descriptor{ b, @@ -136,7 +136,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( cudnn_frontend::DataType_t::NOT_SET, cudnn_frontend::DataType_t::NOT_SET, generate_stats, - generate_max_sum_exp, + return_max_score, }; namespace fe = cudnn_frontend; @@ -326,7 +326,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } std::shared_ptr Max, Sum_Exp; - if (generate_max_sum_exp) { + if (return_max_score) { Max = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Max") .set_dim({b, h, s_q, 1}) @@ -356,7 +356,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( O->set_ragged_offset(offset_o); } - if (!generate_max_sum_exp) { + if (!return_max_score) { Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); if (is_ragged_q && cudnn_runtime_version >= 90600) { offset_stats = @@ -447,7 +447,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( {Q, devPtrQ}, {K, devPtrK}, {V, devPtrV}, {attn_scale, &scaling_factor}, {O, devPtrO}, {S1, devPtrS1}}; - if (generate_max_sum_exp) { + if (return_max_score) { variant_pack[S2] = devPtrS2; } @@ -1035,7 +1035,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( using namespace transformer_engine::fused_attn; void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - bool is_training, bool return_max_sum_exp, float attn_scale, float p_dropout, + bool is_training, bool return_max_score, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, @@ -1087,7 +1087,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t i = 0; if (Aux_CTX_Tensors->size == 0) { const auto cudnn_runtime_version = cudnnGetVersion(); - if (return_max_sum_exp) { + if (return_max_score) { Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_Max->data.dptr = nullptr; if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { @@ -1136,7 +1136,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( Aux_CTX_Tensors->size = i; } else if (Aux_CTX_Tensors->size >= 2) { - if (return_max_sum_exp) { + if (return_max_score) { Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); devPtrS1 = output_Max->data.dptr; Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); @@ -1168,7 +1168,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training, - return_max_sum_exp, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, + return_max_score, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr, nullptr, devPtrSeqOffsets, devPtrSeqOffsets, @@ -1285,7 +1285,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, - bool return_max_sum_exp, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + bool return_max_score, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, @@ -1350,7 +1350,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t i = 0; if (Aux_CTX_Tensors->size == 0) { const auto cudnn_runtime_version = cudnnGetVersion(); - if (return_max_sum_exp) { + if (return_max_score) { Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_Max->data.dptr = nullptr; if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { @@ -1399,7 +1399,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( Aux_CTX_Tensors->size = i; } else if (Aux_CTX_Tensors->size >= 2) { - if (return_max_sum_exp) { + if (return_max_score) { Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); devPtrS1 = output_Max->data.dptr; Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); @@ -1432,7 +1432,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - return_max_sum_exp, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, + return_max_score, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, @@ -1564,7 +1564,7 @@ void fused_attn_arbitrary_seqlen_fwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, - bool return_max_sum_exp, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + bool return_max_score, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, @@ -1619,7 +1619,7 @@ void fused_attn_arbitrary_seqlen_fwd( size_t i = 0; if (Aux_CTX_Tensors->size == 0) { const auto cudnn_runtime_version = cudnnGetVersion(); - if (return_max_sum_exp) { + if (return_max_score) { Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_Max->data.dptr = nullptr; if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { @@ -1668,7 +1668,7 @@ void fused_attn_arbitrary_seqlen_fwd( Aux_CTX_Tensors->size = i; } else if (Aux_CTX_Tensors->size >= 2) { - if (return_max_sum_exp) { + if (return_max_score) { Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); devPtrS1 = output_Max->data.dptr; Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); @@ -1701,7 +1701,7 @@ void fused_attn_arbitrary_seqlen_fwd( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - return_max_sum_exp, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, + return_max_score, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 8cef0cd3f7..99a2f9ca76 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -20,7 +20,7 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - bool is_training, bool return_max_sum_exp, float attn_scale, float p_dropout, + bool is_training, bool return_max_score, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, @@ -43,7 +43,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, - bool return_max_sum_exp, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + bool return_max_score, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, @@ -70,7 +70,7 @@ void fused_attn_arbitrary_seqlen_fwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, - bool return_max_sum_exp, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + bool return_max_score, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 8a8d870edd..9230ba8012 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -206,14 +206,14 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * \param[in] head_dim_v The head dimension of V. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). - * \param[in] return_max_sum_exp Whether to produce Max and Sum_Exp, or Stats. + * \param[in] return_max_score Whether to produce Max and Sum_Exp, or Stats. */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_sum_exp); + int64_t window_size_right, bool return_max_score); /*! \brief Compute dot product attention with packed QKV input. * @@ -256,7 +256,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] max_seqlen Max sequence length used for computing, * it may be >= max(seqlen_i) for i=0,...batch_size-1. * \param[in] is_training Whether this is in training mode or inference. - * \param[in] return_max_sum_exp Whether to produce Max and Sum_Exp, or Stats. + * \param[in] return_max_score Whether to produce Max and Sum_Exp, or Stats. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensor's layout. @@ -272,7 +272,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, - size_t max_seqlen, bool is_training, bool return_max_sum_exp, + size_t max_seqlen, bool is_training, bool return_max_score, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, @@ -386,7 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] max_seqlen_kv Max sequence length used for computing for KV. * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * \param[in] is_training Whether this is in training mode or inference. - * \param[in] return_max_sum_exp Whether to produce Max and Sum_Exp, or Stats. + * \param[in] return_max_score Whether to produce Max and Sum_Exp, or Stats. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensor's layout. @@ -405,7 +405,7 @@ void nvte_fused_attn_fwd_kvpacked( const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, - size_t max_seqlen_kv, bool is_training, bool return_max_sum_exp, float attn_scale, + size_t max_seqlen_kv, bool is_training, bool return_max_score, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); @@ -526,7 +526,7 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] max_seqlen_kv Max sequence length used for computing for K and V. * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * \param[in] is_training Whether this is in training mode or inference. - * \param[in] return_max_sum_exp Whether to produce Max and Sum_Exp, or Stats. + * \param[in] return_max_score Whether to produce Max and Sum_Exp, or Stats. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensors' layout. @@ -544,7 +544,7 @@ void nvte_fused_attn_fwd( const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_sum_exp, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_score, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e76ec425c9..9d0f54a3d0 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -76,7 +76,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_sum_exp); + int64_t window_size_right, bool return_max_score); std::pair quantizer_helper(py::handle quantizer, const std::vector &shape, DType dtype, @@ -94,7 +94,7 @@ std::vector fused_attn_fwd( const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, const std::optional SoftmaxOffset, const std::optional rng_gen, - size_t rng_elts_per_thread, bool return_max_sum_exp); + size_t rng_elts_per_thread, bool return_max_score); std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index fb8f79a68e..b0535676bf 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -45,12 +45,12 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_sum_exp) { + int64_t window_size_right, bool return_max_score) { NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, - return_max_sum_exp); + return_max_score); return fused_attention_backend; } @@ -107,7 +107,7 @@ std::vector fused_attn_fwd( const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, const std::optional SoftmaxOffset, const std::optional rng_gen, - size_t rng_elts_per_thread, bool return_max_sum_exp) { + size_t rng_elts_per_thread, bool return_max_score) { auto none = py::none(); // create QKV tensor wrappers @@ -229,7 +229,7 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_sum_exp, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + return_max_score, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -252,8 +252,8 @@ std::vector fused_attn_fwd( // allocate memory for nvte_aux_tensor_pack.tensors // f16_max512 : S [b, h, sq, skv] // f16_arbitrary: - // return_max_sum_exp=false: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] - // return_max_sum_exp=true: Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv] + // return_max_score=false: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] + // return_max_score=true: Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv] // fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2] size_t i = 0; at::Tensor output_tensor; @@ -263,7 +263,7 @@ std::vector fused_attn_fwd( static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); set_tensor_param(i++, output_tensor); // fp8 has an additional softmax stats tensor, ZInv - if (return_max_sum_exp || qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + if (return_max_score || qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { output_tensor = allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); @@ -289,7 +289,7 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_sum_exp, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + return_max_score, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); }); From f984602608278874f3f0a135a566e8c8d9ef6b56 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 12 Oct 2025 09:29:20 +0000 Subject: [PATCH 12/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention.py | 10 ++++-- .../common/fused_attn/fused_attn.cpp | 18 +++++----- .../fused_attn_f16_arbitrary_seqlen.cu | 35 +++++++++---------- .../fused_attn_f16_arbitrary_seqlen.h | 18 +++++----- .../include/transformer_engine/fused_attn.h | 8 ++--- 5 files changed, 47 insertions(+), 42 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 35e4257c09..a68ecde51f 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -281,7 +281,9 @@ def test_dpa_checkpoint(dtype, model_configs, model): 8, 128, 16, 192, max_seqlen_kv=2048, attn_bias_type="post_scale_bias" ), "max_score_5_0": ModelConfig(8, 1, 16, 512, max_seqlen_kv=2048), - "max_score_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048, attn_mask_type="causal", window_size=(20,0)), + "max_score_5_1": ModelConfig( + 8, 128, 16, 512, max_seqlen_kv=2048, attn_mask_type="causal", window_size=(20, 0) + ), "max_score_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048), "max_score_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048), } @@ -1183,7 +1185,11 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 ) if is_training: - return out_orig, max_score, (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset) + return ( + out_orig, + max_score, + (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset), + ) else: return out_orig, max_score, (None, None, None, d_softmax_offset) else: diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 4c4d44db53..1e4cec2dfd 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -602,10 +602,10 @@ void nvte_fused_attn_fwd_kvpacked( const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, - size_t max_seqlen_kv, bool is_training, bool return_max_score, float attn_scale, - float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { + size_t max_seqlen_kv, bool is_training, bool return_max_score, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -697,9 +697,9 @@ void nvte_fused_attn_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, - return_max_score, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size_left, window_size_right, input_Q, input_KV, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + return_max_score, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset, + output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); #else @@ -929,8 +929,8 @@ void nvte_fused_attn_fwd( fused_attn_arbitrary_seqlen_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, - return_max_score, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias, + return_max_score, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 5541185de7..590b50ad4a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -53,15 +53,14 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, - bool return_max_score, float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, - void *devPtrS1, void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, - void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, - void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, - void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, - size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + bool return_max_score, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, + void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, + void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, + void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, + void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, + void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -1284,15 +1283,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, - size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, - bool return_max_score, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_KV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, - Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_score, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, + const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 99a2f9ca76..094b04da5c 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -42,15 +42,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, - size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, - bool return_max_score, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_KV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, - Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_score, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, + const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 9230ba8012..8f03d5e187 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -405,10 +405,10 @@ void nvte_fused_attn_fwd_kvpacked( const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, - size_t max_seqlen_kv, bool is_training, bool return_max_score, float attn_scale, - float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); + size_t max_seqlen_kv, bool is_training, bool return_max_score, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed KV input. * From cb01843a0c9122466b46e9c6c1a947b40039ed1f Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sun, 12 Oct 2025 02:44:08 -0700 Subject: [PATCH 13/15] remove logic for flash-attn Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../pytorch/attention/dot_product_attention/backends.py | 8 -------- .../dot_product_attention/dot_product_attention.py | 3 ++- .../pytorch/attention/dot_product_attention/utils.py | 6 ------ 3 files changed, 2 insertions(+), 15 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index b729834602..d480a02caf 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -598,7 +598,6 @@ def __init__( attention_type: str = "self", layer_number: Optional[int] = None, deterministic: bool = False, - return_max_score: Optional[bool] = False, ) -> None: super().__init__() @@ -622,7 +621,6 @@ def __init__( self.logger.setLevel(attn_log._log_level) if not self.logger.hasHandlers(): self.logger.addHandler(attn_log._stream_handler) - self.return_max_score = return_max_score def forward( self, @@ -923,11 +921,7 @@ def forward( softmax_scale=self.softmax_scale, causal="causal" in attn_mask_type, **fa_optional_forward_kwargs, - return_attn_probs=self.return_max_score, ) - # if self.return_max_score and (self.attention_dropout == 0.0 or not self.training): - # output, _, S_dmask = output - # max_score = torch.max(S_dmask) else: fa_3_optional_forward_kwargs = {} fa_3_optional_forward_kwargs["window_size"] = window_size @@ -1052,8 +1046,6 @@ def convert_to_torch_float8(tensor, dtype): # thd -> t(hd) output = output.reshape(output.shape[0], -1) - if self.return_max_score: - return output.contiguous(), max_score return output.contiguous() diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 3af595e9f0..c980e6bfcf 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -419,7 +419,6 @@ def __init__( attn_kwargs = { "attention_dropout": attention_dropout, "attention_dropout_ctx": attention_dropout_ctx, - "return_max_score": return_max_score, } self.flash_attention = FlashAttention( @@ -439,6 +438,7 @@ def __init__( deterministic=self.deterministic, **attn_kwargs, softmax_type=self.softmax_type, + return_max_score=self.return_max_score, ) self.unfused_attention = UnfusedDotProductAttention( @@ -447,6 +447,7 @@ def __init__( **attn_kwargs, layer_number=layer_number, softmax_type=self.softmax_type, + return_max_score=self.return_max_score, ) def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index bdd412f7ae..2d2c4499e9 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -491,12 +491,6 @@ def get_attention_backend( if use_flash_attention: use_flash_attention = False logger.debug("Disabling FlashAttention for max_score") - # if use_flash_attention and not (attention_dropout == 0.0 or not is_training): - # use_flash_attention = False - # logger.debug("Disabling FlashAttention for max_score with dropout") - # if use_flash_attention and fp8 and fp8_meta["recipe"].fp8_dpa: - # use_flash_attention = False - # logger.debug("Disabling FlashAttention for max_score with FP8 attention") # Filter: KV cache # backend | precision | KV cache | architecture | qkv_format | page_size From 966b65733e83a1c082c17d2cf69fd6a7ca80b0b4 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sun, 12 Oct 2025 06:23:30 -0700 Subject: [PATCH 14/15] WIP: add CP support for p2p/a2a/all_gather Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/run_attention_with_cp.py | 15 ++-- .../attention/test_attention_with_cp.py | 6 +- .../dot_product_attention/backends.py | 5 +- .../dot_product_attention/context_parallel.py | 69 ++++++++++++++++--- .../attention/dot_product_attention/utils.py | 8 ++- 5 files changed, 81 insertions(+), 22 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index d490c235bb..a272680ec2 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -249,6 +249,7 @@ def run_dpa_with_cp( attn_mask_type=config.attn_mask_type, window_size=config.window_size, softmax_type=config.softmax_type, + return_max_score=config.return_max_score, ).cuda() if config.softmax_type != "vanilla": core_attn.softmax_offset.requires_grad = True @@ -309,6 +310,7 @@ def run_dpa_with_cp( fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) else: fp8_context = nullcontext() + max_score = None with fp8_context: # q, k, v, out in FP8; dout in F16 out = core_attn( @@ -323,6 +325,8 @@ def run_dpa_with_cp( cu_seqlens_kv_padded=cu_seqlens_kv_padded, fp8_output=fp8_mha, ) + if config.return_max_score: + out, max_score = out if fp8_bwd and fp8_mha: dout_fp8 = dout_quantizer(dout) out.backward(dout_fp8) @@ -401,6 +405,7 @@ def run_dpa_with_cp( fp8_context = nullcontext() # run attention + max_score_ = None with fp8_context: # q, k, v, out in FP8; dout in F16 out_ = core_attn( @@ -415,6 +420,8 @@ def run_dpa_with_cp( cu_seqlens_kv_padded=cu_seqlens_kv_padded, fp8_output=fp8_mha, ) + if config.return_max_score: + out_, max_score_ = out_ if fp8_bwd and fp8_mha: dout_fp8_ = dout_quantizer(dout_) out_.backward(dout_fp8_) @@ -496,15 +503,15 @@ def run_dpa_with_cp( ) atol, rtol, rmse_tol = get_tols(config, dtype) - tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_] - tensors_no_cp = [out, dq, dk, dv, d_softmax_offset] - names = ["out", "dq", "dk", "dv", "d_softmax_offset"] + tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_, max_score_] + tensors_no_cp = [out, dq, dk, dv, d_softmax_offset, max_score] + names = ["out", "dq", "dk", "dv", "d_softmax_offset", "max_score"] names_cp = [x + "_cp" for x in names] names_no_cp = [x + "_no_cp" for x in names] is_fp8 = dtype == "fp8" for i, t in enumerate(tensors_no_cp): if t is not None: - if "softmax_offset" not in names[i]: + if "softmax_offset" not in names[i] and "max_score" not in names[i]: if qkv_format == "bshd": compare_and_assert( t[:, 0], diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 0f00b8b0ef..05585d3462 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -137,8 +137,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): model_configs_fused_attn = { # test: ModelConfig(b, sq, hq, dqk) - "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA - "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA + "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", return_max_score=True), # MHA + "cp_1_1": ModelConfig(2, 4096, 12, 128, return_max_score=True), # MHA "cp_1_2": ModelConfig( 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" ), # MHA @@ -183,7 +183,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): qkv_formats = ["bshd", "sbhd", "thd"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] if test_essential: - configs = ["cp_1_0", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"] + configs = ["cp_1_0", "cp_1_1", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} dtypes = ["bf16", "fp8"] qkv_formats = ["sbhd", "thd"] diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index d480a02caf..d9681d04fa 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1150,7 +1150,7 @@ def forward( # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - out_, aux_ctx_tensors = fused_attn_fwd( + out_, aux_ctx_tensors, max_score = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, @@ -1848,6 +1848,7 @@ def forward( softmax_offset=softmax_offset, fp8_output=fp8_output, layer_number=self.layer_number, + return_max_score=self.return_max_score, ) else: with self.attention_dropout_ctx(): @@ -1886,7 +1887,7 @@ def forward( self.return_max_score, ) - if self.return_max_score and not context_parallel: + if self.return_max_score: # ...hd -> ...(hd) return output[0].view(*output[0].shape[:-2], -1), output[1] # ...hd -> ...(hd) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index d1374e949e..33542cf573 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -617,6 +617,7 @@ def cp_p2p_fwd_fused_attn( rank, step, cp_size, + return_max_score, q_part, k_part, v_part, @@ -693,7 +694,7 @@ def cp_p2p_fwd_fused_attn( fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step fp8_meta_kwargs["o_quantizer"] = O_quantizer_per_step - out_per_step, aux_ctx_tensors = fused_attn_fwd( + out_per_step, aux_ctx_tensors, max_score = fused_attn_fwd( is_training, max_seqlen_q_, max_seqlen_kv_, @@ -713,6 +714,7 @@ def cp_p2p_fwd_fused_attn( cu_seqlens_q_padded=cu_seqlens_q_padded_, cu_seqlens_kv_padded=cu_seqlens_kv_padded_, **fp8_meta_kwargs, + return_max_score=return_max_score, ) if fp8: @@ -721,7 +723,7 @@ def cp_p2p_fwd_fused_attn( softmax_lse_per_step, rng_states, *rest = aux_ctx_tensors attn_bias = rest[0] if len(rest) > 0 else None - return out_per_step, softmax_lse_per_step, rng_states, attn_bias + return out_per_step, softmax_lse_per_step, rng_states, attn_bias, max_score def cp_p2p_fwd_flash_attn( @@ -1096,6 +1098,7 @@ def forward( use_flash_attn_3, fp8_output, layer_number, + return_max_score, ): # pylint: disable=missing-function-docstring @@ -1156,6 +1159,8 @@ def forward( amax_per_step = None S_quantizer_per_step = [None for _ in range(cp_size)] O_quantizer_per_step = [None for _ in range(cp_size)] + max_score_per_step = [None for _ in range(cp_size)] + max_score = None assert isinstance(k, q.__class__) and isinstance( v, q.__class__ @@ -1244,6 +1249,8 @@ def forward( q_f16 = q if use_fused_attention: fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + if return_max_score: + max_score_per_step = torch.empty((cp_size, q.shape[-2]), dtype=q.dtype, device=q.device) # split qkv to two halves and prepare for load balancing assert qkv_format == "thd" or ( @@ -1418,6 +1425,7 @@ def forward( rank, i, cp_size, + return_max_score, ] else: flash_attn_inputs = [ @@ -1462,6 +1470,7 @@ def forward( softmax_lse_per_step[i], rng_states[i], attn_biases[i], + max_score_per_step[i], ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) @@ -1488,6 +1497,7 @@ def forward( softmax_lse_per_step[i], rng_states[i], attn_biases[i], + max_score_per_step[i], ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) @@ -1514,6 +1524,7 @@ def forward( softmax_lse_per_step[i], rng_states[i], attn_biases[i], + max_score_per_step[i], ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) @@ -1541,6 +1552,7 @@ def forward( softmax_lse_per_step[i], rng_states[i], attn_biases[i], + max_score_per_step[i], ) = cp_p2p_fwd_fused_attn(*fused_attn_inputs, *prepare_outputs, section) else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( @@ -1600,11 +1612,18 @@ def forward( softmax_lse.view(*softmax_lse.shape[:-1], 2, -1), softmax_lse_per_step[i - 1], ) + if return_max_score: + if i == 1: + max_score = torch.clone(max_score_per_step[0]) + else: + max_score = torch.maximum(max_score, max_score_per_step[i-1]) if i < cp_size: flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done) torch.cuda.current_stream().wait_stream(flash_attn_streams[1]) + if return_max_score: + torch.distributed.all_reduce(max_score, op=torch.distributed.ReduceOp.MAX, group=cp_group) second_half_lse_seqlen = None if causal and rank < (cp_size - 1): @@ -1682,6 +1701,10 @@ def forward( elif qkv_format == "sbhd": # [s*b, h, d] -> [s, b, h, d] out = out.view(-1, ctx.batch_size, *out.shape[-2:]) + if return_max_score: + max_score = flash_attn_a2a_communicate_softmax_offset( + max_score, 0, cp_size_a2a, cp_group_a2a, cp_stream, False + ) elif not use_fused_attention: out = out.view(-1, *out.shape[-2:]) @@ -1811,10 +1834,10 @@ def forward( nvtx_range_pop(f"{nvtx_label}") - return out_ret + return out_ret, max_score @staticmethod - def backward(ctx, dout): + def backward(ctx, dout, *args): # pylint: disable=missing-function-docstring # add NVTX range @@ -2522,6 +2545,7 @@ def backward(ctx, dout): None, None, None, + None, ) @@ -2581,6 +2605,7 @@ def forward( cp_group, cp_stream, use_flash_attn_3, + return_max_score, ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") @@ -2682,6 +2707,8 @@ def forward( softmax_lse_per_step = [None, None] rng_states = [None, None] out = torch.empty_like(q) + max_score_per_step = [None, None] + max_score = None for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): @@ -2712,7 +2739,7 @@ def forward( # [s_range, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] if use_fused_attention: - out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd( + out_per_step[i], [softmax_lse_per_step[i], rng_states[i]], max_score_per_step[i] = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv_, @@ -2732,6 +2759,7 @@ def forward( cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], window_size=window_size_per_step[i], + return_max_score=return_max_score, ) else: fa_forward_args_thd = get_fa_args( @@ -2767,14 +2795,21 @@ def forward( if not use_flash_attn_3: rng_states[i] = fa_outputs[3] + if return_max_score and i == 0: + max_score = torch.clone(max_score_per_step[0]) if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): if qkv_format == "bshd": out[:, i - 1].copy_(out_per_step[i - 1]) elif qkv_format == "sbhd": out[i - 1].copy_(out_per_step[i - 1]) + if return_max_score: + max_score = torch.maximum(max_score, max_score_per_step[i-1]) + torch.cuda.current_stream().wait_stream(cp_stream) + if return_max_score: + torch.distributed.all_reduce(max_score, op=torch.distributed.ReduceOp.MAX, group=cp_group) if use_fused_attention: if qkv_format == "bshd": @@ -2811,10 +2846,10 @@ def forward( ctx.use_fused_attention = use_fused_attention ctx.use_flash_attn_3 = use_flash_attn_3 nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") - return out + return out, max_score @staticmethod - def backward(ctx, dout): + def backward(ctx, dout, *args): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") cp_size = get_distributed_world_size(ctx.cp_group) @@ -3035,6 +3070,7 @@ def backward(ctx, dout): None, None, None, + None, ) @@ -3075,6 +3111,7 @@ def forward( softmax_type, softmax_offset, fp8_output, + return_max_score, ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") @@ -3158,6 +3195,7 @@ def forward( fp8_recipe = fp8_meta["local_recipes"][0] fwd_nominal_dtype = q.dtype fused_attn_backend = None + max_score = None QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( dpa_utils.get_attention_quantizers(fp8, quantizers) @@ -3203,7 +3241,7 @@ def forward( Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) ] - out_, aux_ctx_tensors = fused_attn_fwd( + out_, aux_ctx_tensors, max_score = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, @@ -3226,6 +3264,7 @@ def forward( **fp8_meta_kwargs, softmax_type=softmax_type, softmax_offset=softmax_offset, + return_max_score=return_max_score, ) if isinstance(out_, Float8Tensor): out_fp8 = out_ @@ -3276,6 +3315,10 @@ def forward( out_ = flash_attn_a2a_communicate( out_, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False ) + if return_max_score: + max_score = flash_attn_a2a_communicate_softmax_offset( + max_score, 0, cp_size, cp_group, cp_stream, False + ) if use_fused_attention: if qkv_format == "bshd": @@ -3362,10 +3405,10 @@ def forward( ctx.S_quantizer = S_quantizer.copy() ctx.S_quantizer.scale = S_quantizer.scale.clone() nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") - return out_ret + return out_ret, max_score @staticmethod - def backward(ctx, dout): + def backward(ctx, dout, *args): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") cp_size = get_distributed_world_size(ctx.cp_group) @@ -3601,6 +3644,7 @@ def backward(ctx, dout): None, d_softmax_offset, None, + None, ) @@ -3637,6 +3681,7 @@ def attn_forward_func_with_cp( softmax_offset=None, fp8_output=False, layer_number=1, + return_max_score=False, ) -> torch.Tensor: """ Attention implementation with context parallelism (CP). CP partitions tensors along the sequence @@ -3798,12 +3843,13 @@ def attn_forward_func_with_cp( use_flash_attn_3, fp8_output, layer_number, + return_max_score, ] out = AttnFuncWithCPAndKVP2P.apply(*args) elif cp_comm_type == "all_gather": args.pop(5) args.pop(8) - args += [window_size, cp_group, cp_stream, use_flash_attn_3] + args += [window_size, cp_group, cp_stream, use_flash_attn_3, return_max_score] out = AttnFuncWithCPAndKVAllGather.apply(*args) elif cp_comm_type == "a2a": args += [ @@ -3817,6 +3863,7 @@ def attn_forward_func_with_cp( softmax_type, softmax_offset, fp8_output, + return_max_score, ] out = AttnFuncWithCPAndQKVOA2A.apply(*args) else: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 2d2c4499e9..4afbfcc496 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -485,12 +485,16 @@ def get_attention_backend( if return_max_score: if context_parallel: use_flash_attention = False - use_fused_attention = False use_unfused_attention = False - logger.debug("Disabling all backends for max_score with context parallelism") + logger.debug("Disabling FlashAttention and UnfusedAttention for max_score with context parallelism") if use_flash_attention: use_flash_attention = False logger.debug("Disabling FlashAttention for max_score") + if fp8 and fp8_meta["recipe"].fp8_dpa: + use_flash_attention = False + use_fused_attention = False + use_unfused_attention = False + logger.debug("Disabling all backends for max_score with context parallelism in FP8") # Filter: KV cache # backend | precision | KV cache | architecture | qkv_format | page_size From 9d614a82bd38e87d1f66d99c0aa4b1367dab572c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 12 Oct 2025 13:24:17 +0000 Subject: [PATCH 15/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../dot_product_attention/context_parallel.py | 23 +++++++++++++------ .../attention/dot_product_attention/utils.py | 5 +++- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 33542cf573..6a8659cff1 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1250,7 +1250,9 @@ def forward( if use_fused_attention: fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] if return_max_score: - max_score_per_step = torch.empty((cp_size, q.shape[-2]), dtype=q.dtype, device=q.device) + max_score_per_step = torch.empty( + (cp_size, q.shape[-2]), dtype=q.dtype, device=q.device + ) # split qkv to two halves and prepare for load balancing assert qkv_format == "thd" or ( @@ -1616,14 +1618,16 @@ def forward( if i == 1: max_score = torch.clone(max_score_per_step[0]) else: - max_score = torch.maximum(max_score, max_score_per_step[i-1]) + max_score = torch.maximum(max_score, max_score_per_step[i - 1]) if i < cp_size: flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done) torch.cuda.current_stream().wait_stream(flash_attn_streams[1]) if return_max_score: - torch.distributed.all_reduce(max_score, op=torch.distributed.ReduceOp.MAX, group=cp_group) + torch.distributed.all_reduce( + max_score, op=torch.distributed.ReduceOp.MAX, group=cp_group + ) second_half_lse_seqlen = None if causal and rank < (cp_size - 1): @@ -2739,7 +2743,11 @@ def forward( # [s_range, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] if use_fused_attention: - out_per_step[i], [softmax_lse_per_step[i], rng_states[i]], max_score_per_step[i] = fused_attn_fwd( + ( + out_per_step[i], + [softmax_lse_per_step[i], rng_states[i]], + max_score_per_step[i], + ) = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv_, @@ -2804,12 +2812,13 @@ def forward( elif qkv_format == "sbhd": out[i - 1].copy_(out_per_step[i - 1]) if return_max_score: - max_score = torch.maximum(max_score, max_score_per_step[i-1]) - + max_score = torch.maximum(max_score, max_score_per_step[i - 1]) torch.cuda.current_stream().wait_stream(cp_stream) if return_max_score: - torch.distributed.all_reduce(max_score, op=torch.distributed.ReduceOp.MAX, group=cp_group) + torch.distributed.all_reduce( + max_score, op=torch.distributed.ReduceOp.MAX, group=cp_group + ) if use_fused_attention: if qkv_format == "bshd": diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 4afbfcc496..75aa855590 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -486,7 +486,10 @@ def get_attention_backend( if context_parallel: use_flash_attention = False use_unfused_attention = False - logger.debug("Disabling FlashAttention and UnfusedAttention for max_score with context parallelism") + logger.debug( + "Disabling FlashAttention and UnfusedAttention for max_score with context" + " parallelism" + ) if use_flash_attention: use_flash_attention = False logger.debug("Disabling FlashAttention for max_score")