@@ -31,6 +31,19 @@ using namespace mlir;
31
31
using namespace mlir ::arith;
32
32
using namespace mlir ::tensor;
33
33
34
+ static SmallVector<int64_t > getPackedAxes (ArrayRef<int64_t > dimensions,
35
+ TensorLayout targetLayout) {
36
+ SmallVector<int64_t > result (dimensions);
37
+ auto innerPos = targetLayout.getInnerAxis ();
38
+ for (size_t i = 0 ; i < dimensions.size (); ++i) {
39
+ if (std::find (innerPos.begin (), innerPos.end (), dimensions[i]) !=
40
+ innerPos.end ()) {
41
+ result.push_back (i + targetLayout.getOuterAxis ().size ());
42
+ }
43
+ }
44
+ return result;
45
+ }
46
+
34
47
static FailureOr<linalg::PackResult> packNamedOp (RewriterBase &rewriter,
35
48
linalg::LinalgOp linalgOp,
36
49
OperatorLayout opLayout) {
@@ -125,10 +138,30 @@ static FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
125
138
ValueRange{inputsAndInits}.take_front (linalgOp.getNumDpsInputs ());
126
139
ValueRange inits =
127
140
ValueRange{inputsAndInits}.take_back (linalgOp.getNumDpsInits ());
128
- // TODO(yifei): the axis info of reduce/broadcast/transpose may change
129
- auto packedLinalgOp = mlir::clone (
130
- rewriter, linalgOp, SmallVector<Type>{inputsAndInits.back ().getType ()},
131
- inputsAndInits);
141
+ // TODO(yifei): deal with reduce/broadcast/transpose
142
+ // TODO(yifei): deal with generic
143
+ linalg::LinalgOp packedLinalgOp;
144
+ if (auto reduceOp = dyn_cast<linalg::ReduceOp>(&linalgOp)) {
145
+ SmallVector<int64_t > packedAxes =
146
+ getPackedAxes (reduceOp->getDimensions (), inputLayouts[0 ]);
147
+ packedLinalgOp = rewriter.create <linalg::ReduceOp>(
148
+ loc, inits.getTypes (), inputs, inits, packedAxes);
149
+ packedLinalgOp->getRegion (0 ).takeBody (linalgOp->getRegion (0 ));
150
+ } else if (auto broadcastOp = dyn_cast<linalg::BroadcastOp>(&linalgOp)) {
151
+ packedLinalgOp = rewriter.create <linalg::BroadcastOp>(
152
+ loc, inputs[0 ], inits[0 ], broadcastOp->getDimensions ());
153
+ } else if (isa<linalg::TransposeOp>(linalgOp)) {
154
+ // remove transpose op
155
+ } else if (isa<linalg::SoftmaxOp>(linalgOp) ||
156
+ isa<linalg::GenericOp>(linalgOp) || isa<linalg::MapOp>(linalgOp) ||
157
+ isa<linalg::YieldOp>(linalgOp) || isa<linalg::IndexOp>(linalgOp)) {
158
+ return failure (
159
+ " Packing logic not implemented for SoftMax/Generic/Map/Yield/Index." );
160
+ } else {
161
+ packedLinalgOp = mlir::clone (
162
+ rewriter, linalgOp, SmallVector<Type>{inputsAndInits.back ().getType ()},
163
+ inputsAndInits);
164
+ }
132
165
133
166
// Step 4. Unpack all the op results.
134
167
for (OpResult result : packedLinalgOp->getResults ()) {
@@ -195,7 +228,6 @@ void PropagateLayout::runOnOperation() {
195
228
FailureOr<linalg::PackResult> packedOp =
196
229
packNamedOp (rewriter, linalgOp, *opLayout);
197
230
}
198
- graph->dump ();
199
231
}
200
232
}
201
233
graph->dump ();
0 commit comments