Skip to content

Commit 4ce48a1

Browse files
committed
fix mypy errors
1 parent f161dd3 commit 4ce48a1

File tree

3 files changed

+64
-43
lines changed

3 files changed

+64
-43
lines changed

examples/financial_research_agent/manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from rich.console import Console
88

9-
from agents import Runner, RunResult, custom_span, gen_trace_id, trace
9+
from agents import Runner, RunResult, RunResultStreaming, custom_span, gen_trace_id, trace
1010

1111
from .agents.financials_agent import financials_agent
1212
from .agents.planner_agent import FinancialSearchItem, FinancialSearchPlan, planner_agent
@@ -17,7 +17,7 @@
1717
from .printer import Printer
1818

1919

20-
async def _summary_extractor(run_result: RunResult) -> str:
20+
async def _summary_extractor(run_result: RunResult | RunResultStreaming) -> str:
2121
"""Custom output extractor for sub‑agents that return an AnalysisSummary."""
2222
# The financial/risk analyst agents emit an AnalysisSummary with a `summary` field.
2323
# We want the tool call to return just that summary text so the writer can drop it inline.

src/agents/agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ async def run_agent(context: RunContextWrapper, input: str) -> Any:
439439
from .run import DEFAULT_MAX_TURNS, Runner
440440

441441
resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS
442+
run_result: RunResult | RunResultStreaming
442443

443444
if on_stream is not None:
444445
run_result = Runner.run_streamed(

tests/test_agent_as_tool.py

Lines changed: 61 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any
3+
from typing import Any, cast
44

55
import pytest
66
from openai.types.responses import ResponseOutputMessage, ResponseOutputText
@@ -9,6 +9,7 @@
99
from agents import (
1010
Agent,
1111
AgentBase,
12+
AgentToolStreamEvent,
1213
FunctionTool,
1314
MessageOutputItem,
1415
RunConfig,
@@ -382,8 +383,8 @@ async def test_agent_as_tool_streams_events_with_on_stream(
382383
) -> None:
383384
agent = Agent(name="streamer")
384385
stream_events = [
385-
RawResponsesStreamEvent(data={"type": "response_started"}),
386-
RawResponsesStreamEvent(data={"type": "output_text_delta", "delta": "hi"}),
386+
RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})),
387+
RawResponsesStreamEvent(data=cast(Any, {"type": "output_text_delta", "delta": "hi"})),
387388
]
388389

389390
class DummyStreamingResult:
@@ -431,15 +432,18 @@ async def unexpected_run(*args: Any, **kwargs: Any) -> None:
431432
monkeypatch.setattr(Runner, "run_streamed", classmethod(fake_run_streamed))
432433
monkeypatch.setattr(Runner, "run", classmethod(unexpected_run))
433434

434-
received_events: list[dict[str, Any]] = []
435+
received_events: list[AgentToolStreamEvent] = []
435436

436-
async def on_stream(payload: dict[str, Any]) -> None:
437+
async def on_stream(payload: AgentToolStreamEvent) -> None:
437438
received_events.append(payload)
438439

439-
tool = agent.as_tool(
440-
tool_name="stream_tool",
441-
tool_description="Streams events",
442-
on_stream=on_stream,
440+
tool = cast(
441+
FunctionTool,
442+
agent.as_tool(
443+
tool_name="stream_tool",
444+
tool_description="Streams events",
445+
on_stream=on_stream,
446+
),
443447
)
444448

445449
tool_context = ToolContext(
@@ -463,7 +467,8 @@ async def test_agent_as_tool_streaming_works_with_custom_extractor(
463467
monkeypatch: pytest.MonkeyPatch,
464468
) -> None:
465469
agent = Agent(name="streamer")
466-
stream_events = [RawResponsesStreamEvent(data={"type": "response_started"})]
470+
stream_events = [RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))]
471+
stream_events = [RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))]
467472

468473
class DummyStreamingResult:
469474
def __init__(self) -> None:
@@ -505,14 +510,17 @@ async def extractor(result) -> str:
505510

506511
callbacks: list[Any] = []
507512

508-
async def on_stream(payload: dict[str, Any]) -> None:
513+
async def on_stream(payload: AgentToolStreamEvent) -> None:
509514
callbacks.append(payload["event"])
510515

511-
tool = agent.as_tool(
512-
tool_name="stream_tool",
513-
tool_description="Streams events",
514-
custom_output_extractor=extractor,
515-
on_stream=on_stream,
516+
tool = cast(
517+
FunctionTool,
518+
agent.as_tool(
519+
tool_name="stream_tool",
520+
tool_description="Streams events",
521+
custom_output_extractor=extractor,
522+
on_stream=on_stream,
523+
),
516524
)
517525

518526
tool_context = ToolContext(
@@ -539,7 +547,7 @@ def __init__(self) -> None:
539547
self.final_output = "ok"
540548

541549
async def stream_events(self):
542-
yield RawResponsesStreamEvent(data={"type": "response_started"})
550+
yield RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))
543551

544552
monkeypatch.setattr(
545553
Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult())
@@ -552,13 +560,16 @@ async def stream_events(self):
552560

553561
calls: list[str] = []
554562

555-
def sync_handler(event: dict[str, Any]) -> None:
563+
def sync_handler(event: AgentToolStreamEvent) -> None:
556564
calls.append(event["event"].type)
557565

558-
tool = agent.as_tool(
559-
tool_name="sync_tool",
560-
tool_description="Uses sync handler",
561-
on_stream=sync_handler,
566+
tool = cast(
567+
FunctionTool,
568+
agent.as_tool(
569+
tool_name="sync_tool",
570+
tool_description="Uses sync handler",
571+
on_stream=sync_handler,
572+
),
562573
)
563574
tool_context = ToolContext(
564575
context=None,
@@ -584,7 +595,7 @@ def __init__(self) -> None:
584595
self.final_output = "ok"
585596

586597
async def stream_events(self):
587-
yield RawResponsesStreamEvent(data={"type": "response_started"})
598+
yield RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))
588599

589600
monkeypatch.setattr(
590601
Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult())
@@ -595,13 +606,16 @@ async def stream_events(self):
595606
classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))),
596607
)
597608

598-
def bad_handler(event: dict[str, Any]) -> None:
609+
def bad_handler(event: AgentToolStreamEvent) -> None:
599610
raise RuntimeError("boom")
600611

601-
tool = agent.as_tool(
602-
tool_name="error_tool",
603-
tool_description="Handler throws",
604-
on_stream=bad_handler,
612+
tool = cast(
613+
FunctionTool,
614+
agent.as_tool(
615+
tool_name="error_tool",
616+
tool_description="Handler throws",
617+
on_stream=bad_handler,
618+
),
605619
)
606620
tool_context = ToolContext(
607621
context=None,
@@ -651,9 +665,12 @@ async def fake_run(
651665
classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no stream"))),
652666
)
653667

654-
tool = agent.as_tool(
655-
tool_name="nostream_tool",
656-
tool_description="No streaming path",
668+
tool = cast(
669+
FunctionTool,
670+
agent.as_tool(
671+
tool_name="nostream_tool",
672+
tool_description="No streaming path",
673+
),
657674
)
658675
tool_context = ToolContext(
659676
context=None,
@@ -669,7 +686,7 @@ async def fake_run(
669686

670687

671688
@pytest.mark.asyncio
672-
async def test_agent_as_tool_streaming_sets_tool_call_id_none_for_direct_invocation(
689+
async def test_agent_as_tool_streaming_sets_tool_call_id_from_context(
673690
monkeypatch: pytest.MonkeyPatch,
674691
) -> None:
675692
agent = Agent(name="direct_invocation_agent")
@@ -679,7 +696,7 @@ def __init__(self) -> None:
679696
self.final_output = "ok"
680697

681698
async def stream_events(self):
682-
yield RawResponsesStreamEvent(data={"type": "response_started"})
699+
yield RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))
683700

684701
monkeypatch.setattr(
685702
Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult())
@@ -690,24 +707,27 @@ async def stream_events(self):
690707
classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))),
691708
)
692709

693-
captured: list[dict[str, Any]] = []
710+
captured: list[AgentToolStreamEvent] = []
694711

695-
async def on_stream(event: dict[str, Any]) -> None:
712+
async def on_stream(event: AgentToolStreamEvent) -> None:
696713
captured.append(event)
697714

698-
tool = agent.as_tool(
699-
tool_name="direct_stream_tool",
700-
tool_description="Direct invocation",
701-
on_stream=on_stream,
715+
tool = cast(
716+
FunctionTool,
717+
agent.as_tool(
718+
tool_name="direct_stream_tool",
719+
tool_description="Direct invocation",
720+
on_stream=on_stream,
721+
),
702722
)
703723
tool_context = ToolContext(
704724
context=None,
705725
tool_name="direct_stream_tool",
706-
tool_call_id=None, # Direct invoke path does not have a tool call ID.
726+
tool_call_id="direct-call-id",
707727
tool_arguments='{"input": "hi"}',
708728
)
709729

710730
output = await tool.on_invoke_tool(tool_context, '{"input": "hi"}')
711731

712732
assert output == "ok"
713-
assert captured[0]["tool_call_id"] is None
733+
assert captured[0]["tool_call_id"] == "direct-call-id"

0 commit comments

Comments
 (0)