From 8eb31ac1b7211a045276f1a7800d7c84986259aa Mon Sep 17 00:00:00 2001 From: Michal Fudala Date: Fri, 13 Jun 2025 16:35:29 +0200 Subject: [PATCH] Refactor persistent MCP sessions --- src/llm_tools_mcp/mcp_client.py | 139 +++++++++++++++++++--------- src/llm_tools_mcp/register_tools.py | 11 ++- tests/test_integration.py | 2 + tests/test_llm_tools_mcp.py | 3 + 4 files changed, 106 insertions(+), 49 deletions(-) diff --git a/src/llm_tools_mcp/mcp_client.py b/src/llm_tools_mcp/mcp_client.py index deb7f79..708225f 100644 --- a/src/llm_tools_mcp/mcp_client.py +++ b/src/llm_tools_mcp/mcp_client.py @@ -17,13 +17,92 @@ import os import traceback import uuid +import asyncio from contextlib import asynccontextmanager -from typing import TextIO +from dataclasses import dataclass +from typing import TextIO, Any, AsyncContextManager + + +@dataclass +class _CachedSession: + session: ClientSession + connection_cm: AsyncContextManager[Any] + session_cm: AsyncContextManager[ClientSession] + log_file: TextIO | None + + async def close(self) -> None: + await self.session_cm.__aexit__(None, None, None) + await self.connection_cm.__aexit__(None, None, None) + if self.log_file: + self.log_file.close() class McpClient: def __init__(self, config: McpConfig): self.config = config + self._sessions: dict[str, _CachedSession] = {} + + async def __aenter__(self) -> "McpClient": + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.close() + + def __del__(self) -> None: + if self._sessions: + try: + asyncio.run(self.close()) + except Exception: + pass + + async def _get_or_create_session(self, name: str) -> ClientSession | None: + if name in self._sessions: + return self._sessions[name].session + + server_config = self.config.get().mcpServers.get(name) + if not server_config: + raise ValueError(f"There is no such MCP server: {name}") + + log_file: TextIO | None = None + connection_cm: AsyncContextManager[Any] + + if isinstance(server_config, SseServerConfig): + connection_cm = sse_client(server_config.url) + read, write = await connection_cm.__aenter__() + elif isinstance(server_config, HttpServerConfig): + connection_cm = streamablehttp_client(server_config.url) + read, write, _ = await connection_cm.__aenter__() + elif isinstance(server_config, StdioServerConfig): + params = StdioServerParameters( + command=server_config.command, + args=server_config.args or [], + env=server_config.env, + ) + log_file = self._log_file_for_session(name) + connection_cm = stdio_client(params, errlog=log_file) + read, write = await connection_cm.__aenter__() + else: + raise ValueError(f"Unknown server config type: {type(server_config)}") + + assert connection_cm is not None + session_cm = self._client_session_with_logging(name, read, write) + session = await session_cm.__aenter__() + + if session is None: + await session_cm.__aexit__(None, None, None) + await connection_cm.__aexit__(None, None, None) + if log_file: + log_file.close() + return None + + self._sessions[name] = _CachedSession( + session=session, + connection_cm=connection_cm, + session_cm=session_cm, + log_file=log_file, + ) + + return session @asynccontextmanager async def _client_session_with_logging(self, name, read, write): @@ -44,38 +123,6 @@ async def _client_session_with_logging(self, name, read, write): print(traceback.format_exc(), file=sys.stderr) yield None - @asynccontextmanager - async def _client_session(self, name: str): - server_config = self.config.get().mcpServers.get(name) - if not server_config: - raise ValueError(f"There is no such MCP server: {name}") - if isinstance(server_config, SseServerConfig): - async with sse_client(server_config.url) as (read, write): - async with self._client_session_with_logging( - name, read, write - ) as session: - yield session - elif isinstance(server_config, HttpServerConfig): - async with streamablehttp_client(server_config.url) as (read, write, _): - async with self._client_session_with_logging( - name, read, write - ) as session: - yield session - elif isinstance(server_config, StdioServerConfig): - params = StdioServerParameters( - command=server_config.command, - args=server_config.args or [], - env=server_config.env, - ) - log_file = self._log_file_for_session(name) - async with stdio_client(params, errlog=log_file) as (read, write): - async with self._client_session_with_logging( - name, read, write - ) as session: - yield session - else: - raise ValueError(f"Unknown server config type: {type(server_config)}") - def _log_file_for_session(self, name: str) -> TextIO: log_file = ( self.config.log_path.parent @@ -86,10 +133,10 @@ def _log_file_for_session(self, name: str) -> TextIO: return open(log_file, "w") async def get_tools_for(self, name: str) -> ListToolsResult: - async with self._client_session(name) as session: - if session is None: - return ListToolsResult(tools=[]) - return await session.list_tools() + session = await self._get_or_create_session(name) + if session is None: + return ListToolsResult(tools=[]) + return await session.list_tools() async def get_all_tools(self) -> dict[str, list[Tool]]: tools_for_server: dict[str, list[Tool]] = dict() @@ -99,10 +146,14 @@ async def get_all_tools(self) -> dict[str, list[Tool]]: return tools_for_server async def call_tool(self, server_name: str, name: str, **kwargs): - async with self._client_session(server_name) as session: - if session is None: - return ( - f"Error: Failed to call tool {name} from MCP server {server_name}" - ) - tool_result = await session.call_tool(name, kwargs) - return str(tool_result.content) + session = await self._get_or_create_session(server_name) + if session is None: + return f"Error: Failed to call tool {name} from MCP server {server_name}" + tool_result = await session.call_tool(name, kwargs) + return str(tool_result.content) + + async def close(self) -> None: + """Close all cached MCP sessions.""" + for name, cached in list(self._sessions.items()): + await cached.close() + del self._sessions[name] diff --git a/src/llm_tools_mcp/register_tools.py b/src/llm_tools_mcp/register_tools.py index 7b39900..5ed8652 100644 --- a/src/llm_tools_mcp/register_tools.py +++ b/src/llm_tools_mcp/register_tools.py @@ -60,14 +60,15 @@ def compute_tools(config_path: str = DEFAULT_MCP_JSON_PATH) -> list[llm.Tool]: nonlocal mcp_client previous_config = mcp_config.get() if mcp_config else None new_mcp_config = McpConfig.for_file_path(config_path) - new_mcp_client = McpClient(new_mcp_config) - if previous_config is None or new_mcp_config.get() != previous_config: - tools = _get_tools_for_llm(new_mcp_client) - mcp_client = new_mcp_client + if mcp_client is None or new_mcp_config.get() != previous_config: + if mcp_client is not None: + asyncio.run(mcp_client.close()) + mcp_client = McpClient(new_mcp_config) mcp_config = new_mcp_config + tools = _get_tools_for_llm(mcp_client) else: if tools is None: - tools = _get_tools_for_llm(new_mcp_client) + tools = _get_tools_for_llm(mcp_client) return tools class MCP(llm.Toolbox): diff --git a/tests/test_integration.py b/tests/test_integration.py index 8a900af..b6e94b0 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -30,6 +30,7 @@ async def test_sse_deepwiki_mcp(): ) assert result is not None, "Tool call should return a result" assert "react" in str(result).lower(), "Available pages for facebook/react" + await mcp_client.close() @pytest.mark.asyncio @@ -60,3 +61,4 @@ async def test_remote_fetch_mcp(): assert "fetch" in tool_names, ( f"Should have a fetching tool. Found tools: {tool_names}" ) + await mcp_client.close() diff --git a/tests/test_llm_tools_mcp.py b/tests/test_llm_tools_mcp.py index 253d420..b137121 100644 --- a/tests/test_llm_tools_mcp.py +++ b/tests/test_llm_tools_mcp.py @@ -81,6 +81,7 @@ async def long_task(files: list[str], ctx: Context) -> str: assert result is not None, "Tool call should return a result" result_str = str(result) assert "Tool output" in result_str, "Should find completion message" + await mcp_client.close() @pytest.mark.asyncio @@ -128,6 +129,7 @@ async def test_stdio(): assert result is not None, "Tool call should return a result" result_str = str(result) assert "This is test file 1" in result_str, "Should find test file content" + await mcp_client.close() finally: import shutil @@ -173,3 +175,4 @@ async def long_task_http(files: list[str]) -> str: assert result is not None, "Tool call should return a result" result_str = str(result) assert "Tool output" in result_str, "Should find completion message" + await mcp_client.close()