diff --git a/models/FCModel.py b/models/FCModel.py index c275b5b9..604b1140 100644 --- a/models/FCModel.py +++ b/models/FCModel.py @@ -59,10 +59,10 @@ def __init__(self, opt): self.ss_prob = 0.0 # Schedule sampling probability - self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size) - self.core = LSTMCore(opt) - self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) - self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) + self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size).to(device) + self.core = LSTMCore(opt).to(device) + self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size).to(device) + self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1).to(device) self.init_weights()