Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing bug 857 (regression from 0.1.14 to 0.1.15) #858

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
12 changes: 11 additions & 1 deletion guidance/library/_capture.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from .._guidance import guidance
from .._grammar import capture as grammar_capture, GrammarFunction

# Adapted from active_role_end in _model.py, functionality should be shared probably?
import re
format_pattern = re.compile(r"<\|\|_.*?_\|\|>", flags=re.DOTALL)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly concerned that this appears to be relying on ChatML tags, which not all models use


@guidance(stateless=lambda *args, **kwargs: isinstance(args[0], GrammarFunction))
def capture(lm, value, name):
Expand All @@ -9,4 +12,11 @@ def capture(lm, value, name):
else:
start_len = len(lm)
lm += value
return lm.set(name, str(lm)[start_len:])
# Adapted from active_role_end in _model.py
parts = ""
for _, role_end_str in lm.opened_blocks.values():
role_end_str = format_pattern.sub("", role_end_str)
if len(role_end_str) > 0 and not re.fullmatch(r"\s+", role_end_str):
parts += role_end_str

return lm.set(name, str(lm)[start_len-len(parts):].removesuffix(parts))
21 changes: 19 additions & 2 deletions tests/library/test_capture.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
from guidance import capture, models, one_or_more, select, guidance
from guidance import capture, models, one_or_more, select, guidance, user
import pytest

from ..utils import get_model


@pytest.fixture(scope="module")
def instruct_model(selected_model, selected_model_name):
if selected_model_name in ["transformers_phi3cpu_mini_4k_instruct"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the only model for which the test works? I thought that some of the others supported the role tags? Perhaps move the fixture to conf.py and call it something like model_with_role_tags ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah it appears so! It looks like there's both transformers and non transformers of phi3 mini instruct 4k, as well as a couple other instruct versions I recognized.

Good idea to move it to the conftest.py file though, I can definitely see situations where we could use more tests specifically for models that utilize roles.

return selected_model
else:
pytest.skip("Requires Phi3 4k Instruct model")



def test_capture():
model = models.Mock()
model += "This is" + capture(select(options=["bad", "quite bad"]), name="my_var")
Expand Down Expand Up @@ -44,4 +54,11 @@ def raw_fn(lm):

assert str(lm_nocap).endswith("|the end")
assert str(lm_cap_arg).endswith("|the end")
assert str(lm_cap_kwarg).endswith("|the end")
assert str(lm_cap_kwarg).endswith("|the end")

def test_capture_within_role(instruct_model: guidance.models.Model):
lm = instruct_model
test_text = "This is some text in a role."
with user():
lm += capture(test_text, "test")
assert lm["test"] == test_text