Skip to content

Commit

Permalink
use correct total length to fix static kv_cache performance (#23615)
Browse files Browse the repository at this point in the history
when using static kv_cache, past_sequence_length is the max sequence
length of kv_cache.
issue1: total_sequence_length will be larger than the cache entry
issue2: we do way more calculations that needed so things are noticeable
slower
  • Loading branch information
guschmue authored Feb 11, 2025
1 parent 3901e96 commit 3775057
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/webgpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) {
const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)});
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0;
const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length_;
const int total_sequence_length = parameters.total_sequence_length_;

const TensorShapeVector probs_dims({parameters.batch_size_, parameters.num_heads_,
parameters.sequence_length_, total_sequence_length});
Expand Down

0 comments on commit 3775057

Please sign in to comment.