Skip to content

Commit 42392bc

Browse files
authored
[MLIR][ONNX] Add OnnxToTorch support for matmul ops (#2629)
This commit adds the OnnxToTorch support for Matmul.
1 parent ed4df38 commit 42392bc

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,17 @@ using namespace mlir::torch::onnx_c;
2626
// results in a lot of ONNX test cases that all reduce to the exact same
2727
// thing here, so we simplify.
2828
void mlir::torch::onnx_c::populateDefaultDomainGtoP(
29-
OnnxCustomOpConversionPattern &patterns) {}
29+
OnnxCustomOpConversionPattern &patterns) {
30+
31+
patterns.onOp("MatMul", 13,
32+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
33+
Torch::ValueTensorType resultType;
34+
Value lhs, rhs;
35+
if (binder.tensorOperands(lhs, rhs) ||
36+
binder.tensorResultType(resultType))
37+
return failure();
38+
rewriter.replaceOpWithNewOp<Torch::AtenMatmulOp>(
39+
binder.op, resultType, lhs, rhs);
40+
return success();
41+
});
42+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: torch-mlir-opt <%s -convert-torch-onnx-to-torch | FileCheck %s
2+
// Generally, the test cases accumulated here come from running the importer
3+
// over all included backend tests that involve simple ops with no model
4+
// level constants. This is a pragmatic choice which lets us have a lot
5+
// of tests in this file, whereas the others tend to be more bespoke.
6+
7+
// CHECK-LABEL: @test_matmul_2d
8+
func.func @test_matmul_2d(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
9+
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[3,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[3,3],f32>
10+
%0 = torch.operator "onnx.MatMul"(%arg0, %arg1) : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,3],f32>
11+
return %0 : !torch.vtensor<[3,3],f32>
12+
}
13+
14+
// CHECK-LABEL: @test_matmul_3d
15+
func.func @test_matmul_3d(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,4,3],f32>) -> !torch.vtensor<[2,3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
16+
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,4,3],f32> -> !torch.vtensor<[2,3,3],f32>
17+
%0 = torch.operator "onnx.MatMul"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,4,3],f32>) -> !torch.vtensor<[2,3,3],f32>
18+
return %0 : !torch.vtensor<[2,3,3],f32>
19+
}
20+
21+
// CHECK-LABEL: @test_matmul_4d
22+
func.func @test_matmul_4d(%arg0: !torch.vtensor<[1,2,3,4],f32>, %arg1: !torch.vtensor<[1,2,4,3],f32>) -> !torch.vtensor<[1,2,3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
23+
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[1,2,3,4],f32>, !torch.vtensor<[1,2,4,3],f32> -> !torch.vtensor<[1,2,3,3],f32>
24+
%0 = torch.operator "onnx.MatMul"(%arg0, %arg1) : (!torch.vtensor<[1,2,3,4],f32>, !torch.vtensor<[1,2,4,3],f32>) -> !torch.vtensor<[1,2,3,3],f32>
25+
return %0 : !torch.vtensor<[1,2,3,3],f32>
26+
}

0 commit comments

Comments
 (0)