From 4474c7eb5c6358e43b86dbbde4a410aee4fefacc Mon Sep 17 00:00:00 2001 From: Utsav Krishnan Date: Mon, 13 May 2019 17:03:58 +0530 Subject: [PATCH] Use tensor.item() for scalar conversion To handle IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number --- Capsule Network.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Capsule Network.ipynb b/Capsule Network.ipynb index f3cd682..654867d 100644 --- a/Capsule Network.ipynb +++ b/Capsule Network.ipynb @@ -306,7 +306,7 @@ " loss.backward()\n", " optimizer.step()\n", "\n", - " train_loss += loss.data[0]\n", + " train_loss += loss.item()\n", " \n", " if batch_id % 100 == 0:\n", " print \"train accuracy:\", sum(np.argmax(masked.data.cpu().numpy(), 1) == \n", @@ -327,7 +327,7 @@ " output, reconstructions, masked = capsule_net(data)\n", " loss = capsule_net.loss(data, output, target, reconstructions)\n", "\n", - " test_loss += loss.data[0]\n", + " test_loss += loss.item()\n", " \n", " if batch_id % 100 == 0:\n", " print \"test accuracy:\", sum(np.argmax(masked.data.cpu().numpy(), 1) == \n",