Skip to content

Commit 983400b

Browse files
wreed4evalstate
andauthored
added slow llm to test parallel sampling (#197)
* added slow llm to test parallel sampling * linter --------- Co-authored-by: evalstate <[email protected]>
1 parent 0cc2447 commit 983400b

File tree

5 files changed

+96
-2
lines changed

5 files changed

+96
-2
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import asyncio
2+
from typing import Any, List, Optional, Union
3+
4+
from mcp_agent.llm.augmented_llm import (
5+
MessageParamT,
6+
RequestParams,
7+
)
8+
from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM
9+
from mcp_agent.llm.provider_types import Provider
10+
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
11+
12+
13+
class SlowLLM(PassthroughLLM):
14+
"""
15+
A specialized LLM implementation that sleeps for 3 seconds before responding like PassthroughLLM.
16+
17+
This is useful for testing scenarios where you want to simulate slow responses
18+
or for debugging timing-related issues in parallel workflows.
19+
"""
20+
21+
def __init__(
22+
self, provider=Provider.FAST_AGENT, name: str = "Slow", **kwargs: dict[str, Any]
23+
) -> None:
24+
super().__init__(name=name, provider=provider, **kwargs)
25+
26+
async def generate_str(
27+
self,
28+
message: Union[str, MessageParamT, List[MessageParamT]],
29+
request_params: Optional[RequestParams] = None,
30+
) -> str:
31+
"""Sleep for 3 seconds then return the input message as a string."""
32+
await asyncio.sleep(3)
33+
return await super().generate_str(message, request_params)
34+
35+
async def _apply_prompt_provider_specific(
36+
self,
37+
multipart_messages: List["PromptMessageMultipart"],
38+
request_params: RequestParams | None = None,
39+
) -> PromptMessageMultipart:
40+
"""Sleep for 3 seconds then apply prompt like PassthroughLLM."""
41+
await asyncio.sleep(3)
42+
return await super()._apply_prompt_provider_specific(multipart_messages, request_params)

src/mcp_agent/llm/model_factory.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from mcp_agent.core.request_params import RequestParams
99
from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM
1010
from mcp_agent.llm.augmented_llm_playback import PlaybackLLM
11+
from mcp_agent.llm.augmented_llm_slow import SlowLLM
1112
from mcp_agent.llm.provider_types import Provider
1213
from mcp_agent.llm.providers.augmented_llm_anthropic import AnthropicAugmentedLLM
1314
from mcp_agent.llm.providers.augmented_llm_azure import AzureOpenAIAugmentedLLM
@@ -29,6 +30,7 @@
2930
Type[OpenAIAugmentedLLM],
3031
Type[PassthroughLLM],
3132
Type[PlaybackLLM],
33+
Type[SlowLLM],
3234
Type[DeepSeekAugmentedLLM],
3335
Type[OpenRouterAugmentedLLM],
3436
Type[TensorZeroAugmentedLLM],
@@ -73,6 +75,7 @@ class ModelFactory:
7375
DEFAULT_PROVIDERS = {
7476
"passthrough": Provider.FAST_AGENT,
7577
"playback": Provider.FAST_AGENT,
78+
"slow": Provider.FAST_AGENT,
7679
"gpt-4o": Provider.OPENAI,
7780
"gpt-4o-mini": Provider.OPENAI,
7881
"gpt-4.1": Provider.OPENAI,
@@ -139,6 +142,7 @@ class ModelFactory:
139142
# This overrides the provider-based class selection
140143
MODEL_SPECIFIC_CLASSES: Dict[str, LLMClass] = {
141144
"playback": PlaybackLLM,
145+
"slow": SlowLLM,
142146
}
143147

144148
@classmethod

tests/integration/sampling/fastagent.config.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ mcp:
2323
args: ["run", "sampling_test_server.py"]
2424
sampling:
2525
model: "passthrough"
26-
26+
slow_sampling:
27+
command: "uv"
28+
args: ["run", "sampling_test_server.py"]
29+
sampling:
30+
model: "slow"
2731
sampling_test_no_config:
2832
command: "uv"
2933
args: ["run", "sampling_test_server.py"]

tests/integration/sampling/live.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77

88

99
# Define the agent
10-
@fast.agent(servers=["sampling_test"])
10+
@fast.agent(servers=["sampling_test", "slow_sampling"])
1111
async def main():
1212
# use the --model command line switch or agent arguments to change model
1313
async with fast.run() as agent:
1414
result = await agent.send('***CALL_TOOL sampling_test-sample {"to_sample": "123foo"}')
1515
print(f"RESULT: {result}")
1616

17+
result = await agent.send('***CALL_TOOL slow_sampling-sample_parallel')
18+
print(f"RESULT: {result}")
19+
1720

1821
if __name__ == "__main__":
1922
asyncio.run(main())

tests/integration/sampling/sampling_test_server.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,47 @@ async def sample_many(ctx: Context) -> CallToolResult:
6161
return CallToolResult(content=[TextContent(type="text", text=str(result))])
6262

6363

64+
@mcp.tool()
65+
async def sample_parallel(ctx: Context, count: int = 5) -> CallToolResult:
66+
"""Tool that makes multiple concurrent sampling requests to test parallel processing"""
67+
try:
68+
logger.info(f"Making {count} concurrent sampling requests")
69+
70+
# Create multiple concurrent sampling requests
71+
import asyncio
72+
73+
async def _send_sampling(request: int):
74+
return await ctx.session.create_message(
75+
max_tokens=100,
76+
messages=[SamplingMessage(
77+
role="user",
78+
content=TextContent(type="text", text=f"Parallel request {request+1}")
79+
)],
80+
)
81+
82+
83+
tasks = []
84+
for i in range(count):
85+
task = _send_sampling(i)
86+
tasks.append(task)
87+
88+
# Execute all requests concurrently
89+
results = await asyncio.gather(*[_send_sampling(i) for i in range(count)])
90+
91+
# Combine results
92+
response_texts = [result.content.text for result in results]
93+
combined_response = f"Completed {len(results)} parallel requests: " + ", ".join(response_texts[:3])
94+
if len(response_texts) > 3:
95+
combined_response += f"... and {len(response_texts) - 3} more"
96+
97+
logger.info(f"Parallel sampling completed: {combined_response}")
98+
return CallToolResult(content=[TextContent(type="text", text=combined_response)])
99+
100+
except Exception as e:
101+
logger.error(f"Error in sample_parallel tool: {e}", exc_info=True)
102+
return CallToolResult(isError=True, content=[TextContent(type="text", text=f"Error: {str(e)}")])
103+
104+
64105
if __name__ == "__main__":
65106
logger.info("Starting sampling test server...")
66107
mcp.run()

0 commit comments

Comments
 (0)