@@ -41,9 +41,8 @@ static FailureOr<tensor::ExtractSliceOp>
4141getClosestExtractSliceOfOperand (OpOperand &operand) {
4242 if (auto iterArg = dyn_cast<BlockArgument>(operand.get ())) {
4343 if (auto loop =
44- dyn_cast<LoopLikeOpInterface>(iterArg.getOwner ()->getParentOp ())) {
44+ dyn_cast<LoopLikeOpInterface>(iterArg.getOwner ()->getParentOp ()))
4545 return getClosestExtractSliceOfOperand (*loop.getTiedLoopInit (iterArg));
46- }
4746 }
4847
4948 Operation *defineOp = operand.get ().getDefiningOp ();
@@ -69,10 +68,9 @@ getClosestInsertSliceOfResult(OpResult result) {
6968 sliceOp =
7069 dyn_cast<OffsetSizeAndStrideOpInterface>(useOfResult.getOwner ());
7170 } else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOfResult.getOwner ())) {
72- if (auto loop = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp ())) {
71+ if (auto loop = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp ()))
7372 return getClosestInsertSliceOfResult (
7473 loop->getResult (useOfResult.getOperandNumber ()));
75- }
7674 }
7775 }
7876
@@ -138,9 +136,8 @@ noTilingOnReductionFilter(RewriterBase &rewriter,
138136 presburger::BoundType::UB, tileSizes[resultExpr.index ()], nullptr ,
139137 true );
140138 if (!cstIterDomain || failed (cstTileSizes) ||
141- cstIterDomain != cstTileSizes) {
139+ cstIterDomain != cstTileSizes)
142140 return failure ();
143- }
144141 }
145142 }
146143 return success ();
@@ -246,9 +243,8 @@ SingleCandidateInBlockFilter(RewriterBase &rewriter,
246243 scfX::getRealProducerOfExtractSliceOp (otherCandidate,
247244 backwardSlice);
248245 if (succeeded (realProducer) &&
249- realProducer->getDefiningOp () == defOrUse.ownerOp ) {
246+ realProducer->getDefiningOp () == defOrUse.ownerOp )
250247 return failure ();
251- }
252248 } else {
253249 SmallVector<OffsetSizeAndStrideOpInterface> forwardSlice;
254250 FailureOr<SmallVector<OpOperand *>> realConsumers =
@@ -257,9 +253,8 @@ SingleCandidateInBlockFilter(RewriterBase &rewriter,
257253 if (succeeded (realConsumers) &&
258254 llvm::any_of (*realConsumers, [&defOrUse](OpOperand *use) {
259255 return use->getOwner () == defOrUse.ownerOp ;
260- })) {
256+ }))
261257 return failure ();
262- }
263258 }
264259 }
265260 }
@@ -338,9 +333,8 @@ static int TilingSizeComparer(RewriterBase &rewriter,
338333 computeTileSizeProductOfCandidate (candidateA),
339334 sizeProductB =
340335 computeTileSizeProductOfCandidate (candidateB);
341- if (failed (sizeProductA) || failed (sizeProductB)) {
336+ if (failed (sizeProductA) || failed (sizeProductB))
342337 return 0 ;
343- }
344338 // deal with equality
345339 if (*sizeProductA == *sizeProductB) {
346340 return 0 ;
@@ -401,17 +395,17 @@ tileAndFuseProducerOfOpOperand(RewriterBase &rewriter, OpOperand &operand,
401395 // a. Find the closest sliceOp
402396 FailureOr<tensor::ExtractSliceOp> closestSliceOp =
403397 getClosestExtractSliceOfOperand (operand);
404- if (failed (closestSliceOp)) {
398+ if (failed (closestSliceOp))
405399 return std::nullopt ;
406- }
400+
407401 // b. Find the real producer and collect the sliceOp chain during backward
408402 // stage, sorted from inner to outer.
409403 SmallVector<tensor::ExtractSliceOp> backwardSlice;
410404 FailureOr<OpResult> realProducer =
411405 scfX::getRealProducerOfExtractSliceOp (*closestSliceOp, backwardSlice);
412- if (failed (realProducer)) {
406+ if (failed (realProducer))
413407 return std::nullopt ;
414- }
408+
415409 // c. Check the producer of root source if is tilable.
416410 Operation *producer = realProducer->getDefiningOp <TilingInterface>();
417411 if (!producer)
@@ -451,17 +445,16 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
451445 // a. Find the closest sliceOp
452446 FailureOr<tensor::ExtractSliceOp> closestSliceOp =
453447 getClosestInsertSliceOfResult (result);
454- if (failed (closestSliceOp)) {
448+ if (failed (closestSliceOp))
455449 return std::nullopt ;
456- }
450+
457451 // b. Find the real consumers and collect the sliceOp chain during forward
458452 // stage, sorted from inner to outer.
459453 SmallVector<OffsetSizeAndStrideOpInterface> forwardSlice;
460454 FailureOr<SmallVector<OpOperand *>> realConsumers =
461455 scfX::getRealConsumersFromInsertSliceOp (*closestSliceOp, forwardSlice);
462- if (failed (realConsumers)) {
456+ if (failed (realConsumers))
463457 return std::nullopt ;
464- }
465458
466459 SmallVector<scf::SCFFuseConsumerOfSliceResult> fusedResultList;
467460 for (auto useOperand : *realConsumers) {
@@ -543,18 +536,16 @@ LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
543536 // fuse producer
544537 for (OpOperand &operand : tiledOp->getOpOperands ()) {
545538 if (std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult =
546- tileAndFuseProducerOfOpOperand (rewriter, operand, options)) {
539+ tileAndFuseProducerOfOpOperand (rewriter, operand, options))
547540 tiledOpList.push_back (fuseProducerResult.value ().tiledOps [0 ]);
548- }
549541 }
550542 // fuse consumer(s)
551543 for (OpResult result : tiledOp->getResults ()) {
552544 if (std::optional<SmallVector<scf::SCFFuseConsumerOfSliceResult>>
553545 fuseConsumerResults =
554546 tileAndFuseConsumerOfOpResult (rewriter, result, options)) {
555- for (auto &fuseConsumerResult : *fuseConsumerResults) {
547+ for (auto &fuseConsumerResult : *fuseConsumerResults)
556548 tiledOpList.push_back (fuseConsumerResult.tiledOps [0 ]);
557- }
558549 }
559550 }
560551 }
@@ -573,30 +564,26 @@ LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
573564// / }
574565static LogicalResult isSingleTiledOpInLoop (Operation *targetOp) {
575566 // 0. check tilable
576- if (!isa<TilingInterface>(targetOp)) {
567+ if (!isa<TilingInterface>(targetOp))
577568 return failure ();
578- }
579569 // 1. check parentOp
580570 auto forOp = targetOp->getParentOfType <LoopLikeOpInterface>();
581- if (!forOp) {
571+ if (!forOp)
582572 return failure ();
583- }
584573 // 2. check single one tiling interface in loop body
585574 auto walkResult = forOp->walk ([&targetOp](TilingInterface op) {
586575 // some special op maybe already deal with in template
587576 if (isa<linalg::FillOp, linalg::CopyOp>(op))
588577 return WalkResult::skip ();
589578 return op != targetOp ? WalkResult::interrupt () : WalkResult::advance ();
590579 });
591- if (walkResult.wasInterrupted ()) {
580+ if (walkResult.wasInterrupted ())
592581 return failure ();
593- }
594582 // 3. check whether has either extract or insert slice op
595583 walkResult = forOp->walk (
596584 [](tensor::ExtractSliceOp) { return WalkResult::interrupt (); });
597- if (walkResult.wasInterrupted ()) {
585+ if (walkResult.wasInterrupted ())
598586 return success ();
599- }
600587 walkResult = forOp->walk (
601588 [](tensor::InsertSliceOp) { return WalkResult::interrupt (); });
602589 return success (walkResult.wasInterrupted ());
@@ -690,9 +677,8 @@ static bool defaultTilingOfType(RewriterBase &rewriter, Operation *op,
690677 // word, all reduction dimensions should not be tiled.
691678 if (iterType == utils::IteratorType::parallel &&
692679 (en != iteratorTypes.size () - 1 ||
693- llvm::count (iteratorTypes, utils::IteratorType::reduction))) {
680+ llvm::count (iteratorTypes, utils::IteratorType::reduction)))
694681 defaultTileSize[en] = rewriter.getIndexAttr (1 );
695- }
696682 }
697683 }
698684 // If the tile sizes are all zero, no tiling would happen.
@@ -724,14 +710,13 @@ void iterativeTilingAndFusionUntilExhaustion(
724710 unTiledOps.clear ();
725711 // Pre-order walk through funcOp
726712 f->walk <WalkOrder::PreOrder>([&unTiledOps](Operation *op) {
727- if (isa<LoopLikeOpInterface>(op)) {
713+ if (isa<LoopLikeOpInterface>(op))
728714 return WalkResult::skip ();
729- }
715+
730716 if (isa<TilingInterface>(op) && !op->use_empty ()) {
731717 auto parentLoop = op->getParentOfType <LoopLikeOpInterface>();
732- if (!parentLoop.getOperation ()) {
718+ if (!parentLoop.getOperation ())
733719 unTiledOps.insert (op);
734- }
735720 }
736721 return WalkResult::advance ();
737722 });
@@ -767,9 +752,8 @@ void iterativeTilingAndFusionUntilExhaustion(
767752 changed |= succeeded (iterativelyFuseProducerAndConsumerOfTiledOp (
768753 rewriter, tiledOp, sliceOptions));
769754 });
770- if (changed) {
755+ if (changed)
771756 (void )mlir::simplifyRegions (rewriter, {f.getRegion ()});
772- }
773757 } else {
774758 // Auto tiling with default tile size if no tiled op found. Follow tiling
775759 // priority based on OpTy: `Contraction`->`Reduction`->`Elementwise`.
@@ -803,15 +787,15 @@ static OpTileSizeMap defaultTileSizeParser(ArrayRef<std::string> strArgs) {
803787 for (auto str : strArgs) {
804788 str.erase (llvm::remove_if (str, llvm::isSpace), str.end ());
805789 size_t pos = str.find (" :" );
806- if (pos == std::string::npos) {
790+ if (pos == std::string::npos)
807791 llvm_unreachable (warning);
808- }
792+
809793 std::string opType = str.substr (0 , pos);
810794 std::string strTileSize = str.erase (0 , pos + 1 );
811795 if (strTileSize.size () <= 2 || strTileSize.front () != ' {' ||
812- strTileSize.back () != ' }' ) {
796+ strTileSize.back () != ' }' )
813797 llvm_unreachable (warning);
814- }
798+
815799 strTileSize = strTileSize.substr (1 , strTileSize.size () - 2 );
816800 SmallVector<int64_t > intTileSize;
817801 while ((pos = strTileSize.find (" ," )) != std::string::npos) {
0 commit comments