From 377505746dba2e137a48e0d4c67a196ccba2271c Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Tue, 11 Feb 2025 08:43:59 -0800 Subject: [PATCH] use correct total length to fix static kv_cache performance (#23615) when using static kv_cache, past_sequence_length is the max sequence length of kv_cache. issue1: total_sequence_length will be larger than the cache entry issue2: we do way more calculations that needed so things are noticeable slower --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 568e75b38a98f..33f67375b4ef2 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -439,7 +439,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; - const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length_; + const int total_sequence_length = parameters.total_sequence_length_; const TensorShapeVector probs_dims({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, total_sequence_length});