@@ -3874,10 +3874,12 @@ class TileLayer(_ConcatInputLayer):
3874
3874
"""
3875
3875
layer_class = "tile"
3876
3876
3877
- def __init__ (self , multiples , ** kwargs ):
3877
+ def __init__ (self , multiples , out_dims = None , ** kwargs ):
3878
3878
"""
3879
- :param dict[str, int] multiples: number of multiples per axis (axis provided as str)
3879
+ :param dict[DimensionTag|str, int] multiples: number of multiples per axis (axis provided as dim tag or str desc)
3880
+ :param dict[DimensionTag|str, DimensionTag]|None out_dims:
3880
3881
"""
3882
+ out_dims # noqa # handled in get_out_data_from_opts
3881
3883
super (TileLayer , self ).__init__ (** kwargs )
3882
3884
self .multiples = multiples
3883
3885
input_data = self .input_data
@@ -3893,22 +3895,27 @@ def __init__(self, multiples, **kwargs):
3893
3895
self .output .placeholder = tf .tile (input_data .placeholder , multiples_full )
3894
3896
3895
3897
@classmethod
3896
- def get_out_data_from_opts (cls , name , multiples , sources = () , ** kwargs ):
3898
+ def get_out_data_from_opts (cls , name , sources , multiples , out_dims = None , ** kwargs ):
3897
3899
"""
3898
3900
:param str name:
3899
- :param dict[str, int] multiples:
3900
3901
:param list[LayerBase] sources:
3902
+ :param dict[DimensionTag|str, int] multiples:
3903
+ :param dict[DimensionTag|str, DimensionTag]|None out_dims:
3901
3904
:rtype: Data
3902
3905
"""
3903
3906
from ..util .data import DimensionTag
3904
3907
data = get_concat_sources_data_template (sources , name = "%s_output" % name )
3905
3908
dim_tags = list (data .dim_tags )
3906
3909
for axis , multiple in multiples .items ():
3907
- axis = data .get_axis_from_description (axis , allow_int = False )
3908
- tag = dim_tags [axis ]
3910
+ axis_int = data .get_axis_from_description (axis , allow_int = False )
3911
+ tag = dim_tags [axis_int ]
3909
3912
dim = None if tag .dimension is None else (tag .dimension * multiple )
3910
- tag = DimensionTag (kind = tag .kind , description = "%s_tile" % name , dimension = dim )
3911
- dim_tags [axis ] = tag
3913
+ if out_dims and axis in out_dims :
3914
+ tag = out_dims [axis ]
3915
+ assert tag .dimension == dim
3916
+ else :
3917
+ tag = DimensionTag (kind = tag .kind , description = "%s_tile" % name , dimension = dim )
3918
+ dim_tags [axis_int ] = tag
3912
3919
return data .copy_template_new_dim_tags (dim_tags , keep_special_axes = True )
3913
3920
3914
3921
0 commit comments