Skip to content

Commit 42c24b0

Browse files
committed
rename pass name and add to CPUPipeline
1 parent 0ce727e commit 42c24b0

File tree

5 files changed

+108
-83
lines changed

5 files changed

+108
-83
lines changed

include/gc/Transforms/Passes.td

+2-3
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> {
5858
];
5959
}
6060

61-
def AnyTilableFusion : Pass<"any-tilable-fusion",
61+
def FineGrainedFusion : Pass<"fine-grained-fusion",
6262
"func::FuncOp"> {
63-
let summary = "Fusion for any tilable operation";
63+
let summary = "Fine Grained Fusion for any tilable operation";
6464
let description = [{
6565
The pass tries to fuse any MLIR operation which can be tiled. Moreover, this pass aims to support:
6666
1. Matmul fusion with element-wise/reduce/broadcast ops.
@@ -74,7 +74,6 @@ def AnyTilableFusion : Pass<"any-tilable-fusion",
7474
* `0`: disable any fusion.
7575
* `1`:[Default] enable pre-op fusion + post-op fusion covering any tilable operation including tensor.pack/tensor.fill/linalg.reduce etc but excluding branches forked by multiple uses.
7676
* `2`: `LEVEL 1` + extend to any topology including branches.
77-
* `3`: `LEVEL 2` + support coarse-grained fusion.
7877
}];
7978
let dependentDialects = ["func::FuncDialect", "linalg::LinalgDialect", "scf::SCFDialect",
8079
"tensor::TensorDialect"];

lib/gc/Transforms/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ add_mlir_library(GCPasses
1313
OneDNNGraphToLinalg.cpp
1414
Pipeline.cpp
1515
TileNamed.cpp
16-
AnyTilableFusion.cpp
16+
FineGrainedFusion.cpp
1717
TilingUsingInterfaceX.cpp
1818

1919
ADDITIONAL_HEADER_DIRS

lib/gc/Transforms/AnyTilableFusion.cpp renamed to lib/gc/Transforms/FineGrainedFusion.cpp

+102-77
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
//===-- AnyTilableFusion.cpp - Fusion For Any Tilable Op --------*- C++ -*-===//
1+
//===-- FineGrainedFusion.cpp - Fusion For Any Tilable Op --------*- C++
2+
//-*-===//
23
//
34
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
45
// See https://llvm.org/LICENSE.txt for license information.
@@ -7,6 +8,7 @@
78
//===----------------------------------------------------------------------===//
89

910
#include "gc/Transforms/Passes.h"
11+
#include "mlir/Analysis/TopologicalSortUtils.h"
1012
#include "mlir/Dialect/DLTI/Traits.h"
1113
#include "mlir/Dialect/Func/IR/FuncOps.h"
1214
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -34,7 +36,7 @@
3436

3537
namespace mlir {
3638
namespace gc {
37-
#define GEN_PASS_DEF_ANYTILABLEFUSION
39+
#define GEN_PASS_DEF_FINEGRAINEDFUSION
3840
#include "gc/Transforms/Passes.h.inc"
3941

4042
static FailureOr<tensor::ExtractSliceOp>
@@ -266,14 +268,14 @@ template <typename T1, typename T2> struct CandidateSliceProcessPipeLine {
266268
: CandidateSliceProcessPipeLine() {
267269
append(newFn);
268270
}
269-
CandidateSliceProcessPipeLine(const SmallVector<T1> &newFns)
271+
CandidateSliceProcessPipeLine(ArrayRef<T1> newFns)
270272
: CandidateSliceProcessPipeLine() {
271273
append(newFns);
272274
}
273275

274276
void append(const T1 &newFn) { candidateProcessFn.push_back(newFn); }
275-
void append(const SmallVector<T1> &newFns) {
276-
candidateProcessFn.append(newFns);
277+
void append(ArrayRef<T1> newFns) {
278+
llvm::append_range(candidateProcessFn, newFns);
277279
}
278280

279281
SmallVector<T1> getDefaultPipeLine() { return {}; }
@@ -282,6 +284,7 @@ template <typename T1, typename T2> struct CandidateSliceProcessPipeLine {
282284
struct CandidateSliceFilterPipeLine
283285
: public CandidateSliceProcessPipeLine<CandidateSliceFilter,
284286
CandidateSliceFilterPipeLine> {
287+
CandidateSliceFilterPipeLine() : CandidateSliceProcessPipeLine() {}
285288
CandidateSliceFilterPipeLine(const CandidateSliceFilter &filter)
286289
: CandidateSliceProcessPipeLine(filter) {}
287290
CandidateSliceFilterPipeLine(const SmallVector<CandidateSliceFilter> &filters)
@@ -362,9 +365,31 @@ struct CandidateSliceComparerPipeLine
362365
}
363366
};
364367

365-
std::optional<scf::SCFFuseProducerOfSliceResult> tileAndFuseProducerOfOpOperand(
366-
RewriterBase &rewriter, OpOperand &operand,
367-
const CandidateSliceFilterPipeLine &filterPipeLine) {
368+
struct CandidateSliceOptions {
369+
// Use for validity
370+
CandidateSliceFilterPipeLine filterPipeLine;
371+
// Use for performance
372+
CandidateSliceComparerPipeLine comparerPipeLine;
373+
374+
CandidateSliceOptions() = default;
375+
376+
void addFilter(const CandidateSliceFilter &filter) {
377+
filterPipeLine.append(filter);
378+
}
379+
void addFilter(ArrayRef<CandidateSliceFilter> filters) {
380+
filterPipeLine.append(filters);
381+
}
382+
void addComparer(const CandidateSliceComparer &comparer) {
383+
comparerPipeLine.append(comparer);
384+
}
385+
void addFilter(ArrayRef<CandidateSliceComparer> comparers) {
386+
comparerPipeLine.append(comparers);
387+
}
388+
};
389+
390+
std::optional<scf::SCFFuseProducerOfSliceResult>
391+
tileAndFuseProducerOfOpOperand(RewriterBase &rewriter, OpOperand &operand,
392+
const CandidateSliceOptions &options) {
368393
// a. Find the closest sliceOp
369394
FailureOr<tensor::ExtractSliceOp> closestSliceOp =
370395
getClosestExtractSliceOfOperand(operand);
@@ -388,22 +413,21 @@ std::optional<scf::SCFFuseProducerOfSliceResult> tileAndFuseProducerOfOpOperand(
388413
// d. Filter out invalid candidates
389414
SmallVector<tensor::ExtractSliceOp> validCandidates =
390415
llvm::to_vector(llvm::make_filter_range(
391-
backwardSlice, [&rewriter, &filterPipeLine,
392-
&defOrUse](tensor::ExtractSliceOp &candidate) {
393-
return succeeded(filterPipeLine.filter(
416+
backwardSlice,
417+
[&rewriter, &options, &defOrUse](tensor::ExtractSliceOp &candidate) {
418+
return succeeded(options.filterPipeLine.filter(
394419
rewriter,
395420
cast<OffsetSizeAndStrideOpInterface>(candidate.getOperation()),
396421
defOrUse));
397422
}));
398423
if (validCandidates.empty())
399424
return std::nullopt;
400425
// e. Select best candidates by Cost Model
401-
CandidateSliceComparerPipeLine comparePipeLine;
402426
tensor::ExtractSliceOp bestCandidate = *llvm::min_element(
403-
validCandidates, [&rewriter, &comparePipeLine,
404-
&defOrUse](tensor::ExtractSliceOp &candidateA,
405-
tensor::ExtractSliceOp &candidateB) {
406-
return comparePipeLine.compare(
427+
validCandidates,
428+
[&rewriter, &options, &defOrUse](tensor::ExtractSliceOp &candidateA,
429+
tensor::ExtractSliceOp &candidateB) {
430+
return options.comparerPipeLine.compare(
407431
rewriter,
408432
cast<OffsetSizeAndStrideOpInterface>(candidateA.getOperation()),
409433
cast<OffsetSizeAndStrideOpInterface>(candidateB.getOperation()),
@@ -414,9 +438,8 @@ std::optional<scf::SCFFuseProducerOfSliceResult> tileAndFuseProducerOfOpOperand(
414438
}
415439

416440
std::optional<SmallVector<scf::SCFFuseConsumerOfSliceResult>>
417-
tileAndFuseConsumerOfOpResult(
418-
RewriterBase &rewriter, OpResult result,
419-
const CandidateSliceFilterPipeLine &filterPipeLine) {
441+
tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
442+
const CandidateSliceOptions &options) {
420443
// a. Find the closest sliceOp
421444
FailureOr<tensor::ExtractSliceOp> closestSliceOp =
422445
getClosestInsertSliceOfResult(result);
@@ -443,22 +466,21 @@ tileAndFuseConsumerOfOpResult(
443466
// d. Filter out invalid candidates
444467
SmallVector<OffsetSizeAndStrideOpInterface> validCandidates =
445468
llvm::to_vector(llvm::make_filter_range(
446-
forwardSlice, [&rewriter, &filterPipeLine, &defOrUse](
469+
forwardSlice, [&rewriter, &options, &defOrUse](
447470
const OffsetSizeAndStrideOpInterface &candidate) {
448471
return succeeded(
449-
filterPipeLine.filter(rewriter, candidate, defOrUse));
472+
options.filterPipeLine.filter(rewriter, candidate, defOrUse));
450473
}));
451474
if (validCandidates.empty())
452475
continue;
453476

454477
// e. Select best candidates by Cost Model
455-
CandidateSliceComparerPipeLine comparePipeLine;
456478
OffsetSizeAndStrideOpInterface bestCandidate = *llvm::min_element(
457-
validCandidates, [&rewriter, &comparePipeLine, &defOrUse](
479+
validCandidates, [&rewriter, &options, &defOrUse](
458480
const OffsetSizeAndStrideOpInterface &candidateA,
459481
const OffsetSizeAndStrideOpInterface &candidateB) {
460-
return comparePipeLine.compare(rewriter, candidateA, candidateB,
461-
defOrUse);
482+
return options.comparerPipeLine.compare(rewriter, candidateA,
483+
candidateB, defOrUse);
462484
});
463485
// f. call tiling interface
464486
FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
@@ -496,49 +518,52 @@ tileAndFuseConsumerOfOpResult(
496518
*
497519
* producer1 producer2
498520
* \ /
499-
* tiledOp
521+
* Op
500522
* / \
501523
* consumer1 consumer2
502524
*
503525
* where:
504526
*
505-
* 1. tiled op is responsible for providing scheduled parallel loops and
506-
* several candidate sliceOp including both Producer and Consumer.
507-
* 2. support both pre-op and post-op fusion: try to fuse all of producers and
508-
* consumers of tiled op.
509-
* 3. recursively call forward and backward Fusion on either fused producer or
510-
* consumer op based on BFS.
527+
* Support iterative producer and consumer fusion in BFS fashion.
511528
*/
512-
void IterativelyFuseProducerAndConsumerOfTiledOp(
529+
void iterativelyFuseProducerAndConsumerOfTiledOp(
513530
RewriterBase &rewriter, Operation *tiledOp,
514531
TargetSystemSpecInterface targetSpec) {
515-
516-
// User-defined filter to control whether to fuse or not. If more than one
517-
// filters need given, please use filter list instead.
518-
// E.g.
519-
// SmallVector<CandidateSliceFilter> customizedFilterList
520-
// = {customizedFilter1, customizedFilter2, customizedFilter3, ...};
532+
// Flexible options to control which candidate slice would be selected from
533+
// the view of both validity and performance.
534+
CandidateSliceOptions options;
535+
// User-defined filter to control whether to fuse or not. For instance, the
536+
// maximum amount of fused ops is limited to 20(only used for example).
537+
int64_t numTiledOps = 0;
521538
CandidateSliceFilter customizedFilter =
522-
[](RewriterBase &rewriter, OffsetSizeAndStrideOpInterface candidate,
523-
CandidateDefOrUse defOrUse) -> LogicalResult { return success(); };
539+
[&numTiledOps](RewriterBase &rewriter,
540+
OffsetSizeAndStrideOpInterface candidate,
541+
CandidateDefOrUse defOrUse) -> LogicalResult {
542+
return success(numTiledOps < 20);
543+
};
544+
// If more than one filters need given, please use filter list instead. E.g.
545+
//
546+
// SmallVector<CandidateSliceFilter> customizedFilterList
547+
// = {customizedFilter1, customizedFilter2, ...};
548+
options.addFilter(customizedFilter);
524549

525550
std::deque<Operation *> tiledOpList = {tiledOp};
526551
while (!tiledOpList.empty()) {
527552
tiledOp = tiledOpList.front();
528553
tiledOpList.pop_front();
554+
numTiledOps++;
529555
// fuse producer
530556
for (OpOperand &operand : tiledOp->getOpOperands()) {
531557
if (std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult =
532-
tileAndFuseProducerOfOpOperand(rewriter, operand,
533-
customizedFilter)) {
558+
tileAndFuseProducerOfOpOperand(rewriter, operand, options)) {
534559
tiledOpList.push_back(fuseProducerResult.value().tiledOps[0]);
535560
}
536561
}
537562
// fuse consumer(s)
538563
for (OpResult result : tiledOp->getResults()) {
539564
if (std::optional<SmallVector<scf::SCFFuseConsumerOfSliceResult>>
540-
fuseConsumerResults = tileAndFuseConsumerOfOpResult(
541-
rewriter, result, customizedFilter)) {
565+
fuseConsumerResults =
566+
tileAndFuseConsumerOfOpResult(rewriter, result, options)) {
542567
for (auto &fuseConsumerResult : *fuseConsumerResults) {
543568
tiledOpList.push_back(fuseConsumerResult.tiledOps[0]);
544569
}
@@ -548,10 +573,7 @@ void IterativelyFuseProducerAndConsumerOfTiledOp(
548573
}
549574

550575
/**
551-
* What is Tiled Op?
552-
* 1. located in a for loop
553-
* 2. it is the only one TilingInterface op in for loop
554-
* 3. has extract/insert slice
576+
* What is single tiled op in loop?
555577
*
556578
* E.g.
557579
* %1 = scf.for(){
@@ -564,7 +586,7 @@ void IterativelyFuseProducerAndConsumerOfTiledOp(
564586
* }
565587
*
566588
* */
567-
static LogicalResult isTiledOp(Operation *targetOp) {
589+
static LogicalResult isSingleTiledOpInLoop(Operation *targetOp) {
568590
// 0. check tilable
569591
if (!isa<TilingInterface>(targetOp)) {
570592
return failure();
@@ -595,37 +617,40 @@ static LogicalResult isTiledOp(Operation *targetOp) {
595617
return success(walkResult.wasInterrupted());
596618
}
597619

598-
static void FineGrainedFusion(RewriterBase &rewriter, func::FuncOp f,
599-
TargetSystemSpecInterface targetSpec) {
600-
SmallVector<Operation *> tiledOpList;
601-
// Walk through func operation.
602-
f->walk([&tiledOpList](Operation *op) {
603-
// Target at tiled op, like matmul/conv
604-
if (succeeded(isTiledOp(op))) {
605-
tiledOpList.push_back(op);
606-
}
607-
});
608-
// Fuse all tilable ops around tiled op in forward and backward fashion.
609-
for (auto &tiledOp : tiledOpList) {
610-
IterativelyFuseProducerAndConsumerOfTiledOp(rewriter, tiledOp, targetSpec);
611-
}
612-
}
613-
614-
struct AnyTilableFusion : public impl::AnyTilableFusionBase<AnyTilableFusion> {
620+
struct FineGrainedFusion
621+
: public impl::FineGrainedFusionBase<FineGrainedFusion> {
615622

616623
public:
617624
void runOnOperation() final {
618625
auto &ctx = getContext();
619-
// Get funcOp
620-
func::FuncOp func = getOperation();
621-
// Get target descriptor
622-
TargetSystemSpecInterface targetSpec =
623-
mlir::impl::getTargetSystemSpec(func);
624-
// Get rewriter
625-
IRRewriter rewriter(&ctx);
626-
// Do fine-grained fusion
627-
FineGrainedFusion(rewriter, func, targetSpec);
628-
// Perhaps coarse-grained fusion here
626+
{
627+
// Get funcOp
628+
func::FuncOp func = getOperation();
629+
// Get target descriptor
630+
TargetSystemSpecInterface targetSpec =
631+
mlir::impl::getTargetSystemSpec(func);
632+
// Get rewriter
633+
IRRewriter rewriter(&ctx);
634+
635+
// Collect tiled ops before fusion
636+
llvm::SetVector<Operation *> tiledOps;
637+
// Walk through funcOp
638+
func->walk([&tiledOps](Operation *op) {
639+
// Target at certain kind of tiled op, such as matmul/conv implemented
640+
// by multiple level of nest loops and candidate slices for better
641+
// utilization of parallelism and memory hierarchy.
642+
if (succeeded(isSingleTiledOpInLoop(op))) {
643+
tiledOps.insert(op);
644+
}
645+
});
646+
// Sort by topology
647+
mlir::topologicalSort(tiledOps);
648+
// Iteratively fuse in forward and backward fashion.
649+
for (auto &tiledOp : tiledOps) {
650+
iterativelyFuseProducerAndConsumerOfTiledOp(rewriter, tiledOp,
651+
targetSpec);
652+
}
653+
}
629654

630655
{
631656
RewritePatternSet patternSet(&ctx);

lib/gc/Transforms/Pipeline.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ void populateTensorPasses(mlir::OpPassManager &pm) {
4343
// todo: layout propagation pass
4444
// todo: tensor constant propagation pass
4545
// todo: linalg.matmul lowering to (scf.loop + linalg.brgemm) pass
46-
// todo: fine-grain fusion pass
46+
// Fine-grain fusion pass
47+
pm.addNestedPass<func::FuncOp>(createFineGrainedFusion());
4748
// todo: lower linalg to arith/math on virtual vector pass
4849

4950
// REMOVE this pass after the above passes are added. Currently we add this

test/gc/Transform/any-tilable-fusion.mlir renamed to test/gc/Transform/fine-grained-fusion.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: gc-opt --split-input-file -any-tilable-fusion %s
1+
// RUN: gc-opt --split-input-file -fine-grained-fusion %s
22

33
module {
44
func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> {

0 commit comments

Comments
 (0)