@@ -75,10 +75,6 @@ FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
75
75
LLVM_DEBUG (llvm::dbgs () << " Try packing named op "
76
76
<< linalgOp.getOperation ()->getName () << " .\n " );
77
77
Location loc = linalgOp->getLoc ();
78
- SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray ();
79
- SmallVector<utils::IteratorType> iteratorTypes =
80
- linalgOp.getIteratorTypesArray ();
81
-
82
78
SmallVector<tensor::PackOp> packOps;
83
79
SmallVector<tensor::UnPackOp> unPackOps;
84
80
SmallVector<Value> inputsAndInits, results;
@@ -215,8 +211,9 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph,
215
211
IRRewriter rewriter (ctx);
216
212
auto walk = graph->walk ([&](Operation *op) {
217
213
FailureOr<OperatorLayout> opLayout = controlFn (op);
218
- if (isa<linalg::LinalgOp>(op) && !mlir::linalg::isaContractionOpInterface (
219
- dyn_cast<linalg::LinalgOp>(op))) {
214
+ if ((isa<linalg::LinalgOp>(op) && !mlir::linalg::isaContractionOpInterface (
215
+ dyn_cast<linalg::LinalgOp>(op))) ||
216
+ isa<tensor::ExpandShapeOp>(op) || isa<tensor::PadOp>(op)) {
220
217
if (failed (opLayout)) {
221
218
LLVM_DEBUG (llvm::dbgs () << " Op " << op->getName ()
222
219
<< " does not have layout information.\n " );
@@ -235,9 +232,35 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph,
235
232
if (failed (packedOp)) {
236
233
return WalkResult::skip ();
237
234
}
235
+ } else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op)) {
236
+ Location loc = expandShapeOp->getLoc ();
237
+ auto inputLayout = opLayout->getSupportedInputLayouts ()[0 ];
238
+ auto outputLayout = opLayout->getSupportedOutputLayouts ()[0 ];
239
+ Value dest = tensor::PackOp::createDestinationTensor (
240
+ rewriter, loc, expandShapeOp.getSrc (), inputLayout.getTileSizes (),
241
+ inputLayout.getInnerAxis (), inputLayout.getOuterAxis ());
242
+ Value packedSource = rewriter.create <tensor::PackOp>(
243
+ loc, expandShapeOp.getSrc (), dest, inputLayout.getInnerAxis (),
244
+ inputLayout.getTileSizes (), std::nullopt,
245
+ inputLayout.getOuterAxis ());
246
+ auto resultType = RankedTensorType::get (
247
+ expandShapeOp.getStaticOutputShape (),
248
+ expandShapeOp.getSrcType ().getElementType ());
249
+ RankedTensorType resultPackType = tensor::PackOp::inferPackedType (
250
+ resultType, vector::getAsIntegers (outputLayout.getTileSizes ()),
251
+ outputLayout.getInnerAxis (), outputLayout.getOuterAxis ());
252
+ auto reassocExpand = getReassociationIndicesForReshape (
253
+ cast<ShapedType>(dest.getType ()), resultPackType);
254
+ auto packedExpandShape = rewriter.create <tensor::ExpandShapeOp>(
255
+ loc, expandShapeOp.getSrcType ().getElementType (), packedSource,
256
+ *reassocExpand);
257
+ Value result = rewriter.create <tensor::UnPackOp>(
258
+ packedExpandShape->getLoc (), packedExpandShape, packedExpandShape,
259
+ outputLayout.getInnerAxis (), outputLayout.getTileSizes (),
260
+ outputLayout.getOuterAxis ());
261
+ rewriter.replaceOp (expandShapeOp, result);
238
262
}
239
263
}
240
- } else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op)) {
241
264
}
242
265
return WalkResult::advance ();
243
266
});
0 commit comments