12
12
namespace mlir {
13
13
namespace gc {
14
14
15
- std::ostream &operator <<(std::ostream &ss, const TensorLayout &layout) {
16
- SmallVector<int64_t > outerAxis = layout.getOuterAxis ();
17
- SmallVector<int64_t > innerAxis = layout.getInnerAxis ();
18
- SmallVector<OpFoldResult> tileSizes = layout.getTileSizes ();
15
+ #define DEBUG_TYPE " global-analysis"
16
+
17
+ llvm::raw_ostream &operator <<(llvm::raw_ostream &ss,
18
+ const TensorLayout &layoutCache) {
19
+ SmallVector<int64_t > outerAxis = layoutCache.getOuterAxis ();
20
+ SmallVector<int64_t > innerAxis = layoutCache.getInnerAxis ();
21
+ SmallVector<OpFoldResult> tileSizes = layoutCache.getTileSizes ();
19
22
ss << " [" ;
20
23
for (size_t i = 0 ; i < outerAxis.size (); ++i) {
21
24
if (i != 0 ) {
@@ -43,21 +46,21 @@ std::ostream &operator<<(std::ostream &ss, const TensorLayout &layout) {
43
46
return ss;
44
47
}
45
48
46
- bool TensorLayout::operator ==(const TensorLayout &layout ) {
47
- return (this ->OuterAxis == layout .getOuterAxis ()) &&
48
- (this ->InnerAxis == layout .getInnerAxis ()) &&
49
- (this ->TileSizes == layout .getTileSizes ());
49
+ bool TensorLayout::operator ==(const TensorLayout &layoutCache ) {
50
+ return (this ->OuterAxis == layoutCache .getOuterAxis ()) &&
51
+ (this ->InnerAxis == layoutCache .getInnerAxis ()) &&
52
+ (this ->TileSizes == layoutCache .getTileSizes ());
50
53
}
51
54
52
- std::ostream &operator <<(std::ostream &ss, const OperatorLayout &opLayout) {
53
- ss << " operator has " << opLayout.getSupportedInputLayouts ().size ()
54
- << " inputs; " << opLayout.getSupportedOutputLayouts ().size ()
55
- << " outputs." << std::endl;
56
- for (const auto &layout : opLayout.getSupportedInputLayouts ()) {
57
- ss << " input layout: " << layout << std::endl;
55
+ llvm::raw_ostream &operator <<(llvm::raw_ostream &ss,
56
+ const OperatorLayout &opLayout) {
57
+ for (auto &&[idx, layoutCache] :
58
+ llvm::enumerate (opLayout.getSupportedInputLayouts ())) {
59
+ ss << " input " << idx << " 's layoutCache: " << layoutCache << " \n " ;
58
60
}
59
- for (const auto &layout : opLayout.getSupportedOutputLayouts ()) {
60
- ss << " output layout: " << layout << std::endl;
61
+ for (auto &&[idx, layoutCache] :
62
+ llvm::enumerate (opLayout.getSupportedOutputLayouts ())) {
63
+ ss << " output " << idx << " 's layoutCache: " << layoutCache << " \n " ;
61
64
}
62
65
return ss;
63
66
}
@@ -119,7 +122,6 @@ getReversedIndexMap(const DenseMap<int64_t, int64_t> &indexMap,
119
122
static FailureOr<TensorLayout>
120
123
inferTargetLayout (TensorLayout layoutBase,
121
124
const DenseMap<int64_t , int64_t > &indexMap) {
122
- int64_t dimDifference = indexMap.size () - layoutBase.getTensorRank ();
123
125
SmallVector<int64_t > baseOuterAxis = layoutBase.getOuterAxis ();
124
126
SmallVector<int64_t > baseInnerAxis = layoutBase.getInnerAxis ();
125
127
SmallVector<OpFoldResult> baseTileSizes = layoutBase.getTileSizes ();
@@ -153,38 +155,24 @@ inferTargetLayout(TensorLayout layoutBase,
153
155
154
156
GlobalAnalysis::GlobalAnalysis (Operation *root) {
155
157
root->walk ([&](Operation *op) {
158
+ // get input layouts
159
+ LLVM_DEBUG (llvm::dbgs ()
160
+ << " Inferring layoutCache of op: " << op->getName () << " \n " );
156
161
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
157
- // get input layouts
158
- std::cout << std::endl;
159
- std::cout << " ----------------------------------" << std::endl;
160
- linalgOp.getOperation ()->getName ().print (llvm::errs ());
161
- std::cout << std::endl;
162
- std::cout << " ----------------------------------" << std::endl;
163
- std::cout << std::endl;
164
- SmallVector<AffineMap> indexing_maps = linalgOp.getIndexingMapsArray ();
165
162
auto curInputs = linalgOp.getDpsInputOperands ();
166
163
auto curResults = linalgOp.getOperation ()->getResults ();
167
-
168
164
// ---------------- Get Current Input Layouts -------------------
169
- // get current input layouts
170
- std::cout << " ----- printing ground-truth input layouts -----"
171
- << std::endl;
172
165
SmallVector<TensorLayout> curInputLayouts;
173
166
for (auto input : curInputs) {
174
167
auto parent = input->get ().getDefiningOp ();
175
- if (layout .find (parent) != layout .end ()) {
168
+ if (layoutCache .find (parent) != layoutCache .end ()) {
176
169
// TODO(yifei): it is not always 0 here
177
- curInputLayouts.push_back (layout [parent].getOutputLayout (0 ));
170
+ curInputLayouts.push_back (layoutCache [parent].getOutputLayout (0 ));
178
171
} else {
179
172
curInputLayouts.push_back (TensorLayout::createPlainLayout (
180
173
linalgOp.getMatchingIndexingMap (input).getNumResults ()));
181
174
}
182
175
}
183
- // debug info
184
- for (auto layout : curInputLayouts) {
185
- std::cout << " layout: " << layout << std::endl;
186
- }
187
-
188
176
// ------ Get Current Op's Suggested Layout & Do Propagation ------
189
177
IRRewriter rewriter (linalgOp);
190
178
if (mlir::linalg::isaContractionOpInterface (linalgOp)) {
@@ -193,38 +181,33 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
193
181
// curInputLayouts);
194
182
195
183
// hardcode one for now
196
- // A side layout , [0, 1, 0, 1]; {32, 32}
184
+ // A side layoutCache , [0, 1, 0, 1]; {32, 32}
197
185
TensorLayout A_layout (
198
186
{0 , 1 }, {0 , 1 },
199
187
SmallVector<OpFoldResult>{rewriter.getIndexAttr (32 ),
200
188
rewriter.getIndexAttr (32 )});
201
- // B side layout , [1, 0, 0, 1]; {32, 32}
189
+ // B side layoutCache , [1, 0, 0, 1]; {32, 32}
202
190
TensorLayout B_layout (
203
191
{1 , 0 }, {0 , 1 },
204
192
SmallVector<OpFoldResult>{rewriter.getIndexAttr (32 ),
205
193
rewriter.getIndexAttr (32 )});
206
- // C side layout , [0, 1, 0, 1]; {32, 32}
194
+ // C side layoutCache , [0, 1, 0, 1]; {32, 32}
207
195
TensorLayout C_layout (
208
196
{0 , 1 }, {0 , 1 },
209
197
SmallVector<OpFoldResult>{rewriter.getIndexAttr (32 ),
210
198
rewriter.getIndexAttr (32 )});
211
199
OperatorLayout suggestedLayout ({A_layout, B_layout}, {C_layout});
212
- layout [linalgOp] = suggestedLayout;
200
+ layoutCache [linalgOp] = suggestedLayout;
213
201
} else {
214
202
SmallVector<TensorLayout> inputLayouts, outputLayouts;
215
203
inputLayouts.push_back (curInputLayouts[0 ]);
216
204
// TODO(yifei): wisely choose the input format basis
217
205
// Let's only refer to input[0] for now
218
206
for (size_t i = 1 ; i < curInputs.size (); ++i) {
219
- std::cout << " inferring indexing map relation" << std::endl;
220
207
// getMatchingIndexingMap
221
208
auto res = inferIndexingMapRelation (
222
209
linalgOp.getMatchingIndexingMap (curInputs[0 ]),
223
210
linalgOp.getMatchingIndexingMap (curInputs[i]));
224
- for (auto tp : *res) {
225
- std::cout << " target index: " << tp.first
226
- << " maps to base index: " << tp.second << std::endl;
227
- }
228
211
TensorLayout inputLayout =
229
212
*inferTargetLayout (curInputLayouts[0 ], *res);
230
213
inputLayouts.push_back (inputLayout);
@@ -235,14 +218,66 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
235
218
TensorLayout outputLayout =
236
219
*inferTargetLayout (curInputLayouts[0 ], *res_out);
237
220
outputLayouts.push_back (outputLayout);
238
- for (auto tp : *res_out) {
239
- std::cout << " target index: " << tp.first
240
- << " maps to base index: " << tp.second << std::endl;
241
- }
242
221
OperatorLayout suggestedLayout (inputLayouts, outputLayouts);
243
- layout[linalgOp] = suggestedLayout;
222
+ layoutCache[linalgOp] = suggestedLayout;
223
+ }
224
+ } else if (auto padOp = dyn_cast<tensor::PadOp>(op)) {
225
+ auto inputOperand = padOp.getSource ();
226
+ auto inputRank =
227
+ cast<ShapedType>(inputOperand.getType ()).getShape ().size ();
228
+ auto parent = inputOperand.getDefiningOp ();
229
+ TensorLayout curInputLayout =
230
+ layoutCache.find (parent) != layoutCache.end ()
231
+ ? layoutCache[parent].getOutputLayout (0 )
232
+ : TensorLayout::createPlainLayout (inputRank);
233
+ SmallVector<TensorLayout> inputLayouts{curInputLayout},
234
+ outputLayouts{curInputLayout};
235
+ OperatorLayout suggestedLayout (inputLayouts, outputLayouts);
236
+ layoutCache[padOp] = suggestedLayout;
237
+ } else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op)) {
238
+ auto reassociation = expandShapeOp.getReassociation ();
239
+ auto staticOutputShape = expandShapeOp.getStaticOutputShape ();
240
+ auto parent = expandShapeOp.getSrc ().getDefiningOp ();
241
+ auto inputShape = expandShapeOp.getSrcType ().getShape ();
242
+ TensorLayout curInputLayout =
243
+ layoutCache.find (parent) != layoutCache.end ()
244
+ ? layoutCache[parent].getOutputLayout (0 )
245
+ : TensorLayout::createPlainLayout (inputShape.size ());
246
+ DenseMap<int64_t , int64_t > outputInputIdxMapping, inputOutputIndexMapping;
247
+ int64_t accumulationOffset = 0 ;
248
+ for (int64_t i = 0 ; i < static_cast <int64_t >(reassociation.size ()); ++i) {
249
+ auto subReassociation = llvm::cast<ArrayAttr>(reassociation[i]);
250
+ for (int64_t j = 0 ; j < static_cast <int64_t >(subReassociation.size ());
251
+ ++j) {
252
+ if (staticOutputShape[accumulationOffset + j] == inputShape[i]) {
253
+ outputInputIdxMapping[accumulationOffset + j] = i;
254
+ inputOutputIndexMapping[i] = accumulationOffset + j;
255
+ }
256
+ }
257
+ accumulationOffset += subReassociation.size ();
258
+ }
259
+ auto inputOuterAxis = curInputLayout.getOuterAxis ();
260
+ auto inputInnerAxis = curInputLayout.getInnerAxis ();
261
+ int64_t startIdx = 0 ;
262
+ SmallVector<int64_t > outputOuterAxis, outputInnerAxis;
263
+ for (int64_t i = 0 ; i < static_cast <int64_t >(staticOutputShape.size ());
264
+ ++i) {
265
+ if (outputInputIdxMapping.find (i) != outputInputIdxMapping.end ()) {
266
+ outputOuterAxis.push_back (inputOuterAxis[outputInputIdxMapping[i]]);
267
+ } else {
268
+ outputOuterAxis.push_back (startIdx++);
269
+ }
270
+ }
271
+ for (int64_t i = 0 ; i < static_cast <int64_t >(inputInnerAxis.size ());
272
+ ++i) {
273
+ outputInnerAxis.push_back (inputOutputIndexMapping[inputInnerAxis[i]]);
244
274
}
245
- } else if (isa<tensor::PadOp>(op) || isa<tensor::ExpandShapeOp>(op)) {
275
+ TensorLayout outputLayout (outputOuterAxis, outputInnerAxis,
276
+ curInputLayout.getTileSizes ());
277
+ SmallVector<TensorLayout> inputLayouts{curInputLayout},
278
+ outputLayouts{outputLayout};
279
+ OperatorLayout suggestedLayout (inputLayouts, outputLayouts);
280
+ layoutCache[expandShapeOp] = suggestedLayout;
246
281
}
247
282
});
248
283
}
0 commit comments