Skip to content

Commit

Permalink
handle mapping ReduceSum
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms committed Feb 11, 2025
1 parent 9235ae5 commit 413a159
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 62 deletions.
1 change: 1 addition & 0 deletions onnxruntime/core/optimizer/constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
auto data_type = input_def->TypeAsProto()->tensor_type().elem_type();
if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 ||
data_type == ONNX_NAMESPACE::TensorProto_DataType_INT16 ||
data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8 ||
data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT16) {
can_constant_fold_dq_node = true;
}
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
#include "core/optimizer/not_where_fusion.h"
#include "core/optimizer/pad_fusion.h"
#include "core/optimizer/pre_shape_node_elimination.h"
#include "core/optimizer/map_to_four_dimension.h"
#ifdef MLAS_TARGET_AMD64_IX86
#include "core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.h"
#endif
Expand Down Expand Up @@ -265,6 +266,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
// run TransposeOptimizer last as it works in a slightly different way by moving Transpose nodes around.
// shouldn't affect the end result - just easier to debug any issue if it's last.
transformers.emplace_back(std::make_unique<TransposeOptimizer>(std::move(cpu_allocator)));

transformers.emplace_back(std::make_unique<MapToFourDimensions>());
} break;

case TransformerLevel::Level2: {
Expand Down
164 changes: 124 additions & 40 deletions onnxruntime/core/optimizer/map_to_four_dimension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,80 +8,155 @@
#include "core/optimizer/utils.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/optimizer_execution_frame.h"
#include "core/optimizer/utils.h"
#include "core/framework/op_kernel.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/data_types.h"
#include "core/session/onnxruntime_c_api.h"

using namespace onnxruntime::common;

namespace onnxruntime {

MapToFourDimensions::MapToFourDimensions(const IExecutionProvider& execution_provider,
bool skip_dequantize_linear,
bool dequantize_initializer_for_dequantize_linear,
const ConfigOptions& config_options,
const InlinedHashSet<std::string_view>& compatible_execution_providers,
const InlinedHashSet<std::string>& excluded_initializers) noexcept
: GraphTransformer("MapToFourDimensions", compatible_execution_providers),
skip_dequantize_linear_(skip_dequantize_linear),
dequantize_initializer_for_dequantize_linear_(dequantize_initializer_for_dequantize_linear),
config_options_(config_options),
excluded_initializers_(excluded_initializers),
execution_provider_(execution_provider) {
MapToFourDimensions::MapToFourDimensions() noexcept
: GraphTransformer("MapToFourDimensions") {
}

onnxruntime::NodeArg* AddSliceReduceConcatNodes(onnxruntime::Graph& graph,
onnxruntime::Node& reshape,
onnxruntime::Node& reduce_sum,
onnxruntime::NodeArg* old_arg,
ONNX_NAMESPACE::TypeProto* new_type,
bool new_on_input,
int64_t to_type,
onnxruntime::ProviderType providerType) {
// Insert 2 Slice nodes, 2 ReduceSum nodes and 1 Concat node.

// Create 2 Slice nodes
/**
* Replace Reshape node and ReduceSum node with
* two Slice nodes, two ReduceSum nodes and one Concat node.
*/
Status AddSliceReduceConcatNodes(Graph& graph, Node& reshape, Node& reduce_sum) {
// Create Slice node names
std::string slice_node_0_name = graph.GenerateNodeName(reshape.Name() + "_slice_0");
std::string slice_node_1_name = graph.GenerateNodeName(reshape.Name() + "_slice_1");

// The Slice node type should be the same as the Reshape node (going to be removed) type.
// The Slice node's output type should be the same as the Reshape node's output (going to be removed) type.
auto* slice_node_0_arg = &graph.GetOrCreateNodeArg(slice_node_0_name, reshape.OutputDefs()[0]->TypeAsProto());
auto* slice_node_1_arg = &graph.GetOrCreateNodeArg(slice_node_1_name, reshape.OutputDefs()[0]->TypeAsProto());


// Create inputs as OrtValue for Slice nodes, i.e. the "start", "ends" and "axes" inputs.
// The inputs will become initializers.
OrtMemoryInfo* mem_info = nullptr;
const OrtApi* ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
auto status = ort_api->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &mem_info);
std::unique_ptr<OrtMemoryInfo, decltype(ort_api->ReleaseMemoryInfo)> rel_info(mem_info, ort_api->ReleaseMemoryInfo);

const int input_data_cnt = 1;
int64_t slice_0_input_data_1[input_data_cnt] = {0};
int64_t slice_0_input_data_2[input_data_cnt] = {4};
int64_t slice_0_input_data_3[input_data_cnt] = {3};
int64_t slice_1_input_data_1[input_data_cnt] = {4};
int64_t slice_1_input_data_2[input_data_cnt] = {8};
int64_t slice_1_input_data_3[input_data_cnt] = {3};
const size_t input_len = input_data_cnt * sizeof(int64_t);
const int64_t input_shape[] = {1};
const size_t shape_len = sizeof(input_shape) / sizeof(input_shape[0]);

OrtValue* slice_0_input_1 = nullptr;
OrtValue* slice_0_input_2 = nullptr;
OrtValue* slice_0_input_3 = nullptr;
OrtValue* slice_1_input_1 = nullptr;
OrtValue* slice_1_input_2 = nullptr;
OrtValue* slice_1_input_3 = nullptr;
ort_api->CreateTensorWithDataAsOrtValue(mem_info, slice_0_input_data_1, input_len, input_shape, shape_len, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, &slice_0_input_1);
ort_api->CreateTensorWithDataAsOrtValue(mem_info, slice_0_input_data_2, input_len, input_shape, shape_len, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, &slice_0_input_2);
ort_api->CreateTensorWithDataAsOrtValue(mem_info, slice_0_input_data_3, input_len, input_shape, shape_len, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, &slice_0_input_3);
ort_api->CreateTensorWithDataAsOrtValue(mem_info, slice_1_input_data_1, input_len, input_shape, shape_len, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, &slice_1_input_1);
ort_api->CreateTensorWithDataAsOrtValue(mem_info, slice_1_input_data_2, input_len, input_shape, shape_len, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, &slice_1_input_2);
ort_api->CreateTensorWithDataAsOrtValue(mem_info, slice_1_input_data_3, input_len, input_shape, shape_len, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, &slice_1_input_3);

const Tensor& slice_0_input_tensor_1 = slice_0_input_1->Get<Tensor>();
const Tensor& slice_0_input_tensor_2 = slice_0_input_2->Get<Tensor>();
const Tensor& slice_0_input_tensor_3 = slice_0_input_3->Get<Tensor>();
const Tensor& slice_1_input_tensor_1 = slice_1_input_1->Get<Tensor>();
const Tensor& slice_1_input_tensor_2 = slice_1_input_2->Get<Tensor>();
const Tensor& slice_1_input_tensor_3 = slice_1_input_3->Get<Tensor>();

ONNX_NAMESPACE::TensorProto slice_0_tensorproto_1 = utils::TensorToTensorProto(slice_0_input_tensor_1, slice_node_0_name + "_starts");
ONNX_NAMESPACE::TensorProto slice_0_tensorproto_2 = utils::TensorToTensorProto(slice_0_input_tensor_2, slice_node_0_name + "_ends");
ONNX_NAMESPACE::TensorProto slice_0_tensorproto_3 = utils::TensorToTensorProto(slice_0_input_tensor_3, slice_node_0_name + "_axes");
ONNX_NAMESPACE::TensorProto slice_1_tensorproto_1 = utils::TensorToTensorProto(slice_1_input_tensor_1, slice_node_1_name + "_starts");
ONNX_NAMESPACE::TensorProto slice_1_tensorproto_2 = utils::TensorToTensorProto(slice_1_input_tensor_2, slice_node_1_name + "_ends");
ONNX_NAMESPACE::TensorProto slice_1_tensorproto_3 = utils::TensorToTensorProto(slice_1_input_tensor_3, slice_node_1_name + "_axes");

ONNX_NAMESPACE::TypeProto t;
t.mutable_tensor_type()->set_elem_type(slice_0_tensorproto_1.data_type());
auto* slice_node_0_arg_1 = &graph.GetOrCreateNodeArg(slice_node_0_name + "_starts", &t);
auto* slice_node_0_arg_2 = &graph.GetOrCreateNodeArg(slice_node_0_name + "_ends", &t);
auto* slice_node_0_arg_3 = &graph.GetOrCreateNodeArg(slice_node_0_name + "_axes", &t);
auto* slice_node_1_arg_1 = &graph.GetOrCreateNodeArg(slice_node_1_name + "_starts", &t);
auto* slice_node_1_arg_2 = &graph.GetOrCreateNodeArg(slice_node_1_name + "_ends", &t);
auto* slice_node_1_arg_3 = &graph.GetOrCreateNodeArg(slice_node_1_name + "_axes", &t);

graph.AddInitializedTensor(slice_0_tensorproto_1);
graph.AddInitializedTensor(slice_0_tensorproto_2);
graph.AddInitializedTensor(slice_0_tensorproto_3);
graph.AddInitializedTensor(slice_1_tensorproto_1);
graph.AddInitializedTensor(slice_1_tensorproto_2);
graph.AddInitializedTensor(slice_1_tensorproto_3);

std::vector<onnxruntime::NodeArg*> slice_node_0_input_defs = {reshape.MutableInputDefs()[0], slice_node_0_arg_1, slice_node_0_arg_2, slice_node_0_arg_3};
std::vector<onnxruntime::NodeArg*> slice_node_1_input_defs = {reshape.MutableInputDefs()[0], slice_node_1_arg_1, slice_node_1_arg_2, slice_node_1_arg_3};
std::vector<onnxruntime::NodeArg*> slice_node_0_output_defs = {slice_node_0_arg};
std::vector<onnxruntime::NodeArg*> slice_node_1_output_defs = {slice_node_1_arg};

// Create 2 Slice nodes
auto& slice_node_0 = graph.AddNode(slice_node_0_name, "Slice", "Map 5D/6D to 4D",
reshape.MutableInputDefs(), slice_node_0_output_defs);
slice_node_0_input_defs, slice_node_0_output_defs);
auto& slice_node_1 = graph.AddNode(slice_node_1_name, "Slice", "Map 5D/6D to 4D",
reshape.MutableInputDefs(), slice_node_1_output_defs);
slice_node_1_input_defs, slice_node_1_output_defs);

// Create 2 ReduceSum nodes
std::string reduce_sum_node_0_name = graph.GenerateNodeName(reduce_sum.Name() + "_0");
std::string reduce_sum_node_1_name = graph.GenerateNodeName(reduce_sum.Name() + "_1");

// The ReduceSum node type should be the same as the original ReduceSum node (going to be removed) type.
int64_t reduce_sum_0_input_data_1[input_data_cnt] = {3};
int64_t reduce_sum_1_input_data_1[input_data_cnt] = {3};

OrtValue* reduce_sum_0_input_1 = nullptr;
OrtValue* reduce_sum_1_input_1 = nullptr;
ort_api->CreateTensorWithDataAsOrtValue(mem_info, reduce_sum_0_input_data_1, input_len, input_shape, shape_len, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, &reduce_sum_0_input_1);
ort_api->CreateTensorWithDataAsOrtValue(mem_info, reduce_sum_1_input_data_1, input_len, input_shape, shape_len, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, &reduce_sum_1_input_1);

const Tensor& reduce_sum_0_input_tensor_1 = reduce_sum_0_input_1->Get<Tensor>();
const Tensor& reduce_sum_1_input_tensor_1 = reduce_sum_1_input_1->Get<Tensor>();

ONNX_NAMESPACE::TensorProto reduce_sum_0_tensorproto_1 = utils::TensorToTensorProto(reduce_sum_0_input_tensor_1, reduce_sum_node_0_name + "_axes");
ONNX_NAMESPACE::TensorProto reduce_sum_1_tensorproto_1 = utils::TensorToTensorProto(reduce_sum_1_input_tensor_1, reduce_sum_node_1_name + "_axes");

ONNX_NAMESPACE::TypeProto t1;
t1.mutable_tensor_type()->set_elem_type(reduce_sum_0_tensorproto_1.data_type());
auto* reduce_sum_0_arg_1 = &graph.GetOrCreateNodeArg(reduce_sum_node_0_name + "_axes", &t1);
auto* reduce_sum_1_arg_1 = &graph.GetOrCreateNodeArg(reduce_sum_node_1_name + "_axes", &t1);

graph.AddInitializedTensor(reduce_sum_0_tensorproto_1);
graph.AddInitializedTensor(reduce_sum_1_tensorproto_1);

// The ReduceSum node's output type should be the same as the original ReduceSum node's output (going to be removed) type.
auto* reduce_sum_node_0_arg = &graph.GetOrCreateNodeArg(reduce_sum_node_0_name, reduce_sum.OutputDefs()[0]->TypeAsProto());
auto* reduce_sum_node_1_arg = &graph.GetOrCreateNodeArg(reduce_sum_node_1_name, reduce_sum.OutputDefs()[0]->TypeAsProto());

std::vector<onnxruntime::NodeArg*> reduce_sum_node_0_input_defs = {slice_node_0_output_defs[0], reduce_sum_0_arg_1};
std::vector<onnxruntime::NodeArg*> reduce_sum_node_1_input_defs = {slice_node_1_output_defs[0], reduce_sum_1_arg_1};
std::vector<onnxruntime::NodeArg*> reduce_sum_node_0_output_defs = {reduce_sum_node_0_arg};
std::vector<onnxruntime::NodeArg*> reduce_sum_node_1_output_defs = {reduce_sum_node_1_arg};

auto& reduce_sum_node_0 = graph.AddNode(reduce_sum_node_0_name, "ReduceSum", "Map 5D/6D to 4D",
slice_node_0_output_defs, reduce_sum_node_0_output_defs);
reduce_sum_node_0_input_defs, reduce_sum_node_0_output_defs);
auto& reduce_sum_node_1 = graph.AddNode(reduce_sum_node_1_name, "ReduceSum", "Map 5D/6D to 4D",
slice_node_1_output_defs, reduce_sum_node_1_output_defs);
reduce_sum_node_1_input_defs, reduce_sum_node_1_output_defs);

// Create 1 Concat
reduce_sum_node_0.AddAttribute("keepdims", (int64_t)(1));
reduce_sum_node_1.AddAttribute("keepdims", (int64_t)(1));

// Create 1 Concat node
std::string concat_node_name = graph.GenerateNodeName(reduce_sum_node_0_name + "_concat");
auto* concat_node_arg = &graph.GetOrCreateNodeArg(concat_node_name, reduce_sum.OutputDefs()[0]->TypeAsProto());
std::vector<onnxruntime::NodeArg*> concat_node_arg_input_defs = {reduce_sum_node_0_arg, reduce_sum_node_1_arg};
std::vector<onnxruntime::NodeArg*> concat_node_arg_output_defs = {concat_node_arg};
auto& concat_node = graph.AddNode(concat_node_name, "Concat", "Map 5D/6D to 4D",
concat_node_arg_input_defs, concat_node_arg_output_defs);
concat_node_arg_input_defs, reduce_sum.MutableOutputDefs());
concat_node.AddAttribute("axis", (int64_t)(3));

return concat_node_arg;
return Status::OK();
}

Status MapToFourDimensions::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
Expand Down Expand Up @@ -131,16 +206,25 @@ Status MapToFourDimensions::ApplyImpl(Graph& graph, bool& modified, int graph_le
} else if (node->OpType() == "ReduceSum") {
// assume Reshape -> Q -> DQ -> ReduceSum since we don't remove Q/DQ for now
// TODO: Make sure Reshape, Q and DQ does exist
const Node& node_x = *node->InputNodesBegin(); // Q
const Node& node_y = *node_x.InputNodesBegin(); // DQ
const Node& node_z = *node_y.InputNodesBegin(); // Reshape
const Node& node_x = *node->InputNodesBegin(); // Q
const Node& node_y = *node_x.InputNodesBegin(); // DQ
const Node& node_z = *node_y.InputNodesBegin(); // Reshape
Node* reshape_node = graph.GetNode(node_z.Index()); // Mutable Reshape

const auto* input_0 = node_z.InputDefs()[0];
const auto* output_0 = node_z.OutputDefs()[0];
}

AddSliceReduceConcatNodes(graph, *reshape_node, *node);

// Remove original Reshape and ReduceSum;
// Remove the output edges of the constant node and then remove the node itself.
graph_utils::RemoveNodeOutputEdges(graph, *node);
graph.RemoveNode(node->Index());
}
}

ORT_RETURN_IF_ERROR(graph.Resolve());

return Status::OK();
}
} // namespace onnxruntime
23 changes: 1 addition & 22 deletions onnxruntime/core/optimizer/map_to_four_dimension.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,12 @@

namespace onnxruntime {

/**
@class ConstantFolding
Transformer that traverses the graph top-down and performs constant folding, i.e.,
it statically computes parts of the graph that rely only on constant initializers.
*/
class MapToFourDimensions : public GraphTransformer {
public:
/*! Constant folding will not be applied to nodes that have one of initializers from excluded_initializers as input.
For pre-training, the trainable weights are those initializers to be excluded.
\param execution_provider Execution provider instance to execute constant folding.
*/
MapToFourDimensions(const IExecutionProvider& execution_provider,
bool skip_dequantize_linear,
bool dequantize_initializer_for_dequantize_linear,
const ConfigOptions& config_options,
const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
const InlinedHashSet<std::string>& excluded_initializers = {}) noexcept;
MapToFourDimensions() noexcept;

private:
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;

bool skip_dequantize_linear_;
bool dequantize_initializer_for_dequantize_linear_;
const ConfigOptions& config_options_;
const InlinedHashSet<std::string> excluded_initializers_;
const IExecutionProvider& execution_provider_;
};

} // namespace onnxruntime

0 comments on commit 413a159

Please sign in to comment.