Skip to content

Commit

Permalink
Fix Expand shape inference: stop rank inference if the shape is symbo…
Browse files Browse the repository at this point in the history
…lic (onnx#4019)

* if expand cannt do rank inference; return early

Signed-off-by: Chun-Wei Chen <[email protected]>
  • Loading branch information
jcwchen authored Feb 16, 2022
1 parent c9d61b6 commit f6053e4
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 15 deletions.
3 changes: 2 additions & 1 deletion onnx/defs/math/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2063,11 +2063,12 @@ ONNX_OPERATOR_SET_SCHEMA(
for (int64_t i = 0; i < dim_value; ++i) {
second_shape.add_dim();
}
} else {
return;
}
bidirectionalBroadcastShapeInference(
input_shape, second_shape, *getOutputShape(ctx, 0));
}
return;
}));

static const char* Sinh_ver9_doc = R"DOC(
Expand Down
3 changes: 2 additions & 1 deletion onnx/defs/math/old.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1020,11 +1020,12 @@ ONNX_OPERATOR_SET_SCHEMA(
for (int64_t i = 0; i < dim_value; ++i) {
second_shape.add_dim();
}
} else {
return;
}
bidirectionalBroadcastShapeInference(
input_shape, second_shape, *getOutputShape(ctx, 0));
}
return;
}));

static const char* Sign_ver9_doc = R"DOC(
Expand Down
2 changes: 1 addition & 1 deletion onnx/defs/tensor/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ ONNX_OPERATOR_SET_SCHEMA(
for (size_t i = 0; i < numInputs; i++) {
const auto& shape = ctx.getInputType(i)->tensor_type().shape();
if (shape.dim_size() != rank) {
fail_shape_inference("All inputs to Concat must have same rank");
fail_shape_inference("All inputs to Concat must have same rank. Input ", i , " has rank ", shape.dim_size(), " != ", rank);
}
for (int j = 0; j < rank; j++) {
if (j == axis) {
Expand Down
2 changes: 1 addition & 1 deletion onnx/defs/tensor/old.cc
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ ONNX_OPERATOR_SET_SCHEMA(
for (size_t i = 0; i < numInputs; i++) {
const auto& shape = ctx.getInputType(i)->tensor_type().shape();
if (shape.dim_size() != rank) {
fail_shape_inference("All inputs to Concat must have same rank");
fail_shape_inference("All inputs to Concat must have same rank. Input ", i , " has rank ", shape.dim_size(), " != ", rank);
}
for (int j = 0; j < rank; j++) {
if (j == axis) {
Expand Down
36 changes: 25 additions & 11 deletions onnx/test/shape_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,16 @@ def _compare_value_infos(self, vi_type: TypeProto, inferred_vi_type: TypeProto)
assert vi_type.tensor_type.HasField('elem_type')
assert inferred_vi_type.tensor_type.HasField('elem_type')
assert vi_type.tensor_type.elem_type == inferred_vi_type.tensor_type.elem_type
for dim_i in range(len(vi_type.tensor_type.shape.dim)):
dim = vi_type.tensor_type.shape.dim[dim_i]
inferred_dim = inferred_vi_type.tensor_type.shape.dim[dim_i]
# if it is a symbolic shape, make sure the inferred symbol has generated (dim_param)
if dim.dim_param:
assert dim.dim_param == inferred_dim.dim_param, '\n%s\n%s\n' % (vi_type, inferred_vi_type)
else:
assert dim.dim_value == inferred_dim.dim_value, '\n%s\n%s\n' % (vi_type, inferred_vi_type)
assert vi_type.tensor_type.HasField('shape') == inferred_vi_type.tensor_type.HasField('shape')
if vi_type.tensor_type.HasField('shape'):
for dim_i in range(len(vi_type.tensor_type.shape.dim)):
dim = vi_type.tensor_type.shape.dim[dim_i]
inferred_dim = inferred_vi_type.tensor_type.shape.dim[dim_i]
# if it is a symbolic shape, make sure the inferred symbol has generated (dim_param)
if dim.dim_param:
assert dim.dim_param == inferred_dim.dim_param, '\n%s\n%s\n' % (vi_type, inferred_vi_type)
else:
assert dim.dim_value == inferred_dim.dim_value, '\n%s\n%s\n' % (vi_type, inferred_vi_type)
elif vi_type.HasField('sequence_type'):
assert inferred_vi_type.HasField('sequence_type')
vi = vi_type.sequence_type.elem_type
Expand Down Expand Up @@ -383,6 +385,18 @@ def test_expand_dynamic_shape(self) -> None:
graph,
[make_tensor_value_info('y', TensorProto.INT32, (None, 2, None))])

def test_expand_symbolic_shape(self) -> None:
graph = self._make_graph(
[('x', TensorProto.INT32, (1, 2, None)),
('shape', TensorProto.INT64, ('unk__0',))],
[make_node("Expand", ['x', 'shape'], ['y'])],
[],
initializer=[])
# if giving a symbolic shape, Expand should not infer any shape or rank inference
self._assert_inferred(
graph,
[make_tensor_value_info('y', TensorProto.INT32, None)])

def test_resize_size(self) -> None:
graph = self._make_graph(
[('x', TensorProto.INT32, (2, 4, 3, 5)),
Expand Down Expand Up @@ -2111,7 +2125,7 @@ def test_if_with_different_optional_shapes_in_then_else_branches(self) -> None:
[]
)

output_tensor_proto = helper.make_tensor_type_proto(elem_type=TensorProto.FLOAT, shape=None)
output_tensor_proto = helper.make_tensor_type_proto(elem_type=TensorProto.FLOAT, shape=(None, ))
output_optional_type_proto = helper.make_optional_type_proto(output_tensor_proto)
output_optional_vi = helper.make_value_info('if_output', output_optional_type_proto)
self._assert_inferred(graph, [output_optional_vi]) # type: ignore
Expand Down Expand Up @@ -3862,7 +3876,7 @@ def test_optional_tensor_has_element(self) -> None:
make_node('OptionalHasElement', ['sequence'], ['output'])],
[])
self._assert_inferred(graph, [optional_val_info,
make_tensor_value_info('output', TensorProto.BOOL, None)]) # type: ignore
make_tensor_value_info('output', TensorProto.BOOL, ())]) # type: ignore

def test_optional_sequence_has_element(self) -> None:
tensor_type_proto = helper.make_tensor_type_proto(elem_type=TensorProto.FLOAT, shape=[0, 3, 4])
Expand All @@ -3881,7 +3895,7 @@ def test_optional_sequence_has_element(self) -> None:
make_node('OptionalHasElement', ['optional'], ['output'])],
[])
self._assert_inferred(graph, [sequence_val_info, optional_val_info,
make_tensor_value_info('output', TensorProto.BOOL, None)]) # type: ignore
make_tensor_value_info('output', TensorProto.BOOL, ())]) # type: ignore

def test_optional_tensor_get_element(self) -> None:
tensor_type_proto = helper.make_tensor_type_proto(elem_type=TensorProto.DOUBLE, shape=[2, 1, 4])
Expand Down

0 comments on commit f6053e4

Please sign in to comment.