Skip to content

Commit 8d37c20

Browse files
committed
rename pass name and add to CPUPipeline
1 parent 0ce727e commit 8d37c20

File tree

5 files changed

+107
-83
lines changed

5 files changed

+107
-83
lines changed

include/gc/Transforms/Passes.td

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> {
5858
];
5959
}
6060

61-
def AnyTilableFusion : Pass<"any-tilable-fusion",
61+
def FineGrainedFusion : Pass<"fine-grained-fusion",
6262
"func::FuncOp"> {
63-
let summary = "Fusion for any tilable operation";
63+
let summary = "Fine Grained Fusion for any tilable operation";
6464
let description = [{
6565
The pass tries to fuse any MLIR operation which can be tiled. Moreover, this pass aims to support:
6666
1. Matmul fusion with element-wise/reduce/broadcast ops.
@@ -74,7 +74,6 @@ def AnyTilableFusion : Pass<"any-tilable-fusion",
7474
* `0`: disable any fusion.
7575
* `1`:[Default] enable pre-op fusion + post-op fusion covering any tilable operation including tensor.pack/tensor.fill/linalg.reduce etc but excluding branches forked by multiple uses.
7676
* `2`: `LEVEL 1` + extend to any topology including branches.
77-
* `3`: `LEVEL 2` + support coarse-grained fusion.
7877
}];
7978
let dependentDialects = ["func::FuncDialect", "linalg::LinalgDialect", "scf::SCFDialect",
8079
"tensor::TensorDialect"];

lib/gc/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ add_mlir_library(GCPasses
1313
OneDNNGraphToLinalg.cpp
1414
Pipeline.cpp
1515
TileNamed.cpp
16-
AnyTilableFusion.cpp
16+
FineGrainedFusion.cpp
1717
TilingUsingInterfaceX.cpp
1818

1919
ADDITIONAL_HEADER_DIRS

lib/gc/Transforms/AnyTilableFusion.cpp renamed to lib/gc/Transforms/FineGrainedFusion.cpp

Lines changed: 101 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===-- AnyTilableFusion.cpp - Fusion For Any Tilable Op --------*- C++ -*-===//
1+
//===-- FineGrainedFusion.cpp - Fine-Grained Fusion -------------*- C++ -*-===//
22
//
33
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "gc/Transforms/Passes.h"
10+
#include "mlir/Analysis/TopologicalSortUtils.h"
1011
#include "mlir/Dialect/DLTI/Traits.h"
1112
#include "mlir/Dialect/Func/IR/FuncOps.h"
1213
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -34,7 +35,7 @@
3435

3536
namespace mlir {
3637
namespace gc {
37-
#define GEN_PASS_DEF_ANYTILABLEFUSION
38+
#define GEN_PASS_DEF_FINEGRAINEDFUSION
3839
#include "gc/Transforms/Passes.h.inc"
3940

4041
static FailureOr<tensor::ExtractSliceOp>
@@ -266,14 +267,14 @@ template <typename T1, typename T2> struct CandidateSliceProcessPipeLine {
266267
: CandidateSliceProcessPipeLine() {
267268
append(newFn);
268269
}
269-
CandidateSliceProcessPipeLine(const SmallVector<T1> &newFns)
270+
CandidateSliceProcessPipeLine(ArrayRef<T1> newFns)
270271
: CandidateSliceProcessPipeLine() {
271272
append(newFns);
272273
}
273274

274275
void append(const T1 &newFn) { candidateProcessFn.push_back(newFn); }
275-
void append(const SmallVector<T1> &newFns) {
276-
candidateProcessFn.append(newFns);
276+
void append(ArrayRef<T1> newFns) {
277+
llvm::append_range(candidateProcessFn, newFns);
277278
}
278279

279280
SmallVector<T1> getDefaultPipeLine() { return {}; }
@@ -282,6 +283,7 @@ template <typename T1, typename T2> struct CandidateSliceProcessPipeLine {
282283
struct CandidateSliceFilterPipeLine
283284
: public CandidateSliceProcessPipeLine<CandidateSliceFilter,
284285
CandidateSliceFilterPipeLine> {
286+
CandidateSliceFilterPipeLine() : CandidateSliceProcessPipeLine() {}
285287
CandidateSliceFilterPipeLine(const CandidateSliceFilter &filter)
286288
: CandidateSliceProcessPipeLine(filter) {}
287289
CandidateSliceFilterPipeLine(const SmallVector<CandidateSliceFilter> &filters)
@@ -362,9 +364,31 @@ struct CandidateSliceComparerPipeLine
362364
}
363365
};
364366

365-
std::optional<scf::SCFFuseProducerOfSliceResult> tileAndFuseProducerOfOpOperand(
366-
RewriterBase &rewriter, OpOperand &operand,
367-
const CandidateSliceFilterPipeLine &filterPipeLine) {
367+
struct CandidateSliceOptions {
368+
// Use for validity
369+
CandidateSliceFilterPipeLine filterPipeLine;
370+
// Use for performance
371+
CandidateSliceComparerPipeLine comparerPipeLine;
372+
373+
CandidateSliceOptions() = default;
374+
375+
void addFilter(const CandidateSliceFilter &filter) {
376+
filterPipeLine.append(filter);
377+
}
378+
void addFilter(ArrayRef<CandidateSliceFilter> filters) {
379+
filterPipeLine.append(filters);
380+
}
381+
void addComparer(const CandidateSliceComparer &comparer) {
382+
comparerPipeLine.append(comparer);
383+
}
384+
void addFilter(ArrayRef<CandidateSliceComparer> comparers) {
385+
comparerPipeLine.append(comparers);
386+
}
387+
};
388+
389+
std::optional<scf::SCFFuseProducerOfSliceResult>
390+
tileAndFuseProducerOfOpOperand(RewriterBase &rewriter, OpOperand &operand,
391+
const CandidateSliceOptions &options) {
368392
// a. Find the closest sliceOp
369393
FailureOr<tensor::ExtractSliceOp> closestSliceOp =
370394
getClosestExtractSliceOfOperand(operand);
@@ -388,22 +412,21 @@ std::optional<scf::SCFFuseProducerOfSliceResult> tileAndFuseProducerOfOpOperand(
388412
// d. Filter out invalid candidates
389413
SmallVector<tensor::ExtractSliceOp> validCandidates =
390414
llvm::to_vector(llvm::make_filter_range(
391-
backwardSlice, [&rewriter, &filterPipeLine,
392-
&defOrUse](tensor::ExtractSliceOp &candidate) {
393-
return succeeded(filterPipeLine.filter(
415+
backwardSlice,
416+
[&rewriter, &options, &defOrUse](tensor::ExtractSliceOp &candidate) {
417+
return succeeded(options.filterPipeLine.filter(
394418
rewriter,
395419
cast<OffsetSizeAndStrideOpInterface>(candidate.getOperation()),
396420
defOrUse));
397421
}));
398422
if (validCandidates.empty())
399423
return std::nullopt;
400424
// e. Select best candidates by Cost Model
401-
CandidateSliceComparerPipeLine comparePipeLine;
402425
tensor::ExtractSliceOp bestCandidate = *llvm::min_element(
403-
validCandidates, [&rewriter, &comparePipeLine,
404-
&defOrUse](tensor::ExtractSliceOp &candidateA,
405-
tensor::ExtractSliceOp &candidateB) {
406-
return comparePipeLine.compare(
426+
validCandidates,
427+
[&rewriter, &options, &defOrUse](tensor::ExtractSliceOp &candidateA,
428+
tensor::ExtractSliceOp &candidateB) {
429+
return options.comparerPipeLine.compare(
407430
rewriter,
408431
cast<OffsetSizeAndStrideOpInterface>(candidateA.getOperation()),
409432
cast<OffsetSizeAndStrideOpInterface>(candidateB.getOperation()),
@@ -414,9 +437,8 @@ std::optional<scf::SCFFuseProducerOfSliceResult> tileAndFuseProducerOfOpOperand(
414437
}
415438

416439
std::optional<SmallVector<scf::SCFFuseConsumerOfSliceResult>>
417-
tileAndFuseConsumerOfOpResult(
418-
RewriterBase &rewriter, OpResult result,
419-
const CandidateSliceFilterPipeLine &filterPipeLine) {
440+
tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
441+
const CandidateSliceOptions &options) {
420442
// a. Find the closest sliceOp
421443
FailureOr<tensor::ExtractSliceOp> closestSliceOp =
422444
getClosestInsertSliceOfResult(result);
@@ -443,22 +465,21 @@ tileAndFuseConsumerOfOpResult(
443465
// d. Filter out invalid candidates
444466
SmallVector<OffsetSizeAndStrideOpInterface> validCandidates =
445467
llvm::to_vector(llvm::make_filter_range(
446-
forwardSlice, [&rewriter, &filterPipeLine, &defOrUse](
468+
forwardSlice, [&rewriter, &options, &defOrUse](
447469
const OffsetSizeAndStrideOpInterface &candidate) {
448470
return succeeded(
449-
filterPipeLine.filter(rewriter, candidate, defOrUse));
471+
options.filterPipeLine.filter(rewriter, candidate, defOrUse));
450472
}));
451473
if (validCandidates.empty())
452474
continue;
453475

454476
// e. Select best candidates by Cost Model
455-
CandidateSliceComparerPipeLine comparePipeLine;
456477
OffsetSizeAndStrideOpInterface bestCandidate = *llvm::min_element(
457-
validCandidates, [&rewriter, &comparePipeLine, &defOrUse](
478+
validCandidates, [&rewriter, &options, &defOrUse](
458479
const OffsetSizeAndStrideOpInterface &candidateA,
459480
const OffsetSizeAndStrideOpInterface &candidateB) {
460-
return comparePipeLine.compare(rewriter, candidateA, candidateB,
461-
defOrUse);
481+
return options.comparerPipeLine.compare(rewriter, candidateA,
482+
candidateB, defOrUse);
462483
});
463484
// f. call tiling interface
464485
FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
@@ -496,49 +517,52 @@ tileAndFuseConsumerOfOpResult(
496517
*
497518
* producer1 producer2
498519
* \ /
499-
* tiledOp
520+
* Op
500521
* / \
501522
* consumer1 consumer2
502523
*
503524
* where:
504525
*
505-
* 1. tiled op is responsible for providing scheduled parallel loops and
506-
* several candidate sliceOp including both Producer and Consumer.
507-
* 2. support both pre-op and post-op fusion: try to fuse all of producers and
508-
* consumers of tiled op.
509-
* 3. recursively call forward and backward Fusion on either fused producer or
510-
* consumer op based on BFS.
526+
* Support iterative producer and consumer fusion in BFS fashion.
511527
*/
512-
void IterativelyFuseProducerAndConsumerOfTiledOp(
528+
void iterativelyFuseProducerAndConsumerOfTiledOp(
513529
RewriterBase &rewriter, Operation *tiledOp,
514530
TargetSystemSpecInterface targetSpec) {
515-
516-
// User-defined filter to control whether to fuse or not. If more than one
517-
// filters need given, please use filter list instead.
518-
// E.g.
519-
// SmallVector<CandidateSliceFilter> customizedFilterList
520-
// = {customizedFilter1, customizedFilter2, customizedFilter3, ...};
531+
// Flexible options to control which candidate slice would be selected from
532+
// the view of both validity and performance.
533+
CandidateSliceOptions options;
534+
// User-defined filter to control whether to fuse or not. For instance, the
535+
// maximum amount of fused ops is limited to 20(only used for example).
536+
int64_t numTiledOps = 0;
521537
CandidateSliceFilter customizedFilter =
522-
[](RewriterBase &rewriter, OffsetSizeAndStrideOpInterface candidate,
523-
CandidateDefOrUse defOrUse) -> LogicalResult { return success(); };
538+
[&numTiledOps](RewriterBase &rewriter,
539+
OffsetSizeAndStrideOpInterface candidate,
540+
CandidateDefOrUse defOrUse) -> LogicalResult {
541+
return success(numTiledOps < 20);
542+
};
543+
// If more than one filters need given, please use filter list instead. E.g.
544+
//
545+
// SmallVector<CandidateSliceFilter> customizedFilterList
546+
// = {customizedFilter1, customizedFilter2, ...};
547+
options.addFilter(customizedFilter);
524548

525549
std::deque<Operation *> tiledOpList = {tiledOp};
526550
while (!tiledOpList.empty()) {
527551
tiledOp = tiledOpList.front();
528552
tiledOpList.pop_front();
553+
numTiledOps++;
529554
// fuse producer
530555
for (OpOperand &operand : tiledOp->getOpOperands()) {
531556
if (std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult =
532-
tileAndFuseProducerOfOpOperand(rewriter, operand,
533-
customizedFilter)) {
557+
tileAndFuseProducerOfOpOperand(rewriter, operand, options)) {
534558
tiledOpList.push_back(fuseProducerResult.value().tiledOps[0]);
535559
}
536560
}
537561
// fuse consumer(s)
538562
for (OpResult result : tiledOp->getResults()) {
539563
if (std::optional<SmallVector<scf::SCFFuseConsumerOfSliceResult>>
540-
fuseConsumerResults = tileAndFuseConsumerOfOpResult(
541-
rewriter, result, customizedFilter)) {
564+
fuseConsumerResults =
565+
tileAndFuseConsumerOfOpResult(rewriter, result, options)) {
542566
for (auto &fuseConsumerResult : *fuseConsumerResults) {
543567
tiledOpList.push_back(fuseConsumerResult.tiledOps[0]);
544568
}
@@ -548,10 +572,7 @@ void IterativelyFuseProducerAndConsumerOfTiledOp(
548572
}
549573

550574
/**
551-
* What is Tiled Op?
552-
* 1. located in a for loop
553-
* 2. it is the only one TilingInterface op in for loop
554-
* 3. has extract/insert slice
575+
* What is single tiled op in loop?
555576
*
556577
* E.g.
557578
* %1 = scf.for(){
@@ -564,7 +585,7 @@ void IterativelyFuseProducerAndConsumerOfTiledOp(
564585
* }
565586
*
566587
* */
567-
static LogicalResult isTiledOp(Operation *targetOp) {
588+
static LogicalResult isSingleTiledOpInLoop(Operation *targetOp) {
568589
// 0. check tilable
569590
if (!isa<TilingInterface>(targetOp)) {
570591
return failure();
@@ -595,37 +616,40 @@ static LogicalResult isTiledOp(Operation *targetOp) {
595616
return success(walkResult.wasInterrupted());
596617
}
597618

598-
static void FineGrainedFusion(RewriterBase &rewriter, func::FuncOp f,
599-
TargetSystemSpecInterface targetSpec) {
600-
SmallVector<Operation *> tiledOpList;
601-
// Walk through func operation.
602-
f->walk([&tiledOpList](Operation *op) {
603-
// Target at tiled op, like matmul/conv
604-
if (succeeded(isTiledOp(op))) {
605-
tiledOpList.push_back(op);
606-
}
607-
});
608-
// Fuse all tilable ops around tiled op in forward and backward fashion.
609-
for (auto &tiledOp : tiledOpList) {
610-
IterativelyFuseProducerAndConsumerOfTiledOp(rewriter, tiledOp, targetSpec);
611-
}
612-
}
613-
614-
struct AnyTilableFusion : public impl::AnyTilableFusionBase<AnyTilableFusion> {
619+
struct FineGrainedFusion
620+
: public impl::FineGrainedFusionBase<FineGrainedFusion> {
615621

616622
public:
617623
void runOnOperation() final {
618624
auto &ctx = getContext();
619-
// Get funcOp
620-
func::FuncOp func = getOperation();
621-
// Get target descriptor
622-
TargetSystemSpecInterface targetSpec =
623-
mlir::impl::getTargetSystemSpec(func);
624-
// Get rewriter
625-
IRRewriter rewriter(&ctx);
626-
// Do fine-grained fusion
627-
FineGrainedFusion(rewriter, func, targetSpec);
628-
// Perhaps coarse-grained fusion here
625+
{
626+
// Get funcOp
627+
func::FuncOp func = getOperation();
628+
// Get target descriptor
629+
TargetSystemSpecInterface targetSpec =
630+
mlir::impl::getTargetSystemSpec(func);
631+
// Get rewriter
632+
IRRewriter rewriter(&ctx);
633+
634+
// Collect tiled ops before fusion
635+
llvm::SetVector<Operation *> tiledOps;
636+
// Walk through funcOp
637+
func->walk([&tiledOps](Operation *op) {
638+
// Target at certain kind of tiled op, such as matmul/conv implemented
639+
// by multiple level of nest loops and candidate slices for better
640+
// utilization of parallelism and memory hierarchy.
641+
if (succeeded(isSingleTiledOpInLoop(op))) {
642+
tiledOps.insert(op);
643+
}
644+
});
645+
// Sort by topology
646+
mlir::topologicalSort(tiledOps);
647+
// Iteratively fuse in forward and backward fashion.
648+
for (auto &tiledOp : tiledOps) {
649+
iterativelyFuseProducerAndConsumerOfTiledOp(rewriter, tiledOp,
650+
targetSpec);
651+
}
652+
}
629653

630654
{
631655
RewritePatternSet patternSet(&ctx);

lib/gc/Transforms/Pipeline.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ void populateTensorPasses(mlir::OpPassManager &pm) {
4343
// todo: layout propagation pass
4444
// todo: tensor constant propagation pass
4545
// todo: linalg.matmul lowering to (scf.loop + linalg.brgemm) pass
46-
// todo: fine-grain fusion pass
46+
// Fine-grain fusion pass
47+
pm.addNestedPass<func::FuncOp>(createFineGrainedFusion());
4748
// todo: lower linalg to arith/math on virtual vector pass
4849

4950
// REMOVE this pass after the above passes are added. Currently we add this

test/gc/Transform/any-tilable-fusion.mlir renamed to test/gc/Transform/fine-grained-fusion.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: gc-opt --split-input-file -any-tilable-fusion %s
1+
// RUN: gc-opt --split-input-file -fine-grained-fusion %s
22

33
module {
44
func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> {

0 commit comments

Comments
 (0)