diff --git a/include/polygeist/Passes/Passes.h b/include/polygeist/Passes/Passes.h index ba847ee251b8..a48cfbb14524 100644 --- a/include/polygeist/Passes/Passes.h +++ b/include/polygeist/Passes/Passes.h @@ -24,6 +24,7 @@ std::unique_ptr createParallelLowerPass(); std::unique_ptr createConvertPolygeistToLLVMPass(const LowerToLLVMOptions &options); std::unique_ptr createConvertPolygeistToLLVMPass(); +std::unique_ptr createLowerPolygeistOpsPass(); } // namespace polygeist } // namespace mlir diff --git a/include/polygeist/Passes/Passes.td b/include/polygeist/Passes/Passes.td index d2ab39bbd6e4..2d5c85bd107a 100644 --- a/include/polygeist/Passes/Passes.td +++ b/include/polygeist/Passes/Passes.td @@ -78,6 +78,12 @@ def RemoveTrivialUse : Pass<"trivialuse"> { let constructor = "mlir::polygeist::createRemoveTrivialUsePass()"; } +def LowerPolygeistOps : Pass<"lower-polygeist-ops"> { + let summary = "Lower polygeist ops to memref operations"; + let constructor = "mlir::polygeist::createLowerPolygeistOpsPass()"; + let dependentDialects = ["::mlir::memref::MemRefDialect"]; +} + def ConvertPolygeistToLLVM : Pass<"convert-polygeist-to-llvm", "mlir::ModuleOp"> { let summary = "Convert scalar and vector operations from the Standard to the " "LLVM dialect"; diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index 1a7ffff582cf..fc1e6d0a2eb1 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -529,48 +529,6 @@ class SubToCast final : public OpRewritePattern { } }; -// Simplify polygeist.subindex to memref.subview. -class SubToSubView final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SubIndexOp op, - PatternRewriter &rewriter) const override { - auto srcMemRefType = op.source().getType().cast(); - auto resMemRefType = op.result().getType().cast(); - auto dims = srcMemRefType.getShape().size(); - - // For now, restrict subview lowering to statically defined memref's - if (!srcMemRefType.hasStaticShape() | !resMemRefType.hasStaticShape()) - return failure(); - - // For now, restrict to simple rank-reducing indexing - if (srcMemRefType.getShape().size() <= resMemRefType.getShape().size()) - return failure(); - - // Build offset, sizes and strides - SmallVector sizes(dims, rewriter.getIndexAttr(0)); - sizes[0] = op.index(); - SmallVector offsets(dims); - for (auto dim : llvm::enumerate(srcMemRefType.getShape())) { - if (dim.index() == 0) - offsets[0] = rewriter.getIndexAttr(1); - else - offsets[dim.index()] = rewriter.getIndexAttr(dim.value()); - } - SmallVector strides(dims, rewriter.getIndexAttr(1)); - - // Generate the appropriate return type: - auto subMemRefType = MemRefType::get(srcMemRefType.getShape().drop_front(), - srcMemRefType.getElementType()); - - rewriter.replaceOpWithNewOp( - op, subMemRefType, op.source(), sizes, offsets, strides); - - return success(); - } -}; - // Simplify redundant dynamic subindex patterns which tries to represent // rank-reducing indexing: // %3 = "polygeist.subindex"(%1, %arg0) : (memref<2x1000xi32>, index) -> diff --git a/lib/polygeist/Passes/CMakeLists.txt b/lib/polygeist/Passes/CMakeLists.txt index 495942aaa34a..436c8e94b882 100644 --- a/lib/polygeist/Passes/CMakeLists.txt +++ b/lib/polygeist/Passes/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRPolygeistTransforms TrivialUse.cpp ConvertPolygeistToLLVM.cpp InnerSerialization.cpp + LowerPolygeistOps.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine diff --git a/lib/polygeist/Passes/LowerPolygeistOps.cpp b/lib/polygeist/Passes/LowerPolygeistOps.cpp new file mode 100644 index 000000000000..2c0128cfcb6b --- /dev/null +++ b/lib/polygeist/Passes/LowerPolygeistOps.cpp @@ -0,0 +1,90 @@ +//===- LowerPolygeistOps.cpp - Lower polygeist ops to upstream MLIR ops -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass which lowers any remaining Polygeist dialect +// operations (after canonicalization) to operations found in upstream MLIR +// dialects. +// +//===----------------------------------------------------------------------===// +#include "PassDetails.h" + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Transforms/DialectConversion.h" +#include "polygeist/Dialect.h" +#include "polygeist/Ops.h" + +using namespace mlir; +using namespace polygeist; +using namespace mlir::arith; + +namespace { + +struct SubIndexToReinterpretCast + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(polygeist::SubIndexOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcMemRefType = op.source().getType().cast(); + auto resMemRefType = op.result().getType().cast(); + auto inShape = srcMemRefType.getShape(); + auto outShape = resMemRefType.getShape(); + + if (!resMemRefType.hasStaticShape()) + return failure(); + + llvm::SmallVector strides, sizes; + int64_t innerSize = resMemRefType.getNumElements(); + auto offset = rewriter.create( + op.getLoc(), op.index(), + rewriter.create(op.getLoc(), innerSize)); + + int64_t strideAcc = 1; + for (auto dim : llvm::reverse(outShape)) { + sizes.insert(sizes.begin(), rewriter.getIndexAttr(dim)); + strides.insert(strides.begin(), rewriter.getIndexAttr(strideAcc)); + strideAcc *= dim; + } + + rewriter.replaceOpWithNewOp( + op, resMemRefType, op.source(), offset.getResult(), sizes, strides); + + return success(); + } +}; + +struct LowerPolygeistOpsPass + : public LowerPolygeistOpsBase { + + void runOnOperation() override { + auto op = getOperation(); + auto ctx = op->getContext(); + RewritePatternSet patterns(ctx); + patterns.insert(ctx); + + ConversionTarget target(*ctx); + target.addIllegalDialect(); + target.addLegalDialect(); + + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +namespace mlir { +namespace polygeist { +std::unique_ptr createLowerPolygeistOpsPass() { + return std::make_unique(); +} + +} // namespace polygeist +} // namespace mlir diff --git a/test/polygeist-opt/lower_polygeist_ops.mlir b/test/polygeist-opt/lower_polygeist_ops.mlir new file mode 100644 index 000000000000..11ec89f5b6e7 --- /dev/null +++ b/test/polygeist-opt/lower_polygeist_ops.mlir @@ -0,0 +1,48 @@ +// RUN: polygeist-opt --lower-polygeist-ops --split-input-file %s | FileCheck %s + +// CHECK-LABEL: func @main( +// CHECK-SAME: %[[VAL_0:.*]]: index) -> memref<30xi32> { +// CHECK: %[[VAL_1:.*]] = memref.alloca() : memref<30x30xi32> +// CHECK: %[[VAL_2:.*]] = arith.constant 30 : index +// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_2]] : index +// CHECK: %[[VAL_4:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_3]]], sizes: [30], strides: [1] : memref<30x30xi32> to memref<30xi32> +// CHECK: return %[[VAL_4]] : memref<30xi32> +// CHECK: } +func @main(%arg0 : index) -> memref<30xi32> { + %0 = memref.alloca() : memref<30x30xi32> + %1 = "polygeist.subindex"(%0, %arg0) : (memref<30x30xi32>, index) -> memref<30xi32> + return %1 : memref<30xi32> +} + +// ----- + +// CHECK-LABEL: func @main( +// CHECK-SAME: %[[VAL_0:.*]]: index) -> memref<42x43x44x45xi32> { +// CHECK: %[[VAL_1:.*]] = memref.alloca() : memref<41x42x43x44x45xi32> +// CHECK: %[[VAL_2:.*]] = arith.constant 3575880 : index +// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_2]] : index +// CHECK: %[[VAL_4:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_3]]], sizes: [42, 43, 44, 45], strides: [85140, 1980, 45, 1] : memref<41x42x43x44x45xi32> to memref<42x43x44x45xi32> +// CHECK: return %[[VAL_4]] : memref<42x43x44x45xi32> +// CHECK: } +func @main(%arg0 : index) -> memref<42x43x44x45xi32> { + %0 = memref.alloca() : memref<41x42x43x44x45xi32> + %1 = "polygeist.subindex"(%0, %arg0) : (memref<41x42x43x44x45xi32>, index) -> memref<42x43x44x45xi32> + return %1 : memref<42x43x44x45xi32> +} + +// ----- + +// CHECK-LABEL: func @main( +// CHECK-SAME: %[[VAL_0:.*]]: index) -> memref<29x30xi32> { +// CHECK: %[[VAL_1:.*]] = memref.alloca() : memref<30x30xi32> +// CHECK: %[[VAL_2:.*]] = arith.constant 870 : index +// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_2]] : index +// CHECK: %[[VAL_4:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_3]]], sizes: [29, 30], strides: [30, 1] : memref<30x30xi32> to memref<29x30xi32> +// CHECK: return %[[VAL_4]] : memref<29x30xi32> +// CHECK: } + +func @main(%arg0 : index) -> memref<29x30xi32> { + %0 = memref.alloca() : memref<30x30xi32> + %1 = "polygeist.subindex"(%0, %arg0) : (memref<30x30xi32>, index) -> memref<29x30xi32> + return %1 : memref<29x30xi32> +}