Skip to content

[mlir][linalg] Improve linalg.pack consumer fusion. #148993

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 17, 2025

Conversation

hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Jul 15, 2025

If a dimension is not tiled, it is always valid to fuse the pack op, even if it has padding semantics. Because it always generates a full slice along the dimension.

If a dimension is tiled and it does not need extra padding, the fusion is valid.

The revision also formats corresponding tests for consistency.

@llvmbot
Copy link
Member

llvmbot commented Jul 15, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

Changes

If a dimension is not tiled, it is always valid to fuse the pack op, even if it has padding semantics. Because it always generates a full slice along the dimension.

The revision also formats corresponding tests for consistency.


Patch is 25.76 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148993.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+19-20)
  • (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir (+164-115)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 513cecef29b61..7a2931fd6d645 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -887,26 +887,13 @@ struct PackOpTiling
 
     ArrayRef<OpFoldResult> offsets(allOffsets[0]);
     ArrayRef<OpFoldResult> sizes(allSizes[0]);
-
     auto packOp = cast<PackOp>(op);
-    // It is not trivial to infer dest tile from source tile if `packOp` has
-    // padding semantic.
-    if (packOp.getPaddingValue())
-      return failure();
-
     Location loc = packOp.getLoc();
-
     SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
     DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
         packOp.getDimAndTileMapping();
     for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
       if (dimAndTileMapping.count(dim)) {
-        FailureOr<int64_t> cstSize =
-            ValueBoundsConstraintSet::computeConstantBound(
-                presburger::BoundType::UB, sizes[dim],
-                /*stopCondition=*/nullptr, /*closedUB=*/true);
-        std::optional<int64_t> cstInnerSize =
-            getConstantIntValue(dimAndTileMapping[dim]);
         // Currently fusing `packOp` as consumer only expects perfect tiling
         // scenario because even if without padding semantic, the `packOp` may
         // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
@@ -916,12 +903,23 @@ struct PackOpTiling
         // (0,0)~(0,4) at first row.
         // 2. the second slice is extracted from (5) to (9) and SHOULD BE
         // respectively inserted into two rows with different length, including
-        // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate
-        // them, thus adding below constraint to bypass them temporarily. In
-        // another word, we can only support tiling with consumer if the tile
-        // size for the producer is a multiple of the inner tile size for the
-        // packed dimensions at this moment.
-        if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {
+        // first row: (0,5) and second row (1,0)~(1,3).
+        // It is hard to coordinate them, thus adding below constraint to bypass
+        // them temporarily. In another word, we can only support tiling with
+        // consumer if the tile size for the producer is either a multiple of
+        // the inner tile size for the packed dimensions or the dimension is not
+        // tiled at this moment.
+        FailureOr<int64_t> cstTileSize =
+            ValueBoundsConstraintSet::computeConstantBound(
+                presburger::BoundType::UB, sizes[dim],
+                /*stopCondition=*/nullptr, /*closedUB=*/true);
+        std::optional<int64_t> cstInnerSize =
+            getConstantIntValue(dimAndTileMapping[dim]);
+        std::optional<int64_t> cstDimSize = getConstantIntValue(sizes[dim]);
+        bool isTiled = failed(cstTileSize) || !cstDimSize ||
+                       cstTileSize.value() != cstDimSize.value();
+        if (isTiled && (failed(cstTileSize) || !cstInnerSize ||
+                            *cstTileSize % *cstInnerSize != 0)) {
           return failure();
         }
 
@@ -988,7 +986,8 @@ struct PackOpTiling
         loc, packOp.getDest(), outputOffsets, outputSizes, strides);
     tiledOperands.push_back(outSlice);
 
-    assert(!packOp.getPaddingValue() && "Expect no padding semantic");
+    if (auto val = packOp.getPaddingValue())
+      tiledOperands.push_back(val);
     for (auto tile : packOp.getInnerTiles())
       tiledOperands.push_back(tile);
 
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index d09373bdb3f14..d51621c18fd54 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -193,33 +193,33 @@ module attributes {transform.with_named_sequence} {
 
 #map = affine_map<(d0, d1) -> (d0, d1)>
 module {
-    func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) {
-      %c4 = arith.constant 4 : index
-      %c64 = arith.constant 64 : index
-      %c0 = arith.constant 0 : index
-      %0:2 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %arg3, %arg7 = %arg2) -> (tensor<64x32xf32>, tensor<64x64xf32>) {
-        %extracted_slice = tensor.extract_slice %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
-        %extracted_slice_0 = tensor.extract_slice %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
-        %6 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
-        scf.forall.in_parallel {
-          tensor.parallel_insert_slice %6 into %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
-          tensor.parallel_insert_slice %extracted_slice_0 into %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
-        }
+  func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %0:2 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %arg3, %arg7 = %arg2) -> (tensor<64x32xf32>, tensor<64x64xf32>) {
+      %extracted_slice = tensor.extract_slice %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+      %extracted_slice_0 = tensor.extract_slice %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+      %6 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %6 into %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+        tensor.parallel_insert_slice %extracted_slice_0 into %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
       }
-      %1 = tensor.empty() : tensor<64x64xf32>
-      %2 = tensor.empty() : tensor<64x64xf32>
-      %3 = tensor.empty() : tensor<64x64xf32>
-      %4:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%0#1, %1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%2, %3 : tensor<64x64xf32>, tensor<64x64xf32>) {
-      ^bb0(%in: f32, %in_0: f32, %out: f32, %out_1: f32):
-        %6 = arith.mulf %in, %in_0 : f32
-        %7 = arith.subf %out, %6 : f32
-        %8 = arith.addf %out_1, %in : f32
-        linalg.yield %7, %8 : f32, f32
-      } -> (tensor<64x64xf32>, tensor<64x64xf32>)
-      %5 = tensor.empty() : tensor<2048xf32>
-      %unpack = linalg.unpack %0#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %5 : tensor<64x32xf32> -> tensor<2048xf32>
-      return %4#1, %unpack : tensor<64x64xf32>, tensor<2048xf32>
     }
+    %1 = tensor.empty() : tensor<64x64xf32>
+    %2 = tensor.empty() : tensor<64x64xf32>
+    %3 = tensor.empty() : tensor<64x64xf32>
+    %4:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%0#1, %1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%2, %3 : tensor<64x64xf32>, tensor<64x64xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32, %out_1: f32):
+      %6 = arith.mulf %in, %in_0 : f32
+      %7 = arith.subf %out, %6 : f32
+      %8 = arith.addf %out_1, %in : f32
+      linalg.yield %7, %8 : f32, f32
+    } -> (tensor<64x64xf32>, tensor<64x64xf32>)
+    %5 = tensor.empty() : tensor<2048xf32>
+    %unpack = linalg.unpack %0#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %5 : tensor<64x32xf32> -> tensor<2048xf32>
+    return %4#1, %unpack : tensor<64x64xf32>, tensor<2048xf32>
+  }
 }
 
 module attributes {transform.with_named_sequence} {
@@ -269,38 +269,38 @@ module attributes {transform.with_named_sequence} {
 
 #map = affine_map<(d0, d1) -> (d0, d1)>
 module {
-    func.func @fuse_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2048xf32> {
-        %c4 = arith.constant 4 : index
-        %c64 = arith.constant 64 : index
-        %c0 = arith.constant 0 : index
-        %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
-            %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
-            %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
-                ^bb0(%in: f32, %in_16: f32, %out: f32):
-                %13 = arith.mulf %in, %in_16 : f32
-                %14 = arith.addf %out, %13 : f32
-                linalg.yield %14 : f32
-            } -> tensor<32x32xf32>
-            scf.forall.in_parallel {
-                tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
-            }
-        }
-        %output = tensor.empty() : tensor<2048xf32>
-        %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2048xf32>
-        return %unpack : tensor<2048xf32>
+  func.func @fuse_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2048xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+      %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+      %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+        ^bb0(%in: f32, %in_16: f32, %out: f32):
+        %13 = arith.mulf %in, %in_16 : f32
+        %14 = arith.addf %out, %13 : f32
+        linalg.yield %14 : f32
+      } -> tensor<32x32xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+      }
     }
+    %output = tensor.empty() : tensor<2048xf32>
+    %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2048xf32>
+    return %unpack : tensor<2048xf32>
+  }
 }
-  
+
 module attributes {transform.with_named_sequence} {
-    transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
-        %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
-        : (!transform.any_op) -> !transform.any_op
-        %loop = transform.structured.match ops{["scf.forall"]} in %arg1
-        : (!transform.any_op) -> !transform.any_op
-        %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
-        : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-        transform.yield
-    }
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
+    : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
 }
 //  CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
 //  CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2048)>
@@ -332,38 +332,38 @@ module attributes {transform.with_named_sequence} {
 
 #map = affine_map<(d0, d1) -> (d0, d1)>
 module {
-    func.func @fuse_unaligned_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2047xf32> {
-        %c4 = arith.constant 4 : index
-        %c64 = arith.constant 64 : index
-        %c0 = arith.constant 0 : index
-        %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
-            %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
-            %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
-                ^bb0(%in: f32, %in_16: f32, %out: f32):
-                %13 = arith.mulf %in, %in_16 : f32
-                %14 = arith.addf %out, %13 : f32
-                linalg.yield %14 : f32
-            } -> tensor<32x32xf32>
-            scf.forall.in_parallel {
-                tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
-            }
-        }
-        %output = tensor.empty() : tensor<2047xf32>
-        %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2047xf32>
-        return %unpack : tensor<2047xf32>
+  func.func @fuse_unaligned_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2047xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+      %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+      %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+        ^bb0(%in: f32, %in_16: f32, %out: f32):
+        %13 = arith.mulf %in, %in_16 : f32
+        %14 = arith.addf %out, %13 : f32
+        linalg.yield %14 : f32
+      } -> tensor<32x32xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+      }
     }
+    %output = tensor.empty() : tensor<2047xf32>
+    %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2047xf32>
+    return %unpack : tensor<2047xf32>
+  }
 }
-  
+
 module attributes {transform.with_named_sequence} {
-    transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
-        %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
-        : (!transform.any_op) -> !transform.any_op
-        %loop = transform.structured.match ops{["scf.forall"]} in %arg1
-        : (!transform.any_op) -> !transform.any_op
-        %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
-        : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-        transform.yield
-    }
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
+    : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
 }
 //  CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
 //  CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2047)>
@@ -395,38 +395,38 @@ module attributes {transform.with_named_sequence} {
 
 #map = affine_map<(d0, d1) -> (d0, d1)>
 module {
-    func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
-        %c4 = arith.constant 4 : index
-        %c64 = arith.constant 64 : index
-        %c0 = arith.constant 0 : index
-        %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
-            %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
-            %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
-                ^bb0(%in: f32, %in_16: f32, %out: f32):
-                %13 = arith.mulf %in, %in_16 : f32
-                %14 = arith.addf %out, %13 : f32
-                linalg.yield %14 : f32
-            } -> tensor<32x32xf32>
-            scf.forall.in_parallel {
-                tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
-            }
-        }
-        %output = tensor.empty() : tensor<4x32x16xf32>
-        %pack = linalg.pack %1 inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32>
-        return %pack : tensor<4x32x16xf32>
+  func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+      %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+      %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+        ^bb0(%in: f32, %in_16: f32, %out: f32):
+        %13 = arith.mulf %in, %in_16 : f32
+        %14 = arith.addf %out, %13 : f32
+        linalg.yield %14 : f32
+      } -> tensor<32x32xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+      }
     }
+    %output = tensor.empty() : tensor<4x32x16xf32>
+    %pack = linalg.pack %1 inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32>
+    return %pack : tensor<4x32x16xf32>
+  }
 }
-  
+
 module attributes {transform.with_named_sequence} {
-    transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
-        %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
-        : (!transform.any_op) -> !transform.any_op
-        %loop = transform.structured.match ops{["scf.forall"]} in %arg1
-        : (!transform.any_op) -> !transform.any_op
-        %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
-        : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-        transform.yield
-    }
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+    %loop =...
[truncated]

@hanhanW
Copy link
Contributor Author

hanhanW commented Jul 15, 2025

To reviewers, you can use "hide whitespace" feature to skip the diff about formatting.

image

@hanhanW hanhanW force-pushed the improve-pack-fusion branch from 280d3e0 to 33b3812 Compare July 15, 2025 23:56
Copy link

github-actions bot commented Jul 15, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@hanhanW hanhanW force-pushed the improve-pack-fusion branch 2 times, most recently from 044102c to c54d04d Compare July 16, 2025 00:30
If a dimension is not tiled, it is always valid to to fuse the pack op
even if it has padding semantics. Because it always generates a full
slice along the dimension.

Signed-off-by: hanhanW <[email protected]>
Copy link
Contributor

@egebeysel egebeysel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but I'll leave it to another reviewer that has a better overview to approve :)

}
%1 = tensor.empty() : tensor<23x32x3x16xf32>
%cst = arith.constant 0.000000e+00 : f32
%pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'm not sure I understand what the result shape is demonstrating here, why is it 23x 32 x3x16 and not 23x 2 x3x16?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case is a bit iffy in general but thanks to that it caught my attention too 🙃
(Side note: the two iterations of this loop overwrite the same output in parallel).

After fusion, the iteration space doesn't span the whole 32 output dims.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I think I had a fat finger, it should be 2 in this case. I update the logic a bit, which makes sure that the fusion does not happen in this case.

I found that the other test used wrong number of thread, so I fixed it as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to think about the imperfect tiling case, we might need the check. Let me play a bit with the IR and I'll add a corresponding test.

@rengolin rengolin requested a review from adam-smnk July 17, 2025 13:09
}
%1 = tensor.empty() : tensor<23x32x3x16xf32>
%cst = arith.constant 0.000000e+00 : f32
%pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case is a bit iffy in general but thanks to that it caught my attention too 🙃
(Side note: the two iterations of this loop overwrite the same output in parallel).

After fusion, the iteration space doesn't span the whole 32 output dims.

@adam-smnk
Copy link
Contributor

Maybe also a new integration test for extra verification?

Comment on lines -892 to -896
// It is not trivial to infer dest tile from source tile if `packOp` has
// padding semantic.
if (packOp.getPaddingValue())
return failure();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this part getting removed? I don't see any updates to the logic about the padding value in this function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should live in getTiledImplementationFromOperandTiles, but I did not spot it in my review. The idea is that if the dimension is not tiled and the padding is not needed, the fusion is allowed. Originally, I only added "not tiled" logic in the first commit and thought that *cstTileSize % *cstInnerSize != 0 is the check. Now I have more descriptive variable and comments for checking the padding semantic.

hanhanW added 2 commits July 17, 2025 11:15
Signed-off-by: hanhanW <[email protected]>
Copy link
Contributor Author

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed the issue and added more tests, please take a look. Thanks!

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix
LGTM

Signed-off-by: hanhanW <[email protected]>
Copy link
Contributor

@egebeysel egebeysel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic LGTM. A nit and a comment about the test cases though :)

// It is valid to fuse the pack if the dimension is not tiled even when it needs
// extra padding.

func.func @fuse_pack_consumer_with_untiled_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<33x2x3x16xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this test case testing something different than the above one? shapes are different but they're testing the same thing I believe, both of them have dimensions that need padding but are not tiled: 23x3. vs 33x3.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original size is 64, so 23x3 has the minimal number of padding elements; 33x3 requires extra padding elements. This is why I name it untiled_extra_padding.

It is similar to nofuse_pack_consumer_with_extra_padding, but the extra padding happens on an un-tiled dimension.

Copy link
Contributor

@egebeysel egebeysel Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I get the point, that makes sense. But some remarks/questions:

  • I think the minimum number of elements would be 22x3=66 in that case.
  • If you're distinguishing between those cases, I would expect the minimal number of padding elements case to be tiled and still be valid. If that's not the case, what difference does it make that they need different number of padding elements? The output shape corresponds to the untiled op in both cases anyways. Maybe I'm missing something here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I mis-calculate it. Yes, it should be 22x3=66. It is rare to see extra padding case, but some people use it like pad. I'm -1 on the use though. I'll update the value to 22.

I would expect the minimal number of padding elements case to be tiled and still be valid.

This is not possible to fuse because only perfect tiling is allowed atm.

hanhanW added 2 commits July 17, 2025 14:32
Signed-off-by: hanhanW <[email protected]>
Signed-off-by: hanhanW <[email protected]>
@hanhanW hanhanW merged commit 6ff4718 into llvm:main Jul 17, 2025
9 checks passed
@hanhanW hanhanW deleted the improve-pack-fusion branch July 17, 2025 23:06
hanhanW added a commit to iree-org/llvm-project that referenced this pull request Jul 18, 2025
If a dimension is not tiled, it is always valid to fuse the pack op,
even if it has padding semantics. Because it always generates a full
slice along the dimension.

If a dimension is tiled and it does not need extra padding, the fusion
is valid.

The revision also formats corresponding tests for consistency.

---------

Signed-off-by: hanhanW <[email protected]>
hanhanW added a commit to iree-org/llvm-project that referenced this pull request Jul 19, 2025
If a dimension is not tiled, it is always valid to fuse the pack op,
even if it has padding semantics. Because it always generates a full
slice along the dimension.

If a dimension is tiled and it does not need extra padding, the fusion
is valid.

The revision also formats corresponding tests for consistency.

---------

Signed-off-by: hanhanW <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants