diff --git a/libs/agno/agno/agent/_run.py b/libs/agno/agno/agent/_run.py index f8419d316d..227f456ec9 100644 --- a/libs/agno/agno/agent/_run.py +++ b/libs/agno/agno/agent/_run.py @@ -639,9 +639,11 @@ def _run( return run_response except (InputCheckError, OutputCheckError) as e: - # Handle exceptions during streaming + # Update status and log the failure, then re-raise so callers can + # handle the exception (e.g. `except InputCheckError as e: ...`). + # Previously the exception was silently converted to a RunOutput + # with status=error, making it impossible to catch from user code. run_response.status = RunStatus.error - # If the content is None, set it to the error message if run_response.content is None: run_response.content = str(e) @@ -656,7 +658,7 @@ def _run( user_id=user_id, ) - return run_response + raise except KeyboardInterrupt: run_response = cast(RunOutput, run_response) run_response.status = RunStatus.cancelled @@ -1131,9 +1133,11 @@ def _run_stream( ) break except (InputCheckError, OutputCheckError) as e: - # Handle exceptions during streaming + # Update status, emit the error event for streaming consumers, then + # re-raise so callers can catch `InputCheckError`/`OutputCheckError`. + # Previously the exception was swallowed (break), making it impossible + # to catch from user code. run_response.status = RunStatus.error - # Add error event to list of events run_error = create_run_error_event( run_response, error=str(e), @@ -1143,7 +1147,6 @@ def _run_stream( ) run_response.events = add_error_event(error=run_error, events=run_response.events) - # If the content is None, set it to the error message if run_response.content is None: run_response.content = str(e) @@ -1158,7 +1161,7 @@ def _run_stream( user_id=user_id, ) yield run_error - break + raise except KeyboardInterrupt: run_response = cast(RunOutput, run_response) yield handle_event( # type: ignore @@ -1734,9 +1737,8 @@ async def _arun( return run_response except (InputCheckError, OutputCheckError) as e: - # Handle exceptions during streaming + # Update status and log, then re-raise so callers can catch it. run_response.status = RunStatus.error - # If the content is None, set it to the error message if run_response.content is None: run_response.content = str(e) @@ -1751,7 +1753,7 @@ async def _arun( user_id=user_id, ) - return run_response + raise except KeyboardInterrupt: run_response = cast(RunOutput, run_response) @@ -2483,9 +2485,8 @@ async def _arun_stream( break except (InputCheckError, OutputCheckError) as e: - # Handle exceptions during async streaming + # Update status, emit the error event, then re-raise so callers can catch it. run_response.status = RunStatus.error - # Add error event to list of events run_error = create_run_error_event( run_response, error=str(e), @@ -2495,13 +2496,11 @@ async def _arun_stream( ) run_response.events = add_error_event(error=run_error, events=run_response.events) - # If the content is None, set it to the error message if run_response.content is None: run_response.content = str(e) log_error(f"Validation failed: {str(e)} | Check trigger: {e.check_trigger}") - # Cleanup and store the run response and session if agent_session is not None: await acleanup_and_store( agent, @@ -2511,9 +2510,8 @@ async def _arun_stream( user_id=user_id, ) - # Yield the error event yield run_error - break + raise except KeyboardInterrupt: run_response = cast(RunOutput, run_response) @@ -3216,9 +3214,8 @@ def _continue_run( return run_response except (InputCheckError, OutputCheckError) as e: run_response = cast(RunOutput, run_response) - # Handle exceptions during streaming + # Update status and log, then re-raise so callers can catch it. run_response.status = RunStatus.error - # If the content is None, set it to the error message if run_response.content is None: run_response.content = str(e) @@ -3228,7 +3225,7 @@ def _continue_run( agent, run_response=run_response, session=session, run_context=run_context, user_id=user_id ) - return run_response + raise except KeyboardInterrupt: run_response = cast(RunOutput, run_response) run_response.status = RunStatus.cancelled @@ -3488,9 +3485,8 @@ def _continue_run_stream( break except (InputCheckError, OutputCheckError) as e: run_response = cast(RunOutput, run_response) - # Handle exceptions during streaming + # Update status, emit the error event, then re-raise so callers can catch it. run_response.status = RunStatus.error - # Add error event to list of events run_error = create_run_error_event( run_response, error=str(e), @@ -3500,7 +3496,6 @@ def _continue_run_stream( ) run_response.events = add_error_event(error=run_error, events=run_response.events) - # If the content is None, set it to the error message if run_response.content is None: run_response.content = str(e) @@ -3510,7 +3505,7 @@ def _continue_run_stream( agent, run_response=run_response, session=session, run_context=run_context, user_id=user_id ) yield run_error - break + raise except KeyboardInterrupt: run_response = cast(RunOutput, run_response) yield handle_event( # type: ignore @@ -4175,9 +4170,8 @@ async def _acontinue_run( return run_response except (InputCheckError, OutputCheckError) as e: run_response = cast(RunOutput, run_response) - # Handle exceptions during streaming + # Update status and log, then re-raise so callers can catch it. run_response.status = RunStatus.error - # If the content is None, set it to the error message if run_response.content is None: run_response.content = str(e) @@ -4192,7 +4186,7 @@ async def _acontinue_run( user_id=user_id, ) - return run_response + raise except KeyboardInterrupt: run_response = cast(RunOutput, run_response) @@ -4642,9 +4636,8 @@ async def _acontinue_run_stream( if run_response is None: run_response = RunOutput(run_id=run_id) run_response = cast(RunOutput, run_response) - # Handle exceptions during async streaming + # Update status, emit the error event, then re-raise so callers can catch it. run_response.status = RunStatus.error - # Add error event to list of events run_error = create_run_error_event( run_response, error=str(e), @@ -4654,13 +4647,11 @@ async def _acontinue_run_stream( ) run_response.events = add_error_event(error=run_error, events=run_response.events) - # If the content is None, set it to the error message if run_response.content is None: run_response.content = str(e) log_error(f"Validation failed: {str(e)} | Check trigger: {e.check_trigger}") - # Cleanup and store the run response and session if agent_session is not None: await acleanup_and_store( agent, @@ -4670,9 +4661,8 @@ async def _acontinue_run_stream( user_id=user_id, ) - # Yield the error event yield run_error - break + raise except KeyboardInterrupt: if run_response is None: run_response = RunOutput(run_id=run_id) diff --git a/libs/agno/tests/unit/agent/test_input_check_error_propagation.py b/libs/agno/tests/unit/agent/test_input_check_error_propagation.py new file mode 100644 index 0000000000..5d457de2d6 --- /dev/null +++ b/libs/agno/tests/unit/agent/test_input_check_error_propagation.py @@ -0,0 +1,315 @@ +"""Regression tests for #7604 – InputCheckError/OutputCheckError raised inside a +guardrail pre-hook must propagate out of agent.run() / arun() so user code can +catch them directly. + +Previously the exceptions were caught internally and silently converted to a +RunOutput with status=error, making `except InputCheckError` unreachable. +""" + +from typing import Any, AsyncIterator, Iterator, Union +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest + +from agno.agent.agent import Agent +from agno.exceptions import CheckTrigger, InputCheckError, OutputCheckError +from agno.guardrails.base import BaseGuardrail +from agno.media import Image +from agno.models.base import Model +from agno.models.message import MessageMetrics +from agno.models.response import ModelResponse +from agno.run import RunContext +from agno.run.agent import RunInput, RunOutput, RunStatus +from agno.run.team import TeamRunInput +from agno.utils.hooks import normalize_pre_hooks + + +# --------------------------------------------------------------------------- +# Minimal guardrails +# --------------------------------------------------------------------------- + + +class AlwaysBlockGuardrail(BaseGuardrail): + """Guardrail that unconditionally raises InputCheckError.""" + + def check(self, run_input: Union[RunInput, TeamRunInput], **kwargs) -> None: + raise InputCheckError("blocked by guardrail", check_trigger=CheckTrigger.INPUT_NOT_ALLOWED) + + async def async_check(self, run_input: Union[RunInput, TeamRunInput], **kwargs) -> None: + raise InputCheckError( + "blocked by guardrail (async)", check_trigger=CheckTrigger.INPUT_NOT_ALLOWED + ) + + +# --------------------------------------------------------------------------- +# Minimal mock model (needed so the Agent can be constructed) +# --------------------------------------------------------------------------- + + +class _DummyModel(Model): + """A no-op model that should never be reached when the guardrail blocks.""" + + def __init__(self): + super().__init__(id="dummy", name="dummy", provider="test") + self.instructions = None + self._resp = ModelResponse( + content="should not reach here", + role="assistant", + response_usage=MessageMetrics(), + ) + self.response = Mock(return_value=self._resp) + self.aresponse = AsyncMock(return_value=self._resp) + + def get_instructions_for_model(self, *args, **kwargs): + return None + + def get_system_message_for_model(self, *args, **kwargs): + return None + + async def aget_instructions_for_model(self, *args, **kwargs): + return None + + async def aget_system_message_for_model(self, *args, **kwargs): + return None + + def parse_args(self, *args, **kwargs): + return {} + + def invoke(self, *args, **kwargs) -> ModelResponse: + return self._resp + + async def ainvoke(self, *args, **kwargs) -> ModelResponse: + return await self.aresponse(*args, **kwargs) + + def invoke_stream(self, *args, **kwargs) -> Iterator[ModelResponse]: + yield self._resp + + async def ainvoke_stream(self, *args, **kwargs) -> AsyncIterator[ModelResponse]: + yield self._resp + return + + def _parse_provider_response(self, response: Any, **kwargs) -> ModelResponse: + return self._resp + + def _parse_provider_response_delta(self, response: Any) -> ModelResponse: + return self._resp + + +# --------------------------------------------------------------------------- +# Helpers for hook-layer tests +# --------------------------------------------------------------------------- + + +def _make_mock_agent(pre_hooks=None): + """Return a MagicMock Agent with just enough attributes for _hooks execution.""" + agent = MagicMock() + agent._run_hooks_in_background = False + agent.debug_mode = False + agent.events_to_skip = [] + agent.store_events = False + agent.pre_hooks = pre_hooks + agent.post_hooks = None + return agent + + +def _make_run_context(): + return RunContext(run_id="r1", session_id="s1", session_state={}, metadata={}) + + +def _make_run_input(): + return RunInput(input_content="hello") + + +# --------------------------------------------------------------------------- +# Tests for _hooks.execute_pre_hooks (the hook-layer, already re-raises) +# --------------------------------------------------------------------------- + + +class TestHookLayerPropagation: + """Confirm the hook layer itself propagates InputCheckError (sanity check).""" + + def test_execute_pre_hooks_raises_input_check_error(self): + from agno.agent._hooks import execute_pre_hooks + + agent = _make_mock_agent() + guardrail = AlwaysBlockGuardrail() + hooks = normalize_pre_hooks([guardrail], async_mode=False) + + with pytest.raises(InputCheckError, match="blocked by guardrail"): + list( + execute_pre_hooks( + agent=agent, + hooks=hooks, + run_response=MagicMock(), + run_input=_make_run_input(), + session=MagicMock(), + run_context=_make_run_context(), + ) + ) + + @pytest.mark.asyncio + async def test_aexecute_pre_hooks_raises_input_check_error(self): + from agno.agent._hooks import aexecute_pre_hooks + + agent = _make_mock_agent() + guardrail = AlwaysBlockGuardrail() + hooks = normalize_pre_hooks([guardrail], async_mode=True) + + with pytest.raises(InputCheckError, match="blocked by guardrail"): + async for _ in aexecute_pre_hooks( + agent=agent, + hooks=hooks, + run_response=MagicMock(), + run_input=_make_run_input(), + session=MagicMock(), + run_context=_make_run_context(), + ): + pass + + +# --------------------------------------------------------------------------- +# Tests for plain pre_hook function raising InputCheckError (#7604) +# Regression: when a plain function (not a BaseGuardrail subclass) raises +# InputCheckError, the exception must still propagate out of agent.run(). +# --------------------------------------------------------------------------- + + +class TestPlainHookRaisesInputCheckError: + """Regression test for #7604. + + A plain callable passed as a pre_hook that raises InputCheckError must + propagate out of the execute_pre_hooks generator so that callers can catch it. + Previously the exception was swallowed at the _run() level. + """ + + def test_plain_pre_hook_raises_input_check_error_propagates(self): + from agno.agent._hooks import execute_pre_hooks + + agent = _make_mock_agent() + + def blocking_hook(run_input, **kwargs): + raise InputCheckError("rejected by plain hook") + + with pytest.raises(InputCheckError, match="rejected by plain hook"): + list( + execute_pre_hooks( + agent=agent, + hooks=[blocking_hook], + run_response=MagicMock(), + run_input=_make_run_input(), + session=MagicMock(), + run_context=_make_run_context(), + ) + ) + + @pytest.mark.asyncio + async def test_async_plain_pre_hook_raises_input_check_error_propagates(self): + from agno.agent._hooks import aexecute_pre_hooks + + agent = _make_mock_agent() + + async def async_blocking_hook(run_input, **kwargs): + raise InputCheckError("rejected by async plain hook") + + with pytest.raises(InputCheckError, match="rejected by async plain hook"): + async for _ in aexecute_pre_hooks( + agent=agent, + hooks=[async_blocking_hook], + run_response=MagicMock(), + run_input=_make_run_input(), + session=MagicMock(), + run_context=_make_run_context(), + ): + pass + + +# --------------------------------------------------------------------------- +# END-TO-END tests: agent.run() / agent.arun() with a blocking guardrail +# These exercise the actual _run.py code paths that were fixed (#7604). +# Reverting the _run.py changes would cause these tests to FAIL. +# --------------------------------------------------------------------------- + + +class TestEndToEndInputCheckErrorPropagation: + """End-to-end tests verifying InputCheckError propagates through agent.run() + and agent.arun() — the actual fix site in _run.py. + + Before the fix, these exceptions were caught in the try/except blocks inside + _run(), _arun(), etc. and converted to a RunOutput with status=error. The + user's `except InputCheckError` block was unreachable. + """ + + def test_agent_run_raises_input_check_error(self): + """Sync non-stream: agent.run() must propagate InputCheckError.""" + agent = Agent( + model=_DummyModel(), + pre_hooks=[AlwaysBlockGuardrail()], + ) + + with pytest.raises(InputCheckError, match="blocked by guardrail"): + agent.run("hello") + + @pytest.mark.asyncio + async def test_agent_arun_raises_input_check_error(self): + """Async non-stream: agent.arun() must propagate InputCheckError.""" + agent = Agent( + model=_DummyModel(), + pre_hooks=[AlwaysBlockGuardrail()], + ) + + with pytest.raises(InputCheckError, match="blocked by guardrail"): + await agent.arun("hello") + + def test_agent_run_stream_raises_input_check_error(self): + """Sync stream: consuming agent.run(stream=True) must propagate + InputCheckError after yielding the error event.""" + agent = Agent( + model=_DummyModel(), + pre_hooks=[AlwaysBlockGuardrail()], + ) + + with pytest.raises(InputCheckError, match="blocked by guardrail"): + for _ in agent.run("hello", stream=True): + pass + + @pytest.mark.asyncio + async def test_agent_arun_stream_raises_input_check_error(self): + """Async stream: consuming agent.arun(stream=True) must propagate + InputCheckError after yielding the error event.""" + agent = Agent( + model=_DummyModel(), + pre_hooks=[AlwaysBlockGuardrail()], + ) + + with pytest.raises(InputCheckError, match="blocked by guardrail"): + async for _ in agent.arun("hello", stream=True): + pass + + def test_agent_run_plain_hook_raises_input_check_error(self): + """Sync non-stream with a plain callable pre_hook (not a BaseGuardrail).""" + + def blocking_hook(run_input, **kwargs): + raise InputCheckError("plain hook blocked") + + agent = Agent( + model=_DummyModel(), + pre_hooks=[blocking_hook], + ) + + with pytest.raises(InputCheckError, match="plain hook blocked"): + agent.run("hello") + + @pytest.mark.asyncio + async def test_agent_arun_plain_hook_raises_input_check_error(self): + """Async non-stream with a plain callable pre_hook (not a BaseGuardrail).""" + + async def async_blocking_hook(run_input, **kwargs): + raise InputCheckError("async plain hook blocked") + + agent = Agent( + model=_DummyModel(), + pre_hooks=[async_blocking_hook], + ) + + with pytest.raises(InputCheckError, match="async plain hook blocked"): + await agent.arun("hello")