Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions include/onnxruntime/core/providers/utils/ort_graph_to_proto.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(Ort::ConstValueInfo vi,
bool get_symbolic_dims,
/*out*/ ONNXTensorElementDataType& elem_type,
/*out*/ std::vector<int64_t>& dims,
/*out*/ std::vector<std::string>& symbolic_dims);
/*out*/ std::vector<std::string>& symbolic_dims,
/*out*/ bool& has_shape);
static Ort::Status OrtValueInfoToProto(Ort::ConstValueInfo ort_value_info, onnx::ValueInfoProto& value_info_proto);
static Ort::Status OrtOpAttrToProto(Ort::ConstOpAttr ort_attr, onnx::AttributeProto& attr_proto);

Expand Down Expand Up @@ -390,9 +391,10 @@ Ort::Status OrtGraphToProto(const OrtGraph& graph,
std::vector<int64_t> initializer_dims;
std::vector<std::string> initializer_sym_dims;
ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
bool has_shape = false;
ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(initializer_value_info, /*get_sym_dims*/ false,
initializer_elem_type, initializer_dims,
initializer_sym_dims));
initializer_sym_dims, has_shape));

onnx::TensorProto* tensor_proto = graph_proto.add_initializer();
tensor_proto->set_name(initializer_name);
Expand Down Expand Up @@ -493,7 +495,8 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(Ort::ConstValueInfo vi,
bool get_symbolic_dims,
/*out*/ ONNXTensorElementDataType& elem_type,
/*out*/ std::vector<int64_t>& dims,
/*out*/ std::vector<std::string>& symbolic_dims) {
/*out*/ std::vector<std::string>& symbolic_dims,
/*out*/ bool& has_shape) {
try {
Ort::ConstTypeInfo ort_type_info = vi.TypeInfo();
ONNXType ort_onnx_type = ort_type_info.GetONNXType();
Expand All @@ -505,6 +508,7 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(Ort::ConstValueInfo vi,
size_t num_dims = ort_type_shape.GetDimensionsCount();
std::vector<int64_t> ort_dims = ort_type_shape.GetShape();

has_shape = ort_type_shape.GetHasShape();
elem_type = ort_elem_type;
dims = std::move(ort_dims);

Expand All @@ -531,10 +535,11 @@ static Ort::Status OrtValueInfoToProto(Ort::ConstValueInfo ort_value_info,
std::vector<int64_t> ort_dims;
std::vector<std::string> ort_dim_syms;
ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
bool has_shape = false;

// We currently only support ONNX tensors. Support for other types (e.g., ONNX_TYPE_SEQUENCE) can be added later.
ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(ort_value_info, /*get_sym_dims*/ true,
ort_elem_type, ort_dims, ort_dim_syms));
ort_elem_type, ort_dims, ort_dim_syms, has_shape));

value_info_proto.set_name(ort_value_info.GetName());

Expand All @@ -543,7 +548,7 @@ static Ort::Status OrtValueInfoToProto(Ort::ConstValueInfo ort_value_info,

// If there are no dimensions in the shape, do not set a TensorShapeProto. Otherwise, it always looks
// like a scalar value.
if (!ort_dims.empty()) {
if (!ort_dims.empty() || has_shape) {
onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape();

for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) {
Expand Down
11 changes: 11 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -6580,6 +6580,17 @@ struct OrtApi {
_In_ size_t byte_size, _Outptr_ OrtExternalInitializerInfo** out);

/// @}
/// \name OrtTensorTypeAndShapeInfo
/// @{

/** \brief Get the attribute `has_shape` from ::OrtTensorTypeAndShapeInfo object
*
* \param[out] out Returns bool
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_API2_STATUS(GetHasShape, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ bool* out);
/// @}
};

/*
Expand Down
1 change: 1 addition & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1768,6 +1768,7 @@ struct TensorTypeAndShapeInfoImpl : Base<T> {
void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions
std::vector<const char*> GetSymbolicDimensions() const;

bool GetHasShape() const; ///< Wraps OrtApi::GetHasShape
std::vector<int64_t> GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape
};

Expand Down
7 changes: 7 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -1982,6 +1982,13 @@ inline size_t TensorTypeAndShapeInfoImpl<T>::GetElementCount() const {
return static_cast<size_t>(out);
}

template <typename T>
inline bool TensorTypeAndShapeInfoImpl<T>::GetHasShape() const {
bool out;
ThrowOnError(GetApi().GetHasShape(this->p_, &out));
return static_cast<bool>(out);
}

template <typename T>
inline size_t TensorTypeAndShapeInfoImpl<T>::GetDimensionsCount() const {
size_t out;
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/framework/tensor_type_and_shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ ORT_API_STATUS_IMPL(OrtApis::GetSymbolicDimensions,
return nullptr;
}

ORT_API_STATUS_IMPL(OrtApis::GetHasShape, _In_ const struct OrtTensorTypeAndShapeInfo* info,
_Out_ bool* out) {
*out = info->has_shape;
return nullptr;
}

ORT_API_STATUS_IMPL(OrtApis::SetSymbolicDimensions,
_In_ struct OrtTensorTypeAndShapeInfo* info,
_In_ const char** names, _In_ size_t dim_params_length) {
Expand Down Expand Up @@ -228,6 +234,7 @@ std::unique_ptr<OrtTensorTypeAndShapeInfo> OrtTensorTypeAndShapeInfo::GetTensorS

if (dim_params != nullptr) {
type_and_shape->dim_params = *dim_params;
type_and_shape->has_shape = true;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Being able to set has_shape to true when dim_params is not null seems like a convenient coincidence. Perhaps it would be good to pass that information to this function? I haven't fully thought it through but perhaps something like the following could work:

  1. Add a has_shape parameter to this GetTensorShapeAndTypeHelper function.
  2. Callers of GetTensorShapeAndTypeHelper can use onnx::TypeProto_Tensor::has_shape() to pass this information to GetTensorShapeAndTypeHelper.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

abs_lostdim_case.zip
This is the test model, you can use the abs_0d_input.onnx as input. The abs_0d_lostdim.onnx is the dumped model after serialization.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For your suggestion, could you take a look at onnxruntime/core/framework/onnxruntime_typeinfo.cc:L269, L285:L310. Here already convert the HasShape to nullptr or a valid pointer. After you review that code, I can add it if you think a new parameter is necessary.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think a new parameter is better. Right now, using dim_params != nullptr seems like a coincidence. When you upstream to ORT, others may have other suggestions, but I do think this is preferable.

Note that the line utils::HasShape() internally calls onnx::Typeproto_Tensor::has_shape() (as I pointed out in my original message). This seems like the value we want to pass down to GetTensorShapeAndType().

One other thing to note: the GetTensorShapeAndTypeHelper function is also called in core/session/custom_ops.cc:L83. This location also has access to onnx::TypeProto_Tensor::has_shape().

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for fixing this and providing the model.

} else {
type_and_shape->dim_params.resize(type_and_shape->shape.NumDimensions(), "");
}
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/framework/tensor_type_and_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ struct OrtTensorTypeAndShapeInfo {
// dim_param values. empty string if dim_value or no dim_param was specified.
// one entry per dimension in shape. only guaranteed to be populated for graph inputs and outputs
std::vector<std::string> dim_params;
bool has_shape = false;

OrtTensorTypeAndShapeInfo();
~OrtTensorTypeAndShapeInfo();
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4228,6 +4228,7 @@ static constexpr OrtApi ort_api_1_to_23 = {
&OrtApis::Graph_GetModelMetadata,
&OrtApis::GetModelCompatibilityForEpDevices,
&OrtApis::CreateExternalInitializerInfo,
&OrtApis::GetHasShape,
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ ORT_API_STATUS_IMPL(GetDimensionsCount, _In_ const OrtTensorTypeAndShapeInfo* in
ORT_API_STATUS_IMPL(GetDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length);
ORT_API_STATUS_IMPL(GetSymbolicDimensions, _In_ const OrtTensorTypeAndShapeInfo* info,
_Out_writes_all_(dim_params_length) const char* dim_params[], size_t dim_params_length);
ORT_API_STATUS_IMPL(GetHasShape, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ bool* out);
ORT_API_STATUS_IMPL(GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out);
ORT_API_STATUS_IMPL(GetTensorTypeAndShape, _In_ const OrtValue* value, _Outptr_ OrtTensorTypeAndShapeInfo** out);
ORT_API_STATUS_IMPL(GetTypeInfo, _In_ const OrtValue* value, _Outptr_result_maybenull_ OrtTypeInfo** out);
Expand Down
Loading