diff --git a/onnx/defs/nn/defs.cc b/onnx/defs/nn/defs.cc index c31cc4f36f6..acfd5a2c518 100644 --- a/onnx/defs/nn/defs.cc +++ b/onnx/defs/nn/defs.cc @@ -1765,11 +1765,11 @@ ONNX_OPERATOR_SET_SCHEMA( TensorShapeProto outputs_shape; *outputs_shape.add_dim() = num_channels; // channel - propagateElemTypeFromInputToOutput(ctx, 0, 1); + propagateElemTypeFromInputToOutput(ctx, 3, 1); updateOutputShape(ctx, 1, outputs_shape); if (ctx.getNumOutputs() > 2) { - propagateElemTypeFromInputToOutput(ctx, 0, 2); + propagateElemTypeFromInputToOutput(ctx, 4, 2); updateOutputShape(ctx, 2, outputs_shape); } } diff --git a/onnx/test/shape_inference_test.py b/onnx/test/shape_inference_test.py index ec1202b4da8..6a5b01d7cd6 100644 --- a/onnx/test/shape_inference_test.py +++ b/onnx/test/shape_inference_test.py @@ -3485,6 +3485,21 @@ def test_batch_norm_train_dim_param(self): # type: () -> None make_tensor_value_info('output_var', TensorProto.FLOAT, ('C',)), # type: ignore ]) + def test_batch_norm_train_with_diff_type(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.FLOAT16, (3, 4, 5, 6, 7)), + ('scale', TensorProto.FLOAT16, (4,)), + ('b', TensorProto.FLOAT16, (4,)), + ('input_mean', TensorProto.FLOAT, (4,)), + ('input_var', TensorProto.FLOAT, (4,))], + [make_node('BatchNormalization', ['x', 'scale', 'b', 'input_mean', 'input_var'], + ['out', 'output_mean', 'output_var'], training_mode=1)], + []) + self._assert_inferred(graph, [make_tensor_value_info('out', TensorProto.FLOAT16, (3, 4, 5, 6, 7)), # type: ignore + make_tensor_value_info('output_mean', TensorProto.FLOAT, (4,)), # type: ignore + make_tensor_value_info('output_var', TensorProto.FLOAT, (4,)), # type: ignore + ]) + def test_batch_norm_test(self): # type: () -> None graph = self._make_graph( [('x', TensorProto.FLOAT, (3, 4, 5, 6, 7)),