Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 59 additions & 130 deletions strix/runtime/docker_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,42 @@

from .runtime import AbstractRuntime, SandboxInfo


STRIX_IMAGE = os.getenv("STRIX_IMAGE", "ghcr.io/usestrix/strix-sandbox:0.1.10")
logger = logging.getLogger(__name__)


class DockerRuntime(AbstractRuntime):
def __init__(self) -> None:
import platform

is_mac = platform.system() == "Darwin"
mac_sockets = [
os.path.expanduser("~/.docker/run/docker.sock"),
os.path.expanduser("~/Library/Containers/com.docker.docker/Data/docker.raw.sock"),
]

docker_socket = None
if is_mac and not os.environ.get("DOCKER_HOST"):
for sock in mac_sockets:
if os.path.exists(sock):
docker_socket = sock
break

try:
self.client = docker.from_env()
if docker_socket:
self.client = docker.DockerClient(base_url=f"unix://{docker_socket}")
else:
self.client = docker.from_env()

self.client.ping()
except DockerException as e:
logger.exception("Failed to connect to Docker daemon")
raise RuntimeError("Docker is not available or not configured correctly.") from e
hint = (
"Make sure Docker Desktop is running (macOS)"
if is_mac
else "Try starting Docker (Linux): sudo systemctl start docker"
)
raise RuntimeError(f"Docker is not available or not configured correctly. {hint}") from e

self._scan_container: Container | None = None
self._tool_server_port: int | None = None
Expand All @@ -36,7 +60,7 @@ def _generate_sandbox_token(self) -> str:
def _find_available_port(self) -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return cast("int", s.getsockname()[1])
return cast(int, s.getsockname()[1])

def _get_scan_id(self, agent_id: str) -> str:
try:
Expand All @@ -45,13 +69,21 @@ def _get_scan_id(self, agent_id: str) -> str:
tracer = get_global_tracer()
if tracer and tracer.scan_config:
return str(tracer.scan_config.get("scan_id", "default-scan"))
except ImportError:
logger.debug("Failed to import tracer, using fallback scan ID")
except AttributeError:
logger.debug("Tracer missing scan_config, using fallback scan ID")

except (ImportError, AttributeError):
logger.debug("Failed to get scan ID from tracer, using fallback")
return f"scan-{agent_id.split('-')[0]}"

def _resolve_docker_host(self) -> str:
docker_host = os.getenv("DOCKER_HOST", "")
if not docker_host:
return "127.0.0.1"

from urllib.parse import urlparse
parsed = urlparse(docker_host)
if parsed.scheme in ("tcp", "http", "https") and parsed.hostname:
return parsed.hostname
return "127.0.0.1"

def _verify_image_available(self, image_name: str, max_retries: int = 3) -> None:
def _validate_image(image: docker.models.images.Image) -> None:
if not image.id or not image.attrs:
Expand Down Expand Up @@ -128,30 +160,23 @@ def _create_container_with_retry(self, scan_id: str, max_retries: int = 3) -> Co
self._scan_container = container
logger.info("Created container %s for scan %s", container.id, scan_id)

self._initialize_container(
container, caido_port, tool_server_port, tool_server_token
)
self._initialize_container(container, caido_port, tool_server_port, tool_server_token)
except DockerException as e:
last_exception = e
if attempt == max_retries - 1:
logger.exception(f"Failed to create container after {max_retries} attempts")
break

logger.warning(f"Container creation attempt {attempt + 1}/{max_retries} failed")

self._tool_server_port = None
self._tool_server_token = None

sleep_time = (2**attempt) + (0.1 * attempt)
time.sleep(sleep_time)
time.sleep((2**attempt) + (0.1 * attempt))
else:
return container

raise RuntimeError(
f"Failed to create Docker container after {max_retries} attempts: {last_exception}"
) from last_exception
raise RuntimeError(f"Failed to create Docker container after {max_retries} attempts: {last_exception}") from last_exception

def _get_or_create_scan_container(self, scan_id: str) -> Container: # noqa: PLR0912
def _get_or_create_scan_container(self, scan_id: str) -> Container:
container_name = f"strix-scan-{scan_id}"

if self._scan_container:
Expand All @@ -167,92 +192,40 @@ def _get_or_create_scan_container(self, scan_id: str) -> Container: # noqa: PLR
try:
container = self.client.containers.get(container_name)
container.reload()

if (
"strix-scan-id" not in container.labels
or container.labels["strix-scan-id"] != scan_id
):
logger.warning(
f"Container {container_name} exists but missing/wrong label, updating"
)

if container.status != "running":
logger.info(f"Starting existing container {container_name}")
container.start()
time.sleep(2)

self._scan_container = container

for env_var in container.attrs["Config"]["Env"]:
if env_var.startswith("TOOL_SERVER_PORT="):
self._tool_server_port = int(env_var.split("=")[1])
elif env_var.startswith("TOOL_SERVER_TOKEN="):
self._tool_server_token = env_var.split("=")[1]

logger.info(f"Reusing existing container {container_name}")

return container
except NotFound:
pass
except DockerException as e:
logger.warning(f"Failed to get container by name {container_name}: {e}")
else:
return container

try:
containers = self.client.containers.list(
all=True, filters={"label": f"strix-scan-id={scan_id}"}
)
if containers:
container = containers[0]
if container.status != "running":
container.start()
time.sleep(2)
self._scan_container = container

for env_var in container.attrs["Config"]["Env"]:
if env_var.startswith("TOOL_SERVER_PORT="):
self._tool_server_port = int(env_var.split("=")[1])
elif env_var.startswith("TOOL_SERVER_TOKEN="):
self._tool_server_token = env_var.split("=")[1]

logger.info(f"Found existing container by label for scan {scan_id}")
return container
except DockerException as e:
logger.warning("Failed to find existing container by label for scan %s: %s", scan_id, e)

logger.info("Creating new Docker container for scan %s", scan_id)
return self._create_container_with_retry(scan_id)

def _initialize_container(
self, container: Container, caido_port: int, tool_server_port: int, tool_server_token: str
) -> None:
def _initialize_container(self, container: Container, caido_port: int, tool_server_port: int, tool_server_token: str) -> None:
logger.info("Initializing Caido proxy on port %s", caido_port)
result = container.exec_run(
f"bash -c 'export CAIDO_PORT={caido_port} && /usr/local/bin/docker-entrypoint.sh true'",
detach=False,
)

container.exec_run(f"bash -c 'export CAIDO_PORT={caido_port} && /usr/local/bin/docker-entrypoint.sh true'", detach=False)
time.sleep(5)

result = container.exec_run(
"bash -c 'source /etc/profile.d/proxy.sh && echo $CAIDO_API_TOKEN'", user="pentester"
)
caido_token = result.output.decode().strip() if result.exit_code == 0 else ""

container.exec_run("bash -c 'source /etc/profile.d/proxy.sh && echo $CAIDO_API_TOKEN'", user="pentester")
container.exec_run(
f"bash -c 'source /etc/profile.d/proxy.sh && cd /app && "
f"STRIX_SANDBOX_MODE=true CAIDO_API_TOKEN={caido_token} CAIDO_PORT={caido_port} "
f"STRIX_SANDBOX_MODE=true CAIDO_PORT={caido_port} CAIDO_API_TOKEN=$CAIDO_API_TOKEN "
f"poetry run python strix/runtime/tool_server.py --token {tool_server_token} "
f"--host 0.0.0.0 --port {tool_server_port} &'",
detach=True,
user="pentester",
)

time.sleep(5)

def _copy_local_directory_to_container(
self, container: Container, local_path: str, target_name: str | None = None
) -> None:
def _copy_local_directory_to_container(self, container: Container, local_path: str, target_name: str | None = None) -> None:
import tarfile
from io import BytesIO

Expand All @@ -262,13 +235,8 @@ def _copy_local_directory_to_container(
logger.warning(f"Local path does not exist or is not directory: {local_path_obj}")
return

if target_name:
logger.info(
f"Copying local directory {local_path_obj} to container at "
f"/workspace/{target_name}"
)
else:
logger.info(f"Copying local directory {local_path_obj} to container")
target_display = f"/workspace/{target_name}" if target_name else "/workspace"
logger.info(f"Copying local directory {local_path_obj} to container at {target_display}")

tar_buffer = BytesIO()
with tarfile.open(fileobj=tar_buffer, mode="w") as tar:
Expand All @@ -280,23 +248,12 @@ def _copy_local_directory_to_container(

tar_buffer.seek(0)
container.put_archive("/workspace", tar_buffer.getvalue())

container.exec_run(
"chown -R pentester:pentester /workspace && chmod -R 755 /workspace",
user="root",
)

container.exec_run("chown -R pentester:pentester /workspace && chmod -R 755 /workspace", user="root")
logger.info("Successfully copied local directory to /workspace")

except (OSError, DockerException):
logger.exception("Failed to copy local directory to container")

async def create_sandbox(
self,
agent_id: str,
existing_token: str | None = None,
local_sources: list[dict[str, str]] | None = None,
) -> SandboxInfo:
async def create_sandbox(self, agent_id: str, existing_token: str | None = None, local_sources: list[dict[str, str]] | None = None) -> SandboxInfo:
scan_id = self._get_scan_id(agent_id)
container = self._get_or_create_scan_container(scan_id)

Expand All @@ -306,25 +263,19 @@ async def create_sandbox(
source_path = source.get("source_path")
if not source_path:
continue

target_name = source.get("workspace_subdir")
if not target_name:
target_name = Path(source_path).name or f"target_{index}"

target_name = source.get("workspace_subdir") or Path(source_path).name or f"target_{index}"
self._copy_local_directory_to_container(container, source_path, target_name)
setattr(self, source_copied_key, True)

container_id = container.id
if container_id is None:
raise RuntimeError("Docker container ID is unexpectedly None")

token = existing_token if existing_token is not None else self._tool_server_token

token = existing_token or self._tool_server_token
if self._tool_server_port is None or token is None:
raise RuntimeError("Tool server not initialized or no token available")

api_url = await self.get_sandbox_url(container_id, self._tool_server_port)

await self._register_agent_with_tool_server(api_url, agent_id, token)

return {
Expand All @@ -335,11 +286,8 @@ async def create_sandbox(
"agent_id": agent_id,
}

async def _register_agent_with_tool_server(
self, api_url: str, agent_id: str, token: str
) -> None:
async def _register_agent_with_tool_server(self, api_url: str, agent_id: str, token: str) -> None:
import httpx

try:
async with httpx.AsyncClient(trust_env=False) as client:
response = await client.post(
Expand All @@ -357,29 +305,12 @@ async def get_sandbox_url(self, container_id: str, port: int) -> str:
try:
container = self.client.containers.get(container_id)
container.reload()

host = self._resolve_docker_host()

except NotFound:
raise ValueError(f"Container {container_id} not found.") from None
except DockerException as e:
raise RuntimeError(f"Failed to get container URL for {container_id}: {e}") from e
else:
return f"http://{host}:{port}"

def _resolve_docker_host(self) -> str:
docker_host = os.getenv("DOCKER_HOST", "")
if not docker_host:
return "127.0.0.1"

from urllib.parse import urlparse

parsed = urlparse(docker_host)

if parsed.scheme in ("tcp", "http", "https") and parsed.hostname:
return parsed.hostname

return "127.0.0.1"
return f"http://{host}:{port}"

async def destroy_sandbox(self, container_id: str) -> None:
logger.info("Destroying scan container %s", container_id)
Expand All @@ -388,11 +319,9 @@ async def destroy_sandbox(self, container_id: str) -> None:
container.stop()
container.remove()
logger.info("Successfully destroyed container %s", container_id)

self._scan_container = None
self._tool_server_port = None
self._tool_server_token = None

except NotFound:
logger.warning("Container %s not found for destruction.", container_id)
except DockerException as e:
Expand Down