-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtemp.py
137 lines (112 loc) · 3.89 KB
/
temp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import SequentialSampler
from box import Box
import yaml
# opening training_args file
with open('configs/config.yaml') as f:
cfg = Box(yaml.safe_load(f))
class Img2MML_dataset(Dataset):
def __init__(self, dataframe, vocab, tokenizer):
self.dataframe = dataframe
self.vocab = vocab
def __len__(self):
return len(self.dataframe)
def __getitem__(self, index):
eqn = self.dataframe.iloc[index, 1]
indexed_eqn = []
for token in eqn.split():
if self.vocab.stoi[token] is not None:
indexed_eqn.append(self.vocab.stoi[token])
else:
indexed_eqn.append(self.vocab.stoi["<unk>"])
return self.dataframe.iloc[index, 0], torch.Tensor(indexed_eqn)
class My_pad_collate(object):
"""
padding mml to max_len, and stacking images
return: mml_tensors of shape [batch, max_len]
stacked image_tensors [batch]
"""
def __init__(self, device, vocab, max_len):
self.device = device
self.vocab = vocab
self.max_len = max_len
self.pad_idx = vocab.stoi["<pad>"]
def __call__(self, batch):
_img, _mml = zip(*batch)
# padding mml
# padding to a fix max_len equations with more tokens than
# max_len will be chopped down to max_length.
batch_size = len(_mml)
padded_mml_tensors = (
torch.ones([batch_size, self.max_len], dtype=torch.long)
* self.pad_idx
)
for b in range(batch_size):
if len(_mml[b]) <= self.max_len:
padded_mml_tensors[b][: len(_mml[b])] = _mml[b]
else:
padded_mml_tensors[b][: self.max_len] = _mml[b][: self.max_len]
# images tensors
_img = torch.Tensor(_img)
return (
_img.to(self.device),
padded_mml_tensors.to(self.device),
)
def bin_test_dataloader(
vocab,
device,
start=None,
end=None,
length_based_binning=False,
content_based_binning=False,
):
df = pd.read_csv(
f"{cfg.preprocessing.path_to_data}/test.csv"
)
imgs, eqns = df["IMG"], df["EQUATION"]
eqns_arr = list()
imgs_arr = list()
if length_based_binning:
for i, e in zip(imgs, eqns):
if len(e.split()) > start and len(e.split()) <= end:
eqns_arr.append(e)
imgs_arr.append(i)
elif content_based_binning: # only fo mml
"""
first run the latex config and save the csv as
test_latex which then be used for reference.
"""
df_latex = pd.read_csv(
f"{cfg.preprocessing.path_to_data}/test_latex.csv"
)
eqns_latex = df_latex["EQUATION"]
for idx, e in enumerate(eqns_latex):
if len(e.split()) > start and len(e.split()) <= end:
eqns_arr.append(eqns[idx])
imgs_arr.append(imgs[idx])
raw_mml_data = {
"IMG": imgs_arr,
"EQUATION": eqns_arr,
}
test = pd.DataFrame(raw_mml_data, columns=["IMG", "EQUATION"])
# define tokenizer function
def tokenizer(x):
return x.split()
# initializing pad collate class
mypadcollate = My_pad_collate(device, vocab, cfg.model.decoder_transformer.max_len)
# initailizing class Img2MML_dataset: test dataloader
imml_test = Img2MML_dataset(test, vocab, tokenizer)
sampler = None
shuffle = cfg.preprocessing.shuffle
test_dataloader = DataLoader(
imml_test,
batch_size=cfg.training.batch_size,
num_workers=0,
shuffle=shuffle,
sampler=sampler,
collate_fn=mypadcollate,
pin_memory=False,
)
return test_dataloader