diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 9aa44a1600ae6..45e85fcd9c9d5 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2671,14 +2671,14 @@ This version of the operator has been available since version 1 of the 'com.micr
Scale tensor for past_value.
-#### Outputs (3 - 4) +#### Outputs (1 - 4)
output : T
3D output tensor with shape (batch_size, sequence_length, hidden_size)
-
present_key : T_CACHE
+
present_key (optional) : T_CACHE
present state key with support for format BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
-
present_value : T_CACHE
+
present_value (optional) : T_CACHE
present state value with support for format BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
output_qk (optional) : T
Values of QK matrix multiplication, either before or after softmax normalization
diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index adc7b623ec8c4..1f03cf9f105a2 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -85,7 +85,9 @@ class GQAAttentionBase { if (past_key != nullptr && past_value != nullptr) { seqlen_past_kv_cache = static_cast(past_key->Shape().GetDims()[2]); } - int seqlen_present_kv_cache = static_cast(present_key->Shape().GetDims()[2]); + int seqlen_present_kv_cache = present_key != nullptr + ? static_cast(present_key->Shape().GetDims()[2]) + : parameters.seqlen_present_kv_cache; // Compute the attention score. bool gqa_mlas_supported = MlasGQASupported(CblasNoTrans, CblasTrans) && @@ -175,7 +177,7 @@ class GQAAttentionBase { const size_t past_buff_chunk_length = past_buffer_sequence_length * head_size; // L x H const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H - if (!past_present_share_buffer) { + if (present_key && !past_present_share_buffer) { memset((void*)present_key, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); @@ -402,7 +404,7 @@ class GQAAttentionBase { const size_t past_buff_chunk_length = past_buffer_sequence_length * head_size; // L x H const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H - if (!past_present_share_buffer) { + if (present_value && !past_present_share_buffer) { memset((void*)present_value, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 5698bcb659f20..5ee2f31539bae 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -233,7 +233,9 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { const T* head_sink_data = (head_sink != nullptr) ? head_sink->Data() : nullptr; // Compute the attention score and apply the score to V - return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(), + const T* k_data = packed_qkv ? nullptr : k_rotary; + const T* v_data = packed_qkv ? nullptr : V.Get().Data(); + return ApplyAttention(q_rotary, k_data, v_data, head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v, output_qk, seqlens_k, parameters, allocator, context); } diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index f5399e307fbca..f65568700c0c9 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -242,7 +242,7 @@ Status CheckInputs(const T* query, int q_hidden_size = 0; int kv_hidden_size = 0; int head_size = 0; - const bool is_packed_qkv = key == nullptr; + const bool is_packed_qkv = (key == nullptr); if (!is_packed_qkv) { ORT_RETURN_IF_ERROR(Check_Q_K_V(query, key, value, num_heads, kv_num_heads, batch_size, sequence_length, q_hidden_size, kv_hidden_size, head_size)); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 5f21f3cd34e8f..9563292f9187c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -227,6 +227,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons attention_bias, head_sink, parameters)); + parameters.local_window_size = local_window_size_; parameters.is_unidirectional = is_unidirectional_; parameters.use_smooth_softmax = use_smooth_softmax_ || head_sink != nullptr; @@ -291,13 +292,10 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data()); data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); - - data.present_key = reinterpret_cast(present_key_output->MutableData()); - data.present_value = reinterpret_cast(present_value_output->MutableData()); - + data.present_key = (present_key_output != nullptr) ? reinterpret_cast(present_key_output->MutableData()) : nullptr; + data.present_value = (present_value_output != nullptr) ? reinterpret_cast(present_value_output->MutableData()) : nullptr; // Compute past_present_share_buffer early since it's needed for flash attention path selection. - // This compares the final pointer values after quantization handling. - parameters.past_present_share_buffer = (data.past_key == data.present_key); + parameters.past_present_share_buffer = (data.past_key != nullptr && data.past_key == data.present_key); bool is_inputs_quantized = (k_quant_type_ != KVQuantizationType::NONE) || (v_quant_type_ != KVQuantizationType::NONE); constexpr bool is_int8 = std::is_same::value; @@ -562,10 +560,12 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons } // Validate past_value pointer consistency (past_present_share_buffer was computed early after pointer setup) - if (parameters.past_present_share_buffer) { - ORT_ENFORCE(data.past_value == data.present_value, "past_value and present_value must be the same tensor when past_present_share_buffer is true"); - } else { - ORT_ENFORCE(data.past_value != data.present_value, "past_value and present_value must be different tensors when past_present_share_buffer is false"); + if (data.present_value != nullptr) { + if (parameters.past_present_share_buffer) { + ORT_ENFORCE(data.past_value == data.present_value, "past_value and present_value must be the same tensor when past_present_share_buffer is true"); + } else { + ORT_ENFORCE(data.past_value != data.present_value, "past_value and present_value must be different tensors when past_present_share_buffer is false"); + } } data.output = reinterpret_cast(output->MutableData()); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index ebb6a0b0da215..3ce396989b181 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -98,6 +98,14 @@ Status PrepareQKV( q_out = nullptr; } + // present_key/present_value are required for the CUDA path since flash attention + // and memory-efficient attention read directly from the present KV buffers. + // The CPU path supports optional present outputs for KV-shared layers. + if (data.present_key == nullptr || data.present_value == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "present_key and present_value outputs are required for the CUDA GroupQueryAttention kernel."); + } + U* k = reinterpret_cast(data.present_key); U* v = reinterpret_cast(data.present_value); int max_cache_length = parameters.seqlen_present_kv_cache; diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index fd72f751ee810..5fff0516c7ce3 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -212,7 +212,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& scale_, softcap_, 0, - context.DeviceLimits().maxComputeInvocationsPerWorkgroup)); + static_cast(context.DeviceLimits().maxComputeInvocationsPerWorkgroup))); params.use_smooth_softmax = use_smooth_softmax_; params.rotary_interleaved = rotary_interleaved_; diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 1209446c6a367..e8ec04586a9d6 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1323,13 +1323,15 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "present state key with support for format BNSH. When past_key uses same tensor as present_key" "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", - "T_CACHE") + "T_CACHE", + OpSchema::Optional) .Output(2, "present_value", "present state value with support for format BNSH. When past_value uses same tensor as present_value" "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", - "T_CACHE") + "T_CACHE", + OpSchema::Optional) .Output(3, "output_qk", "Values of QK matrix multiplication, either before or after softmax normalization", diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index 0690094031bb8..1d57488d51363 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -307,5 +307,108 @@ TEST(GroupQueryAttentionTest, SeqlensKWrongLength) { {}, nullptr, &execution_providers); } +// ============================================================================ +// Optional present_key/present_value output tests +// ============================================================================ + +// Helper for tests with optional present outputs. +// When omit_present=true, present_key and present_value outputs are not connected. +static void RunGQAOptionalPresentTest( + int batch_size, + int sequence_length, + int total_seq_len, + bool omit_present, + OpTester::ExpectResult expect, + const std::string& expected_message) { + constexpr int num_heads = 2; + constexpr int kv_num_heads = 1; + constexpr int head_size = 8; + constexpr int hidden_size = num_heads * head_size; + constexpr int kv_hidden_size = kv_num_heads * head_size; + + OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + + std::vector query_data(batch_size * sequence_length * hidden_size, 1.0f); + tester.AddInput("query", {batch_size, sequence_length, hidden_size}, query_data); + + std::vector key_data(batch_size * sequence_length * kv_hidden_size, 0.5f); + tester.AddInput("key", {batch_size, sequence_length, kv_hidden_size}, key_data); + + std::vector value_data(batch_size * sequence_length * kv_hidden_size, 0.5f); + tester.AddInput("value", {batch_size, sequence_length, kv_hidden_size}, value_data); + + tester.AddOptionalInputEdge(); // past_key + tester.AddOptionalInputEdge(); // past_value + + std::vector seqlens_k_data(batch_size, static_cast(total_seq_len - 1)); + tester.AddInput("seqlens_k", {batch_size}, seqlens_k_data); + tester.AddInput("total_sequence_length", {1}, {static_cast(total_seq_len)}); + + tester.AddOptionalInputEdge(); // cos_cache + tester.AddOptionalInputEdge(); // sin_cache + tester.AddOptionalInputEdge(); // position_ids + tester.AddOptionalInputEdge(); // attention_bias + tester.AddOptionalInputEdge(); // head_sink + + // Output 0: output (always required) + tester.AddOutput("output", {batch_size, sequence_length, hidden_size}, + std::vector(batch_size * sequence_length * hidden_size, 0.0f)); + + if (omit_present) { + // Omit present_key and present_value — they are optional + tester.AddOptionalOutputEdge(); // present_key + tester.AddOptionalOutputEdge(); // present_value + } else { + int present_seq_len = total_seq_len; + tester.AddOutput("present_key", {batch_size, kv_num_heads, present_seq_len, head_size}, + std::vector(batch_size * kv_num_heads * present_seq_len * head_size, 0.0f)); + tester.AddOutput("present_value", {batch_size, kv_num_heads, present_seq_len, head_size}, + std::vector(batch_size * kv_num_heads * present_seq_len * head_size, 0.0f)); + } + + if (expect == OpTester::ExpectResult::kExpectSuccess) { + tester.SetOutputTolerance(1e6f); + } + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + tester.Run(expect, expected_message, {}, nullptr, &execution_providers); +} + +// Baseline: GQA with present outputs connected works as before +TEST(GroupQueryAttentionTest, OptionalPresent_WithPresent) { + RunGQAOptionalPresentTest( + /*batch_size=*/1, + /*sequence_length=*/4, + /*total_seq_len=*/4, + /*omit_present=*/false, + OpTester::ExpectResult::kExpectSuccess, + ""); +} + +// KV-shared layer scenario: present outputs omitted, attention uses K,V directly +TEST(GroupQueryAttentionTest, OptionalPresent_WithoutPresent) { + RunGQAOptionalPresentTest( + /*batch_size=*/1, + /*sequence_length=*/4, + /*total_seq_len=*/4, + /*omit_present=*/true, + OpTester::ExpectResult::kExpectSuccess, + ""); +} + +// Batched: present outputs omitted with batch_size > 1 +TEST(GroupQueryAttentionTest, OptionalPresent_Batched) { + RunGQAOptionalPresentTest( + /*batch_size=*/2, + /*sequence_length=*/3, + /*total_seq_len=*/3, + /*omit_present=*/true, + OpTester::ExpectResult::kExpectSuccess, + ""); +} + } // namespace test } // namespace onnxruntime