From 25ed3bd6338b3c343feccd249682976bdde466bf Mon Sep 17 00:00:00 2001 From: benHeid Date: Thu, 20 Feb 2025 16:00:20 +0100 Subject: [PATCH] Add possibility to add graph optimization level to session options --- src/config.cpp | 15 +++++++++++++++ src/config.h | 1 + src/models/model.cpp | 4 ++++ 3 files changed, 20 insertions(+) diff --git a/src/config.cpp b/src/config.cpp index 5e6ad7c03..1b89e0e2e 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -62,6 +62,19 @@ struct ProviderOptionsArray_Element : JSON::Element { ProviderOptionsObject_Element object_{v_}; }; +GraphOptimizationLevel getGraphOptimizationLevel(std::string_view name) { + if (name =="ORT_DISABLE_ALL") { + return ORT_DISABLE_ALL; + } else if (name == "ORT_ENABLE_BASIC") { + return ORT_ENABLE_BASIC; + } else if (name == "ORT_ENABLE_EXTENDED") { + return ORT_ENABLE_EXTENDED; + } else if (name == "ORT_ENABLE_ALL") { + return ORT_ENABLE_ALL; + } else + throw JSON::unknown_value_error{}; +} + struct SessionOptions_Element : JSON::Element { explicit SessionOptions_Element(Config::SessionOptions& v) : v_{v} {} @@ -94,6 +107,8 @@ struct SessionOptions_Element : JSON::Element { v_.ep_context_enable = JSON::Get(value); else if (name == "use_env_allocators") v_.use_env_allocators = JSON::Get(value); + else if (name == "graph_optimization_level") + v_.graph_optimization_level = getGraphOptimizationLevel(JSON::Get(value)); else throw JSON::unknown_value_error{}; } diff --git a/src/config.h b/src/config.h index 5ba21683c..99282233d 100644 --- a/src/config.h +++ b/src/config.h @@ -48,6 +48,7 @@ struct Config { bool use_env_allocators{}; std::vector provider_options; + std::optional graph_optimization_level; }; struct Model { diff --git a/src/models/model.cpp b/src/models/model.cpp index 76f857068..ff00f02c3 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -407,6 +407,10 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_ session_options.AddConfigEntry("session.use_env_allocators", "1"); } + if (config_session_options.graph_optimization_level.has_value()) { + session_options.SetGraphOptimizationLevel(config_session_options.graph_optimization_level.value()); + } + for (auto& provider_options : config_session_options.provider_options) { if (provider_options.name == "cuda") { auto ort_provider_options = OrtCUDAProviderOptionsV2::Create();