@@ -363,11 +363,13 @@ class ConcatLayer(LayerBase):
363
363
"""
364
364
layer_class = "concat"
365
365
366
- def __init__ (self , sources , allow_broadcast = False , ** kwargs ):
366
+ def __init__ (self , sources , allow_broadcast = False , out_dim = None , ** kwargs ):
367
367
"""
368
368
:param list[(LayerBase,str|Dim)] sources:
369
369
:param bool allow_broadcast:
370
+ :param Dim out_dim:
370
371
"""
372
+ out_dim # noqa # via get_out_data_from_opts
371
373
if allow_broadcast :
372
374
raise NotImplementedError
373
375
sources , axes = zip (* sources ) # unzip
@@ -395,11 +397,12 @@ def _copy_compatible(x, axis):
395
397
allow_broadcast = [allow_broadcast ] * len (sources_data ))
396
398
397
399
@classmethod
398
- def get_out_data_from_opts (cls , name , sources , allow_broadcast = False , ** kwargs ):
400
+ def get_out_data_from_opts (cls , name , sources , allow_broadcast = False , out_dim = None , ** kwargs ):
399
401
"""
400
402
:param str name:
401
403
:param list[(LayerBase,str|Dim)] sources:
402
404
:param bool allow_broadcast:
405
+ :param Dim|None out_dim:
403
406
:rtype: Data
404
407
"""
405
408
assert sources
@@ -413,13 +416,15 @@ def get_out_data_from_opts(cls, name, sources, allow_broadcast=False, **kwargs):
413
416
dimension = 0
414
417
for tag in concat_dim_tags :
415
418
dimension += tag .dimension
416
- # We ignore allow_broadcast here... Anyway not currently implemented.
417
- # Just overtake the first input format.
418
- concat_res_dim_tag = Dim (
419
- kind = concat_dim_tags [0 ].kind , description = "%s_concat" % name , dimension = dimension ,
420
- derived_from_tag = concat_dim_tags [0 ])
419
+ if not out_dim :
420
+ # We ignore allow_broadcast here... Anyway not currently implemented.
421
+ # Just overtake the first input format.
422
+ out_dim = Dim (
423
+ kind = concat_dim_tags [0 ].kind , description = "%s_concat" % name , dimension = dimension ,
424
+ derived_from_tag = concat_dim_tags [0 ])
425
+ assert out_dim .dimension == dimension
421
426
res_dim_tags = list (sources [0 ].output .dim_tags )
422
- res_dim_tags [axes_int [0 ]] = concat_res_dim_tag
427
+ res_dim_tags [axes_int [0 ]] = out_dim
423
428
return Data (name = "%s_output" % name , dim_tags = res_dim_tags , dtype = sources [0 ].output .dtype )
424
429
425
430
@classmethod
0 commit comments