Skip to content

Commit d654fc7

Browse files
committed
Added the basic training code for attention models
1 parent 51f8190 commit d654fc7

File tree

3 files changed

+206
-16
lines changed

3 files changed

+206
-16
lines changed

Visual_Attention/attention_models.py

-3
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@ def forward(self, v, q):
2222
def logits(self, v, q):
2323
num_objs = v.size(1)
2424
q = q.unsqueeze(1).repeat(1, num_objs, 1)
25-
print(q.size())
26-
print(v.size())
2725
vq = torch.cat((v, q), 2)
28-
print(vq.size())
2926
joint_repr = self.nonlinear(vq)
3027
logits = self.linear(joint_repr)
3128
return logits

Visual_Attention/train_models.py

+106-13
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vqa_dataset_attention import *
2424
import torch.nn as nn
2525
import random
26+
import utils
2627

2728
def instance_bce_with_logits(logits, labels):
2829
assert logits.dim() == 2
@@ -38,6 +39,56 @@ def compute_score_with_logits(logits, labels):
3839
scores = (one_hots * labels)
3940
return scores
4041

42+
def evaluate_model(model, valid_dataloader,device):
43+
score = 0
44+
Validation_loss = 0
45+
upper_bound = 0
46+
num_data = 0
47+
V_loss=0
48+
print('Validation started')
49+
#i, (feat, quest, label, target)
50+
for data in tqdm(valid_dataloader):
51+
52+
feat, quest, label, target = data
53+
feat = feat.to(device)
54+
quest = quest.to(device)
55+
target = target.to(device) # true labels
56+
57+
pred = model(feat, quest, target)
58+
loss = instance_bce_with_logits(pred, target)
59+
V_loss += loss.item() * feat.size(0)
60+
batch_score = compute_score_with_logits(pred, target.data).sum()
61+
score += batch_score
62+
upper_bound += (target.max(1)[0]).sum()
63+
num_data += pred.size(0)
64+
65+
score = score / len(valid_dataloader.dataset)
66+
V_loss /= len(valid_dataloader.dataset)
67+
upper_bound = upper_bound / len(valid_dataloader.dataset)
68+
print(score,V_loss)
69+
return score, upper_bound, V_loss
70+
71+
def single_batch_run(model,train_dataloader,valid_dataloader,device,output_folder,optim):
72+
feat_train, quest_train, label_train, target_train = next(iter(train_dataloader))
73+
feat_train = feat_train.to(device_select)
74+
quest_train = quest_train.to(device_select)
75+
target_train = target_train.to(device_select) # true labels
76+
pred = model(feat_train, quest_train, target_train)
77+
loss = instance_bce_with_logits(pred, target_train)
78+
logger = utils.Logger(os.path.join(output_folder, 'log_single_batch.txt'))
79+
#print(loss)
80+
loss.backward()
81+
nn.utils.clip_grad_norm_(model.parameters(), 0.25)
82+
optim.step()
83+
optim.zero_grad()
84+
batch_score = compute_score_with_logits(pred, target_train.data).sum()
85+
model.train(False)
86+
eval_score, bound, V_loss = evaluate_model(model, valid_dataloader,device)
87+
model.train(True)
88+
#logger.write('epoch %d, time: %.2f' % (epoch, time.time()-t))
89+
#logger.write('\ttrain_loss: %.3f, score: %.3f' % (total_loss, train_score))
90+
logger.write('\teval loss: %.3f, score: %.3f (%.3f)' % (V_loss, 100 * eval_score, 100 * bound))
91+
4192
def parse_args():
4293
parser = argparse.ArgumentParser()
4394
parser.add_argument('--eval', action='store_true', help='set this to evaluate.')
@@ -52,7 +103,7 @@ def parse_args():
52103
parser.add_argument('--norm', type=str, default='weight', help='weight, batch, layer, none')
53104
parser.add_argument('--model', type=str, default='A3x2')
54105
parser.add_argument('--output', type=str, default='saved_models/')
55-
parser.add_argument('--batch_size', type=int, default=128)
106+
parser.add_argument('--batch_size', type=int, default=512)
56107
parser.add_argument('--weight_decay', type=float, default=0)
57108
parser.add_argument('--optimizer', type=str, default='Adamax', help='Adam, Adamax, Adadelta, RMSprop')
58109
parser.add_argument('--initializer', type=str, default='kaiming_normal')
@@ -67,11 +118,14 @@ def parse_args():
67118
feats_data_path="/data/digbose92/VQA/COCO/train_hdf5_COCO/"
68119
data_root="/proj/digbose92/VQA/VisualQuestion_VQA/common_resources"
69120
npy_file="../../VisualQuestion_VQA/Visual_All/data/glove6b_init_300d.npy"
121+
output_folder="/proj/digbose92/VQA/VisualQuestion_VQA/Visual_Attention/results"
70122
seed = 0
71123
args = parse_args()
72124
#device_selection
73-
device=1
74-
torch.cuda.set_device(device)
125+
device_ids=[0,1]
126+
#device_select=1
127+
#torch.cuda.set_device(device_select)
128+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
75129

76130
if args.seed == 0:
77131
seed = random.randint(1, 10000)
@@ -84,16 +138,22 @@ def parse_args():
84138
torch.cuda.manual_seed(args.seed)
85139

86140
#train dataset
87-
train_dataset=Dataset_VQA(img_root_dir=image_root_dir,feats_data_path=feats_data_path,dictionary=dictionary,dataroot=data_root,arch_choice="resnet152",layer_option="pool")
88-
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8)
141+
train_dataset=Dataset_VQA(img_root_dir=image_root_dir,feats_data_path=feats_data_path,dictionary=dictionary,choice='train',dataroot=data_root,arch_choice="resnet152",layer_option="pool")
142+
valid_dataset=Dataset_VQA(img_root_dir=image_root_dir,feats_data_path=feats_data_path,dictionary=dictionary,choice='val',dataroot=data_root,arch_choice="resnet152",layer_option="pool")
143+
144+
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=10)
145+
val_loader=DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8)
146+
print(len(train_loader))
147+
print(len(val_loader))
89148
total_step=len(train_loader)
90149

91150
#model related issues
92151
model = attention_baseline(train_dataset, num_hid=args.num_hid, dropout= args.dropout, norm=args.norm,\
93152
activation=args.activation, drop_L=args.dropout_L, drop_G=args.dropout_G,\
94153
drop_W=args.dropout_W, drop_C=args.dropout_C)
95154

96-
model=model.to(device)
155+
#model=model.to(device_select)
156+
97157

98158
if args.initializer == 'xavier_normal':
99159
model.apply(weights_init_xn)
@@ -105,7 +165,9 @@ def parse_args():
105165
model.apply(weights_init_ku)
106166

107167
model.w_emb.init_embedding(npy_file)
108-
168+
if torch.cuda.device_count() > 1:
169+
print("Let's use", torch.cuda.device_count(), "GPUs!")
170+
model=torch.nn.DataParallel(model, device_ids=device_ids).to(device)
109171

110172
if args.optimizer == 'Adadelta':
111173
optim = torch.optim.Adadelta(model.parameters(), rho=0.95, eps=1e-6, weight_decay=args.weight_decay)
@@ -115,39 +177,70 @@ def parse_args():
115177
optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.weight_decay)
116178
else:
117179
optim = torch.optim.Adamax(model.parameters(), weight_decay=args.weight_decay)
118-
180+
181+
logger = utils.Logger(os.path.join(output_folder, 'log.txt'))
182+
best_eval_score = 0
119183
print('Starting training')
184+
185+
#placeholder for checking training and testuing working or not
186+
#single_batch_run(model,train_loader,val_loader,device_select,output_folder,optim)
187+
188+
device_select=0
189+
120190
for epoch in range(args.epochs):
121191
total_loss = 0
122192
train_score = 0
123193
t = time.time()
124194
correct = 0
125195
step=0
196+
start_time=time.time()
126197
for i, (feat, quest, label, target) in enumerate(train_loader):
198+
127199
feat = feat.to(device)
128200
quest = quest.to(device)
129201
target = target.to(device) # true labels
130202

131203
pred = model(feat, quest, target)
132204
loss = instance_bce_with_logits(pred, target)
133-
print(loss)
205+
#print(loss)
134206
loss.backward()
135-
nn.utils.clip_grad_norm(model.parameters(), 0.25)
207+
nn.utils.clip_grad_norm_(model.parameters(), 0.25)
136208
optim.step()
137209
optim.zero_grad()
138210

139211
batch_score = compute_score_with_logits(pred, target.data).sum()
140212
total_loss += loss.item() * feat.size(0)
141213
train_score += batch_score
142214
if(step%10==0):
143-
#optimizer.zero_grad()
144-
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
145-
.format(epoch, args.epochs, step, total_step, loss.item()))
215+
end_time=time.time()
216+
time_elapsed=end_time-start_time
217+
218+
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Time elapsed: {:.4f}'
219+
.format(epoch, args.epochs, step, total_step, loss.item(), time_elapsed))
220+
start_time=end_time
146221
step=step+1
147222

148223
total_loss /= len(train_loader.dataset)
149224
train_score = 100 * train_score / len(train_loader.dataset)
150225

226+
print('Epoch [{}/{}], Training Loss: {:.4f}, Training Accuracy {:.4f}'
227+
.format(epoch, args.epochs, total_loss, train_score))
228+
229+
model.train(False)
230+
eval_score, bound, V_loss = evaluate_model(model, val_loader, device)
231+
model.train(True)
232+
233+
logger.write('epoch %d, time: %.2f' % (epoch, time.time()-t))
234+
logger.write('\ttrain_loss: %.3f, score: %.3f' % (total_loss, train_score))
235+
logger.write('\teval loss: %.3f, score: %.3f (%.3f)' % (V_loss, 100 * eval_score, 100 * bound))
236+
237+
if eval_score > best_eval_score:
238+
model_path = os.path.join(output_folder, 'model.pth')
239+
torch.save(model.state_dict(), model_path)
240+
best_eval_score = eval_score
241+
242+
243+
151244

152245

153246

Visual_Attention/utils.py

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from __future__ import print_function
2+
3+
import errno
4+
import os
5+
import numpy as np
6+
from PIL import Image
7+
import torch
8+
import torch.nn as nn
9+
10+
11+
EPS = 1e-7
12+
13+
14+
def assert_eq(real, expected):
15+
assert real == expected, '%s (true) vs %s (expected)' % (real, expected)
16+
17+
18+
def assert_array_eq(real, expected):
19+
assert (np.abs(real-expected) < EPS).all(), \
20+
'%s (true) vs %s (expected)' % (real, expected)
21+
22+
23+
def load_folder(folder, suffix):
24+
imgs = []
25+
for f in sorted(os.listdir(folder)):
26+
if f.endswith(suffix):
27+
imgs.append(os.path.join(folder, f))
28+
return imgs
29+
30+
31+
def load_imageid(folder):
32+
images = load_folder(folder, 'jpg')
33+
img_ids = set()
34+
for img in images:
35+
img_id = int(img.split('/')[-1].split('.')[0].split('_')[-1])
36+
img_ids.add(img_id)
37+
return img_ids
38+
39+
40+
def pil_loader(path):
41+
with open(path, 'rb') as f:
42+
with Image.open(f) as img:
43+
return img.convert('RGB')
44+
45+
46+
def weights_init(m):
47+
"""custom weights initialization."""
48+
cname = m.__class__
49+
if cname == nn.Linear or cname == nn.Conv2d or cname == nn.ConvTranspose2d:
50+
m.weight.data.normal_(0.0, 0.02)
51+
elif cname == nn.BatchNorm2d:
52+
m.weight.data.normal_(1.0, 0.02)
53+
m.bias.data.fill_(0)
54+
else:
55+
print('%s is not initialized.' % cname)
56+
57+
58+
def init_net(net, net_file):
59+
if net_file:
60+
net.load_state_dict(torch.load(net_file))
61+
else:
62+
net.apply(weights_init)
63+
64+
65+
def create_dir(path):
66+
if not os.path.exists(path):
67+
try:
68+
os.makedirs(path)
69+
except OSError as exc:
70+
if exc.errno != errno.EEXIST:
71+
raise
72+
73+
74+
class Logger(object):
75+
def __init__(self, output_name):
76+
dirname = os.path.dirname(output_name)
77+
if not os.path.exists(dirname):
78+
os.mkdir(dirname)
79+
80+
self.log_file = open(output_name, 'w')
81+
self.infos = {}
82+
83+
def append(self, key, val):
84+
vals = self.infos.setdefault(key, [])
85+
vals.append(val)
86+
87+
def log(self, extra_msg=''):
88+
msgs = [extra_msg]
89+
for key, vals in self.infos.iteritems():
90+
msgs.append('%s %.6f' % (key, np.mean(vals)))
91+
msg = '\n'.join(msgs)
92+
self.log_file.write(msg + '\n')
93+
self.log_file.flush()
94+
self.infos = {}
95+
return msg
96+
97+
def write(self, msg):
98+
self.log_file.write(msg + '\n')
99+
self.log_file.flush()
100+
print(msg)

0 commit comments

Comments
 (0)