Skip to content

Commit

Permalink
continue pretraining script (dim mismatch issues)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevenabreu7 committed Apr 16, 2024
1 parent ed19710 commit 053d136
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 24 deletions.
1 change: 1 addition & 0 deletions paper/02_cnn/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
nmnist
ann_pretraining/data/MNIST
183 changes: 159 additions & 24 deletions paper/02_cnn/ann_pretraining/ann_to_snn_conversion.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -16,37 +16,113 @@
"from sinabs.from_torch import from_model\n",
"import sinabs.layers as sl\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running on device: `mps`\n"
]
}
],
"source": [
"device = \"cpu\"\n",
"if torch.cuda.is_available():\n",
" device = \"cuda\"\n",
"elif torch.backends.mps.is_available():\n",
" device = \"mps\"\n",
"print(f\"Running on device: `{device}`\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.weight torch.Size([16, 2, 5, 5])\n",
"2.weight torch.Size([16, 16, 3, 3])\n",
"5.weight torch.Size([8, 16, 3, 3])\n",
"9.weight torch.Size([256, 128])\n",
"11.weight torch.Size([10, 256])\n"
]
}
],
"source": [
"params = torch.load(\"pretrained_ann_weights.pth\")\n",
"for k, v in params.items():\n",
" print(k, v.shape)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Sequential(\n",
" (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (1): ReLU()\n",
" (2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (3): ReLU()\n",
" (4): SumPool2d(norm_type=1, kernel_size=(2, 2), stride=None, ceil_mode=False)\n",
" (5): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (6): ReLU()\n",
" (7): SumPool2d(norm_type=1, kernel_size=(2, 2), stride=None, ceil_mode=False)\n",
" (8): Flatten(start_dim=1, end_dim=-1)\n",
" (9): Linear(in_features=128, out_features=256, bias=False)\n",
" (10): ReLU()\n",
" (11): Linear(in_features=256, out_features=10, bias=False)\n",
" (12): ReLU()\n",
")"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"N_INPUTS = 1\n",
"ann = nn.Sequential(\n",
" nn.Conv2d(1, 20, 5, 1, bias=False),\n",
" nn.Conv2d(\n",
" N_INPUTS, 16, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1), bias=False\n",
" ), # 16, 18, 18\n",
" nn.ReLU(),\n",
" nn.AvgPool2d(2, 2),\n",
" nn.Conv2d(20, 32, 5, 1, bias=False),\n",
" nn.Conv2d(\n",
" 16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False\n",
" ), # 8, 18,18\n",
" nn.ReLU(),\n",
" nn.AvgPool2d(2, 2),\n",
" nn.Conv2d(32, 128, 3, 1, bias=False),\n",
" sl.SumPool2d(kernel_size=(2, 2)), # 8, 17,17\n",
" nn.Conv2d(\n",
" 16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False\n",
" ), # 8, 9, 9\n",
" nn.ReLU(),\n",
" nn.AvgPool2d(2, 2),\n",
" nn.Flatten(),\n",
" nn.Linear(128, 500, bias=False),\n",
" sl.SumPool2d(kernel_size=(2, 2)),\n",
" nn.Flatten(), # 4, 4, 8 -> 128\n",
" nn.Linear(128, 256, bias=False),\n",
" nn.ReLU(),\n",
" nn.Linear(500, 10, bias=False),\n",
")"
" nn.Linear(256, 10, bias=False),\n",
" nn.ReLU(),\n",
")\n",
"ann"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -70,15 +146,56 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 39,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/2 [00:00<?, ?it/s]\n",
" 0%| | 0/1875 [00:00<?, ?it/s]\u001b[A\n",
" 0%| | 0/2 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([32, 1, 28, 28])\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"ename": "RuntimeError",
"evalue": "linear(): input and weight.T shapes cannot be multiplied (32x72 and 128x256)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[39], line 20\u001b[0m\n\u001b[1;32m 18\u001b[0m data, target \u001b[38;5;241m=\u001b[39m data\u001b[38;5;241m.\u001b[39mto(device), target\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28mprint\u001b[39m(data\u001b[38;5;241m.\u001b[39mshape)\n\u001b[0;32m---> 20\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mann\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 21\u001b[0m optim\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m 23\u001b[0m loss \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mcross_entropy(output, target)\n",
"File \u001b[0;32m~/.pyenv/versions/3.10.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/.pyenv/versions/3.10.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/.pyenv/versions/3.10.10/lib/python3.10/site-packages/torch/nn/modules/container.py:217\u001b[0m, in \u001b[0;36mSequential.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[1;32m 216\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[0;32m--> 217\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 218\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n",
"File \u001b[0;32m~/.pyenv/versions/3.10.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/.pyenv/versions/3.10.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/.pyenv/versions/3.10.10/lib/python3.10/site-packages/torch/nn/modules/linear.py:116\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mRuntimeError\u001b[0m: linear(): input and weight.T shapes cannot be multiplied (32x72 and 128x256)"
]
}
],
"source": [
"mnist_train = MNIST(\"./data\", train=True, is_spiking=False)\n",
"train_loader = DataLoader(mnist_train, batch_size=128, shuffle=True)\n",
"train_loader = DataLoader(mnist_train, batch_size=32, shuffle=True)\n",
"\n",
"mnist_test = MNIST(\"./data\", train=False, is_spiking=False)\n",
"test_loader = DataLoader(mnist_test, batch_size=128, shuffle=False)\n",
"test_loader = DataLoader(mnist_test, batch_size=32, shuffle=False)\n",
"\n",
"ann = ann.to(device)\n",
"ann.train()\n",
Expand All @@ -87,15 +204,19 @@
"\n",
"n_epochs = 2\n",
"\n",
"losses = []\n",
"\n",
"for n in tqdm(range(n_epochs)):\n",
" for data, target in iter(train_loader):\n",
" for data, target in tqdm(iter(train_loader)):\n",
" data, target = data.to(device), target.to(device)\n",
" print(data.shape)\n",
" output = ann(data)\n",
" optim.zero_grad()\n",
"\n",
" loss = F.cross_entropy(output, target)\n",
" loss.backward()\n",
" optim.step()\n"
" losses.append(loss.item())\n",
" optim.step()"
]
},
{
Expand Down Expand Up @@ -223,8 +344,22 @@
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 053d136

Please sign in to comment.