Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions src/input_parsers/OnnxParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,12 +411,26 @@ Vector<int> getTensorIntValues( const onnx::TensorProto& tensor, const TensorSha
{
checkEndianness();
const char* bytes = raw_data.c_str();
const int* ints = reinterpret_cast<const int*>( bytes );
for ( int i = 0; i < size; i++ )
{
int value = *(ints + i);
result.append( value );
}
onnx::TensorProto_DataType dataType = static_cast<onnx::TensorProto_DataType>( tensor.data_type() );

if (dataType == onnx::TensorProto_DataType_INT32) {
const int32_t* ints = reinterpret_cast<const int32_t*>( bytes );
for ( int i = 0; i < size; i++ )
{
int value = *(ints + i);
result.append( value );
}
} else if (dataType == onnx::TensorProto_DataType_INT64) {
const int64_t* ints = reinterpret_cast<const int64_t*>( bytes );
for ( int i = 0; i < size; i++ )
{
int value = *(ints + i);
result.append( value );
}
} else {
String errorMessage = Stringf( "Illegal data type for integer tensors used in the model. Only INT32 and INT64 supported." );
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be "unsupported" not "illegal"?

throw MarabouError( MarabouError::ONNX_PARSER_ERROR, errorMessage.ascii() );
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you also need to add a corresponding case for the else below when the data is not stored in raw_data. I'd advice lifting the dataType check above the raw_data check.

}
else
{
Expand Down Expand Up @@ -1239,6 +1253,13 @@ void OnnxParser::convEquations( onnx::NodeProto& node, [[maybe_unused]] bool mak
// First input should be variable tensor
String inputNodeName = node.input()[0];
TensorShape inputShape = _shapeMap[inputNodeName];

// Added: Check if convolutional layers are only 2D
if (inputShape.size()!=4) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very minor but this should be laid out as:

if ( inputShape.size() != 4 ) 
{

and tabs should be four spaces. Can you check that this style is used everywhere?

String errorMessage = Stringf( "Onnx '%s' operation has an unsupported number of dimensions -- only 2D convolutions are supported.", node.op_type().c_str() ) ;
throw MarabouError( MarabouError::ONNX_PARSER_ERROR, errorMessage.ascii() ) ;
}

[[maybe_unused]] unsigned int inputChannels = inputShape[1];
unsigned int inputWidth = inputShape[2];
unsigned int inputHeight = inputShape[3];
Expand Down