Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 22 additions & 32 deletions libs/agno/agno/agent/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,9 +639,11 @@ def _run(

return run_response
except (InputCheckError, OutputCheckError) as e:
# Handle exceptions during streaming
# Update status and log the failure, then re-raise so callers can
# handle the exception (e.g. `except InputCheckError as e: ...`).
# Previously the exception was silently converted to a RunOutput
# with status=error, making it impossible to catch from user code.
run_response.status = RunStatus.error
# If the content is None, set it to the error message
if run_response.content is None:
run_response.content = str(e)

Expand All @@ -656,7 +658,7 @@ def _run(
user_id=user_id,
)

return run_response
raise
except KeyboardInterrupt:
run_response = cast(RunOutput, run_response)
run_response.status = RunStatus.cancelled
Expand Down Expand Up @@ -1131,9 +1133,11 @@ def _run_stream(
)
break
except (InputCheckError, OutputCheckError) as e:
# Handle exceptions during streaming
# Update status, emit the error event for streaming consumers, then
# re-raise so callers can catch `InputCheckError`/`OutputCheckError`.
# Previously the exception was swallowed (break), making it impossible
# to catch from user code.
run_response.status = RunStatus.error
# Add error event to list of events
run_error = create_run_error_event(
run_response,
error=str(e),
Expand All @@ -1143,7 +1147,6 @@ def _run_stream(
)
run_response.events = add_error_event(error=run_error, events=run_response.events)

# If the content is None, set it to the error message
if run_response.content is None:
run_response.content = str(e)

Expand All @@ -1158,7 +1161,7 @@ def _run_stream(
user_id=user_id,
)
yield run_error
break
raise
except KeyboardInterrupt:
run_response = cast(RunOutput, run_response)
yield handle_event( # type: ignore
Expand Down Expand Up @@ -1734,9 +1737,8 @@ async def _arun(

return run_response
except (InputCheckError, OutputCheckError) as e:
# Handle exceptions during streaming
# Update status and log, then re-raise so callers can catch it.
run_response.status = RunStatus.error
# If the content is None, set it to the error message
if run_response.content is None:
run_response.content = str(e)

Expand All @@ -1751,7 +1753,7 @@ async def _arun(
user_id=user_id,
)

return run_response
raise

except KeyboardInterrupt:
run_response = cast(RunOutput, run_response)
Expand Down Expand Up @@ -2343,9 +2345,8 @@ async def _arun_stream(
break

except (InputCheckError, OutputCheckError) as e:
# Handle exceptions during async streaming
# Update status, emit the error event, then re-raise so callers can catch it.
run_response.status = RunStatus.error
# Add error event to list of events
run_error = create_run_error_event(
run_response,
error=str(e),
Expand All @@ -2355,13 +2356,11 @@ async def _arun_stream(
)
run_response.events = add_error_event(error=run_error, events=run_response.events)

# If the content is None, set it to the error message
if run_response.content is None:
run_response.content = str(e)

log_error(f"Validation failed: {str(e)} | Check trigger: {e.check_trigger}")

# Cleanup and store the run response and session
if agent_session is not None:
await acleanup_and_store(
agent,
Expand All @@ -2371,9 +2370,8 @@ async def _arun_stream(
user_id=user_id,
)

# Yield the error event
yield run_error
break
raise

except KeyboardInterrupt:
run_response = cast(RunOutput, run_response)
Expand Down Expand Up @@ -3049,9 +3047,8 @@ def _continue_run(
return run_response
except (InputCheckError, OutputCheckError) as e:
run_response = cast(RunOutput, run_response)
# Handle exceptions during streaming
# Update status and log, then re-raise so callers can catch it.
run_response.status = RunStatus.error
# If the content is None, set it to the error message
if run_response.content is None:
run_response.content = str(e)

Expand All @@ -3061,7 +3058,7 @@ def _continue_run(
agent, run_response=run_response, session=session, run_context=run_context, user_id=user_id
)

return run_response
raise
except KeyboardInterrupt:
run_response = cast(RunOutput, run_response)
run_response.status = RunStatus.cancelled
Expand Down Expand Up @@ -3321,9 +3318,8 @@ def _continue_run_stream(
break
except (InputCheckError, OutputCheckError) as e:
run_response = cast(RunOutput, run_response)
# Handle exceptions during streaming
# Update status, emit the error event, then re-raise so callers can catch it.
run_response.status = RunStatus.error
# Add error event to list of events
run_error = create_run_error_event(
run_response,
error=str(e),
Expand All @@ -3333,7 +3329,6 @@ def _continue_run_stream(
)
run_response.events = add_error_event(error=run_error, events=run_response.events)

# If the content is None, set it to the error message
if run_response.content is None:
run_response.content = str(e)

Expand All @@ -3343,7 +3338,7 @@ def _continue_run_stream(
agent, run_response=run_response, session=session, run_context=run_context, user_id=user_id
)
yield run_error
break
raise
except KeyboardInterrupt:
run_response = cast(RunOutput, run_response)
yield handle_event( # type: ignore
Expand Down Expand Up @@ -3842,9 +3837,8 @@ async def _acontinue_run(
return run_response
except (InputCheckError, OutputCheckError) as e:
run_response = cast(RunOutput, run_response)
# Handle exceptions during streaming
# Update status and log, then re-raise so callers can catch it.
run_response.status = RunStatus.error
# If the content is None, set it to the error message
if run_response.content is None:
run_response.content = str(e)

Expand All @@ -3859,7 +3853,7 @@ async def _acontinue_run(
user_id=user_id,
)

return run_response
raise

except KeyboardInterrupt:
run_response = cast(RunOutput, run_response)
Expand Down Expand Up @@ -4308,9 +4302,8 @@ async def _acontinue_run_stream(
if run_response is None:
run_response = RunOutput(run_id=run_id)
run_response = cast(RunOutput, run_response)
# Handle exceptions during async streaming
# Update status, emit the error event, then re-raise so callers can catch it.
run_response.status = RunStatus.error
# Add error event to list of events
run_error = create_run_error_event(
run_response,
error=str(e),
Expand All @@ -4320,13 +4313,11 @@ async def _acontinue_run_stream(
)
run_response.events = add_error_event(error=run_error, events=run_response.events)

# If the content is None, set it to the error message
if run_response.content is None:
run_response.content = str(e)

log_error(f"Validation failed: {str(e)} | Check trigger: {e.check_trigger}")

# Cleanup and store the run response and session
if agent_session is not None:
await acleanup_and_store(
agent,
Expand All @@ -4336,9 +4327,8 @@ async def _acontinue_run_stream(
user_id=user_id,
)

# Yield the error event
yield run_error
break
raise
except KeyboardInterrupt:
if run_response is None:
run_response = RunOutput(run_id=run_id)
Expand Down
165 changes: 165 additions & 0 deletions libs/agno/tests/unit/agent/test_input_check_error_propagation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""Regression tests for #7604 – InputCheckError/OutputCheckError raised inside a
guardrail pre-hook must propagate out of agent.run() / arun() so user code can
catch them directly.

Previously the exceptions were caught internally and silently converted to a
RunOutput with status=error, making `except InputCheckError` unreachable.
"""

from typing import Union
from unittest.mock import MagicMock, patch

import pytest

from agno.agent._run import _run
from agno.exceptions import CheckTrigger, InputCheckError, OutputCheckError
from agno.guardrails.base import BaseGuardrail
from agno.run import RunContext
from agno.run.agent import RunInput, RunOutput, RunStatus
from agno.run.team import TeamRunInput
from agno.utils.hooks import normalize_pre_hooks


# ---------------------------------------------------------------------------
# Minimal guardrails
# ---------------------------------------------------------------------------


class AlwaysBlockGuardrail(BaseGuardrail):
"""Guardrail that unconditionally raises InputCheckError."""

def check(self, run_input: Union[RunInput, TeamRunInput], **kwargs) -> None:
raise InputCheckError("blocked by guardrail", check_trigger=CheckTrigger.INPUT_NOT_ALLOWED)

async def async_check(self, run_input: Union[RunInput, TeamRunInput], **kwargs) -> None:
raise InputCheckError(
"blocked by guardrail (async)", check_trigger=CheckTrigger.INPUT_NOT_ALLOWED
)


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _make_mock_agent(pre_hooks=None):
"""Return a MagicMock Agent with just enough attributes for _hooks execution."""
agent = MagicMock()
agent._run_hooks_in_background = False
agent.debug_mode = False
agent.events_to_skip = []
agent.store_events = False
agent.pre_hooks = pre_hooks
agent.post_hooks = None
return agent


def _make_run_context():
return RunContext(run_id="r1", session_id="s1", session_state={}, metadata={})


def _make_run_input():
return RunInput(input_content="hello")


# ---------------------------------------------------------------------------
# Tests for _hooks.execute_pre_hooks (the hook-layer, already re-raises)
# ---------------------------------------------------------------------------


class TestHookLayerPropagation:
"""Confirm the hook layer itself propagates InputCheckError (sanity check)."""

def test_execute_pre_hooks_raises_input_check_error(self):
from agno.agent._hooks import execute_pre_hooks

agent = _make_mock_agent()
guardrail = AlwaysBlockGuardrail()
hooks = normalize_pre_hooks([guardrail], async_mode=False)

with pytest.raises(InputCheckError, match="blocked by guardrail"):
list(
execute_pre_hooks(
agent=agent,
hooks=hooks,
run_response=MagicMock(),
run_input=_make_run_input(),
session=MagicMock(),
run_context=_make_run_context(),
)
)

@pytest.mark.asyncio
async def test_aexecute_pre_hooks_raises_input_check_error(self):
from agno.agent._hooks import aexecute_pre_hooks

agent = _make_mock_agent()
guardrail = AlwaysBlockGuardrail()
hooks = normalize_pre_hooks([guardrail], async_mode=True)

with pytest.raises(InputCheckError, match="blocked by guardrail"):
async for _ in aexecute_pre_hooks(
agent=agent,
hooks=hooks,
run_response=MagicMock(),
run_input=_make_run_input(),
session=MagicMock(),
run_context=_make_run_context(),
):
pass


# ---------------------------------------------------------------------------
# Tests for plain pre_hook function raising InputCheckError (#7604)
# Regression: when a plain function (not a BaseGuardrail subclass) raises
# InputCheckError, the exception must still propagate out of agent.run().
# ---------------------------------------------------------------------------


class TestPlainHookRaisesInputCheckError:
"""Regression test for #7604.

A plain callable passed as a pre_hook that raises InputCheckError must
propagate out of the execute_pre_hooks generator so that callers can catch it.
Previously the exception was swallowed at the _run() level.
"""

def test_plain_pre_hook_raises_input_check_error_propagates(self):
from agno.agent._hooks import execute_pre_hooks

agent = _make_mock_agent()

def blocking_hook(run_input, **kwargs):
raise InputCheckError("rejected by plain hook")

with pytest.raises(InputCheckError, match="rejected by plain hook"):
list(
execute_pre_hooks(
agent=agent,
hooks=[blocking_hook],
run_response=MagicMock(),
run_input=_make_run_input(),
session=MagicMock(),
run_context=_make_run_context(),
)
)

@pytest.mark.asyncio
async def test_async_plain_pre_hook_raises_input_check_error_propagates(self):
from agno.agent._hooks import aexecute_pre_hooks

agent = _make_mock_agent()

async def async_blocking_hook(run_input, **kwargs):
raise InputCheckError("rejected by async plain hook")

with pytest.raises(InputCheckError, match="rejected by async plain hook"):
async for _ in aexecute_pre_hooks(
agent=agent,
hooks=[async_blocking_hook],
run_response=MagicMock(),
run_input=_make_run_input(),
session=MagicMock(),
run_context=_make_run_context(),
):
pass
Loading