Skip to content

Commit ad5ba4f

Browse files
authored
Allow TemporalAgent to switch model at agent.run-time (#3537)
1 parent e6b7f1f commit ad5ba4f

File tree

10 files changed

+695
-81
lines changed

10 files changed

+695
-81
lines changed

docs/durable_execution/temporal.md

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,86 @@ As the streaming model request activity, workflow, and workflow execution call a
184184
- To get data from the workflow call site or workflow to the event stream handler, you can use a [dependencies object](#agent-run-context-and-dependencies).
185185
- To get data from the event stream handler to the workflow, workflow call site, or a frontend, you need to use an external system that the event stream handler can write to and the event consumer can read from, like a message queue. You can use the dependency object to make sure the same connection string or other unique ID is available in all the places that need it.
186186

187+
### Model Selection at Runtime
188+
189+
[`Agent.run(model=...)`][pydantic_ai.agent.Agent.run] normally supports both model strings (like `'openai:gpt-5.2'`) and model instances. However, `TemporalAgent` does not support arbitrary model instances because they cannot be serialized for Temporal's replay mechanism.
190+
191+
To use model instances with `TemporalAgent`, you need to pre-register them by passing a dict of model instances to `TemporalAgent(models={...})`. You can then reference them by name or by passing the registered instance directly. If the wrapped agent doesn't have a model set, the first registered model will be used as the default.
192+
193+
Model strings work as expected. For scenarios where you need to customize the provider used by the model string (e.g., inject API keys from deps), you can pass a `provider_factory` to `TemporalAgent`, which is passed the [`RunContext`][pydantic_ai.tools.RunContext] and provider name.
194+
195+
Here's an example showing how to pre-register and use multiple models:
196+
197+
```python {title="multi_model_temporal.py" test="skip"}
198+
from dataclasses import dataclass
199+
from typing import Any
200+
201+
from temporalio import workflow
202+
203+
from pydantic_ai import Agent
204+
from pydantic_ai.durable_exec.temporal import TemporalAgent
205+
from pydantic_ai.models.anthropic import AnthropicModel
206+
from pydantic_ai.models.google import GoogleModel
207+
from pydantic_ai.models.openai import OpenAIResponsesModel
208+
from pydantic_ai.providers import Provider
209+
from pydantic_ai.tools import RunContext
210+
211+
212+
@dataclass
213+
class Deps:
214+
openai_api_key: str | None = None
215+
anthropic_api_key: str | None = None
216+
217+
218+
# Create models from different providers
219+
default_model = OpenAIResponsesModel('gpt-5.2')
220+
fast_model = AnthropicModel('claude-sonnet-4-5')
221+
reasoning_model = GoogleModel('gemini-2.5-pro')
222+
223+
224+
# Optional: provider factory for dynamic model configuration
225+
def my_provider_factory(run_context: RunContext[Deps], provider_name: str) -> Provider[Any]:
226+
"""Create providers with custom configuration based on run context."""
227+
if provider_name == 'openai':
228+
from pydantic_ai.providers.openai import OpenAIProvider
229+
230+
return OpenAIProvider(api_key=run_context.deps.openai_api_key)
231+
elif provider_name == 'anthropic':
232+
from pydantic_ai.providers.anthropic import AnthropicProvider
233+
234+
return AnthropicProvider(api_key=run_context.deps.anthropic_api_key)
235+
else:
236+
raise ValueError(f'Unknown provider: {provider_name}')
237+
238+
239+
agent = Agent(default_model, name='multi_model_agent', deps_type=Deps)
240+
241+
temporal_agent = TemporalAgent(
242+
agent,
243+
models={
244+
'fast': fast_model,
245+
'reasoning': reasoning_model,
246+
},
247+
provider_factory=my_provider_factory, # Optional
248+
)
249+
250+
251+
@workflow.defn
252+
class MultiModelWorkflow:
253+
@workflow.run
254+
async def run(self, prompt: str, use_reasoning: bool, use_fast: bool) -> str:
255+
if use_reasoning:
256+
# Select by registered name
257+
result = await temporal_agent.run(prompt, model='reasoning')
258+
elif use_fast:
259+
# Or pass the registered instance directly
260+
result = await temporal_agent.run(prompt, model=fast_model)
261+
else:
262+
# Or pass a model string (uses provider_factory if set)
263+
result = await temporal_agent.run(prompt, model='openai:gpt-4.1-mini')
264+
return result.output
265+
```
266+
187267
## Activity Configuration
188268

189269
Temporal activity configuration, like timeouts and retry policies, can be customized by passing [`temporalio.workflow.ActivityConfig`](https://python.temporal.io/temporalio.workflow.ActivityConfig.html) objects to the `TemporalAgent` constructor:

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pydantic_graph.nodes import End, NodeRunEndT
2727

2828
from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage
29+
from ._run_context import set_current_run_context
2930
from .exceptions import ToolRetryError
3031
from .output import OutputDataT, OutputSpec
3132
from .settings import ModelSettings
@@ -447,25 +448,26 @@ async def stream(
447448
assert not self._did_stream, 'stream() should only be called once per node'
448449

449450
model_settings, model_request_parameters, message_history, run_context = await self._prepare_request(ctx)
450-
async with ctx.deps.model.request_stream(
451-
message_history, model_settings, model_request_parameters, run_context
452-
) as streamed_response:
453-
self._did_stream = True
454-
ctx.state.usage.requests += 1
455-
agent_stream = result.AgentStream[DepsT, T](
456-
_raw_stream_response=streamed_response,
457-
_output_schema=ctx.deps.output_schema,
458-
_model_request_parameters=model_request_parameters,
459-
_output_validators=ctx.deps.output_validators,
460-
_run_ctx=build_run_context(ctx),
461-
_usage_limits=ctx.deps.usage_limits,
462-
_tool_manager=ctx.deps.tool_manager,
463-
)
464-
yield agent_stream
465-
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
466-
# otherwise usage won't be properly counted:
467-
async for _ in agent_stream:
468-
pass
451+
with set_current_run_context(run_context):
452+
async with ctx.deps.model.request_stream(
453+
message_history, model_settings, model_request_parameters, run_context
454+
) as streamed_response:
455+
self._did_stream = True
456+
ctx.state.usage.requests += 1
457+
agent_stream = result.AgentStream[DepsT, T](
458+
_raw_stream_response=streamed_response,
459+
_output_schema=ctx.deps.output_schema,
460+
_model_request_parameters=model_request_parameters,
461+
_output_validators=ctx.deps.output_validators,
462+
_run_ctx=build_run_context(ctx),
463+
_usage_limits=ctx.deps.usage_limits,
464+
_tool_manager=ctx.deps.tool_manager,
465+
)
466+
yield agent_stream
467+
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
468+
# otherwise usage won't be properly counted:
469+
async for _ in agent_stream:
470+
pass
469471

470472
model_response = streamed_response.get()
471473

@@ -478,8 +480,9 @@ async def _make_request(
478480
if self._result is not None:
479481
return self._result # pragma: no cover
480482

481-
model_settings, model_request_parameters, message_history, _ = await self._prepare_request(ctx)
482-
model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
483+
model_settings, model_request_parameters, message_history, run_context = await self._prepare_request(ctx)
484+
with set_current_run_context(run_context):
485+
model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
483486
ctx.state.usage.requests += 1
484487

485488
return self._finish_handling(ctx, model_response)

pydantic_ai_slim/pydantic_ai/_run_context.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations as _annotations
22

33
import dataclasses
4-
from collections.abc import Sequence
4+
from collections.abc import Iterator, Sequence
5+
from contextlib import contextmanager
6+
from contextvars import ContextVar
57
from dataclasses import field
68
from typing import TYPE_CHECKING, Any, Generic
79

@@ -71,3 +73,36 @@ def last_attempt(self) -> bool:
7173
return self.retry == self.max_retries
7274

7375
__repr__ = _utils.dataclasses_no_defaults_repr
76+
77+
78+
_CURRENT_RUN_CONTEXT: ContextVar[RunContext[Any] | None] = ContextVar(
79+
'pydantic_ai.current_run_context',
80+
default=None,
81+
)
82+
"""Context variable storing the current [`RunContext`][pydantic_ai.tools.RunContext]."""
83+
84+
85+
def get_current_run_context() -> RunContext[Any] | None:
86+
"""Get the current run context, if one is set.
87+
88+
Returns:
89+
The current [`RunContext`][pydantic_ai.tools.RunContext], or `None` if not in an agent run.
90+
"""
91+
return _CURRENT_RUN_CONTEXT.get()
92+
93+
94+
@contextmanager
95+
def set_current_run_context(run_context: RunContext[Any]) -> Iterator[None]:
96+
"""Context manager to set the current run context.
97+
98+
Args:
99+
run_context: The run context to set as current.
100+
101+
Yields:
102+
None
103+
"""
104+
token = _CURRENT_RUN_CONTEXT.set(run_context)
105+
try:
106+
yield
107+
finally:
108+
_CURRENT_RUN_CONTEXT.reset(token)

0 commit comments

Comments
 (0)