diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 184f565314..94b4738f1f 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -1635,7 +1635,7 @@ def __init__(self, axis=None, energy_factor=None, By default, if dyn seq len exists, it uses it. :param bool log_space: if True, returns in log space (i.e. uses log_softmax) """ - from returnn.tf.util.basic import where_bc + from returnn.tf.util.basic import where_bc, set_padding_info super(SoftmaxOverSpatialLayer, self).__init__(**kwargs) self.start = start self.window_start = window_start @@ -1665,6 +1665,8 @@ def __init__(self, axis=None, energy_factor=None, self.output_before_activation = OutputWithActivation( energy, act_func=tf.nn.log_softmax if log_space else tf.nn.softmax) # (...,T) self.output.placeholder = self.output_before_activation.y + if use_time_mask: + set_padding_info(self.output.placeholder, dim=self.output.dim_tags[axis], pad_value=0.) # Never allow inf in output, as softmax should remove all -inf values used for masking self.allow_inf_in_output = False @@ -1753,6 +1755,7 @@ def __init__(self, mask_value, axis="T", seq_len_source=None, :param LayerBase|None window_start: Tensor of shape (B,) indicating the window start :param LayerBase|int|None window_size: """ + from ..util.basic import set_padding_info super(SeqLenMaskLayer, self).__init__(**kwargs) self.seq_len_source = seq_len_source self.start = start @@ -1768,6 +1771,8 @@ def __init__(self, mask_value, axis="T", seq_len_source=None, window_size=window_size.output if isinstance(window_size, LayerBase) else window_size) from returnn.tf.util.basic import where_bc x_ = where_bc(mask, self.input_data.placeholder, mask_value) + axis_ = self.input_data.get_axis_from_description(axis) + set_padding_info(x_, dim=self.input_data.dim_tags[axis_], pad_value=mask_value) self.output.placeholder = x_ if mask_value in [float("-inf"), float("inf")]: self.allow_inf_in_output = True @@ -5370,7 +5375,7 @@ def __init__(self, red1=-1, red2=-2, var1=-2, var2=-1, add_var2_if_empty=True, d :param bool add_var2_if_empty: if var2=None, add dim=1 at the end :param bool debug: will print debug shapes, etc. """ - from returnn.tf.util.basic import prod + from returnn.tf.util.basic import prod, get_shape, get_padding_info_dict_ref, mask_dyn_seq_len_nd super(DotLayer, self).__init__(**kwargs) a_out = self.sources[0].output.copy() b_out = self.sources[1].output.copy() @@ -5403,27 +5408,48 @@ def __init__(self, red1=-1, red2=-2, var1=-2, var2=-1, add_var2_if_empty=True, d # So we reshape such that we collapse all reduce-axes and var-axes into each a single axis. a = a_out.placeholder b = b_out.placeholder - a_shape = tf.shape(a) - b_shape = tf.shape(b) - a_shape = [a_out.batch_shape[i] or a_shape[i] for i in range(a_out.batch_ndim)] - b_shape = [b_out.batch_shape[i] or b_shape[i] for i in range(b_out.batch_ndim)] + a_shape = get_shape(a) + b_shape = get_shape(b) a_rem_dims = [a_shape[i] for i in a_rem_axes] b_rem_dims = [b_shape[i] for i in b_rem_axes] assert len(a_rem_axes) == len(b_rem_axes), "%s: remaining shared (batch) axes do not match. sources %r" % ( self, self.sources) assert all([ - isinstance(d1, tf.Tensor) or isinstance(d2, tf.Tensor) or d1 == d2 - for (d1, d2) in zip(a_rem_dims, b_rem_dims)]) + a_out.dim_tags[i1] == b_out.dim_tags[i2] or d1 == d2 + for (d1, d2, i1, i2) in zip(a_rem_dims, b_rem_dims, a_rem_axes, b_rem_axes)]) a_var_dims = [a_shape[i] for i in a_var_axes] b_var_dims = [b_shape[i] for i in b_var_axes] a_reduce_dims = [a_shape[i] for i in a_reduce_axes] b_reduce_dims = [b_shape[i] for i in b_reduce_axes] assert len(a_reduce_axes) == len(b_reduce_axes) assert all([ - isinstance(d1, tf.Tensor) or isinstance(d2, tf.Tensor) or d1 == d2 - for (d1, d2) in zip(a_reduce_dims, b_reduce_dims)]) + a_out.dim_tags[i1] == b_out.dim_tags[i2] or d1 == d2 + for (d1, d2, i1, i2) in zip(a_reduce_dims, b_reduce_dims, a_reduce_axes, b_reduce_axes)]) a_var_dim = prod(a_var_dims) b_var_dim = prod(b_var_dims) + a_reduce_dyn_axes = [i for i in a_reduce_axes if a_out.batch_shape[i] is None] + b_reduce_dyn_axes = [i for i in b_reduce_axes if b_out.batch_shape[i] is None] + assert len(a_reduce_dyn_axes) == len(b_reduce_dyn_axes) + if a_reduce_dyn_axes: + a_pad, b_pad = get_padding_info_dict_ref(a), get_padding_info_dict_ref(b) + a_pad_values = [a_pad.get(a_out.dim_tags[i], None) for i in a_reduce_dyn_axes] + b_pad_values = [b_pad.get(b_out.dim_tags[i], None) for i in b_reduce_dyn_axes] + if set(a_pad_values) == {0}: + self._info_reduce_mask = "source-0-already-masked" # it's already masked as needed + elif set(b_pad_values) == {0}: + self._info_reduce_mask = "source-1-already-masked" # it's already masked as needed + else: + # We need to apply a mask. + # We don't need it on both a and b. We can either apply it on a or on b. + # Use some very simple heuristic where the mask is maybe cheaper. + if len(a_shape) < len(b_shape): + a = mask_dyn_seq_len_nd(a_out, pad_value=0, axes=a_reduce_dyn_axes) + self._info_reduce_mask = "mask-source-0" + else: + b = mask_dyn_seq_len_nd(b_out, pad_value=0, axes=b_reduce_dyn_axes) + self._info_reduce_mask = "mask-source-1" + else: + self._info_reduce_mask = "none-dynamic" a_reduce_dim = prod(a_reduce_dims) b_reduce_dim = prod(b_reduce_dims) if debug: diff --git a/returnn/tf/util/basic.py b/returnn/tf/util/basic.py index 318f4abf6b..a7314f737a 100644 --- a/returnn/tf/util/basic.py +++ b/returnn/tf/util/basic.py @@ -211,6 +211,71 @@ def copy_with_new_split_axes(old_axis_splits, new_axis_splits, old_values, new_v return new_values +def get_padding_info_dict_ref(x): + """ + :param tf.Tensor x: + :rtype: dict[DimensionTag,float|int] + """ + _attr = "RETURNN_attr_padding_value_info" + if hasattr(x, _attr): + d = getattr(x, _attr) + assert isinstance(d, dict) + return d + d = {} + setattr(x, _attr, d) + return d + + +def set_padding_info(x, dim, pad_value): + """ + Stores the information what kind of padding value to expect after masking in the given dynamic dim. + + :param tf.Tensor x: + :param returnn.tf.util.data.DimensionTag dim: dynamic seq len axis + :param float|int pad_value: + """ + d = get_padding_info_dict_ref(x) + # If there is some earlier padding info, only keep it when it is the same value. + # Otherwise it becomes invalid. + for k, v in list(d.items()): + if v != pad_value: + del d[k] + d[dim] = pad_value + + +def mask_dyn_seq_len_nd(x, pad_value, axes): + """ + :param Data x: + :param float|int pad_value: + :param list[int]|tuple[int] axes: + :return: masked x + :rtype: tf.Tensor + """ + x_ = x.placeholder + dim_tags = [x.dim_tags[i] for i in axes] + d = get_padding_info_dict_ref(x_) + existing_pad_values = [d.get(tag) for tag in dim_tags] + if set(existing_pad_values) == {pad_value}: + return x.placeholder # nothing to do + + x_shape = get_shape(x_) + mask = tf.ones([1] * len(x_shape), dtype=tf.bool) + for axis in axes: + tag = x.dim_tags[axis] + idx_range = tf.range(x_shape[axis]) + idx_range = tf.reshape(idx_range, [1] * (axis - 1) + x_shape[axis:axis + 1] + [1] * (len(x_shape) - axis - 1)) + assert tag.dyn_size_ext + assert set(tag.dyn_size_ext.dim_tags).issubset(x.dim_tags) + size_ext = tag.dyn_size_ext.copy_compatible_to(x, check_dtype=False) + mask_ = tf.less(idx_range, size_ext.placeholder) + mask = tf.logical_and(mask, mask_) + x_ = where_bc(mask, x_, tf.cast(tf.constant(pad_value, name="pad_value"), dtype=x_.dtype)) + d = get_padding_info_dict_ref(x_) + d.clear() + d.update({k: pad_value for k in dim_tags}) + return x_ + + class OutputWithActivation(object): """ Stores some tensor before and after some activation function, diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 7e3d933f8b..a1f6754bb7 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -1306,6 +1306,8 @@ def __init__(self, name, if time_dim_axis is NotSpecified: time_dim_axis = _default_time_dim_axis_dim_tags(dim_tags) dim_tags = tuple(dim_tags) + if auto_create_placeholders: + _auto_create_size_placeholders_on_dim_tags(name=name, dim_tags=dim_tags) del shape_ del batch_dim_axis_ else: @@ -3663,6 +3665,19 @@ def _batch_shape_from_shape(shape, batch_dim_axis): return shape +def _create_size_placeholder(name, axis_wo_b, tag): + """ + :param str name: + :param int axis_wo_b: + :param DimensionTag tag: + """ + from .basic import reuse_name_scope + with reuse_name_scope("extern_data/placeholders/%s" % name, absolute=True): + dyn_size = tf_compat.v1.placeholder( + name="%s_dim%i_size" % (name, axis_wo_b), dtype=Data.size_dtype, shape=(None,)) + tag.set_tag_on_size_tensor(dyn_size) + + def _infer_dim_tags_tuple_from_shape( shape, batch_dim_axis, time_dim_axis, feature_dim_axis, @@ -3685,7 +3700,6 @@ def _infer_dim_tags_tuple_from_shape( :return: dim tags tuple :rtype: tuple[DimensionTag] """ - from .basic import reuse_name_scope assert isinstance(shape, (tuple, list)) shape = tuple(shape) batch_shape = _batch_shape_from_shape(shape, batch_dim_axis=batch_dim_axis) @@ -3723,21 +3737,19 @@ def _infer_dim_tags_tuple_from_shape( dyn_size = size_placeholder.get(axis_wo_b) if (size_placeholder and axis_wo_b is not None) else None dim = batch_shape[axis] if auto_create_placeholders and dim is None and dyn_size is None and axis != batch_dim_axis: - with reuse_name_scope("extern_data/placeholders/%s" % name, absolute=True): - dyn_size = tf_compat.v1.placeholder( - name="%s_dim%i_size" % (name, axis_wo_b), dtype=Data.size_dtype, shape=(None,)) - if not tag: - if axis == time_dim_axis: - tag_name = "time" - else: - tag_name = "spatial%i" % axis - tag = DimensionTag( - description="%s:var:extern_data:%s" % (tag_name, name), - # Spatial dim tag, even if axis == feature_dim_axis. This is to keep the old behavior. - # This is such that DimensionTag.is_equal behaves as before, e.g. in Data.get_common_data. - kind=DimensionTag.Types.Spatial) - dim_tags[axis] = tag - tag.set_tag_on_size_tensor(dyn_size) + if not tag: + if axis == time_dim_axis: + tag_name = "time" + else: + tag_name = "spatial%i" % axis + tag = DimensionTag( + description="%s:var:extern_data:%s" % (tag_name, name), + # Spatial dim tag, even if axis == feature_dim_axis. This is to keep the old behavior. + # This is such that DimensionTag.is_equal behaves as before, e.g. in Data.get_common_data. + kind=DimensionTag.Types.Spatial) + dim_tags[axis] = tag + _create_size_placeholder(name=name, axis_wo_b=axis_wo_b, tag=tag) + dyn_size = tag.dyn_size if tag: # Just some sanity checks. assert isinstance(tag, DimensionTag) @@ -3765,6 +3777,23 @@ def _infer_dim_tags_tuple_from_shape( return tuple(dim_tags[axis] for axis in range(len(batch_shape))) +def _auto_create_size_placeholders_on_dim_tags(name, dim_tags): + """ + :param str name: + :param tuple[DimensionTag] dim_tags: + """ + batch_dim_axis = _batch_dim_axis_from_dim_tags_tuple(dim_tags) + for axis, tag in enumerate(dim_tags): + if tag.is_batch_dim(): + continue + if tag.dimension is not None: + continue + if tag.dyn_size is not None: + continue + axis_wo_b = _get_axis_wo_b(axis, batch_dim_axis=batch_dim_axis) + _create_size_placeholder(name=name, axis_wo_b=axis_wo_b, tag=tag) + + def _get_axis_wo_b(axis_wb, batch_dim_axis, batch_ndim=None): """ :param int axis_wb: counted with batch-dim diff --git a/tests/test_TFNetworkLayer.py b/tests/test_TFNetworkLayer.py index 8b774903c4..fb852e61bc 100644 --- a/tests/test_TFNetworkLayer.py +++ b/tests/test_TFNetworkLayer.py @@ -3172,6 +3172,7 @@ def test_DotLayer(): assert_equal(seq_lens.tolist(), a_seq_lens) assert_equal(out.shape, (B, H, max(a_seq_lens), 1)) + def test_DotLayer2(): """ Test if DotLayer can handle inputs which dont have a batch-dim """ @@ -3193,8 +3194,8 @@ def test_DotLayer2(): b = InternalLayer(name='B', network=net, output=Data(name="B", shape=(S1, S2, R, V), batch_dim_axis=None, time_dim_axis=None)) - assert b.output.batch_dim_axis == None - assert b.output.time_dim_axis == None + assert b.output.batch_dim_axis is None + assert b.output.time_dim_axis is None assert b.output.shape == (S1, S2, R, V) assert b.output.dim == V b.output.placeholder = tf.reshape(tf.range(S1 * S2 * R * V, dtype=tf.float32), (S1, S2, R, V)) @@ -3215,6 +3216,69 @@ def test_DotLayer2(): assert_equal(out.shape, (S1, S2, B, V)) +def test_DotLayer_mask_dyn_seq(): + batch = DimensionTag(kind=DimensionTag.Types.Batch, description="batch") + time = DimensionTag(kind=DimensionTag.Types.Spatial, description="time") + feat1 = DimensionTag(kind=DimensionTag.Types.Feature, description="feature 1", dimension=3) + feat2 = DimensionTag(kind=DimensionTag.Types.Feature, description="feature 2", dimension=5) + config = Config({ + "extern_data": { + "src1": {"dim_tags": [batch, time, feat1]}, + "src2": {"dim_tags": [batch, time, feat2]}, + }, + "network": { + "dot": { + "class": "dot", "from": ["data:src1", "data:src2"], "is_output_layer": True, + "red1": time, "red2": time, "var1": feat1, "var2": feat2 + }, + }, + "debug_print_layer_output_template": True, + }) + + with make_scope() as session: + net = TFNetwork(config=config) + net.construct_from_dict(config.typed_dict["network"]) + layer = net.layers["dot"] + assert isinstance(layer, DotLayer) + assert layer.output.dim_tags == (batch, feat1, feat2) + assert layer._info_reduce_mask == "mask-source-1" + + feed_dict = make_feed_dict(net.extern_data) + session.run(layer.output.placeholder, feed_dict=feed_dict) + + +def test_DotLayer_mask_dyn_seq_after_softmax(): + batch = DimensionTag(kind=DimensionTag.Types.Batch, description="batch") + time = DimensionTag(kind=DimensionTag.Types.Spatial, description="time") + feat1 = DimensionTag(kind=DimensionTag.Types.Feature, description="feature 1", dimension=3) + feat2 = DimensionTag(kind=DimensionTag.Types.Feature, description="feature 2", dimension=5) + config = Config({ + "extern_data": { + "src1": {"dim_tags": [batch, time, feat1]}, + "src2": {"dim_tags": [batch, time, feat2]}, + }, + "network": { + "sm1": {"class": "softmax_over_spatial", "from": "data:src1"}, + "dot": { + "class": "dot", "from": ["sm1", "data:src2"], "is_output_layer": True, + "red1": time, "red2": time, "var1": feat1, "var2": feat2 + }, + }, + "debug_print_layer_output_template": True, + }) + + with make_scope() as session: + net = TFNetwork(config=config) + net.construct_from_dict(config.typed_dict["network"]) + layer = net.layers["dot"] + assert isinstance(layer, DotLayer) + assert layer.output.dim_tags == (batch, feat1, feat2) + assert layer._info_reduce_mask == "source-0-already-masked" + + feed_dict = make_feed_dict(net.extern_data) + session.run(layer.output.placeholder, feed_dict=feed_dict) + + def test_subnet_load_on_init(): import tempfile model_tmp_dir = tempfile.mkdtemp("tmp-checkpoint") diff --git a/tests/test_TFNetworkRecLayer.py b/tests/test_TFNetworkRecLayer.py index 8c3d4966d6..081af39e87 100644 --- a/tests/test_TFNetworkRecLayer.py +++ b/tests/test_TFNetworkRecLayer.py @@ -4844,15 +4844,17 @@ def add_lstm(i, direction, src): def test_GenericAttentionLayer_basic0(): from returnn.tf.layers.base import InternalLayer net = TFNetwork(extern_data=ExternData(), config=Config({"debug_print_layer_output_template": True})) + time = DimensionTag(kind=DimensionTag.Types.Spatial, description="time") kwargs = dict( name="att", network=net, weights=InternalLayer( name="att_weights", network=net, output=Data( - name='att_weights_output', shape=(None, 1), auto_create_placeholders=True)), + name='att_weights_output', shape=(None, 1), auto_create_placeholders=True, same_dim_tags_as={"T": time})), base=InternalLayer( name="enc_value", network=net, - output=Data(name='enc_value_output', shape=(None, 20), auto_create_placeholders=True))) + output=Data( + name='enc_value_output', shape=(None, 20), auto_create_placeholders=True, same_dim_tags_as={"T": time}))) print("GenericAttentionLayer kwargs:") pprint(kwargs) kwargs["output"] = GenericAttentionLayer.get_out_data_from_opts(**kwargs) @@ -4867,14 +4869,19 @@ def test_GenericAttentionLayer_basic(): # This is a common situation when the GenericAttentionLayer is inside a recurrent loop, # and it gets the encoder values from outside ("base:enc_value" or so), # and the attention weights from inside the loop, and they have the same time dim axis as the encoder values. + time = DimensionTag(kind=DimensionTag.Types.Spatial, description="time") kwargs = dict( name="att", network=net, weights=InternalLayer( name="att_weights", network=net, - output=Data(name='att_weights_output', shape=(None, 1), batch_dim_axis=1, auto_create_placeholders=True)), + output=Data( + name='att_weights_output', shape=(None, 1), batch_dim_axis=1, auto_create_placeholders=True, + same_dim_tags_as={"T": time})), base=InternalLayer( name="enc_value", network=net, - output=Data(name='enc_value_output', shape=(None, 1, 2048), batch_dim_axis=1, auto_create_placeholders=True))) + output=Data( + name='enc_value_output', shape=(None, 1, 2048), batch_dim_axis=1, auto_create_placeholders=True, + same_dim_tags_as={"T": time}))) kwargs["output"] = GenericAttentionLayer.get_out_data_from_opts(**kwargs) layer = GenericAttentionLayer(**kwargs) layer.output.sanity_check() @@ -4884,17 +4891,20 @@ def test_GenericAttentionLayer_basic(): def test_GenericAttentionLayer_basic_multi_head(): from returnn.tf.layers.base import InternalLayer net = TFNetwork(extern_data=ExternData(), config=Config({"debug_print_layer_output_template": True})) + time = DimensionTag(kind=DimensionTag.Types.Spatial, description="time") num_heads = 8 kwargs = dict( name="att", network=net, weights=InternalLayer( name="att_weights", network=net, output=Data( - name='att_weights_output', shape=(None, num_heads), batch_dim_axis=1, auto_create_placeholders=True)), + name='att_weights_output', shape=(None, num_heads), batch_dim_axis=1, auto_create_placeholders=True, + same_dim_tags_as={"T": time})), base=InternalLayer( name="enc_value", network=net, output=Data( - name='enc_value_output', shape=(None, num_heads, 2048), batch_dim_axis=1, auto_create_placeholders=True))) + name='enc_value_output', shape=(None, num_heads, 2048), batch_dim_axis=1, auto_create_placeholders=True, + same_dim_tags_as={"T": time}))) kwargs["output"] = GenericAttentionLayer.get_out_data_from_opts(**kwargs) layer = GenericAttentionLayer(**kwargs) layer.output.sanity_check() @@ -4905,15 +4915,18 @@ def test_GenericAttentionLayer_weights_auto_squeeze_time_end(): # Example: weights (B,1,T), base (B,T,V) from returnn.tf.layers.base import InternalLayer net = TFNetwork(extern_data=ExternData(), config=Config({"debug_print_layer_output_template": True})) + time = DimensionTag(kind=DimensionTag.Types.Spatial, description="time") kwargs = dict( name="att", network=net, weights=InternalLayer( name="att_weights", network=net, output=Data( - name='att_weights_output', shape=(1, None), time_dim_axis=2, auto_create_placeholders=True)), + name='att_weights_output', shape=(1, None), time_dim_axis=2, auto_create_placeholders=True, + same_dim_tags_as={"T": time})), base=InternalLayer( name="enc_value", network=net, - output=Data(name='enc_value_output', shape=(None, 2048), auto_create_placeholders=True))) + output=Data( + name='enc_value_output', shape=(None, 2048), auto_create_placeholders=True, same_dim_tags_as={"T": time}))) print("GenericAttentionLayer kwargs:") pprint(kwargs) kwargs["output"] = GenericAttentionLayer.get_out_data_from_opts(**kwargs) @@ -4927,15 +4940,19 @@ def test_GenericAttentionLayer_weights_static_time_axis(): window_size = 10 from returnn.tf.layers.base import InternalLayer net = TFNetwork(extern_data=ExternData(), config=Config({"debug_print_layer_output_template": True})) + time = DimensionTag(kind=DimensionTag.Types.Spatial, description="time") kwargs = dict( name="att", network=net, weights=InternalLayer( name="att_weights", network=net, output=Data( - name='att_weights_output', shape=(1, 10), time_dim_axis=2, auto_create_placeholders=True)), + name='att_weights_output', shape=(1, 10), time_dim_axis=2, auto_create_placeholders=True, + same_dim_tags_as={"T": time})), base=InternalLayer( name="enc_value", network=net, - output=Data(name='enc_value_output', shape=(10, 2048), time_dim_axis=1, auto_create_placeholders=True))) + output=Data( + name='enc_value_output', shape=(10, 2048), time_dim_axis=1, auto_create_placeholders=True, + same_dim_tags_as={"T": time}))) print("GenericAttentionLayer kwargs:") pprint(kwargs) kwargs["output"] = GenericAttentionLayer.get_out_data_from_opts(**kwargs) @@ -4948,16 +4965,20 @@ def test_GenericAttentionLayer_weights_heads_time_end(): # Example: weights (B,H,T), base (B,T,H,V) from returnn.tf.layers.base import InternalLayer net = TFNetwork(extern_data=ExternData(), config=Config({"debug_print_layer_output_template": True})) + time = DimensionTag(kind=DimensionTag.Types.Spatial, description="time") num_heads = 8 kwargs = dict( name="att", network=net, weights=InternalLayer( name="att_weights", network=net, output=Data( - name='att_weights_output', shape=(num_heads, None), time_dim_axis=2, auto_create_placeholders=True)), + name='att_weights_output', shape=(num_heads, None), time_dim_axis=2, auto_create_placeholders=True, + same_dim_tags_as={"T": time})), base=InternalLayer( name="enc_value", network=net, - output=Data(name='enc_value_output', shape=(None, num_heads, 2048), auto_create_placeholders=True))) + output=Data( + name='enc_value_output', shape=(None, num_heads, 2048), auto_create_placeholders=True, + same_dim_tags_as={"T": time}))) print("GenericAttentionLayer kwargs:") pprint(kwargs) kwargs["output"] = GenericAttentionLayer.get_out_data_from_opts(**kwargs) @@ -4970,6 +4991,7 @@ def test_GenericAttentionLayer_weights_heads_auto_squeeze_time_end(): # Example: weights (B,H,1,T), base (B,T,H,V) from returnn.tf.layers.base import InternalLayer net = TFNetwork(extern_data=ExternData(), config=Config({"debug_print_layer_output_template": True})) + time = DimensionTag(kind=DimensionTag.Types.Spatial, description="time") num_heads = 8 kwargs = dict( name="att", network=net, @@ -4977,10 +4999,12 @@ def test_GenericAttentionLayer_weights_heads_auto_squeeze_time_end(): name="att_weights", network=net, output=Data( name='att_weights_output', shape=(num_heads, 1, None), time_dim_axis=3, - auto_create_placeholders=True)), + auto_create_placeholders=True, same_dim_tags_as={"T": time})), base=InternalLayer( name="enc_value", network=net, - output=Data(name='enc_value_output', shape=(None, num_heads, 2048), auto_create_placeholders=True))) + output=Data( + name='enc_value_output', shape=(None, num_heads, 2048), auto_create_placeholders=True, + same_dim_tags_as={"T": time}))) print("GenericAttentionLayer kwargs:") pprint(kwargs) kwargs["output"] = GenericAttentionLayer.get_out_data_from_opts(**kwargs)