Skip to content

Commit

Permalink
Add possibility to specify graph optimization level for session options.
Browse files Browse the repository at this point in the history
The program was tested solely for our own use cases, which might differ from yours.

Benedikt Heidrich <[email protected]> Mercedes-Benz Tech Innovation GmbH ; Licensed subject to the terms of the MPL-2.0. https://github.com/mercedes-benz/foss/blob/master/PROVIDER_INFORMATION.md
  • Loading branch information
benHeid committed Feb 20, 2025
1 parent fe3604a commit 97a94cc
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@ struct ProviderOptionsArray_Element : JSON::Element {
ProviderOptionsObject_Element object_{v_};
};

GraphOptimizationLevel getGraphOptimizationLevel(std::string_view name) {
if (name =="ORT_DISABLE_ALL") {
return ORT_DISABLE_ALL;
} else if (name == "ORT_ENABLE_BASIC") {
return ORT_ENABLE_BASIC;
} else if (name == "ORT_ENABLE_EXTENDED") {
return ORT_ENABLE_EXTENDED;
} else if (name == "ORT_ENABLE_ALL") {
return ORT_ENABLE_ALL;
} else
throw JSON::unknown_value_error{};
}

struct SessionOptions_Element : JSON::Element {
explicit SessionOptions_Element(Config::SessionOptions& v) : v_{v} {}

Expand Down Expand Up @@ -94,6 +107,8 @@ struct SessionOptions_Element : JSON::Element {
v_.ep_context_enable = JSON::Get<bool>(value);
else if (name == "use_env_allocators")
v_.use_env_allocators = JSON::Get<bool>(value);
else if (name == "graph_optimization_level")
v_.graph_optimization_level = getGraphOptimizationLevel(JSON::Get<std::string_view>(value));
else
throw JSON::unknown_value_error{};
}
Expand Down
1 change: 1 addition & 0 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct Config {
bool use_env_allocators{};

std::vector<ProviderOptions> provider_options;
std::optional<GraphOptimizationLevel> graph_optimization_level;
};

struct Model {
Expand Down
4 changes: 4 additions & 0 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,10 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
session_options.AddConfigEntry("session.use_env_allocators", "1");
}

if (config_session_options.graph_optimization_level.has_value()) {
session_options.SetGraphOptimizationLevel(config_session_options.graph_optimization_level.value());
}

for (auto& provider_options : config_session_options.provider_options) {
if (provider_options.name == "cuda") {
auto ort_provider_options = OrtCUDAProviderOptionsV2::Create();
Expand Down

0 comments on commit 97a94cc

Please sign in to comment.