-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathbert_features_extract.py
41 lines (36 loc) · 1.72 KB
/
bert_features_extract.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import json
from flair.embeddings import BertEmbeddings,DocumentPoolEmbeddings
import numpy as np
from tqdm import tqdm
from flair.data import Sentence
import os
import h5py
def extract_bert_features(json_file,dataroot_folder,choice="yes_no",split="train"):
questions=json.load(open(json_file))['questions']
question_ids=[quest['question_id'] for quest in questions]
#questions=questions[0:10]
bert=BertEmbeddings('bert-base-uncased')
doc_bert=DocumentPoolEmbeddings([bert])
bert_embed_matrix=np.zeros((len(questions),3072))
print('Extracting bert features')
for index,quest in tqdm(enumerate(questions)):
sentence=Sentence(quest['question'])
doc_bert.embed(sentence)
bert_embed_matrix[index]=sentence.embedding.numpy()
hdf5_file_path=os.path.join(dataroot_folder,split+'_bert_'+choice+'.hdf5')
h5f = h5py.File(hdf5_file_path, 'w')
h5f.create_dataset('bert_embeddings', data=bert_embed_matrix)
h5f.create_dataset('question_ids', data=question_ids)
h5f.close()
if __name__ == "__main__":
<<<<<<< Updated upstream
json_file="/proj/digbose92/VQA/VisualQuestion_VQA/Visual_All/data/v2_OpenEnded_mscoco_train2014_yes_no_questions.json"
dataroot_folder="/data/digbose92/VQA/COCO/train_hdf5_COCO"
=======
dataroot_folder="/home/nithin_rao/CSCI_599/VisualQuestion_VQA/"
json_file=dataroot_folder+"Visual_All/data/v2_OpenEnded_mscoco_train2014_yes_no_questions.json"
# json_file="/proj/digbose92/VQA/VisualQuestion_VQA/Visual_All/data/v2_OpenEnded_mscoco_val2014_yes_no_questions.json"
# dataroot_folder="/proj/digbose92/VQA/VisualQuestion_VQA/common_resources"
>>>>>>> Stashed changes
extract_bert_features(json_file,dataroot_folder)