Skip to content

Commit

Permalink
Benchmark template rendering times
Browse files Browse the repository at this point in the history
A concern is that rendering the template this way could be much slower,
assuming Jinja2 makes extensive use of caching. We thus write a
comparative benchmark between using Jinja2 templates or simply
concatenating strings.
  • Loading branch information
rlouf committed Sep 26, 2024
1 parent 8be6a60 commit 8ccc783
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ packages = ["prompts"]
write_to = "prompts/_version.py"

[project.optional-dependencies]
test = ["pre-commit", "pytest"]
test = ["pre-commit", "pytest", "pytest-benchmark"]
docs = [
"mkdocs",
"mkdocs-material",
Expand Down
42 changes: 42 additions & 0 deletions tests/test_templates.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import random
import string

import pytest

import prompts
Expand Down Expand Up @@ -192,3 +195,42 @@ def simple_prompt_name(query: str):
assert simple_prompt("test") == "test"
assert simple_prompt["gpt2"]("test") == "test"
assert simple_prompt["provider/name"]("test") == "name: test"


def test_benchmark_template_render(benchmark):

@prompts.template
def test_tpl(var0, var1):
prompt = var0
return prompt + """{{var1}} test"""

def setup():
"""We generate random strings to make sure we don't hit any potential cache."""
length = 10
var0 = "".join(
random.choice(string.ascii_uppercase + string.digits) for _ in range(length)
)
var1 = "".join(
random.choice(string.ascii_uppercase + string.digits) for _ in range(length)
)
return (var0, var1), {}

benchmark.pedantic(test_tpl, setup=setup, rounds=500)


def test_benchmark_template_function(benchmark):

def test_tpl(var0, var1):
return var0 + f"{var1} test"

def setup():
length = 10
var0 = "".join(
random.choice(string.ascii_uppercase + string.digits) for _ in range(length)
)
var1 = "".join(
random.choice(string.ascii_uppercase + string.digits) for _ in range(length)
)
return (var0, var1), {}

benchmark.pedantic(test_tpl, setup=setup, rounds=500)

0 comments on commit 8ccc783

Please sign in to comment.