Skip to content

Commit

Permalink
fix broadcast axis and binary target input
Browse files Browse the repository at this point in the history
  • Loading branch information
yifeizh2 committed Jul 17, 2024
1 parent 10d90d6 commit 0087328
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 38 deletions.
38 changes: 28 additions & 10 deletions lib/gc/Analysis/GlobalAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,15 @@ inferTargetLayout(TensorLayout layoutBase,
return TensorLayout(targetOuterAxis, targetInnerAxis, targetTileSizes);
}

static size_t getTargetInputIdx(ArrayRef<TensorLayout> curInputLayouts) {
for (auto i = 0; i < curInputLayouts.size(); ++i) {
if (!curInputLayouts[i].isPlainLayout()) {
return i;
}
}
return 0;
}

GlobalAnalysis::GlobalAnalysis(Operation *root) {
root->walk([&](Operation *op) {
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
Expand Down Expand Up @@ -201,23 +210,32 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
layoutCache[linalgOp] = suggestedLayout;
} else {
SmallVector<TensorLayout> inputLayouts, outputLayouts;
inputLayouts.push_back(curInputLayouts[0]);
size_t targetIdx = getTargetInputIdx(curInputLayouts);
// TODO(yifei): wisely choose the input format basis
// Let's only refer to input[0] for now
for (size_t i = 1; i < curInputs.size(); ++i) {
for (size_t i = 0; i < curInputs.size(); ++i) {
std::cout << "inferring indexing map relation" << std::endl;
// getMatchingIndexingMap
auto res = inferIndexingMapRelation(
linalgOp.getMatchingIndexingMap(curInputs[0]),
linalgOp.getMatchingIndexingMap(curInputs[i]));
TensorLayout inputLayout =
*inferTargetLayout(curInputLayouts[0], *res);
inputLayouts.push_back(inputLayout);
if (i != targetIdx) {
auto res = inferIndexingMapRelation(
linalgOp.getMatchingIndexingMap(curInputs[targetIdx]),
linalgOp.getMatchingIndexingMap(curInputs[i]));
for (auto tp : *res) {
std::cout << "target index: " << tp.first
<< " maps to base index: " << tp.second << std::endl;
}
TensorLayout inputLayout =
*inferTargetLayout(curInputLayouts[targetIdx], *res);
inputLayouts.push_back(inputLayout);
} else {
inputLayouts.push_back(curInputLayouts[targetIdx]);
}
}
auto res_out = inferIndexingMapRelation(
linalgOp.getMatchingIndexingMap(curInputs[0]),
linalgOp.getMatchingIndexingMap(curInputs[targetIdx]),
linalgOp.getIndexingMapMatchingResult(curResults[0]));
TensorLayout outputLayout =
*inferTargetLayout(curInputLayouts[0], *res_out);
*inferTargetLayout(curInputLayouts[targetIdx], *res_out);
outputLayouts.push_back(outputLayout);
OperatorLayout suggestedLayout(inputLayouts, outputLayouts);
layoutCache[linalgOp] = suggestedLayout;
Expand Down
66 changes: 38 additions & 28 deletions lib/gc/Transforms/PropagateLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ using namespace mlir::tensor;
static SmallVector<int64_t> getPackedAxes(ArrayRef<int64_t> dimensions,
TensorLayout targetLayout) {
SmallVector<int64_t> result(dimensions);
// permuting on outer axis
auto outerPerm = targetLayout.getOuterAxis();
for (size_t i = 0; i < dimensions.size(); ++i) {
result[i] = outerPerm[dimensions[i]];
}
// inserting inner axis
auto innerPos = targetLayout.getInnerAxis();
for (size_t i = 0; i < dimensions.size(); ++i) {
if (std::find(innerPos.begin(), innerPos.end(), dimensions[i]) !=
Expand Down Expand Up @@ -153,8 +159,10 @@ FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
loc, inits.getTypes(), inputs, inits, packedAxes);
packedLinalgOp->getRegion(0).takeBody(linalgOp->getRegion(0));
} else if (auto broadcastOp = dyn_cast<linalg::BroadcastOp>(&linalgOp)) {
packedLinalgOp = rewriter.create<linalg::BroadcastOp>(
loc, inputs[0], inits[0], broadcastOp->getDimensions());
SmallVector<int64_t> packedAxes =
getPackedAxes(broadcastOp->getDimensions(), initLayouts[0]);
packedLinalgOp = rewriter.create<linalg::BroadcastOp>(loc, inputs[0],
inits[0], packedAxes);
} else if (auto transposeOp = dyn_cast<linalg::TransposeOp>(&linalgOp)) {
SmallVector<int64_t> packedPermAxes = getPackedPermAxes(
transposeOp->getPermutation(), inputLayouts[0], initLayouts[0]);
Expand Down Expand Up @@ -237,39 +245,41 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph,
return WalkResult::skip();
}
} else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op)) {
Location loc = expandShapeOp->getLoc();
auto inputLayout = opLayout->getSupportedInputLayouts()[0];
auto outputLayout = opLayout->getSupportedOutputLayouts()[0];
Value dest = tensor::PackOp::createDestinationTensor(
rewriter, loc, expandShapeOp.getSrc(), inputLayout.getTileSizes(),
inputLayout.getInnerAxis(), inputLayout.getOuterAxis());
Value packedSource = rewriter.create<tensor::PackOp>(
loc, expandShapeOp.getSrc(), dest, inputLayout.getInnerAxis(),
inputLayout.getTileSizes(), std::nullopt,
inputLayout.getOuterAxis());
auto resultType = RankedTensorType::get(
expandShapeOp.getStaticOutputShape(),
expandShapeOp.getSrcType().getElementType());
RankedTensorType resultPackType = tensor::PackOp::inferPackedType(
resultType, vector::getAsIntegers(outputLayout.getTileSizes()),
outputLayout.getInnerAxis(), outputLayout.getOuterAxis());
auto reassocExpand = getReassociationIndicesForReshape(
cast<ShapedType>(dest.getType()), resultPackType);
auto packedExpandShape = rewriter.create<tensor::ExpandShapeOp>(
loc, expandShapeOp.getSrcType().getElementType(), packedSource,
*reassocExpand);
Value result = rewriter.create<tensor::UnPackOp>(
packedExpandShape->getLoc(), packedExpandShape, packedExpandShape,
outputLayout.getInnerAxis(), outputLayout.getTileSizes(),
outputLayout.getOuterAxis());
rewriter.replaceOp(expandShapeOp, result);
// Location loc = expandShapeOp->getLoc();
// auto inputLayout = opLayout->getSupportedInputLayouts()[0];
// auto outputLayout = opLayout->getSupportedOutputLayouts()[0];
// Value dest = tensor::PackOp::createDestinationTensor(
// rewriter, loc, expandShapeOp.getSrc(),
// inputLayout.getTileSizes(), inputLayout.getInnerAxis(),
// inputLayout.getOuterAxis());
// Value packedSource = rewriter.create<tensor::PackOp>(
// loc, expandShapeOp.getSrc(), dest, inputLayout.getInnerAxis(),
// inputLayout.getTileSizes(), std::nullopt,
// inputLayout.getOuterAxis());
// auto resultType = RankedTensorType::get(
// expandShapeOp.getStaticOutputShape(),
// expandShapeOp.getSrcType().getElementType());
// RankedTensorType resultPackType = tensor::PackOp::inferPackedType(
// resultType, vector::getAsIntegers(outputLayout.getTileSizes()),
// outputLayout.getInnerAxis(), outputLayout.getOuterAxis());
// auto reassocExpand = getReassociationIndicesForReshape(
// cast<ShapedType>(dest.getType()), resultPackType);
// auto packedExpandShape = rewriter.create<tensor::ExpandShapeOp>(
// loc, expandShapeOp.getSrcType().getElementType(), packedSource,
// *reassocExpand);
// Value result = rewriter.create<tensor::UnPackOp>(
// packedExpandShape->getLoc(), packedExpandShape,
// packedExpandShape, outputLayout.getInnerAxis(),
// outputLayout.getTileSizes(), outputLayout.getOuterAxis());
// rewriter.replaceOp(expandShapeOp, result);
}
}
}
return WalkResult::advance();
});
if (walk.wasSkipped())
return failure();
graph->dump();
return success();
}

Expand Down

0 comments on commit 0087328

Please sign in to comment.