77// ===----------------------------------------------------------------------===//
88
99#include " gc/Analysis/TargetDescriptionAnalysis.h"
10+ #include " gc/Dialect/Linalgx/LinalgxOps.h"
1011#include " gc/Transforms/Passes.h"
1112#include " mlir/Analysis/TopologicalSortUtils.h"
1213#include " mlir/Dialect/DLTI/Traits.h"
@@ -68,17 +69,22 @@ getClosestInsertSliceOfResult(OpResult result) {
6869 sliceOp =
6970 dyn_cast<OffsetSizeAndStrideOpInterface>(useOfResult.getOwner ());
7071 } else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOfResult.getOwner ())) {
71- if (auto loop = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp ()))
72+ if (isa<LoopLikeOpInterface, RegionBranchOpInterface>(
73+ yieldOp->getParentOp ())) {
7274 return getClosestInsertSliceOfResult (
73- loop->getResult (useOfResult.getOperandNumber ()));
75+ yieldOp->getParentOp ()->getResult (useOfResult.getOperandNumber ()));
76+ }
77+ } else if (isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp>(
78+ useOfResult.getOwner ())) {
79+ return getClosestInsertSliceOfResult (
80+ useOfResult.getOwner ()->getResult (useOfResult.getOperandNumber ()));
7481 }
7582 }
7683
7784 if (!llvm::detail::isPresent (sliceOp))
7885 return failure ();
79- else {
80- return sliceOp;
81- }
86+
87+ return sliceOp;
8288}
8389
8490struct CandidateDefOrUse {
@@ -155,12 +161,12 @@ exactTilingOnPackUnPackFilter(RewriterBase &rewriter,
155161 if (auto packOp = dyn_cast<tensor::PackOp>(defOrUse.ownerOp )) {
156162 // tileSize comes from OpResult
157163 if (defOrUse.isDef ()) {
158- targetInnerTileSizes = packOp.getInnerTiles ();
164+ targetInnerTileSizes = packOp.getMixedTiles ();
159165 targetTileSizes = llvm::to_vector (
160166 ArrayRef (tileSizes).take_back (targetInnerTileSizes.size ()));
161167 } else {
162168 // tileSize comes from OpOperand
163- targetTileSizes = llvm::to_vector ( tileSizes) ;
169+ targetTileSizes = tileSizes;
164170 DenseMap<int64_t , OpFoldResult> dimAndTileMapping =
165171 packOp.getDimAndTileMapping ();
166172 targetInnerTileSizes.resize (dimAndTileMapping.size ());
@@ -171,16 +177,18 @@ exactTilingOnPackUnPackFilter(RewriterBase &rewriter,
171177 } else if (auto unPackOp = dyn_cast<tensor::UnPackOp>(defOrUse.ownerOp )) {
172178 // tileSize comes from OpResult
173179 if (defOrUse.isDef ()) {
174- targetTileSizes = llvm::to_vector ( tileSizes) ;
180+ targetTileSizes = tileSizes;
175181 DenseMap<int64_t , OpFoldResult> dimAndTileMapping =
176182 unPackOp.getDimAndTileMapping ();
183+ if (dimAndTileMapping.empty ())
184+ return failure ();
177185 targetInnerTileSizes.resize (dimAndTileMapping.size ());
178186 for (const auto &dimAndTile : dimAndTileMapping) {
179187 targetInnerTileSizes[dimAndTile.first ] = dimAndTile.second ;
180188 }
181189 } else {
182190 // tileSize comes from OpOperand
183- targetInnerTileSizes = unPackOp.getInnerTiles ();
191+ targetInnerTileSizes = unPackOp.getMixedTiles ();
184192 targetTileSizes = llvm::to_vector (
185193 ArrayRef (tileSizes).take_back (targetInnerTileSizes.size ()));
186194 }
@@ -481,21 +489,10 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
481489
482490 if (succeeded (fusedResult)) {
483491 fusedResultList.push_back (*fusedResult);
484- auto whileProducerOutOfLoopBlock =
485- [&fusedResult](LoopLikeOpInterface loop) -> LogicalResult {
486- Block &body = loop->getRegion (0 ).front ();
487- return failure (fusedResult.value ().tiledOps [0 ]->getBlock () == &body);
488- };
489- SmallVector<LoopLikeOpInterface> outerLoops =
490- scfX::getOuterNestLoopsWhile (
491- (*bestCandidate)->getParentOfType <LoopLikeOpInterface>(),
492- whileProducerOutOfLoopBlock);
493- // g. Manually run cse on region which contains top-level loop of
494- // candidate slice in avoid of conflict with subsequent
495- // `tileAndFuseConsumerOfSlice` get nest loops between next candidate
496- // sliceOp and tiled producer.
497- (void )mlir::simplifyRegions (rewriter,
498- {*outerLoops.front ()->getParentRegion ()});
492+ // f. Manually run cse on region which contains original consumer op in
493+ // avoid of conflict with subsequent `tileAndFuseConsumerOfSlice` get nest
494+ // loops between next candidate sliceOp and tiled producer.
495+ (void )mlir::simplifyRegions (rewriter, {*consumer->getParentRegion ()});
499496 }
500497 }
501498 if (fusedResultList.empty ())
@@ -543,7 +540,13 @@ LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
543540 return success (numTiledOps > 1 );
544541}
545542
546- // / What is self tiled op compared with other fused op?
543+ // / This is a workaround to deal with LinalgXOp
544+ static bool isTilableLinalgXOp (Operation *op) {
545+ return isa<linalgx::BatchReduceMatmulVnniOp, linalgx::MultiBatchMatmulOp,
546+ linalgx::Mm2DVnniOp, linalgx::Mm4DVnniOp>(op);
547+ }
548+
549+ // / Check if tiled op inside a loop?
547550// / E.g.
548551// / %1 = scf.for(){
549552// / %2 = scf.for(){
@@ -553,25 +556,17 @@ LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
553556// / yield %5
554557// / }
555558// / }
556- static LogicalResult isSelfTiledOp (Operation *targetOp) {
557- // 0 . check tilable
558- if (!isa<TilingInterface>(targetOp))
559+ static LogicalResult isTiledOpInLoop (Operation *targetOp) {
560+ // 1 . check tilable
561+ if (!isa<TilingInterface>(targetOp) && ! isTilableLinalgXOp (targetOp) )
559562 return failure ();
560- // 1 . check parentOp
563+ // 2 . check parentOp
561564 auto forOp = targetOp->getParentOfType <LoopLikeOpInterface>();
562565 if (!forOp)
563566 return failure ();
564- // 2. check single one tiling interface in loop body
565- auto walkResult = forOp->walk ([&targetOp](TilingInterface op) {
566- // some special op maybe already deal with in template
567- if (isa<linalg::FillOp, linalg::CopyOp>(op))
568- return WalkResult::skip ();
569- return op != targetOp ? WalkResult::interrupt () : WalkResult::advance ();
570- });
571- if (walkResult.wasInterrupted ())
572- return failure ();
567+
573568 // 3. check whether has either extract or insert slice op
574- walkResult = forOp->walk (
569+ auto walkResult = forOp->walk (
575570 [](tensor::ExtractSliceOp) { return WalkResult::interrupt (); });
576571 if (walkResult.wasInterrupted ())
577572 return success ();
@@ -583,11 +578,12 @@ static LogicalResult isSelfTiledOp(Operation *targetOp) {
583578using OpTileSizeMap = std::unordered_map<std::string, SmallVector<int64_t >>;
584579
585580template <typename OpTy>
586- static bool defaultTilingOfType (RewriterBase &rewriter, Operation *op,
587- const OpTileSizeMap &tsMap) {
581+ static FailureOr<scf::SCFTilingResult>
582+ defaultTilingOfType (RewriterBase &rewriter, Operation *op,
583+ const OpTileSizeMap &tsMap) {
588584 // a. Check <OpTy>
589585 if (!isa<TilingInterface>(op) || !isa<OpTy>(op))
590- return false ;
586+ return failure () ;
591587 auto tilingInterfaceOp = cast<TilingInterface>(op);
592588
593589 scf::SCFTilingOptions options;
@@ -618,26 +614,29 @@ static bool defaultTilingOfType(RewriterBase &rewriter, Operation *op,
618614 }
619615 // If the tile sizes are all zero, no tiling would happen.
620616 if (llvm::all_of (defaultTileSize, isZeroIndex))
621- return false ;
617+ return failure () ;
622618
623619 options.setTileSizes (defaultTileSize);
624620 // c. Set loop type
625621 options.setLoopType (scf::SCFTilingOptions::LoopType::ForallOp);
626622 // d. Use builtin tiling interface
627623 FailureOr<scf::SCFTilingResult> tilingResult =
628624 scf::tileUsingSCF (rewriter, tilingInterfaceOp, options);
629- if ( succeeded (tilingResult)) {
630- rewriter. replaceOp (op, tilingResult-> replacements );
631- return true ;
632- }
633- return false ;
625+
626+ if ( failed ( tilingResult))
627+ return failure () ;
628+
629+ return tilingResult ;
634630}
635631
632+ using DefaultTilingFn = std::function<FailureOr<scf::SCFTilingResult>(
633+ RewriterBase &, Operation *, const OpTileSizeMap &)>;
634+
636635void iterativeTilingAndFusionUntilExhaustion (
637636 RewriterBase &rewriter, func::FuncOp &f,
638637 const CandidateSliceOptions &sliceOptions, const OpTileSizeMap &tsMap) {
639638 // Collect untiled and tiled ops respectively
640- llvm::SetVector<Operation *> selfTiledOp , unTiledOps;
639+ llvm::SetVector<Operation *> tiledOps , unTiledOps;
641640
642641 auto collectUnTiledOps = [&f, &unTiledOps]() -> bool {
643642 // Reset
@@ -648,67 +647,64 @@ void iterativeTilingAndFusionUntilExhaustion(
648647 return WalkResult::skip ();
649648 if (isa<TilingInterface>(op)) {
650649 auto parentLoop = op->getParentOfType <LoopLikeOpInterface>();
651- if (!parentLoop.getOperation ())
650+ auto parentGeneric = op->getParentOfType <linalg::GenericOp>();
651+ if (!llvm::detail::isPresent (parentLoop) &&
652+ !llvm::detail::isPresent (parentGeneric))
652653 unTiledOps.insert (op);
653654 }
654655 return WalkResult::advance ();
655656 });
656657 return !unTiledOps.empty ();
657658 };
658659
659- auto collectSelfTiledOp = [&f, &selfTiledOp]() -> bool {
660- // Reset
661- selfTiledOp.clear ();
662- // Walk through funcOp
663- f->walk ([&selfTiledOp](Operation *op) {
664- // Target at certain kind of tiled op, such as matmul/conv implemented
665- // by multiple level of nest loops and candidate slices for better
666- // utilization of parallelism and memory hierarchy.
667- if (succeeded (isSelfTiledOp (op))) {
668- selfTiledOp.insert (op);
669- }
670- });
671- return !selfTiledOp.empty ();
672- };
660+ // Walk through funcOp
661+ f->walk ([&tiledOps](Operation *op) {
662+ if (succeeded (isTiledOpInLoop (op))) {
663+ tiledOps.insert (op);
664+ }
665+ });
673666
674667 // Iterative tiling and fusion until exhaustion.
675668 while (collectUnTiledOps ()) {
676669 // If existing tiled op before tiling.
677- if (collectSelfTiledOp ()) {
670+ if (!tiledOps. empty ()) {
678671 // Sort by topology
679- mlir::topologicalSort (selfTiledOp );
672+ mlir::topologicalSort (tiledOps );
680673 // Record if any fusion happens
681674 bool changed = false ;
682675 // Iteratively fuse in forward and backward fashion.
683- llvm::for_each (selfTiledOp, [&rewriter, &sliceOptions,
684- &changed](Operation *tiledOp) {
685- changed |= succeeded (iterativelyFuseProducerAndConsumerOfTiledOp (
686- rewriter, tiledOp, sliceOptions));
687- });
676+ llvm::for_each (
677+ tiledOps, [&rewriter, &sliceOptions, &changed](Operation *tiledOp) {
678+ changed |= succeeded (iterativelyFuseProducerAndConsumerOfTiledOp (
679+ rewriter, tiledOp, sliceOptions));
680+ });
681+ tiledOps.clear ();
688682 if (changed)
689- (void )mlir::simplifyRegions (rewriter, {f. getRegion ()} );
683+ (void )mlir::simplifyRegions (rewriter, f-> getRegions () );
690684 } else {
691685 // Auto tiling with default tile size if no tiled op found. Follow tiling
692686 // priority based on OpTy: `Contraction`->`Reduction`->`Elementwise`.
693- SmallVector<std::function<bool (RewriterBase &, Operation *,
694- const OpTileSizeMap &)>>
695- priorityTilingPipeLine = {
696- defaultTilingOfType<mlir::linalg::ContractionOpInterface>,
697- defaultTilingOfType<mlir::linalg::ReduceOp>,
698- defaultTilingOfType<mlir::linalg::LinalgOp>};
699- if (llvm::all_of (priorityTilingPipeLine,
700- [&rewriter, &tsMap, &unTiledOps](
701- function_ref<bool (RewriterBase &, Operation *,
702- const OpTileSizeMap &)>
703- tilingFn) {
704- return !llvm::any_of (
705- unTiledOps, std::bind (tilingFn, std::ref (rewriter),
706- std::placeholders::_1,
707- std::cref (tsMap)));
708- })) {
709- // If no op can be tiled
710- return ;
687+ SmallVector<DefaultTilingFn> priorityTilingPipeLine = {
688+ defaultTilingOfType<mlir::linalg::ContractionOpInterface>,
689+ defaultTilingOfType<mlir::linalg::ReduceOp>,
690+ defaultTilingOfType<TilingInterface>};
691+
692+ for (auto &tilingFn : priorityTilingPipeLine) {
693+ for (auto &op : unTiledOps) {
694+ FailureOr<scf::SCFTilingResult> tilingResult =
695+ tilingFn (rewriter, op, tsMap);
696+ if (succeeded (tilingResult)) {
697+ tiledOps.insert (tilingResult->tiledOps [0 ]);
698+ rewriter.replaceOp (op, tilingResult->replacements );
699+ break ;
700+ }
701+ }
702+ if (!tiledOps.empty ())
703+ break ;
711704 }
705+ // If no op can be tiled
706+ if (tiledOps.empty ())
707+ return ;
712708 }
713709 }
714710}
0 commit comments