Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ strix = "strix.interface.main:main"
python = "^3.12"
fastapi = "*"
uvicorn = "*"
litellm = { version = "~1.79.1", extras = ["proxy"] }
openai = ">=1.99.5,<1.100.0"
httpx="0.28.1"
mcp = "^1.23.1"
litellm = { version = "^1.79.1", extras = ["proxy"] }
openai = "^2.8.0"
tenacity = "^9.0.0"
numpydoc = "^1.8.0"
pydantic = {extras = ["email"], version = "^2.11.3"}
Expand All @@ -61,6 +63,7 @@ xmltodict = "^0.13.0"
pyte = "^0.8.1"
requests = "^2.32.0"
libtmux = "^0.46.2"
idna = "^3.11"

[tool.poetry.group.dev.dependencies]
# Type checking and static analysis
Expand All @@ -85,6 +88,7 @@ isort = "^6.0.1"
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"


# ============================================================================
# Type Checking Configuration
# ============================================================================
Expand Down
15 changes: 15 additions & 0 deletions strix/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from strix.llm import LLM, LLMConfig, LLMRequestFailedError
from strix.llm.utils import clean_content
from strix.tools import process_tool_invocations
from strix.tools.mcp_tools.mcp_tools import MCP

from .state import AgentState

Expand Down Expand Up @@ -60,6 +61,9 @@ def __init__(self, config: dict[str, Any]):

if "max_iterations" in config:
self.max_iterations = config["max_iterations"]
self.mcp_server: MCP | None = None
if hasattr(config, "mcp_config_path"):
self.connect_mcp(config["mcp_config_path"])

self.llm_config_name = config.get("llm_config_name", "default")
self.llm_config = config.get("llm_config", self.default_llm_config)
Expand Down Expand Up @@ -145,6 +149,17 @@ def _add_to_agents_graph(self) -> None:
if self.state.parent_id is None and agents_graph_actions._root_agent_id is None:
agents_graph_actions._root_agent_id = self.state.agent_id

def connect_mcp(self, config) -> None:
import json

if isinstance(config, str):
config_path = Path(config)
with open(config_path, encoding="utf-8") as f:
config_data = json.load(f)
else:
raise ValueError("MCP configuration must be a file path or a dictionary")
self.mcp_server = MCP(config_data)

def cancel_current_execution(self) -> None:
if self._current_task and not self._current_task.done():
self._current_task.cancel()
Expand Down
7 changes: 7 additions & 0 deletions strix/interface/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def parse_arguments() -> argparse.Namespace:
"Default is interactive mode with TUI."
),
)
parser.add_argument("--mcp-config", type=str, help="Path to MCP configuration JSON file")

args = parser.parse_args()

Expand All @@ -326,6 +327,12 @@ def parse_arguments() -> argparse.Namespace:
except Exception as e: # noqa: BLE001
parser.error(f"Failed to read instruction file '{instruction_path}': {e}")

if args.mcp_config:
mcp_config_path = Path(args.mcp_config)
if not mcp_config_path.exists() or not mcp_config_path.is_file():
parser.error(
f"MCP configuration file '{mcp_config_path}' does not exist or is not a file"
)
args.targets_info = []
for target in args.target:
try:
Expand Down
3 changes: 3 additions & 0 deletions strix/interface/tui.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@ def _build_agent_config(self, args: argparse.Namespace) -> dict[str, Any]:
if getattr(args, "local_sources", None):
config["local_sources"] = args.local_sources

if getattr(args, "mcp_config", None):
config["mcp_config_path"] = args.mcp_config

return config

def _setup_cleanup_handlers(self) -> None:
Expand Down
155 changes: 155 additions & 0 deletions strix/tools/mcp_tools/mcp_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import os
import sys
from contextlib import AsyncExitStack


current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, "../../../"))
sys.path.append(project_root)
from dataclasses import dataclass
from typing import Any

from mcp import ClientSession, StdioServerParameters, stdio_client
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client

from strix.tools.registry import register_mcp_tool


class TransportType:
STDIO = "stdio"
STREAMABLE_HTTP = "streamable-http"
SSE = "sse"


@dataclass
class Configuration:
transport_type: str = TransportType.STDIO
command: str = "npx" # Example default
args: list | None = None
env: dict[str, str] | None = None
cwd: str | None = None
url: str | None = None
headers: dict[str, Any] | None = None
encoding: str = "utf-8"


class MCPClient:
def __init__(self, config: Configuration | dict[str, Any], timeout: int = 300):
if isinstance(config, dict):
self.config = Configuration(**config)
else:
self.config = config

self.exit_stack = AsyncExitStack()
self.session = None
self.timeout = timeout
print(f"MCPClient initialized with transport type: {self.config.transport_type}")

async def connect(self):
transport_type = self.config.transport_type

if transport_type == TransportType.STDIO:
server_params = StdioServerParameters(
command=self.config.command,
args=self.config.args or [],
env=self.config.env,
cwd=self.config.cwd,
encoding=self.config.encoding,
)

stdio_transport = await self.exit_stack.enter_async_context(
stdio_client(server=server_params)
)
read, write = stdio_transport

self.session = await self.exit_stack.enter_async_context(ClientSession(read, write))
elif transport_type == TransportType.STREAMABLE_HTTP:
if not self.config.url:
raise ValueError("URL must be provided for STREAMABLE_HTTP transport.")

http_transport = await self.exit_stack.enter_async_context(
streamablehttp_client(
url=self.config.url,
headers=self.config.headers or {},
)
)
read, write, _ = http_transport

self.session = await self.exit_stack.enter_async_context(ClientSession(read, write))

elif transport_type == TransportType.SSE:
if not self.config.url:
raise ValueError("URL must be provided for SSE transport.")
sse_transport = await self.exit_stack.enter_async_context(
sse_client(
url=self.config.url,
headers=self.config.headers,
)
)
read, write = sse_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(read, write))

else:
raise ValueError(f"Unsupported transport type: {transport_type}")

await self.session.initialize()

def _generate_xml_schema(self, name, inputSchema, description) -> str:
name_str = f'<tool name="{name}">'
desc_str = f"<description>{description}</description>"
properties = ""
if inputSchema["properties"]:
for key, value in inputSchema["properties"].items():
properties += f'<property name="{key}" type="{value["type"]}" require="{"true" if key in inputSchema["required"] else "false"}"/>'

return f"""{name_str}\n {desc_str}\n <input>\n{properties if inputSchema["properties"] else ""}\n </input>\n</tool>"""

async def register_tools(self):
if not self.session:
raise RuntimeError("Client is not connected. Call connect() first.")

response = await self.session.list_tools()
tools = response.tools

for tool in tools:
name = tool.name

async def dummy_func(tool_name=name, **kwargs) -> Any:
return await self.session.call_tool(tool_name, arguments=kwargs)

tool_xml = self._generate_xml_schema(
name=tool.name, inputSchema=tool.inputSchema, description=tool.description
)
register_mcp_tool(
name=tool.name, func=dummy_func, module="unknown", xml_schema=tool_xml
)

async def __aexit__(self, exc_type, exc_value, traceback):
await self.cleanup()

async def __aenter__(self):
await self.connect()
return self

async def cleanup(self):
await self.exit_stack.aclose()


class MCP:
def __init__(self, config: dict[str, Any], timeout: int = 300):
self.config = config
self.timeout = timeout
self.client: list[MCPClient] = []

async def connect(self) -> MCPClient:
mcp_server_config = self.config.get("mcpServers", {})
for server_name, server_config in mcp_server_config.items():
client = MCPClient(config=server_config, timeout=self.timeout)
await client.connect()
await client.register_tools()
self.client.append(client)

async def cleanup(self):
for client in self.client:
await client.cleanup()
20 changes: 20 additions & 0 deletions strix/tools/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,26 @@ def _get_module_name(func: Callable[..., Any]) -> str:
return "unknown"


def register_mcp_tool(
name: str,
func: Callable[..., Any],
module: str = "unknown",
xml_schema: str | None = None,
) -> None:
func_dict = {
"name": name,
"function": func,
"module": module,
"sandbox_execution": False,
}

if xml_schema:
func_dict["xml_schema"] = xml_schema

tools.append(func_dict)
_tools_by_name[str(func_dict["name"])] = func


def register_tool(
func: Callable[..., Any] | None = None, *, sandbox_execution: bool = True
) -> Callable[..., Any]:
Expand Down