1919from __future__ import print_function
2020
2121# import g3
22+ import numpy as np
2223import tensorflow as tf
2324
2425from tensorflow .python .keras import backend as K
@@ -61,12 +62,16 @@ def on_epoch_end(self, batch, logs=None):
6162 # At the end of every epoch, remask the weights. This ensures that when
6263 # the model is saved after completion, the weights represent mask*weights.
6364 layers = self .model .layers
65+ weight_mask_ops = []
66+
6467 for layer in layers :
6568 if isinstance (layer , pruning_wrapper .PruneLowMagnitude ):
6669 if tf .executing_eagerly ():
6770 layer .pruning_obj .weight_mask_op ()
6871 else :
69- K .get_session ().run (layer .pruning_obj .weight_mask_op ())
72+ weight_mask_ops .append (layer .pruning_obj .weight_mask_op ())
73+
74+ K .batch_get_value (weight_mask_ops )
7075
7176
7277class PruningSummaries (callbacks .TensorBoard ):
@@ -83,15 +88,28 @@ def on_epoch_end(self, batch, logs=None):
8388 super (PruningSummaries , self ).on_epoch_end (batch , logs )
8489
8590 pruning_logs = {}
91+ params = []
8692 layers = self .model .layers
8793 for layer in layers :
8894 if isinstance (layer , pruning_wrapper .PruneLowMagnitude ):
8995 for _ , mask , threshold in layer .pruning_vars :
90- pruning_logs .update ({
91- mask .name + '/sparsity' :
92- K .get_value (1.0 - math_ops .reduce_mean (mask ))
93- })
94- pruning_logs .update (
95- {threshold .name + '/threshold' : K .get_value (threshold )})
96- self ._log_metrics (pruning_logs , '' ,
97- K .get_value (self .model .optimizer .iterations ))
96+ params .append (mask )
97+ params .append (threshold )
98+ params .append (self .model .optimizer .iterations )
99+
100+ values = K .batch_get_value (params )
101+ iteration = values [- 1 ]
102+ del values [- 1 ]
103+ del params [- 1 ]
104+
105+ param_value_pairs = zip (params , values )
106+
107+ for mask , mask_value in param_value_pairs [::2 ]:
108+ pruning_logs .update ({
109+ mask .name + '/sparsity' : 1 - np .mean (mask_value )
110+ })
111+
112+ for threshold , threshold_value in param_value_pairs [1 ::2 ]:
113+ pruning_logs .update ({threshold .name + '/threshold' : threshold_value })
114+
115+ self ._log_metrics (pruning_logs , '' , iteration )
0 commit comments