@@ -38,6 +38,12 @@ using namespace mlir::tensor;
38
38
static SmallVector<int64_t > getPackedAxes (ArrayRef<int64_t > dimensions,
39
39
TensorLayout targetLayout) {
40
40
SmallVector<int64_t > result (dimensions);
41
+ // permuting on outer axis
42
+ auto outerPerm = targetLayout.getOuterAxis ();
43
+ for (size_t i = 0 ; i < dimensions.size (); ++i) {
44
+ result[i] = outerPerm[dimensions[i]];
45
+ }
46
+ // inserting inner axis
41
47
auto innerPos = targetLayout.getInnerAxis ();
42
48
for (size_t i = 0 ; i < dimensions.size (); ++i) {
43
49
if (std::find (innerPos.begin (), innerPos.end (), dimensions[i]) !=
@@ -153,8 +159,10 @@ FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
153
159
loc, inits.getTypes (), inputs, inits, packedAxes);
154
160
packedLinalgOp->getRegion (0 ).takeBody (linalgOp->getRegion (0 ));
155
161
} else if (auto broadcastOp = dyn_cast<linalg::BroadcastOp>(&linalgOp)) {
156
- packedLinalgOp = rewriter.create <linalg::BroadcastOp>(
157
- loc, inputs[0 ], inits[0 ], broadcastOp->getDimensions ());
162
+ SmallVector<int64_t > packedAxes =
163
+ getPackedAxes (broadcastOp->getDimensions (), initLayouts[0 ]);
164
+ packedLinalgOp = rewriter.create <linalg::BroadcastOp>(loc, inputs[0 ],
165
+ inits[0 ], packedAxes);
158
166
} else if (auto transposeOp = dyn_cast<linalg::TransposeOp>(&linalgOp)) {
159
167
SmallVector<int64_t > packedPermAxes = getPackedPermAxes (
160
168
transposeOp->getPermutation (), inputLayouts[0 ], initLayouts[0 ]);
@@ -237,39 +245,41 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph,
237
245
return WalkResult::skip ();
238
246
}
239
247
} else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op)) {
240
- Location loc = expandShapeOp->getLoc ();
241
- auto inputLayout = opLayout->getSupportedInputLayouts ()[0 ];
242
- auto outputLayout = opLayout->getSupportedOutputLayouts ()[0 ];
243
- Value dest = tensor::PackOp::createDestinationTensor (
244
- rewriter, loc, expandShapeOp.getSrc (), inputLayout.getTileSizes (),
245
- inputLayout.getInnerAxis (), inputLayout.getOuterAxis ());
246
- Value packedSource = rewriter.create <tensor::PackOp>(
247
- loc, expandShapeOp.getSrc (), dest, inputLayout.getInnerAxis (),
248
- inputLayout.getTileSizes (), std::nullopt,
249
- inputLayout.getOuterAxis ());
250
- auto resultType = RankedTensorType::get (
251
- expandShapeOp.getStaticOutputShape (),
252
- expandShapeOp.getSrcType ().getElementType ());
253
- RankedTensorType resultPackType = tensor::PackOp::inferPackedType (
254
- resultType, vector::getAsIntegers (outputLayout.getTileSizes ()),
255
- outputLayout.getInnerAxis (), outputLayout.getOuterAxis ());
256
- auto reassocExpand = getReassociationIndicesForReshape (
257
- cast<ShapedType>(dest.getType ()), resultPackType);
258
- auto packedExpandShape = rewriter.create <tensor::ExpandShapeOp>(
259
- loc, expandShapeOp.getSrcType ().getElementType (), packedSource,
260
- *reassocExpand);
261
- Value result = rewriter.create <tensor::UnPackOp>(
262
- packedExpandShape->getLoc (), packedExpandShape, packedExpandShape,
263
- outputLayout.getInnerAxis (), outputLayout.getTileSizes (),
264
- outputLayout.getOuterAxis ());
265
- rewriter.replaceOp (expandShapeOp, result);
248
+ // Location loc = expandShapeOp->getLoc();
249
+ // auto inputLayout = opLayout->getSupportedInputLayouts()[0];
250
+ // auto outputLayout = opLayout->getSupportedOutputLayouts()[0];
251
+ // Value dest = tensor::PackOp::createDestinationTensor(
252
+ // rewriter, loc, expandShapeOp.getSrc(),
253
+ // inputLayout.getTileSizes(), inputLayout.getInnerAxis(),
254
+ // inputLayout.getOuterAxis());
255
+ // Value packedSource = rewriter.create<tensor::PackOp>(
256
+ // loc, expandShapeOp.getSrc(), dest, inputLayout.getInnerAxis(),
257
+ // inputLayout.getTileSizes(), std::nullopt,
258
+ // inputLayout.getOuterAxis());
259
+ // auto resultType = RankedTensorType::get(
260
+ // expandShapeOp.getStaticOutputShape(),
261
+ // expandShapeOp.getSrcType().getElementType());
262
+ // RankedTensorType resultPackType = tensor::PackOp::inferPackedType(
263
+ // resultType, vector::getAsIntegers(outputLayout.getTileSizes()),
264
+ // outputLayout.getInnerAxis(), outputLayout.getOuterAxis());
265
+ // auto reassocExpand = getReassociationIndicesForReshape(
266
+ // cast<ShapedType>(dest.getType()), resultPackType);
267
+ // auto packedExpandShape = rewriter.create<tensor::ExpandShapeOp>(
268
+ // loc, expandShapeOp.getSrcType().getElementType(), packedSource,
269
+ // *reassocExpand);
270
+ // Value result = rewriter.create<tensor::UnPackOp>(
271
+ // packedExpandShape->getLoc(), packedExpandShape,
272
+ // packedExpandShape, outputLayout.getInnerAxis(),
273
+ // outputLayout.getTileSizes(), outputLayout.getOuterAxis());
274
+ // rewriter.replaceOp(expandShapeOp, result);
266
275
}
267
276
}
268
277
}
269
278
return WalkResult::advance ();
270
279
});
271
280
if (walk.wasSkipped ())
272
281
return failure ();
282
+ graph->dump ();
273
283
return success ();
274
284
}
275
285
0 commit comments