2
2
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
3
3
4
4
from typing import List
5
-
5
+ import time
6
+ import numpy as np
7
+ import os
6
8
import torch
7
9
8
10
from llama .tokenizer import Tokenizer
9
11
from llama .model import Transformer
12
+ from llama import is_torch_gcu_available
13
+
14
+ torch .autograd .set_detect_anomaly (True )
15
+
16
+ if is_torch_gcu_available ():
17
+ import torch_gcu
18
+ import torch_gcu .distributed as dist
19
+ local_rank = int (os .environ ['LOCAL_RANK' ]) if 'LOCAL_RANK' in os .environ else dist .get_rank ()
20
+ #gcu_device = torch_gcu.gcu_device(local_rank * int(os.getenv("LEO_CLUSTER_NUM", '1')))
21
+ gcu_device = torch_gcu .gcu_device (local_rank )
22
+ else :
23
+ import torch as torch_gcu
10
24
11
25
12
26
class LLaMA :
13
27
def __init__ (self , model : Transformer , tokenizer : Tokenizer ):
14
28
self .model = model
15
29
self .tokenizer = tokenizer
30
+ self .max_prompts_len = 32
31
+ self .max_seq_len = 512
32
+
33
+ def gen_mask_stage_0 (self , tokens : torch .Tensor , pad_id : int ):
34
+ temp = torch .full ((1 , 1 , self .max_prompts_len , self .max_prompts_len ), - 65500.0 , device = "cpu" )
35
+ temp = torch .triu (temp , diagonal = 1 )
36
+ expand_tokens = tokens [:, None , None , :].expand (1 , 1 , self .max_prompts_len , self .max_prompts_len )
37
+ temp .masked_fill_ (expand_tokens == pad_id , - 65500.0 )
38
+ temp [0 ,0 ,:,:].fill_diagonal_ (fill_value = 0. , wrap = False ).reshape (1 ,1 ,self .max_prompts_len ,self .max_prompts_len )
39
+ mask = torch .full ((1 , 1 , self .max_prompts_len , self .max_seq_len ), - 65500.0 , device = "cpu" )
40
+ mask [0 , 0 , :, - self .max_prompts_len :] = temp
41
+ return mask .to (gcu_device )
42
+
43
+ def gen_mask_stage_1 (self , cur_pos : int ):
44
+ mask = torch .full ((1 , 1 , 1 , self .max_seq_len ), - 65500.0 , device = "cpu" )
45
+ mask [:, :, :, self .max_seq_len - cur_pos :] = 0
46
+ return mask .to (gcu_device )
16
47
17
48
def generate (
18
49
self ,
@@ -31,46 +62,69 @@ def generate(
31
62
max_prompt_size = max ([len (t ) for t in prompt_tokens ])
32
63
33
64
total_len = min (params .max_seq_len , max_gen_len + max_prompt_size )
65
+ total_padding_len = params .max_seq_len
66
+ if not is_torch_gcu_available ():
67
+ tokens = torch .full ((bsz , total_len ), self .tokenizer .pad_id ).cuda ().long ()
68
+
69
+ else :
70
+ tokens = torch .full ((bsz , total_padding_len ), 0 ,device = "cpu" )
71
+ tokens = tokens .long ()
34
72
35
- tokens = torch .full ((bsz , total_len ), self .tokenizer .pad_id ).cuda ().long ()
36
73
for k , t in enumerate (prompt_tokens ):
37
- tokens [k , : len (t )] = torch .tensor (t ).long ()
38
- input_text_mask = tokens != self .tokenizer .pad_id
74
+ assert len (t ) <= self .max_prompts_len , \
75
+ f"prompt size of { prompts [k ]} ({ len (t )} ) is greater than max_prompts_len: { self .max_prompts_len } "
76
+ if not is_torch_gcu_available ():
77
+ tokens [k , : len (t )] = torch .tensor (t ).long ()
78
+ else :
79
+ tokens [k , - len (t ):] = torch .tensor (t ).long ()
39
80
start_pos = min_prompt_size
40
81
prev_pos = 0
82
+ token_time_list = list ()
41
83
for cur_pos in range (start_pos , total_len ):
42
- logits = self .model .forward (tokens [:, prev_pos :cur_pos ], prev_pos )
84
+ start_time = time .time ()
85
+ if prev_pos == 0 :
86
+ mask = self .gen_mask_stage_0 (tokens [:, - self .max_prompts_len :], 0 );
87
+ logits = self .model .forward (tokens [:, - self .max_prompts_len :].to (gcu_device ), start_pos = prev_pos , mask = mask )
88
+ else :
89
+ mask = self .gen_mask_stage_1 (cur_pos )
90
+ logits = self .model .forward (tokens [:, - 1 :].to (gcu_device ), start_pos = prev_pos , mask = mask )
43
91
if temperature > 0 :
44
92
probs = torch .softmax (logits / temperature , dim = - 1 )
45
93
next_token = sample_top_p (probs , top_p )
46
94
else :
47
95
next_token = torch .argmax (logits , dim = - 1 )
48
- next_token = next_token .reshape (- 1 )
96
+
97
+ next_token = next_token .reshape (tokens .shape [0 ],- 1 ).cpu ()
49
98
# only replace token if prompt has already been generated
50
- next_token = torch .where (
51
- input_text_mask [:, cur_pos ], tokens [:, cur_pos ], next_token
52
- )
53
- tokens [:, cur_pos ] = next_token
99
+ tokens = torch .cat ([tokens ,next_token ],dim = 1 )
100
+ tokens = tokens [:, 1 :]
54
101
prev_pos = cur_pos
102
+ end_time = time .time ()
103
+ token_time_list .append (end_time - start_time )
55
104
56
105
decoded = []
57
106
for i , t in enumerate (tokens .tolist ()):
58
107
# cut to max gen len
59
- t = t [: len (prompt_tokens [i ]) + max_gen_len ]
108
+ t = t [- len (prompt_tokens [i ]) - max_gen_len : ]
60
109
# cut to eos tok if any
61
110
try :
62
111
t = t [: t .index (self .tokenizer .eos_id )]
63
112
except ValueError :
64
113
pass
65
114
decoded .append (self .tokenizer .decode (t ))
66
- return decoded
115
+ return decoded , token_time_list
67
116
68
117
69
118
def sample_top_p (probs , p ):
70
119
probs_sort , probs_idx = torch .sort (probs , dim = - 1 , descending = True )
71
- probs_sum = torch .cumsum (probs_sort , dim = - 1 )
120
+ # call sync_lived_tensor to avoid repeat computing in different subgraphs
121
+ torch_gcu .sync_lived_tensor ()
122
+ itemp = probs_sort .cpu ()
123
+ probs_sum = torch .cumsum (itemp , dim = - 1 )
124
+ probs_sum = probs_sum .to (gcu_device )
72
125
mask = probs_sum - probs_sort > p
73
- probs_sort [mask ] = 0.0
126
+ #probs_sort[mask] = 0.0
127
+ probs_sort .masked_fill_ (mask , 0.0 )
74
128
probs_sort .div_ (probs_sort .sum (dim = - 1 , keepdim = True ))
75
129
next_token = torch .multinomial (probs_sort , num_samples = 1 )
76
130
next_token = torch .gather (probs_idx , - 1 , next_token )
0 commit comments