forked from RyanWangZf/PAC-Bayes-IB
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_vgg_PIB.py
79 lines (66 loc) · 2.2 KB
/
run_vgg_PIB.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
'''train VGG model with PAC-Bayes information bottleneck.
'''
import numpy as np
import torch
import os
from src.dataset import load_data
from src.utils import img_preprocess, setup_seed, predict, eval_metric, feature_map_size
from src.utils import train
from src.models import VGG
from src.pib_utils import train_pib
__data_set__ = 'cifar10'
__prior_ckpt__ = './checkpoints/{}/vgg_prior.pt'.format(__data_set__)
__save_ckpt__ = './checkpoints/{}/vgg_pib.pt'.format(__data_set__)
opt = {
'num_epoch':100,
'batch_size':32, # 32
'lr':1e-4,
'weight_decay':0,
'beta':1e-1,
'noise_scale':1e-10,
'schedule': [50, 80],
'early_stop': 10,
}
if not os.path.exists('./checkpoints/{}'.format(__data_set__)):
os.makedirs('./checkpoints/{}'.format(__data_set__))
# set random seed
setup_seed(2020)
# load data & preprocess
x_tr, y_tr, x_va, y_va, x_te, y_te = load_data(__data_set__)
all_tr_idx = np.arange(len(x_tr))
num_class = np.unique(y_va).shape[0]
x_tr, y_tr = img_preprocess(x_tr, y_tr,)
x_va, y_va = img_preprocess(x_va, y_va,)
x_te, y_te = img_preprocess(x_te, y_te,)
# load model
model = VGG(num_classes=num_class, dropout_rate=0.0, last_feature_map_size=feature_map_size(__data_set__))
model.cuda()
# get prior on the validation set
if os.path.exists(__prior_ckpt__):
print("load prior.")
model.load_state_dict(torch.load(__prior_ckpt__))
else:
train(model, np.arange(len(y_va)), x_va, y_va, x_va, y_va, 10, 32, 5e-5, 0, __prior_ckpt__, 5)
w0_dict = dict()
for param in model.named_parameters():
w0_dict[param[0]] = param[1].clone().detach() # detach but still on gpu
model.w0_dict = w0_dict
model._initialize_weights()
print("done get prior weights")
# start training model
info_dict, loss_acc_dict = train_pib(model, all_tr_idx,
x_tr, y_tr, x_va, y_va,
num_epoch=opt['num_epoch'],
batch_size=opt['batch_size'],
lr=opt['lr'],
weight_decay=opt['weight_decay'],
beta=opt['beta'],
early_stop_ckpt_path=__save_ckpt__,
early_stop_tolerance=opt['early_stop'],
noise_scale=opt['noise_scale'],
schedule=opt['schedule'],
)
# evaluate test acc
pred_te = predict(model, x_te)
acc_te = eval_metric(pred_te, y_te, num_class)
print("test acc:", acc_te)