Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""OpenTelemetry Google Generative AI API instrumentation"""

import logging
import os
import time
import types
from typing import Collection

Expand Down Expand Up @@ -33,7 +35,8 @@
LLMRequestTypeValues,
SpanAttributes,
)
from opentelemetry.trace import SpanKind, get_tracer
from opentelemetry.metrics import Meter, get_meter
from opentelemetry.trace import SpanKind, get_tracer, StatusCode
from wrapt import wrap_function_wrapper

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -79,6 +82,7 @@ def _build_from_streaming_response(
response: GenerateContentResponse,
llm_model,
event_logger,
token_histogram,
):
complete_response = ""
last_chunk = None
Expand All @@ -93,12 +97,14 @@ def _build_from_streaming_response(
emit_choice_events(response, event_logger)
else:
set_response_attributes(span, complete_response, llm_model)
set_model_response_attributes(span, last_chunk or response, llm_model)
set_model_response_attributes(
span, last_chunk or response, llm_model, token_histogram
)
span.end()


async def _abuild_from_streaming_response(
span, response: GenerateContentResponse, llm_model, event_logger
span, response: GenerateContentResponse, llm_model, event_logger, token_histogram
):
complete_response = ""
last_chunk = None
Expand All @@ -113,7 +119,9 @@ async def _abuild_from_streaming_response(
emit_choice_events(response, event_logger)
else:
set_response_attributes(span, complete_response, llm_model)
set_model_response_attributes(span, last_chunk if last_chunk else response, llm_model)
set_model_response_attributes(
span, last_chunk if last_chunk else response, llm_model, token_histogram
)
span.end()


Expand All @@ -128,21 +136,33 @@ def _handle_request(span, args, kwargs, llm_model, event_logger):


@dont_throw
def _handle_response(span, response, llm_model, event_logger):
def _handle_response(span, response, llm_model, event_logger, token_histogram):
if should_emit_events() and event_logger:
emit_choice_events(response, event_logger)
else:
set_response_attributes(span, response, llm_model)

set_model_response_attributes(span, response, llm_model)
set_model_response_attributes(span, response, llm_model, token_histogram)


def _with_tracer_wrapper(func):
"""Helper for providing tracer for wrapper functions."""

def _with_tracer(tracer, event_logger, to_wrap):
def _with_tracer(
tracer, event_logger, to_wrap, token_histogram, duration_histogram
):
def wrapper(wrapped, instance, args, kwargs):
return func(tracer, event_logger, to_wrap, wrapped, instance, args, kwargs)
return func(
tracer,
event_logger,
to_wrap,
token_histogram,
duration_histogram,
wrapped,
instance,
args,
kwargs,
)

return wrapper

Expand All @@ -154,6 +174,8 @@ async def _awrap(
tracer,
event_logger,
to_wrap,
token_histogram,
duration_histogram,
wrapped,
instance,
args,
Expand Down Expand Up @@ -186,22 +208,38 @@ async def _awrap(
SpanAttributes.LLM_REQUEST_TYPE: LLMRequestTypeValues.COMPLETION.value,
},
)

start_time = time.perf_counter()
_handle_request(span, args, kwargs, llm_model, event_logger)

response = await wrapped(*args, **kwargs)

try:
response = await wrapped(*args, **kwargs)
except Exception as e:
span.record_exception(e)
span.set_status(StatusCode.ERROR)
span.end()
raise e

if duration_histogram:
duration = time.perf_counter() - start_time
duration_histogram.record(
duration,
attributes={
GenAIAttributes.GEN_AI_PROVIDER_NAME: "Google",
GenAIAttributes.GEN_AI_RESPONSE_MODEL: llm_model,
},
)
if response:
if is_streaming_response(response):
return _build_from_streaming_response(
span, response, llm_model, event_logger
span, response, llm_model, event_logger, token_histogram
)
elif is_async_streaming_response(response):
return _abuild_from_streaming_response(
span, response, llm_model, event_logger
span, response, llm_model, event_logger, token_histogram
)
else:
_handle_response(span, response, llm_model, event_logger)
_handle_response(
span, response, llm_model, event_logger, token_histogram
)

span.end()
return response
Expand All @@ -212,6 +250,8 @@ def _wrap(
tracer,
event_logger,
to_wrap,
token_histogram,
duration_histogram,
wrapped,
instance,
args,
Expand Down Expand Up @@ -245,30 +285,72 @@ def _wrap(
},
)

start_time = time.perf_counter()
_handle_request(span, args, kwargs, llm_model, event_logger)

response = wrapped(*args, **kwargs)

try:
response = wrapped(*args, **kwargs)
except Exception as e:
span.record_exception(e)
span.set_status(StatusCode.ERROR)
span.end()
raise e

if duration_histogram:
duration = time.perf_counter() - start_time
duration_histogram.record(
duration,
attributes={
GenAIAttributes.GEN_AI_PROVIDER_NAME: "Google",
GenAIAttributes.GEN_AI_RESPONSE_MODEL: llm_model,
},
)
if response:
if is_streaming_response(response):
return _build_from_streaming_response(
span, response, llm_model, event_logger
span, response, llm_model, event_logger, token_histogram
)
elif is_async_streaming_response(response):
return _abuild_from_streaming_response(
span, response, llm_model, event_logger
span, response, llm_model, event_logger, token_histogram
)
else:
_handle_response(span, response, llm_model, event_logger)
_handle_response(
span, response, llm_model, event_logger, token_histogram
)

span.end()
return response


def is_metrics_enabled() -> bool:
return (os.getenv("TRACELOOP_METRICS_ENABLED") or "true").lower() == "true"


def _create_metrics(meter: Meter):
token_histogram = meter.create_histogram(
name="gen_ai.client.token.usage",
unit="token",
description="Measures number of input and output tokens used",
)

duration_histogram = meter.create_histogram(
name="gen_ai.client.operation.duration",
unit="s",
description="GenAI operation duration",
)

return token_histogram, duration_histogram


class GoogleGenerativeAiInstrumentor(BaseInstrumentor):
"""An instrumentor for Google Generative AI's client library."""

def __init__(self, exception_logger=None, use_legacy_attributes=True, upload_base64_image=None):
def __init__(
self,
exception_logger=None,
use_legacy_attributes=True,
upload_base64_image=None,
):
super().__init__()
Config.exception_logger = exception_logger
Config.use_legacy_attributes = use_legacy_attributes
Expand All @@ -285,6 +367,15 @@ def _instrument(self, **kwargs):
tracer_provider = kwargs.get("tracer_provider")
tracer = get_tracer(__name__, __version__, tracer_provider)

meter_provider = kwargs.get("meter_provider")
meter = get_meter(__name__, __version__, meter_provider)

token_histogram = None
duration_histogram = None

if is_metrics_enabled():
token_histogram, duration_histogram = _create_metrics(meter)

event_logger = None
if not Config.use_legacy_attributes:
logger_provider = kwargs.get("logger_provider")
Expand All @@ -297,14 +388,24 @@ def _instrument(self, **kwargs):
wrap_object = wrapped_method.get("object")
wrap_method = wrapped_method.get("method")

wrapper_args = (
tracer,
event_logger,
wrapped_method,
token_histogram,
duration_histogram,
)

wrapper = (
_awrap(*wrapper_args)
if wrap_object == "AsyncModels"
else _wrap(*wrapper_args)
)

wrap_function_wrapper(
wrap_package,
f"{wrap_object}.{wrap_method}",
(
_awrap(tracer, event_logger, wrapped_method)
if wrap_object == "AsyncModels"
else _wrap(tracer, event_logger, wrapped_method)
),
wrapper,
)

def _uninstrument(self, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def set_response_attributes(span, response, llm_model):
)


def set_model_response_attributes(span, response, llm_model):
def set_model_response_attributes(span, response, llm_model, token_histogram):
if not span.is_recording():
return

Expand All @@ -469,4 +469,22 @@ def set_model_response_attributes(span, response, llm_model):
response.usage_metadata.prompt_token_count,
)

if token_histogram and hasattr(response, "usage_metadata"):
token_histogram.record(
response.usage_metadata.prompt_token_count,
attributes={
GenAIAttributes.GEN_AI_PROVIDER_NAME: "Google",
GenAIAttributes.GEN_AI_TOKEN_TYPE: "input",
GenAIAttributes.GEN_AI_RESPONSE_MODEL: llm_model,
}
)
token_histogram.record(
response.usage_metadata.candidates_token_count,
attributes={
GenAIAttributes.GEN_AI_PROVIDER_NAME: "Google",
GenAIAttributes.GEN_AI_TOKEN_TYPE: "output",
GenAIAttributes.GEN_AI_RESPONSE_MODEL: llm_model,
},
)

span.set_status(Status(StatusCode.OK))
Loading