Skip to content

Commit

Permalink
Add possibility to change graph optimization level
Browse files Browse the repository at this point in the history
  • Loading branch information
benHeid committed Feb 20, 2025
1 parent fe3604a commit 3e2d44e
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@ struct ProviderOptionsArray_Element : JSON::Element {
ProviderOptionsObject_Element object_{v_};
};

GraphOptimizationLevel getGraphOptimizationLevel(std::string_view name) {
if (name =="ORT_DISABLE_ALL") {
return ORT_DISABLE_ALL;
} else if (name == "ORT_ENABLE_BASIC") {
return ORT_ENABLE_BASIC;
} else if (name == "ORT_ENABLE_EXTENDED") {
return ORT_ENABLE_EXTENDED;
} else if (name == "ORT_ENABLE_ALL") {
return ORT_ENABLE_ALL;
} else
throw JSON::unknown_value_error{};
}

struct SessionOptions_Element : JSON::Element {
explicit SessionOptions_Element(Config::SessionOptions& v) : v_{v} {}

Expand Down Expand Up @@ -94,6 +107,8 @@ struct SessionOptions_Element : JSON::Element {
v_.ep_context_enable = JSON::Get<bool>(value);
else if (name == "use_env_allocators")
v_.use_env_allocators = JSON::Get<bool>(value);
else if (name == "graph_optimization_level")
v_.graph_optimization_level = getGraphOptimizationLevel(JSON::Get<std::string_view>(value));
else
throw JSON::unknown_value_error{};
}
Expand Down
1 change: 1 addition & 0 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct Config {
bool use_env_allocators{};

std::vector<ProviderOptions> provider_options;
std::optional<GraphOptimizationLevel> graph_optimization_level;
};

struct Model {
Expand Down
4 changes: 4 additions & 0 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,10 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
session_options.AddConfigEntry("session.use_env_allocators", "1");
}

if (config_session_options.graph_optimization_level.has_value()) {
session_options.SetGraphOptimizationLevel(config_session_options.graph_optimization_level.value());
}

for (auto& provider_options : config_session_options.provider_options) {
if (provider_options.name == "cuda") {
auto ort_provider_options = OrtCUDAProviderOptionsV2::Create();
Expand Down

0 comments on commit 3e2d44e

Please sign in to comment.