3
3
import os
4
4
import sys
5
5
import time
6
+ import json
6
7
import datetime
7
8
import argparse
8
9
import resource
9
10
import socket
10
11
import threading
11
12
import torch
13
+ import tinychat
12
14
13
15
from transformers import AutoModelForCausalLM , AutoTokenizer , TextIteratorStreamer
14
16
from accelerate import init_empty_weights , load_checkpoint_and_dispatch
25
27
26
28
parser .add_argument ('--model' , type = str , default = '' , required = True , help = "name or path of the huggingface model" )
27
29
parser .add_argument ('--quant' , type = str , default = '' , required = True , help = "path to the real AWQ quantized model checkpoint" )
28
- parser .add_argument (' --prompt' , type = str , default = 'Once upon a time, ' )
30
+ parser .add_argument (" --prompt" , action = 'append' , nargs = '* ' )
29
31
30
32
# benchmarking options
31
- parser .add_argument ('--tokens' , type = int , default = 128 , help = 'number of output tokens to generate, including the input prompt' )
32
- parser .add_argument ('--runs' , type = int , default = 2 , help = 'the number of benchmark timing iterations' )
33
- parser .add_argument ('--warmup' , type = int , default = 2 , help = 'the number of warmup iterations' )
33
+ parser .add_argument ("--max-new-tokens" , type = int , default = 128 )
34
+ parser .add_argument ("--max-num-prompts" , type = int , default = None )
34
35
parser .add_argument ('--save' , type = str , default = '' , help = 'CSV file to save benchmarking results to' )
35
36
36
37
# quantization options
37
38
parser .add_argument ('--w_bit' , type = int , default = 4 )
38
39
parser .add_argument ('--q_group_size' , type = int , default = 128 )
39
40
parser .add_argument ('--no_zero_point' , action = 'store_true' , help = "disable zero_point" )
40
41
parser .add_argument ('--no_tinychat_kernels' , action = 'store_true' , help = "disable tinychat kernels" )
41
- parser .add_argument ('--no_tinychat_infer' , action = 'store_true' , help = "disable tinychat inference" )
42
- parser .add_argument ('--no_streaming' , action = 'store_true' , help = "disable streaming mode" )
42
+ #parser.add_argument('--no_tinychat_infer', action='store_true', help="disable tinychat inference")
43
43
parser .add_argument ('--no_quant' , action = 'store_true' , help = "disable quantization and use FP16 through transformers" )
44
44
parser .add_argument ('--do_sample' , action = 'store_true' )
45
45
46
46
args = parser .parse_args ()
47
47
48
- #args.prompt="Once upon a time, there was a young man named Jack who lived in a small village nestled in the rolling hills of the countryside. Jack was a curious and adventurous soul, always eager to explore the world beyond his village. One day, while wandering through the nearby forest, he stumbled upon a hidden path that he had never seen before. The path was overgrown with weeds and vines, and it looked as though it had been untouched for many years. Jack's curiosity was piqued, and he decided to follow the path to see where it led"
49
- #args.prompt="Once upon a time, there was a young man named Jack who lived in a small village nestled in the rolling hills of the countryside. Jack was a curious and adventurous soul, always eager to explore the world beyond his village. One day, while wandering through the nearby forest, he stumbled upon a hidden path that he had never seen before. The path was overgrown with weeds and vines, and it looked as though it had been untouched for many years. Jack's curiosity was piqued, and he decided to follow the path to see where it led. As he walked down the path, the trees grew taller and the air grew colder. Jack could feel a strange energy emanating from the forest, as if it were alive and watching him. He quickened his pace, eager to reach the end of the path and discover its secrets. After a while, the path opened up into a clearing, and Jack found himself standing in front of a massive stone structure. The building was unlike anything he had ever seen before, with intricate carvings and symbols etched into its walls. Jack felt a sense of awe and wonder as he approached the"
50
-
48
+ if not args .prompt :
49
+ if args .chat : # https://modal.com/docs/guide/ex/vllm_inference
50
+ args .prompt = [
51
+ "What is the meaning of life?" ,
52
+ "How many points did you list out?" ,
53
+ "What is the weather forecast today?" ,
54
+ "What is the fable involving a fox and grapes?" ,
55
+ "What's a good recipe for making tabouli?" ,
56
+ "What is the product of 9 and 8?" ,
57
+ "If a train travels 120 miles in 2 hours, what is its average speed?" ,
58
+ ]
59
+ else :
60
+ args .prompt = [
61
+ "Once upon a time," ,
62
+ "A great place to live is" ,
63
+ "In a world where dreams are shared," ,
64
+ "The weather forecast today is" ,
65
+ "Large language models are" ,
66
+ "Space exploration is exciting" ,
67
+ "The history of the Hoover Dam is" ,
68
+ "San Fransisco is a city in" ,
69
+ "To train for running a marathon," ,
70
+ "A recipe for making tabouli is"
71
+ ]
72
+ else :
73
+ args .prompt = [x [0 ] for x in args .prompt ]
74
+
51
75
print (args )
52
76
77
+ def load_prompts (prompts ):
78
+ """
79
+ Load prompts from a list of txt or json files
80
+ (or if these are strings, just return the strings)
81
+ """
82
+ prompt_list = []
83
+
84
+ for prompt in prompts :
85
+ ext = os .path .splitext (prompt )[1 ]
86
+
87
+ if ext == '.json' :
88
+ with open (prompt ) as file :
89
+ json_prompts = json .load (file )
90
+ for json_prompt in json_prompts :
91
+ if isinstance (json_prompt , dict ):
92
+ prompt_list .append (json_prompt ) # json_prompt['text']
93
+ elif ifinstance (json_prompt , str ):
94
+ prompt_list .append (json_prompt )
95
+ else :
96
+ raise TypeError (f"{ type (json_prompt )} " )
97
+ elif ext == '.txt' :
98
+ with open (prompt ) as file :
99
+ prompt_list .append (file .read ())
100
+ else :
101
+ prompt_list .append (prompt )
102
+
103
+ return prompt_list
104
+
105
+ args .prompt = load_prompts (args .prompt )
106
+
107
+ if args .max_num_prompts :
108
+ args .prompt = args .prompt [:args .max_num_prompts ]
109
+
53
110
def get_model_name_from_path (model_path ):
54
111
model_path = model_path .strip ("/" )
55
112
model_paths = model_path .split ("/" )
@@ -97,100 +154,99 @@ def get_model_name_from_path(model_path):
97
154
no_split_module_classes = ["OPTDecoderLayer" , "LlamaDecoderLayer" ]
98
155
)
99
156
100
- precision = f"int { args .w_bit } "
157
+ precision = f"W { args .w_bit } A16 "
101
158
102
159
print (model )
103
160
print (f"model device: { model .device } " )
104
161
105
162
if not args .no_tinychat_kernels :
163
+ tinychat .utils .constants .max_seq_len = model .config .max_position_embeddings
106
164
make_quant_attn (model , device )
107
165
make_quant_norm (model )
108
166
make_fused_mlp (model )
109
-
110
- print (model )
111
- print (f"model device: { model .device } " )
112
-
167
+ print ("TinyChat model:\n " , model )
168
+ print (f"Model max context length: { model .config .max_position_embeddings } " )
169
+
113
170
#for name, param in model.named_parameters():
114
171
# print(f"{name} {param}")
115
172
116
173
model .eval ()
117
174
118
175
# create tokenizer
119
176
tokenizer = AutoTokenizer .from_pretrained (args .model , use_fast = False )
120
- input_ids = tokenizer (args .prompt , return_tensors = "pt" ).input_ids .to (device )
121
177
122
178
# benchmark inference
123
- avg_latency = 0
124
- avg_tokens_sec = 0
179
+ avg_prefill_time = 0
180
+ avg_prefill_rate = 0
181
+ avg_decode_time = 0
182
+ avg_decode_rate = 0
183
+
184
+ for i , prompt in enumerate (args .prompt ):
185
+ if isinstance (prompt , dict ):
186
+ prompt = prompt ['text' ]
187
+
188
+ input_ids = tokenizer (prompt , return_tensors = "pt" ).input_ids .to (device )
189
+ num_input_tokens = input_ids .shape [- 1 ]
125
190
126
- for run in range (args .runs + args .warmup ):
127
- if args .no_streaming :
128
- time_begin = time .perf_counter ()
129
- generated_ids = model .generate (input_ids , do_sample = args .do_sample , min_new_tokens = args .tokens , max_new_tokens = args .tokens )
130
- time_elapsed = time .perf_counter () - time_begin
131
-
132
- print (tokenizer .decode (generated_ids [0 ], skip_special_tokens = False ))
133
-
134
- num_tokens = len (generated_ids [0 ])
135
- tokens_sec = num_tokens / time_elapsed
136
- latency = time_elapsed
137
- else :
138
- streamer = TextIteratorStreamer (tokenizer , skip_prompt = True , skip_special_tokens = False )
139
-
140
- def generate ():
141
- with torch .inference_mode ():
142
- model .generate (
143
- inputs = input_ids ,
144
- do_sample = args .do_sample ,
145
- min_new_tokens = args .tokens ,
146
- max_new_tokens = args .tokens ,
147
- streamer = streamer
148
- )
149
-
150
- thread = threading .Thread (target = generate )
151
- thread .start ()
152
-
153
- print (f"Prompt: { args .prompt } " )
154
-
155
- new_tokens = ''
156
- num_tokens = 0
191
+ streamer = TextIteratorStreamer (tokenizer , skip_prompt = True , skip_special_tokens = False )
192
+ time_begin = 0
193
+
194
+ def generate ():
195
+ global time_begin
157
196
time_begin = time .perf_counter ()
158
-
159
- for token in streamer :
160
- print (token , end = '' )
161
- sys .stdout .flush ()
162
-
163
- if num_tokens == 0 :
164
- time_first_token = time .perf_counter ()
165
- latency = time_first_token - time_begin
166
- time_begin = time_first_token
167
-
168
- new_tokens += token
169
- num_tokens += 1
197
+ with torch .inference_mode ():
198
+ model .generate (
199
+ inputs = input_ids ,
200
+ do_sample = args .do_sample ,
201
+ min_new_tokens = args .max_new_tokens ,
202
+ max_new_tokens = args .max_new_tokens ,
203
+ streamer = streamer
204
+ )
170
205
171
- time_elapsed = time .perf_counter () - time_begin
172
- tokens_sec = (num_tokens - 1 ) / time_elapsed
173
-
174
- print (f"\n { model_name } : { num_tokens } tokens in { time_elapsed :.2f} sec, { tokens_sec :.2f} tokens/sec, latency { latency :.2f} sec ({ precision } )\n " )
206
+ thread = threading .Thread (target = generate )
207
+ thread .start ()
208
+
209
+ print (f"Prompt: { prompt } \n " )
210
+
211
+ new_tokens = ''
212
+ num_tokens = 0
213
+
214
+ for token in streamer :
215
+ print (token , end = '' , flush = True )
216
+
217
+ if num_tokens == 0 :
218
+ time_first_token = time .perf_counter ()
219
+ prefill_time = time_first_token - time_begin
220
+ time_begin = time_first_token
175
221
176
- if run >= args .warmup :
177
- avg_latency += latency
178
- avg_tokens_sec += tokens_sec
222
+ new_tokens += token
223
+ num_tokens += 1
224
+
225
+ decode_time = time .perf_counter () - time_begin
226
+ decode_rate = (args .max_new_tokens - 1 ) / decode_time
227
+ prefill_rate = num_input_tokens / prefill_time
228
+
229
+ print (f"\n \n { model_name } : input={ num_input_tokens } output={ num_tokens } prefill_time { prefill_time :.3f} sec, prefill_rate { prefill_rate :.1f} tokens/sec, decode_time { decode_time :.3f} sec, decode_rate { decode_rate :.1f} tokens/sec\n " )
230
+
231
+ if i > 0 :
232
+ avg_factor = 1.0 / (len (args .prompt ) - 1 )
233
+ avg_prefill_time += prefill_time * avg_factor
234
+ avg_prefill_rate += prefill_rate * avg_factor
235
+ avg_decode_time += decode_time * avg_factor
236
+ avg_decode_rate += decode_rate * avg_factor
179
237
180
- # compute statistics
181
- avg_latency /= args .runs
182
- avg_tokens_sec /= args .runs
238
+ print (f"AVERAGE OVER { len (args .prompt ) - 1 } RUNS, input={ num_input_tokens } , output={ args .max_new_tokens } , precision={ precision } " )
239
+ print (f"{ model_name } : prefill_time { avg_prefill_time :.3f} sec, prefill_rate { avg_prefill_rate :.1f} tokens/sec, decode_time { avg_decode_time :.3f} sec, decode_rate { avg_decode_rate :.1f} tokens/sec\n " )
183
240
184
241
memory_usage = (resource .getrusage (resource .RUSAGE_SELF ).ru_maxrss + resource .getrusage (resource .RUSAGE_CHILDREN ).ru_maxrss ) / 1024 # https://stackoverflow.com/a/7669482
185
-
186
- print (f"AVERAGE of { args .runs } runs:" )
187
- print (f"{ model_name } : { avg_tokens_sec :.2f} tokens/sec, latency { avg_latency :.2f} sec, memory={ memory_usage :.2f} MB ({ precision } )\n " )
242
+ print (f"Peak memory usage: { memory_usage :.2f} MB" )
188
243
189
244
if args .save :
190
245
if not os .path .isfile (args .save ): # csv header
191
246
with open (args .save , 'w' ) as file :
192
- file .write (f"timestamp, hostname, api, model, precision, tokens, tokens/sec, latency , memory\n " )
247
+ file .write (f"timestamp, hostname, api, model, precision, input_tokens, output_tokens, prefill_time, prefill_rate, decode_time, decode_rate , memory\n " )
193
248
with open (args .save , 'a' ) as file :
194
- file .write (f"{ datetime .datetime .now ().strftime ('%Y%m%d %H:%M:%S' )} , { socket .gethostname ()} , { 'tinychat' if args .tiny_chat else 'awq' } , " )
195
- file .write (f"{ model_name } , { precision } , { args .tokens } , { avg_tokens_sec } , { avg_latency } , { memory_usage } \n " )
249
+ file .write (f"{ datetime .datetime .now ().strftime ('%Y%m%d %H:%M:%S' )} , { socket .gethostname ()} , { 'tinychat' if not args .no_tinychat_kernels else 'awq' } , " )
250
+ file .write (f"{ model_name } , { precision } , { num_input_tokens } , { args .max_new_tokens } , " )
251
+ file .write (f"{ avg_prefill_time } , { avg_prefill_rate } , { avg_decode_time } , { avg_decode_rate } , { memory_usage } \n " )
196
252
0 commit comments