Skip to content

Commit c8b0b78

Browse files
authored
Reraise workflow failure errors from OpenAI's UserError (#1060)
* Reraise workflow failure errors from OpenAI's UserError * Docstring * PR feedback * PR feedback * Fix circular dependency * Remove a few live openai tests, the models don't always do what they need to * Lint
1 parent a3c5370 commit c8b0b78

File tree

7 files changed

+133
-53
lines changed

7 files changed

+133
-53
lines changed

temporalio/contrib/openai_agents/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
from temporalio.contrib.openai_agents._trace_interceptor import (
1818
OpenAIAgentsTracingInterceptor,
1919
)
20+
from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError
2021

2122
from . import workflow
2223

2324
__all__ = [
25+
"AgentsWorkflowError",
2426
"OpenAIAgentsPlugin",
2527
"ModelActivityParameters",
2628
"workflow",

temporalio/contrib/openai_agents/_openai_runner.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from agents import (
77
Agent,
8+
AgentsException,
89
Handoff,
910
RunConfig,
1011
RunContextWrapper,
@@ -21,6 +22,7 @@
2122
from temporalio import workflow
2223
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
2324
from temporalio.contrib.openai_agents._temporal_model_stub import _TemporalModelStub
25+
from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError
2426

2527

2628
class TemporalOpenAIRunner(AgentRunner):
@@ -136,16 +138,28 @@ async def on_invoke(
136138
handoffs=new_handoffs,
137139
)
138140

139-
return await self._runner.run(
140-
starting_agent=convert_agent(starting_agent, None),
141-
input=input,
142-
context=context,
143-
max_turns=max_turns,
144-
hooks=hooks,
145-
run_config=run_config,
146-
previous_response_id=previous_response_id,
147-
session=session,
148-
)
141+
try:
142+
return await self._runner.run(
143+
starting_agent=convert_agent(starting_agent, None),
144+
input=input,
145+
context=context,
146+
max_turns=max_turns,
147+
hooks=hooks,
148+
run_config=run_config,
149+
previous_response_id=previous_response_id,
150+
session=session,
151+
)
152+
except AgentsException as e:
153+
# In order for workflow failures to properly fail the workflow, we need to rewrap them in
154+
# a Temporal error
155+
if e.__cause__ and workflow.is_failure_exception(e.__cause__):
156+
reraise = AgentsWorkflowError(
157+
f"Workflow failure exception in Agents Framework: {e}"
158+
)
159+
reraise.__traceback__ = e.__traceback__
160+
raise reraise from e.__cause__
161+
else:
162+
raise e
149163

150164
def run_sync(
151165
self,

temporalio/contrib/openai_agents/_temporal_openai_agents.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,19 @@
2424

2525
import temporalio.client
2626
import temporalio.worker
27-
from temporalio.client import ClientConfig, Plugin
27+
from temporalio.client import ClientConfig
2828
from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity
2929
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
30-
from temporalio.contrib.openai_agents._openai_runner import TemporalOpenAIRunner
30+
from temporalio.contrib.openai_agents._openai_runner import (
31+
TemporalOpenAIRunner,
32+
)
3133
from temporalio.contrib.openai_agents._temporal_trace_provider import (
3234
TemporalTraceProvider,
3335
)
3436
from temporalio.contrib.openai_agents._trace_interceptor import (
3537
OpenAIAgentsTracingInterceptor,
3638
)
39+
from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError
3740
from temporalio.contrib.pydantic import (
3841
PydanticPayloadConverter,
3942
ToJsonOptions,
@@ -284,6 +287,9 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
284287
config["activities"] = list(config.get("activities") or []) + [
285288
ModelActivity(self._model_provider).invoke_model_activity
286289
]
290+
config["workflow_failure_exception_types"] = list(
291+
config.get("workflow_failure_exception_types") or []
292+
) + [AgentsWorkflowError]
287293
return self.next_worker_plugin.configure_worker(config)
288294

289295
async def run_worker(self, worker: Worker) -> None:

temporalio/contrib/openai_agents/workflow.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,12 @@ class ToolSerializationError(TemporalError):
263263
To fix this error, ensure your tool returns string-convertible values or
264264
modify the tool to return a string representation of the result.
265265
"""
266+
267+
268+
class AgentsWorkflowError(TemporalError):
269+
"""Error that occurs when the agents SDK raises an error which should terminate the calling workflow or update.
270+
271+
.. warning::
272+
This exception is experimental and may change in future versions.
273+
Use with caution in production environments.
274+
"""

temporalio/worker/_workflow_instance.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def activate(
414414
# We want some errors during activation, like those that can happen
415415
# during payload conversion, to be able to fail the workflow not the
416416
# task
417-
if self._is_workflow_failure_exception(err):
417+
if self.workflow_is_failure_exception(err):
418418
try:
419419
self._set_workflow_failure(err)
420420
except Exception as inner_err:
@@ -629,7 +629,7 @@ async def run_update() -> None:
629629
# Validation failures are always update failures. We reuse
630630
# workflow failure logic to decide task failure vs update
631631
# failure after validation.
632-
if not past_validation or self._is_workflow_failure_exception(err):
632+
if not past_validation or self.workflow_is_failure_exception(err):
633633
if command is None:
634634
command = self._add_command()
635635
command.update_response.protocol_instance_id = (
@@ -1686,6 +1686,23 @@ def workflow_set_current_details(self, details: str):
16861686
self._assert_not_read_only("set current details")
16871687
self._current_details = details
16881688

1689+
def workflow_is_failure_exception(self, err: BaseException) -> bool:
1690+
# An exception is a failure instead of a task fail if it's already a
1691+
# failure error or if it is a timeout error or if it is an instance of
1692+
# any of the failure types in the worker or workflow-level setting
1693+
wf_failure_exception_types = self._defn.failure_exception_types
1694+
if self._dynamic_failure_exception_types is not None:
1695+
wf_failure_exception_types = self._dynamic_failure_exception_types
1696+
return (
1697+
isinstance(err, temporalio.exceptions.FailureError)
1698+
or isinstance(err, asyncio.TimeoutError)
1699+
or any(isinstance(err, typ) for typ in wf_failure_exception_types)
1700+
or any(
1701+
isinstance(err, typ)
1702+
for typ in self._worker_level_failure_exception_types
1703+
)
1704+
)
1705+
16891706
#### Calls from outbound impl ####
16901707
# These are in alphabetical order and all start with "_outbound_".
16911708

@@ -1939,7 +1956,7 @@ def _convert_payloads(
19391956
# Don't wrap payload conversion errors that would fail the workflow
19401957
raise
19411958
except Exception as err:
1942-
if self._is_workflow_failure_exception(err):
1959+
if self.workflow_is_failure_exception(err):
19431960
raise
19441961
raise RuntimeError("Failed decoding arguments") from err
19451962

@@ -1982,23 +1999,6 @@ def _instantiate_workflow_object(self) -> Any:
19821999

19832000
return workflow_instance
19842001

1985-
def _is_workflow_failure_exception(self, err: BaseException) -> bool:
1986-
# An exception is a failure instead of a task fail if it's already a
1987-
# failure error or if it is a timeout error or if it is an instance of
1988-
# any of the failure types in the worker or workflow-level setting
1989-
wf_failure_exception_types = self._defn.failure_exception_types
1990-
if self._dynamic_failure_exception_types is not None:
1991-
wf_failure_exception_types = self._dynamic_failure_exception_types
1992-
return (
1993-
isinstance(err, temporalio.exceptions.FailureError)
1994-
or isinstance(err, asyncio.TimeoutError)
1995-
or any(isinstance(err, typ) for typ in wf_failure_exception_types)
1996-
or any(
1997-
isinstance(err, typ)
1998-
for typ in self._worker_level_failure_exception_types
1999-
)
2000-
)
2001-
20022002
def _warn_if_unfinished_handlers(self) -> None:
20032003
def warnable(handler_executions: Iterable[HandlerExecution]):
20042004
return [
@@ -2192,7 +2192,7 @@ async def _run_top_level_workflow_function(self, coro: Awaitable[None]) -> None:
21922192
err
21932193
):
21942194
self._add_command().cancel_workflow_execution.SetInParent()
2195-
elif self._is_workflow_failure_exception(err):
2195+
elif self.workflow_is_failure_exception(err):
21962196
# All other failure errors fail the workflow
21972197
self._set_workflow_failure(err)
21982198
else:

temporalio/workflow.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,9 @@ def workflow_get_current_details(self) -> str: ...
897897
@abstractmethod
898898
def workflow_set_current_details(self, details: str): ...
899899

900+
@abstractmethod
901+
def workflow_is_failure_exception(self, err: BaseException) -> bool: ...
902+
900903

901904
_current_update_info: contextvars.ContextVar[UpdateInfo] = contextvars.ContextVar(
902905
"__temporal_current_update_info"
@@ -981,6 +984,15 @@ def memo() -> Mapping[str, Any]:
981984
return _Runtime.current().workflow_memo()
982985

983986

987+
def is_failure_exception(err: BaseException) -> bool:
988+
"""Checks if the given exception is a workflow failure in the current workflow.
989+
990+
Returns:
991+
True if the given exception is a workflow failure in the current workflow.
992+
"""
993+
return _Runtime.current().workflow_is_failure_exception(err)
994+
995+
984996
@overload
985997
def memo_value(key: str, default: Any = temporalio.common._arg_unset) -> Any: ...
986998

tests/contrib/openai_agents/test_openai.py

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,10 @@ async def run(self, question: str) -> str:
318318
ActivityWeatherService.get_weather_method,
319319
start_to_close_timeout=timedelta(seconds=10),
320320
),
321+
openai_agents.workflow.activity_as_tool(
322+
get_weather_failure,
323+
start_to_close_timeout=timedelta(seconds=10),
324+
),
321325
],
322326
)
323327
result = await Runner.run(
@@ -462,6 +466,53 @@ async def test_tool_workflow(client: Client, use_local_model: bool):
462466
)
463467

464468

469+
@activity.defn
470+
async def get_weather_failure(city: str) -> Weather:
471+
"""
472+
Get the weather for a given city.
473+
"""
474+
raise ApplicationError("No weather", non_retryable=True)
475+
476+
477+
class TestWeatherFailureModel(StaticTestModel):
478+
responses = [
479+
ResponseBuilders.tool_call('{"city":"Tokyo"}', "get_weather_failure"),
480+
]
481+
482+
483+
async def test_tool_failure_workflow(client: Client):
484+
new_config = client.config()
485+
new_config["plugins"] = [
486+
openai_agents.OpenAIAgentsPlugin(
487+
model_params=ModelActivityParameters(
488+
start_to_close_timeout=timedelta(seconds=30)
489+
),
490+
model_provider=TestModelProvider(TestWeatherFailureModel()),
491+
)
492+
]
493+
client = Client(**new_config)
494+
495+
async with new_worker(
496+
client,
497+
ToolsWorkflow,
498+
activities=[
499+
get_weather_failure,
500+
],
501+
) as worker:
502+
workflow_handle = await client.start_workflow(
503+
ToolsWorkflow.run,
504+
"What is the weather in Tokio?",
505+
id=f"tools-failure-workflow-{uuid.uuid4()}",
506+
task_queue=worker.task_queue,
507+
execution_timeout=timedelta(seconds=2),
508+
)
509+
with pytest.raises(WorkflowFailureError) as e:
510+
result = await workflow_handle.result()
511+
cause = e.value.cause
512+
assert isinstance(cause, ApplicationError)
513+
assert "Workflow failure exception in Agents Framework" in cause.message
514+
515+
465516
@pytest.mark.parametrize("use_local_model", [True, False])
466517
async def test_nexus_tool_workflow(
467518
client: Client, env: WorkflowEnvironment, use_local_model: bool
@@ -1909,20 +1960,14 @@ async def run(self, question: str) -> str:
19091960
return result.final_output
19101961

19111962

1912-
@pytest.mark.parametrize("use_local_model", [True, False])
1913-
async def test_code_interpreter_tool(client: Client, use_local_model):
1914-
if not use_local_model and not os.environ.get("OPENAI_API_KEY"):
1915-
pytest.skip("No openai API key")
1916-
1963+
async def test_code_interpreter_tool(client: Client):
19171964
new_config = client.config()
19181965
new_config["plugins"] = [
19191966
openai_agents.OpenAIAgentsPlugin(
19201967
model_params=ModelActivityParameters(
19211968
start_to_close_timeout=timedelta(seconds=60)
19221969
),
1923-
model_provider=TestModelProvider(CodeInterpreterModel())
1924-
if use_local_model
1925-
else None,
1970+
model_provider=TestModelProvider(CodeInterpreterModel()),
19261971
)
19271972
]
19281973
client = Client(**new_config)
@@ -1939,8 +1984,7 @@ async def test_code_interpreter_tool(client: Client, use_local_model):
19391984
execution_timeout=timedelta(seconds=60),
19401985
)
19411986
result = await workflow_handle.result()
1942-
if use_local_model:
1943-
assert result == "Over 9000"
1987+
assert result == "Over 9000"
19441988

19451989

19461990
class HostedMCPModel(StaticTestModel):
@@ -2011,20 +2055,14 @@ def approve(_: MCPToolApprovalRequest) -> MCPToolApprovalFunctionResult:
20112055
return result.final_output
20122056

20132057

2014-
@pytest.mark.parametrize("use_local_model", [True, False])
2015-
async def test_hosted_mcp_tool(client: Client, use_local_model):
2016-
if not use_local_model and not os.environ.get("OPENAI_API_KEY"):
2017-
pytest.skip("No openai API key")
2018-
2058+
async def test_hosted_mcp_tool(client: Client):
20192059
new_config = client.config()
20202060
new_config["plugins"] = [
20212061
openai_agents.OpenAIAgentsPlugin(
20222062
model_params=ModelActivityParameters(
20232063
start_to_close_timeout=timedelta(seconds=120)
20242064
),
2025-
model_provider=TestModelProvider(HostedMCPModel())
2026-
if use_local_model
2027-
else None,
2065+
model_provider=TestModelProvider(HostedMCPModel()),
20282066
)
20292067
]
20302068
client = Client(**new_config)
@@ -2041,8 +2079,7 @@ async def test_hosted_mcp_tool(client: Client, use_local_model):
20412079
execution_timeout=timedelta(seconds=120),
20422080
)
20432081
result = await workflow_handle.result()
2044-
if use_local_model:
2045-
assert result == "Some language"
2082+
assert result == "Some language"
20462083

20472084

20482085
class AssertDifferentModelProvider(ModelProvider):

0 commit comments

Comments
 (0)