Code for the "Interpreting convolutional neural networks to study wide-field amacrine cell inhibition in the retina" paper
- Clone the repository
- Create a conda environment
conda create -n wac_env python=3.10 - Install the requirements
pip install -r requirements.txt
The original data is available at G-Node We used the natural movie part of the dataset to train the CNN models.
The data is expected to be in the following structure:
<your path>/data/marmoset_data/
fixations/
fixations_2022_2023.txt
stas/
cell_data_01_WN_stas_cell_0.npy
cell_data_01_WN_stas_cell_1.npy
...
cell_data_01_WN_stas_cell_369.npy
cell_data_04_WN_stas_cell_0.npy
cell_data_04_WN_stas_cell_1.npy
...
responses/
cell_responses_01_fixation_movie.pkl
cell_responses_04_fixation_movie.pkl
config_r14_fixed.yaml
stimuli/
images/
00000_img_10088.npy
00001_img_10089.npy
...
We use data from both retinas. Retina 1 corresponds to 20220412_SN_252MEA6010_le_s4 and has the key '01' and Retina 2 corresponds to 20220426_SS_252MEA6010_le_n3 and has the key '04' in this structure.
- Fixations
- The fixations have the same structure as the fixation files provided in the stimulus folder
- It's necessary to append the fixations from the two seeds (2022 and 2023) into a single file
- Because we are training on a subsampled version of the data, the centers of the fixations need to be recomputed using the
change_fixation_file_based_on_resizing_ratiofunction indatasets/natural_stimuli/create_dataset.py - The first 5100 frames of the fixation files are the test set
- The fixation file used to train the models in this paper, with the stimulus subsampled 4 times, is in
configsasfixations_2022_2023.txt
- STAs
- The spike-triggered averages (STAs) are computed from the responses to the white noise stimulus in the same data repository
- They are available in the
stasfolder of the white noise part of the dataset - For the purposes of using them within this repository, the STAs should be stored in numpy arrays, where each array corresponds to a cell and are expected to be named as
cell_data_<dataset_id>_WN_stas_cell_<cell_id>.npy
- Responses
- The responses are stored in pickle files, where each file corresponds to one retina
- The responses are expected to be in the format of a dictionary with keys
train_reponses,test_responses, andseed. - The
train_responsesandtest_responsesare numpy arrays of shape(n_cells, n_frames, n_trials), wheren_cellsis the number of cells andn_framesis the number of frames in the stimulus andn_trialsis the number of trials.
- The responses are expected to be in the format of a dictionary with keys
- The response directory also contains a config file, which is used to configure the paths and cells.
- It contains precomputed reliability values for the cells, which are used to filter out cells with low reliability.
- The originally used config file is
config_r14_fixed.yaml, which is available in the repository.
- The responses are stored in pickle files, where each file corresponds to one retina
- Stimuli
- The stimuli are stored in the
stimuli/folder - We used a resized version of the original images, they were 1/4 of the original size --
200x150pixels - The original images were padded with zeros during training, thus also the resized images are expected to be padded by 50 pixels of zero from each side and have a total size of
300x250pixels. - To resize the images we used the
save_subsampled_datasetfunction indatasets/natural_stimuli/create_dataset.py - The images are expected to be in the format of numpy arrays, where each array corresponds to a frame and is named as
AAAAA_img_BBBBB.npy, whereAAAAAis the order of the frame in the stimulus andBBBBBis the order for the frame within the original movie.
- The stimuli are stored in the
To train a model, use the run_multi_retinal_marmoset_cnn.py
script. The model parameters can be passed as command line arguments.
The configs/model_configs/ folder contains configurations for models CNN 3, CNN 4, CNN 3B and CNN 4B.
We use wandb to log experiments, so you need to have a wandb account and be logged in to use it.
The train and validation loss and correlation between the model predictions and the responses are automatically logged to wandb
To get the test model performance, you can use the evaluations/multiretinal_model_evaluation.py script.
The evaluations/multiretinal_model_evaluation.py can be fed a config file name with a config that defines the tested model (or models if you want the test perfromance of an ensemble).
The repo contains an example of such a config called test_cnn.json.
To generate MEIs, run the find_mei.py script.
The options are passed as command line arguments.
The settings used to run the MEI optimalization in the paper are in configs/mei_configs/mei_config.yaml.
Again, pass the model config with the information about the trained model(s).
The example of what this config looks like is test_cnn.json in configs.
Suppressive surrounds can be generated only after MEIs.
To generate suppressive surrounds run the find_suppressive_mei.py.
The options are passed as command line arguments.
The settings used to run the MEI optimalization in the paper are in configs/mei_configs/smei_config.yaml.
The code for evaluating self and cross-suppression is in the notebooks folder
After generating MEIs and Suppressive surround, they can be saved into dictionaries which are used in the different functions within the jupyter notebooks.
To get a dictionary of MEIs use the cell_e_mei.ipynb notebook.
It loads the models based on the model(s) configs you provide (example here is the test_cnn.json) and requires setting further MEI generating parameters you used.
It then saves these dictionaries into the MEI_DIR as pickle files from where they can be loaded and used for the plotting functions in mei_plots.ipynb.
To get a dictionary of suppressive surrounds, use the save_smei_dict.ipynb notebook.
It loads the models based on the model(s) configs you provide (example here is the test_cnn.json) and other MEI and suppressive surround generating parameters.
It then saves these dictionaries into the SMEI_DIR from where they can be loaded and used for the plotting functions in plot_cross_fr_suppression.ipynb.