Skip to content

Commit a2e7988

Browse files
committed
add postprocess pack passes
1 parent b6b090f commit a2e7988

File tree

4 files changed

+141
-0
lines changed

4 files changed

+141
-0
lines changed

include/gc/Transforms/Passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,15 @@ def PropagateLayoutOnNamedOps : Pass<"propagate-layout-on-named-ops"> {
5858
];
5959
}
6060

61+
def PostProcessPackUnpack : Pass<"post-process-pack-unpack"> {
62+
let summary = "Fold and simplify pack and unpack ops.";
63+
let description = [{
64+
Fold and simplify pack and unpack ops.
65+
}];
66+
let dependentDialects = [
67+
"mlir::tensor::TensorDialect",
68+
"mlir::linalg::LinalgDialect"
69+
];
70+
}
71+
6172
#endif // GC_DIALECT_GC_PASSES

lib/gc/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_mlir_library(GCPasses
1111
OneDNNGraphToLinalg.cpp
1212
Pipeline.cpp
1313
PropagateLayout.cpp
14+
PostProcessPackUnpack.cpp
1415
TileNamed.cpp
1516

1617
ADDITIONAL_HEADER_DIRS

lib/gc/Transforms/Pipeline.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ void populateFrontendPasses(mlir::PassManager &pm) {
4040
void populateTensorPasses(mlir::PassManager &pm) {
4141
// todo: padding propagation pass
4242
// todo: layout propagation pass
43+
pm.addPass(createPropagateLayoutOnNamedOps());
44+
pm.addPass(createPostProcessPackUnpack());
4345
// todo: tensor constant propagation pass
4446
// todo: linalg.matmul lowering to (scf.loop + linalg.brgemm) pass
4547
// todo: fine-grain fusion pass
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
//===- PostProcessPackUnpack.cpp - Fold and simplify pack unpack *-- C++-*-===//
2+
//
3+
// This file is only temporarily used to extend upstream or upcoming utility in
4+
// TilingInterface, which finally aims for upstream.
5+
//
6+
//===----------------------------------------------------------------------===//
7+
8+
#include <iostream>
9+
#include <numeric>
10+
11+
#include "gc/Transforms/Transforms.h"
12+
#include "mlir/Dialect/Arith/IR/Arith.h"
13+
#include "mlir/Dialect/Func/IR/FuncOps.h"
14+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
15+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
17+
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
18+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
19+
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
20+
#include "mlir/IR/PatternMatch.h"
21+
#include "mlir/Transforms/DialectConversion.h"
22+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23+
24+
#include "gc/Dialect/Linalgx/LinalgxDialect.h"
25+
#include "gc/Dialect/Linalgx/LinalgxOps.h"
26+
#include "gc/Transforms/Passes.h"
27+
namespace mlir {
28+
namespace gc {
29+
#define GEN_PASS_DEF_POSTPROCESSPACKUNPACK
30+
#include "gc/Transforms/Passes.h.inc"
31+
32+
#define DEBUG_TYPE "post-process-pack-unpack"
33+
34+
using namespace mlir;
35+
36+
// Helper pattern - lower tensor.pack operations that pack constants.
37+
struct LowerConstantPacking : public OpRewritePattern<tensor::PackOp> {
38+
using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
39+
40+
LogicalResult matchAndRewrite(tensor::PackOp packOp,
41+
PatternRewriter &rewriter) const override {
42+
auto constOp = packOp.getSource().getDefiningOp<arith::ConstantOp>();
43+
if (!constOp)
44+
return failure();
45+
// Must be a dense constant.
46+
auto denseAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
47+
if (!denseAttr)
48+
return failure();
49+
50+
// Bail out if the pack is used as a writing operation i.e., the destination
51+
// is not a tensor.empty.
52+
if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
53+
return rewriter.notifyMatchFailure(packOp,
54+
"expects empty tensor destination");
55+
// Pack destination must have static shape.
56+
if (!packOp.getDestType().hasStaticShape())
57+
return rewriter.notifyMatchFailure(
58+
packOp, "expects destination with static shape");
59+
60+
// If it is a splat constant, skip and let tensor.pack folder to handle this
61+
// case.
62+
if (denseAttr.isSplat())
63+
return rewriter.notifyMatchFailure(
64+
packOp, "skip pack - existing folder covers constant splats");
65+
66+
return linalg::lowerPack(rewriter, packOp);
67+
}
68+
};
69+
70+
static void tppPopulateConstantFoldPack(RewritePatternSet &patterns) {
71+
MLIRContext *ctx = patterns.getContext();
72+
patterns.add<LowerConstantPacking>(ctx);
73+
// Apply canonicalization to fold trivial cases and linalg constant folders
74+
// to cleanup lowered packs.
75+
linalg::FillOp::getCanonicalizationPatterns(patterns, ctx);
76+
tensor::PackOp::getCanonicalizationPatterns(patterns, ctx);
77+
tensor::populateRewriteAsConstantPatterns(
78+
patterns, [](OpOperand *) -> bool { return true; });
79+
linalg::populateConstantFoldLinalgOperations(
80+
patterns, [](OpOperand *) -> bool { return true; });
81+
}
82+
83+
class PostProcessPackUnpack
84+
: public impl::PostProcessPackUnpackBase<PostProcessPackUnpack> {
85+
public:
86+
using impl::PostProcessPackUnpackBase<
87+
PostProcessPackUnpack>::PostProcessPackUnpackBase;
88+
void runOnOperation() final;
89+
};
90+
91+
static void tppPopulateSimplifyPacking(RewritePatternSet &patterns) {
92+
MLIRContext *ctx = patterns.getContext();
93+
tensor::populateSimplifyPackAndUnpackPatterns(patterns);
94+
tensor::populateFoldTensorEmptyPatterns(patterns);
95+
tensor::PackOp::getCanonicalizationPatterns(patterns, ctx);
96+
tensor::UnPackOp::getCanonicalizationPatterns(patterns, ctx);
97+
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
98+
tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, ctx);
99+
tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
100+
tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx);
101+
tensor::EmptyOp::getCanonicalizationPatterns(patterns, ctx);
102+
tensor::PadOp::getCanonicalizationPatterns(patterns, ctx);
103+
tensor::ParallelInsertSliceOp::getCanonicalizationPatterns(patterns, ctx);
104+
scf::ForallOp::getCanonicalizationPatterns(patterns, ctx);
105+
// Propagate packs/unpacks only through expand shapes at this point.
106+
// This captures the transformation scope of the replaced downstream pass.
107+
linalg::populateDataLayoutPropagationPatterns(
108+
patterns, [](Operation *op) { return isa<tensor::ExpandShapeOp>(op); });
109+
ctx->getLoadedDialect<tensor::TensorDialect>()->getCanonicalizationPatterns(
110+
patterns);
111+
// patterns.add<FoldUnPackIntoInsertSlice>(ctx);
112+
tensor::populateReassociativeReshapeFoldingPatterns(patterns);
113+
}
114+
115+
void PostProcessPackUnpack::runOnOperation() {
116+
MLIRContext *ctx = getOperation()->getContext();
117+
RewritePatternSet patterns(ctx);
118+
119+
// constant fold packing
120+
tppPopulateConstantFoldPack(patterns);
121+
// simplify packing
122+
tppPopulateSimplifyPacking(patterns);
123+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
124+
}
125+
126+
} // namespace gc
127+
} // namespace mlir

0 commit comments

Comments
 (0)