From 088f4b6a96df42f07499cfdffe44c9869ec8afd9 Mon Sep 17 00:00:00 2001
From: Soeb-aryn <soebh@aryn.ai>
Date: Mon, 27 Jan 2025 16:22:56 -0800
Subject: [PATCH] capturing metadata from LLMs  (#1122)

---
 lib/sycamore/sycamore/llms/anthropic.py       |  3 +-
 lib/sycamore/sycamore/llms/bedrock.py         |  1 +
 lib/sycamore/sycamore/llms/llms.py            | 23 +++++++
 lib/sycamore/sycamore/llms/openai.py          | 65 +++++++++++++++----
 .../transforms/test_data_extraction.py        | 17 ++++-
 .../sycamore/tests/unit/llms/test_llms.py     | 32 +++++++++
 6 files changed, 123 insertions(+), 18 deletions(-)

diff --git a/lib/sycamore/sycamore/llms/anthropic.py b/lib/sycamore/sycamore/llms/anthropic.py
index 81bc4903d..ec400e9a0 100644
--- a/lib/sycamore/sycamore/llms/anthropic.py
+++ b/lib/sycamore/sycamore/llms/anthropic.py
@@ -140,7 +140,6 @@ def generate_metadata(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] =
         response = self._client.messages.create(model=self.model.value, **kwargs)
 
         wall_latency = datetime.now() - start
-
         in_tokens = response.usage.input_tokens
         out_tokens = response.usage.output_tokens
         output = response.content[0].text
@@ -151,7 +150,7 @@ def generate_metadata(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] =
             "in_tokens": in_tokens,
             "out_tokens": out_tokens,
         }
-
+        self.add_llm_metadata(kwargs, output, wall_latency, in_tokens, out_tokens)
         logging.debug(f"Generated response from Anthropic model: {ret}")
 
         self._llm_cache_set(prompt_kwargs, llm_kwargs, ret)
diff --git a/lib/sycamore/sycamore/llms/bedrock.py b/lib/sycamore/sycamore/llms/bedrock.py
index a7d115540..07855062f 100644
--- a/lib/sycamore/sycamore/llms/bedrock.py
+++ b/lib/sycamore/sycamore/llms/bedrock.py
@@ -114,6 +114,7 @@ def generate_metadata(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] =
             "in_tokens": in_tokens,
             "out_tokens": out_tokens,
         }
+        self.add_llm_metadata(kwargs, output, wall_latency, in_tokens, out_tokens)
         self._llm_cache_set(prompt_kwargs, llm_kwargs, ret)
         return ret
 
diff --git a/lib/sycamore/sycamore/llms/llms.py b/lib/sycamore/sycamore/llms/llms.py
index dc0541862..a3e1c4743 100644
--- a/lib/sycamore/sycamore/llms/llms.py
+++ b/lib/sycamore/sycamore/llms/llms.py
@@ -3,6 +3,8 @@
 from PIL import Image
 from typing import Any, Optional
 from sycamore.utils.cache import Cache
+from sycamore.utils.thread_local import ThreadLocalAccess, ADD_METADATA_TO_OUTPUT
+from sycamore.data.metadata import add_metadata
 
 
 class LLM(ABC):
@@ -91,6 +93,27 @@ def _llm_cache_set(self, prompt_kwargs: dict, llm_kwargs: Optional[dict], result
             },
         )
 
+    def get_metadata(self, kwargs, response_text, wall_latency, in_tokens, out_tokens) -> dict:
+        """Generate metadata for the LLM response."""
+        return {
+            "model": self._model_name,
+            "temperature": kwargs.get("temperature", None),
+            "usage": {
+                "completion_tokens": in_tokens,
+                "prompt_tokens": out_tokens,
+                "total_tokens": in_tokens + out_tokens,
+            },
+            "wall_latency": wall_latency,
+            "prompt": kwargs.get("prompt") or kwargs.get("messages"),
+            "output": response_text,
+        }
+
+    def add_llm_metadata(self, kwargs, output, wall_latency, in_tokens, out_tokens):
+        tls = ThreadLocalAccess(ADD_METADATA_TO_OUTPUT)
+        if tls.present():
+            metadata = self.get_metadata(kwargs, output, wall_latency, in_tokens, out_tokens)
+            add_metadata(**metadata)
+
 
 class FakeLLM(LLM):
     """Useful for tests where the fake LLM needs to run in a ray function because mocks are not serializable"""
diff --git a/lib/sycamore/sycamore/llms/openai.py b/lib/sycamore/sycamore/llms/openai.py
index f90d3e1d9..6e44fd35b 100644
--- a/lib/sycamore/sycamore/llms/openai.py
+++ b/lib/sycamore/sycamore/llms/openai.py
@@ -5,7 +5,8 @@
 from dataclasses import dataclass
 from enum import Enum
 from PIL import Image
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, Optional, Tuple, Union
+from datetime import datetime
 
 from openai import AzureOpenAI as AzureOpenAIClient
 from openai import AsyncAzureOpenAI as AsyncAzureOpenAIClient
@@ -23,7 +24,6 @@
 from sycamore.utils.cache import Cache
 from sycamore.utils.image_utils import base64_data_url
 
-
 logger = logging.getLogger(__name__)
 
 
@@ -291,6 +291,15 @@ def is_chat_mode(self):
     def format_image(self, image: Image.Image) -> dict[str, Any]:
         return {"type": "image_url", "image_url": {"url": base64_data_url(image)}}
 
+    def validate_tokens(self, completion) -> Tuple[int, int]:
+        if completion.usage is not None:
+            completion_tokens = completion.usage.completion_tokens or 0
+            prompt_tokens = completion.usage.prompt_tokens or 0
+        else:
+            completion_tokens = 0
+            prompt_tokens = 0
+        return completion_tokens, prompt_tokens
+
     def _convert_response_format(self, llm_kwargs: Optional[Dict]) -> Optional[Dict]:
         """Convert the response_format parameter to the appropriate OpenAI format."""
         if llm_kwargs is None:
@@ -365,26 +374,41 @@ def _generate_using_openai(self, prompt_kwargs, llm_kwargs) -> str:
         kwargs = self._get_generate_kwargs(prompt_kwargs, llm_kwargs)
         logging.debug("OpenAI prompt: %s", kwargs)
         if self.is_chat_mode():
+            starttime = datetime.now()
             completion = self.client_wrapper.get_client().chat.completions.create(model=self._model_name, **kwargs)
             logging.debug("OpenAI completion: %s", completion)
-            return completion.choices[0].message.content
+            wall_latency = datetime.now() - starttime
+            response_text = completion.choices[0].message.content
         else:
+            starttime = datetime.now()
             completion = self.client_wrapper.get_client().completions.create(model=self._model_name, **kwargs)
             logging.debug("OpenAI completion: %s", completion)
-            return completion.choices[0].text
+            wall_latency = datetime.now() - starttime
+            response_text = completion.choices[0].text
+
+        completion_tokens, prompt_tokens = self.validate_tokens(completion)
+        self.add_llm_metadata(kwargs, response_text, wall_latency, completion_tokens, prompt_tokens)
+        if not response_text:
+            raise ValueError("OpenAI returned empty response")
+        return response_text
 
     def _generate_using_openai_structured(self, prompt_kwargs, llm_kwargs) -> str:
         try:
             kwargs = self._get_generate_kwargs(prompt_kwargs, llm_kwargs)
             if self.is_chat_mode():
+                starttime = datetime.now()
                 completion = self.client_wrapper.get_client().beta.chat.completions.parse(
                     model=self._model_name, **kwargs
                 )
+                completion_tokens, prompt_tokens = self.validate_tokens(completion)
+                wall_latency = datetime.now() - starttime
+                response_text = completion.choices[0].message.content
+                self.add_llm_metadata(kwargs, response_text, wall_latency, completion_tokens, prompt_tokens)
             else:
                 raise ValueError("This method doesn't support instruct models. Please use a chat model.")
                 # completion = self.client_wrapper.get_client().beta.completions.parse(model=self._model_name, **kwargs)
-            assert completion.choices[0].message.content is not None, "OpenAI refused to respond to the query"
-            return completion.choices[0].message.content
+            assert response_text is not None, "OpenAI refused to respond to the query"
+            return response_text
         except Exception as e:
             # OpenAI will not respond in two scenarios:
             # 1.) The LLM ran out of output context length(usually do to hallucination of repeating the same phrase)
@@ -408,31 +432,46 @@ async def generate_async(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict
 
     async def _generate_awaitable_using_openai(self, prompt_kwargs, llm_kwargs) -> str:
         kwargs = self._get_generate_kwargs(prompt_kwargs, llm_kwargs)
+        starttime = datetime.now()
         if self.is_chat_mode():
             completion = await self.client_wrapper.get_async_client().chat.completions.create(
                 model=self._model_name, **kwargs
             )
-            return completion.choices[0].message.content
+            response_text = completion.choices[0].message.content
         else:
             completion = await self.client_wrapper.get_async_client().completions.create(
                 model=self._model_name, **kwargs
             )
-            return completion.choices[0].text
+            response_text = completion.choices[0].text
+            wall_latency = datetime.now() - starttime
+            response_text = completion.choices[0].message.content
+
+        if completion.usage is not None:
+            completion_tokens = completion.usage.completion_tokens or 0
+            prompt_tokens = completion.usage.prompt_tokens or 0
+        else:
+            completion_tokens = 0
+            prompt_tokens = 0
+
+        self.add_llm_metadata(kwargs, response_text, wall_latency, completion_tokens, prompt_tokens)
+        return response_text
 
     async def _generate_awaitable_using_openai_structured(self, prompt_kwargs, llm_kwargs) -> str:
         try:
             kwargs = self._get_generate_kwargs(prompt_kwargs, llm_kwargs)
             if self.is_chat_mode():
+                starttime = datetime.now()
                 completion = await self.client_wrapper.get_async_client().beta.chat.completions.parse(
                     model=self._model_name, **kwargs
                 )
+                wall_latency = datetime.now() - starttime
             else:
                 raise ValueError("This method doesn't support instruct models. Please use a chat model.")
-                # completion = await self.client_wrapper.get_async_client().beta.completions.parse(
-                #     model=self._model_name, **kwargs
-                # )
-            assert completion.choices[0].message.content is not None, "OpenAI refused to respond to the query"
-            return completion.choices[0].message.content
+            response_text = completion.choices[0].message.content
+            assert response_text is not None, "OpenAI refused to respond to the query"
+            completion_tokens, prompt_tokens = self.validate_tokens(completion)
+            self.add_llm_metadata(kwargs, response_text, wall_latency, completion_tokens, prompt_tokens)
+            return response_text
         except Exception as e:
             # OpenAI will not respond in two scenarios:
             # 1.) The LLM ran out of output context length(usually do to hallucination of repeating the same phrase)
diff --git a/lib/sycamore/sycamore/tests/integration/transforms/test_data_extraction.py b/lib/sycamore/sycamore/tests/integration/transforms/test_data_extraction.py
index cb6e226a6..e3f415417 100644
--- a/lib/sycamore/sycamore/tests/integration/transforms/test_data_extraction.py
+++ b/lib/sycamore/sycamore/tests/integration/transforms/test_data_extraction.py
@@ -43,12 +43,16 @@ def test_extract_properties_from_dict_schema(llm):
     docs = ctx.read.document(docs)
     docs = docs.extract_properties(property_extractor)
 
-    taken = docs.take_all()
+    taken = docs.take_all(include_metadata=True)
 
     assert taken[0].properties["entity"]["name"] == "Vinayak"
     assert taken[0].properties["entity"]["age"] == 74
     assert "Honolulu" in taken[0].properties["entity"]["from_location"]
 
+    assert len(taken) == 3
+    assert taken[2].metadata["usage"]["prompt_tokens"] > 0
+    assert taken[2].metadata["usage"]["completion_tokens"] > 0
+
 
 @pytest.mark.parametrize("llm", llms)
 def test_extract_properties_from_schema(llm):
@@ -61,6 +65,7 @@ def test_extract_properties_from_schema(llm):
                 field_type="str",
                 description="This is the name of an entity",
                 examples=["Mark", "Ollie", "Winston"],
+                default="null",
             ),
             SchemaField(name="age", field_type="int", default=999),
             SchemaField(name="date", field_type="str", description="Any date in the doc in YYYY-MM-DD format"),
@@ -80,14 +85,20 @@ def test_extract_properties_from_schema(llm):
     docs = ctx.read.document(docs)
     docs = docs.extract_properties(property_extractor)
 
-    taken = docs.take_all()
+    taken = docs.take_all(include_metadata=True)
 
     assert taken[0].properties["entity"]["name"] == "Vinayak"
     assert taken[0].properties["entity"]["age"] == 74
     assert taken[0].properties["entity"]["from_location"] == "Honolulu, HI", "Invalid location extracted or formatted"
     assert taken[0].properties["entity"]["date"] == "1923-02-24"
 
-    assert taken[1].properties["entity"]["name"] is None, "Default None value not being used correctly"
+    assert taken[1].properties["entity"]["name"] == "None"  # Anthropic isn't generating valid JSON with null values.
     assert taken[1].properties["entity"]["age"] == 999, "Default value not being used correctly"
     assert taken[1].properties["entity"]["from_location"] == "New Delhi"
     assert taken[1].properties["entity"]["date"] == "2014-01-11"
+
+    assert len(taken) == 5
+    assert taken[3].metadata["usage"]["prompt_tokens"] > 0
+    assert taken[3].metadata["usage"]["completion_tokens"] > 0
+    assert taken[4].metadata["usage"]["prompt_tokens"] > 0
+    assert taken[4].metadata["usage"]["completion_tokens"] > 0
diff --git a/lib/sycamore/sycamore/tests/unit/llms/test_llms.py b/lib/sycamore/sycamore/tests/unit/llms/test_llms.py
index 62cc4aed1..ab76de5f3 100644
--- a/lib/sycamore/sycamore/tests/unit/llms/test_llms.py
+++ b/lib/sycamore/sycamore/tests/unit/llms/test_llms.py
@@ -5,6 +5,38 @@
 from sycamore.llms.llms import FakeLLM
 from sycamore.llms.prompts import EntityExtractorFewShotGuidancePrompt, EntityExtractorZeroShotGuidancePrompt
 from sycamore.utils.cache import DiskCache
+import datetime
+from sycamore.utils.thread_local import ThreadLocalAccess
+
+
+def test_get_metadata():
+    llm = FakeLLM()
+    wall_latency = datetime.timedelta(seconds=1)
+    metadata = llm.get_metadata({"prompt": "Hello", "temperature": 0.7}, "Test output", wall_latency, 10, 5)
+    assert metadata["model"] == llm._model_name
+    assert metadata["usage"] == {
+        "completion_tokens": 10,
+        "prompt_tokens": 5,
+        "total_tokens": 15,
+    }
+    assert metadata["prompt"] == "Hello"
+    assert metadata["output"] == "Test output"
+    assert metadata["temperature"] == 0.7
+    assert metadata["wall_latency"] == wall_latency
+
+
+@patch("sycamore.llms.llms.add_metadata")
+def test_add_llm_metadata(mock_add_metadata):
+    llm = FakeLLM()
+    with patch.object(ThreadLocalAccess, "present", return_value=True):
+        llm.add_llm_metadata({}, "Test output", datetime.timedelta(seconds=0.5), 1, 2)
+        mock_add_metadata.assert_called_once()
+
+    # If TLS not present, add_metadata should not be called
+    mock_add_metadata.reset_mock()
+    with patch.object(ThreadLocalAccess, "present", return_value=False):
+        llm.add_llm_metadata({}, "Test output", datetime.timedelta(seconds=0.5), 1, 2)
+        mock_add_metadata.assert_not_called()
 
 
 def test_openai_davinci_fallback():