Skip to content

Add external_key/external_value inputs to GroupQueryAttention for KV-shared layers#28242

Open
apsonawane wants to merge 10 commits intomainfrom
asonawane/gemma4
Open

Add external_key/external_value inputs to GroupQueryAttention for KV-shared layers#28242
apsonawane wants to merge 10 commits intomainfrom
asonawane/gemma4

Conversation

@apsonawane
Copy link
Copy Markdown
Contributor

Summary

Adds optional external_key and external_value inputs (inputs 14/15) to the com.microsoft.GroupQueryAttention op, enabling KV-shared layers where decoder layers borrow pre-computed K,V tensors from earlier "source" layers instead of maintaining their own KV cache.

This is needed for architectures like Gemma4 E2B which has 35 decoder layers — 15 with independent KV projections and 20 KV-shared layers that reuse K,V from the last matching source layer.

Changes

Schema (bert_defs.cc)

  • Added optional inputs 14 (external_key) and 15 (external_value) in BNSH format

Parameters (attention_parameters.h)

  • Added use_external_kv flag and external_kv_sequence_length to GroupQueryAttentionParameters

Input validation (group_query_attention_helper.h)

  • Check_Q_Only() — validates Q-only input when key/value are nullptr with external KV
  • CheckExternalKV() — validates external KV tensor shapes (4D BNSH, batch/heads match)
  • CheckAndSetExternalKV() — sets parameters and validates total_sequence_length doesn't exceed external KV length
  • CheckInputshas_external_kv parameter distinguishes Q-only from packed QKV

CPU kernel (group_query_attention.cc, gqa_attention_base.h)

  • Reads external_key/external_value from inputs 14/15
  • Skips K/V transpose and RoPE for external KV mode
  • Uses external KV as effective_past_key/effective_past_value
  • ComputeAttentionProbs/ComputeVxAttentionScore: skips ConcatStateChunkGQA, uses external KV directly, copies to present once per KV head (not per Q head)
  • Enforces do_rotary=0 with clear error message

CUDA kernel (group_query_attention.cc, group_query_attention_impl.cu)

  • Reads external_key/external_value from inputs 14/15
  • Sets data pointers and past_present_share_buffer=false for external KV
  • PrepareQKV: copies external KV into present buffers, skips KV append/RoPE
  • Enforces do_rotary=0 with clear error message

CUDA data (attention_data.h)

  • Added external_key/external_value pointers to GroupQueryAttentionData

Backward compatibility

All changes are gated behind use_external_kv == true which is only set when optional inputs 14/15 are provided. No existing model provides these inputs. Zero impact on existing models.

Related

Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/contrib_ops/cpu/bert/attention_parameters.h Outdated
@apsonawane apsonawane requested a review from Copilot April 27, 2026 21:05
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds support for KV-shared decoder layers by allowing com.microsoft.GroupQueryAttention to optionally consume pre-computed external K/V tensors (instead of maintaining/updating its own KV cache), enabling architectures like Gemma4-style KV sharing.

Changes:

  • Extended the GroupQueryAttention schema with optional inputs external_key / external_value (indices 14/15).
  • Added new parameters + validation helpers to detect/configure “external KV” mode and enforce do_rotary=0.
  • Updated CPU and CUDA kernels to source KV from external tensors and bypass KV-cache update / RoPE-on-KV paths.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
onnxruntime/core/graph/contrib_ops/bert_defs.cc Adds schema inputs for external KV (but type/shape inference also needs external-KV awareness).
onnxruntime/contrib_ops/cpu/bert/attention_parameters.h Adds use_external_kv and external_kv_sequence_length to GQA parameters.
onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h Adds Q-only checks and external-KV shape validation/configuration helpers; updates CheckInputs to distinguish packed-QKV vs Q-only.
onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc Plumbs external KV inputs into CPU kernel and skips K/V transpose + rotary when external KV is used.
onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h Updates CPU attention core to skip KV concatenation and copy external KV to present outputs once per KV head.
onnxruntime/contrib_ops/cuda/bert/attention_data.h Adds external KV pointers to CUDA attention data struct.
onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc Plumbs external KV inputs into CUDA kernel and enforces do_rotary=0 for external KV mode.
onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu Adds external-KV path in PrepareQKV (copy external KV to present and skip append/RoPE).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h Outdated
Comment thread onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Comment thread onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc Outdated
Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc Outdated
Comment thread onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu Outdated
Comment thread onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc Outdated
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need to add extra inputs, you can use key/value for that, and make past_key/past_value/present_key/present_value as optional.

@apsonawane apsonawane requested review from Copilot and tianleiwu April 29, 2026 17:09
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +345 to +373
std::vector<int32_t> seqlens_k_data(batch_size, static_cast<int32_t>(total_seq_len - 1));
tester.AddInput<int32_t>("seqlens_k", {batch_size}, seqlens_k_data);
tester.AddInput<int32_t>("total_sequence_length", {1}, {static_cast<int32_t>(total_seq_len)});

tester.AddOptionalInputEdge<float>(); // cos_cache
tester.AddOptionalInputEdge<float>(); // sin_cache
tester.AddOptionalInputEdge<int64_t>(); // position_ids
tester.AddOptionalInputEdge<float>(); // attention_bias
tester.AddOptionalInputEdge<float>(); // head_sink

// Output 0: output (always required)
tester.AddOutput<float>("output", {batch_size, sequence_length, hidden_size},
std::vector<float>(batch_size * sequence_length * hidden_size, 0.0f));

if (omit_present) {
// Omit present_key and present_value — they are optional
tester.AddOptionalOutputEdge<float>(); // present_key
tester.AddOptionalOutputEdge<float>(); // present_value
} else {
int present_seq_len = total_seq_len;
tester.AddOutput<float>("present_key", {batch_size, kv_num_heads, present_seq_len, head_size},
std::vector<float>(batch_size * kv_num_heads * present_seq_len * head_size, 0.0f));
tester.AddOutput<float>("present_value", {batch_size, kv_num_heads, present_seq_len, head_size},
std::vector<float>(batch_size * kv_num_heads * present_seq_len * head_size, 0.0f));
}

if (expect == OpTester::ExpectResult::kExpectSuccess) {
tester.SetOutputTolerance(1e6f);
}
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests effectively only verify the kernel doesn’t fail: the expected output is all zeros and the tolerance is set to 1e6f, which is likely large enough to pass regardless of correctness. Also seqlens_k is set to total_seq_len - 1, which is an unusual/possibly incorrect pairing with total_sequence_length and may reduce the chance of catching length/edge bugs. Consider (mandatory) asserting something meaningful (e.g., run once with present connected and once omitted and compare the resulting output, or compute a small expected output for a deterministic tiny case) and (optional) set seqlens_k to a consistent value like total_seq_len for the no-past scenario.

Copilot uses AI. Check for mistakes.
Comment on lines +101 to +106
// present_key/present_value are required for the CUDA path since flash attention
// and memory-efficient attention read directly from the present KV buffers.
// The CPU path supports optional present outputs for KV-shared layers.
if (data.present_key == nullptr || data.present_value == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"present_key and present_value outputs are required for the CUDA GroupQueryAttention kernel.");
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The schema/docs now declare present_key/present_value as optional, but the CUDA EP hard-errors if they’re omitted. This makes behavior provider-dependent in a way that API consumers may not expect. Prefer (mandatory) supporting omitted present outputs on CUDA as well by allocating internal temporary present KV buffers (scratch) when the outputs are not requested, or by introducing a non-output internal KV storage path; if that’s not feasible, at least ensure the error message is explicitly actionable (e.g., indicates CUDA EP limitation and suggests connecting outputs or using CPU).

Suggested change
// present_key/present_value are required for the CUDA path since flash attention
// and memory-efficient attention read directly from the present KV buffers.
// The CPU path supports optional present outputs for KV-shared layers.
if (data.present_key == nullptr || data.present_value == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"present_key and present_value outputs are required for the CUDA GroupQueryAttention kernel.");
// present_key/present_value are currently required for the CUDA path since flash attention
// and memory-efficient attention read directly from the present KV buffers.
// Note: although the operator schema/docs may declare these outputs optional,
// the CUDA EP does not currently support omitting them.
if (data.present_key == nullptr || data.present_value == nullptr) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"CUDA GroupQueryAttention currently requires both present_key and present_value outputs to be connected "
"because the CUDA kernels read directly from the present KV buffers. "
"If your model relies on omitted optional present outputs, either connect both outputs when using the CUDA "
"execution provider or run this node on CPU.");

Copilot uses AI. Check for mistakes.
Comment thread docs/ContribOperators.md
Comment on lines +2674 to 2682
#### Outputs (1 - 4)

<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>present_key</tt> : T_CACHE</dt>
<dt><tt>present_key</tt> (optional) : T_CACHE</dt>
<dd>present state key with support for format BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
<dt><tt>present_value</tt> : T_CACHE</dt>
<dt><tt>present_value</tt> (optional) : T_CACHE</dt>
<dd>present state value with support for format BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docs correctly mark present_key/present_value as optional, but (per CUDA implementation) omitting them is not currently supported by the CUDA EP. Add a brief note here (mandatory if CUDA limitation remains) stating that some execution providers (CUDA) may require these outputs to be connected, so users aren’t surprised by provider-specific INVALID_ARGUMENT errors.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants