Skip to content

Commit eb73a85

Browse files
authored
Add SavedModel export to Resnet (tensorflow#3759)
* Adding export_dir and model saving for Resnet * Moving to utils for tests * Adding batch_size * Adding multi-gpu export warning * Responding to CR * Py3 compliance
1 parent 1bfe1df commit eb73a85

File tree

8 files changed

+188
-15
lines changed

8 files changed

+188
-15
lines changed

official/mnist/mnist.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -246,14 +246,9 @@ class MNISTArgParser(argparse.ArgumentParser):
246246
def __init__(self):
247247
super(MNISTArgParser, self).__init__(parents=[
248248
parsers.BaseParser(),
249-
parsers.ImageModelParser()])
250-
251-
self.add_argument(
252-
'--export_dir',
253-
type=str,
254-
help='[default: %(default)s] If set, a SavedModel serialization of the '
255-
'model will be exported to this directory at the end of training. '
256-
'See the README for more details and relevant links.')
249+
parsers.ImageModelParser(),
250+
parsers.ExportParser(),
251+
])
257252

258253
self.set_defaults(
259254
data_dir='/tmp/mnist_data',

official/resnet/cifar10_main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,10 @@ def main(argv):
228228
flags = parser.parse_args(args=argv[1:])
229229

230230
input_function = flags.use_synthetic_data and get_synth_input_fn() or input_fn
231-
resnet_run_loop.resnet_main(flags, cifar10_model_fn, input_function)
231+
232+
resnet_run_loop.resnet_main(
233+
flags, cifar10_model_fn, input_function,
234+
shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS])
232235

233236

234237
if __name__ == '__main__':

official/resnet/imagenet_main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,10 @@ def main(argv):
305305
flags = parser.parse_args(args=argv[1:])
306306

307307
input_function = flags.use_synthetic_data and get_synth_input_fn() or input_fn
308-
resnet_run_loop.resnet_main(flags, imagenet_model_fn, input_function)
308+
309+
resnet_run_loop.resnet_main(
310+
flags, imagenet_model_fn, input_function,
311+
shape=[_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS])
309312

310313

311314
if __name__ == '__main__':

official/resnet/resnet_run_loop.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from official.resnet import resnet_model
3232
from official.utils.arg_parsers import parsers
33+
from official.utils.export import export
3334
from official.utils.logging import hooks_helper
3435
from official.utils.logging import logger
3536

@@ -219,7 +220,13 @@ def resnet_model_fn(features, labels, mode, model_class,
219220
}
220221

221222
if mode == tf.estimator.ModeKeys.PREDICT:
222-
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
223+
# Return the predictions and the specification for serving a SavedModel
224+
return tf.estimator.EstimatorSpec(
225+
mode=mode,
226+
predictions=predictions,
227+
export_outputs={
228+
'predict': tf.estimator.export.PredictOutput(predictions)
229+
})
223230

224231
# Calculate loss, which includes softmax cross entropy and L2 regularization.
225232
cross_entropy = tf.losses.softmax_cross_entropy(
@@ -310,8 +317,20 @@ def validate_batch_size_for_multi_gpu(batch_size):
310317
raise ValueError(err)
311318

312319

313-
def resnet_main(flags, model_function, input_function):
314-
"""Shared main loop for ResNet Models."""
320+
def resnet_main(flags, model_function, input_function, shape=None):
321+
"""Shared main loop for ResNet Models.
322+
323+
Args:
324+
flags: FLAGS object that contains the params for running. See
325+
ResnetArgParser for created flags.
326+
model_function: the function that instantiates the Model and builds the
327+
ops for train/eval. This will be passed directly into the estimator.
328+
input_function: the function that processes the dataset and returns a
329+
dataset that the estimator can train on. This will be wrapped with
330+
all the relevant flags for running and passed to estimator.
331+
shape: list of ints representing the shape of the images used for training.
332+
This is only used if flags.export_dir is passed.
333+
"""
315334

316335
# Using the Winograd non-fused algorithms provides a small performance boost.
317336
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
@@ -389,16 +408,34 @@ def input_fn_eval():
389408
if benchmark_logger:
390409
benchmark_logger.log_estimator_evaluation_result(eval_results)
391410

411+
if flags.export_dir is not None:
412+
warn_on_multi_gpu_export(flags.multi_gpu)
413+
414+
# Exports a saved model for the given classifier.
415+
input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
416+
shape, batch_size=flags.batch_size)
417+
classifier.export_savedmodel(flags.export_dir, input_receiver_fn)
418+
419+
420+
def warn_on_multi_gpu_export(multi_gpu=False):
421+
"""For the time being, multi-GPU mode does not play nicely with exporting."""
422+
if multi_gpu:
423+
tf.logging.warning(
424+
'You are exporting a SavedModel while in multi-GPU mode. Note that '
425+
'the resulting SavedModel will require the same GPUs be available.'
426+
'If you wish to serve the SavedModel from a different device, '
427+
'try exporting the SavedModel with multi-GPU mode turned off.')
428+
392429

393430
class ResnetArgParser(argparse.ArgumentParser):
394-
"""Arguments for configuring and running a Resnet Model.
395-
"""
431+
"""Arguments for configuring and running a Resnet Model."""
396432

397433
def __init__(self, resnet_size_choices=None):
398434
super(ResnetArgParser, self).__init__(parents=[
399435
parsers.BaseParser(),
400436
parsers.PerformanceParser(),
401437
parsers.ImageModelParser(),
438+
parsers.ExportParser(),
402439
parsers.BenchmarkParser(),
403440
])
404441

official/utils/arg_parsers/parsers.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,29 @@ def __init__(self, add_help=False, data_format=True):
226226
)
227227

228228

229+
class ExportParser(argparse.ArgumentParser):
230+
"""Parsing options for exporting saved models or other graph defs.
231+
232+
This is a separate parser for now, but should be made part of BaseParser
233+
once all models are brought up to speed.
234+
235+
Args:
236+
add_help: Create the "--help" flag. False if class instance is a parent.
237+
export_dir: Create a flag to specify where a SavedModel should be exported.
238+
"""
239+
240+
def __init__(self, add_help=False, export_dir=True):
241+
super(ExportParser, self).__init__(add_help=add_help)
242+
if export_dir:
243+
self.add_argument(
244+
"--export_dir", "-ed",
245+
help="[default: %(default)s] If set, a SavedModel serialization of "
246+
"the model will be exported to this directory at the end of "
247+
"training. See the README for more details and relevant links.",
248+
metavar="<ED>"
249+
)
250+
251+
229252
class BenchmarkParser(argparse.ArgumentParser):
230253
"""Default parser for benchmark logging.
231254

official/utils/export/__init__.py

Whitespace-only changes.

official/utils/export/export.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Convenience functions for exporting models as SavedModels or other types."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import tensorflow as tf
22+
23+
24+
def build_tensor_serving_input_receiver_fn(shape, dtype=tf.float32,
25+
batch_size=1):
26+
"""Returns a input_receiver_fn that can be used during serving.
27+
28+
This expects examples to come through as float tensors, and simply
29+
wraps them as TensorServingInputReceivers.
30+
31+
Arguably, this should live in tf.estimator.export. Testing here first.
32+
33+
Args:
34+
shape: list representing target size of a single example.
35+
dtype: the expected datatype for the input example
36+
batch_size: number of input tensors that will be passed for prediction
37+
38+
Returns:
39+
A function that itself returns a TensorServingInputReceiver.
40+
"""
41+
def serving_input_receiver_fn():
42+
# Prep a placeholder where the input example will be fed in
43+
features = tf.placeholder(
44+
dtype=dtype, shape=[batch_size] + shape, name='input_tensor')
45+
46+
return tf.estimator.export.TensorServingInputReceiver(
47+
features=features, receiver_tensors=features)
48+
49+
return serving_input_receiver_fn

official/utils/export/export_test.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for exporting utils."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import tensorflow as tf # pylint: disable=g-bad-import-order
22+
23+
from official.utils.export import export
24+
25+
26+
class ExportUtilsTest(tf.test.TestCase):
27+
"""Tests for the ExportUtils."""
28+
29+
def test_build_tensor_serving_input_receiver_fn(self):
30+
receiver_fn = export.build_tensor_serving_input_receiver_fn(shape=[4, 5])
31+
with tf.Graph().as_default():
32+
receiver = receiver_fn()
33+
self.assertIsInstance(
34+
receiver, tf.estimator.export.TensorServingInputReceiver)
35+
36+
self.assertIsInstance(receiver.features, tf.Tensor)
37+
self.assertEqual(receiver.features.shape, tf.TensorShape([1, 4, 5]))
38+
self.assertEqual(receiver.features.dtype, tf.float32)
39+
self.assertIsInstance(receiver.receiver_tensors, dict)
40+
# Note that Python 3 can no longer index .values() directly; cast to list.
41+
self.assertEqual(list(receiver.receiver_tensors.values())[0].shape,
42+
tf.TensorShape([1, 4, 5]))
43+
44+
def test_build_tensor_serving_input_receiver_fn_batch_dtype(self):
45+
receiver_fn = export.build_tensor_serving_input_receiver_fn(
46+
shape=[4, 5], dtype=tf.int8, batch_size=10)
47+
48+
with tf.Graph().as_default():
49+
receiver = receiver_fn()
50+
self.assertIsInstance(
51+
receiver, tf.estimator.export.TensorServingInputReceiver)
52+
53+
self.assertIsInstance(receiver.features, tf.Tensor)
54+
self.assertEqual(receiver.features.shape, tf.TensorShape([10, 4, 5]))
55+
self.assertEqual(receiver.features.dtype, tf.int8)
56+
self.assertIsInstance(receiver.receiver_tensors, dict)
57+
# Note that Python 3 can no longer index .values() directly; cast to list.
58+
self.assertEqual(list(receiver.receiver_tensors.values())[0].shape,
59+
tf.TensorShape([10, 4, 5]))
60+
61+
62+
if __name__ == "__main__":
63+
tf.test.main()

0 commit comments

Comments
 (0)