From 7938d245ad3c20493a7c6e6021d7d1e3110e8d52 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 11 Sep 2025 16:14:06 -0700 Subject: [PATCH] Support plugin EPs with original SessionOptionsAppendExecutionProvider API function. --- onnxruntime/core/session/ort_env.cc | 10 ++ onnxruntime/core/session/ort_env.h | 3 + .../core/session/provider_registration.cc | 92 ++++++++++++++++++- onnxruntime/test/autoep/library/ep.cc | 2 +- onnxruntime/test/autoep/test_execution.cc | 40 ++++++++ 5 files changed, 145 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/session/ort_env.cc b/onnxruntime/core/session/ort_env.cc index 1bd8a18d7255f..b7384e778d682 100644 --- a/onnxruntime/core/session/ort_env.cc +++ b/onnxruntime/core/session/ort_env.cc @@ -46,6 +46,16 @@ OrtEnv::~OrtEnv() { #endif } +/*static*/ +OrtEnv::UniquePtr OrtEnv::GetInstanceIfExists() { + std::lock_guard lock(m_); + if (p_instance_) { + ++ref_count_; + } + + return OrtEnv::UniquePtr(p_instance_, OrtEnv::Release); +} + OrtEnv* OrtEnv::GetInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_info, onnxruntime::common::Status& status, const OrtThreadingOptions* tp_options) { diff --git a/onnxruntime/core/session/ort_env.h b/onnxruntime/core/session/ort_env.h index 94c8e0a6ea2e8..20dec9b574b9f 100644 --- a/onnxruntime/core/session/ort_env.h +++ b/onnxruntime/core/session/ort_env.h @@ -35,6 +35,9 @@ struct OrtEnv { onnxruntime::common::Status& status, const OrtThreadingOptions* tp_options = nullptr); + using UniquePtr = std::unique_ptr; + static UniquePtr GetInstanceIfExists(); + static void Release(OrtEnv* env_ptr); const onnxruntime::Environment& GetEnvironment() const { diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 48d52ae3cf428..997d9130e7d89 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -17,6 +17,12 @@ #include "core/session/ort_apis.h" #include "core/providers/openvino/openvino_provider_factory_creator.h" +#if !defined(ORT_MINIMAL_BUILD) +#include "core/session/environment.h" +#include "core/session/ort_env.h" +#include "core/session/utils.h" +#endif + #ifdef _WIN32 #include #include "core/platform/tracing.h" @@ -73,6 +79,58 @@ OrtStatus* ParseProviderOptions(_In_reads_(num_keys) const char* const* provider return nullptr; } + +#if !defined(ORT_MINIMAL_BUILD) +/// +/// Tries to append a plugin EP with the given name to the session options. If the plugin EP was registered with +/// ORT, creates an EP factory with all OrtEpDevice instances that the factory supports. +/// +/// The OrtSessionOptions instance. +/// The name of the EP to find among registered plugin EPs. +/// User-specified EP option keys. +/// User-specified EP option values. +/// Output parameter set to true if this function successfully appended a plugin EP +/// factory to the session options. +/// A status indicating success or an error. +onnxruntime::Status TryAppendPluginEp(OrtSessionOptions* options, + const char* ep_name, + gsl::span ep_option_keys, + gsl::span ep_option_vals, + /*out*/ bool& added_ep) { + added_ep = false; + + if (OrtEnv::UniquePtr ort_env = OrtEnv::GetInstanceIfExists(); ort_env.get() != nullptr) { + const onnxruntime::Environment& env = ort_env->GetEnvironment(); + const auto& all_ep_devices = env.GetOrtEpDevices(); + + // Find all OrtEpDevices with the target EP name. + std::vector ep_devices; + for (const OrtEpDevice* ep_device : all_ep_devices) { + if (ep_device->ep_name == ep_name) { + ep_devices.push_back(ep_device); + } + } + + if (!ep_devices.empty()) { + // Add factory for EP that supports the selected EP devices. + std::unique_ptr provider_factory = nullptr; + + ORT_RETURN_IF_ERROR(CreateIExecutionProviderFactoryForEpDevices( + env, + options->value, + ep_devices, + ep_option_keys, + ep_option_vals, + /*output*/ provider_factory)); + options->provider_factories.push_back(std::move(provider_factory)); + added_ep = true; + } + } + + return Status::OK(); +} +#endif // !defined(ORT_MINIMAL_BUILD) + } // namespace /** * Implementation of OrtApis functions for provider registration. @@ -147,7 +205,9 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, } #endif - auto create_not_supported_status = [&provider_name]() { + bool provider_is_not_supported = false; + auto create_not_supported_status = [&provider_name, &provider_is_not_supported]() { + provider_is_not_supported = true; return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, (std::string(provider_name) + " execution provider is not supported in this build. ").c_str()); }; @@ -185,6 +245,20 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, }); if (ep_to_append_iter == supported_eps.end()) { +#if !defined(ORT_MINIMAL_BUILD) + // Unknown EP name: try to append a plugin EP from a library registered with the environment. + bool added_plugin_ep = false; + ORT_API_RETURN_IF_STATUS_NOT_OK(TryAppendPluginEp( + options, provider_name, + gsl::span(provider_options_keys, num_keys), + gsl::span(provider_options_values, num_keys), + /*out*/ added_plugin_ep)); + + if (added_plugin_ep) { + return nullptr; + } +#endif + return create_unknown_provider_status(supported_eps); } @@ -328,6 +402,22 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, status = create_unknown_provider_status(supported_eps); } + if (provider_is_not_supported) { + // Provider is not supported in this build of ORT. Check if it was instead registered in a plugin EP library. +#if !defined(ORT_MINIMAL_BUILD) + bool added_plugin_ep = false; + ORT_API_RETURN_IF_STATUS_NOT_OK(TryAppendPluginEp( + options, provider_name, + gsl::span(provider_options_keys, num_keys), + gsl::span(provider_options_values, num_keys), + /*out*/ added_plugin_ep)); + + if (added_plugin_ep) { + return nullptr; + } +#endif + } + return status; API_IMPL_END } diff --git a/onnxruntime/test/autoep/library/ep.cc b/onnxruntime/test/autoep/library/ep.cc index e4265713d2d0a..8dc289b1ea7b6 100644 --- a/onnxruntime/test/autoep/library/ep.cc +++ b/onnxruntime/test/autoep/library/ep.cc @@ -233,7 +233,7 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG for (const auto& node : nodes) { auto op_type = node.GetOperatorType(); - if (op_type != "Mul") { + if (op_type == "Mul") { // Check that Mul has inputs/output of type float std::vector inputs = node.GetInputs(); std::vector outputs = node.GetOutputs(); diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index 0f4a654f116c4..50b1cbed52dfb 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -98,6 +98,7 @@ TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { // Create model compilation options from the session options. Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetFlags(OrtCompileApiFlags::OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED); compile_options.SetInputModelPath(input_model_file); compile_options.SetOutputModelPath(output_model_file); @@ -107,6 +108,45 @@ TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { ASSERT_TRUE(std::filesystem::exists(output_model_file)); } } + +// Tests compiling a model using a plugin EP. Notably, this test uses the +// "SessionOptionsAppendExecutionProvider" API function that accepts the EP name as an argument (not OrtEpDevices). +TEST(OrtEpLibrary, PluginEp_OldAppendEpApi_GenEpContextModel) { + const OrtApi& c_api = Ort::GetApi(); + ASSERT_ORTSTATUS_OK(c_api.RegisterExecutionProviderLibrary(*ort_env, + Utils::example_ep_info.registration_name.c_str(), + Utils::example_ep_info.library_path.c_str())); + const std::string& plugin_ep_name = Utils::example_ep_info.registration_name; + + { + const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx"); + const ORTCHAR_T* output_model_file = ORT_TSTR("plugin_ep_mul_1_ctx.onnx"); + std::filesystem::remove(output_model_file); + + // Create session options with example plugin EP. + // NOTE: Using the old "SessionOptionsAppendExecutionProvider" API function. + // It will use a registered plugin EP library if there is no built-in EP with the given name. + // The plugin EP library must support at least one OrtHardwareDevice to be chosen. + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider(plugin_ep_name, ep_options); + + // Create model compilation options from the session options. + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetFlags(OrtCompileApiFlags::OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + + // Compile the model. + ASSERT_CXX_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + // Make sure the compiled model was generated. + ASSERT_TRUE(std::filesystem::exists(output_model_file)); + } + + ASSERT_ORTSTATUS_OK(c_api.UnregisterExecutionProviderLibrary(*ort_env, + Utils::example_ep_info.registration_name.c_str())); +} + } // namespace test } // namespace onnxruntime