Skip to content
Open

Rtt #29

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
934 changes: 934 additions & 0 deletions notebooks/draw_raster.ipynb

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions script/finetune_eval_multi_session.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ conda activate ibl-fm

cd ../

python src/finetune_eval_multi_session.py --mask_ratio 0.3 \
python src/finetune_eval_multi_session.py --mask_ratio 0.2 \
--mask_mode $MASK_MODE \
--model_name $MODEL_NAME \
--prompting $PROMPTING \
Expand All @@ -55,7 +55,8 @@ python src/finetune_eval_multi_session.py --mask_ratio 0.3 \
--base_path $SCRATCH/IBL_foundation_model \
--num_train_sessions $NUM_TRAIN_SESSIONS \
--test_eid $TEST_EID \
--use_dummy
--use_nlb \
--seed 42

cd script

Expand Down
130 changes: 130 additions & 0 deletions src/configs/finetune_nlb_trainer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
seed: 42

savestring: test
wandb_project: multi-session
log_to_wandb: false

verbosity: 0

# wandb configuration
wandb:
use: true
entity: null
project: multi-session
run_name: 671c7ea7

# Logging directories
dirs:
checkpoint_dir: checkpoints # save model state dicts (todo optimizer states)
log_dir: results # save tensorboard logs
dataset_cache_dir: checkpoints/datasets_cache # save dataset cache
# dataset_dir: /home/ppwang/neural-data-transformers/data/lfads_lorenz.h5
# pretrained_model_path: checkpoints/models/ndt1/ssl/temporal/model_best_multi-session.pt
dataset_dir: ibl-foundation-model/671c7ea7-6726-4fbe-adeb-f89c2c8e489b
behav_dir: 671c7ea7-6726-4fbe-adeb-f89c2c8e489b
huggingface_org: ibl-foundation-model
pretrain_model_path: /scratch/yl6624/IBL_foundation_model/results_/train/num_session_34/model_NDT1/method_ssl/mask_all/stitch_True/model_best.pt



# Training configuration
training:
num_epochs: 1000
train_batch_size: 16
test_batch_size: 16
shuffle_test_dataloader: false # Shuffle test dataloader between epochs

save_plot_every_n_epochs: 10 # Plot the model output every n epochs
save_every: 50 # Save checkpoint
eval_every: null # Eval model

dummy: true # Use dummy data


# Model configuration.
# Will be passed to the model __init__ method if a model is not passed to the Trainer __init__ method.
model:
model_class: null # Any registered model class name.

# Data configuration.
data:
# dataset_name: lorenz # Any registered dataset name.
dataset_name: ibl # Any registered dataset name.
dataset_class: ssl # Any registered dataset class name.

# Load raw dataset if a dataset is not passed to the Trainer __init__ method.
hf_dataset_name: null # from huggingface
json_dataset_name: null # from json file

train_name: train # name of the train split in the raw datasete
test_name: test # name of the test split in the raw datasete
train_len: null # used length of the train dataset. null to use all
test_len: null # used length of the test dataset. null to use all

LOG_EPSILON: 1.e-7 # epsilon for log transformation, to prevent log(0)
use_lograte: True # use lograte

max_time_length: 30 # max_time_length has to be a multiple of time patch size
max_space_length: 668 # max_space_length has to be a multiple of space patch size
patching: true # patching the neurons
sort_by_depth: false
brain_region: all

include_behav: false # include behavior data
target: null

load_meta: true

num_sessions: 40
train_session_eid: ['72cb5550-43b4-4ef0-add5-e4adfdfb5e02', '51e53aff-1d5d-4182-a684-aba783d50ae5', 'd57df551-6dcb-4242-9c72-b806cff5613a', 'e2b845a1-e313-4a08-bc61-a5f662ed295e', 'd2832a38-27f6-452d-91d6-af72d794136c']
test_session_eid: []

split_method: predefined # random_split/session_based/predefined

use_aligned_test: False

sort_by_depth: false
sort_by_region: false
brain_region: all

spike_augmentation: false

use_re: true

# Method configuration. Contains kwargs that are specific to the training method.
method:

# Passed to the model __init__ method together with the model config
model_kwargs:
method_name: ssl #ssl

use_lograte: true
loss: poisson_nll # poisson_nll # mse/other distirbutions (todo)
output_size: 2
clf: false
reg: false

# Passed to the Dataset __init__ method together with the raw dataset.
dataset_kwargs: {}

# Passed to the DataLoader __init__ method.
dataloader_kwargs:
# Contains which keys to pad, along which dimension with which value
pad_dict:
spikes:
dim: 0
side: right
value: 0
truncate: null
min_length: null


optimizer:
gradient_accumulation_steps: 1
lr: 1.e-4
wd: 0.01
eps: 1.e-8
warmup_pct: 0.15 # cosine/linear
gamma: 0.95 # step
div_factor: 10 # cosine
scheduler: cosine # step/cosine/linear
4 changes: 2 additions & 2 deletions src/configs/finetune_sessions_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ verbosity: 0

# wandb configuration
wandb:
use: true
use: false
entity: null
project: multi-session
run_name: 671c7ea7
Expand Down Expand Up @@ -63,7 +63,7 @@ data:
LOG_EPSILON: 1.e-7 # epsilon for log transformation, to prevent log(0)
use_lograte: True # use lograte

max_time_length: 100 # max_time_length has to be a multiple of time patch size
max_time_length: 30 # max_time_length has to be a multiple of time patch size
max_space_length: 668 # max_space_length has to be a multiple of space patch size
patching: true # patching the neurons
sort_by_depth: false
Expand Down
2 changes: 1 addition & 1 deletion src/configs/ndt1_stitching.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ encoder:
masker:
force_active: true
mode: all # masking mode, e.g. all, temporal
ratio: 0.3 # ratio of data to predict
ratio: 0.2 # ratio of data to predict
zero_ratio: 1.0 # of the data to predict, ratio of zeroed out
random_ratio: 1.0 # of the not zeroed, ratio of randomly replaced
expand_prob: 0.0 # probability of expanding the mask in ``temporal`` mode
Expand Down
2 changes: 1 addition & 1 deletion src/configs/ndt1_stitching_prompting.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ encoder:
masker:
force_active: true
mode: all # masking mode
ratio: 0.3 # ratio of data to predict
ratio: 0.2 # ratio of data to predict
zero_ratio: 1.0 # of the data to predict, ratio of zeroed out
random_ratio: 1.0 # of the not zeroed, ratio of randomly replaced
expand_prob: 0.0 # probability of expanding the mask in ``temporal`` mode
Expand Down
108 changes: 108 additions & 0 deletions src/configs/ndt1_stitching_prompting_temporal.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
model_class: NDT1
# NDT1 stitching mask scheme all: all masking/stitching

encoder:

from_pt: null

stitching: true

# Mask spikes
masker:
force_active: true
mode: temporal # masking mode
ratio: 0.2 # ratio of data to predict
zero_ratio: 1.0 # of the data to predict, ratio of zeroed out
random_ratio: 1.0 # of the not zeroed, ratio of randomly replaced
expand_prob: 0.0 # probability of expanding the mask in ``temporal`` mode
max_timespan: 1 # max span of mask if expanded
channels: null
# [579, 21, 486, 316, 521, 71, 32, 561, 445, 470, 205, 41, 230, 306, 592, 148, 484, 592, 181, 94, 120, 601, 414, 497, 263, 447, 16, 537, 396, 323, 68, 444, 325, 137, 223, 434, 131, 363, 280, 419, 332, 218, 130, 580, 381, 560, 576, 145, 280, 234, 258, 37, 9, 324, 199, 610, 529, 196, 164, 348] # neurons to mask in ``co-smooth`` mode
timesteps: null # time steps to mask in ``forward-pred`` mode
mask_regions: ['all'] # brain regions to mask in ``inter-region`` mode
target_regions: ['all'] # brain regions to predict in ``intra-region`` mode
n_mask_regions: 1 # number of regions to choose from the list of mask_regions or target_regions

# Context available for each timestep
context:
forward: -1
backward: -1

# Normalize and add noise
norm_and_noise:
active: false
smooth_sd: 2 # gaussian smoohing
norm: "zscore" # which normalization layer to use (null/layernorm/scalenorm/zscore)
eps: 1.e-7 # avoid dividing by zero when normalizing padded spikes
white_noise_sd: 1.0 # gaussian noise added to the inputs 1.0 originally
constant_offset_sd: 0.2 # gaussian noise added to the inputs but contsnat in the time dimension 0.2 originally


# Embedding layer
embedder:
n_channels: 668 # number of neurons recorded
n_blocks: 24 # number of blocks of experiments
n_dates: 24 # number of days of experiments
max_F: 100 # max feature len in timesteps

mode: linear # linear/embed/identity
mult: 2 # embedding multiplier. hiddden_sizd = n_channels * mult
adapt: false # adapt the embedding layer for each day
pos: true # embed position
act: softsign # activation for the embedding layers
scale: 1 # scale the embedding multiplying by this number
bias: true # use bias in the embedding layer
dropout: 0.2 # dropout in embedding layer

fixup_init: false # modify weight initialization
init_range: 0.1 # initialization range for embeddings
spike_log_init: false # special initialization
max_spikes: 0 # max number of spikes in a single time bin

tokenize_binary_mask: false
use_prompt: true
use_session: true

stack:
active: false # wether to stack consecutive timesteps
size: 32 # number of consecutive timesteps to stack
stride: 4 # stacking stride


# Transformer
transformer:
n_layers: 5 # number of transformer layers
hidden_size: 512 # hidden space of the transformer
use_scalenorm: false # use scalenorm instead of layernorm
use_rope: false # use rotary postional encoding
rope_theta: 10000.0 # rope angle of rotation


n_heads: 8 # number of attentiomn heads
attention_bias: true # learn bias in the attention layers

act: gelu # activiation function in mlp layers
inter_size: 1024 # intermediate dimension in the mlp layers
mlp_bias: true # learn bias in the mlp layers

dropout: 0.4 # dropout in transformer layers
fixup_init: true # modify weight initialization

# Projection to factor space
factors:
active: false # project from hidden_size to factors
size: 8 # factors size
act: relu # activation function after projecting to factors
bias: true # use bias in projection to factors
dropout: 0.0 # dropout in projection to factors
fixup_init: false # modify weight initialization
init_range: 0.1 # initialization range for factors projetion

decoder:
from_pt: null






Loading