@@ -363,11 +363,13 @@ class ConcatLayer(LayerBase):
363363 """
364364 layer_class = "concat"
365365
366- def __init__ (self , sources , allow_broadcast = False , ** kwargs ):
366+ def __init__ (self , sources , allow_broadcast = False , out_dim = None , ** kwargs ):
367367 """
368368 :param list[(LayerBase,str|Dim)] sources:
369369 :param bool allow_broadcast:
370+ :param Dim out_dim:
370371 """
372+ out_dim # noqa # via get_out_data_from_opts
371373 if allow_broadcast :
372374 raise NotImplementedError
373375 sources , axes = zip (* sources ) # unzip
@@ -395,11 +397,12 @@ def _copy_compatible(x, axis):
395397 allow_broadcast = [allow_broadcast ] * len (sources_data ))
396398
397399 @classmethod
398- def get_out_data_from_opts (cls , name , sources , allow_broadcast = False , ** kwargs ):
400+ def get_out_data_from_opts (cls , name , sources , allow_broadcast = False , out_dim = None , ** kwargs ):
399401 """
400402 :param str name:
401403 :param list[(LayerBase,str|Dim)] sources:
402404 :param bool allow_broadcast:
405+ :param Dim|None out_dim:
403406 :rtype: Data
404407 """
405408 assert sources
@@ -413,13 +416,15 @@ def get_out_data_from_opts(cls, name, sources, allow_broadcast=False, **kwargs):
413416 dimension = 0
414417 for tag in concat_dim_tags :
415418 dimension += tag .dimension
416- # We ignore allow_broadcast here... Anyway not currently implemented.
417- # Just overtake the first input format.
418- concat_res_dim_tag = Dim (
419- kind = concat_dim_tags [0 ].kind , description = "%s_concat" % name , dimension = dimension ,
420- derived_from_tag = concat_dim_tags [0 ])
419+ if not out_dim :
420+ # We ignore allow_broadcast here... Anyway not currently implemented.
421+ # Just overtake the first input format.
422+ out_dim = Dim (
423+ kind = concat_dim_tags [0 ].kind , description = "%s_concat" % name , dimension = dimension ,
424+ derived_from_tag = concat_dim_tags [0 ])
425+ assert out_dim .dimension == dimension
421426 res_dim_tags = list (sources [0 ].output .dim_tags )
422- res_dim_tags [axes_int [0 ]] = concat_res_dim_tag
427+ res_dim_tags [axes_int [0 ]] = out_dim
423428 return Data (name = "%s_output" % name , dim_tags = res_dim_tags , dtype = sources [0 ].output .dtype )
424429
425430 @classmethod
0 commit comments