diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index a16f39751c..9c20e2bde1 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -1007,30 +1007,9 @@ def __init__(self, axis, slice_start=None, slice_end=None, slice_step=None, out_ axis = self.input_data.get_axis_from_description(axis) dim_slice = slice(slice_start, slice_end, slice_step) slices = [slice(None, None)] * axis + [dim_slice] - input_dim_tag = self.input_data.dim_tags[axis] - output_dim_tag = self.output.dim_tags[axis] - if input_dim_tag.dyn_size_ext and output_dim_tag.dyn_size is None: - assert input_dim_tag.dyn_size_ext.placeholder is not None - if not output_dim_tag.dyn_size_ext: - output_dim_tag.dyn_size_ext = input_dim_tag.dyn_size_ext.copy_template(name="%s:slice-size" % self.name) - dyn_size = input_dim_tag.dyn_size_ext.placeholder - if slice_start: - assert slice_start > 0 - dyn_size = tf.maximum(0, dyn_size - slice_start) - if slice_end is not None: - if slice_end >= 0: - dyn_size = tf.minimum(slice_end, dyn_size) - else: # slice_end < 0 - dyn_size = tf.maximum(0, dyn_size + slice_end) - if slice_step: - dyn_size = tf.cast(tf_compat.v1.ceil(tf.divide(dyn_size, slice_step)), tf.int32) - output_dim_tag.dyn_size_ext.placeholder = dyn_size - existing_tag = Dim.get_tag_from_size_tensor(dyn_size) - if existing_tag: - output_dim_tag.declare_same_as(existing_tag) - else: - output_dim_tag.set_tag_on_size_tensor(dyn_size) - self.output.placeholder = self.input_data.placeholder[slices] + y = self.input_data.placeholder[slices] + y.set_shape(self.output.batch_shape) # can be necessary for slice_end>0 + self.output.placeholder = y @classmethod def get_out_data_from_opts( @@ -1051,13 +1030,22 @@ def get_out_data_from_opts( input_data = get_concat_sources_data_template(sources) axis = input_data.get_axis_from_description(axis) dim_tag = input_data.dim_tags[axis] - dim_slice = slice(slice_start, slice_end, slice_step) - new_dim = len(range(dim_tag.dimension)[dim_slice]) if dim_tag.dimension is not None else None + out_dim_ = dim_tag + if slice_start: + assert slice_start >= 0 + out_dim_ = out_dim_.sub_left(slice_start) + if slice_end is not None: + if slice_end >= 0: + out_dim_ = Dim( + kind=dim_tag.kind, description="%s:slice-end" % name, + dimension=slice_end - (slice_start or 0), auto_generated=True) + else: # slice_end < 0 + out_dim_ = out_dim_ + slice_end + if slice_step and slice_step != 1: + out_dim_ = out_dim_.ceildiv_right(abs(slice_step)) if out_dim: - assert out_dim.dimension == new_dim - else: - out_dim = Dim(kind=dim_tag.kind, description="%s:slice" % name, dimension=new_dim, auto_generated=True) - return input_data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim, name="%s_output" % name) + out_dim_.declare_same_as(out_dim) + return input_data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim_, name="%s_output" % name) class SliceNdLayer(_ConcatInputLayer): @@ -1230,13 +1218,14 @@ def get_out_data_from_opts(cls, name, sources=(), start=None, size=None, axis="T else: # size might be None here in which case we set the dyn_size in __init__ assert size is None or isinstance(size, int) + out_spatial_dim_ = Dim( + kind=Dim.Types.Spatial, + description="sliced-time:%s" % name, + dimension=size, auto_generated=True) if out_spatial_dim: - assert out_spatial_dim.dimension == size + out_spatial_dim_.declare_same_as(out_spatial_dim) else: - out_spatial_dim = Dim( - kind=Dim.Types.Spatial, - description="sliced-time:%s" % name, - dimension=size, auto_generated=True) + out_spatial_dim = out_spatial_dim_ gather_positions_data = gather_positions_data.copy_add_dim_by_tag( out_spatial_dim, unbroadcast=True, axis=start_data.batch_ndim) position = InternalLayer( @@ -2955,18 +2944,16 @@ def get_out_data_from_opts(cls, name, sources, n_out=NotSpecified, out_dim=None, input_data = get_concat_sources_data_template(sources) assert not input_data.sparse assert input_data.dim % 2 == 0 - dim = input_data.dim // 2 + out_dim_ = input_data.dim_tags[input_data.feature_dim_axis] // 2 if out_dim: - assert out_dim.dimension == dim - else: - out_dim = FeatureDim("%s:gating" % name, dimension=dim, auto_generated=True) + out_dim_.declare_same_as(out_dim) if n_out is not NotSpecified: - assert n_out == dim + assert n_out == input_data.dim // 2 return Data( name="%s_output" % name, dtype=input_data.dtype, dim_tags=[ - out_dim if i == input_data.feature_dim_axis else d + out_dim_ if i == input_data.feature_dim_axis else d for (i, d) in enumerate(input_data.dim_tags)], sparse=False, time_dim_axis=input_data.time_dim_axis, @@ -3005,6 +2992,7 @@ def __init__(self, window_size=None, window_dim=None, window_left=None, window_r :param int stride: return only each Nth window :param kwargs: """ + out_spatial_dim # noqa # via get_out_data_from_opts super(WindowLayer, self).__init__(**kwargs) if not window_size: assert window_dim and window_dim.dimension @@ -3030,25 +3018,6 @@ def __init__(self, window_size=None, window_dim=None, window_left=None, window_r else: axis = data.get_axis_from_description(axis) new_dim_axis = axis + 1 # new axis will be added right after - in_spatial_dim = data.dim_tags[axis] - out_spatial_dim_ = self.output.dim_tags[axis] - if out_spatial_dim: - assert out_spatial_dim_ == out_spatial_dim - if (padding.lower() == "same" or window_size == 1) and stride == 1: # no change in spatial dim - assert in_spatial_dim == out_spatial_dim - if in_spatial_dim != out_spatial_dim_ and out_spatial_dim_.dimension is None: - if not out_spatial_dim_.dyn_size_ext: - out_spatial_dim_.dyn_size_ext = in_spatial_dim.dyn_size_ext.copy_template(name="%s:spatial-size" % self.name) - if out_spatial_dim_.dyn_size_ext.placeholder is None: - from ..util.basic import same_control_flow_ctx - from ..util.data import Dim - assert in_spatial_dim.dyn_size is not None - size = in_spatial_dim.dyn_size - with same_control_flow_ctx(size): - size = ConvLayer.calc_out_dim( - in_dim=size, - filter_size=window_size, stride=stride, dilation_rate=1, padding=padding) - out_spatial_dim_.dyn_size_ext.placeholder = size from returnn.tf.util.basic import windowed_nd self.output.placeholder = windowed_nd( @@ -3087,27 +3056,19 @@ def get_out_data_from_opts(cls, name, network, sources, window_size=None, window else: axis = data.get_axis_from_description(axis) in_spatial_dim = data.dim_tags[axis] - if (padding.lower() == "same" or window_size == 1) and stride == 1: # no change in spatial dim - out_spatial_dim = in_spatial_dim # error check in __init__ - else: # new spatial dim - if not out_spatial_dim: - dim = None - if in_spatial_dim.dimension is not None: - dim = ConvLayer.calc_out_dim( - in_dim=in_spatial_dim.dimension, - filter_size=window_size, stride=stride, dilation_rate=1, padding=padding) - out_spatial_dim = Dim( - kind=Dim.Types.Spatial, description="%s:spatial" % name, - dimension=dim, derived_from_tag=in_spatial_dim, auto_generated=True, - batch=data.batch, control_flow_ctx=data.control_flow_ctx) - data = data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_spatial_dim) + out_spatial_dim_ = ConvLayer.calc_out_dim( + in_dim=in_spatial_dim, + filter_size=window_size, stride=stride, dilation_rate=1, padding=padding) + assert isinstance(out_spatial_dim_, Dim) + if out_spatial_dim: + out_spatial_dim_.declare_same_as(out_spatial_dim) + data = data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_spatial_dim_) new_dim_axis = axis + 1 # add new axis right after + window_dim_ = Dim( + kind=Dim.Types.Spatial, description="%s:window" % name, dimension=window_size, auto_generated=True) if window_dim: - assert window_dim.dimension == window_size - else: - window_dim = Dim( - kind=Dim.Types.Spatial, description="%s:window" % name, dimension=window_size, auto_generated=True) - return data.copy_add_dim_by_tag(axis=new_dim_axis, dim_tag=window_dim, unbroadcast=True) + window_dim_.declare_same_as(window_dim) + return data.copy_add_dim_by_tag(axis=new_dim_axis, dim_tag=window_dim_, unbroadcast=True) # noinspection PyMethodOverriding @classmethod @@ -3231,25 +3192,6 @@ def __init__(self, axes, padding, out_dims=None, value=0, mode="constant", **kwa self.output.placeholder = tf_util.pad_replicate(self.input_data.placeholder, axes, padding) else: self.output.placeholder = tf.pad(self.input_data.placeholder, paddings=paddings, mode=mode, constant_values=value) - for a in axes: - p = sum(paddings[a]) - in_tag = self.input_data.dim_tags[a] - out_tag = self.output.dim_tags[a] - a = self.input_data.get_batch_axis_excluding_batch(a) - if a is None: - continue - if in_tag.dyn_size is None: - continue - if p == 0: - continue - size = in_tag.dyn_size - with tf_util.same_control_flow_ctx(size): - size = tf_util.simplify_add(size, p) - size_tag = Dim.get_tag_from_size_tensor(size) - if not size_tag: - out_tag.set_tag_on_size_tensor(size, batch=in_tag.batch) - else: - out_tag.declare_same_as(size_tag) @classmethod def _transform_padding(cls, padding, axes): @@ -3302,10 +3244,8 @@ def get_out_data_from_opts(cls, name, sources, axes, padding, out_dims=None, **k pad_left, pad_right = padding[i] out_dim = pad_left + dim_tags[a] + pad_right if out_dims: - assert out_dims[i].dimension == out_dim.dimension - dim_tags[a] = out_dims[i] - else: - dim_tags[a] = out_dim + out_dim.declare_same_as(out_dims[i]) + dim_tags[a] = out_dim return data.copy_template_new_dim_tags(dim_tags, keep_special_axes=True) @@ -3488,6 +3428,7 @@ def get_out_data_from_opts(cls, name, axes, keep_order=NotSpecified, :rtype: Data """ from returnn.util import BehaviorVersion + from returnn.util.basic import prod if keep_order is NotSpecified: keep_order = True if BehaviorVersion.get() >= 6 else False assert not out_type, "currently ignored" @@ -3497,21 +3438,18 @@ def get_out_data_from_opts(cls, name, axes, keep_order=NotSpecified, data = input_data.copy(name="%s_output" % name) if len(axes) <= 1: return data - import numpy res_dim = None if all([data.batch_shape[i] is not None for i in axes]): - res_dim = int(numpy.prod([data.batch_shape[i] for i in axes])) + res_dim = int(prod([data.batch_shape[i] for i in axes])) merge_dim_tags = [data.dim_tags[axis] for axis in axes] merge_target_axis = cls._get_target_axis(input_data=data, merge_axes=axes) + out_dim_ = prod(merge_dim_tags) + assert isinstance(out_dim_, Dim) + assert out_dim_.dimension == res_dim if out_dim: - assert out_dim.dimension == res_dim - else: - from numpy import prod - out_dim = prod(merge_dim_tags) - assert isinstance(out_dim, Dim) - assert out_dim.dimension == res_dim + out_dim_.declare_same_as(out_dim) new_dim_tags = [d for (i, d) in enumerate(data.dim_tags) if i not in axes] - new_dim_tags.insert(merge_target_axis, out_dim) + new_dim_tags.insert(merge_target_axis, out_dim_) data_opts = data.get_kwargs(include_special_axes=False) data_opts["dim_tags"] = new_dim_tags @@ -4412,12 +4350,13 @@ def get_out_data_from_opts(cls, name, sources, axis, repetitions, out_dim=None, original_axis = data.get_axis_from_description(axis, allow_int=False) tag = data.dim_tags[original_axis] data = data.copy_move_axis(original_axis, data.get_batch_axis(0)) - if not out_dim: - if isinstance(repetitions, int): - out_dim = tag * repetitions - else: - out_dim = Dim(description="repeated:%s" % name, kind=tag.kind, derived_from_tag=tag, auto_generated=True) - return data.copy_template_replace_dim_tag(axis=data.get_batch_axis(0), new_dim_tag=out_dim) + if isinstance(repetitions, int): + out_dim_ = tag * repetitions + else: + out_dim_ = Dim(description="repeated:%s" % name, kind=tag.kind, derived_from_tag=tag, auto_generated=True) + if out_dim: + out_dim_.declare_same_as(out_dim) + return data.copy_template_replace_dim_tag(axis=data.get_batch_axis(0), new_dim_tag=out_dim_) class TileLayer(_ConcatInputLayer): @@ -4459,13 +4398,9 @@ def get_out_data_from_opts(cls, name, sources, multiples, out_dims=None, **kwarg dim_tags = list(data.dim_tags) for axis, multiple in multiples.items(): axis_int = data.get_axis_from_description(axis, allow_int=False) - tag = dim_tags[axis_int] - dim = None if tag.dimension is None else (tag.dimension * multiple) + tag = multiple * dim_tags[axis_int] if out_dims and axis in out_dims: - tag = out_dims[axis] - assert tag.dimension == dim - else: - tag = multiple * tag + tag.declare_same_as(out_dims[axis]) dim_tags[axis_int] = tag return data.copy_template_new_dim_tags(dim_tags, keep_special_axes=True) @@ -5021,13 +4956,10 @@ def __init__(self, filter_size, padding, strides=1, dilation_rate=1, groups=1, self.output_before_activation = OutputWithActivation(y) y = self.output_before_activation.y self.output.placeholder = y - self.set_output_dim_tags( - self.output, num_batch_dims=num_batch_dims, in_spatial_dims=in_spatial_dims_, out_spatial_dims=out_spatial_dims, - filter_size=filter_size, strides=strides, dilation_rate=dilation_rate, padding=padding) @classmethod def set_output_dim_tags(cls, output, num_batch_dims, in_spatial_dims, out_spatial_dims, - filter_size, strides, dilation_rate, padding, calc_dyn_size=True): + filter_size, strides, dilation_rate, padding): """ :param Data output: :param int num_batch_dims: @@ -5037,7 +4969,6 @@ def set_output_dim_tags(cls, output, num_batch_dims, in_spatial_dims, out_spatia :param list[int]|tuple[int] strides: :param list[int]|tuple[int] dilation_rate: :param str padding: - :param bool calc_dyn_size: """ if output.feature_dim_axis == num_batch_dims: out_spatial_dims_ = output.dim_tags[num_batch_dims + 1:] @@ -5045,7 +4976,9 @@ def set_output_dim_tags(cls, output, num_batch_dims, in_spatial_dims, out_spatia assert output.feature_dim_axis == output.batch_ndim - 1 out_spatial_dims_ = output.dim_tags[num_batch_dims:-1] if out_spatial_dims: - assert list(out_spatial_dims_) == list(out_spatial_dims) + assert len(out_spatial_dims_) == len(out_spatial_dims) + for i, (out_spatial_dim_, out_spatial_dim) in enumerate(zip(out_spatial_dims_, out_spatial_dims)): + out_spatial_dim_.declare_same_as(out_spatial_dim) assert len(out_spatial_dims_) == len(in_spatial_dims) == len(filter_size) == len(strides) == len(dilation_rate) for i, in_tag in enumerate(in_spatial_dims): out_tag = out_spatial_dims_[i] @@ -5055,25 +4988,6 @@ def set_output_dim_tags(cls, output, num_batch_dims, in_spatial_dims, out_spatia dilation_rate=dilation_rate[i], padding=padding) assert isinstance(out_tag_calc, Dim) out_tag_calc.declare_same_as(out_tag) - if in_tag.dimension is not None: - size = in_tag.dimension - size = cls.calc_out_dim( - in_dim=size, - filter_size=filter_size[i], stride=strides[i], - dilation_rate=dilation_rate[i], padding=padding) - assert out_tag.dimension == size - elif in_tag.dimension is None and in_tag.dyn_size is not None and calc_dyn_size: - size = in_tag.dyn_size - with tf_util.same_control_flow_ctx(size): - size = cls.calc_out_dim( - in_dim=size, - filter_size=filter_size[i], stride=strides[i], - dilation_rate=dilation_rate[i], padding=padding) - size_tag = Dim.get_tag_from_size_tensor(size) - if not size_tag: - out_tag.set_tag_on_size_tensor(size, batch=in_tag.batch) - else: - out_tag.declare_same_as(size_tag) @classmethod def _check_defined_in_spatial_dims(cls, cond): @@ -5184,7 +5098,7 @@ def transform_input(cls, input_data, network, in_dim=None, in_spatial_dims=None, @classmethod def calc_out_dim(cls, in_dim, filter_size, stride, padding, dilation_rate=1): """ - :param int|tf.Tensor|Dim|T in_dim: dimension in some axis + :param T|int|tf.Tensor|Dim in_dim: dimension in some axis :param int filter_size: e.g. 2, for the corresponding axis :param int stride: e.g. 1, for the corresponding axis :param int dilation_rate: e.g. 1 @@ -5212,7 +5126,7 @@ def ceildiv(a, b): if isinstance(in_dim, Dim): filter_left_dilated = (filter_size - 1) * dilation_rate // 2 filter_right_dilated = (filter_size - 1) * dilation_rate - filter_left_dilated - valid_part = (-filter_left_dilated) + in_dim + (-filter_right_dilated) + valid_part = in_dim.sub_left(filter_left_dilated).sub_right(filter_right_dilated) return valid_part.ceildiv_right(stride) return tf_util.simplify_non_negative_seq_length( ceildiv( @@ -5320,8 +5234,7 @@ def get_out_data_from_opts( if len(old_spatial_dim_tags) == len(filter_size): cls.set_output_dim_tags( out, num_batch_dims=num_batch_dims, in_spatial_dims=old_spatial_dim_tags, out_spatial_dims=out_spatial_dims, - filter_size=filter_size, strides=strides, dilation_rate=dilation_rate, padding=padding, - calc_dyn_size=False) + filter_size=filter_size, strides=strides, dilation_rate=dilation_rate, padding=padding) return out def get_dep_layers(self): @@ -5445,9 +5358,6 @@ def __init__(self, mode, pool_size, padding="VALID", dilation_rate=1, strides=No if num_batch_dims > 1: y = tf.reshape(y, tf.concat([extended_batch_shape, tf.shape(y)[1:]], axis=0)) self.output.placeholder = y - ConvLayer.set_output_dim_tags( - self.output, num_batch_dims=num_batch_dims, in_spatial_dims=in_spatial_dims_, out_spatial_dims=out_spatial_dims, - filter_size=pool_size, strides=strides, dilation_rate=dilation_rate, padding=padding) @classmethod def get_out_data_from_opts(cls, name, sources, network, @@ -5683,26 +5593,6 @@ def __init__(self, filter_size, strides=None, self.output_before_activation = OutputWithActivation(y) y = self.output_before_activation.y self.output.placeholder = y - for idx, axis in enumerate(spatial_axes): - input_tag = input_data.dim_tags[axis] - output_tag = self.output.dim_tags[axis] - if input_tag == output_tag: - continue - assert not input_tag.is_batch_dim() and not output_tag.is_batch_dim() - if input_tag.dimension is None: - assert output_tag.dimension is None - assert input_tag.dyn_size is not None - size = input_tag.dyn_size - with tf_util.same_control_flow_ctx(size): - size = self.deconv_output_length( - size, - filter_size=filter_size[idx], stride=strides[idx], - padding=padding, output_padding=output_padding[idx]) - r = remove_padding[idx] - if r: - assert isinstance(r, int) - size = tf_util.simplify_add(size, -r * 2) - output_tag.set_tag_on_size_tensor(size, batch=self.output.batch) @staticmethod def deconv_output_length(input_length, @@ -5715,7 +5605,9 @@ def deconv_output_length(input_length, Determines output length of a transposed convolution given input length. Copied from conv_utils.deconv_output_length, adapted with simplification. - :param T|int|tf.Tensor input_length: + Also see :func:`ConvLayer.calc_out_dim`. + + :param T|int|tf.Tensor|Dim input_length: :param int filter_size: :param str padding: one of `"same"`, `"valid"`, `"full"`. :param int|None output_padding: amount of padding along the output dimension. @@ -5736,13 +5628,19 @@ def deconv_output_length(input_length, # Infer length if output padding is None, else compute the exact length if output_padding is None: if padding == 'valid': - length = tf_util.simplify_add(input_length, max(filter_size - stride, 0)) + if isinstance(input_length, Dim): + length = input_length + max(filter_size - stride, 0) + else: + length = tf_util.simplify_add(input_length, max(filter_size - stride, 0)) elif padding == 'full': - length = tf_util.simplify_add(input_length, -(stride + filter_size - 2)) + if isinstance(input_length, Dim): + length = input_length - (stride + filter_size - 2) + else: + length = tf_util.simplify_add(input_length, -(stride + filter_size - 2)) elif padding == 'same': length = input_length else: - length = None + raise Exception("invalid padding %r" % (padding,)) else: # output_padding if padding == 'same': pad = filter_size // 2 @@ -5751,8 +5649,11 @@ def deconv_output_length(input_length, elif padding == 'full': pad = filter_size - 1 else: - pad = None - length = tf_util.simplify_add(input_length, -stride + filter_size - 2 * pad + output_padding) + raise Exception("invalid padding %r" % (padding,)) + if isinstance(input_length, Dim): + length = input_length + (-stride + filter_size - 2 * pad + output_padding) + else: + length = tf_util.simplify_add(input_length, -stride + filter_size - 2 * pad + output_padding) return length @classmethod @@ -5796,22 +5697,16 @@ def get_out_data_from_opts(cls, name, sources, network, dim_tags = list(data.dim_tags[:num_batch_dims]) # [B] if out_spatial_dims: assert len(out_spatial_dims) == len(filter_size) - # Be relaxed about incorrect input data. Throw errors later. This can also work during template construction. - dim_tags += out_spatial_dims - else: - for i in range(len(filter_size)): - old_tag = old_spatial_dim_tags[i] if i < len(old_spatial_dim_tags) else None - if old_tag and (filter_size[i] == strides[i] == 1 or (strides[i] == 1 and padding == "SAME")): - dim_tags.append(old_tag) # identity in this axis - continue - new_dim = None - if old_tag and old_tag.dimension is not None: - new_dim = cls.deconv_output_length( - old_tag.dimension, filter_size=filter_size[i], stride=strides[i], - padding=padding, output_padding=output_padding[i]) - remove_padding[i] * 2 - dim_tags.append(Dim( - kind=Dim.Types.Spatial, description="%s:conv:s%i" % (name, i), dimension=new_dim, - derived_from_tag=old_tag, undefined=not old_tag, auto_generated=True)) + # Be relaxed about incorrect input data. Throw errors later. This can also work during template construction. + for i in range(len(filter_size)): + old_tag = old_spatial_dim_tags[i] if i < len(old_spatial_dim_tags) else None + new_tag = cls.deconv_output_length( + old_tag, filter_size=filter_size[i], stride=strides[i], + padding=padding, output_padding=output_padding[i]) + new_tag = new_tag.sub_left(remove_padding[i]).sub_right(remove_padding[i]) + if out_spatial_dims: + new_tag.declare_same_as(out_spatial_dims[i]) + dim_tags.append(new_tag) if not out_dim: assert n_out out_dim = FeatureDim("%s:channel" % name, dimension=n_out, auto_generated=True) @@ -6106,11 +6001,10 @@ def get_out_data_from_opts(cls, num_pieces, sources, name, out_dim=None, **kwarg assert out.have_feature_axis() assert not out.sparse assert out.dim % num_pieces == 0 - dim = out.dim // num_pieces - if not out_dim: - out_dim = out.feature_dim_or_sparse_dim // num_pieces - assert out_dim.dimension == dim - return out.copy_template_replace_dim_tag(axis=out.feature_dim_axis, new_dim_tag=out_dim) + out_dim_ = out.feature_dim_or_sparse_dim // num_pieces + if out_dim: + out_dim_.declare_same_as(out_dim) + return out.copy_template_replace_dim_tag(axis=out.feature_dim_axis, new_dim_tag=out_dim_) class SqueezeLayer(_ConcatInputLayer): @@ -6227,11 +6121,11 @@ def get_out_data_from_opts(cls, name, sources, axis=None, out_spatial_dim=None, if axis is None: axis = axis_ out = common_source.copy_template(name="%s_output" % name) - if not out_spatial_dim: - out_spatial_dim = Dim( - kind=Dim.Types.Spatial, description="%s:stack" % name, dimension=len(sources), auto_generated=True) - assert out_spatial_dim.dimension == len(sources) - out = out.copy_add_dim_by_tag(axis=axis, dim_tag=out_spatial_dim, unbroadcast=True) + out_spatial_dim_ = Dim( + kind=Dim.Types.Spatial, description="%s:stack" % name, dimension=len(sources), auto_generated=True) + if out_spatial_dim: + out_spatial_dim_.declare_same_as(out_spatial_dim) + out = out.copy_add_dim_by_tag(axis=axis, dim_tag=out_spatial_dim_, unbroadcast=True) return out @@ -6452,10 +6346,6 @@ def __init__(self, axis="T", out_dim=None, prefix=0.0, repeat=1, size_base=None, assert repeat >= 0 self.repeat = repeat out_dim = self.output.dim_tags[axis_int] - if in_dim.dyn_size is not None and out_dim.dyn_size is None: - if not out_dim.dyn_size_ext: - out_dim.dyn_size_ext = in_dim.dyn_size_ext.copy() - out_dim.dyn_size_ext.placeholder = in_dim.dyn_size_ext.placeholder + repeat max_repeat = repeat if isinstance(repeat, int) else tf.maximum(tf.reduce_max(repeat), 0) shape = [((self.output.batch_shape[i] or tf.shape(input_data.placeholder)[i]) if (i != axis_int) @@ -6506,17 +6396,15 @@ def get_out_data_from_opts(cls, name, sources, axis="T", out_dim=None, size_base x = get_concat_sources_data_template(sources, name="%s_output" % name) axis_int = x.get_axis_from_description(axis, allow_int=False) in_dim = x.dim_tags[axis_int] - out_dim_int = None - if in_dim.dimension and isinstance(repeat, int): - out_dim_int = in_dim.dimension + repeat if size_base: assert not out_dim - out_dim = size_base.output.get_time_dim_tag() - if not out_dim: - out_dim = ( + out_dim_ = size_base.output.get_time_dim_tag() + else: + out_dim_ = ( repeat if isinstance(repeat, int) else SpatialDim("%s:repeat" % repeat.name, auto_generated=True)) + in_dim - assert out_dim.dimension == out_dim_int - x = x.copy_template_replace_dim_tag(axis=axis_int, new_dim_tag=out_dim) + if out_dim: + out_dim_.declare_same_as(out_dim) + x = x.copy_template_replace_dim_tag(axis=axis_int, new_dim_tag=out_dim_) if isinstance(repeat, LayerBase): x = x.copy_as_batch_spatial_major() return x @@ -6571,8 +6459,6 @@ def __init__(self, axis="T", out_dim=None, postfix=0.0, repeat=1, **kwargs): seq_mask = tf.less(idx_range, size_ext.placeholder) assert seq_mask.get_shape().ndims == self.output.batch_ndim self.output.placeholder = tf_util.where_bc(seq_mask, x, c) - out_dim = self.output.dim_tags[axis_int] - out_dim.dyn_size = in_dim.dyn_size + repeat @classmethod def get_out_data_from_opts(cls, name, sources, axis="T", out_dim=None, postfix=0.0, repeat=1, **kwargs): @@ -6588,13 +6474,10 @@ def get_out_data_from_opts(cls, name, sources, axis="T", out_dim=None, postfix=0 x = get_concat_sources_data_template(sources, name="%s_output" % name) axis_int = x.get_axis_from_description(axis, allow_int=False) in_dim = x.dim_tags[axis_int] - out_dim_int = None - if in_dim.dimension: - out_dim_int = in_dim.dimension + repeat - if not out_dim: - out_dim = in_dim + repeat - assert out_dim.dimension == out_dim_int - x = x.copy_template_replace_dim_tag(axis=axis_int, new_dim_tag=out_dim) + out_dim_ = in_dim + repeat + if out_dim: + out_dim_.declare_same_as(out_dim) + x = x.copy_template_replace_dim_tag(axis=axis_int, new_dim_tag=out_dim_) return x @classmethod @@ -7275,26 +7158,22 @@ def __init__(self, factor, axis, out_dim=None, kind="nn", fill_value=None, fill_ :param Dim|None out_dim: :param str kind: "linear", "nn"/"nearest_neighbor", "cubic", "fill" :param None|int|float fill_value: if kind=="fill" - :param float fill_dropout: if set, will dropout in the same axis + :param float|None fill_dropout: if set, will dropout in the same axis """ out_dim # noqa # via get_out_data_from_opts super(ResizeLayer, self).__init__(**kwargs) # self.output.shape and self.output.batch_dim_axis are already set here via self.get_out_data_from_opts(). input_data = self.input_data.copy_as_batch_major() - axis = self.input_data.get_axis_from_description(axis) + axis = input_data.get_axis_from_description(axis) assert axis > 0, "batch-dim resize not supported" input_data = input_data.copy_move_axis(old_axis=axis, new_axis=1) axis = 1 - self.output.placeholder = input_data.placeholder - out_dyn_size = input_data.dim_tags[axis].dyn_size - if out_dyn_size is not None: - out_dyn_size = out_dyn_size * factor # images expected shape: [batch, height, width, channels] remaining_axes = [i for i in range(self.output.batch_ndim) if i not in (0, axis)] - x = dimshuffle(self.output.placeholder, [0, axis, 'x'] + remaining_axes) # [batch,height,width] + remaining_axes + x = dimshuffle(input_data.placeholder, [0, axis, 'x'] + remaining_axes) # [batch,height,width] + remaining_axes from returnn.tf.util.basic import get_shape, optional_mul - shape = get_shape(self.output.placeholder) + shape = get_shape(input_data.placeholder) remaining_shape = [shape[i] for i in remaining_axes] remaining_dim = optional_mul(*remaining_shape) if remaining_axes else 1 x = tf.reshape(x, [shape[0], shape[axis], 1, remaining_dim]) # [batch,height,width,channels] @@ -7337,21 +7216,23 @@ def __init__(self, factor, axis, out_dim=None, kind="nn", fill_value=None, fill_ mask = expand_dims_unbroadcast(mask, axis=0, dim=shape[0]) # (batch,new_size) x = tf.boolean_mask(x, mask) # [batch*new_size_dropped] + remaining_shape x = tf.reshape(x, [shape[0], new_size_dropped] + remaining_shape) # [batch, new_size_dropped] + remaining_shape + out_dyn_size = input_data.dim_tags[axis].dyn_size if out_dyn_size is not None: + out_dyn_size = out_dyn_size * factor orig_mask = tf.sequence_mask( out_dyn_size, maxlen=new_size, dtype=tf.bool) # (batch,new_size) out_dyn_size = tf.reduce_sum(tf.cast(tf.logical_and(mask, orig_mask), tf.int32), axis=1) - if out_dyn_size is not None: - self.output.dim_tags[axis].dyn_size = out_dyn_size + self.output.dim_tags[axis].dyn_size = out_dyn_size self.output.placeholder = x @classmethod - def get_out_data_from_opts(cls, factor, axis, sources, name, out_dim=None, **kwargs): + def get_out_data_from_opts(cls, factor, axis, sources, name, fill_dropout=None, out_dim=None, **kwargs): """ :param int factor: :param Dim|str axis: :param list[LayerBase] sources: :param str name: + :param float|None fill_dropout: :param Dim|None out_dim: :rtype: Data """ @@ -7362,12 +7243,13 @@ def get_out_data_from_opts(cls, factor, axis, sources, name, out_dim=None, **kwa axis = 1 assert axis != out.batch_dim_axis, "batch-dim resize not supported" tag = out.dim_tags[axis] - dim = None if tag.dimension is None else (tag.dimension * factor) - if out_dim: - assert out_dim.dimension == dim + if fill_dropout: + out_dim_ = Dim(kind=tag.kind, description="%s_resize" % name, auto_generated=True) # unknown dim else: - out_dim = Dim(kind=tag.kind, description="%s_resize" % name, dimension=dim, auto_generated=True) - return out.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim) + out_dim_ = tag * factor + if out_dim: + out_dim_.declare_same_as(out_dim) + return out.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim_) class CombineDimsLayer(MergeDimsLayer): diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 8e8ad305f2..3f69ee620c 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -428,6 +428,9 @@ def dyn_size(self): If the dyn size can potentially be of a different shape, directly access dyn_size_ext. :rtype: tf.Tensor|None """ + if self.dimension is None and (not self.dyn_size_ext or self.dyn_size_ext.placeholder is None): + # Try to complete. + self.complete_dyn_size() if self.dyn_size_ext: return self.dyn_size_ext.placeholder return None @@ -611,14 +614,31 @@ def complete_dyn_size(self): if kind.endswith("_left"): kind = kind[:-len("_left")] + import numpy + from tensorflow.python.framework import tensor_util + + def _is_negative(x__): + if isinstance(x__, numpy.ndarray): + return (x__ < 0).any() + if isinstance(x__, (int, float)): + return x__ < 0 + assert isinstance(x__, tf.Tensor) + x__ = tensor_util.constant_value(x__) + if x__ is not None: + return _is_negative(x__) + return False + def _bin_op(a, b): with tf_util.same_control_flow_ctx([a, b]): if kind == "add": - return tf_util.simplify_add(a, b) + use_relu = _is_negative(a) or _is_negative(b) # for dynamic tensors, assume all positive + if use_relu: + return tf.convert_to_tensor(tf_util.simplify_non_negative_seq_length(tf_util.simplify_add(a, b))) + return tf.convert_to_tensor(tf_util.simplify_add(a, b)) elif kind == "sub": - return tf_util.simplify_sub(a, b) + return tf.convert_to_tensor(tf_util.simplify_non_negative_seq_length(tf_util.simplify_sub(a, b))) elif kind == "mul": - return a * b + return tf.convert_to_tensor(tf_util.optional_mul(a, b)) elif kind in ("floordiv", "truediv"): # truediv assumes there is no remainder return a // b elif kind == "ceildiv": @@ -641,7 +661,8 @@ def _bin_op(a, b): continue y.placeholder = _bin_op(y.placeholder, x.dimension) continue - x = x.get_for_batch_ctx(self.batch, self.control_flow_ctx) + if self.batch: + x = x.get_for_batch_ctx(self.batch, self.control_flow_ctx) x.complete_dyn_size() if not x.dyn_size_ext or x.dyn_size_ext.placeholder is None: return @@ -899,6 +920,12 @@ def declare_same_as(self, other): elif other_same_base.dyn_size_ext is None or not other_same_base._validate_in_current_graph(): other_same_base.dyn_size_ext = self.get_dyn_size_ext_for_batch_ctx( other_same_base.batch, other_same_base.control_flow_ctx) + if self.is_dim_known() and other.is_dim_known(): + assert self.dimension == other.dimension + elif self.is_dim_known() and not other.is_dim_known(): + other.dimension = self.dimension + elif not self.is_dim_known() and other.is_dim_known(): + self.dimension = other.dimension if self._vocab and not other_same_base._vocab: other_same_base._vocab = self._vocab elif other_same_base._vocab and not self._vocab: diff --git a/tests/test_TFNetworkLayer.py b/tests/test_TFNetworkLayer.py index 1a83e5c29c..07c9d807a6 100644 --- a/tests/test_TFNetworkLayer.py +++ b/tests/test_TFNetworkLayer.py @@ -2564,6 +2564,22 @@ def test_MergeDimsLayer_modified_time_dim(): session.run(out.placeholder, feed_dict=make_feed_dict(net.extern_data)) +def test_MergeDimsLayer_unspecified_out_dim(): + # https://github.com/rwth-i6/returnn/issues/955 + # https://github.com/rwth-i6/returnn_common/issues/117 + config = Config({ + "extern_data": {"data": {"shape": (None, 3, 5)}}, + }) + out_dim = SpatialDim("out") + with make_scope() as session: + net = TFNetwork(config=config) + net.construct_from_dict({ + "output": { + "class": "merge_dims", "from": "data", "axes": ["dim:3", "dim:5"], "keep_order": True, + "out_dim": out_dim}, + }) + + def test_FlattenBatchLayer(): from returnn.tf.util.data import BatchInfo n_batch, n_time, n_in = 3, 4, 2