From c2a8cfa0ecdf5a1db1a2a3971344305468401fcb Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Thu, 16 Jan 2025 16:40:07 -0800 Subject: [PATCH 01/46] add prompt base classes and ElementListPrompt Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/prompts/prompts.py | 211 ++++++++++++++++++ 1 file changed, 211 insertions(+) create mode 100644 lib/sycamore/sycamore/llms/prompts/prompts.py diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py new file mode 100644 index 000000000..c7ec8ec37 --- /dev/null +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -0,0 +1,211 @@ +from dataclasses import dataclass +from typing import Any, Union, Optional, Callable +import copy + +import pydantic +from sycamore.data.document import Document, Element + + +@dataclass +class RenderedMessage: + """Represents a message per the LLM messages interface - i.e. a role and a content string + + Args: + role: the role of this message. Should be one of "user", "system", "assistant" + content: the content of this message. + """ + + role: str + content: str + + def to_dict(self): + return {"role": self.role, "content": self.content} + + +@dataclass +class RenderedPrompt: + """Represents a prompt to be sent to the LLM per the LLM messages interface + + Args: + messages: the list of messages to be sent to the LLM + response_format: optional output schema, speicified as pydict/json or + a pydantic model. Can only be used (iirc) with modern OpenAI models. + """ + + messages: list[RenderedMessage] + response_format: Union[None, dict[str, Any], pydantic.BaseModel] = None + + def to_dict(self): + res = {"messages": [m.to_dict() for m in self.messages]} + if self.response_format is not None: + res["response_format"] = self.output_structure # type: ignore + return res + + +class SycamorePrompt: + """Base class/API for all Sycamore LLM Prompt objects. Sycamore Prompts + convert sycamore objects (``Document``s, ``Element``s) into ``RenderedPrompts`` + """ + + def render_document(self, doc: Document) -> RenderedPrompt: + """Render this prompt, given this document as context. + Used in llm_map + + Args: + doc: The document to use to populate the prompt + + Returns: + A fully rendered prompt that can be sent to an LLM for inference + """ + raise NotImplementedError(f"render_document is not implemented for {self.__class__.__name__}") + + def render_element(self, elt: Element) -> RenderedPrompt: + """Render this prompt, given this element as context. + Used in llm_map_elements + + Args: + elt: The element to use to populate the prompt + + Returns: + A fully rendered prompt that can be sent to an LLM for inference + """ + raise NotImplementedError(f"render_element is not implemented for {self.__class__.__name__}") + + def render_multiple_documents(self, docs: list[Document]) -> RenderedPrompt: + """Render this prompt, given a list of documents as context. + Used in llm_reduce + + Args: + docs: The list of documents to use to populate the prompt + + Returns: + A fully rendered prompt that can be sent to an LLM for inference""" + raise NotImplementedError(f"render_multiple_documents is not implemented for {self.__class__.__name__}") + + def instead(self, **kwargs) -> "SycamorePrompt": + """Create a new prompt with some fields changed. + + Args: + **kwargs: any keyword arguments will get set as fields in the + resulting prompt + + Returns: + A new SycamorePrompt with updated fields. + + Example: + .. code-block:: python + + p = StaticPrompt(system="hello", user="world") + p.render_document(Document()) + # [ + # {"role": "system", "content": "hello"}, + # {"role": "user", "content": "world"} + # ] + p2 = p.instead(user="bob") + p2.render_document(Document()) + # [ + # {"role": "system", "content": "hello"}, + # {"role": "user", "content": "bob"} + # ] + """ + new = copy.deepcopy(self) + new.__dict__.update(kwargs) + return new + + +class ElementListPrompt(SycamorePrompt): + """A prompt with utilities for constructing a list of elements to include + in the rendered prompt. + + Args: + system: The system prompt string. Use {} to reference names that should + be interpolated. Defaults to None + user: The user prompt string. Use {} to reference names that should be + interpolated. Defaults to None + element_select: Function to choose which set of elements to include in + the prompt. If None, defaults to the first ``num_elements`` elements. + element_order: Function to reorder the selected elements. Defaults to + a noop. + element_list_constructor: Function to turn a list of elements into a + string that can be accessed with the interpolation key "{elements}". + Defaults to "ELEMENT 0: {elts[0].text_representation}\n + ELEMENT 1: {elts[1].text_representation}\n + ..." + num_elements: Sets the number of elements to take if ``element_select`` is + unset. Default is 35. + **kwargs: other keyword arguments are stored and can be used as interpolation keys. + + Example: + .. code-block:: python + + prompt = ElementListPrompt( + system = "Hello {name}. This is a prompt about {doc_property_path}" + user = "What do you make of these tables?\nTables:\n{elements}" + element_select = lambda elts: [e for e in elts if e.type == "table"] + element_order = reversed + name = "David Rothschild" + ) + prompt.render_document(doc) + # [ + # {"role": "system", "content": "Hello David Rothschild. This is a prompt about data/mypdf.pdf"}, + # {"role": "user", "content": "What do you make of these tables?\nTables:\n + # ELEMENT 0: \nELEMENT 1: ..."} + # ] + """ + + def __init__( + self, + *, + system: Optional[str] = None, + user: Optional[str] = None, + element_select: Optional[Callable[[list[Element]], list[Element]]] = None, + element_order: Optional[Callable[[list[Element]], list[Element]]] = None, + element_list_constructor: Optional[Callable[[list[Element]], str]] = None, + num_elements: int = 35, + **kwargs, + ): + self.system = system + self.user = user + self.element_select = element_select or (lambda elts: elts[:num_elements]) + self.element_order = element_order or (lambda elts: elts) + self.element_list_constructor = element_list_constructor or ( + lambda elts: "\n".join(f"ELEMENT {i}: {elts[i].text_representation}" for i in range(len(elts))) + ) + self.kwargs = kwargs + super().__init__() + + def _render_element_list_to_string(self, doc: Document): + elts = self.element_select(doc.elements) + elts = self.element_order(elts) + return self.element_list_constructor(elts) + + def render_document(self, doc: Document) -> RenderedPrompt: + """Render this prompt, given this document as context, using python's + ``str.format()`` method. The keys passed into ``format()`` are as follows: + + - self.kwargs: the additional kwargs specified when creating this prompt. + - doc_text: doc.text_representation + - doc_property_: each property name in doc.properties is + prefixed with 'doc_property_'. So if ``doc.properties = {'k1': 0, 'k2': 3}``, + you get ``doc_property_k1 = 0, doc_property_k2 = 3``. + - elements: the element list constructed from doc.elements using ``self.element_select``, + ``self.element_order``, and ``self.element_list_constructor``. + + Args: + doc: The document to use as context for rendering this prompt + + Returns: + A two-message RenderedPrompt containing ``self.system.format()`` and ``self.user.format()`` + using the format keys as specified above. + """ + format_args = self.kwargs + format_args["doc_text"] = doc.text_representation + format_args.update({"doc_property_" + k: v for k, v in doc.properties.items()}) + format_args["elements"] = self._render_element_list_to_string(doc) + + result = RenderedPrompt(messages=[]) + if self.system is not None: + result.messages.append(RenderedMessage(role="system", content=self.system.format(**format_args))) + if self.user is not None: + result.messages.append(RenderedMessage(role="user", content=self.user.format(**format_args))) + return result From 21a115a614133244facdb99bb6ef20dc10427550 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Thu, 16 Jan 2025 16:49:31 -0800 Subject: [PATCH 02/46] override .instead in ElementListPrompt to store net-new keys in self.kwargs Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/prompts/prompts.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py index c7ec8ec37..f4185af51 100644 --- a/lib/sycamore/sycamore/llms/prompts/prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -209,3 +209,12 @@ def render_document(self, doc: Document) -> RenderedPrompt: if self.user is not None: result.messages.append(RenderedMessage(role="user", content=self.user.format(**format_args))) return result + + def instead(self, **kwargs) -> "SycamorePrompt": + new = copy.deepcopy(self) + for k in kwargs: + if k in new.__dict__: + new.__dict__[k] = kwargs[k] + else: + new.kwargs[k] = kwargs[k] + return new From f94da80f841104870e770a13de2f1f9912f0a078 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Fri, 17 Jan 2025 15:38:49 -0800 Subject: [PATCH 03/46] add ElementPrompt and StaticPrompt Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/prompts/prompts.py | 184 +++++++++++++++--- 1 file changed, 161 insertions(+), 23 deletions(-) diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py index f4185af51..9feb8d27f 100644 --- a/lib/sycamore/sycamore/llms/prompts/prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -3,7 +3,9 @@ import copy import pydantic +from PIL import Image from sycamore.data.document import Document, Element +from sycamore.utils.pdf_utils import get_element_image @dataclass @@ -12,11 +14,13 @@ class RenderedMessage: Args: role: the role of this message. Should be one of "user", "system", "assistant" - content: the content of this message. + content: the content of this message, either a python string or a PIL image. + images: optional list of images to include in this message. """ role: str content: str + images: Optional[list[Image.Image]] = None def to_dict(self): return {"role": self.role, "content": self.content} @@ -59,8 +63,8 @@ def render_document(self, doc: Document) -> RenderedPrompt: """ raise NotImplementedError(f"render_document is not implemented for {self.__class__.__name__}") - def render_element(self, elt: Element) -> RenderedPrompt: - """Render this prompt, given this element as context. + def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: + """Render this prompt, given this element and its parent document as context. Used in llm_map_elements Args: @@ -86,10 +90,12 @@ def instead(self, **kwargs) -> "SycamorePrompt": """Create a new prompt with some fields changed. Args: + **kwargs: any keyword arguments will get set as fields in the resulting prompt Returns: + A new SycamorePrompt with updated fields. Example: @@ -98,18 +104,22 @@ def instead(self, **kwargs) -> "SycamorePrompt": p = StaticPrompt(system="hello", user="world") p.render_document(Document()) # [ - # {"role": "system", "content": "hello"}, - # {"role": "user", "content": "world"} - # ] - p2 = p.instead(user="bob") - p2.render_document(Document()) + # {"role": "system", "content": "hello"}, + # {"role": "user", "content": "world"} + # ] + p2 = p.instead(user="bob") + p2.render_document(Document()) # [ # {"role": "system", "content": "hello"}, # {"role": "user", "content": "bob"} # ] """ new = copy.deepcopy(self) - new.__dict__.update(kwargs) + for k, v in kwargs.items(): + if hasattr(new, "kwargs") and k not in new.__dict__: + getattr(new, "kwargs")[k] = v + else: + new.__dict__[k] = v return new @@ -118,6 +128,7 @@ class ElementListPrompt(SycamorePrompt): in the rendered prompt. Args: + system: The system prompt string. Use {} to reference names that should be interpolated. Defaults to None user: The user prompt string. Use {} to reference names that should be @@ -128,8 +139,8 @@ class ElementListPrompt(SycamorePrompt): a noop. element_list_constructor: Function to turn a list of elements into a string that can be accessed with the interpolation key "{elements}". - Defaults to "ELEMENT 0: {elts[0].text_representation}\n - ELEMENT 1: {elts[1].text_representation}\n + Defaults to "ELEMENT 0: {elts[0].text_representation}\\n + ELEMENT 1: {elts[1].text_representation}\\n ..." num_elements: Sets the number of elements to take if ``element_select`` is unset. Default is 35. @@ -140,7 +151,7 @@ class ElementListPrompt(SycamorePrompt): prompt = ElementListPrompt( system = "Hello {name}. This is a prompt about {doc_property_path}" - user = "What do you make of these tables?\nTables:\n{elements}" + user = "What do you make of these tables?\\nTables:\\n{elements}" element_select = lambda elts: [e for e in elts if e.type == "table"] element_order = reversed name = "David Rothschild" @@ -148,8 +159,8 @@ class ElementListPrompt(SycamorePrompt): prompt.render_document(doc) # [ # {"role": "system", "content": "Hello David Rothschild. This is a prompt about data/mypdf.pdf"}, - # {"role": "user", "content": "What do you make of these tables?\nTables:\n - # ELEMENT 0: \nELEMENT 1: ..."} + # {"role": "user", "content": "What do you make of these tables?\\nTables:\\n + # ELEMENT 0: \\nELEMENT 1: ..."} # ] """ @@ -164,6 +175,7 @@ def __init__( num_elements: int = 35, **kwargs, ): + super().__init__() self.system = system self.user = user self.element_select = element_select or (lambda elts: elts[:num_elements]) @@ -172,7 +184,6 @@ def __init__( lambda elts: "\n".join(f"ELEMENT {i}: {elts[i].text_representation}" for i in range(len(elts))) ) self.kwargs = kwargs - super().__init__() def _render_element_list_to_string(self, doc: Document): elts = self.element_select(doc.elements) @@ -210,11 +221,138 @@ def render_document(self, doc: Document) -> RenderedPrompt: result.messages.append(RenderedMessage(role="user", content=self.user.format(**format_args))) return result - def instead(self, **kwargs) -> "SycamorePrompt": - new = copy.deepcopy(self) - for k in kwargs: - if k in new.__dict__: - new.__dict__[k] = kwargs[k] - else: - new.kwargs[k] = kwargs[k] - return new + +class ElementPrompt(SycamorePrompt): + """A prompt for rendering an element with utilities for capturing information + from the element's parent document, with a system and user prompt. + + Args: + system: The system prompt string. Use {} to reference names to be interpolated. + Defaults to None + user: The user prompt string. Use {} to reference names to be interpolated. + Defaults to None + include_element_image: Whether to include an image of the element in the rendered user + message. Only works if the parent document is a PDF. Defaults to False (no image) + capture_parent_context: Function to gather context from the element's parent document. + Should return {"key": value} dictionary, which will be made available as interpolation + keys. Defaults to returning {} + **kwargs: other keyword arguments are stored and can be used as interpolation keys + + Example: + .. code-block:: python + + prompt = ElementPrompt( + system = "You know everything there is to know about {custom_kwarg}, {name}", + user = "Summarize the information on page {elt_property_page}. \\nTEXT: {elt_text}", + capture_parent_context = lambda doc, elt: {"custom_kwarg": doc.properties["path"]}, + name = "Frank Sinatra", + ) + prompt.render_element(doc.elements[0], doc) + # [ + # {"role": "system", "content": "You know everything there is to know + # about /path/to/doc.pdf, Frank Sinatra"}, + # {"role": "user", "content": "Summarize the information on page 1. \\nTEXT: "} + # ] + """ + + def __init__( + self, + *, + system: Optional[str] = None, + user: Optional[str] = None, + include_element_image: bool = False, + capture_parent_context: Optional[Callable[[Document, Element], dict[str, Any]]] = None, + **kwargs, + ): + super().__init__() + self.system = system + self.user = user + self.include_element_image = include_element_image + self.capture_parent_context = capture_parent_context or (lambda doc, elt: {}) + self.kwargs = kwargs + + def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: + """Render this prompt for this element; also take the parent document + if there is context in that to account for as well. Rendering is done + using pythons ``str.format()`` method. The keys passed into ``format`` + are as follows: + + - self.kwargs: the additional kwargs specified when creating this prompt. + - self.capture_parent_content(doc, elt): key-value pairs returned by the + context-capturing function. + - elt_text: elt.text_representation (the text representation of the element) + - elt_property_: each property name in elt.properties is + prefixed with 'elt_property_'. So if ``elt.properties = {'k1': 0, 'k2': 3}``, + you get ``elt_property_k1 = 0, elt_property_k2 = 3``. + + Args: + elt: The element used as context for rendering this prompt. + doc: The element's parent document; used to add additional context. + + Returns: + A two-message rendered prompt containing ``self.system.format()`` and + ``self.user.format()`` using the format keys as specified above. + If self.include_element_image is true, crop out the image from the page + of the PDF it's on and attach it to the last message (user message if there + is one, o/w system message). + """ + format_args = self.kwargs + format_args.update(self.capture_parent_context(doc, elt)) + format_args["elt_text"] = elt.text_representation + format_args.update({"elt_property_" + k: v for k, v in elt.properties.items()}) + + result = RenderedPrompt(messages=[]) + if self.system is not None: + result.messages.append(RenderedMessage(role="system", content=self.system.format(**format_args))) + if self.user is not None: + result.messages.append(RenderedMessage(role="user", content=self.user.format(**format_args))) + if self.include_element_image and len(result.messages) > 0: + result.messages[-1].images = [get_element_image(elt, doc)] + return result + + +class StaticPrompt(SycamorePrompt): + """A prompt that always renders the same regardless of the Document or Elements + passed in as context. + + Args: + + system: the system prompt string. Use {} to reference names to be interpolated. + Interpolated names only come from kwargs. + user: the user prompt string. Use {} to reference names to be interpolated. + Interpolated names only come from kwargs. + **kwargs: keyword arguments to interpolate. + + Example: + .. code-block:: python + + prompt = StaticPrompt(system="static", user = "prompt - {number}", number=7) + prompt.render_document(Document()) + # [ + # { "role": "system", "content": "static" }, + # { "role": "user", "content": "prompt - 7" }, + # ] + """ + + def __init__(self, *, system: Optional[str] = None, user: Optional[str] = None, **kwargs): + super().__init__() + self.system = system + self.user = user + self.kwargs = kwargs + + def render_generic(self) -> RenderedPrompt: + result = RenderedPrompt(messages=[]) + if self.system is not None: + result.messages.append(RenderedMessage(role="system", content=self.system.format(**self.kwargs))) + if self.user is not None: + result.messages.append(RenderedMessage(role="user", content=self.user.format(**self.kwargs))) + return result + + def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: + return self.render_generic() + + def render_document(self, doc: Document) -> RenderedPrompt: + return self.render_generic() + + def render_multiple_documents(self, docs: list[Document]) -> RenderedPrompt: + return self.render_generic() From b73c1624951f39799d20548df332d37cbeb915b1 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 21 Jan 2025 09:34:59 -0800 Subject: [PATCH 04/46] add unit tests for prompts Signed-off-by: Henry Lindeman --- .../tests/unit/llms/prompts/test_prompts.py | 238 ++++++++++++++++++ lib/sycamore/sycamore/utils/pdf_utils.py | 18 +- 2 files changed, 255 insertions(+), 1 deletion(-) create mode 100644 lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py diff --git a/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py b/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py new file mode 100644 index 000000000..6ce28c1dc --- /dev/null +++ b/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py @@ -0,0 +1,238 @@ +from sycamore.data.element import Element +from sycamore.llms.prompts.prompts import ( + RenderedPrompt, + RenderedMessage, + StaticPrompt, + SycamorePrompt, + ElementPrompt, + ElementListPrompt, +) +from sycamore.data import Document +from sycamore.tests.config import TEST_DIR +from pyarrow.fs import LocalFileSystem +import pytest + + +@pytest.fixture(scope="module") +def dummy_document(): + docpath = TEST_DIR / "resources/data/pdfs/ntsb-report.pdf" + local = LocalFileSystem() + path = str(docpath) + input_stream = local.open_input_stream(path) + document = Document() + document.binary_representation = input_stream.readall() + document.type = "pdf" + document.properties["path"] = path + document.properties["pages"] = 6 + document.elements = [ + Element( + text_representation="Element 1", + type="Text", + element_id="e1", + properties={"page_number": 1}, + bbox=(0.1, 0.1, 0.4, 0.4), + ), + Element( + text_representation="Element 2", + type="Text", + element_id="e2", + properties={"page_number": 2}, + bbox=(0.1, 0.1, 0.4, 0.4), + ), + Element( + text_representation="Element 3", + type="Text", + element_id="e3", + properties={"page_number": 3}, + bbox=(0.1, 0.1, 0.4, 0.4), + ), + Element( + text_representation="Element 4", + type="Text", + element_id="e4", + properties={"page_number": 3}, + bbox=(0.4, 0.1, 0.8, 0.4), + ), + Element( + text_representation="Element 5", + type="Text", + element_id="e5", + properties={"page_number": 3}, + bbox=(0.1, 0.4, 0.8, 0.8), + ), + Element( + text_representation="Element 6", + type="Text", + element_id="e6", + properties={"page_number": 4}, + bbox=(0.1, 0.1, 0.4, 0.4), + ), + ] + return document + + +class TestRenderedPrompt: + """RenderedPrompt and RenderedMessage are dataclasses, + no need to test them. Nothing to test :)""" + + pass + + +class TestSycamorePrompt: + def test_instead_is_cow(self): + sp = SycamorePrompt() + sp.__dict__["key"] = "value" + sp2 = sp.instead(key="other value") + assert sp.key == "value" + assert sp2.key == "other value" + + +class TestStaticPrompt: + def test_static_rd(self, dummy_document): + prompt = StaticPrompt(system="system {x}", user="computers") + with pytest.raises(KeyError): + prompt.render_document(dummy_document) + + prompt = prompt.instead(x=76) + expected = RenderedPrompt( + messages=[ + RenderedMessage(role="system", content="system 76"), + RenderedMessage(role="user", content="computers"), + ] + ) + assert prompt.render_document(dummy_document) == expected + assert prompt.render_element(dummy_document.elements[0], dummy_document) == expected + assert prompt.render_multiple_documents([dummy_document]) == expected + + +class TestElementPrompt: + def test_basic(self, dummy_document): + prompt = ElementPrompt( + system="You know everything there is to know about jazz, {name}", + user="Summarize the information on page {elt_property_page_number}.\nTEXT: {elt_text}", + name="Frank Sinatra", + ) + expected = RenderedPrompt( + messages=[ + RenderedMessage( + role="system", content="You know everything there is to know about jazz, Frank Sinatra" + ), + RenderedMessage(role="user", content="Summarize the information on page 3.\nTEXT: Element 4"), + ] + ) + assert prompt.render_element(dummy_document.elements[3], dummy_document) == expected + with pytest.raises(NotImplementedError): + prompt.render_document(dummy_document) + with pytest.raises(NotImplementedError): + prompt.render_multiple_documents([dummy_document]) + + def test_get_parent_context(self, dummy_document): + prompt = ElementPrompt( + system="You know everything there is to know about {custom_property}, {name}", + user="Summarize the information on page {elt_property_page_number}.\nTEXT: {elt_text}", + name="Frank Sinatra", + capture_parent_context=lambda doc, elt: {"custom_property": doc.properties["pages"]}, + ) + expected = RenderedPrompt( + messages=[ + RenderedMessage(role="system", content="You know everything there is to know about 6, Frank Sinatra"), + RenderedMessage(role="user", content="Summarize the information on page 3.\nTEXT: Element 4"), + ] + ) + assert prompt.render_element(dummy_document.elements[3], dummy_document) == expected + + def test_include_image(self, dummy_document): + prompt = ElementPrompt( + system="You know everything there is to know about {custom_property}, {name}", + user="Summarize the information on page {elt_property_page_number}.\nTEXT: {elt_text}", + name="Frank Sinatra", + capture_parent_context=lambda doc, elt: {"custom_property": doc.properties["pages"]}, + include_element_image=True, + ) + rp = prompt.render_element(dummy_document.elements[3], dummy_document) + assert rp.messages[1].images is not None and len(rp.messages[1].images) == 1 + assert rp.messages[1].role == "user" + assert rp.messages[0].images is None + + prompt = prompt.instead(user=None) + rp2 = prompt.render_element(dummy_document.elements[1], dummy_document) + assert len(rp2.messages) == 1 + assert rp2.messages[0].role == "system" + assert rp2.messages[0].images is not None + assert len(rp2.messages[0].images) == 1 + + +class TestElementListPrompt: + def test_basic(self, dummy_document): + prompt = ElementListPrompt(system="sys", user="usr: {elements}") + expected = RenderedPrompt( + messages=[ + RenderedMessage(role="system", content="sys"), + RenderedMessage( + role="user", + content="usr: ELEMENT 0: Element 1\nELEMENT 1: Element 2\n" + "ELEMENT 2: Element 3\nELEMENT 3: Element 4\nELEMENT 4: Element 5\nELEMENT 5: Element 6", + ), + ] + ) + assert prompt.render_document(dummy_document) == expected + + def test_limit_elements(self, dummy_document): + prompt = ElementListPrompt(system="sys", user="usr: {elements}", num_elements=3) + expected = RenderedPrompt( + messages=[ + RenderedMessage(role="system", content="sys"), + RenderedMessage( + role="user", + content="usr: ELEMENT 0: Element 1\nELEMENT 1: Element 2\nELEMENT 2: Element 3", + ), + ] + ) + assert prompt.render_document(dummy_document) == expected + + def test_select_odd_elements(self, dummy_document): + prompt = ElementListPrompt( + system="sys", + user="usr: {elements}", + element_select=lambda elts: [elts[i] for i in range(len(elts)) if i % 2 == 1], + ) + expected = RenderedPrompt( + messages=[ + RenderedMessage(role="system", content="sys"), + RenderedMessage( + role="user", + content="usr: ELEMENT 0: Element 2\nELEMENT 1: Element 4\nELEMENT 2: Element 6", + ), + ] + ) + assert prompt.render_document(dummy_document) == expected + + def test_order_elements(self, dummy_document): + prompt = ElementListPrompt(system="sys", user="usr: {elements}", element_order=lambda e: list(reversed(e))) + expected = RenderedPrompt( + messages=[ + RenderedMessage(role="system", content="sys"), + RenderedMessage( + role="user", + content="usr: ELEMENT 0: Element 6\nELEMENT 1: Element 5\n" + "ELEMENT 2: Element 4\nELEMENT 3: Element 3\nELEMENT 4: Element 2\nELEMENT 5: Element 1", + ), + ] + ) + assert prompt.render_document(dummy_document) == expected + + def test_construct_element_list(self, dummy_document): + def list_constructor(elts: list[Element]) -> str: + return "<>" + "<>".join(f"{i}-{e.type}" for i, e in enumerate(elts)) + "" + + prompt = ElementListPrompt(system="sys", user="usr: {elements}", element_list_constructor=list_constructor) + expected = RenderedPrompt( + messages=[ + RenderedMessage(role="system", content="sys"), + RenderedMessage( + role="user", + content="usr: <>0-Text<>1-Text<>2-Text<>3-Text<>4-Text<>5-Text", + ), + ] + ) + assert prompt.render_document(dummy_document) == expected diff --git a/lib/sycamore/sycamore/utils/pdf_utils.py b/lib/sycamore/sycamore/utils/pdf_utils.py index 8665bee1e..92331b578 100644 --- a/lib/sycamore/sycamore/utils/pdf_utils.py +++ b/lib/sycamore/sycamore/utils/pdf_utils.py @@ -5,10 +5,11 @@ from PIL import Image from pypdf import PdfReader, PdfWriter +import pdf2image from sycamore import DocSet from sycamore.functions.document import DrawBoxes, split_and_convert_to_image -from sycamore.utils.image_utils import show_images +from sycamore.utils.image_utils import show_images, crop_to_bbox from sycamore.data import Document, Element import json @@ -180,3 +181,18 @@ def promote_title(elements: list[Element], title_candidate_elements=["Section-he if section_header: section_header.type = "Title" return elements + + +def get_element_image(element: Element, document: Document) -> Image.Image: + assert document.type == "pdf", "Cannot get picture of element from non-pdf" + assert document.binary_representation is not None, "Cannot get image since there is not binary representation" + assert element.bbox is not None, "Cannot get picture of element if it has no BBox" + assert element.properties.get("page_number") is not None and isinstance( + element.properties["page_number"], int + ), "Cannot get picture of element without known page number" + bits = BytesIO(document.binary_representation) + pagebits = BytesIO() + select_pdf_pages(bits, pagebits, [element.properties["page_number"]]) + images = pdf2image.convert_from_bytes(pagebits.getvalue()) + im = crop_to_bbox(images[0], element.bbox) + return im From 17b21635a41ac4a18b3384ba3cb057fc48e29802 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 21 Jan 2025 10:33:34 -0800 Subject: [PATCH 05/46] forgot to commit this Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/prompts/prompts.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py index 9feb8d27f..40b020f76 100644 --- a/lib/sycamore/sycamore/llms/prompts/prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -5,7 +5,6 @@ import pydantic from PIL import Image from sycamore.data.document import Document, Element -from sycamore.utils.pdf_utils import get_element_image @dataclass @@ -22,9 +21,6 @@ class RenderedMessage: content: str images: Optional[list[Image.Image]] = None - def to_dict(self): - return {"role": self.role, "content": self.content} - @dataclass class RenderedPrompt: @@ -39,12 +35,6 @@ class RenderedPrompt: messages: list[RenderedMessage] response_format: Union[None, dict[str, Any], pydantic.BaseModel] = None - def to_dict(self): - res = {"messages": [m.to_dict() for m in self.messages]} - if self.response_format is not None: - res["response_format"] = self.output_structure # type: ignore - return res - class SycamorePrompt: """Base class/API for all Sycamore LLM Prompt objects. Sycamore Prompts @@ -307,6 +297,8 @@ def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: if self.user is not None: result.messages.append(RenderedMessage(role="user", content=self.user.format(**format_args))) if self.include_element_image and len(result.messages) > 0: + from sycamore.utils.pdf_utils import get_element_image + result.messages[-1].images = [get_element_image(elt, doc)] return result From 5d145d5f8772e6385e7bf1a987547161728e58d9 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 21 Jan 2025 12:37:35 -0800 Subject: [PATCH 06/46] address pr comments; flatten properties with flatten_data Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/prompts/prompts.py | 25 ++++++++----------- .../tests/unit/llms/prompts/test_prompts.py | 10 +++++++- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py index 40b020f76..24c2e6392 100644 --- a/lib/sycamore/sycamore/llms/prompts/prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -5,6 +5,7 @@ import pydantic from PIL import Image from sycamore.data.document import Document, Element +from sycamore.connectors.common import flatten_data @dataclass @@ -12,8 +13,8 @@ class RenderedMessage: """Represents a message per the LLM messages interface - i.e. a role and a content string Args: - role: the role of this message. Should be one of "user", "system", "assistant" - content: the content of this message, either a python string or a PIL image. + role: the role of this message. e.g. for OpenAI should be one of "user", "system", "assistant" + content: the content of this message images: optional list of images to include in this message. """ @@ -29,7 +30,7 @@ class RenderedPrompt: Args: messages: the list of messages to be sent to the LLM response_format: optional output schema, speicified as pydict/json or - a pydantic model. Can only be used (iirc) with modern OpenAI models. + a pydantic model. Can only be used with modern OpenAI models. """ messages: list[RenderedMessage] @@ -123,10 +124,8 @@ class ElementListPrompt(SycamorePrompt): be interpolated. Defaults to None user: The user prompt string. Use {} to reference names that should be interpolated. Defaults to None - element_select: Function to choose which set of elements to include in - the prompt. If None, defaults to the first ``num_elements`` elements. - element_order: Function to reorder the selected elements. Defaults to - a noop. + element_select: Function to choose the elements (and their order) to include + in the prompt. If None, defaults to the first ``num_elements`` elements. element_list_constructor: Function to turn a list of elements into a string that can be accessed with the interpolation key "{elements}". Defaults to "ELEMENT 0: {elts[0].text_representation}\\n @@ -142,8 +141,7 @@ class ElementListPrompt(SycamorePrompt): prompt = ElementListPrompt( system = "Hello {name}. This is a prompt about {doc_property_path}" user = "What do you make of these tables?\\nTables:\\n{elements}" - element_select = lambda elts: [e for e in elts if e.type == "table"] - element_order = reversed + element_select = lambda elts: list(reversed(e for e in elts if e.type == "table")) name = "David Rothschild" ) prompt.render_document(doc) @@ -160,7 +158,6 @@ def __init__( system: Optional[str] = None, user: Optional[str] = None, element_select: Optional[Callable[[list[Element]], list[Element]]] = None, - element_order: Optional[Callable[[list[Element]], list[Element]]] = None, element_list_constructor: Optional[Callable[[list[Element]], str]] = None, num_elements: int = 35, **kwargs, @@ -169,7 +166,6 @@ def __init__( self.system = system self.user = user self.element_select = element_select or (lambda elts: elts[:num_elements]) - self.element_order = element_order or (lambda elts: elts) self.element_list_constructor = element_list_constructor or ( lambda elts: "\n".join(f"ELEMENT {i}: {elts[i].text_representation}" for i in range(len(elts))) ) @@ -177,7 +173,6 @@ def __init__( def _render_element_list_to_string(self, doc: Document): elts = self.element_select(doc.elements) - elts = self.element_order(elts) return self.element_list_constructor(elts) def render_document(self, doc: Document) -> RenderedPrompt: @@ -201,7 +196,8 @@ def render_document(self, doc: Document) -> RenderedPrompt: """ format_args = self.kwargs format_args["doc_text"] = doc.text_representation - format_args.update({"doc_property_" + k: v for k, v in doc.properties.items()}) + flat_props = flatten_data(doc.properties, prefix="doc_property", separator="_") + format_args.update(flat_props) format_args["elements"] = self._render_element_list_to_string(doc) result = RenderedPrompt(messages=[]) @@ -289,7 +285,8 @@ def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: format_args = self.kwargs format_args.update(self.capture_parent_context(doc, elt)) format_args["elt_text"] = elt.text_representation - format_args.update({"elt_property_" + k: v for k, v in elt.properties.items()}) + flat_props = flatten_data(elt.properties, prefix="elt_property", separator="_") + format_args.update(flat_props) result = RenderedPrompt(messages=[]) if self.system is not None: diff --git a/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py b/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py index 6ce28c1dc..6c36d7510 100644 --- a/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py +++ b/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py @@ -208,7 +208,7 @@ def test_select_odd_elements(self, dummy_document): assert prompt.render_document(dummy_document) == expected def test_order_elements(self, dummy_document): - prompt = ElementListPrompt(system="sys", user="usr: {elements}", element_order=lambda e: list(reversed(e))) + prompt = ElementListPrompt(system="sys", user="usr: {elements}", element_select=lambda e: list(reversed(e))) expected = RenderedPrompt( messages=[ RenderedMessage(role="system", content="sys"), @@ -236,3 +236,11 @@ def list_constructor(elts: list[Element]) -> str: ] ) assert prompt.render_document(dummy_document) == expected + + def test_flattened_properties(self, dummy_document): + doc = dummy_document.copy() + doc.properties["entity"] = {"key": "value"} + + prompt = ElementListPrompt(system="sys {doc_property_entity_key}") + expected = RenderedPrompt(messages=[RenderedMessage(role="system", content="sys value")]) + assert prompt.render_document(doc) == expected From 7fa2ff1488a4db4be5b58c8e15a565b37c32854b Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 21 Jan 2025 13:12:51 -0800 Subject: [PATCH 07/46] support multiple user prompts Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/prompts/prompts.py | 40 ++++++++++--------- .../tests/unit/llms/prompts/test_prompts.py | 11 +++++ 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py index 24c2e6392..40930bad3 100644 --- a/lib/sycamore/sycamore/llms/prompts/prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -114,6 +114,19 @@ def instead(self, **kwargs) -> "SycamorePrompt": return new +def _build_format_str( + system: Optional[str], user: Union[None, str, list[str]], format_args: dict[str, Any] +) -> list[RenderedMessage]: + messages = [] + if system is not None: + messages.append(RenderedMessage(role="system", content=system.format(**format_args))) + if isinstance(user, list): + messages.extend([RenderedMessage(role="user", content=u.format(**format_args)) for u in user]) + elif isinstance(user, str): + messages.append(RenderedMessage(role="user", content=user.format(**format_args))) + return messages + + class ElementListPrompt(SycamorePrompt): """A prompt with utilities for constructing a list of elements to include in the rendered prompt. @@ -156,7 +169,7 @@ def __init__( self, *, system: Optional[str] = None, - user: Optional[str] = None, + user: Union[None, str, list[str]] = None, element_select: Optional[Callable[[list[Element]], list[Element]]] = None, element_list_constructor: Optional[Callable[[list[Element]], str]] = None, num_elements: int = 35, @@ -200,11 +213,8 @@ def render_document(self, doc: Document) -> RenderedPrompt: format_args.update(flat_props) format_args["elements"] = self._render_element_list_to_string(doc) - result = RenderedPrompt(messages=[]) - if self.system is not None: - result.messages.append(RenderedMessage(role="system", content=self.system.format(**format_args))) - if self.user is not None: - result.messages.append(RenderedMessage(role="user", content=self.user.format(**format_args))) + messages = _build_format_str(self.system, self.user, format_args) + result = RenderedPrompt(messages=messages) return result @@ -245,7 +255,7 @@ def __init__( self, *, system: Optional[str] = None, - user: Optional[str] = None, + user: Union[None, str, list[str]] = None, include_element_image: bool = False, capture_parent_context: Optional[Callable[[Document, Element], dict[str, Any]]] = None, **kwargs, @@ -288,11 +298,8 @@ def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: flat_props = flatten_data(elt.properties, prefix="elt_property", separator="_") format_args.update(flat_props) - result = RenderedPrompt(messages=[]) - if self.system is not None: - result.messages.append(RenderedMessage(role="system", content=self.system.format(**format_args))) - if self.user is not None: - result.messages.append(RenderedMessage(role="user", content=self.user.format(**format_args))) + messages = _build_format_str(self.system, self.user, format_args) + result = RenderedPrompt(messages=messages) if self.include_element_image and len(result.messages) > 0: from sycamore.utils.pdf_utils import get_element_image @@ -323,18 +330,15 @@ class StaticPrompt(SycamorePrompt): # ] """ - def __init__(self, *, system: Optional[str] = None, user: Optional[str] = None, **kwargs): + def __init__(self, *, system: Optional[str] = None, user: Union[None, str, list[str]] = None, **kwargs): super().__init__() self.system = system self.user = user self.kwargs = kwargs def render_generic(self) -> RenderedPrompt: - result = RenderedPrompt(messages=[]) - if self.system is not None: - result.messages.append(RenderedMessage(role="system", content=self.system.format(**self.kwargs))) - if self.user is not None: - result.messages.append(RenderedMessage(role="user", content=self.user.format(**self.kwargs))) + messages = _build_format_str(self.system, self.user, self.kwargs) + result = RenderedPrompt(messages=messages) return result def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: diff --git a/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py b/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py index 6c36d7510..111a6ee33 100644 --- a/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py +++ b/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py @@ -104,6 +104,17 @@ def test_static_rd(self, dummy_document): assert prompt.render_element(dummy_document.elements[0], dummy_document) == expected assert prompt.render_multiple_documents([dummy_document]) == expected + def test_static_with_multiple_user_prompts(self, dummy_document): + prompt = StaticPrompt(system="system {x}", user=["{x} user {y}", "{x} user {z}"], x=1, y=2, z=3) + expected = RenderedPrompt( + messages=[ + RenderedMessage(role="system", content="system 1"), + RenderedMessage(role="user", content="1 user 2"), + RenderedMessage(role="user", content="1 user 3"), + ] + ) + assert prompt.render_document(dummy_document) == expected + class TestElementPrompt: def test_basic(self, dummy_document): From abf9b0b7dcd32d37e1d849cd509f8e33bf710d0b Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 21 Jan 2025 16:16:30 -0800 Subject: [PATCH 08/46] rename instead to set Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/prompts/prompts.py | 4 ++-- .../sycamore/tests/unit/llms/prompts/test_prompts.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py index 40930bad3..0ea81112e 100644 --- a/lib/sycamore/sycamore/llms/prompts/prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -77,7 +77,7 @@ def render_multiple_documents(self, docs: list[Document]) -> RenderedPrompt: A fully rendered prompt that can be sent to an LLM for inference""" raise NotImplementedError(f"render_multiple_documents is not implemented for {self.__class__.__name__}") - def instead(self, **kwargs) -> "SycamorePrompt": + def set(self, **kwargs) -> "SycamorePrompt": """Create a new prompt with some fields changed. Args: @@ -98,7 +98,7 @@ def instead(self, **kwargs) -> "SycamorePrompt": # {"role": "system", "content": "hello"}, # {"role": "user", "content": "world"} # ] - p2 = p.instead(user="bob") + p2 = p.set(user="bob") p2.render_document(Document()) # [ # {"role": "system", "content": "hello"}, diff --git a/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py b/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py index 111a6ee33..76d4fefdf 100644 --- a/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py +++ b/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py @@ -79,10 +79,10 @@ class TestRenderedPrompt: class TestSycamorePrompt: - def test_instead_is_cow(self): + def test_set_is_cow(self): sp = SycamorePrompt() sp.__dict__["key"] = "value" - sp2 = sp.instead(key="other value") + sp2 = sp.set(key="other value") assert sp.key == "value" assert sp2.key == "other value" @@ -93,7 +93,7 @@ def test_static_rd(self, dummy_document): with pytest.raises(KeyError): prompt.render_document(dummy_document) - prompt = prompt.instead(x=76) + prompt = prompt.set(x=76) expected = RenderedPrompt( messages=[ RenderedMessage(role="system", content="system 76"), @@ -165,7 +165,7 @@ def test_include_image(self, dummy_document): assert rp.messages[1].role == "user" assert rp.messages[0].images is None - prompt = prompt.instead(user=None) + prompt = prompt.set(user=None) rp2 = prompt.render_element(dummy_document.elements[1], dummy_document) assert len(rp2.messages) == 1 assert rp2.messages[0].role == "system" From 2d1315bc59ee6c5e473eea07bfe61f5ec4f0bc8e Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Wed, 22 Jan 2025 09:40:21 -0800 Subject: [PATCH 09/46] add LLMMap and LLMMapElements transforms Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/transforms/base_llm.py | 87 ++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 lib/sycamore/sycamore/transforms/base_llm.py diff --git a/lib/sycamore/sycamore/transforms/base_llm.py b/lib/sycamore/sycamore/transforms/base_llm.py new file mode 100644 index 000000000..f1f8c1521 --- /dev/null +++ b/lib/sycamore/sycamore/transforms/base_llm.py @@ -0,0 +1,87 @@ +from typing import Optional + +from sycamore.llms.llms import LLM, LLMMode +from sycamore.llms.prompts.prompts import SycamorePrompt, RenderedPrompt +from sycamore.plan_nodes import Node +from sycamore.transforms.map import MapBatch +from sycamore.data import Document, Element + + +def _infer_prompts(prompts: list[RenderedPrompt], llm: LLM, llm_mode: LLMMode) -> list[str]: + if llm_mode == LLMMode.SYNC: + return [llm.generate(p) for p in prompts] + elif llm_mode == LLMMode.ASYNC: + raise NotImplementedError("Haven't done async yet") + elif llm_mode == LLMMode.BATCH: + raise NotImplementedError("Haven't done batch yet") + else: + raise NotImplementedError("Unknown LLM Mode") + + +class LLMMap(MapBatch): + def __init__( + self, + child: Optional[Node], + prompt: SycamorePrompt, + output_field: str, + llm: LLM, + llm_mode: LLMMode = LLMMode.SYNC, + **kwargs, + ): + self._prompt = prompt + self._validate_prompt() + self._output_field = output_field + self._llm = llm + self._llm_mode = llm_mode + super().__init__(child, f=self.llm_map, **kwargs) + + def llm_map(self, documents: list[Document]) -> list[Document]: + rendered = [self._prompt.render_document(d) for d in documents] + results = _infer_prompts(rendered, self._llm, self._llm_mode) + for d, r in zip(documents, results): + d.properties[self._output_field] = r + return documents + + def _validate_prompt(self): + doc = Document() + try: + _ = self._prompt.render_document(doc) + except NotImplementedError as e: + raise e + except Exception: + pass + + +class LLMMapElements(MapBatch): + def __init__( + self, + child: Optional[Node], + prompt: SycamorePrompt, + output_field: str, + llm: LLM, + llm_mode: LLMMode = LLMMode.SYNC, + **kwargs, + ): + self._prompt = prompt + self._validate_prompt() + self._output_field = output_field + self._llm = llm + self._llm_mode = llm_mode + super().__init__(child, f=self.llm_map_elements, **kwargs) + + def llm_map_elements(self, documents: list[Document]) -> list[Document]: + rendered = [(e, self._prompt.render_element(e, d)) for d in documents for e in d.elements] + results = _infer_prompts([p for _, p in rendered], self._llm, self._llm_mode) + for r, (e, _) in zip(results, rendered): + e.properties[self._output_field] = r + return documents + + def _validate_prompt(self): + doc = Document() + elt = Element() + try: + _ = self._prompt.render_element(elt, doc) + except NotImplementedError as e: + raise e + except Exception: + pass From 5e86e56fa2a9226a11d40cfaecde9e02c06b92f3 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Wed, 22 Jan 2025 15:50:48 -0800 Subject: [PATCH 10/46] move llm implementations to use RenderedPrompts Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/anthropic.py | 68 +++++++++----- lib/sycamore/sycamore/llms/bedrock.py | 13 +-- lib/sycamore/sycamore/llms/llms.py | 31 ++++--- lib/sycamore/sycamore/llms/openai.py | 93 ++++++++++--------- .../sycamore/llms/prompts/__init__.py | 19 +++- lib/sycamore/sycamore/transforms/base_llm.py | 2 +- 6 files changed, 140 insertions(+), 86 deletions(-) diff --git a/lib/sycamore/sycamore/llms/anthropic.py b/lib/sycamore/sycamore/llms/anthropic.py index 81bc4903d..840ed850c 100644 --- a/lib/sycamore/sycamore/llms/anthropic.py +++ b/lib/sycamore/sycamore/llms/anthropic.py @@ -6,6 +6,7 @@ from PIL import Image from sycamore.llms.llms import LLM +from sycamore.llms.prompts import RenderedPrompt from sycamore.llms.prompts.default_prompts import SimplePrompt from sycamore.utils.cache import Cache from sycamore.utils.image_utils import base64_data @@ -49,29 +50,54 @@ def rewrite_system_messages(messages: Optional[list[dict]]) -> Optional[list[dic return [m for m in messages if m.get("role") != "system"] -def get_generate_kwargs(prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> dict: +def get_generate_kwargs(prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict: kwargs = { "temperature": 0, **(llm_kwargs or {}), } - kwargs["max_tokens"] = kwargs.get("max_tokens", DEFAULT_MAX_TOKENS) - if "prompt" in prompt_kwargs: - prompt = prompt_kwargs.get("prompt") - - if isinstance(prompt, SimplePrompt): - kwargs.update({"messages": prompt.as_messages(prompt_kwargs)}) + # Anthropic models require _exactly_ alternation between "user" and "assistant" + # roles, so we break the messages into groups of consecutive user/assistant + # messages, treating "system" as "user". Then crunch each group down to a single + # message to ensure alternation. + message_groups = [] + last_role = None + + for m in prompt.messages: + r = m.role + if r == "system": + r = "user" + if r != last_role: + message_groups.append([]) + message_groups[-1].append(m) + last_role = r + + messages = [] + for group in message_groups: + role = group[0].role + if role == "system": + role = "user" + content = "\n".join(m.content for m in group) + if any(m.images is not None for m in group): + images = [im for m in group for im in m.images] + contents = [{"type": "text", "text": content}] + for im in images: + contents.append( + { # type: ignore + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": base64_data(im), + }, + } + ) + messages.append({"role": role, "content": contents}) else: - kwargs.update({"messages": [{"role": "user", "content": f"{prompt}"}]}) - - elif "messages" in prompt_kwargs: - kwargs.update({"messages": prompt_kwargs["messages"]}) - else: - raise ValueError("Either prompt or messages must be present in prompt_kwargs.") - - kwargs["messages"] = rewrite_system_messages(kwargs["messages"]) + messages.append({"role": role, "content": content}) + kwargs["messages"] = messages return kwargs @@ -128,12 +154,12 @@ def is_chat_mode(self) -> bool: def format_image(self, image: Image.Image) -> dict[str, Any]: return format_image(image) - def generate_metadata(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> dict: - ret = self._llm_cache_get(prompt_kwargs, llm_kwargs) + def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict: + ret = self._llm_cache_get(prompt, llm_kwargs) if isinstance(ret, dict): return ret - kwargs = get_generate_kwargs(prompt_kwargs, llm_kwargs) + kwargs = get_generate_kwargs(prompt, llm_kwargs) start = datetime.now() @@ -154,9 +180,9 @@ def generate_metadata(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = logging.debug(f"Generated response from Anthropic model: {ret}") - self._llm_cache_set(prompt_kwargs, llm_kwargs, ret) + self._llm_cache_set(prompt, llm_kwargs, ret) return ret - def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> str: - d = self.generate_metadata(prompt_kwargs=prompt_kwargs, llm_kwargs=llm_kwargs) + def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: + d = self.generate_metadata(prompt=prompt, llm_kwargs=llm_kwargs) return d["output"] diff --git a/lib/sycamore/sycamore/llms/bedrock.py b/lib/sycamore/sycamore/llms/bedrock.py index a7d115540..d6364a135 100644 --- a/lib/sycamore/sycamore/llms/bedrock.py +++ b/lib/sycamore/sycamore/llms/bedrock.py @@ -9,6 +9,7 @@ from sycamore.llms.llms import LLM from sycamore.llms.anthropic import format_image, get_generate_kwargs +from sycamore.llms.prompts.prompts import RenderedPrompt from sycamore.utils.cache import Cache DEFAULT_MAX_TOKENS = 1000 @@ -77,14 +78,14 @@ def format_image(self, image: Image.Image) -> dict[str, Any]: return format_image(image) raise NotImplementedError("Images not supported for non-Anthropic Bedrock models.") - def generate_metadata(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> dict: - ret = self._llm_cache_get(prompt_kwargs, llm_kwargs) + def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict: + ret = self._llm_cache_get(prompt, llm_kwargs) if isinstance(ret, dict): print(f"cache return {ret}") return ret assert ret is None - kwargs = get_generate_kwargs(prompt_kwargs, llm_kwargs) + kwargs = get_generate_kwargs(prompt, llm_kwargs) if self._model_name.startswith("anthropic."): anthropic_version = ( DEFAULT_ANTHROPIC_VERSION @@ -114,9 +115,9 @@ def generate_metadata(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = "in_tokens": in_tokens, "out_tokens": out_tokens, } - self._llm_cache_set(prompt_kwargs, llm_kwargs, ret) + self._llm_cache_set(prompt, llm_kwargs, ret) return ret - def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> str: - d = self.generate_metadata(prompt_kwargs=prompt_kwargs, llm_kwargs=llm_kwargs) + def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: + d = self.generate_metadata(prompt=prompt, llm_kwargs=llm_kwargs) return d["output"] diff --git a/lib/sycamore/sycamore/llms/llms.py b/lib/sycamore/sycamore/llms/llms.py index dc0541862..6b10418be 100644 --- a/lib/sycamore/sycamore/llms/llms.py +++ b/lib/sycamore/sycamore/llms/llms.py @@ -1,8 +1,17 @@ from abc import ABC, abstractmethod +from enum import Enum import pickle from PIL import Image from typing import Any, Optional from sycamore.utils.cache import Cache +from sycamore.llms.prompts import RenderedPrompt + + +class LLMMode(Enum): + UNKNOWN = 0 + SYNC = 1 + ASYNC = 2 + BATCH = 3 class LLM(ABC): @@ -13,7 +22,7 @@ def __init__(self, model_name, cache: Optional[Cache] = None): self._cache = cache @abstractmethod - def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> str: + def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: """Generates a response from the LLM for the given prompt and LLM parameters.""" pass @@ -26,17 +35,17 @@ def format_image(self, image: Image.Image) -> dict[str, Any]: """Returns a dictionary containing the specified image suitable for use in an LLM message.""" raise NotImplementedError("This LLM does not support images.") - async def generate_async(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> str: + async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: """Generates a response from the LLM for the given prompt and LLM parameters asynchronously.""" raise NotImplementedError("This LLM does not support asynchronous generation.") def __str__(self): return f"{self.__class__.__name__}({self._model_name})" - def _llm_cache_key(self, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> str: + def _llm_cache_key(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: """Return a cache key for the given prompt and LLM parameters.""" assert self._cache - combined = {"prompt_kwargs": prompt_kwargs, "llm_kwargs": llm_kwargs, "model_name": self._model_name} + combined = {"prompt": prompt, "llm_kwargs": llm_kwargs, "model_name": self._model_name} data = pickle.dumps(combined) return self._cache.get_hash_context(data).hexdigest() @@ -48,43 +57,43 @@ def _use_caching(self, llm_kwargs: Optional[dict]): # Only cache when temperature setting is zero. return llm_kwargs.get("temperature", 0) == 0 - def _llm_cache_get(self, prompt_kwargs: dict, llm_kwargs: Optional[dict]) -> Any: + def _llm_cache_get(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict]) -> Any: """Get a cached result for the given prompt and LLM parameters. Returns the cached result if found, or otherwise None.""" if not self._use_caching(llm_kwargs): return None assert self._cache is not None, "make mypy happy" - key = self._llm_cache_key(prompt_kwargs, llm_kwargs) + key = self._llm_cache_key(prompt, llm_kwargs) hit = self._cache.get(key) if hit: assert ( len(hit) == 4 - and hit.get("prompt_kwargs") == prompt_kwargs + and hit.get("prompt") == prompt and hit.get("llm_kwargs") == llm_kwargs and hit.get("model_name") == self._model_name and "result" in hit ), f""" Found LLM cache content mismatch: key={key} - prompt_kwargs={prompt_kwargs}, cached={hit.get("prompt_kwargs")} + prompt_kwargs={prompt}, cached={hit.get("prompt")} llm_kwargs={llm_kwargs}, cached={hit.get("llm_kwargs")} model_name={self._model_name}, cached={hit.get("model_name")} Complete hit: {hit}""" return hit.get("result") return None - def _llm_cache_set(self, prompt_kwargs: dict, llm_kwargs: Optional[dict], result: Any) -> None: + def _llm_cache_set(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict], result: Any) -> None: """Set a cached result for the given key.""" if not self._use_caching(llm_kwargs): return assert self._cache is not None, "make mypy happy" - key = self._llm_cache_key(prompt_kwargs, llm_kwargs) + key = self._llm_cache_key(prompt, llm_kwargs) self._cache.set( key, { - "prompt_kwargs": prompt_kwargs, + "prompt_kwargs": prompt, "llm_kwargs": llm_kwargs, "model_name": self._model_name, "result": result, diff --git a/lib/sycamore/sycamore/llms/openai.py b/lib/sycamore/sycamore/llms/openai.py index f90d3e1d9..6351488c0 100644 --- a/lib/sycamore/sycamore/llms/openai.py +++ b/lib/sycamore/sycamore/llms/openai.py @@ -19,7 +19,7 @@ import pydantic from sycamore.llms.llms import LLM -from sycamore.llms.prompts import SimplePrompt +from sycamore.llms.prompts import SimplePrompt, RenderedPrompt from sycamore.utils.cache import Cache from sycamore.utils.image_utils import base64_data_url @@ -304,7 +304,7 @@ def _convert_response_format(self, llm_kwargs: Optional[Dict]) -> Optional[Dict] return retval return llm_kwargs - def _get_generate_kwargs(self, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> dict: + def _get_generate_kwargs(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict: kwargs = { "temperature": 0, **(llm_kwargs or {}), @@ -312,28 +312,29 @@ def _get_generate_kwargs(self, prompt_kwargs: dict, llm_kwargs: Optional[dict] = if "SYCAMORE_OPENAI_USER" in os.environ: kwargs.update({"user": os.environ.get("SYCAMORE_OPENAI_USER")}) - if "prompt" in prompt_kwargs: - prompt = prompt_kwargs.get("prompt") - if self.is_chat_mode(): - if isinstance(prompt, SimplePrompt): - prompt = prompt.as_messages() - for idx, prompt_message in enumerate(prompt): - prompt[idx]["content"] = prompt_message["content"].format(**prompt_kwargs) - kwargs.update( - { - "messages": prompt, - } - ) - else: - kwargs.update({"messages": [{"role": "user", "content": prompt}]}) + if not self.is_chat_mode(): + # If plain completions model, we want a 'prompt' arg with a + # single string in it, not a list of messages. Make it by + # concatenating the messages. + pstring = "\n".join(m.content for m in prompt.messages) + kwargs.update({"prompt": pstring}) + return kwargs + + messages_list = [] + for m in prompt.messages: + if m.role == "system": + role = "developer" else: - if isinstance(prompt, SimplePrompt): - prompt = f"{prompt.system}\n{prompt.user}" - kwargs.update({"prompt": prompt}) - elif "messages" in prompt_kwargs: - kwargs.update({"messages": prompt_kwargs["messages"]}) - else: - raise ValueError("Either prompt or messages must be present in prompt_kwargs.") + role = m.role + if m.images is None: + content = m.content + else: + content = [{"type": "text", "text": m.content}] + for im in m.images: + content.append({"type": "image_url", "image_url": base64_data_url(im)}) + messages_list.append({"role": role, "content": content}) + + kwargs.update({"messages": messages_list}) return kwargs def _determine_using_beta(self, response_format: Any) -> bool: @@ -344,25 +345,22 @@ def _determine_using_beta(self, response_format: Any) -> bool: else: return False - def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> str: + def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: llm_kwargs = self._convert_response_format(llm_kwargs) - ret = self._llm_cache_get(prompt_kwargs, llm_kwargs) + ret = self._llm_cache_get(prompt, llm_kwargs) if ret is not None: return ret - if llm_kwargs is not None: - if self._determine_using_beta(llm_kwargs.get("response_format", None)): - ret = self._generate_using_openai_structured(prompt_kwargs, llm_kwargs) - else: - ret = self._generate_using_openai(prompt_kwargs, llm_kwargs) + if prompt.response_format is not None: + ret = self._generate_using_openai_structured(prompt, llm_kwargs) else: - ret = self._generate_using_openai(prompt_kwargs, llm_kwargs) + ret = self._generate_using_openai(prompt, llm_kwargs) - self._llm_cache_set(prompt_kwargs, llm_kwargs, ret) + self._llm_cache_set(prompt, llm_kwargs, ret) return ret - def _generate_using_openai(self, prompt_kwargs, llm_kwargs) -> str: - kwargs = self._get_generate_kwargs(prompt_kwargs, llm_kwargs) + def _generate_using_openai(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict]) -> str: + kwargs = self._get_generate_kwargs(prompt, llm_kwargs) logging.debug("OpenAI prompt: %s", kwargs) if self.is_chat_mode(): completion = self.client_wrapper.get_client().chat.completions.create(model=self._model_name, **kwargs) @@ -373,9 +371,9 @@ def _generate_using_openai(self, prompt_kwargs, llm_kwargs) -> str: logging.debug("OpenAI completion: %s", completion) return completion.choices[0].text - def _generate_using_openai_structured(self, prompt_kwargs, llm_kwargs) -> str: + def _generate_using_openai_structured(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict]) -> str: try: - kwargs = self._get_generate_kwargs(prompt_kwargs, llm_kwargs) + kwargs = self._get_generate_kwargs(prompt, llm_kwargs) if self.is_chat_mode(): completion = self.client_wrapper.get_client().beta.chat.completions.parse( model=self._model_name, **kwargs @@ -391,23 +389,24 @@ def _generate_using_openai_structured(self, prompt_kwargs, llm_kwargs) -> str: # 2.) The LLM refused to respond to the request because it did not meet guidelines raise e - async def generate_async(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> str: - ret = self._llm_cache_get(prompt_kwargs, llm_kwargs) + async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: + ret = self._llm_cache_get(prompt, llm_kwargs) if ret is not None: return ret if llm_kwargs is None: raise ValueError("Must include llm_kwargs to generate future call") - if self._determine_using_beta(llm_kwargs.get("response_format", None)): - ret = await self._generate_awaitable_using_openai_structured(prompt_kwargs, llm_kwargs) + + if prompt.response_format is not None: + ret = await self._generate_awaitable_using_openai_structured(prompt, llm_kwargs) else: - ret = await self._generate_awaitable_using_openai(prompt_kwargs, llm_kwargs) + ret = await self._generate_awaitable_using_openai(prompt, llm_kwargs) - self._llm_cache_set(prompt_kwargs, llm_kwargs, ret) + self._llm_cache_set(prompt, llm_kwargs, ret) return ret - async def _generate_awaitable_using_openai(self, prompt_kwargs, llm_kwargs) -> str: - kwargs = self._get_generate_kwargs(prompt_kwargs, llm_kwargs) + async def _generate_awaitable_using_openai(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict]) -> str: + kwargs = self._get_generate_kwargs(prompt, llm_kwargs) if self.is_chat_mode(): completion = await self.client_wrapper.get_async_client().chat.completions.create( model=self._model_name, **kwargs @@ -419,9 +418,11 @@ async def _generate_awaitable_using_openai(self, prompt_kwargs, llm_kwargs) -> s ) return completion.choices[0].text - async def _generate_awaitable_using_openai_structured(self, prompt_kwargs, llm_kwargs) -> str: + async def _generate_awaitable_using_openai_structured( + self, prompt: RenderedPrompt, llm_kwargs: Optional[dict] + ) -> str: try: - kwargs = self._get_generate_kwargs(prompt_kwargs, llm_kwargs) + kwargs = self._get_generate_kwargs(prompt, llm_kwargs) if self.is_chat_mode(): completion = await self.client_wrapper.get_async_client().beta.chat.completions.parse( model=self._model_name, **kwargs diff --git a/lib/sycamore/sycamore/llms/prompts/__init__.py b/lib/sycamore/sycamore/llms/prompts/__init__.py index 3fc2487ad..6261099a6 100644 --- a/lib/sycamore/sycamore/llms/prompts/__init__.py +++ b/lib/sycamore/sycamore/llms/prompts/__init__.py @@ -15,6 +15,14 @@ ExtractTablePropertiesPrompt, ) from sycamore.llms.prompts.default_prompts import _deprecated_prompts +from sycamore.llms.prompts.prompts import ( + RenderedPrompt, + RenderedMessage, + SycamorePrompt, + ElementListPrompt, + ElementPrompt, + StaticPrompt, +) prompts = [ "SimplePrompt", @@ -28,7 +36,16 @@ "ExtractTablePropertiesPrompt", ] + list(_deprecated_prompts.keys()) -__all__ = prompts +_all = prompts + [ + "RenderedPrompt", + "RenderedMessage", + "SycamorePrompt", + "ElementListPrompt", + "ElementPrompt", + "StaticPrompt", +] + +__all__ = _all def __getattr__(name): diff --git a/lib/sycamore/sycamore/transforms/base_llm.py b/lib/sycamore/sycamore/transforms/base_llm.py index f1f8c1521..6fafd64a4 100644 --- a/lib/sycamore/sycamore/transforms/base_llm.py +++ b/lib/sycamore/sycamore/transforms/base_llm.py @@ -9,7 +9,7 @@ def _infer_prompts(prompts: list[RenderedPrompt], llm: LLM, llm_mode: LLMMode) -> list[str]: if llm_mode == LLMMode.SYNC: - return [llm.generate(p) for p in prompts] + return [llm.generate(prompt=p) for p in prompts] elif llm_mode == LLMMode.ASYNC: raise NotImplementedError("Haven't done async yet") elif llm_mode == LLMMode.BATCH: From 27581efcc8d0e39d7156c0145e683c05d1a2f226 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Wed, 22 Jan 2025 15:51:35 -0800 Subject: [PATCH 11/46] also this guy Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/llms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/sycamore/sycamore/llms/llms.py b/lib/sycamore/sycamore/llms/llms.py index 6b10418be..7191ab5d7 100644 --- a/lib/sycamore/sycamore/llms/llms.py +++ b/lib/sycamore/sycamore/llms/llms.py @@ -108,7 +108,7 @@ def __init__(self, *, return_value="trivial", cache: Optional[Cache] = None): super().__init__("trivial", cache=cache) self._return_value = return_value - def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> str: + def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: return self._return_value def is_chat_mode(self) -> bool: From 739b672e254dcedca7a46e696ad71a28c69ff2f1 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Wed, 22 Jan 2025 16:38:04 -0800 Subject: [PATCH 12/46] add docset methods Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/docset.py | 39 ++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/lib/sycamore/sycamore/docset.py b/lib/sycamore/sycamore/docset.py index 495f5381b..91655d565 100644 --- a/lib/sycamore/sycamore/docset.py +++ b/lib/sycamore/sycamore/docset.py @@ -8,7 +8,8 @@ from sycamore.context import Context, context_params, OperationTypes from sycamore.data import Document, Element, MetadataDocument from sycamore.functions.tokenizer import Tokenizer -from sycamore.llms.llms import LLM +from sycamore.llms.llms import LLM, LLMMode +from sycamore.llms.prompts import SycamorePrompt from sycamore.llms.prompts.default_prompts import ( LlmClusterEntityAssignGroupsMessagesPrompt, LlmClusterEntityFormGroupsMessagesPrompt, @@ -948,6 +949,42 @@ def custom_flat_mapping_function(document: Document) -> list[Document]: flat_map = FlatMap(self.plan, f=f, **resource_args) return DocSet(self.context, flat_map) + def llm_map( + self, prompt: SycamorePrompt, output_field: str, llm: LLM, llm_mode: LLMMode = LLMMode.SYNC, **kwargs + ) -> "DocSet": + """ + Renders and runs a prompt on every Document of the DocSet. + + Args: + prompt: The prompt to use. Must implement the ``render_document`` method + output_field: Field in properties to store the output. + llm: LLM to use for the inferences. + llm_mode: how to make the api calls to the llm - sync/async/batch + """ + from sycamore.transforms.base_llm import LLMMap + + llm_map = LLMMap(self.plan, prompt=prompt, output_field=output_field, llm=llm, llm_mode=llm_mode, **kwargs) + return DocSet(self.context, llm_map) + + def llm_map_elements( + self, prompt: SycamorePrompt, output_field: str, llm: LLM, llm_mode: LLMMode = LLMMode.SYNC, **kwargs + ) -> "DocSet": + """ + Renders and runs a prompt on every Element of every Document in the DocSet. + + Args: + prompt: The prompt to use. Must implement the ``render_document`` method + output_field: Field in properties to store the output. + llm: LLM to use for the inferences. + llm_mode: how to make the api calls to the llm - sync/async/batch + """ + from sycamore.transforms.base_llm import LLMMapElements + + llm_map_elements = LLMMapElements( + self.plan, prompt=prompt, output_field=output_field, llm=llm, llm_mode=llm_mode, **kwargs + ) + return DocSet(self.context, llm_map_elements) + def filter(self, f: Callable[[Document], bool], **kwargs) -> "DocSet": """ Applies the Filter transform on the Docset. From 73d9bddf76580c6ff1d7a34a47e78110aff00c44 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Thu, 23 Jan 2025 10:45:32 -0800 Subject: [PATCH 13/46] docstrings Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/transforms/base_llm.py | 51 ++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/lib/sycamore/sycamore/transforms/base_llm.py b/lib/sycamore/sycamore/transforms/base_llm.py index 6fafd64a4..5569b1c1a 100644 --- a/lib/sycamore/sycamore/transforms/base_llm.py +++ b/lib/sycamore/sycamore/transforms/base_llm.py @@ -19,6 +19,32 @@ def _infer_prompts(prompts: list[RenderedPrompt], llm: LLM, llm_mode: LLMMode) - class LLMMap(MapBatch): + """The LLMMap transform renders each Document in a docset into + a prompt for an LLM, calls the LLM, and attaches the output to + the document. + + Args: + child: Child node in the sycamore execution graph + prompt: The SycamorePrompt to use to render each document. + Must implement the ``render_document`` method. + output_field: The name of the field in doc.properties in which + to store the llm output + llm: The llm to use for inference. + llm_mode: How to call the llm - sync/async/batch. All LLMs do not + necessarily implement all options. + + Example: + .. code-block:: python + + prompt = EntityExtractorZeroShotGuidancePrompt.set(entity="title") + + docset.llm_map( + prompt=prompt, + output_field="title", + llm=OpenAI(OpenAIModels.GPT_4O_MINI) + ) + """ + def __init__( self, child: Optional[Node], @@ -53,6 +79,31 @@ def _validate_prompt(self): class LLMMapElements(MapBatch): + """The LLMMapElements transform renders each Element for each + Document in a docset into a prompt for an LLM, calls the LLM, + and attaches the output to the document. + + Args: + child: Child node in the sycamore execution graph + prompt: The SycamorePrompt to use to render each element. + Must implement the ``render_element`` method. + output_field: The name of the field in elt.properties in which + to store the llm output. + llm: The llm to use for inference. + llm_mode: How to call the llm - sync/async/batch. All LLMs do not + necessarily implement all options. + + Example: + .. code-block:: python + + prompt = TextSummarizerGuidancePrompt + + docset.llm_map_elements( + prompt = prompt, + output_field = "summary", + llm = OpenAI(OpenAIModels.GPT_4O) + """ + def __init__( self, child: Optional[Node], From ed8785e95d6e5aae55d5430b177554533ae98b95 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Thu, 23 Jan 2025 10:46:23 -0800 Subject: [PATCH 14/46] add llm_map unit tests Signed-off-by: Henry Lindeman --- .../tests/unit/transforms/test_base_llm.py | 79 +++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py b/lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py new file mode 100644 index 000000000..3ba4efdcc --- /dev/null +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py @@ -0,0 +1,79 @@ +from sycamore.data import Document, Element +from sycamore.llms.llms import LLM +from sycamore.llms.prompts import RenderedPrompt, SycamorePrompt +from sycamore.llms.prompts.prompts import RenderedMessage +from sycamore.transforms.base_llm import LLMMap, LLMMapElements +import pytest +from typing import Optional + + +class FakeLLM(LLM): + def __init__(self): + super().__init__(model_name="dummy") + + def is_chat_mode(self) -> bool: + return True + + def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: + return "".join(m.content for m in prompt.messages) + + +class FakeDocPrompt(SycamorePrompt): + def render_document(self, doc: Document) -> RenderedPrompt: + return RenderedPrompt(messages=[RenderedMessage(role="system", content=doc.text_representation or "None")]) + + +class FakeEltPrompt(SycamorePrompt): + def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: + return RenderedPrompt( + messages=[ + RenderedMessage(role="system", content=doc.text_representation or "None"), + RenderedMessage(role="user", content=elt.text_representation or "None"), + ] + ) + + +class TestLLMMap: + def test_wrong_prompt_fails_fast(self): + prompt = FakeEltPrompt() + llm = FakeLLM() + with pytest.raises(NotImplementedError) as einfo: + _ = LLMMap(None, prompt, "out", llm) + assert "FakeEltPrompt" in str(einfo.value) + + def test_happy_path(self): + prompt = FakeDocPrompt() + llm = FakeLLM() + doc1 = Document({"text_representation": "ooga"}) + doc2 = Document({"text_representation": "booga"}) + map = LLMMap(None, prompt, "out", llm) + outdocs = map.llm_map([doc1, doc2]) + + assert outdocs[0].text_representation == "ooga" + assert outdocs[0].properties["out"] == "ooga" + assert outdocs[1].text_representation == "booga" + assert outdocs[1].properties["out"] == "booga" + + +class TestLLMMapElements: + def test_wrong_prompt_fails_fast(self): + prompt = FakeDocPrompt() + llm = FakeLLM() + with pytest.raises(NotImplementedError) as einfo: + _ = LLMMapElements(None, prompt, "out", llm) + assert "FakeDocPrompt" in str(einfo.value) + + def test_happy_path(self): + prompt = FakeEltPrompt() + llm = FakeLLM() + doc1 = Document( + {"text_representation": "ooga", "elements": [{"text_representation": "yo"}, {"text_representation": "ho"}]} + ) + doc2 = Document({"elements": [{"text_representation": "booga"}, {}]}) + map = LLMMapElements(None, prompt, "out", llm) + outdocs = map.llm_map_elements([doc1, doc2]) + + assert outdocs[0].elements[0].properties["out"] == "oogayo" + assert outdocs[0].elements[1].properties["out"] == "oogaho" + assert outdocs[1].elements[0].properties["out"] == "Nonebooga" + assert outdocs[1].elements[1].properties["out"] == "NoneNone" From 523d6e32953f87fff4f26b12f6098387110fa5c4 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Thu, 23 Jan 2025 10:58:40 -0800 Subject: [PATCH 15/46] fix bedrock tests and chaching Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/llms.py | 2 +- .../sycamore/tests/unit/llms/test_bedrock.py | 43 ++++++++++--------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/lib/sycamore/sycamore/llms/llms.py b/lib/sycamore/sycamore/llms/llms.py index 7191ab5d7..3713310a7 100644 --- a/lib/sycamore/sycamore/llms/llms.py +++ b/lib/sycamore/sycamore/llms/llms.py @@ -93,7 +93,7 @@ def _llm_cache_set(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict], res self._cache.set( key, { - "prompt_kwargs": prompt, + "prompt": prompt, "llm_kwargs": llm_kwargs, "model_name": self._model_name, "result": result, diff --git a/lib/sycamore/sycamore/tests/unit/llms/test_bedrock.py b/lib/sycamore/sycamore/tests/unit/llms/test_bedrock.py index f50804e75..af606ab88 100644 --- a/lib/sycamore/sycamore/tests/unit/llms/test_bedrock.py +++ b/lib/sycamore/sycamore/tests/unit/llms/test_bedrock.py @@ -2,6 +2,7 @@ from unittest.mock import patch import tempfile +from sycamore.llms.prompts import RenderedPrompt, RenderedMessage from sycamore.llms import Bedrock, BedrockModels from sycamore.llms.bedrock import DEFAULT_ANTHROPIC_VERSION, DEFAULT_MAX_TOKENS from sycamore.utils.cache import DiskCache @@ -40,11 +41,11 @@ def test_bedrock_simple(mock_boto3_client): assert client._model_name == BedrockModels.CLAUDE_3_5_SONNET.value.name result = client.generate( - prompt_kwargs={ - "messages": [ - {"role": "user", "content": "Roll 4d20 and tell me the final sum."}, + prompt=RenderedPrompt( + messages=[ + RenderedMessage(role="user", content="Roll 4d20 and tell me the final sum."), ] - } + ) ) assert result == "Here is your result: 56" @@ -68,12 +69,12 @@ def test_bedrock_system_role(mock_boto3_client): assert client._model_name == BedrockModels.CLAUDE_3_5_SONNET.value.name result = client.generate( - prompt_kwargs={ - "messages": [ - {"role": "system", "content": "You are a DM for a game of D&D."}, - {"role": "user", "content": "Roll 4d20 and tell me the final sum."}, + prompt=RenderedPrompt( + messages=[ + RenderedMessage(role="system", content="You are a DM for a game of D&D."), + RenderedMessage(role="user", content="Roll 4d20 and tell me the final sum."), ] - } + ) ) assert result == "Here is your result: 56" @@ -99,11 +100,11 @@ def test_bedrock_with_llm_kwargs(mock_boto3_client): assert client._model_name == BedrockModels.CLAUDE_3_5_SONNET.value.name result = client.generate( - prompt_kwargs={ - "messages": [ - {"role": "user", "content": "Roll 4d20 and tell me the final sum."}, + prompt=RenderedPrompt( + messages=[ + RenderedMessage(role="user", content="Roll 4d20 and tell me the final sum."), ] - }, + ), llm_kwargs={"max_tokens": 100, "anthropic_version": "v1"}, ) assert result == "Here is your result: 56" @@ -134,11 +135,11 @@ def test_bedrock_with_cache(mock_boto3_client): assert client._model_name == BedrockModels.CLAUDE_3_5_SONNET.value.name result = client.generate( - prompt_kwargs={ - "messages": [ - {"role": "user", "content": "Roll 4d20 and tell me the final sum."}, + prompt=RenderedPrompt( + messages=[ + RenderedMessage(role="user", content="Roll 4d20 and tell me the final sum."), ] - } + ) ) assert result == "Here is your result: 56" @@ -154,11 +155,11 @@ def test_bedrock_with_cache(mock_boto3_client): } result = client.generate( - prompt_kwargs={ - "messages": [ - {"role": "user", "content": "Roll 4d20 and tell me the final sum."}, + prompt=RenderedPrompt( + messages=[ + RenderedMessage(role="user", content="Roll 4d20 and tell me the final sum."), ] - } + ) ) assert result == "Here is your result: 56" From e1b32062dcc3c3bf7631a3085472305e4e2c2662 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Thu, 23 Jan 2025 13:09:46 -0800 Subject: [PATCH 16/46] fix anthropic and bedrock ITs Signed-off-by: Henry Lindeman --- .../tests/integration/llms/test_anthropic.py | 81 +++++++++++-------- .../tests/integration/llms/test_bedrock.py | 81 +++++++++++-------- 2 files changed, 98 insertions(+), 64 deletions(-) diff --git a/lib/sycamore/sycamore/tests/integration/llms/test_anthropic.py b/lib/sycamore/sycamore/tests/integration/llms/test_anthropic.py index c5f77666b..1f59c5d2f 100644 --- a/lib/sycamore/sycamore/tests/integration/llms/test_anthropic.py +++ b/lib/sycamore/sycamore/tests/integration/llms/test_anthropic.py @@ -2,14 +2,17 @@ from typing import Any from sycamore.llms import Anthropic, AnthropicModels +from sycamore.llms.prompts.prompts import RenderedPrompt, RenderedMessage from sycamore.utils.cache import DiskCache def test_anthropic_defaults(): llm = Anthropic(AnthropicModels.CLAUDE_3_HAIKU) - prompt_kwargs = {"prompt": "Write a limerick about large language models."} + prompt = RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a limerick about large language models.")] + ) - res = llm.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={}) + res = llm.generate(prompt=prompt, llm_kwargs={}) assert len(res) > 0 @@ -17,14 +20,14 @@ def test_anthropic_defaults(): def test_anthropic_messages_defaults(): llm = Anthropic(AnthropicModels.CLAUDE_3_HAIKU) messages = [ - { - "role": "user", - "content": "Write a caption for a recent trip to a sunny beach", - }, + RenderedMessage( + role="user", + content="Write a caption for a recent trip to a sunny beach", + ), ] - prompt_kwargs = {"messages": messages} + prompt = RenderedPrompt(messages=messages) - res = llm.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={}) + res = llm.generate(prompt=prompt, llm_kwargs={}) assert len(res) > 0 @@ -32,43 +35,53 @@ def test_anthropic_messages_defaults(): def test_cached_anthropic(tmp_path: Path): cache = DiskCache(str(tmp_path)) llm = Anthropic(AnthropicModels.CLAUDE_3_HAIKU, cache=cache) - prompt_kwargs = {"prompt": "Write a limerick about large language models."} + prompt = RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a limerick about large language models.")] + ) # pylint: disable=protected-access - key = llm._llm_cache_key(prompt_kwargs, {}) + key = llm._llm_cache_key(prompt, {}) - res = llm.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={}) + res = llm.generate(prompt=prompt, llm_kwargs={}) # assert result is cached assert cache.get(key).get("result")["output"] == res - assert cache.get(key).get("prompt_kwargs") == prompt_kwargs + assert cache.get(key).get("prompt") == prompt assert cache.get(key).get("llm_kwargs") == {} assert cache.get(key).get("model_name") == AnthropicModels.CLAUDE_3_HAIKU.value # assert llm.generate is using cached result custom_output: dict[str, Any] = { "result": {"output": "This is a custom response"}, - "prompt_kwargs": prompt_kwargs, + "prompt": prompt, "llm_kwargs": {}, "model_name": AnthropicModels.CLAUDE_3_HAIKU.value, } cache.set(key, custom_output) - assert llm.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={}) == custom_output["result"]["output"] + assert llm.generate(prompt=prompt, llm_kwargs={}) == custom_output["result"]["output"] def test_cached_bedrock_different_prompts(tmp_path: Path): cache = DiskCache(str(tmp_path)) llm = Anthropic(AnthropicModels.CLAUDE_3_HAIKU, cache=cache) - prompt_kwargs_1 = {"prompt": "Write a limerick about large language models."} - prompt_kwargs_2 = {"prompt": "Write a short limerick about large language models."} - prompt_kwargs_3 = {"prompt": "Write a poem about large language models."} - prompt_kwargs_4 = {"prompt": "Write a short poem about large language models."} - - key_1 = llm._llm_cache_key(prompt_kwargs_1, {}) - key_2 = llm._llm_cache_key(prompt_kwargs_2, {}) - key_3 = llm._llm_cache_key(prompt_kwargs_3, {}) - key_4 = llm._llm_cache_key(prompt_kwargs_4, {}) + prompt_1 = RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a limerick about large language models.")] + ) + prompt_2 = RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a short limerick about large language models.")] + ) + prompt_3 = RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a poem about large language models.")] + ) + prompt_4 = RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a short poem about large language models.")] + ) + + key_1 = llm._llm_cache_key(prompt_1, {}) + key_2 = llm._llm_cache_key(prompt_2, {}) + key_3 = llm._llm_cache_key(prompt_3, {}) + key_4 = llm._llm_cache_key(prompt_4, {}) keys = [key_1, key_2, key_3, key_4] assert len(keys) == len( @@ -87,21 +100,23 @@ def test_cached_anthropic_different_models(tmp_path: Path): llm_HAIKU = Anthropic(AnthropicModels.CLAUDE_3_HAIKU, cache=cache) llm_SONNET = Anthropic(AnthropicModels.CLAUDE_3_SONNET, cache=cache) - prompt_kwargs = {"prompt": "Write a limerick about large language models."} + prompt = RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a limerick about large language models.")] + ) # populate cache - key_HAIKU = llm_HAIKU._llm_cache_key(prompt_kwargs, {}) - res_HAIKU = llm_HAIKU.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={}) - key_SONNET = llm_SONNET._llm_cache_key(prompt_kwargs, {}) - res_SONNET = llm_SONNET.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={}) + key_HAIKU = llm_HAIKU._llm_cache_key(prompt, {}) + res_HAIKU = llm_HAIKU.generate(prompt=prompt, llm_kwargs={}) + key_SONNET = llm_SONNET._llm_cache_key(prompt, {}) + res_SONNET = llm_SONNET.generate(prompt=prompt, llm_kwargs={}) # check proper cached results assert cache.get(key_HAIKU).get("result")["output"] == res_HAIKU - assert cache.get(key_HAIKU).get("prompt_kwargs") == prompt_kwargs + assert cache.get(key_HAIKU).get("prompt") == prompt assert cache.get(key_HAIKU).get("llm_kwargs") == {} assert cache.get(key_HAIKU).get("model_name") == AnthropicModels.CLAUDE_3_HAIKU.value assert cache.get(key_SONNET).get("result")["output"] == res_SONNET - assert cache.get(key_SONNET).get("prompt_kwargs") == prompt_kwargs + assert cache.get(key_SONNET).get("prompt") == prompt assert cache.get(key_SONNET).get("llm_kwargs") == {} assert cache.get(key_SONNET).get("model_name") == AnthropicModels.CLAUDE_3_SONNET.value @@ -112,9 +127,11 @@ def test_cached_anthropic_different_models(tmp_path: Path): def test_metadata(): llm = Anthropic(AnthropicModels.CLAUDE_3_HAIKU) - prompt_kwargs = {"prompt": "Write a limerick about large language models."} + prompt = RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a limerick about large language models.")] + ) - res = llm.generate_metadata(prompt_kwargs=prompt_kwargs, llm_kwargs={}) + res = llm.generate_metadata(prompt=prompt, llm_kwargs={}) assert "output" in res assert "wall_latency" in res diff --git a/lib/sycamore/sycamore/tests/integration/llms/test_bedrock.py b/lib/sycamore/sycamore/tests/integration/llms/test_bedrock.py index d8622d888..d27f39be7 100644 --- a/lib/sycamore/sycamore/tests/integration/llms/test_bedrock.py +++ b/lib/sycamore/sycamore/tests/integration/llms/test_bedrock.py @@ -2,6 +2,7 @@ from typing import Any from sycamore.llms import Bedrock, BedrockModels +from sycamore.llms.prompts import RenderedPrompt, RenderedMessage from sycamore.utils.cache import DiskCache @@ -10,9 +11,11 @@ def test_bedrock_defaults(): llm = Bedrock(BedrockModels.CLAUDE_3_HAIKU) - prompt_kwargs = {"prompt": "Write a limerick about large language models."} + prompt = RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a limerick about large language models.")] + ) - res = llm.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={}) + res = llm.generate(prompt=prompt, llm_kwargs={}) assert len(res) > 0 @@ -20,14 +23,14 @@ def test_bedrock_defaults(): def test_bedrock_messages_defaults(): llm = Bedrock(BedrockModels.CLAUDE_3_HAIKU) messages = [ - { - "role": "user", - "content": "Write a caption for a recent trip to a sunny beach", - }, + RenderedMessage( + role="user", + content="Write a caption for a recent trip to a sunny beach", + ), ] - prompt_kwargs = {"messages": messages} + prompt = RenderedPrompt(messages=messages) - res = llm.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={}) + res = llm.generate(prompt=prompt, llm_kwargs={}) assert len(res) > 0 @@ -35,43 +38,53 @@ def test_bedrock_messages_defaults(): def test_cached_bedrock(tmp_path: Path): cache = DiskCache(str(tmp_path)) llm = Bedrock(BedrockModels.CLAUDE_3_HAIKU, cache=cache) - prompt_kwargs = {"prompt": "Write a limerick about large language models."} + prompt = RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a limerick about large language models.")] + ) # pylint: disable=protected-access - key = llm._llm_cache_key(prompt_kwargs, {}) + key = llm._llm_cache_key(prompt, {}) - res = llm.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={}) + res = llm.generate(prompt=prompt, llm_kwargs={}) # assert result is cached assert cache.get(key).get("result")["output"] == res - assert cache.get(key).get("prompt_kwargs") == prompt_kwargs + assert cache.get(key).get("prompt") == prompt assert cache.get(key).get("llm_kwargs") == {} assert cache.get(key).get("model_name") == BedrockModels.CLAUDE_3_HAIKU.value.name # assert llm.generate is using cached result custom_output: dict[str, Any] = { "result": {"output": "This is a custom response"}, - "prompt_kwargs": prompt_kwargs, + "prompt": prompt, "llm_kwargs": {}, "model_name": BedrockModels.CLAUDE_3_HAIKU.value.name, } cache.set(key, custom_output) - assert llm.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={}) == custom_output["result"]["output"] + assert llm.generate(prompt=prompt, llm_kwargs={}) == custom_output["result"]["output"] def test_cached_bedrock_different_prompts(tmp_path: Path): cache = DiskCache(str(tmp_path)) llm = Bedrock(BedrockModels.CLAUDE_3_HAIKU, cache=cache) - prompt_kwargs_1 = {"prompt": "Write a limerick about large language models."} - prompt_kwargs_2 = {"prompt": "Write a short limerick about large language models."} - prompt_kwargs_3 = {"prompt": "Write a poem about large language models."} - prompt_kwargs_4 = {"prompt": "Write a short poem about large language models."} - - key_1 = llm._llm_cache_key(prompt_kwargs_1, {}) - key_2 = llm._llm_cache_key(prompt_kwargs_2, {}) - key_3 = llm._llm_cache_key(prompt_kwargs_3, {}) - key_4 = llm._llm_cache_key(prompt_kwargs_4, {}) + prompt_1 = RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a limerick about large language models.")] + ) + prompt_2 = RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a short limerick about large language models.")] + ) + prompt_3 = RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a poem about large language models.")] + ) + prompt_4 = RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a short poem about large language models.")] + ) + + key_1 = llm._llm_cache_key(prompt_1, {}) + key_2 = llm._llm_cache_key(prompt_2, {}) + key_3 = llm._llm_cache_key(prompt_3, {}) + key_4 = llm._llm_cache_key(prompt_4, {}) keys = [key_1, key_2, key_3, key_4] assert len(keys) == len( @@ -90,21 +103,23 @@ def test_cached_bedrock_different_models(tmp_path: Path): llm_HAIKU = Bedrock(BedrockModels.CLAUDE_3_HAIKU, cache=cache) llm_SONNET = Bedrock(BedrockModels.CLAUDE_3_SONNET, cache=cache) - prompt_kwargs = {"prompt": "Write a limerick about large language models."} + prompt = RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a limerick about large language models.")] + ) # populate cache - key_HAIKU = llm_HAIKU._llm_cache_key(prompt_kwargs, {}) - res_HAIKU = llm_HAIKU.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={}) - key_SONNET = llm_SONNET._llm_cache_key(prompt_kwargs, {}) - res_SONNET = llm_SONNET.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={}) + key_HAIKU = llm_HAIKU._llm_cache_key(prompt, {}) + res_HAIKU = llm_HAIKU.generate(prompt=prompt, llm_kwargs={}) + key_SONNET = llm_SONNET._llm_cache_key(prompt, {}) + res_SONNET = llm_SONNET.generate(prompt=prompt, llm_kwargs={}) # check proper cached results assert cache.get(key_HAIKU).get("result")["output"] == res_HAIKU - assert cache.get(key_HAIKU).get("prompt_kwargs") == prompt_kwargs + assert cache.get(key_HAIKU).get("prompt") == prompt assert cache.get(key_HAIKU).get("llm_kwargs") == {} assert cache.get(key_HAIKU).get("model_name") == BedrockModels.CLAUDE_3_HAIKU.value.name assert cache.get(key_SONNET).get("result")["output"] == res_SONNET - assert cache.get(key_SONNET).get("prompt_kwargs") == prompt_kwargs + assert cache.get(key_SONNET).get("prompt") == prompt assert cache.get(key_SONNET).get("llm_kwargs") == {} assert cache.get(key_SONNET).get("model_name") == BedrockModels.CLAUDE_3_SONNET.value.name @@ -115,9 +130,11 @@ def test_cached_bedrock_different_models(tmp_path: Path): def test_metadata(): llm = Bedrock(BedrockModels.CLAUDE_3_HAIKU) - prompt_kwargs = {"prompt": "Write a limerick about large language models."} + prompt = RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a limerick about large language models.")] + ) - res = llm.generate_metadata(prompt_kwargs=prompt_kwargs, llm_kwargs={}) + res = llm.generate_metadata(prompt=prompt, llm_kwargs={}) assert "output" in res assert "wall_latency" in res From 6500e1c82985d73bd5cbf9f2586f4269a71f8acb Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Thu, 23 Jan 2025 14:09:45 -0800 Subject: [PATCH 17/46] adjust caching to handle pydantic class response format properly Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/llms.py | 30 +++- lib/sycamore/sycamore/llms/prompts/prompts.py | 2 +- .../tests/integration/llms/test_anthropic.py | 2 + .../tests/integration/llms/test_bedrock.py | 2 + .../tests/integration/llms/test_openai.py | 139 ++++++++++-------- 5 files changed, 111 insertions(+), 64 deletions(-) diff --git a/lib/sycamore/sycamore/llms/llms.py b/lib/sycamore/sycamore/llms/llms.py index 3713310a7..7c1b39e5a 100644 --- a/lib/sycamore/sycamore/llms/llms.py +++ b/lib/sycamore/sycamore/llms/llms.py @@ -1,8 +1,11 @@ +import inspect +from urllib import response from abc import ABC, abstractmethod from enum import Enum import pickle from PIL import Image from typing import Any, Optional +import pydantic from sycamore.utils.cache import Cache from sycamore.llms.prompts import RenderedPrompt @@ -42,10 +45,24 @@ async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[d def __str__(self): return f"{self.__class__.__name__}({self._model_name})" + @staticmethod + def _pickleable_response_format(prompt: RenderedPrompt) -> Any: + if inspect.isclass(prompt.response_format) and issubclass(prompt.response_format, pydantic.BaseModel): + return prompt.response_format.model_json_schema() + else: + return prompt.response_format + def _llm_cache_key(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: """Return a cache key for the given prompt and LLM parameters.""" assert self._cache - combined = {"prompt": prompt, "llm_kwargs": llm_kwargs, "model_name": self._model_name} + rf = self._pickleable_response_format(prompt) + ms = prompt.messages + combined = { + "prompt": RenderedPrompt(messages=ms), + "prompt.response_format": rf, + "llm_kwargs": llm_kwargs, + "model_name": self._model_name, + } data = pickle.dumps(combined) return self._cache.get_hash_context(data).hexdigest() @@ -68,15 +85,17 @@ def _llm_cache_get(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict]) -> hit = self._cache.get(key) if hit: assert ( - len(hit) == 4 - and hit.get("prompt") == prompt + len(hit) == 5 + and hit.get("prompt") == RenderedPrompt(messages=prompt.messages) + and hit.get("prompt.response_format") == self._pickleable_response_format(prompt) and hit.get("llm_kwargs") == llm_kwargs and hit.get("model_name") == self._model_name and "result" in hit ), f""" Found LLM cache content mismatch: key={key} - prompt_kwargs={prompt}, cached={hit.get("prompt")} + prompt={prompt}, cached={hit.get("prompt")} + cached_response_format={hit.get("prompt.response_format")} llm_kwargs={llm_kwargs}, cached={hit.get("llm_kwargs")} model_name={self._model_name}, cached={hit.get("model_name")} Complete hit: {hit}""" @@ -93,7 +112,8 @@ def _llm_cache_set(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict], res self._cache.set( key, { - "prompt": prompt, + "prompt": RenderedPrompt(messages=prompt.messages), + "prompt.response_format": self._pickleable_response_format(prompt), "llm_kwargs": llm_kwargs, "model_name": self._model_name, "result": result, diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py index 0ea81112e..0ba0eb18b 100644 --- a/lib/sycamore/sycamore/llms/prompts/prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -34,7 +34,7 @@ class RenderedPrompt: """ messages: list[RenderedMessage] - response_format: Union[None, dict[str, Any], pydantic.BaseModel] = None + response_format: Union[None, dict[str, Any], type[pydantic.BaseModel]] = None class SycamorePrompt: diff --git a/lib/sycamore/sycamore/tests/integration/llms/test_anthropic.py b/lib/sycamore/sycamore/tests/integration/llms/test_anthropic.py index 1f59c5d2f..43e906925 100644 --- a/lib/sycamore/sycamore/tests/integration/llms/test_anthropic.py +++ b/lib/sycamore/sycamore/tests/integration/llms/test_anthropic.py @@ -47,6 +47,7 @@ def test_cached_anthropic(tmp_path: Path): # assert result is cached assert cache.get(key).get("result")["output"] == res assert cache.get(key).get("prompt") == prompt + assert cache.get(key).get("prompt.response_format") == None assert cache.get(key).get("llm_kwargs") == {} assert cache.get(key).get("model_name") == AnthropicModels.CLAUDE_3_HAIKU.value @@ -54,6 +55,7 @@ def test_cached_anthropic(tmp_path: Path): custom_output: dict[str, Any] = { "result": {"output": "This is a custom response"}, "prompt": prompt, + "prompt.response_format": None, "llm_kwargs": {}, "model_name": AnthropicModels.CLAUDE_3_HAIKU.value, } diff --git a/lib/sycamore/sycamore/tests/integration/llms/test_bedrock.py b/lib/sycamore/sycamore/tests/integration/llms/test_bedrock.py index d27f39be7..fe959230b 100644 --- a/lib/sycamore/sycamore/tests/integration/llms/test_bedrock.py +++ b/lib/sycamore/sycamore/tests/integration/llms/test_bedrock.py @@ -50,6 +50,7 @@ def test_cached_bedrock(tmp_path: Path): # assert result is cached assert cache.get(key).get("result")["output"] == res assert cache.get(key).get("prompt") == prompt + assert cache.get(key).get("prompt.response_format") == None assert cache.get(key).get("llm_kwargs") == {} assert cache.get(key).get("model_name") == BedrockModels.CLAUDE_3_HAIKU.value.name @@ -57,6 +58,7 @@ def test_cached_bedrock(tmp_path: Path): custom_output: dict[str, Any] = { "result": {"output": "This is a custom response"}, "prompt": prompt, + "prompt.response_format": None, "llm_kwargs": {}, "model_name": BedrockModels.CLAUDE_3_HAIKU.value.name, } diff --git a/lib/sycamore/sycamore/tests/integration/llms/test_openai.py b/lib/sycamore/sycamore/tests/integration/llms/test_openai.py index c2102f6d8..aaa96af3c 100644 --- a/lib/sycamore/sycamore/tests/integration/llms/test_openai.py +++ b/lib/sycamore/sycamore/tests/integration/llms/test_openai.py @@ -5,6 +5,7 @@ from sycamore.llms import OpenAI, OpenAIModels, OpenAIClientWrapper from sycamore.llms.openai import OpenAIModel, OpenAIClientType from sycamore.llms.prompts.default_prompts import SimplePrompt +from sycamore.llms.prompts import RenderedPrompt, RenderedMessage, StaticPrompt from sycamore.utils.cache import DiskCache from pydantic import BaseModel @@ -16,9 +17,10 @@ def test_openai_defaults(): llm = OpenAI(OpenAIModels.GPT_3_5_TURBO) - prompt_kwargs = {"prompt": "Write a limerick about large language models."} - - res = llm.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={}) + prompt = RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a limerick about large language models.")] + ) + res = llm.generate(prompt=prompt, llm_kwargs={}) assert len(res) > 0 @@ -26,18 +28,18 @@ def test_openai_defaults(): def test_openai_messages_defaults(): llm = OpenAI(OpenAIModels.GPT_3_5_TURBO) messages = [ - { - "role": "system", - "content": "You are a social media influencer", - }, - { - "role": "user", - "content": "Write a caption for a recent trip to a sunny beach", - }, + RenderedMessage( + role="system", + content="You are a social media influencer", + ), + RenderedMessage( + role="user", + content="Write a caption for a recent trip to a sunny beach", + ), ] - prompt_kwargs = {"messages": messages} + prompt = RenderedPrompt(messages=messages) - res = llm.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={}) + res = llm.generate(prompt=prompt, llm_kwargs={}) assert len(res) > 0 @@ -45,70 +47,84 @@ def test_openai_messages_defaults(): def test_cached_openai(tmp_path: Path): cache = DiskCache(str(tmp_path)) llm = OpenAI(OpenAIModels.GPT_3_5_TURBO, cache=cache) - prompt_kwargs = {"prompt": "Write a limerick about large language models."} + prompt = RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a limerick about large language models.")] + ) - key = llm._llm_cache_key(prompt_kwargs, {}) + key = llm._llm_cache_key(prompt, {}) - res = llm.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={}) + res = llm.generate(prompt=prompt, llm_kwargs={}) # assert result is cached assert cache.get(key).get("result") == res - assert cache.get(key).get("prompt_kwargs") == prompt_kwargs + assert cache.get(key).get("prompt") == prompt + assert cache.get(key).get("prompt.response_format") == None assert cache.get(key).get("llm_kwargs") == {} assert cache.get(key).get("model_name") == "gpt-3.5-turbo" # assert llm.generate is using cached result custom_output = { "result": "This is a custom response", - "prompt_kwargs": prompt_kwargs, + "prompt": prompt, + "prompt.response_format": None, "llm_kwargs": {}, "model_name": "gpt-3.5-turbo", } cache.set(key, custom_output) - assert llm.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={}) == custom_output["result"] + assert llm.generate(prompt=prompt, llm_kwargs={}) == custom_output["result"] def test_cached_guidance(tmp_path: Path): cache = DiskCache(str(tmp_path)) llm = OpenAI(OpenAIModels.GPT_3_5_TURBO, cache=cache) - prompt_kwargs = {"prompt": TestPrompt()} + prompt = TestPrompt().render_generic() - key = llm._llm_cache_key(prompt_kwargs, None) + key = llm._llm_cache_key(prompt, None) - res = llm.generate(prompt_kwargs=prompt_kwargs, llm_kwargs=None) + res = llm.generate(prompt=prompt, llm_kwargs=None) # assert result is cached assert isinstance(cache.get(key), dict) assert cache.get(key).get("result") == res - assert cache.get(key).get("prompt_kwargs") == prompt_kwargs + assert cache.get(key).get("prompt") == prompt + assert cache.get(key).get("prompt.response_format") == None assert cache.get(key).get("llm_kwargs") is None assert cache.get(key).get("model_name") == "gpt-3.5-turbo" # assert llm.generate is using cached result custom_output = { "result": "This is a custom response", - "prompt_kwargs": {"prompt": TestPrompt()}, + "prompt": TestPrompt().render_generic(), + "prompt.response_format": None, "llm_kwargs": None, "model_name": "gpt-3.5-turbo", } cache.set(key, custom_output) - assert llm.generate(prompt_kwargs={"prompt": TestPrompt()}, llm_kwargs=None) == custom_output["result"] + assert llm.generate(prompt=TestPrompt().render_generic(), llm_kwargs=None) == custom_output["result"] def test_cached_openai_different_prompts(tmp_path: Path): cache = DiskCache(str(tmp_path)) llm = OpenAI(OpenAIModels.GPT_3_5_TURBO, cache=cache) - prompt_kwargs_1 = {"prompt": "Write a limerick about large language models."} - prompt_kwargs_2 = {"prompt": "Write a short limerick about large language models."} - prompt_kwargs_3 = {"prompt": "Write a poem about large language models."} - prompt_kwargs_4 = {"prompt": "Write a short poem about large language models."} - - key_1 = llm._llm_cache_key(prompt_kwargs_1, {}) - key_2 = llm._llm_cache_key(prompt_kwargs_2, {}) - key_3 = llm._llm_cache_key(prompt_kwargs_3, {}) - key_4 = llm._llm_cache_key(prompt_kwargs_4, {}) + prompt_1 = RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a limerick about large language models.")] + ) + prompt_2 = RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a short limerick about large language models.")] + ) + prompt_3 = RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a poem about large language models.")] + ) + prompt_4 = RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a short poem about large language models.")] + ) + + key_1 = llm._llm_cache_key(prompt_1, {}) + key_2 = llm._llm_cache_key(prompt_2, {}) + key_3 = llm._llm_cache_key(prompt_3, {}) + key_4 = llm._llm_cache_key(prompt_4, {}) keys = [key_1, key_2, key_3, key_4] assert len(keys) == len( @@ -127,21 +143,23 @@ def test_cached_openai_different_models(tmp_path: Path): llm_GPT_3_5_TURBO = OpenAI(OpenAIModels.GPT_3_5_TURBO, cache=cache) llm_GPT_4O_MINI = OpenAI(OpenAIModels.GPT_4O_MINI, cache=cache) - prompt_kwargs = {"prompt": "Write a limerick about large language models."} + prompt = RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a limerick about large language models.")] + ) # populate cache - key_GPT_3_5_TURBO = llm_GPT_3_5_TURBO._llm_cache_key(prompt_kwargs, {}) - res_GPT_3_5_TURBO = llm_GPT_3_5_TURBO.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={}) - key_GPT_4O_MINI = llm_GPT_4O_MINI._llm_cache_key(prompt_kwargs, {}) - res_GPT_4O_MINI = llm_GPT_4O_MINI.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={}) + key_GPT_3_5_TURBO = llm_GPT_3_5_TURBO._llm_cache_key(prompt, {}) + res_GPT_3_5_TURBO = llm_GPT_3_5_TURBO.generate(prompt=prompt, llm_kwargs={}) + key_GPT_4O_MINI = llm_GPT_4O_MINI._llm_cache_key(prompt, {}) + res_GPT_4O_MINI = llm_GPT_4O_MINI.generate(prompt=prompt, llm_kwargs={}) # check proper cached results assert cache.get(key_GPT_3_5_TURBO).get("result") == res_GPT_3_5_TURBO - assert cache.get(key_GPT_3_5_TURBO).get("prompt_kwargs") == prompt_kwargs + assert cache.get(key_GPT_3_5_TURBO).get("prompt") == prompt assert cache.get(key_GPT_3_5_TURBO).get("llm_kwargs") == {} assert cache.get(key_GPT_3_5_TURBO).get("model_name") == "gpt-3.5-turbo" assert cache.get(key_GPT_4O_MINI).get("result") == res_GPT_4O_MINI - assert cache.get(key_GPT_4O_MINI).get("prompt_kwargs") == prompt_kwargs + assert cache.get(key_GPT_4O_MINI).get("prompt") == prompt assert cache.get(key_GPT_4O_MINI).get("llm_kwargs") == {} assert cache.get(key_GPT_4O_MINI).get("model_name") == "gpt-4o-mini" @@ -157,39 +175,44 @@ def test_cached_openai_pydantic_model(tmp_path: Path): class Statement(BaseModel): is_true: bool - prompt_kwargs = {"prompt": "2+2 = 4, is this statement true?"} - llm_kwargs = {"response_format": Statement} - llm_kwargs_cached = {"response_format": type_to_response_format_param(Statement)} + llm_kwargs = {} + llm_kwargs_cached = {} + + prompt = RenderedPrompt( + messages=[RenderedMessage(role="user", content="2+2 = 4, is this statement true?")], response_format=Statement + ) # populate cache # pylint: disable=protected-access - key_GPT_4O_MINI = llm_GPT_4O_MINI._llm_cache_key(prompt_kwargs, llm_kwargs_cached) - res_GPT_4O_MINI = llm_GPT_4O_MINI.generate(prompt_kwargs=prompt_kwargs, llm_kwargs=llm_kwargs) + key_GPT_4O_MINI = llm_GPT_4O_MINI._llm_cache_key(prompt, llm_kwargs_cached) + res_GPT_4O_MINI = llm_GPT_4O_MINI.generate(prompt=prompt, llm_kwargs=llm_kwargs) + print(res_GPT_4O_MINI) assert key_GPT_4O_MINI is not None # check cache assert cache.get(key_GPT_4O_MINI).get("result") == res_GPT_4O_MINI - assert cache.get(key_GPT_4O_MINI).get("prompt_kwargs") == prompt_kwargs + assert cache.get(key_GPT_4O_MINI).get("prompt") == RenderedPrompt(messages=prompt.messages) + assert cache.get(key_GPT_4O_MINI).get("prompt.response_format") == llm_GPT_4O_MINI._pickleable_response_format( + prompt + ) assert cache.get(key_GPT_4O_MINI).get("llm_kwargs") == llm_kwargs_cached assert cache.get(key_GPT_4O_MINI).get("model_name") == "gpt-4o-mini" -class TestPrompt(SimplePrompt): - system = "You are a skilled poet" - user = "Write a limerick about large language models" +class TestPrompt(StaticPrompt): + def __init__(self): + super().__init__(system="You are a skilled poet", user="Write a limerick about large language models") def test_openai_defaults_guidance_chat(): llm = OpenAI(OpenAIModels.GPT_3_5_TURBO) - prompt_kwargs = {"prompt": TestPrompt()} - res = llm.generate(prompt_kwargs=prompt_kwargs) + res = llm.generate(prompt=TestPrompt().render_generic()) print(res) assert len(res) > 0 def test_openai_defaults_guidance_instruct(): llm = OpenAI(OpenAIModels.GPT_3_5_TURBO_INSTRUCT) - prompt_kwargs = {"prompt": TestPrompt()} - res = llm.generate(prompt_kwargs=prompt_kwargs) + res = llm.generate(prompt=TestPrompt().render_generic()) assert len(res) > 0 @@ -208,14 +231,14 @@ def azure_llm(): def test_azure_defaults_guidance_chat(azure_llm): - prompt_kwargs = {"prompt": TestPrompt()} - res = azure_llm.generate(prompt_kwargs=prompt_kwargs) + prompt = TestPrompt().render_generic() + res = azure_llm.generate(prompt=prompt) assert len(res) > 0 def test_azure_defaults_guidance_instruct(azure_llm): - prompt_kwargs = {"prompt": TestPrompt()} - res = azure_llm.generate(prompt_kwargs=prompt_kwargs) + prompt = TestPrompt().render_generic() + res = azure_llm.generate(prompt=prompt) assert len(res) > 0 From f50032d1c0e1446b427d56fd1808f4a2300b5d4f Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Thu, 23 Jan 2025 14:33:46 -0800 Subject: [PATCH 18/46] fix base llm unit tests Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/prompts/prompts.py | 2 +- .../sycamore/tests/unit/llms/test_llms.py | 32 +++++++++++-------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py index 0ba0eb18b..86676d036 100644 --- a/lib/sycamore/sycamore/llms/prompts/prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -93,7 +93,7 @@ def set(self, **kwargs) -> "SycamorePrompt": .. code-block:: python p = StaticPrompt(system="hello", user="world") - p.render_document(Document()) + p.render_document(Document()) # [ # {"role": "system", "content": "hello"}, # {"role": "user", "content": "world"} diff --git a/lib/sycamore/sycamore/tests/unit/llms/test_llms.py b/lib/sycamore/sycamore/tests/unit/llms/test_llms.py index 62cc4aed1..de2a7491c 100644 --- a/lib/sycamore/sycamore/tests/unit/llms/test_llms.py +++ b/lib/sycamore/sycamore/tests/unit/llms/test_llms.py @@ -1,8 +1,10 @@ from pathlib import Path from unittest.mock import patch +import pytest from sycamore.llms import OpenAI, OpenAIModels, Bedrock, BedrockModels, get_llm, MODELS from sycamore.llms.llms import FakeLLM +from sycamore.llms.prompts import RenderedPrompt, RenderedMessage from sycamore.llms.prompts import EntityExtractorFewShotGuidancePrompt, EntityExtractorZeroShotGuidancePrompt from sycamore.utils.cache import DiskCache @@ -12,6 +14,8 @@ def test_openai_davinci_fallback(): assert llm._model_name == OpenAIModels.GPT_3_5_TURBO_INSTRUCT.value.name +# Skip bc prompts are changing entirely +@pytest.mark.skip def test_deprecated_prompt_fallback(): from sycamore.llms.prompts.default_prompts import ENTITY_EXTRACTOR_ZERO_SHOT_GUIDANCE_PROMPT @@ -36,8 +40,8 @@ def test_get_llm(mock_boto3_client): class TestCache: def test_nocache(self, tmp_path): llm = FakeLLM() - llm._llm_cache_set({}, None, "abc") - assert llm._llm_cache_get({}, None) is None + llm._llm_cache_set(RenderedPrompt(messages=[]), None, "abc") + assert llm._llm_cache_get(RenderedPrompt(messages=[]), None) is None def test_use_caching(self, tmp_path: Path): llm = FakeLLM() @@ -52,19 +56,19 @@ def test_use_caching(self, tmp_path: Path): def test_cache(self, tmp_path: Path): llm = FakeLLM(cache=DiskCache(str(tmp_path))) - def doit(prompt_kwargs, llm_kwargs, result, overwrite=False, already_set=False): + def doit(prompt, llm_kwargs, result, overwrite=False, already_set=False): if overwrite: - assert llm._llm_cache_get(prompt_kwargs, llm_kwargs) is not None - llm._llm_cache_set(prompt_kwargs, llm_kwargs, result) + assert llm._llm_cache_get(prompt, llm_kwargs) is not None + llm._llm_cache_set(prompt, llm_kwargs, result) elif not already_set: - assert llm._llm_cache_get(prompt_kwargs, llm_kwargs) is None - llm._llm_cache_set(prompt_kwargs, llm_kwargs, result) + assert llm._llm_cache_get(prompt, llm_kwargs) is None + llm._llm_cache_set(prompt, llm_kwargs, result) - assert llm._llm_cache_get(prompt_kwargs, llm_kwargs) == result + assert llm._llm_cache_get(prompt, llm_kwargs) == result - doit({}, None, "abc") - doit({}, None, "abc2", overwrite=True) - doit({}, {}, "def") - doit({"prompt": "foff"}, {}, {"ghi": "jkl"}) - doit({}, {"magic": True}, [1, 2, 3]) - doit({}, None, "abc2", already_set=True) + doit(RenderedPrompt(messages=[]), None, "abc") + doit(RenderedPrompt(messages=[]), None, "abc2", overwrite=True) + doit(RenderedPrompt(messages=[]), {}, "def") + doit(RenderedPrompt(messages=[RenderedMessage(role="user", content="foff")]), {}, {"ghi": "jkl"}) + doit(RenderedPrompt(messages=[]), {"magic": True}, [1, 2, 3]) + doit(RenderedPrompt(messages=[]), None, "abc2", already_set=True) From c3c7ea8e655c4e590f024989c7dcf424fbe4983b Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Thu, 23 Jan 2025 15:25:21 -0800 Subject: [PATCH 19/46] adjust all testing mock llms to updated llm interface Signed-off-by: Henry Lindeman --- .../sycamore/tests/unit/query/test_operations.py | 3 ++- .../sycamore/tests/unit/query/test_strategy.py | 3 ++- lib/sycamore/sycamore/tests/unit/test_docset.py | 3 ++- .../tests/unit/transforms/test_extract_entity.py | 3 ++- .../unit/transforms/test_graph_entity_extractor.py | 7 ++++--- .../transforms/test_graph_relationship_extractor.py | 11 ++++++----- .../unit/transforms/test_resolve_graph_entities.py | 13 +++++++------ 7 files changed, 25 insertions(+), 18 deletions(-) diff --git a/lib/sycamore/sycamore/tests/unit/query/test_operations.py b/lib/sycamore/sycamore/tests/unit/query/test_operations.py index 7e1116cf0..59ffb7b38 100644 --- a/lib/sycamore/sycamore/tests/unit/query/test_operations.py +++ b/lib/sycamore/sycamore/tests/unit/query/test_operations.py @@ -8,6 +8,7 @@ from sycamore.functions import CharacterTokenizer from sycamore.functions.basic_filters import MatchFilter, RangeFilter from sycamore.llms import LLM +from sycamore.llms.prompts import RenderedPrompt from sycamore.llms.prompts.default_prompts import ( LlmClusterEntityAssignGroupsMessagesPrompt, LlmClusterEntityFormGroupsMessagesPrompt, @@ -24,7 +25,7 @@ class MockLLM(LLM): def __init__(self): super().__init__(model_name="mock_model") - def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None): + def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: if ( prompt_kwargs["messages"] == LlmClusterEntityFormGroupsMessagesPrompt( diff --git a/lib/sycamore/sycamore/tests/unit/query/test_strategy.py b/lib/sycamore/sycamore/tests/unit/query/test_strategy.py index 8a272ea93..0b1db8d9f 100644 --- a/lib/sycamore/sycamore/tests/unit/query/test_strategy.py +++ b/lib/sycamore/sycamore/tests/unit/query/test_strategy.py @@ -2,6 +2,7 @@ from typing import Optional from sycamore.llms import LLM +from sycamore.llms.prompts import RenderedPrompt from sycamore.query.logical_plan import LogicalPlan from sycamore.query.operators.query_database import QueryDatabase from sycamore.query.strategy import ( @@ -17,7 +18,7 @@ class DummyLLMClient(LLM): def is_chat_mode(self) -> bool: return False - def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> str: + def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: return "Dummy response from an LLM Client" diff --git a/lib/sycamore/sycamore/tests/unit/test_docset.py b/lib/sycamore/sycamore/tests/unit/test_docset.py index 443aadff6..c1d623acc 100644 --- a/lib/sycamore/sycamore/tests/unit/test_docset.py +++ b/lib/sycamore/sycamore/tests/unit/test_docset.py @@ -8,6 +8,7 @@ from sycamore import DocSet, Context from sycamore.data import Document, Element from sycamore.llms import LLM +from sycamore.llms.prompts import RenderedPrompt from sycamore.llms.prompts.default_prompts import ( LlmClusterEntityAssignGroupsMessagesPrompt, LlmClusterEntityFormGroupsMessagesPrompt, @@ -41,7 +42,7 @@ class MockLLM(LLM): def __init__(self): super().__init__(model_name="mock_model") - def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None): + def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: if ( prompt_kwargs == {"messages": [{"role": "user", "content": "Element_index: 1\nText: third element\n"}]} and llm_kwargs == {} diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py b/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py index 389165653..d5b47c48f 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py @@ -8,6 +8,7 @@ from sycamore.transforms import ExtractEntity from sycamore.transforms.extract_entity import OpenAIEntityExtractor from sycamore.llms import LLM +from sycamore.llms.prompts import RenderedPrompt from sycamore.llms.prompts.default_prompts import ( EntityExtractorFewShotGuidancePrompt, EntityExtractorZeroShotGuidancePrompt, @@ -21,7 +22,7 @@ class MockLLM(LLM): def __init__(self): super().__init__(model_name="mock_model") - def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None): + def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: if prompt_kwargs == {"messages": [{"role": "user", "content": "s3://path"}]} and llm_kwargs == {}: return "alt_title" if prompt_kwargs == {"prompt": "s3://path"} and llm_kwargs == {}: diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_graph_entity_extractor.py b/lib/sycamore/sycamore/tests/unit/transforms/test_graph_entity_extractor.py index c7b5ab5f5..b7b3d88bd 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_graph_entity_extractor.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_graph_entity_extractor.py @@ -4,6 +4,7 @@ from sycamore.data.document import Document, HierarchicalDocument from sycamore.data.element import Element from sycamore.llms.llms import LLM +from sycamore.llms.prompts import RenderedPrompt from sycamore.reader import DocSetReader from sycamore.transforms.extract_document_structure import StructureBySection from sycamore.transforms.extract_graph_entities import EntityExtractor @@ -58,13 +59,13 @@ class MockLLM(LLM): def __init__(self): super().__init__(model_name="mock_model") - def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None): - pass + def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: + return "" def is_chat_mode(self): return True - async def generate_async(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None): + async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None): return """{ "Company": [ {"name": "Microsoft"}, diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_graph_relationship_extractor.py b/lib/sycamore/sycamore/tests/unit/transforms/test_graph_relationship_extractor.py index 59c898ff2..d299651ca 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_graph_relationship_extractor.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_graph_relationship_extractor.py @@ -4,6 +4,7 @@ from sycamore.data.document import Document, HierarchicalDocument from sycamore.data.element import Element from sycamore.llms.llms import LLM +from sycamore.llms.prompts import RenderedPrompt from sycamore.reader import DocSetReader from sycamore.transforms.extract_document_structure import StructureBySection from sycamore.transforms.extract_graph_entities import EntityExtractor @@ -60,13 +61,13 @@ class MockEntityLLM(LLM): def __init__(self): super().__init__(model_name="mock_model") - def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None): - pass + def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: + return "" def is_chat_mode(self): return True - async def generate_async(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None): + async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None): return """{ "Company": [ {"name": "Microsoft"}, @@ -80,13 +81,13 @@ class MockRelationshipLLM(LLM): def __init__(self): super().__init__(model_name="mock_model") - def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> str: + def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: return "" def is_chat_mode(self): return True - async def generate_async(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None): + async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None): return """{ "Competes": [ {"start": {"name": "Microsoft"}, "end": {"name": "Google"}} diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_resolve_graph_entities.py b/lib/sycamore/sycamore/tests/unit/transforms/test_resolve_graph_entities.py index c7cd071fb..ab909d5dc 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_resolve_graph_entities.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_resolve_graph_entities.py @@ -4,6 +4,7 @@ from sycamore.data.document import Document, HierarchicalDocument from sycamore.data.element import Element from sycamore.llms.llms import LLM +from sycamore.llms.prompts import RenderedPrompt from sycamore.reader import DocSetReader from sycamore.transforms.extract_document_structure import StructureBySection from sycamore.transforms.extract_graph_entities import EntityExtractor @@ -59,13 +60,13 @@ class MockEntityLLM(LLM): def __init__(self): super().__init__(model_name="mock_model") - def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None): - pass + def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: + """""" def is_chat_mode(self): return True - async def generate_async(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None): + async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None): return """{ "Company": [ {"name": "Microsoft"}, @@ -79,13 +80,13 @@ class MockRelationshipLLM(LLM): def __init__(self): super().__init__(model_name="mock_model") - def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None): - pass + def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: + return "" def is_chat_mode(self): return True - async def generate_async(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None): + async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None): return """{ "Competes": [ {"start": {"name": "Microsoft"}, "end": {"name": "Google"}} From ffaaf0f16abc3e8b4b1d9ba222088924c251e53e Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Thu, 23 Jan 2025 16:04:26 -0800 Subject: [PATCH 20/46] deprecate extract entity and implement it with llm_map Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/docset.py | 21 ++- .../sycamore/llms/prompts/default_prompts.py | 145 ++++++++++++++++-- lib/sycamore/sycamore/utils/deprecate.py | 26 ++++ 3 files changed, 174 insertions(+), 18 deletions(-) create mode 100644 lib/sycamore/sycamore/utils/deprecate.py diff --git a/lib/sycamore/sycamore/docset.py b/lib/sycamore/sycamore/docset.py index 91655d565..cd922070d 100644 --- a/lib/sycamore/sycamore/docset.py +++ b/lib/sycamore/sycamore/docset.py @@ -30,6 +30,7 @@ from sycamore.transforms.extract_table import TableExtractor from sycamore.transforms.merge_elements import ElementMerger from sycamore.utils.extract_json import extract_json +from sycamore.utils.deprecate import deprecated from sycamore.transforms.query import QueryExecutor, Query from sycamore.materialize_config import MaterializeSourceMode @@ -466,7 +467,10 @@ def extract_document_structure(self, structure: DocumentStructure, **kwargs): document_structure = ExtractDocumentStructure(self.plan, structure=structure, **kwargs) return DocSet(self.context, document_structure) - def extract_entity(self, entity_extractor: EntityExtractor, **kwargs) -> "DocSet": + @deprecated(version="0.1.31", reason="Use llm_map instead") + def extract_entity( + self, entity_name: str, llm: LLM, examples: Optional[str] = None, llm_mode: LLMMode = LLMMode.SYNC, **kwargs + ) -> "DocSet": """ Applies the ExtractEntity transform on the Docset. @@ -490,10 +494,19 @@ def extract_entity(self, entity_extractor: EntityExtractor, **kwargs) -> "DocSet .extract_entity(entity_extractor=entity_extractor) """ - from sycamore.transforms import ExtractEntity + if examples is None: + from sycamore.llms.prompts.default_prompts import EntityExtractorZeroShotGuidancePrompt as zero_shot - entities = ExtractEntity(self.plan, context=self.context, entity_extractor=entity_extractor, **kwargs) - return DocSet(self.context, entities) + prompt = zero_shot.set(entity=entity_name) + else: + from sycamore.llms.prompts.default_prompts import EntityExtractorFewShotGuidancePrompt as few_shot + + prompt = few_shot.set(entity=entity_name, examples=examples) + + from sycamore.transforms.base_llm import LLMMap + + llm_map = LLMMap(self.plan, prompt=prompt, output_field=entity_name, llm=llm, llm_mode=llm_mode, **kwargs) + return DocSet(self.context, llm_map) def extract_schema(self, schema_extractor: SchemaExtractor, **kwargs) -> "DocSet": """ diff --git a/lib/sycamore/sycamore/llms/prompts/default_prompts.py b/lib/sycamore/sycamore/llms/prompts/default_prompts.py index ac2760704..5170187b1 100644 --- a/lib/sycamore/sycamore/llms/prompts/default_prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/default_prompts.py @@ -1,13 +1,17 @@ +from dataclasses import dataclass import logging from abc import ABC -from typing import Any, Optional, Type +from typing import Any, Optional, Type, ClassVar, Callable +from sycamore.data.element import Element +from sycamore.data.document import Document from sycamore.schema import Schema +from sycamore.llms.prompts.prompts import ElementListPrompt, ElementPrompt, StaticPrompt logger = logging.getLogger(__name__) -class SimplePrompt(ABC): +class _SimplePrompt(ABC): system: Optional[str] = None user: Optional[str] = None var_name: str = "answer" @@ -35,7 +39,10 @@ def __hash__(self): return hash((self.system, self.user, self.var_name)) -class EntityExtractorZeroShotGuidancePrompt(SimplePrompt): +SimplePrompt = _SimplePrompt + + +class _EntityExtractorZeroShotGuidancePrompt(_SimplePrompt): system = "You are a helpful entity extractor" # ruff: noqa: E501 user = """You are given a few text elements of a document. The {entity} of the document is in these few text elements.Using @@ -45,7 +52,15 @@ class EntityExtractorZeroShotGuidancePrompt(SimplePrompt): """ -class EntityExtractorFewShotGuidancePrompt(SimplePrompt): +EntityExtractorZeroShotGuidancePrompt = ElementListPrompt( + system="You are a helpful entity extractor", + user="""You are given a few text elements of a document. The {entity} of the document is in these few text elements. + Using this context, FIND, COPY, and RETURN the {entity}. DO NOT REPHRASE OR MAKE UP AN ANSWER. + {elements}""", +) + + +class _EntityExtractorFewShotGuidancePrompt(SimplePrompt): system = "You are a helpful entity extractor." # ruff: noqa: E501 user = """You are given a few text elements of a document. The {entity} of the document is in these few text elements. Here are @@ -57,7 +72,19 @@ class EntityExtractorFewShotGuidancePrompt(SimplePrompt): """ -class TextSummarizerGuidancePrompt(SimplePrompt): +EntityExtractorFewShotGuidancePrompt = ElementListPrompt( + system="You are a helpful entity extractor", + user="""You are given a few text elements of a document. The {entity} of the document is in these few text elements. Here are + some example groups of text elements where the {entity} has been identified. + {examples} + Using the context from the document and the provided examples, FIND, COPY, and RETURN the {entity}. Only return the {entity} as part + of your answer. DO NOT REPHRASE OR MAKE UP AN ANSWER. + {elements} + """, +) + + +class _TextSummarizerGuidancePrompt(SimplePrompt): system = "You are a helpful text summarizer." user = """Write a summary of the following. Use only the information provided. Include as many key details as possible. Do not make up answer. Only return the summary as part of your answer. @@ -66,7 +93,16 @@ class TextSummarizerGuidancePrompt(SimplePrompt): var_name = "summary" -class SchemaZeroShotGuidancePrompt(SimplePrompt): +TextSummarizerGuidancePrompt = ElementPrompt( + system="You are a helpful text summarizer.", + user="""Write a summary of the following. Use onlt the information provided. + Include as many key details as possible. Do not make up your answer. Only return the summary as part of your answer + {elt_text} + """, +) + + +class _SchemaZeroShotGuidancePrompt(SimplePrompt): system = "You are a helpful entity extractor. You only return JSON Schema." user = """You are given a few text elements of a document. Extract JSON Schema representing one entity of class {entity} from the document. Using this context, FIND, FORMAT, and RETURN the JSON-LD Schema. @@ -76,7 +112,18 @@ class {entity} from the document. Using this context, FIND, FORMAT, and RETURN t """ -class TaskIdentifierZeroShotGuidancePrompt(SimplePrompt): +SchemaZeroShotGuidancePrompt = ElementListPrompt( + system="You are a helpful entity extractor. You only return JSON Schema.", + user="""You are given a few text elements of a document. Extract JSON Schema representing one entity of + class {entity} from the document. Using this context, FIND, FORMAT, and RETURN the JSON-LD Schema. + Return a flat schema, without nestes properties. Return at most {max_num_properties} properties. + Only return JSON Schema as part of your answer. + {elements}""", + max_num_properties=7, +) + + +class _TaskIdentifierZeroShotGuidancePrompt(SimplePrompt): system = "You are a helpful task identifier. You return a string containing no whitespace." user = """You are given a dictionary where the keys are task IDs and the values are descriptions of tasks. Using this context, FIND and RETURN only the task ID that best matches the given question. @@ -86,6 +133,17 @@ class TaskIdentifierZeroShotGuidancePrompt(SimplePrompt): """ +TaskIdentifierZeroShotGuidancePrompt = StaticPrompt( + system="You are a helpful task identifier. You return a string containing no whitespace.", + user="""You are given a dictionary where the keys are task IDs and the values are descriptions of tasks. + Using this context, FIND and RETURN only the task ID that best matches the given question. + Only return the task ID as a string. Do not return any additional information. + {task_descriptions} + Question: {question} + """, +) + + class GraphEntityExtractorPrompt(SimplePrompt): user = """ -Instructions- @@ -108,7 +166,7 @@ class GraphRelationshipExtractorPrompt(SimplePrompt): """ -class ExtractTablePropertiesPrompt(SimplePrompt): +class _ExtractTablePropertiesPrompt(SimplePrompt): user = """ You are given a text string represented as a CSV (comma-separated values) and an image of a table. @@ -150,8 +208,8 @@ class ExtractTablePropertiesPrompt(SimplePrompt): |---------------------------------|------------------|------------------| return False - - example of a key value table containing null walues + + example of a key value table containing null values |---------------------------------|---------------------| | header 1 : | header 2: 'value2' | | header 3 : | header 4 : | @@ -163,6 +221,65 @@ class ExtractTablePropertiesPrompt(SimplePrompt): """ +ExtractTablePropertiesPrompt = ElementPrompt( + user=""" + You are given a text string represented as a CSV (comma-separated values) and an image of a table. + + Instructions: + Check if the table contains key-value pairs. A key-value pair table is a table where data is structured as key-value pairs. Generally, the first column contains the key and the second column contains the value. However, key-value pairs can also appear in other formats. + If there is a one-to-one mapping between two cells, even if the relationship is not direct, they should be considered key-value pairs. + If the table is a key-value pair table, return its key-value pairs as a JSON object. + If the table is not a key-value pair table, return False. + Use camelCase for the key names in the JSON object. + Parse the CSV table, check the image, and return a flattened JSON object representing the key-value pairs from the table. The extracted key-value pairs should be formatted as a JSON object. + Do not return nested objects; keep the dictionary only one level deep. The only valid value types are numbers, strings, None, and lists. + A table can have multiple or all null values for a key. In such cases, return a JSON object with the specified key set to null for all rows in the table. + For fields where the values are in standard measurement units like miles, nautical miles, knots, or Celsius, include the unit in the key name and only set the numeric value as the value: + + "Wind Speed: 9 knots" should become "windSpeedInKnots": 9 + "Temperature: 3°C" should become "temperatureInC": 3 + Ensure that key names are enclosed in double quotes. + + Return only the JSON object between ``` if the table is a key-value pair table; otherwise, return False. + + example of a key-value pair table: + |---------------------------------|------------------| + | header 1 | header 2 | + |---------------------------------|------------------| + | NEW FIRE ALARM SYSTEMS | $272 TWO HOURS | + | NEW SPRINKLER SYSTEMS | $408 THREE HOURS | + | NEW GASEOUS SUPPRESSION SYSTEMS | $272 TWO HOURS | + |---------------------------------|------------------| + + return ```{"NEW FIRE ALARM SYSTEMS": "$272 TWO HOURS", "NEW SPRINKLER SYSTEMS": "$408 THREE HOURS", "NEW GASEOUS SUPPRESSION SYSTEMS": "$272 TWO HOURS"}``` + + example of a table which is not key-value pair table: + |---------------------------------|------------------|------------------| + | header 1 | header 2 | header 3 | + |---------------------------------|------------------|------------------| + | NEW FIRE ALARM SYSTEMS | $272 TWO HOURS | $2752 ONE HOUR | + | NEW SPRINKLER SYSTEMS | $408 THREE HOURS | $128 FIVE HOURS | + | NEW GASEOUS SUPPRESSION SYSTEMS | $272 TWO HOURS | $652 TEN HOURS | + |---------------------------------|------------------|------------------| + + return False + + example of a key value table containing null values + |---------------------------------|---------------------| + | header 1 : | header 2: 'value2' | + | header 3 : | header 4 : | + | header 5 : | header 6: | + |---------------------------------|---------------------| + + return ```{"header1": null, "header2": "value2", "header3": null, "header4": null, "header5": null, "header6": null}``` + + CSV: + {elt_text} + """, + include_element_image=True, +) + + class ExtractPropertiesFromSchemaPrompt(SimplePrompt): def __init__(self, schema: Schema, text: str): super().__init__() @@ -171,13 +288,13 @@ def __init__(self, schema: Schema, text: str): self.user = f""" Extract values for the following fields: {self._format_schema(schema)} - + Document text: {text} - + Don't return extra information. If you cannot find a value for a requested property, use the provided default or the value 'None'. - Return your answers as a valid json dictionary that will be parsed in python. + Return your answers as a valid json dictionary that will be parsed in python. """ @staticmethod @@ -188,7 +305,7 @@ def _format_schema(schema: Schema) -> str: {i} {field.name}: type={field.field_type}: default={field.default} {field.description}\n Examples values: {field.examples} - + """ return text diff --git a/lib/sycamore/sycamore/utils/deprecate.py b/lib/sycamore/sycamore/utils/deprecate.py new file mode 100644 index 000000000..6810c8d55 --- /dev/null +++ b/lib/sycamore/sycamore/utils/deprecate.py @@ -0,0 +1,26 @@ +from functools import wraps +from typing import Optional, Callable, TypeVar +from typing_extensions import ParamSpec +import warnings + +P = ParamSpec("P") +T = TypeVar("T") + + +def deprecated(version: Optional[str] = None, reason: Optional[str] = None): + + def decorator(fn: Callable[P, T]) -> Callable[P, T]: + warn_msg = f"{fn.__name__} is deprecated" + if version is not None: + warn_msg += f" since version {version}" + if reason is not None: + warn_msg += f". Reason: {reason}" + + @wraps(fn) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + warnings.warn(warn_msg, category=FutureWarning) + return fn(*args, **kwargs) + + return wrapper + + return decorator From d71cf1a2c0ccbe6d5c3fc9e9623b2255a740c704 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Thu, 23 Jan 2025 16:33:45 -0800 Subject: [PATCH 21/46] add context_params decorator to llm_map Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/docset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/sycamore/sycamore/docset.py b/lib/sycamore/sycamore/docset.py index cd922070d..36266429e 100644 --- a/lib/sycamore/sycamore/docset.py +++ b/lib/sycamore/sycamore/docset.py @@ -18,7 +18,6 @@ from sycamore.transforms.augment_text import TextAugmentor from sycamore.transforms.embed import Embedder from sycamore.transforms import DocumentStructure, Sort -from sycamore.transforms.extract_entity import EntityExtractor, OpenAIEntityExtractor from sycamore.transforms.extract_graph_entities import GraphEntityExtractor from sycamore.transforms.extract_graph_relationships import GraphRelationshipExtractor from sycamore.transforms.extract_schema import SchemaExtractor, PropertyExtractor @@ -962,6 +961,7 @@ def custom_flat_mapping_function(document: Document) -> list[Document]: flat_map = FlatMap(self.plan, f=f, **resource_args) return DocSet(self.context, flat_map) + @context_params def llm_map( self, prompt: SycamorePrompt, output_field: str, llm: LLM, llm_mode: LLMMode = LLMMode.SYNC, **kwargs ) -> "DocSet": @@ -979,6 +979,7 @@ def llm_map( llm_map = LLMMap(self.plan, prompt=prompt, output_field=output_field, llm=llm, llm_mode=llm_mode, **kwargs) return DocSet(self.context, llm_map) + @context_params def llm_map_elements( self, prompt: SycamorePrompt, output_field: str, llm: LLM, llm_mode: LLMMode = LLMMode.SYNC, **kwargs ) -> "DocSet": From 4225e11d4f361108089a1000ce318fe59c042bfc Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Fri, 24 Jan 2025 10:21:27 -0800 Subject: [PATCH 22/46] revert extract_entity docset method re-implementation Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/docset.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/lib/sycamore/sycamore/docset.py b/lib/sycamore/sycamore/docset.py index 36266429e..b9dd76232 100644 --- a/lib/sycamore/sycamore/docset.py +++ b/lib/sycamore/sycamore/docset.py @@ -18,6 +18,7 @@ from sycamore.transforms.augment_text import TextAugmentor from sycamore.transforms.embed import Embedder from sycamore.transforms import DocumentStructure, Sort +from sycamore.transforms.extract_entity import EntityExtractor, OpenAIEntityExtractor from sycamore.transforms.extract_graph_entities import GraphEntityExtractor from sycamore.transforms.extract_graph_relationships import GraphRelationshipExtractor from sycamore.transforms.extract_schema import SchemaExtractor, PropertyExtractor @@ -467,9 +468,7 @@ def extract_document_structure(self, structure: DocumentStructure, **kwargs): return DocSet(self.context, document_structure) @deprecated(version="0.1.31", reason="Use llm_map instead") - def extract_entity( - self, entity_name: str, llm: LLM, examples: Optional[str] = None, llm_mode: LLMMode = LLMMode.SYNC, **kwargs - ) -> "DocSet": + def extract_entity(self, entity_extractor: EntityExtractor, **kwargs) -> "DocSet": """ Applies the ExtractEntity transform on the Docset. @@ -493,19 +492,10 @@ def extract_entity( .extract_entity(entity_extractor=entity_extractor) """ - if examples is None: - from sycamore.llms.prompts.default_prompts import EntityExtractorZeroShotGuidancePrompt as zero_shot - - prompt = zero_shot.set(entity=entity_name) - else: - from sycamore.llms.prompts.default_prompts import EntityExtractorFewShotGuidancePrompt as few_shot - - prompt = few_shot.set(entity=entity_name, examples=examples) + from sycamore.transforms import ExtractEntity - from sycamore.transforms.base_llm import LLMMap - - llm_map = LLMMap(self.plan, prompt=prompt, output_field=entity_name, llm=llm, llm_mode=llm_mode, **kwargs) - return DocSet(self.context, llm_map) + entities = ExtractEntity(self.plan, context=self.context, entity_extractor=entity_extractor, **kwargs) + return DocSet(self.context, entities) def extract_schema(self, schema_extractor: SchemaExtractor, **kwargs) -> "DocSet": """ @@ -961,7 +951,6 @@ def custom_flat_mapping_function(document: Document) -> list[Document]: flat_map = FlatMap(self.plan, f=f, **resource_args) return DocSet(self.context, flat_map) - @context_params def llm_map( self, prompt: SycamorePrompt, output_field: str, llm: LLM, llm_mode: LLMMode = LLMMode.SYNC, **kwargs ) -> "DocSet": @@ -979,7 +968,6 @@ def llm_map( llm_map = LLMMap(self.plan, prompt=prompt, output_field=output_field, llm=llm, llm_mode=llm_mode, **kwargs) return DocSet(self.context, llm_map) - @context_params def llm_map_elements( self, prompt: SycamorePrompt, output_field: str, llm: LLM, llm_mode: LLMMode = LLMMode.SYNC, **kwargs ) -> "DocSet": From 0d39b27485bae4719927d31bd5e7450a26f19e6d Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Fri, 24 Jan 2025 16:37:26 -0800 Subject: [PATCH 23/46] add initial support for prompts that generate a sequence of rendered prompts (to do things like adhere to token limits) Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/prompts/prompts.py | 109 +++++++++++++++++- lib/sycamore/sycamore/transforms/base_llm.py | 30 ++++- 2 files changed, 129 insertions(+), 10 deletions(-) diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py index 86676d036..703422573 100644 --- a/lib/sycamore/sycamore/llms/prompts/prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Union, Optional, Callable +from typing import Any, Union, Optional, Callable, Sequence import copy import pydantic @@ -42,7 +42,7 @@ class SycamorePrompt: convert sycamore objects (``Document``s, ``Element``s) into ``RenderedPrompts`` """ - def render_document(self, doc: Document) -> RenderedPrompt: + def render_document(self, doc: Document) -> Union[RenderedPrompt, Sequence[RenderedPrompt]]: """Render this prompt, given this document as context. Used in llm_map @@ -54,7 +54,7 @@ def render_document(self, doc: Document) -> RenderedPrompt: """ raise NotImplementedError(f"render_document is not implemented for {self.__class__.__name__}") - def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: + def render_element(self, elt: Element, doc: Document) -> Union[RenderedPrompt, Sequence[RenderedPrompt]]: """Render this prompt, given this element and its parent document as context. Used in llm_map_elements @@ -66,7 +66,7 @@ def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: """ raise NotImplementedError(f"render_element is not implemented for {self.__class__.__name__}") - def render_multiple_documents(self, docs: list[Document]) -> RenderedPrompt: + def render_multiple_documents(self, docs: list[Document]) -> Union[RenderedPrompt, Sequence[RenderedPrompt]]: """Render this prompt, given a list of documents as context. Used in llm_reduce @@ -113,6 +113,19 @@ def set(self, **kwargs) -> "SycamorePrompt": new.__dict__[k] = v return new + def is_done(self, s: str) -> bool: + """Decide whether a given response is sufficient. Used when rendering + the prompt generates a sequence of prompts rather than a single prompt. + The default implementation always returns True + + Args: + s: a string response from the LLM + + Returns: + Whether to continue making LLM calls + """ + return True + def _build_format_str( system: Optional[str], user: Union[None, str, list[str]], format_args: dict[str, Any] @@ -188,7 +201,7 @@ def _render_element_list_to_string(self, doc: Document): elts = self.element_select(doc.elements) return self.element_list_constructor(elts) - def render_document(self, doc: Document) -> RenderedPrompt: + def render_document(self, doc: Document) -> Union[RenderedPrompt, Sequence[RenderedPrompt]]: """Render this prompt, given this document as context, using python's ``str.format()`` method. The keys passed into ``format()`` are as follows: @@ -218,6 +231,92 @@ def render_document(self, doc: Document) -> RenderedPrompt: return result +class ElementListIterPrompt(ElementListPrompt): + """A prompt with utilities for constructing a lists of elements to include + in a sequence of rendered prompts. Functions almost identically to ElementListPrompt, + but renders into a series of prompts. + + Args: + + system: The system prompt string. Use {} to reference names that should + be interpolated. Defaults to None + user: The user prompt string. Use {} to reference names that should be + interpolated. Defaults to None + element_select: Function to choose the elements (and their order) to include + in the prompt. If None, defaults to the first ``num_elements`` elements. + element_list_constructor: Function to turn a list of elements into a + string that can be accessed with the interpolation key "{elements}". + Defaults to "ELEMENT 0: {elts[0].text_representation}\\n + ELEMENT 1: {elts[1].text_representation}\\n + ..." + num_elements: Sets the number of elements to take if ``element_select`` is + unset. Default is 35. + element_batcher: Constructs batches of elements to render in sequence to generate + several rendered prompts. Defaults to one batch with all elements. + **kwargs: other keyword arguments are stored and can be used as interpolation keys. + + Example: + .. code-block:: python + + prompt = ElementListIterPrompt( + system = "You are a program that returns 'None' if you don't know the answer to my question" + user = "What is the capital of the country described?\\nElements:\\n{elements}" + element_batcher = lambda elts: [elts[i:i+2] for i in range(0, len(elts), 2)] + ).set(is_done=lambda s: s != 'None') + prompt.render_document(doc) + # [ + # [ + # {"role": "system", "content": "You are a program that returns 'None' if you don't know the answer to my question"}, + # {"role": "user", "content": "What is the capital of the country described?\\nElements:\\n + # ELEMENT 0: \\nELEMENT 1: "} + # ], + # [ + # {"role": "system", "content": "You are a program that returns 'None' if you don't know the answer to my question"}, + # {"role": "user", "content": "What is the capital of the country described?\\nElements:\\n + # ELEMENT 0: \\nELEMENT 1: "} + # ] + # ] + """ + + def __init__(self, *, element_batcher: Optional[Callable[[list[Element]], list[list[Element]]]] = None, **kwargs): + self.element_batcher = element_batcher or (lambda e: [e]) + super().__init__(**kwargs) + + def render_document(self, doc: Document) -> Sequence[RenderedPrompt]: + """Render this prompt, given this document as context, using python's + ``str.format()`` method. The keys passed into ``format()`` are as follows: + + - self.kwargs: the additional kwargs specified when creating this prompt. + - doc_text: doc.text_representation + - doc_property_: each property name in doc.properties is + prefixed with 'doc_property_'. So if ``doc.properties = {'k1': 0, 'k2': 3}``, + you get ``doc_property_k1 = 0, doc_property_k2 = 3``. + - elements: the element list constructed from doc.elements using ``self.element_select``, + ``self.element_order``, and ``self.element_list_constructor``. + + Args: + doc: The document to use as context for rendering this prompt + + Returns: + A list of two-message RenderedPrompts containing ``self.system.format()`` and + ``self.user.format()`` using the format keys as specified above. Each instance + is rendered from a batch of elements generated by ``self.element_batcher`` + """ + + format_args = self.kwargs + format_args["doc_text"] = doc.text_representation + flat_props = flatten_data(doc.properties, prefix="doc_property", separator="_") + format_args.update(flat_props) + + prompts = [] + for elt_batch in self.element_batcher(doc.elements): + elements = self.element_select(elt_batch) + elementstr = self.element_list_constructor(elements) + messages = _build_format_str(self.system, self.user, {"elements": elementstr, **format_args}) + prompts.append(RenderedPrompt(messages=messages)) + return prompts + + class ElementPrompt(SycamorePrompt): """A prompt for rendering an element with utilities for capturing information from the element's parent document, with a system and user prompt. diff --git a/lib/sycamore/sycamore/transforms/base_llm.py b/lib/sycamore/sycamore/transforms/base_llm.py index 5569b1c1a..5582840d7 100644 --- a/lib/sycamore/sycamore/transforms/base_llm.py +++ b/lib/sycamore/sycamore/transforms/base_llm.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Sequence, Callable, Union from sycamore.llms.llms import LLM, LLMMode from sycamore.llms.prompts.prompts import SycamorePrompt, RenderedPrompt @@ -7,9 +7,22 @@ from sycamore.data import Document, Element -def _infer_prompts(prompts: list[RenderedPrompt], llm: LLM, llm_mode: LLMMode) -> list[str]: +def _infer_prompts( + prompts: list[Sequence[RenderedPrompt]], + llm: LLM, + llm_mode: LLMMode, + is_done: Callable[[str], bool] = lambda s: True, +) -> list[str]: if llm_mode == LLMMode.SYNC: - return [llm.generate(prompt=p) for p in prompts] + res = [] + for piter in prompts: + s = "" + for p in piter: + s = llm.generate(prompt=p) + if is_done(s): + break + res.append(s) + return res elif llm_mode == LLMMode.ASYNC: raise NotImplementedError("Haven't done async yet") elif llm_mode == LLMMode.BATCH: @@ -63,7 +76,8 @@ def __init__( def llm_map(self, documents: list[Document]) -> list[Document]: rendered = [self._prompt.render_document(d) for d in documents] - results = _infer_prompts(rendered, self._llm, self._llm_mode) + rendered = _as_sequences(rendered) + results = _infer_prompts(rendered, self._llm, self._llm_mode, self._prompt.is_done) for d, r in zip(documents, results): d.properties[self._output_field] = r return documents @@ -122,7 +136,9 @@ def __init__( def llm_map_elements(self, documents: list[Document]) -> list[Document]: rendered = [(e, self._prompt.render_element(e, d)) for d in documents for e in d.elements] - results = _infer_prompts([p for _, p in rendered], self._llm, self._llm_mode) + results = _infer_prompts( + _as_sequences([p for _, p in rendered]), self._llm, self._llm_mode, self._prompt.is_done + ) for r, (e, _) in zip(results, rendered): e.properties[self._output_field] = r return documents @@ -136,3 +152,7 @@ def _validate_prompt(self): raise e except Exception: pass + + +def _as_sequences(l: list[Union[RenderedPrompt, Sequence[RenderedPrompt]]]) -> list[Sequence[RenderedPrompt]]: + return [[p] if isinstance(p, RenderedPrompt) else p for p in l] From 0b5ded46b523c604e629b14031acca95ba90aa6a Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Fri, 24 Jan 2025 16:38:53 -0800 Subject: [PATCH 24/46] add stuff to EntityExtractor/OpenAIEntityExtractor to convert to LLMMap Signed-off-by: Henry Lindeman --- .../sycamore/transforms/extract_entity.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/lib/sycamore/sycamore/transforms/extract_entity.py b/lib/sycamore/sycamore/transforms/extract_entity.py index 19b3663a1..96be10d6e 100644 --- a/lib/sycamore/sycamore/transforms/extract_entity.py +++ b/lib/sycamore/sycamore/transforms/extract_entity.py @@ -9,6 +9,7 @@ EntityExtractorFewShotGuidancePrompt, ) from sycamore.plan_nodes import Node +from sycamore.transforms.base_llm import LLMMap from sycamore.transforms.map import Map from sycamore.utils.time_trace import timetrace from sycamore.functions.tokenizer import Tokenizer @@ -29,6 +30,12 @@ class EntityExtractor(ABC): def __init__(self, entity_name: str): self._entity_name = entity_name + @abstractmethod + def as_llm_map( + self, child: Optional[Node], context: Optional[Context] = None, llm: Optional[LLM] = None, **kwargs + ) -> LLMMap: + pass + @abstractmethod def extract_entity( self, document: Document, context: Optional[Context] = None, llm: Optional[LLM] = None @@ -100,6 +107,32 @@ def __init__( self._similarity_query = similarity_query self._similarity_scorer = similarity_scorer + @context_params(OperationTypes.INFORMATION_EXTRACTOR) + def as_llm_map( + self, child: Optional[Node], context: Optional[Context] = None, llm: Optional[LLM] = None, **kwargs + ) -> LLMMap: + if llm is None: + llm = self._llm + assert llm is not None, "Could not find an LLM to use" + if self._prompt_template is not None: + prompt = EntityExtractorFewShotGuidancePrompt + prompt = prompt.set(examples=self._prompt_template) + else: + prompt = EntityExtractorZeroShotGuidancePrompt + + def elt_sorter(elts: list[Element]) -> list[Element]: + sorter_inner = make_element_sorter_fn(self._field, self._similarity_query, self._similarity_scorer) + dummy_doc = Document(elements=elts) + sorter_inner(dummy_doc) + return dummy_doc.elements + + prompt = prompt.set(element_select=lambda e: elt_sorter(e)[: self._num_of_elements]) + prompt = prompt.set(element_list_constructor=lambda e: self._prompt_formatter(e, self._field)) + prompt = prompt.set(entity=self._entity_name) + + llm_map = LLMMap(child, prompt, self._entity_name, llm, **kwargs) + return llm_map + @context_params(OperationTypes.INFORMATION_EXTRACTOR) @timetrace("OaExtract") def extract_entity( From a52f7c27b62b421d3773d38158e3c4a2244da45a Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Fri, 24 Jan 2025 16:40:31 -0800 Subject: [PATCH 25/46] make docset.extract_entity construct an LLMMap from its entity_extractor Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/docset.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/lib/sycamore/sycamore/docset.py b/lib/sycamore/sycamore/docset.py index b9dd76232..d41139280 100644 --- a/lib/sycamore/sycamore/docset.py +++ b/lib/sycamore/sycamore/docset.py @@ -492,10 +492,8 @@ def extract_entity(self, entity_extractor: EntityExtractor, **kwargs) -> "DocSet .extract_entity(entity_extractor=entity_extractor) """ - from sycamore.transforms import ExtractEntity - - entities = ExtractEntity(self.plan, context=self.context, entity_extractor=entity_extractor, **kwargs) - return DocSet(self.context, entities) + llm_map = entity_extractor.as_llm_map(self.plan, context=self.context, **kwargs) + return DocSet(self.context, llm_map) def extract_schema(self, schema_extractor: SchemaExtractor, **kwargs) -> "DocSet": """ From 3a9ac3c815b2a53f49b0c0017d1837da39cce47e Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Mon, 27 Jan 2025 17:55:35 -0800 Subject: [PATCH 26/46] get extract entity working with tokenizer and token limit Signed-off-by: Henry Lindeman --- .../sycamore/tests/unit/test_docset.py | 27 ++-- .../unit/transforms/test_extract_entity.py | 127 +++++++++--------- .../tests/unit/transforms/test_llm_filter.py | 4 +- lib/sycamore/sycamore/transforms/base_llm.py | 21 ++- .../sycamore/transforms/extract_entity.py | 72 +++++++++- 5 files changed, 165 insertions(+), 86 deletions(-) diff --git a/lib/sycamore/sycamore/tests/unit/test_docset.py b/lib/sycamore/sycamore/tests/unit/test_docset.py index c1d623acc..55cfd29dc 100644 --- a/lib/sycamore/sycamore/tests/unit/test_docset.py +++ b/lib/sycamore/sycamore/tests/unit/test_docset.py @@ -1,6 +1,7 @@ import random import string from typing import Callable, Optional +from dataclasses import asdict import pytest @@ -43,41 +44,47 @@ def __init__(self): super().__init__(model_name="mock_model") def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: + if llm_kwargs is None: + llm_kwargs = {} + if len(prompt.messages) > 1 and prompt.messages[1].content.endswith("Element_index: 1\nText: third element\n"): + return "None" if ( - prompt_kwargs == {"messages": [{"role": "user", "content": "Element_index: 1\nText: third element\n"}]} + asdict(prompt) == {"messages": [{"role": "user", "content": "Element_index: 1\nText: third element\n"}]} and llm_kwargs == {} ): return "None" elif ( - "first short element" in prompt_kwargs["messages"][0]["content"] - and "second longer element with more words" in prompt_kwargs["messages"][0]["content"] + len(prompt.messages) > 0 + and "first short element" in prompt.messages[1].content + and "second longer element with more words" in prompt.messages[1].content and llm_kwargs == {} ): return "4" elif ( - "very long element with many words that might exceed token limit" in prompt_kwargs["messages"][0]["content"] + len(prompt.messages) > 0 + and "very long element with many words that might exceed token limit" in prompt.messages[1].content and llm_kwargs == {} ): return "5" - elif prompt_kwargs == {"messages": [{"role": "user", "content": "test1"}]} and llm_kwargs == {}: + elif asdict(prompt) == {"messages": [{"role": "user", "content": "test1"}]} and llm_kwargs == {}: return "4" - elif prompt_kwargs == {"messages": [{"role": "user", "content": "test2"}]} and llm_kwargs == {}: + elif asdict(prompt) == {"messages": [{"role": "user", "content": "test2"}]} and llm_kwargs == {}: return "2" elif ( - prompt_kwargs["messages"] + prompt.messages == LlmClusterEntityFormGroupsMessagesPrompt( field="text_representation", instruction="", text="1, 2, one, two, 1, 3" ).as_messages() ): return '{"groups": ["group1", "group2", "group3"]}' elif ( - prompt_kwargs["messages"][0] + prompt.messages[0] == LlmClusterEntityAssignGroupsMessagesPrompt( field="text_representation", groups=["group1", "group2", "group3"] ).as_messages()[0] ): - value = prompt_kwargs["messages"][1]["content"] + value = prompt.messages[1].content if value == "1" or value == "one": return "group1" elif value == "2" or value == "two": @@ -85,7 +92,7 @@ def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) elif value == "3" or value == "three": return "group3" else: - return "" + return prompt.messages[-1].content def is_chat_mode(self): return True diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py b/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py index d5b47c48f..ece682bf3 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py @@ -1,6 +1,7 @@ from typing import Optional import logging from unittest.mock import MagicMock +from dataclasses import asdict import sycamore from sycamore.context import Context, OperationTypes, ExecMode @@ -23,39 +24,23 @@ def __init__(self): super().__init__(model_name="mock_model") def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: - if prompt_kwargs == {"messages": [{"role": "user", "content": "s3://path"}]} and llm_kwargs == {}: + if asdict(prompt) == {"messages": [{"role": "user", "content": "s3://path"}]} and llm_kwargs == {}: return "alt_title" - if prompt_kwargs == {"prompt": "s3://path"} and llm_kwargs == {}: + if asdict(prompt) == {"prompt": "s3://path"} and llm_kwargs == {}: return "alt_title" - if ( - prompt_kwargs["entity"] == "title" - and prompt_kwargs["query"] == "ELEMENT 1: None\nELEMENT 2: None\n" - and prompt_kwargs["examples"] is None - ): - assert isinstance(prompt_kwargs["prompt"], EntityExtractorZeroShotGuidancePrompt) - assert llm_kwargs is None - return "title1" + usermessage = prompt.messages[1].content + + if "Jack Black" in usermessage: + return "Jack Black" - if ( - prompt_kwargs["entity"] == "title" - and prompt_kwargs["query"] == "ELEMENT 1: None\nELEMENT 2: None\n" - and prompt_kwargs["examples"] == "title" - ): - assert isinstance(prompt_kwargs["prompt"], EntityExtractorFewShotGuidancePrompt) - assert llm_kwargs is None + if "example" in usermessage: return "title2" - if ( - prompt_kwargs["entity"] == "title" - and prompt_kwargs["query"] == "ELEMENT 1: Jack Black\nELEMENT 2: None\n" - and prompt_kwargs["examples"] is None - ): - assert isinstance(prompt_kwargs["prompt"], EntityExtractorZeroShotGuidancePrompt) - assert llm_kwargs is None - return "Jack Black" + if "title" in usermessage: + return "title1" - logging.error(f"{prompt_kwargs} // {llm_kwargs}") + logging.error(f"{prompt} // {llm_kwargs}") assert False, "Make all generate branches explicitly check the arguments" def is_chat_mode(self): @@ -88,17 +73,17 @@ class TestEntityExtraction: def test_extract_entity_zero_shot(self, mocker): llm = MockLLM() - extract_entity = ExtractEntity(None, entity_extractor=OpenAIEntityExtractor("title", llm=llm)) - out_doc = extract_entity.run(self.doc) - assert out_doc.properties.get("title") == "title1" + extractor = OpenAIEntityExtractor("title", llm=llm) + llm_map = extractor.as_llm_map(None) + out_docs = llm_map.run([self.doc]) + assert out_docs[0].properties.get("title") == "title1" def test_extract_entity_zero_shot_custom_field(self, mocker): llm = MockLLM() - extract_entity = ExtractEntity( - None, entity_extractor=OpenAIEntityExtractor("title", llm=llm, field="properties.entity.author") - ) - out_doc = extract_entity.run(self.doc) - assert out_doc.properties.get("title") == "Jack Black" + extractor = OpenAIEntityExtractor("title", llm=llm, field="properties.entity.author") + llm_map = extractor.as_llm_map(None) + out_docs = llm_map.run([self.doc]) + assert out_docs[0].properties.get("title") == "Jack Black" def test_extract_entity_with_context_llm(self, mocker): llm = MockLLM() @@ -107,40 +92,48 @@ def test_extract_entity_with_context_llm(self, mocker): "default": {"llm": llm}, } ) - extract_entity = ExtractEntity(None, context=context, entity_extractor=OpenAIEntityExtractor("title")) - out_doc = extract_entity.run(self.doc) - assert out_doc.properties.get("title") == "title1" + extractor = OpenAIEntityExtractor("title") + llm_map = extractor.as_llm_map(None, context=context) + out_docs = llm_map.run([self.doc]) + assert out_docs[0].properties.get("title") == "title1" def test_extract_entity_few_shot(self, mocker): llm = MockLLM() - extract_entity = ExtractEntity( - None, entity_extractor=OpenAIEntityExtractor("title", llm=llm, prompt_template="title") - ) - out_doc = extract_entity.run(self.doc) - assert out_doc.properties.get("title") == "title2" + extractor = OpenAIEntityExtractor("title", llm=llm, prompt_template="title") + llm_map = extractor.as_llm_map(None) + out_docs = llm_map.run([self.doc]) + assert out_docs[0].properties.get("title") == "title2" def test_extract_entity_document_field_messages(self, mocker): llm = MockLLM() - extract_entity = ExtractEntity( - None, - entity_extractor=OpenAIEntityExtractor( - "title", llm=llm, use_elements=False, prompt=[], field="properties.path" - ), - ) - out_doc = extract_entity.run(self.doc) - - assert out_doc.properties.get("title") == "alt_title" + extractor = OpenAIEntityExtractor("title", llm=llm, use_elements=False, prompt=[], field="properties.path") + llm_map = extractor.as_llm_map(None) + out_docs = llm_map.run([self.doc]) + assert out_docs[0].properties.get("title") == "alt_title" + # extract_entity = ExtractEntity( + # None, + # entity_extractor=OpenAIEntityExtractor( + # "title", llm=llm, use_elements=False, prompt=[], field="properties.path" + # ), + # ) + # out_doc = extract_entity.run(self.doc) + + # assert out_doc.properties.get("title") == "alt_title" def test_extract_entity_document_field_string(self, mocker): llm = MockLLM() - extract_entity = ExtractEntity( - None, - entity_extractor=OpenAIEntityExtractor( - "title", llm=llm, use_elements=False, prompt="", field="properties.path" - ), - ) - out_doc = extract_entity.run(self.doc) - assert out_doc.properties.get("title") == "alt_title" + extractor = OpenAIEntityExtractor("title", llm=llm, use_elements=False, prompt="", field="properties.path") + llm_map = extractor.as_llm_map(None) + out_docs = llm_map.run([self.doc]) + assert out_docs[0].properties.get("title") == "alt_title" + # extract_entity = ExtractEntity( + # None, + # entity_extractor=OpenAIEntityExtractor( + # "title", llm=llm, use_elements=False, prompt="", field="properties.path" + # ), + # ) + # out_doc = extract_entity.run(self.doc) + # assert out_doc.properties.get("title") == "alt_title" def test_extract_entity_with_similarity_sorting(self, mocker): doc_list = [ @@ -194,10 +187,15 @@ def test_extract_entity_with_similarity_sorting(self, mocker): taken = entity_docset.take() assert len(taken) == 4 assert len(taken[0].elements) == 2 - assert (taken[1].elements[0]["properties"]["_element_index"]) == 9 - assert (taken[1].elements[1]["properties"]["_element_index"]) == 4 + # Element order should be unchanged regardless of scorer + assert (taken[1].elements[0]["properties"]["_element_index"]) == 4 + assert (taken[1].elements[1]["properties"]["_element_index"]) == 9 assert (taken[0].elements[1]["properties"]["_element_index"]) == 2 + # Element order should be changed in the prompt + assert "ELEMENT 1: test2" in taken[1].properties[new_field] + assert "ELEMENT 2: test1" in taken[1].properties[new_field] + def test_extract_entity_with_tokenizer(self, mocker): mock_llm = docsetMockLLM() mock_tokenizer = MockTokenizer() @@ -219,7 +217,7 @@ def test_extract_entity_with_tokenizer(self, mocker): prompt=[], field="text_representation", tokenizer=mock_tokenizer, - max_tokens=10, # Low token limit to test windowing + max_tokens=20, # Low token limit to test windowing ) entity_docset = docset.extract_entity( @@ -235,8 +233,3 @@ def test_extract_entity_with_tokenizer(self, mocker): assert taken[0].elements[2]["properties"]["_autogen_LLMExtractEntityOutput_source_element_index"] == {0, 1, 2} assert taken[1].elements[0]["properties"]["_autogen_LLMExtractEntityOutput_source_element_index"] == {1} assert taken[1].elements[1]["properties"]["_autogen_LLMExtractEntityOutput_source_element_index"] == {2} - assert taken[0].elements[0]["properties"]["_autogen_LLMExtractEntityOutput"] == "4" - assert taken[0].elements[1]["properties"]["_autogen_LLMExtractEntityOutput"] == "4" - assert taken[0].elements[2]["properties"]["_autogen_LLMExtractEntityOutput"] == "4" - assert taken[1].elements[0]["properties"]["_autogen_LLMExtractEntityOutput"] == "None" - assert taken[1].elements[1]["properties"]["_autogen_LLMExtractEntityOutput"] == "5" diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_llm_filter.py b/lib/sycamore/sycamore/tests/unit/transforms/test_llm_filter.py index e4d171bd7..26a6ba3d6 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_llm_filter.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_llm_filter.py @@ -27,7 +27,9 @@ Element(properties={"_element_index": 1}, text_representation="third element"), # llm_filter result = 2 Element( properties={"_element_index": 2}, - text_representation="very long element with many words that might exceed token limit", + text_representation="very long element with many words that might exceed token limit." + " Specifically, it has so many words that even with the additional contextualization" + " like 'Element type' and 'page number' it still overflows", ), # llm_filter result = 5 ], ), diff --git a/lib/sycamore/sycamore/transforms/base_llm.py b/lib/sycamore/sycamore/transforms/base_llm.py index 5582840d7..069340bca 100644 --- a/lib/sycamore/sycamore/transforms/base_llm.py +++ b/lib/sycamore/sycamore/transforms/base_llm.py @@ -12,16 +12,18 @@ def _infer_prompts( llm: LLM, llm_mode: LLMMode, is_done: Callable[[str], bool] = lambda s: True, -) -> list[str]: +) -> list[tuple[str, int]]: if llm_mode == LLMMode.SYNC: res = [] for piter in prompts: s = "" + i = -1 for p in piter: + i += 1 s = llm.generate(prompt=p) if is_done(s): break - res.append(s) + res.append((s, i)) return res elif llm_mode == LLMMode.ASYNC: raise NotImplementedError("Haven't done async yet") @@ -37,6 +39,7 @@ class LLMMap(MapBatch): the document. Args: + child: Child node in the sycamore execution graph prompt: The SycamorePrompt to use to render each document. Must implement the ``render_document`` method. @@ -45,6 +48,11 @@ class LLMMap(MapBatch): llm: The llm to use for inference. llm_mode: How to call the llm - sync/async/batch. All LLMs do not necessarily implement all options. + postprocess_fn: function to call on documents after performing the + llm inference. If the prompt rendered into multiple RenderedPrompts, + ``i`` is the index of the RenderedPrompt that succeeded; if the + prompt rendered into an empty list, ``i`` is -1; and otherwise + ``i`` is 0 Example: .. code-block:: python @@ -65,6 +73,7 @@ def __init__( output_field: str, llm: LLM, llm_mode: LLMMode = LLMMode.SYNC, + postprocess_fn: Callable[[Document, int], Document] = lambda d, i: d, **kwargs, ): self._prompt = prompt @@ -72,15 +81,19 @@ def __init__( self._output_field = output_field self._llm = llm self._llm_mode = llm_mode + self._postprocess_fn = postprocess_fn super().__init__(child, f=self.llm_map, **kwargs) def llm_map(self, documents: list[Document]) -> list[Document]: rendered = [self._prompt.render_document(d) for d in documents] rendered = _as_sequences(rendered) results = _infer_prompts(rendered, self._llm, self._llm_mode, self._prompt.is_done) - for d, r in zip(documents, results): + postprocessed = [] + for d, (r, i) in zip(documents, results): d.properties[self._output_field] = r - return documents + new_d = self._postprocess_fn(d, i) + postprocessed.append(new_d) + return postprocessed def _validate_prompt(self): doc = Document() diff --git a/lib/sycamore/sycamore/transforms/extract_entity.py b/lib/sycamore/sycamore/transforms/extract_entity.py index 96be10d6e..7ffb62589 100644 --- a/lib/sycamore/sycamore/transforms/extract_entity.py +++ b/lib/sycamore/sycamore/transforms/extract_entity.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from typing import Callable, Any, Optional, Union +from typing import Callable, Any, Optional, Union, cast +from functools import partial from sycamore.context import Context, context_params, OperationTypes from sycamore.data import Element, Document @@ -8,6 +9,7 @@ EntityExtractorZeroShotGuidancePrompt, EntityExtractorFewShotGuidancePrompt, ) +from sycamore.llms.prompts.prompts import ElementListIterPrompt, ElementListPrompt from sycamore.plan_nodes import Node from sycamore.transforms.base_llm import LLMMap from sycamore.transforms.map import Map @@ -116,10 +118,71 @@ def as_llm_map( assert llm is not None, "Could not find an LLM to use" if self._prompt_template is not None: prompt = EntityExtractorFewShotGuidancePrompt - prompt = prompt.set(examples=self._prompt_template) + prompt = cast(ElementListPrompt, prompt.set(examples=self._prompt_template)) else: prompt = EntityExtractorZeroShotGuidancePrompt + def postprocess(d: Document, i: int) -> Document: + return d + + if self._tokenizer is not None: + prompt.render_document = partial(ElementListIterPrompt.render_document, prompt) # type: ignore + + def elt_list_ctor(elts: list[Element]) -> str: + combined_text = "" + for element in elts: + if "type" in element: + combined_text += f"Element type: {element['type']}\n" + if "page_number" in element["properties"]: + combined_text += f"Page_number: {element['properties']['page_number']}\n" + if "_element_index" in element["properties"]: + combined_text += f"Element_index: {element['properties']['_element_index']}\n" + combined_text += f"Text: {element.field_to_value(self._field)}\n" + return combined_text + + if self._prompt_formatter is not element_list_formatter: + prompt = prompt.set(element_list_constructor=self._prompt_formatter) + else: + prompt = prompt.set(element_list_constructor=elt_list_ctor) + + def eb(elts: list[Element]) -> list[list[Element]]: + curr_tks = 0 + curr_batch = [] + batches = [] + source_indices = set() + for e in elts: + eltl = prompt.element_list_constructor([e]) + tks = len(self._tokenizer.tokenize(eltl)) + if tks + curr_tks > self._max_tokens: + batches.append(curr_batch) + curr_tks = tks + curr_batch = [e] + source_indices = {e.element_index} + e.properties[f"{self._entity_name}_source_element_index"] = source_indices + else: + e.properties[f"{self._entity_name}_source_element_index"] = source_indices + source_indices.add(e.element_index) + curr_batch.append(e) + curr_tks += tks + batches.append(curr_batch) + return batches + + setattr(prompt, "element_batcher", eb) + prompt.is_done = lambda s: s != "None" + print(prompt.__dict__) + + def postprocess(d: Document, i: int) -> Document: + last_club = set() + source_key = f"{self._entity_name}_source_element_index" + for e in d.elements: + if e.properties[source_key] != last_club: + i -= 1 + last_club = e.properties[source_key] + if i == -1: + break + d.properties[source_key] = last_club + return d + def elt_sorter(elts: list[Element]) -> list[Element]: sorter_inner = make_element_sorter_fn(self._field, self._similarity_query, self._similarity_scorer) dummy_doc = Document(elements=elts) @@ -127,10 +190,11 @@ def elt_sorter(elts: list[Element]) -> list[Element]: return dummy_doc.elements prompt = prompt.set(element_select=lambda e: elt_sorter(e)[: self._num_of_elements]) - prompt = prompt.set(element_list_constructor=lambda e: self._prompt_formatter(e, self._field)) + if self._tokenizer is None: + prompt = prompt.set(element_list_constructor=lambda e: self._prompt_formatter(e, self._field)) prompt = prompt.set(entity=self._entity_name) - llm_map = LLMMap(child, prompt, self._entity_name, llm, **kwargs) + llm_map = LLMMap(child, prompt, self._entity_name, llm, postprocess_fn=postprocess, **kwargs) return llm_map @context_params(OperationTypes.INFORMATION_EXTRACTOR) From befc3d0b974538512680c577b13db12fecf79e56 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 28 Jan 2025 09:43:48 -0800 Subject: [PATCH 27/46] get all extract_entity unit tests passing Signed-off-by: Henry Lindeman --- .../unit/transforms/test_extract_entity.py | 11 +-- .../sycamore/transforms/extract_entity.py | 78 +++++++++++++------ 2 files changed, 59 insertions(+), 30 deletions(-) diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py b/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py index ece682bf3..38b13d78e 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py @@ -24,12 +24,13 @@ def __init__(self): super().__init__(model_name="mock_model") def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: - if asdict(prompt) == {"messages": [{"role": "user", "content": "s3://path"}]} and llm_kwargs == {}: - return "alt_title" - if asdict(prompt) == {"prompt": "s3://path"} and llm_kwargs == {}: - return "alt_title" + if len(prompt.messages) == 1: + usermessage = prompt.messages[0].content + else: + usermessage = prompt.messages[1].content - usermessage = prompt.messages[1].content + if "s3://path" in usermessage: + return "alt_title" if "Jack Black" in usermessage: return "Jack Black" diff --git a/lib/sycamore/sycamore/transforms/extract_entity.py b/lib/sycamore/sycamore/transforms/extract_entity.py index 7ffb62589..4376910cc 100644 --- a/lib/sycamore/sycamore/transforms/extract_entity.py +++ b/lib/sycamore/sycamore/transforms/extract_entity.py @@ -9,7 +9,13 @@ EntityExtractorZeroShotGuidancePrompt, EntityExtractorFewShotGuidancePrompt, ) -from sycamore.llms.prompts.prompts import ElementListIterPrompt, ElementListPrompt +from sycamore.llms.prompts.prompts import ( + ElementListIterPrompt, + ElementListPrompt, + RenderedMessage, + SycamorePrompt, + RenderedPrompt, +) from sycamore.plan_nodes import Node from sycamore.transforms.base_llm import LLMMap from sycamore.transforms.map import Map @@ -28,6 +34,19 @@ def element_list_formatter(elements: list[Element], field: str = "text_represent return query +class FieldToValuePrompt(SycamorePrompt): + def __init__(self, messages: list[RenderedMessage], field: str): + self.messages = messages + self.field = field + + def render_document(self, doc: Document) -> RenderedPrompt: + value = doc.field_to_value(self.field) + rendered = [] + for m in self.messages: + rendered.append(RenderedMessage(role=m.role, content=m.content.format(value=value))) + return RenderedPrompt(messages=rendered) + + class EntityExtractor(ABC): def __init__(self, entity_name: str): self._entity_name = entity_name @@ -122,11 +141,19 @@ def as_llm_map( else: prompt = EntityExtractorZeroShotGuidancePrompt - def postprocess(d: Document, i: int) -> Document: - return d - if self._tokenizer is not None: - prompt.render_document = partial(ElementListIterPrompt.render_document, prompt) # type: ignore + + def postprocess(d: Document, i: int) -> Document: + last_club = set() + source_key = f"{self._entity_name}_source_element_index" + for e in d.elements: + if e.properties[source_key] != last_club: + i -= 1 + last_club = e.properties[source_key] + if i == -1: + break + d.properties[source_key] = last_club + return d def elt_list_ctor(elts: list[Element]) -> str: combined_text = "" @@ -140,11 +167,6 @@ def elt_list_ctor(elts: list[Element]) -> str: combined_text += f"Text: {element.field_to_value(self._field)}\n" return combined_text - if self._prompt_formatter is not element_list_formatter: - prompt = prompt.set(element_list_constructor=self._prompt_formatter) - else: - prompt = prompt.set(element_list_constructor=elt_list_ctor) - def eb(elts: list[Element]) -> list[list[Element]]: curr_tks = 0 curr_batch = [] @@ -167,21 +189,28 @@ def eb(elts: list[Element]) -> list[list[Element]]: batches.append(curr_batch) return batches + prompt.render_document = partial(ElementListIterPrompt.render_document, prompt) # type: ignore + if self._prompt_formatter is not element_list_formatter: + prompt = prompt.set(element_list_constructor=self._prompt_formatter) + else: + prompt = prompt.set(element_list_constructor=elt_list_ctor) setattr(prompt, "element_batcher", eb) prompt.is_done = lambda s: s != "None" - print(prompt.__dict__) + prompt = prompt.set(entity=self._entity_name) + return LLMMap(child, prompt, self._entity_name, llm, postprocess_fn=postprocess, **kwargs) - def postprocess(d: Document, i: int) -> Document: - last_club = set() - source_key = f"{self._entity_name}_source_element_index" - for e in d.elements: - if e.properties[source_key] != last_club: - i -= 1 - last_club = e.properties[source_key] - if i == -1: - break - d.properties[source_key] = last_club - return d + elif not self._use_elements: + if self._prompt is None: + raise ValueError("prompt must be specified if use_elements is False") + if isinstance(self._prompt, str): + prompt = FieldToValuePrompt( + messages=[RenderedMessage(role="user", content=self._prompt + "{value}")], field=self._field + ) + elif isinstance(self._prompt, list): + ms = [RenderedMessage(role=m["role"], content=m["content"]) for m in self._prompt] + ms.append(RenderedMessage(role="user", content="{value}")) + prompt = FieldToValuePrompt(messages=ms, field=self._field) + return LLMMap(child, prompt, self._entity_name, llm, **kwargs) def elt_sorter(elts: list[Element]) -> list[Element]: sorter_inner = make_element_sorter_fn(self._field, self._similarity_query, self._similarity_scorer) @@ -190,11 +219,10 @@ def elt_sorter(elts: list[Element]) -> list[Element]: return dummy_doc.elements prompt = prompt.set(element_select=lambda e: elt_sorter(e)[: self._num_of_elements]) - if self._tokenizer is None: - prompt = prompt.set(element_list_constructor=lambda e: self._prompt_formatter(e, self._field)) + prompt = prompt.set(element_list_constructor=lambda e: self._prompt_formatter(e, self._field)) prompt = prompt.set(entity=self._entity_name) - llm_map = LLMMap(child, prompt, self._entity_name, llm, postprocess_fn=postprocess, **kwargs) + llm_map = LLMMap(child, prompt, self._entity_name, llm, **kwargs) return llm_map @context_params(OperationTypes.INFORMATION_EXTRACTOR) From 8bf42d55765a60af5369019d9404942cb25140fc Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 28 Jan 2025 10:41:33 -0800 Subject: [PATCH 28/46] fix llm_map_elements to deal with postprocess index Signed-off-by: Henry Lindeman --- .../unit/transforms/test_extract_entity.py | 17 ----------------- lib/sycamore/sycamore/transforms/base_llm.py | 2 +- 2 files changed, 1 insertion(+), 18 deletions(-) diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py b/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py index 38b13d78e..35e3e4880 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py @@ -111,15 +111,6 @@ def test_extract_entity_document_field_messages(self, mocker): llm_map = extractor.as_llm_map(None) out_docs = llm_map.run([self.doc]) assert out_docs[0].properties.get("title") == "alt_title" - # extract_entity = ExtractEntity( - # None, - # entity_extractor=OpenAIEntityExtractor( - # "title", llm=llm, use_elements=False, prompt=[], field="properties.path" - # ), - # ) - # out_doc = extract_entity.run(self.doc) - - # assert out_doc.properties.get("title") == "alt_title" def test_extract_entity_document_field_string(self, mocker): llm = MockLLM() @@ -127,14 +118,6 @@ def test_extract_entity_document_field_string(self, mocker): llm_map = extractor.as_llm_map(None) out_docs = llm_map.run([self.doc]) assert out_docs[0].properties.get("title") == "alt_title" - # extract_entity = ExtractEntity( - # None, - # entity_extractor=OpenAIEntityExtractor( - # "title", llm=llm, use_elements=False, prompt="", field="properties.path" - # ), - # ) - # out_doc = extract_entity.run(self.doc) - # assert out_doc.properties.get("title") == "alt_title" def test_extract_entity_with_similarity_sorting(self, mocker): doc_list = [ diff --git a/lib/sycamore/sycamore/transforms/base_llm.py b/lib/sycamore/sycamore/transforms/base_llm.py index 069340bca..ae270d816 100644 --- a/lib/sycamore/sycamore/transforms/base_llm.py +++ b/lib/sycamore/sycamore/transforms/base_llm.py @@ -152,7 +152,7 @@ def llm_map_elements(self, documents: list[Document]) -> list[Document]: results = _infer_prompts( _as_sequences([p for _, p in rendered]), self._llm, self._llm_mode, self._prompt.is_done ) - for r, (e, _) in zip(results, rendered): + for (r, i), (e, _) in zip(results, rendered): e.properties[self._output_field] = r return documents From d7ff1ebf46e22da762c94054c99855da71be9c35 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 28 Jan 2025 11:00:54 -0800 Subject: [PATCH 29/46] add postprocess_fn unit tests for llm_map Signed-off-by: Henry Lindeman --- .../tests/unit/transforms/test_base_llm.py | 48 ++++++++++++++++++- lib/sycamore/sycamore/transforms/base_llm.py | 22 +++++++-- 2 files changed, 65 insertions(+), 5 deletions(-) diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py b/lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py index 3ba4efdcc..c2d4838d1 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py @@ -54,6 +54,23 @@ def test_happy_path(self): assert outdocs[1].text_representation == "booga" assert outdocs[1].properties["out"] == "booga" + def test_postprocess(self): + prompt = FakeDocPrompt() + llm = FakeLLM() + doc1 = Document({"text_representation": "ooga"}) + doc2 = Document({"text_representation": "booga"}) + count = 0 + + def ppfn(d: Document, i: int) -> Document: + nonlocal count + count += 1 + return d + + map = LLMMap(None, prompt, "out", llm, postprocess_fn=ppfn) + _ = map.llm_map([doc1, doc2]) + + assert count == 2 + class TestLLMMapElements: def test_wrong_prompt_fails_fast(self): @@ -67,9 +84,13 @@ def test_happy_path(self): prompt = FakeEltPrompt() llm = FakeLLM() doc1 = Document( - {"text_representation": "ooga", "elements": [{"text_representation": "yo"}, {"text_representation": "ho"}]} + { + "doc_id": "1", + "text_representation": "ooga", + "elements": [{"text_representation": "yo"}, {"text_representation": "ho"}], + } ) - doc2 = Document({"elements": [{"text_representation": "booga"}, {}]}) + doc2 = Document({"doc_id": "2", "elements": [{"text_representation": "booga"}, {}]}) map = LLMMapElements(None, prompt, "out", llm) outdocs = map.llm_map_elements([doc1, doc2]) @@ -77,3 +98,26 @@ def test_happy_path(self): assert outdocs[0].elements[1].properties["out"] == "oogaho" assert outdocs[1].elements[0].properties["out"] == "Nonebooga" assert outdocs[1].elements[1].properties["out"] == "NoneNone" + + def test_postprocess(self): + prompt = FakeEltPrompt() + llm = FakeLLM() + doc1 = Document( + { + "doc_id": "1", + "text_representation": "ooga", + "elements": [{"text_representation": "yo"}, {"text_representation": "ho"}], + } + ) + doc2 = Document({"doc_id": "2", "elements": [{"text_representation": "booga"}, {}]}) + count = 0 + + def ppfn(e: Element, i: int) -> Element: + nonlocal count + count += 1 + return e + + map = LLMMapElements(None, prompt, "out", llm, postprocess_fn=ppfn) + _ = map.llm_map_elements([doc1, doc2]) + + assert count == 4 diff --git a/lib/sycamore/sycamore/transforms/base_llm.py b/lib/sycamore/sycamore/transforms/base_llm.py index ae270d816..e8414c306 100644 --- a/lib/sycamore/sycamore/transforms/base_llm.py +++ b/lib/sycamore/sycamore/transforms/base_llm.py @@ -119,6 +119,11 @@ class LLMMapElements(MapBatch): llm: The llm to use for inference. llm_mode: How to call the llm - sync/async/batch. All LLMs do not necessarily implement all options. + postprocess_fn: function to call on documents after performing the + llm inference. If the prompt rendered into multiple RenderedPrompts, + ``i`` is the index of the RenderedPrompt that succeeded; if the + prompt rendered into an empty list, ``i`` is -1; and otherwise + ``i`` is 0 Example: .. code-block:: python @@ -138,6 +143,7 @@ def __init__( output_field: str, llm: LLM, llm_mode: LLMMode = LLMMode.SYNC, + postprocess_fn: Callable[[Element, int], Element] = lambda e, i: e, **kwargs, ): self._prompt = prompt @@ -145,15 +151,25 @@ def __init__( self._output_field = output_field self._llm = llm self._llm_mode = llm_mode + self._postprocess_fn = postprocess_fn super().__init__(child, f=self.llm_map_elements, **kwargs) def llm_map_elements(self, documents: list[Document]) -> list[Document]: - rendered = [(e, self._prompt.render_element(e, d)) for d in documents for e in d.elements] + rendered = [(d, e, self._prompt.render_element(e, d)) for d in documents for e in d.elements] results = _infer_prompts( - _as_sequences([p for _, p in rendered]), self._llm, self._llm_mode, self._prompt.is_done + _as_sequences([p for _, _, p in rendered]), self._llm, self._llm_mode, self._prompt.is_done ) - for (r, i), (e, _) in zip(results, rendered): + new_elts = [] + last_doc = None + for (r, i), (d, e, _) in zip(results, rendered): + if last_doc is not None and last_doc.doc_id != d.doc_id: + last_doc.elements = new_elts + new_elts = [] e.properties[self._output_field] = r + new_elts.append(self._postprocess_fn(e, i)) + last_doc = d + if last_doc is not None: + last_doc.elements = new_elts return documents def _validate_prompt(self): From a7a2cc00a66bde784fcc88536b2b78bf1464c0ad Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 28 Jan 2025 11:03:25 -0800 Subject: [PATCH 30/46] ruff complaint Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/transforms/base_llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/sycamore/sycamore/transforms/base_llm.py b/lib/sycamore/sycamore/transforms/base_llm.py index e8414c306..46a8de6dd 100644 --- a/lib/sycamore/sycamore/transforms/base_llm.py +++ b/lib/sycamore/sycamore/transforms/base_llm.py @@ -183,5 +183,5 @@ def _validate_prompt(self): pass -def _as_sequences(l: list[Union[RenderedPrompt, Sequence[RenderedPrompt]]]) -> list[Sequence[RenderedPrompt]]: - return [[p] if isinstance(p, RenderedPrompt) else p for p in l] +def _as_sequences(ls: list[Union[RenderedPrompt, Sequence[RenderedPrompt]]]) -> list[Sequence[RenderedPrompt]]: + return [[p] if isinstance(p, RenderedPrompt) else p for p in ls] From ebf721e189ceff815e1a5992126fc732ce451b33 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 28 Jan 2025 11:44:47 -0800 Subject: [PATCH 31/46] fix docset unittests Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/docset.py | 2 +- lib/sycamore/sycamore/llms/llms.py | 17 ++++++++++++++++- lib/sycamore/sycamore/tests/unit/test_docset.py | 15 +++++++++++---- 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/lib/sycamore/sycamore/docset.py b/lib/sycamore/sycamore/docset.py index d41139280..c63b3a297 100644 --- a/lib/sycamore/sycamore/docset.py +++ b/lib/sycamore/sycamore/docset.py @@ -1393,7 +1393,7 @@ def llm_cluster_entity(self, llm: LLM, instruction: str, field: str, **kwargs) - prompt_kwargs = {"messages": messages} # call to LLM - completion = llm.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={"temperature": 0}) + completion = llm.generate_old(prompt_kwargs=prompt_kwargs, llm_kwargs={"temperature": 0}) groups = extract_json(completion) diff --git a/lib/sycamore/sycamore/llms/llms.py b/lib/sycamore/sycamore/llms/llms.py index 7c1b39e5a..f8903f335 100644 --- a/lib/sycamore/sycamore/llms/llms.py +++ b/lib/sycamore/sycamore/llms/llms.py @@ -7,7 +7,8 @@ from typing import Any, Optional import pydantic from sycamore.utils.cache import Cache -from sycamore.llms.prompts import RenderedPrompt +from sycamore.llms.prompts import RenderedPrompt, RenderedMessage +from sycamore.utils.deprecate import deprecated class LLMMode(Enum): @@ -29,6 +30,20 @@ def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) """Generates a response from the LLM for the given prompt and LLM parameters.""" pass + @deprecated(version="0.1.31", reason="Use generate, with a RenderedPrompt, instead") + def generate_old(self, *, prompt_kwargs: dict[str, Any], llm_kwargs: Optional[dict] = None) -> str: + """Generates a response from the LLM""" + if "prompt" in prompt_kwargs: + prompt = prompt_kwargs.get("prompt") + rendered = RenderedPrompt(messages=[RenderedMessage(role="user", content=f"{prompt}")]) + elif "messages" in prompt_kwargs: + ms = prompt_kwargs.get("messages", []) + messages = [RenderedMessage(role=m["role"], content=m["content"]) for m in ms] + rendered = RenderedPrompt(messages=messages) + else: + raise ValueError("Either 'prompt' or 'messages' must be specified in prompt_kwargs") + return self.generate(prompt=rendered, llm_kwargs=llm_kwargs) + @abstractmethod def is_chat_mode(self) -> bool: """Returns True if the LLM is in chat mode, False otherwise.""" diff --git a/lib/sycamore/sycamore/tests/unit/test_docset.py b/lib/sycamore/sycamore/tests/unit/test_docset.py index 55cfd29dc..9f166ab71 100644 --- a/lib/sycamore/sycamore/tests/unit/test_docset.py +++ b/lib/sycamore/sycamore/tests/unit/test_docset.py @@ -31,6 +31,7 @@ ) from sycamore.transforms import Filter from sycamore.transforms.base import get_name_from_callable +from sycamore.transforms.base_llm import LLMMap from sycamore.transforms.extract_entity import OpenAIEntityExtractor from sycamore.transforms.extract_schema import SchemaExtractor from sycamore.transforms.query import QueryExecutor @@ -44,6 +45,7 @@ def __init__(self): super().__init__(model_name="mock_model") def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: + print(prompt) if llm_kwargs is None: llm_kwargs = {} if len(prompt.messages) > 1 and prompt.messages[1].content.endswith("Element_index: 1\nText: third element\n"): @@ -54,14 +56,14 @@ def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) ): return "None" elif ( - len(prompt.messages) > 0 + len(prompt.messages) > 1 and "first short element" in prompt.messages[1].content and "second longer element with more words" in prompt.messages[1].content and llm_kwargs == {} ): return "4" elif ( - len(prompt.messages) > 0 + len(prompt.messages) > 1 and "very long element with many words that might exceed token limit" in prompt.messages[1].content and llm_kwargs == {} ): @@ -71,6 +73,9 @@ def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) elif asdict(prompt) == {"messages": [{"role": "user", "content": "test2"}]} and llm_kwargs == {}: return "2" + elif prompt.messages[-1].content.endswith('"1, 2, one, two, 1, 3".'): + return '{"groups": ["group1", "group2", "group3"]}' + elif ( prompt.messages == LlmClusterEntityFormGroupsMessagesPrompt( @@ -79,7 +84,8 @@ def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) ): return '{"groups": ["group1", "group2", "group3"]}' elif ( - prompt.messages[0] + "['group1', 'group2', 'group3']" in prompt.messages[0].content + or prompt.messages[0] == LlmClusterEntityAssignGroupsMessagesPrompt( field="text_representation", groups=["group1", "group2", "group3"] ).as_messages()[0] @@ -171,10 +177,11 @@ def test_embedding(self, mocker): def test_llm_extract_entity(self, mocker): context = mocker.Mock(spec=Context) + context.params = {} llm = mocker.Mock(spec=LLM) docset = DocSet(context, None) docset = docset.extract_entity(entity_extractor=OpenAIEntityExtractor("title", llm=llm, prompt_template="")) - assert isinstance(docset.lineage(), ExtractEntity) + assert isinstance(docset.lineage(), LLMMap) def test_query(self, mocker): context = mocker.Mock(spec=Context) From 0bd2a459163323d6ca072e72cb5003c8a3b3f869 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 28 Jan 2025 13:24:31 -0800 Subject: [PATCH 32/46] move a bunch of stuff back to llm.generate_old. This includes the active implementation of extract entity bc I don't want to deal with llm_filter just yet Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/llms.py | 13 +++++++++++++ .../sycamore/query/execution/operations.py | 2 +- lib/sycamore/sycamore/query/strategy.py | 6 +++--- .../sycamore/tests/unit/query/test_operations.py | 9 ++++++--- lib/sycamore/sycamore/tests/unit/test_docset.py | 14 ++++++++------ .../tests/unit/transforms/test_llm_filter.py | 9 ++++++++- lib/sycamore/sycamore/transforms/extract_entity.py | 14 ++++++++------ .../sycamore/transforms/extract_graph_entities.py | 2 +- lib/sycamore/sycamore/transforms/llm_filter.py | 8 +++++++- 9 files changed, 55 insertions(+), 22 deletions(-) diff --git a/lib/sycamore/sycamore/llms/llms.py b/lib/sycamore/sycamore/llms/llms.py index f8903f335..07515874d 100644 --- a/lib/sycamore/sycamore/llms/llms.py +++ b/lib/sycamore/sycamore/llms/llms.py @@ -57,6 +57,19 @@ async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[d """Generates a response from the LLM for the given prompt and LLM parameters asynchronously.""" raise NotImplementedError("This LLM does not support asynchronous generation.") + @deprecated(version="0.1.31", reason="Use generate_async, with a RenderedPrompt, instead") + async def generate_async_old(self, *, prompt_kwargs: dict[str, Any], llm_kwargs: Optional[dict] = None) -> str: + if "prompt" in prompt_kwargs: + prompt = prompt_kwargs.get("prompt") + rendered = RenderedPrompt(messages=[RenderedMessage(role="user", content=f"{prompt}")]) + elif "messages" in prompt_kwargs: + ms = prompt_kwargs.get("messages", []) + messages = [RenderedMessage(role=m["role"], content=m["content"]) for m in ms] + rendered = RenderedPrompt(messages=messages) + else: + raise ValueError("Either 'prompt' or 'messages' must be specified in prompt_kwargs") + return await self.generate_async(prompt=rendered, llm_kwargs=llm_kwargs) + def __str__(self): return f"{self.__class__.__name__}({self._model_name})" diff --git a/lib/sycamore/sycamore/query/execution/operations.py b/lib/sycamore/sycamore/query/execution/operations.py index 200cf832d..033d0b7fb 100644 --- a/lib/sycamore/sycamore/query/execution/operations.py +++ b/lib/sycamore/sycamore/query/execution/operations.py @@ -94,7 +94,7 @@ def summarize_data( prompt_kwargs = {"messages": messages} # call to LLM - completion = llm.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={"temperature": 0}) + completion = llm.generate_old(prompt_kwargs=prompt_kwargs, llm_kwargs={"temperature": 0}) # LLM response return completion diff --git a/lib/sycamore/sycamore/query/strategy.py b/lib/sycamore/sycamore/query/strategy.py index edf03c727..f38fa9b31 100644 --- a/lib/sycamore/sycamore/query/strategy.py +++ b/lib/sycamore/sycamore/query/strategy.py @@ -86,7 +86,7 @@ def __call__(self, plan: LogicalPlan) -> LogicalPlan: modified_description = self.postprocess_llm_helper( f""" The following is the description of a Python function. I am modifying the function code - to remove any functionality that specifically has to do with "{op.query_phrase}", thereby + to remove any functionality that specifically has to do with "{op.query_phrase}", thereby generalizing the description to be more flexible. Return only the modified description. Do not make assumptions about the intent of the question that are not explicitly specified. @@ -115,7 +115,7 @@ def __call__(self, plan: LogicalPlan) -> LogicalPlan: f""" Generate a one-line description for a python function whose goal is to filter the input records based on whether they contain {op.query_phrase}. - Here are two example outputs: + Here are two example outputs: (1) Filter to records involving wildfires. (2) Filter to records that occurred in Northwest USA. """, @@ -158,7 +158,7 @@ def postprocess_llm_helper(self, user_message: str) -> str: ] prompt_kwargs = {"messages": messages} - chat_completion = self.llm.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={"temperature": 0}) + chat_completion = self.llm.generate_old(prompt_kwargs=prompt_kwargs, llm_kwargs={"temperature": 0}) return chat_completion diff --git a/lib/sycamore/sycamore/tests/unit/query/test_operations.py b/lib/sycamore/sycamore/tests/unit/query/test_operations.py index 59ffb7b38..1a847d8fe 100644 --- a/lib/sycamore/sycamore/tests/unit/query/test_operations.py +++ b/lib/sycamore/sycamore/tests/unit/query/test_operations.py @@ -26,20 +26,23 @@ def __init__(self): super().__init__(model_name="mock_model") def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: + if prompt.messages[0].content.endswith('"1, 2, one, two, 1, 3".'): + return '{"groups": ["group1", "group2", "group3"]}' if ( - prompt_kwargs["messages"] + prompt.messages == LlmClusterEntityFormGroupsMessagesPrompt( field="text_representation", instruction="", text="1, 2, one, two, 1, 3" ).as_messages() ): return '{"groups": ["group1", "group2", "group3"]}' elif ( - prompt_kwargs["messages"][0] + "['group1', 'group2', 'group3']" in prompt.messages[0].content + or prompt.messages[0] == LlmClusterEntityAssignGroupsMessagesPrompt( field="text_representation", groups=["group1", "group2", "group3"] ).as_messages()[0] ): - value = prompt_kwargs["messages"][1]["content"] + value = prompt.messages[1].content if value == "1" or value == "one": return "group1" elif value == "2" or value == "two": diff --git a/lib/sycamore/sycamore/tests/unit/test_docset.py b/lib/sycamore/sycamore/tests/unit/test_docset.py index 9f166ab71..78301ceb0 100644 --- a/lib/sycamore/sycamore/tests/unit/test_docset.py +++ b/lib/sycamore/sycamore/tests/unit/test_docset.py @@ -48,7 +48,7 @@ def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) print(prompt) if llm_kwargs is None: llm_kwargs = {} - if len(prompt.messages) > 1 and prompt.messages[1].content.endswith("Element_index: 1\nText: third element\n"): + if prompt.messages[-1].content.endswith("Element_index: 1\nText: third element\n"): return "None" if ( asdict(prompt) == {"messages": [{"role": "user", "content": "Element_index: 1\nText: third element\n"}]} @@ -56,22 +56,24 @@ def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) ): return "None" elif ( - len(prompt.messages) > 1 - and "first short element" in prompt.messages[1].content - and "second longer element with more words" in prompt.messages[1].content + "first short element" in prompt.messages[-1].content + and "second longer element with more words" in prompt.messages[-1].content and llm_kwargs == {} ): return "4" elif ( - len(prompt.messages) > 1 - and "very long element with many words that might exceed token limit" in prompt.messages[1].content + "very long element with many words that might exceed token limit" in prompt.messages[-1].content and llm_kwargs == {} ): return "5" elif asdict(prompt) == {"messages": [{"role": "user", "content": "test1"}]} and llm_kwargs == {}: return "4" + elif prompt.messages[0].content == "test1": + return "4" elif asdict(prompt) == {"messages": [{"role": "user", "content": "test2"}]} and llm_kwargs == {}: return "2" + elif prompt.messages[0].content == "test2": + return "2" elif prompt.messages[-1].content.endswith('"1, 2, one, two, 1, 3".'): return '{"groups": ["group1", "group2", "group3"]}' diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_llm_filter.py b/lib/sycamore/sycamore/tests/unit/transforms/test_llm_filter.py index 26a6ba3d6..08a4e484d 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_llm_filter.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_llm_filter.py @@ -7,8 +7,10 @@ from sycamore.data import Document, Element from sycamore.functions import Tokenizer from sycamore.llms import LLM +from sycamore.plan_nodes import Node from sycamore.tests.unit.test_docset import MockLLM, TestSimilarityScorer, MockTokenizer from sycamore.transforms.extract_entity import EntityExtractor +from sycamore.transforms.base_llm import LLMMap tokenizer_doc = [ Document( @@ -225,7 +227,7 @@ def test_llm_filter_with_tokenizer_and_max_tokens(self): threshold=3, use_elements=True, tokenizer=mock_tokenizer, - max_tokens=10, # Low token limit to test windowing + max_tokens=20, # Low token limit to test windowing ) taken = filtered_docset.take() @@ -254,6 +256,11 @@ def __init__(self, entity_name, bad_val): super().__init__(entity_name) self.bad_val = bad_val + def as_llm_map( + self, child: Optional[Node], context: Optional[Context] = None, llm: Optional[LLM] = None, **kwargs + ) -> LLMMap: + pass + def extract_entity( self, document: Document, context: Optional[Context] = None, llm: Optional[LLM] = None ) -> Document: diff --git a/lib/sycamore/sycamore/transforms/extract_entity.py b/lib/sycamore/sycamore/transforms/extract_entity.py index 4376910cc..a0ec652e6 100644 --- a/lib/sycamore/sycamore/transforms/extract_entity.py +++ b/lib/sycamore/sycamore/transforms/extract_entity.py @@ -5,9 +5,11 @@ from sycamore.context import Context, context_params, OperationTypes from sycamore.data import Element, Document from sycamore.llms import LLM -from sycamore.llms.prompts import ( +from sycamore.llms.prompts.default_prompts import ( EntityExtractorZeroShotGuidancePrompt, EntityExtractorFewShotGuidancePrompt, + _EntityExtractorZeroShotGuidancePrompt, + _EntityExtractorFewShotGuidancePrompt, ) from sycamore.llms.prompts.prompts import ( ElementListIterPrompt, @@ -255,10 +257,10 @@ def _handle_element_prompting(self, document: Document) -> Any: if self._prompt is None: prompt: Any = None if self._prompt_template: - prompt = EntityExtractorFewShotGuidancePrompt() + prompt = _EntityExtractorFewShotGuidancePrompt() else: - prompt = EntityExtractorZeroShotGuidancePrompt() - entities = self._llm.generate( + prompt = _EntityExtractorZeroShotGuidancePrompt() + entities = self._llm.generate_old( prompt_kwargs={ "prompt": prompt, "entity": self._entity_name, @@ -302,10 +304,10 @@ def _get_entities(self, content: str, prompt: Optional[Union[list[dict], str]] = assert prompt is not None, "No prompt found for entity extraction" if isinstance(self._prompt, str): prompt = self._prompt + content - response = self._llm.generate(prompt_kwargs={"prompt": prompt}, llm_kwargs={}) + response = self._llm.generate_old(prompt_kwargs={"prompt": prompt}, llm_kwargs={}) else: messages = (self._prompt or []) + [{"role": "user", "content": content}] - response = self._llm.generate(prompt_kwargs={"messages": messages}, llm_kwargs={}) + response = self._llm.generate_old(prompt_kwargs={"messages": messages}, llm_kwargs={}) return response diff --git a/lib/sycamore/sycamore/transforms/extract_graph_entities.py b/lib/sycamore/sycamore/transforms/extract_graph_entities.py index 654880e7f..21e2122fd 100644 --- a/lib/sycamore/sycamore/transforms/extract_graph_entities.py +++ b/lib/sycamore/sycamore/transforms/extract_graph_entities.py @@ -140,7 +140,7 @@ async def _extract_from_section(self, summary: str) -> list[str]: llm_kwargs = {"response_format": model} messages = generate_prompt_messages(self.prompt, summary) outputs.append( - await self.llm.generate_async(prompt_kwargs={"messages": messages}, llm_kwargs=llm_kwargs) + await self.llm.generate_async_old(prompt_kwargs={"messages": messages}, llm_kwargs=llm_kwargs) ) except Exception as e: logger.warn(f"OPENAI CALL FAILED: {e}") diff --git a/lib/sycamore/sycamore/transforms/llm_filter.py b/lib/sycamore/sycamore/transforms/llm_filter.py index 611c4f831..b350da0b4 100644 --- a/lib/sycamore/sycamore/transforms/llm_filter.py +++ b/lib/sycamore/sycamore/transforms/llm_filter.py @@ -35,6 +35,8 @@ def tokenized_threshold_llm_filter( max_tokens: int, tokenizer: Tokenizer, ) -> bool: + print("=" * 80) + print(doc) element_sorter(doc) evaluated_elements = 0 @@ -57,10 +59,14 @@ def tokenized_threshold_llm_filter( if score >= doc.get(doc_source_field_name, 0): doc.properties[f"{new_field}"] = score doc.properties[f"{new_field}_source_element_index"] = window_indices + print(score, combined_text) if score >= threshold: + print("-" * 80) + print(doc) return True evaluated_elements += 1 - + print("-" * 80) + print(doc) if evaluated_elements == 0: # no elements found for property return keep_none return False From 95cbaafda298e417eabaf69d8d34b3aa862c9bea Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 28 Jan 2025 13:39:54 -0800 Subject: [PATCH 33/46] move more stuff back to llm.generate_old Signed-off-by: Henry Lindeman --- .../sycamore/tests/unit/transforms/test_schema.py | 8 ++++---- .../sycamore/tests/unit/transforms/test_summarize.py | 2 +- .../transforms/extract_graph_relationships.py | 2 +- lib/sycamore/sycamore/transforms/extract_schema.py | 12 ++++++------ lib/sycamore/sycamore/transforms/summarize.py | 8 ++++---- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_schema.py b/lib/sycamore/sycamore/tests/unit/transforms/test_schema.py index 56c4c6943..e3fe3ae75 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_schema.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_schema.py @@ -4,7 +4,7 @@ from ray.util import inspect_serializability -from sycamore.llms.prompts import SchemaZeroShotGuidancePrompt +from sycamore.llms.prompts.default_prompts import _SchemaZeroShotGuidancePrompt from sycamore.data import Document, Element from sycamore.llms.llms import LLM, FakeLLM from sycamore.schema import Schema, SchemaField @@ -37,7 +37,7 @@ def test_serializable(self, mocker): def test_extract_schema(self, mocker): llm = mocker.Mock(spec=LLM) - generate = mocker.patch.object(llm, "generate") + generate = mocker.patch.object(llm, "generate_old") generate.return_value = '```json {"accidentNumber": "string"}```' num_of_elements = 10 @@ -66,7 +66,7 @@ def test_extract_schema(self, mocker): assert doc.properties == ground_truth generate.assert_called_once_with( prompt_kwargs={ - "prompt": SchemaZeroShotGuidancePrompt(), + "prompt": _SchemaZeroShotGuidancePrompt(), "entity": class_name, "max_num_properties": max_num_properties, "query": schema_extractor._prompt_formatter(doc.elements), @@ -75,7 +75,7 @@ def test_extract_schema(self, mocker): def test_extract_batch_schema(self, mocker): llm = mocker.Mock(spec=LLM) - generate = mocker.patch.object(llm, "generate") + generate = mocker.patch.object(llm, "generate_old") generate.return_value = '```json {"accidentNumber": "string"}```' schema_extractor = LLMSchemaExtractor("AircraftIncident", llm) diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_summarize.py b/lib/sycamore/sycamore/tests/unit/transforms/test_summarize.py index f4a934c46..f749280eb 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_summarize.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_summarize.py @@ -21,7 +21,7 @@ def test_summarize_text_does_not_call_llm(self, mocker): def test_summarize_text_calls_llm(self, mocker): llm = mocker.Mock(spec=LLM) - generate = mocker.patch.object(llm, "generate") + generate = mocker.patch.object(llm, "generate_old") generate.return_value = "this is the summary" doc = Document() element1 = Element() diff --git a/lib/sycamore/sycamore/transforms/extract_graph_relationships.py b/lib/sycamore/sycamore/transforms/extract_graph_relationships.py index 6d1f5e1fc..c37806568 100644 --- a/lib/sycamore/sycamore/transforms/extract_graph_relationships.py +++ b/lib/sycamore/sycamore/transforms/extract_graph_relationships.py @@ -162,7 +162,7 @@ async def _generate_relationships(self, section: HierarchicalDocument) -> dict: messages = generate_prompt_messages(self.prompt, entities_list[i], section.data["summary"]) llm_kwargs = {"response_format": models[i]} prompt_kwargs = {"messages": messages} - outputs.append(await self.llm.generate_async(prompt_kwargs=prompt_kwargs, llm_kwargs=llm_kwargs)) + outputs.append(await self.llm.generate_async_old(prompt_kwargs=prompt_kwargs, llm_kwargs=llm_kwargs)) async def _process_llm_output(outputs: list[str], parsed_metadata: dict, summary: str): parsed_res: dict[str, Any] = {} diff --git a/lib/sycamore/sycamore/transforms/extract_schema.py b/lib/sycamore/sycamore/transforms/extract_schema.py index c4e0bdd41..d8f066cde 100644 --- a/lib/sycamore/sycamore/transforms/extract_schema.py +++ b/lib/sycamore/sycamore/transforms/extract_schema.py @@ -5,8 +5,8 @@ from sycamore.data import Element, Document from sycamore.schema import Schema from sycamore.llms import LLM -from sycamore.llms.prompts import ( - SchemaZeroShotGuidancePrompt, +from sycamore.llms.prompts.default_prompts import ( + _SchemaZeroShotGuidancePrompt, PropertiesZeroShotGuidancePrompt, ) from sycamore.llms.prompts.default_prompts import ExtractPropertiesFromSchemaPrompt @@ -99,9 +99,9 @@ def extract_schema(self, document: Document) -> Document: def _handle_zero_shot_prompting(self, document: Document) -> Any: sub_elements = [document.elements[i] for i in range((min(self._num_of_elements, len(document.elements))))] - prompt = SchemaZeroShotGuidancePrompt() + prompt = _SchemaZeroShotGuidancePrompt() - entities = self._llm.generate( + entities = self._llm.generate_old( prompt_kwargs={ "prompt": prompt, "entity": self._entity_name, @@ -228,7 +228,7 @@ def _handle_zero_shot_prompting(self, document: Document) -> Any: ) if isinstance(self._schema, Schema): prompt = ExtractPropertiesFromSchemaPrompt(schema=self._schema, text=text) - entities = self._llm.generate(prompt_kwargs={"prompt": prompt}) + entities = self._llm.generate_old(prompt_kwargs={"prompt": prompt}) else: schema = self._schema or document.properties.get("_schema") assert schema is not None, "Schema must be provided or detected before extracting properties." @@ -236,7 +236,7 @@ def _handle_zero_shot_prompting(self, document: Document) -> Any: schema_name = self._schema_name or document.properties.get("_schema_class") assert schema_name is not None, "Schema name must be provided or detected before extracting properties." - entities = self._llm.generate( + entities = self._llm.generate_old( prompt_kwargs={ "prompt": PropertiesZeroShotGuidancePrompt(), "entity": schema_name, diff --git a/lib/sycamore/sycamore/transforms/summarize.py b/lib/sycamore/sycamore/transforms/summarize.py index 9230f44da..bd54452a5 100644 --- a/lib/sycamore/sycamore/transforms/summarize.py +++ b/lib/sycamore/sycamore/transforms/summarize.py @@ -9,7 +9,7 @@ from sycamore.llms.prompts.default_prompts import SummarizeDataMessagesPrompt from sycamore.plan_nodes import NonCPUUser, NonGPUUser, Node from sycamore.llms import LLM -from sycamore.llms.prompts import TextSummarizerGuidancePrompt +from sycamore.llms.prompts.default_prompts import _TextSummarizerGuidancePrompt from sycamore.transforms.map import Map from sycamore.utils.time_trace import timetrace @@ -77,10 +77,10 @@ def summarize(self, document: Document) -> Document: @timetrace("SummText") def _summarize_text_element(self, element: Element) -> Element: - prompt = TextSummarizerGuidancePrompt() + prompt = _TextSummarizerGuidancePrompt() if element.text_representation: - response = self._llm.generate(prompt_kwargs={"prompt": prompt, "query": element.text_representation}) + response = self._llm.generate_old(prompt_kwargs={"prompt": prompt, "query": element.text_representation}) element.properties["summary"] = response return element @@ -96,7 +96,7 @@ def __call__(self, text: str) -> str: t0 = time.time() # call to LLM - summary = self.llm.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={"temperature": 0}) + summary = self.llm.generate_old(prompt_kwargs=prompt_kwargs, llm_kwargs={"temperature": 0}) t1 = time.time() logging.info(f"Summarizer took {t1 - t0} seconds to generate summary.") From ea7f0e6e623672acb48bbc7da34b0e9d08df8a8e Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 28 Jan 2025 13:45:30 -0800 Subject: [PATCH 34/46] fix the last few mocks Signed-off-by: Henry Lindeman --- .../sycamore/tests/unit/transforms/test_schema.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_schema.py b/lib/sycamore/sycamore/tests/unit/transforms/test_schema.py index e3fe3ae75..141b9e87c 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_schema.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_schema.py @@ -98,7 +98,7 @@ def test_extract_batch_schema(self, mocker): def test_extract_properties(self, mocker): llm = mocker.Mock(spec=LLM) - generate = mocker.patch.object(llm, "generate") + generate = mocker.patch.object(llm, "generate_old") generate.return_value = '```json {"accidentNumber": "FTW95FA129", "location": "Fort Worth, TX"}```' doc = Document() @@ -124,7 +124,7 @@ def test_extract_properties(self, mocker): def test_extract_properties_explicit_json(self, mocker): llm = mocker.Mock(spec=LLM) - generate = mocker.patch.object(llm, "generate") + generate = mocker.patch.object(llm, "generate_old") generate.return_value = '{"accidentNumber": "FTW95FA129"}' doc = Document() @@ -147,7 +147,7 @@ def test_extract_properties_explicit_json(self, mocker): def test_extract_properties_fixed_json(self, mocker): llm = mocker.Mock(spec=LLM) - generate = mocker.patch.object(llm, "generate") + generate = mocker.patch.object(llm, "generate_old") generate.return_value = '{"accidentNumber": "FTW95FA129"}' doc = Document() @@ -166,7 +166,7 @@ def test_extract_properties_fixed_json(self, mocker): def test_extract_properties_with_schema(self, mocker): llm = mocker.Mock(spec=LLM) - generate = mocker.patch.object(llm, "generate") + generate = mocker.patch.object(llm, "generate_old") generate.return_value = ( '{"startDate": "2022-01-22 00:01:31", ' '"endDate": "2022-01-24 00:01:59", ' From 57a4e4b1bb25e3516319bd2f52b878c5c28377f4 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 28 Jan 2025 14:20:24 -0800 Subject: [PATCH 35/46] ruff linelength Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/prompts/prompts.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py index 703422573..db806077b 100644 --- a/lib/sycamore/sycamore/llms/prompts/prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -266,12 +266,14 @@ class ElementListIterPrompt(ElementListPrompt): prompt.render_document(doc) # [ # [ - # {"role": "system", "content": "You are a program that returns 'None' if you don't know the answer to my question"}, + # {"role": "system", "content": "You are a program that returns 'None' if you don't + # know the answer to my question"}, # {"role": "user", "content": "What is the capital of the country described?\\nElements:\\n # ELEMENT 0: \\nELEMENT 1: "} # ], # [ - # {"role": "system", "content": "You are a program that returns 'None' if you don't know the answer to my question"}, + # {"role": "system", "content": "You are a program that returns 'None' if you don't + # know the answer to my question"}, # {"role": "user", "content": "What is the capital of the country described?\\nElements:\\n # ELEMENT 0: \\nELEMENT 1: "} # ] From a312ba31f8c6e5c69637cc6cd00fd12e20365ed2 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 28 Jan 2025 15:03:28 -0800 Subject: [PATCH 36/46] mypy!!! Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/evaluation/subtasks.py | 6 +++--- lib/sycamore/sycamore/llms/openai.py | 3 ++- .../sycamore/llms/prompts/default_prompts.py | 16 ++++++++-------- lib/sycamore/sycamore/query/planner.py | 4 ++-- .../tests/integration/llms/test_openai.py | 4 ++-- .../sycamore/tests/unit/query/test_operations.py | 1 + .../transforms/test_extract_table_properties.py | 2 +- .../tests/unit/transforms/test_llm_filter.py | 2 +- .../tests/unit/transforms/test_llm_query.py | 4 ++-- .../transforms/test_resolve_graph_entities.py | 1 + lib/sycamore/sycamore/transforms/base_llm.py | 4 ++-- .../sycamore/transforms/extract_entity.py | 12 ++++++++---- .../transforms/extract_table_properties.py | 4 ++-- lib/sycamore/sycamore/transforms/llm_query.py | 4 ++-- .../sycamore/transforms/summarize_images.py | 6 +++--- 15 files changed, 40 insertions(+), 33 deletions(-) diff --git a/lib/sycamore/sycamore/evaluation/subtasks.py b/lib/sycamore/sycamore/evaluation/subtasks.py index f7db14f89..81189c4cd 100644 --- a/lib/sycamore/sycamore/evaluation/subtasks.py +++ b/lib/sycamore/sycamore/evaluation/subtasks.py @@ -5,7 +5,7 @@ from sycamore.docset import DocSet from sycamore.llms.llms import LLM from sycamore.llms.openai import OpenAI, OpenAIModels -from sycamore.llms.prompts.default_prompts import SimpleGuidancePrompt, TaskIdentifierZeroShotGuidancePrompt +from sycamore.llms.prompts.default_prompts import SimplePrompt, _TaskIdentifierZeroShotGuidancePrompt from sycamore.transforms.embed import Embedder, SentenceTransformerEmbedder from sycamore.transforms.query import QueryExecutor @@ -22,7 +22,7 @@ def __init__( model_name="sentence-transformers/all-MiniLM-L6-v2", batch_size=100 ), llm: LLM = OpenAI(OpenAIModels.GPT_3_5_TURBO.value), - prompt: SimpleGuidancePrompt = TaskIdentifierZeroShotGuidancePrompt(), + prompt: SimplePrompt = _TaskIdentifierZeroShotGuidancePrompt(), knn_query: bool = False, ): if subtask_data: @@ -44,7 +44,7 @@ def __init__( def _get_formulas(self, document: Document) -> list[Document]: f_list = [] if document.properties["subtasks_reqd"]: - task_id = self._llm.generate( + task_id = self._llm.generate_old( prompt_kwargs={ "prompt": self._prompt, "question": document["question"], diff --git a/lib/sycamore/sycamore/llms/openai.py b/lib/sycamore/sycamore/llms/openai.py index d6dc2b89b..c5783434d 100644 --- a/lib/sycamore/sycamore/llms/openai.py +++ b/lib/sycamore/sycamore/llms/openai.py @@ -336,9 +336,10 @@ def _get_generate_kwargs(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict else: role = m.role if m.images is None: - content = m.content + content: Union[str, list] = m.content else: content = [{"type": "text", "text": m.content}] + assert isinstance(content, list) # mypy!!! for im in m.images: content.append({"type": "image_url", "image_url": base64_data_url(im)}) messages_list.append({"role": role, "content": content}) diff --git a/lib/sycamore/sycamore/llms/prompts/default_prompts.py b/lib/sycamore/sycamore/llms/prompts/default_prompts.py index fed516ce0..e769eb790 100644 --- a/lib/sycamore/sycamore/llms/prompts/default_prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/default_prompts.py @@ -405,14 +405,14 @@ def __init__(self, field: str, groups: list[str]): _deprecated_prompts: dict[str, Type[SimplePrompt]] = { - "ENTITY_EXTRACTOR_ZERO_SHOT_GUIDANCE_PROMPT": EntityExtractorZeroShotGuidancePrompt, - "ENTITY_EXTRACTOR_ZERO_SHOT_GUIDANCE_PROMPT_CHAT": EntityExtractorFewShotGuidancePrompt, - "ENTITY_EXTRACTOR_FEW_SHOT_GUIDANCE_PROMPT_CHAT": EntityExtractorFewShotGuidancePrompt, - "ENTITY_EXTRACTOR_FEW_SHOT_GUIDANCE_PROMPT": EntityExtractorFewShotGuidancePrompt, - "TEXT_SUMMARIZER_GUIDANCE_PROMPT": TextSummarizerGuidancePrompt, - "TEXT_SUMMARIZER_GUIDANCE_PROMPT_CHAT": TextSummarizerGuidancePrompt, - "SCHEMA_ZERO_SHOT_GUIDANCE_PROMPT": SchemaZeroShotGuidancePrompt, - "SCHEMA_ZERO_SHOT_GUIDANCE_PROMPT_CHAT": SchemaZeroShotGuidancePrompt, + "ENTITY_EXTRACTOR_ZERO_SHOT_GUIDANCE_PROMPT": _EntityExtractorZeroShotGuidancePrompt, + "ENTITY_EXTRACTOR_ZERO_SHOT_GUIDANCE_PROMPT_CHAT": _EntityExtractorFewShotGuidancePrompt, + "ENTITY_EXTRACTOR_FEW_SHOT_GUIDANCE_PROMPT_CHAT": _EntityExtractorFewShotGuidancePrompt, + "ENTITY_EXTRACTOR_FEW_SHOT_GUIDANCE_PROMPT": _EntityExtractorFewShotGuidancePrompt, + "TEXT_SUMMARIZER_GUIDANCE_PROMPT": _TextSummarizerGuidancePrompt, + "TEXT_SUMMARIZER_GUIDANCE_PROMPT_CHAT": _TextSummarizerGuidancePrompt, + "SCHEMA_ZERO_SHOT_GUIDANCE_PROMPT": _SchemaZeroShotGuidancePrompt, + "SCHEMA_ZERO_SHOT_GUIDANCE_PROMPT_CHAT": _SchemaZeroShotGuidancePrompt, "PROPERTIES_ZERO_SHOT_GUIDANCE_PROMPT": PropertiesZeroShotGuidancePrompt, "PROPERTIES_ZERO_SHOT_GUIDANCE_PROMPT_CHAT": PropertiesZeroShotGuidancePrompt, } diff --git a/lib/sycamore/sycamore/query/planner.py b/lib/sycamore/sycamore/query/planner.py index c23182653..fcfdeb207 100644 --- a/lib/sycamore/sycamore/query/planner.py +++ b/lib/sycamore/sycamore/query/planner.py @@ -41,7 +41,7 @@ "properties.key" or "properties.count"; you can only reference one of those fields. Other than those, DO NOT USE ANY OTHER FIELD NAMES. 5. If an optional field does not have a value in the query plan, return null in its place. - 6. The first step of each plan MUST be a **QueryDatabase** or **QueryVectorDatabase" operation. + 6. The first step of each plan MUST be a **QueryDatabase** or **QueryVectorDatabase" operation. Whenever possible, include all possible filtering operations in the first step. That is, you should strive to construct an OpenSearch query that filters the data as much as possible, reducing the need for further query operations. If using a QueryVectorDatabase, always @@ -518,7 +518,7 @@ def generate_from_llm(self, question: str) -> Tuple[Any, str]: ] prompt_kwargs = {"messages": messages} - chat_completion = self._llm_client.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={"temperature": 0}) + chat_completion = self._llm_client.generate_old(prompt_kwargs=prompt_kwargs, llm_kwargs={"temperature": 0}) return prompt_kwargs, chat_completion def plan(self, question: str) -> LogicalPlan: diff --git a/lib/sycamore/sycamore/tests/integration/llms/test_openai.py b/lib/sycamore/sycamore/tests/integration/llms/test_openai.py index 3977f1b9f..67152e1b0 100644 --- a/lib/sycamore/sycamore/tests/integration/llms/test_openai.py +++ b/lib/sycamore/sycamore/tests/integration/llms/test_openai.py @@ -173,8 +173,8 @@ def test_cached_openai_pydantic_model(tmp_path: Path): class Statement(BaseModel): is_true: bool - llm_kwargs = {} - llm_kwargs_cached = {} + llm_kwargs = {} # type: ignore + llm_kwargs_cached = {} # type: ignore prompt = RenderedPrompt( messages=[RenderedMessage(role="user", content="2+2 = 4, is this statement true?")], response_format=Statement diff --git a/lib/sycamore/sycamore/tests/unit/query/test_operations.py b/lib/sycamore/sycamore/tests/unit/query/test_operations.py index 1a847d8fe..52807315a 100644 --- a/lib/sycamore/sycamore/tests/unit/query/test_operations.py +++ b/lib/sycamore/sycamore/tests/unit/query/test_operations.py @@ -51,6 +51,7 @@ def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) return "group3" else: return "" + return "" def is_chat_mode(self): return True diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_extract_table_properties.py b/lib/sycamore/sycamore/tests/unit/transforms/test_extract_table_properties.py index 5fcc2060c..0aa0f649d 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_extract_table_properties.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_extract_table_properties.py @@ -54,7 +54,7 @@ def test_extract_key_value_pair(self, mocker): mock_frombytes.return_value = image llm = mocker.Mock(spec=OpenAI) - llm.generate.return_value = '{"key1":"val1"}' + llm.generate_old.return_value = '{"key1":"val1"}' llm.format_image.return_value = {"type": "image", "data": "dummy"} property_name = "llm_response" diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_llm_filter.py b/lib/sycamore/sycamore/tests/unit/transforms/test_llm_filter.py index 08a4e484d..5e195a2ca 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_llm_filter.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_llm_filter.py @@ -259,7 +259,7 @@ def __init__(self, entity_name, bad_val): def as_llm_map( self, child: Optional[Node], context: Optional[Context] = None, llm: Optional[LLM] = None, **kwargs ) -> LLMMap: - pass + raise NotImplementedError("Not using this yet") def extract_entity( self, document: Document, context: Optional[Context] = None, llm: Optional[LLM] = None diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_llm_query.py b/lib/sycamore/sycamore/tests/unit/transforms/test_llm_query.py index 6b352dfdb..da9b944f6 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_llm_query.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_llm_query.py @@ -21,7 +21,7 @@ def test_llm_query_text_does_not_call_llm(self, mocker): def test_summarize_text_element_calls_llm(self, mocker): llm = mocker.Mock(spec=OpenAI) - generate = mocker.patch.object(llm, "generate") + generate = mocker.patch.object(llm, "generate_old") generate.return_value = {"summary": "summary"} doc = Document() element1 = Element() @@ -39,7 +39,7 @@ def test_summarize_text_element_calls_llm(self, mocker): def test_summarize_text_document_calls_llm(self, mocker): llm = mocker.Mock(spec=OpenAI) - generate = mocker.patch.object(llm, "generate") + generate = mocker.patch.object(llm, "generate_old") generate.return_value = {"summary": "summary"} doc = Document() element1 = Element() diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_resolve_graph_entities.py b/lib/sycamore/sycamore/tests/unit/transforms/test_resolve_graph_entities.py index ab909d5dc..58cfe1c7f 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_resolve_graph_entities.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_resolve_graph_entities.py @@ -62,6 +62,7 @@ def __init__(self): def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: """""" + raise NotImplementedError("All these calls are async") def is_chat_mode(self): return True diff --git a/lib/sycamore/sycamore/transforms/base_llm.py b/lib/sycamore/sycamore/transforms/base_llm.py index 46a8de6dd..d9efbddf6 100644 --- a/lib/sycamore/sycamore/transforms/base_llm.py +++ b/lib/sycamore/sycamore/transforms/base_llm.py @@ -85,8 +85,8 @@ def __init__( super().__init__(child, f=self.llm_map, **kwargs) def llm_map(self, documents: list[Document]) -> list[Document]: - rendered = [self._prompt.render_document(d) for d in documents] - rendered = _as_sequences(rendered) + rendered_inc = [self._prompt.render_document(d) for d in documents] + rendered = _as_sequences(rendered_inc) results = _infer_prompts(rendered, self._llm, self._llm_mode, self._prompt.is_done) postprocessed = [] for d, (r, i) in zip(documents, results): diff --git a/lib/sycamore/sycamore/transforms/extract_entity.py b/lib/sycamore/sycamore/transforms/extract_entity.py index a0ec652e6..15c05e313 100644 --- a/lib/sycamore/sycamore/transforms/extract_entity.py +++ b/lib/sycamore/sycamore/transforms/extract_entity.py @@ -137,6 +137,7 @@ def as_llm_map( if llm is None: llm = self._llm assert llm is not None, "Could not find an LLM to use" + prompt: SycamorePrompt # grr mypy if self._prompt_template is not None: prompt = EntityExtractorFewShotGuidancePrompt prompt = cast(ElementListPrompt, prompt.set(examples=self._prompt_template)) @@ -146,7 +147,7 @@ def as_llm_map( if self._tokenizer is not None: def postprocess(d: Document, i: int) -> Document: - last_club = set() + last_club: set[int] = set() source_key = f"{self._entity_name}_source_element_index" for e in d.elements: if e.properties[source_key] != last_club: @@ -171,11 +172,14 @@ def elt_list_ctor(elts: list[Element]) -> str: def eb(elts: list[Element]) -> list[list[Element]]: curr_tks = 0 - curr_batch = [] + curr_batch: list[Element] = [] batches = [] source_indices = set() + assert ( + self._tokenizer is not None + ), "Cannot batch elements based on token counts because tokenier is None" for e in elts: - eltl = prompt.element_list_constructor([e]) + eltl = cast(ElementListPrompt, prompt).element_list_constructor([e]) tks = len(self._tokenizer.tokenize(eltl)) if tks + curr_tks > self._max_tokens: batches.append(curr_batch) @@ -197,7 +201,7 @@ def eb(elts: list[Element]) -> list[list[Element]]: else: prompt = prompt.set(element_list_constructor=elt_list_ctor) setattr(prompt, "element_batcher", eb) - prompt.is_done = lambda s: s != "None" + setattr(prompt, "is_done", lambda s: s != "None") prompt = prompt.set(entity=self._entity_name) return LLMMap(child, prompt, self._entity_name, llm, postprocess_fn=postprocess, **kwargs) diff --git a/lib/sycamore/sycamore/transforms/extract_table_properties.py b/lib/sycamore/sycamore/transforms/extract_table_properties.py index 81e9e0cc6..0c684e37b 100644 --- a/lib/sycamore/sycamore/transforms/extract_table_properties.py +++ b/lib/sycamore/sycamore/transforms/extract_table_properties.py @@ -87,7 +87,7 @@ def extract_table_properties( "text": ( prompt_LLM if prompt_LLM is not None - else ExtractTablePropertiesPrompt.user + f"\n CSV: {ele.text_representation}" + else ExtractTablePropertiesPrompt.user + f"\n CSV: {ele.text_representation}" # type: ignore[operator] # thinks ETPP.user could be None ), }, llm.format_image(img), @@ -96,7 +96,7 @@ def extract_table_properties( {"role": "user", "content": content}, ] prompt_kwargs = {"messages": messages} - raw_answer = llm.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={}) + raw_answer = llm.generate_old(prompt_kwargs=prompt_kwargs, llm_kwargs={}) parsed_json = ExtractTableProperties.extract_parent_json(raw_answer) if parsed_json: ele.properties[property_name] = json.loads(parsed_json) diff --git a/lib/sycamore/sycamore/transforms/llm_query.py b/lib/sycamore/sycamore/transforms/llm_query.py index ba0f86426..731c4af84 100644 --- a/lib/sycamore/sycamore/transforms/llm_query.py +++ b/lib/sycamore/sycamore/transforms/llm_query.py @@ -85,7 +85,7 @@ def execute_query(self, document: Document) -> Document: break if not self._per_element: prompt_kwargs = {"prompt": final_prompt} - llm_resp = self._llm.generate(prompt_kwargs=prompt_kwargs, llm_kwargs=self._llm_kwargs) + llm_resp = self._llm.generate_old(prompt_kwargs=prompt_kwargs, llm_kwargs=self._llm_kwargs) document["properties"][self._output_property] = llm_resp else: if document.text_representation: @@ -118,7 +118,7 @@ def _query_text_object( else: prompt = self._prompt + "\n" + object.text_representation prompt_kwargs = {"prompt": prompt} - llm_resp = self._llm.generate(prompt_kwargs=prompt_kwargs, llm_kwargs=self._llm_kwargs) + llm_resp = self._llm.generate_old(prompt_kwargs=prompt_kwargs, llm_kwargs=self._llm_kwargs) if self._table_cont: object["properties"]["table_continuation"] = llm_resp else: diff --git a/lib/sycamore/sycamore/transforms/summarize_images.py b/lib/sycamore/sycamore/transforms/summarize_images.py index 9c274dd2c..d64e60404 100644 --- a/lib/sycamore/sycamore/transforms/summarize_images.py +++ b/lib/sycamore/sycamore/transforms/summarize_images.py @@ -23,11 +23,11 @@ class LLMImageSummarizer: Example: The following code demonstrates how to partition a pdf DocSet and summarize the images it contains. - This version uses a Claude model via Bedrock. + This version uses a Claude model via Bedrock. .. code-block:: python llm = Bedrock(BedrockModels.CLAUDE_3_5_SONNET) - + context = sycamore.init() doc = context.read.binary(paths=paths, binary_format="pdf")\ .partition(partitioner=SycamorePartitioner(extract_images=True))\ @@ -91,7 +91,7 @@ def summarize_image( prompt_kwargs = {"messages": messages} - raw_answer = self.llm.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={}) + raw_answer = self.llm.generate_old(prompt_kwargs=prompt_kwargs, llm_kwargs={}) return extract_json(raw_answer) def summarize_all_images(self, doc: Document) -> Document: From ebde879a17d30778095181069ca0f3cc1458c131 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 28 Jan 2025 15:16:18 -0800 Subject: [PATCH 37/46] type: ignore + line length is tricky Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/transforms/extract_table_properties.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/sycamore/sycamore/transforms/extract_table_properties.py b/lib/sycamore/sycamore/transforms/extract_table_properties.py index 0c684e37b..b07d3f68c 100644 --- a/lib/sycamore/sycamore/transforms/extract_table_properties.py +++ b/lib/sycamore/sycamore/transforms/extract_table_properties.py @@ -87,7 +87,9 @@ def extract_table_properties( "text": ( prompt_LLM if prompt_LLM is not None - else ExtractTablePropertiesPrompt.user + f"\n CSV: {ele.text_representation}" # type: ignore[operator] # thinks ETPP.user could be None + else ( + ExtractTablePropertiesPrompt.user + f"\n CSV: {ele.text_representation}" # type: ignore + ) # type ignore - thinks ETPP.user could be None ), }, llm.format_image(img), From ff5efdcb2f9db3b4475a3c1644b8604bee72b483 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 28 Jan 2025 15:48:48 -0800 Subject: [PATCH 38/46] fix generate_old with SimplePrompts Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/llms.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/lib/sycamore/sycamore/llms/llms.py b/lib/sycamore/sycamore/llms/llms.py index a5eb624da..24028bf66 100644 --- a/lib/sycamore/sycamore/llms/llms.py +++ b/lib/sycamore/sycamore/llms/llms.py @@ -8,7 +8,8 @@ from sycamore.utils.cache import Cache from sycamore.utils.thread_local import ThreadLocalAccess, ADD_METADATA_TO_OUTPUT from sycamore.data.metadata import add_metadata -from sycamore.llms.prompts import RenderedPrompt, RenderedMessage +from sycamore.llms.prompts import RenderedPrompt, RenderedMessage, SimplePrompt + from sycamore.utils.deprecate import deprecated @@ -36,7 +37,15 @@ def generate_old(self, *, prompt_kwargs: dict[str, Any], llm_kwargs: Optional[di """Generates a response from the LLM""" if "prompt" in prompt_kwargs: prompt = prompt_kwargs.get("prompt") - rendered = RenderedPrompt(messages=[RenderedMessage(role="user", content=f"{prompt}")]) + if isinstance(prompt, SimplePrompt): + prompt = prompt.as_messages() + for idx, prompt_message in enumerate(prompt): + prompt[idx]["content"] = prompt_message["content"].format(**prompt_kwargs) + rendered = RenderedPrompt( + messages=[RenderedMessage(role=m["role"], content=m["content"]) for m in prompt] + ) + else: + rendered = RenderedPrompt(messages=[RenderedMessage(role="user", content=f"{prompt}")]) elif "messages" in prompt_kwargs: ms = prompt_kwargs.get("messages", []) messages = [RenderedMessage(role=m["role"], content=m["content"]) for m in ms] @@ -62,7 +71,15 @@ async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[d async def generate_async_old(self, *, prompt_kwargs: dict[str, Any], llm_kwargs: Optional[dict] = None) -> str: if "prompt" in prompt_kwargs: prompt = prompt_kwargs.get("prompt") - rendered = RenderedPrompt(messages=[RenderedMessage(role="user", content=f"{prompt}")]) + if isinstance(prompt, SimplePrompt): + prompt = prompt.as_messages() + for idx, prompt_message in enumerate(prompt): + prompt[idx]["content"] = prompt_message["content"].format(**prompt_kwargs) + rendered = RenderedPrompt( + messages=[RenderedMessage(role=m["role"], content=m["content"]) for m in prompt] + ) + else: + rendered = RenderedPrompt(messages=[RenderedMessage(role="user", content=f"{prompt}")]) elif "messages" in prompt_kwargs: ms = prompt_kwargs.get("messages", []) messages = [RenderedMessage(role=m["role"], content=m["content"]) for m in ms] From 370e2b7c1ef0fd0267796afc2d6fed14314997a8 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 28 Jan 2025 15:54:48 -0800 Subject: [PATCH 39/46] set openai system role name to system instead of developer like their docs say because azure disagrees and they'd better both accept 'system' Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/openai.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/sycamore/sycamore/llms/openai.py b/lib/sycamore/sycamore/llms/openai.py index c5783434d..e475c416a 100644 --- a/lib/sycamore/sycamore/llms/openai.py +++ b/lib/sycamore/sycamore/llms/openai.py @@ -332,7 +332,9 @@ def _get_generate_kwargs(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict messages_list = [] for m in prompt.messages: if m.role == "system": - role = "developer" + # OpenAI docs say "developer" is the new "system" + # but Azure don't like that + role = "system" else: role = m.role if m.images is None: From 98ce6a0f73e73063e285f6b3fdfc69b16ad8d066 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Wed, 29 Jan 2025 14:21:08 -0800 Subject: [PATCH 40/46] address simple pr comments Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/llms.py | 1 - lib/sycamore/sycamore/llms/openai.py | 10 +--------- lib/sycamore/sycamore/llms/prompts/default_prompts.py | 2 +- lib/sycamore/sycamore/transforms/base_llm.py | 2 +- lib/sycamore/sycamore/transforms/llm_filter.py | 7 ------- 5 files changed, 3 insertions(+), 19 deletions(-) diff --git a/lib/sycamore/sycamore/llms/llms.py b/lib/sycamore/sycamore/llms/llms.py index 24028bf66..fb24a3a63 100644 --- a/lib/sycamore/sycamore/llms/llms.py +++ b/lib/sycamore/sycamore/llms/llms.py @@ -14,7 +14,6 @@ class LLMMode(Enum): - UNKNOWN = 0 SYNC = 1 ASYNC = 2 BATCH = 3 diff --git a/lib/sycamore/sycamore/llms/openai.py b/lib/sycamore/sycamore/llms/openai.py index e475c416a..7730da9ea 100644 --- a/lib/sycamore/sycamore/llms/openai.py +++ b/lib/sycamore/sycamore/llms/openai.py @@ -343,20 +343,12 @@ def _get_generate_kwargs(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict content = [{"type": "text", "text": m.content}] assert isinstance(content, list) # mypy!!! for im in m.images: - content.append({"type": "image_url", "image_url": base64_data_url(im)}) + content.append(self.format_image(im)) messages_list.append({"role": role, "content": content}) kwargs.update({"messages": messages_list}) return kwargs - def _determine_using_beta(self, response_format: Any) -> bool: - if isinstance(response_format, dict) and response_format.get("type") == "json_schema": - return True - elif inspect.isclass(response_format) and issubclass(response_format, pydantic.BaseModel): - return True - else: - return False - def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: llm_kwargs = self._convert_response_format(llm_kwargs) ret = self._llm_cache_get(prompt, llm_kwargs) diff --git a/lib/sycamore/sycamore/llms/prompts/default_prompts.py b/lib/sycamore/sycamore/llms/prompts/default_prompts.py index e769eb790..f92d881d8 100644 --- a/lib/sycamore/sycamore/llms/prompts/default_prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/default_prompts.py @@ -92,7 +92,7 @@ class _TextSummarizerGuidancePrompt(SimplePrompt): TextSummarizerGuidancePrompt = ElementPrompt( system="You are a helpful text summarizer.", - user="""Write a summary of the following. Use onlt the information provided. + user="""Write a summary of the following. Use only the information provided. Include as many key details as possible. Do not make up your answer. Only return the summary as part of your answer {elt_text} """, diff --git a/lib/sycamore/sycamore/transforms/base_llm.py b/lib/sycamore/sycamore/transforms/base_llm.py index d9efbddf6..514190dc9 100644 --- a/lib/sycamore/sycamore/transforms/base_llm.py +++ b/lib/sycamore/sycamore/transforms/base_llm.py @@ -108,7 +108,7 @@ def _validate_prompt(self): class LLMMapElements(MapBatch): """The LLMMapElements transform renders each Element for each Document in a docset into a prompt for an LLM, calls the LLM, - and attaches the output to the document. + and attaches the output to the element. Args: child: Child node in the sycamore execution graph diff --git a/lib/sycamore/sycamore/transforms/llm_filter.py b/lib/sycamore/sycamore/transforms/llm_filter.py index b350da0b4..1e7868cad 100644 --- a/lib/sycamore/sycamore/transforms/llm_filter.py +++ b/lib/sycamore/sycamore/transforms/llm_filter.py @@ -35,8 +35,6 @@ def tokenized_threshold_llm_filter( max_tokens: int, tokenizer: Tokenizer, ) -> bool: - print("=" * 80) - print(doc) element_sorter(doc) evaluated_elements = 0 @@ -59,14 +57,9 @@ def tokenized_threshold_llm_filter( if score >= doc.get(doc_source_field_name, 0): doc.properties[f"{new_field}"] = score doc.properties[f"{new_field}_source_element_index"] = window_indices - print(score, combined_text) if score >= threshold: - print("-" * 80) - print(doc) return True evaluated_elements += 1 - print("-" * 80) - print(doc) if evaluated_elements == 0: # no elements found for property return keep_none return False From 178940968da79657c9d1fc4f2c8de8a00b4e0b7e Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Wed, 29 Jan 2025 16:07:42 -0800 Subject: [PATCH 41/46] pickle stuff in llm caching path bc not everything is jsonifiable Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/llms.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/lib/sycamore/sycamore/llms/llms.py b/lib/sycamore/sycamore/llms/llms.py index fb24a3a63..79f5f16a7 100644 --- a/lib/sycamore/sycamore/llms/llms.py +++ b/lib/sycamore/sycamore/llms/llms.py @@ -129,6 +129,7 @@ def _llm_cache_get(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict]) -> key = self._llm_cache_key(prompt, llm_kwargs) hit = self._cache.get(key) if hit: + hit = pickle.loads(hit) assert ( len(hit) == 5 and hit.get("prompt") == RenderedPrompt(messages=prompt.messages) @@ -156,13 +157,15 @@ def _llm_cache_set(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict], res key = self._llm_cache_key(prompt, llm_kwargs) self._cache.set( key, - { - "prompt": RenderedPrompt(messages=prompt.messages), - "prompt.response_format": self._pickleable_response_format(prompt), - "llm_kwargs": llm_kwargs, - "model_name": self._model_name, - "result": result, - }, + pickle.dumps( + { + "prompt": RenderedPrompt(messages=prompt.messages), + "prompt.response_format": self._pickleable_response_format(prompt), + "llm_kwargs": llm_kwargs, + "model_name": self._model_name, + "result": result, + } + ), ) def get_metadata(self, kwargs, response_text, wall_latency, in_tokens, out_tokens) -> dict: From 8b6f085eb14e31e7a4bef675d46cf3b15148dab1 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Thu, 30 Jan 2025 09:46:34 -0800 Subject: [PATCH 42/46] rewrite llm_map to deal with iterative prompting better Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/prompts/prompts.py | 51 ++++----- .../sycamore/tests/unit/test_docset.py | 3 +- .../tests/unit/transforms/test_base_llm.py | 14 +-- .../unit/transforms/test_extract_entity.py | 1 + lib/sycamore/sycamore/transforms/base_llm.py | 103 ++++++++++++------ .../sycamore/transforms/extract_entity.py | 67 +++++++----- 6 files changed, 146 insertions(+), 93 deletions(-) diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py index db806077b..05e7674bd 100644 --- a/lib/sycamore/sycamore/llms/prompts/prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Union, Optional, Callable, Sequence +from typing import Any, Union, Optional, Callable import copy import pydantic @@ -42,7 +42,7 @@ class SycamorePrompt: convert sycamore objects (``Document``s, ``Element``s) into ``RenderedPrompts`` """ - def render_document(self, doc: Document) -> Union[RenderedPrompt, Sequence[RenderedPrompt]]: + def render_document(self, doc: Document) -> RenderedPrompt: """Render this prompt, given this document as context. Used in llm_map @@ -54,7 +54,7 @@ def render_document(self, doc: Document) -> Union[RenderedPrompt, Sequence[Rende """ raise NotImplementedError(f"render_document is not implemented for {self.__class__.__name__}") - def render_element(self, elt: Element, doc: Document) -> Union[RenderedPrompt, Sequence[RenderedPrompt]]: + def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: """Render this prompt, given this element and its parent document as context. Used in llm_map_elements @@ -66,7 +66,7 @@ def render_element(self, elt: Element, doc: Document) -> Union[RenderedPrompt, S """ raise NotImplementedError(f"render_element is not implemented for {self.__class__.__name__}") - def render_multiple_documents(self, docs: list[Document]) -> Union[RenderedPrompt, Sequence[RenderedPrompt]]: + def render_multiple_documents(self, docs: list[Document]) -> RenderedPrompt: """Render this prompt, given a list of documents as context. Used in llm_reduce @@ -113,19 +113,6 @@ def set(self, **kwargs) -> "SycamorePrompt": new.__dict__[k] = v return new - def is_done(self, s: str) -> bool: - """Decide whether a given response is sufficient. Used when rendering - the prompt generates a sequence of prompts rather than a single prompt. - The default implementation always returns True - - Args: - s: a string response from the LLM - - Returns: - Whether to continue making LLM calls - """ - return True - def _build_format_str( system: Optional[str], user: Union[None, str, list[str]], format_args: dict[str, Any] @@ -201,7 +188,7 @@ def _render_element_list_to_string(self, doc: Document): elts = self.element_select(doc.elements) return self.element_list_constructor(elts) - def render_document(self, doc: Document) -> Union[RenderedPrompt, Sequence[RenderedPrompt]]: + def render_document(self, doc: Document) -> RenderedPrompt: """Render this prompt, given this document as context, using python's ``str.format()`` method. The keys passed into ``format()`` are as follows: @@ -280,11 +267,18 @@ class ElementListIterPrompt(ElementListPrompt): # ] """ - def __init__(self, *, element_batcher: Optional[Callable[[list[Element]], list[list[Element]]]] = None, **kwargs): + def __init__( + self, + *, + element_batcher: Optional[Callable[[list[Element]], list[list[Element]]]] = None, + iteration_var_name: str = "i", + **kwargs, + ): self.element_batcher = element_batcher or (lambda e: [e]) + self.iteration_var_name = iteration_var_name super().__init__(**kwargs) - def render_document(self, doc: Document) -> Sequence[RenderedPrompt]: + def render_document(self, doc: Document) -> RenderedPrompt: """Render this prompt, given this document as context, using python's ``str.format()`` method. The keys passed into ``format()`` are as follows: @@ -304,19 +298,22 @@ def render_document(self, doc: Document) -> Sequence[RenderedPrompt]: ``self.user.format()`` using the format keys as specified above. Each instance is rendered from a batch of elements generated by ``self.element_batcher`` """ + i = doc.properties.get(self.iteration_var_name, 0) format_args = self.kwargs format_args["doc_text"] = doc.text_representation flat_props = flatten_data(doc.properties, prefix="doc_property", separator="_") format_args.update(flat_props) - prompts = [] - for elt_batch in self.element_batcher(doc.elements): - elements = self.element_select(elt_batch) - elementstr = self.element_list_constructor(elements) - messages = _build_format_str(self.system, self.user, {"elements": elementstr, **format_args}) - prompts.append(RenderedPrompt(messages=messages)) - return prompts + for j, elt_batch in enumerate(self.element_batcher(doc.elements)): + if j < i: + continue + else: + elements = self.element_select(elt_batch) + elementstr = self.element_list_constructor(elements) + messages = _build_format_str(self.system, self.user, {"elements": elementstr, **format_args}) + return RenderedPrompt(messages=messages) + return RenderedPrompt(messages=[]) class ElementPrompt(SycamorePrompt): diff --git a/lib/sycamore/sycamore/tests/unit/test_docset.py b/lib/sycamore/sycamore/tests/unit/test_docset.py index ba4c8bf1f..3655f1dff 100644 --- a/lib/sycamore/sycamore/tests/unit/test_docset.py +++ b/lib/sycamore/sycamore/tests/unit/test_docset.py @@ -44,7 +44,6 @@ def __init__(self): super().__init__(model_name="mock_model") def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: - print(prompt) if llm_kwargs is None: llm_kwargs = {} if prompt.messages[-1].content.endswith("Element_index: 1\nText: third element\n"): @@ -98,6 +97,8 @@ def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) return "group2" elif value == "3" or value == "three": return "group3" + else: + return "" else: return prompt.messages[-1].content diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py b/lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py index c2d4838d1..0710eb8fe 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py @@ -54,19 +54,19 @@ def test_happy_path(self): assert outdocs[1].text_representation == "booga" assert outdocs[1].properties["out"] == "booga" - def test_postprocess(self): + def test_validate(self): prompt = FakeDocPrompt() llm = FakeLLM() doc1 = Document({"text_representation": "ooga"}) doc2 = Document({"text_representation": "booga"}) count = 0 - def ppfn(d: Document, i: int) -> Document: + def valfn(d: Document) -> bool: nonlocal count count += 1 - return d + return count > 1 - map = LLMMap(None, prompt, "out", llm, postprocess_fn=ppfn) + map = LLMMap(None, prompt, "out", llm, validate=valfn) _ = map.llm_map([doc1, doc2]) assert count == 2 @@ -112,12 +112,12 @@ def test_postprocess(self): doc2 = Document({"doc_id": "2", "elements": [{"text_representation": "booga"}, {}]}) count = 0 - def ppfn(e: Element, i: int) -> Element: + def valfn(e: Element) -> bool: nonlocal count count += 1 - return e + return count > 1 - map = LLMMapElements(None, prompt, "out", llm, postprocess_fn=ppfn) + map = LLMMapElements(None, prompt, "out", llm, validate=valfn) _ = map.llm_map_elements([doc1, doc2]) assert count == 4 diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py b/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py index df706fc6f..0c06d0bff 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py @@ -202,6 +202,7 @@ def test_extract_entity_with_tokenizer(self, mocker): entity_extractor=entity_extractor, ) taken = entity_docset.take() + assert taken[0].properties[f"{new_field}_source_element_index"] == {0, 1, 2} assert taken[1].properties[f"{new_field}_source_element_index"] == {2} assert taken[0].properties[new_field] == "4" diff --git a/lib/sycamore/sycamore/transforms/base_llm.py b/lib/sycamore/sycamore/transforms/base_llm.py index 514190dc9..7aa0f4bb2 100644 --- a/lib/sycamore/sycamore/transforms/base_llm.py +++ b/lib/sycamore/sycamore/transforms/base_llm.py @@ -8,22 +8,15 @@ def _infer_prompts( - prompts: list[Sequence[RenderedPrompt]], + prompts: list[RenderedPrompt], llm: LLM, llm_mode: LLMMode, - is_done: Callable[[str], bool] = lambda s: True, ) -> list[tuple[str, int]]: if llm_mode == LLMMode.SYNC: res = [] - for piter in prompts: - s = "" - i = -1 - for p in piter: - i += 1 - s = llm.generate(prompt=p) - if is_done(s): - break - res.append((s, i)) + for p in prompts: + s = llm.generate(prompt=p) + res.append(s) return res elif llm_mode == LLMMode.ASYNC: raise NotImplementedError("Haven't done async yet") @@ -73,7 +66,9 @@ def __init__( output_field: str, llm: LLM, llm_mode: LLMMode = LLMMode.SYNC, - postprocess_fn: Callable[[Document, int], Document] = lambda d, i: d, + iteration_var: Optional[str] = None, + validate: Callable[[Document], bool] = lambda d: True, + max_tries: int = 5, **kwargs, ): self._prompt = prompt @@ -81,19 +76,37 @@ def __init__( self._output_field = output_field self._llm = llm self._llm_mode = llm_mode - self._postprocess_fn = postprocess_fn + self._iteration_var = iteration_var + self._validate = validate + self._max_tries = max_tries super().__init__(child, f=self.llm_map, **kwargs) def llm_map(self, documents: list[Document]) -> list[Document]: - rendered_inc = [self._prompt.render_document(d) for d in documents] - rendered = _as_sequences(rendered_inc) - results = _infer_prompts(rendered, self._llm, self._llm_mode, self._prompt.is_done) - postprocessed = [] - for d, (r, i) in zip(documents, results): - d.properties[self._output_field] = r - new_d = self._postprocess_fn(d, i) - postprocessed.append(new_d) - return postprocessed + if self._iteration_var is not None: + for d in documents: + d.properties[self._iteration_var] = 0 + + valid = [False] * len(documents) + tries = 0 + while not all(valid) and tries < self._max_tries: + tries += 1 + rendered = [self._prompt.render_document(d) for v, d in zip(valid, documents) if not v] + if sum([0, *(len(r.messages) for r in rendered)]) == 0: + break + results = _infer_prompts(rendered, self._llm, self._llm_mode) + ri = 0 + for i in range(len(documents)): + if valid[i]: + continue + documents[i].properties[self._output_field] = results[ri] + valid[i] = self._validate(documents[i]) + ri += 1 + if self._iteration_var is not None and not valid[i]: + documents[i].properties[self._iteration_var] += 1 + if self._iteration_var is None: + break + + return documents def _validate_prompt(self): doc = Document() @@ -143,7 +156,9 @@ def __init__( output_field: str, llm: LLM, llm_mode: LLMMode = LLMMode.SYNC, - postprocess_fn: Callable[[Element, int], Element] = lambda e, i: e, + iteration_var: Optional[str] = None, + validate: Callable[[Element], bool] = lambda d: True, + max_tries: int = 5, **kwargs, ): self._prompt = prompt @@ -151,22 +166,46 @@ def __init__( self._output_field = output_field self._llm = llm self._llm_mode = llm_mode - self._postprocess_fn = postprocess_fn + self._iteration_var = iteration_var + self._validate = validate + self._max_tries = max_tries super().__init__(child, f=self.llm_map_elements, **kwargs) def llm_map_elements(self, documents: list[Document]) -> list[Document]: - rendered = [(d, e, self._prompt.render_element(e, d)) for d in documents for e in d.elements] - results = _infer_prompts( - _as_sequences([p for _, _, p in rendered]), self._llm, self._llm_mode, self._prompt.is_done - ) - new_elts = [] + elt_doc_pairs = [(e, d) for d in documents for e in d.elements] + if self._iteration_var is not None: + for e, _ in elt_doc_pairs: + e.properties[self._iteration_var] = 0 + + valid = [False] * len(elt_doc_pairs) + tries = 0 + while not all(valid) and tries < self._max_tries: + tries += 1 + rendered = [self._prompt.render_element(e, d) for v, (e, d) in zip(valid, elt_doc_pairs) if not v] + if sum([0, *(len(r.messages) for r in rendered)]) == 0: + break + results = _infer_prompts(rendered, self._llm, self._llm_mode) + ri = 0 + for i in range(len(elt_doc_pairs)): + if valid[i]: + continue + print(ri) + elt, doc = elt_doc_pairs[i] + elt.properties[self._output_field] = results[ri] + valid[i] = self._validate(elt) + ri += 1 + if self._iteration_var is not None: + elt.properties[self._iteration_var] += 1 + if self._iteration_var is None: + break + last_doc = None - for (r, i), (d, e, _) in zip(results, rendered): + new_elts = [] + for e, d in elt_doc_pairs: if last_doc is not None and last_doc.doc_id != d.doc_id: last_doc.elements = new_elts new_elts = [] - e.properties[self._output_field] = r - new_elts.append(self._postprocess_fn(e, i)) + new_elts.append(e) last_doc = d if last_doc is not None: last_doc.elements = new_elts diff --git a/lib/sycamore/sycamore/transforms/extract_entity.py b/lib/sycamore/sycamore/transforms/extract_entity.py index 15c05e313..260db3c12 100644 --- a/lib/sycamore/sycamore/transforms/extract_entity.py +++ b/lib/sycamore/sycamore/transforms/extract_entity.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod from typing import Callable, Any, Optional, Union, cast -from functools import partial from sycamore.context import Context, context_params, OperationTypes from sycamore.data import Element, Document @@ -56,7 +55,7 @@ def __init__(self, entity_name: str): @abstractmethod def as_llm_map( self, child: Optional[Node], context: Optional[Context] = None, llm: Optional[LLM] = None, **kwargs - ) -> LLMMap: + ) -> Node: pass @abstractmethod @@ -133,7 +132,7 @@ def __init__( @context_params(OperationTypes.INFORMATION_EXTRACTOR) def as_llm_map( self, child: Optional[Node], context: Optional[Context] = None, llm: Optional[LLM] = None, **kwargs - ) -> LLMMap: + ) -> Node: if llm is None: llm = self._llm assert llm is not None, "Could not find an LLM to use" @@ -146,19 +145,12 @@ def as_llm_map( if self._tokenizer is not None: - def postprocess(d: Document, i: int) -> Document: - last_club: set[int] = set() - source_key = f"{self._entity_name}_source_element_index" - for e in d.elements: - if e.properties[source_key] != last_club: - i -= 1 - last_club = e.properties[source_key] - if i == -1: - break - d.properties[source_key] = last_club - return d + def validate(d: Document) -> bool: + return d.properties.get(self._entity_name, "None") != "None" def elt_list_ctor(elts: list[Element]) -> str: + if self._prompt_formatter is not element_list_formatter: + return self._prompt_formatter(elts, self._field) combined_text = "" for element in elts: if "type" in element: @@ -170,6 +162,8 @@ def elt_list_ctor(elts: list[Element]) -> str: combined_text += f"Text: {element.field_to_value(self._field)}\n" return combined_text + source_idx_key = f"{self._entity_name}_source_element_index" + def eb(elts: list[Element]) -> list[list[Element]]: curr_tks = 0 curr_batch: list[Element] = [] @@ -177,7 +171,7 @@ def eb(elts: list[Element]) -> list[list[Element]]: source_indices = set() assert ( self._tokenizer is not None - ), "Cannot batch elements based on token counts because tokenier is None" + ), "Cannot batch elements based on token counts because tokenizer is None" for e in elts: eltl = cast(ElementListPrompt, prompt).element_list_constructor([e]) tks = len(self._tokenizer.tokenize(eltl)) @@ -186,24 +180,45 @@ def eb(elts: list[Element]) -> list[list[Element]]: curr_tks = tks curr_batch = [e] source_indices = {e.element_index} - e.properties[f"{self._entity_name}_source_element_index"] = source_indices + e.properties[source_idx_key] = source_indices else: - e.properties[f"{self._entity_name}_source_element_index"] = source_indices + e.properties[source_idx_key] = source_indices source_indices.add(e.element_index) curr_batch.append(e) curr_tks += tks batches.append(curr_batch) return batches - prompt.render_document = partial(ElementListIterPrompt.render_document, prompt) # type: ignore - if self._prompt_formatter is not element_list_formatter: - prompt = prompt.set(element_list_constructor=self._prompt_formatter) - else: - prompt = prompt.set(element_list_constructor=elt_list_ctor) - setattr(prompt, "element_batcher", eb) - setattr(prompt, "is_done", lambda s: s != "None") - prompt = prompt.set(entity=self._entity_name) - return LLMMap(child, prompt, self._entity_name, llm, postprocess_fn=postprocess, **kwargs) + iteration_var_name = f"{self._entity_name}_i" + + def postprocess(d: Document) -> Document: + last_eclub: set[int] = set() + club_idx = 0 + target_club_idx = d.properties[iteration_var_name] + for e in d.elements: + if len(last_eclub) > 0 and e.properties[source_idx_key] != last_eclub: + club_idx += 1 + last_eclub = e.properties[source_idx_key] + if club_idx == target_club_idx: + d.properties[source_idx_key] = last_eclub + break + return d + + prompt = ElementListIterPrompt( + system=prompt.system, + user=prompt.user, + element_list_constructor=elt_list_ctor, + element_batcher=eb, + entity=self._entity_name, + examples=self._prompt_template, + iteration_var_name=iteration_var_name, + ) + + llm_map = LLMMap( + child, prompt, self._entity_name, llm, iteration_var=iteration_var_name, validate=validate, **kwargs + ) + ppmap = Map(llm_map, f=postprocess) + return ppmap elif not self._use_elements: if self._prompt is None: From 763acc5a3d120ac704f09fada035083480641070 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Thu, 30 Jan 2025 09:56:31 -0800 Subject: [PATCH 43/46] add a b64encode-to-str to cache bc you can't put bytes in json either Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/llms.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/lib/sycamore/sycamore/llms/llms.py b/lib/sycamore/sycamore/llms/llms.py index 79f5f16a7..070c60098 100644 --- a/lib/sycamore/sycamore/llms/llms.py +++ b/lib/sycamore/sycamore/llms/llms.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from enum import Enum import pickle +import base64 from PIL import Image from typing import Any, Optional import pydantic @@ -129,6 +130,7 @@ def _llm_cache_get(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict]) -> key = self._llm_cache_key(prompt, llm_kwargs) hit = self._cache.get(key) if hit: + hit = base64.b64decode(hit) hit = pickle.loads(hit) assert ( len(hit) == 5 @@ -155,17 +157,19 @@ def _llm_cache_set(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict], res assert self._cache is not None, "make mypy happy" key = self._llm_cache_key(prompt, llm_kwargs) + databytes = pickle.dumps( + { + "prompt": RenderedPrompt(messages=prompt.messages), + "prompt.response_format": self._pickleable_response_format(prompt), + "llm_kwargs": llm_kwargs, + "model_name": self._model_name, + "result": result, + } + ) + datastr = base64.b64encode(databytes).decode("utf-8") self._cache.set( key, - pickle.dumps( - { - "prompt": RenderedPrompt(messages=prompt.messages), - "prompt.response_format": self._pickleable_response_format(prompt), - "llm_kwargs": llm_kwargs, - "model_name": self._model_name, - "result": result, - } - ), + datastr, ) def get_metadata(self, kwargs, response_text, wall_latency, in_tokens, out_tokens) -> dict: From 0331866fe3634fb647ffc37217faf975c5353b08 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Thu, 30 Jan 2025 11:33:37 -0800 Subject: [PATCH 44/46] fix llm its to mimic the _llm_cache_set/get pickle/unpickle operations Signed-off-by: Henry Lindeman --- .../tests/integration/llms/test_anthropic.py | 40 +++++++---- .../tests/integration/llms/test_bedrock.py | 40 +++++++---- .../tests/integration/llms/test_openai.py | 68 +++++++++++-------- 3 files changed, 92 insertions(+), 56 deletions(-) diff --git a/lib/sycamore/sycamore/tests/integration/llms/test_anthropic.py b/lib/sycamore/sycamore/tests/integration/llms/test_anthropic.py index 11c0a1605..e642e266b 100644 --- a/lib/sycamore/sycamore/tests/integration/llms/test_anthropic.py +++ b/lib/sycamore/sycamore/tests/integration/llms/test_anthropic.py @@ -1,11 +1,23 @@ from pathlib import Path from typing import Any +import base64 +import pickle from sycamore.llms import Anthropic, AnthropicModels from sycamore.llms.prompts.prompts import RenderedPrompt, RenderedMessage from sycamore.utils.cache import DiskCache +def cacheget(cache: DiskCache, key: str): + hit = cache.get(key) + return pickle.loads(base64.b64decode(hit)) # type: ignore + + +def cacheset(cache: DiskCache, key: str, data: Any): + databytes = pickle.dumps(data) + cache.set(key, base64.b64encode(databytes).decode("utf-8")) + + def test_anthropic_defaults(): llm = Anthropic(AnthropicModels.CLAUDE_3_HAIKU) prompt = RenderedPrompt( @@ -45,11 +57,11 @@ def test_cached_anthropic(tmp_path: Path): res = llm.generate(prompt=prompt, llm_kwargs={}) # assert result is cached - assert cache.get(key).get("result")["output"] == res - assert cache.get(key).get("prompt") == prompt - assert cache.get(key).get("prompt.response_format") is None - assert cache.get(key).get("llm_kwargs") == {} - assert cache.get(key).get("model_name") == AnthropicModels.CLAUDE_3_HAIKU.value + assert cacheget(cache, key).get("result")["output"] == res + assert cacheget(cache, key).get("prompt") == prompt + assert cacheget(cache, key).get("prompt.response_format") is None + assert cacheget(cache, key).get("llm_kwargs") == {} + assert cacheget(cache, key).get("model_name") == AnthropicModels.CLAUDE_3_HAIKU.value # assert llm.generate is using cached result custom_output: dict[str, Any] = { @@ -59,7 +71,7 @@ def test_cached_anthropic(tmp_path: Path): "llm_kwargs": {}, "model_name": AnthropicModels.CLAUDE_3_HAIKU.value, } - cache.set(key, custom_output) + cacheset(cache, key, custom_output) assert llm.generate(prompt=prompt, llm_kwargs={}) == custom_output["result"]["output"] @@ -113,14 +125,14 @@ def test_cached_anthropic_different_models(tmp_path: Path): res_SONNET = llm_SONNET.generate(prompt=prompt, llm_kwargs={}) # check proper cached results - assert cache.get(key_HAIKU).get("result")["output"] == res_HAIKU - assert cache.get(key_HAIKU).get("prompt") == prompt - assert cache.get(key_HAIKU).get("llm_kwargs") == {} - assert cache.get(key_HAIKU).get("model_name") == AnthropicModels.CLAUDE_3_HAIKU.value - assert cache.get(key_SONNET).get("result")["output"] == res_SONNET - assert cache.get(key_SONNET).get("prompt") == prompt - assert cache.get(key_SONNET).get("llm_kwargs") == {} - assert cache.get(key_SONNET).get("model_name") == AnthropicModels.CLAUDE_3_SONNET.value + assert cacheget(cache, key_HAIKU).get("result")["output"] == res_HAIKU + assert cacheget(cache, key_HAIKU).get("prompt") == prompt + assert cacheget(cache, key_HAIKU).get("llm_kwargs") == {} + assert cacheget(cache, key_HAIKU).get("model_name") == AnthropicModels.CLAUDE_3_HAIKU.value + assert cacheget(cache, key_SONNET).get("result")["output"] == res_SONNET + assert cacheget(cache, key_SONNET).get("prompt") == prompt + assert cacheget(cache, key_SONNET).get("llm_kwargs") == {} + assert cacheget(cache, key_SONNET).get("model_name") == AnthropicModels.CLAUDE_3_SONNET.value # check for difference with model change assert key_HAIKU != key_SONNET diff --git a/lib/sycamore/sycamore/tests/integration/llms/test_bedrock.py b/lib/sycamore/sycamore/tests/integration/llms/test_bedrock.py index 8621039eb..01719ab0e 100644 --- a/lib/sycamore/sycamore/tests/integration/llms/test_bedrock.py +++ b/lib/sycamore/sycamore/tests/integration/llms/test_bedrock.py @@ -1,11 +1,23 @@ from pathlib import Path from typing import Any +import pickle +import base64 from sycamore.llms import Bedrock, BedrockModels from sycamore.llms.prompts import RenderedPrompt, RenderedMessage from sycamore.utils.cache import DiskCache +def cacheget(cache: DiskCache, key: str): + hit = cache.get(key) + return pickle.loads(base64.b64decode(hit)) # type: ignore + + +def cacheset(cache: DiskCache, key: str, data: Any): + databytes = pickle.dumps(data) + cache.set(key, base64.b64encode(databytes).decode("utf-8")) + + # Note: These tests assume your environment has been configured to access Amazon Bedrock. @@ -48,11 +60,11 @@ def test_cached_bedrock(tmp_path: Path): res = llm.generate(prompt=prompt, llm_kwargs={}) # assert result is cached - assert cache.get(key).get("result")["output"] == res - assert cache.get(key).get("prompt") == prompt - assert cache.get(key).get("prompt.response_format") is None - assert cache.get(key).get("llm_kwargs") == {} - assert cache.get(key).get("model_name") == BedrockModels.CLAUDE_3_HAIKU.value.name + assert cacheget(cache, key).get("result")["output"] == res + assert cacheget(cache, key).get("prompt") == prompt + assert cacheget(cache, key).get("prompt.response_format") is None + assert cacheget(cache, key).get("llm_kwargs") == {} + assert cacheget(cache, key).get("model_name") == BedrockModels.CLAUDE_3_HAIKU.value.name # assert llm.generate is using cached result custom_output: dict[str, Any] = { @@ -62,7 +74,7 @@ def test_cached_bedrock(tmp_path: Path): "llm_kwargs": {}, "model_name": BedrockModels.CLAUDE_3_HAIKU.value.name, } - cache.set(key, custom_output) + cacheset(cache, key, custom_output) assert llm.generate(prompt=prompt, llm_kwargs={}) == custom_output["result"]["output"] @@ -116,14 +128,14 @@ def test_cached_bedrock_different_models(tmp_path: Path): res_SONNET = llm_SONNET.generate(prompt=prompt, llm_kwargs={}) # check proper cached results - assert cache.get(key_HAIKU).get("result")["output"] == res_HAIKU - assert cache.get(key_HAIKU).get("prompt") == prompt - assert cache.get(key_HAIKU).get("llm_kwargs") == {} - assert cache.get(key_HAIKU).get("model_name") == BedrockModels.CLAUDE_3_HAIKU.value.name - assert cache.get(key_SONNET).get("result")["output"] == res_SONNET - assert cache.get(key_SONNET).get("prompt") == prompt - assert cache.get(key_SONNET).get("llm_kwargs") == {} - assert cache.get(key_SONNET).get("model_name") == BedrockModels.CLAUDE_3_SONNET.value.name + assert cacheget(cache, key_HAIKU).get("result")["output"] == res_HAIKU + assert cacheget(cache, key_HAIKU).get("prompt") == prompt + assert cacheget(cache, key_HAIKU).get("llm_kwargs") == {} + assert cacheget(cache, key_HAIKU).get("model_name") == BedrockModels.CLAUDE_3_HAIKU.value.name + assert cacheget(cache, key_SONNET).get("result")["output"] == res_SONNET + assert cacheget(cache, key_SONNET).get("prompt") == prompt + assert cacheget(cache, key_SONNET).get("llm_kwargs") == {} + assert cacheget(cache, key_SONNET).get("model_name") == BedrockModels.CLAUDE_3_SONNET.value.name # check for difference with model change assert key_HAIKU != key_SONNET diff --git a/lib/sycamore/sycamore/tests/integration/llms/test_openai.py b/lib/sycamore/sycamore/tests/integration/llms/test_openai.py index 67152e1b0..70910f4a1 100644 --- a/lib/sycamore/sycamore/tests/integration/llms/test_openai.py +++ b/lib/sycamore/sycamore/tests/integration/llms/test_openai.py @@ -1,6 +1,8 @@ from pathlib import Path import pickle +import base64 import pytest +from typing import Any from sycamore.llms import OpenAI, OpenAIModels, OpenAIClientWrapper from sycamore.llms.openai import OpenAIModel, OpenAIClientType @@ -10,6 +12,16 @@ from pydantic import BaseModel +def cacheget(cache: DiskCache, key: str): + hit = cache.get(key) + return pickle.loads(base64.b64decode(hit)) # type: ignore + + +def cacheset(cache: DiskCache, key: str, data: Any): + databytes = pickle.dumps(data) + cache.set(key, base64.b64encode(databytes).decode("utf-8")) + + # Note: These tests expect you to have OPENAI_API_KEY set in your environment. @@ -54,11 +66,11 @@ def test_cached_openai(tmp_path: Path): res = llm.generate(prompt=prompt, llm_kwargs={}) # assert result is cached - assert cache.get(key).get("result") == res - assert cache.get(key).get("prompt") == prompt - assert cache.get(key).get("prompt.response_format") is None - assert cache.get(key).get("llm_kwargs") == {} - assert cache.get(key).get("model_name") == "gpt-3.5-turbo" + assert cacheget(cache, key).get("result") == res + assert cacheget(cache, key).get("prompt") == prompt + assert cacheget(cache, key).get("prompt.response_format") is None + assert cacheget(cache, key).get("llm_kwargs") == {} + assert cacheget(cache, key).get("model_name") == "gpt-3.5-turbo" # assert llm.generate is using cached result custom_output = { @@ -68,7 +80,7 @@ def test_cached_openai(tmp_path: Path): "llm_kwargs": {}, "model_name": "gpt-3.5-turbo", } - cache.set(key, custom_output) + cacheset(cache, key, custom_output) assert llm.generate(prompt=prompt, llm_kwargs={}) == custom_output["result"] @@ -83,12 +95,12 @@ def test_cached_guidance(tmp_path: Path): res = llm.generate(prompt=prompt, llm_kwargs=None) # assert result is cached - assert isinstance(cache.get(key), dict) - assert cache.get(key).get("result") == res - assert cache.get(key).get("prompt") == prompt - assert cache.get(key).get("prompt.response_format") is None - assert cache.get(key).get("llm_kwargs") is None - assert cache.get(key).get("model_name") == "gpt-3.5-turbo" + assert isinstance(cacheget(cache, key), dict) + assert cacheget(cache, key).get("result") == res + assert cacheget(cache, key).get("prompt") == prompt + assert cacheget(cache, key).get("prompt.response_format") is None + assert cacheget(cache, key).get("llm_kwargs") is None + assert cacheget(cache, key).get("model_name") == "gpt-3.5-turbo" # assert llm.generate is using cached result custom_output = { @@ -98,7 +110,7 @@ def test_cached_guidance(tmp_path: Path): "llm_kwargs": None, "model_name": "gpt-3.5-turbo", } - cache.set(key, custom_output) + cacheset(cache, key, custom_output) assert llm.generate(prompt=TestPrompt().render_generic(), llm_kwargs=None) == custom_output["result"] @@ -152,14 +164,14 @@ def test_cached_openai_different_models(tmp_path: Path): res_GPT_4O_MINI = llm_GPT_4O_MINI.generate(prompt=prompt, llm_kwargs={}) # check proper cached results - assert cache.get(key_GPT_3_5_TURBO).get("result") == res_GPT_3_5_TURBO - assert cache.get(key_GPT_3_5_TURBO).get("prompt") == prompt - assert cache.get(key_GPT_3_5_TURBO).get("llm_kwargs") == {} - assert cache.get(key_GPT_3_5_TURBO).get("model_name") == "gpt-3.5-turbo" - assert cache.get(key_GPT_4O_MINI).get("result") == res_GPT_4O_MINI - assert cache.get(key_GPT_4O_MINI).get("prompt") == prompt - assert cache.get(key_GPT_4O_MINI).get("llm_kwargs") == {} - assert cache.get(key_GPT_4O_MINI).get("model_name") == "gpt-4o-mini" + assert cacheget(cache, key_GPT_3_5_TURBO).get("result") == res_GPT_3_5_TURBO + assert cacheget(cache, key_GPT_3_5_TURBO).get("prompt") == prompt + assert cacheget(cache, key_GPT_3_5_TURBO).get("llm_kwargs") == {} + assert cacheget(cache, key_GPT_3_5_TURBO).get("model_name") == "gpt-3.5-turbo" + assert cacheget(cache, key_GPT_4O_MINI).get("result") == res_GPT_4O_MINI + assert cacheget(cache, key_GPT_4O_MINI).get("prompt") == prompt + assert cacheget(cache, key_GPT_4O_MINI).get("llm_kwargs") == {} + assert cacheget(cache, key_GPT_4O_MINI).get("model_name") == "gpt-4o-mini" # check for difference with model change assert key_GPT_3_5_TURBO != key_GPT_4O_MINI @@ -187,13 +199,13 @@ class Statement(BaseModel): print(res_GPT_4O_MINI) assert key_GPT_4O_MINI is not None # check cache - assert cache.get(key_GPT_4O_MINI).get("result") == res_GPT_4O_MINI - assert cache.get(key_GPT_4O_MINI).get("prompt") == RenderedPrompt(messages=prompt.messages) - assert cache.get(key_GPT_4O_MINI).get("prompt.response_format") == llm_GPT_4O_MINI._pickleable_response_format( - prompt - ) - assert cache.get(key_GPT_4O_MINI).get("llm_kwargs") == llm_kwargs_cached - assert cache.get(key_GPT_4O_MINI).get("model_name") == "gpt-4o-mini" + assert cacheget(cache, key_GPT_4O_MINI).get("result") == res_GPT_4O_MINI + assert cacheget(cache, key_GPT_4O_MINI).get("prompt") == RenderedPrompt(messages=prompt.messages) + assert cacheget(cache, key_GPT_4O_MINI).get( + "prompt.response_format" + ) == llm_GPT_4O_MINI._pickleable_response_format(prompt) + assert cacheget(cache, key_GPT_4O_MINI).get("llm_kwargs") == llm_kwargs_cached + assert cacheget(cache, key_GPT_4O_MINI).get("model_name") == "gpt-4o-mini" class TestPrompt(StaticPrompt): From dfb75402b2b0f45171aa721722af16beca7c5d8c Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Thu, 30 Jan 2025 11:50:39 -0800 Subject: [PATCH 45/46] fix docstrings Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/prompts/prompts.py | 34 +++++++++++-------- lib/sycamore/sycamore/transforms/base_llm.py | 22 ++++++------ 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py index 05e7674bd..2baef85f2 100644 --- a/lib/sycamore/sycamore/llms/prompts/prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -240,6 +240,8 @@ class ElementListIterPrompt(ElementListPrompt): unset. Default is 35. element_batcher: Constructs batches of elements to render in sequence to generate several rendered prompts. Defaults to one batch with all elements. + iteration_var_name: Name of the property to look for in the document to determine + which batch of elements to use to render the prompt. Default is "i" **kwargs: other keyword arguments are stored and can be used as interpolation keys. Example: @@ -250,20 +252,21 @@ class ElementListIterPrompt(ElementListPrompt): user = "What is the capital of the country described?\\nElements:\\n{elements}" element_batcher = lambda elts: [elts[i:i+2] for i in range(0, len(elts), 2)] ).set(is_done=lambda s: s != 'None') + doc.properties["i"] = 0 prompt.render_document(doc) # [ - # [ - # {"role": "system", "content": "You are a program that returns 'None' if you don't - # know the answer to my question"}, - # {"role": "user", "content": "What is the capital of the country described?\\nElements:\\n - # ELEMENT 0: \\nELEMENT 1: "} - # ], - # [ - # {"role": "system", "content": "You are a program that returns 'None' if you don't - # know the answer to my question"}, - # {"role": "user", "content": "What is the capital of the country described?\\nElements:\\n - # ELEMENT 0: \\nELEMENT 1: "} - # ] + # {"role": "system", "content": "You are a program that returns 'None' if you don't + # know the answer to my question"}, + # {"role": "user", "content": "What is the capital of the country described?\\nElements:\\n + # ELEMENT 0: \\nELEMENT 1: "} + # ] + doc.properties["i"] = 1 + prompt.render_document(doc) + # [ + # {"role": "system", "content": "You are a program that returns 'None' if you don't + # know the answer to my question"}, + # {"role": "user", "content": "What is the capital of the country described?\\nElements:\\n + # ELEMENT 0: \\nELEMENT 1: "} # ] """ @@ -294,9 +297,10 @@ def render_document(self, doc: Document) -> RenderedPrompt: doc: The document to use as context for rendering this prompt Returns: - A list of two-message RenderedPrompts containing ``self.system.format()`` and - ``self.user.format()`` using the format keys as specified above. Each instance - is rendered from a batch of elements generated by ``self.element_batcher`` + A two-message RenderedPrompt containing ``self.system.format()`` and + ``self.user.format()`` using the format keys as specified above. The prompt is + rendered from the ``doc.properties[self.iteration_var_name]``'th batch of + elements generated by ``self.element_batcher`` """ i = doc.properties.get(self.iteration_var_name, 0) diff --git a/lib/sycamore/sycamore/transforms/base_llm.py b/lib/sycamore/sycamore/transforms/base_llm.py index 7aa0f4bb2..043f7fb5e 100644 --- a/lib/sycamore/sycamore/transforms/base_llm.py +++ b/lib/sycamore/sycamore/transforms/base_llm.py @@ -41,11 +41,12 @@ class LLMMap(MapBatch): llm: The llm to use for inference. llm_mode: How to call the llm - sync/async/batch. All LLMs do not necessarily implement all options. - postprocess_fn: function to call on documents after performing the - llm inference. If the prompt rendered into multiple RenderedPrompts, - ``i`` is the index of the RenderedPrompt that succeeded; if the - prompt rendered into an empty list, ``i`` is -1; and otherwise - ``i`` is 0 + iteration_var: Name of the document property to increment with every + invalid response. Default is None, which means no re-try. + validate: Function to determine whether an LLM response is valid. + Default is 'everything is valid' + max_tries: Hard limit on the number of LLM calls per document. Default + is 5 Example: .. code-block:: python @@ -132,11 +133,12 @@ class LLMMapElements(MapBatch): llm: The llm to use for inference. llm_mode: How to call the llm - sync/async/batch. All LLMs do not necessarily implement all options. - postprocess_fn: function to call on documents after performing the - llm inference. If the prompt rendered into multiple RenderedPrompts, - ``i`` is the index of the RenderedPrompt that succeeded; if the - prompt rendered into an empty list, ``i`` is -1; and otherwise - ``i`` is 0 + iteration_var: Name of the element property to increment with every + invalid response. Default is None, which means no re-try. + validate: Function to determine whether an LLM response is valid. + Default is 'everything is valid' + max_tries: Hard limit on the number of LLM calls per element. Default + is 5 Example: .. code-block:: python From f7c06e75f535b1889c606825ba41bf4061312522 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Thu, 30 Jan 2025 11:58:09 -0800 Subject: [PATCH 46/46] oops bad type signature Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/transforms/base_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/sycamore/sycamore/transforms/base_llm.py b/lib/sycamore/sycamore/transforms/base_llm.py index 043f7fb5e..d1514fc45 100644 --- a/lib/sycamore/sycamore/transforms/base_llm.py +++ b/lib/sycamore/sycamore/transforms/base_llm.py @@ -11,7 +11,7 @@ def _infer_prompts( prompts: list[RenderedPrompt], llm: LLM, llm_mode: LLMMode, -) -> list[tuple[str, int]]: +) -> list[str]: if llm_mode == LLMMode.SYNC: res = [] for p in prompts: