diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 2baa8f5cd14..71cfc041cfa 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -146,7 +146,8 @@ def generate( # noqa: C901 generate_time = time.time() - generate_start print(f"Prefill time: {prefill_time}") - print(f"Generation tok/s: {len(tokens) / generate_time}") + num_generated_tokens = len(tokens) - len(prompt_tokens) - 1 + print(f"Generation tok/s: {num_generated_tokens / generate_time}") return tokens if echo else tokens[len(prompt_tokens) :] diff --git a/examples/models/llama/runner/native.py b/examples/models/llama/runner/native.py index 6d5d4730844..e146a8c1a42 100644 --- a/examples/models/llama/runner/native.py +++ b/examples/models/llama/runner/native.py @@ -126,6 +126,13 @@ def build_args_parser() -> argparse.ArgumentParser: help="Maximum length of the generated response sequence.", ) + parser.add_argument( + "--cpu_threads", + type=int, + default=4, + help="Number of CPU threads to use for inference.", + ) + return parser @@ -133,6 +140,7 @@ def main() -> None: parser = build_args_parser() args = parser.parse_args() validate_args(args) + portable_lib._unsafe_reset_threadpool(args.cpu_threads) runner = NativeLlamaRunner(args) generated_tokens = runner.text_completion( prompt=args.prompt,