Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ struct GroupQueryAttentionParameters : AttentionParameters {
KVQuantizationType k_quant_type = KVQuantizationType::NONE;
KVQuantizationType v_quant_type = KVQuantizationType::NONE;
int kv_cache_bit_width = 0;

// External KV parameters for KV-shared layers (e.g., Gemma4)
bool use_external_kv = false; // When true, use external K,V tensors instead of internal KV cache
int external_kv_sequence_length = 0; // Sequence length of external KV tensors
};

// Parameters deduced from node attributes and inputs/outputs.
Expand Down
39 changes: 31 additions & 8 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ class GQAAttentionBase {

bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data;

// External KV mode: K and V are nullptr, past_key/past_value contain the external KV data.
// Skip KV cache concatenation and use external KV directly.
const bool use_external_kv = parameters.use_external_kv;

const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K;

T* output_qk_buffer = output_qk != nullptr ? output_qk->MutableData<T>() : nullptr;
Expand All @@ -112,28 +116,28 @@ class GQAAttentionBase {
ComputeAttentionProbs(static_cast<T*>(attention_probs), Q, k, head_sink, seqlens_k->Data<int32_t>(), attention_bias_data,
batch_size, sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache,
seqlen_present_kv_cache, head_size, past_key_data, present_key_data, output_qk_buffer,
past_present_share_buffer, packed_qkv, is_prompt, tp, allocator);
past_present_share_buffer, packed_qkv, is_prompt, use_external_kv, tp, allocator);

// Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<T*>(attention_probs), v,
seqlens_k->Data<int32_t>(),
batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size,
hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv,
is_prompt, tp, allocator);
is_prompt, use_external_kv, tp, allocator);
} else {
ComputeAttentionProbs(static_cast<float*>(attention_probs), Q, k, head_sink, seqlens_k->Data<int32_t>(), attention_bias_data,
batch_size, sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache,
seqlen_present_kv_cache, head_size, past_key_data, present_key_data, output_qk_buffer,
past_present_share_buffer, packed_qkv, is_prompt, tp, allocator);
past_present_share_buffer, packed_qkv, is_prompt, use_external_kv, tp, allocator);

// Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<float*>(attention_probs), v,
seqlens_k->Data<int32_t>(),
batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size,
hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv,
is_prompt, tp, allocator);
is_prompt, use_external_kv, tp, allocator);
}

return Status::OK();
Expand Down Expand Up @@ -164,6 +168,7 @@ class GQAAttentionBase {
const bool past_present_share_buffer, // whether present key and value share the same buffer
const bool packed_qkv, // whether Q, K, V are packed
const bool is_prompt, // whether it is prompt
const bool use_external_kv, // whether using external KV (skip KV concat)
ThreadPool* tp, // thread pool
AllocatorPtr allocator) const { // allocator for temporary buffer
const ptrdiff_t packed_batch_stride =
Expand Down Expand Up @@ -237,12 +242,21 @@ class GQAAttentionBase {
}

const T* k;
if (packed_qkv) {
if (use_external_kv) {
// External KV mode: use past_key directly (it holds the external KV data in BNSH format).
// No new K to concatenate — the external KV is the complete key sequence.
k = past_key + (i / kv_num_heads_factor) * present_buff_chunk_length;
// Copy to present for output pass-through (once per KV head, not per Q head)
if (present_key != nullptr && !past_present_share_buffer && head_index % kv_num_heads_factor == 0) {
memcpy(present_key + (i / kv_num_heads_factor) * present_buff_chunk_length,
k, SafeInt<size_t>(total_seqlen) * head_size * sizeof(T));
}
} else if (packed_qkv) {
k = K + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor);
} else {
k = K + kv_input_chunk_length * (i / kv_num_heads_factor);
}
if (nullptr != present_key) {
if (!use_external_kv && nullptr != present_key) {
k = ConcatStateChunkGQA(past_key, k, present_key, present_buff_chunk_length, past_buff_chunk_length,
past_chunk_length, kv_input_chunk_length, past_present_share_buffer,
i / kv_num_heads_factor);
Expand Down Expand Up @@ -392,6 +406,7 @@ class GQAAttentionBase {
const bool past_present_share_buffer, // whether present key and value share the same buffer
const bool packed_qkv, // whether Q, K, V are packed
const bool is_prompt, // whether it is prompt
const bool use_external_kv, // whether using external KV (skip KV concat)
ThreadPool* tp,
AllocatorPtr allocator) const {
const ptrdiff_t packed_batch_stride =
Expand Down Expand Up @@ -445,12 +460,20 @@ class GQAAttentionBase {
const size_t past_chunk_length = SafeInt<size_t>(past_seqlen) * head_size;

const T* v;
if (packed_qkv) {
if (use_external_kv) {
// External KV mode: use past_value directly (it holds the external KV data in BNSH format).
v = past_value + (i / kv_num_heads_factor) * present_buff_chunk_length;
// Copy to present for output pass-through (once per KV head, not per Q head)
if (present_value != nullptr && !past_present_share_buffer && head_index % kv_num_heads_factor == 0) {
memcpy(present_value + (i / kv_num_heads_factor) * present_buff_chunk_length,
v, SafeInt<size_t>(total_seqlen) * head_size * sizeof(T));
}
} else if (packed_qkv) {
v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor);
} else {
v = V + kv_input_chunk_length * (i / kv_num_heads_factor);
}
if (nullptr != present_value) {
if (!use_external_kv && nullptr != present_value) {
v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length,
past_chunk_length, kv_input_chunk_length, past_present_share_buffer,
i / kv_num_heads_factor);
Expand Down
35 changes: 29 additions & 6 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
const Tensor* position_ids = context->Input<Tensor>(9);
const Tensor* attention_bias = context->Input<Tensor>(10);
const Tensor* head_sink = context->Input<Tensor>(11);
const Tensor* external_key = context->Input<Tensor>(14);
const Tensor* external_value = context->Input<Tensor>(15);

GroupQueryAttentionParameters parameters = {};
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query,
Expand All @@ -71,13 +73,23 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
total_seqlen_tensor,
scale_,
softcap_,
0));
0,
external_key != nullptr));

ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids,
attention_bias,
head_sink,
parameters));

ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckAndSetExternalKV(external_key, external_value, parameters));

// External KV mode requires do_rotary=0 — K already has RoPE from the source layer
if (parameters.use_external_kv && do_rotary_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"do_rotary must be 0 when using external_key/external_value. "
"Pre-apply RoPE to Q and use already-rotated K from the source layer.");
}

Comment thread
apsonawane marked this conversation as resolved.
Outdated
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int present_kv_seqlen = parameters.seqlen_present_kv_cache;
Expand Down Expand Up @@ -125,7 +137,11 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
OrtValue Q;
OrtValue K;
OrtValue V;
if (packed_qkv) {
if (parameters.use_external_kv) {
// External KV mode: only Q needs transposing. K,V come from external tensors.
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
allocator, batch_size, num_heads_, sequence_length, head_size, query, Q));
} else if (packed_qkv) {
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
allocator, batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size, query, Q));
} else {
Expand All @@ -141,8 +157,8 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
OrtValue RotaryQ;
OrtValue RotaryK;
T* q_rotary = Q.GetMutable<Tensor>()->MutableData<T>();
T* k_rotary = packed_qkv ? nullptr : K.GetMutable<Tensor>()->MutableData<T>();
if (do_rotary_) {
T* k_rotary = (packed_qkv || parameters.use_external_kv) ? nullptr : K.GetMutable<Tensor>()->MutableData<T>();
if (do_rotary_ && !parameters.use_external_kv) {
ORT_ENFORCE(cos_cache != nullptr && sin_cache != nullptr, "cos_cache and sin_cache must be provided when do_rotary is true");
// Initialize rotary parameters
rotary_embedding_helper::RotaryParameters rotary_params = {};
Expand Down Expand Up @@ -232,9 +248,16 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {

const T* head_sink_data = (head_sink != nullptr) ? head_sink->Data<T>() : nullptr;

// When external KV is provided, use it in place of past_key/past_value for attention computation.
// External KV is pre-computed from another layer (KV-shared layers, e.g., Gemma4).
const Tensor* effective_past_key = parameters.use_external_kv ? external_key : past_key;
const Tensor* effective_past_value = parameters.use_external_kv ? external_value : past_value;

// Compute the attention score and apply the score to V
return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get<Tensor>().Data<T>(),
head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v,
const T* k_data = (packed_qkv || parameters.use_external_kv) ? nullptr : k_rotary;
const T* v_data = (packed_qkv || parameters.use_external_kv) ? nullptr : V.Get<Tensor>().Data<T>();
return ApplyAttention(q_rotary, k_data, v_data,
head_sink_data, attention_bias, effective_past_key, effective_past_value, output, present_k, present_v,
output_qk, seqlens_k, parameters, allocator, context);
}
} // namespace contrib
Expand Down
Loading
Loading