diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 594848d3e..76f36d86b 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -11,7 +11,6 @@ from openai.types.chat.chat_completion import Choice from openai.types.responses import Response from openai.types.responses.response_prompt_param import ResponsePromptParam -from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from .. import _debug from ..agent_output import AgentOutputSchemaBase @@ -102,18 +101,9 @@ async def get_response( input_tokens=response.usage.prompt_tokens, output_tokens=response.usage.completion_tokens, total_tokens=response.usage.total_tokens, - input_tokens_details=InputTokensDetails( - cached_tokens=getattr( - response.usage.prompt_tokens_details, "cached_tokens", 0 - ) - or 0, - ), - output_tokens_details=OutputTokensDetails( - reasoning_tokens=getattr( - response.usage.completion_tokens_details, "reasoning_tokens", 0 - ) - or 0, - ), + # BeforeValidator in Usage normalizes these from Chat Completions types + input_tokens_details=response.usage.prompt_tokens_details, # type: ignore[arg-type] + output_tokens_details=response.usage.completion_tokens_details, # type: ignore[arg-type] ) if response.usage else Usage() diff --git a/src/agents/usage.py b/src/agents/usage.py index a10778123..915c903ff 100644 --- a/src/agents/usage.py +++ b/src/agents/usage.py @@ -1,9 +1,36 @@ +from __future__ import annotations + from dataclasses import field +from typing import Annotated +from openai.types.completion_usage import CompletionTokensDetails, PromptTokensDetails from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails +from pydantic import BeforeValidator from pydantic.dataclasses import dataclass +def _normalize_input_tokens_details( + v: InputTokensDetails | PromptTokensDetails | None, +) -> InputTokensDetails: + """Converts None or PromptTokensDetails to InputTokensDetails.""" + if v is None: + return InputTokensDetails(cached_tokens=0) + if isinstance(v, PromptTokensDetails): + return InputTokensDetails(cached_tokens=v.cached_tokens or 0) + return v + + +def _normalize_output_tokens_details( + v: OutputTokensDetails | CompletionTokensDetails | None, +) -> OutputTokensDetails: + """Converts None or CompletionTokensDetails to OutputTokensDetails.""" + if v is None: + return OutputTokensDetails(reasoning_tokens=0) + if isinstance(v, CompletionTokensDetails): + return OutputTokensDetails(reasoning_tokens=v.reasoning_tokens or 0) + return v + + @dataclass class RequestUsage: """Usage details for a single API request.""" @@ -32,16 +59,16 @@ class Usage: input_tokens: int = 0 """Total input tokens sent, across all requests.""" - input_tokens_details: InputTokensDetails = field( - default_factory=lambda: InputTokensDetails(cached_tokens=0) - ) + input_tokens_details: Annotated[ + InputTokensDetails, BeforeValidator(_normalize_input_tokens_details) + ] = field(default_factory=lambda: InputTokensDetails(cached_tokens=0)) """Details about the input tokens, matching responses API usage details.""" output_tokens: int = 0 """Total output tokens received, across all requests.""" - output_tokens_details: OutputTokensDetails = field( - default_factory=lambda: OutputTokensDetails(reasoning_tokens=0) - ) + output_tokens_details: Annotated[ + OutputTokensDetails, BeforeValidator(_normalize_output_tokens_details) + ] = field(default_factory=lambda: OutputTokensDetails(reasoning_tokens=0)) """Details about the output tokens, matching responses API usage details.""" total_tokens: int = 0 @@ -70,7 +97,7 @@ def __post_init__(self) -> None: if self.output_tokens_details.reasoning_tokens is None: self.output_tokens_details = OutputTokensDetails(reasoning_tokens=0) - def add(self, other: "Usage") -> None: + def add(self, other: Usage) -> None: """Add another Usage object to this one, aggregating all fields. This method automatically preserves request_usage_entries. diff --git a/tests/test_usage.py b/tests/test_usage.py index 9d89cc750..fbe26c98d 100644 --- a/tests/test_usage.py +++ b/tests/test_usage.py @@ -1,3 +1,4 @@ +from openai.types.completion_usage import CompletionTokensDetails, PromptTokensDetails from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from agents.usage import RequestUsage, Usage @@ -270,7 +271,24 @@ def test_anthropic_cost_calculation_scenario(): def test_usage_normalizes_none_token_details(): - # Some providers don't populate optional fields, resulting in None values + # Some providers don't populate optional token detail fields + # (cached_tokens, reasoning_tokens), and the OpenAI SDK's generated + # code can bypass Pydantic validation (e.g., via model_construct), + # allowing None values. We normalize these to 0 to prevent TypeErrors. + + # Test entire objects being None (BeforeValidator) + usage = Usage( + requests=1, + input_tokens=100, + input_tokens_details=None, # type: ignore[arg-type] + output_tokens=50, + output_tokens_details=None, # type: ignore[arg-type] + total_tokens=150, + ) + assert usage.input_tokens_details.cached_tokens == 0 + assert usage.output_tokens_details.reasoning_tokens == 0 + + # Test fields within objects being None (__post_init__) input_details = InputTokensDetails(cached_tokens=0) input_details.__dict__["cached_tokens"] = None @@ -289,3 +307,33 @@ def test_usage_normalizes_none_token_details(): # __post_init__ should normalize None to 0 assert usage.input_tokens_details.cached_tokens == 0 assert usage.output_tokens_details.reasoning_tokens == 0 + + +def test_usage_normalizes_chat_completions_types(): + # Chat Completions API uses PromptTokensDetails and CompletionTokensDetails, + # while Usage expects InputTokensDetails and OutputTokensDetails (Responses API). + # The BeforeValidator should convert between these types. + + prompt_details = PromptTokensDetails(audio_tokens=10, cached_tokens=50) + completion_details = CompletionTokensDetails( + accepted_prediction_tokens=5, + audio_tokens=10, + reasoning_tokens=100, + rejected_prediction_tokens=2, + ) + + usage = Usage( + requests=1, + input_tokens=200, + input_tokens_details=prompt_details, # type: ignore[arg-type] + output_tokens=150, + output_tokens_details=completion_details, # type: ignore[arg-type] + total_tokens=350, + ) + + # Should convert to Responses API types, extracting the relevant fields + assert isinstance(usage.input_tokens_details, InputTokensDetails) + assert usage.input_tokens_details.cached_tokens == 50 + + assert isinstance(usage.output_tokens_details, OutputTokensDetails) + assert usage.output_tokens_details.reasoning_tokens == 100