diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 39d2562852eee..f6ac9ef81b60f 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -60,6 +60,15 @@ static const char* const kOrtSessionOptionsDisableQuantQDQ = "session.disable_qu static const char* const kOrtSessionOptionsDisableQDQConstantFolding = "session.disable_qdq_constant_folding"; +// Constant folding produces new initializers (folded outputs) that get added to the graph. +// This option limits the maximum size in bytes of any single constant folding output tensor. +// Nodes whose folded output(s) would exceed this limit are skipped to prevent the optimized +// model's memory footprint from growing too much compared to the original model. +// The value should be a non-negative integer in decimal string form. +// The default value of "0" disables the threshold check (all sizes are allowed). +static const char* const kOrtSessionOptionsConfigConstantFoldingNodeWeightSizeThreshold = + "session.constant_folding_node_weight_size_threshold"; + // It controls whether to enable Double QDQ remover and Identical Children Consolidation // "0": not to disable. ORT does remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs // "1": disable. ORT doesn't remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index cb6d65342bc54..d1b462b682374 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -11,6 +11,8 @@ #include "core/optimizer/utils.h" #include "core/framework/op_kernel.h" #include "core/framework/tensorprotoutils.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/common/parse_string.h" using namespace onnxruntime::common; @@ -145,6 +147,18 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, GraphViewer graph_viewer(graph); auto& order = graph_viewer.GetNodesInTopologicalOrder(); + // Read the optional size threshold for constant folding. A value of 0 (the default) means no limit. + size_t output_size_threshold = 0; + { + const std::string threshold_str = config_options_.GetConfigOrDefault( + kOrtSessionOptionsConfigConstantFoldingNodeWeightSizeThreshold, "0"); + if (!TryParseStringWithClassicLocale(threshold_str, output_size_threshold)) { + LOGS(logger, WARNING) << "Failed to parse constant folding size threshold from config value '" + << threshold_str << "'. Using no threshold."; + output_size_threshold = 0; + } + } + #if !defined(DISABLE_SPARSE_TENSORS) std::function is_sparse_initializer_check = [&graph](const std::string& name) -> bool { return graph.IsSparseInitializer(name); @@ -273,6 +287,89 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, continue; } + // If a size threshold is configured, estimate the net memory impact before computation. + // The net increase is the total estimated output size minus the sizes of constant inputs + // that will be freed (i.e., inputs exclusively consumed by this node). Only skip when + // the net increase exceeds the threshold. + if (output_size_threshold > 0) { + // Step 1: sum up the estimated output sizes. If any output has an unknown shape, + // we skip the threshold check entirely and proceed with the folding. + size_t total_estimated_output_size = 0; + bool all_output_sizes_known = true; + for (size_t output_idx : fetch_to_output_idx) { + const auto* node_out = node->OutputDefs()[output_idx]; + const auto* type_proto = node_out->TypeAsProto(); + if (type_proto == nullptr || !utils::HasTensorType(*type_proto)) { + all_output_sizes_known = false; + break; + } + const auto& tensor_type = type_proto->tensor_type(); + if (!utils::HasElemType(tensor_type) || !utils::HasShape(tensor_type)) { + all_output_sizes_known = false; + break; + } + const auto elem_type = static_cast(tensor_type.elem_type()); + const size_t elem_size = utils::GetElementSizeOfTensor(elem_type); + if (elem_size == 0) { + all_output_sizes_known = false; + break; + } + const auto& shape = tensor_type.shape(); + size_t num_elements = 1; + bool all_dims_known = true; + for (const auto& dim : shape.dim()) { + if (!utils::HasDimValue(dim) || dim.dim_value() < 0) { + all_dims_known = false; + break; + } + num_elements *= static_cast(dim.dim_value()); + } + if (!all_dims_known) { + all_output_sizes_known = false; + break; + } + total_estimated_output_size += num_elements * elem_size; + } + + if (all_output_sizes_known) { + // Step 2: compute the sizes of constant inputs that will be freed after folding. + // An input is freed only if this is its sole consumer. + size_t freed_input_size = 0; + for (const auto& [inp_name, tensor_proto] : constant_inputs) { + if (graph.GetConsumerNodes(inp_name).size() == 1) { + const size_t inp_elem_size = utils::GetElementSizeOfTensor( + static_cast(tensor_proto->data_type())); + if (inp_elem_size > 0) { + size_t num_inp_elements = 1; + bool valid = true; + for (int64_t d : tensor_proto->dims()) { + if (d < 0) { + valid = false; + break; + } + num_inp_elements *= static_cast(d); + } + if (valid) { + freed_input_size += num_inp_elements * inp_elem_size; + } + } + } + } + + // The net memory increase is outputs added minus inputs freed (floor at 0). + const size_t net_increase = (total_estimated_output_size > freed_input_size) + ? total_estimated_output_size - freed_input_size + : 0; + if (net_increase > output_size_threshold) { + LOGS(logger, INFO) << "Skipping constant folding for " << node->OpType() + << " node '" << node->Name() + << "': estimated net memory increase " << net_increase + << " bytes exceeds the threshold of " << output_size_threshold << " bytes."; + continue; + } + } + } + const bool node_on_cpu_ep = node->GetExecutionProviderType() == kCpuExecutionProvider; std::unique_ptr kernel; diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 2aae3383a1072..011c92bf4e17c 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -601,6 +601,135 @@ TEST_F(GraphTransformationTests, ConstantFolding) { ASSERT_TRUE(op_to_count["Unsqueeze"] == 0); } +// Test that constant folding respects the size threshold config option. +// The threshold guards against net memory *increases*: output_size - freed_input_size. +// Inputs are "freed" only when the node is their sole consumer. +TEST_F(GraphTransformationTests, ConstantFoldingWithSizeThreshold) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-conv-bn-mul-add-unsqueeze.onnx"; + + // Case 1: no threshold — all Unsqueeze nodes are folded. + { + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); + Graph& graph = model->MainGraph(); + ASSERT_EQ(CountOpsInGraph(graph)["Unsqueeze"], 2); + + std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + const ConfigOptions empty_config_options; + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(*e.get(), false /*skip_dequantize_linear*/, empty_config_options), + TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + ASSERT_EQ(CountOpsInGraph(graph)["Unsqueeze"], 0); + } + + // Case 2: threshold of 1 byte — Unsqueeze nodes still fold because each input is + // exclusively consumed (freed after folding), making the net memory increase zero, + // which does not exceed the 1-byte threshold. + { + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); + Graph& graph = model->MainGraph(); + ASSERT_EQ(CountOpsInGraph(graph)["Unsqueeze"], 2); + + std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ConfigOptions config_options_with_threshold; + ASSERT_STATUS_OK(config_options_with_threshold.AddConfigEntry( + kOrtSessionOptionsConfigConstantFoldingNodeWeightSizeThreshold, "1")); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(*e.get(), false /*skip_dequantize_linear*/, config_options_with_threshold), + TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + // Net increase = output_size (256 bytes) - freed_input_size (256 bytes) = 0, which <= 1. + ASSERT_EQ(CountOpsInGraph(graph)["Unsqueeze"], 0); + } + + // Case 3: build a Tile graph where the output is genuinely larger than the inputs. + // tile_input: float32[1] = {1.0} → 4 bytes (exclusively consumed) + // tile_repeats: int64[1] = {200} → 8 bytes (exclusively consumed) + // Tile output: float32[200] → 800 bytes + // Net increase = 800 - 4 - 8 = 788 bytes. + auto build_tile_graph = [](Graph& graph) { + // Add initializers + TensorProto tile_input_tp; + tile_input_tp.set_name("tile_input"); + tile_input_tp.add_dims(1); + tile_input_tp.add_float_data(1.0f); + tile_input_tp.set_data_type(TensorProto_DataType_FLOAT); + graph.AddInitializedTensor(tile_input_tp); + + TensorProto tile_repeats_tp; + tile_repeats_tp.set_name("tile_repeats"); + tile_repeats_tp.add_dims(1); + tile_repeats_tp.add_int64_data(200LL); + tile_repeats_tp.set_data_type(TensorProto_DataType_INT64); + graph.AddInitializedTensor(tile_repeats_tp); + + TypeProto float_1_type; + float_1_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + float_1_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + + TypeProto int64_1_type; + int64_1_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64); + int64_1_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + + TypeProto float_200_type; + float_200_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + float_200_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(200); + + auto& tile_input_arg = graph.GetOrCreateNodeArg("tile_input", &float_1_type); + auto& tile_repeats_arg = graph.GetOrCreateNodeArg("tile_repeats", &int64_1_type); + auto& tile_output_arg = graph.GetOrCreateNodeArg("tile_output", &float_200_type); + + graph.AddNode("tile", "Tile", "Tile node", {&tile_input_arg, &tile_repeats_arg}, {&tile_output_arg}); + ASSERT_STATUS_OK(graph.Resolve()); + }; + + // Case 3a: threshold below net increase (100 < 788) — Tile node must NOT be folded. + { + Model model("ConstantFoldingWithSizeThreshold_Tile", + false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 13}}, {}, *logger_); + Graph& graph = model.MainGraph(); + build_tile_graph(graph); + ASSERT_EQ(CountOpsInGraph(graph)["Tile"], 1); + + std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ConfigOptions config_low_threshold; + ASSERT_STATUS_OK(config_low_threshold.AddConfigEntry( + kOrtSessionOptionsConfigConstantFoldingNodeWeightSizeThreshold, "100")); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(*e.get(), false /*skip_dequantize_linear*/, config_low_threshold), + TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + ASSERT_EQ(CountOpsInGraph(graph)["Tile"], 1); // not folded — net (788) > threshold (100) + } + + // Case 3b: threshold above net increase (1000 > 788) — Tile node SHOULD be folded. + { + Model model("ConstantFoldingWithSizeThreshold_Tile2", + false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 13}}, {}, *logger_); + Graph& graph = model.MainGraph(); + build_tile_graph(graph); + ASSERT_EQ(CountOpsInGraph(graph)["Tile"], 1); + + std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ConfigOptions config_high_threshold; + ASSERT_STATUS_OK(config_high_threshold.AddConfigEntry( + kOrtSessionOptionsConfigConstantFoldingNodeWeightSizeThreshold, "1000")); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(*e.get(), false /*skip_dequantize_linear*/, config_high_threshold), + TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + ASSERT_EQ(CountOpsInGraph(graph)["Tile"], 0); // folded — net (788) <= threshold (1000) + } +} + TEST_F(GraphTransformationTests, ConstantFoldingNodesOnDifferentEP) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-conv-bn-mul-add-unsqueeze.onnx"; std::shared_ptr model;