From 81c33c3e1a449e5d9bcf23bc902d68536a573fdf Mon Sep 17 00:00:00 2001 From: Morten Borup Petersen Date: Fri, 7 Jan 2022 10:51:20 +0100 Subject: [PATCH 1/6] Lower polygeist.subindex through memref.reinterpret_cast This should be a (hopefully) foolproof method of performing indexing into a memref. A reintrepret_cast is inserted with a dynamic index calculated from the subindex index operand + the product of the sizes of the target type. This has been added as a separate conversion pass instead of through the canonicalization drivers. When added as a canonicalization, the conversion may preemptively apply, resulting in sub-par IR. Nevertheless, i think it has its merits to have a polygeist op lowering pass which can be used as a fallback to convert the dialect operations, if canonicalization fails. For now, just added support for statically shaped memrefs (enough to fix the regression on my side) but should be possible for dynamically shaped as well. --- include/polygeist/Passes/Passes.h | 1 + include/polygeist/Passes/Passes.td | 6 ++ lib/polygeist/Ops.cpp | 42 ---------- lib/polygeist/Passes/CMakeLists.txt | 1 + lib/polygeist/Passes/LowerPolygeistOps.cpp | 88 +++++++++++++++++++++ test/polygeist-opt/lower_polygeist_ops.mlir | 17 ++++ 6 files changed, 113 insertions(+), 42 deletions(-) create mode 100644 lib/polygeist/Passes/LowerPolygeistOps.cpp create mode 100644 test/polygeist-opt/lower_polygeist_ops.mlir diff --git a/include/polygeist/Passes/Passes.h b/include/polygeist/Passes/Passes.h index ba847ee251b8..a48cfbb14524 100644 --- a/include/polygeist/Passes/Passes.h +++ b/include/polygeist/Passes/Passes.h @@ -24,6 +24,7 @@ std::unique_ptr createParallelLowerPass(); std::unique_ptr createConvertPolygeistToLLVMPass(const LowerToLLVMOptions &options); std::unique_ptr createConvertPolygeistToLLVMPass(); +std::unique_ptr createLowerPolygeistOpsPass(); } // namespace polygeist } // namespace mlir diff --git a/include/polygeist/Passes/Passes.td b/include/polygeist/Passes/Passes.td index d2ab39bbd6e4..49f4e1ef127d 100644 --- a/include/polygeist/Passes/Passes.td +++ b/include/polygeist/Passes/Passes.td @@ -78,6 +78,12 @@ def RemoveTrivialUse : Pass<"trivialuse"> { let constructor = "mlir::polygeist::createRemoveTrivialUsePass()"; } +def LowerPolygeistOps : FunctionPass<"lower-polygeist-ops"> { + let summary = "Lower polygeist ops to memref operations"; + let constructor = "mlir::polygeist::createLowerPolygeistOpsPass()"; + let dependentDialects = ["::mlir::memref::MemRefDialect"]; +} + def ConvertPolygeistToLLVM : Pass<"convert-polygeist-to-llvm", "mlir::ModuleOp"> { let summary = "Convert scalar and vector operations from the Standard to the " "LLVM dialect"; diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index 1a7ffff582cf..fc1e6d0a2eb1 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -529,48 +529,6 @@ class SubToCast final : public OpRewritePattern { } }; -// Simplify polygeist.subindex to memref.subview. -class SubToSubView final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SubIndexOp op, - PatternRewriter &rewriter) const override { - auto srcMemRefType = op.source().getType().cast(); - auto resMemRefType = op.result().getType().cast(); - auto dims = srcMemRefType.getShape().size(); - - // For now, restrict subview lowering to statically defined memref's - if (!srcMemRefType.hasStaticShape() | !resMemRefType.hasStaticShape()) - return failure(); - - // For now, restrict to simple rank-reducing indexing - if (srcMemRefType.getShape().size() <= resMemRefType.getShape().size()) - return failure(); - - // Build offset, sizes and strides - SmallVector sizes(dims, rewriter.getIndexAttr(0)); - sizes[0] = op.index(); - SmallVector offsets(dims); - for (auto dim : llvm::enumerate(srcMemRefType.getShape())) { - if (dim.index() == 0) - offsets[0] = rewriter.getIndexAttr(1); - else - offsets[dim.index()] = rewriter.getIndexAttr(dim.value()); - } - SmallVector strides(dims, rewriter.getIndexAttr(1)); - - // Generate the appropriate return type: - auto subMemRefType = MemRefType::get(srcMemRefType.getShape().drop_front(), - srcMemRefType.getElementType()); - - rewriter.replaceOpWithNewOp( - op, subMemRefType, op.source(), sizes, offsets, strides); - - return success(); - } -}; - // Simplify redundant dynamic subindex patterns which tries to represent // rank-reducing indexing: // %3 = "polygeist.subindex"(%1, %arg0) : (memref<2x1000xi32>, index) -> diff --git a/lib/polygeist/Passes/CMakeLists.txt b/lib/polygeist/Passes/CMakeLists.txt index 495942aaa34a..436c8e94b882 100644 --- a/lib/polygeist/Passes/CMakeLists.txt +++ b/lib/polygeist/Passes/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRPolygeistTransforms TrivialUse.cpp ConvertPolygeistToLLVM.cpp InnerSerialization.cpp + LowerPolygeistOps.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine diff --git a/lib/polygeist/Passes/LowerPolygeistOps.cpp b/lib/polygeist/Passes/LowerPolygeistOps.cpp new file mode 100644 index 000000000000..be3152b0d513 --- /dev/null +++ b/lib/polygeist/Passes/LowerPolygeistOps.cpp @@ -0,0 +1,88 @@ +//===- TrivialUse.cpp - Remove trivial use instruction ---------------- -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to lower gpu kernels in NVVM/gpu dialects into +// a generic parallel for representation +//===----------------------------------------------------------------------===// +#include "PassDetails.h" + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Transforms/DialectConversion.h" +#include "polygeist/Dialect.h" +#include "polygeist/Ops.h" + +using namespace mlir; +using namespace polygeist; +using namespace mlir::arith; + +namespace { + +struct SubIndexToReinterpretCast + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(polygeist::SubIndexOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcMemRefType = op.source().getType().cast(); + auto resMemRefType = op.result().getType().cast(); + auto shape = srcMemRefType.getShape(); + + if (!resMemRefType.hasStaticShape()) + return failure(); + + int64_t innerSize = resMemRefType.getNumElements(); + auto offset = rewriter.create( + op.getLoc(), op.index(), + rewriter.create(op.getLoc(), innerSize)); + + llvm::SmallVector sizes, strides; + for (auto dim : shape.drop_front()) { + sizes.push_back(rewriter.getIndexAttr(dim)); + strides.push_back(rewriter.getIndexAttr(1)); + } + + rewriter.replaceOpWithNewOp( + op, resMemRefType, op.source(), offset.getResult(), sizes, strides); + + return success(); + } +}; + +struct LowerPolygeistOpsPass + : public LowerPolygeistOpsBase { + + void runOnFunction() override { + auto op = getOperation(); + auto ctx = op.getContext(); + RewritePatternSet patterns(ctx); + patterns.insert(ctx); + + ConversionTarget target(*ctx); + target.addIllegalDialect(); + target.addLegalDialect(); + + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +namespace mlir { +namespace polygeist { +std::unique_ptr createLowerPolygeistOpsPass() { + return std::make_unique(); +} + +} // namespace polygeist +} // namespace mlir diff --git a/test/polygeist-opt/lower_polygeist_ops.mlir b/test/polygeist-opt/lower_polygeist_ops.mlir new file mode 100644 index 000000000000..cd84039e637b --- /dev/null +++ b/test/polygeist-opt/lower_polygeist_ops.mlir @@ -0,0 +1,17 @@ +// RUN: polygeist-opt --lower-polygeist-ops --split-input-file %s | FileCheck %s + +// CHECK-LABEL: func @main( +// CHECK-SAME: %[[VAL_0:.*]]: index) -> memref<30xi32> { +// CHECK: %[[VAL_1:.*]] = memref.alloca() : memref<30x30xi32> +// CHECK: %[[VAL_2:.*]] = arith.constant 30 : index +// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_2]] : index +// CHECK: %[[VAL_4:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_3]]], sizes: [30], strides: [1] : memref<30x30xi32> to memref<30xi32> +// CHECK: return %[[VAL_4]] : memref<30xi32> +// CHECK: } +module { + func @main(%arg0 : index) -> memref<30xi32> { + %0 = memref.alloca() : memref<30x30xi32> + %1 = "polygeist.subindex"(%0, %arg0) : (memref<30x30xi32>, index) -> memref<30xi32> + return %1 : memref<30xi32> + } +} From 0769f0032b26d9c890f7b27b2669c77486b70b50 Mon Sep 17 00:00:00 2001 From: Morten Borup Petersen Date: Sat, 29 Jan 2022 09:48:45 +0100 Subject: [PATCH 2/6] Extend polygeist subindex lowering to multidim memrefs --- lib/polygeist/Passes/LowerPolygeistOps.cpp | 8 +++--- test/polygeist-opt/lower_polygeist_ops.mlir | 27 ++++++++++++++++----- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/lib/polygeist/Passes/LowerPolygeistOps.cpp b/lib/polygeist/Passes/LowerPolygeistOps.cpp index be3152b0d513..0c7a02a21ee3 100644 --- a/lib/polygeist/Passes/LowerPolygeistOps.cpp +++ b/lib/polygeist/Passes/LowerPolygeistOps.cpp @@ -46,9 +46,11 @@ struct SubIndexToReinterpretCast rewriter.create(op.getLoc(), innerSize)); llvm::SmallVector sizes, strides; - for (auto dim : shape.drop_front()) { - sizes.push_back(rewriter.getIndexAttr(dim)); - strides.push_back(rewriter.getIndexAttr(1)); + int64_t strideAcc = 1; + for (auto dim : llvm::reverse(shape.drop_front())) { + sizes.insert(sizes.begin(), rewriter.getIndexAttr(dim)); + strides.insert(strides.begin(), rewriter.getIndexAttr(strideAcc)); + strideAcc *= dim; } rewriter.replaceOpWithNewOp( diff --git a/test/polygeist-opt/lower_polygeist_ops.mlir b/test/polygeist-opt/lower_polygeist_ops.mlir index cd84039e637b..7f1a6fa1624c 100644 --- a/test/polygeist-opt/lower_polygeist_ops.mlir +++ b/test/polygeist-opt/lower_polygeist_ops.mlir @@ -8,10 +8,25 @@ // CHECK: %[[VAL_4:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_3]]], sizes: [30], strides: [1] : memref<30x30xi32> to memref<30xi32> // CHECK: return %[[VAL_4]] : memref<30xi32> // CHECK: } -module { - func @main(%arg0 : index) -> memref<30xi32> { - %0 = memref.alloca() : memref<30x30xi32> - %1 = "polygeist.subindex"(%0, %arg0) : (memref<30x30xi32>, index) -> memref<30xi32> - return %1 : memref<30xi32> - } +func @main(%arg0 : index) -> memref<30xi32> { + %0 = memref.alloca() : memref<30x30xi32> + %1 = "polygeist.subindex"(%0, %arg0) : (memref<30x30xi32>, index) -> memref<30xi32> + return %1 : memref<30xi32> +} + + +// ----- + +// CHECK-LABEL: func @main( +// CHECK-SAME: %[[VAL_0:.*]]: index) -> memref<42x43x44x45xi32> { +// CHECK: %[[VAL_1:.*]] = memref.alloca() : memref<41x42x43x44x45xi32> +// CHECK: %[[VAL_2:.*]] = arith.constant 3575880 : index +// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_2]] : index +// CHECK: %[[VAL_4:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_3]]], sizes: [42, 43, 44, 45], strides: [85140, 1980, 45, 1] : memref<41x42x43x44x45xi32> to memref<42x43x44x45xi32> +// CHECK: return %[[VAL_4]] : memref<42x43x44x45xi32> +// CHECK: } +func @main(%arg0 : index) -> memref<42x43x44x45xi32> { + %0 = memref.alloca() : memref<41x42x43x44x45xi32> + %1 = "polygeist.subindex"(%0, %arg0) : (memref<41x42x43x44x45xi32>, index) -> memref<42x43x44x45xi32> + return %1 : memref<42x43x44x45xi32> } From 8c565b206fa9dabb6f829237d6aa1a59813d5d97 Mon Sep 17 00:00:00 2001 From: Morten Borup Petersen Date: Mon, 21 Feb 2022 13:51:18 +0100 Subject: [PATCH 3/6] rebase --- include/polygeist/Passes/Passes.td | 2 +- lib/polygeist/Passes/LowerPolygeistOps.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/polygeist/Passes/Passes.td b/include/polygeist/Passes/Passes.td index 49f4e1ef127d..e2cbd2516336 100644 --- a/include/polygeist/Passes/Passes.td +++ b/include/polygeist/Passes/Passes.td @@ -78,7 +78,7 @@ def RemoveTrivialUse : Pass<"trivialuse"> { let constructor = "mlir::polygeist::createRemoveTrivialUsePass()"; } -def LowerPolygeistOps : FunctionPass<"lower-polygeist-ops"> { +def LowerPolygeistOps : Pass<"lower-polygeist-ops", "FuncOp"> { let summary = "Lower polygeist ops to memref operations"; let constructor = "mlir::polygeist::createLowerPolygeistOpsPass()"; let dependentDialects = ["::mlir::memref::MemRefDialect"]; diff --git a/lib/polygeist/Passes/LowerPolygeistOps.cpp b/lib/polygeist/Passes/LowerPolygeistOps.cpp index 0c7a02a21ee3..392f0c5fa1a2 100644 --- a/lib/polygeist/Passes/LowerPolygeistOps.cpp +++ b/lib/polygeist/Passes/LowerPolygeistOps.cpp @@ -63,7 +63,7 @@ struct SubIndexToReinterpretCast struct LowerPolygeistOpsPass : public LowerPolygeistOpsBase { - void runOnFunction() override { + void runOnOperation() override { auto op = getOperation(); auto ctx = op.getContext(); RewritePatternSet patterns(ctx); From c070c5e5befacd4313fd7fc9860b898fc48574df Mon Sep 17 00:00:00 2001 From: Morten Borup Petersen Date: Wed, 23 Feb 2022 08:45:05 +0100 Subject: [PATCH 4/6] Rename pass --- lib/polygeist/Passes/LowerPolygeistOps.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/lib/polygeist/Passes/LowerPolygeistOps.cpp b/lib/polygeist/Passes/LowerPolygeistOps.cpp index 392f0c5fa1a2..02f04e5f7171 100644 --- a/lib/polygeist/Passes/LowerPolygeistOps.cpp +++ b/lib/polygeist/Passes/LowerPolygeistOps.cpp @@ -1,4 +1,4 @@ -//===- TrivialUse.cpp - Remove trivial use instruction ---------------- -*-===// +//===- LowerPolygeistOps.cpp - Lower polygeist ops to upstream MLIR ops -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,8 +6,10 @@ // //===----------------------------------------------------------------------===// // -// This file implements a pass to lower gpu kernels in NVVM/gpu dialects into -// a generic parallel for representation +// This file implements a pass which lowers any remaining Polygeist dialect +// operations (after canonicalization) to operations found in upstream MLIR +// dialects. +// //===----------------------------------------------------------------------===// #include "PassDetails.h" From 52f6117c5f86852f14531050d6289889ff319757 Mon Sep 17 00:00:00 2001 From: Morten Borup Petersen Date: Wed, 23 Feb 2022 09:27:45 +0100 Subject: [PATCH 5/6] Also handle non-rank reducing offset operations --- lib/polygeist/Passes/LowerPolygeistOps.cpp | 7 ++++--- test/polygeist-opt/lower_polygeist_ops.mlir | 18 +++++++++++++++++- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/lib/polygeist/Passes/LowerPolygeistOps.cpp b/lib/polygeist/Passes/LowerPolygeistOps.cpp index 02f04e5f7171..e5dd7748f29a 100644 --- a/lib/polygeist/Passes/LowerPolygeistOps.cpp +++ b/lib/polygeist/Passes/LowerPolygeistOps.cpp @@ -37,19 +37,20 @@ struct SubIndexToReinterpretCast ConversionPatternRewriter &rewriter) const override { auto srcMemRefType = op.source().getType().cast(); auto resMemRefType = op.result().getType().cast(); - auto shape = srcMemRefType.getShape(); + auto inShape = srcMemRefType.getShape(); + auto outShape = resMemRefType.getShape(); if (!resMemRefType.hasStaticShape()) return failure(); + llvm::SmallVector strides, sizes; int64_t innerSize = resMemRefType.getNumElements(); auto offset = rewriter.create( op.getLoc(), op.index(), rewriter.create(op.getLoc(), innerSize)); - llvm::SmallVector sizes, strides; int64_t strideAcc = 1; - for (auto dim : llvm::reverse(shape.drop_front())) { + for (auto dim : llvm::reverse(outShape)) { sizes.insert(sizes.begin(), rewriter.getIndexAttr(dim)); strides.insert(strides.begin(), rewriter.getIndexAttr(strideAcc)); strideAcc *= dim; diff --git a/test/polygeist-opt/lower_polygeist_ops.mlir b/test/polygeist-opt/lower_polygeist_ops.mlir index 7f1a6fa1624c..11ec89f5b6e7 100644 --- a/test/polygeist-opt/lower_polygeist_ops.mlir +++ b/test/polygeist-opt/lower_polygeist_ops.mlir @@ -14,7 +14,6 @@ func @main(%arg0 : index) -> memref<30xi32> { return %1 : memref<30xi32> } - // ----- // CHECK-LABEL: func @main( @@ -30,3 +29,20 @@ func @main(%arg0 : index) -> memref<42x43x44x45xi32> { %1 = "polygeist.subindex"(%0, %arg0) : (memref<41x42x43x44x45xi32>, index) -> memref<42x43x44x45xi32> return %1 : memref<42x43x44x45xi32> } + +// ----- + +// CHECK-LABEL: func @main( +// CHECK-SAME: %[[VAL_0:.*]]: index) -> memref<29x30xi32> { +// CHECK: %[[VAL_1:.*]] = memref.alloca() : memref<30x30xi32> +// CHECK: %[[VAL_2:.*]] = arith.constant 870 : index +// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_2]] : index +// CHECK: %[[VAL_4:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_3]]], sizes: [29, 30], strides: [30, 1] : memref<30x30xi32> to memref<29x30xi32> +// CHECK: return %[[VAL_4]] : memref<29x30xi32> +// CHECK: } + +func @main(%arg0 : index) -> memref<29x30xi32> { + %0 = memref.alloca() : memref<30x30xi32> + %1 = "polygeist.subindex"(%0, %arg0) : (memref<30x30xi32>, index) -> memref<29x30xi32> + return %1 : memref<29x30xi32> +} From 524ccc81f9714aec2b6fc768fe4fb113efc0c401 Mon Sep 17 00:00:00 2001 From: Morten Borup Petersen Date: Sun, 17 Apr 2022 15:14:22 +0200 Subject: [PATCH 6/6] rebase to polygeist upstream --- include/polygeist/Passes/Passes.td | 2 +- lib/polygeist/Passes/LowerPolygeistOps.cpp | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/include/polygeist/Passes/Passes.td b/include/polygeist/Passes/Passes.td index e2cbd2516336..2d5c85bd107a 100644 --- a/include/polygeist/Passes/Passes.td +++ b/include/polygeist/Passes/Passes.td @@ -78,7 +78,7 @@ def RemoveTrivialUse : Pass<"trivialuse"> { let constructor = "mlir::polygeist::createRemoveTrivialUsePass()"; } -def LowerPolygeistOps : Pass<"lower-polygeist-ops", "FuncOp"> { +def LowerPolygeistOps : Pass<"lower-polygeist-ops"> { let summary = "Lower polygeist ops to memref operations"; let constructor = "mlir::polygeist::createLowerPolygeistOpsPass()"; let dependentDialects = ["::mlir::memref::MemRefDialect"]; diff --git a/lib/polygeist/Passes/LowerPolygeistOps.cpp b/lib/polygeist/Passes/LowerPolygeistOps.cpp index e5dd7748f29a..2c0128cfcb6b 100644 --- a/lib/polygeist/Passes/LowerPolygeistOps.cpp +++ b/lib/polygeist/Passes/LowerPolygeistOps.cpp @@ -15,8 +15,6 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Transforms/DialectConversion.h" #include "polygeist/Dialect.h" @@ -68,14 +66,13 @@ struct LowerPolygeistOpsPass void runOnOperation() override { auto op = getOperation(); - auto ctx = op.getContext(); + auto ctx = op->getContext(); RewritePatternSet patterns(ctx); patterns.insert(ctx); ConversionTarget target(*ctx); target.addIllegalDialect(); - target.addLegalDialect(); + target.addLegalDialect(); if (failed(applyPartialConversion(op, target, std::move(patterns)))) return signalPassFailure();