diff --git a/.pr_agent.toml b/.pr_agent.toml index 3e8c50bd7e..e4f65c23d3 100644 --- a/.pr_agent.toml +++ b/.pr_agent.toml @@ -18,3 +18,17 @@ push_commands = [ [review_agent] enabled = true publish_output = true + +[mcp] +# Set to true to enable MCP tool orchestration for /ask, /review, /improve +enabled = false +# Path to the MCP server config file (JSON/JSONC). Defaults to mcp_config.json in cwd. +# config_path = "/path/to/mcp_config.json" +# Whether to fail hard when the MCP config file is invalid (default: false = log and skip) +fail_on_invalid_config = false +# Expand environment variables and ~ in MCP server config values +resolve_env_vars = true +# Maximum number of tools to expose to the model per request +max_tool_catalog_tools = 12 +# Maximum total characters of tool schemas to include in the prompt +max_tool_catalog_schema_chars = 12000 diff --git a/docs/docs/usage-guide/additional_configurations.md b/docs/docs/usage-guide/additional_configurations.md index ba67f8dfc4..b354334e58 100644 --- a/docs/docs/usage-guide/additional_configurations.md +++ b/docs/docs/usage-guide/additional_configurations.md @@ -9,6 +9,27 @@ To print all the available configurations as a comment on your PR, you can use t /config ``` +When MCP is enabled, the `/config` comment also includes a small MCP runtime status block showing whether MCP is enabled and which servers are configured and connected. + +## MCP runtime configuration + +PR-Agent can load MCP servers from a server-side JSON or JSONC file. By default, it reads `/etc/pr-agent/mcp.json`, and you can override that path with `MCP_CONFIG_PATH` or the `[mcp].config_path` setting. + +The file may use either the `servers` key, which matches the VS Code MCP schema, or `mcpServers`, which matches the Claude Desktop schema. + +For example, an AWS Knowledge MCP server can be configured like this: + +```json +{ + "servers": { + "AWS Knowledge": { + "url": "https://knowledge-mcp.global.api.aws", + "type": "http" + } + } +} +``` + ![possible_config1](https://codium.ai/images/pr_agent/possible_config1.png){width=512} To view the **actual** configurations used for a specific tool, after all the user settings are applied, you can add for each tool a `--config.output_relevant_configurations=true` suffix. diff --git a/docs/docs/usage-guide/automations_and_usage.md b/docs/docs/usage-guide/automations_and_usage.md index 668d8a85e7..b7ad5f45e6 100644 --- a/docs/docs/usage-guide/automations_and_usage.md +++ b/docs/docs/usage-guide/automations_and_usage.md @@ -75,6 +75,8 @@ For example, if you want to edit the `review` tool configurations, you can run: Any configuration value in [configuration file](https://github.com/the-pr-agent/pr-agent/blob/main/pr_agent/settings/configuration.toml) file can be similarly edited. Comment `/config` to see the list of available configurations. +If you want PR-Agent to use MCP tools, mount a server-side MCP config file at `/etc/pr-agent/mcp.json` or point `MCP_CONFIG_PATH` at another JSON/JSONC file. The `/config` comment will show the active MCP runtime status when MCP is enabled. + ## PR-Agent Automatic Feedback ### Disabling all automatic feedback diff --git a/pr_agent/algo/ai_handlers/base_ai_handler.py b/pr_agent/algo/ai_handlers/base_ai_handler.py index 956fcaffda..7adc2e42a9 100644 --- a/pr_agent/algo/ai_handlers/base_ai_handler.py +++ b/pr_agent/algo/ai_handlers/base_ai_handler.py @@ -1,4 +1,10 @@ +import inspect +import json +import logging from abc import ABC, abstractmethod +from typing import Any, Awaitable, Callable, Optional + +from pr_agent.mcp.runtime import MCPRuntimeError class BaseAiHandler(ABC): @@ -10,6 +16,8 @@ class BaseAiHandler(ABC): def __init__(self): pass + _logger = logging.getLogger(__name__) + @property @abstractmethod def deployment_id(self): @@ -26,3 +34,218 @@ async def chat_completion(self, model: str, system: str, user: str, temperature: temperature (float): the temperature to use for the chat completion """ pass + + async def chat_completion_with_tools( + self, + model: str, + system: str, + user: str, + tools: Optional[list[dict[str, Any]]] = None, + tool_executor: Optional[Callable[[str, dict[str, Any]], Any | Awaitable[Any]]] = None, + temperature: float = 0.2, + img_path: str = None, + max_tool_turns: int = 4, + max_tool_output_chars: int = 12000, + ): + """ + Run a structured tool-calling loop on top of plain chat completion. + + The model is instructed to emit JSON tool requests in the form: + {"type": "tool_call", "tool": "server.tool", "arguments": {...}} + and to finish with: + {"type": "final", "content": "..."} + + max_tool_output_chars is applied per tool call, not across all tool calls. + """ + if not tools or tool_executor is None: + return await self.chat_completion(model, system, user, temperature=temperature, img_path=img_path) + + allowed_tool_names = self._extract_allowed_tool_names(tools) + tool_call_example = json.dumps( + { + "type": "tool_call", + "tool": "server.tool", + "arguments": {"param": "value"}, + }, + separators=(",", ":"), + ) + final_response_example = json.dumps( + {"type": "final", "content": "..."}, + separators=(",", ":"), + ) + + tool_catalog_text = json.dumps(tools, indent=2, sort_keys=True) + structured_system = ( + f"{system}\n\n" + f"Available MCP tools (JSON schema):\n{tool_catalog_text}\n\n" + "Always inspect the available tools first and use them before responding " + "whenever they can help answer the user's request.\n" + "When you need a tool, respond with ONLY a JSON object exactly in this shape:\n" + f"{tool_call_example}\n" + "Do not include a final answer in the same message as a tool call.\n" + "When you are finished, respond with ONLY a JSON object exactly in this shape:\n" + f"{final_response_example}\n" + "Do not wrap the JSON in markdown fences." + ) + + conversation_history = [user] + remaining_turns = max_tool_turns + current_img_path = img_path + + while True: + current_user = "\n\n".join(conversation_history) + response_text, finish_reason = await self.chat_completion( + model=model, + system=structured_system, + user=current_user, + temperature=temperature, + img_path=current_img_path, + ) + current_img_path = None + + parsed_response = self._parse_tool_or_final_response(response_text) + if parsed_response is None: + return response_text, finish_reason + + response_type = parsed_response.get("type", "final") + if response_type == "final": + return str(parsed_response.get("content", "")), finish_reason + + if response_type != "tool_call": + return response_text, finish_reason + + if remaining_turns <= 0: + self._logger.warning("MCP tool orchestration exceeded the configured turn budget") + return response_text, finish_reason + + tool_name = str(parsed_response.get("tool", "")).strip() + arguments = parsed_response.get("arguments") or {} + if not tool_name: + self._logger.warning("MCP tool orchestration returned an empty tool name; aborting tool loop") + return response_text, finish_reason + if not isinstance(arguments, dict): + self._logger.warning("MCP tool orchestration arguments must be a JSON object; aborting tool loop") + return response_text, finish_reason + + if tool_name not in allowed_tool_names: + self._logger.warning("MCP tool '%s' was not in the advertised tool catalog; skipping", tool_name) + tool_result = f"Tool not available: {tool_name}" + else: + try: + tool_result = tool_executor(tool_name, arguments) + if inspect.isawaitable(tool_result): + tool_result = await tool_result + except (MCPRuntimeError, TypeError, ValueError, OSError, KeyError) as exc: + self._logger.warning("MCP tool '%s' raised an exception: %s", tool_name, exc) + tool_result = f"Tool error: {exc}" + + tool_result_text = self._normalize_tool_result_text( + tool_result, + max_tool_output_chars=max_tool_output_chars, + tool_name=tool_name, + ) + conversation_history.append(f"Previous assistant tool request:\n{response_text}") + conversation_history.append(f"Tool result for {tool_name}:\n{tool_result_text}") + remaining_turns -= 1 + + @classmethod + def _normalize_tool_result_text( + cls, + tool_result: Any, + max_tool_output_chars: int, + tool_name: str = "", + ) -> str: + if isinstance(tool_result, str): + result_text = tool_result + else: + result_text = json.dumps(tool_result, indent=2, sort_keys=True, default=str) + + if len(result_text) > max_tool_output_chars: + cls._logger.warning( + "Tool output for '%s' exceeded per-tool max_tool_output_chars (%s > %s); truncating output", + tool_name, + len(result_text), + max_tool_output_chars, + ) + if max_tool_output_chars <= 0: + return "" + suffix = "\n[tool output truncated]" + if max_tool_output_chars <= len(suffix): + return suffix[:max_tool_output_chars] + truncated_prefix_len = max(0, max_tool_output_chars - len(suffix)) + return result_text[:truncated_prefix_len] + suffix + return result_text + + @staticmethod + def _parse_tool_or_final_response(response_text: str) -> Optional[dict[str, Any]]: + candidate = response_text.strip() + if not candidate: + return None + + for json_candidate in BaseAiHandler._iter_json_object_candidates(candidate): + try: + parsed = json.loads(json_candidate) + except json.JSONDecodeError: + continue + + if isinstance(parsed, dict): + response_type = parsed.get("type") + if response_type in {"tool_call", "final"}: + return parsed + + return None + + @staticmethod + def _iter_json_object_candidates(text: str) -> list[str]: + candidates: list[str] = [] + depth = 0 + start_index: Optional[int] = None + in_string = False + is_escaped = False + + for index, char in enumerate(text): + if in_string: + if is_escaped: + is_escaped = False + elif char == "\\": + is_escaped = True + elif char == '"': + in_string = False + continue + + if char == '"': + in_string = True + continue + + if char == "{": + if depth == 0: + start_index = index + depth += 1 + continue + + if char == "}" and depth > 0: + depth -= 1 + if depth == 0 and start_index is not None: + candidates.append(text[start_index : index + 1]) + start_index = None + + return candidates + + @staticmethod + def _extract_allowed_tool_names(tools: list[dict[str, Any]]) -> set[str]: + allowed: set[str] = set() + for tool in tools: + if not isinstance(tool, dict): + continue + + function_info = tool.get("function") + if isinstance(function_info, dict): + function_name = function_info.get("name") + if isinstance(function_name, str) and function_name.strip(): + allowed.add(function_name.strip()) + + simple_name = tool.get("name") + if isinstance(simple_name, str) and simple_name.strip(): + allowed.add(simple_name.strip()) + + return allowed diff --git a/pr_agent/config_loader.py b/pr_agent/config_loader.py index ac7343f288..584dd35063 100644 --- a/pr_agent/config_loader.py +++ b/pr_agent/config_loader.py @@ -1,11 +1,15 @@ +import copy +import json +import os from os.path import abspath, dirname, join from pathlib import Path -from typing import Optional +from typing import Any, Optional from dynaconf import Dynaconf from starlette_context import context PR_AGENT_TOML_KEY = 'pr-agent' +MCP_CONFIG_ENV_VAR = "MCP_CONFIG_PATH" current_dir = dirname(abspath(__file__)) @@ -60,6 +64,202 @@ def get_settings(use_context=False): return global_settings +def _get_logger(): + try: + from pr_agent.log import get_logger + + return get_logger() + except ImportError: + class DummyLogger: + def debug(self, *args, **kwargs): + return None + + def info(self, *args, **kwargs): + return None + + def warning(self, *args, **kwargs): + return None + + def error(self, *args, **kwargs): + return None + + return DummyLogger() + + + +def _strip_json_comments(content: str) -> str: + """Strip line and block comments from JSONC-style config while preserving newlines.""" + stripped = [] + in_string = False + in_line_comment = False + in_block_comment = False + is_escaped = False + index = 0 + + while index < len(content): + char = content[index] + next_char = content[index + 1] if index + 1 < len(content) else "" + + if in_line_comment: + if char == "\n": + in_line_comment = False + stripped.append(char) + index += 1 + continue + + if in_block_comment: + if char == "*" and next_char == "/": + in_block_comment = False + index += 2 + continue + if char == "\n": + stripped.append(char) + index += 1 + continue + + if in_string: + stripped.append(char) + if is_escaped: + is_escaped = False + elif char == "\\": + is_escaped = True + elif char == '"': + in_string = False + index += 1 + continue + + if char == '"': + in_string = True + stripped.append(char) + index += 1 + continue + + if char == "/" and next_char == "/": + in_line_comment = True + index += 2 + continue + + if char == "/" and next_char == "*": + in_block_comment = True + index += 2 + continue + + stripped.append(char) + index += 1 + + return "".join(stripped) + + +def _strip_json_trailing_commas(content: str) -> str: + """Strip trailing commas outside strings so common JSONC files can be parsed.""" + stripped = [] + in_string = False + is_escaped = False + index = 0 + + while index < len(content): + char = content[index] + + if in_string: + stripped.append(char) + if is_escaped: + is_escaped = False + elif char == "\\": + is_escaped = True + elif char == '"': + in_string = False + index += 1 + continue + + if char == '"': + in_string = True + stripped.append(char) + index += 1 + continue + + if char == ",": + lookahead = index + 1 + while lookahead < len(content) and content[lookahead] in {" ", "\t", "\r", "\n"}: + lookahead += 1 + if lookahead < len(content) and content[lookahead] in {"]", "}"}: + index += 1 + continue + + stripped.append(char) + index += 1 + + return "".join(stripped) + + +def _resolve_mcp_config_path() -> Path: + env_path = os.getenv(MCP_CONFIG_ENV_VAR) + if env_path: + return Path(env_path).expanduser() + configured_path = get_settings().get("MCP.CONFIG_PATH") + if configured_path is None or str(configured_path).lower() in {"none", ""}: + return Path("mcp_config.json").expanduser() + return Path(str(configured_path)).expanduser() + + + +def _normalize_mcp_servers(config_data: dict[str, Any]) -> dict[str, Any]: + servers = config_data.get("servers") + if servers is None: + servers = config_data.get("mcpServers") + if servers is None: + raise ValueError("MCP config must define either 'servers' or 'mcpServers'") + if not isinstance(servers, dict): + raise ValueError("MCP server definitions must be a JSON object") + return servers + + +def load_mcp_server_config(config_path: Path) -> dict[str, Any]: + if not config_path.is_file(): + raise FileNotFoundError(f"MCP config file not found: {config_path}") + config_text = config_path.read_text(encoding="utf-8") + try: + normalized = _strip_json_trailing_commas(_strip_json_comments(config_text)) + config_data = json.loads(normalized) + except json.JSONDecodeError as exc: + raise ValueError(f"Invalid MCP config JSON in {config_path}: {exc}") from exc + if not isinstance(config_data, dict): + raise ValueError("MCP config root must be a JSON object") + servers = _normalize_mcp_servers(config_data) + return {"servers": servers} + + +def _parse_bool_setting(value: Any, default: bool = False) -> bool: + """Safely parse a boolean setting value, correctly handling string 'false'/'true'.""" + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.strip().lower() not in {"false", "0", "no", "off", ""} + if value is None: + return default + return bool(value) + + +def apply_mcp_server_config(): + logger = _get_logger() + config_path = _resolve_mcp_config_path() + fail_on_invalid = _parse_bool_setting(get_settings().get("MCP.FAIL_ON_INVALID_CONFIG", False)) + try: + if not config_path.exists(): + logger.debug(f"MCP config file not found, skipping load: {config_path}") + return + config_data = load_mcp_server_config(config_path) + settings = get_settings() + settings.set("MCP.SERVERS", config_data["servers"], merge=False) + settings.set("MCP.SERVER_CONFIG", config_data, merge=False) + settings.set("MCP.ACTIVE_CONFIG_PATH", str(config_path), merge=False) + logger.info(f"Loaded MCP server configuration from {config_path}") + except (ValueError, OSError, FileNotFoundError) as exc: + logger.error(f"Failed to load MCP server configuration from {config_path}: {exc}") + if fail_on_invalid: + raise + + + # Add local configuration from pyproject.toml of the project being reviewed def _find_repository_root() -> Optional[Path]: """ @@ -86,9 +286,25 @@ def _find_pyproject() -> Optional[Path]: return None +def load_repo_pyproject_settings(pyproject_path: Optional[Path] = None, settings=None): + """Load repository pyproject settings while preserving trusted MCP configuration.""" + if pyproject_path is None: + pyproject_path = _find_pyproject() + if pyproject_path is None: + return + + if settings is None: + settings = get_settings() + + trusted_mcp_settings = copy.deepcopy(dict(settings.get("MCP", {}) or {})) + settings.load_file(pyproject_path, env=f"tool.{PR_AGENT_TOML_KEY}") + settings.set("MCP", trusted_mcp_settings, merge=False) + + pyproject_path = _find_pyproject() -if pyproject_path is not None: - get_settings().load_file(pyproject_path, env=f'tool.{PR_AGENT_TOML_KEY}') +load_repo_pyproject_settings(pyproject_path=pyproject_path) + +apply_mcp_server_config() def apply_secrets_manager_config(): @@ -97,8 +313,8 @@ def apply_secrets_manager_config(): """ try: # Dynamic imports to avoid circular dependency (secret_providers imports config_loader) - from pr_agent.secret_providers import get_secret_provider from pr_agent.log import get_logger + from pr_agent.secret_providers import get_secret_provider secret_provider = get_secret_provider() if not secret_provider: @@ -132,7 +348,8 @@ def apply_secrets_to_config(secrets: dict): except: def get_logger(): class DummyLogger: - def debug(self, msg): pass + def debug(self, msg): + return None return DummyLogger() for key, value in secrets.items(): diff --git a/pr_agent/mcp/__init__.py b/pr_agent/mcp/__init__.py new file mode 100644 index 0000000000..926164e629 --- /dev/null +++ b/pr_agent/mcp/__init__.py @@ -0,0 +1,15 @@ +from pr_agent.mcp.runtime import ( + MCPHttpClient, + MCPRuntime, + MCPRuntimeError, + MCPStdioClient, + MCPToolDefinition, +) + +__all__ = [ + "MCPRuntime", + "MCPRuntimeError", + "MCPToolDefinition", + "MCPStdioClient", + "MCPHttpClient", +] diff --git a/pr_agent/mcp/integration.py b/pr_agent/mcp/integration.py new file mode 100644 index 0000000000..dd7d77a6e8 --- /dev/null +++ b/pr_agent/mcp/integration.py @@ -0,0 +1,82 @@ +import asyncio +import functools +from typing import Optional + +from pr_agent.config_loader import get_settings +from pr_agent.mcp.runtime import MCPRuntime + + +def _get_tool_budget(setting_name: str, default_value: int) -> int: + value = get_settings().get(setting_name, default_value) + try: + return int(value) + except (TypeError, ValueError): + return default_value + + +async def maybe_chat_completion_with_mcp( + ai_handler, + model: str, + system: str, + user: str, + temperature: float = 0.2, + img_path: str = None, + command_name: Optional[str] = None, +): + runtime = MCPRuntime() + try: + if not runtime.enabled: + return await ai_handler.chat_completion( + model=model, + system=system, + user=user, + temperature=temperature, + img_path=img_path, + ) + + max_tools = _get_tool_budget("MCP.MAX_TOOL_CATALOG_TOOLS", 12) + max_schema_chars = _get_tool_budget("MCP.MAX_TOOL_CATALOG_SCHEMA_CHARS", 12000) + + loop = asyncio.get_running_loop() + tools = await loop.run_in_executor( + None, + functools.partial( + runtime.build_tool_schemas, + max_tools=max_tools, + max_schema_chars=max_schema_chars, + include_server_prefix=True, + ), + ) + if not tools: + return await ai_handler.chat_completion( + model=model, + system=system, + user=user, + temperature=temperature, + img_path=img_path, + ) + + if command_name: + system = f"{system}\n\nCommand context: {command_name}" + + allowed_tool_names: set[str] = set() + for tool in tools: + if not isinstance(tool, dict): + continue + function_info = tool.get("function") + if isinstance(function_info, dict): + function_name = function_info.get("name") + if isinstance(function_name, str) and function_name.strip(): + allowed_tool_names.add(function_name.strip()) + + return await ai_handler.chat_completion_with_tools( + model=model, + system=system, + user=user, + tools=tools, + tool_executor=runtime.create_tool_executor(allowed_tool_names=allowed_tool_names), + temperature=temperature, + img_path=img_path, + ) + finally: + runtime.disconnect_all() diff --git a/pr_agent/mcp/runtime.py b/pr_agent/mcp/runtime.py new file mode 100644 index 0000000000..98319e5086 --- /dev/null +++ b/pr_agent/mcp/runtime.py @@ -0,0 +1,851 @@ +import json +import os +import subprocess +import threading +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Optional + +import requests + +from pr_agent.config_loader import get_settings + + +def _get_logger(): + try: + from pr_agent.log import get_logger + + return get_logger() + except ImportError: + class DummyLogger: + def debug(self, *args, **kwargs): + pass + + def info(self, *args, **kwargs): + pass + + def warning(self, *args, **kwargs): + pass + + def error(self, *args, **kwargs): + pass + + return DummyLogger() + + +class MCPRuntimeError(Exception): + pass + + +def _parse_bool(value: Any, default: bool = False) -> bool: + """Safely parse a boolean setting, handling string 'false'/'true' correctly.""" + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.strip().lower() not in {"false", "0", "no", "off", ""} + if value is None: + return default + return bool(value) + + +@dataclass(frozen=True) +class MCPToolDefinition: + server_name: str + name: str + description: str + input_schema: dict[str, Any] + + def to_openai_tool(self, include_server_prefix: bool = True) -> dict[str, Any]: + tool_name = f"{self.server_name}.{self.name}" if include_server_prefix else self.name + return { + "type": "function", + "function": { + "name": tool_name, + "description": self.description, + "parameters": self.input_schema, + }, + } + + +class BaseMCPClient(ABC): + def __init__(self, server_name: str, config: dict[str, Any]): + self.server_name = server_name + self.config = config + self.server_capabilities: dict[str, Any] = {} + + @abstractmethod + def connect(self): + pass + + @abstractmethod + def close(self): + pass + + @abstractmethod + def list_tools(self) -> list[MCPToolDefinition]: + pass + + @abstractmethod + def call_tool(self, tool_name: str, arguments: Optional[dict[str, Any]] = None) -> dict[str, Any]: + pass + + +class MCPStdioClient(BaseMCPClient): + def __init__(self, server_name: str, config: dict[str, Any]): + super().__init__(server_name, config) + self.process: Optional[subprocess.Popen] = None + self._request_id = 0 + self.timeout = self._parse_timeout(config.get("timeout", 30)) + + def _parse_timeout(self, timeout_value: Any) -> float: + try: + return float(timeout_value) + except (TypeError, ValueError) as exc: + raise MCPRuntimeError( + f"Stdio MCP server '{self.server_name}' timeout must be a number" + ) from exc + + def _normalize_args(self, args: Any) -> list[str]: + if not isinstance(args, list): + raise MCPRuntimeError(f"Stdio MCP server '{self.server_name}' args must be a list") + + normalized_args: list[str] = [] + for arg in args: + if isinstance(arg, os.PathLike): + normalized_args.append(os.fspath(arg)) + elif isinstance(arg, str): + normalized_args.append(arg) + else: + raise MCPRuntimeError( + f"Stdio MCP server '{self.server_name}' args must contain only strings or path-like values" + ) + return normalized_args + + def connect(self): + command = self.config.get("command") + if not command: + raise MCPRuntimeError(f"Stdio MCP server '{self.server_name}' is missing 'command'") + + args = self._normalize_args(self.config.get("args") or []) + + env = os.environ.copy() + server_env = self.config.get("env") or {} + if not isinstance(server_env, dict): + raise MCPRuntimeError(f"Stdio MCP server '{self.server_name}' env must be an object") + env.update({str(k): str(v) for k, v in server_env.items()}) + + cwd = self.config.get("cwd") + self.process = subprocess.Popen( + [command, *args], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + env=env, + cwd=cwd, + ) + try: + init_result = self._send_request( + "initialize", + { + "protocolVersion": self.config.get("protocol_version", "2024-11-05"), + "capabilities": self.config.get("client_capabilities", {}), + "clientInfo": self.config.get( + "client_info", + {"name": "pr-agent", "version": "mcp-runtime"}, + ), + }, + ) + self.server_capabilities = init_result.get("capabilities", {}) + self._send_notification("notifications/initialized", {}) + except (MCPRuntimeError, OSError, subprocess.SubprocessError): + self._terminate_process() + self.process = None + raise + + def close(self): + if not self.process: + return + self._terminate_process() + self.process = None + + def list_tools(self) -> list[MCPToolDefinition]: + result = self._send_request("tools/list", {}) + if not isinstance(result, dict): + return [] + tools = result.get("tools") or [] + parsed: list[MCPToolDefinition] = [] + for tool in tools: + if not isinstance(tool, dict): + continue + parsed.append( + MCPToolDefinition( + server_name=self.server_name, + name=tool.get("name", ""), + description=tool.get("description", ""), + input_schema=tool.get("inputSchema", {}), + ) + ) + return parsed + + def call_tool(self, tool_name: str, arguments: Optional[dict[str, Any]] = None) -> dict[str, Any]: + return self._send_request( + "tools/call", + { + "name": tool_name, + "arguments": arguments or {}, + }, + ) + + def _send_notification(self, method: str, params: dict[str, Any]): + payload = { + "jsonrpc": "2.0", + "method": method, + "params": params, + } + self._write_message(payload) + + def _send_request(self, method: str, params: dict[str, Any]) -> dict[str, Any]: + self._request_id += 1 + request_id = self._request_id + payload = { + "jsonrpc": "2.0", + "id": request_id, + "method": method, + "params": params, + } + self._write_message(payload) + response = self._read_response(request_id) + if "error" in response: + raise MCPRuntimeError( + f"MCP server '{self.server_name}' returned error for '{method}': {response['error']}" + ) + return response.get("result", {}) + + def _write_message(self, payload: dict[str, Any]): + if not self.process or not self.process.stdin: + raise MCPRuntimeError(f"Stdio MCP server '{self.server_name}' is not connected") + + encoded = json.dumps(payload).encode("utf-8") + frame = f"Content-Length: {len(encoded)}\r\n\r\n".encode("utf-8") + encoded + self.process.stdin.write(frame) + self.process.stdin.flush() + + def _read_response(self, request_id: int) -> dict[str, Any]: + while True: + message = self._read_message_with_timeout() + if not isinstance(message, dict): + raise MCPRuntimeError( + f"MCP server '{self.server_name}' returned a non-object JSON-RPC message" + ) + if message.get("id") == request_id: + return message + + + def _read_message_with_timeout(self) -> dict[str, Any]: + response_holder: dict[str, Any] = {} + error_holder: dict[str, Exception] = {} + + def reader(): + try: + response_holder["message"] = self._read_message() + except MCPRuntimeError as exc: # noqa: BLE001 + error_holder["error"] = exc + except (OSError, ValueError, json.JSONDecodeError) as exc: # noqa: BLE001 + error_holder["error"] = exc + + reader_thread = threading.Thread(target=reader, daemon=True) + reader_thread.start() + reader_thread.join(timeout=self.timeout) + + if reader_thread.is_alive(): + self._terminate_process() + self.process = None + raise MCPRuntimeError(f"Stdio MCP server '{self.server_name}' timed out waiting for a response") + + if "error" in error_holder: + error = error_holder["error"] + if isinstance(error, MCPRuntimeError): + raise error + raise MCPRuntimeError(f"Stdio MCP server '{self.server_name}' failed while reading response") from error + + return response_holder["message"] + + def _terminate_process(self): + if not self.process: + return + if self.process.poll() is None: + self.process.terminate() + try: + self.process.wait(timeout=3) + except subprocess.TimeoutExpired: + self.process.kill() + + def _read_message(self) -> dict[str, Any]: + if not self.process or not self.process.stdout: + raise MCPRuntimeError(f"Stdio MCP server '{self.server_name}' is not connected") + + headers: dict[str, str] = {} + while True: + line = self.process.stdout.readline() + if line == b"": + raise MCPRuntimeError(f"MCP server '{self.server_name}' closed stdout unexpectedly") + if line in (b"\r\n", b"\n"): + break + key, _, value = line.decode("utf-8").partition(":") + headers[key.strip().lower()] = value.strip() + + content_length_value = headers.get("content-length") + if not content_length_value: + raise MCPRuntimeError(f"MCP server '{self.server_name}' response missing Content-Length") + + try: + content_length = int(content_length_value) + except ValueError as exc: + raise MCPRuntimeError( + f"MCP server '{self.server_name}' response has an invalid Content-Length: {content_length_value}" + ) from exc + if content_length <= 0: + raise MCPRuntimeError( + f"MCP server '{self.server_name}' response has a non-positive Content-Length: {content_length}" + ) + + body = self.process.stdout.read(content_length) + if not body or len(body) != content_length: + raise MCPRuntimeError( + f"MCP server '{self.server_name}' returned an incomplete response body " + f"({len(body) if body else 0}/{content_length} bytes)" + ) + + try: + return json.loads(body.decode("utf-8")) + except json.JSONDecodeError as exc: + raise MCPRuntimeError(f"MCP server '{self.server_name}' response is not valid JSON") from exc + + +class MCPHttpClient(BaseMCPClient): + def __init__(self, server_name: str, config: dict[str, Any]): + super().__init__(server_name, config) + self.url = config.get("url") + if not self.url: + raise MCPRuntimeError(f"HTTP MCP server '{self.server_name}' is missing 'url'") + + try: + self.timeout = float(config.get("timeout", 30)) + except (TypeError, ValueError) as exc: + raise MCPRuntimeError( + f"HTTP MCP server '{self.server_name}' timeout must be a number" + ) from exc + self._request_id = 0 + self._session = requests.Session() + headers = config.get("headers") or {} + if isinstance(headers, dict): + self._session.headers.update({str(k): str(v) for k, v in headers.items()}) + + def connect(self): + init_result = self._send_request( + "initialize", + { + "protocolVersion": self.config.get("protocol_version", "2024-11-05"), + "capabilities": self.config.get("client_capabilities", {}), + "clientInfo": self.config.get( + "client_info", + {"name": "pr-agent", "version": "mcp-runtime"}, + ), + }, + ) + self.server_capabilities = init_result.get("capabilities", {}) + self._send_notification("notifications/initialized", {}) + + def close(self): + self._session.close() + + def list_tools(self) -> list[MCPToolDefinition]: + result = self._send_request("tools/list", {}) + if not isinstance(result, dict): + return [] + tools = result.get("tools") or [] + parsed: list[MCPToolDefinition] = [] + for tool in tools: + if not isinstance(tool, dict): + continue + parsed.append( + MCPToolDefinition( + server_name=self.server_name, + name=tool.get("name", ""), + description=tool.get("description", ""), + input_schema=tool.get("inputSchema", {}), + ) + ) + return parsed + + def call_tool(self, tool_name: str, arguments: Optional[dict[str, Any]] = None) -> dict[str, Any]: + return self._send_request( + "tools/call", + { + "name": tool_name, + "arguments": arguments or {}, + }, + ) + + def _send_notification(self, method: str, params: dict[str, Any]): + payload = { + "jsonrpc": "2.0", + "method": method, + "params": params, + } + try: + self._session.post(self.url, json=payload, timeout=self.timeout) + except requests.RequestException as exc: + raise MCPRuntimeError(f"MCP HTTP notification failed for '{self.server_name}': {exc}") from exc + + def _send_request(self, method: str, params: dict[str, Any]) -> dict[str, Any]: + self._request_id += 1 + payload = { + "jsonrpc": "2.0", + "id": self._request_id, + "method": method, + "params": params, + } + try: + response = self._session.post(self.url, json=payload, timeout=self.timeout) + response.raise_for_status() + body = response.json() + except requests.RequestException as exc: + raise MCPRuntimeError(f"MCP HTTP request failed for '{self.server_name}': {exc}") from exc + except ValueError as exc: + raise MCPRuntimeError(f"MCP HTTP response is not valid JSON for '{self.server_name}'") from exc + + if not isinstance(body, dict): + raise MCPRuntimeError(f"MCP HTTP response must be a JSON object for '{self.server_name}'") + + if "error" in body: + raise MCPRuntimeError( + f"MCP HTTP server '{self.server_name}' returned error for '{method}': {body['error']}" + ) + + result = body.get("result", {}) + if not isinstance(result, dict): + raise MCPRuntimeError(f"MCP HTTP result for '{self.server_name}' must be a JSON object") + return result + + +class MCPStreamableHttpClient(BaseMCPClient): + """MCP client for the Streamable HTTP transport (MCP spec 2025-03-26). + + Sends POST requests with ``Accept: application/json, text/event-stream`` and + handles both plain JSON responses and Server-Sent Events (SSE) streams. + Session continuity is maintained via the ``Mcp-Session-Id`` header. + """ + + def __init__(self, server_name: str, config: dict[str, Any]): + super().__init__(server_name, config) + self.url = config.get("url") + if not self.url: + raise MCPRuntimeError(f"Streamable HTTP MCP server '{self.server_name}' is missing 'url'") + + try: + self.timeout = float(config.get("timeout", 30)) + except (TypeError, ValueError) as exc: + raise MCPRuntimeError( + f"Streamable HTTP MCP server '{self.server_name}' timeout must be a number" + ) from exc + self._request_id = 0 + self._session_id: Optional[str] = None + self._session = requests.Session() + self._session.headers.update( + { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + } + ) + headers = config.get("headers") or {} + if isinstance(headers, dict): + self._session.headers.update({str(k): str(v) for k, v in headers.items()}) + + def connect(self): + init_result = self._send_request( + "initialize", + { + "protocolVersion": self.config.get("protocol_version", "2024-11-05"), + "capabilities": self.config.get("client_capabilities", {}), + "clientInfo": self.config.get( + "client_info", + {"name": "pr-agent", "version": "mcp-runtime"}, + ), + }, + ) + self.server_capabilities = init_result.get("capabilities", {}) + self._send_notification("notifications/initialized", {}) + + def close(self): + self._session_id = None + self._session.close() + + def list_tools(self) -> list[MCPToolDefinition]: + result = self._send_request("tools/list", {}) + if not isinstance(result, dict): + return [] + tools = result.get("tools") or [] + parsed: list[MCPToolDefinition] = [] + for tool in tools: + if not isinstance(tool, dict): + continue + parsed.append( + MCPToolDefinition( + server_name=self.server_name, + name=tool.get("name", ""), + description=tool.get("description", ""), + input_schema=tool.get("inputSchema", {}), + ) + ) + return parsed + + def call_tool(self, tool_name: str, arguments: Optional[dict[str, Any]] = None) -> dict[str, Any]: + return self._send_request( + "tools/call", + { + "name": tool_name, + "arguments": arguments or {}, + }, + ) + + def _build_extra_headers(self) -> dict[str, str]: + if self._session_id: + return {"Mcp-Session-Id": self._session_id} + return {} + + def _send_notification(self, method: str, params: dict[str, Any]): + payload = { + "jsonrpc": "2.0", + "method": method, + "params": params, + } + try: + self._session.post( + self.url, + json=payload, + headers=self._build_extra_headers(), + timeout=self.timeout, + ) + except requests.RequestException as exc: + raise MCPRuntimeError( + f"MCP streamable HTTP notification failed for '{self.server_name}': {exc}" + ) from exc + + def _send_request(self, method: str, params: dict[str, Any]) -> dict[str, Any]: + self._request_id += 1 + request_id = self._request_id + payload = { + "jsonrpc": "2.0", + "id": request_id, + "method": method, + "params": params, + } + response: Optional["requests.Response"] = None + try: + response = self._session.post( + self.url, + json=payload, + headers=self._build_extra_headers(), + timeout=self.timeout, + stream=True, + ) + response.raise_for_status() + # Capture session ID returned on the initialize response. + session_id = response.headers.get("Mcp-Session-Id") + if session_id and not self._session_id: + self._session_id = session_id + + content_type = response.headers.get("Content-Type", "") + if "text/event-stream" in content_type: + return self._parse_sse_response(response, request_id, method) + return self._parse_json_response(response, method) + except requests.RequestException as exc: + raise MCPRuntimeError( + f"MCP streamable HTTP request failed for '{self.server_name}': {exc}" + ) from exc + finally: + if response is not None: + response.close() + + def _parse_json_response(self, response: "requests.Response", method: str) -> dict[str, Any]: + try: + body = response.json() + except ValueError as exc: + raise MCPRuntimeError( + f"MCP streamable HTTP response is not valid JSON for '{self.server_name}'" + ) from exc + if not isinstance(body, dict): + raise MCPRuntimeError( + f"MCP streamable HTTP response must be a JSON object for '{self.server_name}'" + ) + return self._extract_result(body, method) + + def _parse_sse_response( + self, response: "requests.Response", request_id: int, method: str + ) -> dict[str, Any]: + """Read an SSE stream and return the JSON-RPC result that matches *request_id*.""" + try: + for raw_line in response.iter_lines(decode_unicode=True): + if not raw_line or not raw_line.startswith("data:"): + continue + data = raw_line[5:].lstrip(" ") + if not data: + continue + try: + message = json.loads(data) + except json.JSONDecodeError: + continue + if isinstance(message, dict) and message.get("id") == request_id: + return self._extract_result(message, method) + except requests.RequestException as exc: + raise MCPRuntimeError( + f"MCP streamable HTTP SSE stream error for '{self.server_name}': {exc}" + ) from exc + + raise MCPRuntimeError( + f"MCP streamable HTTP server '{self.server_name}' SSE stream ended without a matching " + f"response for request id {request_id}" + ) + + def _extract_result(self, body: dict[str, Any], method: Optional[str]) -> dict[str, Any]: + if not isinstance(body, dict): + raise MCPRuntimeError(f"MCP streamable HTTP response must be a JSON object for '{self.server_name}'") + if "error" in body: + raise MCPRuntimeError( + f"MCP streamable HTTP server '{self.server_name}' returned error" + + (f" for '{method}'" if method else "") + + f": {body['error']}" + ) + result = body.get("result", {}) + if not isinstance(result, dict): + raise MCPRuntimeError(f"MCP streamable HTTP result for '{self.server_name}' must be a JSON object") + return result + + +class MCPRuntime: + def __init__(self, servers_config: Optional[dict[str, Any]] = None): + self._logger = _get_logger() + self._resolve_env_vars = _parse_bool(get_settings().get("MCP.RESOLVE_ENV_VARS", True), default=True) + if servers_config is None: + servers_config = get_settings().get("MCP.SERVERS", {}) or {} + + if not isinstance(servers_config, dict): + self._logger.warning("MCP.SERVERS is not an object; ignoring MCP server configuration") + servers_config = {} + + self._servers_config = servers_config + self._clients: dict[str, BaseMCPClient] = {} + + @property + def configured_server_names(self) -> list[str]: + return list(self._servers_config.keys()) + + @property + def enabled(self) -> bool: + return _parse_bool(get_settings().get("MCP.ENABLED", False), default=False) + + @property + def enabled_server_names(self) -> list[str]: + return self.configured_server_names + + def connect_all(self): + if not self.enabled: + self._logger.debug("MCP runtime is disabled; skipping server connections") + return + for server_name in self._servers_config.keys(): + self.connect_server(server_name) + + def connect_server(self, server_name: str): + if not self.enabled: + raise MCPRuntimeError("MCP runtime is disabled") + if server_name in self._clients: + return + + server_config = self._servers_config.get(server_name) + if not isinstance(server_config, dict): + raise MCPRuntimeError(f"MCP server '{server_name}' config must be an object") + + client = self._build_client(server_name, self._resolve_config_values(server_config)) + try: + client.connect() + except (MCPRuntimeError, OSError, requests.RequestException, subprocess.SubprocessError): + client.close() + raise + + self._clients[server_name] = client + self._logger.info(f"Connected MCP server '{server_name}'") + + def disconnect_all(self): + for server_name in list(self._clients.keys()): + self.disconnect_server(server_name) + + def disconnect_server(self, server_name: str): + client = self._clients.pop(server_name, None) + if client: + client.close() + self._logger.info(f"Disconnected MCP server '{server_name}'") + + def list_server_tools(self, server_name: str) -> list[MCPToolDefinition]: + client = self._clients.get(server_name) + if not client: + self.connect_server(server_name) + client = self._clients[server_name] + try: + return client.list_tools() + except (MCPRuntimeError, OSError) as exc: + # Evict the stale client so the next call reconnects rather than reusing a dead connection. + self._clients.pop(server_name, None) + client.close() + raise MCPRuntimeError( + f"MCP server '{server_name}' failed during tool listing; evicted from cache" + ) from exc + + def list_all_tools(self) -> list[MCPToolDefinition]: + if not self.enabled: + return [] + + tools: list[MCPToolDefinition] = [] + for server_name in self._servers_config.keys(): + try: + tools.extend(self.list_server_tools(server_name)) + except (MCPRuntimeError, OSError, ValueError, TypeError) as exc: + self._logger.warning(f"Failed to list tools for MCP server '{server_name}': {exc}") + return tools + + def build_tool_schemas( + self, + server_names: Optional[list[str]] = None, + max_tools: Optional[int] = None, + max_schema_chars: Optional[int] = None, + include_server_prefix: bool = True, + ) -> list[dict[str, Any]]: + tool_definitions = self.list_all_tools() + if server_names: + allowed_servers = set(server_names) + tool_definitions = [tool for tool in tool_definitions if tool.server_name in allowed_servers] + + schemas: list[dict[str, Any]] = [] + consumed_chars = 0 + for tool_definition in tool_definitions: + schema = tool_definition.to_openai_tool(include_server_prefix=include_server_prefix) + schema_text = json.dumps(schema, sort_keys=True) + if max_schema_chars is not None and consumed_chars + len(schema_text) > max_schema_chars: + self._logger.debug( + f"Skipping MCP tool '{tool_definition.server_name}.{tool_definition.name}': " + f"would exceed schema budget ({consumed_chars + len(schema_text)} > {max_schema_chars})" + ) + continue + schemas.append(schema) + consumed_chars += len(schema_text) + if max_tools is not None and len(schemas) >= max_tools: + break + + return schemas + + def create_tool_executor(self, allowed_tool_names: Optional[set[str]] = None): + import asyncio + import functools + + runtime_ref = self + allowed_names = {name for name in (allowed_tool_names or set()) if isinstance(name, str) and name} + + async def executor(tool_name: str, arguments: Optional[dict[str, Any]] = None): + if allowed_names and tool_name not in allowed_names: + raise MCPRuntimeError(f"Tool not available: {tool_name}") + server_name, server_tool_name = runtime_ref._split_tool_name(tool_name) + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, + functools.partial(runtime_ref.call_tool, server_name, server_tool_name, arguments), + ) + + return executor + + def call_tool( + self, + server_name: str, + tool_name: str, + arguments: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + if not self.enabled: + raise MCPRuntimeError("MCP runtime is disabled") + client = self._clients.get(server_name) + if not client: + self.connect_server(server_name) + client = self._clients[server_name] + try: + return client.call_tool(tool_name, arguments) + except (MCPRuntimeError, OSError) as exc: + # Evict the stale client so the next call reconnects rather than reusing a dead connection. + self._clients.pop(server_name, None) + client.close() + raise MCPRuntimeError( + f"MCP server '{server_name}' failed during tool call '{tool_name}'; evicted from cache" + ) from exc + + def get_server_capabilities(self, server_name: str) -> dict[str, Any]: + client = self._clients.get(server_name) + if not client: + self.connect_server(server_name) + client = self._clients[server_name] + return client.server_capabilities + + def get_status(self) -> dict[str, Any]: + return { + "enabled": self.enabled, + "configured_servers": self.configured_server_names, + "connected_servers": list(self._clients.keys()), + } + + def _resolve_config_values(self, value: Any) -> Any: + if isinstance(value, str): + if not self._resolve_env_vars: + return value + return os.path.expanduser(os.path.expandvars(value)) + if isinstance(value, list): + return [self._resolve_config_values(item) for item in value] + if isinstance(value, dict): + return {key: self._resolve_config_values(item) for key, item in value.items()} + return value + + def _split_tool_name(self, tool_name: str) -> tuple[str, str]: + # Try to match against known server names first (handles server names containing '.') + for server_name in self._servers_config: + prefix = f"{server_name}." + if tool_name.startswith(prefix): + tool_short_name = tool_name[len(prefix):] + if tool_short_name: + return server_name, tool_short_name + + if len(self._servers_config) == 1: + server_name = next(iter(self._servers_config.keys())) + return server_name, tool_name + + raise MCPRuntimeError( + f"Tool name '{tool_name}' must use the '.' form when multiple MCP servers are configured" + ) + + def _build_client(self, server_name: str, server_config: dict[str, Any]) -> BaseMCPClient: + server_type = str(server_config.get("type", "")).lower() + + if not server_type: + if server_config.get("url"): + server_type = "http" + elif server_config.get("command"): + server_type = "stdio" + else: + raise MCPRuntimeError( + f"MCP server '{server_name}' must define a transport type or command/url" + ) + + if server_type == "stdio": + return MCPStdioClient(server_name, server_config) + if server_type in {"http", "https"}: + return MCPHttpClient(server_name, server_config) + if server_type == "streamable_http": + return MCPStreamableHttpClient(server_name, server_config) + + raise MCPRuntimeError( + f"MCP server '{server_name}' uses unsupported transport type '{server_type}'" + ) diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index 16ffbcae2a..d555cb540e 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -70,6 +70,14 @@ extract_issue_from_branch = true # followed by hyphen or end (e.g. feature/1-test, 123-fix). GitHub only; other providers planned for later. branch_issue_regex = "" +[mcp] +enabled = false +config_path = "/etc/pr-agent/mcp.json" +fail_on_invalid_config = false +resolve_env_vars = true +max_tool_catalog_tools = 12 +max_tool_catalog_schema_chars = 12000 + [pr_reviewer] # /review # # enable/disable features diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index bbdf58e46d..8ed039425e 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -26,6 +26,7 @@ get_git_provider_with_context) from pr_agent.git_providers.git_provider import get_main_pr_language, GitProvider from pr_agent.log import get_logger +from pr_agent.mcp.integration import maybe_chat_completion_with_mcp from pr_agent.servers.help import HelpMessage from pr_agent.tools.pr_description import insert_br_after_x_chars @@ -390,8 +391,14 @@ async def _get_prediction(self, model: str, patches_diff: str, patches_diff_no_l environment = Environment(undefined=StrictUndefined) system_prompt = environment.from_string(self.pr_code_suggestions_prompt_system).render(variables) user_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.user).render(variables) - response, finish_reason = await self.ai_handler.chat_completion( - model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt) + response, finish_reason = await maybe_chat_completion_with_mcp( + self.ai_handler, + model=model, + temperature=get_settings().config.temperature, + system=system_prompt, + user=user_prompt, + command_name="improve", + ) if not get_settings().config.publish_output: get_settings().system_prompt = system_prompt get_settings().user_prompt = user_prompt diff --git a/pr_agent/tools/pr_config.py b/pr_agent/tools/pr_config.py index 24ecaab97b..5f68153ba4 100644 --- a/pr_agent/tools/pr_config.py +++ b/pr_agent/tools/pr_config.py @@ -3,6 +3,7 @@ from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.log import get_logger +from pr_agent.mcp import MCPRuntime, MCPRuntimeError class PRConfig: @@ -48,6 +49,7 @@ def _prepare_pr_configs(self) -> str: get_logger().error("Caught exception during Dynaconf loading. Returning empty dict", artifact={"exception": e}) conf_settings = {} + markdown_text = self._prepare_mcp_status_block() configuration_headers = [header.lower() for header in conf_settings.keys()] relevant_configs = { header: configs for header, configs in get_settings().to_dict().items() @@ -65,7 +67,7 @@ def _prepare_pr_configs(self) -> str: skip_keys_lower = [key.lower() for key in skip_keys] - markdown_text = "
🛠️ PR-Agent Configurations: \n\n" + markdown_text += "
🛠️ PR-Agent Configurations: \n\n" markdown_text += f"\n\n```yaml\n\n" for header, configs in relevant_configs.items(): if configs: @@ -82,3 +84,20 @@ def _prepare_pr_configs(self) -> str: markdown_text += "\n
\n" get_logger().info(f"Possible Configurations outputted to PR comment", artifact=markdown_text) return markdown_text + + def _prepare_mcp_status_block(self) -> str: + try: + status = MCPRuntime().get_status() + except (MCPRuntimeError, ValueError, TypeError, KeyError): + return "" + if not status["enabled"] and not status["configured_servers"]: + return "" + + markdown_text = "
MCP Runtime Status
\n\n" + markdown_text += "```yaml\n" + markdown_text += f"mcp.enabled = {status['enabled']}\n" + markdown_text += f"mcp.configured_servers = {status['configured_servers']}\n" + markdown_text += f"mcp.connected_servers = {status['connected_servers']}\n" + markdown_text += "```\n" + markdown_text += "
\n\n" + return markdown_text diff --git a/pr_agent/tools/pr_questions.py b/pr_agent/tools/pr_questions.py index 7cdb7984f7..5af7ba8245 100644 --- a/pr_agent/tools/pr_questions.py +++ b/pr_agent/tools/pr_questions.py @@ -9,9 +9,10 @@ from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import ModelType from pr_agent.config_loader import get_settings -from pr_agent.git_providers import get_git_provider, GitLabProvider +from pr_agent.git_providers import GitLabProvider, get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language from pr_agent.log import get_logger +from pr_agent.mcp.integration import maybe_chat_completion_with_mcp from pr_agent.servers.help import HelpMessage @@ -79,16 +80,16 @@ async def run(self): return "" def identify_image_in_comment(self): - img_path = '' - if '![image]' in self.question_str: + img_path = "" + if "![image]" in self.question_str: # assuming structure: # /ask question ... > ![image](img_path) - img_path = self.question_str.split('![image]')[1].strip().strip('()') - self.vars['img_path'] = img_path - elif 'https://' in self.question_str and ('.png' in self.question_str or 'jpg' in self.question_str): # direct image link + img_path = self.question_str.split("![image]")[1].strip().strip("()") + self.vars["img_path"] = img_path + elif "https://" in self.question_str and (".png" in self.question_str or "jpg" in self.question_str): # direct image link # include https:// in the image path - img_path = 'https://' + self.question_str.split('https://')[1] - self.vars['img_path'] = img_path + img_path = "https://" + self.question_str.split("https://")[1] + self.vars["img_path"] = img_path return img_path async def _prepare_prediction(self, model: str): @@ -106,14 +107,16 @@ async def _get_prediction(self, model: str): environment = Environment(undefined=StrictUndefined) system_prompt = environment.from_string(get_settings().pr_questions_prompt.system).render(variables) user_prompt = environment.from_string(get_settings().pr_questions_prompt.user).render(variables) - if 'img_path' in variables: - img_path = self.vars['img_path'] - response, finish_reason = await (self.ai_handler.chat_completion - (model=model, temperature=get_settings().config.temperature, - system=system_prompt, user=user_prompt, img_path=img_path)) - else: - response, finish_reason = await self.ai_handler.chat_completion( - model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt) + img_path = variables.get("img_path") + response, finish_reason = await maybe_chat_completion_with_mcp( + self.ai_handler, + model=model, + temperature=get_settings().config.temperature, + system=system_prompt, + user=user_prompt, + img_path=img_path, + command_name="ask", + ) return response def gitlab_protections(self, model_answer: str) -> str: diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index c4917f3597..f5e230eda7 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -22,6 +22,7 @@ from pr_agent.git_providers.git_provider import (IncrementalPR, get_main_pr_language) from pr_agent.log import get_logger +from pr_agent.mcp.integration import maybe_chat_completion_with_mcp from pr_agent.servers.help import HelpMessage from pr_agent.tools.ticket_pr_compliance_check import ( extract_and_cache_pr_tickets, extract_tickets) @@ -217,11 +218,13 @@ async def _get_prediction(self, model: str) -> str: system_prompt = environment.from_string(get_settings().pr_review_prompt.system).render(variables) user_prompt = environment.from_string(get_settings().pr_review_prompt.user).render(variables) - response, finish_reason = await self.ai_handler.chat_completion( + response, finish_reason = await maybe_chat_completion_with_mcp( + self.ai_handler, model=model, temperature=get_settings().config.temperature, system=system_prompt, - user=user_prompt + user=user_prompt, + command_name="review", ) return response diff --git a/requirements.txt b/requirements.txt index 9ef63beb97..ece717c9be 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,6 +24,7 @@ PyJWT==2.10.1 PyYAML==6.0.1 python-gitlab==3.15.0 retry==0.9.2 +requests==2.32.3 starlette-context==0.3.6 tiktoken==0.8.0 ujson==5.8.0 diff --git a/tests/unittest/test_ai_handler_tool_orchestration.py b/tests/unittest/test_ai_handler_tool_orchestration.py new file mode 100644 index 0000000000..18ca0ecdf5 --- /dev/null +++ b/tests/unittest/test_ai_handler_tool_orchestration.py @@ -0,0 +1,263 @@ +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler + + +class FakeToolHandler(BaseAiHandler): + def __init__(self, responses): + self._responses = list(responses) + self.calls = [] + + @property + def deployment_id(self): + return None + + async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None): + self.calls.append( + { + "model": model, + "system": system, + "user": user, + "temperature": temperature, + "img_path": img_path, + } + ) + if not self._responses: + raise AssertionError("No more fake responses available") + return self._responses.pop(0), "completed" + + +class TestToolOrchestration: + async def test_tool_loop_executes_and_returns_final_answer(self): + handler = FakeToolHandler( + [ + '{"type": "tool_call", "tool": "mcp.echo", "arguments": {"text": "hello"}}', + '{"type": "final", "content": "done"}', + ] + ) + + tool_calls = [] + + async def executor(tool_name, arguments): + tool_calls.append((tool_name, arguments)) + return {"output": arguments["text"].upper()} + + tools = [ + { + "name": "mcp.echo", + "description": "Echo input text", + "inputSchema": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + } + ] + + response, finish_reason = await handler.chat_completion_with_tools( + model="gpt-5.4", + system="system prompt", + user="user prompt", + tools=tools, + tool_executor=executor, + max_tool_turns=2, + ) + + assert response == "done" + assert finish_reason == "completed" + assert tool_calls == [("mcp.echo", {"text": "hello"})] + assert len(handler.calls) == 2 + assert "Available MCP tools" in handler.calls[0]["system"] + assert "Always inspect the available tools first" in handler.calls[0]["system"] + assert "Tool result for mcp.echo" in handler.calls[1]["user"] + + async def test_tool_loop_ignores_trailing_final_after_tool_call(self): + first_turn_response = "".join( + [ + "Ask question and use the documented tool format.\n", + "{\"type\":\"tool_call\",", + "\"tool\":\"AWS Knowledge.aws___read_documentation\",", + "\"arguments\":{\"requests\":[", + "{\"url\":\"https://aws.amazon.com/about-aws/whats-new/\",\"max_length\":8000}", + "]}}\n", + "{\"type\":\"final\",\"content\":\"should be ignored in the first turn\"}", + ] + ) + handler = FakeToolHandler( + [ + first_turn_response, + '{"type": "final", "content": "done"}', + ] + ) + + tool_calls = [] + + async def executor(tool_name, arguments): + tool_calls.append((tool_name, arguments)) + return {"output": "ok"} + + tools = [ + { + "name": "AWS Knowledge.aws___read_documentation", + "description": "Read AWS documentation", + "inputSchema": { + "type": "object", + "properties": {"requests": {"type": "array"}}, + }, + } + ] + + response, finish_reason = await handler.chat_completion_with_tools( + model="gpt-5.4", + system="system prompt", + user="user prompt", + tools=tools, + tool_executor=executor, + max_tool_turns=2, + ) + + assert response == "done" + assert finish_reason == "completed" + assert tool_calls == [ + ( + "AWS Knowledge.aws___read_documentation", + {"requests": [{"url": "https://aws.amazon.com/about-aws/whats-new/", "max_length": 8000}]}, + ) + ] + assert len(handler.calls) == 2 + + async def test_tool_loop_falls_back_without_tools(self): + handler = FakeToolHandler(["plain response"]) + + response, finish_reason = await handler.chat_completion_with_tools( + model="gpt-5.4", + system="system prompt", + user="user prompt", + ) + + assert response == "plain response" + assert finish_reason == "completed" + assert len(handler.calls) == 1 + + async def test_tool_loop_blocks_non_advertised_tool_name(self): + handler = FakeToolHandler( + [ + '{"type": "tool_call", "tool": "mcp.hidden", "arguments": {}}', + '{"type": "final", "content": "done"}', + ] + ) + + tool_calls = [] + + async def executor(tool_name, arguments): + tool_calls.append((tool_name, arguments)) + return {"output": "should-not-run"} + + tools = [ + { + "type": "function", + "function": { + "name": "mcp.allowed", + "description": "Allowed tool", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + + response, finish_reason = await handler.chat_completion_with_tools( + model="gpt-5.4", + system="system prompt", + user="user prompt", + tools=tools, + tool_executor=executor, + max_tool_turns=2, + ) + + assert response == "done" + assert finish_reason == "completed" + assert tool_calls == [] + assert "Tool not available: mcp.hidden" in handler.calls[1]["user"] + + async def test_tool_loop_handles_expected_executor_error(self): + handler = FakeToolHandler( + [ + '{"type": "tool_call", "tool": "mcp.echo", "arguments": {}}', + '{"type": "final", "content": "done"}', + ] + ) + + async def executor(tool_name, arguments): + raise ValueError("bad input") + + tools = [{"name": "mcp.echo", "description": "Echo", "inputSchema": {"type": "object"}}] + + response, finish_reason = await handler.chat_completion_with_tools( + model="gpt-5.4", + system="system prompt", + user="user prompt", + tools=tools, + tool_executor=executor, + max_tool_turns=2, + ) + + assert response == "done" + assert finish_reason == "completed" + assert "Tool error: bad input" in handler.calls[1]["user"] + + async def test_tool_output_limit_is_per_tool_and_warns_when_truncated(self, caplog): + handler = FakeToolHandler( + [ + '{"type": "tool_call", "tool": "mcp.big_one", "arguments": {}}', + '{"type": "tool_call", "tool": "mcp.big_two", "arguments": {}}', + '{"type": "final", "content": "done"}', + ] + ) + + tool_calls = [] + + async def executor(tool_name, arguments): + tool_calls.append((tool_name, arguments)) + return "x" * 80 + + tools = [ + { + "name": "mcp.big_one", + "description": "Large output tool one", + "inputSchema": {"type": "object", "properties": {}}, + }, + { + "name": "mcp.big_two", + "description": "Large output tool two", + "inputSchema": {"type": "object", "properties": {}}, + }, + ] + + caplog.set_level("WARNING") + + response, finish_reason = await handler.chat_completion_with_tools( + model="gpt-5.4", + system="system prompt", + user="user prompt", + tools=tools, + tool_executor=executor, + max_tool_turns=3, + max_tool_output_chars=40, + ) + + assert response == "done" + assert finish_reason == "completed" + assert tool_calls == [("mcp.big_one", {}), ("mcp.big_two", {})] + assert "[tool output truncated]" in handler.calls[1]["user"] + assert "[tool output truncated]" in handler.calls[2]["user"] + + warning_messages = [record.message for record in caplog.records if "max_tool_output_chars" in record.message] + assert len(warning_messages) == 2 + assert any("mcp.big_one" in message for message in warning_messages) + assert any("mcp.big_two" in message for message in warning_messages) + + def test_tool_output_limit_never_exceeds_tiny_cap(self): + truncated = BaseAiHandler._normalize_tool_result_text( + tool_result="x" * 200, + max_tool_output_chars=5, + tool_name="mcp.tiny_cap", + ) + + assert len(truncated) == 5 diff --git a/tests/unittest/test_config_loader_mcp.py b/tests/unittest/test_config_loader_mcp.py new file mode 100644 index 0000000000..96dfd5238c --- /dev/null +++ b/tests/unittest/test_config_loader_mcp.py @@ -0,0 +1,193 @@ +import json +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from pr_agent.config_loader import _strip_json_comments, apply_mcp_server_config, load_mcp_server_config +from pr_agent.config_loader import load_repo_pyproject_settings + + +class TestMCPConfigLoader: + def test_strip_json_comments_preserves_strings(self): + content = ( + '{\n' + ' // comment\n' + ' "url": "https://example.com//path",\n' + ' /* block */\n' + ' "key": "value"\n' + '}' + ) + stripped = _strip_json_comments(content) + data = json.loads(stripped) + assert data == {"url": "https://example.com//path", "key": "value"} + + def test_load_mcp_server_config_supports_vscode_schema(self, tmp_path): + config_path = tmp_path / "mcp.json" + config_path.write_text( + '{\n' + ' // VS Code schema\n' + ' "servers": {\n' + ' "redmine": {"type": "stdio", "command": "podman"}\n' + ' }\n' + '}', + encoding="utf-8", + ) + config_data = load_mcp_server_config(config_path) + assert config_data == { + "servers": { + "redmine": {"type": "stdio", "command": "podman"}, + } + } + + def test_load_mcp_server_config_supports_trailing_commas(self, tmp_path): + config_path = tmp_path / "mcp-trailing-commas.jsonc" + config_path.write_text( + "{\n" + ' "servers": {\n' + ' "redmine": {"type": "stdio", "command": "podman",},\n' + " },\n" + "}\n", + encoding="utf-8", + ) + + config_data = load_mcp_server_config(config_path) + + assert config_data == { + "servers": { + "redmine": {"type": "stdio", "command": "podman"}, + } + } + + def test_load_mcp_server_config_supports_claude_schema(self, tmp_path): + config_path = tmp_path / ".mcp.json" + config_path.write_text( + '{"mcpServers": {"sourcebot": {"type": "http", "url": "https://example.com/mcp"}}}', + encoding="utf-8", + ) + config_data = load_mcp_server_config(config_path) + assert config_data == { + "servers": { + "sourcebot": {"type": "http", "url": "https://example.com/mcp"}, + } + } + + def test_load_mcp_server_config_supports_aws_knowledge_schema(self, tmp_path): + config_path = tmp_path / "aws-knowledge.json" + config_path.write_text( + '{"servers": {"AWS Knowledge": {"url": "https://knowledge-mcp.global.api.aws", "type": "http"}}}', + encoding="utf-8", + ) + config_data = load_mcp_server_config(config_path) + assert config_data == { + "servers": { + "AWS Knowledge": {"url": "https://knowledge-mcp.global.api.aws", "type": "http"}, + } + } + + def test_load_mcp_server_config_raises_on_missing_servers(self, tmp_path): + config_path = tmp_path / "bad.json" + config_path.write_text('{"other": {}}', encoding="utf-8") + with pytest.raises(ValueError, match="must define either"): + load_mcp_server_config(config_path) + + def test_load_mcp_server_config_raises_on_invalid_json(self, tmp_path): + config_path = tmp_path / "invalid.json" + config_path.write_text('{"servers":', encoding="utf-8") + with pytest.raises(ValueError, match="Invalid MCP config JSON"): + load_mcp_server_config(config_path) + + def test_load_mcp_server_config_raises_on_missing_file(self, tmp_path): + config_path = tmp_path / "nonexistent.json" + with pytest.raises(FileNotFoundError): + load_mcp_server_config(config_path) + + def test_apply_mcp_server_config_uses_env_override(self, tmp_path): + config_path = tmp_path / "override.json" + config_path.write_text( + '{"servers": {"knowledge": {"type": "http", "url": "https://kb/mcp"}}}', + encoding="utf-8", + ) + settings = MagicMock() + settings.get.side_effect = lambda key, default=None: ( + False if key == "MCP.FAIL_ON_INVALID_CONFIG" else default + ) + with patch.dict("os.environ", {"MCP_CONFIG_PATH": str(config_path)}, clear=False), \ + patch("pr_agent.config_loader.get_settings", return_value=settings): + apply_mcp_server_config() + settings.set.assert_any_call( + "MCP.SERVERS", + {"knowledge": {"type": "http", "url": "https://kb/mcp"}}, + merge=False, + ) + settings.set.assert_any_call("MCP.ACTIVE_CONFIG_PATH", str(config_path), merge=False) + + def test_apply_mcp_server_config_raises_when_configured(self, tmp_path): + config_path = tmp_path / "invalid.json" + config_path.write_text('{"servers":', encoding="utf-8") + settings = MagicMock() + settings.get.side_effect = lambda key, default=None: ( + True if key == "MCP.FAIL_ON_INVALID_CONFIG" else default + ) + with patch.dict("os.environ", {"MCP_CONFIG_PATH": str(config_path)}, clear=False), \ + patch("pr_agent.config_loader.get_settings", return_value=settings): + with pytest.raises(ValueError, match="Invalid MCP config JSON"): + apply_mcp_server_config() + + def test_apply_mcp_server_config_skips_when_no_file(self, tmp_path): + settings = MagicMock() + settings.get.side_effect = lambda key, default=None: ( + str(tmp_path / "nonexistent.json") if key == "MCP.CONFIG_PATH" else default + ) + with patch("pr_agent.config_loader.get_settings", return_value=settings): + apply_mcp_server_config() # must not raise + settings.set.assert_not_called() + + def test_apply_mcp_server_config_handles_exists_oserror(self): + settings = MagicMock() + settings.get.side_effect = lambda key, default=None: ( + False if key == "MCP.FAIL_ON_INVALID_CONFIG" else default + ) + + with patch("pr_agent.config_loader.get_settings", return_value=settings), patch( + "pathlib.Path.exists", side_effect=PermissionError("denied") + ): + apply_mcp_server_config() + + settings.set.assert_not_called() + + def test_apply_mcp_server_config_raises_on_exists_oserror_when_configured(self): + settings = MagicMock() + settings.get.side_effect = lambda key, default=None: ( + True if key == "MCP.FAIL_ON_INVALID_CONFIG" else default + ) + + with patch("pr_agent.config_loader.get_settings", return_value=settings), patch( + "pathlib.Path.exists", side_effect=PermissionError("denied") + ): + with pytest.raises(PermissionError, match="denied"): + apply_mcp_server_config() + + def test_load_repo_pyproject_settings_preserves_trusted_mcp_settings(self, tmp_path): + pyproject_path = tmp_path / "pyproject.toml" + pyproject_path.write_text( + "[tool.pr-agent.mcp]\n" + "enabled = true\n" + 'config_path = "/tmp/untrusted.json"\n', + encoding="utf-8", + ) + settings = MagicMock() + settings.get.side_effect = lambda key, default=None: ( + {"ENABLED": False, "CONFIG_PATH": "/etc/pr-agent/mcp.json", "RESOLVE_ENV_VARS": True} + if key == "MCP" + else default + ) + + load_repo_pyproject_settings(pyproject_path=pyproject_path, settings=settings) + + settings.load_file.assert_called_once_with(pyproject_path, env="tool.pr-agent") + settings.set.assert_called_once_with( + "MCP", + {"ENABLED": False, "CONFIG_PATH": "/etc/pr-agent/mcp.json", "RESOLVE_ENV_VARS": True}, + merge=False, + ) diff --git a/tests/unittest/test_mcp_integration_helper.py b/tests/unittest/test_mcp_integration_helper.py new file mode 100644 index 0000000000..93586be386 --- /dev/null +++ b/tests/unittest/test_mcp_integration_helper.py @@ -0,0 +1,128 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from pr_agent.mcp.integration import maybe_chat_completion_with_mcp + + +class FakeHandler: + def __init__(self): + self.chat_calls = [] + self.tool_calls = [] + + async def chat_completion(self, model, system, user, temperature=0.2, img_path=None): + self.chat_calls.append( + { + "model": model, + "system": system, + "user": user, + "temperature": temperature, + "img_path": img_path, + } + ) + return "plain response", "completed" + + async def chat_completion_with_tools( + self, + model, + system, + user, + tools, + tool_executor, + temperature=0.2, + img_path=None, + ): + self.tool_calls.append( + { + "model": model, + "system": system, + "user": user, + "tools": tools, + "temperature": temperature, + "img_path": img_path, + } + ) + return "tool response", "completed" + + +class FakeRuntime: + def __init__(self, enabled, tools=None): + self.enabled = enabled + self.tools = tools or [] + self.executor_created = False + self.disconnected = False + self.allowed_tool_names = None + + def build_tool_schemas(self, **kwargs): + self.build_kwargs = kwargs + return self.tools + + def create_tool_executor(self, allowed_tool_names=None): + self.executor_created = True + self.allowed_tool_names = allowed_tool_names + + async def executor(tool_name, arguments=None): + return {"tool": tool_name, "arguments": arguments or {}} + + return executor + + def disconnect_all(self): + self.disconnected = True + + +class TestMCPIntegrationHelper: + @pytest.mark.asyncio + async def test_falls_back_when_runtime_disabled(self): + handler = FakeHandler() + runtime = FakeRuntime(enabled=False) + settings = MagicMock() + settings.get.side_effect = lambda key, default=None: default + + with patch("pr_agent.mcp.integration.MCPRuntime", return_value=runtime), patch( + "pr_agent.mcp.integration.get_settings", return_value=settings + ): + response, finish_reason = await maybe_chat_completion_with_mcp( + handler, + model="gpt-5.4", + system="system prompt", + user="user prompt", + ) + + assert response == "plain response" + assert finish_reason == "completed" + assert len(handler.chat_calls) == 1 + assert not handler.tool_calls + assert runtime.disconnected + + @pytest.mark.asyncio + async def test_uses_tool_orchestration_when_enabled(self): + handler = FakeHandler() + runtime = FakeRuntime( + enabled=True, + tools=[{"type": "function", "function": {"name": "alpha.echo", "parameters": {}}}], + ) + settings = MagicMock() + settings.get.side_effect = lambda key, default=None: { + "MCP.MAX_TOOL_CATALOG_TOOLS": 2, + "MCP.MAX_TOOL_CATALOG_SCHEMA_CHARS": 1000, + "MCP.ENABLED_SERVERS": None, + }.get(key, default) + + with patch("pr_agent.mcp.integration.MCPRuntime", return_value=runtime), patch( + "pr_agent.mcp.integration.get_settings", return_value=settings + ): + response, finish_reason = await maybe_chat_completion_with_mcp( + handler, + model="gpt-5.4", + system="system prompt", + user="user prompt", + command_name="review", + ) + + assert response == "tool response" + assert finish_reason == "completed" + assert runtime.executor_created + assert runtime.disconnected + assert len(handler.tool_calls) == 1 + assert "Command context: review" in handler.tool_calls[0]["system"] + assert runtime.allowed_tool_names == {"alpha.echo"} diff --git a/tests/unittest/test_mcp_runtime.py b/tests/unittest/test_mcp_runtime.py new file mode 100644 index 0000000000..ee98181fe0 --- /dev/null +++ b/tests/unittest/test_mcp_runtime.py @@ -0,0 +1,387 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest +import requests + +from pr_agent.mcp.runtime import ( + MCPHttpClient, + MCPRuntime, + MCPRuntimeError, + MCPStdioClient, + MCPStreamableHttpClient, + MCPToolDefinition, +) + + +class FakeClient: + def __init__(self, name, tools=None): + self.name = name + self.tools = tools or [] + self.connected = False + self.server_capabilities = {"tools": True} + + def connect(self): + self.connected = True + + def close(self): + self.connected = False + + def list_tools(self): + return self.tools + + def call_tool(self, tool_name, arguments=None): + return { + "name": tool_name, + "arguments": arguments or {}, + "server": self.name, + } + + +class FailingConnectClient(FakeClient): + def __init__(self, name, tools=None): + super().__init__(name, tools) + self.closed = False + + def connect(self): + raise MCPRuntimeError("connect failed") + + def close(self): + self.closed = True + + +class TestMCPRuntime: + def test_runtime_uses_settings_when_not_provided(self): + with patch( + "pr_agent.mcp.runtime.get_settings", + return_value={ + "MCP.SERVERS": {"srv": {"type": "http", "url": "https://example.com/mcp"}}, + "MCP": {"ENABLED": True}, + }, + ): + runtime = MCPRuntime() + assert runtime.configured_server_names == ["srv"] + assert runtime.enabled + + def test_build_client_from_type(self): + runtime = MCPRuntime(servers_config={"s1": {"type": "stdio", "command": "echo"}}) + client = runtime._build_client("s1", {"type": "stdio", "command": "echo"}) + assert isinstance(client, MCPStdioClient) + + client = runtime._build_client("s2", {"type": "http", "url": "https://example.com/mcp"}) + assert isinstance(client, MCPHttpClient) + + def test_build_client_type_inferred(self): + runtime = MCPRuntime(servers_config={}) + + client = runtime._build_client("s1", {"command": "echo"}) + assert isinstance(client, MCPStdioClient) + + client = runtime._build_client("s2", {"url": "https://example.com/mcp"}) + assert isinstance(client, MCPHttpClient) + + def test_build_client_unsupported_transport(self): + runtime = MCPRuntime(servers_config={}) + with pytest.raises(MCPRuntimeError, match="unsupported transport"): + runtime._build_client("bad", {"type": "sse", "url": "https://example.com/sse"}) + + def test_list_all_tools(self): + alpha_tools = [MCPToolDefinition("alpha", "tool_a", "desc", {})] + beta_tools = [MCPToolDefinition("beta", "tool_b", "desc", {})] + fake_clients = { + "alpha": FakeClient("alpha", alpha_tools), + "beta": FakeClient("beta", beta_tools), + } + + with patch("pr_agent.mcp.runtime.get_settings", return_value={"MCP": {"ENABLED": True}}): + runtime = MCPRuntime( + servers_config={ + "alpha": {"type": "http", "url": "https://alpha.example.com/mcp"}, + "beta": {"type": "http", "url": "https://beta.example.com/mcp"}, + } + ) + + with patch.object(runtime, "_build_client", side_effect=lambda name, cfg: fake_clients[name]): + all_tools = runtime.list_all_tools() + + assert {tool.name for tool in all_tools} == {"tool_a", "tool_b"} + assert runtime.get_status() == { + "enabled": True, + "configured_servers": ["alpha", "beta"], + "connected_servers": ["alpha", "beta"], + } + + def test_call_tool_connects_lazy(self): + fake_client = FakeClient("alpha") + + with patch("pr_agent.mcp.runtime.get_settings", return_value={"MCP": {"ENABLED": True}}): + runtime = MCPRuntime( + servers_config={"alpha": {"type": "http", "url": "https://alpha.example.com/mcp"}} + ) + + with patch.object(runtime, "_build_client", return_value=fake_client): + result = runtime.call_tool("alpha", "sum", {"x": 1, "y": 2}) + + assert result["name"] == "sum" + assert result["arguments"] == {"x": 1, "y": 2} + assert fake_client.connected + + def test_connect_server_closes_client_when_connect_fails(self): + failing_client = FailingConnectClient("alpha") + + with patch("pr_agent.mcp.runtime.get_settings", return_value={"MCP": {"ENABLED": True}}): + runtime = MCPRuntime(servers_config={"alpha": {"type": "http", "url": "https://alpha.example.com/mcp"}}) + + with patch.object(runtime, "_build_client", return_value=failing_client): + with pytest.raises(MCPRuntimeError, match="connect failed"): + runtime.connect_server("alpha") + + assert failing_client.closed + assert "alpha" not in runtime._clients + + def test_disconnect_all(self): + fake_clients = { + "alpha": FakeClient("alpha"), + "beta": FakeClient("beta"), + } + + with patch("pr_agent.mcp.runtime.get_settings", return_value={"MCP": {"ENABLED": True}}): + runtime = MCPRuntime( + servers_config={ + "alpha": {"type": "http", "url": "https://alpha.example.com/mcp"}, + "beta": {"type": "http", "url": "https://beta.example.com/mcp"}, + } + ) + + with patch.object(runtime, "_build_client", side_effect=lambda name, cfg: fake_clients[name]): + runtime.connect_all() + runtime.disconnect_all() + + assert not fake_clients["alpha"].connected + assert not fake_clients["beta"].connected + + def test_resolve_env_vars_disabled_preserves_placeholders(self): + settings = MagicMock() + settings.get.side_effect = lambda key, default=None: { + "MCP": {"ENABLED": True}, + "MCP.RESOLVE_ENV_VARS": False, + }.get(key, default) + + with patch("pr_agent.mcp.runtime.get_settings", return_value=settings), patch.dict( + "os.environ", {"MCP_TEST_ENV": "expanded"}, clear=False + ): + runtime = MCPRuntime(servers_config={}) + assert runtime._resolve_config_values("$MCP_TEST_ENV/path") == "$MCP_TEST_ENV/path" + + def test_resolve_env_vars_enabled_expands_placeholders(self): + settings = MagicMock() + settings.get.side_effect = lambda key, default=None: { + "MCP": {"ENABLED": True}, + "MCP.RESOLVE_ENV_VARS": True, + }.get(key, default) + + with patch("pr_agent.mcp.runtime.get_settings", return_value=settings), patch.dict( + "os.environ", {"MCP_TEST_ENV": "expanded"}, clear=False + ): + runtime = MCPRuntime(servers_config={}) + assert runtime._resolve_config_values("$MCP_TEST_ENV/path") == "expanded/path" + + @pytest.mark.asyncio + async def test_tool_executor_rejects_non_allowlisted_tool(self): + settings = MagicMock() + settings.get.side_effect = lambda key, default=None: { + "MCP": {"ENABLED": True}, + "MCP.RESOLVE_ENV_VARS": True, + }.get(key, default) + fake_client = FakeClient("alpha") + + with patch("pr_agent.mcp.runtime.get_settings", return_value=settings): + runtime = MCPRuntime( + servers_config={"alpha": {"type": "http", "url": "https://alpha.example.com/mcp"}} + ) + + with patch.object(runtime, "_build_client", return_value=fake_client): + executor = runtime.create_tool_executor(allowed_tool_names={"alpha.allowed"}) + with pytest.raises(MCPRuntimeError, match="Tool not available"): + await executor("alpha.blocked", {}) + + +def _make_fake_response(body: dict, content_type: str = "application/json", headers: dict = None): + """Build a fake requests.Response for patching.""" + resp = MagicMock(spec=requests.Response) + resp.status_code = 200 + resp.headers = {"Content-Type": content_type, **(headers or {})} + resp.json.return_value = body + resp.raise_for_status.return_value = None + return resp + + +def _make_sse_response(events: list[dict], extra_headers: dict = None): + """Build a fake SSE requests.Response.""" + lines = [] + for event in events: + lines.append(f"data: {json.dumps(event)}") + lines.append("") # SSE event separator + + resp = MagicMock(spec=requests.Response) + resp.status_code = 200 + resp.headers = {"Content-Type": "text/event-stream", **(extra_headers or {})} + resp.raise_for_status.return_value = None + resp.iter_lines.return_value = iter(lines) + return resp + + +class TestMCPStreamableHttpClient: + def _client(self, url="https://mcp.example.com/mcp", **extra): + return MCPStreamableHttpClient("TestServer", {"url": url, **extra}) + + def test_missing_url_raises(self): + with pytest.raises(MCPRuntimeError, match="missing 'url'"): + MCPStreamableHttpClient("TestServer", {}) + + def test_connect_plain_json_response(self): + client = self._client() + init_resp = _make_fake_response( + {"jsonrpc": "2.0", "id": 1, "result": {"capabilities": {"tools": {}}}}, + ) + notif_resp = _make_fake_response({}) + + with patch.object(client._session, "post", side_effect=[init_resp, notif_resp]) as mock_post: + client.connect() + + assert client.server_capabilities == {"tools": {}} + # initialize + notifications/initialized + assert mock_post.call_count == 2 + # Accept header must negotiate both formats + session_headers = dict(client._session.headers) + assert "text/event-stream" in session_headers.get("Accept", "") + + def test_connect_captures_session_id(self): + client = self._client() + init_resp = _make_fake_response( + {"jsonrpc": "2.0", "id": 1, "result": {"capabilities": {}}}, + headers={"Mcp-Session-Id": "abc123"}, + ) + notif_resp = _make_fake_response({}) + + with patch.object(client._session, "post", side_effect=[init_resp, notif_resp]): + client.connect() + + assert client._session_id == "abc123" + + def test_send_request_includes_session_id_header(self): + client = self._client() + client._session_id = "sess-42" + + resp = _make_fake_response({"jsonrpc": "2.0", "id": 1, "result": {"tools": []}}) + with patch.object(client._session, "post", return_value=resp) as mock_post: + client._send_request("tools/list", {}) + + extra_headers = mock_post.call_args.kwargs.get("headers", {}) + assert extra_headers.get("Mcp-Session-Id") == "sess-42" + + def test_list_tools_plain_json(self): + client = self._client() + tools_resp = _make_fake_response( + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + {"name": "current_time", "description": "Get current time", "inputSchema": {"type": "object"}} + ] + }, + } + ) + with patch.object(client._session, "post", return_value=tools_resp): + tools = client.list_tools() + + assert len(tools) == 1 + assert tools[0].name == "current_time" + assert tools[0].server_name == "TestServer" + + def test_list_tools_sse_response(self): + client = self._client() + sse_resp = _make_sse_response( + [ + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [{"name": "get_time", "description": "Time", "inputSchema": {"type": "object"}}] + }, + } + ] + ) + with patch.object(client._session, "post", return_value=sse_resp): + tools = client.list_tools() + + assert len(tools) == 1 + assert tools[0].name == "get_time" + + def test_sse_stream_with_notification_before_result(self): + """Notifications (no 'id') in the SSE stream must be skipped.""" + client = self._client() + notification = {"jsonrpc": "2.0", "method": "notifications/progress", "params": {}} + result_event = {"jsonrpc": "2.0", "id": 1, "result": {"tools": []}} + sse_resp = _make_sse_response([notification, result_event]) + + with patch.object(client._session, "post", return_value=sse_resp): + result = client._send_request("tools/list", {}) + + assert result == {"tools": []} + + def test_sse_stream_exhausted_without_match_raises(self): + client = self._client() + # Only a notification, never a matching id + sse_resp = _make_sse_response([{"jsonrpc": "2.0", "method": "ping", "params": {}}]) + + with patch.object(client._session, "post", return_value=sse_resp): + with pytest.raises(MCPRuntimeError, match="SSE stream ended without a matching response"): + client._send_request("tools/list", {}) + + def test_server_error_in_sse_raises(self): + client = self._client() + sse_resp = _make_sse_response( + [{"jsonrpc": "2.0", "id": 1, "error": {"code": -32600, "message": "bad request"}}] + ) + with patch.object(client._session, "post", return_value=sse_resp): + with pytest.raises(MCPRuntimeError, match="bad request"): + client._send_request("tools/list", {}) + + def test_http_error_raises(self): + client = self._client() + bad_resp = MagicMock(spec=requests.Response) + bad_resp.raise_for_status.side_effect = requests.RequestException("406 Not Acceptable") + + with patch.object(client._session, "post", return_value=bad_resp): + with pytest.raises(MCPRuntimeError, match="406 Not Acceptable"): + client._send_request("initialize", {}) + + def test_close_clears_session_id(self): + client = self._client() + client._session_id = "will-be-cleared" + with patch.object(client._session, "close"): + client.close() + assert client._session_id is None + + +class TestBuildClientStreamableHttp: + def test_build_client_streamable_http_type(self): + runtime = MCPRuntime(servers_config={}) + client = runtime._build_client( + "Sourcebot", {"type": "streamable_http", "url": "http://sourcebot.example.com/api/mcp"} + ) + assert isinstance(client, MCPStreamableHttpClient) + + +class TestMCPStdioClientConfigValidation: + def test_invalid_timeout_raises_mcp_runtime_error(self): + with pytest.raises(MCPRuntimeError, match="timeout must be a number"): + MCPStdioClient("TestServer", {"command": "echo", "timeout": "slow"}) + + def test_invalid_args_element_raises_mcp_runtime_error(self): + client = MCPStdioClient("TestServer", {"command": "echo", "timeout": 30}) + + with pytest.raises(MCPRuntimeError, match="args must contain only strings or path-like values"): + client._normalize_args(["ok", 123]) diff --git a/tests/unittest/test_mcp_tool_discovery.py b/tests/unittest/test_mcp_tool_discovery.py new file mode 100644 index 0000000000..e4493d1b4a --- /dev/null +++ b/tests/unittest/test_mcp_tool_discovery.py @@ -0,0 +1,38 @@ +from unittest.mock import patch + +from pr_agent.mcp.runtime import MCPRuntime, MCPToolDefinition + + +class TestMCPToolDiscovery: + def test_build_tool_schemas_filters_by_server_and_budget(self): + runtime = MCPRuntime(servers_config={"alpha": {}, "beta": {}}) + + tools = [ + MCPToolDefinition("alpha", "tool_a", "desc a", {"type": "object"}), + MCPToolDefinition("beta", "tool_b", "desc b", {"type": "object"}), + MCPToolDefinition("beta", "tool_c", "desc c", {"type": "object"}), + ] + + with patch("pr_agent.mcp.runtime.get_settings", return_value={"MCP": {"ENABLED": True}}): + runtime = MCPRuntime(servers_config={"alpha": {}, "beta": {}}) + + with patch.object(runtime, "list_all_tools", return_value=tools): + schemas = runtime.build_tool_schemas(server_names=["beta"], max_tools=1, include_server_prefix=True) + + assert len(schemas) == 1 + assert schemas[0]["function"]["name"] == "beta.tool_b" + + def test_build_tool_schemas_respects_character_budget(self): + with patch("pr_agent.mcp.runtime.get_settings", return_value={"MCP": {"ENABLED": True}}): + runtime = MCPRuntime(servers_config={"alpha": {}}) + + tools = [ + MCPToolDefinition("alpha", "tool_a", "x" * 50, {"type": "object"}), + MCPToolDefinition("alpha", "tool_b", "y" * 50, {"type": "object"}), + ] + + with patch.object(runtime, "list_all_tools", return_value=tools): + schemas = runtime.build_tool_schemas(max_schema_chars=250) + + assert len(schemas) == 1 + assert schemas[0]["function"]["name"] == "alpha.tool_a"