-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Add external_key/external_value inputs to GroupQueryAttention for KV-shared layers #28242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0db9578
ec041db
bcd8243
7139053
0bf6394
3c07e1d
f905196
d2ead38
aad74ef
64005dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -85,7 +85,9 @@ class GQAAttentionBase { | |
| if (past_key != nullptr && past_value != nullptr) { | ||
| seqlen_past_kv_cache = static_cast<int>(past_key->Shape().GetDims()[2]); | ||
| } | ||
| int seqlen_present_kv_cache = static_cast<int>(present_key->Shape().GetDims()[2]); | ||
| int seqlen_present_kv_cache = present_key != nullptr | ||
| ? static_cast<int>(present_key->Shape().GetDims()[2]) | ||
| : parameters.seqlen_present_kv_cache; | ||
|
|
||
| // Compute the attention score. | ||
| bool gqa_mlas_supported = MlasGQASupported<T>(CblasNoTrans, CblasTrans) && | ||
|
|
@@ -175,7 +177,7 @@ class GQAAttentionBase { | |
| const size_t past_buff_chunk_length = past_buffer_sequence_length * head_size; // L x H | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When |
||
| const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H | ||
|
|
||
| if (!past_present_share_buffer) { | ||
| if (present_key && !past_present_share_buffer) { | ||
| memset((void*)present_key, | ||
| 0, | ||
| batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); | ||
|
|
@@ -402,7 +404,7 @@ class GQAAttentionBase { | |
| const size_t past_buff_chunk_length = past_buffer_sequence_length * head_size; // L x H | ||
| const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H | ||
|
|
||
| if (!past_present_share_buffer) { | ||
| if (present_value && !past_present_share_buffer) { | ||
| memset((void*)present_value, | ||
| 0, | ||
| batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -98,6 +98,14 @@ Status PrepareQKV( | |||||||||||||||||||||||||||||||||||
| q_out = nullptr; | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| // present_key/present_value are required for the CUDA path since flash attention | ||||||||||||||||||||||||||||||||||||
| // and memory-efficient attention read directly from the present KV buffers. | ||||||||||||||||||||||||||||||||||||
| // The CPU path supports optional present outputs for KV-shared layers. | ||||||||||||||||||||||||||||||||||||
| if (data.present_key == nullptr || data.present_value == nullptr) { | ||||||||||||||||||||||||||||||||||||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||||||||||||||||||||||||||||||||||||
| "present_key and present_value outputs are required for the CUDA GroupQueryAttention kernel."); | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+101
to
+106
|
||||||||||||||||||||||||||||||||||||
| // present_key/present_value are required for the CUDA path since flash attention | |
| // and memory-efficient attention read directly from the present KV buffers. | |
| // The CPU path supports optional present outputs for KV-shared layers. | |
| if (data.present_key == nullptr || data.present_value == nullptr) { | |
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | |
| "present_key and present_value outputs are required for the CUDA GroupQueryAttention kernel."); | |
| // present_key/present_value are currently required for the CUDA path since flash attention | |
| // and memory-efficient attention read directly from the present KV buffers. | |
| // Note: although the operator schema/docs may declare these outputs optional, | |
| // the CUDA EP does not currently support omitting them. | |
| if (data.present_key == nullptr || data.present_value == nullptr) { | |
| return ORT_MAKE_STATUS( | |
| ONNXRUNTIME, INVALID_ARGUMENT, | |
| "CUDA GroupQueryAttention currently requires both present_key and present_value outputs to be connected " | |
| "because the CUDA kernels read directly from the present KV buffers. " | |
| "If your model relies on omitted optional present outputs, either connect both outputs when using the CUDA " | |
| "execution provider or run this node on CPU."); |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -307,5 +307,108 @@ TEST(GroupQueryAttentionTest, SeqlensKWrongLength) { | |
| {}, nullptr, &execution_providers); | ||
| } | ||
|
|
||
| // ============================================================================ | ||
| // Optional present_key/present_value output tests | ||
| // ============================================================================ | ||
|
|
||
| // Helper for tests with optional present outputs. | ||
| // When omit_present=true, present_key and present_value outputs are not connected. | ||
| static void RunGQAOptionalPresentTest( | ||
| int batch_size, | ||
| int sequence_length, | ||
| int total_seq_len, | ||
| bool omit_present, | ||
| OpTester::ExpectResult expect, | ||
| const std::string& expected_message) { | ||
| constexpr int num_heads = 2; | ||
| constexpr int kv_num_heads = 1; | ||
| constexpr int head_size = 8; | ||
| constexpr int hidden_size = num_heads * head_size; | ||
| constexpr int kv_hidden_size = kv_num_heads * head_size; | ||
|
|
||
| OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); | ||
| tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(num_heads)); | ||
| tester.AddAttribute<int64_t>("kv_num_heads", static_cast<int64_t>(kv_num_heads)); | ||
|
|
||
| std::vector<float> query_data(batch_size * sequence_length * hidden_size, 1.0f); | ||
| tester.AddInput<float>("query", {batch_size, sequence_length, hidden_size}, query_data); | ||
|
|
||
| std::vector<float> key_data(batch_size * sequence_length * kv_hidden_size, 0.5f); | ||
| tester.AddInput<float>("key", {batch_size, sequence_length, kv_hidden_size}, key_data); | ||
|
|
||
| std::vector<float> value_data(batch_size * sequence_length * kv_hidden_size, 0.5f); | ||
| tester.AddInput<float>("value", {batch_size, sequence_length, kv_hidden_size}, value_data); | ||
|
|
||
| tester.AddOptionalInputEdge<float>(); // past_key | ||
| tester.AddOptionalInputEdge<float>(); // past_value | ||
|
|
||
| std::vector<int32_t> seqlens_k_data(batch_size, static_cast<int32_t>(total_seq_len - 1)); | ||
| tester.AddInput<int32_t>("seqlens_k", {batch_size}, seqlens_k_data); | ||
| tester.AddInput<int32_t>("total_sequence_length", {1}, {static_cast<int32_t>(total_seq_len)}); | ||
|
|
||
| tester.AddOptionalInputEdge<float>(); // cos_cache | ||
| tester.AddOptionalInputEdge<float>(); // sin_cache | ||
| tester.AddOptionalInputEdge<int64_t>(); // position_ids | ||
| tester.AddOptionalInputEdge<float>(); // attention_bias | ||
| tester.AddOptionalInputEdge<float>(); // head_sink | ||
|
|
||
| // Output 0: output (always required) | ||
| tester.AddOutput<float>("output", {batch_size, sequence_length, hidden_size}, | ||
| std::vector<float>(batch_size * sequence_length * hidden_size, 0.0f)); | ||
|
|
||
| if (omit_present) { | ||
| // Omit present_key and present_value — they are optional | ||
| tester.AddOptionalOutputEdge<float>(); // present_key | ||
| tester.AddOptionalOutputEdge<float>(); // present_value | ||
| } else { | ||
| int present_seq_len = total_seq_len; | ||
| tester.AddOutput<float>("present_key", {batch_size, kv_num_heads, present_seq_len, head_size}, | ||
| std::vector<float>(batch_size * kv_num_heads * present_seq_len * head_size, 0.0f)); | ||
| tester.AddOutput<float>("present_value", {batch_size, kv_num_heads, present_seq_len, head_size}, | ||
| std::vector<float>(batch_size * kv_num_heads * present_seq_len * head_size, 0.0f)); | ||
| } | ||
|
|
||
| if (expect == OpTester::ExpectResult::kExpectSuccess) { | ||
| tester.SetOutputTolerance(1e6f); | ||
| } | ||
|
Comment on lines
+345
to
+373
|
||
|
|
||
| std::vector<std::unique_ptr<IExecutionProvider>> execution_providers; | ||
| execution_providers.push_back(DefaultCpuExecutionProvider()); | ||
| tester.Run(expect, expected_message, {}, nullptr, &execution_providers); | ||
| } | ||
|
|
||
| // Baseline: GQA with present outputs connected works as before | ||
| TEST(GroupQueryAttentionTest, OptionalPresent_WithPresent) { | ||
| RunGQAOptionalPresentTest( | ||
| /*batch_size=*/1, | ||
| /*sequence_length=*/4, | ||
| /*total_seq_len=*/4, | ||
| /*omit_present=*/false, | ||
| OpTester::ExpectResult::kExpectSuccess, | ||
| ""); | ||
| } | ||
|
|
||
| // KV-shared layer scenario: present outputs omitted, attention uses K,V directly | ||
| TEST(GroupQueryAttentionTest, OptionalPresent_WithoutPresent) { | ||
| RunGQAOptionalPresentTest( | ||
| /*batch_size=*/1, | ||
| /*sequence_length=*/4, | ||
| /*total_seq_len=*/4, | ||
| /*omit_present=*/true, | ||
| OpTester::ExpectResult::kExpectSuccess, | ||
| ""); | ||
| } | ||
|
|
||
| // Batched: present outputs omitted with batch_size > 1 | ||
| TEST(GroupQueryAttentionTest, OptionalPresent_Batched) { | ||
| RunGQAOptionalPresentTest( | ||
| /*batch_size=*/2, | ||
| /*sequence_length=*/3, | ||
| /*total_seq_len=*/3, | ||
| /*omit_present=*/true, | ||
| OpTester::ExpectResult::kExpectSuccess, | ||
| ""); | ||
| } | ||
|
|
||
| } // namespace test | ||
| } // namespace onnxruntime | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docs correctly mark
present_key/present_valueas optional, but (per CUDA implementation) omitting them is not currently supported by the CUDA EP. Add a brief note here (mandatory if CUDA limitation remains) stating that some execution providers (CUDA) may require these outputs to be connected, so users aren’t surprised by provider-specific INVALID_ARGUMENT errors.