Skip to content

Commit 3517083

Browse files
authored
Merge branch 'main' into xurui/add_benchmark
2 parents 271e0c3 + 6c23c8c commit 3517083

File tree

2 files changed

+130
-52
lines changed

2 files changed

+130
-52
lines changed

lib/gc/Transforms/IterativeTilingAndFusion.cpp

+68-43
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "gc/Analysis/TargetDescriptionAnalysis.h"
1010
#include "gc/Dialect/Linalgx/LinalgxOps.h"
11+
#include "gc/Dialect/Linalgx/Utils.h"
1112
#include "gc/Transforms/Passes.h"
1213
#include "mlir/Analysis/TopologicalSortUtils.h"
1314
#include "mlir/Dialect/DLTI/Traits.h"
@@ -166,11 +167,10 @@ exactTilingOnPackUnPackFilter(RewriterBase &rewriter,
166167
tileSizesOnInnerDims =
167168
llvm::to_vector(ArrayRef(tileSizes).take_back(innerTiles.size()));
168169
} else {
169-
// tileSize comes from OpOperand
170-
ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
171-
for (auto &pos : innerDimPos) {
172-
tileSizesOnInnerDims.push_back(tileSizes[pos]);
173-
}
170+
// Upstream doesn't implement `getTiledImplementationFromOperandTile`
171+
// interface of `packOp` so far. In another word, `packOp` could not be
172+
// fused as consumer. As a result, just return failure currently.
173+
return failure();
174174
}
175175
} else if (auto unPackOp = dyn_cast<tensor::UnPackOp>(defOrUse.ownerOp)) {
176176
innerTiles = unPackOp.getMixedTiles();
@@ -215,7 +215,7 @@ nonContractionOpFilter(RewriterBase &rewriter,
215215
CandidateDefOrUse defOrUse) {
216216
// Currently this pass focuses on fine-grained fusion, which does not expect
217217
// two consecutive contraction ops.
218-
return failure(isa<mlir::linalg::ContractionOpInterface>(defOrUse.ownerOp));
218+
return failure(linalgx::isMatmulOp(defOrUse.ownerOp));
219219
}
220220

221221
/// If fusing multiple consumers is allowed, there may exist following cases:
@@ -635,29 +635,36 @@ static LogicalResult isTiledOpInLoop(Operation *targetOp) {
635635
// 3. check whether has either extract or insert slice op
636636
auto walkResult = forOp->walk(
637637
[](tensor::ExtractSliceOp) { return WalkResult::interrupt(); });
638-
if (walkResult.wasInterrupted())
639-
return success();
640-
walkResult = forOp->walk(
641-
[](tensor::InsertSliceOp) { return WalkResult::interrupt(); });
638+
if (!walkResult.wasInterrupted())
639+
return failure();
640+
walkResult = forOp->walk([](OffsetSizeAndStrideOpInterface op) {
641+
return isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(op)
642+
? WalkResult::interrupt()
643+
: WalkResult::advance();
644+
});
642645
return success(walkResult.wasInterrupted());
643646
}
644647

645648
using OpTileSizeMap = std::unordered_map<std::string, SmallVector<int64_t>>;
646649

647650
/// Default Tiling function only effective for certain `OpTy` operation
648-
template <typename OpTy>
649651
static FailureOr<scf::SCFTilingResult>
650652
defaultTilingOfType(RewriterBase &rewriter, Operation *op,
653+
function_ref<bool(Operation *)> isaOpTy,
651654
const OpTileSizeMap &tsMap) {
652655
// a. Check <OpTy>
653-
if (!isa<TilingInterface>(op) || !isa<OpTy>(op))
656+
if (!isa<TilingInterface>(op) || !isaOpTy(op))
654657
return failure();
655658
auto tilingInterfaceOp = cast<TilingInterface>(op);
656659

657660
scf::SCFTilingOptions options;
658661
// b. Get default tiling size
659662
SmallVector<utils::IteratorType> iteratorTypes =
660663
tilingInterfaceOp.getLoopIteratorTypes();
664+
llvm::SmallVector<Range> iterationDomain =
665+
tilingInterfaceOp.getIterationDomain(rewriter);
666+
assert(iteratorTypes.size() == iterationDomain.size() &&
667+
"Iteration domain expected as same long as iteration type");
661668

662669
SmallVector<OpFoldResult> defaultTileSize;
663670

@@ -671,15 +678,40 @@ defaultTilingOfType(RewriterBase &rewriter, Operation *op,
671678
getAsOpFoldResult(rewriter.getI64ArrayAttr(userDefaultTileSize));
672679
} else {
673680
defaultTileSize.resize(iteratorTypes.size(), rewriter.getIndexAttr(0));
681+
// Try tileSize from `32` to `16`.
682+
SmallVector<int64_t> tsOrder = {32, 16};
683+
// Only 2D tile is expected.
684+
int tileDims = (isa<mlir::linalg::LinalgOp>(op) && !linalgx::isMatmulOp(op))
685+
? cast<mlir::linalg::LinalgOp>(op).getNumReductionLoops()
686+
: 0;
687+
// Reverse both of iteration type and domain from inner to outer.
688+
std::reverse(iteratorTypes.begin(), iteratorTypes.end());
689+
std::reverse(iterationDomain.begin(), iterationDomain.end());
690+
674691
for (auto &&[en, iterType] : llvm::enumerate(iteratorTypes)) {
675-
// All outer non reduction loop should contribute parallelism. In another
676-
// word, all reduction dimensions should not be tiled.
677-
if (iterType == utils::IteratorType::parallel &&
678-
(en != iteratorTypes.size() - 1 ||
679-
llvm::count(iteratorTypes, utils::IteratorType::reduction)))
680-
defaultTileSize[en] = rewriter.getIndexAttr(1);
692+
// All parallel iterator will be tiled by `32` or `16`. If need
693+
// specified, please set option `defaultTileSize`, like `matmul:{64,64}`.
694+
if (iterType == utils::IteratorType::parallel) {
695+
Range curDomain = iterationDomain[en];
696+
std::optional<int64_t> tripCount = mlir::constantTripCount(
697+
curDomain.offset, curDomain.size, curDomain.stride);
698+
if (tileDims >= 2 && en > 0) {
699+
defaultTileSize[en] = rewriter.getIndexAttr(1);
700+
continue;
701+
} else if (tripCount) {
702+
for (auto &ts : tsOrder) {
703+
if (*tripCount % ts == 0 && *tripCount > ts) {
704+
defaultTileSize[en] = rewriter.getIndexAttr(ts);
705+
break;
706+
}
707+
}
708+
}
709+
tileDims++;
710+
}
681711
}
682712
}
713+
// Reverse back default TileSize.
714+
std::reverse(defaultTileSize.begin(), defaultTileSize.end());
683715
// If the tile sizes are all zero, no tiling would happen.
684716
if (llvm::all_of(defaultTileSize, isZeroIndex))
685717
return failure();
@@ -697,20 +729,6 @@ defaultTilingOfType(RewriterBase &rewriter, Operation *op,
697729
return tilingResult;
698730
}
699731

700-
template <typename OpTy1, typename OpTy2, typename... Rest>
701-
static FailureOr<scf::SCFTilingResult>
702-
defaultTilingOfType(RewriterBase &rewriter, Operation *op,
703-
const OpTileSizeMap &tsMap) {
704-
FailureOr<scf::SCFTilingResult> tilingResult =
705-
defaultTilingOfType<OpTy1>(rewriter, op, tsMap);
706-
if (failed(tilingResult))
707-
return defaultTilingOfType<OpTy2, Rest...>(rewriter, op, tsMap);
708-
return tilingResult;
709-
}
710-
711-
using DefaultTilingFn = std::function<FailureOr<scf::SCFTilingResult>(
712-
RewriterBase &, Operation *, const OpTileSizeMap &)>;
713-
714732
void iterativeTilingAndFusionUntilExhaustion(
715733
RewriterBase &rewriter, func::FuncOp &f,
716734
const CandidateSliceOptions &sliceOptions, const OpTileSizeMap &tsMap) {
@@ -738,9 +756,8 @@ void iterativeTilingAndFusionUntilExhaustion(
738756

739757
// Walk through funcOp
740758
f->walk([&tiledOps](Operation *op) {
741-
if (succeeded(isTiledOpInLoop(op))) {
759+
if (succeeded(isTiledOpInLoop(op)))
742760
tiledOps.insert(op);
743-
}
744761
});
745762

746763
// Iterative tiling and fusion until exhaustion.
@@ -764,17 +781,25 @@ void iterativeTilingAndFusionUntilExhaustion(
764781
// Auto tiling with default tile size if no tiled op found. Follow tiling
765782
// priority based on OpTy:
766783
// `ContractionOp`->`ReductionOp`->`LinalgOp`->`TensorOp`.
767-
SmallVector<DefaultTilingFn> priorityTilingPipeLine = {
768-
defaultTilingOfType<mlir::linalg::ContractionOpInterface>,
769-
defaultTilingOfType<mlir::linalg::ReduceOp>,
770-
defaultTilingOfType<mlir::linalg::LinalgOp>,
771-
defaultTilingOfType<tensor::PackOp, tensor::UnPackOp, tensor::PadOp>,
772-
defaultTilingOfType<TilingInterface>};
773-
774-
for (auto &tilingFn : priorityTilingPipeLine) {
784+
SmallVector<std::function<bool(Operation *)>> priorityOpTypeOrder = {
785+
// Generate helper function to check if isa<OpTy>.
786+
#define GenIsaOpTy(opTy) [](Operation *op) { return opTy(op); }
787+
// If ContractionOp
788+
GenIsaOpTy(linalgx::isMatmulOp),
789+
// If ReduceOp
790+
GenIsaOpTy(isa<mlir::linalg::ReduceOp>),
791+
// If other LinalgOp
792+
GenIsaOpTy(isa<mlir::linalg::LinalgOp>),
793+
// If TensorOp
794+
GenIsaOpTy((isa<tensor::PackOp, tensor::UnPackOp, tensor::PadOp>)),
795+
// Fallback
796+
GenIsaOpTy(isa<TilingInterface>)};
797+
#undef GenIsaOpTy
798+
mlir::topologicalSort(unTiledOps);
799+
for (auto &isaOpTy : priorityOpTypeOrder) {
775800
for (auto &op : unTiledOps) {
776801
FailureOr<scf::SCFTilingResult> tilingResult =
777-
tilingFn(rewriter, op, tsMap);
802+
defaultTilingOfType(rewriter, op, isaOpTy, tsMap);
778803
if (succeeded(tilingResult)) {
779804
tiledOps.insert(tilingResult->tiledOps[0]);
780805
rewriter.replaceOp(op, tilingResult->replacements);

test/mlir/test/gc/Transforms/iterative-tiling-and-fusion.mlir

+62-9
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ module {
1515
/// CHECK: tensor.empty
1616
%dest = tensor.empty() : tensor<512x256xbf16>
1717
%unpack = tensor.unpack %arg1 inner_dims_pos = [0, 1] inner_tiles = [16, 32] into %dest : tensor<32x8x16x32xbf16> -> tensor<512x256xbf16>
18-
/// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}, %{{.*}}) in (2, 2)
18+
/// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}) in (2, 2)
1919
%2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %1) -> (tensor<128x256xbf16>) {
2020
%5 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
2121
%6 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg4)
@@ -105,7 +105,7 @@ module {
105105
%cst = arith.constant 0.000000e+00 : f32
106106
%dest0 = tensor.empty() : tensor<256x256xf32>
107107
%dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
108-
/// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}, %{{.*}}) in (2, 2)
108+
/// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}) in (2, 2)
109109
%1 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %dest1) -> tensor<256x256xf32> {
110110
%iv0 = affine.apply #map(%arg4)
111111
%iv1 = affine.apply #map(%arg5)
@@ -157,7 +157,7 @@ module {
157157
%cst = arith.constant 0.000000e+00 : f32
158158
%dest0 = tensor.empty() : tensor<256x256xf32>
159159
%dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
160-
/// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}, %{{.*}}) in (2, 1)
160+
/// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}) in (2, 1)
161161
%1 = scf.forall (%arg3, %arg4) in (2, 1) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
162162
%iv0 = affine.apply #map(%arg3)
163163
%iv1 = affine.apply #map(%arg4)
@@ -205,12 +205,12 @@ module {
205205
%dest0 = tensor.empty() : tensor<128x256x256xf32>
206206
%0 = linalg.add ins(%arg0, %arg1 : tensor<128x256x256xf32>, tensor<128x256x256xf32>) outs(%dest0 : tensor<128x256x256xf32>) -> tensor<128x256x256xf32>
207207
%dest1 = tensor.empty() : tensor<128x256xf32>
208-
/// CHECK: %[[FINAL_RESULT:.*]] = scf.forall (%{{.*}}, %{{.*}}) in (128, 256)
209-
/// CHECK: tensor.extract_slice {{.*}} [1, 256, 1] [1, 1, 1]
210-
/// CHECK: tensor.extract_slice {{.*}} [1, 256, 1] [1, 1, 1]
211-
/// CHECK: tensor.extract_slice {{.*}} [1, 256, 1] [1, 1, 1]
208+
/// CHECK: %[[FINAL_RESULT:.*]] = scf.forall (%{{.*}}) = (0, 0) to (128, 256) step (1, 32)
209+
/// CHECK: tensor.extract_slice {{.*}} [1, 256, 32] [1, 1, 1]
210+
/// CHECK: tensor.extract_slice {{.*}} [1, 256, 32] [1, 1, 1]
211+
/// CHECK: tensor.extract_slice {{.*}} [1, 256, 32] [1, 1, 1]
212212
/// CHECK: %[[ADD_OUT:.*]] = linalg.add
213-
/// CHECK: tensor.extract_slice {{.*}} [1, 1] [1, 1]
213+
/// CHECK: tensor.extract_slice {{.*}} [1, 32] [1, 1]
214214
/// CHECK: %[[REDUCE_OUT:.*]] = linalg.reduce { arith.addf } ins(%[[ADD_OUT]] :
215215
%1 = linalg.reduce { arith.addf } ins(%0 : tensor<128x256x256xf32>) outs(%dest1 : tensor<128x256xf32>) dimensions = [1]
216216
/// CHECK: scf.forall.in_parallel
@@ -319,7 +319,7 @@ module {
319319
/// CHECK-LABEL: @fuse_residual_pattern
320320
func.func @fuse_residual_pattern(%arg0: tensor<128x256x256xf32>, %arg1: tensor<128x256x256xf32>) -> tensor<128x256x256xf32> {
321321
%dest0 = tensor.empty() : tensor<128x256x256xf32>
322-
/// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}, %{{.*}}) in (128, 256)
322+
/// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}) = (0, 0, 0) to (128, 256, 256) step (1, 32, 32)
323323
/// CHECK: %[[ADD_OUT:.*]] = linalg.add
324324
/// CHECK: %[[EXP_OUT:.*]] = linalg.exp ins(%[[ADD_OUT:.*]] :
325325
/// CHECK: %[[MUL_OUT:.*]] = linalg.mul ins(%[[ADD_OUT:.*]], %[[EXP_OUT:.*]] :
@@ -353,4 +353,57 @@ module {
353353
/// CHECK: return %[[PACK_OUT]]
354354
return %pack : tensor<1x1x128x32x32xbf16>
355355
}
356+
}
357+
358+
// -----
359+
360+
module {
361+
// CHECK: func.func @fuse_generic_matmul(
362+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
363+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x16x16xf32>
364+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<4x16x16xf32>
365+
func.func @fuse_generic_matmul(%arg0: tensor<32x32xf32>, %arg1: tensor<2x16x16xf32>, %arg2: tensor<4x16x16xf32>) -> tensor<32x64xf32> attributes {llvm.emit_c_interface} {
366+
/// CHECK: %[[EMPTY_OUT_0:.*]] = tensor.empty
367+
%0 = tensor.empty() : tensor<2x2x16x16xf32>
368+
%pack = tensor.pack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<32x32xf32> -> tensor<2x2x16x16xf32>
369+
/// CHECK: %[[EMPTY_OUT_1:.*]] = tensor.empty
370+
%1 = tensor.empty() : tensor<2x16x16xf32>
371+
/// CHECK: %[[FIRST_MATMUL_OUT:.*]] = scf.forall (%{{.*}}) in (2)
372+
/// CHECK: %[[EXTRACT_SLICE_0:.*]] = tensor.extract_slice %[[ARG0]]{{.*}} [16, 32]
373+
/// CHECK: %[[EXTRACT_SLICE_1:.*]] = tensor.extract_slice %[[EMPTY_OUT_0]]{{.*}} [1, 2, 16, 16]
374+
/// CHECK: %[[PACK_OUT:.*]] = tensor.pack %[[EXTRACT_SLICE_0]]
375+
/// CHECK: %[[EXTRACT_SLICE_2:.*]] = tensor.extract_slice %[[ARG1]]{{.*}} [2, 16, 16]
376+
/// CHECK: %[[MATMUL_OUT_0:.*]] = linalg.generic {{.*}} ins(%[[PACK_OUT]], %[[EXTRACT_SLICE_2]] :
377+
%2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3)>], iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %arg1 : tensor<2x2x16x16xf32>, tensor<2x16x16xf32>) outs(%1 : tensor<2x16x16xf32>) {
378+
^bb0(%in: f32, %in_3: f32, %out: f32):
379+
%9 = arith.mulf %in, %in_3 : f32
380+
%10 = arith.addf %out, %9 : f32
381+
linalg.yield %10 : f32
382+
} -> tensor<2x16x16xf32>
383+
/// CHECK: scf.forall.in_parallel
384+
/// CHECK: tensor.parallel_insert_slice
385+
/// CHECK: }
386+
/// CHECK: %[[EMPTY_OUT_2:.*]] = tensor.empty
387+
/// CHECK: %[[EMPTY_OUT_3:.*]] = tensor.empty
388+
%3 = tensor.empty() : tensor<2x4x16x16xf32>
389+
/// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%{{.*}}) in (2, 4)
390+
/// CHECK: %[[EXTRACT_SLICE_3:.*]] = tensor.extract_slice %[[FIRST_MATMUL_OUT]]{{.*}} [1, 16, 16]
391+
/// CHECK: %[[EXTRACT_SLICE_4:.*]] = tensor.extract_slice %[[ARG2]]{{.*}} [1, 16, 16]
392+
/// CHECK: %[[MATMUL_OUT_1:.*]] = linalg.generic {{.*}} ins(%[[EXTRACT_SLICE_3]], %[[EXTRACT_SLICE_4]] :
393+
/// CHECK: %[[UNPACK_OUT:.*]] = tensor.unpack %[[MATMUL_OUT_1]]
394+
%4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%2, %arg2 : tensor<2x16x16xf32>, tensor<4x16x16xf32>) outs(%3 : tensor<2x4x16x16xf32>) {
395+
^bb0(%in: f32, %in_3: f32, %out: f32):
396+
%9 = arith.mulf %in, %in_3 : f32
397+
%10 = arith.addf %out, %9 : f32
398+
linalg.yield %10 : f32
399+
} -> tensor<2x4x16x16xf32>
400+
/// CHECK: scf.forall.in_parallel
401+
/// CHECK: tensor.parallel_insert_slice
402+
/// CHECK: tensor.parallel_insert_slice
403+
/// CHECK: }
404+
%5 = tensor.empty() : tensor<32x64xf32>
405+
%unpack = tensor.unpack %4 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %5 : tensor<2x4x16x16xf32> -> tensor<32x64xf32>
406+
/// CHECK: return %[[FINAL_RESULT]]#1
407+
return %unpack : tensor<32x64xf32>
408+
}
356409
}

0 commit comments

Comments
 (0)