diff --git a/lib/sycamore/sycamore/llms/anthropic.py b/lib/sycamore/sycamore/llms/anthropic.py index 81bc4903d..ec400e9a0 100644 --- a/lib/sycamore/sycamore/llms/anthropic.py +++ b/lib/sycamore/sycamore/llms/anthropic.py @@ -140,7 +140,6 @@ def generate_metadata(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = response = self._client.messages.create(model=self.model.value, **kwargs) wall_latency = datetime.now() - start - in_tokens = response.usage.input_tokens out_tokens = response.usage.output_tokens output = response.content[0].text @@ -151,7 +150,7 @@ def generate_metadata(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = "in_tokens": in_tokens, "out_tokens": out_tokens, } - + self.add_llm_metadata(kwargs, output, wall_latency, in_tokens, out_tokens) logging.debug(f"Generated response from Anthropic model: {ret}") self._llm_cache_set(prompt_kwargs, llm_kwargs, ret) diff --git a/lib/sycamore/sycamore/llms/bedrock.py b/lib/sycamore/sycamore/llms/bedrock.py index a7d115540..07855062f 100644 --- a/lib/sycamore/sycamore/llms/bedrock.py +++ b/lib/sycamore/sycamore/llms/bedrock.py @@ -114,6 +114,7 @@ def generate_metadata(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = "in_tokens": in_tokens, "out_tokens": out_tokens, } + self.add_llm_metadata(kwargs, output, wall_latency, in_tokens, out_tokens) self._llm_cache_set(prompt_kwargs, llm_kwargs, ret) return ret diff --git a/lib/sycamore/sycamore/llms/llms.py b/lib/sycamore/sycamore/llms/llms.py index dc0541862..a3e1c4743 100644 --- a/lib/sycamore/sycamore/llms/llms.py +++ b/lib/sycamore/sycamore/llms/llms.py @@ -3,6 +3,8 @@ from PIL import Image from typing import Any, Optional from sycamore.utils.cache import Cache +from sycamore.utils.thread_local import ThreadLocalAccess, ADD_METADATA_TO_OUTPUT +from sycamore.data.metadata import add_metadata class LLM(ABC): @@ -91,6 +93,27 @@ def _llm_cache_set(self, prompt_kwargs: dict, llm_kwargs: Optional[dict], result }, ) + def get_metadata(self, kwargs, response_text, wall_latency, in_tokens, out_tokens) -> dict: + """Generate metadata for the LLM response.""" + return { + "model": self._model_name, + "temperature": kwargs.get("temperature", None), + "usage": { + "completion_tokens": in_tokens, + "prompt_tokens": out_tokens, + "total_tokens": in_tokens + out_tokens, + }, + "wall_latency": wall_latency, + "prompt": kwargs.get("prompt") or kwargs.get("messages"), + "output": response_text, + } + + def add_llm_metadata(self, kwargs, output, wall_latency, in_tokens, out_tokens): + tls = ThreadLocalAccess(ADD_METADATA_TO_OUTPUT) + if tls.present(): + metadata = self.get_metadata(kwargs, output, wall_latency, in_tokens, out_tokens) + add_metadata(**metadata) + class FakeLLM(LLM): """Useful for tests where the fake LLM needs to run in a ray function because mocks are not serializable""" diff --git a/lib/sycamore/sycamore/llms/openai.py b/lib/sycamore/sycamore/llms/openai.py index f90d3e1d9..6e44fd35b 100644 --- a/lib/sycamore/sycamore/llms/openai.py +++ b/lib/sycamore/sycamore/llms/openai.py @@ -5,7 +5,8 @@ from dataclasses import dataclass from enum import Enum from PIL import Image -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union +from datetime import datetime from openai import AzureOpenAI as AzureOpenAIClient from openai import AsyncAzureOpenAI as AsyncAzureOpenAIClient @@ -23,7 +24,6 @@ from sycamore.utils.cache import Cache from sycamore.utils.image_utils import base64_data_url - logger = logging.getLogger(__name__) @@ -291,6 +291,15 @@ def is_chat_mode(self): def format_image(self, image: Image.Image) -> dict[str, Any]: return {"type": "image_url", "image_url": {"url": base64_data_url(image)}} + def validate_tokens(self, completion) -> Tuple[int, int]: + if completion.usage is not None: + completion_tokens = completion.usage.completion_tokens or 0 + prompt_tokens = completion.usage.prompt_tokens or 0 + else: + completion_tokens = 0 + prompt_tokens = 0 + return completion_tokens, prompt_tokens + def _convert_response_format(self, llm_kwargs: Optional[Dict]) -> Optional[Dict]: """Convert the response_format parameter to the appropriate OpenAI format.""" if llm_kwargs is None: @@ -365,26 +374,41 @@ def _generate_using_openai(self, prompt_kwargs, llm_kwargs) -> str: kwargs = self._get_generate_kwargs(prompt_kwargs, llm_kwargs) logging.debug("OpenAI prompt: %s", kwargs) if self.is_chat_mode(): + starttime = datetime.now() completion = self.client_wrapper.get_client().chat.completions.create(model=self._model_name, **kwargs) logging.debug("OpenAI completion: %s", completion) - return completion.choices[0].message.content + wall_latency = datetime.now() - starttime + response_text = completion.choices[0].message.content else: + starttime = datetime.now() completion = self.client_wrapper.get_client().completions.create(model=self._model_name, **kwargs) logging.debug("OpenAI completion: %s", completion) - return completion.choices[0].text + wall_latency = datetime.now() - starttime + response_text = completion.choices[0].text + + completion_tokens, prompt_tokens = self.validate_tokens(completion) + self.add_llm_metadata(kwargs, response_text, wall_latency, completion_tokens, prompt_tokens) + if not response_text: + raise ValueError("OpenAI returned empty response") + return response_text def _generate_using_openai_structured(self, prompt_kwargs, llm_kwargs) -> str: try: kwargs = self._get_generate_kwargs(prompt_kwargs, llm_kwargs) if self.is_chat_mode(): + starttime = datetime.now() completion = self.client_wrapper.get_client().beta.chat.completions.parse( model=self._model_name, **kwargs ) + completion_tokens, prompt_tokens = self.validate_tokens(completion) + wall_latency = datetime.now() - starttime + response_text = completion.choices[0].message.content + self.add_llm_metadata(kwargs, response_text, wall_latency, completion_tokens, prompt_tokens) else: raise ValueError("This method doesn't support instruct models. Please use a chat model.") # completion = self.client_wrapper.get_client().beta.completions.parse(model=self._model_name, **kwargs) - assert completion.choices[0].message.content is not None, "OpenAI refused to respond to the query" - return completion.choices[0].message.content + assert response_text is not None, "OpenAI refused to respond to the query" + return response_text except Exception as e: # OpenAI will not respond in two scenarios: # 1.) The LLM ran out of output context length(usually do to hallucination of repeating the same phrase) @@ -408,31 +432,46 @@ async def generate_async(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict async def _generate_awaitable_using_openai(self, prompt_kwargs, llm_kwargs) -> str: kwargs = self._get_generate_kwargs(prompt_kwargs, llm_kwargs) + starttime = datetime.now() if self.is_chat_mode(): completion = await self.client_wrapper.get_async_client().chat.completions.create( model=self._model_name, **kwargs ) - return completion.choices[0].message.content + response_text = completion.choices[0].message.content else: completion = await self.client_wrapper.get_async_client().completions.create( model=self._model_name, **kwargs ) - return completion.choices[0].text + response_text = completion.choices[0].text + wall_latency = datetime.now() - starttime + response_text = completion.choices[0].message.content + + if completion.usage is not None: + completion_tokens = completion.usage.completion_tokens or 0 + prompt_tokens = completion.usage.prompt_tokens or 0 + else: + completion_tokens = 0 + prompt_tokens = 0 + + self.add_llm_metadata(kwargs, response_text, wall_latency, completion_tokens, prompt_tokens) + return response_text async def _generate_awaitable_using_openai_structured(self, prompt_kwargs, llm_kwargs) -> str: try: kwargs = self._get_generate_kwargs(prompt_kwargs, llm_kwargs) if self.is_chat_mode(): + starttime = datetime.now() completion = await self.client_wrapper.get_async_client().beta.chat.completions.parse( model=self._model_name, **kwargs ) + wall_latency = datetime.now() - starttime else: raise ValueError("This method doesn't support instruct models. Please use a chat model.") - # completion = await self.client_wrapper.get_async_client().beta.completions.parse( - # model=self._model_name, **kwargs - # ) - assert completion.choices[0].message.content is not None, "OpenAI refused to respond to the query" - return completion.choices[0].message.content + response_text = completion.choices[0].message.content + assert response_text is not None, "OpenAI refused to respond to the query" + completion_tokens, prompt_tokens = self.validate_tokens(completion) + self.add_llm_metadata(kwargs, response_text, wall_latency, completion_tokens, prompt_tokens) + return response_text except Exception as e: # OpenAI will not respond in two scenarios: # 1.) The LLM ran out of output context length(usually do to hallucination of repeating the same phrase) diff --git a/lib/sycamore/sycamore/tests/integration/transforms/test_data_extraction.py b/lib/sycamore/sycamore/tests/integration/transforms/test_data_extraction.py index cb6e226a6..e3f415417 100644 --- a/lib/sycamore/sycamore/tests/integration/transforms/test_data_extraction.py +++ b/lib/sycamore/sycamore/tests/integration/transforms/test_data_extraction.py @@ -43,12 +43,16 @@ def test_extract_properties_from_dict_schema(llm): docs = ctx.read.document(docs) docs = docs.extract_properties(property_extractor) - taken = docs.take_all() + taken = docs.take_all(include_metadata=True) assert taken[0].properties["entity"]["name"] == "Vinayak" assert taken[0].properties["entity"]["age"] == 74 assert "Honolulu" in taken[0].properties["entity"]["from_location"] + assert len(taken) == 3 + assert taken[2].metadata["usage"]["prompt_tokens"] > 0 + assert taken[2].metadata["usage"]["completion_tokens"] > 0 + @pytest.mark.parametrize("llm", llms) def test_extract_properties_from_schema(llm): @@ -61,6 +65,7 @@ def test_extract_properties_from_schema(llm): field_type="str", description="This is the name of an entity", examples=["Mark", "Ollie", "Winston"], + default="null", ), SchemaField(name="age", field_type="int", default=999), SchemaField(name="date", field_type="str", description="Any date in the doc in YYYY-MM-DD format"), @@ -80,14 +85,20 @@ def test_extract_properties_from_schema(llm): docs = ctx.read.document(docs) docs = docs.extract_properties(property_extractor) - taken = docs.take_all() + taken = docs.take_all(include_metadata=True) assert taken[0].properties["entity"]["name"] == "Vinayak" assert taken[0].properties["entity"]["age"] == 74 assert taken[0].properties["entity"]["from_location"] == "Honolulu, HI", "Invalid location extracted or formatted" assert taken[0].properties["entity"]["date"] == "1923-02-24" - assert taken[1].properties["entity"]["name"] is None, "Default None value not being used correctly" + assert taken[1].properties["entity"]["name"] == "None" # Anthropic isn't generating valid JSON with null values. assert taken[1].properties["entity"]["age"] == 999, "Default value not being used correctly" assert taken[1].properties["entity"]["from_location"] == "New Delhi" assert taken[1].properties["entity"]["date"] == "2014-01-11" + + assert len(taken) == 5 + assert taken[3].metadata["usage"]["prompt_tokens"] > 0 + assert taken[3].metadata["usage"]["completion_tokens"] > 0 + assert taken[4].metadata["usage"]["prompt_tokens"] > 0 + assert taken[4].metadata["usage"]["completion_tokens"] > 0 diff --git a/lib/sycamore/sycamore/tests/unit/llms/test_llms.py b/lib/sycamore/sycamore/tests/unit/llms/test_llms.py index 62cc4aed1..ab76de5f3 100644 --- a/lib/sycamore/sycamore/tests/unit/llms/test_llms.py +++ b/lib/sycamore/sycamore/tests/unit/llms/test_llms.py @@ -5,6 +5,38 @@ from sycamore.llms.llms import FakeLLM from sycamore.llms.prompts import EntityExtractorFewShotGuidancePrompt, EntityExtractorZeroShotGuidancePrompt from sycamore.utils.cache import DiskCache +import datetime +from sycamore.utils.thread_local import ThreadLocalAccess + + +def test_get_metadata(): + llm = FakeLLM() + wall_latency = datetime.timedelta(seconds=1) + metadata = llm.get_metadata({"prompt": "Hello", "temperature": 0.7}, "Test output", wall_latency, 10, 5) + assert metadata["model"] == llm._model_name + assert metadata["usage"] == { + "completion_tokens": 10, + "prompt_tokens": 5, + "total_tokens": 15, + } + assert metadata["prompt"] == "Hello" + assert metadata["output"] == "Test output" + assert metadata["temperature"] == 0.7 + assert metadata["wall_latency"] == wall_latency + + +@patch("sycamore.llms.llms.add_metadata") +def test_add_llm_metadata(mock_add_metadata): + llm = FakeLLM() + with patch.object(ThreadLocalAccess, "present", return_value=True): + llm.add_llm_metadata({}, "Test output", datetime.timedelta(seconds=0.5), 1, 2) + mock_add_metadata.assert_called_once() + + # If TLS not present, add_metadata should not be called + mock_add_metadata.reset_mock() + with patch.object(ThreadLocalAccess, "present", return_value=False): + llm.add_llm_metadata({}, "Test output", datetime.timedelta(seconds=0.5), 1, 2) + mock_add_metadata.assert_not_called() def test_openai_davinci_fallback():