Skip to content

Commit aa1243a

Browse files
committed
add pack vnni
1 parent 151f3c2 commit aa1243a

File tree

3 files changed

+66
-5
lines changed

3 files changed

+66
-5
lines changed

include/gc/Transforms/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ namespace linalgx {
3232
class LinalgxDialect;
3333
}
3434

35+
namespace linalgx {
36+
class LinalgxDialect;
37+
}
38+
3539
namespace MemRef {
3640
class MemRefDialect;
3741
}

include/gc/Transforms/Passes.td

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,11 @@ def PropagateLayoutOnNamedOps : Pass<"propagate-layout-on-named-ops"> {
5151
let description = [{
5252
Insert and propagte tensor.pack
5353
}];
54-
let dependentDialects = ["mlir::tensor::TensorDialect",
55-
"mlir::linalg::LinalgDialect"];
54+
let dependentDialects = [
55+
"mlir::tensor::TensorDialect",
56+
"mlir::linalg::LinalgDialect",
57+
"mlir::linalgx::LinalgxDialect"
58+
];
5659
}
5760

5861
#endif // GC_DIALECT_GC_PASSES

lib/gc/Transforms/PropagateLayout.cpp

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#include "mlir/Transforms/DialectConversion.h"
2222
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2323

24+
#include "gc/Dialect/Linalgx/LinalgxDialect.h"
25+
#include "gc/Dialect/Linalgx/LinalgxOps.h"
2426
#include "gc/Transforms/Passes.h"
2527
namespace mlir {
2628
namespace gc {
@@ -211,8 +213,10 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph,
211213
IRRewriter rewriter(ctx);
212214
auto walk = graph->walk([&](Operation *op) {
213215
FailureOr<OperatorLayout> opLayout = controlFn(op);
214-
if ((isa<linalg::LinalgOp>(op) && !mlir::linalg::isaContractionOpInterface(
215-
dyn_cast<linalg::LinalgOp>(op))) ||
216+
if ((isa<linalg::LinalgOp>(op) &&
217+
!mlir::linalg::isaContractionOpInterface(
218+
dyn_cast<linalg::LinalgOp>(op)) &&
219+
!isa<linalgx::Mm4DVnniOp>(op)) ||
216220
isa<tensor::ExpandShapeOp>(op) || isa<tensor::PadOp>(op)) {
217221
if (failed(opLayout)) {
218222
LLVM_DEBUG(llvm::dbgs() << "Op " << op->getName()
@@ -269,6 +273,50 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph,
269273
return success();
270274
}
271275

276+
static FailureOr<mlir::linalgx::Mm4DVnniOp>
277+
packVNNIMatmul(RewriterBase &rewriter, linalg::Mmt4DOp mmt4dOp) {
278+
auto elementType = getElementTypeOrSelf(mmt4dOp.getInputs()[0].getType());
279+
if (!elementType.isBF16())
280+
return rewriter.notifyMatchFailure(mmt4dOp, "require bf16 type");
281+
Location loc = mmt4dOp.getLoc();
282+
// NKnk --> NKkn2k
283+
SmallVector<int64_t> innerPos{3};
284+
SmallVector<OpFoldResult> tileSize{rewriter.getIndexAttr(2)};
285+
SmallVector<int64_t> outerPerm{0, 1, 3, 2};
286+
OpOperand *RHSOperand = mmt4dOp.getDpsInputOperand(1);
287+
Value dest = tensor::PackOp::createDestinationTensor(
288+
rewriter, loc, RHSOperand->get(), tileSize, innerPos, outerPerm);
289+
Value VNNIPack =
290+
rewriter.create<tensor::PackOp>(loc, RHSOperand->get(), dest, innerPos,
291+
tileSize, std::nullopt, outerPerm);
292+
SmallVector<Value> inputsValues;
293+
SmallVector<OpOperand *> initOperands = llvm::to_vector(llvm::map_range(
294+
mmt4dOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
295+
SmallVector<OpOperand *> inputOperands = mmt4dOp.getDpsInputOperands();
296+
for (OpOperand *opOperand : inputOperands) {
297+
inputsValues.push_back(opOperand->get());
298+
}
299+
inputsValues[1] = VNNIPack;
300+
auto vnniOp = rewriter.create<mlir::linalgx::Mm4DVnniOp>(
301+
loc, mmt4dOp.getDpsInits().getTypes(), inputsValues,
302+
mmt4dOp.getDpsInits());
303+
rewriter.replaceOp(mmt4dOp, vnniOp);
304+
return vnniOp;
305+
}
306+
307+
struct VNNIOnMatmul : public OpRewritePattern<linalg::Mmt4DOp> {
308+
VNNIOnMatmul(MLIRContext *context, PatternBenefit benefit = 1)
309+
: OpRewritePattern<linalg::Mmt4DOp>(context, benefit) {}
310+
LogicalResult matchAndRewrite(linalg::Mmt4DOp matmulOp,
311+
PatternRewriter &rewriter) const override {
312+
FailureOr<mlir::linalgx::Mm4DVnniOp> packedMatmul =
313+
packVNNIMatmul(rewriter, matmulOp);
314+
if (failed(packedMatmul))
315+
return failure();
316+
return success();
317+
}
318+
};
319+
272320
void PropagateLayoutOnNamedOps::runOnOperation() {
273321
MLIRContext *ctx = &getContext();
274322
mlir::Operation *graph = getOperation();
@@ -296,7 +344,13 @@ void PropagateLayoutOnNamedOps::runOnOperation() {
296344
if (failed(applyPatternsAndFoldGreedily(graph, std::move(patterns))))
297345
return signalPassFailure();
298346

299-
// stage3: propagate layout on other namsed ops
347+
// stage2: pack VNNI
348+
RewritePatternSet VNNIPatterns(&getContext());
349+
VNNIPatterns.add<VNNIOnMatmul>(ctx);
350+
if (failed(applyPatternsAndFoldGreedily(graph, std::move(VNNIPatterns))))
351+
return signalPassFailure();
352+
353+
// stage3: propagate layout on other named ops
300354
ControlPackNamedOpsFn layoutControlFn =
301355
[&](Operation *op) -> FailureOr<OperatorLayout> {
302356
auto &layoutAnalysisResult = getAnalysis<GlobalAnalysis>();

0 commit comments

Comments
 (0)