-
Notifications
You must be signed in to change notification settings - Fork 56
Add Node to update KV cache in Stateful LLM model #872
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: ovep-develop
Are you sure you want to change the base?
Changes from 8 commits
0762470
1426c2a
2945283
5d00226
d884cd5
7676b30
5432bd4
c7f57bb
8f464d6
203ee33
7d201fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -288,6 +288,68 @@ common::Status OpenVINOExecutionProvider::SetEpDynamicOptions(gsl::span<const ch | |
| LOGS_DEFAULT(WARNING) << "kvcache_rewind index is < 0:\t" << index; | ||
| } | ||
| } | ||
| } else if (key == "kvcache_reorder") { | ||
| // Convert kvcache_reorder value format "1,2,3;4,5,6" into two vectors | ||
| // src_indices = [1,2,3], dst_indices = [4,5,6] | ||
| size_t delimiter_pos = value.find(';'); | ||
| if (delimiter_pos == std::string::npos) { | ||
| LOGS_DEFAULT(WARNING) << "kvcache_reorder value format is incorrect, expected format is 'x1,x2,x3;y1,y2,y3' where x and y are comma-separated int64_t lists"; | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| std::string src_string = value.substr(0, delimiter_pos); | ||
| std::string dst_string = value.substr(delimiter_pos + 1); | ||
|
|
||
| std::vector<size_t> src_indices; | ||
| std::vector<size_t> dst_indices; | ||
|
|
||
| try { | ||
| // Parse source indices from comma-separated string | ||
| std::stringstream src_stream(src_string); | ||
| std::string src_token; | ||
| while (std::getline(src_stream, src_token, ',')) { | ||
| // Trim whitespace | ||
| src_token.erase(0, src_token.find_first_not_of(" \t")); | ||
| src_token.erase(src_token.find_last_not_of(" \t") + 1); | ||
|
||
|
|
||
| if (!src_token.empty()) { | ||
| int64_t index = std::stoll(src_token); | ||
| if (index >= 0) { | ||
| src_indices.push_back(static_cast<size_t>(index)); | ||
| } else { | ||
| LOGS_DEFAULT(WARNING) << "kvcache_reorder src_index is < 0: " << index; | ||
|
||
| } | ||
| } | ||
| } | ||
|
|
||
| // Parse destination indices from comma-separated string | ||
| std::stringstream dst_stream(dst_string); | ||
| std::string dst_token; | ||
| while (std::getline(dst_stream, dst_token, ',')) { | ||
| // Trim whitespace | ||
| dst_token.erase(0, dst_token.find_first_not_of(" \t")); | ||
| dst_token.erase(dst_token.find_last_not_of(" \t") + 1); | ||
|
||
|
|
||
| if (!dst_token.empty()) { | ||
| int64_t index = std::stoll(dst_token); | ||
| if (index >= 0) { | ||
| dst_indices.push_back(static_cast<size_t>(index)); | ||
| } else { | ||
| LOGS_DEFAULT(WARNING) << "kvcache_reorder dst_index is < 0: " << index; | ||
|
||
| } | ||
| } | ||
| } | ||
|
|
||
| } catch (const std::exception& e) { | ||
| LOGS_DEFAULT(WARNING) << "Conversion for kvcache_reorder string value to int64_t indices failed. " | ||
|
||
| << "Exception: " << e.what(); | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| // Trigger KVCache Reorder for target Backend with vector arguments | ||
| for (auto& backend : backend_managers_) { | ||
| backend.ReorderKVCache(src_indices, dst_indices); | ||
| } | ||
| } else { | ||
| // Handle unknown options | ||
| LOGS_DEFAULT(WARNING) << "Unknown key/value pair - ignoring " << key << "/" << value; | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -423,6 +423,26 @@ void StatefulOVInferRequest::PreProcessInferRequest() { | |||||||||||||||||||||||||||||
| // TODO(ankit): Address this issue and implement the fix at the appropriate layer. | ||||||||||||||||||||||||||||||
| FillTensor("beam_idx", ov::element::i32, {1}, 0); | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if (kv_src_indices.size() > 0) { | ||||||||||||||||||||||||||||||
| ov::Tensor src_idx_tensor = ov::Tensor(ov::element::i32, {kv_src_indices.size()}); | ||||||||||||||||||||||||||||||
| for (int i = 0; i < kv_src_indices.size(); ++i) { | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
| src_idx_tensor.data<int32_t>()[i] = int32_t(kv_src_indices[i]); | ||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||
| ovInfReq.set_tensor("src_idx", src_idx_tensor); | ||||||||||||||||||||||||||||||
| ov::Tensor dst_idx_tensor = ov::Tensor(ov::element::i32, {1, 32, kv_dst_indices.size(), 96}); | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
| for (int i = 0; i < kv_dst_indices.size(); ++i) { | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
| for (int i = 0; i < kv_src_indices.size(); ++i) { | |
| src_idx_tensor.data<int32_t>()[i] = int32_t(kv_src_indices[i]); | |
| } | |
| ovInfReq.set_tensor("src_idx", src_idx_tensor); | |
| ov::Tensor dst_idx_tensor = ov::Tensor(ov::element::i32, {1, 32, kv_dst_indices.size(), 96}); | |
| for (int i = 0; i < kv_dst_indices.size(); ++i) { | |
| for (size_t i = 0; i < kv_src_indices.size(); ++i) { | |
| src_idx_tensor.data<int32_t>()[i] = int32_t(kv_src_indices[i]); | |
| } | |
| ovInfReq.set_tensor("src_idx", src_idx_tensor); | |
| ov::Tensor dst_idx_tensor = ov::Tensor(ov::element::i32, {1, 32, kv_dst_indices.size(), 96}); | |
| for (size_t i = 0; i < kv_dst_indices.size(); ++i) { |
Outdated
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Loop variable i should be size_t instead of int to match the type of kv_dst_indices.size() and avoid signed/unsigned comparison warnings.
| for (int i = 0; i < kv_dst_indices.size(); ++i) { | |
| for (size_t i = 0; i < kv_dst_indices.size(); ++i) { |
Outdated
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hardcoded shape {1, 32, 0, 96} contains magic numbers (32 and 96) that should match the constants used in line 432. These should be extracted as named constants to ensure consistency and improve maintainability.
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Loop variable i should be size_t instead of int to match the type of src_indices.size() and avoid signed/unsigned comparison warnings.
| for (int i = 0; i < src_indices.size(); ++i) { | |
| for (size_t i = 0; i < src_indices.size(); ++i) { |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -132,6 +132,7 @@ class OVInferRequest { | |
| return ovInfReq; | ||
| } | ||
| virtual void RewindKVCache([[maybe_unused]] size_t index) {} | ||
| virtual void ReorderKVCache([[maybe_unused]] const std::vector<size_t>& src_indices, [[maybe_unused]] const std::vector<size_t>& dst_indices) {} | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it is not necessary for functionality. More intention is to follow the original style. |
||
| }; | ||
|
|
||
| class StatefulOVInferRequest : public OVInferRequest { | ||
|
|
@@ -140,6 +141,7 @@ class StatefulOVInferRequest : public OVInferRequest { | |
|
|
||
| void Infer() override; | ||
| void RewindKVCache(size_t index) override; | ||
| void ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) override; | ||
| void FillTensor(const std::string& tensor_name, const ov::element::Type& type, | ||
| const std::vector<size_t>& shape, int32_t fill_value); | ||
| void CacheTensor(const std::string& tensor_name, std::vector<int64_t>& cache); | ||
|
|
@@ -148,13 +150,16 @@ class StatefulOVInferRequest : public OVInferRequest { | |
|
|
||
| private: | ||
| void PreProcessInferRequest(); | ||
| void PostProcessInferRequest(); | ||
| std::string target_device; | ||
|
|
||
| // If prefill_use_full_chat_history is true, cache the "input_ids" & "position_ids" tensors, | ||
| // and ensure that full chat history is passed for each prefill call. | ||
| bool prefill_use_full_chat_history = false; | ||
| std::vector<int64_t> cached_input_ids; | ||
| std::vector<int64_t> cached_position_ids; | ||
| std::vector<int64_t> kv_src_indices; | ||
| std::vector<int64_t> kv_dst_indices; | ||
| }; | ||
|
|
||
| } // namespace openvino_ep | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -91,13 +91,26 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model, | |
| std::string main_input_name = GetInputOutputName(ov_model, input_name_candidates); | ||
|
|
||
| auto input_batch = ov_model->input(main_input_name).get_partial_shape()[0]; | ||
| auto update_shape = ov_model->input(key_value_input_names[0]).get_partial_shape(); | ||
|
|
||
| auto beam_idx = std::make_shared<ov::opset13::Parameter>(ov::element::i32, ov::PartialShape({std::move(input_batch)})); | ||
| beam_idx->set_friendly_name("beam_idx"); | ||
| beam_idx->output(0).get_tensor().add_names({"beam_idx"}); | ||
| ov_model->add_parameters({beam_idx}); | ||
| not_kv_inputs.push_back(beam_idx->get_friendly_name()); | ||
|
|
||
| auto src_idx = std::make_shared<ov::opset13::Parameter>(ov::element::i32, ov::PartialShape({update_shape[2]})); | ||
|
||
| src_idx->set_friendly_name("src_idx"); | ||
| src_idx->output(0).get_tensor().add_names({"src_idx"}); | ||
| ov_model->add_parameters({src_idx}); | ||
| not_kv_inputs.push_back(src_idx->get_friendly_name()); | ||
|
|
||
| auto dst_idx = std::make_shared<ov::opset13::Parameter>(ov::element::i32, update_shape); | ||
| dst_idx->set_friendly_name("dst_idx"); | ||
| dst_idx->output(0).get_tensor().add_names({"dst_idx"}); | ||
| ov_model->add_parameters({dst_idx}); | ||
| not_kv_inputs.push_back(dst_idx->get_friendly_name()); | ||
|
|
||
| // Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx | ||
| for (const auto& input_name : key_value_input_names) { | ||
| auto parameter_output_port = ov_model->input(input_name); | ||
|
|
@@ -108,9 +121,17 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model, | |
| beam_idx, | ||
| ov::opset13::Constant::create(ov::element::i64, {}, {gather_dim})); | ||
|
|
||
| auto updatekv_gather_op = | ||
| std::make_shared<ov::opset13::Gather>(gather_op, | ||
| src_idx, | ||
| ov::opset13::Constant::create(ov::element::i64, {}, {2})); | ||
|
|
||
| auto updatekv_op = std::make_shared<ov::opset12::ScatterElementsUpdate>(gather_op, | ||
| dst_idx, updatekv_gather_op, ov::opset13::Constant::create(ov::element::i64, {}, {2})); | ||
|
|
||
| // Replace the source output for all consumers of the input tensor | ||
| for (auto& consumer : consumers) { | ||
| consumer.replace_source_output(gather_op->output(0)); | ||
| consumer.replace_source_output(updatekv_op->output(0)); | ||
| } | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.