@@ -272,12 +272,37 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph,
272
272
void PropagateLayoutOnNamedOps::runOnOperation () {
273
273
MLIRContext *ctx = &getContext ();
274
274
mlir::Operation *graph = getOperation ();
275
- ControlPackNamedOpsFn controlFn =
275
+ // stage1:
276
+ RewritePatternSet patterns (&getContext ());
277
+ mlir::linalg::ControlBlockPackMatmulFn packMatmulControlFn =
278
+ [&](linalg::LinalgOp op) -> mlir::linalg::BlockPackMatmulOptions {
279
+ mlir::linalg::BlockPackMatmulOptions options;
280
+ auto &layoutAnalysisResult = getAnalysis<GlobalAnalysis>();
281
+ auto matmulLayout = *(layoutAnalysisResult.getOpLayout (op));
282
+ TensorLayout LHSLayout = matmulLayout.getSupportedInputLayouts ()[0 ];
283
+ TensorLayout RHSLayout = matmulLayout.getSupportedInputLayouts ()[1 ];
284
+ // hardcode to mmt4d format
285
+ options.rhsTransposeOuterBlocks = true ;
286
+ options.rhsTransposeInnerBlocks = true ;
287
+ options.blockFactors .push_back (
288
+ *getConstantIntValue (LHSLayout.getTileSizes ()[0 ]));
289
+ options.blockFactors .push_back (
290
+ *getConstantIntValue (LHSLayout.getTileSizes ()[1 ]));
291
+ options.blockFactors .push_back (
292
+ *getConstantIntValue (RHSLayout.getTileSizes ()[1 ]));
293
+ return options;
294
+ };
295
+ linalg::populateBlockPackMatmulPatterns (patterns, packMatmulControlFn);
296
+ if (failed (applyPatternsAndFoldGreedily (graph, std::move (patterns))))
297
+ return signalPassFailure ();
298
+
299
+ // stage3: propagate layout on other namsed ops
300
+ ControlPackNamedOpsFn layoutControlFn =
276
301
[&](Operation *op) -> FailureOr<OperatorLayout> {
277
302
auto &layoutAnalysisResult = getAnalysis<GlobalAnalysis>();
278
303
return layoutAnalysisResult.getOpLayout (op);
279
304
};
280
- if (failed (namedOpLayoutPropagation (ctx, graph, controlFn )))
305
+ if (failed (namedOpLayoutPropagation (ctx, graph, layoutControlFn )))
281
306
return signalPassFailure ();
282
307
}
283
308
0 commit comments