@@ -74,19 +74,25 @@ class VectorFusionBase {
74
74
func::FuncOp func;
75
75
// / Type helper class, can help us to get operation type
76
76
TypeHelper typehelper;
77
+ // / IR rewriter
78
+ IRRewriter *rewriter;
77
79
78
80
public:
79
- VectorFusionBase () = default ;
80
- VectorFusionBase (func::FuncOp & func, HardWareInfo & info)
81
- : func(func), typehelper(info) {}
82
- VectorFusionBase (VectorFusionBase & base)
83
- : func(base.getFunction()), typehelper(base.getHardwareInfo() ) {}
81
+ VectorFusionBase (func::FuncOp &func, HardWareInfo &info, IRRewriter *rewriter)
82
+ : func(func), typehelper( info), rewriter(rewriter) {}
83
+ VectorFusionBase (VectorFusionBase &base, IRRewriter *rewriter)
84
+ : func(base.getFunction()), typehelper( base.getHardwareInfo()),
85
+ rewriter (rewriter ) {}
84
86
85
87
// / get current function IR
86
88
func::FuncOp &getFunction () { return func; }
87
89
// / get current hardware info
88
- HardWareInfo &getHardwareInfo () { return typehelper.getHardwareInfo (); }
89
- TypeHelper &getTypeHelper () { return typehelper; }
90
+ HardWareInfo &getHardwareInfo () noexcept {
91
+ return typehelper.getHardwareInfo ();
92
+ }
93
+ TypeHelper &getTypeHelper () noexcept { return typehelper; }
94
+ IRRewriter *getRewriter () noexcept { return rewriter; }
95
+ void setRewriter (IRRewriter *rewriter) noexcept { this ->rewriter = rewriter; }
90
96
};
91
97
92
98
// / Group operation fusion strategy class.
@@ -132,17 +138,20 @@ class GroupOperationFusion : public VectorFusionBase {
132
138
DenseMap<Value, Value> operandOriginalValue;
133
139
134
140
public:
135
- GroupOperationFusion (func::FuncOp &func, HardWareInfo &info)
136
- : VectorFusionBase(func, info) {}
141
+ GroupOperationFusion (func::FuncOp &func, HardWareInfo &info,
142
+ IRRewriter *rewriter)
143
+ : VectorFusionBase(func, info, rewriter) {}
137
144
138
- GroupOperationFusion (GroupOperationFusion &strategy)
139
- : VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo()),
145
+ GroupOperationFusion (GroupOperationFusion &strategy, IRRewriter *rewriter)
146
+ : VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo(),
147
+ rewriter),
140
148
opGroups (strategy.opGroups), groupMaxSteps(strategy.groupMaxSteps),
141
149
opGroupIndexMap(strategy.opGroupIndexMap),
142
150
opAnchorPos(strategy.opAnchorPos){};
143
151
144
- GroupOperationFusion (GroupOperationFusion &&strategy)
145
- : VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo()),
152
+ GroupOperationFusion (GroupOperationFusion &&strategy, IRRewriter *rewriter)
153
+ : VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo(),
154
+ rewriter),
146
155
opGroups(std::move(strategy.opGroups)),
147
156
groupMaxSteps(std::move(strategy.groupMaxSteps)),
148
157
groupBigestRankVectorType(
@@ -165,9 +174,9 @@ class GroupOperationFusion : public VectorFusionBase {
165
174
this ->getFunction () = fusion.getFunction ();
166
175
this ->getHardwareInfo () = fusion.getHardwareInfo ();
167
176
this ->getTypeHelper () = fusion.getTypeHelper ();
177
+ this ->setRewriter (fusion.getRewriter ());
168
178
return *this ;
169
179
};
170
- GroupOperationFusion &operator =(GroupOperationFusion &&) = default ;
171
180
172
181
// / Get the map which contains each group vector type which has biggest
173
182
// / rank.
@@ -275,10 +284,12 @@ class GroupOperationAnalysis {
275
284
private:
276
285
// / vector-based fusion related data
277
286
GroupOperationFusion fusionStrategy;
287
+ IRRewriter *rewriter;
278
288
279
289
public:
280
- GroupOperationAnalysis (func::FuncOp &func, HardWareInfo &info)
281
- : fusionStrategy(func, info) {}
290
+ GroupOperationAnalysis (func::FuncOp &func, HardWareInfo &info,
291
+ IRRewriter *rewriter)
292
+ : fusionStrategy(func, info, rewriter), rewriter(rewriter) {}
282
293
// / remove the useless operation, due to it result is not require by other
283
294
// / operation
284
295
void analysisEmptyGroup ();
@@ -288,6 +299,8 @@ class GroupOperationAnalysis {
288
299
GroupOperationFusion &getGroupOperationFusion () { return fusionStrategy; }
289
300
// / running the vector-based fusion
290
301
void run () { fusionStrategy.run (); }
302
+ // / get current function rewriter
303
+ IRRewriter *getRewriter () { return rewriter; }
291
304
};
292
305
} // namespace gc
293
306
} // namespace mlir
0 commit comments