From ba122a36435559b168ae30b913cde1506c07bd65 Mon Sep 17 00:00:00 2001
From: 10-zin <33179372+10-zin@users.noreply.github.com>
Date: Mon, 15 Oct 2018 16:14:23 +0530
Subject: [PATCH] Correcting to RNN

---
 tutorials/02-intermediate/recurrent_neural_network/main.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/tutorials/02-intermediate/recurrent_neural_network/main.py b/tutorials/02-intermediate/recurrent_neural_network/main.py
index 9b8685ca..c37ac4b4 100644
--- a/tutorials/02-intermediate/recurrent_neural_network/main.py
+++ b/tutorials/02-intermediate/recurrent_neural_network/main.py
@@ -42,7 +42,7 @@ def __init__(self, input_size, hidden_size, num_layers, num_classes):
         super(RNN, self).__init__()
         self.hidden_size = hidden_size
         self.num_layers = num_layers
-        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
+        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
         self.fc = nn.Linear(hidden_size, num_classes)
     
     def forward(self, x):
@@ -51,7 +51,7 @@ def forward(self, x):
         c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
         
         # Forward propagate LSTM
-        out, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size)
+        out, _ = self.rnn(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size)
         
         # Decode the hidden state of the last time step
         out = self.fc(out[:, -1, :])
@@ -99,4 +99,4 @@ def forward(self, x):
     print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) 
 
 # Save the model checkpoint
-torch.save(model.state_dict(), 'model.ckpt')
\ No newline at end of file
+torch.save(model.state_dict(), 'model.ckpt')