Skip to content

Commit 443100f

Browse files
committed
fix second portion comment
1 parent 3ad0e30 commit 443100f

File tree

4 files changed

+123
-123
lines changed

4 files changed

+123
-123
lines changed

lib/gc/Transforms/IterativeTilingAndFusion.cpp

+83-94
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,14 @@ getClosestExtractSliceOfOperand(OpOperand &operand) {
4646
}
4747

4848
Operation *defineOp = operand.get().getDefiningOp();
49-
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(defineOp)) {
49+
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(defineOp))
5050
return sliceOp;
51-
} else if (isa<linalg::FillOp, tensor::ExpandShapeOp,
52-
tensor::CollapseShapeOp>(defineOp)) {
53-
// For downstream cases
51+
// For downstream cases
52+
if (isa<linalg::FillOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp>(
53+
defineOp))
5454
return getClosestExtractSliceOfOperand(defineOp->getOpOperand(0));
55-
} else {
56-
return failure();
57-
}
55+
56+
return failure();
5857
}
5958

6059
static FailureOr<OffsetSizeAndStrideOpInterface>
@@ -104,9 +103,8 @@ struct CandidateDefOrUse {
104103
using CandidateSliceFilter = std::function<LogicalResult(
105104
RewriterBase &, OffsetSizeAndStrideOpInterface, CandidateDefOrUse)>;
106105

107-
using CandidateSliceComparer =
108-
std::function<int(RewriterBase &, OffsetSizeAndStrideOpInterface,
109-
OffsetSizeAndStrideOpInterface, CandidateDefOrUse)>;
106+
using CandidateSliceComparer = std::function<int(
107+
OffsetSizeAndStrideOpInterface, OffsetSizeAndStrideOpInterface)>;
110108

111109
static LogicalResult
112110
noTilingOnReductionFilter(RewriterBase &rewriter,
@@ -325,22 +323,19 @@ computeTileSizeProductOfCandidate(OffsetSizeAndStrideOpInterface candidate) {
325323
return totalSize;
326324
}
327325

328-
static int TilingSizeComparer(RewriterBase &rewriter,
329-
OffsetSizeAndStrideOpInterface candidateA,
330-
OffsetSizeAndStrideOpInterface candidateB,
331-
CandidateDefOrUse defOrUse) {
326+
static int TilingSizeComparer(OffsetSizeAndStrideOpInterface candidateA,
327+
OffsetSizeAndStrideOpInterface candidateB) {
332328
FailureOr<int64_t> sizeProductA =
333329
computeTileSizeProductOfCandidate(candidateA),
334330
sizeProductB =
335331
computeTileSizeProductOfCandidate(candidateB);
336332
if (failed(sizeProductA) || failed(sizeProductB))
337333
return 0;
338334
// deal with equality
339-
if (*sizeProductA == *sizeProductB) {
335+
if (*sizeProductA == *sizeProductB)
340336
return 0;
341-
} else {
342-
return *sizeProductA < *sizeProductB ? -1 : 1;
343-
}
337+
338+
return *sizeProductA < *sizeProductB ? -1 : 1;
344339
}
345340

346341
struct CandidateSliceComparerPipeLine
@@ -352,17 +347,15 @@ struct CandidateSliceComparerPipeLine
352347
return SmallVector<CandidateSliceComparer>{TilingSizeComparer};
353348
}
354349

355-
bool compare(RewriterBase &rewriter,
356-
OffsetSizeAndStrideOpInterface candidateA,
357-
OffsetSizeAndStrideOpInterface candidateB,
358-
CandidateDefOrUse defOrUse) const {
350+
bool compare(OffsetSizeAndStrideOpInterface candidateA,
351+
OffsetSizeAndStrideOpInterface candidateB) const {
359352
// deal with weak order
360353
int cmpResult = -1;
361-
for (auto &fn : candidateProcessFn) {
362-
cmpResult = fn(rewriter, candidateA, candidateB, defOrUse);
363-
if (cmpResult != 0)
364-
break;
365-
}
354+
llvm::any_of(candidateProcessFn, [&cmpResult, &candidateA, &candidateB](
355+
const CandidateSliceComparer &fn) {
356+
cmpResult = fn(candidateA, candidateB);
357+
return cmpResult != 0;
358+
});
366359
return cmpResult == -1;
367360
}
368361
};
@@ -389,6 +382,29 @@ struct CandidateSliceOptions {
389382
}
390383
};
391384

385+
static FailureOr<OffsetSizeAndStrideOpInterface> filterAndSelectCandidate(
386+
RewriterBase &rewriter,
387+
ArrayRef<OffsetSizeAndStrideOpInterface> candidateSliceList,
388+
const CandidateDefOrUse &defOrUse, const CandidateSliceOptions &options) {
389+
SmallVector<OffsetSizeAndStrideOpInterface> validCandidates =
390+
llvm::to_vector(llvm::make_filter_range(
391+
candidateSliceList,
392+
[&rewriter, &options,
393+
&defOrUse](const OffsetSizeAndStrideOpInterface &candidate) {
394+
return succeeded(
395+
options.filterPipeLine.filter(rewriter, candidate, defOrUse));
396+
}));
397+
if (validCandidates.empty())
398+
return failure();
399+
400+
OffsetSizeAndStrideOpInterface bestCandidate = *llvm::min_element(
401+
validCandidates, [&options](OffsetSizeAndStrideOpInterface &candidateA,
402+
OffsetSizeAndStrideOpInterface &candidateB) {
403+
return options.comparerPipeLine.compare(candidateA, candidateB);
404+
});
405+
return bestCandidate;
406+
}
407+
392408
std::optional<scf::SCFFuseProducerOfSliceResult>
393409
tileAndFuseProducerOfOpOperand(RewriterBase &rewriter, OpOperand &operand,
394410
const CandidateSliceOptions &options) {
@@ -412,31 +428,20 @@ tileAndFuseProducerOfOpOperand(RewriterBase &rewriter, OpOperand &operand,
412428
return std::nullopt;
413429

414430
CandidateDefOrUse defOrUse{*realProducer};
415-
// d. Filter out invalid candidates
416-
SmallVector<tensor::ExtractSliceOp> validCandidates =
417-
llvm::to_vector(llvm::make_filter_range(
418-
backwardSlice,
419-
[&rewriter, &options, &defOrUse](tensor::ExtractSliceOp &candidate) {
420-
return succeeded(options.filterPipeLine.filter(
421-
rewriter,
422-
cast<OffsetSizeAndStrideOpInterface>(candidate.getOperation()),
423-
defOrUse));
424-
}));
425-
if (validCandidates.empty())
431+
// d. Filter out invalid candidates and select best candidates
432+
SmallVector<OffsetSizeAndStrideOpInterface> ossBackwardSlice =
433+
llvm::map_to_vector(backwardSlice,
434+
[](tensor::ExtractSliceOp &extractSlice) {
435+
return cast<OffsetSizeAndStrideOpInterface>(
436+
extractSlice.getOperation());
437+
});
438+
FailureOr<OffsetSizeAndStrideOpInterface> bestCandidate =
439+
filterAndSelectCandidate(rewriter, ossBackwardSlice, defOrUse, options);
440+
if (failed(bestCandidate))
426441
return std::nullopt;
427-
// e. Select best candidates by Cost Model
428-
tensor::ExtractSliceOp bestCandidate = *llvm::min_element(
429-
validCandidates,
430-
[&rewriter, &options, &defOrUse](tensor::ExtractSliceOp &candidateA,
431-
tensor::ExtractSliceOp &candidateB) {
432-
return options.comparerPipeLine.compare(
433-
rewriter,
434-
cast<OffsetSizeAndStrideOpInterface>(candidateA.getOperation()),
435-
cast<OffsetSizeAndStrideOpInterface>(candidateB.getOperation()),
436-
defOrUse);
437-
});
438-
// f. call tiling interface
439-
return scfX::tileAndFuseProducerOfSlice(rewriter, bestCandidate);
442+
443+
// e. call tiling interface
444+
return scfX::tileAndFuseProducerOfSlice(rewriter, *bestCandidate);
440445
}
441446

442447
std::optional<SmallVector<scf::SCFFuseConsumerOfSliceResult>>
@@ -464,28 +469,15 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
464469
continue;
465470

466471
CandidateDefOrUse defOrUse{useOperand};
467-
// d. Filter out invalid candidates
468-
SmallVector<OffsetSizeAndStrideOpInterface> validCandidates =
469-
llvm::to_vector(llvm::make_filter_range(
470-
forwardSlice, [&rewriter, &options, &defOrUse](
471-
const OffsetSizeAndStrideOpInterface &candidate) {
472-
return succeeded(
473-
options.filterPipeLine.filter(rewriter, candidate, defOrUse));
474-
}));
475-
if (validCandidates.empty())
472+
// d. Filter out invalid candidates and select best candidates
473+
FailureOr<OffsetSizeAndStrideOpInterface> bestCandidate =
474+
filterAndSelectCandidate(rewriter, forwardSlice, defOrUse, options);
475+
if (failed(bestCandidate))
476476
continue;
477477

478-
// e. Select best candidates by Cost Model
479-
OffsetSizeAndStrideOpInterface bestCandidate = *llvm::min_element(
480-
validCandidates, [&rewriter, &options, &defOrUse](
481-
const OffsetSizeAndStrideOpInterface &candidateA,
482-
const OffsetSizeAndStrideOpInterface &candidateB) {
483-
return options.comparerPipeLine.compare(rewriter, candidateA,
484-
candidateB, defOrUse);
485-
});
486-
// f. call tiling interface
478+
// e. call tiling interface
487479
FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
488-
scfX::tileAndFuseConsumerOfSlice(rewriter, bestCandidate);
480+
scfX::tileAndFuseConsumerOfSlice(rewriter, *bestCandidate);
489481

490482
if (succeeded(fusedResult)) {
491483
fusedResultList.push_back(*fusedResult);
@@ -496,7 +488,7 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
496488
};
497489
SmallVector<LoopLikeOpInterface> outerLoops =
498490
scfX::getOuterNestLoopsWhile(
499-
bestCandidate->getParentOfType<LoopLikeOpInterface>(),
491+
(*bestCandidate)->getParentOfType<LoopLikeOpInterface>(),
500492
whileProducerOutOfLoopBlock);
501493
// g. Manually run cse on region which contains top-level loop of
502494
// candidate slice in avoid of conflict with subsequent
@@ -506,11 +498,10 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
506498
{*outerLoops.front()->getParentRegion()});
507499
}
508500
}
509-
if (fusedResultList.empty()) {
501+
if (fusedResultList.empty())
510502
return std::nullopt;
511-
} else {
512-
return fusedResultList;
513-
}
503+
504+
return fusedResultList;
514505
}
515506

516507
/// Target at following general topology:
@@ -527,7 +518,7 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
527518
LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
528519
RewriterBase &rewriter, Operation *tiledOp,
529520
const CandidateSliceOptions &options) {
530-
int numTiledOps = 0;
521+
unsigned numTiledOps = 0;
531522
std::deque<Operation *> tiledOpList = {tiledOp};
532523
while (!tiledOpList.empty()) {
533524
tiledOp = tiledOpList.front();
@@ -552,7 +543,7 @@ LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
552543
return success(numTiledOps > 1);
553544
}
554545

555-
/// What is single tiled op in loop?
546+
/// What is self tiled op compared with other fused op?
556547
/// E.g.
557548
/// %1 = scf.for(){
558549
/// %2 = scf.for(){
@@ -562,7 +553,7 @@ LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
562553
/// yield %5
563554
/// }
564555
/// }
565-
static LogicalResult isSingleTiledOpInLoop(Operation *targetOp) {
556+
static LogicalResult isSelfTiledOp(Operation *targetOp) {
566557
// 0. check tilable
567558
if (!isa<TilingInterface>(targetOp))
568559
return failure();
@@ -694,16 +685,15 @@ static bool defaultTilingOfType(RewriterBase &rewriter, Operation *op,
694685
if (succeeded(tilingResult)) {
695686
rewriter.replaceOp(op, tilingResult->replacements);
696687
return true;
697-
} else {
698-
return false;
699688
}
689+
return false;
700690
}
701691

702692
void iterativeTilingAndFusionUntilExhaustion(
703693
RewriterBase &rewriter, func::FuncOp &f,
704694
const CandidateSliceOptions &sliceOptions, const OpTileSizeMap &tsMap) {
705695
// Collect untiled and tiled ops respectively
706-
llvm::SetVector<Operation *> singleTiledOpInLoop, unTiledOps;
696+
llvm::SetVector<Operation *> selfTiledOp, unTiledOps;
707697

708698
auto collectUnTiledOps = [&f, &unTiledOps]() -> bool {
709699
// Reset
@@ -712,8 +702,7 @@ void iterativeTilingAndFusionUntilExhaustion(
712702
f->walk<WalkOrder::PreOrder>([&unTiledOps](Operation *op) {
713703
if (isa<LoopLikeOpInterface>(op))
714704
return WalkResult::skip();
715-
716-
if (isa<TilingInterface>(op) && !op->use_empty()) {
705+
if (isa<TilingInterface>(op)) {
717706
auto parentLoop = op->getParentOfType<LoopLikeOpInterface>();
718707
if (!parentLoop.getOperation())
719708
unTiledOps.insert(op);
@@ -723,32 +712,32 @@ void iterativeTilingAndFusionUntilExhaustion(
723712
return !unTiledOps.empty();
724713
};
725714

726-
auto collectSingleTiledOpInLoop = [&f, &singleTiledOpInLoop]() -> bool {
715+
auto collectSelfTiledOp = [&f, &selfTiledOp]() -> bool {
727716
// Reset
728-
singleTiledOpInLoop.clear();
717+
selfTiledOp.clear();
729718
// Walk through funcOp
730-
f->walk([&singleTiledOpInLoop](Operation *op) {
719+
f->walk([&selfTiledOp](Operation *op) {
731720
// Target at certain kind of tiled op, such as matmul/conv implemented
732721
// by multiple level of nest loops and candidate slices for better
733722
// utilization of parallelism and memory hierarchy.
734-
if (succeeded(isSingleTiledOpInLoop(op))) {
735-
singleTiledOpInLoop.insert(op);
723+
if (succeeded(isSelfTiledOp(op))) {
724+
selfTiledOp.insert(op);
736725
}
737726
});
738-
return !singleTiledOpInLoop.empty();
727+
return !selfTiledOp.empty();
739728
};
740729

741730
// Iterative tiling and fusion until exhaustion.
742731
while (collectUnTiledOps()) {
743732
// If existing tiled op before tiling.
744-
if (collectSingleTiledOpInLoop()) {
733+
if (collectSelfTiledOp()) {
745734
// Sort by topology
746-
mlir::topologicalSort(singleTiledOpInLoop);
735+
mlir::topologicalSort(selfTiledOp);
747736
// Record if any fusion happens
748737
bool changed = false;
749738
// Iteratively fuse in forward and backward fashion.
750-
llvm::for_each(singleTiledOpInLoop, [&rewriter, &sliceOptions,
751-
&changed](Operation *tiledOp) {
739+
llvm::for_each(selfTiledOp, [&rewriter, &sliceOptions,
740+
&changed](Operation *tiledOp) {
752741
changed |= succeeded(iterativelyFuseProducerAndConsumerOfTiledOp(
753742
rewriter, tiledOp, sliceOptions));
754743
});
@@ -774,7 +763,7 @@ void iterativeTilingAndFusionUntilExhaustion(
774763
std::cref(tsMap)));
775764
})) {
776765
// If no op can be tiled
777-
break;
766+
return;
778767
}
779768
}
780769
}

0 commit comments

Comments
 (0)