Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2671,14 +2671,14 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Scale tensor for past_value.</dd>
</dl>

#### Outputs (3 - 4)
#### Outputs (1 - 4)

<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>present_key</tt> : T_CACHE</dt>
<dt><tt>present_key</tt> (optional) : T_CACHE</dt>
<dd>present state key with support for format BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
<dt><tt>present_value</tt> : T_CACHE</dt>
<dt><tt>present_value</tt> (optional) : T_CACHE</dt>
<dd>present state value with support for format BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
Comment on lines +2674 to 2682
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docs correctly mark present_key/present_value as optional, but (per CUDA implementation) omitting them is not currently supported by the CUDA EP. Add a brief note here (mandatory if CUDA limitation remains) stating that some execution providers (CUDA) may require these outputs to be connected, so users aren’t surprised by provider-specific INVALID_ARGUMENT errors.

Copilot uses AI. Check for mistakes.
<dt><tt>output_qk</tt> (optional) : T</dt>
<dd>Values of QK matrix multiplication, either before or after softmax normalization</dd>
Expand Down
8 changes: 5 additions & 3 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ class GQAAttentionBase {
if (past_key != nullptr && past_value != nullptr) {
seqlen_past_kv_cache = static_cast<int>(past_key->Shape().GetDims()[2]);
}
int seqlen_present_kv_cache = static_cast<int>(present_key->Shape().GetDims()[2]);
int seqlen_present_kv_cache = present_key != nullptr
? static_cast<int>(present_key->Shape().GetDims()[2])
: parameters.seqlen_present_kv_cache;

// Compute the attention score.
bool gqa_mlas_supported = MlasGQASupported<T>(CblasNoTrans, CblasTrans) &&
Expand Down Expand Up @@ -175,7 +177,7 @@ class GQAAttentionBase {
const size_t past_buff_chunk_length = past_buffer_sequence_length * head_size; // L x H
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When present_key/present_value are omitted, this guard prevents ConcatStateChunkGQA from ever building the past+current KV buffer, but the later GEMMs still use total_seqlen = seqlens_k[b] + 1 and seqlen_present_kv_cache as if that full buffer exists. For decoding/subsequent-prompt cases with past_key or total_sequence_length > sequence_length, k/v still point only at the current K/V input, so the CPU kernel can read past that tensor or attend over missing past tokens. Please either allocate an internal temporary present-KV buffer when these outputs are omitted, or reject omitted present outputs unless this is the first-prompt/no-past case where sequence_length == total_sequence_length.

const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H

if (!past_present_share_buffer) {
if (present_key && !past_present_share_buffer) {
memset((void*)present_key,
0,
batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
Expand Down Expand Up @@ -402,7 +404,7 @@ class GQAAttentionBase {
const size_t past_buff_chunk_length = past_buffer_sequence_length * head_size; // L x H
const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H

if (!past_present_share_buffer) {
if (present_value && !past_present_share_buffer) {
memset((void*)present_value,
0,
batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,9 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
const T* head_sink_data = (head_sink != nullptr) ? head_sink->Data<T>() : nullptr;

// Compute the attention score and apply the score to V
return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get<Tensor>().Data<T>(),
const T* k_data = packed_qkv ? nullptr : k_rotary;
const T* v_data = packed_qkv ? nullptr : V.Get<Tensor>().Data<T>();
return ApplyAttention(q_rotary, k_data, v_data,
head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v,
output_qk, seqlens_k, parameters, allocator, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ Status CheckInputs(const T* query,
int q_hidden_size = 0;
int kv_hidden_size = 0;
int head_size = 0;
const bool is_packed_qkv = key == nullptr;
const bool is_packed_qkv = (key == nullptr);
if (!is_packed_qkv) {
ORT_RETURN_IF_ERROR(Check_Q_K_V(query, key, value, num_heads, kv_num_heads, batch_size, sequence_length,
q_hidden_size, kv_hidden_size, head_size));
Expand Down
20 changes: 10 additions & 10 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
attention_bias,
head_sink,
parameters));

Comment thread
apsonawane marked this conversation as resolved.
parameters.local_window_size = local_window_size_;
parameters.is_unidirectional = is_unidirectional_;
parameters.use_smooth_softmax = use_smooth_softmax_ || head_sink != nullptr;
Expand Down Expand Up @@ -291,13 +292,10 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons

data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast<const CudaU*>(past_key->Data<U>());
data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast<const CudaU*>(past_value->Data<U>());

data.present_key = reinterpret_cast<CudaU*>(present_key_output->MutableData<U>());
data.present_value = reinterpret_cast<CudaU*>(present_value_output->MutableData<U>());

data.present_key = (present_key_output != nullptr) ? reinterpret_cast<CudaU*>(present_key_output->MutableData<U>()) : nullptr;
data.present_value = (present_value_output != nullptr) ? reinterpret_cast<CudaU*>(present_value_output->MutableData<U>()) : nullptr;
// Compute past_present_share_buffer early since it's needed for flash attention path selection.
// This compares the final pointer values after quantization handling.
parameters.past_present_share_buffer = (data.past_key == data.present_key);
parameters.past_present_share_buffer = (data.past_key != nullptr && data.past_key == data.present_key);

bool is_inputs_quantized = (k_quant_type_ != KVQuantizationType::NONE) || (v_quant_type_ != KVQuantizationType::NONE);
constexpr bool is_int8 = std::is_same<U, int8_t>::value;
Expand Down Expand Up @@ -562,10 +560,12 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
}

// Validate past_value pointer consistency (past_present_share_buffer was computed early after pointer setup)
if (parameters.past_present_share_buffer) {
ORT_ENFORCE(data.past_value == data.present_value, "past_value and present_value must be the same tensor when past_present_share_buffer is true");
} else {
ORT_ENFORCE(data.past_value != data.present_value, "past_value and present_value must be different tensors when past_present_share_buffer is false");
if (data.present_value != nullptr) {
if (parameters.past_present_share_buffer) {
ORT_ENFORCE(data.past_value == data.present_value, "past_value and present_value must be the same tensor when past_present_share_buffer is true");
} else {
ORT_ENFORCE(data.past_value != data.present_value, "past_value and present_value must be different tensors when past_present_share_buffer is false");
}
}

data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ Status PrepareQKV(
q_out = nullptr;
}

// present_key/present_value are required for the CUDA path since flash attention
// and memory-efficient attention read directly from the present KV buffers.
// The CPU path supports optional present outputs for KV-shared layers.
if (data.present_key == nullptr || data.present_value == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"present_key and present_value outputs are required for the CUDA GroupQueryAttention kernel.");
Comment on lines +101 to +106
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The schema/docs now declare present_key/present_value as optional, but the CUDA EP hard-errors if they’re omitted. This makes behavior provider-dependent in a way that API consumers may not expect. Prefer (mandatory) supporting omitted present outputs on CUDA as well by allocating internal temporary present KV buffers (scratch) when the outputs are not requested, or by introducing a non-output internal KV storage path; if that’s not feasible, at least ensure the error message is explicitly actionable (e.g., indicates CUDA EP limitation and suggests connecting outputs or using CPU).

Suggested change
// present_key/present_value are required for the CUDA path since flash attention
// and memory-efficient attention read directly from the present KV buffers.
// The CPU path supports optional present outputs for KV-shared layers.
if (data.present_key == nullptr || data.present_value == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"present_key and present_value outputs are required for the CUDA GroupQueryAttention kernel.");
// present_key/present_value are currently required for the CUDA path since flash attention
// and memory-efficient attention read directly from the present KV buffers.
// Note: although the operator schema/docs may declare these outputs optional,
// the CUDA EP does not currently support omitting them.
if (data.present_key == nullptr || data.present_value == nullptr) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"CUDA GroupQueryAttention currently requires both present_key and present_value outputs to be connected "
"because the CUDA kernels read directly from the present KV buffers. "
"If your model relies on omitted optional present outputs, either connect both outputs when using the CUDA "
"execution provider or run this node on CPU.");

Copilot uses AI. Check for mistakes.
}

U* k = reinterpret_cast<U*>(data.present_key);
U* v = reinterpret_cast<U*>(data.present_value);
int max_cache_length = parameters.seqlen_present_kv_cache;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
scale_,
softcap_,
0,
context.DeviceLimits().maxComputeInvocationsPerWorkgroup));
static_cast<int>(context.DeviceLimits().maxComputeInvocationsPerWorkgroup)));
params.use_smooth_softmax = use_smooth_softmax_;
params.rotary_interleaved = rotary_interleaved_;

Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1323,13 +1323,15 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"present state key with support for format BNSH. When past_key uses same tensor as present_key"
"(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +"
"kv_sequence_length.",
"T_CACHE")
"T_CACHE",
OpSchema::Optional)
.Output(2,
"present_value",
"present state value with support for format BNSH. When past_value uses same tensor as present_value"
"(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +"
"kv_sequence_length.",
"T_CACHE")
"T_CACHE",
OpSchema::Optional)
.Output(3,
"output_qk",
"Values of QK matrix multiplication, either before or after softmax normalization",
Expand Down
103 changes: 103 additions & 0 deletions onnxruntime/test/contrib_ops/group_query_attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,5 +307,108 @@ TEST(GroupQueryAttentionTest, SeqlensKWrongLength) {
{}, nullptr, &execution_providers);
}

// ============================================================================
// Optional present_key/present_value output tests
// ============================================================================

// Helper for tests with optional present outputs.
// When omit_present=true, present_key and present_value outputs are not connected.
static void RunGQAOptionalPresentTest(
int batch_size,
int sequence_length,
int total_seq_len,
bool omit_present,
OpTester::ExpectResult expect,
const std::string& expected_message) {
constexpr int num_heads = 2;
constexpr int kv_num_heads = 1;
constexpr int head_size = 8;
constexpr int hidden_size = num_heads * head_size;
constexpr int kv_hidden_size = kv_num_heads * head_size;

OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain);
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(num_heads));
tester.AddAttribute<int64_t>("kv_num_heads", static_cast<int64_t>(kv_num_heads));

std::vector<float> query_data(batch_size * sequence_length * hidden_size, 1.0f);
tester.AddInput<float>("query", {batch_size, sequence_length, hidden_size}, query_data);

std::vector<float> key_data(batch_size * sequence_length * kv_hidden_size, 0.5f);
tester.AddInput<float>("key", {batch_size, sequence_length, kv_hidden_size}, key_data);

std::vector<float> value_data(batch_size * sequence_length * kv_hidden_size, 0.5f);
tester.AddInput<float>("value", {batch_size, sequence_length, kv_hidden_size}, value_data);

tester.AddOptionalInputEdge<float>(); // past_key
tester.AddOptionalInputEdge<float>(); // past_value

std::vector<int32_t> seqlens_k_data(batch_size, static_cast<int32_t>(total_seq_len - 1));
tester.AddInput<int32_t>("seqlens_k", {batch_size}, seqlens_k_data);
tester.AddInput<int32_t>("total_sequence_length", {1}, {static_cast<int32_t>(total_seq_len)});

tester.AddOptionalInputEdge<float>(); // cos_cache
tester.AddOptionalInputEdge<float>(); // sin_cache
tester.AddOptionalInputEdge<int64_t>(); // position_ids
tester.AddOptionalInputEdge<float>(); // attention_bias
tester.AddOptionalInputEdge<float>(); // head_sink

// Output 0: output (always required)
tester.AddOutput<float>("output", {batch_size, sequence_length, hidden_size},
std::vector<float>(batch_size * sequence_length * hidden_size, 0.0f));

if (omit_present) {
// Omit present_key and present_value — they are optional
tester.AddOptionalOutputEdge<float>(); // present_key
tester.AddOptionalOutputEdge<float>(); // present_value
} else {
int present_seq_len = total_seq_len;
tester.AddOutput<float>("present_key", {batch_size, kv_num_heads, present_seq_len, head_size},
std::vector<float>(batch_size * kv_num_heads * present_seq_len * head_size, 0.0f));
tester.AddOutput<float>("present_value", {batch_size, kv_num_heads, present_seq_len, head_size},
std::vector<float>(batch_size * kv_num_heads * present_seq_len * head_size, 0.0f));
}

if (expect == OpTester::ExpectResult::kExpectSuccess) {
tester.SetOutputTolerance(1e6f);
}
Comment on lines +345 to +373
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests effectively only verify the kernel doesn’t fail: the expected output is all zeros and the tolerance is set to 1e6f, which is likely large enough to pass regardless of correctness. Also seqlens_k is set to total_seq_len - 1, which is an unusual/possibly incorrect pairing with total_sequence_length and may reduce the chance of catching length/edge bugs. Consider (mandatory) asserting something meaningful (e.g., run once with present connected and once omitted and compare the resulting output, or compute a small expected output for a deterministic tiny case) and (optional) set seqlens_k to a consistent value like total_seq_len for the no-past scenario.

Copilot uses AI. Check for mistakes.

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
tester.Run(expect, expected_message, {}, nullptr, &execution_providers);
}

// Baseline: GQA with present outputs connected works as before
TEST(GroupQueryAttentionTest, OptionalPresent_WithPresent) {
RunGQAOptionalPresentTest(
/*batch_size=*/1,
/*sequence_length=*/4,
/*total_seq_len=*/4,
/*omit_present=*/false,
OpTester::ExpectResult::kExpectSuccess,
"");
}

// KV-shared layer scenario: present outputs omitted, attention uses K,V directly
TEST(GroupQueryAttentionTest, OptionalPresent_WithoutPresent) {
RunGQAOptionalPresentTest(
/*batch_size=*/1,
/*sequence_length=*/4,
/*total_seq_len=*/4,
/*omit_present=*/true,
OpTester::ExpectResult::kExpectSuccess,
"");
}

// Batched: present outputs omitted with batch_size > 1
TEST(GroupQueryAttentionTest, OptionalPresent_Batched) {
RunGQAOptionalPresentTest(
/*batch_size=*/2,
/*sequence_length=*/3,
/*total_seq_len=*/3,
/*omit_present=*/true,
OpTester::ExpectResult::kExpectSuccess,
"");
}

} // namespace test
} // namespace onnxruntime
Loading