@@ -15,7 +15,7 @@ module {
1515 /// CHECK: tensor.empty
1616 %dest = tensor.empty () : tensor <512 x256 xbf16 >
1717 %unpack = tensor.unpack %arg1 inner_dims_pos = [0 , 1 ] inner_tiles = [16 , 32 ] into %dest : tensor <32 x8 x16 x32 xbf16 > -> tensor <512 x256 xbf16 >
18- /// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}, %{{.*}} ) in (2, 2)
18+ /// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}) in (2, 2)
1919 %2 = scf.forall (%arg3 , %arg4 ) in (2 , 2 ) shared_outs (%arg5 = %1 ) -> (tensor <128 x256 xbf16 >) {
2020 %5 = affine.apply affine_map <(d0 ) -> (d0 * 64 )>(%arg3 )
2121 %6 = affine.apply affine_map <(d0 ) -> (d0 * 128 )>(%arg4 )
@@ -105,7 +105,7 @@ module {
105105 %cst = arith.constant 0.000000e+00 : f32
106106 %dest0 = tensor.empty () : tensor <256 x256 xf32 >
107107 %dest1 = linalg.fill ins (%cst : f32 ) outs (%dest0 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
108- /// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}, %{{.*}} ) in (2, 2)
108+ /// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}) in (2, 2)
109109 %1 = scf.forall (%arg4 , %arg5 ) in (2 , 2 ) shared_outs (%arg6 = %dest1 ) -> tensor <256 x256 xf32 > {
110110 %iv0 = affine.apply #map (%arg4 )
111111 %iv1 = affine.apply #map (%arg5 )
@@ -157,7 +157,7 @@ module {
157157 %cst = arith.constant 0.000000e+00 : f32
158158 %dest0 = tensor.empty () : tensor <256 x256 xf32 >
159159 %dest1 = linalg.fill ins (%cst : f32 ) outs (%dest0 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
160- /// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}, %{{.*}} ) in (2, 1)
160+ /// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}) in (2, 1)
161161 %1 = scf.forall (%arg3 , %arg4 ) in (2 , 1 ) shared_outs (%arg5 = %dest1 ) -> tensor <256 x256 xf32 > {
162162 %iv0 = affine.apply #map (%arg3 )
163163 %iv1 = affine.apply #map (%arg4 )
@@ -205,12 +205,12 @@ module {
205205 %dest0 = tensor.empty () : tensor <128 x256 x256 xf32 >
206206 %0 = linalg.add ins (%arg0 , %arg1 : tensor <128 x256 x256 xf32 >, tensor <128 x256 x256 xf32 >) outs (%dest0 : tensor <128 x256 x256 xf32 >) -> tensor <128 x256 x256 xf32 >
207207 %dest1 = tensor.empty () : tensor <128 x256 xf32 >
208- /// CHECK: %[[FINAL_RESULT:.*]] = scf.forall (%{{.*}}, %{{.*}}) in (128, 256)
209- /// CHECK: tensor.extract_slice {{.*}} [1, 256, 1 ] [1, 1, 1]
210- /// CHECK: tensor.extract_slice {{.*}} [1, 256, 1 ] [1, 1, 1]
211- /// CHECK: tensor.extract_slice {{.*}} [1, 256, 1 ] [1, 1, 1]
208+ /// CHECK: %[[FINAL_RESULT:.*]] = scf.forall (%{{.*}}) = (0, 0) to (128, 256) step (1, 32 )
209+ /// CHECK: tensor.extract_slice {{.*}} [1, 256, 32 ] [1, 1, 1]
210+ /// CHECK: tensor.extract_slice {{.*}} [1, 256, 32 ] [1, 1, 1]
211+ /// CHECK: tensor.extract_slice {{.*}} [1, 256, 32 ] [1, 1, 1]
212212 /// CHECK: %[[ADD_OUT:.*]] = linalg.add
213- /// CHECK: tensor.extract_slice {{.*}} [1, 1 ] [1, 1]
213+ /// CHECK: tensor.extract_slice {{.*}} [1, 32 ] [1, 1]
214214 /// CHECK: %[[REDUCE_OUT:.*]] = linalg.reduce { arith.addf } ins(%[[ADD_OUT]] :
215215 %1 = linalg.reduce { arith.addf } ins (%0 : tensor <128 x256 x256 xf32 >) outs (%dest1 : tensor <128 x256 xf32 >) dimensions = [1 ]
216216 /// CHECK: scf.forall.in_parallel
@@ -319,7 +319,7 @@ module {
319319 /// CHECK-LABEL: @fuse_residual_pattern
320320 func.func @fuse_residual_pattern (%arg0: tensor <128 x256 x256 xf32 >, %arg1: tensor <128 x256 x256 xf32 >) -> tensor <128 x256 x256 xf32 > {
321321 %dest0 = tensor.empty () : tensor <128 x256 x256 xf32 >
322- /// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}, %{{.*}}) in (128, 256)
322+ /// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}) = (0, 0, 0) to (128, 256, 256) step (1, 32, 32 )
323323 /// CHECK: %[[ADD_OUT:.*]] = linalg.add
324324 /// CHECK: %[[EXP_OUT:.*]] = linalg.exp ins(%[[ADD_OUT:.*]] :
325325 /// CHECK: %[[MUL_OUT:.*]] = linalg.mul ins(%[[ADD_OUT:.*]], %[[EXP_OUT:.*]] :
@@ -353,4 +353,57 @@ module {
353353 /// CHECK: return %[[PACK_OUT]]
354354 return %pack : tensor <1 x1 x128 x32 x32 xbf16 >
355355 }
356+ }
357+
358+ // -----
359+
360+ module {
361+ // CHECK: func.func @fuse_generic_matmul(
362+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
363+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x16x16xf32>
364+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<4x16x16xf32>
365+ func.func @fuse_generic_matmul (%arg0: tensor <32 x32 xf32 >, %arg1: tensor <2 x16 x16 xf32 >, %arg2: tensor <4 x16 x16 xf32 >) -> tensor <32 x64 xf32 > attributes {llvm.emit_c_interface } {
366+ /// CHECK: %[[EMPTY_OUT_0:.*]] = tensor.empty
367+ %0 = tensor.empty () : tensor <2 x2 x16 x16 xf32 >
368+ %pack = tensor.pack %arg0 outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [16 , 16 ] into %0 : tensor <32 x32 xf32 > -> tensor <2 x2 x16 x16 xf32 >
369+ /// CHECK: %[[EMPTY_OUT_1:.*]] = tensor.empty
370+ %1 = tensor.empty () : tensor <2 x16 x16 xf32 >
371+ /// CHECK: %[[FIRST_MATMUL_OUT:.*]] = scf.forall (%{{.*}}) in (2)
372+ /// CHECK: %[[EXTRACT_SLICE_0:.*]] = tensor.extract_slice %[[ARG0]]{{.*}} [16, 32]
373+ /// CHECK: %[[EXTRACT_SLICE_1:.*]] = tensor.extract_slice %[[EMPTY_OUT_0]]{{.*}} [1, 2, 16, 16]
374+ /// CHECK: %[[PACK_OUT:.*]] = tensor.pack %[[EXTRACT_SLICE_0]]
375+ /// CHECK: %[[EXTRACT_SLICE_2:.*]] = tensor.extract_slice %[[ARG1]]{{.*}} [2, 16, 16]
376+ /// CHECK: %[[MATMUL_OUT_0:.*]] = linalg.generic {{.*}} ins(%[[PACK_OUT]], %[[EXTRACT_SLICE_2]] :
377+ %2 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 , d2 , d4 )>, affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d1 , d4 , d3 )>, affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d2 , d3 )>], iterator_types = [" parallel" , " reduction" , " parallel" , " parallel" , " reduction" ]} ins (%pack , %arg1 : tensor <2 x2 x16 x16 xf32 >, tensor <2 x16 x16 xf32 >) outs (%1 : tensor <2 x16 x16 xf32 >) {
378+ ^bb0 (%in: f32 , %in_3: f32 , %out: f32 ):
379+ %9 = arith.mulf %in , %in_3 : f32
380+ %10 = arith.addf %out , %9 : f32
381+ linalg.yield %10 : f32
382+ } -> tensor <2 x16 x16 xf32 >
383+ /// CHECK: scf.forall.in_parallel
384+ /// CHECK: tensor.parallel_insert_slice
385+ /// CHECK: }
386+ /// CHECK: %[[EMPTY_OUT_2:.*]] = tensor.empty
387+ /// CHECK: %[[EMPTY_OUT_3:.*]] = tensor.empty
388+ %3 = tensor.empty () : tensor <2 x4 x16 x16 xf32 >
389+ /// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%{{.*}}) in (2, 4)
390+ /// CHECK: %[[EXTRACT_SLICE_3:.*]] = tensor.extract_slice %[[FIRST_MATMUL_OUT]]{{.*}} [1, 16, 16]
391+ /// CHECK: %[[EXTRACT_SLICE_4:.*]] = tensor.extract_slice %[[ARG2]]{{.*}} [1, 16, 16]
392+ /// CHECK: %[[MATMUL_OUT_1:.*]] = linalg.generic {{.*}} ins(%[[EXTRACT_SLICE_3]], %[[EXTRACT_SLICE_4]] :
393+ /// CHECK: %[[UNPACK_OUT:.*]] = tensor.unpack %[[MATMUL_OUT_1]]
394+ %4 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d2 , d4 )>, affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d1 , d4 , d3 )>, affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 , d2 , d3 )>], iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" , " reduction" ]} ins (%2 , %arg2 : tensor <2 x16 x16 xf32 >, tensor <4 x16 x16 xf32 >) outs (%3 : tensor <2 x4 x16 x16 xf32 >) {
395+ ^bb0 (%in: f32 , %in_3: f32 , %out: f32 ):
396+ %9 = arith.mulf %in , %in_3 : f32
397+ %10 = arith.addf %out , %9 : f32
398+ linalg.yield %10 : f32
399+ } -> tensor <2 x4 x16 x16 xf32 >
400+ /// CHECK: scf.forall.in_parallel
401+ /// CHECK: tensor.parallel_insert_slice
402+ /// CHECK: tensor.parallel_insert_slice
403+ /// CHECK: }
404+ %5 = tensor.empty () : tensor <32 x64 xf32 >
405+ %unpack = tensor.unpack %4 inner_dims_pos = [0 , 1 ] inner_tiles = [16 , 16 ] into %5 : tensor <2 x4 x16 x16 xf32 > -> tensor <32 x64 xf32 >
406+ /// CHECK: return %[[FINAL_RESULT]]#1
407+ return %unpack : tensor <32 x64 xf32 >
408+ }
356409}
0 commit comments