Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
markschoene committed Jun 11, 2024
0 parents commit f912fca
Show file tree
Hide file tree
Showing 31 changed files with 5,061 additions and 0 deletions.
166 changes: 166 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
*.pyc

# S5 specific stuff
wandb/
cache_dir/
raw_datasets/
120 changes: 120 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Scalable Event-by-event Processing of Neuromorphic Sensory Signals With Deep State-Space Models
![Figure 1](docs/figure1.png)
This is the official implementation of our paper [Scalable Event-by-event Processing of Neuromorphic Sensory Signals With Deep State-Space Models
](https://arxiv.org/abs/2404.18508).
The core motivation for this work was the irregular time-series modeling problem presented in the paper [Simplified State Space Layers for Sequence Modeling
](https://arxiv.org/abs/2208.04933).
We acknowledge the awesome [S5 project](https://github.com/lindermanlab/S5) and the trainer class provided by this [UvA tutorial](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/guide4/Research_Projects_with_JAX.html), which highly influenced our code.

Our project treats a quite general machine learning problem:
Modeling **long sequences** that are **irregularly** sampled by a possibly large number of **asynchronous** sensors.
This problem is particularly present in the field of neuromorphic computing, where event-based sensors emit up to millions events per second from asynchronous channels.

We show how linear state-space models can be tuned to effectively model asynchronous event-based sequences.
Our contributions are
- Integration of dirac delta coded event streams
- time-invariant input normalization to effectively learn from long event-streams
- formulating neuromorphic event-streams as a language modeling problem with **asynchronous tokens**
- effectively model event-based vision **without frames and without CNNs**

## Installation
The project is implemented in [JAX](https://github.com/google/jax) with [Flax](https://flax.readthedocs.io/en/latest/).
By default, we install JAX with GPU support with CUDA >= 12.0.
To install JAX for CPU, replace `jax[cuda]` with `jax[cpu]` in the `requirements.txt` file.
PyTorch is only required for loading data.
Therefore, we install only the CPU version of PyTorch.
Install the requirements with
```bash
pip install -r requirements.txt
```
Install this repository
```bash
pip install -e .
```
We tested with JAX versions between `0.4.20` and `0.4.29`.
Different CUDA and JAX versions might result in slightly different results.

## Reproducing experiments
We use the [hydra](https://hydra.cc/docs/intro/) package to manage configurations.
If you are not familiar with hydra, we recommend to read the [documentation](https://hydra.cc/docs/intro/).

### Run benchmark tasks
The basic command to run an experiment is
```bash
python run_training.py
```
This will default to running the Spiking Heidelberg Digits (SHD) dataset.
All benchmark tasks are defined by the configurations in `configs/tasks/`, and can be run by specifying the `task` argument.
E.g. run the Spiking Speech Commands (SSC) task with
```bash
python run_training.py task=spiking-speech-commands
```
or run the DVS128 Gestures task with
```bash
python run_training.py task=dvs-gesture
```

### Trained models
We provide our best models for [download](https://datashare.tu-dresden.de/s/g2dQCi792B8DqnC).
Check out the `tutorial_inference.ipynb` notebook to see how to load and run inference with these models.
We also provide a script to evaluate the models on the test set
```bash
python run_evaluation.py task=spiking-speech-commands checkpoint=downloaded/model/SSC
```


### Specify HPC system and logging
Many researchers operate on different HPC systems and perhaps log their experiments to multiple platforms.
Therefore, the user can specify configurations for
- different systems (directories for reading data and saving outputs)
- logging methods (e.g. whether to log locally or to [wandb](https://wandb.ai/))

By default, the `configs/system/local.yaml` and `configs/logging/local.yaml` configurations are used, respectively.
We suggest to create new configs for the HPC systems and wandb projects you are using.

For example, to run the model on SSC with your custom wandb logging config and your custom HPC specification do
```bash
python run_training.py task=spiking-speech-commands logging=wandb system=hpc
```
where `configs/logging/wandb.yaml` should look like
```yaml
log_dir: ${output_dir}
interval: 1000
wandb: False
summary_metric: "Performance/Validation accuracy"
project: wandb_project_name
entity: wandb_entity_name
```
and `configs/system/hpc.yaml` should specify data and output directories
```yaml
# @package _global_
data_dir: my/fast/storage/location/data
output_dir: my/job/output/location/${task.name}/${oc.env:SLURM_JOB_ID}/${now:%Y-%m-%d-%H-%M-%S}
```
The string `${task.name}/${oc.env:SLURM_JOB_ID}/${now:%Y-%m-%d-%H-%M-%S}` will create subdirectories named by task, slurm job ID, and date,
which we found useful in practice.
This specification of the `output_dir` is not required though.

## Tutorials
To get started with event-based state-space models, we created tutorials for training and inference.
- `tutorial_training.ipynb` shows how to train a model on a reduced version of the Spiking Heidelberg Digits with just two classes. The model converges after few minutes on CPUs.
- `tutorial_inference.ipynb` shows how to load a trained model and run inference. The models are available for download from the provided [download link](https://datashare.tu-dresden.de/s/g2dQCi792B8DqnC).
- `tutorial_online_inference.ipynb` runs event-by-event inference with batch size one (online inference) on the DVS128 Gestures dataset and measures the throughput of the model.

## Help and support
We are eager to help you with any questions or issues you might have.
Please use the GitHub issue tracker for questions and to report issues.

## Citation
Please use the following when citing our work:
```
@misc{Schoene2024,
title={Scalable Event-by-event Processing of Neuromorphic Sensory Signals With Deep State-Space Models},
author={Mark Schöne and Neeraj Mohan Sushma and Jingyue Zhuge and Christian Mayr and Anand Subramoney and David Kappel},
year={2024},
eprint={2404.18508},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
```
12 changes: 12 additions & 0 deletions configs/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
defaults:
- _self_
- system: local
- task: spiking-heidelberg-digits
- logging: local

seed: 1234
checkpoint: null

hydra:
run:
dir: ${output_dir}/hydra-outputs/${now:%Y-%m-%d-%H-%M-%S}
6 changes: 6 additions & 0 deletions configs/logging/local.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
log_dir: ${output_dir}
interval: 1000
wandb: False
summary_metric: "Performance/Validation accuracy"
project: ???
entity: ???
24 changes: 24 additions & 0 deletions configs/model/dvs/small.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# @package _global_

model:
ssm_init:
C_init: lecun_normal
dt_min: 0.001
dt_max: 0.1
conj_sym: false
clip_eigs: true
ssm:
discretization: async
d_model: 128
d_ssm: 128
ssm_block_size: 16
num_stages: 2
num_layers_per_stage: 3
dropout: 0.25
classification_mode: timepool
prenorm: true
batchnorm: false
bn_momentum: 0.95
pooling_stride: 16
pooling_mode: timepool
state_expansion_factor: 2
24 changes: 24 additions & 0 deletions configs/model/shd/medium.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# @package _global_

model:
ssm_init:
C_init: lecun_normal
dt_min: 0.004
dt_max: 0.1
conj_sym: false
clip_eigs: false
ssm:
discretization: async
d_model: 96
d_ssm: 128
ssm_block_size: 8
num_stages: 2
num_layers_per_stage: 3
dropout: 0.23
classification_mode: pool
prenorm: true
batchnorm: false
bn_momentum: 0.95
pooling_stride: 8
pooling_mode: avgpool
state_expansion_factor: 1
24 changes: 24 additions & 0 deletions configs/model/shd/tiny.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# @package _global_

model:
ssm_init:
C_init: lecun_normal
dt_min: 0.004
dt_max: 0.1
conj_sym: false
clip_eigs: false
ssm:
discretization: async
d_model: 16
d_ssm: 16
ssm_block_size: 8
num_stages: 1
num_layers_per_stage: 6
dropout: 0.1
classification_mode: timepool
prenorm: true
batchnorm: false
bn_momentum: 0.95
pooling_stride: 32
pooling_mode: timepool
state_expansion_factor: 1
Loading

0 comments on commit f912fca

Please sign in to comment.