Skip to content

Commit 729d59d

Browse files
mdvoretc-intelKotomi-Du
authored andcommitted
Reorder KV cache using the new gather_by_axis API
Do a ScatterElementsUpdate-based reorder during execution Get variable update lengths from incoming indices Make changes to support new KVCache fusion Add proper include
1 parent fa8f464 commit 729d59d

File tree

4 files changed

+78
-1
lines changed

4 files changed

+78
-1
lines changed

onnxruntime/core/providers/openvino/ov_interface.cc

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,26 @@ void StatefulOVInferRequest::PreProcessInferRequest() {
463463
// TODO(ankit): Address this issue and implement the fix at the appropriate layer.
464464
FillTensor("beam_idx", ov::element::i32, {1}, 0);
465465

466+
if (src_idx_val.size() > 0) {
467+
ov::Tensor src_idx_tensor = ov::Tensor(ov::element::i32, {src_idx_val.size()});
468+
for (int i = 0; i < src_idx_val.size(); ++i) {
469+
src_idx_tensor.data<int32_t>()[i] = int32_t(src_idx_val[i]);
470+
}
471+
ovInfReq.set_tensor("src_idx", src_idx_tensor);
472+
ov::Tensor dst_idx_tensor = ov::Tensor(ov::element::i32, {1, 32, dst_idx_val.size(), 96});
473+
for (int i = 0; i < dst_idx_val.size(); ++i) {
474+
for (int j = 0; j < 32; ++j) {
475+
for (int k = 0; k < 96; ++k) {
476+
dst_idx_tensor.data<int32_t>()[(j * dst_idx_val.size() + i) * 96 + k] = int32_t(dst_idx_val[i]);
477+
}
478+
}
479+
}
480+
ovInfReq.set_tensor("dst_idx", dst_idx_tensor);
481+
} else {
482+
FillTensor("src_idx", ov::element::i32, {0}, 0);
483+
FillTensor("dst_idx", ov::element::i32, {1, 32, 0, 96}, 0);
484+
}
485+
466486
// If 'prefill use full chat history' mode is enabled, we need to cache input_ids and position_ids.
467487
if (prefill_use_full_chat_history) {
468488
auto input_ids_tensor = ovInfReq.get_tensor("input_ids");
@@ -501,6 +521,38 @@ void StatefulOVInferRequest::Infer() {
501521
OVInferRequest::Infer();
502522
}
503523

524+
void StatefulOVInferRequest::ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) {
525+
// Validate input parameters
526+
if (src_indices.size() != dst_indices.size()) {
527+
ORT_THROW(log_tag + "ReorderKVCache: src_indices and dst_indices must have the same size. "
528+
"Got src_indices.size()=" + std::to_string(src_indices.size()) +
529+
", dst_indices.size()=" + std::to_string(dst_indices.size()));
530+
}
531+
532+
LOGS_DEFAULT(INFO) << log_tag << "ReorderKVCache: Reordering OpenVINO-internal KVCache state with "
533+
<< src_indices.size() << " index pairs";
534+
535+
// set beam_idx and dst_idx based on provided values
536+
src_idx_val.clear();
537+
dst_idx_val.clear();
538+
for (int i = 0; i < src_indices.size(); ++i) {
539+
src_idx_val.emplace_back(src_indices[i]);
540+
dst_idx_val.emplace_back(dst_indices[i]);
541+
}
542+
/*
543+
// Retrieve KVCache states and reorder them based on the provided indices
544+
auto states = ovInfReq.query_state();
545+
546+
for (auto& state : states) {
547+
auto start_time = std::chrono::high_resolution_clock::now();
548+
state.gather_by_axis(src_indices, dst_indices);
549+
auto end_time = std::chrono::high_resolution_clock::now();
550+
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
551+
LOGS_DEFAULT(INFO) << log_tag << "gather_by_axis: " << duration << " microseconds";
552+
}
553+
*/
554+
}
555+
504556
void StatefulOVInferRequest::RewindKVCache(size_t index) {
505557
LOGS_DEFAULT(INFO) << log_tag << "RewindKVCache: Rewinding OpenVINO-internal KVCache state to index=" << index;
506558

onnxruntime/core/providers/openvino/ov_interface.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ class StatefulOVInferRequest : public OVInferRequest {
161161

162162
bool IsNPULogitsSliceRequired();
163163
bool _npu_logits_slice_required = false;
164+
std::vector<int64_t> src_idx_val;
165+
std::vector<int64_t> dst_idx_val;
164166
};
165167

166168
} // namespace openvino_ep

onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,26 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
9191
std::string main_input_name = GetInputOutputName(ov_model, input_name_candidates);
9292

9393
auto input_batch = ov_model->input(main_input_name).get_partial_shape()[0];
94+
auto update_shape = ov_model->input(key_value_input_names[0]).get_partial_shape();
9495

9596
auto beam_idx = std::make_shared<ov::opset13::Parameter>(ov::element::i32, ov::PartialShape({std::move(input_batch)}));
9697
beam_idx->set_friendly_name("beam_idx");
9798
beam_idx->output(0).get_tensor().add_names({"beam_idx"});
9899
ov_model->add_parameters({beam_idx});
99100
not_kv_inputs.push_back(beam_idx->get_friendly_name());
100101

102+
auto src_idx = std::make_shared<ov::opset13::Parameter>(ov::element::i32, ov::PartialShape({update_shape[2]}));
103+
src_idx->set_friendly_name("src_idx");
104+
src_idx->output(0).get_tensor().add_names({"src_idx"});
105+
ov_model->add_parameters({src_idx});
106+
not_kv_inputs.push_back(src_idx->get_friendly_name());
107+
108+
auto dst_idx = std::make_shared<ov::opset13::Parameter>(ov::element::i32, update_shape);
109+
dst_idx->set_friendly_name("dst_idx");
110+
dst_idx->output(0).get_tensor().add_names({"dst_idx"});
111+
ov_model->add_parameters({dst_idx});
112+
not_kv_inputs.push_back(dst_idx->get_friendly_name());
113+
101114
// Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx
102115
for (const auto& input_name : key_value_input_names) {
103116
auto parameter_output_port = ov_model->input(input_name);
@@ -108,9 +121,17 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
108121
beam_idx,
109122
ov::opset13::Constant::create(ov::element::i64, {}, {gather_dim}));
110123

124+
auto update_gather_op =
125+
std::make_shared<ov::opset13::Gather>(gather_op,
126+
src_idx,
127+
ov::opset13::Constant::create(ov::element::i64, {}, {2}));
128+
129+
auto update_op = std::make_shared<ov::opset12::ScatterElementsUpdate>(gather_op,
130+
dst_idx, update_gather_op, ov::opset13::Constant::create(ov::element::i64, {}, {2}));
131+
111132
// Replace the source output for all consumers of the input tensor
112133
for (auto& consumer : consumers) {
113-
consumer.replace_source_output(gather_op->output(0));
134+
consumer.replace_source_output(update_op->output(0));
114135
}
115136
}
116137

onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
#include "openvino/pass/manager.hpp"
1515
#include "openvino/pass/make_stateful.hpp"
16+
#include "openvino/opsets/opset3.hpp"
17+
#include "openvino/opsets/opset12.hpp"
1618
#include "openvino/opsets/opset13.hpp"
1719

1820
namespace onnxruntime {

0 commit comments

Comments
 (0)