diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index abb5b31b76e44..fe29d075b5829 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -892,5 +892,11 @@ void BackendManager::RewindKVCache(size_t index) { } } +void BackendManager::ReorderKVCache(const std::vector& src_indices, const std::vector& dst_indices) { + if (concrete_backend_) { + concrete_backend_->ReorderKVCache(src_indices, dst_indices); + } +} + } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/backend_manager.h b/onnxruntime/core/providers/openvino/backend_manager.h index 64dadb6c2151b..474bf2a01a019 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.h +++ b/onnxruntime/core/providers/openvino/backend_manager.h @@ -31,6 +31,7 @@ class BackendManager { void TryExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph, bool include_embed_data); ov::CompiledModel GetOVCompiledModel(); void RewindKVCache(size_t index); + void ReorderKVCache(const std::vector& src_indices, const std::vector& dst_indices); private: std::unique_ptr GetModelProtoFromFusedNode( diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index d7fc0553fb1d4..d08fa548b388b 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -334,6 +334,12 @@ void BasicBackend::RewindKVCache(size_t index) { }); } +void BasicBackend::ReorderKVCache(const std::vector& src_indices, const std::vector& dst_indices) { + infer_req_pool_->forEachIdleRequest([&](OVInferRequestPtr& infer_request) { + infer_request->ReorderKVCache(src_indices, dst_indices); + }); +} + void BasicBackend::Infer(OrtKernelContext* ctx) const { Ort::KernelContext context(ctx); diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index 2cf3d3faa8b47..a1b052ea7aa98 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -138,6 +138,7 @@ class BasicBackend : public IBackend { return exe_network_.Get(); } void RewindKVCache(size_t index) override; + void ReorderKVCache(const std::vector& src_indices, const std::vector& dst_indices) override; private: bool ValidateSubgraph(std::map>& const_outputs_map); diff --git a/onnxruntime/core/providers/openvino/ibackend.h b/onnxruntime/core/providers/openvino/ibackend.h index 365a4625815d6..672fdbc218a78 100644 --- a/onnxruntime/core/providers/openvino/ibackend.h +++ b/onnxruntime/core/providers/openvino/ibackend.h @@ -18,6 +18,7 @@ class IBackend { virtual ov::CompiledModel GetOVCompiledModel() = 0; virtual ~IBackend() = default; virtual void RewindKVCache(size_t index) {} + virtual void ReorderKVCache(const std::vector& src_indices, const std::vector& dst_indices) {} }; using ptr_stream_t = std::unique_ptr; class BackendFactory { diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index f9c9fa2ea6f48..0642ad55b5526 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -288,6 +288,63 @@ common::Status OpenVINOExecutionProvider::SetEpDynamicOptions(gsl::span std::pair> { + std::vector indices; + std::stringstream stream(input); + std::string token; + + try { + while (std::getline(stream, token, ',')) { + // Trim whitespace + token.erase(0, token.find_first_not_of(" \t")); + token.erase(token.find_last_not_of(" \t") + 1); + + if (!token.empty()) { + int64_t index = std::stoll(token); + if (index >= 0) { + indices.push_back(static_cast(index)); + } else { + return {Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "kvcache_reorder " + index_type + " cannot be negative: " + std::to_string(index)), + std::vector()}; + } + } + } + } catch (const std::exception& e) { + return {Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Failed to parse kvcache_reorder " + index_type + ": " + std::string(e.what())), + std::vector()}; + } + + return {Status::OK(), std::move(indices)}; + }; + + auto [src_status, src_indices] = parse_indices(src_string, "src_index"); + if (!src_status.IsOK()) { + return src_status; + } + + auto [dst_status, dst_indices] = parse_indices(dst_string, "dst_index"); + if (!dst_status.IsOK()) { + return dst_status; + } + + // Trigger KVCache Reorder for target Backend with vector arguments + for (auto& backend : backend_managers_) { + backend.ReorderKVCache(src_indices, dst_indices); + } } else { // Handle unknown options LOGS_DEFAULT(WARNING) << "Unknown key/value pair - ignoring " << key << "/" << value; diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 85fc4d93d6243..a1b518298903a 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -106,7 +106,7 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr& model, LOGS_DEFAULT(INFO) << log_tag << "Model IsStateful() Status:\t" << (model_status ? "True" : "False"); if (!model_status) { LOGS_DEFAULT(INFO) << log_tag << "Converting from Stateless OV Model to Stateful OV Model" << std::endl; - PatchStatefulDecoder(model); + PatchStatefulDecoder(model, hw_target); } if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) { @@ -361,6 +361,7 @@ void OVInferRequest::Infer() { StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device) : OVInferRequest(std::move(infer_request)), target_device(device) { bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos)); + is_support_kvcache_reorder = device.find("GPU") != std::string::npos; // check if there is input_ids tensors and if the tensor type is int64, // because logic prefill_use_full_chat_history is only for specific inputs and data type @@ -423,6 +424,32 @@ void StatefulOVInferRequest::PreProcessInferRequest() { // TODO(ankit): Address this issue and implement the fix at the appropriate layer. FillTensor("beam_idx", ov::element::i32, {1}, 0); + if (is_support_kvcache_reorder){ + ov::Shape dst_idx_shape = ovInfReq.get_tensor("dst_idx").get_shape(); + uint64_t kv_num_heads = dst_idx_shape[1]; + uint64_t kv_head_size = dst_idx_shape[3]; + if (kv_src_indices.size() > 0) { + ov::Tensor src_idx_tensor = ov::Tensor(ov::element::i32, {kv_src_indices.size()}); + for (auto i = 0; i < kv_src_indices.size(); ++i) { + src_idx_tensor.data()[i] = int32_t(kv_src_indices[i]); + } + ovInfReq.set_tensor("src_idx", src_idx_tensor); + + ov::Tensor dst_idx_tensor = ov::Tensor(ov::element::i32, {1, kv_num_heads, kv_dst_indices.size(), kv_head_size}); + for (auto i = 0; i < kv_dst_indices.size(); ++i) { + for (auto j = 0; j < kv_num_heads; ++j) { + for (auto k = 0; k < kv_head_size; ++k) { + dst_idx_tensor.data()[(j * kv_dst_indices.size() + i) * kv_head_size + k] = int32_t(kv_dst_indices[i]); + } + } + } + ovInfReq.set_tensor("dst_idx", dst_idx_tensor); + } else { + FillTensor("src_idx", ov::element::i32, {0}, 0); + FillTensor("dst_idx", ov::element::i32, {1, kv_num_heads, 0, kv_head_size}, 0); + } + } + // If 'prefill use full chat history' mode is enabled, we need to cache input_ids and position_ids. if (prefill_use_full_chat_history) { auto input_ids_tensor = ovInfReq.get_tensor("input_ids"); @@ -459,6 +486,33 @@ void StatefulOVInferRequest::PreProcessInferRequest() { void StatefulOVInferRequest::Infer() { PreProcessInferRequest(); OVInferRequest::Infer(); + PostProcessInferRequest(); +} + +void StatefulOVInferRequest::PostProcessInferRequest() { + if(is_support_kvcache_reorder){ + kv_src_indices.clear(); + kv_dst_indices.clear(); + } +} + +void StatefulOVInferRequest::ReorderKVCache(const std::vector& src_indices, const std::vector& dst_indices) { + // Validate input parameters + if (src_indices.size() != dst_indices.size()) { + ORT_THROW(log_tag + "ReorderKVCache: src_indices and dst_indices must have the same size. " + "Got src_indices.size()=" + std::to_string(src_indices.size()) + + ", dst_indices.size()=" + std::to_string(dst_indices.size())); + } + + LOGS_DEFAULT(INFO) << log_tag << "ReorderKVCache: Reordering OpenVINO-internal KVCache state with " + << src_indices.size() << " index pairs"; + + kv_src_indices.clear(); + kv_dst_indices.clear(); + for (int i = 0; i < src_indices.size(); ++i) { + kv_src_indices.emplace_back(src_indices[i]); + kv_dst_indices.emplace_back(dst_indices[i]); + } } void StatefulOVInferRequest::RewindKVCache(size_t index) { diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 8765cd040d098..2d70cc505f871 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -132,6 +132,7 @@ class OVInferRequest { return ovInfReq; } virtual void RewindKVCache([[maybe_unused]] size_t index) {} + virtual void ReorderKVCache([[maybe_unused]] const std::vector& src_indices, [[maybe_unused]] const std::vector& dst_indices) {} }; class StatefulOVInferRequest : public OVInferRequest { @@ -140,6 +141,7 @@ class StatefulOVInferRequest : public OVInferRequest { void Infer() override; void RewindKVCache(size_t index) override; + void ReorderKVCache(const std::vector& src_indices, const std::vector& dst_indices) override; void FillTensor(const std::string& tensor_name, const ov::element::Type& type, const std::vector& shape, int32_t fill_value); void CacheTensor(const std::string& tensor_name, std::vector& cache); @@ -148,6 +150,7 @@ class StatefulOVInferRequest : public OVInferRequest { private: void PreProcessInferRequest(); + void PostProcessInferRequest(); std::string target_device; // If prefill_use_full_chat_history is true, cache the "input_ids" & "position_ids" tensors, @@ -155,6 +158,10 @@ class StatefulOVInferRequest : public OVInferRequest { bool prefill_use_full_chat_history = false; std::vector cached_input_ids; std::vector cached_position_ids; + + bool is_support_kvcache_reorder = false; + std::vector kv_src_indices; + std::vector kv_dst_indices; }; } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index c4ec47534d009..cda2fed1fe3e2 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -75,11 +75,16 @@ std::string GetInputOutputName(std::shared_ptr ov_model, void FuseCacheReorder(std::shared_ptr ov_model, std::vector& not_kv_inputs, const std::vector& key_value_input_names, - int gather_dim) { + int gather_dim, + const std::string& device) { if (ModelHasInputOutputNames(ov_model, "beam_idx")) { throw std::runtime_error("Model already has fused cache"); } + // Flag to add Gather+ScatterElementsUpdate subgraph to reorder KV cache for LLM speculative decoding + // TO-DO: extend to NPU device when OpenVINO NPU has related optimization + bool is_support_kvcache_reorder = device.find("GPU") != std::string::npos; + // Define input name candidates in priority order const std::vector input_name_candidates = { "inputs_embeds", // Default fallback @@ -91,6 +96,7 @@ void FuseCacheReorder(std::shared_ptr ov_model, std::string main_input_name = GetInputOutputName(ov_model, input_name_candidates); auto input_batch = ov_model->input(main_input_name).get_partial_shape()[0]; + auto update_shape = ov_model->input(key_value_input_names[0]).get_partial_shape(); auto beam_idx = std::make_shared(ov::element::i32, ov::PartialShape({std::move(input_batch)})); beam_idx->set_friendly_name("beam_idx"); @@ -98,6 +104,23 @@ void FuseCacheReorder(std::shared_ptr ov_model, ov_model->add_parameters({beam_idx}); not_kv_inputs.push_back(beam_idx->get_friendly_name()); + std::shared_ptr src_idx; + std::shared_ptr dst_idx; + + if (is_support_kvcache_reorder) { + src_idx = std::make_shared(ov::element::i32, ov::PartialShape({update_shape[2]})); + src_idx->set_friendly_name("src_idx"); + src_idx->output(0).get_tensor().add_names({"src_idx"}); + ov_model->add_parameters({src_idx}); + not_kv_inputs.push_back(src_idx->get_friendly_name()); + + dst_idx = std::make_shared(ov::element::i32, update_shape); + dst_idx->set_friendly_name("dst_idx"); + dst_idx->output(0).get_tensor().add_names({"dst_idx"}); + ov_model->add_parameters({dst_idx}); + not_kv_inputs.push_back(dst_idx->get_friendly_name()); + } + // Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx for (const auto& input_name : key_value_input_names) { auto parameter_output_port = ov_model->input(input_name); @@ -108,9 +131,25 @@ void FuseCacheReorder(std::shared_ptr ov_model, beam_idx, ov::opset13::Constant::create(ov::element::i64, {}, {gather_dim})); + std::shared_ptr output_node; + if (is_support_kvcache_reorder) { + auto updatekv_gather_op = + std::make_shared(gather_op, + src_idx, + ov::opset13::Constant::create(ov::element::i64, {}, {2})); + + auto updatekv_op = std::make_shared(gather_op, + dst_idx, + updatekv_gather_op, + ov::opset13::Constant::create(ov::element::i64, {}, {2})); + output_node = updatekv_op; + } else { + output_node = gather_op; + } + // Replace the source output for all consumers of the input tensor for (auto& consumer : consumers) { - consumer.replace_source_output(gather_op->output(0)); + consumer.replace_source_output(output_node->output(0)); } } @@ -248,7 +287,7 @@ std::pair, std::vector> ExtractInputKVTens } // Updated PatchStatefulDecoder function -void PatchStatefulDecoder(std::shared_ptr model) { +void PatchStatefulDecoder(std::shared_ptr model, const std::string& device) { // Use the dynamic pattern-based extraction logic auto [key_value_output_names, extracted_patterns] = ExtractKVPatternsFromOutputs(model); auto [key_value_input_names, not_kv_inputs] = ExtractInputKVTensors(model, extracted_patterns); @@ -258,10 +297,10 @@ void PatchStatefulDecoder(std::shared_ptr model) { } if (key_value_input_names.size() != key_value_output_names.size()) { - ORT_THROW("Found different sizes between key_value_input_names (", - key_value_input_names.size(), - ") and key_value_output_names (", - key_value_output_names.size(), + ORT_THROW("Found different sizes between key_value_input_names (", + key_value_input_names.size(), + ") and key_value_output_names (", + key_value_output_names.size(), "). They couldn't be paired."); } @@ -270,7 +309,7 @@ void PatchStatefulDecoder(std::shared_ptr model) { // batch_dim = 1 if config.model_type == "chatglm" and not hasattr(config, "rope_ratio") else 0 auto batch_dim = 0; - FuseCacheReorder(model, not_kv_inputs, key_value_input_names, batch_dim); + FuseCacheReorder(model, not_kv_inputs, key_value_input_names, batch_dim, device); MakeStateful(model, key_value_input_names, key_value_output_names); } diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h index 0b89c4ed02e13..ce7db01063426 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h @@ -13,6 +13,7 @@ #include "openvino/pass/manager.hpp" #include "openvino/pass/make_stateful.hpp" +#include "openvino/opsets/opset12.hpp" #include "openvino/opsets/opset13.hpp" namespace onnxruntime { @@ -25,13 +26,14 @@ bool ModelHasInputOutputNames(std::shared_ptr model, const std::strin void FuseCacheReorder(std::shared_ptr ov_model, std::vector& not_kv_inputs, const std::vector& key_value_input_names, - int gather_dim); + int gather_dim, + const std::string& device = ""); void MakeStateful(std::shared_ptr& ov_model, const std::vector& key_value_input_names, const std::vector& key_value_output_names); -void PatchStatefulDecoder(std::shared_ptr model); +void PatchStatefulDecoder(std::shared_ptr model, const std::string& device = ""); bool HasOpWithType(const std::shared_ptr& function, const std::string& type_name);