diff --git a/ep7.ipynb b/ep7.ipynb index a54bc04..1839c13 100644 --- a/ep7.ipynb +++ b/ep7.ipynb @@ -294,7 +294,8 @@ ], "source": [ "# here's one it gets right\n", - "print (\"Predicted %d, Label: %d\" % (classifier.predict(test_data[0]), test_labels[0]))\n", + "prediction = classifier.predict(np.array([test_data[0]], dtype=float), as_iterable=False)\n", + "print (\"Predicted %d, Label: %d\" % (prediction, test_labels[0]))\n", "display(0)" ] }, @@ -325,7 +326,8 @@ ], "source": [ "# and one it gets wrong\n", - "print (\"Predicted %d, Label: %d\" % (classifier.predict(test_data[8]), test_labels[8]))\n", + "prediction = classifier.predict(np.array([test_data[0]], dtype=float), as_iterable=False)\n", + "print (\"Predicted %d, Label: %d\" % (prediction, test_labels[8]))\n", "display(8)" ] }, @@ -358,7 +360,7 @@ } ], "source": [ - "weights = classifier.weights_\n", + "weights = classifier.get_variable_value(\"linear//weight\")\n", "f, axes = plt.subplots(2, 5, figsize=(10,4))\n", "axes = axes.reshape(-1)\n", "for i in range(len(axes)):\n",