-
Notifications
You must be signed in to change notification settings - Fork 3k
feat: Add on_stream to agents as tools #2169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
f161dd3
4ce48a1
7f1672a
94d2239
dbcb6e3
a8104bf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| import asyncio | ||
|
|
||
| from agents import Agent, AgentToolStreamEvent, ModelSettings, Runner, function_tool, trace | ||
|
|
||
|
|
||
| @function_tool( | ||
| name_override="billing_status_checker", | ||
| description_override="Answer questions about customer billing status.", | ||
| ) | ||
| def billing_status_checker(customer_id: str | None = None, question: str = "") -> str: | ||
| """Return a canned billing answer or a fallback when the question is unrelated.""" | ||
| normalized = question.lower() | ||
| if "bill" in normalized or "billing" in normalized: | ||
| return f"This customer (ID: {customer_id})'s bill is $100" | ||
| return "I can only answer questions about billing." | ||
|
|
||
|
|
||
| def handle_stream(event: AgentToolStreamEvent) -> None: | ||
| """Print streaming events emitted by the nested billing agent.""" | ||
| stream = event["event"] | ||
| tool_call = event.get("tool_call") | ||
| tool_call_info = tool_call.call_id if tool_call is not None else "unknown" | ||
| print(f"[stream] agent={event['agent'].name} call={tool_call_info} type={stream.type} {stream}") | ||
|
|
||
|
|
||
| async def main() -> None: | ||
| with trace("Agents as tools streaming example"): | ||
| billing_agent = Agent( | ||
| name="Billing Agent", | ||
| instructions="You are a billing agent that answers billing questions.", | ||
| model_settings=ModelSettings(tool_choice="required"), | ||
| tools=[billing_status_checker], | ||
| ) | ||
|
|
||
| billing_agent_tool = billing_agent.as_tool( | ||
| tool_name="billing_agent", | ||
| tool_description="You are a billing agent that answers billing questions.", | ||
| on_stream=handle_stream, | ||
| ) | ||
|
|
||
| main_agent = Agent( | ||
| name="Customer Support Agent", | ||
| instructions=( | ||
| "You are a customer support agent. Always call the billing agent to answer billing " | ||
| "questions and return the billing agent response to the user." | ||
| ), | ||
| tools=[billing_agent_tool], | ||
| ) | ||
|
|
||
| result = await Runner.run( | ||
| main_agent, | ||
| "Hello, my customer ID is ABC123. How much is my bill for this month?", | ||
| ) | ||
|
|
||
| print(f"\nFinal response:\n{result.final_output}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| asyncio.run(main()) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,11 +29,14 @@ | |
| from .util._types import MaybeAwaitable | ||
|
|
||
| if TYPE_CHECKING: | ||
| from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall | ||
|
|
||
| from .lifecycle import AgentHooks, RunHooks | ||
| from .mcp import MCPServer | ||
| from .memory.session import Session | ||
| from .result import RunResult | ||
| from .result import RunResult, RunResultStreaming | ||
| from .run import RunConfig | ||
| from .stream_events import StreamEvent | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -58,6 +61,19 @@ class ToolsToFinalOutputResult: | |
| """ | ||
|
|
||
|
|
||
| class AgentToolStreamEvent(TypedDict): | ||
| """Streaming event emitted when an agent is invoked as a tool.""" | ||
|
|
||
| event: StreamEvent | ||
| """The streaming event from the nested agent run.""" | ||
|
|
||
| agent: Agent[Any] | ||
| """The nested agent emitting the event.""" | ||
|
|
||
| tool_call: ResponseFunctionToolCall | None | ||
| """The originating tool call, if available.""" | ||
|
|
||
|
|
||
| class StopAtTools(TypedDict): | ||
| stop_at_tool_names: list[str] | ||
| """A list of tool names, any of which will stop the agent from running further.""" | ||
|
|
@@ -382,9 +398,12 @@ def as_tool( | |
| self, | ||
| tool_name: str | None, | ||
| tool_description: str | None, | ||
| custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None, | ||
| custom_output_extractor: ( | ||
| Callable[[RunResult | RunResultStreaming], Awaitable[str]] | None | ||
| ) = None, | ||
| is_enabled: bool | ||
| | Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True, | ||
| on_stream: Callable[[AgentToolStreamEvent], MaybeAwaitable[None]] | None = None, | ||
| run_config: RunConfig | None = None, | ||
| max_turns: int | None = None, | ||
| hooks: RunHooks[TContext] | None = None, | ||
|
|
@@ -409,6 +428,10 @@ def as_tool( | |
| is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run | ||
| context and agent and returns whether the tool is enabled. Disabled tools are hidden | ||
| from the LLM at runtime. | ||
| on_stream: Optional callback (sync or async) to receive streaming events from the nested | ||
| agent run. The callback receives an `AgentToolStreamEvent` containing the nested | ||
| agent, the originating tool call (when available), and each stream event. When | ||
| provided, the nested agent is executed in streaming mode. | ||
| """ | ||
|
|
||
| @function_tool( | ||
|
|
@@ -420,22 +443,79 @@ async def run_agent(context: RunContextWrapper, input: str) -> Any: | |
| from .run import DEFAULT_MAX_TURNS, Runner | ||
|
|
||
| resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS | ||
|
|
||
| output = await Runner.run( | ||
| starting_agent=self, | ||
| input=input, | ||
| context=context.context, | ||
| run_config=run_config, | ||
| max_turns=resolved_max_turns, | ||
| hooks=hooks, | ||
| previous_response_id=previous_response_id, | ||
| conversation_id=conversation_id, | ||
| session=session, | ||
| ) | ||
| run_result: RunResult | RunResultStreaming | ||
|
|
||
| if on_stream is not None: | ||
| run_result = Runner.run_streamed( | ||
| starting_agent=self, | ||
| input=input, | ||
| context=context.context, | ||
| run_config=run_config, | ||
| max_turns=resolved_max_turns, | ||
| hooks=hooks, | ||
| previous_response_id=previous_response_id, | ||
| conversation_id=conversation_id, | ||
| session=session, | ||
| ) | ||
| # Dispatch callbacks in the background so slow handlers do not block | ||
| # event consumption. | ||
| event_queue: asyncio.Queue[AgentToolStreamEvent | None] = asyncio.Queue() | ||
|
|
||
| async def _run_handler(payload: AgentToolStreamEvent) -> None: | ||
| """Execute the user callback while capturing exceptions.""" | ||
| try: | ||
| maybe_result = on_stream(payload) | ||
| if inspect.isawaitable(maybe_result): | ||
| await maybe_result | ||
| except Exception: | ||
| logger.exception( | ||
| "Error while handling on_stream event for agent tool %s.", | ||
| self.name, | ||
| ) | ||
|
|
||
| async def dispatch_stream_events() -> None: | ||
| while True: | ||
| payload = await event_queue.get() | ||
| is_sentinel = payload is None # None marks the end of the stream. | ||
| try: | ||
| if payload is not None: | ||
| await _run_handler(payload) | ||
| finally: | ||
| event_queue.task_done() | ||
|
|
||
| if is_sentinel: | ||
| break | ||
|
|
||
| dispatch_task = asyncio.create_task(dispatch_stream_events()) | ||
|
|
||
| try: | ||
| async for event in run_result.stream_events(): | ||
| payload: AgentToolStreamEvent = { | ||
| "event": event, | ||
| "agent": self, | ||
| "tool_call": getattr(context, "tool_call", None), | ||
|
||
| } | ||
| await event_queue.put(payload) | ||
| finally: | ||
| await event_queue.put(None) | ||
| await event_queue.join() | ||
| await dispatch_task | ||
| else: | ||
| run_result = await Runner.run( | ||
| starting_agent=self, | ||
| input=input, | ||
| context=context.context, | ||
| run_config=run_config, | ||
| max_turns=resolved_max_turns, | ||
| hooks=hooks, | ||
| previous_response_id=previous_response_id, | ||
| conversation_id=conversation_id, | ||
| session=session, | ||
| ) | ||
| if custom_output_extractor: | ||
| return await custom_output_extractor(output) | ||
| return await custom_output_extractor(run_result) | ||
|
|
||
| return output.final_output | ||
| return run_result.final_output | ||
|
|
||
| return run_agent | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can be kinda bad since
on_streamwill block you from going to the next event. fine for now, but we should consider a queue based pattern where we write to a queue from here and the consumer reads from the queue, and we aren't blockingThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good call; actually i made these executions async in the TS SDK. so will make this more efficient and consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
improved by 94d2239