Skip to content

Commit f38f5ae

Browse files
authored
Recover relaxed behavior, strict with new behavior version (#1144)
A check on matching time dim of RecLayer sub output layer to the RecLayer time dim. Fix #1140 This introduces a new behavior version 13 (#508).
1 parent 64cb7fa commit f38f5ae

File tree

5 files changed

+159
-29
lines changed

5 files changed

+159
-29
lines changed

docs/configuration_reference/behavior_version.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,24 @@ and not listing legacy/deprecated parameters.
2222
Version History
2323
---------------
2424

25+
Behavior version 13 (2022-10-13)
26+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
27+
28+
This enables some extra checks in the :class:`RecLayer` which break some old configs,
29+
where the old configs where actually broken,
30+
but those broken parts did not play a role for the training
31+
and thus it did not matter.
32+
However, we don't want to allow such broken configs anymore.
33+
More specifically, an optimized-out ``output`` sub-layer of a :class:`RecLayer`
34+
must have the same time dim as the :class:`RecLayer` itself.
35+
For some specific transducer configs, we have this problem
36+
(`example <https://github.com/rwth-i6/returnn-experiments/blob/264d13aef3321d48f685cc9750fd277fb70cc74e/2020-rnn-transducer/configs/rna-tf2.blank0.enc6l-grow2l.scratch-lm.rdrop02.lm1-1024.attwb5-drop02.l2_1e_4.mlr50.config#L778>`__).
37+
38+
This behavior version might also require
39+
that the dim tags of ``extern_data`` are properly defined.
40+
41+
See issue `#1140 <https://github.com/rwth-i6/returnn/issues/1140>`__.
42+
2543
Behavior version 12 (2022-01-06)
2644
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2745

returnn/tf/layers/rec.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2325,6 +2325,7 @@ def get_output(self):
23252325
:rtype: tf.Tensor
23262326
"""
23272327
from returnn.tf.util.basic import check_input_dim, tensor_array_stack, Dim, get_valid_scope_name_from_str
2328+
from returnn.util.basic import BehaviorVersion
23282329
assert self.parent_rec_layer
23292330
rec_layer = self.parent_rec_layer
23302331

@@ -3067,6 +3068,14 @@ def cond(i, net_vars, acc_tas, seq_len_info=None, allow_inf_max_len=False):
30673068
layer = self.net.layers[name]
30683069
assert layer.search_choices
30693070

3071+
for key in (
3072+
self.net.used_data_keys |
3073+
(self.input_layers_net.used_data_keys if self.input_layers_net else set()) |
3074+
(self.output_layers_net.used_data_keys if self.output_layers_net else set())):
3075+
if key == "source":
3076+
continue
3077+
self.parent_net.used_data_keys.add(key)
3078+
30703079
with tf.name_scope("output"):
30713080
output_layer = None
30723081
if self.input_layers_net and "output" in self.input_layers_net.layers:
@@ -3076,7 +3085,19 @@ def cond(i, net_vars, acc_tas, seq_len_info=None, allow_inf_max_len=False):
30763085
if output_layer:
30773086
assert isinstance(output_layer, LayerBase)
30783087
output_data = output_layer.output.copy_as_time_major()
3079-
self.time_dim_tag.declare_same_as(output_data.get_time_dim_tag())
3088+
if not self.time_dim_tag.is_dim_known():
3089+
self.time_dim_tag.declare_same_as(output_data.get_time_dim_tag())
3090+
elif self.time_dim_tag not in output_data.dim_tags:
3091+
# We allow this for older behavior version to not break some older setups.
3092+
# https://github.com/rwth-i6/returnn/issues/1140
3093+
BehaviorVersion.require(
3094+
False,
3095+
"%s: time-dim-tag mismatch: self %r vs sub-output-layer %r time-dim-tag %r" % (
3096+
rec_layer, self.time_dim_tag, output_data, output_data.get_time_dim_tag()), version=13)
3097+
# No further checks, it would fail anyway.
3098+
# Replace the actual rec layer output and return.
3099+
rec_layer.output = output_data
3100+
return output_data.placeholder
30803101
assert len(rec_layer.output.dim_tags) == len(output_data.dim_tags)
30813102
for tag1, tag2 in zip(rec_layer.output.dim_tags, output_data.dim_tags):
30823103
try:
@@ -3090,24 +3111,14 @@ def cond(i, net_vars, acc_tas, seq_len_info=None, allow_inf_max_len=False):
30903111
# and then created once for the template layer, and again for the real layer.
30913112
# Make sure they are really the same such that we get all information like dyn sizes.
30923113
tag1.declare_same_as(tag2)
3093-
output = output_data.placeholder
3114+
return output_data.placeholder
30943115
else:
30953116
assert seq_len is not None
30963117
rec_layer.output.size_placeholder[0] = seq_len
30973118
assert not self.net.layers["output"].get_search_choices()
3098-
output = tensor_array_stack(
3119+
return tensor_array_stack(
30993120
self.final_acc_tas_dict["output_output"], stop=max_seq_len, name="output_stack") # e.g. (time, batch, dim)
31003121

3101-
for key in (
3102-
self.net.used_data_keys |
3103-
(self.input_layers_net.used_data_keys if self.input_layers_net else set()) |
3104-
(self.output_layers_net.used_data_keys if self.output_layers_net else set())):
3105-
if key == "source":
3106-
continue
3107-
self.parent_net.used_data_keys.add(key)
3108-
3109-
return output
3110-
31113122
def _get_search_choice_seq(self, search_choices):
31123123
"""
31133124
:param SearchChoices search_choices:

returnn/util/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ class BehaviorVersion:
238238
The version will be set after the config is defined at __main__.init_config() or Engine.__init__()
239239
"""
240240

241-
_latest_behavior_version = 12
241+
_latest_behavior_version = 13
242242
_behavior_version = None # type: typing.Optional[int]
243243

244244
@classmethod

tests/test_TFNetworkLayer.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5871,14 +5871,16 @@ def test_SliceNdLayer_dyn_size():
58715871

58725872
def test_SliceNdLayer_multidimensional_start():
58735873
with make_scope() as session:
5874-
n_out = 5
5874+
from returnn.tf.util.data import batch_dim, SpatialDim, FeatureDim
5875+
out_dim = FeatureDim("feat", 3)
5876+
time_dim = SpatialDim("time")
58755877
n_batch = 3
58765878
max_seq_len = 10
58775879
config = Config({
58785880
"debug_print_layer_output_template": True,
58795881
"extern_data": {
5880-
"data": {"dim": n_out},
5881-
"classes": {"dim": n_out, "sparse": True}
5882+
"data": {"dim_tags": [batch_dim, time_dim, out_dim]},
5883+
"classes": {"dim_tags": [batch_dim, time_dim], "sparse_dim": out_dim}
58825884
}})
58835885
net = TFNetwork(config=config, train_flag=True)
58845886
net.construct_from_dict({
@@ -5902,7 +5904,7 @@ def test_SliceNdLayer_multidimensional_start():
59025904
input_data = feed[net.extern_data.data["data"].placeholder]
59035905
max_size = numpy.amax(seq_lens[:, None] - starts)
59045906
max_size = max(max_size, 0)
5905-
assert segments.shape == (n_batch, max_seq_len, max_size, n_out)
5907+
assert segments.shape == (n_batch, max_seq_len, max_size, out_dim.dimension)
59065908
for b in range(n_batch):
59075909
for t in range(max_seq_len):
59085910
s = starts[b, t]
@@ -5911,22 +5913,24 @@ def test_SliceNdLayer_multidimensional_start():
59115913
orig_seq = numpy.pad(orig_seq, [(0, max_size - len(orig_seq)), (0, 0)], "constant")
59125914
elif len(orig_seq) > max_size:
59135915
orig_seq = orig_seq[:max_size]
5914-
assert orig_seq.shape == (max_size, n_out)
5916+
assert orig_seq.shape == (max_size, out_dim.dimension)
59155917
orig_seq = numpy.where((numpy.arange(s, s + max_size) >= seq_lens[b])[:, None], 0.0, orig_seq)
59165918
for t2 in range(max_size):
59175919
numpy.testing.assert_equal(orig_seq[t2], segments[b, t, t2])
59185920

59195921

59205922
def test_SliceNdLayer_multidimensional_size():
59215923
with make_scope() as session:
5922-
n_out = 5
5924+
from returnn.tf.util.data import batch_dim, SpatialDim, FeatureDim
5925+
out_dim = FeatureDim("feat", 3)
5926+
time_dim = SpatialDim("time")
59235927
n_batch = 3
59245928
max_seq_len = 10
59255929
config = Config({
59265930
"debug_print_layer_output_template": True,
59275931
"extern_data": {
5928-
"data": {"dim": n_out},
5929-
"classes": {"dim": n_out, "sparse": True}
5932+
"data": {"dim_tags": [batch_dim, time_dim, out_dim]},
5933+
"classes": {"dim_tags": [batch_dim, time_dim], "sparse_dim": out_dim}
59305934
}})
59315935
net = TFNetwork(config=config, train_flag=True)
59325936
net.construct_from_dict({
@@ -5954,7 +5958,7 @@ def test_SliceNdLayer_multidimensional_size():
59545958
input_data = feed[net.extern_data.data["data"].placeholder]
59555959
max_size = numpy.amax(sizes)
59565960
max_size = max(max_size, 0)
5957-
assert segments.shape == (n_batch, max_seq_len, max_size, n_out)
5961+
assert segments.shape == (n_batch, max_seq_len, max_size, out_dim.dimension)
59585962
for b in range(n_batch):
59595963
for t in range(max_seq_len):
59605964
s = starts[b, t]
@@ -5965,7 +5969,7 @@ def test_SliceNdLayer_multidimensional_size():
59655969
orig_seq = numpy.pad(orig_seq, [(0, max_size - len(orig_seq)), (0, 0)], "constant")
59665970
elif len(orig_seq) > max_size:
59675971
orig_seq = orig_seq[:max_size]
5968-
assert orig_seq.shape == (max_size, n_out)
5972+
assert orig_seq.shape == (max_size, out_dim.dimension)
59695973
orig_seq = numpy.where((numpy.arange(s, s + max_size) >= seq_lens[b])[:, None], 0.0, orig_seq)
59705974
for t2 in range(max_size):
59715975
numpy.testing.assert_equal(orig_seq[t2], segments[b, t, t2])

tests/test_TFNetworkRecLayer.py

Lines changed: 102 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -600,11 +600,17 @@ def _enc_func(source, **_):
600600

601601

602602
def test_rec_subnet_with_choice():
603+
from returnn.tf.util.data import batch_dim, SpatialDim, FeatureDim
604+
in_dim = FeatureDim("feat", 3)
605+
out_dim = FeatureDim("classes", 4)
606+
time_dim = SpatialDim("time")
603607
with tf_compat.v1.Session():
604608
config = Config()
605609
config.update({
606-
"num_outputs": 3,
607-
"num_inputs": 4,
610+
"extern_data": {
611+
"data": {"dim_tags": [batch_dim, time_dim, in_dim]},
612+
"classes": {"dim_tags": [batch_dim, time_dim], "sparse_dim": out_dim}
613+
},
608614
"network": {
609615
"output": {"class": "rec", "from": "data:data", "target": "classes", "unit": {
610616
"prob": {"class": "softmax", "from": ["prev:output"], "loss": "ce", "target": "classes"},
@@ -1220,12 +1226,16 @@ def test_rec_RecStepInfoLayer_broadcast_moved_out():
12201226
},
12211227
}
12221228
}
1229+
from returnn.tf.util.data import batch_dim, SpatialDim, FeatureDim
1230+
in_dim = FeatureDim("feat", 3)
1231+
out_dim = FeatureDim("classes", 5)
1232+
time_dim = SpatialDim("time")
12231233
config = Config({
12241234
"debug_print_layer_output_template": True,
12251235
"extern_data": {
1226-
"data": {"dim": 3},
1227-
"classes": {"sparse": True, "dim": 5},
1228-
}
1236+
"data": {"dim_tags": [batch_dim, time_dim, in_dim]},
1237+
"classes": {"dim_tags": [batch_dim, time_dim], "sparse_dim": out_dim}
1238+
},
12291239
})
12301240
from test_TFNetworkLayer import make_feed_dict
12311241
with make_scope() as session:
@@ -6158,6 +6168,93 @@ def test_reclayer_shape_from_initial():
61586168
session.run(out.placeholder, feed_dict=make_feed_dict(net.extern_data))
61596169

61606170

6171+
def test_reclayer_time_sync_target_diff():
6172+
# https://github.com/rwth-i6/returnn/issues/1140
6173+
from returnn.util.basic import BehaviorVersion
6174+
from returnn.tf.util.data import batch_dim, SpatialDim, FeatureDim
6175+
from returnn.tf.layers.rec import _SubnetworkRecCell
6176+
src_dim = FeatureDim("src-feat", 5)
6177+
tgt_dim = FeatureDim("tgt-classes", 7)
6178+
tgt_with_blank_dim = tgt_dim + 1
6179+
src_time_dim = SpatialDim("src-time")
6180+
tgt_time_dim = SpatialDim("out-spatial")
6181+
6182+
config = Config({
6183+
"extern_data": {
6184+
"data": {"dim_tags": [batch_dim, src_time_dim, src_dim]},
6185+
"classes": {"dim_tags": [batch_dim, tgt_time_dim], "sparse_dim": tgt_dim, "available_for_inference": False},
6186+
"align_classes": {
6187+
"dim_tags": [batch_dim, src_time_dim], "sparse_dim": tgt_with_blank_dim, "available_for_inference": False},
6188+
},
6189+
"network": {
6190+
"encoder": {"class": "linear", "activation": "tanh", "n_out": 5, "from": "data:data"},
6191+
6192+
"output": {"class": "rec", "from": "encoder", "unit": {
6193+
"output_prob": {"class": "softmax", "from": "data:source", "out_dim": tgt_with_blank_dim},
6194+
6195+
# Note: This is actually not correct to have 'classes' here.
6196+
# In practice, in search, it would use output_prob and then have actually one more label.
6197+
# classes also has the wrong spatial dim, which actually causes the error.
6198+
# However, then this output is actually never used.
6199+
# We had such training configs for transducer, and we want to make sure that they still work.
6200+
# In that case, in search, the config switched to a different target, so that is why it worked.
6201+
'output': {'class': 'choice', 'target': 'classes', 'beam_size': 12, 'from': "output_prob",
6202+
"initial_output": 0},
6203+
6204+
# Would also look different for recognition.
6205+
"classes_embed": {"class": "linear", "activation": "tanh", "n_out": 5, "from": "base:data:classes"},
6206+
"joint": {
6207+
"class": "combine", "from": ["output_prob", "classes_embed"], "kind": "mul",
6208+
"allow_broadcast_all_sources": True},
6209+
6210+
# Dummy loss. In transducer, this would be the full-sum after joint network.
6211+
# Here we just need sth to trigger the dependencies.
6212+
"loss": {
6213+
"class": "eval", "from": "joint",
6214+
"eval": "tf.reduce_mean(source(0,auto_convert=False))",
6215+
"out_type": {"shape": (), "dtype": "float32", "batch_dim_axis": None, "time_dim_axis": None},
6216+
"loss": "as_is"
6217+
},
6218+
6219+
}, "target": "classes"},
6220+
}})
6221+
6222+
print("Constructing train network (old behavior).")
6223+
with make_scope() as session:
6224+
net = TFNetwork(train_flag=True, config=config)
6225+
orig_behavior_version = BehaviorVersion._behavior_version
6226+
try:
6227+
BehaviorVersion._behavior_version = 0
6228+
# The net dict requires an older behavior version. This is important for the test.
6229+
# We want to make sure such old config still works.
6230+
net.construct_from_dict(config.typed_value("network"))
6231+
finally:
6232+
BehaviorVersion._behavior_version = orig_behavior_version
6233+
# Check whether we triggered the dim tag bug.
6234+
assert src_time_dim != tgt_time_dim
6235+
net.initialize_params(session)
6236+
rec_layer = net.get_layer("output")
6237+
assert isinstance(rec_layer, RecLayer)
6238+
cell = rec_layer.cell
6239+
assert isinstance(cell, _SubnetworkRecCell)
6240+
assert_equal(cell.layers_in_loop, [])
6241+
loss = net.get_total_loss()
6242+
from test_TFNetworkLayer import make_feed_dict
6243+
loss_v = session.run(loss, feed_dict=make_feed_dict(net.extern_data))
6244+
print("Loss:", loss_v)
6245+
6246+
print("Constructing train network (new behavior).")
6247+
with make_scope():
6248+
net = TFNetwork(train_flag=True, config=config)
6249+
try:
6250+
net.construct_from_dict(config.typed_value("network"))
6251+
except BehaviorVersion.RequirementNotSatisfied as exc:
6252+
assert "time-dim-tag mismatch" in str(exc)
6253+
print("Got expected exception:", exc)
6254+
else:
6255+
raise Exception("did not get expected exception")
6256+
6257+
61616258
def test_convert_lstm_params_save_load():
61626259
"""
61636260
Test conversions from different units to different units.

0 commit comments

Comments
 (0)