Skip to content

Commit dab65e8

Browse files
committed
Extend polygeist subindex lowering to multidim memrefs
1 parent fb01c62 commit dab65e8

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

lib/polygeist/Passes/LowerPolygeistOps.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,11 @@ struct SubIndexToReinterpretCast
4646
rewriter.create<ConstantIndexOp>(op.getLoc(), innerSize));
4747

4848
llvm::SmallVector<OpFoldResult> sizes, strides;
49-
for (auto dim : shape.drop_front()) {
50-
sizes.push_back(rewriter.getIndexAttr(dim));
51-
strides.push_back(rewriter.getIndexAttr(1));
49+
int64_t strideAcc = 1;
50+
for (auto dim : llvm::reverse(shape.drop_front())) {
51+
sizes.insert(sizes.begin(), rewriter.getIndexAttr(dim));
52+
strides.insert(strides.begin(), rewriter.getIndexAttr(strideAcc));
53+
strideAcc *= dim;
5254
}
5355

5456
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(

test/polygeist-opt/lower_polygeist_ops.mlir

+21-6
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,25 @@
88
// CHECK: %[[VAL_4:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_3]]], sizes: [30], strides: [1] : memref<30x30xi32> to memref<30xi32>
99
// CHECK: return %[[VAL_4]] : memref<30xi32>
1010
// CHECK: }
11-
module {
12-
func @main(%arg0 : index) -> memref<30xi32> {
13-
%0 = memref.alloca() : memref<30x30xi32>
14-
%1 = "polygeist.subindex"(%0, %arg0) : (memref<30x30xi32>, index) -> memref<30xi32>
15-
return %1 : memref<30xi32>
16-
}
11+
func @main(%arg0 : index) -> memref<30xi32> {
12+
%0 = memref.alloca() : memref<30x30xi32>
13+
%1 = "polygeist.subindex"(%0, %arg0) : (memref<30x30xi32>, index) -> memref<30xi32>
14+
return %1 : memref<30xi32>
15+
}
16+
17+
18+
// -----
19+
20+
// CHECK-LABEL: func @main(
21+
// CHECK-SAME: %[[VAL_0:.*]]: index) -> memref<42x43x44x45xi32> {
22+
// CHECK: %[[VAL_1:.*]] = memref.alloca() : memref<41x42x43x44x45xi32>
23+
// CHECK: %[[VAL_2:.*]] = arith.constant 3575880 : index
24+
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_2]] : index
25+
// CHECK: %[[VAL_4:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_3]]], sizes: [42, 43, 44, 45], strides: [85140, 1980, 45, 1] : memref<41x42x43x44x45xi32> to memref<42x43x44x45xi32>
26+
// CHECK: return %[[VAL_4]] : memref<42x43x44x45xi32>
27+
// CHECK: }
28+
func @main(%arg0 : index) -> memref<42x43x44x45xi32> {
29+
%0 = memref.alloca() : memref<41x42x43x44x45xi32>
30+
%1 = "polygeist.subindex"(%0, %arg0) : (memref<41x42x43x44x45xi32>, index) -> memref<42x43x44x45xi32>
31+
return %1 : memref<42x43x44x45xi32>
1732
}

0 commit comments

Comments
 (0)