|
21 | 21 | #include "mlir/Transforms/DialectConversion.h"
|
22 | 22 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
23 | 23 |
|
| 24 | +#include "gc/Dialect/Linalgx/LinalgxDialect.h" |
| 25 | +#include "gc/Dialect/Linalgx/LinalgxOps.h" |
24 | 26 | #include "gc/Transforms/Passes.h"
|
25 | 27 | namespace mlir {
|
26 | 28 | namespace gc {
|
@@ -211,8 +213,10 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph,
|
211 | 213 | IRRewriter rewriter(ctx);
|
212 | 214 | auto walk = graph->walk([&](Operation *op) {
|
213 | 215 | FailureOr<OperatorLayout> opLayout = controlFn(op);
|
214 |
| - if ((isa<linalg::LinalgOp>(op) && !mlir::linalg::isaContractionOpInterface( |
215 |
| - dyn_cast<linalg::LinalgOp>(op))) || |
| 216 | + if ((isa<linalg::LinalgOp>(op) && |
| 217 | + !mlir::linalg::isaContractionOpInterface( |
| 218 | + dyn_cast<linalg::LinalgOp>(op)) && |
| 219 | + !isa<linalgx::Mm4DVnniOp>(op)) || |
216 | 220 | isa<tensor::ExpandShapeOp>(op) || isa<tensor::PadOp>(op)) {
|
217 | 221 | if (failed(opLayout)) {
|
218 | 222 | LLVM_DEBUG(llvm::dbgs() << "Op " << op->getName()
|
@@ -269,6 +273,50 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph,
|
269 | 273 | return success();
|
270 | 274 | }
|
271 | 275 |
|
| 276 | +static FailureOr<mlir::linalgx::Mm4DVnniOp> |
| 277 | +packVNNIMatmul(RewriterBase &rewriter, linalg::Mmt4DOp mmt4dOp) { |
| 278 | + auto elementType = getElementTypeOrSelf(mmt4dOp.getInputs()[0].getType()); |
| 279 | + if (!elementType.isBF16()) |
| 280 | + return rewriter.notifyMatchFailure(mmt4dOp, "require bf16 type"); |
| 281 | + Location loc = mmt4dOp.getLoc(); |
| 282 | + // NKnk --> NKkn2k |
| 283 | + SmallVector<int64_t> innerPos{3}; |
| 284 | + SmallVector<OpFoldResult> tileSize{rewriter.getIndexAttr(2)}; |
| 285 | + SmallVector<int64_t> outerPerm{0, 1, 3, 2}; |
| 286 | + OpOperand *RHSOperand = mmt4dOp.getDpsInputOperand(1); |
| 287 | + Value dest = tensor::PackOp::createDestinationTensor( |
| 288 | + rewriter, loc, RHSOperand->get(), tileSize, innerPos, outerPerm); |
| 289 | + Value VNNIPack = |
| 290 | + rewriter.create<tensor::PackOp>(loc, RHSOperand->get(), dest, innerPos, |
| 291 | + tileSize, std::nullopt, outerPerm); |
| 292 | + SmallVector<Value> inputsValues; |
| 293 | + SmallVector<OpOperand *> initOperands = llvm::to_vector(llvm::map_range( |
| 294 | + mmt4dOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; })); |
| 295 | + SmallVector<OpOperand *> inputOperands = mmt4dOp.getDpsInputOperands(); |
| 296 | + for (OpOperand *opOperand : inputOperands) { |
| 297 | + inputsValues.push_back(opOperand->get()); |
| 298 | + } |
| 299 | + inputsValues[1] = VNNIPack; |
| 300 | + auto vnniOp = rewriter.create<mlir::linalgx::Mm4DVnniOp>( |
| 301 | + loc, mmt4dOp.getDpsInits().getTypes(), inputsValues, |
| 302 | + mmt4dOp.getDpsInits()); |
| 303 | + rewriter.replaceOp(mmt4dOp, vnniOp); |
| 304 | + return vnniOp; |
| 305 | +} |
| 306 | + |
| 307 | +struct VNNIOnMatmul : public OpRewritePattern<linalg::Mmt4DOp> { |
| 308 | + VNNIOnMatmul(MLIRContext *context, PatternBenefit benefit = 1) |
| 309 | + : OpRewritePattern<linalg::Mmt4DOp>(context, benefit) {} |
| 310 | + LogicalResult matchAndRewrite(linalg::Mmt4DOp matmulOp, |
| 311 | + PatternRewriter &rewriter) const override { |
| 312 | + FailureOr<mlir::linalgx::Mm4DVnniOp> packedMatmul = |
| 313 | + packVNNIMatmul(rewriter, matmulOp); |
| 314 | + if (failed(packedMatmul)) |
| 315 | + return failure(); |
| 316 | + return success(); |
| 317 | + } |
| 318 | +}; |
| 319 | + |
272 | 320 | void PropagateLayoutOnNamedOps::runOnOperation() {
|
273 | 321 | MLIRContext *ctx = &getContext();
|
274 | 322 | mlir::Operation *graph = getOperation();
|
@@ -296,7 +344,13 @@ void PropagateLayoutOnNamedOps::runOnOperation() {
|
296 | 344 | if (failed(applyPatternsAndFoldGreedily(graph, std::move(patterns))))
|
297 | 345 | return signalPassFailure();
|
298 | 346 |
|
299 |
| - // stage3: propagate layout on other namsed ops |
| 347 | + // stage2: pack VNNI |
| 348 | + RewritePatternSet VNNIPatterns(&getContext()); |
| 349 | + VNNIPatterns.add<VNNIOnMatmul>(ctx); |
| 350 | + if (failed(applyPatternsAndFoldGreedily(graph, std::move(VNNIPatterns)))) |
| 351 | + return signalPassFailure(); |
| 352 | + |
| 353 | + // stage3: propagate layout on other named ops |
300 | 354 | ControlPackNamedOpsFn layoutControlFn =
|
301 | 355 | [&](Operation *op) -> FailureOr<OperatorLayout> {
|
302 | 356 | auto &layoutAnalysisResult = getAnalysis<GlobalAnalysis>();
|
|
0 commit comments