diff --git a/pyproject.toml b/pyproject.toml index 95d971e5..cedf9b54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,8 +45,10 @@ strix = "strix.interface.main:main" python = "^3.12" fastapi = "*" uvicorn = "*" -litellm = { version = "~1.79.1", extras = ["proxy"] } -openai = ">=1.99.5,<1.100.0" +httpx="0.28.1" +mcp = "^1.23.1" +litellm = { version = "^1.79.1", extras = ["proxy"] } +openai = "^2.8.0" tenacity = "^9.0.0" numpydoc = "^1.8.0" pydantic = {extras = ["email"], version = "^2.11.3"} @@ -61,6 +63,7 @@ xmltodict = "^0.13.0" pyte = "^0.8.1" requests = "^2.32.0" libtmux = "^0.46.2" +idna = "^3.11" [tool.poetry.group.dev.dependencies] # Type checking and static analysis @@ -85,6 +88,7 @@ isort = "^6.0.1" requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + # ============================================================================ # Type Checking Configuration # ============================================================================ diff --git a/strix/agents/base_agent.py b/strix/agents/base_agent.py index 67aeb383..eaecc540 100644 --- a/strix/agents/base_agent.py +++ b/strix/agents/base_agent.py @@ -17,6 +17,7 @@ from strix.llm import LLM, LLMConfig, LLMRequestFailedError from strix.llm.utils import clean_content from strix.tools import process_tool_invocations +from strix.tools.mcp_tools.mcp_tools import MCP from .state import AgentState @@ -60,6 +61,9 @@ def __init__(self, config: dict[str, Any]): if "max_iterations" in config: self.max_iterations = config["max_iterations"] + self.mcp_server: MCP | None = None + if hasattr(config, "mcp_config_path"): + self.connect_mcp(config["mcp_config_path"]) self.llm_config_name = config.get("llm_config_name", "default") self.llm_config = config.get("llm_config", self.default_llm_config) @@ -145,6 +149,17 @@ def _add_to_agents_graph(self) -> None: if self.state.parent_id is None and agents_graph_actions._root_agent_id is None: agents_graph_actions._root_agent_id = self.state.agent_id + def connect_mcp(self, config) -> None: + import json + + if isinstance(config, str): + config_path = Path(config) + with open(config_path, encoding="utf-8") as f: + config_data = json.load(f) + else: + raise ValueError("MCP configuration must be a file path or a dictionary") + self.mcp_server = MCP(config_data) + def cancel_current_execution(self) -> None: if self._current_task and not self._current_task.done(): self._current_task.cancel() diff --git a/strix/interface/main.py b/strix/interface/main.py index 6c244cf3..7000552f 100644 --- a/strix/interface/main.py +++ b/strix/interface/main.py @@ -312,6 +312,7 @@ def parse_arguments() -> argparse.Namespace: "Default is interactive mode with TUI." ), ) + parser.add_argument("--mcp-config", type=str, help="Path to MCP configuration JSON file") args = parser.parse_args() @@ -326,6 +327,12 @@ def parse_arguments() -> argparse.Namespace: except Exception as e: # noqa: BLE001 parser.error(f"Failed to read instruction file '{instruction_path}': {e}") + if args.mcp_config: + mcp_config_path = Path(args.mcp_config) + if not mcp_config_path.exists() or not mcp_config_path.is_file(): + parser.error( + f"MCP configuration file '{mcp_config_path}' does not exist or is not a file" + ) args.targets_info = [] for target in args.target: try: diff --git a/strix/interface/tui.py b/strix/interface/tui.py index 58fa120d..7453e401 100644 --- a/strix/interface/tui.py +++ b/strix/interface/tui.py @@ -329,6 +329,9 @@ def _build_agent_config(self, args: argparse.Namespace) -> dict[str, Any]: if getattr(args, "local_sources", None): config["local_sources"] = args.local_sources + if getattr(args, "mcp_config", None): + config["mcp_config_path"] = args.mcp_config + return config def _setup_cleanup_handlers(self) -> None: diff --git a/strix/tools/mcp_tools/mcp_tools.py b/strix/tools/mcp_tools/mcp_tools.py new file mode 100644 index 00000000..4d1f2fd7 --- /dev/null +++ b/strix/tools/mcp_tools/mcp_tools.py @@ -0,0 +1,155 @@ +import os +import sys +from contextlib import AsyncExitStack + + +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.abspath(os.path.join(current_dir, "../../../")) +sys.path.append(project_root) +from dataclasses import dataclass +from typing import Any + +from mcp import ClientSession, StdioServerParameters, stdio_client +from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamablehttp_client + +from strix.tools.registry import register_mcp_tool + + +class TransportType: + STDIO = "stdio" + STREAMABLE_HTTP = "streamable-http" + SSE = "sse" + + +@dataclass +class Configuration: + transport_type: str = TransportType.STDIO + command: str = "npx" # Example default + args: list | None = None + env: dict[str, str] | None = None + cwd: str | None = None + url: str | None = None + headers: dict[str, Any] | None = None + encoding: str = "utf-8" + + +class MCPClient: + def __init__(self, config: Configuration | dict[str, Any], timeout: int = 300): + if isinstance(config, dict): + self.config = Configuration(**config) + else: + self.config = config + + self.exit_stack = AsyncExitStack() + self.session = None + self.timeout = timeout + print(f"MCPClient initialized with transport type: {self.config.transport_type}") + + async def connect(self): + transport_type = self.config.transport_type + + if transport_type == TransportType.STDIO: + server_params = StdioServerParameters( + command=self.config.command, + args=self.config.args or [], + env=self.config.env, + cwd=self.config.cwd, + encoding=self.config.encoding, + ) + + stdio_transport = await self.exit_stack.enter_async_context( + stdio_client(server=server_params) + ) + read, write = stdio_transport + + self.session = await self.exit_stack.enter_async_context(ClientSession(read, write)) + elif transport_type == TransportType.STREAMABLE_HTTP: + if not self.config.url: + raise ValueError("URL must be provided for STREAMABLE_HTTP transport.") + + http_transport = await self.exit_stack.enter_async_context( + streamablehttp_client( + url=self.config.url, + headers=self.config.headers or {}, + ) + ) + read, write, _ = http_transport + + self.session = await self.exit_stack.enter_async_context(ClientSession(read, write)) + + elif transport_type == TransportType.SSE: + if not self.config.url: + raise ValueError("URL must be provided for SSE transport.") + sse_transport = await self.exit_stack.enter_async_context( + sse_client( + url=self.config.url, + headers=self.config.headers, + ) + ) + read, write = sse_transport + self.session = await self.exit_stack.enter_async_context(ClientSession(read, write)) + + else: + raise ValueError(f"Unsupported transport type: {transport_type}") + + await self.session.initialize() + + def _generate_xml_schema(self, name, inputSchema, description) -> str: + name_str = f'' + desc_str = f"{description}" + properties = "" + if inputSchema["properties"]: + for key, value in inputSchema["properties"].items(): + properties += f'' + + return f"""{name_str}\n {desc_str}\n \n{properties if inputSchema["properties"] else ""}\n \n""" + + async def register_tools(self): + if not self.session: + raise RuntimeError("Client is not connected. Call connect() first.") + + response = await self.session.list_tools() + tools = response.tools + + for tool in tools: + name = tool.name + + async def dummy_func(tool_name=name, **kwargs) -> Any: + return await self.session.call_tool(tool_name, arguments=kwargs) + + tool_xml = self._generate_xml_schema( + name=tool.name, inputSchema=tool.inputSchema, description=tool.description + ) + register_mcp_tool( + name=tool.name, func=dummy_func, module="unknown", xml_schema=tool_xml + ) + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.cleanup() + + async def __aenter__(self): + await self.connect() + return self + + async def cleanup(self): + await self.exit_stack.aclose() + + +class MCP: + def __init__(self, config: dict[str, Any], timeout: int = 300): + self.config = config + self.timeout = timeout + self.client: list[MCPClient] = [] + + async def connect(self) -> MCPClient: + mcp_server_config = self.config.get("mcpServers", {}) + for server_name, server_config in mcp_server_config.items(): + client = MCPClient(config=server_config, timeout=self.timeout) + await client.connect() + await client.register_tools() + self.client.append(client) + + async def cleanup(self): + for client in self.client: + await client.cleanup() diff --git a/strix/tools/registry.py b/strix/tools/registry.py index a12ae2b0..1336d3b2 100644 --- a/strix/tools/registry.py +++ b/strix/tools/registry.py @@ -95,6 +95,26 @@ def _get_module_name(func: Callable[..., Any]) -> str: return "unknown" +def register_mcp_tool( + name: str, + func: Callable[..., Any], + module: str = "unknown", + xml_schema: str | None = None, +) -> None: + func_dict = { + "name": name, + "function": func, + "module": module, + "sandbox_execution": False, + } + + if xml_schema: + func_dict["xml_schema"] = xml_schema + + tools.append(func_dict) + _tools_by_name[str(func_dict["name"])] = func + + def register_tool( func: Callable[..., Any] | None = None, *, sandbox_execution: bool = True ) -> Callable[..., Any]: