Skip to content

Commit

Permalink
matmul add fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
centwang committed Oct 18, 2024
1 parent b4cb937 commit 5b75280
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 17 deletions.
86 changes: 69 additions & 17 deletions onnxruntime/core/optimizer/matmul_add_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,37 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// Gemm only support Matrix, need to check the shape of MatMul and Add
auto matmul_a_shape = matmul_input_defs[0]->Shape();
auto matmul_b_shape = matmul_input_defs[1]->Shape();
if (nullptr == matmul_a_shape || nullptr == matmul_b_shape) {
if (nullptr == matmul_a_shape || nullptr == matmul_b_shape || matmul_b_shape->dim_size() != 2) {
continue;
}

if (2 != matmul_a_shape->dim_size() || 2 != matmul_b_shape->dim_size()) {
// Gemm only support Matrix
continue;
bool need_reshape = matmul_a_shape->dim_size() != 2;
const auto& dim_n = matmul_b_shape->dim(1);
std::vector<int64_t> shape_values;
int64_t m = 0, k = 0, n = 0;
if (need_reshape) {
// Logically we can use Shape-Concat to produce shape input for Reshape, to keep it simple, we require
// both inputs have concrete shape for now, we can add dynamic shape support in future.
bool is_concrete_shape = true;
for (int i = 0; i < matmul_a_shape->dim_size(); ++i) {
const auto& dim = matmul_a_shape->dim(i);
if (!utils::HasDimValue(dim)) {
is_concrete_shape = false;
break;
}
shape_values.emplace_back(dim.dim_value());
}
if (!is_concrete_shape) {
continue;
}
const auto& dim_k = matmul_b_shape->dim(0);
if (!utils::HasDimValue(dim_k) || !utils::HasDimValue(dim_n)) {
continue;
}
k = dim_k.dim_value();
n = dim_n.dim_value();
ORT_ENFORCE(shape_values.back() == k);
m = std::accumulate(shape_values.begin(), shape_values.end() - 1, 1, std::multiplies<int64_t>());

Check warning on line 98 in onnxruntime/core/optimizer/matmul_add_fusion.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <functional> for multiplies<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/matmul_add_fusion.cc:98: Add #include <functional> for multiplies<> [build/include_what_you_use] [4]
}

const auto& matmul_output = *matmul_node.OutputDefs()[0];
Expand All @@ -92,31 +116,59 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}
const auto& bias_shape = *gemm_input_defs.back()->Shape();
const auto& M = matmul_output.Shape()->dim()[0];
const auto& N = matmul_output.Shape()->dim()[1];
auto dim_has_value_1 = [](const TensorShapeProto_Dimension& dim) {
return dim.has_dim_value() && dim.dim_value() == 1;
};

bool valid = ((bias_shape.dim_size() == 1 && bias_shape.dim()[0] == N) ||
(bias_shape.dim_size() == 2 && dim_has_value_1(bias_shape.dim()[0]) && bias_shape.dim()[1] == N) ||
(bias_shape.dim_size() == 2 && bias_shape.dim()[0] == M &&
(dim_has_value_1(bias_shape.dim()[1]) || bias_shape.dim()[1] == N)));
bool valid = ((bias_shape.dim_size() == 1 && bias_shape.dim(0) == dim_n) ||
(bias_shape.dim_size() == 2 && dim_has_value_1(bias_shape.dim(0)) && bias_shape.dim(1) == dim_n) ||
(bias_shape.dim_size() == 2 &&
((!need_reshape && bias_shape.dim(0) == matmul_a_shape->dim(0)) ||
(need_reshape && utils::HasDimValue(bias_shape.dim(0)) && bias_shape.dim(0).dim_value() == m)) &&
(dim_has_value_1(bias_shape.dim(1)) || bias_shape.dim(1) == dim_n)));
if (!valid) {
continue;
}

Node& gemm_node = graph.AddNode(graph.GenerateNodeName(matmul_node.Name() + "/MatMulAddFusion/"),
"Gemm",
"fused Matmul and Add " + add_node.OpType(),
gemm_input_defs,
{});
auto gemm_output_defs = add_node.MutableOutputDefs();
if (need_reshape) {
auto add_reshape = [&](const std::vector<int64_t>& shape, Graph& graph, bool is_input) {

Check warning on line 135 in onnxruntime/core/optimizer/matmul_add_fusion.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/matmul_add_fusion.cc:135: Add #include <vector> for vector<> [build/include_what_you_use] [4]
const std::string name = is_input ? "gemm_input" : "gemm_output";

Check warning on line 136 in onnxruntime/core/optimizer/matmul_add_fusion.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/matmul_add_fusion.cc:136: Add #include <string> for string [build/include_what_you_use] [4]
ONNX_NAMESPACE::TensorProto shape_initializer_proto;
shape_initializer_proto.set_name(graph.GenerateNodeName(name + "_shape"));
shape_initializer_proto.add_dims(static_cast<int64_t>(shape.size()));
shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
shape_initializer_proto.set_raw_data(shape.data(), shape.size() * sizeof(int64_t));
NodeArg* shape_arg = &graph_utils::AddInitializer(graph, shape_initializer_proto);
ONNX_NAMESPACE::TypeProto new_arg_type;
const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast<ONNX_NAMESPACE::TensorProto_DataType>(
gemm_input_defs[0]->TypeAsProto()->tensor_type().elem_type());
new_arg_type.mutable_tensor_type()->set_elem_type(element_type);
new_arg_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(m);
new_arg_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(is_input ? k : n);
NodeArg* new_arg = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(name + "_reshape_arg"), &new_arg_type);
Node& reshape_node = graph.AddNode(graph.GenerateNodeName(name + "_reshape"), "Reshape", "Reshape for " + name,
{is_input ? gemm_input_defs[0] : new_arg, shape_arg},
{is_input ? new_arg : gemm_output_defs[0]});
reshape_node.SetExecutionProviderType(matmul_node.GetExecutionProviderType());
return new_arg;
};

gemm_input_defs[0] = add_reshape({m, k}, graph, true);
shape_values.back() = n;
gemm_output_defs[0] = add_reshape(shape_values, graph, false);
}

Node& gemm_node = graph.AddNode(graph.GenerateNodeName(matmul_node.Name() + "/MatMulAddFusion/"), "Gemm",
"fused Matmul and Add", gemm_input_defs, gemm_output_defs);

// Assign provider to this new node. Provider should be same as the provider for old node.
gemm_node.SetExecutionProviderType(matmul_node.GetExecutionProviderType());

// move output definitions and edges from act_node to gemm_node. delete gemm_node and act_node.
graph_utils::FinalizeNodeFusion(graph, {matmul_node, add_node}, gemm_node);
graph_utils::RemoveNodeOutputEdges(graph, matmul_node);
graph.RemoveNode(matmul_node.Index());
graph_utils::RemoveNodeOutputEdges(graph, add_node);
graph.RemoveNode(add_node.Index());

modified = true;
}
Expand Down
62 changes: 62 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2471,6 +2471,68 @@ TEST_F(GraphTransformationTests, MatMulAddFusion_MissingShape) {
ASSERT_EQ(op_to_count["Gemm"], 0);
}

TEST_F(GraphTransformationTests, MatMulAddFusion_NeedReshape_1D) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<float>({{16}});
auto* weight_arg = builder.MakeInput<float>({{16, 768}});
auto* bias_arg = builder.MakeInput<float>({{768}});
auto* matmul_out = builder.MakeIntermediate();
auto* output_arg = builder.MakeOutput();
builder.AddNode("MatMul", {input_arg, weight_arg}, {matmul_out});
builder.AddNode("Add", {matmul_out, bias_arg}, {output_arg});
};

auto pre_graph_checker = [](Graph& graph) {
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(op_to_count["MatMul"] == 1);
TEST_RETURN_IF_NOT(op_to_count["Add"] == 1);
return Status::OK();
};

auto post_graph_checker = [](Graph& graph) {
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(op_to_count["MatMul"] == 0);
TEST_RETURN_IF_NOT(op_to_count["Add"] == 0);
TEST_RETURN_IF_NOT(op_to_count["Gemm"] == 1);
return Status::OK();
};

std::unique_ptr<GraphTransformer> transformer = std::make_unique<MatMulAddFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1,
1, pre_graph_checker, post_graph_checker));
}

TEST_F(GraphTransformationTests, MatMulAddFusion_NeedReshape_3D) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<float>({{8, 16, 32}});
auto* weight_arg = builder.MakeInput<float>({{32, 768}});
auto* bias_arg = builder.MakeInput<float>({{1, 768}});
auto* matmul_out = builder.MakeIntermediate();
auto* output_arg = builder.MakeOutput();
builder.AddNode("MatMul", {input_arg, weight_arg}, {matmul_out});
builder.AddNode("Add", {matmul_out, bias_arg}, {output_arg});
};

auto pre_graph_checker = [](Graph& graph) {
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(op_to_count["MatMul"] == 1);
TEST_RETURN_IF_NOT(op_to_count["Add"] == 1);
return Status::OK();
};

auto post_graph_checker = [](Graph& graph) {
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(op_to_count["MatMul"] == 0);
TEST_RETURN_IF_NOT(op_to_count["Add"] == 0);
TEST_RETURN_IF_NOT(op_to_count["Gemm"] == 1);
return Status::OK();
};

std::unique_ptr<GraphTransformer> transformer = std::make_unique<MatMulAddFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1,
1, pre_graph_checker, post_graph_checker));
}

#ifndef DISABLE_CONTRIB_OPS
TEST_F(GraphTransformationTests, Gemm_Relu_three_input) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "matmul_add_fusion/3Input/gemm_relu.onnx";
Expand Down

0 comments on commit 5b75280

Please sign in to comment.