Skip to content

Commit ffd57ad

Browse files
committed
feat: improve interactive CLI Ctrl+C handling
- Modify enhanced_prompt.py to allow session continuation when Ctrl+C is pressed at the input prompt. - Modify interactive_prompt.py to catch KeyboardInterrupt and asyncio.CancelledError during agent tasks, allowing the session to continue after task cancellation.
1 parent 983400b commit ffd57ad

File tree

2 files changed

+67
-16
lines changed

2 files changed

+67
-16
lines changed

src/mcp_agent/core/enhanced_prompt.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,10 +320,12 @@ def pre_process_input(text):
320320
result = await session.prompt_async(HTML(prompt_text), default=default)
321321
return pre_process_input(result)
322322
except KeyboardInterrupt:
323-
# Handle Ctrl+C gracefully
324-
return "STOP"
323+
# Handle Ctrl+C gracefully at the prompt
324+
rich_print("\n[yellow]Input cancelled. Type a command or 'STOP' to exit session.[/yellow]")
325+
return "" # Return empty string to re-prompt
325326
except EOFError:
326327
# Handle Ctrl+D gracefully
328+
rich_print("\n[yellow]EOF received. Type 'STOP' to exit session.[/yellow]")
327329
return "STOP"
328330
except Exception as e:
329331
# Log and gracefully handle other exceptions

src/mcp_agent/core/interactive_prompt.py

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
"""
1616

17+
import asyncio
1718
from typing import Awaitable, Callable, Dict, List, Mapping, Optional, Protocol, Union
1819

1920
from mcp.types import Prompt, PromptMessage
@@ -38,12 +39,20 @@
3839

3940
class PromptProvider(Protocol):
4041
"""Protocol for objects that can provide prompt functionality."""
41-
42-
async def list_prompts(self, server_name: Optional[str] = None, agent_name: Optional[str] = None) -> Mapping[str, List[Prompt]]:
42+
43+
async def list_prompts(
44+
self, server_name: Optional[str] = None, agent_name: Optional[str] = None
45+
) -> Mapping[str, List[Prompt]]:
4346
"""List available prompts."""
4447
...
45-
46-
async def apply_prompt(self, prompt_name: str, arguments: Optional[Dict[str, str]] = None, agent_name: Optional[str] = None, **kwargs) -> str:
48+
49+
async def apply_prompt(
50+
self,
51+
prompt_name: str,
52+
arguments: Optional[Dict[str, str]] = None,
53+
agent_name: Optional[str] = None,
54+
**kwargs,
55+
) -> str:
4756
"""Apply a prompt."""
4857
...
4958

@@ -160,17 +169,19 @@ async def prompt_loop(
160169
await self._list_prompts(prompt_provider, agent)
161170
else:
162171
# Use the name-based selection
163-
await self._select_prompt(
164-
prompt_provider, agent, prompt_name
165-
)
172+
await self._select_prompt(prompt_provider, agent, prompt_name)
166173
continue
167174

168175
# Skip further processing if:
169176
# 1. The command was handled (command_result is truthy)
170177
# 2. The original input was a dictionary (special command like /prompt)
171178
# 3. The command result itself is a dictionary (special command handling result)
172179
# This fixes the issue where /prompt without arguments gets sent to the LLM
173-
if command_result or isinstance(user_input, dict) or isinstance(command_result, dict):
180+
if (
181+
command_result
182+
or isinstance(user_input, dict)
183+
or isinstance(command_result, dict)
184+
):
174185
continue
175186

176187
if user_input.upper() == "STOP":
@@ -179,11 +190,45 @@ async def prompt_loop(
179190
continue
180191

181192
# Send the message to the agent
182-
result = await send_func(user_input, agent)
193+
try:
194+
result = await send_func(user_input, agent)
195+
except KeyboardInterrupt:
196+
rich_print("\n[yellow]Request cancelled by user (Ctrl+C).[/yellow]")
197+
result = "" # Ensure result has a benign value for the loop
198+
# Attempt to stop progress display safely
199+
try:
200+
# For rich.progress.Progress, 'progress_display.live.is_started' is a common check
201+
if hasattr(progress_display, "live") and progress_display.live.is_started:
202+
progress_display.stop()
203+
# Fallback for older rich or different progress setup
204+
elif hasattr(progress_display, "is_running") and progress_display.is_running:
205+
progress_display.stop()
206+
else: # If unsure, try stopping directly if stop() is available
207+
if hasattr(progress_display, "stop"):
208+
progress_display.stop()
209+
except Exception:
210+
pass # Continue anyway, don't let progress display crash the cancel
211+
continue
212+
except asyncio.CancelledError:
213+
rich_print("\n[yellow]Request task was cancelled.[/yellow]")
214+
result = ""
215+
try:
216+
if hasattr(progress_display, "live") and progress_display.live.is_started:
217+
progress_display.stop()
218+
elif hasattr(progress_display, "is_running") and progress_display.is_running:
219+
progress_display.stop()
220+
else:
221+
if hasattr(progress_display, "stop"):
222+
progress_display.stop()
223+
except Exception:
224+
pass
225+
continue
183226

184227
return result
185228

186-
async def _get_all_prompts(self, prompt_provider: PromptProvider, agent_name: Optional[str] = None):
229+
async def _get_all_prompts(
230+
self, prompt_provider: PromptProvider, agent_name: Optional[str] = None
231+
):
187232
"""
188233
Get a list of all available prompts.
189234
@@ -196,8 +241,10 @@ async def _get_all_prompts(self, prompt_provider: PromptProvider, agent_name: Op
196241
"""
197242
try:
198243
# Call list_prompts on the provider
199-
prompt_servers = await prompt_provider.list_prompts(server_name=None, agent_name=agent_name)
200-
244+
prompt_servers = await prompt_provider.list_prompts(
245+
server_name=None, agent_name=agent_name
246+
)
247+
201248
all_prompts = []
202249

203250
# Process the returned prompt servers
@@ -326,9 +373,11 @@ async def _select_prompt(
326373
try:
327374
# Get all available prompts directly from the prompt provider
328375
rich_print(f"\n[bold]Fetching prompts for agent [cyan]{agent_name}[/cyan]...[/bold]")
329-
376+
330377
# Call list_prompts on the provider
331-
prompt_servers = await prompt_provider.list_prompts(server_name=None, agent_name=agent_name)
378+
prompt_servers = await prompt_provider.list_prompts(
379+
server_name=None, agent_name=agent_name
380+
)
332381

333382
if not prompt_servers:
334383
rich_print("[yellow]No prompts available for this agent[/yellow]")

0 commit comments

Comments
 (0)