@@ -15,7 +15,7 @@ module {
1515    /// CHECK: tensor.empty 
1616    %dest  = tensor.empty () : tensor <512 x256 xbf16 >
1717    %unpack  = tensor.unpack  %arg1  inner_dims_pos  = [0 , 1 ] inner_tiles  = [16 , 32 ] into  %dest  : tensor <32 x8 x16 x32 xbf16 > -> tensor <512 x256 xbf16 >
18-     /// CHECK:   %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}, %{{.*}} ) in (2, 2) 
18+     /// CHECK:   %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}) in (2, 2) 
1919    %2  = scf.forall  (%arg3 , %arg4 ) in  (2 , 2 ) shared_outs (%arg5  = %1 ) -> (tensor <128 x256 xbf16 >) {
2020      %5  = affine.apply  affine_map <(d0 ) -> (d0  * 64 )>(%arg3 )
2121      %6  = affine.apply  affine_map <(d0 ) -> (d0  * 128 )>(%arg4 )
@@ -105,7 +105,7 @@ module {
105105    %cst  = arith.constant  0.000000e+00  : f32 
106106    %dest0  = tensor.empty () : tensor <256 x256 xf32 >
107107    %dest1  = linalg.fill  ins (%cst  : f32 ) outs (%dest0  : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
108-     /// CHECK:   %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}, %{{.*}} ) in (2, 2) 
108+     /// CHECK:   %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}) in (2, 2) 
109109    %1  = scf.forall  (%arg4 , %arg5 ) in  (2 , 2 ) shared_outs (%arg6  = %dest1 ) -> tensor <256 x256 xf32 > {
110110      %iv0  = affine.apply  #map (%arg4 )
111111      %iv1  = affine.apply  #map (%arg5 )
@@ -157,7 +157,7 @@ module {
157157    %cst  = arith.constant  0.000000e+00  : f32 
158158    %dest0  = tensor.empty () : tensor <256 x256 xf32 >
159159    %dest1  = linalg.fill  ins (%cst  : f32 ) outs (%dest0  : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
160-     /// CHECK:   %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}, %{{.*}} ) in (2, 1) 
160+     /// CHECK:   %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}) in (2, 1) 
161161    %1  = scf.forall  (%arg3 , %arg4 ) in  (2 , 1 ) shared_outs (%arg5  = %dest1 ) -> tensor <256 x256 xf32 > {
162162      %iv0  = affine.apply  #map (%arg3 )
163163      %iv1  = affine.apply  #map (%arg4 )
@@ -205,12 +205,12 @@ module {
205205    %dest0  = tensor.empty () : tensor <128 x256 x256 xf32 >
206206    %0  = linalg.add  ins (%arg0 , %arg1  : tensor <128 x256 x256 xf32 >, tensor <128 x256 x256 xf32 >) outs (%dest0  : tensor <128 x256 x256 xf32 >) -> tensor <128 x256 x256 xf32 >
207207    %dest1  = tensor.empty () : tensor <128 x256 xf32 >
208-     /// CHECK:   %[[FINAL_RESULT:.*]] = scf.forall (%{{.*}}, %{{.*}}) in  (128, 256) 
209-     /// CHECK: tensor.extract_slice {{.*}} [1, 256, 1 ] [1, 1, 1] 
210-     /// CHECK: tensor.extract_slice {{.*}} [1, 256, 1 ] [1, 1, 1] 
211-     /// CHECK: tensor.extract_slice {{.*}} [1, 256, 1 ] [1, 1, 1] 
208+     /// CHECK:   %[[FINAL_RESULT:.*]] = scf.forall (%{{.*}}) = (0, 0) to  (128, 256) step (1, 32 ) 
209+     /// CHECK: tensor.extract_slice {{.*}} [1, 256, 32 ] [1, 1, 1] 
210+     /// CHECK: tensor.extract_slice {{.*}} [1, 256, 32 ] [1, 1, 1] 
211+     /// CHECK: tensor.extract_slice {{.*}} [1, 256, 32 ] [1, 1, 1] 
212212    /// CHECK: %[[ADD_OUT:.*]] = linalg.add 
213-     /// CHECK: tensor.extract_slice {{.*}} [1, 1 ] [1, 1] 
213+     /// CHECK: tensor.extract_slice {{.*}} [1, 32 ] [1, 1] 
214214    /// CHECK: %[[REDUCE_OUT:.*]] = linalg.reduce { arith.addf } ins(%[[ADD_OUT]] : 
215215    %1  = linalg.reduce  { arith.addf  } ins (%0  : tensor <128 x256 x256 xf32 >) outs (%dest1  : tensor <128 x256 xf32 >) dimensions  = [1 ]
216216    /// CHECK: scf.forall.in_parallel 
@@ -319,7 +319,7 @@ module {
319319  /// CHECK-LABEL: @fuse_residual_pattern 
320320  func.func  @fuse_residual_pattern (%arg0:  tensor <128 x256 x256 xf32 >, %arg1:  tensor <128 x256 x256 xf32 >) -> tensor <128 x256 x256 xf32 > {
321321    %dest0  = tensor.empty () : tensor <128 x256 x256 xf32 >
322-     /// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}, %{{.*}}) in  (128, 256) 
322+     /// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}) = (0, 0, 0) to  (128, 256, 256) step (1, 32, 32 ) 
323323    /// CHECK: %[[ADD_OUT:.*]] = linalg.add 
324324    /// CHECK: %[[EXP_OUT:.*]] = linalg.exp ins(%[[ADD_OUT:.*]] : 
325325    /// CHECK: %[[MUL_OUT:.*]] = linalg.mul ins(%[[ADD_OUT:.*]], %[[EXP_OUT:.*]] : 
@@ -353,4 +353,57 @@ module {
353353    /// CHECK: return %[[PACK_OUT]] 
354354    return  %pack  : tensor <1 x1 x128 x32 x32 xbf16 >
355355  }
356+ }
357+ 
358+ // ----- 
359+ 
360+ module  {
361+   //      CHECK: func.func @fuse_generic_matmul( 
362+   // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> 
363+   // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x16x16xf32> 
364+   // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<4x16x16xf32> 
365+   func.func  @fuse_generic_matmul (%arg0:  tensor <32 x32 xf32 >, %arg1:  tensor <2 x16 x16 xf32 >, %arg2:  tensor <4 x16 x16 xf32 >) -> tensor <32 x64 xf32 > attributes  {llvm.emit_c_interface } {
366+     /// CHECK: %[[EMPTY_OUT_0:.*]] = tensor.empty 
367+     %0  = tensor.empty () : tensor <2 x2 x16 x16 xf32 >
368+     %pack  = tensor.pack  %arg0  outer_dims_perm  = [0 , 1 ] inner_dims_pos  = [0 , 1 ] inner_tiles  = [16 , 16 ] into  %0  : tensor <32 x32 xf32 > -> tensor <2 x2 x16 x16 xf32 >
369+     /// CHECK: %[[EMPTY_OUT_1:.*]] = tensor.empty 
370+     %1  = tensor.empty () : tensor <2 x16 x16 xf32 >
371+     /// CHECK: %[[FIRST_MATMUL_OUT:.*]] = scf.forall (%{{.*}}) in (2) 
372+     /// CHECK: %[[EXTRACT_SLICE_0:.*]] = tensor.extract_slice %[[ARG0]]{{.*}} [16, 32] 
373+     /// CHECK: %[[EXTRACT_SLICE_1:.*]] = tensor.extract_slice %[[EMPTY_OUT_0]]{{.*}} [1, 2, 16, 16] 
374+     /// CHECK: %[[PACK_OUT:.*]] = tensor.pack %[[EXTRACT_SLICE_0]] 
375+     /// CHECK: %[[EXTRACT_SLICE_2:.*]] = tensor.extract_slice %[[ARG1]]{{.*}} [2, 16, 16] 
376+     /// CHECK: %[[MATMUL_OUT_0:.*]] = linalg.generic {{.*}} ins(%[[PACK_OUT]], %[[EXTRACT_SLICE_2]] : 
377+     %2  = linalg.generic  {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 , d2 , d4 )>, affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d1 , d4 , d3 )>, affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d2 , d3 )>], iterator_types  = [" parallel" " reduction" " parallel" " parallel" " reduction" ins (%pack , %arg1  : tensor <2 x2 x16 x16 xf32 >, tensor <2 x16 x16 xf32 >) outs (%1  : tensor <2 x16 x16 xf32 >) {
378+     ^bb0 (%in:  f32 , %in_3:  f32 , %out:  f32 ):
379+       %9  = arith.mulf  %in , %in_3  : f32 
380+       %10  = arith.addf  %out , %9  : f32 
381+       linalg.yield  %10  : f32 
382+     } -> tensor <2 x16 x16 xf32 >
383+     /// CHECK: scf.forall.in_parallel 
384+     /// CHECK: tensor.parallel_insert_slice 
385+     /// CHECK: } 
386+     /// CHECK: %[[EMPTY_OUT_2:.*]] = tensor.empty 
387+     /// CHECK: %[[EMPTY_OUT_3:.*]] = tensor.empty 
388+     %3  = tensor.empty () : tensor <2 x4 x16 x16 xf32 >
389+     /// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%{{.*}}) in (2, 4) 
390+     /// CHECK: %[[EXTRACT_SLICE_3:.*]] = tensor.extract_slice %[[FIRST_MATMUL_OUT]]{{.*}} [1, 16, 16] 
391+     /// CHECK: %[[EXTRACT_SLICE_4:.*]] = tensor.extract_slice %[[ARG2]]{{.*}} [1, 16, 16] 
392+     /// CHECK: %[[MATMUL_OUT_1:.*]] = linalg.generic {{.*}} ins(%[[EXTRACT_SLICE_3]], %[[EXTRACT_SLICE_4]] : 
393+     /// CHECK: %[[UNPACK_OUT:.*]] = tensor.unpack %[[MATMUL_OUT_1]] 
394+     %4  = linalg.generic  {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d2 , d4 )>, affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d1 , d4 , d3 )>, affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 , d2 , d3 )>], iterator_types  = [" parallel" " parallel" " parallel" " parallel" " reduction" ins (%2 , %arg2  : tensor <2 x16 x16 xf32 >, tensor <4 x16 x16 xf32 >) outs (%3  : tensor <2 x4 x16 x16 xf32 >) {
395+     ^bb0 (%in:  f32 , %in_3:  f32 , %out:  f32 ):
396+       %9  = arith.mulf  %in , %in_3  : f32 
397+       %10  = arith.addf  %out , %9  : f32 
398+       linalg.yield  %10  : f32 
399+     } -> tensor <2 x4 x16 x16 xf32 >
400+     /// CHECK: scf.forall.in_parallel 
401+     /// CHECK: tensor.parallel_insert_slice 
402+     /// CHECK: tensor.parallel_insert_slice 
403+     /// CHECK: } 
404+     %5  = tensor.empty () : tensor <32 x64 xf32 >
405+     %unpack  = tensor.unpack  %4  inner_dims_pos  = [0 , 1 ] inner_tiles  = [16 , 16 ] into  %5  : tensor <2 x4 x16 x16 xf32 > -> tensor <32 x64 xf32 >
406+     /// CHECK: return %[[FINAL_RESULT]]#1 
407+     return  %unpack  : tensor <32 x64 xf32 >
408+   }
356409}
0 commit comments