|
| 1 | +// Copyright (C) Intel Corporation |
| 2 | +// Licensed under the MIT License |
| 3 | + |
| 4 | +#include <memory> |
| 5 | +#include <map> |
| 6 | +#include <string> |
| 7 | +#include <algorithm> |
| 8 | +#include <vector> |
| 9 | +#include <ranges> |
| 10 | +#include <format> |
| 11 | + |
| 12 | +#define ORT_API_MANUAL_INIT |
| 13 | +#include "onnxruntime_cxx_api.h" |
| 14 | +#undef ORT_API_MANUAL_INIT |
| 15 | + |
| 16 | +#include "onnxruntime_c_api.h" |
| 17 | +#include "ov_factory.h" |
| 18 | +#include "openvino/openvino.hpp" |
| 19 | +#include "ov_interface.h" |
| 20 | + |
| 21 | +using namespace onnxruntime::openvino_ep; |
| 22 | +using ov_core_singleton = onnxruntime::openvino_ep::WeakSingleton<ov::Core>; |
| 23 | + |
| 24 | +static void InitCxxApi(const OrtApiBase& ort_api_base) { |
| 25 | + static std::once_flag init_api; |
| 26 | + std::call_once(init_api, [&]() { |
| 27 | + const OrtApi* ort_api = ort_api_base.GetApi(ORT_API_VERSION); |
| 28 | + Ort::InitApi(ort_api); |
| 29 | + }); |
| 30 | +} |
| 31 | + |
| 32 | +OpenVINOEpPluginFactory::OpenVINOEpPluginFactory(ApiPtrs apis, const std::string& ov_metadevice_name, std::shared_ptr<ov::Core> core) |
| 33 | + : ApiPtrs{apis}, |
| 34 | + ep_name_(ov_metadevice_name.empty() ? provider_name_ : std::string(provider_name_) + "." + ov_metadevice_name), |
| 35 | + device_type_(ov_metadevice_name), |
| 36 | + ov_core_(std::move(core)) { |
| 37 | + OrtEpFactory::GetName = GetNameImpl; |
| 38 | + OrtEpFactory::GetVendor = GetVendorImpl; |
| 39 | + OrtEpFactory::GetVendorId = GetVendorIdImpl; |
| 40 | + OrtEpFactory::GetSupportedDevices = GetSupportedDevicesImpl; |
| 41 | + OrtEpFactory::GetVersion = GetVersionImpl; |
| 42 | + OrtEpFactory::CreateDataTransfer = CreateDataTransferImpl; |
| 43 | + |
| 44 | + ort_version_supported = ORT_API_VERSION; // Set to the ORT version we were compiled with. |
| 45 | +} |
| 46 | + |
| 47 | +const std::vector<std::string>& OpenVINOEpPluginFactory::GetOvDevices() { |
| 48 | + static std::vector<std::string> devices = ov_core_singleton::Get()->get_available_devices(); |
| 49 | + return devices; |
| 50 | +} |
| 51 | + |
| 52 | +const std::vector<std::string>& OpenVINOEpPluginFactory::GetOvMetaDevices() { |
| 53 | + static std::vector<std::string> virtual_devices = [ov_core = ov_core_singleton::Get()] { |
| 54 | + std::vector<std::string> supported_virtual_devices{}; |
| 55 | + for (const auto& meta_device : known_meta_devices_) { |
| 56 | + try { |
| 57 | + ov_core->get_property(meta_device, ov::supported_properties); |
| 58 | + supported_virtual_devices.push_back(meta_device); |
| 59 | + } catch (ov::Exception&) { |
| 60 | + // meta device isn't supported. |
| 61 | + } |
| 62 | + } |
| 63 | + return supported_virtual_devices; |
| 64 | + }(); |
| 65 | + |
| 66 | + return virtual_devices; |
| 67 | +} |
| 68 | + |
| 69 | +OrtStatus* OpenVINOEpPluginFactory::GetSupportedDevices(const OrtHardwareDevice* const* devices, |
| 70 | + size_t num_devices, |
| 71 | + OrtEpDevice** ep_devices, |
| 72 | + size_t max_ep_devices, |
| 73 | + size_t* p_num_ep_devices) { |
| 74 | + size_t& num_ep_devices = *p_num_ep_devices; |
| 75 | + |
| 76 | + // Create a map for device type mapping |
| 77 | + static const std::map<OrtHardwareDeviceType, std::string> ort_to_ov_device_name = { |
| 78 | + {OrtHardwareDeviceType::OrtHardwareDeviceType_CPU, "CPU"}, |
| 79 | + {OrtHardwareDeviceType::OrtHardwareDeviceType_GPU, "GPU"}, |
| 80 | + {OrtHardwareDeviceType::OrtHardwareDeviceType_NPU, "NPU"}, |
| 81 | + }; |
| 82 | + |
| 83 | + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { |
| 84 | + const OrtHardwareDevice& device = *devices[i]; |
| 85 | + if (ort_api.HardwareDevice_VendorId(&device) != vendor_id_) { |
| 86 | + // Not an Intel Device. |
| 87 | + continue; |
| 88 | + } |
| 89 | + |
| 90 | + auto device_type = ort_api.HardwareDevice_Type(&device); |
| 91 | + auto device_it = ort_to_ov_device_name.find(device_type); |
| 92 | + if (device_it == ort_to_ov_device_name.end()) { |
| 93 | + // We don't know about this device type |
| 94 | + continue; |
| 95 | + } |
| 96 | + |
| 97 | + const auto& ov_device_type = device_it->second; |
| 98 | + std::string ov_device_name; |
| 99 | + auto get_pci_device_id = [&](const std::string& ov_device) { |
| 100 | + try { |
| 101 | + ov::device::PCIInfo pci_info = ov_core_->get_property(ov_device, ov::device::pci_info); |
| 102 | + return pci_info.device; |
| 103 | + } catch (ov::Exception&) { |
| 104 | + return 0u; // If we can't get the PCI info, we won't have a device ID. |
| 105 | + } |
| 106 | + }; |
| 107 | + |
| 108 | + auto filtered_devices = GetOvDevices(ov_device_type); |
| 109 | + auto matched_device = filtered_devices.begin(); |
| 110 | + if (filtered_devices.size() > 1 && device_type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { |
| 111 | + // If there are multiple devices of the same type, we need to match by device ID. |
| 112 | + matched_device = std::find_if(filtered_devices.begin(), filtered_devices.end(), [&](const std::string& ov_device) { |
| 113 | + uint32_t ort_device_id = ort_api.HardwareDevice_DeviceId(&device); |
| 114 | + return ort_device_id == get_pci_device_id(ov_device); |
| 115 | + }); |
| 116 | + } |
| 117 | + |
| 118 | + if (matched_device == filtered_devices.end()) { |
| 119 | + // We didn't find a matching OpenVINO device for the OrtHardwareDevice. |
| 120 | + continue; |
| 121 | + } |
| 122 | + |
| 123 | + // these can be returned as nullptr if you have nothing to add. |
| 124 | + OrtKeyValuePairs* ep_metadata = nullptr; |
| 125 | + OrtKeyValuePairs* ep_options = nullptr; |
| 126 | + ort_api.CreateKeyValuePairs(&ep_metadata); |
| 127 | + ort_api.AddKeyValuePair(ep_metadata, ov_device_key_, matched_device->c_str()); |
| 128 | + |
| 129 | + if (IsMetaDeviceFactory()) { |
| 130 | + ort_api.AddKeyValuePair(ep_metadata, ov_meta_device_key_, device_type_.c_str()); |
| 131 | + } |
| 132 | + |
| 133 | + // Create EP device |
| 134 | + auto* status = ort_api.GetEpApi()->CreateEpDevice(this, &device, ep_metadata, ep_options, |
| 135 | + &ep_devices[num_ep_devices++]); |
| 136 | + |
| 137 | + ort_api.ReleaseKeyValuePairs(ep_metadata); |
| 138 | + ort_api.ReleaseKeyValuePairs(ep_options); |
| 139 | + |
| 140 | + if (status != nullptr) { |
| 141 | + return status; |
| 142 | + } |
| 143 | + } |
| 144 | + |
| 145 | + return nullptr; |
| 146 | +} |
| 147 | + |
| 148 | +extern "C" { |
| 149 | +// |
| 150 | +// Public symbols |
| 151 | +// |
| 152 | +OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase* ort_api_base, |
| 153 | + OrtEpFactory** factories, size_t max_factories, size_t* num_factories) { |
| 154 | + InitCxxApi(*ort_api_base); |
| 155 | + const ApiPtrs api_ptrs{Ort::GetApi(), Ort::GetEpApi(), Ort::GetModelEditorApi()}; |
| 156 | + |
| 157 | + // Get available devices from OpenVINO |
| 158 | + auto ov_core = ov_core_singleton::Get(); |
| 159 | + std::vector<std::string> supported_factories = {""}; |
| 160 | + const auto& meta_devices = OpenVINOEpPluginFactory::GetOvMetaDevices(); |
| 161 | + supported_factories.insert(supported_factories.end(), meta_devices.begin(), meta_devices.end()); |
| 162 | + |
| 163 | + const size_t required_factories = supported_factories.size(); |
| 164 | + if (max_factories < required_factories) { |
| 165 | + return Ort::Status(std::format("Not enough space to return EP factories. Need at least {} factories.", required_factories).c_str(), ORT_INVALID_ARGUMENT); |
| 166 | + } |
| 167 | + |
| 168 | + size_t factory_index = 0; |
| 169 | + for (const auto& device_name : supported_factories) { |
| 170 | + // Create a factory for this specific device |
| 171 | + factories[factory_index++] = new OpenVINOEpPluginFactory(api_ptrs, device_name, ov_core); |
| 172 | + } |
| 173 | + |
| 174 | + *num_factories = factory_index; |
| 175 | + return nullptr; |
| 176 | +} |
| 177 | + |
| 178 | +OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { |
| 179 | + delete static_cast<OpenVINOEpPluginFactory*>(factory); |
| 180 | + return nullptr; |
| 181 | +} |
| 182 | +} |
0 commit comments