@@ -227,14 +227,15 @@ ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType(
227
227
std::unique_ptr<OrtTensorTypeAndShapeInfo> OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper (
228
228
ONNXTensorElementDataType type,
229
229
onnxruntime::TensorShape shape,
230
- const std::vector<std::string>* dim_params) {
230
+ const std::vector<std::string>* dim_params,
231
+ bool has_shape) {
231
232
auto type_and_shape = std::make_unique<OrtTensorTypeAndShapeInfo>();
232
233
type_and_shape->type = type;
233
234
type_and_shape->shape = std::move (shape);
235
+ type_and_shape->has_shape = has_shape;
234
236
235
237
if (dim_params != nullptr ) {
236
238
type_and_shape->dim_params = *dim_params;
237
- type_and_shape->has_shape = true ;
238
239
} else {
239
240
type_and_shape->dim_params .resize (type_and_shape->shape .NumDimensions (), " " );
240
241
}
@@ -244,18 +245,20 @@ std::unique_ptr<OrtTensorTypeAndShapeInfo> OrtTensorTypeAndShapeInfo::GetTensorS
244
245
245
246
std::unique_ptr<OrtTensorTypeAndShapeInfo> OrtTensorTypeAndShapeInfo::GetTensorShapeAndType (
246
247
onnxruntime::TensorShape shape,
247
- const onnxruntime::DataTypeImpl& tensor_data_type) {
248
+ const onnxruntime::DataTypeImpl& tensor_data_type,
249
+ bool has_shape) {
248
250
ONNXTensorElementDataType type = MLDataTypeToOnnxRuntimeTensorElementDataType (&tensor_data_type);
249
251
if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) {
250
252
ORT_NOT_IMPLEMENTED (" Tensor type is undefined" );
251
253
}
252
- return GetTensorShapeAndTypeHelper (type, std::move (shape), nullptr );
254
+ return GetTensorShapeAndTypeHelper (type, std::move (shape), nullptr , has_shape );
253
255
}
254
256
255
257
std::unique_ptr<OrtTensorTypeAndShapeInfo> OrtTensorTypeAndShapeInfo::GetTensorShapeAndType (
256
258
onnxruntime::TensorShape shape,
257
259
const std::vector<std::string>* dim_params,
258
- const ONNX_NAMESPACE::TypeProto& type_proto) {
260
+ const ONNX_NAMESPACE::TypeProto& type_proto,
261
+ bool has_shape) {
259
262
auto value_case = type_proto.value_case ();
260
263
assert (value_case == ONNX_NAMESPACE::TypeProto::kTensorType ||
261
264
value_case == ONNX_NAMESPACE::TypeProto::kSparseTensorType );
@@ -266,7 +269,8 @@ std::unique_ptr<OrtTensorTypeAndShapeInfo> OrtTensorTypeAndShapeInfo::GetTensorS
266
269
if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) {
267
270
ORT_NOT_IMPLEMENTED (" Tensor type is undefined" );
268
271
}
269
- return GetTensorShapeAndTypeHelper (type, std::move (shape), dim_params);
272
+
273
+ return GetTensorShapeAndTypeHelper (type, std::move (shape), dim_params, has_shape);
270
274
}
271
275
272
276
ORT_API_STATUS_IMPL (OrtApis::GetTensorTypeAndShape,
@@ -283,14 +287,14 @@ ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape,
283
287
const Tensor& tensor = v->Get <onnxruntime::Tensor>();
284
288
shape = &tensor.Shape ();
285
289
data_type = tensor.DataType ();
286
- auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType (*shape, *data_type);
290
+ auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType (*shape, *data_type, true );
287
291
*out = ptr.release ();
288
292
} else {
289
293
#if !defined(DISABLE_SPARSE_TENSORS)
290
294
const SparseTensor& tensor = v->Get <onnxruntime::SparseTensor>();
291
295
shape = &tensor.DenseShape ();
292
296
data_type = tensor.DataType ();
293
- auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType (*shape, *data_type);
297
+ auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType (*shape, *data_type, true );
294
298
*out = ptr.release ();
295
299
#else
296
300
ORT_NOT_IMPLEMENTED (" SparseTensor is not supported in this build." );
@@ -309,7 +313,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetSparseTensorValuesTypeAndShape, _In_ const OrtVa
309
313
#if !defined(DISABLE_SPARSE_TENSORS)
310
314
const auto & sparse_tensor = SparseTensor::GetSparseTensorFromOrtValue (*v);
311
315
const auto & values = sparse_tensor.Values ();
312
- auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType (values.Shape (), *values.DataType ());
316
+ auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType (values.Shape (), *values.DataType (), true );
313
317
*out = ptr.release ();
314
318
return nullptr ;
315
319
#else
@@ -351,7 +355,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetSparseTensorIndicesTypeShape, _In_ const OrtValu
351
355
API_IMPL_BEGIN
352
356
#if !defined(DISABLE_SPARSE_TENSORS)
353
357
const Tensor& indices_tensor = GetIndicesTensor (*v, indices_format);
354
- auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType (indices_tensor.Shape (), *indices_tensor.DataType ());
358
+ auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType (indices_tensor.Shape (), *indices_tensor.DataType (), true );
355
359
*out = ptr.release ();
356
360
return nullptr ;
357
361
#else
0 commit comments