diff --git a/pyproject.toml b/pyproject.toml index f5136f334..bca725fc9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,8 @@ dependencies = [ "pandas==2.3.2", "isort==6.0.1", "pre-commit>=4", - "psycopg2-binary==2.9.10" + "psycopg2-binary==2.9.10", + "docker==7.1.0" ] [project.scripts] diff --git a/requirements/requirements-base.txt b/requirements/requirements-base.txt index 39fcd0c43..410a862b9 100644 --- a/requirements/requirements-base.txt +++ b/requirements/requirements-base.txt @@ -85,3 +85,4 @@ aiohttp==3.9.5 nltk==3.9.1 sentence-transformers==5.1.2 rank_bm25==0.2.2 +docker==7.1.0 diff --git a/src/archi/pipelines/agents/base_react.py b/src/archi/pipelines/agents/base_react.py index 5bfe27157..39e131f23 100644 --- a/src/archi/pipelines/agents/base_react.py +++ b/src/archi/pipelines/agents/base_react.py @@ -25,6 +25,38 @@ logger = get_logger(__name__) +def _get_role_context() -> str: + """ + Get role context string for the current user if enabled. + + Requires SSO auth with auth_roles configured and pass_descriptions_to_agent: true. + Returns empty string if conditions not met or user not authenticated. + """ + try: + from flask import session, has_request_context + if not has_request_context(): + return "" + if not session.get('logged_in'): + return "" + + from src.utils.rbac.registry import get_registry + registry = get_registry() + + if not registry.pass_descriptions_to_agent: + return "" + + roles = session.get('roles', []) + if not roles: + return "" + + descriptions = registry.get_role_descriptions(roles) + if descriptions: + return f"\n\nUser roles: {descriptions}." + return "" + except Exception as e: + logger.debug(f"Could not get role context: {e}") + return "" + class BaseReActAgent: """ BaseReActAgent provides a foundational structure for building pipeline classes that @@ -835,8 +867,20 @@ def refresh_agent( self._active_middleware = list(middleware) return self.agent + def _build_system_prompt(self) -> str: + """ + Build the full system prompt, appending role context if enabled. + + Role context is appended when SSO auth with auth_roles is configured + and pass_descriptions_to_agent is set to true. + """ + base_prompt = self.agent_prompt or "" + role_context = _get_role_context() + return base_prompt + role_context + def _create_agent(self, tools: Sequence[Callable], middleware: Sequence[Callable]) -> CompiledStateGraph: """Create the LangGraph agent with the specified LLM, tools, and system prompt.""" + system_prompt = self._build_system_prompt() logger.debug("Creating agent %s with:", self.__class__.__name__) logger.debug("%d tools", len(tools)) logger.debug("%d middleware components", len(middleware)) @@ -844,7 +888,7 @@ def _create_agent(self, tools: Sequence[Callable], middleware: Sequence[Callable model=self.agent_llm, tools=tools, middleware=middleware, - system_prompt=self.agent_prompt, + system_prompt=system_prompt, ) def _build_static_tools(self) -> List[Callable]: diff --git a/src/archi/pipelines/agents/cms_comp_ops_agent.py b/src/archi/pipelines/agents/cms_comp_ops_agent.py index a571b5791..1676ed8f8 100644 --- a/src/archi/pipelines/agents/cms_comp_ops_agent.py +++ b/src/archi/pipelines/agents/cms_comp_ops_agent.py @@ -13,6 +13,9 @@ create_metadata_search_tool, create_metadata_schema_tool, create_retriever_tool, + create_http_get_tool, + create_sandbox_tool, + initialize_mcp_client, RemoteCatalogClient, MONITOpenSearchClient, create_monit_opensearch_search_tool, @@ -170,6 +173,37 @@ def _build_fetch_tool(self) -> Callable: description=description, ) + http_get_tool = create_http_get_tool( + name="fetch_url", + description=( + "Fetch live content from a URL via HTTP GET request. " + "Input: A valid HTTP or HTTPS URL. " + "Output: The response body text or an error message. " + "Use this to retrieve real-time data from web endpoints, APIs, documentation, or status pages. " + "Examples: checking endpoint status, fetching API data, retrieving documentation." + ), + timeout=15.0, + max_response_chars=600000, + ) + + all_tools = [file_search_tool, metadata_search_tool, metadata_schema_tool, fetch_tool, http_get_tool] + + # Add sandbox tool for code execution in isolated containers + sandbox_tool = create_sandbox_tool( + name="run_code", + description=( + "Execute code in a secure sandboxed Docker container. " + "Input: code (str), language ('python', 'bash', or 'sh'). " + "Output: stdout, stderr, exit code, and any files written to /workspace/output/. " + "The /workspace/ and /workspace/output/ directories are pre-created and writable — " + "do NOT call os.makedirs() for them. Do not show internal sandbox paths in the output. " + "Use this for running Python scripts, shell commands, data processing, API calls with curl, " + "rucio commands, or any code that needs to be executed safely. " + "The container is ephemeral and destroyed after execution." + ), + ) + all_tools.append(sandbox_tool) + logger.info("Sandbox tool added to CMSCompOpsAgent") def _build_vector_tool_placeholder(self) -> List[Callable]: return [] diff --git a/src/archi/pipelines/agents/tools/__init__.py b/src/archi/pipelines/agents/tools/__init__.py index 32856b4fa..1da013a6a 100644 --- a/src/archi/pipelines/agents/tools/__init__.py +++ b/src/archi/pipelines/agents/tools/__init__.py @@ -1,3 +1,4 @@ +from .base import check_tool_permission, require_tool_permission from .local_files import ( create_document_fetch_tool, create_file_search_tool, @@ -12,8 +13,18 @@ create_monit_opensearch_search_tool, create_monit_opensearch_aggregation_tool, ) +from .http_get import create_http_get_tool +from .sandbox import ( + create_sandbox_tool, + create_sandbox_tool_with_files, + set_sandbox_context, + get_sandbox_artifacts, + clear_sandbox_context, +) __all__ = [ + "check_tool_permission", + "require_tool_permission", "create_document_fetch_tool", "create_file_search_tool", "create_metadata_search_tool", @@ -21,6 +32,12 @@ "RemoteCatalogClient", "create_retriever_tool", "initialize_mcp_client", + "create_http_get_tool", + "create_sandbox_tool", + "create_sandbox_tool_with_files", + "set_sandbox_context", + "get_sandbox_artifacts", + "clear_sandbox_context", "MONITOpenSearchClient", "create_monit_opensearch_search_tool", "create_monit_opensearch_aggregation_tool", diff --git a/src/archi/pipelines/agents/tools/base.py b/src/archi/pipelines/agents/tools/base.py index 23b4a6203..8306a51b8 100644 --- a/src/archi/pipelines/agents/tools/base.py +++ b/src/archi/pipelines/agents/tools/base.py @@ -1,8 +1,127 @@ -"""Abstract base class for all tools.""" +"""Base utilities and RBAC decorators for agent tools.""" + +from __future__ import annotations + +from functools import wraps +from typing import Callable, Optional, TypeVar -from typing import Callable from langchain.tools import tool +from src.utils.logging import get_logger + +logger = get_logger(__name__) + + +# Type variable for generic function signatures +F = TypeVar('F', bound=Callable) + + +def check_tool_permission(required_permission: str) -> tuple[bool, Optional[str]]: + """ + Check if the current user has permission to use a tool. + + Uses the Flask session to get user roles and checks against the RBAC registry. + This function is designed to fail open in non-web contexts (CLI, testing) + and when the RBAC system is not configured. + + Args: + required_permission: The permission string to check (e.g., 'tools:http_get') + + Returns: + (has_permission, error_message) tuple where: + - has_permission is True if access is granted + - error_message is None if granted, or a user-friendly error string if denied + """ + try: + from flask import session, has_request_context + from src.utils.rbac.registry import get_registry + + # If we're not in a request context, allow the tool (for testing/CLI usage) + if not has_request_context(): + logger.debug("No request context, allowing tool access") + return True, None + + # Get user roles from session + if not session.get('logged_in'): + logger.warning("User not logged in, denying tool access") + return False, "You must be logged in to use this feature." + + user_roles = session.get('roles', []) + + # Check permission using RBAC registry + try: + registry = get_registry() + if registry.has_permission(user_roles, required_permission): + logger.debug(f"User with roles {user_roles} granted permission '{required_permission}'") + return True, None + else: + logger.info(f"User with roles {user_roles} denied permission '{required_permission}'") + return False, ( + f"Permission denied: This tool requires '{required_permission}' permission. " + f"Your current role(s) ({', '.join(user_roles) if user_roles else 'none'}) " + "do not have access to this feature. Please contact an administrator " + "if you believe you should have access." + ) + except Exception as e: + # If RBAC registry is not configured, log warning and allow access + logger.warning(f"RBAC registry not available, allowing tool access: {e}") + return True, None + + except ImportError as e: + # Flask not available (e.g., running outside web context) + logger.debug(f"Flask not available, allowing tool access: {e}") + return True, None + except Exception as e: + logger.error(f"Unexpected error checking tool permission: {e}") + # Fail open for unexpected errors to avoid breaking functionality + return True, None + + +def require_tool_permission(permission: Optional[str]) -> Callable[[F], F]: + """ + Decorator that enforces RBAC permission check before tool execution. + + This decorator wraps a tool function and checks if the current user + has the required permission before allowing the tool to execute. + If permission is denied, returns an error message instead of executing the tool. + + Args: + permission: The permission string required to use the tool (e.g., 'tools:http_get'). + If None, no permission check is performed (allow all). + + Returns: + A decorator function that wraps the tool with permission checking. + + Example: + @require_tool_permission("tools:http_get") + def _http_get_tool(url: str) -> str: + ... + + Note: + - If permission is None, the decorator is a no-op (returns original function) + - Permission checks fail open in non-web contexts (CLI, testing) + - Permission checks fail open if RBAC registry is not configured + """ + def decorator(func: F) -> F: + if permission is None: + # No permission required, return original function + return func + + # Capture permission in closure for type checker (guaranteed non-None here) + required_perm: str = permission + + @wraps(func) + def wrapper(*args, **kwargs): + has_perm, error_msg = check_tool_permission(required_perm) + if not has_perm: + logger.warning(f"Tool '{func.__name__}' permission denied: {required_perm}") + return f"Error: {error_msg}" + return func(*args, **kwargs) + + return wrapper # type: ignore + + return decorator + def create_abstract_tool( *, diff --git a/src/archi/pipelines/agents/tools/http_get.py b/src/archi/pipelines/agents/tools/http_get.py new file mode 100644 index 000000000..bc853aac8 --- /dev/null +++ b/src/archi/pipelines/agents/tools/http_get.py @@ -0,0 +1,198 @@ +"""HTTP GET request tool for fetching live data from URLs.""" + +from __future__ import annotations + +from typing import Callable, Optional +from urllib.parse import urlparse, urlunparse + +import requests +from langchain.tools import tool + +from src.utils.logging import get_logger +from src.archi.pipelines.agents.tools.base import require_tool_permission + +logger = get_logger(__name__) + + +# Default permission required to use the HTTP GET tool +DEFAULT_REQUIRED_PERMISSION = "tools:http_get" + + +def _validate_url(url: str) -> tuple[bool, Optional[str]]: + """ + Validate that the URL is well-formed and uses HTTP/HTTPS. + + Returns: + (is_valid, error_message) tuple + """ + try: + parsed = urlparse(url) + if parsed.scheme not in ("http", "https"): + return False, f"Invalid URL scheme '{parsed.scheme}'. Only HTTP and HTTPS are supported." + if not parsed.netloc: + return False, "Invalid URL: missing hostname." + return True, None + except Exception as e: + return False, f"Invalid URL: {str(e)}" + + +def _sanitize_url_for_error(url: str) -> str: + """ + Remove credentials from URL for error messages. + + Example: http://user:pass@example.com -> http://***:***@example.com + """ + try: + parsed = urlparse(url) + if parsed.username or parsed.password: + sanitized_netloc = f"***:***@{parsed.hostname}" + (f":{parsed.port}" if parsed.port else "") + return urlunparse(( + parsed.scheme, + sanitized_netloc, + parsed.path, + parsed.params, + parsed.query, + parsed.fragment, + )) + return url + except Exception: + return "***" + + +def create_http_get_tool( + *, + name: str = "fetch_url", + description: Optional[str] = None, + timeout: float = 10.0, + max_response_chars: int = 40000, + required_permission: Optional[str] = DEFAULT_REQUIRED_PERMISSION, +) -> Callable[[str], str]: + """ + Create a LangChain tool that makes HTTP GET requests to fetch live data from URLs. + + This tool allows agents to retrieve real-time information from web endpoints, + APIs, or documentation URLs. Only GET requests are supported for security reasons. + + Args: + name: The name of the tool (used by the LLM when selecting tools). + description: Human-readable description of what the tool does. + If None, a default description is used. + timeout: Maximum time in seconds to wait for a response. Default is 10 seconds. + max_response_chars: Maximum number of characters to return from the response body. + Responses longer than this are truncated with a "[truncated]" indicator. + Default is 40000 characters. + required_permission: The RBAC permission required to use this tool. + Default is 'tools:http_get'. Set to None to disable permission checks. + + Returns: + A callable LangChain tool that accepts a URL string and returns either: + - The response body text (truncated if needed) + - An error message describing what went wrong + + Example: + >>> from src.archi.pipelines.agents.tools import create_http_get_tool + >>> http_tool = create_http_get_tool( + ... name="fetch_endpoint", + ... description="Fetch data from a REST API endpoint", + ... timeout=15.0, + ... max_response_chars=60000, + ... ) + >>> # Add to agent's tool list + >>> tools = [retriever_tool, file_search_tool, http_tool] + + Security Notes: + - Only HTTP and HTTPS URLs are accepted + - Credentials in URLs are sanitized in error messages + - No authentication/authorization is built-in (use with public endpoints) + - Response size is limited to prevent context window overflow + - Timeouts prevent hanging on slow/unresponsive endpoints + - RBAC permission check is enforced at tool invocation time + + Error Handling: + The tool returns descriptive error strings rather than raising exceptions, + allowing the agent to handle failures gracefully and provide useful feedback + to the user. Common error cases: + - Permission denied (user lacks required RBAC permission) + - Invalid or malformed URLs + - Connection timeouts or failures + - HTTP error status codes (4xx, 5xx) + - Network errors + """ + tool_description = description or ( + "Fetch content from a URL via HTTP GET request.\n" + "Input: A valid HTTP or HTTPS URL string.\n" + "Output: The response body text (up to {max_chars} characters) or an error message.\n" + "Use this to retrieve live data from web endpoints, APIs, or documentation URLs.\n" + "Example input: 'https://example.com/api/status'\n" + "IMPORTANT: When using this tool, avoid providing general answers from your knowledge. " + "Instead, if you fail to retrieve the data, inform the user with the error message returned by this tool and ask if they would like a general answer instead." + ).format(max_chars=max_response_chars) + + @tool(name, description=tool_description) + @require_tool_permission(required_permission) + def _http_get_tool(url: str) -> str: + """Fetch content from a URL via HTTP GET request.""" + # Validate URL + is_valid, error_msg = _validate_url(url) + if not is_valid: + logger.warning(f"HTTP GET tool received invalid URL: {_sanitize_url_for_error(url)}") + return f"Error: {error_msg}" + + # Make request with error handling + try: + logger.info(f"HTTP GET tool fetching: {_sanitize_url_for_error(url)}") + + response = requests.get( + url, + timeout=timeout, + allow_redirects=True, + ) + + # Check for authentication errors first + if response.status_code == 401: + logger.warning( + f"HTTP GET tool received 401 Unauthorized from {_sanitize_url_for_error(url)}" + ) + return ( + "Error: HTTP 401: Unauthorized. This endpoint requires authentication, " + "but the HTTP GET tool does not support authentication credentials. " + "Please use a public endpoint or provide the user with alternative access methods." + ) + + # Check for other HTTP errors (4xx, 5xx) + if response.status_code >= 400: + logger.warning( + f"HTTP GET tool received status {response.status_code} from {_sanitize_url_for_error(url)}" + ) + status_text = response.reason or "Error" + return f"Error: HTTP {response.status_code}: {status_text}" + + # Success - return response text (truncated if needed) + response_text = response.text + if len(response_text) > max_response_chars: + truncated = response_text[:max_response_chars].rstrip() + logger.info( + f"HTTP GET tool truncated response from {len(response_text)} to {max_response_chars} chars" + ) + return f"{truncated}\n\n... [response truncated at {max_response_chars} characters]" + + logger.info(f"HTTP GET tool successfully fetched {len(response_text)} chars") + return response_text + + except requests.exceptions.Timeout: + logger.warning(f"HTTP GET tool timeout after {timeout}s: {_sanitize_url_for_error(url)}") + return f"Error: Request timed out after {timeout} seconds. The endpoint may be slow or unresponsive." + + except requests.exceptions.ConnectionError as e: + logger.warning(f"HTTP GET tool connection error: {_sanitize_url_for_error(url)} - {str(e)}") + return f"Error: Connection failed. The endpoint may be unreachable or the URL may be incorrect." + + except requests.exceptions.RequestException as e: + logger.warning(f"HTTP GET tool request error: {_sanitize_url_for_error(url)} - {str(e)}") + return f"Error: Request failed - {type(e).__name__}. Please check the URL and try again." + + except Exception as e: + logger.error(f"HTTP GET tool unexpected error: {_sanitize_url_for_error(url)} - {str(e)}") + return f"Error: An unexpected error occurred while fetching the URL." + + return _http_get_tool diff --git a/src/archi/pipelines/agents/tools/http_get_example.py b/src/archi/pipelines/agents/tools/http_get_example.py new file mode 100644 index 000000000..80204270f --- /dev/null +++ b/src/archi/pipelines/agents/tools/http_get_example.py @@ -0,0 +1,50 @@ +""" +Example usage of the HTTP GET tool in an agent pipeline. + +This demonstrates how to add the HTTP GET tool to an agent's tool list. +""" + +# Example 1: Basic usage with defaults +from src.archi.pipelines.agents.tools import create_http_get_tool + +http_tool = create_http_get_tool() +# Tool is now ready to be added to an agent's tool list + +# Example 2: Customized configuration +http_tool_custom = create_http_get_tool( + name="fetch_api_data", + description="Fetch real-time data from REST API endpoints", + timeout=15.0, + max_response_chars=8000, +) + +# Example 3: Adding to an agent (pseudocode based on cms_comp_ops_agent.py pattern) +""" +class MyAgent(BaseReActAgent): + def _build_static_tools(self) -> List[Callable]: + # Existing tools + retriever_tool = create_retriever_tool(...) + file_search_tool = create_file_search_tool(...) + + # Add HTTP GET tool + http_tool = create_http_get_tool( + name="fetch_url", + description="Fetch live data from web endpoints or APIs", + timeout=10.0, + max_response_chars=4000, + ) + + return [retriever_tool, file_search_tool, http_tool] +""" + +# Example 4: Tool invocation (when used by the agent) +""" +When the LLM decides to use the tool, it will call: +result = http_tool.invoke("https://example.com/api/status") + +Possible returns: +- Success: "{'status': 'ok', 'version': '1.0'}" +- Error: "Error: Request timed out after 10.0 seconds..." +- Error: "Error: HTTP 404: Not Found" +- Error: "Error: Invalid URL scheme 'ftp'. Only HTTP and HTTPS are supported." +""" diff --git a/src/archi/pipelines/agents/tools/local_files.py b/src/archi/pipelines/agents/tools/local_files.py index f0b81a139..54b8b2048 100644 --- a/src/archi/pipelines/agents/tools/local_files.py +++ b/src/archi/pipelines/agents/tools/local_files.py @@ -14,6 +14,7 @@ from src.utils.logging import get_logger from src.utils.env import read_secret +from src.archi.pipelines.agents.tools.base import require_tool_permission logger = get_logger(__name__) @@ -212,8 +213,20 @@ def create_file_search_tool( max_results: int = 3, window: int = 240, store_docs: Optional[Callable[[str, Sequence[Path]], None]] = None, + required_permission: Optional[str] = None, ) -> Callable[[str], str]: - """Create a LangChain tool that performs keyword search in catalogued files.""" + """Create a LangChain tool that performs keyword search in catalogued files. + + Args: + catalog: The RemoteCatalogClient instance. + name: The name of the tool. + description: Human-readable description of the tool. + max_results: Maximum number of results to return. + window: Context window size for snippets. + store_docs: Optional callback to store retrieved documents. + required_permission: Optional RBAC permission required to use this tool. + If None, no permission check is performed (allow all). + """ _default_description = ( "Grep-like search over local document contents only (not filenames/paths).\n" @@ -228,6 +241,7 @@ def create_file_search_tool( ) @tool(name, description=tool_description) + @require_tool_permission(required_permission) def _search_local_files( query: str, regex: bool = False, @@ -300,8 +314,19 @@ def create_metadata_search_tool( description: Optional[str] = None, max_results: int = 5, store_docs: Optional[Callable[[str, Sequence[Path]], None]] = None, + required_permission: Optional[str] = None, ) -> Callable[[str], str]: - """Create a LangChain tool to search resource metadata catalogues.""" + """Create a LangChain tool to search resource metadata catalogues. + + Args: + catalog: The RemoteCatalogClient instance. + name: The name of the tool. + description: Human-readable description of the tool. + max_results: Maximum number of results to return. + store_docs: Optional callback to store retrieved documents. + required_permission: Optional RBAC permission required to use this tool. + If None, no permission check is performed (allow all). + """ tool_description = ( description @@ -318,6 +343,7 @@ def create_metadata_search_tool( ) @tool(name, description=tool_description) + @require_tool_permission(required_permission) def _search_metadata(query: str) -> str: if not query.strip(): return "Please provide a non-empty search query." @@ -363,8 +389,17 @@ def create_metadata_schema_tool( *, name: str = "list_metadata_schema", description: Optional[str] = None, + required_permission: Optional[str] = None, ) -> Callable[[], str]: - """Create a tool that returns supported metadata keys and distinct values.""" + """Create a tool that returns supported metadata keys and distinct values. + + Args: + catalog: The RemoteCatalogClient instance. + name: The name of the tool. + description: Human-readable description of the tool. + required_permission: Optional RBAC permission required to use this tool. + If None, no permission check is performed (allow all). + """ tool_description = ( description @@ -375,6 +410,7 @@ def create_metadata_schema_tool( ) @tool(name, description=tool_description) + @require_tool_permission(required_permission) def _schema_tool() -> str: try: payload = catalog.schema() @@ -399,8 +435,18 @@ def create_document_fetch_tool( name: str = "fetch_catalog_document", description: Optional[str] = None, default_max_chars: int = 4000, + required_permission: Optional[str] = None, ) -> Callable[..., str]: - """Create a LangChain tool to fetch a full document by resource hash.""" + """Create a LangChain tool to fetch a full document by resource hash. + + Args: + catalog: The RemoteCatalogClient instance. + name: The name of the tool. + description: Human-readable description of the tool. + default_max_chars: Default maximum characters to return. + required_permission: Optional RBAC permission required to use this tool. + If None, no permission check is performed (allow all). + """ tool_description = ( description @@ -413,6 +459,7 @@ def create_document_fetch_tool( ) @tool(name, description=tool_description) + @require_tool_permission(required_permission) def _fetch_document(resource_hash: str, max_chars: int = default_max_chars) -> str: if not resource_hash.strip(): return "Please provide a non-empty resource hash." diff --git a/src/archi/pipelines/agents/tools/retriever.py b/src/archi/pipelines/agents/tools/retriever.py index ea675ba4c..19d0e6e73 100644 --- a/src/archi/pipelines/agents/tools/retriever.py +++ b/src/archi/pipelines/agents/tools/retriever.py @@ -7,6 +7,7 @@ from langchain_core.retrievers import BaseRetriever from src.utils.logging import get_logger +from src.archi.pipelines.agents.tools.base import require_tool_permission logger = get_logger(__name__) @@ -66,6 +67,7 @@ def create_retriever_tool( max_documents: int = 4, max_chars: int = 800, store_docs: Optional[Callable[[str, Sequence[Document]], None]] = None, + required_permission: Optional[str] = None, ) -> Callable[[str], str]: """ Wrap a `BaseRetriever` instance in a LangChain tool. @@ -74,6 +76,16 @@ def create_retriever_tool( so the calling agent can ground its responses in the vector store content. If ``store_docs`` is provided, it will be invoked with the tool name and the list of retrieved ``Document`` objects before formatting the response. + + Args: + retriever: The LangChain retriever instance to wrap. + name: The name of the tool. + description: Human-readable description of the tool. + max_documents: Maximum number of documents to return. + max_chars: Maximum characters per document snippet. + store_docs: Optional callback to store retrieved documents. + required_permission: Optional RBAC permission required to use this tool. + If None, no permission check is performed (allow all). """ tool_description = ( @@ -87,6 +99,7 @@ def create_retriever_tool( ) @tool(name, description=tool_description) + @require_tool_permission(required_permission) def _retriever_tool(query: str) -> str: results = retriever.invoke(query) docs = _normalize_results(results or []) diff --git a/src/archi/pipelines/agents/tools/sandbox.py b/src/archi/pipelines/agents/tools/sandbox.py new file mode 100644 index 000000000..c48f67f1b --- /dev/null +++ b/src/archi/pipelines/agents/tools/sandbox.py @@ -0,0 +1,774 @@ +""" +Sandbox code execution tool for agent pipelines. + +This module provides the create_sandbox_tool function that creates a LangChain-compatible +tool for executing code in isolated Docker containers. +""" + +from __future__ import annotations + +import base64 +import json +import os +import re +import threading +from pathlib import Path +from typing import Callable, Dict, List, Optional + +from langchain.tools import tool +from pydantic import BaseModel, Field + +from src.utils.logging import get_logger +from src.archi.pipelines.agents.tools.base import require_tool_permission + +logger = get_logger(__name__) + + +# Default permission required to use the sandbox tool +DEFAULT_REQUIRED_PERMISSION = "tools:sandbox" + +# --------------------------------------------------------------------------- +# Per-request context for sandbox artifact persistence. +# +# We use a module-level dictionary keyed by trace_id instead of thread-local +# or contextvars, because LangChain/LangGraph may execute tools in different +# threads. The trace_id is set as an environment variable which IS inherited +# by child threads. +# --------------------------------------------------------------------------- +_sandbox_contexts: Dict[str, Dict] = {} # trace_id -> {data_path, artifacts} +_sandbox_lock = threading.Lock() + +# Environment variable name for passing trace_id to tools +_TRACE_ID_ENV = "_ARCHI_SANDBOX_TRACE_ID" + +# Environment variable for conversation_id (needed for approval requests) +_CONVERSATION_ID_ENV = "_ARCHI_SANDBOX_CONVERSATION_ID" + +# Environment variable for session-level approval mode override +_APPROVAL_MODE_ENV = "_ARCHI_SANDBOX_APPROVAL_MODE" + +# Safe filename pattern +_SAFE_FILENAME_RE = re.compile(r"^[\w\-. ]+$") + +# Type for approval request callbacks +ApprovalRequestCallback = Callable[[Dict], None] + + +def set_sandbox_context( + trace_id: str, + data_path: str, + conversation_id: Optional[int] = None, + approval_callback: Optional[ApprovalRequestCallback] = None, + approval_mode_override: Optional[str] = None, +) -> None: + """ + Set the sandbox context for the current request (call before streaming). + + This stores context in a module-level dict and sets an environment variable + so that the trace_id can be retrieved from any thread. + + Args: + trace_id: Unique identifier for this request. + data_path: Path where artifacts should be stored. + conversation_id: The conversation ID for approval requests. + approval_callback: Optional callback to notify about approval requests. + approval_mode_override: Optional session-level approval mode ("auto" or "manual"). + """ + with _sandbox_lock: + _sandbox_contexts[trace_id] = { + "data_path": data_path, + "artifacts": [], + "conversation_id": conversation_id, + "approval_callback": approval_callback, + "approval_mode_override": approval_mode_override, + } + # Set env vars so child threads can find the trace_id and conversation_id + os.environ[_TRACE_ID_ENV] = trace_id + if conversation_id is not None: + os.environ[_CONVERSATION_ID_ENV] = str(conversation_id) + if approval_mode_override is not None: + os.environ[_APPROVAL_MODE_ENV] = approval_mode_override + logger.debug( + "Sandbox context set: trace_id=%s, data_path=%s, conversation_id=%s, approval_mode=%s", + trace_id, data_path, conversation_id, approval_mode_override + ) + + +def get_sandbox_artifacts() -> List[Dict]: + """Return artifact metadata collected during the current request.""" + trace_id = os.environ.get(_TRACE_ID_ENV) + if not trace_id: + return [] + with _sandbox_lock: + ctx = _sandbox_contexts.get(trace_id) + return ctx["artifacts"] if ctx else [] + + +def clear_sandbox_context() -> None: + """Clear sandbox context (call after consuming artifacts).""" + trace_id = os.environ.pop(_TRACE_ID_ENV, None) + os.environ.pop(_CONVERSATION_ID_ENV, None) + os.environ.pop(_APPROVAL_MODE_ENV, None) + if trace_id: + with _sandbox_lock: + _sandbox_contexts.pop(trace_id, None) + logger.debug("Sandbox context cleared: trace_id=%s", trace_id) + + +def _get_sandbox_context() -> tuple: + """Get current trace_id, data_path, or (None, None) if not set.""" + trace_id = os.environ.get(_TRACE_ID_ENV) + if not trace_id: + return None, None + with _sandbox_lock: + ctx = _sandbox_contexts.get(trace_id) + if ctx: + return trace_id, ctx["data_path"] + return None, None + + +def _get_full_sandbox_context() -> Dict: + """ + Get the full sandbox context including conversation_id and approval_callback. + + Returns: + Dict with trace_id, data_path, conversation_id, approval_callback, approval_mode_override + or empty dict if not set. + """ + trace_id = os.environ.get(_TRACE_ID_ENV) + if not trace_id: + return {} + with _sandbox_lock: + ctx = _sandbox_contexts.get(trace_id) + if ctx: + return { + "trace_id": trace_id, + "data_path": ctx.get("data_path"), + "conversation_id": ctx.get("conversation_id"), + "approval_callback": ctx.get("approval_callback"), + "approval_mode_override": ctx.get("approval_mode_override"), + } + return {} + + +def _sanitize_filename(filename: str) -> str: + """Sanitize a filename for safe storage.""" + name = os.path.basename(filename).strip() + if not name or not _SAFE_FILENAME_RE.match(name): + # Generate a safe fallback name + ext = Path(filename).suffix if filename else ".bin" + name = f"output{ext}" + return name + + +def _persist_sandbox_file(filename: str, mimetype: str, content_base64: str) -> Optional[Dict]: + """ + Persist a sandbox-generated file directly to disk. + + Returns artifact metadata dict with url, or None if context not available. + """ + trace_id, data_path = _get_sandbox_context() + + if not trace_id or not data_path: + logger.warning("Sandbox context not set - cannot persist file %s", filename) + return None + + # Validate trace_id format (UUID) + if not re.fullmatch(r"[0-9a-f\-]{36}", trace_id): + logger.error("Invalid trace_id format: %s", trace_id) + return None + + try: + # Create artifact directory + artifact_dir = Path(data_path) / "sandbox_artifacts" / trace_id + artifact_dir.mkdir(parents=True, exist_ok=True) + + # Sanitize and deduplicate filename + safe_name = _sanitize_filename(filename) + dest = artifact_dir / safe_name + counter = 1 + while dest.exists(): + stem, ext = os.path.splitext(safe_name) + dest = artifact_dir / f"{stem}_{counter}{ext}" + counter += 1 + + # Decode and write + raw = base64.b64decode(content_base64) + dest.write_bytes(raw) + + # Build URL + url = f"/api/sandbox-artifacts/{trace_id}/{dest.name}" + + artifact = { + "filename": dest.name, + "mimetype": mimetype, + "url": url, + "size": len(raw), + } + + # Store in module-level dict for later retrieval + with _sandbox_lock: + ctx = _sandbox_contexts.get(trace_id) + if ctx: + ctx["artifacts"].append(artifact) + + logger.info( + "Persisted sandbox artifact: %s (%s, %d bytes) -> %s", + filename, mimetype, len(raw), url, + ) + return artifact + + except Exception as e: + logger.error("Failed to persist sandbox file %s: %s", filename, e, exc_info=True) + return None + + +def _get_effective_approval_mode(config_approval_mode: "ApprovalMode") -> "ApprovalMode": + """ + Get the effective approval mode, considering session-level override. + + Priority: + 1. Session-level override (set by user per session/query) + 2. Config-level setting (from deployment config) + + Args: + config_approval_mode: The approval mode from the effective config. + + Returns: + The approval mode to use for this execution. + """ + from src.utils.sandbox.config import ApprovalMode + + ctx = _get_full_sandbox_context() + override = ctx.get("approval_mode_override") + + if override: + # Parse the override string to ApprovalMode + override_lower = override.lower() + if override_lower == "auto": + logger.debug("Using session-level approval_mode override: auto") + return ApprovalMode.AUTO + elif override_lower == "manual": + logger.debug("Using session-level approval_mode override: manual") + return ApprovalMode.MANUAL + else: + logger.warning( + "Invalid approval_mode_override '%s', falling back to config: %s", + override, config_approval_mode.value + ) + + return config_approval_mode + + +def _request_approval_and_wait( + code: str, + language: str, + image: str, + tool_call_id: str, + timeout_seconds: float = 300.0, +) -> tuple[bool, Optional[str]]: + """ + Request user approval for sandbox execution and wait for response. + + Args: + code: The code to execute. + language: Programming language. + image: Docker image. + tool_call_id: The tool call ID. + timeout_seconds: How long to wait for approval. + + Returns: + Tuple of (approved: bool, error_message: Optional[str]) + """ + from src.utils.sandbox.approval import ( + ApprovalStatus, + create_approval_request, + wait_for_approval, + ) + + ctx = _get_full_sandbox_context() + trace_id = ctx.get("trace_id") + conversation_id = ctx.get("conversation_id") + approval_callback = ctx.get("approval_callback") + + logger.debug( + "Approval context: trace_id=%s, conversation_id=%s, has_callback=%s", + trace_id, conversation_id, approval_callback is not None + ) + + if not trace_id or conversation_id is None: + logger.warning( + "Cannot request approval: missing context (trace_id=%s, conversation_id=%s)", + trace_id, conversation_id + ) + # Fall back to auto-approve if context is missing + return True, None + + # Create the approval request + request = create_approval_request( + trace_id=trace_id, + conversation_id=conversation_id, + code=code, + language=language, + image=image, + tool_call_id=tool_call_id, + timeout_seconds=timeout_seconds, + ) + + # Notify via callback if available (so the streaming layer can emit an event) + if approval_callback: + try: + logger.info("Invoking approval callback for request: %s", request.approval_id) + approval_callback(request.to_dict()) + logger.info("Approval callback completed successfully") + except Exception as e: + logger.error("Approval callback failed: %s", e, exc_info=True) + else: + logger.warning("No approval_callback available - frontend will not receive approval_request event") + + # Wait for approval + try: + resolved = wait_for_approval( + request.approval_id, + timeout=timeout_seconds, + poll_interval=0.5, + ) + + if resolved.status == ApprovalStatus.APPROVED: + logger.info("Sandbox execution approved: %s", request.approval_id) + return True, None + elif resolved.status == ApprovalStatus.REJECTED: + logger.info("Sandbox execution rejected: %s", request.approval_id) + return False, "Execution rejected by user" + elif resolved.status == ApprovalStatus.EXPIRED: + logger.info("Sandbox approval expired: %s", request.approval_id) + return False, "Approval request timed out" + elif resolved.status == ApprovalStatus.CANCELLED: + logger.info("Sandbox approval cancelled: %s", request.approval_id) + return False, "Approval request cancelled" + else: + return False, f"Unknown approval status: {resolved.status}" + + except Exception as e: + logger.error("Error waiting for approval: %s", e, exc_info=True) + return False, f"Error waiting for approval: {e}" + + +class SandboxInput(BaseModel): + """Input schema for the sandbox execution tool.""" + + code: str = Field( + description="The code to execute. Should be complete, runnable code." + ) + language: str = Field( + default="python", + description="Programming language: 'python', 'bash', or 'sh'. Default is 'python'." + ) + image: Optional[str] = Field( + default=None, + description="Docker image to use. If not specified, uses the default image. " + "Must be in the allowed images list." + ) + + +def _get_user_role_overrides(): + """ + Get sandbox overrides for the current user's role. + + Returns: + (role_overrides, error_message) tuple + """ + try: + from flask import session, has_request_context + + if not has_request_context(): + return None, None + + if not session.get('logged_in'): + return None, None + + # Get user roles + user_roles = session.get('roles', []) + if not user_roles: + return None, None + + # Get role configuration from RBAC registry + from src.utils.rbac.registry import get_registry + from src.utils.sandbox.config import get_role_sandbox_overrides + + registry = get_registry() + + # Find the first role with sandbox overrides + for role_name in user_roles: + role_config = registry.get_role_info(role_name) + if role_config: + overrides = get_role_sandbox_overrides(role_config) + if overrides: + return overrides, None + + return None, None + + except ImportError: + return None, None + except Exception as e: + logger.warning(f"Error getting role sandbox overrides: {e}") + return None, None + + +def _format_output_for_agent(result) -> str: + """ + Format sandbox result for agent consumption. + + Produces a structured text output that the agent can parse and present to users. + """ + from src.utils.sandbox.executor import SandboxResult + + if not isinstance(result, SandboxResult): + return str(result) + + parts = [] + + # Handle system errors + if result.error: + parts.append(f"**Execution Error**: {result.error}") + if result.timed_out: + parts.append("The execution was terminated due to timeout.") + return "\n".join(parts) + + # Exit code + if result.exit_code != 0: + parts.append(f"**Exit Code**: {result.exit_code} (non-zero indicates an error in the code)") + else: + parts.append("**Exit Code**: 0 (success)") + + # Execution time + parts.append(f"**Execution Time**: {result.execution_time:.2f}s") + + # Stdout - filter out container paths to prevent agent from echoing them + if result.stdout: + # Remove lines containing /workspace paths + filtered_lines = [] + for line in result.stdout.split('\n'): + if '/workspace' not in line.lower(): + filtered_lines.append(line) + filtered_stdout = '\n'.join(filtered_lines).strip() + + if filtered_stdout: + parts.append("\n**Standard Output**:") + parts.append("```") + parts.append(filtered_stdout) + parts.append("```") + else: + parts.append("\n**Standard Output**: (output contained only file path info, omitted)") + else: + parts.append("\n**Standard Output**: (empty)") + + # Stderr + if result.stderr: + parts.append("\n**Standard Error**:") + parts.append("```") + parts.append(result.stderr) + parts.append("```") + + # Truncation warning + if result.truncated: + parts.append("\n⚠️ Output was truncated due to size limits.") + + # Generated files - keep output minimal to avoid agent echoing details + if result.files: + logger.info("Formatting %d output file(s) for agent", len(result.files)) + image_count = 0 + other_count = 0 + for f in result.files: + logger.info( + "Processing file: %s, mimetype=%s, size=%d, truncated=%s, has_content=%s", + f.filename, f.mimetype, f.size, f.truncated, bool(f.content_base64), + ) + if f.truncated: + other_count += 1 # Skip truncated files + elif f.content_base64: + # Persist file directly to disk (images and other files) + artifact = _persist_sandbox_file(f.filename, f.mimetype, f.content_base64) + if artifact: + if f.mimetype.startswith("image/"): + image_count += 1 + else: + other_count += 1 + # Include small text files inline for agent context + if f.mimetype.startswith("text/") and f.size < 5000: + try: + content = base64.b64decode(f.content_base64).decode("utf-8") + parts.append(f"\n**Generated text file:**\n```\n{content}\n```") + except Exception: + pass + else: + other_count += 1 + else: + other_count += 1 + + # Summarize what was generated without exposing filenames + if image_count: + parts.append( + f"\n**Generated output:** {image_count} image(s) will be displayed to the user " + f"automatically below your response. Do NOT mention file paths or filenames." + ) + if other_count and not image_count: + parts.append(f"\n**Generated output:** {other_count} file(s) saved.") + + return "\n".join(parts) + + +def create_sandbox_tool( + *, + name: str = "execute_code", + description: Optional[str] = None, + required_permission: Optional[str] = DEFAULT_REQUIRED_PERMISSION, +) -> Callable: + """ + Create a LangChain tool that executes code in an isolated sandbox container. + + This tool allows agents to run arbitrary code (Python, bash, etc.) in ephemeral + Docker containers with resource limits and security isolation. + + Args: + name: The name of the tool (used by the LLM when selecting tools). + description: Human-readable description of what the tool does. + If None, a default description is used. + required_permission: The RBAC permission required to use this tool. + Default is 'tools:sandbox'. Set to None to disable permission checks. + + Returns: + A callable LangChain tool that accepts code and returns execution results. + + Example: + >>> from src.archi.pipelines.agents.tools import create_sandbox_tool + >>> sandbox_tool = create_sandbox_tool( + ... name="run_code", + ... description="Execute Python or bash code in a sandbox", + ... ) + >>> # Add to agent's tool list + >>> tools = [retriever_tool, sandbox_tool] + + Security Notes: + - Code runs in ephemeral Docker containers that are destroyed after execution + - Containers have resource limits (CPU, memory, time) + - Only images from the configured allowlist can be used + - RBAC permission check is enforced at tool invocation time + - Containers are isolated from the host and internal services + + Output: + The tool returns a structured text output containing: + - Exit code + - Execution time + - Standard output (stdout) + - Standard error (stderr) + - Summary of generated files (if any) + """ + tool_description = description or ( + "Execute code in an isolated sandbox container.\n" + "Input: JSON with 'code' (required), 'language' (optional: python/bash/sh), " + "'image' (optional: Docker image from allowlist).\n" + "Output: Execution results including stdout, stderr, and exit code.\n" + "\n" + "Use this tool to:\n" + "- Run Python scripts for data analysis and plotting\n" + "- Execute shell commands\n" + "- Process data and generate outputs\n" + "\n" + "For plots, save to /workspace/output/. The directories are pre-created.\n" + "Example: plt.savefig('/workspace/output/plot.png')\n" + "\n" + "CRITICAL RULES FOR YOUR RESPONSE TO THE USER:\n" + "1. NEVER mention /workspace/, /workspace/output/, or any container paths\n" + "2. NEVER include filenames like 'plot.png' or 'File: xyz.png' in your response\n" + "3. Images are displayed AUTOMATICALLY - just describe what you plotted\n" + "4. Say things like 'Here is the plot' or 'The chart below shows...'\n" + "5. DO NOT echo the stdout if it contains paths - summarize the results instead\n" + "\n" + "Example input: {\"code\": \"print('Hello')\", \"language\": \"python\"}" + ) + + @tool(name, description=tool_description, args_schema=SandboxInput) + @require_tool_permission(required_permission) + def _sandbox_tool(code: str, language: str = "python", image: Optional[str] = None) -> str: + """Execute code in an isolated sandbox container.""" + + # Import here to avoid circular imports and allow lazy loading + from src.utils.sandbox import ( + SandboxExecutor, + get_sandbox_config, + resolve_effective_config, + ) + from src.utils.sandbox.config import ApprovalMode + + # Load base config + base_config = get_sandbox_config() + + if not base_config.enabled: + logger.warning("Sandbox tool invoked but sandbox is not enabled") + return "Error: Sandbox execution is not enabled for this deployment." + + # Get role overrides and resolve effective config + role_overrides, _ = _get_user_role_overrides() + effective_config = resolve_effective_config(base_config, role_overrides) + + # Validate image against effective allowlist + target_image = image or effective_config.default_image + if not effective_config.is_image_allowed(target_image): + allowed = ", ".join(effective_config.image_allowlist) + logger.warning(f"User requested disallowed image '{target_image}'") + return ( + f"Error: Image '{target_image}' is not allowed for your role.\n" + f"Allowed images: {allowed}" + ) + + # Check approval mode and request approval if needed + # Session-level override takes priority over config + effective_approval_mode = _get_effective_approval_mode(effective_config.approval_mode) + if effective_approval_mode == ApprovalMode.MANUAL: + # Generate a tool_call_id for tracking + import uuid + tool_call_id = str(uuid.uuid4()) + + approved, error_msg = _request_approval_and_wait( + code=code, + language=language, + image=target_image, + tool_call_id=tool_call_id, + timeout_seconds=effective_config.timeout * 10, # Allow longer for approval + ) + + if not approved: + logger.info("Sandbox execution not approved: %s", error_msg) + return f"Execution cancelled: {error_msg or 'User did not approve'}" + + # Create executor and run + try: + executor = SandboxExecutor(config=effective_config) + + logger.info( + f"Executing {language} code in sandbox (image={target_image}, " + f"timeout={effective_config.timeout}s)" + ) + + result = executor.execute( + code=code, + language=language, + image=target_image, + timeout=effective_config.timeout, + limits=effective_config.resource_limits, + ) + + # Format output for agent + return _format_output_for_agent(result) + + except Exception as e: + logger.error(f"Sandbox execution failed: {e}", exc_info=True) + return f"Error: Sandbox execution failed - {str(e)}" + + return _sandbox_tool + + +def create_sandbox_tool_with_files( + *, + name: str = "execute_code_with_files", + description: Optional[str] = None, + required_permission: Optional[str] = DEFAULT_REQUIRED_PERMISSION, + return_files: bool = True, +) -> Callable: + """ + Create a sandbox tool that returns structured output including file data. + + This variant is useful when you need programmatic access to generated files + (e.g., for rendering images in chat). + + Args: + name: The name of the tool. + description: Tool description. + required_permission: RBAC permission required. + return_files: Whether to include base64-encoded files in output. + + Returns: + A callable tool that returns JSON-serializable output. + """ + tool_description = description or ( + "Execute code in a sandbox and return structured results including generated files.\n" + "Returns JSON with stdout, stderr, exit_code, execution_time, and files array." + ) + + @tool(name, description=tool_description, args_schema=SandboxInput) + @require_tool_permission(required_permission) + def _sandbox_tool_with_files( + code: str, + language: str = "python", + image: Optional[str] = None + ) -> str: + """Execute code and return structured output with files.""" + + from src.utils.sandbox import ( + SandboxExecutor, + get_sandbox_config, + resolve_effective_config, + ) + from src.utils.sandbox.config import ApprovalMode + + base_config = get_sandbox_config() + + if not base_config.enabled: + return json.dumps({"error": "Sandbox execution is not enabled"}) + + role_overrides, _ = _get_user_role_overrides() + effective_config = resolve_effective_config(base_config, role_overrides) + + target_image = image or effective_config.default_image + if not effective_config.is_image_allowed(target_image): + return json.dumps({ + "error": f"Image '{target_image}' is not allowed", + "allowed_images": effective_config.image_allowlist, + }) + + # Check approval mode and request approval if needed + # Session-level override takes priority over config + effective_approval_mode = _get_effective_approval_mode(effective_config.approval_mode) + if effective_approval_mode == ApprovalMode.MANUAL: + import uuid + tool_call_id = str(uuid.uuid4()) + + approved, error_msg = _request_approval_and_wait( + code=code, + language=language, + image=target_image, + tool_call_id=tool_call_id, + timeout_seconds=effective_config.timeout * 10, + ) + + if not approved: + return json.dumps({ + "error": f"Execution cancelled: {error_msg or 'User did not approve'}", + "approval_rejected": True, + }) + + try: + executor = SandboxExecutor(config=effective_config) + result = executor.execute( + code=code, + language=language, + image=target_image, + timeout=effective_config.timeout, + limits=effective_config.resource_limits, + ) + + output = result.to_dict() + + # Optionally strip file content to reduce size + if not return_files: + for f in output.get("files", []): + f["content_base64"] = None + + return json.dumps(output) + + except Exception as e: + logger.error(f"Sandbox execution failed: {e}", exc_info=True) + return json.dumps({"error": str(e)}) + + return _sandbox_tool_with_files diff --git a/src/cli/managers/templates_manager.py b/src/cli/managers/templates_manager.py index c5fec7b78..c19c74d97 100644 --- a/src/cli/managers/templates_manager.py +++ b/src/cli/managers/templates_manager.py @@ -419,6 +419,11 @@ def _render_compose_file(self, context: TemplateContext) -> None: template_vars["app_version"] = get_git_version() + # Sandbox configuration + archi_config = context.config_manager.config.get("archi", {}) + sandbox_config = archi_config.get("sandbox", {}) + template_vars["sandbox_enabled"] = sandbox_config.get("enabled", False) + # Compose template still expects optional lists template_vars.setdefault("prompt_files", []) template_vars.setdefault("rubrics", []) diff --git a/src/cli/templates/base-compose.yaml b/src/cli/templates/base-compose.yaml index dc2225eea..f9cd240ff 100644 --- a/src/cli/templates/base-compose.yaml +++ b/src/cli/templates/base-compose.yaml @@ -118,12 +118,18 @@ services: args: APP_VERSION: {{ app_version }} container_name: {{ chatbot_container_name }} - {% if postgres_enabled -%} + {% if postgres_enabled or sandbox_enabled -%} depends_on: + {% if postgres_enabled -%} postgres: condition: service_healthy config-seed: condition: service_completed_successfully + {% endif -%} + {% if sandbox_enabled -%} + sandbox-dind: + condition: service_started + {% endif %} {% endif -%} environment: PGHOST: {{ 'localhost' if host_mode else 'postgres' }} @@ -134,6 +140,10 @@ services: VERBOSITY: {{ verbosity }} # Allow overriding Ollama host via env so containers can reach host daemon OLLAMA_HOST: ${OLLAMA_HOST:-} + {% if sandbox_enabled -%} + # Docker host for sandbox execution (Docker-in-Docker) + DOCKER_HOST: tcp://{{ 'localhost' if host_mode else 'sandbox-dind' }}:2375 + {% endif -%} {% for secret in required_secrets | default([]) -%} {{ secret.upper() }}_FILE: /run/secrets/{{ secret.lower() }} {% endfor %} @@ -190,6 +200,29 @@ services: {%- endif %} {%- endif %} + {% if sandbox_enabled -%} + # Docker-in-Docker service for sandbox code execution + sandbox-dind: + image: docker:27-dind + container_name: {{ name }}-sandbox-dind + privileged: true + environment: + DOCKER_TLS_CERTDIR: "" # Disable TLS for internal communication + volumes: + - sandbox-dind-data:/var/lib/docker + command: ["dockerd", "--host=tcp://0.0.0.0:2375", "--host=unix:///var/run/docker.sock"] + logging: + options: + max-size: 10m + restart: always + {% if host_mode -%} + # In host mode, expose DinD on a specific port + ports: + - "127.0.0.1:2375:2375" + network_mode: bridge # DinD needs bridge mode even in host deployments + {% endif %} + {%- endif %} + {% if grafana_enabled -%} grafana: image: {{ grafana_image }}:{{ grafana_tag }} @@ -631,6 +664,9 @@ volumes: {{ volume }}: external: true {% endfor %} + {% if sandbox_enabled -%} + sandbox-dind-data: + {% endif %} {% if required_secrets %} secrets: diff --git a/src/cli/templates/base-config.yaml b/src/cli/templates/base-config.yaml index e520012cf..0c91cff9f 100644 --- a/src/cli/templates/base-config.yaml +++ b/src/cli/templates/base-config.yaml @@ -108,6 +108,9 @@ services: scope: {{ services.chat_app.auth.sso.client_kwargs.scope | default("openid profile email", true) }} basic: enabled: {{ services.chat_app.auth.basic.enabled | default(false, true) }} + {% if services.chat_app.auth.auth_roles is defined %} + auth_roles: {{ services.chat_app.auth.auth_roles | tojson }} + {% endif %} data_manager: auth: enabled: {{ services.data_manager.auth.enabled | default(false) }} @@ -265,3 +268,121 @@ data_manager: {%- endfor %} email_pattern: '{{ data_manager.utils.anonymizer.email_pattern | default("[\\w\\.-]+@[\\w\\.-]+\\.\\w+") }}' username_pattern: '{{ data_manager.utils.anonymizer.username_pattern | default("\\[~[^\\]]+\\]") }}' + +archi: + pipelines: {{ archi.pipelines | default(['QAPipeline'], true) }} + agent_description: {{ archi.agent_description | default('No description provided', true) }} + + # Sandbox configuration for containerized code execution + sandbox: + enabled: {{ archi.sandbox.enabled | default(false, true) }} + default_image: {{ archi.sandbox.default_image | default("python:3.11-slim", true) }} + image_allowlist: + {%- for image in archi.sandbox.image_allowlist | default(["python:3.11-slim"]) %} + - "{{ image }}" + {%- endfor %} + timeout: {{ archi.sandbox.timeout | default(30, true) }} + max_timeout: {{ archi.sandbox.max_timeout | default(300, true) }} + resource_limits: + memory: {{ archi.sandbox.resource_limits.memory | default("256m", true) }} + cpu: {{ archi.sandbox.resource_limits.cpu | default(0.5, true) }} + pids_limit: {{ archi.sandbox.resource_limits.pids_limit | default(100, true) }} + network_enabled: {{ archi.sandbox.network_enabled | default(true, true) }} + output_max_chars: {{ archi.sandbox.output_max_chars | default(100000, true) }} + # Custom Docker registry for private images + registry: + url: {{ archi.sandbox.registry.url | default("", true) }} + username_env: {{ archi.sandbox.registry.username_env | default("", true) }} + password_env: {{ archi.sandbox.registry.password_env | default("", true) }} + + providers: + openai: + enabled: {{ archi.providers.openai.enabled | default(true, true) }} + api_key_env: {{ archi.providers.openai.api_key_env | default("OPENAI_API_KEY", true) }} + base_url: {{ archi.providers.openai.base_url | default("", true) }} + default_model: {{ archi.providers.openai.default_model | default("gpt-4o", true) }} + models: + {%- for m in archi.providers.openai.models | default(["gpt-4o","gpt-4o-mini","gpt-3.5-turbo"]) %} + - {{ m }} + {%- endfor %} + anthropic: + enabled: {{ archi.providers.anthropic.enabled | default(false, true) }} + api_key_env: {{ archi.providers.anthropic.api_key_env | default("ANTHROPIC_API_KEY", true) }} + base_url: {{ archi.providers.anthropic.base_url | default("", true) }} + default_model: {{ archi.providers.anthropic.default_model | default("claude-3-opus-20240229", true) }} + models: + {%- for m in archi.providers.anthropic.models | default(["claude-3-opus-20240229","claude-3-sonnet-20240229"]) %} + - {{ m }} + {%- endfor %} + openrouter: + enabled: {{ archi.providers.openrouter.enabled | default(false, true) }} + api_key_env: {{ archi.providers.openrouter.api_key_env | default("OPENROUTER_API_KEY", true) }} + base_url: {{ archi.providers.openrouter.base_url | default("https://openrouter.ai/api/v1", true) }} + default_model: {{ archi.providers.openrouter.default_model | default("openai/gpt-4o-mini", true) }} + models: + {%- for m in archi.providers.openrouter.models | default(["openai/gpt-4o-mini","openai/gpt-4o"]) %} + - {{ m }} + {%- endfor %} + gemini: + enabled: {{ archi.providers.gemini.enabled | default(false, true) }} + api_key_env: {{ archi.providers.gemini.api_key_env | default("GEMINI_API_KEY", true) }} + base_url: {{ archi.providers.gemini.base_url | default("", true) }} + default_model: {{ archi.providers.gemini.default_model | default("gemini-1.5-pro", true) }} + models: + {%- for m in archi.providers.gemini.models | default(["gemini-1.5-pro","gemini-1.5-flash"]) %} + - {{ m }} + {%- endfor %} + local: + enabled: {{ archi.providers.local.enabled | default(true, true) }} + base_url: {{ archi.providers.local.base_url | default("http://localhost:11434", true) }} + mode: {{ archi.providers.local.mode | default("ollama", true) }} + default_model: {{ archi.providers.local.default_model | default("llama3.2", true) }} + models: + {%- for m in archi.providers.local.models | default(["llama3.2"]) %} + - {{ m }} + {%- endfor %} + + pipeline_map: + QAPipeline: + max_tokens: {{ archi.pipeline_map.QAPipeline.max_tokens | default(10000, true) }} + prompts: + required: + condense_prompt: {% if archi.pipeline_map.QAPipeline.prompts.required.condense_prompt %}"{{ archi.pipeline_map.QAPipeline.prompts.required.condense_prompt }}"{% else %}null{% endif %} + chat_prompt: {% if archi.pipeline_map.QAPipeline.prompts.required.chat_prompt %}"{{ archi.pipeline_map.QAPipeline.prompts.required.chat_prompt }}"{% else %}null{% endif %} + models: + required: + condense_model: {{ archi.pipeline_map.QAPipeline.models.required.condense_model | default('openai/gpt-4o-mini', true) }} + chat_model: {{ archi.pipeline_map.QAPipeline.models.required.chat_model | default('openai/gpt-4o', true) }} + GradingPipeline: + max_tokens: {{ archi.pipeline_map.GradingPipeline.max_tokens | default(10000, true) }} + prompts: + required: + final_grade_prompt: {% if archi.pipeline_map.GradingPipeline.prompts.required.final_grade_prompt %}"{{ archi.pipeline_map.GradingPipeline.prompts.required.final_grade_prompt }}"{% else %}null{% endif %} + optional: + summary_prompt: {% if archi.pipeline_map.GradingPipeline.prompts.optional.summary_prompt %}"{{ archi.pipeline_map.GradingPipeline.prompts.optional.summary_prompt }}"{% else %}null{% endif %} + analysis_prompt: {% if archi.pipeline_map.GradingPipeline.prompts.optional.analysis_prompt %}"{{ archi.pipeline_map.GradingPipeline.prompts.optional.analysis_prompt }}"{% else %}null{% endif %} + models: + required: + final_grade_model: {{ archi.pipeline_map.GradingPipeline.models.required.final_grade_model | default('openai/gpt-4o', true) }} + optional: + summary_model: {{ archi.pipeline_map.GradingPipeline.models.optional.summary_model | default('openai/gpt-4o', true) }} + analysis_model: {{ archi.pipeline_map.GradingPipeline.models.optional.analysis_model | default('openai/gpt-4o', true) }} + ImageProcessingPipeline: + max_tokens: {{ archi.pipeline_map.ImageProcessingPipeline.max_tokens | default(10000, true) }} + prompts: + required: + image_processing_prompt: {% if archi.pipeline_map.ImageProcessingPipeline.prompts.required.image_processing_prompt %}"{{ archi.pipeline_map.ImageProcessingPipeline.prompts.required.image_processing_prompt }}"{% else %}null{% endif %} + models: + required: + image_processing_model: {{ archi.pipeline_map.ImageProcessingPipeline.models.required.image_processing_model | default('openai/gpt-4o', true) }} + CMSCompOpsAgent: + recursion_limit: {{ archi.pipeline_map.CMSCompOpsAgent.recursion_limit | default(100, true) }} + prompts: + required: + agent_prompt: {% if archi.pipeline_map.CMSCompOpsAgent.prompts.required.agent_prompt %}"{{ archi.pipeline_map.CMSCompOpsAgent.prompts.required.agent_prompt }}"{% else %}null{% endif %} + models: + required: + agent_model: {{ archi.pipeline_map.CMSCompOpsAgent.models.required.agent_model | default('openai/gpt-4o', true) }} + + chain_update_time: {{ archi.chain_update_time | default(10, true) }} + mcp_servers: {{ archi.mcp_servers | default({}, true) }} diff --git a/src/interfaces/chat_app/app.py b/src/interfaces/chat_app/app.py index 9d52f9a48..caf5dee40 100644 --- a/src/interfaces/chat_app/app.py +++ b/src/interfaces/chat_app/app.py @@ -62,6 +62,18 @@ from src.interfaces.chat_app.document_utils import * from src.interfaces.chat_app.utils import collapse_assistant_sequences +# RBAC imports for role-based access control +from src.utils.rbac import ( + get_registry, + get_user_roles, + has_permission, + require_permission, + require_any_permission, + require_authenticated, +) +from src.utils.rbac.permissions import get_permission_context +from src.utils.rbac.audit import log_authentication_event + logger = get_logger(__name__) @@ -1532,6 +1544,11 @@ def stream( trace_events: List[Dict[str, Any]] = [] tool_call_count = 0 stream_start_time = time.time() + # Sandbox artifact handling (set if import succeeds) + sandbox_context_set = False + get_sandbox_artifacts = None + format_artifacts_markdown = None + clear_sandbox_context = None try: context, error_code = self._prepare_chat_context( @@ -1579,6 +1596,89 @@ def stream( pipeline_name=self.archi.pipeline_name if hasattr(self.archi, 'pipeline_name') else None, ) + # Queue for events from the agent stream (including approval requests) + # Threading is needed because tools may block waiting for approval, + # but we still need to send approval_request events to the frontend. + import queue + import threading + event_queue = queue.Queue() + stream_done = threading.Event() + stream_exception = [None] # Use list to allow mutation from thread + + def approval_callback(approval_request_dict): + """Called by sandbox tool when approval is needed.""" + logger.info("Approval callback invoked, putting request in queue: %s", approval_request_dict.get("approval_id")) + event_queue.put(("approval", approval_request_dict)) + + # Get session-level approval mode override (if user has set one) + session_approval_mode = session.get("sandbox_approval_mode") + + # Set up sandbox context so artifacts are persisted directly to disk + try: + from src.archi.pipelines.agents.tools.sandbox import ( + set_sandbox_context, + get_sandbox_artifacts as _get_sandbox_artifacts, + clear_sandbox_context as _clear_sandbox_context, + ) + from src.interfaces.chat_app.sandbox_artifacts import ( + format_artifacts_markdown as _format_artifacts_markdown, + ) + set_sandbox_context( + trace_id, + self.data_path, + conversation_id=context.conversation_id, + approval_callback=approval_callback, + approval_mode_override=session_approval_mode, + ) + sandbox_context_set = True + # Store references for use outside this try block + get_sandbox_artifacts = _get_sandbox_artifacts + format_artifacts_markdown = _format_artifacts_markdown + clear_sandbox_context = _clear_sandbox_context + except Exception as e: + logger.debug("Could not set sandbox context: %s", e) + get_sandbox_artifacts = None + format_artifacts_markdown = None + + # Run agent stream in background thread so we can yield approval + # requests while tools are blocking waiting for user response. + def run_stream(): + try: + for output in self.archi.stream(history=context.history, conversation_id=context.conversation_id): + event_queue.put(("output", output)) + except Exception as e: + stream_exception[0] = e + finally: + stream_done.set() + + stream_thread = threading.Thread(target=run_stream, daemon=True) + stream_thread.start() + + # Consume events from queue, yielding to SSE stream + while not stream_done.is_set() or not event_queue.empty(): + try: + event_type_tag, event_data = event_queue.get(timeout=0.1) + except queue.Empty: + continue + + # Handle approval requests immediately + if event_type_tag == "approval": + logger.info("Yielding approval_request event to SSE stream: %s", event_data.get("approval_id")) + yield { + "type": "approval_request", + "approval_id": event_data.get("approval_id"), + "code": event_data.get("code"), + "language": event_data.get("language"), + "image": event_data.get("image"), + "tool_call_id": event_data.get("tool_call_id"), + "timeout_seconds": event_data.get("timeout_seconds", 300), + "conversation_id": context.conversation_id, + "timestamp": datetime.now(timezone.utc).isoformat(), + } + continue + + # Handle agent output + output = event_data for output in self.archi.stream(history=context.history, conversation_id=context.conversation_id): if client_timeout and time.time() - stream_start_time > client_timeout: if trace_id: @@ -1730,6 +1830,10 @@ def stream( } timestamps["chain_finished_ts"] = datetime.now() + + # Check if stream thread raised an exception + if stream_exception[0] is not None: + raise stream_exception[0] if last_output is None: if trace_id: @@ -1796,6 +1900,29 @@ def stream( total_duration_ms=total_duration_ms, ) + # Retrieve sandbox-generated artifact metadata and append + # markdown links. The frontend fetches files from the + # /api/sandbox-artifacts// route. + # Files are already persisted to disk by the sandbox tool itself. + if sandbox_context_set and get_sandbox_artifacts is not None: + try: + artifacts = get_sandbox_artifacts() + logger.info( + "Sandbox artifacts check: found %d artifact(s), trace_id=%s", + len(artifacts) if artifacts else 0, + trace_id, + ) + if artifacts and format_artifacts_markdown is not None: + markdown_suffix = format_artifacts_markdown(artifacts) + logger.info( + "Appending %d artifact(s) as markdown: %s", + len(artifacts), + [a['filename'] for a in artifacts], + ) + output += markdown_suffix + except Exception as e: + logger.warning("Failed to retrieve sandbox artifacts: %s", e, exc_info=True) + yield { "type": "final", "response": output, @@ -1823,6 +1950,12 @@ def stream( cancelled_by='user', cancellation_reason='Stream cancelled by client', ) + # Cancel any pending sandbox approvals for this trace + try: + from src.utils.sandbox.approval import cancel_approvals_for_trace + cancel_approvals_for_trace(trace_id) + except Exception as e: + logger.debug("Could not cancel sandbox approvals: %s", e) raise except ConversationAccessError as exc: logger.warning("Unauthorized conversation access attempt: %s", exc) @@ -1847,6 +1980,12 @@ def stream( ) yield {"type": "error", "status": 500, "message": "server error; see chat logs for message"} finally: + # Clean up sandbox context if it was set + if sandbox_context_set and clear_sandbox_context is not None: + try: + clear_sandbox_context() + except Exception: + pass if self.cursor is not None: self.cursor.close() if self.conn is not None: @@ -1935,7 +2074,8 @@ def __init__(self, app, **configs): self.add_endpoint('/terms', 'terms', self.require_auth(self.terms)) self.add_endpoint('/api/like', 'like', self.require_auth(self.like), methods=["POST"]) self.add_endpoint('/api/dislike', 'dislike', self.require_auth(self.dislike), methods=["POST"]) - self.add_endpoint('/api/update_config', 'update_config', self.require_auth(self.update_config), methods=["POST"]) + # Config modification requires config:modify permission (archi-expert or archi-admins) + self.add_endpoint('/api/update_config', 'update_config', self.require_perm('config:modify')(self.update_config), methods=["POST"]) self.add_endpoint('/api/get_configs', 'get_configs', self.require_auth(self.get_configs), methods=["GET"]) self.add_endpoint('/api/text_feedback', 'text_feedback', self.require_auth(self.text_feedback), methods=["POST"]) @@ -1958,6 +2098,53 @@ def __init__(self, app, **configs): self.add_endpoint('/api/trace/message/', 'get_trace_by_message', self.require_auth(self.get_trace_by_message), methods=["GET"]) self.add_endpoint('/api/cancel_stream', 'cancel_stream', self.require_auth(self.cancel_stream), methods=["POST"]) + # Sandbox artifact serving + self.add_endpoint( + '/api/sandbox-artifacts//', + 'serve_sandbox_artifact', + self.require_auth(self.serve_sandbox_artifact), + methods=["GET"], + ) + + # Sandbox approval endpoints + logger.info("Adding sandbox approval API endpoints") + self.add_endpoint( + '/api/sandbox/approval/', + 'sandbox_get_approval', + self.require_auth(self.sandbox_get_approval), + methods=["GET"], + ) + self.add_endpoint( + '/api/sandbox/approval//approve', + 'sandbox_approve', + self.require_auth(self.sandbox_approve), + methods=["POST"], + ) + self.add_endpoint( + '/api/sandbox/approval//reject', + 'sandbox_reject', + self.require_auth(self.sandbox_reject), + methods=["POST"], + ) + self.add_endpoint( + '/api/sandbox/approvals/pending', + 'sandbox_pending_approvals', + self.require_auth(self.sandbox_pending_approvals), + methods=["GET"], + ) + self.add_endpoint( + '/api/sandbox/config', + 'sandbox_get_config', + self.require_auth(self.sandbox_get_config), + methods=["GET"], + ) + self.add_endpoint( + '/api/sandbox/approval-mode', + 'sandbox_set_approval_mode', + self.require_auth(self.sandbox_set_approval_mode), + methods=["POST", "GET"], + ) + # Provider endpoints logger.info("Adding provider API endpoints") self.add_endpoint('/api/providers', 'get_providers', self.require_auth(self.get_providers), methods=["GET"]) @@ -1976,16 +2163,19 @@ def __init__(self, app, **configs): self.add_endpoint('/api/agents/active', 'set_active_agent', self.require_auth(self.set_active_agent), methods=["POST"]) # Data viewer endpoints + # View data page and list documents - requires documents:view permission + # Enable/disable documents - requires documents:select permission logger.info("Adding data viewer API endpoints") - self.add_endpoint('/data', 'data_viewer', self.require_auth(self.data_viewer_page)) - self.add_endpoint('/api/data/documents', 'list_data_documents', self.require_auth(self.list_data_documents), methods=["GET"]) - self.add_endpoint('/api/data/documents//content', 'get_data_document_content', self.require_auth(self.get_data_document_content), methods=["GET"]) - self.add_endpoint('/api/data/documents//chunks', 'get_data_document_chunks', self.require_auth(self.get_data_document_chunks), methods=["GET"]) - self.add_endpoint('/api/data/documents//enable', 'enable_data_document', self.require_auth(self.enable_data_document), methods=["POST"]) - self.add_endpoint('/api/data/documents//disable', 'disable_data_document', self.require_auth(self.disable_data_document), methods=["POST"]) - self.add_endpoint('/api/data/bulk-enable', 'bulk_enable_documents', self.require_auth(self.bulk_enable_documents), methods=["POST"]) - self.add_endpoint('/api/data/bulk-disable', 'bulk_disable_documents', self.require_auth(self.bulk_disable_documents), methods=["POST"]) - self.add_endpoint('/api/data/stats', 'get_data_stats', self.require_auth(self.get_data_stats), methods=["GET"]) + + self.add_endpoint('/data', 'data_viewer', self.require_perm('documents:view')(self.data_viewer_page)) + self.add_endpoint('/api/data/documents', 'list_data_documents', self.require_perm('documents:view')(self.list_data_documents), methods=["GET"]) + self.add_endpoint('/api/data/documents//content', 'get_data_document_content', self.require_perm('documents:view')(self.get_data_document_content), methods=["GET"]) + self.add_endpoint('/api/data/documents//chunks', 'get_data_document_chunks', self.require_perm('documents:view')(self.get_data_document_chunks), methods=["GET"]) + self.add_endpoint('/api/data/documents//enable', 'enable_data_document', self.require_perm('documents:select')(self.enable_data_document), methods=["POST"]) + self.add_endpoint('/api/data/documents//disable', 'disable_data_document', self.require_perm('documents:select')(self.disable_data_document), methods=["POST"]) + self.add_endpoint('/api/data/bulk-enable', 'bulk_enable_documents', self.require_perm('documents:select')(self.bulk_enable_documents), methods=["POST"]) + self.add_endpoint('/api/data/bulk-disable', 'bulk_disable_documents', self.require_perm('documents:select')(self.bulk_disable_documents), methods=["POST"]) + self.add_endpoint('/api/data/stats', 'get_data_stats', self.require_perm('documents:view')(self.get_data_stats), methods=["GET"]) # Data uploader endpoints logger.info("Adding data uploader API endpoints") @@ -2017,10 +2207,35 @@ def __init__(self, app, **configs): self.add_endpoint('/login', 'login', self.login, methods=['GET', 'POST']) self.add_endpoint('/logout', 'logout', self.logout) self.add_endpoint('/auth/user', 'get_user', self.get_user, methods=['GET']) + self.add_endpoint('/api/permissions', 'get_permissions', self.get_permissions, methods=['GET']) + self.add_endpoint('/api/permissions/check', 'check_permission', self.check_permission_endpoint, methods=['POST']) + if self.sso_enabled: self.add_endpoint('/redirect', 'sso_callback', self.sso_callback) + def _set_user_session(self, email: str, name: str, username: str, user_id: str = '', auth_method: str = 'sso', roles: list = None): + """Set user session with well-defined structure.""" + session['user'] = { + 'email': email, + 'name': name, + 'username': username, + 'id': user_id + } + session['logged_in'] = True + session['auth_method'] = auth_method + session['roles'] = roles if roles is not None else [] + + def _get_session_user_email(self) -> str: + """Get user email from session. Returns empty string if not logged in.""" + if not session.get('logged_in'): + return '' + return session['user']['email'] + + def _get_session_roles(self) -> list: + """Get user roles from session. Returns empty list if not logged in.""" + return session.get('roles', []) + def _setup_sso(self): """Initialize OAuth client for SSO using OpenID Connect""" auth_config = self.chat_app_config.get('auth', {}) @@ -2075,36 +2290,50 @@ def login(self): password = request.form.get('password') if check_credentials(username, password, self.salt, self.app.config['ACCOUNTS_FOLDER']): - session['user'] = { - 'email': username, - 'name': username, - 'username': username - } - session['logged_in'] = True - session['auth_method'] = 'basic' + self._set_user_session( + email=username, + name=username, + username=username, + auth_method='basic', + roles=[] + ) logger.info(f"Basic auth login successful for user: {username}") return redirect(url_for('index')) else: flash('Invalid credentials') # Render login page with available auth methods - return render_template('login.html', + return render_template('landing.html', sso_enabled=self.sso_enabled, basic_auth_enabled=self.basic_auth_enabled) def logout(self): """Unified logout endpoint for all auth methods""" auth_method = session.get('auth_method', 'unknown') + user_email = self._get_session_user_email() or 'unknown' + user_roles = session.get('roles', []) + + # Clear all session data including roles session.pop('user', None) session.pop('logged_in', None) session.pop('auth_method', None) + session.pop('roles', None) - logger.info(f"User logged out (method: {auth_method})") + # Log logout event + log_authentication_event( + user=user_email, + event_type='logout', + success=True, + method=auth_method, + details=f"Previous roles: {user_roles}" + ) + + logger.info(f"User {user_email} logged out (method: {auth_method})") flash('You have been logged out successfully') return redirect(url_for('landing')) def sso_callback(self): - """Handle OAuth callback from SSO provider""" + """Handle OAuth callback from SSO provider with RBAC role extraction""" if not self.sso_enabled or not self.oauth: return jsonify({'error': 'SSO not enabled'}), 400 @@ -2118,60 +2347,212 @@ def sso_callback(self): # If userinfo is not in token, fetch it user_info = self.oauth.sso.userinfo(token=token) + user_email = user_info.get('email', user_info.get('preferred_username', 'unknown')) + + # Extract roles from JWT token using RBAC module + # This handles role validation and default role assignment + user_roles = get_user_roles(token, user_email) + # Store user information in session (normalized structure) - session['user'] = { - 'email': user_info.get('email', ''), - 'name': user_info.get('name', user_info.get('preferred_username', '')), - 'username': user_info.get('preferred_username', user_info.get('email', '')), - 'id': user_info.get('sub', '') - } - session['logged_in'] = True - session['auth_method'] = 'sso' + self._set_user_session( + email=user_info.get('email', ''), + name=user_info.get('name', user_info.get('preferred_username', '')), + username=user_info.get('preferred_username', user_info.get('email', '')), + user_id=user_info.get('sub', ''), + auth_method='sso', + roles=user_roles + ) - logger.info(f"SSO login successful for user: {user_info.get('email')}") + # Log successful authentication + log_authentication_event( + user=user_email, + event_type='login', + success=True, + method='sso', + details=f"Roles: {user_roles}" + ) + + logger.info(f"SSO login successful for user: {user_email} with roles: {user_roles}") # Redirect to main page return redirect(url_for('index')) except Exception as e: logger.error(f"SSO callback error: {str(e)}") + log_authentication_event( + user='unknown', + event_type='login', + success=False, + method='sso', + details=str(e) + ) flash(f"Authentication failed: {str(e)}") return redirect(url_for('login')) def get_user(self): - """API endpoint to get current user information""" + """API endpoint to get current user information including roles and permissions""" if session.get('logged_in'): user = session.get('user', {}) + roles = session.get('roles', []) + + # Get permission context for the frontend + permissions = get_permission_context() + return jsonify({ 'logged_in': True, 'email': user.get('email', ''), 'name': user.get('name', ''), 'auth_method': session.get('auth_method', 'unknown'), - 'auth_enabled': self.auth_enabled + 'auth_enabled': self.auth_enabled, + 'roles': roles, + 'permissions': permissions }) return jsonify({ 'logged_in': False, - 'auth_enabled': self.auth_enabled + 'auth_enabled': self.auth_enabled, + 'roles': [], + 'permissions': get_permission_context() }) def require_auth(self, f): - """Decorator to require authentication for routes""" + """Decorator to require authentication for routes. + + When SSO is enabled and anonymous access is blocked (sso.allow_anonymous: false), + unauthenticated users are redirected to SSO login instead of getting a 401 error. + """ @wraps(f) def decorated_function(*args, **kwargs): if not self.auth_enabled: # If auth is not enabled, allow access return f(*args, **kwargs) + # Debug logging for sandbox approval endpoints + if 'sandbox' in request.path: + logger.info(f"[SANDBOX AUTH DEBUG] Path: {request.path}, Method: {request.method}") + logger.info(f"[SANDBOX AUTH DEBUG] session.logged_in: {session.get('logged_in')}") + logger.info(f"[SANDBOX AUTH DEBUG] session keys: {list(session.keys())}") + logger.info(f"[SANDBOX AUTH DEBUG] cookies: {list(request.cookies.keys())}") + if not session.get('logged_in'): - # Return 401 Unauthorized response instead of redirecting + # Check if SSO is enabled and anonymous access is blocked + if self.sso_enabled: + registry = get_registry() + if not registry.allow_anonymous: + # Log the redirect attempt + log_authentication_event( + user='anonymous', + event_type='anonymous_redirect', + success=False, + method='web', + details=f"path={request.path}, method={request.method}" + ) + # Redirect to login page which will trigger SSO + return redirect(url_for('login')) + + # Return 401 Unauthorized response for API requests return jsonify({'error': 'Unauthorized', 'message': 'Authentication required'}), 401 return f(*args, **kwargs) return decorated_function + def require_perm(self, permission: str): + """ + Decorator to require authentication AND a specific permission for routes. + + This combines require_auth with permission checking. Use for routes + that need specific RBAC permissions (e.g., document uploads, config changes). + + Args: + permission: The permission string required (e.g., 'upload:documents') + + Returns: + Decorator function + """ + def decorator(f): + @wraps(f) + def decorated_function(*args, **kwargs): + # First check authentication + if not self.auth_enabled: + return f(*args, **kwargs) + + if not session.get('logged_in'): + if self.sso_enabled: + registry = get_registry() + if not registry.allow_anonymous: + return redirect(url_for('login')) + return jsonify({'error': 'Unauthorized', 'message': 'Authentication required'}), 401 + + # Now check permission + roles = session.get('roles', []) + if not has_permission(permission, roles): + user_email = session.get('user', {}).get('email', 'unknown') + logger.warning(f"Permission denied: user {user_email} with roles {roles} lacks '{permission}'") + from src.utils.rbac.audit import log_permission_check + log_permission_check( + permission=permission, + granted=False, + user=user_email, + roles=roles, + endpoint=request.path + ) + return jsonify({ + 'error': 'Forbidden', + 'message': f'Permission denied: requires {permission}', + 'required_permission': permission + }), 403 + + return f(*args, **kwargs) + return decorated_function + return decorator + def health(self): return jsonify({"status": "OK"}), 200 + def get_permissions(self): + """API endpoint to get current user's permissions""" + if not session.get('logged_in'): + return jsonify({ + 'logged_in': False, + 'permissions': get_permission_context() + }) + + permissions = get_permission_context() + return jsonify({ + 'logged_in': True, + 'roles': session.get('roles', []), + 'permissions': permissions + }) + + def check_permission_endpoint(self): + """API endpoint to check if user has a specific permission""" + if not session.get('logged_in'): + return jsonify({ + 'error': 'Authentication required', + 'has_permission': False + }), 401 + + data = request.get_json() + if not data or 'permission' not in data: + return jsonify({ + 'error': 'Permission name required', + 'has_permission': False + }), 400 + + permission = data['permission'] + roles = session.get('roles', []) + result = has_permission(permission, roles) + + # Get which roles would grant this permission + registry = get_registry() + roles_with_permission = registry.get_roles_with_permission(permission) + + return jsonify({ + 'permission': permission, + 'has_permission': result, + 'user_roles': roles, + 'roles_with_permission': roles_with_permission + }) + def configs(self, **configs): for config, value in configs: self.app.config[config.upper()] = value @@ -3714,6 +4095,224 @@ def cancel_stream(self): logger.error(f"Error cancelling stream for conversation {conversation_id}: {str(e)}") return jsonify({'error': str(e)}), 500 + # ========================================================================= + # Sandbox Artifact Endpoints + # ========================================================================= + + def serve_sandbox_artifact(self, trace_id: str, filename: str): + """ + Serve a sandbox-generated artifact file. + + URL params: + - trace_id: UUID of the agent trace that produced the file. + - filename: Name of the file within that trace's artifact dir. + + Files are stored under ``/sandbox_artifacts//``. + """ + from src.interfaces.chat_app.sandbox_artifacts import serve_artifact + + result = serve_artifact(self.data_path, trace_id, filename) + if isinstance(result, tuple) and len(result) == 3: + body, status, headers = result + if isinstance(body, (bytes, bytearray)): + return Response(body, status=status, headers=headers) + # JSON error dict + return jsonify(body), status + # Fallback for 2-tuple error returns + body, status = result + return jsonify(body), status + + # ========================================================================= + # Sandbox Approval Endpoints + # ========================================================================= + + def sandbox_get_approval(self, approval_id: str): + """ + Get details of a sandbox approval request. + + URL params: + - approval_id: UUID of the approval request. + + Returns: + JSON with approval request details. + """ + from src.utils.sandbox.approval import get_approval_request + + request_obj = get_approval_request(approval_id) + if not request_obj: + return jsonify({"error": "Approval request not found"}), 404 + + return jsonify(request_obj.to_dict()) + + def sandbox_approve(self, approval_id: str): + """ + Approve a sandbox execution request. + + URL params: + - approval_id: UUID of the approval request. + + Returns: + JSON with updated approval status. + """ + from src.utils.sandbox.approval import resolve_approval, get_approval_request + + request_obj = get_approval_request(approval_id) + if not request_obj: + return jsonify({"error": "Approval request not found"}), 404 + + # Get username from session user info + user_info = session.get('user', {}) + username = user_info.get('username') or user_info.get('email', 'unknown') + updated = resolve_approval(approval_id, approved=True, resolved_by=username) + + if not updated: + return jsonify({"error": "Failed to update approval"}), 500 + + logger.info( + "Sandbox execution approved: approval_id=%s by user=%s", + approval_id, username + ) + + return jsonify(updated.to_dict()) + + def sandbox_reject(self, approval_id: str): + """ + Reject a sandbox execution request. + + URL params: + - approval_id: UUID of the approval request. + + Returns: + JSON with updated approval status. + """ + from src.utils.sandbox.approval import resolve_approval, get_approval_request + + request_obj = get_approval_request(approval_id) + if not request_obj: + return jsonify({"error": "Approval request not found"}), 404 + + # Get username from session user info + user_info = session.get('user', {}) + username = user_info.get('username') or user_info.get('email', 'unknown') + updated = resolve_approval(approval_id, approved=False, resolved_by=username) + + if not updated: + return jsonify({"error": "Failed to update approval"}), 500 + + logger.info( + "Sandbox execution rejected: approval_id=%s by user=%s", + approval_id, username + ) + + return jsonify(updated.to_dict()) + + def sandbox_pending_approvals(self): + """ + Get all pending sandbox approval requests for the current conversation. + + Query params: + - conversation_id: Optional. Filter by conversation ID. + + Returns: + JSON with list of pending approval requests. + """ + from src.utils.sandbox.approval import ( + get_pending_approvals_for_conversation, + ) + + conversation_id = request.args.get('conversation_id', type=int) + + if conversation_id: + pending = get_pending_approvals_for_conversation(conversation_id) + else: + # Return empty if no conversation specified + pending = [] + + return jsonify({ + "pending": [req.to_dict() for req in pending], + "count": len(pending), + }) + + def sandbox_get_config(self): + """ + Get the current sandbox configuration (approval mode, enabled status). + + Returns: + JSON with sandbox configuration. + """ + from src.utils.sandbox.config import get_sandbox_config + + config = get_sandbox_config() + + return jsonify({ + "enabled": config.enabled, + "approval_mode": config.approval_mode.value, + "timeout": config.timeout, + "default_image": config.default_image, + "allowed_images": config.image_allowlist, + }) + + def sandbox_set_approval_mode(self): + """ + Get or set the session-level sandbox approval mode. + + GET: Returns the current session approval mode preference. + POST: Sets the session approval mode preference. + + Request body (POST): + { + "mode": "auto" | "manual" + } + + Returns: + JSON with the current/updated approval mode preference. + """ + from src.utils.sandbox.config import get_sandbox_config + + # Session key for storing user's approval mode preference + SESSION_KEY = "sandbox_approval_mode" + + if request.method == "GET": + # Get current session preference (or fallback to config default) + session_mode = session.get(SESSION_KEY) + config = get_sandbox_config() + + return jsonify({ + "session_mode": session_mode, # None means using deployment default + "effective_mode": session_mode or config.approval_mode.value, + "default_mode": config.approval_mode.value, + }) + + # POST - set the approval mode + data = request.get_json() or {} + mode = data.get("mode", "").lower() + + if mode not in ("auto", "manual", "default"): + return jsonify({ + "error": "Invalid mode. Must be 'auto', 'manual', or 'default'." + }), 400 + + if mode == "default": + # Clear session override, use deployment config default + session.pop(SESSION_KEY, None) + config = get_sandbox_config() + return jsonify({ + "session_mode": None, + "effective_mode": config.approval_mode.value, + "message": "Using deployment default approval mode", + }) + + # Set session preference + session[SESSION_KEY] = mode + + logger.info("User set sandbox approval mode to: %s", mode) + + return jsonify({ + "session_mode": mode, + "effective_mode": mode, + "message": f"Approval mode set to '{mode}' for this session", + }) + # ========================================================================= # Data Viewer Endpoints # ========================================================================= diff --git a/src/interfaces/chat_app/sandbox_artifacts.py b/src/interfaces/chat_app/sandbox_artifacts.py new file mode 100644 index 000000000..14ee27fe4 --- /dev/null +++ b/src/interfaces/chat_app/sandbox_artifacts.py @@ -0,0 +1,154 @@ +""" +Sandbox artifact persistence & serving. + +Sandbox outputs (images, CSVs, text files, …) are written directly to disk by +the sandbox tool under ``/sandbox_artifacts//`` and served +via a dedicated Flask route. The response text contains only lightweight +markdown links/images that the frontend resolves to that route. + +Design goals +------------ +* **Plug-and-play** – if the sandbox tool is not installed or no artifacts are + produced, nothing here runs. No config flag needed. +* **No cross-contamination** – every trace gets its own directory, keyed by + the UUID ``trace_id`` already generated by the streaming layer. +* **Arbitrary file types** – images get ``![](…)`` markdown; everything else + gets a download link. The frontend ``sandbox-artifacts.js`` module adds + richer rendering (inline preview for images, download buttons for others). + +TODO: add a periodic cleanup job or TTL-based reaper for old artifact dirs. +""" + +from __future__ import annotations + +import mimetypes +import os +import re +from pathlib import Path +from typing import Dict, List, Tuple + +from src.utils.logging import get_logger + +logger = get_logger(__name__) + +# Allowed characters in filenames served back to clients (security). +_SAFE_FILENAME_RE = re.compile(r"^[\w\-. ]+$") + +# Mime-type families that the frontend can display inline. +_IMAGE_MIMETYPES = frozenset({ + "image/png", "image/jpeg", "image/gif", "image/webp", "image/svg+xml", +}) + + +def _artifacts_root(data_path: str) -> Path: + """Return the top-level artifacts directory, creating it if needed.""" + root = Path(data_path) / "sandbox_artifacts" + root.mkdir(parents=True, exist_ok=True) + return root + + +def _trace_dir(data_path: str, trace_id: str) -> Path: + """Return the per-trace artifact directory, creating it if needed.""" + # Validate trace_id looks like a UUID to prevent path traversal. + if not re.fullmatch(r"[0-9a-f\-]{36}", trace_id): + raise ValueError(f"Invalid trace_id: {trace_id!r}") + d = _artifacts_root(data_path) / trace_id + d.mkdir(parents=True, exist_ok=True) + return d + + +def _sanitize_filename(filename: str) -> str: + """ + Sanitize a filename for safe storage — strip path separators, reject + suspicious patterns. Returns the cleaned name or raises ValueError. + """ + name = os.path.basename(filename).strip() + if not name or not _SAFE_FILENAME_RE.match(name): + raise ValueError(f"Unsafe filename: {filename!r}") + return name + + +# ------------------------------------------------------------------------- +# Public API +# ------------------------------------------------------------------------- + +def format_artifacts_markdown(artifacts: List[Dict]) -> str: + """ + Build markdown text referencing persisted artifacts. + + Parameters + ---------- + artifacts : list[dict] + Artifact metadata dicts from ``get_sandbox_artifacts()``. Each dict + has keys ``filename``, ``mimetype``, and ``url``. + + Returns + ------- + str + Markdown block with images and/or download links. + Images → ``![filename](url)`` + Other files → ``[📎 filename](url)`` + """ + if not artifacts: + return "" + + parts: List[str] = [] + for art in artifacts: + mime = art.get("mimetype", "") + fname = art["filename"] + url = art["url"] + + if mime in _IMAGE_MIMETYPES: + parts.append(f"![{fname}]({url})") + else: + parts.append(f"[📎 {fname}]({url})") + + return "\n\n" + "\n\n".join(parts) + + +# ------------------------------------------------------------------------- +# Flask route handler +# ------------------------------------------------------------------------- + +def serve_artifact(data_path: str, trace_id: str, filename: str) -> Tuple: + """ + Serve an artifact file from disk. + + Returns a ``(response_body, status, headers)`` tuple suitable for Flask, + or a JSON error tuple on failure. + """ + try: + safe_name = _sanitize_filename(filename) + except ValueError: + return ({"error": "invalid filename"}, 400) + + try: + tdir = _trace_dir(data_path, trace_id) + except ValueError: + return ({"error": "invalid trace_id"}, 400) + + filepath = tdir / safe_name + if not filepath.is_file(): + return ({"error": "not found"}, 404) + + # Resolve to prevent symlink escape + resolved = filepath.resolve() + if not str(resolved).startswith(str(tdir.resolve())): + return ({"error": "access denied"}, 403) + + mime = mimetypes.guess_type(safe_name)[0] or "application/octet-stream" + + # For images, allow inline display; for everything else, suggest download. + disposition = "inline" if mime in _IMAGE_MIMETYPES else f'attachment; filename="{safe_name}"' + + data = resolved.read_bytes() + return ( + data, + 200, + { + "Content-Type": mime, + "Content-Disposition": disposition, + "Cache-Control": "private, max-age=3600", + "Content-Length": str(len(data)), + }, + ) diff --git a/src/interfaces/chat_app/static/chat.css b/src/interfaces/chat_app/static/chat.css index 7fafce5bf..977b4186e 100644 --- a/src/interfaces/chat_app/static/chat.css +++ b/src/interfaces/chat_app/static/chat.css @@ -465,6 +465,181 @@ body { color: var(--error-text); } +/* ----------------------------------------------------------------------------- + User Profile Widget (Bottom of Sidebar) + ----------------------------------------------------------------------------- */ +.user-profile-widget { + display: none; + border-top: 1px solid var(--border-color); + background: var(--bg-secondary); +} + +.user-profile-content { + display: flex; + align-items: center; + gap: 10px; + padding: 12px 16px; + cursor: pointer; + transition: background var(--transition-fast); +} + +.user-profile-content:hover { + background: var(--bg-hover); +} + +.user-avatar { + display: flex; + align-items: center; + justify-content: center; + width: 36px; + height: 36px; + border-radius: var(--radius-full); + background: var(--accent-light); + color: var(--accent); + flex-shrink: 0; +} + +.user-info { + flex: 1; + min-width: 0; + overflow: hidden; +} + +.user-name { + font-size: var(--text-sm); + font-weight: 500; + color: var(--text-primary); + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; +} + +.user-email { + font-size: var(--text-xs); + color: var(--text-tertiary); + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; +} + +.user-roles-toggle { + display: flex; + align-items: center; + justify-content: center; + width: 28px; + height: 28px; + border: none; + border-radius: 4px; + background: transparent; + color: var(--text-tertiary); + cursor: pointer; + transition: all var(--transition-fast); + flex-shrink: 0; +} + +.user-roles-toggle:hover { + background: var(--bg-tertiary); + color: var(--text-primary); +} + +.user-profile-widget.expanded .user-roles-toggle { + transform: rotate(180deg); +} + +.user-roles-panel { + max-height: 0; + overflow: hidden; + transition: max-height var(--transition-normal); + border-top: 1px solid transparent; +} + +.user-profile-widget.expanded .user-roles-panel { + max-height: 300px; + border-top-color: var(--border-color); +} + +.user-roles-header { + padding: 12px 16px 8px; + font-size: var(--text-xs); + font-weight: 600; + color: var(--text-tertiary); + text-transform: uppercase; + letter-spacing: 0.05em; +} + +.user-roles-list { + padding: 0 12px 8px; + display: flex; + flex-direction: column; + gap: 4px; +} + +.user-role-badge { + display: inline-flex; + align-items: center; + gap: 6px; + padding: 6px 12px; + border-radius: 6px; + background: var(--accent-light); + color: var(--accent); + font-size: var(--text-xs); + font-weight: 500; +} + +.user-role-badge svg { + width: 12px; + height: 12px; + flex-shrink: 0; +} + +/* Special role styling */ +.user-role-badge.role-admin { + background: rgba(239, 68, 68, 0.1); + color: #dc2626; +} + +:root[data-theme="dark"] .user-role-badge.role-admin { + background: rgba(239, 68, 68, 0.2); + color: #f87171; +} + +.user-role-badge.role-expert { + background: rgba(59, 130, 246, 0.1); + color: #2563eb; +} + +:root[data-theme="dark"] .user-role-badge.role-expert { + background: rgba(59, 130, 246, 0.2); + color: #60a5fa; +} + +.user-roles-footer { + padding: 8px 12px 12px; + border-top: 1px solid var(--border-color); +} + +.user-logout-btn { + display: flex; + align-items: center; + justify-content: center; + gap: 8px; + width: 100%; + padding: 8px 12px; + border: none; + border-radius: 6px; + background: var(--bg-tertiary); + color: var(--text-secondary); + font-size: var(--text-sm); + font-weight: 500; + cursor: pointer; + transition: all var(--transition-fast); +} + +.user-logout-btn:hover { + background: var(--error-bg); + color: var(--error-text); +} + /* ----------------------------------------------------------------------------- Messages Area ----------------------------------------------------------------------------- */ @@ -584,6 +759,16 @@ body { text-decoration: underline; } +/* Inline images (e.g. sandbox-generated plots) */ +.message-content img { + max-width: 100%; + height: auto; + border-radius: 8px; + margin: 12px 0; + display: block; + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.15); +} + /* Inline Code */ .message-content code:not(pre code) { padding: 2px 6px; @@ -1342,6 +1527,26 @@ body { line-height: 1.5; } +.settings-note { + font-size: var(--text-xs); + color: var(--text-tertiary); + margin-top: 8px; + padding: 6px 10px; + background: var(--bg-tertiary); + border-radius: var(--radius-sm); + display: flex; + align-items: center; + gap: 6px; +} + +.sandbox-status-indicator { + display: inline-block; + width: 8px; + height: 8px; + border-radius: 50%; + background: var(--success-color, #22c55e); +} + /* Compact Toggle (for inline use) */ .settings-toggle-compact { position: relative; @@ -3469,3 +3674,257 @@ body { max-height: 150px; } } + +/* ----------------------------------------------------------------------------- + Sandbox Approval Modal + ----------------------------------------------------------------------------- */ +.approval-modal-overlay { + position: fixed; + top: 0; + left: 0; + right: 0; + bottom: 0; + background: rgba(0, 0, 0, 0.6); + display: flex; + align-items: center; + justify-content: center; + z-index: 10000; + animation: fadeIn 0.2s ease; +} + +@keyframes fadeIn { + from { opacity: 0; } + to { opacity: 1; } +} + +.approval-modal { + background: var(--bg-primary); + border-radius: 12px; + box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3); + max-width: 700px; + width: 90%; + max-height: 80vh; + display: flex; + flex-direction: column; + animation: slideUp 0.2s ease; +} + +@keyframes slideUp { + from { transform: translateY(20px); opacity: 0; } + to { transform: translateY(0); opacity: 1; } +} + +.approval-modal-header { + padding: 20px 24px; + border-bottom: 1px solid var(--border-color); +} + +.approval-modal-header h3 { + margin: 0 0 8px 0; + font-size: 18px; + font-weight: 600; + color: var(--text-primary); +} + +.approval-modal-header p { + margin: 0; + color: var(--text-secondary); + font-size: var(--text-sm); +} + +.approval-modal-body { + padding: 20px 24px; + overflow-y: auto; + flex: 1; +} + +.approval-code-block { + background: var(--code-bg); + border-radius: 8px; + overflow: hidden; +} + +.approval-code-header { + display: flex; + gap: 8px; + padding: 8px 12px; + background: var(--code-header-bg); + border-bottom: 1px solid rgba(255, 255, 255, 0.1); +} + +.language-badge, +.image-badge { + font-size: 11px; + padding: 2px 8px; + border-radius: 4px; + font-family: var(--font-mono); +} + +.language-badge { + background: rgba(16, 163, 127, 0.2); + color: #10a37f; +} + +.image-badge { + background: rgba(99, 102, 241, 0.2); + color: #818cf8; +} + +.approval-code-block pre { + margin: 0; + padding: 16px; + max-height: 300px; + overflow-y: auto; +} + +.approval-code-block code { + font-family: var(--font-mono); + font-size: 13px; + color: var(--code-text); + white-space: pre-wrap; + word-break: break-word; +} + +.approval-modal-footer { + display: flex; + justify-content: flex-end; + gap: 12px; + padding: 16px 24px; + border-top: 1px solid var(--border-color); +} + +.approval-modal-footer button { + padding: 10px 24px; + border-radius: 6px; + font-size: var(--text-sm); + font-weight: 500; + cursor: pointer; + transition: all 0.2s ease; +} + +.btn-reject { + background: var(--bg-tertiary); + border: 1px solid var(--border-color); + color: var(--text-primary); +} + +.btn-reject:hover { + background: var(--bg-hover); +} + +.btn-approve { + background: var(--accent); + border: 1px solid var(--accent); + color: white; +} + +.btn-approve:hover { + background: var(--accent-hover); +} + +.approval-timeout-notice { + text-align: center; + padding: 8px 24px 16px; + font-size: var(--text-xs); + color: var(--text-tertiary); +} + +/* Approval request block in trace */ +.approval-request { + border-left-color: #f59e0b !important; +} + +.approval-request .tool-status.pending { + color: #f59e0b; + animation: pulse 1.5s infinite; +} + +@keyframes pulse { + 0%, 100% { opacity: 1; } + 50% { opacity: 0.5; } +} + +.approval-request .approval-meta { + display: flex; + gap: 12px; + margin-bottom: 12px; + flex-wrap: wrap; +} + +.approval-request .language-badge, +.approval-request .image-badge, +.approval-request .lines-badge { + display: inline-flex; + align-items: center; + padding: 4px 10px; + border-radius: 12px; + font-size: 11px; + font-weight: 500; +} + +.approval-request .language-badge { + background: #3b82f620; + color: #3b82f6; +} + +.approval-request .image-badge { + background: #8b5cf620; + color: #8b5cf6; +} + +.approval-request .lines-badge { + background: #6b728020; + color: var(--text-secondary); +} + +.approval-request .approval-code-block { + background: var(--code-bg); + border-radius: 8px; + overflow: hidden; + border: 1px solid var(--border-color); +} + +.approval-request .code-block-header { + display: flex; + justify-content: space-between; + align-items: center; + padding: 8px 12px; + background: var(--bg-tertiary); + border-bottom: 1px solid var(--border-color); +} + +.approval-request .code-language { + font-size: 11px; + font-weight: 500; + color: var(--text-secondary); + text-transform: uppercase; +} + +.approval-request .copy-btn { + background: none; + border: none; + cursor: pointer; + font-size: 14px; + padding: 2px 6px; + border-radius: 4px; + transition: background var(--transition-fast); +} + +.approval-request .copy-btn:hover { + background: var(--bg-secondary); +} + +.approval-request .approval-code { + margin: 0; + padding: 12px; + max-height: 300px; + overflow-y: auto; + font-family: var(--font-mono); + font-size: 12px; + line-height: 1.5; +} + +.approval-request .approval-code code { + font-family: var(--font-mono); + font-size: 12px; + color: var(--code-text); diff --git a/src/interfaces/chat_app/static/chat.js b/src/interfaces/chat_app/static/chat.js index f55fbd51a..1397ef9ed 100644 --- a/src/interfaces/chat_app/static/chat.js +++ b/src/interfaces/chat_app/static/chat.js @@ -634,6 +634,14 @@ const UI = { modelSelectPrimary: document.getElementById('model-select-primary'), providerSelectB: document.getElementById('provider-select-b'), providerStatus: document.getElementById('provider-status'), + // User profile elements + userProfileWidget: document.getElementById('user-profile-widget'), + userDisplayName: document.getElementById('user-display-name'), + userEmail: document.getElementById('user-email'), + userRolesToggle: document.getElementById('user-roles-toggle'), + userRolesPanel: document.getElementById('user-roles-panel'), + userRolesList: document.getElementById('user-roles-list'), + userLogoutBtn: document.getElementById('user-logout-btn'), customModelInput: document.getElementById('custom-model-input'), customModelRow: document.getElementById('custom-model-row'), activeModelLabel: document.getElementById('active-model-label'), @@ -842,6 +850,21 @@ const UI = { Chat.handleProviderBChange(e.target.value); }); + // User profile widget interactions + this.elements.userRolesToggle?.addEventListener('click', (e) => { + e.stopPropagation(); + this.toggleUserRolesPanel(); + }); + + this.elements.userProfileWidget?.addEventListener('click', () => { + this.toggleUserRolesPanel(); + }); + + this.elements.userLogoutBtn?.addEventListener('click', (e) => { + e.stopPropagation(); + window.location.href = '/logout'; + }); + // Close modal on Escape document.addEventListener('keydown', (e) => { if (e.key === 'Escape' && this.elements.settingsModal?.style.display !== 'none') { @@ -929,6 +952,85 @@ const UI = { } }, + toggleUserRolesPanel() { + this.elements.userProfileWidget?.classList.toggle('expanded'); + }, + + async loadUserProfile() { + try { + const response = await fetch('/auth/user'); + if (!response.ok) return; + + const data = await response.json(); + + if (!data.logged_in) { + // User not logged in, hide the widget + if (this.elements.userProfileWidget) { + this.elements.userProfileWidget.style.display = 'none'; + } + return; + } + + // Show the widget + if (this.elements.userProfileWidget) { + this.elements.userProfileWidget.style.display = 'block'; + } + + // Extract name from email (before @) + const email = data.email || 'User'; + const displayName = email.split('@')[0]; + + // Update user info + if (this.elements.userDisplayName) { + this.elements.userDisplayName.textContent = displayName; + } + if (this.elements.userEmail) { + this.elements.userEmail.textContent = email; + } + + // Render roles + this.renderUserRoles(data.roles || []); + + } catch (e) { + console.error('Failed to load user profile:', e); + // Hide widget on error + if (this.elements.userProfileWidget) { + this.elements.userProfileWidget.style.display = 'none'; + } + } + }, + + renderUserRoles(roles) { + if (!this.elements.userRolesList) return; + + if (!roles || roles.length === 0) { + this.elements.userRolesList.innerHTML = '

No roles assigned

'; + return; + } + + const getRoleClass = (role) => { + if (role.includes('admin')) return 'role-admin'; + if (role.includes('expert')) return 'role-expert'; + return ''; + }; + + const roleIcon = ` + + + + + `; + + this.elements.userRolesList.innerHTML = roles + .map(role => ` +
+ ${roleIcon} + ${Utils.escapeHtml(role)} +
+ `) + .join(''); + }, + async loadAgentInfo() { if (!this.elements.agentInfoContent) return; try { @@ -1648,6 +1750,38 @@ const UI = { }); }, + updateSandboxApprovalToggle(data) { + const toggle = document.getElementById('sandbox-approval-toggle'); + const statusText = document.getElementById('sandbox-approval-mode-text'); + const group = document.getElementById('sandbox-approval-group'); + + if (!toggle || !statusText) { + if (group) group.style.display = 'none'; + return; + } + + // Show the group + if (group) group.style.display = ''; + + // Set toggle state based on effective mode + const isManual = data.effective_mode === 'manual'; + toggle.checked = isManual; + + // Update status text + if (data.session_mode) { + // User has a session override + statusText.textContent = `Mode: ${data.effective_mode} (session preference)`; + } else { + // Using deployment default + statusText.textContent = `Mode: ${data.effective_mode} (deployment default)`; + } + + // Bind the toggle change event (remove existing listeners first) + toggle.onchange = () => { + Chat.toggleSandboxApprovalMode(toggle.checked); + }; + }, + renderConversations(conversations, activeId) { const list = this.elements.conversationList; if (!list) return; @@ -1726,6 +1860,12 @@ const UI = { } container.innerHTML = messages.map((msg) => this.createMessageHTML(msg)).join(''); + + // Enhance sandbox artifact images/links in loaded messages + if (typeof SandboxArtifacts !== 'undefined') { + SandboxArtifacts.enhance(container); + } + this.scrollToBottom(); }, @@ -2334,6 +2474,196 @@ const UI = { } }, + renderApprovalRequest(messageId, event) { + /** + * Render an approval request indicator in the trace view with full code display. + */ + const container = document.querySelector(`.trace-container[data-message-id="${messageId}"]`); + if (!container) return; + + const toolsContainer = container.querySelector('.trace-tools'); + if (!toolsContainer) return; + + const code = event.code || ''; + const language = event.language || 'python'; + const image = event.image || 'default'; + const codeLines = code.split('\n').length; + const escapedCode = Utils.escapeHtml(code); + + const approvalBlock = document.createElement('div'); + approvalBlock.className = 'tool-block approval-request expanded'; + approvalBlock.dataset.approvalId = event.approval_id; + approvalBlock.innerHTML = ` +
+ ⚠️ + Sandbox Code Execution + Awaiting approval... +
+
+
+ ${Utils.escapeHtml(language)} + 📦 ${Utils.escapeHtml(image)} + ${codeLines} lines +
+
+
+ ${Utils.escapeHtml(language)} + +
+
${escapedCode}
+
+
+ `; + + // Add copy functionality + const copyBtn = approvalBlock.querySelector('.copy-btn'); + if (copyBtn) { + copyBtn.onclick = () => { + navigator.clipboard.writeText(code).then(() => { + copyBtn.textContent = '✓'; + setTimeout(() => { copyBtn.textContent = '📋'; }, 1500); + }); + }; + } + + toolsContainer.appendChild(approvalBlock); + this.scrollToBottom(); + + // Apply syntax highlighting if available + if (typeof hljs !== 'undefined') { + const codeEl = approvalBlock.querySelector('code'); + if (codeEl) hljs.highlightElement(codeEl); + } + }, + + async showApprovalModal(event) { + /** + * Show a modal dialog for the user to approve or reject sandbox code execution. + * Returns a Promise that resolves to true (approved) or false (rejected). + */ + return new Promise((resolve) => { + // Create modal overlay + const overlay = document.createElement('div'); + overlay.className = 'approval-modal-overlay'; + + const modal = document.createElement('div'); + modal.className = 'approval-modal'; + modal.innerHTML = ` +
+

⚠️ Sandbox Execution Approval

+

The AI wants to execute the following code. Please review and approve or reject.

+
+
+
+
+ ${Utils.escapeHtml(event.language || 'python')} + ${Utils.escapeHtml(event.image || 'default')} +
+
${Utils.escapeHtml(event.code || '')}
+
+
+ +
+ This request will expire in ${Math.floor((event.timeout_seconds || 300) / 60)} minutes. +
+ `; + + overlay.appendChild(modal); + document.body.appendChild(overlay); + + // Handle button clicks + const approveBtn = modal.querySelector('.btn-approve'); + const rejectBtn = modal.querySelector('.btn-reject'); + + const cleanup = () => { + overlay.remove(); + }; + + approveBtn.onclick = () => { + cleanup(); + // Update the approval block status + const approvalBlock = document.querySelector(`.approval-request[data-approval-id="${event.approval_id}"]`); + if (approvalBlock) { + approvalBlock.classList.remove('expanded'); + approvalBlock.classList.add('tool-success'); + const statusEl = approvalBlock.querySelector('.tool-status'); + if (statusEl) { + statusEl.innerHTML = ' Approved'; + statusEl.classList.remove('pending'); + } + } + resolve(true); + }; + + rejectBtn.onclick = () => { + cleanup(); + // Update the approval block status + const approvalBlock = document.querySelector(`.approval-request[data-approval-id="${event.approval_id}"]`); + if (approvalBlock) { + approvalBlock.classList.remove('expanded'); + approvalBlock.classList.add('tool-error'); + const statusEl = approvalBlock.querySelector('.tool-status'); + if (statusEl) { + statusEl.innerHTML = ' Rejected'; + statusEl.classList.remove('pending'); + } + } + resolve(false); + }; + + // Close on overlay click (treat as reject) + overlay.onclick = (e) => { + if (e.target === overlay) { + cleanup(); + resolve(false); + } + }; + + // Handle escape key + const handleEscape = (e) => { + if (e.key === 'Escape') { + cleanup(); + document.removeEventListener('keydown', handleEscape); + resolve(false); + } + }; + document.addEventListener('keydown', handleEscape); + }); + }, + + updateApprovalStatus(approvalId, status) { + /** + * Update the status display for an approval request. + */ + const approvalBlock = document.querySelector(`.approval-request[data-approval-id="${approvalId}"]`); + if (!approvalBlock) return; + + const statusEl = approvalBlock.querySelector('.tool-status'); + if (!statusEl) return; + + statusEl.classList.remove('pending'); + if (status === 'approved') { + approvalBlock.classList.add('tool-success'); + statusEl.innerHTML = ' Approved'; + } else if (status === 'rejected') { + approvalBlock.classList.add('tool-error'); + statusEl.innerHTML = ' Rejected'; + } else if (status === 'expired') { + approvalBlock.classList.add('tool-error'); + statusEl.innerHTML = ' Expired'; + } + }, + + toggleToolExpanded(toolCallId) { + const toolBlock = document.querySelector(`.tool-block[data-tool-call-id="${toolCallId}"]`); + if (toolBlock) { + toolBlock.classList.toggle('expanded'); + } + }, + // ========================================================================= // Context Meter // ========================================================================= @@ -2831,6 +3161,8 @@ const Chat = { this.loadProviders(), this.loadPipelineDefaultModel(), this.loadApiKeyStatus(), + this.loadSandboxApprovalMode(), + UI.loadUserProfile(), this.loadAgents(), ]); @@ -3190,6 +3522,57 @@ const Chat = { } }, + // Sandbox Approval Mode Management + async loadSandboxApprovalMode() { + try { + const response = await fetch('/api/sandbox/approval-mode'); + if (!response.ok) { + // Sandbox might not be enabled - hide the toggle + this.hideSandboxApprovalToggle(); + return; + } + const data = await response.json(); + this.state.sandboxApprovalMode = data; + UI.updateSandboxApprovalToggle(data); + } catch (e) { + console.error('Failed to load sandbox approval mode:', e); + this.hideSandboxApprovalToggle(); + } + }, + + hideSandboxApprovalToggle() { + const group = document.getElementById('sandbox-approval-group'); + if (group) { + group.style.display = 'none'; + } + }, + + async toggleSandboxApprovalMode(enabled) { + const mode = enabled ? 'manual' : 'auto'; + try { + const response = await fetch('/api/sandbox/approval-mode', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ mode }), + }); + if (!response.ok) { + throw new Error('Failed to update approval mode'); + } + const data = await response.json(); + this.state.sandboxApprovalMode = { + session_mode: data.session_mode, + effective_mode: data.effective_mode, + default_mode: this.state.sandboxApprovalMode?.default_mode, + }; + UI.updateSandboxApprovalToggle(this.state.sandboxApprovalMode); + console.log('Sandbox approval mode updated:', data.message); + } catch (e) { + console.error('Failed to toggle sandbox approval mode:', e); + // Revert the toggle UI + await this.loadSandboxApprovalMode(); + } + }, + async loadConversations() { try { const data = await API.getConversations(); @@ -3690,6 +4073,16 @@ const Chat = { if (showTrace) { UI.renderToolEnd(messageId, event); } + } else if (event.type === 'approval_request') { + // Sandbox approval request - show modal for user to approve/reject + this.state.activeTrace.events.push(event); + if (showTrace) { + UI.renderApprovalRequest(messageId, event); + } + // Show approval modal + const approved = await UI.showApprovalModal(event); + // Send approval decision to server + await this.handleSandboxApproval(event.approval_id, approved); } else if (event.type === 'thinking_start') { this.state.activeTrace.events.push(event); if (showTrace) { @@ -3739,6 +4132,12 @@ const Chat = { streaming: false, }); + // Enhance sandbox artifact images/links + if (typeof SandboxArtifacts !== 'undefined') { + const msgEl = document.getElementById(messageId); + if (msgEl) SandboxArtifacts.enhance(msgEl); + } + // Update message ID from backend so feedback works if (event.message_id != null) { const msg = this.state.messages.find(m => m.id === messageId); @@ -3801,6 +4200,29 @@ const Chat = { } }, + async handleSandboxApproval(approvalId, approved) { + /** + * Send approval decision for a sandbox execution request. + */ + const endpoint = approved + ? `/api/sandbox/approval/${approvalId}/approve` + : `/api/sandbox/approval/${approvalId}/reject`; + + try { + const response = await fetch(endpoint, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + credentials: 'include', // Include session cookies for auth + }); + + if (!response.ok) { + console.error('Failed to send approval decision:', response.status); + } + } catch (e) { + console.error('Error sending approval decision:', e); + } + }, + async cancelStream() { if (this.state.abortController) { this.state.abortController.abort(); diff --git a/src/interfaces/chat_app/static/sandbox-artifacts.js b/src/interfaces/chat_app/static/sandbox-artifacts.js new file mode 100644 index 000000000..0dc59e195 --- /dev/null +++ b/src/interfaces/chat_app/static/sandbox-artifacts.js @@ -0,0 +1,206 @@ +/** + * Sandbox Artifacts Rendering Module + * + * Provides enhanced rendering for sandbox-generated artifacts beyond what + * marked.js does by default. Images get lightbox-style click-to-zoom; + * other file types get download buttons with file-type icons. + * + * Usage: + * After rendering markdown, call `SandboxArtifacts.enhance(containerEl)` + * to upgrade any artifact links/images inside that element. + * + * The module detects artifacts by their URL pattern: + * /api/sandbox-artifacts// + */ + +// eslint-disable-next-line no-unused-vars +const SandboxArtifacts = (function () { + 'use strict'; + + const ARTIFACT_URL_RE = /^\/api\/sandbox-artifacts\/([0-9a-f-]{36})\/(.+)$/i; + + // File extension -> icon emoji (keep simple; no external icon lib needed). + const FILE_ICONS = { + csv: '📊', + json: '📋', + txt: '📄', + md: '📝', + py: '🐍', + js: '📜', + html: '🌐', + pdf: '📕', + zip: '📦', + tar: '📦', + gz: '📦', + }; + + /** + * Check if a URL is a sandbox artifact. + * @param {string} url + * @returns {boolean} + */ + function isArtifactUrl(url) { + return ARTIFACT_URL_RE.test(url); + } + + /** + * Extract filename from an artifact URL. + * @param {string} url + * @returns {string|null} + */ + function extractFilename(url) { + const m = ARTIFACT_URL_RE.exec(url); + return m ? decodeURIComponent(m[2]) : null; + } + + /** + * Get a suitable icon for a filename. + * @param {string} filename + * @returns {string} + */ + function getFileIcon(filename) { + const ext = (filename.split('.').pop() || '').toLowerCase(); + return FILE_ICONS[ext] || '📎'; + } + + /** + * Enhance artifact images within a container. + * + * Adds click-to-open-in-new-tab and a subtle border style. + * @param {HTMLElement} container + */ + function enhanceImages(container) { + const imgs = container.querySelectorAll('img'); + imgs.forEach((img) => { + const src = img.getAttribute('src') || ''; + if (!isArtifactUrl(src)) return; + + // Style for sandbox artifact images + img.classList.add('sandbox-artifact-image'); + + // Wrap in a link if not already wrapped + if (img.parentElement.tagName !== 'A') { + const link = document.createElement('a'); + link.href = src; + link.target = '_blank'; + link.rel = 'noopener noreferrer'; + link.title = 'Click to open full size'; + img.parentNode.insertBefore(link, img); + link.appendChild(img); + } + }); + } + + /** + * Enhance artifact download links. + * + * Converts plain `[📎 file.csv](/api/sandbox-artifacts/...)` links into + * styled download buttons. + * @param {HTMLElement} container + */ + function enhanceLinks(container) { + const links = container.querySelectorAll('a'); + links.forEach((link) => { + const href = link.getAttribute('href') || ''; + if (!isArtifactUrl(href)) return; + + // Skip image links (handled by enhanceImages) + if (link.querySelector('img')) return; + + const filename = extractFilename(href); + if (!filename) return; + + // Already enhanced? + if (link.classList.contains('sandbox-artifact-link')) return; + + link.classList.add('sandbox-artifact-link'); + link.setAttribute('download', filename); + link.title = `Download ${filename}`; + + // Rewrite content to include icon + filename + const icon = getFileIcon(filename); + link.innerHTML = `${icon} ${escapeHtml(filename)}`; + }); + } + + /** + * Simple HTML escape. + * @param {string} str + * @returns {string} + */ + function escapeHtml(str) { + const div = document.createElement('div'); + div.textContent = str; + return div.innerHTML; + } + + /** + * Main entry point. Call after markdown rendering to enhance artifacts. + * @param {HTMLElement} container + */ + function enhance(container) { + if (!container) return; + enhanceImages(container); + enhanceLinks(container); + } + + /** + * Inject default styles for artifact images and links. + * Call once on page load. + */ + function injectStyles() { + if (document.getElementById('sandbox-artifact-styles')) return; + + const style = document.createElement('style'); + style.id = 'sandbox-artifact-styles'; + style.textContent = ` + .sandbox-artifact-image { + max-width: 100%; + border: 1px solid var(--border-color, #444); + border-radius: 6px; + cursor: pointer; + transition: box-shadow 0.2s; + } + .sandbox-artifact-image:hover { + box-shadow: 0 0 8px rgba(102, 179, 255, 0.4); + } + .sandbox-artifact-link { + display: inline-flex; + align-items: center; + gap: 0.4em; + padding: 0.3em 0.7em; + background: var(--secondary-bg, #2d2d2d); + border: 1px solid var(--border-color, #444); + border-radius: 4px; + color: var(--primary-text, #e0e0e0); + text-decoration: none; + font-size: 0.9em; + transition: background 0.2s; + } + .sandbox-artifact-link:hover { + background: var(--hover-bg, #3a3a3a); + text-decoration: none; + } + .artifact-icon { + font-size: 1.1em; + } + `; + document.head.appendChild(style); + } + + // Auto-inject styles when module loads + if (typeof document !== 'undefined') { + if (document.readyState === 'loading') { + document.addEventListener('DOMContentLoaded', injectStyles); + } else { + injectStyles(); + } + } + + return { + enhance, + isArtifactUrl, + extractFilename, + injectStyles, + }; +})(); diff --git a/src/interfaces/chat_app/templates/index.html b/src/interfaces/chat_app/templates/index.html index ca59b893f..d380fc0b5 100644 --- a/src/interfaces/chat_app/templates/index.html +++ b/src/interfaces/chat_app/templates/index.html @@ -41,6 +41,43 @@
+ + + @@ -269,6 +306,22 @@

Advanced Settings

+ + +
+
+ Sandbox Code Approval + +
+

When enabled, you must approve each code execution before it runs in the sandbox.

+

+ + Loading... +

+
@@ -354,6 +407,7 @@

New Agent

+ diff --git a/src/utils/rbac/__init__.py b/src/utils/rbac/__init__.py new file mode 100644 index 000000000..ed67a13f3 --- /dev/null +++ b/src/utils/rbac/__init__.py @@ -0,0 +1,57 @@ +""" +RBAC (Role-Based Access Control) Module for A2rchi + +This module provides authentication and authorization functionality including: +- Permission registry and role-to-permission mappings +- Route protection decorators +- JWT token parsing for role extraction +- Audit logging for security events + +Usage: + from src.utils.rbac import require_permission, has_permission, get_user_roles + + @app.route('/api/upload') + @require_permission('upload:documents') + def upload(): + ... +""" + +from src.utils.rbac.registry import ( + RBACRegistry, + get_registry, + load_rbac_config, +) +from src.utils.rbac.decorators import ( + require_permission, + require_any_permission, + require_authenticated, +) +from src.utils.rbac.permissions import ( + has_permission, + get_user_permissions, + check_permission, +) +from src.utils.rbac.jwt_parser import ( + extract_roles_from_token, + get_user_roles, + assign_default_role, +) + +__all__ = [ + # Registry + 'RBACRegistry', + 'get_registry', + 'load_rbac_config', + # Decorators + 'require_permission', + 'require_any_permission', + 'require_authenticated', + # Permissions + 'has_permission', + 'get_user_permissions', + 'check_permission', + # JWT Parser + 'extract_roles_from_token', + 'get_user_roles', + 'assign_default_role', +] diff --git a/src/utils/rbac/audit.py b/src/utils/rbac/audit.py new file mode 100644 index 000000000..12f33540f --- /dev/null +++ b/src/utils/rbac/audit.py @@ -0,0 +1,145 @@ +""" +RBAC Audit Logging - Security event logging for access control + +This module provides audit logging for all permission checks, +supporting security analysis and compliance requirements. +""" + +import json +from datetime import datetime, timezone +from typing import List, Optional + +from src.utils.logging import get_logger + +# Dedicated audit logger +audit_logger = get_logger('rbac.audit') + + +def log_permission_check( + user: str, + permission: str, + granted: bool, + endpoint: str, + roles: List[str], + missing: Optional[List[str]] = None, + extra: Optional[dict] = None +) -> None: + """ + Log a permission check event for audit trail. + + Args: + user: Username or email of the user (or 'anonymous') + permission: Permission(s) being checked + granted: Whether access was granted + endpoint: Flask endpoint name + roles: User's current roles + missing: Permissions that were missing (if denied) + extra: Additional context information + """ + timestamp = datetime.now(timezone.utc).isoformat() + result = 'GRANTED' if granted else 'DENIED' + + # Structured log entry + log_entry = { + 'timestamp': timestamp, + 'user': user, + 'permission': permission, + 'result': result, + 'endpoint': endpoint, + 'roles': roles, + } + + if missing: + log_entry['missing_permissions'] = missing + + if extra: + log_entry.update(extra) + + # Log level based on result + log_message = f"{user} | {permission} | {result} | {endpoint} | roles: {roles}" + + if granted: + audit_logger.debug(log_message) + else: + audit_logger.warning(log_message) + # Also log structured JSON for easier parsing + audit_logger.info(f"AUDIT: {json.dumps(log_entry)}") + + +def log_role_assignment( + user: str, + roles: List[str], + source: str, + is_default: bool = False +) -> None: + """ + Log a role assignment event. + + Args: + user: Username or email + roles: Roles assigned to user + source: Source of roles (e.g., 'jwt', 'default') + is_default: Whether default role was assigned + """ + timestamp = datetime.now(timezone.utc).isoformat() + + log_entry = { + 'timestamp': timestamp, + 'event': 'role_assignment', + 'user': user, + 'roles': roles, + 'source': source, + 'is_default': is_default, + } + + if is_default: + audit_logger.warning( + f"Default role assigned to {user}: {roles} (no JWT roles found)" + ) + else: + audit_logger.info(f"Roles assigned to {user}: {roles} (source: {source})") + + audit_logger.debug(f"AUDIT: {json.dumps(log_entry)}") + + +def log_authentication_event( + user: str, + event_type: str, + success: bool, + method: str, + details: Optional[str] = None +) -> None: + """ + Log an authentication event. + + Args: + user: Username or email (or 'unknown') + event_type: Type of event ('login', 'logout', 'token_refresh') + success: Whether the event succeeded + method: Authentication method ('sso', 'basic') + details: Additional details or error message + """ + timestamp = datetime.now(timezone.utc).isoformat() + result = 'SUCCESS' if success else 'FAILURE' + + log_entry = { + 'timestamp': timestamp, + 'event': event_type, + 'user': user, + 'result': result, + 'method': method, + } + + if details: + log_entry['details'] = details + + log_message = f"AUTH | {event_type} | {user} | {result} | method: {method}" + if details: + log_message += f" | {details}" + + if success: + audit_logger.info(log_message) + else: + audit_logger.warning(log_message) + + audit_logger.debug(f"AUDIT: {json.dumps(log_entry)}") diff --git a/src/utils/rbac/decorators.py b/src/utils/rbac/decorators.py new file mode 100644 index 000000000..12bb9b34f --- /dev/null +++ b/src/utils/rbac/decorators.py @@ -0,0 +1,338 @@ +""" +RBAC Decorators - Route protection decorators for Flask endpoints + +This module provides decorators to protect Flask routes with permission requirements. +Decorators handle authentication checks, permission validation, and audit logging. +""" + +from functools import wraps +from typing import Callable, List, Optional, Union +from flask import session, jsonify, redirect, url_for, request, g + +from src.utils.logging import get_logger +from src.utils.rbac.registry import get_registry +from src.utils.rbac.audit import log_permission_check + +logger = get_logger(__name__) + + +class PermissionDeniedError(Exception): + """Raised when a permission check fails.""" + def __init__(self, message: str, required_permission: str, user_roles: List[str]): + super().__init__(message) + self.required_permission = required_permission + self.user_roles = user_roles + + +def get_current_user_roles() -> List[str]: + """ + Get roles for the current user from session. + + Returns: + List of role names, or empty list if not authenticated + """ + if not session.get('logged_in'): + return [] + + return session.get('roles', []) + + +def is_authenticated() -> bool: + """ + Check if current user is authenticated. + + Returns: + True if user has an active session + """ + return session.get('logged_in', False) + + +def require_authenticated(f: Callable) -> Callable: + """ + Decorator that requires user to be authenticated. + + Does NOT check for specific permissions, only that user is logged in. + Use this for routes that should be accessible to all authenticated users. + + Usage: + @app.route('/profile') + @require_authenticated + def profile(): + ... + """ + @wraps(f) + def decorated_function(*args, **kwargs): + if not is_authenticated(): + # Log the denial + log_permission_check( + user='anonymous', + permission='authenticated', + granted=False, + endpoint=request.endpoint, + roles=[] + ) + + # Check if this is an API request + if request.is_json or request.path.startswith('/api/'): + return jsonify({ + 'error': 'Authentication required', + 'message': 'Please log in to access this resource', + 'status': 401 + }), 401 + + # Redirect to login for browser requests + return redirect(url_for('login')) + + return f(*args, **kwargs) + + return decorated_function + + +def require_permission(permission: Union[str, List[str]]) -> Callable: + """ + Decorator that requires specific permission(s) to access a route. + + If a list of permissions is provided, user must have ALL of them. + For "any of" logic, use require_any_permission instead. + + Usage: + @app.route('/api/upload') + @require_permission('upload:documents') + def upload(): + ... + + @app.route('/api/admin/config') + @require_permission(['config:view', 'config:modify']) + def admin_config(): + ... + """ + # Normalize to list + required_permissions = [permission] if isinstance(permission, str) else permission + + def decorator(f: Callable) -> Callable: + @wraps(f) + def decorated_function(*args, **kwargs): + # First check authentication + if not is_authenticated(): + log_permission_check( + user='anonymous', + permission=','.join(required_permissions), + granted=False, + endpoint=request.endpoint, + roles=[] + ) + + if request.is_json or request.path.startswith('/api/'): + return jsonify({ + 'error': 'Authentication required', + 'message': 'Please log in to access this resource', + 'status': 401 + }), 401 + + return redirect(url_for('login')) + + # Get user roles + user_roles = get_current_user_roles() + user_email = session.get('user', {}).get('email', 'unknown') + + # Get registry and check permissions + registry = get_registry() + + # Check each required permission + missing_permissions = [] + for perm in required_permissions: + if not registry.has_permission(user_roles, perm): + missing_permissions.append(perm) + + if missing_permissions: + # Log the denial + log_permission_check( + user=user_email, + permission=','.join(required_permissions), + granted=False, + endpoint=request.endpoint, + roles=user_roles, + missing=missing_permissions + ) + + # Get roles that would grant the permission for helpful error message + roles_with_permission = set() + for perm in missing_permissions: + roles_with_permission.update(registry.get_roles_with_permission(perm)) + + if request.is_json or request.path.startswith('/api/'): + return jsonify({ + 'error': 'Insufficient permissions', + 'required_permissions': missing_permissions, + 'user_roles': user_roles, + 'roles_with_permission': list(roles_with_permission), + 'message': f"You need one of these roles to access this feature: {', '.join(roles_with_permission)}", + 'status': 403 + }), 403 + + # For browser requests, return 403 page + from flask import render_template + return render_template('error.html', + error_code=403, + error_title='Permission Denied', + error_message=f"You don't have permission to access this feature.", + required_roles=list(roles_with_permission) + ), 403 + + # Log successful access + log_permission_check( + user=user_email, + permission=','.join(required_permissions), + granted=True, + endpoint=request.endpoint, + roles=user_roles + ) + + return f(*args, **kwargs) + + return decorated_function + + return decorator + + +def require_any_permission(permissions: List[str]) -> Callable: + """ + Decorator that requires ANY ONE of the specified permissions. + + User only needs one of the listed permissions to access the route. + + Usage: + @app.route('/api/settings') + @require_any_permission(['config:view', 'config:modify', 'admin:system']) + def settings(): + ... + """ + def decorator(f: Callable) -> Callable: + @wraps(f) + def decorated_function(*args, **kwargs): + # First check authentication + if not is_authenticated(): + log_permission_check( + user='anonymous', + permission=f"any({','.join(permissions)})", + granted=False, + endpoint=request.endpoint, + roles=[] + ) + + if request.is_json or request.path.startswith('/api/'): + return jsonify({ + 'error': 'Authentication required', + 'message': 'Please log in to access this resource', + 'status': 401 + }), 401 + + return redirect(url_for('login')) + + # Get user roles + user_roles = get_current_user_roles() + user_email = session.get('user', {}).get('email', 'unknown') + + # Get registry and check permissions + registry = get_registry() + + # Check if user has ANY of the permissions + has_any = False + for perm in permissions: + if registry.has_permission(user_roles, perm): + has_any = True + break + + if not has_any: + # Log the denial + log_permission_check( + user=user_email, + permission=f"any({','.join(permissions)})", + granted=False, + endpoint=request.endpoint, + roles=user_roles + ) + + # Get roles that would grant any of the permissions + roles_with_permission = set() + for perm in permissions: + roles_with_permission.update(registry.get_roles_with_permission(perm)) + + if request.is_json or request.path.startswith('/api/'): + return jsonify({ + 'error': 'Insufficient permissions', + 'required_permissions': permissions, + 'user_roles': user_roles, + 'roles_with_permission': list(roles_with_permission), + 'message': f"You need one of these roles: {', '.join(roles_with_permission)}", + 'status': 403 + }), 403 + + from flask import render_template + return render_template('error.html', + error_code=403, + error_title='Permission Denied', + error_message=f"You don't have permission to access this feature.", + required_roles=list(roles_with_permission) + ), 403 + + # Log successful access + log_permission_check( + user=user_email, + permission=f"any({','.join(permissions)})", + granted=True, + endpoint=request.endpoint, + roles=user_roles + ) + + return f(*args, **kwargs) + + return decorated_function + + return decorator + + +def check_sso_required() -> Callable: + """ + Decorator to enforce SSO authentication when configured. + + When SSO is enabled and allow_anonymous is False, this decorator + redirects unauthenticated users to SSO login. + + Usage: + @app.route('/') + @check_sso_required() + def landing(): + ... + """ + def decorator(f: Callable) -> Callable: + @wraps(f) + def decorated_function(*args, **kwargs): + registry = get_registry() + + # If anonymous access is not allowed and user is not authenticated + if not registry.allow_anonymous and not is_authenticated(): + log_permission_check( + user='anonymous', + permission='sso_required', + granted=False, + endpoint=request.endpoint, + roles=[] + ) + + if request.is_json or request.path.startswith('/api/'): + return jsonify({ + 'error': 'Authentication required', + 'message': 'SSO authentication is required for this application', + 'login_url': url_for('login', method='sso'), + 'status': 401 + }), 401 + + # Redirect to SSO login + return redirect(url_for('login', method='sso')) + + return f(*args, **kwargs) + + return decorated_function + + return decorator diff --git a/src/utils/rbac/jwt_parser.py b/src/utils/rbac/jwt_parser.py new file mode 100644 index 000000000..f62002a9a --- /dev/null +++ b/src/utils/rbac/jwt_parser.py @@ -0,0 +1,215 @@ +""" +JWT Parser - Extract roles from SSO provider JWT tokens + +This module handles parsing JWT tokens from SSO providers to extract +user roles from the resource_access claim. +""" + +from typing import Any, Dict, List, Optional +import jwt + +from src.utils.logging import get_logger +from src.utils.rbac.registry import get_registry +from src.utils.rbac.audit import log_role_assignment + +logger = get_logger(__name__) + + +def extract_roles_from_token( + token: Dict[str, Any], + app_name: Optional[str] = None +) -> List[str]: + """ + Extract roles from a JWT token's resource_access claim. + + The expected token structure from Keycloak/CERN SSO: + { + "resource_access": { + "": { + "roles": ["role1", "role2", ...] + } + } + } + + Args: + token: Decoded JWT token dictionary (or the raw token response from OAuth) + app_name: Application name to look for in resource_access. + If not provided, uses the configured app_name from registry. + + Returns: + List of role strings extracted from the token + """ + if app_name is None: + registry = get_registry() + app_name = registry.app_name + + try: + # Handle both raw OAuth token response and decoded JWT + # OAuth libraries may wrap the token differently + + # Check if this is an OAuth token response with nested tokens + access_token_data = token + + # If token has 'access_token' key, it might be encoded + if 'access_token' in token and isinstance(token['access_token'], str): + # Try to decode the access token (without verification for role extraction) + try: + access_token_data = jwt.decode( + token['access_token'], + options={"verify_signature": False} # We trust the OAuth library verified it + ) + except jwt.DecodeError: + logger.warning("Could not decode access_token, using token as-is") + access_token_data = token + + # Also check id_token which may contain roles + id_token_data = {} + if 'id_token' in token and isinstance(token['id_token'], str): + try: + id_token_data = jwt.decode( + token['id_token'], + options={"verify_signature": False} + ) + except jwt.DecodeError: + pass + + # Look for resource_access in access_token first, then id_token + resource_access = access_token_data.get('resource_access', {}) + if not resource_access and id_token_data: + resource_access = id_token_data.get('resource_access', {}) + + # Also check userinfo if present + if not resource_access: + userinfo = token.get('userinfo', {}) + resource_access = userinfo.get('resource_access', {}) + + if not resource_access: + logger.warning(f"No resource_access claim found in token") + logger.debug(f"Token keys: {list(access_token_data.keys())}") + return [] + + # Get roles for our application + app_access = resource_access.get(app_name, {}) + if not app_access: + logger.warning( + f"No roles found for app '{app_name}' in resource_access. " + f"Available apps: {list(resource_access.keys())}" + ) + return [] + + roles = app_access.get('roles', []) + + if not isinstance(roles, list): + logger.warning(f"roles claim is not a list: {type(roles)}") + return [] + + logger.info(f"Extracted roles for app '{app_name}': {roles}") + return roles + + except Exception as e: + logger.error(f"Error extracting roles from token: {e}") + return [] + + +def get_user_roles( + token: Dict[str, Any], + user_email: str, + app_name: Optional[str] = None +) -> List[str]: + """ + Get validated user roles from JWT token, with default role fallback. + + This is the main entry point for role extraction. It: + 1. Extracts roles from the JWT token + 2. Filters to only configured/valid roles + 3. Assigns default role if no valid roles found + 4. Logs the role assignment + + Args: + token: JWT token dictionary from OAuth callback + user_email: User's email for logging + app_name: Optional app name override + + Returns: + List of validated role strings (never empty - at minimum returns default role) + """ + registry = get_registry() + + # Extract raw roles from token + raw_roles = extract_roles_from_token(token, app_name) + + # Filter to only valid/configured roles + valid_roles = registry.filter_valid_roles(raw_roles) + + if valid_roles: + # User has at least one valid role + log_role_assignment( + user=user_email, + roles=valid_roles, + source='jwt', + is_default=False + ) + return valid_roles + else: + # No valid roles - assign default + return assign_default_role(user_email, raw_roles) + + +def assign_default_role(user_email: str, original_roles: List[str] = None) -> List[str]: + """ + Assign the default role to a user who has no configured roles. + + Args: + user_email: User's email for logging + original_roles: Original roles from JWT (for logging, may be unmapped) + + Returns: + List containing only the default role + """ + registry = get_registry() + default_role = registry.default_role + + log_role_assignment( + user=user_email, + roles=[default_role], + source='default', + is_default=True + ) + + if original_roles: + logger.info( + f"User {user_email} has no configured roles. " + f"Original JWT roles {original_roles} are not mapped. " + f"Assigning default role: {default_role}" + ) + else: + logger.info( + f"User {user_email} has no roles in JWT. " + f"Assigning default role: {default_role}" + ) + + return [default_role] + + +def decode_jwt_claims(token_string: str, verify: bool = False) -> Dict[str, Any]: + """ + Decode a JWT token string to extract claims. + + Args: + token_string: Encoded JWT token + verify: Whether to verify signature (requires public key) + + Returns: + Decoded token claims dictionary + """ + try: + return jwt.decode( + token_string, + options={"verify_signature": verify} + ) + except jwt.DecodeError as e: + logger.error(f"Failed to decode JWT: {e}") + return {} + except jwt.ExpiredSignatureError: + logger.warning("JWT token has expired") + return {} diff --git a/src/utils/rbac/permissions.py b/src/utils/rbac/permissions.py new file mode 100644 index 000000000..0b3c6ad95 --- /dev/null +++ b/src/utils/rbac/permissions.py @@ -0,0 +1,296 @@ +""" +RBAC Permissions - Permission checking utilities + +This module provides utility functions for checking permissions +outside of the decorator context. +""" + +from typing import List, Optional, Set +from flask import session + +from src.utils.logging import get_logger +from src.utils.rbac.registry import get_registry + +logger = get_logger(__name__) + + +def has_permission(permission: str, roles: Optional[List[str]] = None) -> bool: + """ + Check if the current user (or provided roles) has a specific permission. + + This function can be used in templates or code to conditionally + show/hide UI elements based on permissions. + + Args: + permission: Permission string to check (e.g., 'upload:documents') + roles: Optional list of roles to check. If not provided, + uses roles from current session. + + Returns: + True if permission is granted, False otherwise + """ + if roles is None: + if not session.get('logged_in'): + return False + roles = session.get('roles', []) + + # Ensure roles is a list + if roles is None: + roles = [] + + registry = get_registry() + return registry.has_permission(roles, permission) + + +def has_any_permission(permissions: List[str], roles: Optional[List[str]] = None) -> bool: + """ + Check if the current user has ANY of the specified permissions. + + Args: + permissions: List of permission strings to check + roles: Optional list of roles (uses session if not provided) + + Returns: + True if at least one permission is granted + """ + if roles is None: + if not session.get('logged_in'): + return False + roles = session.get('roles', []) + + # Ensure roles is a list + if roles is None: + roles = [] + + registry = get_registry() + for permission in permissions: + if registry.has_permission(roles, permission): + return True + + return False + + +def has_all_permissions(permissions: List[str], roles: Optional[List[str]] = None) -> bool: + """ + Check if the current user has ALL of the specified permissions. + + Args: + permissions: List of permission strings to check + roles: Optional list of roles (uses session if not provided) + + Returns: + True if all permissions are granted + """ + if roles is None: + if not session.get('logged_in'): + return False + roles = session.get('roles', []) + + # Ensure roles is a list + if roles is None: + roles = [] + + registry = get_registry() + for permission in permissions: + if not registry.has_permission(roles, permission): + return False + + return True + + +def check_permission(permission: str, roles: Optional[List[str]] = None) -> bool: + """ + Alias for has_permission for compatibility. + """ + return has_permission(permission, roles) + + +def get_user_permissions(roles: Optional[List[str]] = None) -> Set[str]: + """ + Get all permissions available to the current user. + + Args: + roles: Optional list of roles (uses session if not provided) + + Returns: + Set of all permission strings granted to the user + """ + if roles is None: + if not session.get('logged_in'): + return set() + roles = session.get('roles', []) + + # Ensure roles is a list + if roles is None: + roles = [] + + registry = get_registry() + return registry.get_all_permissions_for_roles(roles) + + +def get_user_roles_from_session() -> List[str]: + """ + Get the current user's roles from the session. + + Returns: + List of role names, or empty list if not authenticated + """ + if not session.get('logged_in'): + return [] + + return session.get('roles', []) + + +def get_role_descriptions(roles: Optional[List[str]] = None) -> str: + """ + Get formatted role descriptions for the current user or provided roles. + + Requires SSO auth with auth_roles configured. + + Args: + roles: Optional list of roles. If not provided, uses session roles. + + Returns: + Formatted string like "role1 (description1), role2 (description2)" + or empty string if not authenticated or no roles. + """ + if roles is None: + roles = get_user_roles_from_session() + + if not roles: + return "" + + registry = get_registry() + return registry.get_role_descriptions(roles) + + +def is_admin(roles: Optional[List[str]] = None) -> bool: + """ + Check if the current user has admin role (wildcard permissions). + + Args: + roles: Optional list of roles (uses session if not provided) + + Returns: + True if user has admin-level access (any role with '*' permission) + """ + if roles is None: + if not session.get('logged_in'): + return False + roles = session.get('roles', []) + + # Ensure roles is a list + if roles is None: + roles = [] + + # Check if any role has wildcard permission + registry = get_registry() + for role in roles: + if role in registry._roles: + role_perms = registry._role_permissions_cache.get(role, set()) + if '*' in role_perms: + return True + + return False + + +def is_expert(roles: Optional[List[str]] = None) -> bool: + """ + Check if the current user has expert/power user role. + Expert is defined as having config:modify or upload:documents permissions. + + Args: + roles: Optional list of roles (uses session if not provided) + + Returns: + True if user has expert-level access + """ + if roles is None: + if not session.get('logged_in'): + return False + roles = session.get('roles', []) + + # Ensure roles is a list + if roles is None: + roles = [] + + # Admin (wildcard) counts as expert + if is_admin(roles): + return True + + # Check for expert-level permissions + return (has_permission('config:modify', roles) or + has_permission('upload:documents', roles)) + + +def can_upload_documents(roles: Optional[List[str]] = None) -> bool: + """ + Convenience function to check document upload permission. + + Returns: + True if user can upload documents + """ + return has_permission('upload:documents', roles) + + +def can_modify_config(roles: Optional[List[str]] = None) -> bool: + """ + Convenience function to check config modification permission. + + Returns: + True if user can modify configuration + """ + return has_permission('config:modify', roles) + + +def can_view_metrics(roles: Optional[List[str]] = None) -> bool: + """ + Convenience function to check metrics viewing permission. + + Returns: + True if user can view metrics + """ + return has_permission('view:metrics', roles) + + +def get_permission_context() -> dict: + """ + Get a context dictionary with all permission checks for templates. + + Useful for passing to Jinja2 templates to conditionally render UI. + + Returns: + Dictionary with boolean flags for each major permission + """ + if not session.get('logged_in'): + return { + 'is_authenticated': False, + 'can_chat': False, + 'can_view_documents': False, + 'can_select_documents': False, + 'can_upload_documents': False, + 'can_manage_api_keys': False, + 'can_view_config': False, + 'can_modify_config': False, + 'can_view_metrics': False, + 'is_admin': False, + 'is_expert': False, + 'user_roles': [], + } + + roles = session.get('roles', []) + + return { + 'is_authenticated': True, + 'can_chat': has_permission('chat:query', roles), + 'can_view_documents': has_permission('documents:view', roles), + 'can_select_documents': has_permission('documents:select', roles), + 'can_upload_documents': has_permission('upload:documents', roles), + 'can_manage_api_keys': has_permission('api-keys:manage', roles), + 'can_view_config': has_permission('config:view', roles), + 'can_modify_config': has_permission('config:modify', roles), + 'can_view_metrics': has_permission('view:metrics', roles), + 'is_admin': is_admin(roles), + 'is_expert': is_expert(roles), + 'user_roles': roles, + } diff --git a/src/utils/rbac/registry.py b/src/utils/rbac/registry.py new file mode 100644 index 000000000..4815274bb --- /dev/null +++ b/src/utils/rbac/registry.py @@ -0,0 +1,485 @@ +""" +RBAC Registry - Centralized permission and role management + +This module provides the core registry for role-based access control, +loading configuration from the main config (services.chat_app.auth.auth_roles) +or from a standalone auth_roles.yaml file. +""" + +import os +from pathlib import Path +from typing import Dict, List, Optional, Set, Any +import yaml + +from src.utils.logging import get_logger + +logger = get_logger(__name__) + +# Global registry instance (singleton pattern) +_registry: Optional['RBACRegistry'] = None + + +class RBACConfigError(Exception): + """Raised when RBAC configuration is invalid.""" + pass + + +class RBACRegistry: + """ + Central registry for role-based access control. + + Manages: + - Role definitions and inheritance + - Permission-to-role mappings + - Configuration validation + - Permission lookups with caching + + Singleton pattern - one registry per application. + """ + + def __init__(self, config: Dict[str, Any], app_name: Optional[str] = None): + """ + Initialize the RBAC registry from configuration. + + Args: + config: Dictionary loaded from auth_roles configuration + app_name: Override app_name (e.g., from SSO_CLIENT_ID). If not provided, + uses config['app_name'] or defaults to 'archi-app' + """ + self._config = config + self._app_name = app_name or config.get('app_name', 'archi-app') + self._default_role = config.get('default_role', 'base-user') + self._sso_config = config.get('sso', {}) + self._roles: Dict[str, Dict] = config.get('roles', {}) + self._permissions: Dict[str, Dict] = config.get('permissions', {}) + + # Cache for resolved permissions per role (including inherited) + self._role_permissions_cache: Dict[str, Set[str]] = {} + + # Validate configuration on load + self._validate_config() + + # Pre-compute permission sets for all roles + self._build_permission_cache() + + logger.info(f"RBAC Registry initialized: {len(self._roles)} roles, {len(self._permissions)} permissions") + + def _validate_config(self) -> None: + """ + Validate the RBAC configuration. + + Raises: + RBACConfigError: If configuration is invalid + """ + # Check for required fields - fail startup if missing + if not self._roles: + raise RBACConfigError( + "No roles defined in configuration. " + "At least one role must be defined in auth_roles.roles" + ) + + # Check if default_role is defined, warn if not but allow fallback + if self._default_role not in self._roles: + logger.warning( + f"Default role '{self._default_role}' is not defined in roles. " + f"This role will be assigned to users without configured roles. " + f"Available roles: {list(self._roles.keys())}. " + f"Recommend adding '{self._default_role}' to your configuration." + ) + # If 'base-user' exists, use it; otherwise use first available role + if 'base-user' in self._roles: + self._default_role = 'base-user' + logger.info(f"Using 'base-user' as default role") + else: + self._default_role = next(iter(self._roles)) + logger.warning(f"Falling back to first available role: {self._default_role}") + + # Validate inherited roles exist before checking circular inheritance + for role_name, role_config in self._roles.items(): + for parent in role_config.get('inherits', []): + if parent not in self._roles: + raise RBACConfigError( + f"Role '{role_name}' inherits from undefined role '{parent}'. " + f"Available roles: {list(self._roles.keys())}" + ) + + # Check for circular inheritance (after validating all roles exist) + for role_name in self._roles: + self._check_circular_inheritance(role_name, set(), []) + + logger.debug("RBAC configuration validated successfully") + + def _check_circular_inheritance(self, role_name: str, visited: Set[str], path: List[str]) -> None: + """ + Check for circular inheritance in role definitions. + + Args: + role_name: Current role being checked + visited: Set of roles already visited in this path + path: Current inheritance path for error reporting + + Raises: + RBACConfigError: If circular inheritance is detected + """ + if role_name in visited: + cycle = ' -> '.join(path + [role_name]) + raise RBACConfigError(f"Circular inheritance detected: {cycle}") + + # Add current role to the visited set and path + visited = visited.copy() # Create a copy to avoid polluting other branches + visited.add(role_name) + path = path + [role_name] # Create new list to avoid mutation + + # Check all parent roles recursively + role_config = self._roles.get(role_name, {}) + for parent in role_config.get('inherits', []): + if parent in self._roles: + self._check_circular_inheritance(parent, visited, path) + + def _build_permission_cache(self) -> None: + """ + Pre-compute resolved permissions for each role including inheritance. + """ + for role_name in self._roles: + self._role_permissions_cache[role_name] = self._resolve_permissions(role_name) + + logger.debug(f"Permission cache built for {len(self._role_permissions_cache)} roles") + + def _resolve_permissions(self, role_name: str, visited: Set[str] = None) -> Set[str]: + """ + Resolve all permissions for a role including inherited permissions. + + Args: + role_name: Name of the role to resolve + visited: Set of already-visited roles to prevent infinite loops + + Returns: + Set of all permission strings granted to this role + """ + if visited is None: + visited = set() + + if role_name in visited: + return set() # Prevent infinite loop (shouldn't happen after validation) + + visited.add(role_name) + + role_config = self._roles.get(role_name, {}) + permissions = set(role_config.get('permissions', [])) + + # Resolve inherited permissions + for parent in role_config.get('inherits', []): + if parent in self._roles: + permissions.update(self._resolve_permissions(parent, visited)) + + return permissions + + def _get_all_defined_permissions(self) -> Set[str]: + """ + Get all permissions that are defined in any role configuration. + + Used for detecting undefined permissions during permission checks. + + Returns: + Set of all defined permission strings + """ + all_permissions = set() + for role_perms in self._role_permissions_cache.values(): + all_permissions.update(role_perms) + # Remove wildcard from the set - it's a special marker, not a real permission + all_permissions.discard('*') + return all_permissions + + @property + def app_name(self) -> str: + """Get the SSO application name for role extraction.""" + return self._app_name + + @property + def default_role(self) -> str: + """Get the default role for users without configured roles.""" + return self._default_role + + @property + def allow_anonymous(self) -> bool: + """Check if anonymous access is allowed when SSO is enabled.""" + return self._sso_config.get('allow_anonymous', False) + + def get_role_permissions(self, role_name: str) -> Set[str]: + """ + Get all permissions granted to a role (including inherited). + + Args: + role_name: Name of the role + + Returns: + Set of permission strings, empty set if role not found + """ + return self._role_permissions_cache.get(role_name, set()).copy() + + def get_all_permissions_for_roles(self, roles: List[str]) -> Set[str]: + """ + Get all permissions granted to a list of roles. + + Args: + roles: List of role names + + Returns: + Set of all permission strings granted by any of the roles + """ + permissions = set() + for role in roles: + permissions.update(self.get_role_permissions(role)) + return permissions + + def has_permission(self, roles: List[str], permission: str) -> bool: + """ + Check if any of the given roles has the specified permission. + + Args: + roles: List of role names the user has + permission: Permission string to check (e.g., 'upload:documents') + + Returns: + True if any role grants the permission, False otherwise + """ + # Check if any role has wildcard or specific permission + has_wildcard = False + has_specific = False + + for role in roles: + role_perms = self._role_permissions_cache.get(role, set()) + + # Check for wildcard permission (admin) - early return for efficiency + if '*' in role_perms: + return True + + # Check for specific permission + if permission in role_perms: + has_specific = True + + # If permission was granted, return True + if has_specific: + return True + + # Deny by default - check if this is an undefined permission + # (fail closed security model) + all_defined_permissions = self._get_all_defined_permissions() + if permission not in all_defined_permissions and permission != '*': + logger.warning( + f"Permission check for undefined permission '{permission}' - denying access. " + f"Roles: {roles}. Define this permission in auth_roles config." + ) + + return False + + def is_valid_role(self, role_name: str) -> bool: + """ + Check if a role name is defined in the configuration. + + Args: + role_name: Role name to check + + Returns: + True if role is defined, False otherwise + """ + return role_name in self._roles + + def filter_valid_roles(self, roles: List[str]) -> List[str]: + """ + Filter a list of roles to only include valid/configured roles. + + Args: + roles: List of role names from JWT token + + Returns: + List of roles that are defined in configuration + """ + valid_roles = [r for r in roles if self.is_valid_role(r)] + + invalid_roles = set(roles) - set(valid_roles) + if invalid_roles: + logger.warning(f"Ignoring unmapped roles from JWT: {invalid_roles}") + + return valid_roles + + def get_roles_with_permission(self, permission: str) -> List[str]: + """ + Get all roles that grant a specific permission. + + Useful for error messages ("You need role X or Y to do this"). + + Args: + permission: Permission string to check + + Returns: + List of role names that grant this permission + """ + roles_with_permission = [] + for role_name, perms in self._role_permissions_cache.items(): + if '*' in perms or permission in perms: + roles_with_permission.append(role_name) + return roles_with_permission + + def get_role_info(self, role_name: str) -> Optional[Dict]: + """ + Get configuration info for a specific role. + + Args: + role_name: Name of the role + + Returns: + Role configuration dict or None if not found + """ + return self._roles.get(role_name) + + @property + def pass_descriptions_to_agent(self) -> bool: + """ + Check if role descriptions should be passed to the agent. + + Requires SSO auth with auth_roles configured. + """ + return self._config.get('pass_descriptions_to_agent', False) + + def get_role_descriptions(self, roles: List[str]) -> str: + """ + Get a formatted string of role descriptions for the given roles. + + Used to append role context to agent system prompts when enabled. + Falls back to role name if no description is configured. + + Args: + roles: List of role names + + Returns: + Formatted string like "role1 (description1), role2 (description2)" + or empty string if no valid roles + """ + if not roles: + return "" + + descriptions = [] + for role in roles: + role_info = self._roles.get(role) + if role_info: + desc = role_info.get('description', role) + descriptions.append(f"{role} ({desc})") + elif self.is_valid_role(role): + descriptions.append(role) + + return ", ".join(descriptions) + + +def load_rbac_config(config_path: Optional[str] = None) -> Dict[str, Any]: + """ + Load RBAC configuration from main config or YAML file. + + Priority order: + 1. Main config (services.chat_app.auth.auth_roles) if available + 2. Explicit config_path if provided + 3. Standard auth_roles.yaml locations + 4. Minimal defaults + + Note: app_name is preferably sourced from SSO_CLIENT_ID environment variable. + Use get_registry() to automatically populate app_name from SSO config. + + Args: + config_path: Optional path to auth_roles.yaml. If not provided, + checks main config first, then standard locations. + + Returns: + Configuration dictionary + + Raises: + RBACConfigError: If config is invalid + """ + # First, try to load from main config + try: + from src.utils.config_access import get_full_config + full_config = get_full_config() + auth_roles_config = full_config.get('services', {}).get('chat_app', {}).get('auth', {}).get('auth_roles') + + if auth_roles_config and isinstance(auth_roles_config, dict) and auth_roles_config.get('roles'): + logger.info("Loading RBAC configuration from main config (services.chat_app.auth.auth_roles)") + return auth_roles_config + elif auth_roles_config: + logger.warning("auth_roles found in config but has no roles defined, falling back to defaults") + except Exception as e: + logger.debug(f"Could not load auth_roles from main config: {e}") + + # Fallback: try standalone auth_roles.yaml file + # Allow overriding the auth_roles.yaml location via environment variable + env_config_path = os.getenv('AUTH_ROLES_CONFIG_PATH') + + search_paths = [ + config_path, + env_config_path, # Environment-provided path (e.g., container runtime) + os.path.join(os.getcwd(), 'configs', 'auth_roles.yaml'), # Local dev + os.path.join(os.path.dirname(__file__), '..', '..', '..', 'configs', 'auth_roles.yaml'), + ] + + config_file = None + for path in search_paths: + if path and os.path.isfile(path): + config_file = path + break + + if config_file: + logger.info(f"Loading RBAC configuration from: {config_file}") + with open(config_file, 'r') as f: + config = yaml.safe_load(f) + return config + + # Return minimal default config if no config found + logger.warning("No auth_roles configuration found, granting wildcard permissions by default") + return { + 'app_name': 'archi-app', + 'default_role': 'base-user', + 'sso': {'allow_anonymous': False}, + 'roles': { + 'base-user': { + 'description': 'Default authenticated user (no roles configured)', + 'permissions': ['*'] + } + }, + 'permissions': {} + } + + +def get_registry(config_path: Optional[str] = None, force_reload: bool = False) -> RBACRegistry: + """ + Get the global RBAC registry instance (singleton). + + Automatically uses SSO_CLIENT_ID environment variable as app_name if available. + + Args: + config_path: Optional path to configuration file + force_reload: If True, reload configuration even if already loaded + + Returns: + RBACRegistry instance + """ + global _registry + + if _registry is None or force_reload: + config = load_rbac_config(config_path) + + # Use SSO_CLIENT_ID as app_name if available + from src.utils.env import read_secret + app_name = read_secret('SSO_CLIENT_ID') + + _registry = RBACRegistry(config, app_name=app_name) + + if app_name: + logger.info(f"RBAC registry initialized with app_name from SSO_CLIENT_ID: {app_name}") + + return _registry + + +def reset_registry() -> None: + """ + Reset the global registry (for testing purposes). + """ + global _registry + _registry = None diff --git a/src/utils/sandbox/__init__.py b/src/utils/sandbox/__init__.py new file mode 100644 index 000000000..0f785baee --- /dev/null +++ b/src/utils/sandbox/__init__.py @@ -0,0 +1,62 @@ +""" +Sandbox module for containerized code execution. + +This module provides secure, isolated code execution in ephemeral Docker containers. +""" + +from src.utils.sandbox.config import ( + ApprovalMode, + RegistryConfig, + ResourceLimits, + RoleSandboxOverrides, + SandboxConfig, + get_role_sandbox_overrides, + get_sandbox_config, + resolve_effective_config, +) +from src.utils.sandbox.executor import ( + FileOutput, + SandboxExecutor, + SandboxResult, +) +from src.utils.sandbox.approval import ( + ApprovalRequest, + ApprovalStatus, + cancel_approvals_for_trace, + cleanup_old_requests, + create_approval_request, + get_approval_request, + get_pending_approvals_for_conversation, + get_pending_approvals_for_trace, + register_approval_callback, + resolve_approval, + wait_for_approval, +) + +__all__ = [ + # Config + "ApprovalMode", + "RegistryConfig", + "ResourceLimits", + "RoleSandboxOverrides", + "SandboxConfig", + "get_role_sandbox_overrides", + "get_sandbox_config", + "resolve_effective_config", + # Executor + "FileOutput", + "SandboxExecutor", + "SandboxResult", + # Approval + "ApprovalRequest", + "ApprovalStatus", + "cancel_approvals_for_trace", + "cleanup_old_requests", + "create_approval_request", + "get_approval_request", + "get_pending_approvals_for_conversation", + "get_pending_approvals_for_trace", + "register_approval_callback", + "resolve_approval", + "wait_for_approval", +] diff --git a/src/utils/sandbox/approval.py b/src/utils/sandbox/approval.py new file mode 100644 index 000000000..16875f2e9 --- /dev/null +++ b/src/utils/sandbox/approval.py @@ -0,0 +1,367 @@ +""" +Sandbox approval mechanism for human-in-the-loop code execution. + +This module provides functionality to pause sandbox execution and wait for +user approval before running code. It supports both auto-approve and +manual approval modes. +""" + +from __future__ import annotations + +import threading +import time +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum +from typing import Any, Callable, Dict, List, Optional + +from src.utils.logging import get_logger + +logger = get_logger(__name__) + + +class ApprovalStatus(str, Enum): + """Status of an approval request.""" + + PENDING = "pending" + """Waiting for user decision.""" + + APPROVED = "approved" + """User approved the execution.""" + + REJECTED = "rejected" + """User rejected the execution.""" + + EXPIRED = "expired" + """Approval request timed out.""" + + CANCELLED = "cancelled" + """Request was cancelled (e.g., stream aborted).""" + + +@dataclass +class ApprovalRequest: + """Represents a pending sandbox approval request.""" + + approval_id: str + """Unique identifier for this approval request.""" + + trace_id: str + """The trace ID of the agent run that requested approval.""" + + conversation_id: int + """The conversation this request belongs to.""" + + code: str + """The code to be executed.""" + + language: str + """Programming language (python, bash, sh).""" + + image: str + """Docker image to use for execution.""" + + tool_call_id: str + """The tool call ID from the agent.""" + + status: ApprovalStatus = ApprovalStatus.PENDING + """Current status of the request.""" + + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + """When the request was created.""" + + resolved_at: Optional[datetime] = None + """When the request was approved/rejected/expired.""" + + resolved_by: Optional[str] = None + """User who resolved the request (if any).""" + + timeout_seconds: float = 300.0 + """How long to wait for approval before expiring.""" + + def is_expired(self) -> bool: + """Check if the request has expired.""" + if self.status != ApprovalStatus.PENDING: + return False + elapsed = (datetime.now(timezone.utc) - self.created_at).total_seconds() + return elapsed > self.timeout_seconds + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return { + "approval_id": self.approval_id, + "trace_id": self.trace_id, + "conversation_id": self.conversation_id, + "code": self.code, + "language": self.language, + "image": self.image, + "tool_call_id": self.tool_call_id, + "status": self.status.value, + "created_at": self.created_at.isoformat(), + "resolved_at": self.resolved_at.isoformat() if self.resolved_at else None, + "resolved_by": self.resolved_by, + "timeout_seconds": self.timeout_seconds, + } + + +# Global registry of pending approval requests +_approval_requests: Dict[str, ApprovalRequest] = {} +_approval_lock = threading.Lock() + +# Callbacks for notifying approval status changes +_approval_callbacks: Dict[str, List[Callable[[ApprovalRequest], None]]] = {} + + +def create_approval_request( + *, + trace_id: str, + conversation_id: int, + code: str, + language: str, + image: str, + tool_call_id: str, + timeout_seconds: float = 300.0, +) -> ApprovalRequest: + """ + Create a new approval request for sandbox code execution. + + Args: + trace_id: The trace ID of the agent run. + conversation_id: The conversation ID. + code: The code to execute. + language: Programming language. + image: Docker image. + tool_call_id: The tool call ID. + timeout_seconds: How long to wait for approval. + + Returns: + The created ApprovalRequest. + """ + approval_id = str(uuid.uuid4()) + + request = ApprovalRequest( + approval_id=approval_id, + trace_id=trace_id, + conversation_id=conversation_id, + code=code, + language=language, + image=image, + tool_call_id=tool_call_id, + timeout_seconds=timeout_seconds, + ) + + with _approval_lock: + _approval_requests[approval_id] = request + + logger.info( + "Created approval request: id=%s, trace=%s, conversation=%d, language=%s", + approval_id, trace_id, conversation_id, language + ) + + return request + + +def get_approval_request(approval_id: str) -> Optional[ApprovalRequest]: + """Get an approval request by ID.""" + with _approval_lock: + return _approval_requests.get(approval_id) + + +def get_pending_approvals_for_trace(trace_id: str) -> List[ApprovalRequest]: + """Get all pending approval requests for a trace.""" + with _approval_lock: + return [ + req for req in _approval_requests.values() + if req.trace_id == trace_id and req.status == ApprovalStatus.PENDING + ] + + +def get_pending_approvals_for_conversation(conversation_id: int) -> List[ApprovalRequest]: + """Get all pending approval requests for a conversation.""" + with _approval_lock: + return [ + req for req in _approval_requests.values() + if req.conversation_id == conversation_id + and req.status == ApprovalStatus.PENDING + ] + + +def resolve_approval( + approval_id: str, + approved: bool, + resolved_by: Optional[str] = None, +) -> Optional[ApprovalRequest]: + """ + Resolve an approval request (approve or reject). + + Args: + approval_id: The approval request ID. + approved: True to approve, False to reject. + resolved_by: Optional user identifier. + + Returns: + The updated ApprovalRequest, or None if not found. + """ + with _approval_lock: + request = _approval_requests.get(approval_id) + if not request: + logger.warning("Approval request not found: %s", approval_id) + return None + + if request.status != ApprovalStatus.PENDING: + logger.warning( + "Approval request %s already resolved: %s", + approval_id, request.status.value + ) + return request + + if request.is_expired(): + request.status = ApprovalStatus.EXPIRED + request.resolved_at = datetime.now(timezone.utc) + logger.info("Approval request %s expired", approval_id) + else: + request.status = ApprovalStatus.APPROVED if approved else ApprovalStatus.REJECTED + request.resolved_at = datetime.now(timezone.utc) + request.resolved_by = resolved_by + logger.info( + "Approval request %s %s by %s", + approval_id, + "approved" if approved else "rejected", + resolved_by or "unknown" + ) + + # Notify callbacks + callbacks = _approval_callbacks.get(approval_id, []) + + # Call outside lock to avoid deadlocks + for callback in callbacks: + try: + callback(request) + except Exception as e: + logger.error("Approval callback error: %s", e, exc_info=True) + + return request + + +def cancel_approvals_for_trace(trace_id: str) -> List[ApprovalRequest]: + """ + Cancel all pending approvals for a trace (e.g., when stream is aborted). + + Args: + trace_id: The trace ID to cancel approvals for. + + Returns: + List of cancelled approval requests. + """ + cancelled = [] + + with _approval_lock: + for request in _approval_requests.values(): + if request.trace_id == trace_id and request.status == ApprovalStatus.PENDING: + request.status = ApprovalStatus.CANCELLED + request.resolved_at = datetime.now(timezone.utc) + cancelled.append(request) + + if cancelled: + logger.info("Cancelled %d approval requests for trace %s", len(cancelled), trace_id) + + return cancelled + + +def wait_for_approval( + approval_id: str, + timeout: Optional[float] = None, + poll_interval: float = 0.5, +) -> ApprovalRequest: + """ + Wait for an approval request to be resolved. + + This is a blocking call that polls the approval status. + + Args: + approval_id: The approval request ID. + timeout: Maximum time to wait (uses request timeout if None). + poll_interval: How often to check status. + + Returns: + The resolved ApprovalRequest. + + Raises: + ValueError: If approval request not found. + """ + request = get_approval_request(approval_id) + if not request: + raise ValueError(f"Approval request not found: {approval_id}") + + effective_timeout = timeout if timeout is not None else request.timeout_seconds + start_time = time.time() + + while True: + # Check current status + with _approval_lock: + request = _approval_requests.get(approval_id) + + if not request: + raise ValueError(f"Approval request disappeared: {approval_id}") + + # Check if resolved + if request.status != ApprovalStatus.PENDING: + return request + + # Check if expired + if request.is_expired(): + resolve_approval(approval_id, approved=False) + with _approval_lock: + return _approval_requests.get(approval_id, request) + + # Check timeout + elapsed = time.time() - start_time + if elapsed >= effective_timeout: + resolve_approval(approval_id, approved=False) + with _approval_lock: + return _approval_requests.get(approval_id, request) + + # Wait before next poll + time.sleep(poll_interval) + + +def register_approval_callback( + approval_id: str, + callback: Callable[[ApprovalRequest], None], +) -> None: + """Register a callback to be notified when an approval is resolved.""" + with _approval_lock: + if approval_id not in _approval_callbacks: + _approval_callbacks[approval_id] = [] + _approval_callbacks[approval_id].append(callback) + + +def cleanup_old_requests(max_age_seconds: float = 3600) -> int: + """ + Remove old resolved requests to prevent memory leaks. + + Args: + max_age_seconds: Remove requests older than this. + + Returns: + Number of requests cleaned up. + """ + now = datetime.now(timezone.utc) + to_remove = [] + + with _approval_lock: + for approval_id, request in _approval_requests.items(): + if request.status != ApprovalStatus.PENDING: + age = (now - request.created_at).total_seconds() + if age > max_age_seconds: + to_remove.append(approval_id) + + for approval_id in to_remove: + del _approval_requests[approval_id] + _approval_callbacks.pop(approval_id, None) + + if to_remove: + logger.debug("Cleaned up %d old approval requests", len(to_remove)) + + return len(to_remove) diff --git a/src/utils/sandbox/config.py b/src/utils/sandbox/config.py new file mode 100644 index 000000000..cc3cc2907 --- /dev/null +++ b/src/utils/sandbox/config.py @@ -0,0 +1,331 @@ +""" +Sandbox configuration schema and validation. + +This module defines the configuration dataclasses for the containerized +sandbox code execution feature. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Union + +from src.utils.logging import get_logger + +logger = get_logger(__name__) + + +class ApprovalMode(str, Enum): + """Approval mode for sandbox command execution.""" + + AUTO = "auto" + """Commands are executed automatically without user approval.""" + + MANUAL = "manual" + """Each command requires explicit user approval before execution.""" + + +@dataclass +class RegistryConfig: + """Configuration for a custom Docker registry.""" + + url: str = "" + """Registry URL (e.g., 'registry.cern.ch', 'ghcr.io').""" + + username_env: str = "" + """Environment variable name containing the registry username.""" + + password_env: str = "" + """Environment variable name containing the registry password/token.""" + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "RegistryConfig": + """Create RegistryConfig from a dictionary.""" + if not data: + return cls() + return cls( + url=data.get("url", ""), + username_env=data.get("username_env", ""), + password_env=data.get("password_env", ""), + ) + + def is_configured(self) -> bool: + """Check if registry is configured with credentials.""" + return bool(self.url and self.username_env and self.password_env) + + def get_credentials(self) -> Optional[Dict[str, str]]: + """Get registry credentials from environment variables.""" + import os + if not self.is_configured(): + return None + + username = os.environ.get(self.username_env) + password = os.environ.get(self.password_env) + + if not username or not password: + logger.warning( + f"Registry credentials not found in environment. " + f"Expected {self.username_env} and {self.password_env}" + ) + return None + + return { + "username": username, + "password": password, + "registry": self.url, + } + + +@dataclass +class ResourceLimits: + """Resource constraints for sandbox containers.""" + + memory: str = "256m" + """Docker memory limit format (e.g., '256m', '1g').""" + + cpu: float = 0.5 + """CPU cores limit (e.g., 0.5 = half a core).""" + + pids_limit: int = 100 + """Maximum number of processes to prevent fork bombs.""" + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ResourceLimits": + """Create ResourceLimits from a dictionary.""" + return cls( + memory=data.get("memory", "256m"), + cpu=float(data.get("cpu", 0.5)), + pids_limit=int(data.get("pids_limit", 100)), + ) + + +@dataclass +class SandboxConfig: + """Configuration for the sandbox execution system.""" + + enabled: bool = False + """Whether sandbox execution is enabled for this deployment.""" + + approval_mode: ApprovalMode = ApprovalMode.AUTO + """Approval mode for sandbox commands: 'auto' or 'manual'.""" + + default_image: str = "python:3.11-slim" + """Default Docker image to use when not specified.""" + + image_allowlist: List[str] = field(default_factory=lambda: ["python:3.11-slim"]) + """List of allowed Docker images. Only images in this list can be used.""" + + timeout: float = 30.0 + """Default execution timeout in seconds.""" + + max_timeout: float = 300.0 + """Maximum allowed timeout (roles cannot exceed this).""" + + resource_limits: ResourceLimits = field(default_factory=ResourceLimits) + """Resource limits for containers.""" + + network_enabled: bool = True + """Whether containers have outbound network access.""" + + output_max_chars: int = 100000 + """Maximum characters for stdout/stderr output.""" + + output_max_file_size: int = 10 * 1024 * 1024 # 10MB + """Maximum size in bytes for captured output files.""" + + docker_socket: str = "/var/run/docker.sock" + """Path to Docker socket.""" + + registry: RegistryConfig = field(default_factory=RegistryConfig) + """Custom Docker registry configuration for private images.""" + + @classmethod + def from_dict(cls, data: Optional[Dict[str, Any]]) -> "SandboxConfig": + """Create SandboxConfig from a dictionary (e.g., from YAML config).""" + if not data: + return cls() + + resource_limits_data = data.get("resource_limits", {}) + resource_limits = ResourceLimits.from_dict(resource_limits_data) + + registry_data = data.get("registry", {}) + registry = RegistryConfig.from_dict(registry_data) + + # Parse approval_mode from string + approval_mode_str = data.get("approval_mode", "auto").lower() + try: + approval_mode = ApprovalMode(approval_mode_str) + except ValueError: + logger.warning( + f"Invalid approval_mode '{approval_mode_str}', defaulting to 'auto'" + ) + approval_mode = ApprovalMode.AUTO + + return cls( + enabled=bool(data.get("enabled", False)), + approval_mode=approval_mode, + default_image=data.get("default_image", "python:3.11-slim"), + image_allowlist=data.get("image_allowlist", ["python:3.11-slim"]), + timeout=float(data.get("timeout", 30.0)), + max_timeout=float(data.get("max_timeout", 300.0)), + resource_limits=resource_limits, + network_enabled=bool(data.get("network_enabled", True)), + output_max_chars=int(data.get("output_max_chars", 100000)), + output_max_file_size=int(data.get("output_max_file_size", 10 * 1024 * 1024)), + docker_socket=data.get("docker_socket", "/var/run/docker.sock"), + registry=registry, + ) + + def validate(self) -> List[str]: + """ + Validate the configuration and return a list of errors. + + Returns: + List of error messages (empty if valid). + """ + errors = [] + + if self.timeout <= 0: + errors.append("timeout must be positive") + + if self.max_timeout <= 0: + errors.append("max_timeout must be positive") + + if self.timeout > self.max_timeout: + errors.append("timeout cannot exceed max_timeout") + + if not self.image_allowlist: + errors.append("image_allowlist cannot be empty when sandbox is enabled") + + if self.default_image not in self.image_allowlist: + errors.append(f"default_image '{self.default_image}' must be in image_allowlist") + + if self.resource_limits.cpu <= 0: + errors.append("resource_limits.cpu must be positive") + + if self.resource_limits.pids_limit <= 0: + errors.append("resource_limits.pids_limit must be positive") + + return errors + + def is_image_allowed(self, image: str) -> bool: + """Check if an image is in the allowlist.""" + return image in self.image_allowlist + + +@dataclass +class RoleSandboxOverrides: + """Per-role sandbox configuration overrides.""" + + allowed_images: Union[List[str], Literal["*"]] = field(default_factory=list) + """ + Images this role can use. Either a list of image names (subset of deployment allowlist) + or "*" to allow all images from deployment allowlist. + """ + + timeout: Optional[float] = None + """Role-specific timeout override (capped at deployment max_timeout).""" + + @classmethod + def from_dict(cls, data: Optional[Dict[str, Any]]) -> "RoleSandboxOverrides": + """Create RoleSandboxOverrides from a dictionary.""" + if not data: + return cls() + + allowed_images = data.get("allowed_images", []) + timeout = data.get("timeout") + + return cls( + allowed_images=allowed_images, + timeout=float(timeout) if timeout is not None else None, + ) + + +def resolve_effective_config( + base_config: SandboxConfig, + role_overrides: Optional[RoleSandboxOverrides] = None, +) -> SandboxConfig: + """ + Resolve effective sandbox config by applying role overrides to base config. + + Role overrides can only restrict (not expand) the base configuration. + + Args: + base_config: The deployment-level sandbox configuration. + role_overrides: Optional role-specific overrides. + + Returns: + Effective SandboxConfig with role restrictions applied. + """ + if not role_overrides: + return base_config + + # Resolve timeout (role can set custom, but capped at max_timeout) + effective_timeout = base_config.timeout + if role_overrides.timeout is not None: + effective_timeout = min(role_overrides.timeout, base_config.max_timeout) + + # Resolve allowed images + if role_overrides.allowed_images == "*": + effective_images = base_config.image_allowlist + elif role_overrides.allowed_images: + # Role can only use images that are in both role list AND deployment allowlist + effective_images = [ + img for img in role_overrides.allowed_images + if img in base_config.image_allowlist + ] + else: + effective_images = base_config.image_allowlist + + # Resolve default image (must be in effective images) + effective_default = base_config.default_image + if effective_default not in effective_images and effective_images: + effective_default = effective_images[0] + + return SandboxConfig( + enabled=base_config.enabled, + approval_mode=base_config.approval_mode, + default_image=effective_default, + image_allowlist=effective_images, + timeout=effective_timeout, + max_timeout=base_config.max_timeout, + resource_limits=base_config.resource_limits, + network_enabled=base_config.network_enabled, + output_max_chars=base_config.output_max_chars, + output_max_file_size=base_config.output_max_file_size, + docker_socket=base_config.docker_socket, + ) + + +def get_sandbox_config() -> SandboxConfig: + """ + Load sandbox configuration from the deployment config. + + Returns: + SandboxConfig loaded from archi.sandbox config section. + """ + try: + from src.utils.config_access import get_archi_config + archi_config = get_archi_config() + sandbox_data = archi_config.get("sandbox", {}) + return SandboxConfig.from_dict(sandbox_data) + except Exception as e: + logger.warning(f"Failed to load sandbox config, using defaults: {e}") + return SandboxConfig() + + +def get_role_sandbox_overrides(role_config: Dict[str, Any]) -> Optional[RoleSandboxOverrides]: + """ + Extract sandbox overrides from a role configuration. + + Args: + role_config: The role configuration dictionary. + + Returns: + RoleSandboxOverrides if present, None otherwise. + """ + sandbox_data = role_config.get("sandbox") + if not sandbox_data: + return None + return RoleSandboxOverrides.from_dict(sandbox_data) diff --git a/src/utils/sandbox/executor.py b/src/utils/sandbox/executor.py new file mode 100644 index 000000000..273720b58 --- /dev/null +++ b/src/utils/sandbox/executor.py @@ -0,0 +1,562 @@ +""" +Sandbox executor for running code in isolated Docker containers. + +This module provides the SandboxExecutor class that manages ephemeral container +lifecycle for secure code execution. +""" + +from __future__ import annotations + +import base64 +import io +import os +import tarfile +import tempfile +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional + +from src.utils.logging import get_logger +from src.utils.sandbox.config import ResourceLimits, SandboxConfig + +logger = get_logger(__name__) + + +# Language to execution command mapping +LANGUAGE_COMMANDS = { + "python": ["python", "/workspace/script.py"], + "python3": ["python3", "/workspace/script.py"], + "bash": ["bash", "/workspace/script.sh"], + "sh": ["sh", "/workspace/script.sh"], +} + +# Language to file extension mapping +LANGUAGE_EXTENSIONS = { + "python": ".py", + "python3": ".py", + "bash": ".sh", + "sh": ".sh", +} + + +@dataclass +class FileOutput: + """Represents a file generated by sandbox execution.""" + + filename: str + """Name of the file.""" + + mimetype: str + """MIME type of the file.""" + + content_base64: str + """Base64-encoded file content.""" + + size: int + """Size of the file in bytes.""" + + truncated: bool = False + """Whether the file was truncated due to size limits.""" + + +@dataclass +class SandboxResult: + """Result of a sandbox code execution.""" + + stdout: str = "" + """Captured standard output.""" + + stderr: str = "" + """Captured standard error.""" + + exit_code: int = 0 + """Process exit code.""" + + execution_time: float = 0.0 + """Execution time in seconds.""" + + files: List[FileOutput] = field(default_factory=list) + """Generated output files.""" + + truncated: bool = False + """Whether stdout/stderr was truncated.""" + + error: Optional[str] = None + """Error message if execution failed (not a code error, but system error).""" + + timed_out: bool = False + """Whether execution timed out.""" + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "stdout": self.stdout, + "stderr": self.stderr, + "exit_code": self.exit_code, + "execution_time": self.execution_time, + "files": [ + { + "filename": f.filename, + "mimetype": f.mimetype, + "content_base64": f.content_base64, + "size": f.size, + "truncated": f.truncated, + } + for f in self.files + ], + "truncated": self.truncated, + "error": self.error, + "timed_out": self.timed_out, + } + + +def _guess_mimetype(filename: str) -> str: + """Guess MIME type from filename extension.""" + ext = Path(filename).suffix.lower() + mime_map = { + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".gif": "image/gif", + ".svg": "image/svg+xml", + ".pdf": "application/pdf", + ".json": "application/json", + ".csv": "text/csv", + ".txt": "text/plain", + ".html": "text/html", + ".xml": "application/xml", + } + return mime_map.get(ext, "application/octet-stream") + + +class SandboxExecutor: + """ + Manages ephemeral container execution for sandbox code. + + This executor creates Docker containers on-demand, runs code, + captures output, and destroys the container after execution. + """ + + def __init__( + self, + config: Optional[SandboxConfig] = None, + docker_client: Optional[Any] = None, + ): + """ + Initialize the sandbox executor. + + Args: + config: Sandbox configuration. If None, loads from deployment config. + docker_client: Optional Docker client instance. If None, creates one. + """ + self.config = config or self._load_config() + self._docker_client = docker_client + self._client_initialized = False + + def _load_config(self) -> SandboxConfig: + """Load configuration from deployment config.""" + from src.utils.sandbox.config import get_sandbox_config + return get_sandbox_config() + + @property + def docker_client(self): + """Lazy-load Docker client.""" + if not self._client_initialized: + if self._docker_client is None: + try: + import docker + # docker.from_env() automatically reads DOCKER_HOST env var + # for Docker-in-Docker, this should be tcp://sandbox-dind:2375 + docker_host = os.environ.get("DOCKER_HOST", "unix:///var/run/docker.sock") + self._docker_client = docker.from_env() + logger.debug(f"Docker client initialized (DOCKER_HOST={docker_host})") + except Exception as e: + logger.error(f"Failed to initialize Docker client: {e}") + raise RuntimeError( + "Docker client initialization failed. " + "Ensure Docker/DinD is running and accessible." + ) from e + self._client_initialized = True + return self._docker_client + + def execute( + self, + code: str, + language: str = "python", + image: Optional[str] = None, + timeout: Optional[float] = None, + limits: Optional[ResourceLimits] = None, + ) -> SandboxResult: + """ + Execute code in an ephemeral container. + + Args: + code: The code to execute. + language: Programming language (python, bash, sh). + image: Docker image to use. If None, uses default_image. + timeout: Execution timeout in seconds. If None, uses config default. + limits: Resource limits. If None, uses config defaults. + + Returns: + SandboxResult with execution output and metadata. + """ + # Resolve defaults + image = image or self.config.default_image + timeout = timeout if timeout is not None else self.config.timeout + limits = limits or self.config.resource_limits + + # Validate image + if not self.config.is_image_allowed(image): + return SandboxResult( + error=f"Image '{image}' is not in the allowlist. " + f"Allowed images: {', '.join(self.config.image_allowlist)}", + exit_code=-1, + ) + + # Validate language + if language not in LANGUAGE_COMMANDS: + return SandboxResult( + error=f"Unsupported language '{language}'. " + f"Supported: {', '.join(LANGUAGE_COMMANDS.keys())}", + exit_code=-1, + ) + + container = None + start_time = time.time() + + try: + # Create container + container = self._create_container(image, limits) + logger.info(f"Created sandbox container {container.short_id} with image {image}") + + # Start container + container.start() + + # Prepare /workspace & /workspace/output with world-writable perms + # (runs as root so it works for any image user) + self._prepare_workspace(container) + + # Inject the user's code + self._inject_code(container, code, language) + + # Wait for completion with timeout + result = self._run_with_timeout(container, timeout, language) + result.execution_time = time.time() - start_time + + # Capture output files if execution succeeded + if not result.error and not result.timed_out: + try: + result.files = self._capture_output_files(container) + except Exception as e: + logger.warning(f"Failed to capture output files: {e}") + + return result + + except Exception as e: + logger.error(f"Sandbox execution failed: {e}", exc_info=True) + return SandboxResult( + error=f"Execution failed: {str(e)}", + exit_code=-1, + execution_time=time.time() - start_time, + ) + finally: + # Always cleanup container + if container: + self._cleanup_container(container) + + def _create_container(self, image: str, limits: ResourceLimits): + """Create an ephemeral container with resource limits. + + The container is configured so that ``/workspace`` and its ``output`` + sub-directory are world-writable (mode 0o777). This makes the + executor image-agnostic: it works whether the image runs as root + (e.g. ``python:3.11-slim``) or as an unprivileged user (e.g. + ``jupyter/scipy-notebook`` which runs as UID 1000 / ``jovyan``). + """ + import docker + + # Build container configuration. + # We do NOT set ``working_dir`` here; the entrypoint script below + # creates /workspace with the right permissions first, then the + # actual code execution later uses ``workdir="/workspace"``. + container_config = { + "image": image, + "command": ["sleep", "infinity"], # Keep alive until we run code + "detach": True, + "mem_limit": limits.memory, + "nano_cpus": int(limits.cpu * 1e9), # Convert cores to nanocpus + "pids_limit": limits.pids_limit, + "network_mode": "bridge" if self.config.network_enabled else "none", + # Security settings + "read_only": False, # Need to write script and outputs + "security_opt": ["no-new-privileges"], + # Don't auto-remove so we can capture output + "auto_remove": False, + } + + try: + container = self.docker_client.containers.create(**container_config) + return container + except docker.errors.ImageNotFound: + # Try to pull the image with registry credentials if configured + logger.info(f"Image {image} not found locally, pulling...") + auth_config = self._get_registry_auth(image) + if auth_config: + logger.debug(f"Using registry credentials for {auth_config.get('registry', 'default')}") + self.docker_client.images.pull(image, auth_config=auth_config) + else: + self.docker_client.images.pull(image) + return self.docker_client.containers.create(**container_config) + + def _prepare_workspace(self, container) -> None: + """Create /workspace and /workspace/output with world-writable permissions. + + All commands are executed as **root** so they succeed regardless of the + image's default user. The directories are set to mode 0o777 so that + subsequent code execution (which may run as an unprivileged user) can + read, write, and create files freely. + """ + setup_cmds = [ + ["mkdir", "-p", "/workspace/output"], + ["chmod", "777", "/workspace"], + ["chmod", "777", "/workspace/output"], + ] + for cmd in setup_cmds: + exit_code, output = container.exec_run(cmd, user="root") + if exit_code != 0: + detail = (output or b"").decode("utf-8", errors="replace") + raise RuntimeError( + f"Workspace preparation failed (cmd={cmd!r}, " + f"exit_code={exit_code}): {detail}" + ) + + def _get_registry_auth(self, image: str) -> Optional[Dict[str, str]]: + """ + Get registry authentication config for an image. + + Returns auth_config dict if the image matches the configured registry, + otherwise returns None. + """ + if not self.config.registry.is_configured(): + return None + + # Check if image is from the configured registry + registry_url = self.config.registry.url + + # Image could be formatted as: + # - registry.example.com/image:tag + # - registry.example.com/namespace/image:tag + if image.startswith(registry_url) or image.startswith(f"{registry_url}/"): + return self.config.registry.get_credentials() + + return None + + def _inject_code(self, container, code: str, language: str): + """Write code to a script file inside the container. + + The script is packaged in a tar archive and extracted to ``/workspace``. + Ownership is set to ``0:0`` (root) with mode ``0o755`` so that every + user in the container can read and execute it. + """ + ext = LANGUAGE_EXTENSIONS.get(language, ".py") + script_name = f"script{ext}" + + # Build a tar archive containing the script + tar_stream = io.BytesIO() + with tarfile.open(fileobj=tar_stream, mode="w") as tar: + script_bytes = code.encode("utf-8") + info = tarfile.TarInfo(name=script_name) + info.size = len(script_bytes) + info.mode = 0o755 # world-readable + executable + info.uid = 0 + info.gid = 0 + tar.addfile(info, io.BytesIO(script_bytes)) + + tar_stream.seek(0) + container.put_archive("/workspace", tar_stream) + logger.debug(f"Injected {len(code)} bytes of {language} code into container") + + def _run_with_timeout( + self, + container, + timeout: float, + language: str, + ) -> SandboxResult: + """Run the code with timeout enforcement.""" + import docker + + command = LANGUAGE_COMMANDS[language] + + try: + # Execute the command + exec_result = container.exec_run( + command, + workdir="/workspace", + demux=True, # Separate stdout and stderr + ) + + # Parse output + stdout_bytes, stderr_bytes = exec_result.output + stdout = (stdout_bytes or b"").decode("utf-8", errors="replace") + stderr = (stderr_bytes or b"").decode("utf-8", errors="replace") + + # Truncate if needed + truncated = False + max_chars = self.config.output_max_chars + + if len(stdout) > max_chars: + stdout = stdout[:max_chars] + f"\n... [truncated at {max_chars} characters]" + truncated = True + + if len(stderr) > max_chars: + stderr = stderr[:max_chars] + f"\n... [truncated at {max_chars} characters]" + truncated = True + + return SandboxResult( + stdout=stdout, + stderr=stderr, + exit_code=exec_result.exit_code, + truncated=truncated, + ) + + except Exception as e: + # Check if it was a timeout + if "timeout" in str(e).lower(): + # Try to get partial output + try: + logs = container.logs(stdout=True, stderr=True).decode("utf-8", errors="replace") + except Exception: + logs = "" + + return SandboxResult( + stdout=logs, + stderr="", + exit_code=-1, + timed_out=True, + error=f"Execution timed out after {timeout} seconds", + ) + raise + + def _capture_output_files(self, container) -> List[FileOutput]: + """Capture generated files from the sandbox container. + + Looks in ``/workspace/output/`` first. If that directory is empty or + missing, falls back to scanning ``/workspace/`` for common output files + (images, CSVs, etc.) that the code may have written to the working + directory instead. + """ + files = self._extract_files_from_path(container, "/workspace/output/") + if files: + logger.info("Captured %d file(s) from /workspace/output/", len(files)) + return files + + # Fallback: scan /workspace/ root for image/data files the code may + # have saved outside the output directory. + logger.debug("/workspace/output/ empty or missing; scanning /workspace/ for output files") + SCAN_EXTENSIONS = { + ".png", ".jpg", ".jpeg", ".gif", ".svg", ".pdf", + ".csv", ".json", ".html", ".txt", ".xml", + } + fallback_files = self._extract_files_from_path(container, "/workspace/") + # Filter to only include output-like files (skip the injected script) + filtered: List[FileOutput] = [] + for f in fallback_files: + ext = Path(f.filename).suffix.lower() + if ext in SCAN_EXTENSIONS: + filtered.append(f) + if filtered: + logger.info( + "Captured %d output file(s) from /workspace/ fallback: %s", + len(filtered), + [f.filename for f in filtered], + ) + else: + logger.debug("No output files found in /workspace/ fallback either") + return filtered + + def _extract_files_from_path( + self, container, path: str + ) -> List[FileOutput]: + """Extract regular files from *path* inside the container.""" + files: List[FileOutput] = [] + try: + bits, stat = container.get_archive(path) + + tar_stream = io.BytesIO() + for chunk in bits: + tar_stream.write(chunk) + tar_stream.seek(0) + + with tarfile.open(fileobj=tar_stream, mode="r") as tar: + for member in tar.getmembers(): + if not member.isfile(): + continue + + if member.size > self.config.output_max_file_size: + logger.warning( + "Skipping large file %s (%d bytes)", member.name, member.size + ) + files.append( + FileOutput( + filename=Path(member.name).name, + mimetype=_guess_mimetype(member.name), + content_base64="", + size=member.size, + truncated=True, + ) + ) + continue + + f = tar.extractfile(member) + if f: + content = f.read() + filename = Path(member.name).name + mimetype = _guess_mimetype(member.name) + logger.debug( + "Extracted file %s (%s, %d bytes) from %s", + filename, mimetype, len(content), path, + ) + files.append( + FileOutput( + filename=filename, + mimetype=mimetype, + content_base64=base64.b64encode(content).decode("ascii"), + size=len(content), + truncated=False, + ) + ) + + except Exception as e: + err_str = str(e) + if "No such container path" not in err_str and "404" not in err_str: + logger.warning("Error capturing files from %s: %s", path, e) + else: + logger.debug("Path %s does not exist in container (expected)", path) + + return files + + def _cleanup_container(self, container): + """Force remove the container.""" + try: + container.remove(force=True) + logger.debug(f"Removed sandbox container {container.short_id}") + except Exception as e: + logger.warning(f"Failed to remove container {container.short_id}: {e}") + + def health_check(self) -> bool: + """ + Check if the sandbox executor is operational. + + Returns: + True if Docker is accessible and sandbox is enabled. + """ + if not self.config.enabled: + return False + + try: + self.docker_client.ping() + return True + except Exception as e: + logger.warning(f"Sandbox health check failed: {e}") + return False