1
- // ===- PropagateLayout.cpp - Propagate pack unpack on linalg named ops --*- C++
2
- // -*-===//
1
+ // ===- PropagateLayout.cpp - Propagate packing on linalg named ops*- C++-*-===//
3
2
//
4
3
// This file is only temporarily used to extend upstream or upcoming utility in
5
4
// TilingInterface, which finally aims for upstream.
@@ -44,6 +43,29 @@ static SmallVector<int64_t> getPackedAxes(ArrayRef<int64_t> dimensions,
44
43
return result;
45
44
}
46
45
46
+ static SmallVector<int64_t > getPackedPermAxes (ArrayRef<int64_t > plainPermAxes,
47
+ TensorLayout inputLayout,
48
+ TensorLayout outputLayout) {
49
+ // dim(result, i) = dim(input, permutation[i])
50
+ // input: permutation[i] --> output: i
51
+ // input: permutation[i] --> packed input: std::find(permutation[i]) - begin()
52
+ // output: i --> packed output: std::find(permutation[i]) - begin()
53
+ size_t packedRank =
54
+ outputLayout.getInnerAxis ().size () + outputLayout.getOuterAxis ().size ();
55
+ SmallVector<int64_t > result (packedRank, 0 );
56
+ SmallVector<int64_t > inputCount (inputLayout.getOuterAxis ().size (), 0 );
57
+ auto inputP2B = TensorLayout::getPlain2PackedMapping (inputLayout);
58
+ for (size_t i = 0 ; i < packedRank; ++i) {
59
+ // packedOutput[i] --> output[?]
60
+ size_t originalOutputAxis = *outputLayout.getOriginalAxis (i);
61
+ size_t originalInputAxis = plainPermAxes[originalOutputAxis];
62
+ SmallVector<int64_t > packedInputAxes = inputP2B[originalInputAxis];
63
+ result[i] = packedInputAxes[inputCount[originalInputAxis]++];
64
+ }
65
+ return result;
66
+ }
67
+
68
+ // extends mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp's linalg::pack
47
69
static FailureOr<linalg::PackResult> packNamedOp (RewriterBase &rewriter,
48
70
linalg::LinalgOp linalgOp,
49
71
OperatorLayout opLayout) {
@@ -150,8 +172,11 @@ static FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
150
172
} else if (auto broadcastOp = dyn_cast<linalg::BroadcastOp>(&linalgOp)) {
151
173
packedLinalgOp = rewriter.create <linalg::BroadcastOp>(
152
174
loc, inputs[0 ], inits[0 ], broadcastOp->getDimensions ());
153
- } else if (isa<linalg::TransposeOp>(linalgOp)) {
154
- // remove transpose op
175
+ } else if (auto transposeOp = dyn_cast<linalg::TransposeOp>(&linalgOp)) {
176
+ SmallVector<int64_t > packedPermAxes = getPackedPermAxes (
177
+ transposeOp->getPermutation (), inputLayouts[0 ], initLayouts[0 ]);
178
+ packedLinalgOp = rewriter.create <linalg::TransposeOp>(
179
+ loc, inputs[0 ], inits[0 ], packedPermAxes);
155
180
} else if (isa<linalg::SoftmaxOp>(linalgOp) ||
156
181
isa<linalg::GenericOp>(linalgOp) || isa<linalg::MapOp>(linalgOp) ||
157
182
isa<linalg::YieldOp>(linalgOp) || isa<linalg::IndexOp>(linalgOp)) {
@@ -175,7 +200,8 @@ static FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
175
200
// Build the symmetrical UnPackOp to the existing PackOp.
176
201
unPackOps.push_back (rewriter.create <tensor::UnPackOp>(
177
202
packedLinalgOp->getLoc (), result, maybePackedInit.getSource (),
178
- maybePackedInit.getInnerDimsPos (), maybePackedInit.getMixedTiles ()));
203
+ maybePackedInit.getInnerDimsPos (), maybePackedInit.getMixedTiles (),
204
+ maybePackedInit.getOuterDimsPerm ()));
179
205
results.push_back (unPackOps.back ());
180
206
}
181
207
0 commit comments