1
- // ===-- AnyTilableFusion.cpp - Fusion For Any Tilable Op --------*- C++ -*-===//
1
+ // ===-- FineGrainedFusion.cpp - Fusion For Any Tilable Op --------*- C++
2
+ // -*-===//
2
3
//
3
4
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4
5
// See https://llvm.org/LICENSE.txt for license information.
7
8
// ===----------------------------------------------------------------------===//
8
9
9
10
#include " gc/Transforms/Passes.h"
11
+ #include " mlir/Analysis/TopologicalSortUtils.h"
10
12
#include " mlir/Dialect/DLTI/Traits.h"
11
13
#include " mlir/Dialect/Func/IR/FuncOps.h"
12
14
#include " mlir/Dialect/Linalg/IR/Linalg.h"
34
36
35
37
namespace mlir {
36
38
namespace gc {
37
- #define GEN_PASS_DEF_ANYTILABLEFUSION
39
+ #define GEN_PASS_DEF_FINEGRAINEDFUSION
38
40
#include " gc/Transforms/Passes.h.inc"
39
41
40
42
static FailureOr<tensor::ExtractSliceOp>
@@ -266,14 +268,14 @@ template <typename T1, typename T2> struct CandidateSliceProcessPipeLine {
266
268
: CandidateSliceProcessPipeLine() {
267
269
append (newFn);
268
270
}
269
- CandidateSliceProcessPipeLine (const SmallVector <T1> & newFns)
271
+ CandidateSliceProcessPipeLine (ArrayRef <T1> newFns)
270
272
: CandidateSliceProcessPipeLine() {
271
273
append (newFns);
272
274
}
273
275
274
276
void append (const T1 &newFn) { candidateProcessFn.push_back (newFn); }
275
- void append (const SmallVector <T1> & newFns) {
276
- candidateProcessFn. append ( newFns);
277
+ void append (ArrayRef <T1> newFns) {
278
+ llvm::append_range (candidateProcessFn, newFns);
277
279
}
278
280
279
281
SmallVector<T1> getDefaultPipeLine () { return {}; }
@@ -282,6 +284,7 @@ template <typename T1, typename T2> struct CandidateSliceProcessPipeLine {
282
284
struct CandidateSliceFilterPipeLine
283
285
: public CandidateSliceProcessPipeLine<CandidateSliceFilter,
284
286
CandidateSliceFilterPipeLine> {
287
+ CandidateSliceFilterPipeLine () : CandidateSliceProcessPipeLine() {}
285
288
CandidateSliceFilterPipeLine (const CandidateSliceFilter &filter)
286
289
: CandidateSliceProcessPipeLine(filter) {}
287
290
CandidateSliceFilterPipeLine (const SmallVector<CandidateSliceFilter> &filters)
@@ -362,9 +365,31 @@ struct CandidateSliceComparerPipeLine
362
365
}
363
366
};
364
367
365
- std::optional<scf::SCFFuseProducerOfSliceResult> tileAndFuseProducerOfOpOperand (
366
- RewriterBase &rewriter, OpOperand &operand,
367
- const CandidateSliceFilterPipeLine &filterPipeLine) {
368
+ struct CandidateSliceOptions {
369
+ // Use for validity
370
+ CandidateSliceFilterPipeLine filterPipeLine;
371
+ // Use for performance
372
+ CandidateSliceComparerPipeLine comparerPipeLine;
373
+
374
+ CandidateSliceOptions () = default ;
375
+
376
+ void addFilter (const CandidateSliceFilter &filter) {
377
+ filterPipeLine.append (filter);
378
+ }
379
+ void addFilter (ArrayRef<CandidateSliceFilter> filters) {
380
+ filterPipeLine.append (filters);
381
+ }
382
+ void addComparer (const CandidateSliceComparer &comparer) {
383
+ comparerPipeLine.append (comparer);
384
+ }
385
+ void addFilter (ArrayRef<CandidateSliceComparer> comparers) {
386
+ comparerPipeLine.append (comparers);
387
+ }
388
+ };
389
+
390
+ std::optional<scf::SCFFuseProducerOfSliceResult>
391
+ tileAndFuseProducerOfOpOperand (RewriterBase &rewriter, OpOperand &operand,
392
+ const CandidateSliceOptions &options) {
368
393
// a. Find the closest sliceOp
369
394
FailureOr<tensor::ExtractSliceOp> closestSliceOp =
370
395
getClosestExtractSliceOfOperand (operand);
@@ -388,22 +413,21 @@ std::optional<scf::SCFFuseProducerOfSliceResult> tileAndFuseProducerOfOpOperand(
388
413
// d. Filter out invalid candidates
389
414
SmallVector<tensor::ExtractSliceOp> validCandidates =
390
415
llvm::to_vector (llvm::make_filter_range (
391
- backwardSlice, [&rewriter, &filterPipeLine,
392
- &defOrUse](tensor::ExtractSliceOp &candidate) {
393
- return succeeded (filterPipeLine.filter (
416
+ backwardSlice,
417
+ [&rewriter, &options, &defOrUse](tensor::ExtractSliceOp &candidate) {
418
+ return succeeded (options. filterPipeLine .filter (
394
419
rewriter,
395
420
cast<OffsetSizeAndStrideOpInterface>(candidate.getOperation ()),
396
421
defOrUse));
397
422
}));
398
423
if (validCandidates.empty ())
399
424
return std::nullopt;
400
425
// e. Select best candidates by Cost Model
401
- CandidateSliceComparerPipeLine comparePipeLine;
402
426
tensor::ExtractSliceOp bestCandidate = *llvm::min_element (
403
- validCandidates, [&rewriter, &comparePipeLine,
404
- &defOrUse](tensor::ExtractSliceOp &candidateA,
405
- tensor::ExtractSliceOp &candidateB) {
406
- return comparePipeLine .compare (
427
+ validCandidates,
428
+ [&rewriter, &options, &defOrUse](tensor::ExtractSliceOp &candidateA,
429
+ tensor::ExtractSliceOp &candidateB) {
430
+ return options. comparerPipeLine .compare (
407
431
rewriter,
408
432
cast<OffsetSizeAndStrideOpInterface>(candidateA.getOperation ()),
409
433
cast<OffsetSizeAndStrideOpInterface>(candidateB.getOperation ()),
@@ -414,9 +438,8 @@ std::optional<scf::SCFFuseProducerOfSliceResult> tileAndFuseProducerOfOpOperand(
414
438
}
415
439
416
440
std::optional<SmallVector<scf::SCFFuseConsumerOfSliceResult>>
417
- tileAndFuseConsumerOfOpResult (
418
- RewriterBase &rewriter, OpResult result,
419
- const CandidateSliceFilterPipeLine &filterPipeLine) {
441
+ tileAndFuseConsumerOfOpResult (RewriterBase &rewriter, OpResult result,
442
+ const CandidateSliceOptions &options) {
420
443
// a. Find the closest sliceOp
421
444
FailureOr<tensor::ExtractSliceOp> closestSliceOp =
422
445
getClosestInsertSliceOfResult (result);
@@ -443,22 +466,21 @@ tileAndFuseConsumerOfOpResult(
443
466
// d. Filter out invalid candidates
444
467
SmallVector<OffsetSizeAndStrideOpInterface> validCandidates =
445
468
llvm::to_vector (llvm::make_filter_range (
446
- forwardSlice, [&rewriter, &filterPipeLine , &defOrUse](
469
+ forwardSlice, [&rewriter, &options , &defOrUse](
447
470
const OffsetSizeAndStrideOpInterface &candidate) {
448
471
return succeeded (
449
- filterPipeLine.filter (rewriter, candidate, defOrUse));
472
+ options. filterPipeLine .filter (rewriter, candidate, defOrUse));
450
473
}));
451
474
if (validCandidates.empty ())
452
475
continue ;
453
476
454
477
// e. Select best candidates by Cost Model
455
- CandidateSliceComparerPipeLine comparePipeLine;
456
478
OffsetSizeAndStrideOpInterface bestCandidate = *llvm::min_element (
457
- validCandidates, [&rewriter, &comparePipeLine , &defOrUse](
479
+ validCandidates, [&rewriter, &options , &defOrUse](
458
480
const OffsetSizeAndStrideOpInterface &candidateA,
459
481
const OffsetSizeAndStrideOpInterface &candidateB) {
460
- return comparePipeLine. compare (rewriter, candidateA, candidateB ,
461
- defOrUse);
482
+ return options. comparerPipeLine . compare (rewriter, candidateA,
483
+ candidateB, defOrUse);
462
484
});
463
485
// f. call tiling interface
464
486
FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
@@ -496,49 +518,52 @@ tileAndFuseConsumerOfOpResult(
496
518
*
497
519
* producer1 producer2
498
520
* \ /
499
- * tiledOp
521
+ * Op
500
522
* / \
501
523
* consumer1 consumer2
502
524
*
503
525
* where:
504
526
*
505
- * 1. tiled op is responsible for providing scheduled parallel loops and
506
- * several candidate sliceOp including both Producer and Consumer.
507
- * 2. support both pre-op and post-op fusion: try to fuse all of producers and
508
- * consumers of tiled op.
509
- * 3. recursively call forward and backward Fusion on either fused producer or
510
- * consumer op based on BFS.
527
+ * Support iterative producer and consumer fusion in BFS fashion.
511
528
*/
512
- void IterativelyFuseProducerAndConsumerOfTiledOp (
529
+ void iterativelyFuseProducerAndConsumerOfTiledOp (
513
530
RewriterBase &rewriter, Operation *tiledOp,
514
531
TargetSystemSpecInterface targetSpec) {
515
-
516
- // User-defined filter to control whether to fuse or not. If more than one
517
- // filters need given, please use filter list instead.
518
- // E.g.
519
- // SmallVector<CandidateSliceFilter> customizedFilterList
520
- // = {customizedFilter1, customizedFilter2, customizedFilter3, ...} ;
532
+ // Flexible options to control which candidate slice would be selected from
533
+ // the view of both validity and performance.
534
+ CandidateSliceOptions options;
535
+ // User-defined filter to control whether to fuse or not. For instance, the
536
+ // maximum amount of fused ops is limited to 20(only used for example).
537
+ int64_t numTiledOps = 0 ;
521
538
CandidateSliceFilter customizedFilter =
522
- [](RewriterBase &rewriter, OffsetSizeAndStrideOpInterface candidate,
523
- CandidateDefOrUse defOrUse) -> LogicalResult { return success (); };
539
+ [&numTiledOps](RewriterBase &rewriter,
540
+ OffsetSizeAndStrideOpInterface candidate,
541
+ CandidateDefOrUse defOrUse) -> LogicalResult {
542
+ return success (numTiledOps < 20 );
543
+ };
544
+ // If more than one filters need given, please use filter list instead. E.g.
545
+ //
546
+ // SmallVector<CandidateSliceFilter> customizedFilterList
547
+ // = {customizedFilter1, customizedFilter2, ...};
548
+ options.addFilter (customizedFilter);
524
549
525
550
std::deque<Operation *> tiledOpList = {tiledOp};
526
551
while (!tiledOpList.empty ()) {
527
552
tiledOp = tiledOpList.front ();
528
553
tiledOpList.pop_front ();
554
+ numTiledOps++;
529
555
// fuse producer
530
556
for (OpOperand &operand : tiledOp->getOpOperands ()) {
531
557
if (std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult =
532
- tileAndFuseProducerOfOpOperand (rewriter, operand,
533
- customizedFilter)) {
558
+ tileAndFuseProducerOfOpOperand (rewriter, operand, options)) {
534
559
tiledOpList.push_back (fuseProducerResult.value ().tiledOps [0 ]);
535
560
}
536
561
}
537
562
// fuse consumer(s)
538
563
for (OpResult result : tiledOp->getResults ()) {
539
564
if (std::optional<SmallVector<scf::SCFFuseConsumerOfSliceResult>>
540
- fuseConsumerResults = tileAndFuseConsumerOfOpResult (
541
- rewriter, result, customizedFilter )) {
565
+ fuseConsumerResults =
566
+ tileAndFuseConsumerOfOpResult ( rewriter, result, options )) {
542
567
for (auto &fuseConsumerResult : *fuseConsumerResults) {
543
568
tiledOpList.push_back (fuseConsumerResult.tiledOps [0 ]);
544
569
}
@@ -548,10 +573,7 @@ void IterativelyFuseProducerAndConsumerOfTiledOp(
548
573
}
549
574
550
575
/* *
551
- * What is Tiled Op?
552
- * 1. located in a for loop
553
- * 2. it is the only one TilingInterface op in for loop
554
- * 3. has extract/insert slice
576
+ * What is single tiled op in loop?
555
577
*
556
578
* E.g.
557
579
* %1 = scf.for(){
@@ -564,7 +586,7 @@ void IterativelyFuseProducerAndConsumerOfTiledOp(
564
586
* }
565
587
*
566
588
* */
567
- static LogicalResult isTiledOp (Operation *targetOp) {
589
+ static LogicalResult isSingleTiledOpInLoop (Operation *targetOp) {
568
590
// 0. check tilable
569
591
if (!isa<TilingInterface>(targetOp)) {
570
592
return failure ();
@@ -595,37 +617,40 @@ static LogicalResult isTiledOp(Operation *targetOp) {
595
617
return success (walkResult.wasInterrupted ());
596
618
}
597
619
598
- static void FineGrainedFusion (RewriterBase &rewriter, func::FuncOp f,
599
- TargetSystemSpecInterface targetSpec) {
600
- SmallVector<Operation *> tiledOpList;
601
- // Walk through func operation.
602
- f->walk ([&tiledOpList](Operation *op) {
603
- // Target at tiled op, like matmul/conv
604
- if (succeeded (isTiledOp (op))) {
605
- tiledOpList.push_back (op);
606
- }
607
- });
608
- // Fuse all tilable ops around tiled op in forward and backward fashion.
609
- for (auto &tiledOp : tiledOpList) {
610
- IterativelyFuseProducerAndConsumerOfTiledOp (rewriter, tiledOp, targetSpec);
611
- }
612
- }
613
-
614
- struct AnyTilableFusion : public impl ::AnyTilableFusionBase<AnyTilableFusion> {
620
+ struct FineGrainedFusion
621
+ : public impl::FineGrainedFusionBase<FineGrainedFusion> {
615
622
616
623
public:
617
624
void runOnOperation () final {
618
625
auto &ctx = getContext ();
619
- // Get funcOp
620
- func::FuncOp func = getOperation ();
621
- // Get target descriptor
622
- TargetSystemSpecInterface targetSpec =
623
- mlir::impl::getTargetSystemSpec (func);
624
- // Get rewriter
625
- IRRewriter rewriter (&ctx);
626
- // Do fine-grained fusion
627
- FineGrainedFusion (rewriter, func, targetSpec);
628
- // Perhaps coarse-grained fusion here
626
+ {
627
+ // Get funcOp
628
+ func::FuncOp func = getOperation ();
629
+ // Get target descriptor
630
+ TargetSystemSpecInterface targetSpec =
631
+ mlir::impl::getTargetSystemSpec (func);
632
+ // Get rewriter
633
+ IRRewriter rewriter (&ctx);
634
+
635
+ // Collect tiled ops before fusion
636
+ llvm::SetVector<Operation *> tiledOps;
637
+ // Walk through funcOp
638
+ func->walk ([&tiledOps](Operation *op) {
639
+ // Target at certain kind of tiled op, such as matmul/conv implemented
640
+ // by multiple level of nest loops and candidate slices for better
641
+ // utilization of parallelism and memory hierarchy.
642
+ if (succeeded (isSingleTiledOpInLoop (op))) {
643
+ tiledOps.insert (op);
644
+ }
645
+ });
646
+ // Sort by topology
647
+ mlir::topologicalSort (tiledOps);
648
+ // Iteratively fuse in forward and backward fashion.
649
+ for (auto &tiledOp : tiledOps) {
650
+ iterativelyFuseProducerAndConsumerOfTiledOp (rewriter, tiledOp,
651
+ targetSpec);
652
+ }
653
+ }
629
654
630
655
{
631
656
RewritePatternSet patternSet (&ctx);
0 commit comments