Skip to content

Commit 18c05e9

Browse files
authored
MergeDimsLayer, only allow keep_order=True (#784)
Fix #654. Also introduce Data.get_axes_from_description "dim:%i" variant. This is mostly as a simple way to fix many of the test cases. But I guess it could also be useful for the user in general.
1 parent 15fed20 commit 18c05e9

File tree

7 files changed

+113
-78
lines changed

7 files changed

+113
-78
lines changed

docs/configuration_reference/behavior_version.rst

+10
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ and not listing legacy/deprecated parameters.
2222
Version History
2323
---------------
2424

25+
Behavior version 6 (2021-11-27)
26+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
27+
28+
:class:`MergeDimsLayer` uses ``keep_order=True`` and does not allow ``keep_order=False``.
29+
There never should be a reason to use ``keep_order=False`` anyway.
30+
If you have that, just remove it.
31+
If that causes any problems, there is probably some other issue in your config.
32+
33+
See issue `#654 <https://github.com/rwth-i6/returnn/issues/654>`__.
34+
2535
Behavior version 5 (2021-11-26)
2636
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2737

returnn/tf/layers/basic.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -2848,26 +2848,34 @@ class MergeDimsLayer(_ConcatInputLayer):
28482848
"""
28492849
layer_class = "merge_dims"
28502850

2851-
def __init__(self, axes, keep_order=False, n_out=None, **kwargs):
2851+
def __init__(self, axes, keep_order=NotSpecified, n_out=None, **kwargs):
28522852
"""
2853-
:param str|list[str]|list[int] axes: see Data.get_axes_from_description(), e.g. "except_time"
2854-
:param bool keep_order: By default (for historical reasons), the axes are sorted, and then merged.
2853+
:param str|list[DimensionTag|str] axes: see :func:`Data.get_axis_from_description`
2854+
:param bool|NotSpecified keep_order: The old default was: the axes are sorted, and then merged.
28552855
Thus, the order of incoming axes will influence the result.
28562856
E.g. inputs [B,S,F] and [B,F,S], with ``axes=["S","F"]``, will get different results,
28572857
although the output shape is [B,S*F] in both cases.
28582858
This is bad: In general, other layers in RETURNN might reorder the axes for various reasons,
28592859
and all layers should behave in the same way, no matter the order.
28602860
It is recommended to set ``keep_order=True``, such that the order defined in ``axes`` defines the behavior,
28612861
and not the incoming axis order.
2862+
Since behavior version 6, this is already the case.
28622863
:param int|None n_out:
28632864
"""
2865+
from returnn.util import BehaviorVersion
28642866
super(MergeDimsLayer, self).__init__(**kwargs)
2867+
if keep_order is NotSpecified:
2868+
keep_order = True if BehaviorVersion.get() >= 6 else False
2869+
BehaviorVersion.require(
2870+
condition=keep_order, message="MergeDimsLayer, only keep_order=True is allowed", version=6)
28652871
if keep_order:
2866-
assert isinstance(axes, (tuple, list)), "%s: unique axes %r required" % (self, axes)
2872+
assert isinstance(axes, (tuple, list)), (
2873+
"%s: axes %r must be a list or tuple, to have a well defined order in input %s" % (self, axes, self.input_data))
28672874
axes_ = []
28682875
for axis in axes:
28692876
axis_ = self.input_data.get_axes_from_description(axis, allow_int=False)
2870-
assert len(axis_) <= 1, "%s: unique axes %r required, but got %r -> %r" % (self, axes, axis, axis_)
2877+
assert len(axis_) <= 1, (
2878+
"%s: unique axes %r required in input %s, but got %r -> %r" % (self, axes, self.input_data, axis, axis_))
28712879
axes_ += axis_
28722880
axes = axes_
28732881
else:
@@ -2981,18 +2989,20 @@ def _set_output_sizes(self, merge_axes):
29812989
target_tag.dyn_size_ext = out_size
29822990

29832991
@classmethod
2984-
def get_out_data_from_opts(cls, name, axes, keep_order=False,
2992+
def get_out_data_from_opts(cls, name, axes, keep_order=NotSpecified,
29852993
sources=(), n_out=NotSpecified, out_type=None, **kwargs):
29862994
"""
29872995
:param str name:
29882996
:param str|list[str] axes:
2989-
:param bool keep_order:
2997+
:param bool|NotSpecified keep_order:
29902998
:param list[LayerBase] sources:
29912999
:param int|None|NotSpecified n_out:
29923000
:param None|dict[str] out_type:
29933001
:rtype: Data
29943002
"""
2995-
from ..util.data import DimensionTag
3003+
from returnn.util import BehaviorVersion
3004+
if keep_order is NotSpecified:
3005+
keep_order = True if BehaviorVersion.get() >= 6 else False
29963006
assert not out_type, "currently ignored"
29973007
input_data = get_concat_sources_data_template(sources)
29983008
data = input_data.copy(name="%s_output" % name)

returnn/tf/util/data.py

+5
Original file line numberDiff line numberDiff line change
@@ -3530,6 +3530,11 @@ def get_axes_from_description(self, axes, allow_int=NotSpecified):
35303530
s += len(static_axes)
35313531
assert 0 <= s < len(static_axes), "%s get_axes_from_description: %r invalid" % (self, axes)
35323532
return [static_axes[s]]
3533+
elif re.match("(dim):\\d+$", axes):
3534+
s = int(axes.split(":")[1])
3535+
dims = [a for a in range(self.batch_ndim) if self.batch_shape[a] == s]
3536+
assert dims, "%s get_axes_from_description: no dim %i found" % (self, s)
3537+
return dims
35333538
elif axes in ["f", "feature", "non_spatial"]:
35343539
return self.get_feature_batch_axes()
35353540
elif all([a in "btf" for a in axes]):

returnn/util/basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ class BehaviorVersion:
209209
The version will be set after the config is defined at __main__.init_config() or Engine.__init__()
210210
"""
211211

212-
_latest_behavior_version = 5
212+
_latest_behavior_version = 6
213213
_behavior_version = None # type: typing.Optional[int]
214214

215215
@classmethod

tests/test_TFEngine.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -1233,7 +1233,7 @@ def test_attention_no_encoder_dependency():
12331233
'n_out': 4, 'padding': 'same'},
12341234
"location_feedback": {'class': 'linear', 'from': ['convolved_att'], 'n_out': 6, 'activation': None},
12351235
"att_energy_in": {'class': 'combine', 'kind': 'add', 'from': ['location_feedback', 's_transformed']},
1236-
"c": {"class": "generic_attention", "base": "base:encoder", "weights": "att_weights"},
1236+
"c": {"class": "generic_attention", "base": "base:encoder", "weights": "att_weights", "auto_squeeze": True},
12371237
},
12381238
},
12391239
"decision": {"class": "decide", "from": ["output"], "loss": "edit_distance"}
@@ -1345,7 +1345,7 @@ def test_attention_convolutional_feedback_variant1():
13451345
"location_feedback": {'class': 'linear', 'from': ['convolved_att'], 'n_out': 6, 'activation': None},
13461346
"att_energy_in": {'class': 'combine', 'kind': 'add', 'from': [
13471347
'base:enc_transformed', 'location_feedback', 's_transformed']},
1348-
"c": {"class": "generic_attention", "base": "base:encoder", "weights": "att_weights"},
1348+
"c": {"class": "generic_attention", "base": "base:encoder", "weights": "att_weights", "auto_squeeze": True},
13491349
}
13501350

13511351
check_attention_variant(recurrent_unit_dict)
@@ -1373,7 +1373,7 @@ def test_attention_convolutional_feedback_variant2():
13731373
"location_feedback": {'class': 'linear', 'from': ['convolved_att'], 'n_out': 6, 'activation': None},
13741374
"att_energy_in": {'class': 'combine', 'kind': 'add', 'from': [
13751375
'base:enc_transformed', 'location_feedback', 's_transformed']},
1376-
"c": {"class": "generic_attention", "base": "base:encoder", "weights": "att_weights"},
1376+
"c": {"class": "generic_attention", "base": "base:encoder", "weights": "att_weights", "auto_squeeze": True},
13771377
}
13781378

13791379
check_attention_variant(recurrent_unit_dict)
@@ -1412,7 +1412,7 @@ def test_attention_convolutional_feedback_variant3():
14121412
"location_feedback": {'class': 'linear', 'from': ['convolved_att'], 'n_out': 6, 'activation': None},
14131413
"att_energy_in": {'class': 'combine', 'kind': 'add', 'from': [
14141414
'base:enc_transformed', 'location_feedback', 's_transformed']},
1415-
"c": {"class": "generic_attention", "base": "base:encoder", "weights": "att_weights"},
1415+
"c": {"class": "generic_attention", "base": "base:encoder", "weights": "att_weights", "auto_squeeze": True},
14161416
}
14171417

14181418
check_attention_variant(recurrent_unit_dict)
@@ -2135,7 +2135,7 @@ def test_rec_subnet_construct_1():
21352135
"accum_att_weights": {"class": "eval", "from": ["prev:accum_att_weights", "att_weights", "base:inv_fertility"],
21362136
"eval": "source(0) + source(1) * source(2) * 0.5",
21372137
"out_type": {"dim": 1, "shape": (None, 1)}},
2138-
"att": {"class": "generic_attention", "weights": "att_weights", "base": "base:encoder"},
2138+
"att": {"class": "generic_attention", "weights": "att_weights", "base": "base:encoder", "auto_squeeze": True},
21392139
"s": {"class": "rnn_cell", "unit": "LSTMBlock", "from": ["target_embed", "att"], "n_out": 10},
21402140
"s2": {"class": "rnn_cell", "unit": "LSTMBlock", "from": ["s"], "n_out": 10},
21412141
"readout_in": {"class": "linear", "from": ["prev:s2", "prev:target_embed", "att"], "activation": None, "n_out": 10},
@@ -2192,7 +2192,7 @@ def test_rec_subnet_construct_2():
21922192
"accum_att_weights": {"class": "eval", "from": ["prev:accum_att_weights", "att_weights", "base:inv_fertility"],
21932193
"eval": "source(0) + source(1) * source(2) * 0.5",
21942194
"out_type": {"dim": 1, "shape": (None, 1)}},
2195-
"att": {"class": "generic_attention", "weights": "att_weights", "base": "base:encoder"},
2195+
"att": {"class": "generic_attention", "weights": "att_weights", "base": "base:encoder", "auto_squeeze": True},
21962196
"s": {"class": "rnn_cell", "unit": "LSTMBlock", "from": ["target_embed", "att"], "n_out": 10},
21972197
"s2": {"class": "rnn_cell", "unit": "LSTMBlock", "from": ["s"], "n_out": 10},
21982198
"readout_in": {"class": "linear", "from": ["prev:s2", "prev:target_embed", "att"], "activation": None, "n_out": 10},
@@ -2255,7 +2255,7 @@ def test_rec_subnet_construct_3():
22552255
"accum_att_weights": {"class": "eval", "from": ["prev:accum_att_weights", "att_weights", "base:inv_fertility"],
22562256
"eval": "source(0) + source(1) * source(2) * 0.5",
22572257
"out_type": {"dim": 1, "shape": (None, 1)}},
2258-
"att": {"class": "generic_attention", "weights": "att_weights", "base": "base:encoder"},
2258+
"att": {"class": "generic_attention", "weights": "att_weights", "base": "base:encoder", "auto_squeeze": True},
22592259
"s": {"class": "rnn_cell", "unit": "LSTMBlock", "from": ["target_embed", "att"], "n_out": 10},
22602260
"s2": {"class": "rnn_cell", "unit": "LSTMBlock", "from": ["prev:s", "prev:target_embed", "att"], "n_out": 10},
22612261
"readout_in": {"class": "linear", "from": ["s2"], "activation": None, "n_out": 10},
@@ -2288,9 +2288,9 @@ def test_rec_subnet_eval_init_out_apply0():
22882288
# (also defined by num_inputs & num_outputs)
22892289
beam_size = 3
22902290
AttNumHeads = 2
2291-
EncKeyTotalDim = AttNumHeads * 2
2291+
EncKeyTotalDim = AttNumHeads * 5
22922292
EncKeyPerHeadDim = EncKeyTotalDim // AttNumHeads
2293-
EncValueTotalDim = AttNumHeads * 2
2293+
EncValueTotalDim = AttNumHeads * 5
22942294
EncValuePerHeadDim = EncValueTotalDim // AttNumHeads
22952295
network = {
22962296
"lstm0_fw": {"class": "rec", "unit": "nativelstm2", "n_out": 2, "direction": 1, "from": "data:data"},
@@ -2334,7 +2334,7 @@ def test_rec_subnet_eval_init_out_apply0():
23342334
"eval": "source(0) + source(1) * source(2) * 0.5",
23352335
"out_type": {"dim": 1, "shape": (None, 1)}, "initial_output": "apply(0)"}, # (B, enc-T, 1)
23362336
"att0": {"class": "generic_attention", "weights": "att_weights", "base": "base:enc_value"}, # (B, H, V)
2337-
"att": {"class": "merge_dims", "axes": "except_batch", "from": ["att0"]}, # (B, H*V)
2337+
"att": {"class": "merge_dims", "axes": ["dim:%i" % AttNumHeads, "dim:%i" % EncValuePerHeadDim], "from": "att0"}, # (B, H*V)
23382338

23392339
"s": {"class": "rnn_cell", "unit": "LSTMBlock", "from": ["target_embed", "att"], "n_out": 2}, # transform
23402340
"readout_in": {"class": "linear", "from": ["prev:s", "prev:target_embed", "att"], "activation": None,
@@ -2768,13 +2768,13 @@ def custom_construction_algo(idx, net_dict):
27682768
def test_net_safe_log_to_log_softmax():
27692769
n_out = 5
27702770
net_dict = {
2771-
"ff_in_window": {"class": "window", "window_size": 3, "from": "data:data"}, # (B,T,3,3)
2772-
"ff_in": {"class": "merge_dims", "axes": "except_time", "from": ["ff_in_window"]}, # (B,T,9)
2773-
"ff0": {"class": "hidden", "activation": "relu", "n_out": 8, "L2": 0.01, "from": ["ff_in"]}, # (B,T,8)
2774-
"ff_out": {"class": "softmax", "n_out": n_out, "from": ["ff0"]}, # (B,T,5)
2771+
"ff_in_window": {"class": "window", "window_size": 4, "from": "data:data"}, # (B,T,4,3)
2772+
"ff_in": {"class": "merge_dims", "axes": ["dim:3", "dim:4"], "from": "ff_in_window"}, # (B,T,9)
2773+
"ff0": {"class": "hidden", "activation": "relu", "n_out": 8, "L2": 0.01, "from": "ff_in"}, # (B,T,8)
2774+
"ff_out": {"class": "softmax", "n_out": n_out, "from": "ff0"}, # (B,T,5)
27752775
"ff_out_prior": {
27762776
"class": "accumulate_mean", "exp_average": 0.001,
2777-
"is_prob_distribution": True, "from": ["ff_out"]}, # (5,)
2777+
"is_prob_distribution": True, "from": "ff_out"}, # (5,)
27782778
"output": {
27792779
"class": "combine", "kind": "eval", "from": ["ff_out", "ff_out_prior"],
27802780
"eval": "safe_log(source(0)) - safe_log(source(1))",
@@ -2826,7 +2826,7 @@ def test_preload_from_files():
28262826
"class": "linear", "activation": None, "n_out": n_hidden, "from": "data:data",
28272827
'bias_init': 1.0, 'forward_weights_init': 'orthogonal'},
28282828
"output": {
2829-
"class": "linear", "activation": None, "n_out": n_out, "from": ["l1"],
2829+
"class": "linear", "activation": None, "n_out": n_out, "from": "l1",
28302830
'bias_init': 2.0, 'forward_weights_init': 'orthogonal'}
28312831
}
28322832
})
@@ -3366,7 +3366,7 @@ def test_attention_forward_hdf_then_unflatten_2d():
33663366
# (B, enc-T, 1)
33673367
"energy": {"class": "linear", "activation": None, "with_bias": False, "from": ["energy_tanh"], "n_out": 1},
33683368
"att_weights": {"class": "softmax_over_spatial", "from": ["energy"], "is_output_layer": True}, # (B, enc-T, 1)
3369-
"att": {"class": "generic_attention", "weights": "att_weights", "base": "base:encoder"},
3369+
"att": {"class": "generic_attention", "weights": "att_weights", "base": "base:encoder", "auto_squeeze": True},
33703370
"s": {"class": "rnn_cell", "unit": "LSTMBlock", "from": ["prev:target_embed", "prev:att"], "n_out": 10},
33713371
"readout_in": {"class": "linear", "from": ["s", "prev:target_embed", "att"], "activation": None, "n_out": 10},
33723372
"readout": {"class": "reduce_out", "mode": "max", "num_pieces": 2, "from": ["readout_in"]},

0 commit comments

Comments
 (0)