Skip to content

Commit f529f98

Browse files
committed
ChoiceLayer fix dim tags and batch info
1 parent 31259a7 commit f529f98

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

returnn/tf/layers/rec.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4691,12 +4691,13 @@ def __init__(self, beam_size, keep_beams=False,
46914691
self.output_list.append(Data(
46924692
name="%s_choice_output_%d" % (self.name, index),
46934693
batch_dim_axis=0,
4694-
shape=self.output.shape,
4694+
dim_tags=self.output.dim_tags,
46954695
sparse=True,
46964696
dim=self.sources[index].output.dim,
46974697
dtype=self.output.dtype,
46984698
placeholder=labels_,
46994699
available_for_inference=True,
4700+
batch=self.output.batch,
47004701
beam=self.output.beam))
47014702

47024703
# We use the labels of the first target as "normal" output.
@@ -4733,7 +4734,8 @@ def __init__(self, beam_size, keep_beams=False,
47334734
self.output = Data(
47344735
name="%s_sampled_output" % self.name,
47354736
batch_dim_axis=0,
4736-
shape=self.output.shape,
4737+
dim_tags=self.output.dim_tags,
4738+
batch=self.output.batch,
47374739
sparse=input_type != "regression",
47384740
dim=self.output.dim,
47394741
dtype=self.output.dtype,
@@ -5036,10 +5038,11 @@ def get_out_data_from_opts(cls, name, sources, target, network,
50365038
assert search, "%s %r: no target given, must do search" % (cls.__name__, name)
50375039
# Output will be the sparse version of the input.
50385040
out_data = sources[0].output.copy_template().copy_as_batch_major()
5039-
shape = list(out_data.batch_shape)
5040-
del shape[out_data.feature_dim_axis]
5041-
del shape[out_data.batch_dim_axis]
5042-
out_data = Data(name="%s_output" % name, shape=shape, sparse=True, dim=out_data.dim)
5041+
dim_tags = list(out_data.dim_tags)
5042+
del dim_tags[out_data.feature_dim_axis]
5043+
out_data = Data(
5044+
name="%s_output" % name, dim_tags=dim_tags, sparse=True, dim=out_data.dim,
5045+
batch=out_data.batch.copy_set_beam(None) if out_data.batch else network.get_global_batch_info())
50435046
if search:
50445047
out_data.beam = cls._create_search_beam(name=name, beam_size=beam_size, sources=sources, network=network)
50455048
elif sources and sources[0]:

0 commit comments

Comments
 (0)