Skip to content

Commit db55bb0

Browse files
committed
Support named toolset pre-registration
1 parent 9cc0ce6 commit db55bb0

File tree

3 files changed

+200
-22
lines changed

3 files changed

+200
-22
lines changed

docs/durable_execution/temporal.md

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,11 @@ To use toolsets at runtime with `TemporalAgent`, you need to:
276276

277277
This pattern allows multiple agents to share the same set of tools without duplicating activity registrations, enabling dynamic agent creation while maintaining proper tool registration.
278278

279-
Here's an example showing how to register and use shared toolsets across multiple agents:
279+
Alternatively, you can pre-register toolsets with the `TemporalAgent` constructor and reference them by name at runtime. This is similar to how models are handled.
280+
281+
### Using Toolset Instances
282+
283+
Here's an example showing how to register and use shared toolsets across multiple agents using toolset instances:
280284

281285
```python {title="shared_toolset_temporal.py" test="skip"}
282286
from datetime import timedelta
@@ -381,6 +385,54 @@ async def main():
381385
5. Pass the wrapped toolset at runtime to any agent that needs it.
382386
6. Register the shared toolset's activities once with the worker. Agent activities are automatically registered via `__pydantic_ai_agents__`.
383387

388+
6. Register the shared toolset's activities once with the worker. Agent activities are automatically registered via `__pydantic_ai_agents__`.
389+
390+
### Using Named Toolsets
391+
392+
You can also pre-register toolsets with names and reference them by name at runtime:
393+
394+
```python {title="named_toolset_temporal.py" test="skip"}
395+
from temporalio import workflow
396+
from pydantic_ai import Agent, FunctionToolset
397+
from pydantic_ai.durable_exec.temporal import TemporalAgent, TemporalFunctionToolset
398+
399+
# Define tools and toolset
400+
def magic_trick(input: str) -> str:
401+
return f"Magic: {input}"
402+
403+
magic_toolset = FunctionToolset(tools=[magic_trick], id='magic')
404+
405+
# Wrap toolset
406+
wrapped_magic_toolset = TemporalFunctionToolset(
407+
magic_toolset,
408+
activity_name_prefix='magic',
409+
deps_type=type(None),
410+
)
411+
412+
# Create agent with pre-registered toolset
413+
agent = Agent('openai:gpt-5', name='magic_agent')
414+
temporal_agent = TemporalAgent(
415+
agent,
416+
toolsets={'magic_tools': wrapped_magic_toolset}, # (1)!
417+
)
418+
419+
@workflow.defn
420+
class MagicWorkflow:
421+
__pydantic_ai_agents__ = [temporal_agent]
422+
423+
@workflow.run
424+
async def run(self, input: str) -> str:
425+
# Reference toolset by name
426+
result = await temporal_agent.run(
427+
input,
428+
toolsets=['magic_tools'], # (2)!
429+
)
430+
return result.output
431+
```
432+
433+
1. Pass a dictionary of toolsets to `TemporalAgent` to pre-register them. The keys are the names used to reference the toolsets at runtime.
434+
2. Pass the listing of toolset names to `run(toolsets=[...])` to use the pre-registered toolsets.
435+
384436
You can also wrap toolsets at agent creation time by passing them to the wrapped agent's constructor, which will automatically temporalize them. The runtime pattern shown above is useful when you want to share toolsets across multiple agents or select toolsets dynamically based on workflow parameters.
385437

386438
## Activity Configuration

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py

Lines changed: 96 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def __init__(
117117
wrapped: AbstractAgent[AgentDepsT, OutputDataT],
118118
*,
119119
name: str | None = None,
120+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | Mapping[str, AbstractToolset[AgentDepsT]] | None = None,
120121
models: Mapping[str, Model] | None = None,
121122
provider_factory: TemporalProviderFactory | None = None,
122123
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
@@ -144,6 +145,11 @@ def __init__(
144145
Args:
145146
wrapped: The agent to wrap.
146147
name: Optional unique agent name to use in the Temporal activities' names. If not provided, the agent's `name` will be used.
148+
toolsets:
149+
Optional additional toolsets to register with the agent, or a mapping of toolset names to toolset instances.
150+
Toolsets passed here will be temporalized and their activities registered alongside the wrapped agent's existing toolsets.
151+
If a mapping is provided, toolsets can be referenced by name in `run(toolsets=['name'])`.
152+
147153
models:
148154
Optional mapping of model instances to register with the agent.
149155
Keys define the names that can be referenced at runtime and the values are `Model` instances.
@@ -191,6 +197,7 @@ def __init__(
191197
]
192198
activity_config['retry_policy'] = retry_policy
193199
self.activity_config = activity_config
200+
self._named_toolsets: Mapping[str, AbstractToolset[AgentDepsT]] | None = None
194201

195202
model_activity_config = model_activity_config or {}
196203
toolset_activity_config = toolset_activity_config or {}
@@ -235,6 +242,18 @@ async def streamed_response():
235242
activities.extend(temporal_model.temporal_activities)
236243
self._temporal_model = temporal_model
237244

245+
if toolsets:
246+
if isinstance(toolsets, Mapping):
247+
# Flatten the mapping for temporalization, but keep track of names
248+
additional_toolsets = list(toolsets.values())
249+
self._named_toolsets = toolsets
250+
else:
251+
additional_toolsets = list(toolsets)
252+
self._named_toolsets = {}
253+
else:
254+
additional_toolsets = []
255+
self._named_toolsets = {}
256+
238257
def temporalize_toolset(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDepsT]:
239258
id = toolset.id
240259
if id is None:
@@ -254,9 +273,35 @@ def temporalize_toolset(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset
254273
activities.extend(toolset.temporal_activities)
255274
return toolset
256275

257-
temporal_toolsets = [toolset.visit_and_replace(temporalize_toolset) for toolset in wrapped.toolsets]
276+
all_toolsets = [*wrapped.toolsets, *additional_toolsets]
277+
temporal_toolsets = [toolset.visit_and_replace(temporalize_toolset) for toolset in all_toolsets]
278+
279+
# If we had named toolsets, we need to map the names to the temporalized versions
280+
# We know that visit_and_replace returns a new instance (or the same one if no replacement needed)
281+
# matching the structure of the input.
282+
# However, since we flattened everything into `all_toolsets` and then mapped it to `temporal_toolsets`,
283+
# we can reconstruct the named mapping by index.
284+
# But wait, `wrapped.toolsets` are first.
285+
if self._named_toolsets:
286+
# The additional toolsets are at the end of `temporal_toolsets`
287+
num_wrapped_toolsets = len(wrapped.toolsets)
288+
# The temporalized additional toolsets
289+
temporal_additional_toolsets = temporal_toolsets[num_wrapped_toolsets:]
290+
291+
# create a new mapping pointing to the temporalized versions
292+
new_named_toolsets: dict[str, AbstractToolset[AgentDepsT]] = {}
293+
for name, temporal_toolset in zip(self._named_toolsets, temporal_additional_toolsets):
294+
new_named_toolsets[name] = temporal_toolset
295+
self._named_toolsets = new_named_toolsets
296+
297+
# If toolsets were passed as a mapping, they are not added to the active toolsets by default
298+
if isinstance(toolsets, Mapping):
299+
self._toolsets = temporal_toolsets[:num_wrapped_toolsets]
300+
else:
301+
self._toolsets = temporal_toolsets
302+
else:
303+
self._toolsets = temporal_toolsets
258304

259-
self._toolsets = temporal_toolsets
260305
self._temporal_activities = activities
261306

262307
self._temporal_overrides_active: ContextVar[bool] = ContextVar('_temporal_overrides_active', default=False)
@@ -314,16 +359,21 @@ def temporal_activities(self) -> list[Callable[..., Any]]:
314359
@contextmanager
315360
def _temporal_overrides(
316361
self,
317-
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
362+
toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None,
318363
model: models.Model | models.KnownModelName | str | None = None,
319364
force: bool = False,
320365
) -> Iterator[None]:
321366
in_workflow = workflow.in_workflow()
322367

323368
if toolsets:
324-
if in_workflow:
325-
_validate_temporal_toolsets(toolsets)
326-
overridden_toolsets = [*self._toolsets, *toolsets]
369+
if workflow.in_workflow():
370+
# If toolsets are provided as strings, we can't validate them directly here as they are resolved later.
371+
# We only validate if they are already AbstractToolset instances.
372+
_validate_temporal_toolsets([t for t in toolsets if not isinstance(t, str)])
373+
374+
resolved_toolsets = self._resolve_toolsets(toolsets)
375+
assert resolved_toolsets is not None
376+
overridden_toolsets = [*self._toolsets, *resolved_toolsets]
327377
else:
328378
overridden_toolsets = list(self._toolsets)
329379

@@ -352,6 +402,26 @@ def _temporal_overrides(
352402
finally:
353403
self._temporal_overrides_active.reset(temporal_active_token)
354404

405+
def _resolve_toolsets(
406+
self, toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None
407+
) -> Sequence[AbstractToolset[AgentDepsT]] | None:
408+
if toolsets is None:
409+
return None
410+
411+
resolved_toolsets: list[AbstractToolset[AgentDepsT]] = []
412+
for t in toolsets:
413+
if isinstance(t, str):
414+
if self._named_toolsets is None:
415+
raise UserError(f"Unknown toolset name: '{t}'. No named toolsets registered.")
416+
if t not in self._named_toolsets:
417+
raise UserError(
418+
f"Unknown toolset name: '{t}'. Available toolsets: {list(self._named_toolsets.keys())}"
419+
)
420+
resolved_toolsets.append(self._named_toolsets[t])
421+
else:
422+
resolved_toolsets.append(t)
423+
return resolved_toolsets
424+
355425
@overload
356426
async def run(
357427
self,
@@ -367,7 +437,7 @@ async def run(
367437
usage_limits: _usage.UsageLimits | None = None,
368438
usage: _usage.RunUsage | None = None,
369439
infer_name: bool = True,
370-
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
440+
toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None,
371441
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
372442
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
373443
) -> AgentRunResult[OutputDataT]: ...
@@ -387,7 +457,7 @@ async def run(
387457
usage_limits: _usage.UsageLimits | None = None,
388458
usage: _usage.RunUsage | None = None,
389459
infer_name: bool = True,
390-
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
460+
toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None,
391461
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
392462
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
393463
) -> AgentRunResult[RunOutputDataT]: ...
@@ -406,7 +476,7 @@ async def run(
406476
usage_limits: _usage.UsageLimits | None = None,
407477
usage: _usage.RunUsage | None = None,
408478
infer_name: bool = True,
409-
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
479+
toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None,
410480
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
411481
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
412482
**_deprecated_kwargs: Never,
@@ -492,7 +562,7 @@ def run_sync(
492562
usage_limits: _usage.UsageLimits | None = None,
493563
usage: _usage.RunUsage | None = None,
494564
infer_name: bool = True,
495-
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
565+
toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None,
496566
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
497567
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
498568
) -> AgentRunResult[OutputDataT]: ...
@@ -512,7 +582,7 @@ def run_sync(
512582
usage_limits: _usage.UsageLimits | None = None,
513583
usage: _usage.RunUsage | None = None,
514584
infer_name: bool = True,
515-
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
585+
toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None,
516586
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
517587
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
518588
) -> AgentRunResult[RunOutputDataT]: ...
@@ -531,7 +601,7 @@ def run_sync(
531601
usage_limits: _usage.UsageLimits | None = None,
532602
usage: _usage.RunUsage | None = None,
533603
infer_name: bool = True,
534-
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
604+
toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None,
535605
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
536606
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
537607
**_deprecated_kwargs: Never,
@@ -589,7 +659,7 @@ def run_sync(
589659
usage_limits=usage_limits,
590660
usage=usage,
591661
infer_name=infer_name,
592-
toolsets=toolsets,
662+
toolsets=self._resolve_toolsets(toolsets),
593663
builtin_tools=builtin_tools,
594664
event_stream_handler=event_stream_handler,
595665
**_deprecated_kwargs,
@@ -610,7 +680,7 @@ def run_stream(
610680
usage_limits: _usage.UsageLimits | None = None,
611681
usage: _usage.RunUsage | None = None,
612682
infer_name: bool = True,
613-
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
683+
toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None,
614684
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
615685
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
616686
) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, OutputDataT]]: ...
@@ -630,7 +700,7 @@ def run_stream(
630700
usage_limits: _usage.UsageLimits | None = None,
631701
usage: _usage.RunUsage | None = None,
632702
infer_name: bool = True,
633-
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
703+
toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None,
634704
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
635705
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
636706
) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, RunOutputDataT]]: ...
@@ -650,7 +720,7 @@ async def run_stream(
650720
usage_limits: _usage.UsageLimits | None = None,
651721
usage: _usage.RunUsage | None = None,
652722
infer_name: bool = True,
653-
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
723+
toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None,
654724
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
655725
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
656726
**_deprecated_kwargs: Never,
@@ -707,7 +777,7 @@ async def main():
707777
usage_limits=usage_limits,
708778
usage=usage,
709779
infer_name=infer_name,
710-
toolsets=toolsets,
780+
toolsets=self._resolve_toolsets(toolsets),
711781
event_stream_handler=event_stream_handler,
712782
builtin_tools=builtin_tools,
713783
**_deprecated_kwargs,
@@ -729,8 +799,9 @@ def run_stream_events(
729799
usage_limits: _usage.UsageLimits | None = None,
730800
usage: _usage.RunUsage | None = None,
731801
infer_name: bool = True,
732-
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
802+
toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None,
733803
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
804+
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
734805
) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[OutputDataT]]: ...
735806

736807
@overload
@@ -748,8 +819,9 @@ def run_stream_events(
748819
usage_limits: _usage.UsageLimits | None = None,
749820
usage: _usage.RunUsage | None = None,
750821
infer_name: bool = True,
751-
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
822+
toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None,
752823
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
824+
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
753825
) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[RunOutputDataT]]: ...
754826

755827
def run_stream_events(
@@ -766,8 +838,9 @@ def run_stream_events(
766838
usage_limits: _usage.UsageLimits | None = None,
767839
usage: _usage.RunUsage | None = None,
768840
infer_name: bool = True,
769-
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
841+
toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None,
770842
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
843+
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
771844
) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[Any]]:
772845
"""Run the agent with a user prompt in async mode and stream events from the run.
773846
@@ -818,6 +891,7 @@ async def main():
818891
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
819892
toolsets: Optional additional toolsets for this run.
820893
builtin_tools: Optional additional builtin tools for this run.
894+
event_stream_handler: Optional event stream handler to use for this run. It will receive all the events up until the final result is found, which you can then read or stream from inside the context manager.
821895
822896
Returns:
823897
An async iterable of stream events `AgentStreamEvent` and finally a `AgentRunResultEvent` with the final
@@ -841,7 +915,7 @@ async def main():
841915
usage_limits=usage_limits,
842916
usage=usage,
843917
infer_name=infer_name,
844-
toolsets=toolsets,
918+
toolsets=self._resolve_toolsets(toolsets),
845919
builtin_tools=builtin_tools,
846920
)
847921

@@ -979,6 +1053,7 @@ async def main():
9791053
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
9801054
toolsets: Optional additional toolsets for this run.
9811055
builtin_tools: Optional additional builtin tools for this run.
1056+
event_stream_handler: Optional event stream handler to use for this run.
9821057
9831058
Returns:
9841059
The result of the run.

0 commit comments

Comments
 (0)