@@ -46,15 +46,14 @@ getClosestExtractSliceOfOperand(OpOperand &operand) {
46
46
}
47
47
48
48
Operation *defineOp = operand.get ().getDefiningOp ();
49
- if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(defineOp)) {
49
+ if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(defineOp))
50
50
return sliceOp;
51
- } else if (isa<linalg::FillOp, tensor::ExpandShapeOp,
52
- tensor::CollapseShapeOp>(defineOp)) {
53
- // For downstream cases
51
+ // For downstream cases
52
+ if (isa<linalg::FillOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp>(
53
+ defineOp))
54
54
return getClosestExtractSliceOfOperand (defineOp->getOpOperand (0 ));
55
- } else {
56
- return failure ();
57
- }
55
+
56
+ return failure ();
58
57
}
59
58
60
59
static FailureOr<OffsetSizeAndStrideOpInterface>
@@ -104,9 +103,8 @@ struct CandidateDefOrUse {
104
103
using CandidateSliceFilter = std::function<LogicalResult(
105
104
RewriterBase &, OffsetSizeAndStrideOpInterface, CandidateDefOrUse)>;
106
105
107
- using CandidateSliceComparer =
108
- std::function<int (RewriterBase &, OffsetSizeAndStrideOpInterface,
109
- OffsetSizeAndStrideOpInterface, CandidateDefOrUse)>;
106
+ using CandidateSliceComparer = std::function<int (
107
+ OffsetSizeAndStrideOpInterface, OffsetSizeAndStrideOpInterface)>;
110
108
111
109
static LogicalResult
112
110
noTilingOnReductionFilter (RewriterBase &rewriter,
@@ -325,22 +323,19 @@ computeTileSizeProductOfCandidate(OffsetSizeAndStrideOpInterface candidate) {
325
323
return totalSize;
326
324
}
327
325
328
- static int TilingSizeComparer (RewriterBase &rewriter,
329
- OffsetSizeAndStrideOpInterface candidateA,
330
- OffsetSizeAndStrideOpInterface candidateB,
331
- CandidateDefOrUse defOrUse) {
326
+ static int TilingSizeComparer (OffsetSizeAndStrideOpInterface candidateA,
327
+ OffsetSizeAndStrideOpInterface candidateB) {
332
328
FailureOr<int64_t > sizeProductA =
333
329
computeTileSizeProductOfCandidate (candidateA),
334
330
sizeProductB =
335
331
computeTileSizeProductOfCandidate (candidateB);
336
332
if (failed (sizeProductA) || failed (sizeProductB))
337
333
return 0 ;
338
334
// deal with equality
339
- if (*sizeProductA == *sizeProductB) {
335
+ if (*sizeProductA == *sizeProductB)
340
336
return 0 ;
341
- } else {
342
- return *sizeProductA < *sizeProductB ? -1 : 1 ;
343
- }
337
+
338
+ return *sizeProductA < *sizeProductB ? -1 : 1 ;
344
339
}
345
340
346
341
struct CandidateSliceComparerPipeLine
@@ -352,17 +347,15 @@ struct CandidateSliceComparerPipeLine
352
347
return SmallVector<CandidateSliceComparer>{TilingSizeComparer};
353
348
}
354
349
355
- bool compare (RewriterBase &rewriter,
356
- OffsetSizeAndStrideOpInterface candidateA,
357
- OffsetSizeAndStrideOpInterface candidateB,
358
- CandidateDefOrUse defOrUse) const {
350
+ bool compare (OffsetSizeAndStrideOpInterface candidateA,
351
+ OffsetSizeAndStrideOpInterface candidateB) const {
359
352
// deal with weak order
360
353
int cmpResult = -1 ;
361
- for ( auto &fn : candidateProcessFn) {
362
- cmpResult = fn (rewriter, candidateA, candidateB, defOrUse);
363
- if ( cmpResult != 0 )
364
- break ;
365
- }
354
+ llvm::any_of (candidateProcessFn, [&cmpResult, &candidateA, &candidateB](
355
+ const CandidateSliceComparer &fn) {
356
+ cmpResult = fn (candidateA, candidateB);
357
+ return cmpResult != 0 ;
358
+ });
366
359
return cmpResult == -1 ;
367
360
}
368
361
};
@@ -389,6 +382,29 @@ struct CandidateSliceOptions {
389
382
}
390
383
};
391
384
385
+ static FailureOr<OffsetSizeAndStrideOpInterface> filterAndSelectCandidate (
386
+ RewriterBase &rewriter,
387
+ ArrayRef<OffsetSizeAndStrideOpInterface> candidateSliceList,
388
+ const CandidateDefOrUse &defOrUse, const CandidateSliceOptions &options) {
389
+ SmallVector<OffsetSizeAndStrideOpInterface> validCandidates =
390
+ llvm::to_vector (llvm::make_filter_range (
391
+ candidateSliceList,
392
+ [&rewriter, &options,
393
+ &defOrUse](const OffsetSizeAndStrideOpInterface &candidate) {
394
+ return succeeded (
395
+ options.filterPipeLine .filter (rewriter, candidate, defOrUse));
396
+ }));
397
+ if (validCandidates.empty ())
398
+ return failure ();
399
+
400
+ OffsetSizeAndStrideOpInterface bestCandidate = *llvm::min_element (
401
+ validCandidates, [&options](OffsetSizeAndStrideOpInterface &candidateA,
402
+ OffsetSizeAndStrideOpInterface &candidateB) {
403
+ return options.comparerPipeLine .compare (candidateA, candidateB);
404
+ });
405
+ return bestCandidate;
406
+ }
407
+
392
408
std::optional<scf::SCFFuseProducerOfSliceResult>
393
409
tileAndFuseProducerOfOpOperand (RewriterBase &rewriter, OpOperand &operand,
394
410
const CandidateSliceOptions &options) {
@@ -412,31 +428,20 @@ tileAndFuseProducerOfOpOperand(RewriterBase &rewriter, OpOperand &operand,
412
428
return std::nullopt;
413
429
414
430
CandidateDefOrUse defOrUse{*realProducer};
415
- // d. Filter out invalid candidates
416
- SmallVector<tensor::ExtractSliceOp> validCandidates =
417
- llvm::to_vector (llvm::make_filter_range (
418
- backwardSlice,
419
- [&rewriter, &options, &defOrUse](tensor::ExtractSliceOp &candidate) {
420
- return succeeded (options.filterPipeLine .filter (
421
- rewriter,
422
- cast<OffsetSizeAndStrideOpInterface>(candidate.getOperation ()),
423
- defOrUse));
424
- }));
425
- if (validCandidates.empty ())
431
+ // d. Filter out invalid candidates and select best candidates
432
+ SmallVector<OffsetSizeAndStrideOpInterface> ossBackwardSlice =
433
+ llvm::map_to_vector (backwardSlice,
434
+ [](tensor::ExtractSliceOp &extractSlice) {
435
+ return cast<OffsetSizeAndStrideOpInterface>(
436
+ extractSlice.getOperation ());
437
+ });
438
+ FailureOr<OffsetSizeAndStrideOpInterface> bestCandidate =
439
+ filterAndSelectCandidate (rewriter, ossBackwardSlice, defOrUse, options);
440
+ if (failed (bestCandidate))
426
441
return std::nullopt;
427
- // e. Select best candidates by Cost Model
428
- tensor::ExtractSliceOp bestCandidate = *llvm::min_element (
429
- validCandidates,
430
- [&rewriter, &options, &defOrUse](tensor::ExtractSliceOp &candidateA,
431
- tensor::ExtractSliceOp &candidateB) {
432
- return options.comparerPipeLine .compare (
433
- rewriter,
434
- cast<OffsetSizeAndStrideOpInterface>(candidateA.getOperation ()),
435
- cast<OffsetSizeAndStrideOpInterface>(candidateB.getOperation ()),
436
- defOrUse);
437
- });
438
- // f. call tiling interface
439
- return scfX::tileAndFuseProducerOfSlice (rewriter, bestCandidate);
442
+
443
+ // e. call tiling interface
444
+ return scfX::tileAndFuseProducerOfSlice (rewriter, *bestCandidate);
440
445
}
441
446
442
447
std::optional<SmallVector<scf::SCFFuseConsumerOfSliceResult>>
@@ -464,28 +469,15 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
464
469
continue ;
465
470
466
471
CandidateDefOrUse defOrUse{useOperand};
467
- // d. Filter out invalid candidates
468
- SmallVector<OffsetSizeAndStrideOpInterface> validCandidates =
469
- llvm::to_vector (llvm::make_filter_range (
470
- forwardSlice, [&rewriter, &options, &defOrUse](
471
- const OffsetSizeAndStrideOpInterface &candidate) {
472
- return succeeded (
473
- options.filterPipeLine .filter (rewriter, candidate, defOrUse));
474
- }));
475
- if (validCandidates.empty ())
472
+ // d. Filter out invalid candidates and select best candidates
473
+ FailureOr<OffsetSizeAndStrideOpInterface> bestCandidate =
474
+ filterAndSelectCandidate (rewriter, forwardSlice, defOrUse, options);
475
+ if (failed (bestCandidate))
476
476
continue ;
477
477
478
- // e. Select best candidates by Cost Model
479
- OffsetSizeAndStrideOpInterface bestCandidate = *llvm::min_element (
480
- validCandidates, [&rewriter, &options, &defOrUse](
481
- const OffsetSizeAndStrideOpInterface &candidateA,
482
- const OffsetSizeAndStrideOpInterface &candidateB) {
483
- return options.comparerPipeLine .compare (rewriter, candidateA,
484
- candidateB, defOrUse);
485
- });
486
- // f. call tiling interface
478
+ // e. call tiling interface
487
479
FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
488
- scfX::tileAndFuseConsumerOfSlice (rewriter, bestCandidate);
480
+ scfX::tileAndFuseConsumerOfSlice (rewriter, * bestCandidate);
489
481
490
482
if (succeeded (fusedResult)) {
491
483
fusedResultList.push_back (*fusedResult);
@@ -496,7 +488,7 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
496
488
};
497
489
SmallVector<LoopLikeOpInterface> outerLoops =
498
490
scfX::getOuterNestLoopsWhile (
499
- bestCandidate->getParentOfType <LoopLikeOpInterface>(),
491
+ (* bestCandidate) ->getParentOfType <LoopLikeOpInterface>(),
500
492
whileProducerOutOfLoopBlock);
501
493
// g. Manually run cse on region which contains top-level loop of
502
494
// candidate slice in avoid of conflict with subsequent
@@ -506,11 +498,10 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
506
498
{*outerLoops.front ()->getParentRegion ()});
507
499
}
508
500
}
509
- if (fusedResultList.empty ()) {
501
+ if (fusedResultList.empty ())
510
502
return std::nullopt;
511
- } else {
512
- return fusedResultList;
513
- }
503
+
504
+ return fusedResultList;
514
505
}
515
506
516
507
// / Target at following general topology:
@@ -527,7 +518,7 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
527
518
LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp (
528
519
RewriterBase &rewriter, Operation *tiledOp,
529
520
const CandidateSliceOptions &options) {
530
- int numTiledOps = 0 ;
521
+ unsigned numTiledOps = 0 ;
531
522
std::deque<Operation *> tiledOpList = {tiledOp};
532
523
while (!tiledOpList.empty ()) {
533
524
tiledOp = tiledOpList.front ();
@@ -552,7 +543,7 @@ LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
552
543
return success (numTiledOps > 1 );
553
544
}
554
545
555
- // / What is single tiled op in loop ?
546
+ // / What is self tiled op compared with other fused op ?
556
547
// / E.g.
557
548
// / %1 = scf.for(){
558
549
// / %2 = scf.for(){
@@ -562,7 +553,7 @@ LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
562
553
// / yield %5
563
554
// / }
564
555
// / }
565
- static LogicalResult isSingleTiledOpInLoop (Operation *targetOp) {
556
+ static LogicalResult isSelfTiledOp (Operation *targetOp) {
566
557
// 0. check tilable
567
558
if (!isa<TilingInterface>(targetOp))
568
559
return failure ();
@@ -694,16 +685,15 @@ static bool defaultTilingOfType(RewriterBase &rewriter, Operation *op,
694
685
if (succeeded (tilingResult)) {
695
686
rewriter.replaceOp (op, tilingResult->replacements );
696
687
return true ;
697
- } else {
698
- return false ;
699
688
}
689
+ return false ;
700
690
}
701
691
702
692
void iterativeTilingAndFusionUntilExhaustion (
703
693
RewriterBase &rewriter, func::FuncOp &f,
704
694
const CandidateSliceOptions &sliceOptions, const OpTileSizeMap &tsMap) {
705
695
// Collect untiled and tiled ops respectively
706
- llvm::SetVector<Operation *> singleTiledOpInLoop , unTiledOps;
696
+ llvm::SetVector<Operation *> selfTiledOp , unTiledOps;
707
697
708
698
auto collectUnTiledOps = [&f, &unTiledOps]() -> bool {
709
699
// Reset
@@ -712,8 +702,7 @@ void iterativeTilingAndFusionUntilExhaustion(
712
702
f->walk <WalkOrder::PreOrder>([&unTiledOps](Operation *op) {
713
703
if (isa<LoopLikeOpInterface>(op))
714
704
return WalkResult::skip ();
715
-
716
- if (isa<TilingInterface>(op) && !op->use_empty ()) {
705
+ if (isa<TilingInterface>(op)) {
717
706
auto parentLoop = op->getParentOfType <LoopLikeOpInterface>();
718
707
if (!parentLoop.getOperation ())
719
708
unTiledOps.insert (op);
@@ -723,32 +712,32 @@ void iterativeTilingAndFusionUntilExhaustion(
723
712
return !unTiledOps.empty ();
724
713
};
725
714
726
- auto collectSingleTiledOpInLoop = [&f, &singleTiledOpInLoop ]() -> bool {
715
+ auto collectSelfTiledOp = [&f, &selfTiledOp ]() -> bool {
727
716
// Reset
728
- singleTiledOpInLoop .clear ();
717
+ selfTiledOp .clear ();
729
718
// Walk through funcOp
730
- f->walk ([&singleTiledOpInLoop ](Operation *op) {
719
+ f->walk ([&selfTiledOp ](Operation *op) {
731
720
// Target at certain kind of tiled op, such as matmul/conv implemented
732
721
// by multiple level of nest loops and candidate slices for better
733
722
// utilization of parallelism and memory hierarchy.
734
- if (succeeded (isSingleTiledOpInLoop (op))) {
735
- singleTiledOpInLoop .insert (op);
723
+ if (succeeded (isSelfTiledOp (op))) {
724
+ selfTiledOp .insert (op);
736
725
}
737
726
});
738
- return !singleTiledOpInLoop .empty ();
727
+ return !selfTiledOp .empty ();
739
728
};
740
729
741
730
// Iterative tiling and fusion until exhaustion.
742
731
while (collectUnTiledOps ()) {
743
732
// If existing tiled op before tiling.
744
- if (collectSingleTiledOpInLoop ()) {
733
+ if (collectSelfTiledOp ()) {
745
734
// Sort by topology
746
- mlir::topologicalSort (singleTiledOpInLoop );
735
+ mlir::topologicalSort (selfTiledOp );
747
736
// Record if any fusion happens
748
737
bool changed = false ;
749
738
// Iteratively fuse in forward and backward fashion.
750
- llvm::for_each (singleTiledOpInLoop , [&rewriter, &sliceOptions,
751
- &changed](Operation *tiledOp) {
739
+ llvm::for_each (selfTiledOp , [&rewriter, &sliceOptions,
740
+ &changed](Operation *tiledOp) {
752
741
changed |= succeeded (iterativelyFuseProducerAndConsumerOfTiledOp (
753
742
rewriter, tiledOp, sliceOptions));
754
743
});
@@ -774,7 +763,7 @@ void iterativeTilingAndFusionUntilExhaustion(
774
763
std::cref (tsMap)));
775
764
})) {
776
765
// If no op can be tiled
777
- break ;
766
+ return ;
778
767
}
779
768
}
780
769
}
0 commit comments