Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable global thread pool in python #23495

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
set_default_logger_severity, # noqa: F401
set_default_logger_verbosity, # noqa: F401
set_seed, # noqa: F401
set_global_thread_pool_sizes, # noqa: F401
)

import_capi_exception = None
Expand Down
81 changes: 75 additions & 6 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "core/session/abi_session_options_impl.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/session/provider_bridge_ort.h"
#include "core/session/onnxruntime_cxx_api.h"

#include "core/session/lora_adapters.h"

Expand Down Expand Up @@ -1388,6 +1389,15 @@
#endif

void addGlobalMethods(py::module& m) {
m.def("set_global_thread_pool_sizes", [](int intra_op_num_threads, int inter_op_num_threads) {
OrtThreadPoolParams intra_op_param = {intra_op_num_threads};
OrtThreadPoolParams inter_op_param = {inter_op_num_threads};
OrtThreadingOptions to = {intra_op_param, inter_op_param};
SetGlobalThreadingOptions(to);
},
py::arg("intra_op_num_threads") = 0, // Default value for intra_op_num_threads
py::arg("inter_op_num_threads") = 0, // Default value for inter_op_num_threads
"Set the number of threads used by the global thread pools for intra and inter op parallelism.");
m.def("get_default_session_options", &GetDefaultCPUSessionOptions, "Return a default session_options instance.");
m.def("get_session_initializer", &SessionObjectInitializer::Get, "Return a default session object initializer.");
m.def(
Expand Down Expand Up @@ -1728,6 +1738,13 @@
},
R"pbdoc(VLOG level if DEBUG build and session_log_severity_level is 0.
Applies to session load, initialization, etc. Default is 0.)pbdoc")
.def_property(
"use_per_session_threads",
[](const PySessionOptions* options) -> bool { return options->value.use_per_session_threads; },
[](PySessionOptions* options, bool use_per_session_threads) -> void {
options->value.use_per_session_threads = use_per_session_threads;
},
R"pbdoc(Whether to use per-session thread pool. Default is True.)pbdoc")
.def_property(
"intra_op_num_threads",
[](const PySessionOptions* options) -> int { return options->value.intra_op_param.thread_pool_size; },
Expand Down Expand Up @@ -1999,6 +2016,14 @@
auto env = GetEnv();
std::unique_ptr<PyInferenceSession> sess;

if (CheckIfUsingGlobalThreadPool() && so.value.use_per_session_threads) {
ORT_THROW("use_per_session_threads must be false when using a global thread pool");
}

if (so.value.intra_op_param.thread_pool_size != 0 || so.value.inter_op_param.thread_pool_size != 0) {
LOGS_DEFAULT(WARNING) << "session options intra_op_param.thread_pool_size and inter_op_param.thread_pool_size are ignored when using a global thread pool";
}

// separate creation of the session from model loading unless we have to read the config from the model.
// in a minimal build we only support load via Load(...) and not at session creation time
if (load_config_from_model) {
Expand Down Expand Up @@ -2303,7 +2328,7 @@

import_array1(false);

auto env = GetEnv();
// auto env = GetEnv();

addGlobalMethods(m);
addObjectMethods(m, RegisterExecutionProviders);
Expand Down Expand Up @@ -2360,6 +2385,16 @@
// For all the related details and why it is needed see "Modern C++ design" by A. Alexandrescu Chapter 6.
class EnvInitializer {
public:

static void SetGlobalThreadingOptions(const OrtThreadingOptions& tp_options) {
if (EnvInitializer::initialized) {
ORT_THROW("Cannot set global threading options after the environment has been initialized.");
}

EnvInitializer::tp_options = tp_options;
EnvInitializer::use_per_session_threads = false;
}

static std::shared_ptr<onnxruntime::Environment> SharedInstance() {
// Guard against attempts to resurrect the singleton
if (EnvInitializer::destroyed) {
Expand All @@ -2369,16 +2404,36 @@
return env_holder.Get();
}

static bool GetUsePerSessionThreads() {
return use_per_session_threads;
}

private:
EnvInitializer() {
std::unique_ptr<Environment> env_ptr;
Env::Default().GetTelemetryProvider().SetLanguageProjection(OrtLanguageProjection::ORT_PROJECTION_PYTHON);
OrtPybindThrowIfError(Environment::Create(std::make_unique<LoggingManager>(
std::make_unique<CLogSink>(),
Severity::kWARNING, false, LoggingManager::InstanceType::Default,
&SessionObjectInitializer::default_logger_id),
env_ptr));

// create logging manager here
std::unique_ptr<LoggingManager> lm = std::make_unique<LoggingManager>(
std::make_unique<CLogSink>(),
Severity::kWARNING, false, LoggingManager::InstanceType::Default,
&SessionObjectInitializer::default_logger_id);

if (EnvInitializer::use_per_session_threads) {
OrtPybindThrowIfError(Environment::Create(std::move(lm),
env_ptr
));
} else {
OrtPybindThrowIfError(Environment::Create(std::move(lm),

Check warning on line 2427 in onnxruntime/python/onnxruntime_pybind_state.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/python/onnxruntime_pybind_state.cc:2427: Add #include <utility> for move [build/include_what_you_use] [4]
env_ptr,
&EnvInitializer::tp_options,
true
));
}

session_env_ = std::shared_ptr<Environment>(env_ptr.release());

initialized = true;
destroyed = false;
}

Expand All @@ -2392,12 +2447,26 @@

std::shared_ptr<Environment> session_env_;

static OrtThreadingOptions tp_options;
static bool use_per_session_threads;
static bool initialized;
static bool destroyed;
};

OrtThreadingOptions EnvInitializer::tp_options;
bool EnvInitializer::use_per_session_threads = true;
bool EnvInitializer::initialized = false;
bool EnvInitializer::destroyed = false;
} // namespace

void SetGlobalThreadingOptions(const OrtThreadingOptions& tp_options) {
EnvInitializer::SetGlobalThreadingOptions(tp_options);
}

bool CheckIfUsingGlobalThreadPool() {
return !EnvInitializer::GetUsePerSessionThreads();
}

std::shared_ptr<onnxruntime::Environment> GetEnv() {
return EnvInitializer::SharedInstance();
}
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_state_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,9 @@ class SessionObjectInitializer {
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(pop)
#endif

void SetGlobalThreadingOptions(const OrtThreadingOptions& tp_options);
bool CheckIfUsingGlobalThreadPool();
std::shared_ptr<Environment> GetEnv();

// Initialize an InferenceSession.
Expand Down
Loading