@@ -117,6 +117,7 @@ def __init__(
117117 wrapped : AbstractAgent [AgentDepsT , OutputDataT ],
118118 * ,
119119 name : str | None = None ,
120+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | Mapping [str , AbstractToolset [AgentDepsT ]] | None = None ,
120121 models : Mapping [str , Model ] | None = None ,
121122 provider_factory : TemporalProviderFactory | None = None ,
122123 event_stream_handler : EventStreamHandler [AgentDepsT ] | None = None ,
@@ -144,6 +145,11 @@ def __init__(
144145 Args:
145146 wrapped: The agent to wrap.
146147 name: Optional unique agent name to use in the Temporal activities' names. If not provided, the agent's `name` will be used.
148+ toolsets:
149+ Optional additional toolsets to register with the agent, or a mapping of toolset names to toolset instances.
150+ Toolsets passed here will be temporalized and their activities registered alongside the wrapped agent's existing toolsets.
151+ If a mapping is provided, toolsets can be referenced by name in `run(toolsets=['name'])`.
152+
147153 models:
148154 Optional mapping of model instances to register with the agent.
149155 Keys define the names that can be referenced at runtime and the values are `Model` instances.
@@ -191,6 +197,7 @@ def __init__(
191197 ]
192198 activity_config ['retry_policy' ] = retry_policy
193199 self .activity_config = activity_config
200+ self ._named_toolsets : Mapping [str , AbstractToolset [AgentDepsT ]] | None = None
194201
195202 model_activity_config = model_activity_config or {}
196203 toolset_activity_config = toolset_activity_config or {}
@@ -235,6 +242,18 @@ async def streamed_response():
235242 activities .extend (temporal_model .temporal_activities )
236243 self ._temporal_model = temporal_model
237244
245+ if toolsets :
246+ if isinstance (toolsets , Mapping ):
247+ # Flatten the mapping for temporalization, but keep track of names
248+ additional_toolsets = list (toolsets .values ())
249+ self ._named_toolsets = toolsets
250+ else :
251+ additional_toolsets = list (toolsets )
252+ self ._named_toolsets = {}
253+ else :
254+ additional_toolsets = []
255+ self ._named_toolsets = {}
256+
238257 def temporalize_toolset (toolset : AbstractToolset [AgentDepsT ]) -> AbstractToolset [AgentDepsT ]:
239258 id = toolset .id
240259 if id is None :
@@ -254,9 +273,35 @@ def temporalize_toolset(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset
254273 activities .extend (toolset .temporal_activities )
255274 return toolset
256275
257- temporal_toolsets = [toolset .visit_and_replace (temporalize_toolset ) for toolset in wrapped .toolsets ]
276+ all_toolsets = [* wrapped .toolsets , * additional_toolsets ]
277+ temporal_toolsets = [toolset .visit_and_replace (temporalize_toolset ) for toolset in all_toolsets ]
278+
279+ # If we had named toolsets, we need to map the names to the temporalized versions
280+ # We know that visit_and_replace returns a new instance (or the same one if no replacement needed)
281+ # matching the structure of the input.
282+ # However, since we flattened everything into `all_toolsets` and then mapped it to `temporal_toolsets`,
283+ # we can reconstruct the named mapping by index.
284+ # But wait, `wrapped.toolsets` are first.
285+ if self ._named_toolsets :
286+ # The additional toolsets are at the end of `temporal_toolsets`
287+ num_wrapped_toolsets = len (wrapped .toolsets )
288+ # The temporalized additional toolsets
289+ temporal_additional_toolsets = temporal_toolsets [num_wrapped_toolsets :]
290+
291+ # create a new mapping pointing to the temporalized versions
292+ new_named_toolsets : dict [str , AbstractToolset [AgentDepsT ]] = {}
293+ for name , temporal_toolset in zip (self ._named_toolsets , temporal_additional_toolsets ):
294+ new_named_toolsets [name ] = temporal_toolset
295+ self ._named_toolsets = new_named_toolsets
296+
297+ # If toolsets were passed as a mapping, they are not added to the active toolsets by default
298+ if isinstance (toolsets , Mapping ):
299+ self ._toolsets = temporal_toolsets [:num_wrapped_toolsets ]
300+ else :
301+ self ._toolsets = temporal_toolsets
302+ else :
303+ self ._toolsets = temporal_toolsets
258304
259- self ._toolsets = temporal_toolsets
260305 self ._temporal_activities = activities
261306
262307 self ._temporal_overrides_active : ContextVar [bool ] = ContextVar ('_temporal_overrides_active' , default = False )
@@ -314,16 +359,21 @@ def temporal_activities(self) -> list[Callable[..., Any]]:
314359 @contextmanager
315360 def _temporal_overrides (
316361 self ,
317- toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
362+ toolsets : Sequence [AbstractToolset [AgentDepsT ] | str ] | None = None ,
318363 model : models .Model | models .KnownModelName | str | None = None ,
319364 force : bool = False ,
320365 ) -> Iterator [None ]:
321366 in_workflow = workflow .in_workflow ()
322367
323368 if toolsets :
324- if in_workflow :
325- _validate_temporal_toolsets (toolsets )
326- overridden_toolsets = [* self ._toolsets , * toolsets ]
369+ if workflow .in_workflow ():
370+ # If toolsets are provided as strings, we can't validate them directly here as they are resolved later.
371+ # We only validate if they are already AbstractToolset instances.
372+ _validate_temporal_toolsets ([t for t in toolsets if not isinstance (t , str )])
373+
374+ resolved_toolsets = self ._resolve_toolsets (toolsets )
375+ assert resolved_toolsets is not None
376+ overridden_toolsets = [* self ._toolsets , * resolved_toolsets ]
327377 else :
328378 overridden_toolsets = list (self ._toolsets )
329379
@@ -352,6 +402,26 @@ def _temporal_overrides(
352402 finally :
353403 self ._temporal_overrides_active .reset (temporal_active_token )
354404
405+ def _resolve_toolsets (
406+ self , toolsets : Sequence [AbstractToolset [AgentDepsT ] | str ] | None
407+ ) -> Sequence [AbstractToolset [AgentDepsT ]] | None :
408+ if toolsets is None :
409+ return None
410+
411+ resolved_toolsets : list [AbstractToolset [AgentDepsT ]] = []
412+ for t in toolsets :
413+ if isinstance (t , str ):
414+ if self ._named_toolsets is None :
415+ raise UserError (f"Unknown toolset name: '{ t } '. No named toolsets registered." )
416+ if t not in self ._named_toolsets :
417+ raise UserError (
418+ f"Unknown toolset name: '{ t } '. Available toolsets: { list (self ._named_toolsets .keys ())} "
419+ )
420+ resolved_toolsets .append (self ._named_toolsets [t ])
421+ else :
422+ resolved_toolsets .append (t )
423+ return resolved_toolsets
424+
355425 @overload
356426 async def run (
357427 self ,
@@ -367,7 +437,7 @@ async def run(
367437 usage_limits : _usage .UsageLimits | None = None ,
368438 usage : _usage .RunUsage | None = None ,
369439 infer_name : bool = True ,
370- toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
440+ toolsets : Sequence [AbstractToolset [AgentDepsT ] | str ] | None = None ,
371441 builtin_tools : Sequence [AbstractBuiltinTool | BuiltinToolFunc [AgentDepsT ]] | None = None ,
372442 event_stream_handler : EventStreamHandler [AgentDepsT ] | None = None ,
373443 ) -> AgentRunResult [OutputDataT ]: ...
@@ -387,7 +457,7 @@ async def run(
387457 usage_limits : _usage .UsageLimits | None = None ,
388458 usage : _usage .RunUsage | None = None ,
389459 infer_name : bool = True ,
390- toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
460+ toolsets : Sequence [AbstractToolset [AgentDepsT ] | str ] | None = None ,
391461 builtin_tools : Sequence [AbstractBuiltinTool | BuiltinToolFunc [AgentDepsT ]] | None = None ,
392462 event_stream_handler : EventStreamHandler [AgentDepsT ] | None = None ,
393463 ) -> AgentRunResult [RunOutputDataT ]: ...
@@ -406,7 +476,7 @@ async def run(
406476 usage_limits : _usage .UsageLimits | None = None ,
407477 usage : _usage .RunUsage | None = None ,
408478 infer_name : bool = True ,
409- toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
479+ toolsets : Sequence [AbstractToolset [AgentDepsT ] | str ] | None = None ,
410480 builtin_tools : Sequence [AbstractBuiltinTool | BuiltinToolFunc [AgentDepsT ]] | None = None ,
411481 event_stream_handler : EventStreamHandler [AgentDepsT ] | None = None ,
412482 ** _deprecated_kwargs : Never ,
@@ -492,7 +562,7 @@ def run_sync(
492562 usage_limits : _usage .UsageLimits | None = None ,
493563 usage : _usage .RunUsage | None = None ,
494564 infer_name : bool = True ,
495- toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
565+ toolsets : Sequence [AbstractToolset [AgentDepsT ] | str ] | None = None ,
496566 builtin_tools : Sequence [AbstractBuiltinTool | BuiltinToolFunc [AgentDepsT ]] | None = None ,
497567 event_stream_handler : EventStreamHandler [AgentDepsT ] | None = None ,
498568 ) -> AgentRunResult [OutputDataT ]: ...
@@ -512,7 +582,7 @@ def run_sync(
512582 usage_limits : _usage .UsageLimits | None = None ,
513583 usage : _usage .RunUsage | None = None ,
514584 infer_name : bool = True ,
515- toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
585+ toolsets : Sequence [AbstractToolset [AgentDepsT ] | str ] | None = None ,
516586 builtin_tools : Sequence [AbstractBuiltinTool | BuiltinToolFunc [AgentDepsT ]] | None = None ,
517587 event_stream_handler : EventStreamHandler [AgentDepsT ] | None = None ,
518588 ) -> AgentRunResult [RunOutputDataT ]: ...
@@ -531,7 +601,7 @@ def run_sync(
531601 usage_limits : _usage .UsageLimits | None = None ,
532602 usage : _usage .RunUsage | None = None ,
533603 infer_name : bool = True ,
534- toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
604+ toolsets : Sequence [AbstractToolset [AgentDepsT ] | str ] | None = None ,
535605 builtin_tools : Sequence [AbstractBuiltinTool | BuiltinToolFunc [AgentDepsT ]] | None = None ,
536606 event_stream_handler : EventStreamHandler [AgentDepsT ] | None = None ,
537607 ** _deprecated_kwargs : Never ,
@@ -589,7 +659,7 @@ def run_sync(
589659 usage_limits = usage_limits ,
590660 usage = usage ,
591661 infer_name = infer_name ,
592- toolsets = toolsets ,
662+ toolsets = self . _resolve_toolsets ( toolsets ) ,
593663 builtin_tools = builtin_tools ,
594664 event_stream_handler = event_stream_handler ,
595665 ** _deprecated_kwargs ,
@@ -610,7 +680,7 @@ def run_stream(
610680 usage_limits : _usage .UsageLimits | None = None ,
611681 usage : _usage .RunUsage | None = None ,
612682 infer_name : bool = True ,
613- toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
683+ toolsets : Sequence [AbstractToolset [AgentDepsT ] | str ] | None = None ,
614684 builtin_tools : Sequence [AbstractBuiltinTool | BuiltinToolFunc [AgentDepsT ]] | None = None ,
615685 event_stream_handler : EventStreamHandler [AgentDepsT ] | None = None ,
616686 ) -> AbstractAsyncContextManager [StreamedRunResult [AgentDepsT , OutputDataT ]]: ...
@@ -630,7 +700,7 @@ def run_stream(
630700 usage_limits : _usage .UsageLimits | None = None ,
631701 usage : _usage .RunUsage | None = None ,
632702 infer_name : bool = True ,
633- toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
703+ toolsets : Sequence [AbstractToolset [AgentDepsT ] | str ] | None = None ,
634704 builtin_tools : Sequence [AbstractBuiltinTool | BuiltinToolFunc [AgentDepsT ]] | None = None ,
635705 event_stream_handler : EventStreamHandler [AgentDepsT ] | None = None ,
636706 ) -> AbstractAsyncContextManager [StreamedRunResult [AgentDepsT , RunOutputDataT ]]: ...
@@ -650,7 +720,7 @@ async def run_stream(
650720 usage_limits : _usage .UsageLimits | None = None ,
651721 usage : _usage .RunUsage | None = None ,
652722 infer_name : bool = True ,
653- toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
723+ toolsets : Sequence [AbstractToolset [AgentDepsT ] | str ] | None = None ,
654724 builtin_tools : Sequence [AbstractBuiltinTool | BuiltinToolFunc [AgentDepsT ]] | None = None ,
655725 event_stream_handler : EventStreamHandler [AgentDepsT ] | None = None ,
656726 ** _deprecated_kwargs : Never ,
@@ -707,7 +777,7 @@ async def main():
707777 usage_limits = usage_limits ,
708778 usage = usage ,
709779 infer_name = infer_name ,
710- toolsets = toolsets ,
780+ toolsets = self . _resolve_toolsets ( toolsets ) ,
711781 event_stream_handler = event_stream_handler ,
712782 builtin_tools = builtin_tools ,
713783 ** _deprecated_kwargs ,
@@ -729,8 +799,9 @@ def run_stream_events(
729799 usage_limits : _usage .UsageLimits | None = None ,
730800 usage : _usage .RunUsage | None = None ,
731801 infer_name : bool = True ,
732- toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
802+ toolsets : Sequence [AbstractToolset [AgentDepsT ] | str ] | None = None ,
733803 builtin_tools : Sequence [AbstractBuiltinTool | BuiltinToolFunc [AgentDepsT ]] | None = None ,
804+ event_stream_handler : EventStreamHandler [AgentDepsT ] | None = None ,
734805 ) -> AsyncIterator [_messages .AgentStreamEvent | AgentRunResultEvent [OutputDataT ]]: ...
735806
736807 @overload
@@ -748,8 +819,9 @@ def run_stream_events(
748819 usage_limits : _usage .UsageLimits | None = None ,
749820 usage : _usage .RunUsage | None = None ,
750821 infer_name : bool = True ,
751- toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
822+ toolsets : Sequence [AbstractToolset [AgentDepsT ] | str ] | None = None ,
752823 builtin_tools : Sequence [AbstractBuiltinTool | BuiltinToolFunc [AgentDepsT ]] | None = None ,
824+ event_stream_handler : EventStreamHandler [AgentDepsT ] | None = None ,
753825 ) -> AsyncIterator [_messages .AgentStreamEvent | AgentRunResultEvent [RunOutputDataT ]]: ...
754826
755827 def run_stream_events (
@@ -766,8 +838,9 @@ def run_stream_events(
766838 usage_limits : _usage .UsageLimits | None = None ,
767839 usage : _usage .RunUsage | None = None ,
768840 infer_name : bool = True ,
769- toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
841+ toolsets : Sequence [AbstractToolset [AgentDepsT ] | str ] | None = None ,
770842 builtin_tools : Sequence [AbstractBuiltinTool | BuiltinToolFunc [AgentDepsT ]] | None = None ,
843+ event_stream_handler : EventStreamHandler [AgentDepsT ] | None = None ,
771844 ) -> AsyncIterator [_messages .AgentStreamEvent | AgentRunResultEvent [Any ]]:
772845 """Run the agent with a user prompt in async mode and stream events from the run.
773846
@@ -818,6 +891,7 @@ async def main():
818891 infer_name: Whether to try to infer the agent name from the call frame if it's not set.
819892 toolsets: Optional additional toolsets for this run.
820893 builtin_tools: Optional additional builtin tools for this run.
894+ event_stream_handler: Optional event stream handler to use for this run. It will receive all the events up until the final result is found, which you can then read or stream from inside the context manager.
821895
822896 Returns:
823897 An async iterable of stream events `AgentStreamEvent` and finally a `AgentRunResultEvent` with the final
@@ -841,7 +915,7 @@ async def main():
841915 usage_limits = usage_limits ,
842916 usage = usage ,
843917 infer_name = infer_name ,
844- toolsets = toolsets ,
918+ toolsets = self . _resolve_toolsets ( toolsets ) ,
845919 builtin_tools = builtin_tools ,
846920 )
847921
@@ -979,6 +1053,7 @@ async def main():
9791053 infer_name: Whether to try to infer the agent name from the call frame if it's not set.
9801054 toolsets: Optional additional toolsets for this run.
9811055 builtin_tools: Optional additional builtin tools for this run.
1056+ event_stream_handler: Optional event stream handler to use for this run.
9821057
9831058 Returns:
9841059 The result of the run.
0 commit comments