diff --git a/src/Dialect/ONNX/ONNXOps/Math/Scatter.cpp b/src/Dialect/ONNX/ONNXOps/Math/Scatter.cpp index 189d855805..701e03721d 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/Scatter.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/Scatter.cpp @@ -76,6 +76,10 @@ LogicalResult ONNXScatterElementsOp::verify() { if (dataDimAtAxis >= 0) { if (ElementsAttr valueAttribute = getElementAttributeFromONNXValue(indices)) { + if (isElementAttrUninitializedDenseResource(valueAttribute)) { + return success(); // Return success to allow the parsing of MLIR with + // elided attributes + } for (IntegerAttr value : valueAttribute.getValues()) { int64_t index = value.getInt(); if (index >= -dataDimAtAxis && index < dataDimAtAxis) diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp index 36cefe7675..0468919038 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp @@ -12,6 +12,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Path.h" @@ -872,4 +873,15 @@ std::string getNodeNameInPresenceOfOpt(Operation *op, bool useFileLine) { return "NOTSET"; } +//===----------------------------------------------------------------------===// +// Support for DenseElementsAttr. +//===----------------------------------------------------------------------===// + +bool isElementAttrUninitializedDenseResource(mlir::ElementsAttr elementsAttr) { + const auto denseResourceElementsAttr = + mlir::dyn_cast(elementsAttr); + return denseResourceElementsAttr && + !denseResourceElementsAttr.getRawHandle().getBlob(); +} + } // namespace onnx_mlir diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp index 68976fe05b..e3a364022c 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp @@ -380,6 +380,14 @@ bool isIdentityReshape(mlir::Value input, mlir::Value output, std::string getNodeNameInPresenceOfOpt( mlir::Operation *op, bool useFileLine = true); +//===----------------------------------------------------------------------===// +// Support for DenseElementsAttr. +//===----------------------------------------------------------------------===// + +/// Returns true if elementsAttr is a DenseResourceAttr with a blob that can not +/// be received +bool isElementAttrUninitializedDenseResource(mlir::ElementsAttr elementsAttr); + #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc" } // namespace onnx_mlir diff --git a/src/Dialect/ONNX/ONNXOps/Sequence/SplitToSequence.cpp b/src/Dialect/ONNX/ONNXOps/Sequence/SplitToSequence.cpp index 3a17990e56..38f922f765 100644 --- a/src/Dialect/ONNX/ONNXOps/Sequence/SplitToSequence.cpp +++ b/src/Dialect/ONNX/ONNXOps/Sequence/SplitToSequence.cpp @@ -58,6 +58,10 @@ LogicalResult ONNXSplitToSequenceOp::verify() { if (splitRank > 1) return emitOpError() << ": split has rank " << splitRank << " > 1"; if (ElementsAttr entries = getElementAttributeFromONNXValue(splitValue)) { + if (isElementAttrUninitializedDenseResource(entries)) { + return success(); // Return success to allow the parsing of MLIR with + // elided attributes + } if (splitRank == 0) { auto scalar = getScalarValue(entries, splitType); if (scalar <= 0) diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp index 787fc9b75e..6058adfcdb 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp @@ -70,6 +70,10 @@ LogicalResult ONNXConstantOfShapeOp::verify() { if (auto constantOp = getONNXConstantOp(input)) { ElementsAttr valueAttribute = mlir::cast(constantOp.getValueAttr()); + if (isElementAttrUninitializedDenseResource(valueAttribute)) { + return success(); // Return success to allow the parsing of MLIR with + // elided attributes + } // Get repeat values from valueAttribute. auto valueIt = valueAttribute.getValues().begin(); for (int i = 0; i < inputShape[0]; ++i) { diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/GatherElements.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/GatherElements.cpp index ce35ad81b3..dde8029994 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/GatherElements.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/GatherElements.cpp @@ -71,8 +71,13 @@ LogicalResult ONNXGatherElementsOp::verify() { // along axis of size s. ArrayRef dataShape = dataType.getShape(); const int64_t dataDimAtAxis = dataShape[axis]; - if (dataDimAtAxis >= 0) - if (ElementsAttr valueAttribute = getElementAttributeFromONNXValue(indices)) + if (dataDimAtAxis >= 0) { + if (ElementsAttr valueAttribute = + getElementAttributeFromONNXValue(indices)) { + if (isElementAttrUninitializedDenseResource(valueAttribute)) { + return success(); // Return success to allow the parsing of MLIR with + // elided attributes + } for (IntegerAttr value : valueAttribute.getValues()) { int64_t index = value.getInt(); if (index >= -dataDimAtAxis && index < dataDimAtAxis) @@ -83,6 +88,8 @@ LogicalResult ONNXGatherElementsOp::verify() { onnx_mlir::Diagnostic::Range( -dataDimAtAxis, dataDimAtAxis - 1)); } + } + } return success(); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp index f5cf329cd0..b388607c12 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp @@ -144,6 +144,10 @@ LogicalResult ONNXGatherNDOp::verify() { // All values in 'indices' are expected to satisfy the inequality: // -data.shape[b + i] <= indices[...,i] <= (data.shape[b + i]-1)]. if (ElementsAttr valueAttribute = getElementAttributeFromONNXValue(indices)) { + if (isElementAttrUninitializedDenseResource(valueAttribute)) { + return success(); // Return success to allow the parsing of MLIR with + // elided attributes + } int flatIndex = 0; for (IntegerAttr value : valueAttribute.getValues()) { int64_t indexValue = value.getInt(); diff --git a/test/mlir/onnx/invalid.mlir b/test/mlir/onnx/invalid.mlir index f91d261eaa..7b8d087c03 100644 --- a/test/mlir/onnx/invalid.mlir +++ b/test/mlir/onnx/invalid.mlir @@ -182,6 +182,15 @@ func.func @test_constantofshape_verifier_4() -> tensor<2xi64> { // ----- +func.func @test_constantofshape_elided() -> tensor<2xi64> { + // Tests that we do not crash on elided elements + %0 = onnx.Constant dense_resource<__elided__> : tensor<2xi64> + %1 = "onnx.ConstantOfShape"(%0) : (tensor<2xi64>) -> tensor<2xi64> + "onnx.Return"(%1) : (tensor<2xi64>) -> () +} + +// ----- + func.func @test_flatten_verifier_1(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { // expected-error @+1 {{onnx.Flatten: 'axis' value is 5, accepted range is [-4, 4]}} %1 = "onnx.Flatten"(%arg0) {axis = 5 : si64} : (tensor<5x5x1x32xf32>) -> tensor<*xf32> @@ -214,6 +223,15 @@ func.func @test_gatherElements_verifier_2(%data: tensor<2x2xf32>, %indices: tens // ----- +func.func @test_gatherElements_verifier_elided(%data: tensor<12x14x1024xf32>) -> tensor<12x14x14xf32> { + // Tests that we do not crash on elided elements + %indices = onnx.Constant dense_resource<__elided__> : tensor<12x14x14xi64> + %1 = "onnx.GatherElements"(%data, %indices) {axis = -1 : si64} : (tensor<12x14x1024xf32>, tensor<12x14x14xi64>) -> tensor<12x14x14xf32> + "onnx.Return"(%1) : (tensor<12x14x14xf32>) -> () +} + +// ----- + func.func @test_hardmax_verifier_1(%arg0: tensor<2x2xf32>) -> tensor<*xf32> { // expected-error @+1 {{onnx.Hardmax: 'axis' value is 3, accepted range is [-2, 1]}} %1 = "onnx.Hardmax"(%arg0) {axis = 3: si64} : (tensor<2x2xf32>) -> tensor<*xf32> @@ -307,6 +325,16 @@ func.func @test_gatherND_verifier_6(%arg0 : tensor<3x4x4x4xf32>) -> tensor<*xf32 // expected-error @+2 {{onnx.GatherND: 'indices[0]' value is 3, accepted range is [-3, 2]}} %indices = "onnx.Constant"() {value = dense<[3,2,2]> : tensor<3xi64>} : () -> tensor<3x3x2xi64> %1 = "onnx.GatherND"(%arg0, %indices) : (tensor<3x4x4x4xf32>, tensor<3x3x2xi64>) -> tensor<*xf32> + "onnx.Return"(%1) : (tensor<*xf32>) -> () +} + +// ----- + +func.func @test_gatherND_verifier_elided(%arg0 : tensor<3x4x4x4xf32>) -> tensor<*xf32> { + // Test that we do not crash on elided elements + %indices = onnx.Constant dense_resource<__elided__> : tensor<3x3x2xi64> + %1 = "onnx.GatherND"(%arg0, %indices) : (tensor<3x4x4x4xf32>, tensor<3x3x2xi64>) -> tensor<*xf32> + "onnx.Return"(%1) : (tensor<*xf32>) -> () } // ----- @@ -580,6 +608,15 @@ func.func @test_splitToSequence_verifier_6(%arg0: tensor<2x2xf32>) -> !onnx.Seq< // ----- +func.func @test_splitToSequence_verifier_elided(%arg0: tensor<2x2xf32>) -> !onnx.Seq> { + // Tests that we do not crash on elided elements + %0 = onnx.Constant dense_resource<__elided__> : tensor + %1 = "onnx.SplitToSequence"(%arg0, %0) : (tensor<2x2xf32>, tensor) -> !onnx.Seq> + "onnx.Return"(%1) : (!onnx.Seq>) -> () +} + +// ----- + func.func @test_topK_verifier_1(%arg0: tensor<3x4xi64>, %arg1: tensor<1xi64>) -> (tensor<*xf32>, tensor<*xi64>) { // expected-error @+1 {{onnx.TopK: 'axis' value is 2, accepted range is [-2, 1]}} %1, %2 = "onnx.TopK"(%arg0, %arg1) {axis = 2 : si64, largest = 1 : si64, sorted = 1 : si64} : (tensor<3x4xi64>, tensor<1xi64>) -> (tensor<*xf32>, tensor<*xi64>)