1919from __future__ import print_function
2020# import g3
2121import numpy as np
22- from tensorflow .python .eager import context
23- from tensorflow .python .framework import constant_op
24- from tensorflow .python .framework import dtypes
25- from tensorflow .python .framework import ops
22+
23+ import tensorflow .compat .v1 as tf
24+ # TODO(tf-mot): when migrating to 2.0, K.get_session() no longer exists.
25+ K = tf .keras .backend
26+ dtypes = tf .dtypes
27+ test = tf .test
28+
2629from tensorflow .python .framework import test_util as tf_test_util
27- from tensorflow .python .keras import backend as K
28- from tensorflow .python .ops import math_ops
29- from tensorflow .python .ops import partitioned_variables
30- from tensorflow .python .ops import state_ops
31- from tensorflow .python .ops import variable_scope
32- from tensorflow .python .ops import variables
33- from tensorflow .python .platform import test
3430from tensorflow_model_optimization .python .core .sparsity .keras import pruning_impl
3531from tensorflow_model_optimization .python .core .sparsity .keras import pruning_schedule
3632from tensorflow_model_optimization .python .core .sparsity .keras import pruning_utils
@@ -66,7 +62,7 @@ def testUpdateSingleMask(self):
6662 mask_before_pruning = K .get_value (mask )
6763 self .assertAllEqual (np .count_nonzero (mask_before_pruning ), 100 )
6864
69- if context .executing_eagerly ():
65+ if tf .executing_eagerly ():
7066 p .conditional_mask_update ()
7167 else :
7268 K .get_session ().run (p .conditional_mask_update ())
@@ -121,7 +117,7 @@ def testBlockMaskingAvg(self):
121117 def testBlockMaskingMax (self ):
122118 block_size = (2 , 2 )
123119 block_pooling_type = "MAX"
124- weight = constant_op .constant ([[0.1 , 0.0 , 0.2 , 0.0 ], [0.0 , - 0.1 , 0.0 , - 0.2 ],
120+ weight = tf .constant ([[0.1 , 0.0 , 0.2 , 0.0 ], [0.0 , - 0.1 , 0.0 , - 0.2 ],
125121 [0.3 , 0.0 , 0.4 , 0.0 ], [0.0 , - 0.3 , 0.0 ,
126122 - 0.4 ]])
127123 expected_mask = [[0.0 , 0.0 , 0.0 , 0.0 ], [0.0 , 0.0 , 0.0 , 0.0 ],
@@ -133,7 +129,7 @@ def testBlockMaskingWithHigherDimensionsRaisesError(self):
133129 block_size = (2 , 2 )
134130 block_pooling_type = "AVG"
135131 # Weights as in testBlockMasking, but with one extra dimension.
136- weight = constant_op .constant ([[[0.1 , 0.1 , 0.2 , 0.2 ], [0.1 , 0.1 , 0.2 , 0.2 ],
132+ weight = tf .constant ([[[0.1 , 0.1 , 0.2 , 0.2 ], [0.1 , 0.1 , 0.2 , 0.2 ],
137133 [0.3 , 0.3 , 0.4 , 0.4 ], [0.3 , 0.3 , 0.4 ,
138134 0.4 ]]])
139135 expected_mask = [[[0.0 , 0.0 , 0.0 , 0.0 ], [0.0 , 0.0 , 0.0 , 0.0 ],
@@ -149,9 +145,9 @@ def testConditionalMaskUpdate(self):
149145 threshold = K .zeros ([])
150146
151147 def linear_sparsity (step ):
152- sparsity_val = ops .convert_to_tensor (
148+ sparsity_val = tf .convert_to_tensor (
153149 [0.0 , 0.1 , 0.1 , 0.3 , 0.3 , 0.5 , 0.5 , 0.5 , 0.5 , 0.5 ])
154- return ops .convert_to_tensor (True ), sparsity_val [step ]
150+ return tf .convert_to_tensor (True ), sparsity_val [step ]
155151
156152 # Set up pruning
157153 p = pruning_impl .Pruning (
@@ -163,14 +159,14 @@ def linear_sparsity(step):
163159
164160 non_zero_count = []
165161 for _ in range (10 ):
166- if context .executing_eagerly ():
162+ if tf .executing_eagerly ():
167163 p .conditional_mask_update ()
168164 p .weight_mask_op ()
169- state_ops .assign_add (self .global_step , 1 )
165+ tf .assign_add (self .global_step , 1 )
170166 else :
171167 K .get_session ().run (p .conditional_mask_update ())
172168 K .get_session ().run (p .weight_mask_op ())
173- K .get_session ().run (state_ops .assign_add (self .global_step , 1 ))
169+ K .get_session ().run (tf .assign_add (self .global_step , 1 ))
174170
175171 non_zero_count .append (np .count_nonzero (K .get_value (weight )))
176172
@@ -180,4 +176,5 @@ def linear_sparsity(step):
180176
181177
182178if __name__ == "__main__" :
179+ tf .disable_v2_behavior ()
183180 test .main ()
0 commit comments