Skip to content

Commit 2a0d722

Browse files
committed
add reorder KV cache API
1 parent 729d59d commit 2a0d722

File tree

7 files changed

+79
-0
lines changed

7 files changed

+79
-0
lines changed

onnxruntime/core/providers/openvino/backend_manager.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -844,5 +844,11 @@ void BackendManager::RewindKVCache(size_t index) {
844844
}
845845
}
846846

847+
void BackendManager::ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) {
848+
if (concrete_backend_) {
849+
concrete_backend_->ReorderKVCache(src_indices, dst_indices);
850+
}
851+
}
852+
847853
} // namespace openvino_ep
848854
} // namespace onnxruntime

onnxruntime/core/providers/openvino/backend_manager.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class BackendManager {
3131
void TryExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph, bool include_embed_data);
3232
ov::CompiledModel GetOVCompiledModel();
3333
void RewindKVCache(size_t index);
34+
void ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices);
3435

3536
private:
3637
std::unique_ptr<ONNX_NAMESPACE::ModelProto> GetModelProtoFromFusedNode(

onnxruntime/core/providers/openvino/backends/basic_backend.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,12 @@ void BasicBackend::RewindKVCache(size_t index) {
334334
});
335335
}
336336

337+
void BasicBackend::ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) {
338+
infer_req_pool_->forEachIdleRequest([&](OVInferRequestPtr& infer_request) {
339+
infer_request->ReorderKVCache(src_indices, dst_indices);
340+
});
341+
}
342+
337343
void BasicBackend::Infer(OrtKernelContext* ctx) const {
338344
Ort::KernelContext context(ctx);
339345

onnxruntime/core/providers/openvino/backends/basic_backend.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ class BasicBackend : public IBackend {
138138
return exe_network_.Get();
139139
}
140140
void RewindKVCache(size_t index) override;
141+
void ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) override;
141142

142143
private:
143144
bool ValidateSubgraph(std::map<std::string, std::shared_ptr<ov::Node>>& const_outputs_map);

onnxruntime/core/providers/openvino/ibackend.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class IBackend {
1818
virtual ov::CompiledModel GetOVCompiledModel() = 0;
1919
virtual ~IBackend() = default;
2020
virtual void RewindKVCache(size_t index) {}
21+
virtual void ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) {}
2122
};
2223
using ptr_stream_t = std::unique_ptr<ModelBlobWrapper>;
2324
class BackendFactory {

onnxruntime/core/providers/openvino/openvino_execution_provider.cc

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,68 @@ common::Status OpenVINOExecutionProvider::SetEpDynamicOptions(gsl::span<const ch
286286
LOGS_DEFAULT(WARNING) << "kvcache_rewind index is < 0:\t" << index;
287287
}
288288
}
289+
} else if (key == "kvcache_reorder") {
290+
// Convert kvcache_reorder value format "1,2,3;4,5,6" into two vectors
291+
// src_indices = [1,2,3], dst_indices = [4,5,6]
292+
size_t delimiter_pos = value.find(';');
293+
if (delimiter_pos == std::string::npos) {
294+
LOGS_DEFAULT(WARNING) << "kvcache_reorder value format is incorrect, expected format is 'x1,x2,x3;y1,y2,y3' where x and y are comma-separated int64_t lists";
295+
return Status::OK();
296+
}
297+
298+
std::string src_string = value.substr(0, delimiter_pos);
299+
std::string dst_string = value.substr(delimiter_pos + 1);
300+
301+
std::vector<size_t> src_indices;
302+
std::vector<size_t> dst_indices;
303+
304+
try {
305+
// Parse source indices from comma-separated string
306+
std::stringstream src_stream(src_string);
307+
std::string src_token;
308+
while (std::getline(src_stream, src_token, ',')) {
309+
// Trim whitespace
310+
src_token.erase(0, src_token.find_first_not_of(" \t"));
311+
src_token.erase(src_token.find_last_not_of(" \t") + 1);
312+
313+
if (!src_token.empty()) {
314+
int64_t index = std::stoll(src_token);
315+
if (index >= 0) {
316+
src_indices.push_back(static_cast<size_t>(index));
317+
} else {
318+
LOGS_DEFAULT(WARNING) << "kvcache_reorder src_index is < 0: " << index;
319+
}
320+
}
321+
}
322+
323+
// Parse destination indices from comma-separated string
324+
std::stringstream dst_stream(dst_string);
325+
std::string dst_token;
326+
while (std::getline(dst_stream, dst_token, ',')) {
327+
// Trim whitespace
328+
dst_token.erase(0, dst_token.find_first_not_of(" \t"));
329+
dst_token.erase(dst_token.find_last_not_of(" \t") + 1);
330+
331+
if (!dst_token.empty()) {
332+
int64_t index = std::stoll(dst_token);
333+
if (index >= 0) {
334+
dst_indices.push_back(static_cast<size_t>(index));
335+
} else {
336+
LOGS_DEFAULT(WARNING) << "kvcache_reorder dst_index is < 0: " << index;
337+
}
338+
}
339+
}
340+
341+
} catch (const std::exception& e) {
342+
LOGS_DEFAULT(WARNING) << "Conversion for kvcache_reorder string value to int64_t indices failed. "
343+
<< "Exception: " << e.what();
344+
return Status::OK();
345+
}
346+
347+
// Trigger KVCache Reorder for target Backend with vector arguments
348+
for (auto& backend : backend_managers_) {
349+
backend.ReorderKVCache(src_indices, dst_indices);
350+
}
289351
} else {
290352
// Handle unknown options
291353
LOGS_DEFAULT(WARNING) << "Unknown key/value pair - ignoring " << key << "/" << value;

onnxruntime/core/providers/openvino/ov_interface.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ class OVInferRequest {
134134
return ovInfReq;
135135
}
136136
virtual void RewindKVCache([[maybe_unused]] size_t index) {}
137+
virtual void ReorderKVCache([[maybe_unused]] const std::vector<size_t>& src_indices, [[maybe_unused]] const std::vector<size_t>& dst_indices) {}
137138
};
138139

139140
class StatefulOVInferRequest : public OVInferRequest {
@@ -142,6 +143,7 @@ class StatefulOVInferRequest : public OVInferRequest {
142143

143144
void Infer() override;
144145
void RewindKVCache(size_t index) override;
146+
void ReorderKVCache(const std::vector<size_t>& src_indices, const std::vector<size_t>& dst_indices) override;
145147
void FillTensor(const std::string& tensor_name, const ov::element::Type& type,
146148
const std::vector<size_t>& shape, int32_t fill_value);
147149
void CacheTensor(const std::string& tensor_name, std::vector<int64_t>& cache);

0 commit comments

Comments
 (0)