Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower polygeist.subindex through memref.reinterpret_cast #147

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/polygeist/Passes/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ std::unique_ptr<Pass> createParallelLowerPass();
std::unique_ptr<Pass>
createConvertPolygeistToLLVMPass(const LowerToLLVMOptions &options);
std::unique_ptr<Pass> createConvertPolygeistToLLVMPass();
std::unique_ptr<Pass> createLowerPolygeistOpsPass();

} // namespace polygeist
} // namespace mlir
Expand Down
6 changes: 6 additions & 0 deletions include/polygeist/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ def RemoveTrivialUse : Pass<"trivialuse"> {
let constructor = "mlir::polygeist::createRemoveTrivialUsePass()";
}

def LowerPolygeistOps : Pass<"lower-polygeist-ops"> {
let summary = "Lower polygeist ops to memref operations";
let constructor = "mlir::polygeist::createLowerPolygeistOpsPass()";
let dependentDialects = ["::mlir::memref::MemRefDialect"];
}

def ConvertPolygeistToLLVM : Pass<"convert-polygeist-to-llvm", "mlir::ModuleOp"> {
let summary = "Convert scalar and vector operations from the Standard to the "
"LLVM dialect";
Expand Down
42 changes: 0 additions & 42 deletions lib/polygeist/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,48 +529,6 @@ class SubToCast final : public OpRewritePattern<SubIndexOp> {
}
};

// Simplify polygeist.subindex to memref.subview.
class SubToSubView final : public OpRewritePattern<SubIndexOp> {
public:
using OpRewritePattern<SubIndexOp>::OpRewritePattern;

LogicalResult matchAndRewrite(SubIndexOp op,
PatternRewriter &rewriter) const override {
auto srcMemRefType = op.source().getType().cast<MemRefType>();
auto resMemRefType = op.result().getType().cast<MemRefType>();
auto dims = srcMemRefType.getShape().size();

// For now, restrict subview lowering to statically defined memref's
if (!srcMemRefType.hasStaticShape() | !resMemRefType.hasStaticShape())
return failure();

// For now, restrict to simple rank-reducing indexing
if (srcMemRefType.getShape().size() <= resMemRefType.getShape().size())
return failure();

// Build offset, sizes and strides
SmallVector<OpFoldResult> sizes(dims, rewriter.getIndexAttr(0));
sizes[0] = op.index();
SmallVector<OpFoldResult> offsets(dims);
for (auto dim : llvm::enumerate(srcMemRefType.getShape())) {
if (dim.index() == 0)
offsets[0] = rewriter.getIndexAttr(1);
else
offsets[dim.index()] = rewriter.getIndexAttr(dim.value());
}
SmallVector<OpFoldResult> strides(dims, rewriter.getIndexAttr(1));

// Generate the appropriate return type:
auto subMemRefType = MemRefType::get(srcMemRefType.getShape().drop_front(),
srcMemRefType.getElementType());

rewriter.replaceOpWithNewOp<memref::SubViewOp>(
op, subMemRefType, op.source(), sizes, offsets, strides);

return success();
}
};

// Simplify redundant dynamic subindex patterns which tries to represent
// rank-reducing indexing:
// %3 = "polygeist.subindex"(%1, %arg0) : (memref<2x1000xi32>, index) ->
Expand Down
1 change: 1 addition & 0 deletions lib/polygeist/Passes/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRPolygeistTransforms
TrivialUse.cpp
ConvertPolygeistToLLVM.cpp
InnerSerialization.cpp
LowerPolygeistOps.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine
Expand Down
90 changes: 90 additions & 0 deletions lib/polygeist/Passes/LowerPolygeistOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
//===- LowerPolygeistOps.cpp - Lower polygeist ops to upstream MLIR ops -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass which lowers any remaining Polygeist dialect
// operations (after canonicalization) to operations found in upstream MLIR
// dialects.
//
//===----------------------------------------------------------------------===//
#include "PassDetails.h"

#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/DialectConversion.h"
#include "polygeist/Dialect.h"
#include "polygeist/Ops.h"

using namespace mlir;
using namespace polygeist;
using namespace mlir::arith;

namespace {

struct SubIndexToReinterpretCast
: public OpConversionPattern<polygeist::SubIndexOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(polygeist::SubIndexOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcMemRefType = op.source().getType().cast<MemRefType>();
auto resMemRefType = op.result().getType().cast<MemRefType>();
auto inShape = srcMemRefType.getShape();
auto outShape = resMemRefType.getShape();

if (!resMemRefType.hasStaticShape())
return failure();

llvm::SmallVector<OpFoldResult> strides, sizes;
int64_t innerSize = resMemRefType.getNumElements();
auto offset = rewriter.create<arith::MulIOp>(
op.getLoc(), op.index(),
rewriter.create<ConstantIndexOp>(op.getLoc(), innerSize));

int64_t strideAcc = 1;
for (auto dim : llvm::reverse(outShape)) {
sizes.insert(sizes.begin(), rewriter.getIndexAttr(dim));
strides.insert(strides.begin(), rewriter.getIndexAttr(strideAcc));
strideAcc *= dim;
}

rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
op, resMemRefType, op.source(), offset.getResult(), sizes, strides);

return success();
}
};

struct LowerPolygeistOpsPass
: public LowerPolygeistOpsBase<LowerPolygeistOpsPass> {

void runOnOperation() override {
auto op = getOperation();
auto ctx = op->getContext();
RewritePatternSet patterns(ctx);
patterns.insert<SubIndexToReinterpretCast>(ctx);

ConversionTarget target(*ctx);
target.addIllegalDialect<polygeist::PolygeistDialect>();
target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect>();

if (failed(applyPartialConversion(op, target, std::move(patterns))))
return signalPassFailure();
}
};
} // namespace

namespace mlir {
namespace polygeist {
std::unique_ptr<Pass> createLowerPolygeistOpsPass() {
return std::make_unique<LowerPolygeistOpsPass>();
}

} // namespace polygeist
} // namespace mlir
48 changes: 48 additions & 0 deletions test/polygeist-opt/lower_polygeist_ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// RUN: polygeist-opt --lower-polygeist-ops --split-input-file %s | FileCheck %s

// CHECK-LABEL: func @main(
// CHECK-SAME: %[[VAL_0:.*]]: index) -> memref<30xi32> {
// CHECK: %[[VAL_1:.*]] = memref.alloca() : memref<30x30xi32>
// CHECK: %[[VAL_2:.*]] = arith.constant 30 : index
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_2]] : index
// CHECK: %[[VAL_4:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_3]]], sizes: [30], strides: [1] : memref<30x30xi32> to memref<30xi32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you may need to extract the previous stride and then add this new value. Lest this not work if you have two such operations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have been done in the updated commit (see the new test for reference). Let me know what you think!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps I'm not understanding, or I'm still not seeing it.

Can you add a subindex of subindex test (where both subindices are just offsets, rather than rank reducing?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, could you have the memref argument be passed in as an argument?

In other words I would've expected %[[VAL_3:.*]] to set the new offset = oldoffset + index * dimsize, whereas it is currently just index * dimsize

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a subindex of subindex test (where both subindices are just offsets, rather than rank reducing?

I've added an extra test for the case of a non-rank reducing subindex operation. Again, this initial PR only covers the case of statically sized memrefs (which takes care of the most pressing issues on my sides). subindex to subindex of statically sized memories should therefore hold transitively.

Additionally, could you have the memref argument be passed in as an argument?
In other words I would've expected %[[VAL_3:.*]] to set the new offset = oldoffset + index * dimsize, whereas it is currently just index * dimsize

Not sure i understand what you mean by this. Could you show a snippet of the IR that you expect here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess my worry is that you don't appear to be adding to the offset (e.g. you are setting the new offset). Suppose you have two subindex's in a row, the total offset should include terms from both subindex operations. Is that the case presently?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless i am misunderstanding the reinterpret_cast operation, the offset is included in the value that the reinterpret_cast returns. So when two subindex operations are in a row, the first one will be converted to a reinterpret cast that has the same semantics as the initial subindex operation. Lowering the second is therefore independent of this.

But again an IR example of expected behaviour here might clarify any confusion.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
  llvm.func @printf(!llvm.ptr<i8>, ...) -> i32
  llvm.mlir.global internal constant @str0("data[%d][%d]=%d\0A\00")
  func @set(%arg0: memref<?xi32>) attributes {llvm.linkage = #llvm.linkage<external>} {
    %c3_i32 = arith.constant 3 : i32
    affine.store %c3_i32, %arg0[0] : memref<?xi32>
    return
  }
  func @main() -> i32 attributes {llvm.linkage = #llvm.linkage<external>} {
    %c3_i32 = arith.constant 3 : i32
    %c2 = arith.constant 2 : index
    %c1 = arith.constant 1 : index
    %c0_i32 = arith.constant 0 : i32
    %c0 = arith.constant 0 : index
    %0 = llvm.mlir.undef : i32
    %1 = memref.alloca() : memref<2x3xi32>
    affine.store %c0_i32, %1[0, 0] : memref<2x3xi32>
    affine.store %c0_i32, %1[0, 1] : memref<2x3xi32>
    affine.store %c0_i32, %1[0, 2] : memref<2x3xi32>
    affine.store %c0_i32, %1[1, 0] : memref<2x3xi32>
    affine.store %c0_i32, %1[1, 1] : memref<2x3xi32>
    affine.store %c0_i32, %1[1, 2] : memref<2x3xi32>
    %2 = "polygeist.subindex"(%1, %c1) : (memref<2x3xi32>, index) -> memref<?xi32>
    %3 = "polygeist.subindex"(%2, %c2) : (memref<?xi32>, index) -> memref<?xi32>
    affine.store %c3_i32, %3[0] : memref<?xi32>
    scf.for %arg0 = %c0 to %c2 step %c1 {
      %4 = arith.index_cast %arg0 : index to i32
      scf.for %arg1 = %c0 to %c2 step %c1 {
        %5 = arith.index_cast %arg1 : index to i32
        %6 = llvm.mlir.addressof @str0 : !llvm.ptr<array<17 x i8>>
        %7 = llvm.getelementptr %6[%c0_i32, %c0_i32] : (!llvm.ptr<array<17 x i8>>, i32, i32) -> !llvm.ptr<i8>
        %8 = memref.load %1[%arg0, %arg1] : memref<2x3xi32>
        %9 = llvm.call @printf(%7, %4, %5, %8) : (!llvm.ptr<i8>, i32, i32, i32) -> i32
      }
    }
    return %0 : i32
  }
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such cases are not covered by the PR due to the dynamically sized memrefs.

Again, this PR intends to lower polygeist.subindex operations with static memrefs. By having this, even if it is only a subset of all possible polygeist.subindex operations that are converted, we still expand the set of C programs that Polygeist can emit using upstream MLIR dialect operations. It is vital that the Polygeist dialect operations are converted for any downstream tools to be able to consume the output IR.
I am not excluding that I'll in the future look more closely into how polygeist.subindex operations with dynamically shaped memrefs can be lowered into something meaningful, but for now, our usecase requires statically shaped memrefs and as such that is the most pressing issue to have resolved in Polygeist.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh apologies, let me make that a statically sized subindex in that case (also regardless running the pass on the above crashes, which shuoldn't happen...)

Regardless, my concern here is (perhaps because of unfamiliarity with reinterpret_cast) that one of the offsets will be lost.

module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
  llvm.func @printf(!llvm.ptr<i8>, ...) -> i32
  llvm.mlir.global internal constant @str0("data[%d][%d]=%d\0A\00")
  func @set(%arg0: memref<?xi32>) attributes {llvm.linkage = #llvm.linkage<external>} {
    %c3_i32 = arith.constant 3 : i32
    affine.store %c3_i32, %arg0[0] : memref<?xi32>
    return
  }
  func @main() -> i32 attributes {llvm.linkage = #llvm.linkage<external>} {
    %c3_i32 = arith.constant 3 : i32
    %c2 = arith.constant 2 : index
    %c1 = arith.constant 1 : index
    %c0_i32 = arith.constant 0 : i32
    %c0 = arith.constant 0 : index
    %0 = llvm.mlir.undef : i32
    %1 = memref.alloca() : memref<2x3xi32>
    affine.store %c0_i32, %1[0, 0] : memref<2x3xi32>
    affine.store %c0_i32, %1[0, 1] : memref<2x3xi32>
    affine.store %c0_i32, %1[0, 2] : memref<2x3xi32>
    affine.store %c0_i32, %1[1, 0] : memref<2x3xi32>
    affine.store %c0_i32, %1[1, 1] : memref<2x3xi32>
    affine.store %c0_i32, %1[1, 2] : memref<2x3xi32>
    %2 = "polygeist.subindex"(%1, %c1) : (memref<2x3xi32>, index) -> memref<3xi32>
    %3 = "polygeist.subindex"(%2, %c2) : (memref<3xi32>, index) -> memref<?xi32>
    affine.store %c3_i32, %3[0] : memref<?xi32>
    scf.for %arg0 = %c0 to %c2 step %c1 {
      %4 = arith.index_cast %arg0 : index to i32
      scf.for %arg1 = %c0 to %c2 step %c1 {
        %5 = arith.index_cast %arg1 : index to i32
        %6 = llvm.mlir.addressof @str0 : !llvm.ptr<array<17 x i8>>
        %7 = llvm.getelementptr %6[%c0_i32, %c0_i32] : (!llvm.ptr<array<17 x i8>>, i32, i32) -> !llvm.ptr<i8>
        %8 = memref.load %1[%arg0, %arg1] : memref<2x3xi32>
        %9 = llvm.call @printf(%7, %4, %5, %8) : (!llvm.ptr<i8>, i32, i32, i32) -> i32
      }
    }
    return %0 : i32
  }
}

// CHECK: return %[[VAL_4]] : memref<30xi32>
// CHECK: }
func @main(%arg0 : index) -> memref<30xi32> {
%0 = memref.alloca() : memref<30x30xi32>
%1 = "polygeist.subindex"(%0, %arg0) : (memref<30x30xi32>, index) -> memref<30xi32>
return %1 : memref<30xi32>
}

// -----

// CHECK-LABEL: func @main(
// CHECK-SAME: %[[VAL_0:.*]]: index) -> memref<42x43x44x45xi32> {
// CHECK: %[[VAL_1:.*]] = memref.alloca() : memref<41x42x43x44x45xi32>
// CHECK: %[[VAL_2:.*]] = arith.constant 3575880 : index
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_2]] : index
// CHECK: %[[VAL_4:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_3]]], sizes: [42, 43, 44, 45], strides: [85140, 1980, 45, 1] : memref<41x42x43x44x45xi32> to memref<42x43x44x45xi32>
// CHECK: return %[[VAL_4]] : memref<42x43x44x45xi32>
// CHECK: }
func @main(%arg0 : index) -> memref<42x43x44x45xi32> {
%0 = memref.alloca() : memref<41x42x43x44x45xi32>
%1 = "polygeist.subindex"(%0, %arg0) : (memref<41x42x43x44x45xi32>, index) -> memref<42x43x44x45xi32>
return %1 : memref<42x43x44x45xi32>
}

// -----

// CHECK-LABEL: func @main(
// CHECK-SAME: %[[VAL_0:.*]]: index) -> memref<29x30xi32> {
// CHECK: %[[VAL_1:.*]] = memref.alloca() : memref<30x30xi32>
// CHECK: %[[VAL_2:.*]] = arith.constant 870 : index
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_2]] : index
// CHECK: %[[VAL_4:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_3]]], sizes: [29, 30], strides: [30, 1] : memref<30x30xi32> to memref<29x30xi32>
// CHECK: return %[[VAL_4]] : memref<29x30xi32>
// CHECK: }

func @main(%arg0 : index) -> memref<29x30xi32> {
%0 = memref.alloca() : memref<30x30xi32>
%1 = "polygeist.subindex"(%0, %arg0) : (memref<30x30xi32>, index) -> memref<29x30xi32>
return %1 : memref<29x30xi32>
}