Skip to content

Commit 5f28bb5

Browse files
authored
Merge pull request #13 from mfl28/bidirectional-rnn-fix
Fix forward pass of BiRNN
2 parents f09f44d + af62100 commit 5f28bb5

File tree

1 file changed

+15
-5
lines changed
  • tutorials/intermediate/bidirectional_recurrent_neural_network/src

1 file changed

+15
-5
lines changed

tutorials/intermediate/bidirectional_recurrent_neural_network/src/bi_rnn.cpp

+15-5
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,20 @@ BiRNNImpl::BiRNNImpl(int64_t input_size, int64_t hidden_size, int64_t num_layers
1010
}
1111

1212
torch::Tensor BiRNNImpl::forward(torch::Tensor x) {
13-
auto out = lstm->forward(x)
14-
.output
15-
.slice(1, -1)
16-
.squeeze(1);
17-
out = fc->forward(out);
13+
auto out = lstm->forward(x).output; // out: tensor of shape (batch_size, sequence_length, 2 * hidden_size)
14+
15+
// Concatenate the last hidden state of forward LSTM and first hidden state of backward LSTM
16+
//
17+
// Source: Translated from python code at
18+
// https://github.com/yunjey/pytorch-tutorial/pull/174/commits/8c0897ee93fed8d9b352d33a60c1f931c9be5351
19+
auto out_directions = out.chunk(2, 2);
20+
// Last hidden state of forward direction output
21+
auto out_1 = out_directions[0].slice(1, -1).squeeze(1); // out_1: tensor of shape (batch_size, hidden_size)
22+
// First hidden state of backward direction output
23+
auto out_2 = out_directions[1].slice(1, 0, 1).squeeze(1); // out_2: tensor of shape (batch_size, hidden_size)
24+
auto out_cat = torch::cat({out_1, out_2}, 1); // out_cat: tensor of shape (batch_size, 2 * hidden_size)
25+
26+
out = fc->forward(out_cat);
1827
return torch::log_softmax(out, 1);
1928
}
29+

0 commit comments

Comments
 (0)