Skip to content

Commit aeeaee8

Browse files
committed
Fixed the single image inference issue
1 parent 04a873d commit aeeaee8

9 files changed

+593
-62
lines changed

Visual_Attention/bert_features_extract.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def extract_bert_features(json_file,dataroot_folder,choice="yes_no",split="train
2929
h5f.close()
3030

3131
if __name__ == "__main__":
32-
json_file="/proj/digbose92/VQA/VisualQuestion_VQA/Visual_All/data/v2_OpenEnded_mscoco_train2014_1000_questions.json"
32+
json_file="/proj/digbose92/VQA/VisualQuestion_VQA/Visual_All/data/v2_OpenEnded_mscoco_train2014_yes_no_questions.json"
3333
dataroot_folder="/data/digbose92/VQA/COCO/train_hdf5_COCO"
3434
extract_bert_features(json_file,dataroot_folder)
3535

Visual_Attention/fusion_models.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
import torch.nn as nn
33
import torch.nn.functional as F
4-
54
class mfh_baseline(nn.Module):
65
def __init__(self,QUEST_EMBED,VIS_EMBED,MFB_FACTOR_NUM=5,MFB_OUT_DIM=1000,MFB_DROPOUT_RATIO=0.1,NUM_OUTPUT_UNITS=2):
76
super(mfh_baseline, self).__init__()
@@ -15,30 +14,42 @@ def __init__(self,QUEST_EMBED,VIS_EMBED,MFB_FACTOR_NUM=5,MFB_OUT_DIM=1000,MFB_DR
1514
self.Linear_imgproj2 = nn.Linear(VIS_EMBED, self.JOINT_EMB_SIZE)
1615
#self.Linear_predict = nn.Linear(MFB_OUT_DIM * 2, NUM_OUTPUT_UNITS)
1716
#self.Dropout1 = nn.Dropout(p=opt.LSTM_DROPOUT_RATIO)
18-
self.Dropout2 = nn.Dropout(MFB_DROPOUT_RATIO)
17+
#self.Dropout2 = nn.Dropout(MFB_DROPOUT_RATIO)
1918

2019
def forward(self, q_feat, img_feat):
2120

2221
mfb_q_o2_proj = self.Linear_dataproj1(q_feat) # data_out (N, 5000)
2322
mfb_i_o2_proj = self.Linear_imgproj1(img_feat.float()) # img_feature (N, 5000)
2423
mfb_iq_o2_eltwise = torch.mul(mfb_q_o2_proj, mfb_i_o2_proj)
25-
mfb_iq_o2_drop = self.Dropout2(mfb_iq_o2_eltwise)
26-
mfb_iq_o2_resh = mfb_iq_o2_drop.view(-1, 1, self.MFB_OUT_DIM, self.MFB_FACTOR_NUM) # N x 1 x 1000 x 5
27-
mfb_o2_out = torch.squeeze(torch.sum(mfb_iq_o2_resh, 3)) # N x 1000
28-
mfb_o2_out = torch.sqrt(F.relu(mfb_o2_out)) - torch.sqrt(F.relu(-mfb_o2_out)) # signed sqrt
24+
mfb_iq_o2_drop = mfb_iq_o2_eltwise
25+
#mfb_iq_o2_drop = self.Dropout2(mfb_iq_o2_eltwise)
26+
mfb_iq_o2_resh = mfb_iq_o2_drop.view(-1, 1, self.MFB_OUT_DIM, self.MFB_FACTOR_NUM)
27+
if(mfb_iq_o2_resh.size(0)>1): # N x 1 x 1000 x 5
28+
mfb_o2_out = torch.squeeze(torch.sum(mfb_iq_o2_resh, 3))
29+
else:
30+
mfb_o2_out = torch.sum(mfb_iq_o2_resh, 3).view(1,mfb_iq_o2_resh.size(2)) # N x 1000
31+
mfb_o2_out = torch.sqrt(F.relu(mfb_o2_out)) - torch.sqrt(F.relu(-mfb_o2_out))
32+
#print(mfb_o2_out.size()) # signed sqrt
2933
mfb_o2_out = F.normalize(mfb_o2_out)
34+
3035

3136
mfb_q_o3_proj = self.Linear_dataproj2(q_feat) # data_out (N, 5000)
3237
mfb_i_o3_proj = self.Linear_imgproj2(img_feat.float()) # img_feature (N, 5000)
3338
mfb_iq_o3_eltwise = torch.mul(mfb_q_o3_proj, mfb_i_o3_proj)
3439
mfb_iq_o3_eltwise = torch.mul(mfb_iq_o3_eltwise, mfb_iq_o2_drop)
35-
mfb_iq_o3_drop = self.Dropout2(mfb_iq_o3_eltwise)
40+
mfb_iq_o3_drop = mfb_iq_o3_eltwise
41+
#mfb_iq_o3_drop = self.Dropout2(mfb_iq_o3_eltwise)
3642
mfb_iq_o3_resh = mfb_iq_o3_drop.view(-1, 1, self.MFB_OUT_DIM, self.MFB_FACTOR_NUM)
37-
mfb_o3_out = torch.squeeze(torch.sum(mfb_iq_o3_resh, 3)) # N x 1000
43+
44+
#mfb_o3_out = torch.squeeze(torch.sum(mfb_iq_o3_resh, 3)) # N x 1000
45+
if(mfb_iq_o3_resh.size(0)>1): # N x 1 x 1000 x 5
46+
mfb_o3_out = torch.squeeze(torch.sum(mfb_iq_o3_resh, 3))
47+
else:
48+
mfb_o3_out = torch.sum(mfb_iq_o3_resh, 3).view(1,mfb_iq_o3_resh.size(2))
3849
mfb_o3_out = torch.sqrt(F.relu(mfb_o3_out)) - torch.sqrt(F.relu(-mfb_o3_out))
3950
mfb_o3_out = F.normalize(mfb_o3_out)
4051

41-
mfb_o23_out = torch.cat((mfb_o2_out, mfb_o3_out), 1) #200,2000
52+
mfb_o23_out = torch.cat((mfb_o2_out, mfb_o3_out), 1)#200,2000
4253
#prediction = self.Linear_predict(mfb_o23_out)
4354
#prediction = F.log_softmax(prediction)
4455

Visual_Attention/grad_cam.py

Whitespace-only changes.

Visual_Attention/inference_attention_model.py

+41-33
Original file line numberDiff line numberDiff line change
@@ -37,41 +37,48 @@ def compute_score_with_logits(logits, labels):
3737
def evaluate_attention_model(args):
3838

3939
class_data=pd.read_csv(args.class_metadata_file)
40-
class_label_map={0:"no",1:"yes"}
40+
#class_label_map={0:"no",1:"yes"}
4141

42-
#class_label_map=class_data['Label_names'].tolist()
42+
class_label_map=class_data['Label_names'].tolist()
4343

4444
print('Loading model checkpoint')
4545
attention_model_checkpoint=torch.load(args.model_path)
46-
4746
new_state_dict = OrderedDict()
4847
for k, v in attention_model_checkpoint.items():
4948
name = k[7:] # remove `module.`
5049
new_state_dict[name] = v
5150
print('Model checkpoint loaded')
51+
#new_state_dict["classifier.main.2.bias"]=new_state_dict.pop("classifier.main.3.bias")
52+
#new_state_dict["classifier.main.2.weight_g"]=new_state_dict.pop("classifier.main.3.weight_g")
53+
#new_state_dict["classifier.main.2.weight_v"]=new_state_dict.pop("classifier.main.3.weight_v")
5254

5355
print(new_state_dict.keys())
5456
print('Loading Dictionary')
5557
dictionary=Dictionary.load_from_file(args.pickle_path)
5658

5759
train_dataset=Dataset_VQA(img_root_dir=args.image_root_dir,feats_data_path=args.feats_data_path,dictionary=dictionary,choice='train',dataroot=args.data_root,arch_choice=args.arch_choice,layer_option=args.layer_option)
5860
print('Loading the attention model')
59-
attention_model = attention_mfh(train_dataset, num_hid=args.num_hid, dropout= args.dropout, norm=args.norm,\
61+
attention_model = attention_baseline(train_dataset, num_hid=args.num_hid, dropout= args.dropout, norm=args.norm,\
6062
activation=args.activation, drop_L=args.dropout_L, drop_G=args.dropout_G,\
61-
drop_W=args.dropout_W, drop_C=args.dropout_C,mfb_out_dim=args.mfb_out_dim)
63+
drop_W=args.dropout_W, drop_C=args.dropout_C)
64+
65+
#attention_model=attention_mfh(train_dataset, num_hid=args.num_hid, dropout= args.dropout, norm=args.norm,\
66+
#activation=args.activation, drop_L=args.dropout_L, drop_G=args.dropout_G,\
67+
#drop_W=args.dropout_W, drop_C=args.dropout_C,mfb_out_dim=args.mfb_out_dim)
6268
attention_model.load_state_dict(new_state_dict)
6369
attention_model.eval()
6470

6571
torch.manual_seed(args.seed)
66-
torch.cuda.manual_seed(args.seed)
72+
torch.cuda.manual_seed_all(args.seed)
73+
#torch.cuda.manual_seed(args.seed)
6774
torch.cuda.set_device(args.device)
6875
attention_model.to(args.device)
6976
if(args.image_model is None):
7077
"""use extracted features as a Dataset and Dataloader
7178
"""
7279
print('Using validation features')
73-
dataset_temp=Dataset_VQA(img_root_dir=args.image_root_dir,feats_data_path=args.feats_data_path,dictionary=dictionary,bert_option=args.bert_option,rcnn_pkl_path=args.rcnn_path,choice=args.choice,dataroot=args.data_root,arch_choice=args.arch_choice,layer_option=args.layer_option)
74-
loader=DataLoader(dataset_temp, batch_size=args.batch_size, shuffle=False, num_workers=10)
80+
dataset_temp=Dataset_VQA(img_root_dir=args.image_root_dir,feats_data_path=args.feats_data_path,dictionary=dictionary,bert_option=args.bert_option,rcnn_pkl_path=None,choice=args.choice,dataroot=args.data_root,arch_choice=args.arch_choice,layer_option=args.layer_option)
81+
loader=DataLoader(dataset_temp, batch_size=args.batch_size, shuffle=False, num_workers=1)
7582
print('Length of validation dataloader:', len(loader))
7683
upper_bound = 0
7784
num_data = 0
@@ -82,25 +89,26 @@ def evaluate_attention_model(args):
8289
predicted_class_labels=[]
8390
question_set=[]
8491
question_id=[]
85-
for data in tqdm(loader):
86-
87-
feat,quest,quest_sent,quest_id,target = data
88-
feat = feat.to(args.device)
89-
quest = quest.to(args.device)
90-
target = target.to(args.device)
91-
92-
question_id=question_id+quest_id.tolist()
93-
pred = attention_model(feat, quest, target)
94-
question_set=question_set+list(quest_sent)
95-
loss = instance_bce_with_logits(pred, target)
96-
V_loss += loss.item() * feat.size(0)
97-
score_temp, logits, class_labels= compute_score_with_logits(pred, target.data)
98-
actual_class_labels=actual_class_labels+list(class_labels.cpu().numpy())
99-
predicted_class_labels=predicted_class_labels+list(logits.cpu().numpy())
100-
batch_score=score_temp.sum()
101-
score += batch_score
102-
upper_bound += (target.max(1)[0]).sum()
103-
num_data += pred.size(0)
92+
count=0
93+
for data in tqdm(loader):
94+
feat,quest,quest_sent,quest_id,target = data
95+
feat = feat.to(args.device)
96+
quest = quest.to(args.device)
97+
target = target.to(args.device)
98+
99+
question_id=question_id+quest_id.tolist()
100+
pred = attention_model(feat, quest)
101+
question_set=question_set+list(quest_sent)
102+
loss = instance_bce_with_logits(pred, target)
103+
V_loss += loss.item() * feat.size(0)
104+
score_temp, logits, class_labels= compute_score_with_logits(pred, target.data)
105+
actual_class_labels=actual_class_labels+list(class_labels.cpu().numpy())
106+
predicted_class_labels=predicted_class_labels+list(logits.cpu().numpy())
107+
batch_score=score_temp.sum()
108+
score += batch_score
109+
upper_bound += (target.max(1)[0]).sum()
110+
num_data += pred.size(0)
111+
#count=count+1
104112

105113

106114

@@ -112,10 +120,10 @@ def evaluate_attention_model(args):
112120
for index,val in tqdm(enumerate(question_id)):
113121
temp={"answer":class_predicted_name[index],"question_id":val}
114122
list_set.append(temp)
115-
with open('validation_results.json', 'w') as fout:
123+
with open('validation_results_resnet_152_attention_baseline_num_hid_512_batch_size_512.json', 'w') as fout:
116124
json.dump(list_set , fout)
117-
#predicted_df=pd.DataFrame({'Questions':question_set,'Actual_Answers':class_actual_name,'Predicted_Answers':class_predicted_name})
118-
#predicted_df.to_csv('Validation_Stats.csv')
125+
predicted_df=pd.DataFrame({'Question_id':question_id,'Questions':question_set,'Actual_Answers':class_actual_name,'Predicted_Answers':class_predicted_name})
126+
predicted_df.to_csv('Validation_Stats_resnet_152_attention_baseline_num_hid_512_batch_size_512.csv')
119127
score = score / len(loader.dataset)
120128
V_loss /= len(loader.dataset)
121129
upper_bound = upper_bound / len(loader.dataset)
@@ -136,10 +144,10 @@ def evaluate_attention_model(args):
136144
parser.add_argument('--feats_data_path', type=str, default="/data/digbose92/VQA/COCO/train_hdf5_COCO/")
137145
parser.add_argument('--data_root', type=str, default="/proj/digbose92/VQA/VisualQuestion_VQA/common_resources")
138146
parser.add_argument('--npy_file', type=str, default="../../VisualQuestion_VQA/Visual_All/data/glove6b_init_300d.npy")
139-
parser.add_argument('--model_path', type=str, default="results_GRU_uni/results_rcnn_hid_1280_mfh_YES_NO_ADAM/model.pth")
147+
parser.add_argument('--model_path', type=str, default="results_GRU_uni/results_resnet_152_hid_512_YES_NO_ADAM/model.pth")
140148
parser.add_argument('--image_model', type=str, default=None)
141-
parser.add_argument('--batch_size', type=int, default=32)
142-
parser.add_argument('--num_hid', type=int, default=1280) # they used 1024
149+
parser.add_argument('--batch_size', type=int, default=512)
150+
parser.add_argument('--num_hid', type=int, default=512) # they used 1024
143151
parser.add_argument('--dropout', type=float, default=0.3)
144152
parser.add_argument('--dropout_L', type=float, default=0.1)
145153
parser.add_argument('--dropout_G', type=float, default=0.2)

Visual_Attention/model_combined.py

+55-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, w_emb, q_emb, v_att, q_net, v_net, classifier):
2525
self.v_net = v_net
2626
self.classifier = classifier
2727

28-
def forward(self, v, q, labels):
28+
def forward(self, v, q):
2929
"""Forward
3030
3131
v: [batch, num_objs, obj_dim]
@@ -86,7 +86,7 @@ def __init__(self, w_emb, q_emb, v_att, q_net, v_net, mfh_net, classifier):
8686
self.mfh_net = mfh_net
8787
self.classifier = classifier
8888

89-
def forward(self, v, q, labels):
89+
def forward(self, v, q):
9090
"""Forward
9191
9292
v: [batch, num_objs, obj_dim]
@@ -112,6 +112,7 @@ def forward(self, v, q, labels):
112112
logits = self.classifier(joint_repr)
113113
return logits
114114

115+
115116
class VQA_Model_MFH_classifier(nn.Module):
116117
def __init__(self, w_emb, q_emb, v_att, q_net, v_net, mfh_net):
117118
super(VQA_Model_MFH_classifier, self).__init__()
@@ -150,10 +151,44 @@ def forward(self, v, q, labels):
150151
#logits = self.classifier(joint_repr)
151152
return logits
152153

154+
class VQA_Model_MFH_BERT_fusion(nn.Module):
155+
def __init__(self, bert_emb, v_att, q_net, v_net, mfh_net,classifier):
156+
super(VQA_Model_MFH_BERT_fusion, self).__init__()
157+
self.bert_emb = bert_emb
158+
self.v_att = v_att
159+
self.q_net = q_net
160+
self.v_net = v_net
161+
self.mfh_net = mfh_net
162+
self.classifier = classifier
163+
164+
def forward(self, v, q, labels):
165+
"""Forward
166+
167+
v: [batch, num_objs, obj_dim]
168+
q: [batch_size, seq_length]
169+
170+
return: logits, not probs
171+
"""
172+
q_emb = self.bert_emb(q)
173+
#print(q_emb.size())
174+
175+
att = self.v_att(v, q_emb) # [batch, 1, v_dim]
176+
v_emb = (att * v).sum(1) # [batch, v_dim]
177+
178+
q_repr = self.q_net(q_emb)
179+
v_repr = self.v_net(v_emb)
180+
#joint_repr=self.mfh_net(q_repr,v_repr)
181+
joint_repr=self.mfh_net(q_repr,v_repr)
182+
#joint_repr = q_repr * v_repr
183+
184+
#invoke MFH for fusion of q_repr and v_repr
153185

186+
logits = self.classifier(joint_repr)
187+
return logits
154188

155189
############# ATTENTION BASELINE ############
156190
def attention_baseline(dataset, num_hid, dropout, norm, activation, drop_L , drop_G, drop_W, drop_C, bidirect_val=False):
191+
print('Here in the attention baseline')
157192
w_emb = WordEmbedding(dataset.dictionary.ntoken, emb_dim=300, dropout=drop_W)
158193
q_emb = QuestionEmbedding(in_dim=300, num_hid=num_hid, nlayers=1, bidirect=bidirect_val, dropout=drop_G, rnn_type='GRU')
159194
#bert_emb=BertEmbedding(in_dim=7168,num_hid=num_hid)
@@ -225,6 +260,24 @@ def attention_mfh_classifier(dataset, num_hid, dropout, norm, activation, drop_L
225260
return(VQA_Model_MFH_classifier(w_emb,q_emb,v_att,q_net,v_net,mfh_net))
226261

227262

263+
###### ATTENTION + BERT + MFH FUSION #############
264+
def attention_bert_mfh_fusion(dataset, num_hid, dropout, norm, activation, drop_L , drop_G, drop_W, drop_C, mfb_out_dim, bidirect_val=False):
265+
#w_emb = WordEmbedding(dataset.dictionary.ntoken, emb_dim=300, dropout=drop_W)
266+
#q_emb = QuestionEmbedding(in_dim=300, num_hid=num_hid, nlayers=1, bidirect=bidirect_val, dropout=drop_G, rnn_type='GRU')
267+
268+
bert_emb=BertEmbedding(in_dim=3072,num_hid=num_hid)
269+
v_att = Base_Att(v_dim= dataset.v_dim, q_dim= num_hid, num_hid= num_hid, dropout= dropout, bidirect=bidirect_val,norm= norm, act= activation)
270+
if(bidirect_val is False):
271+
q_net = FCNet([num_hid, num_hid], dropout= drop_L, norm= norm, act= activation)
272+
#v_net = FCNet([dataset.v_dim, num_hid], dropout= drop_L, norm= norm, act= activation)
273+
else:
274+
q_net = FCNet([2*num_hid, num_hid], dropout= drop_L, norm= norm, act= activation)
275+
276+
v_net = FCNet([dataset.v_dim, num_hid], dropout= drop_L, norm= norm, act= activation)
277+
mfh_net=mfh_baseline(QUEST_EMBED=num_hid,VIS_EMBED=num_hid,MFB_OUT_DIM=mfb_out_dim)
278+
classifier = SimpleClassifier(in_dim=2*mfb_out_dim, hid_dim=2 * num_hid, out_dim=dataset.num_ans_candidates, dropout=drop_C, norm= norm, act= activation)
279+
return(VQA_Model_MFH_BERT_fusion(bert_emb,v_att,q_net,v_net,mfh_net,classifier))
280+
228281

229282
def weights_init_xn(m):
230283
if isinstance(m, nn.Linear):

0 commit comments

Comments
 (0)