Skip to content

Commit 5a476c7

Browse files
committed
ConcatLayer, add out_dim
1 parent 983f7d0 commit 5a476c7

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

returnn/tf/layers/basic.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -363,11 +363,13 @@ class ConcatLayer(LayerBase):
363363
"""
364364
layer_class = "concat"
365365

366-
def __init__(self, sources, allow_broadcast=False, **kwargs):
366+
def __init__(self, sources, allow_broadcast=False, out_dim=None, **kwargs):
367367
"""
368368
:param list[(LayerBase,str|Dim)] sources:
369369
:param bool allow_broadcast:
370+
:param Dim out_dim:
370371
"""
372+
out_dim # noqa # via get_out_data_from_opts
371373
if allow_broadcast:
372374
raise NotImplementedError
373375
sources, axes = zip(*sources) # unzip
@@ -395,11 +397,12 @@ def _copy_compatible(x, axis):
395397
allow_broadcast=[allow_broadcast] * len(sources_data))
396398

397399
@classmethod
398-
def get_out_data_from_opts(cls, name, sources, allow_broadcast=False, **kwargs):
400+
def get_out_data_from_opts(cls, name, sources, allow_broadcast=False, out_dim=None, **kwargs):
399401
"""
400402
:param str name:
401403
:param list[(LayerBase,str|Dim)] sources:
402404
:param bool allow_broadcast:
405+
:param Dim|None out_dim:
403406
:rtype: Data
404407
"""
405408
assert sources
@@ -413,13 +416,15 @@ def get_out_data_from_opts(cls, name, sources, allow_broadcast=False, **kwargs):
413416
dimension = 0
414417
for tag in concat_dim_tags:
415418
dimension += tag.dimension
416-
# We ignore allow_broadcast here... Anyway not currently implemented.
417-
# Just overtake the first input format.
418-
concat_res_dim_tag = Dim(
419-
kind=concat_dim_tags[0].kind, description="%s_concat" % name, dimension=dimension,
420-
derived_from_tag=concat_dim_tags[0])
419+
if not out_dim:
420+
# We ignore allow_broadcast here... Anyway not currently implemented.
421+
# Just overtake the first input format.
422+
out_dim = Dim(
423+
kind=concat_dim_tags[0].kind, description="%s_concat" % name, dimension=dimension,
424+
derived_from_tag=concat_dim_tags[0])
425+
assert out_dim.dimension == dimension
421426
res_dim_tags = list(sources[0].output.dim_tags)
422-
res_dim_tags[axes_int[0]] = concat_res_dim_tag
427+
res_dim_tags[axes_int[0]] = out_dim
423428
return Data(name="%s_output" % name, dim_tags=res_dim_tags, dtype=sources[0].output.dtype)
424429

425430
@classmethod

0 commit comments

Comments
 (0)