Skip to content

Commit

Permalink
add pre-training notebook from gregor
Browse files Browse the repository at this point in the history
  • Loading branch information
stevenabreu7 committed Apr 22, 2024
1 parent 053d136 commit 73febb9
Show file tree
Hide file tree
Showing 11 changed files with 624 additions and 367 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,9 @@ lightning_logs/
paper/03/data/ds_train.pt
paper/03/data/ds_val.pt
env_NIR.yml
# tmp (steve)
venv_lif/
paper/01_lif/lava
paper/01_lif/.venv
paper/02_cnn/lava-dl
paper/02_cnn/ann_pretraining/data/MNIST
14 changes: 14 additions & 0 deletions paper/02_cnn/ann_pretraining/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# NMNIST experiments in Sinabs

## Install requirements
```
pip install -r requirements.txt
```

## Convert a trained CNN to an SNN
Run the `test-converted-snn.ipynb` notebook

## Train CNN from scratch
```
python train.py --num_workers=4 --model=cnn --batch_size=64
```
367 changes: 0 additions & 367 deletions paper/02_cnn/ann_pretraining/ann_to_snn_conversion.ipynb

This file was deleted.

Binary file not shown.
53 changes: 53 additions & 0 deletions paper/02_cnn/ann_pretraining/cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.nn import functional as F


class CNN(pl.LightningModule):
def __init__(self, lr=1e-3):
super().__init__()
self.lr = lr
self.model = nn.Sequential(
nn.Conv2d(2, 20, 5, 1, bias=False),
nn.ReLU(),
nn.AvgPool2d(2, 2),
nn.Conv2d(20, 32, 5, 1, bias=False),
nn.ReLU(),
nn.AvgPool2d(2, 2),
nn.Conv2d(32, 128, 3, 1, bias=False),
nn.ReLU(),
nn.AvgPool2d(2, 2),
nn.Flatten(),
nn.Linear(128, 500, bias=False),
nn.ReLU(),
nn.Linear(500, 10, bias=False),
)

def forward(self, x):
return self.model(x)

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('train_loss', loss, on_epoch=True)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('valid_loss', loss)
prediction = (y_hat.argmax(1) == y).float()
self.log('valid_acc', prediction.sum() / len(prediction), prog_bar=True)

def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
prediction = (y_hat.argmax(1) == y).float()
self.log('test_acc', prediction.sum() / len(prediction), prog_bar=True)


def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)
160 changes: 160 additions & 0 deletions paper/02_cnn/ann_pretraining/nmnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import os
from typing import Callable, Optional, Tuple
import pytorch_lightning as pl

import numpy as np
import tonic
from tonic import (DiskCachedDataset, MemoryCachedDataset, SlicedDataset,
datasets, slicers, transforms)
from torch.utils.data import DataLoader


class NMNISTFrames(pl.LightningDataModule):
"""
This dataset provides 3 frames for each sample in the original NMNIST dataset.
The dataset length is 3*60000 for training and 3*10000 for testing set.
The frames are cached to disk in an efficient format.
Parameters:
save_to: str path where to save raw data to.
batch_size: the dataloader batch size.
augmentation: An optional callable that will be applied to each sample.
cache_path: Where to store cached versions of all the frames.
metadata_path: Store metadata about how recordings are sliced in individual samples.
Providing the path to store the metadata saves time when loading the dataset the next time.
num_workers: the number of threads for the dataloader.
precision: can be 16 for half or 32 for full precision.
"""

def __init__(
self,
save_to: str,
batch_size: int,
augmentation: Optional[Callable] = None,
cache_path: str = 'cache/frames',
metadata_path: str = 'metadata/frames',
num_workers: int = 6,
precision: int = 32,
):
super().__init__()
self.save_to = save_to
self.batch_size = batch_size
self.augmentation = augmentation
self.cache_path = cache_path
self.metadata_path = metadata_path
self.num_workers = num_workers
self.precision = precision

def prepare_data(self):
datasets.NMNIST(save_to=self.save_to, train=True)
datasets.NMNIST(save_to=self.save_to, train=False)

def get_train_or_testset(self, train: bool):
dataset = datasets.NMNIST(save_to=self.save_to, train=train)

slicer = slicers.SliceByTimeBins(3)
image_transform = transforms.ToImage(sensor_size=dataset.sensor_size)

dtype = {
32: np.float32,
16: np.float16,
}

sliced_dataset = SlicedDataset(
dataset,
slicer=slicer,
metadata_path=os.path.join(self.metadata_path, f"train_{train}"),
transform=lambda x: image_transform(x).astype(dtype[self.precision]),
)

return DiskCachedDataset(
dataset=sliced_dataset,
cache_path=os.path.join(self.cache_path, f"train_{train}", f"precision_{self.precision}"),
transform=self.augmentation,
)

def setup(self, stage=None):
self.train_data = self.get_train_or_testset(True)
self.test_data = self.get_train_or_testset(False)

def train_dataloader(self):
return DataLoader(self.train_data, num_workers=self.num_workers, batch_size=self.batch_size, shuffle=True)

def val_dataloader(self):
return DataLoader(self.test_data, num_workers=self.num_workers, batch_size=self.batch_size)

def test_dataloader(self):
return self.val_dataloader()


class NMNISTRaster(pl.LightningDataModule):
"""
This dataset provides the original NMNIST samples as rasters
and caches them to disk.
Parameters:
save_to: str path where to save raw data to.
batch_size: The batch size.
n_time_bins: How many time bins per sample.
augmentation: An optional callable that will be applied to each sample.
cache_path: Where to store cached versions of all the frames.
num_workers: the number of threads for the dataloader.
precision: can be 16 for half or 32 for full precision.
"""

def __init__(
self,
save_to: str,
batch_size: int,
n_time_bins: int,
augmentation: Optional[Callable] = None,
cache_path: str = 'cache/rasters',
num_workers: int = 6,
precision: int = 32,
):
super().__init__()
self.save_to = save_to
self.batch_size = batch_size
self.n_time_bins = n_time_bins
self.augmentation = augmentation
self.cache_path = cache_path
self.num_workers = num_workers
self.precision = precision

def prepare_data(self):
datasets.NMNIST(save_to=self.save_to, train=True)
datasets.NMNIST(save_to=self.save_to, train=False)

def get_train_or_testset(self, train: bool):
frame_transform = transforms.ToFrame(sensor_size=datasets.NMNIST.sensor_size, n_time_bins=self.n_time_bins)

dtype = {
32: np.float32,
16: np.float16,
}

dataset = datasets.NMNIST(
save_to=self.save_to,
train=train,
transform=lambda x: frame_transform(x).astype(dtype[self.precision]),
)

return DiskCachedDataset(
dataset=dataset,
cache_path=os.path.join(self.cache_path, f"train_{train}", f"precision_{self.precision}"),
transform=self.augmentation,
)

def setup(self, stage=None):
self.train_data = self.get_train_or_testset(True)
self.test_data = self.get_train_or_testset(False)

def train_dataloader(self):
return DataLoader(self.train_data, num_workers=self.num_workers, batch_size=self.batch_size, shuffle=True)#, collate_fn=tonic.collation.PadTensors())

def val_dataloader(self):
return DataLoader(self.test_data, num_workers=self.num_workers, batch_size=self.batch_size)#, collate_fn=tonic.collation.PadTensors())

def test_dataloader(self):
return self.val_dataloader()

4 changes: 4 additions & 0 deletions paper/02_cnn/ann_pretraining/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pytorch_lightning==1.9.5
sinabs==1.2.10
tonic
ipykernel
69 changes: 69 additions & 0 deletions paper/02_cnn/ann_pretraining/snn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.nn import functional as F
import sinabs.layers as sl
# import sinabs.exodus.layers as sel


class SNN(pl.LightningModule):
def __init__(self, batch_size, lr=1e-3):
super().__init__()
self.batch_size = batch_size
self.lr = lr
backend = sl
self.model = nn.Sequential(
nn.Conv2d(2, 20, 5, 1, bias=False),
backend.IAFSqueeze(shape=[batch_size, 20, 30, 30], batch_size=batch_size),
nn.AvgPool2d(2, 2),
nn.Conv2d(20, 32, 5, 1, bias=False),
backend.IAFSqueeze(shape=[batch_size, 32, 11, 11], batch_size=batch_size),
nn.AvgPool2d(2, 2),
nn.Conv2d(32, 128, 3, 1, bias=False),
backend.IAFSqueeze(shape=[batch_size, 128, 3, 3], batch_size=batch_size),
nn.AvgPool2d(2, 2),
nn.Flatten(),
nn.Linear(128, 500, bias=False),
backend.IAFSqueeze(shape=[batch_size, 500], batch_size=batch_size),
nn.Linear(500, 10, bias=False),
)

def forward(self, x):
self.reset_states()
return self.model(x.flatten(0, 1)).unflatten(0, (self.batch_size, -1))

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x).sum(1)
loss = F.cross_entropy(y_hat, y)
self.log('train_loss', loss, on_epoch=True)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x).sum(1)
loss = F.cross_entropy(y_hat, y)
self.log('valid_loss', loss)
prediction = (y_hat.argmax(1) == y).float()
self.log('valid_acc', prediction.sum() / len(prediction), prog_bar=True)

def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x).sum(1)
prediction = (y_hat.argmax(1) == y).float()
self.log('test_acc', prediction.sum() / len(prediction), prog_bar=True)

@property
def sinabs_layers(self):
return [
layer
for layer in self.model.modules()
if isinstance(layer, sl.StatefulLayer)
]

def reset_states(self):
for layer in self.sinabs_layers:
layer.reset_states()

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)
Loading

0 comments on commit 73febb9

Please sign in to comment.