diff --git a/examples/seq2seq_exposure_bias/interpolation_decoder.py b/examples/seq2seq_exposure_bias/interpolation_decoder.py index a0936fc6..382409d3 100644 --- a/examples/seq2seq_exposure_bias/interpolation_decoder.py +++ b/examples/seq2seq_exposure_bias/interpolation_decoder.py @@ -110,12 +110,12 @@ def step(self, time, inputs, state, name=None): logits, sample_ids, wrapper_outputs, attention_scores, attention_context) - return (outputs, wrapper_state) + return (outputs, [decoded_ids, wrapper_state]) def next_inputs(self, time, outputs, state): (finished, next_inputs, next_state) = self._helper.next_inputs( time=time, outputs=outputs.logits, - state=[state[0], state], + state=state, sample_ids=outputs.sample_id) return (finished, next_inputs, next_state)