Skip to content

Commit 5c60edd

Browse files
committed
layout propagation for expandshape
1 parent 20db35a commit 5c60edd

File tree

2 files changed

+35
-11
lines changed

2 files changed

+35
-11
lines changed

lib/gc/Analysis/GlobalAnalysis.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,9 @@ inferTargetLayout(TensorLayout layoutBase,
155155

156156
GlobalAnalysis::GlobalAnalysis(Operation *root) {
157157
root->walk([&](Operation *op) {
158-
// get input layouts
159-
LLVM_DEBUG(llvm::dbgs()
160-
<< "Inferring layoutCache of op: " << op->getName() << "\n");
161158
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
159+
LLVM_DEBUG(llvm::dbgs()
160+
<< "Inferring layout of op: " << op->getName() << "\n");
162161
auto curInputs = linalgOp.getDpsInputOperands();
163162
auto curResults = linalgOp.getOperation()->getResults();
164163
// ---------------- Get Current Input Layouts -------------------
@@ -258,12 +257,14 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
258257
}
259258
auto inputOuterAxis = curInputLayout.getOuterAxis();
260259
auto inputInnerAxis = curInputLayout.getInnerAxis();
260+
int64_t diffDifference = staticOutputShape.size() - inputShape.size();
261261
int64_t startIdx = 0;
262262
SmallVector<int64_t> outputOuterAxis, outputInnerAxis;
263263
for (int64_t i = 0; i < static_cast<int64_t>(staticOutputShape.size());
264264
++i) {
265265
if (outputInputIdxMapping.find(i) != outputInputIdxMapping.end()) {
266-
outputOuterAxis.push_back(inputOuterAxis[outputInputIdxMapping[i]]);
266+
outputOuterAxis.push_back(inputOuterAxis[outputInputIdxMapping[i]] +
267+
diffDifference);
267268
} else {
268269
outputOuterAxis.push_back(startIdx++);
269270
}

lib/gc/Transforms/PropagateLayout.cpp

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,6 @@ FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
7575
LLVM_DEBUG(llvm::dbgs() << "Try packing named op "
7676
<< linalgOp.getOperation()->getName() << ".\n");
7777
Location loc = linalgOp->getLoc();
78-
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
79-
SmallVector<utils::IteratorType> iteratorTypes =
80-
linalgOp.getIteratorTypesArray();
81-
8278
SmallVector<tensor::PackOp> packOps;
8379
SmallVector<tensor::UnPackOp> unPackOps;
8480
SmallVector<Value> inputsAndInits, results;
@@ -215,8 +211,9 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph,
215211
IRRewriter rewriter(ctx);
216212
auto walk = graph->walk([&](Operation *op) {
217213
FailureOr<OperatorLayout> opLayout = controlFn(op);
218-
if (isa<linalg::LinalgOp>(op) && !mlir::linalg::isaContractionOpInterface(
219-
dyn_cast<linalg::LinalgOp>(op))) {
214+
if ((isa<linalg::LinalgOp>(op) && !mlir::linalg::isaContractionOpInterface(
215+
dyn_cast<linalg::LinalgOp>(op))) ||
216+
isa<tensor::ExpandShapeOp>(op) || isa<tensor::PadOp>(op)) {
220217
if (failed(opLayout)) {
221218
LLVM_DEBUG(llvm::dbgs() << "Op " << op->getName()
222219
<< "does not have layout information.\n");
@@ -235,9 +232,35 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph,
235232
if (failed(packedOp)) {
236233
return WalkResult::skip();
237234
}
235+
} else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op)) {
236+
Location loc = expandShapeOp->getLoc();
237+
auto inputLayout = opLayout->getSupportedInputLayouts()[0];
238+
auto outputLayout = opLayout->getSupportedOutputLayouts()[0];
239+
Value dest = tensor::PackOp::createDestinationTensor(
240+
rewriter, loc, expandShapeOp.getSrc(), inputLayout.getTileSizes(),
241+
inputLayout.getInnerAxis(), inputLayout.getOuterAxis());
242+
Value packedSource = rewriter.create<tensor::PackOp>(
243+
loc, expandShapeOp.getSrc(), dest, inputLayout.getInnerAxis(),
244+
inputLayout.getTileSizes(), std::nullopt,
245+
inputLayout.getOuterAxis());
246+
auto resultType = RankedTensorType::get(
247+
expandShapeOp.getStaticOutputShape(),
248+
expandShapeOp.getSrcType().getElementType());
249+
RankedTensorType resultPackType = tensor::PackOp::inferPackedType(
250+
resultType, vector::getAsIntegers(outputLayout.getTileSizes()),
251+
outputLayout.getInnerAxis(), outputLayout.getOuterAxis());
252+
auto reassocExpand = getReassociationIndicesForReshape(
253+
cast<ShapedType>(dest.getType()), resultPackType);
254+
auto packedExpandShape = rewriter.create<tensor::ExpandShapeOp>(
255+
loc, expandShapeOp.getSrcType().getElementType(), packedSource,
256+
*reassocExpand);
257+
Value result = rewriter.create<tensor::UnPackOp>(
258+
packedExpandShape->getLoc(), packedExpandShape, packedExpandShape,
259+
outputLayout.getInnerAxis(), outputLayout.getTileSizes(),
260+
outputLayout.getOuterAxis());
261+
rewriter.replaceOp(expandShapeOp, result);
238262
}
239263
}
240-
} else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op)) {
241264
}
242265
return WalkResult::advance();
243266
});

0 commit comments

Comments
 (0)