@@ -15,7 +15,7 @@ module {
15
15
/// CHECK: tensor.empty
16
16
%dest = tensor.empty () : tensor <512 x256 xbf16 >
17
17
%unpack = tensor.unpack %arg1 inner_dims_pos = [0 , 1 ] inner_tiles = [16 , 32 ] into %dest : tensor <32 x8 x16 x32 xbf16 > -> tensor <512 x256 xbf16 >
18
- /// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}, %{{.*}} ) in (2, 2)
18
+ /// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}) in (2, 2)
19
19
%2 = scf.forall (%arg3 , %arg4 ) in (2 , 2 ) shared_outs (%arg5 = %1 ) -> (tensor <128 x256 xbf16 >) {
20
20
%5 = affine.apply affine_map <(d0 ) -> (d0 * 64 )>(%arg3 )
21
21
%6 = affine.apply affine_map <(d0 ) -> (d0 * 128 )>(%arg4 )
@@ -105,7 +105,7 @@ module {
105
105
%cst = arith.constant 0.000000e+00 : f32
106
106
%dest0 = tensor.empty () : tensor <256 x256 xf32 >
107
107
%dest1 = linalg.fill ins (%cst : f32 ) outs (%dest0 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
108
- /// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}, %{{.*}} ) in (2, 2)
108
+ /// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}) in (2, 2)
109
109
%1 = scf.forall (%arg4 , %arg5 ) in (2 , 2 ) shared_outs (%arg6 = %dest1 ) -> tensor <256 x256 xf32 > {
110
110
%iv0 = affine.apply #map (%arg4 )
111
111
%iv1 = affine.apply #map (%arg5 )
@@ -157,7 +157,7 @@ module {
157
157
%cst = arith.constant 0.000000e+00 : f32
158
158
%dest0 = tensor.empty () : tensor <256 x256 xf32 >
159
159
%dest1 = linalg.fill ins (%cst : f32 ) outs (%dest0 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
160
- /// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}, %{{.*}} ) in (2, 1)
160
+ /// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}) in (2, 1)
161
161
%1 = scf.forall (%arg3 , %arg4 ) in (2 , 1 ) shared_outs (%arg5 = %dest1 ) -> tensor <256 x256 xf32 > {
162
162
%iv0 = affine.apply #map (%arg3 )
163
163
%iv1 = affine.apply #map (%arg4 )
@@ -205,12 +205,12 @@ module {
205
205
%dest0 = tensor.empty () : tensor <128 x256 x256 xf32 >
206
206
%0 = linalg.add ins (%arg0 , %arg1 : tensor <128 x256 x256 xf32 >, tensor <128 x256 x256 xf32 >) outs (%dest0 : tensor <128 x256 x256 xf32 >) -> tensor <128 x256 x256 xf32 >
207
207
%dest1 = tensor.empty () : tensor <128 x256 xf32 >
208
- /// CHECK: %[[FINAL_RESULT:.*]] = scf.forall (%{{.*}}, %{{.*}}) in (128, 256)
209
- /// CHECK: tensor.extract_slice {{.*}} [1, 256, 1 ] [1, 1, 1]
210
- /// CHECK: tensor.extract_slice {{.*}} [1, 256, 1 ] [1, 1, 1]
211
- /// CHECK: tensor.extract_slice {{.*}} [1, 256, 1 ] [1, 1, 1]
208
+ /// CHECK: %[[FINAL_RESULT:.*]] = scf.forall (%{{.*}}) = (0, 0) to (128, 256) step (1, 32 )
209
+ /// CHECK: tensor.extract_slice {{.*}} [1, 256, 32 ] [1, 1, 1]
210
+ /// CHECK: tensor.extract_slice {{.*}} [1, 256, 32 ] [1, 1, 1]
211
+ /// CHECK: tensor.extract_slice {{.*}} [1, 256, 32 ] [1, 1, 1]
212
212
/// CHECK: %[[ADD_OUT:.*]] = linalg.add
213
- /// CHECK: tensor.extract_slice {{.*}} [1, 1 ] [1, 1]
213
+ /// CHECK: tensor.extract_slice {{.*}} [1, 32 ] [1, 1]
214
214
/// CHECK: %[[REDUCE_OUT:.*]] = linalg.reduce { arith.addf } ins(%[[ADD_OUT]] :
215
215
%1 = linalg.reduce { arith.addf } ins (%0 : tensor <128 x256 x256 xf32 >) outs (%dest1 : tensor <128 x256 xf32 >) dimensions = [1 ]
216
216
/// CHECK: scf.forall.in_parallel
@@ -319,7 +319,7 @@ module {
319
319
/// CHECK-LABEL: @fuse_residual_pattern
320
320
func.func @fuse_residual_pattern (%arg0: tensor <128 x256 x256 xf32 >, %arg1: tensor <128 x256 x256 xf32 >) -> tensor <128 x256 x256 xf32 > {
321
321
%dest0 = tensor.empty () : tensor <128 x256 x256 xf32 >
322
- /// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}, %{{.*}}) in (128, 256)
322
+ /// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}) = (0, 0, 0) to (128, 256, 256) step (1, 32, 32 )
323
323
/// CHECK: %[[ADD_OUT:.*]] = linalg.add
324
324
/// CHECK: %[[EXP_OUT:.*]] = linalg.exp ins(%[[ADD_OUT:.*]] :
325
325
/// CHECK: %[[MUL_OUT:.*]] = linalg.mul ins(%[[ADD_OUT:.*]], %[[EXP_OUT:.*]] :
@@ -353,4 +353,57 @@ module {
353
353
/// CHECK: return %[[PACK_OUT]]
354
354
return %pack : tensor <1 x1 x128 x32 x32 xbf16 >
355
355
}
356
+ }
357
+
358
+ // -----
359
+
360
+ module {
361
+ // CHECK: func.func @fuse_generic_matmul(
362
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
363
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x16x16xf32>
364
+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<4x16x16xf32>
365
+ func.func @fuse_generic_matmul (%arg0: tensor <32 x32 xf32 >, %arg1: tensor <2 x16 x16 xf32 >, %arg2: tensor <4 x16 x16 xf32 >) -> tensor <32 x64 xf32 > attributes {llvm.emit_c_interface } {
366
+ /// CHECK: %[[EMPTY_OUT_0:.*]] = tensor.empty
367
+ %0 = tensor.empty () : tensor <2 x2 x16 x16 xf32 >
368
+ %pack = tensor.pack %arg0 outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [16 , 16 ] into %0 : tensor <32 x32 xf32 > -> tensor <2 x2 x16 x16 xf32 >
369
+ /// CHECK: %[[EMPTY_OUT_1:.*]] = tensor.empty
370
+ %1 = tensor.empty () : tensor <2 x16 x16 xf32 >
371
+ /// CHECK: %[[FIRST_MATMUL_OUT:.*]] = scf.forall (%{{.*}}) in (2)
372
+ /// CHECK: %[[EXTRACT_SLICE_0:.*]] = tensor.extract_slice %[[ARG0]]{{.*}} [16, 32]
373
+ /// CHECK: %[[EXTRACT_SLICE_1:.*]] = tensor.extract_slice %[[EMPTY_OUT_0]]{{.*}} [1, 2, 16, 16]
374
+ /// CHECK: %[[PACK_OUT:.*]] = tensor.pack %[[EXTRACT_SLICE_0]]
375
+ /// CHECK: %[[EXTRACT_SLICE_2:.*]] = tensor.extract_slice %[[ARG1]]{{.*}} [2, 16, 16]
376
+ /// CHECK: %[[MATMUL_OUT_0:.*]] = linalg.generic {{.*}} ins(%[[PACK_OUT]], %[[EXTRACT_SLICE_2]] :
377
+ %2 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 , d2 , d4 )>, affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d1 , d4 , d3 )>, affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d2 , d3 )>], iterator_types = [" parallel" , " reduction" , " parallel" , " parallel" , " reduction" ]} ins (%pack , %arg1 : tensor <2 x2 x16 x16 xf32 >, tensor <2 x16 x16 xf32 >) outs (%1 : tensor <2 x16 x16 xf32 >) {
378
+ ^bb0 (%in: f32 , %in_3: f32 , %out: f32 ):
379
+ %9 = arith.mulf %in , %in_3 : f32
380
+ %10 = arith.addf %out , %9 : f32
381
+ linalg.yield %10 : f32
382
+ } -> tensor <2 x16 x16 xf32 >
383
+ /// CHECK: scf.forall.in_parallel
384
+ /// CHECK: tensor.parallel_insert_slice
385
+ /// CHECK: }
386
+ /// CHECK: %[[EMPTY_OUT_2:.*]] = tensor.empty
387
+ /// CHECK: %[[EMPTY_OUT_3:.*]] = tensor.empty
388
+ %3 = tensor.empty () : tensor <2 x4 x16 x16 xf32 >
389
+ /// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%{{.*}}) in (2, 4)
390
+ /// CHECK: %[[EXTRACT_SLICE_3:.*]] = tensor.extract_slice %[[FIRST_MATMUL_OUT]]{{.*}} [1, 16, 16]
391
+ /// CHECK: %[[EXTRACT_SLICE_4:.*]] = tensor.extract_slice %[[ARG2]]{{.*}} [1, 16, 16]
392
+ /// CHECK: %[[MATMUL_OUT_1:.*]] = linalg.generic {{.*}} ins(%[[EXTRACT_SLICE_3]], %[[EXTRACT_SLICE_4]] :
393
+ /// CHECK: %[[UNPACK_OUT:.*]] = tensor.unpack %[[MATMUL_OUT_1]]
394
+ %4 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d2 , d4 )>, affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d1 , d4 , d3 )>, affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 , d2 , d3 )>], iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" , " reduction" ]} ins (%2 , %arg2 : tensor <2 x16 x16 xf32 >, tensor <4 x16 x16 xf32 >) outs (%3 : tensor <2 x4 x16 x16 xf32 >) {
395
+ ^bb0 (%in: f32 , %in_3: f32 , %out: f32 ):
396
+ %9 = arith.mulf %in , %in_3 : f32
397
+ %10 = arith.addf %out , %9 : f32
398
+ linalg.yield %10 : f32
399
+ } -> tensor <2 x4 x16 x16 xf32 >
400
+ /// CHECK: scf.forall.in_parallel
401
+ /// CHECK: tensor.parallel_insert_slice
402
+ /// CHECK: tensor.parallel_insert_slice
403
+ /// CHECK: }
404
+ %5 = tensor.empty () : tensor <32 x64 xf32 >
405
+ %unpack = tensor.unpack %4 inner_dims_pos = [0 , 1 ] inner_tiles = [16 , 16 ] into %5 : tensor <2 x4 x16 x16 xf32 > -> tensor <32 x64 xf32 >
406
+ /// CHECK: return %[[FINAL_RESULT]]#1
407
+ return %unpack : tensor <32 x64 xf32 >
408
+ }
356
409
}
0 commit comments