-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcontinue_training.py
More file actions
71 lines (62 loc) · 2.71 KB
/
Copy pathcontinue_training.py
File metadata and controls
71 lines (62 loc) · 2.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argparse
import yaml
from training.trainers import continue_training_cnn
from utils.global_functions import DATA_DIR
parser = argparse.ArgumentParser()
parser.add_argument('--model_file', default="2025-04-17_earnest-sea-170", type=str)
parser.add_argument('--model_seed', default=16, type=int)
parser.add_argument('--data_dir', default=DATA_DIR, type=str)
parser.add_argument('--setting', default='nm', type=str)
parser.add_argument('--config_file', default='config_s4_fixed_r14', type=str)
parser.add_argument('--directory_prefix', default='cnn_nm_nina')
parser.add_argument('--data_type', default='marmoset', type=str)
parser.add_argument('--retina_index', default='01', type=str)
parser.add_argument('--continue_optim', default=0, type=int)
parser.add_argument('--lr', default=0.0001, type=float)
parser.add_argument('--save_every_epoch', default=20, type=int)
parser.add_argument('--flip_imgs', default=0, type=int)
parser.add_argument('--model_name', default='model_e_5.m')
parser.add_argument('--cluster_number', default=1, type=int)
if __name__ == '__main__':
args = parser.parse_args()
model_file = args.model_file
model_seed = args.model_seed
data_dir = args.data_dir
setting = args.setting
data_type = args.data_type
retina_index = args.retina_index
with open(f'{data_dir}/data/{data_type}_data/responses/{args.config_file}.yaml', 'rb') as file:
config_dict = yaml.unsafe_load(file)
directory = f"{data_dir}/wandb/models/{args.directory_prefix}_ev_0.15_cnn/{data_type}/retina{retina_index}/cell_None/readout_isotropic/gmp_0/"
# training parameters
update_config= {'optimizer_config': {'lr': args.lr}}
continue_training_cnn(
directory,
model_file,
model_seed,
data_dir,
setting,
config_dict,
dataloader_config=None,
fancy_nonlin=None,
freeze=False,
stimulus_seed=0,
fixation_file='None',
time_chunk_size=None,
batch_size=16,
data_type="marmoset",
model_fn="models.MultiRetinalFactorizedEncoder.build_trained",
device=f"cuda",
directory_prefix=args.directory_prefix,
stopper_patience=20,
multiretinal=True,
change_dict=update_config,
wandb_log=True,
continue_optimizer=args.continue_optim==1,
save_every_epoch=args.save_every_epoch,
config_file=args.config_file,
flip_imgs=args.flip_imgs==1,
model_name=args.model_name, #give the checkpoint name here
loss='clustering',
cluster_number=args.cluster_number,
)