Skip to content

Commit

Permalink
capturing metadata from LLMs (#1122)
Browse files Browse the repository at this point in the history
  • Loading branch information
Soeb-aryn authored Jan 28, 2025
1 parent 1904622 commit 088f4b6
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 18 deletions.
3 changes: 1 addition & 2 deletions lib/sycamore/sycamore/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def generate_metadata(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] =
response = self._client.messages.create(model=self.model.value, **kwargs)

wall_latency = datetime.now() - start

in_tokens = response.usage.input_tokens
out_tokens = response.usage.output_tokens
output = response.content[0].text
Expand All @@ -151,7 +150,7 @@ def generate_metadata(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] =
"in_tokens": in_tokens,
"out_tokens": out_tokens,
}

self.add_llm_metadata(kwargs, output, wall_latency, in_tokens, out_tokens)
logging.debug(f"Generated response from Anthropic model: {ret}")

self._llm_cache_set(prompt_kwargs, llm_kwargs, ret)
Expand Down
1 change: 1 addition & 0 deletions lib/sycamore/sycamore/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def generate_metadata(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] =
"in_tokens": in_tokens,
"out_tokens": out_tokens,
}
self.add_llm_metadata(kwargs, output, wall_latency, in_tokens, out_tokens)
self._llm_cache_set(prompt_kwargs, llm_kwargs, ret)
return ret

Expand Down
23 changes: 23 additions & 0 deletions lib/sycamore/sycamore/llms/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from PIL import Image
from typing import Any, Optional
from sycamore.utils.cache import Cache
from sycamore.utils.thread_local import ThreadLocalAccess, ADD_METADATA_TO_OUTPUT
from sycamore.data.metadata import add_metadata


class LLM(ABC):
Expand Down Expand Up @@ -91,6 +93,27 @@ def _llm_cache_set(self, prompt_kwargs: dict, llm_kwargs: Optional[dict], result
},
)

def get_metadata(self, kwargs, response_text, wall_latency, in_tokens, out_tokens) -> dict:
"""Generate metadata for the LLM response."""
return {
"model": self._model_name,
"temperature": kwargs.get("temperature", None),
"usage": {
"completion_tokens": in_tokens,
"prompt_tokens": out_tokens,
"total_tokens": in_tokens + out_tokens,
},
"wall_latency": wall_latency,
"prompt": kwargs.get("prompt") or kwargs.get("messages"),
"output": response_text,
}

def add_llm_metadata(self, kwargs, output, wall_latency, in_tokens, out_tokens):
tls = ThreadLocalAccess(ADD_METADATA_TO_OUTPUT)
if tls.present():
metadata = self.get_metadata(kwargs, output, wall_latency, in_tokens, out_tokens)
add_metadata(**metadata)


class FakeLLM(LLM):
"""Useful for tests where the fake LLM needs to run in a ray function because mocks are not serializable"""
Expand Down
65 changes: 52 additions & 13 deletions lib/sycamore/sycamore/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from dataclasses import dataclass
from enum import Enum
from PIL import Image
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Tuple, Union
from datetime import datetime

from openai import AzureOpenAI as AzureOpenAIClient
from openai import AsyncAzureOpenAI as AsyncAzureOpenAIClient
Expand All @@ -23,7 +24,6 @@
from sycamore.utils.cache import Cache
from sycamore.utils.image_utils import base64_data_url


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -291,6 +291,15 @@ def is_chat_mode(self):
def format_image(self, image: Image.Image) -> dict[str, Any]:
return {"type": "image_url", "image_url": {"url": base64_data_url(image)}}

def validate_tokens(self, completion) -> Tuple[int, int]:
if completion.usage is not None:
completion_tokens = completion.usage.completion_tokens or 0
prompt_tokens = completion.usage.prompt_tokens or 0
else:
completion_tokens = 0
prompt_tokens = 0
return completion_tokens, prompt_tokens

def _convert_response_format(self, llm_kwargs: Optional[Dict]) -> Optional[Dict]:
"""Convert the response_format parameter to the appropriate OpenAI format."""
if llm_kwargs is None:
Expand Down Expand Up @@ -365,26 +374,41 @@ def _generate_using_openai(self, prompt_kwargs, llm_kwargs) -> str:
kwargs = self._get_generate_kwargs(prompt_kwargs, llm_kwargs)
logging.debug("OpenAI prompt: %s", kwargs)
if self.is_chat_mode():
starttime = datetime.now()
completion = self.client_wrapper.get_client().chat.completions.create(model=self._model_name, **kwargs)
logging.debug("OpenAI completion: %s", completion)
return completion.choices[0].message.content
wall_latency = datetime.now() - starttime
response_text = completion.choices[0].message.content
else:
starttime = datetime.now()
completion = self.client_wrapper.get_client().completions.create(model=self._model_name, **kwargs)
logging.debug("OpenAI completion: %s", completion)
return completion.choices[0].text
wall_latency = datetime.now() - starttime
response_text = completion.choices[0].text

completion_tokens, prompt_tokens = self.validate_tokens(completion)
self.add_llm_metadata(kwargs, response_text, wall_latency, completion_tokens, prompt_tokens)
if not response_text:
raise ValueError("OpenAI returned empty response")
return response_text

def _generate_using_openai_structured(self, prompt_kwargs, llm_kwargs) -> str:
try:
kwargs = self._get_generate_kwargs(prompt_kwargs, llm_kwargs)
if self.is_chat_mode():
starttime = datetime.now()
completion = self.client_wrapper.get_client().beta.chat.completions.parse(
model=self._model_name, **kwargs
)
completion_tokens, prompt_tokens = self.validate_tokens(completion)
wall_latency = datetime.now() - starttime
response_text = completion.choices[0].message.content
self.add_llm_metadata(kwargs, response_text, wall_latency, completion_tokens, prompt_tokens)
else:
raise ValueError("This method doesn't support instruct models. Please use a chat model.")
# completion = self.client_wrapper.get_client().beta.completions.parse(model=self._model_name, **kwargs)
assert completion.choices[0].message.content is not None, "OpenAI refused to respond to the query"
return completion.choices[0].message.content
assert response_text is not None, "OpenAI refused to respond to the query"
return response_text
except Exception as e:
# OpenAI will not respond in two scenarios:
# 1.) The LLM ran out of output context length(usually do to hallucination of repeating the same phrase)
Expand All @@ -408,31 +432,46 @@ async def generate_async(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict

async def _generate_awaitable_using_openai(self, prompt_kwargs, llm_kwargs) -> str:
kwargs = self._get_generate_kwargs(prompt_kwargs, llm_kwargs)
starttime = datetime.now()
if self.is_chat_mode():
completion = await self.client_wrapper.get_async_client().chat.completions.create(
model=self._model_name, **kwargs
)
return completion.choices[0].message.content
response_text = completion.choices[0].message.content
else:
completion = await self.client_wrapper.get_async_client().completions.create(
model=self._model_name, **kwargs
)
return completion.choices[0].text
response_text = completion.choices[0].text
wall_latency = datetime.now() - starttime
response_text = completion.choices[0].message.content

if completion.usage is not None:
completion_tokens = completion.usage.completion_tokens or 0
prompt_tokens = completion.usage.prompt_tokens or 0
else:
completion_tokens = 0
prompt_tokens = 0

self.add_llm_metadata(kwargs, response_text, wall_latency, completion_tokens, prompt_tokens)
return response_text

async def _generate_awaitable_using_openai_structured(self, prompt_kwargs, llm_kwargs) -> str:
try:
kwargs = self._get_generate_kwargs(prompt_kwargs, llm_kwargs)
if self.is_chat_mode():
starttime = datetime.now()
completion = await self.client_wrapper.get_async_client().beta.chat.completions.parse(
model=self._model_name, **kwargs
)
wall_latency = datetime.now() - starttime
else:
raise ValueError("This method doesn't support instruct models. Please use a chat model.")
# completion = await self.client_wrapper.get_async_client().beta.completions.parse(
# model=self._model_name, **kwargs
# )
assert completion.choices[0].message.content is not None, "OpenAI refused to respond to the query"
return completion.choices[0].message.content
response_text = completion.choices[0].message.content
assert response_text is not None, "OpenAI refused to respond to the query"
completion_tokens, prompt_tokens = self.validate_tokens(completion)
self.add_llm_metadata(kwargs, response_text, wall_latency, completion_tokens, prompt_tokens)
return response_text
except Exception as e:
# OpenAI will not respond in two scenarios:
# 1.) The LLM ran out of output context length(usually do to hallucination of repeating the same phrase)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,16 @@ def test_extract_properties_from_dict_schema(llm):
docs = ctx.read.document(docs)
docs = docs.extract_properties(property_extractor)

taken = docs.take_all()
taken = docs.take_all(include_metadata=True)

assert taken[0].properties["entity"]["name"] == "Vinayak"
assert taken[0].properties["entity"]["age"] == 74
assert "Honolulu" in taken[0].properties["entity"]["from_location"]

assert len(taken) == 3
assert taken[2].metadata["usage"]["prompt_tokens"] > 0
assert taken[2].metadata["usage"]["completion_tokens"] > 0


@pytest.mark.parametrize("llm", llms)
def test_extract_properties_from_schema(llm):
Expand All @@ -61,6 +65,7 @@ def test_extract_properties_from_schema(llm):
field_type="str",
description="This is the name of an entity",
examples=["Mark", "Ollie", "Winston"],
default="null",
),
SchemaField(name="age", field_type="int", default=999),
SchemaField(name="date", field_type="str", description="Any date in the doc in YYYY-MM-DD format"),
Expand All @@ -80,14 +85,20 @@ def test_extract_properties_from_schema(llm):
docs = ctx.read.document(docs)
docs = docs.extract_properties(property_extractor)

taken = docs.take_all()
taken = docs.take_all(include_metadata=True)

assert taken[0].properties["entity"]["name"] == "Vinayak"
assert taken[0].properties["entity"]["age"] == 74
assert taken[0].properties["entity"]["from_location"] == "Honolulu, HI", "Invalid location extracted or formatted"
assert taken[0].properties["entity"]["date"] == "1923-02-24"

assert taken[1].properties["entity"]["name"] is None, "Default None value not being used correctly"
assert taken[1].properties["entity"]["name"] == "None" # Anthropic isn't generating valid JSON with null values.
assert taken[1].properties["entity"]["age"] == 999, "Default value not being used correctly"
assert taken[1].properties["entity"]["from_location"] == "New Delhi"
assert taken[1].properties["entity"]["date"] == "2014-01-11"

assert len(taken) == 5
assert taken[3].metadata["usage"]["prompt_tokens"] > 0
assert taken[3].metadata["usage"]["completion_tokens"] > 0
assert taken[4].metadata["usage"]["prompt_tokens"] > 0
assert taken[4].metadata["usage"]["completion_tokens"] > 0
32 changes: 32 additions & 0 deletions lib/sycamore/sycamore/tests/unit/llms/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,38 @@
from sycamore.llms.llms import FakeLLM
from sycamore.llms.prompts import EntityExtractorFewShotGuidancePrompt, EntityExtractorZeroShotGuidancePrompt
from sycamore.utils.cache import DiskCache
import datetime
from sycamore.utils.thread_local import ThreadLocalAccess


def test_get_metadata():
llm = FakeLLM()
wall_latency = datetime.timedelta(seconds=1)
metadata = llm.get_metadata({"prompt": "Hello", "temperature": 0.7}, "Test output", wall_latency, 10, 5)
assert metadata["model"] == llm._model_name
assert metadata["usage"] == {
"completion_tokens": 10,
"prompt_tokens": 5,
"total_tokens": 15,
}
assert metadata["prompt"] == "Hello"
assert metadata["output"] == "Test output"
assert metadata["temperature"] == 0.7
assert metadata["wall_latency"] == wall_latency


@patch("sycamore.llms.llms.add_metadata")
def test_add_llm_metadata(mock_add_metadata):
llm = FakeLLM()
with patch.object(ThreadLocalAccess, "present", return_value=True):
llm.add_llm_metadata({}, "Test output", datetime.timedelta(seconds=0.5), 1, 2)
mock_add_metadata.assert_called_once()

# If TLS not present, add_metadata should not be called
mock_add_metadata.reset_mock()
with patch.object(ThreadLocalAccess, "present", return_value=False):
llm.add_llm_metadata({}, "Test output", datetime.timedelta(seconds=0.5), 1, 2)
mock_add_metadata.assert_not_called()


def test_openai_davinci_fallback():
Expand Down

0 comments on commit 088f4b6

Please sign in to comment.