diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 390a555a74..b31cb84504 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -5521,12 +5521,14 @@ class StackLayer(LayerBase): """ layer_class = "stack" - def __init__(self, axis=None, **kwargs): + def __init__(self, axis=None, out_spatial_dim=None, **kwargs): """ :param int|None axis: new axis. If not given, will use Data.get_default_new_axis_for_dim_tag(), i.e. some reasonable default for a new spatial axis. + :param DimensionTag|None out_spatial_dim: """ + out_spatial_dim # noqa # handled in get_out_data_from_opts super(StackLayer, self).__init__(**kwargs) axis_, common_source = self._get_axis_and_common(self.sources) if axis is None: @@ -5543,24 +5545,28 @@ def _get_axis_and_common(cls, sources): :param list[LayerBase] sources: :rtype: (int,Data) """ - from returnn.tf.util.basic import DimensionTag common_source = Data.get_common_data([src.output for src in sources]).copy_template() - tag = DimensionTag(kind=DimensionTag.Types.Spatial, dimension=1) - return common_source.get_default_new_axis_for_dim_tag(tag), common_source + dummy_tag = DimensionTag(kind=DimensionTag.Types.Spatial, dimension=1) + return common_source.get_default_new_axis_for_dim_tag(dummy_tag), common_source @classmethod - def get_out_data_from_opts(cls, name, sources, axis=None, **kwargs): + def get_out_data_from_opts(cls, name, sources, axis=None, out_spatial_dim=None, **kwargs): """ :param str name: :param list[LayerBase] sources: :param int|None axis: + :param DimensionTag|None out_spatial_dim: :rtype: Data """ axis_, common_source = cls._get_axis_and_common(sources) if axis is None: axis = axis_ out = common_source.copy_template(name="%s_output" % name) - out = out.copy_add_spatial_dim(spatial_dim_axis=axis, dim=len(sources)) + if not out_spatial_dim: + out_spatial_dim = DimensionTag( + kind=DimensionTag.Types.Spatial, description="%s:stack" % name, dimension=len(sources)) + assert out_spatial_dim.dimension == len(sources) + out = out.copy_add_dim_by_tag(axis=axis, dim_tag=out_spatial_dim, unbroadcast=True) return out