@@ -46,15 +46,14 @@ getClosestExtractSliceOfOperand(OpOperand &operand) {
4646 }
4747
4848 Operation *defineOp = operand.get ().getDefiningOp ();
49- if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(defineOp)) {
49+ if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(defineOp))
5050 return sliceOp;
51- } else if (isa<linalg::FillOp, tensor::ExpandShapeOp,
52- tensor::CollapseShapeOp>(defineOp)) {
53- // For downstream cases
51+ // For downstream cases
52+ if (isa<linalg::FillOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp>(
53+ defineOp))
5454 return getClosestExtractSliceOfOperand (defineOp->getOpOperand (0 ));
55- } else {
56- return failure ();
57- }
55+
56+ return failure ();
5857}
5958
6059static FailureOr<OffsetSizeAndStrideOpInterface>
@@ -104,9 +103,8 @@ struct CandidateDefOrUse {
104103using CandidateSliceFilter = std::function<LogicalResult(
105104 RewriterBase &, OffsetSizeAndStrideOpInterface, CandidateDefOrUse)>;
106105
107- using CandidateSliceComparer =
108- std::function<int (RewriterBase &, OffsetSizeAndStrideOpInterface,
109- OffsetSizeAndStrideOpInterface, CandidateDefOrUse)>;
106+ using CandidateSliceComparer = std::function<int (
107+ OffsetSizeAndStrideOpInterface, OffsetSizeAndStrideOpInterface)>;
110108
111109static LogicalResult
112110noTilingOnReductionFilter (RewriterBase &rewriter,
@@ -325,22 +323,19 @@ computeTileSizeProductOfCandidate(OffsetSizeAndStrideOpInterface candidate) {
325323 return totalSize;
326324}
327325
328- static int TilingSizeComparer (RewriterBase &rewriter,
329- OffsetSizeAndStrideOpInterface candidateA,
330- OffsetSizeAndStrideOpInterface candidateB,
331- CandidateDefOrUse defOrUse) {
326+ static int TilingSizeComparer (OffsetSizeAndStrideOpInterface candidateA,
327+ OffsetSizeAndStrideOpInterface candidateB) {
332328 FailureOr<int64_t > sizeProductA =
333329 computeTileSizeProductOfCandidate (candidateA),
334330 sizeProductB =
335331 computeTileSizeProductOfCandidate (candidateB);
336332 if (failed (sizeProductA) || failed (sizeProductB))
337333 return 0 ;
338334 // deal with equality
339- if (*sizeProductA == *sizeProductB) {
335+ if (*sizeProductA == *sizeProductB)
340336 return 0 ;
341- } else {
342- return *sizeProductA < *sizeProductB ? -1 : 1 ;
343- }
337+
338+ return *sizeProductA < *sizeProductB ? -1 : 1 ;
344339}
345340
346341struct CandidateSliceComparerPipeLine
@@ -352,17 +347,15 @@ struct CandidateSliceComparerPipeLine
352347 return SmallVector<CandidateSliceComparer>{TilingSizeComparer};
353348 }
354349
355- bool compare (RewriterBase &rewriter,
356- OffsetSizeAndStrideOpInterface candidateA,
357- OffsetSizeAndStrideOpInterface candidateB,
358- CandidateDefOrUse defOrUse) const {
350+ bool compare (OffsetSizeAndStrideOpInterface candidateA,
351+ OffsetSizeAndStrideOpInterface candidateB) const {
359352 // deal with weak order
360353 int cmpResult = -1 ;
361- for ( auto &fn : candidateProcessFn) {
362- cmpResult = fn (rewriter, candidateA, candidateB, defOrUse);
363- if ( cmpResult != 0 )
364- break ;
365- }
354+ llvm::any_of (candidateProcessFn, [&cmpResult, &candidateA, &candidateB](
355+ const CandidateSliceComparer &fn) {
356+ cmpResult = fn (candidateA, candidateB);
357+ return cmpResult != 0 ;
358+ });
366359 return cmpResult == -1 ;
367360 }
368361};
@@ -389,6 +382,29 @@ struct CandidateSliceOptions {
389382 }
390383};
391384
385+ static FailureOr<OffsetSizeAndStrideOpInterface> filterAndSelectCandidate (
386+ RewriterBase &rewriter,
387+ ArrayRef<OffsetSizeAndStrideOpInterface> candidateSliceList,
388+ const CandidateDefOrUse &defOrUse, const CandidateSliceOptions &options) {
389+ SmallVector<OffsetSizeAndStrideOpInterface> validCandidates =
390+ llvm::to_vector (llvm::make_filter_range (
391+ candidateSliceList,
392+ [&rewriter, &options,
393+ &defOrUse](const OffsetSizeAndStrideOpInterface &candidate) {
394+ return succeeded (
395+ options.filterPipeLine .filter (rewriter, candidate, defOrUse));
396+ }));
397+ if (validCandidates.empty ())
398+ return failure ();
399+
400+ OffsetSizeAndStrideOpInterface bestCandidate = *llvm::min_element (
401+ validCandidates, [&options](OffsetSizeAndStrideOpInterface &candidateA,
402+ OffsetSizeAndStrideOpInterface &candidateB) {
403+ return options.comparerPipeLine .compare (candidateA, candidateB);
404+ });
405+ return bestCandidate;
406+ }
407+
392408std::optional<scf::SCFFuseProducerOfSliceResult>
393409tileAndFuseProducerOfOpOperand (RewriterBase &rewriter, OpOperand &operand,
394410 const CandidateSliceOptions &options) {
@@ -412,31 +428,20 @@ tileAndFuseProducerOfOpOperand(RewriterBase &rewriter, OpOperand &operand,
412428 return std::nullopt ;
413429
414430 CandidateDefOrUse defOrUse{*realProducer};
415- // d. Filter out invalid candidates
416- SmallVector<tensor::ExtractSliceOp> validCandidates =
417- llvm::to_vector (llvm::make_filter_range (
418- backwardSlice,
419- [&rewriter, &options, &defOrUse](tensor::ExtractSliceOp &candidate) {
420- return succeeded (options.filterPipeLine .filter (
421- rewriter,
422- cast<OffsetSizeAndStrideOpInterface>(candidate.getOperation ()),
423- defOrUse));
424- }));
425- if (validCandidates.empty ())
431+ // d. Filter out invalid candidates and select best candidates
432+ SmallVector<OffsetSizeAndStrideOpInterface> ossBackwardSlice =
433+ llvm::map_to_vector (backwardSlice,
434+ [](tensor::ExtractSliceOp &extractSlice) {
435+ return cast<OffsetSizeAndStrideOpInterface>(
436+ extractSlice.getOperation ());
437+ });
438+ FailureOr<OffsetSizeAndStrideOpInterface> bestCandidate =
439+ filterAndSelectCandidate (rewriter, ossBackwardSlice, defOrUse, options);
440+ if (failed (bestCandidate))
426441 return std::nullopt ;
427- // e. Select best candidates by Cost Model
428- tensor::ExtractSliceOp bestCandidate = *llvm::min_element (
429- validCandidates,
430- [&rewriter, &options, &defOrUse](tensor::ExtractSliceOp &candidateA,
431- tensor::ExtractSliceOp &candidateB) {
432- return options.comparerPipeLine .compare (
433- rewriter,
434- cast<OffsetSizeAndStrideOpInterface>(candidateA.getOperation ()),
435- cast<OffsetSizeAndStrideOpInterface>(candidateB.getOperation ()),
436- defOrUse);
437- });
438- // f. call tiling interface
439- return scfX::tileAndFuseProducerOfSlice (rewriter, bestCandidate);
442+
443+ // e. call tiling interface
444+ return scfX::tileAndFuseProducerOfSlice (rewriter, *bestCandidate);
440445}
441446
442447std::optional<SmallVector<scf::SCFFuseConsumerOfSliceResult>>
@@ -464,28 +469,15 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
464469 continue ;
465470
466471 CandidateDefOrUse defOrUse{useOperand};
467- // d. Filter out invalid candidates
468- SmallVector<OffsetSizeAndStrideOpInterface> validCandidates =
469- llvm::to_vector (llvm::make_filter_range (
470- forwardSlice, [&rewriter, &options, &defOrUse](
471- const OffsetSizeAndStrideOpInterface &candidate) {
472- return succeeded (
473- options.filterPipeLine .filter (rewriter, candidate, defOrUse));
474- }));
475- if (validCandidates.empty ())
472+ // d. Filter out invalid candidates and select best candidates
473+ FailureOr<OffsetSizeAndStrideOpInterface> bestCandidate =
474+ filterAndSelectCandidate (rewriter, forwardSlice, defOrUse, options);
475+ if (failed (bestCandidate))
476476 continue ;
477477
478- // e. Select best candidates by Cost Model
479- OffsetSizeAndStrideOpInterface bestCandidate = *llvm::min_element (
480- validCandidates, [&rewriter, &options, &defOrUse](
481- const OffsetSizeAndStrideOpInterface &candidateA,
482- const OffsetSizeAndStrideOpInterface &candidateB) {
483- return options.comparerPipeLine .compare (rewriter, candidateA,
484- candidateB, defOrUse);
485- });
486- // f. call tiling interface
478+ // e. call tiling interface
487479 FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
488- scfX::tileAndFuseConsumerOfSlice (rewriter, bestCandidate);
480+ scfX::tileAndFuseConsumerOfSlice (rewriter, * bestCandidate);
489481
490482 if (succeeded (fusedResult)) {
491483 fusedResultList.push_back (*fusedResult);
@@ -496,7 +488,7 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
496488 };
497489 SmallVector<LoopLikeOpInterface> outerLoops =
498490 scfX::getOuterNestLoopsWhile (
499- bestCandidate->getParentOfType <LoopLikeOpInterface>(),
491+ (* bestCandidate) ->getParentOfType <LoopLikeOpInterface>(),
500492 whileProducerOutOfLoopBlock);
501493 // g. Manually run cse on region which contains top-level loop of
502494 // candidate slice in avoid of conflict with subsequent
@@ -506,11 +498,10 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
506498 {*outerLoops.front ()->getParentRegion ()});
507499 }
508500 }
509- if (fusedResultList.empty ()) {
501+ if (fusedResultList.empty ())
510502 return std::nullopt ;
511- } else {
512- return fusedResultList;
513- }
503+
504+ return fusedResultList;
514505}
515506
516507// / Target at following general topology:
@@ -527,7 +518,7 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
527518LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp (
528519 RewriterBase &rewriter, Operation *tiledOp,
529520 const CandidateSliceOptions &options) {
530- int numTiledOps = 0 ;
521+ unsigned numTiledOps = 0 ;
531522 std::deque<Operation *> tiledOpList = {tiledOp};
532523 while (!tiledOpList.empty ()) {
533524 tiledOp = tiledOpList.front ();
@@ -552,7 +543,7 @@ LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
552543 return success (numTiledOps > 1 );
553544}
554545
555- // / What is single tiled op in loop ?
546+ // / What is self tiled op compared with other fused op ?
556547// / E.g.
557548// / %1 = scf.for(){
558549// / %2 = scf.for(){
@@ -562,7 +553,7 @@ LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
562553// / yield %5
563554// / }
564555// / }
565- static LogicalResult isSingleTiledOpInLoop (Operation *targetOp) {
556+ static LogicalResult isSelfTiledOp (Operation *targetOp) {
566557 // 0. check tilable
567558 if (!isa<TilingInterface>(targetOp))
568559 return failure ();
@@ -694,16 +685,15 @@ static bool defaultTilingOfType(RewriterBase &rewriter, Operation *op,
694685 if (succeeded (tilingResult)) {
695686 rewriter.replaceOp (op, tilingResult->replacements );
696687 return true ;
697- } else {
698- return false ;
699688 }
689+ return false ;
700690}
701691
702692void iterativeTilingAndFusionUntilExhaustion (
703693 RewriterBase &rewriter, func::FuncOp &f,
704694 const CandidateSliceOptions &sliceOptions, const OpTileSizeMap &tsMap) {
705695 // Collect untiled and tiled ops respectively
706- llvm::SetVector<Operation *> singleTiledOpInLoop , unTiledOps;
696+ llvm::SetVector<Operation *> selfTiledOp , unTiledOps;
707697
708698 auto collectUnTiledOps = [&f, &unTiledOps]() -> bool {
709699 // Reset
@@ -712,8 +702,7 @@ void iterativeTilingAndFusionUntilExhaustion(
712702 f->walk <WalkOrder::PreOrder>([&unTiledOps](Operation *op) {
713703 if (isa<LoopLikeOpInterface>(op))
714704 return WalkResult::skip ();
715-
716- if (isa<TilingInterface>(op) && !op->use_empty ()) {
705+ if (isa<TilingInterface>(op)) {
717706 auto parentLoop = op->getParentOfType <LoopLikeOpInterface>();
718707 if (!parentLoop.getOperation ())
719708 unTiledOps.insert (op);
@@ -723,32 +712,32 @@ void iterativeTilingAndFusionUntilExhaustion(
723712 return !unTiledOps.empty ();
724713 };
725714
726- auto collectSingleTiledOpInLoop = [&f, &singleTiledOpInLoop ]() -> bool {
715+ auto collectSelfTiledOp = [&f, &selfTiledOp ]() -> bool {
727716 // Reset
728- singleTiledOpInLoop .clear ();
717+ selfTiledOp .clear ();
729718 // Walk through funcOp
730- f->walk ([&singleTiledOpInLoop ](Operation *op) {
719+ f->walk ([&selfTiledOp ](Operation *op) {
731720 // Target at certain kind of tiled op, such as matmul/conv implemented
732721 // by multiple level of nest loops and candidate slices for better
733722 // utilization of parallelism and memory hierarchy.
734- if (succeeded (isSingleTiledOpInLoop (op))) {
735- singleTiledOpInLoop .insert (op);
723+ if (succeeded (isSelfTiledOp (op))) {
724+ selfTiledOp .insert (op);
736725 }
737726 });
738- return !singleTiledOpInLoop .empty ();
727+ return !selfTiledOp .empty ();
739728 };
740729
741730 // Iterative tiling and fusion until exhaustion.
742731 while (collectUnTiledOps ()) {
743732 // If existing tiled op before tiling.
744- if (collectSingleTiledOpInLoop ()) {
733+ if (collectSelfTiledOp ()) {
745734 // Sort by topology
746- mlir::topologicalSort (singleTiledOpInLoop );
735+ mlir::topologicalSort (selfTiledOp );
747736 // Record if any fusion happens
748737 bool changed = false ;
749738 // Iteratively fuse in forward and backward fashion.
750- llvm::for_each (singleTiledOpInLoop , [&rewriter, &sliceOptions,
751- &changed](Operation *tiledOp) {
739+ llvm::for_each (selfTiledOp , [&rewriter, &sliceOptions,
740+ &changed](Operation *tiledOp) {
752741 changed |= succeeded (iterativelyFuseProducerAndConsumerOfTiledOp (
753742 rewriter, tiledOp, sliceOptions));
754743 });
@@ -774,7 +763,7 @@ void iterativeTilingAndFusionUntilExhaustion(
774763 std::cref (tsMap)));
775764 })) {
776765 // If no op can be tiled
777- break ;
766+ return ;
778767 }
779768 }
780769 }
0 commit comments