Skip to content

Commit c504ca6

Browse files
nkovela1tensorflower-gardener
authored andcommitted
Creates non-breaking changes where necessary in preparation for switching all of Keras to new serialization format.
PiperOrigin-RevId: 507864605
1 parent 8c33592 commit c504ca6

25 files changed

+311
-85
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,11 @@ def testStripClusteringSequentialModel(self):
667667
stripped_model = cluster.strip_clustering(clustered_model)
668668

669669
self.assertEqual(self._count_clustered_layers(stripped_model), 0)
670-
self.assertEqual(model.get_config(), stripped_model.get_config())
670+
model_config = model.get_config()
671+
for layer in model_config['layers']:
672+
# New serialization format includes `build_config` in wrapper
673+
layer.pop('build_config', None)
674+
self.assertEqual(model_config, stripped_model.get_config())
671675

672676
def testClusterStrippingFunctionalModel(self):
673677
"""Verifies that stripping the clustering wrappers from a functional model produces the expected config."""

tensorflow_model_optimization/python/core/quantization/keras/BUILD

+9-1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ py_strict_test(
7272
visibility = ["//visibility:public"],
7373
deps = [
7474
":quantizers",
75+
":utils",
7576
# absl/testing:parameterized dep1,
7677
# numpy dep1,
7778
# tensorflow dep1,
@@ -87,9 +88,10 @@ py_strict_library(
8788
srcs_version = "PY3",
8889
visibility = ["//visibility:public"],
8990
deps = [
91+
":quantizers",
92+
":utils",
9093
# six dep1,
9194
# tensorflow dep1,
92-
"//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
9395
],
9496
)
9597

@@ -125,6 +127,7 @@ py_strict_library(
125127
srcs_version = "PY3",
126128
visibility = ["//visibility:public"],
127129
deps = [
130+
":utils",
128131
# tensorflow dep1,
129132
],
130133
)
@@ -152,6 +155,7 @@ py_strict_library(
152155
srcs_version = "PY3",
153156
visibility = ["//visibility:public"],
154157
deps = [
158+
":utils",
155159
# tensorflow dep1,
156160
"//tensorflow_model_optimization/python/core/keras:utils",
157161
],
@@ -167,6 +171,7 @@ py_strict_test(
167171
deps = [
168172
":quantize_aware_activation",
169173
":quantizers",
174+
":utils",
170175
# absl/testing:parameterized dep1,
171176
# numpy dep1,
172177
# tensorflow dep1,
@@ -182,6 +187,7 @@ py_strict_library(
182187
visibility = ["//visibility:public"],
183188
deps = [
184189
":quantizers",
190+
":utils",
185191
# tensorflow dep1,
186192
"//tensorflow_model_optimization/python/core/keras:utils",
187193
],
@@ -211,6 +217,7 @@ py_strict_library(
211217
visibility = ["//visibility:public"],
212218
deps = [
213219
":quantize_aware_activation",
220+
":utils",
214221
# tensorflow dep1,
215222
# python/util tensorflow dep2,
216223
"//tensorflow_model_optimization/python/core/keras:metrics",
@@ -249,6 +256,7 @@ py_strict_library(
249256
":quantize_layer",
250257
":quantize_wrapper",
251258
":quantizers",
259+
":utils",
252260
# tensorflow dep1,
253261
"//tensorflow_model_optimization/python/core/keras:metrics",
254262
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_registry",

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,15 @@
2525
import tensorflow as tf
2626

2727
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
28+
from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
2829
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
2930

3031
keras = tf.keras
3132
K = tf.keras.backend
3233
l = tf.keras.layers
3334

34-
deserialize_keras_object = tf.keras.utils.deserialize_keras_object
35-
serialize_keras_object = tf.keras.utils.serialize_keras_object
35+
deserialize_keras_object = quantize_utils.deserialize_keras_object
36+
serialize_keras_object = quantize_utils.serialize_keras_object
3637

3738

3839
class _TestHelper(object):

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py

+41-12
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
2525
from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer
2626
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
27+
from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
2728
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_configs
2829
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
2930
from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms
@@ -67,13 +68,17 @@ def _get_params(conv_layer, bn_layer, relu_layer=None):
6768
list(conv_layer['config'].items()) + list(bn_layer['config'].items()))
6869

6970
if relu_layer is not None:
70-
params['post_activation'] = keras.layers.deserialize(relu_layer)
71+
params['post_activation'] = quantize_utils.deserialize_layer(
72+
relu_layer, use_legacy_format=True
73+
)
7174

7275
return params
7376

7477

7578
def _get_layer_node(fused_layer, weights):
76-
layer_config = keras.layers.serialize(fused_layer)
79+
layer_config = quantize_utils.serialize_layer(
80+
fused_layer, use_legacy_format=True
81+
)
7782
layer_config['name'] = layer_config['config']['name']
7883
# This config tracks which layers get quantized, and whether they have a
7984
# custom QuantizeConfig.
@@ -118,7 +123,10 @@ def _replace(self, bn_layer_node, conv_layer_node):
118123
return bn_layer_node
119124

120125
conv_layer_node.layer['config']['activation'] = (
121-
keras.activations.serialize(quantize_aware_activation.NoOpActivation()))
126+
quantize_utils.serialize_activation(
127+
quantize_aware_activation.NoOpActivation(), use_legacy_format=True
128+
)
129+
)
122130
bn_layer_node.metadata['quantize_config'] = (
123131
default_8bit_quantize_configs.Default8BitOutputQuantizeConfig())
124132

@@ -180,7 +188,10 @@ def _replace(self, relu_layer_node, bn_layer_node, conv_layer_node):
180188
return relu_layer_node
181189

182190
conv_layer_node.layer['config']['activation'] = (
183-
keras.activations.serialize(quantize_aware_activation.NoOpActivation()))
191+
quantize_utils.serialize_activation(
192+
quantize_aware_activation.NoOpActivation(), use_legacy_format=True
193+
)
194+
)
184195
bn_layer_node.metadata['quantize_config'] = (
185196
default_8bit_quantize_configs.NoOpQuantizeConfig())
186197

@@ -261,7 +272,10 @@ def _replace(self, bn_layer_node, dense_layer_node):
261272
return bn_layer_node
262273

263274
dense_layer_node.layer['config']['activation'] = (
264-
keras.activations.serialize(quantize_aware_activation.NoOpActivation()))
275+
quantize_utils.serialize_activation(
276+
quantize_aware_activation.NoOpActivation(), use_legacy_format=True
277+
)
278+
)
265279
bn_layer_node.metadata['quantize_config'] = (
266280
default_8bit_quantize_configs.Default8BitOutputQuantizeConfig())
267281

@@ -297,7 +311,10 @@ def _replace(self, relu_layer_node, bn_layer_node, dense_layer_node):
297311
return relu_layer_node
298312

299313
dense_layer_node.layer['config']['activation'] = (
300-
keras.activations.serialize(quantize_aware_activation.NoOpActivation()))
314+
quantize_utils.serialize_activation(
315+
quantize_aware_activation.NoOpActivation(), use_legacy_format=True
316+
)
317+
)
301318
bn_layer_node.metadata['quantize_config'] = (
302319
default_8bit_quantize_configs.NoOpQuantizeConfig())
303320

@@ -408,7 +425,9 @@ def replacement(self, match_layer):
408425
else:
409426
spatial_dim = 2
410427

411-
sepconv2d_layer_config = keras.layers.serialize(sepconv2d_layer)
428+
sepconv2d_layer_config = quantize_utils.serialize_layer(
429+
sepconv2d_layer, use_legacy_format=True
430+
)
412431
sepconv2d_layer_config['name'] = sepconv2d_layer.name
413432

414433
# Needed to ensure these new layers are considered for quantization.
@@ -420,15 +439,19 @@ def replacement(self, match_layer):
420439
expand_layer = tf.keras.layers.Lambda(
421440
lambda x: tf.expand_dims(x, spatial_dim),
422441
name=self._get_name('sepconv1d_expand'))
423-
expand_layer_config = keras.layers.serialize(expand_layer)
442+
expand_layer_config = quantize_utils.serialize_layer(
443+
expand_layer, use_legacy_format=True
444+
)
424445
expand_layer_config['name'] = expand_layer.name
425446
expand_layer_metadata = {
426447
'quantize_config': default_8bit_quantize_configs.NoOpQuantizeConfig()}
427448

428449
squeeze_layer = tf.keras.layers.Lambda(
429450
lambda x: tf.squeeze(x, [spatial_dim]),
430451
name=self._get_name('sepconv1d_squeeze'))
431-
squeeze_layer_config = keras.layers.serialize(squeeze_layer)
452+
squeeze_layer_config = quantize_utils.serialize_layer(
453+
squeeze_layer, use_legacy_format=True
454+
)
432455
squeeze_layer_config['name'] = squeeze_layer.name
433456
squeeze_layer_metadata = {
434457
'quantize_config': default_8bit_quantize_configs.NoOpQuantizeConfig()}
@@ -493,7 +516,9 @@ def replacement(self, match_layer):
493516
)
494517
dconv_weights = collections.OrderedDict()
495518
dconv_weights['depthwise_kernel:0'] = sepconv_weights[0]
496-
dconv_layer_config = keras.layers.serialize(dconv_layer)
519+
dconv_layer_config = quantize_utils.serialize_layer(
520+
dconv_layer, use_legacy_format=True
521+
)
497522
dconv_layer_config['name'] = dconv_layer.name
498523
# Needed to ensure these new layers are considered for quantization.
499524
dconv_metadata = {'quantize_config': None}
@@ -521,7 +546,9 @@ def replacement(self, match_layer):
521546
conv_weights['kernel:0'] = sepconv_weights[1]
522547
if sepconv_layer['config']['use_bias']:
523548
conv_weights['bias:0'] = sepconv_weights[2]
524-
conv_layer_config = keras.layers.serialize(conv_layer)
549+
conv_layer_config = quantize_utils.serialize_layer(
550+
conv_layer, use_legacy_format=True
551+
)
525552
conv_layer_config['name'] = conv_layer.name
526553
# Needed to ensure these new layers are considered for quantization.
527554
conv_metadata = {'quantize_config': None}
@@ -588,7 +615,9 @@ def replacement(self, match_layer):
588615
quant_layer = quantize_layer.QuantizeLayer(
589616
quantizers.AllValuesQuantizer(
590617
num_bits=8, per_axis=False, symmetric=False, narrow_range=False))
591-
layer_config = keras.layers.serialize(quant_layer)
618+
layer_config = quantize_utils.serialize_layer(
619+
quant_layer, use_legacy_format=True
620+
)
592621
layer_config['name'] = quant_layer.name
593622

594623
quant_layer_node = LayerNode(

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -707,4 +707,6 @@ def testConcatConcatTransformDisablesOutput(self):
707707

708708

709709
if __name__ == '__main__':
710+
if hasattr(tf.keras.__internal__, 'enable_unsafe_deserialization'):
711+
tf.keras.__internal__.enable_unsafe_deserialization()
710712
tf.test.main()

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/quantize_numerical_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -200,4 +200,6 @@ def testModelEndToEnd(self, model_fn):
200200

201201

202202
if __name__ == '__main__':
203+
if hasattr(tf.keras.__internal__, 'enable_unsafe_deserialization'):
204+
tf.keras.__internal__.enable_unsafe_deserialization()
203205
tf.test.main()

tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry_test.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,15 @@
2525
import tensorflow as tf
2626

2727
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
28+
from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
2829
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantize_registry as n_bit_registry
2930

3031
keras = tf.keras
3132
K = tf.keras.backend
3233
l = tf.keras.layers
3334

34-
deserialize_keras_object = tf.keras.utils.deserialize_keras_object
35-
serialize_keras_object = tf.keras.utils.serialize_keras_object
35+
deserialize_keras_object = quantize_utils.deserialize_keras_object
36+
serialize_keras_object = quantize_utils.serialize_keras_object
3637

3738

3839
class _TestHelper(object):

0 commit comments

Comments
 (0)