diff --git a/NLI.ipynb b/NLI.ipynb index d67c7fe..d3a81e3 100644 --- a/NLI.ipynb +++ b/NLI.ipynb @@ -588,8 +588,8 @@ "source": [ "# For making predictions at test time\n", "def predict(model: torch.nn.Module, sents: torch.Tensor) -> List:\n", - " logits = model(sents)\n", - " return list(torch.argmax(logits, axis=2).squeeze().numpy())" + " logits = model(sents.to(device))\n", + " return list(torch.Tensor.cpu(torch.argmax(logits, axis=2).squeeze()).numpy())" ] }, { @@ -713,7 +713,7 @@ " all_preds = []\n", " all_labels = []\n", " for sents, labels in tqdm(zip(dev_sents, dev_labels), total=len(dev_sents)):\n", - " pred = predict(model, sents).cpu()\n", + " pred = predict(model, sents)\n", " all_preds.extend(pred)\n", " all_labels.extend(list(labels.cpu().numpy()))\n", "\n",