Skip to content

Commit ff96f5c

Browse files
committed
Also handle non-rank reducing offset operations
1 parent d2c7b3b commit ff96f5c

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

lib/polygeist/Passes/LowerPolygeistOps.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,20 @@ struct SubIndexToReinterpretCast
3737
ConversionPatternRewriter &rewriter) const override {
3838
auto srcMemRefType = op.source().getType().cast<MemRefType>();
3939
auto resMemRefType = op.result().getType().cast<MemRefType>();
40-
auto shape = srcMemRefType.getShape();
40+
auto inShape = srcMemRefType.getShape();
41+
auto outShape = resMemRefType.getShape();
4142

4243
if (!resMemRefType.hasStaticShape())
4344
return failure();
4445

46+
llvm::SmallVector<OpFoldResult> strides, sizes;
4547
int64_t innerSize = resMemRefType.getNumElements();
4648
auto offset = rewriter.create<arith::MulIOp>(
4749
op.getLoc(), op.index(),
4850
rewriter.create<ConstantIndexOp>(op.getLoc(), innerSize));
4951

50-
llvm::SmallVector<OpFoldResult> sizes, strides;
5152
int64_t strideAcc = 1;
52-
for (auto dim : llvm::reverse(shape.drop_front())) {
53+
for (auto dim : llvm::reverse(outShape)) {
5354
sizes.insert(sizes.begin(), rewriter.getIndexAttr(dim));
5455
strides.insert(strides.begin(), rewriter.getIndexAttr(strideAcc));
5556
strideAcc *= dim;

test/polygeist-opt/lower_polygeist_ops.mlir

+17-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ func @main(%arg0 : index) -> memref<30xi32> {
1414
return %1 : memref<30xi32>
1515
}
1616

17-
1817
// -----
1918

2019
// CHECK-LABEL: func @main(
@@ -30,3 +29,20 @@ func @main(%arg0 : index) -> memref<42x43x44x45xi32> {
3029
%1 = "polygeist.subindex"(%0, %arg0) : (memref<41x42x43x44x45xi32>, index) -> memref<42x43x44x45xi32>
3130
return %1 : memref<42x43x44x45xi32>
3231
}
32+
33+
// -----
34+
35+
// CHECK-LABEL: func @main(
36+
// CHECK-SAME: %[[VAL_0:.*]]: index) -> memref<29x30xi32> {
37+
// CHECK: %[[VAL_1:.*]] = memref.alloca() : memref<30x30xi32>
38+
// CHECK: %[[VAL_2:.*]] = arith.constant 870 : index
39+
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_2]] : index
40+
// CHECK: %[[VAL_4:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_3]]], sizes: [29, 30], strides: [30, 1] : memref<30x30xi32> to memref<29x30xi32>
41+
// CHECK: return %[[VAL_4]] : memref<29x30xi32>
42+
// CHECK: }
43+
44+
func @main(%arg0 : index) -> memref<29x30xi32> {
45+
%0 = memref.alloca() : memref<30x30xi32>
46+
%1 = "polygeist.subindex"(%0, %arg0) : (memref<30x30xi32>, index) -> memref<29x30xi32>
47+
return %1 : memref<29x30xi32>
48+
}

0 commit comments

Comments
 (0)