Skip to content

Commit 7713ade

Browse files
authored
TileLayer, out_dims option (#810)
#597
1 parent 43ce5a9 commit 7713ade

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

returnn/tf/layers/basic.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3874,10 +3874,12 @@ class TileLayer(_ConcatInputLayer):
38743874
"""
38753875
layer_class = "tile"
38763876

3877-
def __init__(self, multiples, **kwargs):
3877+
def __init__(self, multiples, out_dims=None, **kwargs):
38783878
"""
3879-
:param dict[str, int] multiples: number of multiples per axis (axis provided as str)
3879+
:param dict[DimensionTag|str, int] multiples: number of multiples per axis (axis provided as dim tag or str desc)
3880+
:param dict[DimensionTag|str, DimensionTag]|None out_dims:
38803881
"""
3882+
out_dims # noqa # handled in get_out_data_from_opts
38813883
super(TileLayer, self).__init__(**kwargs)
38823884
self.multiples = multiples
38833885
input_data = self.input_data
@@ -3893,22 +3895,27 @@ def __init__(self, multiples, **kwargs):
38933895
self.output.placeholder = tf.tile(input_data.placeholder, multiples_full)
38943896

38953897
@classmethod
3896-
def get_out_data_from_opts(cls, name, multiples, sources=(), **kwargs):
3898+
def get_out_data_from_opts(cls, name, sources, multiples, out_dims=None, **kwargs):
38973899
"""
38983900
:param str name:
3899-
:param dict[str, int] multiples:
39003901
:param list[LayerBase] sources:
3902+
:param dict[DimensionTag|str, int] multiples:
3903+
:param dict[DimensionTag|str, DimensionTag]|None out_dims:
39013904
:rtype: Data
39023905
"""
39033906
from ..util.data import DimensionTag
39043907
data = get_concat_sources_data_template(sources, name="%s_output" % name)
39053908
dim_tags = list(data.dim_tags)
39063909
for axis, multiple in multiples.items():
3907-
axis = data.get_axis_from_description(axis, allow_int=False)
3908-
tag = dim_tags[axis]
3910+
axis_int = data.get_axis_from_description(axis, allow_int=False)
3911+
tag = dim_tags[axis_int]
39093912
dim = None if tag.dimension is None else (tag.dimension * multiple)
3910-
tag = DimensionTag(kind=tag.kind, description="%s_tile" % name, dimension=dim)
3911-
dim_tags[axis] = tag
3913+
if out_dims and axis in out_dims:
3914+
tag = out_dims[axis]
3915+
assert tag.dimension == dim
3916+
else:
3917+
tag = DimensionTag(kind=tag.kind, description="%s_tile" % name, dimension=dim)
3918+
dim_tags[axis_int] = tag
39123919
return data.copy_template_new_dim_tags(dim_tags, keep_special_axes=True)
39133920

39143921

0 commit comments

Comments
 (0)