Skip to content

Commit 5d005f2

Browse files
authored
RepeatLayer, handle out_dim option (#803)
#597
1 parent 7713ade commit 5d005f2

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

returnn/tf/layers/basic.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -3729,14 +3729,15 @@ class RepeatLayer(_ConcatInputLayer):
37293729
"""
37303730
layer_class = "repeat"
37313731

3732-
def __init__(self, repetitions, axis="T", **kwargs):
3732+
def __init__(self, repetitions, axis="T", out_dim=None, **kwargs):
37333733
"""
37343734
:param LayerBase|int repetitions:
37353735
number of repetitions for each sequence and position in target axis.
37363736
Can be [B,T] or [T,B] or some subset of that shape
3737-
:param str axis: (dynamic) axis for repetition (currently only time axis is supported)
3737+
:param DimensionTag|str axis: (dynamic) axis for repetition (currently only time axis is supported)
3738+
:param DimensionTag|None out_dim:
37383739
"""
3739-
super(RepeatLayer, self).__init__(**kwargs)
3740+
super(RepeatLayer, self).__init__(out_dim=out_dim, **kwargs)
37403741
self.repetitions = repetitions
37413742
if isinstance(self.repetitions, int):
37423743
repetitions_data = Data.from_tensor(tf.constant(self.repetitions))
@@ -3819,7 +3820,7 @@ def copy_placeholder_with_batch_axis(data, other_batch):
38193820
# set size placeholders
38203821
output_axis = self.output.get_axis_from_description(axis)
38213822
tag = self.output.dim_tags[output_axis]
3822-
if tag.dimension is None: # dynamic? dyn sizes needed?
3823+
if tag.dimension is None and tag.dyn_size is None: # dynamic? dyn sizes needed?
38233824
tag.set_tag_on_size_tensor(target_seq_len, batch=self.output.batch)
38243825

38253826
def get_dep_layers(self):
@@ -3843,12 +3844,13 @@ def transform_config_dict(cls, d, network, get_layer):
38433844
d["repetitions"] = get_layer(d["repetitions"])
38443845

38453846
@classmethod
3846-
def get_out_data_from_opts(cls, name, axis, repetitions, sources=(), **kwargs):
3847+
def get_out_data_from_opts(cls, name, sources, axis, repetitions, out_dim=None, **kwargs):
38473848
"""
38483849
:param str name:
3849-
:param str axis:
3850-
:param LayerBase|int repetitions:
38513850
:param list[LayerBase] sources:
3851+
:param DimensionTag|str axis:
3852+
:param LayerBase|int repetitions:
3853+
:param DimensionTag|None out_dim:
38523854
:rtype: Data
38533855
"""
38543856
from ..util.data import DimensionTag
@@ -3864,8 +3866,11 @@ def get_out_data_from_opts(cls, name, axis, repetitions, sources=(), **kwargs):
38643866
else:
38653867
new_dim = None
38663868
data = data.copy_move_axis(original_axis, data.get_batch_axis(0))
3867-
tag = DimensionTag(description="repeated:%s" % name, kind=tag.kind, dimension=new_dim)
3868-
return data.copy_template_replace_dim_tag(axis=data.get_batch_axis(0), new_dim_tag=tag)
3869+
if not out_dim:
3870+
out_dim = DimensionTag(description="repeated:%s" % name, kind=tag.kind, dimension=new_dim, derived_from_tag=tag)
3871+
else:
3872+
assert out_dim.dimension == new_dim
3873+
return data.copy_template_replace_dim_tag(axis=data.get_batch_axis(0), new_dim_tag=out_dim)
38693874

38703875

38713876
class TileLayer(_ConcatInputLayer):

0 commit comments

Comments
 (0)