-
Notifications
You must be signed in to change notification settings - Fork 91
/
Copy pathdataset.py
112 lines (96 loc) · 3.77 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import argparse
import os
import glob
import torch
import ast
import numpy as np
import pandas as pd
from tqdm import tqdm, trange
from torch.utils.data import Dataset, DataLoader
import lightning as L
from functools import partial
class KoBARTSummaryDataset(Dataset):
def __init__(self, file, tokenizer, max_len, ignore_index=-100):
super().__init__()
self.tokenizer = tokenizer
self.max_len = max_len
self.docs = pd.read_csv(file, sep='\t')
self.len = self.docs.shape[0]
self.pad_index = self.tokenizer.pad_token_id
self.ignore_index = ignore_index
def add_padding_data(self, inputs):
if len(inputs) < self.max_len:
pad = np.array([self.pad_index] *(self.max_len - len(inputs)))
inputs = np.concatenate([inputs, pad])
else:
inputs = inputs[:self.max_len]
return inputs
def add_ignored_data(self, inputs):
if len(inputs) < self.max_len:
pad = np.array([self.ignore_index] *(self.max_len - len(inputs)))
inputs = np.concatenate([inputs, pad])
else:
inputs = inputs[:self.max_len]
return inputs
def __getitem__(self, idx):
instance = self.docs.iloc[idx]
input_ids = self.tokenizer.encode(instance['news'])
input_ids = self.add_padding_data(input_ids)
label_ids = self.tokenizer.encode(instance['summary'])
label_ids.append(self.tokenizer.eos_token_id)
dec_input_ids = [self.tokenizer.eos_token_id]
dec_input_ids += label_ids[:-1]
dec_input_ids = self.add_padding_data(dec_input_ids)
label_ids = self.add_ignored_data(label_ids)
return {'input_ids': np.array(input_ids, dtype=np.int_),
'decoder_input_ids': np.array(dec_input_ids, dtype=np.int_),
'labels': np.array(label_ids, dtype=np.int_)
}
def __len__(self):
return self.len
class KobartSummaryModule(L.LightningDataModule):
def __init__(self, train_file,
test_file, tok,
max_len=512,
batch_size=8,
num_workers=4):
super().__init__()
self.batch_size = batch_size
self.max_len = max_len
self.train_file_path = train_file
self.test_file_path = test_file
self.tok = tok
self.num_workers = num_workers
@staticmethod
def add_model_specific_args(parent_parser):
parser = argparse.ArgumentParser(
parents=[parent_parser], add_help=False)
parser.add_argument('--num_workers',
type=int,
default=4,
help='num of worker for dataloader')
return parser
# OPTIONAL, called for every GPU/machine (assigning state is OK)
def setup(self, stage):
# split dataset
self.train = KoBARTSummaryDataset(self.train_file_path,
self.tok,
self.max_len)
self.test = KoBARTSummaryDataset(self.test_file_path,
self.tok,
self.max_len)
def train_dataloader(self):
train = DataLoader(self.train,
batch_size=self.batch_size,
num_workers=self.num_workers, shuffle=True)
return train
def val_dataloader(self):
val = DataLoader(self.test,
batch_size=self.batch_size,
num_workers=self.num_workers, shuffle=False)
return val
def test_dataloader(self):
test = DataLoader(self.test,
batch_size=self.batch_size,
num_workers=self.num_workers, shuffle=False)
return test