diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..328479f --- /dev/null +++ b/.gitignore @@ -0,0 +1,166 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ +*.pyc + +# S5 specific stuff +wandb/ +cache_dir/ +raw_datasets/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..6ff54f8 --- /dev/null +++ b/README.md @@ -0,0 +1,120 @@ +# Scalable Event-by-event Processing of Neuromorphic Sensory Signals With Deep State-Space Models +![Figure 1](docs/figure1.png) +This is the official implementation of our paper [Scalable Event-by-event Processing of Neuromorphic Sensory Signals With Deep State-Space Models +](https://arxiv.org/abs/2404.18508). +The core motivation for this work was the irregular time-series modeling problem presented in the paper [Simplified State Space Layers for Sequence Modeling +](https://arxiv.org/abs/2208.04933). +We acknowledge the awesome [S5 project](https://github.com/lindermanlab/S5) and the trainer class provided by this [UvA tutorial](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/guide4/Research_Projects_with_JAX.html), which highly influenced our code. + +Our project treats a quite general machine learning problem: +Modeling **long sequences** that are **irregularly** sampled by a possibly large number of **asynchronous** sensors. +This problem is particularly present in the field of neuromorphic computing, where event-based sensors emit up to millions events per second from asynchronous channels. + +We show how linear state-space models can be tuned to effectively model asynchronous event-based sequences. +Our contributions are +- Integration of dirac delta coded event streams +- time-invariant input normalization to effectively learn from long event-streams +- formulating neuromorphic event-streams as a language modeling problem with **asynchronous tokens** +- effectively model event-based vision **without frames and without CNNs** + +## Installation +The project is implemented in [JAX](https://github.com/google/jax) with [Flax](https://flax.readthedocs.io/en/latest/). +By default, we install JAX with GPU support with CUDA >= 12.0. +To install JAX for CPU, replace `jax[cuda]` with `jax[cpu]` in the `requirements.txt` file. +PyTorch is only required for loading data. +Therefore, we install only the CPU version of PyTorch. +Install the requirements with + ```bash + pip install -r requirements.txt + ``` +Install this repository + ```bash + pip install -e . + ``` +We tested with JAX versions between `0.4.20` and `0.4.29`. +Different CUDA and JAX versions might result in slightly different results. + +## Reproducing experiments +We use the [hydra](https://hydra.cc/docs/intro/) package to manage configurations. +If you are not familiar with hydra, we recommend to read the [documentation](https://hydra.cc/docs/intro/). + +### Run benchmark tasks +The basic command to run an experiment is +```bash +python run_training.py +``` +This will default to running the Spiking Heidelberg Digits (SHD) dataset. +All benchmark tasks are defined by the configurations in `configs/tasks/`, and can be run by specifying the `task` argument. +E.g. run the Spiking Speech Commands (SSC) task with +```bash +python run_training.py task=spiking-speech-commands +``` +or run the DVS128 Gestures task with +```bash +python run_training.py task=dvs-gesture +``` + +### Trained models +We provide our best models for [download](https://datashare.tu-dresden.de/s/g2dQCi792B8DqnC). +Check out the `tutorial_inference.ipynb` notebook to see how to load and run inference with these models. +We also provide a script to evaluate the models on the test set +```bash +python run_evaluation.py task=spiking-speech-commands checkpoint=downloaded/model/SSC +``` + + +### Specify HPC system and logging +Many researchers operate on different HPC systems and perhaps log their experiments to multiple platforms. +Therefore, the user can specify configurations for +- different systems (directories for reading data and saving outputs) +- logging methods (e.g. whether to log locally or to [wandb](https://wandb.ai/)) + +By default, the `configs/system/local.yaml` and `configs/logging/local.yaml` configurations are used, respectively. +We suggest to create new configs for the HPC systems and wandb projects you are using. + +For example, to run the model on SSC with your custom wandb logging config and your custom HPC specification do +```bash +python run_training.py task=spiking-speech-commands logging=wandb system=hpc +``` +where `configs/logging/wandb.yaml` should look like +```yaml +log_dir: ${output_dir} +interval: 1000 +wandb: False +summary_metric: "Performance/Validation accuracy" +project: wandb_project_name +entity: wandb_entity_name +``` +and `configs/system/hpc.yaml` should specify data and output directories +```yaml +# @package _global_ + +data_dir: my/fast/storage/location/data +output_dir: my/job/output/location/${task.name}/${oc.env:SLURM_JOB_ID}/${now:%Y-%m-%d-%H-%M-%S} +``` +The string `${task.name}/${oc.env:SLURM_JOB_ID}/${now:%Y-%m-%d-%H-%M-%S}` will create subdirectories named by task, slurm job ID, and date, +which we found useful in practice. +This specification of the `output_dir` is not required though. + +## Tutorials +To get started with event-based state-space models, we created tutorials for training and inference. +- `tutorial_training.ipynb` shows how to train a model on a reduced version of the Spiking Heidelberg Digits with just two classes. The model converges after few minutes on CPUs. +- `tutorial_inference.ipynb` shows how to load a trained model and run inference. The models are available for download from the provided [download link](https://datashare.tu-dresden.de/s/g2dQCi792B8DqnC). +- `tutorial_online_inference.ipynb` runs event-by-event inference with batch size one (online inference) on the DVS128 Gestures dataset and measures the throughput of the model. + +## Help and support +We are eager to help you with any questions or issues you might have. +Please use the GitHub issue tracker for questions and to report issues. + +## Citation +Please use the following when citing our work: +``` +@misc{Schoene2024, + title={Scalable Event-by-event Processing of Neuromorphic Sensory Signals With Deep State-Space Models}, + author={Mark Schöne and Neeraj Mohan Sushma and Jingyue Zhuge and Christian Mayr and Anand Subramoney and David Kappel}, + year={2024}, + eprint={2404.18508}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} +``` diff --git a/configs/base.yaml b/configs/base.yaml new file mode 100644 index 0000000..af73c99 --- /dev/null +++ b/configs/base.yaml @@ -0,0 +1,12 @@ +defaults: + - _self_ + - system: local + - task: spiking-heidelberg-digits + - logging: local + +seed: 1234 +checkpoint: null + +hydra: + run: + dir: ${output_dir}/hydra-outputs/${now:%Y-%m-%d-%H-%M-%S} \ No newline at end of file diff --git a/configs/logging/local.yaml b/configs/logging/local.yaml new file mode 100644 index 0000000..9d610e1 --- /dev/null +++ b/configs/logging/local.yaml @@ -0,0 +1,6 @@ +log_dir: ${output_dir} +interval: 1000 +wandb: False +summary_metric: "Performance/Validation accuracy" +project: ??? +entity: ??? \ No newline at end of file diff --git a/configs/model/dvs/small.yaml b/configs/model/dvs/small.yaml new file mode 100644 index 0000000..6bcb069 --- /dev/null +++ b/configs/model/dvs/small.yaml @@ -0,0 +1,24 @@ +# @package _global_ + +model: + ssm_init: + C_init: lecun_normal + dt_min: 0.001 + dt_max: 0.1 + conj_sym: false + clip_eigs: true + ssm: + discretization: async + d_model: 128 + d_ssm: 128 + ssm_block_size: 16 + num_stages: 2 + num_layers_per_stage: 3 + dropout: 0.25 + classification_mode: timepool + prenorm: true + batchnorm: false + bn_momentum: 0.95 + pooling_stride: 16 + pooling_mode: timepool + state_expansion_factor: 2 diff --git a/configs/model/shd/medium.yaml b/configs/model/shd/medium.yaml new file mode 100644 index 0000000..a5ff46e --- /dev/null +++ b/configs/model/shd/medium.yaml @@ -0,0 +1,24 @@ +# @package _global_ + +model: + ssm_init: + C_init: lecun_normal + dt_min: 0.004 + dt_max: 0.1 + conj_sym: false + clip_eigs: false + ssm: + discretization: async + d_model: 96 + d_ssm: 128 + ssm_block_size: 8 + num_stages: 2 + num_layers_per_stage: 3 + dropout: 0.23 + classification_mode: pool + prenorm: true + batchnorm: false + bn_momentum: 0.95 + pooling_stride: 8 + pooling_mode: avgpool + state_expansion_factor: 1 \ No newline at end of file diff --git a/configs/model/shd/tiny.yaml b/configs/model/shd/tiny.yaml new file mode 100644 index 0000000..00da6fe --- /dev/null +++ b/configs/model/shd/tiny.yaml @@ -0,0 +1,24 @@ +# @package _global_ + +model: + ssm_init: + C_init: lecun_normal + dt_min: 0.004 + dt_max: 0.1 + conj_sym: false + clip_eigs: false + ssm: + discretization: async + d_model: 16 + d_ssm: 16 + ssm_block_size: 8 + num_stages: 1 + num_layers_per_stage: 6 + dropout: 0.1 + classification_mode: timepool + prenorm: true + batchnorm: false + bn_momentum: 0.95 + pooling_stride: 32 + pooling_mode: timepool + state_expansion_factor: 1 diff --git a/configs/model/ssc/medium.yaml b/configs/model/ssc/medium.yaml new file mode 100644 index 0000000..3bc6bdd --- /dev/null +++ b/configs/model/ssc/medium.yaml @@ -0,0 +1,24 @@ +# @package _global_ + +model: + ssm_init: + C_init: lecun_normal + dt_min: 0.0015 + dt_max: 0.1 + conj_sym: true + clip_eigs: false + ssm: + discretization: async + d_model: 96 + d_ssm: 128 + ssm_block_size: 16 + num_stages: 2 + num_layers_per_stage: 3 + dropout: 0.23 + classification_mode: pool + prenorm: true + batchnorm: true + bn_momentum: 0.95 + pooling_stride: 8 + pooling_mode: avgpool + state_expansion_factor: 2 diff --git a/configs/model/ssc/small.yaml b/configs/model/ssc/small.yaml new file mode 100644 index 0000000..d41c394 --- /dev/null +++ b/configs/model/ssc/small.yaml @@ -0,0 +1,24 @@ +# @package _global_ + +model: + ssm_init: + C_init: lecun_normal + dt_min: 0.002 + dt_max: 0.1 + conj_sym: true + clip_eigs: false + ssm: + discretization: async + d_model: 64 + d_ssm: 64 + ssm_block_size: 8 + num_stages: 1 + num_layers_per_stage: 6 + dropout: 0.27 + classification_mode: timepool + prenorm: true + batchnorm: true + bn_momentum: 0.95 + pooling_stride: 8 + pooling_mode: timepool + state_expansion_factor: 1 \ No newline at end of file diff --git a/configs/system/local.yaml b/configs/system/local.yaml new file mode 100644 index 0000000..eefb5df --- /dev/null +++ b/configs/system/local.yaml @@ -0,0 +1,5 @@ +# @package _global_ + +data_dir: ./data +output_dir: ./outputs/${now:%Y-%m-%d-%H-%M-%S} +checkpoint_dir: ./checkpoints \ No newline at end of file diff --git a/configs/task/dvs-gesture.yaml b/configs/task/dvs-gesture.yaml new file mode 100644 index 0000000..0b95657 --- /dev/null +++ b/configs/task/dvs-gesture.yaml @@ -0,0 +1,34 @@ +# @package _global_ +defaults: + - /model: dvs/small + +task: + name: dvs-gesture-classification + +training: + num_epochs: 100 + per_device_batch_size: 16 + per_device_eval_batch_size: 4 + num_workers: 4 + time_jitter: 5 + spatial_jitter: 0.3 + noise: 0.0 + drop_event: 0.05 + time_skew: 1.12 + max_roll: 32 + max_angle: 10 + max_scale: 1.2 + max_drop_chunk: 0.02 + cut_mix: 0.4 + pad_unit: 524288 + slice_events: 65536 + validate_on_test: true + +optimizer: + ssm_base_lr: 0.000012 + lr_factor: 6 + warmup_epochs: 10 + ssm_weight_decay: 0.0 + weight_decay: 0.02 + schedule: cosine + accumulation_steps: 4 \ No newline at end of file diff --git a/configs/task/spiking-heidelberg-digits.yaml b/configs/task/spiking-heidelberg-digits.yaml new file mode 100644 index 0000000..825a622 --- /dev/null +++ b/configs/task/spiking-heidelberg-digits.yaml @@ -0,0 +1,30 @@ +# @package _global_ +defaults: + - /model: shd/medium + +task: + name: shd-classification + +training: + num_epochs: 30 + per_device_batch_size: 32 + per_device_eval_batch_size: 128 + num_workers: 4 + time_jitter: 1 + spatial_jitter: 0.55 + noise: 35 + max_drop_chunk: 0.02 + drop_event: 0.1 + time_skew: 1.2 + cut_mix: 0.3 + pad_unit: 8192 + validate_on_test: true + +optimizer: + ssm_base_lr: 1.7e-5 + lr_factor: 10 + warmup_epochs: 3 + ssm_weight_decay: 0.0 + weight_decay: 0.03 + schedule: cosine + accumulation_steps: 1 \ No newline at end of file diff --git a/configs/task/spiking-speech-commands.yaml b/configs/task/spiking-speech-commands.yaml new file mode 100644 index 0000000..2f21d91 --- /dev/null +++ b/configs/task/spiking-speech-commands.yaml @@ -0,0 +1,29 @@ +# @package _global_ +defaults: + - /model: ssc/medium + +task: + name: ssc-classification + +training: + num_epochs: 200 + per_device_batch_size: 64 + per_device_eval_batch_size: 128 + num_workers: 4 + time_jitter: 3 + spatial_jitter: 1.0 + noise: 100 + drop_event: 0.1 + max_drop_chunk: 0.02 + cut_mix: 0.3 + time_skew: 1.05 + pad_unit: 8192 + +optimizer: + ssm_base_lr: 0.000005 + lr_factor: 5 + warmup_epochs: 20 + ssm_weight_decay: 0.0 + weight_decay: 0.05 + schedule: cosine + accumulation_steps: 1 \ No newline at end of file diff --git a/configs/task/tutorial.yaml b/configs/task/tutorial.yaml new file mode 100644 index 0000000..f22c230 --- /dev/null +++ b/configs/task/tutorial.yaml @@ -0,0 +1,30 @@ +# @package _global_ +defaults: + - /model: shd/tiny + +task: + name: shd-classification + +training: + num_epochs: 5 + per_device_batch_size: 16 + per_device_eval_batch_size: 16 + num_workers: 4 + time_jitter: 1 + spatial_jitter: 0.55 + noise: 35 + max_drop_chunk: 0.02 + drop_event: 0.1 + time_skew: 1.2 + cut_mix: 0.3 + pad_unit: 8192 + validate_on_test: false + +optimizer: + ssm_base_lr: 5e-5 + lr_factor: 10 + warmup_epochs: 1 + ssm_weight_decay: 0.0 + weight_decay: 0.01 + schedule: cosine + accumulation_steps: 1 \ No newline at end of file diff --git a/docs/figure1.png b/docs/figure1.png new file mode 100644 index 0000000..246116c Binary files /dev/null and b/docs/figure1.png differ diff --git a/event_ssm/__init__.py b/event_ssm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/event_ssm/dataloading.py b/event_ssm/dataloading.py new file mode 100644 index 0000000..06fd699 --- /dev/null +++ b/event_ssm/dataloading.py @@ -0,0 +1,444 @@ +import torch +from pathlib import Path +from typing import Callable, Optional, TypeVar, Dict, Tuple, List, Union +import tonic +from functools import partial +import numpy as np +from event_ssm.transform import Identity, Roll, Rotate, Scale, DropEventChunk, Jitter1D, OneHotLabels, cut_mix_augmentation + +DEFAULT_CACHE_DIR_ROOT = Path('./cache_dir/') + +DataLoader = TypeVar('DataLoader') +InputType = [str, Optional[int], Optional[int]] + + +class Data: + """ + Data class for storing dataset specific information + """ + def __init__( + self, + n_classes: int, + num_embeddings: int, + train_size: int +): + self.n_classes = n_classes + self.num_embeddings = num_embeddings + self.train_size = train_size + + +def event_stream_collate_fn(batch, resolution, pad_unit, cut_mix=0.0, no_time_information=False): + """ + Collate function to turn event stream data into tokens ready for the JAX model + + :param batch: list of tuples of (events, target) + :param resolution: resolution of the event stream + :param pad_unit: padding unit for the tokens. All sequences will be padded to integer multiples of this unit. + This option results in JAX compiling multiple GPU kernels for different sequence lengths, + which might slow down compilation time, but improves throughput for the rest of the training process. + :param cut_mix: probability of applying cut mix augmentation + :param no_time_information: if True, the time information is ignored and all events are treated as if they were + recorded sampled at uniform time intervals. + This option is only used for ablation studies. + """ + # x are inputs, y are targets, z are aux data + x, y, *z = zip(*batch) + assert len(z) == 0 + batch_size_one = len(x) == 1 + + # apply cut mix augmentation + if np.random.rand() < cut_mix: + x, y = cut_mix_augmentation(x, y) + + # set labels to numpy array + y = np.stack(y) + + # integration time steps are the difference between two consequtive time stamps + if no_time_information: + timesteps = [np.ones_like(e['t'][:-1]) for e in x] + else: + timesteps = [np.diff(e['t']) for e in x] + + # NOTE: since timesteps are deltas, their length is L - 1, and we have to remove the last token in the following + + # process tokens for single input dim (e.g. audio) + if len(resolution) == 1: + tokens = [e['x'][:-1].astype(np.int32) for e in x] + elif len(resolution) == 2: + tokens = [(e['x'][:-1] * e['y'][:-1] + np.prod(resolution) * e['p'][:-1].astype(np.int32)).astype(np.int32) for e in x] + else: + raise ValueError('resolution must contain 1 or 2 elements') + + # get padding lengths + lengths = np.array([len(e) for e in timesteps], dtype=np.int32) + pad_length = (lengths.max() // pad_unit) * pad_unit + pad_unit + + # pad tokens with -1, which results in a zero vector with embedding look-ups + tokens = np.stack( + [np.pad(e, (0, pad_length - len(e)), mode='constant', constant_values=-1) for e in tokens]) + timesteps = np.stack( + [np.pad(e, (0, pad_length - len(e)), mode='constant', constant_values=0) for e in timesteps]) + + # timesteps are in micro seconds... transform to milliseconds + timesteps = timesteps / 1000 + + if batch_size_one: + lengths = lengths[None, ...] + + return tokens, y, timesteps, lengths + + +def event_stream_dataloader( + train_data, + val_data, + test_data, + batch_size, + eval_batch_size, + train_collate_fn, + eval_collate_fn, + rng, + num_workers=0, + shuffle_training=True +): + """ + Create dataloaders for training, validation and testing + + :param train_data: training dataset + :param val_data: validation dataset + :param test_data: test dataset + :param batch_size: batch size for training + :param eval_batch_size: batch size for evaluation + :param train_collate_fn: collate function for training + :param eval_collate_fn: collate function for evaluation + :param rng: random number generator + :param num_workers: number of workers for data loading + :param shuffle_training: whether to shuffle the training data + + :return: train_loader, val_loader, test_loader + """ + def dataloader(dset, bsz, collate_fn, shuffle, drop_last): + return torch.utils.data.DataLoader( + dset, + batch_size=bsz, + drop_last=drop_last, + collate_fn=collate_fn, + shuffle=shuffle, + generator=rng, + num_workers=num_workers + ) + train_loader = dataloader(train_data, batch_size, train_collate_fn, shuffle=shuffle_training, drop_last=True) + val_loader = dataloader(val_data, eval_batch_size, eval_collate_fn, shuffle=False, drop_last=True) + test_loader = dataloader(test_data, eval_batch_size, eval_collate_fn, shuffle=False, drop_last=False) + return train_loader, val_loader, test_loader + + +def create_events_shd_classification_dataset( + cache_dir: Union[str, Path] = DEFAULT_CACHE_DIR_ROOT, + per_device_batch_size: int = 32, + per_device_eval_batch_size: int = 64, + world_size: int = 1, + num_workers: int = 0, + seed: int = 42, + time_jitter: float = 100, + spatial_jitter: float = 1.0, + max_drop_chunk: float = 0.1, + noise: int = 100, + drop_event: float = 0.1, + time_skew: float = 1.1, + cut_mix: float = 0.5, + pad_unit: int = 8192, + validate_on_test: bool = False, + no_time_information: bool = False, + **kwargs +) -> Tuple[DataLoader, DataLoader, DataLoader, Data]: + """ + creates a view of the spiking heidelberg digits dataset + + :param cache_dir: (str): where to store the dataset + :param bsz: (int): Batch size. + :param seed: (int) Seed for shuffling data. + :param time_jitter: (float) Standard deviation of the time jitter. + :param spatial_jitter: (float) Standard deviation of the spatial jitter. + :param max_drop_chunk: (float) Maximum fraction of events to drop in a single chunk. + :param noise: (int) Number of noise events to add. + :param drop_event: (float) Probability of dropping an event. + :param time_skew: (float) Time skew factor. + :param cut_mix: (float) Probability of applying cut mix augmentation. + :param pad_unit: (int) Padding unit for the tokens. See collate function for more details + :param validate_on_test: (bool) If True, use the test set for validation. + Else use a random validation split from the test set. + :param no_time_information: (bool) Whether to ignore the time information in the events. + + :return: train_loader, val_loader, test_loader, data + """ + print("[*] Generating Spiking Heidelberg Digits Classification Dataset") + + if seed is not None: + rng = torch.Generator() + rng.manual_seed(seed) + else: + rng = None + + sensor_size = (700, 1, 1) + + transforms = tonic.transforms.Compose([ + tonic.transforms.DropEvent(p=drop_event), + DropEventChunk(p=0.3, max_drop_size=max_drop_chunk), + Jitter1D(sensor_size=sensor_size, var=spatial_jitter), + tonic.transforms.TimeSkew(coefficient=(1 / time_skew, time_skew), offset=0), + tonic.transforms.TimeJitter(std=time_jitter, clip_negative=False, sort_timestamps=True), + tonic.transforms.UniformNoise(sensor_size=sensor_size, n=(0, noise)) + ]) + target_transforms = OneHotLabels(num_classes=20) + + train_data = tonic.datasets.SHD(save_to=cache_dir, train=True, transform=transforms, target_transform=target_transforms) + val_data = tonic.datasets.SHD(save_to=cache_dir, train=True, target_transform=target_transforms) + test_data = tonic.datasets.SHD(save_to=cache_dir, train=False, target_transform=target_transforms) + + # create validation set + if validate_on_test: + print("[*] WARNING: Using test set for validation") + val_data = tonic.datasets.SHD(save_to=cache_dir, train=False, target_transform=target_transforms) + else: + val_length = int(0.1 * len(train_data)) + indices = torch.randperm(len(train_data), generator=rng) + train_data = torch.utils.data.Subset(train_data, indices[:-val_length]) + val_data = torch.utils.data.Subset(val_data, indices[-val_length:]) + + collate_fn = partial(event_stream_collate_fn, resolution=(700,), pad_unit=pad_unit, no_time_information=no_time_information) + train_loader, val_loader, test_loader = event_stream_dataloader( + train_data, val_data, test_data, + train_collate_fn=partial(collate_fn, cut_mix=cut_mix), + eval_collate_fn=collate_fn, + batch_size=per_device_batch_size * world_size, eval_batch_size=per_device_eval_batch_size * world_size, + rng=rng, num_workers=num_workers, shuffle_training=True + ) + data = Data( + n_classes=20, num_embeddings=700, train_size=len(train_data) + ) + return train_loader, val_loader, test_loader, data + + +def create_events_ssc_classification_dataset( + cache_dir: Union[str, Path] = DEFAULT_CACHE_DIR_ROOT, + per_device_batch_size: int = 32, + per_device_eval_batch_size: int = 64, + world_size: int = 1, + num_workers: int = 0, + seed: int = 42, + time_jitter: float = 100, + spatial_jitter: float = 1.0, + max_drop_chunk: float = 0.1, + noise: int = 100, + drop_event: float = 0.1, + time_skew: float = 1.1, + cut_mix: float = 0.5, + pad_unit: int = 8192, + no_time_information: bool = False, + **kwargs +) -> Tuple[DataLoader, DataLoader, DataLoader, Data]: + """ + creates a view of the spiking speech commands dataset + + :param cache_dir: (str): where to store the dataset + :param bsz: (int): Batch size. + :param seed: (int) Seed for shuffling data. + :param time_jitter: (float) Standard deviation of the time jitter. + :param spatial_jitter: (float) Standard deviation of the spatial jitter. + :param max_drop_chunk: (float) Maximum fraction of events to drop in a single chunk. + :param noise: (int) Number of noise events to add. + :param drop_event: (float) Probability of dropping an event. + :param time_skew: (float) Time skew factor. + :param cut_mix: (float) Probability of applying cut mix augmentation. + :param pad_unit: (int) Padding unit for the tokens. See collate function for more details + :param no_time_information: (bool) Whether to ignore the time information in the events. + + :return: train_loader, val_loader, test_loader, data + """ + print("[*] Generating Spiking Speech Commands Classification Dataset") + + if seed is not None: + rng = torch.Generator() + rng.manual_seed(seed) + else: + rng = None + + sensor_size = (700, 1, 1) + + transforms = tonic.transforms.Compose([ + tonic.transforms.DropEvent(p=drop_event), + DropEventChunk(p=0.3, max_drop_size=max_drop_chunk), + Jitter1D(sensor_size=sensor_size, var=spatial_jitter), + tonic.transforms.TimeSkew(coefficient=(1 / time_skew, time_skew), offset=0), + tonic.transforms.TimeJitter(std=time_jitter, clip_negative=False, sort_timestamps=True), + tonic.transforms.UniformNoise(sensor_size=sensor_size, n=(0, noise)) + ]) + target_transforms = OneHotLabels(num_classes=35) + + train_data = tonic.datasets.SSC(save_to=cache_dir, split='train', transform=transforms, target_transform=target_transforms) + val_data = tonic.datasets.SSC(save_to=cache_dir, split='valid', target_transform=target_transforms) + test_data = tonic.datasets.SSC(save_to=cache_dir, split='test', target_transform=target_transforms) + + collate_fn = partial(event_stream_collate_fn, resolution=(700,), pad_unit=pad_unit, no_time_information=no_time_information) + train_loader, val_loader, test_loader = event_stream_dataloader( + train_data, val_data, test_data, + train_collate_fn=partial(collate_fn, cut_mix=cut_mix), + eval_collate_fn=collate_fn, + batch_size=per_device_batch_size * world_size, eval_batch_size=per_device_eval_batch_size * world_size, + rng=rng, num_workers=num_workers, shuffle_training=True + ) + + data = Data( + n_classes=35, num_embeddings=700, train_size=len(train_data) + ) + return train_loader, val_loader, test_loader, data + + +def create_events_dvs_gesture_classification_dataset( + cache_dir: Union[str, Path] = DEFAULT_CACHE_DIR_ROOT, + per_device_batch_size: int = 32, + per_device_eval_batch_size: int = 64, + world_size: int = 1, + num_workers: int = 0, + seed: int = 42, + slice_events: int = 0, + pad_unit: int = 2 ** 19, + # Augmentation parameters + time_jitter: float = 100, + spatial_jitter: float = 1.0, + noise: int = 100, + drop_event: float = 0.1, + time_skew: float = 1.1, + cut_mix: float = 0.5, + downsampling: int = 1, + max_roll: int = 4, + max_angle: float = 10, + max_scale: float = 1.5, + max_drop_chunk: float = 0.1, + validate_on_test: bool = False, + **kwargs +) -> Tuple[DataLoader, DataLoader, DataLoader, Data]: + """ + creates a view of the DVS Gesture dataset + + :param cache_dir: (str): where to store the dataset + :param bsz: (int): Batch size. + :param seed: (int) Seed for shuffling data. + :param slice_events: (int) Number of events per slice. + :param pad_unit: (int) Padding unit for the tokens. See collate function for more details + :param time_jitter: (float) Standard deviation of the time jitter. + :param spatial_jitter: (float) Standard deviation of the spatial jitter. + :param noise: (int) Number of noise events to add. + :param drop_event: (float) Probability of dropping an event. + :param time_skew: (float) Time skew factor. + :param cut_mix: (float) Probability of applying cut mix augmentation. + :param downsampling: (int) Downsampling factor. + :param max_roll: (int) Maximum number of pixels to roll the events. + :param max_angle: (float) Maximum angle to rotate the events. + :param max_scale: (float) Maximum scale factor to scale the events. + :param max_drop_chunk: (float) Maximum fraction of events to drop in a single chunk. + :param validate_on_test: (bool) If True, use the test set for validation. + Else use a random validation split from the test set. + + :return: train_loader, val_loader, test_loader, data + """ + print("[*] Generating DVS Gesture Classification Dataset") + + assert time_skew > 1, "time_skew must be greater than 1" + + if seed is not None: + rng = torch.Generator() + rng.manual_seed(seed) + else: + rng = None + + orig_sensor_size = (128, 128, 2) + new_sensor_size = (128 // downsampling, 128 // downsampling, 2) + train_transforms = [ + # Event transformations + DropEventChunk(p=0.3, max_drop_size=max_drop_chunk), + tonic.transforms.DropEvent(p=drop_event), + tonic.transforms.UniformNoise(sensor_size=new_sensor_size, n=(0, noise)), + # Time tranformations + tonic.transforms.TimeSkew(coefficient=(1 / time_skew, time_skew), offset=0), + tonic.transforms.TimeJitter(std=time_jitter, clip_negative=False, sort_timestamps=True), + # Spatial transformations + tonic.transforms.SpatialJitter(sensor_size=orig_sensor_size, var_x=spatial_jitter, var_y=spatial_jitter, clip_outliers=True), + tonic.transforms.Downsample(sensor_size=orig_sensor_size, target_size=new_sensor_size[:2]) if downsampling > 1 else Identity(), + # Geometric tranformations + Roll(sensor_size=new_sensor_size, p=0.3, max_roll=max_roll), + Rotate(sensor_size=new_sensor_size, p=0.3, max_angle=max_angle), + Scale(sensor_size=new_sensor_size, p=0.3, max_scale=max_scale), + ] + + train_transforms = tonic.transforms.Compose(train_transforms) + test_transforms = tonic.transforms.Compose([ + tonic.transforms.Downsample(sensor_size=orig_sensor_size, target_size=new_sensor_size[:2]) if downsampling > 1 else Identity(), + ]) + target_transforms = OneHotLabels(num_classes=11) + + TrainData = partial(tonic.datasets.DVSGesture, save_to=cache_dir, train=True) + TestData = partial(tonic.datasets.DVSGesture, save_to=cache_dir, train=False) + + # create validation set + if validate_on_test: + print("[*] WARNING: Using test set for validation") + val_data = TestData(transform=test_transforms, target_transform=target_transforms) + else: + # create train validation split + val_data = TrainData(transform=test_transforms, target_transform=target_transforms) + val_length = int(0.2 * len(val_data)) + indices = torch.randperm(len(val_data), generator=rng) + val_data = torch.utils.data.Subset(val_data, indices[-val_length:]) + + # if slice event count is given, train on slices of the training data + if slice_events > 0: + slicer = tonic.slicers.SliceByEventCount(event_count=slice_events, overlap=slice_events // 2, include_incomplete=True) + train_subset = torch.utils.data.Subset(TrainData(), indices[:-val_length]) if not validate_on_test else TrainData() + train_data = tonic.sliced_dataset.SlicedDataset( + dataset=train_subset, + slicer=slicer, + transform=train_transforms, + target_transform=target_transforms, + metadata_path=None + ) + else: + train_data = torch.utils.data.Subset( + TrainData(transform=train_transforms, target_transform=target_transforms), + indices[:-val_length] + ) if not validate_on_test else TrainData(transform=train_transforms) + + # Always evaluate on the full sequences + test_data = TestData(transform=test_transforms, target_transform=target_transforms) + + # define collate functions + train_collate_fn = partial( + event_stream_collate_fn, + resolution=new_sensor_size[:2], + pad_unit=slice_events if (slice_events != 0 and slice_events < pad_unit) else pad_unit, + cut_mix=cut_mix + ) + eval_collate_fn = partial( + event_stream_collate_fn, + resolution=new_sensor_size[:2], + pad_unit=pad_unit, + ) + train_loader, val_loader, test_loader = event_stream_dataloader( + train_data, val_data, test_data, + train_collate_fn=train_collate_fn, + eval_collate_fn=eval_collate_fn, + batch_size=per_device_batch_size * world_size, eval_batch_size=per_device_eval_batch_size * world_size, + rng=rng, num_workers=num_workers, shuffle_training=True + ) + + data = Data( + n_classes=11, num_embeddings=np.prod(new_sensor_size), train_size=len(train_data) + ) + return train_loader, val_loader, test_loader, data + + +Datasets = { + "shd-classification": create_events_shd_classification_dataset, + "ssc-classification": create_events_ssc_classification_dataset, + "dvs-gesture-classification": create_events_dvs_gesture_classification_dataset, +} diff --git a/event_ssm/layers.py b/event_ssm/layers.py new file mode 100644 index 0000000..b740359 --- /dev/null +++ b/event_ssm/layers.py @@ -0,0 +1,198 @@ +from flax import linen as nn +import jax +from functools import partial + + +class EventPooling(nn.Module): + """ + Subsampling layer for event sequences. + """ + stride: int = 1 + mode: str = "last" + eps: float = 1e-6 + + def __call__(self, x, integration_timesteps): + """ + Compute the pooled (L/stride)xH output given an LxH input. + :param x: input sequence (L, d_model) + :param integration_timesteps: the integration timesteps for the SSM + :return: output sequence (L/stride, d_model) + """ + if self.stride == 1: + raise ValueError("Stride 1 not supported for pooling") + + else: + remaining_timesteps = (len(integration_timesteps) // self.stride) * self.stride + new_integration_timesteps = integration_timesteps[:remaining_timesteps].reshape(-1, self.stride).sum(axis=1) + x = x[:remaining_timesteps] + d_model = x.shape[-1] + + if self.mode == 'last': + x = x[::self.stride] + return x, new_integration_timesteps + elif self.mode == 'avgpool': + x = x.reshape(-1, self.stride, d_model).mean(axis=1) + return x, new_integration_timesteps + elif self.mode == 'timepool': + weight = integration_timesteps[:remaining_timesteps, None] + self.eps + x = (x * weight).reshape(-1, self.stride, d_model).sum(axis=1) + x = x / weight.reshape(-1, self.stride, 1).sum(axis=1) + return x, new_integration_timesteps + else: + raise NotImplementedError("Pooling mode: {} not implemented".format(self.stride)) + + +class SequenceStage(nn.Module): + """ + Defines a block of EventSSM layers with the same hidden size and event-resolution + + :param ssm: the SSM to be used (i.e. S5 ssm) + :param d_model_in: this is the feature size of the layer inputs and outputs + we usually refer to this size as H + :param d_model_out: this is the feature size of the layer outputs + :param d_ssm: the size of the state space model + :param ssm_block_size: the block size of the state space model + :param layers_per_stage: the number of S5 layers to stack + :param dropout: dropout rate + :param prenorm: whether to use layernorm before the module or after it + :param batchnorm: If True, use batchnorm instead of layernorm + :param bn_momentum: momentum for batchnorm + :param step_rescale: rescale the integration timesteps by this factor + :param pooling_stride: stride for pooling + :param pooling_mode: pooling mode (last, avgpool, timepool) + :param state_expansion_factor: factor to expand the state space model + """ + ssm: nn.Module + discretization: str + d_model_in: int + d_model_out: int + d_ssm: int + ssm_block_size: int + layers_per_stage: int + dropout: float = 0.0 + prenorm: bool = False + batchnorm: bool = False + bn_momentum: float = 0.9 + step_rescale: float = 1.0 + pooling_stride: int = 1 + pooling_mode: str = "last" + state_expansion_factor: int = 1 + + @nn.compact + def __call__(self, x, integration_timesteps, train: bool): + """ + Compute the LxH output of the stacked encoder given an Lxd_input input sequence. + + :param x: input sequence (L, d_input) + :param integration_timesteps: the integration timesteps for the SSM + :param train: If True, applies dropout and batch norm from batch statistics + :return: output sequence (L, d_model), integration_timesteps + """ + EventSSMLayer = partial( + SequenceLayer, + ssm=self.ssm, + discretization=self.discretization, + dropout=self.dropout, + d_ssm=self.d_ssm, + block_size=self.ssm_block_size, + prenorm=self.prenorm, + batchnorm=self.batchnorm, + bn_momentum=self.bn_momentum, + step_rescale=self.step_rescale, + ) + + # first layer with pooling + x, integration_timesteps = EventSSMLayer( + d_model_in=self.d_model_in, + d_model_out=self.d_model_out, + pooling_stride=self.pooling_stride, + pooling_mode=self.pooling_mode + )(x, integration_timesteps, train=train) + + # further layers without pooling + for l in range(self.layers_per_stage - 1): + x, integration_timesteps = EventSSMLayer( + d_model_in=self.d_model_out, + d_model_out=self.d_model_out, + pooling_stride=1 + )(x, integration_timesteps, train=train) + + return x, integration_timesteps + + +class SequenceLayer(nn.Module): + """ + Defines a single event-ssm layer, with S5 SSM, nonlinearity, + dropout, batch/layer norm, etc. + + :param ssm: the SSM to be used (i.e. S5 ssm) + :param discretization: the discretization method to use (zoh, dirac, async) + :param dropout: dropout rate + :param d_model_in: the input feature size + :param d_model_out: the output feature size + :param d_ssm: the size of the state space model + :param block_size: the block size of the state space model + :param prenorm: whether to use layernorm before the module or after it + :param batchnorm: If True, use batchnorm instead of layernorm + :param bn_momentum: momentum for batchnorm + :param step_rescale: rescale the integration timesteps by this factor + :param pooling_stride: stride for pooling + :param pooling_mode: pooling mode (last, avgpool, timepool) + """ + ssm: nn.Module + discretization: str + dropout: float + d_model_in: int + d_model_out: int + d_ssm: int + block_size: int + prenorm: bool = False + batchnorm: bool = False + bn_momentum: float = 0.90 + step_rescale: float = 1.0 + pooling_stride: int = 1 + pooling_mode: str = "last" + + @nn.compact + def __call__(self, x, integration_timesteps, train: bool): + """ + Compute a layer step + + :param x: input sequence (L, d_model_in) + :param integration_timesteps: the integration timesteps for the SSM + :param train: If True, applies dropout and batch norm from batch statistics + :return: output sequence (L, d_model_out), integration_timesteps + """ + skip = x + + if self.prenorm: + norm = nn.BatchNorm(momentum=self.bn_momentum, axis_name='batch') if self.batchnorm else nn.LayerNorm() + x = norm(x, use_running_average=not train) if self.batchnorm else norm(x) + + # apply state space model + x = self.ssm( + H_in=self.d_model_in, H_out=self.d_model_out, P=self.d_ssm, block_size=self.block_size, + step_rescale=self.step_rescale, discretization=self.discretization, + stride=self.pooling_stride, pooling_mode=self.pooling_mode + )(x, integration_timesteps) + + # non-linear activation function + x1 = nn.Dropout(self.dropout, broadcast_dims=[0], deterministic=not train)(nn.gelu(x)) + x1 = nn.Dense(self.d_model_out)(x1) + x = x * nn.sigmoid(x1) + x = nn.Dropout(self.dropout, broadcast_dims=[0], deterministic=not train)(x) + + if self.pooling_stride > 1: + pool = EventPooling(stride=self.pooling_stride, mode=self.pooling_mode) + skip, integration_timesteps = pool(skip, integration_timesteps) + + if self.d_model_in != self.d_model_out: + skip = nn.Dense(self.d_model_out)(skip) + + x = skip + x + + if not self.prenorm: + norm = nn.BatchNorm(momentum=self.bn_momentum, axis_name='batch') if self.batchnorm else nn.LayerNorm() + x = norm(x, use_running_average=not train) if self.batchnorm else norm(x) + + return x, integration_timesteps diff --git a/event_ssm/seq_model.py b/event_ssm/seq_model.py new file mode 100644 index 0000000..0919671 --- /dev/null +++ b/event_ssm/seq_model.py @@ -0,0 +1,277 @@ +import jax +import jax.numpy as np +from flax import linen as nn +from .layers import SequenceStage + + +class StackedEncoderModel(nn.Module): + """ + Defines a stack of S5 layers to be used as an encoder. + + :param ssm: the SSM to be used (i.e. S5 ssm) + :param discretization: the discretization to be used for the SSM + :param d_model: the feature size of the layer inputs and outputs. We usually refer to this size as H + :param d_ssm: the size of the state space model. We usually refer to this size as P + :param ssm_block_size: the block size of the state space model + :param num_stages: the number of S5 layers to stack + :param num_layers_per_stage: the number of EventSSM layers to stack + :param num_embeddings: the number of embeddings to use + :param dropout: dropout rate + :param prenorm: whether to use layernorm before the module or after it + :param batchnorm: If True, use batchnorm instead of layernorm + :param bn_momentum: momentum for batchnorm + :param step_rescale: rescale the integration timesteps by this factor + :param pooling_stride: stride for subsampling + :param pooling_every_n_layers: pool every n layers + :param pooling_mode: pooling mode (last, avgpool, timepool) + :param state_expansion_factor: factor to expand the state space model + """ + ssm: nn.Module + discretization: str + d_model: int + d_ssm: int + ssm_block_size: int + num_stages: int + num_layers_per_stage: int + num_embeddings: int = 0 + dropout: float = 0.0 + prenorm: bool = False + batchnorm: bool = False + bn_momentum: float = 0.9 + step_rescale: float = 1.0 + pooling_stride: int = 1 + pooling_every_n_layers: int = 1 + pooling_mode: str = "last" + state_expansion_factor: int = 1 + + def setup(self): + """ + Initializes a linear encoder and the stack of EventSSM layers. + """ + assert self.num_embeddings > 0 + self.encoder = nn.Embed(num_embeddings=self.num_embeddings, features=self.d_model) + + # generate strides for the model + stages = [] + d_model_in = self.d_model + d_model_out = self.d_model + d_ssm = self.d_ssm + total_downsampling = 1 + for stage in range(self.num_stages): + # pool from the first layer but don't expand the state dim for the first layer + total_downsampling *= self.pooling_stride + + stages.append( + SequenceStage( + ssm=self.ssm, + discretization=self.discretization, + d_model_in=d_model_in, + d_model_out=d_model_out, + d_ssm=d_ssm, + ssm_block_size=self.ssm_block_size, + layers_per_stage=self.num_layers_per_stage, + dropout=self.dropout, + prenorm=self.prenorm, + batchnorm=self.batchnorm, + bn_momentum=self.bn_momentum, + step_rescale=self.step_rescale, + pooling_stride=self.pooling_stride, + pooling_mode=self.pooling_mode + ) + ) + + d_ssm = self.state_expansion_factor * d_ssm + d_model_out = self.state_expansion_factor * d_model_in + + if stage > 0: + d_model_in = self.state_expansion_factor * d_model_in + + self.stages = stages + self.total_downsampling = total_downsampling + + def __call__(self, x, integration_timesteps, train: bool): + """ + Compute the LxH output of the stacked encoder given an Lxd_input + input sequence. + :param x: input sequence (L, d_input) + :param integration_timesteps: the integration timesteps for the SSM + :param train: If True, applies dropout and batch norm from batch statistics + :return: output sequence (L, d_model), integration timesteps + """ + x = self.encoder(x) + for i, stage in enumerate(self.stages): + # apply layer SSM + x, integration_timesteps = stage(x, integration_timesteps, train=train) + return x, integration_timesteps + + +def masked_meanpool(x, lengths): + """ + Helper function to perform mean pooling across the sequence length + when sequences have variable lengths. We only want to pool across + the prepadded sequence length. + + :param x: input sequence (L, d_model) + :param lengths: the original length of the sequence before padding + :return: mean pooled output sequence (d_model) + """ + L = x.shape[0] + mask = np.arange(L) < lengths + return np.sum(mask[..., None]*x, axis=0)/lengths + + +def timepool(x, integration_timesteps): + """ + Helper function to perform weighted mean across the sequence length. + Means are weighted with the integration time steps + + :param x: input sequence (L, d_model) + :param integration_timesteps: the integration timesteps for the SSM + :return: time pooled output sequence (d_model) + """ + T = np.sum(integration_timesteps, axis=0) + integral = np.sum(x * integration_timesteps[..., None], axis=0) + return integral / T + + +def masked_timepool(x, lengths, integration_timesteps, eps=1e-6): + """ + Helper function to perform weighted mean across the sequence length + when sequences have variable lengths. We only want to pool across + the prepadded sequence length. Means are weighted with the integration time steps + + :param x: input sequence (L, d_model) + :param lengths: the original length of the sequence before padding + :param integration_timesteps: the integration timesteps for the SSM + :param eps: small value to avoid division by zero + :return: time pooled output sequence (d_model) + """ + L = x.shape[0] + mask = np.arange(L) < lengths + T = np.sum(integration_timesteps) + + # integrate with time weighting + weight = integration_timesteps[..., None] + eps + integral = np.sum(mask[..., None] * x * weight, axis=0) + return integral / T + + +# Here we call vmap to parallelize across a batch of input sequences +batch_masked_meanpool = jax.vmap(masked_meanpool) + + +class ClassificationModel(nn.Module): + """ + EventSSM classificaton sequence model. This consists of the stacked encoder + (which consists of a linear encoder and stack of S5 layers), mean pooling + across the sequence length, a linear decoder, and a softmax operation. + + :param ssm: the SSM to be used (i.e. S5 ssm) + :param discretization: the discretization to be used for the SSM (zoh, dirac, async) + :param num_classes: the number of classes for the classification task + :param d_model: the feature size of the layer inputs and outputs. We usually refer to this size as H + :param d_ssm: the size of the state space model. We usually refer to this size as P + :param ssm_block_size: the block size of the state space model + :param num_stages: the number of S5 layers to stack + :param num_layers_per_stage: the number of EventSSM layers to stack + :param num_embeddings: the number of embeddings to use + :param dropout: dropout rate + :param classification_mode: the classification mode (pool, timepool, last) + :param prenorm: whether to use layernorm before the module or after it + :param batchnorm: If True, use batchnorm instead of layernorm + :param bn_momentum: momentum for batchnorm + :param step_rescale: rescale the integration timesteps by this factor + :param pooling_stride: stride for subsampling + :param pooling_every_n_layers: pool every n layers + :param pooling_mode: pooling mode (last, avgpool, timepool) + :param state_expansion_factor: factor to expand the state space model + """ + ssm: nn.Module + discretization: str + num_classes: int + d_model: int + d_ssm: int + ssm_block_size: int + num_stages: int + num_layers_per_stage: int + num_embeddings: int = 0 + dropout: float = 0.2 + classification_mode: str = "pool" + prenorm: bool = False + batchnorm: bool = False + bn_momentum: float = 0.9 + step_rescale: float = 1.0 + pooling_stride: int = 1 + pooling_every_n_layers: int = 1 + pooling_mode: str = "last" + state_expansion_factor: int = 1 + + def setup(self): + """ + Initializes the stacked EventSSM encoder and a linear decoder. + """ + self.encoder = StackedEncoderModel( + ssm=self.ssm, + discretization=self.discretization, + d_model=self.d_model, + d_ssm=self.d_ssm, + ssm_block_size=self.ssm_block_size, + num_stages=self.num_stages, + num_layers_per_stage=self.num_layers_per_stage, + num_embeddings=self.num_embeddings, + dropout=self.dropout, + prenorm=self.prenorm, + batchnorm=self.batchnorm, + bn_momentum=self.bn_momentum, + step_rescale=self.step_rescale, + pooling_stride=self.pooling_stride, + pooling_every_n_layers=self.pooling_every_n_layers, + pooling_mode=self.pooling_mode, + state_expansion_factor=self.state_expansion_factor + ) + self.decoder = nn.Dense(self.num_classes) + + def __call__(self, x, integration_timesteps, length, train=True): + """ + Compute the size num_classes log softmax output given a + Lxd_input input sequence. + + :param x: input sequence (L, d_input) + :param integration_timesteps: the integration timesteps for the SSM + :param length: the original length of the sequence before padding + :param train: If True, applies dropout and batch norm from batch statistics + + :return: output (num_classes) + """ + # if the sequence is downsampled we need to adjust the length + length = length // self.encoder.total_downsampling + + # run encoder backbone + x, integration_timesteps = self.encoder(x, integration_timesteps, train=train) + + # apply classification head + if self.classification_mode in ["pool"]: + # Perform mean pooling across time + x = masked_meanpool(x, length) + + elif self.classification_mode in ["timepool"]: + # Perform mean pooling across time weighted by integration time steps + x = masked_timepool(x, length, integration_timesteps) + + elif self.classification_mode in ["last"]: + # Just take the last state + x = x[-1] + else: + raise NotImplementedError("Mode must be in ['pool', 'last]") + + x = self.decoder(x) + return x + + +# Here we call vmap to parallelize across a batch of input sequences +BatchClassificationModel = nn.vmap( + ClassificationModel, + in_axes=(0, 0, 0, None), + out_axes=0, + variable_axes={"params": None, "dropout": None, 'batch_stats': None, "cache": 0, "prime": None}, + split_rngs={"params": False, "dropout": True}, axis_name='batch') diff --git a/event_ssm/ssm.py b/event_ssm/ssm.py new file mode 100644 index 0000000..c605747 --- /dev/null +++ b/event_ssm/ssm.py @@ -0,0 +1,280 @@ +from functools import partial +import jax +import jax.numpy as np +from jax.scipy.linalg import block_diag + +from flax import linen as nn +from jax.nn.initializers import lecun_normal, normal, glorot_normal + +from .ssm_init import init_CV, init_VinvB, init_log_steps, trunc_standard_normal, make_DPLR_HiPPO + +from .layers import EventPooling + + +def discretize_zoh(Lambda, step_delta, time_delta): + """ + Discretize a diagonalized, continuous-time linear SSM + using zero-order hold method. + This is the default discretization method used by many SSM works including S5. + + :param Lambda: diagonal state matrix (P,) + :param step_delta: discretization step sizes (P,) + :param time_delta: (float32) discretization step sizes (P,) + :return: discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) + """ + Identity = np.ones(Lambda.shape[0]) + Delta = step_delta * time_delta + Lambda_bar = np.exp(Lambda * Delta) + gamma_bar = (1/Lambda * (Lambda_bar-Identity)) + return Lambda_bar, gamma_bar + + +def discretize_dirac(Lambda, step_delta, time_delta): + """ + Discretize a diagonalized, continuous-time linear SSM + with dirac delta input spikes. + :param Lambda: diagonal state matrix (P,) + :param step_delta: discretization step sizes (P,) + :param time_delta: (float32) discretization step sizes (P,) + :return: discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) + """ + Delta = step_delta * time_delta + Lambda_bar = np.exp(Lambda * Delta) + gamma_bar = 1.0 + return Lambda_bar, gamma_bar + + +def discretize_async(Lambda, step_delta, time_delta): + """ + Discretize a diagonalized, continuous-time linear SSM + with dirac delta input spikes and appropriate input normalization. + + :param Lambda: diagonal state matrix (P,) + :param step_delta: discretization step sizes (P,) + :param time_delta: (float32) discretization step sizes (P,) + :return: discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) + """ + Identity = np.ones(Lambda.shape[0]) + Lambda_bar = np.exp(Lambda * step_delta * time_delta) + gamma_bar = (1/Lambda * (np.exp(Lambda * step_delta)-Identity)) + return Lambda_bar, gamma_bar + + +# Parallel scan operations +@jax.vmap +def binary_operator(q_i, q_j): + """ + Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A. + + :param q_i: tuple containing A_i and Bu_i at position i (P,), (P,) + :param q_j: tuple containing A_j and Bu_j at position j (P,), (P,) + :return: new element ( A_out, Bu_out ) + """ + A_i, b_i = q_i + A_j, b_j = q_j + return A_j * A_i, A_j * b_i + b_j + + +def apply_ssm(Lambda_elements, Bu_elements, C_tilde, conj_sym, stride=1): + """ + Compute the LxH output of discretized SSM given an LxH input. + + :param Lambda_elements: (complex64) discretized state matrix (L, P) + :param Bu_elements: (complex64) discretized inputs projected to state space (L, P) + :param C_tilde: (complex64) output matrix (H, P) + :param conj_sym: (bool) whether conjugate symmetry is enforced + :return: ys: (float32) the SSM outputs (S5 layer preactivations) (L, H) + """ + remaining_timesteps = (Bu_elements.shape[0] // stride) * stride + + _, xs = jax.lax.associative_scan(binary_operator, (Lambda_elements, Bu_elements)) + + xs = xs[:remaining_timesteps:stride] + + if conj_sym: + return jax.vmap(lambda x: 2*(C_tilde @ x).real)(xs) + else: + return jax.vmap(lambda x: (C_tilde @ x).real)(xs) + + +class S5SSM(nn.Module): + H_in: int + H_out: int + P: int + block_size: int + C_init: str + discretization: str + dt_min: float + dt_max: float + conj_sym: bool = True + clip_eigs: bool = False + step_rescale: float = 1.0 + stride: int = 1 + pooling_mode: str = "last" + + """ + Event-based S5 module + + :param H_in: int, SSM input dimension + :param H_out: int, SSM output dimension + :param P: int, SSM state dimension + :param block_size: int, block size for block-diagonal state matrix + :param C_init: str, initialization method for output matrix C + :param discretization: str, discretization method for event-based SSM + :param dt_min: float, minimum value of log timestep + :param dt_max: float, maximum value of log timestep + :param conj_sym: bool, whether to enforce conjugate symmetry in the state space operator + :param clip_eigs: bool, whether to clip eigenvalues of the state space operator + :param step_rescale: float, rescale factor for step size + :param stride: int, stride for subsampling layer + :param pooling_mode: str, pooling mode for subsampling layer + """ + + def setup(self): + """ + Initializes parameters once and performs discretization each time the SSM is applied to a sequence + """ + + # Initialize state matrix A using approximation to HiPPO-LegS matrix + Lambda, _, B, V, B_orig = make_DPLR_HiPPO(self.block_size) + + blocks = self.P // self.block_size + block_size = self.block_size // 2 if self.conj_sym else self.block_size + local_P = self.P // 2 if self.conj_sym else self.P + + Lambda = Lambda[:block_size] + V = V[:, :block_size] + Vc = V.conj().T + + # If initializing state matrix A as block-diagonal, put HiPPO approximation + # on each block + Lambda = (Lambda * np.ones((blocks, block_size))).ravel() + V = block_diag(*([V] * blocks)) + Vinv = block_diag(*([Vc] * blocks)) + + state_str = f"SSM: {self.H_in} -> {self.P} -> {self.H_out}" + if self.stride > 1: + state_str += f" (stride {self.stride} with pooling mode {self.pooling_mode})" + print(state_str) + + # Initialize diagonal state to state matrix Lambda (eigenvalues) + self.Lambda_re = self.param("Lambda_re", lambda rng, shape: Lambda.real, (None,)) + self.Lambda_im = self.param("Lambda_im", lambda rng, shape: Lambda.imag, (None,)) + + if self.clip_eigs: + self.Lambda = np.clip(self.Lambda_re, None, -1e-4) + 1j * self.Lambda_im + else: + self.Lambda = self.Lambda_re + 1j * self.Lambda_im + + # Initialize input to state (B) matrix + B_init = lecun_normal() + B_shape = (self.P, self.H_in) + self.B = self.param("B", + lambda rng, shape: init_VinvB(B_init, rng, shape, Vinv), + B_shape) + + # Initialize state to output (C) matrix + if self.C_init in ["trunc_standard_normal"]: + C_init = trunc_standard_normal + C_shape = (self.H_out, self.P, 2) + elif self.C_init in ["lecun_normal"]: + C_init = lecun_normal() + C_shape = (self.H_out, self.P, 2) + elif self.C_init in ["complex_normal"]: + C_init = normal(stddev=0.5 ** 0.5) + else: + raise NotImplementedError( + "C_init method {} not implemented".format(self.C_init)) + + if self.C_init in ["complex_normal"]: + C = self.param("C", C_init, (self.H_out, local_P, 2)) + self.C_tilde = C[..., 0] + 1j * C[..., 1] + + else: + self.C = self.param("C", + lambda rng, shape: init_CV(C_init, rng, shape, V), + C_shape) + + self.C_tilde = self.C[..., 0] + 1j * self.C[..., 1] + + # Initialize feedthrough (D) matrix + if self.H_in == self.H_out: + self.D = self.param("D", normal(stddev=1.0), (self.H_in,)) + else: + self.D = self.param("D", glorot_normal(), (self.H_out, self.H_in)) + + # Initialize learnable discretization timescale value + self.log_step = self.param("log_step", + init_log_steps, + (local_P, self.dt_min, self.dt_max)) + + # pooling layer + self.pool = EventPooling(stride=self.stride, mode=self.pooling_mode) + + # Discretize + if self.discretization in ["zoh"]: + self.discretize_fn = discretize_zoh + elif self.discretization in ["dirac"]: + self.discretize_fn = discretize_dirac + elif self.discretization in ["async"]: + self.discretize_fn = discretize_async + else: + raise NotImplementedError("Discretization method {} not implemented".format(self.discretization)) + + def __call__(self, input_sequence, integration_timesteps): + """ + Compute the LxH output of the S5 SSM given an LxH input sequence using a parallel scan. + + :param input_sequence: (float32) input sequence (L, H) + :param integration_timesteps: (float32) integration timesteps (L,) + :return: (float32) output sequence (L, H) + """ + + # discretize on the fly + B = self.B[..., 0] + 1j * self.B[..., 1] + + def discretize_and_project_inputs(u, _timestep): + step = self.step_rescale * np.exp(self.log_step[:, 0]) + Lambda_bar, gamma_bar = self.discretize_fn(self.Lambda, step, _timestep) + Bu = gamma_bar * (B @ u) + return Lambda_bar, Bu + + Lambda_bar_elements, Bu_bar_elements = jax.vmap(discretize_and_project_inputs)(input_sequence, integration_timesteps) + + ys = apply_ssm( + Lambda_bar_elements, + Bu_bar_elements, + self.C_tilde, + self.conj_sym, + stride=self.stride + ) + + if self.stride > 1: + input_sequence, _ = self.pool(input_sequence, integration_timesteps) + + if self.H_in == self.H_out: + Du = jax.vmap(lambda u: self.D * u)(input_sequence) + else: + Du = jax.vmap(lambda u: self.D @ u)(input_sequence) + + return ys + Du + + +def init_S5SSM( + C_init, + dt_min, + dt_max, + conj_sym, + clip_eigs, +): + """ + Convenience function that will be used to initialize the SSM. + Same arguments as defined in S5SSM above. + """ + return partial(S5SSM, + C_init=C_init, + dt_min=dt_min, + dt_max=dt_max, + conj_sym=conj_sym, + clip_eigs=clip_eigs + ) diff --git a/event_ssm/ssm_init.py b/event_ssm/ssm_init.py new file mode 100644 index 0000000..7c5cd52 --- /dev/null +++ b/event_ssm/ssm_init.py @@ -0,0 +1,153 @@ +from jax import random +import jax.numpy as np +from jax.nn.initializers import lecun_normal +from jax.numpy.linalg import eigh + + +def make_HiPPO(N): + """ + Create a HiPPO-LegS matrix. + From https://github.com/srush/annotated-s4/blob/main/s4/s4.py + + :params N: int32, state size + :returns: N x N HiPPO LegS matrix + """ + P = np.sqrt(1 + 2 * np.arange(N)) + A = P[:, np.newaxis] * P[np.newaxis, :] + A = np.tril(A) - np.diag(np.arange(N)) + return -A + + +def make_NPLR_HiPPO(N): + """ + Makes components needed for NPLR representation of HiPPO-LegS + From https://github.com/srush/annotated-s4/blob/main/s4/s4.py + + :params N: int32, state size + :returns: N x N HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B + """ + # Make -HiPPO + hippo = make_HiPPO(N) + + # Add in a rank 1 term. Makes it Normal. + P = np.sqrt(np.arange(N) + 0.5) + + # HiPPO also specifies the B matrix + B = np.sqrt(2 * np.arange(N) + 1.0) + return hippo, P, B + + +def make_DPLR_HiPPO(N): + """ + Makes components needed for DPLR representation of HiPPO-LegS + From https://github.com/srush/annotated-s4/blob/main/s4/s4.py + Note, we will only use the diagonal part + + :params N: int32, state size + :returns: eigenvalues Lambda, low-rank term P, conjugated HiPPO input matrix B, + eigenvectors V, HiPPO B pre-conjugation + """ + A, P, B = make_NPLR_HiPPO(N) + + S = A + P[:, np.newaxis] * P[np.newaxis, :] + + S_diag = np.diagonal(S) + Lambda_real = np.mean(S_diag) * np.ones_like(S_diag) + + # Diagonalize S to V \Lambda V^* + Lambda_imag, V = eigh(S * -1j) + + P = V.conj().T @ P + B_orig = B + B = V.conj().T @ B + return Lambda_real + 1j * Lambda_imag, P, B, V, B_orig + + +def log_step_initializer(dt_min=0.001, dt_max=0.1): + """ + Initialize the learnable timescale Delta by sampling + uniformly between dt_min and dt_max. + + :params dt_min: float32, minimum value of log timestep + :params dt_max: float32, maximum value of log timestep + :returns: init function + """ + def init(key, shape): + return random.uniform(key, shape) * ( + np.log(dt_max) - np.log(dt_min) + ) + np.log(dt_min) + + return init + + +def init_log_steps(key, input): + """ + Initialize an array of learnable timescale parameters + + :params key: jax random + :params input: tuple containing the array shape H and + dt_min and dt_max + :returns: initialized array of timescales (float32): (H,) + """ + H, dt_min, dt_max = input + log_steps = [] + for i in range(H): + key, skey = random.split(key) + log_step = log_step_initializer(dt_min=dt_min, dt_max=dt_max)(skey, shape=(1,)) + log_steps.append(log_step) + + return np.array(log_steps) + + +def init_VinvB(init_fun, rng, shape, Vinv): + """ + Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B. + Note we will parameterize this with two different matrices for complex numbers. + + :params init_fun: function, the initialization function to use, e.g. lecun_normal() + :params rng: jax random key to be used with init function. + :params shape: tuple, desired shape (P,H) + :params Vinv: complex64, the inverse eigenvectors used for initialization + :returns: B_tilde (complex64) of shape (P,H,2) + """ + B = init_fun(rng, shape) + VinvB = Vinv @ B + VinvB_real = VinvB.real + VinvB_imag = VinvB.imag + return np.concatenate((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1) + + +def trunc_standard_normal(key, shape): + """ + Sample C with a truncated normal distribution with standard deviation 1. + + :params key: jax random key + :params shape: tuple, desired shape (H,P, _) + :returns: sampled C matrix (float32) of shape (H,P,2) (for complex parameterization) + """ + H, P, _ = shape + Cs = [] + for i in range(H): + key, skey = random.split(key) + C = lecun_normal()(skey, shape=(1, P, 2)) + Cs.append(C) + return np.array(Cs)[:, 0] + + +def init_CV(init_fun, rng, shape, V): + """ + Initialize C_tilde=CV. First sample C. Then compute CV. + Note we will parameterize this with two different matrices for complex numbers. + + :params init_fun: function, the initialization function to use, e.g. lecun_normal() + :params rng: jax random key to be used with init function. + :params shape: tuple, desired shape (H,P) + :params V: complex64, the eigenvectors used for initialization + :returns: C_tilde (complex64) of shape (H,P,2) + """ + C_ = init_fun(rng, shape) + C = C_[..., 0] + 1j * C_[..., 1] + CV = C @ V + CV_real = CV.real + CV_imag = CV.imag + return np.concatenate((CV_real[..., None], CV_imag[..., None]), axis=-1) diff --git a/event_ssm/train_utils.py b/event_ssm/train_utils.py new file mode 100644 index 0000000..177847b --- /dev/null +++ b/event_ssm/train_utils.py @@ -0,0 +1,249 @@ +import numpy as np +import jax +import jax.numpy as jnp +from jaxtyping import Array +from typing import Any, Dict +import random +from flax.training import train_state +import optax +from functools import partial + + +class TrainState(train_state.TrainState): + key: Array + model_state: Dict + + +def training_step( + train_state: TrainState, + batch: Array, + dropout_key: Array, + distributed: bool = False +): + """ + Conducts a single training step on a batch of data. + + :param train_state: a Flax TrainState that carries the parameters, optimizer states etc + :param batch: the data consisting of [data, target, integration_timesteps, lengths] + :param distributed: If True, apply reduce operations like psum, pmean etc + :return: train_state, metrics + """ + inputs, targets, integration_timesteps, lengths = batch + + def loss_fn(params): + logits, updates = train_state.apply_fn( + {'params': params, **train_state.model_state}, + inputs, integration_timesteps, lengths, + True, + rngs={'dropout': dropout_key}, + mutable=['batch_stats'] + ) + + loss = optax.softmax_cross_entropy(logits, targets) + loss = loss.mean() + + return loss, (logits, updates) + + (loss, (logits, batch_updates)), grads = jax.value_and_grad(loss_fn, has_aux=True)(train_state.params) + + preds = jnp.argmax(logits, axis=-1) + targets = jnp.argmax(targets, axis=-1) + accuracy = (preds == targets).mean() + + if distributed: + grads = jax.lax.pmean(grads, axis_name='data') + loss = jax.lax.pmean(loss, axis_name='data') + accuracy = jax.lax.pmean(accuracy, axis_name='data') + + train_state = train_state.apply_gradients(grads=grads) + train_state = train_state.replace(model_state=batch_updates) + + return train_state, {'loss': loss, 'accuracy': accuracy} + + +def evaluation_step( + train_state: TrainState, + batch: Array, + distributed: bool = False +): + """ + Conducts a single evaluation step on a batch of data. + + :param train_state: a Flax TrainState that carries the parameters, optimizer states etc + :param batch: the data consisting of [data, target] + :param distributed: If True, apply reduce operations like psum, pmean etc + :return: train_state, metrics + """ + inputs, targets, integration_timesteps, lengths = batch + logits = train_state.apply_fn( + {'params': train_state.params, **train_state.model_state}, + inputs, integration_timesteps, lengths, + False, + ) + loss = optax.softmax_cross_entropy(logits, targets) + loss = loss.mean() + preds = jnp.argmax(logits, axis=-1) + targets = jnp.argmax(targets, axis=-1) + accuracy = (preds == targets).mean() + + if distributed: + loss = jax.lax.pmean(loss, axis_name='data') + accuracy = jax.lax.pmean(accuracy, axis_name='data') + + return train_state, {'loss': loss, 'accuracy': accuracy} + + +def map_nested_fn(fn): + """ + Recursively apply `fn to the key-value pairs of a nested dict / pytree. + We use this for some of the optax definitions below. + """ + + def map_fn(nested_dict): + return { + k: (map_fn(v) if hasattr(v, "keys") else fn(k, v)) + for k, v in nested_dict.items() + } + + return map_fn + + +def map_nested_fn_with_keyword(keyword_1, keyword_2): + '''labels all the leaves that are descendants of keyword_1 with keyword 1, + else label the leaf with keyword_2''' + + def map_fn(nested_dict): + output_dict = {} + for k, v in nested_dict.items(): + if isinstance(v, dict): + if k == keyword_1: + output_dict[k] = map_fn_2(v) + else: + output_dict[k] = map_fn(v) + else: + if k == keyword_1: + output_dict[k] = keyword_1 + else: + output_dict[k] = keyword_2 + return output_dict + + def map_fn_2(nested_dict): + output_dict = {} + for k, v in nested_dict.items(): + if isinstance(v, dict): + output_dict[k] = map_fn_2(v) + else: + output_dict[k] = keyword_1 + return output_dict + + return map_fn + + +def seed_all(seed): + random.seed(seed) + np.random.seed(seed) + + +def get_first_device(x): + x = jax.tree_util.tree_map(lambda a: a[0], x) + return jax.device_get(x) + + +def print_model_size(params, name=''): + fn_is_complex = lambda x: x.dtype in [np.complex64, np.complex128] + param_sizes = map_nested_fn(lambda k, param: param.size * (2 if fn_is_complex(param) else 1))(params) + total_params_size = sum(jax.tree_leaves(param_sizes)) + print('[*] Model parameter count:', total_params_size) + + +def get_learning_rate_fn(lr, total_steps, warmup_steps, schedule, **kwargs): + if schedule == 'cosine': + learning_rate_fn = optax.warmup_cosine_decay_schedule( + init_value=0., + peak_value=lr, + warmup_steps=warmup_steps, + decay_steps=total_steps + ) + elif schedule == 'constant': + learning_rate_fn = optax.join_schedules([ + optax.linear_schedule( + init_value=0., + end_value=lr, + transition_steps=warmup_steps + ), + optax.constant_schedule(lr) + ], [warmup_steps]) + else: + raise ValueError(f'Unknown schedule: {schedule}') + + return learning_rate_fn + + +def get_optimizer(opt_config): + + ssm_lrs = ["B", "Lambda_re", "Lambda_im"] + ssm_fn = map_nested_fn( + lambda k, _: "ssm" + if k in ssm_lrs + else "regular" + ) + learning_rate_fn = partial( + get_learning_rate_fn, + total_steps=opt_config.total_steps, + warmup_steps=opt_config.warmup_steps, + schedule=opt_config.schedule + ) + + def optimizer(learning_rate): + tx = optax.multi_transform( + { + "ssm": optax.inject_hyperparams(partial( + optax.adamw, + b1=0.9, b2=0.999, + weight_decay=opt_config.ssm_weight_decay + ))(learning_rate=learning_rate_fn(lr=learning_rate)), + "regular": optax.adamw( + learning_rate=learning_rate_fn(lr=learning_rate * opt_config.lr_factor), + b1=0.9, b2=0.999, + weight_decay=opt_config.weight_decay), + }, + ssm_fn, + ) + if opt_config.get('accumulation_steps', False): + print(f"[*] Using gradient accumulation with {opt_config.accumulation_steps} steps") + tx = optax.MultiSteps(tx, every_k_schedule=opt_config.accumulation_steps) + return tx + + return optimizer(opt_config.ssm_lr) + + +def init_model_state(rng_key, model, inputs, steps, lengths, opt_config): + """ + Initialize the training state. + + :param rng_key: a PRNGKey + :param model: the Flax model to train + :param inputs: dummy input data + :param steps: dummy integration timesteps + :param lengths: dummy number of events + :param opt_config: a dictionary containing the optimizer configuration + :return: a TrainState object + """ + init_key, dropout_key = jax.random.split(rng_key) + variables = model.init( + {"params": init_key, + "dropout": dropout_key}, + inputs, steps, lengths, True + ) + params = variables.pop('params') + model_state = variables + print_model_size(params) + + tx = get_optimizer(opt_config) + return TrainState.create( + apply_fn=model.apply, + params=params, + tx=tx, + key=dropout_key, + model_state=model_state + ) diff --git a/event_ssm/trainer.py b/event_ssm/trainer.py new file mode 100644 index 0000000..80fd9fa --- /dev/null +++ b/event_ssm/trainer.py @@ -0,0 +1,357 @@ +import time +import os +import json +import sys +import wandb +from collections import defaultdict, OrderedDict +from omegaconf import OmegaConf as om +from omegaconf import DictConfig +import jax.numpy as jnp +import jax.random +from jaxtyping import Array +from typing import Callable, Dict, Optional, Iterator, Any +from flax.training.train_state import TrainState +from flax.training import checkpoints +from flax import jax_utils +from functools import partial + + +@partial(jax.jit, static_argnums=(1,)) +def reshape_batch_per_device(x, num_devices): + return jax.tree_util.tree_map(partial(reshape_array_per_device, num_devices=num_devices), x) + + +def reshape_array_per_device(x, num_devices): + batch_size_per_device, ragged = divmod(x.shape[0], num_devices) + if ragged: + msg = "batch size must be divisible by device count, got {} and {}." + raise ValueError(msg.format(x.shape[0], num_devices)) + return x.reshape((num_devices, batch_size_per_device, ) + (x.shape[1:])) + + +class TrainerModule: + """ + Handles training and logging of models. Most of the boilerplate code is hidden from the user. + """ + def __init__( + self, + train_state: TrainState, + training_step_fn: Callable, + evaluation_step_fn: Callable, + world_size: int, + config: DictConfig, + ): + """ + + :param train_state: A TrainState object that contains the model parameters, optimizer states etc. + :param training_step_fn: A function that takes the train_state and a batch of data and returns the updated train_state and metrics. + :param evaluation_step_fn: A function that takes the train_state and a batch of data and returns the updated train_state and metrics. + :param world_size: Number of devices to run the training on. + :param config: The configuration of the training run. + """ + super().__init__() + self.train_state = train_state + self.train_step = training_step_fn + self.eval_step = evaluation_step_fn + + self.world_size = world_size + self.log_config = config.logging + self.epoch_idx = 0 + self.num_epochs = config.training.num_epochs + self.best_eval_metrics = {} + + # logger details + self.log_dir = os.path.join(self.log_config.log_dir) + print('[*] Logging to', self.log_dir) + + if not os.path.isdir(self.log_dir): + os.makedirs(self.log_dir) + if not os.path.isdir(os.path.join(self.log_dir, 'metrics')): + os.makedirs(os.path.join(self.log_dir, 'metrics')) + if not os.path.isdir(os.path.join(self.log_dir, 'checkpoints')): + os.makedirs(os.path.join(self.log_dir, 'checkpoints')) + + num_parameters = int(sum( + [arr.size for arr in jax.tree_flatten(self.train_state.params)[0] + if isinstance(arr, Array)] + ) / self.world_size) + print("[*] Number of model parameters:", num_parameters) + + if self.log_config.wandb: + wandb.init( + # set the wandb project where this run will be logged + dir=self.log_config.log_dir, + project=self.log_config.project, + entity=self.log_config.entity, + config=om.to_container(config, resolve=True)) + wandb.config.update({'SLURM_JOB_ID': os.getenv('SLURM_JOB_ID')}) + + # log number of parameters + wandb.run.summary['Num parameters'] = num_parameters + wandb.define_metric(self.log_config.summary_metric, summary='max') + + def train_model( + self, + train_loader: Iterator, + val_loader: Iterator, + dropout_key: Array, + test_loader: Optional[Iterator] = None, + ) -> Dict[str, Any]: + """ + Trains a model on a dataset. + + :param train_loader: Data loader of the training set. + :param val_loader: Data loader of the validation set. + :param dropout_key: Random key for dropout. + :param test_loader: Data loader of the test set. + :return: A dictionary of the best evaluation metrics. + """ + + # Prepare training loop + self.on_training_start() + + for epoch_idx in range(1, self.num_epochs+1): + self.epoch_idx = epoch_idx + + # run training step for this epoch + train_metrics = self.train_epoch(train_loader, dropout_key) + + self.on_training_epoch_end(train_metrics) + + # Validation every N epochs + eval_metrics = self.eval_model( + val_loader, + log_prefix='Performance/Validation', + ) + + self.on_validation_epoch_end(eval_metrics) + + if self.log_config.wandb: + from optax import MultiStepsState + wandb_metrics = {'Performance/epoch': epoch_idx} + wandb_metrics.update(train_metrics) + wandb_metrics.update(eval_metrics) + if isinstance(self.train_state.opt_state, MultiStepsState): + lr = self.train_state.opt_state.inner_opt_state.inner_states['ssm'].inner_state.hyperparams['learning_rate'].item() + else: + lr = self.train_state.opt_state.inner_states['ssm'].inner_state.hyperparams['learning_rate'].item() + wandb_metrics['learning rate'] = lr + wandb.log(wandb_metrics) + + # Test best model if possible + if test_loader is not None: + self.load_model() + test_metrics = self.eval_model( + test_loader, + log_prefix='Performance/Test', + ) + self.save_metrics('test', test_metrics) + self.best_eval_metrics.update(test_metrics) + + if self.log_config.wandb: + wandb.log(test_metrics) + + print('-' * 89) + print('| End of Training |') + print('| Test Metrics |', + ' | '.join([f"{k.split('/')[1].replace('Test','')}: {v:5.2f}" for k, v in test_metrics.items() if 'Test' in k])) + print('-' * 89) + + return self.best_eval_metrics + + def train_epoch(self, train_loader: Iterator, dropout_key) -> Dict[str, Any]: + """ + Trains the model on one epoch of the training set. + + :param train_loader: Data loader of the training set. + :param dropout_key: Random key for dropout. + :return: A dictionary of the training metrics. + """ + + # Train model for one epoch, and log avg loss and accuracy + metrics = defaultdict(float) + running_metrics = defaultdict(float) + num_batches = 0 + num_train_batches = len(train_loader) + start_time = time.time() + epoch_start_time = start_time + + # set up intra epoch logging + log_interval = self.log_config.interval + + for i, batch in enumerate(train_loader): + num_batches += 1 + + # skip batches with empty sequences which might randomly occur due to data augmentation + _, _, _, lengths = batch + if jnp.any(lengths == 0): + continue + + if self.world_size > 1: + step_key, dropout_key = jax.vmap(jax.random.split, in_axes=0, out_axes=1)(dropout_key) + step_key = jax.vmap(jax.random.fold_in)(step_key, jnp.arange(self.world_size)) + batch = reshape_batch_per_device(batch, self.world_size) + else: + step_key, dropout_key = jax.random.split(dropout_key) + + self.train_state, step_metrics = self.train_step(self.train_state, batch, step_key) + + # exit from training if loss is nan + if jnp.isnan(step_metrics['loss']).any(): + print("EXITING TRAINING DUE TO NAN LOSS") + break + + # record metrics + for key in step_metrics: + metrics['Performance/Training ' + key] += step_metrics[key] + running_metrics['Performance/Training ' + key] += step_metrics[key] + + # print metrics to terminal + if (i + 1) % log_interval == 0: + elapsed = time.time() - start_time + start_time = time.time() + print(f'| epoch {self.epoch_idx} | {i + 1}/{num_train_batches} batches | ms/batch {elapsed * 1000 / log_interval:5.2f} |', + ' | '.join([f'{k}: {jnp.mean(v).item() / log_interval:5.2f}' for k, v in running_metrics.items()])) + for key in step_metrics: + running_metrics['Performance/Training ' + key] = 0 + + metrics = {key: jnp.mean(metrics[key] / num_batches).item() for key in metrics} + metrics['epoch_time'] = time.time() - epoch_start_time + return metrics + + def eval_model( + self, + data_loader: Iterator, + log_prefix: Optional[str] = '', + ) -> Dict[str, Any]: + """ + Evaluates the model on a dataset. + + :param data_loader: Data loader of the dataset. + :param log_prefix: Prefix to add to the keys of the logged metrics such as "Best" or "Validation". + :return: A dictionary of the evaluation metrics. + """ + + # Test model on all images of a data loader and return avg loss + metrics = defaultdict(float) + num_batches = 0 + + for i, batch in enumerate(iter(data_loader)): + + if self.world_size > 1: + batch = reshape_batch_per_device(batch, self.world_size) + + self.train_state, step_metrics = self.eval_step(self.train_state, batch) + + for key in step_metrics: + metrics[key] += step_metrics[key] + num_batches += 1 + + prefix = log_prefix + ' ' if log_prefix else '' + metrics = {(prefix + key): jnp.mean(metrics[key] / num_batches).item() for key in metrics} + return metrics + + def is_new_model_better(self, new_metrics: Dict[str, Any], old_metrics: Dict[str, Any]) -> bool: + """ + Compares two sets of evaluation metrics to decide whether the + new model is better than the previous ones or not. + + :params new_metrics: A dictionary of the evaluation metrics of the new model. + :params old_metrics: A dictionary of the evaluation metrics of the previously + best model, i.e. the one to compare to. + :return: True if the new model is better than the old one, and False otherwise. + """ + if len(old_metrics) == 0: + return True + for key, is_larger in [('val/val_metric', False), ('Performance/Validation accuracy', True), ('Performance/Validation loss', False)]: + if key in new_metrics: + if is_larger: + return new_metrics[key] > old_metrics[key] + else: + return new_metrics[key] < old_metrics[key] + assert False, f'No known metrics to log on: {new_metrics}' + + def save_metrics(self, filename: str, metrics: Dict[str, Any]): + """ + Saves a dictionary of metrics to file. Can be used as a textual + representation of the validation performance for checking in the terminal. + + :param filename: The name of the file to save the metrics to. + :param metrics: A dictionary of the metrics to save. + """ + with open(os.path.join(self.log_dir, f'metrics/{filename}.json'), 'w') as f: + json.dump(metrics, f, indent=4) + + def save_model(self): + """ + Saves the model to a file. The model is saved in the log directory. + """ + if self.world_size > 1: + state = jax_utils.unreplicate(self.train_state) + else: + state = self.train_state + checkpoints.save_checkpoint( + ckpt_dir=os.path.abspath(os.path.join(self.log_dir, 'checkpoints')), + target=state, + step=state.step, + overwrite=True, + keep=1 + ) + del state + + def load_model(self): + """ + Loads the model from a file. The model is loaded from the log directory. + """ + if self.world_size > 1: + state = jax_utils.unreplicate(self.train_state) + raw_restored = checkpoints.restore_checkpoint(ckpt_dir=os.path.join(self.log_dir, 'checkpoints'), target=state) + self.train_state = jax_utils.replicate(raw_restored) + del state + else: + self.train_state = checkpoints.restore_checkpoint(ckpt_dir=os.path.join(self.log_dir, 'checkpoints'), target=self.train_state) + + def on_training_start(self): + """ + Method called before training is started. Can be used for additional + initialization operations etc. + """ + pass + + def on_training_epoch_end(self, train_metrics): + """ + Method called at the end of each training epoch. Can be used for additional + logging or similar. + """ + print('-' * 89) + print(f"| end of epoch {self.epoch_idx:3d} | time per epoch: {train_metrics['epoch_time']:5.2f}s |") + print('| Train Metrics |', ' | '.join( + [f"{k.split('/')[1].replace('Training ', '')}: {v:5.2f}" for k, v in train_metrics.items() if + 'Train' in k])) + + # check metrics for nan values and possibly exit training + if jnp.isnan(train_metrics['Performance/Training loss']).item(): + print("EXITING TRAINING DUE TO NAN LOSS") + sys.exit(1) + + def on_validation_epoch_end(self, eval_metrics: Dict[str, Any]): + """ + Method called at the end of each validation epoch. Can be used for additional + logging and evaluation. + + Args: + eval_metrics: A dictionary of the validation metrics. New metrics added to + this dictionary will be logged as well. + """ + print('| Eval Metrics |', ' | '.join( + [f"{k.split('/')[1].replace('Validation ', '')}: {v:5.2f}" for k, v in eval_metrics.items() if + 'Validation' in k])) + print('-' * 89) + + self.save_metrics(f'eval_epoch_{str(self.epoch_idx).zfill(3)}', eval_metrics) + + # Save best model + if self.is_new_model_better(eval_metrics, self.best_eval_metrics): + self.best_eval_metrics = eval_metrics + self.save_model() + self.save_metrics('best_eval', eval_metrics) diff --git a/event_ssm/transform.py b/event_ssm/transform.py new file mode 100644 index 0000000..7c88df8 --- /dev/null +++ b/event_ssm/transform.py @@ -0,0 +1,246 @@ +import numpy as np + + +class Identity: + def __call__(self, events): + return events + + +class CropEvents: + """Crops event stream to a specified number of events + + Parameters: + num_events (int): number of events to keep + """ + + def __init__(self, num_events): + self.num_events = num_events + + def __call__(self, events): + if self.num_events >= len(events): + return events + else: + start = np.random.randint(0, len(events) - self.num_events) + return events[start:start + self.num_events] + + +class Jitter1D: + """ + Apply random jitter to event coordinates + Parameters: + max_roll (int): maximum number of pixels to roll by + """ + def __init__(self, sensor_size, var): + self.sensor_size = sensor_size + self.var = var + + def __call__(self, events): + # roll x, y coordinates by a random amount + shift = np.random.normal(0, self.var, len(events)).astype(np.int32) + events['x'] += shift + # remove events who got shifted out of the sensor size + mask = (events['x'] >= 0) & (events['x'] < self.sensor_size[0]) + events = events[mask] + return events + + +class Roll: + """ + Roll event x, y coordinates by a random amount + + Parameters: + max_roll (int): maximum number of pixels to roll by + """ + def __init__(self, sensor_size, p, max_roll): + self.sensor_size = sensor_size + self.max_roll = max_roll + self.p = p + + def __call__(self, events): + if np.random.rand() > self.p: + return events + # roll x, y coordinates by a random amount + roll_x = np.random.randint(-self.max_roll, self.max_roll) + roll_y = np.random.randint(-self.max_roll, self.max_roll) + events['x'] += roll_x + events['y'] += roll_y + # remove events who got shifted out of the sensor size + mask = (events['x'] >= 0) & (events['x'] < self.sensor_size[0]) & (events['y'] >= 0) & (events['y'] < self.sensor_size[1]) + events = events[mask] + return events + + +class Rotate: + """ + Rotate event x, y coordinates by a random angle + """ + def __init__(self, sensor_size, p, max_angle): + self.p = p + self.sensor_size = sensor_size + self.max_angle = 2 * np.pi * max_angle / 360 + + def __call__(self, events): + if np.random.rand() > self.p: + return events + # rotate x, y coordinates by a random angle + angle = np.random.uniform(-self.max_angle, self.max_angle) + x = events['x'] - self.sensor_size[0] / 2 + y = events['y'] - self.sensor_size[1] / 2 + x_new = x * np.cos(angle) - y * np.sin(angle) + y_new = x * np.sin(angle) + y * np.cos(angle) + events['x'] = (x_new + self.sensor_size[0] / 2).astype(np.int32) + events['y'] = (y_new + self.sensor_size[1] / 2).astype(np.int32) + # clip to original range + events['x'] = np.clip(events['x'], 0, self.sensor_size[0]) + events['y'] = np.clip(events['y'], 0, self.sensor_size[1]) + return events + + +class Scale: + """ + Scale event x, y coordinates by a random factor + """ + def __init__(self, sensor_size, p, max_scale): + assert max_scale >= 1 + self.p = p + self.sensor_size = sensor_size + self.max_scale = max_scale + + def __call__(self, events): + if np.random.rand() > self.p: + return events + # scale x, y coordinates by a random factor + scale = np.random.uniform(1/self.max_scale, self.max_scale) + x = events['x'] - self.sensor_size[0] / 2 + y = events['y'] - self.sensor_size[1] / 2 + x_new = x * scale + y_new = y * scale + events['x'] = (x_new + self.sensor_size[0] / 2).astype(np.int32) + events['y'] = (y_new + self.sensor_size[1] / 2).astype(np.int32) + # remove events who got shifted out of the sensor size + mask = (events['x'] >= 0) & (events['x'] < self.sensor_size[0]) & (events['y'] >= 0) & (events['y'] < self.sensor_size[1]) + events = events[mask] + return events + + +class DropEventChunk: + """ + Randomly drop a chunk of events + """ + def __init__(self, p, max_drop_size): + self.drop_prob = p + self.max_drop_size = max_drop_size + + def __call__(self, events): + max_drop_events = self.max_drop_size * len(events) + if np.random.rand() < self.drop_prob: + drop_size = np.random.randint(1, max_drop_events) + start = np.random.randint(0, len(events) - drop_size) + events = np.delete(events, slice(start, start + drop_size), axis=0) + return events + + +class OneHotLabels: + """ + Convert integer labels to one-hot encoding + """ + def __init__(self, num_classes): + self.num_classes = num_classes + + def __call__(self, label): + return np.eye(self.num_classes)[label] + + +def cut_mix_augmentation(events, targets): + """ + Cut and mix two event streams by a random event chunk. Input is a list of event streams. + + Args: + events (dict): batch of event streams of shape (batch_size, num_events, 4) + max_num_events (int): maximum number of events to mix + """ + # get the total time of all events + lengths = np.array([e.shape[0] for e in events]) + + # get fraction of the event-stream to cut + cut_size = np.random.randint(low=1, high=lengths) + start_event = np.random.randint(low=0, high=lengths - cut_size) + + # a random permutation to mix the events + rand_index = np.random.permutation(len(events)) + + mixed_events = [] + mixed_targets = [] + + # cut events from b and mix them with events from a + for i in range(len(events)): + events_b = events[rand_index[i]][start_event[rand_index[i]]:start_event[rand_index[i]] + cut_size[rand_index[i]]] + mask_a = (events[i]['t'] >= events_b['t'][0]) & (events[i]['t'] <= events_b['t'][-1]) + events_a = events[i][~mask_a] + + # mix and sort events + new_events = np.concatenate([events_a, events_b]) + new_events = new_events[np.argsort(new_events['t'])] + + # mix targets + lam = events_b.shape[0] / new_events.shape[0] + assert 0 <= lam <= 1, f'lam should be between 0 and 1, but got {lam} {cut_size[rand_index[i]]} {events_a.shape[0]} {events_b.shape[0]}' + + # append mixed events and targets + mixed_events.append(new_events) + mixed_targets.append(targets[i] * (1 - lam) + targets[rand_index[i]] * lam) + + return mixed_events, mixed_targets + + +def cut_mix_augmentation_time(events, targets): + """ + Cut and mix two event streams by a random event chunk. Input is a list of event streams. + + :param events: batch of event streams of shape (batch_size, num_events, 4) + :param targets: batch of targets of shape (batch_size, num_classes) + + :return: mixed events, mixed targets + """ + # get the total time of all events + lengths = np.array([e['t'][-1] - e['t'][0] for e in events], dtype=np.float32) + + # get fraction of the event-stream to cut + cut_size = np.random.uniform(low=0, high=lengths) + start_time = np.random.uniform(low=0, high=lengths - cut_size) + + # a random permutation to mix the events + rand_index = np.random.permutation(len(events)) + + mixed_events = [] + mixed_targets = [] + + # cut events from b and mix them with events from a + for i in range(len(events)): + start, end = start_time[rand_index[i]], start_time[rand_index[i]] + cut_size[rand_index[i]] + mask_a = (events[i]['t'] >= start) & (events[i]['t'] <= end) + mask_b = (events[rand_index[i]]['t'] >= start) & (events[rand_index[i]]['t'] <= end) + + # mix events + new_events = np.concatenate([events[i][~mask_a], events[rand_index[i]][mask_b]]) + + # avoid the case that the new events are empty + if len(new_events) == 0: + mixed_events.append(events[i]) + mixed_targets.append(targets[i]) + else: + # sort events + new_events = new_events[np.argsort(new_events['t'])] + mixed_events.append(new_events) + + # mix targets + new_length = new_events['t'][-1] - new_events['t'][0] + if len(events[rand_index[i]]['t'][mask_b]) == 0: + cut_length = 0 + else: + cut_length = events[rand_index[i]]['t'][mask_b][-1] - events[rand_index[i]]['t'][mask_b][0] + lam = cut_length / new_length + assert 0 <= lam <= 1, f'lam should be between 0 and 1, but got {lam} {new_length} {cut_size[rand_index[i]]} {start} {end}' + mixed_targets.append(targets[i] * (1 - lam) + targets[rand_index[i]] * lam) + + return mixed_events, mixed_targets diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f82bda8 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +--extra-index-url https://download.pytorch.org/whl/cpu +jax[cuda12] +flax +jaxtyping +optax +torch +torchvision +torchaudio +wandb +tonic +hydra-core \ No newline at end of file diff --git a/run_evaluation.py b/run_evaluation.py new file mode 100644 index 0000000..f6b086a --- /dev/null +++ b/run_evaluation.py @@ -0,0 +1,126 @@ +import hydra +from omegaconf import OmegaConf as om +from omegaconf import DictConfig, open_dict +from functools import partial + +import jax.random +import jax.numpy as jnp +import optax +from flax.training import checkpoints + +from event_ssm.dataloading import Datasets +from event_ssm.ssm import init_S5SSM +from event_ssm.seq_model import BatchClassificationModel + + +def setup_evaluation(cfg: DictConfig): + num_devices = jax.local_device_count() + assert cfg.checkpoint, "No checkpoint directory provided. Use checkpoint= to specify a checkpoint." + + # load task specific data + create_dataset_fn = Datasets[cfg.task.name] + + # Create dataset... + print("[*] Loading dataset...") + train_loader, val_loader, test_loader, data = create_dataset_fn( + cache_dir=cfg.data_dir, + seed=cfg.seed, + world_size=num_devices, + **cfg.training + ) + + with open_dict(cfg): + # optax updates the schedule every iteration and not every epoch + cfg.optimizer.total_steps = cfg.training.num_epochs * len(train_loader) // cfg.optimizer.accumulation_steps + cfg.optimizer.warmup_steps = cfg.optimizer.warmup_epochs * len(train_loader) // cfg.optimizer.accumulation_steps + + # scale learning rate by batch size + cfg.optimizer.ssm_lr = cfg.optimizer.ssm_base_lr * cfg.training.per_device_batch_size * num_devices * cfg.optimizer.accumulation_steps + + # load model + print("[*] Creating model...") + ssm_init_fn = init_S5SSM(**cfg.model.ssm_init) + model = BatchClassificationModel( + ssm=ssm_init_fn, + num_classes=data.n_classes, + num_embeddings=data.num_embeddings, + **cfg.model.ssm, + ) + + # initialize training state + state = checkpoints.restore_checkpoint(cfg.checkpoint, target=None) + params = state['params'] + model_state = state['model_state'] + + return model, params, model_state, train_loader, val_loader, test_loader + + +def evaluation_step( + apply_fn, + params, + model_state, + batch +): + """ + Evaluates the loss of the function passed as argument on a batch + + :param train_state: a Flax TrainState that carries the parameters, optimizer states etc + :param batch: the data consisting of [data, target] + :return: train_state, metrics + """ + inputs, targets, integration_timesteps, lengths = batch + logits = apply_fn( + + {'params': params, **model_state}, + inputs, integration_timesteps, lengths, + False, + ) + + loss = optax.softmax_cross_entropy(logits, targets) + loss = loss.mean() + preds = jnp.argmax(logits, axis=-1) + targets = jnp.argmax(targets, axis=-1) + accuracy = (preds == targets).mean() + + return {'loss': loss, 'accuracy': accuracy}, preds + + +@hydra.main(version_base=None, config_path='configs', config_name='base') +def main(config: DictConfig): + print(om.to_yaml(config)) + + model, params, model_state, train_loader, val_loader, test_loader = setup_evaluation(cfg=config) + step = partial(evaluation_step, model.apply, params, model_state) + step = jax.jit(step) + + # run training + print("[*] Running evaluation...") + metrics = {} + events_per_sample = [] + time_per_sample = [] + targets = [] + predictions = [] + num_batches = 0 + + for i, batch in enumerate(test_loader): + step_metrics, preds = step(batch) + + predictions.append(preds) + targets.append(jnp.argmax(batch[1], axis=-1)) + time_per_sample.append(jnp.sum(batch[2], axis=1)) + events_per_sample.append(batch[3]) + + if not metrics: + metrics = step_metrics + else: + for key, val in step_metrics.items(): + metrics[key] += val + num_batches += 1 + + metrics = {key: jnp.mean(metrics[key] / num_batches).item() for key in metrics} + + print(f"[*] Test accuracy: {100 * metrics['accuracy']:.2f}%") + + +if __name__ == '__main__': + main() diff --git a/run_training.py b/run_training.py new file mode 100644 index 0000000..f6741c0 --- /dev/null +++ b/run_training.py @@ -0,0 +1,121 @@ +import hydra +from omegaconf import OmegaConf as om +from omegaconf import DictConfig, open_dict +from functools import partial +import os + +import jax.random +from flax import jax_utils +from flax.training import checkpoints + +from event_ssm.dataloading import Datasets +from event_ssm.ssm import init_S5SSM +from event_ssm.seq_model import BatchClassificationModel +from event_ssm.train_utils import training_step, evaluation_step, init_model_state +from event_ssm.trainer import TrainerModule + + +def setup_training(key, cfg: DictConfig): + num_devices = jax.local_device_count() + + # load task specific data + create_dataset_fn = Datasets[cfg.task.name] + + # Create dataset... + print("[*] Loading dataset...") + train_loader, val_loader, test_loader, data = create_dataset_fn( + cache_dir=cfg.data_dir, + seed=cfg.seed, + world_size=num_devices, + **cfg.training + ) + + with open_dict(cfg): + # optax updates the schedule every iteration and not every epoch + cfg.optimizer.total_steps = cfg.training.num_epochs * len(train_loader) // cfg.optimizer.accumulation_steps + cfg.optimizer.warmup_steps = cfg.optimizer.warmup_epochs * len(train_loader) // cfg.optimizer.accumulation_steps + + # scale learning rate by batch size + cfg.optimizer.ssm_lr = cfg.optimizer.ssm_base_lr * cfg.training.per_device_batch_size * num_devices * cfg.optimizer.accumulation_steps + + # load model + print("[*] Creating model...") + ssm_init_fn = init_S5SSM(**cfg.model.ssm_init) + model = BatchClassificationModel( + ssm=ssm_init_fn, + num_classes=data.n_classes, + num_embeddings=data.num_embeddings, + **cfg.model.ssm, + ) + + # initialize training state + print("[*] Initializing model state...") + single_bsz = cfg.training.per_device_batch_size + batch = next(iter(train_loader)) + inputs, targets, timesteps, lengths = batch + state = init_model_state(key, model, inputs[:single_bsz], timesteps[:single_bsz], lengths[:single_bsz], cfg.optimizer) + + if cfg.training.get('from_checkpoint', None): + print(f'[*] Resuming model from {cfg.training.from_checkpoint}') + state = checkpoints.restore_checkpoint(cfg.training.from_checkpoint, state) + + # check if multiple GPUs are available and distribute training + if num_devices >= 2: + print(f"[*] Running training on {num_devices} GPUs") + state = jax_utils.replicate(state) + train_step = jax.pmap( + partial(training_step, distributed=True), + axis_name='data', + ) + eval_step = jax.pmap( + partial(evaluation_step, distributed=True), + axis_name='data' + ) + else: + train_step = jax.jit( + training_step + ) + eval_step = jax.jit( + evaluation_step + ) + + # set up trainer module + trainer = TrainerModule( + train_state=state, + training_step_fn=train_step, + evaluation_step_fn=eval_step, + world_size=num_devices, + config=cfg, + ) + + return trainer, train_loader, val_loader, test_loader + + +@hydra.main(version_base=None, config_path='configs', config_name='base') +def main(config: DictConfig): + # print config and save to log directory + print(om.to_yaml(config)) + with open(os.path.join(config.logging.log_dir, 'config.yaml'), 'w') as f: + om.save(config, f) + + # Set the random seed manually for reproducibility. + key = jax.random.PRNGKey(config.seed) + init_key, dropout_key = jax.random.split(key) + + if jax.local_device_count() > 1: + dropout_key = jax.random.split(dropout_key, jax.local_device_count()) + + trainer, train_loader, val_loader, test_loader = setup_training(key=init_key, cfg=config) + + # run training + print("[*] Running training...") + trainer.train_model( + train_loader=train_loader, + val_loader=val_loader, + test_loader=test_loader, + dropout_key=dropout_key + ) + + +if __name__ == '__main__': + main() diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..39f09c8 --- /dev/null +++ b/setup.py @@ -0,0 +1,10 @@ +from distutils.core import setup + +setup( + name='Event-based-SSM', + packages=['event_ssm'], + version='0.1', + description='Event-stream modeling with state-space models', + author='Mark Schoene', + author_email='mark.schoene@tu-dresden.de', +) diff --git a/tutorial_inference.ipynb b/tutorial_inference.ipynb new file mode 100644 index 0000000..1d15ca1 --- /dev/null +++ b/tutorial_inference.ipynb @@ -0,0 +1,1098 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this notebook, we demonstrate how to evaluate a trained event-SSM model on batches of unseen data on the three tasks:\n", + " \n", + "1) Spiking Speech Commands\n", + "2) Spiking Heidelberg Digits\n", + "3) DVS128 Gesture \n", + "\n", + "\n", + "\n", + "\n", + "# Setup\n", + "\n", + "Install and load the important modules and configuration.\n", + "\n", + "To install required packages, please do ``` pip3 install requirements.txt ```
\n", + "\n", + "Directories for loading datasets, model checkpoints and saving results are defined in the configuration file `system/local.yaml`.\n", + "Please set your directories accordingly.\n", + "\n", + "The trained model checkpoints are [available for download](https://datashare.tu-dresden.de/s/g2dQCi792B8DqnC).\n", + "\n", + "## Important Libraries\n", + "* [Hydra](https://hydra.cc/docs/intro/) - to manage configurations.\n", + "* [Flax](https://flax.readthedocs.io/en/latest/), Neural network package built on top of [Jax](https://jax.readthedocs.io/en/latest/) - for model development\n", + "* [Tonic](https://tonic.readthedocs.io/en/latest/) - for datasets" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "import os\n", + "from pathlib import Path\n", + "from functools import partial\n", + "from typing import Optional, TypeVar, Tuple, Union\n", + "\n", + "import numpy as np\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.metrics import confusion_matrix\n", + "\n", + "import torch\n", + "import tonic\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "from flax.training import checkpoints\n", + "\n", + "from hydra import initialize, compose\n", + "from omegaconf import OmegaConf as om\n", + "\n", + "from event_ssm.ssm import init_S5SSM\n", + "from event_ssm.seq_model import BatchClassificationModel" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": "os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Turn off GPU", + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Task 1 - Spiking Heidelberg Digits\n", + "\n", + "Spike-based version of Heidelberg digits dataset, consist of approximately 10k high-quality recordings of spoken digits ranging from zero to nine in English and German language. In total 12 speakers were included, six of which were female and six male. \n", + "\n", + "Two speakers were heldout exclusively for the test set. The remainder of the test set was filled with samples (5 % of the trials) from speakers also present in the training set.\n", + "\n", + "\n", + "\n", + "Ref : https://arxiv.org/pdf/1910.07407v3" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 1 : Load configuration" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Load configurations\n", + "with initialize(version_base=None, config_path=\"configs\"):\n", + " cfg = compose(config_name=\"base.yaml\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# See the model config:\n", + "print(om.to_yaml(cfg.model))" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 2 : Visualise data" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "data = tonic.datasets.SHD(cfg.data_dir, train=False)\n", + "audio_events, label = data[0]\n", + "tonic.utils.plot_event_grid(audio_events)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 3: Load single data sample for inference" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "DEFAULT_CACHE_DIR_ROOT = Path('./cache_dir/')\n", + "DataLoader = TypeVar('DataLoader')\n", + "InputType = [str, Optional[int], Optional[int]]\n", + "class Data:\n", + " def __init__(\n", + " self,\n", + " n_classes: int,\n", + " num_embeddings: int\n", + " ):\n", + " self.n_classes = n_classes\n", + " self.num_embeddings = num_embeddings" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "def event_stream_collate_fn(batch, resolution, pad_unit, no_time_information=False):\n", + " # x are inputs, y are targets, z are aux data\n", + " x, y, *z = zip(*batch)\n", + " assert len(z) == 0\n", + " batch_size_one = len(x) == 1\n", + "\n", + " # set labels to numpy array\n", + " y = np.stack(y)\n", + "\n", + " # integration time steps are the difference between two consequtive time stamps\n", + " if no_time_information:\n", + " timesteps = [np.ones_like(e['t'][:-1]) for e in x]\n", + " else:\n", + " timesteps = [np.diff(e['t']) for e in x]\n", + "\n", + " # NOTE: since timesteps are deltas, their length is L - 1, and we have to remove the last token in the following\n", + "\n", + " # process tokens for single input dim (e.g. audio)\n", + " if len(resolution) == 1:\n", + " tokens = [e['x'][:-1].astype(np.int32) for e in x]\n", + " elif len(resolution) == 2:\n", + " tokens = [(e['x'][:-1] * e['y'][:-1] + np.prod(resolution) * e['p'][:-1].astype(np.int32)).astype(np.int32) for e in x]\n", + " else:\n", + " raise ValueError('resolution must contain 1 or 2 elements')\n", + "\n", + " # get padding lengths\n", + " lengths = np.array([len(e) for e in timesteps], dtype=np.int32)\n", + " pad_length = (lengths.max() // pad_unit) * pad_unit + pad_unit\n", + "\n", + " # pad tokens with -1, which results in a zero vector with embedding look-ups\n", + " tokens = np.stack(\n", + " [np.pad(e, (0, pad_length - len(e)), mode='constant', constant_values=-1) for e in tokens])\n", + " timesteps = np.stack(\n", + " [np.pad(e, (0, pad_length - len(e)), mode='constant', constant_values=0) for e in timesteps])\n", + "\n", + " # timesteps are in micro seconds... transform to milliseconds\n", + " timesteps = timesteps / 1000\n", + "\n", + " if batch_size_one:\n", + " lengths = lengths[None, ...]\n", + "\n", + " return tokens, y, timesteps, lengths" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "def event_stream_dataloader(test_data,eval_batch_size,eval_collate_fn, rng, num_workers=0):\n", + " def dataloader(dset, bsz, collate_fn, shuffle, drop_last):\n", + " return torch.utils.data.DataLoader(\n", + " dset,\n", + " batch_size=bsz,\n", + " drop_last=drop_last,\n", + " collate_fn=collate_fn,\n", + " shuffle=shuffle,\n", + " generator=rng,\n", + " num_workers=num_workers\n", + " )\n", + " test_loader = dataloader(test_data, eval_batch_size, eval_collate_fn, shuffle=True, drop_last=False)\n", + " return test_loader" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "def create_events_shd_classification_dataset(\n", + " cache_dir: Union[str, Path] = DEFAULT_CACHE_DIR_ROOT,\n", + " per_device_eval_batch_size: int = 64,\n", + " world_size: int = 1,\n", + " num_workers: int = 0,\n", + " seed: int = 42,\n", + " pad_unit: int = 8192,\n", + " no_time_information: bool = False,\n", + " **kwargs\n", + ") -> Tuple[DataLoader, Data]:\n", + " \"\"\"\n", + " creates a view of the spiking heidelberg digits dataset\n", + "\n", + " :param cache_dir:\t\t (str):\t\twhere to store the dataset\n", + " :param per_device_eval_batch_size:\t\t\t\t(int):\t\tEvaluation Batch size.\n", + " :param seed:\t\t\t (int):\t\tSeed for shuffling data.\n", + " \"\"\"\n", + " print(\"[*] Generating Spiking Heidelberg Digits Classification Dataset\")\n", + "\n", + " if seed is not None:\n", + " rng = torch.Generator()\n", + " rng.manual_seed(seed)\n", + " else:\n", + " rng = None\n", + " \n", + " #target_transforms = OneHotLabels(num_classes=20)\n", + " test_data = tonic.datasets.SHD(save_to=cache_dir, train=False)\n", + " collate_fn = partial(event_stream_collate_fn, resolution=(700,), pad_unit=pad_unit, no_time_information=no_time_information)\n", + " test_loader = event_stream_dataloader(\n", + " test_data,\n", + " eval_collate_fn=collate_fn,\n", + " eval_batch_size=per_device_eval_batch_size * world_size,\n", + " rng=rng, \n", + " num_workers=num_workers\n", + " )\n", + " data = Data(\n", + " n_classes=20, num_embeddings=700)\n", + " return test_loader, data" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "print(\"[*] Loading dataset...\")\n", + "num_devices = jax.local_device_count()\n", + "test_loader, data = create_events_shd_classification_dataset(\n", + " cache_dir=cfg.data_dir,\n", + " seed=cfg.seed,\n", + " world_size=num_devices,\n", + " per_device_eval_batch_size = 1,\n", + " pad_unit=cfg.training.pad_unit \n", + " )" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Load a sample\n", + "batch = next(iter(test_loader))\n", + "inputs, targets, timesteps, lengths = batch" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 4 : Load model" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Set the random seed manually for reproducibility.\n", + "init_key = jax.random.PRNGKey(cfg.seed)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Model initialisation in flax\n", + "ssm_init_fn = init_S5SSM(**cfg.model.ssm_init)\n", + "model = BatchClassificationModel(\n", + " ssm=ssm_init_fn,\n", + " num_classes=data.n_classes,\n", + " num_embeddings=data.num_embeddings,\n", + " **cfg.model.ssm,\n", + " )" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Visualise model\n", + "print(model.tabulate({\"params\": init_key},\n", + " inputs, timesteps, lengths, False))" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "checkpoint_dir = os.path.abspath(os.path.join(cfg.checkpoint_dir, 'SHD'))\n", + "training_state = checkpoints.restore_checkpoint(checkpoint_dir, target=None)\n", + "params = training_state['params']\n", + "model_state = training_state['model_state']" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 5 - Model prediction" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": "logits = model.apply({'params': params, **model_state}, inputs, timesteps, lengths, False)", + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "print(f\"Predicted label:{jnp.argmax(logits,axis=-1)}\")\n", + "print(f\"Actual label:{targets}\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 6 - Evaluate model on a batch" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "print(\"[*] Loading dataset...\")\n", + "num_devices = jax.local_device_count()\n", + "test_loader, data = create_events_shd_classification_dataset(\n", + " cache_dir=cfg.data_dir,\n", + " seed=cfg.seed,\n", + " world_size=num_devices,\n", + " per_device_eval_batch_size = cfg.training.per_device_eval_batch_size,\n", + " pad_unit=cfg.training.pad_unit,\n", + " #no_time_information = cfg.training.no_time_information\n", + " \n", + " )" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Load a batch\n", + "batch = next(iter(test_loader))\n", + "inputs, targets, timesteps, lengths = batch\n", + "logits = model.apply({'params': params, **model_state},inputs, timesteps, lengths,False)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Plot the confusion matrix\n", + "cm = confusion_matrix(jnp.argmax(logits,axis=1), targets)\n", + "sns.heatmap(cm, annot=True,fmt='d', cmap='YlGnBu')\n", + "plt.ylabel('Prediction',fontsize=12)\n", + "plt.xlabel('Actual',fontsize=12)\n", + "plt.title('Confusion Matrix',fontsize=16)\n", + "plt.show()" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "print(f\"Accuracy of the model: {(jnp.argmax(logits,axis=1)==targets).mean()}\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Task 2 - Spiking Speech Commands \n", + "\n", + "The Spiking Speech Commands is based on the Speech Commands release by Google which consists of utterances recorded from a larger number of speakers under less controlled conditions. It contains 35 word categories from a larger number of speakers.\n", + "\n", + "Ref : https://arxiv.org/pdf/1910.07407v3" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 1 : Load configuration" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Load configurations\n", + "with initialize(version_base=None, config_path=\"configs\"):\n", + " cfg = compose(config_name=\"base.yaml\",overrides=[\"task=spiking-speech-commands\"])" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# See the model config:\n", + "print(om.to_yaml(cfg.model))" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 2 : Visualise data" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "data = tonic.datasets.SSC(cfg.data_dir, split='test')\n", + "audio_events, label = data[0]\n", + "tonic.utils.plot_event_grid(audio_events)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 3: Load single data sample for inference" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "def create_events_ssc_classification_dataset(\n", + " cache_dir: Union[str, Path] = DEFAULT_CACHE_DIR_ROOT,\n", + " per_device_eval_batch_size: int = 64,\n", + " world_size: int = 1,\n", + " num_workers: int = 0,\n", + " seed: int = 42,\n", + " pad_unit: int = 8192,\n", + " no_time_information: bool = False,\n", + " **kwargs\n", + ") -> Tuple[DataLoader, DataLoader, DataLoader, Data]:\n", + " \"\"\"\n", + " creates a view of the spiking speech commands dataset\n", + "\n", + " :param cache_dir:\t\t(str):\t\twhere to store the dataset\n", + " :param bsz:\t\t\t\t(int):\t\tBatch size.\n", + " :param seed:\t\t\t(int)\t\tSeed for shuffling data.\n", + " \"\"\"\n", + " print(\"[*] Generating Spiking Speech Commands Classification Dataset\")\n", + "\n", + " if seed is not None:\n", + " rng = torch.Generator()\n", + " rng.manual_seed(seed)\n", + " else:\n", + " rng = None\n", + "\n", + " test_data = tonic.datasets.SSC(save_to=cache_dir, split='test')\n", + " collate_fn = partial(event_stream_collate_fn, resolution=(700,), pad_unit=pad_unit, no_time_information=no_time_information)\n", + " test_loader = event_stream_dataloader(\n", + " test_data,\n", + " eval_collate_fn=collate_fn,\n", + " eval_batch_size=per_device_eval_batch_size * world_size,\n", + " rng=rng, \n", + " num_workers=num_workers,\n", + " )\n", + "\n", + " data = Data(\n", + " n_classes=35, num_embeddings=700\n", + " )\n", + " return test_loader, data\n" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "print(\"[*] Loading dataset...\")\n", + "num_devices = jax.local_device_count()\n", + "test_loader, data = create_events_ssc_classification_dataset(\n", + " cache_dir=cfg.data_dir,\n", + " seed=cfg.seed,\n", + " world_size=num_devices,\n", + " per_device_eval_batch_size = 1,\n", + " pad_unit=cfg.training.pad_unit \n", + " )" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Load a sample\n", + "batch = next(iter(test_loader))\n", + "inputs, targets, timesteps, lengths = batch" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 4 : Load model" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Set the random seed manually for reproducibility.\n", + "init_key = jax.random.PRNGKey(cfg.seed)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "ssm_init_fn = init_S5SSM(**cfg.model.ssm_init)\n", + "model = BatchClassificationModel(\n", + " ssm=ssm_init_fn,\n", + " num_classes=data.n_classes,\n", + " num_embeddings=data.num_embeddings,\n", + " **cfg.model.ssm,\n", + " )" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "print(model.tabulate({\"params\": init_key},\n", + " inputs, timesteps, lengths, False))" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# load model parameters from checkpoint\n", + "checkpoint_dir = os.path.abspath(os.path.join(cfg.checkpoint_dir, 'SSC'))\n", + "training_state = checkpoints.restore_checkpoint(checkpoint_dir, target=None)\n", + "params = training_state['params']\n", + "model_state = training_state['model_state']" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 5 - Model prediction" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": "logits = model.apply({'params': params, **model_state},inputs, timesteps, lengths,False)", + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "print(f\"Predicted label:{jnp.argmax(logits,axis=-1)}\")\n", + "print(f\"Actual label:{targets}\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 6 - Evaluate model on single batch" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "print(\"[*] Loading dataset...\")\n", + "num_devices = jax.local_device_count()\n", + "test_loader, data = create_events_ssc_classification_dataset(\n", + " cache_dir=cfg.data_dir,\n", + " seed=cfg.seed,\n", + " world_size=num_devices,\n", + " per_device_eval_batch_size = cfg.training.per_device_eval_batch_size,\n", + " pad_unit=cfg.training.pad_unit,\n", + " #no_time_information = cfg.training.no_time_information\n", + " \n", + " )" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Load a batch\n", + "batch = next(iter(test_loader))\n", + "inputs, targets, timesteps, lengths = batch" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": "logits = model.apply({'params': params, **model_state},inputs, timesteps, lengths,False)", + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Plot the confusion matrix\n", + "cm = confusion_matrix(jax.numpy.argmax(logits,axis=-1), targets)\n", + "sns.heatmap(cm, annot=True,fmt='d', cmap='YlGnBu')\n", + "plt.ylabel('Prediction',fontsize=12)\n", + "plt.xlabel('Actual',fontsize=12)\n", + "plt.title('Confusion Matrix',fontsize=16)\n", + "plt.show()" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": "print(f\"Accuracy of the model: {(jnp.argmax(logits,axis=-1)==targets).mean()}\")", + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Task 3 - DVS Gesture " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Task Description\n", + "\n", + "It is the first gesture recognition system implemented end-to-end on event-based hardware. The dataset comprises of 11 hand gesture categories from 29 subjects under 3 illumination conditions.\n", + "\n", + "Ref : https://ieeexplore.ieee.org/document/8100264\n", + "\n", + "### Excercise\n", + "\n", + "Similar to SHD and SSC, implement inference steps for DVS Gesture data." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 1 : Load configuration" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Load configurations\n", + "with initialize(version_base=None, config_path=\"configs\"):\n", + " cfg = compose(config_name=\"base.yaml\",overrides=[\"task=dvs-gesture\"])" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# model config:\n", + "print(om.to_yaml(cfg.model))" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 2 : Visualise Data" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "from IPython.display import HTML\n", + "import warnings\n", + "\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "#warnings.filterwarnings( \"ignore\", module = \"matplotlib\\..*\" )\n", + "\n", + "data = tonic.datasets.DVSGesture(cfg.data_dir, train=False)\n", + "events, label = data[0]\n", + "\n", + "transform = tonic.transforms.Compose(\n", + " [\n", + " tonic.transforms.TimeJitter(std=100, clip_negative=False),\n", + " tonic.transforms.ToFrame(\n", + " sensor_size=data.sensor_size,\n", + " time_window=10000,\n", + " ),\n", + " ]\n", + ")\n", + "\n", + "frames = transform(events)\n", + "HTML(tonic.utils.plot_animation((frames* 255).astype(np.uint8)).to_html5_video())" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 3: Load single inference sample" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "from event_ssm.transform import Identity" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "def create_events_dvs_gesture_classification_dataset(\n", + " cache_dir: Union[str, Path] = DEFAULT_CACHE_DIR_ROOT,\n", + " per_device_eval_batch_size: int = 64,\n", + " world_size: int = 1,\n", + " num_workers: int = 0,\n", + " seed: int = 42,\n", + " pad_unit: int = 2 ** 19,\n", + " downsampling: int=1,\n", + " **kwargs\n", + ") -> Tuple[DataLoader, Data]:\n", + " \"\"\"\n", + " creates a view of the DVS Gesture dataset\n", + "\n", + " :param cache_dir:\t\t(str):\t\twhere to store the dataset\n", + " :param bsz:\t\t\t\t(int):\t\tBatch size.\n", + " :param seed:\t\t\t(int)\t\tSeed for shuffling data.\n", + " \"\"\"\n", + " print(\"[*] Generating DVS Gesture Classification Dataset\")\n", + "\n", + " if seed is not None:\n", + " rng = torch.Generator()\n", + " rng.manual_seed(seed)\n", + " else:\n", + " rng = None\n", + "\n", + " orig_sensor_size = (128, 128, 2)\n", + " new_sensor_size = (128 // downsampling, 128 // downsampling, 2)\n", + " test_transforms = tonic.transforms.Compose([\n", + " tonic.transforms.Downsample(sensor_size=orig_sensor_size, target_size=new_sensor_size[:2]) if downsampling > 1 else Identity(),\n", + " ])\n", + "\n", + " TestData = partial(tonic.datasets.DVSGesture, save_to=cache_dir, train=False)\n", + " test_data = TestData(transform=test_transforms)\n", + "\n", + " # define collate function\n", + " eval_collate_fn = partial(\n", + " event_stream_collate_fn,\n", + " resolution=new_sensor_size[:2],\n", + " pad_unit=pad_unit,\n", + " )\n", + " test_loader = event_stream_dataloader(\n", + " test_data,\n", + " eval_collate_fn=eval_collate_fn,\n", + " eval_batch_size=per_device_eval_batch_size * world_size,\n", + " rng=rng, \n", + " num_workers=num_workers\n", + " )\n", + "\n", + " data = Data(\n", + " n_classes=11, num_embeddings=np.prod(new_sensor_size)\n", + " )\n", + " return test_loader, data" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "num_devices = jax.local_device_count()\n", + " # Create dataset...\n", + "test_loader, data = create_events_dvs_gesture_classification_dataset(\n", + " cache_dir=cfg.data_dir,\n", + " seed=cfg.seed,\n", + " world_size=num_devices,\n", + " per_device_eval_batch_size = 1,\n", + " pad_unit=cfg.training.pad_unit, \n", + " )" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Load a sample\n", + "batch = next(iter(test_loader))\n", + "inputs, targets, timesteps, lengths = batch" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 4 : Load model" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Set the random key for the task\n", + "init_key = jax.random.PRNGKey(cfg.seed)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "print(\"[*] Creating model...\")\n", + "ssm_init_fn = init_S5SSM(**cfg.model.ssm_init)\n", + "model = BatchClassificationModel(\n", + " ssm=ssm_init_fn,\n", + " num_classes=data.n_classes,\n", + " num_embeddings=data.num_embeddings,\n", + " **cfg.model.ssm,\n", + " )" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# visualise model\n", + "print(model.tabulate({\"params\": init_key},\n", + " inputs, timesteps, lengths, False))" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# load model parameters from checkpoint\n", + "checkpoint_dir = os.path.abspath(os.path.join(cfg.checkpoint_dir, 'DVS'))\n", + "training_state = checkpoints.restore_checkpoint(checkpoint_dir, target=None)\n", + "params = training_state['params']\n", + "model_state = training_state['model_state']" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 5 - Model prediction" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": "logits = model.apply({'params': params, **model_state}, inputs, timesteps, lengths, False)", + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "print(f\"Predicted label:{jnp.argmax(logits,axis=-1)}\")\n", + "print(f\"Actual label:{targets}\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 6 - Evaluate model on single batch" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "num_devices = jax.local_device_count()\n", + " # Create dataset...\n", + "test_loader, data = create_events_dvs_gesture_classification_dataset(\n", + " cache_dir=cfg.data_dir,\n", + " seed=cfg.seed,\n", + " world_size=num_devices,\n", + " per_device_eval_batch_size = cfg.training.per_device_eval_batch_size,\n", + " pad_unit=cfg.training.pad_unit,\n", + " #no_time_information = cfg.training.no_time_information\n", + " )" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Load a batch\n", + "batch = next(iter(test_loader))\n", + "inputs, targets, timesteps, lengths = batch" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": { + "jupyter": { + "is_executing": true + } + }, + "source": "logits = model.apply({'params': params, **model_state}, inputs, timesteps, lengths, False)", + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Plot the confusion matrix\n", + "cm = confusion_matrix(jax.numpy.argmax(logits, axis=-1), targets)\n", + "sns.heatmap(cm, annot=True, fmt='d', cmap='YlGnBu')\n", + "plt.ylabel('Prediction', fontsize=12)\n", + "plt.xlabel('Actual', fontsize=12)\n", + "plt.title('Confusion Matrix', fontsize=16)\n", + "plt.show()" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": "print(f\"Accuracy of the model: {(jnp.argmax(logits, axis=-1) == targets).mean()}\")", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "code", + "source": "", + "outputs": [], + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "blocksparse", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorial_online_inference.ipynb b/tutorial_online_inference.ipynb new file mode 100644 index 0000000..f2060f0 --- /dev/null +++ b/tutorial_online_inference.ipynb @@ -0,0 +1,491 @@ +{ + "cells": [ + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Online Inference Tutorial\n", + "In this tutorial, we will implement online inference with event-based state-space models.\n", + "Online inference is the process of classifying events as they arrive in real-time.\n", + "For many edge systems, the batch size is 1, and the model has to meet a specific throughput of events per second.\n", + "Here, you will test if your CPU is able to run real-time classification with EventSSM.\n", + "\n", + "The tutorial requires basic familiarity with JAX." + ], + "id": "b99721b9d6b26c10" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-05-27T09:22:32.658921Z", + "start_time": "2024-05-27T09:22:32.654126Z" + } + }, + "cell_type": "code", + "source": [ + "from hydra import initialize, compose\n", + "from omegaconf import OmegaConf as om\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "from event_ssm.ssm import init_S5SSM\n", + "from event_ssm.seq_model import ClassificationModel" + ], + "id": "bc0a9044321d654d", + "outputs": [], + "execution_count": 24 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "## Step 1: Load the model", + "id": "d8b261a76014fbc7" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-05-27T09:22:33.679045Z", + "start_time": "2024-05-27T09:22:33.561733Z" + } + }, + "cell_type": "code", + "source": [ + "# Load configurations\n", + "with initialize(version_base=None, config_path=\"configs\"):\n", + " cfg = compose(config_name=\"base.yaml\", overrides=[\"model=dvs/small\"])" + ], + "id": "7efb7b5428f7472", + "outputs": [], + "execution_count": 25 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-05-27T09:22:33.771341Z", + "start_time": "2024-05-27T09:22:33.766065Z" + } + }, + "cell_type": "code", + "source": [ + "# Print the configuration\n", + "print(om.to_yaml(cfg.model))" + ], + "id": "16eb6e254f8090cd", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ssm_init:\n", + " C_init: lecun_normal\n", + " dt_min: 0.001\n", + " dt_max: 0.1\n", + " conj_sym: false\n", + " clip_eigs: true\n", + "ssm:\n", + " discretization: async\n", + " d_model: 128\n", + " d_ssm: 128\n", + " ssm_block_size: 16\n", + " num_stages: 2\n", + " num_layers_per_stage: 3\n", + " dropout: 0.25\n", + " classification_mode: timepool\n", + " prenorm: true\n", + " batchnorm: false\n", + " bn_momentum: 0.95\n", + " pooling_stride: 16\n", + " pooling_mode: timepool\n", + " state_expansion_factor: 2\n", + "\n" + ] + } + ], + "execution_count": 26 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-05-27T09:22:34.290002Z", + "start_time": "2024-05-27T09:22:34.282856Z" + } + }, + "cell_type": "code", + "source": [ + "# Set the random seed manually for reproducibility.\n", + "key = jax.random.PRNGKey(cfg.seed)\n", + "init_key, data_key = jax.random.split(key)" + ], + "id": "9806959c6627a4d5", + "outputs": [], + "execution_count": 27 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-05-27T09:22:34.904836Z", + "start_time": "2024-05-27T09:22:34.897859Z" + } + }, + "cell_type": "code", + "source": [ + "# Model initialisation in flax\n", + "ssm_init_fn = init_S5SSM(**cfg.model.ssm_init)\n", + "\n", + "# number of classes (dummy)\n", + "classes = 10\n", + "\n", + "# number of tokens for a DVS sensor of size 128x128\n", + "num_tokens = 128 * 128 * 2\n", + "model = ClassificationModel(\n", + " ssm=ssm_init_fn,\n", + " num_classes=10,\n", + " num_embeddings=num_tokens,\n", + " **cfg.model.ssm,\n", + " )" + ], + "id": "b936f3fdd1538bfe", + "outputs": [], + "execution_count": 28 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "EventSSM subsamples sequences in multiple stages to reduce the computational cost.\n", + "Let's investigate the total subsampling" + ], + "id": "accb046df2d07e7" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-05-27T09:56:14.174709Z", + "start_time": "2024-05-27T09:56:14.161702Z" + } + }, + "cell_type": "code", + "source": [ + "total_subsampling = cfg.model.ssm.pooling_stride ** cfg.model.ssm.num_stages\n", + "print(f\"Total subsampling: {total_subsampling}\")" + ], + "id": "3ed763820fe9f204", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total subsampling: 256\n" + ] + } + ], + "execution_count": 35 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-05-27T09:56:42.653733Z", + "start_time": "2024-05-27T09:56:38.056333Z" + } + }, + "cell_type": "code", + "source": [ + "# initialize model parameters\n", + "x = jnp.zeros(total_subsampling, dtype=jnp.int32)\n", + "t = jnp.ones(total_subsampling)\n", + "variables = model.init(\n", + " {\"params\": init_key},\n", + " x, t, total_subsampling, False\n", + " )" + ], + "id": "e18fbb811f6c46e0", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SSM: 128 -> 128 -> 128 (stride 16 with pooling mode timepool)\n", + "SSM: 128 -> 128 -> 128\n", + "SSM: 128 -> 128 -> 128\n", + "SSM: 128 -> 256 -> 256 (stride 16 with pooling mode timepool)\n", + "SSM: 256 -> 256 -> 256\n", + "SSM: 256 -> 256 -> 256\n" + ] + } + ], + "execution_count": 36 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Step 2: Run the model on random data\n", + "Generate a random list of integer tokens, jit compile the model and classify online." + ], + "id": "8ed847f8098b7f53" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-05-27T10:14:54.375839Z", + "start_time": "2024-05-27T10:14:54.360101Z" + } + }, + "cell_type": "code", + "source": [ + "# Generate random data\n", + "sequence_length = 2 ** 18\n", + "tokens = jax.random.randint(data_key, shape=(sequence_length,), minval=0, maxval=num_tokens)\n", + "timesteps = jnp.ones(sequence_length)\n", + "print(\"Sequence length:\", sequence_length)" + ], + "id": "9b32e55bfaf178e9", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sequence length: 262144\n" + ] + } + ], + "execution_count": 63 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-05-27T10:15:07.170346Z", + "start_time": "2024-05-27T10:14:55.732901Z" + } + }, + "cell_type": "code", + "source": [ + "# jit compile the model\n", + "from functools import partial\n", + "jit_apply = jax.jit(partial(model.apply, length=total_subsampling, train=False))\n", + "jit_apply(variables, x[:total_subsampling], t[:total_subsampling])" + ], + "id": "8f49cd496d6ef30d", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SSM: 128 -> 128 -> 128 (stride 16 with pooling mode timepool)\n", + "SSM: 128 -> 128 -> 128\n", + "SSM: 128 -> 128 -> 128\n", + "SSM: 128 -> 256 -> 256 (stride 16 with pooling mode timepool)\n", + "SSM: 256 -> 256 -> 256\n", + "SSM: 256 -> 256 -> 256\n" + ] + }, + { + "data": { + "text/plain": [ + "Array([-0.12317943, -0.17902763, -0.26315966, 0.5992651 , 0.7048361 ,\n", + " 1.2036127 , 0.00121723, 0.41398254, 0.26262668, 0.18357195], dtype=float32)" + ] + }, + "execution_count": 64, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 64 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-05-27T10:15:11.763566Z", + "start_time": "2024-05-27T10:15:08.166525Z" + } + }, + "cell_type": "code", + "source": [ + "# loop through the model\n", + "from tqdm import tqdm\n", + "from time import time\n", + "print(f\"Looping through {sequence_length} events with total_subsampling={total_subsampling} --> {sequence_length // total_subsampling} iterations\")\n", + "start = time()\n", + "for i in tqdm(range(0, sequence_length, total_subsampling)):\n", + " x = tokens[i:i + total_subsampling]\n", + " t = timesteps[i:i + total_subsampling]\n", + " logits = jit_apply(variables, x, t).block_until_ready()\n", + "end = time()\n", + "print(f\"Time taken: {end - start:.2f}s\")\n", + "print(f\"Events per second: {sequence_length / (end - start):.2f}\")" + ], + "id": "55a885c77a44e8eb", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Looping through 262144 events with total_subsampling=256 --> 1024 iterations\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1024/1024 [00:03<00:00, 285.19it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Time taken: 3.59s\n", + "Events per second: 72962.94\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "execution_count": 65 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Step 3: Optimize the inference speed\n", + "We suggest to use [jax.lax.scan](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) instead of the for loop to further speed up the inference." + ], + "id": "541f0afde67081f8" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-05-27T10:15:27.619686Z", + "start_time": "2024-05-27T10:15:14.529552Z" + } + }, + "cell_type": "code", + "source": [ + "def step(carry, inputs):\n", + " x, t = inputs\n", + " logits = model.apply(variables, x, t, total_subsampling, False)\n", + " return None, logits\n", + "tokens = tokens.reshape(-1, total_subsampling)\n", + "timesteps = timesteps.reshape(-1, total_subsampling)\n", + "\n", + "# run the scan: first jit-compiles and then iterates\n", + "logits = jax.lax.scan(step, init=None, xs=(tokens, timesteps))" + ], + "id": "1318e7467cbb3b3f", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SSM: 128 -> 128 -> 128 (stride 16 with pooling mode timepool)\n", + "SSM: 128 -> 128 -> 128\n", + "SSM: 128 -> 128 -> 128\n", + "SSM: 128 -> 256 -> 256 (stride 16 with pooling mode timepool)\n", + "SSM: 256 -> 256 -> 256\n", + "SSM: 256 -> 256 -> 256\n" + ] + } + ], + "execution_count": 66 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-05-27T10:15:49.444621Z", + "start_time": "2024-05-27T10:15:46.788818Z" + } + }, + "cell_type": "code", + "source": [ + "# measure run-time\n", + "start = time()\n", + "_, logits = jax.block_until_ready(jax.lax.scan(step, init=None, xs=(tokens, timesteps)))\n", + "end = time()\n", + "print(f\"Time taken: {end - start:.2f}s\")\n", + "print(f\"Events per second: {sequence_length / (end - start):.2f}\")" + ], + "id": "aa170aadad84036d", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Time taken: 2.65s\n", + "Events per second: 99018.86\n" + ] + } + ], + "execution_count": 68 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-05-27T10:15:53.224299Z", + "start_time": "2024-05-27T10:15:53.220810Z" + } + }, + "cell_type": "code", + "source": "logits.shape", + "id": "718dffb170c2df1c", + "outputs": [ + { + "data": { + "text/plain": [ + "(1024, 10)" + ] + }, + "execution_count": 69, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 69 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Step 4: Run inference on the DVS128 Gestures dataset\n", + "Follow the steps in the `tutorial_inference.ipynb` to run inference on the DVS128 Gestures dataset with a pretrained model.\n", + "Plot the confidence of the model in the correct class over time" + ], + "id": "bcaba7dc4697605d" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "", + "id": "d9110111c449d185" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorial_training.ipynb b/tutorial_training.ipynb new file mode 100644 index 0000000..5c1ccf0 --- /dev/null +++ b/tutorial_training.ipynb @@ -0,0 +1,448 @@ +{ + "cells": [ + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Tutorial: Training a model\n", + "In this tutorial, we will train an event-based state-space model on a reduced version of the [Spiking Heidelberg Digits](https://zenkelab.org/resources/spiking-heidelberg-datasets-shd/) dataset.\n", + "For training on larger datasets or multiple GPUs, we recommend using the training script `run_training.py` instead.\n", + "\n", + "## Setup\n", + "\n", + "Install and load the important modules and configuration. To install required packages, please do \n", + "```\n", + "pip3 install requirements.txt\n", + "```\n", + "\n", + "Directories for loading datasets, model checkpoints and saving results are defined in the configuration file `system/local.yaml`.\n", + "Please set your directories accordingly." + ], + "id": "4d02d51dcadfcfb" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Data loading\n", + "The SHD dataset contains 20 classes, digits from 0 to 9 in both German and English. \n", + "We will use a reduced version of the dataset containing only two digits to train the model to non-trivial performance in reasonable time even on CPUs.\n", + "\n", + "[Download the training and test dataset](https://zenkelab.org/datasets/) and unpack the archives to `./data/`." + ], + "id": "df7ee68ff3e429ed" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "from torch.utils.data import Dataset, DataLoader, random_split\n", + "import h5py\n", + "import numpy as np\n", + "\n", + "class SpikingHeidelbergDigits(Dataset):\n", + " def __init__(self, path_to_file):\n", + " self.num_classes = 2\n", + " self.num_channels = 700\n", + " self.path_to_file = path_to_file\n", + " \n", + " # load the dataset\n", + " with h5py.File(path_to_file, 'r') as f:\n", + " self.channels = f['spikes']['units'][:]\n", + " self.timesteps = f['spikes']['times'][:]\n", + " self.labels = f['labels'][:]\n", + " \n", + " # filter the dataset to contain only two classes\n", + " mask = (self.labels == 0) | (self.labels == 1)\n", + " self.channels = self.channels[mask]\n", + " self.timesteps = self.timesteps[mask]\n", + " self.labels = self.labels[mask]\n", + " \n", + " def __len__(self):\n", + " return len(self.labels)\n", + " \n", + " def __getitem__(self, idx):\n", + " # create tonic-like structured arrays\n", + " dtype = np.dtype([(\"t\", int), (\"x\", int), (\"p\", int)])\n", + " struct_arr = np.empty_like(self.channels[idx], dtype=dtype)\n", + " \n", + " # yield timesteps in milliseconds\n", + " timesteps = self.timesteps[idx] * 1e6\n", + " \n", + " struct_arr['t'] = timesteps\n", + " struct_arr['x'] = self.channels[idx]\n", + " struct_arr['p'] = 1\n", + " \n", + " # one-hot encoding of labels (required for CutMix augmentation)\n", + " label = np.eye(self.num_classes)[self.labels[idx]].astype(np.int32)\n", + " \n", + " return struct_arr, label" + ], + "id": "f9883d23c86e5bcd", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "# Load the training and test dataset\n", + "train_dataset = SpikingHeidelbergDigits('data/shd_train.h5')\n", + "test_dataset = SpikingHeidelbergDigits('data/shd_test.h5')" + ], + "id": "3be0429979f96a3f", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Check the length of the datasets to check if the data loading was successful.", + "id": "cf72529578541b9" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "print(f\"Number of training samples: {len(train_dataset)}\")\n", + "print(f\"Number of test samples: {len(test_dataset)}\")" + ], + "id": "ec059aefa5d3408", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Now, create a validation set by randomly splitting the training dataset, and create data loaders for training, validation, and test datasets.", + "id": "ee746256534411df" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "# Split the training dataset into training and validation\n", + "train_dataset, val_dataset = random_split(train_dataset, [int(0.8*len(train_dataset)), len(train_dataset) - int(0.8*len(train_dataset))])\n", + "\n", + "# Create data loaders\n", + "from event_ssm.dataloading import event_stream_collate_fn\n", + "from functools import partial\n", + "\n", + "collate_fn = partial(event_stream_collate_fn, resolution=(700,), pad_unit=8192)\n", + "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True, collate_fn=collate_fn)\n", + "val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, drop_last=True, collate_fn=collate_fn)\n", + "test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)" + ], + "id": "1ab24a1d63c4c194", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Model definition\n", + "We use the [hydra](https://hydra.cc/docs/intro/) package for efficient configuration management. Define the model configuration in a config file in the `configs` directory." + ], + "id": "acc88e3270fda10b" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "from hydra import compose, initialize\n", + "from omegaconf import OmegaConf, open_dict\n", + "\n", + "with initialize(version_base=None, config_path=\"configs\", job_name=\"training tutorial\"):\n", + " cfg = compose(config_name=\"base\", overrides=[\"task=tutorial\"])\n", + "\n", + "with open_dict(cfg): \n", + " # optax updates the schedule every iteration and not every epoch\n", + " cfg.optimizer.total_steps = cfg.training.num_epochs * len(train_loader) // cfg.optimizer.accumulation_steps\n", + " cfg.optimizer.warmup_steps = cfg.optimizer.warmup_epochs * len(train_loader) // cfg.optimizer.accumulation_steps\n", + " \n", + " # scale learning rate by batch size\n", + " cfg.optimizer.ssm_lr = cfg.optimizer.ssm_base_lr * cfg.training.per_device_batch_size * cfg.optimizer.accumulation_steps\n", + "\n", + "print(OmegaConf.to_yaml(cfg))" + ], + "id": "810edc0798ad7622", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Now, create the model using the configuration defined above.", + "id": "2ca62d33ebabfdb2" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "from event_ssm.ssm import init_S5SSM\n", + "from event_ssm.seq_model import BatchClassificationModel\n", + "\n", + "ssm_init_fn = init_S5SSM(**cfg.model.ssm_init)\n", + "model = BatchClassificationModel(\n", + " ssm=ssm_init_fn,\n", + " num_classes=test_dataset.num_classes,\n", + " num_embeddings=test_dataset.num_channels,\n", + " **cfg.model.ssm,\n", + ")" + ], + "id": "83e2062ea8b4fe02", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "\n", + "Initialize the training state by feeding a dummy input" + ], + "id": "c855737447f70896" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "import jax\n", + "from event_ssm.train_utils import init_model_state\n", + "\n", + "# pick the first batch from the training loader\n", + "batch = next(iter(train_loader))\n", + "inputs, targets, timesteps, lengths = batch\n", + "\n", + "# initialize the training state\n", + "key = jax.random.PRNGKey(cfg.seed)\n", + "state = init_model_state(key, model, inputs, timesteps, lengths, cfg.optimizer)" + ], + "id": "d4def1c65952a8ba", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Inspect the model\n", + "The model parameters are accessible as part of the training state. \n", + "We will look into the spectrum of the recurrent operator here.\n", + "The model was initialized with a single stage of blocks." + ], + "id": "424bce6010abb8f1" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "def get_spectrum(state):\n", + " params = state.params['encoder']['stages_0']\n", + " lambda_bar = []\n", + " time_scales = []\n", + " for name, sequence_layer in params.items():\n", + " # read lambda parameters\n", + " Lambda_im = sequence_layer['S5SSM_0']['Lambda_im']\n", + " Lambda_re = sequence_layer['S5SSM_0']['Lambda_re']\n", + " \n", + " # read and compute delta and Lambda\n", + " delta = np.exp(sequence_layer['S5SSM_0']['log_step'][:, 0])\n", + " Lambda = Lambda_re + 1j * Lambda_im\n", + " \n", + " # compute lambda_bar and time scales\n", + " lambda_bar.append(np.exp(Lambda * delta))\n", + " time_scales.append(1 / np.abs(Lambda) / delta)\n", + " return lambda_bar, time_scales\n", + "spectrum, time_scales = get_spectrum(state)" + ], + "id": "5a1602ed4a962265", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Plot the spectrum of the recurrent operator and the corresponding time scales upon initialization.", + "id": "2d1377b48ebf3728" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "def plot_spectrum(spectrum):\n", + " fig, axes = plt.subplots(1, 6, figsize=(len(spectrum) * 4, 4))\n", + " # draw the unit circle\n", + " theta = np.linspace(0, 2 * np.pi, 100) # 100 points from 0 to 2*pi\n", + " x = np.cos(theta)\n", + " y = np.sin(theta)\n", + " \n", + " # plot the spectrum\n", + " for i, (ax, layer) in enumerate(zip(axes, spectrum)):\n", + " ax.plot(x, y, 'r', linewidth=1)\n", + " ax.scatter(np.real(layer), np.imag(layer), marker='o', alpha=0.8)\n", + " \n", + " # format axis\n", + " ax.set_title(f'Layer {i}')\n", + " ax.set_aspect('equal', adjustable='box')\n", + " ax.set_xlim(-1.1, 1.1)\n", + " ax.set_ylim(-1.1, 1.1)\n", + " \n", + " plt.tight_layout()\n", + " plt.show()\n", + " \n", + "plot_spectrum(spectrum)" + ], + "id": "9ff987be826d7314", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "def plot_time_scales(time_scales):\n", + " log_scales = np.log2(np.stack(time_scales).flatten())\n", + " fig, ax = plt.subplots(1, 1, figsize=(6, 6))\n", + " ax.hist(log_scales)\n", + " \n", + " # format axis\n", + " max_scale = np.max(np.ceil(log_scales))\n", + " min_scale = np.min(np.floor(log_scales))\n", + " ax.set_xlim((min_scale, max_scale))\n", + " xticks = np.arange(1 + max_scale - min_scale) + min_scale\n", + " ax.set_xticks(xticks, (2 ** xticks).astype(np.int32))\n", + " ax.set_title('Distribution of time scales')\n", + " ax.set_xlabel('Time scale')\n", + " ax.set_ylabel('Count')\n", + " plt.show()\n", + " \n", + "plot_time_scales(time_scales)" + ], + "id": "c1779ae6f5f72b44", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Train the model\n", + "For training, we implemented a trainer module that makes training as easy as possible. The trainer module hides some boilerplate code for training from the user and provides a simple interface to train the model. It loops through the data loader, computes the loss, and updates the model parameters. Therefore, we need to define training_step and validation_step functions that the loop calls upon the model. These are implemented already, and can be used here." + ], + "id": "b4970c69f459df1d" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "from event_ssm.train_utils import training_step, evaluation_step\n", + "from event_ssm.trainer import TrainerModule\n", + "\n", + "# just-in-time compile the training and evaluation functions\n", + "train_step = jax.jit(training_step)\n", + "eval_step = jax.jit(evaluation_step)\n", + "\n", + "# initialize the trainer module\n", + "num_devices = 1\n", + "trainer = TrainerModule(\n", + " train_state=state,\n", + " training_step_fn=train_step,\n", + " evaluation_step_fn=eval_step,\n", + " world_size=num_devices,\n", + " config=cfg,\n", + ")" + ], + "id": "61ab72052c47f47c", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "We are now ready to start the training loop. \n", + "\n", + "**Note:** JAX compiles your program just-in-time (JIT) to optimize performance. This means that the first iteration of the training loop will be slower than the following ones. " + ], + "id": "d66a413bc0ac7d2b" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "# generate random key for dropout\n", + "key, dropout_key = jax.random.split(key)\n", + "\n", + "# train the model\n", + "trainer.train_model(\n", + " train_loader=train_loader,\n", + " val_loader=val_loader,\n", + " test_loader=test_loader,\n", + " dropout_key=dropout_key\n", + ")" + ], + "id": "9d5ab8aa623db697", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Inspect the trained model\n", + "We now have a trained toy model on the SHD dataset.\n", + "Let's look into the spectrum of the recurrent operator after training." + ], + "id": "a929b74f8ce235e5" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "spectrum, time_scales = get_spectrum(trainer.train_state)\n", + "plot_spectrum(spectrum)\n", + "plot_time_scales(time_scales)" + ], + "id": "3281a08743303429", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Assignment\n", + "The function `apply_ssm` in `event_ssm/ssm.py` implements the recurrent operator with an associative scan. On highly parallel GPUs, this can speed up training on very long sequences. \n", + "On CPUs however, the overhead of the scan operation can slow down training. \n", + "Your task is to implement a CPU-friendly version of the recurrent operator in `event_ssm/ssm.py` and compare the training time with the original implementation.\n", + "We suggest to implement a step-by-step recurrence with [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) instead of the currenlty used [`jax.lax.associative_scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.associative_scan.html) for this purpose." + ], + "id": "28ad17b8c7b61230" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}