From fe31ae344a6eec7b65bddc6343f653fd8a2d6aae Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 1 Mar 2022 17:55:05 +0100 Subject: [PATCH 01/29] test_MergeDimsLayer_unspecified_out_dim https://github.com/rwth-i6/returnn/issues/955 https://github.com/rwth-i6/returnn_common/issues/117 --- tests/test_TFNetworkLayer.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_TFNetworkLayer.py b/tests/test_TFNetworkLayer.py index 1a83e5c29c..07c9d807a6 100644 --- a/tests/test_TFNetworkLayer.py +++ b/tests/test_TFNetworkLayer.py @@ -2564,6 +2564,22 @@ def test_MergeDimsLayer_modified_time_dim(): session.run(out.placeholder, feed_dict=make_feed_dict(net.extern_data)) +def test_MergeDimsLayer_unspecified_out_dim(): + # https://github.com/rwth-i6/returnn/issues/955 + # https://github.com/rwth-i6/returnn_common/issues/117 + config = Config({ + "extern_data": {"data": {"shape": (None, 3, 5)}}, + }) + out_dim = SpatialDim("out") + with make_scope() as session: + net = TFNetwork(config=config) + net.construct_from_dict({ + "output": { + "class": "merge_dims", "from": "data", "axes": ["dim:3", "dim:5"], "keep_order": True, + "out_dim": out_dim}, + }) + + def test_FlattenBatchLayer(): from returnn.tf.util.data import BatchInfo n_batch, n_time, n_in = 3, 4, 2 From 4bae259a3fb79ff36538d1d71337006732ebe7be Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 1 Mar 2022 22:27:52 +0100 Subject: [PATCH 02/29] Dim declare_same_as, set when is_dim_known #955 --- returnn/tf/util/data.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 8e8ad305f2..722f296e48 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -899,6 +899,12 @@ def declare_same_as(self, other): elif other_same_base.dyn_size_ext is None or not other_same_base._validate_in_current_graph(): other_same_base.dyn_size_ext = self.get_dyn_size_ext_for_batch_ctx( other_same_base.batch, other_same_base.control_flow_ctx) + if self.is_dim_known() and other.is_dim_known(): + assert self.dimension == other.dimension + elif self.is_dim_known() and not other.is_dim_known(): + other.dimension = self.dimension + elif not self.is_dim_known() and other.is_dim_known(): + self.dimension = other.dimension if self._vocab and not other_same_base._vocab: other_same_base._vocab = self._vocab elif other_same_base._vocab and not self._vocab: From 0a0641850de120a42e9ec9b35b2cc39ef52b3859 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 1 Mar 2022 22:28:48 +0100 Subject: [PATCH 03/29] MergeDimsLayer, declare_same_as on out_dim #955 --- returnn/tf/layers/basic.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index a16f39751c..dbb788d5ca 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -3488,6 +3488,7 @@ def get_out_data_from_opts(cls, name, axes, keep_order=NotSpecified, :rtype: Data """ from returnn.util import BehaviorVersion + from returnn.util.basic import prod if keep_order is NotSpecified: keep_order = True if BehaviorVersion.get() >= 6 else False assert not out_type, "currently ignored" @@ -3497,21 +3498,18 @@ def get_out_data_from_opts(cls, name, axes, keep_order=NotSpecified, data = input_data.copy(name="%s_output" % name) if len(axes) <= 1: return data - import numpy res_dim = None if all([data.batch_shape[i] is not None for i in axes]): - res_dim = int(numpy.prod([data.batch_shape[i] for i in axes])) + res_dim = int(prod([data.batch_shape[i] for i in axes])) merge_dim_tags = [data.dim_tags[axis] for axis in axes] merge_target_axis = cls._get_target_axis(input_data=data, merge_axes=axes) + out_dim_ = prod(merge_dim_tags) + assert isinstance(out_dim_, Dim) + assert out_dim_.dimension == res_dim if out_dim: - assert out_dim.dimension == res_dim - else: - from numpy import prod - out_dim = prod(merge_dim_tags) - assert isinstance(out_dim, Dim) - assert out_dim.dimension == res_dim + out_dim_.declare_same_as(out_dim) new_dim_tags = [d for (i, d) in enumerate(data.dim_tags) if i not in axes] - new_dim_tags.insert(merge_target_axis, out_dim) + new_dim_tags.insert(merge_target_axis, out_dim_) data_opts = data.get_kwargs(include_special_axes=False) data_opts["dim_tags"] = new_dim_tags From b00168853121155009b2addef0d5b8624a7692ad Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 1 Mar 2022 22:42:23 +0100 Subject: [PATCH 04/29] ConvLayer and co, declare_same_as on out_spatial_dims #955 --- returnn/tf/layers/basic.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index dbb788d5ca..644580f049 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -5043,7 +5043,9 @@ def set_output_dim_tags(cls, output, num_batch_dims, in_spatial_dims, out_spatia assert output.feature_dim_axis == output.batch_ndim - 1 out_spatial_dims_ = output.dim_tags[num_batch_dims:-1] if out_spatial_dims: - assert list(out_spatial_dims_) == list(out_spatial_dims) + assert len(out_spatial_dims_) == len(out_spatial_dims) + for i, (out_spatial_dim_, out_spatial_dim) in enumerate(zip(out_spatial_dims_, out_spatial_dims)): + out_spatial_dim_.declare_same_as(out_spatial_dim) assert len(out_spatial_dims_) == len(in_spatial_dims) == len(filter_size) == len(strides) == len(dilation_rate) for i, in_tag in enumerate(in_spatial_dims): out_tag = out_spatial_dims_[i] From b1bbcb6b5f53f92fe43411f53e0b9d6f757b1e52 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 09:32:29 +0100 Subject: [PATCH 05/29] WindowLayer, declare_same_as on out_spatial_dim, window_dim #955 --- returnn/tf/layers/basic.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 644580f049..1bd214ebdb 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -3087,27 +3087,19 @@ def get_out_data_from_opts(cls, name, network, sources, window_size=None, window else: axis = data.get_axis_from_description(axis) in_spatial_dim = data.dim_tags[axis] - if (padding.lower() == "same" or window_size == 1) and stride == 1: # no change in spatial dim - out_spatial_dim = in_spatial_dim # error check in __init__ - else: # new spatial dim - if not out_spatial_dim: - dim = None - if in_spatial_dim.dimension is not None: - dim = ConvLayer.calc_out_dim( - in_dim=in_spatial_dim.dimension, - filter_size=window_size, stride=stride, dilation_rate=1, padding=padding) - out_spatial_dim = Dim( - kind=Dim.Types.Spatial, description="%s:spatial" % name, - dimension=dim, derived_from_tag=in_spatial_dim, auto_generated=True, - batch=data.batch, control_flow_ctx=data.control_flow_ctx) - data = data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_spatial_dim) + out_spatial_dim_ = ConvLayer.calc_out_dim( + in_dim=in_spatial_dim, + filter_size=window_size, stride=stride, dilation_rate=1, padding=padding) + assert isinstance(out_spatial_dim_, Dim) + if out_spatial_dim: + out_spatial_dim_.declare_same_as(out_spatial_dim) + data = data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_spatial_dim_) new_dim_axis = axis + 1 # add new axis right after + window_dim_ = Dim( + kind=Dim.Types.Spatial, description="%s:window" % name, dimension=window_size, auto_generated=True) if window_dim: - assert window_dim.dimension == window_size - else: - window_dim = Dim( - kind=Dim.Types.Spatial, description="%s:window" % name, dimension=window_size, auto_generated=True) - return data.copy_add_dim_by_tag(axis=new_dim_axis, dim_tag=window_dim, unbroadcast=True) + window_dim_.declare_same_as(window_dim) + return data.copy_add_dim_by_tag(axis=new_dim_axis, dim_tag=window_dim_, unbroadcast=True) # noinspection PyMethodOverriding @classmethod From 723dcc75ccb976576c2cd401860e13378c7b0811 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 09:34:07 +0100 Subject: [PATCH 06/29] WindowLayer, cleanup --- returnn/tf/layers/basic.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 1bd214ebdb..937114ae3b 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -3030,25 +3030,6 @@ def __init__(self, window_size=None, window_dim=None, window_left=None, window_r else: axis = data.get_axis_from_description(axis) new_dim_axis = axis + 1 # new axis will be added right after - in_spatial_dim = data.dim_tags[axis] - out_spatial_dim_ = self.output.dim_tags[axis] - if out_spatial_dim: - assert out_spatial_dim_ == out_spatial_dim - if (padding.lower() == "same" or window_size == 1) and stride == 1: # no change in spatial dim - assert in_spatial_dim == out_spatial_dim - if in_spatial_dim != out_spatial_dim_ and out_spatial_dim_.dimension is None: - if not out_spatial_dim_.dyn_size_ext: - out_spatial_dim_.dyn_size_ext = in_spatial_dim.dyn_size_ext.copy_template(name="%s:spatial-size" % self.name) - if out_spatial_dim_.dyn_size_ext.placeholder is None: - from ..util.basic import same_control_flow_ctx - from ..util.data import Dim - assert in_spatial_dim.dyn_size is not None - size = in_spatial_dim.dyn_size - with same_control_flow_ctx(size): - size = ConvLayer.calc_out_dim( - in_dim=size, - filter_size=window_size, stride=stride, dilation_rate=1, padding=padding) - out_spatial_dim_.dyn_size_ext.placeholder = size from returnn.tf.util.basic import windowed_nd self.output.placeholder = windowed_nd( From 11b9fefb85997ae50245a4638d8a920e8557cfcc Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 11:15:17 +0100 Subject: [PATCH 07/29] TransposedConvLayer, declare_same_as on out_spatial_dims, cleanup #955 --- returnn/tf/layers/basic.py | 71 +++++++++++++++----------------------- 1 file changed, 28 insertions(+), 43 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 937114ae3b..fc79524ab2 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -5157,7 +5157,7 @@ def transform_input(cls, input_data, network, in_dim=None, in_spatial_dims=None, @classmethod def calc_out_dim(cls, in_dim, filter_size, stride, padding, dilation_rate=1): """ - :param int|tf.Tensor|Dim|T in_dim: dimension in some axis + :param T|int|tf.Tensor|Dim in_dim: dimension in some axis :param int filter_size: e.g. 2, for the corresponding axis :param int stride: e.g. 1, for the corresponding axis :param int dilation_rate: e.g. 1 @@ -5656,26 +5656,6 @@ def __init__(self, filter_size, strides=None, self.output_before_activation = OutputWithActivation(y) y = self.output_before_activation.y self.output.placeholder = y - for idx, axis in enumerate(spatial_axes): - input_tag = input_data.dim_tags[axis] - output_tag = self.output.dim_tags[axis] - if input_tag == output_tag: - continue - assert not input_tag.is_batch_dim() and not output_tag.is_batch_dim() - if input_tag.dimension is None: - assert output_tag.dimension is None - assert input_tag.dyn_size is not None - size = input_tag.dyn_size - with tf_util.same_control_flow_ctx(size): - size = self.deconv_output_length( - size, - filter_size=filter_size[idx], stride=strides[idx], - padding=padding, output_padding=output_padding[idx]) - r = remove_padding[idx] - if r: - assert isinstance(r, int) - size = tf_util.simplify_add(size, -r * 2) - output_tag.set_tag_on_size_tensor(size, batch=self.output.batch) @staticmethod def deconv_output_length(input_length, @@ -5688,7 +5668,9 @@ def deconv_output_length(input_length, Determines output length of a transposed convolution given input length. Copied from conv_utils.deconv_output_length, adapted with simplification. - :param T|int|tf.Tensor input_length: + Also see :func:`ConvLayer.calc_out_dim`. + + :param T|int|tf.Tensor|Dim input_length: :param int filter_size: :param str padding: one of `"same"`, `"valid"`, `"full"`. :param int|None output_padding: amount of padding along the output dimension. @@ -5709,13 +5691,19 @@ def deconv_output_length(input_length, # Infer length if output padding is None, else compute the exact length if output_padding is None: if padding == 'valid': - length = tf_util.simplify_add(input_length, max(filter_size - stride, 0)) + if isinstance(input_length, Dim): + length = input_length + max(filter_size - stride, 0) + else: + length = tf_util.simplify_add(input_length, max(filter_size - stride, 0)) elif padding == 'full': - length = tf_util.simplify_add(input_length, -(stride + filter_size - 2)) + if isinstance(input_length, Dim): + length = input_length - (stride + filter_size - 2) + else: + length = tf_util.simplify_add(input_length, -(stride + filter_size - 2)) elif padding == 'same': length = input_length else: - length = None + raise Exception("invalid padding %r" % (padding,)) else: # output_padding if padding == 'same': pad = filter_size // 2 @@ -5724,8 +5712,11 @@ def deconv_output_length(input_length, elif padding == 'full': pad = filter_size - 1 else: - pad = None - length = tf_util.simplify_add(input_length, -stride + filter_size - 2 * pad + output_padding) + raise Exception("invalid padding %r" % (padding,)) + if isinstance(input_length, Dim): + length = input_length + (-stride + filter_size - 2 * pad + output_padding) + else: + length = tf_util.simplify_add(input_length, -stride + filter_size - 2 * pad + output_padding) return length @classmethod @@ -5769,22 +5760,16 @@ def get_out_data_from_opts(cls, name, sources, network, dim_tags = list(data.dim_tags[:num_batch_dims]) # [B] if out_spatial_dims: assert len(out_spatial_dims) == len(filter_size) - # Be relaxed about incorrect input data. Throw errors later. This can also work during template construction. - dim_tags += out_spatial_dims - else: - for i in range(len(filter_size)): - old_tag = old_spatial_dim_tags[i] if i < len(old_spatial_dim_tags) else None - if old_tag and (filter_size[i] == strides[i] == 1 or (strides[i] == 1 and padding == "SAME")): - dim_tags.append(old_tag) # identity in this axis - continue - new_dim = None - if old_tag and old_tag.dimension is not None: - new_dim = cls.deconv_output_length( - old_tag.dimension, filter_size=filter_size[i], stride=strides[i], - padding=padding, output_padding=output_padding[i]) - remove_padding[i] * 2 - dim_tags.append(Dim( - kind=Dim.Types.Spatial, description="%s:conv:s%i" % (name, i), dimension=new_dim, - derived_from_tag=old_tag, undefined=not old_tag, auto_generated=True)) + # Be relaxed about incorrect input data. Throw errors later. This can also work during template construction. + for i in range(len(filter_size)): + old_tag = old_spatial_dim_tags[i] if i < len(old_spatial_dim_tags) else None + new_tag = cls.deconv_output_length( + old_tag, filter_size=filter_size[i], stride=strides[i], + padding=padding, output_padding=output_padding[i]) + new_tag = new_tag.sub_left(remove_padding[i]).sub_right(remove_padding[i]) + if out_spatial_dims: + new_tag.declare_same_as(out_spatial_dims[i]) + dim_tags.append(new_tag) if not out_dim: assert n_out out_dim = FeatureDim("%s:channel" % name, dimension=n_out, auto_generated=True) From f41d79597979194e5da7edbb8e4d5252af8df397 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 11:18:11 +0100 Subject: [PATCH 08/29] PadLayer, declare_same_as on out_dims, cleanup #955 --- returnn/tf/layers/basic.py | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index fc79524ab2..12fc3f3632 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -3204,25 +3204,6 @@ def __init__(self, axes, padding, out_dims=None, value=0, mode="constant", **kwa self.output.placeholder = tf_util.pad_replicate(self.input_data.placeholder, axes, padding) else: self.output.placeholder = tf.pad(self.input_data.placeholder, paddings=paddings, mode=mode, constant_values=value) - for a in axes: - p = sum(paddings[a]) - in_tag = self.input_data.dim_tags[a] - out_tag = self.output.dim_tags[a] - a = self.input_data.get_batch_axis_excluding_batch(a) - if a is None: - continue - if in_tag.dyn_size is None: - continue - if p == 0: - continue - size = in_tag.dyn_size - with tf_util.same_control_flow_ctx(size): - size = tf_util.simplify_add(size, p) - size_tag = Dim.get_tag_from_size_tensor(size) - if not size_tag: - out_tag.set_tag_on_size_tensor(size, batch=in_tag.batch) - else: - out_tag.declare_same_as(size_tag) @classmethod def _transform_padding(cls, padding, axes): @@ -3275,10 +3256,8 @@ def get_out_data_from_opts(cls, name, sources, axes, padding, out_dims=None, **k pad_left, pad_right = padding[i] out_dim = pad_left + dim_tags[a] + pad_right if out_dims: - assert out_dims[i].dimension == out_dim.dimension - dim_tags[a] = out_dims[i] - else: - dim_tags[a] = out_dim + out_dim.declare_same_as(out_dims[i]) + dim_tags[a] = out_dim return data.copy_template_new_dim_tags(dim_tags, keep_special_axes=True) From 8e939e3c9c70bb7b381467f6e5704e9987233592 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 11:36:39 +0100 Subject: [PATCH 09/29] WindowLayer, small warning fix --- returnn/tf/layers/basic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 12fc3f3632..f6397767e8 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -3005,6 +3005,7 @@ def __init__(self, window_size=None, window_dim=None, window_left=None, window_r :param int stride: return only each Nth window :param kwargs: """ + out_spatial_dim # noqa # via get_out_data_from_opts super(WindowLayer, self).__init__(**kwargs) if not window_size: assert window_dim and window_dim.dimension From 285a7d539b94a2e64c78d882f50511aa9c593c96 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 11:39:12 +0100 Subject: [PATCH 10/29] Dim, dyn_size, try to complete --- returnn/tf/util/data.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 722f296e48..d528347463 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -428,6 +428,9 @@ def dyn_size(self): If the dyn size can potentially be of a different shape, directly access dyn_size_ext. :rtype: tf.Tensor|None """ + if self.dimension is None and (not self.dyn_size_ext or self.dyn_size_ext.placeholder is None): + # Try to complete. + self.complete_dyn_size() if self.dyn_size_ext: return self.dyn_size_ext.placeholder return None From b30a013fc3d552375d7f2e8c7faf5000ce5e4ec5 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 11:48:23 +0100 Subject: [PATCH 11/29] Dim, complete_dyn_size, make sure it is a tensor --- returnn/tf/util/data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index d528347463..7522f97ad3 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -617,11 +617,11 @@ def complete_dyn_size(self): def _bin_op(a, b): with tf_util.same_control_flow_ctx([a, b]): if kind == "add": - return tf_util.simplify_add(a, b) + return tf.convert_to_tensor(tf_util.simplify_add(a, b)) elif kind == "sub": - return tf_util.simplify_sub(a, b) + return tf.convert_to_tensor(tf_util.simplify_sub(a, b)) elif kind == "mul": - return a * b + return tf.convert_to_tensor(tf_util.optional_mul(a, b)) elif kind in ("floordiv", "truediv"): # truediv assumes there is no remainder return a // b elif kind == "ceildiv": From 2953a8be142b583403fe9a09aa406d0dc227c438 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 12:00:22 +0100 Subject: [PATCH 12/29] Dim, complete_dyn_size, make sure it is positive --- returnn/tf/util/data.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 7522f97ad3..0bcf716a76 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -614,12 +614,25 @@ def complete_dyn_size(self): if kind.endswith("_left"): kind = kind[:-len("_left")] + def _is_negative(x__): + if isinstance(x__, (int, float)): + return x__ < 0 + assert isinstance(x__, tf.Tensor) + from tensorflow.python.framework import tensor_util + x__ = tensor_util.constant_value(x__) + if x__ is not None: + return x__ < 0 + return False + def _bin_op(a, b): with tf_util.same_control_flow_ctx([a, b]): if kind == "add": + use_relu = _is_negative(a) or _is_negative(b) # for dynamic tensors, assume all positive + if use_relu: + return tf.convert_to_tensor(tf_util.simplify_non_negative_seq_length(tf_util.simplify_add(a, b))) return tf.convert_to_tensor(tf_util.simplify_add(a, b)) elif kind == "sub": - return tf.convert_to_tensor(tf_util.simplify_sub(a, b)) + return tf.convert_to_tensor(tf_util.simplify_non_negative_seq_length(tf_util.simplify_sub(a, b))) elif kind == "mul": return tf.convert_to_tensor(tf_util.optional_mul(a, b)) elif kind in ("floordiv", "truediv"): # truediv assumes there is no remainder From e1e68d186b1777c9c4b69d13b449cabb06b78da5 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 12:00:40 +0100 Subject: [PATCH 13/29] ConvLayer calc_out_dim, nicer --- returnn/tf/layers/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index f6397767e8..c882314569 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -5165,7 +5165,7 @@ def ceildiv(a, b): if isinstance(in_dim, Dim): filter_left_dilated = (filter_size - 1) * dilation_rate // 2 filter_right_dilated = (filter_size - 1) * dilation_rate - filter_left_dilated - valid_part = (-filter_left_dilated) + in_dim + (-filter_right_dilated) + valid_part = in_dim.sub_left(filter_left_dilated).sub_right(filter_right_dilated) return valid_part.ceildiv_right(stride) return tf_util.simplify_non_negative_seq_length( ceildiv( From dacb25b97866bdfbd6e7f14f1f8bf4487914cb12 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 12:15:43 +0100 Subject: [PATCH 14/29] Dim complete_dyn_size, small fix for missing batch info --- returnn/tf/util/data.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 0bcf716a76..b6a12bb89e 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -657,7 +657,8 @@ def _bin_op(a, b): continue y.placeholder = _bin_op(y.placeholder, x.dimension) continue - x = x.get_for_batch_ctx(self.batch, self.control_flow_ctx) + if self.batch: + x = x.get_for_batch_ctx(self.batch, self.control_flow_ctx) x.complete_dyn_size() if not x.dyn_size_ext or x.dyn_size_ext.placeholder is None: return From 92ba17ded53151ad4a4e8ef0a46f7246f5d349f7 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 12:24:27 +0100 Subject: [PATCH 15/29] ConvLayer, cleanup and fix set_output_dim_tags --- returnn/tf/layers/basic.py | 31 ++----------------------------- 1 file changed, 2 insertions(+), 29 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index c882314569..6196c80e9f 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -4972,13 +4972,10 @@ def __init__(self, filter_size, padding, strides=1, dilation_rate=1, groups=1, self.output_before_activation = OutputWithActivation(y) y = self.output_before_activation.y self.output.placeholder = y - self.set_output_dim_tags( - self.output, num_batch_dims=num_batch_dims, in_spatial_dims=in_spatial_dims_, out_spatial_dims=out_spatial_dims, - filter_size=filter_size, strides=strides, dilation_rate=dilation_rate, padding=padding) @classmethod def set_output_dim_tags(cls, output, num_batch_dims, in_spatial_dims, out_spatial_dims, - filter_size, strides, dilation_rate, padding, calc_dyn_size=True): + filter_size, strides, dilation_rate, padding): """ :param Data output: :param int num_batch_dims: @@ -4988,7 +4985,6 @@ def set_output_dim_tags(cls, output, num_batch_dims, in_spatial_dims, out_spatia :param list[int]|tuple[int] strides: :param list[int]|tuple[int] dilation_rate: :param str padding: - :param bool calc_dyn_size: """ if output.feature_dim_axis == num_batch_dims: out_spatial_dims_ = output.dim_tags[num_batch_dims + 1:] @@ -5008,25 +5004,6 @@ def set_output_dim_tags(cls, output, num_batch_dims, in_spatial_dims, out_spatia dilation_rate=dilation_rate[i], padding=padding) assert isinstance(out_tag_calc, Dim) out_tag_calc.declare_same_as(out_tag) - if in_tag.dimension is not None: - size = in_tag.dimension - size = cls.calc_out_dim( - in_dim=size, - filter_size=filter_size[i], stride=strides[i], - dilation_rate=dilation_rate[i], padding=padding) - assert out_tag.dimension == size - elif in_tag.dimension is None and in_tag.dyn_size is not None and calc_dyn_size: - size = in_tag.dyn_size - with tf_util.same_control_flow_ctx(size): - size = cls.calc_out_dim( - in_dim=size, - filter_size=filter_size[i], stride=strides[i], - dilation_rate=dilation_rate[i], padding=padding) - size_tag = Dim.get_tag_from_size_tensor(size) - if not size_tag: - out_tag.set_tag_on_size_tensor(size, batch=in_tag.batch) - else: - out_tag.declare_same_as(size_tag) @classmethod def _check_defined_in_spatial_dims(cls, cond): @@ -5273,8 +5250,7 @@ def get_out_data_from_opts( if len(old_spatial_dim_tags) == len(filter_size): cls.set_output_dim_tags( out, num_batch_dims=num_batch_dims, in_spatial_dims=old_spatial_dim_tags, out_spatial_dims=out_spatial_dims, - filter_size=filter_size, strides=strides, dilation_rate=dilation_rate, padding=padding, - calc_dyn_size=False) + filter_size=filter_size, strides=strides, dilation_rate=dilation_rate, padding=padding) return out def get_dep_layers(self): @@ -5398,9 +5374,6 @@ def __init__(self, mode, pool_size, padding="VALID", dilation_rate=1, strides=No if num_batch_dims > 1: y = tf.reshape(y, tf.concat([extended_batch_shape, tf.shape(y)[1:]], axis=0)) self.output.placeholder = y - ConvLayer.set_output_dim_tags( - self.output, num_batch_dims=num_batch_dims, in_spatial_dims=in_spatial_dims_, out_spatial_dims=out_spatial_dims, - filter_size=pool_size, strides=strides, dilation_rate=dilation_rate, padding=padding) @classmethod def get_out_data_from_opts(cls, name, sources, network, From 48285ec84701d43cd1eee70c6b6b168b4bba59ef Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 13:09:57 +0100 Subject: [PATCH 16/29] ResizeLayer, declare_same_as on out_dim, cleanup #955 --- returnn/tf/layers/basic.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 6196c80e9f..af541ff203 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -7186,7 +7186,7 @@ def __init__(self, factor, axis, out_dim=None, kind="nn", fill_value=None, fill_ :param Dim|None out_dim: :param str kind: "linear", "nn"/"nearest_neighbor", "cubic", "fill" :param None|int|float fill_value: if kind=="fill" - :param float fill_dropout: if set, will dropout in the same axis + :param float|None fill_dropout: if set, will dropout in the same axis """ out_dim # noqa # via get_out_data_from_opts super(ResizeLayer, self).__init__(**kwargs) @@ -7196,10 +7196,6 @@ def __init__(self, factor, axis, out_dim=None, kind="nn", fill_value=None, fill_ assert axis > 0, "batch-dim resize not supported" input_data = input_data.copy_move_axis(old_axis=axis, new_axis=1) axis = 1 - self.output.placeholder = input_data.placeholder - out_dyn_size = input_data.dim_tags[axis].dyn_size - if out_dyn_size is not None: - out_dyn_size = out_dyn_size * factor # images expected shape: [batch, height, width, channels] remaining_axes = [i for i in range(self.output.batch_ndim) if i not in (0, axis)] @@ -7248,21 +7244,23 @@ def __init__(self, factor, axis, out_dim=None, kind="nn", fill_value=None, fill_ mask = expand_dims_unbroadcast(mask, axis=0, dim=shape[0]) # (batch,new_size) x = tf.boolean_mask(x, mask) # [batch*new_size_dropped] + remaining_shape x = tf.reshape(x, [shape[0], new_size_dropped] + remaining_shape) # [batch, new_size_dropped] + remaining_shape + out_dyn_size = input_data.dim_tags[axis].dyn_size if out_dyn_size is not None: + out_dyn_size = out_dyn_size * factor orig_mask = tf.sequence_mask( out_dyn_size, maxlen=new_size, dtype=tf.bool) # (batch,new_size) out_dyn_size = tf.reduce_sum(tf.cast(tf.logical_and(mask, orig_mask), tf.int32), axis=1) - if out_dyn_size is not None: - self.output.dim_tags[axis].dyn_size = out_dyn_size + self.output.dim_tags[axis].dyn_size = out_dyn_size self.output.placeholder = x @classmethod - def get_out_data_from_opts(cls, factor, axis, sources, name, out_dim=None, **kwargs): + def get_out_data_from_opts(cls, factor, axis, sources, name, fill_dropout=None, out_dim=None, **kwargs): """ :param int factor: :param Dim|str axis: :param list[LayerBase] sources: :param str name: + :param float|None fill_dropout: :param Dim|None out_dim: :rtype: Data """ @@ -7273,12 +7271,13 @@ def get_out_data_from_opts(cls, factor, axis, sources, name, out_dim=None, **kwa axis = 1 assert axis != out.batch_dim_axis, "batch-dim resize not supported" tag = out.dim_tags[axis] - dim = None if tag.dimension is None else (tag.dimension * factor) - if out_dim: - assert out_dim.dimension == dim + if fill_dropout: + out_dim_ = Dim(kind=tag.kind, description="%s_resize" % name, auto_generated=True) # unknown dim else: - out_dim = Dim(kind=tag.kind, description="%s_resize" % name, dimension=dim, auto_generated=True) - return out.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim) + out_dim_ = tag * factor + if out_dim: + out_dim_.declare_same_as(out_dim) + return out.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim_) class CombineDimsLayer(MergeDimsLayer): From 183617757a6d39544ddb3fb6aec00890ce35a68a Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 13:26:42 +0100 Subject: [PATCH 17/29] ResizeLayer, fixes --- returnn/tf/layers/basic.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index af541ff203..0d9953d8f7 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -7192,16 +7192,16 @@ def __init__(self, factor, axis, out_dim=None, kind="nn", fill_value=None, fill_ super(ResizeLayer, self).__init__(**kwargs) # self.output.shape and self.output.batch_dim_axis are already set here via self.get_out_data_from_opts(). input_data = self.input_data.copy_as_batch_major() - axis = self.input_data.get_axis_from_description(axis) + axis = input_data.get_axis_from_description(axis) assert axis > 0, "batch-dim resize not supported" input_data = input_data.copy_move_axis(old_axis=axis, new_axis=1) axis = 1 # images expected shape: [batch, height, width, channels] remaining_axes = [i for i in range(self.output.batch_ndim) if i not in (0, axis)] - x = dimshuffle(self.output.placeholder, [0, axis, 'x'] + remaining_axes) # [batch,height,width] + remaining_axes + x = dimshuffle(input_data.placeholder, [0, axis, 'x'] + remaining_axes) # [batch,height,width] + remaining_axes from returnn.tf.util.basic import get_shape, optional_mul - shape = get_shape(self.output.placeholder) + shape = get_shape(input_data.placeholder) remaining_shape = [shape[i] for i in remaining_axes] remaining_dim = optional_mul(*remaining_shape) if remaining_axes else 1 x = tf.reshape(x, [shape[0], shape[axis], 1, remaining_dim]) # [batch,height,width,channels] From 9bb81e3b28c7a0a677453710a683b14b76372b6b Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 13:31:08 +0100 Subject: [PATCH 18/29] PrefixInTimeLayer, declare_same_as on out_dim, cleanup #955 --- returnn/tf/layers/basic.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 0d9953d8f7..4d14568ea5 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -6363,10 +6363,6 @@ def __init__(self, axis="T", out_dim=None, prefix=0.0, repeat=1, size_base=None, assert repeat >= 0 self.repeat = repeat out_dim = self.output.dim_tags[axis_int] - if in_dim.dyn_size is not None and out_dim.dyn_size is None: - if not out_dim.dyn_size_ext: - out_dim.dyn_size_ext = in_dim.dyn_size_ext.copy() - out_dim.dyn_size_ext.placeholder = in_dim.dyn_size_ext.placeholder + repeat max_repeat = repeat if isinstance(repeat, int) else tf.maximum(tf.reduce_max(repeat), 0) shape = [((self.output.batch_shape[i] or tf.shape(input_data.placeholder)[i]) if (i != axis_int) @@ -6417,17 +6413,15 @@ def get_out_data_from_opts(cls, name, sources, axis="T", out_dim=None, size_base x = get_concat_sources_data_template(sources, name="%s_output" % name) axis_int = x.get_axis_from_description(axis, allow_int=False) in_dim = x.dim_tags[axis_int] - out_dim_int = None - if in_dim.dimension and isinstance(repeat, int): - out_dim_int = in_dim.dimension + repeat if size_base: assert not out_dim - out_dim = size_base.output.get_time_dim_tag() - if not out_dim: - out_dim = ( + out_dim_ = size_base.output.get_time_dim_tag() + else: + out_dim_ = ( repeat if isinstance(repeat, int) else SpatialDim("%s:repeat" % repeat.name, auto_generated=True)) + in_dim - assert out_dim.dimension == out_dim_int - x = x.copy_template_replace_dim_tag(axis=axis_int, new_dim_tag=out_dim) + if out_dim: + out_dim_.declare_same_as(out_dim) + x = x.copy_template_replace_dim_tag(axis=axis_int, new_dim_tag=out_dim_) if isinstance(repeat, LayerBase): x = x.copy_as_batch_spatial_major() return x From e01b32a4a0f5ac798fd1960ae7870fe191f614bf Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 13:32:46 +0100 Subject: [PATCH 19/29] PostfixInTimeLayer, declare_same_as on out_dim, cleanup #955 --- returnn/tf/layers/basic.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 4d14568ea5..1af62fac38 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -6476,8 +6476,6 @@ def __init__(self, axis="T", out_dim=None, postfix=0.0, repeat=1, **kwargs): seq_mask = tf.less(idx_range, size_ext.placeholder) assert seq_mask.get_shape().ndims == self.output.batch_ndim self.output.placeholder = tf_util.where_bc(seq_mask, x, c) - out_dim = self.output.dim_tags[axis_int] - out_dim.dyn_size = in_dim.dyn_size + repeat @classmethod def get_out_data_from_opts(cls, name, sources, axis="T", out_dim=None, postfix=0.0, repeat=1, **kwargs): @@ -6493,13 +6491,10 @@ def get_out_data_from_opts(cls, name, sources, axis="T", out_dim=None, postfix=0 x = get_concat_sources_data_template(sources, name="%s_output" % name) axis_int = x.get_axis_from_description(axis, allow_int=False) in_dim = x.dim_tags[axis_int] - out_dim_int = None - if in_dim.dimension: - out_dim_int = in_dim.dimension + repeat - if not out_dim: - out_dim = in_dim + repeat - assert out_dim.dimension == out_dim_int - x = x.copy_template_replace_dim_tag(axis=axis_int, new_dim_tag=out_dim) + out_dim_ = in_dim + repeat + if out_dim: + out_dim_.declare_same_as(out_dim) + x = x.copy_template_replace_dim_tag(axis=axis_int, new_dim_tag=out_dim_) return x @classmethod From 7d2955cd52aac2b000d812bbaf036134d4aa40ee Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 14:29:50 +0100 Subject: [PATCH 20/29] SliceLayer, declare_same_as on out_dim, cleanup #955 --- returnn/tf/layers/basic.py | 44 +++++++++++++------------------------- 1 file changed, 15 insertions(+), 29 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 1af62fac38..4d95971c05 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -1007,29 +1007,6 @@ def __init__(self, axis, slice_start=None, slice_end=None, slice_step=None, out_ axis = self.input_data.get_axis_from_description(axis) dim_slice = slice(slice_start, slice_end, slice_step) slices = [slice(None, None)] * axis + [dim_slice] - input_dim_tag = self.input_data.dim_tags[axis] - output_dim_tag = self.output.dim_tags[axis] - if input_dim_tag.dyn_size_ext and output_dim_tag.dyn_size is None: - assert input_dim_tag.dyn_size_ext.placeholder is not None - if not output_dim_tag.dyn_size_ext: - output_dim_tag.dyn_size_ext = input_dim_tag.dyn_size_ext.copy_template(name="%s:slice-size" % self.name) - dyn_size = input_dim_tag.dyn_size_ext.placeholder - if slice_start: - assert slice_start > 0 - dyn_size = tf.maximum(0, dyn_size - slice_start) - if slice_end is not None: - if slice_end >= 0: - dyn_size = tf.minimum(slice_end, dyn_size) - else: # slice_end < 0 - dyn_size = tf.maximum(0, dyn_size + slice_end) - if slice_step: - dyn_size = tf.cast(tf_compat.v1.ceil(tf.divide(dyn_size, slice_step)), tf.int32) - output_dim_tag.dyn_size_ext.placeholder = dyn_size - existing_tag = Dim.get_tag_from_size_tensor(dyn_size) - if existing_tag: - output_dim_tag.declare_same_as(existing_tag) - else: - output_dim_tag.set_tag_on_size_tensor(dyn_size) self.output.placeholder = self.input_data.placeholder[slices] @classmethod @@ -1051,13 +1028,22 @@ def get_out_data_from_opts( input_data = get_concat_sources_data_template(sources) axis = input_data.get_axis_from_description(axis) dim_tag = input_data.dim_tags[axis] - dim_slice = slice(slice_start, slice_end, slice_step) - new_dim = len(range(dim_tag.dimension)[dim_slice]) if dim_tag.dimension is not None else None + out_dim_ = dim_tag + if slice_start: + assert slice_start >= 0 + out_dim_ = out_dim_.sub_left(slice_start) + if slice_end is not None: + if slice_end >= 0: + out_dim_ = Dim( + kind=dim_tag.kind, description="%s:slice-end" % name, + dimension=slice_end - (slice_start or 0), auto_generated=True) + else: # slice_end < 0 + out_dim_ = out_dim_ + slice_end + if slice_step and slice_step != 1: + out_dim_ = out_dim_.ceildiv_right(abs(slice_step)) if out_dim: - assert out_dim.dimension == new_dim - else: - out_dim = Dim(kind=dim_tag.kind, description="%s:slice" % name, dimension=new_dim, auto_generated=True) - return input_data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim, name="%s_output" % name) + out_dim_.declare_same_as(out_dim) + return input_data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim_, name="%s_output" % name) class SliceNdLayer(_ConcatInputLayer): From 75b2313c0afe282ff6cd159dedba1f9099a8b934 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 14:32:42 +0100 Subject: [PATCH 21/29] Dim complete_dyn_size, small fix for is_negative check --- returnn/tf/util/data.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index b6a12bb89e..3f69ee620c 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -614,14 +614,18 @@ def complete_dyn_size(self): if kind.endswith("_left"): kind = kind[:-len("_left")] + import numpy + from tensorflow.python.framework import tensor_util + def _is_negative(x__): + if isinstance(x__, numpy.ndarray): + return (x__ < 0).any() if isinstance(x__, (int, float)): return x__ < 0 assert isinstance(x__, tf.Tensor) - from tensorflow.python.framework import tensor_util x__ = tensor_util.constant_value(x__) if x__ is not None: - return x__ < 0 + return _is_negative(x__) return False def _bin_op(a, b): From 09e5bddae38dcad411e847768c172faa0808bce5 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 14:40:17 +0100 Subject: [PATCH 22/29] SliceLayer, small fix --- returnn/tf/layers/basic.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 4d95971c05..404f6e997d 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -1007,7 +1007,9 @@ def __init__(self, axis, slice_start=None, slice_end=None, slice_step=None, out_ axis = self.input_data.get_axis_from_description(axis) dim_slice = slice(slice_start, slice_end, slice_step) slices = [slice(None, None)] * axis + [dim_slice] - self.output.placeholder = self.input_data.placeholder[slices] + y = self.input_data.placeholder[slices] + y.set_shape(self.output.batch_shape) # can be necessary for slice_end>0 + self.output.placeholder = y @classmethod def get_out_data_from_opts( From 43fd1274ed60cef2a08a60723e9e741f529c1fdf Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 15:08:43 +0100 Subject: [PATCH 23/29] SliceNdLayer, declare_same_as on out_spatial_dim #955 --- returnn/tf/layers/basic.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 404f6e997d..057859399a 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -1218,13 +1218,14 @@ def get_out_data_from_opts(cls, name, sources=(), start=None, size=None, axis="T else: # size might be None here in which case we set the dyn_size in __init__ assert size is None or isinstance(size, int) + out_spatial_dim_ = Dim( + kind=Dim.Types.Spatial, + description="sliced-time:%s" % name, + dimension=size, auto_generated=True) if out_spatial_dim: - assert out_spatial_dim.dimension == size + out_spatial_dim_.declare_same_as(out_spatial_dim) else: - out_spatial_dim = Dim( - kind=Dim.Types.Spatial, - description="sliced-time:%s" % name, - dimension=size, auto_generated=True) + out_spatial_dim = out_spatial_dim_ gather_positions_data = gather_positions_data.copy_add_dim_by_tag( out_spatial_dim, unbroadcast=True, axis=start_data.batch_ndim) position = InternalLayer( From 7d3fc59323b23f95aa3f20ceaf48449602ba2c5e Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 15:41:22 +0100 Subject: [PATCH 24/29] RepeatLayer, declare_same_as on out_dim #955 --- returnn/tf/layers/basic.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 057859399a..f25657bc59 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -4352,12 +4352,13 @@ def get_out_data_from_opts(cls, name, sources, axis, repetitions, out_dim=None, original_axis = data.get_axis_from_description(axis, allow_int=False) tag = data.dim_tags[original_axis] data = data.copy_move_axis(original_axis, data.get_batch_axis(0)) - if not out_dim: - if isinstance(repetitions, int): - out_dim = tag * repetitions - else: - out_dim = Dim(description="repeated:%s" % name, kind=tag.kind, derived_from_tag=tag, auto_generated=True) - return data.copy_template_replace_dim_tag(axis=data.get_batch_axis(0), new_dim_tag=out_dim) + if isinstance(repetitions, int): + out_dim_ = tag * repetitions + else: + out_dim_ = Dim(description="repeated:%s" % name, kind=tag.kind, derived_from_tag=tag, auto_generated=True) + if out_dim: + out_dim_.declare_same_as(out_dim) + return data.copy_template_replace_dim_tag(axis=data.get_batch_axis(0), new_dim_tag=out_dim_) class TileLayer(_ConcatInputLayer): From 0f9cc6107443e0d216420e95e5ad81d68447e406 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 15:52:09 +0100 Subject: [PATCH 25/29] ReduceOutLayer, declare_same_as on out_dim #955 --- returnn/tf/layers/basic.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index f25657bc59..fd9a453758 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -6008,10 +6008,10 @@ def get_out_data_from_opts(cls, num_pieces, sources, name, out_dim=None, **kwarg assert not out.sparse assert out.dim % num_pieces == 0 dim = out.dim // num_pieces - if not out_dim: - out_dim = out.feature_dim_or_sparse_dim // num_pieces - assert out_dim.dimension == dim - return out.copy_template_replace_dim_tag(axis=out.feature_dim_axis, new_dim_tag=out_dim) + out_dim_ = out.feature_dim_or_sparse_dim // num_pieces + if out_dim: + out_dim_.declare_same_as(out_dim) + return out.copy_template_replace_dim_tag(axis=out.feature_dim_axis, new_dim_tag=out_dim_) class SqueezeLayer(_ConcatInputLayer): From 5137a827b40db9e7ea17852fa46b79737c6f2203 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 15:53:02 +0100 Subject: [PATCH 26/29] StackLayer, declare_same_as on out_spatial_dim #955 --- returnn/tf/layers/basic.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index fd9a453758..bff898ab4a 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -6128,11 +6128,11 @@ def get_out_data_from_opts(cls, name, sources, axis=None, out_spatial_dim=None, if axis is None: axis = axis_ out = common_source.copy_template(name="%s_output" % name) - if not out_spatial_dim: - out_spatial_dim = Dim( - kind=Dim.Types.Spatial, description="%s:stack" % name, dimension=len(sources), auto_generated=True) - assert out_spatial_dim.dimension == len(sources) - out = out.copy_add_dim_by_tag(axis=axis, dim_tag=out_spatial_dim, unbroadcast=True) + out_spatial_dim_ = Dim( + kind=Dim.Types.Spatial, description="%s:stack" % name, dimension=len(sources), auto_generated=True) + if out_spatial_dim: + out_spatial_dim_.declare_same_as(out_spatial_dim) + out = out.copy_add_dim_by_tag(axis=axis, dim_tag=out_spatial_dim_, unbroadcast=True) return out From 02548e47d711161d126a006b7ff2a39280cd3104 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 15:56:23 +0100 Subject: [PATCH 27/29] TileLayer, declare_same_as on out_dims #955 --- returnn/tf/layers/basic.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index bff898ab4a..5740c77787 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -4400,13 +4400,9 @@ def get_out_data_from_opts(cls, name, sources, multiples, out_dims=None, **kwarg dim_tags = list(data.dim_tags) for axis, multiple in multiples.items(): axis_int = data.get_axis_from_description(axis, allow_int=False) - tag = dim_tags[axis_int] - dim = None if tag.dimension is None else (tag.dimension * multiple) + tag = multiple * dim_tags[axis_int] if out_dims and axis in out_dims: - tag = out_dims[axis] - assert tag.dimension == dim - else: - tag = multiple * tag + tag.declare_same_as(out_dims[axis]) dim_tags[axis_int] = tag return data.copy_template_new_dim_tags(dim_tags, keep_special_axes=True) From 6a16a957265c49a572d238c0867f93692186ff42 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 16:02:07 +0100 Subject: [PATCH 28/29] GatingLayer, declare_same_as on out_dim #955 --- returnn/tf/layers/basic.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 5740c77787..79814eee85 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -2944,18 +2944,16 @@ def get_out_data_from_opts(cls, name, sources, n_out=NotSpecified, out_dim=None, input_data = get_concat_sources_data_template(sources) assert not input_data.sparse assert input_data.dim % 2 == 0 - dim = input_data.dim // 2 + out_dim_ = input_data.dim_tags[input_data.feature_dim_axis] // 2 if out_dim: - assert out_dim.dimension == dim - else: - out_dim = FeatureDim("%s:gating" % name, dimension=dim, auto_generated=True) + out_dim_.declare_same_as(out_dim) if n_out is not NotSpecified: - assert n_out == dim + assert n_out == input_data.dim // 2 return Data( name="%s_output" % name, dtype=input_data.dtype, dim_tags=[ - out_dim if i == input_data.feature_dim_axis else d + out_dim_ if i == input_data.feature_dim_axis else d for (i, d) in enumerate(input_data.dim_tags)], sparse=False, time_dim_axis=input_data.time_dim_axis, From 802399dd782b8ded29ac784d34555afcb627c45b Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 3 Mar 2022 16:15:28 +0100 Subject: [PATCH 29/29] ReduceOutLayer, fix warning --- returnn/tf/layers/basic.py | 1 - 1 file changed, 1 deletion(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 79814eee85..9c20e2bde1 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -6001,7 +6001,6 @@ def get_out_data_from_opts(cls, num_pieces, sources, name, out_dim=None, **kwarg assert out.have_feature_axis() assert not out.sparse assert out.dim % num_pieces == 0 - dim = out.dim // num_pieces out_dim_ = out.feature_dim_or_sparse_dim // num_pieces if out_dim: out_dim_.declare_same_as(out_dim)