1
+ // RUN: torch-mlir-opt <%s -convert-torch-onnx-to-torch | FileCheck %s
2
+ // Generally, the test cases accumulated here come from running the importer
3
+ // over all included backend tests that involve simple ops with no model
4
+ // level constants. This is a pragmatic choice which lets us have a lot
5
+ // of tests in this file, whereas the others tend to be more bespoke.
6
+
7
+ // CHECK-LABEL: @test_matmul_2d
8
+ func.func @test_matmul_2d (%arg0: !torch.vtensor <[3 ,4 ],f32 >, %arg1: !torch.vtensor <[4 ,3 ],f32 >) -> !torch.vtensor <[3 ,3 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
9
+ // CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[3,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[3,3],f32>
10
+ %0 = torch.operator " onnx.MatMul" (%arg0 , %arg1 ) : (!torch.vtensor <[3 ,4 ],f32 >, !torch.vtensor <[4 ,3 ],f32 >) -> !torch.vtensor <[3 ,3 ],f32 >
11
+ return %0 : !torch.vtensor <[3 ,3 ],f32 >
12
+ }
13
+
14
+ // CHECK-LABEL: @test_matmul_3d
15
+ func.func @test_matmul_3d (%arg0: !torch.vtensor <[2 ,3 ,4 ],f32 >, %arg1: !torch.vtensor <[2 ,4 ,3 ],f32 >) -> !torch.vtensor <[2 ,3 ,3 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
16
+ // CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,4,3],f32> -> !torch.vtensor<[2,3,3],f32>
17
+ %0 = torch.operator " onnx.MatMul" (%arg0 , %arg1 ) : (!torch.vtensor <[2 ,3 ,4 ],f32 >, !torch.vtensor <[2 ,4 ,3 ],f32 >) -> !torch.vtensor <[2 ,3 ,3 ],f32 >
18
+ return %0 : !torch.vtensor <[2 ,3 ,3 ],f32 >
19
+ }
20
+
21
+ // CHECK-LABEL: @test_matmul_4d
22
+ func.func @test_matmul_4d (%arg0: !torch.vtensor <[1 ,2 ,3 ,4 ],f32 >, %arg1: !torch.vtensor <[1 ,2 ,4 ,3 ],f32 >) -> !torch.vtensor <[1 ,2 ,3 ,3 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
23
+ // CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[1,2,3,4],f32>, !torch.vtensor<[1,2,4,3],f32> -> !torch.vtensor<[1,2,3,3],f32>
24
+ %0 = torch.operator " onnx.MatMul" (%arg0 , %arg1 ) : (!torch.vtensor <[1 ,2 ,3 ,4 ],f32 >, !torch.vtensor <[1 ,2 ,4 ,3 ],f32 >) -> !torch.vtensor <[1 ,2 ,3 ,3 ],f32 >
25
+ return %0 : !torch.vtensor <[1 ,2 ,3 ,3 ],f32 >
26
+ }
0 commit comments