@@ -176,43 +176,33 @@ class SubToCast final : public OpRewritePattern<SubIndexOp> {
176
176
}
177
177
};
178
178
179
- // Simplify polygeist.subindex to memref.subview .
180
- class SubToSubView final : public OpRewritePattern<SubIndexOp> {
179
+ // Simplify polygeist.subindex to a memref.reinterpret_cast .
180
+ class SubToReinterpretCast final : public OpRewritePattern<SubIndexOp> {
181
181
public:
182
182
using OpRewritePattern<SubIndexOp>::OpRewritePattern;
183
183
184
184
LogicalResult matchAndRewrite (SubIndexOp op,
185
185
PatternRewriter &rewriter) const override {
186
186
auto srcMemRefType = op.source ().getType ().cast <MemRefType>();
187
187
auto resMemRefType = op.result ().getType ().cast <MemRefType>();
188
- auto dims = srcMemRefType.getShape (). size ();
188
+ auto shape = srcMemRefType.getShape ();
189
189
190
- // For now, restrict subview lowering to statically defined memref's
191
- if (!srcMemRefType.hasStaticShape () | !resMemRefType.hasStaticShape ())
190
+ if (!resMemRefType.hasStaticShape ())
192
191
return failure ();
193
192
194
- // For now, restrict to simple rank-reducing indexing
195
- if (srcMemRefType.getShape ().size () <= resMemRefType.getShape ().size ())
196
- return failure ();
193
+ int64_t innerSize = resMemRefType.getNumElements ();
194
+ auto offset = rewriter.create <arith::MulIOp>(
195
+ op.getLoc (), op.index (),
196
+ rewriter.create <ConstantIndexOp>(op.getLoc (), innerSize));
197
197
198
- // Build offset, sizes and strides
199
- SmallVector<OpFoldResult> sizes (dims, rewriter.getIndexAttr (0 ));
200
- sizes[0 ] = op.index ();
201
- SmallVector<OpFoldResult> offsets (dims);
202
- for (auto dim : llvm::enumerate (srcMemRefType.getShape ())) {
203
- if (dim.index () == 0 )
204
- offsets[0 ] = rewriter.getIndexAttr (1 );
205
- else
206
- offsets[dim.index ()] = rewriter.getIndexAttr (dim.value ());
198
+ llvm::SmallVector<OpFoldResult> sizes, strides;
199
+ for (auto dim : shape.drop_front ()) {
200
+ sizes.push_back (rewriter.getIndexAttr (dim));
201
+ strides.push_back (rewriter.getIndexAttr (1 ));
207
202
}
208
- SmallVector<OpFoldResult> strides (dims, rewriter.getIndexAttr (1 ));
209
-
210
- // Generate the appropriate return type:
211
- auto subMemRefType = MemRefType::get (srcMemRefType.getShape ().drop_front (),
212
- srcMemRefType.getElementType ());
213
203
214
- rewriter.replaceOpWithNewOp <memref::SubViewOp >(
215
- op, subMemRefType , op.source (), sizes, offsets , strides);
204
+ rewriter.replaceOpWithNewOp <memref::ReinterpretCastOp >(
205
+ op, resMemRefType , op.source (), offset. getResult (), sizes , strides);
216
206
217
207
return success ();
218
208
}
@@ -677,8 +667,8 @@ void SubIndexOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
677
667
MLIRContext *context) {
678
668
results.insert <CastOfSubIndex, SubIndex2, SubToCast, SimplifySubViewUsers,
679
669
SimplifySubIndexUsers, SelectOfCast, SelectOfSubIndex,
680
- RedundantDynSubIndex>(context);
681
- // Disabled: SubToSubView
670
+ RedundantDynSubIndex, SubToReinterpretCast >(context);
671
+ // Disabled:
682
672
}
683
673
684
674
// / Simplify pointer2memref(memref2pointer(x)) to cast(x)
0 commit comments