11// RUN: gc-opt --split-input-file --deep-tile-contraction-named-op %s
22
3- // -----
3+ // // -----
44
5- /// CHECK-LABEL: @blocked_matmul_f32
6- func.func @blocked_matmul_f32 (%arg0: tensor <128 x128 x32 x32 xf32 >) -> tensor <128 x128 x32 x32 xf32 > {
7- %cst = arith.constant dense <1.000000e+00 > : tensor <128 x128 x32 x32 xf32 >
8- %cst_0 = arith.constant 0.000000e+00 : f32
9- %0 = tensor.empty () : tensor <128 x128 x32 x32 xf32 >
10- %1 = linalg.fill ins (%cst_0 : f32 ) outs (%0 : tensor <128 x128 x32 x32 xf32 >) -> tensor <128 x128 x32 x32 xf32 >
11- %2 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d2 , d3 , d5 )>, affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d1 , d2 , d5 , d4 )>, affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d3 , d4 )>], iterator_types = [" parallel" , " parallel" , " reduction" , " parallel" , " parallel" , " reduction" ]} ins (%arg0 , %cst : tensor <128 x128 x32 x32 xf32 >, tensor <128 x128 x32 x32 xf32 >) outs (%1 : tensor <128 x128 x32 x32 xf32 >) {
12- ^bb0 (%in: f32 , %in_1: f32 , %out: f32 ):
13- %3 = arith.mulf %in , %in_1 : f32
14- %4 = arith.addf %out , %3 : f32
15- linalg.yield %4 : f32
16- } -> tensor <128 x128 x32 x32 xf32 >
17- return %2 : tensor <128 x128 x32 x32 xf32 >
18- }
5+ // /// CHECK-LABEL: @matmul_4Dx4D_f32
6+ // func.func @matmul_4Dx4D_f32(%arg0: tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32> {
7+ // %cst = arith.constant dense<1.000000e+00> : tensor<128x128x32x32x1xf32>
8+ // %cst_0 = arith.constant 0.000000e+00 : f32
9+ // %0 = tensor.empty() : tensor<128x128x32x32xf32>
10+ // %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32>
11+ // %2 = linalgx.mm4d_vnni ins(%arg0, %cst : tensor<128x128x32x32xf32>, tensor<128x128x32x32x1xf32>) outs(%1 : tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32>
12+ // return %2 : tensor<128x128x32x32xf32>
13+ // }
1914
2015// -----
2116
22- /// CHECK-LABEL: @plain_matmul_f32
23- func.func @plain_matmul_f32 (%arg0: tensor <4096 x4096 xf32 >) -> tensor <4096 x4096 xf32 > {
17+ /// CHECK-LABEL: @matmul_2Dx2D_f32
18+ func.func @matmul_2Dx2D_f32 (%arg0: tensor <4096 x4096 xf32 >) -> tensor <4096 x4096 xf32 > {
2419 %cst = arith.constant dense <1.000000e+00 > : tensor <4096 x4096 xf32 >
2520 %cst_0 = arith.constant 0.000000e+00 : f32
2621 %0 = tensor.empty () : tensor <4096 x4096 xf32 >
@@ -29,20 +24,39 @@ func.func @plain_matmul_f32(%arg0: tensor<4096x4096xf32>) -> tensor<4096x4096xf3
2924 return %2 : tensor <4096 x4096 xf32 >
3025}
3126
27+ // // -----
28+
29+ // /// CHECK-LABEL: @matmul_2Dx4D_f32
30+ // func.func @matmul_4Dx4D_f32(%arg0: tensor<4096x4096xf32>) -> tensor<4096x4096xf32> {
31+ // %cst = arith.constant dense<1.000000e+00> : tensor<128x128x32x32x1xf32>
32+ // %cst_0 = arith.constant 0.000000e+00 : f32
33+ // %0 = tensor.empty() : tensor<4096x4096xf32>
34+ // %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
35+ // %2 = linalgx.mm2d_vnni ins(%arg0, %cst : tensor<4096x4096xf32>, tensor<128x128x32x32x1xf32>) outs(%1 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
36+ // return %2 : tensor<4096x4096xf32>
37+ // }
38+
3239// -----
3340
34- /// CHECK-LABEL: @blocked_matmul_bf16
35- func.func @blocked_matmul_bf16 (%arg0: tensor <128 x128 x32 x32 xbf16 >) -> tensor <128 x128 x32 x32 xbf16 > {
41+ /// CHECK-LABEL: @matmul_4Dx4D_bf16
42+ func.func @matmul_4Dx4D_bf16 (%arg0: tensor <128 x128 x32 x32 xbf16 >) -> tensor <128 x128 x32 x32 xbf16 > {
3643 %cst = arith.constant dense <1.000000e+00 > : tensor <128 x128 x16 x32 x2 xbf16 >
3744 %cst_0 = arith.constant 0.000000e+00 : bf16
3845 %0 = tensor.empty () : tensor <128 x128 x32 x32 xbf16 >
3946 %1 = linalg.fill ins (%cst_0 : bf16 ) outs (%0 : tensor <128 x128 x32 x32 xbf16 >) -> tensor <128 x128 x32 x32 xbf16 >
40- %2 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 ) -> (d0 , d2 , d4 , d6 )>, affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 ) -> (d1 , d2 , d6 floordiv 2 , d5 , d3 )>, affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 ) -> (d0 , d1 , d4 , d5 )>], iterator_types = [" parallel" , " parallel" , " reduction" , " reduction" , " parallel" , " parallel" , " reduction" ]} ins (%arg0 , %cst : tensor <128 x128 x32 x32 xbf16 >, tensor <128 x128 x16 x32 x2 xbf16 >) outs (%1 : tensor <128 x128 x32 x32 xbf16 >) {
41- ^bb0 (%in: bf16 , %in_1: bf16 , %out: bf16 ):
42- %3 = arith.mulf %in , %in_1 : bf16
43- %4 = arith.addf %out , %3 : bf16
44- linalg.yield %4 : bf16
45- } -> tensor <128 x128 x32 x32 xbf16 >
47+ %2 = linalgx.mm4d_vnni ins (%arg0 , %cst : tensor <128 x128 x32 x32 xbf16 >, tensor <128 x128 x16 x32 x2 xbf16 >) outs (%1 : tensor <128 x128 x32 x32 xbf16 >) -> tensor <128 x128 x32 x32 xbf16 >
4648 return %2 : tensor <128 x128 x32 x32 xbf16 >
4749}
4850
51+ // // -----
52+
53+ // /// CHECK-LABEL: @matmul_2Dx4D_bf16
54+ // func.func @matmul_4Dx4D_bf16(%arg0: tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> {
55+ // %cst = arith.constant dense<1.000000e+00> : tensor<128x128x16x32x2xbf16>
56+ // %cst_0 = arith.constant 0.000000e+00 : bf16
57+ // %0 = tensor.empty() : tensor<4096x4096xbf16>
58+ // %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16>
59+ // %2 = linalgx.mm2d_vnni ins(%arg0, %cst : tensor<4096x4096xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16>
60+ // return %2 : tensor<4096x4096xbf16>
61+ // }
62+
0 commit comments