Skip to content

Commit 7e061b2

Browse files
feat: implement early exit mechanism for SequentialAgent using escalate action
1 parent 902994e commit 7e061b2

File tree

2 files changed

+120
-109
lines changed

2 files changed

+120
-109
lines changed

src/google/adk/agents/sequential_agent.py

Lines changed: 118 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -33,126 +33,135 @@
3333
from .llm_agent import LlmAgent
3434
from .sequential_agent_config import SequentialAgentConfig
3535

36-
logger = logging.getLogger('google_adk.' + __name__)
36+
logger = logging.getLogger("google_adk." + __name__)
3737

3838

3939
@experimental
4040
class SequentialAgentState(BaseAgentState):
41-
"""State for SequentialAgent."""
41+
"""State for SequentialAgent."""
4242

43-
current_sub_agent: str = ''
44-
"""The name of the current sub-agent to run."""
43+
current_sub_agent: str = ""
44+
"""The name of the current sub-agent to run."""
4545

4646

4747
class SequentialAgent(BaseAgent):
48-
"""A shell agent that runs its sub-agents in sequence."""
49-
50-
config_type: ClassVar[Type[BaseAgentConfig]] = SequentialAgentConfig
51-
"""The config type for this agent."""
52-
53-
@override
54-
async def _run_async_impl(
55-
self, ctx: InvocationContext
56-
) -> AsyncGenerator[Event, None]:
57-
if not self.sub_agents:
58-
return
59-
60-
# Initialize or resume the execution state from the agent state.
61-
agent_state = self._load_agent_state(ctx, SequentialAgentState)
62-
start_index = self._get_start_index(agent_state)
63-
64-
pause_invocation = False
65-
resuming_sub_agent = agent_state is not None
66-
for i in range(start_index, len(self.sub_agents)):
67-
sub_agent = self.sub_agents[i]
68-
if not resuming_sub_agent:
69-
# If we are resuming from the current event, it means the same event has
70-
# already been logged, so we should avoid yielding it again.
48+
"""A shell agent that runs its sub-agents in sequence."""
49+
50+
config_type: ClassVar[Type[BaseAgentConfig]] = SequentialAgentConfig
51+
"""The config type for this agent."""
52+
53+
@override
54+
async def _run_async_impl(
55+
self, ctx: InvocationContext
56+
) -> AsyncGenerator[Event, None]:
57+
if not self.sub_agents:
58+
return
59+
60+
# Initialize or resume the execution state from the agent state.
61+
agent_state = self._load_agent_state(ctx, SequentialAgentState)
62+
start_index = self._get_start_index(agent_state)
63+
64+
pause_invocation = False
65+
resuming_sub_agent = agent_state is not None
66+
for i in range(start_index, len(self.sub_agents)):
67+
sub_agent = self.sub_agents[i]
68+
if not resuming_sub_agent:
69+
# If we are resuming from the current event, it means the same event has
70+
# already been logged, so we should avoid yielding it again.
71+
if ctx.is_resumable:
72+
agent_state = SequentialAgentState(current_sub_agent=sub_agent.name)
73+
ctx.set_agent_state(self.name, agent_state=agent_state)
74+
yield self._create_agent_state_event(ctx)
75+
76+
async with Aclosing(sub_agent.run_async(ctx)) as agen:
77+
async for event in agen:
78+
yield event
79+
if ctx.should_pause_invocation(event):
80+
pause_invocation = True
81+
# Check for escalate action to enable early exit from the sequence.
82+
# When escalate is set, we terminate immediately, stopping both
83+
# subsequent events in the current agent and all remaining agents.
84+
# Note: escalate takes precedence over pause_invocation.
85+
if event.actions and event.actions.escalate:
86+
return
87+
88+
# Skip the rest of the sub-agents if the invocation is paused.
89+
if pause_invocation:
90+
return
91+
92+
# Reset the flag for the next sub-agent.
93+
resuming_sub_agent = False
94+
7195
if ctx.is_resumable:
72-
agent_state = SequentialAgentState(current_sub_agent=sub_agent.name)
73-
ctx.set_agent_state(self.name, agent_state=agent_state)
74-
yield self._create_agent_state_event(ctx)
75-
76-
async with Aclosing(sub_agent.run_async(ctx)) as agen:
77-
async for event in agen:
78-
yield event
79-
if ctx.should_pause_invocation(event):
80-
pause_invocation = True
81-
82-
# Skip the rest of the sub-agents if the invocation is paused.
83-
if pause_invocation:
84-
return
85-
86-
# Reset the flag for the next sub-agent.
87-
resuming_sub_agent = False
88-
89-
if ctx.is_resumable:
90-
ctx.set_agent_state(self.name, end_of_agent=True)
91-
yield self._create_agent_state_event(ctx)
92-
93-
def _get_start_index(
94-
self,
95-
agent_state: SequentialAgentState,
96-
) -> int:
97-
"""Calculates the start index for the sub-agent loop."""
98-
if not agent_state:
99-
return 0
100-
101-
if not agent_state.current_sub_agent:
102-
# This means the process was finished.
103-
return len(self.sub_agents)
104-
105-
try:
106-
sub_agent_names = [sub_agent.name for sub_agent in self.sub_agents]
107-
return sub_agent_names.index(agent_state.current_sub_agent)
108-
except ValueError:
109-
# A sub-agent was removed so the agent name is not found.
110-
# For now, we restart from the beginning.
111-
logger.warning(
112-
'Sub-agent %s was removed so the agent name is not found. Restarting'
113-
' from the beginning.',
114-
agent_state.current_sub_agent,
115-
)
116-
return 0
117-
118-
@override
119-
async def _run_live_impl(
120-
self, ctx: InvocationContext
121-
) -> AsyncGenerator[Event, None]:
122-
"""Implementation for live SequentialAgent.
123-
124-
Compared to the non-live case, live agents process a continuous stream of audio
125-
or video, so there is no way to tell if it's finished and should pass
126-
to the next agent or not. So we introduce a task_completed() function so the
127-
model can call this function to signal that it's finished the task and we
128-
can move on to the next agent.
129-
130-
Args:
131-
ctx: The invocation context of the agent.
132-
"""
133-
if not self.sub_agents:
134-
return
135-
136-
# There is no way to know if it's using live during init phase so we have to init it here
137-
for sub_agent in self.sub_agents:
138-
# add tool
139-
def task_completed():
140-
"""
141-
Signals that the agent has successfully completed the user's question
142-
or task.
96+
ctx.set_agent_state(self.name, end_of_agent=True)
97+
yield self._create_agent_state_event(ctx)
98+
99+
def _get_start_index(
100+
self,
101+
agent_state: SequentialAgentState,
102+
) -> int:
103+
"""Calculates the start index for the sub-agent loop."""
104+
if not agent_state:
105+
return 0
106+
107+
if not agent_state.current_sub_agent:
108+
# This means the process was finished.
109+
return len(self.sub_agents)
110+
111+
try:
112+
sub_agent_names = [sub_agent.name for sub_agent in self.sub_agents]
113+
return sub_agent_names.index(agent_state.current_sub_agent)
114+
except ValueError:
115+
# A sub-agent was removed so the agent name is not found.
116+
# For now, we restart from the beginning.
117+
logger.warning(
118+
"Sub-agent %s was removed so the agent name is not found. Restarting"
119+
" from the beginning.",
120+
agent_state.current_sub_agent,
121+
)
122+
return 0
123+
124+
@override
125+
async def _run_live_impl(
126+
self, ctx: InvocationContext
127+
) -> AsyncGenerator[Event, None]:
128+
"""Implementation for live SequentialAgent.
129+
130+
Compared to the non-live case, live agents process a continuous stream of audio
131+
or video, so there is no way to tell if it's finished and should pass
132+
to the next agent or not. So we introduce a task_completed() function so the
133+
model can call this function to signal that it's finished the task and we
134+
can move on to the next agent.
135+
136+
Args:
137+
ctx: The invocation context of the agent.
143138
"""
144-
return 'Task completion signaled.'
145-
146-
if isinstance(sub_agent, LlmAgent):
147-
# Use function name to dedupe.
148-
if task_completed.__name__ not in sub_agent.tools:
149-
sub_agent.tools.append(task_completed)
150-
sub_agent.instruction += f"""If you finished the user's request
139+
if not self.sub_agents:
140+
return
141+
142+
# There is no way to know if it's using live during init phase so we have to init it here
143+
for sub_agent in self.sub_agents:
144+
# add tool
145+
def task_completed():
146+
"""
147+
Signals that the agent has successfully completed the user's question
148+
or task.
149+
"""
150+
return "Task completion signaled."
151+
152+
if isinstance(sub_agent, LlmAgent):
153+
# Use function name to dedupe.
154+
if task_completed.__name__ not in sub_agent.tools:
155+
sub_agent.tools.append(task_completed)
156+
sub_agent.instruction += f"""If you finished the user's request
151157
according to its description, call the {task_completed.__name__} function
152158
to exit so the next agents can take over. When calling this function,
153159
do not generate any text other than the function call."""
154160

155-
for sub_agent in self.sub_agents:
156-
async with Aclosing(sub_agent.run_live(ctx)) as agen:
157-
async for event in agen:
158-
yield event
161+
for sub_agent in self.sub_agents:
162+
async with Aclosing(sub_agent.run_live(ctx)) as agen:
163+
async for event in agen:
164+
yield event
165+
# Check for escalate action to enable early exit in live mode.
166+
if event.actions and event.actions.escalate:
167+
return

src/google/adk/tools/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .enterprise_search_tool import enterprise_web_search_tool as enterprise_web_search
2323
from .example_tool import ExampleTool
2424
from .exit_loop_tool import exit_loop
25+
from .exit_sequence_tool import exit_sequence
2526
from .function_tool import FunctionTool
2627
from .get_user_choice_tool import get_user_choice_tool as get_user_choice
2728
from .google_maps_grounding_tool import google_maps_grounding
@@ -48,6 +49,7 @@
4849
'VertexAiSearchTool',
4950
'ExampleTool',
5051
'exit_loop',
52+
'exit_sequence',
5153
'FunctionTool',
5254
'get_user_choice',
5355
'load_artifacts',

0 commit comments

Comments
 (0)