Skip to content

Commit 6035e86

Browse files
jantrienesrlouf
authored andcommitted
Reuse jinja environment for a prompt
1 parent 538e714 commit 6035e86

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

outlines/prompts.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class Prompt:
2424

2525
def __post_init__(self):
2626
self.parameters: List[str] = list(self.signature.parameters.keys())
27+
self.jinja_environment = create_jinja_template(self.template)
2728

2829
def __call__(self, *args, **kwargs) -> str:
2930
"""Render and return the template.
@@ -35,7 +36,7 @@ def __call__(self, *args, **kwargs) -> str:
3536
"""
3637
bound_arguments = self.signature.bind(*args, **kwargs)
3738
bound_arguments.apply_defaults()
38-
return render(self.template, **bound_arguments.arguments)
39+
return self.jinja_environment.render(**bound_arguments.arguments)
3940

4041
def __str__(self):
4142
return self.template
@@ -182,6 +183,11 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str:
182183
A string that contains the rendered template.
183184
184185
"""
186+
jinja_template = create_jinja_template(template)
187+
return jinja_template.render(**values)
188+
189+
190+
def create_jinja_template(template: str):
185191
# Dedent, and remove extra linebreak
186192
cleaned_template = inspect.cleandoc(template)
187193

@@ -210,8 +216,7 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str:
210216
env.filters["args"] = get_fn_args
211217

212218
jinja_template = env.from_string(cleaned_template)
213-
214-
return jinja_template.render(**values)
219+
return jinja_template
215220

216221

217222
def get_fn_name(fn: Callable):

0 commit comments

Comments
 (0)