-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Add size threshold to prevent constant folding from inflating model memory footprint #28204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
80f9872
768d741
43b5cf0
38c9da7
b3a5439
dea8bad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,8 @@ | |
| #include "core/optimizer/utils.h" | ||
| #include "core/framework/op_kernel.h" | ||
| #include "core/framework/tensorprotoutils.h" | ||
| #include "core/session/onnxruntime_session_options_config_keys.h" | ||
| #include "core/common/parse_string.h" | ||
|
|
||
| using namespace onnxruntime::common; | ||
|
|
||
|
|
@@ -145,6 +147,18 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, | |
| GraphViewer graph_viewer(graph); | ||
| auto& order = graph_viewer.GetNodesInTopologicalOrder(); | ||
|
|
||
| // Read the optional size threshold for constant folding. A value of 0 (the default) means no limit. | ||
| size_t output_size_threshold = 0; | ||
| { | ||
| const std::string threshold_str = config_options_.GetConfigOrDefault( | ||
| kOrtSessionOptionsConfigConstantFoldingNodeWeightSizeThreshold, "0"); | ||
| if (!TryParseStringWithClassicLocale(threshold_str, output_size_threshold)) { | ||
| LOGS(logger, WARNING) << "Failed to parse constant folding size threshold from config value '" | ||
| << threshold_str << "'. Using no threshold."; | ||
| output_size_threshold = 0; | ||
| } | ||
| } | ||
|
|
||
| #if !defined(DISABLE_SPARSE_TENSORS) | ||
| std::function<bool(const std::string&)> is_sparse_initializer_check = [&graph](const std::string& name) -> bool { | ||
| return graph.IsSparseInitializer(name); | ||
|
|
@@ -336,6 +350,24 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, | |
| } | ||
| } | ||
|
|
||
| // If a size threshold was configured, check whether any output tensor exceeds it. | ||
| // Skipping large outputs prevents the optimized model from having a much larger | ||
| // memory footprint than the original model. | ||
| if (converted_to_constant && output_size_threshold > 0) { | ||
| for (const OrtValue& ort_value : fetches) { | ||
| if (ort_value.IsTensor()) { | ||
| const size_t tensor_size = ort_value.Get<Tensor>().SizeInBytes(); | ||
| if (tensor_size > output_size_threshold) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @copilot if inputs are deleted, you need to take that into account and only consider the increase
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in dea8bad. The threshold check now computes the net memory increase: |
||
| LOGS(logger, INFO) << "Skipping constant folding for " << node->OpType() | ||
| << " node '" << node->Name() << "': output size " << tensor_size | ||
| << " bytes exceeds the threshold of " << output_size_threshold << " bytes."; | ||
| converted_to_constant = false; | ||
| break; | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if (converted_to_constant) { | ||
| for (size_t fetch_idx = 0; fetch_idx < fetches.size(); ++fetch_idx) { | ||
| OrtValue& ort_value = fetches[fetch_idx]; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot you need to move the test before the computation happens and use the estimated size before computation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in 38c9da7. The check now runs before the computation: for each output
NodeArg, the elem type and all concrete dim values are read from theTypeAsProto, andestimated_size = num_elements * elem_sizeis compared against the threshold beforeOptimizerExecutionFrameis created or the kernel is run. If any dim is symbolic/unknown the pre-check is skipped for that output (safe fallback).