From 3ad444b4b1338ab8a91379bf4e3018caa22a8566 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 19 Feb 2026 14:56:38 +0100 Subject: [PATCH 1/6] Add Azure AI Inference integration Instrument `ChatCompletionsClient.complete()` and `EmbeddingsClient.embed()` from the `azure-ai-inference` SDK with custom monkey-patching. Supports sync and async clients, streaming, tool calls, and both v1/latest semconv versions. --- docs/integrations/llms/azure-ai-inference.md | 138 ++++ logfire-api/logfire_api/__init__.py | 4 + logfire/__init__.py | 2 + .../llm_providers/azure_ai_inference.py | 734 ++++++++++++++++++ logfire/_internal/main.py | 99 +++ mkdocs.yml | 1 + pyproject.toml | 2 + .../test_azure_ai_inference.py | 566 ++++++++++++++ uv.lock | 44 +- 9 files changed, 1589 insertions(+), 1 deletion(-) create mode 100644 docs/integrations/llms/azure-ai-inference.md create mode 100644 logfire/_internal/integrations/llm_providers/azure_ai_inference.py create mode 100644 tests/otel_integrations/test_azure_ai_inference.py diff --git a/docs/integrations/llms/azure-ai-inference.md b/docs/integrations/llms/azure-ai-inference.md new file mode 100644 index 000000000..5279a0f14 --- /dev/null +++ b/docs/integrations/llms/azure-ai-inference.md @@ -0,0 +1,138 @@ +--- +title: Pydantic Logfire Azure AI Inference Integration +description: "Instrument calls to Azure AI Inference with logfire.instrument_azure_ai_inference(). Track chat completions, embeddings, streaming responses, and token usage." +integration: logfire +--- +# Azure AI Inference + +**Logfire** supports instrumenting calls to [Azure AI Inference](https://pypi.org/project/azure-ai-inference/) with the [`logfire.instrument_azure_ai_inference()`][logfire.Logfire.instrument_azure_ai_inference] method. + +```python hl_lines="11-12" skip-run="true" skip-reason="external-connection" +from azure.ai.inference import ChatCompletionsClient +from azure.core.credentials import AzureKeyCredential + +import logfire + +client = ChatCompletionsClient( + endpoint='https://my-endpoint.inference.ai.azure.com', + credential=AzureKeyCredential('my-api-key'), +) + +logfire.configure() +logfire.instrument_azure_ai_inference(client) + +response = client.complete( + model='gpt-4', + messages=[ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + {'role': 'user', 'content': 'Please write me a limerick about Python logging.'}, + ], +) +print(response.choices[0].message.content) +``` + +With that you get: + +* a span around the call which records duration and captures any exceptions that might occur +* Human-readable display of the conversation with the agent +* details of the response, including the number of tokens used + +## Installation + +Install Logfire with the `azure-ai-inference` extra: + +{{ install_logfire(extras=['azure-ai-inference']) }} + +## Methods covered + +The following methods are covered: + +- [`ChatCompletionsClient.complete`](https://learn.microsoft.com/python/api/azure-ai-inference/azure.ai.inference.chatcompletionsclient) - with and without `stream=True` +- [`EmbeddingsClient.embed`](https://learn.microsoft.com/python/api/azure-ai-inference/azure.ai.inference.embeddingsclient) + +All methods are covered with both sync (`azure.ai.inference`) and async (`azure.ai.inference.aio`) clients. + +## Streaming Responses + +When instrumenting streaming responses, Logfire creates two spans - one around the initial request and one around the streamed response. + +```python skip-run="true" skip-reason="external-connection" +from azure.ai.inference import ChatCompletionsClient +from azure.core.credentials import AzureKeyCredential + +import logfire + +client = ChatCompletionsClient( + endpoint='https://my-endpoint.inference.ai.azure.com', + credential=AzureKeyCredential('my-api-key'), +) + +logfire.configure() +logfire.instrument_azure_ai_inference(client) + +response = client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'Write Python to show a tree of files.'}], + stream=True, +) +for chunk in response: + if chunk.choices: + delta = chunk.choices[0].delta + if delta and delta.content: + print(delta.content, end='', flush=True) +``` + +## Embeddings + +You can also instrument the `EmbeddingsClient`: + +```python skip-run="true" skip-reason="external-connection" +from azure.ai.inference import EmbeddingsClient +from azure.core.credentials import AzureKeyCredential + +import logfire + +client = EmbeddingsClient( + endpoint='https://my-endpoint.inference.ai.azure.com', + credential=AzureKeyCredential('my-api-key'), +) + +logfire.configure() +logfire.instrument_azure_ai_inference(client) + +response = client.embed( + model='text-embedding-ada-002', + input=['Hello world'], +) +print(len(response.data[0].embedding)) +``` + +## Async Support + +Async clients from `azure.ai.inference.aio` are fully supported: + +```python skip-run="true" skip-reason="external-connection" +from azure.ai.inference.aio import ChatCompletionsClient +from azure.core.credentials import AzureKeyCredential + +import logfire + +client = ChatCompletionsClient( + endpoint='https://my-endpoint.inference.ai.azure.com', + credential=AzureKeyCredential('my-api-key'), +) + +logfire.configure() +logfire.instrument_azure_ai_inference(client) +``` + +## Global Instrumentation + +If no client is passed, all `ChatCompletionsClient` and `EmbeddingsClient` classes (both sync and async) are instrumented: + +```python skip-run="true" skip-reason="external-connection" +import logfire + +logfire.configure() +logfire.instrument_azure_ai_inference() +``` diff --git a/logfire-api/logfire_api/__init__.py b/logfire-api/logfire_api/__init__.py index ab431b5ea..d83abca09 100644 --- a/logfire-api/logfire_api/__init__.py +++ b/logfire-api/logfire_api/__init__.py @@ -187,6 +187,9 @@ def instrument_print(self, *args, **kwargs) -> ContextManager[None]: def instrument_openai_agents(self, *args, **kwargs) -> None: ... + def instrument_azure_ai_inference(self, *args, **kwargs) -> ContextManager[None]: + return nullcontext() + def instrument_google_genai(self, *args, **kwargs) -> None: ... def instrument_litellm(self, *args, **kwargs) -> None: ... @@ -230,6 +233,7 @@ def shutdown(self, *args, **kwargs) -> None: ... instrument_openai = DEFAULT_LOGFIRE_INSTANCE.instrument_openai instrument_openai_agents = DEFAULT_LOGFIRE_INSTANCE.instrument_openai_agents instrument_anthropic = DEFAULT_LOGFIRE_INSTANCE.instrument_anthropic + instrument_azure_ai_inference = DEFAULT_LOGFIRE_INSTANCE.instrument_azure_ai_inference instrument_google_genai = DEFAULT_LOGFIRE_INSTANCE.instrument_google_genai instrument_litellm = DEFAULT_LOGFIRE_INSTANCE.instrument_litellm instrument_dspy = DEFAULT_LOGFIRE_INSTANCE.instrument_dspy diff --git a/logfire/__init__.py b/logfire/__init__.py index 8c0d01e02..2170e96f7 100644 --- a/logfire/__init__.py +++ b/logfire/__init__.py @@ -45,6 +45,7 @@ instrument_openai = DEFAULT_LOGFIRE_INSTANCE.instrument_openai instrument_openai_agents = DEFAULT_LOGFIRE_INSTANCE.instrument_openai_agents instrument_anthropic = DEFAULT_LOGFIRE_INSTANCE.instrument_anthropic +instrument_azure_ai_inference = DEFAULT_LOGFIRE_INSTANCE.instrument_azure_ai_inference instrument_google_genai = DEFAULT_LOGFIRE_INSTANCE.instrument_google_genai instrument_litellm = DEFAULT_LOGFIRE_INSTANCE.instrument_litellm instrument_dspy = DEFAULT_LOGFIRE_INSTANCE.instrument_dspy @@ -152,6 +153,7 @@ def loguru_handler() -> Any: 'instrument_openai', 'instrument_openai_agents', 'instrument_anthropic', + 'instrument_azure_ai_inference', 'instrument_google_genai', 'instrument_litellm', 'instrument_dspy', diff --git a/logfire/_internal/integrations/llm_providers/azure_ai_inference.py b/logfire/_internal/integrations/llm_providers/azure_ai_inference.py new file mode 100644 index 000000000..702da9a62 --- /dev/null +++ b/logfire/_internal/integrations/llm_providers/azure_ai_inference.py @@ -0,0 +1,734 @@ +# pyright: reportUnknownVariableType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false +from __future__ import annotations + +import json +from collections.abc import AsyncIterator, Iterator +from contextlib import AbstractContextManager, ExitStack, contextmanager, nullcontext +from typing import TYPE_CHECKING, Any, cast + +from opentelemetry.trace import SpanKind + +from logfire import attach_context, get_context + +from ...constants import ONE_SECOND_IN_NANOSECONDS +from ...utils import handle_internal_errors, is_instrumentation_suppressed, log_internal_error, suppress_instrumentation +from .semconv import ( + INPUT_MESSAGES, + INPUT_TOKENS, + OPERATION_NAME, + OUTPUT_MESSAGES, + OUTPUT_TOKENS, + PROVIDER_NAME, + REQUEST_FREQUENCY_PENALTY, + REQUEST_MAX_TOKENS, + REQUEST_MODEL, + REQUEST_PRESENCE_PENALTY, + REQUEST_SEED, + REQUEST_STOP_SEQUENCES, + REQUEST_TEMPERATURE, + REQUEST_TOP_P, + RESPONSE_FINISH_REASONS, + RESPONSE_ID, + RESPONSE_MODEL, + SYSTEM_INSTRUCTIONS, + TOOL_DEFINITIONS, + BlobPart, + ChatMessage, + InputMessages, + MessagePart, + OutputMessage, + OutputMessages, + Role, + SemconvVersion, + SystemInstructions, + TextPart, + ToolCallPart, + ToolCallResponsePart, + UriPart, +) + +if TYPE_CHECKING: + from ...main import Logfire, LogfireSpan + +__all__ = ('instrument_azure_ai_inference',) + +AZURE_PROVIDER = 'azure.ai.inference' + + +# --- Main instrumentation entry point --- + + +def instrument_azure_ai_inference( + logfire_instance: Logfire, + client: Any, + suppress_other_instrumentation: bool, + versions: frozenset[SemconvVersion], +) -> AbstractContextManager[None]: + """Instrument Azure AI Inference clients.""" + if isinstance(client, (tuple, list)): + context_managers = [ + instrument_azure_ai_inference(logfire_instance, c, suppress_other_instrumentation, versions) for c in client + ] + + @contextmanager + def uninstrument_all() -> Iterator[None]: + with ExitStack() as stack: + for cm in context_managers: + stack.enter_context(cm) + yield + + return uninstrument_all() + + if getattr(client, '_is_instrumented_by_logfire', False): + return nullcontext() + + client_cls = client if isinstance(client, type) else type(client) + is_async = _is_async_client(client_cls) + client_type = _get_client_type(client_cls) + + if client_type is None: # pragma: no cover + return nullcontext() + + logfire_llm = logfire_instance.with_settings(custom_scope_suffix='azure_ai_inference', tags=['LLM']) + client._is_instrumented_by_logfire = True + + if client_type == 'chat': + method_name = 'complete' + original = client.complete + client._original_logfire_method = original + client.complete = _make_instrumented_complete( + original, logfire_llm, suppress_other_instrumentation, versions, is_async + ) + else: + method_name = 'embed' + original = client.embed + client._original_logfire_method = original + client.embed = _make_instrumented_embed( + original, logfire_llm, suppress_other_instrumentation, versions, is_async + ) + + @contextmanager + def uninstrument() -> Iterator[None]: + try: + yield + finally: + setattr(client, method_name, client._original_logfire_method) + del client._original_logfire_method + client._is_instrumented_by_logfire = False + + return uninstrument() + + +# --- Client type detection --- + + +def _is_async_client(client_cls: type[Any]) -> bool: + return 'aio' in client_cls.__module__ + + +def _get_client_type(client_cls: type[Any]) -> str | None: + name = client_cls.__name__ + if 'ChatCompletions' in name: + return 'chat' + if 'Embeddings' in name: + return 'embeddings' + return None # pragma: no cover + + +# --- Instrumented method factories --- + + +def _make_instrumented_complete( + original: Any, + logfire_llm: Logfire, + suppress: bool, + versions: frozenset[SemconvVersion], + is_async: bool, +) -> Any: + if is_async: + + async def instrumented_complete(*args: Any, **kwargs: Any) -> Any: + if is_instrumentation_suppressed(): + return await original(*args, **kwargs) + try: + span_data = _build_chat_span_data(args, kwargs, versions) + except Exception: + log_internal_error() + return await original(*args, **kwargs) + + is_streaming = kwargs.get('stream', False) + original_context = get_context() + + with logfire_llm.span( + 'Chat completion with {request_data[model]!r}', + _span_kind=SpanKind.CLIENT, + **span_data, + ) as span: + if suppress: + with suppress_instrumentation(): + response = await original(*args, **kwargs) + else: + response = await original(*args, **kwargs) + + if is_streaming: + return _AsyncStreamWrapper(response, logfire_llm, span_data, versions, original_context) + _on_chat_response(response, span, versions) + return response + + return instrumented_complete + else: + + def instrumented_complete_sync(*args: Any, **kwargs: Any) -> Any: + if is_instrumentation_suppressed(): + return original(*args, **kwargs) + try: + span_data = _build_chat_span_data(args, kwargs, versions) + except Exception: + log_internal_error() + return original(*args, **kwargs) + + is_streaming = kwargs.get('stream', False) + original_context = get_context() + + with logfire_llm.span( + 'Chat completion with {request_data[model]!r}', + _span_kind=SpanKind.CLIENT, + **span_data, + ) as span: + if suppress: + with suppress_instrumentation(): + response = original(*args, **kwargs) + else: + response = original(*args, **kwargs) + + if is_streaming: + return _SyncStreamWrapper(response, logfire_llm, span_data, versions, original_context) + _on_chat_response(response, span, versions) + return response + + return instrumented_complete_sync + + +def _make_instrumented_embed( + original: Any, + logfire_llm: Logfire, + suppress: bool, + versions: frozenset[SemconvVersion], + is_async: bool, +) -> Any: + if is_async: + + async def instrumented_embed(*args: Any, **kwargs: Any) -> Any: + if is_instrumentation_suppressed(): + return await original(*args, **kwargs) + try: + span_data = _build_embed_span_data(args, kwargs, versions) + except Exception: + log_internal_error() + return await original(*args, **kwargs) + + with logfire_llm.span( + 'Embeddings with {request_data[model]!r}', + _span_kind=SpanKind.CLIENT, + **span_data, + ) as span: + if suppress: + with suppress_instrumentation(): + response = await original(*args, **kwargs) + else: + response = await original(*args, **kwargs) + _on_embed_response(response, span, versions) + return response + + return instrumented_embed + else: + + def instrumented_embed_sync(*args: Any, **kwargs: Any) -> Any: + if is_instrumentation_suppressed(): + return original(*args, **kwargs) + try: + span_data = _build_embed_span_data(args, kwargs, versions) + except Exception: + log_internal_error() + return original(*args, **kwargs) + + with logfire_llm.span( + 'Embeddings with {request_data[model]!r}', + _span_kind=SpanKind.CLIENT, + **span_data, + ) as span: + if suppress: + with suppress_instrumentation(): + response = original(*args, **kwargs) + else: + response = original(*args, **kwargs) + _on_embed_response(response, span, versions) + return response + + return instrumented_embed_sync + + +# --- Span data builders --- + + +def _build_chat_span_data( + args: tuple[Any, ...], + kwargs: dict[str, Any], + versions: frozenset[SemconvVersion], +) -> dict[str, Any]: + params = _extract_params(args, kwargs) + messages = params.get('messages', []) + model = params.get('model') + + request_data: dict[str, Any] = {'model': model} + if 1 in versions: + if messages: + request_data['messages'] = [_msg_to_dict(m) for m in messages] + for key in ('temperature', 'max_tokens', 'top_p', 'frequency_penalty', 'presence_penalty', 'seed', 'stop'): + val = params.get(key) + if val is not None: + request_data[key] = val + if (tools := params.get('tools')) is not None: + request_data['tools'] = [t if isinstance(t, dict) else t.as_dict() for t in tools] + + span_data: dict[str, Any] = { + 'request_data': request_data, + PROVIDER_NAME: AZURE_PROVIDER, + OPERATION_NAME: 'chat', + } + if model: + span_data[REQUEST_MODEL] = model + + _extract_request_parameters(params, span_data) + + if 'latest' in versions and messages: + input_messages, system_instructions = convert_messages_to_semconv(messages) + span_data[INPUT_MESSAGES] = input_messages + if system_instructions: + span_data[SYSTEM_INSTRUCTIONS] = system_instructions + + return span_data + + +def _build_embed_span_data( + args: tuple[Any, ...], + kwargs: dict[str, Any], + versions: frozenset[SemconvVersion], +) -> dict[str, Any]: + params = _extract_params(args, kwargs) + model = params.get('model') + + request_data: dict[str, Any] = {'model': model} + if 1 in versions: + input_val = params.get('input') + if input_val is not None: + request_data['input'] = input_val + + span_data: dict[str, Any] = { + 'request_data': request_data, + PROVIDER_NAME: AZURE_PROVIDER, + OPERATION_NAME: 'embeddings', + } + if model: + span_data[REQUEST_MODEL] = model + + return span_data + + +def _extract_params(args: tuple[Any, ...], kwargs: dict[str, Any]) -> dict[str, Any]: + """Extract parameters from method call, handling both body and keyword styles.""" + if 'body' in kwargs and isinstance(kwargs['body'], dict): + return kwargs['body'] + for arg in args: + if isinstance(arg, dict) and ('messages' in arg or 'input' in arg): + return arg + return kwargs + + +def _extract_request_parameters(params: dict[str, Any], span_data: dict[str, Any]) -> None: + if (max_tokens := params.get('max_tokens')) is not None: + span_data[REQUEST_MAX_TOKENS] = max_tokens + if (temperature := params.get('temperature')) is not None: + span_data[REQUEST_TEMPERATURE] = temperature + if (top_p := params.get('top_p')) is not None: + span_data[REQUEST_TOP_P] = top_p + if (frequency_penalty := params.get('frequency_penalty')) is not None: + span_data[REQUEST_FREQUENCY_PENALTY] = frequency_penalty + if (presence_penalty := params.get('presence_penalty')) is not None: + span_data[REQUEST_PRESENCE_PENALTY] = presence_penalty + if (seed := params.get('seed')) is not None: + span_data[REQUEST_SEED] = seed + if (stop := params.get('stop')) is not None: + span_data[REQUEST_STOP_SEQUENCES] = json.dumps(stop) + if (tools := params.get('tools')) is not None: + span_data[TOOL_DEFINITIONS] = json.dumps([t if isinstance(t, dict) else t.as_dict() for t in tools]) + + +# --- Response processors --- + + +@handle_internal_errors +def _on_chat_response(response: Any, span: LogfireSpan, versions: frozenset[SemconvVersion]) -> None: + choices = getattr(response, 'choices', []) + usage = getattr(response, 'usage', None) + + if 1 in versions: + response_data: dict[str, Any] = {} + if choices: + message = getattr(choices[0], 'message', None) + if message: + msg_data: dict[str, Any] = {'role': getattr(message, 'role', 'assistant')} + content = getattr(message, 'content', None) + if content: + msg_data['content'] = content + tool_calls = getattr(message, 'tool_calls', None) + if tool_calls: + msg_data['tool_calls'] = [ + { + 'id': getattr(tc, 'id', ''), + 'function': { + 'name': getattr(getattr(tc, 'function', None), 'name', ''), + 'arguments': getattr(getattr(tc, 'function', None), 'arguments', ''), + }, + } + for tc in tool_calls + ] + response_data['message'] = msg_data + if usage: + response_data['usage'] = { + 'prompt_tokens': getattr(usage, 'prompt_tokens', 0), + 'completion_tokens': getattr(usage, 'completion_tokens', 0), + 'total_tokens': getattr(usage, 'total_tokens', 0), + } + span.set_attribute('response_data', response_data) + + if 'latest' in versions: + output_messages = convert_response_to_semconv(response) + if output_messages: + span.set_attribute(OUTPUT_MESSAGES, output_messages) + + model = getattr(response, 'model', None) + if model: + span.set_attribute(RESPONSE_MODEL, model) + + response_id = getattr(response, 'id', None) + if response_id: + span.set_attribute(RESPONSE_ID, response_id) + + if usage: + prompt_tokens = getattr(usage, 'prompt_tokens', None) + if prompt_tokens is not None: + span.set_attribute(INPUT_TOKENS, prompt_tokens) + completion_tokens = getattr(usage, 'completion_tokens', None) + if completion_tokens is not None: + span.set_attribute(OUTPUT_TOKENS, completion_tokens) + + finish_reasons = [str(c.finish_reason) for c in choices if getattr(c, 'finish_reason', None)] + if finish_reasons: + span.set_attribute(RESPONSE_FINISH_REASONS, finish_reasons) + + +@handle_internal_errors +def _on_embed_response(response: Any, span: LogfireSpan, versions: frozenset[SemconvVersion]) -> None: + usage = getattr(response, 'usage', None) + + if 1 in versions: + response_data: dict[str, Any] = {} + if usage: + response_data['usage'] = { + 'prompt_tokens': getattr(usage, 'prompt_tokens', 0), + 'total_tokens': getattr(usage, 'total_tokens', 0), + } + data = getattr(response, 'data', None) + if data: + response_data['data_count'] = len(data) + span.set_attribute('response_data', response_data) + + model = getattr(response, 'model', None) + if model: + span.set_attribute(RESPONSE_MODEL, model) + + response_id = getattr(response, 'id', None) + if response_id: + span.set_attribute(RESPONSE_ID, response_id) + + if usage: + prompt_tokens = getattr(usage, 'prompt_tokens', None) + if prompt_tokens is not None: + span.set_attribute(INPUT_TOKENS, prompt_tokens) + + +# --- Message conversion --- + + +def _msg_to_dict(msg: Any) -> dict[str, Any]: + """Convert an Azure message object or dict to a plain dict.""" + if isinstance(msg, dict): + return msg + if hasattr(msg, 'as_dict'): + return msg.as_dict() + return {} # pragma: no cover + + +def convert_messages_to_semconv(messages: list[Any]) -> tuple[InputMessages, SystemInstructions]: + """Convert Azure AI Inference messages to OTel GenAI semconv format.""" + input_messages: InputMessages = [] + system_instructions: SystemInstructions = [] + + for msg in messages: + msg_dict = _msg_to_dict(msg) + role: str = msg_dict.get('role', 'user') + content = msg_dict.get('content') + + if role in ('system', 'developer'): + if isinstance(content, str): + system_instructions.append(TextPart(type='text', content=content)) + continue + + if role == 'tool': + tool_call_id = msg_dict.get('tool_call_id', '') + input_messages.append( + ChatMessage( + role='tool', + parts=[ + ToolCallResponsePart( + type='tool_call_response', + id=tool_call_id, + response=content if isinstance(content, str) else str(content) if content else '', + ) + ], + ) + ) + continue + + parts: list[MessagePart] = [] + if isinstance(content, str) and content: + parts.append(TextPart(type='text', content=content)) + elif isinstance(content, list): + for item in content: + parts.append(_convert_content_item(item)) + + tool_calls = msg_dict.get('tool_calls') + if tool_calls: + for tc in tool_calls: + tc_dict = tc if isinstance(tc, dict) else (tc.as_dict() if hasattr(tc, 'as_dict') else {}) + func = tc_dict.get('function', {}) + parts.append( + ToolCallPart( + type='tool_call', + id=tc_dict.get('id', ''), + name=func.get('name', ''), + arguments=func.get('arguments'), + ) + ) + + chat_role: Role = cast('Role', role if role in ('user', 'assistant') else 'user') + input_messages.append(ChatMessage(role=chat_role, parts=parts)) + + return input_messages, system_instructions + + +def _convert_content_item(item: Any) -> MessagePart: + """Convert a content item (text, image, audio) to semconv format.""" + if isinstance(item, str): + return TextPart(type='text', content=item) + + item_dict = item if isinstance(item, dict) else (item.as_dict() if hasattr(item, 'as_dict') else {}) + item_type = item_dict.get('type', 'text') + + if item_type == 'text': + return TextPart(type='text', content=item_dict.get('text', '')) + elif item_type == 'image_url': + image_url = item_dict.get('image_url', {}) + return UriPart(type='uri', uri=image_url.get('url', ''), modality='image') + elif item_type == 'input_audio': + audio = item_dict.get('input_audio', {}) + return BlobPart( + type='blob', + content=audio.get('data', ''), + media_type=f'audio/{audio.get("format", "wav")}', + modality='audio', + ) + else: # pragma: no cover + return cast('MessagePart', item_dict) + + +def convert_response_to_semconv(response: Any) -> OutputMessages: + """Convert a ChatCompletions response to OTel GenAI semconv format.""" + output_messages: OutputMessages = [] + + for choice in getattr(response, 'choices', []): + message = getattr(choice, 'message', None) + if not message: + continue + + parts: list[MessagePart] = [] + content = getattr(message, 'content', None) + if content: + parts.append(TextPart(type='text', content=content)) + + tool_calls = getattr(message, 'tool_calls', None) + if tool_calls: + for tc in tool_calls: + func = getattr(tc, 'function', None) + if func: + parts.append( + ToolCallPart( + type='tool_call', + id=getattr(tc, 'id', ''), + name=getattr(func, 'name', ''), + arguments=getattr(func, 'arguments', None), + ) + ) + + output_msg: OutputMessage = { + 'role': cast('Role', getattr(message, 'role', 'assistant')), + 'parts': parts, + } + finish_reason = getattr(choice, 'finish_reason', None) + if finish_reason: + output_msg['finish_reason'] = str(finish_reason) + output_messages.append(output_msg) + + return output_messages + + +# --- Streaming wrappers --- + + +class _SyncStreamWrapper: + """Wraps a sync streaming response to record chunks and emit a streaming info span.""" + + def __init__( + self, + wrapped: Any, + logfire_llm: Logfire, + span_data: dict[str, Any], + versions: frozenset[SemconvVersion], + original_context: Any, + ) -> None: + self._wrapped = wrapped + self._logfire_llm = logfire_llm + self._span_data = span_data + self._versions = versions + self._original_context = original_context + self._chunks: list[str] = [] + + def __enter__(self) -> _SyncStreamWrapper: + if hasattr(self._wrapped, '__enter__'): + self._wrapped.__enter__() + return self + + def __exit__(self, *args: Any) -> None: + if hasattr(self._wrapped, '__exit__'): + self._wrapped.__exit__(*args) + + def __iter__(self) -> Iterator[Any]: + timer = self._logfire_llm._config.advanced.ns_timestamp_generator # type: ignore + start = timer() + try: + for chunk in self._wrapped: + self._record_chunk(chunk) + yield chunk + finally: + duration = (timer() - start) / ONE_SECOND_IN_NANOSECONDS + with attach_context(self._original_context): + self._logfire_llm.info( + 'streaming response from {request_data[model]!r} took {duration:.2f}s', + duration=duration, + **self._get_stream_attributes(), + ) + + def _record_chunk(self, chunk: Any) -> None: + for choice in getattr(chunk, 'choices', []): + delta = getattr(choice, 'delta', None) + if delta: + content = getattr(delta, 'content', None) + if content: + self._chunks.append(content) + + def _get_stream_attributes(self) -> dict[str, Any]: + result = dict(**self._span_data) + combined = ''.join(self._chunks) + if 1 in self._versions: + result['response_data'] = { + 'combined_chunk_content': combined, + 'chunk_count': len(self._chunks), + } + if 'latest' in self._versions and self._chunks: + result[OUTPUT_MESSAGES] = [ + OutputMessage( + role='assistant', + parts=[TextPart(type='text', content=combined)], + ) + ] + return result + + +class _AsyncStreamWrapper: + """Wraps an async streaming response to record chunks and emit a streaming info span.""" + + def __init__( + self, + wrapped: Any, + logfire_llm: Logfire, + span_data: dict[str, Any], + versions: frozenset[SemconvVersion], + original_context: Any, + ) -> None: + self._wrapped = wrapped + self._logfire_llm = logfire_llm + self._span_data = span_data + self._versions = versions + self._original_context = original_context + self._chunks: list[str] = [] + + async def __aenter__(self) -> _AsyncStreamWrapper: + if hasattr(self._wrapped, '__aenter__'): + await self._wrapped.__aenter__() + return self + + async def __aexit__(self, *args: Any) -> None: + if hasattr(self._wrapped, '__aexit__'): + await self._wrapped.__aexit__(*args) + + async def __aiter__(self) -> AsyncIterator[Any]: + timer = self._logfire_llm._config.advanced.ns_timestamp_generator # type: ignore + start = timer() + try: + async for chunk in self._wrapped: + self._record_chunk(chunk) + yield chunk + finally: + duration = (timer() - start) / ONE_SECOND_IN_NANOSECONDS + with attach_context(self._original_context): + self._logfire_llm.info( + 'streaming response from {request_data[model]!r} took {duration:.2f}s', + duration=duration, + **self._get_stream_attributes(), + ) + + def _record_chunk(self, chunk: Any) -> None: + for choice in getattr(chunk, 'choices', []): + delta = getattr(choice, 'delta', None) + if delta: + content = getattr(delta, 'content', None) + if content: + self._chunks.append(content) + + def _get_stream_attributes(self) -> dict[str, Any]: + result = dict(**self._span_data) + combined = ''.join(self._chunks) + if 1 in self._versions: + result['response_data'] = { + 'combined_chunk_content': combined, + 'chunk_count': len(self._chunks), + } + if 'latest' in self._versions and self._chunks: + result[OUTPUT_MESSAGES] = [ + OutputMessage( + role='assistant', + parts=[TextPart(type='text', content=combined)], + ) + ] + return result diff --git a/logfire/_internal/main.py b/logfire/_internal/main.py index 472910163..80399c33f 100644 --- a/logfire/_internal/main.py +++ b/logfire/_internal/main.py @@ -1389,6 +1389,105 @@ def instrument_anthropic( is_async_client, ) + def instrument_azure_ai_inference( + self, + azure_ai_inference_client: Any = None, + *, + suppress_other_instrumentation: bool = True, + version: SemconvVersion | Sequence[SemconvVersion] = 1, + ) -> AbstractContextManager[None]: + """Instrument an Azure AI Inference client so that spans are automatically created for each request. + + Supports both the sync and async clients from the + [`azure-ai-inference`](https://pypi.org/project/azure-ai-inference/) package: + + - [`ChatCompletionsClient.complete`](https://learn.microsoft.com/python/api/azure-ai-inference/azure.ai.inference.chatcompletionsclient) - with and without `stream=True` + - [`EmbeddingsClient.embed`](https://learn.microsoft.com/python/api/azure-ai-inference/azure.ai.inference.embeddingsclient) + + Example usage: + + ```python skip-run="true" skip-reason="external-connection" + from azure.ai.inference import ChatCompletionsClient + from azure.core.credentials import AzureKeyCredential + + import logfire + + client = ChatCompletionsClient( + endpoint='https://my-endpoint.inference.ai.azure.com', + credential=AzureKeyCredential('my-api-key'), + ) + + logfire.configure() + logfire.instrument_azure_ai_inference(client) + + response = client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'What is four plus five?'}], + ) + print(response.choices[0].message.content) + ``` + + Args: + azure_ai_inference_client: The Azure AI Inference client or class to instrument: + + - `None` (the default) to instrument all Azure AI Inference client classes. + - A `ChatCompletionsClient` or `EmbeddingsClient` class or instance (sync or async). + + suppress_other_instrumentation: If True, suppress any other OTEL instrumentation that may be otherwise + enabled. In reality, this means the Azure Core tracing instrumentation, which could otherwise be + called since the Azure SDK uses its own pipeline to make HTTP requests. + + version: The version(s) of the span attribute format to use: + + - `1` (the default): Uses `request_data` and `response_data` attributes. + - `'latest'`: Uses OpenTelemetry Gen AI semantic convention attributes + (`gen_ai.input.messages`, `gen_ai.output.messages`, etc.) and omits the full + `response_data` attribute. A minimal `request_data` (e.g. `{"model": ...}`) is + still recorded for message template compatibility. This format may change between + releases. + - `[1, 'latest']`: Emits both the full legacy attributes and the semantic convention + attributes simultaneously, useful for migration and testing. + + Returns: + A context manager that will revert the instrumentation when exited. + Use of this context manager is optional. + """ + try: + from azure.ai.inference import ChatCompletionsClient, EmbeddingsClient + except ImportError: + raise RuntimeError( + 'The `logfire.instrument_azure_ai_inference()` method ' + 'requires the `azure-ai-inference` package.\n' + 'You can install this with:\n' + " pip install 'logfire[azure-ai-inference]'" + ) + + from .integrations.llm_providers.azure_ai_inference import instrument_azure_ai_inference + from .integrations.llm_providers.semconv import normalize_versions + + normalized_versions = normalize_versions(version) + self._warn_if_not_initialized_for_instrumentation() + + if azure_ai_inference_client is None: + clients_to_instrument: list[Any] = [ChatCompletionsClient, EmbeddingsClient] + try: + from azure.ai.inference.aio import ( + ChatCompletionsClient as AsyncChatCompletionsClient, + EmbeddingsClient as AsyncEmbeddingsClient, + ) + + clients_to_instrument.extend([AsyncChatCompletionsClient, AsyncEmbeddingsClient]) + except ImportError: # pragma: no cover + pass + azure_ai_inference_client = clients_to_instrument + + return instrument_azure_ai_inference( + self, + azure_ai_inference_client, + suppress_other_instrumentation, + normalized_versions, + ) + def instrument_google_genai(self, **kwargs: Any): """Instrument the [Google Gen AI SDK (`google-genai`)](https://googleapis.github.io/python-genai/). diff --git a/mkdocs.yml b/mkdocs.yml index 67b619d78..f79d650c9 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -114,6 +114,7 @@ nav: - OpenAI: integrations/llms/openai.md - Google Gen AI: integrations/llms/google-genai.md - Anthropic: integrations/llms/anthropic.md + - Azure AI Inference: integrations/llms/azure-ai-inference.md - LangChain: integrations/llms/langchain.md - LiteLLM: integrations/llms/litellm.md - DSPy: integrations/llms/dspy.md diff --git a/pyproject.toml b/pyproject.toml index e4e1a35cb..333ab4d9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,7 @@ requests = ["opentelemetry-instrumentation-requests >= 0.42b0"] mysql = ["opentelemetry-instrumentation-mysql >= 0.42b0"] sqlite3 = ["opentelemetry-instrumentation-sqlite3 >= 0.42b0"] aws-lambda = ["opentelemetry-instrumentation-aws-lambda >= 0.42b0"] +azure-ai-inference = ["azure-ai-inference >= 1.0.0b1"] google-genai = ["opentelemetry-instrumentation-google-genai >= 0.4b0"] litellm = ["openinference-instrumentation-litellm >= 0"] dspy = ["openinference-instrumentation-dspy >= 0"] @@ -159,6 +160,7 @@ dev = [ "cryptography >= 44.0.0", "cloudpickle>=3.0.0", "anthropic>=0.27.0", + "azure-ai-inference>=1.0.0b1", "sqlmodel>=0.0.15", "mypy>=1.10.0", "celery>=5.4.0", diff --git a/tests/otel_integrations/test_azure_ai_inference.py b/tests/otel_integrations/test_azure_ai_inference.py new file mode 100644 index 000000000..a19091c6d --- /dev/null +++ b/tests/otel_integrations/test_azure_ai_inference.py @@ -0,0 +1,566 @@ +# pyright: reportCallIssue=false, reportArgumentType=false +from __future__ import annotations as _annotations + +from datetime import datetime +from typing import Any + +import pytest +from azure.ai.inference.models import ( + ChatChoice, + ChatCompletions, + ChatResponseMessage, + CompletionsUsage, + EmbeddingItem, + EmbeddingsResult, + EmbeddingsUsage, + StreamingChatChoiceUpdate, + StreamingChatCompletionsUpdate, + StreamingChatResponseMessageUpdate, +) +from inline_snapshot import snapshot + +import logfire +from logfire.testing import TestExporter + + +def _make_chat_response( + content: str = 'Nine', + finish_reason: str = 'stop', + tool_calls: list[Any] | None = None, +) -> ChatCompletions: + message_kwargs: dict[str, Any] = {'role': 'assistant', 'content': content} + if tool_calls is not None: + message_kwargs['tool_calls'] = tool_calls + return ChatCompletions( + id='test-id', + model='gpt-4', + created=datetime(2024, 1, 1), + choices=[ + ChatChoice( + index=0, + finish_reason=finish_reason, + message=ChatResponseMessage(**message_kwargs), + ) + ], + usage=CompletionsUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ) + + +def _make_tool_response() -> ChatCompletions: + return _make_chat_response( + content='', + finish_reason='tool_calls', + tool_calls=[ + { + 'id': 'call_1', + 'type': 'function', + 'function': {'name': 'get_weather', 'arguments': '{"city": "London"}'}, + } + ], + ) + + +def _make_streaming_chunks() -> list[StreamingChatCompletionsUpdate]: + return [ + StreamingChatCompletionsUpdate( + id='test-id', + model='gpt-4', + created=datetime(2024, 1, 1), + choices=[ + StreamingChatChoiceUpdate( + index=0, delta=StreamingChatResponseMessageUpdate(role='assistant', content='') + ) + ], + ), + StreamingChatCompletionsUpdate( + id='test-id', + model='gpt-4', + created=datetime(2024, 1, 1), + choices=[ + StreamingChatChoiceUpdate(index=0, delta=StreamingChatResponseMessageUpdate(content='The answer')) + ], + ), + StreamingChatCompletionsUpdate( + id='test-id', + model='gpt-4', + created=datetime(2024, 1, 1), + choices=[ + StreamingChatChoiceUpdate(index=0, delta=StreamingChatResponseMessageUpdate(content=' is secret')) + ], + ), + StreamingChatCompletionsUpdate( + id='test-id', + model='gpt-4', + created=datetime(2024, 1, 1), + choices=[ + StreamingChatChoiceUpdate( + index=0, finish_reason='stop', delta=StreamingChatResponseMessageUpdate(content='') + ) + ], + ), + ] + + +def _make_embed_response() -> EmbeddingsResult: + return EmbeddingsResult( + id='test-id', + model='text-embedding-ada-002', + data=[EmbeddingItem(embedding=[0.1, 0.2, 0.3], index=0)], + usage=EmbeddingsUsage(prompt_tokens=5, total_tokens=5), + ) + + +class MockChatCompletionsClient: + """Mock ChatCompletionsClient that returns preconfigured responses.""" + + __module__ = 'azure.ai.inference' + + def __init__(self, response: Any = None, stream_chunks: list[Any] | None = None) -> None: + self._response = response or _make_chat_response() + self._stream_chunks = stream_chunks + + def complete(self, **kwargs: Any) -> Any: + if kwargs.get('stream'): + return iter(self._stream_chunks or _make_streaming_chunks()) + return self._response + + +class MockAsyncChatCompletionsClient: + """Mock async ChatCompletionsClient.""" + + __module__ = 'azure.ai.inference.aio' + + def __init__(self, response: Any = None, stream_chunks: list[Any] | None = None) -> None: + self._response = response or _make_chat_response() + self._stream_chunks = stream_chunks + + async def complete(self, **kwargs: Any) -> Any: + if kwargs.get('stream'): + return _async_iter(self._stream_chunks or _make_streaming_chunks()) + return self._response + + +class MockEmbeddingsClient: + """Mock EmbeddingsClient.""" + + __module__ = 'azure.ai.inference' + + def __init__(self, response: Any = None) -> None: + self._response = response or _make_embed_response() + + def embed(self, **kwargs: Any) -> Any: + return self._response + + +class MockAsyncEmbeddingsClient: + """Mock async EmbeddingsClient.""" + + __module__ = 'azure.ai.inference.aio' + + def __init__(self, response: Any = None) -> None: + self._response = response or _make_embed_response() + + async def embed(self, **kwargs: Any) -> Any: + return self._response + + +async def _async_iter(items: list[Any]) -> Any: + for item in items: + yield item + + +def test_sync_chat(exporter: TestExporter) -> None: + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client, version=[1, 'latest']): + response = client.complete( + model='gpt-4', + messages=[ + {'role': 'system', 'content': 'You are helpful.'}, + {'role': 'user', 'content': 'What is four plus five?'}, + ], + temperature=0.5, + ) + assert response.choices[0].message.content == 'Nine' + assert exporter.exported_spans_as_dict(parse_json_attributes=True) == snapshot( + [ + { + 'name': 'Chat completion with {request_data[model]!r}', + 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, + 'parent': None, + 'start_time': 1000000000, + 'end_time': 2000000000, + 'attributes': { + 'code.filepath': 'test_azure_ai_inference.py', + 'code.function': 'test_sync_chat', + 'code.lineno': 123, + 'request_data': { + 'model': 'gpt-4', + 'messages': [ + {'role': 'system', 'content': 'You are helpful.'}, + {'role': 'user', 'content': 'What is four plus five?'}, + ], + 'temperature': 0.5, + }, + 'gen_ai.provider.name': 'azure.ai.inference', + 'gen_ai.operation.name': 'chat', + 'gen_ai.request.model': 'gpt-4', + 'gen_ai.request.temperature': 0.5, + 'gen_ai.input.messages': [ + {'role': 'user', 'parts': [{'type': 'text', 'content': 'What is four plus five?'}]} + ], + 'gen_ai.system_instructions': [{'type': 'text', 'content': 'You are helpful.'}], + 'logfire.msg_template': 'Chat completion with {request_data[model]!r}', + 'logfire.msg': "Chat completion with 'gpt-4'", + 'logfire.tags': ('LLM',), + 'logfire.span_type': 'span', + 'response_data': { + 'message': {'role': 'assistant', 'content': 'Nine'}, + 'usage': {'prompt_tokens': 10, 'completion_tokens': 5, 'total_tokens': 15}, + }, + 'gen_ai.output.messages': [ + { + 'role': 'assistant', + 'parts': [{'type': 'text', 'content': 'Nine'}], + 'finish_reason': 'CompletionsFinishReason.STOPPED', + } + ], + 'gen_ai.response.model': 'gpt-4', + 'gen_ai.response.id': 'test-id', + 'gen_ai.usage.input_tokens': 10, + 'gen_ai.usage.output_tokens': 5, + 'gen_ai.response.finish_reasons': ['CompletionsFinishReason.STOPPED'], + 'logfire.json_schema': { + 'type': 'object', + 'properties': { + 'request_data': {'type': 'object'}, + 'gen_ai.provider.name': {}, + 'gen_ai.operation.name': {}, + 'gen_ai.request.model': {}, + 'gen_ai.request.temperature': {}, + 'gen_ai.input.messages': {'type': 'array'}, + 'gen_ai.system_instructions': {'type': 'array'}, + 'response_data': { + 'type': 'object', + 'properties': { + 'message': { + 'type': 'object', + 'properties': { + 'role': { + 'type': 'string', + 'title': 'ChatRole', + 'x-python-datatype': 'Enum', + 'enum': ['system', 'user', 'assistant', 'tool', 'developer'], + } + }, + } + }, + }, + 'gen_ai.output.messages': { + 'type': 'array', + 'items': { + 'type': 'object', + 'properties': { + 'role': { + 'type': 'string', + 'title': 'ChatRole', + 'x-python-datatype': 'Enum', + 'enum': ['system', 'user', 'assistant', 'tool', 'developer'], + } + }, + }, + }, + 'gen_ai.response.model': {}, + 'gen_ai.response.id': {}, + 'gen_ai.usage.input_tokens': {}, + 'gen_ai.usage.output_tokens': {}, + 'gen_ai.response.finish_reasons': {'type': 'array'}, + }, + }, + }, + } + ] + ) + + +def test_sync_chat_streaming(exporter: TestExporter) -> None: + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client, version=[1, 'latest']): + response = client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'Tell me a secret'}], + stream=True, + ) + chunks = list(response) + assert len(chunks) == 4 + assert exporter.exported_spans_as_dict(parse_json_attributes=True) == snapshot( + [ + { + 'name': 'Chat completion with {request_data[model]!r}', + 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, + 'parent': None, + 'start_time': 1000000000, + 'end_time': 2000000000, + 'attributes': { + 'code.filepath': 'test_azure_ai_inference.py', + 'code.function': 'test_sync_chat_streaming', + 'code.lineno': 123, + 'request_data': { + 'model': 'gpt-4', + 'messages': [{'role': 'user', 'content': "[Scrubbed due to 'secret']"}], + }, + 'gen_ai.provider.name': 'azure.ai.inference', + 'gen_ai.operation.name': 'chat', + 'gen_ai.request.model': 'gpt-4', + 'gen_ai.input.messages': [ + {'role': 'user', 'parts': [{'type': 'text', 'content': 'Tell me a secret'}]} + ], + 'logfire.msg_template': 'Chat completion with {request_data[model]!r}', + 'logfire.msg': "Chat completion with 'gpt-4'", + 'logfire.json_schema': { + 'type': 'object', + 'properties': { + 'request_data': {'type': 'object'}, + 'gen_ai.provider.name': {}, + 'gen_ai.operation.name': {}, + 'gen_ai.request.model': {}, + 'gen_ai.input.messages': {'type': 'array'}, + }, + }, + 'logfire.tags': ('LLM',), + 'logfire.span_type': 'span', + 'gen_ai.response.model': 'gpt-4', + 'logfire.scrubbed': [ + { + 'path': ['attributes', 'request_data', 'messages', 0, 'content'], + 'matched_substring': 'secret', + } + ], + }, + }, + { + 'name': 'streaming response from {request_data[model]!r} took {duration:.2f}s', + 'context': {'trace_id': 2, 'span_id': 3, 'is_remote': False}, + 'parent': None, + 'start_time': 5000000000, + 'end_time': 5000000000, + 'attributes': { + 'logfire.span_type': 'log', + 'logfire.level_num': 9, + 'logfire.msg_template': 'streaming response from {request_data[model]!r} took {duration:.2f}s', + 'logfire.msg': "streaming response from 'gpt-4' took 1.00s", + 'code.filepath': 'test_azure_ai_inference.py', + 'code.function': 'test_sync_chat_streaming', + 'code.lineno': 123, + 'duration': 1.0, + 'request_data': { + 'model': 'gpt-4', + 'messages': [{'role': 'user', 'content': "[Scrubbed due to 'secret']"}], + }, + 'gen_ai.provider.name': 'azure.ai.inference', + 'gen_ai.operation.name': 'chat', + 'gen_ai.request.model': 'gpt-4', + 'gen_ai.input.messages': [ + {'role': 'user', 'parts': [{'type': 'text', 'content': 'Tell me a secret'}]} + ], + 'response_data': {'combined_chunk_content': "[Scrubbed due to 'secret']", 'chunk_count': 2}, + 'gen_ai.output.messages': [ + {'role': 'assistant', 'parts': [{'type': 'text', 'content': 'The answer is secret'}]} + ], + 'logfire.json_schema': { + 'type': 'object', + 'properties': { + 'duration': {}, + 'request_data': {'type': 'object'}, + 'gen_ai.provider.name': {}, + 'gen_ai.operation.name': {}, + 'gen_ai.request.model': {}, + 'gen_ai.input.messages': {'type': 'array'}, + 'response_data': {'type': 'object'}, + 'gen_ai.output.messages': {'type': 'array'}, + }, + }, + 'logfire.tags': ('LLM',), + 'gen_ai.response.model': 'gpt-4', + 'logfire.scrubbed': [ + { + 'path': ['attributes', 'request_data', 'messages', 0, 'content'], + 'matched_substring': 'secret', + }, + { + 'path': ['attributes', 'response_data', 'combined_chunk_content'], + 'matched_substring': 'secret', + }, + ], + }, + }, + ] + ) + + +def test_sync_chat_tool_calls(exporter: TestExporter) -> None: + client = MockChatCompletionsClient(response=_make_tool_response()) + with logfire.instrument_azure_ai_inference(client, version=[1, 'latest']): + client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'What is the weather?'}], + ) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + attrs = spans[0]['attributes'] + # Check tool calls in response_data + assert attrs['response_data']['message']['tool_calls'] == [ + { + 'id': 'call_1', + 'function': {'name': 'get_weather', 'arguments': '{"city": "London"}'}, + } + ] + # Check tool calls in semconv output + output_msgs = attrs['gen_ai.output.messages'] + assert len(output_msgs) == 1 + tool_part = output_msgs[0]['parts'][0] + assert tool_part['type'] == 'tool_call' + assert tool_part['name'] == 'get_weather' + + +@pytest.mark.anyio +async def test_async_chat(exporter: TestExporter) -> None: + client = MockAsyncChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client, version=[1, 'latest']): + response = await client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'What is four plus five?'}], + ) + assert response.choices[0].message.content == 'Nine' + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + assert spans[0]['attributes']['gen_ai.response.model'] == 'gpt-4' + assert spans[0]['attributes']['gen_ai.usage.input_tokens'] == 10 + + +@pytest.mark.anyio +async def test_async_chat_streaming(exporter: TestExporter) -> None: + client = MockAsyncChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client, version=[1, 'latest']): + response = await client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'Tell me a secret'}], + stream=True, + ) + chunks = [chunk async for chunk in response] + assert len(chunks) == 4 + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 2 + # First span: the request + assert spans[0]['attributes']['logfire.msg'] == "Chat completion with 'gpt-4'" + # Second span: streaming info + assert 'streaming response from' in spans[1]['attributes']['logfire.msg'] + assert spans[1]['attributes']['response_data']['combined_chunk_content'] == "[Scrubbed due to 'secret']" + + +def test_sync_embeddings(exporter: TestExporter) -> None: + client = MockEmbeddingsClient() + with logfire.instrument_azure_ai_inference(client, version=[1, 'latest']): + response = client.embed( + model='text-embedding-ada-002', + input=['Hello world'], + ) + assert len(response.data) == 1 + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + attrs = spans[0]['attributes'] + assert attrs['gen_ai.provider.name'] == 'azure.ai.inference' + assert attrs['gen_ai.operation.name'] == 'embeddings' + assert attrs['gen_ai.request.model'] == 'text-embedding-ada-002' + assert attrs['gen_ai.response.model'] == 'text-embedding-ada-002' + assert attrs['gen_ai.usage.input_tokens'] == 5 + assert attrs['response_data'] == { + 'usage': {'prompt_tokens': 5, 'total_tokens': 5}, + 'data_count': 1, + } + + +@pytest.mark.anyio +async def test_async_embeddings(exporter: TestExporter) -> None: + client = MockAsyncEmbeddingsClient() + with logfire.instrument_azure_ai_inference(client, version=[1, 'latest']): + response = await client.embed( + model='text-embedding-ada-002', + input=['Hello world'], + ) + assert len(response.data) == 1 + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + assert spans[0]['attributes']['gen_ai.operation.name'] == 'embeddings' + + +def test_uninstrumentation(exporter: TestExporter) -> None: + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client, version=1): + client.complete(model='gpt-4', messages=[{'role': 'user', 'content': 'Hi'}]) + assert len(exporter.exported_spans_as_dict()) == 1 + + # After exiting context, client should be uninstrumented + exporter.clear() + client.complete(model='gpt-4', messages=[{'role': 'user', 'content': 'Hi'}]) + assert len(exporter.exported_spans_as_dict()) == 0 + + +def test_double_instrumentation(exporter: TestExporter) -> None: + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client, version=1): + with logfire.instrument_azure_ai_inference(client, version=1): + client.complete(model='gpt-4', messages=[{'role': 'user', 'content': 'Hi'}]) + # Should only produce one span (not double-instrumented) + assert len(exporter.exported_spans_as_dict()) == 1 + + +def test_message_conversion_with_typed_objects() -> None: + """Test that Azure SDK typed message objects are converted correctly.""" + from azure.ai.inference.models import SystemMessage, UserMessage + + from logfire._internal.integrations.llm_providers.azure_ai_inference import convert_messages_to_semconv + + messages = [ + SystemMessage(content='You are helpful.'), + UserMessage(content='Hello'), + ] + input_msgs, system_instructions = convert_messages_to_semconv(messages) + assert system_instructions == [{'type': 'text', 'content': 'You are helpful.'}] + assert input_msgs == [{'role': 'user', 'parts': [{'type': 'text', 'content': 'Hello'}]}] + + +def test_message_conversion_with_tool_messages() -> None: + """Test that tool messages are converted correctly.""" + from logfire._internal.integrations.llm_providers.azure_ai_inference import convert_messages_to_semconv + + messages = [ + {'role': 'user', 'content': 'What is the weather?'}, + { + 'role': 'assistant', + 'content': '', + 'tool_calls': [ + { + 'id': 'call_1', + 'type': 'function', + 'function': {'name': 'get_weather', 'arguments': '{"city": "London"}'}, + }, + ], + }, + {'role': 'tool', 'content': '72F', 'tool_call_id': 'call_1'}, + ] + input_msgs, system_instructions = convert_messages_to_semconv(messages) + assert len(input_msgs) == 3 + assert system_instructions == [] + # User message + assert input_msgs[0] == {'role': 'user', 'parts': [{'type': 'text', 'content': 'What is the weather?'}]} + # Assistant with tool call + assert input_msgs[1]['role'] == 'assistant' + tool_part: Any = input_msgs[1]['parts'][0] + assert tool_part['type'] == 'tool_call' + assert tool_part['name'] == 'get_weather' + # Tool response + assert input_msgs[2]['role'] == 'tool' + tool_resp: Any = input_msgs[2]['parts'][0] + assert tool_resp['type'] == 'tool_call_response' + assert tool_resp['id'] == 'call_1' + assert tool_resp['response'] == '72F' diff --git a/uv.lock b/uv.lock index fcc647ece..102a21162 100644 --- a/uv.lock +++ b/uv.lock @@ -425,6 +425,33 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3a/2a/7cc015f5b9f5db42b7d48157e23356022889fc354a2813c15934b7cb5c0e/attrs-25.4.0-py3-none-any.whl", hash = "sha256:adcf7e2a1fb3b36ac48d97835bb6d8ade15b8dcce26aba8bf1d14847b57a3373", size = 67615, upload-time = "2025-10-06T13:54:43.17Z" }, ] +[[package]] +name = "azure-ai-inference" +version = "1.0.0b9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "azure-core" }, + { name = "isodate" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4e/6a/ed85592e5c64e08c291992f58b1a94dab6869f28fb0f40fd753dced73ba6/azure_ai_inference-1.0.0b9.tar.gz", hash = "sha256:1feb496bd84b01ee2691befc04358fa25d7c344d8288e99364438859ad7cd5a4", size = 182408, upload-time = "2025-02-15T00:37:28.464Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4f/0f/27520da74769db6e58327d96c98e7b9a07ce686dff582c9a5ec60b03f9dd/azure_ai_inference-1.0.0b9-py3-none-any.whl", hash = "sha256:49823732e674092dad83bb8b0d1b65aa73111fab924d61349eb2a8cdc0493990", size = 124885, upload-time = "2025-02-15T00:37:29.964Z" }, +] + +[[package]] +name = "azure-core" +version = "1.38.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "requests" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/00/fe/5c7710bc611a4070d06ba801de9a935cc87c3d4b689c644958047bdf2cba/azure_core-1.38.2.tar.gz", hash = "sha256:67562857cb979217e48dc60980243b61ea115b77326fa93d83b729e7ff0482e7", size = 363734, upload-time = "2026-02-18T19:33:05.6Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/23/6371a551800d3812d6019cd813acd985f9fac0fedc1290129211a73da4ae/azure_core-1.38.2-py3-none-any.whl", hash = "sha256:074806c75cf239ea284a33a66827695ef7aeddac0b4e19dda266a93e4665ead9", size = 217957, upload-time = "2026-02-18T19:33:07.696Z" }, +] + [[package]] name = "babel" version = "2.18.0" @@ -2442,6 +2469,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/97/25/0e84a6322e5fdb1bf67870b2269151449f4894987b26c78718918dd64ea6/inline_snapshot-0.32.0-py3-none-any.whl", hash = "sha256:b522ae2c891f666e80213c5f9677ec6fd4a2a7d334ab9d6ce745675bec6a40f0", size = 84087, upload-time = "2026-02-13T19:51:52.604Z" }, ] +[[package]] +name = "isodate" +version = "0.7.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/54/4d/e940025e2ce31a8ce1202635910747e5a87cc3a6a6bb2d00973375014749/isodate-0.7.2.tar.gz", hash = "sha256:4cd1aa0f43ca76f4a6c6c0292a85f40b35ec2e43e315b59f06e6d32171a953e6", size = 29705, upload-time = "2024-10-08T23:04:11.5Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/15/aa/0aca39a37d3c7eb941ba736ede56d689e7be91cab5d9ca846bde3999eba6/isodate-0.7.2-py3-none-any.whl", hash = "sha256:28009937d8031054830160fce6d409ed342816b543597cece116d966c6d99e15", size = 22320, upload-time = "2024-10-08T23:04:09.501Z" }, +] + [[package]] name = "itsdangerous" version = "2.2.0" @@ -3258,6 +3294,9 @@ asyncpg = [ aws-lambda = [ { name = "opentelemetry-instrumentation-aws-lambda" }, ] +azure-ai-inference = [ + { name = "azure-ai-inference" }, +] celery = [ { name = "opentelemetry-instrumentation-celery" }, ] @@ -3338,6 +3377,7 @@ dev = [ { name = "anthropic" }, { name = "asyncpg" }, { name = "attrs" }, + { name = "azure-ai-inference" }, { name = "boto3" }, { name = "botocore" }, { name = "celery" }, @@ -3463,6 +3503,7 @@ docs = [ [package.metadata] requires-dist = [ + { name = "azure-ai-inference", marker = "extra == 'azure-ai-inference'", specifier = ">=1.0.0b1" }, { name = "executing", specifier = ">=2.0.1" }, { name = "httpx", marker = "extra == 'datasets'", specifier = ">=0.27.2" }, { name = "openinference-instrumentation-dspy", marker = "extra == 'dspy'", specifier = ">=0" }, @@ -3504,7 +3545,7 @@ requires-dist = [ { name = "tomli", marker = "python_full_version < '3.11'", specifier = ">=2.0.1" }, { name = "typing-extensions", specifier = ">=4.1.0" }, ] -provides-extras = ["system-metrics", "asgi", "wsgi", "aiohttp", "aiohttp-client", "aiohttp-server", "celery", "django", "fastapi", "flask", "httpx", "starlette", "sqlalchemy", "asyncpg", "psycopg", "psycopg2", "pymongo", "redis", "requests", "mysql", "sqlite3", "aws-lambda", "google-genai", "litellm", "dspy", "datasets", "variables"] +provides-extras = ["system-metrics", "asgi", "wsgi", "aiohttp", "aiohttp-client", "aiohttp-server", "celery", "django", "fastapi", "flask", "httpx", "starlette", "sqlalchemy", "asyncpg", "psycopg", "psycopg2", "pymongo", "redis", "requests", "mysql", "sqlite3", "aws-lambda", "azure-ai-inference", "google-genai", "litellm", "dspy", "datasets", "variables"] [package.metadata.requires-dev] dev = [ @@ -3513,6 +3554,7 @@ dev = [ { name = "anthropic", specifier = ">=0.27.0" }, { name = "asyncpg", specifier = ">=0.30.0" }, { name = "attrs", specifier = ">=23.1.0" }, + { name = "azure-ai-inference", specifier = ">=1.0.0b1" }, { name = "boto3", specifier = ">=1.28.57" }, { name = "botocore", specifier = ">=1.31.57" }, { name = "celery", specifier = ">=5.4.0" }, From 1d6ca36273a3a7cc01f2861e8b9d5144ecdd5dbc Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 19 Feb 2026 15:28:09 +0100 Subject: [PATCH 2/6] Remove old semconv (version=1) from Azure AI Inference integration Only use the latest OTel GenAI semantic convention attributes (gen_ai.input.messages, gen_ai.output.messages, etc.) instead of the legacy request_data/response_data format. Also backfill model from response when the request model is None (Azure OpenAI deployments). --- .../llm_providers/azure_ai_inference.py | 174 ++++++------------ logfire/_internal/main.py | 15 -- .../test_azure_ai_inference.py | 89 ++------- 3 files changed, 71 insertions(+), 207 deletions(-) diff --git a/logfire/_internal/integrations/llm_providers/azure_ai_inference.py b/logfire/_internal/integrations/llm_providers/azure_ai_inference.py index 702da9a62..41a6d7ac7 100644 --- a/logfire/_internal/integrations/llm_providers/azure_ai_inference.py +++ b/logfire/_internal/integrations/llm_providers/azure_ai_inference.py @@ -39,7 +39,6 @@ OutputMessage, OutputMessages, Role, - SemconvVersion, SystemInstructions, TextPart, ToolCallPart, @@ -62,12 +61,11 @@ def instrument_azure_ai_inference( logfire_instance: Logfire, client: Any, suppress_other_instrumentation: bool, - versions: frozenset[SemconvVersion], ) -> AbstractContextManager[None]: """Instrument Azure AI Inference clients.""" if isinstance(client, (tuple, list)): context_managers = [ - instrument_azure_ai_inference(logfire_instance, c, suppress_other_instrumentation, versions) for c in client + instrument_azure_ai_inference(logfire_instance, c, suppress_other_instrumentation) for c in client ] @contextmanager @@ -96,16 +94,12 @@ def uninstrument_all() -> Iterator[None]: method_name = 'complete' original = client.complete client._original_logfire_method = original - client.complete = _make_instrumented_complete( - original, logfire_llm, suppress_other_instrumentation, versions, is_async - ) + client.complete = _make_instrumented_complete(original, logfire_llm, suppress_other_instrumentation, is_async) else: method_name = 'embed' original = client.embed client._original_logfire_method = original - client.embed = _make_instrumented_embed( - original, logfire_llm, suppress_other_instrumentation, versions, is_async - ) + client.embed = _make_instrumented_embed(original, logfire_llm, suppress_other_instrumentation, is_async) @contextmanager def uninstrument() -> Iterator[None]: @@ -142,7 +136,6 @@ def _make_instrumented_complete( original: Any, logfire_llm: Logfire, suppress: bool, - versions: frozenset[SemconvVersion], is_async: bool, ) -> Any: if is_async: @@ -151,7 +144,7 @@ async def instrumented_complete(*args: Any, **kwargs: Any) -> Any: if is_instrumentation_suppressed(): return await original(*args, **kwargs) try: - span_data = _build_chat_span_data(args, kwargs, versions) + span_data = _build_chat_span_data(args, kwargs) except Exception: log_internal_error() return await original(*args, **kwargs) @@ -171,8 +164,8 @@ async def instrumented_complete(*args: Any, **kwargs: Any) -> Any: response = await original(*args, **kwargs) if is_streaming: - return _AsyncStreamWrapper(response, logfire_llm, span_data, versions, original_context) - _on_chat_response(response, span, versions) + return _AsyncStreamWrapper(response, logfire_llm, span_data, original_context) + _on_chat_response(response, span, span_data) return response return instrumented_complete @@ -182,7 +175,7 @@ def instrumented_complete_sync(*args: Any, **kwargs: Any) -> Any: if is_instrumentation_suppressed(): return original(*args, **kwargs) try: - span_data = _build_chat_span_data(args, kwargs, versions) + span_data = _build_chat_span_data(args, kwargs) except Exception: log_internal_error() return original(*args, **kwargs) @@ -202,8 +195,8 @@ def instrumented_complete_sync(*args: Any, **kwargs: Any) -> Any: response = original(*args, **kwargs) if is_streaming: - return _SyncStreamWrapper(response, logfire_llm, span_data, versions, original_context) - _on_chat_response(response, span, versions) + return _SyncStreamWrapper(response, logfire_llm, span_data, original_context) + _on_chat_response(response, span, span_data) return response return instrumented_complete_sync @@ -213,7 +206,6 @@ def _make_instrumented_embed( original: Any, logfire_llm: Logfire, suppress: bool, - versions: frozenset[SemconvVersion], is_async: bool, ) -> Any: if is_async: @@ -222,7 +214,7 @@ async def instrumented_embed(*args: Any, **kwargs: Any) -> Any: if is_instrumentation_suppressed(): return await original(*args, **kwargs) try: - span_data = _build_embed_span_data(args, kwargs, versions) + span_data = _build_embed_span_data(args, kwargs) except Exception: log_internal_error() return await original(*args, **kwargs) @@ -237,7 +229,7 @@ async def instrumented_embed(*args: Any, **kwargs: Any) -> Any: response = await original(*args, **kwargs) else: response = await original(*args, **kwargs) - _on_embed_response(response, span, versions) + _on_embed_response(response, span, span_data) return response return instrumented_embed @@ -247,7 +239,7 @@ def instrumented_embed_sync(*args: Any, **kwargs: Any) -> Any: if is_instrumentation_suppressed(): return original(*args, **kwargs) try: - span_data = _build_embed_span_data(args, kwargs, versions) + span_data = _build_embed_span_data(args, kwargs) except Exception: log_internal_error() return original(*args, **kwargs) @@ -262,7 +254,7 @@ def instrumented_embed_sync(*args: Any, **kwargs: Any) -> Any: response = original(*args, **kwargs) else: response = original(*args, **kwargs) - _on_embed_response(response, span, versions) + _on_embed_response(response, span, span_data) return response return instrumented_embed_sync @@ -274,25 +266,13 @@ def instrumented_embed_sync(*args: Any, **kwargs: Any) -> Any: def _build_chat_span_data( args: tuple[Any, ...], kwargs: dict[str, Any], - versions: frozenset[SemconvVersion], ) -> dict[str, Any]: params = _extract_params(args, kwargs) messages = params.get('messages', []) model = params.get('model') - request_data: dict[str, Any] = {'model': model} - if 1 in versions: - if messages: - request_data['messages'] = [_msg_to_dict(m) for m in messages] - for key in ('temperature', 'max_tokens', 'top_p', 'frequency_penalty', 'presence_penalty', 'seed', 'stop'): - val = params.get(key) - if val is not None: - request_data[key] = val - if (tools := params.get('tools')) is not None: - request_data['tools'] = [t if isinstance(t, dict) else t.as_dict() for t in tools] - span_data: dict[str, Any] = { - 'request_data': request_data, + 'request_data': {'model': model}, PROVIDER_NAME: AZURE_PROVIDER, OPERATION_NAME: 'chat', } @@ -301,7 +281,7 @@ def _build_chat_span_data( _extract_request_parameters(params, span_data) - if 'latest' in versions and messages: + if messages: input_messages, system_instructions = convert_messages_to_semconv(messages) span_data[INPUT_MESSAGES] = input_messages if system_instructions: @@ -313,19 +293,12 @@ def _build_chat_span_data( def _build_embed_span_data( args: tuple[Any, ...], kwargs: dict[str, Any], - versions: frozenset[SemconvVersion], ) -> dict[str, Any]: params = _extract_params(args, kwargs) model = params.get('model') - request_data: dict[str, Any] = {'model': model} - if 1 in versions: - input_val = params.get('input') - if input_val is not None: - request_data['input'] = input_val - span_data: dict[str, Any] = { - 'request_data': request_data, + 'request_data': {'model': model}, PROVIDER_NAME: AZURE_PROVIDER, OPERATION_NAME: 'embeddings', } @@ -367,45 +340,29 @@ def _extract_request_parameters(params: dict[str, Any], span_data: dict[str, Any # --- Response processors --- +def _backfill_model(response: Any, span: LogfireSpan, span_data: dict[str, Any]) -> None: + """If the request model was None, backfill it from the response model.""" + model = getattr(response, 'model', None) + if not model: + return + request_data = span_data.get('request_data') + if not isinstance(request_data, dict) or request_data.get('model') is not None: + return + request_data['model'] = model + span.set_attribute('request_data', request_data) + span.set_attribute(REQUEST_MODEL, model) + span.message = span.message.replace('None', repr(model)) + + @handle_internal_errors -def _on_chat_response(response: Any, span: LogfireSpan, versions: frozenset[SemconvVersion]) -> None: +def _on_chat_response(response: Any, span: LogfireSpan, span_data: dict[str, Any]) -> None: + _backfill_model(response, span, span_data) choices = getattr(response, 'choices', []) usage = getattr(response, 'usage', None) - if 1 in versions: - response_data: dict[str, Any] = {} - if choices: - message = getattr(choices[0], 'message', None) - if message: - msg_data: dict[str, Any] = {'role': getattr(message, 'role', 'assistant')} - content = getattr(message, 'content', None) - if content: - msg_data['content'] = content - tool_calls = getattr(message, 'tool_calls', None) - if tool_calls: - msg_data['tool_calls'] = [ - { - 'id': getattr(tc, 'id', ''), - 'function': { - 'name': getattr(getattr(tc, 'function', None), 'name', ''), - 'arguments': getattr(getattr(tc, 'function', None), 'arguments', ''), - }, - } - for tc in tool_calls - ] - response_data['message'] = msg_data - if usage: - response_data['usage'] = { - 'prompt_tokens': getattr(usage, 'prompt_tokens', 0), - 'completion_tokens': getattr(usage, 'completion_tokens', 0), - 'total_tokens': getattr(usage, 'total_tokens', 0), - } - span.set_attribute('response_data', response_data) - - if 'latest' in versions: - output_messages = convert_response_to_semconv(response) - if output_messages: - span.set_attribute(OUTPUT_MESSAGES, output_messages) + output_messages = convert_response_to_semconv(response) + if output_messages: + span.set_attribute(OUTPUT_MESSAGES, output_messages) model = getattr(response, 'model', None) if model: @@ -429,21 +386,10 @@ def _on_chat_response(response: Any, span: LogfireSpan, versions: frozenset[Semc @handle_internal_errors -def _on_embed_response(response: Any, span: LogfireSpan, versions: frozenset[SemconvVersion]) -> None: +def _on_embed_response(response: Any, span: LogfireSpan, span_data: dict[str, Any]) -> None: + _backfill_model(response, span, span_data) usage = getattr(response, 'usage', None) - if 1 in versions: - response_data: dict[str, Any] = {} - if usage: - response_data['usage'] = { - 'prompt_tokens': getattr(usage, 'prompt_tokens', 0), - 'total_tokens': getattr(usage, 'total_tokens', 0), - } - data = getattr(response, 'data', None) - if data: - response_data['data_count'] = len(data) - span.set_attribute('response_data', response_data) - model = getattr(response, 'model', None) if model: span.set_attribute(RESPONSE_MODEL, model) @@ -461,15 +407,6 @@ def _on_embed_response(response: Any, span: LogfireSpan, versions: frozenset[Sem # --- Message conversion --- -def _msg_to_dict(msg: Any) -> dict[str, Any]: - """Convert an Azure message object or dict to a plain dict.""" - if isinstance(msg, dict): - return msg - if hasattr(msg, 'as_dict'): - return msg.as_dict() - return {} # pragma: no cover - - def convert_messages_to_semconv(messages: list[Any]) -> tuple[InputMessages, SystemInstructions]: """Convert Azure AI Inference messages to OTel GenAI semconv format.""" input_messages: InputMessages = [] @@ -528,6 +465,15 @@ def convert_messages_to_semconv(messages: list[Any]) -> tuple[InputMessages, Sys return input_messages, system_instructions +def _msg_to_dict(msg: Any) -> dict[str, Any]: + """Convert an Azure message object or dict to a plain dict.""" + if isinstance(msg, dict): + return msg + if hasattr(msg, 'as_dict'): + return msg.as_dict() + return {} # pragma: no cover + + def _convert_content_item(item: Any) -> MessagePart: """Convert a content item (text, image, audio) to semconv format.""" if isinstance(item, str): @@ -604,13 +550,11 @@ def __init__( wrapped: Any, logfire_llm: Logfire, span_data: dict[str, Any], - versions: frozenset[SemconvVersion], original_context: Any, ) -> None: self._wrapped = wrapped self._logfire_llm = logfire_llm self._span_data = span_data - self._versions = versions self._original_context = original_context self._chunks: list[str] = [] @@ -640,6 +584,11 @@ def __iter__(self) -> Iterator[Any]: ) def _record_chunk(self, chunk: Any) -> None: + if self._span_data.get('request_data', {}).get('model') is None: + model = getattr(chunk, 'model', None) + if model: + self._span_data['request_data']['model'] = model + self._span_data[REQUEST_MODEL] = model for choice in getattr(chunk, 'choices', []): delta = getattr(choice, 'delta', None) if delta: @@ -650,12 +599,7 @@ def _record_chunk(self, chunk: Any) -> None: def _get_stream_attributes(self) -> dict[str, Any]: result = dict(**self._span_data) combined = ''.join(self._chunks) - if 1 in self._versions: - result['response_data'] = { - 'combined_chunk_content': combined, - 'chunk_count': len(self._chunks), - } - if 'latest' in self._versions and self._chunks: + if self._chunks: result[OUTPUT_MESSAGES] = [ OutputMessage( role='assistant', @@ -673,13 +617,11 @@ def __init__( wrapped: Any, logfire_llm: Logfire, span_data: dict[str, Any], - versions: frozenset[SemconvVersion], original_context: Any, ) -> None: self._wrapped = wrapped self._logfire_llm = logfire_llm self._span_data = span_data - self._versions = versions self._original_context = original_context self._chunks: list[str] = [] @@ -709,6 +651,11 @@ async def __aiter__(self) -> AsyncIterator[Any]: ) def _record_chunk(self, chunk: Any) -> None: + if self._span_data.get('request_data', {}).get('model') is None: + model = getattr(chunk, 'model', None) + if model: + self._span_data['request_data']['model'] = model + self._span_data[REQUEST_MODEL] = model for choice in getattr(chunk, 'choices', []): delta = getattr(choice, 'delta', None) if delta: @@ -719,12 +666,7 @@ def _record_chunk(self, chunk: Any) -> None: def _get_stream_attributes(self) -> dict[str, Any]: result = dict(**self._span_data) combined = ''.join(self._chunks) - if 1 in self._versions: - result['response_data'] = { - 'combined_chunk_content': combined, - 'chunk_count': len(self._chunks), - } - if 'latest' in self._versions and self._chunks: + if self._chunks: result[OUTPUT_MESSAGES] = [ OutputMessage( role='assistant', diff --git a/logfire/_internal/main.py b/logfire/_internal/main.py index 80399c33f..85235344e 100644 --- a/logfire/_internal/main.py +++ b/logfire/_internal/main.py @@ -1394,7 +1394,6 @@ def instrument_azure_ai_inference( azure_ai_inference_client: Any = None, *, suppress_other_instrumentation: bool = True, - version: SemconvVersion | Sequence[SemconvVersion] = 1, ) -> AbstractContextManager[None]: """Instrument an Azure AI Inference client so that spans are automatically created for each request. @@ -1437,17 +1436,6 @@ def instrument_azure_ai_inference( enabled. In reality, this means the Azure Core tracing instrumentation, which could otherwise be called since the Azure SDK uses its own pipeline to make HTTP requests. - version: The version(s) of the span attribute format to use: - - - `1` (the default): Uses `request_data` and `response_data` attributes. - - `'latest'`: Uses OpenTelemetry Gen AI semantic convention attributes - (`gen_ai.input.messages`, `gen_ai.output.messages`, etc.) and omits the full - `response_data` attribute. A minimal `request_data` (e.g. `{"model": ...}`) is - still recorded for message template compatibility. This format may change between - releases. - - `[1, 'latest']`: Emits both the full legacy attributes and the semantic convention - attributes simultaneously, useful for migration and testing. - Returns: A context manager that will revert the instrumentation when exited. Use of this context manager is optional. @@ -1463,9 +1451,7 @@ def instrument_azure_ai_inference( ) from .integrations.llm_providers.azure_ai_inference import instrument_azure_ai_inference - from .integrations.llm_providers.semconv import normalize_versions - normalized_versions = normalize_versions(version) self._warn_if_not_initialized_for_instrumentation() if azure_ai_inference_client is None: @@ -1485,7 +1471,6 @@ def instrument_azure_ai_inference( self, azure_ai_inference_client, suppress_other_instrumentation, - normalized_versions, ) def instrument_google_genai(self, **kwargs: Any): diff --git a/tests/otel_integrations/test_azure_ai_inference.py b/tests/otel_integrations/test_azure_ai_inference.py index a19091c6d..1b639a608 100644 --- a/tests/otel_integrations/test_azure_ai_inference.py +++ b/tests/otel_integrations/test_azure_ai_inference.py @@ -171,7 +171,7 @@ async def _async_iter(items: list[Any]) -> Any: def test_sync_chat(exporter: TestExporter) -> None: client = MockChatCompletionsClient() - with logfire.instrument_azure_ai_inference(client, version=[1, 'latest']): + with logfire.instrument_azure_ai_inference(client): response = client.complete( model='gpt-4', messages=[ @@ -193,14 +193,7 @@ def test_sync_chat(exporter: TestExporter) -> None: 'code.filepath': 'test_azure_ai_inference.py', 'code.function': 'test_sync_chat', 'code.lineno': 123, - 'request_data': { - 'model': 'gpt-4', - 'messages': [ - {'role': 'system', 'content': 'You are helpful.'}, - {'role': 'user', 'content': 'What is four plus five?'}, - ], - 'temperature': 0.5, - }, + 'request_data': {'model': 'gpt-4'}, 'gen_ai.provider.name': 'azure.ai.inference', 'gen_ai.operation.name': 'chat', 'gen_ai.request.model': 'gpt-4', @@ -213,10 +206,6 @@ def test_sync_chat(exporter: TestExporter) -> None: 'logfire.msg': "Chat completion with 'gpt-4'", 'logfire.tags': ('LLM',), 'logfire.span_type': 'span', - 'response_data': { - 'message': {'role': 'assistant', 'content': 'Nine'}, - 'usage': {'prompt_tokens': 10, 'completion_tokens': 5, 'total_tokens': 15}, - }, 'gen_ai.output.messages': [ { 'role': 'assistant', @@ -239,22 +228,6 @@ def test_sync_chat(exporter: TestExporter) -> None: 'gen_ai.request.temperature': {}, 'gen_ai.input.messages': {'type': 'array'}, 'gen_ai.system_instructions': {'type': 'array'}, - 'response_data': { - 'type': 'object', - 'properties': { - 'message': { - 'type': 'object', - 'properties': { - 'role': { - 'type': 'string', - 'title': 'ChatRole', - 'x-python-datatype': 'Enum', - 'enum': ['system', 'user', 'assistant', 'tool', 'developer'], - } - }, - } - }, - }, 'gen_ai.output.messages': { 'type': 'array', 'items': { @@ -284,7 +257,7 @@ def test_sync_chat(exporter: TestExporter) -> None: def test_sync_chat_streaming(exporter: TestExporter) -> None: client = MockChatCompletionsClient() - with logfire.instrument_azure_ai_inference(client, version=[1, 'latest']): + with logfire.instrument_azure_ai_inference(client): response = client.complete( model='gpt-4', messages=[{'role': 'user', 'content': 'Tell me a secret'}], @@ -304,10 +277,7 @@ def test_sync_chat_streaming(exporter: TestExporter) -> None: 'code.filepath': 'test_azure_ai_inference.py', 'code.function': 'test_sync_chat_streaming', 'code.lineno': 123, - 'request_data': { - 'model': 'gpt-4', - 'messages': [{'role': 'user', 'content': "[Scrubbed due to 'secret']"}], - }, + 'request_data': {'model': 'gpt-4'}, 'gen_ai.provider.name': 'azure.ai.inference', 'gen_ai.operation.name': 'chat', 'gen_ai.request.model': 'gpt-4', @@ -329,12 +299,6 @@ def test_sync_chat_streaming(exporter: TestExporter) -> None: 'logfire.tags': ('LLM',), 'logfire.span_type': 'span', 'gen_ai.response.model': 'gpt-4', - 'logfire.scrubbed': [ - { - 'path': ['attributes', 'request_data', 'messages', 0, 'content'], - 'matched_substring': 'secret', - } - ], }, }, { @@ -352,17 +316,13 @@ def test_sync_chat_streaming(exporter: TestExporter) -> None: 'code.function': 'test_sync_chat_streaming', 'code.lineno': 123, 'duration': 1.0, - 'request_data': { - 'model': 'gpt-4', - 'messages': [{'role': 'user', 'content': "[Scrubbed due to 'secret']"}], - }, + 'request_data': {'model': 'gpt-4'}, 'gen_ai.provider.name': 'azure.ai.inference', 'gen_ai.operation.name': 'chat', 'gen_ai.request.model': 'gpt-4', 'gen_ai.input.messages': [ {'role': 'user', 'parts': [{'type': 'text', 'content': 'Tell me a secret'}]} ], - 'response_data': {'combined_chunk_content': "[Scrubbed due to 'secret']", 'chunk_count': 2}, 'gen_ai.output.messages': [ {'role': 'assistant', 'parts': [{'type': 'text', 'content': 'The answer is secret'}]} ], @@ -375,22 +335,11 @@ def test_sync_chat_streaming(exporter: TestExporter) -> None: 'gen_ai.operation.name': {}, 'gen_ai.request.model': {}, 'gen_ai.input.messages': {'type': 'array'}, - 'response_data': {'type': 'object'}, 'gen_ai.output.messages': {'type': 'array'}, }, }, 'logfire.tags': ('LLM',), 'gen_ai.response.model': 'gpt-4', - 'logfire.scrubbed': [ - { - 'path': ['attributes', 'request_data', 'messages', 0, 'content'], - 'matched_substring': 'secret', - }, - { - 'path': ['attributes', 'response_data', 'combined_chunk_content'], - 'matched_substring': 'secret', - }, - ], }, }, ] @@ -399,7 +348,7 @@ def test_sync_chat_streaming(exporter: TestExporter) -> None: def test_sync_chat_tool_calls(exporter: TestExporter) -> None: client = MockChatCompletionsClient(response=_make_tool_response()) - with logfire.instrument_azure_ai_inference(client, version=[1, 'latest']): + with logfire.instrument_azure_ai_inference(client): client.complete( model='gpt-4', messages=[{'role': 'user', 'content': 'What is the weather?'}], @@ -407,13 +356,6 @@ def test_sync_chat_tool_calls(exporter: TestExporter) -> None: spans = exporter.exported_spans_as_dict(parse_json_attributes=True) assert len(spans) == 1 attrs = spans[0]['attributes'] - # Check tool calls in response_data - assert attrs['response_data']['message']['tool_calls'] == [ - { - 'id': 'call_1', - 'function': {'name': 'get_weather', 'arguments': '{"city": "London"}'}, - } - ] # Check tool calls in semconv output output_msgs = attrs['gen_ai.output.messages'] assert len(output_msgs) == 1 @@ -425,7 +367,7 @@ def test_sync_chat_tool_calls(exporter: TestExporter) -> None: @pytest.mark.anyio async def test_async_chat(exporter: TestExporter) -> None: client = MockAsyncChatCompletionsClient() - with logfire.instrument_azure_ai_inference(client, version=[1, 'latest']): + with logfire.instrument_azure_ai_inference(client): response = await client.complete( model='gpt-4', messages=[{'role': 'user', 'content': 'What is four plus five?'}], @@ -440,7 +382,7 @@ async def test_async_chat(exporter: TestExporter) -> None: @pytest.mark.anyio async def test_async_chat_streaming(exporter: TestExporter) -> None: client = MockAsyncChatCompletionsClient() - with logfire.instrument_azure_ai_inference(client, version=[1, 'latest']): + with logfire.instrument_azure_ai_inference(client): response = await client.complete( model='gpt-4', messages=[{'role': 'user', 'content': 'Tell me a secret'}], @@ -454,12 +396,11 @@ async def test_async_chat_streaming(exporter: TestExporter) -> None: assert spans[0]['attributes']['logfire.msg'] == "Chat completion with 'gpt-4'" # Second span: streaming info assert 'streaming response from' in spans[1]['attributes']['logfire.msg'] - assert spans[1]['attributes']['response_data']['combined_chunk_content'] == "[Scrubbed due to 'secret']" def test_sync_embeddings(exporter: TestExporter) -> None: client = MockEmbeddingsClient() - with logfire.instrument_azure_ai_inference(client, version=[1, 'latest']): + with logfire.instrument_azure_ai_inference(client): response = client.embed( model='text-embedding-ada-002', input=['Hello world'], @@ -473,16 +414,12 @@ def test_sync_embeddings(exporter: TestExporter) -> None: assert attrs['gen_ai.request.model'] == 'text-embedding-ada-002' assert attrs['gen_ai.response.model'] == 'text-embedding-ada-002' assert attrs['gen_ai.usage.input_tokens'] == 5 - assert attrs['response_data'] == { - 'usage': {'prompt_tokens': 5, 'total_tokens': 5}, - 'data_count': 1, - } @pytest.mark.anyio async def test_async_embeddings(exporter: TestExporter) -> None: client = MockAsyncEmbeddingsClient() - with logfire.instrument_azure_ai_inference(client, version=[1, 'latest']): + with logfire.instrument_azure_ai_inference(client): response = await client.embed( model='text-embedding-ada-002', input=['Hello world'], @@ -495,7 +432,7 @@ async def test_async_embeddings(exporter: TestExporter) -> None: def test_uninstrumentation(exporter: TestExporter) -> None: client = MockChatCompletionsClient() - with logfire.instrument_azure_ai_inference(client, version=1): + with logfire.instrument_azure_ai_inference(client): client.complete(model='gpt-4', messages=[{'role': 'user', 'content': 'Hi'}]) assert len(exporter.exported_spans_as_dict()) == 1 @@ -507,8 +444,8 @@ def test_uninstrumentation(exporter: TestExporter) -> None: def test_double_instrumentation(exporter: TestExporter) -> None: client = MockChatCompletionsClient() - with logfire.instrument_azure_ai_inference(client, version=1): - with logfire.instrument_azure_ai_inference(client, version=1): + with logfire.instrument_azure_ai_inference(client): + with logfire.instrument_azure_ai_inference(client): client.complete(model='gpt-4', messages=[{'role': 'user', 'content': 'Hi'}]) # Should only produce one span (not double-instrumented) assert len(exporter.exported_spans_as_dict()) == 1 From cae114f297559d2cf2e2b13099727efb00036cd0 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 19 Feb 2026 16:06:10 +0100 Subject: [PATCH 3/6] Handle missing model gracefully and improve test coverage Use conditional message templates so spans show "Chat completion" instead of "Chat completion with None" when model isn't in request params (Azure OpenAI deployments). Backfill updates the span message once the response arrives. Add 15 new tests covering model backfill, suppress=False, list instrumentation, all request parameters, body-style params, multimodal content items, and stream context managers. Coverage: 73% -> 89%. --- .../llm_providers/azure_ai_inference.py | 62 +++-- .../test_azure_ai_inference.py | 215 ++++++++++++++++++ 2 files changed, 244 insertions(+), 33 deletions(-) diff --git a/logfire/_internal/integrations/llm_providers/azure_ai_inference.py b/logfire/_internal/integrations/llm_providers/azure_ai_inference.py index 41a6d7ac7..f0515afa1 100644 --- a/logfire/_internal/integrations/llm_providers/azure_ai_inference.py +++ b/logfire/_internal/integrations/llm_providers/azure_ai_inference.py @@ -53,6 +53,13 @@ AZURE_PROVIDER = 'azure.ai.inference' +CHAT_MSG_TEMPLATE = 'Chat completion with {request_data[model]!r}' +CHAT_MSG_TEMPLATE_NO_MODEL = 'Chat completion' +EMBED_MSG_TEMPLATE = 'Embeddings with {request_data[model]!r}' +EMBED_MSG_TEMPLATE_NO_MODEL = 'Embeddings' +STREAM_MSG_TEMPLATE = 'streaming response from {request_data[model]!r} took {duration:.2f}s' +STREAM_MSG_TEMPLATE_NO_MODEL = 'streaming response took {duration:.2f}s' + # --- Main instrumentation entry point --- @@ -151,12 +158,9 @@ async def instrumented_complete(*args: Any, **kwargs: Any) -> Any: is_streaming = kwargs.get('stream', False) original_context = get_context() + msg = CHAT_MSG_TEMPLATE if span_data['request_data']['model'] else CHAT_MSG_TEMPLATE_NO_MODEL - with logfire_llm.span( - 'Chat completion with {request_data[model]!r}', - _span_kind=SpanKind.CLIENT, - **span_data, - ) as span: + with logfire_llm.span(msg, _span_kind=SpanKind.CLIENT, **span_data) as span: if suppress: with suppress_instrumentation(): response = await original(*args, **kwargs) @@ -182,12 +186,9 @@ def instrumented_complete_sync(*args: Any, **kwargs: Any) -> Any: is_streaming = kwargs.get('stream', False) original_context = get_context() + msg = CHAT_MSG_TEMPLATE if span_data['request_data']['model'] else CHAT_MSG_TEMPLATE_NO_MODEL - with logfire_llm.span( - 'Chat completion with {request_data[model]!r}', - _span_kind=SpanKind.CLIENT, - **span_data, - ) as span: + with logfire_llm.span(msg, _span_kind=SpanKind.CLIENT, **span_data) as span: if suppress: with suppress_instrumentation(): response = original(*args, **kwargs) @@ -219,11 +220,9 @@ async def instrumented_embed(*args: Any, **kwargs: Any) -> Any: log_internal_error() return await original(*args, **kwargs) - with logfire_llm.span( - 'Embeddings with {request_data[model]!r}', - _span_kind=SpanKind.CLIENT, - **span_data, - ) as span: + msg = EMBED_MSG_TEMPLATE if span_data['request_data']['model'] else EMBED_MSG_TEMPLATE_NO_MODEL + + with logfire_llm.span(msg, _span_kind=SpanKind.CLIENT, **span_data) as span: if suppress: with suppress_instrumentation(): response = await original(*args, **kwargs) @@ -244,11 +243,9 @@ def instrumented_embed_sync(*args: Any, **kwargs: Any) -> Any: log_internal_error() return original(*args, **kwargs) - with logfire_llm.span( - 'Embeddings with {request_data[model]!r}', - _span_kind=SpanKind.CLIENT, - **span_data, - ) as span: + msg = EMBED_MSG_TEMPLATE if span_data['request_data']['model'] else EMBED_MSG_TEMPLATE_NO_MODEL + + with logfire_llm.span(msg, _span_kind=SpanKind.CLIENT, **span_data) as span: if suppress: with suppress_instrumentation(): response = original(*args, **kwargs) @@ -340,7 +337,7 @@ def _extract_request_parameters(params: dict[str, Any], span_data: dict[str, Any # --- Response processors --- -def _backfill_model(response: Any, span: LogfireSpan, span_data: dict[str, Any]) -> None: +def _backfill_model(response: Any, span: LogfireSpan, span_data: dict[str, Any], operation: str = 'chat') -> None: """If the request model was None, backfill it from the response model.""" model = getattr(response, 'model', None) if not model: @@ -351,7 +348,10 @@ def _backfill_model(response: Any, span: LogfireSpan, span_data: dict[str, Any]) request_data['model'] = model span.set_attribute('request_data', request_data) span.set_attribute(REQUEST_MODEL, model) - span.message = span.message.replace('None', repr(model)) + if operation == 'chat': + span.message = f'Chat completion with {model!r}' + else: + span.message = f'Embeddings with {model!r}' @handle_internal_errors @@ -387,7 +387,7 @@ def _on_chat_response(response: Any, span: LogfireSpan, span_data: dict[str, Any @handle_internal_errors def _on_embed_response(response: Any, span: LogfireSpan, span_data: dict[str, Any]) -> None: - _backfill_model(response, span, span_data) + _backfill_model(response, span, span_data, operation='embeddings') usage = getattr(response, 'usage', None) model = getattr(response, 'model', None) @@ -576,12 +576,10 @@ def __iter__(self) -> Iterator[Any]: yield chunk finally: duration = (timer() - start) / ONE_SECOND_IN_NANOSECONDS + has_model = self._span_data.get('request_data', {}).get('model') is not None + msg = STREAM_MSG_TEMPLATE if has_model else STREAM_MSG_TEMPLATE_NO_MODEL with attach_context(self._original_context): - self._logfire_llm.info( - 'streaming response from {request_data[model]!r} took {duration:.2f}s', - duration=duration, - **self._get_stream_attributes(), - ) + self._logfire_llm.info(msg, duration=duration, **self._get_stream_attributes()) def _record_chunk(self, chunk: Any) -> None: if self._span_data.get('request_data', {}).get('model') is None: @@ -643,12 +641,10 @@ async def __aiter__(self) -> AsyncIterator[Any]: yield chunk finally: duration = (timer() - start) / ONE_SECOND_IN_NANOSECONDS + has_model = self._span_data.get('request_data', {}).get('model') is not None + msg = STREAM_MSG_TEMPLATE if has_model else STREAM_MSG_TEMPLATE_NO_MODEL with attach_context(self._original_context): - self._logfire_llm.info( - 'streaming response from {request_data[model]!r} took {duration:.2f}s', - duration=duration, - **self._get_stream_attributes(), - ) + self._logfire_llm.info(msg, duration=duration, **self._get_stream_attributes()) def _record_chunk(self, chunk: Any) -> None: if self._span_data.get('request_data', {}).get('model') is None: diff --git a/tests/otel_integrations/test_azure_ai_inference.py b/tests/otel_integrations/test_azure_ai_inference.py index 1b639a608..0d1a944f5 100644 --- a/tests/otel_integrations/test_azure_ai_inference.py +++ b/tests/otel_integrations/test_azure_ai_inference.py @@ -451,6 +451,221 @@ def test_double_instrumentation(exporter: TestExporter) -> None: assert len(exporter.exported_spans_as_dict()) == 1 +def test_no_model_backfill(exporter: TestExporter) -> None: + """When request has no model, backfill from response.""" + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + client.complete( + messages=[{'role': 'user', 'content': 'Hi'}], + ) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + attrs = spans[0]['attributes'] + # Model backfilled from response + assert attrs['logfire.msg'] == "Chat completion with 'gpt-4'" + assert attrs['gen_ai.request.model'] == 'gpt-4' + assert attrs['gen_ai.response.model'] == 'gpt-4' + + +def test_no_model_streaming_backfill(exporter: TestExporter) -> None: + """When streaming request has no model, backfill from first chunk.""" + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + response = client.complete( + messages=[{'role': 'user', 'content': 'Hi'}], + stream=True, + ) + list(response) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 2 + # Streaming info span should have model from chunks + assert spans[1]['attributes']['request_data']['model'] == 'gpt-4' + assert spans[1]['attributes']['gen_ai.request.model'] == 'gpt-4' + + +@pytest.mark.anyio +async def test_no_model_async_streaming_backfill(exporter: TestExporter) -> None: + """When async streaming request has no model, backfill from first chunk.""" + client = MockAsyncChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + response = await client.complete( + messages=[{'role': 'user', 'content': 'Hi'}], + stream=True, + ) + async for _ in response: + pass + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 2 + assert spans[1]['attributes']['request_data']['model'] == 'gpt-4' + + +def test_no_model_embed_backfill(exporter: TestExporter) -> None: + """When embed request has no model, backfill from response.""" + client = MockEmbeddingsClient() + with logfire.instrument_azure_ai_inference(client): + client.embed(input=['Hello']) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + attrs = spans[0]['attributes'] + assert attrs['logfire.msg'] == "Embeddings with 'text-embedding-ada-002'" + assert attrs['gen_ai.request.model'] == 'text-embedding-ada-002' + + +@pytest.mark.anyio +async def test_no_model_async_embed_backfill(exporter: TestExporter) -> None: + """When async embed request has no model, backfill from response.""" + client = MockAsyncEmbeddingsClient() + with logfire.instrument_azure_ai_inference(client): + await client.embed(input=['Hello']) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + assert attrs_msg(spans[0]) == "Embeddings with 'text-embedding-ada-002'" + + +def attrs_msg(span: dict[str, Any]) -> str: + return span['attributes']['logfire.msg'] + + +def test_suppress_false(exporter: TestExporter) -> None: + """Test with suppress_other_instrumentation=False.""" + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client, suppress_other_instrumentation=False): + client.complete(model='gpt-4', messages=[{'role': 'user', 'content': 'Hi'}]) + assert len(exporter.exported_spans_as_dict()) == 1 + + +@pytest.mark.anyio +async def test_suppress_false_async(exporter: TestExporter) -> None: + """Test with suppress_other_instrumentation=False for async.""" + client = MockAsyncChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client, suppress_other_instrumentation=False): + await client.complete(model='gpt-4', messages=[{'role': 'user', 'content': 'Hi'}]) + assert len(exporter.exported_spans_as_dict()) == 1 + + +def test_suppress_false_embed(exporter: TestExporter) -> None: + """Test embed with suppress=False.""" + client = MockEmbeddingsClient() + with logfire.instrument_azure_ai_inference(client, suppress_other_instrumentation=False): + client.embed(model='text-embedding-ada-002', input=['Hi']) + assert len(exporter.exported_spans_as_dict()) == 1 + + +@pytest.mark.anyio +async def test_suppress_false_async_embed(exporter: TestExporter) -> None: + """Test async embed with suppress=False.""" + client = MockAsyncEmbeddingsClient() + with logfire.instrument_azure_ai_inference(client, suppress_other_instrumentation=False): + await client.embed(model='text-embedding-ada-002', input=['Hi']) + assert len(exporter.exported_spans_as_dict()) == 1 + + +def test_list_instrumentation(exporter: TestExporter) -> None: + """Test instrumenting a list of clients.""" + chat_client = MockChatCompletionsClient() + embed_client = MockEmbeddingsClient() + with logfire.instrument_azure_ai_inference([chat_client, embed_client]): + chat_client.complete(model='gpt-4', messages=[{'role': 'user', 'content': 'Hi'}]) + embed_client.embed(model='text-embedding-ada-002', input=['Hello']) + assert len(exporter.exported_spans_as_dict()) == 2 + + # After exiting, both should be uninstrumented + exporter.clear() + chat_client.complete(model='gpt-4', messages=[{'role': 'user', 'content': 'Hi'}]) + embed_client.embed(model='text-embedding-ada-002', input=['Hello']) + assert len(exporter.exported_spans_as_dict()) == 0 + + +def test_request_parameters(exporter: TestExporter) -> None: + """Test that all request parameters are captured.""" + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'Hi'}], + temperature=0.7, + max_tokens=100, + top_p=0.9, + frequency_penalty=0.5, + presence_penalty=0.3, + seed=42, + stop=['\n', 'END'], + ) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + attrs = spans[0]['attributes'] + assert attrs['gen_ai.request.temperature'] == 0.7 + assert attrs['gen_ai.request.max_tokens'] == 100 + assert attrs['gen_ai.request.top_p'] == 0.9 + assert attrs['gen_ai.request.frequency_penalty'] == 0.5 + assert attrs['gen_ai.request.presence_penalty'] == 0.3 + assert attrs['gen_ai.request.seed'] == 42 + assert attrs['gen_ai.request.stop_sequences'] == ['\n', 'END'] + + +def test_extract_params_body_style(exporter: TestExporter) -> None: + """Test that body-style parameters are extracted.""" + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + client.complete(body={'model': 'gpt-4', 'messages': [{'role': 'user', 'content': 'Hi'}]}) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert spans[0]['attributes']['gen_ai.request.model'] == 'gpt-4' + + +def test_content_item_conversion() -> None: + """Test conversion of multimodal content items.""" + from logfire._internal.integrations.llm_providers.azure_ai_inference import convert_messages_to_semconv + + messages = [ + { + 'role': 'user', + 'content': [ + 'plain string item', + {'type': 'text', 'text': 'text item'}, + {'type': 'image_url', 'image_url': {'url': 'https://example.com/img.png'}}, + {'type': 'input_audio', 'input_audio': {'data': 'base64data', 'format': 'mp3'}}, + ], + }, + ] + input_msgs, _ = convert_messages_to_semconv(messages) + parts = input_msgs[0]['parts'] + assert parts[0] == {'type': 'text', 'content': 'plain string item'} + assert parts[1] == {'type': 'text', 'content': 'text item'} + assert parts[2] == {'type': 'uri', 'uri': 'https://example.com/img.png', 'modality': 'image'} + assert parts[3] == {'type': 'blob', 'content': 'base64data', 'media_type': 'audio/mp3', 'modality': 'audio'} + + +def test_stream_context_manager(exporter: TestExporter) -> None: + """Test that sync stream wrapper supports context manager protocol.""" + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + response = client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'Hi'}], + stream=True, + ) + # Use as context manager + with response: + for _ in response: + pass + assert len(exporter.exported_spans_as_dict()) == 2 + + +@pytest.mark.anyio +async def test_async_stream_context_manager(exporter: TestExporter) -> None: + """Test that async stream wrapper supports async context manager protocol.""" + client = MockAsyncChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + response = await client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'Hi'}], + stream=True, + ) + async with response: + async for _ in response: + pass + assert len(exporter.exported_spans_as_dict()) == 2 + + def test_message_conversion_with_typed_objects() -> None: """Test that Azure SDK typed message objects are converted correctly.""" from azure.ai.inference.models import SystemMessage, UserMessage From 8c9e8bb92081cdfb970d236e63056b47de144cd4 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 19 Feb 2026 16:29:29 +0100 Subject: [PATCH 4/6] Reach 100% test coverage for Azure AI Inference integration Add comprehensive tests for all code paths including: - Positional arg extraction in _extract_params - Tool objects with as_dict() method - Model backfill when response has no model - Minimal/empty responses (no model, id, usage, finish_reason) - Choice without message in response conversion - Stream context manager delegation (__enter__/__exit__, __aenter__/__aexit__) - Streaming with empty chunks and no-model chunks - System messages with non-string content - Response tool calls without function attribute Add pragma: no cover to defensive paths (is_instrumentation_suppressed, except Exception, ImportError) matching patterns from other integrations. --- .../llm_providers/azure_ai_inference.py | 16 +- logfire/_internal/main.py | 2 +- .../test_azure_ai_inference.py | 474 +++++++++++++++++- 3 files changed, 480 insertions(+), 12 deletions(-) diff --git a/logfire/_internal/integrations/llm_providers/azure_ai_inference.py b/logfire/_internal/integrations/llm_providers/azure_ai_inference.py index f0515afa1..b80cbdf73 100644 --- a/logfire/_internal/integrations/llm_providers/azure_ai_inference.py +++ b/logfire/_internal/integrations/llm_providers/azure_ai_inference.py @@ -148,11 +148,11 @@ def _make_instrumented_complete( if is_async: async def instrumented_complete(*args: Any, **kwargs: Any) -> Any: - if is_instrumentation_suppressed(): + if is_instrumentation_suppressed(): # pragma: no cover return await original(*args, **kwargs) try: span_data = _build_chat_span_data(args, kwargs) - except Exception: + except Exception: # pragma: no cover log_internal_error() return await original(*args, **kwargs) @@ -176,11 +176,11 @@ async def instrumented_complete(*args: Any, **kwargs: Any) -> Any: else: def instrumented_complete_sync(*args: Any, **kwargs: Any) -> Any: - if is_instrumentation_suppressed(): + if is_instrumentation_suppressed(): # pragma: no cover return original(*args, **kwargs) try: span_data = _build_chat_span_data(args, kwargs) - except Exception: + except Exception: # pragma: no cover log_internal_error() return original(*args, **kwargs) @@ -212,11 +212,11 @@ def _make_instrumented_embed( if is_async: async def instrumented_embed(*args: Any, **kwargs: Any) -> Any: - if is_instrumentation_suppressed(): + if is_instrumentation_suppressed(): # pragma: no cover return await original(*args, **kwargs) try: span_data = _build_embed_span_data(args, kwargs) - except Exception: + except Exception: # pragma: no cover log_internal_error() return await original(*args, **kwargs) @@ -235,11 +235,11 @@ async def instrumented_embed(*args: Any, **kwargs: Any) -> Any: else: def instrumented_embed_sync(*args: Any, **kwargs: Any) -> Any: - if is_instrumentation_suppressed(): + if is_instrumentation_suppressed(): # pragma: no cover return original(*args, **kwargs) try: span_data = _build_embed_span_data(args, kwargs) - except Exception: + except Exception: # pragma: no cover log_internal_error() return original(*args, **kwargs) diff --git a/logfire/_internal/main.py b/logfire/_internal/main.py index 85235344e..83615ff23 100644 --- a/logfire/_internal/main.py +++ b/logfire/_internal/main.py @@ -1442,7 +1442,7 @@ def instrument_azure_ai_inference( """ try: from azure.ai.inference import ChatCompletionsClient, EmbeddingsClient - except ImportError: + except ImportError: # pragma: no cover raise RuntimeError( 'The `logfire.instrument_azure_ai_inference()` method ' 'requires the `azure-ai-inference` package.\n' diff --git a/tests/otel_integrations/test_azure_ai_inference.py b/tests/otel_integrations/test_azure_ai_inference.py index 0d1a944f5..b2a7c593c 100644 --- a/tests/otel_integrations/test_azure_ai_inference.py +++ b/tests/otel_integrations/test_azure_ai_inference.py @@ -1,6 +1,7 @@ -# pyright: reportCallIssue=false, reportArgumentType=false +# pyright: reportCallIssue=false, reportArgumentType=false, reportPrivateUsage=false from __future__ import annotations as _annotations +from collections.abc import AsyncIterator, Iterator from datetime import datetime from typing import Any @@ -119,7 +120,7 @@ def __init__(self, response: Any = None, stream_chunks: list[Any] | None = None) self._response = response or _make_chat_response() self._stream_chunks = stream_chunks - def complete(self, **kwargs: Any) -> Any: + def complete(self, *args: Any, **kwargs: Any) -> Any: if kwargs.get('stream'): return iter(self._stream_chunks or _make_streaming_chunks()) return self._response @@ -134,7 +135,7 @@ def __init__(self, response: Any = None, stream_chunks: list[Any] | None = None) self._response = response or _make_chat_response() self._stream_chunks = stream_chunks - async def complete(self, **kwargs: Any) -> Any: + async def complete(self, *args: Any, **kwargs: Any) -> Any: if kwargs.get('stream'): return _async_iter(self._stream_chunks or _make_streaming_chunks()) return self._response @@ -666,6 +667,473 @@ async def test_async_stream_context_manager(exporter: TestExporter) -> None: assert len(exporter.exported_spans_as_dict()) == 2 +def test_positional_arg_extraction() -> None: + """Test that _extract_params handles positional args correctly.""" + from logfire._internal.integrations.llm_providers.azure_ai_inference import _extract_params + + # Single positional arg with 'messages' key + result = _extract_params(({'model': 'gpt-4', 'messages': [{'role': 'user', 'content': 'Hi'}]},), {}) + assert result['model'] == 'gpt-4' + + # Multiple args, first doesn't match, second does (covers loop iteration) + result = _extract_params(('not-a-dict', {'messages': [{'role': 'user', 'content': 'Hi'}]}), {}) + assert 'messages' in result + + # No matching arg, falls back to kwargs + result = _extract_params(('not-a-dict',), {'model': 'gpt-4'}) + assert result == {'model': 'gpt-4'} + + +def test_tools_with_as_dict(exporter: TestExporter) -> None: + """Test that tool objects with as_dict() are handled.""" + + class MockTool: + def as_dict(self) -> dict[str, Any]: + return {'type': 'function', 'function': {'name': 'my_tool', 'parameters': {}}} + + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'Hi'}], + tools=[MockTool()], + ) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + tool_defs = spans[0]['attributes']['gen_ai.tool.definitions'] + assert tool_defs == [{'type': 'function', 'function': {'name': 'my_tool', 'parameters': {}}}] + + +def test_backfill_no_model_in_response(exporter: TestExporter) -> None: + """Test backfill when response also has no model.""" + response = ChatCompletions( + id='test-id', + model=None, + created=datetime(2024, 1, 1), + choices=[ + ChatChoice( + index=0, + finish_reason='stop', + message=ChatResponseMessage(role='assistant', content='Hello'), + ) + ], + usage=CompletionsUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ) + client = MockChatCompletionsClient(response=response) + with logfire.instrument_azure_ai_inference(client): + client.complete(messages=[{'role': 'user', 'content': 'Hi'}]) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + # Model stays None - no backfill + assert spans[0]['attributes']['logfire.msg'] == 'Chat completion' + + +def test_minimal_chat_response(exporter: TestExporter) -> None: + """Test response with no model, no id, no usage, no finish_reason. + + Exercises the false branches in _on_chat_response. + """ + response = ChatCompletions( + id=None, + model=None, + created=datetime(2024, 1, 1), + choices=[ + ChatChoice( + index=0, + finish_reason=None, + message=ChatResponseMessage(role='assistant', content='Hi'), + ) + ], + usage=None, + ) + client = MockChatCompletionsClient(response=response) + with logfire.instrument_azure_ai_inference(client): + client.complete(model='gpt-4', messages=[{'role': 'user', 'content': 'Hi'}]) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + + +def test_choice_without_message() -> None: + """Test response with a choice that has no message. + + Exercises the `if not message: continue` path in convert_response_to_semconv. + """ + from logfire._internal.integrations.llm_providers.azure_ai_inference import convert_response_to_semconv + + class FakeChoice: + message = None + finish_reason = 'stop' + + class FakeResponse: + choices = [FakeChoice()] + + output = convert_response_to_semconv(FakeResponse()) + assert output == [] + + +def test_minimal_embed_response(exporter: TestExporter) -> None: + """Test embed response with no model, no id, no usage. + + Exercises the false branches in _on_embed_response. + """ + response = EmbeddingsResult( + id=None, + model=None, + data=[EmbeddingItem(embedding=[0.1], index=0)], + usage=None, + ) + client = MockEmbeddingsClient(response=response) + with logfire.instrument_azure_ai_inference(client): + client.embed(model='text-embedding-ada-002', input=['Hi']) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + + +def test_response_empty_choices(exporter: TestExporter) -> None: + """Test response with empty choices list. + + Exercises the false branch of `if output_messages:` in _on_chat_response. + """ + + class FakeResponse: + id = 'test-id' + model = 'gpt-4' + choices: list[Any] = [] + usage = None + + client = MockChatCompletionsClient(response=FakeResponse()) + with logfire.instrument_azure_ai_inference(client): + client.complete(model='gpt-4', messages=[{'role': 'user', 'content': 'Hi'}]) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + + +def test_usage_with_none_tokens(exporter: TestExporter) -> None: + """Test response with usage but None prompt/completion tokens. + + Exercises false branches of prompt_tokens/completion_tokens checks. + """ + + class FakeUsage: + prompt_tokens = None + completion_tokens = None + + class FakeMessage: + role = 'assistant' + content = 'Hi' + tool_calls = None + + class FakeChoice: + index = 0 + finish_reason = None + message = FakeMessage() + + class FakeResponse: + id = None + model = None + choices = [FakeChoice()] + usage = FakeUsage() + + client = MockChatCompletionsClient(response=FakeResponse()) + with logfire.instrument_azure_ai_inference(client): + client.complete(model='gpt-4', messages=[{'role': 'user', 'content': 'Hi'}]) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + + +def test_embed_usage_with_none_tokens(exporter: TestExporter) -> None: + """Test embed response with usage but None prompt_tokens. + + Exercises false branch of prompt_tokens check in _on_embed_response. + """ + + class FakeUsage: + prompt_tokens = None + + class FakeResponse: + id = None + model = None + data = [] + usage = FakeUsage() + + client = MockEmbeddingsClient(response=FakeResponse()) + with logfire.instrument_azure_ai_inference(client): + client.embed(model='text-embedding-ada-002', input=['Hi']) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + + +def test_no_messages_in_request(exporter: TestExporter) -> None: + """Test chat completion with no messages.""" + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + client.complete(model='gpt-4', messages=[]) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert 'gen_ai.input.messages' not in spans[0]['attributes'] + + +def test_system_message_non_string_content() -> None: + """Test system message with non-string content. + + Exercises the false branch of isinstance(content, str) for system messages. + """ + from logfire._internal.integrations.llm_providers.azure_ai_inference import convert_messages_to_semconv + + messages = [ + {'role': 'system', 'content': None}, + {'role': 'user', 'content': 'Hi'}, + ] + input_msgs, system_instructions = convert_messages_to_semconv(messages) + assert system_instructions == [] + assert len(input_msgs) == 1 + + +def test_response_no_output_messages() -> None: + """Test convert_response_to_semconv with empty choices. + + Exercises the false branch of `if output_messages:` in _on_chat_response. + """ + from logfire._internal.integrations.llm_providers.azure_ai_inference import convert_response_to_semconv + + class EmptyResponse: + choices = [] + + output = convert_response_to_semconv(EmptyResponse()) + assert output == [] + + +def test_response_tool_call_no_function() -> None: + """Test response tool call without function attribute. + + Exercises the false branch of `if func:` in convert_response_to_semconv. + """ + from logfire._internal.integrations.llm_providers.azure_ai_inference import convert_response_to_semconv + + class FakeToolCall: + id = 'tc1' + function = None + + class FakeMessage: + role = 'assistant' + content = None + tool_calls = [FakeToolCall()] + + class FakeChoice: + message = FakeMessage() + finish_reason = None + + class FakeResponse: + choices = [FakeChoice()] + + output = convert_response_to_semconv(FakeResponse()) + assert len(output) == 1 + assert output[0]['parts'] == [] + + +def test_stream_wrapped_with_context_manager(exporter: TestExporter) -> None: + """Test sync stream where wrapped object has __enter__/__exit__.""" + + class ContextManagerIterator: + def __init__(self, items: list[Any]) -> None: + self.items = items + self.entered = False + self.exited = False + + def __enter__(self) -> ContextManagerIterator: + self.entered = True + return self + + def __exit__(self, *args: Any) -> None: + self.exited = True + + def __iter__(self) -> Iterator[Any]: + return iter(self.items) + + chunks = _make_streaming_chunks() + wrapped = ContextManagerIterator(chunks) + + class MockChatCompletionsClientWithCM: + __module__ = 'azure.ai.inference' + + def complete(self, **kwargs: Any) -> Any: + if kwargs.get('stream'): + return wrapped + return _make_chat_response() + + client = MockChatCompletionsClientWithCM() + with logfire.instrument_azure_ai_inference(client): + response = client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'Hi'}], + stream=True, + ) + with response: + for _ in response: + pass + assert wrapped.entered + assert wrapped.exited + assert len(exporter.exported_spans_as_dict()) == 2 + + +@pytest.mark.anyio +async def test_async_stream_wrapped_with_context_manager(exporter: TestExporter) -> None: + """Test async stream where wrapped object has __aenter__/__aexit__.""" + + class AsyncContextManagerIterator: + def __init__(self, items: list[Any]) -> None: + self.items = items + self.entered = False + self.exited = False + + async def __aenter__(self) -> AsyncContextManagerIterator: + self.entered = True + return self + + async def __aexit__(self, *args: Any) -> None: + self.exited = True + + async def __aiter__(self) -> AsyncIterator[Any]: + for item in self.items: + yield item + + chunks = _make_streaming_chunks() + wrapped = AsyncContextManagerIterator(chunks) + + class MockAsyncChatCompletionsClientWithCM: + __module__ = 'azure.ai.inference.aio' + + async def complete(self, **kwargs: Any) -> Any: + if kwargs.get('stream'): + return wrapped + return _make_chat_response() + + client = MockAsyncChatCompletionsClientWithCM() + with logfire.instrument_azure_ai_inference(client): + response = await client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'Hi'}], + stream=True, + ) + async with response: + async for _ in response: + pass + assert wrapped.entered + assert wrapped.exited + assert len(exporter.exported_spans_as_dict()) == 2 + + +def test_streaming_empty_chunks(exporter: TestExporter) -> None: + """Test streaming with chunks that have no choices or no content.""" + empty_chunks = [ + StreamingChatCompletionsUpdate( + id='test-id', + model='gpt-4', + created=datetime(2024, 1, 1), + choices=[], + ), + StreamingChatCompletionsUpdate( + id='test-id', + model='gpt-4', + created=datetime(2024, 1, 1), + choices=[StreamingChatChoiceUpdate(index=0, delta=None)], + ), + StreamingChatCompletionsUpdate( + id='test-id', + model='gpt-4', + created=datetime(2024, 1, 1), + choices=[StreamingChatChoiceUpdate(index=0, delta=StreamingChatResponseMessageUpdate(content=None))], + ), + ] + client = MockChatCompletionsClient(stream_chunks=empty_chunks) + with logfire.instrument_azure_ai_inference(client): + response = client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'Hi'}], + stream=True, + ) + list(response) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + # Streaming info span should NOT have output messages (no actual content) + assert 'gen_ai.output.messages' not in spans[1]['attributes'] + + +@pytest.mark.anyio +async def test_async_streaming_empty_chunks(exporter: TestExporter) -> None: + """Test async streaming with chunks that have no choices or no content.""" + empty_chunks = [ + StreamingChatCompletionsUpdate( + id='test-id', + model='gpt-4', + created=datetime(2024, 1, 1), + choices=[], + ), + StreamingChatCompletionsUpdate( + id='test-id', + model='gpt-4', + created=datetime(2024, 1, 1), + choices=[StreamingChatChoiceUpdate(index=0, delta=None)], + ), + ] + client = MockAsyncChatCompletionsClient(stream_chunks=empty_chunks) + with logfire.instrument_azure_ai_inference(client): + response = await client.complete( + messages=[{'role': 'user', 'content': 'Hi'}], + stream=True, + ) + async for _ in response: + pass + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert 'gen_ai.output.messages' not in spans[1]['attributes'] + + +def test_streaming_no_model_chunks(exporter: TestExporter) -> None: + """Test streaming where both request and chunk have no model. + + Exercises false branch of `if model:` in _record_chunk (sync). + """ + no_model_chunks = [ + StreamingChatCompletionsUpdate( + id='test-id', + model=None, + created=datetime(2024, 1, 1), + choices=[StreamingChatChoiceUpdate(index=0, delta=StreamingChatResponseMessageUpdate(content='Hi'))], + ), + ] + client = MockChatCompletionsClient(stream_chunks=no_model_chunks) + with logfire.instrument_azure_ai_inference(client): + response = client.complete( + messages=[{'role': 'user', 'content': 'Hi'}], + stream=True, + ) + list(response) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 2 + + +@pytest.mark.anyio +async def test_async_streaming_no_model_chunks(exporter: TestExporter) -> None: + """Test async streaming where both request and chunk have no model. + + Exercises false branch of `if model:` in _record_chunk (async). + """ + no_model_chunks = [ + StreamingChatCompletionsUpdate( + id='test-id', + model=None, + created=datetime(2024, 1, 1), + choices=[StreamingChatChoiceUpdate(index=0, delta=StreamingChatResponseMessageUpdate(content='Hi'))], + ), + ] + client = MockAsyncChatCompletionsClient(stream_chunks=no_model_chunks) + with logfire.instrument_azure_ai_inference(client): + response = await client.complete( + messages=[{'role': 'user', 'content': 'Hi'}], + stream=True, + ) + async for _ in response: + pass + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 2 + + def test_message_conversion_with_typed_objects() -> None: """Test that Azure SDK typed message objects are converted correctly.""" from azure.ai.inference.models import SystemMessage, UserMessage From c62c243cc5bb6611c3cae109c973fa1dbe479a1a Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 19 Feb 2026 16:58:36 +0100 Subject: [PATCH 5/6] Use proper type annotations and move ImportError handling to integration module Address @Kludex's PR review: use TYPE_CHECKING imports with proper union type for the azure_ai_inference_client parameter (matching OpenAI/Anthropic pattern), and move the ImportError check + client list building from main.py to the integration module. --- .../llm_providers/azure_ai_inference.py | 23 ++++++++++ logfire/_internal/main.py | 44 ++++++++----------- .../test_azure_ai_inference.py | 11 +++++ 3 files changed, 53 insertions(+), 25 deletions(-) diff --git a/logfire/_internal/integrations/llm_providers/azure_ai_inference.py b/logfire/_internal/integrations/llm_providers/azure_ai_inference.py index b80cbdf73..57954ddb8 100644 --- a/logfire/_internal/integrations/llm_providers/azure_ai_inference.py +++ b/logfire/_internal/integrations/llm_providers/azure_ai_inference.py @@ -70,6 +70,29 @@ def instrument_azure_ai_inference( suppress_other_instrumentation: bool, ) -> AbstractContextManager[None]: """Instrument Azure AI Inference clients.""" + if client is None: + try: + from azure.ai.inference import ChatCompletionsClient, EmbeddingsClient + except ImportError: # pragma: no cover + raise RuntimeError( + 'The `logfire.instrument_azure_ai_inference()` method ' + 'requires the `azure-ai-inference` package.\n' + 'You can install this with:\n' + " pip install 'logfire[azure-ai-inference]'" + ) + + clients: list[Any] = [ChatCompletionsClient, EmbeddingsClient] + try: + from azure.ai.inference.aio import ( + ChatCompletionsClient as AsyncChatCompletionsClient, + EmbeddingsClient as AsyncEmbeddingsClient, + ) + + clients.extend([AsyncChatCompletionsClient, AsyncEmbeddingsClient]) + except ImportError: # pragma: no cover + pass + client = clients + if isinstance(client, (tuple, list)): context_managers = [ instrument_azure_ai_inference(logfire_instance, c, suppress_other_instrumentation) for c in client diff --git a/logfire/_internal/main.py b/logfire/_internal/main.py index 83615ff23..f3b9e9429 100644 --- a/logfire/_internal/main.py +++ b/logfire/_internal/main.py @@ -79,6 +79,14 @@ import openai import pydantic_ai.models import requests + from azure.ai.inference import ( + ChatCompletionsClient as AzureChatCompletionsClient, + EmbeddingsClient as AzureEmbeddingsClient, + ) + from azure.ai.inference.aio import ( + ChatCompletionsClient as AsyncAzureChatCompletionsClient, + EmbeddingsClient as AsyncAzureEmbeddingsClient, + ) from django.http import HttpRequest, HttpResponse from fastapi import FastAPI from flask.app import Flask @@ -1391,7 +1399,17 @@ def instrument_anthropic( def instrument_azure_ai_inference( self, - azure_ai_inference_client: Any = None, + azure_ai_inference_client: ( + AzureChatCompletionsClient + | AzureEmbeddingsClient + | AsyncAzureChatCompletionsClient + | AsyncAzureEmbeddingsClient + | type[AzureChatCompletionsClient] + | type[AzureEmbeddingsClient] + | type[AsyncAzureChatCompletionsClient] + | type[AsyncAzureEmbeddingsClient] + | None + ) = None, *, suppress_other_instrumentation: bool = True, ) -> AbstractContextManager[None]: @@ -1440,33 +1458,9 @@ def instrument_azure_ai_inference( A context manager that will revert the instrumentation when exited. Use of this context manager is optional. """ - try: - from azure.ai.inference import ChatCompletionsClient, EmbeddingsClient - except ImportError: # pragma: no cover - raise RuntimeError( - 'The `logfire.instrument_azure_ai_inference()` method ' - 'requires the `azure-ai-inference` package.\n' - 'You can install this with:\n' - " pip install 'logfire[azure-ai-inference]'" - ) - from .integrations.llm_providers.azure_ai_inference import instrument_azure_ai_inference self._warn_if_not_initialized_for_instrumentation() - - if azure_ai_inference_client is None: - clients_to_instrument: list[Any] = [ChatCompletionsClient, EmbeddingsClient] - try: - from azure.ai.inference.aio import ( - ChatCompletionsClient as AsyncChatCompletionsClient, - EmbeddingsClient as AsyncEmbeddingsClient, - ) - - clients_to_instrument.extend([AsyncChatCompletionsClient, AsyncEmbeddingsClient]) - except ImportError: # pragma: no cover - pass - azure_ai_inference_client = clients_to_instrument - return instrument_azure_ai_inference( self, azure_ai_inference_client, diff --git a/tests/otel_integrations/test_azure_ai_inference.py b/tests/otel_integrations/test_azure_ai_inference.py index b2a7c593c..43bdd0c93 100644 --- a/tests/otel_integrations/test_azure_ai_inference.py +++ b/tests/otel_integrations/test_azure_ai_inference.py @@ -1184,3 +1184,14 @@ def test_message_conversion_with_tool_messages() -> None: assert tool_resp['type'] == 'tool_call_response' assert tool_resp['id'] == 'call_1' assert tool_resp['response'] == '72F' + + +def test_global_instrumentation() -> None: + """Test passing None instruments all client classes via the integration module.""" + from logfire._internal.integrations.llm_providers.azure_ai_inference import instrument_azure_ai_inference + + # Call with client=None so the integration module resolves client classes itself + cm = instrument_azure_ai_inference(logfire.DEFAULT_LOGFIRE_INSTANCE, None, True) + # Just verify it returns a context manager without error + with cm: + pass From f175bd4679d682ded8bae452b20e7a3ab82b2214 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 19 Feb 2026 17:01:46 +0100 Subject: [PATCH 6/6] Document integration typing and ImportError patterns in CLAUDE.md --- CLAUDE.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CLAUDE.md b/CLAUDE.md index dc5d683ed..25147b8a9 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -75,6 +75,10 @@ Some tests are decorated with `@pytest.mark.vcr()` and use `pytest-recording` to The `logfire-api` package is a no-op shim that libraries can depend on to avoid hard dependencies on Logfire itself. It provides minimal 'implementations' in `logfire-api/logfire_api/__init__.py`, which needs to be kept up to date with the public API of the `logfire` module, especially if `test_logfire_api.py` starts failing. The rest is just `.pyi` stubs which should be ignored and are autogenerated when needed during release. +# Integrations + +`instrument_*` methods in `main.py` should use proper type annotations for their client parameters, not `Any`. The third-party types are imported under `if TYPE_CHECKING:` at the top of `main.py` (aliased if needed to avoid name collisions). ImportError handling and client resolution (e.g., building the list of all client classes when `None` is passed) should live in the integration module under `logfire/_internal/integrations/`, not in `main.py`. See `instrument_openai`, `instrument_anthropic`, and `instrument_azure_ai_inference` for examples. + # Misc Use `git push origin HEAD` to push, not just `git push`, so that it pushes to the current branch without needing to set upstream explicitly.