Skip to content

Commit 16a7e8e

Browse files
committed
cache
1 parent eaa9197 commit 16a7e8e

File tree

7 files changed

+112
-34
lines changed

7 files changed

+112
-34
lines changed

include/gc/Analysis/GlobalAnalysis.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ class GlobalAnalysis {
143143
DenseMap<Operation *, OperatorLayout> layoutCache;
144144
};
145145

146+
namespace utils {
147+
bool isPackableNamedOp(Operation *op);
148+
}
146149
} // namespace gc
147150
} // namespace mlir
148151

include/gc/Transforms/Passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,4 +97,15 @@ def PostProcessPackUnpack : Pass<"post-process-pack-unpack"> {
9797
];
9898
}
9999

100+
def LowerPackUnpack : Pass<"lower-pack-unpack"> {
101+
let summary = "Lower pack and unpack ops.";
102+
let description = [{
103+
Lower pack and unpack into transpose and shape related ops.
104+
}];
105+
let dependentDialects = [
106+
"mlir::tensor::TensorDialect",
107+
"mlir::linalg::LinalgDialect"
108+
];
109+
}
110+
100111
#endif // GC_DIALECT_GC_PASSES

lib/gc/Analysis/GlobalAnalysis.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,5 +342,17 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
342342
}
343343
});
344344
}
345+
346+
namespace utils {
347+
bool isPackableNamedOp(Operation *op) {
348+
if ((isa<linalg::LinalgOp>(op) &&
349+
!mlir::linalg::isaContractionOpInterface(
350+
dyn_cast<linalg::LinalgOp>(op)) &&
351+
!isa<linalgx::Mm4DVnniOp>(op)) ||
352+
isa<tensor::ExpandShapeOp>(op) || isa<tensor::PadOp>(op))
353+
return true;
354+
return false;
355+
}
356+
} // namespace utils
345357
} // namespace gc
346358
} // namespace mlir

lib/gc/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_mlir_library(GCPasses
1414
Pipeline.cpp
1515
PropagateLayout.cpp
1616
PostProcessPackUnpack.cpp
17+
LowerPackUnpack.cpp
1718
DeepTileContractionNamedOp.cpp
1819
TilingUtil.cpp
1920
SinkOpIntoInnerLoop.cpp
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
//===- LowerPackUnpack.cpp - Lower pack unpack into linalg ops *---- C++-*-===//
2+
//
3+
// This file is only temporarily used to extend upstream or upcoming utility in
4+
// TilingInterface, which finally aims for upstream.
5+
//
6+
//===----------------------------------------------------------------------===//
7+
8+
#include <iostream>
9+
#include <numeric>
10+
11+
#include "gc/Transforms/Transforms.h"
12+
#include "mlir/Dialect/Arith/IR/Arith.h"
13+
#include "mlir/Dialect/Func/IR/FuncOps.h"
14+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
15+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
17+
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
18+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
19+
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
20+
#include "mlir/IR/PatternMatch.h"
21+
#include "mlir/Transforms/DialectConversion.h"
22+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23+
24+
#include "gc/Dialect/Linalgx/LinalgxDialect.h"
25+
#include "gc/Dialect/Linalgx/LinalgxOps.h"
26+
#include "gc/Transforms/Passes.h"
27+
namespace mlir {
28+
namespace gc {
29+
#define GEN_PASS_DEF_LOWERPACKUNPACK
30+
#include "gc/Transforms/Passes.h.inc"
31+
32+
#define DEBUG_TYPE "lower-pack-unpack"
33+
34+
using namespace mlir;
35+
36+
// copied from tpp
37+
// A wrapper pattern that calls linalg::lowerPack on tensor::PackOp. It lowers
38+
// a tensor.pack op to tensor.pad + tensor.expand_shape + linalg.transpose ops.
39+
struct LowerPackPattern : public OpRewritePattern<tensor::PackOp> {
40+
using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
41+
42+
LogicalResult matchAndRewrite(tensor::PackOp op,
43+
PatternRewriter &rewriter) const override {
44+
FailureOr<linalg::LowerPackResult> res = linalg::lowerPack(rewriter, op);
45+
if (failed(res)) {
46+
return rewriter.notifyMatchFailure(
47+
op, "cannot lower to pad + expand + transpose");
48+
}
49+
return success();
50+
}
51+
};
52+
53+
// A wrapper pattern that calls linalg::lowerUnPack on tensor::UnPackOp. It
54+
// lowers a tensor.unpack op to tensor.empty + linalg.transpose +
55+
// tensor.collapse_shape + tensor.extract_slice ops.
56+
struct LowerUnPackPattern : public OpRewritePattern<tensor::UnPackOp> {
57+
using OpRewritePattern<tensor::UnPackOp>::OpRewritePattern;
58+
59+
LogicalResult matchAndRewrite(tensor::UnPackOp op,
60+
PatternRewriter &rewriter) const override {
61+
if (failed(linalg::lowerUnPack(rewriter, op))) {
62+
return rewriter.notifyMatchFailure(
63+
op, "cannot lower to empty + transpose + reshape + extract_slice");
64+
}
65+
return success();
66+
}
67+
};
68+
69+
class LowerPackUnpack : public impl::LowerPackUnpackBase<LowerPackUnpack> {
70+
public:
71+
using impl::LowerPackUnpackBase<LowerPackUnpack>::LowerPackUnpackBase;
72+
void runOnOperation() final;
73+
};
74+
75+
void LowerPackUnpack::runOnOperation() {
76+
auto *ctx = &getContext();
77+
RewritePatternSet patterns(ctx);
78+
patterns.add<LowerPackPattern, LowerUnPackPattern>(ctx);
79+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
80+
}
81+
82+
} // namespace gc
83+
} // namespace mlir

lib/gc/Transforms/Pipeline.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ void populateTensorPasses(mlir::OpPassManager &pm) {
5858
pm.addNestedPass<func::FuncOp>(createLinalgGeneralizeNamedOpsPass());
5959
pm.addPass(createLoopInvariantCodeMotionPass());
6060
pm.addPass(createControlFlowSinkPass());
61+
// TODO(yifei): remove lower pack here
62+
pm.addPass(createLowerPackUnpack());
6163
populateCleanUpPasses(pm);
6264
}
6365

lib/gc/Transforms/PostProcessPackUnpack.cpp

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -126,39 +126,6 @@ struct EliminateDummyUnpack : public OpRewritePattern<tensor::UnPackOp> {
126126
}
127127
};
128128

129-
// copied from tpp
130-
// A wrapper pattern that calls linalg::lowerPack on tensor::PackOp. It lowers
131-
// a tensor.pack op to tensor.pad + tensor.expand_shape + linalg.transpose ops.
132-
struct LowerPackPattern : public OpRewritePattern<tensor::PackOp> {
133-
using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
134-
135-
LogicalResult matchAndRewrite(tensor::PackOp op,
136-
PatternRewriter &rewriter) const override {
137-
FailureOr<linalg::LowerPackResult> res = linalg::lowerPack(rewriter, op);
138-
if (failed(res)) {
139-
return rewriter.notifyMatchFailure(
140-
op, "cannot lower to pad + expand + transpose");
141-
}
142-
return success();
143-
}
144-
};
145-
146-
// A wrapper pattern that calls linalg::lowerUnPack on tensor::UnPackOp. It
147-
// lowers a tensor.unpack op to tensor.empty + linalg.transpose +
148-
// tensor.collapse_shape + tensor.extract_slice ops.
149-
struct LowerUnPackPattern : public OpRewritePattern<tensor::UnPackOp> {
150-
using OpRewritePattern<tensor::UnPackOp>::OpRewritePattern;
151-
152-
LogicalResult matchAndRewrite(tensor::UnPackOp op,
153-
PatternRewriter &rewriter) const override {
154-
if (failed(linalg::lowerUnPack(rewriter, op))) {
155-
return rewriter.notifyMatchFailure(
156-
op, "cannot lower to empty + transpose + reshape + extract_slice");
157-
}
158-
return success();
159-
}
160-
};
161-
162129
static void populateEliminateDummyPackUnpack(RewritePatternSet &patterns) {
163130
MLIRContext *ctx = patterns.getContext();
164131
patterns.add<EliminateDummyPack, EliminateDummyUnpack>(ctx);
@@ -205,7 +172,6 @@ void PostProcessPackUnpack::runOnOperation() {
205172
tppPopulateSimplifyPacking(patterns);
206173
// gc new packing related simplification
207174
populateEliminateDummyPackUnpack(patterns);
208-
patterns.add<LowerPackPattern, LowerUnPackPattern>(ctx);
209175
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
210176
}
211177

0 commit comments

Comments
 (0)