Skip to content

Commit dcae584

Browse files
committed
limit "t" and correct prev non blank for search
1 parent e62f264 commit dcae584

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

common/models/transducer/transducer_fullsum.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -229,11 +229,15 @@ def make(self, encoder: LayerRef):
229229
blank_idx = self.ctx.blank_idx
230230

231231
rec_decoder = {
232-
"am0": {"class": "gather_nd", "from": _base(encoder), "position": "prev:t"}, # [B,D]
232+
"index": {"class": "eval", "from": ["prev:t", "enc_seq_len"], "eval": 'tf.minimum(source(0), source(1)-1)'},
233+
"am0": {"class": "gather_nd", "from": _base(encoder), "position": "index"}, # [B,D]
233234
"am": {"class": "copy", "from": "am0" if search else "data:source"},
234235

236+
"prev_output_wo_b": {
237+
"class": "masked_computation", "unit": {"class": "copy", "initial_output": 0},
238+
"from": "prev:output_", "mask": "prev:output_emit", "initial_output": 0},
235239
"prev_out_non_blank": {
236-
"class": "reinterpret_data", "from": "prev:output_", "set_sparse_dim": target.get_num_classes()},
240+
"class": "reinterpret_data", "from": "prev_output_wo_b", "set_sparse_dim": target.get_num_classes()},
237241

238242
"slow_rnn": self.slow_rnn.make(
239243
prev_sparse_label_nb="prev_out_non_blank",
@@ -252,7 +256,7 @@ def make(self, encoder: LayerRef):
252256

253257
"output": {
254258
"class": 'choice',
255-
'target': target.key, # note: wrong! but this is ignored both in full-sum training and in search
259+
'target': target.key if train else None, # note: wrong! but this is ignored both in full-sum training and in search
256260
'beam_size': beam_size,
257261
'from': "output_log_prob_wb", "input_type": "log_prob",
258262
"initial_output": 0,

0 commit comments

Comments
 (0)