@@ -75,10 +75,6 @@ FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
7575 LLVM_DEBUG (llvm::dbgs () << " Try packing named op "
7676 << linalgOp.getOperation ()->getName () << " .\n " );
7777 Location loc = linalgOp->getLoc ();
78- SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray ();
79- SmallVector<utils::IteratorType> iteratorTypes =
80- linalgOp.getIteratorTypesArray ();
81-
8278 SmallVector<tensor::PackOp> packOps;
8379 SmallVector<tensor::UnPackOp> unPackOps;
8480 SmallVector<Value> inputsAndInits, results;
@@ -215,8 +211,9 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph,
215211 IRRewriter rewriter (ctx);
216212 auto walk = graph->walk ([&](Operation *op) {
217213 FailureOr<OperatorLayout> opLayout = controlFn (op);
218- if (isa<linalg::LinalgOp>(op) && !mlir::linalg::isaContractionOpInterface (
219- dyn_cast<linalg::LinalgOp>(op))) {
214+ if ((isa<linalg::LinalgOp>(op) && !mlir::linalg::isaContractionOpInterface (
215+ dyn_cast<linalg::LinalgOp>(op))) ||
216+ isa<tensor::ExpandShapeOp>(op) || isa<tensor::PadOp>(op)) {
220217 if (failed (opLayout)) {
221218 LLVM_DEBUG (llvm::dbgs () << " Op " << op->getName ()
222219 << " does not have layout information.\n " );
@@ -235,9 +232,35 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph,
235232 if (failed (packedOp)) {
236233 return WalkResult::skip ();
237234 }
235+ } else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op)) {
236+ Location loc = expandShapeOp->getLoc ();
237+ auto inputLayout = opLayout->getSupportedInputLayouts ()[0 ];
238+ auto outputLayout = opLayout->getSupportedOutputLayouts ()[0 ];
239+ Value dest = tensor::PackOp::createDestinationTensor (
240+ rewriter, loc, expandShapeOp.getSrc (), inputLayout.getTileSizes (),
241+ inputLayout.getInnerAxis (), inputLayout.getOuterAxis ());
242+ Value packedSource = rewriter.create <tensor::PackOp>(
243+ loc, expandShapeOp.getSrc (), dest, inputLayout.getInnerAxis (),
244+ inputLayout.getTileSizes (), std::nullopt ,
245+ inputLayout.getOuterAxis ());
246+ auto resultType = RankedTensorType::get (
247+ expandShapeOp.getStaticOutputShape (),
248+ expandShapeOp.getSrcType ().getElementType ());
249+ RankedTensorType resultPackType = tensor::PackOp::inferPackedType (
250+ resultType, vector::getAsIntegers (outputLayout.getTileSizes ()),
251+ outputLayout.getInnerAxis (), outputLayout.getOuterAxis ());
252+ auto reassocExpand = getReassociationIndicesForReshape (
253+ cast<ShapedType>(dest.getType ()), resultPackType);
254+ auto packedExpandShape = rewriter.create <tensor::ExpandShapeOp>(
255+ loc, expandShapeOp.getSrcType ().getElementType (), packedSource,
256+ *reassocExpand);
257+ Value result = rewriter.create <tensor::UnPackOp>(
258+ packedExpandShape->getLoc (), packedExpandShape, packedExpandShape,
259+ outputLayout.getInnerAxis (), outputLayout.getTileSizes (),
260+ outputLayout.getOuterAxis ());
261+ rewriter.replaceOp (expandShapeOp, result);
238262 }
239263 }
240- } else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op)) {
241264 }
242265 return WalkResult::advance ();
243266 });
0 commit comments