From af62100cc8161fd0183e36b715e3efe1539ea886 Mon Sep 17 00:00:00 2001 From: Markus Fleischhacker Date: Sat, 7 Dec 2019 13:53:01 +0100 Subject: [PATCH] Fix forward pass of BiRNN --- .../src/bi_rnn.cpp | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/tutorials/intermediate/bidirectional_recurrent_neural_network/src/bi_rnn.cpp b/tutorials/intermediate/bidirectional_recurrent_neural_network/src/bi_rnn.cpp index ae73c05..2bf8160 100644 --- a/tutorials/intermediate/bidirectional_recurrent_neural_network/src/bi_rnn.cpp +++ b/tutorials/intermediate/bidirectional_recurrent_neural_network/src/bi_rnn.cpp @@ -10,10 +10,20 @@ BiRNNImpl::BiRNNImpl(int64_t input_size, int64_t hidden_size, int64_t num_layers } torch::Tensor BiRNNImpl::forward(torch::Tensor x) { - auto out = lstm->forward(x) - .output - .slice(1, -1) - .squeeze(1); - out = fc->forward(out); + auto out = lstm->forward(x).output; // out: tensor of shape (batch_size, sequence_length, 2 * hidden_size) + + // Concatenate the last hidden state of forward LSTM and first hidden state of backward LSTM + // + // Source: Translated from python code at + // https://github.com/yunjey/pytorch-tutorial/pull/174/commits/8c0897ee93fed8d9b352d33a60c1f931c9be5351 + auto out_directions = out.chunk(2, 2); + // Last hidden state of forward direction output + auto out_1 = out_directions[0].slice(1, -1).squeeze(1); // out_1: tensor of shape (batch_size, hidden_size) + // First hidden state of backward direction output + auto out_2 = out_directions[1].slice(1, 0, 1).squeeze(1); // out_2: tensor of shape (batch_size, hidden_size) + auto out_cat = torch::cat({out_1, out_2}, 1); // out_cat: tensor of shape (batch_size, 2 * hidden_size) + + out = fc->forward(out_cat); return torch::log_softmax(out, 1); } +