Skip to content

Commit f161dd3

Browse files
committed
feat: Add on_stream to agents as tools
1 parent 71fa12c commit f161dd3

File tree

5 files changed

+460
-15
lines changed

5 files changed

+460
-15
lines changed

examples/agent_patterns/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ The mental model for handoffs is that the new agent "takes over". It sees the pr
2828
For example, you could model the translation task above as tool calls instead: rather than handing over to the language-specific agent, you could call the agent as a tool, and then use the result in the next step. This enables things like translating multiple languages at once.
2929

3030
See the [`agents_as_tools.py`](./agents_as_tools.py) file for an example of this.
31+
See the [`agents_as_tools_streaming.py`](./agents_as_tools_streaming.py) file for a streaming variant that taps into nested agent events via `on_stream`.
3132

3233
## LLM-as-a-judge
3334

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import asyncio
2+
3+
from agents import Agent, AgentToolStreamEvent, ModelSettings, Runner, function_tool, trace
4+
5+
6+
@function_tool(
7+
name_override="billing_status_checker",
8+
description_override="Answer questions about customer billing status.",
9+
)
10+
def billing_status_checker(customer_id: str | None = None, question: str = "") -> str:
11+
"""Return a canned billing answer or a fallback when the question is unrelated."""
12+
normalized = question.lower()
13+
if "bill" in normalized or "billing" in normalized:
14+
return f"This customer (ID: {customer_id})'s bill is $100"
15+
return "I can only answer questions about billing."
16+
17+
18+
def handle_stream(event: AgentToolStreamEvent) -> None:
19+
"""Print streaming events emitted by the nested billing agent."""
20+
stream = event["event"]
21+
print(f"[stream] agent={event['agent_name']} type={stream.type} {stream}")
22+
23+
24+
async def main() -> None:
25+
with trace("Agents as tools streaming example"):
26+
billing_agent = Agent(
27+
name="Billing Agent",
28+
instructions="You are a billing agent that answers billing questions.",
29+
model_settings=ModelSettings(tool_choice="required"),
30+
tools=[billing_status_checker],
31+
)
32+
33+
billing_agent_tool = billing_agent.as_tool(
34+
tool_name="billing_agent",
35+
tool_description="You are a billing agent that answers billing questions.",
36+
on_stream=handle_stream,
37+
)
38+
39+
main_agent = Agent(
40+
name="Customer Support Agent",
41+
instructions=(
42+
"You are a customer support agent. Always call the billing agent to answer billing "
43+
"questions and return the billing agent response to the user."
44+
),
45+
tools=[billing_agent_tool],
46+
)
47+
48+
result = await Runner.run(
49+
main_agent,
50+
"Hello, my customer ID is ABC123. How much is my bill for this month?",
51+
)
52+
53+
print(f"\nFinal response:\n{result.final_output}")
54+
55+
56+
if __name__ == "__main__":
57+
asyncio.run(main())

src/agents/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .agent import (
99
Agent,
1010
AgentBase,
11+
AgentToolStreamEvent,
1112
StopAtTools,
1213
ToolsToFinalOutputFunction,
1314
ToolsToFinalOutputResult,
@@ -214,6 +215,7 @@ def enable_verbose_stdout_logging():
214215
__all__ = [
215216
"Agent",
216217
"AgentBase",
218+
"AgentToolStreamEvent",
217219
"StopAtTools",
218220
"ToolsToFinalOutputFunction",
219221
"ToolsToFinalOutputResult",

src/agents/agent.py

Lines changed: 62 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@
3232
from .lifecycle import AgentHooks, RunHooks
3333
from .mcp import MCPServer
3434
from .memory.session import Session
35-
from .result import RunResult
35+
from .result import RunResult, RunResultStreaming
3636
from .run import RunConfig
37+
from .stream_events import StreamEvent
3738

3839

3940
@dataclass
@@ -58,6 +59,19 @@ class ToolsToFinalOutputResult:
5859
"""
5960

6061

62+
class AgentToolStreamEvent(TypedDict):
63+
"""Streaming event emitted when an agent is invoked as a tool."""
64+
65+
event: StreamEvent
66+
"""The streaming event from the nested agent run."""
67+
68+
agent_name: str
69+
"""The name of the nested agent emitting the event."""
70+
71+
tool_call_id: str | None
72+
"""The originating tool call ID, if available."""
73+
74+
6175
class StopAtTools(TypedDict):
6276
stop_at_tool_names: list[str]
6377
"""A list of tool names, any of which will stop the agent from running further."""
@@ -382,9 +396,12 @@ def as_tool(
382396
self,
383397
tool_name: str | None,
384398
tool_description: str | None,
385-
custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None,
399+
custom_output_extractor: (
400+
Callable[[RunResult | RunResultStreaming], Awaitable[str]] | None
401+
) = None,
386402
is_enabled: bool
387403
| Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True,
404+
on_stream: Callable[[AgentToolStreamEvent], MaybeAwaitable[None]] | None = None,
388405
run_config: RunConfig | None = None,
389406
max_turns: int | None = None,
390407
hooks: RunHooks[TContext] | None = None,
@@ -409,6 +426,8 @@ def as_tool(
409426
is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run
410427
context and agent and returns whether the tool is enabled. Disabled tools are hidden
411428
from the LLM at runtime.
429+
on_stream: Optional callback (sync or async) to receive streaming events from the nested
430+
agent run. When provided, the nested agent is executed in streaming mode.
412431
"""
413432

414433
@function_tool(
@@ -421,21 +440,49 @@ async def run_agent(context: RunContextWrapper, input: str) -> Any:
421440

422441
resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS
423442

424-
output = await Runner.run(
425-
starting_agent=self,
426-
input=input,
427-
context=context.context,
428-
run_config=run_config,
429-
max_turns=resolved_max_turns,
430-
hooks=hooks,
431-
previous_response_id=previous_response_id,
432-
conversation_id=conversation_id,
433-
session=session,
434-
)
443+
if on_stream is not None:
444+
run_result = Runner.run_streamed(
445+
starting_agent=self,
446+
input=input,
447+
context=context.context,
448+
run_config=run_config,
449+
max_turns=resolved_max_turns,
450+
hooks=hooks,
451+
previous_response_id=previous_response_id,
452+
conversation_id=conversation_id,
453+
session=session,
454+
)
455+
async for event in run_result.stream_events():
456+
payload: AgentToolStreamEvent = {
457+
"event": event,
458+
"agent_name": self.name,
459+
"tool_call_id": getattr(context, "tool_call_id", None),
460+
}
461+
try:
462+
maybe_result = on_stream(payload)
463+
if inspect.isawaitable(maybe_result):
464+
await maybe_result
465+
except Exception:
466+
logger.exception(
467+
"Error while handling on_stream event for agent tool %s.",
468+
self.name,
469+
)
470+
else:
471+
run_result = await Runner.run(
472+
starting_agent=self,
473+
input=input,
474+
context=context.context,
475+
run_config=run_config,
476+
max_turns=resolved_max_turns,
477+
hooks=hooks,
478+
previous_response_id=previous_response_id,
479+
conversation_id=conversation_id,
480+
session=session,
481+
)
435482
if custom_output_extractor:
436-
return await custom_output_extractor(output)
483+
return await custom_output_extractor(run_result)
437484

438-
return output.final_output
485+
return run_result.final_output
439486

440487
return run_agent
441488

0 commit comments

Comments
 (0)