Skip to content

Commit 151f3c2

Browse files
committed
extend LinalgBlockPackMatmul pass to support mmt4d op
1 parent 5c60edd commit 151f3c2

File tree

2 files changed

+59
-2
lines changed

2 files changed

+59
-2
lines changed

lib/gc/Transforms/PropagateLayout.cpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,12 +272,37 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph,
272272
void PropagateLayoutOnNamedOps::runOnOperation() {
273273
MLIRContext *ctx = &getContext();
274274
mlir::Operation *graph = getOperation();
275-
ControlPackNamedOpsFn controlFn =
275+
// stage1:
276+
RewritePatternSet patterns(&getContext());
277+
mlir::linalg::ControlBlockPackMatmulFn packMatmulControlFn =
278+
[&](linalg::LinalgOp op) -> mlir::linalg::BlockPackMatmulOptions {
279+
mlir::linalg::BlockPackMatmulOptions options;
280+
auto &layoutAnalysisResult = getAnalysis<GlobalAnalysis>();
281+
auto matmulLayout = *(layoutAnalysisResult.getOpLayout(op));
282+
TensorLayout LHSLayout = matmulLayout.getSupportedInputLayouts()[0];
283+
TensorLayout RHSLayout = matmulLayout.getSupportedInputLayouts()[1];
284+
// hardcode to mmt4d format
285+
options.rhsTransposeOuterBlocks = true;
286+
options.rhsTransposeInnerBlocks = true;
287+
options.blockFactors.push_back(
288+
*getConstantIntValue(LHSLayout.getTileSizes()[0]));
289+
options.blockFactors.push_back(
290+
*getConstantIntValue(LHSLayout.getTileSizes()[1]));
291+
options.blockFactors.push_back(
292+
*getConstantIntValue(RHSLayout.getTileSizes()[1]));
293+
return options;
294+
};
295+
linalg::populateBlockPackMatmulPatterns(patterns, packMatmulControlFn);
296+
if (failed(applyPatternsAndFoldGreedily(graph, std::move(patterns))))
297+
return signalPassFailure();
298+
299+
// stage3: propagate layout on other namsed ops
300+
ControlPackNamedOpsFn layoutControlFn =
276301
[&](Operation *op) -> FailureOr<OperatorLayout> {
277302
auto &layoutAnalysisResult = getAnalysis<GlobalAnalysis>();
278303
return layoutAnalysisResult.getOpLayout(op);
279304
};
280-
if (failed(namedOpLayoutPropagation(ctx, graph, controlFn)))
305+
if (failed(namedOpLayoutPropagation(ctx, graph, layoutControlFn)))
281306
return signalPassFailure();
282307
}
283308

packMatmul.patch

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
2+
index 91d4efa3372b..f3f61ff92140 100644
3+
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
4+
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
5+
@@ -210,6 +210,19 @@ linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
6+
packedMatmul->packOps[1] = packedRhs->transposedPackOp;
7+
packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp;
8+
9+
+ // rewrite generic to mmt4d
10+
+ if (!options->lhsTransposeOuterBlocks && !options->lhsTransposeInnerBlocks &&
11+
+ options->rhsTransposeOuterBlocks && options->rhsTransposeInnerBlocks &&
12+
+ options->mnkOrder == SmallVector<int64_t>{0, 1, 2}) {
13+
+ auto originalLinalgOp = packedMatmul->packedLinalgOp;
14+
+ rewriter.setInsertionPoint(originalLinalgOp);
15+
+ auto mmt4d = rewriter.create<linalg::Mmt4DOp>(
16+
+ originalLinalgOp.getLoc(), originalLinalgOp.getDpsInits().getTypes(),
17+
+ originalLinalgOp.getDpsInputs(), originalLinalgOp.getDpsInits());
18+
+ rewriter.replaceOp(originalLinalgOp, mmt4d);
19+
+ packedMatmul->packedLinalgOp = mmt4d;
20+
+ }
21+
+
22+
return packedMatmul;
23+
}
24+
25+
@@ -307,6 +320,7 @@ struct LinalgBlockPackMatmul
26+
};
27+
} // namespace
28+
29+
+// extend to transform to mmt4d or batch_mmt4d
30+
void linalg::populateBlockPackMatmulPatterns(
31+
RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
32+
patterns.add<BlockPackMatmul<linalg::GenericOp>,

0 commit comments

Comments
 (0)