1313#include " mlir/Dialect/Arith/IR/Arith.h"
1414#include " mlir/Dialect/Func/IR/FuncOps.h"
1515#include " mlir/Dialect/Linalg/IR/Linalg.h"
16+ #include " mlir/Dialect/Linalg/Transforms/Transforms.h"
1617#include " mlir/Dialect/Tensor/IR/Tensor.h"
1718#include " mlir/Dialect/Utils/StaticValueUtils.h"
1819#include " mlir/Dialect/Utils/StructuredOpsUtils.h"
@@ -30,6 +31,130 @@ using namespace mlir;
3031using namespace mlir ::arith;
3132using namespace mlir ::tensor;
3233
34+ static FailureOr<linalg::PackResult> packNamedOp (RewriterBase &rewriter,
35+ linalg::LinalgOp linalgOp,
36+ OperatorLayout opLayout) {
37+ std::cout << " ----------------------------------" << std::endl;
38+ std::cout << " Visiting op in packNamedOp " ;
39+ linalgOp->getName ().print (llvm::errs ());
40+ std::cout << std::endl;
41+ std::cout << " ----------------------------------" << std::endl;
42+ Location loc = linalgOp->getLoc ();
43+ SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray ();
44+ SmallVector<utils::IteratorType> iteratorTypes =
45+ linalgOp.getIteratorTypesArray ();
46+
47+ SmallVector<tensor::PackOp> packOps;
48+ SmallVector<tensor::UnPackOp> unPackOps;
49+ SmallVector<Value> inputsAndInits, results;
50+ SmallVector<OpOperand *> initOperands = llvm::to_vector (llvm::map_range (
51+ linalgOp.getDpsInitsMutable (), [](OpOperand &o) { return &o; }));
52+ SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands ();
53+ std::cout << " Num of input operands: " << inputOperands.size () << std::endl;
54+ std::cout << " Num of init operands: " << initOperands.size () << std::endl;
55+ SmallVector<TensorLayout> inputLayouts = opLayout.getSupportedInputLayouts ();
56+ SmallVector<TensorLayout> initLayouts = opLayout.getSupportedOutputLayouts ();
57+ std::cout << " Num of input layouts: " << inputLayouts.size () << std::endl;
58+ std::cout << " Num of init layouts: " << initLayouts.size () << std::endl;
59+
60+ // check all inputs and inits are tensor, otherwise no need for layout
61+ // propagation
62+ bool allTensor =
63+ llvm::all_of (inputOperands,
64+ [](OpOperand *opOperand) {
65+ return opOperand->get ().getType ().isa <TensorType>();
66+ }) &&
67+ llvm::all_of (initOperands, [](OpOperand *opOperand) {
68+ return opOperand->get ().getType ().isa <TensorType>();
69+ });
70+ std::cout << " The op's input is all tensor?" << allTensor << std::endl;
71+ if (!allTensor) {
72+ return failure (" the op does not need packing." );
73+ }
74+ for (const auto &operandsList : {inputOperands, initOperands}) {
75+ for (OpOperand *opOperand : operandsList) {
76+ int64_t pos = opOperand->getOperandNumber ();
77+ std::cout << " pos: " << pos << std::endl;
78+ Value operand = opOperand->get ();
79+ TensorLayout targetLayout = pos >= inputLayouts.size ()
80+ ? initLayouts[pos - inputLayouts.size ()]
81+ : inputLayouts[pos];
82+ SmallVector<int64_t > outerPerm = targetLayout.getOuterAxis ();
83+ SmallVector<int64_t > innerPos = targetLayout.getInnerAxis ();
84+ SmallVector<OpFoldResult> innerPackSizes = targetLayout.getTileSizes ();
85+
86+ std::cout << " Suggested layout: " << targetLayout << std::endl;
87+
88+ std::cout << " Operand shape: " ;
89+ for (auto dim :
90+ llvm::cast<RankedTensorType>(operand.getType ()).getShape ()) {
91+ std::cout << dim << " , " ;
92+ }
93+ std::cout << std::endl;
94+
95+ Value dest = tensor::PackOp::createDestinationTensor (
96+ rewriter, loc, operand, innerPackSizes, innerPos, outerPerm);
97+ ShapedType operandType = cast<ShapedType>(operand.getType ());
98+ bool areConstantTiles =
99+ llvm::all_of (innerPackSizes, [](OpFoldResult tile) {
100+ return getConstantIntValue (tile).has_value ();
101+ });
102+ if (areConstantTiles && operandType.hasStaticShape () &&
103+ !tensor::PackOp::requirePaddingValue (
104+ operandType.getShape (), innerPos,
105+ cast<ShapedType>(dest.getType ()).getShape (), {},
106+ innerPackSizes)) {
107+ packOps.push_back (rewriter.create <tensor::PackOp>(
108+ loc, operand, dest, innerPos, innerPackSizes, std::nullopt ,
109+ outerPerm));
110+ } else {
111+ // TODO: value of the padding attribute should be determined by
112+ // consumers.
113+ auto zeroAttr =
114+ rewriter.getZeroAttr (getElementTypeOrSelf (dest.getType ()));
115+ Value zero = rewriter.create <arith::ConstantOp>(loc, zeroAttr);
116+ packOps.push_back (rewriter.create <tensor::PackOp>(
117+ loc, operand, dest, innerPos, innerPackSizes, zero, outerPerm));
118+ }
119+ inputsAndInits.push_back (packOps.back ());
120+ }
121+ }
122+
123+ // Step 3. Build the packed op, use the type of `inits` as result types.
124+ ValueRange inputs =
125+ ValueRange{inputsAndInits}.take_front (linalgOp.getNumDpsInputs ());
126+ ValueRange inits =
127+ ValueRange{inputsAndInits}.take_back (linalgOp.getNumDpsInits ());
128+ // TODO(yifei): the axis info of reduce/broadcast/transpose may change
129+ auto packedLinalgOp = mlir::clone (
130+ rewriter, linalgOp, SmallVector<Type>{inputsAndInits.back ().getType ()},
131+ inputsAndInits);
132+
133+ // Step 4. Unpack all the op results.
134+ for (OpResult result : packedLinalgOp->getResults ()) {
135+ int64_t resultNum = result.getResultNumber ();
136+ tensor::PackOp maybePackedInit =
137+ inits[resultNum].getDefiningOp <tensor::PackOp>();
138+ if (!maybePackedInit) {
139+ results.push_back (result);
140+ continue ;
141+ }
142+ // Build the symmetrical UnPackOp to the existing PackOp.
143+ unPackOps.push_back (rewriter.create <tensor::UnPackOp>(
144+ packedLinalgOp->getLoc (), result, maybePackedInit.getSource (),
145+ maybePackedInit.getInnerDimsPos (), maybePackedInit.getMixedTiles ()));
146+ results.push_back (unPackOps.back ());
147+ }
148+
149+ // Step 5. Replace `linalgOp`.
150+ rewriter.replaceOp (linalgOp, results);
151+
152+ // Return packedLinalgOp.
153+ return linalg::PackResult{
154+ packOps, cast<linalg::LinalgOp>(packedLinalgOp.getOperation ()),
155+ unPackOps};
156+ }
157+
33158class PropagateLayout : public impl ::PropagateLayoutBase<PropagateLayout> {
34159public:
35160 using impl::PropagateLayoutBase<PropagateLayout>::PropagateLayoutBase;
@@ -42,24 +167,37 @@ void PropagateLayout::runOnOperation() {
42167 IRRewriter rewriter (ctx);
43168 // walk the entire graph
44169 auto &layoutAnalysisResult = getAnalysis<GlobalAnalysis>();
45- graph->walk ([&](linalg::LinalgOp linalgOp) {
170+ SmallVector<Operation *> packTODOList;
171+ graph->walk ([&](Operation *op) {
172+ if (isa<linalg::LinalgOp>(op) && !mlir::linalg::isaContractionOpInterface (
173+ dyn_cast<linalg::LinalgOp>(op))) {
174+ packTODOList.push_back (op);
175+ }
176+ });
177+ for (auto op : packTODOList) {
46178 std::cout << std::endl;
47179 std::cout << " ----------------------------------" << std::endl;
48180 std::cout << " Visiting op " ;
49- linalgOp. getOperation () ->getName ().print (llvm::errs ());
181+ op ->getName ().print (llvm::errs ());
50182 std::cout << std::endl;
51183 std::cout << " ----------------------------------" << std::endl;
52- FailureOr<OperatorLayout> opLayout =
53- layoutAnalysisResult.getOpLayout (linalgOp);
184+ FailureOr<OperatorLayout> opLayout = layoutAnalysisResult.getOpLayout (op);
54185 if (failed (opLayout)) {
55186 std::cout << " infer failed" << std::endl;
56187 } else {
57188 // pack op into ideal layout
58189 std::cout << " -------- supported layouts -------" << std::endl;
59190 std::cout << *opLayout << std::endl;
60191 // insert pack
192+ OpBuilder::InsertionGuard guard (rewriter);
193+ rewriter.setInsertionPoint (op);
194+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
195+ FailureOr<linalg::PackResult> packedOp =
196+ packNamedOp (rewriter, linalgOp, *opLayout);
197+ }
198+ graph->dump ();
61199 }
62- });
200+ }
63201 graph->dump ();
64202}
65203
0 commit comments