Skip to content

Commit

Permalink
Update benchmark e2e test to use chat template (#1236)
Browse files Browse the repository at this point in the history
Currently benchmark e2e did not use chat template, update it to use it

---------

Co-authored-by: Kunal Vaishnavi <[email protected]>
  • Loading branch information
apsonawane and kunal-vaishnavi authored Feb 13, 2025
1 parent 09c1c0d commit 131a335
Showing 1 changed file with 53 additions and 38 deletions.
91 changes: 53 additions & 38 deletions benchmark/python/benchmark_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def generate_prompt(model, tokenizer, prompt_length, use_graph_capture) -> str:
prompt = "a"
tokens = tokenizer.encode(prompt)
params=og.GeneratorParams(model)
params.set_search_options(max_length=prompt_length, min_length=prompt_length)
max_length_to_use = prompt_length + len(tokens)
params.set_search_options(max_length=max_length_to_use, min_length=prompt_length)

if use_graph_capture:
params.try_graph_capture_with_max_batch_size(1)
Expand All @@ -85,10 +86,9 @@ def generate_prompt(model, tokenizer, prompt_length, use_graph_capture) -> str:
# Use prompt length to get pre-defined prompt
def get_prompt_by_length(prompt_length):
json_path = "prompts.json"
with open(json_path) as prompts_file:
content = prompts_file.read()
data = json.load(content)
return data[f"{prompt_length}"]
with open(json_path, "r") as file:
data = json.load(file)
return data[f"{prompt_length}"]

def get_target_pip_package_version(target_pip_package_name_list):
# get package name and version
Expand All @@ -110,18 +110,6 @@ def get_target_pip_package_version(target_pip_package_name_list):
pkg_version = installed_packages_list[0].split("==")[1]
return pkg_name, pkg_version

def get_model_info_from_genai_config(model_input_folder):
genai_config_file_path = os.path.join(model_input_folder, "genai_config.json")
genai_config_file = open(genai_config_file_path)
genai_config = json.load(genai_config_file)
model_info = {}
model_info["execution_provider"] = "cpu"
provider_options = genai_config["model"]["decoder"]["session_options"]["provider_options"]
if len(provider_options) > 0 and len(provider_options[0].keys()) > 0:
model_info["execution_provider"] = list(genai_config["model"]["decoder"]["session_options"]["provider_options"][0].keys())[0]
genai_config_file.close()
return model_info

def save_results(args, results, filename, print_memory_usage=False):
import pandas as pd

Expand Down Expand Up @@ -155,11 +143,10 @@ def save_results(args, results, filename, print_memory_usage=False):
# df = df.transpose() # This line swaps the rows and columns

genai_package_name, genai_package_version = get_target_pip_package_version(["onnxruntime-genai", "onnxruntime-genai-cuda", "onnxruntime-genai-directml"])
model_info = get_model_info_from_genai_config(args.input_folder)

records = []
for _, row in df.iterrows():
record = BenchmarkRecord(args.model_name, args.precision, "onnxruntime-genai", model_info["execution_provider"], genai_package_name, genai_package_version )
record = BenchmarkRecord(args.model_name, args.precision, "onnxruntime-genai", args.execution_provider, genai_package_name, genai_package_version )
record.config.batch_size = row["Batch Size"]
record.config.customized["prompt_length"] = row["Prompt Length"]
record.config.customized["tokens_generated"] = row["Tokens Generated"]
Expand All @@ -184,7 +171,6 @@ def save_results(args, results, filename, print_memory_usage=False):
records.append(record)

# df.to_csv(filename, header=True, index=False)
BenchmarkRecord.save_as_csv(filename, records)
BenchmarkRecord.save_as_json(filename.replace(".csv", ".json"), records)
print(f"Results saved in {filename}!")

Expand Down Expand Up @@ -227,33 +213,60 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length
temperature = 1.0

# Get tokenizer, and model
if args.verbose: print("Getting config")
config = og.Config(f'{args.input_folder}')
config.clear_providers()
if args.verbose: print("Loading model... ")
model=og.Model(f'{args.input_folder}')
if args.execution_provider != "cpu":
if args.verbose: print(f"Setting model to {args.execution_provider}")
config.append_provider(args.execution_provider)
model = og.Model(config)
if args.verbose: print("Model loaded")
tokenizer = og.Tokenizer(model)


# Get model type
model_type = None
if hasattr(model, "type"):
model_type = model.type
else:
with open(os.path.join(args.input_folder, "genai_config.json"), "r") as f:
genai_config = json.load(f)
model_type = genai_config["model"]["type"]

# Set chat template
if args.chat_template:
if args.chat_template.count('{') != 1 or args.chat_template.count('}') != 1:
raise ValueError("Chat template must have exactly one pair of curly braces with input word in it, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'")
else:
if model_type.startswith("phi2") or model_type.startswith("phi3"):
args.chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>'
elif model_type.startswith("phi4"):
args.chat_template = '<|im_start|>user<|im_sep|>\n{input}<|im_end|>\n<|im_start|>assistant<|im_sep|>'
elif model_type.startswith("llama"):
args.chat_template = '<|start_header_id|>user<|end_header_id|>\n{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>'
elif model_type.startswith("llama2"):
args.chat_template = '<s>{input}'
elif model_type.startswith("qwen2"):
args.chat_template = '<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant\n'
else:
raise ValueError(f"Chat Template for model type {model_type} is not known. Please provide chat template using --chat_template")

# Generate prompt
tokens, prompt = None, None
if args.use_random_tokens:
# use random tokens instead of generating a prompt using the model and then tokenizing it
tokens = np.random.randint(100, size=(batch_size, prompt_length))
prompt = [tokenizer.decode(tokens[0])] * batch_size
text = [tokenizer.decode(tokens[0])] * batch_size
prompt = f'{args.chat_template.format(input=text)}'
elif args.use_prompt_set:
prompt = [get_prompt_by_length(prompt_length)] * batch_size
tokens = tokenizer.encode_batch(prompt)

if len(tokens) > max_length:
# Shorten the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
tokens = tokens[:, :max_length]
elif len(tokens) < max_length:
# Lengthen the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
tokens_first_col = tokens[:, 0].unsqueeze(0).T
for _ in range(max_length - len(tokens)):
tokens = np.hstack((tokens_first_col, tokens))
text = [get_prompt_by_length(prompt_length)] * batch_size
prompt = f'{args.chat_template.format(input=text)}'
tokens = tokenizer.encode(prompt)
else:
prompt = [generate_prompt(model, tokenizer, prompt_length, args.use_graph_capture)] * batch_size
tokens = tokenizer.encode_batch(prompt)
text = [generate_prompt(model, tokenizer, prompt_length, args.use_graph_capture)] * batch_size
prompt = f'{args.chat_template.format(input=text)}'
tokens = tokenizer.encode(prompt)
prompt_length = len(tokens)
max_length = prompt_length + generation_length

params = og.GeneratorParams(model)
do_sample = args.top_k > 1 or (args.top_p != 1.0 and args.top_p > 0.0)
Expand Down Expand Up @@ -283,7 +296,7 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length

# Measure tokenization
tokenize_start_time = time.perf_counter()
tokens = tokenizer.encode_batch(prompt)
tokens = tokenizer.encode(prompt)
tokenize_end_time = time.perf_counter()
tokenize_times.append(tokenize_end_time - tokenize_start_time)

Expand Down Expand Up @@ -437,6 +450,8 @@ def str2strlist(value):
parser.add_argument('-pr', '--precision', type=str, default='fp16', help='Model precision for metrics info')
parser.add_argument('--use_random_tokens', action='store_true', help='Use random tokens instead of generating a prompt')
parser.add_argument('--use_prompt_set', action='store_true', help='Use pre-generated prompt set instead of generating a prompt')
parser.add_argument('--chat_template', type=str, default='', help='Chat template to use for the prompt. User input will be injected into {input}')
parser.add_argument('-e', '--execution_provider', type=str, required=True, choices=["cpu", "cuda", "dml"], help='Execution provider to run ONNX model with')
args = parser.parse_args()

# check max_lengths
Expand Down

0 comments on commit 131a335

Please sign in to comment.