diff --git a/guidance/library/_capture.py b/guidance/library/_capture.py
index afeff6eed..1efac56c4 100644
--- a/guidance/library/_capture.py
+++ b/guidance/library/_capture.py
@@ -1,6 +1,9 @@
from .._guidance import guidance
from .._grammar import capture as grammar_capture, GrammarFunction
+# Adapted from active_role_end in _model.py, functionality should be shared probably?
+import re
+format_pattern = re.compile(r"<\|\|_.*?_\|\|>", flags=re.DOTALL)
@guidance(stateless=lambda *args, **kwargs: isinstance(args[0], GrammarFunction))
def capture(lm, value, name):
@@ -9,4 +12,11 @@ def capture(lm, value, name):
else:
start_len = len(lm)
lm += value
- return lm.set(name, str(lm)[start_len:])
+ # Adapted from active_role_end in _model.py
+ parts = ""
+ for _, role_end_str in lm.opened_blocks.values():
+ role_end_str = format_pattern.sub("", role_end_str)
+ if len(role_end_str) > 0 and not re.fullmatch(r"\s+", role_end_str):
+ parts += role_end_str
+
+ return lm.set(name, str(lm)[start_len-len(parts):].removesuffix(parts))
\ No newline at end of file
diff --git a/tests/conftest.py b/tests/conftest.py
index 94b697954..fc91a8a0f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -16,9 +16,7 @@
AVAILABLE_MODELS = {
"gpt2cpu": dict(name="transformers:gpt2", kwargs=dict()),
- "phi2cpu": dict(
- name="transformers:microsoft/phi-2", kwargs={"trust_remote_code": True}
- ),
+ "phi2cpu": dict(name="transformers:microsoft/phi-2", kwargs={"trust_remote_code": True}),
"azure_guidance": dict(
name="azure_guidance:",
kwargs={},
@@ -41,9 +39,7 @@
name="huggingface_hubllama:TheBloke/Llama-2-7B-GGUF:llama-2-7b.Q5_K_M.gguf",
kwargs={"verbose": True, "n_ctx": 4096},
),
- "transformers_mistral_7b": dict(
- name="transformers:mistralai/Mistral-7B-v0.1", kwargs=dict()
- ),
+ "transformers_mistral_7b": dict(name="transformers:mistralai/Mistral-7B-v0.1", kwargs=dict()),
"hfllama_mistral_7b": dict(
name="huggingface_hubllama:TheBloke/Mistral-7B-Instruct-v0.2-GGUF:mistral-7b-instruct-v0.2.Q8_0.gguf",
kwargs={"verbose": True},
@@ -101,6 +97,19 @@ def selected_model(selected_model_info: str) -> models.Model:
return model
+@pytest.fixture(scope="module")
+def model_with_role_tags(selected_model, selected_model_name):
+ if selected_model_name in [
+ "transformers_phi3cpu_mini_4k_instruct",
+ "transformers_llama3cpu_8b",
+ "hfllama_phi3cpu_mini_4k_instruct",
+ "hfllama_mistral_7b",
+ ]:
+ return selected_model
+ else:
+ pytest.skip("Requires a model that supports role tags!")
+
+
@pytest.fixture(scope="function")
def rate_limiter() -> int:
"""Limit test execution rate
diff --git a/tests/model_integration/library/test_capture.py b/tests/model_integration/library/test_capture.py
new file mode 100644
index 000000000..13a4ad5bc
--- /dev/null
+++ b/tests/model_integration/library/test_capture.py
@@ -0,0 +1,8 @@
+import guidance
+
+def test_capture_within_role(model_with_role_tags: guidance.models.Model):
+ lm = model_with_role_tags
+ test_text = "This is some text in a role."
+ with guidance.user():
+ lm += guidance.capture(test_text, "test")
+ assert lm["test"] == test_text
diff --git a/tests/unit/library/test_capture.py b/tests/unit/library/test_capture.py
index 0c0f70179..a6e659de6 100644
--- a/tests/unit/library/test_capture.py
+++ b/tests/unit/library/test_capture.py
@@ -25,11 +25,18 @@ def raw_fn(lm):
elif lm["state"] == "2":
lm += select(["5", "6"], name="state_2")
return lm
-
+
lm_nocap = lm + "the beginning|" + raw_fn() + "|the end"
- lm_cap_arg = lm + "the beginning|" + capture("" + raw_fn() + "" , "cap_arg") + "|the end"
- lm_cap_kwarg = lm + "the beginning|" + capture("" + raw_fn() + "", name="cap_kwarg") + "|the end"
-
+ lm_cap_arg = (
+ lm + "the beginning|" + capture("" + raw_fn() + "", "cap_arg") + "|the end"
+ )
+ lm_cap_kwarg = (
+ lm
+ + "the beginning|"
+ + capture("" + raw_fn() + "", name="cap_kwarg")
+ + "|the end"
+ )
+
# Bunch of random tests
assert "state_1" in lm_nocap or "state_2" in lm_nocap
assert "cap_arg" in lm_cap_arg
@@ -42,4 +49,4 @@ def raw_fn(lm):
assert str(lm_nocap).endswith("|the end")
assert str(lm_cap_arg).endswith("|the end")
- assert str(lm_cap_kwarg).endswith("|the end")
\ No newline at end of file
+ assert str(lm_cap_kwarg).endswith("|the end")