diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp index ad35f5a3fb..9124ae3a2e 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp @@ -14,10 +14,12 @@ //===----------------------------------------------------------------------===// #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "src/Accelerators/Accelerator.hpp" +#include "src/Compiler/CompilerOptions.hpp" #include "src/Dialect/Krnl/DialectBuilder.hpp" #include "src/Dialect/Mlir/DialectBuilder.hpp" #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" @@ -882,4 +884,75 @@ void impl::onnxToKrnlSimdReport(Operation *op, bool successful, static_cast(simdLoopTripCount)); } +// The Gather op is data dependent: the value of index should be +// within the input data size. +// Add runtime check if enableSafeCodeGen is set true +// Implementation comments vs. createGenerateRuntimeVerificationPass +// This check is according to onnx op semantics, not general bound +// check for memref. Implementation of RuntimeVerification could be +// borrowed. Slightly difference is that onnx semenatics check is for +// each dimension independently, not the final address is within +// the memref bound. +void genSafeCodeForGatherAlike(mlir::ConversionPatternRewriter &rewriter, + mlir::Location loc, Operation *op, Value data, Value indices, + int64_t axisLit) { + // Do nothing if not enabled + if (!enableSafeCodeGen) + return; + + MultiDialectBuilder + create(rewriter, loc); + + // Check all the element of indices + DimsExpr dataDims, indicesDims; + create.krnlIE.getShapeAsDims(data, dataDims); + create.krnlIE.getShapeAsDims(indices, indicesDims); + SymbolIndexExpr axisDim(dataDims[axisLit]); + int64_t indicesRank = mlir::cast(indices.getType()).getRank(); + ValueRange loopDef = create.krnl.defineLoops(indicesRank); + LiteralIndexExpr zeroIE(0); + DimsExpr lbs(indicesRank, zeroIE); + create.krnl.iterateIE(loopDef, loopDef, lbs, indicesDims, + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { + IndexExprScope innerLoopScope(createKrnl); + + // Access function for indices + DimsExpr accessFct; + getIndexExprList(loopInd, accessFct); + // Compute index = indices[i][j]...[n] + Value indexVal = createKrnl.loadIE(indices, accessFct); + IndexExpr index = NonAffineIndexExpr(indexVal); + + // index should be in range of [-r, r-1], where r = dim size of + // data[axis]. + // Assume that the index is loaded from tensor with negative value + // correction. + Value errorCondition = + ((index < (-1) * axisDim) | (index >= axisDim)).getValue(); + rewriter.create( + loc, errorCondition, + /*thenBuilder=*/ + [&](OpBuilder &thenBuilder, Location thenLoc) { + MultiDialectBuilder create( + thenBuilder, loc); + std::string nodeNameStr = "Warning: "; + nodeNameStr += op->getName().getStringRef().str() + " "; + StringAttr nodeName = + op->getAttrOfType("onnx_node_name"); + if (nodeName && !nodeName.getValue().empty()) { + nodeNameStr += nodeName.getValue().str(); + } + std::string msg = nodeNameStr + + ": Value of indices is out of bound. " + + "The out-of-bound indices value is: "; + create.krnl.printf(msg, indexVal, true); + msg = "The out-of-bound index is replaced with zero.\n"; + create.krnl.printf(msg); + thenBuilder.create(thenLoc); + }, + /*elseBuilder=*/nullptr); + }); +} + } // namespace onnx_mlir diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index 3db45b4525..c4dea75fb6 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -770,5 +770,14 @@ void emitSymmetricQuantRecscaleToScalar( mlir::Operation *op, mlir::Value input, uint64_t bitWidth, mlir::Value &recscale, bool enableSIMD, bool enableParallel); +// Generate safe code for GatherOp and GatherElementsOp. +// Insert runtime check for the value of indices. +// If the value is out of scope of the `axis` dimension of input data, +// warning message will be printed and the value will be change to zero to +// avoid crash or assertion error. +void genSafeCodeForGatherAlike(mlir::ConversionPatternRewriter &rewriter, + mlir::Location loc, mlir::Operation *op, mlir::Value data, + mlir::Value indices, int64_t axisLit); + } // namespace onnx_mlir #endif diff --git a/src/Conversion/ONNXToKrnl/Tensor/Gather.cpp b/src/Conversion/ONNXToKrnl/Tensor/Gather.cpp index d632a43969..88fa54c508 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Gather.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Gather.cpp @@ -12,8 +12,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "src/Compiler/CompilerOptions.hpp" #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" @@ -72,6 +71,9 @@ struct ONNXGatherOpLowering : public OpConversionPattern { // Negative value means counting dimensions from the back. axisLit = axisLit < 0 ? axisLit + dataRank : axisLit; + // Check the value of indices and change it to zero if out-of-bound + genSafeCodeForGatherAlike(rewriter, loc, op, data, indices, axisLit); + int64_t outputRank = shapeHelper.getOutputDims().size(); int iIndexStart = 0; int jIndexStart = iIndexStart + axisLit; @@ -123,43 +125,14 @@ struct ONNXGatherOpLowering : public OpConversionPattern { Value indexVal = createKrnl.loadIE(indices, indicesAccessFct); // Loaded value is an index that is not affine IndexExpr index = NonAffineIndexExpr(indexVal); + // When index may be negative, add axis Dim to it. if (indicesMayBeNegative) index = index.selectOrSelf(index < zeroIE, index + axisDim); - // The Gather op is data dependent: the value of index should be - // within the input data size. - // Add runtime check if enableSafeCodeGen is set true - // Implementation comments vs. createGenerateRuntimeVerificationPass - // This check is according to onnx op semantics, not general bound - // check for memref. Implementation of RuntimeVerification could be - // borrowed. Slightly difference is that onnx semenatics check is for - // each dimension independently, not the final address is within - // the memref bound. if (enableSafeCodeGen) { - // From onnx document: - // All index values are expected to be within bounds [-s, s-1] - // along axis of size s. It is an error if any of the index values - // are out of bounds. - // After the negative correction, the range should be [0, s-1] - Value upperBound = create.mem.dim(data, axisLit); - Value compareUpperBound = - create.math.slt(index.getValue(), upperBound); - // Report onnx_node_name if the op has the attribute - std::string nodeNameStr = op->getName().getStringRef().str() + " "; - StringAttr nodeName = - op->getAttrOfType("onnx_node_name"); - if (nodeName && !nodeName.getValue().empty()) { - nodeNameStr = nodeNameStr + nodeName.getValue().str(); - } - rewriter.create(loc, compareUpperBound, - nodeNameStr + - " indices of GatherOp is larger than the upper bound"); - Value compareLowerBound = - create.math.sge(index.getValue(), zeroIE.getValue()); - rewriter.create(loc, compareLowerBound, - nodeNameStr + - " indices of GatherOp is less than the lower bound"); + index = index.selectOrSelf(index < 0, zeroIE); + index = index.selectOrSelf(index >= axisDim, axisDim - 1); } // Compute access function of data: data[ii + (indices[jj],) + kk] diff --git a/src/Conversion/ONNXToKrnl/Tensor/GatherElements.cpp b/src/Conversion/ONNXToKrnl/Tensor/GatherElements.cpp index a9571621a6..fa3fb43c92 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/GatherElements.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/GatherElements.cpp @@ -11,9 +11,7 @@ // This file lowers the ONNX GatherElements Operator to Krnl dialect. // //===----------------------------------------------------------------------===// - -#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "src/Compiler/CompilerOptions.hpp" #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" @@ -56,7 +54,7 @@ struct ONNXGatherElementsOpLowering // Operands and attributes. Value data = adaptor.getData(); Value indices = adaptor.getIndices(); - int64_t axis = adaptor.getAxis(); + int64_t axisLit = adaptor.getAxis(); int64_t dataRank = mlir::cast(data.getType()).getRank(); int64_t indicesRank = mlir::cast(indices.getType()).getRank(); int64_t outputRank = outputMemRefType.getShape().size(); @@ -67,8 +65,12 @@ struct ONNXGatherElementsOpLowering bool indicesMayBeNegative = !indicesAreNonNegativeConstants(indices); // Negative value means counting dimensions from the back. - axis = axis < 0 ? axis + dataRank : axis; + axisLit = axisLit < 0 ? axisLit + dataRank : axisLit; + + // Insert safety check code + genSafeCodeForGatherAlike(rewriter, loc, op, data, indices, axisLit); + LiteralIndexExpr zeroIE(0); DimsExpr dataDims, indicesDims; create.krnlIE.getShapeAsDims(data, dataDims); create.krnlIE.getShapeAsDims(indices, indicesDims); @@ -83,6 +85,7 @@ struct ONNXGatherElementsOpLowering [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { // Insert code inside the loop. IndexExprScope innerLoopScope(createKrnl); + SymbolIndexExpr axisDim(dataDims[axisLit]); // Access function for indices and output. DimsExpr accessFct; @@ -93,41 +96,20 @@ struct ONNXGatherElementsOpLowering IndexExpr index = NonAffineIndexExpr(indexVal); if (indicesMayBeNegative) { - LiteralIndexExpr zero(0); - SymbolIndexExpr axisDim(dataDims[axis]); - index = index.selectOrSelf(index < zero, index + axisDim); + index = index.selectOrSelf(index < zeroIE, index + axisDim); } // Check the dynamic requirement of GatherElement Op // Refer to the comments in Gather.cpp if (enableSafeCodeGen) { - // From onnx document: - // All index values are expected to be within bounds [-s, s-1] - // along axis of size s. It is an error if any of the index values - // are out of bounds. - // After the negative correction, the range should be [0, s-1] - Value upperBound = create.mem.dim(data, axis); - Value compareUpperBound = - create.math.slt(index.getValue(), upperBound); - std::string nodeNameStr = op->getName().getStringRef().str() + " "; - StringAttr nodeName = - op->getAttrOfType("onnx_node_name"); - if (nodeName && !nodeName.getValue().empty()) { - nodeNameStr = nodeNameStr + nodeName.getValue().str(); - } - rewriter.create(loc, compareUpperBound, - "indices of GatherOp is larger than the upper bound"); - LiteralIndexExpr zero(0); - Value compareLowerBound = - create.math.sge(index.getValue(), zero.getValue()); - rewriter.create(loc, compareLowerBound, - "indices of GatherOp is less than the lower bound"); + index = index.selectOrSelf(index < 0, zeroIE); + index = index.selectOrSelf(index >= axisDim, axisDim - 1); } // Access function for the 'data' tensor. DimsExpr dataAccessFct; for (int64_t i = 0; i < dataRank; ++i) - dataAccessFct.emplace_back((i == axis) ? index : accessFct[i]); + dataAccessFct.emplace_back((i == axisLit) ? index : accessFct[i]); // Gather values from the 'data' tensor and save them. Value dataVal = createKrnl.loadIE(data, dataAccessFct); diff --git a/test/mlir/conversion/onnx_to_krnl/Tensor/GatherElements.mlir b/test/mlir/conversion/onnx_to_krnl/Tensor/GatherElements.mlir index 85b7caef5b..8bcb1b1797 100644 --- a/test/mlir/conversion/onnx_to_krnl/Tensor/GatherElements.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Tensor/GatherElements.mlir @@ -3,23 +3,27 @@ func.func @test_gather_elements(%arg0: tensor<4xi64>, %arg1: tensor<2xi64>) -> tensor<2xi64> { %0 = "onnx.GatherElements"(%arg0, %arg1) : (tensor<4xi64>, tensor<2xi64>) -> tensor<2xi64> return %0 : tensor<2xi64> -// CHECK-LABEL: @test_gather_elements -// CHECK-SAME: ([[PARAM_0:%.+]]: memref<4xi64>, [[PARAM_1:%.+]]: memref<2xi64>) -> memref<2xi64> { -// CHECK-DAG: [[RES:%.+]] = memref.alloc() {alignment = 16 : i64} : memref<2xi64> -// CHECK-DAG: [[LOOP_0:%.+]] = krnl.define_loops 1 -// CHECK: krnl.iterate([[LOOP_0]]) with ([[LOOP_0]] -> [[I_0:%.+]] = 0 to 2){ -// CHECK: [[IV:%.+]] = krnl.get_induction_var_value([[LOOP_0]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_INDEX:%.+]] = krnl.load [[PARAM_1]]{{.*}}[[IV]]]{{.*}} : memref<2xi64> -// CHECK-DAG: [[INDEX:%.+]] = arith.index_cast [[LOAD_INDEX]] : i64 to index -// CHECK-DAG: [[CST_0:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[CST_4:%.+]] = arith.constant 4 : index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[CMP:%.+]] = arith.cmpi slt, [[INDEX]], [[CST_0]] : index -// CHECK-DAG: [[VAR_1:%.+]] = arith.addi [[INDEX]], [[CST_4]] : index -// CHECK: [[SEL:%.+]] = arith.select [[CMP]], [[VAR_1]], [[INDEX]] : index -// CHECK: [[DATA_VAL:%.+]] = krnl.load [[PARAM_0]]{{.}}[[SEL]]{{.}} : memref<4xi64> -// CHECK: krnl.store [[DATA_VAL]], [[RES]]{{.}}[[IV]]{{.}} : memref<2xi64> +// CHECK-LABEL: func.func @test_gather_elements +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<4xi64>, [[PARAM_1_:%.+]]: memref<2xi64>) -> memref<2xi64> { +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<2xi64> +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[CST_2_1_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index +// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 2){ +// CHECK-DAG: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[CST_4_1_:%.+]] = arith.constant 4 : index +// CHECK: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_1_]]{{.}} : memref<2xi64> +// CHECK: [[VAR_3_:%.+]] = arith.index_cast [[LOAD_PARAM_1_MEM_]] : i64 to index +// CHECK-DAG: [[VAR_4_:%.+]] = arith.cmpi slt, [[VAR_3_]], [[CST_0_]] : index +// CHECK-DAG: [[VAR_5_:%.+]] = arith.addi [[VAR_3_]], [[CST_4_1_]] : index +// CHECK: [[VAR_6_:%.+]] = arith.select [[VAR_4_]], [[VAR_5_]], [[VAR_3_]] : index +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_6_]]{{.}} : memref<4xi64> +// CHECK: krnl.store [[LOAD_PARAM_0_MEM_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<2xi64> // CHECK: } -// CHECK: return [[RES]] : memref<2xi64> +// CHECK: return [[RES_]] : memref<2xi64> +// CHECK: } } diff --git a/test/mlir/conversion/onnx_to_krnl/Tensor/Gather_safe.mlir b/test/mlir/conversion/onnx_to_krnl/Tensor/Gather_safe.mlir new file mode 100644 index 0000000000..b6770c2a03 --- /dev/null +++ b/test/mlir/conversion/onnx_to_krnl/Tensor/Gather_safe.mlir @@ -0,0 +1,41 @@ +// RUN: onnx-mlir-opt --enable-safe-code-gen --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s + +// Checking of print has to be manual modified. +func.func @test_gather_scalar(%arg0: tensor<4xi64>, %arg1: tensor) -> tensor { + %0 = "onnx.Gather"(%arg0, %arg1) {axis = 0 : si64} : (tensor<4xi64>, tensor) -> tensor + return %0 : tensor +// CHECK-LABEL: func.func @test_gather_scalar +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<4xi64>, [[PARAM_1_:%.+]]: memref) -> memref { +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_minus_4_:%.+]] = arith.constant -4 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref +// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[LOAD_PARAM_1_MEM_]] : i64 to index +// CHECK-DAG: [[VAR_2_:%.+]] = arith.cmpi slt, [[VAR_1_]], [[CST_minus_4_]] : index +// CHECK-DAG: [[VAR_3_:%.+]] = arith.cmpi sge, [[VAR_1_]], [[CST_4_]] : index +// CHECK: [[VAR_4_:%.+]] = arith.ori [[VAR_2_]], [[VAR_3_]] : i1 +// CHECK: scf.if [[VAR_4_]] { +// CHECK: "krnl.print" +// CHECK: "krnl.print" +// CHECK: } +// CHECK: } +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOAD_PARAM_1_MEM_1_:%.+]] = krnl.load [[PARAM_1_]][] : memref +// CHECK: [[VAR_1_1_:%.+]] = arith.index_cast [[LOAD_PARAM_1_MEM_1_]] : i64 to index +// CHECK-DAG: [[VAR_2_1_:%.+]] = arith.cmpi slt, [[VAR_1_1_]], [[CST_0_]] : index +// CHECK-DAG: [[VAR_3_1_:%.+]] = arith.addi [[VAR_1_1_]], [[CST_4_]] : index +// CHECK: [[VAR_4_1_:%.+]] = arith.select [[VAR_2_1_]], [[VAR_3_1_]], [[VAR_1_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.cmpi slt, [[VAR_4_1_]], [[CST_0_]] : index +// CHECK: [[VAR_6_:%.+]] = arith.select [[VAR_5_]], [[CST_0_]], [[VAR_4_1_]] : index +// CHECK: [[VAR_7_:%.+]] = arith.cmpi sge, [[VAR_6_]], [[CST_4_]] : index +// CHECK: [[VAR_8_:%.+]] = arith.select [[VAR_7_]], [[CST_3_]], [[VAR_6_]] : index +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_8_]]{{.}} : memref<4xi64> +// CHECK: krnl.store [[LOAD_PARAM_0_MEM_]], [[RES_]][] : memref +// CHECK: } +// CHECK: return [[RES_]] : memref +// CHECK: } + +}