1
- // ===-- AnyTilableFusion .cpp - Fusion For Any Tilable Op --------*- C++ -*-===//
1
+ // ===-- FineGrainedFusion .cpp - Fine-Grained Fusion ----- --------*- C++ -*-===//
2
2
//
3
3
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4
4
// See https://llvm.org/LICENSE.txt for license information.
7
7
// ===----------------------------------------------------------------------===//
8
8
9
9
#include " gc/Transforms/Passes.h"
10
+ #include " mlir/Analysis/TopologicalSortUtils.h"
10
11
#include " mlir/Dialect/DLTI/Traits.h"
11
12
#include " mlir/Dialect/Func/IR/FuncOps.h"
12
13
#include " mlir/Dialect/Linalg/IR/Linalg.h"
34
35
35
36
namespace mlir {
36
37
namespace gc {
37
- #define GEN_PASS_DEF_ANYTILABLEFUSION
38
+ #define GEN_PASS_DEF_FINEGRAINEDFUSION
38
39
#include " gc/Transforms/Passes.h.inc"
39
40
40
41
static FailureOr<tensor::ExtractSliceOp>
@@ -266,14 +267,14 @@ template <typename T1, typename T2> struct CandidateSliceProcessPipeLine {
266
267
: CandidateSliceProcessPipeLine() {
267
268
append (newFn);
268
269
}
269
- CandidateSliceProcessPipeLine (const SmallVector <T1> & newFns)
270
+ CandidateSliceProcessPipeLine (ArrayRef <T1> newFns)
270
271
: CandidateSliceProcessPipeLine() {
271
272
append (newFns);
272
273
}
273
274
274
275
void append (const T1 &newFn) { candidateProcessFn.push_back (newFn); }
275
- void append (const SmallVector <T1> & newFns) {
276
- candidateProcessFn. append ( newFns);
276
+ void append (ArrayRef <T1> newFns) {
277
+ llvm::append_range (candidateProcessFn, newFns);
277
278
}
278
279
279
280
SmallVector<T1> getDefaultPipeLine () { return {}; }
@@ -282,6 +283,7 @@ template <typename T1, typename T2> struct CandidateSliceProcessPipeLine {
282
283
struct CandidateSliceFilterPipeLine
283
284
: public CandidateSliceProcessPipeLine<CandidateSliceFilter,
284
285
CandidateSliceFilterPipeLine> {
286
+ CandidateSliceFilterPipeLine () : CandidateSliceProcessPipeLine() {}
285
287
CandidateSliceFilterPipeLine (const CandidateSliceFilter &filter)
286
288
: CandidateSliceProcessPipeLine(filter) {}
287
289
CandidateSliceFilterPipeLine (const SmallVector<CandidateSliceFilter> &filters)
@@ -362,9 +364,31 @@ struct CandidateSliceComparerPipeLine
362
364
}
363
365
};
364
366
365
- std::optional<scf::SCFFuseProducerOfSliceResult> tileAndFuseProducerOfOpOperand (
366
- RewriterBase &rewriter, OpOperand &operand,
367
- const CandidateSliceFilterPipeLine &filterPipeLine) {
367
+ struct CandidateSliceOptions {
368
+ // Use for validity
369
+ CandidateSliceFilterPipeLine filterPipeLine;
370
+ // Use for performance
371
+ CandidateSliceComparerPipeLine comparerPipeLine;
372
+
373
+ CandidateSliceOptions () = default ;
374
+
375
+ void addFilter (const CandidateSliceFilter &filter) {
376
+ filterPipeLine.append (filter);
377
+ }
378
+ void addFilter (ArrayRef<CandidateSliceFilter> filters) {
379
+ filterPipeLine.append (filters);
380
+ }
381
+ void addComparer (const CandidateSliceComparer &comparer) {
382
+ comparerPipeLine.append (comparer);
383
+ }
384
+ void addFilter (ArrayRef<CandidateSliceComparer> comparers) {
385
+ comparerPipeLine.append (comparers);
386
+ }
387
+ };
388
+
389
+ std::optional<scf::SCFFuseProducerOfSliceResult>
390
+ tileAndFuseProducerOfOpOperand (RewriterBase &rewriter, OpOperand &operand,
391
+ const CandidateSliceOptions &options) {
368
392
// a. Find the closest sliceOp
369
393
FailureOr<tensor::ExtractSliceOp> closestSliceOp =
370
394
getClosestExtractSliceOfOperand (operand);
@@ -388,22 +412,21 @@ std::optional<scf::SCFFuseProducerOfSliceResult> tileAndFuseProducerOfOpOperand(
388
412
// d. Filter out invalid candidates
389
413
SmallVector<tensor::ExtractSliceOp> validCandidates =
390
414
llvm::to_vector (llvm::make_filter_range (
391
- backwardSlice, [&rewriter, &filterPipeLine,
392
- &defOrUse](tensor::ExtractSliceOp &candidate) {
393
- return succeeded (filterPipeLine.filter (
415
+ backwardSlice,
416
+ [&rewriter, &options, &defOrUse](tensor::ExtractSliceOp &candidate) {
417
+ return succeeded (options. filterPipeLine .filter (
394
418
rewriter,
395
419
cast<OffsetSizeAndStrideOpInterface>(candidate.getOperation ()),
396
420
defOrUse));
397
421
}));
398
422
if (validCandidates.empty ())
399
423
return std::nullopt;
400
424
// e. Select best candidates by Cost Model
401
- CandidateSliceComparerPipeLine comparePipeLine;
402
425
tensor::ExtractSliceOp bestCandidate = *llvm::min_element (
403
- validCandidates, [&rewriter, &comparePipeLine,
404
- &defOrUse](tensor::ExtractSliceOp &candidateA,
405
- tensor::ExtractSliceOp &candidateB) {
406
- return comparePipeLine .compare (
426
+ validCandidates,
427
+ [&rewriter, &options, &defOrUse](tensor::ExtractSliceOp &candidateA,
428
+ tensor::ExtractSliceOp &candidateB) {
429
+ return options. comparerPipeLine .compare (
407
430
rewriter,
408
431
cast<OffsetSizeAndStrideOpInterface>(candidateA.getOperation ()),
409
432
cast<OffsetSizeAndStrideOpInterface>(candidateB.getOperation ()),
@@ -414,9 +437,8 @@ std::optional<scf::SCFFuseProducerOfSliceResult> tileAndFuseProducerOfOpOperand(
414
437
}
415
438
416
439
std::optional<SmallVector<scf::SCFFuseConsumerOfSliceResult>>
417
- tileAndFuseConsumerOfOpResult (
418
- RewriterBase &rewriter, OpResult result,
419
- const CandidateSliceFilterPipeLine &filterPipeLine) {
440
+ tileAndFuseConsumerOfOpResult (RewriterBase &rewriter, OpResult result,
441
+ const CandidateSliceOptions &options) {
420
442
// a. Find the closest sliceOp
421
443
FailureOr<tensor::ExtractSliceOp> closestSliceOp =
422
444
getClosestInsertSliceOfResult (result);
@@ -443,22 +465,21 @@ tileAndFuseConsumerOfOpResult(
443
465
// d. Filter out invalid candidates
444
466
SmallVector<OffsetSizeAndStrideOpInterface> validCandidates =
445
467
llvm::to_vector (llvm::make_filter_range (
446
- forwardSlice, [&rewriter, &filterPipeLine , &defOrUse](
468
+ forwardSlice, [&rewriter, &options , &defOrUse](
447
469
const OffsetSizeAndStrideOpInterface &candidate) {
448
470
return succeeded (
449
- filterPipeLine.filter (rewriter, candidate, defOrUse));
471
+ options. filterPipeLine .filter (rewriter, candidate, defOrUse));
450
472
}));
451
473
if (validCandidates.empty ())
452
474
continue ;
453
475
454
476
// e. Select best candidates by Cost Model
455
- CandidateSliceComparerPipeLine comparePipeLine;
456
477
OffsetSizeAndStrideOpInterface bestCandidate = *llvm::min_element (
457
- validCandidates, [&rewriter, &comparePipeLine , &defOrUse](
478
+ validCandidates, [&rewriter, &options , &defOrUse](
458
479
const OffsetSizeAndStrideOpInterface &candidateA,
459
480
const OffsetSizeAndStrideOpInterface &candidateB) {
460
- return comparePipeLine. compare (rewriter, candidateA, candidateB ,
461
- defOrUse);
481
+ return options. comparerPipeLine . compare (rewriter, candidateA,
482
+ candidateB, defOrUse);
462
483
});
463
484
// f. call tiling interface
464
485
FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
@@ -496,49 +517,52 @@ tileAndFuseConsumerOfOpResult(
496
517
*
497
518
* producer1 producer2
498
519
* \ /
499
- * tiledOp
520
+ * Op
500
521
* / \
501
522
* consumer1 consumer2
502
523
*
503
524
* where:
504
525
*
505
- * 1. tiled op is responsible for providing scheduled parallel loops and
506
- * several candidate sliceOp including both Producer and Consumer.
507
- * 2. support both pre-op and post-op fusion: try to fuse all of producers and
508
- * consumers of tiled op.
509
- * 3. recursively call forward and backward Fusion on either fused producer or
510
- * consumer op based on BFS.
526
+ * Support iterative producer and consumer fusion in BFS fashion.
511
527
*/
512
- void IterativelyFuseProducerAndConsumerOfTiledOp (
528
+ void iterativelyFuseProducerAndConsumerOfTiledOp (
513
529
RewriterBase &rewriter, Operation *tiledOp,
514
530
TargetSystemSpecInterface targetSpec) {
515
-
516
- // User-defined filter to control whether to fuse or not. If more than one
517
- // filters need given, please use filter list instead.
518
- // E.g.
519
- // SmallVector<CandidateSliceFilter> customizedFilterList
520
- // = {customizedFilter1, customizedFilter2, customizedFilter3, ...} ;
531
+ // Flexible options to control which candidate slice would be selected from
532
+ // the view of both validity and performance.
533
+ CandidateSliceOptions options;
534
+ // User-defined filter to control whether to fuse or not. For instance, the
535
+ // maximum amount of fused ops is limited to 20(only used for example).
536
+ int64_t numTiledOps = 0 ;
521
537
CandidateSliceFilter customizedFilter =
522
- [](RewriterBase &rewriter, OffsetSizeAndStrideOpInterface candidate,
523
- CandidateDefOrUse defOrUse) -> LogicalResult { return success (); };
538
+ [&numTiledOps](RewriterBase &rewriter,
539
+ OffsetSizeAndStrideOpInterface candidate,
540
+ CandidateDefOrUse defOrUse) -> LogicalResult {
541
+ return success (numTiledOps < 20 );
542
+ };
543
+ // If more than one filters need given, please use filter list instead. E.g.
544
+ //
545
+ // SmallVector<CandidateSliceFilter> customizedFilterList
546
+ // = {customizedFilter1, customizedFilter2, ...};
547
+ options.addFilter (customizedFilter);
524
548
525
549
std::deque<Operation *> tiledOpList = {tiledOp};
526
550
while (!tiledOpList.empty ()) {
527
551
tiledOp = tiledOpList.front ();
528
552
tiledOpList.pop_front ();
553
+ numTiledOps++;
529
554
// fuse producer
530
555
for (OpOperand &operand : tiledOp->getOpOperands ()) {
531
556
if (std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult =
532
- tileAndFuseProducerOfOpOperand (rewriter, operand,
533
- customizedFilter)) {
557
+ tileAndFuseProducerOfOpOperand (rewriter, operand, options)) {
534
558
tiledOpList.push_back (fuseProducerResult.value ().tiledOps [0 ]);
535
559
}
536
560
}
537
561
// fuse consumer(s)
538
562
for (OpResult result : tiledOp->getResults ()) {
539
563
if (std::optional<SmallVector<scf::SCFFuseConsumerOfSliceResult>>
540
- fuseConsumerResults = tileAndFuseConsumerOfOpResult (
541
- rewriter, result, customizedFilter )) {
564
+ fuseConsumerResults =
565
+ tileAndFuseConsumerOfOpResult ( rewriter, result, options )) {
542
566
for (auto &fuseConsumerResult : *fuseConsumerResults) {
543
567
tiledOpList.push_back (fuseConsumerResult.tiledOps [0 ]);
544
568
}
@@ -548,10 +572,7 @@ void IterativelyFuseProducerAndConsumerOfTiledOp(
548
572
}
549
573
550
574
/* *
551
- * What is Tiled Op?
552
- * 1. located in a for loop
553
- * 2. it is the only one TilingInterface op in for loop
554
- * 3. has extract/insert slice
575
+ * What is single tiled op in loop?
555
576
*
556
577
* E.g.
557
578
* %1 = scf.for(){
@@ -564,7 +585,7 @@ void IterativelyFuseProducerAndConsumerOfTiledOp(
564
585
* }
565
586
*
566
587
* */
567
- static LogicalResult isTiledOp (Operation *targetOp) {
588
+ static LogicalResult isSingleTiledOpInLoop (Operation *targetOp) {
568
589
// 0. check tilable
569
590
if (!isa<TilingInterface>(targetOp)) {
570
591
return failure ();
@@ -595,37 +616,40 @@ static LogicalResult isTiledOp(Operation *targetOp) {
595
616
return success (walkResult.wasInterrupted ());
596
617
}
597
618
598
- static void FineGrainedFusion (RewriterBase &rewriter, func::FuncOp f,
599
- TargetSystemSpecInterface targetSpec) {
600
- SmallVector<Operation *> tiledOpList;
601
- // Walk through func operation.
602
- f->walk ([&tiledOpList](Operation *op) {
603
- // Target at tiled op, like matmul/conv
604
- if (succeeded (isTiledOp (op))) {
605
- tiledOpList.push_back (op);
606
- }
607
- });
608
- // Fuse all tilable ops around tiled op in forward and backward fashion.
609
- for (auto &tiledOp : tiledOpList) {
610
- IterativelyFuseProducerAndConsumerOfTiledOp (rewriter, tiledOp, targetSpec);
611
- }
612
- }
613
-
614
- struct AnyTilableFusion : public impl ::AnyTilableFusionBase<AnyTilableFusion> {
619
+ struct FineGrainedFusion
620
+ : public impl::FineGrainedFusionBase<FineGrainedFusion> {
615
621
616
622
public:
617
623
void runOnOperation () final {
618
624
auto &ctx = getContext ();
619
- // Get funcOp
620
- func::FuncOp func = getOperation ();
621
- // Get target descriptor
622
- TargetSystemSpecInterface targetSpec =
623
- mlir::impl::getTargetSystemSpec (func);
624
- // Get rewriter
625
- IRRewriter rewriter (&ctx);
626
- // Do fine-grained fusion
627
- FineGrainedFusion (rewriter, func, targetSpec);
628
- // Perhaps coarse-grained fusion here
625
+ {
626
+ // Get funcOp
627
+ func::FuncOp func = getOperation ();
628
+ // Get target descriptor
629
+ TargetSystemSpecInterface targetSpec =
630
+ mlir::impl::getTargetSystemSpec (func);
631
+ // Get rewriter
632
+ IRRewriter rewriter (&ctx);
633
+
634
+ // Collect tiled ops before fusion
635
+ llvm::SetVector<Operation *> tiledOps;
636
+ // Walk through funcOp
637
+ func->walk ([&tiledOps](Operation *op) {
638
+ // Target at certain kind of tiled op, such as matmul/conv implemented
639
+ // by multiple level of nest loops and candidate slices for better
640
+ // utilization of parallelism and memory hierarchy.
641
+ if (succeeded (isSingleTiledOpInLoop (op))) {
642
+ tiledOps.insert (op);
643
+ }
644
+ });
645
+ // Sort by topology
646
+ mlir::topologicalSort (tiledOps);
647
+ // Iteratively fuse in forward and backward fashion.
648
+ for (auto &tiledOp : tiledOps) {
649
+ iterativelyFuseProducerAndConsumerOfTiledOp (rewriter, tiledOp,
650
+ targetSpec);
651
+ }
652
+ }
629
653
630
654
{
631
655
RewritePatternSet patternSet (&ctx);
0 commit comments