Skip to content

Commit 0087328

Browse files
committed
fix broadcast axis and binary target input
1 parent 10d90d6 commit 0087328

File tree

2 files changed

+66
-38
lines changed

2 files changed

+66
-38
lines changed

lib/gc/Analysis/GlobalAnalysis.cpp

+28-10
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,15 @@ inferTargetLayout(TensorLayout layoutBase,
154154
return TensorLayout(targetOuterAxis, targetInnerAxis, targetTileSizes);
155155
}
156156

157+
static size_t getTargetInputIdx(ArrayRef<TensorLayout> curInputLayouts) {
158+
for (auto i = 0; i < curInputLayouts.size(); ++i) {
159+
if (!curInputLayouts[i].isPlainLayout()) {
160+
return i;
161+
}
162+
}
163+
return 0;
164+
}
165+
157166
GlobalAnalysis::GlobalAnalysis(Operation *root) {
158167
root->walk([&](Operation *op) {
159168
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
@@ -201,23 +210,32 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
201210
layoutCache[linalgOp] = suggestedLayout;
202211
} else {
203212
SmallVector<TensorLayout> inputLayouts, outputLayouts;
204-
inputLayouts.push_back(curInputLayouts[0]);
213+
size_t targetIdx = getTargetInputIdx(curInputLayouts);
205214
// TODO(yifei): wisely choose the input format basis
206215
// Let's only refer to input[0] for now
207-
for (size_t i = 1; i < curInputs.size(); ++i) {
216+
for (size_t i = 0; i < curInputs.size(); ++i) {
217+
std::cout << "inferring indexing map relation" << std::endl;
208218
// getMatchingIndexingMap
209-
auto res = inferIndexingMapRelation(
210-
linalgOp.getMatchingIndexingMap(curInputs[0]),
211-
linalgOp.getMatchingIndexingMap(curInputs[i]));
212-
TensorLayout inputLayout =
213-
*inferTargetLayout(curInputLayouts[0], *res);
214-
inputLayouts.push_back(inputLayout);
219+
if (i != targetIdx) {
220+
auto res = inferIndexingMapRelation(
221+
linalgOp.getMatchingIndexingMap(curInputs[targetIdx]),
222+
linalgOp.getMatchingIndexingMap(curInputs[i]));
223+
for (auto tp : *res) {
224+
std::cout << "target index: " << tp.first
225+
<< " maps to base index: " << tp.second << std::endl;
226+
}
227+
TensorLayout inputLayout =
228+
*inferTargetLayout(curInputLayouts[targetIdx], *res);
229+
inputLayouts.push_back(inputLayout);
230+
} else {
231+
inputLayouts.push_back(curInputLayouts[targetIdx]);
232+
}
215233
}
216234
auto res_out = inferIndexingMapRelation(
217-
linalgOp.getMatchingIndexingMap(curInputs[0]),
235+
linalgOp.getMatchingIndexingMap(curInputs[targetIdx]),
218236
linalgOp.getIndexingMapMatchingResult(curResults[0]));
219237
TensorLayout outputLayout =
220-
*inferTargetLayout(curInputLayouts[0], *res_out);
238+
*inferTargetLayout(curInputLayouts[targetIdx], *res_out);
221239
outputLayouts.push_back(outputLayout);
222240
OperatorLayout suggestedLayout(inputLayouts, outputLayouts);
223241
layoutCache[linalgOp] = suggestedLayout;

lib/gc/Transforms/PropagateLayout.cpp

+38-28
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ using namespace mlir::tensor;
3838
static SmallVector<int64_t> getPackedAxes(ArrayRef<int64_t> dimensions,
3939
TensorLayout targetLayout) {
4040
SmallVector<int64_t> result(dimensions);
41+
// permuting on outer axis
42+
auto outerPerm = targetLayout.getOuterAxis();
43+
for (size_t i = 0; i < dimensions.size(); ++i) {
44+
result[i] = outerPerm[dimensions[i]];
45+
}
46+
// inserting inner axis
4147
auto innerPos = targetLayout.getInnerAxis();
4248
for (size_t i = 0; i < dimensions.size(); ++i) {
4349
if (std::find(innerPos.begin(), innerPos.end(), dimensions[i]) !=
@@ -153,8 +159,10 @@ FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
153159
loc, inits.getTypes(), inputs, inits, packedAxes);
154160
packedLinalgOp->getRegion(0).takeBody(linalgOp->getRegion(0));
155161
} else if (auto broadcastOp = dyn_cast<linalg::BroadcastOp>(&linalgOp)) {
156-
packedLinalgOp = rewriter.create<linalg::BroadcastOp>(
157-
loc, inputs[0], inits[0], broadcastOp->getDimensions());
162+
SmallVector<int64_t> packedAxes =
163+
getPackedAxes(broadcastOp->getDimensions(), initLayouts[0]);
164+
packedLinalgOp = rewriter.create<linalg::BroadcastOp>(loc, inputs[0],
165+
inits[0], packedAxes);
158166
} else if (auto transposeOp = dyn_cast<linalg::TransposeOp>(&linalgOp)) {
159167
SmallVector<int64_t> packedPermAxes = getPackedPermAxes(
160168
transposeOp->getPermutation(), inputLayouts[0], initLayouts[0]);
@@ -237,39 +245,41 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph,
237245
return WalkResult::skip();
238246
}
239247
} else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op)) {
240-
Location loc = expandShapeOp->getLoc();
241-
auto inputLayout = opLayout->getSupportedInputLayouts()[0];
242-
auto outputLayout = opLayout->getSupportedOutputLayouts()[0];
243-
Value dest = tensor::PackOp::createDestinationTensor(
244-
rewriter, loc, expandShapeOp.getSrc(), inputLayout.getTileSizes(),
245-
inputLayout.getInnerAxis(), inputLayout.getOuterAxis());
246-
Value packedSource = rewriter.create<tensor::PackOp>(
247-
loc, expandShapeOp.getSrc(), dest, inputLayout.getInnerAxis(),
248-
inputLayout.getTileSizes(), std::nullopt,
249-
inputLayout.getOuterAxis());
250-
auto resultType = RankedTensorType::get(
251-
expandShapeOp.getStaticOutputShape(),
252-
expandShapeOp.getSrcType().getElementType());
253-
RankedTensorType resultPackType = tensor::PackOp::inferPackedType(
254-
resultType, vector::getAsIntegers(outputLayout.getTileSizes()),
255-
outputLayout.getInnerAxis(), outputLayout.getOuterAxis());
256-
auto reassocExpand = getReassociationIndicesForReshape(
257-
cast<ShapedType>(dest.getType()), resultPackType);
258-
auto packedExpandShape = rewriter.create<tensor::ExpandShapeOp>(
259-
loc, expandShapeOp.getSrcType().getElementType(), packedSource,
260-
*reassocExpand);
261-
Value result = rewriter.create<tensor::UnPackOp>(
262-
packedExpandShape->getLoc(), packedExpandShape, packedExpandShape,
263-
outputLayout.getInnerAxis(), outputLayout.getTileSizes(),
264-
outputLayout.getOuterAxis());
265-
rewriter.replaceOp(expandShapeOp, result);
248+
// Location loc = expandShapeOp->getLoc();
249+
// auto inputLayout = opLayout->getSupportedInputLayouts()[0];
250+
// auto outputLayout = opLayout->getSupportedOutputLayouts()[0];
251+
// Value dest = tensor::PackOp::createDestinationTensor(
252+
// rewriter, loc, expandShapeOp.getSrc(),
253+
// inputLayout.getTileSizes(), inputLayout.getInnerAxis(),
254+
// inputLayout.getOuterAxis());
255+
// Value packedSource = rewriter.create<tensor::PackOp>(
256+
// loc, expandShapeOp.getSrc(), dest, inputLayout.getInnerAxis(),
257+
// inputLayout.getTileSizes(), std::nullopt,
258+
// inputLayout.getOuterAxis());
259+
// auto resultType = RankedTensorType::get(
260+
// expandShapeOp.getStaticOutputShape(),
261+
// expandShapeOp.getSrcType().getElementType());
262+
// RankedTensorType resultPackType = tensor::PackOp::inferPackedType(
263+
// resultType, vector::getAsIntegers(outputLayout.getTileSizes()),
264+
// outputLayout.getInnerAxis(), outputLayout.getOuterAxis());
265+
// auto reassocExpand = getReassociationIndicesForReshape(
266+
// cast<ShapedType>(dest.getType()), resultPackType);
267+
// auto packedExpandShape = rewriter.create<tensor::ExpandShapeOp>(
268+
// loc, expandShapeOp.getSrcType().getElementType(), packedSource,
269+
// *reassocExpand);
270+
// Value result = rewriter.create<tensor::UnPackOp>(
271+
// packedExpandShape->getLoc(), packedExpandShape,
272+
// packedExpandShape, outputLayout.getInnerAxis(),
273+
// outputLayout.getTileSizes(), outputLayout.getOuterAxis());
274+
// rewriter.replaceOp(expandShapeOp, result);
266275
}
267276
}
268277
}
269278
return WalkResult::advance();
270279
});
271280
if (walk.wasSkipped())
272281
return failure();
282+
graph->dump();
273283
return success();
274284
}
275285

0 commit comments

Comments
 (0)