Skip to content

Commit f267b7e

Browse files
authored
Fix MatmulTransposeFusion when input A and B are the same (microsoft#24373)
### Description MatmulTransposeFusion does not work correctly when input A and B are the same for a `MatMul` node. ![image](https://github.com/user-attachments/assets/48a6afd8-13d0-48d4-b86f-53a866c47803) Fixes microsoft#24341 ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent c5b82a5 commit f267b7e

File tree

3 files changed

+31
-6
lines changed

3 files changed

+31
-6
lines changed

onnxruntime/core/optimizer/matmul_transpose_fusion.cc

+13-6
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,19 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_
310310
continue;
311311
}
312312

313+
NodeArg* right_input = node.MutableInputDefs()[1];
314+
auto right_type = right_input->TypeAsProto()->tensor_type().elem_type();
315+
if (!IsAllowedFusedMatMulDataType(static_cast<ONNX_NAMESPACE::TensorProto_DataType>(right_type))) {
316+
continue;
317+
}
318+
319+
if (left_input == right_input) {
320+
// If both inputs are the same, we skip the fusion.
321+
// Currently, this situation is not handled correctly in the code below.
322+
// Otherwise, the model initialization may fail. See https://github.com/microsoft/onnxruntime/issues/24341.
323+
continue;
324+
}
325+
313326
bool is_trans_left = false;
314327
bool is_trans_batch_left = false;
315328
Node* left = nullptr;
@@ -325,12 +338,6 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_
325338
}
326339
}
327340

328-
NodeArg* right_input = node.MutableInputDefs()[1];
329-
auto right_type = right_input->TypeAsProto()->tensor_type().elem_type();
330-
if (!IsAllowedFusedMatMulDataType(static_cast<ONNX_NAMESPACE::TensorProto_DataType>(right_type))) {
331-
continue;
332-
}
333-
334341
bool is_trans_right = false;
335342
bool is_trans_batch_right = false;
336343
Node* right = nullptr;

onnxruntime/test/optimizer/graph_transform_test.cc

+18
Original file line numberDiff line numberDiff line change
@@ -2946,6 +2946,24 @@ TEST_F(GraphTransformationTests, TransposeMatmulTransBatchNoFusion) {
29462946
}
29472947
}
29482948

2949+
TEST_F(GraphTransformationTests, TransposeMatmulFusion_SameInput_gh_issue_24341) {
2950+
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/gh_issue_24341.onnx";
2951+
2952+
std::shared_ptr<Model> p_model;
2953+
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
2954+
Graph& graph = p_model->MainGraph();
2955+
std::map<std::string, int> orig_op_to_count = CountOpsInGraph(graph);
2956+
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
2957+
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
2958+
std::make_unique<MatmulTransposeFusion>(), TransformerLevel::Level1));
2959+
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
2960+
2961+
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
2962+
ASSERT_EQ(op_to_count["Transpose"], orig_op_to_count["Transpose"]);
2963+
ASSERT_EQ(op_to_count["MatMul"], orig_op_to_count["MatMul"]);
2964+
ASSERT_EQ(op_to_count["Cast"], orig_op_to_count["Cast"]);
2965+
}
2966+
29492967
TEST_F(GraphTransformationTests, Gemm_LeakyRelu_Fusion) {
29502968
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "gemm_activation_fusion/gemm_activation_fusion.onnx";
29512969

Binary file not shown.

0 commit comments

Comments
 (0)