1- // ===- PropagateLayout.cpp - Propagate pack unpack on linalg named ops --*- C++
2- // -*-===//
1+ // ===- PropagateLayout.cpp - Propagate packing on linalg named ops*- C++-*-===//
32//
43// This file is only temporarily used to extend upstream or upcoming utility in
54// TilingInterface, which finally aims for upstream.
@@ -44,6 +43,29 @@ static SmallVector<int64_t> getPackedAxes(ArrayRef<int64_t> dimensions,
4443 return result;
4544}
4645
46+ static SmallVector<int64_t > getPackedPermAxes (ArrayRef<int64_t > plainPermAxes,
47+ TensorLayout inputLayout,
48+ TensorLayout outputLayout) {
49+ // dim(result, i) = dim(input, permutation[i])
50+ // input: permutation[i] --> output: i
51+ // input: permutation[i] --> packed input: std::find(permutation[i]) - begin()
52+ // output: i --> packed output: std::find(permutation[i]) - begin()
53+ size_t packedRank =
54+ outputLayout.getInnerAxis ().size () + outputLayout.getOuterAxis ().size ();
55+ SmallVector<int64_t > result (packedRank, 0 );
56+ SmallVector<int64_t > inputCount (inputLayout.getOuterAxis ().size (), 0 );
57+ auto inputP2B = TensorLayout::getPlain2PackedMapping (inputLayout);
58+ for (size_t i = 0 ; i < packedRank; ++i) {
59+ // packedOutput[i] --> output[?]
60+ size_t originalOutputAxis = *outputLayout.getOriginalAxis (i);
61+ size_t originalInputAxis = plainPermAxes[originalOutputAxis];
62+ SmallVector<int64_t > packedInputAxes = inputP2B[originalInputAxis];
63+ result[i] = packedInputAxes[inputCount[originalInputAxis]++];
64+ }
65+ return result;
66+ }
67+
68+ // extends mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp's linalg::pack
4769static FailureOr<linalg::PackResult> packNamedOp (RewriterBase &rewriter,
4870 linalg::LinalgOp linalgOp,
4971 OperatorLayout opLayout) {
@@ -150,8 +172,11 @@ static FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
150172 } else if (auto broadcastOp = dyn_cast<linalg::BroadcastOp>(&linalgOp)) {
151173 packedLinalgOp = rewriter.create <linalg::BroadcastOp>(
152174 loc, inputs[0 ], inits[0 ], broadcastOp->getDimensions ());
153- } else if (isa<linalg::TransposeOp>(linalgOp)) {
154- // remove transpose op
175+ } else if (auto transposeOp = dyn_cast<linalg::TransposeOp>(&linalgOp)) {
176+ SmallVector<int64_t > packedPermAxes = getPackedPermAxes (
177+ transposeOp->getPermutation (), inputLayouts[0 ], initLayouts[0 ]);
178+ packedLinalgOp = rewriter.create <linalg::TransposeOp>(
179+ loc, inputs[0 ], inits[0 ], packedPermAxes);
155180 } else if (isa<linalg::SoftmaxOp>(linalgOp) ||
156181 isa<linalg::GenericOp>(linalgOp) || isa<linalg::MapOp>(linalgOp) ||
157182 isa<linalg::YieldOp>(linalgOp) || isa<linalg::IndexOp>(linalgOp)) {
@@ -175,7 +200,8 @@ static FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
175200 // Build the symmetrical UnPackOp to the existing PackOp.
176201 unPackOps.push_back (rewriter.create <tensor::UnPackOp>(
177202 packedLinalgOp->getLoc (), result, maybePackedInit.getSource (),
178- maybePackedInit.getInnerDimsPos (), maybePackedInit.getMixedTiles ()));
203+ maybePackedInit.getInnerDimsPos (), maybePackedInit.getMixedTiles (),
204+ maybePackedInit.getOuterDimsPerm ()));
179205 results.push_back (unPackOps.back ());
180206 }
181207
0 commit comments