@@ -229,11 +229,15 @@ def make(self, encoder: LayerRef):
229
229
blank_idx = self .ctx .blank_idx
230
230
231
231
rec_decoder = {
232
- "am0" : {"class" : "gather_nd" , "from" : _base (encoder ), "position" : "prev:t" }, # [B,D]
232
+ "index" : {"class" : "eval" , "from" : ["prev:t" , "enc_seq_len" ], "eval" : 'tf.minimum(source(0), source(1)-1)' },
233
+ "am0" : {"class" : "gather_nd" , "from" : _base (encoder ), "position" : "index" }, # [B,D]
233
234
"am" : {"class" : "copy" , "from" : "am0" if search else "data:source" },
234
235
236
+ "prev_output_wo_b" : {
237
+ "class" : "masked_computation" , "unit" : {"class" : "copy" , "initial_output" : 0 },
238
+ "from" : "prev:output_" , "mask" : "prev:output_emit" , "initial_output" : 0 },
235
239
"prev_out_non_blank" : {
236
- "class" : "reinterpret_data" , "from" : "prev:output_ " , "set_sparse_dim" : target .get_num_classes ()},
240
+ "class" : "reinterpret_data" , "from" : "prev_output_wo_b " , "set_sparse_dim" : target .get_num_classes ()},
237
241
238
242
"slow_rnn" : self .slow_rnn .make (
239
243
prev_sparse_label_nb = "prev_out_non_blank" ,
@@ -252,7 +256,7 @@ def make(self, encoder: LayerRef):
252
256
253
257
"output" : {
254
258
"class" : 'choice' ,
255
- 'target' : target .key , # note: wrong! but this is ignored both in full-sum training and in search
259
+ 'target' : target .key if train else None , # note: wrong! but this is ignored both in full-sum training and in search
256
260
'beam_size' : beam_size ,
257
261
'from' : "output_log_prob_wb" , "input_type" : "log_prob" ,
258
262
"initial_output" : 0 ,
0 commit comments