@@ -3729,14 +3729,15 @@ class RepeatLayer(_ConcatInputLayer):
3729
3729
"""
3730
3730
layer_class = "repeat"
3731
3731
3732
- def __init__ (self , repetitions , axis = "T" , ** kwargs ):
3732
+ def __init__ (self , repetitions , axis = "T" , out_dim = None , ** kwargs ):
3733
3733
"""
3734
3734
:param LayerBase|int repetitions:
3735
3735
number of repetitions for each sequence and position in target axis.
3736
3736
Can be [B,T] or [T,B] or some subset of that shape
3737
- :param str axis: (dynamic) axis for repetition (currently only time axis is supported)
3737
+ :param DimensionTag|str axis: (dynamic) axis for repetition (currently only time axis is supported)
3738
+ :param DimensionTag|None out_dim:
3738
3739
"""
3739
- super (RepeatLayer , self ).__init__ (** kwargs )
3740
+ super (RepeatLayer , self ).__init__ (out_dim = out_dim , ** kwargs )
3740
3741
self .repetitions = repetitions
3741
3742
if isinstance (self .repetitions , int ):
3742
3743
repetitions_data = Data .from_tensor (tf .constant (self .repetitions ))
@@ -3819,7 +3820,7 @@ def copy_placeholder_with_batch_axis(data, other_batch):
3819
3820
# set size placeholders
3820
3821
output_axis = self .output .get_axis_from_description (axis )
3821
3822
tag = self .output .dim_tags [output_axis ]
3822
- if tag .dimension is None : # dynamic? dyn sizes needed?
3823
+ if tag .dimension is None and tag . dyn_size is None : # dynamic? dyn sizes needed?
3823
3824
tag .set_tag_on_size_tensor (target_seq_len , batch = self .output .batch )
3824
3825
3825
3826
def get_dep_layers (self ):
@@ -3843,12 +3844,13 @@ def transform_config_dict(cls, d, network, get_layer):
3843
3844
d ["repetitions" ] = get_layer (d ["repetitions" ])
3844
3845
3845
3846
@classmethod
3846
- def get_out_data_from_opts (cls , name , axis , repetitions , sources = () , ** kwargs ):
3847
+ def get_out_data_from_opts (cls , name , sources , axis , repetitions , out_dim = None , ** kwargs ):
3847
3848
"""
3848
3849
:param str name:
3849
- :param str axis:
3850
- :param LayerBase|int repetitions:
3851
3850
:param list[LayerBase] sources:
3851
+ :param DimensionTag|str axis:
3852
+ :param LayerBase|int repetitions:
3853
+ :param DimensionTag|None out_dim:
3852
3854
:rtype: Data
3853
3855
"""
3854
3856
from ..util .data import DimensionTag
@@ -3864,8 +3866,11 @@ def get_out_data_from_opts(cls, name, axis, repetitions, sources=(), **kwargs):
3864
3866
else :
3865
3867
new_dim = None
3866
3868
data = data .copy_move_axis (original_axis , data .get_batch_axis (0 ))
3867
- tag = DimensionTag (description = "repeated:%s" % name , kind = tag .kind , dimension = new_dim )
3868
- return data .copy_template_replace_dim_tag (axis = data .get_batch_axis (0 ), new_dim_tag = tag )
3869
+ if not out_dim :
3870
+ out_dim = DimensionTag (description = "repeated:%s" % name , kind = tag .kind , dimension = new_dim , derived_from_tag = tag )
3871
+ else :
3872
+ assert out_dim .dimension == new_dim
3873
+ return data .copy_template_replace_dim_tag (axis = data .get_batch_axis (0 ), new_dim_tag = out_dim )
3869
3874
3870
3875
3871
3876
class TileLayer (_ConcatInputLayer ):
0 commit comments