Skip to content

Commit 26862e4

Browse files
authored
interrupts - decorated tools (#1041)
1 parent 7cd10b9 commit 26862e4

File tree

18 files changed

+419
-61
lines changed

18 files changed

+419
-61
lines changed

src/strands/hooks/events.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing_extensions import override
1111

1212
from ..types.content import Message
13-
from ..types.interrupt import InterruptHookEvent
13+
from ..types.interrupt import _Interruptible
1414
from ..types.streaming import StopReason
1515
from ..types.tools import AgentTool, ToolResult, ToolUse
1616
from .registry import HookEvent
@@ -88,7 +88,7 @@ class MessageAddedEvent(HookEvent):
8888

8989

9090
@dataclass
91-
class BeforeToolCallEvent(HookEvent, InterruptHookEvent):
91+
class BeforeToolCallEvent(HookEvent, _Interruptible):
9292
"""Event triggered before a tool is invoked.
9393
9494
This event is fired just before the agent executes a tool, allowing hook
@@ -124,7 +124,7 @@ def _interrupt_id(self, name: str) -> str:
124124
Returns:
125125
Interrupt id.
126126
"""
127-
return f"v1:{self.tool_use['toolUseId']}:{uuid.uuid5(uuid.NAMESPACE_OID, name)}"
127+
return f"v1:before_tool_call:{self.tool_use['toolUseId']}:{uuid.uuid5(uuid.NAMESPACE_OID, name)}"
128128

129129

130130
@dataclass

src/strands/tools/decorator.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def my_tool(param1: str, param2: int = 42) -> dict:
6262
from pydantic import BaseModel, Field, create_model
6363
from typing_extensions import override
6464

65-
from ..types._events import ToolResultEvent, ToolStreamEvent
65+
from ..interrupt import InterruptException
66+
from ..types._events import ToolInterruptEvent, ToolResultEvent, ToolStreamEvent
6667
from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolResult, ToolSpec, ToolUse
6768

6869
logger = logging.getLogger(__name__)
@@ -493,6 +494,10 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
493494
result = await asyncio.to_thread(self._tool_func, **validated_input) # type: ignore
494495
yield self._wrap_tool_result(tool_use_id, result)
495496

497+
except InterruptException as e:
498+
yield ToolInterruptEvent(tool_use, [e.interrupt])
499+
return
500+
496501
except ValueError as e:
497502
# Special handling for validation errors
498503
error_msg = str(e)

src/strands/tools/executors/_executor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,16 @@ async def _stream(
163163
# we yield it directly; all other cases (non-sdk AgentTools), we wrap events in
164164
# ToolStreamEvent and the last event is just the result.
165165

166+
if isinstance(event, ToolInterruptEvent):
167+
yield event
168+
return
169+
166170
if isinstance(event, ToolResultEvent):
167171
# below the last "event" must point to the tool_result
168172
event = event.tool_result
169173
break
170-
elif isinstance(event, ToolStreamEvent):
174+
175+
if isinstance(event, ToolStreamEvent):
171176
yield event
172177
else:
173178
yield ToolStreamEvent(tool_use, event)

src/strands/types/interrupt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ def approve(self, event: BeforeToolCallEvent) -> None:
118118
from ..agent import Agent
119119

120120

121-
class InterruptHookEvent(Protocol):
122-
"""Interface that adds interrupt support to hook events."""
121+
class _Interruptible(Protocol):
122+
"""Interface that adds interrupt support to hook events and tools."""
123123

124124
agent: "Agent"
125125

src/strands/types/tools.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
- Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html
66
"""
77

8+
import uuid
89
from abc import ABC, abstractmethod
910
from dataclasses import dataclass
1011
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union
1112

1213
from typing_extensions import NotRequired, TypedDict
1314

15+
from .interrupt import _Interruptible
1416
from .media import DocumentContent, ImageContent
1517

1618
if TYPE_CHECKING:
@@ -126,7 +128,7 @@ class ToolChoiceTool(TypedDict):
126128

127129

128130
@dataclass
129-
class ToolContext:
131+
class ToolContext(_Interruptible):
130132
"""Context object containing framework-provided data for decorated tools.
131133
132134
This object provides access to framework-level information that may be useful
@@ -148,6 +150,17 @@ class ToolContext:
148150
agent: "Agent"
149151
invocation_state: dict[str, Any]
150152

153+
def _interrupt_id(self, name: str) -> str:
154+
"""Unique id for the interrupt.
155+
156+
Args:
157+
name: User defined name for the interrupt.
158+
159+
Returns:
160+
Interrupt id.
161+
"""
162+
return f"v1:tool_call:{self.tool_use['toolUseId']}:{uuid.uuid5(uuid.NAMESPACE_OID, name)}"
163+
151164

152165
# Individual ToolChoice type aliases
153166
ToolChoiceAutoDict = dict[Literal["auto"], ToolChoiceAuto]

tests/strands/agent/test_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1957,7 +1957,7 @@ def test_agent__call__resume_interrupt(mock_model, tool_decorated, agenerator):
19571957
)
19581958

19591959
interrupt = Interrupt(
1960-
id="v1:t1:78714d6c-613c-5cf4-bf25-7037569941f9",
1960+
id="v1:before_tool_call:t1:78714d6c-613c-5cf4-bf25-7037569941f9",
19611961
name="test_name",
19621962
reason="test reason",
19631963
)

tests/strands/event_loop/test_event_loop.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,7 @@ def interrupt_callback(event):
884884
exp_stop_reason = "interrupt"
885885
exp_interrupts = [
886886
Interrupt(
887-
id="v1:t1:78714d6c-613c-5cf4-bf25-7037569941f9",
887+
id="v1:before_tool_call:t1:78714d6c-613c-5cf4-bf25-7037569941f9",
888888
name="test_name",
889889
reason="test reason",
890890
),
@@ -911,8 +911,8 @@ def interrupt_callback(event):
911911
},
912912
},
913913
"interrupts": {
914-
"v1:t1:78714d6c-613c-5cf4-bf25-7037569941f9": {
915-
"id": "v1:t1:78714d6c-613c-5cf4-bf25-7037569941f9",
914+
"v1:before_tool_call:t1:78714d6c-613c-5cf4-bf25-7037569941f9": {
915+
"id": "v1:before_tool_call:t1:78714d6c-613c-5cf4-bf25-7037569941f9",
916916
"name": "test_name",
917917
"reason": "test reason",
918918
"response": None,
@@ -925,7 +925,7 @@ def interrupt_callback(event):
925925
@pytest.mark.asyncio
926926
async def test_event_loop_cycle_interrupt_resume(agent, model, tool, tool_times_2, agenerator, alist):
927927
interrupt = Interrupt(
928-
id="v1:t1:78714d6c-613c-5cf4-bf25-7037569941f9",
928+
id="v1:before_tool_call:t1:78714d6c-613c-5cf4-bf25-7037569941f9",
929929
name="test_name",
930930
reason="test reason",
931931
response="test response",

tests/strands/hooks/test_registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ def test_hook_registry_invoke_callbacks_interrupt(registry, agent):
3838
_, tru_interrupts = registry.invoke_callbacks(event)
3939
exp_interrupts = [
4040
Interrupt(
41-
id="v1:test_tool_id:da3551f3-154b-5978-827e-50ac387877ee",
41+
id="v1:before_tool_call:test_tool_id:da3551f3-154b-5978-827e-50ac387877ee",
4242
name="test_name_1",
4343
reason="test reason 1",
4444
),
4545
Interrupt(
46-
id="v1:test_tool_id:0f5a8068-d1ba-5a48-bf67-c9d33786d8d4",
46+
id="v1:before_tool_call:test_tool_id:0f5a8068-d1ba-5a48-bf67-c9d33786d8d4",
4747
name="test_name_2",
4848
reason="test reason 2",
4949
),

tests/strands/tools/executors/conftest.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from strands.agent.interrupt import InterruptState
88
from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent, HookRegistry
99
from strands.tools.registry import ToolRegistry
10+
from strands.types.tools import ToolContext
1011

1112

1213
@pytest.fixture
@@ -79,12 +80,22 @@ def func():
7980

8081

8182
@pytest.fixture
82-
def tool_registry(weather_tool, temperature_tool, exception_tool, thread_tool):
83+
def interrupt_tool():
84+
@strands.tool(name="interrupt_tool", context=True)
85+
def func(tool_context: ToolContext) -> str:
86+
return tool_context.interrupt("test_name", reason="test reason")
87+
88+
return func
89+
90+
91+
@pytest.fixture
92+
def tool_registry(weather_tool, temperature_tool, exception_tool, thread_tool, interrupt_tool):
8393
registry = ToolRegistry()
8494
registry.register_tool(weather_tool)
8595
registry.register_tool(temperature_tool)
8696
registry.register_tool(exception_tool)
8797
registry.register_tool(thread_tool)
98+
registry.register_tool(interrupt_tool)
8899
return registry
89100

90101

@@ -113,5 +124,5 @@ def cycle_span():
113124

114125

115126
@pytest.fixture
116-
def invocation_state():
117-
return {}
127+
def invocation_state(agent):
128+
return {"agent": agent}

tests/strands/tools/executors/test_concurrent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ async def test_concurrent_executor_interrupt(
3838
executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist
3939
):
4040
interrupt = Interrupt(
41-
id="v1:test_tool_id_1:78714d6c-613c-5cf4-bf25-7037569941f9",
41+
id="v1:before_tool_call:test_tool_id_1:78714d6c-613c-5cf4-bf25-7037569941f9",
4242
name="test_name",
4343
reason="test reason",
4444
)

0 commit comments

Comments
 (0)