Skip to content

Commit 2462342

Browse files
committed
add logic for reduce and broadcast
1 parent 51d1d16 commit 2462342

File tree

2 files changed

+41
-9
lines changed

2 files changed

+41
-9
lines changed

lib/gc/Analysis/GlobalAnalysis.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
//===- GlobalAnalysis.cpp - Propagate pack unpack on linalg named ops --*- C++
2-
//-*-===//
1+
//===- GlobalAnalysis.cpp - Propagate packing on linalg named ops *- C++-*-===//
32
//
43
// This file is only temporarily used to extend upstream or upcoming utility in
54
// TilingInterface, which finally aims for upstream.
@@ -101,7 +100,8 @@ inferIndexingMapRelation(AffineMap indexingMapBase,
101100

102101
// given j --> i and max rank of i, return i --> j
103102
static DenseMap<int64_t, int64_t>
104-
getReversedIndexMap(DenseMap<int64_t, int64_t> indexMap, size_t maxRank) {
103+
getReversedIndexMap(const DenseMap<int64_t, int64_t> &indexMap,
104+
size_t maxRank) {
105105
DenseMap<int64_t, int64_t> res;
106106
for (auto pair : indexMap) {
107107
if (pair.second >= 0) {
@@ -118,7 +118,7 @@ getReversedIndexMap(DenseMap<int64_t, int64_t> indexMap, size_t maxRank) {
118118

119119
static FailureOr<TensorLayout>
120120
inferTargetLayout(TensorLayout layoutBase,
121-
DenseMap<int64_t, int64_t> indexMap) {
121+
const DenseMap<int64_t, int64_t> &indexMap) {
122122
int64_t dimDifference = indexMap.size() - layoutBase.getTensorRank();
123123
SmallVector<int64_t> baseOuterAxis = layoutBase.getOuterAxis();
124124
SmallVector<int64_t> baseInnerAxis = layoutBase.getInnerAxis();

lib/gc/Transforms/PropagateLayout.cpp

+37-5
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,19 @@ using namespace mlir;
3131
using namespace mlir::arith;
3232
using namespace mlir::tensor;
3333

34+
static SmallVector<int64_t> getPackedAxes(ArrayRef<int64_t> dimensions,
35+
TensorLayout targetLayout) {
36+
SmallVector<int64_t> result(dimensions);
37+
auto innerPos = targetLayout.getInnerAxis();
38+
for (size_t i = 0; i < dimensions.size(); ++i) {
39+
if (std::find(innerPos.begin(), innerPos.end(), dimensions[i]) !=
40+
innerPos.end()) {
41+
result.push_back(i + targetLayout.getOuterAxis().size());
42+
}
43+
}
44+
return result;
45+
}
46+
3447
static FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
3548
linalg::LinalgOp linalgOp,
3649
OperatorLayout opLayout) {
@@ -125,10 +138,30 @@ static FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
125138
ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
126139
ValueRange inits =
127140
ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
128-
// TODO(yifei): the axis info of reduce/broadcast/transpose may change
129-
auto packedLinalgOp = mlir::clone(
130-
rewriter, linalgOp, SmallVector<Type>{inputsAndInits.back().getType()},
131-
inputsAndInits);
141+
// TODO(yifei): deal with reduce/broadcast/transpose
142+
// TODO(yifei): deal with generic
143+
linalg::LinalgOp packedLinalgOp;
144+
if (auto reduceOp = dyn_cast<linalg::ReduceOp>(&linalgOp)) {
145+
SmallVector<int64_t> packedAxes =
146+
getPackedAxes(reduceOp->getDimensions(), inputLayouts[0]);
147+
packedLinalgOp = rewriter.create<linalg::ReduceOp>(
148+
loc, inits.getTypes(), inputs, inits, packedAxes);
149+
packedLinalgOp->getRegion(0).takeBody(linalgOp->getRegion(0));
150+
} else if (auto broadcastOp = dyn_cast<linalg::BroadcastOp>(&linalgOp)) {
151+
packedLinalgOp = rewriter.create<linalg::BroadcastOp>(
152+
loc, inputs[0], inits[0], broadcastOp->getDimensions());
153+
} else if (isa<linalg::TransposeOp>(linalgOp)) {
154+
// remove transpose op
155+
} else if (isa<linalg::SoftmaxOp>(linalgOp) ||
156+
isa<linalg::GenericOp>(linalgOp) || isa<linalg::MapOp>(linalgOp) ||
157+
isa<linalg::YieldOp>(linalgOp) || isa<linalg::IndexOp>(linalgOp)) {
158+
return failure(
159+
"Packing logic not implemented for SoftMax/Generic/Map/Yield/Index.");
160+
} else {
161+
packedLinalgOp = mlir::clone(
162+
rewriter, linalgOp, SmallVector<Type>{inputsAndInits.back().getType()},
163+
inputsAndInits);
164+
}
132165

133166
// Step 4. Unpack all the op results.
134167
for (OpResult result : packedLinalgOp->getResults()) {
@@ -195,7 +228,6 @@ void PropagateLayout::runOnOperation() {
195228
FailureOr<linalg::PackResult> packedOp =
196229
packNamedOp(rewriter, linalgOp, *opLayout);
197230
}
198-
graph->dump();
199231
}
200232
}
201233
graph->dump();

0 commit comments

Comments
 (0)