diff --git a/docs/durable_execution/temporal.md b/docs/durable_execution/temporal.md index 177329e449..764401d6bd 100644 --- a/docs/durable_execution/temporal.md +++ b/docs/durable_execution/temporal.md @@ -159,7 +159,7 @@ To ensure that Temporal knows what code to run when an activity fails or is inte When `TemporalAgent` dynamically creates activities for the wrapped agent's model requests and toolsets (specifically those that implement their own tool listing and calling, i.e. [`FunctionToolset`][pydantic_ai.toolsets.FunctionToolset] and [`MCPServer`][pydantic_ai.mcp.MCPServer]), their names are derived from the agent's [`name`][pydantic_ai.agent.AbstractAgent.name] and the toolsets' [`id`s][pydantic_ai.toolsets.AbstractToolset.id]. These fields are normally optional, but are required to be set when using Temporal. They should not be changed once the durable agent has been deployed to production as this would break active workflows. -For dynamic toolsets created with the [`@agent.toolset`][pydantic_ai.Agent.toolset] decorator, the `id` parameter must be set explicitly. Note that with Temporal, `per_run_step=False` is not respected, as the toolset always needs to be created on-the-fly in the activity. +For dynamic toolsets created with the [`@agent.toolset`][pydantic_ai.agent.Agent.toolset] decorator, the `id` parameter must be set explicitly. Note that with Temporal, `per_run_step=False` is not respected, as the toolset always needs to be created on-the-fly in the activity. Other than that, any agent and toolset will just work! @@ -264,6 +264,160 @@ class MultiModelWorkflow: return result.output ``` +### Runtime Toolset Selection + +[`Agent.run(toolsets=...)`][pydantic_ai.agent.Agent.run] normally supports passing toolsets directly. However, `TemporalAgent` requires toolsets to be wrapped in a [`TemporalWrapperToolset`][pydantic_ai.durable_exec.temporal.TemporalWrapperToolset] and pre-registered because Temporal activities must be registered with the worker before the workflow starts. + +The preferred way is to pass toolsets to the `TemporalAgent` constructor and reference them by name at runtime using `run(toolsets=[...])`. + +Alternatively, for sharing toolsets across multiple agents, you can manually wrap them using [`temporalize_toolset`][pydantic_ai.durable_exec.temporal.temporalize_toolset]. + +Here's an example with named toolsets: + +```python {title="named_toolset_temporal.py" test="skip"} +from temporalio import workflow + +from pydantic_ai import Agent, FunctionToolset +from pydantic_ai.durable_exec.temporal import TemporalAgent + + +# Define tools and toolset +def magic_trick(input: str) -> str: + return f'Magic: {input}' + +magic_toolset = FunctionToolset(tools=[magic_trick], id='magic') + +# Create agent with pre-registered toolset +agent = Agent('openai:gpt-5', name='magic_agent') +temporal_agent = TemporalAgent( + agent, + toolsets={'magic_tools': magic_toolset}, # (1)! +) + +@workflow.defn +class MagicWorkflow: + __pydantic_ai_agents__ = [temporal_agent] + + @workflow.run + async def run(self, input: str) -> str: + # Reference toolset by name + result = await temporal_agent.run( + input, + toolsets=['magic_tools'], # (2)! + ) + return result.output +``` + +1. Pre-register toolsets by passing a dict to `TemporalAgent`. The keys become the toolset names. The toolsets are automatically wrapped for Temporal. +2. Reference toolsets by name in `run(toolsets=[...])`. + +For sharing toolsets across multiple agents, manually wrap them first using [`temporalize_toolset`][pydantic_ai.durable_exec.temporal.temporalize_toolset]: + +```python {title="shared_toolset_temporal.py" test="skip"} +from datetime import timedelta + +from temporalio import workflow +from temporalio.client import Client +from temporalio.worker import Worker +from temporalio.workflow import ActivityConfig + +from pydantic_ai import Agent, FunctionToolset +from pydantic_ai.durable_exec.temporal import TemporalAgent, temporalize_toolset + + +# Define shared tools +def web_search(query: str) -> str: + """Search the web for information.""" + # Actual web search implementation + return f'Search results for: {query}' + + +def calculate(expression: str) -> str: + """Evaluate a mathematical expression.""" + # Actual calculation implementation + return f'Result: {expression}' + + +# Create a shared toolset +shared_toolset = FunctionToolset( + tools=[web_search, calculate], + id='shared_tools', # (1)! +) + +# Wrap the toolset for Temporal +wrapped_shared_toolset = temporalize_toolset( + shared_toolset, + activity_name_prefix='shared', # (2)! + activity_config=ActivityConfig(start_to_close_timeout=timedelta(minutes=2)), +) + +# Create multiple agents that can use the shared toolset +research_agent = Agent( + 'openai:gpt-5', + name='research_agent', +) +math_agent = Agent( + 'anthropic:claude-sonnet-4.5', + name='math_agent', +) + +# Wrap agents for Temporal, pre-registering the shared toolset +temporal_research_agent = TemporalAgent( + research_agent, + toolsets={'shared': wrapped_shared_toolset}, # (3)! +) +temporal_math_agent = TemporalAgent( + math_agent, + toolsets={'shared': wrapped_shared_toolset}, +) + + +@workflow.defn +class SharedToolsetWorkflow: + __pydantic_ai_agents__ = [temporal_research_agent, temporal_math_agent] # (4)! + + @workflow.run + async def run(self, task: str, use_research: bool) -> str: + if use_research: + # Research agent uses shared toolset by name + result = await temporal_research_agent.run( + task, + toolsets=['shared'], # (5)! + ) + else: + # Math agent also uses the same shared toolset + result = await temporal_math_agent.run( + task, + toolsets=['shared'], + ) + return result.output + + +async def main(): + client = await Client.connect('localhost:7233') + + async with Worker( + client, + task_queue='shared-toolset-queue', + workflows=[SharedToolsetWorkflow], + # Toolset activities are automatically registered because they were passed + # to TemporalAgent(...toolsets=...) + ): + result = await client.execute_workflow( + SharedToolsetWorkflow.run, + args=['Search for Python tutorials', True], + id='shared-toolset-workflow', + task_queue='shared-toolset-queue', + ) + print(result) +``` + +1. The toolset must have a unique `id` to be used with Temporal. +2. The `activity_name_prefix` ensures activity names don't conflict across different toolset registrations. +3. Pass the wrapped toolset to `TemporalAgent` to pre-register it. This enables referencing it by name and ensures its activities are automatically registered with the worker. +4. The `__pydantic_ai_agents__` pattern automatically registers all agent activities (including the pre-registered toolset activities) with the workflow. +5. Reference the toolset by its registered name. + ## Activity Configuration Temporal activity configuration, like timeouts and retry policies, can be customized by passing [`temporalio.workflow.ActivityConfig`](https://python.temporal.io/temporalio.workflow.ActivityConfig.html) objects to the `TemporalAgent` constructor: diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/__init__.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/__init__.py index dc27bd9409..0e96843f26 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/__init__.py @@ -16,7 +16,7 @@ from ._agent import TemporalAgent from ._logfire import LogfirePlugin from ._run_context import TemporalRunContext -from ._toolset import TemporalWrapperToolset +from ._toolset import TemporalWrapperToolset, temporalize_toolset from ._workflow import PydanticAIWorkflow __all__ = [ @@ -26,6 +26,7 @@ 'AgentPlugin', 'TemporalRunContext', 'TemporalWrapperToolset', + 'temporalize_toolset', 'PydanticAIWorkflow', ] diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py index 7fb158fc66..e01e8d40d5 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py @@ -18,6 +18,7 @@ from pydantic_ai import ( AbstractToolset, AgentRunResultEvent, + FunctionToolset, _utils, messages as _messages, models, @@ -45,6 +46,63 @@ from ._toolset import TemporalWrapperToolset, temporalize_toolset +def _validate_temporal_toolsets(toolsets: Sequence[AbstractToolset[AgentDepsT]]) -> None: + """Validate that all leaf toolsets requiring temporal wrapping are properly wrapped. + + This function recursively traverses the toolset hierarchy and checks that any leaf + toolsets that need temporal wrapping (FunctionToolset, MCPServer, FastMCPToolset, DynamicToolset) + are wrapped in a TemporalWrapperToolset. + + Args: + toolsets: The toolsets to validate. + + Raises: + UserError: If an unwrapped leaf toolset is found that requires temporal wrapping. + The error message includes the toolset label for identification. + """ + + def validate_toolset(t: AbstractToolset[AgentDepsT]) -> None: + # If we encounter a TemporalWrapperToolset, we don't need to check its children + # since they're already wrapped + if isinstance(t, TemporalWrapperToolset): + return + + if isinstance(t, FunctionToolset): + raise UserError(f'Toolset {t.label} must be wrapped in a `TemporalWrapperToolset`.') + + # Check if this is a DynamicToolset that needs wrapping + from pydantic_ai.toolsets._dynamic import DynamicToolset + + if isinstance(t, DynamicToolset): + raise UserError(f'Toolset {t.label} must be wrapped in a `TemporalWrapperToolset`.') + + # Check if this is an MCPServer that needs wrapping + try: + from pydantic_ai.mcp import MCPServer + except ImportError: + pass + else: + if isinstance(t, MCPServer): + raise UserError(f'Toolset {t.label} must be wrapped in a `TemporalWrapperToolset`.') + + # Check if this is a FastMCPToolset that needs wrapping + try: + from pydantic_ai.toolsets.fastmcp import FastMCPToolset + except ImportError: + pass + else: + if isinstance(t, FastMCPToolset): + raise UserError(f'Toolset {t.label} must be wrapped in a `TemporalWrapperToolset`.') + + # For other toolsets (like CombinedToolset, WrapperToolset, etc.), + # we return them unchanged - apply will handle recursion + return # pragma: no cover - defensive code for future toolset types + + # Visit and validate each toolset recursively + for toolset in toolsets: + toolset.apply(validate_toolset) + + @dataclass @with_config(ConfigDict(arbitrary_types_allowed=True)) class _EventStreamHandlerParams: @@ -58,6 +116,7 @@ def __init__( wrapped: AbstractAgent[AgentDepsT, OutputDataT], *, name: str | None = None, + toolsets: Mapping[str, AbstractToolset[AgentDepsT]] | None = None, models: Mapping[str, Model] | None = None, provider_factory: TemporalProviderFactory | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, @@ -85,6 +144,11 @@ def __init__( Args: wrapped: The agent to wrap. name: Optional unique agent name to use in the Temporal activities' names. If not provided, the agent's `name` will be used. + toolsets: + Optional mapping of toolset names to toolset instances to register with the agent. + Toolsets passed here will be temporalized and their activities registered alongside the wrapped agent's existing toolsets. + Registered toolsets can be referenced by name in `run(toolsets=['name'])`. + models: Optional mapping of model instances to register with the agent. Keys define the names that can be referenced at runtime and the values are `Model` instances. @@ -195,9 +259,15 @@ def temporalize_toolset(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset activities.extend(toolset.temporal_activities) return toolset - temporal_toolsets = [toolset.visit_and_replace(temporalize_toolset) for toolset in wrapped.toolsets] + # Temporalize wrapped agent's toolsets + self._toolsets = [toolset.visit_and_replace(temporalize_toolset) for toolset in wrapped.toolsets] + + # Process additional toolsets (if provided) + # Temporalize named toolsets and store the mapping + self._named_toolsets: Mapping[str, AbstractToolset[AgentDepsT]] = { + name: toolset.visit_and_replace(temporalize_toolset) for name, toolset in (toolsets or {}).items() + } - self._toolsets = temporal_toolsets self._temporal_activities = activities self._temporal_overrides_active: ContextVar[bool] = ContextVar('_temporal_overrides_active', default=False) @@ -254,21 +324,30 @@ def temporal_activities(self) -> list[Callable[..., Any]]: @contextmanager def _temporal_overrides( - self, *, model: models.Model | models.KnownModelName | str | None = None, force: bool = False + self, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + force: bool = False, ) -> Iterator[None]: - """Context manager for workflow-specific overrides. + in_workflow = workflow.in_workflow() - When called outside a workflow, this is a no-op. - When called inside a workflow, it overrides the model and toolsets. - """ - if not workflow.in_workflow() and not force: - yield + if toolsets: + overridden_toolsets = [*self._toolsets, *toolsets] + else: + overridden_toolsets = list(self._toolsets) + + # Outside workflow, only apply toolsets override (model is passed directly to run) + if not in_workflow and not force: + if toolsets: + with super().override(toolsets=overridden_toolsets, tools=[]): + yield + else: + yield return - # We reset tools here as the temporalized function toolset is already in self._toolsets. - # Override model and set the model for workflow execution + # We reset tools here as the temporalized function toolset is already in overridden_toolsets. with ( - super().override(model=self._temporal_model, toolsets=self._toolsets, tools=[]), + super().override(model=self._temporal_model, toolsets=overridden_toolsets, tools=[]), self._temporal_model.using_model(model), _utils.disable_threads(), ): @@ -282,6 +361,45 @@ def _temporal_overrides( finally: self._temporal_overrides_active.reset(temporal_active_token) + def _resolve_toolsets( + self, toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None + ) -> Sequence[AbstractToolset[AgentDepsT]] | None: + if toolsets is None: + return None + + resolved_toolsets: list[AbstractToolset[AgentDepsT]] = [] + for t in toolsets: + if isinstance(t, str): + # String name: lookup in named toolsets + try: + resolved_toolsets.append(self._named_toolsets[t]) + except KeyError as e: + if not self._named_toolsets: + raise UserError(f"Unknown toolset name: '{t}'. No named toolsets registered.") from e + raise UserError( + f"Unknown toolset name: '{t}'. Available toolsets: {list(self._named_toolsets.keys())}" + ) from e + elif isinstance(t, TemporalWrapperToolset): + # Already a temporal wrapper: use as-is + resolved_toolsets.append(t) + else: + # Original toolset instance: find its temporal wrapper + # Check if this toolset instance is wrapped in any of our named toolsets + wrapper = next( + ( + wrapper + for wrapper in self._named_toolsets.values() + if isinstance(wrapper, TemporalWrapperToolset) and wrapper.wrapped is t + ), + None, + ) + if wrapper is not None: + resolved_toolsets.append(wrapper) + else: + # Not found in named toolsets, use as-is (will be validated later) + resolved_toolsets.append(t) + return resolved_toolsets + @overload async def run( self, @@ -297,7 +415,7 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None, builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AgentRunResult[OutputDataT]: ... @@ -317,7 +435,7 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None, builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... @@ -336,7 +454,7 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None, builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, **_deprecated_kwargs: Never, @@ -372,13 +490,27 @@ async def main(): usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. - toolsets: Optional additional toolsets for this run. + toolsets: Optional additional toolsets for this run. Can be toolset instances or strings + referencing toolsets registered by name in the agent constructor's `toolsets` parameter. event_stream_handler: Optional event stream handler to use for this run. builtin_tools: Optional additional builtin tools for this run. Returns: The result of the run. """ + # Validate and resolve toolsets at callsite + resolved_toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None + if toolsets: + if workflow.in_workflow(): + # Validate toolsets in workflow context + try: + _validate_temporal_toolsets([t for t in toolsets if not isinstance(t, str)]) + except UserError as e: + raise UserError( + f'Toolsets provided at runtime inside a Temporal workflow must be wrapped in a `TemporalWrapperToolset`. {e}' + ) from e + resolved_toolsets = self._resolve_toolsets(toolsets) + if workflow.in_workflow(): if event_stream_handler is not None: raise UserError( @@ -388,7 +520,7 @@ async def main(): else: resolved_model = self._temporal_model.resolve_model(model) - with self._temporal_overrides(model=model): + with self._temporal_overrides(toolsets=resolved_toolsets, model=model): return await super().run( user_prompt, output_type=output_type, @@ -401,7 +533,7 @@ async def main(): usage_limits=usage_limits, usage=usage, infer_name=infer_name, - toolsets=toolsets, + toolsets=None, # Toolsets are set via _temporal_overrides builtin_tools=builtin_tools, event_stream_handler=event_stream_handler or self.event_stream_handler, **_deprecated_kwargs, @@ -422,7 +554,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None, builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AgentRunResult[OutputDataT]: ... @@ -442,7 +574,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None, builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... @@ -461,7 +593,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None, builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, **_deprecated_kwargs: Never, @@ -495,7 +627,8 @@ def run_sync( usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. - toolsets: Optional additional toolsets for this run. + toolsets: Optional additional toolsets for this run. Can be toolset instances or strings + referencing toolsets registered by name in the agent constructor's `toolsets` parameter. event_stream_handler: Optional event stream handler to use for this run. builtin_tools: Optional additional builtin tools for this run. @@ -519,7 +652,7 @@ def run_sync( usage_limits=usage_limits, usage=usage, infer_name=infer_name, - toolsets=toolsets, + toolsets=self._resolve_toolsets(toolsets), builtin_tools=builtin_tools, event_stream_handler=event_stream_handler, **_deprecated_kwargs, @@ -540,7 +673,7 @@ def run_stream( usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None, builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, OutputDataT]]: ... @@ -560,7 +693,7 @@ def run_stream( usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None, builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... @@ -580,7 +713,7 @@ async def run_stream( usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None, builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, **_deprecated_kwargs: Never, @@ -612,7 +745,8 @@ async def main(): usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. - toolsets: Optional additional toolsets for this run. + toolsets: Optional additional toolsets for this run. Can be toolset instances or strings + referencing toolsets registered by name in the agent constructor's `toolsets` parameter. builtin_tools: Optional additional builtin tools for this run. event_stream_handler: Optional event stream handler to use for this run. It will receive all the events up until the final result is found, which you can then read or stream from inside the context manager. @@ -637,7 +771,7 @@ async def main(): usage_limits=usage_limits, usage=usage, infer_name=infer_name, - toolsets=toolsets, + toolsets=self._resolve_toolsets(toolsets), event_stream_handler=event_stream_handler, builtin_tools=builtin_tools, **_deprecated_kwargs, @@ -659,7 +793,7 @@ def run_stream_events( usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None, builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[OutputDataT]]: ... @@ -678,7 +812,7 @@ def run_stream_events( usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None, builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[RunOutputDataT]]: ... @@ -696,13 +830,12 @@ def run_stream_events( usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT] | str] | None = None, builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[Any]]: """Run the agent with a user prompt in async mode and stream events from the run. - This is a convenience method that wraps [`self.run`][pydantic_ai.agent.AbstractAgent.run] and - uses the `event_stream_handler` kwarg to get a stream of events from the run. + This is a convenience method that wraps [`self.run`][pydantic_ai.agent.AbstractAgent.run]. Example: ```python @@ -730,8 +863,7 @@ async def main(): ''' ``` - Arguments are the same as for [`self.run`][pydantic_ai.agent.AbstractAgent.run], - except that `event_stream_handler` is now allowed. + Arguments are the same as for [`self.run`][pydantic_ai.agent.AbstractAgent.run]. Args: user_prompt: User input to start/continue the conversation. @@ -746,7 +878,8 @@ async def main(): usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. - toolsets: Optional additional toolsets for this run. + toolsets: Optional additional toolsets for this run. Can be toolset instances or strings + referencing toolsets registered by name in the agent constructor's `toolsets` parameter. builtin_tools: Optional additional builtin tools for this run. Returns: @@ -771,7 +904,7 @@ async def main(): usage_limits=usage_limits, usage=usage, infer_name=infer_name, - toolsets=toolsets, + toolsets=self._resolve_toolsets(toolsets), builtin_tools=builtin_tools, ) @@ -920,12 +1053,17 @@ async def main(): 'Set an `event_stream_handler` on the agent and use `agent.run()` instead.' ) - assert model is None, 'Temporal overrides must set the model before `agent.iter()` is invoked' - - if toolsets is not None: + if model is not None: # pragma: no cover - defensive check for workflow execution path raise UserError( - 'Toolsets cannot be set at agent run time inside a Temporal workflow, it must be set at agent creation time.' + 'Model cannot be set at agent run time inside a Temporal workflow, it must be set at agent creation time.' ) + if toolsets is not None: # pragma: no cover - defensive check for workflow execution path + try: + _validate_temporal_toolsets(toolsets) + except UserError as e: + raise UserError( + f'Toolsets provided at runtime inside a Temporal workflow must be wrapped in a `TemporalWrapperToolset`. {e}' + ) from e resolved_model = None else: @@ -979,9 +1117,12 @@ def override( 'Model cannot be contextually overridden inside a Temporal workflow, it must be set at agent creation time.' ) if _utils.is_set(toolsets): - raise UserError( - 'Toolsets cannot be contextually overridden inside a Temporal workflow, they must be set at agent creation time.' - ) + try: + _validate_temporal_toolsets(toolsets) + except UserError as e: + raise UserError( + f'Toolsets cannot be contextually overridden inside a Temporal workflow, unless they are wrapped in a `TemporalWrapperToolset`. {e}' + ) from e if _utils.is_set(tools): raise UserError( 'Tools cannot be contextually overridden inside a Temporal workflow, they must be set at agent creation time.' diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_dynamic_toolset.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_dynamic_toolset.py index 696a91dcf8..9c56b98832 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_dynamic_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_dynamic_toolset.py @@ -47,15 +47,17 @@ def __init__( self, toolset: DynamicToolset[AgentDepsT], *, - activity_name_prefix: str, - activity_config: ActivityConfig, - tool_activity_config: dict[str, ActivityConfig | Literal[False]], - deps_type: type[AgentDepsT], + activity_name_prefix: str | None = None, + activity_config: ActivityConfig | None = None, + tool_activity_config: dict[str, ActivityConfig | Literal[False]] | None = None, + deps_type: type[AgentDepsT] | None = None, run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT], ): super().__init__(toolset) - self.activity_config = activity_config - self.tool_activity_config = tool_activity_config + from datetime import timedelta + + self.activity_config = activity_config or ActivityConfig(start_to_close_timeout=timedelta(minutes=1)) + self.tool_activity_config = tool_activity_config or {} self.run_context_type = run_context_type async def get_tools_activity(params: _GetToolsParams, deps: AgentDepsT) -> dict[str, _ToolInfo]: @@ -69,8 +71,10 @@ async def get_tools_activity(params: _GetToolsParams, deps: AgentDepsT) -> dict[ for name, tool in tools.items() } - get_tools_activity.__annotations__['deps'] = deps_type + # Set type hint explicitly so that Temporal can take care of serialization and deserialization + get_tools_activity.__annotations__['deps'] = deps_type or Any + activity_name_prefix = activity_name_prefix or '' self.get_tools_activity = activity.defn(name=f'{activity_name_prefix}__dynamic_toolset__{self.id}__get_tools')( get_tools_activity ) @@ -90,7 +94,8 @@ async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallTo return await self._call_tool_in_activity(params.name, params.tool_args, ctx, tool) - call_tool_activity.__annotations__['deps'] = deps_type + # Set type hint explicitly so that Temporal can take care of serialization and deserialization + call_tool_activity.__annotations__['deps'] = deps_type or Any self.call_tool_activity = activity.defn(name=f'{activity_name_prefix}__dynamic_toolset__{self.id}__call_tool')( call_tool_activity diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_fastmcp_toolset.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_fastmcp_toolset.py index 5682c32f2c..ac924122ff 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_fastmcp_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_fastmcp_toolset.py @@ -17,10 +17,10 @@ def __init__( self, toolset: FastMCPToolset[AgentDepsT], *, - activity_name_prefix: str, - activity_config: ActivityConfig, - tool_activity_config: dict[str, ActivityConfig | Literal[False]], - deps_type: type[AgentDepsT], + activity_name_prefix: str | None = None, + activity_config: ActivityConfig | None = None, + tool_activity_config: dict[str, ActivityConfig | Literal[False]] | None = None, + deps_type: type[AgentDepsT] | None = None, run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT], ): super().__init__( diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py index 5412825ea3..0f47e2084b 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Callable +from datetime import timedelta from typing import Any, Literal from temporalio import activity, workflow @@ -24,15 +25,15 @@ def __init__( self, toolset: FunctionToolset[AgentDepsT], *, - activity_name_prefix: str, - activity_config: ActivityConfig, - tool_activity_config: dict[str, ActivityConfig | Literal[False]], - deps_type: type[AgentDepsT], + activity_name_prefix: str | None = None, + activity_config: ActivityConfig | None = None, + tool_activity_config: dict[str, ActivityConfig | Literal[False]] | None = None, + deps_type: type[AgentDepsT] | None = None, run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT], ): super().__init__(toolset) - self.activity_config = activity_config - self.tool_activity_config = tool_activity_config + self.activity_config = activity_config or ActivityConfig(start_to_close_timeout=timedelta(minutes=1)) + self.tool_activity_config = tool_activity_config or {} self.run_context_type = run_context_type async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult: @@ -49,8 +50,9 @@ async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallTo return await self._call_tool_in_activity(name, params.tool_args, ctx, tool) # Set type hint explicitly so that Temporal can take care of serialization and deserialization - call_tool_activity.__annotations__['deps'] = deps_type + call_tool_activity.__annotations__['deps'] = deps_type or Any + activity_name_prefix = activity_name_prefix or '' self.call_tool_activity = activity.defn(name=f'{activity_name_prefix}__toolset__{self.id}__call_tool')( call_tool_activity ) @@ -62,7 +64,7 @@ def temporal_activities(self) -> list[Callable[..., Any]]: async def call_tool( self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] ) -> Any: - if not workflow.in_workflow(): # pragma: no cover + if not workflow.in_workflow(): return await super().call_tool(name, tool_args, ctx, tool) tool_activity_config = self.tool_activity_config.get(name, {}) diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_mcp.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_mcp.py index 92fad1c3c6..f62a004010 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_mcp.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_mcp.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass +from datetime import timedelta from typing import Any, Literal from pydantic import ConfigDict, with_config @@ -33,16 +34,17 @@ def __init__( self, toolset: AbstractToolset[AgentDepsT], *, - activity_name_prefix: str, - activity_config: ActivityConfig, - tool_activity_config: dict[str, ActivityConfig | Literal[False]], - deps_type: type[AgentDepsT], + activity_name_prefix: str | None = None, + activity_config: ActivityConfig | None = None, + tool_activity_config: dict[str, ActivityConfig | Literal[False]] | None = None, + deps_type: type[AgentDepsT] | None = None, run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT], ): super().__init__(toolset) - self.activity_config = activity_config + self.activity_config = activity_config or ActivityConfig(start_to_close_timeout=timedelta(minutes=1)) self.tool_activity_config: dict[str, ActivityConfig] = {} + tool_activity_config = tool_activity_config or {} for tool_name, tool_config in tool_activity_config.items(): if tool_config is False: raise UserError( @@ -61,7 +63,9 @@ async def get_tools_activity(params: _GetToolsParams, deps: AgentDepsT) -> dict[ return {name: tool.tool_def for name, tool in tools.items()} # Set type hint explicitly so that Temporal can take care of serialization and deserialization - get_tools_activity.__annotations__['deps'] = deps_type + get_tools_activity.__annotations__['deps'] = deps_type or Any + + activity_name_prefix = activity_name_prefix or '' self.get_tools_activity = activity.defn(name=f'{activity_name_prefix}__mcp_server__{self.id}__get_tools')( get_tools_activity @@ -80,7 +84,7 @@ async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallTo ) # Set type hint explicitly so that Temporal can take care of serialization and deserialization - call_tool_activity.__annotations__['deps'] = deps_type + call_tool_activity.__annotations__['deps'] = deps_type or Any self.call_tool_activity = activity.defn(name=f'{activity_name_prefix}__mcp_server__{self.id}__call_tool')( call_tool_activity diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_mcp_server.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_mcp_server.py index 8fe779239a..a2bf3c57a5 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_mcp_server.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_mcp_server.py @@ -17,10 +17,10 @@ def __init__( self, server: MCPServer, *, - activity_name_prefix: str, - activity_config: ActivityConfig, - tool_activity_config: dict[str, ActivityConfig | Literal[False]], - deps_type: type[AgentDepsT], + activity_name_prefix: str | None = None, + activity_config: ActivityConfig | None = None, + tool_activity_config: dict[str, ActivityConfig | Literal[False]] | None = None, + deps_type: type[AgentDepsT] | None = None, run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT], ): super().__init__( diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py index 5dd4465516..b14cbc3a65 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py @@ -75,13 +75,17 @@ def visit_and_replace( # Temporalized toolsets cannot be swapped out after the fact. return self + def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None: + # Stop recursion at temporalized toolsets - they're already wrapped and validated. + visitor(self) + async def __aenter__(self) -> Self: - if not workflow.in_workflow(): # pragma: no cover + if not workflow.in_workflow(): await self.wrapped.__aenter__() return self async def __aexit__(self, *args: Any) -> bool | None: - if not workflow.in_workflow(): # pragma: no cover + if not workflow.in_workflow(): return await self.wrapped.__aexit__(*args) return None @@ -126,10 +130,10 @@ async def _call_tool_in_activity( def temporalize_toolset( toolset: AbstractToolset[AgentDepsT], - activity_name_prefix: str, - activity_config: ActivityConfig, - tool_activity_config: dict[str, ActivityConfig | Literal[False]], - deps_type: type[AgentDepsT], + activity_name_prefix: str | None = None, + activity_config: ActivityConfig | None = None, + tool_activity_config: dict[str, ActivityConfig | Literal[False]] | None = None, + deps_type: type[AgentDepsT] | None = None, run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT], ) -> AbstractToolset[AgentDepsT]: """Temporalize a toolset. diff --git a/tests/test_temporal.py b/tests/test_temporal.py index e3ee550c34..e454c7edd2 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -17,6 +17,7 @@ AgentRunResultEvent, AgentStreamEvent, BinaryImage, + CombinedToolset, ExternalToolset, FinalResultEvent, FunctionToolCallEvent, @@ -96,6 +97,11 @@ except ImportError: # pragma: lax no cover pytest.skip('fastmcp not installed', allow_module_level=True) +try: + from pydantic_ai.toolsets._dynamic import DynamicToolset +except ImportError: # pragma: lax no cover + pytest.skip('dynamic toolset not available', allow_module_level=True) + try: from pydantic_ai.models.openai import OpenAIChatModel, OpenAIResponsesModel from pydantic_ai.providers.openai import OpenAIProvider @@ -1654,7 +1660,7 @@ async def test_temporal_agent_run_in_workflow_with_toolsets(allow_model_requests with workflow_raises( UserError, snapshot( - 'Toolsets cannot be set at agent run time inside a Temporal workflow, it must be set at agent creation time.' + 'Toolsets provided at runtime inside a Temporal workflow must be wrapped in a `TemporalWrapperToolset`. Toolset FunctionToolset must be wrapped in a `TemporalWrapperToolset`.' ), ): await client.execute_workflow( @@ -1712,7 +1718,7 @@ async def test_temporal_agent_override_toolsets_in_workflow(allow_model_requests with workflow_raises( UserError, snapshot( - 'Toolsets cannot be contextually overridden inside a Temporal workflow, they must be set at agent creation time.' + 'Toolsets cannot be contextually overridden inside a Temporal workflow, unless they are wrapped in a `TemporalWrapperToolset`. Toolset FunctionToolset must be wrapped in a `TemporalWrapperToolset`.' ), ): await client.execute_workflow( @@ -2970,3 +2976,439 @@ async def test_temporal_model_request_stream_outside_workflow(): # Verify response comes from the wrapped TestModel assert any(isinstance(part, TextPart) and part.content == 'Direct stream response' for part in response.parts) + + +combined_override_child_toolset_1 = FunctionToolset(id='combined_override_child_1') +combined_override_child_toolset_2 = FunctionToolset(id='combined_override_child_2') +combined_override_wrapped_toolset_1 = TemporalFunctionToolset(combined_override_child_toolset_1) +combined_override_wrapped_toolset_2 = TemporalFunctionToolset(combined_override_child_toolset_2) +combined_override_toolset = CombinedToolset([combined_override_wrapped_toolset_1, combined_override_wrapped_toolset_2]) + + +@workflow.defn +class SimpleAgentWorkflowWithOverrideCombinedToolsets: + @workflow.run + async def run(self, prompt: str) -> str: + with simple_temporal_agent.override(toolsets=[combined_override_toolset]): + return 'ok' + + +async def test_temporal_agent_override_combined_toolsets_in_workflow(allow_model_requests: None, client: Client): + async with Worker( + client, + task_queue=TASK_QUEUE, + workflows=[SimpleAgentWorkflowWithOverrideCombinedToolsets], + plugins=[AgentPlugin(simple_temporal_agent)], + ): + output = await client.execute_workflow( + SimpleAgentWorkflowWithOverrideCombinedToolsets.run, + args=['What is the capital of Mexico?'], + id=SimpleAgentWorkflowWithOverrideCombinedToolsets.__name__, + task_queue=TASK_QUEUE, + ) + assert output == 'ok' + + +# Dynamic agent with runtime toolset test +def echo_tool(x: str) -> str: + return f'echo: {x}' + + +# Create toolset for dynamic agent test +dynamic_test_toolset = FunctionToolset(tools=[echo_tool], id='my_tools') + +# Wrap toolset for Temporal +wrapped_dynamic_test_toolset = TemporalFunctionToolset( + dynamic_test_toolset, + activity_name_prefix='shared_tools', + activity_config=ActivityConfig(start_to_close_timeout=timedelta(minutes=1)), + tool_activity_config={}, + deps_type=type(None), +) + +# Create agent that will be used in workflow +# This demonstrates dynamic model selection with runtime toolset passing +dynamic_runtime_test_model = TestModel(call_tools=['echo_tool']) +dynamic_runtime_test_agent = Agent(dynamic_runtime_test_model, name='test_agent_dynamic_runtime') +dynamic_runtime_test_temporal_agent = TemporalAgent(dynamic_runtime_test_agent) + + +@workflow.defn +class DynamicAgentRuntimeToolsetWorkflow: + __pydantic_ai_agents__ = [dynamic_runtime_test_temporal_agent] + + @workflow.run + async def run(self, user_prompt: str) -> str: + # Use the pre-created agent but pass toolset at runtime + # This demonstrates decoupling tool registration from agent definition + result = await dynamic_runtime_test_temporal_agent.run(user_prompt, toolsets=[wrapped_dynamic_test_toolset]) + return result.output + + +async def test_dynamic_agent_with_runtime_toolset(allow_model_requests: None, client: Client): + """Test passing a TemporalWrapperToolset at runtime to a TemporalAgent within a workflow. + + This test demonstrates the pattern described in the issue where tools are registered + separately from agents, allowing dynamic agents to use a shared set of tools. + """ + async with Worker( + client, + task_queue=TASK_QUEUE, + workflows=[DynamicAgentRuntimeToolsetWorkflow], + # Only register the shared toolset activities + # Agent activities are automatically registered via __pydantic_ai_agents__ + activities=[ + *wrapped_dynamic_test_toolset.temporal_activities, + ], + ): + result = await client.execute_workflow( + DynamicAgentRuntimeToolsetWorkflow.run, + args=['test prompt'], + id='test-workflow-run-dynamic-runtime', + task_queue=TASK_QUEUE, + ) + + # Verify tool was called successfully + assert 'echo' in result + assert 'echo:' in result + + +# specific toolset for named registration test to avoid conflicts +named_test_toolset = FunctionToolset(tools=[echo_tool], id='named_tools') +wrapped_named_test_toolset = TemporalFunctionToolset( + named_test_toolset, + activity_name_prefix='named', + deps_type=type(None), +) +named_toolset_agent = TemporalAgent( + Agent(TestModel(), name='named_agent'), + name='test_agent_named_toolset', + toolsets={'shared_tools_name': wrapped_named_test_toolset}, +) + + +@workflow.defn +class DynamicAgentNamedToolsetWorkflowLocal: + __pydantic_ai_agents__ = [named_toolset_agent] + + @workflow.run + async def run(self, user_prompt: str) -> str: + # Reference toolset by name + result = await named_toolset_agent.run(user_prompt, toolsets=['shared_tools_name']) + return result.output + + +async def test_dynamic_agent_with_named_toolset(allow_model_requests: None, client: Client): + """Test passing a toolset name at runtime to a TemporalAgent within a workflow. + + This test checks that toolsets pre-registered with a name in TemporalAgent + can be referenced by that name in `run()`. + """ + async with Worker( + client, + task_queue=TASK_QUEUE, + workflows=[DynamicAgentNamedToolsetWorkflowLocal], + activities=[ + *wrapped_named_test_toolset.temporal_activities, + ], + ): + result = await client.execute_workflow( + DynamicAgentNamedToolsetWorkflowLocal.run, + args=['test prompt'], + id='test-workflow-run-named-toolset', + task_queue=TASK_QUEUE, + ) + + # Verify tool was called successfully + assert 'echo' in result + assert 'echo:' in result + + +@workflow.defn +class UnknownToolsetWorkflow: + __pydantic_ai_agents__ = [named_toolset_agent] + + @workflow.run + async def run(self, user_prompt: str) -> str: + # Reference a toolset name that doesn't exist + result = await named_toolset_agent.run(user_prompt, toolsets=['nonexistent_toolset']) + return result.output # pragma: no cover - workflow fails before reaching this line + + +async def test_unknown_toolset_name_error(allow_model_requests: None, client: Client): + """Test that referencing an unknown toolset name raises an error.""" + async with Worker( + client, + task_queue=TASK_QUEUE, + workflows=[UnknownToolsetWorkflow], + activities=[ + *wrapped_named_test_toolset.temporal_activities, + ], + ): + with pytest.raises(WorkflowFailureError) as exc_info: + await client.execute_workflow( + UnknownToolsetWorkflow.run, + args=['test prompt'], + id='test-workflow-unknown-toolset', + task_queue=TASK_QUEUE, + ) + assert isinstance(exc_info.value.__cause__, ApplicationError) + assert 'Unknown toolset name' in exc_info.value.__cause__.message + assert 'nonexistent_toolset' in exc_info.value.__cause__.message + + +# Create an agent without named toolsets for testing +agent_without_named = TemporalAgent( + Agent(TestModel(), name='no_named_agent'), + name='test_agent_no_named', +) + + +@workflow.defn +class NoNamedToolsetsWorkflow: + __pydantic_ai_agents__ = [agent_without_named] + + @workflow.run + async def run(self, user_prompt: str) -> str: + result = await agent_without_named.run(user_prompt, toolsets=['some_name']) + return result.output # pragma: no cover - workflow fails before reaching this line + + +async def test_unknown_toolset_name_when_no_named_toolsets(allow_model_requests: None, client: Client): + """Test that referencing a toolset name when no named toolsets are registered raises an error.""" + async with Worker( + client, + task_queue=TASK_QUEUE, + workflows=[NoNamedToolsetsWorkflow], + activities=[], + ): + with pytest.raises(WorkflowFailureError) as exc_info: + await client.execute_workflow( + NoNamedToolsetsWorkflow.run, + args=['test prompt'], + id='test-workflow-no-named-toolsets', + task_queue=TASK_QUEUE, + ) + assert isinstance(exc_info.value.__cause__, ApplicationError) + assert 'Unknown toolset name' in exc_info.value.__cause__.message + assert 'No named toolsets registered' in exc_info.value.__cause__.message + + +async def test_temporal_agent_run_with_toolsets_outside_workflow(): + """Test that calling run() with toolsets outside a workflow works correctly.""" + + # Create a simple tool + def simple_tool(ctx: RunContext[None]) -> str: + """A simple tool.""" + return 'tool result' + + extra_toolset = FunctionToolset(tools=[simple_tool], id='extra_tools') + wrapped_extra = TemporalFunctionToolset( + extra_toolset, + activity_name_prefix='extra', + deps_type=type(None), + ) + + agent = TemporalAgent( + Agent(TestModel(), name='test_with_toolsets'), + name='test_agent_with_toolsets', + ) + + # Outside workflow, toolsets should be applied correctly + result = await agent.run('test', toolsets=[wrapped_extra]) + assert result.output is not None + + +# Create an unwrapped DynamicToolset for testing + + +def _unwrapped_dynamic_toolset_func(ctx: RunContext[None]) -> FunctionToolset[None]: + return FunctionToolset[None]( + id='inner_toolset' + ) # pragma: no cover - test helper function not executed during error test + + +_unwrapped_dynamic_toolset = DynamicToolset(_unwrapped_dynamic_toolset_func, id='dynamic') + + +@workflow.defn +class WorkflowWithUnwrappedDynamicToolset: + @workflow.run + async def run(self, prompt: str) -> str: + result = await simple_temporal_agent.run(prompt, toolsets=[_unwrapped_dynamic_toolset]) + return result.output # pragma: no cover + + +async def test_temporal_agent_run_with_unwrapped_dynamic_toolset_error(allow_model_requests: None, client: Client): + """Test that passing an unwrapped DynamicToolset at runtime raises an error.""" + async with Worker( + client, + task_queue=TASK_QUEUE, + workflows=[WorkflowWithUnwrappedDynamicToolset], + plugins=[AgentPlugin(simple_temporal_agent)], + ): + with workflow_raises( + UserError, + snapshot( + "Toolsets provided at runtime inside a Temporal workflow must be wrapped in a `TemporalWrapperToolset`. Toolset DynamicToolset 'dynamic' must be wrapped in a `TemporalWrapperToolset`." + ), + ): + await client.execute_workflow( + WorkflowWithUnwrappedDynamicToolset.run, + args=['test'], + id=WorkflowWithUnwrappedDynamicToolset.__name__, + task_queue=TASK_QUEUE, + ) + + +# Create an unwrapped MCP server for testing + +_unwrapped_mcp_server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], timeout=20, id='mcp') + + +@workflow.defn +class WorkflowWithUnwrappedMCPServer: + @workflow.run + async def run(self, prompt: str) -> str: + result = await simple_temporal_agent.run(prompt, toolsets=[_unwrapped_mcp_server]) + return result.output # pragma: no cover + + +async def test_temporal_agent_run_with_unwrapped_mcp_server_error(client: Client): + """Test that passing an unwrapped MCPServer at runtime raises an error.""" + async with Worker( + client, + task_queue=TASK_QUEUE, + workflows=[WorkflowWithUnwrappedMCPServer], + plugins=[AgentPlugin(simple_temporal_agent)], + ): + with workflow_raises( + UserError, + snapshot( + "Toolsets provided at runtime inside a Temporal workflow must be wrapped in a `TemporalWrapperToolset`. Toolset MCPServerStdio 'mcp' must be wrapped in a `TemporalWrapperToolset`." + ), + ): + await client.execute_workflow( + WorkflowWithUnwrappedMCPServer.run, + args=['test'], + id=WorkflowWithUnwrappedMCPServer.__name__, + task_queue=TASK_QUEUE, + ) + + +# Create an unwrapped FastMCP toolset for testing + +_unwrapped_fastmcp_toolset = FastMCPToolset('https://mcp.deepwiki.com/mcp', id='deepwiki') + + +@workflow.defn +class WorkflowWithUnwrappedFastMCPToolset: + @workflow.run + async def run(self, prompt: str) -> str: + result = await simple_temporal_agent.run(prompt, toolsets=[_unwrapped_fastmcp_toolset]) + return result.output # pragma: no cover + + +async def test_temporal_agent_run_with_unwrapped_fastmcp_toolset_error(allow_model_requests: None, client: Client): + """Test that passing an unwrapped FastMCPToolset at runtime raises an error.""" + async with Worker( + client, + task_queue=TASK_QUEUE, + workflows=[WorkflowWithUnwrappedFastMCPToolset], + plugins=[AgentPlugin(simple_temporal_agent)], + ): + with workflow_raises( + UserError, + snapshot( + "Toolsets provided at runtime inside a Temporal workflow must be wrapped in a `TemporalWrapperToolset`. Toolset FastMCPToolset 'deepwiki' must be wrapped in a `TemporalWrapperToolset`." + ), + ): + await client.execute_workflow( + WorkflowWithUnwrappedFastMCPToolset.run, + args=['test'], + id=WorkflowWithUnwrappedFastMCPToolset.__name__, + task_queue=TASK_QUEUE, + ) + + +def test_validate_temporal_toolsets_logic(): + from collections.abc import Callable + from typing import Any + + from pydantic_ai import FunctionToolset + from pydantic_ai.durable_exec.temporal import TemporalWrapperToolset + from pydantic_ai.durable_exec.temporal._agent import ( + _validate_temporal_toolsets, # pyright: ignore[reportPrivateUsage] + ) + + class MockTemporalWrapper(TemporalWrapperToolset): + @property + def temporal_activities(self) -> list[Callable[..., Any]]: + return [] # pragma: no cover + + def tool_func(x: int) -> int: + return x # pragma: no cover + + # 1. Test Wrapped Toolset (Should Pass) + func_toolset = FunctionToolset(tools=[tool_func], id='test_func_pass') + wrapper = MockTemporalWrapper(func_toolset) + _validate_temporal_toolsets([wrapper]) + + # 2. Test Unwrapped Toolset (Should Fail) + func_toolset_fail = FunctionToolset(tools=[tool_func], id='test_func_fail') + with pytest.raises(UserError, match='must be wrapped in a `TemporalWrapperToolset`'): + _validate_temporal_toolsets([func_toolset_fail]) + + +def test_resolve_toolsets_logic(): + from collections.abc import Callable + from typing import Any + + from pydantic_ai import Agent, FunctionToolset + from pydantic_ai.durable_exec.temporal import TemporalAgent, TemporalWrapperToolset + + class MockWrapper(TemporalWrapperToolset): + @property + def temporal_activities(self) -> list[Callable[..., Any]]: + return [] # pragma: no cover + + base_agent = Agent(TestModel(), name='test-agent') + + # 1. No toolsets (None) + agent = TemporalAgent(base_agent) + assert agent._resolve_toolsets(None) is None # pyright: ignore[reportPrivateUsage] + + # 2. String lookup + t1 = FunctionToolset(tools=[], id='t1') + w1 = MockWrapper(t1) + agent_with_tools = TemporalAgent(base_agent, toolsets={'my_tool': w1}) + + # Found + assert agent_with_tools._resolve_toolsets(['my_tool']) == [w1] # pyright: ignore[reportPrivateUsage] + + # Not found - No named toolsets registered + with pytest.raises(UserError, match=r"Unknown toolset name: 'missing'. No named toolsets registered."): + agent._resolve_toolsets(['missing']) # pyright: ignore[reportPrivateUsage] + + # Not found - With available toolsets + with pytest.raises(UserError, match=r"Unknown toolset name: 'missing'. Available toolsets: \['my_tool'\]"): + agent_with_tools._resolve_toolsets(['missing']) # pyright: ignore[reportPrivateUsage] + + # 3. Already a wrapper + assert agent._resolve_toolsets([w1]) == [w1] # pyright: ignore[reportPrivateUsage] + + # 4. Original instance auto-resolution - wrapper found + # w1 wraps t1, so passing t1 should return w1 if it's in named_toolsets + result = agent_with_tools._resolve_toolsets([t1]) # pyright: ignore[reportPrivateUsage] + assert result is not None + assert result == [w1] + assert len(result) == 1 + assert isinstance(result[0], MockWrapper) + + # 5. Original instance not found - returns as-is + t2 = FunctionToolset(tools=[], id='t2') + result = agent_with_tools._resolve_toolsets([t2]) # pyright: ignore[reportPrivateUsage] + assert result is not None + assert result == [t2] + assert len(result) == 1 + assert result[0] is t2