diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 0d9e6db1a7748..c9a15de9ef897 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -38,6 +38,8 @@ struct OrtRunOptions; namespace onnxruntime { +class IResourceAccountant; + /** Logical device representation. */ @@ -130,7 +132,8 @@ class IExecutionProvider { */ virtual std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& kernel_lookup) const; + const IKernelLookup& kernel_lookup, + IResourceAccountant* resource_accountant = nullptr) const; /** Get kernel registry per execution provider type. diff --git a/include/onnxruntime/core/framework/op_kernel_context.h b/include/onnxruntime/core/framework/op_kernel_context.h index ac22d9130983a..3fd9ee0d8b292 100644 --- a/include/onnxruntime/core/framework/op_kernel_context.h +++ b/include/onnxruntime/core/framework/op_kernel_context.h @@ -192,7 +192,7 @@ class OpKernelContext { onnxruntime::NodeIndex GetNodeIndex() const; virtual const OrtValue* GetInputMLValue(int index) const; - const OrtValue* GetImplicitInputMLValue(int index) const; + virtual const OrtValue* GetImplicitInputMLValue(int index) const; OrtValue* GetOutputMLValue(int index); #ifdef ENABLE_ATEN @@ -204,6 +204,8 @@ class OpKernelContext { virtual OrtValue* GetOrCreateOutputMLValue(int index); + virtual int GetOrtValueIndexForOutput(int output_index) const; + private: ORT_DISALLOW_COPY_AND_ASSIGNMENT(OpKernelContext); int GetInputArgIndex(int index) const; diff --git a/include/onnxruntime/core/framework/resource_accountant.h b/include/onnxruntime/core/framework/resource_accountant.h new file mode 100644 index 0000000000000..274750a505fbd --- /dev/null +++ b/include/onnxruntime/core/framework/resource_accountant.h @@ -0,0 +1,123 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/common/inlined_containers_fwd.h" + +namespace onnxruntime { + +struct ConfigOptions; +#ifndef SHARED_PROVIDER +class Node; +#else +struct Node; +#endif + +// Common holder for potentially different resource accounting +// for different EPs +using ResourceCount = std::variant; + +/// +/// This class is used for graph partitioning by EPs +/// It stores the cumulative amount of the resource such as +/// memory that would be consumed by the graph nodes if it is assigned to the EP. +/// +/// It provides interfaces to add, remove and query the resource consumption. +/// +/// Each provider may assign its own meaning to the resource according to its constraints. +/// +class IResourceAccountant { + protected: + IResourceAccountant() = default; + IResourceAccountant(const ResourceCount& threshold) : threshold_(threshold) {} + + public: + virtual ~IResourceAccountant() = default; + virtual ResourceCount GetConsumedAmount() const = 0; + virtual void AddConsumedAmount(const ResourceCount& amount) = 0; + virtual void RemoveConsumedAmount(const ResourceCount& amount) = 0; + virtual ResourceCount ComputeResourceCount(const Node& node) const = 0; + + std::optional GetThreshold() const { + return threshold_; + } + + void SetStopAssignment() noexcept { + stop_assignment_ = true; + } + + bool IsStopIssued() const noexcept { return stop_assignment_; } + + static std::string MakeUniqueNodeName(const Node& node); + + private: + bool stop_assignment_ = false; + std::optional threshold_; +}; + +// A map of Ep Type to a resource accountant for this EP +using ResourceAccountantMap = InlinedHashMap>; + +// This struct keeps accounting of the memory allocation stats +// for a kernel during runtime if enabled. +struct NodeAllocationStats { + size_t input_sizes = 0; + size_t initializers_sizes = 0; + size_t total_dynamic_sizes = 0; + size_t total_temp_allocations = 0; + + NodeAllocationStats& operator+=(const NodeAllocationStats& other) { + input_sizes += other.input_sizes; + initializers_sizes += other.initializers_sizes; + total_dynamic_sizes += other.total_dynamic_sizes; + total_temp_allocations += other.total_temp_allocations; + return *this; + } + + void UpdateIfGreater(const NodeAllocationStats& other) { + input_sizes = std::max(input_sizes, other.input_sizes); + initializers_sizes = std::max(initializers_sizes, other.initializers_sizes); + total_dynamic_sizes = std::max(total_dynamic_sizes, other.total_dynamic_sizes); + total_temp_allocations = std::max(total_temp_allocations, other.total_temp_allocations); + } +}; + +class NodeStatsRecorder { + public: + explicit NodeStatsRecorder(const std::filesystem::path& stats_file_name); + ~NodeStatsRecorder(); + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(NodeStatsRecorder); + + const std::filesystem::path& GetNodeStatsFileName() const noexcept; + + bool ShouldAccountFor(const std::string& input_output_name) const; + + void ResetPerRunNameDeduper(); + + void ReportNodeStats(const std::string& node_name, const NodeAllocationStats& stats); + + void DumpStats(const std::filesystem::path& model_path) const; + + [[nodiscard]] static Status CreateAccountants( + const ConfigOptions& config_options, + const std::filesystem::path& model_path, + std::optional& acc_map); + + private: + void DumpStats(std::ostream& os) const; + + struct Impl; + std::unique_ptr impl_; +}; + +} // namespace onnxruntime diff --git a/include/onnxruntime/core/graph/indexed_sub_graph.h b/include/onnxruntime/core/graph/indexed_sub_graph.h index c57db41254159..e457d3dcad1f1 100644 --- a/include/onnxruntime/core/graph/indexed_sub_graph.h +++ b/include/onnxruntime/core/graph/indexed_sub_graph.h @@ -7,6 +7,8 @@ #include #include +#include "core/common/inlined_containers_fwd.h" +#include "core/framework/resource_accountant.h" #include "core/graph/basic_types.h" #include "core/graph/onnx_protobuf.h" @@ -70,9 +72,45 @@ struct IndexedSubGraph { return meta_def_.get(); } + // Check if the accounting is enabled for the current EP + bool IsAccountingEnabled() const { + return resource_accountant != nullptr && + nodes_costs.size() == nodes.size(); + } + + // Should call IsAccountingEnabled() first + // Takes the previously computed ResourceCount for the node + // (usually during GetCapabiilty()) + // if present and adds it to the consumed amount + void AccountForNode(size_t cost_index) const { + assert(cost_index < nodes_costs.size()); + resource_accountant->AddConsumedAmount(nodes_costs[cost_index]); + } + + // This computes and accounts for the resource cost for the node that just + // been fused from other nodes, and the EP did not had a chance to compute the costs. + void ComputeAndAccountForNode(const Node& node) const { + assert(resource_accountant != nullptr); + resource_accountant->AddConsumedAmount(resource_accountant->ComputeResourceCount(node)); + } + + void SetAccountant(IResourceAccountant* res_accountant) { + resource_accountant = res_accountant; + } + + // Append resource count to the list of costs for the nodes. + void AppendNodeCost(const ResourceCount& cost) { + assert(resource_accountant != nullptr); + nodes_costs.emplace_back(cost); + } + private: // subgraph meta definition. std::unique_ptr meta_def_; + // Optional resource accountant for this subgraph. + IResourceAccountant* resource_accountant = nullptr; + // Vector with resource costs for nodes above. Should have the same size + InlinedVector nodes_costs; }; } // namespace onnxruntime diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 018e1ddc81799..117a2cdabca2f 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -267,6 +267,34 @@ static const char* const kOrtSessionOptionsModelExternalInitializersFileFolderPa static const char* const kOrtSessionOptionsSavePrePackedConstantInitializers = "session.save_external_prepacked_constant_initializers"; +// Use this config when you want to collect memory stats for each node in the graph. +// The file format is a CSV file with the following columns: +// The file will be created if it does not exist, and will be overwritten if it does. +// +// The content of the file can be used to estimate memory requirements at run time including +// the temporary allocations. This operation is preferably done on a CPU device, as the model may exceed +// device memory limits in constrained environments. When enabling this option, it is important to disable +// memory patterns, as they tend to allocate large blocks to avoid fragmentation and accommodate needs of multiple +// kernels. Memory patterns may make it difficult to allocate on a device with limited memory. +// +// The collected stats then can be used to partition the graph among the devices in a way that only the +// required memory is allocated on each device. +// +// node_name, initializers_memory, dynamic_outputs_sizes, temp_allocations_size +// +// - "full path to file": there is not a default for this option. If the file can not be opened for writing, an error will be returned. +static const char* const kOrtSessionOptionsCollectNodeMemoryStatsToFile = "session.collect_node_memory_stats_to_file"; + +/// This is a composite CSV setting formatted as "memory limit in kb,file name for collected stats" +/// "limit > 0": enables Capacity Aware Partitioning for Cuda EP. `limit` is optional and when absent +/// the provider may attempt to figure out the memory available automatically. +/// The setting with no limit is expected to look like: ",file name for collected stats" +/// The EP will place nodes on device "file name" : +/// this file is expected to be found at the same folder with the model. The file contains +/// pre-recorded stats collected when running with kOrtSessionOptionsCollectNodeMemoryStatsToFile enforce (see above) +static const char* const kOrtSessionOptionsResourceCudaPartitioningSettings = + "session.resource_cuda_partitioning_settings"; + // Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file. // The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead. // "0": disable. (default) diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index 894e0daae94b6..c5046353ba528 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -614,6 +614,12 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va #endif } +#if !defined(ORT_MINIMAL_BUILD) + if (session_state_.GetNodeStatsRecorder() != nullptr) { + ort_value_to_dynamic_allocations_size_.insert_or_assign(ort_value_index, size); + } +#endif + return Status::OK(); } diff --git a/onnxruntime/core/framework/execution_frame.h b/onnxruntime/core/framework/execution_frame.h index de571f86f1c77..7b5a8fd8a4b01 100644 --- a/onnxruntime/core/framework/execution_frame.h +++ b/onnxruntime/core/framework/execution_frame.h @@ -92,10 +92,10 @@ class IExecutionFrame { Status ReleaseMLValue(int ort_value_idx); - protected: // get the ort_value_idx from NodeIndexInfo int GetNodeIdxToMLValueIdx(int index) const; + protected: OrtValue& GetMutableMLValue(int ort_value_index) { return const_cast(GetMLValue(ort_value_index)); } virtual Status ReleaseMLValueImpl(int ort_value_idx); @@ -103,6 +103,8 @@ class IExecutionFrame { // returns true if the ort_value_idx is an output from the graph bool IsOutput(int ort_value_idx) const; + const OrtValueNameIdxMap& GetOrtValueNameIdxMap() const noexcept { return ort_value_idx_map_; } + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(IExecutionFrame); @@ -166,6 +168,16 @@ class ExecutionFrame final : public IExecutionFrame { return planner_.has_value(); } +#if !defined(ORT_MINIMAL_BUILD) + std::optional GetOrtValueDynamicAllocation(int ort_value_index) const { + auto it = ort_value_to_dynamic_allocations_size_.find(ort_value_index); + if (it != ort_value_to_dynamic_allocations_size_.end()) { + return it->second; + } + return std::nullopt; + } +#endif + // This function try retrieve the inferred shapes for the given NodeArg index. // If the retrival is successful, this function returns true and false otherwise. bool TryGetInferredShape(int index, TensorShape& shape) const override; @@ -258,10 +270,14 @@ class ExecutionFrame final : public IExecutionFrame { // This field is not physical memory size. // dynamic_activation_memory_sizes_in_byte_[location] is the dynamic memory consumption on "location". std::unordered_map dynamic_activation_memory_sizes_in_byte_; +#endif +#if !defined(ORT_MINIMAL_BUILD) + // OrtValue index to the size of dynamic memory allocation. + std::unordered_map ort_value_to_dynamic_allocations_size_; +#endif // Mutex which should be acquired when executing non-thread-safe member functions. // A current example is the tracker of dynamic memory allocation. mutable std::mutex mtx_; -#endif }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/execution_provider.cc b/onnxruntime/core/framework/execution_provider.cc index b39924d4c3ff9..3a937a119d03b 100644 --- a/onnxruntime/core/framework/execution_provider.cc +++ b/onnxruntime/core/framework/execution_provider.cc @@ -13,7 +13,8 @@ namespace onnxruntime { std::vector> IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& kernel_lookup) const { + const IKernelLookup& kernel_lookup, + IResourceAccountant*) const { std::vector> result; for (const auto& node : graph.Nodes()) { if (const KernelCreateInfo* kernel_create_info = kernel_lookup.LookUpKernel(node); diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index c02e6cf3af5ab..111f8e0a5fc34 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -6,12 +6,15 @@ #include #include +#include "core/common/inlined_containers.h" +#include "core/common/string_utils.h" #include "core/framework/compute_capability.h" #include "core/framework/execution_providers.h" #include "core/framework/func_kernel.h" #include "core/framework/kernel_lookup.h" #include "core/framework/kernel_registry_manager.h" #include "core/framework/kernel_registry.h" +#include "core/framework/resource_accountant.h" #include "core/graph/function.h" #include "core/graph/function_utils.h" #include "core/graph/graph_viewer.h" @@ -93,11 +96,14 @@ static bool TryAssignNodes(Graph& graph, const IndexedSubGraph& capability, } } - for (auto node_index : capability.nodes) { - auto* node = graph.GetNode(node_index); + const bool acc_enabled = capability.IsAccountingEnabled(); + for (size_t i = 0, limit = capability.nodes.size(); i < limit; ++i) { + auto* node = graph.GetNode(capability.nodes[i]); node->SetExecutionProviderType(provider_type); + if (acc_enabled) { + capability.AccountForNode(i); + } } - return true; } @@ -114,6 +120,9 @@ static bool TryAssignSingleNode(Graph& graph, if (nullptr != node && node->GetExecutionProviderType().empty()) { // The node was not fused or assigned. Assign it to . node->SetExecutionProviderType(provider_type); + if (indexed_sub_graph.IsAccountingEnabled()) { + indexed_sub_graph.AccountForNode(0); + } return true; } @@ -132,12 +141,14 @@ struct GetCapabilityForEPParams { std::reference_wrapper transform_layout; std::reference_wrapper debug_graph_fn; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + IResourceAccountant* resource_accountant; }; auto get_capabilities = [](const IExecutionProvider& ep, const GraphViewer& graph_viewer, - const IExecutionProvider::IKernelLookup& kernel_lookup) { - auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup); + const IExecutionProvider::IKernelLookup& kernel_lookup, + IResourceAccountant* resource_accountant) { + auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup, resource_accountant); // In theory an EP could return an empty capability. Remove those. capabilities.erase(std::remove_if(capabilities.begin(), capabilities.end(), @@ -174,7 +185,7 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l { const GraphViewer graph_viewer(graph); - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant); if (capabilities.empty()) { return Status::OK(); @@ -212,7 +223,7 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l capabilities.clear(); const GraphViewer graph_viewer(graph); - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant); // all nodes with an index >= first_new_node with domain of kMSInternalNHWCDomain should be in the capabilities InlinedHashSet new_nodes_in_capabilities; @@ -261,7 +272,7 @@ static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer, logger}; // TODO: Provide EP with a capability to look inside the functions. - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, nullptr); return Status::OK(); } @@ -319,6 +330,7 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability, } if (sub_graph_available_for_assignment) { + const bool acc_enabled = capability.IsAccountingEnabled(); if (mode == GraphPartitioner::Mode::kNormal) { std::ostringstream oss; oss << provider_type << "_" << capability.GetMetaDef()->name << "_" << fused_node_unique_id++; @@ -334,6 +346,13 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability, } fused_node->SetExecutionProviderType(provider_type); + if (acc_enabled) { + // We account for the fused node. We operate under assumption + // that the fused node would use no more memory when the nodes we are fusing. + // and potentially less than that, and therefore, no threshold check is needed here. + // All threshold checks are done within the EP. + capability.ComputeAndAccountForNode(*fused_node); + } result = fused_node; } else { @@ -341,10 +360,13 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability, // This is used when exporting an ORT format model to maintain the original nodes and re-do the fusion // at runtime. The original nodes provide a fallback if fewer nodes can be fused at runtime due to device // capabilities. - for (auto node_index : capability.nodes) { - auto* node = graph.GetNode(node_index); + for (size_t i = 0, limit = capability.nodes.size(); i < limit; ++i) { + auto* node = graph.GetNode(capability.nodes[i]); if (node != nullptr) { node->SetExecutionProviderType(provider_type); + if (acc_enabled) { + capability.AccountForNode(i); + } } } } @@ -364,7 +386,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, int& fused_node_unique_id, const layout_transformation::TransformLayoutFunction& transform_layout_fn, const layout_transformation::DebugGraphFn& debug_graph_fn, - const logging::Logger& logger) { + const logging::Logger& logger, IResourceAccountant* resource_accountant) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. // doing it here saves all providers checking for this in GetCapability if (graph.NumberOfNodes() == 0) { @@ -378,7 +400,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, // we pass through the FuncManager from the top level graph ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr, fused_kernel_registry, current_ep, mode, fused_node_unique_id, - transform_layout_fn, debug_graph_fn, logger)); + transform_layout_fn, debug_graph_fn, logger, resource_accountant)); } } @@ -401,7 +423,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, std::ref(capabilities), mode, std::cref(transform_layout_fn), - std::cref(debug_graph_fn)}; + std::cref(debug_graph_fn), + resource_accountant}; ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger)); if (capabilities.empty()) { @@ -770,6 +793,7 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode, const ExecutionProviders& execution_providers, KernelRegistryManager& kernel_registry_manager, + const std::optional& acc_map, const logging::Logger& logger) { bool modified_graph = false; @@ -782,11 +806,18 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, do { // process full graph with each EP for (const auto& ep : execution_providers) { + IResourceAccountant* resource_accountant = nullptr; + if (acc_map.has_value()) { + auto hit = acc_map->find(ep->Type()); + if (hit != acc_map->end()) { + resource_accountant = hit->second.get(); + } + } ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(graph, func_mgr, kernel_registry_manager, fused_kernel_registry, *ep, mode, fused_node_unique_id, transform_layout_function, partition_params.debug_graph_fn, - logger)); + logger, resource_accountant)); } // expand any nodes that have an ONNX function definition but no matching ORT kernel. @@ -821,8 +852,8 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param auto& subgraph = *entry.second; PartitionParams subgraph_partition_params = partition_params; subgraph_partition_params.graph = std::ref(subgraph); - ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(subgraph_partition_params, kernel_registry_mgr, current_ep, - logger)); + ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(subgraph_partition_params, kernel_registry_mgr, + current_ep, logger)); } } @@ -838,6 +869,7 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param std::cref(partition_params.transform_layout_function), std::cref(partition_params.debug_graph_fn), #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + nullptr }; // clang-format on @@ -870,6 +902,9 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param Node& fused_node = graph.BeginFuseSubGraph(indexed_sub_graph, node_name); fused_node.SetExecutionProviderType(type); + if (indexed_sub_graph.IsAccountingEnabled()) { + indexed_sub_graph.ComputeAndAccountForNode(fused_node); + } // create filtered graph viewer for this set of nodes // @@ -886,6 +921,7 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // We will compile the fused nodes one by one, and fuse the subgraph if successful. for (const auto& compilation_entry : compilation_entries) { + const bool acc_enabled = compilation_entry.capability.get().sub_graph->IsAccountingEnabled(); Node& node = compilation_entry.fused_node; std::vector single_node_compute_func; ORT_RETURN_IF_ERROR(current_ep.Compile({IExecutionProvider::FusedNodeAndGraph{node, *compilation_entry.viewer}}, @@ -913,6 +949,9 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param // now that we're done compiling we can remove the original nodes from the Graph and wire in the new one graph.FinalizeFuseSubGraph(indexed_sub_graph, node); + if (acc_enabled) { + compilation_entry.capability.get().sub_graph->ComputeAndAccountForNode(node); + } } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -1032,11 +1071,18 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, ORT_RETURN_IF_ERROR(EpContextFilePathCheck(ep_context_path, graph.ModelPath())); } - ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_, logger)); + // We use this only if Resource Aware Partitioning is enabled for any of the EPs + // The map is empty if not created if not enabled + std::optional ep_acc_map; + ORT_RETURN_IF_ERROR(NodeStatsRecorder::CreateAccountants(config_options, graph.ModelPath(), ep_acc_map)); + + ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_, + ep_acc_map, logger)); if (ep_context_enabled) { std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); - std::string external_ini_file_name = config_options.GetConfigOrDefault(kOrtSessionOptionsEpContextModelExternalInitializersFileName, ""); + std::string external_ini_file_name = config_options.GetConfigOrDefault( + kOrtSessionOptionsEpContextModelExternalInitializersFileName, ""); ORT_RETURN_IF_ERROR(CreateEpContextModel(providers_, graph, ep_context_path, external_ini_file_name, logger)); } #else diff --git a/onnxruntime/core/framework/op_kernel.cc b/onnxruntime/core/framework/op_kernel.cc index 94b6224440ed0..212ce9c5069ea 100644 --- a/onnxruntime/core/framework/op_kernel.cc +++ b/onnxruntime/core/framework/op_kernel.cc @@ -130,6 +130,11 @@ OrtValue* OpKernelContext::GetOrCreateOutputMLValue(int index) { return value; } +int OpKernelContext::GetOrtValueIndexForOutput(int output_index) const { + int output_arg_index = GetOutputArgIndex(output_index); + return execution_frame_->GetNodeIdxToMLValueIdx(output_arg_index); +} + int OpKernelContext::GetInputArgIndex(int index) const { return node_input_start_index_ + index; } diff --git a/onnxruntime/core/framework/op_kernel_context_internal.h b/onnxruntime/core/framework/op_kernel_context_internal.h index 64bd70465a1c7..4c7ee10a07691 100644 --- a/onnxruntime/core/framework/op_kernel_context_internal.h +++ b/onnxruntime/core/framework/op_kernel_context_internal.h @@ -36,6 +36,15 @@ class OpKernelContextInternal : public OpKernelContext { implicit_inputs[i]->Name(), " does not."); implicit_input_values_.push_back(entry); } + +#if !defined(ORT_MINIMAL_BUILD) + if (session_state_.GetNodeStatsRecorder() != nullptr) { + auto alloc = OpKernelContext::GetAllocator(kernel.GetDevice(OrtMemTypeDefault)); + if (alloc != nullptr) { + accounting_allocator_ = std::make_shared(std::move(alloc)); + } + } +#endif } bool GetUseDeterministicCompute() const override { @@ -69,9 +78,63 @@ class OpKernelContextInternal : public OpKernelContext { return implicit_input_values_; } + int GetOrtValueIndexForOutput(int output_index) const override { + return OpKernelContext::GetOrtValueIndexForOutput(output_index); + } + +#if !defined(ORT_MINIMAL_BUILD) + Status GetTempSpaceAllocator(AllocatorPtr* output) const override { + if (accounting_allocator_) { + *output = accounting_allocator_; + return Status::OK(); + } + return OpKernelContext::GetTempSpaceAllocator(output); + } +#endif + +#if !defined(ORT_MINIMAL_BUILD) + bool GetAllocatorStats(AllocatorStats& stats) { + if (accounting_allocator_ == nullptr) { + return false; + } + accounting_allocator_->GetStats(&stats); + return true; + } +#endif + const bool& GetTerminateFlag() const noexcept { return terminate_flag_; } private: +#if !defined(ORT_MINIMAL_BUILD) + class AccountingAllocator : public IAllocator { + public: + AccountingAllocator(AllocatorPtr alloc) : IAllocator(alloc->Info()), allocator_(std::move(alloc)) { + } + + void* Alloc(size_t size) override { + void* p = allocator_->Alloc(size); + if (p != nullptr) { + stats_.total_allocated_bytes += size; + } + return p; + } + + void Free(void* p) override { + allocator_->Free(p); + } + + void GetStats(AllocatorStats* stats) override { + *stats = stats_; + } + + private: + AllocatorPtr allocator_; + AllocatorStats stats_; + }; + + AllocatorPtr accounting_allocator_; +#endif + const SessionState& session_state_; const bool& terminate_flag_; std::vector implicit_input_values_; diff --git a/onnxruntime/core/framework/resource_accountant.cc b/onnxruntime/core/framework/resource_accountant.cc new file mode 100644 index 0000000000000..4d537219ec714 --- /dev/null +++ b/onnxruntime/core/framework/resource_accountant.cc @@ -0,0 +1,223 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/resource_accountant.h" + +#include "core/common/inlined_containers.h" +#include "core/common/narrow.h" +#include "core/common/parse_string.h" +#include "core/common/safeint.h" +#include "core/common/string_utils.h" + +#include "core/framework/config_options.h" +#include "core/framework/murmurhash3.h" +#include "core/graph/constants.h" +#include "core/graph/graph.h" +#include "core/session/onnxruntime_session_options_config_keys.h" + +#include + +namespace onnxruntime { + +// Use this accountant if your resource can be counted with size_t type +class SizeTAccountant : public IResourceAccountant { + public: + SizeTAccountant() = default; + ~SizeTAccountant() = default; + + SizeTAccountant(size_t threshold, InlinedHashMap&& node_stats) + : IResourceAccountant(threshold), node_stats_(std::move(node_stats)) {} + + explicit SizeTAccountant(InlinedHashMap&& node_stats) + : IResourceAccountant(), node_stats_(std::move(node_stats)) {} + + ResourceCount GetConsumedAmount() const noexcept override { + return consumed_amount_; + } + + void AddConsumedAmount(const ResourceCount& amount) noexcept override { + if (std::holds_alternative(amount)) { + consumed_amount_ += std::get(amount); + } + } + void RemoveConsumedAmount(const ResourceCount& amount) noexcept override { + if (std::holds_alternative(amount)) { + consumed_amount_ -= std::get<0>(amount); + } + } + + ResourceCount ComputeResourceCount(const Node& node) const override { + const auto node_name = MakeUniqueNodeName(node); + auto hit = node_stats_.find(node_name); + if (hit != node_stats_.end()) { + const auto& stats = hit->second; + return stats.input_sizes + stats.initializers_sizes + + stats.total_dynamic_sizes + stats.total_temp_allocations; + } + return static_cast(0U); + } + + private: + size_t consumed_amount_ = 0; + InlinedHashMap node_stats_; +}; + +struct NodeStatsRecorder::Impl { + std::filesystem::path node_stats_path; + // This is a node name to allocation stats map + InlinedHashMap node_stats; + // Keeps track of nodes for which input/output sizes are accounted + InlinedHashSet input_output_accounted; +}; + +NodeStatsRecorder::NodeStatsRecorder(const std::filesystem::path& node_stats_path) + : impl_(std::make_unique()) { + impl_->node_stats_path = node_stats_path; +} + +NodeStatsRecorder::~NodeStatsRecorder() = default; + +const std::filesystem::path& NodeStatsRecorder::GetNodeStatsFileName() const noexcept { + return impl_->node_stats_path; +} + +bool NodeStatsRecorder::ShouldAccountFor(const std::string& input_output_name) const { + return impl_->input_output_accounted.insert(input_output_name).second; +} + +void NodeStatsRecorder::ResetPerRunNameDeduper() { + impl_->input_output_accounted.clear(); +} + +void NodeStatsRecorder::ReportNodeStats(const std::string& node_name, const NodeAllocationStats& stats) { + auto result = impl_->node_stats.emplace(node_name, stats); + if (!result.second) { + // Node already exists, update the stats + // This may happen when the user collects stats from multiple Runs() + result.first->second.UpdateIfGreater(stats); + } +} + +void NodeStatsRecorder::DumpStats(std::ostream& os) const { + os << "#name,input_sizes,initializers_sizes,total_dynamic_sizes,total_temp_allocations\n"; + for (const auto& [name, stats] : impl_->node_stats) { + os << name << "," << stats.input_sizes << "," << stats.initializers_sizes << "," + << stats.total_dynamic_sizes << "," + << stats.total_temp_allocations << "\n"; + } +} + +void NodeStatsRecorder::DumpStats(const std::filesystem::path& model_path) const { + auto node_stats_file = model_path; + if (node_stats_file.has_filename()) { + node_stats_file = node_stats_file.parent_path(); + } + node_stats_file /= GetNodeStatsFileName(); + std::ofstream ofs(node_stats_file, std::ofstream::out); + ORT_ENFORCE(ofs.is_open(), "Failed to open file: ", node_stats_file); + DumpStats(ofs); + ofs.close(); +} + +static Status LoadNodeAllocationStats( + const std::filesystem::path& model_path, const std::filesystem::path& file_name, + InlinedHashMap& result) { + InlinedHashMap node_stats; + std::filesystem::path file_path = model_path; + if (file_path.has_filename()) { + file_path = file_path.parent_path(); + } + + file_path /= file_name; + + std::ifstream file(file_path); + ORT_RETURN_IF_NOT(file.is_open(), "Failed to open file ", file_path); + std::string line; + // Read and load a CSV file line by line + while (std::getline(file, line)) { + if (line.empty() || line[0] == '#') continue; + + auto splits = utils::SplitString(line, ",", true); + ORT_ENFORCE(splits.size() == 5, "Invalid line in the file ", file_path, ": ", line); + if (splits[0].empty()) { + continue; + } + std::string node_name{splits[0]}; + size_t input_sizes = SafeInt(std::stoull(std::string{splits[1]})); + size_t initializers_sizes = SafeInt(std::stoull(std::string{splits[2]})); + size_t total_dynamic_sizes = SafeInt(std::stoull(std::string{splits[3]})); + size_t total_temp_allocations = SafeInt(std::stoull(std::string{splits[4]})); + node_stats.insert_or_assign(std::move(node_name), {input_sizes, initializers_sizes, + total_dynamic_sizes, total_temp_allocations}); + } + + result.swap(node_stats); + return Status::OK(); +} + +Status NodeStatsRecorder::CreateAccountants( + const ConfigOptions& config_options, + const std::filesystem::path& model_path, + std::optional& acc_map) { + // Check if CUDA partitioning settings are provided + const std::string resource_partitioning_settings = config_options.GetConfigOrDefault( + kOrtSessionOptionsResourceCudaPartitioningSettings, ""); + + if (!resource_partitioning_settings.empty()) { + auto splits = utils::SplitString(resource_partitioning_settings, ",", true); + if (splits.size() == 2) { + if (splits[1].empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid resource partitioning settings"); + } + + InlinedHashMap loaded_stats; + ORT_RETURN_IF_ERROR(LoadNodeAllocationStats(model_path, splits[1], loaded_stats)); + + std::optional result; + auto& map = result.emplace(); + + if (!splits[0].empty()) { + size_t cuda_memory_limit = 0; + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(std::string{splits[0]}, cuda_memory_limit)); + cuda_memory_limit = SafeInt(cuda_memory_limit) * 1024; // to bytes + map.insert_or_assign(kCudaExecutionProvider, + std::make_unique(cuda_memory_limit, + std::move(loaded_stats))); + } else { + map.insert_or_assign(kCudaExecutionProvider, + std::make_unique(std::move(loaded_stats))); + } + + acc_map = std::move(result); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid format for: ", + kOrtSessionOptionsResourceCudaPartitioningSettings, + " : expecting comma separated fields"); + } + } + + return Status::OK(); +} + +std::string IResourceAccountant::MakeUniqueNodeName(const Node& node) { + std::string result; + + uint32_t hash[4] = {0, 0, 0, 0}; + auto hash_str = [&hash](const std::string& str) { + MurmurHash3::x86_128(str.data(), narrow(str.size()), hash[0], &hash); + }; + + const auto& node_name = (node.Name().empty()) ? node.OpType() : node.Name(); + + for (const auto& def : node.InputDefs()) { + hash_str(def->Name()); + } + + HashValue node_hash = hash[0] | (uint64_t(hash[1]) << 32); + result.reserve(node_name.size() + 1 + 16); + result.append(node_name).append("_").append(std::to_string(node_hash)); + + return result; +} + +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/framework/sequential_executor.cc b/onnxruntime/core/framework/sequential_executor.cc index 61fd9b08655b7..26a57ec3ea02f 100644 --- a/onnxruntime/core/framework/sequential_executor.cc +++ b/onnxruntime/core/framework/sequential_executor.cc @@ -11,6 +11,7 @@ #include "core/common/logging/logging.h" #include "core/framework/allocation_planner.h" #include "core/framework/execution_frame.h" +#include "core/framework/resource_accountant.h" #include "core/framework/stream_execution_context.h" #include "core/framework/session_state.h" #include "core/framework/op_kernel_context_internal.h" @@ -104,7 +105,7 @@ static void CalculateTotalInputSizes(const OpKernelContextInternal* op_kernel_co const int input_count = op_kernel_context->InputCount(); for (auto i = 0; i < input_count; i++) { const OrtValue* p_input = op_kernel_context->GetInputMLValue(i); - if (p_input != nullptr && p_input->IsTensor() && p_input->IsAllocated()) { + if (p_input != nullptr && p_input->IsAllocated() && p_input->IsTensor()) { const OpKernelInfo& op_kernel_info = p_op_kernel->Info(); const Tensor* p_tensor = nullptr; bool is_param = op_kernel_info.TryGetConstantInput(i, &p_tensor); @@ -256,6 +257,8 @@ class SessionScope { TimePoint session_start_; #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) const ExecutionFrame& frame_; +#endif +#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) // Whether memory profiler need create events and flush to file. // For partial graph run, when the last subgraph of the whole graph is executing, we need flush to file. bool flush_memory_info_ = true; @@ -487,6 +490,65 @@ onnxruntime::Status ExecuteKernel(StreamExecutionContext& ctx, } #else status = p_kernel->Compute(&kernel_ctx); + +#if !defined(ORT_MINIMAL_BUILD) + auto* node_stats_recorder = ctx.GetSessionState().GetNodeStatsRecorder(); + if (node_stats_recorder != nullptr) { + const auto& node = p_kernel->Node(); + const OpKernelInfo& op_kernel_info = p_kernel->Info(); + const auto input_defs = node.InputDefs(); + + // Lets first check if any inputs are initializers, + // if so we need to account for their memory usage. + SafeInt initializers_size = 0; + SafeInt input_sizes = 0; + for (int i = 0, lim = kernel_ctx.InputCount(); i < lim; ++i) { + // Need to get ort_value_index for each input. + const OrtValue* p_input = kernel_ctx.GetInputMLValue(i); + if (p_input != nullptr && p_input->IsAllocated() && p_input->IsTensor()) { + const auto& input_name = input_defs[i]->Name(); + if (node_stats_recorder->ShouldAccountFor(input_name)) { + const Tensor* p_tensor = nullptr; + const bool is_constant = op_kernel_info.TryGetConstantInput(i, &p_tensor); + if (!is_constant) { + p_tensor = &p_input->Get(); + } + input_sizes += p_tensor->SizeInBytes(); + } + } + } + + // Get outputs and see if anything were allocated dynamically + const auto output_defs = node.OutputDefs(); + SafeInt total_dynamic_sizes = 0; + const auto& exec_frame = ctx.GetExecutionFrame(); + for (int i = 0, lim = kernel_ctx.OutputCount(); i < lim; ++i) { + const OrtValue* p_output = kernel_ctx.GetOutputMLValue(i); + if (p_output != nullptr && p_output->IsAllocated() && p_output->IsTensor()) { + int ort_value_index = kernel_ctx.GetOrtValueIndexForOutput(i); + auto maybe_val = exec_frame.GetOrtValueDynamicAllocation(ort_value_index); + if (maybe_val.has_value() && node_stats_recorder->ShouldAccountFor(output_defs[i]->Name())) { + total_dynamic_sizes += *maybe_val; + } + } + } + + NodeAllocationStats node_stats; + node_stats.input_sizes = static_cast(input_sizes); + node_stats.initializers_sizes = static_cast(initializers_size); + node_stats.total_dynamic_sizes = total_dynamic_sizes; + + // Get the temporary allocations + AllocatorStats temp_stats; + if (kernel_ctx.GetAllocatorStats(temp_stats)) { + node_stats.total_temp_allocations = narrow(temp_stats.total_allocated_bytes); + } + + // Record node allocation stats + const std::string name = IResourceAccountant::MakeUniqueNodeName(node); + node_stats_recorder->ReportNodeStats(name, node_stats); + } +#endif #endif } ORT_CATCH(const std::exception& ex) { @@ -510,6 +572,7 @@ onnxruntime::Status ExecuteKernel(StreamExecutionContext& ctx, LOGS(logger, ERROR) << msg_string; return Status(status.Category(), status.Code(), msg_string); } + ctx.RecycleNodeInputs(idx); VLOGS(logger, 0) << "stream " << stream_idx << " launch kernel with idx " << idx; return Status::OK(); diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index 82f520f4a4252..964c059e529f9 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -375,6 +375,24 @@ class SessionState { /// true of false bool GetSaveModeForPrepacks(bool saving_model, bool saving_ort_format); +#if !defined(ORT_MINIMAL_BUILD) + + void SetNodeStatsRecorder(NodeStatsRecorder* node_stats_recorder) { + node_stats_recorder_ = node_stats_recorder; + } + + /** + * Returns a pointer to the NodeStatsRecorder object if it was enabled for the session. + * The object pointer is only present at the root SessionState object + */ + NodeStatsRecorder* GetNodeStatsRecorder() const { + if (parent_ != nullptr) { + return parent_->GetNodeStatsRecorder(); + } + return node_stats_recorder_; + } +#endif + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SessionState); @@ -502,6 +520,10 @@ class SessionState { MemoryProfiler* memory_profiler_; #endif +#if !defined(ORT_MINIMAL_BUILD) + NodeStatsRecorder* node_stats_recorder_ = nullptr; +#endif + // switch for enable memory pattern optimization or not. bool enable_mem_pattern_; diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 097ce436f4419..17c37b8882168 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -844,18 +844,9 @@ INSTANTIATE_UNPACK_TENSOR(UInt4x2) break; template -common::Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out) { - const auto& dims = tensor_proto.dims(); - size_t size = 1; - for (google::protobuf::int64 dim : dims) { - if (dim < 0 || static_cast(dim) >= std::numeric_limits::max()) { - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid TensorProto"); - } - if (!IAllocator::CalcMemSizeForArray(size, static_cast(dim), &size)) { - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid TensorProto"); - } - } - switch (tensor_proto.data_type()) { +common::Status GetSizeInBytesFromTensorShapeAndType(const TensorShape& shape, int32_t element_type, size_t* out) { + const auto size = narrow(shape.Size()); + switch (element_type) { CASE_PROTO_TRACE(FLOAT, float); CASE_PROTO_TRACE(DOUBLE, double); CASE_PROTO_TRACE(BOOL, bool); @@ -884,24 +875,61 @@ common::Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& return Status::OK(); } +template +common::Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out) { + TensorShape tensor_shape = GetTensorShapeFromTensorProto(tensor_proto); + + bool any_out_of_bounds = std::any_of(tensor_shape.GetDims().begin(), tensor_shape.GetDims().end(), + [](int64_t dim) { + if (dim < 0 || + static_cast(dim) >= std::numeric_limits::max()) { + return true; + } + return false; + }); + + ORT_RETURN_IF(any_out_of_bounds, "Out of bounds dimensions in TypeProto_Tensor"); + + return GetSizeInBytesFromTensorShapeAndType(tensor_shape, tensor_proto.data_type(), out); +} + +template +common::Status GetSizeInBytesFromTensorTypeProto(const ONNX_NAMESPACE::TypeProto_Tensor& tensor_proto, size_t* out) { + ORT_RETURN_IF_NOT(HasShape(tensor_proto), "TypeProto_Tensor does not have shape"); + ORT_RETURN_IF_NOT(HasElemType(tensor_proto), "TypeProto_Tensor does not have element type"); + + TensorShape tensor_shape = GetTensorShapeFromTensorShapeProto(tensor_proto.shape()); + + bool any_out_of_bounds = std::any_of(tensor_shape.GetDims().begin(), tensor_shape.GetDims().end(), + [](int64_t dim) { + return dim < 0 || + static_cast(dim) >= std::numeric_limits::max(); + }); + ORT_RETURN_IF(any_out_of_bounds, "Out of bounds dimensions in TypeProto_Tensor"); + + return GetSizeInBytesFromTensorShapeAndType(tensor_shape, tensor_proto.elem_type(), out); +} + +template Status GetSizeInBytesFromTensorTypeProto<0>(const ONNX_NAMESPACE::TypeProto_Tensor& tensor_proto, size_t* out); + TensorShape GetTensorShapeFromTensorShapeProto(const ONNX_NAMESPACE::TensorShapeProto& tensor_shape_proto) { const auto& dims = tensor_shape_proto.dim(); - std::vector tensor_shape_vec(static_cast(dims.size())); + TensorShapeVector tensor_shape_vec(static_cast(dims.size())); for (int i = 0; i < dims.size(); ++i) { tensor_shape_vec[i] = HasDimValue(dims[i]) ? dims[i].dim_value() : -1; /* symbolic dimensions are represented as -1 in onnxruntime*/ } - return TensorShape(std::move(tensor_shape_vec)); + return TensorShape(tensor_shape_vec); } TensorShape GetTensorShapeFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto) { const auto& dims = tensor_proto.dims(); - std::vector tensor_shape_vec(static_cast(dims.size())); + TensorShapeVector tensor_shape_vec(static_cast(dims.size())); for (int i = 0; i < dims.size(); ++i) { tensor_shape_vec[i] = dims[i]; } - return TensorShape(std::move(tensor_shape_vec)); + return TensorShape(tensor_shape_vec); } struct UnInitializeParam { diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 7b9a47842388c..f5dec7ae988f2 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -157,6 +157,9 @@ ONNXTensorElementDataType GetTensorElementType(const ONNX_NAMESPACE::TensorProto template common::Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out); +template +Status GetSizeInBytesFromTensorTypeProto(const ONNX_NAMESPACE::TypeProto_Tensor& tensor_proto, size_t* out); + /** Special marker used to indicate an existing memory buffer contains the TensorProto external data. If the 'location' field of the external data info is set to this marker, the 'offset' field should contain the diff --git a/onnxruntime/core/providers/acl/acl_execution_provider.cc b/onnxruntime/core/providers/acl/acl_execution_provider.cc index 8d34e36fe7cd6..ede476ff74d1b 100644 --- a/onnxruntime/core/providers/acl/acl_execution_provider.cc +++ b/onnxruntime/core/providers/acl/acl_execution_provider.cc @@ -152,7 +152,8 @@ std::shared_ptr ACLExecutionProvider::GetKernelRegistry() const std::vector> ACLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& kernel_lookup) const { + const IKernelLookup& kernel_lookup, + IResourceAccountant*) const { std::vector> result; for (const auto& node : graph.Nodes()) { if (const KernelCreateInfo* kernel_create_info = kernel_lookup.LookUpKernel(node); diff --git a/onnxruntime/core/providers/acl/acl_execution_provider.h b/onnxruntime/core/providers/acl/acl_execution_provider.h index 1c267d8713673..d635e56add30b 100755 --- a/onnxruntime/core/providers/acl/acl_execution_provider.h +++ b/onnxruntime/core/providers/acl/acl_execution_provider.h @@ -38,7 +38,8 @@ class ACLExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph, - const IKernelLookup& kernel_lookup) const override; + const IKernelLookup& kernel_lookup, + IResourceAccountant* resource_accountant) const override; Status OnRunStart(const onnxruntime::RunOptions&) override; diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index f954baf3eabae..07e83933a890c 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1253,7 +1253,8 @@ GetSubGraphPartition(const std::vector& topological_order, const std: std::vector> CANNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& kernel_lookup) const { + const IKernelLookup& kernel_lookup, + IResourceAccountant*) const { std::vector> result; // TODO(FFFrog): Feature Enhancement diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.h b/onnxruntime/core/providers/cann/cann_execution_provider.h index 7debfa72778fd..5ff935463a1c1 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.h +++ b/onnxruntime/core/providers/cann/cann_execution_provider.h @@ -55,7 +55,8 @@ class CANNExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& kernel_lookup) const override; + const IKernelLookup& kernel_lookup, + IResourceAccountant* resource_accountant) const override; Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc index b6bb4f2c1d66a..3fa3868267c9b 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc @@ -38,7 +38,8 @@ CoreMLExecutionProvider::~CoreMLExecutionProvider() {} std::vector> CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& /*kernel_lookup*/) const { + const IKernelLookup& /*kernel_lookup*/, + IResourceAccountant* /* resource_accountant */) const { std::vector> result; const auto& logger = *GetLogger(); diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.h b/onnxruntime/core/providers/coreml/coreml_execution_provider.h index 650d81a4fecf7..0609bf6af726d 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.h +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.h @@ -19,7 +19,8 @@ class CoreMLExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& /*kernel_lookup*/) const override; + const IKernelLookup& /*kernel_lookup*/, + IResourceAccountant* resource_accountant) const override; #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) common::Status Compile(const std::vector& fused_nodes, diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 4a10de153653c..b675c08e5f804 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -5,6 +5,7 @@ #include "core/common/inlined_containers.h" #include "core/common/parse_string.h" #include "core/framework/int4.h" +#include "core/framework/resource_accountant.h" #include "core/providers/shared_library/provider_api.h" #include "core/platform/env_var_utils.h" #include "core/providers/cuda/cuda_execution_provider.h" @@ -2658,11 +2659,43 @@ std::unique_ptr CUDAExecutionProvider::GetDataTransf std::vector> CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& kernel_lookup) const { + const IKernelLookup& kernel_lookup, + IResourceAccountant* resource_accountant) const { + std::vector> result; + const logging::Logger& logger = *GetLogger(); + + // Figure out the memory limit if accountant is available + size_t memory_threshold = std::numeric_limits::max(); + SafeInt consumed_memory = 0; + if (resource_accountant != nullptr) { + if (resource_accountant->IsStopIssued()) { + LOGS(logger, WARNING) << "CUDA_EP returning due to Stop Set"; + return result; + } + + auto threshold = resource_accountant->GetThreshold(); + if (!threshold.has_value()) { + // info_.gpu_mem_limit is for BFC arena + size_t free_memory, total_memory; + if (0 != cudaMemGetInfo(&free_memory, &total_memory)) { + memory_threshold = info_.gpu_mem_limit; + } else { + memory_threshold = std::min(free_memory, info_.gpu_mem_limit); + } + } else { + memory_threshold = std::get<0>(threshold.value()); + } + + consumed_memory = std::get<0>(resource_accountant->GetConsumedAmount()); + } + + InlinedHashSet previously_assigned_nodes; + // On repeated calls to this function, we may have most of the nodes already + // assigned to a CUDA EP capability. We'll skip accounting for these nodes. + previously_assigned_nodes.reserve(graph.NumberOfNodes()); InlinedVector candidates; // A subset of the above vector. A subset of the tentative_nodes might be moved to CPU. InlinedVector tentative_nodes; - const logging::Logger& logger = *GetLogger(); for (auto& node_index : graph.GetNodesInTopologicalOrder()) { const auto* p_node = graph.GetNode(node_index); if (p_node == nullptr) @@ -2672,6 +2705,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, if (!node.GetExecutionProviderType().empty()) { if (node.GetExecutionProviderType() == kCudaExecutionProvider) { candidates.push_back(node.Index()); + previously_assigned_nodes.insert(node.Index()); } continue; } @@ -2726,14 +2760,40 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // These are usually shape related computation subgraphs // Following logic can be extended for other EPs auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger); - std::vector> result; for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) continue; - auto sub_graph = IndexedSubGraph::Create(); - sub_graph->Nodes().push_back(node_index); - result.push_back(ComputeCapability::Create(std::move(sub_graph))); + // Previously assigned nodes have been accounted before + if (previously_assigned_nodes.count(node_index) > 0 || resource_accountant == nullptr) { + auto sub_graph = IndexedSubGraph::Create(); + sub_graph->Nodes().push_back(node_index); + result.push_back(ComputeCapability::Create(std::move(sub_graph))); + } else { + auto* node = graph.GetNode(node_index); + auto resource_count = std::get<0>(resource_accountant->ComputeResourceCount(*node)); + const auto would_be_consumed = resource_count + consumed_memory; + LOGS(logger, INFO) << "CUDA_EP Node: " << node_index << " Memory usage : " << resource_count + << " would be consumed " << static_cast(would_be_consumed) + << " threshold: " << memory_threshold; + if (would_be_consumed < memory_threshold) { + consumed_memory = would_be_consumed; + auto sub_graph = IndexedSubGraph::Create(); + sub_graph->SetAccountant(resource_accountant); + sub_graph->Nodes().push_back(node_index); + sub_graph->AppendNodeCost(resource_count); + result.push_back(ComputeCapability::Create(std::move(sub_graph))); + } else { + // We break here so we do not have patches of CUDA assigned nodes. + auto* node = graph.GetNode(node_index); + if (node != nullptr) { + LOGS(logger, WARNING) << "CUDA_EP Halting assignment due to capacity threshold at node: " + << node->Name() << " index: " << node_index; + } + resource_accountant->SetStopAssignment(); + break; + } + } } /* std::vector> result; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index bd2be2eac2181..79a48e7cb89e1 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -72,7 +72,8 @@ class CUDAExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph, - const IKernelLookup& kernel_lookup) const override; + const IKernelLookup& kernel_lookup, + IResourceAccountant* resource_accountant) const override; int GetDeviceId() const override { return info_.device_id; } const cudaDeviceProp& GetDeviceProp() const { return device_prop_; }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 826f48b5f7a68..9d23b8b950272 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -92,12 +92,13 @@ namespace Dml std::vector> ExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, - const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup) const + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, + onnxruntime::IResourceAccountant* resource_accountant) const { #ifdef ENABLE_GRAPH_COMPILATION - return m_impl->GetCapability(graph, kernel_lookup, *GetLogger()); + return m_impl->GetCapability(graph, kernel_lookup, resource_accountant, *GetLogger()); #else - return onnxruntime::IExecutionProvider::GetCapability(graph, kernel_lookup); + return onnxruntime::IExecutionProvider::GetCapability(graph, kernel_lookup, resource_accountant); #endif } @@ -877,8 +878,7 @@ namespace Dml ExecutionProviderImpl::GetCapability( const onnxruntime::GraphViewer& graph, const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, - const onnxruntime::logging::Logger& logger) const - { + onnxruntime::IResourceAccountant*, const onnxruntime::logging::Logger& logger) const { uint32_t deviceDataTypeMask = GetSupportedDeviceDataTypeMask(); // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::vector> result; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index e7d859c5764de..7f420f8850001 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -11,6 +11,10 @@ #include #include +namespace onnxruntime { +class IResourceAccountant; +} + namespace WRL { template using Base = Microsoft::WRL::RuntimeClass< @@ -89,8 +93,8 @@ namespace Dml GetCapability( const onnxruntime::GraphViewer& graph, const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, - const onnxruntime::logging::Logger& logger - ) const; + onnxruntime::IResourceAccountant* resource_accountant, + const onnxruntime::logging::Logger& logger) const; uint32_t GetSupportedDeviceDataTypeMask() const; @@ -283,7 +287,8 @@ namespace Dml std::vector> GetCapability(const onnxruntime::GraphViewer& graph, - const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup) const final override; + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, + onnxruntime::IResourceAccountant* resource_accountant) const final override; onnxruntime::common::Status OnSessionInitializationEnd() override { diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc index c96f9cc1ff400..4da82b351f1d6 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc @@ -146,7 +146,8 @@ std::vector> DnnlExecutionProvider::GetSupportedNodes(con std::vector> DnnlExecutionProvider::GetCapability( const GraphViewer& graph_viewer, - const IKernelLookup& /*kernel_lookup*/) const { + const IKernelLookup& /*kernel_lookup*/, + IResourceAccountant* /* resource_accountant */) const { // follow from coreml ep's Getcapability std::vector> result; diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h index b7fcbb7765180..bde18e139f2a3 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h @@ -24,7 +24,8 @@ class DnnlExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& /*kernel_lookup*/) const override; + const IKernelLookup& /*kernel_lookup*/, + onnxruntime::IResourceAccountant* /* resource_accountant */) const override; common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index d2d1d5e6fdd03..5a753d1ccf79a 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -790,7 +790,8 @@ std::vector JsExecutionProvider::CreatePreferredAllocators() { std::vector> JsExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, - const IKernelLookup& kernel_lookup) const { + const IKernelLookup& kernel_lookup, + IResourceAccountant* /* resource_accountant */) const { InlinedVector candidates; // `tenative_candidates` is a subset of `candidates`. InlinedVector tenative_candidates; diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index 966f9c6980212..4bead50fc782e 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -44,7 +44,8 @@ class JsExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& /*kernel_lookup*/) const override; + const IKernelLookup& /*kernel_lookup*/, + IResourceAccountant* /* resource_accountant */) const override; std::shared_ptr GetKernelRegistry() const override; std::unique_ptr GetDataTransfer() const override; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 95fbe7ab58ce2..1558d22137c05 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -992,7 +992,8 @@ GetPartitionedSubgraphs(const std::vector& topological_order, std::vector> MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& /*kernel_lookup*/) const { + const IKernelLookup& /*kernel_lookup*/, + IResourceAccountant* /* resource_accountant */) const { std::vector> result; auto model = graph_viewer.CreateModel(*GetLogger()); auto model_proto = model->ToProto(); diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 91b6a4741b55e..d6af991f9b77e 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -68,7 +68,8 @@ class MIGraphXExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& /*kernel_lookup*/) const override; + const IKernelLookup& /*kernel_lookup*/, + IResourceAccountant* /* resource_accountant */) const override; common::Status Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) override; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc index f92c9592742d5..27bd584e2d3c6 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc @@ -80,9 +80,10 @@ NnapiExecutionProvider::~NnapiExecutionProvider() {} std::vector> NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& /*kernel_lookup*/) const { - const auto& logger = *GetLogger(); + const IKernelLookup& /*kernel_lookup*/, + IResourceAccountant* /* resource_accountant */) const { std::vector> result; + const logging::Logger& logger = *GetLogger(); // TODO: Task 812756: NNAPI EP, add support for subgraph (If and Loop operators) if (graph_viewer.IsSubgraph()) { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h index 460616c41991f..ebf9372eb668d 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h @@ -25,7 +25,8 @@ class NnapiExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_view, - const IKernelLookup& /*kernel_lookup*/) const override; + const IKernelLookup& /*kernel_lookup*/, + IResourceAccountant* /* resource_accountant */) const override; #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) common::Status Compile(const std::vector& fused_nodes, diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 22477611ce25b..67cb1910b441e 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -162,7 +162,8 @@ OpenVINOExecutionProvider::~OpenVINOExecutionProvider() { std::vector> OpenVINOExecutionProvider::GetCapability(const GraphViewer& graph_viewer, - const IKernelLookup& /*kernel_lookup*/) const { + const IKernelLookup& /*kernel_lookup*/, + IResourceAccountant* /* resource_accountant */) const { std::vector> result; // Enable CI Logs diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h index 75f4ef9f8ecc8..8cabfdb1b17f3 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h @@ -50,7 +50,8 @@ class OpenVINOExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const GraphViewer& graph_viewer, - const IKernelLookup& /*kernel_lookup*/) const override; + const IKernelLookup& /*kernel_lookup*/, + IResourceAccountant* /* resource_accountant */) const override; Status Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) override; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 9e61583e4e0d2..1b3f7fe3c60e1 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -646,7 +646,8 @@ static void PartitionCtxModel(const onnxruntime::GraphViewer& graph_viewer, std::vector> QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& /*kernel_lookup*/) const { + const IKernelLookup& /*kernel_lookup*/, + IResourceAccountant* /* resource_accountant */) const { std::vector> result; if (graph_viewer.IsSubgraph()) { diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 48f41c4da384f..2862e1b5f5661 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -30,7 +30,8 @@ class QNNExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_view, - const IKernelLookup& /*kernel_lookup*/) const override; + const IKernelLookup& /*kernel_lookup*/, + IResourceAccountant* /* resource_accountant */) const override; Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; diff --git a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc index 44b34f4b4ce6c..10fd81786f977 100644 --- a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc +++ b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc @@ -50,7 +50,8 @@ std::vector> RknpuExecutionProvider::GetSupportedNodes( std::vector> RknpuExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& /*kernel_lookup*/) const { + const IKernelLookup& /*kernel_lookup*/, + IResourceAccountant* /* resource_accountant */) const { // Find inputs, initializers and outputs for each supported subgraph std::vector> result; diff --git a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h index 1289c8569f8e8..ce16d63e111d9 100644 --- a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h +++ b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h @@ -19,7 +19,8 @@ class RknpuExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& /*kernel_lookup*/) const override; + const IKernelLookup& /*kernel_lookup*/, + IResourceAccountant* /* resource_accountant */) const override; common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; std::shared_ptr GetKernelRegistry() const override; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 0a427b146dcaa..9d6e9df907ce3 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -2440,7 +2440,8 @@ std::unique_ptr ROCMExecutionProvider::GetDataTransf std::vector> ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& kernel_lookup) const { + const IKernelLookup& kernel_lookup, + IResourceAccountant* /* resource_accountant */) const { InlinedVector candidates; // A subset of the above vector. A subset of the tentative_nodes might be moved to CPU. InlinedVector tentative_nodes; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index be467869248ea..ff2bff7c98723 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -61,7 +61,8 @@ class ROCMExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph, - const IKernelLookup& kernel_lookup) const override; + const IKernelLookup& kernel_lookup, + IResourceAccountant* /* resource_accountant */) const override; int GetDeviceId() const override { return info_.device_id; } const hipDeviceProp_t& GetDeviceProp() const { return device_prop_; }; diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 4c050534456da..2dab9f6a402a0 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -331,8 +331,9 @@ bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, siz } std::vector> IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& kernel_lookup) const { - return g_host->IExecutionProvider__GetCapability(this, graph_viewer, kernel_lookup); + const IKernelLookup& kernel_lookup, + IResourceAccountant* resource_accountant) const { + return g_host->IExecutionProvider__GetCapability(this, graph_viewer, kernel_lookup, resource_accountant); } common::Status IExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index a1bb86598ebc0..0dd771f522336 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -9,8 +9,8 @@ // Public wrappers around internal ort interfaces (currently) #include "core/providers/shared_library/provider_host_api.h" - #include "core/common/inlined_containers_fwd.h" +#include "core/framework/resource_accountant.h" #include "core/providers/shared/common.h" #define PROVIDER_DISALLOW_ALL(TypeName) \ @@ -252,7 +252,8 @@ struct ProviderHost { // IExecutionProvider virtual std::vector> IExecutionProvider__GetCapability(const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, - const IExecutionProvider::IKernelLookup& kernel_lookup) = 0; + const IExecutionProvider::IKernelLookup& kernel_lookup, + IResourceAccountant* resource_accountant) = 0; virtual common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) = 0; @@ -659,6 +660,7 @@ struct ProviderHost { virtual std::unique_ptr IndexedSubGraph__construct() = 0; virtual void IndexedSubGraph__operator_delete(IndexedSubGraph* p) = 0; + virtual const std::vector& IndexedSubGraph__Nodes(const IndexedSubGraph* p) = 0; virtual std::vector& IndexedSubGraph__Nodes(IndexedSubGraph* p) = 0; virtual void IndexedSubGraph__SetMetaDef(IndexedSubGraph* p, std::unique_ptr&& meta_def_) = 0; @@ -666,6 +668,8 @@ struct ProviderHost { virtual void IndexedSubGraph__SetSchemaSource(IndexedSubGraph* p, IndexedSubGraph_SourceOfSchema schema_source) = 0; virtual IndexedSubGraph_SourceOfSchema IndexedSubGraph__GetSchemaSource(const IndexedSubGraph* p) = 0; + virtual void IndexedSubGraph__SetAccountant(IndexedSubGraph* p, IResourceAccountant*) = 0; + virtual void IndexedSubGraph__AppendNodeCost(IndexedSubGraph* p, const ResourceCount& count) = 0; // KernelDef virtual void KernelDef__operator_delete(KernelDef* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 4feedd75f8004..a502ce9c66f69 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -581,6 +581,7 @@ struct IndexedSubGraph final { static std::unique_ptr Create() { return g_host->IndexedSubGraph__construct(); } static void operator delete(void* p) { g_host->IndexedSubGraph__operator_delete(reinterpret_cast(p)); } + gsl::span Nodes() const { return g_host->IndexedSubGraph__Nodes(this); } std::vector& Nodes() { return g_host->IndexedSubGraph__Nodes(this); } void SetMetaDef(std::unique_ptr&& meta_def_) { return g_host->IndexedSubGraph__SetMetaDef(this, std::move(*reinterpret_cast*>(&meta_def_))); } @@ -588,6 +589,12 @@ struct IndexedSubGraph final { void SetSchemaSource(IndexedSubGraph_SourceOfSchema schema_source) { return g_host->IndexedSubGraph__SetSchemaSource(this, schema_source); } IndexedSubGraph_SourceOfSchema GetSchemaSource() const { return g_host->IndexedSubGraph__GetSchemaSource(this); } + void SetAccountant(IResourceAccountant* resource_accountant) { + g_host->IndexedSubGraph__SetAccountant(this, resource_accountant); + } + void AppendNodeCost(const ResourceCount& resource_count) { + g_host->IndexedSubGraph__AppendNodeCost(this, resource_count); + } IndexedSubGraph() = delete; IndexedSubGraph(const IndexedSubGraph&) = delete; diff --git a/onnxruntime/core/providers/snpe/snpe_execution_provider.cc b/onnxruntime/core/providers/snpe/snpe_execution_provider.cc index fb9ce580ea2dc..c7fc6d3a556a7 100644 --- a/onnxruntime/core/providers/snpe/snpe_execution_provider.cc +++ b/onnxruntime/core/providers/snpe/snpe_execution_provider.cc @@ -71,7 +71,8 @@ SNPEExecutionProvider::~SNPEExecutionProvider() {} std::vector> SNPEExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& kernel_lookup) const { + const IKernelLookup& kernel_lookup, + IResourceAccountant* /* resource_accountant */) const { std::vector candidates; for (auto& node_index : graph.GetNodesInTopologicalOrder()) { const auto* p_node = graph.GetNode(node_index); diff --git a/onnxruntime/core/providers/snpe/snpe_execution_provider.h b/onnxruntime/core/providers/snpe/snpe_execution_provider.h index c0a62eea11a25..99033649fcbbf 100644 --- a/onnxruntime/core/providers/snpe/snpe_execution_provider.h +++ b/onnxruntime/core/providers/snpe/snpe_execution_provider.h @@ -18,7 +18,8 @@ class SNPEExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph, - const IKernelLookup& kernel_lookup) const override; + const IKernelLookup& kernel_lookup, + IResourceAccountant* /* resource_accountant */) const override; std::shared_ptr GetKernelRegistry() const override; std::unordered_map GetRuntimeOptions() const { return runtime_options_; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index c583598bbcc52..0ee5cef7cbaa1 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2451,7 +2451,8 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& std::vector> TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, - const IKernelLookup& /*kernel_lookup*/) const { + const IKernelLookup& /*kernel_lookup*/, + IResourceAccountant* /* resource_accountant */) const { // Construct subgraph capability from node list std::vector> result; // Get ModelPath diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index d3e0b0fba8891..92fdcbd3d950c 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -247,7 +247,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const GraphViewer& graph, - const IKernelLookup& /*kernel_lookup*/) const override; + const IKernelLookup& /*kernel_lookup*/, + IResourceAccountant* /* resource_accountant */) const override; int GetDeviceId() const { return device_id_; } diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 3a99f56bb732a..5d2204b0b1979 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -51,7 +51,7 @@ const InlinedVector VitisAIExecutionProvider::GetEpContextNodes() c return ep_context_node_ptrs; } std::vector> VitisAIExecutionProvider::GetCapability( - const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup) const { + const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, IResourceAccountant* /* resource_accountant */) const { if (graph_viewer.IsSubgraph()) { // VITIS AI EP not support sungraph. Assigned to CPU. return {}; diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h index f0d1a289a2a73..5b031ab882839 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h @@ -28,7 +28,8 @@ class VitisAIExecutionProvider : public IExecutionProvider { ~VitisAIExecutionProvider() = default; std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& /*kernel_lookup*/) const override; + const IKernelLookup& /*kernel_lookup*/, + IResourceAccountant* /* resource_accountant */) const override; int GetDeviceId() const { return 0; } common::Status OnRunStart(const onnxruntime::RunOptions& /*run_options*/) override; diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc index 7da7cc6cb63ba..4b9f6fae86423 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc +++ b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc @@ -61,8 +61,8 @@ VSINPUExecutionProvider::~VSINPUExecutionProvider() {} std::vector> VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& /*kernel_lookup*/) const { - const auto& logger = *GetLogger(); + const IKernelLookup& /*kernel_lookup*/, + IResourceAccountant* /* resource_accountant */) const { std::vector> result; if (graph_viewer.IsSubgraph()) { diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h index c2605eb65faee..16cfbc8a9c581 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h +++ b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h @@ -39,7 +39,8 @@ class VSINPUExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& kernel_lookup) const override; + const IKernelLookup& kernel_lookup, + IResourceAccountant* /* resource_accountant */) const override; std::shared_ptr GetKernelRegistry() const override; Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 3b7b2adcf70ad..87383fe197477 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -759,7 +759,8 @@ std::vector WebGpuExecutionProvider::CreatePreferredAllocators() { std::vector> WebGpuExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, - const IKernelLookup& kernel_lookup) const { + const IKernelLookup& kernel_lookup, + IResourceAccountant* /* resource_accountant */) const { InlinedVector candidates; // `tenative_candidates` is a subset of `candidates`. InlinedVector tenative_candidates; diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 1e74c04fba108..7a0ade97aa3df 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -44,7 +44,8 @@ class WebGpuExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& /*kernel_lookup*/) const override; + const IKernelLookup& /*kernel_lookup*/, + IResourceAccountant* /* resource_accountant */) const override; std::shared_ptr GetKernelRegistry() const override; std::unique_ptr GetDataTransfer() const override; diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 00fbb26b731f8..df95b653bd863 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -55,7 +55,8 @@ WebNNExecutionProvider::~WebNNExecutionProvider() {} std::vector> WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& /*kernel_registries*/) const { + const IKernelLookup& /*kernel_registries*/, + IResourceAccountant* /* resource_accountant */) const { // For subgraph which is the attribute of the control flow nodes, part of its initializers are stored in its // ancestor graphs as common initializers shared for other subgraphs. We need to collect all of them used for // identifying the required initializer names and storing into 'meta_def->constant_initializers'. diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.h b/onnxruntime/core/providers/webnn/webnn_execution_provider.h index 26c5e476bcc4f..e806dc340d53e 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.h +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.h @@ -24,7 +24,8 @@ class WebNNExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& /*kernel_registries*/) const override; + const IKernelLookup& /*kernel_registries*/, + IResourceAccountant* /* resource_accountant */) const override; DataLayout GetPreferredLayout() const override { return preferred_layout_; } diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index ee4e7be0f1f49..641f8b0729d0a 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -257,7 +257,8 @@ static void AddComputeCapabilityForEachNodeInNodeUnit( std::vector> XnnpackExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, - const IKernelLookup& /*kernel_lookup*/) const { + const IKernelLookup& /*kernel_lookup*/, + IResourceAccountant* /* resource_accountant */) const { const auto& logger = *GetLogger(); std::vector> capabilities; diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.h b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.h index 395dc2f90070e..152bef1a1c52c 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.h +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.h @@ -32,7 +32,8 @@ class XnnpackExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& /*kernel_lookup*/) const override; + const IKernelLookup& /*kernel_lookup*/, + IResourceAccountant* /* resource_accountant */) const override; std::shared_ptr GetKernelRegistry() const override; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index afd1e24bd4742..a1903898ea7f0 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1615,6 +1615,17 @@ common::Status InferenceSession::AddPrePackedWeightsContainer(PrepackedWeightsCo return Status::OK(); } +#if !defined(ORT_MINIMAL_BUILD) +Status onnxruntime::InferenceSession::CreateNodeStatsRecorder(const std::filesystem::path& node_stats_file) { + if (node_stats_recorder_.has_value()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "The session already has an instance of NodeStatsRecorder"); + } + node_stats_recorder_.emplace(node_stats_file); + return Status::OK(); +} +#endif + namespace { Status PartitionOrtFormatModel(onnxruntime::Graph& graph, const ExecutionProviders& providers, @@ -1816,6 +1827,17 @@ common::Status InferenceSession::Initialize() { } } +#if !defined(ORT_MINIMAL_BUILD) + const std::string node_stats_file = session_options_.config_options.GetConfigOrDefault( + kOrtSessionOptionsCollectNodeMemoryStatsToFile, ""); + + if (!node_stats_file.empty()) { + ORT_RETURN_IF_ERROR_SESSIONID_(CreateNodeStatsRecorder(node_stats_file)); + } + + session_state_->SetNodeStatsRecorder(GetNodeStatsRecorder()); +#endif + #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) // Don't want to pollute SessionState constructor since memory profile is enabled optionally. session_state_->SetMemoryProfiler(&memory_profiler_); @@ -2747,6 +2769,14 @@ Status InferenceSession::Run(const RunOptions& run_options, TraceLoggingWriteStop(ortrun_activity, "OrtRun"); #endif +#if !defined(ORT_MINIMAL_BUILD) + if (IsNodeStatsCollectionEnabled() && retval.IsOK()) { + // Dump node stats if the run was successful + node_stats_recorder_->DumpStats(session_state_->GetGraphViewer().ModelPath()); + node_stats_recorder_->ResetPerRunNameDeduper(); + } +#endif + // As N+1 inference runs (N for memory allocation and 1 for graph capturing) // are needed before replaying the captured graph, here run N inference runs recursively until graph captured, // so that users just need one session run to capture the graph. diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index e28ff75345785..2c0c09dfd3e51 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -21,6 +21,7 @@ #include "core/framework/external_data_loader_manager.h" #include "core/framework/kernel_registry_manager.h" #include "core/framework/prepacked_weights_container.h" +#include "core/framework/resource_accountant.h" #include "core/framework/session_state.h" #include "core/framework/tuning_results.h" #include "core/framework/framework_provider_common.h" @@ -545,6 +546,31 @@ class InferenceSession { */ Status AddPrePackedWeightsContainer(PrepackedWeightsContainer* prepacked_weights_container); +#if !defined(ORT_MINIMAL_BUILD) + /** + * CreateNodeStats recorder and enable collection of node statistics that is useful + * for resource constrained partitioning and otherwise. + * + * @param node_stats_file - this file will be created at the same folder where the model file is present. + */ + Status CreateNodeStatsRecorder(const std::filesystem::path& node_stats_file); + + /** + * Returns true if collection is enabled + */ + bool IsNodeStatsCollectionEnabled() const noexcept { + return node_stats_recorder_.has_value(); + } + + /** + * NodeStatsRecorder pointer. If not present, returns nullptr + */ + NodeStatsRecorder* GetNodeStatsRecorder() noexcept { + return node_stats_recorder_.has_value() ? &*node_stats_recorder_ : nullptr; + } + +#endif + protected: #if !defined(ORT_MINIMAL_BUILD) @@ -911,6 +937,11 @@ class InferenceSession { }; CachedExecutionProviderForGraphReplay cached_execution_provider_for_graph_replay_; + +#if !defined(ORT_MINIMAL_BUILD) + // Enable nodestats collection + std::optional node_stats_recorder_; +#endif }; struct SessionIOBinding { diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index f36345cdabf64..a1cd9af3b5091 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -359,8 +359,9 @@ struct ProviderHostImpl : ProviderHost { // IExecutionProvider (direct) std::vector> IExecutionProvider__GetCapability( const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, - const IExecutionProvider::IKernelLookup& kernel_lookup) override { - return p->IExecutionProvider::GetCapability(graph_viewer, kernel_lookup); + const IExecutionProvider::IKernelLookup& kernel_lookup, + IResourceAccountant* resource_accountant) override { + return p->IExecutionProvider::GetCapability(graph_viewer, kernel_lookup, resource_accountant); } common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override { @@ -827,6 +828,9 @@ struct ProviderHostImpl : ProviderHost { std::unique_ptr IndexedSubGraph__construct() override { return std::make_unique(); } void IndexedSubGraph__operator_delete(IndexedSubGraph* p) override { delete p; } + const std::vector& IndexedSubGraph__Nodes(const IndexedSubGraph* p) override { + return p->nodes; + } std::vector& IndexedSubGraph__Nodes(IndexedSubGraph* p) override { return p->nodes; } void IndexedSubGraph__SetMetaDef(IndexedSubGraph* p, std::unique_ptr&& meta_def_) override { p->SetMetaDef(std::move(meta_def_)); } @@ -834,6 +838,12 @@ struct ProviderHostImpl : ProviderHost { void IndexedSubGraph__SetSchemaSource(IndexedSubGraph* p, IndexedSubGraph_SourceOfSchema schema_source) override { p->schema_source = schema_source; } IndexedSubGraph_SourceOfSchema IndexedSubGraph__GetSchemaSource(const IndexedSubGraph* p) override { return p->schema_source; } + void IndexedSubGraph__SetAccountant(IndexedSubGraph* p, IResourceAccountant* resource_accountant) override { + p->SetAccountant(resource_accountant); + } + void IndexedSubGraph__AppendNodeCost(IndexedSubGraph* p, const ResourceCount& resource_count) override { + p->AppendNodeCost(resource_count); + } // KernelDef (wrapped) void KernelDef__operator_delete(KernelDef* p) override { delete p; } diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 740c566794f15..1b06eb55afbd2 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -10,6 +10,7 @@ #include #include #include +#include #include #include "core/common/denormal.h" @@ -59,7 +60,6 @@ #include "gtest/gtest.h" #include "gmock/gmock.h" -using namespace std; using namespace ONNX_NAMESPACE; using namespace onnxruntime::logging; using namespace onnxruntime::concurrency; @@ -137,7 +137,8 @@ class FuseExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& /*kernel_lookup*/) const override { + const IKernelLookup& /*kernel_lookup*/, + IResourceAccountant* /* resource_accountant */) const override { // Fuse two add into one. std::vector> result; std::unique_ptr sub_graph = std::make_unique(); @@ -283,7 +284,7 @@ void RunModelWithBindingMatMul(InferenceSession& session_object, ProviderType allocation_provider, IExecutionProvider* gpu_provider, OrtDevice* output_device) { - unique_ptr io_binding; + std::unique_ptr io_binding; Status st = session_object.NewIOBinding(&io_binding); ASSERT_TRUE(st.IsOK()); auto input_allocator = io_binding->GetCPUAllocator(bind_provider_type); @@ -358,7 +359,7 @@ void RunModelWithBindingMatMul(InferenceSession& session_object, (output_device && output_device->Type() == OrtDevice::GPU)) { #if defined(USE_CUDA) || defined(USE_ROCM) // in this case we need to copy the tensor from cuda to cpu - vector& outputs = io_binding->GetOutputs(); + std::vector& outputs = io_binding->GetOutputs(); ASSERT_EQ(1u, outputs.size()); auto& rtensor = outputs.front().Get(); auto element_type = rtensor.DataType(); @@ -438,7 +439,7 @@ TEST(InferenceSessionTests, TestModelSerialization) { // Load model with level 0 transform level // and assert that the model has Identity nodes. SessionOptions so; - const string test_model = "testdata/transform/abs-id-max.onnx"; + const std::string test_model = "testdata/transform/abs-id-max.onnx"; so.session_logid = "InferenceSessionTests.TestModelSerialization"; so.graph_optimization_level = TransformerLevel::Default; InferenceSessionWrapper session_object_noopt{so, GetEnvironment()}; @@ -478,9 +479,9 @@ TEST(InferenceSessionTests, TestModelSerialization) { // Assert that re-feed of optimized model with default transform level results // in same runtime model as abs-id-max.onnx with TransformLevel-1. - std::ifstream model_fs_session1(so.optimized_model_filepath, ios::in | ios::binary); + std::ifstream model_fs_session1(so.optimized_model_filepath, std::ios::in | std::ios::binary); ASSERT_TRUE(model_fs_session1.good()); - std::ifstream model_fs_session2(so_opt.optimized_model_filepath, ios::in | ios::binary); + std::ifstream model_fs_session2(so_opt.optimized_model_filepath, std::ios::in | std::ios::binary); ASSERT_TRUE(model_fs_session2.good()); ASSERT_TRUE(model_fs_session1.tellg() == model_fs_session2.tellg()); model_fs_session1.seekg(0, std::ifstream::beg); @@ -499,7 +500,7 @@ TEST(InferenceSessionTests, TestModelSerialization) { #ifdef ORT_RUN_EXTERNAL_ONNX_TESTS static bool Compare(const InputDefList& f_arg, const InputDefList& s_arg) { if (f_arg.size() != s_arg.size()) { - cout << "Sizes differ: f_arg size: " << f_arg.size() << " s_arg size: " << s_arg.size() << endl; + std::cout << "Sizes differ: f_arg size: " << f_arg.size() << " s_arg size: " << s_arg.size() << std::endl; return false; } @@ -564,9 +565,9 @@ TEST(InferenceSessionTests, ModelMetadata) { } auto retval = session_object.GetModelInputs(); - cout << "weights size: " << weights.size() - << " inputs.size(): " << inputs.size() - << " from session: " << retval.second->size() << endl; + std::cout << "weights size: " << weights.size() + << " inputs.size(): " << inputs.size() + << " from session: " << retval.second->size() << std::endl; ASSERT_TRUE(retval.first.IsOK()); ASSERT_TRUE(Compare(inputs_no_weights, *retval.second)); } @@ -617,7 +618,7 @@ TEST(InferenceSessionTests, CheckRunLogger) { bool have_log_entry_with_run_tag = (std::find_if(msgs.begin(), msgs.end(), [&run_options](std::string msg) { - return msg.find(run_options.run_tag) != string::npos; + return msg.find(run_options.run_tag) != std::string::npos; }) != msgs.end()); ASSERT_TRUE(have_log_entry_with_run_tag); @@ -660,18 +661,18 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions) { auto size = lines.size(); ASSERT_TRUE(size > 1); - ASSERT_TRUE(lines[0].find("[") != string::npos); - ASSERT_TRUE(lines[1].find("model_loading_uri") != string::npos); - ASSERT_TRUE(lines[size - 1].find("]") != string::npos); + ASSERT_TRUE(lines[0].find("[") != std::string::npos); + ASSERT_TRUE(lines[1].find("model_loading_uri") != std::string::npos); + ASSERT_TRUE(lines[size - 1].find("]") != std::string::npos); std::vector tags = {"pid", "dur", "ts", "ph", "X", "name", "args"}; bool has_kernel_info = false; for (size_t i = 1; i < size - 1; ++i) { for (auto& s : tags) { - ASSERT_TRUE(lines[i].find(s) != string::npos); - has_kernel_info = has_kernel_info || lines[i].find("Kernel") != string::npos && - lines[i].find("stream") != string::npos && - lines[i].find("block_x") != string::npos; + ASSERT_TRUE(lines[i].find(s) != std::string::npos); + has_kernel_info = has_kernel_info || lines[i].find("Kernel") != std::string::npos && + lines[i].find("stream") != std::string::npos && + lines[i].find("block_x") != std::string::npos; } } @@ -717,25 +718,25 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions2) { auto size = lines.size(); ASSERT_TRUE(size > 1); - ASSERT_TRUE(lines[0].find("[") != string::npos); - ASSERT_TRUE(lines[1].find("model_loading_uri") != string::npos); - ASSERT_TRUE(lines[size - 1].find("]") != string::npos); + ASSERT_TRUE(lines[0].find("[") != std::string::npos); + ASSERT_TRUE(lines[1].find("model_loading_uri") != std::string::npos); + ASSERT_TRUE(lines[size - 1].find("]") != std::string::npos); std::vector tags = {"pid", "dur", "ts", "ph", "X", "name", "args"}; [[maybe_unused]] bool has_api_info = false; for (size_t i = 1; i < size - 1; ++i) { for (auto& s : tags) { - ASSERT_TRUE(lines[i].find(s) != string::npos); + ASSERT_TRUE(lines[i].find(s) != std::string::npos); #ifdef USE_CUDA - has_api_info = has_api_info || lines[i].find("Api") != string::npos && - lines[i].find("cudaLaunch") != string::npos; + has_api_info = has_api_info || lines[i].find("Api") != std::string::npos && + lines[i].find("cudaLaunch") != std::string::npos; #endif #ifdef USE_ROCM - has_api_info = has_api_info || lines[i].find("Api") != string::npos && - lines[i].find("hipLaunch") != string::npos; + has_api_info = has_api_info || lines[i].find("Api") != std::string::npos && + lines[i].find("hipLaunch") != std::string::npos; #endif #ifdef USE_WEBGPU - has_api_info = has_api_info || lines[i].find("Api") != string::npos; + has_api_info = has_api_info || lines[i].find("Api") != std::string::npos; #endif } } @@ -769,17 +770,17 @@ TEST(InferenceSessionTests, CheckRunProfilerWithStartProfile) { int count = 0; while (std::getline(profile, line)) { if (count == 0) { - ASSERT_TRUE(line.find("[") != string::npos); + ASSERT_TRUE(line.find("[") != std::string::npos); } else if (count <= 3) { for (auto& s : tags) { - ASSERT_TRUE(line.find(s) != string::npos); + ASSERT_TRUE(line.find(s) != std::string::npos); } } else { - ASSERT_TRUE(line.find("]") != string::npos); + ASSERT_TRUE(line.find("]") != std::string::npos); } if (count == 1) { - ASSERT_TRUE(line.find("mul_1_kernel_time") != string::npos); + ASSERT_TRUE(line.find("mul_1_kernel_time") != std::string::npos); } count++; } @@ -929,7 +930,7 @@ TEST(InferenceSessionTests, ConfigureVerbosityLevel) { std::copy(msgs.begin(), msgs.end(), std::ostream_iterator(std::cout, "\n")); bool have_log_entry_with_vlog_session_msg = (std::find_if(msgs.begin(), msgs.end(), - [&](std::string msg) { return msg.find("Added input argument with name") != string::npos; }) != + [&](std::string msg) { return msg.find("Added input argument with name") != std::string::npos; }) != msgs.end()); ASSERT_TRUE(have_log_entry_with_vlog_session_msg); @@ -942,7 +943,8 @@ TEST(InferenceSessionTests, ConfigureVerbosityLevel) { // ASSERT_TRUE(have_log_entry_with_vlog_run_msg); bool has_num_streams_msg = - (std::find_if(msgs.begin(), msgs.end(), [&](std::string msg) { return msg.find("Number of streams") != string::npos; }) != msgs.end()); + (std::find_if(msgs.begin(), msgs.end(), [&](std::string msg) { return msg.find("Number of streams") != + std::string::npos; }) != msgs.end()); ASSERT_TRUE(has_num_streams_msg); #endif @@ -983,7 +985,7 @@ TEST(InferenceSessionTests, UseUserSpecifiedLoggingFunctionInSession) { #ifndef NDEBUG bool have_log_entry_with_vlog_session_msg = (std::find_if(log_msgs.begin(), log_msgs.end(), - [&](std::string msg) { return msg.find("Added input argument with name") != string::npos; }) != + [&](std::string msg) { return msg.find("Added input argument with name") != std::string::npos; }) != log_msgs.end()); ASSERT_TRUE(have_log_entry_with_vlog_session_msg); #endif @@ -996,7 +998,7 @@ TEST(InferenceSessionTests, TestWithIstream) { InferenceSession session_object{so, GetEnvironment()}; - std::ifstream model_file_stream(MODEL_URI, ios::in | ios::binary); + std::ifstream model_file_stream(MODEL_URI, std::ios::in | std::ios::binary); ASSERT_TRUE(model_file_stream.good()); ASSERT_TRUE(session_object.Load(model_file_stream).IsOK()); ASSERT_STATUS_OK(session_object.Initialize()); @@ -1015,7 +1017,7 @@ TEST(InferenceSessionTests, TestRegisterExecutionProvider) { CPUExecutionProviderInfo epi; ASSERT_TRUE(session_object.RegisterExecutionProvider(std::make_unique(epi)).IsOK()); - std::ifstream model_file_stream(MODEL_URI, ios::in | ios::binary); + std::ifstream model_file_stream(MODEL_URI, std::ios::in | std::ios::binary); ASSERT_TRUE(model_file_stream.good()); ASSERT_TRUE(session_object.Load(model_file_stream).IsOK()); ASSERT_STATUS_OK(session_object.Initialize()); @@ -1092,13 +1094,14 @@ TEST(InferenceSessionTests, TestIOBindingReuse) { std::stringstream sstr(s1); ASSERT_TRUE(session_object.Load(sstr).IsOK()); ASSERT_STATUS_OK(session_object.Initialize()); - unique_ptr io_binding; + std::unique_ptr io_binding; Status st = session_object.NewIOBinding(&io_binding); ASSERT_TRUE(st.IsOK()); OrtValue ml_value1; - vector v1{2.f}; - CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], {1}, v1, &ml_value1); + const std::vector v1{2.f}; + const int64_t shape[] = {1}; + CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], shape, v1, &ml_value1); ASSERT_STATUS_OK(io_binding->BindOutput("foo", ml_value1)); ASSERT_TRUE(io_binding->GetOutputs().size() == 1); auto span = io_binding->GetOutputs()[0].Get().DataAsSpan(); @@ -1108,8 +1111,8 @@ TEST(InferenceSessionTests, TestIOBindingReuse) { } OrtValue ml_value2; - vector v2{3.f}; - CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], {1}, v2, &ml_value2); + const std::vector v2{3.f}; + CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], shape, v2, &ml_value2); ASSERT_STATUS_OK(io_binding->BindOutput("foo", ml_value2)); ASSERT_TRUE(io_binding->GetOutputs().size() == 1); span = io_binding->GetOutputs()[0].Get().DataAsSpan(); @@ -1651,7 +1654,7 @@ TEST(InferenceSessionTests, Test3LayerNestedSubgraph) { run_options.run_tag = so.session_logid; std::vector dim = {1}; - std::vector va = {false}; + InlinedVector va = {false}; OrtValue ml_value_x; CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], dim, va, &ml_value_x); @@ -1807,8 +1810,9 @@ TEST(InferenceSessionTests, Test2LayerNestedSubgraph) { OrtValue ml_value_input_0; CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], dim_input_0, data_input_0, &ml_value_input_0); - std::vector dim_input_1 = {1}; - std::vector data_input_1 = {false}; + + const int64_t dim_input_1[] = {1}; + const bool data_input_1[] = {false}; OrtValue ml_value_input_1; CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], dim_input_1, data_input_1, &ml_value_input_1); @@ -2047,7 +2051,7 @@ TEST(InferenceSessionTests, TestCopyToFromDevices) { // It creates and registers a dummy transformer and after session initialize // validates that this transformer was called regardless of the graph optimization level set. TEST(InferenceSessionTests, TestRegisterTransformers) { - string model_uri = "testdata/transform/fusion/fuse-conv-bn-mul-add-unsqueeze.onnx"; + std::string model_uri = "testdata/transform/fusion/fuse-conv-bn-mul-add-unsqueeze.onnx"; for (int i = static_cast(TransformerLevel::Default); i <= static_cast(TransformerLevel::MaxLevel); i++) { SessionOptions so; @@ -2126,7 +2130,7 @@ TEST(InferenceSessionTests, TestStrictShapeInference) { tester.AddInput("data", input_shape, input_data); tester.AddOutput("output", invalid_output_shape, output_data); - const std::unordered_set excluded_provider_types = { + const std::unordered_set excluded_provider_types = { kTensorrtExecutionProvider, // Doesn't handle Unsqueeze. kOpenVINOExecutionProvider}; // Disabled temporarily. @@ -2144,7 +2148,7 @@ TEST(InferenceSessionTests, TestStrictShapeInference) { #ifdef USE_CUDA // disable it, since we are going to enable parallel execution with cuda ep TEST(InferenceSessionTests, DISABLED_TestParallelExecutionWithCudaProvider) { - string model_uri = "testdata/transform/fusion/fuse-conv-bn-mul-add-unsqueeze.onnx"; + std::string model_uri = "testdata/transform/fusion/fuse-conv-bn-mul-add-unsqueeze.onnx"; SessionOptions so; so.execution_mode = ExecutionMode::ORT_PARALLEL; @@ -2822,10 +2826,10 @@ TEST(InferenceSessionTests, InitializerSharing_EnsureSessionsUseUserAddedInitial std::vector input_data_vec{1., 2., 3., 4., 5., 6.}; auto allocator = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; - CreateMLValue(allocator, {3, 2}, input_data_vec, &val_to_share_from_allocator); + CreateMLValue(allocator, AsSpan({3, 2}), input_data_vec, &val_to_share_from_allocator); OrtMemoryInfo mem_info{CPU, OrtArenaAllocator}; - CreateMLValue(std::array{3, 2}, input_data_vec.data(), mem_info, &val_to_share); + CreateMLValue(AsSpan({3, 2}), input_data_vec.data(), mem_info, &val_to_share); // create sessions to share the allocator SessionOptions so1; diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index e7f8b1aaa49d8..b6b915f90d99a 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -22,6 +22,8 @@ #include "core/util/thread_utils.h" #include "gtest/gtest.h" #include "test/test_environment.h" +#include "test/optimizer/graph_transform_test_builder.h" +#include "test/util/include/test_environment.h" #include "test/util/include/default_providers.h" #include "test/util/include/file_util.h" #include "core/optimizer/layout_transformation/layout_transformation.h" @@ -440,6 +442,123 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { } } +#ifdef USE_CUDA + +namespace { + +using ParitionVerifierFn = std::function; + +void LoadWithResourceAwarePartitioning(const ORTCHAR_T* model_path, + const SessionOptions& sess_options, + const ParitionVerifierFn& verifier_fn) { + const auto& log_manager = DefaultLoggingManager(); + log_manager.SetDefaultLoggerSeverity(onnxruntime::logging::Severity::kVERBOSE); + const auto& default_logger = log_manager.DefaultLogger(); + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_path, model, nullptr, default_logger)); + + Graph& graph = model->MainGraph(); + ASSERT_STATUS_OK(graph.Resolve()); + + OrtThreadPoolParams to; + to.thread_pool_size = 1; + auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP); + + ExecutionProviders execution_providers; + auto tmp_cpu_execution_provider = DefaultCudaExecutionProvider(); + tmp_cpu_execution_provider->SetLogger(&default_logger); + ASSERT_STATUS_OK(execution_providers.Add(kCudaExecutionProvider, std::move(tmp_cpu_execution_provider))); + + KernelRegistryManager krm; + ASSERT_STATUS_OK(krm.RegisterKernels(execution_providers)); + + DataTransferManager dtm; + ExternalDataLoaderManager edlm; + profiling::Profiler profiler; + + SessionState session_state(model->MainGraph(), execution_providers, tp.get(), nullptr, dtm, edlm, + default_logger, profiler, sess_options); + + GraphPartitioner partitioner(krm, execution_providers); + layout_transformation::TransformLayoutFunction transform_layout_fn; + layout_transformation::DebugGraphFn debug_graph_fn; + ASSERT_STATUS_OK( + partitioner.Partition(graph, session_state.GetMutableFuncMgr(), transform_layout_fn, + sess_options.config_options, default_logger, + GraphPartitioner::Mode::kNormal, debug_graph_fn)); + + verifier_fn(graph); +} +} // namespace + +TEST(SessionStateTest, TestResourceAwarePartitioning_NoLimit) { + constexpr const ORTCHAR_T* model_path = ORT_TSTR("testdata/transformers/tiny_gpt2_beamsearch.onnx"); + + // Try to load the model without restrictions + // and verify nodes have been placed to CUDA + SessionOptions sess_options; + sess_options.enable_mem_pattern = false; + sess_options.execution_mode = ExecutionMode::ORT_SEQUENTIAL; + sess_options.use_deterministic_compute = false; + sess_options.enable_mem_reuse = false; + + LoadWithResourceAwarePartitioning(model_path, sess_options, [](const Graph& graph) { + const auto& graph_nodes = graph.Nodes(); + for (const auto& node : graph_nodes) { + EXPECT_EQ(node.GetExecutionProviderType(), kCudaExecutionProvider); + } + }); +} + +TEST(SessionStateTest, TestResourceAwarePartitioning_LargeLimit) { + constexpr const ORTCHAR_T* model_path = ORT_TSTR("testdata/transformers/tiny_gpt2_beamsearch.onnx"); + constexpr const char* limit_setting = "10000,tiny_gpt2_beamsearch_node_stats.txt"; + + // Large limit, all nodes are still assigned + SessionOptions sess_options; + sess_options.enable_mem_pattern = false; + sess_options.execution_mode = ExecutionMode::ORT_SEQUENTIAL; + sess_options.use_deterministic_compute = false; + sess_options.enable_mem_reuse = false; + ASSERT_STATUS_OK(sess_options.config_options.AddConfigEntry( + kOrtSessionOptionsResourceCudaPartitioningSettings, limit_setting)); + + LoadWithResourceAwarePartitioning(model_path, sess_options, [](const Graph& graph) { + const auto& graph_nodes = graph.Nodes(); + for (const auto& node : graph_nodes) { + EXPECT_EQ(node.GetExecutionProviderType(), kCudaExecutionProvider); + } + }); +} + +TEST(SessionStateTest, TestResourceAwarePartitioning_CPUOffloaded) { + constexpr const ORTCHAR_T* model_path = ORT_TSTR("testdata/transformers/tiny_gpt2_beamsearch.onnx"); + constexpr const char* limit_setting = "5000,tiny_gpt2_beamsearch_node_stats.txt"; + + // Large limit, all nodes are still assigned + SessionOptions sess_options; + sess_options.enable_mem_pattern = false; + sess_options.execution_mode = ExecutionMode::ORT_SEQUENTIAL; + sess_options.use_deterministic_compute = false; + sess_options.enable_mem_reuse = false; + ASSERT_STATUS_OK(sess_options.config_options.AddConfigEntry( + kOrtSessionOptionsResourceCudaPartitioningSettings, limit_setting)); + + LoadWithResourceAwarePartitioning(model_path, sess_options, [](const Graph& graph) { + const auto& graph_nodes = graph.Nodes(); + bool cpu_node_found = false; + for (const auto& node : graph_nodes) { + if (node.GetExecutionProviderType() != kCudaExecutionProvider) { + cpu_node_found = true; + break; + } + } + EXPECT_TRUE(cpu_node_found); + }); +} + +#endif // USE_CUDA + INSTANTIATE_TEST_SUITE_P(SessionStateTests, SessionStateTestP, testing::ValuesIn(param_list)); #ifndef ENABLE_TRAINING_CORE diff --git a/onnxruntime/test/framework/test_utils.h b/onnxruntime/test/framework/test_utils.h index 51b02ee3e7f8c..9c5893948ff1b 100644 --- a/onnxruntime/test/framework/test_utils.h +++ b/onnxruntime/test/framework/test_utils.h @@ -32,8 +32,13 @@ namespace test { IExecutionProvider* TestCPUExecutionProvider(); template +inline void CopyVectorToTensor(gsl::span value, Tensor& tensor) { + gsl::copy(value, tensor.MutableDataAsSpan()); +} + +template inline void CopyVectorToTensor(const std::vector& value, Tensor& tensor) { - gsl::copy(gsl::make_span(value), tensor.MutableDataAsSpan()); + gsl::copy(AsSpan(value), tensor.MutableDataAsSpan()); } // vector is specialized so we need to handle it separately @@ -45,8 +50,20 @@ inline void CopyVectorToTensor(const std::vector& value, Tensor& ten } } +template +void CreateMLValue(AllocatorPtr alloc, gsl::span dims, const std::vector& value, + OrtValue* p_mlvalue) { + TensorShape shape(dims); + auto element_type = DataTypeImpl::GetType(); + Tensor::InitOrtValue(element_type, shape, std::move(alloc), *p_mlvalue); + if (!value.empty()) { + Tensor& tensor = *p_mlvalue->GetMutable(); + CopyVectorToTensor(value, tensor); + } +} + template -void CreateMLValue(AllocatorPtr alloc, const std::vector& dims, const std::vector& value, +void CreateMLValue(AllocatorPtr alloc, gsl::span dims, gsl::span value, OrtValue* p_mlvalue) { TensorShape shape(dims); auto element_type = DataTypeImpl::GetType(); @@ -58,6 +75,24 @@ void CreateMLValue(AllocatorPtr alloc, const std::vector& dims, const s } } +template +void CreateMLValue(AllocatorPtr alloc, std::initializer_list dims, gsl::span value, + OrtValue* p_mlvalue) { + CreateMLValue(alloc, AsSpan(dims), value, p_mlvalue); +} + +template +void CreateMLValue(AllocatorPtr alloc, gsl::span dims, std::initializer_list value, + OrtValue* p_mlvalue) { + CreateMLValue(alloc, dims, AsSpan(value), p_mlvalue); +} + +template +void CreateMLValue(AllocatorPtr alloc, std::initializer_list dims, std::initializer_list value, + OrtValue* p_mlvalue) { + CreateMLValue(alloc, AsSpan(dims), AsSpan(value), p_mlvalue); +} + // Lifetime of data_buffer should be managed by the caller. template void CreateMLValue(gsl::span dims, T* data_buffer, const OrtMemoryInfo& info, @@ -68,7 +103,7 @@ void CreateMLValue(gsl::span dims, T* data_buffer, const OrtMemor } template -void AllocateMLValue(AllocatorPtr alloc, const std::vector& dims, OrtValue* p_mlvalue) { +void AllocateMLValue(AllocatorPtr alloc, gsl::span dims, OrtValue* p_mlvalue) { TensorShape shape(dims); auto element_type = DataTypeImpl::GetType(); Tensor::InitOrtValue(element_type, shape, std::move(alloc), *p_mlvalue); diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc index 2e073def5d643..b753bc386d722 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc @@ -110,7 +110,8 @@ DataLayout InternalTestingExecutionProvider::GetPreferredLayout() const { std::vector> InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& kernel_lookup) const { + const IKernelLookup& kernel_lookup, + IResourceAccountant* /* resource_accountant */) const { // find nodes that have ops in our supported list std::unordered_set supported_static_nodes; std::unordered_set supported_compiled_nodes; diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h index 6615eb82f2b05..d2ed8259ee974 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h @@ -19,7 +19,8 @@ class InternalTestingExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_view, - const IKernelLookup& /*kernel_lookup*/) const override; + const IKernelLookup& /*kernel_lookup*/, + IResourceAccountant* /* resource_accountant */) const override; common::Status Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) override; diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index cb385267baf4e..e2deccc4fff0f 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -281,7 +281,7 @@ static BackendSupport GetHTPSupport(const onnxruntime::logging::Logger& logger) {{"backend_path", "QnnHtp.dll"}, {"offload_graph_io_quantization", "0"}}); qnn_ep->SetLogger(&logger); - auto result = qnn_ep->GetCapability(graph_viewer, kernel_lookup); + auto result = qnn_ep->GetCapability(graph_viewer, kernel_lookup, nullptr); return result.empty() ? BackendSupport::UNSUPPORTED : BackendSupport::SUPPORTED; } @@ -344,7 +344,7 @@ static BackendSupport GetCPUSupport(const onnxruntime::logging::Logger& logger) {{"backend_path", "QnnCpu.dll"}, {"offload_graph_io_quantization", "0"}}); qnn_ep->SetLogger(&logger); - auto result = qnn_ep->GetCapability(graph_viewer, kernel_lookup); + auto result = qnn_ep->GetCapability(graph_viewer, kernel_lookup, nullptr); return result.empty() ? BackendSupport::UNSUPPORTED : BackendSupport::SUPPORTED; } diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 2ea99151c2bfd..59920487a7248 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -40,6 +40,7 @@ #endif #ifdef USE_CUDA +#include "core/providers/cuda/cuda_provider_options.h" #include #endif @@ -4777,3 +4778,82 @@ TEST(CApiTest, OrtCustomOp_GetInPlace) { ASSERT_EQ(len, static_cast(2)); mock_gqa.ReleaseAliasMap(input_index, output_index); } + +#if !defined(ORT_MINIMAL_BUILD) && defined(USE_CUDA) + +TEST(CApiTest, GenerateNodeStatsFile) { + Ort::Env env(ORT_LOGGING_LEVEL_INFO); + constexpr const ORTCHAR_T* model_path = TSTR("testdata/transformers/tiny_gpt2_beamsearch.onnx"); + + Ort::SessionOptions session_options; + session_options.AddConfigEntry(kOrtSessionOptionsCollectNodeMemoryStatsToFile, + "tiny_gpt2_beamsearch_node_stats.txt"); + + OrtCUDAProviderOptionsV2 cuda_options; + cuda_options.use_tf32 = false; + session_options.AppendExecutionProvider_CUDA_V2(cuda_options); + + std::vector input_ids_shape{3, 12}; + std::vector input_ids{ + 0, 0, 0, 0, 0, 52, 195, 731, 321, 301, 734, 620, + 41, 554, 74, 622, 206, 222, 75, 223, 221, 198, 224, 572, + 0, 0, 0, 52, 328, 219, 328, 206, 288, 227, 896, 328}; + + std::vector parameter_shape{1}; + std::vector max_length{20}; + std::vector min_length{1}; + std::vector num_beams{4}; + std::vector num_return_sequences{1}; + std::vector length_penalty{1.0f}; + std::vector repetition_penalty{1.0f}; + + std::vector expected_output_shape{input_ids_shape[0], num_return_sequences[0], max_length[0]}; + std::vector expected_output{ + 0, 0, 0, 0, 0, 52, 195, 731, 321, 301, 734, 620, 131, 131, 131, 181, 638, 638, 638, 638, + 41, 554, 74, 622, 206, 222, 75, 223, 221, 198, 224, 572, 292, 292, 292, 292, 292, 292, 292, 292, + 0, 0, 0, 52, 328, 219, 328, 206, 288, 227, 896, 328, 328, 669, 669, 669, 669, 669, 669, 669}; + + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + auto input_ids_tensor = Ort::Value::CreateTensor( + info, input_ids.data(), input_ids.size(), input_ids_shape.data(), input_ids_shape.size()); + + auto max_length_tensor = Ort::Value::CreateTensor( + info, max_length.data(), max_length.size(), parameter_shape.data(), parameter_shape.size()); + + auto min_length_tensor = Ort::Value::CreateTensor( + info, min_length.data(), min_length.size(), parameter_shape.data(), parameter_shape.size()); + + auto num_beams_tensor = Ort::Value::CreateTensor( + info, num_beams.data(), num_beams.size(), parameter_shape.data(), parameter_shape.size()); + + auto num_return_sequences_tensor = Ort::Value::CreateTensor( + info, num_return_sequences.data(), num_return_sequences.size(), parameter_shape.data(), parameter_shape.size()); + + auto length_penalty_tensor = Ort::Value::CreateTensor( + info, length_penalty.data(), length_penalty.size(), parameter_shape.data(), parameter_shape.size()); + + auto repetition_penalty_tensor = Ort::Value::CreateTensor( + info, repetition_penalty.data(), repetition_penalty.size(), parameter_shape.data(), parameter_shape.size()); + + std::vector ort_inputs; + ort_inputs.push_back(std::move(input_ids_tensor)); + ort_inputs.push_back(std::move(max_length_tensor)); + ort_inputs.push_back(std::move(min_length_tensor)); + ort_inputs.push_back(std::move(num_beams_tensor)); + ort_inputs.push_back(std::move(num_return_sequences_tensor)); + ort_inputs.push_back(std::move(length_penalty_tensor)); + ort_inputs.push_back(std::move(repetition_penalty_tensor)); + const char* input_names[] = {"input_ids", "max_length", "min_length", "num_beams", "num_return_sequences", + "length_penalty", "repetition_penalty"}; + const char* const output_names[] = {"sequences"}; + + // The ONNX model is generated like the following: + // python convert_generation.py --model_type gpt2 -m hf-internal-testing/tiny-random-gpt2 + // --output tiny_gpt2_beamsearch_fp16.onnx --use_gpu --max_length 20 + // (with separate_gpt2_decoder_for_init_run set to False as it is now set to True by default) + Ort::Session session(env, model_path, session_options); + session.Run(Ort::RunOptions{}, input_names, ort_inputs.data(), ort_inputs.size(), + output_names, 1); +} + +#endif \ No newline at end of file diff --git a/onnxruntime/test/testdata/transformers/tiny_gpt2_beamsearch_node_stats.txt b/onnxruntime/test/testdata/transformers/tiny_gpt2_beamsearch_node_stats.txt new file mode 100644 index 0000000000000..df1e0c48825a0 --- /dev/null +++ b/onnxruntime/test/testdata/transformers/tiny_gpt2_beamsearch_node_stats.txt @@ -0,0 +1,57 @@ +#name,input_sizes,initializers_sizes,total_dynamic_sizes,total_temp_allocations +GptAttention_1_matmul_3390928670334833856,22528,0,0,0 +LayerNorm_8_16340230589392852003,18432,0,0,0 +LayerNorm_6_9539917679182944001,18432,0,0,0 +LayerNorm_4_3998281518089755446,18432,0,0,0 +Add_295_12458934867448263403,18432,0,0,0 +GptAttention_1_5945223373512700064,30720,0,36864,165888 +FastGelu_AddBias_0_8293496556664011978,512,0,73728,0 +FullyConnect_MatMul_7_9121431797220490115,90112,0,0,0 +GptAttention_0_7799922821510396356,13248,0,55296,165888 +GptAttention_2_13772881973491265914,30720,0,36864,165888 +LayerNorm_1_10060807585253518719,18432,0,0,0 +LayerNorm_5_12297409543002935527,18432,0,0,0 +Add_492_15870509848159592443,18432,0,0,0 +FullyConnect_MatMul_5_12754193998971094488,90112,0,0,0 +LayerNorm_7_11450735811828114024,18432,0,0,0 +FullyConnect_Add_5_4749853671277160818,18432,0,0,0 +GptAttention_3_add_5419272690383812111,18432,0,0,0 +FullyConnect_MatMul_8_14154070846330210236,34816,0,0,0 +FullyConnect_MatMul_9_9215108924175066058,90112,0,0,0 +GptAttention_2_add_7251589488810842639,18432,0,0,0 +FullyConnect_Add_7_2612800351421913827,18432,0,0,0 +GptAttention_1_add_3894862726029568115,18432,0,0,0 +FullyConnect_MatMul_2_4814122527985171273,34816,0,0,0 +LayerNorm_3_3589946186712403351,18432,0,0,0 +GptAttention_3_8921810316598002134,30720,0,36864,165888 +LayerNorm_9_9113032450990548295,18432,0,0,0 +Add_886_7198133075029541336,18432,0,0,0 +Add_689_16588197583517413999,18432,0,0,0 +GptAttention_3_matmul_14740826065423798917,22528,0,0,0 +FastGelu_AddBias_4_17289691003819959460,73728,0,0,0 +Add_754_3697562882104452642,18432,0,0,0 +FullyConnect_MatMul_4_3508821612885617837,34816,0,0,0 +FastGelu_AddBias_1_17699324882619485158,73728,0,0,0 +FullyConnect_MatMul_3_17781936527365066348,90112,0,0,0 +GptAttention_2_matmul_7328860221231123895,22528,0,0,0 +SkipLayerNormalization_6957325406340516852,18432,0,0,0 +BeamSearch_gpt2_3957842931497654942,24,0,256,1823244 +GptAttention_4_matmul_90143216136586800,22528,0,0,0 +FullyConnect_MatMul_6_11858231833228352542,34816,0,0,0 +GptAttention_0_matmul_16767551145055538728,4096,0,0,0 +FullyConnect_Add_3_17196504264676187520,18432,0,0,0 +GptAttention_0_add_9807374014361508564,18432,0,0,0 +FullyConnect_MatMul_1_17322107022932292417,16384,0,0,0 +GptAttention_4_14364416985904266109,30720,0,36864,165888 +FullyConnect_MatMul_0_3724322618026197588,34816,0,73728,0 +Add_557_10312911821132522354,18432,0,0,0 +Add_360_12940403527838064497,18432,0,0,0 +FastGelu_AddBias_3_13817144420946871274,73728,0,0,0 +EmbedLayerNormalization_0_7260843146120485633,194944,0,37120,0 +FastGelu_AddBias_2_7906787140370676932,73728,0,0,0 +MatMul_1165_4290064500958888402,146432,0,576000,0 +GptAttention_4_add_15131081400494402711,18432,0,0,0 +Add_1083_4580993573699232732,18432,0,0,0 +Add_951_2303460452509012571,18432,0,0,0 +LayerNorm_2_2575702077895349965,18432,0,0,0 +FullyConnect_Add_1_7648227151832366839,18432,0,0,0