Skip to content

Commit 41cd05f

Browse files
committed
support transpose
1 parent 2462342 commit 41cd05f

File tree

3 files changed

+59
-8
lines changed

3 files changed

+59
-8
lines changed

include/gc/Analysis/GlobalAnalysis.h

+26-1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,31 @@ class TensorLayout {
5555
SmallVector<OpFoldResult>{});
5656
}
5757

58+
static DenseMap<int64_t, SmallVector<int64_t>>
59+
getPlain2PackedMapping(TensorLayout layout) {
60+
DenseMap<int64_t, SmallVector<int64_t>> p2b;
61+
SmallVector<int64_t> outerAxis = layout.getOuterAxis();
62+
SmallVector<int64_t> innerAxis = layout.getInnerAxis();
63+
for (size_t i = 0; i < outerAxis.size(); ++i) {
64+
p2b[outerAxis[i]].push_back(i);
65+
}
66+
for (size_t i = 0; i < innerAxis.size(); ++i) {
67+
p2b[innerAxis[i]].push_back(outerAxis.size() + i);
68+
}
69+
return p2b;
70+
}
71+
72+
FailureOr<int64_t> getOriginalAxis(int64_t idx) {
73+
size_t totalRank = OuterAxis.size() + InnerAxis.size();
74+
if (idx >= totalRank) {
75+
return failure("Index out of range.");
76+
} else if (idx >= OuterAxis.size()) {
77+
return InnerAxis[idx - OuterAxis.size()];
78+
} else {
79+
return OuterAxis[idx];
80+
}
81+
}
82+
5883
size_t getTensorRank() const { return OuterAxis.size(); }
5984

6085
SmallVector<int64_t> getOuterAxis() const { return OuterAxis; }
@@ -112,7 +137,7 @@ class GlobalAnalysis {
112137
if (layout.find(op) != layout.end())
113138
return layout[op];
114139
else
115-
return op->emitError("Current op does not have layout information.");
140+
return failure("Current op does not have layout information.");
116141
}
117142

118143
private:

lib/gc/Analysis/GlobalAnalysis.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@ std::ostream &operator<<(std::ostream &ss, const OperatorLayout &opLayout) {
5353
ss << "operator has " << opLayout.getSupportedInputLayouts().size()
5454
<< " inputs; " << opLayout.getSupportedOutputLayouts().size()
5555
<< " outputs." << std::endl;
56-
for (auto layout : opLayout.getSupportedInputLayouts()) {
56+
for (const auto &layout : opLayout.getSupportedInputLayouts()) {
5757
ss << "input layout: " << layout << std::endl;
5858
}
59-
for (auto layout : opLayout.getSupportedOutputLayouts()) {
59+
for (const auto &layout : opLayout.getSupportedOutputLayouts()) {
6060
ss << "output layout: " << layout << std::endl;
6161
}
6262
return ss;

lib/gc/Transforms/PropagateLayout.cpp

+31-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
//===- PropagateLayout.cpp - Propagate pack unpack on linalg named ops --*- C++
2-
//-*-===//
1+
//===- PropagateLayout.cpp - Propagate packing on linalg named ops*- C++-*-===//
32
//
43
// This file is only temporarily used to extend upstream or upcoming utility in
54
// TilingInterface, which finally aims for upstream.
@@ -44,6 +43,29 @@ static SmallVector<int64_t> getPackedAxes(ArrayRef<int64_t> dimensions,
4443
return result;
4544
}
4645

46+
static SmallVector<int64_t> getPackedPermAxes(ArrayRef<int64_t> plainPermAxes,
47+
TensorLayout inputLayout,
48+
TensorLayout outputLayout) {
49+
// dim(result, i) = dim(input, permutation[i])
50+
// input: permutation[i] --> output: i
51+
// input: permutation[i] --> packed input: std::find(permutation[i]) - begin()
52+
// output: i --> packed output: std::find(permutation[i]) - begin()
53+
size_t packedRank =
54+
outputLayout.getInnerAxis().size() + outputLayout.getOuterAxis().size();
55+
SmallVector<int64_t> result(packedRank, 0);
56+
SmallVector<int64_t> inputCount(inputLayout.getOuterAxis().size(), 0);
57+
auto inputP2B = TensorLayout::getPlain2PackedMapping(inputLayout);
58+
for (size_t i = 0; i < packedRank; ++i) {
59+
// packedOutput[i] --> output[?]
60+
size_t originalOutputAxis = *outputLayout.getOriginalAxis(i);
61+
size_t originalInputAxis = plainPermAxes[originalOutputAxis];
62+
SmallVector<int64_t> packedInputAxes = inputP2B[originalInputAxis];
63+
result[i] = packedInputAxes[inputCount[originalInputAxis]++];
64+
}
65+
return result;
66+
}
67+
68+
// extends mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp's linalg::pack
4769
static FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
4870
linalg::LinalgOp linalgOp,
4971
OperatorLayout opLayout) {
@@ -150,8 +172,11 @@ static FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
150172
} else if (auto broadcastOp = dyn_cast<linalg::BroadcastOp>(&linalgOp)) {
151173
packedLinalgOp = rewriter.create<linalg::BroadcastOp>(
152174
loc, inputs[0], inits[0], broadcastOp->getDimensions());
153-
} else if (isa<linalg::TransposeOp>(linalgOp)) {
154-
// remove transpose op
175+
} else if (auto transposeOp = dyn_cast<linalg::TransposeOp>(&linalgOp)) {
176+
SmallVector<int64_t> packedPermAxes = getPackedPermAxes(
177+
transposeOp->getPermutation(), inputLayouts[0], initLayouts[0]);
178+
packedLinalgOp = rewriter.create<linalg::TransposeOp>(
179+
loc, inputs[0], inits[0], packedPermAxes);
155180
} else if (isa<linalg::SoftmaxOp>(linalgOp) ||
156181
isa<linalg::GenericOp>(linalgOp) || isa<linalg::MapOp>(linalgOp) ||
157182
isa<linalg::YieldOp>(linalgOp) || isa<linalg::IndexOp>(linalgOp)) {
@@ -175,7 +200,8 @@ static FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
175200
// Build the symmetrical UnPackOp to the existing PackOp.
176201
unPackOps.push_back(rewriter.create<tensor::UnPackOp>(
177202
packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
178-
maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
203+
maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles(),
204+
maybePackedInit.getOuterDimsPerm()));
179205
results.push_back(unPackOps.back());
180206
}
181207

0 commit comments

Comments
 (0)