Skip to content

Commit d2b0d9b

Browse files
author
Daniel
committed
Freeze layers for transfer learning.
1 parent 555a265 commit d2b0d9b

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

training/deepspeech_training/train.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,35 @@ def get_tower_results(iterator, optimizer, dropout_rates):
321321
# Retain tower's avg losses
322322
tower_avg_losses.append(avg_loss)
323323

324+
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
325+
326+
# Filter out layers if we want to freeze some
327+
if FLAGS.freeze_source_layers > 0:
328+
filter_vars = []
329+
if FLAGS.freeze_source_layers <= 5:
330+
filter_vars.append("layer_1")
331+
if FLAGS.freeze_source_layers <= 4:
332+
filter_vars.append("layer_2")
333+
if FLAGS.freeze_source_layers <= 3:
334+
filter_vars.append("layer_3")
335+
if FLAGS.freeze_source_layers <= 2:
336+
filter_vars.append("lstm")
337+
if FLAGS.freeze_source_layers <= 1:
338+
filter_vars.append("layer_5")
339+
340+
new_train_vars = list(train_vars)
341+
for fv in filter_vars:
342+
for tv in train_vars:
343+
if fv in tv.name:
344+
new_train_vars.remove(tv)
345+
train_vars = new_train_vars
346+
msg = "Tower {} - Training only variables: {}"
347+
print(msg.format(i, [v.name for v in train_vars]))
348+
else:
349+
print("Tower {} - Training all layers".format(i))
350+
324351
# Compute gradients for model parameters using tower's mini-batch
325-
gradients = optimizer.compute_gradients(avg_loss)
352+
gradients = optimizer.compute_gradients(avg_loss, var_list=train_vars)
326353

327354
# Retain tower's gradients
328355
tower_gradients.append(gradients)
@@ -667,7 +694,6 @@ def __call__(self, progress, data, **kwargs):
667694

668695
print('-' * 80)
669696

670-
671697
except KeyboardInterrupt:
672698
pass
673699
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))

training/deepspeech_training/util/checkpoints.py

+15
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,21 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
4646
'tensors. Missing variables: {}'.format(missing_var_names))
4747
sys.exit(1)
4848

49+
if FLAGS.load_frozen_graph:
50+
# After training with "freeze_source_layers" the Adam tensors for the frozen layers aren't
51+
# existing anymore because they were not used
52+
# Therefore we have to initialize them again to continue training on such checkpoints
53+
for v in load_vars:
54+
if v.op.name not in vars_in_ckpt:
55+
if 'Adam' in v.name:
56+
init_vars.add(v)
57+
else:
58+
msg = "Tried to load a frozen checkpoint but there was a missing " \
59+
"variable other than the Adam tensors: {}"
60+
log_error(msg.format(v))
61+
sys.exit(1)
62+
load_vars -= init_vars
63+
4964
if allow_drop_layers and FLAGS.drop_source_layers > 0:
5065
# This transfer learning approach requires supplying
5166
# the layers which we exclude from the source model.

training/deepspeech_training/util/flags.py

+2
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def create_flags():
9393
# Transfer Learning
9494

9595
f.DEFINE_integer('drop_source_layers', 0, 'single integer for how many layers to drop from source model (to drop just output == 1, drop penultimate and output ==2, etc)')
96+
f.DEFINE_integer('freeze_source_layers', 0, 'use same value as above to freeze the other layers')
97+
f.DEFINE_boolean('load_frozen_graph', False, 'Needed to load a graph checkpoint which was trained with "freeze_source_layers" flag. Allows initialization of missing optimization tensors.')
9698

9799
# Exporting
98100

0 commit comments

Comments
 (0)