|
71 | 71 | PydanticAIWorkflow, |
72 | 72 | TemporalAgent, |
73 | 73 | ) |
| 74 | + from pydantic_ai.durable_exec.temporal._dynamic_toolset import TemporalDynamicToolset |
74 | 75 | from pydantic_ai.durable_exec.temporal._function_toolset import TemporalFunctionToolset |
75 | 76 | from pydantic_ai.durable_exec.temporal._mcp_server import TemporalMCPServer |
76 | 77 | from pydantic_ai.durable_exec.temporal._model import TemporalModel |
@@ -3117,3 +3118,281 @@ async def test_dynamic_agent_with_named_toolset(allow_model_requests: None, clie |
3117 | 3118 | # Verify tool was called successfully |
3118 | 3119 | assert 'echo' in result |
3119 | 3120 | assert 'echo:' in result |
| 3121 | + |
| 3122 | + |
| 3123 | +def test_temporal_dynamic_toolset_errors(): |
| 3124 | + """Test that TemporalDynamicToolset raises errors when activity_name_prefix or deps_type is None.""" |
| 3125 | + from pydantic_ai.toolsets._dynamic import DynamicToolset |
| 3126 | + |
| 3127 | + def toolset_func(ctx: RunContext[None]) -> FunctionToolset[None]: |
| 3128 | + return FunctionToolset[None](id='test') |
| 3129 | + |
| 3130 | + dynamic_toolset = DynamicToolset(toolset_func, id='test_toolset') |
| 3131 | + |
| 3132 | + # Test missing activity_name_prefix |
| 3133 | + with pytest.raises(UserError, match='activity_name_prefix is required'): |
| 3134 | + TemporalDynamicToolset(dynamic_toolset, activity_name_prefix=None, deps_type=type(None)) |
| 3135 | + |
| 3136 | + # Test missing deps_type |
| 3137 | + with pytest.raises(UserError, match='deps_type is required'): |
| 3138 | + TemporalDynamicToolset(dynamic_toolset, activity_name_prefix='test', deps_type=None) |
| 3139 | + |
| 3140 | + |
| 3141 | +async def test_validate_temporal_toolsets_errors(allow_model_requests: None, client: Client): |
| 3142 | + """Test that _validate_temporal_toolsets raises errors for unwrapped toolsets.""" |
| 3143 | + from pydantic_ai.toolsets._dynamic import DynamicToolset |
| 3144 | + |
| 3145 | + # Create agents with unwrapped toolsets |
| 3146 | + def toolset_func(ctx: RunContext[None]) -> FunctionToolset[None]: |
| 3147 | + return FunctionToolset[None](id='test') |
| 3148 | + |
| 3149 | + dynamic_toolset = DynamicToolset(toolset_func, id='test_dynamic') |
| 3150 | + mcp_server = MCPServerStdio('echo', ['hello']) |
| 3151 | + fastmcp_toolset = FastMCPToolset('http://example.com', id='test_fastmcp') |
| 3152 | + |
| 3153 | + # Test DynamicToolset error |
| 3154 | + base_agent_dynamic = Agent(model, name='agent_with_dynamic', toolsets=[dynamic_toolset]) |
| 3155 | + agent_with_dynamic = TemporalAgent(base_agent_dynamic) |
| 3156 | + |
| 3157 | + @workflow.defn(name='TestDynamicToolsetError') |
| 3158 | + class TestDynamicToolsetErrorWorkflow: |
| 3159 | + @workflow.run |
| 3160 | + async def run(self) -> str: |
| 3161 | + result = await agent_with_dynamic.run('test') |
| 3162 | + return result.output |
| 3163 | + |
| 3164 | + async with Worker( |
| 3165 | + client, |
| 3166 | + task_queue=TASK_QUEUE, |
| 3167 | + workflows=[TestDynamicToolsetErrorWorkflow], |
| 3168 | + ): |
| 3169 | + with workflow_raises( |
| 3170 | + UserError, |
| 3171 | + 'Toolsets of type DynamicToolset must be wrapped with TemporalDynamicToolset before being added to a TemporalAgent.', |
| 3172 | + ): |
| 3173 | + await client.execute_workflow( |
| 3174 | + TestDynamicToolsetErrorWorkflow.run, |
| 3175 | + id='test-workflow-dynamic-error', |
| 3176 | + task_queue=TASK_QUEUE, |
| 3177 | + ) |
| 3178 | + |
| 3179 | + # Test MCPServer error |
| 3180 | + base_agent_mcp = Agent(model, name='agent_with_mcp', toolsets=[mcp_server]) |
| 3181 | + agent_with_mcp = TemporalAgent(base_agent_mcp) |
| 3182 | + |
| 3183 | + @workflow.defn(name='TestMCPServerError') |
| 3184 | + class TestMCPServerErrorWorkflow: |
| 3185 | + @workflow.run |
| 3186 | + async def run(self) -> str: |
| 3187 | + result = await agent_with_mcp.run('test') |
| 3188 | + return result.output |
| 3189 | + |
| 3190 | + async with Worker( |
| 3191 | + client, |
| 3192 | + task_queue=TASK_QUEUE, |
| 3193 | + workflows=[TestMCPServerErrorWorkflow], |
| 3194 | + ): |
| 3195 | + with workflow_raises( |
| 3196 | + UserError, |
| 3197 | + 'Toolsets of type MCPServer must be wrapped with TemporalMCPServer before being added to a TemporalAgent.', |
| 3198 | + ): |
| 3199 | + await client.execute_workflow( |
| 3200 | + TestMCPServerErrorWorkflow.run, |
| 3201 | + id='test-workflow-mcp-error', |
| 3202 | + task_queue=TASK_QUEUE, |
| 3203 | + ) |
| 3204 | + |
| 3205 | + # Test FastMCPToolset error |
| 3206 | + base_agent_fastmcp = Agent(model, name='agent_with_fastmcp', toolsets=[fastmcp_toolset]) |
| 3207 | + agent_with_fastmcp = TemporalAgent(base_agent_fastmcp) |
| 3208 | + |
| 3209 | + @workflow.defn(name='TestFastMCPToolsetError') |
| 3210 | + class TestFastMCPToolsetErrorWorkflow: |
| 3211 | + @workflow.run |
| 3212 | + async def run(self) -> str: |
| 3213 | + result = await agent_with_fastmcp.run('test') |
| 3214 | + return result.output |
| 3215 | + |
| 3216 | + async with Worker( |
| 3217 | + client, |
| 3218 | + task_queue=TASK_QUEUE, |
| 3219 | + workflows=[TestFastMCPToolsetErrorWorkflow], |
| 3220 | + ): |
| 3221 | + with workflow_raises( |
| 3222 | + UserError, |
| 3223 | + 'Toolsets of type FastMCPToolset must be wrapped with TemporalFastMCPToolset before being added to a TemporalAgent.', |
| 3224 | + ): |
| 3225 | + await client.execute_workflow( |
| 3226 | + TestFastMCPToolsetErrorWorkflow.run, |
| 3227 | + id='test-workflow-fastmcp-error', |
| 3228 | + task_queue=TASK_QUEUE, |
| 3229 | + ) |
| 3230 | + |
| 3231 | + |
| 3232 | +async def test_temporal_agent_toolsets_as_list(allow_model_requests: None, client: Client): |
| 3233 | + """Test passing toolsets as a list instead of a mapping.""" |
| 3234 | + |
| 3235 | + # Create agent with toolsets as a list |
| 3236 | + toolset1 = FunctionToolset[None](id='toolset1') |
| 3237 | + |
| 3238 | + @toolset1.tool |
| 3239 | + def echo_tool(ctx: RunContext[None], text: str) -> str: |
| 3240 | + return f'echo: {text}' |
| 3241 | + |
| 3242 | + wrapped_toolset1 = TemporalFunctionToolset(toolset1, deps_type=type(None)) |
| 3243 | + |
| 3244 | + base_agent = Agent( |
| 3245 | + model, |
| 3246 | + name='agent_list_toolsets', |
| 3247 | + toolsets=[wrapped_toolset1], # List instead of mapping |
| 3248 | + ) |
| 3249 | + agent = TemporalAgent(base_agent) |
| 3250 | + |
| 3251 | + @workflow.defn(name='TestListToolsets') |
| 3252 | + class TestListToolsetsWorkflow: |
| 3253 | + @workflow.run |
| 3254 | + async def run(self) -> str: |
| 3255 | + result = await agent.run('test') |
| 3256 | + return result.output |
| 3257 | + |
| 3258 | + async with Worker( |
| 3259 | + client, |
| 3260 | + task_queue=TASK_QUEUE, |
| 3261 | + workflows=[TestListToolsetsWorkflow], |
| 3262 | + activities=[*wrapped_toolset1.temporal_activities], |
| 3263 | + ): |
| 3264 | + result = await client.execute_workflow( |
| 3265 | + TestListToolsetsWorkflow.run, |
| 3266 | + id='test-workflow-list-toolsets', |
| 3267 | + task_queue=TASK_QUEUE, |
| 3268 | + ) |
| 3269 | + assert 'echo' in result |
| 3270 | + |
| 3271 | + |
| 3272 | +async def test_temporal_agent_override_outside_workflow(): |
| 3273 | + """Test that override works outside workflow with toolsets.""" |
| 3274 | + |
| 3275 | + toolset = FunctionToolset[None](id='test_toolset') |
| 3276 | + |
| 3277 | + @toolset.tool |
| 3278 | + def echo_tool(ctx: RunContext[None], text: str) -> str: |
| 3279 | + return f'echo: {text}' |
| 3280 | + |
| 3281 | + wrapped_toolset = TemporalFunctionToolset(toolset, deps_type=type(None)) |
| 3282 | + |
| 3283 | + base_agent = Agent(model, name='agent_override_outside') |
| 3284 | + agent = TemporalAgent(base_agent) |
| 3285 | + |
| 3286 | + # Test override outside workflow |
| 3287 | + with agent.override(toolsets=[wrapped_toolset]): |
| 3288 | + result = await agent.run('test prompt') |
| 3289 | + assert 'echo' in result.output |
| 3290 | + |
| 3291 | + |
| 3292 | +async def test_temporal_agent_unknown_toolset_name(allow_model_requests: None, client: Client): |
| 3293 | + """Test that using an unknown toolset name raises an error.""" |
| 3294 | + |
| 3295 | + toolset = FunctionToolset[None](id='test_toolset') |
| 3296 | + |
| 3297 | + @toolset.tool |
| 3298 | + def echo_tool(ctx: RunContext[None], text: str) -> str: |
| 3299 | + return f'echo: {text}' |
| 3300 | + |
| 3301 | + wrapped_toolset = TemporalFunctionToolset(toolset, deps_type=type(None)) |
| 3302 | + |
| 3303 | + base_agent = Agent( |
| 3304 | + model, |
| 3305 | + name='agent_unknown_toolset', |
| 3306 | + ) |
| 3307 | + agent = TemporalAgent(base_agent, toolsets={'known_toolset': wrapped_toolset}) |
| 3308 | + |
| 3309 | + @workflow.defn(name='TestUnknownToolsetName') |
| 3310 | + class TestUnknownToolsetNameWorkflow: |
| 3311 | + @workflow.run |
| 3312 | + async def run(self) -> str: |
| 3313 | + # Try to use an unknown toolset name |
| 3314 | + result = await agent.run('test', toolsets=['unknown_toolset']) |
| 3315 | + return result.output |
| 3316 | + |
| 3317 | + async with Worker( |
| 3318 | + client, |
| 3319 | + task_queue=TASK_QUEUE, |
| 3320 | + workflows=[TestUnknownToolsetNameWorkflow], |
| 3321 | + activities=[*wrapped_toolset.temporal_activities], |
| 3322 | + ): |
| 3323 | + with workflow_raises( |
| 3324 | + UserError, |
| 3325 | + "Unknown toolset name: 'unknown_toolset'. Available toolsets: ['known_toolset']", |
| 3326 | + ): |
| 3327 | + await client.execute_workflow( |
| 3328 | + TestUnknownToolsetNameWorkflow.run, |
| 3329 | + id='test-workflow-unknown-toolset', |
| 3330 | + task_queue=TASK_QUEUE, |
| 3331 | + ) |
| 3332 | + |
| 3333 | + |
| 3334 | +async def test_temporal_agent_no_named_toolsets(allow_model_requests: None, client: Client): |
| 3335 | + """Test that using a toolset name when no named toolsets are registered raises an error.""" |
| 3336 | + base_agent = Agent( |
| 3337 | + model, |
| 3338 | + name='agent_no_named_toolsets', |
| 3339 | + ) |
| 3340 | + agent = TemporalAgent(base_agent) |
| 3341 | + |
| 3342 | + @workflow.defn(name='TestNoNamedToolsets') |
| 3343 | + class TestNoNamedToolsetsWorkflow: |
| 3344 | + @workflow.run |
| 3345 | + async def run(self) -> str: |
| 3346 | + # Try to use a toolset name when none are registered |
| 3347 | + result = await agent.run('test', toolsets=['some_toolset']) |
| 3348 | + return result.output |
| 3349 | + |
| 3350 | + async with Worker( |
| 3351 | + client, |
| 3352 | + task_queue=TASK_QUEUE, |
| 3353 | + workflows=[TestNoNamedToolsetsWorkflow], |
| 3354 | + ): |
| 3355 | + with workflow_raises( |
| 3356 | + UserError, |
| 3357 | + "Unknown toolset name: 'some_toolset'. No named toolsets registered.", |
| 3358 | + ): |
| 3359 | + await client.execute_workflow( |
| 3360 | + TestNoNamedToolsetsWorkflow.run, |
| 3361 | + id='test-workflow-no-named-toolsets', |
| 3362 | + task_queue=TASK_QUEUE, |
| 3363 | + ) |
| 3364 | + |
| 3365 | + |
| 3366 | +async def test_temporal_mcp_with_deps_type(allow_model_requests: None, client: Client): |
| 3367 | + """Test TemporalMCPServer with deps_type set.""" |
| 3368 | + |
| 3369 | + @dataclass |
| 3370 | + class MCPDeps: |
| 3371 | + value: str = 'test' |
| 3372 | + |
| 3373 | + mock_server = MCPServerStdio('echo', ['hello']) |
| 3374 | + wrapped_mcp = TemporalMCPServer(mock_server, activity_name_prefix='test_mcp', deps_type=MCPDeps) |
| 3375 | + |
| 3376 | + base_agent = Agent(model, name='agent_mcp_deps', toolsets=[wrapped_mcp], deps_type=MCPDeps) |
| 3377 | + agent = TemporalAgent(base_agent) |
| 3378 | + |
| 3379 | + @workflow.defn(name='TestMCPWithDepsType') |
| 3380 | + class TestMCPWithDepsTypeWorkflow: |
| 3381 | + @workflow.run |
| 3382 | + async def run(self, deps: MCPDeps) -> str: |
| 3383 | + result = await agent.run('test', deps=deps) |
| 3384 | + return result.output |
| 3385 | + |
| 3386 | + async with Worker( |
| 3387 | + client, |
| 3388 | + task_queue=TASK_QUEUE, |
| 3389 | + workflows=[TestMCPWithDepsTypeWorkflow], |
| 3390 | + activities=[*wrapped_mcp.temporal_activities], |
| 3391 | + ): |
| 3392 | + result = await client.execute_workflow( |
| 3393 | + TestMCPWithDepsTypeWorkflow.run, |
| 3394 | + args=[MCPDeps()], |
| 3395 | + id='test-workflow-mcp-deps', |
| 3396 | + task_queue=TASK_QUEUE, |
| 3397 | + ) |
| 3398 | + assert isinstance(result, str) |
0 commit comments