diff --git a/docs/docs/agents_tools.md b/docs/docs/agents_tools.md index dfcd56d71..2f5b31d81 100644 --- a/docs/docs/agents_tools.md +++ b/docs/docs/agents_tools.md @@ -250,6 +250,31 @@ external information retrieval. - Each MCP tool is wrapped for synchronous execution so it integrates seamlessly with the ReAct agent loop - Tool names from MCP servers are namespaced to avoid conflicts with built-in tools +### Built-in Archi MCP Server + +When `services.mcp_server.enabled: true` is set, the chat service also exposes +its own MCP server at `/mcp/sse`. This lets IDEs and MCP clients connect +directly to an Archi deployment and use Archi-native tools over SSE. + +The built-in Archi MCP server currently exposes these read-only tools: + +- `archi_query` — ask the deployment a question through the normal RAG/chat pipeline +- `archi_list_documents` — page through indexed documents with source, status, and enabled state +- `archi_search_document_metadata` — search by metadata fields such as `source_type`, `ticket_id`, `url`, or `relative_path` +- `archi_list_metadata_schema` — inspect the metadata keys and common values supported by metadata search +- `archi_search_document_content` — grep-like exact or regex search over indexed document contents +- `archi_get_document_content` — fetch the raw text content for a document by hash +- `archi_get_document_chunks` — inspect stored chunk boundaries and chunk text for a document +- `archi_get_data_stats` — view corpus-level document, chunk, source, and ingestion statistics +- `archi_get_deployment_info` — inspect active model, retrieval settings, embedding config, and MCP runtime info +- `archi_list_agents` — list available agent specs and their configured tools +- `archi_get_agent_spec` — fetch the full markdown agent spec for a named agent +- `archi_health` — basic deployment/database health check + +These tools are especially useful from VS Code, Cursor, Claude Desktop, and +Claude Code when you want direct access to Archi's indexed corpus without +having to proxy through a separate MCP server. + --- ## Vector Store & Retrieval diff --git a/docs/docs/services.md b/docs/docs/services.md index 70e43e9f2..61e192bf6 100644 --- a/docs/docs/services.md +++ b/docs/docs/services.md @@ -25,6 +25,7 @@ The primary user-facing service. Provides a web-based chat application for inter - Streaming responses with tool-call visualization - Agent selector dropdown for switching between agents - Built-in [Data Viewer](data_sources.md#data-viewer) at `/data` +- Optional built-in MCP server at `/mcp/sse` for IDE and agent integrations - Settings panel for model/provider selection - [BYOK](models_providers.md#bring-your-own-key-byok) support - Conversation history @@ -51,6 +52,24 @@ services: archi create [...] --services chatbot ``` +### Built-in MCP Server + +The chat service can expose Archi itself as an MCP server over Server-Sent +Events. Enable it when you want tools like VS Code, Cursor, Claude Desktop, or +Claude Code to connect directly to your deployment. + +```yaml +services: + mcp_server: + enabled: true + url: "https://chat.example.org" +``` + +- **Endpoint:** `/mcp/sse` +- **Auth page:** `/mcp/auth` for generating bearer tokens when auth is enabled +- **Tools exposed:** query, document discovery, metadata search, content grep, + chunk inspection, corpus stats, deployment info, and agent-spec inspection + --- ## Service Status Board & Alert Banners @@ -279,31 +298,194 @@ archi create [...] --services chatbot,redmine-mailer ## Mattermost Interface -Reads posts from a Mattermost forum and posts draft responses to a specified channel. +Connects Archi to a Mattermost channel. Supports two operating modes: -### Configuration +- **Webhook mode** — Mattermost pushes outgoing webhooks to Archi (recommended) +- **Polling mode** — Archi polls a channel periodically via the Mattermost API + +**Default port:** `5000` + +### Setup + +#### Secrets + +```bash +# Required for webhook mode +MATTERMOST_WEBHOOK=https://mattermost.example.com/hooks/... # Incoming webhook URL +MATTERMOST_OUTGOING_TOKEN=... # Outgoing webhook token for request validation + +# Required for polling mode only +MATTERMOST_PAK=... # Personal Access Token for the bot account +MATTERMOST_CHANNEL_ID_READ=... # Channel to read posts from +MATTERMOST_CHANNEL_ID_WRITE=... # Channel to post responses to + +# Required for SSO auth (db mode) +SSO_CLIENT_ID=... +SSO_CLIENT_SECRET=... +BYOK_ENCRYPTION_KEY=... # Used to encrypt stored refresh tokens +PG_PASSWORD=... +``` + +#### Basic Configuration ```yaml services: mattermost: - update_time: 60 + update_time: 60 # polling interval in seconds (polling mode only) + port: 5000 + external_port: 5000 ``` -### Secrets +#### Running ```bash -MATTERMOST_WEBHOOK=... -MATTERMOST_PAK=... -MATTERMOST_CHANNEL_ID_READ=... -MATTERMOST_CHANNEL_ID_WRITE=... +archi create [...] --services chatbot,mattermost ``` -### Running +--- -```bash -archi create [...] --services chatbot,mattermost +### Authentication + +By default auth is disabled and the bot responds to all users. Two auth modes are available. + +#### Mode 1: Config (Static Allowlist) + +Roles are assigned to Mattermost users via a static map in the config. No SSO or database required. + +```yaml +services: + mattermost: + auth: + enabled: true + token_store: config + default_role: mattermost-restricted # role for users not in user_roles + user_roles: + jsmith: [archi-expert] # Mattermost username → list of roles + ahmedmu: [archi-admins] + someuser: [archi-expert, base-user] +``` + +- Users in `user_roles` get the specified roles. +- Users not in `user_roles` get `default_role`. +- If `default_role` is not defined in `auth_roles`, those users have no permissions and are denied. + +#### Mode 2: DB / SSO (Recommended) + +Roles come from the CERN SSO JWT token. On first message, the bot sends the user a login link. After authenticating, their roles are stored in the database and reused on subsequent messages — no re-login required until the session expires. + +```yaml +services: + mattermost: + auth: + enabled: true + token_store: db + session_lifetime_days: 30 # full re-login required after this period + roles_refresh_hours: 24 # silent background role refresh interval + login_base_url: "https://your-mattermost-service-host:5000" + sso: + server_metadata_url: "https://auth.cern.ch/auth/realms/cern/.well-known/openid-configuration" + token_endpoint: "https://auth.cern.ch/auth/realms/cern/protocol/openid-connect/token" +``` + +**SSO registration requirement:** The callback URL `/mattermost-auth/callback` must be registered as a valid redirect URI in your SSO client (Keycloak / CERN Auth). + +**Login flow:** + +``` +1. User sends message to bot (no token stored) +2. Bot replies: "Please login: https://:5000/mattermost-auth?state=&username=" +3. User clicks link → redirected to CERN SSO +4. After SSO login → redirected to /mattermost-auth/callback +5. Roles extracted from JWT, stored in mattermost_tokens table +6. User sees success page, closes tab, returns to Mattermost +7. Future messages use stored roles (silent refresh every 24h) +``` + +**Session lifecycle:** + +| Event | Behaviour | +|-------|-----------| +| First message | Login link sent | +| Token valid, roles fresh | Respond normally | +| Roles stale (`> roles_refresh_hours`) | Silent refresh via stored refresh token | +| Session expired (`> session_lifetime_days`) | Login link sent again | +| Admin invalidates token | Login link sent on next message | + +--- + +### Role-Based Access Control + +Mattermost auth integrates with the same RBAC system used by the chat app. Roles are defined under `services.chat_app.auth.auth_roles`. + +#### Restricting Access + +To allow only users with a specific role (e.g. `archi-expert` and above), add the `mattermost:access` permission to those roles and **not** to `base-user`: + +```yaml +services: + chat_app: + auth: + auth_roles: + roles: + base-user: + permissions: + - chat:query + - chat:history + # no mattermost:access here + + archi-expert: + inherits: [base-user] + permissions: + - mattermost:access # grants access to the Mattermost bot + - documents:view + - config:view + # ... + + archi-admins: + permissions: + - "*" # wildcard includes mattermost:access + + permissions: + mattermost:access: + description: "Access the Mattermost bot" + category: "mattermost" ``` +- `base-user` only → denied with "you don't have permission" message +- `archi-expert` → allowed (has `mattermost:access`) +- `archi-admins` → allowed (wildcard) + +#### Tool-Level Permissions + +Tool permissions work the same as in the chat app. Add permissions like `tools:http_get` to roles that should be able to use specific agent tools. The Mattermost user context is propagated through the full call stack so tool checks apply correctly. + +```yaml + archi-expert: + permissions: + - mattermost:access + - tools:http_get # allow HTTP GET tool for this role +``` + +#### Database + +A `mattermost_tokens` table is required when using `token_store: db`. It is created automatically by `init.sql` on first deploy. For existing deployments, run the migration manually: + +```sql +CREATE TABLE IF NOT EXISTS mattermost_tokens ( + mattermost_user_id VARCHAR(255) PRIMARY KEY, + mattermost_username VARCHAR(255), + email VARCHAR(255), + roles JSONB NOT NULL DEFAULT '[]', + refresh_token BYTEA, + token_expires_at TIMESTAMPTZ, + roles_refreshed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +``` + +Refresh tokens are encrypted at rest using `pgp_sym_encrypt` (requires `BYOK_ENCRYPTION_KEY`). + --- ## Grafana Monitoring diff --git a/src/archi/pipelines/agents/base_react.py b/src/archi/pipelines/agents/base_react.py index 43e08f604..e1feaf143 100644 --- a/src/archi/pipelines/agents/base_react.py +++ b/src/archi/pipelines/agents/base_react.py @@ -998,20 +998,39 @@ def refresh_agent( extra_tools: Optional[Sequence[Callable]] = None, middleware: Optional[Sequence[Callable]] = None, force: bool = False, + user_id: Optional[str] = None, ) -> CompiledStateGraph: """Ensure the LangGraph agent reflects the latest tool set.""" base_tools = list(static_tools) if static_tools is not None else self.tools toolset: List[Callable] = list(base_tools) if "mcp" in self.selected_tool_names: - if self._mcp_tools is None: - built = self._build_mcp_tools() - self._mcp_tools = list(built or []) - toolset.extend(self._mcp_tools) + # When user_id is present, always rebuild so each request fetches a + # fresh (possibly refreshed) token from the DB for SSO-auth servers. + # Without a user_id (anonymous), cache the tools as before. + if self._mcp_tools is None or user_id: + built = self._build_mcp_tools(user_id=user_id) + if not user_id: + self._mcp_tools = list(built or []) + toolset.extend(built or []) + else: + toolset.extend(self._mcp_tools) if extra_tools: toolset.extend(extra_tools) + # OpenAI enforces a hard 128-tool limit per request. + _OPENAI_MAX_TOOLS = 128 + if len(toolset) > _OPENAI_MAX_TOOLS: + logger.warning( + f"Toolset has {len(toolset)} tools, exceeding OpenAI max of {_OPENAI_MAX_TOOLS}. " + f"Truncating MCP tools to fit. Static tools ({len(base_tools)}) are preserved." + ) + # Keep all static/extra tools; trim only the MCP portion + n_static = len(base_tools) + (len(list(extra_tools)) if extra_tools else 0) + mcp_budget = max(0, _OPENAI_MAX_TOOLS - n_static) + toolset = toolset[:n_static] + toolset[n_static:n_static + mcp_budget] + middleware = list(middleware) if middleware is not None else self.middleware requires_refresh = ( @@ -1057,14 +1076,14 @@ def _build_static_tools(self) -> List[Callable]: static_names = [name for name in selected if name != "mcp"] return self._select_tools_from_registry(static_names) - def _build_mcp_tools(self) -> List[Callable]: + def _build_mcp_tools(self, user_id: Optional[str] = None) -> List[Callable]: """Retrieve MCP tools from servers defined in the config and keep those server connections alive""" try: self._async_runner = AsyncLoopThread.get_instance() # Initialize MCP client on the background loop # The client and sessions will live on this loop - client, mcp_tools = self._async_runner.run(initialize_mcp_client()) + client, mcp_tools = self._async_runner.run(initialize_mcp_client(user_id=user_id)) if client is None: logger.info("No MCP servers configured.") return None @@ -1153,7 +1172,8 @@ def _prepare_agent_inputs(self, **kwargs) -> Dict[str, Any]: if hasattr(self, "_vector_tools"): extra_tools = self._vector_tools if self._vector_tools else None # type: ignore[attr-defined] - self.refresh_agent(extra_tools=extra_tools) + user_id = kwargs.get("user_id") + self.refresh_agent(extra_tools=extra_tools, user_id=user_id) inputs = self._prepare_inputs(history=kwargs.get("history")) history_messages = inputs["history"] diff --git a/src/archi/pipelines/agents/tools/base.py b/src/archi/pipelines/agents/tools/base.py index 003bf73e9..920f063ea 100644 --- a/src/archi/pipelines/agents/tools/base.py +++ b/src/archi/pipelines/agents/tools/base.py @@ -35,7 +35,30 @@ def check_tool_permission(required_permission: str) -> tuple[bool, Optional[str] try: from flask import session, has_request_context from src.utils.rbac.registry import get_registry - + + # Check Mattermost context first — covers webhook mode (Flask context, no session) + # and polling mode (no Flask context). ContextVar is set by mattermost_user_context(). + try: + from src.utils.rbac.mattermost_context import get_mattermost_context + mm_ctx = get_mattermost_context() + if mm_ctx is not None: + registry = get_registry() + if registry.has_permission(mm_ctx.roles, required_permission): + logger.debug( + f"Mattermost user @{mm_ctx.username} granted '{required_permission}'" + ) + return True, None + logger.info( + f"Mattermost user @{mm_ctx.username} denied '{required_permission}' " + f"(roles: {mm_ctx.roles})" + ) + return False, ( + f"Permission denied for @{mm_ctx.username}: " + f"requires '{required_permission}'." + ) + except Exception as mm_exc: + logger.debug(f"Mattermost context check skipped: {mm_exc}") + # If we're not in a request context, allow the tool (for testing/CLI usage) if not has_request_context(): logger.debug("No request context, allowing tool access") diff --git a/src/archi/pipelines/agents/tools/local_files.py b/src/archi/pipelines/agents/tools/local_files.py index 9190bb5e2..2c30cb8f6 100644 --- a/src/archi/pipelines/agents/tools/local_files.py +++ b/src/archi/pipelines/agents/tools/local_files.py @@ -2,6 +2,7 @@ import os import re +import time from dataclasses import dataclass, field from pathlib import Path from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple @@ -31,6 +32,8 @@ def __init__( port: int = 7871, external_port: Optional[int] = None, timeout: float = 30.0, + retry_attempts: int = 3, + retry_backoff_seconds: float = 1.0, api_token: Optional[str] = None, ): host_mode_flag = self._resolve_host_mode(host_mode) @@ -42,10 +45,57 @@ def __init__( final_port = external_port if host_mode_flag and external_port else port self.base_url = f"http://{host}:{final_port}" self.timeout = timeout + self.retry_attempts = max(int(retry_attempts), 1) + self.retry_backoff_seconds = max(float(retry_backoff_seconds), 0.0) self._headers: Dict[str, str] = {} if api_token: self._headers["Authorization"] = f"Bearer {api_token}" + def _get(self, path: str, *, params: Optional[Dict[str, object]] = None) -> requests.Response: + last_exc: Optional[Exception] = None + for attempt in range(1, self.retry_attempts + 1): + try: + resp = requests.get( + f"{self.base_url}{path}", + params=params, + headers=self._headers, + timeout=self.timeout, + ) + except (requests.Timeout, requests.ConnectionError) as exc: + last_exc = exc + if attempt >= self.retry_attempts: + raise + sleep_s = self.retry_backoff_seconds * (2 ** (attempt - 1)) + logger.warning( + "Catalog request failed (%s/%s) %s: %s; retrying in %.2fs", + attempt, + self.retry_attempts, + path, + exc, + sleep_s, + ) + time.sleep(sleep_s) + continue + + if resp.status_code >= 500 and attempt < self.retry_attempts: + sleep_s = self.retry_backoff_seconds * (2 ** (attempt - 1)) + logger.warning( + "Catalog request got HTTP %s (%s/%s) %s; retrying in %.2fs", + resp.status_code, + attempt, + self.retry_attempts, + path, + sleep_s, + ) + time.sleep(sleep_s) + continue + + return resp + + if last_exc is not None: + raise last_exc + raise RuntimeError(f"Catalog request exhausted retries for {path}") + @classmethod def from_deployment_config(cls, config: Optional[Dict[str, object]]) -> "RemoteCatalogClient": """Create a client using the standard archi deployment config structure.""" @@ -61,6 +111,9 @@ def from_deployment_config(cls, config: Optional[Dict[str, object]]) -> "RemoteC hostname=data_manager_cfg.get("hostname") or data_manager_cfg.get("host"), port=data_manager_cfg.get("port", 7871), external_port=data_manager_cfg.get("external_port"), + timeout=float(data_manager_cfg.get("catalog_timeout_seconds", 60.0)), + retry_attempts=int(data_manager_cfg.get("catalog_retry_attempts", 3)), + retry_backoff_seconds=float(data_manager_cfg.get("catalog_retry_backoff_seconds", 0.25)), api_token=api_token, ) @@ -103,22 +156,15 @@ def search( params["after"] = after if max_matches_per_file is not None: params["max_matches_per_file"] = max_matches_per_file - resp = requests.get( - f"{self.base_url}/api/catalog/search", - params=params, - headers=self._headers, - timeout=self.timeout, - ) + resp = self._get("/api/catalog/search", params=params) resp.raise_for_status() data = resp.json() return data.get("hits", []) or [] def get_document(self, resource_hash: str, *, max_chars: int = 4000) -> Optional[Dict[str, object]]: - resp = requests.get( - f"{self.base_url}/api/catalog/document/{resource_hash}", + resp = self._get( + f"/api/catalog/document/{resource_hash}", params={"max_chars": max_chars}, - headers=self._headers, - timeout=self.timeout, ) if resp.status_code == 404: return None @@ -126,11 +172,7 @@ def get_document(self, resource_hash: str, *, max_chars: int = 4000) -> Optional return resp.json() def schema(self) -> Dict[str, object]: - resp = requests.get( - f"{self.base_url}/api/catalog/schema", - headers=self._headers, - timeout=self.timeout, - ) + resp = self._get("/api/catalog/schema") resp.raise_for_status() return resp.json() diff --git a/src/archi/pipelines/agents/tools/mcp.py b/src/archi/pipelines/agents/tools/mcp.py index 46b034cd8..537ce9f76 100644 --- a/src/archi/pipelines/agents/tools/mcp.py +++ b/src/archi/pipelines/agents/tools/mcp.py @@ -1,6 +1,8 @@ from __future__ import annotations +import os from typing import List, Any, Tuple, Optional +import httpx from langchain_mcp_adapters.client import MultiServerMCPClient from langchain_mcp_adapters.tools import load_mcp_tools from langchain.tools import BaseTool @@ -10,22 +12,71 @@ logger = get_logger(__name__) -async def initialize_mcp_client() -> Tuple[Optional[MultiServerMCPClient], List[BaseTool]]: +_CERN_CA_BUNDLE = "/etc/ssl/certs/tls-ca-bundle.pem" + + +def _make_httpx_factory(ca_bundle: str): + """Return an httpx_client_factory that uses the given CA bundle for SSL verification.""" + def factory( + headers: dict | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + return httpx.AsyncClient( + headers=headers or {}, + timeout=timeout, + auth=auth, + verify=ca_bundle, + follow_redirects=True, + ) + return factory + + +async def initialize_mcp_client(user_id: Optional[str] = None) -> Tuple[Optional[MultiServerMCPClient], List[BaseTool]]: """ Initializes the MCP client and fetches tool definitions. + Args: + user_id: SSO user ID used to look up a valid MCP OAuth token from the DB + for servers configured with sso_auth: true. Returns: client: The active client instance (must be kept alive by the caller). tools: The list of LangChain-compatible tools. """ + from src.utils.mcp_oauth_service import MCPOAuthService mcp_servers = get_mcp_servers_config() - logger.info(f"Configuring MCP client with servers: {list(mcp_servers.keys())}") - client = MultiServerMCPClient(mcp_servers) + _mcp_oauth = MCPOAuthService() + + _use_cern_ca = os.path.exists(_CERN_CA_BUNDLE) + if _use_cern_ca: + logger.info(f"Using CERN CA bundle for MCP SSL verification: {_CERN_CA_BUNDLE}") + + # Resolve per-server config, injecting Bearer auth where sso_auth is enabled. + # Skip SSO-auth servers when no valid MCP OAuth token is available. + resolved_servers = {} + for name, cfg in mcp_servers.items(): + server_cfg = dict(cfg) + requires_sso = server_cfg.pop('sso_auth', False) + if requires_sso: + access_token = _mcp_oauth.get_access_token(user_id, name) if user_id else None + if not access_token: + logger.info(f"Skipping MCP server '{name}': sso_auth=true but no valid token for user_id={user_id!r}") + continue + server_cfg.setdefault('headers', {})['Authorization'] = f'Bearer {access_token}' + + # Inject CERN CA bundle via httpx_client_factory (SSE/streamable_http transports) + if _use_cern_ca and server_cfg.get('transport') in ('sse', 'streamable_http'): + server_cfg['httpx_client_factory'] = _make_httpx_factory(_CERN_CA_BUNDLE) + + resolved_servers[name] = server_cfg + + logger.info(f"Configuring MCP client with servers: {list(resolved_servers.keys())}") + client = MultiServerMCPClient(resolved_servers) all_tools: List[BaseTool] = [] failed_servers: dict[str, str] = {} - for name in mcp_servers.keys(): + for name in resolved_servers.keys(): try: tools = await client.get_tools(server_name=name) for tool in tools: @@ -35,7 +86,7 @@ async def initialize_mcp_client() -> Tuple[Optional[MultiServerMCPClient], List[ logger.error(f"Failed to fetch tools from MCP server '{name}': {e}") failed_servers[name] = str(e) - logger.info(f"Active MCP servers: {[n for n in mcp_servers if n not in failed_servers]}") + logger.info(f"Active MCP servers: {[n for n in resolved_servers if n not in failed_servers]}") logger.warning(f"Failed MCP servers: {list(failed_servers.keys())}") return client, all_tools diff --git a/src/bin/service_mattermost.py b/src/bin/service_mattermost.py index db776ea12..6382feb2e 100755 --- a/src/bin/service_mattermost.py +++ b/src/bin/service_mattermost.py @@ -2,14 +2,21 @@ import multiprocessing as mp import os import time +from threading import Thread from src.interfaces import mattermost from src.utils.env import read_secret from src.utils.logging import setup_logging +from src.utils.postgres_service_factory import PostgresServiceFactory # set basicConfig for logging setup_logging() +def run_polling(mattermost_agent, update_time): + while True: + mattermost_agent.process_posts() + time.sleep(update_time) + def main(): # set openai os.environ['OPENAI_API_KEY'] = read_secret("OPENAI_API_KEY") @@ -18,13 +25,26 @@ def main(): time.sleep(30) # temporary hack to prevent mattermost from starting at the same time as other services; eventually replace this with more robust solution - print("Initializing Mattermost Service") - mattermost_agent = mattermost.Mattermost() - update_time = int(mattermost_agent.mattermost_config["update_time"]) + # Initialize Postgres config service (required before any get_full_config() call) + factory = PostgresServiceFactory.from_env(password_override=read_secret("PG_PASSWORD")) + PostgresServiceFactory.set_instance(factory) - while True: - mattermost_agent.process_posts() - time.sleep(update_time) + # Start webhook server first — its __init__ initializes the config service via MattermostAIWrapper + print("Initializing Mattermost webhook server") + webhook_server = mattermost.MattermostWebhookServer() + + # Start polling loop in background thread if PAK is available (config service now ready) + pak = read_secret("MATTERMOST_PAK") + if pak: + print("Initializing Mattermost polling service") + mattermost_agent = mattermost.Mattermost() + update_time = int(mattermost_agent.mattermost_config.get("update_time", 60)) + polling_thread = Thread(target=run_polling, args=(mattermost_agent, update_time), daemon=True) + polling_thread.start() + else: + print("MATTERMOST_PAK not set — skipping polling mode") + + webhook_server.run(host='0.0.0.0', port=webhook_server.port) if __name__ == "__main__": mp.set_start_method("spawn", force=True) diff --git a/src/cli/managers/templates_manager.py b/src/cli/managers/templates_manager.py index cf40c4293..60bef3c24 100644 --- a/src/cli/managers/templates_manager.py +++ b/src/cli/managers/templates_manager.py @@ -478,6 +478,11 @@ def _render_compose_file(self, context: TemplateContext) -> None: template_vars.setdefault("prompt_files", []) template_vars.setdefault("rubrics", []) + # SSL cert file for HTTPS verification (e.g. CERN CA bundle) + chat_config = context.config_manager.config.get("services", {}).get("chat_app", {}) + template_vars.setdefault("ssl_cert_host", chat_config.get("ssl_cert_host", "")) + template_vars.setdefault("ssl_cert_file", chat_config.get("ssl_cert_file", "")) + if context.plan.get_service("grader").enabled: template_vars["rubrics"] = self._get_grader_rubrics(context.config_manager) diff --git a/src/cli/service_registry.py b/src/cli/service_registry.py index 2f2e4718a..ed43b2f60 100644 --- a/src/cli/service_registry.py +++ b/src/cli/service_registry.py @@ -136,8 +136,11 @@ def _register_default_services(self): name='mattermost', description='Integration service for Mattermost channels', category='integration', - required_secrets=['MATTERMOST_WEBHOOK', 'MATTERMOST_CHANNEL_ID_READ', - 'MATTERMOST_CHANNEL_ID_WRITE', 'MATTERMOST_PAK'] + requires_volume=True, + required_secrets=['MATTERMOST_WEBHOOK', + # 'MATTERMOST_CHANNEL_ID_READ', + # 'MATTERMOST_CHANNEL_ID_WRITE', 'MATTERMOST_PAK' + 'MATTERMOST_OUTGOING_TOKEN'] )) self.register(ServiceDefinition( diff --git a/src/cli/templates/base-compose.yaml b/src/cli/templates/base-compose.yaml index 65244fa93..b793752b8 100644 --- a/src/cli/templates/base-compose.yaml +++ b/src/cli/templates/base-compose.yaml @@ -141,6 +141,9 @@ services: NVIDIA_VISIBLE_DEVICES: all NVIDIA_DRIVER_CAPABILITIES: compute,utility,graphics {% endif %} + {% if ssl_cert_file -%} + SSL_CERT_FILE: {{ ssl_cert_file }} + {%- endif %} env_file: - .env {% if gpu_ids and not use_podman -%} @@ -168,6 +171,9 @@ services: {% if gpu_ids -%} - archi-models:/root/models/ {%- endif %} + {% if ssl_cert_host and ssl_cert_file -%} + - {{ ssl_cert_host }}:{{ ssl_cert_file }}:ro + {%- endif %} logging: options: max-size: 10m @@ -386,6 +392,7 @@ services: PGPORT: {{ postgres_port }} PGDATABASE: archi-db PGUSER: archi + PG_PASSWORD: ${PG_PASSWORD} {% for secret in required_secrets | default([]) -%} {{ secret.upper() }}_FILE: /run/secrets/{{ secret.lower() }} {% endfor %} @@ -394,6 +401,9 @@ services: NVIDIA_DRIVER_CAPABILITIES: compute,utility,graphics {% endif %} VERBOSITY: {{ verbosity }} + {% if ssl_cert_file -%} + SSL_CERT_FILE: {{ ssl_cert_file }} + {%- endif %} env_file: - .env {% if gpu_ids and not use_podman -%} @@ -409,8 +419,10 @@ services: {% for secret in required_secrets | default([]) -%} - {{ secret.lower() }} {% endfor %} + ports: + - "{{ services.mattermost.external_port | default(5000) }}:{{ services.mattermost.port | default(5000) }}" volumes: - - {{ data_volume_name }}:/root/data/ + - {{ data_manager_volume_name }}:/root/data/ - ./configs:/root/archi/configs - ./data/prompts:/root/archi/data/prompts:ro - ./data/agents:/root/archi/agents:ro @@ -421,6 +433,9 @@ services: {% if gpu_ids -%} - archi-models:/root/models/ {%- endif %} + {% if ssl_cert_host and ssl_cert_file -%} + - {{ ssl_cert_host }}:{{ ssl_cert_file }}:ro + {%- endif %} logging: options: max-size: 10m diff --git a/src/cli/templates/base-config.yaml b/src/cli/templates/base-config.yaml index ef6d21fee..9d8f7d7e5 100644 --- a/src/cli/templates/base-config.yaml +++ b/src/cli/templates/base-config.yaml @@ -15,6 +15,7 @@ mcp_servers: {{ server_name }}: transport: {{ server_config.transport | default('streamable_http', true) }} url: {{ server_config.url | default('', true) }} + sso_auth: {{ server_config.sso_auth | default(false, true) }} {%- endfor %} services: @@ -51,6 +52,31 @@ services: update_time: {{ services.piazza.update_time | default(60, true) }} mattermost: update_time: {{ services.mattermost.update_time | default(60, true) }} + port: {{ services.mattermost.port | default(5000, true) }} + external_port: {{ services.mattermost.external_port | default(5000, true) }} + base_url: "{{ services.mattermost.base_url | default('https://mattermost.web.cern.ch/', true) }}" + # Bot account Mattermost user ID — used to skip self-replies and add reactions. + # Leave blank to auto-fetch from GET /api/v4/users/me on startup (requires MATTERMOST_PAK). + bot_user_id: "{{ services.mattermost.bot_user_id | default('', true) }}" + # How many prior conversation turns (user+AI pairs) to pass to the AI pipeline per thread. + context_window: {{ services.mattermost.context_window | default(20, true) }} + # Path to JSON file used for polling-mode deduplication. + tracking_file: "{{ services.mattermost.tracking_file | default('/root/data/mattermost/answered_posts.json', true) }}" + # Emoji reactions to add while processing / on success / on error. + reactions: + processing: "{{ services.mattermost.reactions.processing | default('eyes', true) }}" + done: "{{ services.mattermost.reactions.done | default('white_check_mark', true) }}" + error: "{{ services.mattermost.reactions.error | default('x', true) }}" + auth: + enabled: {{ services.mattermost.auth.enabled | default(false, true) }} + token_store: {{ services.mattermost.auth.token_store | default('config', true) }} + default_role: {{ services.mattermost.auth.default_role | default('mattermost-restricted', true) }} + session_lifetime_days: {{ services.mattermost.auth.session_lifetime_days | default(30, true) }} + roles_refresh_hours: {{ services.mattermost.auth.roles_refresh_hours | default(24, true) }} + login_base_url: {{ services.mattermost.auth.login_base_url | default('', true) }} + sso: + token_endpoint: {{ services.mattermost.auth.sso.token_endpoint | default('', true) }} + user_roles: {{ services.mattermost.auth.user_roles | default({}, true) | tojson }} redmine_mailbox: agent_class: {{ services.redmine_mailbox.agent_class | default('CMSCompOpsAgent', true) }} agents_dir: "{{ services.redmine_mailbox.agents_dir | default('', true) }}" @@ -142,6 +168,13 @@ services: grafana: port: {{ services.grafana.port | default(3000, true) }} external_port: {{ services.grafana.external_port | default(3000, true) }} + mcp_server: + enabled: {{ services.mcp_server.enabled | default(false, true) }} + # Public URL of the chat service that MCP clients will connect to. + # Defaults to the chat app's hostname and external port. + url: "{{ services.mcp_server.url | default('http://' + (services.chat_app.hostname | default('localhost', true)) + ':' + (services.chat_app.external_port | default(7861, true) | string), true) }}" + # HTTP request timeout in seconds. + timeout: {{ services.mcp_server.timeout | default(120, true) }} data_manager: collection_name: {{ collection_name | default("default_collection", true) }} diff --git a/src/cli/templates/init.sql b/src/cli/templates/init.sql index 1334fc23c..3f635ffb8 100644 --- a/src/cli/templates/init.sql +++ b/src/cli/templates/init.sql @@ -74,7 +74,27 @@ CREATE INDEX IF NOT EXISTS idx_users_auth_provider ON users(auth_provider); CREATE UNIQUE INDEX IF NOT EXISTS idx_users_github_id ON users(github_id) WHERE github_id IS NOT NULL; -- ============================================================================ --- 1.1 SESSIONS +-- 1.1 MATTERMOST TOKENS +-- ============================================================================ +-- Stores SSO refresh tokens for Mattermost users, enabling role-based access +-- without requiring re-login on every message. + +CREATE TABLE IF NOT EXISTS mattermost_tokens ( + mattermost_user_id VARCHAR(255) PRIMARY KEY, + mattermost_username VARCHAR(255), + email VARCHAR(255), + roles JSONB NOT NULL DEFAULT '[]', + refresh_token BYTEA, -- pgp_sym_encrypt(token, BYOK_ENCRYPTION_KEY) + token_expires_at TIMESTAMPTZ, -- when re-login is required (configurable session lifetime) + roles_refreshed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_mm_tokens_username ON mattermost_tokens(mattermost_username); + +-- ============================================================================ +-- 1.2 SESSIONS -- ============================================================================ CREATE TABLE IF NOT EXISTS sessions ( @@ -88,6 +108,80 @@ CREATE TABLE IF NOT EXISTS sessions ( CREATE INDEX IF NOT EXISTS idx_sessions_user ON sessions(user_id); CREATE INDEX IF NOT EXISTS idx_sessions_expires ON sessions(expires_at); +-- ============================================================================ +-- 1.3 SSO TOKENS (for MCP Bearer auth) +-- ============================================================================ + +CREATE TABLE IF NOT EXISTS sso_tokens ( + user_id VARCHAR(200) PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE, + access_token BYTEA, -- pgp_sym_encrypt(token, BYOK_ENCRYPTION_KEY) + refresh_token BYTEA, -- pgp_sym_encrypt(token, BYOK_ENCRYPTION_KEY) + access_token_expires_at TIMESTAMPTZ, + session_expires_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +-- OAuth2 client registrations for MCP servers (one row per MCP server) +CREATE TABLE IF NOT EXISTS mcp_oauth_clients ( + server_name VARCHAR(200) PRIMARY KEY, + server_url TEXT NOT NULL, + client_id TEXT NOT NULL, + client_secret TEXT NOT NULL DEFAULT '', + redirect_uri TEXT NOT NULL, + auth_meta JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +-- Per-user per-server MCP OAuth2 tokens +CREATE TABLE IF NOT EXISTS mcp_oauth_tokens ( + user_id VARCHAR(200) REFERENCES users(id) ON DELETE CASCADE, + server_name VARCHAR(200) NOT NULL, + access_token BYTEA, -- pgp_sym_encrypt(token, encryption_key) + refresh_token BYTEA, -- pgp_sym_encrypt(token, encryption_key) + access_token_expires_at TIMESTAMPTZ, + session_expires_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + PRIMARY KEY (user_id, server_name) +-- 1.2 MCP API TOKENS (VS Code / Cursor integration) +-- ============================================================================ + +CREATE TABLE IF NOT EXISTS mcp_tokens ( + token VARCHAR(64) PRIMARY KEY, -- secrets.token_hex(32) + user_id VARCHAR(200) NOT NULL REFERENCES users(id) ON DELETE CASCADE, + display_name TEXT, -- e.g. "VS Code – work laptop" + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + last_used_at TIMESTAMPTZ, + expires_at TIMESTAMPTZ -- NULL = never expires +); + +CREATE INDEX IF NOT EXISTS idx_mcp_tokens_user ON mcp_tokens(user_id); + +-- Short-lived authorization codes for the OAuth2 PKCE flow used by MCP clients. +CREATE TABLE IF NOT EXISTS mcp_auth_codes ( + code VARCHAR(64) PRIMARY KEY, + user_id VARCHAR(200) NOT NULL REFERENCES users(id) ON DELETE CASCADE, + code_challenge VARCHAR(128) NOT NULL, + code_challenge_method VARCHAR(10) NOT NULL DEFAULT 'S256', + redirect_uri TEXT NOT NULL, + client_id VARCHAR(100) NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '10 minutes', + used BOOLEAN NOT NULL DEFAULT FALSE +); + +CREATE INDEX IF NOT EXISTS idx_mcp_auth_codes_expires ON mcp_auth_codes(expires_at); + +-- OAuth2 dynamic client registrations (RFC 7591) used by MCP clients. +CREATE TABLE IF NOT EXISTS mcp_oauth_clients ( + client_id VARCHAR(32) PRIMARY KEY, -- secrets.token_hex(16) + client_name TEXT, + redirect_uris TEXT[] NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + -- ============================================================================ -- 2. STATIC CONFIGURATION (Deploy-Time) -- ============================================================================ @@ -355,11 +449,20 @@ CREATE TABLE IF NOT EXISTS conversation_metadata ( title TEXT, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), last_message_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - archi_version VARCHAR(50) + archi_version VARCHAR(50), + -- Cross-platform fields (added for Mattermost ↔ web-chat continuity) + archi_service TEXT NOT NULL DEFAULT 'chat', -- 'chat' | 'mattermost' + source_ref TEXT DEFAULT NULL -- e.g. "mm_thread_" ); -CREATE INDEX IF NOT EXISTS idx_conv_meta_user ON conversation_metadata(user_id); -CREATE INDEX IF NOT EXISTS idx_conv_meta_client ON conversation_metadata(client_id); +CREATE INDEX IF NOT EXISTS idx_conv_meta_user ON conversation_metadata(user_id); +CREATE INDEX IF NOT EXISTS idx_conv_meta_client ON conversation_metadata(client_id); +CREATE INDEX IF NOT EXISTS idx_conv_meta_source_ref ON conversation_metadata(source_ref); + +-- Live-database migrations: add columns that may be missing in existing deployments +ALTER TABLE conversation_metadata ADD COLUMN IF NOT EXISTS archi_service TEXT NOT NULL DEFAULT 'chat'; +ALTER TABLE conversation_metadata ADD COLUMN IF NOT EXISTS source_ref TEXT DEFAULT NULL; +CREATE INDEX IF NOT EXISTS idx_conv_meta_source_ref ON conversation_metadata(source_ref); -- Add FK to conversation_doc_overrides now that conversation_metadata exists DO $$ diff --git a/src/interfaces/chat_app/app.py b/src/interfaces/chat_app/app.py index 6c3e877dc..88631bd8a 100644 --- a/src/interfaces/chat_app/app.py +++ b/src/interfaces/chat_app/app.py @@ -9,7 +9,10 @@ from threading import Lock from typing import Any, Dict, Iterator, List, Optional from pathlib import Path -from urllib.parse import urlparse +import base64 +import hashlib +import secrets +from urllib.parse import urlparse, urlencode, urlunparse, parse_qs, urljoin from functools import wraps import requests @@ -60,6 +63,7 @@ SQL_GET_PENDING_AB_COMPARISON, SQL_DELETE_AB_COMPARISON, SQL_GET_AB_COMPARISONS_BY_CONVERSATION, SQL_CREATE_AGENT_TRACE, SQL_UPDATE_AGENT_TRACE, SQL_GET_AGENT_TRACE, SQL_GET_TRACE_BY_MESSAGE, SQL_GET_ACTIVE_TRACE, SQL_CANCEL_ACTIVE_TRACES, + SQL_LIST_CONVERSATIONS_ALL_SOURCES, SQL_GET_CONVERSATION_METADATA_ALL_SOURCES, ) from src.interfaces.chat_app.document_utils import * from src.interfaces.chat_app.service_alerts import ( @@ -67,6 +71,7 @@ ) from src.interfaces.chat_app.utils import collapse_assistant_sequences from src.utils.user_service import UserService +from src.utils.sso_token_service import SSOTokenService # RBAC imports for role-based access control from src.utils.rbac import ( @@ -1120,6 +1125,12 @@ def update_conversation_timestamp(self, conversation_id: int, client_id: str, us # update timestamp if user_id: cursor.execute(SQL_UPDATE_CONVERSATION_TIMESTAMP_BY_USER, (now, conversation_id, user_id, client_id)) + if cursor.rowcount == 0: + # Mattermost-originated conversation: client_id is "mm_user_" + username = session.get('user', {}).get('username', '') or '' + if username: + mm_client_id = f"mm_user_{username}" + cursor.execute(SQL_UPDATE_CONVERSATION_TIMESTAMP_BY_USER, (now, conversation_id, user_id, mm_client_id)) else: cursor.execute(SQL_UPDATE_CONVERSATION_TIMESTAMP, (now, conversation_id, client_id)) conn.commit() @@ -1562,7 +1573,7 @@ def __call__(self, message: List[str], conversation_id: int|None, client_id: str requested_config = self._resolve_config_name(config_name) self.update_config(config_name=requested_config) - result = self.archi(history=context.history, conversation_id=context.conversation_id) + result = self.archi(history=context.history, conversation_id=context.conversation_id, user_id=user_id) timestamps["chain_finished_ts"] = datetime.now(timezone.utc) # keep track of total number of queries and log this amount @@ -1712,7 +1723,7 @@ def _remember_tool_call(tool_call_id: str, tool_name: Any, tool_args: Any) -> No pipeline_name=self.archi.pipeline_name if hasattr(self.archi, 'pipeline_name') else None, ) - for output in self.archi.stream(history=context.history, conversation_id=context.conversation_id): + for output in self.archi.stream(history=context.history, conversation_id=context.conversation_id, user_id=user_id): if client_timeout and time.time() - stream_start_time > client_timeout: if trace_id: total_duration_ms = int((time.time() - stream_start_time) * 1000) @@ -2164,6 +2175,7 @@ def __init__(self, app, **configs): self.auth_enabled = auth_config.get('enabled', False) self.sso_enabled = auth_config.get('sso', {}).get('enabled', False) self.basic_auth_enabled = auth_config.get('basic', {}).get('enabled', False) + self.mcp_enabled = self.services_config.get('mcp_server', {}).get('enabled', False) logger.info(f"Auth enabled: {self.auth_enabled}, SSO: {self.sso_enabled}, Basic: {self.basic_auth_enabled}") @@ -2298,9 +2310,236 @@ def _inject_alerts(): self.add_endpoint('/api/permissions', 'get_permissions', self.get_permissions, methods=['GET']) self.add_endpoint('/api/permissions/check', 'check_permission', self.check_permission_endpoint, methods=['POST']) - if self.sso_enabled: self.add_endpoint('/redirect', 'sso_callback', self.sso_callback) + self.add_endpoint('/mcp/authorize', 'mcp_authorize', self.mcp_authorize, methods=['GET']) + self.add_endpoint('/mcp/callback', 'mcp_callback', self.mcp_callback, methods=['GET']) + + # MCP SSE endpoint – exposes archi as MCP tools over HTTP+SSE. + if self.mcp_enabled: + logger.info("MCP server enabled – registering /mcp/* endpoints") + from src.interfaces.chat_app.mcp_sse import register_mcp_sse + _mcp_auth_required = self.auth_enabled and self.sso_enabled + _mcp_public_url = self.services_config.get('mcp_server', {}).get('url', '').rstrip('/') + register_mcp_sse( + self.app, self, + pg_config=self.pg_config, + auth_enabled=_mcp_auth_required, + public_url=_mcp_public_url or None, + ) + self.add_endpoint('/.well-known/oauth-authorization-server', 'oauth_metadata', self.oauth_metadata, methods=['GET']) + self.add_endpoint('/.well-known/oauth-protected-resource', 'oauth_protected_resource', self.oauth_protected_resource, methods=['GET']) + self.add_endpoint('/mcp/oauth/register', 'oauth_register', self.oauth_register, methods=['POST']) + self.add_endpoint('/mcp/oauth/authorize', 'oauth_authorize', self.oauth_authorize, methods=['GET']) + self.add_endpoint('/mcp/oauth/token', 'oauth_token', self.oauth_token, methods=['POST']) + if self.auth_enabled and self.sso_enabled: + self.add_endpoint('/mcp/auth', 'mcp_auth', self.mcp_auth, methods=['GET']) + self.add_endpoint('/mcp/auth/regenerate', 'mcp_auth_regenerate', self.mcp_auth_regenerate, methods=['POST']) + else: + logger.info("MCP server disabled (set services.mcp_server.enabled: true to enable)") + + # ------------------------------------------------------------------ + # OAuth2 PKCE endpoints (used by MCP clients like Claude Desktop) + # ------------------------------------------------------------------ + + def _mcp_public_base_url(self) -> str: + """Return the public base URL for MCP endpoints.""" + configured = self.services_config.get('mcp_server', {}).get('url', '').rstrip('/') + if configured: + return configured + fwd_proto = request.headers.get("X-Forwarded-Proto") or request.scheme + fwd_host = request.headers.get("X-Forwarded-Host") or request.host + return f"{fwd_proto}://{fwd_host}" + + def oauth_metadata(self): + """GET /.well-known/oauth-authorization-server — RFC 8414 discovery doc.""" + base = self._mcp_public_base_url() + return jsonify({ + "issuer": base, + "authorization_endpoint": base + "/mcp/oauth/authorize", + "token_endpoint": base + "/mcp/oauth/token", + "registration_endpoint": base + "/mcp/oauth/register", + "response_types_supported": ["code"], + "grant_types_supported": ["authorization_code"], + "code_challenge_methods_supported": ["S256"], + }) + + def oauth_protected_resource(self): + """GET /.well-known/oauth-protected-resource — RFC 8707 resource metadata.""" + base = self._mcp_public_base_url() + return jsonify({ + "resource": base + "/", + "authorization_servers": [base], + }) + + def oauth_register(self): + """POST /mcp/oauth/register — RFC 7591 dynamic client registration. + + MCP clients (mcp-remote, Claude Desktop, VS Code) call this once to + obtain a client_id before starting the authorization flow. We keep + it simple: any caller may register; we persist the client and return + a client_id immediately (no client_secret for public clients). + """ + body = request.get_json(silent=True) or {} + redirect_uris = body.get("redirect_uris", []) + client_name = body.get("client_name", "") + + if not redirect_uris: + return jsonify({"error": "invalid_client_metadata", + "error_description": "redirect_uris is required"}), 400 + + client_id = secrets.token_hex(16) + conn = psycopg2.connect(**self.pg_config) + try: + with conn.cursor() as cur: + cur.execute( + """INSERT INTO mcp_oauth_clients (client_id, client_name, redirect_uris) + VALUES (%s, %s, %s)""", + (client_id, client_name, redirect_uris), + ) + conn.commit() + finally: + conn.close() + + base = self._mcp_public_base_url() + return jsonify({ + "client_id": client_id, + "client_name": client_name, + "redirect_uris": redirect_uris, + "grant_types": ["authorization_code"], + "response_types": ["code"], + "token_endpoint_auth_method": "none", + "registration_client_uri": base + "/mcp/oauth/register", + }), 201 + + def oauth_authorize(self): + """GET /mcp/oauth/authorize — OAuth2 authorization code endpoint with PKCE. + + If the user is not logged in they are redirected to SSO login and + returned here afterwards via ``session['sso_next']``. Once logged in + an auth code is generated, stored in ``mcp_auth_codes``, and the + browser is redirected back to the client's ``redirect_uri``. + """ + client_id = request.args.get('client_id', '') + response_type = request.args.get('response_type', '') + code_challenge = request.args.get('code_challenge', '') + code_challenge_method = request.args.get('code_challenge_method', 'S256') + redirect_uri = request.args.get('redirect_uri', '') + state = request.args.get('state', '') + + if response_type != 'code' or not code_challenge or not redirect_uri: + return jsonify({"error": "invalid_request", + "error_description": "response_type=code, code_challenge, and redirect_uri are required"}), 400 + + if not session.get('logged_in'): + # Preserve all OAuth params so we return here after SSO login. + session['sso_next'] = request.url + if self.sso_enabled and self.oauth: + # Reuse the existing SSO config directly — no intermediate login page. + sso_callback_uri = url_for('sso_callback', _external=True) + return self.oauth.sso.authorize_redirect(sso_callback_uri) + return redirect(url_for('login')) + + user_id = session.get('user', {}).get('id') + if not user_id: + return jsonify({"error": "server_error", "error_description": "Could not determine user identity"}), 500 + + # Create a short-lived auth code. + code = secrets.token_hex(32) + conn = psycopg2.connect(**self.pg_config) + try: + with conn.cursor() as cur: + cur.execute( + """INSERT INTO mcp_auth_codes + (code, user_id, code_challenge, code_challenge_method, redirect_uri, client_id) + VALUES (%s, %s, %s, %s, %s, %s)""", + (code, user_id, code_challenge, code_challenge_method, redirect_uri, client_id), + ) + conn.commit() + finally: + conn.close() + + # Safely append code (and optional state) to the redirect_uri. + parsed = urlparse(redirect_uri) + params = {"code": code} + if state: + params["state"] = state + new_query = urlencode(params) if not parsed.query else parsed.query + '&' + urlencode(params) + return redirect(urlunparse(parsed._replace(query=new_query))) + + def oauth_token(self): + """POST /token — OAuth2 token exchange with PKCE verification.""" + grant_type = request.form.get('grant_type', '') + code = request.form.get('code', '') + code_verifier = request.form.get('code_verifier', '') + redirect_uri = request.form.get('redirect_uri', '') + + if grant_type != 'authorization_code' or not code or not code_verifier: + return jsonify({"error": "invalid_request", + "error_description": "grant_type=authorization_code, code, and code_verifier are required"}), 400 + + # Verify PKCE before touching the DB. + digest = hashlib.sha256(code_verifier.encode()).digest() + computed_challenge = base64.urlsafe_b64encode(digest).rstrip(b'=').decode() + + conn = psycopg2.connect(**self.pg_config) + try: + with conn.cursor() as cur: + # Atomically mark the code as used and return its fields. + # This prevents replay attacks without a separate SELECT + UPDATE. + cur.execute( + """UPDATE mcp_auth_codes + SET used = TRUE + WHERE code = %s + AND used = FALSE + AND expires_at > NOW() + RETURNING user_id, code_challenge, redirect_uri""", + (code,), + ) + row = cur.fetchone() + if not row: + return jsonify({"error": "invalid_grant", + "error_description": "Authorization code is invalid, expired, or already used"}), 400 + + user_id, stored_challenge, stored_redirect = row + + # Validate redirect_uri matches what was used in /authorize. + if redirect_uri and redirect_uri != stored_redirect: + return jsonify({"error": "invalid_grant", + "error_description": "redirect_uri does not match the authorization request"}), 400 + + # Verify PKCE: BASE64URL(SHA256(code_verifier)) == code_challenge + if computed_challenge != stored_challenge: + return jsonify({"error": "invalid_grant", + "error_description": "code_verifier does not match code_challenge"}), 400 + + # Opportunistically delete expired codes to keep the table tidy. + cur.execute("DELETE FROM mcp_auth_codes WHERE expires_at < NOW()") + + # Fetch or create the user's long-lived MCP token in the same + # connection to avoid opening extra DB connections. + cur.execute( + """SELECT token FROM mcp_tokens + WHERE user_id = %s + AND (expires_at IS NULL OR expires_at > NOW()) + ORDER BY created_at DESC LIMIT 1""", + (user_id,), + ) + token_row = cur.fetchone() + if token_row: + token = token_row[0] + else: + token = secrets.token_hex(32) + cur.execute( + "INSERT INTO mcp_tokens (token, user_id) VALUES (%s, %s)", + (token, user_id), + ) + + conn.commit() + finally: + conn.close() + + return jsonify({"access_token": token, "token_type": "bearer"}) def _set_user_session(self, email: str, name: str, username: str, user_id: str = '', auth_method: str = 'sso', roles: list = None): """Set user session with well-defined structure.""" @@ -2358,6 +2597,29 @@ def _setup_sso(self): logger.info(f"SSO configured with server: {server_metadata_url}") + # Derive token endpoint for silent refresh (Keycloak-style metadata URL) + token_endpoint = sso_config.get('token_endpoint', '') + if not token_endpoint and server_metadata_url: + import re as _re + token_endpoint = _re.sub(r'/\.well-known/.*', '/protocol/openid-connect/token', server_metadata_url) + + session_lifetime_days = self.chat_app_config.get('auth', {}).get('session_lifetime_days', 30) + self.sso_token_service = SSOTokenService( + pg_config=self.pg_config, + token_endpoint=token_endpoint, + session_lifetime_days=int(session_lifetime_days), + ) + + # MCP OAuth2 service — handles per-server authorization code + PKCE flow + from src.utils.mcp_oauth_service import MCPOAuthService + from src.utils.config_access import get_mcp_servers_config as _get_mcp_cfg + mcp_server_cfg = self.config.get('services', {}).get('mcp_server', {}) + app_base_url = mcp_server_cfg.get('url', '') or f"http://localhost:{self.chat_app_config.get('port', 7861)}" + self.mcp_oauth_service = MCPOAuthService( + pg_config=self.pg_config, + app_base_url=app_base_url, + ) + def login(self): """Unified login endpoint supporting multiple auth methods""" # If user is already logged in, redirect to index @@ -2465,7 +2727,20 @@ def sso_callback(self): auth_method='sso', roles=user_roles ) - + + # Persist access + refresh tokens in DB so SSO-auth MCP servers can + # authenticate on behalf of this user across requests. + if sso_user_id and token.get('access_token'): + try: + self.sso_token_service.store_token( + user_id=sso_user_id, + access_token=token['access_token'], + refresh_token=token.get('refresh_token'), + expires_in=int(token.get('expires_in', 300)), + ) + except Exception as te: + logger.warning(f"Failed to persist SSO token for MCP auth: {te}") + # Log successful authentication log_authentication_event( user=user_email, @@ -2476,7 +2751,25 @@ def sso_callback(self): ) logger.info(f"SSO login successful for user: {user_email} with roles: {user_roles}") - + + # After login, authorize any MCP servers that need it + # Use preferred_username (same key used by Mattermost) so token lookups match. + mcp_user_id = user_info.get('preferred_username', '') or sso_user_id + if hasattr(self, 'mcp_oauth_service') and mcp_user_id: + try: + from src.utils.config_access import get_mcp_servers_config as _get_mcp + mcp_servers = _get_mcp() + needing = self.mcp_oauth_service.get_servers_needing_auth(mcp_user_id, mcp_servers) + if needing: + logger.info(f"Redirecting user {mcp_user_id!r} to authorize MCP server(s): {needing}") + return redirect(url_for('mcp_authorize', server=needing[0], next=url_for('index'))) + except Exception as me: + logger.warning(f"MCP auth check failed after login: {me}") + # Honour any pending post-login redirect (e.g. /mcp/auth) + next_url = session.pop('sso_next', None) + if next_url: + return redirect(next_url) + # Redirect to main page return redirect(url_for('index')) @@ -2492,6 +2785,195 @@ def sso_callback(self): flash(f"Authentication failed: {str(e)}") return redirect(url_for('login')) + def mcp_authorize(self): + """Redirect user to an MCP server's OAuth2 authorization endpoint.""" + if not session.get('logged_in'): + return redirect(url_for('login')) + + server_name = request.args.get('server', '') + next_url = request.args.get('next', url_for('index')) + + from src.utils.config_access import get_mcp_servers_config as _get_mcp + mcp_servers = _get_mcp() + server_cfg = mcp_servers.get(server_name) + if not server_cfg or not server_cfg.get('sso_auth'): + logger.warning(f"mcp_authorize: unknown or non-sso_auth server '{server_name}'") + return redirect(next_url) + + server_url = server_cfg.get('url', '') + result = self.mcp_oauth_service.get_authorization_url(server_name, server_url) + if not result: + logger.error(f"mcp_authorize: could not build auth URL for '{server_name}'") + flash(f"Could not connect to MCP server '{server_name}'. Check logs for details.") + return redirect(next_url) + + auth_url, state, code_verifier = result + session[f'mcp_state_{server_name}'] = state + session[f'mcp_verifier_{server_name}'] = code_verifier + session['mcp_pending_server'] = server_name + session['mcp_next_url'] = next_url + logger.info(f"Redirecting user to MCP OAuth for server '{server_name}'") + return redirect(auth_url) + + def mcp_callback(self): + """Handle the OAuth2 callback from an MCP server.""" + if not session.get('logged_in'): + return redirect(url_for('login')) + + code = request.args.get('code', '') + state = request.args.get('state', '') + error = request.args.get('error', '') + + server_name = session.pop('mcp_pending_server', '') + next_url = session.pop('mcp_next_url', url_for('index')) + + if error or not code or not server_name: + logger.warning(f"MCP callback error for '{server_name}': error={error!r}, code present={bool(code)}") + flash(f"MCP authorization failed for '{server_name}': {error or 'missing code'}") + return redirect(next_url) + + expected_state = session.pop(f'mcp_state_{server_name}', '') + code_verifier = session.pop(f'mcp_verifier_{server_name}', '') + + if state != expected_state: + logger.error(f"MCP callback state mismatch for '{server_name}'") + flash("MCP authorization failed: state mismatch.") + return redirect(next_url) + + token_data = self.mcp_oauth_service.exchange_code(server_name, code, code_verifier) + if not token_data or not token_data.get('access_token'): + logger.error(f"MCP token exchange failed for '{server_name}'") + flash(f"MCP authorization failed for '{server_name}': token exchange error.") + return redirect(next_url) + + # Use preferred_username as the MCP token key so that Mattermost lookups + # (which pass ctx.username = preferred_username) find the stored token. + # Fall back to SSO sub UUID only if username is absent. + user_id = session.get('user', {}).get('username', '') or session.get('user', {}).get('id', '') + self.mcp_oauth_service.store_user_token( + user_id=user_id, + server_name=server_name, + access_token=token_data['access_token'], + refresh_token=token_data.get('refresh_token'), + expires_in=int(token_data.get('expires_in', 3600)), + ) + logger.info(f"MCP OAuth complete for user={user_id!r}, server='{server_name}'") + + # If more servers need auth, chain to next one + if hasattr(self, 'mcp_oauth_service') and user_id: + try: + from src.utils.config_access import get_mcp_servers_config as _get_mcp + mcp_servers = _get_mcp() + needing = self.mcp_oauth_service.get_servers_needing_auth(user_id, mcp_servers) + if needing: + return redirect(url_for('mcp_authorize', server=needing[0], next=next_url)) + except Exception as me: + logger.warning(f"MCP chain auth check failed: {me}") + + return redirect(next_url) + # ------------------------------------------------------------------ + # MCP token helpers + # ------------------------------------------------------------------ + + def _get_mcp_token(self, user_id: str) -> Optional[str]: + """Return the existing MCP token for a user, or None.""" + import secrets as _secrets + conn = psycopg2.connect(**self.pg_config) + try: + with conn.cursor() as cur: + cur.execute( + """SELECT token FROM mcp_tokens + WHERE user_id = %s + AND (expires_at IS NULL OR expires_at > NOW()) + ORDER BY created_at DESC LIMIT 1""", + (user_id,), + ) + row = cur.fetchone() + return row[0] if row else None + finally: + conn.close() + + def _create_mcp_token(self, user_id: str) -> str: + """Create and store a new MCP token for a user, returning the token string.""" + import secrets as _secrets + token = _secrets.token_hex(32) + conn = psycopg2.connect(**self.pg_config) + try: + with conn.cursor() as cur: + cur.execute( + "INSERT INTO mcp_tokens (token, user_id) VALUES (%s, %s)", + (token, user_id), + ) + conn.commit() + finally: + conn.close() + return token + + def _rotate_mcp_token(self, user_id: str) -> str: + """Delete all existing MCP tokens for a user and create a fresh one.""" + import secrets as _secrets + token = _secrets.token_hex(32) + conn = psycopg2.connect(**self.pg_config) + try: + with conn.cursor() as cur: + cur.execute("DELETE FROM mcp_tokens WHERE user_id = %s", (user_id,)) + cur.execute( + "INSERT INTO mcp_tokens (token, user_id) VALUES (%s, %s)", + (token, user_id), + ) + conn.commit() + finally: + conn.close() + return token + + # ------------------------------------------------------------------ + # MCP auth page + # ------------------------------------------------------------------ + + def mcp_auth(self): + """Show the MCP token page. + + Requires SSO login. If the user is not logged in they are + redirected to the SSO provider; after login the SSO callback + returns them here via ``session['sso_next']``. + """ + if not session.get('logged_in'): + session['sso_next'] = '/mcp/auth' + return redirect(url_for('login') + '?method=sso') + + user_info = session.get('user', {}) + user_id = user_info.get('id') + if not user_id: + flash("Could not determine your user identity. Please log in again.") + return redirect(url_for('login')) + + token = self._get_mcp_token(user_id) + if not token: + token = self._create_mcp_token(user_id) + + mcp_url = self._mcp_public_base_url() + '/mcp/sse' + regenerated = request.args.get('regenerated') == '1' + + return render_template( + 'mcp_auth.html', + token=token, + mcp_url=mcp_url, + user=user_info, + regenerated=regenerated, + ) + + def mcp_auth_regenerate(self): + """Rotate the user's MCP token (POST /mcp/auth/regenerate).""" + if not session.get('logged_in'): + return jsonify({'error': 'Authentication required'}), 401 + + user_id = session.get('user', {}).get('id') + if not user_id: + return jsonify({'error': 'Could not determine user identity'}), 400 + + self._rotate_mcp_token(user_id) + return redirect(url_for('mcp_auth') + '?regenerated=1') + def get_user(self): """API endpoint to get current user information including roles and permissions""" if session.get('logged_in'): @@ -3802,7 +4284,10 @@ def list_conversations(self): conn = psycopg2.connect(**self.pg_config) cursor = conn.cursor() if user_id: - cursor.execute(SQL_LIST_CONVERSATIONS_BY_USER, (user_id, client_id, limit)) + # Include Mattermost-originated conversations via "mm_user_" client_id + username = session.get('user', {}).get('username', '') or '' + mm_client_id = f"mm_user_{username}" if username else "" + cursor.execute(SQL_LIST_CONVERSATIONS_ALL_SOURCES, (user_id, client_id, mm_client_id, limit)) else: cursor.execute(SQL_LIST_CONVERSATIONS, (client_id, limit)) rows = cursor.fetchall() @@ -3814,6 +4299,7 @@ def list_conversations(self): 'title': row[1] or "New Chat", 'created_at': row[2].isoformat() if row[2] else None, 'last_message_at': row[3].isoformat() if row[3] else None, + 'archi_service': row[4] if len(row) > 4 else 'chat', }) # clean up database connection state @@ -3853,9 +4339,11 @@ def load_conversation(self): conn = psycopg2.connect(**self.pg_config) cursor = conn.cursor() - # get conversation metadata + # get conversation metadata — include Mattermost conversations for authenticated users if user_id: - cursor.execute(SQL_GET_CONVERSATION_METADATA_BY_USER, (conversation_id, user_id, client_id)) + username = session.get('user', {}).get('username', '') or '' + mm_client_id = f"mm_user_{username}" if username else "" + cursor.execute(SQL_GET_CONVERSATION_METADATA_ALL_SOURCES, (conversation_id, user_id, client_id, mm_client_id)) else: cursor.execute(SQL_GET_CONVERSATION_METADATA, (conversation_id, client_id)) meta_row = cursor.fetchone() @@ -3918,6 +4406,7 @@ def load_conversation(self): 'title': meta_row[1] or "New Conversation", 'created_at': meta_row[2].isoformat() if meta_row[2] else None, 'last_message_at': meta_row[3].isoformat() if meta_row[3] else None, + 'archi_service': meta_row[4] if len(meta_row) > 4 else 'chat', 'messages': messages } diff --git a/src/interfaces/chat_app/mcp_sse.py b/src/interfaces/chat_app/mcp_sse.py new file mode 100644 index 000000000..c77e28aad --- /dev/null +++ b/src/interfaces/chat_app/mcp_sse.py @@ -0,0 +1,1430 @@ +""" +MCP SSE endpoint – exposes archi's RAG capabilities as MCP tools over HTTP+SSE. + +AI assistants in VS Code (GitHub Copilot), Cursor, Claude Desktop, Claude Code, +and any other MCP-compatible client can connect with just a URL: + + http://:/mcp/sse + +No local installation required on the client side. + +VS Code (.vscode/mcp.json): + { + "servers": { + "archi": { "type": "sse", "url": "http://localhost:7861/mcp/sse" } + } + } + +Cursor (~/.cursor/mcp.json): + { + "mcpServers": { + "archi": { "url": "http://localhost:7861/mcp/sse" } + } + } + +Claude Desktop (~/Library/Application Support/Claude/claude_desktop_config.json): + { + "mcpServers": { + "archi": { + "command": "npx", + "args": [ + "mcp-remote", + "http://localhost:7861/mcp/sse", + "--header", + "Authorization: Bearer " + ] + } + } + } + +Claude Code (run once in terminal): + claude mcp add --transport sse archi http://localhost:7861/mcp/sse + + Or add to .mcp.json in your project root: + { + "mcpServers": { + "archi": { "type": "sse", "url": "http://localhost:7861/mcp/sse" } + } + } + +Implements the MCP SSE transport (JSON-RPC 2.0 over Server-Sent Events) +directly in Flask using thread-safe queues — no external ``mcp`` package needed. +""" + +from __future__ import annotations + +import json +import queue +import re +import shlex +import textwrap +import uuid +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime, timezone +from threading import BoundedSemaphore, Lock +from pathlib import Path +from typing import Any, Callable, Dict, Optional + +import psycopg2 +import yaml +from flask import Blueprint, Response, jsonify, request, stream_with_context + +from src.utils.logging import get_logger + +logger = get_logger(__name__) + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +_MCP_VERSION = "2024-11-05" +_SERVER_INFO = {"name": "archi", "version": "1.0.0"} +_KEEPALIVE_TIMEOUT = 30 # seconds between keepalive pings +_MCP_DISPATCH_MAX_WORKERS = 48 +_MCP_DISPATCH_MAX_INFLIGHT = 512 + +# --------------------------------------------------------------------------- +# Session registry (session_id → {"queue": Queue, "user_id": str|None}) +# --------------------------------------------------------------------------- + +_sessions: Dict[str, Dict] = {} +_sessions_lock = Lock() +_dispatch_executor = ThreadPoolExecutor( + max_workers=_MCP_DISPATCH_MAX_WORKERS, + thread_name_prefix="mcp-dispatch", +) +_dispatch_slots = BoundedSemaphore(_MCP_DISPATCH_MAX_INFLIGHT) + + +# --------------------------------------------------------------------------- +# Token validation +# --------------------------------------------------------------------------- + + +def _validate_mcp_token(token: str, pg_config: Optional[dict]) -> Optional[str]: + """Validate an MCP bearer token and return the user_id, or None if invalid.""" + if not token or not pg_config: + return None + try: + conn = psycopg2.connect(**pg_config) + try: + with conn.cursor() as cur: + cur.execute( + """SELECT user_id FROM mcp_tokens + WHERE token = %s + AND (expires_at IS NULL OR expires_at > NOW())""", + (token,), + ) + row = cur.fetchone() + if row: + cur.execute( + "UPDATE mcp_tokens SET last_used_at = NOW() WHERE token = %s", + (token,), + ) + conn.commit() + return row[0] + finally: + conn.close() + except Exception: + logger.exception("Error validating MCP token") + return None + + +def _extract_bearer_token(req) -> Optional[str]: + auth_header = req.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + return auth_header[7:].strip() + return None + +# --------------------------------------------------------------------------- +# MCP tool definitions +# --------------------------------------------------------------------------- + +_TOOLS = [ + { + "name": "archi_query", + "description": textwrap.dedent("""\ + Ask a question to the archi RAG (Retrieval-Augmented Generation) system. + + archi retrieves relevant documents from its knowledge base and uses an LLM + to compose a grounded answer. Use this tool when you need information that + is stored in the connected archi deployment (documentation, tickets, wiki + pages, research papers, course material, etc.). + + You may continue a conversation by passing the conversation_id returned by + a previous call. + """), + "inputSchema": { + "type": "object", + "properties": { + "question": { + "type": "string", + "description": "The question or request to send to archi.", + }, + "conversation_id": { + "type": "integer", + "description": ( + "Optional. Pass the conversation_id from a previous archi_query " + "call to continue the same conversation thread." + ), + }, + "provider": { + "type": "string", + "description": "Optional. Override the LLM provider (e.g. 'openai', 'anthropic').", + }, + "model": { + "type": "string", + "description": "Optional. Override the specific model (e.g. 'gpt-4o').", + }, + "config_name": { + "type": "string", + "description": "Optional. The deployment config name to use (e.g. 'comp_ops'). Defaults to the active config.", + }, + "client_timeout": { + "type": "number", + "description": "Optional. Request timeout in milliseconds (default 18000000 = 5 hours).", + "default": 18000000, + }, + }, + "required": ["question"], + }, + }, + { + "name": "archi_list_documents", + "description": textwrap.dedent("""\ + List the documents that have been indexed into archi's knowledge base. + + Returns a paginated list of document metadata (filename, source type, + URL, enabled state, ingestion status, etc.). Use this tool to discover + what information archi has access to before querying it, or to find a + specific document's hash for use with archi_get_document_content. + """), + "inputSchema": { + "type": "object", + "properties": { + "conversation_id": { + "type": "integer", + "description": ( + "Optional. Filter enabled/disabled state for a specific " + "conversation." + ), + }, + "search": { + "type": "string", + "description": "Optional keyword to filter documents by name or URL.", + }, + "source_type": { + "type": "string", + "description": "Optional. Filter by source type: 'web', 'git', 'local', 'jira', etc.", + }, + "enabled": { + "type": "string", + "description": ( + "Optional. Filter by enabled state: 'enabled', 'disabled', " + "or 'all' (default)." + ), + }, + "limit": { + "type": "integer", + "description": "Max results to return (default 50, max 500).", + "default": 50, + }, + "offset": { + "type": "integer", + "description": "Pagination offset (default 0).", + "default": 0, + }, + }, + "required": [], + }, + }, + { + "name": "archi_get_document_content", + "description": textwrap.dedent("""\ + Retrieve the full text content of a document indexed in archi. + + Use archi_list_documents first to obtain a document's hash, then pass + it here to read the raw source text that archi ingested. + """), + "inputSchema": { + "type": "object", + "properties": { + "document_hash": { + "type": "string", + "description": "The document hash returned by archi_list_documents.", + }, + "max_size": { + "type": "integer", + "description": ( + "Optional maximum number of bytes/chars to return " + "(default 100000, max 1000000)." + ), + "default": 100000, + }, + }, + "required": ["document_hash"], + }, + }, + { + "name": "archi_search_document_metadata", + "description": textwrap.dedent("""\ + Search the indexed document catalog by metadata, paths, URLs, ticket IDs, + and other stored document attributes. + + Supports free text plus exact `key:value` filters. Multiple filter groups + can be OR-ed with the literal token `OR`, matching the same metadata-query + syntax used by archi's built-in agent tools. + """), + "inputSchema": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": ( + "Metadata query string, e.g. " + "`source_type:git relative_path:docs/README.md`." + ), + }, + "limit": { + "type": "integer", + "description": "Max results to return (default 10, max 100).", + "default": 10, + }, + }, + "required": ["query"], + }, + }, + { + "name": "archi_list_metadata_schema", + "description": textwrap.dedent("""\ + List the metadata filter keys and common values supported by + archi_search_document_metadata. + + Use this tool before metadata searches when you do not know which + fields exist in the catalog. + """), + "inputSchema": {"type": "object", "properties": {}, "required": []}, + }, + { + "name": "archi_search_document_content", + "description": textwrap.dedent("""\ + Search indexed document contents for an exact phrase or regex pattern. + + This is a grep-like content search intended for logs, error messages, + code snippets, and other exact-text lookups. Optionally pre-filter the + candidate documents with a metadata query. + """), + "inputSchema": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Phrase or regex pattern to search for.", + }, + "metadata_query": { + "type": "string", + "description": ( + "Optional metadata pre-filter using the same syntax as " + "archi_search_document_metadata." + ), + }, + "regex": { + "type": "boolean", + "description": "Treat `query` as a regular expression.", + "default": False, + }, + "case_sensitive": { + "type": "boolean", + "description": "Perform a case-sensitive match.", + "default": False, + }, + "before": { + "type": "integer", + "description": "Number of context lines before each match.", + "default": 0, + }, + "after": { + "type": "integer", + "description": "Number of context lines after each match.", + "default": 0, + }, + "max_matches_per_document": { + "type": "integer", + "description": "Maximum matches to show per document.", + "default": 3, + }, + "limit": { + "type": "integer", + "description": "Maximum documents to return (default 5, max 20).", + "default": 5, + }, + }, + "required": ["query"], + }, + }, + { + "name": "archi_get_document_chunks", + "description": textwrap.dedent("""\ + Inspect the stored chunks for a document as they exist in archi's + vectorized corpus. + + Useful for debugging chunk boundaries, truncation, and ingestion issues. + """), + "inputSchema": { + "type": "object", + "properties": { + "document_hash": { + "type": "string", + "description": "The document hash returned by archi_list_documents.", + }, + "offset": { + "type": "integer", + "description": "Chunk offset to start from (default 0).", + "default": 0, + }, + "limit": { + "type": "integer", + "description": "Maximum chunks to return (default 20, max 100).", + "default": 20, + }, + "max_chars_per_chunk": { + "type": "integer", + "description": "Maximum characters to show per chunk (default 600).", + "default": 600, + }, + }, + "required": ["document_hash"], + }, + }, + { + "name": "archi_get_data_stats", + "description": textwrap.dedent("""\ + Return corpus-level statistics for the connected archi deployment. + + Includes total documents, total chunks, enabled/disabled counts, + ingestion status counts, bytes stored, and a breakdown by source type. + """), + "inputSchema": { + "type": "object", + "properties": { + "conversation_id": { + "type": "integer", + "description": ( + "Optional. Compute enabled/disabled counts for a specific " + "conversation." + ), + }, + }, + "required": [], + }, + }, + { + "name": "archi_get_deployment_info", + "description": textwrap.dedent("""\ + Return configuration and status information about the connected archi + deployment. + + Includes the active LLM pipeline and model, retrieval settings (number of + documents retrieved, hybrid search weights), embedding model, and the list + of available pipelines. Useful for understanding how archi is configured + before issuing queries. + """), + "inputSchema": {"type": "object", "properties": {}, "required": []}, + }, + { + "name": "archi_list_agents", + "description": textwrap.dedent("""\ + Return the agent configurations (agent specs) available in this archi + deployment. + + Each agent spec defines a name, a system prompt, and the set of tools + (retriever, MCP servers, local file search, etc.) that agent can use. + """), + "inputSchema": {"type": "object", "properties": {}, "required": []}, + }, + { + "name": "archi_get_agent_spec", + "description": textwrap.dedent("""\ + Retrieve the full agent spec markdown for a named archi agent. + + Use archi_list_agents first to discover available agent names, then call + this tool to inspect the exact tools and prompt configured for that agent. + """), + "inputSchema": { + "type": "object", + "properties": { + "agent_name": { + "type": "string", + "description": "The agent name returned by archi_list_agents.", + }, + }, + "required": ["agent_name"], + }, + }, + { + "name": "archi_health", + "description": ( + "Check whether the archi deployment is reachable and its database is healthy." + ), + "inputSchema": {"type": "object", "properties": {}, "required": []}, + }, +] + +# --------------------------------------------------------------------------- +# JSON-RPC helpers +# --------------------------------------------------------------------------- + + +def _ok(result: Any, rpc_id: Any) -> Dict: + return {"jsonrpc": "2.0", "id": rpc_id, "result": result} + + +def _err(code: int, message: str, rpc_id: Any) -> Dict: + return {"jsonrpc": "2.0", "id": rpc_id, "error": {"code": code, "message": message}} + + +def _text(text: str) -> Dict: + """Wrap a string as an MCP tool result.""" + return {"content": [{"type": "text", "text": str(text)}]} + + +_METADATA_ALIAS_MAP = { + "resource_type": "source_type", + "resource_id": "ticket_id", +} + +_METADATA_FILTER_KEYS = [ + "path", + "file_path", + "display_name", + "source_type", + "url", + "ticket_id", + "suffix", + "size_bytes", + "original_path", + "base_path", + "relative_path", + "created_at", + "modified_at", + "file_modified_at", + "ingested_at", +] + + +def _clamp_int(value: Any, *, default: int, minimum: int, maximum: int) -> int: + try: + parsed = int(value) + except (TypeError, ValueError): + parsed = default + return max(minimum, min(parsed, maximum)) + + +def _truncate_text(value: Any, *, max_chars: int) -> str: + text = str(value or "") + if len(text) <= max_chars: + return text + return text[: max_chars - 3].rstrip() + "..." + + +def _parse_metadata_query(query: str) -> tuple[Dict[str, str] | list[Dict[str, str]], str]: + filter_groups: list[Dict[str, str]] = [] + current_group: Dict[str, str] = {} + free_tokens: list[str] = [] + + try: + tokens = shlex.split(query) + except ValueError as exc: + # Fall back to whitespace tokenization for malformed quoted input. + logger.warning("Invalid metadata query syntax; using fallback tokenization: %s", exc) + tokens = query.split() + + for token in tokens: + if token.upper() == "OR": + if current_group: + filter_groups.append(current_group) + current_group = {} + continue + if ":" in token: + key, value = token.split(":", 1) + key = _METADATA_ALIAS_MAP.get(key.strip(), key.strip()) + value = value.strip() + if key and value: + current_group[key] = value + continue + free_tokens.append(token) + + if current_group: + filter_groups.append(current_group) + + if not filter_groups: + filters: Dict[str, str] | list[Dict[str, str]] = {} + elif len(filter_groups) == 1: + filters = filter_groups[0] + else: + filters = filter_groups + + return filters, " ".join(free_tokens) + + +def _compile_query_pattern(query: str, *, regex: bool, case_sensitive: bool) -> re.Pattern[str]: + flags = 0 if case_sensitive else re.IGNORECASE + pattern = query if regex else re.escape(query) + return re.compile(pattern, flags) + + +def _grep_text_lines( + text: str, + pattern: re.Pattern[str], + *, + before: int = 0, + after: int = 0, + max_matches: int = 3, +) -> list[Dict[str, Any]]: + if max_matches <= 0: + return [] + lines = text.splitlines() + matches: list[Dict[str, Any]] = [] + for idx, line in enumerate(lines): + if not pattern.search(line): + continue + matches.append( + { + "line": idx + 1, + "text": line, + "before": lines[max(0, idx - before):idx] if before else [], + "after": lines[idx + 1: idx + 1 + after] if after else [], + } + ) + if len(matches) >= max_matches: + break + return matches + + +def _document_display_name(doc: Dict[str, Any]) -> str: + return ( + doc.get("display_name") + or doc.get("filename") + or doc.get("url") + or doc.get("hash") + or doc.get("id") + or "unknown" + ) + + +def _parse_agent_frontmatter(path: Path) -> Optional[Dict[str, Any]]: + try: + text = path.read_text(encoding="utf-8") + except Exception: + return None + + lines = text.splitlines() + idx = 0 + while idx < len(lines) and not lines[idx].strip(): + idx += 1 + if idx >= len(lines) or lines[idx].strip() != "---": + return None + + idx += 1 + frontmatter_lines: list[str] = [] + while idx < len(lines): + if lines[idx].strip() == "---": + idx += 1 + break + frontmatter_lines.append(lines[idx]) + idx += 1 + else: + return None + + try: + frontmatter = yaml.safe_load("\n".join(frontmatter_lines)) or {} + except Exception: + return None + + if not isinstance(frontmatter, dict): + return None + + name = frontmatter.get("name") + tools = frontmatter.get("tools") + if not isinstance(name, str) or not name.strip(): + return None + if not isinstance(tools, list) or not all(isinstance(tool, str) and tool.strip() for tool in tools): + return None + + return { + "name": name.strip(), + "tools": [tool.strip() for tool in tools], + "path": path, + "content": text, + } + + +def _list_agent_specs(agents_dir: Path) -> list[Dict[str, Any]]: + if not agents_dir.exists() or not agents_dir.is_dir(): + return [] + + specs: list[Dict[str, Any]] = [] + for path in sorted(agents_dir.iterdir()): + if not path.is_file() or path.suffix.lower() != ".md": + continue + spec = _parse_agent_frontmatter(path) + if spec is not None: + specs.append(spec) + return specs + + +# --------------------------------------------------------------------------- +# Tool handlers (run inside the Flask process – no HTTP round-trip) +# --------------------------------------------------------------------------- + + +def _call_tool( + name: str, + arguments: Dict[str, Any], + wrapper, + user_id: Optional[str] = None, + notify=None, +) -> Dict: + """Dispatch a tools/call request to the appropriate archi internals. + + ``notify`` is an optional callable(message, progress, total) that sends a + ``notifications/progress`` event back to the MCP client over the SSE stream. + It is only provided when the client included ``_meta.progressToken`` in the + tools/call request. + """ + try: + if name == "archi_query": + return _tool_query(arguments, wrapper, user_id, notify=notify) + elif name == "archi_list_documents": + return _tool_list_documents(arguments, wrapper) + elif name == "archi_get_document_content": + return _tool_get_document_content(arguments, wrapper) + elif name == "archi_search_document_metadata": + return _tool_search_document_metadata(arguments, wrapper) + elif name == "archi_list_metadata_schema": + return _tool_list_metadata_schema(wrapper) + elif name == "archi_search_document_content": + return _tool_search_document_content(arguments, wrapper) + elif name == "archi_get_document_chunks": + return _tool_get_document_chunks(arguments, wrapper) + elif name == "archi_get_data_stats": + return _tool_get_data_stats(arguments, wrapper) + elif name == "archi_get_deployment_info": + return _tool_deployment_info(wrapper) + elif name == "archi_list_agents": + return _tool_list_agents(wrapper) + elif name == "archi_get_agent_spec": + return _tool_get_agent_spec(arguments, wrapper) + elif name == "archi_health": + return _text("status: OK\ndatabase: OK") + else: + return _text(f"ERROR: Unknown tool '{name}'.") + except Exception as exc: + logger.exception("MCP tool %s raised an exception", name) + return _text(f"ERROR: {exc}") + + +def _tool_query( + arguments: Dict[str, Any], + wrapper, + user_id: Optional[str] = None, + notify=None, +) -> Dict: + question = (arguments.get("question") or "").strip() + if not question: + return _text("ERROR: 'question' is required.") + + conversation_id = arguments.get("conversation_id") + provider = arguments.get("provider") or None + model = arguments.get("model") or None + config_name = arguments.get("config_name") or None + default_timeout_ms = 30000 + try: + chat_cfg = (wrapper.config or {}).get("services", {}).get("chat_app", {}) + default_timeout_ms = int(float(chat_cfg.get("client_timeout_seconds", 30)) * 1000) + except Exception: + pass + # client_timeout is in milliseconds (matching UI convention); convert to seconds + client_timeout_ms = arguments.get("client_timeout", default_timeout_ms) + try: + client_timeout = max(float(client_timeout_ms) / 1000.0, 1.0) + except (TypeError, ValueError): + client_timeout = max(float(default_timeout_ms) / 1000.0, 1.0) + client_id = f"mcp-sse-{uuid.uuid4().hex[:12]}" + now = datetime.now(timezone.utc) + + # Chunk events carry accumulated text (not deltas) — keep only the last one. + answer: str = "" + new_conv_id = None + + for event in wrapper.chat.stream( + [["User", question]], + conversation_id, + client_id, + False, # is_refresh + now, # server_received_msg_ts + now.timestamp(), # client_sent_msg_ts + client_timeout, # client_timeout (seconds, converted from ms) + config_name, # config_name (e.g. 'comp_ops', or None for active config) + provider=provider, + model=model, + user_id=user_id, + ): + etype = event.get("type", "") + + if etype == "error": + return _text(f"ERROR: {event.get('message', 'unknown error')}") + + elif etype == "thinking_start": + if notify: + notify("Thinking…") + + elif etype == "thinking_end": + if notify: + thinking = event.get("thinking_content", "") + if thinking: + preview = thinking[:120].replace("\n", " ") + notify(f"Thought: {preview}{'…' if len(thinking) > 120 else ''}") + + elif etype == "tool_start": + if notify: + tool_name = event.get("tool_name", "tool") + tool_args = event.get("tool_args") or {} + if tool_args: + args_preview = ", ".join( + f"{k}={str(v)[:40]}" for k, v in (tool_args if isinstance(tool_args, dict) else {}).items() + ) + notify(f"Calling {tool_name}({args_preview})") + else: + notify(f"Calling {tool_name}()") + + elif etype == "tool_output": + if notify: + notify(f"Got result from {event.get('tool_name', 'tool')}") + + elif etype == "chunk": + content = event.get("content", "") + if content: + answer = content + if notify: + notify("Generating answer…") + + elif etype == "final": + conv_id = event.get("conversation_id") + if conv_id is not None: + new_conv_id = conv_id + response = event.get("response") + final_answer = getattr(response, "answer", None) if response is not None else None + if final_answer: + answer = final_answer + parts = [answer] + if new_conv_id is not None: + parts.append( + f"\n\n---\n_conversation_id: {new_conv_id} " + "(pass this to archi_query to continue the conversation)_" + ) + return _text("".join(parts)) + + +def _tool_list_documents(arguments: Dict[str, Any], wrapper) -> Dict: + limit = _clamp_int(arguments.get("limit", 50), default=50, minimum=1, maximum=500) + offset = _clamp_int(arguments.get("offset", 0), default=0, minimum=0, maximum=1_000_000) + conversation_id = arguments.get("conversation_id") + search: Optional[str] = arguments.get("search") or None + source_type: Optional[str] = arguments.get("source_type") or None + enabled_filter = (arguments.get("enabled") or "").strip().lower() or None + if enabled_filter not in {None, "enabled", "disabled", "all"}: + return _text("ERROR: 'enabled' must be one of: enabled, disabled, all.") + + result = wrapper.chat.data_viewer.list_documents( + conversation_id=conversation_id, + source_type=source_type, + search=search, + enabled_filter=None if enabled_filter in {None, "all"} else enabled_filter, + limit=limit, + offset=offset, + ) + docs = result.get("documents", result.get("items", [])) + total = result.get("total", len(docs)) + + lines = [f"Found {total} document(s) (offset={offset}, limit={limit}):\n"] + for doc in docs: + display = _document_display_name(doc) + source = doc.get("source_type", doc.get("type", "")) + doc_hash = doc.get("hash", doc.get("id", "")) + status = doc.get("ingestion_status", "unknown") + enabled = doc.get("enabled") + extra: list[str] = [source] if source else [] + if status: + extra.append(f"status={status}") + if enabled is not None: + extra.append(f"enabled={'yes' if enabled else 'no'}") + lines.append(f" • {display} [{' | '.join(extra)}] hash={doc_hash}") + + lines.append( + "\nUse archi_get_document_content(document_hash=) to read a document." + ) + return _text("\n".join(lines)) + + +def _tool_get_document_content(arguments: Dict[str, Any], wrapper) -> Dict: + doc_hash = (arguments.get("document_hash") or "").strip() + if not doc_hash: + return _text("ERROR: 'document_hash' is required.") + + max_size = _clamp_int(arguments.get("max_size", 100000), default=100000, minimum=1000, maximum=1_000_000) + result = wrapper.chat.data_viewer.get_document_content(doc_hash, max_size) + if result is None: + return _text(f"ERROR: Document not found: {doc_hash}") + + content = result.get("content", result.get("text", json.dumps(result, indent=2))) + if result.get("truncated"): + content = f"{content}\n\n---\n(truncated at {max_size} bytes/chars)" + return _text(content) + + +def _tool_search_document_metadata(arguments: Dict[str, Any], wrapper) -> Dict: + query = (arguments.get("query") or "").strip() + if not query: + return _text("ERROR: 'query' is required.") + + limit = _clamp_int(arguments.get("limit", 10), default=10, minimum=1, maximum=100) + filters, free_query = _parse_metadata_query(query) + catalog = wrapper.chat.data_viewer.catalog + results = catalog.search_metadata( + free_query, + limit=limit, + filters=filters or None, + ) + + if not results: + return _text("No documents matched that metadata query.") + + lines = [f"Found {len(results)} metadata match(es):\n"] + for item in results: + metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {} + path = item.get("path") + display = ( + metadata.get("display_name") + or metadata.get("file_name") + or metadata.get("title") + or metadata.get("url") + or item.get("hash") + or "unknown" + ) + lines.append(f" • {display} hash={item.get('hash')}") + lines.append(f" Path: {path}") + if metadata.get("source_type"): + lines.append(f" Source: {metadata.get('source_type')}") + if metadata.get("ticket_id"): + lines.append(f" Ticket: {metadata.get('ticket_id')}") + if metadata.get("relative_path"): + lines.append(f" Relative path: {metadata.get('relative_path')}") + if metadata.get("url"): + lines.append(f" URL: {_truncate_text(metadata.get('url'), max_chars=180)}") + + lines.append( + "\nUse archi_get_document_content(document_hash=) to inspect a result." + ) + return _text("\n".join(lines)) + + +def _tool_list_metadata_schema(wrapper) -> Dict: + catalog = wrapper.chat.data_viewer.catalog + distinct = catalog.get_distinct_metadata(["source_type", "suffix"]) + keys = sorted(_METADATA_FILTER_KEYS) + source_types = distinct.get("source_type", []) + suffixes = distinct.get("suffix", []) + + lines = [ + "Supported metadata keys: " + (", ".join(keys) or "none"), + "source_type values: " + (", ".join(source_types) or "none"), + "suffix values: " + (", ".join(suffixes) or "none"), + "", + "Examples:", + " source_type:git relative_path:docs/README.md", + " ticket_id:CMSPROD-1234", + " source_type:web OR source_type:git", + " url:github.com/org/repo", + "", + "Legacy aliases: resource_type -> source_type, resource_id -> ticket_id", + ] + return _text("\n".join(lines)) + + +def _tool_search_document_content(arguments: Dict[str, Any], wrapper) -> Dict: + from src.data_manager.vectorstore.loader_utils import load_text_from_path + + query = (arguments.get("query") or "").strip() + if not query: + return _text("ERROR: 'query' is required.") + + regex = bool(arguments.get("regex", False)) + case_sensitive = bool(arguments.get("case_sensitive", False)) + before = _clamp_int(arguments.get("before", 0), default=0, minimum=0, maximum=20) + after = _clamp_int(arguments.get("after", 0), default=0, minimum=0, maximum=20) + max_matches_per_document = _clamp_int( + arguments.get("max_matches_per_document", 3), + default=3, + minimum=1, + maximum=20, + ) + limit = _clamp_int(arguments.get("limit", 5), default=5, minimum=1, maximum=20) + metadata_query = (arguments.get("metadata_query") or "").strip() + + try: + pattern = _compile_query_pattern(query, regex=regex, case_sensitive=case_sensitive) + except re.error as exc: + return _text(f"ERROR: invalid regex: {exc}") + + catalog = wrapper.chat.data_viewer.catalog + candidate_metadata: Dict[str, Dict[str, Any]] = {} + if metadata_query: + filters, free_query = _parse_metadata_query(metadata_query) + candidates = catalog.search_metadata( + free_query, + limit=None, + filters=filters or None, + ) + iterable = [] + for item in candidates: + resource_hash = item.get("hash") + if not resource_hash: + continue + path = catalog.get_filepath_for_hash(resource_hash) + if path: + iterable.append((resource_hash, path)) + metadata = item.get("metadata") + if isinstance(metadata, dict): + candidate_metadata[resource_hash] = metadata + else: + iterable = list(catalog.iter_files()) + + hits: list[Dict[str, Any]] = [] + for resource_hash, path in iterable: + metadata = candidate_metadata.get(resource_hash) or catalog.get_metadata_for_hash(resource_hash) or {} + text = load_text_from_path(path) or "" + if not text: + continue + matches = _grep_text_lines( + text, + pattern, + before=before, + after=after, + max_matches=max_matches_per_document, + ) + if not matches: + continue + hits.append( + { + "hash": resource_hash, + "path": path, + "metadata": metadata, + "matches": matches, + } + ) + if len(hits) >= limit: + break + + if not hits: + return _text("No indexed document contents matched that search query.") + + lines = [f"Found {len(hits)} matching document(s):\n"] + for item in hits: + metadata = item["metadata"] if isinstance(item["metadata"], dict) else {} + display = ( + metadata.get("display_name") + or metadata.get("file_name") + or metadata.get("title") + or str(item["path"]) + ) + source = metadata.get("source_type") or "unknown" + lines.append(f" • {display} [{source}] hash={item['hash']}") + lines.append(f" Path: {item['path']}") + for match in item["matches"]: + before_lines = match.get("before") or [] + after_lines = match.get("after") or [] + for ctx in before_lines: + lines.append(f" B: {_truncate_text(ctx, max_chars=240)}") + lines.append(f" L{match.get('line', '?')}: {_truncate_text(match.get('text'), max_chars=240)}") + for ctx in after_lines: + lines.append(f" A: {_truncate_text(ctx, max_chars=240)}") + + return _text("\n".join(lines)) + + +def _tool_get_document_chunks(arguments: Dict[str, Any], wrapper) -> Dict: + doc_hash = (arguments.get("document_hash") or "").strip() + if not doc_hash: + return _text("ERROR: 'document_hash' is required.") + + offset = _clamp_int(arguments.get("offset", 0), default=0, minimum=0, maximum=1_000_000) + limit = _clamp_int(arguments.get("limit", 20), default=20, minimum=1, maximum=100) + max_chars_per_chunk = _clamp_int( + arguments.get("max_chars_per_chunk", 600), + default=600, + minimum=80, + maximum=5000, + ) + + chunks = wrapper.chat.data_viewer.get_document_chunks(doc_hash) + if not chunks: + return _text(f"No stored chunks found for document: {doc_hash}") + + selected = chunks[offset: offset + limit] + lines = [ + f"Document {doc_hash} has {len(chunks)} chunk(s); showing {len(selected)} from offset {offset}:\n" + ] + for chunk in selected: + start_char = chunk.get("start_char") + end_char = chunk.get("end_char") + lines.append( + f" • chunk {chunk.get('index')} chars={start_char}-{end_char}\n" + f" {_truncate_text(chunk.get('text'), max_chars=max_chars_per_chunk)}" + ) + return _text("\n".join(lines)) + + +def _tool_get_data_stats(arguments: Dict[str, Any], wrapper) -> Dict: + conversation_id = arguments.get("conversation_id") + stats = wrapper.chat.data_viewer.get_stats(conversation_id) + by_source_type = stats.get("by_source_type") or {} + status_counts = stats.get("status_counts") or {} + + lines = [ + "Corpus statistics:", + f" Total documents: {stats.get('total_documents', 0)}", + f" Total chunks: {stats.get('total_chunks', 0)}", + f" Enabled documents: {stats.get('enabled_documents', 0)}", + f" Disabled documents: {stats.get('disabled_documents', 0)}", + f" Total size (bytes): {stats.get('total_size_bytes', 0)}", + f" Last sync: {stats.get('last_sync') or 'n/a'}", + "", + "Ingestion status:", + f" pending={status_counts.get('pending', 0)}", + f" embedding={status_counts.get('embedding', 0)}", + f" embedded={status_counts.get('embedded', 0)}", + f" failed={status_counts.get('failed', 0)}", + ] + + if by_source_type: + lines.append("") + lines.append("By source type:") + for source_type, counts in sorted(by_source_type.items()): + total = counts.get("total", 0) if isinstance(counts, dict) else counts + enabled = counts.get("enabled", total) if isinstance(counts, dict) else total + lines.append(f" {source_type}: total={total}, enabled={enabled}") + + return _text("\n".join(lines)) + + +def _tool_deployment_info(wrapper) -> Dict: + from src.utils.config_access import get_dynamic_config, get_full_config, get_static_config + + config = get_full_config() or {} + static = get_static_config() + services = config.get("services", {}) + chat_cfg = services.get("chat_app", {}) + dm_cfg = services.get("data_manager", {}) + mcp_servers = config.get("mcp_servers", {}) or {} + + try: + dynamic = get_dynamic_config() + except Exception: + dynamic = None + + lines = [ + f"# archi Deployment: {config.get('name', 'unknown')}", + "", + "## Active configuration", + f" Pipeline: {chat_cfg.get('pipeline', 'n/a')}", + f" Agent class: {chat_cfg.get('agent_class', chat_cfg.get('pipeline', 'n/a'))}", + ] + if dynamic: + lines += [ + f" Active agent: {dynamic.active_agent_name or getattr(getattr(wrapper.chat, 'agent_spec', None), 'name', 'n/a')}", + f" Model: {dynamic.active_model}", + f" Temperature: {dynamic.temperature}", + f" Max tokens: {dynamic.max_tokens}", + f" Docs retrieved (k): {dynamic.num_documents_to_retrieve}", + f" Hybrid search: {dynamic.use_hybrid_search}", + f" BM25 weight: {dynamic.bm25_weight}", + f" Semantic weight: {dynamic.semantic_weight}", + ] + else: + lines += [ + f" Active agent: {getattr(getattr(wrapper.chat, 'agent_spec', None), 'name', 'n/a')}", + ] + + embedding_cfg = dm_cfg.get("embedding", {}) + lines += [ + "", + "## Embedding", + f" Model: {embedding_cfg.get('model', 'n/a')}", + f" Chunk size: {embedding_cfg.get('chunk_size', 'n/a')}", + f" Chunk overlap: {embedding_cfg.get('chunk_overlap', 'n/a')}", + "", + "## Runtime", + f" Available providers: {', '.join(static.available_providers or []) or 'n/a'}", + f" Available pipelines: {', '.join(static.available_pipelines or []) or 'n/a'}", + f" MCP servers: {', '.join(sorted(mcp_servers.keys())) or 'none'}", + f" MCP endpoint enabled: {services.get('mcp_server', {}).get('enabled', False)}", + ] + return _text("\n".join(lines)) + + +def _tool_list_agents(wrapper) -> Dict: + agents_dir = wrapper._get_agents_dir() + lines = [] + for spec in _list_agent_specs(agents_dir): + tools_str = ", ".join(spec.get("tools", []) or []) or "none" + path = spec["path"] + lines.append(f" • {spec['name']} ({path.name})") + lines.append(f" Tools: {tools_str}") + + if not lines: + return _text("No agent specs found in this deployment.") + return _text("Available agents:\n" + "\n".join(lines)) + + +def _tool_get_agent_spec(arguments: Dict[str, Any], wrapper) -> Dict: + agent_name = (arguments.get("agent_name") or "").strip() + if not agent_name: + return _text("ERROR: 'agent_name' is required.") + + agents_dir = wrapper._get_agents_dir() + for spec in _list_agent_specs(agents_dir): + if spec["name"] == agent_name: + return _text(spec["content"]) + return _text(f"ERROR: Agent not found: {agent_name}") + + +# --------------------------------------------------------------------------- +# JSON-RPC dispatcher +# --------------------------------------------------------------------------- + + +def _dispatch(body: Dict, session_queue: queue.Queue, wrapper, user_id: Optional[str] = None) -> None: + """Process one incoming JSON-RPC message and enqueue the response if needed.""" + rpc_id = body.get("id") + method = body.get("method", "") + params = body.get("params") or {} + + # Notifications have no id – no response expected. + if rpc_id is None: + return + + if method == "initialize": + response = _ok( + { + "protocolVersion": _MCP_VERSION, + "capabilities": {"tools": {}}, + "serverInfo": _SERVER_INFO, + }, + rpc_id, + ) + elif method == "tools/list": + response = _ok({"tools": _TOOLS}, rpc_id) + elif method == "tools/call": + # Extract optional progress token from _meta so we can stream status + # events back to the client while archi works. + meta = params.get("_meta") or {} + progress_token = meta.get("progressToken") + + logger.info( + "tools/call %s – progressToken=%s", + params.get("name", "?"), + progress_token if progress_token is not None else "", + ) + + notify_fn: Optional[Callable[[str, Optional[int], Optional[int]], None]] = None + if progress_token is not None: + _progress_counter = [0] + + def _notify( + message: str, + progress: Optional[int] = None, + total: Optional[int] = None, + ) -> None: + _progress_counter[0] += 1 + p = progress if progress is not None else _progress_counter[0] + notification: Dict[str, Any] = { + "jsonrpc": "2.0", + "method": "notifications/progress", + "params": { + "progressToken": progress_token, + "progress": p, + "message": message, + }, + } + if total is not None: + notification["params"]["total"] = total + session_queue.put(notification) + + notify_fn = _notify + + result = _call_tool( + params.get("name", ""), + params.get("arguments") or {}, + wrapper, + user_id, + notify=notify_fn, + ) + response = _ok(result, rpc_id) + elif method == "ping": + response = _ok({}, rpc_id) + else: + response = _err(-32601, f"Method not found: {method}", rpc_id) + + session_queue.put(response) + + +def _dispatch_and_release( + body: Dict, + session_queue: queue.Queue, + wrapper, + user_id: Optional[str] = None, +) -> None: + try: + _dispatch(body, session_queue, wrapper, user_id) + finally: + _dispatch_slots.release() + + +# --------------------------------------------------------------------------- +# Blueprint factory +# --------------------------------------------------------------------------- + + +def register_mcp_sse( + app, + wrapper, + pg_config: Optional[dict] = None, + auth_enabled: bool = False, + public_url: Optional[str] = None, +) -> None: + """Register the MCP SSE endpoints on a Flask app. + + ``public_url``: externally reachable base URL (e.g. ``https://example.com``). + When set, the ``endpoint`` SSE event uses it to build the absolute POST URL + instead of inferring from request headers. + """ + mcp = Blueprint("mcp_sse", __name__) + + def _auth_check(): + """Return (user_id, error_response) tuple. error_response is None on success.""" + if not auth_enabled: + return None, None + token = _extract_bearer_token(request) + if not token: + resp = jsonify({ + "error": "unauthorized", + "message": "MCP access requires a bearer token. " + "Visit /mcp/auth to generate one after logging in.", + "login_url": "/mcp/auth", + }) + resp.status_code = 401 + return None, resp + user_id = _validate_mcp_token(token, pg_config) + if not user_id: + resp = jsonify({ + "error": "invalid_token", + "message": "The bearer token is invalid or has expired. " + "Visit /mcp/auth to generate a new token.", + "login_url": "/mcp/auth", + }) + resp.status_code = 401 + return None, resp + return user_id, None + + @mcp.route("/mcp/sse") + def sse(): + """Open an SSE stream for one MCP client session.""" + user_id, err = _auth_check() + if err is not None: + return err + + session_id = uuid.uuid4().hex + q: queue.Queue = queue.Queue() + # Resolve base URL now — generators run outside request context. + if public_url: + _base = public_url.rstrip("/") + else: + fwd_proto = request.headers.get("X-Forwarded-Proto") or request.scheme + fwd_host = request.headers.get("X-Forwarded-Host") or request.host + _base = f"{fwd_proto}://{fwd_host}" + post_url = f"{_base}/mcp/messages?session_id={session_id}" + with _sessions_lock: + _sessions[session_id] = {"queue": q, "user_id": user_id} + logger.info("MCP SSE session %s opened (user=%s)", session_id, user_id) + + def generate(): + yield f"event: endpoint\ndata: {post_url}\n\n" + try: + while True: + try: + msg = q.get(timeout=_KEEPALIVE_TIMEOUT) + if msg is None: + break + yield f"event: message\ndata: {json.dumps(msg)}\n\n" + except queue.Empty: + yield ": keepalive\n\n" + finally: + with _sessions_lock: + _sessions.pop(session_id, None) + + return Response( + stream_with_context(generate()), + content_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + "Connection": "keep-alive", + }, + ) + + @mcp.route("/mcp/messages", methods=["POST"]) + def messages(): + """Receive a JSON-RPC message from an MCP client.""" + # Auth was verified when the SSE stream was opened; session_id is the proof. + session_id = request.args.get("session_id", "") + with _sessions_lock: + session_entry = _sessions.get(session_id) + if session_entry is None: + logger.warning("MCP /mcp/messages: unknown session_id=%r (active sessions: %s)", + session_id, list(_sessions.keys())) + return {"error": "unknown or expired session_id"}, 404 + + q = session_entry["queue"] + user_id = session_entry.get("user_id") + + body = request.get_json(silent=True) + if not body: + return {"error": "request body must be valid JSON"}, 400 + + logger.info("MCP /mcp/messages session=%s method=%s id=%s", + session_id, body.get("method", "?"), body.get("id")) + # Use a bounded dispatch pool to prevent thread storms under high parallel load. + if not _dispatch_slots.acquire(blocking=False): + logger.warning( + "MCP dispatch overloaded: rejecting request method=%s id=%s", + body.get("method", "?"), + body.get("id"), + ) + rpc_id = body.get("id") + if rpc_id is not None: + q.put(_err(-32001, "Server is busy. Please retry shortly.", rpc_id)) + return "", 202 + + try: + _dispatch_executor.submit(_dispatch_and_release, body, q, wrapper, user_id) + except Exception: + _dispatch_slots.release() + raise + return "", 202 + + app.register_blueprint(mcp) + if auth_enabled: + logger.info("Registered MCP SSE endpoint at /mcp/sse (auth required – Bearer token)") + else: + logger.info("Registered MCP SSE endpoint at /mcp/sse (no auth)") diff --git a/src/interfaces/chat_app/service_alerts.py b/src/interfaces/chat_app/service_alerts.py index 3cc92b5f7..f7114249e 100644 --- a/src/interfaces/chat_app/service_alerts.py +++ b/src/interfaces/chat_app/service_alerts.py @@ -34,6 +34,16 @@ _auth_enabled: bool = False _chat_app_config: dict = {} +# Cache for banner alerts to avoid a DB connection on every template render. +_alerts_cache: list = [] +_alerts_cache_until: datetime = datetime.min +_ALERTS_CACHE_TTL = timedelta(seconds=30) + + +def _invalidate_alerts_cache(): + global _alerts_cache_until + _alerts_cache_until = datetime.min + # --------------------------------------------------------------------------- # Helpers (public — called by the context processor in app.py) @@ -71,14 +81,20 @@ def is_alert_manager() -> bool: def get_active_banner_alerts() -> list: - """Return alerts that should appear in the page banner (non-expired, active).""" + """Return alerts that should appear in the page banner (non-expired, active). + + Results are cached for 30 seconds to avoid a DB round-trip on every render. + """ + global _alerts_cache, _alerts_cache_until + if datetime.now() < _alerts_cache_until: + return _alerts_cache try: conn = psycopg2.connect(**_pg_config) cursor = conn.cursor() try: cursor.execute(SQL_LIST_ACTIVE_BANNER_ALERTS) rows = cursor.fetchall() - return [ + _alerts_cache = [ { 'id': row[0], 'severity': row[1], @@ -90,12 +106,14 @@ def get_active_banner_alerts() -> list: } for row in rows ] + _alerts_cache_until = datetime.now() + _ALERTS_CACHE_TTL + return _alerts_cache finally: cursor.close() conn.close() except Exception as exc: logger.warning("Failed to fetch banner alerts: %s", exc) - return [] + return _alerts_cache # serve stale cache on error rather than failing # --------------------------------------------------------------------------- @@ -193,6 +211,7 @@ def create_alert(): cursor.execute(SQL_SET_ALERT_EXPIRY, (expires_at, alert_id)) conn.commit() + _invalidate_alerts_cache() logger.info( "Service alert %d created by %s: [%s] %s", alert_id, created_by, severity, message, @@ -219,6 +238,8 @@ def delete_alert(alert_id: int): cursor.execute(SQL_DELETE_ALERT, (alert_id,)) deleted = cursor.rowcount > 0 conn.commit() + if deleted: + _invalidate_alerts_cache() finally: cursor.close() conn.close() diff --git a/src/interfaces/chat_app/static/chat.css b/src/interfaces/chat_app/static/chat.css index 775d0772c..baaf0075f 100644 --- a/src/interfaces/chat_app/static/chat.css +++ b/src/interfaces/chat_app/static/chat.css @@ -468,6 +468,16 @@ body { color: var(--error-text); } +/* Mattermost-originated conversations get a subtle left accent */ +.conversation-item.from-mattermost { + border-left: 2px solid var(--accent-color, #1565c0); + padding-left: 10px; +} + +.conversation-item.from-mattermost .conversation-item-icon { + color: var(--accent-color, #1565c0); +} + /* ----------------------------------------------------------------------------- User Profile Widget (Bottom of Sidebar) ----------------------------------------------------------------------------- */ diff --git a/src/interfaces/chat_app/static/chat.js b/src/interfaces/chat_app/static/chat.js index 187f7b4dc..2cb8c29f8 100644 --- a/src/interfaces/chat_app/static/chat.js +++ b/src/interfaces/chat_app/static/chat.js @@ -1777,13 +1777,22 @@ const UI = { for (const conv of items) { const isActive = conv.conversation_id === activeId; const title = Utils.escapeHtml(conv.title || `Conversation ${conv.conversation_id}`); - + const isMattermost = conv.archi_service === 'mattermost'; + + // Icon: Mattermost grid for MM conversations, chat bubble for web-chat + const icon = isMattermost + ? `` + : ``; + html += ` -
- +
+ ${icon} ${title} + + + +
+ +
+

+ File: .vscode/mcp.json (in your project root) or + settings.json under mcp.servers. + VS Code will open a browser to log in on first use — no token needed in the config. +

+
{
+  "servers": {
+    "archi": {
+      "type": "sse",
+      "url": "{{ mcp_url }}"
+    }
+  }
+}
+ +
+ +
+

+ File: ~/.cursor/mcp.json. + Cursor will open a browser to log in on first use — no token needed in the config. +

+
{
+  "mcpServers": {
+    "archi": {
+      "url": "{{ mcp_url }}"
+    }
+  }
+}
+ +
+ +
+

+ File: ~/Library/Application Support/Claude/claude_desktop_config.json (macOS) + or %APPDATA%\Claude\claude_desktop_config.json (Windows). + Restart Claude Desktop after saving. mcp-remote will open a browser to + log in on first use and store the token locally — no token needed in the config. +

+
{
+  "mcpServers": {
+    "archi": {
+      "command": "npx",
+      "args": [
+        "mcp-remote",
+        "{{ mcp_url }}"
+      ]
+    }
+  }
+}
+ +
+ +
+

+ Run this once in your terminal to register archi globally. Claude Code stores it in + ~/.claude.json. It will open a browser to log in on first use. +

+
claude mcp add --transport sse archi {{ mcp_url }}
+ +

+ Or add it to .mcp.json in your project root for project-scoped access: +

+
{
+  "mcpServers": {
+    "archi": {
+      "type": "sse",
+      "url": "{{ mcp_url }}"
+    }
+  }
+}
+ +
+
+ + +
+ + Advanced: manual bearer token + +
+

+ For clients that don't support OAuth, use this static token directly in the + Authorization: Bearer <token> header. + Keep it secret — it grants access to archi on your behalf. +

+
+
{{ token }}
+ +
+
+

Rotate token

+

+ Regenerate a new token. Your current token will stop working immediately. + OAuth sessions are unaffected. +

+
+ +
+
+
+
+ + + + + diff --git a/src/interfaces/mattermost.py b/src/interfaces/mattermost.py index 4d91ec9fa..2661cca11 100644 --- a/src/interfaces/mattermost.py +++ b/src/interfaces/mattermost.py @@ -1,95 +1,769 @@ import json import os -import time -from threading import Thread +import threading +from pathlib import Path +from typing import Dict, List, Optional, Tuple import requests -from flask import Flask +from authlib.integrations.flask_client import OAuth +from flask import Flask, request as flask_request, jsonify, session, url_for from src.archi.archi import archi +from src.archi.pipelines.agents.agent_spec import AgentSpecError, select_agent_spec from src.data_manager.data_manager import DataManager from src.utils.env import read_secret from src.utils.logging import get_logger from src.utils.config_access import get_full_config +from src.utils.mattermost_auth import MattermostAuthManager +from src.utils.mattermost_token_service import MattermostTokenService +from src.utils.rbac.jwt_parser import get_user_roles +from src.utils.rbac.mattermost_context import get_mattermost_context, mattermost_user_context +from src.utils.rbac.registry import get_registry +from src.utils.rbac.permission_enum import Permission logger = get_logger(__name__) + +# --------------------------------------------------------------------------- +# MattermostClient — stateless REST API wrapper +# --------------------------------------------------------------------------- + +class MattermostClient: + """Stateless HTTP wrapper for the Mattermost REST API. + + Requires a Personal Access Token (PAK). All methods raise + requests.HTTPError on unexpected status codes; send_typing is + best-effort and swallows exceptions. + """ + + def __init__(self, base_url: str, personal_access_token: str): + self._base = base_url.rstrip('/') + self._headers = { + 'Authorization': f'Bearer {personal_access_token}', + 'Content-Type': 'application/json', + } + + def create_post(self, channel_id: str, message: str, root_id: str = "") -> dict: + """Create a post. Pass root_id to make it a thread reply.""" + payload: dict = {"channel_id": channel_id, "message": message} + if root_id: + payload["root_id"] = root_id + r = requests.post(f"{self._base}/api/v4/posts", json=payload, headers=self._headers) + r.raise_for_status() + return r.json() + + def get_thread(self, post_id: str) -> dict: + """GET /api/v4/posts/{post_id}/thread — full thread with ordered post list.""" + r = requests.get(f"{self._base}/api/v4/posts/{post_id}/thread", headers=self._headers) + r.raise_for_status() + return r.json() + + def get_channel_posts(self, channel_id: str, per_page: int = 60, before: str = "") -> dict: + """GET /api/v4/channels/{channel_id}/posts""" + params: dict = {"per_page": per_page} + if before: + params["before"] = before + r = requests.get( + f"{self._base}/api/v4/channels/{channel_id}/posts", + params=params, + headers=self._headers, + ) + r.raise_for_status() + return r.json() + + def get_me(self) -> dict: + """GET /api/v4/users/me — fetch the bot's own user info.""" + r = requests.get(f"{self._base}/api/v4/users/me", headers=self._headers) + r.raise_for_status() + return r.json() + + def add_reaction(self, user_id: str, post_id: str, emoji_name: str) -> None: + """POST /api/v4/reactions""" + payload = {"user_id": user_id, "post_id": post_id, "emoji_name": emoji_name} + r = requests.post(f"{self._base}/api/v4/reactions", json=payload, headers=self._headers) + r.raise_for_status() + + def delete_reaction(self, user_id: str, post_id: str, emoji_name: str) -> None: + """DELETE /api/v4/users/{user_id}/posts/{post_id}/reactions/{emoji_name}""" + r = requests.delete( + f"{self._base}/api/v4/users/{user_id}/posts/{post_id}/reactions/{emoji_name}", + headers=self._headers, + ) + if r.status_code not in (200, 204, 404): + r.raise_for_status() + + def send_typing(self, channel_id: str, parent_id: str = "") -> None: + """POST /api/v4/users/me/typing — best-effort, swallows all exceptions.""" + payload: dict = {"channel_id": channel_id} + if parent_id: + payload["parent_id"] = parent_id + try: + requests.post( + f"{self._base}/api/v4/users/me/typing", + json=payload, + headers=self._headers, + timeout=3, + ) + except Exception: + pass + + +# --------------------------------------------------------------------------- +# ThreadContextManager — per-thread conversation history via PostgreSQL +# --------------------------------------------------------------------------- + +class ThreadContextManager: + """Per-thread conversation history backed by PostgreSQL ConversationService. + + Mattermost messages are stored in the same `conversations` table as web-chat + messages, keyed by a proper INTEGER `conversation_id` from `conversation_metadata`. + A stable string `source_ref` (e.g. "mm_thread_") is stored in + `conversation_metadata.source_ref` so the same integer id is reused across restarts. + + The `client_id` stored in conversation_metadata is "mm_user_", which + the web-chat `list_conversations` endpoint uses to surface Mattermost conversations + in the user's sidebar. + """ + + def __init__(self, conversation_service, bot_user_id: str = "", context_window: int = 20): + self._svc = conversation_service + self._bot_user_id = bot_user_id + self._context_window = context_window + # In-process cache: source_ref → integer conversation_id (avoids repeated DB lookups) + self._id_cache: Dict[str, int] = {} + + # ------------------------------------------------------------------ + # Source-ref helpers (stable string keys used by the event handler) + # ------------------------------------------------------------------ + + @staticmethod + def source_ref_for_thread(root_post_id: str) -> str: + return f"mm_thread_{root_post_id}" + + @staticmethod + def source_ref_for_user_channel(channel_id: str, user_id: str) -> str: + """Fallback key for non-threaded messages: per-user per-channel conversation.""" + return f"mm_channel_{channel_id}_user_{user_id}" + + @staticmethod + def mm_client_id(username: str) -> str: + """Bridge key written into conversation_metadata.client_id. + + The web-chat list_conversations query includes ``client_id = 'mm_user_'`` + so Mattermost conversations appear automatically in the authenticated user's sidebar. + """ + return f"mm_user_{username}" if username else "" + + # ------------------------------------------------------------------ + # Integer ID resolution (creates conversation_metadata row on first use) + # ------------------------------------------------------------------ + + def _get_or_create_int_id( + self, source_ref: str, username: str = "", title: str = "" + ) -> Optional[int]: + """Return the integer conversation_id for *source_ref*, creating if needed. + + Results are cached in-process to avoid redundant DB round-trips. + Returns None only on DB error (caller falls back to stateless mode). + """ + if source_ref in self._id_cache: + return self._id_cache[source_ref] + try: + client_id = self.mm_client_id(username) + int_id = self._svc.get_or_create_conversation_for_ref( + source_ref=source_ref, + mm_client_id=client_id, + title=title, + ) + self._id_cache[source_ref] = int_id + return int_id + except Exception as e: + logger.warning("ThreadContextManager: could not resolve int id for %r: %s", source_ref, e) + return None + + # ------------------------------------------------------------------ + # History retrieval + # ------------------------------------------------------------------ + + def build_history_from_db( + self, source_ref: str, username: str = "", title: str = "" + ) -> Tuple[Optional[int], List[Tuple[str, str]]]: + """Return (int_conv_id, history) for *source_ref*. + + int_conv_id is the INTEGER primary key of the conversation_metadata row + (used for storing the exchange). history is a list of (role, content) tuples. + """ + int_id = self._get_or_create_int_id(source_ref, username=username, title=title) + if int_id is None: + return None, [] + try: + messages = self._svc.get_conversation_history(int_id, limit=self._context_window * 2) + except Exception as e: + logger.warning("ThreadContextManager: failed to fetch DB history for id=%s: %s", int_id, e) + return int_id, [] + history = [ + ("User" if m.sender == "user" else "AI", m.content) + for m in messages + ] + return int_id, history + + def build_history_from_thread(self, thread_data: dict) -> List[Tuple[str, str]]: + """Cold-start: build history from a live Mattermost thread API response.""" + order = thread_data.get("order", []) + posts = thread_data.get("posts", {}) + history = [] + for post_id in order: + post = posts.get(post_id, {}) + if post.get("type", ""): # skip system messages + continue + content = post.get("message", "").strip() + if not content: + continue + role = "AI" if post.get("user_id") == self._bot_user_id else "User" + history.append((role, content)) + return history[-(self._context_window * 2):] + + # ------------------------------------------------------------------ + # Storage + # ------------------------------------------------------------------ + + def store_exchange( + self, + conv_int_id: int, + user_content: str, + bot_content: str, + model_used: Optional[str] = None, + pipeline_used: Optional[str] = None, + ) -> None: + """Persist the user message and bot reply to PostgreSQL. + + Uses the INTEGER conv_int_id so the insert respects the FK constraint + on conversations.conversation_id → conversation_metadata.conversation_id. + """ + from datetime import datetime, timezone + from src.utils.conversation_service import Message + now = datetime.now(timezone.utc) + try: + self._svc.insert_messages([ + Message( + conversation_id=conv_int_id, + sender="user", + content=user_content, + ts=now, + archi_service="mattermost", + ), + Message( + conversation_id=conv_int_id, + sender="assistant", + content=bot_content, + ts=now, + model_used=model_used, + pipeline_used=pipeline_used, + archi_service="mattermost", + ), + ]) + # Keep last_message_at fresh so the conversation floats to the top + # of the web-chat sidebar. + self._svc.update_conversation_timestamp_for_ref(conv_int_id) + except Exception as e: + logger.warning("ThreadContextManager: failed to store exchange: %s", e) + + +# --------------------------------------------------------------------------- +# MattermostEventHandler — orchestrates auth → AI → thread reply for one post +# --------------------------------------------------------------------------- + +class MattermostEventHandler: + """Handles a single Mattermost post end-to-end. + + Used by both MattermostWebhookServer and Mattermost (polling) so that + auth, AI, and reply logic is never duplicated. + + post_data dict keys: id, channel_id, root_id, user_id, username, message + """ + + def __init__( + self, + ai_wrapper, + auth_manager: MattermostAuthManager, + auth_enabled: bool, + bot_user_id: str = "", + mm_client: Optional[MattermostClient] = None, + thread_ctx: Optional[ThreadContextManager] = None, + webhook_url: Optional[str] = None, + reactions: Optional[dict] = None, + chat_base_url: str = "", + ): + self._ai = ai_wrapper + self._auth_manager = auth_manager + self._auth_enabled = auth_enabled + self._bot_user_id = bot_user_id + self._client = mm_client + self._thread_ctx = thread_ctx + self._webhook_url = webhook_url + self._webhook_headers = {'Content-Type': 'application/json'} + self._reactions = reactions or { + "processing": "eyes", + "done": "white_check_mark", + "error": "x", + } + self._chat_base_url = chat_base_url.rstrip("/") + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _send_message(self, channel_id: str, message: str, root_id: str = "") -> None: + """Post a message, preferring thread-aware REST API, falling back to incoming webhook.""" + if self._client: + try: + self._client.create_post(channel_id, message, root_id=root_id) + return + except Exception as e: + logger.warning("MattermostClient.create_post failed, falling back to webhook: %s", e) + if self._webhook_url: + try: + requests.post( + self._webhook_url, + data=json.dumps({"text": message}), + headers=self._webhook_headers, + ) + except Exception as e: + logger.error("Webhook fallback also failed: %s", e) + + def _add_reaction(self, post_id: str, emoji: str) -> None: + if self._client and self._bot_user_id: + try: + self._client.add_reaction(self._bot_user_id, post_id, emoji) + except Exception as e: + logger.warning("Could not add reaction %s: %s", emoji, e) + + def _delete_reaction(self, post_id: str, emoji: str) -> None: + if self._client and self._bot_user_id: + try: + self._client.delete_reaction(self._bot_user_id, post_id, emoji) + except Exception as e: + logger.warning("Could not delete reaction %s: %s", emoji, e) + + def _send_typing(self, channel_id: str, root_id: str = "") -> None: + if self._client: + self._client.send_typing(channel_id, parent_id=root_id) + + def _notify_mcp_auth_needed(self, channel_id: str, username: str, root_id: str = "") -> None: + """If any SSO-auth MCP servers lack a valid token for this user, send a one-line notice. + + Non-blocking: the AI call proceeds regardless — tools on unauthorized servers are simply + unavailable for this message. The notice is only sent when chat_base_url is configured. + """ + if not self._chat_base_url or not username: + return + try: + from src.utils.mcp_oauth_service import MCPOAuthService + from src.utils.config_access import get_mcp_servers_config + mcp_servers = get_mcp_servers_config() + if not mcp_servers: + return + oauth_svc = MCPOAuthService() + needing = oauth_svc.get_servers_needing_auth(username, mcp_servers) + if needing: + links = " ".join( + f"[Authorize {n}]({self._chat_base_url}/mcp/authorize?server={n})" + for n in needing + ) + self._send_message( + channel_id, + f":key: Some tools require authorization. Please visit the web chat to grant access: {links}", + root_id=root_id, + ) + except Exception as e: + logger.debug("MCP auth-needed check failed (non-fatal): %s", e) + + def _call_ai_with_retry(self, history: List[Tuple[str, str]], ctx, max_retries: int = 2) -> str: + """Call archi with exponential-backoff retry on transient errors (e.g. OpenAI 500). + + Raises on the last attempt so the caller's except block can send the user an error. + """ + import time + last_exc: Exception = RuntimeError("no attempts made") + for attempt in range(max_retries + 1): + try: + with mattermost_user_context(ctx): + result = self._ai.archi( + history=history, + user_id=ctx.username or None, + ) + return result["answer"] + except Exception as e: + last_exc = e + if attempt < max_retries: + wait = 2 ** attempt # 1 s, 2 s + logger.warning( + "MattermostEventHandler: AI call failed (attempt %d/%d), " + "retrying in %ds: %s", + attempt + 1, max_retries + 1, wait, e, + ) + time.sleep(wait) + raise last_exc + + def _resolve_source_ref( + self, post_id: str, root_id: str, channel_id: str, user_id: str + ) -> str: + """Return the stable source_ref string for this post. + + - Thread reply (root_id ≠ post_id): scoped to that thread so everyone + in the thread shares context, and the web-chat user can continue there. + - Root / non-threaded post: per-user-per-channel key so follow-up + messages still carry history. + """ + if self._thread_ctx is None: + return "" + if root_id and root_id != post_id: + return ThreadContextManager.source_ref_for_thread(root_id) + return ThreadContextManager.source_ref_for_user_channel(channel_id, user_id) + + def _build_history( + self, + post_id: str, + root_id: str, + channel_id: str, + user_id: str, + username: str, + message: str, + ) -> Tuple[Optional[int], List[Tuple[str, str]]]: + """Return (conv_int_id, history) for this post. + + conv_int_id is the INTEGER primary key of the conversation_metadata row + (None when ThreadContextManager is unavailable). History is built from + the DB first; falls back to the live Mattermost thread API on cold-start. + """ + if self._thread_ctx is None: + return None, [("User", message)] + + source_ref = self._resolve_source_ref(post_id, root_id, channel_id, user_id) + conv_int_id, history = self._thread_ctx.build_history_from_db( + source_ref, username=username, title=message[:50] + ) + + # Cold-start: seed history from live Mattermost thread (PAK required) + if not history and root_id and root_id != post_id and self._client: + try: + thread_data = self._client.get_thread(root_id) + history = self._thread_ctx.build_history_from_thread(thread_data) + except Exception as e: + logger.warning("Cold-start thread fetch failed: %s", e) + + return conv_int_id, history + [("User", message)] + + # ------------------------------------------------------------------ + # Main entry point + # ------------------------------------------------------------------ + + def handle(self, post_data: dict) -> None: + """Full handling pipeline for one Mattermost post. + + Steps: self-filter → auth → RBAC → typing → 👀 → build history + → call AI → post reply in thread → store → ✅/❌ + """ + post_id = post_data.get("id", "") + channel_id = post_data.get("channel_id", "") + # root_id for the thread — if empty this is a root post, reply starts a thread + root_id = post_data.get("root_id") or post_id + user_id = post_data.get("user_id", "") + username = post_data.get("username", "") or post_data.get("user_name", "") + message = post_data.get("message", "").strip() + + # 1. Skip the bot's own posts (prevents infinite-loop on self-replies) + if self._bot_user_id and user_id == self._bot_user_id: + return + + # 2. Auth: build user context (None → no stored token → prompt login) + ctx = self._auth_manager.build_context(user_id=user_id, username=username) + if ctx is None: + login_url = self._auth_manager.login_url(user_id, username) + self._send_message( + channel_id, + f"Hi @{username}! To use this bot, please login first: {login_url}\n" + "After logging in, send your message again.", + root_id=root_id, + ) + return + + # 3. RBAC permission check + if self._auth_enabled: + registry = get_registry() + if not registry.has_permission(ctx.roles, Permission.Mattermost.ACCESS): + logger.info( + "MattermostEventHandler: access denied for user_id=%r (roles=%s)", + user_id, ctx.roles, + ) + self._send_message( + channel_id, + "Sorry, you don't have permission to use this bot. Please contact an administrator.", + root_id=root_id, + ) + return + + # 3.5. Notify user about MCP servers that still need OAuth authorization. + # This is informational-only — the AI call proceeds with available tools. + self._notify_mcp_auth_needed(channel_id, ctx.username, root_id=root_id) + + logger.info( + "MattermostEventHandler: message from @%s (id=%s, channel=%s): %r", + username, user_id, channel_id, message, + ) + + # 4. Typing indicator (best-effort, before the long AI call) + self._send_typing(channel_id, root_id=root_id) + + # 5. 👀 reaction to acknowledge receipt + self._add_reaction(post_id, self._reactions["processing"]) + + try: + # 6. Build conversation history (DB primary, cold-start from thread API). + # conv_int_id is the INTEGER pk of conversation_metadata (None = no storage). + conv_int_id, history = self._build_history( + post_id, root_id, channel_id, user_id, username, message + ) + + # 7. Call AI pipeline — retry up to 2x on transient upstream errors (e.g. OpenAI 500) + answer = self._call_ai_with_retry(history, ctx) + logger.debug("MattermostEventHandler: ANSWER = %s", answer) + + # 8. Post reply in thread + self._send_message(channel_id, answer, root_id=root_id) + + # 9. Store exchange in PostgreSQL (conv_int_id is None without storage layer) + if self._thread_ctx and conv_int_id is not None: + self._thread_ctx.store_exchange(conv_int_id, message, answer) + + # 10a. ✅ reaction, remove 👀 + self._delete_reaction(post_id, self._reactions["processing"]) + self._add_reaction(post_id, self._reactions["done"]) + + except Exception as e: + logger.error( + "MattermostEventHandler: failed to handle post %s: %s", post_id, e, exc_info=True + ) + try: + self._send_message( + channel_id, + "Sorry, I encountered an error processing your message. Please try again.", + root_id=root_id, + ) + except Exception: + pass + # 10b. ❌ reaction, remove 👀 + self._delete_reaction(post_id, self._reactions["processing"]) + self._add_reaction(post_id, self._reactions["error"]) + + +# --------------------------------------------------------------------------- +# MattermostAIWrapper — initializes and calls the archi pipeline +# --------------------------------------------------------------------------- + class MattermostAIWrapper: def __init__(self): # initialize and update vector store self.data_manager = DataManager(run_ingestion=False) - # intialize chain - self.archi = archi() - - def __call__(self, post): + # initialize chain + config = get_full_config() + services_cfg = config.get("services", {}) + mm_cfg = services_cfg.get("mattermost", {}) + chat_cfg = services_cfg.get("chat_app", {}) + agent_class = mm_cfg.get("agent_class") or chat_cfg.get("agent_class", "QAPipeline") + agents_dir = mm_cfg.get("agents_dir") or chat_cfg.get("agents_dir") + agent_spec = None + if agents_dir: + try: + agent_spec = select_agent_spec(Path(agents_dir)) + except AgentSpecError as exc: + logger.warning(f"Failed to load agent spec: {exc}") + agent_spec = None + prompt_overrides = mm_cfg.get("prompts", {}) + self.archi = archi( + pipeline=agent_class, + agent_spec=agent_spec, + default_provider=mm_cfg.get("default_provider") or chat_cfg.get("default_provider"), + default_model=mm_cfg.get("default_model") or chat_cfg.get("default_model"), + prompt_overrides=prompt_overrides, + ) - # form the formatted history using the post - formatted_history = [] + def call_with_history( + self, history: List[Tuple[str, str]], user_id: Optional[str] = None + ) -> str: + """Call archi with explicit multi-turn history. Returns answer string.""" + answer = self.archi(history=history, user_id=user_id)["answer"] + logger.debug('ANSWER = %s', answer) + return answer + def __call__(self, post): + # Build single-turn history from post dict (backward-compatible path) post_str = post['message'] - formatted_history.append(("User", post_str)) + formatted_history = [("User", post_str)] - # call chain - answer = self.archi(formatted_history)["answer"] - logger.debug('ANSWER = ',answer) + # Resolve user_id for MCP tool authentication. + # ctx.username matches the CERN SSO preferred_username / sub used as + # the key in mcp_oauth_tokens, enabling cmspnr tools for authed users. + user_id = None + try: + mm_ctx = get_mattermost_context() + if mm_ctx is not None: + user_id = mm_ctx.username or None + except Exception: + pass + answer = self.archi(history=formatted_history, user_id=user_id)["answer"] + logger.debug('ANSWER = %s', answer) return answer, post_str + +# --------------------------------------------------------------------------- +# Shared setup helpers +# --------------------------------------------------------------------------- + +def _derive_chat_base_url(config: dict) -> str: + """Derive the web-chat base URL from config so Mattermost can send MCP auth links. + + Uses services.mcp_server.url (already the public host) and appends + services.chat_app.external_port. Falls back to services.mattermost.auth.login_base_url + with the chat port substituted, then empty string if nothing is available. + """ + from urllib.parse import urlparse + try: + services = config.get("services", {}) + chat_port = int(services.get("chat_app", {}).get("external_port", 0)) + mcp_srv_url = services.get("mcp_server", {}).get("url", "") + if mcp_srv_url and chat_port: + parsed = urlparse(mcp_srv_url) + return f"{parsed.scheme}://{parsed.hostname}:{chat_port}" + # Fallback: derive from Mattermost login_base_url by swapping port + login_url = services.get("mattermost", {}).get("auth", {}).get("login_base_url", "") + if login_url and chat_port: + parsed = urlparse(login_url) + return f"{parsed.scheme}://{parsed.hostname}:{chat_port}" + except Exception: + pass + return "" + + +def _build_mm_client_and_context( + mm_config: dict, + mattermost_url: str, +) -> Tuple[Optional[MattermostClient], Optional[ThreadContextManager], str]: + """Create MattermostClient + ThreadContextManager if PAK is available. + + Returns (mm_client, thread_ctx, bot_user_id). + bot_user_id comes from config, or is auto-fetched from the API if blank. + """ + pak = read_secret("MATTERMOST_PAK") + bot_user_id = mm_config.get("bot_user_id", "") + + mm_client: Optional[MattermostClient] = None + if pak: + mm_client = MattermostClient(mattermost_url, pak) + # Auto-fetch bot user ID if not explicitly configured + if not bot_user_id: + try: + me = mm_client.get_me() + bot_user_id = me.get("id", "") + logger.info("MattermostClient: auto-fetched bot_user_id=%r", bot_user_id) + except Exception as e: + logger.warning("MattermostClient: could not auto-fetch bot user ID: %s", e) + + # ThreadContextManager is always created when PostgreSQL is available — + # it handles conversation storage/continuity independently of the PAK. + thread_ctx: Optional[ThreadContextManager] = None + try: + from src.utils.postgres_service_factory import PostgresServiceFactory + factory = PostgresServiceFactory.get_instance() + if factory is not None: + thread_ctx = ThreadContextManager( + conversation_service=factory.conversation_service, + bot_user_id=bot_user_id, + context_window=int(mm_config.get("context_window", 20)), + ) + logger.info( + "MattermostClient: thread context manager initialised " + "(context_window=%s, pak=%s)", + mm_config.get("context_window", 20), + "yes" if pak else "no", + ) + except Exception as e: + logger.warning("Could not initialise ThreadContextManager: %s", e) + + return mm_client, thread_ctx, bot_user_id + + +# --------------------------------------------------------------------------- +# Mattermost — polling mode +# --------------------------------------------------------------------------- + class Mattermost: """ - Class to go through unresolved posts in Mattermost and suggest answers. - Filter feed for new posts and propose answers. - Also filter for new posts that have been resolved and add to vector store. - For now, just iterate through all posts and send replies for unresolved. + Polling-based Mattermost integration. + Periodically fetches new posts from a channel and replies via the event handler. """ def __init__(self): - logger.info('Mattermost::INIT') - self.mattermost_config = get_full_config().get("utils", {}).get("mattermost", None) - - # mattermost webhook for reading questions/sending responses - self.mattermost_url = 'https://mattermost.web.cern.ch/' + config = get_full_config() + self.mattermost_config = config.get("services", {}).get("mattermost", {}) + mm_config = self.mattermost_config or {} + + # Auth setup + auth_config = mm_config.get("auth", {}) + pg_config = { + "password": read_secret("PG_PASSWORD"), + **config.get("services", {}).get("postgres", {}), + } + auth_manager = MattermostAuthManager(auth_config, pg_config=pg_config) + auth_enabled = auth_config.get("enabled", False) + + # Mattermost connection details + self.mattermost_url = mm_config.get("base_url", "https://mattermost.web.cern.ch/") self.mattermost_webhook = read_secret("MATTERMOST_WEBHOOK") self.mattermost_channel_id_read = read_secret("MATTERMOST_CHANNEL_ID_READ") self.mattermost_channel_id_write = read_secret("MATTERMOST_CHANNEL_ID_WRITE") - self.PAK = read_secret("MATTERMOST_PAK") + self.PAK = read_secret("MATTERMOST_PAK") self.mattermost_headers = { 'Authorization': f'Bearer {self.PAK}', - 'Content-Type': 'application/json' + 'Content-Type': 'application/json', } - logger.debug('mattermost_webhook =', self.mattermost_webhook) - logger.debug('mattermost_channel_id_read =', self.mattermost_channel_id_read) - logger.debug('mattermost_channel_id_write =', self.mattermost_channel_id_write) - logger.debug('PAK =', self.PAK) + logger.debug('mattermost_webhook = %s', self.mattermost_webhook) + logger.debug('mattermost_channel_id_read = %s', self.mattermost_channel_id_read) + logger.debug('mattermost_channel_id_write = %s', self.mattermost_channel_id_write) + logger.debug('PAK = %s', self.PAK) - # initialize MattermostAIWrapper - self.ai_wrapper = MattermostAIWrapper() + # Tracking file for deduplication (config-driven, safe default) + self.min_next_post_file = mm_config.get( + "tracking_file", "/root/data/mattermost/answered_posts.json" + ) - # - self.min_next_post_file = "/root/data/LPC2025/min_next_post.json" + # AI wrapper + ai_wrapper = MattermostAIWrapper() - def post_response(self, response): + # MattermostClient + ThreadContextManager (PAK-gated) + mm_client, thread_ctx, bot_user_id = _build_mm_client_and_context( + mm_config, self.mattermost_url + ) -# TODO: support writing in a dedicated mattermost_channel_id_write -# url = f"{self.mattermost_url}/api/v4/posts" -# print('GOING TO WRITE HERE: ',url) - -# payload = { -# "channel_id": self.mattermost_channel_id_write, -# "message": response -# } - # send response to MM - #r = requests.post(url, data=json.dumps(payload), headers=self.mattermost_headers) - r = requests.post(self.mattermost_webhook, data=json.dumps({"text": response,"channel" : "town-square"}), headers=self.mattermost_headers) - - return + # Event handler — consolidates all auth + AI + reply logic + self.event_handler = MattermostEventHandler( + ai_wrapper=ai_wrapper, + auth_manager=auth_manager, + auth_enabled=auth_enabled, + bot_user_id=bot_user_id, + mm_client=mm_client, + thread_ctx=thread_ctx, + webhook_url=self.mattermost_webhook, + reactions=mm_config.get("reactions"), + chat_base_url=_derive_chat_base_url(config), + ) def write_min_next_post(self, answered_key): try: - # create directory if it does not exist os.makedirs(os.path.dirname(self.min_next_post_file), exist_ok=True) with open(self.min_next_post_file, "w") as f: json.dump({"answered_id": answered_key}, f) @@ -98,12 +772,11 @@ def write_min_next_post(self, answered_key): logger.debug(f"ERROR - Failed to write answered_key to file: {e}") def get_active_posts(self): - content = f"api/v4/channels/{self.mattermost_channel_id_read}/posts" r = requests.get(self.mattermost_url + content, headers=self.mattermost_headers) - active_posts={} + active_posts = {} for id in r.json()["order"]: - active_posts[id]=r.json()["posts"][id]["message"] + active_posts[id] = r.json()["posts"][id]["message"] return active_posts def filter_posts(self, posts, excluded_user_id): @@ -115,30 +788,25 @@ def filter_posts(self, posts, excluded_user_id): "system_leave_channel", "system_remove_from_channel", } - filtered = [] - for post in posts.values(): if post.get("user_id") == excluded_user_id: - continue # Skip this user - + continue if post.get("type") in system_types: - continue # Skip system messages - + continue filtered.append(post) - return filtered def get_last_post(self): - content = f"api/v4/channels/{self.mattermost_channel_id_read}/posts" - r = requests.get(self.mattermost_url + content, headers=self.mattermost_headers) data = r.json() posts = data.get('posts', {}) - excluded_archi_id = "ajb6wyizpinqir7m16owntod7o" - filtered_posts = self.filter_posts(posts, excluded_user_id=excluded_archi_id) + # bot_user_id comes from event_handler (config-driven or auto-fetched) + excluded_bot_id = self.event_handler._bot_user_id + + filtered_posts = self.filter_posts(posts, excluded_user_id=excluded_bot_id) sorted_posts = sorted(filtered_posts, key=lambda x: x['create_at'], reverse=True) if sorted_posts: @@ -151,31 +819,25 @@ def get_last_post(self): return sorted_posts[0] def checkAnswerExist(self, tmpID): - - # Check if file exists if not os.path.exists(self.min_next_post_file): logger.info("File does not exist, creating new one.") - data = {"answered_id": []} # Initialize with empty list return False else: - # Load existing data with open(self.min_next_post_file, "r") as f: data = json.load(f) - logger.info("Loaded data:", data) + logger.info("Loaded data: %s", data) answered_ids = data.get("answered_id", []) - # Ensure it's a list if not isinstance(answered_ids, list): answered_ids = [answered_ids] - # Only append if not already in the list (optional) if tmpID in answered_ids: logger.info(f"{tmpID} already exists") return True else: answered_ids.append(tmpID) - data["answered_id"] = answered_ids # Overwrite with new ID + data["answered_id"] = answered_ids with open(self.min_next_post_file, "w") as f: json.dump(data, f, indent=2) logger.info(f"Added {tmpID} for next iterations") @@ -183,26 +845,225 @@ def checkAnswerExist(self, tmpID): # for now just processes "main" posts, i.e. not replies/follow-ups def process_posts(self): - try: - # get last post topic = self.get_last_post() except Exception as e: logger.error("ERROR - Failed to parse feed due to the following exception:") logger.error(str(e)) return - if self.checkAnswerExist(topic['id']) : - # no need to answer someone already answered - logger.info('no need to answer someone already answered') + if self.checkAnswerExist(topic['id']): + logger.info('no need to answer someone already answered') + return + + post_data = { + "id": topic.get("id", ""), + "channel_id": topic.get("channel_id", self.mattermost_channel_id_read), + "root_id": topic.get("root_id", ""), + "user_id": topic.get("user_id", ""), + "username": topic.get("username", ""), + "message": topic.get("message", ""), + } + + try: + self.event_handler.handle(post_data) + self.write_min_next_post(topic['id']) + except Exception as e: + logger.error( + f"ERROR - Failed to process post {topic['id']} due to the following exception:" + ) + logger.error(str(e)) + + +# --------------------------------------------------------------------------- +# MattermostWebhookServer — event-driven (Flask) mode +# --------------------------------------------------------------------------- + +class MattermostWebhookServer: + """ + Event-driven alternative to the polling-based Mattermost class. + Runs a Flask HTTP server that receives messages via an outgoing webhook + (Mattermost pushes POSTs here) and replies via the REST API or incoming webhook. + No Personal Access Token required for basic operation; PAK enables thread + replies, reactions, and typing indicators. + """ + def __init__(self): + logger.info('MattermostWebhookServer::INIT') + + self.mattermost_webhook = read_secret("MATTERMOST_WEBHOOK") + self.outgoing_token = read_secret("MATTERMOST_OUTGOING_TOKEN") + + config = get_full_config() + mm_config = config.get("services", {}).get("mattermost", {}) + self.port = int(mm_config.get("port", 5000)) + self.mattermost_url = mm_config.get("base_url", "https://mattermost.web.cern.ch/") + + # Auth setup + auth_config = mm_config.get("auth", {}) + pg_config = { + "password": read_secret("PG_PASSWORD"), + **config.get("services", {}).get("postgres", {}), + } + auth_manager = MattermostAuthManager(auth_config, pg_config=pg_config) + auth_enabled = auth_config.get("enabled", False) + + # AI wrapper + ai_wrapper = MattermostAIWrapper() + + # MattermostClient + ThreadContextManager (PAK-gated) + mm_client, thread_ctx, bot_user_id = _build_mm_client_and_context( + mm_config, self.mattermost_url + ) + + # Event handler + self.event_handler = MattermostEventHandler( + ai_wrapper=ai_wrapper, + auth_manager=auth_manager, + auth_enabled=auth_enabled, + bot_user_id=bot_user_id, + mm_client=mm_client, + thread_ctx=thread_ctx, + webhook_url=self.mattermost_webhook, + reactions=mm_config.get("reactions"), + chat_base_url=_derive_chat_base_url(config), + ) + + import secrets as _secrets + self.app = Flask(__name__) + self.app.secret_key = read_secret("FLASK_UPLOADER_APP_SECRET_KEY") or _secrets.token_hex(32) + self.app.add_url_rule('/webhook', 'webhook', self._handle_webhook, methods=['POST']) + + # SSO OAuth routes for Mattermost user authentication + sso_cfg = auth_config.get('sso', {}) + self._sso_enabled = bool(read_secret("SSO_CLIENT_ID") and read_secret("SSO_CLIENT_SECRET")) + if self._sso_enabled: + self._oauth = OAuth(self.app) + self._oauth.register( + name='sso', + client_id=read_secret("SSO_CLIENT_ID"), + client_secret=read_secret("SSO_CLIENT_SECRET"), + server_metadata_url=sso_cfg.get( + 'server_metadata_url', + 'https://auth.cern.ch/auth/realms/cern/.well-known/openid-configuration', + ), + client_kwargs={'scope': 'openid profile email offline_access'}, + ) + self._token_service = MattermostTokenService( + pg_config=pg_config, + token_endpoint=sso_cfg.get('token_endpoint', ''), + session_lifetime_days=int(auth_config.get('session_lifetime_days', 30)), + roles_refresh_hours=int(auth_config.get('roles_refresh_hours', 24)), + ) + self.app.add_url_rule('/mattermost-auth', 'mattermost_auth_login', self._mattermost_auth_login) + self.app.add_url_rule('/mattermost-auth/callback', 'mattermost_auth_callback', self._mattermost_auth_callback) + logger.info("MattermostWebhookServer: SSO auth routes registered") + + def _handle_webhook(self): + # Mattermost outgoing webhooks send either application/x-www-form-urlencoded or application/json + if flask_request.is_json: + data = flask_request.get_json(silent=True) or {} else: - # otherwise, process it - try: - answer, post_str = self.ai_wrapper(topic) - print('topic',topic,' \n ANSWER: ',answer) - postedMM = self.post_response(answer) - post_str = self.write_min_next_post(topic['id']) + data = flask_request.form - except Exception as e: - logger.error(f"ERROR - Failed to process post {topic['id']} due to the following exception:") - logger.error(str(e)) + token = data.get('token', '') + if self.outgoing_token and token != self.outgoing_token: + logger.warning('MattermostWebhookServer: received request with invalid token') + return jsonify({}), 403 + + text = data.get('text', '').strip() + if not text: + return jsonify({}), 200 + + # Build a normalised post_data dict from the outgoing webhook payload. + # Mattermost outgoing webhook fields: post_id, root_id, user_id, user_name, + # channel_id, text, token, team_id, etc. + post_data = { + "id": data.get("post_id", ""), + "channel_id": data.get("channel_id", ""), + "root_id": data.get("root_id", ""), # non-empty if this post is a thread reply + "user_id": data.get("user_id", ""), + "username": data.get("user_name", ""), + "message": text, + } + + # Process in a background thread — Mattermost outgoing webhooks have a ~5 second + # timeout and will retry on no response, causing duplicate AI calls. We must return + # 200 immediately and do the AI work asynchronously. + threading.Thread( + target=self.event_handler.handle, + args=(post_data,), + daemon=True, + ).start() + + # Without PAK, reactions and typing indicators are unavailable. + # Return a text acknowledgment in the HTTP response body instead — + # Mattermost posts this text to the channel immediately as visual feedback. + if not self.event_handler._client: + return jsonify({"text": ":hourglass_flowing_sand: _Processing..._"}), 200 + + return jsonify({}), 200 + + def _mattermost_auth_login(self): + """ + Step 1: user clicks the login link from Mattermost. + Stashes mm_username in session, then redirects to CERN SSO. + mm_user_id is passed as OAuth state and round-tripped back by SSO. + """ + mm_user_id = flask_request.args.get('state', '').strip() + mm_username = flask_request.args.get('username', '').strip() + if not mm_user_id: + return "Missing Mattermost user ID", 400 + session['_mm_pending_username'] = mm_username + redirect_uri = url_for('mattermost_auth_callback', _external=True) + return self._oauth.sso.authorize_redirect(redirect_uri, state=mm_user_id) + + def _mattermost_auth_callback(self): + """ + Step 2: CERN SSO redirects back here after the user authenticates. + Extracts roles from the JWT and stores them in mattermost_tokens. + """ + try: + token = self._oauth.sso.authorize_access_token() + mm_user_id = flask_request.args.get('state', '').strip() + mm_username = session.pop('_mm_pending_username', '') + + if not mm_user_id: + return "Missing Mattermost user ID in callback state", 400 + + user_info = token.get('userinfo') or self._oauth.sso.userinfo(token=token) + user_email = user_info.get('email', user_info.get('preferred_username', '')) + user_roles = get_user_roles(token, user_email) + + self._token_service.store_token( + mm_user_id=mm_user_id, + mm_username=mm_username or user_info.get('preferred_username', ''), + email=user_email, + roles=user_roles, + refresh_token=token.get('refresh_token'), + ) + + logger.info( + f"Mattermost auth successful: @{mm_username} (id={mm_user_id!r}) " + f"email={user_email!r} roles={user_roles}" + ) + return ( + "" + "

Login successful!

" + f"

You are now authenticated as {user_email} " + f"with roles: {', '.join(user_roles)}.

" + "

You can close this tab and return to Mattermost.

" + "" + ) + except Exception as exc: + logger.error(f"Mattermost auth callback error: {exc}") + return ( + "" + "

Authentication failed

" + f"

Error: {exc}

" + "

Please try clicking the login link in Mattermost again.

" + "" + ), 500 + + def run(self, host='0.0.0.0', port=5000): + logger.info(f'MattermostWebhookServer: starting on {host}:{port}') + self.app.run(host=host, port=port) diff --git a/src/interfaces/uploader_app/app.py b/src/interfaces/uploader_app/app.py index f7fd20cc8..9312699b3 100644 --- a/src/interfaces/uploader_app/app.py +++ b/src/interfaces/uploader_app/app.py @@ -507,7 +507,7 @@ def api_catalog_search(self): } if candidate_hashes is None: - iterable = list(self.catalog.iter_files()) + iterable = self.catalog.iter_files() else: iterable = [] for resource_hash in candidate_hashes: @@ -560,7 +560,7 @@ def api_catalog_search(self): } if candidate_hashes is None: - iterable = list(self.catalog.iter_files()) + iterable = self.catalog.iter_files() else: iterable = [] for resource_hash in candidate_hashes: @@ -666,7 +666,14 @@ def _parse_metadata_query(query: str) -> Tuple[Dict[str, str] | List[Dict[str, s filters_groups: List[Dict[str, str]] = [] current_group: Dict[str, str] = {} free_tokens = [] - for token in shlex.split(query): + try: + tokens = shlex.split(query) + except ValueError as exc: + # Avoid 500s on malformed quoted input generated by LLM tools. + logger.warning("Invalid metadata query syntax; using fallback tokenization: %s", exc) + tokens = query.split() + + for token in tokens: if token.upper() == "OR": if current_group: filters_groups.append(current_group) diff --git a/src/utils/config_service.py b/src/utils/config_service.py index e0921907f..bb92eaf86 100644 --- a/src/utils/config_service.py +++ b/src/utils/config_service.py @@ -219,8 +219,8 @@ def _ensure_config_tables(self) -> None: ADD COLUMN IF NOT EXISTS services_config JSONB DEFAULT '{}'::jsonb, ADD COLUMN IF NOT EXISTS data_manager_config JSONB DEFAULT '{}'::jsonb, ADD COLUMN IF NOT EXISTS archi_config JSONB DEFAULT '{}'::jsonb, - ADD COLUMN IF NOT EXISTS mcp_servers_config JSONB DEFAULT '{}'::jsonb, - ADD COLUMN IF NOT EXISTS global_config JSONB DEFAULT '{}'::jsonb + ADD COLUMN IF NOT EXISTS global_config JSONB DEFAULT '{}'::jsonb, + ADD COLUMN IF NOT EXISTS mcp_servers_config JSONB DEFAULT '{}'::jsonb """ ) cursor.execute( @@ -229,6 +229,50 @@ def _ensure_config_tables(self) -> None: ADD COLUMN IF NOT EXISTS active_agent_name VARCHAR(200) """ ) + # SSO token table for MCP Bearer auth (added for sso_auth MCP support) + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS sso_tokens ( + user_id VARCHAR(200) PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE, + access_token BYTEA, + refresh_token BYTEA, + access_token_expires_at TIMESTAMPTZ, + session_expires_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + ) + """ + ) + # MCP OAuth2 client registrations and per-user tokens + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS mcp_oauth_clients ( + server_name VARCHAR(200) PRIMARY KEY, + server_url TEXT NOT NULL, + client_id TEXT NOT NULL, + client_secret TEXT NOT NULL DEFAULT '', + redirect_uri TEXT NOT NULL, + auth_meta JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + ) + """ + ) + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS mcp_oauth_tokens ( + user_id VARCHAR(200) REFERENCES users(id) ON DELETE CASCADE, + server_name VARCHAR(200) NOT NULL, + access_token BYTEA, + refresh_token BYTEA, + access_token_expires_at TIMESTAMPTZ, + session_expires_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + PRIMARY KEY (user_id, server_name) + ) + """ + ) conn.commit() except psycopg2.Error as e: logger.debug("Could not ensure config tables/columns: %s", e) diff --git a/src/utils/conversation_service.py b/src/utils/conversation_service.py index fb683cdfd..1c2ace89f 100644 --- a/src/utils/conversation_service.py +++ b/src/utils/conversation_service.py @@ -12,6 +12,8 @@ import psycopg2 from psycopg2.extras import execute_values +import os + from src.utils.sql import ( SQL_INSERT_CONVO, SQL_INSERT_AB_COMPARISON, @@ -20,6 +22,9 @@ SQL_GET_PENDING_AB_COMPARISON, SQL_DELETE_AB_COMPARISON, SQL_GET_AB_COMPARISONS_BY_CONVERSATION, + SQL_MM_GET_CONV_ID_BY_SOURCE_REF, + SQL_MM_CREATE_CONVERSATION, + SQL_MM_UPDATE_CONVERSATION_TIMESTAMP, ) @@ -27,7 +32,7 @@ class Message: """A conversation message.""" message_id: Optional[int] = None - conversation_id: str = "" + conversation_id: Any = "" # int (FK to conversation_metadata) or "" when unset sender: str = "" # 'user' or 'assistant' content: str = "" link: Optional[str] = None @@ -252,10 +257,72 @@ def get_user_conversations( finally: self._release_connection(conn) + # ========================================================================= + # Cross-platform Conversation Operations + # ========================================================================= + + def get_or_create_conversation_for_ref( + self, + source_ref: str, + mm_client_id: str = "", + title: str = "", + ) -> int: + """Return the integer conversation_id for a Mattermost source_ref. + + Looks up `conversation_metadata` by `source_ref`. If no row exists yet, + creates one with `archi_service = 'mattermost'` and returns the new id. + + Args: + source_ref: Stable external key, e.g. "mm_thread_" or + "mm_channel__user_". + mm_client_id: Bridge key used by the web-chat list query, typically + "mm_user_". + title: Short title for the conversation (truncated to 50 chars). + + Returns: + Integer conversation_id from conversation_metadata. + """ + conn = self._get_connection() + try: + with conn.cursor() as cur: + cur.execute(SQL_MM_GET_CONV_ID_BY_SOURCE_REF, (source_ref,)) + row = cur.fetchone() + if row: + return row[0] + + # Create new metadata row + now = datetime.now(timezone.utc) + title_clean = (title[:50] if title else source_ref[:50]) or "Mattermost conversation" + version = os.getenv("APP_VERSION", "unknown") + cur.execute( + SQL_MM_CREATE_CONVERSATION, + (title_clean, now, now, mm_client_id, version, source_ref), + ) + conv_id = cur.fetchone()[0] + conn.commit() + return conv_id + except Exception: + conn.rollback() + raise + finally: + self._release_connection(conn) + + def update_conversation_timestamp_for_ref(self, conv_int_id: int) -> None: + """Update last_message_at for a Mattermost conversation to now.""" + conn = self._get_connection() + try: + with conn.cursor() as cur: + cur.execute(SQL_MM_UPDATE_CONVERSATION_TIMESTAMP, (datetime.now(timezone.utc), conv_int_id)) + conn.commit() + except Exception: + conn.rollback() + finally: + self._release_connection(conn) + # ========================================================================= # A/B Comparison Operations # ========================================================================= - + def create_ab_comparison( self, conversation_id: str, diff --git a/src/utils/mattermost_auth.py b/src/utils/mattermost_auth.py new file mode 100644 index 000000000..dedcf49bd --- /dev/null +++ b/src/utils/mattermost_auth.py @@ -0,0 +1,143 @@ +""" +Mattermost Auth Manager - Maps Mattermost user identity to RBAC roles. + +Supports two token_store modes: + config — static username→roles mapping in the config file (no DB, no SSO) + db — SSO-backed tokens stored in mattermost_tokens PostgreSQL table + +Config structure (services.mattermost.auth): + enabled: true + token_store: db # 'db' (SSO) or 'config' (static map) + default_role: mattermost-restricted + session_lifetime_days: 30 + roles_refresh_hours: 24 + login_base_url: "https://vocms248.cern.ch" + sso: + token_endpoint: "https://auth.cern.ch/auth/realms/cern/protocol/openid-connect/token" + # Only used when token_store=config: + user_roles: + ahmedmu: [archi-admins] +""" + +from typing import Dict, List, Optional + +from src.utils.rbac.mattermost_context import MattermostUserContext +from src.utils.logging import get_logger + +logger = get_logger(__name__) + + +class MattermostAuthManager: + """ + Resolves Mattermost users to RBAC roles. + + In 'config' mode: static username/user_id → roles map from config. + In 'db' mode: delegates to MattermostTokenService for SSO-backed roles. + Returns None when user has no stored token (triggers login prompt). + """ + + def __init__(self, auth_config: dict, pg_config: Optional[dict] = None): + self.enabled: bool = auth_config.get('enabled', False) + self.token_store: str = auth_config.get('token_store', 'config') + self.default_role: str = auth_config.get('default_role', 'mattermost-restricted') + self.login_base_url: str = auth_config.get('login_base_url', '').rstrip('/') + self.user_roles: Dict[str, List[str]] = auth_config.get('user_roles', {}) + + self._token_service = None + if self.enabled and self.token_store == 'db': + if pg_config: + self._init_token_service(auth_config, pg_config) + else: + logger.warning( + "MattermostAuthManager: token_store=db but no pg_config provided. " + "Falling back to token_store=config." + ) + self.token_store = 'config' + + if self.enabled: + logger.info( + f"MattermostAuthManager: enabled=True, token_store={self.token_store!r}, " + f"default_role={self.default_role!r}" + ) + else: + logger.info( + f"MattermostAuthManager: disabled — all users get " + f"default_role={self.default_role!r}" + ) + + def _init_token_service(self, auth_config: dict, pg_config: dict) -> None: + try: + from src.utils.mattermost_token_service import MattermostTokenService + sso_cfg = auth_config.get('sso', {}) + self._token_service = MattermostTokenService( + pg_config=pg_config, + token_endpoint=sso_cfg.get('token_endpoint', ''), + session_lifetime_days=int(auth_config.get('session_lifetime_days', 30)), + roles_refresh_hours=int(auth_config.get('roles_refresh_hours', 24)), + ) + logger.info("MattermostAuthManager: DB token service initialized") + except Exception as exc: + logger.error(f"MattermostAuthManager: failed to init token service: {exc}") + self._token_service = None + self.token_store = 'config' + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def build_context( + self, user_id: str, username: str = "", email: str = "" + ) -> Optional[MattermostUserContext]: + """ + Return a MattermostUserContext for the given Mattermost user. + + Returns None in db mode when the user has no stored token — the caller + should send a login link and abort processing. + + Returns a context with default_role in config mode for unknown users. + """ + if not self.enabled: + return MattermostUserContext( + user_id=user_id, username=username, + roles=[self.default_role], email=email, + ) + + if self.token_store == 'db' and self._token_service: + return self._token_service.get_user_context(user_id, username) + + # config mode — static lookup + roles = self._static_roles(username, user_id) + return MattermostUserContext( + user_id=user_id, username=username, roles=roles, email=email, + ) + + def login_url(self, user_id: str, username: str = "") -> str: + """Build the SSO login URL to send to an unauthenticated Mattermost user.""" + base = self.login_base_url or "http://localhost:7861" + url = f"{base}/mattermost-auth?state={user_id}" + if username: + url += f"&username={username}" + return url + + def invalidate(self, user_id: str) -> None: + """Force re-login for a user (admin action).""" + if self._token_service: + self._token_service.invalidate(user_id) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _static_roles(self, username: str, user_id: str) -> List[str]: + """Config-mode role lookup: username first, then user_id, then default.""" + roles = self.user_roles.get(username) or self.user_roles.get(user_id) + if roles: + logger.debug( + f"MattermostAuthManager: @{username!r} (id={user_id!r}) -> roles={roles}" + ) + return roles + logger.debug( + f"MattermostAuthManager: unknown user @{username!r} (id={user_id!r}), " + f"assigning default_role={self.default_role!r}" + ) + return [self.default_role] diff --git a/src/utils/mattermost_token_service.py b/src/utils/mattermost_token_service.py new file mode 100644 index 000000000..50fb6ac4e --- /dev/null +++ b/src/utils/mattermost_token_service.py @@ -0,0 +1,285 @@ +""" +Mattermost Token Service - Manages SSO tokens for Mattermost users. + +Stores SSO refresh tokens and roles in PostgreSQL, enabling role-based access +without requiring re-login on every message. Silently refreshes roles using the +stored refresh token; only prompts the user to re-login when the session expires. + +Session lifetime: configurable (default 30 days) — full re-login required +Roles refresh: configurable (default 24h) — silent, uses refresh token +""" + +import json +import requests as http_requests +from datetime import datetime, timedelta, timezone +from typing import List, Optional + +from src.utils.env import read_secret +from src.utils.logging import get_logger +from src.utils.rbac.jwt_parser import get_user_roles +from src.utils.rbac.mattermost_context import MattermostUserContext + +logger = get_logger(__name__) + + +class MattermostTokenService: + """ + DB-backed token store for Mattermost SSO auth. + + Initialized with PostgreSQL config and OIDC token endpoint. The token + endpoint is used for silent role refresh via refresh_token grant. + """ + + def __init__( + self, + pg_config: dict, + token_endpoint: str = "", + session_lifetime_days: int = 30, + roles_refresh_hours: int = 24, + ): + self.pg_config = pg_config + self.token_endpoint = token_endpoint + self.session_lifetime_days = session_lifetime_days + self.roles_refresh_hours = roles_refresh_hours + self._encryption_key = read_secret("BYOK_ENCRYPTION_KEY") or read_secret("PG_ENCRYPTION_KEY") + if not self._encryption_key: + logger.warning( + "MattermostTokenService: no encryption key found (BYOK_ENCRYPTION_KEY). " + "Refresh tokens will not be stored — silent role refresh disabled." + ) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def get_user_context( + self, mm_user_id: str, mm_username: str = "" + ) -> Optional[MattermostUserContext]: + """ + Return a MattermostUserContext for a known user, or None if: + - No token is stored (user must login) + - Session has expired (user must re-login) + + Silently refreshes roles if they are stale (older than roles_refresh_hours). + """ + row = self._fetch_row(mm_user_id) + if row is None: + logger.debug(f"No token for Mattermost user_id={mm_user_id!r}") + return None + + stored_username, email, roles, token_expires_at, roles_refreshed_at, refresh_token = row + + # Session expiry check — requires full re-login + now = datetime.now(timezone.utc) + if token_expires_at and now > token_expires_at: + logger.info(f"Session expired for Mattermost user_id={mm_user_id!r}, invalidating") + self.invalidate(mm_user_id) + return None + + # Silent role refresh if stale + if roles_refreshed_at: + stale_threshold = now - timedelta(hours=self.roles_refresh_hours) + if roles_refreshed_at < stale_threshold: + logger.info(f"Refreshing stale roles for Mattermost user_id={mm_user_id!r}") + fresh = self._refresh_roles(mm_user_id, email or "", refresh_token) + if fresh is not None: + roles = fresh + else: + logger.warning( + f"Role refresh failed for user_id={mm_user_id!r}, using cached roles" + ) + + return MattermostUserContext( + user_id=mm_user_id, + username=mm_username or stored_username or "", + roles=roles, + email=email or "", + ) + + def store_token( + self, + mm_user_id: str, + mm_username: str, + email: str, + roles: List[str], + refresh_token: Optional[str], + ) -> None: + """Store or update a token for a Mattermost user.""" + expires_at = datetime.now(timezone.utc) + timedelta(days=self.session_lifetime_days) + pool = self._get_pool() + with pool.get_connection() as conn: + with conn.cursor() as cur: + if self._encryption_key and refresh_token: + cur.execute( + """ + INSERT INTO mattermost_tokens + (mattermost_user_id, mattermost_username, email, roles, + refresh_token, token_expires_at, roles_refreshed_at, updated_at) + VALUES (%s, %s, %s, %s, + pgp_sym_encrypt(%s, %s), %s, NOW(), NOW()) + ON CONFLICT (mattermost_user_id) DO UPDATE SET + mattermost_username = EXCLUDED.mattermost_username, + email = EXCLUDED.email, + roles = EXCLUDED.roles, + refresh_token = EXCLUDED.refresh_token, + token_expires_at = EXCLUDED.token_expires_at, + roles_refreshed_at = NOW(), + updated_at = NOW() + """, + (mm_user_id, mm_username, email, json.dumps(roles), + refresh_token, self._encryption_key, expires_at), + ) + else: + # No encryption key or no refresh token — store without refresh capability + cur.execute( + """ + INSERT INTO mattermost_tokens + (mattermost_user_id, mattermost_username, email, roles, + token_expires_at, roles_refreshed_at, updated_at) + VALUES (%s, %s, %s, %s, %s, NOW(), NOW()) + ON CONFLICT (mattermost_user_id) DO UPDATE SET + mattermost_username = EXCLUDED.mattermost_username, + email = EXCLUDED.email, + roles = EXCLUDED.roles, + token_expires_at = EXCLUDED.token_expires_at, + roles_refreshed_at = NOW(), + updated_at = NOW() + """, + (mm_user_id, mm_username, email, json.dumps(roles), expires_at), + ) + conn.commit() + logger.info( + f"Stored token for Mattermost @{mm_username} (id={mm_user_id!r}), " + f"roles={roles}, expires={expires_at.date()}" + ) + + def invalidate(self, mm_user_id: str) -> None: + """Delete the stored token for a user, forcing re-login on next message.""" + pool = self._get_pool() + with pool.get_connection() as conn: + with conn.cursor() as cur: + cur.execute( + "DELETE FROM mattermost_tokens WHERE mattermost_user_id = %s", + (mm_user_id,), + ) + conn.commit() + logger.info(f"Invalidated Mattermost token for user_id={mm_user_id!r}") + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _get_pool(self): + from src.utils.postgres_service_factory import PostgresServiceFactory + factory = PostgresServiceFactory.get_instance() + if factory: + return factory.connection_pool + from src.utils.connection_pool import ConnectionPool + return ConnectionPool(connection_params=self.pg_config) + + def _fetch_row(self, mm_user_id: str): + """Fetch token row from DB. Returns tuple or None.""" + pool = self._get_pool() + with pool.get_connection() as conn: + with conn.cursor() as cur: + if self._encryption_key: + cur.execute( + """ + SELECT mattermost_username, email, roles, + token_expires_at, roles_refreshed_at, + pgp_sym_decrypt(refresh_token, %s)::text AS refresh_token + FROM mattermost_tokens + WHERE mattermost_user_id = %s + """, + (self._encryption_key, mm_user_id), + ) + else: + cur.execute( + """ + SELECT mattermost_username, email, roles, + token_expires_at, roles_refreshed_at, + NULL AS refresh_token + FROM mattermost_tokens + WHERE mattermost_user_id = %s + """, + (mm_user_id,), + ) + row = cur.fetchone() + + if row is None: + return None + + stored_username, email, roles_raw, token_expires_at, roles_refreshed_at, refresh_token = row + roles = json.loads(roles_raw) if isinstance(roles_raw, str) else (roles_raw or []) + return stored_username, email, roles, token_expires_at, roles_refreshed_at, refresh_token + + def _refresh_roles( + self, mm_user_id: str, email: str, refresh_token: Optional[str] + ) -> Optional[List[str]]: + """ + Exchange refresh token for a new token, extract fresh roles, update DB. + Returns new roles on success, None on failure. + """ + if not refresh_token or not self.token_endpoint: + return None + + client_id = read_secret("SSO_CLIENT_ID") + client_secret = read_secret("SSO_CLIENT_SECRET") + if not client_id or not client_secret: + return None + + try: + resp = http_requests.post( + self.token_endpoint, + data={ + "grant_type": "refresh_token", + "client_id": client_id, + "client_secret": client_secret, + "refresh_token": refresh_token, + }, + timeout=10, + ) + resp.raise_for_status() + new_token = resp.json() + except Exception as exc: + logger.warning(f"Token refresh HTTP error for user_id={mm_user_id!r}: {exc}") + return None + + try: + fresh_roles = get_user_roles(new_token, email) + except Exception as exc: + logger.warning(f"Role extraction failed for user_id={mm_user_id!r}: {exc}") + return None + + # Update DB — new refresh token if provided + new_refresh = new_token.get("refresh_token") or refresh_token + pool = self._get_pool() + try: + with pool.get_connection() as conn: + with conn.cursor() as cur: + if self._encryption_key and new_refresh: + cur.execute( + """ + UPDATE mattermost_tokens + SET roles = %s, roles_refreshed_at = NOW(), + refresh_token = pgp_sym_encrypt(%s, %s), updated_at = NOW() + WHERE mattermost_user_id = %s + """, + (json.dumps(fresh_roles), new_refresh, + self._encryption_key, mm_user_id), + ) + else: + cur.execute( + """ + UPDATE mattermost_tokens + SET roles = %s, roles_refreshed_at = NOW(), updated_at = NOW() + WHERE mattermost_user_id = %s + """, + (json.dumps(fresh_roles), mm_user_id), + ) + conn.commit() + except Exception as exc: + logger.warning(f"DB update failed after role refresh for user_id={mm_user_id!r}: {exc}") + + logger.info(f"Refreshed roles for user_id={mm_user_id!r}: {fresh_roles}") + return fresh_roles diff --git a/src/utils/mcp_oauth_service.py b/src/utils/mcp_oauth_service.py new file mode 100644 index 000000000..95a64d2b7 --- /dev/null +++ b/src/utils/mcp_oauth_service.py @@ -0,0 +1,419 @@ +""" +MCPOAuthService - OAuth2 authorization for MCP servers (MCP 2025-11 spec). + +Implements the full authorization code + PKCE flow with dynamic client registration: + 1. Discover auth server via RFC 9728 /.well-known/oauth-protected-resource + 2. Fetch server metadata via RFC 8414 /.well-known/oauth-authorization-server + 3. Register archi as a client via RFC 7591 dynamic registration (once per server) + 4. User authorizes via /mcp/authorize → MCP server → /mcp/callback + 5. Exchange auth code for MCP-issued access + refresh tokens + 6. Store per-user per-server tokens encrypted in PostgreSQL + 7. Silently refresh tokens when expired +""" + +import base64 +import hashlib +import json +import os +import secrets +from datetime import datetime, timedelta, timezone +from typing import Optional, Tuple +from urllib.parse import urlencode, urlparse + +import requests as http_requests + +from src.utils.env import read_secret +from src.utils.logging import get_logger + +logger = get_logger(__name__) + +_CA_BUNDLE = "/etc/ssl/certs/tls-ca-bundle.pem" +_VERIFY = _CA_BUNDLE if os.path.exists(_CA_BUNDLE) else True + + +class MCPOAuthService: + """ + Manages the OAuth2 authorization code + PKCE flow for MCP servers + that implement their own authorization server (MCP 2025-11 spec). + """ + + def __init__(self, pg_config: dict = None, app_base_url: str = ""): + self.pg_config = pg_config or {} + self.app_base_url = app_base_url.rstrip("/") + self._encryption_key = ( + read_secret("BYOK_ENCRYPTION_KEY") + or read_secret("PG_ENCRYPTION_KEY") + or read_secret("UPLOADER_SALT") + or read_secret("FLASK_UPLOADER_APP_SECRET_KEY") + ) + if not self._encryption_key: + logger.warning("MCPOAuthService: no encryption key found — tokens will not be persisted") + + # ------------------------------------------------------------------ + # Discovery & Registration + # ------------------------------------------------------------------ + + def discover_auth_server(self, server_url: str) -> Optional[dict]: + """ + Discover the OAuth2 authorization server metadata for an MCP server URL. + Uses RFC 9728 /.well-known/oauth-protected-resource then RFC 8414. + """ + parsed = urlparse(server_url) + host_url = f"{parsed.scheme}://{parsed.netloc}" + + try: + resp = http_requests.get( + f"{host_url}/.well-known/oauth-protected-resource", + verify=_VERIFY, timeout=10, + ) + if resp.status_code != 200: + logger.debug(f"No protected-resource metadata at {host_url}: {resp.status_code}") + return None + resource_meta = resp.json() + except Exception as e: + logger.warning(f"Failed to fetch protected-resource metadata for {server_url}: {e}") + return None + + auth_server_urls = resource_meta.get("authorization_servers", []) + if not auth_server_urls: + return None + + auth_base = auth_server_urls[0].rstrip("/") + for path in ["/.well-known/oauth-authorization-server", "/.well-known/openid-configuration"]: + try: + resp = http_requests.get(f"{auth_base}{path}", verify=_VERIFY, timeout=10) + if resp.status_code == 200: + return resp.json() + except Exception as e: + logger.debug(f"Failed fetching {auth_base}{path}: {e}") + + return None + + def get_or_register_client(self, server_name: str, server_url: str) -> Optional[dict]: + """ + Return existing client registration or perform dynamic registration (RFC 7591). + Returns dict with client_id, client_secret, redirect_uri, auth_meta. + """ + existing = self._fetch_client_registration(server_name) + if existing: + return existing + + auth_meta = self.discover_auth_server(server_url) + if not auth_meta: + logger.warning(f"Could not discover auth server for MCP server '{server_name}'") + return None + + registration_endpoint = auth_meta.get("registration_endpoint") + if not registration_endpoint: + logger.warning(f"MCP server '{server_name}' has no registration_endpoint") + return None + + redirect_uri = f"{self.app_base_url}/mcp/callback" + try: + resp = http_requests.post( + registration_endpoint, + json={ + "client_name": "archi", + "redirect_uris": [redirect_uri], + "grant_types": ["authorization_code"], + "response_types": ["code"], + "token_endpoint_auth_method": "none", + }, + verify=_VERIFY, timeout=10, + ) + resp.raise_for_status() + reg = resp.json() + except Exception as e: + logger.error(f"Client registration failed for MCP server '{server_name}': {e}") + return None + + client_id = reg.get("client_id") + if not client_id: + logger.error(f"No client_id in registration response for '{server_name}'") + return None + + client_secret = reg.get("client_secret", "") + self._store_client_registration(server_name, server_url, client_id, client_secret, + redirect_uri, auth_meta) + logger.info(f"Registered OAuth2 client for MCP server '{server_name}': {client_id!r}") + return { + "client_id": client_id, + "client_secret": client_secret, + "redirect_uri": redirect_uri, + "auth_meta": auth_meta, + } + + # ------------------------------------------------------------------ + # Authorization URL + PKCE + # ------------------------------------------------------------------ + + def get_authorization_url( + self, server_name: str, server_url: str + ) -> Optional[Tuple[str, str, str]]: + """ + Build the authorization redirect URL. + Returns (authorization_url, state, code_verifier) or None. + """ + reg = self.get_or_register_client(server_name, server_url) + if not reg: + return None + + auth_endpoint = reg["auth_meta"].get("authorization_endpoint") + if not auth_endpoint: + return None + + code_verifier = secrets.token_urlsafe(64) + code_challenge = ( + base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) + .rstrip(b"=") + .decode() + ) + state = secrets.token_urlsafe(32) + + params = { + "response_type": "code", + "client_id": reg["client_id"], + "redirect_uri": reg["redirect_uri"], + "state": state, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + } + auth_url = f"{auth_endpoint}?{urlencode(params)}" + return auth_url, state, code_verifier + + # ------------------------------------------------------------------ + # Token Exchange & Refresh + # ------------------------------------------------------------------ + + def exchange_code(self, server_name: str, code: str, code_verifier: str) -> Optional[dict]: + """Exchange an authorization code for tokens.""" + reg = self._fetch_client_registration(server_name) + if not reg: + return None + + token_endpoint = reg["auth_meta"].get("token_endpoint") + if not token_endpoint: + return None + + data = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": reg["redirect_uri"], + "client_id": reg["client_id"], + "code_verifier": code_verifier, + } + if reg.get("client_secret"): + data["client_secret"] = reg["client_secret"] + + try: + resp = http_requests.post(token_endpoint, data=data, verify=_VERIFY, timeout=10) + resp.raise_for_status() + return resp.json() + except Exception as e: + logger.error(f"Token exchange failed for MCP server '{server_name}': {e}") + return None + + def _refresh_access_token(self, server_name: str, user_id: str, + refresh_token: str) -> Optional[str]: + reg = self._fetch_client_registration(server_name) + if not reg: + return None + + token_endpoint = reg["auth_meta"].get("token_endpoint") + if not token_endpoint or not refresh_token: + return None + + data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": reg["client_id"], + } + if reg.get("client_secret"): + data["client_secret"] = reg["client_secret"] + + try: + resp = http_requests.post(token_endpoint, data=data, verify=_VERIFY, timeout=10) + resp.raise_for_status() + token_data = resp.json() + except Exception as e: + logger.warning(f"Token refresh failed for '{server_name}', user={user_id!r}: {e}") + return None + + new_access = token_data.get("access_token") + new_refresh = token_data.get("refresh_token") or refresh_token + expires_in = int(token_data.get("expires_in", 3600)) + if new_access: + self.store_user_token(user_id, server_name, new_access, new_refresh, expires_in) + return new_access + + # ------------------------------------------------------------------ + # User token storage + # ------------------------------------------------------------------ + + def store_user_token(self, user_id: str, server_name: str, access_token: str, + refresh_token: Optional[str], expires_in: int = 3600) -> None: + if not self._encryption_key: + return + + now = datetime.now(timezone.utc) + access_expires_at = now + timedelta(seconds=expires_in) + session_expires_at = now + timedelta(days=30) + + try: + with self._get_pool().get_connection() as conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO mcp_oauth_tokens + (user_id, server_name, access_token, refresh_token, + access_token_expires_at, session_expires_at, updated_at) + VALUES (%s, %s, + pgp_sym_encrypt(%s, %s), + pgp_sym_encrypt(%s, %s), + %s, %s, NOW()) + ON CONFLICT (user_id, server_name) DO UPDATE SET + access_token = EXCLUDED.access_token, + refresh_token = EXCLUDED.refresh_token, + access_token_expires_at = EXCLUDED.access_token_expires_at, + session_expires_at = EXCLUDED.session_expires_at, + updated_at = NOW() + """, + ( + user_id, server_name, + access_token, self._encryption_key, + refresh_token or "", self._encryption_key, + access_expires_at, session_expires_at, + ), + ) + conn.commit() + logger.info(f"Stored MCP token for user={user_id!r}, server={server_name!r}, " + f"expires={access_expires_at.isoformat()}") + except Exception as e: + logger.error(f"Failed to store MCP token for user={user_id!r}, server={server_name!r}: {e}") + + def get_access_token(self, user_id: str, server_name: str) -> Optional[str]: + """Return a valid access token, silently refreshing if expired.""" + if not user_id or not self._encryption_key: + return None + + row = self._fetch_user_token(user_id, server_name) + if row is None: + return None + + access_token, refresh_token, access_expires_at, session_expires_at = row + now = datetime.now(timezone.utc) + + if session_expires_at and now > session_expires_at: + self.invalidate_user_token(user_id, server_name) + return None + + if access_expires_at and now < access_expires_at: + return access_token + + logger.info(f"MCP access token expired for user={user_id!r}, server={server_name!r}, refreshing") + return self._refresh_access_token(server_name, user_id, refresh_token) + + def invalidate_user_token(self, user_id: str, server_name: str) -> None: + try: + with self._get_pool().get_connection() as conn: + with conn.cursor() as cur: + cur.execute( + "DELETE FROM mcp_oauth_tokens WHERE user_id = %s AND server_name = %s", + (user_id, server_name), + ) + conn.commit() + except Exception as e: + logger.warning(f"Failed to invalidate MCP token for user={user_id!r}, server={server_name!r}: {e}") + + def get_servers_needing_auth(self, user_id: str, mcp_servers: dict) -> list: + """Return list of server names that require OAuth but have no valid token.""" + return [ + name for name, cfg in mcp_servers.items() + if cfg.get("sso_auth") and not self.get_access_token(user_id, name) + ] + + # ------------------------------------------------------------------ + # Internal DB helpers + # ------------------------------------------------------------------ + + def _store_client_registration(self, server_name: str, server_url: str, + client_id: str, client_secret: str, + redirect_uri: str, auth_meta: dict) -> None: + try: + with self._get_pool().get_connection() as conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO mcp_oauth_clients + (server_name, server_url, client_id, client_secret, + redirect_uri, auth_meta) + VALUES (%s, %s, %s, %s, %s, %s) + ON CONFLICT (server_name) DO UPDATE SET + server_url = EXCLUDED.server_url, + client_id = EXCLUDED.client_id, + client_secret = EXCLUDED.client_secret, + redirect_uri = EXCLUDED.redirect_uri, + auth_meta = EXCLUDED.auth_meta, + updated_at = NOW() + """, + (server_name, server_url, client_id, client_secret, + redirect_uri, json.dumps(auth_meta)), + ) + conn.commit() + except Exception as e: + logger.error(f"Failed to store client registration for '{server_name}': {e}") + + def _fetch_client_registration(self, server_name: str) -> Optional[dict]: + try: + with self._get_pool().get_connection() as conn: + with conn.cursor() as cur: + cur.execute( + "SELECT client_id, client_secret, redirect_uri, auth_meta " + "FROM mcp_oauth_clients WHERE server_name = %s", + (server_name,), + ) + row = cur.fetchone() + if row is None: + return None + client_id, client_secret, redirect_uri, auth_meta_raw = row + return { + "client_id": client_id, + "client_secret": client_secret, + "redirect_uri": redirect_uri, + "auth_meta": ( + json.loads(auth_meta_raw) + if isinstance(auth_meta_raw, str) + else auth_meta_raw + ), + } + except Exception as e: + logger.warning(f"Failed to fetch client registration for '{server_name}': {e}") + return None + + def _fetch_user_token(self, user_id: str, server_name: str): + try: + with self._get_pool().get_connection() as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT pgp_sym_decrypt(access_token, %s)::text, + pgp_sym_decrypt(refresh_token, %s)::text, + access_token_expires_at, + session_expires_at + FROM mcp_oauth_tokens + WHERE user_id = %s AND server_name = %s + """, + (self._encryption_key, self._encryption_key, user_id, server_name), + ) + return cur.fetchone() + except Exception as e: + logger.warning(f"Failed to fetch MCP token for user={user_id!r}, server={server_name!r}: {e}") + return None + + def _get_pool(self): + from src.utils.postgres_service_factory import PostgresServiceFactory + factory = PostgresServiceFactory.get_instance() + if factory: + return factory.connection_pool + from src.utils.connection_pool import ConnectionPool + return ConnectionPool(connection_params=self.pg_config) diff --git a/src/utils/rbac/mattermost_context.py b/src/utils/rbac/mattermost_context.py new file mode 100644 index 000000000..c2a5c921f --- /dev/null +++ b/src/utils/rbac/mattermost_context.py @@ -0,0 +1,51 @@ +""" +Mattermost User Context - Thread-safe per-request user context for Mattermost. + +Provides a ContextVar-based mechanism to carry Mattermost user identity +through the call stack without needing Flask sessions. +""" + +from contextvars import ContextVar +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import List, Optional + +from src.utils.logging import get_logger + +logger = get_logger(__name__) + + +@dataclass +class MattermostUserContext: + user_id: str + username: str + roles: List[str] + email: str = "" + + +# Module-level ContextVar — default None means "no Mattermost context active" +_mm_context: ContextVar[Optional[MattermostUserContext]] = ContextVar( + 'mm_user_context', default=None +) + + +def get_mattermost_context() -> Optional[MattermostUserContext]: + """Return the active Mattermost user context, or None if not set.""" + return _mm_context.get() + + +@contextmanager +def mattermost_user_context(ctx: MattermostUserContext): + """ + Context manager that sets the Mattermost user context for the duration + of the block, then resets it. Thread-safe via ContextVar. + + Usage: + with mattermost_user_context(ctx): + answer, _ = ai_wrapper(post) + """ + token = _mm_context.set(ctx) + try: + yield ctx + finally: + _mm_context.reset(token) diff --git a/src/utils/rbac/permission_enum.py b/src/utils/rbac/permission_enum.py index cde4998a2..ecc6fe222 100644 --- a/src/utils/rbac/permission_enum.py +++ b/src/utils/rbac/permission_enum.py @@ -26,6 +26,9 @@ class Chat(str, Enum): HISTORY = "chat:history" FEEDBACK = "chat:feedback" + class Mattermost(str, Enum): + ACCESS = "mattermost:access" # Gate for Mattermost bot access (not granted to base-user) + class Documents(str, Enum): VIEW = "documents:view" SELECT = "documents:select" diff --git a/src/utils/sql.py b/src/utils/sql.py index 166644488..73aea0fa5 100644 --- a/src/utils/sql.py +++ b/src/utils/sql.py @@ -159,6 +159,55 @@ WHERE conversation_id = %s AND (user_id = %s OR client_id = %s); """ +# ============================================================================= +# Cross-platform (Mattermost ↔ web-chat) conversation queries +# ============================================================================= + +# Look up integer conversation_id by external source_ref (e.g. "mm_thread_xxx") +SQL_MM_GET_CONV_ID_BY_SOURCE_REF = """ +SELECT conversation_id FROM conversation_metadata +WHERE source_ref = %s +LIMIT 1; +""" + +# Create a conversation_metadata row for a Mattermost thread/channel +SQL_MM_CREATE_CONVERSATION = """ +INSERT INTO conversation_metadata ( + title, created_at, last_message_at, client_id, archi_version, archi_service, source_ref +) +VALUES (%s, %s, %s, %s, %s, 'mattermost', %s) +RETURNING conversation_id; +""" + +# Update last_message_at for a Mattermost conversation (matched by integer id + source_ref) +SQL_MM_UPDATE_CONVERSATION_TIMESTAMP = """ +UPDATE conversation_metadata +SET last_message_at = %s +WHERE conversation_id = %s; +""" + +# List conversations for an authenticated user, including Mattermost ones linked by +# mm_client_id = "mm_user_{preferred_username}". Three ownership checks: +# 1. user_id matches (web-chat SSO conversations) +# 2. client_id matches (anonymous / browser-keyed conversations) +# 3. client_id = mm_client_id (Mattermost-originated conversations) +SQL_LIST_CONVERSATIONS_ALL_SOURCES = """ +SELECT conversation_id, title, created_at, last_message_at, + COALESCE(archi_service, 'chat') AS archi_service +FROM conversation_metadata +WHERE user_id = %s OR client_id = %s OR client_id = %s +ORDER BY last_message_at DESC +LIMIT %s; +""" + +# Fetch metadata for a single conversation, accepting any of the three ownership proofs +SQL_GET_CONVERSATION_METADATA_ALL_SOURCES = """ +SELECT conversation_id, title, created_at, last_message_at, + COALESCE(archi_service, 'chat') AS archi_service +FROM conversation_metadata +WHERE conversation_id = %s AND (user_id = %s OR client_id = %s OR client_id = %s); +""" + # ============================================================================= # Tool Calls Queries # ============================================================================= diff --git a/src/utils/sso_token_service.py b/src/utils/sso_token_service.py new file mode 100644 index 000000000..0c202c13e --- /dev/null +++ b/src/utils/sso_token_service.py @@ -0,0 +1,218 @@ +""" +SSOTokenService - DB-backed access/refresh token store for web SSO users. + +Mirrors the pattern from MattermostTokenService but for the main chat_app's +SSO flow. Stores pgp-encrypted tokens in PostgreSQL, refreshes access tokens +silently when they expire, so MCP servers can use per-user Bearer auth +without relying on short-lived Flask session state. + +Session lifetime: configurable (default 30 days) — full re-login required +Access token: refreshed silently via refresh_token when expired +""" + +import requests as http_requests +from datetime import datetime, timedelta, timezone +from typing import Optional + +from src.utils.env import read_secret +from src.utils.logging import get_logger + +logger = get_logger(__name__) + + +class SSOTokenService: + """ + DB-backed token store for chat_app SSO auth. + + Stores encrypted access_token + refresh_token per user_id (the SSO 'sub' + claim). When the access_token has expired, silently exchanges the + refresh_token for a new one. Full re-login is only required when the + refresh_token itself expires (session_lifetime_days). + """ + + def __init__( + self, + pg_config: dict = None, + token_endpoint: str = "", + session_lifetime_days: int = 30, + ): + self.pg_config = pg_config or {} + self.token_endpoint = token_endpoint + self.session_lifetime_days = session_lifetime_days + self._encryption_key = ( + read_secret("BYOK_ENCRYPTION_KEY") + or read_secret("PG_ENCRYPTION_KEY") + or read_secret("UPLOADER_SALT") + or read_secret("FLASK_UPLOADER_APP_SECRET_KEY") + ) + if not self._encryption_key: + logger.warning( + "SSOTokenService: no encryption key found " + "(BYOK_ENCRYPTION_KEY / PG_ENCRYPTION_KEY / UPLOADER_SALT). " + "SSO tokens will not be persisted — MCP sso_auth servers will be skipped." + ) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def store_token( + self, + user_id: str, + access_token: str, + refresh_token: Optional[str], + expires_in: int = 300, + ) -> None: + """Persist access + refresh tokens after a successful SSO login.""" + if not self._encryption_key: + return + + now = datetime.now(timezone.utc) + access_expires_at = now + timedelta(seconds=expires_in) + session_expires_at = now + timedelta(days=self.session_lifetime_days) + + try: + with self._get_pool().get_connection() as conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO sso_tokens + (user_id, access_token, refresh_token, + access_token_expires_at, session_expires_at, updated_at) + VALUES (%s, + pgp_sym_encrypt(%s, %s), + pgp_sym_encrypt(%s, %s), + %s, %s, NOW()) + ON CONFLICT (user_id) DO UPDATE SET + access_token = EXCLUDED.access_token, + refresh_token = EXCLUDED.refresh_token, + access_token_expires_at = EXCLUDED.access_token_expires_at, + session_expires_at = EXCLUDED.session_expires_at, + updated_at = NOW() + """, + ( + user_id, + access_token, self._encryption_key, + refresh_token or "", self._encryption_key, + access_expires_at, session_expires_at, + ), + ) + conn.commit() + logger.info( + f"Stored SSO tokens for user_id={user_id!r}, " + f"access_expires={access_expires_at.isoformat()}" + ) + except Exception as exc: + logger.error(f"Failed to store SSO token for user_id={user_id!r}: {exc}") + + def get_access_token(self, user_id: str) -> Optional[str]: + """ + Return a valid access token for the user. + + - Returns the stored token if it hasn't expired. + - Silently refreshes via refresh_token if the access token is stale. + - Returns None if no token is stored, the session has expired, or + the refresh fails (user must re-login). + """ + if not user_id or not self._encryption_key: + return None + + row = self._fetch_row(user_id) + if row is None: + logger.debug(f"No SSO token stored for user_id={user_id!r}") + return None + + access_token, refresh_token, access_expires_at, session_expires_at = row + now = datetime.now(timezone.utc) + + # Hard session expiry — full re-login required + if session_expires_at and now > session_expires_at: + logger.info(f"SSO session expired for user_id={user_id!r}, invalidating") + self.invalidate(user_id) + return None + + # Access token still valid + if access_expires_at and now < access_expires_at: + return access_token + + # Access token expired — try silent refresh + logger.info(f"Access token expired for user_id={user_id!r}, refreshing") + return self._refresh_access_token(user_id, refresh_token) + + def invalidate(self, user_id: str) -> None: + """Delete stored tokens (e.g. on logout or hard session expiry).""" + try: + with self._get_pool().get_connection() as conn: + with conn.cursor() as cur: + cur.execute("DELETE FROM sso_tokens WHERE user_id = %s", (user_id,)) + conn.commit() + logger.info(f"Invalidated SSO tokens for user_id={user_id!r}") + except Exception as exc: + logger.warning(f"Failed to invalidate SSO token for user_id={user_id!r}: {exc}") + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _get_pool(self): + from src.utils.postgres_service_factory import PostgresServiceFactory + factory = PostgresServiceFactory.get_instance() + if factory: + return factory.connection_pool + from src.utils.connection_pool import ConnectionPool + return ConnectionPool(connection_params=self.pg_config) + + def _fetch_row(self, user_id: str): + try: + with self._get_pool().get_connection() as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT pgp_sym_decrypt(access_token, %s)::text, + pgp_sym_decrypt(refresh_token, %s)::text, + access_token_expires_at, + session_expires_at + FROM sso_tokens + WHERE user_id = %s + """, + (self._encryption_key, self._encryption_key, user_id), + ) + return cur.fetchone() + except Exception as exc: + logger.warning(f"Failed to fetch SSO token for user_id={user_id!r}: {exc}") + return None + + def _refresh_access_token(self, user_id: str, refresh_token: Optional[str]) -> Optional[str]: + if not refresh_token or not self.token_endpoint: + return None + + client_id = read_secret("SSO_CLIENT_ID") + client_secret = read_secret("SSO_CLIENT_SECRET") + if not client_id or not client_secret: + return None + + try: + resp = http_requests.post( + self.token_endpoint, + data={ + "grant_type": "refresh_token", + "client_id": client_id, + "client_secret": client_secret, + "refresh_token": refresh_token, + }, + timeout=10, + ) + resp.raise_for_status() + new_token = resp.json() + except Exception as exc: + logger.warning(f"Token refresh HTTP error for user_id={user_id!r}: {exc}") + return None + + new_access = new_token.get("access_token") + new_refresh = new_token.get("refresh_token") or refresh_token + expires_in = int(new_token.get("expires_in", 300)) + + if new_access: + self.store_token(user_id, new_access, new_refresh, expires_in) + + return new_access diff --git a/tests/unit/test_mcp_sse_tools.py b/tests/unit/test_mcp_sse_tools.py new file mode 100644 index 000000000..985bc22e6 --- /dev/null +++ b/tests/unit/test_mcp_sse_tools.py @@ -0,0 +1,281 @@ +from __future__ import annotations + +import sys +import types +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +try: + import flask # noqa: F401 +except ModuleNotFoundError: + fake_flask = types.ModuleType("flask") + + class _Blueprint: + def __init__(self, *args, **kwargs): + pass + + def route(self, *args, **kwargs): + def _decorator(func): + return func + + return _decorator + + class _Response: + def __init__(self, *args, **kwargs): + pass + + fake_flask.Blueprint = _Blueprint + fake_flask.Response = _Response + fake_flask.jsonify = lambda payload=None, **kwargs: payload if payload is not None else kwargs + fake_flask.request = types.SimpleNamespace( + headers={}, + args={}, + host="localhost", + scheme="http", + get_json=lambda silent=True: None, + ) + fake_flask.stream_with_context = lambda generator: generator + sys.modules["flask"] = fake_flask + +from src.interfaces.chat_app import mcp_sse + + +def _tool_text(result: dict) -> str: + return result["content"][0]["text"] + + +@pytest.fixture +def wrapper(): + wrapper = MagicMock() + wrapper.chat = MagicMock() + wrapper.chat.data_viewer = MagicMock() + return wrapper + + +def test_tool_list_includes_repo_backed_discovery_tools(): + names = {tool["name"] for tool in mcp_sse._TOOLS} + + assert "archi_search_document_metadata" in names + assert "archi_list_metadata_schema" in names + assert "archi_search_document_content" in names + assert "archi_get_document_chunks" in names + assert "archi_get_data_stats" in names + assert "archi_get_agent_spec" in names + + +def test_list_documents_passes_conversation_and_enabled_filter(wrapper): + wrapper.chat.data_viewer.list_documents.return_value = { + "documents": [ + { + "hash": "doc-1", + "display_name": "Doc 1", + "source_type": "git", + "ingestion_status": "embedded", + "enabled": False, + } + ], + "total": 1, + } + + result = mcp_sse._tool_list_documents( + { + "conversation_id": 42, + "enabled": "disabled", + "limit": 10, + "offset": 5, + }, + wrapper, + ) + + wrapper.chat.data_viewer.list_documents.assert_called_once_with( + conversation_id=42, + source_type=None, + search=None, + enabled_filter="disabled", + limit=10, + offset=5, + ) + assert "enabled=no" in _tool_text(result) + + +def test_search_document_metadata_parses_or_filters(wrapper): + catalog = wrapper.chat.data_viewer.catalog + catalog.search_metadata.return_value = [ + { + "hash": "hash-1", + "path": Path("/tmp/docs/readme.md"), + "metadata": { + "display_name": "README", + "source_type": "git", + "relative_path": "docs/readme.md", + }, + } + ] + + result = mcp_sse._tool_search_document_metadata( + {"query": "source_type:git OR source_type:web outage", "limit": 7}, + wrapper, + ) + + catalog.search_metadata.assert_called_once_with( + "outage", + limit=7, + filters=[{"source_type": "git"}, {"source_type": "web"}], + ) + assert "README" in _tool_text(result) + assert "hash-1" in _tool_text(result) + + +def test_list_metadata_schema_formats_distinct_values(wrapper): + catalog = wrapper.chat.data_viewer.catalog + catalog.get_distinct_metadata.return_value = { + "source_type": ["git", "web"], + "suffix": [".md", ".py"], + } + + result = mcp_sse._tool_list_metadata_schema(wrapper) + text = _tool_text(result) + + assert "source_type values: git, web" in text + assert "suffix values: .md, .py" in text + assert "relative_path" in text + + +def test_search_document_content_greps_indexed_files(wrapper, tmp_path): + catalog = wrapper.chat.data_viewer.catalog + doc_path = tmp_path / "example.log" + doc_path.write_text("alpha\nneedle here\nomega\n", encoding="utf-8") + + catalog.iter_files.return_value = [("hash-1", doc_path)] + catalog.get_metadata_for_hash.return_value = { + "display_name": "Example Log", + "source_type": "local_files", + } + + fake_loader_utils = types.ModuleType("src.data_manager.vectorstore.loader_utils") + fake_loader_utils.load_text_from_path = lambda path: Path(path).read_text(encoding="utf-8") + + with patch.dict(sys.modules, {"src.data_manager.vectorstore.loader_utils": fake_loader_utils}): + result = mcp_sse._tool_search_document_content( + {"query": "needle", "before": 1, "after": 1}, + wrapper, + ) + + text = _tool_text(result) + assert "Example Log" in text + assert "L2: needle here" in text + assert "B: alpha" in text + assert "A: omega" in text + + +def test_get_document_chunks_paginates_and_truncates(wrapper): + wrapper.chat.data_viewer.get_document_chunks.return_value = [ + {"index": 0, "text": "a" * 120, "start_char": 0, "end_char": 119}, + {"index": 1, "text": "b" * 120, "start_char": 120, "end_char": 239}, + {"index": 2, "text": "c" * 120, "start_char": 240, "end_char": 359}, + ] + + result = mcp_sse._tool_get_document_chunks( + { + "document_hash": "hash-1", + "offset": 1, + "limit": 1, + "max_chars_per_chunk": 80, + }, + wrapper, + ) + + text = _tool_text(result) + assert "showing 1 from offset 1" in text + assert "chunk 1" in text + assert "chars=120-239" in text + assert "..." in text + + +def test_get_data_stats_formats_source_breakdown(wrapper): + wrapper.chat.data_viewer.get_stats.return_value = { + "total_documents": 12, + "total_chunks": 34, + "enabled_documents": 10, + "disabled_documents": 2, + "total_size_bytes": 2048, + "last_sync": "2026-03-17T12:00:00+00:00", + "status_counts": {"pending": 1, "embedding": 2, "embedded": 8, "failed": 1}, + "by_source_type": { + "git": {"total": 5, "enabled": 4}, + "web": {"total": 7, "enabled": 6}, + }, + } + + result = mcp_sse._tool_get_data_stats({"conversation_id": 99}, wrapper) + wrapper.chat.data_viewer.get_stats.assert_called_once_with(99) + + text = _tool_text(result) + assert "Total documents: 12" in text + assert "git: total=5, enabled=4" in text + assert "web: total=7, enabled=6" in text + + +def test_get_agent_spec_returns_full_markdown(wrapper, tmp_path): + agents_dir = tmp_path / "agents" + agents_dir.mkdir() + agent_path = agents_dir / "ops.md" + content = ( + "---\n" + "name: Ops Agent\n" + "tools:\n" + " - search_vectorstore_hybrid\n" + "---\n\n" + "You help with ops questions.\n" + ) + agent_path.write_text(content, encoding="utf-8") + wrapper._get_agents_dir.return_value = agents_dir + + result = mcp_sse._tool_get_agent_spec({"agent_name": "Ops Agent"}, wrapper) + + assert _tool_text(result) == content + + +def test_deployment_info_includes_active_agent_and_mcp_servers(wrapper): + wrapper.chat.agent_spec = types.SimpleNamespace(name="Fallback Agent") + + fake_static = types.SimpleNamespace( + available_providers=["openai", "anthropic"], + available_pipelines=["QAPipeline", "CMSCompOpsAgent"], + ) + fake_dynamic = types.SimpleNamespace( + active_agent_name="Configured Agent", + active_model="openai/gpt-4o", + temperature=0.2, + max_tokens=2048, + num_documents_to_retrieve=6, + use_hybrid_search=True, + bm25_weight=0.4, + semantic_weight=0.6, + ) + + with patch("src.utils.config_access.get_full_config") as mock_full, \ + patch("src.utils.config_access.get_static_config") as mock_static, \ + patch("src.utils.config_access.get_dynamic_config") as mock_dynamic: + mock_full.return_value = { + "name": "demo", + "services": { + "chat_app": {"pipeline": "QAPipeline", "agent_class": "CMSCompOpsAgent"}, + "data_manager": {"embedding": {"model": "text-embedding-3-small", "chunk_size": 800, "chunk_overlap": 100}}, + "mcp_server": {"enabled": True}, + }, + "mcp_servers": {"deepwiki": {}, "search": {}}, + } + mock_static.return_value = fake_static + mock_dynamic.return_value = fake_dynamic + + result = mcp_sse._tool_deployment_info(wrapper) + + text = _tool_text(result) + assert "Active agent: Configured Agent" in text + assert "MCP servers: deepwiki, search" in text + assert "Available providers: openai, anthropic" in text