Skip to content

Commit

Permalink
fix compiler error
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms committed Feb 7, 2025
1 parent 3360dfd commit 5f7da9f
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 50 deletions.
3 changes: 3 additions & 0 deletions onnxruntime/core/framework/compute_capability.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ struct ComputeCapability {
// Optimization: std::function<Status(const Graph& graph, const ComputeCapability& this_optimization, ComputeCapability& cc_to_update)>
std::function<Status(Graph&, const ComputeCapability&, ComputeCapability&)> optimization_func;

// Optional key/value strings to configure an optimizer
std::unordered_map<std::string, std::string> optimization_configs;

// optional ComputeCapability instances for sets of nodes within this ComputeCapability that should be optimized.
// when an optimization is applied, ORT will update this ComputeCapability to reflect the changes made.
// IndexedSubGraph.nodes:
Expand Down
59 changes: 25 additions & 34 deletions onnxruntime/core/optimizer/graph_optimizer_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,63 +15,54 @@ GraphOptimizerRegistry::GraphOptimizerRegistry() {
logger_ = &logging::LoggingManager::DefaultLogger();
}

common::Status GraphOptimizerRegistry::Register(std::unique_ptr<GraphTransformer> transformer) {
const auto& name = transformer->Name();
if (name_to_transformer_map_.find(name) != name_to_transformer_map_.end() &&
name_to_transformer_map_.at(name)) {
LOGS(*logger_, WARNING) << "This optimizer is already created and registered " << name;
return Status::OK();
}

name_to_transformer_map_[name] = transformer.get();
transformer_list_.push_back(std::move(transformer));

if (name == kCONSTANT_FOLDING_DQ) {
transformer_name_to_selection_func_[name] = ConstantFoldingDQ_selection;
}

return Status::OK();
}

common::Status GraphOptimizerRegistry::AddPredefinedOptimizerNames(std::vector<std::string>& optimizer_names) {
for (auto name : optimizer_names) {
if (name_to_transformer_map_.find(name) != name_to_transformer_map_.end()) {
LOGS(*logger_, WARNING) << "This transformer name is already added " << name;
return Status::OK();
}
name_to_transformer_map_[name] = nullptr; // The transformer will be instantizted only when EP requests it

if (name == kCONSTANT_FOLDING_DQ) {
transformer_name_to_selection_func_[name] = ConstantFoldingDQ_selection;
}
}
return Status::OK();
}

std::unique_ptr<GraphTransformer> GraphOptimizerRegistry::CreateOptimizer(std::string& name, std::unordered_map<std::string, std::string>& key_value_configs) {
std::unique_ptr<GraphTransformer> transformer;
common::Status GraphOptimizerRegistry::CreateOptimizer(std::string& name, std::unordered_map<std::string, std::string>& key_value_configs) {
if (name == kCONSTANT_FOLDING_DQ) {
const InlinedHashSet<NodeIndex> node_index_set = {};
return std::make_unique<ConstantFoldingDQ>(*cpu_ep_, false /*skip_dequantize_linear*/,
session_options_->config_options, node_index_set);
auto transformer = std::make_unique<ConstantFoldingDQ>(*cpu_ep_, false /*skip_dequantize_linear*/,
session_options_->config_options, node_index_set);
Get()->Register(std::move(transformer));
return Status::OK();
}
LOGS(*logger_, WARNING) << "Can't create optimizer " << name;
return transformer;

LOGS(*logger_, WARNING) << "Can't create optimizer for " << name << ". It's not in the predefined optimizer list.";
return Status::OK();
}

std::optional<std::function<std::vector<std::unique_ptr<ComputeCapability>>(const GraphViewer&)>> GraphOptimizerRegistry::GetSelectionFunc(std::string& name,
std::unordered_map<std::string, std::string>& key_value_configs) const {
if (name_to_transformer_map_.find(name) == name_to_transformer_map_.end()) {
LOGS(*logger_, WARNING) << "Can't find optimizer " << name;
return std::nullopt;
common::Status GraphOptimizerRegistry::Register(std::unique_ptr<GraphTransformer> transformer) {
const auto& name = transformer->Name();
if (name_to_transformer_map_.find(name) != name_to_transformer_map_.end() &&
name_to_transformer_map_.at(name)) {
LOGS(*logger_, WARNING) << "This optimizer is already created and registered " << name;
return Status::OK();
}

// Create and register if the transformer instance is not created.
if (!name_to_transformer_map_.at(name)) {
auto new_transformer = Get()->CreateOptimizer(name, key_value_configs);
Get()->Register(std::move(new_transformer));
}
name_to_transformer_map_[name] = transformer.get();
transformer_list_.push_back(std::move(transformer));

return Status::OK();
}

std::optional<std::function<std::vector<std::unique_ptr<ComputeCapability>>(const GraphViewer&)>> GraphOptimizerRegistry::GetSelectionFunc(std::string& name) const {
auto lookup = transformer_name_to_selection_func_.find(name);
if (lookup != transformer_name_to_selection_func_.end()) {
return transformer_name_to_selection_func_.at(name);
}
LOGS(*logger_, WARNING) << "Can't find selection function of " << name;
return std::nullopt;
}

Expand Down
23 changes: 15 additions & 8 deletions onnxruntime/core/optimizer/graph_optimizer_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ class GraphOptimizerRegistry {
const logging::Logger& logger);

/**
* Create optimizer instance.
* Create and register optimizer.
*/
std::unique_ptr<GraphTransformer> CreateOptimizer(std::string& name, std::unordered_map<std::string, std::string>& key_value_configs);
common::Status GraphOptimizerRegistry::CreateOptimizer(std::string& name, std::unordered_map<std::string, std::string>& key_value_configs);

/**
* Get optimizer by name.
Expand All @@ -67,23 +67,30 @@ class GraphOptimizerRegistry {
common::Status Register(std::unique_ptr<GraphTransformer> transformer);

/**
* Get optimizer selection function requested by EP. If the optimizer name can't be found, return nullopt.
*
* Please note that this function also creates and registers the optimizer if its instance is not existed.
* Get optimizer selection function. If the optimizer name can't be found, return nullopt.
*/
std::optional<std::function<std::vector<std::unique_ptr<ComputeCapability>>(const GraphViewer&)>> GraphOptimizerRegistry::GetSelectionFunc(std::string& name,
std::unordered_map<std::string, std::string>& key_value_configs) const;
std::optional<std::function<std::vector<std::unique_ptr<ComputeCapability>>(const GraphViewer&)>> GraphOptimizerRegistry::GetSelectionFunc(std::string& name) const;

/**
* Add CPU EP reference from InferenceSession as it's needed for some optimizers, ex: ConstantFoldingDQ.
*/
common::Status AddCpuEpReference(onnxruntime::IExecutionProvider* cpu_ep);

/**
* Add Session Options reference from InferenceSession as it's needed for some optimizers, ex: ConstantFoldingDQ.
* Get CPU EP reference.
*/
onnxruntime::IExecutionProvider* GetCpuEpReference() const { return cpu_ep_; }

/**
* Add session options reference from InferenceSession as it's needed for some optimizers, ex: ConstantFoldingDQ.
*/
common::Status AddSessionOptionsReference(onnxruntime::SessionOptions* session_options);

/**
* Get Session Options reference.
*/
onnxruntime::SessionOptions* GetSessionOptionsReference() const { return session_options_; }

private:
InlinedVector<std::unique_ptr<GraphTransformer>> transformer_list_;
InlinedHashMap<std::string, GraphTransformer*> name_to_transformer_map_;
Expand Down
19 changes: 16 additions & 3 deletions onnxruntime/core/optimizer/selection_and_optimization_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,22 @@ Status ConstantFoldingDQ_optimization(Graph& graph, const ComputeCapability& opt
}

auto optimizer_registry = onnxruntime::GraphOptimizerRegistry::Get();

ConstantFoldingDQ* transformer = static_cast<ConstantFoldingDQ*>(optimizer_registry->GetTransformerByName(optimizer_name));
transformer->UpdateNodeIndexSet(dq_node_index_set);

// ConstantFoldingDQ optimizer doesn't need the key/value strings.
std::unordered_map<std::string, std::string> key_value_configs = optimization_cc.optimization_configs;

// Don't use CreateOptimizer as ConstantFoldingDQ needs dq_node_index_set for instantiation.
// optimizer_registry->CreateOptimizer(optimizer_name, key_value_configs);

// Create ConstantFoldingDQ optimizer if it's not existed.
if (!optimizer_registry->GetTransformerByName(optimizer_name)) {
auto transformer = std::make_unique<ConstantFoldingDQ>(*optimizer_registry->GetCpuEpReference(),
false /*skip_dequantize_linear*/,
optimizer_registry->GetSessionOptionsReference()->config_options,
dq_node_index_set);
optimizer_registry->Register(std::move(transformer));
}


// apply constant folding on DQ nodes
optimizer_registry->ApplyTransformer(graph, optimizer_name, *logger);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ struct ProviderHost {
virtual const OrtApiBase* OrtGetApiBase() = 0;

virtual Status GetOptimizerByName(const std::string& name,
std::unordered_map<std::string, std::string>& key_value_configs,
std::function<std::vector<std::unique_ptr<ComputeCapability>>(const GraphViewer&)>& selection_func) = 0;

virtual void* HeapAllocate(size_t size) = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2670,8 +2670,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
*/

std::function<std::vector<std::unique_ptr<ComputeCapability>>(const GraphViewer&)> selection_func;
std::unordered_map<std::string, std::string> key_value_configs = {};
auto status = g_host->GetOptimizerByName("ConstantFoldingDQ", key_value_configs, selection_func);
auto status = g_host->GetOptimizerByName("ConstantFoldingDQ", selection_func);
std::vector<std::unique_ptr<ComputeCapability>> selection_cc;
if (selection_func) {
selection_cc = selection_func(graph);
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,11 @@ struct ProviderHostImpl : ProviderHost {
const OrtApiBase* OrtGetApiBase() override { return ::OrtGetApiBase(); }

Status GetOptimizerByName(const std::string& name,
std::unordered_map<std::string, std::string>& key_value_configs,
std::function<std::vector<std::unique_ptr<ComputeCapability>>(const GraphViewer&)>& selection_func) override {
std::string optimizer_name(name);

auto optimizer_registry = onnxruntime::GraphOptimizerRegistry::Get();
auto func = optimizer_registry->GetSelectionFunc(optimizer_name, key_value_configs);
auto func = optimizer_registry->GetSelectionFunc(optimizer_name);

if (func.has_value()) {
selection_func = func.value();
Expand Down

0 comments on commit 5f7da9f

Please sign in to comment.