Skip to content

Commit

Permalink
layout propagation for expandshape
Browse files Browse the repository at this point in the history
  • Loading branch information
yifeizh2 committed Jun 3, 2024
1 parent 2d31467 commit 4e7f7b9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 11 deletions.
9 changes: 5 additions & 4 deletions lib/gc/Analysis/GlobalAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,9 @@ inferTargetLayout(TensorLayout layoutBase,

GlobalAnalysis::GlobalAnalysis(Operation *root) {
root->walk([&](Operation *op) {
// get input layouts
LLVM_DEBUG(llvm::dbgs()
<< "Inferring layoutCache of op: " << op->getName() << "\n");
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
LLVM_DEBUG(llvm::dbgs()
<< "Inferring layout of op: " << op->getName() << "\n");
auto curInputs = linalgOp.getDpsInputOperands();
auto curResults = linalgOp.getOperation()->getResults();
// ---------------- Get Current Input Layouts -------------------
Expand Down Expand Up @@ -258,12 +257,14 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
}
auto inputOuterAxis = curInputLayout.getOuterAxis();
auto inputInnerAxis = curInputLayout.getInnerAxis();
int64_t diffDifference = staticOutputShape.size() - inputShape.size();
int64_t startIdx = 0;
SmallVector<int64_t> outputOuterAxis, outputInnerAxis;
for (int64_t i = 0; i < static_cast<int64_t>(staticOutputShape.size());
++i) {
if (outputInputIdxMapping.find(i) != outputInputIdxMapping.end()) {
outputOuterAxis.push_back(inputOuterAxis[outputInputIdxMapping[i]]);
outputOuterAxis.push_back(inputOuterAxis[outputInputIdxMapping[i]] +
diffDifference);
} else {
outputOuterAxis.push_back(startIdx++);
}
Expand Down
37 changes: 30 additions & 7 deletions lib/gc/Transforms/PropagateLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,6 @@ FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
LLVM_DEBUG(llvm::dbgs() << "Try packing named op "
<< linalgOp.getOperation()->getName() << ".\n");
Location loc = linalgOp->getLoc();
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
SmallVector<utils::IteratorType> iteratorTypes =
linalgOp.getIteratorTypesArray();

SmallVector<tensor::PackOp> packOps;
SmallVector<tensor::UnPackOp> unPackOps;
SmallVector<Value> inputsAndInits, results;
Expand Down Expand Up @@ -215,8 +211,9 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph,
IRRewriter rewriter(ctx);
auto walk = graph->walk([&](Operation *op) {
FailureOr<OperatorLayout> opLayout = controlFn(op);
if (isa<linalg::LinalgOp>(op) && !mlir::linalg::isaContractionOpInterface(
dyn_cast<linalg::LinalgOp>(op))) {
if ((isa<linalg::LinalgOp>(op) && !mlir::linalg::isaContractionOpInterface(
dyn_cast<linalg::LinalgOp>(op))) ||
isa<tensor::ExpandShapeOp>(op) || isa<tensor::PadOp>(op)) {
if (failed(opLayout)) {
LLVM_DEBUG(llvm::dbgs() << "Op " << op->getName()
<< "does not have layout information.\n");
Expand All @@ -235,9 +232,35 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph,
if (failed(packedOp)) {
return WalkResult::skip();
}
} else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op)) {
Location loc = expandShapeOp->getLoc();
auto inputLayout = opLayout->getSupportedInputLayouts()[0];
auto outputLayout = opLayout->getSupportedOutputLayouts()[0];
Value dest = tensor::PackOp::createDestinationTensor(
rewriter, loc, expandShapeOp.getSrc(), inputLayout.getTileSizes(),
inputLayout.getInnerAxis(), inputLayout.getOuterAxis());
Value packedSource = rewriter.create<tensor::PackOp>(
loc, expandShapeOp.getSrc(), dest, inputLayout.getInnerAxis(),
inputLayout.getTileSizes(), std::nullopt,
inputLayout.getOuterAxis());
auto resultType = RankedTensorType::get(
expandShapeOp.getStaticOutputShape(),
expandShapeOp.getSrcType().getElementType());
RankedTensorType resultPackType = tensor::PackOp::inferPackedType(
resultType, vector::getAsIntegers(outputLayout.getTileSizes()),
outputLayout.getInnerAxis(), outputLayout.getOuterAxis());
auto reassocExpand = getReassociationIndicesForReshape(
cast<ShapedType>(dest.getType()), resultPackType);
auto packedExpandShape = rewriter.create<tensor::ExpandShapeOp>(
loc, expandShapeOp.getSrcType().getElementType(), packedSource,
*reassocExpand);
Value result = rewriter.create<tensor::UnPackOp>(
packedExpandShape->getLoc(), packedExpandShape, packedExpandShape,
outputLayout.getInnerAxis(), outputLayout.getTileSizes(),
outputLayout.getOuterAxis());
rewriter.replaceOp(expandShapeOp, result);
}
}
} else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op)) {
}
return WalkResult::advance();
});
Expand Down

0 comments on commit 4e7f7b9

Please sign in to comment.