diff --git a/token_benchmark_ray.py b/token_benchmark_ray.py index 63216b1..c0380f4 100644 --- a/token_benchmark_ray.py +++ b/token_benchmark_ray.py @@ -24,10 +24,11 @@ ) from tqdm import tqdm -from transformers import LlamaTokenizerFast +from transformers import AutoTokenizer, LlamaTokenizerFast def get_token_throughput_latencies( model: str, + tokenizer_name_or_path: str, mean_input_tokens: int, stddev_input_tokens: int, mean_output_tokens: int, @@ -42,6 +43,7 @@ def get_token_throughput_latencies( Args: model: The name of the model to query. + tokenizer_name_or_path: The Hugging Face name or local path of the tokenizer to use for tokenizing prompts and completions. mean_input_tokens: The mean number of tokens to send in the prompt for the request. stddev_input_tokens: The standard deviation of the number of tokens to send in the prompt for the request. mean_output_tokens: The mean number of tokens to generate per request. @@ -60,9 +62,12 @@ def get_token_throughput_latencies( """ random.seed(11111) - tokenizer = LlamaTokenizerFast.from_pretrained( - "hf-internal-testing/llama-tokenizer" - ) + if tokenizer_name_or_path: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + else: + tokenizer = LlamaTokenizerFast.from_pretrained( + "hf-internal-testing/llama-tokenizer" + ) get_token_length = lambda text: len(tokenizer.encode(text)) if not additional_sampling_params: @@ -282,6 +287,7 @@ def flatten(item): def run_token_benchmark( llm_api: str, model: str, + tokenizer_name_or_path: str, test_timeout_s: int, max_num_completed_requests: int, num_concurrent_requests: int, @@ -297,6 +303,7 @@ def run_token_benchmark( Args: llm_api: The name of the llm api to use. model: The name of the model to query. + tokenizer_name_or_path: The Hugging Face name or local path of the tokenizer to use for tokenizing prompts and completions. max_num_completed_requests: The number of requests to complete before finishing the test. test_timeout_s: The amount of time to run the test for before reporting results. num_concurrent_requests: The number of concurrent requests to make. Increase @@ -318,6 +325,7 @@ def run_token_benchmark( summary, individual_responses = get_token_throughput_latencies( model=model, + tokenizer_name_or_path=tokenizer_name_or_path, llm_api=llm_api, test_timeout_s=test_timeout_s, max_num_completed_requests=max_num_completed_requests, @@ -368,6 +376,9 @@ def run_token_benchmark( args.add_argument( "--model", type=str, required=True, help="The model to use for this load test." ) +args.add_argument( + "--tokenizer", type=str, required=False, default="", help="The tokenizer to use for this load test." +) args.add_argument( "--mean-input-tokens", type=int, @@ -478,6 +489,7 @@ def run_token_benchmark( run_token_benchmark( llm_api=args.llm_api, model=args.model, + tokenizer_name_or_path=args.tokenizer, test_timeout_s=args.timeout, max_num_completed_requests=args.max_num_completed_requests, mean_input_tokens=args.mean_input_tokens,