7
7
// ===----------------------------------------------------------------------===//
8
8
9
9
#include " gc/Analysis/TargetDescriptionAnalysis.h"
10
+ #include " gc/Dialect/Linalgx/LinalgxOps.h"
10
11
#include " gc/Transforms/Passes.h"
11
12
#include " mlir/Analysis/TopologicalSortUtils.h"
12
13
#include " mlir/Dialect/DLTI/Traits.h"
@@ -68,17 +69,22 @@ getClosestInsertSliceOfResult(OpResult result) {
68
69
sliceOp =
69
70
dyn_cast<OffsetSizeAndStrideOpInterface>(useOfResult.getOwner ());
70
71
} else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOfResult.getOwner ())) {
71
- if (auto loop = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp ()))
72
+ if (isa<LoopLikeOpInterface, RegionBranchOpInterface>(
73
+ yieldOp->getParentOp ())) {
72
74
return getClosestInsertSliceOfResult (
73
- loop->getResult (useOfResult.getOperandNumber ()));
75
+ yieldOp->getParentOp ()->getResult (useOfResult.getOperandNumber ()));
76
+ }
77
+ } else if (isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp>(
78
+ useOfResult.getOwner ())) {
79
+ return getClosestInsertSliceOfResult (
80
+ useOfResult.getOwner ()->getResult (useOfResult.getOperandNumber ()));
74
81
}
75
82
}
76
83
77
84
if (!llvm::detail::isPresent (sliceOp))
78
85
return failure ();
79
- else {
80
- return sliceOp;
81
- }
86
+
87
+ return sliceOp;
82
88
}
83
89
84
90
struct CandidateDefOrUse {
@@ -155,12 +161,12 @@ exactTilingOnPackUnPackFilter(RewriterBase &rewriter,
155
161
if (auto packOp = dyn_cast<tensor::PackOp>(defOrUse.ownerOp )) {
156
162
// tileSize comes from OpResult
157
163
if (defOrUse.isDef ()) {
158
- targetInnerTileSizes = packOp.getInnerTiles ();
164
+ targetInnerTileSizes = packOp.getMixedTiles ();
159
165
targetTileSizes = llvm::to_vector (
160
166
ArrayRef (tileSizes).take_back (targetInnerTileSizes.size ()));
161
167
} else {
162
168
// tileSize comes from OpOperand
163
- targetTileSizes = llvm::to_vector ( tileSizes) ;
169
+ targetTileSizes = tileSizes;
164
170
DenseMap<int64_t , OpFoldResult> dimAndTileMapping =
165
171
packOp.getDimAndTileMapping ();
166
172
targetInnerTileSizes.resize (dimAndTileMapping.size ());
@@ -171,16 +177,18 @@ exactTilingOnPackUnPackFilter(RewriterBase &rewriter,
171
177
} else if (auto unPackOp = dyn_cast<tensor::UnPackOp>(defOrUse.ownerOp )) {
172
178
// tileSize comes from OpResult
173
179
if (defOrUse.isDef ()) {
174
- targetTileSizes = llvm::to_vector ( tileSizes) ;
180
+ targetTileSizes = tileSizes;
175
181
DenseMap<int64_t , OpFoldResult> dimAndTileMapping =
176
182
unPackOp.getDimAndTileMapping ();
183
+ if (dimAndTileMapping.empty ())
184
+ return failure ();
177
185
targetInnerTileSizes.resize (dimAndTileMapping.size ());
178
186
for (const auto &dimAndTile : dimAndTileMapping) {
179
187
targetInnerTileSizes[dimAndTile.first ] = dimAndTile.second ;
180
188
}
181
189
} else {
182
190
// tileSize comes from OpOperand
183
- targetInnerTileSizes = unPackOp.getInnerTiles ();
191
+ targetInnerTileSizes = unPackOp.getMixedTiles ();
184
192
targetTileSizes = llvm::to_vector (
185
193
ArrayRef (tileSizes).take_back (targetInnerTileSizes.size ()));
186
194
}
@@ -481,21 +489,10 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
481
489
482
490
if (succeeded (fusedResult)) {
483
491
fusedResultList.push_back (*fusedResult);
484
- auto whileProducerOutOfLoopBlock =
485
- [&fusedResult](LoopLikeOpInterface loop) -> LogicalResult {
486
- Block &body = loop->getRegion (0 ).front ();
487
- return failure (fusedResult.value ().tiledOps [0 ]->getBlock () == &body);
488
- };
489
- SmallVector<LoopLikeOpInterface> outerLoops =
490
- scfX::getOuterNestLoopsWhile (
491
- (*bestCandidate)->getParentOfType <LoopLikeOpInterface>(),
492
- whileProducerOutOfLoopBlock);
493
- // g. Manually run cse on region which contains top-level loop of
494
- // candidate slice in avoid of conflict with subsequent
495
- // `tileAndFuseConsumerOfSlice` get nest loops between next candidate
496
- // sliceOp and tiled producer.
497
- (void )mlir::simplifyRegions (rewriter,
498
- {*outerLoops.front ()->getParentRegion ()});
492
+ // f. Manually run cse on region which contains original consumer op in
493
+ // avoid of conflict with subsequent `tileAndFuseConsumerOfSlice` get nest
494
+ // loops between next candidate sliceOp and tiled producer.
495
+ (void )mlir::simplifyRegions (rewriter, {*consumer->getParentRegion ()});
499
496
}
500
497
}
501
498
if (fusedResultList.empty ())
@@ -543,7 +540,13 @@ LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
543
540
return success (numTiledOps > 1 );
544
541
}
545
542
546
- // / What is self tiled op compared with other fused op?
543
+ // / This is a workaround to deal with LinalgXOp
544
+ static bool isTilableLinalgXOp (Operation *op) {
545
+ return isa<linalgx::BatchReduceMatmulVnniOp, linalgx::MultiBatchMatmulOp,
546
+ linalgx::Mm2DVnniOp, linalgx::Mm4DVnniOp>(op);
547
+ }
548
+
549
+ // / Check if tiled op inside a loop?
547
550
// / E.g.
548
551
// / %1 = scf.for(){
549
552
// / %2 = scf.for(){
@@ -553,25 +556,17 @@ LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
553
556
// / yield %5
554
557
// / }
555
558
// / }
556
- static LogicalResult isSelfTiledOp (Operation *targetOp) {
557
- // 0 . check tilable
558
- if (!isa<TilingInterface>(targetOp))
559
+ static LogicalResult isTiledOpInLoop (Operation *targetOp) {
560
+ // 1 . check tilable
561
+ if (!isa<TilingInterface>(targetOp) && ! isTilableLinalgXOp (targetOp) )
559
562
return failure ();
560
- // 1 . check parentOp
563
+ // 2 . check parentOp
561
564
auto forOp = targetOp->getParentOfType <LoopLikeOpInterface>();
562
565
if (!forOp)
563
566
return failure ();
564
- // 2. check single one tiling interface in loop body
565
- auto walkResult = forOp->walk ([&targetOp](TilingInterface op) {
566
- // some special op maybe already deal with in template
567
- if (isa<linalg::FillOp, linalg::CopyOp>(op))
568
- return WalkResult::skip ();
569
- return op != targetOp ? WalkResult::interrupt () : WalkResult::advance ();
570
- });
571
- if (walkResult.wasInterrupted ())
572
- return failure ();
567
+
573
568
// 3. check whether has either extract or insert slice op
574
- walkResult = forOp->walk (
569
+ auto walkResult = forOp->walk (
575
570
[](tensor::ExtractSliceOp) { return WalkResult::interrupt (); });
576
571
if (walkResult.wasInterrupted ())
577
572
return success ();
@@ -583,11 +578,12 @@ static LogicalResult isSelfTiledOp(Operation *targetOp) {
583
578
using OpTileSizeMap = std::unordered_map<std::string, SmallVector<int64_t >>;
584
579
585
580
template <typename OpTy>
586
- static bool defaultTilingOfType (RewriterBase &rewriter, Operation *op,
587
- const OpTileSizeMap &tsMap) {
581
+ static FailureOr<scf::SCFTilingResult>
582
+ defaultTilingOfType (RewriterBase &rewriter, Operation *op,
583
+ const OpTileSizeMap &tsMap) {
588
584
// a. Check <OpTy>
589
585
if (!isa<TilingInterface>(op) || !isa<OpTy>(op))
590
- return false ;
586
+ return failure () ;
591
587
auto tilingInterfaceOp = cast<TilingInterface>(op);
592
588
593
589
scf::SCFTilingOptions options;
@@ -618,26 +614,29 @@ static bool defaultTilingOfType(RewriterBase &rewriter, Operation *op,
618
614
}
619
615
// If the tile sizes are all zero, no tiling would happen.
620
616
if (llvm::all_of (defaultTileSize, isZeroIndex))
621
- return false ;
617
+ return failure () ;
622
618
623
619
options.setTileSizes (defaultTileSize);
624
620
// c. Set loop type
625
621
options.setLoopType (scf::SCFTilingOptions::LoopType::ForallOp);
626
622
// d. Use builtin tiling interface
627
623
FailureOr<scf::SCFTilingResult> tilingResult =
628
624
scf::tileUsingSCF (rewriter, tilingInterfaceOp, options);
629
- if ( succeeded (tilingResult)) {
630
- rewriter. replaceOp (op, tilingResult-> replacements );
631
- return true ;
632
- }
633
- return false ;
625
+
626
+ if ( failed ( tilingResult))
627
+ return failure () ;
628
+
629
+ return tilingResult ;
634
630
}
635
631
632
+ using DefaultTilingFn = std::function<FailureOr<scf::SCFTilingResult>(
633
+ RewriterBase &, Operation *, const OpTileSizeMap &)>;
634
+
636
635
void iterativeTilingAndFusionUntilExhaustion (
637
636
RewriterBase &rewriter, func::FuncOp &f,
638
637
const CandidateSliceOptions &sliceOptions, const OpTileSizeMap &tsMap) {
639
638
// Collect untiled and tiled ops respectively
640
- llvm::SetVector<Operation *> selfTiledOp , unTiledOps;
639
+ llvm::SetVector<Operation *> tiledOps , unTiledOps;
641
640
642
641
auto collectUnTiledOps = [&f, &unTiledOps]() -> bool {
643
642
// Reset
@@ -648,67 +647,64 @@ void iterativeTilingAndFusionUntilExhaustion(
648
647
return WalkResult::skip ();
649
648
if (isa<TilingInterface>(op)) {
650
649
auto parentLoop = op->getParentOfType <LoopLikeOpInterface>();
651
- if (!parentLoop.getOperation ())
650
+ auto parentGeneric = op->getParentOfType <linalg::GenericOp>();
651
+ if (!llvm::detail::isPresent (parentLoop) &&
652
+ !llvm::detail::isPresent (parentGeneric))
652
653
unTiledOps.insert (op);
653
654
}
654
655
return WalkResult::advance ();
655
656
});
656
657
return !unTiledOps.empty ();
657
658
};
658
659
659
- auto collectSelfTiledOp = [&f, &selfTiledOp]() -> bool {
660
- // Reset
661
- selfTiledOp.clear ();
662
- // Walk through funcOp
663
- f->walk ([&selfTiledOp](Operation *op) {
664
- // Target at certain kind of tiled op, such as matmul/conv implemented
665
- // by multiple level of nest loops and candidate slices for better
666
- // utilization of parallelism and memory hierarchy.
667
- if (succeeded (isSelfTiledOp (op))) {
668
- selfTiledOp.insert (op);
669
- }
670
- });
671
- return !selfTiledOp.empty ();
672
- };
660
+ // Walk through funcOp
661
+ f->walk ([&tiledOps](Operation *op) {
662
+ if (succeeded (isTiledOpInLoop (op))) {
663
+ tiledOps.insert (op);
664
+ }
665
+ });
673
666
674
667
// Iterative tiling and fusion until exhaustion.
675
668
while (collectUnTiledOps ()) {
676
669
// If existing tiled op before tiling.
677
- if (collectSelfTiledOp ()) {
670
+ if (!tiledOps. empty ()) {
678
671
// Sort by topology
679
- mlir::topologicalSort (selfTiledOp );
672
+ mlir::topologicalSort (tiledOps );
680
673
// Record if any fusion happens
681
674
bool changed = false ;
682
675
// Iteratively fuse in forward and backward fashion.
683
- llvm::for_each (selfTiledOp, [&rewriter, &sliceOptions,
684
- &changed](Operation *tiledOp) {
685
- changed |= succeeded (iterativelyFuseProducerAndConsumerOfTiledOp (
686
- rewriter, tiledOp, sliceOptions));
687
- });
676
+ llvm::for_each (
677
+ tiledOps, [&rewriter, &sliceOptions, &changed](Operation *tiledOp) {
678
+ changed |= succeeded (iterativelyFuseProducerAndConsumerOfTiledOp (
679
+ rewriter, tiledOp, sliceOptions));
680
+ });
681
+ tiledOps.clear ();
688
682
if (changed)
689
- (void )mlir::simplifyRegions (rewriter, {f. getRegion ()} );
683
+ (void )mlir::simplifyRegions (rewriter, f-> getRegions () );
690
684
} else {
691
685
// Auto tiling with default tile size if no tiled op found. Follow tiling
692
686
// priority based on OpTy: `Contraction`->`Reduction`->`Elementwise`.
693
- SmallVector<std::function<bool (RewriterBase &, Operation *,
694
- const OpTileSizeMap &)>>
695
- priorityTilingPipeLine = {
696
- defaultTilingOfType<mlir::linalg::ContractionOpInterface>,
697
- defaultTilingOfType<mlir::linalg::ReduceOp>,
698
- defaultTilingOfType<mlir::linalg::LinalgOp>};
699
- if (llvm::all_of (priorityTilingPipeLine,
700
- [&rewriter, &tsMap, &unTiledOps](
701
- function_ref<bool (RewriterBase &, Operation *,
702
- const OpTileSizeMap &)>
703
- tilingFn) {
704
- return !llvm::any_of (
705
- unTiledOps, std::bind (tilingFn, std::ref (rewriter),
706
- std::placeholders::_1,
707
- std::cref (tsMap)));
708
- })) {
709
- // If no op can be tiled
710
- return ;
687
+ SmallVector<DefaultTilingFn> priorityTilingPipeLine = {
688
+ defaultTilingOfType<mlir::linalg::ContractionOpInterface>,
689
+ defaultTilingOfType<mlir::linalg::ReduceOp>,
690
+ defaultTilingOfType<TilingInterface>};
691
+
692
+ for (auto &tilingFn : priorityTilingPipeLine) {
693
+ for (auto &op : unTiledOps) {
694
+ FailureOr<scf::SCFTilingResult> tilingResult =
695
+ tilingFn (rewriter, op, tsMap);
696
+ if (succeeded (tilingResult)) {
697
+ tiledOps.insert (tilingResult->tiledOps [0 ]);
698
+ rewriter.replaceOp (op, tilingResult->replacements );
699
+ break ;
700
+ }
701
+ }
702
+ if (!tiledOps.empty ())
703
+ break ;
711
704
}
705
+ // If no op can be tiled
706
+ if (tiledOps.empty ())
707
+ return ;
712
708
}
713
709
}
714
710
}
0 commit comments