Skip to content

Commit 6ba9310

Browse files
authored
[Transform][Fusion] Support fusion for K-tiling deep matmul (#217)
Fix several issues found in llama2 MLP workload: 1. Support `if-else` control block. 2. Cover `GenericOp` and `LinalgXOp`. 3. Enhance use-def check for residual pattern.
1 parent b22af8c commit 6ba9310

File tree

4 files changed

+242
-178
lines changed

4 files changed

+242
-178
lines changed

include/gc/Transforms/Passes.td

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,25 +59,18 @@ def IterativeTilingAndFusion : Pass<"iterative-tiling-and-fusion",
5959
3. Multi-consumer and multi-producer support.
6060
4. Multiple level of nest loops and candidates.
6161
5. Flexible option to control the boundary of iterative process.
62-
6. Cost-model to determine whether to fuse or not.
63-
64-
It intends to control the granularity of fusion by `fusion-level`, E.g.
65-
* `0`: disable any fusion.
66-
* `1`:[Default] enable both producer and consumer fusion, covering any tilable operation including tensor.pack/tensor.fill/linalg.reduce etc but excluding branches forked by multiple uses.
67-
* `2`: `LEVEL 1` + extend to any topology including branches.
62+
6. Default tiling when no op is tiled before fusion.
63+
7. Cost-model to determine whether to fuse or not.
6864
}];
6965
let dependentDialects = ["func::FuncDialect", "linalg::LinalgDialect", "scf::SCFDialect",
7066
"tensor::TensorDialect"];
7167

7268
let options = [
73-
Option<"fusionLevel", "fusion-level", "int64_t",
74-
/*default=*/"1",
75-
"Control the granularity of fusion.">,
7669
Option<"useCostModel", "use-cost-model", "bool",
7770
/*default=*/"false",
7871
"Decide if enable cost model to control iterative fusion.">,
7972
ListOption<"defaultTileSize", "default-tile-size", "std::string",
80-
"Set default TileSize for the certain type of op, saying matmul:{32,32}">,
73+
"Set default TileSize for the certain type of op, saying `matmul:{32,32}`">,
8174
];
8275
}
8376

lib/gc/Transforms/IterativeTilingAndFusion.cpp

Lines changed: 86 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "gc/Analysis/TargetDescriptionAnalysis.h"
10+
#include "gc/Dialect/Linalgx/LinalgxOps.h"
1011
#include "gc/Transforms/Passes.h"
1112
#include "mlir/Analysis/TopologicalSortUtils.h"
1213
#include "mlir/Dialect/DLTI/Traits.h"
@@ -68,17 +69,22 @@ getClosestInsertSliceOfResult(OpResult result) {
6869
sliceOp =
6970
dyn_cast<OffsetSizeAndStrideOpInterface>(useOfResult.getOwner());
7071
} else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOfResult.getOwner())) {
71-
if (auto loop = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp()))
72+
if (isa<LoopLikeOpInterface, RegionBranchOpInterface>(
73+
yieldOp->getParentOp())) {
7274
return getClosestInsertSliceOfResult(
73-
loop->getResult(useOfResult.getOperandNumber()));
75+
yieldOp->getParentOp()->getResult(useOfResult.getOperandNumber()));
76+
}
77+
} else if (isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp>(
78+
useOfResult.getOwner())) {
79+
return getClosestInsertSliceOfResult(
80+
useOfResult.getOwner()->getResult(useOfResult.getOperandNumber()));
7481
}
7582
}
7683

7784
if (!llvm::detail::isPresent(sliceOp))
7885
return failure();
79-
else {
80-
return sliceOp;
81-
}
86+
87+
return sliceOp;
8288
}
8389

8490
struct CandidateDefOrUse {
@@ -155,12 +161,12 @@ exactTilingOnPackUnPackFilter(RewriterBase &rewriter,
155161
if (auto packOp = dyn_cast<tensor::PackOp>(defOrUse.ownerOp)) {
156162
// tileSize comes from OpResult
157163
if (defOrUse.isDef()) {
158-
targetInnerTileSizes = packOp.getInnerTiles();
164+
targetInnerTileSizes = packOp.getMixedTiles();
159165
targetTileSizes = llvm::to_vector(
160166
ArrayRef(tileSizes).take_back(targetInnerTileSizes.size()));
161167
} else {
162168
// tileSize comes from OpOperand
163-
targetTileSizes = llvm::to_vector(tileSizes);
169+
targetTileSizes = tileSizes;
164170
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
165171
packOp.getDimAndTileMapping();
166172
targetInnerTileSizes.resize(dimAndTileMapping.size());
@@ -171,16 +177,18 @@ exactTilingOnPackUnPackFilter(RewriterBase &rewriter,
171177
} else if (auto unPackOp = dyn_cast<tensor::UnPackOp>(defOrUse.ownerOp)) {
172178
// tileSize comes from OpResult
173179
if (defOrUse.isDef()) {
174-
targetTileSizes = llvm::to_vector(tileSizes);
180+
targetTileSizes = tileSizes;
175181
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
176182
unPackOp.getDimAndTileMapping();
183+
if (dimAndTileMapping.empty())
184+
return failure();
177185
targetInnerTileSizes.resize(dimAndTileMapping.size());
178186
for (const auto &dimAndTile : dimAndTileMapping) {
179187
targetInnerTileSizes[dimAndTile.first] = dimAndTile.second;
180188
}
181189
} else {
182190
// tileSize comes from OpOperand
183-
targetInnerTileSizes = unPackOp.getInnerTiles();
191+
targetInnerTileSizes = unPackOp.getMixedTiles();
184192
targetTileSizes = llvm::to_vector(
185193
ArrayRef(tileSizes).take_back(targetInnerTileSizes.size()));
186194
}
@@ -481,21 +489,10 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
481489

482490
if (succeeded(fusedResult)) {
483491
fusedResultList.push_back(*fusedResult);
484-
auto whileProducerOutOfLoopBlock =
485-
[&fusedResult](LoopLikeOpInterface loop) -> LogicalResult {
486-
Block &body = loop->getRegion(0).front();
487-
return failure(fusedResult.value().tiledOps[0]->getBlock() == &body);
488-
};
489-
SmallVector<LoopLikeOpInterface> outerLoops =
490-
scfX::getOuterNestLoopsWhile(
491-
(*bestCandidate)->getParentOfType<LoopLikeOpInterface>(),
492-
whileProducerOutOfLoopBlock);
493-
// g. Manually run cse on region which contains top-level loop of
494-
// candidate slice in avoid of conflict with subsequent
495-
// `tileAndFuseConsumerOfSlice` get nest loops between next candidate
496-
// sliceOp and tiled producer.
497-
(void)mlir::simplifyRegions(rewriter,
498-
{*outerLoops.front()->getParentRegion()});
492+
// f. Manually run cse on region which contains original consumer op in
493+
// avoid of conflict with subsequent `tileAndFuseConsumerOfSlice` get nest
494+
// loops between next candidate sliceOp and tiled producer.
495+
(void)mlir::simplifyRegions(rewriter, {*consumer->getParentRegion()});
499496
}
500497
}
501498
if (fusedResultList.empty())
@@ -543,7 +540,13 @@ LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
543540
return success(numTiledOps > 1);
544541
}
545542

546-
/// What is self tiled op compared with other fused op?
543+
/// This is a workaround to deal with LinalgXOp
544+
static bool isTilableLinalgXOp(Operation *op) {
545+
return isa<linalgx::BatchReduceMatmulVnniOp, linalgx::MultiBatchMatmulOp,
546+
linalgx::Mm2DVnniOp, linalgx::Mm4DVnniOp>(op);
547+
}
548+
549+
/// Check if tiled op inside a loop?
547550
/// E.g.
548551
/// %1 = scf.for(){
549552
/// %2 = scf.for(){
@@ -553,25 +556,17 @@ LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
553556
/// yield %5
554557
/// }
555558
/// }
556-
static LogicalResult isSelfTiledOp(Operation *targetOp) {
557-
// 0. check tilable
558-
if (!isa<TilingInterface>(targetOp))
559+
static LogicalResult isTiledOpInLoop(Operation *targetOp) {
560+
// 1. check tilable
561+
if (!isa<TilingInterface>(targetOp) && !isTilableLinalgXOp(targetOp))
559562
return failure();
560-
// 1. check parentOp
563+
// 2. check parentOp
561564
auto forOp = targetOp->getParentOfType<LoopLikeOpInterface>();
562565
if (!forOp)
563566
return failure();
564-
// 2. check single one tiling interface in loop body
565-
auto walkResult = forOp->walk([&targetOp](TilingInterface op) {
566-
// some special op maybe already deal with in template
567-
if (isa<linalg::FillOp, linalg::CopyOp>(op))
568-
return WalkResult::skip();
569-
return op != targetOp ? WalkResult::interrupt() : WalkResult::advance();
570-
});
571-
if (walkResult.wasInterrupted())
572-
return failure();
567+
573568
// 3. check whether has either extract or insert slice op
574-
walkResult = forOp->walk(
569+
auto walkResult = forOp->walk(
575570
[](tensor::ExtractSliceOp) { return WalkResult::interrupt(); });
576571
if (walkResult.wasInterrupted())
577572
return success();
@@ -583,11 +578,12 @@ static LogicalResult isSelfTiledOp(Operation *targetOp) {
583578
using OpTileSizeMap = std::unordered_map<std::string, SmallVector<int64_t>>;
584579

585580
template <typename OpTy>
586-
static bool defaultTilingOfType(RewriterBase &rewriter, Operation *op,
587-
const OpTileSizeMap &tsMap) {
581+
static FailureOr<scf::SCFTilingResult>
582+
defaultTilingOfType(RewriterBase &rewriter, Operation *op,
583+
const OpTileSizeMap &tsMap) {
588584
// a. Check <OpTy>
589585
if (!isa<TilingInterface>(op) || !isa<OpTy>(op))
590-
return false;
586+
return failure();
591587
auto tilingInterfaceOp = cast<TilingInterface>(op);
592588

593589
scf::SCFTilingOptions options;
@@ -618,26 +614,29 @@ static bool defaultTilingOfType(RewriterBase &rewriter, Operation *op,
618614
}
619615
// If the tile sizes are all zero, no tiling would happen.
620616
if (llvm::all_of(defaultTileSize, isZeroIndex))
621-
return false;
617+
return failure();
622618

623619
options.setTileSizes(defaultTileSize);
624620
// c. Set loop type
625621
options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
626622
// d. Use builtin tiling interface
627623
FailureOr<scf::SCFTilingResult> tilingResult =
628624
scf::tileUsingSCF(rewriter, tilingInterfaceOp, options);
629-
if (succeeded(tilingResult)) {
630-
rewriter.replaceOp(op, tilingResult->replacements);
631-
return true;
632-
}
633-
return false;
625+
626+
if (failed(tilingResult))
627+
return failure();
628+
629+
return tilingResult;
634630
}
635631

632+
using DefaultTilingFn = std::function<FailureOr<scf::SCFTilingResult>(
633+
RewriterBase &, Operation *, const OpTileSizeMap &)>;
634+
636635
void iterativeTilingAndFusionUntilExhaustion(
637636
RewriterBase &rewriter, func::FuncOp &f,
638637
const CandidateSliceOptions &sliceOptions, const OpTileSizeMap &tsMap) {
639638
// Collect untiled and tiled ops respectively
640-
llvm::SetVector<Operation *> selfTiledOp, unTiledOps;
639+
llvm::SetVector<Operation *> tiledOps, unTiledOps;
641640

642641
auto collectUnTiledOps = [&f, &unTiledOps]() -> bool {
643642
// Reset
@@ -648,67 +647,64 @@ void iterativeTilingAndFusionUntilExhaustion(
648647
return WalkResult::skip();
649648
if (isa<TilingInterface>(op)) {
650649
auto parentLoop = op->getParentOfType<LoopLikeOpInterface>();
651-
if (!parentLoop.getOperation())
650+
auto parentGeneric = op->getParentOfType<linalg::GenericOp>();
651+
if (!llvm::detail::isPresent(parentLoop) &&
652+
!llvm::detail::isPresent(parentGeneric))
652653
unTiledOps.insert(op);
653654
}
654655
return WalkResult::advance();
655656
});
656657
return !unTiledOps.empty();
657658
};
658659

659-
auto collectSelfTiledOp = [&f, &selfTiledOp]() -> bool {
660-
// Reset
661-
selfTiledOp.clear();
662-
// Walk through funcOp
663-
f->walk([&selfTiledOp](Operation *op) {
664-
// Target at certain kind of tiled op, such as matmul/conv implemented
665-
// by multiple level of nest loops and candidate slices for better
666-
// utilization of parallelism and memory hierarchy.
667-
if (succeeded(isSelfTiledOp(op))) {
668-
selfTiledOp.insert(op);
669-
}
670-
});
671-
return !selfTiledOp.empty();
672-
};
660+
// Walk through funcOp
661+
f->walk([&tiledOps](Operation *op) {
662+
if (succeeded(isTiledOpInLoop(op))) {
663+
tiledOps.insert(op);
664+
}
665+
});
673666

674667
// Iterative tiling and fusion until exhaustion.
675668
while (collectUnTiledOps()) {
676669
// If existing tiled op before tiling.
677-
if (collectSelfTiledOp()) {
670+
if (!tiledOps.empty()) {
678671
// Sort by topology
679-
mlir::topologicalSort(selfTiledOp);
672+
mlir::topologicalSort(tiledOps);
680673
// Record if any fusion happens
681674
bool changed = false;
682675
// Iteratively fuse in forward and backward fashion.
683-
llvm::for_each(selfTiledOp, [&rewriter, &sliceOptions,
684-
&changed](Operation *tiledOp) {
685-
changed |= succeeded(iterativelyFuseProducerAndConsumerOfTiledOp(
686-
rewriter, tiledOp, sliceOptions));
687-
});
676+
llvm::for_each(
677+
tiledOps, [&rewriter, &sliceOptions, &changed](Operation *tiledOp) {
678+
changed |= succeeded(iterativelyFuseProducerAndConsumerOfTiledOp(
679+
rewriter, tiledOp, sliceOptions));
680+
});
681+
tiledOps.clear();
688682
if (changed)
689-
(void)mlir::simplifyRegions(rewriter, {f.getRegion()});
683+
(void)mlir::simplifyRegions(rewriter, f->getRegions());
690684
} else {
691685
// Auto tiling with default tile size if no tiled op found. Follow tiling
692686
// priority based on OpTy: `Contraction`->`Reduction`->`Elementwise`.
693-
SmallVector<std::function<bool(RewriterBase &, Operation *,
694-
const OpTileSizeMap &)>>
695-
priorityTilingPipeLine = {
696-
defaultTilingOfType<mlir::linalg::ContractionOpInterface>,
697-
defaultTilingOfType<mlir::linalg::ReduceOp>,
698-
defaultTilingOfType<mlir::linalg::LinalgOp>};
699-
if (llvm::all_of(priorityTilingPipeLine,
700-
[&rewriter, &tsMap, &unTiledOps](
701-
function_ref<bool(RewriterBase &, Operation *,
702-
const OpTileSizeMap &)>
703-
tilingFn) {
704-
return !llvm::any_of(
705-
unTiledOps, std::bind(tilingFn, std::ref(rewriter),
706-
std::placeholders::_1,
707-
std::cref(tsMap)));
708-
})) {
709-
// If no op can be tiled
710-
return;
687+
SmallVector<DefaultTilingFn> priorityTilingPipeLine = {
688+
defaultTilingOfType<mlir::linalg::ContractionOpInterface>,
689+
defaultTilingOfType<mlir::linalg::ReduceOp>,
690+
defaultTilingOfType<TilingInterface>};
691+
692+
for (auto &tilingFn : priorityTilingPipeLine) {
693+
for (auto &op : unTiledOps) {
694+
FailureOr<scf::SCFTilingResult> tilingResult =
695+
tilingFn(rewriter, op, tsMap);
696+
if (succeeded(tilingResult)) {
697+
tiledOps.insert(tilingResult->tiledOps[0]);
698+
rewriter.replaceOp(op, tilingResult->replacements);
699+
break;
700+
}
701+
}
702+
if (!tiledOps.empty())
703+
break;
711704
}
705+
// If no op can be tiled
706+
if (tiledOps.empty())
707+
return;
712708
}
713709
}
714710
}

0 commit comments

Comments
 (0)