Skip to content

Commit 4beb832

Browse files
committedOct 2, 2023
updated AWQ benchmarks
1 parent 25a1e35 commit 4beb832

File tree

3 files changed

+142
-78
lines changed

3 files changed

+142
-78
lines changed
 

‎packages/llm/awq/Dockerfile

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# name: awq
33
# group: llm
44
# config: config.py
5-
# depends: [pytorch, llava]
5+
# depends: [transformers]
66
# requires: '>=34.1.0'
77
# test: test.sh
88
# docs: docs.md
@@ -20,7 +20,7 @@ WORKDIR /opt
2020
# force rebuild on new git commits - https://stackoverflow.com/a/56945508
2121
ADD https://api.github.com/repos/${AWQ_REPO}/git/refs/heads/${AWQ_BRANCH} /tmp/awq_version.json
2222

23-
RUN git clone --depth=1 https://github.com/${AWQ_REPO} awq
23+
RUN git clone --branch=${AWQ_BRANCH} --depth=1 https://github.com/${AWQ_REPO} awq
2424

2525
# enable giving huggingface model names (as opposed to paths only)
2626
#RUN sed 's|^ if not os.path.exists(model_path)|# if not os.path.exists(model_path)|g' -i awq/awq/entry.py && \

‎packages/llm/awq/benchmark.py

+132-76
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import os
44
import sys
55
import time
6+
import json
67
import datetime
78
import argparse
89
import resource
910
import socket
1011
import threading
1112
import torch
13+
import tinychat
1214

1315
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
1416
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
@@ -25,31 +27,86 @@
2527

2628
parser.add_argument('--model', type=str, default='', required=True, help="name or path of the huggingface model")
2729
parser.add_argument('--quant', type=str, default='', required=True, help="path to the real AWQ quantized model checkpoint")
28-
parser.add_argument('--prompt', type=str, default='Once upon a time,')
30+
parser.add_argument("--prompt", action='append', nargs='*')
2931

3032
# benchmarking options
31-
parser.add_argument('--tokens', type=int, default=128, help='number of output tokens to generate, including the input prompt')
32-
parser.add_argument('--runs', type=int, default=2, help='the number of benchmark timing iterations')
33-
parser.add_argument('--warmup', type=int, default=2, help='the number of warmup iterations')
33+
parser.add_argument("--max-new-tokens", type=int, default=128)
34+
parser.add_argument("--max-num-prompts", type=int, default=None)
3435
parser.add_argument('--save', type=str, default='', help='CSV file to save benchmarking results to')
3536

3637
# quantization options
3738
parser.add_argument('--w_bit', type=int, default=4)
3839
parser.add_argument('--q_group_size', type=int, default=128)
3940
parser.add_argument('--no_zero_point', action='store_true', help="disable zero_point")
4041
parser.add_argument('--no_tinychat_kernels', action='store_true', help="disable tinychat kernels")
41-
parser.add_argument('--no_tinychat_infer', action='store_true', help="disable tinychat inference")
42-
parser.add_argument('--no_streaming', action='store_true', help="disable streaming mode")
42+
#parser.add_argument('--no_tinychat_infer', action='store_true', help="disable tinychat inference")
4343
parser.add_argument('--no_quant', action='store_true', help="disable quantization and use FP16 through transformers")
4444
parser.add_argument('--do_sample', action='store_true')
4545

4646
args = parser.parse_args()
4747

48-
#args.prompt="Once upon a time, there was a young man named Jack who lived in a small village nestled in the rolling hills of the countryside. Jack was a curious and adventurous soul, always eager to explore the world beyond his village. One day, while wandering through the nearby forest, he stumbled upon a hidden path that he had never seen before. The path was overgrown with weeds and vines, and it looked as though it had been untouched for many years. Jack's curiosity was piqued, and he decided to follow the path to see where it led"
49-
#args.prompt="Once upon a time, there was a young man named Jack who lived in a small village nestled in the rolling hills of the countryside. Jack was a curious and adventurous soul, always eager to explore the world beyond his village. One day, while wandering through the nearby forest, he stumbled upon a hidden path that he had never seen before. The path was overgrown with weeds and vines, and it looked as though it had been untouched for many years. Jack's curiosity was piqued, and he decided to follow the path to see where it led. As he walked down the path, the trees grew taller and the air grew colder. Jack could feel a strange energy emanating from the forest, as if it were alive and watching him. He quickened his pace, eager to reach the end of the path and discover its secrets. After a while, the path opened up into a clearing, and Jack found himself standing in front of a massive stone structure. The building was unlike anything he had ever seen before, with intricate carvings and symbols etched into its walls. Jack felt a sense of awe and wonder as he approached the"
50-
48+
if not args.prompt:
49+
if args.chat: # https://modal.com/docs/guide/ex/vllm_inference
50+
args.prompt = [
51+
"What is the meaning of life?",
52+
"How many points did you list out?",
53+
"What is the weather forecast today?",
54+
"What is the fable involving a fox and grapes?",
55+
"What's a good recipe for making tabouli?",
56+
"What is the product of 9 and 8?",
57+
"If a train travels 120 miles in 2 hours, what is its average speed?",
58+
]
59+
else:
60+
args.prompt = [
61+
"Once upon a time,",
62+
"A great place to live is",
63+
"In a world where dreams are shared,",
64+
"The weather forecast today is",
65+
"Large language models are",
66+
"Space exploration is exciting",
67+
"The history of the Hoover Dam is",
68+
"San Fransisco is a city in",
69+
"To train for running a marathon,",
70+
"A recipe for making tabouli is"
71+
]
72+
else:
73+
args.prompt = [x[0] for x in args.prompt]
74+
5175
print(args)
5276

77+
def load_prompts(prompts):
78+
"""
79+
Load prompts from a list of txt or json files
80+
(or if these are strings, just return the strings)
81+
"""
82+
prompt_list = []
83+
84+
for prompt in prompts:
85+
ext = os.path.splitext(prompt)[1]
86+
87+
if ext == '.json':
88+
with open(prompt) as file:
89+
json_prompts = json.load(file)
90+
for json_prompt in json_prompts:
91+
if isinstance(json_prompt, dict):
92+
prompt_list.append(json_prompt) # json_prompt['text']
93+
elif ifinstance(json_prompt, str):
94+
prompt_list.append(json_prompt)
95+
else:
96+
raise TypeError(f"{type(json_prompt)}")
97+
elif ext == '.txt':
98+
with open(prompt) as file:
99+
prompt_list.append(file.read())
100+
else:
101+
prompt_list.append(prompt)
102+
103+
return prompt_list
104+
105+
args.prompt = load_prompts(args.prompt)
106+
107+
if args.max_num_prompts:
108+
args.prompt = args.prompt[:args.max_num_prompts]
109+
53110
def get_model_name_from_path(model_path):
54111
model_path = model_path.strip("/")
55112
model_paths = model_path.split("/")
@@ -97,100 +154,99 @@ def get_model_name_from_path(model_path):
97154
no_split_module_classes=["OPTDecoderLayer", "LlamaDecoderLayer"]
98155
)
99156

100-
precision = f"int{args.w_bit}"
157+
precision = f"W{args.w_bit}A16"
101158

102159
print(model)
103160
print(f"model device: {model.device}")
104161

105162
if not args.no_tinychat_kernels:
163+
tinychat.utils.constants.max_seq_len = model.config.max_position_embeddings
106164
make_quant_attn(model, device)
107165
make_quant_norm(model)
108166
make_fused_mlp(model)
109-
110-
print(model)
111-
print(f"model device: {model.device}")
112-
167+
print("TinyChat model:\n", model)
168+
print(f"Model max context length: {model.config.max_position_embeddings}")
169+
113170
#for name, param in model.named_parameters():
114171
# print(f"{name} {param}")
115172

116173
model.eval()
117174

118175
# create tokenizer
119176
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
120-
input_ids = tokenizer(args.prompt, return_tensors="pt").input_ids.to(device)
121177

122178
# benchmark inference
123-
avg_latency=0
124-
avg_tokens_sec=0
179+
avg_prefill_time = 0
180+
avg_prefill_rate = 0
181+
avg_decode_time = 0
182+
avg_decode_rate = 0
183+
184+
for i, prompt in enumerate(args.prompt):
185+
if isinstance(prompt, dict):
186+
prompt = prompt['text']
187+
188+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
189+
num_input_tokens = input_ids.shape[-1]
125190

126-
for run in range(args.runs + args.warmup):
127-
if args.no_streaming:
128-
time_begin = time.perf_counter()
129-
generated_ids = model.generate(input_ids, do_sample=args.do_sample, min_new_tokens=args.tokens, max_new_tokens=args.tokens)
130-
time_elapsed = time.perf_counter() - time_begin
131-
132-
print(tokenizer.decode(generated_ids[0], skip_special_tokens=False))
133-
134-
num_tokens=len(generated_ids[0])
135-
tokens_sec=num_tokens / time_elapsed
136-
latency=time_elapsed
137-
else:
138-
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
139-
140-
def generate():
141-
with torch.inference_mode():
142-
model.generate(
143-
inputs=input_ids,
144-
do_sample=args.do_sample,
145-
min_new_tokens=args.tokens,
146-
max_new_tokens=args.tokens,
147-
streamer=streamer
148-
)
149-
150-
thread = threading.Thread(target=generate)
151-
thread.start()
152-
153-
print(f"Prompt: {args.prompt}")
154-
155-
new_tokens = ''
156-
num_tokens = 0
191+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
192+
time_begin = 0
193+
194+
def generate():
195+
global time_begin
157196
time_begin = time.perf_counter()
158-
159-
for token in streamer:
160-
print(token, end='')
161-
sys.stdout.flush()
162-
163-
if num_tokens == 0:
164-
time_first_token=time.perf_counter()
165-
latency=time_first_token - time_begin
166-
time_begin=time_first_token
167-
168-
new_tokens += token
169-
num_tokens += 1
197+
with torch.inference_mode():
198+
model.generate(
199+
inputs=input_ids,
200+
do_sample=args.do_sample,
201+
min_new_tokens=args.max_new_tokens,
202+
max_new_tokens=args.max_new_tokens,
203+
streamer=streamer
204+
)
170205

171-
time_elapsed=time.perf_counter() - time_begin
172-
tokens_sec=(num_tokens-1) / time_elapsed
173-
174-
print(f"\n{model_name}: {num_tokens} tokens in {time_elapsed:.2f} sec, {tokens_sec:.2f} tokens/sec, latency {latency:.2f} sec ({precision})\n")
206+
thread = threading.Thread(target=generate)
207+
thread.start()
208+
209+
print(f"Prompt: {prompt}\n")
210+
211+
new_tokens = ''
212+
num_tokens = 0
213+
214+
for token in streamer:
215+
print(token, end='', flush=True)
216+
217+
if num_tokens == 0:
218+
time_first_token=time.perf_counter()
219+
prefill_time=time_first_token - time_begin
220+
time_begin=time_first_token
175221

176-
if run >= args.warmup:
177-
avg_latency += latency
178-
avg_tokens_sec += tokens_sec
222+
new_tokens += token
223+
num_tokens += 1
224+
225+
decode_time=time.perf_counter() - time_begin
226+
decode_rate=(args.max_new_tokens-1) / decode_time
227+
prefill_rate=num_input_tokens / prefill_time
228+
229+
print(f"\n\n{model_name}: input={num_input_tokens} output={num_tokens} prefill_time {prefill_time:.3f} sec, prefill_rate {prefill_rate:.1f} tokens/sec, decode_time {decode_time:.3f} sec, decode_rate {decode_rate:.1f} tokens/sec\n")
230+
231+
if i > 0:
232+
avg_factor = 1.0 / (len(args.prompt) - 1)
233+
avg_prefill_time += prefill_time * avg_factor
234+
avg_prefill_rate += prefill_rate * avg_factor
235+
avg_decode_time += decode_time * avg_factor
236+
avg_decode_rate += decode_rate * avg_factor
179237

180-
# compute statistics
181-
avg_latency /= args.runs
182-
avg_tokens_sec /= args.runs
238+
print(f"AVERAGE OVER {len(args.prompt) - 1} RUNS, input={num_input_tokens}, output={args.max_new_tokens}, precision={precision}")
239+
print(f"{model_name}: prefill_time {avg_prefill_time:.3f} sec, prefill_rate {avg_prefill_rate:.1f} tokens/sec, decode_time {avg_decode_time:.3f} sec, decode_rate {avg_decode_rate:.1f} tokens/sec\n")
183240

184241
memory_usage = (resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss) / 1024 # https://stackoverflow.com/a/7669482
185-
186-
print(f"AVERAGE of {args.runs} runs:")
187-
print(f"{model_name}: {avg_tokens_sec:.2f} tokens/sec, latency {avg_latency:.2f} sec, memory={memory_usage:.2f} MB ({precision})\n")
242+
print(f"Peak memory usage: {memory_usage:.2f} MB")
188243

189244
if args.save:
190245
if not os.path.isfile(args.save): # csv header
191246
with open(args.save, 'w') as file:
192-
file.write(f"timestamp, hostname, api, model, precision, tokens, tokens/sec, latency, memory\n")
247+
file.write(f"timestamp, hostname, api, model, precision, input_tokens, output_tokens, prefill_time, prefill_rate, decode_time, decode_rate, memory\n")
193248
with open(args.save, 'a') as file:
194-
file.write(f"{datetime.datetime.now().strftime('%Y%m%d %H:%M:%S')}, {socket.gethostname()}, {'tinychat' if args.tiny_chat else 'awq'}, ")
195-
file.write(f"{model_name}, {precision}, {args.tokens}, {avg_tokens_sec}, {avg_latency}, {memory_usage}\n")
249+
file.write(f"{datetime.datetime.now().strftime('%Y%m%d %H:%M:%S')}, {socket.gethostname()}, {'tinychat' if not args.no_tinychat_kernels else 'awq'}, ")
250+
file.write(f"{model_name}, {precision}, {num_input_tokens}, {args.max_new_tokens}, ")
251+
file.write(f"{avg_prefill_time}, {avg_prefill_rate}, {avg_decode_time}, {avg_decode_rate}, {memory_usage}\n")
196252

‎packages/llm/awq/config.py

+8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11

22
from jetson_containers import CUDA_ARCHITECTURES
3+
import copy
34

45
# AWQ package is for Orin only:
56
# ptxas /tmp/tmpxft_000000b4_00000000-7_gemm_cuda_gen.compute_72.ptx, line 889; error
@@ -8,3 +9,10 @@
89
package['build_args'] = {
910
'TORCH_CUDA_ARCH_LIST': '8.7',
1011
}
12+
13+
dev_package = copy.deepcopy(package)
14+
15+
dev_package['name'] = 'awq:dev'
16+
dev_package['build_args']['AWQ_BRANCH'] = 'dev/tinychat_update_0918'
17+
18+
package = [package, dev_package]

0 commit comments

Comments
 (0)
Please sign in to comment.