From 0a0a5ca7a83111cec5c0c44215028d15a4aa3413 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Fri, 7 Feb 2025 08:21:06 -0800 Subject: [PATCH 1/2] use correct total length to fix static kv_cache performance --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 568e75b38a98f..33f67375b4ef2 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -439,7 +439,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; - const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length_; + const int total_sequence_length = parameters.total_sequence_length_; const TensorShapeVector probs_dims({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, total_sequence_length}); From 8f808c26c28385655b6a6398839d9f0a937490ba Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Mon, 10 Feb 2025 13:09:28 -0800 Subject: [PATCH 2/2] Simple version of enabling GQA --- onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | 7 +++++-- .../contrib_ops/webgpu/bert/group_query_attention.cc | 6 ++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index b51c2fbe27e1d..bdb2234a5d3b2 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -420,8 +420,10 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co FlashAttentionProgram program{"FlashAttention", has_attention_bias, parameters.head_size_, parameters.num_heads_}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4}, {present_key, ProgramTensorMetadataDependency::TypeAndRank, 4}, - {present_value, ProgramTensorMetadataDependency::TypeAndRank, 4}, - {attention_bias, ProgramTensorMetadataDependency::TypeAndRank}}); + {present_value, ProgramTensorMetadataDependency::TypeAndRank, 4}}); + if (has_attention_bias) { + program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); + } program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, 4}}); const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; @@ -443,6 +445,7 @@ bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const return parameters.batch_size_ == 1 && bias == nullptr && parameters.sequence_length_ > 1 && + parameters.qkv_format_ == Q_K_V_BSNH && context.Device().HasFeature(wgpu::FeatureName::Subgroups) && present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 && present_value->SizeInBytes() > 0 && parameters.head_size_ % 4 == 0; diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 31c8af9b4f922..bd62c68d871fa 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -5,6 +5,7 @@ #include "contrib_ops/webgpu/bert/attention_common.h" #include "contrib_ops/webgpu/bert/group_query_attention.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "contrib_ops/webgpu/bert/flash_attention.h" #include "core/providers/webgpu/webgpu_supported_types.h" @@ -74,6 +75,11 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor* present_value = context.Output(2, present_kv_shape); parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw(); + if (CanApplyFlashAttention(/*bias*/ nullptr, present_key, present_value, parameters, context)) { + return ApplyFlashAttention(query, key, value, /*attention_bias*/ nullptr, output, past_key, present_key, past_value, + present_value, parameters, context); + } + TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_}); TensorShape q_new_shape(q_new_dims);