Skip to content

Commit 092c1ad

Browse files
committed
Lower polygeist.subindex through memref.reinterpret_cast
This should be a (hopefully) foolproof method of performing indexing into a memref. A reintrepret_cast is inserted with a dynamic index calculated from the subindex index operand + the producto of the sizes of the target type. Also re-enables the existing canonicalization tests.
1 parent 4325b34 commit 092c1ad

File tree

2 files changed

+33
-37
lines changed

2 files changed

+33
-37
lines changed

lib/polygeist/Ops.cpp

+16-26
Original file line numberDiff line numberDiff line change
@@ -176,43 +176,33 @@ class SubToCast final : public OpRewritePattern<SubIndexOp> {
176176
}
177177
};
178178

179-
// Simplify polygeist.subindex to memref.subview.
180-
class SubToSubView final : public OpRewritePattern<SubIndexOp> {
179+
// Simplify polygeist.subindex to a memref.reinterpret_cast.
180+
class SubToReinterpretCast final : public OpRewritePattern<SubIndexOp> {
181181
public:
182182
using OpRewritePattern<SubIndexOp>::OpRewritePattern;
183183

184184
LogicalResult matchAndRewrite(SubIndexOp op,
185185
PatternRewriter &rewriter) const override {
186186
auto srcMemRefType = op.source().getType().cast<MemRefType>();
187187
auto resMemRefType = op.result().getType().cast<MemRefType>();
188-
auto dims = srcMemRefType.getShape().size();
188+
auto shape = srcMemRefType.getShape();
189189

190-
// For now, restrict subview lowering to statically defined memref's
191-
if (!srcMemRefType.hasStaticShape() | !resMemRefType.hasStaticShape())
190+
if (!resMemRefType.hasStaticShape())
192191
return failure();
193192

194-
// For now, restrict to simple rank-reducing indexing
195-
if (srcMemRefType.getShape().size() <= resMemRefType.getShape().size())
196-
return failure();
193+
int64_t innerSize = resMemRefType.getNumElements();
194+
auto offset = rewriter.create<arith::MulIOp>(
195+
op.getLoc(), op.index(),
196+
rewriter.create<ConstantIndexOp>(op.getLoc(), innerSize));
197197

198-
// Build offset, sizes and strides
199-
SmallVector<OpFoldResult> sizes(dims, rewriter.getIndexAttr(0));
200-
sizes[0] = op.index();
201-
SmallVector<OpFoldResult> offsets(dims);
202-
for (auto dim : llvm::enumerate(srcMemRefType.getShape())) {
203-
if (dim.index() == 0)
204-
offsets[0] = rewriter.getIndexAttr(1);
205-
else
206-
offsets[dim.index()] = rewriter.getIndexAttr(dim.value());
198+
llvm::SmallVector<OpFoldResult> sizes, strides;
199+
for (auto dim : shape.drop_front()) {
200+
sizes.push_back(rewriter.getIndexAttr(dim));
201+
strides.push_back(rewriter.getIndexAttr(1));
207202
}
208-
SmallVector<OpFoldResult> strides(dims, rewriter.getIndexAttr(1));
209-
210-
// Generate the appropriate return type:
211-
auto subMemRefType = MemRefType::get(srcMemRefType.getShape().drop_front(),
212-
srcMemRefType.getElementType());
213203

214-
rewriter.replaceOpWithNewOp<memref::SubViewOp>(
215-
op, subMemRefType, op.source(), sizes, offsets, strides);
204+
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
205+
op, resMemRefType, op.source(), offset.getResult(), sizes, strides);
216206

217207
return success();
218208
}
@@ -677,8 +667,8 @@ void SubIndexOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
677667
MLIRContext *context) {
678668
results.insert<CastOfSubIndex, SubIndex2, SubToCast, SimplifySubViewUsers,
679669
SimplifySubIndexUsers, SelectOfCast, SelectOfSubIndex,
680-
RedundantDynSubIndex>(context);
681-
// Disabled: SubToSubView
670+
RedundantDynSubIndex, SubToReinterpretCast>(context);
671+
// Disabled:
682672
}
683673

684674
/// Simplify pointer2memref(memref2pointer(x)) to cast(x)

test/polygeist-opt/canonicalization.mlir

+17-11
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
// RUN: polygeist-opt --canonicalize --split-input-file %s | FileCheck %s
2-
// XFAIL: *
3-
// CHECK: func @main(%arg0: index) -> memref<30xi32> {
4-
// CHECK: %0 = memref.alloca() : memref<30x30xi32>
5-
// CHECK: %1 = memref.subview %0[%arg0, 0] [1, 30] [1, 1] : memref<30x30xi32> to memref<30xi32>
6-
// CHECK: return %1 : memref<30xi32>
7-
// CHECK: }
2+
3+
// CHECK-LABEL: func @main(
4+
// CHECK-SAME: %[[VAL_0:.*]]: index) -> memref<30xi32> {
5+
// CHECK: %[[VAL_1:.*]] = arith.constant 30 : index
6+
// CHECK: %[[VAL_2:.*]] = memref.alloca() : memref<30x30xi32>
7+
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : index
8+
// CHECK: %[[VAL_4:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}%[[VAL_3]]], sizes: [30], strides: [1] : memref<30x30xi32> to memref<30xi32>
9+
// CHECK: return %[[VAL_4]] : memref<30xi32>
10+
// CHECK: }
811
module {
912
func @main(%arg0 : index) -> memref<30xi32> {
1013
%0 = memref.alloca() : memref<30x30xi32>
@@ -15,11 +18,14 @@ module {
1518

1619
// -----
1720

18-
// CHECK: func @main(%arg0: index) -> memref<1000xi32> {
19-
// CHECK: %0 = memref.alloca() : memref<2x1000xi32>
20-
// CHECK: %1 = memref.subview %0[%arg0, 0] [1, 1000] [1, 1] : memref<2x1000xi32> to memref<1000xi32>
21-
// CHECK: return %1 : memref<1000xi32>
22-
// CHECK: }
21+
// CHECK-LABEL: func @main(
22+
// CHECK-SAME: %[[VAL_0:.*]]: index) -> memref<1000xi32> {
23+
// CHECK: %[[VAL_1:.*]] = arith.constant 1000 : index
24+
// CHECK: %[[VAL_2:.*]] = memref.alloca() : memref<2x1000xi32>
25+
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : index
26+
// CHECK: %[[VAL_4:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}%[[VAL_3]]], sizes: [1000], strides: [1] : memref<2x1000xi32> to memref<1000xi32>
27+
// CHECK: return %[[VAL_4]] : memref<1000xi32>
28+
// CHECK: }
2329
func @main(%arg0 : index) -> memref<1000xi32> {
2430
%c0 = arith.constant 0 : index
2531
%1 = memref.alloca() : memref<2x1000xi32>

0 commit comments

Comments
 (0)