Skip to content

Commit 705d249

Browse files
committed
fix comments
1 parent 0e7794c commit 705d249

File tree

6 files changed

+304
-235
lines changed

6 files changed

+304
-235
lines changed

include/gc/Analysis/VectorBasedFusionAnalysis.h

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,19 +74,25 @@ class VectorFusionBase {
7474
func::FuncOp func;
7575
/// Type helper class, can help us to get operation type
7676
TypeHelper typehelper;
77+
/// IR rewriter
78+
IRRewriter *rewriter;
7779

7880
public:
79-
VectorFusionBase() = default;
80-
VectorFusionBase(func::FuncOp &func, HardWareInfo &info)
81-
: func(func), typehelper(info) {}
82-
VectorFusionBase(VectorFusionBase &base)
83-
: func(base.getFunction()), typehelper(base.getHardwareInfo()) {}
81+
VectorFusionBase(func::FuncOp &func, HardWareInfo &info, IRRewriter *rewriter)
82+
: func(func), typehelper(info), rewriter(rewriter) {}
83+
VectorFusionBase(VectorFusionBase &base, IRRewriter *rewriter)
84+
: func(base.getFunction()), typehelper(base.getHardwareInfo()),
85+
rewriter(rewriter) {}
8486

8587
/// get current function IR
8688
func::FuncOp &getFunction() { return func; }
8789
/// get current hardware info
88-
HardWareInfo &getHardwareInfo() { return typehelper.getHardwareInfo(); }
89-
TypeHelper &getTypeHelper() { return typehelper; }
90+
HardWareInfo &getHardwareInfo() noexcept {
91+
return typehelper.getHardwareInfo();
92+
}
93+
TypeHelper &getTypeHelper() noexcept { return typehelper; }
94+
IRRewriter *getRewriter() noexcept { return rewriter; }
95+
void setRewriter(IRRewriter *rewriter) noexcept { this->rewriter = rewriter; }
9096
};
9197

9298
/// Group operation fusion strategy class.
@@ -132,17 +138,20 @@ class GroupOperationFusion : public VectorFusionBase {
132138
DenseMap<Value, Value> operandOriginalValue;
133139

134140
public:
135-
GroupOperationFusion(func::FuncOp &func, HardWareInfo &info)
136-
: VectorFusionBase(func, info) {}
141+
GroupOperationFusion(func::FuncOp &func, HardWareInfo &info,
142+
IRRewriter *rewriter)
143+
: VectorFusionBase(func, info, rewriter) {}
137144

138-
GroupOperationFusion(GroupOperationFusion &strategy)
139-
: VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo()),
145+
GroupOperationFusion(GroupOperationFusion &strategy, IRRewriter *rewriter)
146+
: VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo(),
147+
rewriter),
140148
opGroups(strategy.opGroups), groupMaxSteps(strategy.groupMaxSteps),
141149
opGroupIndexMap(strategy.opGroupIndexMap),
142150
opAnchorPos(strategy.opAnchorPos){};
143151

144-
GroupOperationFusion(GroupOperationFusion &&strategy)
145-
: VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo()),
152+
GroupOperationFusion(GroupOperationFusion &&strategy, IRRewriter *rewriter)
153+
: VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo(),
154+
rewriter),
146155
opGroups(std::move(strategy.opGroups)),
147156
groupMaxSteps(std::move(strategy.groupMaxSteps)),
148157
groupBigestRankVectorType(
@@ -165,9 +174,9 @@ class GroupOperationFusion : public VectorFusionBase {
165174
this->getFunction() = fusion.getFunction();
166175
this->getHardwareInfo() = fusion.getHardwareInfo();
167176
this->getTypeHelper() = fusion.getTypeHelper();
177+
this->setRewriter(fusion.getRewriter());
168178
return *this;
169179
};
170-
GroupOperationFusion &operator=(GroupOperationFusion &&) = default;
171180

172181
/// Get the map which contains each group vector type which has biggest
173182
/// rank.
@@ -275,10 +284,12 @@ class GroupOperationAnalysis {
275284
private:
276285
/// vector-based fusion related data
277286
GroupOperationFusion fusionStrategy;
287+
IRRewriter *rewriter;
278288

279289
public:
280-
GroupOperationAnalysis(func::FuncOp &func, HardWareInfo &info)
281-
: fusionStrategy(func, info) {}
290+
GroupOperationAnalysis(func::FuncOp &func, HardWareInfo &info,
291+
IRRewriter *rewriter)
292+
: fusionStrategy(func, info, rewriter), rewriter(rewriter) {}
282293
/// remove the useless operation, due to it result is not require by other
283294
/// operation
284295
void analysisEmptyGroup();
@@ -288,6 +299,8 @@ class GroupOperationAnalysis {
288299
GroupOperationFusion &getGroupOperationFusion() { return fusionStrategy; }
289300
/// running the vector-based fusion
290301
void run() { fusionStrategy.run(); }
302+
/// get current function rewriter
303+
IRRewriter *getRewriter() { return rewriter; }
291304
};
292305
} // namespace gc
293306
} // namespace mlir

include/gc/Transforms/TilingVector.h

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- TilingVector.hpp - Tiling large vector to small vector ---*- C++ -*-===//
1+
//===- TilingVector.h - Tiling large vector to small vector -----*- C++ -*-===//
22
//
33
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -343,12 +343,16 @@ class ShapeCastCanonicalizer
343343
class ForLoopGenerator {
344344
private:
345345
GroupOperationFusion vectorBasedFusion;
346+
IRRewriter *rewriter;
346347

347348
public:
348-
ForLoopGenerator(GroupOperationFusion &fusion) : vectorBasedFusion(fusion) {}
349+
ForLoopGenerator(GroupOperationFusion &fusion, IRRewriter *rewriter)
350+
: vectorBasedFusion(fusion, rewriter), rewriter(rewriter) {}
349351

350352
virtual ~ForLoopGenerator() noexcept {}
351353

354+
IRRewriter *getRewriter() noexcept { return rewriter; }
355+
352356
void setVectorBaseFusion(GroupOperationFusion &vectorBasedFusion) {
353357
this->vectorBasedFusion = vectorBasedFusion;
354358
};
@@ -466,7 +470,8 @@ class LoopGeneratorImpl : public ForLoopGenerator {
466470
SmallVector<ShapeCastCanonicalizer, 8> shapeCastCanonicalizers;
467471

468472
public:
469-
LoopGeneratorImpl(GroupOperationFusion &fusion) : ForLoopGenerator(fusion){};
473+
LoopGeneratorImpl(GroupOperationFusion &fusion, IRRewriter *rewriter)
474+
: ForLoopGenerator(fusion, rewriter){};
470475

471476
virtual ~LoopGeneratorImpl() noexcept {};
472477

@@ -569,8 +574,9 @@ class GroupOperationFusionImpl : public GroupOperationAnalysis {
569574

570575
public:
571576
virtual ~GroupOperationFusionImpl() = default;
572-
GroupOperationFusionImpl(func::FuncOp &func, HardWareInfo &info)
573-
: GroupOperationAnalysis(func, info) {}
577+
GroupOperationFusionImpl(func::FuncOp &func, HardWareInfo &info,
578+
IRRewriter *rewriter)
579+
: GroupOperationAnalysis(func, info, rewriter) {}
574580

575581
void broadcastFromElements(Operation *op, size_t grpIdx);
576582
void scalarOperandFromElements();
@@ -632,17 +638,20 @@ class VectorOperationCanonicalizer {
632638
LoopGeneratorImpl loopGenerator;
633639
CanonicalizerKind kind;
634640
func::FuncOp func;
635-
IRRewriter rewriter;
641+
IRRewriter *rewriter;
636642

637643
public:
638644
VectorOperationCanonicalizer(
639-
func::FuncOp &func, HardWareInfo &info,
645+
func::FuncOp &func, HardWareInfo &info, IRRewriter *rewriter,
640646
CanonicalizerKind kind = CanonicalizerKind::GroupOperations)
641-
: fusion(func, info), loopGenerator(fusion.getGroupOperationFusion()),
642-
kind(kind), rewriter(func) {}
647+
: fusion(func, info, rewriter),
648+
loopGenerator(fusion.getGroupOperationFusion(), rewriter), kind(kind),
649+
rewriter(rewriter) {}
643650
virtual ~VectorOperationCanonicalizer() = default;
644651
/// run the vector canonicalizer for the IR
645652
void run();
653+
/// get current funtion rewriter
654+
IRRewriter *getRewriter() noexcept { return rewriter; }
646655
};
647656
} // namespace gc
648657
} // namespace mlir

include/gc/Transforms/Utils/VectorUtils.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@
2424

2525
namespace mlir {
2626
namespace gc {
27+
28+
enum class OPPRIORITY : uint8_t {
29+
FIRST = 0,
30+
SECOND,
31+
THIRD,
32+
LAST,
33+
OTHERS = 255,
34+
};
2735
/// Need to move some operations like extract_slice or insert_slice.
2836
/// Because those operation may interpret our analysis result. e.g.:
2937
/// ```
@@ -54,8 +62,8 @@ namespace gc {
5462
/// block.
5563
/// insert_slice just move them to the privious of the first operation which
5664
/// use it.
57-
void moveOpsFrontOrBack(func::FuncOp *func, MLIRContext *ctx,
58-
std::function<bool(Operation *)> &conditionalFunc);
65+
void moveOpsFrontOrBack(func::FuncOp *func, IRRewriter &rewriter,
66+
OPPRIORITY start, OPPRIORITY end);
5967

6068
/// build a constant operation of index type
6169
Value makeIndexArithConstantOp(OpBuilder &opBuilder, const Location &loc,

0 commit comments

Comments
 (0)