|
2 | 2 | // Licensed under the MIT License |
3 | 3 |
|
4 | 4 | #include "core/providers/openvino/ov_stateful_patch_utils.h" |
| 5 | +#include "core/providers/shared_library/provider_api.h" |
| 6 | +#include "core/common/common.h" |
5 | 7 |
|
6 | 8 | namespace onnxruntime { |
7 | 9 | namespace openvino_ep { |
@@ -132,50 +134,135 @@ void MakeStateful(std::shared_ptr<ov::Model>& ov_model, |
132 | 134 | manager.run_passes(ov_model); |
133 | 135 | } |
134 | 136 |
|
135 | | -// Converted to C++ from below reference URL: |
136 | | -// https://github.com/huggingface/optimum-intel/blob/main/optimum/exporters/openvino/stateful.py#L281 |
137 | | -void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) { |
| 137 | +// Helper function to extract KV patterns from output names dynamically |
| 138 | +// |
| 139 | +// Example: Given output names ["present_key_cross_0", "present_key_cross_1", "present_value_cross_0", "present_value_cross_1", "logits"] |
| 140 | +// key_value_output_names = ["present_key_cross_0", "present_key_cross_1", "present_value_cross_0", "present_value_cross_1"] |
| 141 | +// unique_patterns = {"key_cross", "value_cross"} |
| 142 | +std::pair<std::vector<std::string>, std::unordered_set<std::string>> ExtractKVPatternsFromOutputs(const std::shared_ptr<ov::Model>& model) { |
| 143 | + std::vector<std::string> key_value_output_names; |
| 144 | + std::unordered_set<std::string> unique_patterns; |
| 145 | + |
| 146 | + const std::string prefix = "present_"; |
| 147 | + const size_t prefix_len = prefix.length(); |
| 148 | + for (const ov::Output<ov::Node>& output : model->outputs()) { |
| 149 | + const auto& names = output.get_names(); |
| 150 | + for (const auto& name : names) { |
| 151 | + if (name.find(prefix) == 0 && name.length() > prefix_len) { |
| 152 | + size_t last_underscore_pos = name.rfind('_'); |
| 153 | + // Extract pattern between "present_" and the last underscore |
| 154 | + if (last_underscore_pos != std::string::npos && last_underscore_pos > prefix_len) { |
| 155 | + std::string pattern = name.substr(prefix_len, last_underscore_pos - prefix_len); |
| 156 | + if (!pattern.empty()) { |
| 157 | + unique_patterns.insert(pattern); |
| 158 | + key_value_output_names.push_back(name); |
| 159 | + } |
| 160 | + } |
| 161 | + break; |
| 162 | + } |
| 163 | + } |
| 164 | + } |
| 165 | + |
| 166 | + if (unique_patterns.size() > 2) { |
| 167 | + ORT_THROW("More than two unique KV patterns found in output names."); |
| 168 | + } |
| 169 | + return std::make_pair(key_value_output_names, unique_patterns); |
| 170 | +} |
| 171 | + |
| 172 | +// Main function to extract KV tensors using dynamic pattern matching |
| 173 | +// |
| 174 | +// Example: Given input names ["input_ids", "attention_mask", "past_key_cross_0", "past_key_cross_1", "past_value_cross_0", "past_value_cross_1"] |
| 175 | +// kv_patterns = {"key_cross", "value_cross"} |
| 176 | +// |
| 177 | +// key_value_input_names = ["past_key_cross_0", "past_key_cross_1", "past_value_cross_0", "past_value_cross_1"] |
| 178 | +// not_kv_inputs = ["input_ids", "attention_mask"] |
| 179 | +std::pair<std::vector<std::string>, std::vector<std::string>> ExtractInputKVTensors( |
| 180 | + const std::shared_ptr<ov::Model>& model, const std::unordered_set<std::string>& kv_patterns) { |
| 181 | + |
138 | 182 | std::vector<std::string> key_value_input_names; |
139 | 183 | std::vector<std::string> not_kv_inputs; |
| 184 | + |
| 185 | + if (kv_patterns.empty()) { |
| 186 | + // Fallback: use original substring matching |
| 187 | + for (const ov::Output<ov::Node>& input : model->inputs()) { |
| 188 | + const auto& names = input.get_names(); |
| 189 | + const std::string input_name = input.get_any_name(); |
| 190 | + |
| 191 | + bool is_kv_input = false; |
| 192 | + for (const auto& name : names) { |
| 193 | + if (name.find("key_values") != std::string::npos || |
| 194 | + name.find("keys") != std::string::npos || |
| 195 | + name.find("values") != std::string::npos) { |
| 196 | + key_value_input_names.push_back(name); |
| 197 | + is_kv_input = true; |
| 198 | + break; |
| 199 | + } |
| 200 | + } |
| 201 | + |
| 202 | + if (!is_kv_input) { |
| 203 | + not_kv_inputs.push_back(input_name); |
| 204 | + } |
| 205 | + } |
| 206 | + |
| 207 | + return std::make_pair(key_value_input_names, not_kv_inputs); |
| 208 | + } |
| 209 | + |
| 210 | + // Inline helper function to check if name is matched with provided pattern followed by "_%d" |
| 211 | + auto matches_pattern = [](const std::string& name, const std::string& pattern) -> bool { |
| 212 | + size_t pos = name.find(pattern); |
| 213 | + if (pos == std::string::npos) { |
| 214 | + return false; |
| 215 | + } |
| 216 | + |
| 217 | + size_t after_pattern = pos + pattern.length(); |
| 218 | + if (after_pattern >= name.length() || name[after_pattern] != '_') { |
| 219 | + return false; |
| 220 | + } |
| 221 | + |
| 222 | + std::string suffix = name.substr(after_pattern + 1); |
| 223 | + return !suffix.empty() && std::all_of(suffix.begin(), suffix.end(), ::isdigit); |
| 224 | + }; |
| 225 | + |
140 | 226 | for (const ov::Output<ov::Node>& input : model->inputs()) { |
141 | 227 | auto& names = input.get_names(); |
142 | | - |
143 | 228 | bool found = false; |
144 | | - for (auto& name : names) { |
145 | | - if (name.find("key_values") != std::string::npos) { |
146 | | - key_value_input_names.push_back(name); |
147 | | - found = true; |
148 | | - break; |
149 | | - } else if (name.find("keys") != std::string::npos) { |
150 | | - key_value_input_names.push_back(name); |
151 | | - found = true; |
152 | | - break; |
153 | | - } else if (name.find("values") != std::string::npos) { |
154 | | - key_value_input_names.push_back(name); |
155 | | - found = true; |
156 | | - break; |
| 229 | + |
| 230 | + // Check if any input name contains either key or value pattern |
| 231 | + for (const auto& name : names) { |
| 232 | + for (const auto& pattern : kv_patterns) { |
| 233 | + if (matches_pattern(name, pattern)) { |
| 234 | + key_value_input_names.push_back(name); |
| 235 | + found = true; |
| 236 | + break; |
| 237 | + } |
157 | 238 | } |
| 239 | + if (found) break; |
158 | 240 | } |
159 | 241 |
|
160 | 242 | if (!found) { |
161 | 243 | not_kv_inputs.push_back(input.get_any_name()); |
162 | 244 | } |
163 | 245 | } |
164 | 246 |
|
165 | | - std::vector<std::string> key_value_output_names; |
166 | | - for (const ov::Output<ov::Node>& output : model->outputs()) { |
167 | | - auto& names = output.get_names(); |
168 | | - for (auto& name : names) { |
169 | | - if (name.find("present") != std::string::npos) { |
170 | | - key_value_output_names.push_back(name); |
171 | | - break; |
172 | | - } |
173 | | - } |
174 | | - } |
| 247 | + return std::make_pair(key_value_input_names, not_kv_inputs); |
| 248 | +} |
| 249 | + |
| 250 | +// Updated PatchStatefulDecoder function |
| 251 | +void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) { |
| 252 | + // Use the dynamic pattern-based extraction logic |
| 253 | + auto [key_value_output_names, extracted_patterns] = ExtractKVPatternsFromOutputs(model); |
| 254 | + auto [key_value_input_names, not_kv_inputs] = ExtractInputKVTensors(model, extracted_patterns); |
175 | 255 |
|
176 | 256 | if (key_value_input_names.empty() || key_value_output_names.empty()) { |
177 | | - std::cout << "no key_value_input_names or key_value_output_names found" << std::endl; |
178 | | - return; |
| 257 | + ORT_THROW("No key_value_input_names or key_value_output_names found"); |
| 258 | + } |
| 259 | + |
| 260 | + if (key_value_input_names.size() != key_value_output_names.size()) { |
| 261 | + ORT_THROW("Found different sizes between key_value_input_names (", |
| 262 | + key_value_input_names.size(), |
| 263 | + ") and key_value_output_names (", |
| 264 | + key_value_output_names.size(), |
| 265 | + "). They couldn't be paired."); |
179 | 266 | } |
180 | 267 |
|
181 | 268 | // By default, batch is the 0 - th but chatglm uses 1 - st dimension as batch |
|
0 commit comments