Skip to content

Commit 3642d25

Browse files
committed
[mlir][linalg] Improve linalg.pack consumer fusion.
If a dimension is not tiled, it is always valid to to fuse the pack op even if it has padding semantics. Because it always generates a full slice along the dimension. Signed-off-by: hanhanW <[email protected]>
1 parent 66e707e commit 3642d25

File tree

2 files changed

+184
-135
lines changed

2 files changed

+184
-135
lines changed

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -887,26 +887,13 @@ struct PackOpTiling
887887

888888
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
889889
ArrayRef<OpFoldResult> sizes(allSizes[0]);
890-
891890
auto packOp = cast<PackOp>(op);
892-
// It is not trivial to infer dest tile from source tile if `packOp` has
893-
// padding semantic.
894-
if (packOp.getPaddingValue())
895-
return failure();
896-
897891
Location loc = packOp.getLoc();
898-
899892
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
900893
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
901894
packOp.getDimAndTileMapping();
902895
for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
903896
if (dimAndTileMapping.count(dim)) {
904-
FailureOr<int64_t> cstSize =
905-
ValueBoundsConstraintSet::computeConstantBound(
906-
presburger::BoundType::UB, sizes[dim],
907-
/*stopCondition=*/nullptr, /*closedUB=*/true);
908-
std::optional<int64_t> cstInnerSize =
909-
getConstantIntValue(dimAndTileMapping[dim]);
910897
// Currently fusing `packOp` as consumer only expects perfect tiling
911898
// scenario because even if without padding semantic, the `packOp` may
912899
// also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
@@ -916,12 +903,25 @@ struct PackOpTiling
916903
// (0,0)~(0,4) at first row.
917904
// 2. the second slice is extracted from (5) to (9) and SHOULD BE
918905
// respectively inserted into two rows with different length, including
919-
// first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate
920-
// them, thus adding below constraint to bypass them temporarily. In
921-
// another word, we can only support tiling with consumer if the tile
922-
// size for the producer is a multiple of the inner tile size for the
923-
// packed dimensions at this moment.
924-
if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {
906+
// first row: (0,5) and second row (1,0)~(1,3).
907+
// It is hard to coordinate them, thus adding below constraint to bypass
908+
// them temporarily. In another word, we can only support tiling with
909+
// consumer if the tile size for the producer is either a multiple of
910+
// the inner tile size for the packed dimensions or the dimension is not
911+
// tiled at this moment.
912+
FailureOr<int64_t> cstTileSize =
913+
ValueBoundsConstraintSet::computeConstantBound(
914+
presburger::BoundType::UB, sizes[dim],
915+
/*stopCondition=*/nullptr, /*closedUB=*/true);
916+
std::optional<int64_t> cstInnerSize =
917+
getConstantIntValue(dimAndTileMapping[dim]);
918+
int64_t dimSize = packOp.getSourceType().getDimSize(dim);
919+
// TODO: It could be untiled if the `dimSize` is dynamic. It is a hard
920+
// check to determine if a dimension is tiled or not.
921+
bool isTiled = failed(cstTileSize) || ShapedType::isDynamic(dimSize) ||
922+
cstTileSize.value() != dimSize;
923+
if (isTiled && (failed(cstTileSize) || !cstInnerSize ||
924+
*cstTileSize % *cstInnerSize != 0)) {
925925
return failure();
926926
}
927927

@@ -988,7 +988,8 @@ struct PackOpTiling
988988
loc, packOp.getDest(), outputOffsets, outputSizes, strides);
989989
tiledOperands.push_back(outSlice);
990990

991-
assert(!packOp.getPaddingValue() && "Expect no padding semantic");
991+
if (auto val = packOp.getPaddingValue())
992+
tiledOperands.push_back(val);
992993
for (auto tile : packOp.getInnerTiles())
993994
tiledOperands.push_back(tile);
994995

0 commit comments

Comments
 (0)