Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 3 additions & 13 deletions src/agents/models/openai_chatcompletions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from openai.types.chat.chat_completion import Choice
from openai.types.responses import Response
from openai.types.responses.response_prompt_param import ResponsePromptParam
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails

from .. import _debug
from ..agent_output import AgentOutputSchemaBase
Expand Down Expand Up @@ -102,18 +101,9 @@ async def get_response(
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
input_tokens_details=InputTokensDetails(
cached_tokens=getattr(
response.usage.prompt_tokens_details, "cached_tokens", 0
)
or 0,
),
output_tokens_details=OutputTokensDetails(
reasoning_tokens=getattr(
response.usage.completion_tokens_details, "reasoning_tokens", 0
)
or 0,
),
# BeforeValidator in Usage normalizes these from Chat Completions types
input_tokens_details=response.usage.prompt_tokens_details, # type: ignore[arg-type]
output_tokens_details=response.usage.completion_tokens_details, # type: ignore[arg-type]
)
if response.usage
else Usage()
Expand Down
41 changes: 34 additions & 7 deletions src/agents/usage.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,36 @@
from __future__ import annotations

from dataclasses import field
from typing import Annotated

from openai.types.completion_usage import CompletionTokensDetails, PromptTokensDetails
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
from pydantic import BeforeValidator
from pydantic.dataclasses import dataclass


def _normalize_input_tokens_details(
v: InputTokensDetails | PromptTokensDetails | None,
) -> InputTokensDetails:
"""Converts None or PromptTokensDetails to InputTokensDetails."""
if v is None:
return InputTokensDetails(cached_tokens=0)
if isinstance(v, PromptTokensDetails):
return InputTokensDetails(cached_tokens=v.cached_tokens or 0)
return v


def _normalize_output_tokens_details(
v: OutputTokensDetails | CompletionTokensDetails | None,
) -> OutputTokensDetails:
"""Converts None or CompletionTokensDetails to OutputTokensDetails."""
if v is None:
return OutputTokensDetails(reasoning_tokens=0)
if isinstance(v, CompletionTokensDetails):
return OutputTokensDetails(reasoning_tokens=v.reasoning_tokens or 0)
return v


@dataclass
class RequestUsage:
"""Usage details for a single API request."""
Expand Down Expand Up @@ -32,16 +59,16 @@ class Usage:
input_tokens: int = 0
"""Total input tokens sent, across all requests."""

input_tokens_details: InputTokensDetails = field(
default_factory=lambda: InputTokensDetails(cached_tokens=0)
)
input_tokens_details: Annotated[
InputTokensDetails, BeforeValidator(_normalize_input_tokens_details)
] = field(default_factory=lambda: InputTokensDetails(cached_tokens=0))
"""Details about the input tokens, matching responses API usage details."""
output_tokens: int = 0
"""Total output tokens received, across all requests."""

output_tokens_details: OutputTokensDetails = field(
default_factory=lambda: OutputTokensDetails(reasoning_tokens=0)
)
output_tokens_details: Annotated[
OutputTokensDetails, BeforeValidator(_normalize_output_tokens_details)
] = field(default_factory=lambda: OutputTokensDetails(reasoning_tokens=0))
"""Details about the output tokens, matching responses API usage details."""

total_tokens: int = 0
Expand Down Expand Up @@ -70,7 +97,7 @@ def __post_init__(self) -> None:
if self.output_tokens_details.reasoning_tokens is None:
self.output_tokens_details = OutputTokensDetails(reasoning_tokens=0)

def add(self, other: "Usage") -> None:
def add(self, other: Usage) -> None:
"""Add another Usage object to this one, aggregating all fields.

This method automatically preserves request_usage_entries.
Expand Down
50 changes: 49 additions & 1 deletion tests/test_usage.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from openai.types.completion_usage import CompletionTokensDetails, PromptTokensDetails
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails

from agents.usage import RequestUsage, Usage
Expand Down Expand Up @@ -270,7 +271,24 @@ def test_anthropic_cost_calculation_scenario():


def test_usage_normalizes_none_token_details():
# Some providers don't populate optional fields, resulting in None values
# Some providers don't populate optional token detail fields
# (cached_tokens, reasoning_tokens), and the OpenAI SDK's generated
# code can bypass Pydantic validation (e.g., via model_construct),
# allowing None values. We normalize these to 0 to prevent TypeErrors.

# Test entire objects being None (BeforeValidator)
usage = Usage(
requests=1,
input_tokens=100,
input_tokens_details=None, # type: ignore[arg-type]
output_tokens=50,
output_tokens_details=None, # type: ignore[arg-type]
total_tokens=150,
)
assert usage.input_tokens_details.cached_tokens == 0
assert usage.output_tokens_details.reasoning_tokens == 0

# Test fields within objects being None (__post_init__)
input_details = InputTokensDetails(cached_tokens=0)
input_details.__dict__["cached_tokens"] = None

Expand All @@ -289,3 +307,33 @@ def test_usage_normalizes_none_token_details():
# __post_init__ should normalize None to 0
assert usage.input_tokens_details.cached_tokens == 0
assert usage.output_tokens_details.reasoning_tokens == 0


def test_usage_normalizes_chat_completions_types():
# Chat Completions API uses PromptTokensDetails and CompletionTokensDetails,
# while Usage expects InputTokensDetails and OutputTokensDetails (Responses API).
# The BeforeValidator should convert between these types.

prompt_details = PromptTokensDetails(audio_tokens=10, cached_tokens=50)
completion_details = CompletionTokensDetails(
accepted_prediction_tokens=5,
audio_tokens=10,
reasoning_tokens=100,
rejected_prediction_tokens=2,
)

usage = Usage(
requests=1,
input_tokens=200,
input_tokens_details=prompt_details, # type: ignore[arg-type]
output_tokens=150,
output_tokens_details=completion_details, # type: ignore[arg-type]
total_tokens=350,
)

# Should convert to Responses API types, extracting the relevant fields
assert isinstance(usage.input_tokens_details, InputTokensDetails)
assert usage.input_tokens_details.cached_tokens == 50

assert isinstance(usage.output_tokens_details, OutputTokensDetails)
assert usage.output_tokens_details.reasoning_tokens == 100