diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 568e75b38a98f..33f67375b4ef2 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -439,7 +439,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; - const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length_; + const int total_sequence_length = parameters.total_sequence_length_; const TensorShapeVector probs_dims({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, total_sequence_length});