@@ -31,6 +31,19 @@ using namespace mlir;
3131using namespace mlir ::arith;
3232using namespace mlir ::tensor;
3333
34+ static SmallVector<int64_t > getPackedAxes (ArrayRef<int64_t > dimensions,
35+ TensorLayout targetLayout) {
36+ SmallVector<int64_t > result (dimensions);
37+ auto innerPos = targetLayout.getInnerAxis ();
38+ for (size_t i = 0 ; i < dimensions.size (); ++i) {
39+ if (std::find (innerPos.begin (), innerPos.end (), dimensions[i]) !=
40+ innerPos.end ()) {
41+ result.push_back (i + targetLayout.getOuterAxis ().size ());
42+ }
43+ }
44+ return result;
45+ }
46+
3447static FailureOr<linalg::PackResult> packNamedOp (RewriterBase &rewriter,
3548 linalg::LinalgOp linalgOp,
3649 OperatorLayout opLayout) {
@@ -125,10 +138,30 @@ static FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
125138 ValueRange{inputsAndInits}.take_front (linalgOp.getNumDpsInputs ());
126139 ValueRange inits =
127140 ValueRange{inputsAndInits}.take_back (linalgOp.getNumDpsInits ());
128- // TODO(yifei): the axis info of reduce/broadcast/transpose may change
129- auto packedLinalgOp = mlir::clone (
130- rewriter, linalgOp, SmallVector<Type>{inputsAndInits.back ().getType ()},
131- inputsAndInits);
141+ // TODO(yifei): deal with reduce/broadcast/transpose
142+ // TODO(yifei): deal with generic
143+ linalg::LinalgOp packedLinalgOp;
144+ if (auto reduceOp = dyn_cast<linalg::ReduceOp>(&linalgOp)) {
145+ SmallVector<int64_t > packedAxes =
146+ getPackedAxes (reduceOp->getDimensions (), inputLayouts[0 ]);
147+ packedLinalgOp = rewriter.create <linalg::ReduceOp>(
148+ loc, inits.getTypes (), inputs, inits, packedAxes);
149+ packedLinalgOp->getRegion (0 ).takeBody (linalgOp->getRegion (0 ));
150+ } else if (auto broadcastOp = dyn_cast<linalg::BroadcastOp>(&linalgOp)) {
151+ packedLinalgOp = rewriter.create <linalg::BroadcastOp>(
152+ loc, inputs[0 ], inits[0 ], broadcastOp->getDimensions ());
153+ } else if (isa<linalg::TransposeOp>(linalgOp)) {
154+ // remove transpose op
155+ } else if (isa<linalg::SoftmaxOp>(linalgOp) ||
156+ isa<linalg::GenericOp>(linalgOp) || isa<linalg::MapOp>(linalgOp) ||
157+ isa<linalg::YieldOp>(linalgOp) || isa<linalg::IndexOp>(linalgOp)) {
158+ return failure (
159+ " Packing logic not implemented for SoftMax/Generic/Map/Yield/Index." );
160+ } else {
161+ packedLinalgOp = mlir::clone (
162+ rewriter, linalgOp, SmallVector<Type>{inputsAndInits.back ().getType ()},
163+ inputsAndInits);
164+ }
132165
133166 // Step 4. Unpack all the op results.
134167 for (OpResult result : packedLinalgOp->getResults ()) {
@@ -195,7 +228,6 @@ void PropagateLayout::runOnOperation() {
195228 FailureOr<linalg::PackResult> packedOp =
196229 packNamedOp (rewriter, linalgOp, *opLayout);
197230 }
198- graph->dump ();
199231 }
200232 }
201233 graph->dump ();
0 commit comments