|
29 | 29 | import paddlenlp
|
30 | 30 |
|
31 | 31 | __all__ = [
|
32 |
| - 'GPTModel', |
33 |
| - 'GPTForPretraining', |
34 |
| - 'GPTPretrainingCriterion', |
| 32 | + 'GPTModel', 'GPTForPretraining', 'GPTPretrainingCriterion', |
| 33 | + 'GPTForGeneration' |
35 | 34 | ]
|
36 | 35 |
|
| 36 | +device = "gpu" |
| 37 | +int_type = "int64" |
| 38 | + |
37 | 39 |
|
38 | 40 | class MultiHeadAttention(nn.Layer):
|
39 | 41 | """
|
@@ -153,6 +155,11 @@ def _prepare_qkv(self, query, key, value, use_cache=False, cache=None):
|
153 | 155 | # for decoder self-attention in inference
|
154 | 156 | k = tensor.concat([cache.k, k], axis=2)
|
155 | 157 | v = tensor.concat([cache.v, v], axis=2)
|
| 158 | + |
| 159 | + ## if not assign here, assign in While loop |
| 160 | + #layers.assign(k, cache.k) # update caches |
| 161 | + #layers.assign(v, cache.v) |
| 162 | + |
156 | 163 | if use_cache is True:
|
157 | 164 | cache = self.Cache(k, v)
|
158 | 165 |
|
@@ -229,7 +236,12 @@ def forward(self,
|
229 | 236 | product = layers.matmul(
|
230 | 237 | x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)
|
231 | 238 |
|
232 |
| - weights = incubate.softmax_mask_fuse_upper_triangle(product) |
| 239 | + if self.training: |
| 240 | + weights = incubate.softmax_mask_fuse_upper_triangle(product) |
| 241 | + else: |
| 242 | + if attn_mask is not None: |
| 243 | + product = product + attn_mask |
| 244 | + weights = F.softmax(product) |
233 | 245 |
|
234 | 246 | if self.dropout:
|
235 | 247 | weights = F.dropout(
|
@@ -311,12 +323,20 @@ def forward(self,
|
311 | 323 | cache=cache)
|
312 | 324 |
|
313 | 325 | else:
|
314 |
| - output, new_cache = mod(output, |
315 |
| - memory, |
316 |
| - tgt_mask=tgt_mask, |
317 |
| - use_cache=use_cache, |
318 |
| - cache=cache[i]) |
319 |
| - new_caches.append(new_cache) |
| 326 | + if use_cache: |
| 327 | + output, new_cache = mod(output, |
| 328 | + memory, |
| 329 | + tgt_mask=tgt_mask, |
| 330 | + use_cache=use_cache, |
| 331 | + cache=cache[i]) |
| 332 | + new_caches.append(new_cache) |
| 333 | + else: |
| 334 | + output = mod(output, |
| 335 | + memory, |
| 336 | + tgt_mask=tgt_mask, |
| 337 | + use_cache=use_cache, |
| 338 | + cache=cache[i]) |
| 339 | + |
320 | 340 | self.checkpoints.append(output.name)
|
321 | 341 |
|
322 | 342 | if self.norm is not None:
|
@@ -675,6 +695,9 @@ def __init__(self,
|
675 | 695 | self.topo = topo
|
676 | 696 | self.hidden_size = hidden_size
|
677 | 697 | self.vocab_size = vocab_size
|
| 698 | + self.num_attention_heads = num_attention_heads |
| 699 | + self.num_hidden_layers = num_hidden_layers |
| 700 | + self.hidden_size = hidden_size |
678 | 701 |
|
679 | 702 | self.pipline_mode = topo is not None and topo.pp_info.size > 1
|
680 | 703 | if self.pipline_mode:
|
@@ -738,16 +761,18 @@ def forward(self,
|
738 | 761 | paddle.shape(input_ids)[-1] + past_length,
|
739 | 762 | dtype='int64')
|
740 | 763 | position_ids = position_ids.unsqueeze(0)
|
741 |
| - # .expand_as(input_ids) |
742 | 764 | position_ids = paddle.fluid.layers.expand_as(position_ids,
|
743 | 765 | input_ids)
|
744 | 766 | embedding_output = self.embeddings(
|
745 | 767 | input_ids=input_ids, position_ids=position_ids)
|
746 | 768 |
|
| 769 | + tgt_mask = None |
| 770 | + if not self.training: |
| 771 | + tgt_mask = attention_mask |
747 | 772 | encoder_outputs = self.decoder(
|
748 | 773 | embedding_output,
|
749 | 774 | memory=None,
|
750 |
| - tgt_mask=None, |
| 775 | + tgt_mask=tgt_mask, |
751 | 776 | use_cache=use_cache,
|
752 | 777 | cache=cache)
|
753 | 778 | self.checkpoints.extend(self.decoder.checkpoints)
|
@@ -830,3 +855,243 @@ def forward(self, prediction_scores, masked_lm_labels, loss_mask):
|
830 | 855 | masked_lm_loss = paddle.sum(masked_lm_loss.reshape([-1]) * loss_mask)
|
831 | 856 | loss = masked_lm_loss / loss_mask.sum()
|
832 | 857 | return loss
|
| 858 | + |
| 859 | + |
| 860 | +class GPTForGeneration(GPTPretrainedModel): |
| 861 | + def __init__(self, |
| 862 | + gpt, |
| 863 | + max_length=20, |
| 864 | + min_length=0, |
| 865 | + decoding_strategy='sampling', |
| 866 | + temperature=1.0, |
| 867 | + top_k=0, |
| 868 | + top_p=1.0, |
| 869 | + eos_id=None): |
| 870 | + super(GPTForGeneration, self).__init__() |
| 871 | + self.gpt = gpt |
| 872 | + self.apply(self.init_weights) |
| 873 | + self.vocab_size = gpt.vocab_size |
| 874 | + self.eos_token_id = eos_id or 7 |
| 875 | + |
| 876 | + self.min_dec_len = min_length |
| 877 | + self.max_dec_len = max_length |
| 878 | + self.decoding_strategy = decoding_strategy |
| 879 | + self.temperature = temperature |
| 880 | + self.topk = top_k |
| 881 | + self.topp = top_p |
| 882 | + self._fuse = False |
| 883 | + self._init_gen_cache = False |
| 884 | + self.generation_caches = [] |
| 885 | + self._dtype = "float32" |
| 886 | + |
| 887 | + def _init_generation_caches(self, src_ids): |
| 888 | + if self._init_gen_cache: |
| 889 | + return self.generation_caches |
| 890 | + |
| 891 | + num_heads = self.gpt.num_attention_heads |
| 892 | + num_layers = self.gpt.num_hidden_layers |
| 893 | + mp_n_head = num_heads // self.gpt.topo.mp_info.size |
| 894 | + hidden_size = self.gpt.hidden_size |
| 895 | + head_size = hidden_size // num_heads |
| 896 | + for i in range(num_layers): |
| 897 | + k = layers.fill_constant_batch_size_like( |
| 898 | + input=src_ids, |
| 899 | + shape=[-1, mp_n_head, 0, head_size], |
| 900 | + dtype=self._dtype, |
| 901 | + value=0) |
| 902 | + v = layers.fill_constant_batch_size_like( |
| 903 | + input=src_ids, |
| 904 | + shape=[-1, mp_n_head, 0, head_size], |
| 905 | + dtype=self._dtype, |
| 906 | + value=0) |
| 907 | + self.generation_caches.append(MultiHeadAttention.Cache(k, v)) |
| 908 | + self._init_gen_cache = True |
| 909 | + return self.generation_caches |
| 910 | + |
| 911 | + def parallel_matmul(self, lm_output, logit_weights, parallel_output, topo): |
| 912 | + if topo is not None and topo.mp_info.size > 1: |
| 913 | + input_parallel = paddle.distributed.collective._c_identity( |
| 914 | + lm_output, group=None) |
| 915 | + |
| 916 | + logits = paddle.matmul( |
| 917 | + input_parallel, logit_weights, transpose_y=True) |
| 918 | + |
| 919 | + if parallel_output: |
| 920 | + return logits |
| 921 | + |
| 922 | + # TODO(qinqing): collective._c_concat is not support in static graph now |
| 923 | + return paddle.distributed.collective._c_concat(logits, group=None) |
| 924 | + else: |
| 925 | + logits = paddle.matmul(lm_output, logit_weights, transpose_y=True) |
| 926 | + return logits |
| 927 | + |
| 928 | + def topk_sampling(self, probs): |
| 929 | + topk_probs, _ = paddle.topk(probs, self.topk) |
| 930 | + ge_cond = paddle.cast( |
| 931 | + paddle.greater_equal(probs, |
| 932 | + paddle.unsqueeze(topk_probs[:, -1], [1])), |
| 933 | + "float32") |
| 934 | + old_probs = probs |
| 935 | + probs = probs * ge_cond / paddle.sum(topk_probs, axis=-1, keepdim=True) |
| 936 | + sampling_ids = layers.sampling_id(probs, dtype="int") |
| 937 | + probs = old_probs |
| 938 | + return probs, sampling_ids |
| 939 | + |
| 940 | + def topp_sampling(self, probs): |
| 941 | + sorted_probs, sorted_idx = layers.argsort(probs, descending=True) |
| 942 | + cum_sorted_probs = layers.cumsum(sorted_probs, axis=1, exclusive=True) |
| 943 | + lt_cond = paddle.cast( |
| 944 | + paddle.less_than(cum_sorted_probs, |
| 945 | + layers.fill_constant_batch_size_like( |
| 946 | + cum_sorted_probs, cum_sorted_probs.shape, |
| 947 | + cum_sorted_probs.dtype, self.topp)), "float32") |
| 948 | + old_probs = probs |
| 949 | + candidate_probs = sorted_probs * lt_cond |
| 950 | + probs = candidate_probs / paddle.sum(candidate_probs, |
| 951 | + axis=-1, |
| 952 | + keep_dim=True) |
| 953 | + sampling_ids = layers.sampling_id(probs, dtype="int") |
| 954 | + sampling_ids = paddle.index_sample(sorted_idx, |
| 955 | + paddle.unsqueeze(sampling_ids, [1])) |
| 956 | + sampling_ids = paddle.squeeze(sampling_ids, [1]) |
| 957 | + probs = old_probs |
| 958 | + return probs, sampling_ids |
| 959 | + |
| 960 | + def model(self, |
| 961 | + input_ids, |
| 962 | + position_ids=None, |
| 963 | + attention_mask=None, |
| 964 | + masked_positions=None, |
| 965 | + use_cache=False, |
| 966 | + cache=None): |
| 967 | + outputs = self.gpt(input_ids, |
| 968 | + position_ids=position_ids, |
| 969 | + attention_mask=attention_mask, |
| 970 | + use_cache=use_cache, |
| 971 | + cache=cache) |
| 972 | + if use_cache: |
| 973 | + encoder_outputs, cached_kvs = outputs[:2] |
| 974 | + else: |
| 975 | + encoder_outputs = outputs |
| 976 | + logits = self.parallel_matmul( |
| 977 | + encoder_outputs, self.gpt.embeddings.word_embeddings.weight, False, |
| 978 | + self.gpt.topo) |
| 979 | + if use_cache: |
| 980 | + return logits, cached_kvs |
| 981 | + else: |
| 982 | + return logits |
| 983 | + |
| 984 | + def forward(self, inputs, use_cache=False, cache=None): |
| 985 | + """ |
| 986 | + Args: |
| 987 | + inputs (dict): include src_ids. |
| 988 | + pos_ids, input_mask and max_dec_len are optional. |
| 989 | + """ |
| 990 | + ######### forward context ######### |
| 991 | + input_ids = inputs['src_ids'] |
| 992 | + position_ids = inputs['pos_ids'] if 'pos_ids' in inputs else None |
| 993 | + attention_mask = inputs[ |
| 994 | + 'input_mask'] if 'input_mask' in inputs else None |
| 995 | + |
| 996 | + causal_mask = paddle.tensor.triu( |
| 997 | + paddle.ones((paddle.shape(input_ids)[-1], |
| 998 | + paddle.shape(input_ids)[-1])) * -1e4, |
| 999 | + diagonal=1) |
| 1000 | + if attention_mask is not None: |
| 1001 | + tgt_pos = paddle.sum(attention_mask, axis=-1, |
| 1002 | + keepdim=True).astype('int64') |
| 1003 | + if len(attention_mask.shape) == 2: |
| 1004 | + attention_mask = paddle.unsqueeze(attention_mask, axis=[1, 2]) |
| 1005 | + encode_mask = attention_mask + causal_mask |
| 1006 | + else: |
| 1007 | + encode_mask = causal_mask |
| 1008 | + |
| 1009 | + # if cached_kvs are assigned to next step in _prepare_qkv of MultiHeadAttention, |
| 1010 | + # need to init the global caches here |
| 1011 | + #gen_caches = self._init_generation_caches(input_ids) |
| 1012 | + |
| 1013 | + logits, cached_kvs = self.model( |
| 1014 | + input_ids, position_ids, encode_mask, use_cache=True) |
| 1015 | + |
| 1016 | + next_id = paddle.argmax(logits[:, -1, :], axis=-1).reshape([-1, 1]) |
| 1017 | + #################################### |
| 1018 | + |
| 1019 | + if 'max_dec_len' not in inputs: |
| 1020 | + max_len = layers.fill_constant( |
| 1021 | + [1], dtype=int_type, value=self.max_dec_len, force_cpu=True) |
| 1022 | + else: |
| 1023 | + max_len = inputs['max_dec_len'] |
| 1024 | + min_len = layers.fill_constant( |
| 1025 | + shape=[1], dtype=int_type, value=self.min_dec_len, force_cpu=True) |
| 1026 | + step_idx = layers.fill_constant( |
| 1027 | + shape=[1], value=0, dtype='int64', force_cpu=True) |
| 1028 | + |
| 1029 | + placehold_ids = layers.fill_constant_batch_size_like( |
| 1030 | + input=inputs["src_ids"], |
| 1031 | + value=0, |
| 1032 | + shape=[-1, 1], |
| 1033 | + dtype=next_id.dtype) |
| 1034 | + ids = layers.array_write(next_id, step_idx) |
| 1035 | + |
| 1036 | + if 'max_dec_len' in inputs: |
| 1037 | + max_len = paddle.tensor.creation._memcpy( |
| 1038 | + max_len, place=paddle.CPUPlace()) |
| 1039 | + cond_int = paddle.full([1], 0, dtype=int_type, name="cond_int") |
| 1040 | + cond = paddle.less_than(step_idx, max_len) |
| 1041 | + |
| 1042 | + if attention_mask is not None: |
| 1043 | + append_mask = layers.fill_constant_batch_size_like( |
| 1044 | + input=next_id, |
| 1045 | + value=1, |
| 1046 | + shape=[-1, 1, 1, 1], |
| 1047 | + dtype=attention_mask.dtype) |
| 1048 | + |
| 1049 | + while_op = layers.While(cond, is_test=True) |
| 1050 | + with while_op.block(): |
| 1051 | + pre_ids = layers.array_read(array=ids, i=step_idx) |
| 1052 | + if attention_mask: |
| 1053 | + decode_mask = paddle.concat( |
| 1054 | + [attention_mask, append_mask], axis=-1) |
| 1055 | + tgt_pos = tgt_pos + step_idx |
| 1056 | + att_mask = (1 - decode_mask) * -1e4 |
| 1057 | + else: |
| 1058 | + att_mask = None |
| 1059 | + tgt_pos = None |
| 1060 | + |
| 1061 | + layers.increment(x=step_idx, value=1.0, in_place=True) |
| 1062 | + layers.array_write(placehold_ids, i=step_idx, array=ids) |
| 1063 | + |
| 1064 | + logits, decode_cached_kvs = self.model( |
| 1065 | + pre_ids, tgt_pos, att_mask, use_cache=True, cache=cached_kvs) |
| 1066 | + |
| 1067 | + logits = paddle.reshape(logits, shape=(-1, self.vocab_size)) |
| 1068 | + probs = F.softmax(logits / self.temperature) |
| 1069 | + |
| 1070 | + if self.decoding_strategy.startswith("sampling"): |
| 1071 | + sampling_ids = layers.sampling_id(probs, dtype="int") |
| 1072 | + elif self.decoding_strategy.startswith("topk_sampling"): |
| 1073 | + probs, sampling_ids = self.topk_sampling(probs) |
| 1074 | + elif self.decoding_strategy.startswith("topp_sampling"): |
| 1075 | + probs, sampling_ids = self.topp_sampling(probs) |
| 1076 | + else: |
| 1077 | + raise ValueError(self.decoding_strategy) |
| 1078 | + |
| 1079 | + selected_ids = paddle.unsqueeze(sampling_ids, -1) |
| 1080 | + layers.array_write(selected_ids, i=step_idx, array=ids) |
| 1081 | + |
| 1082 | + length_cond = paddle.less_than( |
| 1083 | + x=step_idx, y=max_len, name="length_cond") |
| 1084 | + finish_cond = paddle.logical_not( |
| 1085 | + paddle.is_empty(x=selected_ids), name="finish_cond") |
| 1086 | + paddle.logical_and( |
| 1087 | + x=length_cond, y=finish_cond, out=cond, name="logical_and_cond") |
| 1088 | + |
| 1089 | + paddle.assign(layers.cast(cond, dtype='bool'), cond) |
| 1090 | + if attention_mask: |
| 1091 | + paddle.assign(decode_mask, attention_mask) |
| 1092 | + for i in range(len(decode_cached_kvs)): |
| 1093 | + paddle.assign(decode_cached_kvs[i].k, cached_kvs[i].k) |
| 1094 | + paddle.assign(decode_cached_kvs[i].v, cached_kvs[i].v) |
| 1095 | + |
| 1096 | + ids, _ = layers.tensor_array_to_tensor(ids) |
| 1097 | + return ids |
0 commit comments