@@ -24,6 +24,7 @@ class Prompt:
24
24
25
25
def __post_init__ (self ):
26
26
self .parameters : List [str ] = list (self .signature .parameters .keys ())
27
+ self .jinja_environment = create_jinja_template (self .template )
27
28
28
29
def __call__ (self , * args , ** kwargs ) -> str :
29
30
"""Render and return the template.
@@ -35,7 +36,7 @@ def __call__(self, *args, **kwargs) -> str:
35
36
"""
36
37
bound_arguments = self .signature .bind (* args , ** kwargs )
37
38
bound_arguments .apply_defaults ()
38
- return render ( self .template , ** bound_arguments .arguments )
39
+ return self .jinja_environment . render ( ** bound_arguments .arguments )
39
40
40
41
def __str__ (self ):
41
42
return self .template
@@ -182,6 +183,11 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str:
182
183
A string that contains the rendered template.
183
184
184
185
"""
186
+ jinja_template = create_jinja_template (template )
187
+ return jinja_template .render (** values )
188
+
189
+
190
+ def create_jinja_template (template : str ):
185
191
# Dedent, and remove extra linebreak
186
192
cleaned_template = inspect .cleandoc (template )
187
193
@@ -210,8 +216,7 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str:
210
216
env .filters ["args" ] = get_fn_args
211
217
212
218
jinja_template = env .from_string (cleaned_template )
213
-
214
- return jinja_template .render (** values )
219
+ return jinja_template
215
220
216
221
217
222
def get_fn_name (fn : Callable ):
0 commit comments