Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llm unify 2/n] Implement llm_map(_elements) and move extract_entity to it. #1126

Merged
merged 49 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
c2a8cfa
add prompt base classes and ElementListPrompt
HenryL27 Jan 17, 2025
21a115a
override .instead in ElementListPrompt to store net-new keys in self.…
HenryL27 Jan 17, 2025
f94da80
add ElementPrompt and StaticPrompt
HenryL27 Jan 17, 2025
b73c162
add unit tests for prompts
HenryL27 Jan 21, 2025
17b2163
forgot to commit this
HenryL27 Jan 21, 2025
5d145d5
address pr comments; flatten properties with flatten_data
HenryL27 Jan 21, 2025
7fa2ff1
support multiple user prompts
HenryL27 Jan 21, 2025
abf9b0b
rename instead to set
HenryL27 Jan 22, 2025
9909c7e
Merge branch 'main' of github.com:aryn-ai/sycamore into hml-llm-unify
HenryL27 Jan 22, 2025
2d1315b
add LLMMap and LLMMapElements transforms
HenryL27 Jan 22, 2025
1853d51
Merge branch 'main' of github.com:aryn-ai/sycamore into hml-llm-unify
HenryL27 Jan 22, 2025
5e86e56
move llm implementations to use RenderedPrompts
HenryL27 Jan 22, 2025
27581ef
also this guy
HenryL27 Jan 22, 2025
739b672
add docset methods
HenryL27 Jan 23, 2025
73d9bdd
docstrings
HenryL27 Jan 23, 2025
ed8785e
add llm_map unit tests
HenryL27 Jan 23, 2025
523d6e3
fix bedrock tests and chaching
HenryL27 Jan 23, 2025
e1b3206
fix anthropic and bedrock ITs
HenryL27 Jan 23, 2025
6500e1c
adjust caching to handle pydantic class response format properly
HenryL27 Jan 23, 2025
f50032d
fix base llm unit tests
HenryL27 Jan 23, 2025
c3c7ea8
adjust all testing mock llms to updated llm interface
HenryL27 Jan 23, 2025
ffaaf0f
deprecate extract entity and implement it with llm_map
HenryL27 Jan 24, 2025
d71cf1a
add context_params decorator to llm_map
HenryL27 Jan 24, 2025
4225e11
revert extract_entity docset method re-implementation
HenryL27 Jan 24, 2025
0d39b27
add initial support for prompts that generate a sequence of rendered …
HenryL27 Jan 25, 2025
0b5ded4
add stuff to EntityExtractor/OpenAIEntityExtractor to convert to LLMMap
HenryL27 Jan 25, 2025
a52f7c2
make docset.extract_entity construct an LLMMap from its entity_extractor
HenryL27 Jan 25, 2025
3a9ac3c
get extract entity working with tokenizer and token limit
HenryL27 Jan 28, 2025
befc3d0
get all extract_entity unit tests passing
HenryL27 Jan 28, 2025
8bf42d5
fix llm_map_elements to deal with postprocess index
HenryL27 Jan 28, 2025
d7ff1eb
add postprocess_fn unit tests for llm_map
HenryL27 Jan 28, 2025
a7a2cc0
ruff complaint
HenryL27 Jan 28, 2025
ebf721e
fix docset unittests
HenryL27 Jan 28, 2025
0bd2a45
move a bunch of stuff back to llm.generate_old. This includes the act…
HenryL27 Jan 28, 2025
95cbaaf
move more stuff back to llm.generate_old
HenryL27 Jan 28, 2025
ea7f0e6
fix the last few mocks
HenryL27 Jan 28, 2025
2e51ee1
Merge branch 'main' of github.com:aryn-ai/sycamore into hml-llm-unify
HenryL27 Jan 28, 2025
57a4e4b
ruff linelength
HenryL27 Jan 28, 2025
a312ba3
mypy!!!
HenryL27 Jan 28, 2025
ebde879
type: ignore + line length is tricky
HenryL27 Jan 28, 2025
ff5efdc
fix generate_old with SimplePrompts
HenryL27 Jan 28, 2025
370e2b7
set openai system role name to system instead of developer like their…
HenryL27 Jan 28, 2025
98ce6a0
address simple pr comments
HenryL27 Jan 29, 2025
1789409
pickle stuff in llm caching path bc not everything is jsonifiable
HenryL27 Jan 30, 2025
8b6f085
rewrite llm_map to deal with iterative prompting better
HenryL27 Jan 30, 2025
763acc5
add a b64encode-to-str to cache bc you can't put bytes in json either
HenryL27 Jan 30, 2025
0331866
fix llm its to mimic the _llm_cache_set/get pickle/unpickle operations
HenryL27 Jan 30, 2025
dfb7540
fix docstrings
HenryL27 Jan 30, 2025
f7c06e7
oops bad type signature
HenryL27 Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions lib/sycamore/sycamore/docset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from sycamore.context import Context, context_params, OperationTypes
from sycamore.data import Document, Element, MetadataDocument
from sycamore.functions.tokenizer import Tokenizer
from sycamore.llms.llms import LLM
from sycamore.llms.llms import LLM, LLMMode
from sycamore.llms.prompts import SycamorePrompt
from sycamore.llms.prompts.default_prompts import (
LlmClusterEntityAssignGroupsMessagesPrompt,
LlmClusterEntityFormGroupsMessagesPrompt,
Expand All @@ -29,6 +30,7 @@
from sycamore.transforms.extract_table import TableExtractor
from sycamore.transforms.merge_elements import ElementMerger
from sycamore.utils.extract_json import extract_json
from sycamore.utils.deprecate import deprecated
from sycamore.transforms.query import QueryExecutor, Query
from sycamore.materialize_config import MaterializeSourceMode

Expand Down Expand Up @@ -465,6 +467,7 @@ def extract_document_structure(self, structure: DocumentStructure, **kwargs):
document_structure = ExtractDocumentStructure(self.plan, structure=structure, **kwargs)
return DocSet(self.context, document_structure)

@deprecated(version="0.1.31", reason="Use llm_map instead")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the plan also to deprecate extract_properties?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the plan is to deprecate just about everything on my sprint tasks

def extract_entity(self, entity_extractor: EntityExtractor, **kwargs) -> "DocSet":
"""
Applies the ExtractEntity transform on the Docset.
Expand All @@ -489,10 +492,8 @@ def extract_entity(self, entity_extractor: EntityExtractor, **kwargs) -> "DocSet
.extract_entity(entity_extractor=entity_extractor)

"""
from sycamore.transforms import ExtractEntity

entities = ExtractEntity(self.plan, context=self.context, entity_extractor=entity_extractor, **kwargs)
return DocSet(self.context, entities)
llm_map = entity_extractor.as_llm_map(self.plan, context=self.context, **kwargs)
return DocSet(self.context, llm_map)

def extract_schema(self, schema_extractor: SchemaExtractor, **kwargs) -> "DocSet":
"""
Expand Down Expand Up @@ -948,6 +949,42 @@ def custom_flat_mapping_function(document: Document) -> list[Document]:
flat_map = FlatMap(self.plan, f=f, **resource_args)
return DocSet(self.context, flat_map)

def llm_map(
self, prompt: SycamorePrompt, output_field: str, llm: LLM, llm_mode: LLMMode = LLMMode.SYNC, **kwargs
) -> "DocSet":
"""
Renders and runs a prompt on every Document of the DocSet.

Args:
prompt: The prompt to use. Must implement the ``render_document`` method
output_field: Field in properties to store the output.
llm: LLM to use for the inferences.
llm_mode: how to make the api calls to the llm - sync/async/batch
"""
from sycamore.transforms.base_llm import LLMMap

llm_map = LLMMap(self.plan, prompt=prompt, output_field=output_field, llm=llm, llm_mode=llm_mode, **kwargs)
return DocSet(self.context, llm_map)

def llm_map_elements(
self, prompt: SycamorePrompt, output_field: str, llm: LLM, llm_mode: LLMMode = LLMMode.SYNC, **kwargs
) -> "DocSet":
"""
Renders and runs a prompt on every Element of every Document in the DocSet.

Args:
prompt: The prompt to use. Must implement the ``render_document`` method
output_field: Field in properties to store the output.
llm: LLM to use for the inferences.
llm_mode: how to make the api calls to the llm - sync/async/batch
"""
from sycamore.transforms.base_llm import LLMMapElements

llm_map_elements = LLMMapElements(
self.plan, prompt=prompt, output_field=output_field, llm=llm, llm_mode=llm_mode, **kwargs
)
return DocSet(self.context, llm_map_elements)

def filter(self, f: Callable[[Document], bool], **kwargs) -> "DocSet":
"""
Applies the Filter transform on the Docset.
Expand Down Expand Up @@ -1356,7 +1393,7 @@ def llm_cluster_entity(self, llm: LLM, instruction: str, field: str, **kwargs) -
prompt_kwargs = {"messages": messages}

# call to LLM
completion = llm.generate(prompt_kwargs=prompt_kwargs, llm_kwargs={"temperature": 0})
completion = llm.generate_old(prompt_kwargs=prompt_kwargs, llm_kwargs={"temperature": 0})

groups = extract_json(completion)

Expand Down
6 changes: 3 additions & 3 deletions lib/sycamore/sycamore/evaluation/subtasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sycamore.docset import DocSet
from sycamore.llms.llms import LLM
from sycamore.llms.openai import OpenAI, OpenAIModels
from sycamore.llms.prompts.default_prompts import SimpleGuidancePrompt, TaskIdentifierZeroShotGuidancePrompt
from sycamore.llms.prompts.default_prompts import SimplePrompt, _TaskIdentifierZeroShotGuidancePrompt
from sycamore.transforms.embed import Embedder, SentenceTransformerEmbedder
from sycamore.transforms.query import QueryExecutor

Expand All @@ -22,7 +22,7 @@ def __init__(
model_name="sentence-transformers/all-MiniLM-L6-v2", batch_size=100
),
llm: LLM = OpenAI(OpenAIModels.GPT_3_5_TURBO.value),
prompt: SimpleGuidancePrompt = TaskIdentifierZeroShotGuidancePrompt(),
prompt: SimplePrompt = _TaskIdentifierZeroShotGuidancePrompt(),
knn_query: bool = False,
):
if subtask_data:
Expand All @@ -44,7 +44,7 @@ def __init__(
def _get_formulas(self, document: Document) -> list[Document]:
f_list = []
if document.properties["subtasks_reqd"]:
task_id = self._llm.generate(
task_id = self._llm.generate_old(
prompt_kwargs={
"prompt": self._prompt,
"question": document["question"],
Expand Down
69 changes: 47 additions & 22 deletions lib/sycamore/sycamore/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from PIL import Image

from sycamore.llms.llms import LLM
from sycamore.llms.prompts.default_prompts import SimplePrompt
from sycamore.llms.prompts import RenderedPrompt
from sycamore.utils.cache import Cache
from sycamore.utils.image_utils import base64_data
from sycamore.utils.import_utils import requires_modules
Expand Down Expand Up @@ -49,29 +49,54 @@ def rewrite_system_messages(messages: Optional[list[dict]]) -> Optional[list[dic
return [m for m in messages if m.get("role") != "system"]


def get_generate_kwargs(prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> dict:
def get_generate_kwargs(prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict:
kwargs = {
"temperature": 0,
**(llm_kwargs or {}),
}

kwargs["max_tokens"] = kwargs.get("max_tokens", DEFAULT_MAX_TOKENS)

if "prompt" in prompt_kwargs:
prompt = prompt_kwargs.get("prompt")

if isinstance(prompt, SimplePrompt):
kwargs.update({"messages": prompt.as_messages(prompt_kwargs)})
# Anthropic models require _exactly_ alternation between "user" and "assistant"
# roles, so we break the messages into groups of consecutive user/assistant
# messages, treating "system" as "user". Then crunch each group down to a single
# message to ensure alternation.
message_groups = [] # type: ignore
last_role = None

for m in prompt.messages:
r = m.role
if r == "system":
r = "user"
if r != last_role:
message_groups.append([])
message_groups[-1].append(m)
last_role = r

messages = []
for group in message_groups:
role = group[0].role
if role == "system":
role = "user"
content = "\n".join(m.content for m in group)
if any(m.images is not None for m in group):
images = [im for m in group for im in m.images]
contents = [{"type": "text", "text": content}]
for im in images:
contents.append(
{ # type: ignore
"type": "image",
"source": { # type: ignore
"type": "base64",
"media_type": "image/png",
"data": base64_data(im),
},
}
)
messages.append({"role": role, "content": contents})
else:
kwargs.update({"messages": [{"role": "user", "content": f"{prompt}"}]})

elif "messages" in prompt_kwargs:
kwargs.update({"messages": prompt_kwargs["messages"]})
else:
raise ValueError("Either prompt or messages must be present in prompt_kwargs.")

kwargs["messages"] = rewrite_system_messages(kwargs["messages"])
messages.append({"role": role, "content": content})

kwargs["messages"] = messages
return kwargs


Expand Down Expand Up @@ -128,12 +153,12 @@ def is_chat_mode(self) -> bool:
def format_image(self, image: Image.Image) -> dict[str, Any]:
return format_image(image)

def generate_metadata(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> dict:
ret = self._llm_cache_get(prompt_kwargs, llm_kwargs)
def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict:
ret = self._llm_cache_get(prompt, llm_kwargs)
if isinstance(ret, dict):
return ret

kwargs = get_generate_kwargs(prompt_kwargs, llm_kwargs)
kwargs = get_generate_kwargs(prompt, llm_kwargs)

start = datetime.now()

Expand All @@ -153,9 +178,9 @@ def generate_metadata(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] =
self.add_llm_metadata(kwargs, output, wall_latency, in_tokens, out_tokens)
logging.debug(f"Generated response from Anthropic model: {ret}")

self._llm_cache_set(prompt_kwargs, llm_kwargs, ret)
self._llm_cache_set(prompt, llm_kwargs, ret)
return ret

def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> str:
d = self.generate_metadata(prompt_kwargs=prompt_kwargs, llm_kwargs=llm_kwargs)
def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str:
d = self.generate_metadata(prompt=prompt, llm_kwargs=llm_kwargs)
return d["output"]
13 changes: 7 additions & 6 deletions lib/sycamore/sycamore/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from sycamore.llms.llms import LLM
from sycamore.llms.anthropic import format_image, get_generate_kwargs
from sycamore.llms.prompts.prompts import RenderedPrompt
from sycamore.utils.cache import Cache

DEFAULT_MAX_TOKENS = 1000
Expand Down Expand Up @@ -77,14 +78,14 @@ def format_image(self, image: Image.Image) -> dict[str, Any]:
return format_image(image)
raise NotImplementedError("Images not supported for non-Anthropic Bedrock models.")

def generate_metadata(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> dict:
ret = self._llm_cache_get(prompt_kwargs, llm_kwargs)
def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict:
ret = self._llm_cache_get(prompt, llm_kwargs)
if isinstance(ret, dict):
print(f"cache return {ret}")
return ret
assert ret is None

kwargs = get_generate_kwargs(prompt_kwargs, llm_kwargs)
kwargs = get_generate_kwargs(prompt, llm_kwargs)
if self._model_name.startswith("anthropic."):
anthropic_version = (
DEFAULT_ANTHROPIC_VERSION
Expand Down Expand Up @@ -115,9 +116,9 @@ def generate_metadata(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] =
"out_tokens": out_tokens,
}
self.add_llm_metadata(kwargs, output, wall_latency, in_tokens, out_tokens)
self._llm_cache_set(prompt_kwargs, llm_kwargs, ret)
self._llm_cache_set(prompt, llm_kwargs, ret)
return ret

def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> str:
d = self.generate_metadata(prompt_kwargs=prompt_kwargs, llm_kwargs=llm_kwargs)
def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str:
d = self.generate_metadata(prompt=prompt, llm_kwargs=llm_kwargs)
return d["output"]
Loading
Loading