From f37f585be5f0b99474e1162618d004219a053cc6 Mon Sep 17 00:00:00 2001 From: Saedbhati Date: Sat, 6 Dec 2025 19:45:33 +0530 Subject: [PATCH 1/6] feat: mcp_tools_integration --- pyproject.toml | 8 +- strix/tools/mcp_tools/mcp_tools.py | 149 +++++++++++++++++++++++++++++ strix/tools/registry.py | 20 ++++ 3 files changed, 175 insertions(+), 2 deletions(-) create mode 100644 strix/tools/mcp_tools/mcp_tools.py diff --git a/pyproject.toml b/pyproject.toml index 95d971e5..cedf9b54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,8 +45,10 @@ strix = "strix.interface.main:main" python = "^3.12" fastapi = "*" uvicorn = "*" -litellm = { version = "~1.79.1", extras = ["proxy"] } -openai = ">=1.99.5,<1.100.0" +httpx="0.28.1" +mcp = "^1.23.1" +litellm = { version = "^1.79.1", extras = ["proxy"] } +openai = "^2.8.0" tenacity = "^9.0.0" numpydoc = "^1.8.0" pydantic = {extras = ["email"], version = "^2.11.3"} @@ -61,6 +63,7 @@ xmltodict = "^0.13.0" pyte = "^0.8.1" requests = "^2.32.0" libtmux = "^0.46.2" +idna = "^3.11" [tool.poetry.group.dev.dependencies] # Type checking and static analysis @@ -85,6 +88,7 @@ isort = "^6.0.1" requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + # ============================================================================ # Type Checking Configuration # ============================================================================ diff --git a/strix/tools/mcp_tools/mcp_tools.py b/strix/tools/mcp_tools/mcp_tools.py new file mode 100644 index 00000000..8323483a --- /dev/null +++ b/strix/tools/mcp_tools/mcp_tools.py @@ -0,0 +1,149 @@ +import asyncio +import json +import os +import sys +from contextlib import AsyncExitStack +from dataclasses import dataclass +from pathlib import Path # Use Path for modern path manipulation +from typing import Any + +# INP001: Ensure 'strix\tools\mcp_tools' directory contains an __init__.py file +# to resolve this error. + +# Using pathlib for cleaner path handling (PTH120, PTH100, PTH118) +current_dir = Path(__file__).resolve().parent +project_root = current_dir.parent.parent.parent.resolve() +sys.path.append(str(project_root)) + +# E402: Imports moved to the top of the file +from strix.tools.registry import register_mcp_tool + + +class TransportType: + STDIO = "stdio" + + +@dataclass +class Configuration: + transport_type: str = TransportType.STDIO + command: str = "npx" # Example default + args: list | None = None + env: dict | None = None + cwd: str | None = None + encoding: str = "utf-8" + + +class MCPClient: + def __init__(self, config: Configuration | dict[str, Any], timeout: int = 300): + if isinstance(config, dict): + self.config = Configuration(**config) + else: + self.config = config + + self.exit_stack = AsyncExitStack() + self.session = None + self.timeout = timeout + # Removed T201: print(f"MCPClient initialized with transport type: {self.config.transport_type}") + + async def connect(self): + transport_type = self.config.transport_type + + if transport_type == TransportType.STDIO: + from mcp import ClientSession, StdioServerParameters, stdio_client + + server_params = StdioServerParameters( + command=self.config.command, + args=self.config.args or [], + env=self.config.env, + cwd=self.config.cwd, + encoding=self.config.encoding, + ) + + stdio_transport = await self.exit_stack.enter_async_context( + stdio_client(server=server_params) + ) + read, write = stdio_transport + + self.session = await self.exit_stack.enter_async_context( + ClientSession(read, write) + ) + + await self.session.initialize() + + else: + raise ValueError(f"Unsupported transport type: {transport_type}") + + # Renamed inputSchema to input_schema (N803) + def _generate_xml_schema(self, name, input_schema, description) -> str: + name_str = f'' + desc_str = f"{description}" + properties = "" + # Used tuple for multi-line string concatenation to avoid E501 + if input_schema["properties"]: + for key, value in input_schema["properties"].items(): + is_required = "true" if key in input_schema["required"] else "false" + properties += ( + f'' + ) + + # Used tuple for multi-line string concatenation to avoid E501 + input_content = properties if input_schema["properties"] else "" + return ( + f"""{name_str}\n   {desc_str}\n    \n""" + f"""{input_content}\n    \n""" + ) + + async def register_tools(self): + if not self.session: + raise RuntimeError("Client is not connected. Call connect() first.") + + response = await self.session.list_tools() + tools = response.tools + + for tool in tools: + name = tool.name + + async def dummy_func(tool_name=name, **kwargs) -> Any: + return await self.session.call_tool(tool_name, arguments=kwargs) + + tool_xml = self._generate_xml_schema( + name=tool.name, + input_schema=tool.inputSchema, # Note: The tool object still uses inputSchema + description=tool.description, + ) + register_mcp_tool( + name=tool.name, func=dummy_func, module="unknown", xml_schema=tool_xml + ) + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.cleanup() + + async def __aenter__(self): + await self.connect() + return self + + async def cleanup(self): + await self.exit_stack.aclose() + + +async def main(): + # Example Usage + # Removed ERA001: Commented-out code + + # Mocking data for demonstration + # Used Path.open() (PTH123) + with (Path(r"strix/tools/mcp_tools/mcp.json")).open() as f: + config_data = json.load(f) + # Removed T201: print("--- Connecting via explicit connect() ---") + + client = MCPClient(config_data["mcpServers"]["weather"]) + try: + await client.connect() + await client.register_tools() + finally: + await client.cleanup() + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/strix/tools/registry.py b/strix/tools/registry.py index a12ae2b0..1336d3b2 100644 --- a/strix/tools/registry.py +++ b/strix/tools/registry.py @@ -95,6 +95,26 @@ def _get_module_name(func: Callable[..., Any]) -> str: return "unknown" +def register_mcp_tool( + name: str, + func: Callable[..., Any], + module: str = "unknown", + xml_schema: str | None = None, +) -> None: + func_dict = { + "name": name, + "function": func, + "module": module, + "sandbox_execution": False, + } + + if xml_schema: + func_dict["xml_schema"] = xml_schema + + tools.append(func_dict) + _tools_by_name[str(func_dict["name"])] = func + + def register_tool( func: Callable[..., Any] | None = None, *, sandbox_execution: bool = True ) -> Callable[..., Any]: From 34f46c746c39a85224d9a13b62a7bdeca8e96f13 Mon Sep 17 00:00:00 2001 From: Saedbhati Date: Sat, 6 Dec 2025 21:52:51 +0530 Subject: [PATCH 2/6] update more transport --- strix/tools/mcp_tools/mcp.json | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 strix/tools/mcp_tools/mcp.json diff --git a/strix/tools/mcp_tools/mcp.json b/strix/tools/mcp_tools/mcp.json new file mode 100644 index 00000000..7cc98e5e --- /dev/null +++ b/strix/tools/mcp_tools/mcp.json @@ -0,0 +1,8 @@ +{ + "mcpServers": { + "weather": { + "url": "http://127.0.0.1:8000/sse", + "transport_type": "sse" + } + } +} From 7fed320e021a760afd7c30b3cc98ddc0dd289fd6 Mon Sep 17 00:00:00 2001 From: Saedbhati Date: Sat, 6 Dec 2025 21:53:51 +0530 Subject: [PATCH 3/6] update more transport --- strix/tools/mcp_tools/mcp.json | 8 --- strix/tools/mcp_tools/mcp_tools.py | 97 +++++++++++++++++------------- 2 files changed, 55 insertions(+), 50 deletions(-) delete mode 100644 strix/tools/mcp_tools/mcp.json diff --git a/strix/tools/mcp_tools/mcp.json b/strix/tools/mcp_tools/mcp.json deleted file mode 100644 index 7cc98e5e..00000000 --- a/strix/tools/mcp_tools/mcp.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "mcpServers": { - "weather": { - "url": "http://127.0.0.1:8000/sse", - "transport_type": "sse" - } - } -} diff --git a/strix/tools/mcp_tools/mcp_tools.py b/strix/tools/mcp_tools/mcp_tools.py index 8323483a..01b56bcc 100644 --- a/strix/tools/mcp_tools/mcp_tools.py +++ b/strix/tools/mcp_tools/mcp_tools.py @@ -3,33 +3,35 @@ import os import sys from contextlib import AsyncExitStack -from dataclasses import dataclass -from pathlib import Path # Use Path for modern path manipulation -from typing import Any -# INP001: Ensure 'strix\tools\mcp_tools' directory contains an __init__.py file -# to resolve this error. -# Using pathlib for cleaner path handling (PTH120, PTH100, PTH118) -current_dir = Path(__file__).resolve().parent -project_root = current_dir.parent.parent.parent.resolve() -sys.path.append(str(project_root)) +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.abspath(os.path.join(current_dir, "../../../")) +sys.path.append(project_root) +from dataclasses import dataclass +from typing import Any, Dict, List +from mcp import ClientSession, StdioServerParameters, stdio_client -# E402: Imports moved to the top of the file from strix.tools.registry import register_mcp_tool +from mcp.client.streamable_http import streamablehttp_client +from mcp.client.sse import sse_client class TransportType: STDIO = "stdio" + STREAMABLE_HTTP = "streamable-http" + SSE = "sse" @dataclass class Configuration: transport_type: str = TransportType.STDIO command: str = "npx" # Example default - args: list | None = None - env: dict | None = None + args: List | None = None + env: Dict[str, str] | None = None cwd: str | None = None + url: str | None = None + headers: Dict[str, Any] | None= None encoding: str = "utf-8" @@ -43,13 +45,12 @@ def __init__(self, config: Configuration | dict[str, Any], timeout: int = 300): self.exit_stack = AsyncExitStack() self.session = None self.timeout = timeout - # Removed T201: print(f"MCPClient initialized with transport type: {self.config.transport_type}") + print(f"MCPClient initialized with transport type: {self.config.transport_type}") async def connect(self): transport_type = self.config.transport_type if transport_type == TransportType.STDIO: - from mcp import ClientSession, StdioServerParameters, stdio_client server_params = StdioServerParameters( command=self.config.command, @@ -64,35 +65,48 @@ async def connect(self): ) read, write = stdio_transport - self.session = await self.exit_stack.enter_async_context( - ClientSession(read, write) + self.session = await self.exit_stack.enter_async_context(ClientSession(read, write)) + elif transport_type == TransportType.STREAMABLE_HTTP : + if not self.config.url: + raise ValueError("URL must be provided for STREAMABLE_HTTP transport.") + + http_transport = await self.exit_stack.enter_async_context( + streamablehttp_client( + url=self.config.url, + headers=self.config.headers or {}, + ) ) + read, write, _ = http_transport + + self.session = await self.exit_stack.enter_async_context(ClientSession(read, write)) - await self.session.initialize() + elif transport_type == TransportType.SSE: + + if not self.config.url: + raise ValueError("URL must be provided for SSE transport.") + sse_transport = await self.exit_stack.enter_async_context( + sse_client( + url=self.config.url, + headers=self.config.headers , + ) ) + read, write= sse_transport + self.session = await self.exit_stack.enter_async_context(ClientSession(read, write)) + else: raise ValueError(f"Unsupported transport type: {transport_type}") + + await self.session.initialize() - # Renamed inputSchema to input_schema (N803) - def _generate_xml_schema(self, name, input_schema, description) -> str: + def _generate_xml_schema(self, name, inputSchema, description) -> str: name_str = f'' desc_str = f"{description}" properties = "" - # Used tuple for multi-line string concatenation to avoid E501 - if input_schema["properties"]: - for key, value in input_schema["properties"].items(): - is_required = "true" if key in input_schema["required"] else "false" - properties += ( - f'' - ) + if inputSchema["properties"]: + for key, value in inputSchema["properties"].items(): + properties += f'' - # Used tuple for multi-line string concatenation to avoid E501 - input_content = properties if input_schema["properties"] else "" - return ( - f"""{name_str}\n   {desc_str}\n    \n""" - f"""{input_content}\n    \n""" - ) + return f"""{name_str}\n {desc_str}\n \n{properties if inputSchema["properties"] else ""}\n \n""" async def register_tools(self): if not self.session: @@ -108,14 +122,13 @@ async def dummy_func(tool_name=name, **kwargs) -> Any: return await self.session.call_tool(tool_name, arguments=kwargs) tool_xml = self._generate_xml_schema( - name=tool.name, - input_schema=tool.inputSchema, # Note: The tool object still uses inputSchema - description=tool.description, + name=tool.name, inputSchema=tool.inputSchema, description=tool.description ) register_mcp_tool( name=tool.name, func=dummy_func, module="unknown", xml_schema=tool_xml ) + async def __aexit__(self, exc_type, exc_value, traceback): await self.cleanup() @@ -129,14 +142,14 @@ async def cleanup(self): async def main(): # Example Usage - # Removed ERA001: Commented-out code + # Replace with your actual JSON loading logic + # with open(r"strix\tools\mcp_tools\mcp.json") as f: + # config_data = json.load(f) # Mocking data for demonstration - # Used Path.open() (PTH123) - with (Path(r"strix/tools/mcp_tools/mcp.json")).open() as f: + with open(r"strix\tools\mcp_tools\mcp.json") as f: config_data = json.load(f) - # Removed T201: print("--- Connecting via explicit connect() ---") - + print("--- Connecting via explicit connect() ---") client = MCPClient(config_data["mcpServers"]["weather"]) try: await client.connect() @@ -146,4 +159,4 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) From 0abd879f7c1b73e9c2d331813799ca7e9c3dd0f4 Mon Sep 17 00:00:00 2001 From: Saedbhati Date: Sat, 6 Dec 2025 21:54:10 +0530 Subject: [PATCH 4/6] update more transport --- strix/tools/mcp_tools/mcp_tools.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/strix/tools/mcp_tools/mcp_tools.py b/strix/tools/mcp_tools/mcp_tools.py index 01b56bcc..93fda929 100644 --- a/strix/tools/mcp_tools/mcp_tools.py +++ b/strix/tools/mcp_tools/mcp_tools.py @@ -9,12 +9,13 @@ project_root = os.path.abspath(os.path.join(current_dir, "../../../")) sys.path.append(project_root) from dataclasses import dataclass -from typing import Any, Dict, List +from typing import Any + from mcp import ClientSession, StdioServerParameters, stdio_client +from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamablehttp_client from strix.tools.registry import register_mcp_tool -from mcp.client.streamable_http import streamablehttp_client -from mcp.client.sse import sse_client class TransportType: @@ -27,11 +28,11 @@ class TransportType: class Configuration: transport_type: str = TransportType.STDIO command: str = "npx" # Example default - args: List | None = None - env: Dict[str, str] | None = None + args: list | None = None + env: dict[str, str] | None = None cwd: str | None = None url: str | None = None - headers: Dict[str, Any] | None= None + headers: dict[str, Any] | None = None encoding: str = "utf-8" @@ -51,7 +52,6 @@ async def connect(self): transport_type = self.config.transport_type if transport_type == TransportType.STDIO: - server_params = StdioServerParameters( command=self.config.command, args=self.config.args or [], @@ -66,7 +66,7 @@ async def connect(self): read, write = stdio_transport self.session = await self.exit_stack.enter_async_context(ClientSession(read, write)) - elif transport_type == TransportType.STREAMABLE_HTTP : + elif transport_type == TransportType.STREAMABLE_HTTP: if not self.config.url: raise ValueError("URL must be provided for STREAMABLE_HTTP transport.") @@ -81,22 +81,21 @@ async def connect(self): self.session = await self.exit_stack.enter_async_context(ClientSession(read, write)) elif transport_type == TransportType.SSE: - if not self.config.url: raise ValueError("URL must be provided for SSE transport.") sse_transport = await self.exit_stack.enter_async_context( sse_client( url=self.config.url, - headers=self.config.headers , - ) ) - read, write= sse_transport + headers=self.config.headers, + ) + ) + read, write = sse_transport self.session = await self.exit_stack.enter_async_context(ClientSession(read, write)) - else: raise ValueError(f"Unsupported transport type: {transport_type}") - - await self.session.initialize() + + await self.session.initialize() def _generate_xml_schema(self, name, inputSchema, description) -> str: name_str = f'' @@ -128,7 +127,6 @@ async def dummy_func(tool_name=name, **kwargs) -> Any: name=tool.name, func=dummy_func, module="unknown", xml_schema=tool_xml ) - async def __aexit__(self, exc_type, exc_value, traceback): await self.cleanup() From 2ef01dee63dad4a9f9d6c4dd16c256fbf7aa5e66 Mon Sep 17 00:00:00 2001 From: Saedbhati Date: Sat, 6 Dec 2025 22:14:52 +0530 Subject: [PATCH 5/6] update --- strix/agents/base_agent.py | 15 +++++++++++++ strix/interface/main.py | 7 ++++++ strix/interface/tui.py | 3 +++ strix/tools/mcp_tools/mcp_tools.py | 34 ++++++++++++------------------ 4 files changed, 39 insertions(+), 20 deletions(-) diff --git a/strix/agents/base_agent.py b/strix/agents/base_agent.py index 67aeb383..eaecc540 100644 --- a/strix/agents/base_agent.py +++ b/strix/agents/base_agent.py @@ -17,6 +17,7 @@ from strix.llm import LLM, LLMConfig, LLMRequestFailedError from strix.llm.utils import clean_content from strix.tools import process_tool_invocations +from strix.tools.mcp_tools.mcp_tools import MCP from .state import AgentState @@ -60,6 +61,9 @@ def __init__(self, config: dict[str, Any]): if "max_iterations" in config: self.max_iterations = config["max_iterations"] + self.mcp_server: MCP | None = None + if hasattr(config, "mcp_config_path"): + self.connect_mcp(config["mcp_config_path"]) self.llm_config_name = config.get("llm_config_name", "default") self.llm_config = config.get("llm_config", self.default_llm_config) @@ -145,6 +149,17 @@ def _add_to_agents_graph(self) -> None: if self.state.parent_id is None and agents_graph_actions._root_agent_id is None: agents_graph_actions._root_agent_id = self.state.agent_id + def connect_mcp(self, config) -> None: + import json + + if isinstance(config, str): + config_path = Path(config) + with open(config_path, encoding="utf-8") as f: + config_data = json.load(f) + else: + raise ValueError("MCP configuration must be a file path or a dictionary") + self.mcp_server = MCP(config_data) + def cancel_current_execution(self) -> None: if self._current_task and not self._current_task.done(): self._current_task.cancel() diff --git a/strix/interface/main.py b/strix/interface/main.py index 6c244cf3..7000552f 100644 --- a/strix/interface/main.py +++ b/strix/interface/main.py @@ -312,6 +312,7 @@ def parse_arguments() -> argparse.Namespace: "Default is interactive mode with TUI." ), ) + parser.add_argument("--mcp-config", type=str, help="Path to MCP configuration JSON file") args = parser.parse_args() @@ -326,6 +327,12 @@ def parse_arguments() -> argparse.Namespace: except Exception as e: # noqa: BLE001 parser.error(f"Failed to read instruction file '{instruction_path}': {e}") + if args.mcp_config: + mcp_config_path = Path(args.mcp_config) + if not mcp_config_path.exists() or not mcp_config_path.is_file(): + parser.error( + f"MCP configuration file '{mcp_config_path}' does not exist or is not a file" + ) args.targets_info = [] for target in args.target: try: diff --git a/strix/interface/tui.py b/strix/interface/tui.py index 58fa120d..7453e401 100644 --- a/strix/interface/tui.py +++ b/strix/interface/tui.py @@ -329,6 +329,9 @@ def _build_agent_config(self, args: argparse.Namespace) -> dict[str, Any]: if getattr(args, "local_sources", None): config["local_sources"] = args.local_sources + if getattr(args, "mcp_config", None): + config["mcp_config_path"] = args.mcp_config + return config def _setup_cleanup_handlers(self) -> None: diff --git a/strix/tools/mcp_tools/mcp_tools.py b/strix/tools/mcp_tools/mcp_tools.py index 93fda929..bbf216a9 100644 --- a/strix/tools/mcp_tools/mcp_tools.py +++ b/strix/tools/mcp_tools/mcp_tools.py @@ -1,5 +1,3 @@ -import asyncio -import json import os import sys from contextlib import AsyncExitStack @@ -138,23 +136,19 @@ async def cleanup(self): await self.exit_stack.aclose() -async def main(): - # Example Usage - # Replace with your actual JSON loading logic - # with open(r"strix\tools\mcp_tools\mcp.json") as f: - # config_data = json.load(f) - - # Mocking data for demonstration - with open(r"strix\tools\mcp_tools\mcp.json") as f: - config_data = json.load(f) - print("--- Connecting via explicit connect() ---") - client = MCPClient(config_data["mcpServers"]["weather"]) - try: - await client.connect() - await client.register_tools() - finally: - await client.cleanup() +class MCP: + def __init__(self, config: dict[str, Any], timeout: int = 300): + self.config = config + self.timeout = timeout + self.client: list[MCPClient] = [] + async def connect(self) -> MCPClient: + mcp_server_config = self.config.get("mcpServers", {}) + for server_name, server_config in mcp_server_config.items(): + client = MCPClient(config=server_config, timeout=self.timeout) + await client.connect() + self.client.append(client) -if __name__ == "__main__": - asyncio.run(main()) + async def cleanup(self): + for client in self.client: + await client.cleanup() From 336a43fdd284d2396e709a30b620882a318ed8a0 Mon Sep 17 00:00:00 2001 From: Saedbhati Date: Sun, 7 Dec 2025 09:05:43 +0530 Subject: [PATCH 6/6] minor fixes --- strix/tools/mcp_tools/mcp_tools.py | 1 + 1 file changed, 1 insertion(+) diff --git a/strix/tools/mcp_tools/mcp_tools.py b/strix/tools/mcp_tools/mcp_tools.py index bbf216a9..4d1f2fd7 100644 --- a/strix/tools/mcp_tools/mcp_tools.py +++ b/strix/tools/mcp_tools/mcp_tools.py @@ -147,6 +147,7 @@ async def connect(self) -> MCPClient: for server_name, server_config in mcp_server_config.items(): client = MCPClient(config=server_config, timeout=self.timeout) await client.connect() + await client.register_tools() self.client.append(client) async def cleanup(self):