Skip to content

Commit 7780497

Browse files
authoredMay 19, 2022
[GPT-3] Support generation code for GPT-3 in static graph. (PaddlePaddle#2188)
* Support GPT-3 generation in static graph * Support batch generation and fix code style
1 parent 1fdf308 commit 7780497

File tree

10 files changed

+690
-33
lines changed

10 files changed

+690
-33
lines changed
 

‎examples/language_model/gpt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
../../model_zoo/gpt
1+
../../model_zoo/gpt/

‎examples/language_model/gpt-3/README.md

+11
Original file line numberDiff line numberDiff line change
@@ -144,5 +144,16 @@ python -u -m paddle.distributed.fleet.launch \
144144

145145
除了上述混合并行策略外,飞桨还支持重计算、offload、混合精度等策略,来减少显存占用、加速训练。更多具体内容可以参考稿件:[飞桨分布式训练又推新品,4D混合并行可训千亿级AI模型](https://baijiahao.baidu.com/s?id=1697085717806202673)
146146

147+
### 飞桨超大模型部署
148+
149+
飞桨超大模型部署工具:
150+
151+
- Paddle Fleet: 飞桨训练自适应并行技术,同样适应于超大模型部署,针对推理硬件自适应切分
152+
- Paddle Inference: 支持模型并行、流水线并行、混合并行策略,经过极致优化,性能领先
153+
- Paddle Serving: 支持服务化部署,支持自动Batch、容错调度、服务监控、负载均衡
154+
- Paddle Slim: 支持超大模型量化、稀疏压缩
155+
156+
具体部署示例参考[GPT-3超大模型部署教程](deploy)
157+
147158
### 参考文献
148159
- [Language Models are Few-Shot Learners](https://arxiv.org/pdf/2005.14165.pdf)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
## 超大模型部署
2+
3+
TBD
4+
5+
### 模型导出
6+
7+
### 自动切分
8+
9+
### 推理部署
10+
11+
### Benchmark

‎examples/language_model/gpt-3/static/args.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,6 @@ def parse_args(MODEL_CLASSES):
160160
type=int,
161161
default=10,
162162
help="Evaluate the model use X steps data.")
163-
164163
# Config for 4D Parallelism
165164
parser.add_argument(
166165
"--use_sharding",
@@ -258,6 +257,46 @@ def parse_args(MODEL_CLASSES):
258257
default=None,
259258
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
260259
)
260+
parser.add_argument(
261+
"--max_dec_len",
262+
type=int,
263+
default=20,
264+
help="The maximum length of decoded sequence.", )
265+
parser.add_argument(
266+
"--decoding_strategy",
267+
type=str,
268+
default="topk_sampling",
269+
choices=["topk_sampling", "topp_sampling", "sampling"],
270+
help="The decoding strategy, not support beam_search now!", )
271+
parser.add_argument(
272+
"--temperature",
273+
type=float,
274+
default=1.,
275+
help="The temperature in each generation step.")
276+
# top-k sampling
277+
parser.add_argument(
278+
"--topk",
279+
type=int,
280+
default=10,
281+
help="The hyper-parameter in top-k sampling..")
282+
# top-p sampling
283+
parser.add_argument(
284+
"--topp",
285+
type=float,
286+
default=0.9,
287+
help="The hyper-parameter in top-p sampling.")
288+
# beam search
289+
parser.add_argument(
290+
"--beam_size",
291+
type=int,
292+
default=1,
293+
help="The hyper-parameter in beam search.")
294+
parser.add_argument(
295+
"--save_inference_model_then_exist",
296+
type=bool,
297+
default=False,
298+
help="save_inference_model_then_exist")
299+
261300
args = parser.parse_args()
262301
args.test_iters = args.eval_iters * 10
263302

‎examples/language_model/gpt-3/static/dataset.py

+29-18
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _num_tokens(documents, lens):
148148

149149

150150
def _num_epochs(tokens_per_epoch, seq_length, num_samples):
151-
"""Based on number of samples and sequence lenght, calculate how many
151+
"""Based on number of samples and sequence length, calculate how many
152152
epochs will be needed."""
153153
num_epochs = 0
154154
total_tokens = 0
@@ -256,18 +256,17 @@ def get_train_valid_test_split_(splits_string, size):
256256
return splits_index
257257

258258

259-
def create_pretrained_dataset(
260-
args,
261-
input_path,
262-
local_rank,
263-
data_world_rank,
264-
data_world_size,
265-
eos_id,
266-
worker_init=None,
267-
max_seq_len=1024,
268-
places=None,
269-
data_holders=None,
270-
pipeline_mode=False, ):
259+
def create_pretrained_dataset(args,
260+
input_path,
261+
local_rank,
262+
data_world_rank,
263+
data_world_size,
264+
eos_id,
265+
worker_init=None,
266+
max_seq_len=1024,
267+
places=None,
268+
data_holders=None,
269+
pipeline_mode=False):
271270

272271
if local_rank == 0:
273272
start_time = time.time()
@@ -339,7 +338,8 @@ def build_dataset(index, name, num_samples):
339338
sample_lens=sample_lens,
340339
eos_id=eos_id,
341340
seed=args.seed,
342-
use_pure_fp16=args.use_amp and args.amp_level == "O2")
341+
use_pure_fp16=args.use_amp and args.amp_level == "O2",
342+
data_holders=data_holders)
343343
batch_sampler = DistributedBatchSampler(
344344
dataset,
345345
batch_size=args.micro_batch_size,
@@ -361,14 +361,16 @@ def data_gen():
361361
data_loader.set_sample_generator(
362362
data_gen, batch_size=args.micro_batch_size, places=places)
363363
else:
364+
stacks = (Stack(), ) * len(data_holders)
365+
collate_fn = Tuple(*stacks)
364366
data_loader = DataLoader(
365367
dataset=dataset,
366368
places=places,
367369
feed_list=data_holders,
368370
batch_sampler=batch_sampler,
369371
num_workers=1,
370372
worker_init_fn=worker_init,
371-
collate_fn=Tuple(Stack(), Stack(), Stack(), Stack()),
373+
collate_fn=collate_fn,
372374
return_list=False)
373375
return data_loader
374376

@@ -401,7 +403,8 @@ def __init__(self,
401403
name="gpt",
402404
max_seq_len=1024,
403405
seed=1234,
404-
use_pure_fp16=False):
406+
use_pure_fp16=False,
407+
data_holders=None):
405408
self.file_prefix = file_prefix
406409
self.max_seq_len = max_seq_len
407410
self.name = name
@@ -410,6 +413,7 @@ def __init__(self,
410413
self.sample_lens = sample_lens
411414
self.micro_batch_size = micro_batch_size
412415
self.use_pure_fp16 = use_pure_fp16
416+
self.data_holders = data_holders
413417

414418
if documents is None:
415419
document_ids = np.arange(0, self.sample_lens.shape[0])
@@ -435,10 +439,17 @@ def _construct_sample(self, tokens):
435439
else:
436440
loss_mask = np.ones(seq_length, dtype="float32")
437441
loss_mask[np.where(np.array(tokens) == self.eos_id)] = 0.0
438-
position_ids = np.arange(0, seq_length, dtype="int64")
439442

443+
position_ids = np.arange(0, seq_length, dtype="int64")
440444
labels = np.array(labels, dtype="int64")
441-
return [tokens, loss_mask, position_ids, labels]
445+
if len(self.data_holders) == 4:
446+
return [tokens, loss_mask, position_ids, labels]
447+
elif len(self.data_holders) == 3:
448+
return [tokens, loss_mask, position_ids]
449+
else:
450+
assert len(self.data_holders) == 1, \
451+
"length of daat_holders should be 4, 3 or 1"
452+
return [tokens]
442453

443454
def _get_single_sample_from_idx(self, doc_index_f, doc_index_l, offset_f,
444455
offset_l):

‎examples/language_model/gpt-3/static/modeling.py

+277-12
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,13 @@
2929
import paddlenlp
3030

3131
__all__ = [
32-
'GPTModel',
33-
'GPTForPretraining',
34-
'GPTPretrainingCriterion',
32+
'GPTModel', 'GPTForPretraining', 'GPTPretrainingCriterion',
33+
'GPTForGeneration'
3534
]
3635

36+
device = "gpu"
37+
int_type = "int64"
38+
3739

3840
class MultiHeadAttention(nn.Layer):
3941
"""
@@ -153,6 +155,11 @@ def _prepare_qkv(self, query, key, value, use_cache=False, cache=None):
153155
# for decoder self-attention in inference
154156
k = tensor.concat([cache.k, k], axis=2)
155157
v = tensor.concat([cache.v, v], axis=2)
158+
159+
## if not assign here, assign in While loop
160+
#layers.assign(k, cache.k) # update caches
161+
#layers.assign(v, cache.v)
162+
156163
if use_cache is True:
157164
cache = self.Cache(k, v)
158165

@@ -229,7 +236,12 @@ def forward(self,
229236
product = layers.matmul(
230237
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)
231238

232-
weights = incubate.softmax_mask_fuse_upper_triangle(product)
239+
if self.training:
240+
weights = incubate.softmax_mask_fuse_upper_triangle(product)
241+
else:
242+
if attn_mask is not None:
243+
product = product + attn_mask
244+
weights = F.softmax(product)
233245

234246
if self.dropout:
235247
weights = F.dropout(
@@ -311,12 +323,20 @@ def forward(self,
311323
cache=cache)
312324

313325
else:
314-
output, new_cache = mod(output,
315-
memory,
316-
tgt_mask=tgt_mask,
317-
use_cache=use_cache,
318-
cache=cache[i])
319-
new_caches.append(new_cache)
326+
if use_cache:
327+
output, new_cache = mod(output,
328+
memory,
329+
tgt_mask=tgt_mask,
330+
use_cache=use_cache,
331+
cache=cache[i])
332+
new_caches.append(new_cache)
333+
else:
334+
output = mod(output,
335+
memory,
336+
tgt_mask=tgt_mask,
337+
use_cache=use_cache,
338+
cache=cache[i])
339+
320340
self.checkpoints.append(output.name)
321341

322342
if self.norm is not None:
@@ -675,6 +695,9 @@ def __init__(self,
675695
self.topo = topo
676696
self.hidden_size = hidden_size
677697
self.vocab_size = vocab_size
698+
self.num_attention_heads = num_attention_heads
699+
self.num_hidden_layers = num_hidden_layers
700+
self.hidden_size = hidden_size
678701

679702
self.pipline_mode = topo is not None and topo.pp_info.size > 1
680703
if self.pipline_mode:
@@ -738,16 +761,18 @@ def forward(self,
738761
paddle.shape(input_ids)[-1] + past_length,
739762
dtype='int64')
740763
position_ids = position_ids.unsqueeze(0)
741-
# .expand_as(input_ids)
742764
position_ids = paddle.fluid.layers.expand_as(position_ids,
743765
input_ids)
744766
embedding_output = self.embeddings(
745767
input_ids=input_ids, position_ids=position_ids)
746768

769+
tgt_mask = None
770+
if not self.training:
771+
tgt_mask = attention_mask
747772
encoder_outputs = self.decoder(
748773
embedding_output,
749774
memory=None,
750-
tgt_mask=None,
775+
tgt_mask=tgt_mask,
751776
use_cache=use_cache,
752777
cache=cache)
753778
self.checkpoints.extend(self.decoder.checkpoints)
@@ -830,3 +855,243 @@ def forward(self, prediction_scores, masked_lm_labels, loss_mask):
830855
masked_lm_loss = paddle.sum(masked_lm_loss.reshape([-1]) * loss_mask)
831856
loss = masked_lm_loss / loss_mask.sum()
832857
return loss
858+
859+
860+
class GPTForGeneration(GPTPretrainedModel):
861+
def __init__(self,
862+
gpt,
863+
max_length=20,
864+
min_length=0,
865+
decoding_strategy='sampling',
866+
temperature=1.0,
867+
top_k=0,
868+
top_p=1.0,
869+
eos_id=None):
870+
super(GPTForGeneration, self).__init__()
871+
self.gpt = gpt
872+
self.apply(self.init_weights)
873+
self.vocab_size = gpt.vocab_size
874+
self.eos_token_id = eos_id or 7
875+
876+
self.min_dec_len = min_length
877+
self.max_dec_len = max_length
878+
self.decoding_strategy = decoding_strategy
879+
self.temperature = temperature
880+
self.topk = top_k
881+
self.topp = top_p
882+
self._fuse = False
883+
self._init_gen_cache = False
884+
self.generation_caches = []
885+
self._dtype = "float32"
886+
887+
def _init_generation_caches(self, src_ids):
888+
if self._init_gen_cache:
889+
return self.generation_caches
890+
891+
num_heads = self.gpt.num_attention_heads
892+
num_layers = self.gpt.num_hidden_layers
893+
mp_n_head = num_heads // self.gpt.topo.mp_info.size
894+
hidden_size = self.gpt.hidden_size
895+
head_size = hidden_size // num_heads
896+
for i in range(num_layers):
897+
k = layers.fill_constant_batch_size_like(
898+
input=src_ids,
899+
shape=[-1, mp_n_head, 0, head_size],
900+
dtype=self._dtype,
901+
value=0)
902+
v = layers.fill_constant_batch_size_like(
903+
input=src_ids,
904+
shape=[-1, mp_n_head, 0, head_size],
905+
dtype=self._dtype,
906+
value=0)
907+
self.generation_caches.append(MultiHeadAttention.Cache(k, v))
908+
self._init_gen_cache = True
909+
return self.generation_caches
910+
911+
def parallel_matmul(self, lm_output, logit_weights, parallel_output, topo):
912+
if topo is not None and topo.mp_info.size > 1:
913+
input_parallel = paddle.distributed.collective._c_identity(
914+
lm_output, group=None)
915+
916+
logits = paddle.matmul(
917+
input_parallel, logit_weights, transpose_y=True)
918+
919+
if parallel_output:
920+
return logits
921+
922+
# TODO(qinqing): collective._c_concat is not support in static graph now
923+
return paddle.distributed.collective._c_concat(logits, group=None)
924+
else:
925+
logits = paddle.matmul(lm_output, logit_weights, transpose_y=True)
926+
return logits
927+
928+
def topk_sampling(self, probs):
929+
topk_probs, _ = paddle.topk(probs, self.topk)
930+
ge_cond = paddle.cast(
931+
paddle.greater_equal(probs,
932+
paddle.unsqueeze(topk_probs[:, -1], [1])),
933+
"float32")
934+
old_probs = probs
935+
probs = probs * ge_cond / paddle.sum(topk_probs, axis=-1, keepdim=True)
936+
sampling_ids = layers.sampling_id(probs, dtype="int")
937+
probs = old_probs
938+
return probs, sampling_ids
939+
940+
def topp_sampling(self, probs):
941+
sorted_probs, sorted_idx = layers.argsort(probs, descending=True)
942+
cum_sorted_probs = layers.cumsum(sorted_probs, axis=1, exclusive=True)
943+
lt_cond = paddle.cast(
944+
paddle.less_than(cum_sorted_probs,
945+
layers.fill_constant_batch_size_like(
946+
cum_sorted_probs, cum_sorted_probs.shape,
947+
cum_sorted_probs.dtype, self.topp)), "float32")
948+
old_probs = probs
949+
candidate_probs = sorted_probs * lt_cond
950+
probs = candidate_probs / paddle.sum(candidate_probs,
951+
axis=-1,
952+
keep_dim=True)
953+
sampling_ids = layers.sampling_id(probs, dtype="int")
954+
sampling_ids = paddle.index_sample(sorted_idx,
955+
paddle.unsqueeze(sampling_ids, [1]))
956+
sampling_ids = paddle.squeeze(sampling_ids, [1])
957+
probs = old_probs
958+
return probs, sampling_ids
959+
960+
def model(self,
961+
input_ids,
962+
position_ids=None,
963+
attention_mask=None,
964+
masked_positions=None,
965+
use_cache=False,
966+
cache=None):
967+
outputs = self.gpt(input_ids,
968+
position_ids=position_ids,
969+
attention_mask=attention_mask,
970+
use_cache=use_cache,
971+
cache=cache)
972+
if use_cache:
973+
encoder_outputs, cached_kvs = outputs[:2]
974+
else:
975+
encoder_outputs = outputs
976+
logits = self.parallel_matmul(
977+
encoder_outputs, self.gpt.embeddings.word_embeddings.weight, False,
978+
self.gpt.topo)
979+
if use_cache:
980+
return logits, cached_kvs
981+
else:
982+
return logits
983+
984+
def forward(self, inputs, use_cache=False, cache=None):
985+
"""
986+
Args:
987+
inputs (dict): include src_ids.
988+
pos_ids, input_mask and max_dec_len are optional.
989+
"""
990+
######### forward context #########
991+
input_ids = inputs['src_ids']
992+
position_ids = inputs['pos_ids'] if 'pos_ids' in inputs else None
993+
attention_mask = inputs[
994+
'input_mask'] if 'input_mask' in inputs else None
995+
996+
causal_mask = paddle.tensor.triu(
997+
paddle.ones((paddle.shape(input_ids)[-1],
998+
paddle.shape(input_ids)[-1])) * -1e4,
999+
diagonal=1)
1000+
if attention_mask is not None:
1001+
tgt_pos = paddle.sum(attention_mask, axis=-1,
1002+
keepdim=True).astype('int64')
1003+
if len(attention_mask.shape) == 2:
1004+
attention_mask = paddle.unsqueeze(attention_mask, axis=[1, 2])
1005+
encode_mask = attention_mask + causal_mask
1006+
else:
1007+
encode_mask = causal_mask
1008+
1009+
# if cached_kvs are assigned to next step in _prepare_qkv of MultiHeadAttention,
1010+
# need to init the global caches here
1011+
#gen_caches = self._init_generation_caches(input_ids)
1012+
1013+
logits, cached_kvs = self.model(
1014+
input_ids, position_ids, encode_mask, use_cache=True)
1015+
1016+
next_id = paddle.argmax(logits[:, -1, :], axis=-1).reshape([-1, 1])
1017+
####################################
1018+
1019+
if 'max_dec_len' not in inputs:
1020+
max_len = layers.fill_constant(
1021+
[1], dtype=int_type, value=self.max_dec_len, force_cpu=True)
1022+
else:
1023+
max_len = inputs['max_dec_len']
1024+
min_len = layers.fill_constant(
1025+
shape=[1], dtype=int_type, value=self.min_dec_len, force_cpu=True)
1026+
step_idx = layers.fill_constant(
1027+
shape=[1], value=0, dtype='int64', force_cpu=True)
1028+
1029+
placehold_ids = layers.fill_constant_batch_size_like(
1030+
input=inputs["src_ids"],
1031+
value=0,
1032+
shape=[-1, 1],
1033+
dtype=next_id.dtype)
1034+
ids = layers.array_write(next_id, step_idx)
1035+
1036+
if 'max_dec_len' in inputs:
1037+
max_len = paddle.tensor.creation._memcpy(
1038+
max_len, place=paddle.CPUPlace())
1039+
cond_int = paddle.full([1], 0, dtype=int_type, name="cond_int")
1040+
cond = paddle.less_than(step_idx, max_len)
1041+
1042+
if attention_mask is not None:
1043+
append_mask = layers.fill_constant_batch_size_like(
1044+
input=next_id,
1045+
value=1,
1046+
shape=[-1, 1, 1, 1],
1047+
dtype=attention_mask.dtype)
1048+
1049+
while_op = layers.While(cond, is_test=True)
1050+
with while_op.block():
1051+
pre_ids = layers.array_read(array=ids, i=step_idx)
1052+
if attention_mask:
1053+
decode_mask = paddle.concat(
1054+
[attention_mask, append_mask], axis=-1)
1055+
tgt_pos = tgt_pos + step_idx
1056+
att_mask = (1 - decode_mask) * -1e4
1057+
else:
1058+
att_mask = None
1059+
tgt_pos = None
1060+
1061+
layers.increment(x=step_idx, value=1.0, in_place=True)
1062+
layers.array_write(placehold_ids, i=step_idx, array=ids)
1063+
1064+
logits, decode_cached_kvs = self.model(
1065+
pre_ids, tgt_pos, att_mask, use_cache=True, cache=cached_kvs)
1066+
1067+
logits = paddle.reshape(logits, shape=(-1, self.vocab_size))
1068+
probs = F.softmax(logits / self.temperature)
1069+
1070+
if self.decoding_strategy.startswith("sampling"):
1071+
sampling_ids = layers.sampling_id(probs, dtype="int")
1072+
elif self.decoding_strategy.startswith("topk_sampling"):
1073+
probs, sampling_ids = self.topk_sampling(probs)
1074+
elif self.decoding_strategy.startswith("topp_sampling"):
1075+
probs, sampling_ids = self.topp_sampling(probs)
1076+
else:
1077+
raise ValueError(self.decoding_strategy)
1078+
1079+
selected_ids = paddle.unsqueeze(sampling_ids, -1)
1080+
layers.array_write(selected_ids, i=step_idx, array=ids)
1081+
1082+
length_cond = paddle.less_than(
1083+
x=step_idx, y=max_len, name="length_cond")
1084+
finish_cond = paddle.logical_not(
1085+
paddle.is_empty(x=selected_ids), name="finish_cond")
1086+
paddle.logical_and(
1087+
x=length_cond, y=finish_cond, out=cond, name="logical_and_cond")
1088+
1089+
paddle.assign(layers.cast(cond, dtype='bool'), cond)
1090+
if attention_mask:
1091+
paddle.assign(decode_mask, attention_mask)
1092+
for i in range(len(decode_cached_kvs)):
1093+
paddle.assign(decode_cached_kvs[i].k, cached_kvs[i].k)
1094+
paddle.assign(decode_cached_kvs[i].v, cached_kvs[i].v)
1095+
1096+
ids, _ = layers.tensor_array_to_tensor(ids)
1097+
return ids
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
set -x
2+
export PADDLE_WITH_GLOO=0
3+
export FLAGS_call_stack_level=2
4+
export FLAGS_allocator_strategy=naive_best_fit
5+
unset CUDA_VISIBLE_DEVICES
6+
7+
rm -rf main_sharding*
8+
9+
task_name="gpt-generation"
10+
rm -rf output/$task_name/log
11+
12+
python -u -m paddle.distributed.fleet.launch \
13+
--gpus "0" \
14+
--log_dir "output/$task_name/log" run_generation.py \
15+
--model_type "gpt" \
16+
--model_name_or_path "gpt2-medium-en" \
17+
--input_dir "./data" \
18+
--output_dir "output/$task_name" \
19+
--max_seq_len 1024 \
20+
--micro_batch_size 2 \
21+
--global_batch_size 2 \
22+
--sharding_degree 1 \
23+
--mp_degree 1 \
24+
--dp_degree 1 \
25+
--pp_degree 1 \
26+
--max_dec_len 20 \
27+
--decoding_strategy 'topk_sampling' \
28+
--topp 0.9 \
29+
--save_inference_model_then_exist true \
30+
--device "gpu"
31+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import absolute_import
16+
from __future__ import division
17+
from __future__ import print_function
18+
19+
import argparse
20+
import os
21+
import sys
22+
import random
23+
import math
24+
import time
25+
import numpy as np
26+
27+
os.path.expandvars('$HOME')
28+
os.path.expanduser('~')
29+
30+
import paddle
31+
import paddle.distributed.fleet as fleet
32+
33+
from paddlenlp.transformers import GPTTokenizer, GPTChineseTokenizer
34+
from paddlenlp.ops import guard, Topology, get_rng_state_tracker
35+
from paddlenlp.utils.log import logger
36+
import paddlenlp.ops as ops
37+
38+
from paddle.distributed import init_parallel_env
39+
40+
from modeling import GPTModel, GPTForPretraining, GPTForGeneration
41+
42+
# Used to load the data_tools path, should import before dataset
43+
filepath = os.path.abspath(os.path.dirname(__file__))
44+
sys.path.insert(0, os.path.join(filepath, "../../"))
45+
from dataset import create_pretrained_dataset
46+
from args import parse_args
47+
import lr
48+
49+
MODEL_CLASSES = {
50+
"gpt": (GPTForGeneration, GPTTokenizer),
51+
"gpt-cn": (GPTForGeneration, GPTChineseTokenizer),
52+
}
53+
54+
USE_LOCAL_HPI = True
55+
56+
device = "gpu"
57+
ascend = False
58+
int_type = "int64"
59+
device_id = int(os.environ.get('FLAGS_selected_gpus', 0))
60+
61+
# yapf: enable.
62+
63+
64+
def create_data_holder(args):
65+
shapes = [[-1, -1], [-1, -1], [-1, -1]]
66+
dtypes = [int_type, 'float32', int_type]
67+
names = ['src_ids', 'input_mask', 'pos_ids'] # three inputs
68+
#names = ['src_ids'] # one input
69+
70+
inputs = [
71+
paddle.static.data(
72+
name=names[i], shape=shapes[i], dtype=dtypes[i])
73+
for i in range(len(names))
74+
]
75+
return inputs
76+
77+
78+
def debug_program(name, program):
79+
with open("{}.txt.{}".format(name, device_id), 'w') as f:
80+
f.write(str(program))
81+
82+
83+
def get_data_file(args):
84+
files = [
85+
os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
86+
if (os.path.isfile(os.path.join(args.input_dir, f)) and str(f).endswith(
87+
"_idx.npz"))
88+
]
89+
files = [x.replace("_idx.npz", "") for x in files]
90+
if len(files) == 0:
91+
logger.warning(
92+
"Not found dataset with name of xxx_ids.npy and xxx_idx.npz! \
93+
Try to found old compatible xxx_ids.npz file.")
94+
else:
95+
return files
96+
files = [
97+
os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
98+
if (os.path.isfile(os.path.join(args.input_dir, f)) and str(f).endswith(
99+
"_ids.npz"))
100+
]
101+
files = [x.replace("_ids.npz", "") for x in files]
102+
return files
103+
104+
105+
def init_static_with_params(model, dygraph_params, topo, prog=None):
106+
from paddlenlp.utils.tools import dygraph_params_to_static
107+
static_params = dygraph_params_to_static(model, dygraph_params, topo)
108+
if prog is None:
109+
prog = paddle.static.default_main_program()
110+
paddle.static.set_program_state(prog, static_params)
111+
112+
113+
def do_generation(args):
114+
# Initialize the paddle and paddle fleet execute environment
115+
paddle.enable_static()
116+
117+
strategy = fleet.DistributedStrategy()
118+
strategy.hybrid_configs = {"dp_degree": 1, "mp_degree": 2, "pp_degree": 1}
119+
fleet.init(is_collective=True, strategy=strategy)
120+
121+
group = paddle.distributed.init_parallel_env()
122+
123+
# Create the random seed for the worker
124+
random.seed(args.seed)
125+
np.random.seed(args.seed)
126+
paddle.seed(args.seed)
127+
get_rng_state_tracker().add('global_seed', args.seed)
128+
get_rng_state_tracker().add('local_seed',
129+
args.seed + fleet.worker_index() + 2021)
130+
131+
if args.use_amp and args.amp_level == "O2":
132+
assert (args.mp_degree == 1 and args.pp_degree == 1
133+
), "When amp level is O2, mp_degree and pp_degree should be 1."
134+
assert (args.use_sharding == False
135+
), "When amp level is O2, use_sharding should be False."
136+
137+
assert args.device in [
138+
"cpu", "gpu", "xpu"
139+
], "Invalid device! Available device should be cpu, gpu, or xpu."
140+
place = paddle.set_device(args.device)
141+
142+
worker_num = fleet.worker_num()
143+
worker_index = fleet.worker_index()
144+
local_rank = 0 if fleet.local_rank() is None else int(fleet.local_rank())
145+
146+
topo = Topology(
147+
device_rank=worker_index,
148+
world_size=worker_num,
149+
dp_degree=args.dp_degree,
150+
pp_degree=args.pp_degree,
151+
sharding_degree=args.sharding_degree,
152+
mp_degree=args.mp_degree)
153+
154+
logger.info("The topo of hybrid parallelism:\n{}".format(topo))
155+
156+
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
157+
pretrained_models_list = list(
158+
model_class.pretrained_init_configuration.keys())
159+
160+
data_file = get_data_file(args)
161+
main_program = paddle.static.default_main_program()
162+
startup_program = paddle.static.default_startup_program()
163+
with paddle.static.program_guard(main_program, startup_program):
164+
with paddle.utils.unique_name.guard():
165+
with paddle.static.device_guard('gpu:0'):
166+
feeds = create_data_holder(args)
167+
tokenizer = tokenizer_class.from_pretrained(
168+
args.model_name_or_path)
169+
eos_id = tokenizer.eos_token_id
170+
171+
_, _, test_data_loader = create_pretrained_dataset(
172+
args,
173+
data_file,
174+
local_rank=local_rank,
175+
data_world_size=topo.data_info.size,
176+
data_world_rank=topo.data_info.rank,
177+
eos_id=eos_id,
178+
max_seq_len=args.max_seq_len,
179+
places=paddle.static.cuda_places(),
180+
data_holders=feeds,
181+
pipeline_mode=False)
182+
183+
if args.model_name_or_path in pretrained_models_list:
184+
model_config = model_class.pretrained_init_configuration[
185+
args.model_name_or_path]
186+
model_config[
187+
"hidden_dropout_prob"] = args.hidden_dropout_prob
188+
model_config[
189+
"attention_probs_dropout_prob"] = args.attention_probs_dropout_prob
190+
model_config["topo"] = topo
191+
model = GPTForGeneration(
192+
GPTModel(**model_config),
193+
max_length=args.max_dec_len,
194+
decoding_strategy=args.decoding_strategy,
195+
temperature=args.temperature,
196+
top_k=args.topk,
197+
top_p=args.topp,
198+
eos_id=eos_id)
199+
else:
200+
logger.error("No checkpoint load.")
201+
model.eval()
202+
ins = {v.name: v for v in feeds}
203+
preds = model(ins)
204+
205+
# Define the Executor for running the static model
206+
exe = paddle.static.Executor(place)
207+
exe.run(startup_program)
208+
main_program = main_program.clone(for_test=True)
209+
210+
model_urls = model.pretrained_resource_files_map['model_state']
211+
model_path = args.model_name_or_path
212+
if model_path in pretrained_models_list and model_path in model_urls:
213+
flag_loaded = False
214+
from paddle.utils.download import get_weights_path_from_url
215+
dygraph_path = get_weights_path_from_url(model_urls[model_path])
216+
if os.path.exists(dygraph_path):
217+
if args.sharding_degree > 1:
218+
logger.warning("Sharding should init with static vars")
219+
else:
220+
logger.info("Loading parameters from %s" % dygraph_path)
221+
init_static_with_params(
222+
model,
223+
paddle.load(
224+
dygraph_path, return_numpy=True),
225+
topo,
226+
main_program)
227+
flag_loaded = True
228+
if not flag_loaded:
229+
logger.error("No checkpoint load.")
230+
231+
global_step = 0
232+
epoch = 0
233+
fetchs = [preds]
234+
235+
### check resutls
236+
text = [
237+
"Question: Where is the capital of China? Answer:",
238+
"Question:Who is the CEO of Apple? Answer:"
239+
]
240+
inputs = tokenizer(
241+
text,
242+
padding=True,
243+
return_attention_mask=True,
244+
return_position_ids=True)
245+
ids = np.array(inputs["input_ids"]).reshape(len(text), -1).astype('int64')
246+
position_ids = np.array(inputs["position_ids"]).reshape(len(text),
247+
-1).astype('int64')
248+
attention_mask = np.array(inputs["attention_mask"]).reshape(
249+
len(text), -1).astype('float32')
250+
251+
t_ids = paddle.fluid.core.Tensor()
252+
t_ids.set(ids, place)
253+
t_mask = paddle.fluid.core.Tensor()
254+
t_mask.set(attention_mask, place)
255+
t_pos = paddle.fluid.core.Tensor()
256+
t_pos.set(position_ids, place)
257+
feed_data = {'src_ids': t_ids, 'pos_ids': t_pos, 'input_mask': t_mask}
258+
ret = exe.run(main_program, feed=feed_data, fetch_list=fetchs)
259+
ret = np.array(ret[0])
260+
for i in range(ret.shape[0]):
261+
o = [int(x) for x in ret[i]]
262+
ret_str = tokenizer.convert_ids_to_string(o)
263+
ret_str = text[i] + ret_str
264+
logger.info(ret_str)
265+
##################
266+
267+
for step, batch in enumerate(test_data_loader()):
268+
ret = exe.run(main_program, feed=batch, fetch_list=fetchs)
269+
if step == 5:
270+
break
271+
272+
if args.save_inference_model_then_exist:
273+
save_inference_model_dir = 'inference_model_pp{pp_degree}mp{mp_degree}'.format(
274+
pp_degree=args.pp_degree, mp_degree=args.mp_degree)
275+
inference_save_path = os.path.join(save_inference_model_dir,
276+
'rank_' + str(fleet.worker_index()),
277+
'step_' + str(0))
278+
print("saving inference models to {}".format(inference_save_path))
279+
feed_names = [v.name for v in feeds]
280+
fetchs_names = [v.name for v in fetchs]
281+
print('feeds: ', feed_names, 'fetches: ', fetchs_names)
282+
paddle.static.save_inference_model(
283+
inference_save_path, feeds, fetchs, exe, program=main_program)
284+
285+
286+
if __name__ == '__main__':
287+
args = parse_args(MODEL_CLASSES)
288+
do_generation(args)

‎examples/language_model/gpt-3/static/run_static.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ python -u -m paddle.distributed.fleet.launch \
2222
--max_seq_len 1024 \
2323
--micro_batch_size 8 \
2424
--global_batch_size 16 \
25-
--sharding_degree 2\
25+
--sharding_degree 2 \
2626
--mp_degree 2 \
2727
--dp_degree 1 \
2828
--pp_degree 1 \

‎model_zoo/gpt/args.py

+1
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def parse_args(MODEL_CLASSES):
258258
default=None,
259259
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
260260
)
261+
261262
args = parser.parse_args()
262263
args.test_iters = args.eval_iters * 10
263264

0 commit comments

Comments
 (0)
Please sign in to comment.