Skip to content

Commit 899feb5

Browse files
committed
clean up code
1 parent 7676b30 commit 899feb5

File tree

4 files changed

+21
-35
lines changed

4 files changed

+21
-35
lines changed

onnxruntime/core/providers/openvino/ov_interface.cc

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -423,17 +423,17 @@ void StatefulOVInferRequest::PreProcessInferRequest() {
423423
// TODO(ankit): Address this issue and implement the fix at the appropriate layer.
424424
FillTensor("beam_idx", ov::element::i32, {1}, 0);
425425

426-
if (src_idx_val.size() > 0) {
427-
ov::Tensor src_idx_tensor = ov::Tensor(ov::element::i32, {src_idx_val.size()});
428-
for (int i = 0; i < src_idx_val.size(); ++i) {
429-
src_idx_tensor.data<int32_t>()[i] = int32_t(src_idx_val[i]);
426+
if (kv_src_indices.size() > 0) {
427+
ov::Tensor src_idx_tensor = ov::Tensor(ov::element::i32, {kv_src_indices.size()});
428+
for (int i = 0; i < kv_src_indices.size(); ++i) {
429+
src_idx_tensor.data<int32_t>()[i] = int32_t(kv_src_indices[i]);
430430
}
431431
ovInfReq.set_tensor("src_idx", src_idx_tensor);
432-
ov::Tensor dst_idx_tensor = ov::Tensor(ov::element::i32, {1, 32, dst_idx_val.size(), 96});
433-
for (int i = 0; i < dst_idx_val.size(); ++i) {
432+
ov::Tensor dst_idx_tensor = ov::Tensor(ov::element::i32, {1, 32, kv_dst_indices.size(), 96});
433+
for (int i = 0; i < kv_dst_indices.size(); ++i) {
434434
for (int j = 0; j < 32; ++j) {
435435
for (int k = 0; k < 96; ++k) {
436-
dst_idx_tensor.data<int32_t>()[(j * dst_idx_val.size() + i) * 96 + k] = int32_t(dst_idx_val[i]);
436+
dst_idx_tensor.data<int32_t>()[(j * kv_dst_indices.size() + i) * 96 + k] = int32_t(kv_dst_indices[i]);
437437
}
438438
}
439439
}
@@ -492,25 +492,12 @@ void StatefulOVInferRequest::ReorderKVCache(const std::vector<size_t>& src_indic
492492
LOGS_DEFAULT(INFO) << log_tag << "ReorderKVCache: Reordering OpenVINO-internal KVCache state with "
493493
<< src_indices.size() << " index pairs";
494494

495-
// set beam_idx and dst_idx based on provided values
496-
src_idx_val.clear();
497-
dst_idx_val.clear();
495+
kv_src_indices.clear();
496+
kv_dst_indices.clear();
498497
for (int i = 0; i < src_indices.size(); ++i) {
499-
src_idx_val.emplace_back(src_indices[i]);
500-
dst_idx_val.emplace_back(dst_indices[i]);
498+
kv_src_indices.emplace_back(src_indices[i]);
499+
kv_dst_indices.emplace_back(dst_indices[i]);
501500
}
502-
/*
503-
// Retrieve KVCache states and reorder them based on the provided indices
504-
auto states = ovInfReq.query_state();
505-
506-
for (auto& state : states) {
507-
auto start_time = std::chrono::high_resolution_clock::now();
508-
state.gather_by_axis(src_indices, dst_indices);
509-
auto end_time = std::chrono::high_resolution_clock::now();
510-
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
511-
LOGS_DEFAULT(INFO) << log_tag << "gather_by_axis: " << duration << " microseconds";
512-
}
513-
*/
514501
}
515502

516503
void StatefulOVInferRequest::RewindKVCache(size_t index) {

onnxruntime/core/providers/openvino/ov_interface.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ class StatefulOVInferRequest : public OVInferRequest {
157157
bool prefill_use_full_chat_history = false;
158158
std::vector<int64_t> cached_input_ids;
159159
std::vector<int64_t> cached_position_ids;
160-
std::vector<int64_t> src_idx_val;
161-
std::vector<int64_t> dst_idx_val;
160+
std::vector<int64_t> kv_src_indices;
161+
std::vector<int64_t> kv_dst_indices;
162162
};
163163

164164
} // namespace openvino_ep

onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,17 +121,17 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
121121
beam_idx,
122122
ov::opset13::Constant::create(ov::element::i64, {}, {gather_dim}));
123123

124-
auto update_gather_op =
124+
auto updatekv_gather_op =
125125
std::make_shared<ov::opset13::Gather>(gather_op,
126126
src_idx,
127127
ov::opset13::Constant::create(ov::element::i64, {}, {2}));
128128

129-
auto update_op = std::make_shared<ov::opset12::ScatterElementsUpdate>(gather_op,
130-
dst_idx, update_gather_op, ov::opset13::Constant::create(ov::element::i64, {}, {2}));
129+
auto updatekv_op = std::make_shared<ov::opset12::ScatterElementsUpdate>(gather_op,
130+
dst_idx, updatekv_gather_op, ov::opset13::Constant::create(ov::element::i64, {}, {2}));
131131

132132
// Replace the source output for all consumers of the input tensor
133133
for (auto& consumer : consumers) {
134-
consumer.replace_source_output(update_op->output(0));
134+
consumer.replace_source_output(updatekv_op->output(0));
135135
}
136136
}
137137

@@ -279,10 +279,10 @@ void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
279279
}
280280

281281
if (key_value_input_names.size() != key_value_output_names.size()) {
282-
ORT_THROW("Found different sizes between key_value_input_names (",
283-
key_value_input_names.size(),
284-
") and key_value_output_names (",
285-
key_value_output_names.size(),
282+
ORT_THROW("Found different sizes between key_value_input_names (",
283+
key_value_input_names.size(),
284+
") and key_value_output_names (",
285+
key_value_output_names.size(),
286286
"). They couldn't be paired.");
287287
}
288288

onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
#include "openvino/pass/manager.hpp"
1515
#include "openvino/pass/make_stateful.hpp"
16-
#include "openvino/opsets/opset3.hpp"
1716
#include "openvino/opsets/opset12.hpp"
1817
#include "openvino/opsets/opset13.hpp"
1918

0 commit comments

Comments
 (0)