Skip to content

Commit d900755

Browse files
committed
Change has_shape to parameter
1 parent 0e66c18 commit d900755

File tree

4 files changed

+26
-19
lines changed

4 files changed

+26
-19
lines changed

onnxruntime/core/framework/onnxruntime_typeinfo.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ std::unique_ptr<OrtTypeInfo> OrtTypeInfo::FromOrtValue(const OrtValue& value) {
170170
const Tensor& tensor = value.Get<onnxruntime::Tensor>();
171171
const auto* tensor_data_type = tensor.DataType();
172172
if (tensor_data_type != nullptr) {
173-
auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(tensor.Shape(), *tensor_data_type);
173+
auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(tensor.Shape(), *tensor_data_type, true);
174174
return MakePtr(ONNX_TYPE_TENSOR, std::move(type_shape));
175175
}
176176
return MakePtr(ONNX_TYPE_TENSOR);
@@ -181,7 +181,7 @@ std::unique_ptr<OrtTypeInfo> OrtTypeInfo::FromOrtValue(const OrtValue& value) {
181181
const SparseTensor& tensor = value.Get<onnxruntime::SparseTensor>();
182182
const auto* tensor_data_type = tensor.DataType();
183183
if (tensor_data_type != nullptr) {
184-
auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(tensor.DenseShape(), *tensor_data_type);
184+
auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(tensor.DenseShape(), *tensor_data_type, true);
185185
return MakePtr(ONNX_TYPE_SPARSETENSOR, std::move(type_shape));
186186
}
187187
return MakePtr(ONNX_TYPE_SPARSETENSOR);
@@ -195,7 +195,7 @@ std::unique_ptr<OrtTypeInfo> OrtTypeInfo::FromOrtValue(const OrtValue& value) {
195195
ORT_ENFORCE(tensor_data_type != nullptr, "OrtValue is TensorSequence type but has no element Tensor DataType.");
196196

197197
TensorShape void_shape = {};
198-
auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(void_shape, *tensor_data_type);
198+
auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(void_shape, *tensor_data_type, false);
199199
auto type_info = MakePtr(ONNX_TYPE_TENSOR, std::move(type_shape));
200200
auto sequence_type_info = std::make_unique<OrtSequenceTypeInfo>(std::move(type_info));
201201
return MakePtr(std::move(sequence_type_info));
@@ -303,9 +303,9 @@ std::unique_ptr<OrtTypeInfo> OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::Ty
303303
assert(false);
304304
}
305305
}
306-
type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(std::move(shape_data), &dim_params, input);
306+
type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(std::move(shape_data), &dim_params, input, true);
307307
} else {
308-
type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(TensorShape(), nullptr, input);
308+
type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(TensorShape(), nullptr, input, false);
309309
}
310310

311311
result = MakePtr(ten_type, std::move(type_shape));

onnxruntime/core/framework/tensor_type_and_shape.cc

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -227,14 +227,15 @@ ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType(
227227
std::unique_ptr<OrtTensorTypeAndShapeInfo> OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(
228228
ONNXTensorElementDataType type,
229229
onnxruntime::TensorShape shape,
230-
const std::vector<std::string>* dim_params) {
230+
const std::vector<std::string>* dim_params,
231+
bool has_shape) {
231232
auto type_and_shape = std::make_unique<OrtTensorTypeAndShapeInfo>();
232233
type_and_shape->type = type;
233234
type_and_shape->shape = std::move(shape);
235+
type_and_shape->has_shape = has_shape;
234236

235237
if (dim_params != nullptr) {
236238
type_and_shape->dim_params = *dim_params;
237-
type_and_shape->has_shape = true;
238239
} else {
239240
type_and_shape->dim_params.resize(type_and_shape->shape.NumDimensions(), "");
240241
}
@@ -244,18 +245,20 @@ std::unique_ptr<OrtTensorTypeAndShapeInfo> OrtTensorTypeAndShapeInfo::GetTensorS
244245

245246
std::unique_ptr<OrtTensorTypeAndShapeInfo> OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(
246247
onnxruntime::TensorShape shape,
247-
const onnxruntime::DataTypeImpl& tensor_data_type) {
248+
const onnxruntime::DataTypeImpl& tensor_data_type,
249+
bool has_shape) {
248250
ONNXTensorElementDataType type = MLDataTypeToOnnxRuntimeTensorElementDataType(&tensor_data_type);
249251
if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) {
250252
ORT_NOT_IMPLEMENTED("Tensor type is undefined");
251253
}
252-
return GetTensorShapeAndTypeHelper(type, std::move(shape), nullptr);
254+
return GetTensorShapeAndTypeHelper(type, std::move(shape), nullptr, has_shape);
253255
}
254256

255257
std::unique_ptr<OrtTensorTypeAndShapeInfo> OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(
256258
onnxruntime::TensorShape shape,
257259
const std::vector<std::string>* dim_params,
258-
const ONNX_NAMESPACE::TypeProto& type_proto) {
260+
const ONNX_NAMESPACE::TypeProto& type_proto,
261+
bool has_shape) {
259262
auto value_case = type_proto.value_case();
260263
assert(value_case == ONNX_NAMESPACE::TypeProto::kTensorType ||
261264
value_case == ONNX_NAMESPACE::TypeProto::kSparseTensorType);
@@ -266,7 +269,8 @@ std::unique_ptr<OrtTensorTypeAndShapeInfo> OrtTensorTypeAndShapeInfo::GetTensorS
266269
if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) {
267270
ORT_NOT_IMPLEMENTED("Tensor type is undefined");
268271
}
269-
return GetTensorShapeAndTypeHelper(type, std::move(shape), dim_params);
272+
273+
return GetTensorShapeAndTypeHelper(type, std::move(shape), dim_params, has_shape);
270274
}
271275

272276
ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape,
@@ -283,14 +287,14 @@ ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape,
283287
const Tensor& tensor = v->Get<onnxruntime::Tensor>();
284288
shape = &tensor.Shape();
285289
data_type = tensor.DataType();
286-
auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(*shape, *data_type);
290+
auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(*shape, *data_type, true);
287291
*out = ptr.release();
288292
} else {
289293
#if !defined(DISABLE_SPARSE_TENSORS)
290294
const SparseTensor& tensor = v->Get<onnxruntime::SparseTensor>();
291295
shape = &tensor.DenseShape();
292296
data_type = tensor.DataType();
293-
auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(*shape, *data_type);
297+
auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(*shape, *data_type, true);
294298
*out = ptr.release();
295299
#else
296300
ORT_NOT_IMPLEMENTED("SparseTensor is not supported in this build.");
@@ -309,7 +313,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetSparseTensorValuesTypeAndShape, _In_ const OrtVa
309313
#if !defined(DISABLE_SPARSE_TENSORS)
310314
const auto& sparse_tensor = SparseTensor::GetSparseTensorFromOrtValue(*v);
311315
const auto& values = sparse_tensor.Values();
312-
auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(values.Shape(), *values.DataType());
316+
auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(values.Shape(), *values.DataType(), true);
313317
*out = ptr.release();
314318
return nullptr;
315319
#else
@@ -351,7 +355,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetSparseTensorIndicesTypeShape, _In_ const OrtValu
351355
API_IMPL_BEGIN
352356
#if !defined(DISABLE_SPARSE_TENSORS)
353357
const Tensor& indices_tensor = GetIndicesTensor(*v, indices_format);
354-
auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(indices_tensor.Shape(), *indices_tensor.DataType());
358+
auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(indices_tensor.Shape(), *indices_tensor.DataType(), true);
355359
*out = ptr.release();
356360
return nullptr;
357361
#else

onnxruntime/core/framework/tensor_type_and_shape.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,19 @@ struct OrtTensorTypeAndShapeInfo {
3333
static std::unique_ptr<OrtTensorTypeAndShapeInfo> GetTensorShapeAndTypeHelper(
3434
ONNXTensorElementDataType type,
3535
onnxruntime::TensorShape shape,
36-
const std::vector<std::string>* dim_params);
36+
const std::vector<std::string>* dim_params,
37+
bool has_shape);
3738

3839
static std::unique_ptr<OrtTensorTypeAndShapeInfo> GetTensorShapeAndType(
3940
onnxruntime::TensorShape shape,
40-
const onnxruntime::DataTypeImpl& tensor_data_type);
41+
const onnxruntime::DataTypeImpl& tensor_data_type,
42+
bool has_shape);
4143

4244
static std::unique_ptr<OrtTensorTypeAndShapeInfo> GetTensorShapeAndType(
4345
onnxruntime::TensorShape shape,
4446
const std::vector<std::string>* dim_params,
45-
const ONNX_NAMESPACE::TypeProto&);
47+
const ONNX_NAMESPACE::TypeProto&,
48+
bool has_shape);
4649

4750
// We provide Clone() here to satisfy the existing coding pattern
4851
// as we need copies made on the heap even though we achieve that

onnxruntime/core/session/custom_ops.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ struct OrtShapeInferContext {
8080
auto tensor_shape = ::onnxruntime::utils::GetTensorShapeFromTensorShapeProto(shape_proto);
8181
auto symbolic_dims = GetSymbolicDims(shape_proto);
8282
input_type_shapes_.emplace_back(
83-
OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(elem_type, tensor_shape, &symbolic_dims).release());
83+
OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(elem_type, tensor_shape, &symbolic_dims, type_proto.has_shape()).release());
8484
}
8585
}
8686

0 commit comments

Comments
 (0)