forked from PaddlePaddle/PaddleNLP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgen_utils.py
316 lines (276 loc) ยท 12.3 KB
/
gen_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from functools import partial
import numpy as np
import paddle
import paddle.distributed as dist
from paddle.io import DataLoader, DistributedBatchSampler, BatchSampler
from paddlenlp.data import Pad
def print_args(args):
print("----------- Configuration Arguments -----------")
for arg, value in sorted(vars(args).items()):
print("%s: %s" % (arg, value))
print("------------------------------------------------")
def set_seed(seed):
# Use the same data seed(for data shuffle) for all procs to guarantee data
# consistency after sharding.
random.seed(seed)
np.random.seed(seed)
# Maybe different op seeds(for dropout) for different procs is better.
paddle.seed(seed + dist.get_rank())
def convert_example(
example, tokenizer, max_seq_len=512, max_target_len=128, max_title_len=256, mode="train", template=0
):
"""Convert all examples into necessary features."""
if mode == "pretrain" or mode == "pretrain_test":
context = example["context"]
answer = example["answer"]
target = example["target"]
source = "็ญๆก๏ผ" + answer + tokenizer.sep_token + "ไธไธๆ๏ผ" + context
title = None
elif mode == "train" or mode == "test":
target = None
title = None
if "source" in example and "title" in example:
source = example["source"]
if "title" in example.keys():
title = example["title"]
elif "context" in example and "answer" in example:
source = example["context"]
if "answer" in example.keys():
title = example["answer"]
else:
assert False, "Source and title are not in the input dictionary, nor are context and answer."
if "target" in example.keys():
target = example["target"]
elif "question" in example.keys():
target = example["question"]
if template == 1:
source = "็ญๆก๏ผ" + title + tokenizer.sep_token + "ไธไธๆ๏ผ" + source
title = None
if target:
target = "้ฎ้ข๏ผ" + target
elif template == 2:
source = "็ญๆก๏ผ" + title + tokenizer.sep_token + "ไธไธๆ๏ผ" + source
title = None
if target:
target = "ๅจๅทฒ็ฅ็ญๆก็ๅๆไธ๏ผ้ฎ้ข๏ผ" + target
elif template == 3:
source = "่ฟๆฏไธไธช้ฎ้ข็ๆไปปๅก๏ผๆ นๆฎๆไพ็็ญๆกๅไธไธๆ๏ผๆฅ็ๆ้ฎ้ขใ" + title + tokenizer.sep_token + "ไธไธๆ๏ผ" + source
title = None
if target:
target = "้ฎ้ข๏ผ" + target
elif template == 4:
prompt_common = example["prompt_common"]
prompt_domain = example["prompt_domain"]
source = (
prompt_common
+ " "
+ tokenizer.sep_token
+ " "
+ "".join(
[" " + tokenizer.cls_token + " " + one + " " + tokenizer.sep_token + " " for one in prompt_domain]
)
+ " "
+ tokenizer.cls_token
+ " "
+ "็ญๆก๏ผ"
+ title
+ " "
+ tokenizer.sep_token
+ " "
+ tokenizer.cls_token
+ "ไธไธๆ๏ผ"
+ source
)
title = None
if target:
target = "้ฎ้ข๏ผ" + target
if mode == "train" or mode == "pretrain":
tokenized_example = tokenizer.gen_encode(
source,
title=title,
target=target,
max_seq_len=max_seq_len,
max_target_len=max_target_len,
max_title_len=max_title_len,
return_position_ids=True,
return_length=True,
)
temp_tokens = tokenizer.convert_ids_to_tokens(tokenized_example["input_ids"])
index_list = []
count = tokenized_example["input_ids"].count(tokenizer.cls_token_id)
# If template==4, count must be equal to 7, otherwise count must be equal to 2
assert count == 7 or count == 2, (
str(count) + " is not in [2, 7], temp_tokens: " + " ".join(temp_tokens) + "source: " + source
)
index = -1
for i in range(0, count):
index = tokenized_example["input_ids"].index(tokenizer.cls_token_id, index + 1)
index_list.append(index)
if template == 4:
tokenized_example["token_type_ids"] = (
[2] * (index_list[1] - index_list[0])
+ [3] * (index_list[4] - index_list[1])
+ [0] * (index_list[6] - index_list[4])
+ [1] * (len(tokenized_example["input_ids"]) - index_list[6])
)
target_start = index_list[-1]
target_end = tokenized_example["seq_len"]
# Use to gather the logits corresponding to the labels during training
tokenized_example["masked_positions"] = list(range(target_start, target_end - 1))
tokenized_example["labels"] = tokenized_example["input_ids"][target_start + 1 : target_end]
if template == 4:
tokenized_example["token_type_ids"]
return tokenized_example
elif mode == "test" or mode == "pretrain_test":
tokenized_example = tokenizer.gen_encode(
source,
title=title,
max_seq_len=max_seq_len,
max_title_len=max_title_len,
add_start_token_for_decoding=True,
return_position_ids=True,
)
if template == 4:
# temp_tokens = tokenizer.convert_ids_to_tokens(tokenized_example['input_ids'])
index_list = []
count = tokenized_example["input_ids"].count(tokenizer.cls_token_id)
assert count == 7, str(count) + " is not in [7]"
index = -1
for i in range(0, count):
index = tokenized_example["input_ids"].index(tokenizer.cls_token_id, index + 1)
index_list.append(index)
tokenized_example["token_type_ids"] = (
[2] * (index_list[1] - index_list[0])
+ [3] * (index_list[4] - index_list[1])
+ [0] * (index_list[6] - index_list[4])
+ [1] * (len(tokenized_example["input_ids"]) - index_list[6])
)
if "target" in example and example["target"]:
tokenized_example["target"] = example["target"]
elif "question" in example and example["question"]:
tokenized_example["target"] = example["question"]
return tokenized_example
def batchify_fn(batch_examples, pad_val, mode):
def pad_mask(batch_attention_mask):
batch_size = len(batch_attention_mask)
max_len = max(map(len, batch_attention_mask))
attention_mask = np.ones((batch_size, max_len, max_len), dtype="float32") * -1e9
for i, mask_data in enumerate(attention_mask):
seq_len = len(batch_attention_mask[i])
mask_data[-seq_len:, -seq_len:] = np.array(batch_attention_mask[i], dtype="float32")
# In order to ensure the correct broadcasting mechanism, expand one
# dimension to the second dimension (n_head of Transformer).
attention_mask = np.expand_dims(attention_mask, axis=1)
return attention_mask
pad_func = Pad(pad_val=pad_val, pad_right=False, dtype="int64")
input_ids = pad_func([example["input_ids"] for example in batch_examples])
token_type_ids = pad_func([example["token_type_ids"] for example in batch_examples])
position_ids = pad_func([example["position_ids"] for example in batch_examples])
attention_mask = pad_mask([example["attention_mask"] for example in batch_examples])
if mode == "train" or mode == "pretrain":
max_len = max([example["seq_len"] for example in batch_examples])
masked_positions = np.concatenate(
[
np.array(example["masked_positions"]) + (max_len - example["seq_len"]) + i * max_len
for i, example in enumerate(batch_examples)
]
)
labels = np.concatenate([np.array(example["labels"], dtype="int64") for example in batch_examples])
return input_ids, token_type_ids, position_ids, attention_mask, masked_positions, labels
elif mode == "test" or mode == "pretrain_test":
return input_ids, token_type_ids, position_ids, attention_mask
def create_data_loader(dataset, tokenizer, args, mode):
trans_func = partial(
convert_example,
tokenizer=tokenizer,
max_seq_len=args.max_seq_len,
max_target_len=args.max_target_len,
max_title_len=args.max_title_len,
mode=mode,
template=args.template,
)
dataset = dataset.map(trans_func, lazy=True)
if mode == "pretrain":
batch_sampler = DistributedBatchSampler(dataset, batch_size=args.batch_size, shuffle=True)
elif mode == "train":
batch_sampler = DistributedBatchSampler(dataset, batch_size=args.batch_size, shuffle=True)
elif mode == "test" or mode == "pretrain_test":
batch_sampler = BatchSampler(dataset, batch_size=args.batch_size // 2, shuffle=False)
collate_fn = partial(batchify_fn, pad_val=tokenizer.pad_token_id, mode=mode)
data_loader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, return_list=True)
return dataset, data_loader
def post_process_sum(token_ids, tokenizer):
"""Post-process the decoded sequence. Truncate from the first <eos>."""
eos_pos = len(token_ids)
for i, tok_id in enumerate(token_ids):
if tok_id == tokenizer.mask_token_id:
eos_pos = i
break
token_ids = token_ids[:eos_pos]
tokens = tokenizer.convert_ids_to_tokens(token_ids)
tokens = tokenizer.merge_subword(tokens)
special_tokens = ["[UNK]"]
tokens = [token for token in tokens if token not in special_tokens]
return token_ids, tokens
def remove_template(instr):
"""Remove template prefix of decoded sequence."""
outstr = instr.strip("้ฎ้ข๏ผ")
outstr = instr.strip("ๅจๅทฒ็ฅ็ญๆก็ๅๆไธ๏ผ้ฎ้ข๏ผ")
return outstr
def select_sum(ids, scores, tokenizer, max_dec_len=None, num_return_sequences=1):
results = []
group = []
tmp = []
if scores is not None:
ids = ids.numpy()
scores = scores.numpy()
if len(ids) != len(scores) or (len(ids) % num_return_sequences) != 0:
raise ValueError(
"the length of `ids` is {}, but the `num_return_sequences` is {}".format(
len(ids), num_return_sequences
)
)
for pred, score in zip(ids, scores):
pred_token_ids, pred_tokens = post_process_sum(pred, tokenizer)
num_token = len(pred_token_ids)
target = "".join(pred_tokens)
target = remove_template(target)
# not ending
if max_dec_len is not None and num_token >= max_dec_len:
score -= 1e3
tmp.append([target, score])
if len(tmp) == num_return_sequences:
group.append(tmp)
tmp = []
for preds in group:
preds = sorted(preds, key=lambda x: -x[1])
results.append(preds[0][0])
else:
ids = ids.numpy()
for pred in ids:
pred_token_ids, pred_tokens = post_process_sum(pred, tokenizer)
num_token = len(pred_token_ids)
response = "".join(pred_tokens)
response = remove_template(response)
# TODO: Support return scores in FT.
tmp.append([response])
if len(tmp) == num_return_sequences:
group.append(tmp)
tmp = []
for preds in group:
results.append(preds[0][0])
return results