Skip to content

Commit 73db296

Browse files
Xharktensorflower-gardener
authored andcommitted
Fix compatibility issues for the TF/Keras 2.13.
PiperOrigin-RevId: 535031817
1 parent af9d021 commit 73db296

File tree

3 files changed

+23
-7
lines changed

3 files changed

+23
-7
lines changed

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import collections
1818
import inspect
1919

20-
from keras import backend
2120
import numpy as np
2221
import tensorflow as tf
2322

@@ -29,6 +28,12 @@
2928
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
3029
from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms
3130

31+
try:
32+
from keras.backend import unique_object_name # pylint: disable=g-import-not-at-top
33+
except ImportError:
34+
# Path as seen in pip packages as of TF/Keras 2.13.
35+
from keras.src.backend import unique_object_name # pylint: disable=g-import-not-at-top
36+
3237
LayerNode = transforms.LayerNode
3338
LayerPattern = transforms.LayerPattern
3439

@@ -364,9 +369,9 @@ def pattern(self):
364369
return LayerPattern('SeparableConv1D')
365370

366371
def _get_name(self, prefix):
367-
# TODO(pulkitb): Move away from `backend.unique_object_name` since it isn't
372+
# TODO(pulkitb): Move away from `unique_object_name` since it isn't
368373
# exposed as externally usable.
369-
return backend.unique_object_name(prefix)
374+
return unique_object_name(prefix)
370375

371376
def replacement(self, match_layer):
372377
if _has_custom_quantize_config(match_layer):

tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import collections
1818
import inspect
1919

20-
from keras import backend
2120
import numpy as np
2221
import tensorflow as tf
2322

@@ -29,6 +28,13 @@
2928
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantize_registry
3029
from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms
3130

31+
32+
try:
33+
from keras.backend import unique_object_name # pylint: disable=g-import-not-at-top
34+
except ImportError:
35+
# Path as seen in pip packages as of TF/Keras 2.13.
36+
from keras.src.backend import unique_object_name # pylint: disable=g-import-not-at-top
37+
3238
LayerNode = transforms.LayerNode
3339
LayerPattern = transforms.LayerPattern
3440

@@ -395,9 +401,9 @@ def pattern(self):
395401
return LayerPattern('SeparableConv1D')
396402

397403
def _get_name(self, prefix):
398-
# TODO(pulkitb): Move away from `backend.unique_object_name` since it isn't
404+
# TODO(pulkitb): Move away from `unique_object_name` since it isn't
399405
# exposed as externally usable.
400-
return backend.unique_object_name(prefix)
406+
return unique_object_name(prefix)
401407

402408
def replacement(self, match_layer):
403409
if _has_custom_quantize_config(match_layer):

tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,16 @@
1414
# ==============================================================================
1515
"""Registry responsible for built-in keras classes."""
1616

17-
from keras.engine import base_layer
1817
import tensorflow as tf
1918

2019
from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer
2120

21+
try:
22+
from keras.engine import base_layer # pylint: disable=g-import-not-at-top
23+
except ImportError:
24+
# Path as seen in pip packages as of TF/Keras 2.13.
25+
from keras.src.engine import base_layer # pylint: disable=g-import-not-at-top
26+
2227
# TODO(b/139939526): move to public API.
2328

2429
layers = tf.keras.layers

0 commit comments

Comments
 (0)