Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ static const char* const kOrtSessionOptionsDisableQuantQDQ = "session.disable_qu
static const char* const kOrtSessionOptionsDisableQDQConstantFolding =
"session.disable_qdq_constant_folding";

// Constant folding produces new initializers (folded outputs) that get added to the graph.
// This option limits the maximum size in bytes of any single constant folding output tensor.
// Nodes whose folded output(s) would exceed this limit are skipped to prevent the optimized
// model's memory footprint from growing too much compared to the original model.
// The value should be a non-negative integer in decimal string form.
// The default value of "0" disables the threshold check (all sizes are allowed).
static const char* const kOrtSessionOptionsConfigConstantFoldingNodeWeightSizeThreshold =
"session.constant_folding_node_weight_size_threshold";

// It controls whether to enable Double QDQ remover and Identical Children Consolidation
// "0": not to disable. ORT does remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
// "1": disable. ORT doesn't remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
Expand Down
97 changes: 97 additions & 0 deletions onnxruntime/core/optimizer/constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include "core/optimizer/utils.h"
#include "core/framework/op_kernel.h"
#include "core/framework/tensorprotoutils.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/common/parse_string.h"

using namespace onnxruntime::common;

Expand Down Expand Up @@ -145,6 +147,18 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
GraphViewer graph_viewer(graph);
auto& order = graph_viewer.GetNodesInTopologicalOrder();

// Read the optional size threshold for constant folding. A value of 0 (the default) means no limit.
size_t output_size_threshold = 0;
{
const std::string threshold_str = config_options_.GetConfigOrDefault(
kOrtSessionOptionsConfigConstantFoldingNodeWeightSizeThreshold, "0");
if (!TryParseStringWithClassicLocale(threshold_str, output_size_threshold)) {
LOGS(logger, WARNING) << "Failed to parse constant folding size threshold from config value '"
<< threshold_str << "'. Using no threshold.";
output_size_threshold = 0;
}
}

#if !defined(DISABLE_SPARSE_TENSORS)
std::function<bool(const std::string&)> is_sparse_initializer_check = [&graph](const std::string& name) -> bool {
return graph.IsSparseInitializer(name);
Expand Down Expand Up @@ -273,6 +287,89 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}

// If a size threshold is configured, estimate the net memory impact before computation.
// The net increase is the total estimated output size minus the sizes of constant inputs
// that will be freed (i.e., inputs exclusively consumed by this node). Only skip when
// the net increase exceeds the threshold.
if (output_size_threshold > 0) {
// Step 1: sum up the estimated output sizes. If any output has an unknown shape,
// we skip the threshold check entirely and proceed with the folding.
size_t total_estimated_output_size = 0;
bool all_output_sizes_known = true;
for (size_t output_idx : fetch_to_output_idx) {
const auto* node_out = node->OutputDefs()[output_idx];
const auto* type_proto = node_out->TypeAsProto();
if (type_proto == nullptr || !utils::HasTensorType(*type_proto)) {
all_output_sizes_known = false;
break;
}
const auto& tensor_type = type_proto->tensor_type();
if (!utils::HasElemType(tensor_type) || !utils::HasShape(tensor_type)) {
all_output_sizes_known = false;
break;
}
const auto elem_type = static_cast<ONNX_NAMESPACE::TensorProto_DataType>(tensor_type.elem_type());
const size_t elem_size = utils::GetElementSizeOfTensor(elem_type);
if (elem_size == 0) {
all_output_sizes_known = false;
break;
}
const auto& shape = tensor_type.shape();
size_t num_elements = 1;
bool all_dims_known = true;
for (const auto& dim : shape.dim()) {
if (!utils::HasDimValue(dim) || dim.dim_value() < 0) {
all_dims_known = false;
break;
}
num_elements *= static_cast<size_t>(dim.dim_value());
}
if (!all_dims_known) {
all_output_sizes_known = false;
break;
}
total_estimated_output_size += num_elements * elem_size;
}

if (all_output_sizes_known) {
// Step 2: compute the sizes of constant inputs that will be freed after folding.
// An input is freed only if this is its sole consumer.
size_t freed_input_size = 0;
for (const auto& [inp_name, tensor_proto] : constant_inputs) {
if (graph.GetConsumerNodes(inp_name).size() == 1) {
const size_t inp_elem_size = utils::GetElementSizeOfTensor(
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(tensor_proto->data_type()));
if (inp_elem_size > 0) {
size_t num_inp_elements = 1;
bool valid = true;
for (int64_t d : tensor_proto->dims()) {
if (d < 0) {
valid = false;
break;
}
num_inp_elements *= static_cast<size_t>(d);
}
if (valid) {
freed_input_size += num_inp_elements * inp_elem_size;
}
}
}
}

// The net memory increase is outputs added minus inputs freed (floor at 0).
const size_t net_increase = (total_estimated_output_size > freed_input_size)
? total_estimated_output_size - freed_input_size
: 0;
if (net_increase > output_size_threshold) {
LOGS(logger, INFO) << "Skipping constant folding for " << node->OpType()
<< " node '" << node->Name()
<< "': estimated net memory increase " << net_increase
<< " bytes exceeds the threshold of " << output_size_threshold << " bytes.";
continue;
}
}
}

const bool node_on_cpu_ep = node->GetExecutionProviderType() == kCpuExecutionProvider;

std::unique_ptr<const OpKernel> kernel;
Expand Down
129 changes: 129 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,135 @@ TEST_F(GraphTransformationTests, ConstantFolding) {
ASSERT_TRUE(op_to_count["Unsqueeze"] == 0);
}

// Test that constant folding respects the size threshold config option.
// The threshold guards against net memory *increases*: output_size - freed_input_size.
// Inputs are "freed" only when the node is their sole consumer.
TEST_F(GraphTransformationTests, ConstantFoldingWithSizeThreshold) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-conv-bn-mul-add-unsqueeze.onnx";

// Case 1: no threshold — all Unsqueeze nodes are folded.
{
std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
Graph& graph = model->MainGraph();
ASSERT_EQ(CountOpsInGraph(graph)["Unsqueeze"], 2);

std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
const ConfigOptions empty_config_options;
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, empty_config_options),
TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_EQ(CountOpsInGraph(graph)["Unsqueeze"], 0);
}

// Case 2: threshold of 1 byte — Unsqueeze nodes still fold because each input is
// exclusively consumed (freed after folding), making the net memory increase zero,
// which does not exceed the 1-byte threshold.
{
std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
Graph& graph = model->MainGraph();
ASSERT_EQ(CountOpsInGraph(graph)["Unsqueeze"], 2);

std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ConfigOptions config_options_with_threshold;
ASSERT_STATUS_OK(config_options_with_threshold.AddConfigEntry(
kOrtSessionOptionsConfigConstantFoldingNodeWeightSizeThreshold, "1"));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, config_options_with_threshold),
TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
// Net increase = output_size (256 bytes) - freed_input_size (256 bytes) = 0, which <= 1.
ASSERT_EQ(CountOpsInGraph(graph)["Unsqueeze"], 0);
}

// Case 3: build a Tile graph where the output is genuinely larger than the inputs.
// tile_input: float32[1] = {1.0} → 4 bytes (exclusively consumed)
// tile_repeats: int64[1] = {200} → 8 bytes (exclusively consumed)
// Tile output: float32[200] → 800 bytes
// Net increase = 800 - 4 - 8 = 788 bytes.
auto build_tile_graph = [](Graph& graph) {
// Add initializers
TensorProto tile_input_tp;
tile_input_tp.set_name("tile_input");
tile_input_tp.add_dims(1);
tile_input_tp.add_float_data(1.0f);
tile_input_tp.set_data_type(TensorProto_DataType_FLOAT);
graph.AddInitializedTensor(tile_input_tp);

TensorProto tile_repeats_tp;
tile_repeats_tp.set_name("tile_repeats");
tile_repeats_tp.add_dims(1);
tile_repeats_tp.add_int64_data(200LL);
tile_repeats_tp.set_data_type(TensorProto_DataType_INT64);
graph.AddInitializedTensor(tile_repeats_tp);

TypeProto float_1_type;
float_1_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
float_1_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);

TypeProto int64_1_type;
int64_1_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64);
int64_1_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);

TypeProto float_200_type;
float_200_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
float_200_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(200);

auto& tile_input_arg = graph.GetOrCreateNodeArg("tile_input", &float_1_type);
auto& tile_repeats_arg = graph.GetOrCreateNodeArg("tile_repeats", &int64_1_type);
auto& tile_output_arg = graph.GetOrCreateNodeArg("tile_output", &float_200_type);

graph.AddNode("tile", "Tile", "Tile node", {&tile_input_arg, &tile_repeats_arg}, {&tile_output_arg});
ASSERT_STATUS_OK(graph.Resolve());
};

// Case 3a: threshold below net increase (100 < 788) — Tile node must NOT be folded.
{
Model model("ConstantFoldingWithSizeThreshold_Tile",
false, ModelMetaData(), PathString(),
IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 13}}, {}, *logger_);
Graph& graph = model.MainGraph();
build_tile_graph(graph);
ASSERT_EQ(CountOpsInGraph(graph)["Tile"], 1);

std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ConfigOptions config_low_threshold;
ASSERT_STATUS_OK(config_low_threshold.AddConfigEntry(
kOrtSessionOptionsConfigConstantFoldingNodeWeightSizeThreshold, "100"));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, config_low_threshold),
TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_EQ(CountOpsInGraph(graph)["Tile"], 1); // not folded — net (788) > threshold (100)
}

// Case 3b: threshold above net increase (1000 > 788) — Tile node SHOULD be folded.
{
Model model("ConstantFoldingWithSizeThreshold_Tile2",
false, ModelMetaData(), PathString(),
IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 13}}, {}, *logger_);
Graph& graph = model.MainGraph();
build_tile_graph(graph);
ASSERT_EQ(CountOpsInGraph(graph)["Tile"], 1);

std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ConfigOptions config_high_threshold;
ASSERT_STATUS_OK(config_high_threshold.AddConfigEntry(
kOrtSessionOptionsConfigConstantFoldingNodeWeightSizeThreshold, "1000"));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, config_high_threshold),
TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_EQ(CountOpsInGraph(graph)["Tile"], 0); // folded — net (788) <= threshold (1000)
}
}

TEST_F(GraphTransformationTests, ConstantFoldingNodesOnDifferentEP) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-conv-bn-mul-add-unsqueeze.onnx";
std::shared_ptr<Model> model;
Expand Down
Loading