diff --git a/strix/runtime/docker_runtime.py b/strix/runtime/docker_runtime.py index 7ba04f88..9822429a 100644 --- a/strix/runtime/docker_runtime.py +++ b/strix/runtime/docker_runtime.py @@ -13,18 +13,42 @@ from .runtime import AbstractRuntime, SandboxInfo - STRIX_IMAGE = os.getenv("STRIX_IMAGE", "ghcr.io/usestrix/strix-sandbox:0.1.10") logger = logging.getLogger(__name__) class DockerRuntime(AbstractRuntime): def __init__(self) -> None: + import platform + + is_mac = platform.system() == "Darwin" + mac_sockets = [ + os.path.expanduser("~/.docker/run/docker.sock"), + os.path.expanduser("~/Library/Containers/com.docker.docker/Data/docker.raw.sock"), + ] + + docker_socket = None + if is_mac and not os.environ.get("DOCKER_HOST"): + for sock in mac_sockets: + if os.path.exists(sock): + docker_socket = sock + break + try: - self.client = docker.from_env() + if docker_socket: + self.client = docker.DockerClient(base_url=f"unix://{docker_socket}") + else: + self.client = docker.from_env() + + self.client.ping() except DockerException as e: logger.exception("Failed to connect to Docker daemon") - raise RuntimeError("Docker is not available or not configured correctly.") from e + hint = ( + "Make sure Docker Desktop is running (macOS)" + if is_mac + else "Try starting Docker (Linux): sudo systemctl start docker" + ) + raise RuntimeError(f"Docker is not available or not configured correctly. {hint}") from e self._scan_container: Container | None = None self._tool_server_port: int | None = None @@ -36,7 +60,7 @@ def _generate_sandbox_token(self) -> str: def _find_available_port(self) -> int: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", 0)) - return cast("int", s.getsockname()[1]) + return cast(int, s.getsockname()[1]) def _get_scan_id(self, agent_id: str) -> str: try: @@ -45,13 +69,21 @@ def _get_scan_id(self, agent_id: str) -> str: tracer = get_global_tracer() if tracer and tracer.scan_config: return str(tracer.scan_config.get("scan_id", "default-scan")) - except ImportError: - logger.debug("Failed to import tracer, using fallback scan ID") - except AttributeError: - logger.debug("Tracer missing scan_config, using fallback scan ID") - + except (ImportError, AttributeError): + logger.debug("Failed to get scan ID from tracer, using fallback") return f"scan-{agent_id.split('-')[0]}" + def _resolve_docker_host(self) -> str: + docker_host = os.getenv("DOCKER_HOST", "") + if not docker_host: + return "127.0.0.1" + + from urllib.parse import urlparse + parsed = urlparse(docker_host) + if parsed.scheme in ("tcp", "http", "https") and parsed.hostname: + return parsed.hostname + return "127.0.0.1" + def _verify_image_available(self, image_name: str, max_retries: int = 3) -> None: def _validate_image(image: docker.models.images.Image) -> None: if not image.id or not image.attrs: @@ -128,9 +160,7 @@ def _create_container_with_retry(self, scan_id: str, max_retries: int = 3) -> Co self._scan_container = container logger.info("Created container %s for scan %s", container.id, scan_id) - self._initialize_container( - container, caido_port, tool_server_port, tool_server_token - ) + self._initialize_container(container, caido_port, tool_server_port, tool_server_token) except DockerException as e: last_exception = e if attempt == max_retries - 1: @@ -138,20 +168,15 @@ def _create_container_with_retry(self, scan_id: str, max_retries: int = 3) -> Co break logger.warning(f"Container creation attempt {attempt + 1}/{max_retries} failed") - self._tool_server_port = None self._tool_server_token = None - - sleep_time = (2**attempt) + (0.1 * attempt) - time.sleep(sleep_time) + time.sleep((2**attempt) + (0.1 * attempt)) else: return container - raise RuntimeError( - f"Failed to create Docker container after {max_retries} attempts: {last_exception}" - ) from last_exception + raise RuntimeError(f"Failed to create Docker container after {max_retries} attempts: {last_exception}") from last_exception - def _get_or_create_scan_container(self, scan_id: str) -> Container: # noqa: PLR0912 + def _get_or_create_scan_container(self, scan_id: str) -> Container: container_name = f"strix-scan-{scan_id}" if self._scan_container: @@ -167,92 +192,40 @@ def _get_or_create_scan_container(self, scan_id: str) -> Container: # noqa: PLR try: container = self.client.containers.get(container_name) container.reload() - - if ( - "strix-scan-id" not in container.labels - or container.labels["strix-scan-id"] != scan_id - ): - logger.warning( - f"Container {container_name} exists but missing/wrong label, updating" - ) - if container.status != "running": logger.info(f"Starting existing container {container_name}") container.start() time.sleep(2) - self._scan_container = container - for env_var in container.attrs["Config"]["Env"]: if env_var.startswith("TOOL_SERVER_PORT="): self._tool_server_port = int(env_var.split("=")[1]) elif env_var.startswith("TOOL_SERVER_TOKEN="): self._tool_server_token = env_var.split("=")[1] - - logger.info(f"Reusing existing container {container_name}") - + return container except NotFound: pass except DockerException as e: logger.warning(f"Failed to get container by name {container_name}: {e}") - else: - return container - try: - containers = self.client.containers.list( - all=True, filters={"label": f"strix-scan-id={scan_id}"} - ) - if containers: - container = containers[0] - if container.status != "running": - container.start() - time.sleep(2) - self._scan_container = container - - for env_var in container.attrs["Config"]["Env"]: - if env_var.startswith("TOOL_SERVER_PORT="): - self._tool_server_port = int(env_var.split("=")[1]) - elif env_var.startswith("TOOL_SERVER_TOKEN="): - self._tool_server_token = env_var.split("=")[1] - - logger.info(f"Found existing container by label for scan {scan_id}") - return container - except DockerException as e: - logger.warning("Failed to find existing container by label for scan %s: %s", scan_id, e) - - logger.info("Creating new Docker container for scan %s", scan_id) return self._create_container_with_retry(scan_id) - def _initialize_container( - self, container: Container, caido_port: int, tool_server_port: int, tool_server_token: str - ) -> None: + def _initialize_container(self, container: Container, caido_port: int, tool_server_port: int, tool_server_token: str) -> None: logger.info("Initializing Caido proxy on port %s", caido_port) - result = container.exec_run( - f"bash -c 'export CAIDO_PORT={caido_port} && /usr/local/bin/docker-entrypoint.sh true'", - detach=False, - ) - + container.exec_run(f"bash -c 'export CAIDO_PORT={caido_port} && /usr/local/bin/docker-entrypoint.sh true'", detach=False) time.sleep(5) - - result = container.exec_run( - "bash -c 'source /etc/profile.d/proxy.sh && echo $CAIDO_API_TOKEN'", user="pentester" - ) - caido_token = result.output.decode().strip() if result.exit_code == 0 else "" - + container.exec_run("bash -c 'source /etc/profile.d/proxy.sh && echo $CAIDO_API_TOKEN'", user="pentester") container.exec_run( f"bash -c 'source /etc/profile.d/proxy.sh && cd /app && " - f"STRIX_SANDBOX_MODE=true CAIDO_API_TOKEN={caido_token} CAIDO_PORT={caido_port} " + f"STRIX_SANDBOX_MODE=true CAIDO_PORT={caido_port} CAIDO_API_TOKEN=$CAIDO_API_TOKEN " f"poetry run python strix/runtime/tool_server.py --token {tool_server_token} " f"--host 0.0.0.0 --port {tool_server_port} &'", detach=True, user="pentester", ) - time.sleep(5) - def _copy_local_directory_to_container( - self, container: Container, local_path: str, target_name: str | None = None - ) -> None: + def _copy_local_directory_to_container(self, container: Container, local_path: str, target_name: str | None = None) -> None: import tarfile from io import BytesIO @@ -262,13 +235,8 @@ def _copy_local_directory_to_container( logger.warning(f"Local path does not exist or is not directory: {local_path_obj}") return - if target_name: - logger.info( - f"Copying local directory {local_path_obj} to container at " - f"/workspace/{target_name}" - ) - else: - logger.info(f"Copying local directory {local_path_obj} to container") + target_display = f"/workspace/{target_name}" if target_name else "/workspace" + logger.info(f"Copying local directory {local_path_obj} to container at {target_display}") tar_buffer = BytesIO() with tarfile.open(fileobj=tar_buffer, mode="w") as tar: @@ -280,23 +248,12 @@ def _copy_local_directory_to_container( tar_buffer.seek(0) container.put_archive("/workspace", tar_buffer.getvalue()) - - container.exec_run( - "chown -R pentester:pentester /workspace && chmod -R 755 /workspace", - user="root", - ) - + container.exec_run("chown -R pentester:pentester /workspace && chmod -R 755 /workspace", user="root") logger.info("Successfully copied local directory to /workspace") - except (OSError, DockerException): logger.exception("Failed to copy local directory to container") - async def create_sandbox( - self, - agent_id: str, - existing_token: str | None = None, - local_sources: list[dict[str, str]] | None = None, - ) -> SandboxInfo: + async def create_sandbox(self, agent_id: str, existing_token: str | None = None, local_sources: list[dict[str, str]] | None = None) -> SandboxInfo: scan_id = self._get_scan_id(agent_id) container = self._get_or_create_scan_container(scan_id) @@ -306,11 +263,7 @@ async def create_sandbox( source_path = source.get("source_path") if not source_path: continue - - target_name = source.get("workspace_subdir") - if not target_name: - target_name = Path(source_path).name or f"target_{index}" - + target_name = source.get("workspace_subdir") or Path(source_path).name or f"target_{index}" self._copy_local_directory_to_container(container, source_path, target_name) setattr(self, source_copied_key, True) @@ -318,13 +271,11 @@ async def create_sandbox( if container_id is None: raise RuntimeError("Docker container ID is unexpectedly None") - token = existing_token if existing_token is not None else self._tool_server_token - + token = existing_token or self._tool_server_token if self._tool_server_port is None or token is None: raise RuntimeError("Tool server not initialized or no token available") api_url = await self.get_sandbox_url(container_id, self._tool_server_port) - await self._register_agent_with_tool_server(api_url, agent_id, token) return { @@ -335,11 +286,8 @@ async def create_sandbox( "agent_id": agent_id, } - async def _register_agent_with_tool_server( - self, api_url: str, agent_id: str, token: str - ) -> None: + async def _register_agent_with_tool_server(self, api_url: str, agent_id: str, token: str) -> None: import httpx - try: async with httpx.AsyncClient(trust_env=False) as client: response = await client.post( @@ -357,29 +305,12 @@ async def get_sandbox_url(self, container_id: str, port: int) -> str: try: container = self.client.containers.get(container_id) container.reload() - host = self._resolve_docker_host() - except NotFound: raise ValueError(f"Container {container_id} not found.") from None except DockerException as e: raise RuntimeError(f"Failed to get container URL for {container_id}: {e}") from e - else: - return f"http://{host}:{port}" - - def _resolve_docker_host(self) -> str: - docker_host = os.getenv("DOCKER_HOST", "") - if not docker_host: - return "127.0.0.1" - - from urllib.parse import urlparse - - parsed = urlparse(docker_host) - - if parsed.scheme in ("tcp", "http", "https") and parsed.hostname: - return parsed.hostname - - return "127.0.0.1" + return f"http://{host}:{port}" async def destroy_sandbox(self, container_id: str) -> None: logger.info("Destroying scan container %s", container_id) @@ -388,11 +319,9 @@ async def destroy_sandbox(self, container_id: str) -> None: container.stop() container.remove() logger.info("Successfully destroyed container %s", container_id) - self._scan_container = None self._tool_server_port = None self._tool_server_token = None - except NotFound: logger.warning("Container %s not found for destruction.", container_id) except DockerException as e: