-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
69 lines (61 loc) · 2.16 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import pandas
import tokenizer
class LangDataset(torch.utils.data.Dataset):
def __init__(self):
self.column_names = ["id_eng", "eng", "id_cantonese", "cantonese"]
self.df = pandas.read_csv(
"./eng-mandarin.tsv",
delimiter="\t",
encoding="utf-8",
on_bad_lines="skip",
header=None,
names=self.column_names,
)
self.tk = tokenizer.LangTokenizer()
self.tk.load()
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
contx = self.tk.encode(row["eng"])
input = [self.tk.sp.bos_id()] + self.tk.encode(row["cantonese"])
label = (self.tk.encode(row["cantonese"])) + [self.tk.sp.eos_id()]
return {
"txt_eng": row["eng"],
"txt_cantonese": row["cantonese"],
"contx": torch.tensor(contx),
"input": torch.tensor(input),
"label": torch.tensor(label),
}
def collate_fn(self, batch):
contx_pad = torch.nn.utils.rnn.pad_sequence(
[item["contx"] for item in batch], batch_first=True, padding_value=0
)
input_pad = torch.nn.utils.rnn.pad_sequence(
[item["input"] for item in batch], batch_first=True, padding_value=0
)
label_pad = torch.nn.utils.rnn.pad_sequence(
[item["label"] for item in batch], batch_first=True, padding_value=0
)
return {
"eng": [item["txt_eng"] for item in batch],
"cantonese": [item["txt_cantonese"] for item in batch],
"contx": contx_pad,
"input": input_pad,
"label": label_pad,
}
if __name__ == "__main__":
# ds = Dataset()
# emma = ds[0]
# print('emma', emma)
# 'plain': 'emma'
# 'input': tensor([ 7, 15, 15, 3])
# 'label': tensor([15, 15, 3, 1])
# 'masks': tensor([ 1, 1, 1, 1])
ds = LangDataset()
print("len(ds)", len(ds))
print("ds[362]", ds[362])
tk = tokenizer.LangTokenizer()
decoded_input = tk.decode(ds[362]['input'].tolist())
print("ds[362] decode input", decoded_input)