Skip to content

Commit d0bac3e

Browse files
CVS-175736 - [OVEP] Optimize Stateful Path: use output-to-input strategy to get the pairs of KV name (#845)
* use output-to-input strategy to get the pairs of KV name * minor change * remove regex for extracting pattern * Address review * Design strict KV patterns: only two separately for key and value; patterns have to be followed by _%d * simplify code structure * address review * remove useless comment * add brief example to explain the functionalities --------- Co-authored-by: MayureshV1 <[email protected]>
1 parent 51493cd commit d0bac3e

File tree

1 file changed

+116
-29
lines changed

1 file changed

+116
-29
lines changed

onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc

Lines changed: 116 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
// Licensed under the MIT License
33

44
#include "core/providers/openvino/ov_stateful_patch_utils.h"
5+
#include "core/providers/shared_library/provider_api.h"
6+
#include "core/common/common.h"
57

68
namespace onnxruntime {
79
namespace openvino_ep {
@@ -132,50 +134,135 @@ void MakeStateful(std::shared_ptr<ov::Model>& ov_model,
132134
manager.run_passes(ov_model);
133135
}
134136

135-
// Converted to C++ from below reference URL:
136-
// https://github.com/huggingface/optimum-intel/blob/main/optimum/exporters/openvino/stateful.py#L281
137-
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
137+
// Helper function to extract KV patterns from output names dynamically
138+
//
139+
// Example: Given output names ["present_key_cross_0", "present_key_cross_1", "present_value_cross_0", "present_value_cross_1", "logits"]
140+
// key_value_output_names = ["present_key_cross_0", "present_key_cross_1", "present_value_cross_0", "present_value_cross_1"]
141+
// unique_patterns = {"key_cross", "value_cross"}
142+
std::pair<std::vector<std::string>, std::unordered_set<std::string>> ExtractKVPatternsFromOutputs(const std::shared_ptr<ov::Model>& model) {
143+
std::vector<std::string> key_value_output_names;
144+
std::unordered_set<std::string> unique_patterns;
145+
146+
const std::string prefix = "present_";
147+
const size_t prefix_len = prefix.length();
148+
for (const ov::Output<ov::Node>& output : model->outputs()) {
149+
const auto& names = output.get_names();
150+
for (const auto& name : names) {
151+
if (name.find(prefix) == 0 && name.length() > prefix_len) {
152+
size_t last_underscore_pos = name.rfind('_');
153+
// Extract pattern between "present_" and the last underscore
154+
if (last_underscore_pos != std::string::npos && last_underscore_pos > prefix_len) {
155+
std::string pattern = name.substr(prefix_len, last_underscore_pos - prefix_len);
156+
if (!pattern.empty()) {
157+
unique_patterns.insert(pattern);
158+
key_value_output_names.push_back(name);
159+
}
160+
}
161+
break;
162+
}
163+
}
164+
}
165+
166+
if (unique_patterns.size() > 2) {
167+
ORT_THROW("More than two unique KV patterns found in output names.");
168+
}
169+
return std::make_pair(key_value_output_names, unique_patterns);
170+
}
171+
172+
// Main function to extract KV tensors using dynamic pattern matching
173+
//
174+
// Example: Given input names ["input_ids", "attention_mask", "past_key_cross_0", "past_key_cross_1", "past_value_cross_0", "past_value_cross_1"]
175+
// kv_patterns = {"key_cross", "value_cross"}
176+
//
177+
// key_value_input_names = ["past_key_cross_0", "past_key_cross_1", "past_value_cross_0", "past_value_cross_1"]
178+
// not_kv_inputs = ["input_ids", "attention_mask"]
179+
std::pair<std::vector<std::string>, std::vector<std::string>> ExtractInputKVTensors(
180+
const std::shared_ptr<ov::Model>& model, const std::unordered_set<std::string>& kv_patterns) {
181+
138182
std::vector<std::string> key_value_input_names;
139183
std::vector<std::string> not_kv_inputs;
184+
185+
if (kv_patterns.empty()) {
186+
// Fallback: use original substring matching
187+
for (const ov::Output<ov::Node>& input : model->inputs()) {
188+
const auto& names = input.get_names();
189+
const std::string input_name = input.get_any_name();
190+
191+
bool is_kv_input = false;
192+
for (const auto& name : names) {
193+
if (name.find("key_values") != std::string::npos ||
194+
name.find("keys") != std::string::npos ||
195+
name.find("values") != std::string::npos) {
196+
key_value_input_names.push_back(name);
197+
is_kv_input = true;
198+
break;
199+
}
200+
}
201+
202+
if (!is_kv_input) {
203+
not_kv_inputs.push_back(input_name);
204+
}
205+
}
206+
207+
return std::make_pair(key_value_input_names, not_kv_inputs);
208+
}
209+
210+
// Inline helper function to check if name is matched with provided pattern followed by "_%d"
211+
auto matches_pattern = [](const std::string& name, const std::string& pattern) -> bool {
212+
size_t pos = name.find(pattern);
213+
if (pos == std::string::npos) {
214+
return false;
215+
}
216+
217+
size_t after_pattern = pos + pattern.length();
218+
if (after_pattern >= name.length() || name[after_pattern] != '_') {
219+
return false;
220+
}
221+
222+
std::string suffix = name.substr(after_pattern + 1);
223+
return !suffix.empty() && std::all_of(suffix.begin(), suffix.end(), ::isdigit);
224+
};
225+
140226
for (const ov::Output<ov::Node>& input : model->inputs()) {
141227
auto& names = input.get_names();
142-
143228
bool found = false;
144-
for (auto& name : names) {
145-
if (name.find("key_values") != std::string::npos) {
146-
key_value_input_names.push_back(name);
147-
found = true;
148-
break;
149-
} else if (name.find("keys") != std::string::npos) {
150-
key_value_input_names.push_back(name);
151-
found = true;
152-
break;
153-
} else if (name.find("values") != std::string::npos) {
154-
key_value_input_names.push_back(name);
155-
found = true;
156-
break;
229+
230+
// Check if any input name contains either key or value pattern
231+
for (const auto& name : names) {
232+
for (const auto& pattern : kv_patterns) {
233+
if (matches_pattern(name, pattern)) {
234+
key_value_input_names.push_back(name);
235+
found = true;
236+
break;
237+
}
157238
}
239+
if (found) break;
158240
}
159241

160242
if (!found) {
161243
not_kv_inputs.push_back(input.get_any_name());
162244
}
163245
}
164246

165-
std::vector<std::string> key_value_output_names;
166-
for (const ov::Output<ov::Node>& output : model->outputs()) {
167-
auto& names = output.get_names();
168-
for (auto& name : names) {
169-
if (name.find("present") != std::string::npos) {
170-
key_value_output_names.push_back(name);
171-
break;
172-
}
173-
}
174-
}
247+
return std::make_pair(key_value_input_names, not_kv_inputs);
248+
}
249+
250+
// Updated PatchStatefulDecoder function
251+
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
252+
// Use the dynamic pattern-based extraction logic
253+
auto [key_value_output_names, extracted_patterns] = ExtractKVPatternsFromOutputs(model);
254+
auto [key_value_input_names, not_kv_inputs] = ExtractInputKVTensors(model, extracted_patterns);
175255

176256
if (key_value_input_names.empty() || key_value_output_names.empty()) {
177-
std::cout << "no key_value_input_names or key_value_output_names found" << std::endl;
178-
return;
257+
ORT_THROW("No key_value_input_names or key_value_output_names found");
258+
}
259+
260+
if (key_value_input_names.size() != key_value_output_names.size()) {
261+
ORT_THROW("Found different sizes between key_value_input_names (",
262+
key_value_input_names.size(),
263+
") and key_value_output_names (",
264+
key_value_output_names.size(),
265+
"). They couldn't be paired.");
179266
}
180267

181268
// By default, batch is the 0 - th but chatglm uses 1 - st dimension as batch

0 commit comments

Comments
 (0)