-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][linalg] Improve linalg.pack consumer fusion. #148993
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Han-Chung Wang (hanhanW) ChangesIf a dimension is not tiled, it is always valid to fuse the pack op, even if it has padding semantics. Because it always generates a full slice along the dimension. The revision also formats corresponding tests for consistency. Patch is 25.76 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148993.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 513cecef29b61..7a2931fd6d645 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -887,26 +887,13 @@ struct PackOpTiling
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
ArrayRef<OpFoldResult> sizes(allSizes[0]);
-
auto packOp = cast<PackOp>(op);
- // It is not trivial to infer dest tile from source tile if `packOp` has
- // padding semantic.
- if (packOp.getPaddingValue())
- return failure();
-
Location loc = packOp.getLoc();
-
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
packOp.getDimAndTileMapping();
for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
if (dimAndTileMapping.count(dim)) {
- FailureOr<int64_t> cstSize =
- ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::UB, sizes[dim],
- /*stopCondition=*/nullptr, /*closedUB=*/true);
- std::optional<int64_t> cstInnerSize =
- getConstantIntValue(dimAndTileMapping[dim]);
// Currently fusing `packOp` as consumer only expects perfect tiling
// scenario because even if without padding semantic, the `packOp` may
// also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
@@ -916,12 +903,23 @@ struct PackOpTiling
// (0,0)~(0,4) at first row.
// 2. the second slice is extracted from (5) to (9) and SHOULD BE
// respectively inserted into two rows with different length, including
- // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate
- // them, thus adding below constraint to bypass them temporarily. In
- // another word, we can only support tiling with consumer if the tile
- // size for the producer is a multiple of the inner tile size for the
- // packed dimensions at this moment.
- if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {
+ // first row: (0,5) and second row (1,0)~(1,3).
+ // It is hard to coordinate them, thus adding below constraint to bypass
+ // them temporarily. In another word, we can only support tiling with
+ // consumer if the tile size for the producer is either a multiple of
+ // the inner tile size for the packed dimensions or the dimension is not
+ // tiled at this moment.
+ FailureOr<int64_t> cstTileSize =
+ ValueBoundsConstraintSet::computeConstantBound(
+ presburger::BoundType::UB, sizes[dim],
+ /*stopCondition=*/nullptr, /*closedUB=*/true);
+ std::optional<int64_t> cstInnerSize =
+ getConstantIntValue(dimAndTileMapping[dim]);
+ std::optional<int64_t> cstDimSize = getConstantIntValue(sizes[dim]);
+ bool isTiled = failed(cstTileSize) || !cstDimSize ||
+ cstTileSize.value() != cstDimSize.value();
+ if (isTiled && (failed(cstTileSize) || !cstInnerSize ||
+ *cstTileSize % *cstInnerSize != 0)) {
return failure();
}
@@ -988,7 +986,8 @@ struct PackOpTiling
loc, packOp.getDest(), outputOffsets, outputSizes, strides);
tiledOperands.push_back(outSlice);
- assert(!packOp.getPaddingValue() && "Expect no padding semantic");
+ if (auto val = packOp.getPaddingValue())
+ tiledOperands.push_back(val);
for (auto tile : packOp.getInnerTiles())
tiledOperands.push_back(tile);
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index d09373bdb3f14..d51621c18fd54 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -193,33 +193,33 @@ module attributes {transform.with_named_sequence} {
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
- func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) {
- %c4 = arith.constant 4 : index
- %c64 = arith.constant 64 : index
- %c0 = arith.constant 0 : index
- %0:2 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %arg3, %arg7 = %arg2) -> (tensor<64x32xf32>, tensor<64x64xf32>) {
- %extracted_slice = tensor.extract_slice %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
- %extracted_slice_0 = tensor.extract_slice %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
- %6 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
- scf.forall.in_parallel {
- tensor.parallel_insert_slice %6 into %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
- tensor.parallel_insert_slice %extracted_slice_0 into %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
- }
+ func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %0:2 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %arg3, %arg7 = %arg2) -> (tensor<64x32xf32>, tensor<64x64xf32>) {
+ %extracted_slice = tensor.extract_slice %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+ %extracted_slice_0 = tensor.extract_slice %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+ %6 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %6 into %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+ tensor.parallel_insert_slice %extracted_slice_0 into %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
}
- %1 = tensor.empty() : tensor<64x64xf32>
- %2 = tensor.empty() : tensor<64x64xf32>
- %3 = tensor.empty() : tensor<64x64xf32>
- %4:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%0#1, %1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%2, %3 : tensor<64x64xf32>, tensor<64x64xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32, %out_1: f32):
- %6 = arith.mulf %in, %in_0 : f32
- %7 = arith.subf %out, %6 : f32
- %8 = arith.addf %out_1, %in : f32
- linalg.yield %7, %8 : f32, f32
- } -> (tensor<64x64xf32>, tensor<64x64xf32>)
- %5 = tensor.empty() : tensor<2048xf32>
- %unpack = linalg.unpack %0#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %5 : tensor<64x32xf32> -> tensor<2048xf32>
- return %4#1, %unpack : tensor<64x64xf32>, tensor<2048xf32>
}
+ %1 = tensor.empty() : tensor<64x64xf32>
+ %2 = tensor.empty() : tensor<64x64xf32>
+ %3 = tensor.empty() : tensor<64x64xf32>
+ %4:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%0#1, %1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%2, %3 : tensor<64x64xf32>, tensor<64x64xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32, %out_1: f32):
+ %6 = arith.mulf %in, %in_0 : f32
+ %7 = arith.subf %out, %6 : f32
+ %8 = arith.addf %out_1, %in : f32
+ linalg.yield %7, %8 : f32, f32
+ } -> (tensor<64x64xf32>, tensor<64x64xf32>)
+ %5 = tensor.empty() : tensor<2048xf32>
+ %unpack = linalg.unpack %0#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %5 : tensor<64x32xf32> -> tensor<2048xf32>
+ return %4#1, %unpack : tensor<64x64xf32>, tensor<2048xf32>
+ }
}
module attributes {transform.with_named_sequence} {
@@ -269,38 +269,38 @@ module attributes {transform.with_named_sequence} {
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
- func.func @fuse_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2048xf32> {
- %c4 = arith.constant 4 : index
- %c64 = arith.constant 64 : index
- %c0 = arith.constant 0 : index
- %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
- %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
- %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
- ^bb0(%in: f32, %in_16: f32, %out: f32):
- %13 = arith.mulf %in, %in_16 : f32
- %14 = arith.addf %out, %13 : f32
- linalg.yield %14 : f32
- } -> tensor<32x32xf32>
- scf.forall.in_parallel {
- tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
- }
- }
- %output = tensor.empty() : tensor<2048xf32>
- %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2048xf32>
- return %unpack : tensor<2048xf32>
+ func.func @fuse_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2048xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+ %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.addf %out, %13 : f32
+ linalg.yield %14 : f32
+ } -> tensor<32x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+ }
}
+ %output = tensor.empty() : tensor<2048xf32>
+ %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2048xf32>
+ return %unpack : tensor<2048xf32>
+ }
}
-
+
module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
- %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- %loop = transform.structured.match ops{["scf.forall"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
- : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.yield
- }
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
}
// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2048)>
@@ -332,38 +332,38 @@ module attributes {transform.with_named_sequence} {
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
- func.func @fuse_unaligned_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2047xf32> {
- %c4 = arith.constant 4 : index
- %c64 = arith.constant 64 : index
- %c0 = arith.constant 0 : index
- %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
- %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
- %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
- ^bb0(%in: f32, %in_16: f32, %out: f32):
- %13 = arith.mulf %in, %in_16 : f32
- %14 = arith.addf %out, %13 : f32
- linalg.yield %14 : f32
- } -> tensor<32x32xf32>
- scf.forall.in_parallel {
- tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
- }
- }
- %output = tensor.empty() : tensor<2047xf32>
- %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2047xf32>
- return %unpack : tensor<2047xf32>
+ func.func @fuse_unaligned_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2047xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+ %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.addf %out, %13 : f32
+ linalg.yield %14 : f32
+ } -> tensor<32x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+ }
}
+ %output = tensor.empty() : tensor<2047xf32>
+ %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2047xf32>
+ return %unpack : tensor<2047xf32>
+ }
}
-
+
module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
- %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- %loop = transform.structured.match ops{["scf.forall"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
- : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.yield
- }
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
}
// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2047)>
@@ -395,38 +395,38 @@ module attributes {transform.with_named_sequence} {
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
- func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
- %c4 = arith.constant 4 : index
- %c64 = arith.constant 64 : index
- %c0 = arith.constant 0 : index
- %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
- %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
- %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
- ^bb0(%in: f32, %in_16: f32, %out: f32):
- %13 = arith.mulf %in, %in_16 : f32
- %14 = arith.addf %out, %13 : f32
- linalg.yield %14 : f32
- } -> tensor<32x32xf32>
- scf.forall.in_parallel {
- tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
- }
- }
- %output = tensor.empty() : tensor<4x32x16xf32>
- %pack = linalg.pack %1 inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32>
- return %pack : tensor<4x32x16xf32>
+ func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+ %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.addf %out, %13 : f32
+ linalg.yield %14 : f32
+ } -> tensor<32x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+ }
}
+ %output = tensor.empty() : tensor<4x32x16xf32>
+ %pack = linalg.pack %1 inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32>
+ return %pack : tensor<4x32x16xf32>
+ }
}
-
+
module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
- %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- %loop = transform.structured.match ops{["scf.forall"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
- : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.yield
- }
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %loop =...
[truncated]
|
280d3e0
to
33b3812
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
044102c
to
c54d04d
Compare
If a dimension is not tiled, it is always valid to to fuse the pack op even if it has padding semantics. Because it always generates a full slice along the dimension. Signed-off-by: hanhanW <[email protected]>
c54d04d
to
3642d25
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but I'll leave it to another reviewer that has a better overview to approve :)
} | ||
%1 = tensor.empty() : tensor<23x32x3x16xf32> | ||
%cst = arith.constant 0.000000e+00 : f32 | ||
%pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I'm not sure I understand what the result shape is demonstrating here, why is it 23x 32 x3x16
and not 23x 2 x3x16
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test case is a bit iffy in general but thanks to that it caught my attention too 🙃
(Side note: the two iterations of this loop overwrite the same output in parallel).
After fusion, the iteration space doesn't span the whole 32
output dims.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I think I had a fat finger, it should be 2
in this case. I update the logic a bit, which makes sure that the fusion does not happen in this case.
I found that the other test used wrong number of thread, so I fixed it as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to think about the imperfect tiling case, we might need the check. Let me play a bit with the IR and I'll add a corresponding test.
} | ||
%1 = tensor.empty() : tensor<23x32x3x16xf32> | ||
%cst = arith.constant 0.000000e+00 : f32 | ||
%pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test case is a bit iffy in general but thanks to that it caught my attention too 🙃
(Side note: the two iterations of this loop overwrite the same output in parallel).
After fusion, the iteration space doesn't span the whole 32
output dims.
mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
Outdated
Show resolved
Hide resolved
Maybe also a new integration test for extra verification? |
// It is not trivial to infer dest tile from source tile if `packOp` has | ||
// padding semantic. | ||
if (packOp.getPaddingValue()) | ||
return failure(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this part getting removed? I don't see any updates to the logic about the padding value in this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should live in getTiledImplementationFromOperandTiles
, but I did not spot it in my review. The idea is that if the dimension is not tiled and the padding is not needed, the fusion is allowed. Originally, I only added "not tiled" logic in the first commit and thought that *cstTileSize % *cstInnerSize != 0
is the check. Now I have more descriptive variable and comments for checking the padding semantic.
Signed-off-by: hanhanW <[email protected]>
Signed-off-by: hanhanW <[email protected]>
Signed-off-by: hanhanW <[email protected]>
Signed-off-by: hanhanW <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed the issue and added more tests, please take a look. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix
LGTM
Signed-off-by: hanhanW <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic LGTM. A nit and a comment about the test cases though :)
mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
Outdated
Show resolved
Hide resolved
// It is valid to fuse the pack if the dimension is not tiled even when it needs | ||
// extra padding. | ||
|
||
func.func @fuse_pack_consumer_with_untiled_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<33x2x3x16xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this test case testing something different than the above one? shapes are different but they're testing the same thing I believe, both of them have dimensions that need padding but are not tiled: 23x3. vs 33x3
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original size is 64, so 23x3
has the minimal number of padding elements; 33x3
requires extra padding elements. This is why I name it untiled_extra_padding
.
It is similar to nofuse_pack_consumer_with_extra_padding
, but the extra padding happens on an un-tiled dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I get the point, that makes sense. But some remarks/questions:
- I think the minimum number of elements would be
22x3=66
in that case. - If you're distinguishing between those cases, I would expect the minimal number of padding elements case to be tiled and still be valid. If that's not the case, what difference does it make that they need different number of padding elements? The output shape corresponds to the untiled op in both cases anyways. Maybe I'm missing something here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I mis-calculate it. Yes, it should be 22x3=66
. It is rare to see extra padding case, but some people use it like pad. I'm -1 on the use though. I'll update the value to 22
.
I would expect the minimal number of padding elements case to be tiled and still be valid.
This is not possible to fuse because only perfect tiling is allowed atm.
Signed-off-by: hanhanW <[email protected]>
Signed-off-by: hanhanW <[email protected]>
If a dimension is not tiled, it is always valid to fuse the pack op, even if it has padding semantics. Because it always generates a full slice along the dimension. If a dimension is tiled and it does not need extra padding, the fusion is valid. The revision also formats corresponding tests for consistency. --------- Signed-off-by: hanhanW <[email protected]>
If a dimension is not tiled, it is always valid to fuse the pack op, even if it has padding semantics. Because it always generates a full slice along the dimension. If a dimension is tiled and it does not need extra padding, the fusion is valid. The revision also formats corresponding tests for consistency. --------- Signed-off-by: hanhanW <[email protected]>
If a dimension is not tiled, it is always valid to fuse the pack op, even if it has padding semantics. Because it always generates a full slice along the dimension.
If a dimension is tiled and it does not need extra padding, the fusion is valid.
The revision also formats corresponding tests for consistency.