Skip to content

Commit 484338e

Browse files
committed
feat: improve interactive CLI Ctrl+C handling
- Modify enhanced_prompt.py to allow session continuation when Ctrl+C is pressed at the input prompt. - Modify interactive_prompt.py to catch KeyboardInterrupt and asyncio.CancelledError during agent tasks, allowing the session to continue after task cancellation.
1 parent 0cc2447 commit 484338e

File tree

7 files changed

+163
-18
lines changed

7 files changed

+163
-18
lines changed

src/mcp_agent/core/enhanced_prompt.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,10 +320,12 @@ def pre_process_input(text):
320320
result = await session.prompt_async(HTML(prompt_text), default=default)
321321
return pre_process_input(result)
322322
except KeyboardInterrupt:
323-
# Handle Ctrl+C gracefully
324-
return "STOP"
323+
# Handle Ctrl+C gracefully at the prompt
324+
rich_print("\n[yellow]Input cancelled. Type a command or 'STOP' to exit session.[/yellow]")
325+
return "" # Return empty string to re-prompt
325326
except EOFError:
326327
# Handle Ctrl+D gracefully
328+
rich_print("\n[yellow]EOF received. Type 'STOP' to exit session.[/yellow]")
327329
return "STOP"
328330
except Exception as e:
329331
# Log and gracefully handle other exceptions

src/mcp_agent/core/interactive_prompt.py

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
"""
1616

17+
import asyncio
1718
from typing import Awaitable, Callable, Dict, List, Mapping, Optional, Protocol, Union
1819

1920
from mcp.types import Prompt, PromptMessage
@@ -38,12 +39,20 @@
3839

3940
class PromptProvider(Protocol):
4041
"""Protocol for objects that can provide prompt functionality."""
41-
42-
async def list_prompts(self, server_name: Optional[str] = None, agent_name: Optional[str] = None) -> Mapping[str, List[Prompt]]:
42+
43+
async def list_prompts(
44+
self, server_name: Optional[str] = None, agent_name: Optional[str] = None
45+
) -> Mapping[str, List[Prompt]]:
4346
"""List available prompts."""
4447
...
45-
46-
async def apply_prompt(self, prompt_name: str, arguments: Optional[Dict[str, str]] = None, agent_name: Optional[str] = None, **kwargs) -> str:
48+
49+
async def apply_prompt(
50+
self,
51+
prompt_name: str,
52+
arguments: Optional[Dict[str, str]] = None,
53+
agent_name: Optional[str] = None,
54+
**kwargs,
55+
) -> str:
4756
"""Apply a prompt."""
4857
...
4958

@@ -160,17 +169,19 @@ async def prompt_loop(
160169
await self._list_prompts(prompt_provider, agent)
161170
else:
162171
# Use the name-based selection
163-
await self._select_prompt(
164-
prompt_provider, agent, prompt_name
165-
)
172+
await self._select_prompt(prompt_provider, agent, prompt_name)
166173
continue
167174

168175
# Skip further processing if:
169176
# 1. The command was handled (command_result is truthy)
170177
# 2. The original input was a dictionary (special command like /prompt)
171178
# 3. The command result itself is a dictionary (special command handling result)
172179
# This fixes the issue where /prompt without arguments gets sent to the LLM
173-
if command_result or isinstance(user_input, dict) or isinstance(command_result, dict):
180+
if (
181+
command_result
182+
or isinstance(user_input, dict)
183+
or isinstance(command_result, dict)
184+
):
174185
continue
175186

176187
if user_input.upper() == "STOP":
@@ -179,11 +190,45 @@ async def prompt_loop(
179190
continue
180191

181192
# Send the message to the agent
182-
result = await send_func(user_input, agent)
193+
try:
194+
result = await send_func(user_input, agent)
195+
except KeyboardInterrupt:
196+
rich_print("\n[yellow]Request cancelled by user (Ctrl+C).[/yellow]")
197+
result = "" # Ensure result has a benign value for the loop
198+
# Attempt to stop progress display safely
199+
try:
200+
# For rich.progress.Progress, 'progress_display.live.is_started' is a common check
201+
if hasattr(progress_display, "live") and progress_display.live.is_started:
202+
progress_display.stop()
203+
# Fallback for older rich or different progress setup
204+
elif hasattr(progress_display, "is_running") and progress_display.is_running:
205+
progress_display.stop()
206+
else: # If unsure, try stopping directly if stop() is available
207+
if hasattr(progress_display, "stop"):
208+
progress_display.stop()
209+
except Exception:
210+
pass # Continue anyway, don't let progress display crash the cancel
211+
continue
212+
except asyncio.CancelledError:
213+
rich_print("\n[yellow]Request task was cancelled.[/yellow]")
214+
result = ""
215+
try:
216+
if hasattr(progress_display, "live") and progress_display.live.is_started:
217+
progress_display.stop()
218+
elif hasattr(progress_display, "is_running") and progress_display.is_running:
219+
progress_display.stop()
220+
else:
221+
if hasattr(progress_display, "stop"):
222+
progress_display.stop()
223+
except Exception:
224+
pass
225+
continue
183226

184227
return result
185228

186-
async def _get_all_prompts(self, prompt_provider: PromptProvider, agent_name: Optional[str] = None):
229+
async def _get_all_prompts(
230+
self, prompt_provider: PromptProvider, agent_name: Optional[str] = None
231+
):
187232
"""
188233
Get a list of all available prompts.
189234
@@ -196,8 +241,10 @@ async def _get_all_prompts(self, prompt_provider: PromptProvider, agent_name: Op
196241
"""
197242
try:
198243
# Call list_prompts on the provider
199-
prompt_servers = await prompt_provider.list_prompts(server_name=None, agent_name=agent_name)
200-
244+
prompt_servers = await prompt_provider.list_prompts(
245+
server_name=None, agent_name=agent_name
246+
)
247+
201248
all_prompts = []
202249

203250
# Process the returned prompt servers
@@ -326,9 +373,11 @@ async def _select_prompt(
326373
try:
327374
# Get all available prompts directly from the prompt provider
328375
rich_print(f"\n[bold]Fetching prompts for agent [cyan]{agent_name}[/cyan]...[/bold]")
329-
376+
330377
# Call list_prompts on the provider
331-
prompt_servers = await prompt_provider.list_prompts(server_name=None, agent_name=agent_name)
378+
prompt_servers = await prompt_provider.list_prompts(
379+
server_name=None, agent_name=agent_name
380+
)
332381

333382
if not prompt_servers:
334383
rich_print("[yellow]No prompts available for this agent[/yellow]")
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import asyncio
2+
from typing import Any, List, Optional, Union
3+
4+
from mcp_agent.llm.augmented_llm import (
5+
MessageParamT,
6+
RequestParams,
7+
)
8+
from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM
9+
from mcp_agent.llm.provider_types import Provider
10+
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
11+
12+
13+
class SlowLLM(PassthroughLLM):
14+
"""
15+
A specialized LLM implementation that sleeps for 3 seconds before responding like PassthroughLLM.
16+
17+
This is useful for testing scenarios where you want to simulate slow responses
18+
or for debugging timing-related issues in parallel workflows.
19+
"""
20+
21+
def __init__(
22+
self, provider=Provider.FAST_AGENT, name: str = "Slow", **kwargs: dict[str, Any]
23+
) -> None:
24+
super().__init__(name=name, provider=provider, **kwargs)
25+
26+
async def generate_str(
27+
self,
28+
message: Union[str, MessageParamT, List[MessageParamT]],
29+
request_params: Optional[RequestParams] = None,
30+
) -> str:
31+
"""Sleep for 3 seconds then return the input message as a string."""
32+
await asyncio.sleep(3)
33+
return await super().generate_str(message, request_params)
34+
35+
async def _apply_prompt_provider_specific(
36+
self,
37+
multipart_messages: List["PromptMessageMultipart"],
38+
request_params: RequestParams | None = None,
39+
) -> PromptMessageMultipart:
40+
"""Sleep for 3 seconds then apply prompt like PassthroughLLM."""
41+
await asyncio.sleep(3)
42+
return await super()._apply_prompt_provider_specific(multipart_messages, request_params)

src/mcp_agent/llm/model_factory.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from mcp_agent.core.request_params import RequestParams
99
from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM
1010
from mcp_agent.llm.augmented_llm_playback import PlaybackLLM
11+
from mcp_agent.llm.augmented_llm_slow import SlowLLM
1112
from mcp_agent.llm.provider_types import Provider
1213
from mcp_agent.llm.providers.augmented_llm_anthropic import AnthropicAugmentedLLM
1314
from mcp_agent.llm.providers.augmented_llm_azure import AzureOpenAIAugmentedLLM
@@ -29,6 +30,7 @@
2930
Type[OpenAIAugmentedLLM],
3031
Type[PassthroughLLM],
3132
Type[PlaybackLLM],
33+
Type[SlowLLM],
3234
Type[DeepSeekAugmentedLLM],
3335
Type[OpenRouterAugmentedLLM],
3436
Type[TensorZeroAugmentedLLM],
@@ -73,6 +75,7 @@ class ModelFactory:
7375
DEFAULT_PROVIDERS = {
7476
"passthrough": Provider.FAST_AGENT,
7577
"playback": Provider.FAST_AGENT,
78+
"slow": Provider.FAST_AGENT,
7679
"gpt-4o": Provider.OPENAI,
7780
"gpt-4o-mini": Provider.OPENAI,
7881
"gpt-4.1": Provider.OPENAI,
@@ -139,6 +142,7 @@ class ModelFactory:
139142
# This overrides the provider-based class selection
140143
MODEL_SPECIFIC_CLASSES: Dict[str, LLMClass] = {
141144
"playback": PlaybackLLM,
145+
"slow": SlowLLM,
142146
}
143147

144148
@classmethod

tests/integration/sampling/fastagent.config.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ mcp:
2323
args: ["run", "sampling_test_server.py"]
2424
sampling:
2525
model: "passthrough"
26-
26+
slow_sampling:
27+
command: "uv"
28+
args: ["run", "sampling_test_server.py"]
29+
sampling:
30+
model: "slow"
2731
sampling_test_no_config:
2832
command: "uv"
2933
args: ["run", "sampling_test_server.py"]

tests/integration/sampling/live.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77

88

99
# Define the agent
10-
@fast.agent(servers=["sampling_test"])
10+
@fast.agent(servers=["sampling_test", "slow_sampling"])
1111
async def main():
1212
# use the --model command line switch or agent arguments to change model
1313
async with fast.run() as agent:
1414
result = await agent.send('***CALL_TOOL sampling_test-sample {"to_sample": "123foo"}')
1515
print(f"RESULT: {result}")
1616

17+
result = await agent.send('***CALL_TOOL slow_sampling-sample_parallel')
18+
print(f"RESULT: {result}")
19+
1720

1821
if __name__ == "__main__":
1922
asyncio.run(main())

tests/integration/sampling/sampling_test_server.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,47 @@ async def sample_many(ctx: Context) -> CallToolResult:
6161
return CallToolResult(content=[TextContent(type="text", text=str(result))])
6262

6363

64+
@mcp.tool()
65+
async def sample_parallel(ctx: Context, count: int = 5) -> CallToolResult:
66+
"""Tool that makes multiple concurrent sampling requests to test parallel processing"""
67+
try:
68+
logger.info(f"Making {count} concurrent sampling requests")
69+
70+
# Create multiple concurrent sampling requests
71+
import asyncio
72+
73+
async def _send_sampling(request: int):
74+
return await ctx.session.create_message(
75+
max_tokens=100,
76+
messages=[SamplingMessage(
77+
role="user",
78+
content=TextContent(type="text", text=f"Parallel request {request+1}")
79+
)],
80+
)
81+
82+
83+
tasks = []
84+
for i in range(count):
85+
task = _send_sampling(i)
86+
tasks.append(task)
87+
88+
# Execute all requests concurrently
89+
results = await asyncio.gather(*[_send_sampling(i) for i in range(count)])
90+
91+
# Combine results
92+
response_texts = [result.content.text for result in results]
93+
combined_response = f"Completed {len(results)} parallel requests: " + ", ".join(response_texts[:3])
94+
if len(response_texts) > 3:
95+
combined_response += f"... and {len(response_texts) - 3} more"
96+
97+
logger.info(f"Parallel sampling completed: {combined_response}")
98+
return CallToolResult(content=[TextContent(type="text", text=combined_response)])
99+
100+
except Exception as e:
101+
logger.error(f"Error in sample_parallel tool: {e}", exc_info=True)
102+
return CallToolResult(isError=True, content=[TextContent(type="text", text=f"Error: {str(e)}")])
103+
104+
64105
if __name__ == "__main__":
65106
logger.info("Starting sampling test server...")
66107
mcp.run()

0 commit comments

Comments
 (0)