@@ -4691,12 +4691,13 @@ def __init__(self, beam_size, keep_beams=False,
4691
4691
self .output_list .append (Data (
4692
4692
name = "%s_choice_output_%d" % (self .name , index ),
4693
4693
batch_dim_axis = 0 ,
4694
- shape = self .output .shape ,
4694
+ dim_tags = self .output .dim_tags ,
4695
4695
sparse = True ,
4696
4696
dim = self .sources [index ].output .dim ,
4697
4697
dtype = self .output .dtype ,
4698
4698
placeholder = labels_ ,
4699
4699
available_for_inference = True ,
4700
+ batch = self .output .batch ,
4700
4701
beam = self .output .beam ))
4701
4702
4702
4703
# We use the labels of the first target as "normal" output.
@@ -4733,7 +4734,8 @@ def __init__(self, beam_size, keep_beams=False,
4733
4734
self .output = Data (
4734
4735
name = "%s_sampled_output" % self .name ,
4735
4736
batch_dim_axis = 0 ,
4736
- shape = self .output .shape ,
4737
+ dim_tags = self .output .dim_tags ,
4738
+ batch = self .output .batch ,
4737
4739
sparse = input_type != "regression" ,
4738
4740
dim = self .output .dim ,
4739
4741
dtype = self .output .dtype ,
@@ -5036,10 +5038,11 @@ def get_out_data_from_opts(cls, name, sources, target, network,
5036
5038
assert search , "%s %r: no target given, must do search" % (cls .__name__ , name )
5037
5039
# Output will be the sparse version of the input.
5038
5040
out_data = sources [0 ].output .copy_template ().copy_as_batch_major ()
5039
- shape = list (out_data .batch_shape )
5040
- del shape [out_data .feature_dim_axis ]
5041
- del shape [out_data .batch_dim_axis ]
5042
- out_data = Data (name = "%s_output" % name , shape = shape , sparse = True , dim = out_data .dim )
5041
+ dim_tags = list (out_data .dim_tags )
5042
+ del dim_tags [out_data .feature_dim_axis ]
5043
+ out_data = Data (
5044
+ name = "%s_output" % name , dim_tags = dim_tags , sparse = True , dim = out_data .dim ,
5045
+ batch = out_data .batch .copy_set_beam (None ) if out_data .batch else network .get_global_batch_info ())
5043
5046
if search :
5044
5047
out_data .beam = cls ._create_search_beam (name = name , beam_size = beam_size , sources = sources , network = network )
5045
5048
elif sources and sources [0 ]:
0 commit comments