From 8ccc783e0a278b27446eb09ee1c1f1f13964ab24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 26 Sep 2024 16:42:54 +0200 Subject: [PATCH] Benchmark template rendering times A concern is that rendering the template this way could be much slower, assuming Jinja2 makes extensive use of caching. We thus write a comparative benchmark between using Jinja2 templates or simply concatenating strings. --- pyproject.toml | 2 +- tests/test_templates.py | 42 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b7a049f..326a183 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ packages = ["prompts"] write_to = "prompts/_version.py" [project.optional-dependencies] -test = ["pre-commit", "pytest"] +test = ["pre-commit", "pytest", "pytest-benchmark"] docs = [ "mkdocs", "mkdocs-material", diff --git a/tests/test_templates.py b/tests/test_templates.py index 3032b36..9e4e8f9 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -1,3 +1,6 @@ +import random +import string + import pytest import prompts @@ -192,3 +195,42 @@ def simple_prompt_name(query: str): assert simple_prompt("test") == "test" assert simple_prompt["gpt2"]("test") == "test" assert simple_prompt["provider/name"]("test") == "name: test" + + +def test_benchmark_template_render(benchmark): + + @prompts.template + def test_tpl(var0, var1): + prompt = var0 + return prompt + """{{var1}} test""" + + def setup(): + """We generate random strings to make sure we don't hit any potential cache.""" + length = 10 + var0 = "".join( + random.choice(string.ascii_uppercase + string.digits) for _ in range(length) + ) + var1 = "".join( + random.choice(string.ascii_uppercase + string.digits) for _ in range(length) + ) + return (var0, var1), {} + + benchmark.pedantic(test_tpl, setup=setup, rounds=500) + + +def test_benchmark_template_function(benchmark): + + def test_tpl(var0, var1): + return var0 + f"{var1} test" + + def setup(): + length = 10 + var0 = "".join( + random.choice(string.ascii_uppercase + string.digits) for _ in range(length) + ) + var1 = "".join( + random.choice(string.ascii_uppercase + string.digits) for _ in range(length) + ) + return (var0, var1), {} + + benchmark.pedantic(test_tpl, setup=setup, rounds=500)