diff --git a/guidance/library/_capture.py b/guidance/library/_capture.py index afeff6eed..1efac56c4 100644 --- a/guidance/library/_capture.py +++ b/guidance/library/_capture.py @@ -1,6 +1,9 @@ from .._guidance import guidance from .._grammar import capture as grammar_capture, GrammarFunction +# Adapted from active_role_end in _model.py, functionality should be shared probably? +import re +format_pattern = re.compile(r"<\|\|_.*?_\|\|>", flags=re.DOTALL) @guidance(stateless=lambda *args, **kwargs: isinstance(args[0], GrammarFunction)) def capture(lm, value, name): @@ -9,4 +12,11 @@ def capture(lm, value, name): else: start_len = len(lm) lm += value - return lm.set(name, str(lm)[start_len:]) + # Adapted from active_role_end in _model.py + parts = "" + for _, role_end_str in lm.opened_blocks.values(): + role_end_str = format_pattern.sub("", role_end_str) + if len(role_end_str) > 0 and not re.fullmatch(r"\s+", role_end_str): + parts += role_end_str + + return lm.set(name, str(lm)[start_len-len(parts):].removesuffix(parts)) \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 94b697954..fc91a8a0f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,9 +16,7 @@ AVAILABLE_MODELS = { "gpt2cpu": dict(name="transformers:gpt2", kwargs=dict()), - "phi2cpu": dict( - name="transformers:microsoft/phi-2", kwargs={"trust_remote_code": True} - ), + "phi2cpu": dict(name="transformers:microsoft/phi-2", kwargs={"trust_remote_code": True}), "azure_guidance": dict( name="azure_guidance:", kwargs={}, @@ -41,9 +39,7 @@ name="huggingface_hubllama:TheBloke/Llama-2-7B-GGUF:llama-2-7b.Q5_K_M.gguf", kwargs={"verbose": True, "n_ctx": 4096}, ), - "transformers_mistral_7b": dict( - name="transformers:mistralai/Mistral-7B-v0.1", kwargs=dict() - ), + "transformers_mistral_7b": dict(name="transformers:mistralai/Mistral-7B-v0.1", kwargs=dict()), "hfllama_mistral_7b": dict( name="huggingface_hubllama:TheBloke/Mistral-7B-Instruct-v0.2-GGUF:mistral-7b-instruct-v0.2.Q8_0.gguf", kwargs={"verbose": True}, @@ -101,6 +97,19 @@ def selected_model(selected_model_info: str) -> models.Model: return model +@pytest.fixture(scope="module") +def model_with_role_tags(selected_model, selected_model_name): + if selected_model_name in [ + "transformers_phi3cpu_mini_4k_instruct", + "transformers_llama3cpu_8b", + "hfllama_phi3cpu_mini_4k_instruct", + "hfllama_mistral_7b", + ]: + return selected_model + else: + pytest.skip("Requires a model that supports role tags!") + + @pytest.fixture(scope="function") def rate_limiter() -> int: """Limit test execution rate diff --git a/tests/model_integration/library/test_capture.py b/tests/model_integration/library/test_capture.py new file mode 100644 index 000000000..13a4ad5bc --- /dev/null +++ b/tests/model_integration/library/test_capture.py @@ -0,0 +1,8 @@ +import guidance + +def test_capture_within_role(model_with_role_tags: guidance.models.Model): + lm = model_with_role_tags + test_text = "This is some text in a role." + with guidance.user(): + lm += guidance.capture(test_text, "test") + assert lm["test"] == test_text diff --git a/tests/unit/library/test_capture.py b/tests/unit/library/test_capture.py index 0c0f70179..a6e659de6 100644 --- a/tests/unit/library/test_capture.py +++ b/tests/unit/library/test_capture.py @@ -25,11 +25,18 @@ def raw_fn(lm): elif lm["state"] == "2": lm += select(["5", "6"], name="state_2") return lm - + lm_nocap = lm + "the beginning|" + raw_fn() + "|the end" - lm_cap_arg = lm + "the beginning|" + capture("" + raw_fn() + "" , "cap_arg") + "|the end" - lm_cap_kwarg = lm + "the beginning|" + capture("" + raw_fn() + "", name="cap_kwarg") + "|the end" - + lm_cap_arg = ( + lm + "the beginning|" + capture("" + raw_fn() + "", "cap_arg") + "|the end" + ) + lm_cap_kwarg = ( + lm + + "the beginning|" + + capture("" + raw_fn() + "", name="cap_kwarg") + + "|the end" + ) + # Bunch of random tests assert "state_1" in lm_nocap or "state_2" in lm_nocap assert "cap_arg" in lm_cap_arg @@ -42,4 +49,4 @@ def raw_fn(lm): assert str(lm_nocap).endswith("|the end") assert str(lm_cap_arg).endswith("|the end") - assert str(lm_cap_kwarg).endswith("|the end") \ No newline at end of file + assert str(lm_cap_kwarg).endswith("|the end")