Add external_key/external_value inputs to GroupQueryAttention for KV-shared layers#28242
Add external_key/external_value inputs to GroupQueryAttention for KV-shared layers#28242apsonawane wants to merge 10 commits intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds support for KV-shared decoder layers by allowing com.microsoft.GroupQueryAttention to optionally consume pre-computed external K/V tensors (instead of maintaining/updating its own KV cache), enabling architectures like Gemma4-style KV sharing.
Changes:
- Extended the GroupQueryAttention schema with optional inputs
external_key/external_value(indices 14/15). - Added new parameters + validation helpers to detect/configure “external KV” mode and enforce
do_rotary=0. - Updated CPU and CUDA kernels to source KV from external tensors and bypass KV-cache update / RoPE-on-KV paths.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/graph/contrib_ops/bert_defs.cc | Adds schema inputs for external KV (but type/shape inference also needs external-KV awareness). |
| onnxruntime/contrib_ops/cpu/bert/attention_parameters.h | Adds use_external_kv and external_kv_sequence_length to GQA parameters. |
| onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h | Adds Q-only checks and external-KV shape validation/configuration helpers; updates CheckInputs to distinguish packed-QKV vs Q-only. |
| onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc | Plumbs external KV inputs into CPU kernel and skips K/V transpose + rotary when external KV is used. |
| onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h | Updates CPU attention core to skip KV concatenation and copy external KV to present outputs once per KV head. |
| onnxruntime/contrib_ops/cuda/bert/attention_data.h | Adds external KV pointers to CUDA attention data struct. |
| onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc | Plumbs external KV inputs into CUDA kernel and enforces do_rotary=0 for external KV mode. |
| onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu | Adds external-KV path in PrepareQKV (copy external KV to present and skip append/RoPE). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
tianleiwu
left a comment
There was a problem hiding this comment.
There is no need to add extra inputs, you can use key/value for that, and make past_key/past_value/present_key/present_value as optional.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| std::vector<int32_t> seqlens_k_data(batch_size, static_cast<int32_t>(total_seq_len - 1)); | ||
| tester.AddInput<int32_t>("seqlens_k", {batch_size}, seqlens_k_data); | ||
| tester.AddInput<int32_t>("total_sequence_length", {1}, {static_cast<int32_t>(total_seq_len)}); | ||
|
|
||
| tester.AddOptionalInputEdge<float>(); // cos_cache | ||
| tester.AddOptionalInputEdge<float>(); // sin_cache | ||
| tester.AddOptionalInputEdge<int64_t>(); // position_ids | ||
| tester.AddOptionalInputEdge<float>(); // attention_bias | ||
| tester.AddOptionalInputEdge<float>(); // head_sink | ||
|
|
||
| // Output 0: output (always required) | ||
| tester.AddOutput<float>("output", {batch_size, sequence_length, hidden_size}, | ||
| std::vector<float>(batch_size * sequence_length * hidden_size, 0.0f)); | ||
|
|
||
| if (omit_present) { | ||
| // Omit present_key and present_value — they are optional | ||
| tester.AddOptionalOutputEdge<float>(); // present_key | ||
| tester.AddOptionalOutputEdge<float>(); // present_value | ||
| } else { | ||
| int present_seq_len = total_seq_len; | ||
| tester.AddOutput<float>("present_key", {batch_size, kv_num_heads, present_seq_len, head_size}, | ||
| std::vector<float>(batch_size * kv_num_heads * present_seq_len * head_size, 0.0f)); | ||
| tester.AddOutput<float>("present_value", {batch_size, kv_num_heads, present_seq_len, head_size}, | ||
| std::vector<float>(batch_size * kv_num_heads * present_seq_len * head_size, 0.0f)); | ||
| } | ||
|
|
||
| if (expect == OpTester::ExpectResult::kExpectSuccess) { | ||
| tester.SetOutputTolerance(1e6f); | ||
| } |
There was a problem hiding this comment.
These tests effectively only verify the kernel doesn’t fail: the expected output is all zeros and the tolerance is set to 1e6f, which is likely large enough to pass regardless of correctness. Also seqlens_k is set to total_seq_len - 1, which is an unusual/possibly incorrect pairing with total_sequence_length and may reduce the chance of catching length/edge bugs. Consider (mandatory) asserting something meaningful (e.g., run once with present connected and once omitted and compare the resulting output, or compute a small expected output for a deterministic tiny case) and (optional) set seqlens_k to a consistent value like total_seq_len for the no-past scenario.
| // present_key/present_value are required for the CUDA path since flash attention | ||
| // and memory-efficient attention read directly from the present KV buffers. | ||
| // The CPU path supports optional present outputs for KV-shared layers. | ||
| if (data.present_key == nullptr || data.present_value == nullptr) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
| "present_key and present_value outputs are required for the CUDA GroupQueryAttention kernel."); |
There was a problem hiding this comment.
The schema/docs now declare present_key/present_value as optional, but the CUDA EP hard-errors if they’re omitted. This makes behavior provider-dependent in a way that API consumers may not expect. Prefer (mandatory) supporting omitted present outputs on CUDA as well by allocating internal temporary present KV buffers (scratch) when the outputs are not requested, or by introducing a non-output internal KV storage path; if that’s not feasible, at least ensure the error message is explicitly actionable (e.g., indicates CUDA EP limitation and suggests connecting outputs or using CPU).
| // present_key/present_value are required for the CUDA path since flash attention | |
| // and memory-efficient attention read directly from the present KV buffers. | |
| // The CPU path supports optional present outputs for KV-shared layers. | |
| if (data.present_key == nullptr || data.present_value == nullptr) { | |
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | |
| "present_key and present_value outputs are required for the CUDA GroupQueryAttention kernel."); | |
| // present_key/present_value are currently required for the CUDA path since flash attention | |
| // and memory-efficient attention read directly from the present KV buffers. | |
| // Note: although the operator schema/docs may declare these outputs optional, | |
| // the CUDA EP does not currently support omitting them. | |
| if (data.present_key == nullptr || data.present_value == nullptr) { | |
| return ORT_MAKE_STATUS( | |
| ONNXRUNTIME, INVALID_ARGUMENT, | |
| "CUDA GroupQueryAttention currently requires both present_key and present_value outputs to be connected " | |
| "because the CUDA kernels read directly from the present KV buffers. " | |
| "If your model relies on omitted optional present outputs, either connect both outputs when using the CUDA " | |
| "execution provider or run this node on CPU."); |
| #### Outputs (1 - 4) | ||
|
|
||
| <dl> | ||
| <dt><tt>output</tt> : T</dt> | ||
| <dd>3D output tensor with shape (batch_size, sequence_length, hidden_size)</dd> | ||
| <dt><tt>present_key</tt> : T_CACHE</dt> | ||
| <dt><tt>present_key</tt> (optional) : T_CACHE</dt> | ||
| <dd>present state key with support for format BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd> | ||
| <dt><tt>present_value</tt> : T_CACHE</dt> | ||
| <dt><tt>present_value</tt> (optional) : T_CACHE</dt> | ||
| <dd>present state value with support for format BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd> |
There was a problem hiding this comment.
Docs correctly mark present_key/present_value as optional, but (per CUDA implementation) omitting them is not currently supported by the CUDA EP. Add a brief note here (mandatory if CUDA limitation remains) stating that some execution providers (CUDA) may require these outputs to be connected, so users aren’t surprised by provider-specific INVALID_ARGUMENT errors.
Summary
Adds optional
external_keyandexternal_valueinputs (inputs 14/15) to thecom.microsoft.GroupQueryAttentionop, enabling KV-shared layers where decoder layers borrow pre-computed K,V tensors from earlier "source" layers instead of maintaining their own KV cache.This is needed for architectures like Gemma4 E2B which has 35 decoder layers — 15 with independent KV projections and 20 KV-shared layers that reuse K,V from the last matching source layer.
Changes
Schema (
bert_defs.cc)external_key) and 15 (external_value) in BNSH formatParameters (
attention_parameters.h)use_external_kvflag andexternal_kv_sequence_lengthtoGroupQueryAttentionParametersInput validation (
group_query_attention_helper.h)Check_Q_Only()— validates Q-only input when key/value are nullptr with external KVCheckExternalKV()— validates external KV tensor shapes (4D BNSH, batch/heads match)CheckAndSetExternalKV()— sets parameters and validatestotal_sequence_lengthdoesn't exceed external KV lengthCheckInputs—has_external_kvparameter distinguishes Q-only from packed QKVCPU kernel (
group_query_attention.cc,gqa_attention_base.h)external_key/external_valuefrom inputs 14/15effective_past_key/effective_past_valueComputeAttentionProbs/ComputeVxAttentionScore: skipsConcatStateChunkGQA, uses external KV directly, copies to present once per KV head (not per Q head)do_rotary=0with clear error messageCUDA kernel (
group_query_attention.cc,group_query_attention_impl.cu)external_key/external_valuefrom inputs 14/15past_present_share_buffer=falsefor external KVPrepareQKV: copies external KV into present buffers, skips KV append/RoPEdo_rotary=0with clear error messageCUDA data (
attention_data.h)external_key/external_valuepointers toGroupQueryAttentionDataBackward compatibility
All changes are gated behind
use_external_kv == truewhich is only set when optional inputs 14/15 are provided. No existing model provides these inputs. Zero impact on existing models.Related