Skip to content

Commit 1af566d

Browse files
dansuh17tensorflower-gardener
authored andcommitted
Create a local copy of keras.utils.generic_utils.to_snake_case at pruning_wrapper.py.
`to_snake_case` is private to the `keras.utils` module and is not exported externally. Since this is a short utility function, we can simply have a copy of this function locally. PiperOrigin-RevId: 524779757
1 parent cbcc4e0 commit 1af566d

File tree

3 files changed

+32
-5
lines changed

3 files changed

+32
-5
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/BUILD

-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ py_strict_library(
105105
":pruning_impl",
106106
":pruning_schedule",
107107
":pruning_utils",
108-
# keras/utils:generic_utils dep1,
109108
# numpy dep1,
110109
# tensorflow:tensorflow_no_contrib dep1,
111110
"//tensorflow_model_optimization/python/core/keras:compat",

tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py

+24-3
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
from __future__ import print_function
2121

2222
import inspect
23+
import re
2324

2425
# import g3
2526

26-
from keras.utils import generic_utils
2727
import numpy as np
2828
import tensorflow as tf
2929

@@ -42,6 +42,26 @@
4242
Wrapper = keras.layers.Wrapper
4343

4444

45+
def _to_snake_case(name: str) -> str:
46+
"""Converts `name` to snake case.
47+
48+
Example: "TensorFlow" -> "tensor_flow"
49+
50+
Args:
51+
name: The name of some python class.
52+
53+
Returns:
54+
`name` converted to snake case.
55+
"""
56+
intermediate = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
57+
insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower()
58+
# If the class is private the name starts with "_" which is not secure
59+
# for creating scopes. We prefix the name with "private" in this case.
60+
if insecure[0] != '_':
61+
return insecure
62+
return 'private' + insecure
63+
64+
4565
class PruneLowMagnitude(Wrapper):
4666
"""This wrapper augments a keras layer so the weight tensor may be pruned.
4767
@@ -154,8 +174,9 @@ def __init__(self,
154174
# TODO(pulkitb): This should be pushed up to the wrappers.py
155175
# Name the layer using the wrapper and underlying layer name.
156176
# Prune(Dense) becomes prune_dense_1
157-
kwargs.update({'name': '{}_{}'.format(
158-
generic_utils.to_snake_case(self.__class__.__name__), layer.name)})
177+
kwargs.update(
178+
{'name': f'{_to_snake_case(self.__class__.__name__)}_{layer.name}'}
179+
)
159180

160181
if isinstance(layer, prunable_layer.PrunableLayer) or hasattr(
161182
layer, 'get_prunable_weights'):

tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper_test.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,14 @@ def testCustomLayerPrunable(self):
144144
layer = CustomLayerPrunable(input_dim=16, output_dim=32)
145145
inputs = keras.layers.Input(shape=(16))
146146
_ = layer(inputs)
147-
pruning_wrapper.PruneLowMagnitude(layer, block_pooling_type='MAX')
147+
pruned_layer = pruning_wrapper.PruneLowMagnitude(
148+
layer, block_pooling_type='MAX'
149+
)
150+
# The name is the layer's name prefixed by the snake_case version of the
151+
# `PruneLowMagnitude` class's name.
152+
self.assertEqual(
153+
pruned_layer.name, 'prune_low_magnitude_custom_layer_prunable'
154+
)
148155

149156
def testCollectPrunableLayers(self):
150157
lstm_layer = keras.layers.RNN(

0 commit comments

Comments
 (0)