From 0db957875ef7e06f9e94abe5b0285857327084c2 Mon Sep 17 00:00:00 2001 From: Akshay Sonawane Date: Fri, 24 Apr 2026 17:37:30 -0700 Subject: [PATCH 1/9] Update GQA to support Gemma4 --- .../cpu/bert/attention_parameters.h | 4 + .../contrib_ops/cpu/bert/gqa_attention_base.h | 39 ++++-- .../cpu/bert/group_query_attention.cc | 28 ++++- .../cpu/bert/group_query_attention_helper.h | 114 +++++++++++++++++- .../contrib_ops/cuda/bert/attention_data.h | 4 + .../cuda/bert/group_query_attention.cc | 49 ++++++-- .../cuda/bert/group_query_attention_impl.cu | 17 +++ .../core/graph/contrib_ops/bert_defs.cc | 15 +++ 8 files changed, 240 insertions(+), 30 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h index f316a0dfdf91c..101ec88df375c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h @@ -101,6 +101,10 @@ struct GroupQueryAttentionParameters : AttentionParameters { KVQuantizationType k_quant_type = KVQuantizationType::NONE; KVQuantizationType v_quant_type = KVQuantizationType::NONE; int kv_cache_bit_width = 0; + + // External KV parameters for KV-shared layers (e.g., Gemma4) + bool use_external_kv = false; // When true, use external K,V tensors instead of internal KV cache + int external_kv_sequence_length = 0; // Sequence length of external KV tensors }; // Parameters deduced from node attributes and inputs/outputs. diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index adc7b623ec8c4..8d2468bb7b21e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -104,6 +104,10 @@ class GQAAttentionBase { bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data; + // External KV mode: K and V are nullptr, past_key/past_value contain the external KV data. + // Skip KV cache concatenation and use external KV directly. + const bool use_external_kv = parameters.use_external_kv; + const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; T* output_qk_buffer = output_qk != nullptr ? output_qk->MutableData() : nullptr; @@ -112,7 +116,7 @@ class GQAAttentionBase { ComputeAttentionProbs(static_cast(attention_probs), Q, k, head_sink, seqlens_k->Data(), attention_bias_data, batch_size, sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, present_key_data, output_qk_buffer, - past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); + past_present_share_buffer, packed_qkv, is_prompt, use_external_kv, tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; @@ -120,12 +124,12 @@ class GQAAttentionBase { seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, - is_prompt, tp, allocator); + is_prompt, use_external_kv, tp, allocator); } else { ComputeAttentionProbs(static_cast(attention_probs), Q, k, head_sink, seqlens_k->Data(), attention_bias_data, batch_size, sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, present_key_data, output_qk_buffer, - past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); + past_present_share_buffer, packed_qkv, is_prompt, use_external_kv, tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; @@ -133,7 +137,7 @@ class GQAAttentionBase { seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, - is_prompt, tp, allocator); + is_prompt, use_external_kv, tp, allocator); } return Status::OK(); @@ -164,6 +168,7 @@ class GQAAttentionBase { const bool past_present_share_buffer, // whether present key and value share the same buffer const bool packed_qkv, // whether Q, K, V are packed const bool is_prompt, // whether it is prompt + const bool use_external_kv, // whether using external KV (skip KV concat) ThreadPool* tp, // thread pool AllocatorPtr allocator) const { // allocator for temporary buffer const ptrdiff_t packed_batch_stride = @@ -237,12 +242,21 @@ class GQAAttentionBase { } const T* k; - if (packed_qkv) { + if (use_external_kv) { + // External KV mode: use past_key directly (it holds the external KV data in BNSH format). + // No new K to concatenate — the external KV is the complete key sequence. + k = past_key + (i / kv_num_heads_factor) * present_buff_chunk_length; + // Also copy to present for output pass-through + if (present_key != nullptr && !past_present_share_buffer) { + memcpy(present_key + (i / kv_num_heads_factor) * present_buff_chunk_length, + k, SafeInt(total_seqlen) * head_size * sizeof(T)); + } + } else if (packed_qkv) { k = K + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); } else { k = K + kv_input_chunk_length * (i / kv_num_heads_factor); } - if (nullptr != present_key) { + if (!use_external_kv && nullptr != present_key) { k = ConcatStateChunkGQA(past_key, k, present_key, present_buff_chunk_length, past_buff_chunk_length, past_chunk_length, kv_input_chunk_length, past_present_share_buffer, i / kv_num_heads_factor); @@ -392,6 +406,7 @@ class GQAAttentionBase { const bool past_present_share_buffer, // whether present key and value share the same buffer const bool packed_qkv, // whether Q, K, V are packed const bool is_prompt, // whether it is prompt + const bool use_external_kv, // whether using external KV (skip KV concat) ThreadPool* tp, AllocatorPtr allocator) const { const ptrdiff_t packed_batch_stride = @@ -445,12 +460,20 @@ class GQAAttentionBase { const size_t past_chunk_length = SafeInt(past_seqlen) * head_size; const T* v; - if (packed_qkv) { + if (use_external_kv) { + // External KV mode: use past_value directly (it holds the external KV data in BNSH format). + v = past_value + (i / kv_num_heads_factor) * present_buff_chunk_length; + // Copy to present for output pass-through + if (present_value != nullptr && !past_present_share_buffer) { + memcpy(present_value + (i / kv_num_heads_factor) * present_buff_chunk_length, + v, SafeInt(total_seqlen) * head_size * sizeof(T)); + } + } else if (packed_qkv) { v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); } else { v = V + kv_input_chunk_length * (i / kv_num_heads_factor); } - if (nullptr != present_value) { + if (!use_external_kv && nullptr != present_value) { v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length, past_chunk_length, kv_input_chunk_length, past_present_share_buffer, i / kv_num_heads_factor); diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 5698bcb659f20..7513c64aeeac0 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -55,6 +55,8 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { const Tensor* position_ids = context->Input(9); const Tensor* attention_bias = context->Input(10); const Tensor* head_sink = context->Input(11); + const Tensor* external_key = context->Input(14); + const Tensor* external_value = context->Input(15); GroupQueryAttentionParameters parameters = {}; ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, @@ -71,13 +73,16 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { total_seqlen_tensor, scale_, softcap_, - 0)); + 0, + external_key != nullptr)); ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids, attention_bias, head_sink, parameters)); + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckAndSetExternalKV(external_key, external_value, parameters)); + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int present_kv_seqlen = parameters.seqlen_present_kv_cache; @@ -125,7 +130,11 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { OrtValue Q; OrtValue K; OrtValue V; - if (packed_qkv) { + if (parameters.use_external_kv) { + // External KV mode: only Q needs transposing. K,V come from external tensors. + ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( + allocator, batch_size, num_heads_, sequence_length, head_size, query, Q)); + } else if (packed_qkv) { ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( allocator, batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size, query, Q)); } else { @@ -141,8 +150,8 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { OrtValue RotaryQ; OrtValue RotaryK; T* q_rotary = Q.GetMutable()->MutableData(); - T* k_rotary = packed_qkv ? nullptr : K.GetMutable()->MutableData(); - if (do_rotary_) { + T* k_rotary = (packed_qkv || parameters.use_external_kv) ? nullptr : K.GetMutable()->MutableData(); + if (do_rotary_ && !parameters.use_external_kv) { ORT_ENFORCE(cos_cache != nullptr && sin_cache != nullptr, "cos_cache and sin_cache must be provided when do_rotary is true"); // Initialize rotary parameters rotary_embedding_helper::RotaryParameters rotary_params = {}; @@ -232,9 +241,16 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { const T* head_sink_data = (head_sink != nullptr) ? head_sink->Data() : nullptr; + // When external KV is provided, use it in place of past_key/past_value for attention computation. + // External KV is pre-computed from another layer (KV-shared layers, e.g., Gemma4). + const Tensor* effective_past_key = parameters.use_external_kv ? external_key : past_key; + const Tensor* effective_past_value = parameters.use_external_kv ? external_value : past_value; + // Compute the attention score and apply the score to V - return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(), - head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v, + const T* k_data = (packed_qkv || parameters.use_external_kv) ? nullptr : k_rotary; + const T* v_data = (packed_qkv || parameters.use_external_kv) ? nullptr : V.Get().Data(); + return ApplyAttention(q_rotary, k_data, v_data, + head_sink_data, attention_bias, effective_past_key, effective_past_value, output, present_k, present_v, output_qk, seqlens_k, parameters, allocator, context); } } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index f5399e307fbca..f730f5dbdd82d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -96,6 +96,27 @@ Status Check_QKV(const T* packed_qkv, const T* value, const int num_heads, const return Status::OK(); } +template +Status Check_Q_Only(const T* query, const int num_heads, const int kv_num_heads, + int& batch_size, int& sequence_length, int& q_hidden_size, int& kv_hidden_size, int& head_size) { + const auto& query_dims = query->Shape().GetDims(); + if (query_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", + query_dims.size()); + } + batch_size = static_cast(query_dims[0]); + sequence_length = static_cast(query_dims[1]); + q_hidden_size = static_cast(query_dims[2]); + head_size = static_cast(q_hidden_size) / num_heads; + if (head_size % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be a multiple of 8. Got head_size % 8 == ", + head_size % 8); + } + kv_hidden_size = head_size * kv_num_heads; + return Status::OK(); +} + template Status CheckPast(const T* past_key, const T* past_value, int batch_size, int kv_num_heads, int head_size, int kv_cache_bit_width, int& past_sequence_length) { @@ -157,6 +178,56 @@ Status CheckPast(const T* past_key, const T* past_value, int batch_size, int kv_ return Status::OK(); } +template +Status CheckExternalKV(const T* external_key, const T* external_value, int batch_size, int kv_num_heads, + int& external_sequence_length) { + if (external_key == nullptr || external_value == nullptr) { + if (external_key != nullptr || external_value != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'external_key' and 'external_value' shall be both present or both absent."); + } + return Status::OK(); + } + + const auto& ext_key_dims = external_key->Shape().GetDims(); + const auto& ext_value_dims = external_value->Shape().GetDims(); + + if (ext_key_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'external_key' is expected to have 4 dimensions (BNSH), got ", + ext_key_dims.size()); + } + if (ext_value_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'external_value' is expected to have 4 dimensions (BNSH), got ", + ext_value_dims.size()); + } + if (ext_key_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'external_key' dimension 0 should be batch_size, got ", ext_key_dims[0]); + } + if (ext_value_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'external_value' dimension 0 should be batch_size, got ", ext_value_dims[0]); + } + if (ext_key_dims[1] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'external_key' shall have kv_num_heads, got ", ext_key_dims[1]); + } + if (ext_value_dims[1] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'external_value' shall have kv_num_heads, got ", ext_value_dims[1]); + } + if (ext_key_dims[2] != ext_value_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'external_key' and 'external_value' should have same sequence length dimension."); + } + // Note: head_size validation is relaxed here — external KV may have different head_size + // than the query (e.g., Gemma4 global layers with head_dim=512 vs local head_dim=256). + external_sequence_length = static_cast(ext_key_dims[2]); + return Status::OK(); +} + template Status CheckRotaryCaches(const T* cos_cache, const T* sin_cache, int head_size, int total_sequence_length, int& rotary_dim) { @@ -207,7 +278,8 @@ Status CheckInputs(const T* query, const T* total_seqlen, float scale, float softcap, - int kv_cache_bit_width) { + int kv_cache_bit_width, + bool has_external_kv = false) { // Note: Here S* is seqlen_past_kv_cache, S+ is seqlen_present_kv_cache // past_key : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr // past_value : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr @@ -242,8 +314,14 @@ Status CheckInputs(const T* query, int q_hidden_size = 0; int kv_hidden_size = 0; int head_size = 0; - const bool is_packed_qkv = key == nullptr; - if (!is_packed_qkv) { + // When external KV is provided, key/value can be nullptr without implying packed QKV. + // In this mode, query contains only Q (not packed QKV). + const bool is_packed_qkv = (key == nullptr) && !has_external_kv; + if (has_external_kv && key == nullptr) { + // Q-only mode: query is just Q, K and V come from external tensors + ORT_RETURN_IF_ERROR(Check_Q_Only(query, num_heads, kv_num_heads, batch_size, sequence_length, + q_hidden_size, kv_hidden_size, head_size)); + } else if (!is_packed_qkv) { ORT_RETURN_IF_ERROR(Check_Q_K_V(query, key, value, num_heads, kv_num_heads, batch_size, sequence_length, q_hidden_size, kv_hidden_size, head_size)); } else { @@ -350,12 +428,13 @@ Status CheckInputs(const T* query, float scale, float softcap, int kv_cache_bit_width, - int max_threads_per_block) { + int max_threads_per_block, + bool has_external_kv = false) { if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); } - return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale, softcap, kv_cache_bit_width); + return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale, softcap, kv_cache_bit_width, has_external_kv); } template @@ -445,6 +524,31 @@ inline Status CheckNoQKOutput(int num_outputs, int qk_output) { return Status::OK(); } +// Validate and configure external KV inputs for KV-shared layers. +// Call this after CheckInputs to set up external KV parameters. +template +Status CheckAndSetExternalKV(const T* external_key, const T* external_value, + GroupQueryAttentionParameters& parameters) { + if (external_key == nullptr && external_value == nullptr) { + return Status::OK(); + } + + int external_sequence_length = 0; + ORT_RETURN_IF_ERROR(CheckExternalKV(external_key, external_value, + parameters.batch_size, parameters.kv_num_heads, + external_sequence_length)); + + parameters.use_external_kv = true; + parameters.external_kv_sequence_length = external_sequence_length; + + // When using external KV, the total sequence length for attention is determined + // by the external KV tensor, not the internal KV cache. + parameters.total_sequence_length = external_sequence_length; + parameters.seqlen_present_kv_cache = external_sequence_length; + + return Status::OK(); +} + } // namespace group_query_attention_helper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h index 486bf05bd86d5..a95e9b818ca81 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h @@ -192,6 +192,10 @@ struct GroupQueryAttentionData { U* present_key = nullptr; U* present_value = nullptr; + // External KV for KV-shared layers (e.g., Gemma4) + const U* external_key = nullptr; + const U* external_value = nullptr; + // Kernel Flags bool use_flash_attention = false; bool use_memory_efficient_attention = false; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 3b6b5f9079ebe..f8de4d44261f5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -164,6 +164,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons const Tensor* head_sink = context->Input(11); const Tensor* k_scale = context->Input(12); const Tensor* v_scale = context->Input(13); + const Tensor* external_key = context->Input(14); + const Tensor* external_value = context->Input(15); if (k_quant_type_ != KVQuantizationType::NONE) { if (k_scale == nullptr) { @@ -219,12 +221,16 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons scale_, softcap_, kv_cache_bit_width_, - device_prop.maxThreadsPerBlock)); + device_prop.maxThreadsPerBlock, + external_key != nullptr)); ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids, attention_bias, head_sink, parameters)); + + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckAndSetExternalKV(external_key, external_value, parameters)); + parameters.local_window_size = local_window_size_; parameters.is_unidirectional = is_unidirectional_; parameters.use_smooth_softmax = use_smooth_softmax_ || head_sink != nullptr; @@ -236,6 +242,14 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons parameters.do_rotary = do_rotary_; parameters.rotary_interleaved = rotary_interleaved_; + // When using external KV, disable rotary embedding — the external KV already has RoPE applied + // from the source layer, and the caller is expected to pre-apply RoPE to Q. + if (parameters.use_external_kv && parameters.do_rotary) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "do_rotary must be 0 when using external_key/external_value. " + "Pre-apply RoPE to Q and use already-rotated K from the source layer."); + } + // The current GQA CUDA implementation will never be able to have a QK output. // GQA CUDA uses either flash attention or memory efficient attention. Neither kernel supports returning the QK output. ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( @@ -287,15 +301,26 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons data.k_scale = k_scale == nullptr ? nullptr : reinterpret_cast(k_scale->DataRaw()); data.v_scale = v_scale == nullptr ? nullptr : reinterpret_cast(v_scale->DataRaw()); - data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data()); - data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); - - data.present_key = reinterpret_cast(present_key_output->MutableData()); - data.present_value = reinterpret_cast(present_value_output->MutableData()); - - // Compute past_present_share_buffer early since it's needed for flash attention path selection. - // This compares the final pointer values after quantization handling. - parameters.past_present_share_buffer = (data.past_key == data.present_key); + if (parameters.use_external_kv) { + // External KV mode: use external tensors as the KV source for attention. + // The external KV is treated as "past" KV since it's already computed. + // No KV cache update is performed — the present outputs are copies/views of external KV. + data.external_key = reinterpret_cast(external_key->Data()); + data.external_value = reinterpret_cast(external_value->Data()); + data.past_key = data.external_key; + data.past_value = data.external_value; + data.present_key = reinterpret_cast(present_key_output->MutableData()); + data.present_value = reinterpret_cast(present_value_output->MutableData()); + // Mark as shared buffer so the kernel treats external KV as already-populated cache + parameters.past_present_share_buffer = false; + } else { + data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data()); + data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); + data.present_key = reinterpret_cast(present_key_output->MutableData()); + data.present_value = reinterpret_cast(present_value_output->MutableData()); + // Compute past_present_share_buffer early since it's needed for flash attention path selection. + parameters.past_present_share_buffer = (data.past_key == data.present_key); + } bool is_inputs_quantized = (k_quant_type_ != KVQuantizationType::NONE) || (v_quant_type_ != KVQuantizationType::NONE); constexpr bool is_int8 = std::is_same::value; @@ -519,7 +544,9 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons } // Validate past_value pointer consistency (past_present_share_buffer was computed early after pointer setup) - if (parameters.past_present_share_buffer) { + if (parameters.use_external_kv) { + // External KV mode: past and present are separate (external source -> present output) + } else if (parameters.past_present_share_buffer) { ORT_ENFORCE(data.past_value == data.present_value, "past_value and present_value must be the same tensor when past_present_share_buffer is true"); } else { ORT_ENFORCE(data.past_value != data.present_value, "past_value and present_value must be different tensors when past_present_share_buffer is false"); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index c617de747ccf7..d1a4bde9fe20d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -97,6 +97,23 @@ Status PrepareQKV( q_out = nullptr; } + // External KV mode: the external KV data is already set as past_key/past_value. + // Copy it into the present buffers and skip the KV append/RoPE logic. + if (parameters.use_external_kv) { + U* k = reinterpret_cast(data.present_key); + U* v = reinterpret_cast(data.present_value); + int external_seq_len = parameters.external_kv_sequence_length; + + // Copy external KV into present buffers + size_t kv_copy_size = (size_t)batch_size * kv_num_heads * external_seq_len * head_size * sizeof(U); + CUDA_CALL_THROW(cudaMemcpyAsync(k, data.past_key, kv_copy_size, cudaMemcpyDeviceToDevice, stream)); + CUDA_CALL_THROW(cudaMemcpyAsync(v, data.past_value, kv_copy_size, cudaMemcpyDeviceToDevice, stream)); + + // Q is used directly from the input + q = reinterpret_cast(data.query); + return Status::OK(); + } + U* k = reinterpret_cast(data.present_key); U* v = reinterpret_cast(data.present_value); int max_cache_length = parameters.seqlen_present_kv_cache; diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 1209446c6a367..9b13cc170a27d 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1314,6 +1314,21 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(12, "k_scale", "Scale tensor for past_key.", "T_KV_SCALE", OpSchema::Optional) .Input(13, "v_scale", "Scale tensor for past_value.", "T_KV_SCALE", OpSchema::Optional) + .Input(14, + "external_key", + "External pre-computed key tensor in BNSH format (batch_size, kv_num_heads, external_seq_len, head_size). " + "Used for KV-shared layers that borrow K,V from another layer's present KV output. " + "When provided, the operator skips its internal KV cache update and uses this tensor directly " + "for attention computation. RoPE is not applied to external keys (assumed already applied).", + "T_CACHE", + OpSchema::Optional) + .Input(15, + "external_value", + "External pre-computed value tensor in BNSH format (batch_size, kv_num_heads, external_seq_len, head_size). " + "Must be provided together with external_key. When provided, the operator uses this tensor " + "for attention computation instead of the internal KV cache.", + "T_CACHE", + OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", From ec041dbb8294f095374871b22262d5489a4442bd Mon Sep 17 00:00:00 2001 From: Akshay Sonawane Date: Mon, 27 Apr 2026 10:44:38 -0700 Subject: [PATCH 2/9] Fix the op --- onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h | 8 ++++---- onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc | 7 +++++++ .../contrib_ops/cpu/bert/group_query_attention_helper.h | 7 +++++++ 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 8d2468bb7b21e..f2ea775dd78ff 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -246,8 +246,8 @@ class GQAAttentionBase { // External KV mode: use past_key directly (it holds the external KV data in BNSH format). // No new K to concatenate — the external KV is the complete key sequence. k = past_key + (i / kv_num_heads_factor) * present_buff_chunk_length; - // Also copy to present for output pass-through - if (present_key != nullptr && !past_present_share_buffer) { + // Copy to present for output pass-through (once per KV head, not per Q head) + if (present_key != nullptr && !past_present_share_buffer && head_index % kv_num_heads_factor == 0) { memcpy(present_key + (i / kv_num_heads_factor) * present_buff_chunk_length, k, SafeInt(total_seqlen) * head_size * sizeof(T)); } @@ -463,8 +463,8 @@ class GQAAttentionBase { if (use_external_kv) { // External KV mode: use past_value directly (it holds the external KV data in BNSH format). v = past_value + (i / kv_num_heads_factor) * present_buff_chunk_length; - // Copy to present for output pass-through - if (present_value != nullptr && !past_present_share_buffer) { + // Copy to present for output pass-through (once per KV head, not per Q head) + if (present_value != nullptr && !past_present_share_buffer && head_index % kv_num_heads_factor == 0) { memcpy(present_value + (i / kv_num_heads_factor) * present_buff_chunk_length, v, SafeInt(total_seqlen) * head_size * sizeof(T)); } diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 7513c64aeeac0..6cba0c32757b5 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -83,6 +83,13 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckAndSetExternalKV(external_key, external_value, parameters)); + // External KV mode requires do_rotary=0 — K already has RoPE from the source layer + if (parameters.use_external_kv && do_rotary_) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "do_rotary must be 0 when using external_key/external_value. " + "Pre-apply RoPE to Q and use already-rotated K from the source layer."); + } + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int present_kv_seqlen = parameters.seqlen_present_kv_cache; diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index f730f5dbdd82d..7b72cbe3b9804 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -543,6 +543,13 @@ Status CheckAndSetExternalKV(const T* external_key, const T* external_value, // When using external KV, the total sequence length for attention is determined // by the external KV tensor, not the internal KV cache. + // Validate that the original total_sequence_length doesn't exceed external KV length. + if (parameters.total_sequence_length > external_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "total_sequence_length (", parameters.total_sequence_length, + ") exceeds external KV sequence length (", external_sequence_length, + "). Ensure seqlens_k is consistent with the external KV tensor size."); + } parameters.total_sequence_length = external_sequence_length; parameters.seqlen_present_kv_cache = external_sequence_length; From 7139053b63617de2d4c5716de604c9f0223f3130 Mon Sep 17 00:00:00 2001 From: Akshay Sonawane Date: Mon, 27 Apr 2026 11:02:58 -0700 Subject: [PATCH 3/9] Fix lint error --- onnxruntime/contrib_ops/cpu/bert/attention_parameters.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h index 101ec88df375c..08345727677a6 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h @@ -103,8 +103,8 @@ struct GroupQueryAttentionParameters : AttentionParameters { int kv_cache_bit_width = 0; // External KV parameters for KV-shared layers (e.g., Gemma4) - bool use_external_kv = false; // When true, use external K,V tensors instead of internal KV cache - int external_kv_sequence_length = 0; // Sequence length of external KV tensors + bool use_external_kv = false; // When true, use external K,V tensors instead of internal KV cache + int external_kv_sequence_length = 0; // Sequence length of external KV tensors }; // Parameters deduced from node attributes and inputs/outputs. From 0bf639413032c0af0b4ea0b5da8b2e605f4146f1 Mon Sep 17 00:00:00 2001 From: Akshay Sonawane Date: Mon, 27 Apr 2026 14:03:45 -0700 Subject: [PATCH 4/9] Fix webgpu build --- onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 7b72cbe3b9804..b94db4c9774e6 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -279,7 +279,7 @@ Status CheckInputs(const T* query, float scale, float softcap, int kv_cache_bit_width, - bool has_external_kv = false) { + bool has_external_kv) { // Note: Here S* is seqlen_past_kv_cache, S+ is seqlen_present_kv_cache // past_key : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr // past_value : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr From 3c07e1dd8d4ea4bc678b5faf00348db5b4c93220 Mon Sep 17 00:00:00 2001 From: Akshay Sonawane Date: Mon, 27 Apr 2026 14:32:09 -0700 Subject: [PATCH 5/9] fix webgpu build --- onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index fd72f751ee810..98f09a496f579 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -212,7 +212,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& scale_, softcap_, 0, - context.DeviceLimits().maxComputeInvocationsPerWorkgroup)); + static_cast(context.DeviceLimits().maxComputeInvocationsPerWorkgroup), + /*has_external_kv=*/false)); params.use_smooth_softmax = use_smooth_softmax_; params.rotary_interleaved = rotary_interleaved_; From f9051967780107f35ef999b5104762936b35f6df Mon Sep 17 00:00:00 2001 From: Akshay Sonawane Date: Mon, 27 Apr 2026 14:58:36 -0700 Subject: [PATCH 6/9] Address copilot comments --- .../cpu/bert/group_query_attention.cc | 7 + .../cpu/bert/group_query_attention_helper.h | 44 +++++- .../cuda/bert/group_query_attention.cc | 7 + .../cuda/bert/group_query_attention_impl.cu | 6 +- .../core/graph/contrib_ops/bert_defs.cc | 9 +- .../group_query_attention_op_test.cc | 127 ++++++++++++++++++ 6 files changed, 193 insertions(+), 7 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 6cba0c32757b5..d126e68f69edd 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -83,6 +83,13 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckAndSetExternalKV(external_key, external_value, parameters)); + // External KV is mutually exclusive with provided key/value inputs + if (parameters.use_external_kv && (key != nullptr || value != nullptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "key and value (inputs 1/2) must not be provided when using external_key/external_value. " + "External KV replaces the K,V projections entirely."); + } + // External KV mode requires do_rotary=0 — K already has RoPE from the source layer if (parameters.use_external_kv && do_rotary_) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index b94db4c9774e6..0a458c3a8e124 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -107,7 +107,18 @@ Status Check_Q_Only(const T* query, const int num_heads, const int kv_num_heads, batch_size = static_cast(query_dims[0]); sequence_length = static_cast(query_dims[1]); q_hidden_size = static_cast(query_dims[2]); - head_size = static_cast(q_hidden_size) / num_heads; + if (q_hidden_size % num_heads != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "q_hidden_size (", q_hidden_size, ") must be divisible by num_heads (", num_heads, + ") in Q-only mode (external KV). Got q_hidden_size % num_heads == ", + q_hidden_size % num_heads); + } + head_size = q_hidden_size / num_heads; + if (head_size == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be > 0. Got q_hidden_size=", q_hidden_size, + ", num_heads=", num_heads); + } if (head_size % 8 != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "head_size must be a multiple of 8. Got head_size % 8 == ", @@ -180,7 +191,7 @@ Status CheckPast(const T* past_key, const T* past_value, int batch_size, int kv_ template Status CheckExternalKV(const T* external_key, const T* external_value, int batch_size, int kv_num_heads, - int& external_sequence_length) { + int head_size, int kv_cache_bit_width, int& external_sequence_length) { if (external_key == nullptr || external_value == nullptr) { if (external_key != nullptr || external_value != nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -222,8 +233,18 @@ Status CheckExternalKV(const T* external_key, const T* external_value, int batch return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'external_key' and 'external_value' should have same sequence length dimension."); } - // Note: head_size validation is relaxed here — external KV may have different head_size - // than the query (e.g., Gemma4 global layers with head_dim=512 vs local head_dim=256). + // Validate head dimension (dim 3). For 4-bit quantized KV cache, the stored dimension is head_size / 2. + int expected_head_dim = (kv_cache_bit_width == 4) ? (head_size / 2) : head_size; + if (ext_key_dims[3] != expected_head_dim) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'external_key' dimension 3 should be head_size (", expected_head_dim, + "), got ", ext_key_dims[3]); + } + if (ext_value_dims[3] != expected_head_dim) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'external_value' dimension 3 should be head_size (", expected_head_dim, + "), got ", ext_value_dims[3]); + } external_sequence_length = static_cast(ext_key_dims[2]); return Status::OK(); } @@ -536,11 +557,26 @@ Status CheckAndSetExternalKV(const T* external_key, const T* external_value, int external_sequence_length = 0; ORT_RETURN_IF_ERROR(CheckExternalKV(external_key, external_value, parameters.batch_size, parameters.kv_num_heads, + parameters.head_size, parameters.kv_cache_bit_width, external_sequence_length)); parameters.use_external_kv = true; parameters.external_kv_sequence_length = external_sequence_length; + // External KV mode is incompatible with packed QKV — query must contain only Q. + if (parameters.is_packed_qkv) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "external_key/external_value cannot be used with packed QKV input. " + "Provide query as Q-only (without K,V) when using external KV."); + } + + // External KV replaces the internal KV cache — past_key/past_value should not be provided. + if (parameters.seqlen_past_kv_cache > 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "past_key/past_value should not be provided when using external_key/external_value. " + "External KV replaces the internal KV cache entirely."); + } + // When using external KV, the total sequence length for attention is determined // by the external KV tensor, not the internal KV cache. // Validate that the original total_sequence_length doesn't exceed external KV length. diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index f3fdb27e09b10..2562e72451bd8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -233,6 +233,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckAndSetExternalKV(external_key, external_value, parameters)); + // External KV is mutually exclusive with provided key/value inputs + if (parameters.use_external_kv && (key != nullptr || value != nullptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "key and value (inputs 1/2) must not be provided when using external_key/external_value. " + "External KV replaces the K,V projections entirely."); + } + parameters.local_window_size = local_window_size_; parameters.is_unidirectional = is_unidirectional_; parameters.use_smooth_softmax = use_smooth_softmax_ || head_sink != nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index e8a222dcc4396..4f0da8fda4d08 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -105,8 +105,10 @@ Status PrepareQKV( U* v = reinterpret_cast(data.present_value); int external_seq_len = parameters.external_kv_sequence_length; - // Copy external KV into present buffers - size_t kv_copy_size = (size_t)batch_size * kv_num_heads * external_seq_len * head_size * sizeof(U); + // For 4-bit quantized KV cache, the stored head dimension is head_size/2 (two nibbles per byte). + // Use the packed dimension to compute the correct copy size. + int cache_head_dim = (parameters.kv_cache_bit_width == 4) ? (head_size + 1) / 2 : head_size; + size_t kv_copy_size = (size_t)batch_size * kv_num_heads * external_seq_len * cache_head_dim * sizeof(U); CUDA_CALL_THROW(cudaMemcpyAsync(k, data.past_key, kv_copy_size, cudaMemcpyDeviceToDevice, stream)); CUDA_CALL_THROW(cudaMemcpyAsync(v, data.past_value, kv_copy_size, cudaMemcpyDeviceToDevice, stream)); diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 9b13cc170a27d..ccecc28ec6afc 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -240,13 +240,20 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte if (ctx.getNumOutputs() >= 3) { // has present output const auto* past_key_type = ctx.getInputType(past_key_index); + // external_key is at input index 14 for GroupQueryAttention + const auto* external_key_type = (ctx.getNumInputs() > 14) ? ctx.getInputType(14) : nullptr; if (past_key_type != nullptr) { // present_key and present_value have the same type as past_key/past_value. // This allows them to be int8 or packed uint8 when quantization is enabled. ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, past_key_index, 1); // present_key ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, past_key_index + 1, 2); // present_value + } else if (external_key_type != nullptr) { + // When external KV is provided (inputs 14/15), present outputs should match + // the external KV type (T_CACHE), not the query type (T). + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 14, 1); // present_key from external_key + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 15, 2); // present_value from external_value } else { - // If no past state, present is the same type as query. + // If no past state and no external KV, present is the same type as query. ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 1); ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 2); } diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index 0690094031bb8..593fc55c01e99 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -307,5 +307,132 @@ TEST(GroupQueryAttentionTest, SeqlensKWrongLength) { {}, nullptr, &execution_providers); } +// ============================================================================ +// External KV tests (inputs 14/15: external_key, external_value) +// ============================================================================ + +// Helper for external KV tests +static void RunGQAExternalKVTest( + int external_seq_len, + OpTester::ExpectResult expect, + const std::string& expected_message, + bool provide_key_value = false, + bool provide_past = false, + bool do_rotary = false) { + constexpr int batch_size = 1; + constexpr int sequence_length = 1; + constexpr int num_heads = 2; + constexpr int kv_num_heads = 1; + constexpr int head_size = 8; + constexpr int hidden_size = num_heads * head_size; + constexpr int kv_hidden_size = kv_num_heads * head_size; + + OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + if (do_rotary) { + tester.AddAttribute("do_rotary", 1); + } + + // Query (Q-only when using external KV) + std::vector query_data(batch_size * sequence_length * hidden_size, 1.0f); + tester.AddInput("query", {batch_size, sequence_length, hidden_size}, query_data); + + // Key/Value inputs (should be absent for external KV) + if (provide_key_value) { + std::vector key_data(batch_size * sequence_length * kv_hidden_size, 1.0f); + std::vector value_data(batch_size * sequence_length * kv_hidden_size, 1.0f); + tester.AddInput("key", {batch_size, sequence_length, kv_hidden_size}, key_data); + tester.AddInput("value", {batch_size, sequence_length, kv_hidden_size}, value_data); + } else { + tester.AddOptionalInputEdge(); // key + tester.AddOptionalInputEdge(); // value + } + + // Past key/value (should be absent for external KV) + if (provide_past) { + std::vector past_k(batch_size * kv_num_heads * 4 * head_size, 0.5f); + std::vector past_v(batch_size * kv_num_heads * 4 * head_size, 0.5f); + tester.AddInput("past_key", {batch_size, kv_num_heads, 4, head_size}, past_k); + tester.AddInput("past_value", {batch_size, kv_num_heads, 4, head_size}, past_v); + } else { + tester.AddOptionalInputEdge(); // past_key + tester.AddOptionalInputEdge(); // past_value + } + + // seqlens_k = external_seq_len - 1 (historical convention) + tester.AddInput("seqlens_k", {batch_size}, {static_cast(external_seq_len - 1)}); + tester.AddInput("total_sequence_length", {1}, {static_cast(external_seq_len)}); + + tester.AddOptionalInputEdge(); // cos_cache (7) + tester.AddOptionalInputEdge(); // sin_cache (8) + tester.AddOptionalInputEdge(); // position_ids (9) + tester.AddOptionalInputEdge(); // attention_bias (10) + tester.AddOptionalInputEdge(); // head_sink (11) + tester.AddOptionalInputEdge(); // k_scale (12) + tester.AddOptionalInputEdge(); // v_scale (13) + + // External key/value (inputs 14/15) — BNSH format + std::vector ext_key(batch_size * kv_num_heads * external_seq_len * head_size, 0.5f); + std::vector ext_value(batch_size * kv_num_heads * external_seq_len * head_size, 0.5f); + tester.AddInput("external_key", {batch_size, kv_num_heads, external_seq_len, head_size}, ext_key); + tester.AddInput("external_value", {batch_size, kv_num_heads, external_seq_len, head_size}, ext_value); + + // Outputs + int present_seq_len = std::max(1, external_seq_len); + tester.AddOutput("output", {batch_size, sequence_length, hidden_size}, + std::vector(batch_size * sequence_length * hidden_size, 0.0f)); + tester.AddOutput("present_key", {batch_size, kv_num_heads, present_seq_len, head_size}, + std::vector(batch_size * kv_num_heads * present_seq_len * head_size, 0.0f)); + tester.AddOutput("present_value", {batch_size, kv_num_heads, present_seq_len, head_size}, + std::vector(batch_size * kv_num_heads * present_seq_len * head_size, 0.0f)); + + if (expect == OpTester::ExpectResult::kExpectSuccess) { + tester.SetOutputTolerance(1e6f); + } + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + tester.Run(expect, expected_message, {}, nullptr, &execution_providers); +} + +// Basic: external KV with Q-only query should succeed +TEST(GroupQueryAttentionTest, ExternalKV_BasicSuccess) { + RunGQAExternalKVTest( + /*external_seq_len=*/8, + OpTester::ExpectResult::kExpectSuccess, + ""); +} + +// Reject: external KV with key/value inputs provided (mutual exclusivity) +TEST(GroupQueryAttentionTest, ExternalKV_RejectsProvidedKeyValue) { + RunGQAExternalKVTest( + /*external_seq_len=*/8, + OpTester::ExpectResult::kExpectFailure, + "key and value (inputs 1/2) must not be provided", + /*provide_key_value=*/true); +} + +// Reject: external KV with past_key/past_value provided +TEST(GroupQueryAttentionTest, ExternalKV_RejectsPastKV) { + RunGQAExternalKVTest( + /*external_seq_len=*/8, + OpTester::ExpectResult::kExpectFailure, + "past_key/past_value should not be provided", + /*provide_key_value=*/false, + /*provide_past=*/true); +} + +// Reject: external KV with do_rotary=1 +TEST(GroupQueryAttentionTest, ExternalKV_RejectsDoRotary) { + RunGQAExternalKVTest( + /*external_seq_len=*/8, + OpTester::ExpectResult::kExpectFailure, + "do_rotary must be 0", + /*provide_key_value=*/false, + /*provide_past=*/false, + /*do_rotary=*/true); +} + } // namespace test } // namespace onnxruntime From d2ead38cfe01669faa157b78338f4ee58d9181c9 Mon Sep 17 00:00:00 2001 From: Akshay Sonawane Date: Tue, 28 Apr 2026 00:59:51 -0700 Subject: [PATCH 7/9] Make GQA present_key/present_value outputs optional for KV-shared layers --- .../cpu/bert/attention_parameters.h | 4 - .../contrib_ops/cpu/bert/gqa_attention_base.h | 47 ++---- .../cpu/bert/group_query_attention.cc | 42 +---- .../cpu/bert/group_query_attention_helper.h | 157 +----------------- .../contrib_ops/cuda/bert/attention_data.h | 4 - .../cuda/bert/group_query_attention.cc | 60 ++----- .../cuda/bert/group_query_attention_impl.cu | 23 +-- .../webgpu/bert/group_query_attention.cc | 3 +- .../core/graph/contrib_ops/bert_defs.cc | 30 +--- .../group_query_attention_op_test.cc | 151 +++++++---------- 10 files changed, 113 insertions(+), 408 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h index 08345727677a6..f316a0dfdf91c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h @@ -101,10 +101,6 @@ struct GroupQueryAttentionParameters : AttentionParameters { KVQuantizationType k_quant_type = KVQuantizationType::NONE; KVQuantizationType v_quant_type = KVQuantizationType::NONE; int kv_cache_bit_width = 0; - - // External KV parameters for KV-shared layers (e.g., Gemma4) - bool use_external_kv = false; // When true, use external K,V tensors instead of internal KV cache - int external_kv_sequence_length = 0; // Sequence length of external KV tensors }; // Parameters deduced from node attributes and inputs/outputs. diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index f2ea775dd78ff..1f03cf9f105a2 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -85,7 +85,9 @@ class GQAAttentionBase { if (past_key != nullptr && past_value != nullptr) { seqlen_past_kv_cache = static_cast(past_key->Shape().GetDims()[2]); } - int seqlen_present_kv_cache = static_cast(present_key->Shape().GetDims()[2]); + int seqlen_present_kv_cache = present_key != nullptr + ? static_cast(present_key->Shape().GetDims()[2]) + : parameters.seqlen_present_kv_cache; // Compute the attention score. bool gqa_mlas_supported = MlasGQASupported(CblasNoTrans, CblasTrans) && @@ -104,10 +106,6 @@ class GQAAttentionBase { bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data; - // External KV mode: K and V are nullptr, past_key/past_value contain the external KV data. - // Skip KV cache concatenation and use external KV directly. - const bool use_external_kv = parameters.use_external_kv; - const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; T* output_qk_buffer = output_qk != nullptr ? output_qk->MutableData() : nullptr; @@ -116,7 +114,7 @@ class GQAAttentionBase { ComputeAttentionProbs(static_cast(attention_probs), Q, k, head_sink, seqlens_k->Data(), attention_bias_data, batch_size, sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, present_key_data, output_qk_buffer, - past_present_share_buffer, packed_qkv, is_prompt, use_external_kv, tp, allocator); + past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; @@ -124,12 +122,12 @@ class GQAAttentionBase { seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, - is_prompt, use_external_kv, tp, allocator); + is_prompt, tp, allocator); } else { ComputeAttentionProbs(static_cast(attention_probs), Q, k, head_sink, seqlens_k->Data(), attention_bias_data, batch_size, sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, present_key_data, output_qk_buffer, - past_present_share_buffer, packed_qkv, is_prompt, use_external_kv, tp, allocator); + past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; @@ -137,7 +135,7 @@ class GQAAttentionBase { seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, - is_prompt, use_external_kv, tp, allocator); + is_prompt, tp, allocator); } return Status::OK(); @@ -168,7 +166,6 @@ class GQAAttentionBase { const bool past_present_share_buffer, // whether present key and value share the same buffer const bool packed_qkv, // whether Q, K, V are packed const bool is_prompt, // whether it is prompt - const bool use_external_kv, // whether using external KV (skip KV concat) ThreadPool* tp, // thread pool AllocatorPtr allocator) const { // allocator for temporary buffer const ptrdiff_t packed_batch_stride = @@ -180,7 +177,7 @@ class GQAAttentionBase { const size_t past_buff_chunk_length = past_buffer_sequence_length * head_size; // L x H const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H - if (!past_present_share_buffer) { + if (present_key && !past_present_share_buffer) { memset((void*)present_key, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); @@ -242,21 +239,12 @@ class GQAAttentionBase { } const T* k; - if (use_external_kv) { - // External KV mode: use past_key directly (it holds the external KV data in BNSH format). - // No new K to concatenate — the external KV is the complete key sequence. - k = past_key + (i / kv_num_heads_factor) * present_buff_chunk_length; - // Copy to present for output pass-through (once per KV head, not per Q head) - if (present_key != nullptr && !past_present_share_buffer && head_index % kv_num_heads_factor == 0) { - memcpy(present_key + (i / kv_num_heads_factor) * present_buff_chunk_length, - k, SafeInt(total_seqlen) * head_size * sizeof(T)); - } - } else if (packed_qkv) { + if (packed_qkv) { k = K + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); } else { k = K + kv_input_chunk_length * (i / kv_num_heads_factor); } - if (!use_external_kv && nullptr != present_key) { + if (nullptr != present_key) { k = ConcatStateChunkGQA(past_key, k, present_key, present_buff_chunk_length, past_buff_chunk_length, past_chunk_length, kv_input_chunk_length, past_present_share_buffer, i / kv_num_heads_factor); @@ -406,7 +394,6 @@ class GQAAttentionBase { const bool past_present_share_buffer, // whether present key and value share the same buffer const bool packed_qkv, // whether Q, K, V are packed const bool is_prompt, // whether it is prompt - const bool use_external_kv, // whether using external KV (skip KV concat) ThreadPool* tp, AllocatorPtr allocator) const { const ptrdiff_t packed_batch_stride = @@ -417,7 +404,7 @@ class GQAAttentionBase { const size_t past_buff_chunk_length = past_buffer_sequence_length * head_size; // L x H const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H - if (!past_present_share_buffer) { + if (present_value && !past_present_share_buffer) { memset((void*)present_value, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); @@ -460,20 +447,12 @@ class GQAAttentionBase { const size_t past_chunk_length = SafeInt(past_seqlen) * head_size; const T* v; - if (use_external_kv) { - // External KV mode: use past_value directly (it holds the external KV data in BNSH format). - v = past_value + (i / kv_num_heads_factor) * present_buff_chunk_length; - // Copy to present for output pass-through (once per KV head, not per Q head) - if (present_value != nullptr && !past_present_share_buffer && head_index % kv_num_heads_factor == 0) { - memcpy(present_value + (i / kv_num_heads_factor) * present_buff_chunk_length, - v, SafeInt(total_seqlen) * head_size * sizeof(T)); - } - } else if (packed_qkv) { + if (packed_qkv) { v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); } else { v = V + kv_input_chunk_length * (i / kv_num_heads_factor); } - if (!use_external_kv && nullptr != present_value) { + if (nullptr != present_value) { v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length, past_chunk_length, kv_input_chunk_length, past_present_share_buffer, i / kv_num_heads_factor); diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index d126e68f69edd..5ee2f31539bae 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -55,8 +55,6 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { const Tensor* position_ids = context->Input(9); const Tensor* attention_bias = context->Input(10); const Tensor* head_sink = context->Input(11); - const Tensor* external_key = context->Input(14); - const Tensor* external_value = context->Input(15); GroupQueryAttentionParameters parameters = {}; ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, @@ -73,30 +71,13 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { total_seqlen_tensor, scale_, softcap_, - 0, - external_key != nullptr)); + 0)); ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids, attention_bias, head_sink, parameters)); - ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckAndSetExternalKV(external_key, external_value, parameters)); - - // External KV is mutually exclusive with provided key/value inputs - if (parameters.use_external_kv && (key != nullptr || value != nullptr)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "key and value (inputs 1/2) must not be provided when using external_key/external_value. " - "External KV replaces the K,V projections entirely."); - } - - // External KV mode requires do_rotary=0 — K already has RoPE from the source layer - if (parameters.use_external_kv && do_rotary_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "do_rotary must be 0 when using external_key/external_value. " - "Pre-apply RoPE to Q and use already-rotated K from the source layer."); - } - const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int present_kv_seqlen = parameters.seqlen_present_kv_cache; @@ -144,11 +125,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { OrtValue Q; OrtValue K; OrtValue V; - if (parameters.use_external_kv) { - // External KV mode: only Q needs transposing. K,V come from external tensors. - ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( - allocator, batch_size, num_heads_, sequence_length, head_size, query, Q)); - } else if (packed_qkv) { + if (packed_qkv) { ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( allocator, batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size, query, Q)); } else { @@ -164,8 +141,8 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { OrtValue RotaryQ; OrtValue RotaryK; T* q_rotary = Q.GetMutable()->MutableData(); - T* k_rotary = (packed_qkv || parameters.use_external_kv) ? nullptr : K.GetMutable()->MutableData(); - if (do_rotary_ && !parameters.use_external_kv) { + T* k_rotary = packed_qkv ? nullptr : K.GetMutable()->MutableData(); + if (do_rotary_) { ORT_ENFORCE(cos_cache != nullptr && sin_cache != nullptr, "cos_cache and sin_cache must be provided when do_rotary is true"); // Initialize rotary parameters rotary_embedding_helper::RotaryParameters rotary_params = {}; @@ -255,16 +232,11 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { const T* head_sink_data = (head_sink != nullptr) ? head_sink->Data() : nullptr; - // When external KV is provided, use it in place of past_key/past_value for attention computation. - // External KV is pre-computed from another layer (KV-shared layers, e.g., Gemma4). - const Tensor* effective_past_key = parameters.use_external_kv ? external_key : past_key; - const Tensor* effective_past_value = parameters.use_external_kv ? external_value : past_value; - // Compute the attention score and apply the score to V - const T* k_data = (packed_qkv || parameters.use_external_kv) ? nullptr : k_rotary; - const T* v_data = (packed_qkv || parameters.use_external_kv) ? nullptr : V.Get().Data(); + const T* k_data = packed_qkv ? nullptr : k_rotary; + const T* v_data = packed_qkv ? nullptr : V.Get().Data(); return ApplyAttention(q_rotary, k_data, v_data, - head_sink_data, attention_bias, effective_past_key, effective_past_value, output, present_k, present_v, + head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v, output_qk, seqlens_k, parameters, allocator, context); } } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 0a458c3a8e124..f65568700c0c9 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -96,38 +96,6 @@ Status Check_QKV(const T* packed_qkv, const T* value, const int num_heads, const return Status::OK(); } -template -Status Check_Q_Only(const T* query, const int num_heads, const int kv_num_heads, - int& batch_size, int& sequence_length, int& q_hidden_size, int& kv_hidden_size, int& head_size) { - const auto& query_dims = query->Shape().GetDims(); - if (query_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", - query_dims.size()); - } - batch_size = static_cast(query_dims[0]); - sequence_length = static_cast(query_dims[1]); - q_hidden_size = static_cast(query_dims[2]); - if (q_hidden_size % num_heads != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "q_hidden_size (", q_hidden_size, ") must be divisible by num_heads (", num_heads, - ") in Q-only mode (external KV). Got q_hidden_size % num_heads == ", - q_hidden_size % num_heads); - } - head_size = q_hidden_size / num_heads; - if (head_size == 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "head_size must be > 0. Got q_hidden_size=", q_hidden_size, - ", num_heads=", num_heads); - } - if (head_size % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "head_size must be a multiple of 8. Got head_size % 8 == ", - head_size % 8); - } - kv_hidden_size = head_size * kv_num_heads; - return Status::OK(); -} - template Status CheckPast(const T* past_key, const T* past_value, int batch_size, int kv_num_heads, int head_size, int kv_cache_bit_width, int& past_sequence_length) { @@ -189,66 +157,6 @@ Status CheckPast(const T* past_key, const T* past_value, int batch_size, int kv_ return Status::OK(); } -template -Status CheckExternalKV(const T* external_key, const T* external_value, int batch_size, int kv_num_heads, - int head_size, int kv_cache_bit_width, int& external_sequence_length) { - if (external_key == nullptr || external_value == nullptr) { - if (external_key != nullptr || external_value != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'external_key' and 'external_value' shall be both present or both absent."); - } - return Status::OK(); - } - - const auto& ext_key_dims = external_key->Shape().GetDims(); - const auto& ext_value_dims = external_value->Shape().GetDims(); - - if (ext_key_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'external_key' is expected to have 4 dimensions (BNSH), got ", - ext_key_dims.size()); - } - if (ext_value_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'external_value' is expected to have 4 dimensions (BNSH), got ", - ext_value_dims.size()); - } - if (ext_key_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'external_key' dimension 0 should be batch_size, got ", ext_key_dims[0]); - } - if (ext_value_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'external_value' dimension 0 should be batch_size, got ", ext_value_dims[0]); - } - if (ext_key_dims[1] != kv_num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'external_key' shall have kv_num_heads, got ", ext_key_dims[1]); - } - if (ext_value_dims[1] != kv_num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'external_value' shall have kv_num_heads, got ", ext_value_dims[1]); - } - if (ext_key_dims[2] != ext_value_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'external_key' and 'external_value' should have same sequence length dimension."); - } - // Validate head dimension (dim 3). For 4-bit quantized KV cache, the stored dimension is head_size / 2. - int expected_head_dim = (kv_cache_bit_width == 4) ? (head_size / 2) : head_size; - if (ext_key_dims[3] != expected_head_dim) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'external_key' dimension 3 should be head_size (", expected_head_dim, - "), got ", ext_key_dims[3]); - } - if (ext_value_dims[3] != expected_head_dim) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'external_value' dimension 3 should be head_size (", expected_head_dim, - "), got ", ext_value_dims[3]); - } - external_sequence_length = static_cast(ext_key_dims[2]); - return Status::OK(); -} - template Status CheckRotaryCaches(const T* cos_cache, const T* sin_cache, int head_size, int total_sequence_length, int& rotary_dim) { @@ -299,8 +207,7 @@ Status CheckInputs(const T* query, const T* total_seqlen, float scale, float softcap, - int kv_cache_bit_width, - bool has_external_kv) { + int kv_cache_bit_width) { // Note: Here S* is seqlen_past_kv_cache, S+ is seqlen_present_kv_cache // past_key : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr // past_value : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr @@ -335,14 +242,8 @@ Status CheckInputs(const T* query, int q_hidden_size = 0; int kv_hidden_size = 0; int head_size = 0; - // When external KV is provided, key/value can be nullptr without implying packed QKV. - // In this mode, query contains only Q (not packed QKV). - const bool is_packed_qkv = (key == nullptr) && !has_external_kv; - if (has_external_kv && key == nullptr) { - // Q-only mode: query is just Q, K and V come from external tensors - ORT_RETURN_IF_ERROR(Check_Q_Only(query, num_heads, kv_num_heads, batch_size, sequence_length, - q_hidden_size, kv_hidden_size, head_size)); - } else if (!is_packed_qkv) { + const bool is_packed_qkv = (key == nullptr); + if (!is_packed_qkv) { ORT_RETURN_IF_ERROR(Check_Q_K_V(query, key, value, num_heads, kv_num_heads, batch_size, sequence_length, q_hidden_size, kv_hidden_size, head_size)); } else { @@ -449,13 +350,12 @@ Status CheckInputs(const T* query, float scale, float softcap, int kv_cache_bit_width, - int max_threads_per_block, - bool has_external_kv = false) { + int max_threads_per_block) { if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); } - return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale, softcap, kv_cache_bit_width, has_external_kv); + return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale, softcap, kv_cache_bit_width); } template @@ -545,53 +445,6 @@ inline Status CheckNoQKOutput(int num_outputs, int qk_output) { return Status::OK(); } -// Validate and configure external KV inputs for KV-shared layers. -// Call this after CheckInputs to set up external KV parameters. -template -Status CheckAndSetExternalKV(const T* external_key, const T* external_value, - GroupQueryAttentionParameters& parameters) { - if (external_key == nullptr && external_value == nullptr) { - return Status::OK(); - } - - int external_sequence_length = 0; - ORT_RETURN_IF_ERROR(CheckExternalKV(external_key, external_value, - parameters.batch_size, parameters.kv_num_heads, - parameters.head_size, parameters.kv_cache_bit_width, - external_sequence_length)); - - parameters.use_external_kv = true; - parameters.external_kv_sequence_length = external_sequence_length; - - // External KV mode is incompatible with packed QKV — query must contain only Q. - if (parameters.is_packed_qkv) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "external_key/external_value cannot be used with packed QKV input. " - "Provide query as Q-only (without K,V) when using external KV."); - } - - // External KV replaces the internal KV cache — past_key/past_value should not be provided. - if (parameters.seqlen_past_kv_cache > 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "past_key/past_value should not be provided when using external_key/external_value. " - "External KV replaces the internal KV cache entirely."); - } - - // When using external KV, the total sequence length for attention is determined - // by the external KV tensor, not the internal KV cache. - // Validate that the original total_sequence_length doesn't exceed external KV length. - if (parameters.total_sequence_length > external_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "total_sequence_length (", parameters.total_sequence_length, - ") exceeds external KV sequence length (", external_sequence_length, - "). Ensure seqlens_k is consistent with the external KV tensor size."); - } - parameters.total_sequence_length = external_sequence_length; - parameters.seqlen_present_kv_cache = external_sequence_length; - - return Status::OK(); -} - } // namespace group_query_attention_helper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h index 7c2f805b5292b..c54a1fea9ad3a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h @@ -192,10 +192,6 @@ struct GroupQueryAttentionData { U* present_key = nullptr; U* present_value = nullptr; - // External KV for KV-shared layers (e.g., Gemma4) - const U* external_key = nullptr; - const U* external_value = nullptr; - // Kernel Flags bool use_flash_attention = false; bool use_memory_efficient_attention = false; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 2562e72451bd8..9563292f9187c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -166,8 +166,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons const Tensor* head_sink = context->Input(11); const Tensor* k_scale = context->Input(12); const Tensor* v_scale = context->Input(13); - const Tensor* external_key = context->Input(14); - const Tensor* external_value = context->Input(15); if (k_quant_type_ != KVQuantizationType::NONE) { if (k_scale == nullptr) { @@ -223,23 +221,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons scale_, softcap_, kv_cache_bit_width_, - device_prop.maxThreadsPerBlock, - external_key != nullptr)); + device_prop.maxThreadsPerBlock)); ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids, attention_bias, head_sink, parameters)); - ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckAndSetExternalKV(external_key, external_value, parameters)); - - // External KV is mutually exclusive with provided key/value inputs - if (parameters.use_external_kv && (key != nullptr || value != nullptr)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "key and value (inputs 1/2) must not be provided when using external_key/external_value. " - "External KV replaces the K,V projections entirely."); - } - parameters.local_window_size = local_window_size_; parameters.is_unidirectional = is_unidirectional_; parameters.use_smooth_softmax = use_smooth_softmax_ || head_sink != nullptr; @@ -251,14 +239,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons parameters.do_rotary = do_rotary_; parameters.rotary_interleaved = rotary_interleaved_; - // When using external KV, disable rotary embedding — the external KV already has RoPE applied - // from the source layer, and the caller is expected to pre-apply RoPE to Q. - if (parameters.use_external_kv && parameters.do_rotary) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "do_rotary must be 0 when using external_key/external_value. " - "Pre-apply RoPE to Q and use already-rotated K from the source layer."); - } - // The current GQA CUDA implementation will never be able to have a QK output. // GQA CUDA uses either flash attention or memory efficient attention. Neither kernel supports returning the QK output. ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( @@ -310,26 +290,12 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons data.k_scale = k_scale == nullptr ? nullptr : reinterpret_cast(k_scale->DataRaw()); data.v_scale = v_scale == nullptr ? nullptr : reinterpret_cast(v_scale->DataRaw()); - if (parameters.use_external_kv) { - // External KV mode: use external tensors as the KV source for attention. - // The external KV is treated as "past" KV since it's already computed. - // No KV cache update is performed — the present outputs are copies/views of external KV. - data.external_key = reinterpret_cast(external_key->Data()); - data.external_value = reinterpret_cast(external_value->Data()); - data.past_key = data.external_key; - data.past_value = data.external_value; - data.present_key = reinterpret_cast(present_key_output->MutableData()); - data.present_value = reinterpret_cast(present_value_output->MutableData()); - // Mark as shared buffer so the kernel treats external KV as already-populated cache - parameters.past_present_share_buffer = false; - } else { - data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data()); - data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); - data.present_key = reinterpret_cast(present_key_output->MutableData()); - data.present_value = reinterpret_cast(present_value_output->MutableData()); - // Compute past_present_share_buffer early since it's needed for flash attention path selection. - parameters.past_present_share_buffer = (data.past_key == data.present_key); - } + data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data()); + data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); + data.present_key = (present_key_output != nullptr) ? reinterpret_cast(present_key_output->MutableData()) : nullptr; + data.present_value = (present_value_output != nullptr) ? reinterpret_cast(present_value_output->MutableData()) : nullptr; + // Compute past_present_share_buffer early since it's needed for flash attention path selection. + parameters.past_present_share_buffer = (data.past_key != nullptr && data.past_key == data.present_key); bool is_inputs_quantized = (k_quant_type_ != KVQuantizationType::NONE) || (v_quant_type_ != KVQuantizationType::NONE); constexpr bool is_int8 = std::is_same::value; @@ -594,12 +560,12 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons } // Validate past_value pointer consistency (past_present_share_buffer was computed early after pointer setup) - if (parameters.use_external_kv) { - // External KV mode: past and present are separate (external source -> present output) - } else if (parameters.past_present_share_buffer) { - ORT_ENFORCE(data.past_value == data.present_value, "past_value and present_value must be the same tensor when past_present_share_buffer is true"); - } else { - ORT_ENFORCE(data.past_value != data.present_value, "past_value and present_value must be different tensors when past_present_share_buffer is false"); + if (data.present_value != nullptr) { + if (parameters.past_present_share_buffer) { + ORT_ENFORCE(data.past_value == data.present_value, "past_value and present_value must be the same tensor when past_present_share_buffer is true"); + } else { + ORT_ENFORCE(data.past_value != data.present_value, "past_value and present_value must be different tensors when past_present_share_buffer is false"); + } } data.output = reinterpret_cast(output->MutableData()); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 4f0da8fda4d08..3ce396989b181 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -98,23 +98,12 @@ Status PrepareQKV( q_out = nullptr; } - // External KV mode: the external KV data is already set as past_key/past_value. - // Copy it into the present buffers and skip the KV append/RoPE logic. - if (parameters.use_external_kv) { - U* k = reinterpret_cast(data.present_key); - U* v = reinterpret_cast(data.present_value); - int external_seq_len = parameters.external_kv_sequence_length; - - // For 4-bit quantized KV cache, the stored head dimension is head_size/2 (two nibbles per byte). - // Use the packed dimension to compute the correct copy size. - int cache_head_dim = (parameters.kv_cache_bit_width == 4) ? (head_size + 1) / 2 : head_size; - size_t kv_copy_size = (size_t)batch_size * kv_num_heads * external_seq_len * cache_head_dim * sizeof(U); - CUDA_CALL_THROW(cudaMemcpyAsync(k, data.past_key, kv_copy_size, cudaMemcpyDeviceToDevice, stream)); - CUDA_CALL_THROW(cudaMemcpyAsync(v, data.past_value, kv_copy_size, cudaMemcpyDeviceToDevice, stream)); - - // Q is used directly from the input - q = reinterpret_cast(data.query); - return Status::OK(); + // present_key/present_value are required for the CUDA path since flash attention + // and memory-efficient attention read directly from the present KV buffers. + // The CPU path supports optional present outputs for KV-shared layers. + if (data.present_key == nullptr || data.present_value == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "present_key and present_value outputs are required for the CUDA GroupQueryAttention kernel."); } U* k = reinterpret_cast(data.present_key); diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 98f09a496f579..5fff0516c7ce3 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -212,8 +212,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& scale_, softcap_, 0, - static_cast(context.DeviceLimits().maxComputeInvocationsPerWorkgroup), - /*has_external_kv=*/false)); + static_cast(context.DeviceLimits().maxComputeInvocationsPerWorkgroup))); params.use_smooth_softmax = use_smooth_softmax_; params.rotary_interleaved = rotary_interleaved_; diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index ccecc28ec6afc..e8ec04586a9d6 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -240,20 +240,13 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte if (ctx.getNumOutputs() >= 3) { // has present output const auto* past_key_type = ctx.getInputType(past_key_index); - // external_key is at input index 14 for GroupQueryAttention - const auto* external_key_type = (ctx.getNumInputs() > 14) ? ctx.getInputType(14) : nullptr; if (past_key_type != nullptr) { // present_key and present_value have the same type as past_key/past_value. // This allows them to be int8 or packed uint8 when quantization is enabled. ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, past_key_index, 1); // present_key ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, past_key_index + 1, 2); // present_value - } else if (external_key_type != nullptr) { - // When external KV is provided (inputs 14/15), present outputs should match - // the external KV type (T_CACHE), not the query type (T). - ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 14, 1); // present_key from external_key - ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 15, 2); // present_value from external_value } else { - // If no past state and no external KV, present is the same type as query. + // If no past state, present is the same type as query. ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 1); ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 2); } @@ -1321,21 +1314,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(12, "k_scale", "Scale tensor for past_key.", "T_KV_SCALE", OpSchema::Optional) .Input(13, "v_scale", "Scale tensor for past_value.", "T_KV_SCALE", OpSchema::Optional) - .Input(14, - "external_key", - "External pre-computed key tensor in BNSH format (batch_size, kv_num_heads, external_seq_len, head_size). " - "Used for KV-shared layers that borrow K,V from another layer's present KV output. " - "When provided, the operator skips its internal KV cache update and uses this tensor directly " - "for attention computation. RoPE is not applied to external keys (assumed already applied).", - "T_CACHE", - OpSchema::Optional) - .Input(15, - "external_value", - "External pre-computed value tensor in BNSH format (batch_size, kv_num_heads, external_seq_len, head_size). " - "Must be provided together with external_key. When provided, the operator uses this tensor " - "for attention computation instead of the internal KV cache.", - "T_CACHE", - OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", @@ -1345,13 +1323,15 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "present state key with support for format BNSH. When past_key uses same tensor as present_key" "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", - "T_CACHE") + "T_CACHE", + OpSchema::Optional) .Output(2, "present_value", "present state value with support for format BNSH. When past_value uses same tensor as present_value" "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", - "T_CACHE") + "T_CACHE", + OpSchema::Optional) .Output(3, "output_qk", "Values of QK matrix multiplication, either before or after softmax normalization", diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index 593fc55c01e99..cc00d1bf61fbb 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -308,19 +308,18 @@ TEST(GroupQueryAttentionTest, SeqlensKWrongLength) { } // ============================================================================ -// External KV tests (inputs 14/15: external_key, external_value) +// Optional present_key/present_value output tests // ============================================================================ -// Helper for external KV tests -static void RunGQAExternalKVTest( - int external_seq_len, +// Helper for tests with optional present outputs. +// When omit_present=true, present_key and present_value outputs are not connected. +static void RunGQAOptionalPresentTest( + int batch_size, + int sequence_length, + int total_seq_len, + bool omit_present, OpTester::ExpectResult expect, - const std::string& expected_message, - bool provide_key_value = false, - bool provide_past = false, - bool do_rotary = false) { - constexpr int batch_size = 1; - constexpr int sequence_length = 1; + const std::string& expected_message) { constexpr int num_heads = 2; constexpr int kv_num_heads = 1; constexpr int head_size = 8; @@ -330,62 +329,43 @@ static void RunGQAExternalKVTest( OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); tester.AddAttribute("num_heads", static_cast(num_heads)); tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); - if (do_rotary) { - tester.AddAttribute("do_rotary", 1); - } - // Query (Q-only when using external KV) std::vector query_data(batch_size * sequence_length * hidden_size, 1.0f); tester.AddInput("query", {batch_size, sequence_length, hidden_size}, query_data); - // Key/Value inputs (should be absent for external KV) - if (provide_key_value) { - std::vector key_data(batch_size * sequence_length * kv_hidden_size, 1.0f); - std::vector value_data(batch_size * sequence_length * kv_hidden_size, 1.0f); - tester.AddInput("key", {batch_size, sequence_length, kv_hidden_size}, key_data); - tester.AddInput("value", {batch_size, sequence_length, kv_hidden_size}, value_data); - } else { - tester.AddOptionalInputEdge(); // key - tester.AddOptionalInputEdge(); // value - } + std::vector key_data(batch_size * sequence_length * kv_hidden_size, 0.5f); + tester.AddInput("key", {batch_size, sequence_length, kv_hidden_size}, key_data); - // Past key/value (should be absent for external KV) - if (provide_past) { - std::vector past_k(batch_size * kv_num_heads * 4 * head_size, 0.5f); - std::vector past_v(batch_size * kv_num_heads * 4 * head_size, 0.5f); - tester.AddInput("past_key", {batch_size, kv_num_heads, 4, head_size}, past_k); - tester.AddInput("past_value", {batch_size, kv_num_heads, 4, head_size}, past_v); - } else { - tester.AddOptionalInputEdge(); // past_key - tester.AddOptionalInputEdge(); // past_value - } + std::vector value_data(batch_size * sequence_length * kv_hidden_size, 0.5f); + tester.AddInput("value", {batch_size, sequence_length, kv_hidden_size}, value_data); + + tester.AddOptionalInputEdge(); // past_key + tester.AddOptionalInputEdge(); // past_value + + tester.AddInput("seqlens_k", {batch_size}, {static_cast(total_seq_len - 1)}); + tester.AddInput("total_sequence_length", {1}, {static_cast(total_seq_len)}); + + tester.AddOptionalInputEdge(); // cos_cache + tester.AddOptionalInputEdge(); // sin_cache + tester.AddOptionalInputEdge(); // position_ids + tester.AddOptionalInputEdge(); // attention_bias + tester.AddOptionalInputEdge(); // head_sink - // seqlens_k = external_seq_len - 1 (historical convention) - tester.AddInput("seqlens_k", {batch_size}, {static_cast(external_seq_len - 1)}); - tester.AddInput("total_sequence_length", {1}, {static_cast(external_seq_len)}); - - tester.AddOptionalInputEdge(); // cos_cache (7) - tester.AddOptionalInputEdge(); // sin_cache (8) - tester.AddOptionalInputEdge(); // position_ids (9) - tester.AddOptionalInputEdge(); // attention_bias (10) - tester.AddOptionalInputEdge(); // head_sink (11) - tester.AddOptionalInputEdge(); // k_scale (12) - tester.AddOptionalInputEdge(); // v_scale (13) - - // External key/value (inputs 14/15) — BNSH format - std::vector ext_key(batch_size * kv_num_heads * external_seq_len * head_size, 0.5f); - std::vector ext_value(batch_size * kv_num_heads * external_seq_len * head_size, 0.5f); - tester.AddInput("external_key", {batch_size, kv_num_heads, external_seq_len, head_size}, ext_key); - tester.AddInput("external_value", {batch_size, kv_num_heads, external_seq_len, head_size}, ext_value); - - // Outputs - int present_seq_len = std::max(1, external_seq_len); + // Output 0: output (always required) tester.AddOutput("output", {batch_size, sequence_length, hidden_size}, std::vector(batch_size * sequence_length * hidden_size, 0.0f)); - tester.AddOutput("present_key", {batch_size, kv_num_heads, present_seq_len, head_size}, - std::vector(batch_size * kv_num_heads * present_seq_len * head_size, 0.0f)); - tester.AddOutput("present_value", {batch_size, kv_num_heads, present_seq_len, head_size}, - std::vector(batch_size * kv_num_heads * present_seq_len * head_size, 0.0f)); + + if (omit_present) { + // Omit present_key and present_value — they are optional + tester.AddOptionalOutputEdge(); // present_key + tester.AddOptionalOutputEdge(); // present_value + } else { + int present_seq_len = total_seq_len; + tester.AddOutput("present_key", {batch_size, kv_num_heads, present_seq_len, head_size}, + std::vector(batch_size * kv_num_heads * present_seq_len * head_size, 0.0f)); + tester.AddOutput("present_value", {batch_size, kv_num_heads, present_seq_len, head_size}, + std::vector(batch_size * kv_num_heads * present_seq_len * head_size, 0.0f)); + } if (expect == OpTester::ExpectResult::kExpectSuccess) { tester.SetOutputTolerance(1e6f); @@ -396,42 +376,37 @@ static void RunGQAExternalKVTest( tester.Run(expect, expected_message, {}, nullptr, &execution_providers); } -// Basic: external KV with Q-only query should succeed -TEST(GroupQueryAttentionTest, ExternalKV_BasicSuccess) { - RunGQAExternalKVTest( - /*external_seq_len=*/8, +// Baseline: GQA with present outputs connected works as before +TEST(GroupQueryAttentionTest, OptionalPresent_WithPresent) { + RunGQAOptionalPresentTest( + /*batch_size=*/1, + /*sequence_length=*/4, + /*total_seq_len=*/4, + /*omit_present=*/false, OpTester::ExpectResult::kExpectSuccess, ""); } -// Reject: external KV with key/value inputs provided (mutual exclusivity) -TEST(GroupQueryAttentionTest, ExternalKV_RejectsProvidedKeyValue) { - RunGQAExternalKVTest( - /*external_seq_len=*/8, - OpTester::ExpectResult::kExpectFailure, - "key and value (inputs 1/2) must not be provided", - /*provide_key_value=*/true); -} - -// Reject: external KV with past_key/past_value provided -TEST(GroupQueryAttentionTest, ExternalKV_RejectsPastKV) { - RunGQAExternalKVTest( - /*external_seq_len=*/8, - OpTester::ExpectResult::kExpectFailure, - "past_key/past_value should not be provided", - /*provide_key_value=*/false, - /*provide_past=*/true); +// KV-shared layer scenario: present outputs omitted, attention uses K,V directly +TEST(GroupQueryAttentionTest, OptionalPresent_WithoutPresent) { + RunGQAOptionalPresentTest( + /*batch_size=*/1, + /*sequence_length=*/4, + /*total_seq_len=*/4, + /*omit_present=*/true, + OpTester::ExpectResult::kExpectSuccess, + ""); } -// Reject: external KV with do_rotary=1 -TEST(GroupQueryAttentionTest, ExternalKV_RejectsDoRotary) { - RunGQAExternalKVTest( - /*external_seq_len=*/8, - OpTester::ExpectResult::kExpectFailure, - "do_rotary must be 0", - /*provide_key_value=*/false, - /*provide_past=*/false, - /*do_rotary=*/true); +// Batched: present outputs omitted with batch_size > 1 +TEST(GroupQueryAttentionTest, OptionalPresent_Batched) { + RunGQAOptionalPresentTest( + /*batch_size=*/2, + /*sequence_length=*/3, + /*total_seq_len=*/3, + /*omit_present=*/true, + OpTester::ExpectResult::kExpectSuccess, + ""); } } // namespace test From aad74efae671df82a813b519b3fcc0e3fb493c83 Mon Sep 17 00:00:00 2001 From: Akshay Sonawane Date: Tue, 28 Apr 2026 11:17:09 -0700 Subject: [PATCH 8/9] Fix tests --- onnxruntime/test/contrib_ops/group_query_attention_op_test.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index cc00d1bf61fbb..1d57488d51363 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -342,7 +342,8 @@ static void RunGQAOptionalPresentTest( tester.AddOptionalInputEdge(); // past_key tester.AddOptionalInputEdge(); // past_value - tester.AddInput("seqlens_k", {batch_size}, {static_cast(total_seq_len - 1)}); + std::vector seqlens_k_data(batch_size, static_cast(total_seq_len - 1)); + tester.AddInput("seqlens_k", {batch_size}, seqlens_k_data); tester.AddInput("total_sequence_length", {1}, {static_cast(total_seq_len)}); tester.AddOptionalInputEdge(); // cos_cache From 64005dd59d91c1fa65b4aa0aa0f8ac752167879c Mon Sep 17 00:00:00 2001 From: Akshay Sonawane Date: Tue, 28 Apr 2026 17:32:23 -0700 Subject: [PATCH 9/9] Update the docs --- docs/ContribOperators.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 9aa44a1600ae6..45e85fcd9c9d5 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2671,14 +2671,14 @@ This version of the operator has been available since version 1 of the 'com.micr
Scale tensor for past_value.
-#### Outputs (3 - 4) +#### Outputs (1 - 4)
output : T
3D output tensor with shape (batch_size, sequence_length, hidden_size)
-
present_key : T_CACHE
+
present_key (optional) : T_CACHE
present state key with support for format BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
-
present_value : T_CACHE
+
present_value (optional) : T_CACHE
present state value with support for format BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
output_qk (optional) : T
Values of QK matrix multiplication, either before or after softmax normalization