Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 95 additions & 44 deletions src/llm_tools_mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,92 @@
import os
import traceback
import uuid
import asyncio
from contextlib import asynccontextmanager
from typing import TextIO
from dataclasses import dataclass
from typing import TextIO, Any, AsyncContextManager


@dataclass
class _CachedSession:
session: ClientSession
connection_cm: AsyncContextManager[Any]
session_cm: AsyncContextManager[ClientSession]
log_file: TextIO | None

async def close(self) -> None:
await self.session_cm.__aexit__(None, None, None)
await self.connection_cm.__aexit__(None, None, None)
if self.log_file:
self.log_file.close()


class McpClient:
def __init__(self, config: McpConfig):
self.config = config
self._sessions: dict[str, _CachedSession] = {}

async def __aenter__(self) -> "McpClient":
return self

async def __aexit__(self, exc_type, exc, tb) -> None:
await self.close()

def __del__(self) -> None:
if self._sessions:
try:
asyncio.run(self.close())
except Exception:
pass

async def _get_or_create_session(self, name: str) -> ClientSession | None:
if name in self._sessions:
return self._sessions[name].session

server_config = self.config.get().mcpServers.get(name)
if not server_config:
raise ValueError(f"There is no such MCP server: {name}")

log_file: TextIO | None = None
connection_cm: AsyncContextManager[Any]

if isinstance(server_config, SseServerConfig):
connection_cm = sse_client(server_config.url)
read, write = await connection_cm.__aenter__()
elif isinstance(server_config, HttpServerConfig):
connection_cm = streamablehttp_client(server_config.url)
read, write, _ = await connection_cm.__aenter__()
elif isinstance(server_config, StdioServerConfig):
params = StdioServerParameters(
command=server_config.command,
args=server_config.args or [],
env=server_config.env,
)
log_file = self._log_file_for_session(name)
connection_cm = stdio_client(params, errlog=log_file)
read, write = await connection_cm.__aenter__()
else:
raise ValueError(f"Unknown server config type: {type(server_config)}")

assert connection_cm is not None
session_cm = self._client_session_with_logging(name, read, write)
session = await session_cm.__aenter__()

if session is None:
await session_cm.__aexit__(None, None, None)
await connection_cm.__aexit__(None, None, None)
if log_file:
log_file.close()
return None

self._sessions[name] = _CachedSession(
session=session,
connection_cm=connection_cm,
session_cm=session_cm,
log_file=log_file,
)

return session

@asynccontextmanager
async def _client_session_with_logging(self, name, read, write):
Expand All @@ -44,38 +123,6 @@ async def _client_session_with_logging(self, name, read, write):
print(traceback.format_exc(), file=sys.stderr)
yield None

@asynccontextmanager
async def _client_session(self, name: str):
server_config = self.config.get().mcpServers.get(name)
if not server_config:
raise ValueError(f"There is no such MCP server: {name}")
if isinstance(server_config, SseServerConfig):
async with sse_client(server_config.url) as (read, write):
async with self._client_session_with_logging(
name, read, write
) as session:
yield session
elif isinstance(server_config, HttpServerConfig):
async with streamablehttp_client(server_config.url) as (read, write, _):
async with self._client_session_with_logging(
name, read, write
) as session:
yield session
elif isinstance(server_config, StdioServerConfig):
params = StdioServerParameters(
command=server_config.command,
args=server_config.args or [],
env=server_config.env,
)
log_file = self._log_file_for_session(name)
async with stdio_client(params, errlog=log_file) as (read, write):
async with self._client_session_with_logging(
name, read, write
) as session:
yield session
else:
raise ValueError(f"Unknown server config type: {type(server_config)}")

def _log_file_for_session(self, name: str) -> TextIO:
log_file = (
self.config.log_path.parent
Expand All @@ -86,10 +133,10 @@ def _log_file_for_session(self, name: str) -> TextIO:
return open(log_file, "w")

async def get_tools_for(self, name: str) -> ListToolsResult:
async with self._client_session(name) as session:
if session is None:
return ListToolsResult(tools=[])
return await session.list_tools()
session = await self._get_or_create_session(name)
if session is None:
return ListToolsResult(tools=[])
return await session.list_tools()

async def get_all_tools(self) -> dict[str, list[Tool]]:
tools_for_server: dict[str, list[Tool]] = dict()
Expand All @@ -99,10 +146,14 @@ async def get_all_tools(self) -> dict[str, list[Tool]]:
return tools_for_server

async def call_tool(self, server_name: str, name: str, **kwargs):
async with self._client_session(server_name) as session:
if session is None:
return (
f"Error: Failed to call tool {name} from MCP server {server_name}"
)
tool_result = await session.call_tool(name, kwargs)
return str(tool_result.content)
session = await self._get_or_create_session(server_name)
if session is None:
return f"Error: Failed to call tool {name} from MCP server {server_name}"
tool_result = await session.call_tool(name, kwargs)
return str(tool_result.content)

async def close(self) -> None:
"""Close all cached MCP sessions."""
for name, cached in list(self._sessions.items()):
await cached.close()
del self._sessions[name]
11 changes: 6 additions & 5 deletions src/llm_tools_mcp/register_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,15 @@ def compute_tools(config_path: str = DEFAULT_MCP_JSON_PATH) -> list[llm.Tool]:
nonlocal mcp_client
previous_config = mcp_config.get() if mcp_config else None
new_mcp_config = McpConfig.for_file_path(config_path)
new_mcp_client = McpClient(new_mcp_config)
if previous_config is None or new_mcp_config.get() != previous_config:
tools = _get_tools_for_llm(new_mcp_client)
mcp_client = new_mcp_client
if mcp_client is None or new_mcp_config.get() != previous_config:
if mcp_client is not None:
asyncio.run(mcp_client.close())
mcp_client = McpClient(new_mcp_config)
mcp_config = new_mcp_config
tools = _get_tools_for_llm(mcp_client)
else:
if tools is None:
tools = _get_tools_for_llm(new_mcp_client)
tools = _get_tools_for_llm(mcp_client)
return tools

class MCP(llm.Toolbox):
Expand Down
2 changes: 2 additions & 0 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ async def test_sse_deepwiki_mcp():
)
assert result is not None, "Tool call should return a result"
assert "react" in str(result).lower(), "Available pages for facebook/react"
await mcp_client.close()


@pytest.mark.asyncio
Expand Down Expand Up @@ -60,3 +61,4 @@ async def test_remote_fetch_mcp():
assert "fetch" in tool_names, (
f"Should have a fetching tool. Found tools: {tool_names}"
)
await mcp_client.close()
3 changes: 3 additions & 0 deletions tests/test_llm_tools_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ async def long_task(files: list[str], ctx: Context) -> str:
assert result is not None, "Tool call should return a result"
result_str = str(result)
assert "Tool output" in result_str, "Should find completion message"
await mcp_client.close()


@pytest.mark.asyncio
Expand Down Expand Up @@ -128,6 +129,7 @@ async def test_stdio():
assert result is not None, "Tool call should return a result"
result_str = str(result)
assert "This is test file 1" in result_str, "Should find test file content"
await mcp_client.close()

finally:
import shutil
Expand Down Expand Up @@ -173,3 +175,4 @@ async def long_task_http(files: list[str]) -> str:
assert result is not None, "Tool call should return a result"
result_str = str(result)
assert "Tool output" in result_str, "Should find completion message"
await mcp_client.close()
Loading