Skip to content
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/openvino/backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -892,5 +892,11 @@ void BackendManager::RewindKVCache(size_t index) {
}
}

void BackendManager::ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) {
if (concrete_backend_) {
concrete_backend_->ReorderKVCache(src_indices, dst_indices);
}
}

} // namespace openvino_ep
} // namespace onnxruntime
1 change: 1 addition & 0 deletions onnxruntime/core/providers/openvino/backend_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class BackendManager {
void TryExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph, bool include_embed_data);
ov::CompiledModel GetOVCompiledModel();
void RewindKVCache(size_t index);
void ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices);

private:
std::unique_ptr<ONNX_NAMESPACE::ModelProto> GetModelProtoFromFusedNode(
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/openvino/backends/basic_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,12 @@ void BasicBackend::RewindKVCache(size_t index) {
});
}

void BasicBackend::ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) {
infer_req_pool_->forEachIdleRequest([&](OVInferRequestPtr& infer_request) {
infer_request->ReorderKVCache(src_indices, dst_indices);
});
}

void BasicBackend::Infer(OrtKernelContext* ctx) const {
Ort::KernelContext context(ctx);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class BasicBackend : public IBackend {
return exe_network_.Get();
}
void RewindKVCache(size_t index) override;
void ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) override;

private:
bool ValidateSubgraph(std::map<std::string, std::shared_ptr<ov::Node>>& const_outputs_map);
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/openvino/ibackend.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class IBackend {
virtual ov::CompiledModel GetOVCompiledModel() = 0;
virtual ~IBackend() = default;
virtual void RewindKVCache(size_t index) {}
virtual void ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) {}
};
using ptr_stream_t = std::unique_ptr<ModelBlobWrapper>;
class BackendFactory {
Expand Down
62 changes: 62 additions & 0 deletions onnxruntime/core/providers/openvino/openvino_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,68 @@ common::Status OpenVINOExecutionProvider::SetEpDynamicOptions(gsl::span<const ch
LOGS_DEFAULT(WARNING) << "kvcache_rewind index is < 0:\t" << index;
}
}
} else if (key == "kvcache_reorder") {
// Convert kvcache_reorder value format "1,2,3;4,5,6" into two vectors

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Convert kvcache_reorder value format "1,2,3;4,5,6" into two vectors
// Convert kvcache_reorder value format "1,2,3;4,5,6" into two vectors

// src_indices = [1,2,3], dst_indices = [4,5,6]
size_t delimiter_pos = value.find(';');
if (delimiter_pos == std::string::npos) {
LOGS_DEFAULT(WARNING) << "kvcache_reorder value format is incorrect, expected format is 'x1,x2,x3;y1,y2,y3' where x and y are comma-separated int64_t lists";
return Status::OK();
}

std::string src_string = value.substr(0, delimiter_pos);
std::string dst_string = value.substr(delimiter_pos + 1);

std::vector<size_t> src_indices;
std::vector<size_t> dst_indices;

try {
// Parse source indices from comma-separated string
std::stringstream src_stream(src_string);
std::string src_token;
while (std::getline(src_stream, src_token, ',')) {
// Trim whitespace
src_token.erase(0, src_token.find_first_not_of(" \t"));
src_token.erase(src_token.find_last_not_of(" \t") + 1);
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whitespace trimming logic is duplicated for both src_token and dst_token. Consider extracting this into a helper function (e.g., TrimWhitespace(std::string&)) to reduce code duplication and improve maintainability.

Copilot uses AI. Check for mistakes.

if (!src_token.empty()) {
int64_t index = std::stoll(src_token);
if (index >= 0) {
src_indices.push_back(static_cast<size_t>(index));
} else {
LOGS_DEFAULT(WARNING) << "kvcache_reorder src_index is < 0: " << index;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should throw an exception here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
}
}

// Parse destination indices from comma-separated string
std::stringstream dst_stream(dst_string);
std::string dst_token;
while (std::getline(dst_stream, dst_token, ',')) {
// Trim whitespace
dst_token.erase(0, dst_token.find_first_not_of(" \t"));
dst_token.erase(dst_token.find_last_not_of(" \t") + 1);
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whitespace trimming logic is duplicated for both src_token and dst_token. Consider extracting this into a helper function (e.g., TrimWhitespace(std::string&)) to reduce code duplication and improve maintainability.

Copilot uses AI. Check for mistakes.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost identical branches for src and dst, consider refactoring

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


if (!dst_token.empty()) {
int64_t index = std::stoll(dst_token);
if (index >= 0) {
dst_indices.push_back(static_cast<size_t>(index));
} else {
LOGS_DEFAULT(WARNING) << "kvcache_reorder dst_index is < 0: " << index;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should throw an exception here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
}
}

} catch (const std::exception& e) {
LOGS_DEFAULT(WARNING) << "Conversion for kvcache_reorder string value to int64_t indices failed. "

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here we should actually return an error / throw an exception.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

<< "Exception: " << e.what();
return Status::OK();
}

// Trigger KVCache Reorder for target Backend with vector arguments
for (auto& backend : backend_managers_) {
backend.ReorderKVCache(src_indices, dst_indices);
}
} else {
// Handle unknown options
LOGS_DEFAULT(WARNING) << "Unknown key/value pair - ignoring " << key << "/" << value;
Expand Down
45 changes: 45 additions & 0 deletions onnxruntime/core/providers/openvino/ov_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,26 @@ void StatefulOVInferRequest::PreProcessInferRequest() {
// TODO(ankit): Address this issue and implement the fix at the appropriate layer.
FillTensor("beam_idx", ov::element::i32, {1}, 0);

if (kv_src_indices.size() > 0) {
ov::Tensor src_idx_tensor = ov::Tensor(ov::element::i32, {kv_src_indices.size()});
for (int i = 0; i < kv_src_indices.size(); ++i) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signed-unsigned mismatch

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all switches to auto

src_idx_tensor.data<int32_t>()[i] = int32_t(kv_src_indices[i]);
}
ovInfReq.set_tensor("src_idx", src_idx_tensor);
ov::Tensor dst_idx_tensor = ov::Tensor(ov::element::i32, {1, 32, kv_dst_indices.size(), 96});
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hardcoded dimensions {1, 32, kv_dst_indices.size(), 96} in dst_idx_tensor creation appear as magic numbers. These values (32 and 96) should be extracted from the model's KV cache shape or defined as named constants to improve maintainability and prevent issues if model dimensions change.

Copilot uses AI. Check for mistakes.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hardcoded 32 and 96 values for this block -- could they be derived by something instead of fixing them as a magic number?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

for (int i = 0; i < kv_dst_indices.size(); ++i) {
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loop variable i should be size_t instead of int to match the type of kv_src_indices.size() and avoid signed/unsigned comparison warnings.

Suggested change
for (int i = 0; i < kv_src_indices.size(); ++i) {
src_idx_tensor.data<int32_t>()[i] = int32_t(kv_src_indices[i]);
}
ovInfReq.set_tensor("src_idx", src_idx_tensor);
ov::Tensor dst_idx_tensor = ov::Tensor(ov::element::i32, {1, 32, kv_dst_indices.size(), 96});
for (int i = 0; i < kv_dst_indices.size(); ++i) {
for (size_t i = 0; i < kv_src_indices.size(); ++i) {
src_idx_tensor.data<int32_t>()[i] = int32_t(kv_src_indices[i]);
}
ovInfReq.set_tensor("src_idx", src_idx_tensor);
ov::Tensor dst_idx_tensor = ov::Tensor(ov::element::i32, {1, 32, kv_dst_indices.size(), 96});
for (size_t i = 0; i < kv_dst_indices.size(); ++i) {

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loop variable i should be size_t instead of int to match the type of kv_dst_indices.size() and avoid signed/unsigned comparison warnings.

Suggested change
for (int i = 0; i < kv_dst_indices.size(); ++i) {
for (size_t i = 0; i < kv_dst_indices.size(); ++i) {

Copilot uses AI. Check for mistakes.
for (int j = 0; j < 32; ++j) {
for (int k = 0; k < 96; ++k) {
dst_idx_tensor.data<int32_t>()[(j * kv_dst_indices.size() + i) * 96 + k] = int32_t(kv_dst_indices[i]);
}
}
}
ovInfReq.set_tensor("dst_idx", dst_idx_tensor);
} else {
FillTensor("src_idx", ov::element::i32, {0}, 0);
FillTensor("dst_idx", ov::element::i32, {1, 32, 0, 96}, 0);
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hardcoded shape {1, 32, 0, 96} contains magic numbers (32 and 96) that should match the constants used in line 432. These should be extracted as named constants to ensure consistency and improve maintainability.

Copilot uses AI. Check for mistakes.
}

// If 'prefill use full chat history' mode is enabled, we need to cache input_ids and position_ids.
if (prefill_use_full_chat_history) {
auto input_ids_tensor = ovInfReq.get_tensor("input_ids");
Expand Down Expand Up @@ -459,6 +479,31 @@ void StatefulOVInferRequest::PreProcessInferRequest() {
void StatefulOVInferRequest::Infer() {
PreProcessInferRequest();
OVInferRequest::Infer();
PostProcessInferRequest();
}

void StatefulOVInferRequest::PostProcessInferRequest() {
kv_src_indices.clear();
kv_dst_indices.clear();
}

void StatefulOVInferRequest::ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) {
// Validate input parameters
if (src_indices.size() != dst_indices.size()) {
ORT_THROW(log_tag + "ReorderKVCache: src_indices and dst_indices must have the same size. "
"Got src_indices.size()=" + std::to_string(src_indices.size()) +
", dst_indices.size()=" + std::to_string(dst_indices.size()));
}

LOGS_DEFAULT(INFO) << log_tag << "ReorderKVCache: Reordering OpenVINO-internal KVCache state with "
<< src_indices.size() << " index pairs";

kv_src_indices.clear();
kv_dst_indices.clear();
for (int i = 0; i < src_indices.size(); ++i) {
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loop variable i should be size_t instead of int to match the type of src_indices.size() and avoid signed/unsigned comparison warnings.

Suggested change
for (int i = 0; i < src_indices.size(); ++i) {
for (size_t i = 0; i < src_indices.size(); ++i) {

Copilot uses AI. Check for mistakes.
kv_src_indices.emplace_back(src_indices[i]);
kv_dst_indices.emplace_back(dst_indices[i]);
}
}

void StatefulOVInferRequest::RewindKVCache(size_t index) {
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/openvino/ov_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class OVInferRequest {
return ovInfReq;
}
virtual void RewindKVCache([[maybe_unused]] size_t index) {}
virtual void ReorderKVCache([[maybe_unused]] const std::vector<size_t>& src_indices, [[maybe_unused]] const std::vector<size_t>& dst_indices) {}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is [[maybe_unused]] really necessary here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is not necessary for functionality. More intention is to follow the original style.

};

class StatefulOVInferRequest : public OVInferRequest {
Expand All @@ -140,6 +141,7 @@ class StatefulOVInferRequest : public OVInferRequest {

void Infer() override;
void RewindKVCache(size_t index) override;
void ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) override;
void FillTensor(const std::string& tensor_name, const ov::element::Type& type,
const std::vector<size_t>& shape, int32_t fill_value);
void CacheTensor(const std::string& tensor_name, std::vector<int64_t>& cache);
Expand All @@ -148,13 +150,16 @@ class StatefulOVInferRequest : public OVInferRequest {

private:
void PreProcessInferRequest();
void PostProcessInferRequest();
std::string target_device;

// If prefill_use_full_chat_history is true, cache the "input_ids" & "position_ids" tensors,
// and ensure that full chat history is passed for each prefill call.
bool prefill_use_full_chat_history = false;
std::vector<int64_t> cached_input_ids;
std::vector<int64_t> cached_position_ids;
std::vector<int64_t> kv_src_indices;
std::vector<int64_t> kv_dst_indices;
};

} // namespace openvino_ep
Expand Down
23 changes: 22 additions & 1 deletion onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,26 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
std::string main_input_name = GetInputOutputName(ov_model, input_name_candidates);

auto input_batch = ov_model->input(main_input_name).get_partial_shape()[0];
auto update_shape = ov_model->input(key_value_input_names[0]).get_partial_shape();

auto beam_idx = std::make_shared<ov::opset13::Parameter>(ov::element::i32, ov::PartialShape({std::move(input_batch)}));
beam_idx->set_friendly_name("beam_idx");
beam_idx->output(0).get_tensor().add_names({"beam_idx"});
ov_model->add_parameters({beam_idx});
not_kv_inputs.push_back(beam_idx->get_friendly_name());

auto src_idx = std::make_shared<ov::opset13::Parameter>(ov::element::i32, ov::PartialShape({update_shape[2]}));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So do I understand correctly that stateful flow will always add src_idx / dst_idx input tensors to the model?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it will always to the model. For OV GPU, it will be optimized out if the input are all 0s. For NPU, a flag is added to bypass the logic of kv cache reroder.

src_idx->set_friendly_name("src_idx");
src_idx->output(0).get_tensor().add_names({"src_idx"});
ov_model->add_parameters({src_idx});
not_kv_inputs.push_back(src_idx->get_friendly_name());

auto dst_idx = std::make_shared<ov::opset13::Parameter>(ov::element::i32, update_shape);
dst_idx->set_friendly_name("dst_idx");
dst_idx->output(0).get_tensor().add_names({"dst_idx"});
ov_model->add_parameters({dst_idx});
not_kv_inputs.push_back(dst_idx->get_friendly_name());

// Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx
for (const auto& input_name : key_value_input_names) {
auto parameter_output_port = ov_model->input(input_name);
Expand All @@ -108,9 +121,17 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
beam_idx,
ov::opset13::Constant::create(ov::element::i64, {}, {gather_dim}));

auto updatekv_gather_op =
std::make_shared<ov::opset13::Gather>(gather_op,
src_idx,
ov::opset13::Constant::create(ov::element::i64, {}, {2}));

auto updatekv_op = std::make_shared<ov::opset12::ScatterElementsUpdate>(gather_op,
dst_idx, updatekv_gather_op, ov::opset13::Constant::create(ov::element::i64, {}, {2}));

// Replace the source output for all consumers of the input tensor
for (auto& consumer : consumers) {
consumer.replace_source_output(gather_op->output(0));
consumer.replace_source_output(updatekv_op->output(0));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "openvino/pass/manager.hpp"
#include "openvino/pass/make_stateful.hpp"
#include "openvino/opsets/opset12.hpp"
#include "openvino/opsets/opset13.hpp"

namespace onnxruntime {
Expand Down