diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 9aa44a1600ae6..45e85fcd9c9d5 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2671,14 +2671,14 @@ This version of the operator has been available since version 1 of the 'com.micr
Scale tensor for past_value.
-#### Outputs (3 - 4)
+#### Outputs (1 - 4)
- output : T
- 3D output tensor with shape (batch_size, sequence_length, hidden_size)
-- present_key : T_CACHE
+- present_key (optional) : T_CACHE
- present state key with support for format BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
-- present_value : T_CACHE
+- present_value (optional) : T_CACHE
- present state value with support for format BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
- output_qk (optional) : T
- Values of QK matrix multiplication, either before or after softmax normalization
diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
index adc7b623ec8c4..1f03cf9f105a2 100644
--- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
@@ -85,7 +85,9 @@ class GQAAttentionBase {
if (past_key != nullptr && past_value != nullptr) {
seqlen_past_kv_cache = static_cast(past_key->Shape().GetDims()[2]);
}
- int seqlen_present_kv_cache = static_cast(present_key->Shape().GetDims()[2]);
+ int seqlen_present_kv_cache = present_key != nullptr
+ ? static_cast(present_key->Shape().GetDims()[2])
+ : parameters.seqlen_present_kv_cache;
// Compute the attention score.
bool gqa_mlas_supported = MlasGQASupported(CblasNoTrans, CblasTrans) &&
@@ -175,7 +177,7 @@ class GQAAttentionBase {
const size_t past_buff_chunk_length = past_buffer_sequence_length * head_size; // L x H
const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H
- if (!past_present_share_buffer) {
+ if (present_key && !past_present_share_buffer) {
memset((void*)present_key,
0,
batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
@@ -402,7 +404,7 @@ class GQAAttentionBase {
const size_t past_buff_chunk_length = past_buffer_sequence_length * head_size; // L x H
const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H
- if (!past_present_share_buffer) {
+ if (present_value && !past_present_share_buffer) {
memset((void*)present_value,
0,
batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
index 5698bcb659f20..5ee2f31539bae 100644
--- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
@@ -233,7 +233,9 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const {
const T* head_sink_data = (head_sink != nullptr) ? head_sink->Data() : nullptr;
// Compute the attention score and apply the score to V
- return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(),
+ const T* k_data = packed_qkv ? nullptr : k_rotary;
+ const T* v_data = packed_qkv ? nullptr : V.Get().Data();
+ return ApplyAttention(q_rotary, k_data, v_data,
head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v,
output_qk, seqlens_k, parameters, allocator, context);
}
diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
index f5399e307fbca..f65568700c0c9 100644
--- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
@@ -242,7 +242,7 @@ Status CheckInputs(const T* query,
int q_hidden_size = 0;
int kv_hidden_size = 0;
int head_size = 0;
- const bool is_packed_qkv = key == nullptr;
+ const bool is_packed_qkv = (key == nullptr);
if (!is_packed_qkv) {
ORT_RETURN_IF_ERROR(Check_Q_K_V(query, key, value, num_heads, kv_num_heads, batch_size, sequence_length,
q_hidden_size, kv_hidden_size, head_size));
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
index 5f21f3cd34e8f..9563292f9187c 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
@@ -227,6 +227,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons
attention_bias,
head_sink,
parameters));
+
parameters.local_window_size = local_window_size_;
parameters.is_unidirectional = is_unidirectional_;
parameters.use_smooth_softmax = use_smooth_softmax_ || head_sink != nullptr;
@@ -291,13 +292,10 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons
data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data());
data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data());
-
- data.present_key = reinterpret_cast(present_key_output->MutableData());
- data.present_value = reinterpret_cast(present_value_output->MutableData());
-
+ data.present_key = (present_key_output != nullptr) ? reinterpret_cast(present_key_output->MutableData()) : nullptr;
+ data.present_value = (present_value_output != nullptr) ? reinterpret_cast(present_value_output->MutableData()) : nullptr;
// Compute past_present_share_buffer early since it's needed for flash attention path selection.
- // This compares the final pointer values after quantization handling.
- parameters.past_present_share_buffer = (data.past_key == data.present_key);
+ parameters.past_present_share_buffer = (data.past_key != nullptr && data.past_key == data.present_key);
bool is_inputs_quantized = (k_quant_type_ != KVQuantizationType::NONE) || (v_quant_type_ != KVQuantizationType::NONE);
constexpr bool is_int8 = std::is_same::value;
@@ -562,10 +560,12 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons
}
// Validate past_value pointer consistency (past_present_share_buffer was computed early after pointer setup)
- if (parameters.past_present_share_buffer) {
- ORT_ENFORCE(data.past_value == data.present_value, "past_value and present_value must be the same tensor when past_present_share_buffer is true");
- } else {
- ORT_ENFORCE(data.past_value != data.present_value, "past_value and present_value must be different tensors when past_present_share_buffer is false");
+ if (data.present_value != nullptr) {
+ if (parameters.past_present_share_buffer) {
+ ORT_ENFORCE(data.past_value == data.present_value, "past_value and present_value must be the same tensor when past_present_share_buffer is true");
+ } else {
+ ORT_ENFORCE(data.past_value != data.present_value, "past_value and present_value must be different tensors when past_present_share_buffer is false");
+ }
}
data.output = reinterpret_cast(output->MutableData());
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
index ebb6a0b0da215..3ce396989b181 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
@@ -98,6 +98,14 @@ Status PrepareQKV(
q_out = nullptr;
}
+ // present_key/present_value are required for the CUDA path since flash attention
+ // and memory-efficient attention read directly from the present KV buffers.
+ // The CPU path supports optional present outputs for KV-shared layers.
+ if (data.present_key == nullptr || data.present_value == nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "present_key and present_value outputs are required for the CUDA GroupQueryAttention kernel.");
+ }
+
U* k = reinterpret_cast(data.present_key);
U* v = reinterpret_cast(data.present_value);
int max_cache_length = parameters.seqlen_present_kv_cache;
diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc
index fd72f751ee810..5fff0516c7ce3 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc
+++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc
@@ -212,7 +212,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
scale_,
softcap_,
0,
- context.DeviceLimits().maxComputeInvocationsPerWorkgroup));
+ static_cast(context.DeviceLimits().maxComputeInvocationsPerWorkgroup)));
params.use_smooth_softmax = use_smooth_softmax_;
params.rotary_interleaved = rotary_interleaved_;
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index 1209446c6a367..e8ec04586a9d6 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -1323,13 +1323,15 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"present state key with support for format BNSH. When past_key uses same tensor as present_key"
"(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +"
"kv_sequence_length.",
- "T_CACHE")
+ "T_CACHE",
+ OpSchema::Optional)
.Output(2,
"present_value",
"present state value with support for format BNSH. When past_value uses same tensor as present_value"
"(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +"
"kv_sequence_length.",
- "T_CACHE")
+ "T_CACHE",
+ OpSchema::Optional)
.Output(3,
"output_qk",
"Values of QK matrix multiplication, either before or after softmax normalization",
diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc
index 0690094031bb8..1d57488d51363 100644
--- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc
+++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc
@@ -307,5 +307,108 @@ TEST(GroupQueryAttentionTest, SeqlensKWrongLength) {
{}, nullptr, &execution_providers);
}
+// ============================================================================
+// Optional present_key/present_value output tests
+// ============================================================================
+
+// Helper for tests with optional present outputs.
+// When omit_present=true, present_key and present_value outputs are not connected.
+static void RunGQAOptionalPresentTest(
+ int batch_size,
+ int sequence_length,
+ int total_seq_len,
+ bool omit_present,
+ OpTester::ExpectResult expect,
+ const std::string& expected_message) {
+ constexpr int num_heads = 2;
+ constexpr int kv_num_heads = 1;
+ constexpr int head_size = 8;
+ constexpr int hidden_size = num_heads * head_size;
+ constexpr int kv_hidden_size = kv_num_heads * head_size;
+
+ OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain);
+ tester.AddAttribute("num_heads", static_cast(num_heads));
+ tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads));
+
+ std::vector query_data(batch_size * sequence_length * hidden_size, 1.0f);
+ tester.AddInput("query", {batch_size, sequence_length, hidden_size}, query_data);
+
+ std::vector key_data(batch_size * sequence_length * kv_hidden_size, 0.5f);
+ tester.AddInput("key", {batch_size, sequence_length, kv_hidden_size}, key_data);
+
+ std::vector value_data(batch_size * sequence_length * kv_hidden_size, 0.5f);
+ tester.AddInput("value", {batch_size, sequence_length, kv_hidden_size}, value_data);
+
+ tester.AddOptionalInputEdge(); // past_key
+ tester.AddOptionalInputEdge(); // past_value
+
+ std::vector seqlens_k_data(batch_size, static_cast(total_seq_len - 1));
+ tester.AddInput("seqlens_k", {batch_size}, seqlens_k_data);
+ tester.AddInput("total_sequence_length", {1}, {static_cast(total_seq_len)});
+
+ tester.AddOptionalInputEdge(); // cos_cache
+ tester.AddOptionalInputEdge(); // sin_cache
+ tester.AddOptionalInputEdge(); // position_ids
+ tester.AddOptionalInputEdge(); // attention_bias
+ tester.AddOptionalInputEdge(); // head_sink
+
+ // Output 0: output (always required)
+ tester.AddOutput("output", {batch_size, sequence_length, hidden_size},
+ std::vector(batch_size * sequence_length * hidden_size, 0.0f));
+
+ if (omit_present) {
+ // Omit present_key and present_value — they are optional
+ tester.AddOptionalOutputEdge(); // present_key
+ tester.AddOptionalOutputEdge(); // present_value
+ } else {
+ int present_seq_len = total_seq_len;
+ tester.AddOutput("present_key", {batch_size, kv_num_heads, present_seq_len, head_size},
+ std::vector(batch_size * kv_num_heads * present_seq_len * head_size, 0.0f));
+ tester.AddOutput("present_value", {batch_size, kv_num_heads, present_seq_len, head_size},
+ std::vector(batch_size * kv_num_heads * present_seq_len * head_size, 0.0f));
+ }
+
+ if (expect == OpTester::ExpectResult::kExpectSuccess) {
+ tester.SetOutputTolerance(1e6f);
+ }
+
+ std::vector> execution_providers;
+ execution_providers.push_back(DefaultCpuExecutionProvider());
+ tester.Run(expect, expected_message, {}, nullptr, &execution_providers);
+}
+
+// Baseline: GQA with present outputs connected works as before
+TEST(GroupQueryAttentionTest, OptionalPresent_WithPresent) {
+ RunGQAOptionalPresentTest(
+ /*batch_size=*/1,
+ /*sequence_length=*/4,
+ /*total_seq_len=*/4,
+ /*omit_present=*/false,
+ OpTester::ExpectResult::kExpectSuccess,
+ "");
+}
+
+// KV-shared layer scenario: present outputs omitted, attention uses K,V directly
+TEST(GroupQueryAttentionTest, OptionalPresent_WithoutPresent) {
+ RunGQAOptionalPresentTest(
+ /*batch_size=*/1,
+ /*sequence_length=*/4,
+ /*total_seq_len=*/4,
+ /*omit_present=*/true,
+ OpTester::ExpectResult::kExpectSuccess,
+ "");
+}
+
+// Batched: present outputs omitted with batch_size > 1
+TEST(GroupQueryAttentionTest, OptionalPresent_Batched) {
+ RunGQAOptionalPresentTest(
+ /*batch_size=*/2,
+ /*sequence_length=*/3,
+ /*total_seq_len=*/3,
+ /*omit_present=*/true,
+ OpTester::ExpectResult::kExpectSuccess,
+ "");
+}
+
} // namespace test
} // namespace onnxruntime