Skip to content

Commit a0d5597

Browse files
author
Daniel
committed
Refactor freezing.
1 parent fadaf2a commit a0d5597

File tree

3 files changed

+45
-53
lines changed

3 files changed

+45
-53
lines changed

training/deepspeech_training/train.py

+2-13
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from .evaluate import evaluate
3030
from six.moves import zip, range
3131
from .util.config import Config, initialize_globals
32-
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation, reload_best_checkpoint
32+
from .util.checkpoints import drop_freeze_number_to_layers, load_or_init_graph_for_training, load_graph_for_evaluation, reload_best_checkpoint
3333
from .util.evaluate_tools import save_samples_json
3434
from .util.feeding import create_dataset, audio_to_features, audiofile_to_features
3535
from .util.flags import create_flags, FLAGS
@@ -326,18 +326,7 @@ def get_tower_results(iterator, optimizer, dropout_rates):
326326

327327
# Filter out layers if we want to freeze some
328328
if FLAGS.freeze_source_layers > 0:
329-
filter_vars = []
330-
if FLAGS.freeze_source_layers <= 5:
331-
filter_vars.append("layer_1")
332-
if FLAGS.freeze_source_layers <= 4:
333-
filter_vars.append("layer_2")
334-
if FLAGS.freeze_source_layers <= 3:
335-
filter_vars.append("layer_3")
336-
if FLAGS.freeze_source_layers <= 2:
337-
filter_vars.append("lstm")
338-
if FLAGS.freeze_source_layers <= 1:
339-
filter_vars.append("layer_5")
340-
329+
filter_vars = drop_freeze_number_to_layers(FLAGS.freeze_source_layers, "freeze")
341330
new_train_vars = list(train_vars)
342331
for fv in filter_vars:
343332
for tv in train_vars:

training/deepspeech_training/util/checkpoints.py

+41-37
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import sys
2-
import tensorflow as tf
2+
33
import tensorflow.compat.v1 as tfv1
44

55
from .flags import FLAGS
6-
from .logging import log_info, log_error, log_warn
6+
from .logging import log_error, log_info, log_warn
77

88

99
def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True):
@@ -19,47 +19,33 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
1919
# compatibility with older checkpoints.
2020
lr_var = set(v for v in load_vars if v.op.name == 'learning_rate')
2121
if lr_var and ('learning_rate' not in vars_in_ckpt or
22-
(FLAGS.force_initialize_learning_rate and allow_lr_init)):
22+
(FLAGS.force_initialize_learning_rate and allow_lr_init)):
2323
assert len(lr_var) <= 1
2424
load_vars -= lr_var
2525
init_vars |= lr_var
2626

27-
if FLAGS.load_cudnn:
28-
# Initialize training from a CuDNN RNN checkpoint
29-
# Identify the variables which we cannot load, and set them
30-
# for initialization
31-
missing_vars = set()
32-
for v in load_vars:
33-
if v.op.name not in vars_in_ckpt:
34-
log_warn('CUDNN variable not found: %s' % (v.op.name))
35-
missing_vars.add(v)
27+
# After training with "freeze_source_layers" the Adam moment tensors for the frozen layers
28+
# are missing because they were not used. This might also occur when loading a cudnn checkpoint
29+
# Therefore we have to initialize them again to continue training on such checkpoints
30+
print_msg = False
31+
for v in load_vars:
32+
if v.op.name not in vars_in_ckpt:
33+
if 'Adam' in v.name:
3634
init_vars.add(v)
35+
print_msg = True
36+
if print_msg:
37+
msg = "Some Adam tensors are missing, they will be initialized automatically."
38+
log_info(msg)
39+
load_vars -= init_vars
3740

38-
load_vars -= init_vars
39-
40-
# Check that the only missing variables (i.e. those to be initialised)
41-
# are the Adam moment tensors, if they aren't then we have an issue
42-
missing_var_names = [v.op.name for v in missing_vars]
43-
if any('Adam' not in v for v in missing_var_names):
44-
log_error('Tried to load a CuDNN RNN checkpoint but there were '
45-
'more missing variables than just the Adam moment '
46-
'tensors. Missing variables: {}'.format(missing_var_names))
47-
sys.exit(1)
48-
49-
if FLAGS.load_frozen_graph:
50-
# After training with "freeze_source_layers" the Adam tensors for the frozen layers aren't
51-
# existing anymore because they were not used
52-
# Therefore we have to initialize them again to continue training on such checkpoints
41+
if FLAGS.load_cudnn:
42+
# Check all required tensors are included in the cudnn checkpoint we want to load
5343
for v in load_vars:
54-
if v.op.name not in vars_in_ckpt:
55-
if 'Adam' in v.name:
56-
init_vars.add(v)
57-
else:
58-
msg = "Tried to load a frozen checkpoint but there was a missing " \
59-
"variable other than the Adam tensors: {}"
60-
log_error(msg.format(v))
61-
sys.exit(1)
62-
load_vars -= init_vars
44+
if v.op.name not in vars_in_ckpt and 'Adam' not in v.op.name:
45+
msg = 'Tried to load a CuDNN RNN checkpoint but there was a missing' \
46+
' variable other than an Adam moment tensor: {}'
47+
log_error(msg.format(v.op.name))
48+
sys.exit(1)
6349

6450
if allow_drop_layers and FLAGS.drop_source_layers > 0:
6551
# This transfer learning approach requires supplying
@@ -74,7 +60,7 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
7460
'dropping only 5 layers.')
7561
FLAGS.drop_source_layers = 5
7662

77-
dropped_layers = ['2', '3', 'lstm', '5', '6'][-1 * int(FLAGS.drop_source_layers):]
63+
dropped_layers = drop_freeze_number_to_layers(FLAGS.drop_source_layers, "drop")
7864
# Initialize all variables needed for DS, but not loaded from ckpt
7965
for v in load_vars:
8066
if any(layer in v.op.name for layer in dropped_layers):
@@ -90,6 +76,24 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
9076
session.run(v.initializer)
9177

9278

79+
def drop_freeze_number_to_layers(drop_freeze_number, mode):
80+
""" Convert number of layers to drop or freeze into layer names """
81+
82+
if drop_freeze_number >= 6:
83+
log_warn('The checkpoint only has 6 layers, but you are trying '
84+
'to drop or freeze all of them or more. Continuing with 5 layers.')
85+
drop_freeze_number = 5
86+
87+
layer_keys = ["layer_1", "layer_2", "layer_3", "lstm", "layer_5", "layer_6"]
88+
if mode == "drop":
89+
layer_keys = layer_keys[-1 * int(drop_freeze_number):]
90+
elif mode == "freeze":
91+
layer_keys = layer_keys[:-1 * int(drop_freeze_number)]
92+
else:
93+
raise ValueError
94+
return layer_keys
95+
96+
9397
def _checkpoint_path_or_none(checkpoint_filename):
9498
checkpoint = tfv1.train.get_checkpoint_state(FLAGS.load_checkpoint_dir, checkpoint_filename)
9599
if not checkpoint:

training/deepspeech_training/util/flags.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,8 @@ def create_flags():
9292

9393
# Transfer Learning
9494

95-
f.DEFINE_integer('drop_source_layers', 0, 'single integer for how many layers to drop from source model (to drop just output == 1, drop penultimate and output ==2, etc)')
96-
f.DEFINE_integer('freeze_source_layers', 0, 'use same value as above to freeze the other layers')
97-
f.DEFINE_boolean('load_frozen_graph', False, 'Needed to load a graph checkpoint which was trained with "freeze_source_layers" flag. Allows initialization of missing optimization tensors.')
95+
f.DEFINE_integer('drop_source_layers', 0, 'single integer for how many layers to drop from source model (to drop just output == 1, drop penultimate and output == 2, etc)')
96+
f.DEFINE_integer('freeze_source_layers', 0, 'freeze layer weights (to freeze all but output == 1, freeze all but penultimate and output == 2, etc). Normally used in combination with "drop_source_layers" flag and should be used in a two step training (first drop and freeze layers and train a few epochs, second continue without both flags)')
9897

9998
# Exporting
10099

0 commit comments

Comments
 (0)