Skip to content

Commit

Permalink
Handle out-of-bound value for Gather alike operation (#3077)
Browse files Browse the repository at this point in the history
* implementation

Signed-off-by: Chen Tong <[email protected]>

* format

Signed-off-by: Chen Tong <[email protected]>

* response

Signed-off-by: Chen Tong <[email protected]>

* test

Signed-off-by: Chen Tong <[email protected]>

* comments

Signed-off-by: Chen Tong <[email protected]>

* retrigger check

Signed-off-by: Chen Tong <[email protected]>

---------

Signed-off-by: Chen Tong <[email protected]>
  • Loading branch information
chentong319 authored Feb 19, 2025
1 parent be4a2b8 commit 7b4829b
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 81 deletions.
73 changes: 73 additions & 0 deletions src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"

#include "src/Accelerators/Accelerator.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/Krnl/DialectBuilder.hpp"
#include "src/Dialect/Mlir/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
Expand Down Expand Up @@ -882,4 +884,75 @@ void impl::onnxToKrnlSimdReport(Operation *op, bool successful,
static_cast<long long int>(simdLoopTripCount));
}

// The Gather op is data dependent: the value of index should be
// within the input data size.
// Add runtime check if enableSafeCodeGen is set true
// Implementation comments vs. createGenerateRuntimeVerificationPass
// This check is according to onnx op semantics, not general bound
// check for memref. Implementation of RuntimeVerification could be
// borrowed. Slightly difference is that onnx semenatics check is for
// each dimension independently, not the final address is within
// the memref bound.
void genSafeCodeForGatherAlike(mlir::ConversionPatternRewriter &rewriter,
mlir::Location loc, Operation *op, Value data, Value indices,
int64_t axisLit) {
// Do nothing if not enabled
if (!enableSafeCodeGen)
return;

MultiDialectBuilder<KrnlBuilder, IndexExprBuilderForKrnl, MemRefBuilder,
MathBuilder>
create(rewriter, loc);

// Check all the element of indices
DimsExpr dataDims, indicesDims;
create.krnlIE.getShapeAsDims(data, dataDims);
create.krnlIE.getShapeAsDims(indices, indicesDims);
SymbolIndexExpr axisDim(dataDims[axisLit]);
int64_t indicesRank = mlir::cast<MemRefType>(indices.getType()).getRank();
ValueRange loopDef = create.krnl.defineLoops(indicesRank);
LiteralIndexExpr zeroIE(0);
DimsExpr lbs(indicesRank, zeroIE);
create.krnl.iterateIE(loopDef, loopDef, lbs, indicesDims,
[&](const KrnlBuilder &createKrnl, ValueRange loopInd) {
IndexExprScope innerLoopScope(createKrnl);

// Access function for indices
DimsExpr accessFct;
getIndexExprList<DimIndexExpr>(loopInd, accessFct);
// Compute index = indices[i][j]...[n]
Value indexVal = createKrnl.loadIE(indices, accessFct);
IndexExpr index = NonAffineIndexExpr(indexVal);

// index should be in range of [-r, r-1], where r = dim size of
// data[axis].
// Assume that the index is loaded from tensor with negative value
// correction.
Value errorCondition =
((index < (-1) * axisDim) | (index >= axisDim)).getValue();
rewriter.create<scf::IfOp>(
loc, errorCondition,
/*thenBuilder=*/
[&](OpBuilder &thenBuilder, Location thenLoc) {
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(
thenBuilder, loc);
std::string nodeNameStr = "Warning: ";
nodeNameStr += op->getName().getStringRef().str() + " ";
StringAttr nodeName =
op->getAttrOfType<mlir::StringAttr>("onnx_node_name");
if (nodeName && !nodeName.getValue().empty()) {
nodeNameStr += nodeName.getValue().str();
}
std::string msg = nodeNameStr +
": Value of indices is out of bound. " +
"The out-of-bound indices value is: ";
create.krnl.printf(msg, indexVal, true);
msg = "The out-of-bound index is replaced with zero.\n";
create.krnl.printf(msg);
thenBuilder.create<scf::YieldOp>(thenLoc);
},
/*elseBuilder=*/nullptr);
});
}

} // namespace onnx_mlir
9 changes: 9 additions & 0 deletions src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -770,5 +770,14 @@ void emitSymmetricQuantRecscaleToScalar(
mlir::Operation *op, mlir::Value input, uint64_t bitWidth,
mlir::Value &recscale, bool enableSIMD, bool enableParallel);

// Generate safe code for GatherOp and GatherElementsOp.
// Insert runtime check for the value of indices.
// If the value is out of scope of the `axis` dimension of input data,
// warning message will be printed and the value will be change to zero to
// avoid crash or assertion error.
void genSafeCodeForGatherAlike(mlir::ConversionPatternRewriter &rewriter,
mlir::Location loc, mlir::Operation *op, mlir::Value data,
mlir::Value indices, int64_t axisLit);

} // namespace onnx_mlir
#endif
41 changes: 7 additions & 34 deletions src/Conversion/ONNXToKrnl/Tensor/Gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"

#include "src/Compiler/CompilerOptions.hpp"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
Expand Down Expand Up @@ -72,6 +71,9 @@ struct ONNXGatherOpLowering : public OpConversionPattern<ONNXGatherOp> {
// Negative value means counting dimensions from the back.
axisLit = axisLit < 0 ? axisLit + dataRank : axisLit;

// Check the value of indices and change it to zero if out-of-bound
genSafeCodeForGatherAlike(rewriter, loc, op, data, indices, axisLit);

int64_t outputRank = shapeHelper.getOutputDims().size();
int iIndexStart = 0;
int jIndexStart = iIndexStart + axisLit;
Expand Down Expand Up @@ -123,43 +125,14 @@ struct ONNXGatherOpLowering : public OpConversionPattern<ONNXGatherOp> {
Value indexVal = createKrnl.loadIE(indices, indicesAccessFct);
// Loaded value is an index that is not affine
IndexExpr index = NonAffineIndexExpr(indexVal);

// When index may be negative, add axis Dim to it.
if (indicesMayBeNegative)
index = index.selectOrSelf(index < zeroIE, index + axisDim);

// The Gather op is data dependent: the value of index should be
// within the input data size.
// Add runtime check if enableSafeCodeGen is set true
// Implementation comments vs. createGenerateRuntimeVerificationPass
// This check is according to onnx op semantics, not general bound
// check for memref. Implementation of RuntimeVerification could be
// borrowed. Slightly difference is that onnx semenatics check is for
// each dimension independently, not the final address is within
// the memref bound.
if (enableSafeCodeGen) {
// From onnx document:
// All index values are expected to be within bounds [-s, s-1]
// along axis of size s. It is an error if any of the index values
// are out of bounds.
// After the negative correction, the range should be [0, s-1]
Value upperBound = create.mem.dim(data, axisLit);
Value compareUpperBound =
create.math.slt(index.getValue(), upperBound);
// Report onnx_node_name if the op has the attribute
std::string nodeNameStr = op->getName().getStringRef().str() + " ";
StringAttr nodeName =
op->getAttrOfType<mlir::StringAttr>("onnx_node_name");
if (nodeName && !nodeName.getValue().empty()) {
nodeNameStr = nodeNameStr + nodeName.getValue().str();
}
rewriter.create<cf::AssertOp>(loc, compareUpperBound,
nodeNameStr +
" indices of GatherOp is larger than the upper bound");
Value compareLowerBound =
create.math.sge(index.getValue(), zeroIE.getValue());
rewriter.create<cf::AssertOp>(loc, compareLowerBound,
nodeNameStr +
" indices of GatherOp is less than the lower bound");
index = index.selectOrSelf(index < 0, zeroIE);
index = index.selectOrSelf(index >= axisDim, axisDim - 1);
}

// Compute access function of data: data[ii + (indices[jj],) + kk]
Expand Down
42 changes: 12 additions & 30 deletions src/Conversion/ONNXToKrnl/Tensor/GatherElements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
// This file lowers the ONNX GatherElements Operator to Krnl dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"

#include "src/Compiler/CompilerOptions.hpp"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
Expand Down Expand Up @@ -56,7 +54,7 @@ struct ONNXGatherElementsOpLowering
// Operands and attributes.
Value data = adaptor.getData();
Value indices = adaptor.getIndices();
int64_t axis = adaptor.getAxis();
int64_t axisLit = adaptor.getAxis();
int64_t dataRank = mlir::cast<MemRefType>(data.getType()).getRank();
int64_t indicesRank = mlir::cast<MemRefType>(indices.getType()).getRank();
int64_t outputRank = outputMemRefType.getShape().size();
Expand All @@ -67,8 +65,12 @@ struct ONNXGatherElementsOpLowering
bool indicesMayBeNegative = !indicesAreNonNegativeConstants(indices);

// Negative value means counting dimensions from the back.
axis = axis < 0 ? axis + dataRank : axis;
axisLit = axisLit < 0 ? axisLit + dataRank : axisLit;

// Insert safety check code
genSafeCodeForGatherAlike(rewriter, loc, op, data, indices, axisLit);

LiteralIndexExpr zeroIE(0);
DimsExpr dataDims, indicesDims;
create.krnlIE.getShapeAsDims(data, dataDims);
create.krnlIE.getShapeAsDims(indices, indicesDims);
Expand All @@ -83,6 +85,7 @@ struct ONNXGatherElementsOpLowering
[&](const KrnlBuilder &createKrnl, ValueRange loopInd) {
// Insert code inside the loop.
IndexExprScope innerLoopScope(createKrnl);
SymbolIndexExpr axisDim(dataDims[axisLit]);

// Access function for indices and output.
DimsExpr accessFct;
Expand All @@ -93,41 +96,20 @@ struct ONNXGatherElementsOpLowering
IndexExpr index = NonAffineIndexExpr(indexVal);

if (indicesMayBeNegative) {
LiteralIndexExpr zero(0);
SymbolIndexExpr axisDim(dataDims[axis]);
index = index.selectOrSelf(index < zero, index + axisDim);
index = index.selectOrSelf(index < zeroIE, index + axisDim);
}

// Check the dynamic requirement of GatherElement Op
// Refer to the comments in Gather.cpp
if (enableSafeCodeGen) {
// From onnx document:
// All index values are expected to be within bounds [-s, s-1]
// along axis of size s. It is an error if any of the index values
// are out of bounds.
// After the negative correction, the range should be [0, s-1]
Value upperBound = create.mem.dim(data, axis);
Value compareUpperBound =
create.math.slt(index.getValue(), upperBound);
std::string nodeNameStr = op->getName().getStringRef().str() + " ";
StringAttr nodeName =
op->getAttrOfType<mlir::StringAttr>("onnx_node_name");
if (nodeName && !nodeName.getValue().empty()) {
nodeNameStr = nodeNameStr + nodeName.getValue().str();
}
rewriter.create<cf::AssertOp>(loc, compareUpperBound,
"indices of GatherOp is larger than the upper bound");
LiteralIndexExpr zero(0);
Value compareLowerBound =
create.math.sge(index.getValue(), zero.getValue());
rewriter.create<cf::AssertOp>(loc, compareLowerBound,
"indices of GatherOp is less than the lower bound");
index = index.selectOrSelf(index < 0, zeroIE);
index = index.selectOrSelf(index >= axisDim, axisDim - 1);
}

// Access function for the 'data' tensor.
DimsExpr dataAccessFct;
for (int64_t i = 0; i < dataRank; ++i)
dataAccessFct.emplace_back((i == axis) ? index : accessFct[i]);
dataAccessFct.emplace_back((i == axisLit) ? index : accessFct[i]);

// Gather values from the 'data' tensor and save them.
Value dataVal = createKrnl.loadIE(data, dataAccessFct);
Expand Down
38 changes: 21 additions & 17 deletions test/mlir/conversion/onnx_to_krnl/Tensor/GatherElements.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,27 @@
func.func @test_gather_elements(%arg0: tensor<4xi64>, %arg1: tensor<2xi64>) -> tensor<2xi64> {
%0 = "onnx.GatherElements"(%arg0, %arg1) : (tensor<4xi64>, tensor<2xi64>) -> tensor<2xi64>
return %0 : tensor<2xi64>
// CHECK-LABEL: @test_gather_elements
// CHECK-SAME: ([[PARAM_0:%.+]]: memref<4xi64>, [[PARAM_1:%.+]]: memref<2xi64>) -> memref<2xi64> {
// CHECK-DAG: [[RES:%.+]] = memref.alloc() {alignment = 16 : i64} : memref<2xi64>
// CHECK-DAG: [[LOOP_0:%.+]] = krnl.define_loops 1
// CHECK: krnl.iterate([[LOOP_0]]) with ([[LOOP_0]] -> [[I_0:%.+]] = 0 to 2){
// CHECK: [[IV:%.+]] = krnl.get_induction_var_value([[LOOP_0]]) : (!krnl.loop) -> index
// CHECK-DAG: [[LOAD_INDEX:%.+]] = krnl.load [[PARAM_1]]{{.*}}[[IV]]]{{.*}} : memref<2xi64>
// CHECK-DAG: [[INDEX:%.+]] = arith.index_cast [[LOAD_INDEX]] : i64 to index
// CHECK-DAG: [[CST_0:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[CST_4:%.+]] = arith.constant 4 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[CMP:%.+]] = arith.cmpi slt, [[INDEX]], [[CST_0]] : index
// CHECK-DAG: [[VAR_1:%.+]] = arith.addi [[INDEX]], [[CST_4]] : index
// CHECK: [[SEL:%.+]] = arith.select [[CMP]], [[VAR_1]], [[INDEX]] : index
// CHECK: [[DATA_VAL:%.+]] = krnl.load [[PARAM_0]]{{.}}[[SEL]]{{.}} : memref<4xi64>
// CHECK: krnl.store [[DATA_VAL]], [[RES]]{{.}}[[IV]]{{.}} : memref<2xi64>
// CHECK-LABEL: func.func @test_gather_elements
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<4xi64>, [[PARAM_1_:%.+]]: memref<2xi64>) -> memref<2xi64> {
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<2xi64>
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index
// CHECK-DAG: [[CST_2_1_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1
// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index
// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 2){
// CHECK-DAG: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index
// CHECK-DAG: [[CST_4_1_:%.+]] = arith.constant 4 : index
// CHECK: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_1_]]{{.}} : memref<2xi64>
// CHECK: [[VAR_3_:%.+]] = arith.index_cast [[LOAD_PARAM_1_MEM_]] : i64 to index
// CHECK-DAG: [[VAR_4_:%.+]] = arith.cmpi slt, [[VAR_3_]], [[CST_0_]] : index
// CHECK-DAG: [[VAR_5_:%.+]] = arith.addi [[VAR_3_]], [[CST_4_1_]] : index
// CHECK: [[VAR_6_:%.+]] = arith.select [[VAR_4_]], [[VAR_5_]], [[VAR_3_]] : index
// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_6_]]{{.}} : memref<4xi64>
// CHECK: krnl.store [[LOAD_PARAM_0_MEM_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<2xi64>
// CHECK: }
// CHECK: return [[RES]] : memref<2xi64>
// CHECK: return [[RES_]] : memref<2xi64>
// CHECK: }
}

41 changes: 41 additions & 0 deletions test/mlir/conversion/onnx_to_krnl/Tensor/Gather_safe.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// RUN: onnx-mlir-opt --enable-safe-code-gen --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s

// Checking of print has to be manual modified.
func.func @test_gather_scalar(%arg0: tensor<4xi64>, %arg1: tensor<i64>) -> tensor<i64> {
%0 = "onnx.Gather"(%arg0, %arg1) {axis = 0 : si64} : (tensor<4xi64>, tensor<i64>) -> tensor<i64>
return %0 : tensor<i64>
// CHECK-LABEL: func.func @test_gather_scalar
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<4xi64>, [[PARAM_1_:%.+]]: memref<i64>) -> memref<i64> {
// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[CST_minus_4_:%.+]] = arith.constant -4 : index
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<i64>
// CHECK: krnl.iterate() with (){
// CHECK: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref<i64>
// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[LOAD_PARAM_1_MEM_]] : i64 to index
// CHECK-DAG: [[VAR_2_:%.+]] = arith.cmpi slt, [[VAR_1_]], [[CST_minus_4_]] : index
// CHECK-DAG: [[VAR_3_:%.+]] = arith.cmpi sge, [[VAR_1_]], [[CST_4_]] : index
// CHECK: [[VAR_4_:%.+]] = arith.ori [[VAR_2_]], [[VAR_3_]] : i1
// CHECK: scf.if [[VAR_4_]] {
// CHECK: "krnl.print"
// CHECK: "krnl.print"
// CHECK: }
// CHECK: }
// CHECK: krnl.iterate() with (){
// CHECK: [[LOAD_PARAM_1_MEM_1_:%.+]] = krnl.load [[PARAM_1_]][] : memref<i64>
// CHECK: [[VAR_1_1_:%.+]] = arith.index_cast [[LOAD_PARAM_1_MEM_1_]] : i64 to index
// CHECK-DAG: [[VAR_2_1_:%.+]] = arith.cmpi slt, [[VAR_1_1_]], [[CST_0_]] : index
// CHECK-DAG: [[VAR_3_1_:%.+]] = arith.addi [[VAR_1_1_]], [[CST_4_]] : index
// CHECK: [[VAR_4_1_:%.+]] = arith.select [[VAR_2_1_]], [[VAR_3_1_]], [[VAR_1_1_]] : index
// CHECK: [[VAR_5_:%.+]] = arith.cmpi slt, [[VAR_4_1_]], [[CST_0_]] : index
// CHECK: [[VAR_6_:%.+]] = arith.select [[VAR_5_]], [[CST_0_]], [[VAR_4_1_]] : index
// CHECK: [[VAR_7_:%.+]] = arith.cmpi sge, [[VAR_6_]], [[CST_4_]] : index
// CHECK: [[VAR_8_:%.+]] = arith.select [[VAR_7_]], [[CST_3_]], [[VAR_6_]] : index
// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_8_]]{{.}} : memref<4xi64>
// CHECK: krnl.store [[LOAD_PARAM_0_MEM_]], [[RES_]][] : memref<i64>
// CHECK: }
// CHECK: return [[RES_]] : memref<i64>
// CHECK: }

}

0 comments on commit 7b4829b

Please sign in to comment.