Skip to content

Commit ca0e2b6

Browse files
committed
support expand shape
1 parent 41cd05f commit ca0e2b6

File tree

5 files changed

+205
-151
lines changed

5 files changed

+205
-151
lines changed

include/gc/Analysis/GlobalAnalysis.h

+17-26
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Pass/Pass.h"
1919
#include "mlir/Support/LLVM.h"
2020
#include "llvm/ADT/DenseMap.h"
21+
#include <llvm/Support/Debug.h>
2122

2223
namespace mlir {
2324
namespace gc {
@@ -27,17 +28,9 @@ using namespace mlir;
2728
class TensorLayout {
2829
public:
2930
TensorLayout(ArrayRef<int64_t> outerAxis, ArrayRef<int64_t> innerAxis,
30-
ArrayRef<OpFoldResult> tileSizes) {
31+
ArrayRef<OpFoldResult> tileSizes)
32+
: OuterAxis(outerAxis), InnerAxis(innerAxis), TileSizes(tileSizes) {
3133
assert(innerAxis.size() == tileSizes.size());
32-
for (auto oa : outerAxis) {
33-
OuterAxis.push_back(oa);
34-
}
35-
for (auto ia : innerAxis) {
36-
InnerAxis.push_back(ia);
37-
}
38-
for (auto ts : tileSizes) {
39-
TileSizes.push_back(ts);
40-
}
4134
}
4235

4336
bool isPlainLayout() const {
@@ -55,25 +48,22 @@ class TensorLayout {
5548
SmallVector<OpFoldResult>{});
5649
}
5750

58-
static DenseMap<int64_t, SmallVector<int64_t>>
59-
getPlain2PackedMapping(TensorLayout layout) {
51+
DenseMap<int64_t, SmallVector<int64_t>> getPlain2PackedMapping() {
6052
DenseMap<int64_t, SmallVector<int64_t>> p2b;
61-
SmallVector<int64_t> outerAxis = layout.getOuterAxis();
62-
SmallVector<int64_t> innerAxis = layout.getInnerAxis();
63-
for (size_t i = 0; i < outerAxis.size(); ++i) {
64-
p2b[outerAxis[i]].push_back(i);
53+
for (size_t i = 0; i < OuterAxis.size(); ++i) {
54+
p2b[OuterAxis[i]].push_back(i);
6555
}
66-
for (size_t i = 0; i < innerAxis.size(); ++i) {
67-
p2b[innerAxis[i]].push_back(outerAxis.size() + i);
56+
for (size_t i = 0; i < InnerAxis.size(); ++i) {
57+
p2b[InnerAxis[i]].push_back(InnerAxis.size() + i);
6858
}
6959
return p2b;
7060
}
7161

7262
FailureOr<int64_t> getOriginalAxis(int64_t idx) {
73-
size_t totalRank = OuterAxis.size() + InnerAxis.size();
63+
int64_t totalRank = OuterAxis.size() + InnerAxis.size();
7464
if (idx >= totalRank) {
7565
return failure("Index out of range.");
76-
} else if (idx >= OuterAxis.size()) {
66+
} else if (idx >= static_cast<int64_t>(OuterAxis.size())) {
7767
return InnerAxis[idx - OuterAxis.size()];
7868
} else {
7969
return OuterAxis[idx];
@@ -88,7 +78,8 @@ class TensorLayout {
8878

8979
SmallVector<OpFoldResult> getTileSizes() const { return TileSizes; }
9080

91-
friend std::ostream &operator<<(std::ostream &ss, const TensorLayout &layout);
81+
friend llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
82+
const TensorLayout &layout);
9283

9384
bool operator==(const TensorLayout &layout);
9485

@@ -121,8 +112,8 @@ class OperatorLayout {
121112
return supportedOutputLayouts[idx];
122113
}
123114

124-
friend std::ostream &operator<<(std::ostream &ss,
125-
const OperatorLayout &opLayout);
115+
friend llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
116+
const OperatorLayout &opLayout);
126117

127118
private:
128119
SmallVector<TensorLayout> supportedInputLayouts;
@@ -134,14 +125,14 @@ class GlobalAnalysis {
134125
explicit GlobalAnalysis(Operation *root);
135126

136127
FailureOr<OperatorLayout> getOpLayout(Operation *op) {
137-
if (layout.find(op) != layout.end())
138-
return layout[op];
128+
if (layoutCache.find(op) != layoutCache.end())
129+
return layoutCache[op];
139130
else
140131
return failure("Current op does not have layout information.");
141132
}
142133

143134
private:
144-
DenseMap<Operation *, OperatorLayout> layout;
135+
DenseMap<Operation *, OperatorLayout> layoutCache;
145136
};
146137

147138
} // namespace gc

include/gc/Transforms/Passes.td

+4-3
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,13 @@ def GCCPUPipeline: Pass<"gc-cpu-pipeline"> {
4444
"vector::VectorDialect"];
4545
}
4646

47-
def PropagateLayout : Pass<"propagate-layout"> {
48-
let summary = "Insert and propagte tensor.pack to pack the computation of general linalg named ops and tensor ops.";
47+
def PropagateLayoutOnNamedOps : Pass<"propagate-layout-on-named-ops"> {
48+
let summary = "Insert and propagte tensor.pack to pack the computation of linalg named ops and tensor ops.";
4949
let description = [{
5050
Insert and propagte tensor.pack
5151
}];
52-
let dependentDialects = ["mlir::tensor::TensorDialect", "mlir::linalg::LinalgDialect"];
52+
let dependentDialects = ["mlir::tensor::TensorDialect",
53+
"mlir::linalg::LinalgDialect"];
5354
}
5455

5556
#endif // GC_DIALECT_GC_PASSES

include/gc/Transforms/Transforms.h

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===- Transforms.h - transformation utilities ------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef GC_TRANSFORMS_TRANSFORMS_H
10+
#define GC_TRANSFORMS_TRANSFORMS_H
11+
12+
#include "gc/Analysis/GlobalAnalysis.h"
13+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
14+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15+
16+
namespace mlir {
17+
namespace gc {
18+
FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
19+
linalg::LinalgOp linalgOp,
20+
OperatorLayout opLayout);
21+
22+
LogicalResult namedOpLayoutPropagation(RewriterBase &rewriter,
23+
linalg::LinalgOp linalgOp,
24+
OperatorLayout opLayout);
25+
} // namespace gc
26+
} // namespace mlir
27+
28+
#endif // GC_TRANSFORMS_TRANSFORMS_H

lib/gc/Analysis/GlobalAnalysis.cpp

+86-51
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212
namespace mlir {
1313
namespace gc {
1414

15-
std::ostream &operator<<(std::ostream &ss, const TensorLayout &layout) {
16-
SmallVector<int64_t> outerAxis = layout.getOuterAxis();
17-
SmallVector<int64_t> innerAxis = layout.getInnerAxis();
18-
SmallVector<OpFoldResult> tileSizes = layout.getTileSizes();
15+
#define DEBUG_TYPE "global-analysis"
16+
17+
llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
18+
const TensorLayout &layoutCache) {
19+
SmallVector<int64_t> outerAxis = layoutCache.getOuterAxis();
20+
SmallVector<int64_t> innerAxis = layoutCache.getInnerAxis();
21+
SmallVector<OpFoldResult> tileSizes = layoutCache.getTileSizes();
1922
ss << "[";
2023
for (size_t i = 0; i < outerAxis.size(); ++i) {
2124
if (i != 0) {
@@ -43,21 +46,21 @@ std::ostream &operator<<(std::ostream &ss, const TensorLayout &layout) {
4346
return ss;
4447
}
4548

46-
bool TensorLayout::operator==(const TensorLayout &layout) {
47-
return (this->OuterAxis == layout.getOuterAxis()) &&
48-
(this->InnerAxis == layout.getInnerAxis()) &&
49-
(this->TileSizes == layout.getTileSizes());
49+
bool TensorLayout::operator==(const TensorLayout &layoutCache) {
50+
return (this->OuterAxis == layoutCache.getOuterAxis()) &&
51+
(this->InnerAxis == layoutCache.getInnerAxis()) &&
52+
(this->TileSizes == layoutCache.getTileSizes());
5053
}
5154

52-
std::ostream &operator<<(std::ostream &ss, const OperatorLayout &opLayout) {
53-
ss << "operator has " << opLayout.getSupportedInputLayouts().size()
54-
<< " inputs; " << opLayout.getSupportedOutputLayouts().size()
55-
<< " outputs." << std::endl;
56-
for (const auto &layout : opLayout.getSupportedInputLayouts()) {
57-
ss << "input layout: " << layout << std::endl;
55+
llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
56+
const OperatorLayout &opLayout) {
57+
for (auto &&[idx, layoutCache] :
58+
llvm::enumerate(opLayout.getSupportedInputLayouts())) {
59+
ss << "input " << idx << "'s layoutCache: " << layoutCache << "\n";
5860
}
59-
for (const auto &layout : opLayout.getSupportedOutputLayouts()) {
60-
ss << "output layout: " << layout << std::endl;
61+
for (auto &&[idx, layoutCache] :
62+
llvm::enumerate(opLayout.getSupportedOutputLayouts())) {
63+
ss << "output " << idx << "'s layoutCache: " << layoutCache << "\n";
6164
}
6265
return ss;
6366
}
@@ -119,7 +122,6 @@ getReversedIndexMap(const DenseMap<int64_t, int64_t> &indexMap,
119122
static FailureOr<TensorLayout>
120123
inferTargetLayout(TensorLayout layoutBase,
121124
const DenseMap<int64_t, int64_t> &indexMap) {
122-
int64_t dimDifference = indexMap.size() - layoutBase.getTensorRank();
123125
SmallVector<int64_t> baseOuterAxis = layoutBase.getOuterAxis();
124126
SmallVector<int64_t> baseInnerAxis = layoutBase.getInnerAxis();
125127
SmallVector<OpFoldResult> baseTileSizes = layoutBase.getTileSizes();
@@ -153,38 +155,24 @@ inferTargetLayout(TensorLayout layoutBase,
153155

154156
GlobalAnalysis::GlobalAnalysis(Operation *root) {
155157
root->walk([&](Operation *op) {
158+
// get input layouts
159+
LLVM_DEBUG(llvm::dbgs()
160+
<< "Inferring layoutCache of op: " << op->getName() << "\n");
156161
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
157-
// get input layouts
158-
std::cout << std::endl;
159-
std::cout << "----------------------------------" << std::endl;
160-
linalgOp.getOperation()->getName().print(llvm::errs());
161-
std::cout << std::endl;
162-
std::cout << "----------------------------------" << std::endl;
163-
std::cout << std::endl;
164-
SmallVector<AffineMap> indexing_maps = linalgOp.getIndexingMapsArray();
165162
auto curInputs = linalgOp.getDpsInputOperands();
166163
auto curResults = linalgOp.getOperation()->getResults();
167-
168164
// ---------------- Get Current Input Layouts -------------------
169-
// get current input layouts
170-
std::cout << "----- printing ground-truth input layouts -----"
171-
<< std::endl;
172165
SmallVector<TensorLayout> curInputLayouts;
173166
for (auto input : curInputs) {
174167
auto parent = input->get().getDefiningOp();
175-
if (layout.find(parent) != layout.end()) {
168+
if (layoutCache.find(parent) != layoutCache.end()) {
176169
// TODO(yifei): it is not always 0 here
177-
curInputLayouts.push_back(layout[parent].getOutputLayout(0));
170+
curInputLayouts.push_back(layoutCache[parent].getOutputLayout(0));
178171
} else {
179172
curInputLayouts.push_back(TensorLayout::createPlainLayout(
180173
linalgOp.getMatchingIndexingMap(input).getNumResults()));
181174
}
182175
}
183-
// debug info
184-
for (auto layout : curInputLayouts) {
185-
std::cout << "layout: " << layout << std::endl;
186-
}
187-
188176
// ------ Get Current Op's Suggested Layout & Do Propagation ------
189177
IRRewriter rewriter(linalgOp);
190178
if (mlir::linalg::isaContractionOpInterface(linalgOp)) {
@@ -193,38 +181,33 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
193181
// curInputLayouts);
194182

195183
// hardcode one for now
196-
// A side layout, [0, 1, 0, 1]; {32, 32}
184+
// A side layoutCache, [0, 1, 0, 1]; {32, 32}
197185
TensorLayout A_layout(
198186
{0, 1}, {0, 1},
199187
SmallVector<OpFoldResult>{rewriter.getIndexAttr(32),
200188
rewriter.getIndexAttr(32)});
201-
// B side layout, [1, 0, 0, 1]; {32, 32}
189+
// B side layoutCache, [1, 0, 0, 1]; {32, 32}
202190
TensorLayout B_layout(
203191
{1, 0}, {0, 1},
204192
SmallVector<OpFoldResult>{rewriter.getIndexAttr(32),
205193
rewriter.getIndexAttr(32)});
206-
// C side layout, [0, 1, 0, 1]; {32, 32}
194+
// C side layoutCache, [0, 1, 0, 1]; {32, 32}
207195
TensorLayout C_layout(
208196
{0, 1}, {0, 1},
209197
SmallVector<OpFoldResult>{rewriter.getIndexAttr(32),
210198
rewriter.getIndexAttr(32)});
211199
OperatorLayout suggestedLayout({A_layout, B_layout}, {C_layout});
212-
layout[linalgOp] = suggestedLayout;
200+
layoutCache[linalgOp] = suggestedLayout;
213201
} else {
214202
SmallVector<TensorLayout> inputLayouts, outputLayouts;
215203
inputLayouts.push_back(curInputLayouts[0]);
216204
// TODO(yifei): wisely choose the input format basis
217205
// Let's only refer to input[0] for now
218206
for (size_t i = 1; i < curInputs.size(); ++i) {
219-
std::cout << "inferring indexing map relation" << std::endl;
220207
// getMatchingIndexingMap
221208
auto res = inferIndexingMapRelation(
222209
linalgOp.getMatchingIndexingMap(curInputs[0]),
223210
linalgOp.getMatchingIndexingMap(curInputs[i]));
224-
for (auto tp : *res) {
225-
std::cout << "target index: " << tp.first
226-
<< " maps to base index: " << tp.second << std::endl;
227-
}
228211
TensorLayout inputLayout =
229212
*inferTargetLayout(curInputLayouts[0], *res);
230213
inputLayouts.push_back(inputLayout);
@@ -235,14 +218,66 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
235218
TensorLayout outputLayout =
236219
*inferTargetLayout(curInputLayouts[0], *res_out);
237220
outputLayouts.push_back(outputLayout);
238-
for (auto tp : *res_out) {
239-
std::cout << "target index: " << tp.first
240-
<< " maps to base index: " << tp.second << std::endl;
241-
}
242221
OperatorLayout suggestedLayout(inputLayouts, outputLayouts);
243-
layout[linalgOp] = suggestedLayout;
222+
layoutCache[linalgOp] = suggestedLayout;
223+
}
224+
} else if (auto padOp = dyn_cast<tensor::PadOp>(op)) {
225+
auto inputOperand = padOp.getSource();
226+
auto inputRank =
227+
cast<ShapedType>(inputOperand.getType()).getShape().size();
228+
auto parent = inputOperand.getDefiningOp();
229+
TensorLayout curInputLayout =
230+
layoutCache.find(parent) != layoutCache.end()
231+
? layoutCache[parent].getOutputLayout(0)
232+
: TensorLayout::createPlainLayout(inputRank);
233+
SmallVector<TensorLayout> inputLayouts{curInputLayout},
234+
outputLayouts{curInputLayout};
235+
OperatorLayout suggestedLayout(inputLayouts, outputLayouts);
236+
layoutCache[padOp] = suggestedLayout;
237+
} else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op)) {
238+
auto reassociation = expandShapeOp.getReassociation();
239+
auto staticOutputShape = expandShapeOp.getStaticOutputShape();
240+
auto parent = expandShapeOp.getSrc().getDefiningOp();
241+
auto inputShape = expandShapeOp.getSrcType().getShape();
242+
TensorLayout curInputLayout =
243+
layoutCache.find(parent) != layoutCache.end()
244+
? layoutCache[parent].getOutputLayout(0)
245+
: TensorLayout::createPlainLayout(inputShape.size());
246+
DenseMap<int64_t, int64_t> outputInputIdxMapping, inputOutputIndexMapping;
247+
int64_t accumulationOffset = 0;
248+
for (int64_t i = 0; i < static_cast<int64_t>(reassociation.size()); ++i) {
249+
auto subReassociation = llvm::cast<ArrayAttr>(reassociation[i]);
250+
for (int64_t j = 0; j < static_cast<int64_t>(subReassociation.size());
251+
++j) {
252+
if (staticOutputShape[accumulationOffset + j] == inputShape[i]) {
253+
outputInputIdxMapping[accumulationOffset + j] = i;
254+
inputOutputIndexMapping[i] = accumulationOffset + j;
255+
}
256+
}
257+
accumulationOffset += subReassociation.size();
258+
}
259+
auto inputOuterAxis = curInputLayout.getOuterAxis();
260+
auto inputInnerAxis = curInputLayout.getInnerAxis();
261+
int64_t startIdx = 0;
262+
SmallVector<int64_t> outputOuterAxis, outputInnerAxis;
263+
for (int64_t i = 0; i < static_cast<int64_t>(staticOutputShape.size());
264+
++i) {
265+
if (outputInputIdxMapping.find(i) != outputInputIdxMapping.end()) {
266+
outputOuterAxis.push_back(inputOuterAxis[outputInputIdxMapping[i]]);
267+
} else {
268+
outputOuterAxis.push_back(startIdx++);
269+
}
270+
}
271+
for (int64_t i = 0; i < static_cast<int64_t>(inputInnerAxis.size());
272+
++i) {
273+
outputInnerAxis.push_back(inputOutputIndexMapping[inputInnerAxis[i]]);
244274
}
245-
} else if (isa<tensor::PadOp>(op) || isa<tensor::ExpandShapeOp>(op)) {
275+
TensorLayout outputLayout(outputOuterAxis, outputInnerAxis,
276+
curInputLayout.getTileSizes());
277+
SmallVector<TensorLayout> inputLayouts{curInputLayout},
278+
outputLayouts{outputLayout};
279+
OperatorLayout suggestedLayout(inputLayouts, outputLayouts);
280+
layoutCache[expandShapeOp] = suggestedLayout;
246281
}
247282
});
248283
}

0 commit comments

Comments
 (0)