Skip to content

[MLOB-3112] make shared base classes for llmobs traced streams #13736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
178 changes: 51 additions & 127 deletions ddtrace/contrib/internal/anthropic/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
import anthropic
import wrapt

from ddtrace.llmobs._integrations.base_stream_handler import AsyncStreamHandler
from ddtrace.llmobs._integrations.base_stream_handler import make_traced_async_stream
from ddtrace.llmobs._integrations.base_stream_handler import make_traced_stream
from ddtrace.llmobs._integrations.base_stream_handler import StreamHandler
from ddtrace.contrib.internal.anthropic.utils import tag_tool_use_output_on_span
from ddtrace.internal.logger import get_logger
from ddtrace.llmobs._utils import _get_attr
Expand All @@ -15,139 +19,59 @@
log = get_logger(__name__)


def handle_streamed_response(integration, resp, args, kwargs, span):
if _is_stream(resp):
return TracedAnthropicStream(resp, integration, span, args, kwargs)
elif _is_async_stream(resp):
return TracedAnthropicAsyncStream(resp, integration, span, args, kwargs)
elif _is_stream_manager(resp):
return TracedAnthropicStreamManager(resp, integration, span, args, kwargs)
elif _is_async_stream_manager(resp):
return TracedAnthropicAsyncStreamManager(resp, integration, span, args, kwargs)


class BaseTracedAnthropicStream(wrapt.ObjectProxy):
def __init__(self, wrapped, integration, span, args, kwargs):
super().__init__(wrapped)
self._dd_span = span
self._streamed_chunks = []
self._dd_integration = integration
self._kwargs = kwargs
self._args = args


class TracedAnthropicStream(BaseTracedAnthropicStream):
def __init__(self, wrapped, integration, span, args, kwargs):
super().__init__(wrapped, integration, span, args, kwargs)
# we need to set a text_stream attribute so we can trace the yielded chunks
self.text_stream = self.__stream_text__()

def __enter__(self):
self.__wrapped__.__enter__()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.__wrapped__.__exit__(exc_type, exc_val, exc_tb)

def __iter__(self):
return self

def __next__(self):
try:
chunk = self.__wrapped__.__next__()
self._streamed_chunks.append(chunk)
return chunk
except StopIteration:
_process_finished_stream(
self._dd_integration, self._dd_span, self._args, self._kwargs, self._streamed_chunks
)
self._dd_span.finish()
raise
except Exception:
self._dd_span.set_exc_info(*sys.exc_info())
self._dd_span.finish()
raise

def __stream_text__(self):
# this is overridden because it is a helper function that collects all stream content chunks
for chunk in self:
if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta":
yield chunk.delta.text


class TracedAnthropicAsyncStream(BaseTracedAnthropicStream):
def __init__(self, wrapped, integration, span, args, kwargs):
super().__init__(wrapped, integration, span, args, kwargs)
# we need to set a text_stream attribute so we can trace the yielded chunks
self.text_stream = self.__stream_text__()

async def __aenter__(self):
await self.__wrapped__.__aenter__()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb)

def __aiter__(self):
return self

async def __anext__(self):
try:
chunk = await self.__wrapped__.__anext__()
self._streamed_chunks.append(chunk)
return chunk
except StopAsyncIteration:
_process_finished_stream(
self._dd_integration,
self._dd_span,
self._args,
self._kwargs,
self._streamed_chunks,
)
self._dd_span.finish()
raise
except Exception:
self._dd_span.set_exc_info(*sys.exc_info())
self._dd_span.finish()
raise

async def __stream_text__(self):
# this is overridden because it is a helper function that collects all stream content chunks
async for chunk in self:
if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta":
yield chunk.delta.text


class TracedAnthropicStreamManager(BaseTracedAnthropicStream):
def __enter__(self):
stream = self.__wrapped__.__enter__()
traced_stream = TracedAnthropicStream(
stream,
self._dd_integration,
self._dd_span,
self._args,
self._kwargs,
)
return traced_stream

def __exit__(self, exc_type, exc_val, exc_tb):
self.__wrapped__.__exit__(exc_type, exc_val, exc_tb)
def _text_stream_generator(traced_stream):
for chunk in traced_stream:
if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta":
yield chunk.delta.text


async def _async_text_stream_generator(traced_stream):
async for chunk in traced_stream:
if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta":
yield chunk.delta.text

class TracedAnthropicAsyncStreamManager(BaseTracedAnthropicStream):
async def __aenter__(self):
stream = await self.__wrapped__.__aenter__()
traced_stream = TracedAnthropicAsyncStream(
stream,
self._dd_integration,
self._dd_span,
self._args,
self._kwargs,
def handle_streamed_response(integration, resp, args, kwargs, span):
def add_text_stream(stream):
stream.text_stream = _text_stream_generator(stream)

def add_async_text_stream(stream):
stream.text_stream = _async_text_stream_generator(stream)

if _is_stream(resp) or _is_stream_manager(resp):
traced_stream = make_traced_stream(
resp,
AnthropicStreamHandler(integration, span, args, kwargs),
on_stream_created=add_text_stream
)
traced_stream.text_stream = _text_stream_generator(traced_stream)
return traced_stream
elif _is_async_stream(resp) or _is_async_stream_manager(resp):
traced_stream = make_traced_async_stream(
resp,
AnthropicAsyncStreamHandler(integration, span, args, kwargs),
on_stream_created=add_async_text_stream
)
traced_stream.text_stream = _async_text_stream_generator(traced_stream)
return traced_stream

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb)
class BaseAnthropicStreamHandler:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Quality Violation

Class BaseAnthropicStreamHandler should have an init method (...read more)

Ensure that a class has an __init__ method. This check is bypassed when the class is a data class (annotated with @dataclass).

View in Datadog  Leave us feedback  Documentation

def _initialize_chunk_storage(self):
return []

def finalize_stream(self, exception=None):
_process_finished_stream(
self.integration, self.primary_span, self.request_args, self.request_kwargs, self.chunks
)
self.primary_span.finish()

class AnthropicStreamHandler(BaseAnthropicStreamHandler, StreamHandler):
def process_chunk(self, chunk, iterator=None):
self.chunks.append(chunk)

class AnthropicAsyncStreamHandler(BaseAnthropicStreamHandler, AsyncStreamHandler):
async def process_chunk(self, chunk, iterator=None):
self.chunks.append(chunk)


def _process_finished_stream(integration, span, args, kwargs, streamed_chunks):
Expand Down
153 changes: 81 additions & 72 deletions ddtrace/contrib/internal/botocore/services/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from ddtrace.internal import core
from ddtrace.internal.logger import get_logger
from ddtrace.internal.schema import schematize_service_name
from ddtrace.llmobs._integrations.base_stream_handler import make_traced_stream
from ddtrace.llmobs._integrations.base_stream_handler import StreamHandler
from ddtrace.llmobs._integrations.bedrock_utils import parse_model_id


Expand All @@ -27,81 +29,88 @@
_STABILITY = "stability"


class TracedBotocoreStreamingBody(wrapt.ObjectProxy):
"""
This class wraps the StreamingBody object returned by botocore api calls, specifically for Bedrock invocations.
Since the response body is in the form of a stream object, we need to wrap it in order to tag the response data
and fire completion events as the user consumes the streamed response.
"""

def __init__(self, wrapped, ctx: core.ExecutionContext):
super().__init__(wrapped)
self._body = []
self._execution_ctx = ctx

def read(self, amt=None):
"""Wraps around method to tags the response data and finish the span as the user consumes the stream."""
try:
body = self.__wrapped__.read(amt=amt)
self._body.append(json.loads(body))
if self.__wrapped__.tell() == int(self.__wrapped__._content_length):
formatted_response = _extract_text_and_response_reason(self._execution_ctx, self._body[0])
model_provider = self._execution_ctx["model_provider"]
model_name = self._execution_ctx["model_name"]
should_set_choice_ids = model_provider == _COHERE and "embed" not in model_name
core.dispatch(
"botocore.bedrock.process_response",
[self._execution_ctx, formatted_response, None, self._body[0], should_set_choice_ids],
)
return body
except Exception:
core.dispatch("botocore.patched_bedrock_api_call.exception", [self._execution_ctx, sys.exc_info()])
raise

def readlines(self):
"""Wraps around method to tags the response data and finish the span as the user consumes the stream."""
try:
lines = self.__wrapped__.readlines()
for line in lines:
self._body.append(json.loads(line))
formatted_response = _extract_text_and_response_reason(self._execution_ctx, self._body[0])
model_provider = self._execution_ctx["model_provider"]
model_name = self._execution_ctx["model_name"]
def traced_stream_read(traced_stream, original_read, amt=None):
"""Wraps around method to tags the response data and finish the span as the user consumes the stream."""
handler = traced_stream.handler
execution_ctx = handler.options.get("execution_ctx", {})
try:
body = original_read(amt=amt)
handler.chunks.append(json.loads(body))
if traced_stream.__wrapped__.tell() == int(traced_stream.__wrapped__._content_length):
formatted_response = _extract_text_and_response_reason(execution_ctx, handler.chunks[0])
model_provider = execution_ctx["model_provider"]
model_name = execution_ctx["model_name"]
should_set_choice_ids = model_provider == _COHERE and "embed" not in model_name
core.dispatch(
"botocore.bedrock.process_response",
[self._execution_ctx, formatted_response, None, self._body[0], should_set_choice_ids],
)
return lines
except Exception:
core.dispatch("botocore.patched_bedrock_api_call.exception", [self._execution_ctx, sys.exc_info()])
raise

def __iter__(self):
"""Wraps around method to tags the response data and finish the span as the user consumes the stream."""
exception_raised = False
try:
for line in self.__wrapped__:
self._body.append(json.loads(line["chunk"]["bytes"]))
yield line
except Exception:
core.dispatch("botocore.patched_bedrock_api_call.exception", [self._execution_ctx, sys.exc_info()])
exception_raised = True
raise
finally:
if exception_raised:
return
metadata = _extract_streamed_response_metadata(self._execution_ctx, self._body)
formatted_response = _extract_streamed_response(self._execution_ctx, self._body)
model_provider = self._execution_ctx["model_provider"]
model_name = self._execution_ctx["model_name"]
should_set_choice_ids = (
model_provider == _COHERE and "is_finished" not in self._body[0] and "embed" not in model_name
)
core.dispatch(
"botocore.bedrock.process_response",
[self._execution_ctx, formatted_response, metadata, self._body, should_set_choice_ids],
[execution_ctx, formatted_response, None, handler.chunks[0], should_set_choice_ids],
)
return body
except Exception:
core.dispatch("botocore.patched_bedrock_api_call.exception", [execution_ctx, sys.exc_info()])
raise

def traced_stream_readlines(traced_stream, original_readlines):
"""Wraps around method to tags the response data and finish the span as the user consumes the stream."""
handler = traced_stream.handler
execution_ctx = handler.options.get("execution_ctx", {})
try:
lines = original_readlines()
for line in lines:
handler.chunks.append(json.loads(line))
formatted_response = _extract_text_and_response_reason(execution_ctx, handler.chunks[0])
model_provider = execution_ctx["model_provider"]
model_name = execution_ctx["model_name"]
should_set_choice_ids = model_provider == _COHERE and "embed" not in model_name
core.dispatch(
"botocore.bedrock.process_response",
[execution_ctx, formatted_response, None, handler.chunks[0], should_set_choice_ids],
)
return lines
except Exception:
core.dispatch("botocore.patched_bedrock_api_call.exception", [execution_ctx, sys.exc_info()])
raise

class BotocoreStreamingBodyStreamHandler(StreamHandler):
def _initialize_chunk_storage(self):
return []

def process_chunk(self, chunk, iterator=None):
self.chunks.append(json.loads(chunk["chunk"]["bytes"]))

def handle_exception(self, exception):
core.dispatch("botocore.patched_bedrock_api_call.exception", [self.options.get("execution_ctx", {}), sys.exc_info()])

def finalize_stream(self, exception=None):
if exception:
return
execution_ctx = self.options.get("execution_ctx", {})
metadata = _extract_streamed_response_metadata(execution_ctx, self.chunks)
formatted_response = _extract_streamed_response(execution_ctx, self.chunks)
model_provider = execution_ctx["model_provider"]
model_name = execution_ctx["model_name"]
should_set_choice_ids = (
model_provider == _COHERE and "is_finished" not in self.chunks[0] and "embed" not in model_name
)
core.dispatch(
"botocore.bedrock.process_response",
[execution_ctx, formatted_response, metadata, self.chunks, should_set_choice_ids],
)


def make_botocore_streaming_body_traced_stream(streaming_body, integration, span, args, kwargs, execution_ctx):
original_read = getattr(streaming_body, "read", None)
original_readlines = getattr(streaming_body, "readlines", None)
traced_stream = make_traced_stream(
streaming_body,
BotocoreStreamingBodyStreamHandler(integration, span, args, kwargs, execution_ctx=execution_ctx),
)
# add bedrock-specific methods to the traced stream
if original_read:
traced_stream.read = lambda amt=None: traced_stream_read(traced_stream, original_read, amt)
if original_readlines:
traced_stream.readlines = lambda: traced_stream_readlines(traced_stream, original_readlines)
return traced_stream


class TracedBotocoreConverseStream(wrapt.ObjectProxy):
Expand Down Expand Up @@ -440,7 +449,7 @@ def handle_bedrock_response(
return result

body = result["body"]
result["body"] = TracedBotocoreStreamingBody(body, ctx)
result["body"] = make_botocore_streaming_body_traced_stream(body, None, None, None, None, ctx)
return result


Expand Down
Loading