From 5f7da9f7f430268993bd3bbdc160920fc1091be6 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 6 Feb 2025 16:47:44 -0800 Subject: [PATCH] fix compiler error --- .../core/framework/compute_capability.h | 3 + .../optimizer/graph_optimizer_registry.cc | 59 ++++++++----------- .../core/optimizer/graph_optimizer_registry.h | 23 +++++--- .../selection_and_optimization_func.cc | 19 +++++- .../shared_library/provider_interfaces.h | 1 - .../tensorrt/tensorrt_execution_provider.cc | 3 +- .../core/session/provider_bridge_ort.cc | 3 +- 7 files changed, 61 insertions(+), 50 deletions(-) diff --git a/onnxruntime/core/framework/compute_capability.h b/onnxruntime/core/framework/compute_capability.h index 2d1d4e0e0153f..dfe8536fe983a 100644 --- a/onnxruntime/core/framework/compute_capability.h +++ b/onnxruntime/core/framework/compute_capability.h @@ -30,6 +30,9 @@ struct ComputeCapability { // Optimization: std::function std::function optimization_func; + // Optional key/value strings to configure an optimizer + std::unordered_map optimization_configs; + // optional ComputeCapability instances for sets of nodes within this ComputeCapability that should be optimized. // when an optimization is applied, ORT will update this ComputeCapability to reflect the changes made. // IndexedSubGraph.nodes: diff --git a/onnxruntime/core/optimizer/graph_optimizer_registry.cc b/onnxruntime/core/optimizer/graph_optimizer_registry.cc index 6caa6ad2e28f0..c7a702f1e4a52 100644 --- a/onnxruntime/core/optimizer/graph_optimizer_registry.cc +++ b/onnxruntime/core/optimizer/graph_optimizer_registry.cc @@ -15,24 +15,6 @@ GraphOptimizerRegistry::GraphOptimizerRegistry() { logger_ = &logging::LoggingManager::DefaultLogger(); } -common::Status GraphOptimizerRegistry::Register(std::unique_ptr transformer) { - const auto& name = transformer->Name(); - if (name_to_transformer_map_.find(name) != name_to_transformer_map_.end() && - name_to_transformer_map_.at(name)) { - LOGS(*logger_, WARNING) << "This optimizer is already created and registered " << name; - return Status::OK(); - } - - name_to_transformer_map_[name] = transformer.get(); - transformer_list_.push_back(std::move(transformer)); - - if (name == kCONSTANT_FOLDING_DQ) { - transformer_name_to_selection_func_[name] = ConstantFoldingDQ_selection; - } - - return Status::OK(); -} - common::Status GraphOptimizerRegistry::AddPredefinedOptimizerNames(std::vector& optimizer_names) { for (auto name : optimizer_names) { if (name_to_transformer_map_.find(name) != name_to_transformer_map_.end()) { @@ -40,38 +22,47 @@ common::Status GraphOptimizerRegistry::AddPredefinedOptimizerNames(std::vector GraphOptimizerRegistry::CreateOptimizer(std::string& name, std::unordered_map& key_value_configs) { - std::unique_ptr transformer; +common::Status GraphOptimizerRegistry::CreateOptimizer(std::string& name, std::unordered_map& key_value_configs) { if (name == kCONSTANT_FOLDING_DQ) { const InlinedHashSet node_index_set = {}; - return std::make_unique(*cpu_ep_, false /*skip_dequantize_linear*/, - session_options_->config_options, node_index_set); + auto transformer = std::make_unique(*cpu_ep_, false /*skip_dequantize_linear*/, + session_options_->config_options, node_index_set); + Get()->Register(std::move(transformer)); + return Status::OK(); } - LOGS(*logger_, WARNING) << "Can't create optimizer " << name; - return transformer; + + LOGS(*logger_, WARNING) << "Can't create optimizer for " << name << ". It's not in the predefined optimizer list."; + return Status::OK(); } -std::optional>(const GraphViewer&)>> GraphOptimizerRegistry::GetSelectionFunc(std::string& name, - std::unordered_map& key_value_configs) const { - if (name_to_transformer_map_.find(name) == name_to_transformer_map_.end()) { - LOGS(*logger_, WARNING) << "Can't find optimizer " << name; - return std::nullopt; +common::Status GraphOptimizerRegistry::Register(std::unique_ptr transformer) { + const auto& name = transformer->Name(); + if (name_to_transformer_map_.find(name) != name_to_transformer_map_.end() && + name_to_transformer_map_.at(name)) { + LOGS(*logger_, WARNING) << "This optimizer is already created and registered " << name; + return Status::OK(); } - // Create and register if the transformer instance is not created. - if (!name_to_transformer_map_.at(name)) { - auto new_transformer = Get()->CreateOptimizer(name, key_value_configs); - Get()->Register(std::move(new_transformer)); - } + name_to_transformer_map_[name] = transformer.get(); + transformer_list_.push_back(std::move(transformer)); + + return Status::OK(); +} +std::optional>(const GraphViewer&)>> GraphOptimizerRegistry::GetSelectionFunc(std::string& name) const { auto lookup = transformer_name_to_selection_func_.find(name); if (lookup != transformer_name_to_selection_func_.end()) { return transformer_name_to_selection_func_.at(name); } + LOGS(*logger_, WARNING) << "Can't find selection function of " << name; return std::nullopt; } diff --git a/onnxruntime/core/optimizer/graph_optimizer_registry.h b/onnxruntime/core/optimizer/graph_optimizer_registry.h index 99ddc6542d665..844f714104028 100644 --- a/onnxruntime/core/optimizer/graph_optimizer_registry.h +++ b/onnxruntime/core/optimizer/graph_optimizer_registry.h @@ -46,9 +46,9 @@ class GraphOptimizerRegistry { const logging::Logger& logger); /** - * Create optimizer instance. + * Create and register optimizer. */ - std::unique_ptr CreateOptimizer(std::string& name, std::unordered_map& key_value_configs); + common::Status GraphOptimizerRegistry::CreateOptimizer(std::string& name, std::unordered_map& key_value_configs); /** * Get optimizer by name. @@ -67,12 +67,9 @@ class GraphOptimizerRegistry { common::Status Register(std::unique_ptr transformer); /** - * Get optimizer selection function requested by EP. If the optimizer name can't be found, return nullopt. - * - * Please note that this function also creates and registers the optimizer if its instance is not existed. + * Get optimizer selection function. If the optimizer name can't be found, return nullopt. */ - std::optional>(const GraphViewer&)>> GraphOptimizerRegistry::GetSelectionFunc(std::string& name, - std::unordered_map& key_value_configs) const; + std::optional>(const GraphViewer&)>> GraphOptimizerRegistry::GetSelectionFunc(std::string& name) const; /** * Add CPU EP reference from InferenceSession as it's needed for some optimizers, ex: ConstantFoldingDQ. @@ -80,10 +77,20 @@ class GraphOptimizerRegistry { common::Status AddCpuEpReference(onnxruntime::IExecutionProvider* cpu_ep); /** - * Add Session Options reference from InferenceSession as it's needed for some optimizers, ex: ConstantFoldingDQ. + * Get CPU EP reference. + */ + onnxruntime::IExecutionProvider* GetCpuEpReference() const { return cpu_ep_; } + + /** + * Add session options reference from InferenceSession as it's needed for some optimizers, ex: ConstantFoldingDQ. */ common::Status AddSessionOptionsReference(onnxruntime::SessionOptions* session_options); + /** + * Get Session Options reference. + */ + onnxruntime::SessionOptions* GetSessionOptionsReference() const { return session_options_; } + private: InlinedVector> transformer_list_; InlinedHashMap name_to_transformer_map_; diff --git a/onnxruntime/core/optimizer/selection_and_optimization_func.cc b/onnxruntime/core/optimizer/selection_and_optimization_func.cc index a3fe45d89f928..6592631ff07df 100644 --- a/onnxruntime/core/optimizer/selection_and_optimization_func.cc +++ b/onnxruntime/core/optimizer/selection_and_optimization_func.cc @@ -57,9 +57,22 @@ Status ConstantFoldingDQ_optimization(Graph& graph, const ComputeCapability& opt } auto optimizer_registry = onnxruntime::GraphOptimizerRegistry::Get(); - - ConstantFoldingDQ* transformer = static_cast(optimizer_registry->GetTransformerByName(optimizer_name)); - transformer->UpdateNodeIndexSet(dq_node_index_set); + + // ConstantFoldingDQ optimizer doesn't need the key/value strings. + std::unordered_map key_value_configs = optimization_cc.optimization_configs; + + // Don't use CreateOptimizer as ConstantFoldingDQ needs dq_node_index_set for instantiation. + // optimizer_registry->CreateOptimizer(optimizer_name, key_value_configs); + + // Create ConstantFoldingDQ optimizer if it's not existed. + if (!optimizer_registry->GetTransformerByName(optimizer_name)) { + auto transformer = std::make_unique(*optimizer_registry->GetCpuEpReference(), + false /*skip_dequantize_linear*/, + optimizer_registry->GetSessionOptionsReference()->config_options, + dq_node_index_set); + optimizer_registry->Register(std::move(transformer)); + } + // apply constant folding on DQ nodes optimizer_registry->ApplyTransformer(graph, optimizer_name, *logger); diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index ec6cd85b4279e..16fc15ea76725 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -143,7 +143,6 @@ struct ProviderHost { virtual const OrtApiBase* OrtGetApiBase() = 0; virtual Status GetOptimizerByName(const std::string& name, - std::unordered_map& key_value_configs, std::function>(const GraphViewer&)>& selection_func) = 0; virtual void* HeapAllocate(size_t size) = 0; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 41005522572ac..8a899884e5892 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2670,8 +2670,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, */ std::function>(const GraphViewer&)> selection_func; - std::unordered_map key_value_configs = {}; - auto status = g_host->GetOptimizerByName("ConstantFoldingDQ", key_value_configs, selection_func); + auto status = g_host->GetOptimizerByName("ConstantFoldingDQ", selection_func); std::vector> selection_cc; if (selection_func) { selection_cc = selection_func(graph); diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 10b5a4138b336..46346e9457f21 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -214,12 +214,11 @@ struct ProviderHostImpl : ProviderHost { const OrtApiBase* OrtGetApiBase() override { return ::OrtGetApiBase(); } Status GetOptimizerByName(const std::string& name, - std::unordered_map& key_value_configs, std::function>(const GraphViewer&)>& selection_func) override { std::string optimizer_name(name); auto optimizer_registry = onnxruntime::GraphOptimizerRegistry::Get(); - auto func = optimizer_registry->GetSelectionFunc(optimizer_name, key_value_configs); + auto func = optimizer_registry->GetSelectionFunc(optimizer_name); if (func.has_value()) { selection_func = func.value();