Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions onnxruntime/core/session/ort_env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ OrtEnv::~OrtEnv() {
#endif
}

/*static*/
OrtEnv::UniquePtr OrtEnv::GetInstanceIfExists() {
std::lock_guard<std::mutex> lock(m_);
if (p_instance_) {
++ref_count_;
}

return OrtEnv::UniquePtr(p_instance_, OrtEnv::Release);
}

OrtEnv* OrtEnv::GetInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_info,
onnxruntime::common::Status& status,
const OrtThreadingOptions* tp_options) {
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/session/ort_env.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ struct OrtEnv {
onnxruntime::common::Status& status,
const OrtThreadingOptions* tp_options = nullptr);

using UniquePtr = std::unique_ptr<OrtEnv, void (*)(OrtEnv*)>;
static UniquePtr GetInstanceIfExists();

static void Release(OrtEnv* env_ptr);

const onnxruntime::Environment& GetEnvironment() const {
Expand Down
92 changes: 91 additions & 1 deletion onnxruntime/core/session/provider_registration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
#include "core/session/ort_apis.h"
#include "core/providers/openvino/openvino_provider_factory_creator.h"

#if !defined(ORT_MINIMAL_BUILD)
#include "core/session/environment.h"
#include "core/session/ort_env.h"
#include "core/session/utils.h"
#endif

#ifdef _WIN32
#include <winmeta.h>
#include "core/platform/tracing.h"
Expand Down Expand Up @@ -73,6 +79,58 @@

return nullptr;
}

#if !defined(ORT_MINIMAL_BUILD)
/// <summary>
/// Tries to append a plugin EP with the given name to the session options. If the plugin EP was registered with
/// ORT, creates an EP factory with all OrtEpDevice instances that the factory supports.
/// </summary>
/// <param name="options">The OrtSessionOptions instance.</param>
/// <param name="ep_name">The name of the EP to find among registered plugin EPs.</param>
/// <param name="ep_option_keys">User-specified EP option keys.</param>
/// <param name="ep_option_vals">User-specified EP option values.</param>
/// <param name="added_ep">Output parameter set to true if this function successfully appended a plugin EP
/// factory to the session options.</param>
/// <returns>A status indicating success or an error.</returns>
onnxruntime::Status TryAppendPluginEp(OrtSessionOptions* options,
const char* ep_name,
gsl::span<const char* const> ep_option_keys,
gsl::span<const char* const> ep_option_vals,
/*out*/ bool& added_ep) {
added_ep = false;

if (OrtEnv::UniquePtr ort_env = OrtEnv::GetInstanceIfExists(); ort_env.get() != nullptr) {
const onnxruntime::Environment& env = ort_env->GetEnvironment();
const auto& all_ep_devices = env.GetOrtEpDevices();

// Find all OrtEpDevices with the target EP name.
std::vector<const OrtEpDevice*> ep_devices;

Check warning on line 107 in onnxruntime/core/session/provider_registration.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/session/provider_registration.cc:107: Add #include <vector> for vector<> [build/include_what_you_use] [4]
for (const OrtEpDevice* ep_device : all_ep_devices) {
if (ep_device->ep_name == ep_name) {
ep_devices.push_back(ep_device);
}
}

if (!ep_devices.empty()) {
// Add factory for EP that supports the selected EP devices.
std::unique_ptr<IExecutionProviderFactory> provider_factory = nullptr;

Check warning on line 116 in onnxruntime/core/session/provider_registration.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/session/provider_registration.cc:116: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]

ORT_RETURN_IF_ERROR(CreateIExecutionProviderFactoryForEpDevices(
env,
options->value,
ep_devices,
ep_option_keys,
ep_option_vals,
/*output*/ provider_factory));
options->provider_factories.push_back(std::move(provider_factory));
added_ep = true;
}
}

return Status::OK();
}
#endif // !defined(ORT_MINIMAL_BUILD)

} // namespace
/**
* Implementation of OrtApis functions for provider registration.
Expand Down Expand Up @@ -147,7 +205,9 @@
}
#endif

auto create_not_supported_status = [&provider_name]() {
bool provider_is_not_supported = false;
auto create_not_supported_status = [&provider_name, &provider_is_not_supported]() {
provider_is_not_supported = true;
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
(std::string(provider_name) + " execution provider is not supported in this build. ").c_str());
};
Expand Down Expand Up @@ -185,6 +245,20 @@
});

if (ep_to_append_iter == supported_eps.end()) {
#if !defined(ORT_MINIMAL_BUILD)
// Unknown EP name: try to append a plugin EP from a library registered with the environment.
bool added_plugin_ep = false;
ORT_API_RETURN_IF_STATUS_NOT_OK(TryAppendPluginEp(
options, provider_name,
gsl::span<const char* const>(provider_options_keys, num_keys),
gsl::span<const char* const>(provider_options_values, num_keys),
/*out*/ added_plugin_ep));

if (added_plugin_ep) {
return nullptr;
}
#endif

return create_unknown_provider_status(supported_eps);
}

Expand Down Expand Up @@ -328,6 +402,22 @@
status = create_unknown_provider_status(supported_eps);
}

if (provider_is_not_supported) {
// Provider is not supported in this build of ORT. Check if it was instead registered in a plugin EP library.
#if !defined(ORT_MINIMAL_BUILD)
bool added_plugin_ep = false;
ORT_API_RETURN_IF_STATUS_NOT_OK(TryAppendPluginEp(
options, provider_name,
gsl::span<const char* const>(provider_options_keys, num_keys),
gsl::span<const char* const>(provider_options_values, num_keys),
/*out*/ added_plugin_ep));

if (added_plugin_ep) {
return nullptr;
}
#endif
}

return status;
API_IMPL_END
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/autoep/library/ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG
for (const auto& node : nodes) {
auto op_type = node.GetOperatorType();

if (op_type != "Mul") {
if (op_type == "Mul") {
// Check that Mul has inputs/output of type float
std::vector<Ort::ConstValueInfo> inputs = node.GetInputs();
std::vector<Ort::ConstValueInfo> outputs = node.GetOutputs();
Expand Down
40 changes: 40 additions & 0 deletions onnxruntime/test/autoep/test_execution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ TEST(OrtEpLibrary, PluginEp_GenEpContextModel) {

// Create model compilation options from the session options.
Ort::ModelCompilationOptions compile_options(*ort_env, session_options);
compile_options.SetFlags(OrtCompileApiFlags::OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED);
compile_options.SetInputModelPath(input_model_file);
compile_options.SetOutputModelPath(output_model_file);

Expand All @@ -107,6 +108,45 @@ TEST(OrtEpLibrary, PluginEp_GenEpContextModel) {
ASSERT_TRUE(std::filesystem::exists(output_model_file));
}
}

// Tests compiling a model using a plugin EP. Notably, this test uses the
// "SessionOptionsAppendExecutionProvider" API function that accepts the EP name as an argument (not OrtEpDevices).
TEST(OrtEpLibrary, PluginEp_OldAppendEpApi_GenEpContextModel) {
const OrtApi& c_api = Ort::GetApi();
ASSERT_ORTSTATUS_OK(c_api.RegisterExecutionProviderLibrary(*ort_env,
Utils::example_ep_info.registration_name.c_str(),
Utils::example_ep_info.library_path.c_str()));
const std::string& plugin_ep_name = Utils::example_ep_info.registration_name;

{
const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx");
const ORTCHAR_T* output_model_file = ORT_TSTR("plugin_ep_mul_1_ctx.onnx");
std::filesystem::remove(output_model_file);

// Create session options with example plugin EP.
// NOTE: Using the old "SessionOptionsAppendExecutionProvider" API function.
// It will use a registered plugin EP library if there is no built-in EP with the given name.
// The plugin EP library must support at least one OrtHardwareDevice to be chosen.
Ort::SessionOptions session_options;
std::unordered_map<std::string, std::string> ep_options;
session_options.AppendExecutionProvider(plugin_ep_name, ep_options);

// Create model compilation options from the session options.
Ort::ModelCompilationOptions compile_options(*ort_env, session_options);
compile_options.SetFlags(OrtCompileApiFlags::OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED);
compile_options.SetInputModelPath(input_model_file);
compile_options.SetOutputModelPath(output_model_file);

// Compile the model.
ASSERT_CXX_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options));
// Make sure the compiled model was generated.
ASSERT_TRUE(std::filesystem::exists(output_model_file));
}

ASSERT_ORTSTATUS_OK(c_api.UnregisterExecutionProviderLibrary(*ort_env,
Utils::example_ep_info.registration_name.c_str()));
}

} // namespace test
} // namespace onnxruntime

Expand Down
Loading