diff --git a/ddtrace/contrib/internal/anthropic/_streaming.py b/ddtrace/contrib/internal/anthropic/_streaming.py index 439f61bb5a6..21f5b27350b 100644 --- a/ddtrace/contrib/internal/anthropic/_streaming.py +++ b/ddtrace/contrib/internal/anthropic/_streaming.py @@ -7,6 +7,10 @@ import anthropic import wrapt +from ddtrace.llmobs._integrations.base_stream_handler import AsyncStreamHandler +from ddtrace.llmobs._integrations.base_stream_handler import make_traced_async_stream +from ddtrace.llmobs._integrations.base_stream_handler import make_traced_stream +from ddtrace.llmobs._integrations.base_stream_handler import StreamHandler from ddtrace.contrib.internal.anthropic.utils import tag_tool_use_output_on_span from ddtrace.internal.logger import get_logger from ddtrace.llmobs._utils import _get_attr @@ -15,139 +19,59 @@ log = get_logger(__name__) -def handle_streamed_response(integration, resp, args, kwargs, span): - if _is_stream(resp): - return TracedAnthropicStream(resp, integration, span, args, kwargs) - elif _is_async_stream(resp): - return TracedAnthropicAsyncStream(resp, integration, span, args, kwargs) - elif _is_stream_manager(resp): - return TracedAnthropicStreamManager(resp, integration, span, args, kwargs) - elif _is_async_stream_manager(resp): - return TracedAnthropicAsyncStreamManager(resp, integration, span, args, kwargs) - - -class BaseTracedAnthropicStream(wrapt.ObjectProxy): - def __init__(self, wrapped, integration, span, args, kwargs): - super().__init__(wrapped) - self._dd_span = span - self._streamed_chunks = [] - self._dd_integration = integration - self._kwargs = kwargs - self._args = args - - -class TracedAnthropicStream(BaseTracedAnthropicStream): - def __init__(self, wrapped, integration, span, args, kwargs): - super().__init__(wrapped, integration, span, args, kwargs) - # we need to set a text_stream attribute so we can trace the yielded chunks - self.text_stream = self.__stream_text__() - - def __enter__(self): - self.__wrapped__.__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.__wrapped__.__exit__(exc_type, exc_val, exc_tb) - - def __iter__(self): - return self - - def __next__(self): - try: - chunk = self.__wrapped__.__next__() - self._streamed_chunks.append(chunk) - return chunk - except StopIteration: - _process_finished_stream( - self._dd_integration, self._dd_span, self._args, self._kwargs, self._streamed_chunks - ) - self._dd_span.finish() - raise - except Exception: - self._dd_span.set_exc_info(*sys.exc_info()) - self._dd_span.finish() - raise - - def __stream_text__(self): - # this is overridden because it is a helper function that collects all stream content chunks - for chunk in self: - if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": - yield chunk.delta.text - - -class TracedAnthropicAsyncStream(BaseTracedAnthropicStream): - def __init__(self, wrapped, integration, span, args, kwargs): - super().__init__(wrapped, integration, span, args, kwargs) - # we need to set a text_stream attribute so we can trace the yielded chunks - self.text_stream = self.__stream_text__() - - async def __aenter__(self): - await self.__wrapped__.__aenter__() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb) - - def __aiter__(self): - return self - - async def __anext__(self): - try: - chunk = await self.__wrapped__.__anext__() - self._streamed_chunks.append(chunk) - return chunk - except StopAsyncIteration: - _process_finished_stream( - self._dd_integration, - self._dd_span, - self._args, - self._kwargs, - self._streamed_chunks, - ) - self._dd_span.finish() - raise - except Exception: - self._dd_span.set_exc_info(*sys.exc_info()) - self._dd_span.finish() - raise - - async def __stream_text__(self): - # this is overridden because it is a helper function that collects all stream content chunks - async for chunk in self: - if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": - yield chunk.delta.text - - -class TracedAnthropicStreamManager(BaseTracedAnthropicStream): - def __enter__(self): - stream = self.__wrapped__.__enter__() - traced_stream = TracedAnthropicStream( - stream, - self._dd_integration, - self._dd_span, - self._args, - self._kwargs, - ) - return traced_stream - def __exit__(self, exc_type, exc_val, exc_tb): - self.__wrapped__.__exit__(exc_type, exc_val, exc_tb) +def _text_stream_generator(traced_stream): + for chunk in traced_stream: + if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": + yield chunk.delta.text + +async def _async_text_stream_generator(traced_stream): + async for chunk in traced_stream: + if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": + yield chunk.delta.text -class TracedAnthropicAsyncStreamManager(BaseTracedAnthropicStream): - async def __aenter__(self): - stream = await self.__wrapped__.__aenter__() - traced_stream = TracedAnthropicAsyncStream( - stream, - self._dd_integration, - self._dd_span, - self._args, - self._kwargs, +def handle_streamed_response(integration, resp, args, kwargs, span): + def add_text_stream(stream): + stream.text_stream = _text_stream_generator(stream) + + def add_async_text_stream(stream): + stream.text_stream = _async_text_stream_generator(stream) + + if _is_stream(resp) or _is_stream_manager(resp): + traced_stream = make_traced_stream( + resp, + AnthropicStreamHandler(integration, span, args, kwargs), + on_stream_created=add_text_stream + ) + traced_stream.text_stream = _text_stream_generator(traced_stream) + return traced_stream + elif _is_async_stream(resp) or _is_async_stream_manager(resp): + traced_stream = make_traced_async_stream( + resp, + AnthropicAsyncStreamHandler(integration, span, args, kwargs), + on_stream_created=add_async_text_stream ) + traced_stream.text_stream = _async_text_stream_generator(traced_stream) return traced_stream - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb) +class BaseAnthropicStreamHandler: + def _initialize_chunk_storage(self): + return [] + + def finalize_stream(self, exception=None): + _process_finished_stream( + self.integration, self.primary_span, self.request_args, self.request_kwargs, self.chunks + ) + self.primary_span.finish() + +class AnthropicStreamHandler(BaseAnthropicStreamHandler, StreamHandler): + def process_chunk(self, chunk, iterator=None): + self.chunks.append(chunk) + +class AnthropicAsyncStreamHandler(BaseAnthropicStreamHandler, AsyncStreamHandler): + async def process_chunk(self, chunk, iterator=None): + self.chunks.append(chunk) def _process_finished_stream(integration, span, args, kwargs, streamed_chunks): diff --git a/ddtrace/contrib/internal/botocore/services/bedrock.py b/ddtrace/contrib/internal/botocore/services/bedrock.py index 028d812a36e..dbd5168b395 100644 --- a/ddtrace/contrib/internal/botocore/services/bedrock.py +++ b/ddtrace/contrib/internal/botocore/services/bedrock.py @@ -13,6 +13,8 @@ from ddtrace.internal import core from ddtrace.internal.logger import get_logger from ddtrace.internal.schema import schematize_service_name +from ddtrace.llmobs._integrations.base_stream_handler import make_traced_stream +from ddtrace.llmobs._integrations.base_stream_handler import StreamHandler from ddtrace.llmobs._integrations.bedrock_utils import parse_model_id @@ -27,81 +29,88 @@ _STABILITY = "stability" -class TracedBotocoreStreamingBody(wrapt.ObjectProxy): - """ - This class wraps the StreamingBody object returned by botocore api calls, specifically for Bedrock invocations. - Since the response body is in the form of a stream object, we need to wrap it in order to tag the response data - and fire completion events as the user consumes the streamed response. - """ - - def __init__(self, wrapped, ctx: core.ExecutionContext): - super().__init__(wrapped) - self._body = [] - self._execution_ctx = ctx - - def read(self, amt=None): - """Wraps around method to tags the response data and finish the span as the user consumes the stream.""" - try: - body = self.__wrapped__.read(amt=amt) - self._body.append(json.loads(body)) - if self.__wrapped__.tell() == int(self.__wrapped__._content_length): - formatted_response = _extract_text_and_response_reason(self._execution_ctx, self._body[0]) - model_provider = self._execution_ctx["model_provider"] - model_name = self._execution_ctx["model_name"] - should_set_choice_ids = model_provider == _COHERE and "embed" not in model_name - core.dispatch( - "botocore.bedrock.process_response", - [self._execution_ctx, formatted_response, None, self._body[0], should_set_choice_ids], - ) - return body - except Exception: - core.dispatch("botocore.patched_bedrock_api_call.exception", [self._execution_ctx, sys.exc_info()]) - raise - - def readlines(self): - """Wraps around method to tags the response data and finish the span as the user consumes the stream.""" - try: - lines = self.__wrapped__.readlines() - for line in lines: - self._body.append(json.loads(line)) - formatted_response = _extract_text_and_response_reason(self._execution_ctx, self._body[0]) - model_provider = self._execution_ctx["model_provider"] - model_name = self._execution_ctx["model_name"] +def traced_stream_read(traced_stream, original_read, amt=None): + """Wraps around method to tags the response data and finish the span as the user consumes the stream.""" + handler = traced_stream.handler + execution_ctx = handler.options.get("execution_ctx", {}) + try: + body = original_read(amt=amt) + handler.chunks.append(json.loads(body)) + if traced_stream.__wrapped__.tell() == int(traced_stream.__wrapped__._content_length): + formatted_response = _extract_text_and_response_reason(execution_ctx, handler.chunks[0]) + model_provider = execution_ctx["model_provider"] + model_name = execution_ctx["model_name"] should_set_choice_ids = model_provider == _COHERE and "embed" not in model_name core.dispatch( "botocore.bedrock.process_response", - [self._execution_ctx, formatted_response, None, self._body[0], should_set_choice_ids], - ) - return lines - except Exception: - core.dispatch("botocore.patched_bedrock_api_call.exception", [self._execution_ctx, sys.exc_info()]) - raise - - def __iter__(self): - """Wraps around method to tags the response data and finish the span as the user consumes the stream.""" - exception_raised = False - try: - for line in self.__wrapped__: - self._body.append(json.loads(line["chunk"]["bytes"])) - yield line - except Exception: - core.dispatch("botocore.patched_bedrock_api_call.exception", [self._execution_ctx, sys.exc_info()]) - exception_raised = True - raise - finally: - if exception_raised: - return - metadata = _extract_streamed_response_metadata(self._execution_ctx, self._body) - formatted_response = _extract_streamed_response(self._execution_ctx, self._body) - model_provider = self._execution_ctx["model_provider"] - model_name = self._execution_ctx["model_name"] - should_set_choice_ids = ( - model_provider == _COHERE and "is_finished" not in self._body[0] and "embed" not in model_name - ) - core.dispatch( - "botocore.bedrock.process_response", - [self._execution_ctx, formatted_response, metadata, self._body, should_set_choice_ids], + [execution_ctx, formatted_response, None, handler.chunks[0], should_set_choice_ids], ) + return body + except Exception: + core.dispatch("botocore.patched_bedrock_api_call.exception", [execution_ctx, sys.exc_info()]) + raise + +def traced_stream_readlines(traced_stream, original_readlines): + """Wraps around method to tags the response data and finish the span as the user consumes the stream.""" + handler = traced_stream.handler + execution_ctx = handler.options.get("execution_ctx", {}) + try: + lines = original_readlines() + for line in lines: + handler.chunks.append(json.loads(line)) + formatted_response = _extract_text_and_response_reason(execution_ctx, handler.chunks[0]) + model_provider = execution_ctx["model_provider"] + model_name = execution_ctx["model_name"] + should_set_choice_ids = model_provider == _COHERE and "embed" not in model_name + core.dispatch( + "botocore.bedrock.process_response", + [execution_ctx, formatted_response, None, handler.chunks[0], should_set_choice_ids], + ) + return lines + except Exception: + core.dispatch("botocore.patched_bedrock_api_call.exception", [execution_ctx, sys.exc_info()]) + raise + +class BotocoreStreamingBodyStreamHandler(StreamHandler): + def _initialize_chunk_storage(self): + return [] + + def process_chunk(self, chunk, iterator=None): + self.chunks.append(json.loads(chunk["chunk"]["bytes"])) + + def handle_exception(self, exception): + core.dispatch("botocore.patched_bedrock_api_call.exception", [self.options.get("execution_ctx", {}), sys.exc_info()]) + + def finalize_stream(self, exception=None): + if exception: + return + execution_ctx = self.options.get("execution_ctx", {}) + metadata = _extract_streamed_response_metadata(execution_ctx, self.chunks) + formatted_response = _extract_streamed_response(execution_ctx, self.chunks) + model_provider = execution_ctx["model_provider"] + model_name = execution_ctx["model_name"] + should_set_choice_ids = ( + model_provider == _COHERE and "is_finished" not in self.chunks[0] and "embed" not in model_name + ) + core.dispatch( + "botocore.bedrock.process_response", + [execution_ctx, formatted_response, metadata, self.chunks, should_set_choice_ids], + ) + + +def make_botocore_streaming_body_traced_stream(streaming_body, integration, span, args, kwargs, execution_ctx): + original_read = getattr(streaming_body, "read", None) + original_readlines = getattr(streaming_body, "readlines", None) + traced_stream = make_traced_stream( + streaming_body, + BotocoreStreamingBodyStreamHandler(integration, span, args, kwargs, execution_ctx=execution_ctx), + ) + # add bedrock-specific methods to the traced stream + if original_read: + traced_stream.read = lambda amt=None: traced_stream_read(traced_stream, original_read, amt) + if original_readlines: + traced_stream.readlines = lambda: traced_stream_readlines(traced_stream, original_readlines) + return traced_stream class TracedBotocoreConverseStream(wrapt.ObjectProxy): @@ -440,7 +449,7 @@ def handle_bedrock_response( return result body = result["body"] - result["body"] = TracedBotocoreStreamingBody(body, ctx) + result["body"] = make_botocore_streaming_body_traced_stream(body, None, None, None, None, ctx) return result diff --git a/ddtrace/contrib/internal/google_generativeai/_utils.py b/ddtrace/contrib/internal/google_generativeai/_utils.py index e41729380fa..e4210d20833 100644 --- a/ddtrace/contrib/internal/google_generativeai/_utils.py +++ b/ddtrace/contrib/internal/google_generativeai/_utils.py @@ -4,62 +4,32 @@ import wrapt from ddtrace.internal.utils import get_argument_value +from ddtrace.llmobs._integrations.base_stream_handler import AsyncStreamHandler +from ddtrace.llmobs._integrations.base_stream_handler import StreamHandler from ddtrace.llmobs._integrations.utils import get_generation_config_google from ddtrace.llmobs._integrations.utils import get_system_instructions_from_google_model from ddtrace.llmobs._integrations.utils import tag_request_content_part_google from ddtrace.llmobs._integrations.utils import tag_response_part_google +class BaseGoogleGenerativeAIStramHandler: + def finalize_stream(self, exception=None): + tag_response(self.primary_span, self.options.get("wrapped_stream", None), self.integration, self.options.get("model_instance", None)) + self.request_kwargs["instance"] = self.options.get("model_instance", None) + self.integration.llmobs_set_tags( + self.primary_span, + args=self.request_args, + kwargs=self.request_kwargs, + response=self.options.get("wrapped_stream", None), + ) + self.primary_span.finish() -class BaseTracedGenerateContentResponse(wrapt.ObjectProxy): - """Base wrapper class for GenerateContentResponse objects for tracing streamed responses.""" - - def __init__(self, wrapped, instance, integration, span, args, kwargs): - super().__init__(wrapped) - self._model_instance = instance - self._dd_integration = integration - self._dd_span = span - self._args = args - self._kwargs = kwargs - - -class TracedGenerateContentResponse(BaseTracedGenerateContentResponse): - def __iter__(self): - try: - for chunk in self.__wrapped__.__iter__(): - yield chunk - except Exception: - self._dd_span.set_exc_info(*sys.exc_info()) - raise - finally: - tag_response(self._dd_span, self.__wrapped__, self._dd_integration, self._model_instance) - self._kwargs["instance"] = self._model_instance - self._dd_integration.llmobs_set_tags( - self._dd_span, - args=self._args, - kwargs=self._kwargs, - response=self.__wrapped__, - ) - self._dd_span.finish() - - -class TracedAsyncGenerateContentResponse(BaseTracedGenerateContentResponse): - async def __aiter__(self): - try: - async for chunk in self.__wrapped__.__aiter__(): - yield chunk - except Exception: - self._dd_span.set_exc_info(*sys.exc_info()) - raise - finally: - tag_response(self._dd_span, self.__wrapped__, self._dd_integration, self._model_instance) - self._kwargs["instance"] = self._model_instance - self._dd_integration.llmobs_set_tags( - self._dd_span, - args=self._args, - kwargs=self._kwargs, - response=self.__wrapped__, - ) - self._dd_span.finish() +class GoogleGenerativeAIStramHandler(BaseGoogleGenerativeAIStramHandler, StreamHandler): + def process_chunk(self, chunk, iterator=None): + pass + +class GoogleGenerativeAIAsyncStreamHandler(BaseGoogleGenerativeAIStramHandler, AsyncStreamHandler): + async def process_chunk(self, chunk, iterator=None): + pass def _extract_api_key(instance): diff --git a/ddtrace/contrib/internal/google_generativeai/patch.py b/ddtrace/contrib/internal/google_generativeai/patch.py index 7f6bdbaf43b..2411abddc7f 100644 --- a/ddtrace/contrib/internal/google_generativeai/patch.py +++ b/ddtrace/contrib/internal/google_generativeai/patch.py @@ -5,8 +5,6 @@ import google.generativeai as genai from ddtrace import config -from ddtrace.contrib.internal.google_generativeai._utils import TracedAsyncGenerateContentResponse -from ddtrace.contrib.internal.google_generativeai._utils import TracedGenerateContentResponse from ddtrace.contrib.internal.google_generativeai._utils import _extract_api_key from ddtrace.contrib.internal.google_generativeai._utils import tag_request from ddtrace.contrib.internal.google_generativeai._utils import tag_response @@ -15,6 +13,10 @@ from ddtrace.contrib.internal.trace_utils import wrap from ddtrace.llmobs._integrations import GeminiIntegration from ddtrace.llmobs._integrations.utils import extract_model_name_google +from ddtrace.llmobs._integrations.base_stream_handler import make_traced_async_stream +from ddtrace.llmobs._integrations.base_stream_handler import make_traced_stream +from ddtrace.contrib.internal.google_generativeai._utils import GoogleGenerativeAIAsyncStreamHandler +from ddtrace.contrib.internal.google_generativeai._utils import GoogleGenerativeAIStramHandler from ddtrace.trace import Pin @@ -57,7 +59,7 @@ def traced_generate(genai, pin, func, instance, args, kwargs): if api_key: span.set_tag("google_generativeai.request.api_key", "...{}".format(api_key[-4:])) if stream: - return TracedGenerateContentResponse(generations, instance, integration, span, args, kwargs) + return make_traced_stream(generations, GoogleGenerativeAIStramHandler(integration, span, args, kwargs, model_instance=instance, wrapped_stream=generations)) tag_response(span, generations, integration, instance) except Exception: span.set_exc_info(*sys.exc_info()) @@ -87,7 +89,7 @@ async def traced_agenerate(genai, pin, func, instance, args, kwargs): tag_request(span, integration, instance, args, kwargs) generations = await func(*args, **kwargs) if stream: - return TracedAsyncGenerateContentResponse(generations, instance, integration, span, args, kwargs) + return make_traced_async_stream(generations, GoogleGenerativeAIAsyncStreamHandler(integration, span, args, kwargs, model_instance=instance, wrapped_stream=generations)) tag_response(span, generations, integration, instance) except Exception: span.set_exc_info(*sys.exc_info()) diff --git a/ddtrace/contrib/internal/langchain/utils.py b/ddtrace/contrib/internal/langchain/utils.py index 74ba810f473..f06e059fc9f 100644 --- a/ddtrace/contrib/internal/langchain/utils.py +++ b/ddtrace/contrib/internal/langchain/utils.py @@ -2,70 +2,34 @@ import sys from ddtrace.internal import core - - -class BaseTracedLangChainStreamResponse: - def __init__(self, generator, integration, span, on_span_finish, chunk_callback): - self._generator = generator - self._dd_integration = integration - self._dd_span = span - self._on_span_finish = on_span_finish - self._chunk_callback = chunk_callback - self._chunks = [] - - -class TracedLangchainStreamResponse(BaseTracedLangChainStreamResponse): - def __enter__(self): - self._generator.__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self._generator.__exit__(exc_type, exc_val, exc_tb) - - def __iter__(self): - return self - - def __next__(self): - try: - chunk = self._generator.__next__() - self._chunks.append(chunk) - self._chunk_callback(chunk) - return chunk - except StopIteration: - self._on_span_finish(self._dd_span, self._chunks) - self._dd_span.finish() - raise - except Exception: - self._dd_span.set_exc_info(*sys.exc_info()) - self._dd_span.finish() - raise - - -class TracedLangchainAsyncStreamResponse(BaseTracedLangChainStreamResponse): - async def __aenter__(self): - await self._generator.__enter__() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self._generator.__exit__(exc_type, exc_val, exc_tb) - - def __aiter__(self): - return self - - async def __anext__(self): - try: - chunk = await self._generator.__anext__() - self._chunks.append(chunk) - self._chunk_callback(chunk) - return chunk - except StopAsyncIteration: - self._on_span_finish(self._dd_span, self._chunks) - self._dd_span.finish() - raise - except Exception: - self._dd_span.set_exc_info(*sys.exc_info()) - self._dd_span.finish() - raise +from ddtrace.llmobs._integrations.base_stream_handler import AsyncStreamHandler +from ddtrace.llmobs._integrations.base_stream_handler import make_traced_async_stream +from ddtrace.llmobs._integrations.base_stream_handler import make_traced_stream +from ddtrace.llmobs._integrations.base_stream_handler import StreamHandler + +class BaseLangchainStreamHandler: + def _initialize_chunk_storage(self): + return [] + + def _process_chunk(self, chunk): + self.chunks.append(chunk) + chunk_callback = self.options.get("chunk_callback", None) + if chunk_callback: + chunk_callback(chunk) + + def finalize_stream(self, exception=None): + on_span_finish = self.options.get("on_span_finish", None) + if on_span_finish: + on_span_finish(self.primary_span, self.chunks) + self.primary_span.finish() + +class LangchainStreamHandler(BaseLangchainStreamHandler, StreamHandler): + def process_chunk(self, chunk, iterator=None): + self._process_chunk(chunk) + +class LangchainAsyncStreamHandler(BaseLangchainStreamHandler, AsyncStreamHandler): + async def process_chunk(self, chunk, iterator=None): + self._process_chunk(chunk) def shared_stream( @@ -95,9 +59,10 @@ def shared_stream( try: resp = func(*args, **kwargs) - cls = TracedLangchainAsyncStreamResponse if inspect.isasyncgen(resp) else TracedLangchainStreamResponse chunk_callback = _get_chunk_callback(interface_type, args, kwargs) - return cls(resp, integration, span, on_span_finished, chunk_callback) + if inspect.isasyncgen(resp): + return make_traced_async_stream(resp, LangchainAsyncStreamHandler(integration, span, args, kwargs, on_span_finish=on_span_finished, chunk_callback=chunk_callback)) + return make_traced_stream(resp, LangchainStreamHandler(integration, span, args, kwargs, on_span_finish=on_span_finished, chunk_callback=chunk_callback)) except Exception: # error with the method call itself span.set_exc_info(*sys.exc_info()) diff --git a/ddtrace/contrib/internal/litellm/patch.py b/ddtrace/contrib/internal/litellm/patch.py index d24d4581e82..f3bfb0f58df 100644 --- a/ddtrace/contrib/internal/litellm/patch.py +++ b/ddtrace/contrib/internal/litellm/patch.py @@ -4,15 +4,17 @@ import litellm from ddtrace import config -from ddtrace.contrib.internal.litellm.utils import TracedLiteLLMAsyncStream -from ddtrace.contrib.internal.litellm.utils import TracedLiteLLMStream from ddtrace.contrib.internal.litellm.utils import extract_host_tag +from ddtrace.contrib.internal.litellm.utils import LiteLLMAsyncStreamHandler +from ddtrace.contrib.internal.litellm.utils import LiteLLMStreamHandler from ddtrace.contrib.trace_utils import unwrap from ddtrace.contrib.trace_utils import with_traced_module from ddtrace.contrib.trace_utils import wrap from ddtrace.internal.utils import get_argument_value from ddtrace.llmobs._constants import LITELLM_ROUTER_INSTANCE_KEY from ddtrace.llmobs._integrations import LiteLLMIntegration +from ddtrace.llmobs._integrations.base_stream_handler import make_traced_async_stream +from ddtrace.llmobs._integrations.base_stream_handler import make_traced_stream from ddtrace.trace import Pin @@ -46,7 +48,7 @@ def traced_completion(litellm, pin, func, instance, args, kwargs): try: resp = func(*args, **kwargs) if stream: - return TracedLiteLLMStream(resp, integration, span, kwargs) + return make_traced_stream(resp, LiteLLMStreamHandler(integration, span, args, kwargs)) return resp except Exception: span.set_exc_info(*sys.exc_info()) @@ -77,7 +79,7 @@ async def traced_acompletion(litellm, pin, func, instance, args, kwargs): try: resp = await func(*args, **kwargs) if stream: - return TracedLiteLLMAsyncStream(resp, integration, span, kwargs) + return make_traced_async_stream(resp, LiteLLMAsyncStreamHandler(integration, span, args, kwargs)) return resp except Exception: span.set_exc_info(*sys.exc_info()) @@ -108,7 +110,7 @@ def traced_router_completion(litellm, pin, func, instance, args, kwargs): try: resp = func(*args, **kwargs) if stream: - resp._add_router_span_info(span, kwargs, instance) + resp.handler.add_span(span, kwargs, instance) return resp except Exception: span.set_exc_info(*sys.exc_info()) @@ -139,7 +141,7 @@ async def traced_router_acompletion(litellm, pin, func, instance, args, kwargs): try: resp = await func(*args, **kwargs) if stream: - resp._add_router_span_info(span, kwargs, instance) + resp.handler.add_span(span, kwargs, instance) return resp except Exception: span.set_exc_info(*sys.exc_info()) diff --git a/ddtrace/contrib/internal/litellm/utils.py b/ddtrace/contrib/internal/litellm/utils.py index be33e17c22d..541a5894d5e 100644 --- a/ddtrace/contrib/internal/litellm/utils.py +++ b/ddtrace/contrib/internal/litellm/utils.py @@ -1,12 +1,11 @@ -from collections import defaultdict import sys -import wrapt - from ddtrace.internal.logger import get_logger from ddtrace.llmobs._constants import LITELLM_ROUTER_INSTANCE_KEY from ddtrace.llmobs._integrations.utils import openai_construct_completion_from_streamed_chunks from ddtrace.llmobs._integrations.utils import openai_construct_message_from_streamed_chunks +from ddtrace.llmobs._integrations.base_stream_handler import StreamHandler +from ddtrace.llmobs._integrations.base_stream_handler import AsyncStreamHandler log = get_logger(__name__) @@ -18,137 +17,54 @@ def extract_host_tag(kwargs): return None -class BaseTracedLiteLLMStream(wrapt.ObjectProxy): - def __init__(self, wrapped, integration, span, kwargs): - super().__init__(wrapped) - self._dd_integration = integration - self._span_info = [(span, kwargs)] - self._streamed_chunks = defaultdict(list) - - def _add_router_span_info(self, span, kwargs, instance): - """Handler to add router span to this streaming object. - - Helps to ensure that all spans associated with a single stream are finished and have the correct tags. - """ +class BaseLiteLLMStreamHandler: + def add_span(self, span, kwargs, instance): kwargs[LITELLM_ROUTER_INSTANCE_KEY] = instance - self._span_info.append((span, kwargs)) + self.spans.append((span, kwargs)) - def _finish_spans(self): - """Helper to finish all spans associated with this stream.""" + def _process_chunk(self, chunk, iterator=None): + for choice in getattr(chunk, "choices", []): + choice_index = getattr(choice, "index", 0) + self.chunks[choice_index].append(choice) + if getattr(chunk, "usage", None): + self.chunks[0].insert(0, chunk) + + def finalize_stream(self, exception=None): formatted_completions = None - for span, kwargs in self._span_info: + for span, kwargs in self.spans: if not formatted_completions: - formatted_completions = _process_finished_stream( - self._dd_integration, span, kwargs, self._streamed_chunks, span.resource - ) - elif self._dd_integration.is_pc_sampled_llmobs(span): - self._dd_integration.llmobs_set_tags( + formatted_completions = self._process_finished_stream(span) + if self.integration.is_pc_sampled_llmobs(span): + self.integration.llmobs_set_tags( span, args=[], kwargs=kwargs, response=formatted_completions, operation=span.resource ) span.finish() - -class TracedLiteLLMStream(BaseTracedLiteLLMStream): - def __enter__(self): - self.__wrapped__.__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.__wrapped__.__exit__(exc_type, exc_val, exc_tb) - - def __iter__(self): - try: - for chunk in self.__wrapped__: - yield chunk - _loop_handler(chunk, self._streamed_chunks) - except Exception: - if self._span_info and len(self._span_info[0]) > 0: - span = self._span_info[0][0] - span.set_exc_info(*sys.exc_info()) - raise - finally: - self._finish_spans() - - def __next__(self): + def _process_finished_stream(self, span): try: - chunk = self.__wrapped__.__next__() - _loop_handler(chunk, self._streamed_chunks) - return chunk - except StopIteration: - raise + operation = span.resource + formatted_completions = None + if self.integration.is_completion_operation(operation): + formatted_completions = [ + openai_construct_completion_from_streamed_chunks(choice) + for choice in self.chunks.values() + ] + else: + formatted_completions = [ + openai_construct_message_from_streamed_chunks(choice) + for choice in self.chunks.values() + ] + return formatted_completions except Exception: - if self._span_info and len(self._span_info[0]) > 0: - span = self._span_info[0][0] - span.set_exc_info(*sys.exc_info()) - raise - finally: - self._finish_spans() + log.warning("Error processing streamed completion/chat response.", exc_info=True) + return formatted_completions -class TracedLiteLLMAsyncStream(BaseTracedLiteLLMStream): - async def __aenter__(self): - await self.__wrapped__.__aenter__() - return self +class LiteLLMStreamHandler(BaseLiteLLMStreamHandler, StreamHandler): + def process_chunk(self, chunk, iterator=None): + self._process_chunk(chunk, iterator) - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb) - async def __aiter__(self): - try: - async for chunk in self.__wrapped__: - yield chunk - _loop_handler(chunk, self._streamed_chunks) - except Exception: - if self._span_info and len(self._span_info[0]) > 0: - span = self._span_info[0][0] - span.set_exc_info(*sys.exc_info()) - raise - finally: - self._finish_spans() - - async def __anext__(self): - try: - chunk = await self.__wrapped__.__anext__() - _loop_handler(chunk, self._streamed_chunks) - return chunk - except StopAsyncIteration: - raise - except Exception: - if self._span_info and len(self._span_info[0]) > 0: - span = self._span_info[0][0] - span.set_exc_info(*sys.exc_info()) - raise - finally: - self._finish_spans() - - -def _loop_handler(chunk, streamed_chunks): - """Appends the chunk to the correct index in the streamed_chunks list. - - When handling a streamed chat/completion response, this function is called for each chunk in the streamed response. - """ - for choice in getattr(chunk, "choices", []): - choice_index = getattr(choice, "index", 0) - streamed_chunks[choice_index].append(choice) - if getattr(chunk, "usage", None): - streamed_chunks[0].insert(0, chunk) - - -def _process_finished_stream(integration, span, kwargs, streamed_chunks, operation): - try: - formatted_completions = None - if integration.is_completion_operation(operation): - formatted_completions = [ - openai_construct_completion_from_streamed_chunks(choice) for choice in streamed_chunks.values() - ] - else: - formatted_completions = [ - openai_construct_message_from_streamed_chunks(choice) for choice in streamed_chunks.values() - ] - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags( - span, args=[], kwargs=kwargs, response=formatted_completions, operation=operation - ) - except Exception: - log.warning("Error processing streamed completion/chat response.", exc_info=True) - return formatted_completions +class LiteLLMAsyncStreamHandler(BaseLiteLLMStreamHandler, AsyncStreamHandler): + async def process_chunk(self, chunk, iterator=None): + self._process_chunk(chunk, iterator) \ No newline at end of file diff --git a/ddtrace/contrib/internal/openai/_endpoint_hooks.py b/ddtrace/contrib/internal/openai/_endpoint_hooks.py index 94c3cccd262..47769d474d6 100644 --- a/ddtrace/contrib/internal/openai/_endpoint_hooks.py +++ b/ddtrace/contrib/internal/openai/_endpoint_hooks.py @@ -1,7 +1,9 @@ from openai.version import VERSION as OPENAI_VERSION -from ddtrace.contrib.internal.openai.utils import TracedOpenAIAsyncStream -from ddtrace.contrib.internal.openai.utils import TracedOpenAIStream +from ddtrace.llmobs._integrations.base_stream_handler import make_traced_async_stream +from ddtrace.llmobs._integrations.base_stream_handler import make_traced_stream +from ddtrace.contrib.internal.openai.utils import OpenAIAsyncStreamHandler +from ddtrace.contrib.internal.openai.utils import OpenAIStreamHandler from ddtrace.contrib.internal.openai.utils import _format_openai_api_key from ddtrace.contrib.internal.openai.utils import _is_async_generator from ddtrace.contrib.internal.openai.utils import _is_generator @@ -102,9 +104,9 @@ def _handle_streamed_response(self, integration, span, kwargs, resp, is_completi """ if parse_version(OPENAI_VERSION) >= (1, 6, 0): if _is_async_generator(resp): - return TracedOpenAIAsyncStream(resp, integration, span, kwargs, is_completion) + return make_traced_async_stream(resp, OpenAIAsyncStreamHandler(integration, span, None, kwargs, is_completion=is_completion)) elif _is_generator(resp): - return TracedOpenAIStream(resp, integration, span, kwargs, is_completion) + return make_traced_stream(resp, OpenAIStreamHandler(integration, span, None, kwargs, is_completion=is_completion)) def shared_gen(): try: diff --git a/ddtrace/contrib/internal/openai/utils.py b/ddtrace/contrib/internal/openai/utils.py index 0bb26c42aad..b3476b8b984 100644 --- a/ddtrace/contrib/internal/openai/utils.py +++ b/ddtrace/contrib/internal/openai/utils.py @@ -8,6 +8,8 @@ from ddtrace.internal.logger import get_logger from ddtrace.llmobs._integrations.utils import openai_construct_completion_from_streamed_chunks from ddtrace.llmobs._integrations.utils import openai_construct_message_from_streamed_chunks +from ddtrace.llmobs._integrations.base_stream_handler import AsyncStreamHandler +from ddtrace.llmobs._integrations.base_stream_handler import StreamHandler from ddtrace.llmobs._utils import _get_attr @@ -24,70 +26,62 @@ _punc_regex = re.compile(r"[\w']+|[.,!?;~@#$%^&*()+/-]") -class BaseTracedOpenAIStream(wrapt.ObjectProxy): - def __init__(self, wrapped, integration, span, kwargs, is_completion=False): - super().__init__(wrapped) - n = kwargs.get("n", 1) or 1 - prompts = kwargs.get("prompt", "") - if is_completion and prompts and isinstance(prompts, list) and not isinstance(prompts[0], int): - n *= len(prompts) - self._dd_span = span - self._streamed_chunks = [[] for _ in range(n)] - self._dd_integration = integration - self._is_completion = is_completion - self._kwargs = kwargs - - -class TracedOpenAIStream(BaseTracedOpenAIStream): - """ - This class is used to trace OpenAI stream objects for chat/completion/response. - """ - - def __enter__(self): - self.__wrapped__.__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.__wrapped__.__exit__(exc_type, exc_val, exc_tb) +class BaseOpenAIStreamHandler: + def _loop_handler(self, span, chunk, streamed_chunks): + """ + Sets the openai model tag and appends the chunk to the correct index in the streamed_chunks list. + When handling a streamed chat/completion/responses, + this function is called for each chunk in the streamed response. + """ + + if span.get_tag("openai.response.model") is None: + if hasattr(chunk, "type") and chunk.type.startswith("response."): + response = getattr(chunk, "response", None) + model = getattr(response, "model", "") + else: + model = getattr(chunk, "model", "") + span.set_tag_str("openai.response.model", model) + # Only run if the chunk is a completion/chat completion + for choice in getattr(chunk, "choices", []): + streamed_chunks[choice.index].append(choice) + if getattr(chunk, "usage", None): + streamed_chunks[0].insert(0, chunk) + + def finalize_stream(self, exception=None): + if not exception: + self._process_finished_stream( + self.options.get("is_completion", False) + ) + self.primary_span.finish() - def __iter__(self): - exception_raised = False - try: - for chunk in self.__wrapped__: - self._extract_token_chunk(chunk) - yield chunk - _loop_handler(self._dd_span, chunk, self._streamed_chunks) - except Exception: - self._dd_span.set_exc_info(*sys.exc_info()) - exception_raised = True - raise - finally: - if not exception_raised: - _process_finished_stream( - self._dd_integration, self._dd_span, self._kwargs, self._streamed_chunks, self._is_completion - ) - self._dd_span.finish() - - def __next__(self): + def _process_finished_stream(self, is_completion=False): + prompts = self.request_kwargs.get("prompt", None) + request_messages = self.request_kwargs.get("messages", None) try: - chunk = self.__wrapped__.__next__() - self._extract_token_chunk(chunk) - _loop_handler(self._dd_span, chunk, self._streamed_chunks) - return chunk - except StopIteration: - _process_finished_stream( - self._dd_integration, self._dd_span, self._kwargs, self._streamed_chunks, self._is_completion - ) - self._dd_span.finish() - raise + if is_completion: + formatted_completions = [ + openai_construct_completion_from_streamed_chunks(choice) for choice in self.chunks.values() + ] + else: + formatted_completions = [ + openai_construct_message_from_streamed_chunks(choice) for choice in self.chunks.values() + ] + if self.integration.is_pc_sampled_span(self.primary_span): + _tag_streamed_response(self.integration, self.primary_span, formatted_completions) + _set_token_metrics(self.primary_span, formatted_completions, prompts, request_messages, self.request_kwargs) + operation = "completion" if is_completion else "chat" + self.integration.llmobs_set_tags(self.primary_span, args=[], kwargs=self.request_kwargs, response=formatted_completions, operation=operation) except Exception: - self._dd_span.set_exc_info(*sys.exc_info()) - self._dd_span.finish() - raise + log.warning("Error processing streamed completion/chat response.", exc_info=True) + +class OpenAIStreamHandler(BaseOpenAIStreamHandler, StreamHandler): + def process_chunk(self, chunk, iterator=None): + self._extract_token_chunk(chunk, iterator) + self._loop_handler(self.primary_span, chunk, self.chunks) - def _extract_token_chunk(self, chunk): + def _extract_token_chunk(self, chunk, iterator=None): """Attempt to extract the token chunk (last chunk in the stream) from the streamed response.""" - if not self._dd_span._get_ctx_item("_dd.auto_extract_token_chunk"): + if not self.primary_span._get_ctx_item("_dd.auto_extract_token_chunk"): return choices = getattr(chunk, "choices") if not choices: @@ -99,62 +93,19 @@ def _extract_token_chunk(self, chunk): try: # User isn't expecting last token chunk to be present since it's not part of the default streamed response, # so we consume it and extract the token usage metadata before it reaches the user. - usage_chunk = self.__wrapped__.__next__() - self._streamed_chunks[0].insert(0, usage_chunk) + usage_chunk = iterator.__next__() + self.chunks[0].insert(0, usage_chunk) except (StopIteration, GeneratorExit): return +class OpenAIAsyncStreamHandler(BaseOpenAIStreamHandler, AsyncStreamHandler): + async def process_chunk(self, chunk, iterator=None): + await self._extract_token_chunk(chunk, iterator) + self._loop_handler(self.primary_span, chunk, self.chunks) -class TracedOpenAIAsyncStream(BaseTracedOpenAIStream): - """ - This class is used to trace AsyncOpenAI stream objects for chat/completion/response. - """ - - async def __aenter__(self): - await self.__wrapped__.__aenter__() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb) - - async def __aiter__(self): - exception_raised = False - try: - async for chunk in self.__wrapped__: - await self._extract_token_chunk(chunk) - yield chunk - _loop_handler(self._dd_span, chunk, self._streamed_chunks) - except Exception: - self._dd_span.set_exc_info(*sys.exc_info()) - exception_raised = True - raise - finally: - if not exception_raised: - _process_finished_stream( - self._dd_integration, self._dd_span, self._kwargs, self._streamed_chunks, self._is_completion - ) - self._dd_span.finish() - - async def __anext__(self): - try: - chunk = await self.__wrapped__.__anext__() - await self._extract_token_chunk(chunk) - _loop_handler(self._dd_span, chunk, self._streamed_chunks) - return chunk - except StopAsyncIteration: - _process_finished_stream( - self._dd_integration, self._dd_span, self._kwargs, self._streamed_chunks, self._is_completion - ) - self._dd_span.finish() - raise - except Exception: - self._dd_span.set_exc_info(*sys.exc_info()) - self._dd_span.finish() - raise - - async def _extract_token_chunk(self, chunk): + async def _extract_token_chunk(self, chunk, iterator=None): """Attempt to extract the token chunk (last chunk in the stream) from the streamed response.""" - if not self._dd_span._get_ctx_item("_dd.auto_extract_token_chunk"): + if not self.primary_span._get_ctx_item("_dd.auto_extract_token_chunk"): return choices = getattr(chunk, "choices") if not choices: @@ -163,12 +114,11 @@ async def _extract_token_chunk(self, chunk): if not getattr(choice, "finish_reason", None): return try: - usage_chunk = await self.__wrapped__.__anext__() - self._streamed_chunks[0].insert(0, usage_chunk) + usage_chunk = await iterator.__anext__() + self.chunks[0].insert(0, usage_chunk) except (StopAsyncIteration, GeneratorExit): return - def _compute_token_count(content, model): # type: (Union[str, List[int]], Optional[str]) -> Tuple[bool, int] """ diff --git a/ddtrace/contrib/internal/vertexai/_utils.py b/ddtrace/contrib/internal/vertexai/_utils.py index 8d2b28f06c2..5ea3d5782e4 100644 --- a/ddtrace/contrib/internal/vertexai/_utils.py +++ b/ddtrace/contrib/internal/vertexai/_utils.py @@ -4,85 +4,39 @@ from vertexai.generative_models import Part from ddtrace.internal.utils import get_argument_value +from ddtrace.llmobs._integrations.base_stream_handler import AsyncStreamHandler +from ddtrace.llmobs._integrations.base_stream_handler import StreamHandler from ddtrace.llmobs._integrations.utils import get_generation_config_google from ddtrace.llmobs._integrations.utils import get_system_instructions_from_google_model from ddtrace.llmobs._integrations.utils import tag_request_content_part_google from ddtrace.llmobs._integrations.utils import tag_response_part_google from ddtrace.llmobs._utils import _get_attr - -class BaseTracedVertexAIStreamResponse: - def __init__(self, generator, model_instance, integration, span, args, kwargs, is_chat, history): - self._generator = generator - self._model_instance = model_instance - self._dd_integration = integration - self._dd_span = span - self._args = args - self._kwargs = kwargs - self.is_chat = is_chat - self._chunks = [] - self._history = history - - -class TracedVertexAIStreamResponse(BaseTracedVertexAIStreamResponse): - def __enter__(self): - self._generator.__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self._generator.__exit__(exc_type, exc_val, exc_tb) - - def __iter__(self): - try: - for chunk in self._generator.__iter__(): - # only keep track of the first chunk for chat messages since - # it is modified during the streaming process - if not self.is_chat or not self._chunks: - self._chunks.append(chunk) - yield chunk - except Exception: - self._dd_span.set_exc_info(*sys.exc_info()) - raise - finally: - tag_stream_response(self._dd_span, self._chunks, self._dd_integration) - if self._dd_integration.is_pc_sampled_llmobs(self._dd_span): - self._kwargs["instance"] = self._model_instance - self._kwargs["history"] = self._history - self._dd_integration.llmobs_set_tags( - self._dd_span, args=self._args, kwargs=self._kwargs, response=self._chunks - ) - self._dd_span.finish() - - -class TracedAsyncVertexAIStreamResponse(BaseTracedVertexAIStreamResponse): - def __aenter__(self): - self._generator.__enter__() - return self - - def __aexit__(self, exc_type, exc_val, exc_tb): - self._generator.__exit__(exc_type, exc_val, exc_tb) - - async def __aiter__(self): - try: - async for chunk in self._generator.__aiter__(): - # only keep track of the first chunk for chat messages since - # it is modified during the streaming process - if not self.is_chat or not self._chunks: - self._chunks.append(chunk) - yield chunk - except Exception: - self._dd_span.set_exc_info(*sys.exc_info()) - raise - finally: - tag_stream_response(self._dd_span, self._chunks, self._dd_integration) - if self._dd_integration.is_pc_sampled_llmobs(self._dd_span): - self._kwargs["instance"] = self._model_instance - self._kwargs["history"] = self._history - self._dd_integration.llmobs_set_tags( - self._dd_span, args=self._args, kwargs=self._kwargs, response=self._chunks - ) - self._dd_span.finish() - +class BaseVertexAIStreamHandler: + def _initialize_chunk_storage(self): + return [] + + def _process_chunk(self, chunk): + if not self.options.get("is_chat", False) or not self.chunks: + self.chunks.append(chunk) + + def finalize_stream(self, exception=None): + tag_stream_response(self.primary_span, self.chunks, self.integration) + if self.integration.is_pc_sampled_llmobs(self.primary_span): + self.request_kwargs["instance"] = self.options.get("model_instance", None) + self.request_kwargs["history"] = self.options.get("history", None) + self.integration.llmobs_set_tags( + self.primary_span, args=self.request_args, kwargs=self.request_kwargs, response=self.chunks + ) + self.primary_span.finish() + +class VertexAIStreamHandler(BaseVertexAIStreamHandler, StreamHandler): + def process_chunk(self, chunk, iterator=None): + self._process_chunk(chunk) + +class VertexAIAsyncStreamHandler(BaseVertexAIStreamHandler, AsyncStreamHandler): + async def process_chunk(self, chunk, iterator=None): + self._process_chunk(chunk) def extract_info_from_parts(parts): """Return concatenated text from parts and function calls.""" diff --git a/ddtrace/contrib/internal/vertexai/patch.py b/ddtrace/contrib/internal/vertexai/patch.py index 6dbe455cd99..b24197510b1 100644 --- a/ddtrace/contrib/internal/vertexai/patch.py +++ b/ddtrace/contrib/internal/vertexai/patch.py @@ -8,11 +8,13 @@ from ddtrace.contrib.internal.trace_utils import unwrap from ddtrace.contrib.internal.trace_utils import with_traced_module from ddtrace.contrib.internal.trace_utils import wrap -from ddtrace.contrib.internal.vertexai._utils import TracedAsyncVertexAIStreamResponse -from ddtrace.contrib.internal.vertexai._utils import TracedVertexAIStreamResponse from ddtrace.contrib.internal.vertexai._utils import tag_request from ddtrace.contrib.internal.vertexai._utils import tag_response from ddtrace.llmobs._integrations import VertexAIIntegration +from ddtrace.llmobs._integrations.base_stream_handler import make_traced_async_stream +from ddtrace.llmobs._integrations.base_stream_handler import make_traced_stream +from ddtrace.contrib.internal.vertexai._utils import VertexAIAsyncStreamHandler +from ddtrace.contrib.internal.vertexai._utils import VertexAIStreamHandler from ddtrace.llmobs._integrations.utils import extract_model_name_google from ddtrace.trace import Pin @@ -72,9 +74,7 @@ def _traced_generate(vertexai, pin, func, instance, args, kwargs, model_instance tag_request(span, integration, instance, args, kwargs, is_chat) generations = func(*args, **kwargs) if stream: - return TracedVertexAIStreamResponse( - generations, model_instance, integration, span, args, kwargs, is_chat, history - ) + return make_traced_stream(generations, VertexAIStreamHandler(integration, span, args, kwargs, is_chat=is_chat, history=history, model_instance=model_instance)) tag_response(span, generations, integration) except Exception: span.set_exc_info(*sys.exc_info()) @@ -107,9 +107,7 @@ async def _traced_agenerate(vertexai, pin, func, instance, args, kwargs, model_i tag_request(span, integration, instance, args, kwargs, is_chat) generations = await func(*args, **kwargs) if stream: - return TracedAsyncVertexAIStreamResponse( - generations, model_instance, integration, span, args, kwargs, is_chat, history - ) + return make_traced_async_stream(generations, VertexAIAsyncStreamHandler(integration, span, args, kwargs, is_chat=is_chat, history=history, model_instance=model_instance)) tag_response(span, generations, integration) except Exception: span.set_exc_info(*sys.exc_info()) diff --git a/ddtrace/llmobs/_integrations/base_stream_handler.py b/ddtrace/llmobs/_integrations/base_stream_handler.py new file mode 100644 index 00000000000..4940db099c9 --- /dev/null +++ b/ddtrace/llmobs/_integrations/base_stream_handler.py @@ -0,0 +1,214 @@ +import sys +from abc import ABC, abstractmethod +from collections import defaultdict + +import wrapt + +from ddtrace.internal.logger import get_logger + +log = get_logger(__name__) + + +class BaseStreamHandler(ABC): + def __init__(self, integration, span, args, kwargs, **options): + self.integration = integration + self.primary_span = span + self.request_args = args + self.request_kwargs = kwargs + self.options = options + + self.spans = [(span, kwargs)] + self.chunks = self._initialize_chunk_storage() + + def _initialize_chunk_storage(self): + return defaultdict(list) + + def add_span(self, span, kwargs): + self.spans.append((span, kwargs)) + + def handle_exception(self, exception): + """ + Handle exceptions that occur during streaming. + + Default implementation sets exception info on the primary span. + + Args: + exception: The exception that occurred + """ + if self.primary_span: + self.primary_span.set_exc_info(*sys.exc_info()) + + @abstractmethod + def finalize_stream(self, exception=None): + """ + Finalize the stream and complete all spans. + + This method is called when the stream ends (successfully or with error). + Implementations should: + 1. Process accumulated chunks into final response + 2. Set appropriate span tags + 3. Finish all spans + """ + pass + + +class StreamHandler(BaseStreamHandler): + @abstractmethod + def process_chunk(self, chunk, iterator=None): + """ + Process a single chunk from the stream. + + This method is called for each chunk as it's received. + Implementations should extract and store relevant data. + + Args: + chunk: The chunk object from the stream + iterator: The sync iterator object from the stream + """ + pass + + +class AsyncStreamHandler(BaseStreamHandler): + @abstractmethod + async def process_chunk(self, chunk, iterator=None): + """ + Process a single chunk from the stream. + + This method is called for each chunk as it's received. + Implementations should extract and store relevant data. + + Args: + chunk: The chunk object from the stream + iterator: The async iterator object from the stream + """ + pass + + +class TracedStream(wrapt.ObjectProxy): + def __init__(self, wrapped, handler: StreamHandler, on_stream_created=None): + """ + Wrap a stream object to trace the stream. + + Args: + wrapped: The stream object to wrap + handler: The StreamHandler instance to use for processing chunks + on_stream_created: In the case that the stream is created by a stream manager, this + callback function will be called when the underlying stream is created in case + modifications to the stream object are needed + """ + super().__init__(wrapped) + self._self_handler = handler + self._self_on_stream_created = on_stream_created + self._self_stream_iter = self.__wrapped__ + + def __iter__(self): + exc = None + try: + for chunk in self._self_stream_iter: + self._self_handler.process_chunk(chunk, self._self_stream_iter) + yield chunk + except Exception as e: + exc = e + self._self_handler.handle_exception(e) + raise + finally: + self._self_handler.finalize_stream(exc) + + def __enter__(self): + """ + Enter the context of the stream. + + If the stream is wrapped by a stream manager, the stream manager will be entered and the + underlying stream will be wrapped in a TracedStream object. The _self_on_stream_created + callback function will be called on the TracedStream object if it is provided and then it + will be returned. + + If the stream is not wrapped by a stream manager, the stream will be returned as is. + """ + if hasattr(self.__wrapped__, '__enter__'): + result = self.__wrapped__.__enter__() + # update iterator in case we are wrapping a stream manager + if result is not self.__wrapped__: + self._self_stream_iter = result + traced_stream = TracedStream(result, self._self_handler, self._self_on_stream_created) + if self._self_on_stream_created: + self._self_on_stream_created(traced_stream) + return traced_stream + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if hasattr(self.__wrapped__, '__exit__'): + return self.__wrapped__.__exit__(exc_type, exc_val, exc_tb) + + @property + def handler(self): + return self._self_handler + + +class TracedAsyncStream(wrapt.ObjectProxy): + def __init__(self, wrapped, handler: AsyncStreamHandler, on_stream_created=None): + """ + Wrap an async stream object to trace the stream. + + Args: + wrapped: The stream object to wrap + handler: The AsyncStreamHandler instance to use for processing chunks + on_stream_created: In the case that the stream is created by a stream manager, this + callback function will be called when the underlying stream is created in case + modifications to the stream object are needed + """ + super().__init__(wrapped) + self._self_handler = handler + self._self_on_stream_created = on_stream_created + self._self_async_stream_iter = self.__wrapped__ + + async def __aiter__(self): + exc = None + try: + async for chunk in self._self_async_stream_iter: + await self._self_handler.process_chunk(chunk, self._self_async_stream_iter) + yield chunk + except Exception as e: + exc = e + self._self_handler.handle_exception(e) + raise + finally: + self._self_handler.finalize_stream(exc) + + async def __aenter__(self): + """ + Enter the context of the stream. + + If the stream is wrapped by a stream manager, the stream manager will be entered and the + underlying stream will be wrapped in a TracedAsyncStream object. The _self_on_stream_created + callback function will be called on the TracedAsyncStream object if it is provided and then it + will be returned. + + If the stream is not wrapped by a stream manager, the stream will be returned as is. + """ + if hasattr(self.__wrapped__, '__aenter__'): + result = await self.__wrapped__.__aenter__() + # update iterator in case we are wrapping a stream manager + if result is not self.__wrapped__: + self._self_async_stream_iter = result + traced_stream = TracedAsyncStream(result, self._self_handler, self._self_on_stream_created) + if self._self_on_stream_created: + self._self_on_stream_created(traced_stream) + return traced_stream + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if hasattr(self.__wrapped__, '__aexit__'): + return await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb) + + @property + def handler(self): + return self._self_handler + + +def make_traced_stream(wrapped, handler: StreamHandler, on_stream_created=None): + return TracedStream(wrapped, handler, on_stream_created) + + +def make_traced_async_stream(wrapped, handler: AsyncStreamHandler, on_stream_created=None): + return TracedAsyncStream(wrapped, handler, on_stream_created) \ No newline at end of file