diff --git a/README.md b/README.md index 50402a0..55301b3 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ from prompts import template @template def few_shots(instructions, examples, question): - """{{ instructions }} + return """{{ instructions }} Examples -------- diff --git a/docs/reference/template.md b/docs/reference/template.md index 14baaf1..0a6fd2c 100644 --- a/docs/reference/template.md +++ b/docs/reference/template.md @@ -38,7 +38,7 @@ will pass to the prompt function. @prompts.template def greetings(name, question): - """Hello, {{ name }}! + return """Hello, {{ name }}! {{ question }} """ @@ -62,7 +62,7 @@ If a variable is missing in the function's arguments, Jinja2 will throw an `Unde @prompts.template def greetings(name): - """Hello, {{ surname }}!""" + return """Hello, {{ surname }}!""" prompt = greetings("user") ``` @@ -94,7 +94,7 @@ Prompt functions are functions, and thus can be imported from other modules: @prompts.template def greetings(name, question): - """Hello, {{ name }}! + return """Hello, {{ name }}! {{ question }} """ ``` @@ -128,7 +128,7 @@ keys `question` and `answer` to the prompt function: @prompts.template def few_shots(instructions, examples, question): - """{{ instructions }} + return """{{ instructions }} Examples -------- @@ -207,12 +207,12 @@ below does not matter for formatting: @prompts.template def prompt1(): - """My prompt + return """My prompt """ @prompts.template def prompt2(): - """ + return """ My prompt """ @@ -236,20 +236,20 @@ Indentation is relative to the second line of the docstring, and leading spaces @prompts.template def example1(): - """First line + return """First line Second line """ @prompts.template def example2(): - """ + return """ Second line Third line """ @prompts.template def example3(): - """ + return """ Second line Third line """ @@ -285,7 +285,7 @@ You can use the backslash `\` to break a long line of text. It will render as a @prompts.template def example(): - """ + return """ Break in \ several lines \ But respect the indentation diff --git a/prompts/templates.py b/prompts/templates.py index a7f5a8c..56188f8 100644 --- a/prompts/templates.py +++ b/prompts/templates.py @@ -3,7 +3,7 @@ import warnings from dataclasses import dataclass, field from functools import lru_cache -from typing import Callable, Dict, Hashable, Optional, cast +from typing import Callable, Dict, Hashable, Optional from jinja2 import Environment, StrictUndefined @@ -15,7 +15,7 @@ class Template: """Represents a prompt template. A prompt template is a callable that, given a Jinja2 template and a set of values, - renders the template using those values. It is recommended to instantiate `Temaplate` + renders the template using those values. It is recommended to instantiate `Template` using the `template` decorator, which extracts the template from the function's docstring and its variables from the function's signature. @@ -40,11 +40,15 @@ class Template: """ - template: str signature: inspect.Signature + fn: Callable model: Optional[str] = None registry: Dict[str, Callable] = field(default_factory=dict) + def __init__(self, fn: Callable): + self.fn = fn + self.signature = inspect.signature(fn) + def __call__(self, *args, **kwargs) -> str: """Render and return the template. @@ -55,7 +59,10 @@ def __call__(self, *args, **kwargs) -> str: """ bound_arguments = self.signature.bind(*args, **kwargs) bound_arguments.apply_defaults() - return render(self.template, self.model, **bound_arguments.arguments) + + template = self.fn(**bound_arguments.arguments) + + return render(template, self.model, **bound_arguments.arguments) def __str__(self): return self.template @@ -104,11 +111,11 @@ def template(fn: Callable) -> Template: manipulation by providing some degree of encapsulation. It uses the `render` function internally to render templates. - >>> import outlines + >>> import prompts >>> - >>> @outlines.prompt + >>> @prompts.template >>> def build_prompt(question): - ... "I have a ${question}" + ... return "I have a {{question}}" ... >>> prompt = build_prompt("How are you?") @@ -116,12 +123,11 @@ def template(fn: Callable) -> Template: are set when the agent is initialized and never modified later. In this situation we can partially apply the prompt function at initialization. - >>> import outlines - >>> import functools as ft + >>> import prompts ... - >>> @outlines.prompt + >>> @prompts.template ... def solve_task(name: str, objective: str, task: str): - ... '''Your name is {{name}}. + ... return '''Your name is {{name}}. .. Your overall objective is to {{objective}}. ... Please solve the following task: {{task}}''' ... @@ -132,17 +138,7 @@ def template(fn: Callable) -> Template: A `Prompt` callable class which will render the template when called. """ - signature = inspect.signature(fn) - - # The docstring contains the template that will be rendered to be used - # as a prompt to the language model. - docstring = fn.__doc__ - if docstring is None: - raise TypeError("Could not find a template in the function's docstring.") - - template = cast(str, docstring) - - return Template(template, signature) + return Template(fn) @lru_cache diff --git a/tests/test_templates.py b/tests/test_templates.py index 8bd295d..4784134 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -129,9 +129,8 @@ def test_render_jinja(): def test_prompt_basic(): @prompts.template def test_tpl(variable): - """{{variable}} test""" + return """{{variable}} test""" - assert test_tpl.template == "{{variable}} test" assert list(test_tpl.signature.parameters.keys()) == ["variable"] with pytest.raises(TypeError): @@ -145,7 +144,7 @@ def test_tpl(variable): @prompts.template def test_single_quote_tpl(variable): - "${variable} test" + return "{{variable}} test" p = test_tpl("test") assert p == "test test" @@ -154,9 +153,8 @@ def test_single_quote_tpl(variable): def test_prompt_kwargs(): @prompts.template def test_kwarg_tpl(var, other_var="other"): - """{{var}} and {{other_var}}""" + return """{{var}} and {{other_var}}""" - assert test_kwarg_tpl.template == "{{var}} and {{other_var}}" assert list(test_kwarg_tpl.signature.parameters.keys()) == ["var", "other_var"] p = test_kwarg_tpl("test") @@ -169,30 +167,16 @@ def test_kwarg_tpl(var, other_var="other"): assert p == "test and test" -def test_no_prompt(): - with pytest.raises(TypeError, match="template"): - - @prompts.template - def test_empty(variable): - pass - - with pytest.raises(TypeError, match="template"): - - @prompts.template - def test_only_code(variable): - return variable - - @pytest.mark.filterwarnings("ignore: The model") def test_dispatch(): @prompts.template def simple_prompt(query: str): - """{{ query }}""" + return """{{ query }}""" @simple_prompt.register("provider/name") def simple_prompt_name(query: str): - """name: {{ query }}""" + return """name: {{ query }}""" assert list(simple_prompt.registry.keys()) == ["provider/name"] assert callable(simple_prompt) @@ -214,7 +198,7 @@ def test_special_tokens(): @prompts.template def simple_prompt(query: str): - """{{ bos + query + eos }}""" + return """{{ bos + query + eos }}""" assert simple_prompt("test") == "test" assert simple_prompt["openai-community/gpt2"]("test") == "test<|endoftext|>" @@ -225,7 +209,7 @@ def test_warn(): @prompts.template def simple_prompt(): - """test""" + return """test""" with pytest.warns(UserWarning, match="not present in the special token"): simple_prompt["non-existent-model"]()