diff --git a/model.py b/model.py index 571fe32c8..c77ba93b1 100755 --- a/model.py +++ b/model.py @@ -383,8 +383,6 @@ def decode(self, decoder_input, attention_weights=None): cell_input, (self.attention_hidden, self.attention_cell)) self.attention_hidden = F.dropout( self.attention_hidden, self.p_attention_dropout, self.training) - self.attention_cell = F.dropout( - self.attention_cell, self.p_attention_dropout, self.training) attention_weights_cat = torch.cat( (self.attention_weights.unsqueeze(1), @@ -400,8 +398,6 @@ def decode(self, decoder_input, attention_weights=None): decoder_input, (self.decoder_hidden, self.decoder_cell)) self.decoder_hidden = F.dropout( self.decoder_hidden, self.p_decoder_dropout, self.training) - self.decoder_cell = F.dropout( - self.decoder_cell, self.p_decoder_dropout, self.training) decoder_hidden_attention_context = torch.cat( (self.decoder_hidden, self.attention_context), dim=1)