From a2e9318e113495e8fb692ad052c4ac7410aa0638 Mon Sep 17 00:00:00 2001 From: Gregor Lenz Date: Sat, 6 Jul 2024 21:48:51 +0200 Subject: [PATCH] Updated docs with snntorch to Norse example and dev docs (#109) * Updated docs with snntorch to Norse example and dev docs * Updated docs for porting --------- Co-authored-by: jegp --- docs/source/_config.yml | 11 +- docs/source/_toc.yml | 5 +- docs/source/api_design.md | 10 +- docs/source/{dev.md => contributing.md} | 0 docs/source/examples/snntorch_to_norse.ipynb | 929 +++++++++++++++++++ docs/source/nir_graph_example.svg | 250 +++++ docs/source/porting_nir.md | 115 +++ docs/source/usage.md | 2 +- flake.lock | 8 +- flake.nix | 5 +- 10 files changed, 1324 insertions(+), 11 deletions(-) rename docs/source/{dev.md => contributing.md} (100%) create mode 100644 docs/source/examples/snntorch_to_norse.ipynb create mode 100644 docs/source/nir_graph_example.svg create mode 100644 docs/source/porting_nir.md diff --git a/docs/source/_config.yml b/docs/source/_config.yml index 0129a78..3d06869 100644 --- a/docs/source/_config.yml +++ b/docs/source/_config.yml @@ -4,4 +4,13 @@ logo: ../logo_light.png repository: url: https://github.com/neuromorphs/nir - branch: main \ No newline at end of file + branch: main + path_to_book: docs/source + +execute: + execute_notebooks: off + +launch_buttons: + notebook_interface: "jupyterlab" + binderhub_url: "https://mybinder.org/v2/gh/neuromorphs/nir/main?urlpath=lab" + colab_url: "https://colab.research.google.com" \ No newline at end of file diff --git a/docs/source/_toc.yml b/docs/source/_toc.yml index 1882b8a..80178af 100644 --- a/docs/source/_toc.yml +++ b/docs/source/_toc.yml @@ -24,8 +24,9 @@ parts: - file: examples/snntorch/nir-conversion - file: examples/spinnaker2/import - file: examples/spyx/conversion + - file: examples/snntorch_to_norse - caption: Developer guide chapters: + - file: porting_nir - file: api_design - - file: dev - title: Contributing + - file: contributing diff --git a/docs/source/api_design.md b/docs/source/api_design.md index d523696..0675573 100644 --- a/docs/source/api_design.md +++ b/docs/source/api_design.md @@ -1,6 +1,6 @@ # API design -The reference implementation simply consists of a series of Python classes that *represent* [NIR structures](primitives). +NIR is simple: it consists of a series of objects that *represent* [NIR structures](primitives). In other words, they do not implement the functionality of the nodes, but simply represent the necessary parameters required to *eventually* evaluate the node. We chose Python because the language is straight-forward, known by most, and has excellent [dataclasses](https://docs.python.org/3/library/dataclasses.html) exactly for our purpose. @@ -16,6 +16,14 @@ In this example, we create a class that inherits from the parent [`NIRNode`](htt Instantiating the class is simply `MyNIRNode(np.array([...]))`. ## NIR Graphs and edges +```{figure} nir_graph_example.svg +--- +height: 200px +name: nir-graph-example +--- +An example of a NIR graph with four nodes: Input, Leaky-Integrator, Affine map, and Output. +``` + A collection of nodes is a `NIRGraph`, which is, you guessed it, a `NIRNode`. But the graph node is special in that it contains a number of named nodes (`.nodes`) and connections between them (`.edges`). The nodes are named because we need to uniquely distinguish them from each other, so `.nodes` is actually a dictionary (`Dict[str, NIRNode]`). diff --git a/docs/source/dev.md b/docs/source/contributing.md similarity index 100% rename from docs/source/dev.md rename to docs/source/contributing.md diff --git a/docs/source/examples/snntorch_to_norse.ipynb b/docs/source/examples/snntorch_to_norse.ipynb new file mode 100644 index 0000000..984baf1 --- /dev/null +++ b/docs/source/examples/snntorch_to_norse.ipynb @@ -0,0 +1,929 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8R08u5xKHBv6" + }, + "source": [ + "# snnTorch to Norse (Sim2Sim)\n", + "\n", + "## NIR for deep Spiking Neural Networks - From snnTorch to Norse\n", + "### Written by Jason Eshraghian and Bernhard Vogginger\n", + "\n", + "What you will learn:\n", + "* Learn how spiking neurons are implemented as a recurrent network\n", + "* Download event-based data and train a spiking neural network with it\n", + "* Export it to the neuromorphic intermediate representation\n", + "* Import it to Norse and run inference\n", + "\n", + "Install the latest PyPi distribution of snnTorch by clicking into the following cell and pressing `Shift+Enter`.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "H_3jC4pJ8xzO" + }, + "source": [ + "## 1. Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "l5rPCjp5OIhf" + }, + "outputs": [], + "source": [ + "!pip install snntorch --quiet\n", + "!pip install tonic --quiet" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kh4NW-mc0JaY" + }, + "outputs": [], + "source": [ + "# imports\n", + "import snntorch as snn\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from torch.utils.data import DataLoader\n", + "from torchvision import datasets, transforms" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4lvCNRHbOGW7" + }, + "source": [ + "# 2. Handling Event-based Data with Tonic" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nhJunMSEOVMk" + }, + "source": [ + "## 2.1 PokerDVS Dataset\n", + "\n", + "The dataset used in this tutorial is POKERDVS by T. Serrano-Gotarredona and B. Linares-Barranco:\n", + "\n", + "```\n", + "Serrano-Gotarredona, Teresa, and Bernabé Linares-Barranco. \"Poker-DVS and MNIST-DVS. Their history, how they were made, and other details.\" Frontiers in neuroscience 9 (2015): 481.\n", + "```\n", + "\n", + "It is comprised of four classes, each being a suite of a playing card deck: clubs, spades, hearts, and diamonds. The data consists of 131 poker pip symbols, and was collected by flipping poker cards in front of a DVS128 camera." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mM4cKrtmOgPS" + }, + "outputs": [], + "source": [ + "import tonic\n", + "\n", + "poker_train = tonic.datasets.POKERDVS(save_to='./data', train=True)\n", + "poker_test = tonic.datasets.POKERDVS(save_to='./data', train=False)\n", + "\n", + "events, target = poker_train[0]\n", + "print(events)\n", + "tonic.utils.plot_event_grid(events)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cGgrBCjqOhpo" + }, + "outputs": [], + "source": [ + "import tonic.transforms as transforms\n", + "from tonic import DiskCachedDataset\n", + "\n", + "# time_window\n", + "frame_transform = tonic.transforms.Compose([tonic.transforms.Denoise(filter_time=10000),\n", + " tonic.transforms.ToFrame(\n", + " sensor_size=tonic.datasets.POKERDVS.sensor_size,\n", + " time_window=1000)\n", + " ])\n", + "\n", + "batch_size = 8\n", + "cached_trainset = DiskCachedDataset(poker_train, transform=frame_transform, cache_path='./cache/pokerdvs/train')\n", + "cached_testset = DiskCachedDataset(poker_test, transform=frame_transform, cache_path='./cache/pokerdvs/test')\n", + "\n", + "train_loader = DataLoader(cached_trainset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True)\n", + "test_loader = DataLoader(cached_testset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True)\n", + "\n", + "data, labels = next(iter(train_loader))\n", + "print(data.size())\n", + "print(labels)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IVacJHwO6t4M" + }, + "source": [ + "## 3. Define the SNN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "SmclKI7d62Oc" + }, + "outputs": [], + "source": [ + "num_inputs = 35*35*2\n", + "num_hidden = 128\n", + "num_outputs = 4" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "CWdyYErq7zXe" + }, + "outputs": [], + "source": [ + "dtype = torch.float\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hH8nHOsZ9Hxo" + }, + "source": [ + "In the following code-block, note how the decay rate `beta` has two alternative definitions:\n", + "* `beta1` is set to a global decay rate for all neurons in the first spiking layer.\n", + "* `beta2` is randomly initialized to a vector of 10 different numbers. Each spiking neuron in the output layer (which not-so-coincidentally has 10 neurons) therefore has a unique, and random, decay rate." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7rEX7-U687zV" + }, + "outputs": [], + "source": [ + "# Define Network\n", + "class Net(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + " alpha1 = 0.5\n", + " beta1 = 0.9 # global decay rate for all leaky neurons in layer 1\n", + " beta2 = torch.rand((num_outputs), dtype = torch.float) # independent decay rate for each leaky neuron in layer 2: [0, 1)\n", + " threshold2 = torch.ones_like(beta2) # threshold parameter must have the same shape as beta for NIR\n", + " alpha2 = torch.ones_like(beta2)*0.9\n", + "\n", + " # Initialize layers\n", + " self.fc1 = nn.Linear(num_inputs, num_hidden)\n", + " self.lif1 = snn.Synaptic(alpha=alpha1, beta=beta1) # not a learnable decay rate\n", + " self.fc2 = nn.Linear(num_hidden, num_outputs)\n", + " self.lif2 = snn.Synaptic(alpha=alpha2, beta=beta2, threshold=threshold2, learn_beta=True) # learnable decay rate\n", + "\n", + " def forward(self, x):\n", + " syn1, mem1 = self.lif1.init_synaptic() # reset/init hidden states at t=0\n", + " syn2, mem2 = self.lif2.init_synaptic() # reset/init hidden states at t=0\n", + "\n", + " spk2_rec = [] # record output spikes\n", + " mem2_rec = [] # record output hidden states\n", + "\n", + " for step in range(x.size(0)): # loop over time\n", + " cur1 = self.fc1(x[step].flatten(1))\n", + " spk1, syn1, mem1 = self.lif1(cur1, syn1, mem1)\n", + " cur2 = self.fc2(spk1)\n", + " spk2, syn2, mem2 = self.lif2(cur2, syn2, mem2)\n", + "\n", + " spk2_rec.append(spk2) # record spikes\n", + " mem2_rec.append(mem2) # record membrane\n", + "\n", + " return torch.stack(spk2_rec), torch.stack(mem2_rec)\n", + "\n", + "# Load the network onto CUDA if available\n", + "net = Net().to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3le9vSo29ocU" + }, + "source": [ + "The code in the `forward()` function will only be called once the input argument `x` is explicitly passed into `net`.\n", + "\n", + "* `fc1` applies a linear transformation to all input pixels from the POKERDVS dataset;\n", + "* `lif1` integrates the weighted input over time, emitting a spike if the threshold condition is met;\n", + "* `fc2` applies a linear transformation to the output spikes of `lif1`;\n", + "* `lif2` is another spiking neuron layer, integrating the weighted spikes over time.\n", + "\n", + "A 'biophysical' interpretation is that `fc1` and `fc2` generate current injections that are fed into a set of $128$ and $10$ spiking neurons in `lif1` and `lif2`, respectively.\n", + "\n", + "> Note: the number of spiking neurons is automatically inferred by the dimensionality of the dimensions of the current injection value." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "L-ekyCxt78qB" + }, + "source": [ + "# 4. Training the **SNN**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1zKiDeSM1onU" + }, + "source": [ + "## 4.1 Accuracy Metric\n", + "Below is a function that takes a batch of data, counts up all the spikes from each neuron (i.e., a rate code over the simulation time), and compares the index of the highest count with the actual target. If they match, then the network correctly predicted the target." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "eytnSVQ8OtBj" + }, + "outputs": [], + "source": [ + "def measure_accuracy(model, dataloader):\n", + " with torch.no_grad():\n", + " model.eval()\n", + " running_length = 0\n", + " running_accuracy = 0\n", + "\n", + " for data, targets in iter(dataloader):\n", + " data = data.to(device)\n", + " targets = targets.to(device)\n", + "\n", + " # forward-pass\n", + " spk_rec, _ = model(data)\n", + " spike_count = spk_rec.sum(0) # batch x num_outputs\n", + " _, max_spike = spike_count.max(1)\n", + "\n", + " # correct classes for one batch\n", + " num_correct = (max_spike == targets).sum()\n", + "\n", + " # total accuracy\n", + " running_length += len(targets)\n", + " running_accuracy += num_correct\n", + "\n", + " accuracy = (running_accuracy / running_length)\n", + "\n", + " return accuracy.item()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "k1Fzf-3S3WGp" + }, + "source": [ + "## 4.2 Loss Definition\n", + "The `nn.CrossEntropyLoss` function in PyTorch automatically handles taking the softmax of the output layer as well as generating a loss at the output." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mK0xRQQp3ypC" + }, + "outputs": [], + "source": [ + "loss = nn.CrossEntropyLoss()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wr4TTQ204ubI" + }, + "source": [ + "## 4.3 Optimizer\n", + "Adam is a robust optimizer that performs well on recurrent networks, so let's use that with a learning rate of $5\\times10^{-4}$." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "muBPATWo40pI" + }, + "outputs": [], + "source": [ + "optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IKTBRtsX-ZNy" + }, + "source": [ + "## 4.4 One Iteration of Training\n", + "Take the first batch of data and load it onto CUDA if available." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "e323axy244wa" + }, + "outputs": [], + "source": [ + "data, targets = next(iter(train_loader))\n", + "data = data.to(device)\n", + "targets = targets.to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZZ1zApYM-cRk" + }, + "source": [ + "Pass the input data to the network." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "J-w2w7ME-d2o" + }, + "outputs": [], + "source": [ + "spk_rec, mem_rec = net(data)\n", + "print(mem_rec.size())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AIikA-Rh-fSa" + }, + "source": [ + "The recording of the membrane potential is taken across:\n", + "* 29 time steps\n", + "* 8 samples of data\n", + "* 4 output neurons\n", + "\n", + "We wish to calculate the loss at every time step, and sum these up together:\n", + "\n", + "\n", + "$$\\mathcal{L}_{Total-CE} = \\sum_t\\mathcal{L}_{CE}[t]$$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WHlczZEb-vqj" + }, + "outputs": [], + "source": [ + "# initialize the total loss value\n", + "loss_val = torch.zeros((1), dtype=dtype, device=device)\n", + "\n", + "# sum loss at every step\n", + "for step in range(mem_rec.size(0)):\n", + " loss_val += loss(mem_rec[step], targets)\n", + "\n", + "print(f\"Training loss: {loss_val.item():.3f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "InrjD9n--xCt" + }, + "source": [ + "The loss is quite large, because it is summed over 29-ish time steps. The accuracy is also bad (it should be roughly around 25%) as the network is untrained:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LaDDSPLK9u7p" + }, + "outputs": [], + "source": [ + "measure_accuracy(net, train_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0dvil_l3-09y" + }, + "source": [ + "A single weight update is applied to the network as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7Okzx-3p-0Mt" + }, + "outputs": [], + "source": [ + "# clear previously stored gradients\n", + "optimizer.zero_grad()\n", + "\n", + "# calculate the gradients\n", + "loss_val.backward()\n", + "\n", + "# weight update\n", + "optimizer.step()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iHbLCKsa-3Ao" + }, + "source": [ + "Now, re-run the loss calculation and accuracy after a single iteration:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Cvx9Ux2--4Aw" + }, + "outputs": [], + "source": [ + "# calculate new network outputs using the same data\n", + "spk_rec, mem_rec = net(data)\n", + "\n", + "# initialize the total loss value\n", + "loss_val = torch.zeros((1), dtype=dtype, device=device)\n", + "\n", + "# sum loss at every step\n", + "for step in range(mem_rec.size(0)):\n", + " loss_val += loss(mem_rec[step], targets)\n", + "\n", + "print(f\"Training loss: {loss_val.item():.3f}\")\n", + "measure_accuracy(net, train_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7T3MCnxo-5fE" + }, + "source": [ + "After only one iteration, the loss should have decreased and accuracy should have increased. Note how membrane potential is used to calculate the cross entropy loss, and spike count is used for the measure of accuracy. It is also possible to use the spike count in the loss ([see Tutorial 6 in the snnTorch docs](https://snntorch.readthedocs.io/en/latest/tutorials/index.html))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hDrQ7oMM--dU" + }, + "source": [ + "## 4.5 Training Loop\n", + "\n", + "Let's combine everything into a training loop. We will train for one epoch (though feel free to increase `num_epochs`), exposing our network to each sample of data once." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UCnjeX4U_AVD" + }, + "outputs": [], + "source": [ + "num_epochs = 10\n", + "loss_hist = []\n", + "test_loss_hist = []\n", + "counter = 0\n", + "\n", + "# Outer training loop\n", + "for epoch in range(num_epochs):\n", + " iter_counter = 0\n", + " train_batch = iter(train_loader)\n", + "\n", + " # Minibatch training loop\n", + " for data, targets in train_batch:\n", + " data = data.to(device)\n", + " targets = targets.to(device)\n", + "\n", + " # forward pass\n", + " net.train()\n", + " spk_rec, mem_rec = net(data)\n", + "\n", + " # initialize the loss & sum over time\n", + " loss_val = torch.zeros((1), dtype=dtype, device=device)\n", + " for step in range(mem_rec.size(0)):\n", + " loss_val += loss(mem_rec[step], targets)\n", + "\n", + " # Gradient calculation + weight update\n", + " optimizer.zero_grad()\n", + " loss_val.backward()\n", + " optimizer.step()\n", + "\n", + " # Store loss history for future plotting\n", + " loss_hist.append(loss_val.item())\n", + "\n", + " # Test set\n", + " with torch.no_grad():\n", + " net.eval()\n", + " test_data, test_targets = next(iter(test_loader))\n", + " test_data = test_data.to(device)\n", + " test_targets = test_targets.to(device)\n", + "\n", + " # Test set forward pass\n", + " test_spk, test_mem = net(test_data)\n", + "\n", + " # Test set loss\n", + " test_loss = torch.zeros((1), dtype=dtype, device=device)\n", + " for step in range(test_mem.size(0)):\n", + " test_loss += loss(test_mem[step], test_targets)\n", + " test_loss_hist.append(test_loss.item())\n", + "\n", + " # Print train/test loss/accuracy\n", + " # if counter % 50 == 0:\n", + " print(f\"Iteration: {counter} \\t Accuracy: {measure_accuracy(net, test_loader)}\")\n", + " counter += 1\n", + " iter_counter +=1" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MWXOCBNI_B47" + }, + "source": [ + "If this was your first time training an SNN, then congratulations. I'm proud of you and I always believed in you." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4pAJ9_WsFeSB" + }, + "outputs": [], + "source": [ + "measure_accuracy(net, test_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7BZpywK3_LZL" + }, + "source": [ + "# 5. Export to NIR" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yJ2ZcBH-ARmt" + }, + "outputs": [], + "source": [ + "import nir" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qvtZS8kW_euK" + }, + "outputs": [], + "source": [ + "nir_model = snn.export_to_nir(net.cpu(), data.cpu())\n", + "nir.write(\"nir_model.nir\", nir_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VFq1hIbkFeSC" + }, + "source": [ + "# 6. Run the model with Norse" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0-xAVat6FeSD" + }, + "source": [ + "## 6.1 Import NIR model to Norse" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "doW_7eauFeSD" + }, + "outputs": [], + "source": [ + "!pip install norse --quiet" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_ITH3yqMFeSD" + }, + "outputs": [], + "source": [ + "import norse.torch as norse" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "la-3vqRmFeSD" + }, + "outputs": [], + "source": [ + "nir_model = nir.read(\"nir_model.nir\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "iq0JrUPJFeSD" + }, + "outputs": [], + "source": [ + "norse_model = norse.from_nir(nir_model, dt=0.0001) # dt is the simulation step width assumed by snntorch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zI5ZfDJuKqCf" + }, + "outputs": [], + "source": [ + "norse_model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "S56YqwRYFeSD" + }, + "source": [ + "`norse.from_nir(..)` returns a `GraphExecutor` object. Its is callable like a `nn.Module`.\n", + "\n", + "In this case it contains:\n", + "- Two Linear Layers\n", + "- Two CubaLIF layers, each composed of a leaky-integrator and an LIF neuron\n", + "- Identy layers for input and output\n", + "\n", + "The order in which the layers are called, can also be obtained:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "e6LvAiH6MbnX" + }, + "outputs": [], + "source": [ + "print([elem.name for elem in norse_model.get_execution_order()])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "r6HSJq7LFeSD" + }, + "source": [ + "## 6.2. Run the model with a single batch of data\n", + "\n", + "The graph executor can run a single forward step. Let's write a function to apply the data for all time steps..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pnqTzuWvFeSD" + }, + "outputs": [], + "source": [ + "def apply(data):\n", + " \"\"\"\n", + " apply an input data batch to the norse model\n", + " \"\"\"\n", + " state = None\n", + " hid_rec = []\n", + " out = []\n", + "\n", + " for i, t in enumerate(data):\n", + " z, state = norse_model(t.flatten(1), state)\n", + " out.append(z)\n", + " hid_rec.append(state)\n", + " spk_out = torch.stack(out)\n", + " # hid_rec = torch.stack(hid_rec)\n", + " return spk_out, hid_rec" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QCL4EIZQFeSE" + }, + "source": [ + "Apply to a batch of data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5Z6MVgU6FeSE" + }, + "outputs": [], + "source": [ + "data, targets = next(iter(test_loader))\n", + "\n", + "# data = data.to(device)\n", + "\n", + "spk, hid = apply(data)\n", + "\n", + "# count the number of spikes for each neuron and assess the winner\n", + "predictions = spk.sum(axis=0).argmax(axis=-1)\n", + "print(f\"Predicted classes: {predictions}\")\n", + "print(f\"Actual classes: {targets}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "x3eGShtIFeSE" + }, + "source": [ + "### 6.3 Measure accuracy for test dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "w8pgFPB5FeSE" + }, + "outputs": [], + "source": [ + "def measure_accuracy2(model, dataloader):\n", + " with torch.no_grad():\n", + " # model.eval() # not needed!\n", + " running_length = 0\n", + " running_accuracy = 0\n", + "\n", + " for data, targets in iter(dataloader):\n", + " # data = data.to(device)\n", + " # targets = targets.to(device)\n", + "\n", + " # forward-pass\n", + " spk_rec, _ = model(data)\n", + " spike_count = spk_rec.sum(0) # batch x num_outputs\n", + " _, max_spike = spike_count.max(1)\n", + "\n", + " # correct classes for one batch\n", + " num_correct = (max_spike == targets).sum()\n", + "\n", + " # total accuracy\n", + " running_length += len(targets)\n", + " running_accuracy += num_correct\n", + "\n", + " accuracy = (running_accuracy / running_length)\n", + "\n", + " return accuracy.item()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PDkVh1maFeSE" + }, + "outputs": [], + "source": [ + "measure_accuracy2(apply, test_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "NaTLN3HN_7R0" + }, + "outputs": [], + "source": [ + "#@title Run this block for a good time\n", + "import requests\n", + "from IPython.display import Image, display\n", + "\n", + "def display_image_from_url(url):\n", + " response = requests.get(url, stream=True)\n", + " display(Image(response.content))\n", + "\n", + "url = \"http://www.clker.com/cliparts/7/8/a/0/1498553633398980412very-nice-borat.med.png\"\n", + "display_image_from_url(url)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3SlDUImz6Q2e" + }, + "source": [ + "# Conclusion\n", + "\n", + "That covers how to train a spiking neural network, how to convert it into the neuromorphic intermediate representation, and how to load into another pytorch based framework.\n", + "\n", + "There are a lot of ways to alter this, e.g. for SNN training, by using different neuron models, surrogate gradients, learnable beta and threshold values, or modifying the fully-connected layers by replacing them with convolutions or whatever else you fancy." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "include_colab_link": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python [conda env:nir-demo] *", + "language": "python", + "name": "conda-env-nir-demo-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/source/nir_graph_example.svg b/docs/source/nir_graph_example.svg new file mode 100644 index 0000000..2a58f4d --- /dev/null +++ b/docs/source/nir_graph_example.svg @@ -0,0 +1,250 @@ + + + + + + + + + + + + + + + + + + + + + + + + NIRGraph + + + + + Input + + LI + + Affine + + Output + + + + + diff --git a/docs/source/porting_nir.md b/docs/source/porting_nir.md new file mode 100644 index 0000000..32071a9 --- /dev/null +++ b/docs/source/porting_nir.md @@ -0,0 +1,115 @@ +# Using NIR in hardware + +NIR is easily portable to any platform and is liberally licensed under the BSD-3 clause, so it can be used in any project, commercial or open-source. +We have a reference implementation in Python, but can export [HDF5](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) files that can be read by any language or platform. +**To use NIR, you simply convert the nodes and edges into your platform's primitives**. +That's all. +Let us unpack that statement. + +## The NIR format +NIR is an intermediate representation and **only consists of declarations**. +That is, we do not implement any dynamic or describe any runtime behavior---that is up to the individual platforms. +NIR consists of a *hierarchical* structure, with the top-most level being a single [`NIRGraph` object](api_design.md#nir-graphs-and-edges). + +```{figure} nir_graph_example.svg +--- +height: 200px +name: nir-graph-example +--- +An example of a NIR graph with four nodes: Input, Leaky-Integrator, Affine map, and Output. +``` + +```{note} +See our [API design](api_design.md) for more information about the NIR hierarchical structure. + +See the [NIR primitives](primitives.md) for more information about the individual nodes. +``` + +## Integrating with Python + +First step is to load the NIR graph into your platform. +In Python, you can import NIR as a library (installable via [`pip install nir`](https://pypi.org/project/nir/)). +If your graph is stored in a file, you can load it using the `nir.read` function. + +```python +import nir +my_graph = nir.read("path_to_my_graph.nir") +``` + +Once that is done, the graph can be parsed by (1) matching the nodes to your platform's primitives and (2) connecting the nodes together. +Note that the top-level graph may be recursive, so we recommend a recursive function that traverses the graph and evaluates the nodes. +Here's a simple example (without recursion): + +```python + +import nir + +def parse_graph(graph: nir.NIRGraph): + # Create a dictionary of nodes + nodes = {} + for name, node in graph.nodes.items(): + # Match the node to your platform's primitive + if isinstance(node, nir.Input): + nodes[name] = MyPlatformInput() + elif isinstance(node, nir.LI): + nodes[name] = MyPlatformLeakyIntegrator(node.tau, node.r, node.v_leak) + elif isinstance(node, nir.Affine): + nodes[name] = MyPlatformAffine(node.weights, node.bias) + elif isinstance(node, nir.Output): + nodes[name] = MyPlatformOutput() + else: + raise NotImplementedError(f"Node {node} not supported.") + + # Connect the nodes + for edge in graph.edges: + # Connect the nodes + nodes[edge[0]].connect(nodes[edge[1]]) + + return nodes + +``` + +```{note} +See the [NIR primitives](primitives.md) for more information about the content of each node. +``` + +### Integrating with PyTorch + +Since several libraries are built on top of PyTorch, we provide default PyTorch mappings in [nirtorch](https://github.com/neuromorphs/nirtorch). +`nirtorch` provides a simple way to write and load NIR graphs, but you still need to let `nirtorch` know how to evaluate the SNN-specific nodes (such as Leaky-Integrator and Spike). + +```python +import nir, nirtorch + +# Map nodes that are specific to your library +# - nirtorch will map obvious nodes like `Input`, `Output`, `Affine`, `Conv2d` etc. +# - but only if your parsing function do not return a module for that node +def parse_module(node: nir.NIRNode) -> Optional[torch.nn.Module]: + if isinstance(module, LIFBoxCell): + return ... + else: + return None # Return none to allow nirtorch to map the node + +# Load a graph as a PyTorch module (`torch.nn.Module`) +nir_graph = ... +torch_graph = nirtorch.load(nir_graph, parse_module) +``` + +## Integrating via HDF5 files + +If you are not using Python, you can load the NIR graph from the HDF5 file and parse it using your platform's primitives. +The data follows the API structure from before. +Values in the nodes are encoded as [Numpy arrays](https://numpy.org/doc/stable/reference/c-api/array.html). + +HDF5 interfaces with numerous languages, including C, C++, Java, and MATLAB. +We refer to the Wikipedia page on [HDF5](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) for more information on how to use it in your language. + +## What if my platform doesn't support `X`? +NIR contains a number of primitives, like the `LI` (Leaky Integrator) node, or `Spike` (threshold) primitives, that may not be directly supported by your platform. +In this case, you have two options: +* **Approximate the behavior**: You can approximate the behavior of the node using your platform's primitives. + For example, the `LIF` node can be approximated by a simple leaky integrator. +* **Ignore the node**: If your platform does not support the node, you can simply ignore it. + This is a valid strategy, as most hardware platforms are naturally constrained. In this case, we advice that you simply raise an exception to the user and inform them that the node is not supported. + In our [roadmap](roadmap.md), we plan to work on optimization and approximation strategies for these cases. + If this is interesting for you, we invite you to [about#Contact](get in touch). diff --git a/docs/source/usage.md b/docs/source/usage.md index 8b604c1..120702f 100644 --- a/docs/source/usage.md +++ b/docs/source/usage.md @@ -32,7 +32,7 @@ sample_data = torch.randn(batch_size, 10) nir_model = norse.to_nir(model, sample_data) ``` -### Part 2: Convert NIR model to +### Part 2: Convert NIR model to chip ```python import sinabs from sinabs.backend.dynapcnn import DynapcnnNetwork diff --git a/flake.lock b/flake.lock index 6f0a383..7ea6986 100644 --- a/flake.lock +++ b/flake.lock @@ -20,16 +20,16 @@ }, "nixpkgs": { "locked": { - "lastModified": 1688389917, - "narHash": "sha256-RKiK1QeommEsjQ8fLgxt4831x9O6n2gD7wAhVZTrr8M=", + "lastModified": 1720110830, + "narHash": "sha256-E5dN9GDV4LwMEduhBLSkyEz51zM17XkWZ3/9luvNOPs=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "aed4b19d312525ae7ca9bceb4e1efe3357d0e2eb", + "rev": "c0d0be00d4ecc4b51d2d6948e37466194c1e6c51", "type": "github" }, "original": { "id": "nixpkgs", - "ref": "nixos-23.05", + "ref": "nixos-24.05", "type": "indirect" } }, diff --git a/flake.nix b/flake.nix index 3d09f07..75aac6c 100644 --- a/flake.nix +++ b/flake.nix @@ -1,7 +1,7 @@ { description = "Neuromorphic Intermediate Representation reference implementation"; inputs = { - nixpkgs.url = "nixpkgs/nixos-23.05"; + nixpkgs.url = "nixpkgs/nixos-24.05"; flake-utils.url = "github:numtide/flake-utils"; }; outputs = { self, nixpkgs, flake-utils }: @@ -9,7 +9,7 @@ let pkgs = nixpkgs.legacyPackages.${system}; in { devShells.default = - let pythonPackages = pkgs.python39Packages; + let pythonPackages = pkgs.python3Packages; in pkgs.mkShell rec { name = "impurePythonEnv"; venvDir = "./.venv"; @@ -20,6 +20,7 @@ pythonPackages.h5py pythonPackages.black pkgs.ruff + pkgs.autoPatchelfHook ]; postVenvCreation = '' unset SOURCE_DATE_EPOCH