11from __future__ import annotations
22
3- from typing import Any
3+ from typing import Any , cast
44
55import pytest
66from openai .types .responses import ResponseOutputMessage , ResponseOutputText
99from agents import (
1010 Agent ,
1111 AgentBase ,
12+ AgentToolStreamEvent ,
1213 FunctionTool ,
1314 MessageOutputItem ,
1415 RunConfig ,
@@ -382,8 +383,8 @@ async def test_agent_as_tool_streams_events_with_on_stream(
382383) -> None :
383384 agent = Agent (name = "streamer" )
384385 stream_events = [
385- RawResponsesStreamEvent (data = {"type" : "response_started" }),
386- RawResponsesStreamEvent (data = {"type" : "output_text_delta" , "delta" : "hi" }),
386+ RawResponsesStreamEvent (data = cast ( Any , {"type" : "response_started" }) ),
387+ RawResponsesStreamEvent (data = cast ( Any , {"type" : "output_text_delta" , "delta" : "hi" }) ),
387388 ]
388389
389390 class DummyStreamingResult :
@@ -431,15 +432,18 @@ async def unexpected_run(*args: Any, **kwargs: Any) -> None:
431432 monkeypatch .setattr (Runner , "run_streamed" , classmethod (fake_run_streamed ))
432433 monkeypatch .setattr (Runner , "run" , classmethod (unexpected_run ))
433434
434- received_events : list [dict [ str , Any ] ] = []
435+ received_events : list [AgentToolStreamEvent ] = []
435436
436- async def on_stream (payload : dict [ str , Any ] ) -> None :
437+ async def on_stream (payload : AgentToolStreamEvent ) -> None :
437438 received_events .append (payload )
438439
439- tool = agent .as_tool (
440- tool_name = "stream_tool" ,
441- tool_description = "Streams events" ,
442- on_stream = on_stream ,
440+ tool = cast (
441+ FunctionTool ,
442+ agent .as_tool (
443+ tool_name = "stream_tool" ,
444+ tool_description = "Streams events" ,
445+ on_stream = on_stream ,
446+ ),
443447 )
444448
445449 tool_context = ToolContext (
@@ -463,7 +467,8 @@ async def test_agent_as_tool_streaming_works_with_custom_extractor(
463467 monkeypatch : pytest .MonkeyPatch ,
464468) -> None :
465469 agent = Agent (name = "streamer" )
466- stream_events = [RawResponsesStreamEvent (data = {"type" : "response_started" })]
470+ stream_events = [RawResponsesStreamEvent (data = cast (Any , {"type" : "response_started" }))]
471+ stream_events = [RawResponsesStreamEvent (data = cast (Any , {"type" : "response_started" }))]
467472
468473 class DummyStreamingResult :
469474 def __init__ (self ) -> None :
@@ -505,14 +510,17 @@ async def extractor(result) -> str:
505510
506511 callbacks : list [Any ] = []
507512
508- async def on_stream (payload : dict [ str , Any ] ) -> None :
513+ async def on_stream (payload : AgentToolStreamEvent ) -> None :
509514 callbacks .append (payload ["event" ])
510515
511- tool = agent .as_tool (
512- tool_name = "stream_tool" ,
513- tool_description = "Streams events" ,
514- custom_output_extractor = extractor ,
515- on_stream = on_stream ,
516+ tool = cast (
517+ FunctionTool ,
518+ agent .as_tool (
519+ tool_name = "stream_tool" ,
520+ tool_description = "Streams events" ,
521+ custom_output_extractor = extractor ,
522+ on_stream = on_stream ,
523+ ),
516524 )
517525
518526 tool_context = ToolContext (
@@ -539,7 +547,7 @@ def __init__(self) -> None:
539547 self .final_output = "ok"
540548
541549 async def stream_events (self ):
542- yield RawResponsesStreamEvent (data = {"type" : "response_started" })
550+ yield RawResponsesStreamEvent (data = cast ( Any , {"type" : "response_started" }) )
543551
544552 monkeypatch .setattr (
545553 Runner , "run_streamed" , classmethod (lambda * args , ** kwargs : DummyStreamingResult ())
@@ -552,13 +560,16 @@ async def stream_events(self):
552560
553561 calls : list [str ] = []
554562
555- def sync_handler (event : dict [ str , Any ] ) -> None :
563+ def sync_handler (event : AgentToolStreamEvent ) -> None :
556564 calls .append (event ["event" ].type )
557565
558- tool = agent .as_tool (
559- tool_name = "sync_tool" ,
560- tool_description = "Uses sync handler" ,
561- on_stream = sync_handler ,
566+ tool = cast (
567+ FunctionTool ,
568+ agent .as_tool (
569+ tool_name = "sync_tool" ,
570+ tool_description = "Uses sync handler" ,
571+ on_stream = sync_handler ,
572+ ),
562573 )
563574 tool_context = ToolContext (
564575 context = None ,
@@ -584,7 +595,7 @@ def __init__(self) -> None:
584595 self .final_output = "ok"
585596
586597 async def stream_events (self ):
587- yield RawResponsesStreamEvent (data = {"type" : "response_started" })
598+ yield RawResponsesStreamEvent (data = cast ( Any , {"type" : "response_started" }) )
588599
589600 monkeypatch .setattr (
590601 Runner , "run_streamed" , classmethod (lambda * args , ** kwargs : DummyStreamingResult ())
@@ -595,13 +606,16 @@ async def stream_events(self):
595606 classmethod (lambda * args , ** kwargs : (_ for _ in ()).throw (AssertionError ("no run" ))),
596607 )
597608
598- def bad_handler (event : dict [ str , Any ] ) -> None :
609+ def bad_handler (event : AgentToolStreamEvent ) -> None :
599610 raise RuntimeError ("boom" )
600611
601- tool = agent .as_tool (
602- tool_name = "error_tool" ,
603- tool_description = "Handler throws" ,
604- on_stream = bad_handler ,
612+ tool = cast (
613+ FunctionTool ,
614+ agent .as_tool (
615+ tool_name = "error_tool" ,
616+ tool_description = "Handler throws" ,
617+ on_stream = bad_handler ,
618+ ),
605619 )
606620 tool_context = ToolContext (
607621 context = None ,
@@ -651,9 +665,12 @@ async def fake_run(
651665 classmethod (lambda * args , ** kwargs : (_ for _ in ()).throw (AssertionError ("no stream" ))),
652666 )
653667
654- tool = agent .as_tool (
655- tool_name = "nostream_tool" ,
656- tool_description = "No streaming path" ,
668+ tool = cast (
669+ FunctionTool ,
670+ agent .as_tool (
671+ tool_name = "nostream_tool" ,
672+ tool_description = "No streaming path" ,
673+ ),
657674 )
658675 tool_context = ToolContext (
659676 context = None ,
@@ -669,7 +686,7 @@ async def fake_run(
669686
670687
671688@pytest .mark .asyncio
672- async def test_agent_as_tool_streaming_sets_tool_call_id_none_for_direct_invocation (
689+ async def test_agent_as_tool_streaming_sets_tool_call_id_from_context (
673690 monkeypatch : pytest .MonkeyPatch ,
674691) -> None :
675692 agent = Agent (name = "direct_invocation_agent" )
@@ -679,7 +696,7 @@ def __init__(self) -> None:
679696 self .final_output = "ok"
680697
681698 async def stream_events (self ):
682- yield RawResponsesStreamEvent (data = {"type" : "response_started" })
699+ yield RawResponsesStreamEvent (data = cast ( Any , {"type" : "response_started" }) )
683700
684701 monkeypatch .setattr (
685702 Runner , "run_streamed" , classmethod (lambda * args , ** kwargs : DummyStreamingResult ())
@@ -690,24 +707,27 @@ async def stream_events(self):
690707 classmethod (lambda * args , ** kwargs : (_ for _ in ()).throw (AssertionError ("no run" ))),
691708 )
692709
693- captured : list [dict [ str , Any ] ] = []
710+ captured : list [AgentToolStreamEvent ] = []
694711
695- async def on_stream (event : dict [ str , Any ] ) -> None :
712+ async def on_stream (event : AgentToolStreamEvent ) -> None :
696713 captured .append (event )
697714
698- tool = agent .as_tool (
699- tool_name = "direct_stream_tool" ,
700- tool_description = "Direct invocation" ,
701- on_stream = on_stream ,
715+ tool = cast (
716+ FunctionTool ,
717+ agent .as_tool (
718+ tool_name = "direct_stream_tool" ,
719+ tool_description = "Direct invocation" ,
720+ on_stream = on_stream ,
721+ ),
702722 )
703723 tool_context = ToolContext (
704724 context = None ,
705725 tool_name = "direct_stream_tool" ,
706- tool_call_id = None , # Direct invoke path does not have a tool call ID.
726+ tool_call_id = "direct- call-id" ,
707727 tool_arguments = '{"input": "hi"}' ,
708728 )
709729
710730 output = await tool .on_invoke_tool (tool_context , '{"input": "hi"}' )
711731
712732 assert output == "ok"
713- assert captured [0 ]["tool_call_id" ] is None
733+ assert captured [0 ]["tool_call_id" ] == "direct-call-id"
0 commit comments