diff --git a/AML_lab_1&2_task_2.ipynb b/AML_lab_1&2_task_2.ipynb new file mode 100644 index 0000000..acaf65b --- /dev/null +++ b/AML_lab_1&2_task_2.ipynb @@ -0,0 +1,582 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gq8hC9S2WyIM" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch as tc\n", + "from torch import nn\n", + "from torchvision import datasets\n", + "from torchvision import transforms\n", + "import matplotlib.pyplot as plt\n", + "from torch import optim\n", + "import torch.nn.functional as F" + ] + }, + { + "cell_type": "code", + "source": [ + "# Transforms images to a PyTorch Tensor\n", + "tensor_transform = transforms.ToTensor()\n", + "\n", + "dataset = datasets.FashionMNIST(root = \"./data\",\n", + "\t\t\t\t\t\ttrain = True,\n", + "\t\t\t\t\t\tdownload = True,\n", + "\t\t\t\t\t\ttransform = tensor_transform)\n", + "\n", + "loader = tc.utils.data.DataLoader(dataset = dataset,\n", + "\t\t\t\t\t\t\t\t\tbatch_size = 32,\n", + "\t\t\t\t\t\t\t\t\tshuffle = True, num_workers = 3)\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Z7C9S_5YXQXU", + "outputId": "3ef9a131-4311-40d9-996c-97fa86562bee" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz\n", + "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 26421880/26421880 [00:02<00:00, 11830179.71it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw\n", + "\n", + "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz\n", + "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 29515/29515 [00:00<00:00, 200436.32it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw\n", + "\n", + "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz\n", + "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 4422102/4422102 [00:01<00:00, 3739051.57it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw\n", + "\n", + "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz\n", + "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 5148/5148 [00:00<00:00, 18329606.95it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw\n", + "\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:557: UserWarning: This DataLoader will create 3 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", + " warnings.warn(_create_warning_msg(\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "for data in loader:\n", + "\n", + " plt.imshow(data[0][0].reshape((28, 28)))\n", + " print(data[0].shape)\n", + " print(data[1])\n", + " break" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 482 + }, + "id": "UG-vfDi7XR9A", + "outputId": "09e0948c-8318-44ec-c43b-1402a9644f27" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([32, 1, 28, 28])\n", + "tensor([0, 8, 3, 2, 0, 7, 2, 2, 7, 2, 8, 2, 1, 3, 6, 5, 8, 6, 7, 3, 5, 4, 0, 0,\n", + " 9, 0, 2, 0, 4, 9, 3, 4])\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAhg0lEQVR4nO3df3DU9b3v8dfuJlkIJBtDyC8JGPAHrfzoKYU0VSmWXCCd44ByWlHvDDgWrjY4RWp10lHRns6kxRnr0UNxzrkt1HvFX+cIXL099CqaMLZAC8phOK0poVHCQAJymiwE8mv3c//gmt4VkH4+7O4nP56Pme8M2f2+833vZ7/klW92807AGGMEAECaBX03AAAYngggAIAXBBAAwAsCCADgBQEEAPCCAAIAeEEAAQC8IIAAAF5k+G7g0+LxuI4ePaqcnBwFAgHf7QAALBljdOrUKZWWlioYvPh1zoALoKNHj6qsrMx3GwCAy9TS0qJx48Zd9P4BF0A5OTmSpBv1dWUoM7UHc73CcpleFAzZ18Rj9jVp1LFklnXNn//LWeua//OVf7auqfpf91vXSFL55i7rmsBvf29dE7oiYl1z9m+usq5p+bu4dY0kvXzjP1nXLNn5LeuaCf/D/v9gaMe/W9ekUyDD/suq6etLQSf+9KlX7+qX/V/PLyZlAbRu3To9+eSTam1t1fTp0/Xss89q1qxLf8H65MduGcpURmCABpAcAijgEECBgf0SXShrhHVNMNt+7XJy7NchOMK+N0ly+NqhgMN5GgpmWddkZDqs90i3ABrtsubZ9v1lZDgEUKq/LlymQMAhgIbayw3/77/5pV5GSclXuJdfflmrV6/WmjVr9N5772n69OmaP3++jh8/norDAQAGoZQE0FNPPaXly5fr7rvv1uc//3k999xzys7O1s9//vNUHA4AMAglPYB6enq0d+9eVVVV/eUgwaCqqqq0c+fO8/bv7u5WNBpN2AAAQ1/SA+jjjz9WLBZTUVFRwu1FRUVqbW09b/+6ujpFIpH+jXfAAcDw4P1V7traWnV0dPRvLS0tvlsCAKRB0t8FV1BQoFAopLa2toTb29raVFxcfN7+4XBY4XA42W0AAAa4pF8BZWVlacaMGdq+fXv/bfF4XNu3b1dlZWWyDwcAGKRS8ntAq1ev1tKlS/WlL31Js2bN0tNPP63Ozk7dfffdqTgcAGAQSkkA3X777Tpx4oQee+wxtba26gtf+IK2bdt23hsTAADDV8AYl7kyqRONRhWJRDRHC1M/CWGAy7hqvHXNH799pXXNNTM/sq6RpNws+7E1H0WvsK4ZM/KMdc0LV79mXSNJowPpeT0y7jJNw0G36XWqu7XxG9Y1Z3rt/7/OHHvYuqbp1Fjrmg/2TrCukaRJr9qfe9q13+lYQ0mf6VW9tqqjo0O5ubkX3c/7u+AAAMMTAQQA8IIAAgB4QQABALwggAAAXhBAAAAvCCAAgBcEEADACwIIAOAFAQQA8IIAAgB4QQABALwY3sNIgyG3unjMuuRI7Vesa964d611TWPvGOuapw9XWddI0ocf51vX5I0+a13T0TnSuqYr6jZUtKik3bomb4T9Y+qN2597XX32w+tP/DnHukaSssJ91jWRbPt1ONVl/zytvK7BuuaqzBPWNZI0NevP1jXz/vEh65rStb+xrknn1y9bDCMFAAxoBBAAwAsCCADgBQEEAPCCAAIAeEEAAQC8IIAAAF4QQAAALwggAIAXBBAAwAsCCADgBQEEAPCCAAIAeDG8p2FDktT80jSnulcq/tm65u3Oz1nXjA51WdfsP11mXSNJH/eMsq453Ws/0TkYsP9vF5R9TXZGj3WNJE0c9bF1TbRvhHXNfyvYYV3zjf/+Xeuash86TJuGM6ZhAwAGNAIIAOAFAQQA8IIAAgB4QQABALwggAAAXhBAAAAvCCAAgBcEEADACwIIAOAFAQQA8IIAAgB4keG7gcHoxH2V1jWBr5+0rjH/e4x1Tckbh61rSvOj1jWS9IWw/RDOmD6wrjnYU2Rd880xu61rJKnXhKxrRgXsB362x7Ota0YEeq1rsoPd1jWSFHIYfNoeH2ldU55hv96Zp61L1Pqdr9gXSfrcN+3P15wM+zX/4MdTrGuyX3M7xwcSroAAAF4QQAAALwggAIAXBBAAwAsCCADgBQEEAPCCAAIAeEEAAQC8IIAAAF4QQAAALwggAIAXBBAAwIuAMcZ+6mAKRaNRRSIRzdFCZQQyfbdzQX/7H3+2rlmR12Rd09TbZ11zfZb9QMh/6ii1rpGkLX93k3XNH76Ta13zDze/YF3TZdzOnRN99v3Fjf33cZkB++c2MxCzrhkRtB9g6splkGtn3H6gbU1ei3WNqx1d9jVfzLIvWtJ0q3VN75xj1jXp0md6Va+t6ujoUG7uxf9PcQUEAPCCAAIAeJH0AHr88ccVCAQStsmTJyf7MACAQS4lf5Du+uuv11tvvfWXg2Twd+8AAIlSkgwZGRkqLi5OxacGAAwRKXkN6ODBgyotLdXEiRN111136fDhi/+Z6O7ubkWj0YQNADD0JT2AKioqtHHjRm3btk3r169Xc3OzbrrpJp06deqC+9fV1SkSifRvZWVlyW4JADAAJT2Aqqur9Y1vfEPTpk3T/Pnz9ctf/lLt7e165ZVXLrh/bW2tOjo6+reWlvS9xx8A4E/K3x2Ql5ena6+9Vk1NF/5FzHA4rHDY/pfRAACDW8p/D+j06dM6dOiQSkpKUn0oAMAgkvQAevDBB9XQ0KAPP/xQv/nNb3TrrbcqFArpjjvuSPahAACDWNJ/BHfkyBHdcccdOnnypMaOHasbb7xRu3bt0tixY5N9KADAIJb0AHrppZeS/SkHnD+csf9x4r9knE5BJ+c7EbMflDprRLPTsYq3tFvXrP7dN52OZetIzxinutJM+/VzGWAac/jhQ1D2c4NPxUZY10hu/bkMS3UZ5Pqvp+3X22X4qyS19uVZ1xztPWtd82z5q9Y19+pG65qBhllwAAAvCCAAgBcEEADACwIIAOAFAQQA8IIAAgB4QQABALwggAAAXhBAAAAvCCAAgBcEEADACwIIAOBFyv8g3UAWunaSU92CvH+zrukymdY1IwK91jUxBaxr2uMjrWtcvX3jP1rXfNg32rpmyoj0/WXdvNAZ6xqX4Zguwz674vbnnSSFAnHrml6Tni8neUH79R6XYT8gVJL2OHyPftxhOO3mU9OsazKKi6xrJKmvtc2pLhW4AgIAeEEAAQC8IIAAAF4QQAAALwggAIAXBBAAwAsCCADgBQEEAPCCAAIAeEEAAQC8IIAAAF4QQAAALwggAIAXw3oadud1Y5zqNh67wbrmiiz7abyTsk9Y11TlHLCuORMPW9dI0s+P3mRdMzqz27rmTJ/bRGcXwYCxrumLh1LQyfniDpPOXWVn9FjXZAXtJ3y7uHPsbuua/Z1lTsd6/qMvW9ec6rL//3RlpMO65sNvXWVdI0llP2QaNgBgmCOAAABeEEAAAC8IIACAFwQQAMALAggA4AUBBADwggACAHhBAAEAvCCAAABeEEAAAC8IIACAF8N6GOmRuW75+7f5f7Ku+dfDf2Nd87tW+wGK/7NvpnXN1WM/tq6RpAMfllrXhLN7rWvicfshnPGY23MbcBhGKpeadDFuA0xDGXHrGpe1izk8Tz1x+y9bo0L2w1UlaVSmfd31+cesa2bkfGRd81TWBOuagYYrIACAFwQQAMALAggA4AUBBADwggACAHhBAAEAvCCAAABeEEAAAC8IIACAFwQQAMALAggA4AUBBADwYlgPI7169e+c6v7lv86zrjn+lZh1zcSrW61rKguarWsiGWesayQp2j3CuqY7FrKuCTkMuexxOM5QFHMY5CpJo7Lsh8a6CDo8t1ePOmFd86fOAusaSfrjH+0H7p7caz9EuGXraOuaCSd2WtcMNFwBAQC8IIAAAF5YB9COHTt0yy23qLS0VIFAQFu2bEm43xijxx57TCUlJRo5cqSqqqp08ODBZPULABgirAOos7NT06dP17p16y54/9q1a/XMM8/oueee0+7duzVq1CjNnz9fXV1dl90sAGDosH4TQnV1taqrqy94nzFGTz/9tB555BEtXLhQkvT888+rqKhIW7Zs0ZIlSy6vWwDAkJHU14Cam5vV2tqqqqqq/tsikYgqKiq0c+eF37HR3d2taDSasAEAhr6kBlBr67m3DRcVFSXcXlRU1H/fp9XV1SkSifRvZWX2b2EEAAw+3t8FV1tbq46Ojv6tpaXFd0sAgDRIagAVFxdLktra2hJub2tr67/v08LhsHJzcxM2AMDQl9QAKi8vV3FxsbZv395/WzQa1e7du1VZWZnMQwEABjnrd8GdPn1aTU1N/R83Nzdr3759ys/P1/jx47Vq1Sr98Ic/1DXXXKPy8nI9+uijKi0t1aJFi5LZNwBgkLMOoD179ujmm2/u/3j16tWSpKVLl2rjxo166KGH1NnZqRUrVqi9vV033nijtm3bphEj7OeGAQCGroAxxn4aYApFo1FFIhHN0UJlBDJ9tzMsnPyW249HZ6zYZ13z7x9f6XQsW70xt58uG2M/vDPgMFBzoAsF7R9Tb5/9ANgvFdu/6Wjn5unWNVf+6DfWNXDXZ3pVr63q6Oj4zNf1vb8LDgAwPBFAAAAvCCAAgBcEEADACwIIAOAFAQQA8IIAAgB4QQABALwggAAAXhBAAAAvCCAAgBcEEADACwIIAOCF9Z9jGFKC9tN7JSkQsq8zvT1Ox0qH/N+fcaobk9lpXRN3mDadGYpZ17hOw3aRrgna8bj9Y8oIxa1rJLc1j8Udntug/XH6stM4fTxg/5jcjuNwvsbt126g4QoIAOAFAQQA8IIAAgB4QQABALwggAAAXhBAAAAvCCAAgBcEEADACwIIAOAFAQQA8IIAAgB4QQABALwY3sNIHYf5mTQNAQxk2D89pq/PuqZ7TNi6xlXQYQinyzjIoOMMyZjDnEuXwaIu/RmXtXOokdzWPF2M2wzh9DEOa24G/2BRF1wBAQC8IIAAAF4QQAAALwggAIAXBBAAwAsCCADgBQEEAPCCAAIAeEEAAQC8IIAAAF4QQAAALwggAIAXw3sY6UAXcpi66DCMNOAygVNSLE3fv2QE49Y13Y7HMsZ+DKfLwE+3GuuStHIZsHrkTJ51TWB4zu0ckrgCAgB4QQABALwggAAAXhBAAAAvCCAAgBcEEADACwIIAOAFAQQA8IIAAgB4QQABALwggAAAXhBAAAAvGEY6gAUcpk86jRV1m0WqkOyHhLoMFnUZ3Okq6NCfC5ehpy6CjmuXGbKf+NndZ//lJO6wDvFM6xIMUFwBAQC8IIAAAF5YB9COHTt0yy23qLS0VIFAQFu2bEm4f9myZQoEAgnbggULktUvAGCIsA6gzs5OTZ8+XevWrbvoPgsWLNCxY8f6txdffPGymgQADD3WrxpWV1erurr6M/cJh8MqLi52bgoAMPSl5DWg+vp6FRYW6rrrrtN9992nkydPXnTf7u5uRaPRhA0AMPQlPYAWLFig559/Xtu3b9ePf/xjNTQ0qLq6WrHYhd/WWVdXp0gk0r+VlZUluyUAwACU9N8DWrJkSf+/p06dqmnTpmnSpEmqr6/X3Llzz9u/trZWq1ev7v84Go0SQgAwDKT8bdgTJ05UQUGBmpqaLnh/OBxWbm5uwgYAGPpSHkBHjhzRyZMnVVJSkupDAQAGEesfwZ0+fTrhaqa5uVn79u1Tfn6+8vPz9cQTT2jx4sUqLi7WoUOH9NBDD+nqq6/W/Pnzk9o4AGBwsw6gPXv26Oabb+7/+JPXb5YuXar169dr//79+sUvfqH29naVlpZq3rx5+vu//3uFw+HkdQ0AGPSsA2jOnDky5uIDDn/1q19dVkNIP9dZn66DLtMh5DhUNOAwHDMWT89gUReug1xdntu4w6FcjtOXPXDPO9hhFhwAwAsCCADgBQEEAPCCAAIAeEEAAQC8IIAAAF4QQAAALwggAIAXBBAAwAsCCADgBQEEAPCCAAIAeEEAAQC8SPqf5MbwEXeYHJ2uOcauU6BN3P57sqDDMGynad1Ovbmtg8tz6yIjELOuGcBD2GGJKyAAgBcEEADACwIIAOAFAQQA8IIAAgB4QQABALwggAAAXhBAAAAvCCAAgBcEEADACwIIAOAFAQQA8IJhpHDmMujSZcily3Fch3D2xey/J8sIOQwWdeA6YDVdMh3WoSdu/yUo6z/5vnmo4JkEAHhBAAEAvCCAAABeEEAAAC8IIACAFwQQAMALAggA4AUBBADwggACAHhBAAEAvCCAAABeEEAAAC8YRjqAGZOe4ZOBPrdhmi6DRdN1nJDj4M5wZp91TbrWwXXAqotYPD3fm2YE7M+9jLMpaARecAUEAPCCAAIAeEEAAQC8IIAAAF4QQAAALwggAIAXBBAAwAsCCADgBQEEAPCCAAIAeEEAAQC8IIAAAF4wjBRymAfpzGVIaEbQvkHXwZ0hh2P1xEJOx7LlMvLUdR3SNWB10ugT1jV/ik9KQScXEXD4Ht3Ekt/HEMUVEADACwIIAOCFVQDV1dVp5syZysnJUWFhoRYtWqTGxsaEfbq6ulRTU6MxY8Zo9OjRWrx4sdra2pLaNABg8LMKoIaGBtXU1GjXrl1688031dvbq3nz5qmzs7N/nwceeECvv/66Xn31VTU0NOjo0aO67bbbkt44AGBws3oTwrZt2xI+3rhxowoLC7V3717Nnj1bHR0d+tnPfqZNmzbpa1/7miRpw4YN+tznPqddu3bpy1/+cvI6BwAMapf1GlBHR4ckKT8/X5K0d+9e9fb2qqqqqn+fyZMna/z48dq5c+cFP0d3d7ei0WjCBgAY+pwDKB6Pa9WqVbrhhhs0ZcoUSVJra6uysrKUl5eXsG9RUZFaW1sv+Hnq6uoUiUT6t7KyMteWAACDiHMA1dTU6MCBA3rppZcuq4Ha2lp1dHT0by0tLZf1+QAAg4PTL6KuXLlSb7zxhnbs2KFx48b1315cXKyenh61t7cnXAW1tbWpuLj4gp8rHA4rHA67tAEAGMSsroCMMVq5cqU2b96st99+W+Xl5Qn3z5gxQ5mZmdq+fXv/bY2NjTp8+LAqKyuT0zEAYEiwugKqqanRpk2btHXrVuXk5PS/rhOJRDRy5EhFIhHdc889Wr16tfLz85Wbm6v7779flZWVvAMOAJDAKoDWr18vSZozZ07C7Rs2bNCyZcskST/5yU8UDAa1ePFidXd3a/78+frpT3+alGYBAEOHVQAZc+nBhiNGjNC6deu0bt0656aQZumZOylJchmNGXSochkqKkkmTUM4XY4TcBws6iSenildH53Jt67py05BIxdj0jipdxhiFhwAwAsCCADgBQEEAPCCAAIAeEEAAQC8IIAAAF4QQAAALwggAIAXBBAAwAsCCADgBQEEAPCCAAIAeEEAAQC8cPqLqBhi0jlk2WEKdJ+x/z4pQ+mbYpwVjKXlOC7r4CoYsn9MZ3szrWsyHKaWdxWmZ70lSX/FXwA4T8BhorrLcYYAroAAAF4QQAAALwggAIAXBBAAwAsCCADgBQEEAPCCAAIAeEEAAQC8IIAAAF4QQAAALwggAIAXBBAAwAuGkUJ92aG0HSsUsB+6mJmmYZ+ugg6PyWUoazoHrMblMFDT5TgO66CcvuQ3Ai+4AgIAeEEAAQC8IIAAAF4QQAAALwggAIAXBBAAwAsCCADgBQEEAPCCAAIAeEEAAQC8IIAAAF4QQAAALxhGOoAFAvaDGu3HYkrG8Sz44+lC6xqXwZ0xY/99UtBpJaRQ0H7gp8tjcqnpi9uvQ2/cbdCsS3/pWjsMHVwBAQC8IIAAAF4QQAAALwggAIAXBBAAwAsCCADgBQEEAPCCAAIAeEEAAQC8IIAAAF4QQAAALwggAIAXDCOFjMPQU8ltkGSvw0DNWG+mdY3bI5IyQzHrmnQN1Iwb+0cVc1hvye15ynAYRurymEzM9dnFQMMVEADACwIIAOCFVQDV1dVp5syZysnJUWFhoRYtWqTGxsaEfebMmaNAIJCw3XvvvUltGgAw+FkFUENDg2pqarRr1y69+eab6u3t1bx589TZ2Zmw3/Lly3Xs2LH+be3atUltGgAw+Fm9CWHbtm0JH2/cuFGFhYXau3evZs+e3X97dna2iouLk9MhAGBIuqzXgDo6OiRJ+fn5Cbe/8MILKigo0JQpU1RbW6szZ85c9HN0d3crGo0mbACAoc/5bdjxeFyrVq3SDTfcoClTpvTffuedd2rChAkqLS3V/v379fDDD6uxsVGvvfbaBT9PXV2dnnjiCdc2AACDlHMA1dTU6MCBA3r33XcTbl+xYkX/v6dOnaqSkhLNnTtXhw4d0qRJk877PLW1tVq9enX/x9FoVGVlZa5tAQAGCacAWrlypd544w3t2LFD48aN+8x9KyoqJElNTU0XDKBwOKxwOOzSBgBgELMKIGOM7r//fm3evFn19fUqLy+/ZM2+ffskSSUlJU4NAgCGJqsAqqmp0aZNm7R161bl5OSotbVVkhSJRDRy5EgdOnRImzZt0te//nWNGTNG+/fv1wMPPKDZs2dr2rRpKXkAAIDBySqA1q9fL+ncL5v+/zZs2KBly5YpKytLb731lp5++ml1dnaqrKxMixcv1iOPPJK0hgEAQ4P1j+A+S1lZmRoaGi6rIQDA8MA0bOjsGLdfB5uQ/Z/WNT0x+1Nu0ugT1jXplK5p2C7HCcl+QrUkxRx+RbA7bv/cuvR3qHCMdY2rQIb9YzJ9fSnoZGhiGCkAwAsCCADgBQEEAPCCAAIAeEEAAQC8IIAAAF4QQAAALwggAIAXBBAAwAsCCADgBQEEAPCCAAIAeMEw0gEsXUMNC7c2OdX9W0Gldc2oo/YDNY8GJlrXuM4HNQ7fkjnVBOxrXB5TwG0WqVNdsM++we5c+4UY8x9d1jWuTCyWtmMNR1wBAQC8IIAAAF4QQAAALwggAIAXBBAAwAsCCADgBQEEAPCCAAIAeEEAAQC8IIAAAF4QQAAALwbcLDhjzs2T6lOv5DjPa6gIGPsFMMZ+fpyJ91jXSFKs234mV6zH4TGlaW6axCy4y6kzMfsGYz32C9HXZ3/eBU2vdc05Dk+Uw//boaZP59bbXGItAuZSe6TZkSNHVFZW5rsNAMBlamlp0bhx4y56/4ALoHg8rqNHjyonJ0eBQOJ3H9FoVGVlZWppaVFubq6nDv1jHc5hHc5hHc5hHc4ZCOtgjNGpU6dUWlqqYPDiPyIYcD+CCwaDn5mYkpSbmzusT7BPsA7nsA7nsA7nsA7n+F6HSCRyyX14EwIAwAsCCADgxaAKoHA4rDVr1igcDvtuxSvW4RzW4RzW4RzW4ZzBtA4D7k0IAIDhYVBdAQEAhg4CCADgBQEEAPCCAAIAeDFoAmjdunW66qqrNGLECFVUVOi3v/2t75bS7vHHH1cgEEjYJk+e7LutlNuxY4duueUWlZaWKhAIaMuWLQn3G2P02GOPqaSkRCNHjlRVVZUOHjzop9kUutQ6LFu27LzzY8GCBX6aTZG6ujrNnDlTOTk5Kiws1KJFi9TY2JiwT1dXl2pqajRmzBiNHj1aixcvVltbm6eOU+OvWYc5c+acdz7ce++9njq+sEERQC+//LJWr16tNWvW6L333tP06dM1f/58HT9+3HdraXf99dfr2LFj/du7777ru6WU6+zs1PTp07Vu3boL3r927Vo988wzeu6557R7926NGjVK8+fPV1eX/dDKgexS6yBJCxYsSDg/XnzxxTR2mHoNDQ2qqanRrl279Oabb6q3t1fz5s1TZ2dn/z4PPPCAXn/9db366qtqaGjQ0aNHddttt3nsOvn+mnWQpOXLlyecD2vXrvXU8UWYQWDWrFmmpqam/+NYLGZKS0tNXV2dx67Sb82aNWb69Om+2/BKktm8eXP/x/F43BQXF5snn3yy/7b29nYTDofNiy++6KHD9Pj0OhhjzNKlS83ChQu99OPL8ePHjSTT0NBgjDn33GdmZppXX321f58//OEPRpLZuXOnrzZT7tPrYIwxX/3qV813vvMdf039FQb8FVBPT4/27t2rqqqq/tuCwaCqqqq0c+dOj535cfDgQZWWlmrixIm66667dPjwYd8tedXc3KzW1taE8yMSiaiiomJYnh/19fUqLCzUddddp/vuu08nT5703VJKdXR0SJLy8/MlSXv37lVvb2/C+TB58mSNHz9+SJ8Pn16HT7zwwgsqKCjQlClTVFtbqzNnzvho76IG3DDST/v4448Vi8VUVFSUcHtRUZE++OADT135UVFRoY0bN+q6667TsWPH9MQTT+imm27SgQMHlJOT47s9L1pbWyXpgufHJ/cNFwsWLNBtt92m8vJyHTp0SN///vdVXV2tnTt3KhQK+W4v6eLxuFatWqUbbrhBU6ZMkXTufMjKylJeXl7CvkP5fLjQOkjSnXfeqQkTJqi0tFT79+/Xww8/rMbGRr322mseu0004AMIf1FdXd3/72nTpqmiokITJkzQK6+8onvuucdjZxgIlixZ0v/vqVOnatq0aZo0aZLq6+s1d+5cj52lRk1NjQ4cODAsXgf9LBdbhxUrVvT/e+rUqSopKdHcuXN16NAhTZo0Kd1tXtCA/xFcQUGBQqHQee9iaWtrU3FxsaeuBoa8vDxde+21ampq8t2KN5+cA5wf55s4caIKCgqG5PmxcuVKvfHGG3rnnXcS/nxLcXGxenp61N7enrD/UD0fLrYOF1JRUSFJA+p8GPABlJWVpRkzZmj79u39t8XjcW3fvl2VlZUeO/Pv9OnTOnTokEpKSny34k15ebmKi4sTzo9oNKrdu3cP+/PjyJEjOnny5JA6P4wxWrlypTZv3qy3335b5eXlCffPmDFDmZmZCedDY2OjDh8+PKTOh0utw4Xs27dPkgbW+eD7XRB/jZdeesmEw2GzceNG8/vf/96sWLHC5OXlmdbWVt+tpdV3v/tdU19fb5qbm82vf/1rU1VVZQoKCszx48d9t5ZSp06dMu+//755//33jSTz1FNPmffff9989NFHxhhjfvSjH5m8vDyzdetWs3//frNw4UJTXl5uzp4967nz5PqsdTh16pR58MEHzc6dO01zc7N56623zBe/+EVzzTXXmK6uLt+tJ819991nIpGIqa+vN8eOHevfzpw507/Pvffea8aPH2/efvtts2fPHlNZWWkqKys9dp18l1qHpqYm84Mf/MDs2bPHNDc3m61bt5qJEyea2bNne+480aAIIGOMefbZZ8348eNNVlaWmTVrltm1a5fvltLu9ttvNyUlJSYrK8tceeWV5vbbbzdNTU2+20q5d955x0g6b1u6dKkx5txbsR999FFTVFRkwuGwmTt3rmlsbPTbdAp81jqcOXPGzJs3z4wdO9ZkZmaaCRMmmOXLlw+5b9Iu9PglmQ0bNvTvc/bsWfPtb3/bXHHFFSY7O9vceuut5tixY/6aToFLrcPhw4fN7NmzTX5+vgmHw+bqq6823/ve90xHR4ffxj+FP8cAAPBiwL8GBAAYmgggAIAXBBAAwAsCCADgBQEEAPCCAAIAeEEAAQC8IIAAAF4QQAAALwggAIAXBBAAwAsCCADgxf8Fk5Z0ptjRNS8AAAAASUVORK5CYII=\n" + }, + "metadata": {} + } + ] + }, + { + "cell_type": "code", + "source": [ + "class Model_Linear(nn.Module):\n", + "\n", + " def __init__(self) -> None:\n", + "\n", + " super(Model_Linear, self).__init__()\n", + "\n", + " self.encoder = nn.Sequential(\n", + " nn.Flatten(),\n", + " nn.Linear(28*28, 128),\n", + " nn.BatchNorm1d(128),\n", + " nn.ReLU(),\n", + " nn.Linear(128, 60),\n", + " )\n", + "\n", + " self.decoder = nn.Sequential(\n", + " nn.Linear(60, 128),\n", + " nn.BatchNorm1d(128),\n", + " nn.ReLU(),\n", + " nn.Linear(128, 28*28),\n", + " nn.Sigmoid()\n", + " )\n", + "\n", + " def forward(self, x):\n", + "\n", + " encoded = self.encoder(x)\n", + " decoded = self.decoder(encoded)\n", + "\n", + " return decoded\n", + "\n", + "model_linear = Model_Linear().cuda()" + ], + "metadata": { + "id": "ZWDdOR_mXXZV" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#Loss function\n", + "CRITERION = nn.MSELoss()\n", + "\n", + "#Hyperparameters\n", + "EPOCHS = 20\n", + "LEARNING_RATE = 1e-3\n", + "WEIGHT_DECAY = 1e-8\n", + "\n", + "# Optimizers\n", + "OPTIMIZER = optim.Adam(model_linear.parameters(), lr=LEARNING_RATE, weight_decay = WEIGHT_DECAY)\n" + ], + "metadata": { + "id": "iAPTG9HzYAuO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from tqdm import tqdm\n", + "\n", + "conv_losses = []\n", + "conv_output = []\n", + "\n", + "for epoch in range(EPOCHS):\n", + " loss_full = 0\n", + " for img, label in tqdm(loader):\n", + " # img = img.reshape(-1, 1, 64, 64)\n", + "\n", + " # Forward pass\n", + " img = img.to(tc.float32)\n", + " img = img.cuda()\n", + "\n", + " outputs = model_linear(img)\n", + "\n", + " # print(img.shape, outputs.shape, decoded.shape)\n", + " loss = CRITERION(outputs, img.reshape((32, 28*28)))\n", + "\n", + " # Backward pass and optimization\n", + " OPTIMIZER.zero_grad()\n", + " loss.backward()\n", + " OPTIMIZER.step()\n", + "\n", + " loss_full += loss.item()\n", + "\n", + " loss_full = loss_full/len(loader)\n", + " conv_losses.append(loss_full)\n", + "\n", + " print(f'Epoch: {epoch + 1}, Loss: {loss_full:.7f}')\n", + " conv_output.append((epoch, img, outputs))\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "uiRykfL2YBHD", + "outputId": "92b1cb36-63ec-4e78-d391-2932b9381d4e" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "\r 0%| | 0/1875 [00:00