Skip to content

Commit 9a6fe3b

Browse files
committed
Allow dynamic in Temporal workflows
1 parent 6c32588 commit 9a6fe3b

File tree

4 files changed

+120
-11
lines changed

4 files changed

+120
-11
lines changed

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from ...exceptions import UserError
1616
from ._agent import TemporalAgent
17+
from ._function_toolset import TemporalFunctionToolset
1718
from ._logfire import LogfirePlugin
1819
from ._run_context import TemporalRunContext
1920
from ._toolset import TemporalWrapperToolset
@@ -26,6 +27,7 @@
2627
'AgentPlugin',
2728
'TemporalRunContext',
2829
'TemporalWrapperToolset',
30+
'TemporalFunctionToolset',
2931
'PydanticAIWorkflow',
3032
]
3133

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,21 @@ def temporal_activities(self) -> list[Callable[..., Any]]:
242242
return self._temporal_activities
243243

244244
@contextmanager
245-
def _temporal_overrides(self) -> Iterator[None]:
246-
# We reset tools here as the temporalized function toolset is already in self._toolsets.
245+
def _temporal_overrides(self, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None) -> Iterator[None]:
246+
in_workflow = workflow.in_workflow()
247+
248+
if toolsets:
249+
if in_workflow and any(not isinstance(t, TemporalWrapperToolset) for t in toolsets):
250+
raise UserError(
251+
'Toolsets provided at runtime inside a Temporal workflow must be wrapped in a `TemporalWrapperToolset`.'
252+
)
253+
overridden_toolsets = [*self._toolsets, *toolsets]
254+
else:
255+
overridden_toolsets = list(self._toolsets)
256+
257+
# We reset tools here as the temporalized function toolset is already in overridden_toolsets.
247258
with (
248-
super().override(model=self._model, toolsets=self._toolsets, tools=[]),
259+
super().override(model=self._model, toolsets=overridden_toolsets, tools=[]),
249260
_utils.disable_threads(),
250261
):
251262
temporal_active_token = self._temporal_overrides_active.set(True)
@@ -359,7 +370,7 @@ async def main():
359370
'Event stream handler cannot be set at agent run time inside a Temporal workflow, it must be set at agent creation time.'
360371
)
361372

362-
with self._temporal_overrides():
373+
with self._temporal_overrides(toolsets=toolsets):
363374
return await super().run(
364375
user_prompt,
365376
output_type=output_type,
@@ -372,7 +383,7 @@ async def main():
372383
usage_limits=usage_limits,
373384
usage=usage,
374385
infer_name=infer_name,
375-
toolsets=toolsets,
386+
toolsets=None, # Toolsets are set via _temporal_overrides
376387
builtin_tools=builtin_tools,
377388
event_stream_handler=event_stream_handler or self.event_stream_handler,
378389
**_deprecated_kwargs,
@@ -895,9 +906,9 @@ async def main():
895906
raise UserError(
896907
'Model cannot be set at agent run time inside a Temporal workflow, it must be set at agent creation time.'
897908
)
898-
if toolsets is not None:
909+
if toolsets is not None and any(not isinstance(t, TemporalWrapperToolset) for t in toolsets):
899910
raise UserError(
900-
'Toolsets cannot be set at agent run time inside a Temporal workflow, it must be set at agent creation time.'
911+
'Toolsets cannot be set at agent run time inside a Temporal workflow, unless they are wrapped in a `TemporalWrapperToolset`.'
901912
)
902913

903914
async with super().iter(
@@ -947,9 +958,9 @@ def override(
947958
raise UserError(
948959
'Model cannot be contextually overridden inside a Temporal workflow, it must be set at agent creation time.'
949960
)
950-
if _utils.is_set(toolsets):
961+
if _utils.is_set(toolsets) and any(not isinstance(t, TemporalWrapperToolset) for t in toolsets):
951962
raise UserError(
952-
'Toolsets cannot be contextually overridden inside a Temporal workflow, they must be set at agent creation time.'
963+
'Toolsets cannot be contextually overridden inside a Temporal workflow, unless they are wrapped in a `TemporalWrapperToolset`.'
953964
)
954965
if _utils.is_set(tools):
955966
raise UserError(

tests/test_temporal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,7 +1512,7 @@ async def test_temporal_agent_run_in_workflow_with_toolsets(allow_model_requests
15121512
with workflow_raises(
15131513
UserError,
15141514
snapshot(
1515-
'Toolsets cannot be set at agent run time inside a Temporal workflow, it must be set at agent creation time.'
1515+
'Toolsets provided at runtime inside a Temporal workflow must be wrapped in a `TemporalWrapperToolset`.'
15161516
),
15171517
):
15181518
await client.execute_workflow(
@@ -1570,7 +1570,7 @@ async def test_temporal_agent_override_toolsets_in_workflow(allow_model_requests
15701570
with workflow_raises(
15711571
UserError,
15721572
snapshot(
1573-
'Toolsets cannot be contextually overridden inside a Temporal workflow, they must be set at agent creation time.'
1573+
'Toolsets cannot be contextually overridden inside a Temporal workflow, unless they are wrapped in a `TemporalWrapperToolset`.'
15741574
),
15751575
):
15761576
await client.execute_workflow(

tests/test_temporal_dynamic.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from datetime import timedelta
2+
3+
import pytest
4+
from temporalio import workflow
5+
from temporalio.testing import WorkflowEnvironment
6+
from temporalio.worker import Worker
7+
from temporalio.workflow import ActivityConfig
8+
9+
from pydantic_ai import Agent, FunctionToolset
10+
from pydantic_ai.durable_exec.temporal import PydanticAIPlugin, TemporalAgent, TemporalFunctionToolset
11+
from pydantic_ai.models.test import TestModel
12+
13+
14+
# 1. Define Tool
15+
def echo(x: str) -> str:
16+
return f'echo: {x}'
17+
18+
19+
# 2. Create Toolset with specific ID
20+
# Using an explicit ID allows us to reference it later if needed,
21+
# though here we pass the toolset object directly.
22+
toolset = FunctionToolset(tools=[echo], id='my_tools')
23+
24+
# 3. Wrap Toolset for Temporal (DouweM pattern)
25+
wrapped_toolset = TemporalFunctionToolset(
26+
toolset,
27+
activity_name_prefix='shared_tools',
28+
activity_config=ActivityConfig(start_to_close_timeout=timedelta(minutes=1)),
29+
tool_activity_config={},
30+
deps_type=type(None),
31+
)
32+
33+
# 4. Create base agent for model activity registration
34+
# This agent's activities will be registered in the Worker
35+
# We use a known name "test_agent" so the dynamic agent can share it.
36+
base_model = TestModel()
37+
base_agent = Agent(base_model, name='test_agent')
38+
base_temporal_agent = TemporalAgent(base_agent)
39+
40+
41+
# 5. Define Workflow
42+
@workflow.defn
43+
class DynamicToolWorkflow:
44+
@workflow.run
45+
async def run(self, user_prompt: str) -> str:
46+
# Create agent dynamically within the workflow
47+
# Note: We are using TestModel which mocks LLM behavior.
48+
# We explicitly tell TestModel to call the 'echo' tool.
49+
model = TestModel(call_tools=['echo'])
50+
51+
# We reuse the name "test_agent" so that the model activities
52+
# (which are registered under that name) can be found.
53+
# TEST: Revert to run-time passing
54+
agent = Agent(
55+
model,
56+
name='test_agent',
57+
)
58+
59+
temporal_agent = TemporalAgent(agent)
60+
61+
# Pass wrapped toolset at runtime
62+
result = await temporal_agent.run(user_prompt, toolsets=[wrapped_toolset])
63+
return result.output
64+
65+
66+
# 6. Test
67+
pytestmark = pytest.mark.anyio
68+
69+
70+
async def test_dynamic_tool_registration():
71+
"""Test passing a `TemporalWrapperToolset` at runtime to a `TemporalAgent` within a workflow."""
72+
env = await WorkflowEnvironment.start_local() # type: ignore[reportUnknownMemberType]
73+
async with env:
74+
async with Worker(
75+
env.client,
76+
task_queue='test-queue',
77+
workflows=[DynamicToolWorkflow],
78+
# Register activities from both base agent and shared toolset
79+
activities=[
80+
*base_temporal_agent.temporal_activities,
81+
*wrapped_toolset.temporal_activities,
82+
],
83+
plugins=[PydanticAIPlugin()],
84+
):
85+
result = await env.client.execute_workflow(
86+
DynamicToolWorkflow.run,
87+
args=['test prompt'],
88+
id='test-workflow-run',
89+
task_queue='test-queue',
90+
)
91+
92+
# Verify tool was called successfully
93+
# TestModel generates random args, so we just verify echo was called
94+
# "echo" is the tool return value format: "echo: {arg}"
95+
assert 'echo' in result
96+
assert 'echo:' in result

0 commit comments

Comments
 (0)