diff --git a/.github/workflows/deploy-proxy.yml b/.github/workflows/deploy-proxy.yml new file mode 100644 index 0000000..e15183c --- /dev/null +++ b/.github/workflows/deploy-proxy.yml @@ -0,0 +1,172 @@ +name: Deploy Proxy + +on: + workflow_dispatch: + push: + branches: [main] + paths: + - 'cmd/proxy/**' + - 'internal/proxy/**' + - 'deploy/proxy.service' + +jobs: + deploy: + runs-on: ubuntu-latest + environment: staging + permissions: + contents: read + id-token: write + + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + cache: true + + - name: Build proxy binary + run: | + mkdir -p bin + CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags '-s -w' -o bin/proxy ./cmd/proxy + + - name: Authenticate to GCP + uses: google-github-actions/auth@v2 + with: + workload_identity_provider: ${{ secrets.GCP_WORKLOAD_IDENTITY_PROVIDER }} + service_account: ${{ secrets.GCP_SERVICE_ACCOUNT }} + + - name: Set up Cloud SDK + uses: google-github-actions/setup-gcloud@v2 + + - name: Discover and deploy proxy to VMD instances + env: + GCP_PROJECT: ${{ vars.GCP_PROJECT }} + VMD_LABEL: ${{ vars.VMD_LABEL }} + VMD_INSTALL_DIR: ${{ vars.VMD_INSTALL_DIR }} + SHA: ${{ github.sha }} + SANDBOX_ACCESS_TOKEN_SEED: ${{ secrets.SANDBOX_ACCESS_TOKEN_SEED_STAGING }} + TERMINAL_ALLOWED_ORIGINS: ${{ vars.TERMINAL_ALLOWED_ORIGINS_STAGING }} + REQUIRE_DATA_PLANE: ${{ vars.REQUIRE_DATA_PLANE_STAGING }} + run: | + python3 - <<'PYEOF' + import os, sys, subprocess, textwrap + from concurrent.futures import ThreadPoolExecutor, as_completed + import re as _re + + project = os.environ['GCP_PROJECT'] + label = os.environ.get('VMD_LABEL', 'component=vmd') + install_dir = os.environ.get('VMD_INSTALL_DIR', '/usr/local/bin') + sha = os.environ['SHA'][:8] + + access_seed = os.environ.get('SANDBOX_ACCESS_TOKEN_SEED', '') + if access_seed and not _re.fullmatch(r'[0-9a-fA-F]{64,}', access_seed): + print('ERROR: SANDBOX_ACCESS_TOKEN_SEED must be hex-encoded, >= 32 bytes (64 hex chars)', file=sys.stderr) + sys.exit(1) + terminal_origins = os.environ.get('TERMINAL_ALLOWED_ORIGINS', '') + if terminal_origins and not _re.fullmatch(r'[A-Za-z0-9.,:/*\-]+', terminal_origins): + print('ERROR: TERMINAL_ALLOWED_ORIGINS contains disallowed characters', file=sys.stderr) + sys.exit(1) + require_data_plane = os.environ.get('REQUIRE_DATA_PLANE', '') + if require_data_plane not in ('', '0', '1'): + print('ERROR: REQUIRE_DATA_PLANE must be empty, "0", or "1"', file=sys.stderr) + sys.exit(1) + + result = subprocess.run([ + 'gcloud', 'compute', 'instances', 'list', + f'--project={project}', + f'--filter=labels.{label} AND status=RUNNING', + '--format=csv[no-heading](name,zone)', + ], capture_output=True, text=True, check=True) + + instances = [ + {'name': r[0], 'zone': r[1]} + for line in result.stdout.strip().splitlines() + if line.strip() + for r in [line.strip().split(',')] + ] + + if not instances: + print(f'No instances with label {label} found in {project}', file=sys.stderr) + sys.exit(1) + + print(f'Deploying proxy to {len(instances)} instance(s)') + + def deploy(inst): + name, zone = inst['name'], inst['zone'] + tag = f'{name}/{zone}' + + for src, dst in [ + ('bin/proxy', f'/tmp/proxy-{sha}'), + ('deploy/proxy.service', '/tmp/proxy.service'), + ]: + subprocess.run([ + 'gcloud', 'compute', 'scp', src, f'{name}:{dst}', + f'--zone={zone}', f'--project={project}', + '--quiet', '--tunnel-through-iap', + ], check=True, capture_output=True) + print(f'[{tag}] proxy uploaded') + + # Deploy proxy only. VMD is NOT restarted. + deploy_script = textwrap.dedent(f''' + set -euo pipefail + + sudo mv /tmp/proxy-{sha} {install_dir}/proxy + sudo chmod +x {install_dir}/proxy + + sudo mv /tmp/proxy.service /etc/systemd/system/proxy.service + sudo systemctl daemon-reload + sudo systemctl enable proxy + + sudo mkdir -p /etc/sandbox + if [ -n "{access_seed}" ]; then + sudo tee /etc/sandbox/proxy.env > /dev/null <&2 + sudo systemctl status --no-pager proxy >&2 || true + sudo journalctl -u proxy --no-pager -n 40 >&2 || true + exit 1 + ) + ''') + r = subprocess.run([ + 'gcloud', 'compute', 'ssh', name, + f'--zone={zone}', f'--project={project}', + '--quiet', '--tunnel-through-iap', + '--command', deploy_script, + ], capture_output=True, text=True) + if r.returncode != 0: + raise RuntimeError( + f'proxy not healthy\n' + f'--- stdout ---\n{r.stdout}\n' + f'--- stderr ---\n{r.stderr}' + ) + print(f'[{tag}] proxy active') + + failed = [] + with ThreadPoolExecutor(max_workers=len(instances)) as ex: + futures = {ex.submit(deploy, inst): inst for inst in instances} + for f in as_completed(futures): + inst = futures[f] + try: + f.result() + except Exception as e: + tag = f"{inst['name']}/{inst['zone']}" + print(f'[{tag}] FAILED: {e}', file=sys.stderr) + failed.append(tag) + + if failed: + print(f'Deploy failed on: {", ".join(failed)}', file=sys.stderr) + sys.exit(1) + + print(f'Deployed proxy to {len(instances)} instance(s). sha={sha}') + PYEOF diff --git a/.github/workflows/deploy-vmd.yml b/.github/workflows/deploy-vmd.yml index ac7f4a4..764dec3 100644 --- a/.github/workflows/deploy-vmd.yml +++ b/.github/workflows/deploy-vmd.yml @@ -7,11 +7,13 @@ on: paths: - 'cmd/vmd/**' - 'cmd/boxd/**' - - 'cmd/proxy/**' - 'internal/vm/**' - 'internal/network/**' - - 'internal/proxy/**' - - 'deploy/proxy.service' + - 'deploy/superserve-vmd.service' + - 'deploy/firecracker@.service' + - 'deploy/firecracker-netns@.service' + - 'deploy/sandboxes.slice' + - 'scripts/fc-cleanup' jobs: deploy: @@ -30,12 +32,25 @@ jobs: go-version-file: go.mod cache: true + - name: Check if boxd source changed + id: boxd-changed + run: | + # Compare boxd-related paths against the parent commit. + if git diff --name-only HEAD~1 HEAD | grep -qE '^cmd/boxd/|^proto/boxdpb/|^proto/boxd\.proto'; then + echo "changed=true" >> "$GITHUB_OUTPUT" + else + echo "changed=false" >> "$GITHUB_OUTPUT" + fi + - name: Build binaries run: | mkdir -p bin - CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags '-s -w' -o bin/vmd ./cmd/vmd - CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags '-s -w' -o bin/boxd ./cmd/boxd - CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags '-s -w' -o bin/proxy ./cmd/proxy + CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags '-s -w' -trimpath -o bin/vmd ./cmd/vmd + if [ "${{ steps.boxd-changed.outputs.changed }}" = "true" ]; then + CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags '-s -w' -trimpath -o bin/boxd ./cmd/boxd + else + echo "boxd source unchanged — skipping build" + fi - name: Authenticate to GCP uses: google-github-actions/auth@v2 @@ -53,18 +68,19 @@ jobs: VMD_SERVICE: ${{ vars.VMD_SERVICE }} VMD_INSTALL_DIR: ${{ vars.VMD_INSTALL_DIR }} SHA: ${{ github.sha }} + BOXD_CHANGED: ${{ steps.boxd-changed.outputs.changed }} run: | python3 - <<'PYEOF' - import os, sys, json, subprocess, textwrap + import os, sys, subprocess, textwrap from concurrent.futures import ThreadPoolExecutor, as_completed project = os.environ['GCP_PROJECT'] label = os.environ.get('VMD_LABEL', 'component=vmd') - service = os.environ.get('VMD_SERVICE', 'vmd') + service = os.environ.get('VMD_SERVICE', 'superserve-vmd') install_dir = os.environ.get('VMD_INSTALL_DIR', '/usr/local/bin') sha = os.environ['SHA'][:8] + boxd_changed = os.environ.get('BOXD_CHANGED', 'true') == 'true' - # Discover instances by label. result = subprocess.run([ 'gcloud', 'compute', 'instances', 'list', f'--project={project}', @@ -83,97 +99,89 @@ jobs: print(f'No instances with label {label} found in {project}', file=sys.stderr) sys.exit(1) - print(f'Deploying to {len(instances)} instance(s)') + print(f'Deploying VMD to {len(instances)} instance(s)') def deploy(inst): name, zone = inst['name'], inst['zone'] tag = f'{name}/{zone}' - - for binary in ('vmd', 'boxd', 'proxy'): - subprocess.run([ - 'gcloud', 'compute', 'scp', - f'bin/{binary}', - f'{name}:/tmp/{binary}-{sha}', - f'--zone={zone}', - f'--project={project}', - '--quiet', - '--tunnel-through-iap', - ], check=True, capture_output=True) - subprocess.run([ - 'gcloud', 'compute', 'scp', - 'deploy/proxy.service', - f'{name}:/tmp/proxy.service', - f'--zone={zone}', - f'--project={project}', - '--quiet', - '--tunnel-through-iap', + scp = lambda src, dst: subprocess.run([ + 'gcloud', 'compute', 'scp', src, f'{name}:{dst}', + f'--zone={zone}', f'--project={project}', + '--quiet', '--tunnel-through-iap', ], check=True, capture_output=True) - print(f'[{tag}] binaries uploaded') - - # Inject updated boxd into the base rootfs atomically: - # 1. copy ROOTFS -> ROOTFS.new on the same filesystem - # 2. mount the copy, cp boxd into it, umount - # 3. mv ROOTFS.new -> ROOTFS (rename is atomic on POSIX) - # The original rootfs is untouched until the final rename, so a - # failure at any earlier step leaves the system in its - # previous working state. The trap ensures we clean up any - # staging file/mount on failure. - # - # textwrap.dedent is used so we can keep the script indented - # consistently with the enclosing Python code (which YAML - # requires) without leaking that indentation into the shell. - # NOTE on braces: GitHub Actions scans the YAML for its own - # expression syntax (dollar-double-brace) EVERYWHERE in the - # file, including inside Python heredocs and string literals. - # We cannot use doubled curly braces in f-strings here because - # the raw YAML characters would get parsed as a GitHub - # expression and fail with "Unrecognized named-value: ROOTFS". - # Workaround: use $ROOTFS / $STAGING without curly braces. - # Bash terminates variable names at non-identifier chars - # (dot, space, slash), so $ROOTFS.new.$$ works identically. - # The cleanup callback is inlined into trap so we avoid any - # standalone curly-brace pairs from a bash function body. + + # Upload binaries + scp('bin/vmd', f'/tmp/vmd-{sha}') + if boxd_changed: + scp('bin/boxd', f'/tmp/boxd-{sha}') + + # Upload systemd units and helper scripts + scp('deploy/superserve-vmd.service', '/tmp/superserve-vmd.service') + scp('deploy/firecracker@.service', '/tmp/firecracker@.service') + scp('deploy/firecracker-netns@.service', '/tmp/firecracker-netns@.service') + scp('deploy/sandboxes.slice', '/tmp/sandboxes.slice') + scp('scripts/fc-cleanup', '/tmp/fc-cleanup') + print(f'[{tag}] files uploaded') + inject_script = textwrap.dedent(f''' set -euo pipefail - sudo mv /tmp/vmd-{sha} {install_dir}/vmd - sudo mv /tmp/boxd-{sha} {install_dir}/boxd - sudo chmod +x {install_dir}/vmd {install_dir}/boxd - - # Resolve the base rootfs path from whichever env file is - # present. We cannot use `grep -s ... file1 file2` because - # grep exits with 2 when ANY listed file is missing (even - # with -s), and `set -euo pipefail` treats that as a fatal - # error via the pipeline. Enumerate candidate files safely. - ROOTFS="" - for env_file in /etc/superserve/vmd.env /etc/agentbox/vmd.env; do - if [ -f "$env_file" ]; then - candidate=$(grep "^BASE_ROOTFS_PATH=" "$env_file" | head -1 | cut -d= -f2) || true - if [ -n "$candidate" ]; then - ROOTFS="$candidate" - break + # Install VMD binary + sudo mv /tmp/vmd-{sha} {install_dir}/vmd + sudo chmod +x {install_dir}/vmd + + # Install systemd units + sudo mv /tmp/superserve-vmd.service /etc/systemd/system/superserve-vmd.service + sudo mv /tmp/firecracker@.service /etc/systemd/system/firecracker@.service + sudo mv /tmp/firecracker-netns@.service /etc/systemd/system/firecracker-netns@.service + sudo mv /tmp/sandboxes.slice /etc/systemd/system/sandboxes.slice + sudo systemctl daemon-reload + + # Install helper scripts + sudo mv /tmp/fc-cleanup {install_dir}/fc-cleanup + sudo chmod +x {install_dir}/fc-cleanup + + # Only inject boxd + rebuild rootfs when boxd source changed. + # Skipping this preserves the rootfs hash so VMD's template + # cache works — no ~8s cold boot on VMD-only deploys. + BOXD_SRC_CHANGED={'true' if boxd_changed else 'false'} + if [ "$BOXD_SRC_CHANGED" = "true" ]; then + sudo mv /tmp/boxd-{sha} {install_dir}/boxd + sudo chmod +x {install_dir}/boxd + + ROOTFS="" + for env_file in /etc/sandbox/vmd.env; do + if [ -f "$env_file" ]; then + candidate=$(grep "^BASE_ROOTFS_PATH=" "$env_file" | head -1 | cut -d= -f2) || true + if [ -n "$candidate" ]; then + ROOTFS="$candidate" + break + fi fi + done + + if [ -n "$ROOTFS" ] && [ -f "$ROOTFS" ]; then + STAGING="$ROOTFS.new.$$" + MNT=$(mktemp -d) + trap '\''if mountpoint -q "$MNT" 2>/dev/null; then sudo umount "$MNT" || true; fi; rmdir "$MNT" 2>/dev/null || true; sudo rm -f "$STAGING" 2>/dev/null || true'\'' EXIT + + sudo cp --reflink=auto "$ROOTFS" "$STAGING" + sudo mount -o loop "$STAGING" "$MNT" + sudo cp {install_dir}/boxd "$MNT/usr/local/bin/boxd" + sudo chmod +x "$MNT/usr/local/bin/boxd" + sudo umount "$MNT" + rmdir "$MNT" + sudo mv "$STAGING" "$ROOTFS" + trap - EXIT + echo "boxd injected into rootfs" + else + echo "WARNING: BASE_ROOTFS_PATH not found; skipping rootfs inject" fi - done - - if [ -n "$ROOTFS" ] && [ -f "$ROOTFS" ]; then - STAGING="$ROOTFS.new.$$" - MNT=$(mktemp -d) - trap '\''if mountpoint -q "$MNT" 2>/dev/null; then sudo umount "$MNT" || true; fi; rmdir "$MNT" 2>/dev/null || true; sudo rm -f "$STAGING" 2>/dev/null || true'\'' EXIT - - sudo cp --reflink=auto "$ROOTFS" "$STAGING" - sudo mount -o loop "$STAGING" "$MNT" - sudo cp {install_dir}/boxd "$MNT/usr/local/bin/boxd" - sudo chmod +x "$MNT/usr/local/bin/boxd" - sudo umount "$MNT" - rmdir "$MNT" - sudo mv "$STAGING" "$ROOTFS" - trap - EXIT - echo "boxd injected into rootfs atomically" else - echo "WARNING: BASE_ROOTFS_PATH not found or not readable; skipping rootfs inject" + echo "boxd source unchanged — skipping build and rootfs inject" fi + # Restart VMD only. Proxy is NOT restarted — it has its own pipeline. sudo systemctl restart {service} sleep 3 sudo systemctl is-active --quiet {service} || ( @@ -182,24 +190,6 @@ jobs: sudo journalctl -u {service} --no-pager -n 40 >&2 || true exit 1 ) - - # Deploy proxy binary and service - sudo mv /tmp/proxy-{sha} {install_dir}/proxy - sudo chmod +x {install_dir}/proxy - - # Install or update the systemd service file - sudo mv /tmp/proxy.service /etc/systemd/system/proxy.service - sudo systemctl daemon-reload - sudo systemctl enable proxy - - sudo systemctl restart proxy - sleep 3 - sudo systemctl is-active --quiet proxy || ( - echo "ERROR: proxy failed to become active after restart" >&2 - sudo systemctl status --no-pager proxy >&2 || true - sudo journalctl -u proxy --no-pager -n 40 >&2 || true - exit 1 - ) ''') r = subprocess.run([ 'gcloud', 'compute', 'ssh', name, @@ -208,11 +198,6 @@ jobs: '--command', inject_script, ], capture_output=True, text=True) if r.returncode != 0: - # Surface both stdout and stderr so deploy failures are - # actually debuggable from the GitHub Actions log. The - # inject script prints journalctl output to stderr on - # service failure; silently swallowing it was a source of - # long debug sessions. raise RuntimeError( f'service not healthy\n' f'--- stdout ---\n{r.stdout}\n' @@ -238,5 +223,5 @@ jobs: print(f'Deploy failed on: {", ".join(failed)}', file=sys.stderr) sys.exit(1) - print(f'Deployed to {len(instances)} instance(s). sha={sha}') + print(f'Deployed VMD to {len(instances)} instance(s). sha={sha}') PYEOF diff --git a/api/openapi.yaml b/api/openapi.yaml index a78e4f1..e713392 100644 --- a/api/openapi.yaml +++ b/api/openapi.yaml @@ -8,7 +8,7 @@ info: ## Sandbox lifecycle ``` - starting --> active <--> idle --> deleted + starting --> active <--> paused --> deleted ``` | Endpoint | What it does | @@ -189,11 +189,11 @@ paths: ## Currently patchable fields - `network` — replaces the egress allow/deny rules. The sandbox - must be in the `active` state; patching a paused or idle sandbox + must be in the `active` state; patching a paused sandbox returns `409`. Rules take effect immediately and are persisted so they survive a future pause/resume cycle. - `metadata` — replaces the sandbox's metadata tags. Can be updated - regardless of sandbox state (active, paused, idle). + regardless of sandbox state (active, paused). security: - apiKey: [] requestBody: @@ -226,13 +226,13 @@ paths: summary: Pause a running sandbox description: | Snapshots the sandbox's full state (memory + disk), suspends the VM, - and transitions to `idle`. Resume it later to continue exactly where + and transitions to `paused`. Resume it later to continue exactly where it left off. security: - apiKey: [] responses: "200": - description: Sandbox is now idle + description: Sandbox is now paused content: application/json: schema: @@ -275,6 +275,7 @@ paths: "500": $ref: "#/components/responses/InternalError" + /sandboxes/{sandbox_id}/exec: parameters: - $ref: "#/components/parameters/SandboxId" @@ -349,76 +350,6 @@ paths: "500": $ref: "#/components/responses/InternalError" - /sandboxes/{sandbox_id}/files/{path}: - parameters: - - $ref: "#/components/parameters/SandboxId" - - name: path - in: path - required: true - schema: - type: string - description: File path inside the sandbox (without leading slash). - - put: - operationId: uploadFile - tags: [Files] - summary: Upload a file into a sandbox - description: Idle sandboxes are automatically resumed before the upload. - security: - - apiKey: [] - requestBody: - required: true - content: - application/octet-stream: - schema: - type: string - format: binary - responses: - "200": - description: File uploaded - content: - application/json: - schema: - type: object - properties: - path: - type: string - size: - type: integer - format: int64 - "404": - $ref: "#/components/responses/NotFound" - "409": - $ref: "#/components/responses/Conflict" - "401": - $ref: "#/components/responses/Unauthorized" - "500": - $ref: "#/components/responses/InternalError" - - get: - operationId: downloadFile - tags: [Files] - summary: Download a file from a sandbox - description: Idle sandboxes are automatically resumed before the download. - security: - - apiKey: [] - responses: - "200": - description: File content - content: - application/octet-stream: - schema: - type: string - format: binary - "404": - $ref: "#/components/responses/NotFound" - "409": - $ref: "#/components/responses/Conflict" - "401": - $ref: "#/components/responses/Unauthorized" - "500": - $ref: "#/components/responses/InternalError" - components: securitySchemes: apiKey: @@ -460,7 +391,7 @@ components: description: > Optional hard lifetime cap in seconds, measured from sandbox creation. When set, the sandbox is destroyed this many seconds - after creation regardless of state (active, paused, idle) or + after creation regardless of state (active, paused) or activity — the user asked for a hard deadline. When unset, the sandbox lives until explicitly paused or deleted. Maximum 604800 (7 days). @@ -489,6 +420,18 @@ components: example: env: prod owner: agent-7 + env_vars: + type: object + additionalProperties: + type: string + description: | + Environment variables injected into every process inside the + sandbox (terminal sessions, exec calls). Not persisted in the + database — they live in the VM agent's memory for the sandbox's + lifetime and survive pause/resume via snapshot. + example: + OPENAI_API_KEY: sk-... + DEBUG: "1" run_id: 7f3c-21 network: $ref: "#/components/schemas/NetworkConfig" @@ -503,11 +446,19 @@ components: type: string status: type: string - enum: [starting, active, pausing, idle, deleted] + enum: [starting, active, pausing, paused, deleted] vcpu_count: type: integer memory_mib: type: integer + access_token: + type: string + description: | + Per-sandbox HMAC access token for data-plane operations + (file upload/download, terminal). Pass as `X-Access-Token` + header when hitting the edge proxy at + `boxd-{id}.sandbox.superserve.ai`. Reusable for the + lifetime of the sandbox. snapshot_id: type: string format: uuid diff --git a/cmd/boxd/main.go b/cmd/boxd/main.go index 53c08a8..6461deb 100644 --- a/cmd/boxd/main.go +++ b/cmd/boxd/main.go @@ -17,6 +17,7 @@ import ( "time" "connectrpc.com/connect" + "connectrpc.com/otelconnect" "github.com/creack/pty" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" @@ -26,10 +27,9 @@ import ( ) const ( - httpPort = 49983 - defaultShell = "/bin/bash" - defaultHome = "/home/user" - maxUploadBytes = 512 * 1024 * 1024 // 512 MB upload limit + httpPort = 49983 + defaultShell = "/bin/bash" + defaultHome = "/home/user" ) // dangerousPaths are paths that must never be modified via the filesystem API. @@ -57,16 +57,31 @@ func main() { mux := http.NewServeMux() + env := &sandboxEnv{} + + // otelconnect server interceptor extracts trace context from inbound + // headers so spans link back to the controlplane → vmd → boxd chain. + // boxd does not initialise the telemetry SDK (it runs inside the VM + // where the collector is unreachable); spans are no-op locally but the + // trace IDs propagate correctly to anything boxd calls out to. + otelInt, otelErr := otelconnect.NewInterceptor() + var handlerOpts []connect.HandlerOption + if otelErr == nil { + handlerOpts = append(handlerOpts, connect.WithInterceptors(otelInt)) + } + // Connect RPC services. procService := &processService{ processes: &sync.Map{}, + env: env, } - mux.Handle(boxdpbconnect.NewProcessServiceHandler(procService)) - mux.Handle(boxdpbconnect.NewFilesystemServiceHandler(&filesystemService{})) + mux.Handle(boxdpbconnect.NewProcessServiceHandler(procService, handlerOpts...)) + mux.Handle(boxdpbconnect.NewFilesystemServiceHandler(&filesystemService{}, handlerOpts...)) - // Raw HTTP endpoints (file content transfer + health). + // Raw HTTP endpoints (file content transfer + health + init). mux.HandleFunc("/files", handleFiles) mux.HandleFunc("/health", handleHealth) + mux.HandleFunc("/init", handleInit(env)) addr := fmt.Sprintf("0.0.0.0:%d", httpPort) log.Printf("boxd listening on %s (Connect RPC + HTTP)", addr) @@ -88,6 +103,34 @@ func handleHealth(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, `{"status":"ok"}`) } +// handleInit accepts sandbox-level environment variables from VMD after boot. +// POST /init with JSON body {"env_vars": {"KEY": "VALUE", ...}}. +func handleInit(env *sandboxEnv) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.Header().Set("Allow", "POST") + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + var body struct { + EnvVars map[string]string `json:"env_vars"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, `{"error":"invalid JSON"}`, http.StatusBadRequest) + return + } + + if len(body.EnvVars) > 0 { + env.set(body.EnvVars) + log.Printf("init: set %d env var(s)", len(body.EnvVars)) + } + + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"status":"ok"}`) + } +} + // --------------------------------------------------------------------------- // Process service (Connect RPC) // --------------------------------------------------------------------------- @@ -97,9 +140,50 @@ type runningProcess struct { tty *os.File // nil for non-PTY processes. } +// sandboxEnv holds sandbox-level environment variables set via POST /init. +// These are injected into every process boxd spawns, underneath per-request +// overrides from StartRequest.envs. +type sandboxEnv struct { + mu sync.RWMutex + vars map[string]string +} + +func (e *sandboxEnv) set(vars map[string]string) { + e.mu.Lock() + defer e.mu.Unlock() + e.vars = vars +} + +func (e *sandboxEnv) environ() []string { + e.mu.RLock() + defer e.mu.RUnlock() + out := make([]string, 0, len(e.vars)) + for k, v := range e.vars { + out = append(out, k+"="+v) + } + return out +} + type processService struct { boxdpbconnect.UnimplementedProcessServiceHandler processes *sync.Map // pid → *runningProcess + env *sandboxEnv +} + +// buildEnv assembles the environment for a child process. Layers (last wins): +// 1. OS base env 2. system defaults (PATH, HOME, USER) 3. sandbox-level +// env vars from /init 4. per-request env vars from StartRequest.envs. +func (s *processService) buildEnv(requestEnvs map[string]string) []string { + env := append(os.Environ(), + "PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin", + "HOME="+defaultHome, + "USER=user", + ) + env = append(env, s.env.environ()...) + for k, v := range requestEnvs { + env = append(env, k+"="+v) + } + return env } func (s *processService) Start(ctx context.Context, req *connect.Request[pb.StartRequest], stream *connect.ServerStream[pb.ProcessEvent]) error { @@ -117,14 +201,7 @@ func (s *processService) Start(ctx context.Context, req *connect.Request[pb.Star cmd.Dir = defaultHome } - cmd.Env = append(os.Environ(), - "PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin", - "HOME="+defaultHome, - "USER=user", - ) - for k, v := range msg.GetEnvs() { - cmd.Env = append(cmd.Env, k+"="+v) - } + cmd.Env = s.buildEnv(msg.GetEnvs()) timeout := time.Duration(msg.GetTimeoutMs()) * time.Millisecond if timeout > 0 { @@ -136,14 +213,7 @@ func (s *processService) Start(ctx context.Context, req *connect.Request[pb.Star if cmd.Dir == "" { cmd.Dir = defaultHome } - cmd.Env = append(os.Environ(), - "PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin", - "HOME="+defaultHome, - "USER=user", - ) - for k, v := range msg.GetEnvs() { - cmd.Env = append(cmd.Env, k+"="+v) - } + cmd.Env = s.buildEnv(msg.GetEnvs()) } isPTY := msg.GetPty() != nil @@ -529,8 +599,42 @@ func handleFileDownload(w http.ResponseWriter, r *http.Request, path string) { http.ServeContent(w, r, filepath.Base(path), info.ModTime(), f) } +// storageFullResponse is the canonical 507 body we return whenever a +// write fails because the sandbox has run out of disk space. It's a +// stable shape (code + message) so SDKs and the eventual web UI can +// branch on `error.code == "sandbox_storage_full"` rather than parsing +// free-form text. +const storageFullResponse = `{"error":{"code":"sandbox_storage_full","message":"Sandbox storage limit reached."}}` + +// writeStorageFull sends the canonical 507 + cleans up any partial +// file that may have been left behind by a failed write. Extracted +// because we handle ENOSPC at two distinct syscall boundaries (open +// and write) and both need the same response. +func writeStorageFull(w http.ResponseWriter, partialPath string) { + if partialPath != "" { + // Best-effort: reclaim the bytes that did land before the + // kernel returned ENOSPC. If the remove itself fails (disk + // problem, race with a concurrent process), we swallow it — + // leaving an empty/partial file on a full disk is strictly + // better than failing the request a second time. + _ = os.Remove(partialPath) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInsufficientStorage) // 507 + _, _ = w.Write([]byte(storageFullResponse)) +} + func handleFileUpload(w http.ResponseWriter, r *http.Request, path string) { if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + // mkdir itself can hit ENOSPC if the sandbox is already + // brimming — inodes exhausted, no room for a new directory + // entry. Surface it as the same storage-full error so users + // get one consistent code for "you're out of disk" regardless + // of which syscall tripped it. + if errors.Is(err, syscall.ENOSPC) { + writeStorageFull(w, "") + return + } errJSON, _ := json.Marshal(map[string]string{"error": "mkdir: " + err.Error()}) http.Error(w, string(errJSON), http.StatusInternalServerError) return @@ -538,24 +642,33 @@ func handleFileUpload(w http.ResponseWriter, r *http.Request, path string) { f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644) if err != nil { + if errors.Is(err, syscall.ENOSPC) { + writeStorageFull(w, "") + return + } errJSON, _ := json.Marshal(map[string]string{"error": "create file: " + err.Error()}) http.Error(w, string(errJSON), http.StatusInternalServerError) return } - defer f.Close() - written, err := io.Copy(f, io.LimitReader(r.Body, maxUploadBytes+1)) + written, err := io.Copy(f, r.Body) + f.Close() if err != nil { + // Remove the partial file — a truncated upload is never useful + // to the caller. This handles both ENOSPC (disk full) and + // client disconnect (network drop, cancel) so interrupted + // uploads don't leave orphaned files eating disk space. + _ = os.Remove(path) + + if errors.Is(err, syscall.ENOSPC) { + writeStorageFull(w, "") + return + } errJSON, _ := json.Marshal(map[string]string{"error": "write: " + err.Error()}) http.Error(w, string(errJSON), http.StatusInternalServerError) return } - if written > maxUploadBytes { - os.Remove(path) - http.Error(w, `{"error":"file too large","max_bytes":536870912}`, http.StatusRequestEntityTooLarge) - return - } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]interface{}{"path": path, "size": written}) + _ = json.NewEncoder(w).Encode(map[string]interface{}{"path": path, "size": written}) } diff --git a/cmd/controlplane/main.go b/cmd/controlplane/main.go index c71514e..026a2df 100644 --- a/cmd/controlplane/main.go +++ b/cmd/controlplane/main.go @@ -8,26 +8,35 @@ import ( "net/http" "os" "os/signal" - "strings" "syscall" "time" + "github.com/exaring/otelpgx" "github.com/jackc/pgx/v5/pgxpool" "github.com/rs/zerolog" "github.com/rs/zerolog/log" + "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "github.com/superserve-ai/sandbox/internal/api" "github.com/superserve-ai/sandbox/internal/config" dbq "github.com/superserve-ai/sandbox/internal/db" + "github.com/superserve-ai/sandbox/internal/hostreg" + "github.com/superserve-ai/sandbox/internal/scheduler" + "github.com/superserve-ai/sandbox/internal/telemetry" + "github.com/superserve-ai/sandbox/internal/vmdclient" "github.com/superserve-ai/sandbox/proto/vmdpb" ) +// version is set by ldflags at build time; falls back to "dev". +var version = "dev" + func main() { zerolog.TimeFieldFormat = zerolog.TimeFormatUnix log.Logger = zerolog.New(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}). - With().Timestamp().Caller().Logger() + With().Timestamp().Caller().Logger(). + Hook(telemetry.ZerologTraceHook{}) if err := run(); err != nil { log.Fatal().Err(err).Msg("controlplane exited with error") @@ -46,8 +55,30 @@ func run() error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - // Connect to PostgreSQL. - dbPool, err := pgxpool.New(ctx, cfg.DatabaseURL) + // Telemetry. No-op when OTEL_EXPORTER_OTLP_ENDPOINT is unset, so local + // dev and tests pay nothing. Shut down with a fresh context so we still + // flush after ctx is cancelled by signal handling below. + tel, err := telemetry.New(ctx, "controlplane", version, os.Getenv("NODE_ID")) + if err != nil { + return fmt.Errorf("init telemetry: %w", err) + } + defer func() { + if err := tel.Shutdown(context.Background()); err != nil { + log.Warn().Err(err).Msg("telemetry shutdown") + } + }() + if err := tel.StartRuntimeInstrumentation(); err != nil { + log.Warn().Err(err).Msg("runtime instrumentation") + } + + // Connect to PostgreSQL with otelpgx tracer so every sqlc query is a + // child span of the request that issued it. + pgxCfg, err := pgxpool.ParseConfig(cfg.DatabaseURL) + if err != nil { + return fmt.Errorf("parse database URL: %w", err) + } + pgxCfg.ConnConfig.Tracer = otelpgx.NewTracer() + dbPool, err := pgxpool.NewWithConfig(ctx, pgxCfg) if err != nil { return fmt.Errorf("connect to database: %w", err) } @@ -57,9 +88,12 @@ func run() error { } log.Info().Msg("connected to database") - // Connect to VMD via gRPC. + // Connect to VMD via gRPC. otelgrpc stats handler propagates trace + // context across the boundary so spans from controlplane continue + // inside vmd (once vmd is wired in commit 3). grpcConn, err := grpc.NewClient(cfg.VMDAddress, grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithStatsHandler(otelgrpc.NewClientHandler()), ) if err != nil { return fmt.Errorf("dial VMD gRPC: %w", err) @@ -70,7 +104,24 @@ func run() error { // Build handlers and router. vmdClient := newGRPCVMDClient(grpcConn) queries := dbq.New(dbPool) + handlers := api.NewHandlers(vmdClient, queries, cfg) + + // Host registry: resolves host_id → VMDClient via DB lookup + gRPC dial. + // Falls back to the default vmdClient when the registry has no entry. + dialVMD := func(addr string) (vmdclient.Client, error) { + conn, err := grpc.NewClient(addr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithStatsHandler(otelgrpc.NewClientHandler()), + ) + if err != nil { + return nil, err + } + return newGRPCVMDClient(conn), nil + } + handlers.Hosts = hostreg.New(queries, dialVMD) + handlers.Scheduler = &scheduler.LeastLoaded{DB: queries, DefaultHostID: cfg.DefaultHostID} + router := api.SetupRouter(ctx, handlers, dbPool) // Launch the timeout reaper. This goroutine destroys sandboxes whose @@ -78,6 +129,11 @@ func run() error { // to ctx so it exits on shutdown. handlers.StartTimeoutReaper(ctx, api.DefaultReaperConfig()) + // Launch the host health detector. Marks active hosts as unhealthy + // when their VMD heartbeat goes stale (>2 min). The scheduler + // excludes unhealthy hosts from placement. + go api.StartHostDetector(ctx, queries) + // Start HTTP server. srv := &http.Server{ Addr: ":" + cfg.Port, @@ -137,10 +193,11 @@ func newGRPCVMDClient(conn *grpc.ClientConn) *grpcVMDClient { } } -func (c *grpcVMDClient) CreateInstance(ctx context.Context, vmID string, vcpu, memMiB, diskMiB uint32, metadata map[string]string) (string, error) { +func (c *grpcVMDClient) CreateInstance(ctx context.Context, vmID string, vcpu, memMiB, diskMiB uint32, metadata map[string]string, envVars map[string]string) (string, uint32, uint32, error) { resp, err := c.client.CreateVM(ctx, &vmdpb.CreateVMRequest{ VmId: vmID, Metadata: metadata, + EnvVars: envVars, ResourceLimits: &vmdpb.ResourceLimits{ VcpuCount: vcpu, MemoryMib: memMiB, @@ -148,9 +205,14 @@ func (c *grpcVMDClient) CreateInstance(ctx context.Context, vmID string, vcpu, m }, }) if err != nil { - return "", fmt.Errorf("gRPC CreateVM: %w", err) + return "", 0, 0, fmt.Errorf("gRPC CreateVM: %w", err) } - return resp.IpAddress, nil + var actualVcpu, actualMemMiB uint32 + if rl := resp.GetResourceLimits(); rl != nil { + actualVcpu = rl.GetVcpuCount() + actualMemMiB = rl.GetMemoryMib() + } + return resp.IpAddress, actualVcpu, actualMemMiB, nil } func (c *grpcVMDClient) DestroyInstance(ctx context.Context, vmID string, force bool) error { @@ -175,16 +237,55 @@ func (c *grpcVMDClient) PauseInstance(ctx context.Context, vmID, snapshotDir str return resp.SnapshotPath, resp.MemFilePath, nil } -func (c *grpcVMDClient) ResumeInstance(ctx context.Context, vmID, snapshotPath, memPath string) (string, error) { +func (c *grpcVMDClient) ResumeInstance(ctx context.Context, vmID, snapshotPath, memPath string, envVars map[string]string) (string, uint32, uint32, error) { resp, err := c.client.ResumeVM(ctx, &vmdpb.ResumeVMRequest{ VmId: vmID, SnapshotPath: snapshotPath, MemFilePath: memPath, + EnvVars: envVars, + }) + if err != nil { + return "", 0, 0, fmt.Errorf("gRPC ResumeVM: %w", err) + } + var actualVcpu, actualMemMiB uint32 + if rl := resp.GetResourceLimits(); rl != nil { + actualVcpu = rl.GetVcpuCount() + actualMemMiB = rl.GetMemoryMib() + } + return resp.IpAddress, actualVcpu, actualMemMiB, nil +} + +// RestoreSnapshot is the stateless restore path — VMD creates a fresh VM +// instance from the snapshot files, bypassing any in-memory state. Used as +// a fallback when ResumeInstance returns NotFound (e.g. after VMD lost its +// map to a crash but the snapshot files are still on disk). +func (c *grpcVMDClient) RestoreSnapshot(ctx context.Context, vmID, snapshotPath, memPath string) (string, uint32, uint32, error) { + resp, err := c.client.RestoreSnapshot(ctx, &vmdpb.RestoreSnapshotRequest{ + VmId: vmID, + SnapshotPath: snapshotPath, + MemFilePath: memPath, + }) + if err != nil { + return "", 0, 0, fmt.Errorf("gRPC RestoreSnapshot: %w", err) + } + // RestoreSnapshotResponse doesn't carry ResourceLimits in the proto today, + // so we return 0,0 and let the caller keep the existing DB values. + return resp.IpAddress, 0, 0, nil +} + +// DeleteSnapshot removes the on-disk snapshot artifacts for a previous pause. +// Idempotent — VMD treats missing files as success. Path traversal is blocked +// VMD-side, so the control plane cannot use this to delete unrelated files. +func (c *grpcVMDClient) DeleteSnapshot(ctx context.Context, vmID, snapshotPath, memPath string) error { + _, err := c.client.DeleteSnapshot(ctx, &vmdpb.DeleteSnapshotRequest{ + VmId: vmID, + SnapshotPath: snapshotPath, + MemFilePath: memPath, }) if err != nil { - return "", fmt.Errorf("gRPC ResumeVM: %w", err) + return fmt.Errorf("gRPC DeleteSnapshot: %w", err) } - return resp.IpAddress, nil + return nil } func (c *grpcVMDClient) ExecCommand(ctx context.Context, vmID, command string, args []string, env map[string]string, workingDir string, timeoutS uint32) (string, string, int32, error) { @@ -248,91 +349,6 @@ func (c *grpcVMDClient) ExecCommandStream(ctx context.Context, vmID, command str } } -func (c *grpcVMDClient) UploadFile(ctx context.Context, vmID, path string, content io.Reader) (int64, error) { - stream, err := c.client.UploadFile(ctx) - if err != nil { - return 0, fmt.Errorf("gRPC UploadFile: %w", err) - } - - buf := make([]byte, 64*1024) - first := true - for { - n, readErr := content.Read(buf) - if n > 0 || first { - msg := &vmdpb.UploadFileRequest{Data: buf[:n]} - if first { - msg.VmId = vmID - msg.Path = path - first = false - } - if err := stream.Send(msg); err != nil { - return 0, fmt.Errorf("gRPC UploadFile send: %w", err) - } - } - if readErr != nil { - if readErr != io.EOF { - return 0, fmt.Errorf("gRPC UploadFile read content: %w", readErr) - } - break - } - } - - resp, err := stream.CloseAndRecv() - if err != nil { - return 0, fmt.Errorf("gRPC UploadFile close: %w", err) - } - return resp.BytesWritten, nil -} - -func (c *grpcVMDClient) DownloadFile(ctx context.Context, vmID, path string) (io.ReadCloser, error) { - streamCtx, streamCancel := context.WithCancel(ctx) - - stream, err := c.client.DownloadFile(streamCtx, &vmdpb.DownloadFileRequest{ - VmId: vmID, - Path: path, - }) - if err != nil { - streamCancel() - return nil, fmt.Errorf("gRPC DownloadFile: %w", err) - } - - first, err := stream.Recv() - if err != nil { - streamCancel() - if err == io.EOF { - return io.NopCloser(strings.NewReader("")), nil - } - return nil, fmt.Errorf("gRPC DownloadFile: %w", err) - } - - pr, pw := io.Pipe() - go func() { - defer pw.Close() - defer streamCancel() - if len(first.Data) > 0 { - if _, err := pw.Write(first.Data); err != nil { - return - } - } - for { - resp, err := stream.Recv() - if err != nil { - if err != io.EOF { - pw.CloseWithError(fmt.Errorf("gRPC DownloadFile recv: %w", err)) - } - return - } - if len(resp.Data) > 0 { - if _, err := pw.Write(resp.Data); err != nil { - return - } - } - } - }() - - return pr, nil -} - func (c *grpcVMDClient) UpdateSandboxNetwork(ctx context.Context, vmID string, allowedCIDRs, deniedCIDRs, allowedDomains []string) error { _, err := c.client.UpdateSandboxNetwork(ctx, &vmdpb.UpdateSandboxNetworkRequest{ VmId: vmID, diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 22c17cb..1cf0e26 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -2,24 +2,35 @@ package main import ( "context" + "encoding/hex" "net/http" "os" "os/signal" + "strings" "syscall" + "time" "github.com/rs/zerolog" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" + "github.com/superserve-ai/sandbox/internal/auth" "github.com/superserve-ai/sandbox/internal/proxy" + "github.com/superserve-ai/sandbox/internal/telemetry" ) +// version is set by ldflags at build time; falls back to "dev". +var version = "dev" + func main() { zerolog.TimeFieldFormat = zerolog.TimeFormatUnix log := zerolog.New(os.Stdout).With(). Timestamp(). Str("service", "proxy"). - Logger() + Logger(). + Hook(telemetry.ZerologTraceHook{}) addr := envOrDefault("PROXY_ADDR", ":5007") + redirectAddr := envOrDefault("PROXY_REDIRECT_ADDR", ":5008") vmdAddr := envOrDefault("VMD_ADDR", "http://127.0.0.1:9090") domain := envOrDefault("PROXY_DOMAIN", "sandbox.superserve.ai") @@ -32,22 +43,110 @@ func main() { ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer stop() + tel, err := telemetry.New(ctx, "proxy", version, os.Getenv("NODE_ID")) + if err != nil { + log.Fatal().Err(err).Msg("init telemetry") + } + defer func() { + if err := tel.Shutdown(context.Background()); err != nil { + log.Warn().Err(err).Msg("telemetry shutdown") + } + }() + if err := tel.StartRuntimeInstrumentation(); err != nil { + log.Warn().Err(err).Msg("runtime instrumentation") + } + resolver := proxy.NewVMDResolver(vmdAddr) proxyHandler := proxy.NewHandler(domain, resolver, log) proxyHandler.StartSweeper(ctx) - // Wrap with a health check endpoint for the GCP LB health probe. - // The LB hits /health directly on the instance IP (not a sandbox URL), - // so the proxy handler would reject it — intercept it first. + // Data-plane auth — the HMAC seed is shared with the control plane. + // Both sides derive per-sandbox access tokens as HMAC-SHA256(seed, sandboxID). + seedHex := os.Getenv("SANDBOX_ACCESS_TOKEN_SEED") + originsEnv := os.Getenv("TERMINAL_ALLOWED_ORIGINS") + required := os.Getenv("REQUIRE_DATA_PLANE") == "1" + + if seedHex == "" { + if required { + log.Fatal().Msg("REQUIRE_DATA_PLANE=1 but SANDBOX_ACCESS_TOKEN_SEED missing") + } + log.Warn().Msg("data-plane endpoints disabled (SANDBOX_ACCESS_TOKEN_SEED not configured)") + } else { + seed, err := hex.DecodeString(seedHex) + if err != nil { + log.Fatal().Err(err).Msg("SANDBOX_ACCESS_TOKEN_SEED is not valid hex") + } + if err := auth.ValidateSeed(seed); err != nil { + log.Fatal().Err(err).Msg("SANDBOX_ACCESS_TOKEN_SEED invalid") + } + + proxyHandler.WithAuth(seed) + proxyHandler.WithFiles() + log.Info().Msg("files endpoint enabled") + + if originsEnv != "" { + origins := splitCSV(originsEnv) + proxyHandler.WithTerminal(origins) + log.Info().Strs("allowed_origins", origins).Msg("terminal endpoint enabled") + } else { + log.Warn().Msg("terminal endpoint disabled (TERMINAL_ALLOWED_ORIGINS not configured)") + } + } + + // Health check for the GCP LB. Only responds on non-sandbox hosts + // so the boxd-label lockdown isn't bypassed. + domainSuffix := "." + domain mux := http.NewServeMux() mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + host := r.Host + if i := strings.IndexByte(host, ':'); i >= 0 { + host = host[:i] + } + if strings.HasSuffix(host, domainSuffix) { + proxyHandler.ServeHTTP(w, r) + return + } w.WriteHeader(http.StatusOK) }) mux.Handle("/", proxyHandler) - if err := proxy.ListenAndServe(ctx, addr, mux, log); err != nil { + // HTTP→HTTPS redirect listener with graceful shutdown. + redirectMux := http.NewServeMux() + redirectMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + host := r.Host + if i := strings.IndexByte(host, ':'); i >= 0 { + host = host[:i] + } + http.Redirect(w, r, "https://"+host+r.URL.RequestURI(), http.StatusMovedPermanently) + }) + redirectSrv := &http.Server{ + Addr: redirectAddr, + Handler: redirectMux, + } + go func() { + log.Info().Str("addr", redirectAddr).Msg("starting HTTP→HTTPS redirect listener") + if err := redirectSrv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Error().Err(err).Msg("redirect listener error") + } + }() + + // otelhttp wraps the entire mux so every forwarded request becomes a + // span. The /health filter avoids drowning the trace backend in + // noise from the LB liveness probe. + tracedMux := otelhttp.NewHandler(mux, "proxy", + otelhttp.WithFilter(func(r *http.Request) bool { + return r.URL.Path != "/health" + }), + ) + if err := proxy.ListenAndServe(ctx, addr, tracedMux, log); err != nil { log.Fatal().Err(err).Msg("proxy error") } + + // Shut down the redirect listener cleanly. + shutCtx, shutCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutCancel() + _ = redirectSrv.Shutdown(shutCtx) + log.Info().Msg("proxy stopped") } @@ -57,3 +156,13 @@ func envOrDefault(key, fallback string) string { } return fallback } + +func splitCSV(v string) []string { + var out []string + for _, s := range strings.Split(v, ",") { + if t := strings.TrimSpace(s); t != "" { + out = append(out, t) + } + } + return out +} diff --git a/cmd/vmd/main.go b/cmd/vmd/main.go index 1294f0a..2d7725b 100644 --- a/cmd/vmd/main.go +++ b/cmd/vmd/main.go @@ -7,20 +7,28 @@ import ( "os" "os/exec" "os/signal" + "path/filepath" "slices" "strconv" "sync" "syscall" "time" + "github.com/jackc/pgx/v5/pgxpool" "github.com/rs/zerolog" + "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "google.golang.org/grpc" + dbq "github.com/superserve-ai/sandbox/internal/db" "github.com/superserve-ai/sandbox/internal/network" + "github.com/superserve-ai/sandbox/internal/telemetry" "github.com/superserve-ai/sandbox/internal/vm" "github.com/superserve-ai/sandbox/proto/vmdpb" ) +// version is set by ldflags at build time; falls back to "dev". +var version = "dev" + // Config holds the daemon configuration sourced from environment variables. type Config struct { FirecrackerBin string @@ -31,6 +39,21 @@ type Config struct { RunDir string GRPCPort int HostInterface string + + // HostID identifies this bare-metal host in the `host` table. Used by + // the reconciler to scope its DB queries ("sandboxes on my host"). + HostID string + + // DatabaseURL is optional. When set, the reconciler does three-way + // reconciliation (BoltDB ↔ systemd ↔ control plane DB) and writes + // audit log entries. When unset, the reconciler only detects drift + // between BoltDB and systemd. + DatabaseURL string + + // ControlPlaneURL is the base URL of the control plane API. Used by + // the heartbeat goroutine to POST liveness. Optional — if unset, + // heartbeat is disabled. + ControlPlaneURL string } func loadConfig() (Config, error) { @@ -48,6 +71,9 @@ func loadConfig() (Config, error) { RunDir: envOrDefault("RUN_DIR", "/var/lib/sandbox/rundir"), GRPCPort: port, HostInterface: envOrDefault("HOST_INTERFACE", "eth0"), + HostID: envOrDefault("HOST_ID", "default"), + DatabaseURL: os.Getenv("DATABASE_URL"), + ControlPlaneURL: os.Getenv("CONTROL_PLANE_URL"), } if cfg.KernelPath == "" { @@ -181,7 +207,8 @@ func main() { log := zerolog.New(os.Stdout).With(). Timestamp(). Str("service", "vmd"). - Logger() + Logger(). + Hook(telemetry.ZerologTraceHook{}) cfg, err := loadConfig() if err != nil { @@ -212,6 +239,21 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // Telemetry. No-op when OTEL_EXPORTER_OTLP_ENDPOINT is unset. Use HostID + // as the node identifier so per-host metrics are distinguishable. + tel, telErr := telemetry.New(ctx, "vmd", version, cfg.HostID) + if telErr != nil { + log.Fatal().Err(telErr).Msg("init telemetry") + } + defer func() { + if err := tel.Shutdown(context.Background()); err != nil { + log.Warn().Err(err).Msg("telemetry shutdown") + } + }() + if err := tel.StartRuntimeInstrumentation(); err != nil { + log.Warn().Err(err).Msg("runtime instrumentation") + } + lc := newLifecycle(log) // ---- Network manager + host firewall ---- @@ -221,6 +263,12 @@ func main() { } lc.addCloser("network manager", func(_ context.Context) error { return netMgr.Close() }) + // ---- Pre-allocate network slots ---- + // Keeps 5 ready-to-use network namespaces so sandbox creation grabs + // one in microseconds instead of running ~11 shell commands (~10-30ms). + netPool := netMgr.StartPool(ctx, network.PoolConfig{}) + lc.addCloser("network pool", func(_ context.Context) error { netPool.Stop(); return nil }) + // ---- VM manager ---- mgr, err := vm.NewManager(vm.ManagerConfig{ FirecrackerBin: cfg.FirecrackerBin, @@ -233,18 +281,10 @@ func main() { if err != nil { log.Fatal().Err(err).Msg("failed to initialize VM manager") } - lc.addCloser("vm manager: active sandboxes", func(_ context.Context) error { - mgr.ShutdownAll() - return nil - }) - lc.addCloser("vm manager: template", func(_ context.Context) error { - mgr.CleanupTemplate() - return nil - }) // ---- TCP egress proxy ---- - // The nftables firewall in each sandbox namespace REDIRECTs TCP traffic - // to these ports for HTTP Host header / TLS SNI inspection. + // Must be set before ReattachAll or any VM operations so domain + // filtering is active from the start. const maxConnsPerSandbox = 256 egressProxy := network.NewEgressProxy( network.DefaultHTTPProxyPort, @@ -257,6 +297,74 @@ func main() { netMgr.SetEgressProxy(egressProxy) lc.start("egress proxy", func() error { return egressProxy.Start(ctx) }) + // ---- BoltDB state store ---- + statePath := envOrDefault("VMD_STATE_PATH", filepath.Join(filepath.Dir(cfg.RunDir), "vmd.db")) + stateStore, err := vm.OpenStateStore(statePath) + if err != nil { + log.Fatal().Err(err).Str("path", statePath).Msg("failed to open state store") + } + mgr.SetStateStore(stateStore) + lc.addCloser("state store", func(_ context.Context) error { return stateStore.Close() }) + + // ---- Reattach to running VMs from previous VMD lifetime ---- + reattached, stale := mgr.ReattachAll(ctx) + if reattached > 0 || stale > 0 { + log.Info().Int("reattached", reattached).Int("stale", stale).Msg("startup reattach complete") + } + + // ---- Optional DB connection for the reconciler ---- + // VMD does not need the DB for its request path (that stays on gRPC). + // The reconciler uses the DB for three-way drift detection and audit + // logging. If DATABASE_URL is unset, the reconciler falls back to a + // BoltDB ↔ systemd comparison only. + var reconcilerDB *dbq.Queries + if cfg.DatabaseURL != "" { + dbPool, dbErr := pgxpool.New(ctx, cfg.DatabaseURL) + if dbErr != nil { + log.Fatal().Err(dbErr).Msg("failed to connect to database for reconciler") + } + if err := dbPool.Ping(ctx); err != nil { + log.Fatal().Err(err).Msg("failed to ping database for reconciler") + } + reconcilerDB = dbq.New(dbPool) + lc.addCloser("reconciler db pool", func(_ context.Context) error { + dbPool.Close() + return nil + }) + log.Info().Msg("reconciler DB connection ready") + } else { + log.Warn().Msg("DATABASE_URL unset — reconciler will run in BoltDB↔systemd-only mode") + } + + // ---- Continuous reconciler ---- + reconcilerCfg := vm.DefaultReconcilerConfig() + reconcilerCfg.HostID = cfg.HostID + reconcilerCfg.DB = reconcilerDB + reconciler := vm.NewReconciler(mgr, reconcilerCfg) + lc.start("reconciler", func() error { reconciler.Run(ctx); return nil }) + + // ---- Heartbeat to control plane ---- + if cfg.ControlPlaneURL != "" { + lc.start("heartbeat", func() error { + vm.StartHeartbeat(ctx, vm.HeartbeatConfig{ + ControlPlaneURL: cfg.ControlPlaneURL, + HostID: cfg.HostID, + Token: os.Getenv("INTERNAL_API_TOKEN"), + }, log) + return nil + }) + } else { + log.Warn().Msg("CONTROL_PLANE_URL unset — heartbeat disabled") + } + + lc.addCloser("vm manager: active sandboxes", func(_ context.Context) error { + mgr.ShutdownAll() + return nil + }) + // Template files are NOT cleaned up on shutdown — they persist on + // disk so the next startup can reuse them via hash caching instead + // of cold-booting a new template (~3s saved per restart). + // ---- Default template ---- // Boot a throwaway VM from the base image, snapshot it, keep the // snapshot on disk. Every subsequent CreateVM restores from this @@ -283,6 +391,7 @@ func main() { } grpcServer := grpc.NewServer( grpc.MaxRecvMsgSize(64 << 20), // 64 MiB + grpc.StatsHandler(otelgrpc.NewServerHandler()), ) vmdpb.RegisterVMDaemonServer(grpcServer, vm.NewGRPCAdapter(mgr)) lc.start("grpc server", func() error { diff --git a/db/queries/hosts.sql b/db/queries/hosts.sql new file mode 100644 index 0000000..5407a88 --- /dev/null +++ b/db/queries/hosts.sql @@ -0,0 +1,62 @@ +-- name: GetHost :one +SELECT * FROM host WHERE id = $1; + +-- name: ListActiveHosts :many +SELECT * FROM host +WHERE status = 'active' +ORDER BY created_at ASC; + +-- name: ListHosts :many +SELECT * FROM host +ORDER BY created_at ASC; + +-- name: CreateHost :one +INSERT INTO host (id, vmd_addr, proxy_addr, region, capacity_memory_mib, capacity_vcpus) +VALUES ($1, $2, $3, $4, $5, $6) +RETURNING *; + +-- name: UpdateHostStatus :exec +UPDATE host +SET status = $2, updated_at = now() +WHERE id = $1; + +-- name: UpdateHostHeartbeat :one +-- Returns the host row so the caller can verify the host exists. Also +-- re-activates unhealthy hosts that resume heartbeating — this is the +-- automatic recovery path after a transient network outage. +UPDATE host +SET last_heartbeat_at = now(), + status = CASE WHEN status = 'unhealthy' THEN 'active' ELSE status END, + updated_at = now() +WHERE id = $1 +RETURNING *; + +-- name: MarkHostUnhealthy :exec +UPDATE host +SET status = 'unhealthy', updated_at = now() +WHERE id = $1 AND status = 'active'; + +-- name: ListStaleHosts :many +-- Returns active hosts whose last heartbeat is older than the given +-- threshold. Used by the unhealthy-host detector. +SELECT * FROM host +WHERE status = 'active' + AND last_heartbeat_at IS NOT NULL + AND last_heartbeat_at < $1 +ORDER BY last_heartbeat_at ASC; + +-- name: ListActiveHostsByLoad :many +-- Returns active hosts sorted by current sandbox count (ascending). +-- The scheduler picks the first row (least loaded host). One query +-- replaces N per-host lookups. +SELECT h.id, h.vmd_addr, h.proxy_addr, h.region, h.status, + h.capacity_memory_mib, h.capacity_vcpus, + h.last_heartbeat_at, h.created_at, h.updated_at, + COALESCE(COUNT(s.id), 0)::int AS active_sandbox_count +FROM host h +LEFT JOIN sandbox s ON s.host_id = h.id + AND s.status IN ('active', 'starting') + AND s.destroyed_at IS NULL +WHERE h.status = 'active' +GROUP BY h.id +ORDER BY COUNT(s.id) ASC; diff --git a/db/queries/reconciler_log.sql b/db/queries/reconciler_log.sql new file mode 100644 index 0000000..6ad9fec --- /dev/null +++ b/db/queries/reconciler_log.sql @@ -0,0 +1,9 @@ +-- name: InsertReconcilerLog :exec +INSERT INTO reconciler_log (host_id, sandbox_id, action, reason, drift_kind) +VALUES ($1, $2, $3, $4, $5); + +-- name: ListReconcilerLogByHost :many +SELECT * FROM reconciler_log +WHERE host_id = $1 +ORDER BY created_at DESC +LIMIT $2; diff --git a/db/queries/sandboxes.sql b/db/queries/sandboxes.sql index 3cc38fa..0d9649e 100644 --- a/db/queries/sandboxes.sql +++ b/db/queries/sandboxes.sql @@ -1,6 +1,10 @@ -- name: CreateSandbox :one -INSERT INTO sandbox (team_id, name, status, vcpu_count, memory_mib, host_id, ip_address, pid, snapshot_id, timeout_seconds, metadata) -VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) +-- ID is supplied by the caller (generated in Go via uuid.New()) rather +-- than defaulted in SQL, so the caller can parallelize this INSERT with +-- the VMD CreateVM call — both need the same sandbox_id and generating +-- it client-side lets them run concurrently instead of strictly serially. +INSERT INTO sandbox (id, team_id, name, status, vcpu_count, memory_mib, host_id, ip_address, pid, snapshot_id, timeout_seconds, metadata) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) RETURNING *; -- name: GetSandbox :one @@ -33,10 +37,14 @@ UPDATE sandbox SET host_id = $2, ip_address = $3, pid = $4, updated_at = now() WHERE id = $1 AND team_id = $5 AND destroyed_at IS NULL; --- name: UpdateSandboxLastActivity :exec +-- name: ActivateSandbox :exec UPDATE sandbox -SET last_activity_at = now(), updated_at = now() -WHERE id = $1 AND team_id = $2 AND destroyed_at IS NULL; +SET status = 'active', + vcpu_count = $2, + memory_mib = $3, + ip_address = $4, + updated_at = now() +WHERE id = $1 AND team_id = $5 AND destroyed_at IS NULL; -- name: SetSandboxSnapshot :exec UPDATE sandbox @@ -51,12 +59,91 @@ WHERE id = $1 AND team_id = $2 AND destroyed_at IS NULL; -- name: SandboxExists :one SELECT EXISTS(SELECT 1 FROM sandbox WHERE id = $1 AND team_id = $2 AND destroyed_at IS NULL); --- name: ListIdleSandboxes :many +-- name: ListSandboxesByHost :many +-- Used by the VMD reconciler to find all non-deleted sandboxes scheduled on +-- this host. Includes both active and paused sandboxes because the reconciler +-- needs to validate both states (active → systemd unit, paused → snapshot file). SELECT * FROM sandbox -WHERE status = 'idle' - AND destroyed_at IS NULL - AND last_activity_at < $1 -ORDER BY last_activity_at ASC; +WHERE host_id = $1 AND destroyed_at IS NULL; + +-- name: MarkSandboxFailed :exec +-- Used by the reconciler to mark a sandbox failed when VMD detects it is +-- actually gone. No team_id filter — the reconciler runs with host scope, +-- not team scope. +UPDATE sandbox +SET status = 'failed', updated_at = now() +WHERE id = $1 AND destroyed_at IS NULL; + +-- name: BeginPause :one +-- Atomic ownership + state check + transition to 'pausing'. Replaces the +-- GetSandbox → check status → UpdateSandboxStatus sequence on the pause +-- hot path, collapsing two DB roundtrips into one. The WHERE clause +-- enforces the invariant (only active, non-deleted sandboxes owned by +-- this team can be paused); a 0-row result means "no such sandbox OR +-- wrong team OR not currently active", and the caller disambiguates via +-- a fallback GetSandbox in the rare error path. +UPDATE sandbox +SET status = 'pausing', updated_at = now() +WHERE id = $1 AND team_id = $2 AND destroyed_at IS NULL AND status = 'active' +RETURNING *; + +-- name: FinalizePause :one +-- Atomically insert the snapshot row, link it to the sandbox, and flip +-- status from 'pausing' to 'paused'. Replaces the sequence +-- CreateSnapshot → SetSandboxSnapshot → UpdateSandboxStatus, collapsing +-- three DB roundtrips into one. +-- +-- The INSERT is gated on a `WHERE EXISTS` against a non-deleted sandbox +-- in the same query. This prevents the common race where a sandbox is +-- soft-deleted before FinalizePause runs — without the gate, the CTE +-- INSERT would always execute (per PostgreSQL's rule that data-modifying +-- CTEs run independently of the main query), producing an orphan snapshot +-- row and a snapshot file on disk with no owner. A concurrent delete that +-- commits BETWEEN the EXISTS check and the INSERT under READ COMMITTED +-- can still race, but that window is microseconds and the resulting +-- orphan is detectable/cleanable by a background job. +-- +-- Also captures the sandbox's previous snapshot_id (before we overwrite it) +-- so the caller can garbage-collect the now-unreachable prior snapshot +-- asynchronously. Returns NULL for the first pause of a sandbox. +-- +-- When either the sandbox is missing/deleted or the INSERT did not fire, +-- the query returns 0 rows and the caller maps that to ErrSandboxGone. +WITH target AS ( + SELECT id, team_id, snapshot_id AS prev_snapshot_id FROM sandbox + WHERE id = $1 AND team_id = $2 AND destroyed_at IS NULL +), +new_snapshot AS ( + INSERT INTO snapshot (sandbox_id, team_id, path, mem_path, size_bytes, saved, name, trigger) + SELECT target.id, target.team_id, $3, $4, $5, $6, $7, $8 FROM target + RETURNING snapshot.id AS snap_id +) +UPDATE sandbox +SET snapshot_id = (SELECT snap_id FROM new_snapshot), + status = 'paused', + updated_at = now() +FROM new_snapshot +WHERE sandbox.id = $1 AND sandbox.team_id = $2 AND sandbox.destroyed_at IS NULL +RETURNING + new_snapshot.snap_id::uuid AS snapshot_id, + (SELECT prev_snapshot_id FROM target) AS prev_snapshot_id; + +-- name: GetSnapshotForCleanup :one +-- Fetch a snapshot's paths for garbage collection. Returns only non-saved +-- snapshots — saved=true rows are reserved for the (future) user-named +-- template feature and must never be auto-deleted. A 0-row result means +-- the row was already gone or is a saved snapshot; either way the caller +-- should skip deletion. +SELECT id, team_id, path, mem_path +FROM snapshot +WHERE id = $1 AND team_id = $2 AND saved = false; + +-- name: DeleteSnapshotRow :execrows +-- Remove a snapshot row. Guarded by saved=false so a future template feature +-- can rely on row durability — auto-GC callers cannot accidentally nuke a +-- user-saved snapshot even if they passed the wrong ID. +DELETE FROM snapshot +WHERE id = $1 AND team_id = $2 AND saved = false; -- name: UpdateSandboxNetworkConfig :exec UPDATE sandbox @@ -78,7 +165,7 @@ WHERE id = $1 AND team_id = $2 AND destroyed_at IS NULL; -- skip rows already being processed, so multi-replica Cloud Run deployments -- do not double-process the same sandbox. -- --- Only 'active' sandboxes are claimed — idle sandboxes are already stopped, +-- Only 'active' sandboxes are claimed — paused sandboxes are already stopped, -- and transient states (starting, pausing) are skipped to avoid racing with -- in-progress operations. The 60-second grace window prevents reaping a sandbox -- that was just created with a very short timeout before it finishes starting up. diff --git a/db/queries/snapshots.sql b/db/queries/snapshots.sql index b8232df..c6d4abf 100644 --- a/db/queries/snapshots.sql +++ b/db/queries/snapshots.sql @@ -1,9 +1,19 @@ -- name: CreateSnapshot :one -INSERT INTO snapshot (sandbox_id, team_id, path, size_bytes, saved, name, trigger) -VALUES ($1, $2, $3, $4, $5, $6, $7) +INSERT INTO snapshot (sandbox_id, team_id, path, mem_path, size_bytes, saved, name, trigger) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING *; -- name: GetSnapshot :one +-- Team-scoped snapshot lookup for user-facing handlers. The join on +-- team_id enforces tenant isolation at the SQL layer so callers cannot +-- accidentally leak another team's snapshot metadata by forgetting the +-- in-Go team check. +SELECT * FROM snapshot +WHERE id = $1 AND team_id = $2; + +-- name: GetSnapshotByID :one +-- Unscoped snapshot lookup for internal (host-scoped) code paths such as +-- the VMD reconciler. DO NOT call from user-facing handlers. SELECT * FROM snapshot WHERE id = $1; diff --git a/deploy/dev/otel-collector-stdout.yaml b/deploy/dev/otel-collector-stdout.yaml new file mode 100644 index 0000000..5cb4a04 --- /dev/null +++ b/deploy/dev/otel-collector-stdout.yaml @@ -0,0 +1,33 @@ +# Local-dev OpenTelemetry Collector — prints everything to stdout. +# +# Use this when iterating on instrumentation code and you just want to see +# the SDK is emitting what you expect, without standing up Tempo/Loki/Mimir. +# +# Run: +# otelcol-contrib --config=deploy/dev/otel-collector-stdout.yaml +# +# Then in another terminal: +# export OTEL_EXPORTER_OTLP_ENDPOINT=localhost:4317 +# ./controlplane (or vmd / proxy) + +receivers: + otlp: + protocols: + grpc: + endpoint: 0.0.0.0:4317 + +exporters: + debug: + verbosity: detailed + +service: + pipelines: + traces: + receivers: [otlp] + exporters: [debug] + metrics: + receivers: [otlp] + exporters: [debug] + logs: + receivers: [otlp] + exporters: [debug] diff --git a/deploy/firecracker-netns@.service b/deploy/firecracker-netns@.service new file mode 100644 index 0000000..e0120fd --- /dev/null +++ b/deploy/firecracker-netns@.service @@ -0,0 +1,12 @@ +[Unit] +Description=Network namespace for Firecracker VM %i + +[Service] +Type=oneshot +RemainAfterExit=yes + +# Network setup/teardown scripts are generated by VMD and placed on the +# host. They create/destroy the network namespace, veth pair, TAP device, +# and nftables rules for the sandbox. +ExecStart=/usr/local/bin/fc-netns-setup %i +ExecStop=/usr/local/bin/fc-netns-teardown %i diff --git a/deploy/firecracker@.service b/deploy/firecracker@.service new file mode 100644 index 0000000..5cc8946 --- /dev/null +++ b/deploy/firecracker@.service @@ -0,0 +1,18 @@ +[Unit] +Description=Firecracker VM %i +After=network-online.target +StopWhenUnneeded=no + +[Service] +Type=simple +Slice=sandboxes.slice + +ExecStart=/var/lib/sandbox/rundir/%i/start.sh +ExecStopPost=/usr/local/bin/fc-cleanup %i + +Restart=no +KillMode=mixed +TimeoutStopSec=10 + +StandardOutput=journal +StandardError=journal diff --git a/deploy/otel-collector.yaml b/deploy/otel-collector.yaml new file mode 100644 index 0000000..f9a20d0 --- /dev/null +++ b/deploy/otel-collector.yaml @@ -0,0 +1,105 @@ +# Per-host OpenTelemetry Collector for superserve. +# +# Topology: +# apps (vmd, controlplane, proxy) → this collector :4317 → central pool +# +# This collector lives on the same machine as the apps so its only network +# hop to them is loopback. It batches and gzips before sending to the +# central pool, which keeps egress small and decouples app health from +# central-pool availability — a flap on the central side cannot back-pressure +# vmd. +# +# Set the central endpoint via env var (kept out of git so the same file +# works in staging and prod): +# export OTEL_CENTRAL_ENDPOINT=otel.staging.superserve-internal:4317 +# export OTEL_CENTRAL_AUTH=Bearer\ # if backend needs it + +receivers: + otlp: + protocols: + grpc: + endpoint: 0.0.0.0:4317 + http: + endpoint: 0.0.0.0:4318 + + # Host metrics — CPU, memory, disk, network, process count. + # 30s scrape; high-cardinality fields disabled to keep series count flat. + hostmetrics: + collection_interval: 30s + scrapers: + cpu: + memory: + disk: + # Real block devices only — skip loop, ram, and synthetic mounts. + include_devices: + match_type: regexp + devices: ["sd[a-z].*", "nvme.*", "vd[a-z].*"] + filesystem: + exclude_mount_points: + match_type: regexp + mount_points: ["/var/lib/docker/.*", "/proc/.*", "/sys/.*", "/run/.*"] + load: + network: + # Skip virtual interfaces — docker, veth, tap pollute series count. + exclude: + match_type: regexp + interfaces: ["docker.*", "veth.*", "br-.*", "tap.*", "lo"] + +processors: + batch: + timeout: 5s + send_batch_size: 512 + send_batch_max_size: 1024 + + # memory_limiter prevents the collector from OOMing the host if the + # central side is unavailable for a long time. + memory_limiter: + check_interval: 1s + limit_percentage: 75 + spike_limit_percentage: 25 + + # Add host_id to every signal for host-level pivoting in the backend. + resource: + attributes: + - key: host.id + from_attribute: host.name + action: insert + +exporters: + otlp: + endpoint: ${env:OTEL_CENTRAL_ENDPOINT} + compression: gzip + headers: + authorization: ${env:OTEL_CENTRAL_AUTH} + sending_queue: + enabled: true + queue_size: 5000 + retry_on_failure: + enabled: true + initial_interval: 5s + max_interval: 30s + max_elapsed_time: 5m + + # Useful for sanity-checking what's flowing during incidents: + # journalctl -u superserve-otel-collector | grep -A2 "Span\|Metric" + # Disabled by default — comment-in via the pipelines block below. + debug: + verbosity: basic + +service: + telemetry: + logs: + level: info + pipelines: + traces: + receivers: [otlp] + processors: [memory_limiter, batch, resource] + exporters: [otlp] + metrics: + receivers: [otlp, hostmetrics] + processors: [memory_limiter, batch, resource] + exporters: [otlp] + logs: + receivers: [otlp] + processors: [memory_limiter, batch, resource] + exporters: [otlp] diff --git a/deploy/proxy.service b/deploy/proxy.service index e9aecb5..d5b77ef 100644 --- a/deploy/proxy.service +++ b/deploy/proxy.service @@ -1,6 +1,6 @@ [Unit] Description=Superserve Edge Proxy -After=network-online.target vmd.service +After=network-online.target vmd.service superserve-otel-collector.service Wants=network-online.target [Service] @@ -10,9 +10,15 @@ Restart=always RestartSec=2 Environment=PROXY_ADDR=:5007 +Environment=PROXY_REDIRECT_ADDR=:5008 Environment=VMD_ADDR=http://127.0.0.1:9090 Environment=PROXY_DOMAIN=sandbox.superserve.ai +# Load host-local overrides (SANDBOX_ACCESS_TOKEN_SEED, TERMINAL_ALLOWED_ORIGINS, +# REQUIRE_DATA_PLANE) from an env file. Leading `-` makes the file optional — +# the proxy still starts without it, but data-plane endpoints will be disabled. +EnvironmentFile=-/etc/sandbox/proxy.env + # Run as a dedicated unprivileged user DynamicUser=yes NoNewPrivileges=true diff --git a/deploy/sandboxes.slice b/deploy/sandboxes.slice new file mode 100644 index 0000000..aaccd0f --- /dev/null +++ b/deploy/sandboxes.slice @@ -0,0 +1,13 @@ +[Unit] +Description=Sandbox VMs cgroup slice + +[Slice] +# Resource accounting for all sandbox Firecracker units. +CPUAccounting=yes +MemoryAccounting=yes +IOAccounting=yes + +# Safety ceiling: sandboxes cannot consume the entire host. +# Per-sandbox limits are set via drop-in files on each unit. +MemoryMax=95% +CPUQuota=90% diff --git a/deploy/superserve-otel-collector.service b/deploy/superserve-otel-collector.service new file mode 100644 index 0000000..0b954ea --- /dev/null +++ b/deploy/superserve-otel-collector.service @@ -0,0 +1,31 @@ +[Unit] +Description=Superserve OpenTelemetry Collector (per-host) +# Start before the apps so they have a collector listening when they boot. +# Apps are no-op when they can't reach the collector, but starting in the +# right order avoids harmless connect-refused warnings in journald. +Before=superserve-vmd.service proxy.service +After=network-online.target +Wants=network-online.target + +[Service] +Type=simple +ExecStart=/usr/local/bin/otelcol-contrib --config=/etc/sandbox/otel-collector.yaml +EnvironmentFile=/etc/sandbox/otel-collector.env + +# Restart on crash but back off so we don't loop hard on a config error. +Restart=always +RestartSec=5 +KillMode=process + +# Cap memory — collector should never starve apps. memory_limiter inside +# the config also enforces this in software for finer-grained shedding. +MemoryMax=512M + +LimitNOFILE=65536 + +StandardOutput=journal +StandardError=journal +SyslogIdentifier=superserve-otel-collector + +[Install] +WantedBy=multi-user.target diff --git a/deploy/superserve-vmd.service b/deploy/superserve-vmd.service new file mode 100644 index 0000000..28b002f --- /dev/null +++ b/deploy/superserve-vmd.service @@ -0,0 +1,33 @@ +[Unit] +Description=Superserve VM Daemon +After=network-online.target superserve-otel-collector.service +Wants=network-online.target +# Soft dependency — vmd runs fine if the collector is missing (telemetry +# becomes a no-op via OTEL_EXPORTER_OTLP_ENDPOINT being unset there). + +[Service] +Type=simple +ExecStart=/usr/local/bin/vmd + +# DO NOT add ExecStartPre commands that kill Firecracker processes or +# remove network namespaces. VMs are managed by systemd units +# (firecracker@.service) and must survive VMD restarts. + +# mixed: send SIGTERM to the main process only. Do NOT kill the cgroup +# (which would include firecracker@ units if they were children — they +# aren't in systemd mode, but this is defense-in-depth). +KillMode=process +Restart=always +RestartSec=2 + +# Environment +EnvironmentFile=/etc/sandbox/vmd.env + +LimitNOFILE=65536 + +StandardOutput=journal +StandardError=journal +SyslogIdentifier=vmd + +[Install] +WantedBy=multi-user.target diff --git a/go.mod b/go.mod index ae9bca5..f8a5be5 100644 --- a/go.mod +++ b/go.mod @@ -4,30 +4,55 @@ go 1.25.0 require ( connectrpc.com/connect v1.19.1 + connectrpc.com/otelconnect v0.9.0 + github.com/coder/websocket v1.8.14 github.com/creack/pty v1.1.24 + github.com/exaring/otelpgx v0.10.0 github.com/gin-gonic/gin v1.12.0 github.com/go-openapi/errors v0.22.7 github.com/go-openapi/runtime v0.29.3 github.com/go-openapi/strfmt v0.26.1 github.com/go-openapi/swag v0.25.5 github.com/go-openapi/validate v0.25.2 + github.com/google/nftables v0.3.0 github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.9.1 github.com/rs/zerolog v1.34.0 + github.com/vishvananda/netns v0.0.5 + go.etcd.io/bbolt v1.4.3 + go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.68.0 + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 + go.opentelemetry.io/contrib/instrumentation/runtime v0.68.0 + go.opentelemetry.io/otel v1.43.0 + go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.19.0 + go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0 + go.opentelemetry.io/otel/log v0.19.0 + go.opentelemetry.io/otel/sdk v1.43.0 + go.opentelemetry.io/otel/sdk/log v0.19.0 + go.opentelemetry.io/otel/sdk/metric v1.43.0 + go.opentelemetry.io/otel/trace v1.43.0 golang.org/x/net v0.52.0 - google.golang.org/grpc v1.79.3 + golang.org/x/sync v0.20.0 + golang.org/x/sys v0.42.0 + golang.org/x/time v0.15.0 + google.golang.org/grpc v1.80.0 google.golang.org/protobuf v1.36.11 ) require ( - github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/gopkg v0.1.4 // indirect github.com/bytedance/sonic v1.15.0 // indirect - github.com/bytedance/sonic/loader v0.5.0 // indirect + github.com/bytedance/sonic/loader v0.5.1 // indirect + github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/gabriel-vasile/mimetype v1.4.12 // indirect - github.com/gin-contrib/sse v1.1.0 // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/gabriel-vasile/mimetype v1.4.13 // indirect + github.com/gin-contrib/sse v1.1.1 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/analysis v0.24.3 // indirect @@ -48,12 +73,12 @@ require ( github.com/go-openapi/swag/yamlutils v0.25.5 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-playground/validator/v10 v10.30.1 // indirect + github.com/go-playground/validator/v10 v10.30.2 // indirect github.com/go-viper/mapstructure/v2 v2.5.0 // indirect - github.com/goccy/go-json v0.10.5 // indirect + github.com/goccy/go-json v0.10.6 // indirect github.com/goccy/go-yaml v1.19.2 // indirect github.com/google/go-cmp v0.7.0 // indirect - github.com/google/nftables v0.3.0 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect @@ -67,24 +92,20 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/oklog/ulid/v2 v2.1.1 // indirect - github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/pelletier/go-toml/v2 v2.3.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/quic-go/qpack v0.6.0 // indirect github.com/quic-go/quic-go v0.59.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.3.1 // indirect - github.com/vishvananda/netns v0.0.5 // indirect go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect - go.opentelemetry.io/otel v1.41.0 // indirect - go.opentelemetry.io/otel/metric v1.41.0 // indirect - go.opentelemetry.io/otel/trace v1.41.0 // indirect + go.opentelemetry.io/otel/metric v1.43.0 // indirect + go.opentelemetry.io/proto/otlp v1.10.0 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect - golang.org/x/arch v0.22.0 // indirect + golang.org/x/arch v0.25.0 // indirect golang.org/x/crypto v0.49.0 // indirect - golang.org/x/sync v0.20.0 // indirect - golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.35.0 // indirect - golang.org/x/time v0.15.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260406210006-6f92a3bedf2d // indirect ) diff --git a/go.sum b/go.sum index 46173d0..c6e0463 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,21 @@ connectrpc.com/connect v1.19.1 h1:R5M57z05+90EfEvCY1b7hBxDVOUl45PrtXtAV2fOC14= connectrpc.com/connect v1.19.1/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w= -github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= -github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= +connectrpc.com/otelconnect v0.9.0 h1:NggB3pzRC3pukQWaYbRHJulxuXvmCKCKkQ9hbrHAWoA= +connectrpc.com/otelconnect v0.9.0/go.mod h1:AEkVLjCPXra+ObGFCOClcJkNjS7zPaQSqvO0lCyjfZc= +github.com/bytedance/gopkg v0.1.4 h1:oZnQwnX82KAIWb7033bEwtxvTqXcYMxDBaQxo5JJHWM= +github.com/bytedance/gopkg v0.1.4/go.mod h1:v1zWfPm21Fb+OsyXN2VAHdL6TBb2L88anLQgdyje6R4= github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE= github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k= -github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE= -github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= +github.com/bytedance/sonic/loader v0.5.1 h1:Ygpfa9zwRCCKSlrp5bBP/b/Xzc3VxsAW+5NIYXrOOpI= +github.com/bytedance/sonic/loader v0.5.1/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= +github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= +github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= +github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= @@ -17,10 +23,14 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw= -github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= -github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w= -github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM= +github.com/exaring/otelpgx v0.10.0 h1:NGGegdoBQM3jNZDKG8ENhigUcgBN7d7943L0YlcIpZc= +github.com/exaring/otelpgx v0.10.0/go.mod h1:R5/M5LWsPPBZc1SrRE5e0DiU48bI78C1/GPTWs6I66U= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/gabriel-vasile/mimetype v1.4.13 h1:46nXokslUBsAJE/wMsp5gtO500a4F3Nkz9Ufpk2AcUM= +github.com/gabriel-vasile/mimetype v1.4.13/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= +github.com/gin-contrib/sse v1.1.1 h1:uGYpNwTacv5R68bSGMapo62iLTRa9l5zxGCps4hK6ko= +github.com/gin-contrib/sse v1.1.1/go.mod h1:QXzuVkA0YO7o/gun03UI1Q+FTI8ZV/n5t03kIQAI89s= github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8= github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -82,12 +92,12 @@ github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/o github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w= -github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM= +github.com/go-playground/validator/v10 v10.30.2 h1:JiFIMtSSHb2/XBUbWM4i/MpeQm9ZK2xqPNk8vgvu5JQ= +github.com/go-playground/validator/v10 v10.30.2/go.mod h1:mAf2pIOVXjTEBrwUMGKkCWKKPs9NheYGabeB04txQSc= github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro= github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= -github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= -github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/goccy/go-json v0.10.6 h1:p8HrPJzOakx/mn/bQtjgNjdTcN+/S6FcG2CTtQOrHVU= +github.com/goccy/go-json v0.10.6/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM= github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -100,6 +110,8 @@ github.com/google/nftables v0.3.0 h1:bkyZ0cbpVeMHXOrtlFc8ISmfVqq5gPJukoYieyVmITg github.com/google/nftables v0.3.0/go.mod h1:BCp9FsrbF1Fn/Yu6CLUc9GGZFw/+hsxfluNXXmxBfRM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -136,8 +148,8 @@ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjY github.com/oklog/ulid/v2 v2.1.1 h1:suPZ4ARWLOJLegGFiZZ1dFAkqzhMjL3J1TzI+5wHz8s= github.com/oklog/ulid/v2 v2.1.1/go.mod h1:rcEKHmBBKfef9DhnvX7y1HZBYxjXb0cP5ExxNsTT1QQ= github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= -github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= -github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM= +github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= @@ -169,26 +181,58 @@ github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +go.etcd.io/bbolt v1.4.3 h1:dEadXpI6G79deX5prL3QRNP6JB8UxVkqo4UPnHaNXJo= +go.etcd.io/bbolt v1.4.3/go.mod h1:tKQlpPaYCVFctUIgFKFnAlvbmB3tpy1vkTnDWohtc0E= go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE= go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= -go.opentelemetry.io/otel v1.41.0 h1:YlEwVsGAlCvczDILpUXpIpPSL/VPugt7zHThEMLce1c= -go.opentelemetry.io/otel v1.41.0/go.mod h1:Yt4UwgEKeT05QbLwbyHXEwhnjxNO6D8L5PQP51/46dE= -go.opentelemetry.io/otel/metric v1.41.0 h1:rFnDcs4gRzBcsO9tS8LCpgR0dxg4aaxWlJxCno7JlTQ= -go.opentelemetry.io/otel/metric v1.41.0/go.mod h1:xPvCwd9pU0VN8tPZYzDZV/BMj9CM9vs00GuBjeKhJps= -go.opentelemetry.io/otel/sdk v1.41.0 h1:YPIEXKmiAwkGl3Gu1huk1aYWwtpRLeskpV+wPisxBp8= -go.opentelemetry.io/otel/sdk v1.41.0/go.mod h1:ahFdU0G5y8IxglBf0QBJXgSe7agzjE4GiTJ6HT9ud90= -go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8= -go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= -go.opentelemetry.io/otel/trace v1.41.0 h1:Vbk2co6bhj8L59ZJ6/xFTskY+tGAbOnCtQGVVa9TIN0= -go.opentelemetry.io/otel/trace v1.41.0/go.mod h1:U1NU4ULCoxeDKc09yCWdWe+3QoyweJcISEVa1RBzOis= +go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.68.0 h1:5FXSL2s6afUC1bzNzl1iedZZ8yqR7GOhbCoEXtyeK6Q= +go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.68.0/go.mod h1:MdHW7tLtkeGJnR4TyOrnd5D0zUGZQB1l84uHCe8hRpE= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 h1:0Qx7VGBacMm9ZENQ7TnNObTYI4ShC+lHI16seduaxZo= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0/go.mod h1:Sje3i3MjSPKTSPvVWCaL8ugBzJwik3u4smCjUeuupqg= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 h1:CqXxU8VOmDefoh0+ztfGaymYbhdB/tT3zs79QaZTNGY= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0/go.mod h1:BuhAPThV8PBHBvg8ZzZ/Ok3idOdhWIodywz2xEcRbJo= +go.opentelemetry.io/contrib/instrumentation/runtime v0.68.0 h1:jhVIQEprwUTV+KfzzliLidclhoTOoHTgdz96kAyR8mU= +go.opentelemetry.io/contrib/instrumentation/runtime v0.68.0/go.mod h1:4HsdbLUbernaTnA8CNaNE+1g026SciXb3juRYe3l8EY= +go.opentelemetry.io/contrib/propagators/b3 v1.43.0 h1:CETqV3QLLPTy5yNrqyMr41VnAOOD4lsRved7n4QG00A= +go.opentelemetry.io/contrib/propagators/b3 v1.43.0/go.mod h1:Q4mCiCdziYzpNR0g+6UqVotAlCDZdzz6L8jwY4knOrw= +go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= +go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= +go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.19.0 h1:Dn8rkudDzY6KV9dr/D/bTUuWgqDf9xe0rr4G2elrn0Y= +go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.19.0/go.mod h1:gMk9F0xDgyN9M/3Ed5Y1wKcx/9mlU91NXY2SNq7RQuU= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0 h1:8UQVDcZxOJLtX6gxtDt3vY2WTgvZqMQRzjsqiIHQdkc= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0/go.mod h1:2lmweYCiHYpEjQ/lSJBYhj9jP1zvCvQW4BqL9dnT7FQ= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 h1:88Y4s2C8oTui1LGM6bTWkw0ICGcOLCAI5l6zsD1j20k= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0/go.mod h1:Vl1/iaggsuRlrHf/hfPJPvVag77kKyvrLeD10kpMl+A= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0 h1:RAE+JPfvEmvy+0LzyUA25/SGawPwIUbZ6u0Wug54sLc= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0/go.mod h1:AGmbycVGEsRx9mXMZ75CsOyhSP6MFIcj/6dnG+vhVjk= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.43.0 h1:mS47AX77OtFfKG4vtp+84kuGSFZHTyxtXIN269vChY0= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.43.0/go.mod h1:PJnsC41lAGncJlPUniSwM81gc80GkgWJWr3cu2nKEtU= +go.opentelemetry.io/otel/log v0.19.0 h1:KUZs/GOsw79TBBMfDWsXS+KZ4g2Ckzksd1ymzsIEbo4= +go.opentelemetry.io/otel/log v0.19.0/go.mod h1:5DQYeGmxVIr4n0/BcJvF4upsraHjg6vudJJpnkL6Ipk= +go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM= +go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY= +go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg= +go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg= +go.opentelemetry.io/otel/sdk/log v0.19.0 h1:scYVLqT22D2gqXItnWiocLUKGH9yvkkeql5dBDiXyko= +go.opentelemetry.io/otel/sdk/log v0.19.0/go.mod h1:vFBowwXGLlW9AvpuF7bMgnNI95LiW10szrOdvzBHlAg= +go.opentelemetry.io/otel/sdk/log/logtest v0.19.0 h1:BEbF7ZBB6qQloV/Ub1+3NQoOUnVtcGkU3XX4Ws3GQfk= +go.opentelemetry.io/otel/sdk/log/logtest v0.19.0/go.mod h1:Lua81/3yM0wOmoHTokLj9y9ADeA02v1naRrVrkAZuKk= +go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw= +go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A= +go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= +go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= +go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g= +go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= -golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI= -golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= +golang.org/x/arch v0.25.0 h1:qnk6Ksugpi5Bz32947rkUgDt9/s5qvqDPl/gBKdMJLE= +golang.org/x/arch v0.25.0/go.mod h1:0X+GdSIP+kL5wPmpK7sdkEVTt2XoYP0cSjQSbZBwOi8= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= @@ -204,12 +248,14 @@ golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= -gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= -gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= -google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww= -google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= -google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE= -google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= +gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= +gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= +google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA= +google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260406210006-6f92a3bedf2d h1:wT2n40TBqFY6wiwazVK9/iTWbsQrgk5ZfCSVFLO9LQA= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260406210006-6f92a3bedf2d/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= +google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= +google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/api/handlers.go b/internal/api/handlers.go index bc6c2c0..0efd275 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -4,11 +4,11 @@ import ( "context" "encoding/json" "fmt" - "io" "net/http" "net/netip" "path/filepath" "strings" + "sync" "time" "github.com/gin-gonic/gin" @@ -17,29 +17,53 @@ import ( "github.com/jackc/pgx/v5/pgtype" "github.com/rs/zerolog/log" + "github.com/superserve-ai/sandbox/internal/auth" "github.com/superserve-ai/sandbox/internal/config" "github.com/superserve-ai/sandbox/internal/db" + "github.com/superserve-ai/sandbox/internal/telemetry" + "github.com/superserve-ai/sandbox/internal/vmdclient" ) -// VMDClient defines the subset of the VM daemon gRPC interface used by the -// control plane. This is satisfied by the gRPC adapter in cmd/controlplane. -type VMDClient interface { - CreateInstance(ctx context.Context, instanceID string, vcpu, memMiB, diskMiB uint32, metadata map[string]string) (ipAddress string, err error) - DestroyInstance(ctx context.Context, instanceID string, force bool) error - PauseInstance(ctx context.Context, instanceID, snapshotDir string) (snapshotPath, memPath string, err error) - ResumeInstance(ctx context.Context, instanceID, snapshotPath, memPath string) (ipAddress string, err error) - ExecCommand(ctx context.Context, instanceID, command string, args []string, env map[string]string, workingDir string, timeoutS uint32) (stdout, stderr string, exitCode int32, err error) - ExecCommandStream(ctx context.Context, instanceID, command string, args []string, env map[string]string, workingDir string, timeoutS uint32, onChunk func(stdout, stderr []byte, exitCode int32, finished bool)) error - UploadFile(ctx context.Context, instanceID, path string, content io.Reader) (int64, error) - DownloadFile(ctx context.Context, instanceID, path string) (io.ReadCloser, error) - UpdateSandboxNetwork(ctx context.Context, instanceID string, allowedCIDRs, deniedCIDRs, allowedDomains []string) error +// lifecycleTimer captures the start time and emits a single +// sandbox.lifecycle.duration record on completion. Outcome is derived from +// the gin response status: 2xx = ok, 408/504 = timeout, anything else = error. +// The `from` label is only meaningful for create; pass "" elsewhere. +func lifecycleTimer(c *gin.Context, op telemetry.SandboxLifecycleOp, from string) func() { + start := time.Now() + return func() { + outcome := telemetry.OutcomeOK + switch s := c.Writer.Status(); { + case s >= 200 && s < 400: + outcome = telemetry.OutcomeOK + case s == http.StatusRequestTimeout || s == http.StatusGatewayTimeout: + outcome = telemetry.OutcomeTimeout + default: + outcome = telemetry.OutcomeError + } + telemetry.RecordSandboxLifecycle(c.Request.Context(), op, outcome, from, time.Since(start).Seconds()) + } +} + +// VMDClient is the interface for talking to a VM daemon. +type VMDClient = vmdclient.Client + +// Scheduler selects a host for new sandboxes. +type Scheduler interface { + SelectHost(ctx context.Context) (hostID string, err error) +} + +// HostRegistry resolves a host ID to a VMD client. +type HostRegistry interface { + ClientFor(ctx context.Context, hostID string) (vmdclient.Client, error) } // Handlers holds shared dependencies for all route handlers. type Handlers struct { - VMD VMDClient - DB *db.Queries - Config *config.Config + VMD VMDClient // default VMD client (used when Hosts is nil or host lookup fails on legacy sandboxes) + DB *db.Queries + Config *config.Config + Hosts HostRegistry // when set, routes VMD calls via host_id + Scheduler Scheduler // when set, picks host on create } // NewHandlers creates a new Handlers instance. @@ -51,6 +75,26 @@ func NewHandlers(vmd VMDClient, queries *db.Queries, cfg *config.Config) *Handle } } +// vmdForHost returns the VMDClient for the given host. When a registry is +// configured, it resolves via DB lookup. If the lookup fails (e.g. legacy +// sandbox with a backfilled host_id that has no host row), falls back to +// the default VMD client so existing sandboxes keep working during the +// migration period. +func (h *Handlers) vmdForHost(ctx context.Context, hostID string) (VMDClient, error) { + if h.Hosts == nil { + return h.VMD, nil + } + c, err := h.Hosts.ClientFor(ctx, hostID) + if err != nil { + if h.VMD != nil { + log.Warn().Err(err).Str("host_id", hostID).Msg("host registry lookup failed, falling back to default VMD client") + return h.VMD, nil + } + return nil, err + } + return c, nil +} + // vmdTimeout is the default deadline for VMD gRPC calls. const vmdTimeout = 30 * time.Second @@ -84,105 +128,135 @@ func (h *Handlers) logActivityAsync(reqCtx context.Context, sandboxID, teamID uu }() } -// updateLastActivityAsync bumps last_activity_at in a background goroutine. -// Same detached-but-traced context pattern as logActivityAsync. -func (h *Handlers) updateLastActivityAsync(reqCtx context.Context, sandboxID, teamID uuid.UUID) { +// cleanupOldSnapshotAsync garbage-collects a snapshot whose sandbox reference +// was just overwritten by a newer pause. Runs detached from the request +// context so it does not add latency to the pause response, but preserves +// the trace/span so the work is visible in request traces. +// +// Order of operations: +// 1. Look up the snapshot's paths in the DB (also filters out saved=true +// snapshots — the DB query returns 0 rows for those). +// 2. Call VMD to unlink the vmstate + memory files. Idempotent on VMD side. +// 3. Delete the DB row. +// +// On any failure we log and exit: files may remain on disk (detectable by a +// future janitor) or the row may remain in DB (inert — no sandbox references +// it). Both outcomes are bounded and safe; the next pause will produce the +// same cleanup attempt for the newly-orphaned snapshot. +// +// prevSnapshotID is typically nil/invalid on a sandbox's first pause — the +// helper short-circuits in that case. +func (h *Handlers) cleanupOldSnapshotAsync(reqCtx context.Context, sandboxID, teamID uuid.UUID, hostID string, prevSnapshotID pgtype.UUID) { + if !prevSnapshotID.Valid { + return + } + oldSnapshotID := uuid.UUID(prevSnapshotID.Bytes) + asyncCtx := context.WithoutCancel(reqCtx) go func() { - ctx, cancel := context.WithTimeout(asyncCtx, asyncTimeout) - defer cancel() - if err := h.DB.UpdateSandboxLastActivity(ctx, db.UpdateSandboxLastActivityParams{ - ID: sandboxID, + l := log.With(). + Str("sandbox_id", sandboxID.String()). + Str("snapshot_id", oldSnapshotID.String()). + Logger() + + lookupCtx, lookupCancel := context.WithTimeout(asyncCtx, asyncTimeout) + snap, err := h.DB.GetSnapshotForCleanup(lookupCtx, db.GetSnapshotForCleanupParams{ + ID: oldSnapshotID, TeamID: teamID, - }); err != nil { - log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("async last_activity_at update failed") + }) + lookupCancel() + if err != nil { + // ErrNoRows: either the row is already gone, or it's a + // saved=true row we must not auto-delete. Either way, nothing + // to do. Any other error is transient — log and stop; a future + // janitor can retry. + if err != pgx.ErrNoRows { + l.Error().Err(err).Msg("cleanup: lookup old snapshot failed") + } + return } - }() -} -// AutoWake returns middleware that loads a sandbox, verifies team ownership, and -// transparently resumes idle sandboxes. On success it stores *db.Sandbox under -// the "sandbox" context key for downstream handlers. -func (h *Handlers) AutoWake() gin.HandlerFunc { - return func(c *gin.Context) { - sandboxID, err := parseSandboxID(c) - if err != nil { - c.Abort() + vmd, vmdErr := h.vmdForHost(asyncCtx, hostID) + if vmdErr != nil { + l.Error().Err(vmdErr).Str("host_id", hostID).Msg("cleanup: resolve VMD failed") return } - teamID, err := teamIDFromContext(c) + memPath := "" + if snap.MemPath != nil { + memPath = *snap.MemPath + } + + vmdCtx, vmdCancel := context.WithTimeout(asyncCtx, vmdTimeout) + err = vmd.DeleteSnapshot(vmdCtx, sandboxID.String(), snap.Path, memPath) + vmdCancel() if err != nil { - c.Abort() + // Files may linger on disk; row stays so a janitor (or the next + // pause-cleanup after another pause/resume cycle) can retry. + l.Error().Err(err).Msg("cleanup: VMD DeleteSnapshot failed, leaving row in place") return } - sandbox, err := h.DB.GetSandbox(c.Request.Context(), db.GetSandboxParams{ - ID: sandboxID, + delCtx, delCancel := context.WithTimeout(asyncCtx, asyncTimeout) + _, err = h.DB.DeleteSnapshotRow(delCtx, db.DeleteSnapshotRowParams{ + ID: oldSnapshotID, TeamID: teamID, }) + delCancel() if err != nil { - if err == pgx.ErrNoRows { - respondError(c, ErrSandboxNotFound) - } else { - log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("DB GetSandbox failed") - respondError(c, ErrInternal) - } - c.Abort() + // Files are gone; row remains. Inert (no sandbox references it) + // and cleanable by a janitor later. + l.Error().Err(err).Msg("cleanup: DeleteSnapshotRow failed, files already removed") return } + l.Debug().Msg("cleanup: previous snapshot garbage-collected") + }() +} - switch sandbox.Status { - case db.SandboxStatusActive: - // Ready. - case db.SandboxStatusIdle: - vmdCtx, vmdCancel := context.WithTimeout(c.Request.Context(), vmdTimeout) - defer vmdCancel() - if _, err := h.VMD.ResumeInstance(vmdCtx, sandboxID.String(), "", ""); err != nil { - log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("auto-wake ResumeInstance failed") - respondError(c, ErrInternal) - c.Abort() - return - } - // VM is running — detach from cancellation so a client - // disconnect cannot leave the row stuck in "idle", but - // keep the trace/span context so post-VMD writes show - // up in the same request trace. - postCtx, postCancel := context.WithTimeout(context.WithoutCancel(c.Request.Context()), vmdTimeout) - defer postCancel() - if err := h.DB.UpdateSandboxStatus(postCtx, db.UpdateSandboxStatusParams{ - ID: sandboxID, - Status: db.SandboxStatusActive, - TeamID: teamID, - }); err != nil { - log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("auto-wake UpdateSandboxStatus failed") - respondError(c, ErrInternal) - c.Abort() - return - } - // Reapply persisted egress rules — nftables + proxy state are fresh after restore. - if err := h.reapplyNetworkConfig(postCtx, sandboxID.String(), sandbox.NetworkConfig); err != nil { - log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("auto-wake reapply network config failed") - respondError(c, ErrInternal) - c.Abort() - return - } - default: - respondError(c, ErrInvalidState) - c.Abort() - return +// loadActiveSandbox fetches a sandbox by ID, verifies team ownership, and +// requires that it is in the `active` state. Non-active sandboxes (paused, +// pausing, failed, etc.) are rejected — callers must resume the sandbox +// explicitly via POST /sandboxes/:id/resume before operating on it. +// +// On any error path this writes the response and returns nil; the caller +// should simply return. On success it returns the loaded sandbox. +func (h *Handlers) loadActiveSandbox(c *gin.Context) *db.Sandbox { + sandboxID, err := parseSandboxID(c) + if err != nil { + return nil + } + teamID, err := teamIDFromContext(c) + if err != nil { + return nil + } + sandbox, err := h.DB.GetSandbox(c.Request.Context(), db.GetSandboxParams{ + ID: sandboxID, + TeamID: teamID, + }) + if err != nil { + if err == pgx.ErrNoRows { + respondError(c, ErrSandboxNotFound) + } else { + log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("DB GetSandbox failed") + respondError(c, ErrInternal) } - - c.Set("sandbox", &sandbox) - c.Next() + return nil + } + if sandbox.Status != db.SandboxStatusActive { + respondError(c, ErrInvalidState) + return nil } + return &sandbox } -// sandboxFromContext retrieves the *db.Sandbox stored by the AutoWake middleware. -func sandboxFromContext(c *gin.Context) *db.Sandbox { - val, _ := c.Get("sandbox") - sb, _ := val.(*db.Sandbox) - return sb +// resolveMemPath returns the memory snapshot path from a Snapshot record. +// Uses the stored mem_path column if set, otherwise falls back to the +// convention of placing mem.snap alongside the vmstate snapshot. +func resolveMemPath(snap db.Snapshot) string { + if snap.MemPath != nil && *snap.MemPath != "" { + return *snap.MemPath + } + return filepath.Join(filepath.Dir(snap.Path), "mem.snap") } // persistedEgressConfig mirrors the jsonb shape stored in sandbox.network_config. @@ -196,12 +270,12 @@ type persistedEgressConfig struct { // reapplyNetworkConfig reads the sandbox's persisted egress config from the DB // record and pushes it back to VMD. Called after every resume path (explicit -// /resume, AutoWake, post-restore in CreateSandbox) because the nftables rules -// and proxy state are fresh after a snapshot restore. +// /resume, post-restore in CreateSandbox) because the nftables rules and +// proxy state are fresh after a snapshot restore. // // Uses a caller-supplied context so the caller controls timeout/cancellation. // Silently returns nil if there is no persisted config (default allow-all). -func (h *Handlers) reapplyNetworkConfig(ctx context.Context, sandboxID string, raw []byte) error { +func (h *Handlers) reapplyNetworkConfig(ctx context.Context, vmd VMDClient, sandboxID string, raw []byte) error { if len(raw) == 0 { return nil } @@ -217,7 +291,7 @@ func (h *Handlers) reapplyNetworkConfig(ctx context.Context, sandboxID string, r return nil } - return h.VMD.UpdateSandboxNetwork(ctx, sandboxID, + return vmd.UpdateSandboxNetwork(ctx, sandboxID, cfg.Egress.AllowedCIDRs, cfg.Egress.DeniedCIDRs, cfg.Egress.AllowedDomains, @@ -235,250 +309,10 @@ func (h *Handlers) Health(c *gin.Context) { }) } -// --------------------------------------------------------------------------- -// Instance CRUD -// --------------------------------------------------------------------------- - -type createInstanceRequest struct { - Name string `json:"name" binding:"required,min=1,max=64"` -} - -func (h *Handlers) CreateInstance(c *gin.Context) { - var req createInstanceRequest - if err := c.ShouldBindJSON(&req); err != nil { - respondErrorMsg(c, "bad_request", fmt.Sprintf("Validation failed: %v", err), http.StatusBadRequest) - return - } - - instanceID := uuid.New().String() - - vmdCtx, vmdCancel := context.WithTimeout(c.Request.Context(), vmdTimeout) - defer vmdCancel() - _, err := h.VMD.CreateInstance(vmdCtx, instanceID, 0, 0, 0, nil) - if err != nil { - log.Error().Err(err).Str("instance_id", instanceID).Msg("VMD CreateInstance failed") - respondError(c, ErrInternal) - return - } - - c.JSON(http.StatusCreated, gin.H{ - "id": instanceID, - "name": req.Name, - "status": "RUNNING", - }) -} - -func (h *Handlers) GetInstance(c *gin.Context) { - instanceID, err := parseInstanceID(c) - if err != nil { - return - } - - // TODO: when DB is added, look up instance state from DB. - // For now, query VMD directly. - respondErrorMsg(c, "not_implemented", fmt.Sprintf("GetInstance %s — requires DB (not yet connected)", instanceID), http.StatusNotImplemented) -} - -func (h *Handlers) ListInstances(c *gin.Context) { - // TODO: when DB is added, list from DB. - respondErrorMsg(c, "not_implemented", "ListInstances — requires DB (not yet connected)", http.StatusNotImplemented) -} - -func (h *Handlers) DeleteInstance(c *gin.Context) { - instanceID, err := parseInstanceID(c) - if err != nil { - return - } - - vmdCtx, vmdCancel := context.WithTimeout(c.Request.Context(), vmdTimeout) - defer vmdCancel() - if err := h.VMD.DestroyInstance(vmdCtx, instanceID.String(), true); err != nil { - log.Error().Err(err).Str("instance_id", instanceID.String()).Msg("VMD DestroyInstance failed") - respondError(c, ErrInternal) - return - } - - c.Status(http.StatusNoContent) -} - -// --------------------------------------------------------------------------- -// Pause / Resume -// --------------------------------------------------------------------------- - -func (h *Handlers) PauseInstance(c *gin.Context) { - instanceID, err := parseInstanceID(c) - if err != nil { - return - } - - vmdCtx, vmdCancel := context.WithTimeout(c.Request.Context(), vmdTimeout) - defer vmdCancel() - _, _, err = h.VMD.PauseInstance(vmdCtx, instanceID.String(), "") - if err != nil { - log.Error().Err(err).Str("instance_id", instanceID.String()).Msg("VMD PauseInstance failed") - respondError(c, ErrInternal) - return - } - - c.JSON(http.StatusOK, gin.H{ - "id": instanceID.String(), - "status": "PAUSED", - }) -} - -func (h *Handlers) ResumeInstance(c *gin.Context) { - instanceID, err := parseInstanceID(c) - if err != nil { - return - } - - // TODO: when DB is added, read snapshot paths from DB. - // For now, pass empty paths — VMD uses its default. - vmdCtx, vmdCancel := context.WithTimeout(c.Request.Context(), vmdTimeout) - defer vmdCancel() - ipAddress, err := h.VMD.ResumeInstance(vmdCtx, instanceID.String(), "", "") - if err != nil { - log.Error().Err(err).Str("instance_id", instanceID.String()).Msg("VMD ResumeInstance failed") - respondError(c, ErrInternal) - return - } - - c.JSON(http.StatusOK, gin.H{ - "id": instanceID.String(), - "status": "RUNNING", - "ip_address": ipAddress, - }) -} - -// --------------------------------------------------------------------------- -// Exec -// --------------------------------------------------------------------------- - -type execRequest struct { - Command string `json:"command" binding:"required,min=1"` - Args []string `json:"args"` - Env map[string]string `json:"env"` - WorkingDir string `json:"working_dir"` - TimeoutS int `json:"timeout_s"` -} - -func (h *Handlers) ExecCommand(c *gin.Context) { - instanceID, err := parseInstanceID(c) - if err != nil { - return - } - - var req execRequest - if err := c.ShouldBindJSON(&req); err != nil { - respondErrorMsg(c, "bad_request", fmt.Sprintf("Validation failed: %v", err), http.StatusBadRequest) - return - } - - if req.TimeoutS <= 0 { - req.TimeoutS = 30 - } - - stdout, stderr, exitCode, err := h.VMD.ExecCommand(c.Request.Context(), instanceID.String(), - req.Command, req.Args, req.Env, req.WorkingDir, uint32(req.TimeoutS)) - if err != nil { - log.Error().Err(err).Str("instance_id", instanceID.String()).Msg("VMD ExecCommand failed") - respondError(c, ErrInternal) - return - } - - c.JSON(http.StatusOK, gin.H{ - "stdout": stdout, - "stderr": stderr, - "exit_code": exitCode, - }) -} - -// --------------------------------------------------------------------------- -// Files -// --------------------------------------------------------------------------- - -func (h *Handlers) UploadFile(c *gin.Context) { - instanceID, err := parseInstanceID(c) - if err != nil { - return - } - - filePath, err := cleanFilePath(c.Param("path")) - if err != nil { - respondErrorMsg(c, "bad_request", err.Error(), http.StatusBadRequest) - return - } - - bytesWritten, err := h.VMD.UploadFile(c.Request.Context(), instanceID.String(), filePath, c.Request.Body) - if err != nil { - log.Error().Err(err).Str("instance_id", instanceID.String()).Msg("file upload failed") - respondError(c, ErrInternal) - return - } - - c.JSON(http.StatusOK, gin.H{"path": filePath, "size": bytesWritten}) -} - -func (h *Handlers) DownloadFile(c *gin.Context) { - instanceID, err := parseInstanceID(c) - if err != nil { - return - } - - filePath, err := cleanFilePath(c.Param("path")) - if err != nil { - respondErrorMsg(c, "bad_request", err.Error(), http.StatusBadRequest) - return - } - - reader, err := h.VMD.DownloadFile(c.Request.Context(), instanceID.String(), filePath) - if err != nil { - log.Error().Err(err).Str("instance_id", instanceID.String()).Msg("file download failed") - errMsg := err.Error() - if strings.Contains(errMsg, "404") || strings.Contains(errMsg, "not found") { - respondErrorMsg(c, "not_found", - fmt.Sprintf("File not found: %s", filePath), - http.StatusNotFound) - } else { - respondError(c, ErrInternal) - } - return - } - defer reader.Close() - - c.Header("Content-Type", "application/octet-stream") - c.Status(http.StatusOK) - io.Copy(c.Writer, reader) -} - // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- -func cleanFilePath(raw string) (string, error) { - raw = strings.TrimPrefix(raw, "/") - if raw == "" { - return "", fmt.Errorf("file path is required") - } - if strings.Contains(raw, "..") { - return "", fmt.Errorf("path traversal not allowed") - } - cleaned := filepath.Clean("/" + raw) - return cleaned, nil -} - -func parseInstanceID(c *gin.Context) (uuid.UUID, error) { - raw := c.Param("instance_id") - id, err := uuid.Parse(raw) - if err != nil { - respondErrorMsg(c, "bad_request", - fmt.Sprintf("Invalid instance_id: %q is not a valid UUID", raw), - http.StatusBadRequest) - return uuid.Nil, err - } - return id, nil -} - func parseSandboxID(c *gin.Context) (uuid.UUID, error) { raw := c.Param("sandbox_id") id, err := uuid.Parse(raw) @@ -511,6 +345,7 @@ func teamIDFromContext(c *gin.Context) (uuid.UUID, error) { // --------------------------------------------------------------------------- func (h *Handlers) ResumeSandbox(c *gin.Context) { + defer lifecycleTimer(c, telemetry.OpResume, "")() sandboxID, err := parseSandboxID(c) if err != nil { return @@ -536,20 +371,23 @@ func (h *Handlers) ResumeSandbox(c *gin.Context) { return } - // Only idle sandboxes can be resumed. - if sandbox.Status != db.SandboxStatusIdle { + // Only paused sandboxes can be resumed. + if sandbox.Status != db.SandboxStatusPaused { respondError(c, ErrInvalidState) return } // Read the snapshot to get paths for VMD. if !sandbox.SnapshotID.Valid { - log.Error().Str("sandbox_id", sandboxID.String()).Msg("idle sandbox has no snapshot_id") + log.Error().Str("sandbox_id", sandboxID.String()).Msg("paused sandbox has no snapshot_id") respondError(c, ErrInternal) return } - snapshot, err := h.DB.GetSnapshot(c.Request.Context(), sandbox.SnapshotID.Bytes) + snapshot, err := h.DB.GetSnapshot(c.Request.Context(), db.GetSnapshotParams{ + ID: sandbox.SnapshotID.Bytes, + TeamID: teamID, + }) if err != nil { log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("DB GetSnapshot failed") respondError(c, ErrInternal) @@ -557,68 +395,87 @@ func (h *Handlers) ResumeSandbox(c *gin.Context) { } snapshotPath := snapshot.Path - memPath := filepath.Join(filepath.Dir(snapshotPath), "mem.snap") + memPath := resolveMemPath(snapshot) + + // Resolve the VMD client for this sandbox's host. + vmd, vmdLookupErr := h.vmdForHost(c.Request.Context(), sandbox.HostID) + if vmdLookupErr != nil { + log.Error().Err(vmdLookupErr).Str("sandbox_id", sandboxID.String()).Msg("resolve VMD for resume failed") + respondError(c, ErrInternal) + return + } // Resume the VM. Cancellation of this call still follows the request // context — if the client hangs up mid-resume, abort the VMD call. vmdCtx, vmdCancel := context.WithTimeout(c.Request.Context(), vmdTimeout) defer vmdCancel() - ipAddress, err := h.VMD.ResumeInstance(vmdCtx, sandboxID.String(), snapshotPath, memPath) + ipAddress, actualVcpu, actualMemMiB, err := vmd.ResumeInstance(vmdCtx, sandboxID.String(), snapshotPath, memPath, nil) if err != nil { - log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("VMD ResumeInstance failed") - respondError(c, ErrInternal) - return + if isVMDNotFound(err) { + log.Warn().Err(err).Str("sandbox_id", sandboxID.String()). + Msg("VMD ResumeInstance: VM not in map, falling back to stateless RestoreSnapshot") + ipAddress, actualVcpu, actualMemMiB, err = vmd.RestoreSnapshot(vmdCtx, sandboxID.String(), snapshotPath, memPath) + if err != nil { + log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("VMD RestoreSnapshot fallback failed") + respondError(c, ErrInternal) + return + } + } else { + log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("VMD ResumeInstance failed") + respondError(c, ErrInternal) + return + } + } + + // The fallback may have returned 0 for vcpu/mem. Fall back to the DB values. + if actualVcpu == 0 { + actualVcpu = uint32(sandbox.VcpuCount) + } + if actualMemMiB == 0 { + actualMemMiB = uint32(sandbox.MemoryMib) } // Past this point the VM is running. Detach from cancellation so a - // client disconnect cannot leave the sandbox stuck in "idle" while + // client disconnect cannot leave the sandbox stuck in "paused" while // the VM is actually up, but preserve the trace/span context so // these DB writes still appear in the request trace. postCtx, postCancel := context.WithTimeout(context.WithoutCancel(c.Request.Context()), vmdTimeout) defer postCancel() - // Update sandbox status to active. - if err := h.DB.UpdateSandboxStatus(postCtx, db.UpdateSandboxStatusParams{ - ID: sandboxID, - Status: db.SandboxStatusActive, - TeamID: teamID, - }); err != nil { - log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("DB UpdateSandboxStatus failed") - respondError(c, ErrInternal) - return + var ipAddr *netip.Addr + if ipAddress != "" { + if addr, parseErr := netip.ParseAddr(ipAddress); parseErr == nil { + ipAddr = &addr + } } - - // Update host runtime info. - ipAddr, _ := netip.ParseAddr(ipAddress) - if err := h.DB.UpdateSandboxHost(postCtx, db.UpdateSandboxHostParams{ + if err := h.DB.ActivateSandbox(postCtx, db.ActivateSandboxParams{ ID: sandboxID, - HostID: sandbox.HostID, - IpAddress: &ipAddr, + VcpuCount: int32(actualVcpu), + MemoryMib: int32(actualMemMiB), + IpAddress: ipAddr, TeamID: teamID, }); err != nil { - log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("DB UpdateSandboxHost failed") + log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("DB ActivateSandbox failed") respondError(c, ErrInternal) return } // Reapply persisted egress rules — the nftables rules and proxy state // are fresh after a snapshot restore, so user rules must be re-pushed. - if err := h.reapplyNetworkConfig(postCtx, sandboxID.String(), sandbox.NetworkConfig); err != nil { + if err := h.reapplyNetworkConfig(postCtx, vmd, sandboxID.String(), sandbox.NetworkConfig); err != nil { log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("reapply network config on resume failed") respondError(c, ErrInternal) return } // Async observability writes. - h.updateLastActivityAsync(c.Request.Context(), sandboxID, teamID) h.logActivityAsync(c.Request.Context(), sandboxID, teamID, "sandbox", "resumed", "success", &sandbox.Name, nil, nil) - c.JSON(http.StatusOK, gin.H{ - "id": sandboxID.String(), - "name": sandbox.Name, - "status": "active", - "ip_address": ipAddress, - }) + sandbox.Status = db.SandboxStatusActive + sandbox.VcpuCount = int32(actualVcpu) + sandbox.MemoryMib = int32(actualMemMiB) + sandbox.IpAddress = ipAddr + c.JSON(http.StatusOK, h.sandboxToResponse(sandbox)) } // --------------------------------------------------------------------------- @@ -626,6 +483,7 @@ func (h *Handlers) ResumeSandbox(c *gin.Context) { // --------------------------------------------------------------------------- func (h *Handlers) DeleteSandbox(c *gin.Context) { + defer lifecycleTimer(c, telemetry.OpDestroy, "")() sandboxID, err := parseSandboxID(c) if err != nil { return @@ -653,13 +511,25 @@ func (h *Handlers) DeleteSandbox(c *gin.Context) { // Destroy the VM (skip if sandbox never booted). if sandbox.Status != db.SandboxStatusFailed { - vmdCtx, vmdCancel := context.WithTimeout(c.Request.Context(), vmdTimeout) - defer vmdCancel() - if err := h.VMD.DestroyInstance(vmdCtx, sandboxID.String(), true); err != nil { - log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("VMD DestroyInstance failed") + vmd, vmdLookupErr := h.vmdForHost(c.Request.Context(), sandbox.HostID) + if vmdLookupErr != nil { + log.Error().Err(vmdLookupErr).Str("sandbox_id", sandboxID.String()).Msg("resolve VMD for delete failed") respondError(c, ErrInternal) return } + vmdCtx, vmdCancel := context.WithTimeout(c.Request.Context(), vmdTimeout) + defer vmdCancel() + if err := vmd.DestroyInstance(vmdCtx, sandboxID.String(), true); err != nil { + // Delete is idempotent — if the VM is already gone, proceed + // with DB cleanup instead of failing the request. + if isVMDNotFound(err) { + log.Warn().Err(err).Str("sandbox_id", sandboxID.String()).Msg("VMD DestroyInstance: VM already gone, proceeding with DB cleanup") + } else { + log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("VMD DestroyInstance failed") + respondError(c, ErrInternal) + return + } + } } // Soft-delete in DB. @@ -694,7 +564,7 @@ type createSandboxRequest struct { // TimeoutSeconds is a hard lifetime cap in seconds, measured from // created_at. When set, the reaper pauses the sandbox that many seconds - // after creation if it is still active. Already-idle sandboxes are left + // after creation if it is still active. Already-paused sandboxes are left // alone. Matches the user intent "stop this sandbox in N seconds so it // cannot burn resources indefinitely." // @@ -714,28 +584,35 @@ type createSandboxRequest struct { // Strings only — no nested objects, numbers, or arrays. This is // deliberate: it keeps URL filters unambiguous (no "is 42 the number // or the string?" questions) and matches what every other tagging - // system in this space does (E2B, AWS tags, GCE labels, k8s labels). + // system in this space does (AWS tags, GCE labels, k8s labels). // // Limits are enforced by validateMetadata: 64 keys, 256-byte keys, // 2 KB values, 16 KB total. Keys starting with `superserve.` or // `_superserve` are reserved for platform use and rejected. Metadata map[string]string `json:"metadata,omitempty"` + + // EnvVars are environment variables injected into every process inside + // the sandbox (terminal sessions, exec calls). Not stored in the DB — + // they live in boxd's memory for the sandbox's lifetime and survive + // pause/resume via snapshot. + EnvVars map[string]string `json:"env_vars,omitempty"` } type sandboxResponse struct { - ID uuid.UUID `json:"id"` - Name string `json:"name"` - Status string `json:"status"` - VcpuCount int32 `json:"vcpu_count"` - MemoryMib int32 `json:"memory_mib"` - SnapshotID *uuid.UUID `json:"snapshot_id,omitempty"` - CreatedAt time.Time `json:"created_at"` - TimeoutSeconds *int32 `json:"timeout_seconds,omitempty"` + ID uuid.UUID `json:"id"` + Name string `json:"name"` + Status string `json:"status"` + VcpuCount int32 `json:"vcpu_count"` + MemoryMib int32 `json:"memory_mib"` + AccessToken string `json:"access_token,omitempty"` + SnapshotID *uuid.UUID `json:"snapshot_id,omitempty"` + CreatedAt time.Time `json:"created_at"` + TimeoutSeconds *int32 `json:"timeout_seconds,omitempty"` Network *networkConfigRequest `json:"network,omitempty"` - Metadata map[string]string `json:"metadata"` + Metadata map[string]string `json:"metadata"` } -func sandboxToResponse(s db.Sandbox) sandboxResponse { +func (h *Handlers) sandboxToResponse(s db.Sandbox) sandboxResponse { resp := sandboxResponse{ ID: s.ID, Name: s.Name, @@ -745,6 +622,9 @@ func sandboxToResponse(s db.Sandbox) sandboxResponse { CreatedAt: s.CreatedAt, Metadata: decodeMetadata(s.Metadata), } + if h.Config != nil && h.Config.SandboxAccessTokenSeed != nil { + resp.AccessToken = auth.ComputeAccessToken(h.Config.SandboxAccessTokenSeed, s.ID.String()) + } if s.SnapshotID.Valid { id := uuid.UUID(s.SnapshotID.Bytes) resp.SnapshotID = &id @@ -838,7 +718,7 @@ func (h *Handlers) ListSandboxes(c *gin.Context) { out := make([]sandboxResponse, len(sandboxes)) for i, s := range sandboxes { - out[i] = sandboxToResponse(s) + out[i] = h.sandboxToResponse(s) } c.JSON(http.StatusOK, out) } @@ -901,10 +781,24 @@ func (h *Handlers) GetSandboxByID(c *gin.Context) { return } - c.JSON(http.StatusOK, sandboxToResponse(sandbox)) + c.JSON(http.StatusOK, h.sandboxToResponse(sandbox)) } func (h *Handlers) CreateSandbox(c *gin.Context) { + // `from` is determined inside the handler (snapshot vs cold) and the + // label would require mutation across the deferred closure. Use a + // pointer so the handler can update the from label before the timer + // fires. Default is empty (unlabelled). + from := "" + defer func(start time.Time) { + outcome := telemetry.OutcomeOK + if s := c.Writer.Status(); s < 200 || s >= 400 { + outcome = telemetry.OutcomeError + } + telemetry.RecordSandboxLifecycle(c.Request.Context(), telemetry.OpCreate, outcome, from, time.Since(start).Seconds()) + }(time.Now()) + _ = from // silence unused warning until we wire the snapshot/cold branch label + var req createSandboxRequest if err := bindJSONStrict(c, &req); err != nil { respondErrorMsg(c, "bad_request", fmt.Sprintf("Validation failed: %v", err), http.StatusBadRequest) @@ -947,6 +841,10 @@ func (h *Handlers) CreateSandbox(c *gin.Context) { respondErrorMsg(c, "bad_request", err.Error(), http.StatusBadRequest) return } + if err := validateEnvVars(req.EnvVars); err != nil { + respondErrorMsg(c, "bad_request", err.Error(), http.StatusBadRequest) + return + } // Marshal once into the canonical jsonb shape. Empty / nil maps are // stored as the empty object so the column is never NULL. metadataJSON, err := json.Marshal(req.Metadata) @@ -976,45 +874,72 @@ func (h *Handlers) CreateSandbox(c *gin.Context) { return } - snapshot, err := h.DB.GetSnapshot(c.Request.Context(), snapUUID) + snapshot, err := h.DB.GetSnapshot(c.Request.Context(), db.GetSnapshotParams{ + ID: snapUUID, + TeamID: teamID, + }) if err != nil { respondErrorMsg(c, "not_found", "Snapshot not found", http.StatusNotFound) return } - if snapshot.TeamID != teamID { - respondErrorMsg(c, "not_found", "Snapshot not found", http.StatusNotFound) - return - } snapshotID = pgtype.UUID{Bytes: snapUUID, Valid: true} snapshotPath = snapshot.Path - snapshotMemPath = filepath.Join(filepath.Dir(snapshotPath), "mem.snap") - } - - // Default template resources (1 vCPU, 512 MiB). - const defaultVcpu int32 = 1 - const defaultMemoryMib int32 = 512 - - // Insert sandbox with status=starting. timeout_seconds is optional — - // NULL means the sandbox lives until explicitly paused or deleted. - // metadata is always non-NULL (empty object when the user provided none) - // to match the DB column constraint and keep read paths nil-free. - sandbox, err := h.DB.CreateSandbox(c.Request.Context(), db.CreateSandboxParams{ - TeamID: teamID, - Name: req.Name, - Status: db.SandboxStatusStarting, - VcpuCount: defaultVcpu, - MemoryMib: defaultMemoryMib, - SnapshotID: snapshotID, - TimeoutSeconds: req.TimeoutSeconds, - Metadata: metadataJSON, - }) - if err != nil { - log.Error().Err(err).Msg("failed to create sandbox record") + snapshotMemPath = resolveMemPath(snapshot) + } + + // Select a host for this sandbox. + var hostID string + if h.Scheduler != nil { + hostID, err = h.Scheduler.SelectHost(c.Request.Context()) + if err != nil { + log.Error().Err(err).Msg("scheduler SelectHost failed") + respondErrorMsg(c, "service_unavailable", "No hosts available", http.StatusServiceUnavailable) + return + } + } else if h.Config != nil && h.Config.DefaultHostID != "" { + hostID = h.Config.DefaultHostID + } else { + hostID = "default" + } + + // Resolve the VMD client up front so we don't waste a DB INSERT on + // a host we can't reach. + vmd, vmdLookupErr := h.vmdForHost(c.Request.Context(), hostID) + if vmdLookupErr != nil { + log.Error().Err(vmdLookupErr).Msg("resolve VMD for create failed") respondError(c, ErrInternal) return } + // Generate the sandbox ID in Go so the DB INSERT and the VMD call + // can run in parallel — both need the same ID and neither needs to + // wait on the other. This hides the ~10-20ms INSERT roundtrip behind + // VMD's ~100-200ms create latency, shaving that much off the p50. + sandboxID := uuid.New() + + insertCtx := context.WithoutCancel(c.Request.Context()) + type insertResult struct { + sandbox db.Sandbox + err error + } + insertCh := make(chan insertResult, 1) + go func() { + sb, insertErr := h.DB.CreateSandbox(insertCtx, db.CreateSandboxParams{ + ID: sandboxID, + TeamID: teamID, + Name: req.Name, + Status: db.SandboxStatusStarting, + VcpuCount: 1, // placeholders; real values land via ActivateSandbox + MemoryMib: 1, + HostID: hostID, + SnapshotID: snapshotID, + TimeoutSeconds: req.TimeoutSeconds, + Metadata: metadataJSON, + }) + insertCh <- insertResult{sandbox: sb, err: insertErr} + }() + // Boot the VM synchronously — the client gets a response only after // the sandbox is fully running and ready to use. This call is still // scoped to the request context so that if the client hangs up, the @@ -1023,19 +948,41 @@ func (h *Handlers) CreateSandbox(c *gin.Context) { defer vmdCancel() var ipAddress string + var actualVcpu, actualMemMiB uint32 var vmdErr error if req.FromSnapshot != nil { - ipAddress, vmdErr = h.VMD.ResumeInstance(vmdCtx, sandbox.ID.String(), snapshotPath, snapshotMemPath) + ipAddress, actualVcpu, actualMemMiB, vmdErr = vmd.ResumeInstance(vmdCtx, sandboxID.String(), snapshotPath, snapshotMemPath, req.EnvVars) } else { - ipAddress, vmdErr = h.VMD.CreateInstance(vmdCtx, sandbox.ID.String(), - uint32(defaultVcpu), uint32(defaultMemoryMib), 0, nil) + ipAddress, actualVcpu, actualMemMiB, vmdErr = vmd.CreateInstance(vmdCtx, sandboxID.String(), + 0, 0, 0, nil, req.EnvVars) } - if vmdErr != nil { + + // Wait for the parallel INSERT to complete — its result determines + // how we handle a VMD failure (mark row failed vs. nothing to mark). + insertRes := <-insertCh + sandbox := insertRes.sandbox + dbErr := insertRes.err + + switch { + case dbErr != nil && vmdErr != nil: + // Both failed — nothing persisted, nothing to clean up. + log.Error().Err(dbErr).AnErr("vmd_err", vmdErr).Msg("CreateSandbox: DB and VMD both failed") + respondError(c, ErrInternal) + return + case dbErr != nil: + // DB insert failed but VMD succeeded — destroy the orphan VM so + // it doesn't linger on the host. Use a detached context so client + // disconnect doesn't leak the VM. + log.Error().Err(dbErr).Str("sandbox_id", sandboxID.String()).Msg("CreateSandbox: INSERT failed, destroying orphan VM") + cleanupCtx, cleanupCancel := context.WithTimeout(context.WithoutCancel(c.Request.Context()), vmdTimeout) + _ = vmd.DestroyInstance(cleanupCtx, sandboxID.String(), true) + cleanupCancel() + respondError(c, ErrInternal) + return + case vmdErr != nil: + // VMD failed but DB row exists — mark it failed so the reaper + // doesn't leave it stuck in "starting". log.Error().Err(vmdErr).Str("sandbox_id", sandbox.ID.String()).Msg("VMD create/resume failed") - // Mark the row failed using a cancellation-detached context so a - // disconnected client does not leave the sandbox stuck in - // "starting", but keep trace context so the failure write shows - // up in the same request span. failCtx, failCancel := context.WithTimeout(context.WithoutCancel(c.Request.Context()), asyncTimeout) defer failCancel() _ = h.DB.UpdateSandboxStatus(failCtx, db.UpdateSandboxStatusParams{ @@ -1053,42 +1000,59 @@ func (h *Handlers) CreateSandbox(c *gin.Context) { postCtx, postCancel := context.WithTimeout(context.WithoutCancel(c.Request.Context()), vmdTimeout) defer postCancel() - // Mark active in DB. - if err := h.DB.UpdateSandboxStatus(postCtx, db.UpdateSandboxStatusParams{ - ID: sandbox.ID, - Status: db.SandboxStatusActive, - TeamID: teamID, - }); err != nil { - log.Error().Err(err).Str("sandbox_id", sandbox.ID.String()).Msg("DB UpdateSandboxStatus(active) failed") - } - - // Persist the VM's assigned IP. host_id and pid are not tracked yet. + // Single atomic transition: starting → active with real resources + // and IP. VMD's response is the source of truth for vcpu/memory + // (they come from the template snapshot, not from what the control + // plane requested). + var ipAddr *netip.Addr if ipAddress != "" { if addr, parseErr := netip.ParseAddr(ipAddress); parseErr == nil { - if err := h.DB.UpdateSandboxHost(postCtx, db.UpdateSandboxHostParams{ - ID: sandbox.ID, - IpAddress: &addr, - TeamID: teamID, - }); err != nil { - log.Error().Err(err).Str("sandbox_id", sandbox.ID.String()).Msg("DB UpdateSandboxHost failed") - } + ipAddr = &addr } } + sandbox.Status = db.SandboxStatusActive + sandbox.VcpuCount = int32(actualVcpu) + sandbox.MemoryMib = int32(actualMemMiB) + sandbox.IpAddress = ipAddr - // Apply network rules if provided at creation. - if req.Network != nil && (len(req.Network.AllowOut) > 0 || len(req.Network.DenyOut) > 0) { - var allowedCIDRs, allowedDomains []string - for _, entry := range req.Network.AllowOut { - if isIPOrCIDR(entry) { - allowedCIDRs = append(allowedCIDRs, entry) - } else { - allowedDomains = append(allowedDomains, entry) - } + // ActivateSandbox (DB UPDATE) and network rule application (VMD call + // + DB UPDATE) are independent — both read VMD's result and write to + // their own sink. Run them in parallel so the response latency is + // max(activate, network) instead of activate + network. + hasNetworkRules := req.Network != nil && (len(req.Network.AllowOut) > 0 || len(req.Network.DenyOut) > 0) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + if err := h.DB.ActivateSandbox(postCtx, db.ActivateSandboxParams{ + ID: sandbox.ID, + VcpuCount: int32(actualVcpu), + MemoryMib: int32(actualMemMiB), + IpAddress: ipAddr, + TeamID: teamID, + }); err != nil { + log.Error().Err(err).Str("sandbox_id", sandbox.ID.String()).Msg("DB ActivateSandbox failed") } + }() - if err := h.VMD.UpdateSandboxNetwork(postCtx, sandbox.ID.String(), allowedCIDRs, req.Network.DenyOut, allowedDomains); err != nil { - log.Error().Err(err).Str("sandbox_id", sandbox.ID.String()).Msg("failed to apply network rules at creation") - } else { + if hasNetworkRules { + wg.Add(1) + go func() { + defer wg.Done() + var allowedCIDRs, allowedDomains []string + for _, entry := range req.Network.AllowOut { + if isIPOrCIDR(entry) { + allowedCIDRs = append(allowedCIDRs, entry) + } else { + allowedDomains = append(allowedDomains, entry) + } + } + + if err := vmd.UpdateSandboxNetwork(postCtx, sandbox.ID.String(), allowedCIDRs, req.Network.DenyOut, allowedDomains); err != nil { + log.Error().Err(err).Str("sandbox_id", sandbox.ID.String()).Msg("failed to apply network rules at creation") + return + } networkConfig, _ := json.Marshal(map[string]any{ "egress": map[string]any{ "allowed_cidrs": allowedCIDRs, @@ -1101,17 +1065,14 @@ func (h *Handlers) CreateSandbox(c *gin.Context) { NetworkConfig: networkConfig, TeamID: teamID, }) - } + }() } + wg.Wait() h.logActivityAsync(c.Request.Context(), sandbox.ID, teamID, "sandbox", "started", "success", &sandbox.Name, nil, nil) - // Build the response from the freshly-inserted row so metadata, IP, and - // any other server-populated fields make it back to the client without - // having to re-read the row. The row's status is still "starting" from - // the INSERT — overwrite it with "active" since we just transitioned. - resp := sandboxToResponse(sandbox) - resp.Status = string(db.SandboxStatusActive) + sandbox.Status = db.SandboxStatusActive + resp := h.sandboxToResponse(sandbox) if req.Network != nil && (len(req.Network.AllowOut) > 0 || len(req.Network.DenyOut) > 0) { resp.Network = req.Network } @@ -1123,6 +1084,7 @@ func (h *Handlers) CreateSandbox(c *gin.Context) { // --------------------------------------------------------------------------- func (h *Handlers) PauseSandbox(c *gin.Context) { + defer lifecycleTimer(c, telemetry.OpPause, "")() sandboxID, err := parseSandboxID(c) if err != nil { return @@ -1133,34 +1095,43 @@ func (h *Handlers) PauseSandbox(c *gin.Context) { return } - // Verify sandbox exists and belongs to this team. - sandbox, err := h.DB.GetSandbox(c.Request.Context(), db.GetSandboxParams{ + // Atomic ownership + state check + transition to 'pausing'. Collapses + // GetSandbox + UpdateStatus into a single DB roundtrip on the happy + // path. An empty result means the sandbox either doesn't exist, isn't + // ours, or isn't currently active — we do a cheap existence check to + // return the right error code (404 vs 409). + sandbox, err := h.DB.BeginPause(c.Request.Context(), db.BeginPauseParams{ ID: sandboxID, TeamID: teamID, }) if err != nil { - if err == pgx.ErrNoRows { + if err != pgx.ErrNoRows { + log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("DB BeginPause failed") + respondError(c, ErrInternal) + return + } + // Disambiguate: missing row (404) vs wrong state (409). + exists, existsErr := h.DB.SandboxExists(c.Request.Context(), db.SandboxExistsParams{ + ID: sandboxID, + TeamID: teamID, + }) + if existsErr != nil { + log.Error().Err(existsErr).Str("sandbox_id", sandboxID.String()).Msg("DB SandboxExists (pause fallback) failed") + respondError(c, ErrInternal) + return + } + if !exists { respondError(c, ErrSandboxNotFound) return } - log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("DB GetSandbox failed") - respondError(c, ErrInternal) - return - } - - // Only active sandboxes can be paused. - if sandbox.Status != db.SandboxStatusActive { respondError(c, ErrInvalidState) return } - // Mark as pausing before calling VMD. - if err := h.DB.UpdateSandboxStatus(c.Request.Context(), db.UpdateSandboxStatusParams{ - ID: sandboxID, - Status: db.SandboxStatusPausing, - TeamID: teamID, - }); err != nil { - log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("DB UpdateSandboxStatus(pausing) failed") + // Resolve the VMD client for this sandbox's host. + vmd, vmdLookupErr := h.vmdForHost(c.Request.Context(), sandbox.HostID) + if vmdLookupErr != nil { + log.Error().Err(vmdLookupErr).Str("sandbox_id", sandboxID.String()).Msg("resolve VMD for pause failed") respondError(c, ErrInternal) return } @@ -1168,8 +1139,18 @@ func (h *Handlers) PauseSandbox(c *gin.Context) { // Call VMD to pause and snapshot the VM. vmdCtx, vmdCancel := context.WithTimeout(c.Request.Context(), vmdTimeout) defer vmdCancel() - snapshotPath, memPath, err := h.VMD.PauseInstance(vmdCtx, sandboxID.String(), "") + snapshotPath, memPath, err := vmd.PauseInstance(vmdCtx, sandboxID.String(), "") if err != nil { + // VMD says the VM doesn't exist — it crashed or was removed out-of-band. + // Mark the sandbox failed and return 410 Gone. No revert — the VM is + // already dead, "active" was a lie. + if isVMDNotFound(err) { + log.Warn().Err(err).Str("sandbox_id", sandboxID.String()).Msg("VMD PauseInstance: VM unavailable, marking sandbox failed") + h.markSandboxFailedAsync(c.Request.Context(), sandboxID, teamID) + respondError(c, ErrSandboxGone) + return + } + log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("VMD PauseInstance failed") // Revert status to active asynchronously. Detach cancellation so // the revert survives client disconnect, but keep trace context @@ -1190,9 +1171,6 @@ func (h *Handlers) PauseSandbox(c *gin.Context) { return } - // TODO: store memPath in snapshot table (requires adding a mem_path column to the - // snapshot schema). For now, ResumeSandbox derives it via - // filepath.Join(filepath.Dir(snapshotPath), "mem.snap") by convention. log.Debug(). Str("sandbox_id", sandboxID.String()). Str("snapshot_path", snapshotPath). @@ -1206,40 +1184,32 @@ func (h *Handlers) PauseSandbox(c *gin.Context) { postCtx, postCancel := context.WithTimeout(context.WithoutCancel(c.Request.Context()), vmdTimeout) defer postCancel() - // Create snapshot record in DB. + // Atomic post-VMD bookkeeping: insert the snapshot row, link it to + // the sandbox, and flip status from pausing → paused in a single CTE. + // Collapses three DB roundtrips into one. triggerName := "pause" - snapshot, err := h.DB.CreateSnapshot(postCtx, db.CreateSnapshotParams{ - SandboxID: sandboxID, + finalized, err := h.DB.FinalizePause(postCtx, db.FinalizePauseParams{ + ID: sandboxID, TeamID: teamID, Path: snapshotPath, + MemPath: &memPath, SizeBytes: 0, Saved: false, Name: &triggerName, Trigger: triggerName, }) if err != nil { - log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("DB CreateSnapshot failed") - respondError(c, ErrInternal) - return - } - - // Link snapshot to sandbox and mark as idle. - if err := h.DB.SetSandboxSnapshot(postCtx, db.SetSandboxSnapshotParams{ - ID: sandboxID, - SnapshotID: pgtype.UUID{Bytes: snapshot.ID, Valid: true}, - TeamID: teamID, - }); err != nil { - log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("DB SetSandboxSnapshot failed") - respondError(c, ErrInternal) - return - } - - if err := h.DB.UpdateSandboxStatus(postCtx, db.UpdateSandboxStatusParams{ - ID: sandboxID, - Status: db.SandboxStatusIdle, - TeamID: teamID, - }); err != nil { - log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("DB UpdateSandboxStatus(idle) failed") + // ErrNoRows here means the sandbox was soft-deleted between + // BeginPause and FinalizePause (a rare race with DeleteSandbox). + // The VM is already stopped and its snapshot files are on disk — + // we can't finalize bookkeeping for a sandbox that no longer + // exists, so return 410 Gone. + if err == pgx.ErrNoRows { + log.Warn().Str("sandbox_id", sandboxID.String()).Msg("FinalizePause: sandbox deleted mid-pause") + respondError(c, ErrSandboxGone) + return + } + log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("DB FinalizePause failed") respondError(c, ErrInternal) return } @@ -1247,82 +1217,20 @@ func (h *Handlers) PauseSandbox(c *gin.Context) { // Async observability. h.logActivityAsync(c.Request.Context(), sandboxID, teamID, "sandbox", "paused", "success", &sandbox.Name, nil, nil) + // Async orphan-snapshot cleanup. FinalizePause atomically swapped + // sandbox.snapshot_id to the new snapshot, so the previous one (if any) + // is now unreferenced. Delete its files + row in the background so pause + // latency is not affected by the extra VMD round-trip and DB write. + h.cleanupOldSnapshotAsync(c.Request.Context(), sandboxID, teamID, sandbox.HostID, finalized.PrevSnapshotID) + c.JSON(http.StatusOK, gin.H{ "id": sandboxID.String(), "name": sandbox.Name, - "status": "idle", - "snapshot_id": snapshot.ID.String(), + "status": "paused", + "snapshot_id": finalized.SnapshotID.String(), }) } -// --------------------------------------------------------------------------- -// Sandbox Files -// --------------------------------------------------------------------------- - -// UploadSandboxFile uploads a file to a sandbox. The sandbox is loaded and -// auto-woken by the AutoWake middleware. -func (h *Handlers) UploadSandboxFile(c *gin.Context) { - sandbox := sandboxFromContext(c) - if sandbox == nil { - respondError(c, ErrInternal) - return - } - - filePath, err := cleanFilePath(c.Param("path")) - if err != nil { - respondErrorMsg(c, "bad_request", err.Error(), http.StatusBadRequest) - return - } - - bytesWritten, err := h.VMD.UploadFile(c.Request.Context(), sandbox.ID.String(), filePath, c.Request.Body) - if err != nil { - log.Error().Err(err).Str("sandbox_id", sandbox.ID.String()).Msg("sandbox file upload failed") - respondError(c, ErrInternal) - return - } - - h.updateLastActivityAsync(c.Request.Context(), sandbox.ID, sandbox.TeamID) - - c.JSON(http.StatusOK, gin.H{"path": filePath, "size": bytesWritten}) -} - -// DownloadSandboxFile downloads a file from a sandbox. The sandbox is loaded -// and auto-woken by the AutoWake middleware. -func (h *Handlers) DownloadSandboxFile(c *gin.Context) { - sandbox := sandboxFromContext(c) - if sandbox == nil { - respondError(c, ErrInternal) - return - } - - filePath, err := cleanFilePath(c.Param("path")) - if err != nil { - respondErrorMsg(c, "bad_request", err.Error(), http.StatusBadRequest) - return - } - - reader, err := h.VMD.DownloadFile(c.Request.Context(), sandbox.ID.String(), filePath) - if err != nil { - log.Error().Err(err).Str("sandbox_id", sandbox.ID.String()).Msg("sandbox file download failed") - errMsg := err.Error() - if strings.Contains(errMsg, "404") || strings.Contains(errMsg, "not found") { - respondErrorMsg(c, "not_found", - fmt.Sprintf("File not found: %s", filePath), - http.StatusNotFound) - } else { - respondError(c, ErrInternal) - } - return - } - defer reader.Close() - - h.updateLastActivityAsync(c.Request.Context(), sandbox.ID, sandbox.TeamID) - - c.Header("Content-Type", "application/octet-stream") - c.Status(http.StatusOK) - io.Copy(c.Writer, reader) -} - // --------------------------------------------------------------------------- // Sandbox Exec // --------------------------------------------------------------------------- @@ -1336,11 +1244,11 @@ type sandboxExecRequest struct { } // ExecSandbox runs a command inside a sandbox and returns the result. -// The sandbox is loaded and auto-woken by the AutoWake middleware. +// The sandbox must already be active — callers must resume a paused sandbox +// via POST /sandboxes/:id/resume first. func (h *Handlers) ExecSandbox(c *gin.Context) { - sandbox := sandboxFromContext(c) + sandbox := h.loadActiveSandbox(c) if sandbox == nil { - respondError(c, ErrInternal) return } @@ -1354,18 +1262,30 @@ func (h *Handlers) ExecSandbox(c *gin.Context) { req.TimeoutS = 30 } + vmd, vmdLookupErr := h.vmdForHost(c.Request.Context(), sandbox.HostID) + if vmdLookupErr != nil { + log.Error().Err(vmdLookupErr).Str("sandbox_id", sandbox.ID.String()).Msg("resolve VMD for exec failed") + respondError(c, ErrInternal) + return + } + start := time.Now() - stdout, stderr, exitCode, err := h.VMD.ExecCommand(c.Request.Context(), sandbox.ID.String(), + stdout, stderr, exitCode, err := vmd.ExecCommand(c.Request.Context(), sandbox.ID.String(), req.Command, req.Args, req.Env, req.WorkingDir, uint32(req.TimeoutS)) durationMs := int32(time.Since(start).Milliseconds()) if err != nil { + if isVMDNotFound(err) { + log.Warn().Err(err).Str("sandbox_id", sandbox.ID.String()).Msg("VMD ExecCommand: VM unavailable, marking sandbox failed") + h.markSandboxFailedAsync(c.Request.Context(), sandbox.ID, sandbox.TeamID) + respondError(c, ErrSandboxGone) + return + } log.Error().Err(err).Str("sandbox_id", sandbox.ID.String()).Msg("VMD ExecCommand failed") respondError(c, ErrInternal) return } // Async observability writes. - h.updateLastActivityAsync(c.Request.Context(), sandbox.ID, sandbox.TeamID) metadata, _ := json.Marshal(map[string]any{ "command": req.Command, "exit_code": exitCode, @@ -1474,10 +1394,18 @@ func (h *Handlers) PatchSandbox(c *gin.Context) { } } + // Resolve the VMD client for this sandbox's host. + vmd, vmdLookupErr := h.vmdForHost(c.Request.Context(), sandbox.HostID) + if vmdLookupErr != nil { + log.Error().Err(vmdLookupErr).Str("sandbox_id", sandboxID.String()).Msg("resolve VMD for patch failed") + respondError(c, ErrInternal) + return + } + // Apply rules to the running VM via VMD. vmdCtx, vmdCancel := context.WithTimeout(c.Request.Context(), vmdTimeout) defer vmdCancel() - if err := h.VMD.UpdateSandboxNetwork(vmdCtx, sandboxID.String(), allowedCIDRs, body.Network.DenyOut, allowedDomains); err != nil { + if err := vmd.UpdateSandboxNetwork(vmdCtx, sandboxID.String(), allowedCIDRs, body.Network.DenyOut, allowedDomains); err != nil { log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("VMD UpdateSandboxNetwork failed") respondError(c, ErrInternal) return @@ -1609,10 +1537,6 @@ func validateMetadata(md map[string]string) error { return fmt.Errorf("metadata has %d keys, max is %d", len(md), metadataMaxKeys) } - // Track total size as we iterate so we can fail fast on the offending - // key rather than re-marshalling at the end. The size is approximate - // (it doesn't include json punctuation) but the cap is conservative - // enough that the difference doesn't matter. totalBytes := 0 for k, v := range md { if k == "" { @@ -1638,6 +1562,41 @@ func validateMetadata(md map[string]string) error { return nil } +// Env var validation limits. Same key count cap as metadata; values are +// larger (API keys, connection strings) so 8 KB per value, 64 KB total. +const ( + envVarsMaxKeys = 64 + envVarsMaxKeyLen = 256 + envVarsMaxValueLen = 8192 // 8 KB — API keys, tokens, DSNs + envVarsMaxTotalBytes = 65536 // 64 KB +) + +func validateEnvVars(env map[string]string) error { + if len(env) == 0 { + return nil + } + if len(env) > envVarsMaxKeys { + return fmt.Errorf("env_vars has %d keys, max is %d", len(env), envVarsMaxKeys) + } + totalBytes := 0 + for k, v := range env { + if k == "" { + return fmt.Errorf("env_vars keys cannot be empty") + } + if len(k) > envVarsMaxKeyLen { + return fmt.Errorf("env_vars key %q is %d bytes, max is %d", k, len(k), envVarsMaxKeyLen) + } + if len(v) > envVarsMaxValueLen { + return fmt.Errorf("env_vars value for key %q is %d bytes, max is %d", k, len(v), envVarsMaxValueLen) + } + totalBytes += len(k) + len(v) + if totalBytes > envVarsMaxTotalBytes { + return fmt.Errorf("env_vars exceeds %d bytes total", envVarsMaxTotalBytes) + } + } + return nil +} + // Returns nil on success or a 400-appropriate error message. func validateEgressRules(allowOut, denyOut []string) error { for _, entry := range denyOut { diff --git a/internal/api/handlers_test.go b/internal/api/handlers_test.go index 89f6f2d..7d1d0e1 100644 --- a/internal/api/handlers_test.go +++ b/internal/api/handlers_test.go @@ -1,16 +1,14 @@ package api import ( - "bytes" "context" "encoding/json" - "errors" "fmt" - "io" "net/http" "net/http/httptest" "net/netip" "strings" + "sync/atomic" "testing" "time" @@ -28,20 +26,21 @@ import ( // --------------------------------------------------------------------------- type stubVMD struct { - createFn func(ctx context.Context, id string, vcpu, memMiB, diskMiB uint32, metadata map[string]string) (string, error) - destroyFn func(ctx context.Context, id string, force bool) error - pauseFn func(ctx context.Context, id, snapshotDir string) (string, string, error) - resumeFn func(ctx context.Context, id, snapshotPath, memPath string) (string, error) - execFn func(ctx context.Context, id, command string, args []string, env map[string]string, workingDir string, timeoutS uint32) (string, string, int32, error) - uploadFn func(ctx context.Context, id, path string, content io.Reader) (int64, error) - downloadFn func(ctx context.Context, id, path string) (io.ReadCloser, error) + createFn func(ctx context.Context, id string, vcpu, memMiB, diskMiB uint32, metadata map[string]string) (string, error) + destroyFn func(ctx context.Context, id string, force bool) error + pauseFn func(ctx context.Context, id, snapshotDir string) (string, string, error) + resumeFn func(ctx context.Context, id, snapshotPath, memPath string) (string, error) + restoreFn func(ctx context.Context, id, snapshotPath, memPath string) (string, error) + deleteSnapshotFn func(ctx context.Context, id, snapshotPath, memPath string) error + execFn func(ctx context.Context, id, command string, args []string, env map[string]string, workingDir string, timeoutS uint32) (string, string, int32, error) } -func (s *stubVMD) CreateInstance(ctx context.Context, id string, vcpu, memMiB, diskMiB uint32, metadata map[string]string) (string, error) { +func (s *stubVMD) CreateInstance(ctx context.Context, id string, vcpu, memMiB, diskMiB uint32, metadata map[string]string, envVars map[string]string) (string, uint32, uint32, error) { if s.createFn != nil { - return s.createFn(ctx, id, vcpu, memMiB, diskMiB, metadata) + ip, err := s.createFn(ctx, id, vcpu, memMiB, diskMiB, metadata) + return ip, 1, 1024, err } - return "10.0.0.1", nil + return "10.0.0.1", 1, 1024, nil } func (s *stubVMD) DestroyInstance(ctx context.Context, id string, force bool) error { if s.destroyFn != nil { @@ -55,11 +54,25 @@ func (s *stubVMD) PauseInstance(ctx context.Context, id, snapshotDir string) (st } return "/snapshots/vmstate.snap", "/snapshots/mem.snap", nil } -func (s *stubVMD) ResumeInstance(ctx context.Context, id, snapshotPath, memPath string) (string, error) { +func (s *stubVMD) ResumeInstance(ctx context.Context, id, snapshotPath, memPath string, envVars map[string]string) (string, uint32, uint32, error) { if s.resumeFn != nil { - return s.resumeFn(ctx, id, snapshotPath, memPath) + ip, err := s.resumeFn(ctx, id, snapshotPath, memPath) + return ip, 1, 1024, err } - return "10.0.0.1", nil + return "10.0.0.1", 1, 1024, nil +} +func (s *stubVMD) RestoreSnapshot(ctx context.Context, id, snapshotPath, memPath string) (string, uint32, uint32, error) { + if s.restoreFn != nil { + ip, err := s.restoreFn(ctx, id, snapshotPath, memPath) + return ip, 1, 1024, err + } + return "10.0.0.1", 1, 1024, nil +} +func (s *stubVMD) DeleteSnapshot(ctx context.Context, id, snapshotPath, memPath string) error { + if s.deleteSnapshotFn != nil { + return s.deleteSnapshotFn(ctx, id, snapshotPath, memPath) + } + return nil } func (s *stubVMD) ExecCommand(ctx context.Context, id, command string, args []string, env map[string]string, workingDir string, timeoutS uint32) (string, string, int32, error) { if s.execFn != nil { @@ -70,18 +83,6 @@ func (s *stubVMD) ExecCommand(ctx context.Context, id, command string, args []st func (s *stubVMD) ExecCommandStream(context.Context, string, string, []string, map[string]string, string, uint32, func([]byte, []byte, int32, bool)) error { return nil } -func (s *stubVMD) UploadFile(ctx context.Context, id, path string, content io.Reader) (int64, error) { - if s.uploadFn != nil { - return s.uploadFn(ctx, id, path, content) - } - return 0, nil -} -func (s *stubVMD) DownloadFile(ctx context.Context, id, path string) (io.ReadCloser, error) { - if s.downloadFn != nil { - return s.downloadFn(ctx, id, path) - } - return io.NopCloser(strings.NewReader("file-content")), nil -} func (s *stubVMD) UpdateSandboxNetwork(_ context.Context, _ string, _, _, _ []string) error { return nil } @@ -118,10 +119,10 @@ func (m *mockDBTX) Query(context.Context, string, ...any) (pgx.Rows, error) { // --------------------------------------------------------------------------- // sandboxRow returns a mockRow that populates a Sandbox from GetSandbox's Scan -// call (17 destination pointers matching the column order in sqlc-generated +// call (16 destination pointers matching the column order in sqlc-generated // queries: ID, TeamID, Name, Status, VcpuCount, MemoryMib, HostID, IpAddress, -// Pid, SnapshotID, LastActivityAt, CreatedAt, UpdatedAt, DestroyedAt, -// NetworkConfig, TimeoutSeconds, Metadata). +// Pid, SnapshotID, CreatedAt, UpdatedAt, DestroyedAt, NetworkConfig, +// TimeoutSeconds, Metadata). func sandboxRow(s db.Sandbox) *mockRow { return &mockRow{scanFn: func(dest ...any) error { *dest[0].(*uuid.UUID) = s.ID @@ -130,17 +131,16 @@ func sandboxRow(s db.Sandbox) *mockRow { *dest[3].(*db.SandboxStatus) = s.Status *dest[4].(*int32) = s.VcpuCount *dest[5].(*int32) = s.MemoryMib - *dest[6].(**string) = s.HostID + *dest[6].(*string) = s.HostID *dest[7].(**netip.Addr) = s.IpAddress *dest[8].(**int32) = s.Pid *dest[9].(*pgtype.UUID) = s.SnapshotID - *dest[10].(*time.Time) = s.LastActivityAt - *dest[11].(*time.Time) = s.CreatedAt - *dest[12].(*time.Time) = s.UpdatedAt - *dest[13].(*pgtype.Timestamptz) = s.DestroyedAt - *dest[14].(*[]byte) = s.NetworkConfig - *dest[15].(**int32) = s.TimeoutSeconds - *dest[16].(*[]byte) = s.Metadata + *dest[10].(*time.Time) = s.CreatedAt + *dest[11].(*time.Time) = s.UpdatedAt + *dest[12].(*pgtype.Timestamptz) = s.DestroyedAt + *dest[13].(*[]byte) = s.NetworkConfig + *dest[14].(**int32) = s.TimeoutSeconds + *dest[15].(*[]byte) = s.Metadata return nil }} } @@ -189,16 +189,8 @@ func setupTestRouter(h *Handlers, teamID string) *gin.Engine { r.POST("/sandboxes/:sandbox_id/resume", h.ResumeSandbox) r.POST("/sandboxes/:sandbox_id/pause", h.PauseSandbox) r.DELETE("/sandboxes/:sandbox_id", h.DeleteSandbox) - - // Routes with auto-wake middleware. - ops := r.Group("/sandboxes/:sandbox_id") - ops.Use(h.AutoWake()) - { - ops.POST("/exec", h.ExecSandbox) - ops.POST("/exec/stream", h.ExecSandboxStream) - ops.PUT("/files/*path", h.UploadSandboxFile) - ops.GET("/files/*path", h.DownloadSandboxFile) - } + r.POST("/sandboxes/:sandbox_id/exec", h.ExecSandbox) + r.POST("/sandboxes/:sandbox_id/exec/stream", h.ExecSandboxStream) return r } @@ -439,6 +431,25 @@ func snapshotRow(s db.Snapshot) *mockRow { }} } +// finalizePauseRow mocks the two-column RETURNING clause of FinalizePause: +// (snapshot_id uuid, prev_snapshot_id uuid NULL). Pass pgtype.UUID{Valid:false} +// as prev to simulate a sandbox being paused for the first time. +func finalizePauseRow(snapshotID uuid.UUID, prev pgtype.UUID) *mockRow { + return &mockRow{scanFn: func(dest ...any) error { + *dest[0].(*uuid.UUID) = snapshotID + *dest[1].(*pgtype.UUID) = prev + return nil + }} +} + +// boolRow mocks a single-bool scan (e.g. SandboxExists). +func boolRow(value bool) *mockRow { + return &mockRow{scanFn: func(dest ...any) error { + *dest[0].(*bool) = value + return nil + }} +} + func resumeRequest(sandboxID string) *http.Request { return httptest.NewRequest(http.MethodPost, "/sandboxes/"+sandboxID+"/resume", nil) } @@ -448,7 +459,7 @@ func idleSandboxWithSnapshot(sandboxID, teamID, snapshotID uuid.UUID) db.Sandbox ID: sandboxID, TeamID: teamID, Name: "test-sb", - Status: db.SandboxStatusIdle, + Status: db.SandboxStatusPaused, SnapshotID: pgtype.UUID{Bytes: snapshotID, Valid: true}, } } @@ -602,7 +613,7 @@ func TestResumeSandbox_NoSnapshotID(t *testing.T) { sandboxID := uuid.New() teamID := uuid.New() // Idle but no snapshot_id set. - sb := db.Sandbox{ID: sandboxID, TeamID: teamID, Name: "sb", Status: db.SandboxStatusIdle} + sb := db.Sandbox{ID: sandboxID, TeamID: teamID, Name: "sb", Status: db.SandboxStatusPaused} mock := &mockDBTX{ queryRowFn: func(context.Context, string, ...any) pgx.Row { return sandboxRow(sb) }, @@ -830,50 +841,35 @@ func TestExecSandbox_InvalidState(t *testing.T) { } } -func TestExecSandbox_AutoWakeIdle(t *testing.T) { +// Paused sandboxes must be resumed explicitly via POST /resume before exec +// works — there is no implicit auto-wake. +func TestExecSandbox_PausedRejected(t *testing.T) { sandboxID := uuid.New() teamID := uuid.New() - sb := db.Sandbox{ID: sandboxID, TeamID: teamID, Name: "idle-sb", Status: db.SandboxStatusIdle} + sb := db.Sandbox{ID: sandboxID, TeamID: teamID, Name: "paused-sb", Status: db.SandboxStatusPaused} - var resumeCalled, execCalled bool + var execCalled bool vmd := &stubVMD{ - resumeFn: func(_ context.Context, id, _, _ string) (string, error) { - resumeCalled = true - if id != sandboxID.String() { - t.Errorf("ResumeInstance id = %q, want %q", id, sandboxID) - } - return "10.0.0.1", nil - }, - execFn: func(_ context.Context, id, command string, _ []string, _ map[string]string, _ string, _ uint32) (string, string, int32, error) { + execFn: func(context.Context, string, string, []string, map[string]string, string, uint32) (string, string, int32, error) { execCalled = true - return "ok\n", "", 0, nil + return "", "", 0, nil }, } mock := &mockDBTX{ - queryRowFn: func(_ context.Context, sql string, _ ...any) pgx.Row { - if strings.Contains(sql, "FROM sandbox") { - return sandboxRow(sb) - } - return activityRow() - }, - execFn: func(context.Context, string, ...any) (pgconn.CommandTag, error) { - return pgconn.NewCommandTag("UPDATE 1"), nil - }, + queryRowFn: func(context.Context, string, ...any) pgx.Row { return sandboxRow(sb) }, + execFn: func(context.Context, string, ...any) (pgconn.CommandTag, error) { return pgconn.NewCommandTag(""), nil }, } h := &Handlers{VMD: vmd, DB: db.New(mock)} w := httptest.NewRecorder() - setupTestRouter(h, teamID.String()).ServeHTTP(w, sandboxExecReq(sandboxID.String(), `{"command":"echo","args":["hello"]}`)) + setupTestRouter(h, teamID.String()).ServeHTTP(w, sandboxExecReq(sandboxID.String(), `{"command":"echo"}`)) - if w.Code != http.StatusOK { - t.Errorf("status = %d, want %d; body: %s", w.Code, http.StatusOK, w.Body.String()) + if w.Code != http.StatusConflict { + t.Errorf("status = %d, want %d; body: %s", w.Code, http.StatusConflict, w.Body.String()) } - if !resumeCalled { - t.Error("VMD.ResumeInstance was not called for auto-wake") - } - if !execCalled { - t.Error("VMD.ExecCommand was not called after auto-wake") + if execCalled { + t.Error("VMD.ExecCommand should not be called on a paused sandbox") } } @@ -931,13 +927,7 @@ func TestCreateSandbox_Success(t *testing.T) { sandboxID := uuid.New() vmd := &stubVMD{ - createFn: func(_ context.Context, id string, vcpu, memMiB, _ uint32, _ map[string]string) (string, error) { - if vcpu != 1 { - t.Errorf("vcpu = %d, want 1", vcpu) - } - if memMiB != 512 { - t.Errorf("memMiB = %d, want 512", memMiB) - } + createFn: func(_ context.Context, id string, _, _, _ uint32, _ map[string]string) (string, error) { return "10.0.0.42", nil }, } @@ -974,6 +964,13 @@ func TestCreateSandbox_Success(t *testing.T) { if body["status"] != "active" { t.Errorf("status = %q, want active", body["status"]) } + // Resources should reflect what VMD reported, not the initial INSERT placeholders. + if v := body["vcpu_count"].(float64); v == 0 { + t.Error("vcpu_count is 0 — VMD's reported value was not propagated to the response") + } + if v := body["memory_mib"].(float64); v == 0 { + t.Error("memory_mib is 0 — VMD's reported value was not propagated to the response") + } } func TestCreateSandbox_InvalidBody(t *testing.T) { @@ -1068,18 +1065,18 @@ func TestPauseSandbox_Success(t *testing.T) { } snapshotID := uuid.New() - queryRowCall := 0 mock := &mockDBTX{ queryRowFn: func(_ context.Context, sql string, _ ...any) pgx.Row { - queryRowCall++ - if strings.Contains(sql, "FROM sandbox") { + switch { + case strings.Contains(sql, "'pausing'"): + // BeginPause: atomic transition active → pausing, returns * + return sandboxRow(sb) + case strings.Contains(sql, "INSERT INTO snapshot"): + // FinalizePause: CTE insert + update, returns snapshot_id + return finalizePauseRow(snapshotID, pgtype.UUID{}) + case strings.Contains(sql, "FROM sandbox"): + // Generic GetSandbox (fallback from BeginPause) return sandboxRow(sb) - } - if strings.Contains(sql, "INSERT INTO snapshot") { - return snapshotRow(db.Snapshot{ - ID: snapshotID, SandboxID: sandboxID, TeamID: teamID, - Path: "/snapshots/vmstate.snap", Trigger: "pause", - }) } return activityRow() }, @@ -1100,63 +1097,46 @@ func TestPauseSandbox_Success(t *testing.T) { } body := parseJSON(t, w) - if body["status"] != "idle" { - t.Errorf("status = %q, want %q", body["status"], "idle") + if body["status"] != "paused" { + t.Errorf("status = %q, want %q", body["status"], "paused") } if body["snapshot_id"] != snapshotID.String() { t.Errorf("snapshot_id = %q, want %q", body["snapshot_id"], snapshotID) } } -func TestPauseSandbox_NotActive(t *testing.T) { - sandboxID := uuid.New() - teamID := uuid.New() - sb := db.Sandbox{ID: sandboxID, TeamID: teamID, Name: "sb", Status: db.SandboxStatusIdle} - - mock := &mockDBTX{ - queryRowFn: func(context.Context, string, ...any) pgx.Row { return sandboxRow(sb) }, - execFn: func(context.Context, string, ...any) (pgconn.CommandTag, error) { return pgconn.NewCommandTag(""), nil }, - } - vmd := &stubVMD{} - - h := &Handlers{VMD: vmd, DB: db.New(mock)} - w := httptest.NewRecorder() - setupTestRouter(h, teamID.String()).ServeHTTP(w, pauseRequest(sandboxID.String())) - - if w.Code != http.StatusConflict { - t.Errorf("status = %d, want %d", w.Code, http.StatusConflict) - } -} - -func TestPauseSandbox_NotFound(t *testing.T) { - mock := &mockDBTX{ - queryRowFn: func(context.Context, string, ...any) pgx.Row { return notFoundRow() }, - execFn: func(context.Context, string, ...any) (pgconn.CommandTag, error) { return pgconn.NewCommandTag(""), nil }, - } - vmd := &stubVMD{} - - h := &Handlers{VMD: vmd, DB: db.New(mock)} - w := httptest.NewRecorder() - setupTestRouter(h, uuid.New().String()).ServeHTTP(w, pauseRequest(uuid.New().String())) - - if w.Code != http.StatusNotFound { - t.Errorf("status = %d, want %d", w.Code, http.StatusNotFound) - } -} - -func TestPauseSandbox_VMDError(t *testing.T) { +// TestPauseSandbox_FirstPauseNoCleanup verifies that the first pause of a +// sandbox (prev_snapshot_id IS NULL) does not trigger VMD.DeleteSnapshot. +func TestPauseSandbox_FirstPauseNoCleanup(t *testing.T) { sandboxID := uuid.New() teamID := uuid.New() sb := db.Sandbox{ID: sandboxID, TeamID: teamID, Name: "sb", Status: db.SandboxStatusActive} + var deleteSnapshotCalled int32 vmd := &stubVMD{ pauseFn: func(context.Context, string, string) (string, string, error) { - return "", "", fmt.Errorf("vmd unreachable") + return "/snapshots/new/vmstate.snap", "/snapshots/new/mem.snap", nil + }, + deleteSnapshotFn: func(context.Context, string, string, string) error { + atomic.AddInt32(&deleteSnapshotCalled, 1) + return nil }, } + snapshotID := uuid.New() mock := &mockDBTX{ - queryRowFn: func(context.Context, string, ...any) pgx.Row { return sandboxRow(sb) }, + queryRowFn: func(_ context.Context, sql string, _ ...any) pgx.Row { + switch { + case strings.Contains(sql, "'pausing'"): + return sandboxRow(sb) + case strings.Contains(sql, "INSERT INTO snapshot"): + // prev_snapshot_id = invalid (first pause) + return finalizePauseRow(snapshotID, pgtype.UUID{}) + case strings.Contains(sql, "FROM sandbox"): + return sandboxRow(sb) + } + return activityRow() + }, execFn: func(context.Context, string, ...any) (pgconn.CommandTag, error) { return pgconn.NewCommandTag("UPDATE 1"), nil }, @@ -1166,676 +1146,287 @@ func TestPauseSandbox_VMDError(t *testing.T) { w := httptest.NewRecorder() setupTestRouter(h, teamID.String()).ServeHTTP(w, pauseRequest(sandboxID.String())) - if w.Code != http.StatusInternalServerError { - t.Errorf("status = %d, want %d", w.Code, http.StatusInternalServerError) + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200; body: %s", w.Code, w.Body.String()) } -} -func TestPauseSandbox_MissingTeamID(t *testing.T) { - vmd := &stubVMD{} - mock := &mockDBTX{ - queryRowFn: func(context.Context, string, ...any) pgx.Row { return notFoundRow() }, - execFn: func(context.Context, string, ...any) (pgconn.CommandTag, error) { return pgconn.NewCommandTag(""), nil }, + // The cleanup is fire-and-forget; if prev is NULL the helper returns + // immediately before launching the goroutine. A brief sleep is enough + // to catch the false-positive case where a goroutine did run. + time.Sleep(50 * time.Millisecond) + if got := atomic.LoadInt32(&deleteSnapshotCalled); got != 0 { + t.Errorf("VMD.DeleteSnapshot called %d times on first pause, want 0", got) } - - h := &Handlers{VMD: vmd, DB: db.New(mock)} - w := httptest.NewRecorder() - setupTestRouter(h, "").ServeHTTP(w, pauseRequest(uuid.New().String())) - - if w.Code != http.StatusUnauthorized { - t.Errorf("status = %d, want %d", w.Code, http.StatusUnauthorized) - } -} - -// --------------------------------------------------------------------------- -// Sandbox file operation tests -// --------------------------------------------------------------------------- - -func uploadFileReq(sandboxID, filePath string) *http.Request { - return httptest.NewRequest(http.MethodPut, "/sandboxes/"+sandboxID+"/files/"+filePath, strings.NewReader("file-content")) } -func downloadFileReq(sandboxID, filePath string) *http.Request { - return httptest.NewRequest(http.MethodGet, "/sandboxes/"+sandboxID+"/files/"+filePath, nil) -} - -func TestUploadSandboxFile_Success(t *testing.T) { +// TestPauseSandbox_PreviousSnapshotCleanedUp verifies that when a sandbox is +// re-paused (prev_snapshot_id is set), the old snapshot is garbage-collected: +// its files are removed via VMD.DeleteSnapshot and its DB row is deleted. +func TestPauseSandbox_PreviousSnapshotCleanedUp(t *testing.T) { sandboxID := uuid.New() teamID := uuid.New() + prevSnapshotID := uuid.New() + prevPath := "/snapshots/prev/vmstate.snap" + prevMemPath := "/snapshots/prev/mem.snap" sb := db.Sandbox{ID: sandboxID, TeamID: teamID, Name: "sb", Status: db.SandboxStatusActive} - var uploadCalled bool + deleteCalled := make(chan struct{}, 1) + var gotSnapshotPath, gotMemPath string vmd := &stubVMD{ - uploadFn: func(_ context.Context, id, path string, content io.Reader) (int64, error) { - uploadCalled = true - if id != sandboxID.String() { - t.Errorf("UploadFile id = %q, want %q", id, sandboxID) - } - if path != "/app/main.go" { - t.Errorf("UploadFile path = %q, want %q", path, "/app/main.go") + pauseFn: func(context.Context, string, string) (string, string, error) { + return "/snapshots/new/vmstate.snap", "/snapshots/new/mem.snap", nil + }, + deleteSnapshotFn: func(_ context.Context, _id, sp, mp string) error { + gotSnapshotPath = sp + gotMemPath = mp + select { + case deleteCalled <- struct{}{}: + default: } - data, _ := io.ReadAll(content) - return int64(len(data)), nil + return nil }, } + newSnapshotID := uuid.New() + var deleteRowCalled int32 mock := &mockDBTX{ queryRowFn: func(_ context.Context, sql string, _ ...any) pgx.Row { - if strings.Contains(sql, "FROM sandbox") { + switch { + case strings.Contains(sql, "'pausing'"): + return sandboxRow(sb) + case strings.Contains(sql, "INSERT INTO snapshot"): + return finalizePauseRow(newSnapshotID, pgtype.UUID{Bytes: prevSnapshotID, Valid: true}) + case strings.Contains(sql, "FROM snapshot"): + // GetSnapshotForCleanup: return prev snapshot paths. + mem := prevMemPath + return &mockRow{scanFn: func(dest ...any) error { + *dest[0].(*uuid.UUID) = prevSnapshotID + *dest[1].(*uuid.UUID) = teamID + *dest[2].(*string) = prevPath + *dest[3].(**string) = &mem + return nil + }} + case strings.Contains(sql, "FROM sandbox"): return sandboxRow(sb) } return activityRow() }, - execFn: func(context.Context, string, ...any) (pgconn.CommandTag, error) { + execFn: func(_ context.Context, sql string, _ ...any) (pgconn.CommandTag, error) { + if strings.Contains(sql, "DELETE FROM snapshot") { + atomic.AddInt32(&deleteRowCalled, 1) + return pgconn.NewCommandTag("DELETE 1"), nil + } return pgconn.NewCommandTag("UPDATE 1"), nil }, } h := &Handlers{VMD: vmd, DB: db.New(mock)} w := httptest.NewRecorder() - setupTestRouter(h, teamID.String()).ServeHTTP(w, uploadFileReq(sandboxID.String(), "app/main.go")) + setupTestRouter(h, teamID.String()).ServeHTTP(w, pauseRequest(sandboxID.String())) if w.Code != http.StatusOK { - t.Errorf("status = %d, want %d; body: %s", w.Code, http.StatusOK, w.Body.String()) + t.Fatalf("status = %d, want 200; body: %s", w.Code, w.Body.String()) + } + + select { + case <-deleteCalled: + case <-time.After(500 * time.Millisecond): + t.Fatal("VMD.DeleteSnapshot was not called within 500ms") + } + + if gotSnapshotPath != prevPath { + t.Errorf("DeleteSnapshot snapshot_path = %q, want %q", gotSnapshotPath, prevPath) + } + if gotMemPath != prevMemPath { + t.Errorf("DeleteSnapshot mem_path = %q, want %q", gotMemPath, prevMemPath) } - if !uploadCalled { - t.Error("VMD.UploadFile was not called") + + // DB row delete runs after VMD succeeds; give it a tick. + deadline := time.Now().Add(500 * time.Millisecond) + for time.Now().Before(deadline) { + if atomic.LoadInt32(&deleteRowCalled) > 0 { + break + } + time.Sleep(10 * time.Millisecond) + } + if atomic.LoadInt32(&deleteRowCalled) == 0 { + t.Error("DeleteSnapshotRow was not called after successful VMD delete") } } -func TestDownloadSandboxFile_Success(t *testing.T) { +// TestPauseSandbox_CleanupVMDFailureLeavesRow verifies that when VMD delete +// fails, the DB row is left in place (so a future retry can clean it up). +func TestPauseSandbox_CleanupVMDFailureLeavesRow(t *testing.T) { sandboxID := uuid.New() teamID := uuid.New() + prevSnapshotID := uuid.New() sb := db.Sandbox{ID: sandboxID, TeamID: teamID, Name: "sb", Status: db.SandboxStatusActive} - var downloadCalled bool + deleteAttempted := make(chan struct{}, 1) vmd := &stubVMD{ - downloadFn: func(_ context.Context, id, path string) (io.ReadCloser, error) { - downloadCalled = true - if id != sandboxID.String() { - t.Errorf("DownloadFile id = %q, want %q", id, sandboxID) + pauseFn: func(context.Context, string, string) (string, string, error) { + return "/snapshots/new/vmstate.snap", "/snapshots/new/mem.snap", nil + }, + deleteSnapshotFn: func(context.Context, string, string, string) error { + select { + case deleteAttempted <- struct{}{}: + default: } - return io.NopCloser(strings.NewReader("hello world")), nil + return fmt.Errorf("vmd unreachable") }, } + newSnapshotID := uuid.New() + var deleteRowCalled int32 mock := &mockDBTX{ queryRowFn: func(_ context.Context, sql string, _ ...any) pgx.Row { - if strings.Contains(sql, "FROM sandbox") { + switch { + case strings.Contains(sql, "'pausing'"): + return sandboxRow(sb) + case strings.Contains(sql, "INSERT INTO snapshot"): + return finalizePauseRow(newSnapshotID, pgtype.UUID{Bytes: prevSnapshotID, Valid: true}) + case strings.Contains(sql, "FROM snapshot"): + mem := "/snapshots/prev/mem.snap" + return &mockRow{scanFn: func(dest ...any) error { + *dest[0].(*uuid.UUID) = prevSnapshotID + *dest[1].(*uuid.UUID) = teamID + *dest[2].(*string) = "/snapshots/prev/vmstate.snap" + *dest[3].(**string) = &mem + return nil + }} + case strings.Contains(sql, "FROM sandbox"): return sandboxRow(sb) } return activityRow() }, - execFn: func(context.Context, string, ...any) (pgconn.CommandTag, error) { + execFn: func(_ context.Context, sql string, _ ...any) (pgconn.CommandTag, error) { + if strings.Contains(sql, "DELETE FROM snapshot") { + atomic.AddInt32(&deleteRowCalled, 1) + } return pgconn.NewCommandTag("UPDATE 1"), nil }, } h := &Handlers{VMD: vmd, DB: db.New(mock)} w := httptest.NewRecorder() - setupTestRouter(h, teamID.String()).ServeHTTP(w, downloadFileReq(sandboxID.String(), "app/main.go")) + setupTestRouter(h, teamID.String()).ServeHTTP(w, pauseRequest(sandboxID.String())) if w.Code != http.StatusOK { - t.Errorf("status = %d, want %d; body: %s", w.Code, http.StatusOK, w.Body.String()) + t.Fatalf("status = %d, want 200; body: %s", w.Code, w.Body.String()) } - if !downloadCalled { - t.Error("VMD.DownloadFile was not called") + + select { + case <-deleteAttempted: + case <-time.After(500 * time.Millisecond): + t.Fatal("VMD.DeleteSnapshot was not attempted within 500ms") } - if w.Body.String() != "hello world" { - t.Errorf("body = %q, want %q", w.Body.String(), "hello world") + + time.Sleep(50 * time.Millisecond) + if atomic.LoadInt32(&deleteRowCalled) != 0 { + t.Error("DeleteSnapshotRow should not run when VMD delete failed") } } -func TestUploadSandboxFile_PathTraversal(t *testing.T) { +func TestPauseSandbox_NotActive(t *testing.T) { sandboxID := uuid.New() teamID := uuid.New() - sb := db.Sandbox{ID: sandboxID, TeamID: teamID, Name: "sb", Status: db.SandboxStatusActive} - vmd := &stubVMD{} + // BeginPause's WHERE status = 'active' clause excludes an idle + // sandbox → 0 rows. The handler falls back to SandboxExists to + // disambiguate 404 vs 409; since the row exists (just in the wrong + // state), we return 409 Conflict. mock := &mockDBTX{ queryRowFn: func(_ context.Context, sql string, _ ...any) pgx.Row { - if strings.Contains(sql, "FROM sandbox") { - return sandboxRow(sb) + if strings.Contains(sql, "'pausing'") { + return notFoundRow() // BeginPause: no active row matched } - return activityRow() - }, - execFn: func(context.Context, string, ...any) (pgconn.CommandTag, error) { - return pgconn.NewCommandTag("UPDATE 1"), nil + if strings.Contains(sql, "EXISTS") { + return boolRow(true) // fallback: row exists, but not active + } + return notFoundRow() }, + execFn: func(context.Context, string, ...any) (pgconn.CommandTag, error) { return pgconn.NewCommandTag(""), nil }, } + vmd := &stubVMD{} h := &Handlers{VMD: vmd, DB: db.New(mock)} w := httptest.NewRecorder() - setupTestRouter(h, teamID.String()).ServeHTTP(w, uploadFileReq(sandboxID.String(), "../etc/passwd")) + setupTestRouter(h, teamID.String()).ServeHTTP(w, pauseRequest(sandboxID.String())) - if w.Code != http.StatusBadRequest { - t.Errorf("status = %d, want %d; body: %s", w.Code, http.StatusBadRequest, w.Body.String()) + if w.Code != http.StatusConflict { + t.Errorf("status = %d, want %d", w.Code, http.StatusConflict) } } -func TestDownloadSandboxFile_NotFound(t *testing.T) { - sandboxID := uuid.New() - teamID := uuid.New() - sb := db.Sandbox{ID: sandboxID, TeamID: teamID, Name: "sb", Status: db.SandboxStatusActive} - - vmd := &stubVMD{ - downloadFn: func(context.Context, string, string) (io.ReadCloser, error) { - return nil, fmt.Errorf("404 not found") - }, - } - +func TestPauseSandbox_NotFound(t *testing.T) { + // BeginPause returns 0 rows (sandbox doesn't exist), and the + // SandboxExists fallback also returns false → 404. mock := &mockDBTX{ queryRowFn: func(_ context.Context, sql string, _ ...any) pgx.Row { - if strings.Contains(sql, "FROM sandbox") { - return sandboxRow(sb) + if strings.Contains(sql, "'pausing'") { + return notFoundRow() } - return activityRow() - }, - execFn: func(context.Context, string, ...any) (pgconn.CommandTag, error) { - return pgconn.NewCommandTag("UPDATE 1"), nil + if strings.Contains(sql, "EXISTS") { + return boolRow(false) + } + return notFoundRow() }, + execFn: func(context.Context, string, ...any) (pgconn.CommandTag, error) { return pgconn.NewCommandTag(""), nil }, } + vmd := &stubVMD{} h := &Handlers{VMD: vmd, DB: db.New(mock)} w := httptest.NewRecorder() - setupTestRouter(h, teamID.String()).ServeHTTP(w, downloadFileReq(sandboxID.String(), "missing.txt")) + setupTestRouter(h, uuid.New().String()).ServeHTTP(w, pauseRequest(uuid.New().String())) if w.Code != http.StatusNotFound { t.Errorf("status = %d, want %d", w.Code, http.StatusNotFound) } } -// --------------------------------------------------------------------------- -// Instance handler mock (separate from sandbox stubVMD above) -// --------------------------------------------------------------------------- - -type mockVMD struct { - createInstanceFn func(ctx context.Context, instanceID string, vcpu, memMiB, diskMiB uint32, metadata map[string]string) (string, error) - destroyInstanceFn func(ctx context.Context, instanceID string, force bool) error - pauseInstanceFn func(ctx context.Context, instanceID, snapshotDir string) (string, string, error) - resumeInstanceFn func(ctx context.Context, instanceID, snapshotPath, memPath string) (string, error) - execCommandFn func(ctx context.Context, instanceID, command string, args []string, env map[string]string, workingDir string, timeoutS uint32) (string, string, int32, error) - execCommandStreamFn func(ctx context.Context, instanceID, command string, args []string, env map[string]string, workingDir string, timeoutS uint32, onChunk func([]byte, []byte, int32, bool)) error - uploadFileFn func(ctx context.Context, instanceID, path string, content io.Reader) (int64, error) - downloadFileFn func(ctx context.Context, instanceID, path string) (io.ReadCloser, error) -} - -func (m *mockVMD) CreateInstance(ctx context.Context, instanceID string, vcpu, memMiB, diskMiB uint32, metadata map[string]string) (string, error) { - if m.createInstanceFn != nil { - return m.createInstanceFn(ctx, instanceID, vcpu, memMiB, diskMiB, metadata) - } - return "10.0.0.1", nil -} -func (m *mockVMD) DestroyInstance(ctx context.Context, instanceID string, force bool) error { - if m.destroyInstanceFn != nil { - return m.destroyInstanceFn(ctx, instanceID, force) - } - return nil -} -func (m *mockVMD) PauseInstance(ctx context.Context, instanceID, snapshotDir string) (string, string, error) { - if m.pauseInstanceFn != nil { - return m.pauseInstanceFn(ctx, instanceID, snapshotDir) - } - return "/snap/path", "/mem/path", nil -} -func (m *mockVMD) ResumeInstance(ctx context.Context, instanceID, snapshotPath, memPath string) (string, error) { - if m.resumeInstanceFn != nil { - return m.resumeInstanceFn(ctx, instanceID, snapshotPath, memPath) - } - return "10.0.0.1", nil -} -func (m *mockVMD) ExecCommand(ctx context.Context, instanceID, command string, args []string, env map[string]string, workingDir string, timeoutS uint32) (string, string, int32, error) { - if m.execCommandFn != nil { - return m.execCommandFn(ctx, instanceID, command, args, env, workingDir, timeoutS) - } - return "hello\n", "", 0, nil -} -func (m *mockVMD) ExecCommandStream(ctx context.Context, instanceID, command string, args []string, env map[string]string, workingDir string, timeoutS uint32, onChunk func([]byte, []byte, int32, bool)) error { - if m.execCommandStreamFn != nil { - return m.execCommandStreamFn(ctx, instanceID, command, args, env, workingDir, timeoutS, onChunk) - } - onChunk([]byte("hello\n"), nil, 0, true) - return nil -} -func (m *mockVMD) UploadFile(ctx context.Context, instanceID, path string, content io.Reader) (int64, error) { - if m.uploadFileFn != nil { - return m.uploadFileFn(ctx, instanceID, path, content) - } - return 42, nil -} -func (m *mockVMD) DownloadFile(ctx context.Context, instanceID, path string) (io.ReadCloser, error) { - if m.downloadFileFn != nil { - return m.downloadFileFn(ctx, instanceID, path) - } - return io.NopCloser(strings.NewReader("file-content")), nil -} -func (m *mockVMD) UpdateSandboxNetwork(_ context.Context, _ string, _, _, _ []string) error { - return nil -} - -func newTestHandlers(vmd VMDClient) *Handlers { return &Handlers{VMD: vmd} } - -func jsonBody(v interface{}) *bytes.Buffer { b, _ := json.Marshal(v); return bytes.NewBuffer(b) } - -func setupInstanceTestRouter(h *Handlers, teamID string) *gin.Engine { - r := gin.New() - r.Use(func(c *gin.Context) { - if teamID != "" { - c.Set("team_id", teamID) - } - c.Next() - }) - r.GET("/health", h.Health) - r.POST("/instances", h.CreateInstance) - r.GET("/instances/:instance_id", h.GetInstance) - r.GET("/instances", h.ListInstances) - r.DELETE("/instances/:instance_id", h.DeleteInstance) - r.POST("/instances/:instance_id/pause", h.PauseInstance) - r.POST("/instances/:instance_id/resume", h.ResumeInstance) - r.POST("/instances/:instance_id/exec", h.ExecCommand) - r.POST("/instances/:instance_id/exec/stream", h.ExecCommandStream) - r.PUT("/instances/:instance_id/files/*path", h.UploadFile) - r.GET("/instances/:instance_id/files/*path", h.DownloadFile) - return r -} - -func TestHealth(t *testing.T) { - r := setupInstanceTestRouter(newTestHandlers(&mockVMD{}), "") - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/health", nil) - r.ServeHTTP(w, req) - if w.Code != http.StatusOK { - t.Fatalf("expected 200, got %d", w.Code) - } - body := parseJSON(t, w) - if body["status"] != "ok" { - t.Errorf("status=%v want ok", body["status"]) - } -} - -func TestCreateInstance_Success(t *testing.T) { - r := setupInstanceTestRouter(newTestHandlers(&mockVMD{}), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/instances", jsonBody(map[string]string{"name": "test-box"})) - req.Header.Set("Content-Type", "application/json") - r.ServeHTTP(w, req) - if w.Code != http.StatusCreated { - t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String()) - } -} - -func TestCreateInstance_BadRequest_MissingName(t *testing.T) { - r := setupInstanceTestRouter(newTestHandlers(&mockVMD{}), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/instances", jsonBody(map[string]string{})) - req.Header.Set("Content-Type", "application/json") - r.ServeHTTP(w, req) - if w.Code != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", w.Code) - } -} - -func TestCreateInstance_BadRequest_EmptyBody(t *testing.T) { - r := setupInstanceTestRouter(newTestHandlers(&mockVMD{}), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/instances", nil) - req.Header.Set("Content-Type", "application/json") - r.ServeHTTP(w, req) - if w.Code != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", w.Code) - } -} - -func TestCreateInstance_VMDFailure(t *testing.T) { - vmd := &mockVMD{createInstanceFn: func(_ context.Context, _ string, _, _, _ uint32, _ map[string]string) (string, error) { - return "", errors.New("vmd unavailable") - }} - r := setupInstanceTestRouter(newTestHandlers(vmd), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/instances", jsonBody(map[string]string{"name": "fail-box"})) - req.Header.Set("Content-Type", "application/json") - r.ServeHTTP(w, req) - if w.Code != http.StatusInternalServerError { - t.Fatalf("expected 500, got %d", w.Code) - } -} - -func TestDeleteInstance_Success(t *testing.T) { - r := setupInstanceTestRouter(newTestHandlers(&mockVMD{}), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("DELETE", "/instances/"+uuid.New().String(), nil) - r.ServeHTTP(w, req) - if w.Code != http.StatusNoContent { - t.Fatalf("expected 204, got %d: %s", w.Code, w.Body.String()) - } -} - -func TestDeleteInstance_InvalidUUID(t *testing.T) { - r := setupInstanceTestRouter(newTestHandlers(&mockVMD{}), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("DELETE", "/instances/not-a-uuid", nil) - r.ServeHTTP(w, req) - if w.Code != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", w.Code) - } -} - -func TestDeleteInstance_VMDFailure(t *testing.T) { - vmd := &mockVMD{destroyInstanceFn: func(_ context.Context, _ string, _ bool) error { return errors.New("destroy failed") }} - r := setupInstanceTestRouter(newTestHandlers(vmd), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("DELETE", "/instances/"+uuid.New().String(), nil) - r.ServeHTTP(w, req) - if w.Code != http.StatusInternalServerError { - t.Fatalf("expected 500, got %d", w.Code) - } -} - -func TestPauseInstance_Success(t *testing.T) { - r := setupInstanceTestRouter(newTestHandlers(&mockVMD{}), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/instances/"+uuid.New().String()+"/pause", nil) - r.ServeHTTP(w, req) - if w.Code != http.StatusOK { - t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) - } - if parseJSON(t, w)["status"] != "PAUSED" { - t.Errorf("expected status=PAUSED") - } -} - -func TestPauseInstance_InvalidUUID(t *testing.T) { - r := setupInstanceTestRouter(newTestHandlers(&mockVMD{}), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/instances/bad/pause", nil) - r.ServeHTTP(w, req) - if w.Code != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", w.Code) - } -} - -func TestPauseInstance_VMDFailure(t *testing.T) { - vmd := &mockVMD{pauseInstanceFn: func(_ context.Context, _, _ string) (string, string, error) { return "", "", errors.New("pause failed") }} - r := setupInstanceTestRouter(newTestHandlers(vmd), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/instances/"+uuid.New().String()+"/pause", nil) - r.ServeHTTP(w, req) - if w.Code != http.StatusInternalServerError { - t.Fatalf("expected 500, got %d", w.Code) - } -} - -func TestResumeInstance_Success(t *testing.T) { - vmd := &mockVMD{resumeInstanceFn: func(_ context.Context, _, _, _ string) (string, error) { return "10.0.0.42", nil }} - r := setupInstanceTestRouter(newTestHandlers(vmd), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/instances/"+uuid.New().String()+"/resume", nil) - r.ServeHTTP(w, req) - if w.Code != http.StatusOK { - t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) - } -} - -func TestResumeInstance_VMDFailure(t *testing.T) { - vmd := &mockVMD{resumeInstanceFn: func(_ context.Context, _, _, _ string) (string, error) { return "", errors.New("resume failed") }} - r := setupInstanceTestRouter(newTestHandlers(vmd), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/instances/"+uuid.New().String()+"/resume", nil) - r.ServeHTTP(w, req) - if w.Code != http.StatusInternalServerError { - t.Fatalf("expected 500, got %d", w.Code) - } -} - -func TestExecCommand_Success(t *testing.T) { - vmd := &mockVMD{execCommandFn: func(_ context.Context, _, _ string, _ []string, _ map[string]string, _ string, _ uint32) (string, string, int32, error) { - return "hello world\n", "", 0, nil - }} - r := setupInstanceTestRouter(newTestHandlers(vmd), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/instances/"+uuid.New().String()+"/exec", - jsonBody(map[string]interface{}{"command": "echo"})) - req.Header.Set("Content-Type", "application/json") - r.ServeHTTP(w, req) - if w.Code != http.StatusOK { - t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) - } - if parseJSON(t, w)["stdout"] != "hello world\n" { - t.Errorf("unexpected stdout") - } -} - -func TestExecCommand_BadRequest_MissingCommand(t *testing.T) { - r := setupInstanceTestRouter(newTestHandlers(&mockVMD{}), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/instances/"+uuid.New().String()+"/exec", jsonBody(map[string]interface{}{})) - req.Header.Set("Content-Type", "application/json") - r.ServeHTTP(w, req) - if w.Code != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", w.Code) - } -} - -func TestExecCommand_DefaultTimeout(t *testing.T) { - var capturedTimeout uint32 - vmd := &mockVMD{execCommandFn: func(_ context.Context, _, _ string, _ []string, _ map[string]string, _ string, timeoutS uint32) (string, string, int32, error) { - capturedTimeout = timeoutS - return "", "", 0, nil - }} - r := setupInstanceTestRouter(newTestHandlers(vmd), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/instances/"+uuid.New().String()+"/exec", jsonBody(map[string]interface{}{"command": "ls"})) - req.Header.Set("Content-Type", "application/json") - r.ServeHTTP(w, req) - if capturedTimeout != 30 { - t.Errorf("expected default timeout=30, got %d", capturedTimeout) - } -} - -func TestExecCommand_VMDFailure(t *testing.T) { - vmd := &mockVMD{execCommandFn: func(_ context.Context, _, _ string, _ []string, _ map[string]string, _ string, _ uint32) (string, string, int32, error) { - return "", "", 0, errors.New("exec failed") - }} - r := setupInstanceTestRouter(newTestHandlers(vmd), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/instances/"+uuid.New().String()+"/exec", jsonBody(map[string]interface{}{"command": "fail"})) - req.Header.Set("Content-Type", "application/json") - r.ServeHTTP(w, req) - if w.Code != http.StatusInternalServerError { - t.Fatalf("expected 500, got %d", w.Code) - } -} +func TestPauseSandbox_VMDError(t *testing.T) { + sandboxID := uuid.New() + teamID := uuid.New() + sb := db.Sandbox{ID: sandboxID, TeamID: teamID, Name: "sb", Status: db.SandboxStatusActive} -func TestExecCommand_NonZeroExit(t *testing.T) { - vmd := &mockVMD{execCommandFn: func(_ context.Context, _, _ string, _ []string, _ map[string]string, _ string, _ uint32) (string, string, int32, error) { - return "", "not found\n", 127, nil - }} - r := setupInstanceTestRouter(newTestHandlers(vmd), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/instances/"+uuid.New().String()+"/exec", jsonBody(map[string]interface{}{"command": "missing"})) - req.Header.Set("Content-Type", "application/json") - r.ServeHTTP(w, req) - if w.Code != http.StatusOK { - t.Fatalf("expected 200, got %d", w.Code) - } - if parseJSON(t, w)["exit_code"] != float64(127) { - t.Errorf("expected exit_code=127") + vmd := &stubVMD{ + pauseFn: func(context.Context, string, string) (string, string, error) { + return "", "", fmt.Errorf("vmd unreachable") + }, } -} -func TestUploadFile_Success(t *testing.T) { - vmd := &mockVMD{uploadFileFn: func(_ context.Context, _, _ string, content io.Reader) (int64, error) { - data, _ := io.ReadAll(content) - return int64(len(data)), nil - }} - r := setupInstanceTestRouter(newTestHandlers(vmd), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("PUT", "/instances/"+uuid.New().String()+"/files/home/user/test.txt", strings.NewReader("file content here")) - r.ServeHTTP(w, req) - if w.Code != http.StatusOK { - t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) - } - if parseJSON(t, w)["path"] != "/home/user/test.txt" { - t.Errorf("unexpected path") + mock := &mockDBTX{ + queryRowFn: func(context.Context, string, ...any) pgx.Row { return sandboxRow(sb) }, + execFn: func(context.Context, string, ...any) (pgconn.CommandTag, error) { + return pgconn.NewCommandTag("UPDATE 1"), nil + }, } -} -func TestUploadFile_PathTraversal(t *testing.T) { - r := setupInstanceTestRouter(newTestHandlers(&mockVMD{}), uuid.New().String()) + h := &Handlers{VMD: vmd, DB: db.New(mock)} w := httptest.NewRecorder() - req, _ := http.NewRequest("PUT", "/instances/"+uuid.New().String()+"/files/../etc/passwd", strings.NewReader("bad")) - r.ServeHTTP(w, req) - if w.Code != http.StatusBadRequest { - t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String()) - } -} + setupTestRouter(h, teamID.String()).ServeHTTP(w, pauseRequest(sandboxID.String())) -func TestUploadFile_VMDFailure(t *testing.T) { - vmd := &mockVMD{uploadFileFn: func(_ context.Context, _, _ string, _ io.Reader) (int64, error) { return 0, errors.New("upload failed") }} - r := setupInstanceTestRouter(newTestHandlers(vmd), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("PUT", "/instances/"+uuid.New().String()+"/files/test.txt", strings.NewReader("data")) - r.ServeHTTP(w, req) if w.Code != http.StatusInternalServerError { - t.Fatalf("expected 500, got %d", w.Code) - } -} - -func TestDownloadFile_Success(t *testing.T) { - vmd := &mockVMD{downloadFileFn: func(_ context.Context, _, _ string) (io.ReadCloser, error) { - return io.NopCloser(strings.NewReader("downloaded-content")), nil - }} - r := setupInstanceTestRouter(newTestHandlers(vmd), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/instances/"+uuid.New().String()+"/files/data.txt", nil) - r.ServeHTTP(w, req) - if w.Code != http.StatusOK { - t.Fatalf("expected 200, got %d", w.Code) - } - if w.Body.String() != "downloaded-content" { - t.Errorf("unexpected body: %s", w.Body.String()) - } -} - -func TestDownloadFile_NotFound(t *testing.T) { - vmd := &mockVMD{downloadFileFn: func(_ context.Context, _, _ string) (io.ReadCloser, error) { return nil, errors.New("404 not found") }} - r := setupInstanceTestRouter(newTestHandlers(vmd), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/instances/"+uuid.New().String()+"/files/missing.txt", nil) - r.ServeHTTP(w, req) - if w.Code != http.StatusNotFound { - t.Fatalf("expected 404, got %d", w.Code) - } -} - -func TestDownloadFile_PathTraversal(t *testing.T) { - r := setupInstanceTestRouter(newTestHandlers(&mockVMD{}), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/instances/"+uuid.New().String()+"/files/../../../etc/shadow", nil) - r.ServeHTTP(w, req) - if w.Code != http.StatusBadRequest { - t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String()) - } -} - -func TestExecCommandStream_Success(t *testing.T) { - vmd := &mockVMD{execCommandStreamFn: func(_ context.Context, _, _ string, _ []string, _ map[string]string, _ string, _ uint32, onChunk func([]byte, []byte, int32, bool)) error { - onChunk([]byte("line1\n"), nil, 0, false) - onChunk(nil, nil, 0, true) - return nil - }} - r := setupInstanceTestRouter(newTestHandlers(vmd), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/instances/"+uuid.New().String()+"/exec/stream", jsonBody(map[string]interface{}{"command": "echo"})) - req.Header.Set("Content-Type", "application/json") - r.ServeHTTP(w, req) - if w.Code != http.StatusOK { - t.Fatalf("expected 200, got %d", w.Code) - } - if w.Header().Get("Content-Type") != "text/event-stream" { - t.Errorf("expected text/event-stream") - } - if !strings.Contains(w.Body.String(), "line1") { - t.Errorf("expected stdout in SSE stream") + t.Errorf("status = %d, want %d", w.Code, http.StatusInternalServerError) } } -func TestExecCommandStream_BadRequest(t *testing.T) { - r := setupInstanceTestRouter(newTestHandlers(&mockVMD{}), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/instances/"+uuid.New().String()+"/exec/stream", jsonBody(map[string]interface{}{})) - req.Header.Set("Content-Type", "application/json") - r.ServeHTTP(w, req) - if w.Code != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", w.Code) +func TestPauseSandbox_MissingTeamID(t *testing.T) { + vmd := &stubVMD{} + mock := &mockDBTX{ + queryRowFn: func(context.Context, string, ...any) pgx.Row { return notFoundRow() }, + execFn: func(context.Context, string, ...any) (pgconn.CommandTag, error) { return pgconn.NewCommandTag(""), nil }, } -} -func TestGetInstance_NotImplemented(t *testing.T) { - r := setupInstanceTestRouter(newTestHandlers(&mockVMD{}), uuid.New().String()) - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/instances/"+uuid.New().String(), nil) - r.ServeHTTP(w, req) - if w.Code != http.StatusNotImplemented { - t.Fatalf("expected 501, got %d", w.Code) - } -} - -func TestListInstances_NotImplemented(t *testing.T) { - r := setupInstanceTestRouter(newTestHandlers(&mockVMD{}), uuid.New().String()) + h := &Handlers{VMD: vmd, DB: db.New(mock)} w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/instances", nil) - r.ServeHTTP(w, req) - if w.Code != http.StatusNotImplemented { - t.Fatalf("expected 501, got %d", w.Code) - } -} - -func TestCleanFilePath(t *testing.T) { - tests := []struct { - name string - input string - want string - wantErr bool - }{ - {"simple file", "/test.txt", "/test.txt", false}, - {"nested path", "/home/user/data.csv", "/home/user/data.csv", false}, - {"strips leading slash", "test.txt", "/test.txt", false}, - {"empty path", "/", "", true}, - {"traversal blocked", "/../etc/passwd", "", true}, - {"double dot in middle", "/foo/../bar", "", true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := cleanFilePath(tt.input) - if (err != nil) != tt.wantErr { - t.Errorf("cleanFilePath(%q): err=%v, wantErr=%v", tt.input, err, tt.wantErr) - return - } - if !tt.wantErr && got != tt.want { - t.Errorf("cleanFilePath(%q) = %q, want %q", tt.input, got, tt.want) - } - }) - } -} + setupTestRouter(h, "").ServeHTTP(w, pauseRequest(uuid.New().String())) -func TestParseInstanceID_Valid(t *testing.T) { - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - c.Params = gin.Params{{Key: "instance_id", Value: uuid.New().String()}} - if _, err := parseInstanceID(c); err != nil { - t.Fatalf("expected no error, got %v", err) + if w.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want %d", w.Code, http.StatusUnauthorized) } } -func TestParseInstanceID_Invalid(t *testing.T) { - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - c.Params = gin.Params{{Key: "instance_id", Value: "bad"}} - if _, err := parseInstanceID(c); err == nil { - t.Fatal("expected error for invalid UUID") - } -} func TestParseSandboxID_Valid(t *testing.T) { w := httptest.NewRecorder() @@ -2048,8 +1639,9 @@ func TestCreateSandbox_WithMetadata(t *testing.T) { mock := &mockDBTX{ queryRowFn: func(_ context.Context, sql string, args ...any) pgx.Row { if strings.Contains(sql, "INSERT INTO sandbox") { - // args[10] is the metadata jsonb (11th positional, 0-indexed). - if b, ok := args[10].([]byte); ok { + // Positional args (0-indexed): id, team_id, name, status, + // vcpu, mem, host_id, ip, pid, snapshot_id, timeout, metadata. + if b, ok := args[11].([]byte); ok { capturedMetadata = b } // Echo metadata back through the row so the response carries it. @@ -2108,7 +1700,9 @@ func TestCreateSandbox_EmptyMetadataIsObjectNotNull(t *testing.T) { mock := &mockDBTX{ queryRowFn: func(_ context.Context, sql string, args ...any) pgx.Row { if strings.Contains(sql, "INSERT INTO sandbox") { - if b, ok := args[10].([]byte); ok { + // Positional args (0-indexed): id, team_id, name, status, + // vcpu, mem, host_id, ip, pid, snapshot_id, timeout, metadata. + if b, ok := args[11].([]byte); ok { capturedMetadata = b } return sandboxRow(db.Sandbox{ diff --git a/internal/api/host_detector.go b/internal/api/host_detector.go new file mode 100644 index 0000000..d6bfb81 --- /dev/null +++ b/internal/api/host_detector.go @@ -0,0 +1,73 @@ +package api + +import ( + "context" + "time" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/rs/zerolog/log" + + "github.com/superserve-ai/sandbox/internal/db" +) + +const ( + // heartbeatTimeout is how long a host can go without heartbeating + // before it's marked unhealthy. Matches the design doc's "2 minutes". + heartbeatTimeout = 2 * time.Minute + + // detectorInterval is how often we check for stale hosts. + detectorInterval = 30 * time.Second + + // detectorRunTimeout bounds each detection pass so a slow DB can't + // wedge the loop. + detectorRunTimeout = 15 * time.Second +) + +// StartHostDetector launches a background goroutine that periodically +// marks active hosts as unhealthy when their heartbeat goes stale. +// Blocks until ctx is cancelled. +func StartHostDetector(ctx context.Context, queries *db.Queries) { + log.Info(). + Dur("timeout", heartbeatTimeout). + Dur("interval", detectorInterval). + Msg("host detector started") + + ticker := time.NewTicker(detectorInterval) + defer ticker.Stop() + + detectOnce(ctx, queries) + + for { + select { + case <-ctx.Done(): + log.Info().Msg("host detector exiting") + return + case <-ticker.C: + runCtx, cancel := context.WithTimeout(ctx, detectorRunTimeout) + detectOnce(runCtx, queries) + cancel() + } + } +} + +func detectOnce(ctx context.Context, queries *db.Queries) { + cutoff := time.Now().Add(-heartbeatTimeout) + stale, err := queries.ListStaleHosts(ctx, pgtype.Timestamptz{ + Time: cutoff, + Valid: true, + }) + if err != nil { + log.Error().Err(err).Msg("host detector: ListStaleHosts failed") + return + } + + for _, host := range stale { + log.Warn().Str("host_id", host.ID). + Time("last_heartbeat", host.LastHeartbeatAt.Time). + Msg("host detector: marking host unhealthy (heartbeat timeout)") + + if err := queries.MarkHostUnhealthy(ctx, host.ID); err != nil { + log.Error().Err(err).Str("host_id", host.ID).Msg("host detector: MarkHostUnhealthy failed") + } + } +} diff --git a/internal/api/hosts.go b/internal/api/hosts.go new file mode 100644 index 0000000..513b2b6 --- /dev/null +++ b/internal/api/hosts.go @@ -0,0 +1,35 @@ +package api + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "github.com/jackc/pgx/v5" + "github.com/rs/zerolog/log" +) + +// HostHeartbeat handles POST /internal/hosts/:host_id/heartbeat. +// VMD calls this every 30s to prove liveness. The control plane updates +// last_heartbeat_at; a background detector marks hosts unhealthy after +// 2 minutes of silence. If the host was previously marked unhealthy, the +// heartbeat automatically re-activates it (recovery from transient outage). +func (h *Handlers) HostHeartbeat(c *gin.Context) { + hostID := c.Param("host_id") + if hostID == "" { + respondErrorMsg(c, "bad_request", "host_id is required", http.StatusBadRequest) + return + } + + host, err := h.DB.UpdateHostHeartbeat(c.Request.Context(), hostID) + if err != nil { + if err == pgx.ErrNoRows { + respondErrorMsg(c, "not_found", "host not found", http.StatusNotFound) + return + } + log.Error().Err(err).Str("host_id", hostID).Msg("UpdateHostHeartbeat failed") + respondError(c, ErrInternal) + return + } + + c.JSON(http.StatusOK, gin.H{"status": host.Status}) +} diff --git a/internal/api/internal_auth.go b/internal/api/internal_auth.go new file mode 100644 index 0000000..0dbf4f6 --- /dev/null +++ b/internal/api/internal_auth.go @@ -0,0 +1,42 @@ +package api + +import ( + "crypto/subtle" + "net/http" + "os" + "strings" + + "github.com/gin-gonic/gin" +) + +// InternalAuth returns middleware that authenticates internal API requests +// via a shared token in the Authorization header. The expected token is +// read from the INTERNAL_API_TOKEN env var. If the env var is unset, all +// requests are rejected (fail-closed). +func InternalAuth() gin.HandlerFunc { + token := os.Getenv("INTERNAL_API_TOKEN") + + return func(c *gin.Context) { + if token == "" { + respondErrorMsg(c, "unauthorized", "internal API not configured", http.StatusUnauthorized) + c.Abort() + return + } + + auth := c.GetHeader("Authorization") + provided := strings.TrimPrefix(auth, "Bearer ") + if provided == auth || provided == "" { + respondErrorMsg(c, "unauthorized", "missing or invalid Authorization header", http.StatusUnauthorized) + c.Abort() + return + } + + if subtle.ConstantTimeCompare([]byte(provided), []byte(token)) != 1 { + respondErrorMsg(c, "unauthorized", "invalid token", http.StatusUnauthorized) + c.Abort() + return + } + + c.Next() + } +} diff --git a/internal/api/reaper.go b/internal/api/reaper.go index 963df07..2df1932 100644 --- a/internal/api/reaper.go +++ b/internal/api/reaper.go @@ -4,11 +4,11 @@ import ( "context" "time" - "github.com/jackc/pgx/v5/pgtype" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/superserve-ai/sandbox/internal/db" + "github.com/superserve-ai/sandbox/internal/telemetry" ) // ReaperConfig controls the timeout reaper loop. @@ -91,6 +91,7 @@ func (h *Handlers) reapOnce(ctx context.Context, batchSize int32) { } log.Info().Int("count", len(expired)).Msg("reaper: pausing expired sandboxes") + telemetry.IncReaperReaped(ctx, "timeout", int64(len(expired))) for _, sbx := range expired { // Check for shutdown between each pause so we exit promptly. @@ -130,8 +131,15 @@ func (h *Handlers) pauseExpired(ctx context.Context, sbx db.ClaimExpiredSandboxe Str("name", sbx.Name). Logger() + vmd, vmdLookupErr := h.vmdForHost(ctx, sbx.HostID) + if vmdLookupErr != nil { + l.Error().Err(vmdLookupErr).Msg("reaper: resolve VMD failed — reverting to active") + h.revertToActiveOrFail(ctx, sbx, vmdLookupErr, l) + return + } + vmdCtx, vmdCancel := context.WithTimeout(ctx, vmdTimeout) - snapshotPath, memPath, err := h.VMD.PauseInstance(vmdCtx, sbx.ID.String(), "") + snapshotPath, memPath, err := vmd.PauseInstance(vmdCtx, sbx.ID.String(), "") vmdCancel() if err != nil { // VM never stopped — safe to revert DB to active so the reaper @@ -144,44 +152,32 @@ func (h *Handlers) pauseExpired(ctx context.Context, sbx db.ClaimExpiredSandboxe postCtx, postCancel := context.WithTimeout(ctx, vmdTimeout) defer postCancel() + // Atomic post-VMD bookkeeping: insert the snapshot row, link it to + // the sandbox, and flip status from pausing → idle in a single CTE. + // Same query as the user-initiated PauseSandbox handler, so the two + // code paths have identical atomicity guarantees. triggerName := "timeout" - snapshot, err := h.DB.CreateSnapshot(postCtx, db.CreateSnapshotParams{ - SandboxID: sbx.ID, + finalized, err := h.DB.FinalizePause(postCtx, db.FinalizePauseParams{ + ID: sbx.ID, TeamID: sbx.TeamID, Path: snapshotPath, + MemPath: &memPath, SizeBytes: 0, Saved: false, Name: &triggerName, Trigger: triggerName, }) if err != nil { - l.Error().Err(err).Msg("reaper: CreateSnapshot failed — rolling back VMD pause") - h.rollbackPausedVM(ctx, sbx, snapshotPath, memPath, err, l) - return - } - - if err := h.DB.SetSandboxSnapshot(postCtx, db.SetSandboxSnapshotParams{ - ID: sbx.ID, - SnapshotID: pgtype.UUID{Bytes: snapshot.ID, Valid: true}, - TeamID: sbx.TeamID, - }); err != nil { - l.Error().Err(err).Msg("reaper: SetSandboxSnapshot failed — rolling back VMD pause") - h.rollbackPausedVM(ctx, sbx, snapshotPath, memPath, err, l) - return - } - - if err := h.DB.UpdateSandboxStatus(postCtx, db.UpdateSandboxStatusParams{ - ID: sbx.ID, - Status: db.SandboxStatusIdle, - TeamID: sbx.TeamID, - }); err != nil { - l.Error().Err(err).Msg("reaper: UpdateSandboxStatus(idle) failed — rolling back VMD pause") + l.Error().Err(err).Msg("reaper: FinalizePause failed — rolling back VMD pause") h.rollbackPausedVM(ctx, sbx, snapshotPath, memPath, err, l) return } l.Info().Msg("reaper: sandbox paused due to timeout") h.logActivityAsync(ctx, sbx.ID, sbx.TeamID, "sandbox", "timeout_paused", "success", &sbx.Name, nil, nil) + + // Async GC for the now-unreachable previous snapshot, if any. + h.cleanupOldSnapshotAsync(ctx, sbx.ID, sbx.TeamID, sbx.HostID, finalized.PrevSnapshotID) } // rollbackPausedVM is the saga compensation for a failed pause. The VM is @@ -199,8 +195,15 @@ func (h *Handlers) rollbackPausedVM(ctx context.Context, sbx db.ClaimExpiredSand AnErr("cause", cause). Logger() + vmd, vmdLookupErr := h.vmdForHost(ctx, sbx.HostID) + if vmdLookupErr != nil { + rl.Error().Err(vmdLookupErr).Msg("reaper: resolve VMD for rollback failed") + h.markSandboxFailed(ctx, sbx, "resolve VMD failed during rollback", rl) + return + } + vmdCtx, vmdCancel := context.WithTimeout(ctx, vmdTimeout) - _, err := h.VMD.ResumeInstance(vmdCtx, sbx.ID.String(), snapshotPath, memPath) + _, _, _, err := vmd.ResumeInstance(vmdCtx, sbx.ID.String(), snapshotPath, memPath, nil) vmdCancel() if err != nil { rl.Error().Err(err).Msg("reaper: rollback resume failed") diff --git a/internal/api/reaper_test.go b/internal/api/reaper_test.go index e12fecf..9ae3633 100644 --- a/internal/api/reaper_test.go +++ b/internal/api/reaper_test.go @@ -42,7 +42,7 @@ func (r *stubRows) Scan(dest ...any) error { *dest[1].(*uuid.UUID) = row.TeamID *dest[2].(*string) = row.Name *dest[3].(*pgtype.UUID) = row.SnapshotID - *dest[4].(**string) = row.HostID + *dest[4].(*string) = row.HostID return nil } @@ -78,8 +78,14 @@ func (m *reaperMockDBTX) QueryRow(ctx context.Context, sql string, args ...any) if m.queryRowFn != nil { return m.queryRowFn(ctx, sql, args...) } - // Route by SQL content: snapshot insert vs activity insert. - if strings.Contains(sql, "snapshot") { + // Route by SQL content: + // - FinalizePause returns only a single snapshot_id uuid + // - legacy CreateSnapshot returns a full Snapshot row + // - activity insert returns an activity row + switch { + case strings.Contains(sql, "new_snapshot AS"): + return finalizePauseRow(uuid.New(), pgtype.UUID{}) + case strings.Contains(sql, "INSERT INTO snapshot"): return reaperSnapshotRow() } return activityRow() @@ -151,21 +157,24 @@ func TestReaper_NothingExpired(t *testing.T) { } } -// TestReaper_VMDSucceeds verifies that a claimed sandbox triggers a VMD pause, -// snapshot creation, and status update to idle. +// TestReaper_VMDSucceeds verifies that a claimed sandbox triggers a VMD +// pause followed by the atomic FinalizePause bookkeeping query. func TestReaper_VMDSucceeds(t *testing.T) { row := expiredRow("sbx-a") var pausedID string - var execCalls []string + var finalizeCalls int32 h := newReaperHandlers( &reaperMockDBTX{ queryFn: func(_ context.Context, _ string, _ ...any) (pgx.Rows, error) { return newStubRows([]db.ClaimExpiredSandboxesRow{row}), nil }, - execFn: func(_ context.Context, sql string, args ...any) (pgconn.CommandTag, error) { - execCalls = append(execCalls, sql) - return pgconn.CommandTag{}, nil + queryRowFn: func(_ context.Context, sql string, _ ...any) pgx.Row { + if strings.Contains(sql, "new_snapshot AS") { + atomic.AddInt32(&finalizeCalls, 1) + return finalizePauseRow(uuid.New(), pgtype.UUID{}) + } + return activityRow() }, }, &stubVMD{pauseFn: func(_ context.Context, id string, _ string) (string, string, error) { @@ -179,10 +188,8 @@ func TestReaper_VMDSucceeds(t *testing.T) { if pausedID != row.ID.String() { t.Fatalf("expected PauseInstance called with %s, got %q", row.ID, pausedID) } - - // Expect SetSandboxSnapshot and UpdateSandboxStatus(idle) execs. - if len(execCalls) < 2 { - t.Fatalf("expected at least 2 exec calls (SetSandboxSnapshot + UpdateSandboxStatus), got %d", len(execCalls)) + if got := atomic.LoadInt32(&finalizeCalls); got != 1 { + t.Fatalf("expected exactly 1 FinalizePause call, got %d", got) } } diff --git a/internal/api/router.go b/internal/api/router.go index 26429b6..dbce613 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -2,9 +2,11 @@ package api import ( "context" + "net/http" "github.com/gin-gonic/gin" "github.com/jackc/pgx/v5/pgxpool" + "go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin" ) // SetupRouter creates and configures the Gin router with all route groups. @@ -14,9 +16,16 @@ import ( // router instance doesn't leak a cleanup goroutine. func SetupRouter(ctx context.Context, h *Handlers, pool *pgxpool.Pool) *gin.Engine { r := gin.New() - // Global middleware: security headers, coarse per-IP rate limit - // (unauthenticated flood protection), logging, panic recovery. + // otelgin runs first so spans cover all downstream middleware (rate + // limiter rejects, auth failures, panic recovery) — those are the + // things you most want a trace for. r.Use( + otelgin.Middleware("controlplane", + otelgin.WithFilter(func(req *http.Request) bool { + // Health checks would dominate trace volume with no signal. + return req.URL.Path != "/health" + }), + ), SecurityHeaders(), RateLimit(ctx, DefaultIPRateLimitConfig()), RequestLogger(), @@ -30,20 +39,7 @@ func SetupRouter(ctx context.Context, h *Handlers, pool *pgxpool.Pool) *gin.Engi // and becomes meaningless for fairness. api.Use(APIKeyAuth(pool), TeamRateLimit(ctx, DefaultTeamRateLimitConfig())) { - api.POST("/instances", h.CreateInstance) - api.GET("/instances", h.ListInstances) - api.GET("/instances/:instance_id", h.GetInstance) - api.DELETE("/instances/:instance_id", h.DeleteInstance) - api.POST("/instances/:instance_id/pause", h.PauseInstance) - api.POST("/instances/:instance_id/resume", h.ResumeInstance) - - api.POST("/instances/:instance_id/exec", h.ExecCommand) - api.POST("/instances/:instance_id/exec/stream", h.ExecCommandStream) - - api.PUT("/instances/:instance_id/files/*path", h.UploadFile) - api.GET("/instances/:instance_id/files/*path", h.DownloadFile) - - // Sandbox lifecycle (no auto-wake). + // Sandbox lifecycle. api.POST("/sandboxes", h.CreateSandbox) api.GET("/sandboxes", h.ListSandboxes) api.GET("/sandboxes/:sandbox_id", h.GetSandboxByID) @@ -52,18 +48,24 @@ func SetupRouter(ctx context.Context, h *Handlers, pool *pgxpool.Pool) *gin.Engi api.DELETE("/sandboxes/:sandbox_id", h.DeleteSandbox) api.PATCH("/sandboxes/:sandbox_id", h.PatchSandbox) - // Sandbox operations with auto-wake middleware. - sandboxOps := api.Group("/sandboxes/:sandbox_id") - sandboxOps.Use(h.AutoWake()) - { - sandboxOps.POST("/exec", h.ExecSandbox) - sandboxOps.POST("/exec/stream", h.ExecSandboxStream) - sandboxOps.PUT("/files/*path", h.UploadSandboxFile) - sandboxOps.GET("/files/*path", h.DownloadSandboxFile) - } + // Sandbox operations. Sandbox must already be active — paused + // sandboxes must be resumed explicitly via /resume. + api.POST("/sandboxes/:sandbox_id/exec", h.ExecSandbox) + api.POST("/sandboxes/:sandbox_id/exec/stream", h.ExecSandboxStream) } r.GET("/health", h.Health) + // Internal endpoints — authenticated via a shared token (not per-team + // API keys). Called by infrastructure components (VMD heartbeat) and + // not exposed to customers. The token is checked by InternalAuth + // middleware; if INTERNAL_API_TOKEN is unset, the middleware rejects + // all requests (fail-closed). + internal := r.Group("/internal") + internal.Use(InternalAuth()) + { + internal.POST("/hosts/:host_id/heartbeat", h.HostHeartbeat) + } + return r } diff --git a/internal/api/streaming.go b/internal/api/streaming.go index 23a94dc..2535ae3 100644 --- a/internal/api/streaming.go +++ b/internal/api/streaming.go @@ -22,10 +22,12 @@ type streamExecRequest struct { TimeoutS int `json:"timeout_s"` } -// ExecCommandStream runs a shell command inside an instance and streams output via SSE. -func (h *Handlers) ExecCommandStream(c *gin.Context) { - instanceID, err := parseInstanceID(c) - if err != nil { +// ExecSandboxStream runs a command inside a sandbox and streams output via SSE. +// The sandbox must already be active — callers must resume a paused sandbox +// via POST /sandboxes/:id/resume first. +func (h *Handlers) ExecSandboxStream(c *gin.Context) { + sandbox := h.loadActiveSandbox(c) + if sandbox == nil { return } @@ -51,70 +53,9 @@ func (h *Handlers) ExecCommandStream(c *gin.Context) { return } - err = h.VMD.ExecCommandStream(c.Request.Context(), instanceID.String(), - req.Command, req.Args, req.Env, req.WorkingDir, uint32(req.TimeoutS), - func(stdout, stderr []byte, exitCode int32, finished bool) { - event := gin.H{ - "timestamp": time.Now().Format(time.RFC3339Nano), - } - if len(stdout) > 0 { - event["stdout"] = string(stdout) - } - if len(stderr) > 0 { - event["stderr"] = string(stderr) - } - if finished { - event["exit_code"] = exitCode - event["finished"] = true - } - - data, marshalErr := json.Marshal(event) - if marshalErr != nil { - return - } - - fmt.Fprintf(c.Writer, "data: %s\n\n", data) - flusher.Flush() - }) - - if err != nil { - log.Error().Err(err).Str("instance_id", instanceID.String()).Msg("streaming exec failed") - errEvent, _ := json.Marshal(gin.H{ - "error": err.Error(), - "finished": true, - }) - fmt.Fprintf(c.Writer, "data: %s\n\n", errEvent) - flusher.Flush() - } -} - -// ExecSandboxStream runs a command inside a sandbox and streams output via SSE. -// The sandbox is loaded and auto-woken by the AutoWake middleware. -func (h *Handlers) ExecSandboxStream(c *gin.Context) { - sandbox := sandboxFromContext(c) - if sandbox == nil { - respondError(c, ErrInternal) - return - } - - var req streamExecRequest - if err := c.ShouldBindJSON(&req); err != nil { - respondErrorMsg(c, "bad_request", fmt.Sprintf("Validation failed: %v", err), http.StatusBadRequest) - return - } - - if req.TimeoutS <= 0 { - req.TimeoutS = 30 - } - - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("X-Accel-Buffering", "no") - c.Status(http.StatusOK) - - flusher, ok := c.Writer.(http.Flusher) - if !ok { + vmd, vmdLookupErr := h.vmdForHost(c.Request.Context(), sandbox.HostID) + if vmdLookupErr != nil { + log.Error().Err(vmdLookupErr).Str("sandbox_id", sandbox.ID.String()).Msg("resolve VMD for exec stream failed") respondError(c, ErrInternal) return } @@ -122,7 +63,7 @@ func (h *Handlers) ExecSandboxStream(c *gin.Context) { start := time.Now() var lastExitCode int32 - err := h.VMD.ExecCommandStream(c.Request.Context(), sandbox.ID.String(), + err := vmd.ExecCommandStream(c.Request.Context(), sandbox.ID.String(), req.Command, req.Args, req.Env, req.WorkingDir, uint32(req.TimeoutS), func(stdout, stderr []byte, exitCode int32, finished bool) { event := gin.H{ @@ -150,19 +91,35 @@ func (h *Handlers) ExecSandboxStream(c *gin.Context) { }) if err != nil { - log.Error().Err(err).Str("sandbox_id", sandbox.ID.String()).Msg("streaming sandbox exec failed") - errEvent, _ := json.Marshal(gin.H{ - "error": err.Error(), - "finished": true, - }) - fmt.Fprintf(c.Writer, "data: %s\n\n", errEvent) - flusher.Flush() + // If VMD says the VM is gone, mark the sandbox failed. The HTTP + // response has already committed 200 OK (SSE headers flushed + // before the call), so we can't downgrade the status code — + // instead emit a "gone" event in the stream so clients can + // distinguish this from transient errors. + if isVMDNotFound(err) { + log.Warn().Err(err).Str("sandbox_id", sandbox.ID.String()).Msg("VMD ExecCommandStream: VM unavailable, marking sandbox failed") + h.markSandboxFailedAsync(c.Request.Context(), sandbox.ID, sandbox.TeamID) + errEvent, _ := json.Marshal(gin.H{ + "error": "sandbox VM is no longer available", + "code": "gone", + "finished": true, + }) + fmt.Fprintf(c.Writer, "data: %s\n\n", errEvent) + flusher.Flush() + } else { + log.Error().Err(err).Str("sandbox_id", sandbox.ID.String()).Msg("streaming sandbox exec failed") + errEvent, _ := json.Marshal(gin.H{ + "error": err.Error(), + "finished": true, + }) + fmt.Fprintf(c.Writer, "data: %s\n\n", errEvent) + flusher.Flush() + } } durationMs := int32(time.Since(start).Milliseconds()) // Async observability writes. - h.updateLastActivityAsync(c.Request.Context(), sandbox.ID, sandbox.TeamID) actStatus := "success" if err != nil { actStatus = "error" diff --git a/internal/api/vmd_errors.go b/internal/api/vmd_errors.go new file mode 100644 index 0000000..1005a29 --- /dev/null +++ b/internal/api/vmd_errors.go @@ -0,0 +1,55 @@ +package api + +import ( + "context" + + "github.com/google/uuid" + "github.com/rs/zerolog/log" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/superserve-ai/sandbox/internal/db" +) + +// ErrSandboxGone is returned when a handler detects that the underlying VM +// is gone (VMD returned NotFound and the sandbox should be marked failed). +// Maps to HTTP 410 Gone — the resource existed but is permanently lost. +var ErrSandboxGone = &AppError{ + Code: "gone", + Message: "Sandbox VM is no longer available and has been marked failed", + HTTPStatus: 410, +} + +// isVMDNotFound returns true when VMD reports that the VM is gone +// (gRPC NotFound). This covers two cases: +// - VMD never had the VM in its map (lost BoltDB entry) +// - VMD had the VM but detected the process is dead, cleaned up, +// and returned NotFound +// +// In both cases the sandbox should be marked failed and the client +// gets 410 Gone. +func isVMDNotFound(err error) bool { + if err == nil { + return false + } + return status.Code(err) == codes.NotFound +} + +// markSandboxFailedAsync writes status=failed in a detached goroutine. +// Used when a handler discovers (via VMD NotFound) that the VM is gone. +// Detaches cancellation so the state transition survives client disconnect, +// but keeps the request's trace context so the write appears in the same span. +func (h *Handlers) markSandboxFailedAsync(reqCtx context.Context, sandboxID, teamID uuid.UUID) { + asyncCtx := context.WithoutCancel(reqCtx) + go func() { + ctx, cancel := context.WithTimeout(asyncCtx, asyncTimeout) + defer cancel() + if err := h.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{ + ID: sandboxID, + Status: db.SandboxStatusFailed, + TeamID: teamID, + }); err != nil { + log.Error().Err(err).Str("sandbox_id", sandboxID.String()).Msg("async mark-failed write failed") + } + }() +} diff --git a/internal/auth/sandbox_token.go b/internal/auth/sandbox_token.go new file mode 100644 index 0000000..f65a6de --- /dev/null +++ b/internal/auth/sandbox_token.go @@ -0,0 +1,37 @@ +package auth + +import ( + "crypto/hmac" + "crypto/sha256" + "crypto/subtle" + "encoding/hex" + "fmt" +) + +// ComputeAccessToken derives a per-sandbox access token from the +// shared seed and the sandbox ID. The result is a stable hex-encoded +// HMAC-SHA256 digest — same inputs always produce the same output. +func ComputeAccessToken(seed []byte, sandboxID string) string { + mac := hmac.New(sha256.New, seed) + mac.Write([]byte(sandboxID)) + return hex.EncodeToString(mac.Sum(nil)) +} + +// VerifyAccessToken checks whether a presented token matches the +// expected HMAC for the given sandbox ID. Uses constant-time +// comparison to prevent timing side-channels. +func VerifyAccessToken(seed []byte, sandboxID, presentedToken string) bool { + expected := ComputeAccessToken(seed, sandboxID) + return subtle.ConstantTimeCompare([]byte(expected), []byte(presentedToken)) == 1 +} + +// ValidateSeed checks that a seed key is present and of reasonable length. +func ValidateSeed(seed []byte) error { + if len(seed) == 0 { + return fmt.Errorf("auth: sandbox access token seed is empty") + } + if len(seed) < 32 { + return fmt.Errorf("auth: sandbox access token seed is too short (%d bytes, want >= 32)", len(seed)) + } + return nil +} diff --git a/internal/auth/sandbox_token_test.go b/internal/auth/sandbox_token_test.go new file mode 100644 index 0000000..a0b991e --- /dev/null +++ b/internal/auth/sandbox_token_test.go @@ -0,0 +1,75 @@ +package auth + +import "testing" + +func TestComputeAccessToken_Deterministic(t *testing.T) { + seed := []byte("test-seed-key-that-is-at-least-32-bytes-long!!") + tok1 := ComputeAccessToken(seed, "sandbox-123") + tok2 := ComputeAccessToken(seed, "sandbox-123") + if tok1 != tok2 { + t.Errorf("same inputs produced different tokens: %q vs %q", tok1, tok2) + } +} + +func TestComputeAccessToken_DifferentSandboxesDiffer(t *testing.T) { + seed := []byte("test-seed-key-that-is-at-least-32-bytes-long!!") + a := ComputeAccessToken(seed, "sandbox-aaa") + b := ComputeAccessToken(seed, "sandbox-bbb") + if a == b { + t.Error("different sandbox IDs produced the same token") + } +} + +func TestComputeAccessToken_DifferentSeedsDiffer(t *testing.T) { + seedA := []byte("seed-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") + seedB := []byte("seed-bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") + tok1 := ComputeAccessToken(seedA, "sandbox-123") + tok2 := ComputeAccessToken(seedB, "sandbox-123") + if tok1 == tok2 { + t.Error("different seeds produced the same token") + } +} + +func TestVerifyAccessToken_Valid(t *testing.T) { + seed := []byte("test-seed-key-that-is-at-least-32-bytes-long!!") + tok := ComputeAccessToken(seed, "sandbox-123") + if !VerifyAccessToken(seed, "sandbox-123", tok) { + t.Error("valid token rejected") + } +} + +func TestVerifyAccessToken_WrongToken(t *testing.T) { + seed := []byte("test-seed-key-that-is-at-least-32-bytes-long!!") + if VerifyAccessToken(seed, "sandbox-123", "totally-wrong") { + t.Error("wrong token accepted") + } +} + +func TestVerifyAccessToken_WrongSandbox(t *testing.T) { + seed := []byte("test-seed-key-that-is-at-least-32-bytes-long!!") + tok := ComputeAccessToken(seed, "sandbox-aaa") + if VerifyAccessToken(seed, "sandbox-bbb", tok) { + t.Error("token for sandbox-aaa accepted for sandbox-bbb") + } +} + +func TestValidateSeed_Empty(t *testing.T) { + if err := ValidateSeed(nil); err == nil { + t.Error("nil seed accepted") + } + if err := ValidateSeed([]byte{}); err == nil { + t.Error("empty seed accepted") + } +} + +func TestValidateSeed_TooShort(t *testing.T) { + if err := ValidateSeed([]byte("short")); err == nil { + t.Error("short seed accepted") + } +} + +func TestValidateSeed_Valid(t *testing.T) { + if err := ValidateSeed([]byte("this-is-a-valid-seed-with-32-byt")); err != nil { + t.Errorf("valid seed rejected: %v", err) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 6551697..3fe53b9 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,9 +1,14 @@ -// Package config loads application configuration from environment variables. package config import ( + "crypto/rand" + "encoding/hex" "fmt" "os" + + "github.com/rs/zerolog/log" + + "github.com/superserve-ai/sandbox/internal/auth" ) // Config holds all configuration for the Superserve Sandbox control plane. @@ -11,23 +16,71 @@ type Config struct { Port string // API_PORT, default "8080" VMDAddress string // VMD_GRPC_ADDRESS, default "localhost:50051" DatabaseURL string // DATABASE_URL, required + + // SandboxAccessTokenSeed is the HMAC seed shared with the edge + // proxy. Both sides derive per-sandbox access tokens as + // HMAC-SHA256(seed, sandboxID). Loaded from SANDBOX_ACCESS_TOKEN_SEED + // (hex-encoded, >= 32 bytes). + SandboxAccessTokenSeed []byte + + // EdgeProxyDomain is the public hostname suffix served by the edge + // proxy, used to construct URLs in sandbox responses. + EdgeProxyDomain string + + // DefaultHostID is the fallback host identifier used when no scheduler + // is configured. Set via DEFAULT_HOST_ID; defaults to "default". + DefaultHostID string } -// Load reads configuration from environment variables, applying defaults where -// appropriate. +// Load reads configuration from environment variables. func Load() (*Config, error) { dbURL := os.Getenv("DATABASE_URL") if dbURL == "" { return nil, fmt.Errorf("DATABASE_URL is required") } + + seed, err := loadSeed( + os.Getenv("SANDBOX_ACCESS_TOKEN_SEED"), + os.Getenv("ALLOW_EPHEMERAL_SEED") == "1", + ) + if err != nil { + return nil, fmt.Errorf("SANDBOX_ACCESS_TOKEN_SEED: %w", err) + } + cfg := &Config{ - Port: envOrDefault("API_PORT", "8080"), - VMDAddress: envOrDefault("VMD_GRPC_ADDRESS", "localhost:50051"), - DatabaseURL: dbURL, + Port: envOrDefault("API_PORT", "8080"), + VMDAddress: envOrDefault("VMD_GRPC_ADDRESS", "localhost:50051"), + DatabaseURL: dbURL, + SandboxAccessTokenSeed: seed, + EdgeProxyDomain: envOrDefault("EDGE_PROXY_DOMAIN", "sandbox.superserve.ai"), + DefaultHostID: envOrDefault("DEFAULT_HOST_ID", "default"), } return cfg, nil } +func loadSeed(envValue string, allowEphemeral bool) ([]byte, error) { + if envValue == "" { + if !allowEphemeral { + return nil, fmt.Errorf("required in production; set ALLOW_EPHEMERAL_SEED=1 for local dev") + } + seed := make([]byte, 32) + if _, err := rand.Read(seed); err != nil { + return nil, fmt.Errorf("generate ephemeral seed: %w", err) + } + log.Warn().Msg("SANDBOX_ACCESS_TOKEN_SEED unset — generated ephemeral seed (DO NOT USE IN PRODUCTION)") + return seed, nil + } + + seed, err := hex.DecodeString(envValue) + if err != nil { + return nil, fmt.Errorf("not valid hex: %w", err) + } + if err := auth.ValidateSeed(seed); err != nil { + return nil, err + } + return seed, nil +} + func envOrDefault(key, fallback string) string { if v := os.Getenv(key); v != "" { return v diff --git a/internal/db/hosts.sql.go b/internal/db/hosts.sql.go new file mode 100644 index 0000000..981107a --- /dev/null +++ b/internal/db/hosts.sql.go @@ -0,0 +1,309 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: hosts.sql + +package db + +import ( + "context" + "time" + + "github.com/jackc/pgx/v5/pgtype" +) + +const createHost = `-- name: CreateHost :one +INSERT INTO host (id, vmd_addr, proxy_addr, region, capacity_memory_mib, capacity_vcpus) +VALUES ($1, $2, $3, $4, $5, $6) +RETURNING id, vmd_addr, proxy_addr, region, status, capacity_memory_mib, capacity_vcpus, last_heartbeat_at, created_at, updated_at +` + +type CreateHostParams struct { + ID string `json:"id"` + VmdAddr string `json:"vmd_addr"` + ProxyAddr string `json:"proxy_addr"` + Region string `json:"region"` + CapacityMemoryMib int32 `json:"capacity_memory_mib"` + CapacityVcpus int32 `json:"capacity_vcpus"` +} + +func (q *Queries) CreateHost(ctx context.Context, arg CreateHostParams) (Host, error) { + row := q.db.QueryRow(ctx, createHost, + arg.ID, + arg.VmdAddr, + arg.ProxyAddr, + arg.Region, + arg.CapacityMemoryMib, + arg.CapacityVcpus, + ) + var i Host + err := row.Scan( + &i.ID, + &i.VmdAddr, + &i.ProxyAddr, + &i.Region, + &i.Status, + &i.CapacityMemoryMib, + &i.CapacityVcpus, + &i.LastHeartbeatAt, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getHost = `-- name: GetHost :one +SELECT id, vmd_addr, proxy_addr, region, status, capacity_memory_mib, capacity_vcpus, last_heartbeat_at, created_at, updated_at FROM host WHERE id = $1 +` + +func (q *Queries) GetHost(ctx context.Context, id string) (Host, error) { + row := q.db.QueryRow(ctx, getHost, id) + var i Host + err := row.Scan( + &i.ID, + &i.VmdAddr, + &i.ProxyAddr, + &i.Region, + &i.Status, + &i.CapacityMemoryMib, + &i.CapacityVcpus, + &i.LastHeartbeatAt, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const listActiveHosts = `-- name: ListActiveHosts :many +SELECT id, vmd_addr, proxy_addr, region, status, capacity_memory_mib, capacity_vcpus, last_heartbeat_at, created_at, updated_at FROM host +WHERE status = 'active' +ORDER BY created_at ASC +` + +func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) { + rows, err := q.db.Query(ctx, listActiveHosts) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Host{} + for rows.Next() { + var i Host + if err := rows.Scan( + &i.ID, + &i.VmdAddr, + &i.ProxyAddr, + &i.Region, + &i.Status, + &i.CapacityMemoryMib, + &i.CapacityVcpus, + &i.LastHeartbeatAt, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listActiveHostsByLoad = `-- name: ListActiveHostsByLoad :many +SELECT h.id, h.vmd_addr, h.proxy_addr, h.region, h.status, + h.capacity_memory_mib, h.capacity_vcpus, + h.last_heartbeat_at, h.created_at, h.updated_at, + COALESCE(COUNT(s.id), 0)::int AS active_sandbox_count +FROM host h +LEFT JOIN sandbox s ON s.host_id = h.id + AND s.status IN ('active', 'starting') + AND s.destroyed_at IS NULL +WHERE h.status = 'active' +GROUP BY h.id +ORDER BY COUNT(s.id) ASC +` + +type ListActiveHostsByLoadRow struct { + ID string `json:"id"` + VmdAddr string `json:"vmd_addr"` + ProxyAddr string `json:"proxy_addr"` + Region string `json:"region"` + Status string `json:"status"` + CapacityMemoryMib int32 `json:"capacity_memory_mib"` + CapacityVcpus int32 `json:"capacity_vcpus"` + LastHeartbeatAt pgtype.Timestamptz `json:"last_heartbeat_at"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + ActiveSandboxCount int32 `json:"active_sandbox_count"` +} + +// Returns active hosts sorted by current sandbox count (ascending). +// The scheduler picks the first row (least loaded host). One query +// replaces N per-host lookups. +func (q *Queries) ListActiveHostsByLoad(ctx context.Context) ([]ListActiveHostsByLoadRow, error) { + rows, err := q.db.Query(ctx, listActiveHostsByLoad) + if err != nil { + return nil, err + } + defer rows.Close() + items := []ListActiveHostsByLoadRow{} + for rows.Next() { + var i ListActiveHostsByLoadRow + if err := rows.Scan( + &i.ID, + &i.VmdAddr, + &i.ProxyAddr, + &i.Region, + &i.Status, + &i.CapacityMemoryMib, + &i.CapacityVcpus, + &i.LastHeartbeatAt, + &i.CreatedAt, + &i.UpdatedAt, + &i.ActiveSandboxCount, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listHosts = `-- name: ListHosts :many +SELECT id, vmd_addr, proxy_addr, region, status, capacity_memory_mib, capacity_vcpus, last_heartbeat_at, created_at, updated_at FROM host +ORDER BY created_at ASC +` + +func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) { + rows, err := q.db.Query(ctx, listHosts) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Host{} + for rows.Next() { + var i Host + if err := rows.Scan( + &i.ID, + &i.VmdAddr, + &i.ProxyAddr, + &i.Region, + &i.Status, + &i.CapacityMemoryMib, + &i.CapacityVcpus, + &i.LastHeartbeatAt, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listStaleHosts = `-- name: ListStaleHosts :many +SELECT id, vmd_addr, proxy_addr, region, status, capacity_memory_mib, capacity_vcpus, last_heartbeat_at, created_at, updated_at FROM host +WHERE status = 'active' + AND last_heartbeat_at IS NOT NULL + AND last_heartbeat_at < $1 +ORDER BY last_heartbeat_at ASC +` + +// Returns active hosts whose last heartbeat is older than the given +// threshold. Used by the unhealthy-host detector. +func (q *Queries) ListStaleHosts(ctx context.Context, lastHeartbeatAt pgtype.Timestamptz) ([]Host, error) { + rows, err := q.db.Query(ctx, listStaleHosts, lastHeartbeatAt) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Host{} + for rows.Next() { + var i Host + if err := rows.Scan( + &i.ID, + &i.VmdAddr, + &i.ProxyAddr, + &i.Region, + &i.Status, + &i.CapacityMemoryMib, + &i.CapacityVcpus, + &i.LastHeartbeatAt, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const markHostUnhealthy = `-- name: MarkHostUnhealthy :exec +UPDATE host +SET status = 'unhealthy', updated_at = now() +WHERE id = $1 AND status = 'active' +` + +func (q *Queries) MarkHostUnhealthy(ctx context.Context, id string) error { + _, err := q.db.Exec(ctx, markHostUnhealthy, id) + return err +} + +const updateHostHeartbeat = `-- name: UpdateHostHeartbeat :one +UPDATE host +SET last_heartbeat_at = now(), + status = CASE WHEN status = 'unhealthy' THEN 'active' ELSE status END, + updated_at = now() +WHERE id = $1 +RETURNING id, vmd_addr, proxy_addr, region, status, capacity_memory_mib, capacity_vcpus, last_heartbeat_at, created_at, updated_at +` + +// Returns the host row so the caller can verify the host exists. Also +// re-activates unhealthy hosts that resume heartbeating — this is the +// automatic recovery path after a transient network outage. +func (q *Queries) UpdateHostHeartbeat(ctx context.Context, id string) (Host, error) { + row := q.db.QueryRow(ctx, updateHostHeartbeat, id) + var i Host + err := row.Scan( + &i.ID, + &i.VmdAddr, + &i.ProxyAddr, + &i.Region, + &i.Status, + &i.CapacityMemoryMib, + &i.CapacityVcpus, + &i.LastHeartbeatAt, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const updateHostStatus = `-- name: UpdateHostStatus :exec +UPDATE host +SET status = $2, updated_at = now() +WHERE id = $1 +` + +type UpdateHostStatusParams struct { + ID string `json:"id"` + Status string `json:"status"` +} + +func (q *Queries) UpdateHostStatus(ctx context.Context, arg UpdateHostStatusParams) error { + _, err := q.db.Exec(ctx, updateHostStatus, arg.ID, arg.Status) + return err +} diff --git a/internal/db/models.go b/internal/db/models.go index 22b23a4..edbe20c 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -20,7 +20,7 @@ const ( SandboxStatusStarting SandboxStatus = "starting" SandboxStatusActive SandboxStatus = "active" SandboxStatusPausing SandboxStatus = "pausing" - SandboxStatusIdle SandboxStatus = "idle" + SandboxStatusPaused SandboxStatus = "paused" SandboxStatusDeleted SandboxStatus = "deleted" SandboxStatusFailed SandboxStatus = "failed" ) @@ -107,6 +107,19 @@ type EarlyAccessRequest struct { CreatedAt time.Time `json:"created_at"` } +type Host struct { + ID string `json:"id"` + VmdAddr string `json:"vmd_addr"` + ProxyAddr string `json:"proxy_addr"` + Region string `json:"region"` + Status string `json:"status"` + CapacityMemoryMib int32 `json:"capacity_memory_mib"` + CapacityVcpus int32 `json:"capacity_vcpus"` + LastHeartbeatAt pgtype.Timestamptz `json:"last_heartbeat_at"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + type Profile struct { ID uuid.UUID `json:"id"` Email string `json:"email"` @@ -120,22 +133,31 @@ type Profile struct { UpdatedAt time.Time `json:"updated_at"` } +type ReconcilerLog struct { + ID int64 `json:"id"` + HostID string `json:"host_id"` + SandboxID pgtype.UUID `json:"sandbox_id"` + Action string `json:"action"` + Reason string `json:"reason"` + DriftKind *string `json:"drift_kind"` + CreatedAt time.Time `json:"created_at"` +} + type Sandbox struct { - ID uuid.UUID `json:"id"` - TeamID uuid.UUID `json:"team_id"` - Name string `json:"name"` - Status SandboxStatus `json:"status"` - VcpuCount int32 `json:"vcpu_count"` - MemoryMib int32 `json:"memory_mib"` - HostID *string `json:"host_id"` - IpAddress *netip.Addr `json:"ip_address"` - Pid *int32 `json:"pid"` - SnapshotID pgtype.UUID `json:"snapshot_id"` - LastActivityAt time.Time `json:"last_activity_at"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - DestroyedAt pgtype.Timestamptz `json:"destroyed_at"` - NetworkConfig []byte `json:"network_config"` + ID uuid.UUID `json:"id"` + TeamID uuid.UUID `json:"team_id"` + Name string `json:"name"` + Status SandboxStatus `json:"status"` + VcpuCount int32 `json:"vcpu_count"` + MemoryMib int32 `json:"memory_mib"` + HostID string `json:"host_id"` + IpAddress *netip.Addr `json:"ip_address"` + Pid *int32 `json:"pid"` + SnapshotID pgtype.UUID `json:"snapshot_id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DestroyedAt pgtype.Timestamptz `json:"destroyed_at"` + NetworkConfig []byte `json:"network_config"` // Hard lifetime cap in seconds from created_at. NULL = no cap. The reaper destroys the sandbox when now() > created_at + (timeout_seconds || ' seconds')::interval, regardless of state (active, paused, idle). TimeoutSeconds *int32 `json:"timeout_seconds"` // User-supplied flat string→string tags attached at creation. Immutable. Filterable on list endpoints via jsonb @> containment. Always non-null; an absent value is the empty object {}, never NULL. @@ -152,6 +174,7 @@ type Snapshot struct { Name *string `json:"name"` Trigger string `json:"trigger"` CreatedAt time.Time `json:"created_at"` + MemPath *string `json:"mem_path"` } type Team struct { diff --git a/internal/db/reconciler_log.sql.go b/internal/db/reconciler_log.sql.go new file mode 100644 index 0000000..6f58983 --- /dev/null +++ b/internal/db/reconciler_log.sql.go @@ -0,0 +1,76 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: reconciler_log.sql + +package db + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const insertReconcilerLog = `-- name: InsertReconcilerLog :exec +INSERT INTO reconciler_log (host_id, sandbox_id, action, reason, drift_kind) +VALUES ($1, $2, $3, $4, $5) +` + +type InsertReconcilerLogParams struct { + HostID string `json:"host_id"` + SandboxID pgtype.UUID `json:"sandbox_id"` + Action string `json:"action"` + Reason string `json:"reason"` + DriftKind *string `json:"drift_kind"` +} + +func (q *Queries) InsertReconcilerLog(ctx context.Context, arg InsertReconcilerLogParams) error { + _, err := q.db.Exec(ctx, insertReconcilerLog, + arg.HostID, + arg.SandboxID, + arg.Action, + arg.Reason, + arg.DriftKind, + ) + return err +} + +const listReconcilerLogByHost = `-- name: ListReconcilerLogByHost :many +SELECT id, host_id, sandbox_id, action, reason, drift_kind, created_at FROM reconciler_log +WHERE host_id = $1 +ORDER BY created_at DESC +LIMIT $2 +` + +type ListReconcilerLogByHostParams struct { + HostID string `json:"host_id"` + Limit int32 `json:"limit"` +} + +func (q *Queries) ListReconcilerLogByHost(ctx context.Context, arg ListReconcilerLogByHostParams) ([]ReconcilerLog, error) { + rows, err := q.db.Query(ctx, listReconcilerLogByHost, arg.HostID, arg.Limit) + if err != nil { + return nil, err + } + defer rows.Close() + items := []ReconcilerLog{} + for rows.Next() { + var i ReconcilerLog + if err := rows.Scan( + &i.ID, + &i.HostID, + &i.SandboxID, + &i.Action, + &i.Reason, + &i.DriftKind, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/db/sandboxes.sql.go b/internal/db/sandboxes.sql.go index 54e6e4e..4c66ea4 100644 --- a/internal/db/sandboxes.sql.go +++ b/internal/db/sandboxes.sql.go @@ -8,12 +8,83 @@ package db import ( "context" "net/netip" - "time" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgtype" ) +const activateSandbox = `-- name: ActivateSandbox :exec +UPDATE sandbox +SET status = 'active', + vcpu_count = $2, + memory_mib = $3, + ip_address = $4, + updated_at = now() +WHERE id = $1 AND team_id = $5 AND destroyed_at IS NULL +` + +type ActivateSandboxParams struct { + ID uuid.UUID `json:"id"` + VcpuCount int32 `json:"vcpu_count"` + MemoryMib int32 `json:"memory_mib"` + IpAddress *netip.Addr `json:"ip_address"` + TeamID uuid.UUID `json:"team_id"` +} + +func (q *Queries) ActivateSandbox(ctx context.Context, arg ActivateSandboxParams) error { + _, err := q.db.Exec(ctx, activateSandbox, + arg.ID, + arg.VcpuCount, + arg.MemoryMib, + arg.IpAddress, + arg.TeamID, + ) + return err +} + +const beginPause = `-- name: BeginPause :one +UPDATE sandbox +SET status = 'pausing', updated_at = now() +WHERE id = $1 AND team_id = $2 AND destroyed_at IS NULL AND status = 'active' +RETURNING id, team_id, name, status, vcpu_count, memory_mib, host_id, ip_address, pid, snapshot_id, created_at, updated_at, destroyed_at, network_config, timeout_seconds, metadata +` + +type BeginPauseParams struct { + ID uuid.UUID `json:"id"` + TeamID uuid.UUID `json:"team_id"` +} + +// Atomic ownership + state check + transition to 'pausing'. Replaces the +// GetSandbox → check status → UpdateSandboxStatus sequence on the pause +// hot path, collapsing two DB roundtrips into one. The WHERE clause +// enforces the invariant (only active, non-deleted sandboxes owned by +// this team can be paused); a 0-row result means "no such sandbox OR +// wrong team OR not currently active", and the caller disambiguates via +// a fallback GetSandbox in the rare error path. +func (q *Queries) BeginPause(ctx context.Context, arg BeginPauseParams) (Sandbox, error) { + row := q.db.QueryRow(ctx, beginPause, arg.ID, arg.TeamID) + var i Sandbox + err := row.Scan( + &i.ID, + &i.TeamID, + &i.Name, + &i.Status, + &i.VcpuCount, + &i.MemoryMib, + &i.HostID, + &i.IpAddress, + &i.Pid, + &i.SnapshotID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DestroyedAt, + &i.NetworkConfig, + &i.TimeoutSeconds, + &i.Metadata, + ) + return i, err +} + const claimExpiredSandboxes = `-- name: ClaimExpiredSandboxes :many WITH expired AS ( SELECT id, team_id, name, snapshot_id, host_id @@ -39,7 +110,7 @@ type ClaimExpiredSandboxesRow struct { TeamID uuid.UUID `json:"team_id"` Name string `json:"name"` SnapshotID pgtype.UUID `json:"snapshot_id"` - HostID *string `json:"host_id"` + HostID string `json:"host_id"` } // Atomically claims active sandboxes whose hard timeout has elapsed and marks @@ -47,7 +118,7 @@ type ClaimExpiredSandboxesRow struct { // skip rows already being processed, so multi-replica Cloud Run deployments // do not double-process the same sandbox. // -// Only 'active' sandboxes are claimed — idle sandboxes are already stopped, +// Only 'active' sandboxes are claimed — paused sandboxes are already stopped, // and transient states (starting, pausing) are skipped to avoid racing with // in-progress operations. The 60-second grace window prevents reaping a sandbox // that was just created with a very short timeout before it finishes starting up. @@ -78,18 +149,19 @@ func (q *Queries) ClaimExpiredSandboxes(ctx context.Context, limit int32) ([]Cla } const createSandbox = `-- name: CreateSandbox :one -INSERT INTO sandbox (team_id, name, status, vcpu_count, memory_mib, host_id, ip_address, pid, snapshot_id, timeout_seconds, metadata) -VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) -RETURNING id, team_id, name, status, vcpu_count, memory_mib, host_id, ip_address, pid, snapshot_id, last_activity_at, created_at, updated_at, destroyed_at, network_config, timeout_seconds, metadata +INSERT INTO sandbox (id, team_id, name, status, vcpu_count, memory_mib, host_id, ip_address, pid, snapshot_id, timeout_seconds, metadata) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) +RETURNING id, team_id, name, status, vcpu_count, memory_mib, host_id, ip_address, pid, snapshot_id, created_at, updated_at, destroyed_at, network_config, timeout_seconds, metadata ` type CreateSandboxParams struct { + ID uuid.UUID `json:"id"` TeamID uuid.UUID `json:"team_id"` Name string `json:"name"` Status SandboxStatus `json:"status"` VcpuCount int32 `json:"vcpu_count"` MemoryMib int32 `json:"memory_mib"` - HostID *string `json:"host_id"` + HostID string `json:"host_id"` IpAddress *netip.Addr `json:"ip_address"` Pid *int32 `json:"pid"` SnapshotID pgtype.UUID `json:"snapshot_id"` @@ -97,8 +169,13 @@ type CreateSandboxParams struct { Metadata []byte `json:"metadata"` } +// ID is supplied by the caller (generated in Go via uuid.New()) rather +// than defaulted in SQL, so the caller can parallelize this INSERT with +// the VMD CreateVM call — both need the same sandbox_id and generating +// it client-side lets them run concurrently instead of strictly serially. func (q *Queries) CreateSandbox(ctx context.Context, arg CreateSandboxParams) (Sandbox, error) { row := q.db.QueryRow(ctx, createSandbox, + arg.ID, arg.TeamID, arg.Name, arg.Status, @@ -123,7 +200,6 @@ func (q *Queries) CreateSandbox(ctx context.Context, arg CreateSandboxParams) (S &i.IpAddress, &i.Pid, &i.SnapshotID, - &i.LastActivityAt, &i.CreatedAt, &i.UpdatedAt, &i.DestroyedAt, @@ -134,6 +210,27 @@ func (q *Queries) CreateSandbox(ctx context.Context, arg CreateSandboxParams) (S return i, err } +const deleteSnapshotRow = `-- name: DeleteSnapshotRow :execrows +DELETE FROM snapshot +WHERE id = $1 AND team_id = $2 AND saved = false +` + +type DeleteSnapshotRowParams struct { + ID uuid.UUID `json:"id"` + TeamID uuid.UUID `json:"team_id"` +} + +// Remove a snapshot row. Guarded by saved=false so a future template feature +// can rely on row durability — auto-GC callers cannot accidentally nuke a +// user-saved snapshot even if they passed the wrong ID. +func (q *Queries) DeleteSnapshotRow(ctx context.Context, arg DeleteSnapshotRowParams) (int64, error) { + result, err := q.db.Exec(ctx, deleteSnapshotRow, arg.ID, arg.TeamID) + if err != nil { + return 0, err + } + return result.RowsAffected(), nil +} + const destroySandbox = `-- name: DestroySandbox :exec UPDATE sandbox SET destroyed_at = now(), status = 'deleted', updated_at = now() @@ -150,8 +247,82 @@ func (q *Queries) DestroySandbox(ctx context.Context, arg DestroySandboxParams) return err } +const finalizePause = `-- name: FinalizePause :one +WITH target AS ( + SELECT id, team_id, snapshot_id AS prev_snapshot_id FROM sandbox + WHERE id = $1 AND team_id = $2 AND destroyed_at IS NULL +), +new_snapshot AS ( + INSERT INTO snapshot (sandbox_id, team_id, path, mem_path, size_bytes, saved, name, trigger) + SELECT target.id, target.team_id, $3, $4, $5, $6, $7, $8 FROM target + RETURNING snapshot.id AS snap_id +) +UPDATE sandbox +SET snapshot_id = (SELECT snap_id FROM new_snapshot), + status = 'paused', + updated_at = now() +FROM new_snapshot +WHERE sandbox.id = $1 AND sandbox.team_id = $2 AND sandbox.destroyed_at IS NULL +RETURNING + new_snapshot.snap_id::uuid AS snapshot_id, + (SELECT prev_snapshot_id FROM target) AS prev_snapshot_id +` + +type FinalizePauseParams struct { + ID uuid.UUID `json:"id"` + TeamID uuid.UUID `json:"team_id"` + Path string `json:"path"` + MemPath *string `json:"mem_path"` + SizeBytes int64 `json:"size_bytes"` + Saved bool `json:"saved"` + Name *string `json:"name"` + Trigger string `json:"trigger"` +} + +type FinalizePauseRow struct { + SnapshotID uuid.UUID `json:"snapshot_id"` + PrevSnapshotID pgtype.UUID `json:"prev_snapshot_id"` +} + +// Atomically insert the snapshot row, link it to the sandbox, and flip +// status from 'pausing' to 'paused'. Replaces the sequence +// CreateSnapshot → SetSandboxSnapshot → UpdateSandboxStatus, collapsing +// three DB roundtrips into one. +// +// The INSERT is gated on a `WHERE EXISTS` against a non-deleted sandbox +// in the same query. This prevents the common race where a sandbox is +// soft-deleted before FinalizePause runs — without the gate, the CTE +// INSERT would always execute (per PostgreSQL's rule that data-modifying +// CTEs run independently of the main query), producing an orphan snapshot +// row and a snapshot file on disk with no owner. A concurrent delete that +// commits BETWEEN the EXISTS check and the INSERT under READ COMMITTED +// can still race, but that window is microseconds and the resulting +// orphan is detectable/cleanable by a background job. +// +// Also captures the sandbox's previous snapshot_id (before we overwrite it) +// so the caller can garbage-collect the now-unreachable prior snapshot +// asynchronously. Returns NULL for the first pause of a sandbox. +// +// When either the sandbox is missing/deleted or the INSERT did not fire, +// the query returns 0 rows and the caller maps that to ErrSandboxGone. +func (q *Queries) FinalizePause(ctx context.Context, arg FinalizePauseParams) (FinalizePauseRow, error) { + row := q.db.QueryRow(ctx, finalizePause, + arg.ID, + arg.TeamID, + arg.Path, + arg.MemPath, + arg.SizeBytes, + arg.Saved, + arg.Name, + arg.Trigger, + ) + var i FinalizePauseRow + err := row.Scan(&i.SnapshotID, &i.PrevSnapshotID) + return i, err +} + const getSandbox = `-- name: GetSandbox :one -SELECT id, team_id, name, status, vcpu_count, memory_mib, host_id, ip_address, pid, snapshot_id, last_activity_at, created_at, updated_at, destroyed_at, network_config, timeout_seconds, metadata FROM sandbox +SELECT id, team_id, name, status, vcpu_count, memory_mib, host_id, ip_address, pid, snapshot_id, created_at, updated_at, destroyed_at, network_config, timeout_seconds, metadata FROM sandbox WHERE id = $1 AND team_id = $2 AND destroyed_at IS NULL ` @@ -174,7 +345,6 @@ func (q *Queries) GetSandbox(ctx context.Context, arg GetSandboxParams) (Sandbox &i.IpAddress, &i.Pid, &i.SnapshotID, - &i.LastActivityAt, &i.CreatedAt, &i.UpdatedAt, &i.DestroyedAt, @@ -202,16 +372,51 @@ func (q *Queries) GetSandboxNetworkConfig(ctx context.Context, arg GetSandboxNet return network_config, err } -const listIdleSandboxes = `-- name: ListIdleSandboxes :many -SELECT id, team_id, name, status, vcpu_count, memory_mib, host_id, ip_address, pid, snapshot_id, last_activity_at, created_at, updated_at, destroyed_at, network_config, timeout_seconds, metadata FROM sandbox -WHERE status = 'idle' - AND destroyed_at IS NULL - AND last_activity_at < $1 -ORDER BY last_activity_at ASC +const getSnapshotForCleanup = `-- name: GetSnapshotForCleanup :one +SELECT id, team_id, path, mem_path +FROM snapshot +WHERE id = $1 AND team_id = $2 AND saved = false ` -func (q *Queries) ListIdleSandboxes(ctx context.Context, lastActivityAt time.Time) ([]Sandbox, error) { - rows, err := q.db.Query(ctx, listIdleSandboxes, lastActivityAt) +type GetSnapshotForCleanupParams struct { + ID uuid.UUID `json:"id"` + TeamID uuid.UUID `json:"team_id"` +} + +type GetSnapshotForCleanupRow struct { + ID uuid.UUID `json:"id"` + TeamID uuid.UUID `json:"team_id"` + Path string `json:"path"` + MemPath *string `json:"mem_path"` +} + +// Fetch a snapshot's paths for garbage collection. Returns only non-saved +// snapshots — saved=true rows are reserved for the (future) user-named +// template feature and must never be auto-deleted. A 0-row result means +// the row was already gone or is a saved snapshot; either way the caller +// should skip deletion. +func (q *Queries) GetSnapshotForCleanup(ctx context.Context, arg GetSnapshotForCleanupParams) (GetSnapshotForCleanupRow, error) { + row := q.db.QueryRow(ctx, getSnapshotForCleanup, arg.ID, arg.TeamID) + var i GetSnapshotForCleanupRow + err := row.Scan( + &i.ID, + &i.TeamID, + &i.Path, + &i.MemPath, + ) + return i, err +} + +const listSandboxesByHost = `-- name: ListSandboxesByHost :many +SELECT id, team_id, name, status, vcpu_count, memory_mib, host_id, ip_address, pid, snapshot_id, created_at, updated_at, destroyed_at, network_config, timeout_seconds, metadata FROM sandbox +WHERE host_id = $1 AND destroyed_at IS NULL +` + +// Used by the VMD reconciler to find all non-deleted sandboxes scheduled on +// this host. Includes both active and paused sandboxes because the reconciler +// needs to validate both states (active → systemd unit, paused → snapshot file). +func (q *Queries) ListSandboxesByHost(ctx context.Context, hostID string) ([]Sandbox, error) { + rows, err := q.db.Query(ctx, listSandboxesByHost, hostID) if err != nil { return nil, err } @@ -230,7 +435,6 @@ func (q *Queries) ListIdleSandboxes(ctx context.Context, lastActivityAt time.Tim &i.IpAddress, &i.Pid, &i.SnapshotID, - &i.LastActivityAt, &i.CreatedAt, &i.UpdatedAt, &i.DestroyedAt, @@ -249,7 +453,7 @@ func (q *Queries) ListIdleSandboxes(ctx context.Context, lastActivityAt time.Tim } const listSandboxesByTeam = `-- name: ListSandboxesByTeam :many -SELECT id, team_id, name, status, vcpu_count, memory_mib, host_id, ip_address, pid, snapshot_id, last_activity_at, created_at, updated_at, destroyed_at, network_config, timeout_seconds, metadata FROM sandbox +SELECT id, team_id, name, status, vcpu_count, memory_mib, host_id, ip_address, pid, snapshot_id, created_at, updated_at, destroyed_at, network_config, timeout_seconds, metadata FROM sandbox WHERE team_id = $1 AND destroyed_at IS NULL ORDER BY created_at DESC ` @@ -274,7 +478,6 @@ func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID uuid.UUID) ([] &i.IpAddress, &i.Pid, &i.SnapshotID, - &i.LastActivityAt, &i.CreatedAt, &i.UpdatedAt, &i.DestroyedAt, @@ -293,7 +496,7 @@ func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID uuid.UUID) ([] } const listSandboxesByTeamWithFilter = `-- name: ListSandboxesByTeamWithFilter :many -SELECT id, team_id, name, status, vcpu_count, memory_mib, host_id, ip_address, pid, snapshot_id, last_activity_at, created_at, updated_at, destroyed_at, network_config, timeout_seconds, metadata FROM sandbox +SELECT id, team_id, name, status, vcpu_count, memory_mib, host_id, ip_address, pid, snapshot_id, created_at, updated_at, destroyed_at, network_config, timeout_seconds, metadata FROM sandbox WHERE team_id = $1 AND destroyed_at IS NULL AND metadata @> $2 @@ -329,7 +532,6 @@ func (q *Queries) ListSandboxesByTeamWithFilter(ctx context.Context, arg ListSan &i.IpAddress, &i.Pid, &i.SnapshotID, - &i.LastActivityAt, &i.CreatedAt, &i.UpdatedAt, &i.DestroyedAt, @@ -347,6 +549,20 @@ func (q *Queries) ListSandboxesByTeamWithFilter(ctx context.Context, arg ListSan return items, nil } +const markSandboxFailed = `-- name: MarkSandboxFailed :exec +UPDATE sandbox +SET status = 'failed', updated_at = now() +WHERE id = $1 AND destroyed_at IS NULL +` + +// Used by the reconciler to mark a sandbox failed when VMD detects it is +// actually gone. No team_id filter — the reconciler runs with host scope, +// not team scope. +func (q *Queries) MarkSandboxFailed(ctx context.Context, id uuid.UUID) error { + _, err := q.db.Exec(ctx, markSandboxFailed, id) + return err +} + const sandboxExists = `-- name: SandboxExists :one SELECT EXISTS(SELECT 1 FROM sandbox WHERE id = $1 AND team_id = $2 AND destroyed_at IS NULL) ` @@ -388,7 +604,7 @@ WHERE id = $1 AND team_id = $5 AND destroyed_at IS NULL type UpdateSandboxHostParams struct { ID uuid.UUID `json:"id"` - HostID *string `json:"host_id"` + HostID string `json:"host_id"` IpAddress *netip.Addr `json:"ip_address"` Pid *int32 `json:"pid"` TeamID uuid.UUID `json:"team_id"` @@ -405,22 +621,6 @@ func (q *Queries) UpdateSandboxHost(ctx context.Context, arg UpdateSandboxHostPa return err } -const updateSandboxLastActivity = `-- name: UpdateSandboxLastActivity :exec -UPDATE sandbox -SET last_activity_at = now(), updated_at = now() -WHERE id = $1 AND team_id = $2 AND destroyed_at IS NULL -` - -type UpdateSandboxLastActivityParams struct { - ID uuid.UUID `json:"id"` - TeamID uuid.UUID `json:"team_id"` -} - -func (q *Queries) UpdateSandboxLastActivity(ctx context.Context, arg UpdateSandboxLastActivityParams) error { - _, err := q.db.Exec(ctx, updateSandboxLastActivity, arg.ID, arg.TeamID) - return err -} - const updateSandboxMetadata = `-- name: UpdateSandboxMetadata :exec UPDATE sandbox SET metadata = $2, updated_at = now() diff --git a/internal/db/snapshots.sql.go b/internal/db/snapshots.sql.go index d67e982..04a57bb 100644 --- a/internal/db/snapshots.sql.go +++ b/internal/db/snapshots.sql.go @@ -12,15 +12,16 @@ import ( ) const createSnapshot = `-- name: CreateSnapshot :one -INSERT INTO snapshot (sandbox_id, team_id, path, size_bytes, saved, name, trigger) -VALUES ($1, $2, $3, $4, $5, $6, $7) -RETURNING id, sandbox_id, team_id, path, size_bytes, saved, name, trigger, created_at +INSERT INTO snapshot (sandbox_id, team_id, path, mem_path, size_bytes, saved, name, trigger) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8) +RETURNING id, sandbox_id, team_id, path, size_bytes, saved, name, trigger, created_at, mem_path ` type CreateSnapshotParams struct { SandboxID uuid.UUID `json:"sandbox_id"` TeamID uuid.UUID `json:"team_id"` Path string `json:"path"` + MemPath *string `json:"mem_path"` SizeBytes int64 `json:"size_bytes"` Saved bool `json:"saved"` Name *string `json:"name"` @@ -32,6 +33,7 @@ func (q *Queries) CreateSnapshot(ctx context.Context, arg CreateSnapshotParams) arg.SandboxID, arg.TeamID, arg.Path, + arg.MemPath, arg.SizeBytes, arg.Saved, arg.Name, @@ -48,6 +50,7 @@ func (q *Queries) CreateSnapshot(ctx context.Context, arg CreateSnapshotParams) &i.Name, &i.Trigger, &i.CreatedAt, + &i.MemPath, ) return i, err } @@ -63,12 +66,46 @@ func (q *Queries) DeleteSnapshot(ctx context.Context, id uuid.UUID) error { } const getSnapshot = `-- name: GetSnapshot :one -SELECT id, sandbox_id, team_id, path, size_bytes, saved, name, trigger, created_at FROM snapshot +SELECT id, sandbox_id, team_id, path, size_bytes, saved, name, trigger, created_at, mem_path FROM snapshot +WHERE id = $1 AND team_id = $2 +` + +type GetSnapshotParams struct { + ID uuid.UUID `json:"id"` + TeamID uuid.UUID `json:"team_id"` +} + +// Team-scoped snapshot lookup for user-facing handlers. The join on +// team_id enforces tenant isolation at the SQL layer so callers cannot +// accidentally leak another team's snapshot metadata by forgetting the +// in-Go team check. +func (q *Queries) GetSnapshot(ctx context.Context, arg GetSnapshotParams) (Snapshot, error) { + row := q.db.QueryRow(ctx, getSnapshot, arg.ID, arg.TeamID) + var i Snapshot + err := row.Scan( + &i.ID, + &i.SandboxID, + &i.TeamID, + &i.Path, + &i.SizeBytes, + &i.Saved, + &i.Name, + &i.Trigger, + &i.CreatedAt, + &i.MemPath, + ) + return i, err +} + +const getSnapshotByID = `-- name: GetSnapshotByID :one +SELECT id, sandbox_id, team_id, path, size_bytes, saved, name, trigger, created_at, mem_path FROM snapshot WHERE id = $1 ` -func (q *Queries) GetSnapshot(ctx context.Context, id uuid.UUID) (Snapshot, error) { - row := q.db.QueryRow(ctx, getSnapshot, id) +// Unscoped snapshot lookup for internal (host-scoped) code paths such as +// the VMD reconciler. DO NOT call from user-facing handlers. +func (q *Queries) GetSnapshotByID(ctx context.Context, id uuid.UUID) (Snapshot, error) { + row := q.db.QueryRow(ctx, getSnapshotByID, id) var i Snapshot err := row.Scan( &i.ID, @@ -80,12 +117,13 @@ func (q *Queries) GetSnapshot(ctx context.Context, id uuid.UUID) (Snapshot, erro &i.Name, &i.Trigger, &i.CreatedAt, + &i.MemPath, ) return i, err } const listSnapshotsBySandbox = `-- name: ListSnapshotsBySandbox :many -SELECT id, sandbox_id, team_id, path, size_bytes, saved, name, trigger, created_at FROM snapshot +SELECT id, sandbox_id, team_id, path, size_bytes, saved, name, trigger, created_at, mem_path FROM snapshot WHERE sandbox_id = $1 ORDER BY created_at DESC ` @@ -109,6 +147,7 @@ func (q *Queries) ListSnapshotsBySandbox(ctx context.Context, sandboxID uuid.UUI &i.Name, &i.Trigger, &i.CreatedAt, + &i.MemPath, ); err != nil { return nil, err } @@ -121,7 +160,7 @@ func (q *Queries) ListSnapshotsBySandbox(ctx context.Context, sandboxID uuid.UUI } const listSnapshotsByTeam = `-- name: ListSnapshotsByTeam :many -SELECT id, sandbox_id, team_id, path, size_bytes, saved, name, trigger, created_at FROM snapshot +SELECT id, sandbox_id, team_id, path, size_bytes, saved, name, trigger, created_at, mem_path FROM snapshot WHERE team_id = $1 ORDER BY created_at DESC ` @@ -145,6 +184,7 @@ func (q *Queries) ListSnapshotsByTeam(ctx context.Context, teamID uuid.UUID) ([] &i.Name, &i.Trigger, &i.CreatedAt, + &i.MemPath, ); err != nil { return nil, err } diff --git a/internal/hostreg/registry.go b/internal/hostreg/registry.go new file mode 100644 index 0000000..ecd4f31 --- /dev/null +++ b/internal/hostreg/registry.go @@ -0,0 +1,63 @@ +package hostreg + +import ( + "context" + "fmt" + "sync" + + "github.com/superserve-ai/sandbox/internal/db" + "github.com/superserve-ai/sandbox/internal/vmdclient" +) + +// DialFunc creates a VMD client for the given gRPC address. +type DialFunc func(addr string) (vmdclient.Client, error) + +// Registry maps host IDs to VMD clients. Clients are lazily created on +// first use and cached. +type Registry struct { + db *db.Queries + dial DialFunc + mu sync.RWMutex + clients map[string]vmdclient.Client +} + +// New creates a Registry backed by the host table. +func New(queries *db.Queries, dial DialFunc) *Registry { + return &Registry{ + db: queries, + dial: dial, + clients: make(map[string]vmdclient.Client), + } +} + +// ClientFor returns the VMD client for the given host. It looks up the host +// in the DB on first access, dials gRPC, and caches the result. +func (r *Registry) ClientFor(ctx context.Context, hostID string) (vmdclient.Client, error) { + r.mu.RLock() + c, ok := r.clients[hostID] + r.mu.RUnlock() + if ok { + return c, nil + } + + r.mu.Lock() + defer r.mu.Unlock() + + // Double-check after acquiring write lock. + if c, ok := r.clients[hostID]; ok { + return c, nil + } + + host, err := r.db.GetHost(ctx, hostID) + if err != nil { + return nil, fmt.Errorf("get host %q: %w", hostID, err) + } + + c, err = r.dial(host.VmdAddr) + if err != nil { + return nil, fmt.Errorf("dial VMD at %s for host %q: %w", host.VmdAddr, hostID, err) + } + + r.clients[hostID] = c + return c, nil +} diff --git a/internal/integration/integration_test.go b/internal/integration/integration_test.go index ee21245..8ba6f67 100644 --- a/internal/integration/integration_test.go +++ b/internal/integration/integration_test.go @@ -114,15 +114,18 @@ func applyMigrations(ctx context.Context, pool *pgxpool.Pool) error { // values so that HTTP handlers can complete and write to the DB. type stubVMD struct{} -func (s *stubVMD) CreateInstance(_ context.Context, _ string, _, _, _ uint32, _ map[string]string) (string, error) { - return "10.0.0.1", nil +func (s *stubVMD) CreateInstance(_ context.Context, _ string, _, _, _ uint32, _ map[string]string, _ map[string]string) (string, uint32, uint32, error) { + return "10.0.0.1", 1, 1024, nil } func (s *stubVMD) DestroyInstance(_ context.Context, _ string, _ bool) error { return nil } func (s *stubVMD) PauseInstance(_ context.Context, _, _ string) (string, string, error) { return "/snapshots/disk.snap", "/snapshots/mem.snap", nil } -func (s *stubVMD) ResumeInstance(_ context.Context, _, _, _ string) (string, error) { - return "10.0.0.1", nil +func (s *stubVMD) ResumeInstance(_ context.Context, _, _, _ string, _ map[string]string) (string, uint32, uint32, error) { + return "10.0.0.1", 1, 1024, nil +} +func (s *stubVMD) RestoreSnapshot(_ context.Context, _, _, _ string) (string, uint32, uint32, error) { + return "10.0.0.1", 1, 1024, nil } func (s *stubVMD) ExecCommand(_ context.Context, _, _ string, _ []string, _ map[string]string, _ string, _ uint32) (string, string, int32, error) { return "hello\n", "", 0, nil @@ -132,13 +135,6 @@ func (s *stubVMD) ExecCommandStream(_ context.Context, _, _ string, _ []string, onChunk(nil, nil, 0, true) return nil } -func (s *stubVMD) UploadFile(_ context.Context, _, _ string, r io.Reader) (int64, error) { - n, _ := io.Copy(io.Discard, r) - return n, nil -} -func (s *stubVMD) DownloadFile(_ context.Context, _, _ string) (io.ReadCloser, error) { - return io.NopCloser(strings.NewReader("file-content")), nil -} func (s *stubVMD) UpdateSandboxNetwork(_ context.Context, _ string, _, _, _ []string) error { return nil } @@ -295,8 +291,8 @@ func TestIntegration_CreateSandbox_Success(t *testing.T) { if sb.VcpuCount != 1 { t.Errorf("vcpu_count = %d, want 1", sb.VcpuCount) } - if sb.MemoryMib != 512 { - t.Errorf("memory_mib = %d, want 512", sb.MemoryMib) + if sb.MemoryMib != 1024 { + t.Errorf("memory_mib = %d, want 1024", sb.MemoryMib) } } @@ -433,21 +429,21 @@ func TestIntegration_PauseSandbox_Success(t *testing.T) { t.Fatal("pause response missing snapshot_id") } - // DB: sandbox is idle, snapshot record exists and is linked. + // DB: sandbox is paused, snapshot record exists and is linked. sandboxID, _ := uuid.Parse(sid) sb, err := testQueries.GetSandbox(ctx, db.GetSandboxParams{ID: sandboxID, TeamID: teamID}) if err != nil { t.Fatalf("get sandbox: %v", err) } - if sb.Status != db.SandboxStatusIdle { - t.Errorf("DB status = %q, want idle", sb.Status) + if sb.Status != db.SandboxStatusPaused { + t.Errorf("DB status = %q, want paused", sb.Status) } if !sb.SnapshotID.Valid { t.Error("sandbox snapshot_id should be set after pause") } snapID, _ := uuid.Parse(snapshotIDStr) - snap, err := testQueries.GetSnapshot(ctx, snapID) + snap, err := testQueries.GetSnapshot(ctx, db.GetSnapshotParams{ID: snapID, TeamID: teamID}) if err != nil { t.Fatalf("snapshot not found in DB: %v", err) } @@ -614,12 +610,13 @@ func TestIntegration_ExecSandbox_Success(t *testing.T) { } } -func TestIntegration_ExecSandbox_AutoWakeIdleSandbox(t *testing.T) { - ctx := context.Background() - teamID, apiKey := seedTeamAndKey(t) +// Exec on a paused sandbox must be rejected — callers must resume explicitly +// via POST /resume. There is no implicit auto-wake on traffic. +func TestIntegration_ExecSandbox_PausedRejected(t *testing.T) { + _, apiKey := seedTeamAndKey(t) r := newRouter(t) - cw := do(r, "POST", "/sandboxes", apiKey, `{"name":"wake-box"}`) + cw := do(r, "POST", "/sandboxes", apiKey, `{"name":"paused-box"}`) if cw.Code != http.StatusCreated { t.Fatalf("create: %d", cw.Code) } @@ -630,69 +627,9 @@ func TestIntegration_ExecSandbox_AutoWakeIdleSandbox(t *testing.T) { t.Fatalf("pause: %d %s", pw.Code, pw.Body.String()) } - // Exec on idle sandbox — AutoWake middleware should resume transparently. ew := do(r, "POST", "/sandboxes/"+sid+"/exec", apiKey, `{"command":"echo hello"}`) - if ew.Code != http.StatusOK { - t.Fatalf("exec on idle: expected 200, got %d: %s", ew.Code, ew.Body.String()) - } - - // DB: active after auto-wake. - time.Sleep(50 * time.Millisecond) - sandboxID, _ := uuid.Parse(sid) - sb, err := testQueries.GetSandbox(ctx, db.GetSandboxParams{ID: sandboxID, TeamID: teamID}) - if err != nil { - t.Fatalf("get sandbox: %v", err) - } - if sb.Status != db.SandboxStatusActive { - t.Errorf("DB status = %q after auto-wake, want active", sb.Status) - } -} - -// --------------------------------------------------------------------------- -// PUT + GET /sandboxes/:id/files/*path -// --------------------------------------------------------------------------- - -func TestIntegration_FileUploadDownload(t *testing.T) { - _, apiKey := seedTeamAndKey(t) - r := newRouter(t) - - cw := do(r, "POST", "/sandboxes", apiKey, `{"name":"file-box"}`) - if cw.Code != http.StatusCreated { - t.Fatalf("create: %d %s", cw.Code, cw.Body.String()) - } - sid := mustJSON(t, cw)["id"].(string) - // Upload. - uw := doBinary(r, "PUT", "/sandboxes/"+sid+"/files/home/user/test.txt", apiKey, []byte("hello sandbox file")) - if uw.Code != http.StatusOK { - t.Fatalf("upload: expected 200, got %d: %s", uw.Code, uw.Body.String()) - } - ub := mustJSON(t, uw) - if ub["path"] != "/home/user/test.txt" { - t.Errorf("upload path = %q, want /home/user/test.txt", ub["path"]) - } - - // Download. - dw := do(r, "GET", "/sandboxes/"+sid+"/files/home/user/test.txt", apiKey, "") - if dw.Code != http.StatusOK { - t.Fatalf("download: expected 200, got %d: %s", dw.Code, dw.Body.String()) - } - if dw.Body.String() != "file-content" { // stubVMD always returns "file-content" - t.Errorf("download body = %q, want file-content", dw.Body.String()) - } -} - -func TestIntegration_FileUpload_PathTraversalRejected(t *testing.T) { - _, apiKey := seedTeamAndKey(t) - r := newRouter(t) - - cw := do(r, "POST", "/sandboxes", apiKey, `{"name":"pt-box"}`) - if cw.Code != http.StatusCreated { - t.Fatalf("create: %d", cw.Code) - } - sid := mustJSON(t, cw)["id"].(string) - w := doBinary(r, "PUT", "/sandboxes/"+sid+"/files/../../../etc/passwd", apiKey, []byte("x")) - if w.Code != http.StatusBadRequest { - t.Fatalf("expected 400 for path traversal, got %d", w.Code) + if ew.Code != http.StatusConflict { + t.Fatalf("exec on paused: expected 409, got %d: %s", ew.Code, ew.Body.String()) } } diff --git a/internal/network/manager.go b/internal/network/manager.go index 0c228b6..8924984 100644 --- a/internal/network/manager.go +++ b/internal/network/manager.go @@ -3,6 +3,7 @@ package network import ( "context" "fmt" + "os" "os/exec" "sync" "time" @@ -68,6 +69,9 @@ type VMNetInfo struct { // This supports up to ~32K concurrent VMs per node — hardware (RAM/CPU) is the real limit. const MaxSlots = 32000 +// ErrNoSlots is returned when no network slots are available. +var ErrNoSlots = fmt.Errorf("no available network slots (max %d concurrent VMs)", MaxSlots) + type Manager struct { hostInterface string log zerolog.Logger @@ -83,6 +87,9 @@ type Manager struct { // TCP egress proxy — receives per-sandbox rule updates and cleanup. egressProxy *EgressProxy + // Pre-allocated network slot pool (nil = disabled, on-demand setup). + pool *Pool + // Proxy ports for the TCP egress proxy. httpProxyPort uint16 tlsProxyPort uint16 @@ -150,6 +157,14 @@ func (m *Manager) SetProxyPorts(http, tls, other uint16) { // The host reaches the VM at hostIP:. NAT inside the namespace // translates to 169.254.0.21:. No guest IP reconfig needed. func (m *Manager) SetupVM(ctx context.Context, vmID string, cfg *Config) (*VMNetInfo, error) { + // Try the pre-allocated pool first (microseconds instead of ~10-30ms). + if m.pool != nil { + if info := m.pool.Claim(vmID); info != nil { + return info, nil + } + m.log.Debug().Str("vm_id", vmID).Msg("network pool empty, falling back to on-demand setup") + } + m.mu.Lock() var idx int if len(m.freeSlots) > 0 { @@ -158,16 +173,46 @@ func (m *Manager) SetupVM(ctx context.Context, vmID string, cfg *Config) (*VMNet } else { if m.nextSlot > MaxSlots { m.mu.Unlock() - return nil, fmt.Errorf("no available network slots (max %d concurrent VMs)", MaxSlots) + return nil, ErrNoSlots } idx = m.nextSlot m.nextSlot++ } m.mu.Unlock() - log := m.log.With().Str("vm_id", vmID).Int("slot", idx).Logger() + info, vethName, err := m.setupSlot(ctx, idx) + if err != nil { + m.mu.Lock() + m.freeSlots = append(m.freeSlots, idx) + m.mu.Unlock() + return nil, err + } + + // Host-level nftables rules require the vmID. + hostCIDR := fmt.Sprintf("%s/32", info.HostIP) + if err := m.hostFW.AddVM(vmID, vethName, hostCIDR); err != nil { + m.cleanupFull(info.Namespace, vethName) + return nil, fmt.Errorf("add host firewall rules: %w", err) + } - // Calculate IPs for this slot using /16 subnets. + m.mu.Lock() + m.devices[vmID] = info + m.mu.Unlock() + + m.log.Info(). + Str("vm_id", vmID). + Str("namespace", info.Namespace). + Str("host_ip", info.HostIP). + Msg("network namespace created") + + return info, nil +} + +// setupSlot runs the expensive network setup (namespace, veth, TAP, +// nftables, routing) for a single slot index. Used by both SetupVM +// (on-demand) and Pool (pre-allocation). Does NOT add host-level +// firewall rules — that requires a vmID and is done by the caller. +func (m *Manager) setupSlot(ctx context.Context, idx int) (*VMNetInfo, string, error) { hostIP := fmt.Sprintf("10.11.%d.%d", idx/256, idx%256) vpeerIP := fmt.Sprintf("10.12.%d.%d", (idx*2)/256, (idx*2)%256) vethIP := fmt.Sprintf("10.12.%d.%d", (idx*2+1)/256, (idx*2+1)%256) @@ -176,79 +221,77 @@ func (m *Manager) SetupVM(ctx context.Context, vmID string, cfg *Config) (*VMNet vpeerName := "eth0" hostCIDR := fmt.Sprintf("%s/32", hostIP) - // 1. Create network namespace. + // If the namespace already exists, this slot is in use by a + // running sandbox from a previous VMD lifetime. Skip it — the + // pool caller will retry with the next slot index. + if nsExists(nsName) { + return nil, "", fmt.Errorf("namespace %s already exists (slot in use)", nsName) + } + if err := run(ctx, "ip", "netns", "add", nsName); err != nil { - return nil, fmt.Errorf("create namespace: %w", err) + return nil, "", fmt.Errorf("create namespace: %w", err) } - // 2. Create veth pair inside the namespace. if err := nsRun(ctx, nsName, "ip", "link", "add", vethName, "type", "veth", "peer", "name", vpeerName); err != nil { m.removeNS(nsName) - return nil, fmt.Errorf("create veth pair: %w", err) + return nil, "", fmt.Errorf("create veth pair: %w", err) } - // 3. Configure vpeer (stays in namespace). if err := nsRun(ctx, nsName, "ip", "link", "set", vpeerName, "up"); err != nil { m.removeNS(nsName) - return nil, fmt.Errorf("bring up vpeer: %w", err) + return nil, "", fmt.Errorf("bring up vpeer: %w", err) } if err := nsRun(ctx, nsName, "ip", "link", "set", vpeerName, "mtu", ifaceMTU); err != nil { m.removeNS(nsName) - return nil, fmt.Errorf("set vpeer MTU: %w", err) + return nil, "", fmt.Errorf("set vpeer MTU: %w", err) } if err := nsRun(ctx, nsName, "ip", "addr", "add", vpeerIP+"/31", "dev", vpeerName); err != nil { m.removeNS(nsName) - return nil, fmt.Errorf("assign vpeer IP: %w", err) + return nil, "", fmt.Errorf("assign vpeer IP: %w", err) } - // 4. Move veth to host namespace. if err := nsRun(ctx, nsName, "ip", "link", "set", vethName, "netns", "1"); err != nil { m.removeNS(nsName) - return nil, fmt.Errorf("move veth to host: %w", err) + return nil, "", fmt.Errorf("move veth to host: %w", err) } - // 5. Configure veth on host side. if err := run(ctx, "ip", "link", "set", vethName, "up"); err != nil { m.removeNS(nsName) - return nil, fmt.Errorf("bring up veth: %w", err) + return nil, "", fmt.Errorf("bring up veth: %w", err) } if err := run(ctx, "ip", "link", "set", vethName, "mtu", ifaceMTU); err != nil { m.removeNS(nsName) - return nil, fmt.Errorf("set veth MTU: %w", err) + return nil, "", fmt.Errorf("set veth MTU: %w", err) } if err := run(ctx, "ip", "addr", "add", vethIP+"/31", "dev", vethName); err != nil { m.removeNS(nsName) - return nil, fmt.Errorf("assign veth IP: %w", err) + return nil, "", fmt.Errorf("assign veth IP: %w", err) } - // 6. Create TAP device inside namespace. if err := nsRun(ctx, nsName, "ip", "tuntap", "add", "dev", TAPName, "mode", "tap"); err != nil { m.cleanupFull(nsName, vethName) - return nil, fmt.Errorf("create TAP: %w", err) + return nil, "", fmt.Errorf("create TAP: %w", err) } if err := nsRun(ctx, nsName, "ip", "link", "set", TAPName, "up"); err != nil { m.cleanupFull(nsName, vethName) - return nil, fmt.Errorf("bring up TAP: %w", err) + return nil, "", fmt.Errorf("bring up TAP: %w", err) } if err := nsRun(ctx, nsName, "ip", "link", "set", TAPName, "mtu", ifaceMTU); err != nil { m.cleanupFull(nsName, vethName) - return nil, fmt.Errorf("set TAP MTU: %w", err) + return nil, "", fmt.Errorf("set TAP MTU: %w", err) } if err := nsRun(ctx, nsName, "ip", "addr", "add", tapCIDR, "dev", TAPName); err != nil { m.cleanupFull(nsName, vethName) - return nil, fmt.Errorf("assign TAP IP: %w", err) + return nil, "", fmt.Errorf("assign TAP IP: %w", err) } - // 7. Bring up loopback in namespace. _ = nsRun(ctx, nsName, "ip", "link", "set", "lo", "up") - // 8. Default route in namespace → via veth IP (on host side). if err := nsRun(ctx, nsName, "ip", "route", "add", "default", "via", vethIP); err != nil { m.cleanupFull(nsName, vethName) - return nil, fmt.Errorf("add default route in ns: %w", err) + return nil, "", fmt.Errorf("add default route in ns: %w", err) } - // 9. Initialize nftables firewall inside namespace (NAT + filtering + MSS clamping + TCP redirect). var fw *Firewall if err := nsExecGo(nsName, func() error { var fwErr error @@ -265,18 +308,11 @@ func (m *Manager) SetupVM(ctx context.Context, vmID string, cfg *Config) (*VMNet return fwErr }); err != nil { m.cleanupFull(nsName, vethName) - return nil, fmt.Errorf("init firewall: %w", err) + return nil, "", fmt.Errorf("init firewall: %w", err) } - // 10. Host routing: traffic to hostIP goes via vpeer through the veth. if err := run(ctx, "ip", "route", "add", hostCIDR, "via", vpeerIP, "dev", vethName); err != nil { - log.Debug().Err(err).Msg("host route (may already exist)") - } - - // 11. Host-level nftables: FORWARD + MASQUERADE + MSS clamping for this VM. - if err := m.hostFW.AddVM(vmID, vethName, hostCIDR); err != nil { - m.cleanupFull(nsName, vethName) - return nil, fmt.Errorf("add host firewall rules: %w", err) + m.log.Debug().Err(err).Str("ns", nsName).Msg("host route (may already exist)") } mac := fmt.Sprintf("AA:FC:00:%02X:%02X:%02X", 0, idx/256, idx%256) @@ -291,17 +327,7 @@ func (m *Manager) SetupVM(ctx context.Context, vmID string, cfg *Config) (*VMNet Firewall: fw, } - m.mu.Lock() - m.devices[vmID] = info - m.mu.Unlock() - - log.Info(). - Str("namespace", nsName). - Str("host_ip", hostIP). - Str("vm_ip", VMInternalIP). - Msg("network namespace created") - - return info, nil + return info, vethName, nil } func (m *Manager) GetVMNetInfo(vmID string) *VMNetInfo { @@ -327,51 +353,55 @@ func (m *Manager) CleanupVM(vmID string) { return } - // Parse slot index once. var idx int fmt.Sscanf(info.Namespace, "ns-%d", &idx) + vethName := fmt.Sprintf("veth-%d", idx) + + // Remove host-level nftables rules (vmID-specific). + if err := m.hostFW.RemoveVM(vmID); err != nil { + m.log.Warn().Err(err).Str("vm_id", vmID).Msg("error removing host firewall rules") + } + + // Remove per-sandbox egress proxy rules. + if m.egressProxy != nil { + m.egressProxy.RemoveRules(info.HostIP) + } + + // Try to recycle the slot into the pool instead of tearing it down. + // The namespace, veth, TAP, and base nftables stay configured — + // only the vmID-specific host firewall and egress rules were removed + // above. The next Claim re-adds them for the new vmID. + if m.pool != nil { + // Reset user-defined firewall rules to defaults before recycling. + if info.Firewall != nil { + _ = info.Firewall.ReplaceUserRules(nil, nil) + } + m.pool.Return(&preallocSlot{idx: idx, info: info, vethName: vethName}) + return + } - // Recycle the slot index for reuse. + // No pool — full teardown. m.mu.Lock() m.freeSlots = append(m.freeSlots, idx) m.mu.Unlock() - log := m.log.With().Str("vm_id", vmID).Logger() - - // Close nftables firewall inside namespace (kernel removes table + all rules). if info.Firewall != nil { if err := info.Firewall.Close(); err != nil { - log.Warn().Err(err).Msg("error closing namespace firewall") + m.log.Warn().Err(err).Str("vm_id", vmID).Msg("error closing namespace firewall") } } - // Remove host-level nftables rules for this VM. - if err := m.hostFW.RemoveVM(vmID); err != nil { - log.Warn().Err(err).Msg("error removing host firewall rules") - } - - // Remove per-sandbox rules and connection limiter entries from the egress proxy. - if m.egressProxy != nil { - m.egressProxy.RemoveRules(info.HostIP) - } - - vethName := fmt.Sprintf("veth-%d", idx) vpeerIP := fmt.Sprintf("10.12.%d.%d", (idx*2)/256, (idx*2)%256) hostCIDR := fmt.Sprintf("%s/32", info.HostIP) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - // Remove host route. _ = run(ctx, "ip", "route", "del", hostCIDR, "via", vpeerIP, "dev", vethName) - - // Delete veth (also removes peer in namespace). _ = run(ctx, "ip", "link", "del", vethName) - - // Delete namespace. _ = run(ctx, "ip", "netns", "del", info.Namespace) - log.Info().Str("namespace", info.Namespace).Msg("network namespace cleaned up") + m.log.Info().Str("vm_id", vmID).Str("namespace", info.Namespace).Msg("network namespace cleaned up") } // UpdateFirewallRules atomically replaces the user allow/deny sets for a VM's firewall. @@ -392,6 +422,11 @@ func (m *Manager) UpdateFirewallRules(vmID string, allowedCIDRs, deniedCIDRs []s // Helpers // --------------------------------------------------------------------------- +func nsExists(nsName string) bool { + _, err := os.Stat("/run/netns/" + nsName) + return err == nil +} + func (m *Manager) removeNS(nsName string) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/internal/network/pool.go b/internal/network/pool.go new file mode 100644 index 0000000..428b1f8 --- /dev/null +++ b/internal/network/pool.go @@ -0,0 +1,246 @@ +package network + +import ( + "context" + "sync" + + "github.com/rs/zerolog" +) + +// PoolConfig controls the pre-allocated network slot pool. +type PoolConfig struct { + // NewSize is the number of fresh pre-allocated slots to keep ready. + // Default: 32. + NewSize int + // RecycleSize is the capacity for recycled slots — network namespaces + // returned from destroyed sandboxes. Recycled slots skip the full + // setup (namespace, veth, TAP, nftables are already configured). + // Default: 100. + RecycleSize int +} + +// Pool pre-allocates network namespaces, veth pairs, TAP devices, and +// firewall rules so that SetupVM can claim a ready slot in microseconds +// instead of running ~11 shell commands on the hot path (~10-30ms). +// +// The pool is optional — if not started, SetupVM falls back to on-demand +// setup (the original behavior). Call StartPool after NewManager to enable. +type Pool struct { + mgr *Manager + log zerolog.Logger + newSize int + fresh chan *preallocSlot // pre-allocated from scratch + recycled chan *preallocSlot // returned from destroyed sandboxes + stopCh chan struct{} + wg sync.WaitGroup +} + +// preallocSlot holds a fully configured network namespace ready to be +// assigned to a VM. +type preallocSlot struct { + idx int + info *VMNetInfo + // vethName is needed for cleanup if the slot is never claimed. + vethName string +} + +// StartPool creates and starts the network slot pool. Blocks until the +// initial batch of fresh slots is filled, then refills in the background. +func (m *Manager) StartPool(ctx context.Context, cfg PoolConfig) *Pool { + newSize := cfg.NewSize + if newSize <= 0 { + newSize = 32 + } + recycleSize := cfg.RecycleSize + if recycleSize <= 0 { + recycleSize = 100 + } + + p := &Pool{ + mgr: m, + log: m.log.With().Str("component", "net_pool").Logger(), + newSize: newSize, + fresh: make(chan *preallocSlot, newSize), + recycled: make(chan *preallocSlot, recycleSize), + stopCh: make(chan struct{}), + } + + // Fill initial batch synchronously so the pool is warm on first create. + for i := 0; i < newSize; i++ { + slot, err := p.allocate(ctx) + if err != nil { + p.log.Error().Err(err).Int("filled", i).Msg("initial pool fill incomplete") + break + } + p.fresh <- slot + } + p.log.Info().Int("fresh", len(p.fresh)).Int("recycle_cap", recycleSize).Msg("network pool ready") + + p.wg.Add(1) + go p.refillLoop(ctx) + + m.pool = p + return p +} + +// Claim takes a slot from the pool and assigns it to the given VM ID. +// Prefers recycled slots (zero setup cost) over fresh ones (one nftables +// call). Returns nil if both pools are empty — caller falls back to +// on-demand SetupVM. +func (p *Pool) Claim(vmID string) *VMNetInfo { + var slot *preallocSlot + + // Prefer recycled slots — they already have host firewall rules + // from the previous owner, which get replaced below. + select { + case slot = <-p.recycled: + default: + select { + case slot = <-p.fresh: + default: + return nil + } + } + + // Add host-level firewall rules (requires vmID). + hostCIDR := slot.info.HostIP + "/32" + if err := p.mgr.hostFW.AddVM(vmID, slot.vethName, hostCIDR); err != nil { + p.log.Error().Err(err).Str("vm_id", vmID).Msg("claim: AddVM firewall failed") + p.cleanup(slot) + return nil + } + + p.mgr.mu.Lock() + p.mgr.devices[vmID] = slot.info + p.mgr.mu.Unlock() + + return slot.info +} + +// Return puts a slot back into the recycled pool after a sandbox is +// destroyed. The network namespace, veth, TAP, and nftables stay +// configured — the next Claim reuses them with zero setup cost. +// If the recycled pool is full, the slot is torn down instead. +func (p *Pool) Return(slot *preallocSlot) { + select { + case p.recycled <- slot: + default: + // Recycle pool full — tear down. + p.cleanup(slot) + } +} + +// Stop drains both pools and cleans up unclaimed slots. +func (p *Pool) Stop() { + close(p.stopCh) + p.wg.Wait() + + close(p.fresh) + for slot := range p.fresh { + p.cleanup(slot) + } + close(p.recycled) + for slot := range p.recycled { + p.cleanup(slot) + } + p.log.Info().Msg("network pool stopped") +} + +func (p *Pool) refillLoop(ctx context.Context) { + defer p.wg.Done() + for { + select { + case <-p.stopCh: + return + case <-ctx.Done(): + return + default: + } + + if len(p.fresh) >= p.newSize { + // Pool full — block until a slot is consumed or shutdown. + select { + case <-p.stopCh: + return + case <-ctx.Done(): + return + case p.fresh <- p.mustAllocate(ctx): + } + continue + } + + slot, err := p.allocate(ctx) + if err != nil { + p.log.Error().Err(err).Msg("pool refill failed") + continue + } + select { + case p.fresh <- slot: + case <-p.stopCh: + p.cleanup(slot) + return + case <-ctx.Done(): + p.cleanup(slot) + return + } + } +} + +func (p *Pool) mustAllocate(ctx context.Context) *preallocSlot { + for { + slot, err := p.allocate(ctx) + if err == nil { + return slot + } + p.log.Error().Err(err).Msg("pool allocate retry") + select { + case <-p.stopCh: + return nil + case <-ctx.Done(): + return nil + default: + } + } +} + +func (p *Pool) allocate(ctx context.Context) (*preallocSlot, error) { + // Grab a slot index from the manager's free list. + p.mgr.mu.Lock() + var idx int + if len(p.mgr.freeSlots) > 0 { + idx = p.mgr.freeSlots[len(p.mgr.freeSlots)-1] + p.mgr.freeSlots = p.mgr.freeSlots[:len(p.mgr.freeSlots)-1] + } else { + if p.mgr.nextSlot > MaxSlots { + p.mgr.mu.Unlock() + return nil, ErrNoSlots + } + idx = p.mgr.nextSlot + p.mgr.nextSlot++ + } + p.mgr.mu.Unlock() + + // Run the full network setup (namespace, veth, TAP, nftables). + // This is the expensive part we're moving off the hot path. + info, vethName, err := p.mgr.setupSlot(ctx, idx) + if err != nil { + // Don't return the index to freeSlots — the namespace may be in + // use by a running sandbox from a previous VMD lifetime. Returning + // it would cause an infinite retry loop. The index is consumed; + // the next allocate will use nextSlot or a different free index. + return nil, err + } + + return &preallocSlot{idx: idx, info: info, vethName: vethName}, nil +} + +func (p *Pool) cleanup(slot *preallocSlot) { + if slot == nil || slot.info == nil { + return + } + nsName := slot.info.Namespace + p.mgr.cleanupFull(nsName, slot.vethName) + p.mgr.mu.Lock() + p.mgr.freeSlots = append(p.mgr.freeSlots, slot.idx) + p.mgr.mu.Unlock() +} diff --git a/internal/proxy/authz.go b/internal/proxy/authz.go new file mode 100644 index 0000000..68bc06d --- /dev/null +++ b/internal/proxy/authz.go @@ -0,0 +1,64 @@ +package proxy + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/superserve-ai/sandbox/internal/auth" + "github.com/superserve-ai/sandbox/internal/telemetry" +) + +// authzFailure is a structured rejection from authorizeSandboxRequest. +type authzFailure struct { + Status int + Message string +} + +func (f *authzFailure) write(w http.ResponseWriter) { + http.Error(w, f.Message, f.Status) +} + +// authorizeSandboxRequest verifies the per-sandbox HMAC access token +// and resolves the sandbox to a running VM. Shared by /terminal and +// /files on the boxd host label. +func (h *Handler) authorizeSandboxRequest( + ctx context.Context, + token string, + requestSandboxID string, +) (InstanceInfo, *authzFailure) { + if h.seedKey == nil { + panic("proxy: authorizeSandboxRequest called without WithAuth") + } + + if !auth.VerifyAccessToken(h.seedKey, requestSandboxID, token) { + telemetry.IncProxyHMACFailure(ctx) + return InstanceInfo{}, &authzFailure{ + Status: http.StatusUnauthorized, + Message: "invalid access token", + } + } + + info, err := h.resolver.Lookup(ctx, requestSandboxID) + if err != nil { + if errors.Is(err, ErrInstanceNotFound) { + return InstanceInfo{}, &authzFailure{ + Status: http.StatusNotFound, + Message: "sandbox not found", + } + } + return InstanceInfo{}, &authzFailure{ + Status: http.StatusServiceUnavailable, + Message: "sandbox unavailable", + } + } + if info.Status != "running" { + return InstanceInfo{}, &authzFailure{ + Status: http.StatusServiceUnavailable, + Message: fmt.Sprintf("sandbox is %s", info.Status), + } + } + + return info, nil +} diff --git a/internal/proxy/files.go b/internal/proxy/files.go new file mode 100644 index 0000000..7ea37cf --- /dev/null +++ b/internal/proxy/files.go @@ -0,0 +1,190 @@ +package proxy + +import ( + "fmt" + "net/http" + "net/http/httputil" + "net/url" + "strings" +) + +// File bridge constants. The /files path lives on the same boxd port that +// the terminal bridge talks to (boxdPort, defined in terminal.go), because +// boxd serves both its connect-rpc services and the raw /files HTTP +// endpoint on a single HTTP listener. The proxy treats all traffic to +// boxdPort as sensitive regardless of path — only /files is allowlisted +// through, everything else is 404'd so the in-VM connect-rpc services +// stay strictly internal. +const ( + // filesPath is the HTTP path the edge proxy forwards to boxd's + // raw /files handler after verifying the access token. + filesPath = "/files" + + // terminalPath is the HTTP path the edge proxy upgrades to a + // WebSocket and bridges to boxd's connect-rpc ProcessService. + // The bridge itself is implemented in terminal.go; this constant + // just names the route serveBoxdPort dispatches to it on. + terminalPath = "/terminal" + + // accessTokenHeader is the carrier for the per-sandbox HMAC access token. + accessTokenHeader = "X-Access-Token" +) + +// serveBoxdPort is the entry point for any request addressed at the +// reserved `boxd-{instanceID}.{domain}` host label. It dispatches by +// path to the concrete handler for each boxd-fronted feature. +// +// boxd is a special case: inside the VM a single HTTP listener serves +// both the raw /files endpoint and the full connect-rpc service +// surface (ProcessService, FilesystemService). We only ever expose the +// narrow set of paths we explicitly handle below; any other path +// returns an opaque 404 so a caller probing the in-VM surface cannot +// enumerate what exists behind the proxy. That includes `/health`, +// connect-rpc routes, and anything future boxd grows internally +// without our knowledge. +func (h *Handler) serveBoxdPort(w http.ResponseWriter, r *http.Request, instanceID string) { + if !h.sandboxConns.acquire(instanceID) { + http.Error(w, "too many connections to sandbox", http.StatusTooManyRequests) + return + } + defer h.sandboxConns.release(instanceID) + + clientIP := clientAddr(r) + if !h.ipConns.acquire(clientIP) { + http.Error(w, "too many connections from this IP", http.StatusTooManyRequests) + return + } + defer h.ipConns.release(clientIP) + + switch r.URL.Path { + case filesPath: + h.serveFiles(w, r, instanceID) + case terminalPath: + if h.terminal == nil { + // Proxy started without WithTerminal — don't leak that + // the feature exists but is off. + http.NotFound(w, r) + return + } + h.serveTerminal(w, r, instanceID) + default: + http.NotFound(w, r) + } +} + +// serveFiles handles POST/GET /files on the boxd host label. It +// verifies the sandbox access token, scrubs the token and caller- +// controlled headers, and reverse-proxies the request to boxd's +// internal /files handler. +func (h *Handler) serveFiles(w http.ResponseWriter, r *http.Request, instanceID string) { + if !h.filesEnabled { + // The proxy was started without WithFiles — either this is a + // legacy deployment that doesn't have the feature on yet or a + // misconfigured one. Don't leak which: return the same 404 a + // caller would see probing any other internal path. + http.NotFound(w, r) + return + } + + // boxd's /files handler only implements GET (download) and POST + // (upload). Anything else is a client bug and should surface loudly. + if r.Method != http.MethodGet && r.Method != http.MethodPost { + w.Header().Set("Allow", "GET, POST") + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + // Path traversal rejection. boxd's own safePath runs filepath.Clean, + // which silently resolves `..` segments instead of refusing them — + // `/home/user/../../../etc/x` quietly becomes `/etc/x` and gets + // written as root. That's technically no worse than what a caller + // could do via the exec endpoint, but it contradicts the documented + // "path traversal rejected" contract and turns a typo in a relative + // path into a silent write to an unintended location. Reject any + // request whose ?path= contains a literal `..` segment, before we + // hit the auth check. + requestedPath := r.URL.Query().Get("path") + if requestedPath == "" { + http.Error(w, "missing path query parameter", http.StatusBadRequest) + return + } + for _, seg := range strings.Split(requestedPath, "/") { + if seg == ".." { + http.Error(w, "path traversal not allowed", http.StatusBadRequest) + return + } + } + + token := r.Header.Get(accessTokenHeader) + if token == "" { + http.Error(w, "missing X-Access-Token header", http.StatusUnauthorized) + return + } + + // Scrub the token before forwarding to boxd. + r.Header.Del(accessTokenHeader) + + w.Header().Set("Referrer-Policy", "no-referrer") + + info, fail := h.authorizeSandboxRequest(r.Context(), token, instanceID) + if fail != nil { + h.log.Warn().Str("sandbox_id", instanceID).Int("status", fail.Status).Msg("files: auth failed") + fail.write(w) + return + } + + // From here on it's just a transparent reverse proxy to boxd. + // Reuse the lifecycle-keyed transport cache for the same reasons + // as the generic forwarder: one pooled set of TCP connections per + // sandbox incarnation, reset on pause/resume. + transport := h.transports.get(instanceID, info) + target := &url.URL{ + Scheme: "http", + Host: fmt.Sprintf("%s:%d", info.VMIP, boxdPort), + } + + rp := &httputil.ReverseProxy{ + Director: func(req *http.Request) { + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + // Preserve the original Host so boxd logs the public name, + // not the VM private IP. Also avoids Host-header confusion + // on any downstream middleware that trusts it. + req.Host = r.Host + // Strip all forwarding / origin headers — a caller could + // otherwise inject these to spoof identity in any boxd + // log or future handler that trusts them. Note the + // explicit `= nil` for X-Forwarded-For: httputil.ReverseProxy + // re-appends that header after the Director runs unless + // its value is the nil slice. A plain Del leaves it + // missing, which httputil then "helpfully" refills. + req.Header["X-Forwarded-For"] = nil + for _, hdr := range []string{ + "X-Forwarded-Host", + "X-Forwarded-Proto", + "X-Real-Ip", + "Forwarded", + } { + req.Header.Del(hdr) + } + }, + Transport: transport, + // FlushInterval -1 streams the response as it arrives, which is + // what we want for large downloads: the client sees bytes as + // boxd produces them, not after the whole file is buffered. + FlushInterval: -1, + ErrorHandler: func(rw http.ResponseWriter, req *http.Request, proxyErr error) { + h.log.Error().Err(proxyErr). + Str("instance", instanceID). + Str("target", target.Host). + Msg("files: upstream error") + // Invalidate so the next request re-resolves from VMD, + // in case the VM was replaced mid-stream. + h.resolver.Invalidate(instanceID) + rw.Header().Set("Retry-After", "2") + http.Error(rw, "sandbox unreachable", http.StatusBadGateway) + }, + } + rp.ServeHTTP(w, r) +} + diff --git a/internal/proxy/files_test.go b/internal/proxy/files_test.go new file mode 100644 index 0000000..ffe48dc --- /dev/null +++ b/internal/proxy/files_test.go @@ -0,0 +1,366 @@ +package proxy + +import ( + "context" + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "testing" + "time" + + "github.com/rs/zerolog" + + "github.com/superserve-ai/sandbox/internal/auth" +) + +// --------------------------------------------------------------------------- +// Test harness +// --------------------------------------------------------------------------- + +type stubResolver struct { + info InstanceInfo + err error + invMu sync.Mutex + invIDs []string +} + +func (s *stubResolver) Lookup(_ context.Context, _ string) (InstanceInfo, error) { + if s.err != nil { + return InstanceInfo{}, s.err + } + return s.info, nil +} + +func (s *stubResolver) Invalidate(instanceID string) { + s.invMu.Lock() + s.invIDs = append(s.invIDs, instanceID) + s.invMu.Unlock() +} + +type filesTestEnv struct { + t *testing.T + seedKey []byte + handler *Handler + upstream *httptest.Server + sandboxID string + domain string + resolver *stubResolver + upstreamMu sync.Mutex + lastReq capturedRequest +} + +type capturedRequest struct { + method string + path string + rawQuery string + host string + hasToken bool + fwdFor string + body string + received bool +} + +func newFilesTestEnv(t *testing.T) *filesTestEnv { + t.Helper() + + seedKey := []byte("test-seed-key-that-is-at-least-32-bytes-long!!") + + env := &filesTestEnv{ + t: t, + seedKey: seedKey, + sandboxID: "sbx-" + strings.Repeat("a", 8), + domain: "sandbox.test", + } + + env.upstream = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + bodyBytes, _ := io.ReadAll(r.Body) + env.upstreamMu.Lock() + env.lastReq = capturedRequest{ + method: r.Method, + path: r.URL.Path, + rawQuery: r.URL.RawQuery, + host: r.Host, + hasToken: r.Header.Get("X-Access-Token") != "", + fwdFor: r.Header.Get("X-Forwarded-For"), + body: string(bodyBytes), + received: true, + } + env.upstreamMu.Unlock() + + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + t.Cleanup(env.upstream.Close) + + upURL, _ := url.Parse(env.upstream.URL) + env.resolver = &stubResolver{ + info: InstanceInfo{ + VMIP: upURL.Hostname(), + Status: "running", + StartedAt: time.Now().UnixNano(), + }, + } + + env.handler = NewHandler(env.domain, env.resolver, zerolog.Nop()) + env.handler.WithAuth(seedKey).WithTerminal([]string{"*"}).WithFiles() + + upHost := upURL.Host + env.handler.transports = &transportCache{ + items: map[string]*transportEntry{}, + } + redirTransport := &http.Transport{ + DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, network, upHost) + }, + DisableKeepAlives: true, + } + env.handler.transports.items[env.sandboxID] = &transportEntry{ + lifecycleKey: env.resolver.info.lifecycleKey(), + transport: redirTransport, + lastUsed: time.Now(), + } + + return env +} + +func (e *filesTestEnv) validToken() string { + return auth.ComputeAccessToken(e.seedKey, e.sandboxID) +} + +func (e *filesTestEnv) buildRequest(method, filePath, token string, body io.Reader) *http.Request { + q := url.Values{} + if filePath != "" { + q.Set("path", filePath) + } + target := "http://unused/files" + if len(q) > 0 { + target += "?" + q.Encode() + } + req := httptest.NewRequest(method, target, body) + req.Host = "boxd-" + e.sandboxID + "." + e.domain + if token != "" { + req.Header.Set("X-Access-Token", token) + } + return req +} + +// --------------------------------------------------------------------------- +// Happy paths +// --------------------------------------------------------------------------- + +func TestFiles_UploadHeaderCarrier_ForwardsToUpstream(t *testing.T) { + env := newFilesTestEnv(t) + tok := env.validToken() + + req := env.buildRequest(http.MethodPost, "/home/u/app.txt", tok, + strings.NewReader("file contents")) + req.Header.Set("Content-Type", "application/octet-stream") + + w := httptest.NewRecorder() + env.handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200; body: %s", w.Code, w.Body.String()) + } + if !env.lastReq.received { + t.Fatal("upstream never received the request") + } + if env.lastReq.method != http.MethodPost { + t.Errorf("upstream method = %q, want POST", env.lastReq.method) + } + if env.lastReq.path != "/files" { + t.Errorf("upstream path = %q, want /files", env.lastReq.path) + } + if !strings.Contains(env.lastReq.rawQuery, "path=%2Fhome%2Fu%2Fapp.txt") { + t.Errorf("upstream query missing path param: %q", env.lastReq.rawQuery) + } + if env.lastReq.body != "file contents" { + t.Errorf("upstream body = %q, want 'file contents'", env.lastReq.body) + } + if env.lastReq.hasToken { + t.Error("X-Access-Token leaked to upstream") + } + if env.lastReq.fwdFor != "" { + t.Errorf("X-Forwarded-For leaked: %q", env.lastReq.fwdFor) + } + if !strings.HasPrefix(env.lastReq.host, "boxd-") { + t.Errorf("Host = %q, want public sandbox label", env.lastReq.host) + } +} + +// --------------------------------------------------------------------------- +// Auth rejections +// --------------------------------------------------------------------------- + +func TestFiles_MissingToken_Unauthorized(t *testing.T) { + env := newFilesTestEnv(t) + req := env.buildRequest(http.MethodGet, "/f.txt", "", nil) + w := httptest.NewRecorder() + env.handler.ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Fatalf("status = %d, want 401", w.Code) + } +} + +func TestFiles_WrongTokenRejected(t *testing.T) { + env := newFilesTestEnv(t) + req := env.buildRequest(http.MethodGet, "/f.txt", "totally-wrong-token", nil) + w := httptest.NewRecorder() + env.handler.ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Fatalf("status = %d, want 401", w.Code) + } +} + +func TestFiles_WrongSandboxTokenRejected(t *testing.T) { + env := newFilesTestEnv(t) + wrongToken := auth.ComputeAccessToken(env.seedKey, "different-sandbox-id") + req := env.buildRequest(http.MethodGet, "/f.txt", wrongToken, nil) + w := httptest.NewRecorder() + env.handler.ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Fatalf("status = %d, want 401", w.Code) + } +} + +func TestFiles_TokenReusable(t *testing.T) { + env := newFilesTestEnv(t) + tok := env.validToken() + + // First request + req1 := env.buildRequest(http.MethodGet, "/f.txt", tok, nil) + w1 := httptest.NewRecorder() + env.handler.ServeHTTP(w1, req1) + if w1.Code != http.StatusOK { + t.Fatalf("first call status = %d, want 200", w1.Code) + } + + // Same token again — should still work (not single-use anymore) + req2 := env.buildRequest(http.MethodGet, "/f.txt", tok, nil) + w2 := httptest.NewRecorder() + env.handler.ServeHTTP(w2, req2) + if w2.Code != http.StatusOK { + t.Fatalf("second call status = %d, want 200 (token should be reusable)", w2.Code) + } +} + +func TestFiles_SandboxNotRunningReturns503(t *testing.T) { + env := newFilesTestEnv(t) + env.resolver.info.Status = "paused" + tok := env.validToken() + req := env.buildRequest(http.MethodGet, "/f.txt", tok, nil) + w := httptest.NewRecorder() + env.handler.ServeHTTP(w, req) + if w.Code != http.StatusServiceUnavailable { + t.Fatalf("status = %d, want 503", w.Code) + } +} + +func TestFiles_SandboxNotFoundReturns404(t *testing.T) { + env := newFilesTestEnv(t) + env.resolver.err = ErrInstanceNotFound + tok := env.validToken() + req := env.buildRequest(http.MethodGet, "/f.txt", tok, nil) + w := httptest.NewRecorder() + env.handler.ServeHTTP(w, req) + if w.Code != http.StatusNotFound { + t.Fatalf("status = %d, want 404", w.Code) + } +} + +// --------------------------------------------------------------------------- +// Boxd-port lockdown +// --------------------------------------------------------------------------- + +func TestFiles_NonFilesPathBlocked(t *testing.T) { + env := newFilesTestEnv(t) + tok := env.validToken() + + paths := []string{ + "/superserve.boxd.v1.ProcessService/Start", + "/superserve.boxd.v1.FilesystemService/ListDir", + "/health", + "/", + } + for _, p := range paths { + t.Run(p, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://unused"+p, nil) + req.Host = "boxd-" + env.sandboxID + "." + env.domain + req.Header.Set("X-Access-Token", tok) + + w := httptest.NewRecorder() + env.handler.ServeHTTP(w, req) + if w.Code != http.StatusNotFound { + t.Errorf("path %q: status = %d, want 404", p, w.Code) + } + }) + } +} + +func TestFiles_PathTraversalRejected(t *testing.T) { + env := newFilesTestEnv(t) + tok := env.validToken() + + cases := []string{ + "/home/user/../../../etc/bad.txt", + "/home/user/x/../y", + "../x", + "..", + } + for _, p := range cases { + t.Run(p, func(t *testing.T) { + req := env.buildRequest(http.MethodPost, p, tok, + strings.NewReader("content")) + w := httptest.NewRecorder() + env.handler.ServeHTTP(w, req) + if w.Code != http.StatusBadRequest { + t.Errorf("status = %d, want 400", w.Code) + } + }) + } +} + +func TestFiles_MissingPathParam(t *testing.T) { + env := newFilesTestEnv(t) + tok := env.validToken() + + req := env.buildRequest(http.MethodPost, "", tok, + strings.NewReader("content")) + w := httptest.NewRecorder() + env.handler.ServeHTTP(w, req) + if w.Code != http.StatusBadRequest { + t.Errorf("status = %d, want 400", w.Code) + } +} + +func TestFiles_MethodNotAllowed(t *testing.T) { + env := newFilesTestEnv(t) + tok := env.validToken() + req := env.buildRequest(http.MethodDelete, "/f.txt", tok, nil) + w := httptest.NewRecorder() + env.handler.ServeHTTP(w, req) + if w.Code != http.StatusMethodNotAllowed { + t.Fatalf("status = %d, want 405", w.Code) + } + if got := w.Header().Get("Allow"); got != "GET, POST" { + t.Errorf("Allow = %q, want 'GET, POST'", got) + } +} + +func TestFiles_DisabledReturns404(t *testing.T) { + env := newFilesTestEnv(t) + env.handler.filesEnabled = false + + tok := env.validToken() + req := env.buildRequest(http.MethodGet, "/f.txt", tok, nil) + w := httptest.NewRecorder() + env.handler.ServeHTTP(w, req) + if w.Code != http.StatusNotFound { + t.Fatalf("status = %d, want 404", w.Code) + } +} diff --git a/internal/proxy/host.go b/internal/proxy/host.go index 7184314..da831ae 100644 --- a/internal/proxy/host.go +++ b/internal/proxy/host.go @@ -12,28 +12,51 @@ import ( // Prevents path traversal (%2f, ..) from reaching the VMD resolver URL. var validInstanceID = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9-]{0,63}$`) -// ParseHost extracts port and instanceID from a Host header of the form -// {port}-{instanceID}.{domain} and validates both fields. +// boxdHostLabel is the reserved left-most label that addresses boxd's +// own HTTP endpoint on the edge proxy. We deliberately do NOT let +// callers reach boxd by typing its numeric port in the URL: boxd's +// internal port is an implementation detail of the VM and putting it +// in public URLs would (a) leak a magic number into every integration +// and (b) give the impression that the port itself is exposed, when +// in reality the proxy handles `boxd-...` traffic specially and never +// bounces arbitrary paths through to that port. +const boxdHostLabel = "boxd" + +// ParseHost extracts the routing label and instanceID from a Host +// header of the form {label}-{instanceID}.{domain} and validates both. +// +// The label is either: +// +// - the literal word "boxd", which maps to the boxd port +// (boxdPort). This is the only way to address boxd through the +// edge proxy; the numeric form is intentionally rejected. +// +// - a decimal number in [1, 65535] above the privileged-port +// threshold, which routes to that user-application port on the +// VM. // -// Returns ErrInvalidHost if the host doesn't end with the expected domain suffix, -// so the proxy rejects forged Host headers pointing at arbitrary backends. +// Returns ErrInvalidHost if the host doesn't end with the expected +// domain suffix, so the proxy rejects forged Host headers pointing at +// arbitrary backends. func ParseHost(host, domain string) (port int, instanceID string, err error) { - // Strip TCP port from Host if present (e.g. "49983-abc.sandbox.superserve.ai:443") + // Strip TCP port from Host if present (e.g. "boxd-abc.sandbox.superserve.ai:443") hostname, _, splitErr := net.SplitHostPort(host) if splitErr != nil { hostname = host } - // Validate domain suffix — prevents accepting any Host: port-id.attacker.com + // Validate domain suffix — prevents accepting any Host: label-id.attacker.com if !strings.HasSuffix(hostname, "."+domain) { return 0, "", fmt.Errorf("proxy: host %q does not end in .%s", hostname, domain) } - // Take the leftmost label only: "49983-abc123" + // Take the leftmost label only: "boxd-abc123" or "3000-mybox" label, _, _ := strings.Cut(hostname, ".") - // Split on the first "-" to separate port from instance ID - portStr, instanceID, ok := strings.Cut(label, "-") + // Split on the first "-" to separate the routing label from the + // instance ID. Note: instance IDs themselves contain hyphens (UUIDs), + // so we only split once. + routing, instanceID, ok := strings.Cut(label, "-") if !ok { return 0, "", fmt.Errorf("proxy: host label %q has no '-' separator", label) } @@ -47,9 +70,20 @@ func ParseHost(host, domain string) (port int, instanceID string, err error) { return 0, "", fmt.Errorf("proxy: instance ID %q contains invalid characters", instanceID) } - port, err = strconv.Atoi(portStr) + // Reserved label for boxd. + if routing == boxdHostLabel { + return boxdPort, instanceID, nil + } + + // Numeric label for user-application ports. We accept a decimal + // in [1, 65535] but explicitly refuse the boxd port number — that + // address form exists only under the "boxd" label. + port, err = strconv.Atoi(routing) if err != nil || port < 1 || port > 65535 { - return 0, "", fmt.Errorf("proxy: invalid port %q in host %q", portStr, host) + return 0, "", fmt.Errorf("proxy: invalid label %q in host %q", routing, host) + } + if port == boxdPort { + return 0, "", fmt.Errorf("proxy: boxd must be addressed as %q, not by port number", boxdHostLabel) } return port, instanceID, nil diff --git a/internal/proxy/host_test.go b/internal/proxy/host_test.go index f52189f..fc27f75 100644 --- a/internal/proxy/host_test.go +++ b/internal/proxy/host_test.go @@ -13,40 +13,53 @@ func TestParseHost(t *testing.T) { wantInstanceID string wantErr bool }{ - // Happy path + // Happy path — boxd label maps to boxdPort { - host: "49983-abc123.sandbox.superserve.ai", - wantPort: 49983, + host: "boxd-abc123.sandbox.superserve.ai", + wantPort: boxdPort, wantInstanceID: "abc123", }, + { + host: "boxd-abc123.sandbox.superserve.ai:443", + wantPort: boxdPort, + wantInstanceID: "abc123", + }, + // UUID-style instance IDs (our actual format) + { + host: "boxd-b150ee22-4956-4f5b-926a-f921ed8c37d6.sandbox.superserve.ai", + wantPort: boxdPort, + wantInstanceID: "b150ee22-4956-4f5b-926a-f921ed8c37d6", + }, + // User application ports — numeric label { host: "3000-mybox.sandbox.superserve.ai", wantPort: 3000, wantInstanceID: "mybox", }, { - host: "49983-abc123.sandbox.superserve.ai:443", - wantPort: 49983, - wantInstanceID: "abc123", + host: "8080-b150ee22-4956-4f5b-926a-f921ed8c37d6.sandbox.superserve.ai", + wantPort: 8080, + wantInstanceID: "b150ee22-4956-4f5b-926a-f921ed8c37d6", }, - // UUID-style instance IDs (our actual format) + + // Boxd's numeric port is intentionally refused — the only way + // to reach boxd is via the "boxd" label. { - host: "49983-b150ee22-4956-4f5b-926a-f921ed8c37d6.sandbox.superserve.ai", - wantPort: 49983, - wantInstanceID: "b150ee22-4956-4f5b-926a-f921ed8c37d6", + host: "49983-abc.sandbox.superserve.ai", + wantErr: true, }, // Domain suffix validation { - host: "49983-abc.attacker.com", + host: "boxd-abc.attacker.com", wantErr: true, }, { - host: "49983-abc.evil.sandbox.superserve.ai.attacker.com", + host: "boxd-abc.evil.sandbox.superserve.ai.attacker.com", wantErr: true, }, { - host: "49983-abc", // no domain at all + host: "boxd-abc", // no domain at all wantErr: true, }, @@ -76,7 +89,7 @@ func TestParseHost(t *testing.T) { // Empty instance ID { - host: "49983-.sandbox.superserve.ai", + host: "boxd-.sandbox.superserve.ai", wantErr: true, }, diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 6cb06b0..fdd0e52 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -57,6 +57,19 @@ type Handler struct { sandboxConns *connLimiter ipConns *connLimiter log zerolog.Logger + + // seedKey is the HMAC seed shared with the control plane. Both + // sides derive per-sandbox access tokens as HMAC-SHA256(seed, sandboxID). + // Set via WithAuth; nil means data-plane endpoints are disabled. + seedKey []byte + + // terminal holds the dependencies specific to the /terminal WebSocket + // bridge (allowed browser origins for the Origin check). Nil means + // the /terminal path is disabled. + terminal *terminalBridgeDeps + + // filesEnabled controls whether /files on boxdPort is served. + filesEnabled bool } // NewHandler creates a proxy Handler that only accepts requests whose Host @@ -107,6 +120,18 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + // Boxd traffic (port 49983) is handled by a dedicated, token-gated + // path. Everything on that port — even outside /files — must never + // fall through to the generic reverse proxy, because that would + // expose the in-VM connect-rpc services (ProcessService, + // FilesystemService) directly to any internet caller who can guess + // an instance ID. The file bridge allowlists /files only; anything + // else on the boxd port is refused. + if port == boxdPort { + h.serveBoxdPort(w, r, instanceID) + return + } + info, err := h.resolver.Lookup(r.Context(), instanceID) if err != nil { if errors.Is(err, ErrInstanceNotFound) { @@ -156,8 +181,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { req.Host = r.Host // Strip all forwarding headers — a client could inject these to // spoof origin info that boxd or user apps might trust. + // X-Forwarded-For needs an explicit nil assignment because + // httputil.ReverseProxy re-appends it after the Director + // runs unless the slot is the nil slice. + req.Header["X-Forwarded-For"] = nil for _, h := range []string{ - "X-Forwarded-For", "X-Forwarded-Host", "X-Forwarded-Proto", "X-Real-Ip", diff --git a/internal/proxy/terminal.go b/internal/proxy/terminal.go new file mode 100644 index 0000000..9050451 --- /dev/null +++ b/internal/proxy/terminal.go @@ -0,0 +1,507 @@ +package proxy + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + "sync" + "syscall" + "time" + + "connectrpc.com/connect" + "connectrpc.com/otelconnect" + "github.com/coder/websocket" + "github.com/rs/zerolog" + + pb "github.com/superserve-ai/sandbox/proto/boxdpb" + "github.com/superserve-ai/sandbox/proto/boxdpb/boxdpbconnect" + + "github.com/superserve-ai/sandbox/internal/auth" +) + +// terminalBridgeDeps holds the dependencies specific to the /terminal +// WebSocket bridge. Auth is handled by the shared HMAC seed on the +// Handler; all that remains here is the browser origin allowlist. +type terminalBridgeDeps struct { + allowedOrigins []string +} + +// terminalBoxdInterceptors injects trace context into outbound Connect calls +// to boxd. No-op when telemetry is disabled (the global tracer is the SDK +// noop). otelconnect.NewInterceptor only errors on impossible configuration, +// so a panic at startup is the right escalation. +var terminalBoxdInterceptors = func() connect.Option { + i, err := otelconnect.NewInterceptor() + if err != nil { + panic("otelconnect.NewInterceptor: " + err.Error()) + } + return connect.WithInterceptors(i) +}() + +// WithAuth sets the HMAC seed used by every data-plane endpoint on the +// boxd host label (/terminal, /files). Call once at proxy startup. +func (h *Handler) WithAuth(seedKey []byte) *Handler { + if err := auth.ValidateSeed(seedKey); err != nil { + panic("proxy: " + err.Error()) + } + h.seedKey = seedKey + return h +} + +// WithTerminal enables the /terminal WebSocket bridge. Requires +// WithAuth to have been called first. +func (h *Handler) WithTerminal(allowedOrigins []string) *Handler { + if len(allowedOrigins) == 0 { + panic("proxy: WithTerminal requires at least one allowed origin (use \"*\" for dev)") + } + if h.seedKey == nil { + panic("proxy: WithTerminal requires WithAuth to be called first") + } + h.terminal = &terminalBridgeDeps{allowedOrigins: allowedOrigins} + return h +} + +// WithFiles enables the /files HTTP reverse proxy on boxdPort. Requires +// WithAuth to have been called first. +func (h *Handler) WithFiles() *Handler { + if h.seedKey == nil { + panic("proxy: WithFiles requires WithAuth to be called first") + } + h.filesEnabled = true + return h +} + +// Terminal bridge constants. +const ( + // boxdPort is the port boxd's connect-rpc HTTP server listens on + // inside each VM. Terminal sessions target the same port as the + // regular proxy path for user apps at port 49983 — boxd is the + // single HTTP endpoint exposed by the VM. + boxdPort = 49983 + + // terminalProtocol is the identifying WebSocket subprotocol echoed + // back on a successful upgrade. Bump the version if the wire format + // ever changes so older clients break loudly instead of silently + // misinterpreting frames. + terminalProtocol = "superserve.terminal.v1" + + // tokenProtocolPrefix is how clients smuggle the auth token through + // the WebSocket handshake without putting it in a URL query param. + // Browser WebSocket APIs cannot set custom headers on upgrade, but + // they CAN set the Sec-WebSocket-Protocol header via the second arg + // of `new WebSocket(url, protocols)`. We look for an entry starting + // with this prefix and treat the suffix as the signed token. The + // server never echoes this value back — only terminalProtocol — so + // the token never lands in a response header or access log. + tokenProtocolPrefix = "token." + + // maxReadBytes bounds the size of a single WebSocket frame we will + // accept from the browser. Terminal input frames are keystrokes, + // typically 1-10 bytes. 64 KiB is several orders of magnitude + // above legitimate traffic and protects boxd from a malicious + // client asking us to forward a huge SendInput payload for every + // "keystroke." + maxReadBytes = 64 * 1024 + + // writeWait is the max duration we wait for a WS write to complete + // before closing the connection. If the browser side is slow or + // unresponsive we want to free the PTY rather than block forever. + writeWait = 10 * time.Second + + // idleCloseAfter is how long the WS can sit with no traffic in + // either direction before we tear it down. Protects against zombie + // connections (user closes laptop, leaves tab open, network drops + // without FIN). Re-set on every message in either direction. + idleCloseAfter = 10 * time.Minute + + // maxSessionDuration is a hard ceiling on a single terminal session + // regardless of traffic. Bounds the blast radius of a hijacked WS + // connection — even if an attacker sends just enough traffic to + // keep the idle timer alive, the session still terminates after + // this duration. Clients should reconnect for longer-running work. + maxSessionDuration = 4 * time.Hour + + // initialTerminalCols / initialTerminalRows are the PTY dimensions + // we start the shell with. The browser will almost immediately send + // a resize message once xterm.js measures its container, so these + // values are just placeholders that minimize visual glitches during + // the ~100ms handshake. + initialTerminalCols = 80 + initialTerminalRows = 24 +) + +// wsControlMessage is the wire format for text-frame messages on the WS. +// Binary frames are raw PTY bytes; text frames carry JSON control. +// +// We use a tagged union (discriminated by Type) so future message types +// can be added additively without breaking older clients. Unknown Type +// values are logged and dropped. +type wsControlMessage struct { + Type string `json:"type"` + + // Resize fields — populated when Type == "resize". + Cols uint32 `json:"cols,omitempty"` + Rows uint32 `json:"rows,omitempty"` + + // Signal fields — populated when Type == "signal". + // Name is a POSIX signal name ("SIGINT", "SIGTERM", "SIGKILL"). + // We translate to numeric values server-side so clients don't need + // to hardcode Linux signal numbers. + Name string `json:"name,omitempty"` +} + +// serveTerminal handles /terminal requests for a sandbox addressed by +// the `boxd-{id}.{domain}` host label. The caller (serveBoxdPort) has +// already parsed the instance ID and confirmed the terminal feature is +// enabled; this function is responsible for: +// +// 1. Extracting the HMAC access token from the WebSocket subprotocol. +// 2. Verifying the token against the sandbox ID. +// 3. Resolving the VM IP via the resolver. +// 4. Upgrading the HTTP connection to a WebSocket. +// 5. Opening a connect-rpc stream to boxd ProcessService.Start. +// 6. Bridging bytes until either side closes. +// +// Errors before the upgrade are returned as standard HTTP error responses. +// Errors after the upgrade are sent as WebSocket close frames with codes +// from the coder/websocket library. +func (h *Handler) serveTerminal(w http.ResponseWriter, r *http.Request, instanceID string) { + // The token is carried in the Sec-WebSocket-Protocol header, NOT the + // URL. This keeps the token out of GCP LB access logs, browser history, + // Referer headers on sub-resources, and any middleware request logger. + // See extractTerminalToken for the parser. + // + // Defence in depth: unconditionally scrub any ?t= query param before + // the request reaches any downstream logger, in case an older client + // still sends one. + r.URL.RawQuery = "" + + // Discourage browsers from sending Referer on any sub-resource the + // handshake might spawn. Terminal upgrades don't produce sub-resources + // but the header is cheap and matches the "token is sensitive" posture. + w.Header().Set("Referrer-Policy", "no-referrer") + + token := extractTerminalToken(r) + if token == "" { + http.Error(w, "missing token (pass as Sec-WebSocket-Protocol: token.)", http.StatusUnauthorized) + return + } + + info, fail := h.authorizeSandboxRequest(r.Context(), token, instanceID) + if fail != nil { + h.log.Warn().Str("sandbox_id", instanceID).Int("status", fail.Status).Msg("terminal: auth failed") + fail.write(w) + return + } + + // From here on, errors go back through the WebSocket (if the upgrade + // succeeds) because we've committed to streaming. + // + // Origin enforcement: Origin is a CSRF-like defense that stops a + // malicious page in the user's browser from leveraging a leaked or + // coerced token (e.g. via a mint-CSRF path) to open a live terminal. + // The set is configured at proxy startup. Passing `"*"` explicitly is + // how local dev opts out; the default is deny-unknown. + // + // Subprotocols: we declare terminalProtocol as the one we'll accept + // and echo. Clients must include it in their offered list. The + // token-carrier subprotocol (token.) is NOT listed here, so + // coder/websocket will never echo it back in the handshake response. + acceptOpts := &websocket.AcceptOptions{ + OriginPatterns: h.terminal.allowedOrigins, + Subprotocols: []string{terminalProtocol}, + CompressionMode: websocket.CompressionDisabled, + } + // Explicit opt-out: a single "*" entry in allowed origins disables + // the check for dev. The config loader is responsible for only + // allowing this via an explicit env var. + if len(h.terminal.allowedOrigins) == 1 && h.terminal.allowedOrigins[0] == "*" { + acceptOpts.OriginPatterns = nil + acceptOpts.InsecureSkipVerify = true + } + + ws, err := websocket.Accept(w, r, acceptOpts) + if err != nil { + h.log.Warn().Err(err).Msg("terminal: WS upgrade failed") + return + } + + // Bound the size of any single input frame. Keystrokes are tiny; + // anything approaching 64 KiB is either a paste (legitimate but + // still bounded) or an attack trying to amplify into SendInput. + ws.SetReadLimit(maxReadBytes) + + // Build the connect-rpc client to boxd. We use the transport cache so + // the bridge benefits from the same lifecycle keying and connection + // pooling as the generic proxy path. + transport := h.transports.get(instanceID, info) + httpClient := &http.Client{Transport: transport} + baseURL := fmt.Sprintf("http://%s:%d", info.VMIP, boxdPort) + procClient := boxdpbconnect.NewProcessServiceClient(httpClient, baseURL, terminalBoxdInterceptors) + + // Tie the bridge lifetime to the request context so shutdowns + // propagate cleanly. The WS will be closed in bridgeTerminal. + ctx := r.Context() + h.bridgeTerminal(ctx, ws, procClient, instanceID) +} + +// bridgeTerminal is the long-lived function that pumps bytes between the +// WebSocket and boxd until one side closes. It owns the WS handle and is +// responsible for closing it on the way out. +// +// Design: two goroutines, one per direction, plus the main goroutine that +// waits for either to finish. When either direction errors, we cancel the +// shared context and the other direction sees its read/write return, then +// exits. This is the simplest correct pattern for bidirectional bridging. +func (h *Handler) bridgeTerminal(ctx context.Context, ws *websocket.Conn, procClient boxdpbconnect.ProcessServiceClient, instanceID string) { + l := h.log.With(). + Str("sandbox_id", instanceID). + Logger() + + // Scoped context so either direction's failure cancels everything. + bridgeCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // Start a shell in PTY mode. Initial size is a placeholder — the + // browser will resize immediately after mount. + startReq := connect.NewRequest(&pb.StartRequest{ + Cmd: "/bin/bash", + Pty: &pb.PtyConfig{ + Size: &pb.TerminalSize{ + Cols: initialTerminalCols, + Rows: initialTerminalRows, + }, + }, + }) + stream, err := procClient.Start(bridgeCtx, startReq) + if err != nil { + l.Error().Err(err).Msg("terminal: boxd Start failed") + _ = ws.Close(websocket.StatusInternalError, "failed to start shell") + return + } + + // Read the first event — must be a StartEvent so we learn the PID + // (needed for SendInput / Resize / Signal which all address the + // process by PID). + if !stream.Receive() { + l.Error().Err(stream.Err()).Msg("terminal: boxd stream empty on start") + _ = ws.Close(websocket.StatusInternalError, "shell did not start") + return + } + startEvent := stream.Msg().GetStart() + if startEvent == nil { + l.Error().Msg("terminal: first event was not StartEvent") + _ = ws.Close(websocket.StatusInternalError, "unexpected event") + return + } + pid := startEvent.GetPid() + l.Info().Uint32("pid", pid).Msg("terminal: bridge established") + + // Idle timer — reset on every message in either direction. Tears + // down sessions that have gone quiet for idleCloseAfter. + var idleMu sync.Mutex + lastActivity := time.Now() + touchIdle := func() { + idleMu.Lock() + lastActivity = time.Now() + idleMu.Unlock() + } + + // Hard session deadline — independent of activity. Bounds the + // blast radius of a hijacked WS connection regardless of how + // chatty the attacker is. Clients should reconnect for sessions + // longer than this. + sessionDeadline := time.NewTimer(maxSessionDuration) + defer sessionDeadline.Stop() + + go func() { + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + for { + select { + case <-bridgeCtx.Done(): + return + case <-sessionDeadline.C: + l.Info().Dur("max", maxSessionDuration).Msg("terminal: max session reached, closing") + cancel() + return + case <-ticker.C: + idleMu.Lock() + last := lastActivity + idleMu.Unlock() + if time.Since(last) > idleCloseAfter { + l.Info().Msg("terminal: idle timeout, closing") + cancel() + return + } + } + } + }() + + var wg sync.WaitGroup + wg.Add(2) + + // ------- boxd → browser ------- + // Read ProcessEvents from the connect-rpc stream; write PtyData as + // binary WebSocket frames. Non-PTY events (stdout/stderr outside + // PTY mode, End, Keepalive) are ignored — we asked for PTY so + // everything interactive comes through PtyData. + go func() { + defer wg.Done() + defer cancel() + for stream.Receive() { + msg := stream.Msg() + if d := msg.GetData(); d != nil { + if pty := d.GetPtyData(); len(pty) > 0 { + wctx, wcancel := context.WithTimeout(bridgeCtx, writeWait) + if err := ws.Write(wctx, websocket.MessageBinary, pty); err != nil { + wcancel() + if !errors.Is(err, context.Canceled) { + l.Debug().Err(err).Msg("terminal: WS write failed") + } + return + } + wcancel() + touchIdle() + } + } + if e := msg.GetEnd(); e != nil { + // Shell exited — close the WS with a clean + // code so xterm.js can show "session ended". + l.Info().Int32("exit_code", e.GetExitCode()).Msg("terminal: shell exited") + _ = ws.Close(websocket.StatusNormalClosure, "shell exited") + return + } + } + if err := stream.Err(); err != nil && !errors.Is(err, context.Canceled) { + l.Warn().Err(err).Msg("terminal: boxd stream error") + } + }() + + // ------- browser → boxd ------- + // Read WS frames. Binary frames are PTY input (forward as SendInput + // with the captured PID). Text frames are control JSON. + go func() { + defer wg.Done() + defer cancel() + for { + typ, data, err := ws.Read(bridgeCtx) + if err != nil { + if !errors.Is(err, context.Canceled) { + // Normal close is not an error. + closeErr := websocket.CloseStatus(err) + if closeErr != websocket.StatusNormalClosure && closeErr != websocket.StatusGoingAway { + l.Debug().Err(err).Msg("terminal: WS read ended") + } + } + return + } + touchIdle() + + switch typ { + case websocket.MessageBinary: + _, err := procClient.SendInput(bridgeCtx, connect.NewRequest(&pb.SendInputRequest{ + Pid: pid, + Data: data, + })) + if err != nil { + l.Warn().Err(err).Msg("terminal: boxd SendInput failed") + return + } + case websocket.MessageText: + h.handleControlMessage(bridgeCtx, procClient, pid, data, l) + } + } + }() + + wg.Wait() + _ = ws.Close(websocket.StatusNormalClosure, "bridge closed") +} + +// handleControlMessage parses a text frame as a wsControlMessage and +// dispatches to the appropriate boxd RPC. Unknown types are logged and +// dropped rather than crashing the bridge — a client speaking a newer +// protocol version shouldn't tear down the session. +func (h *Handler) handleControlMessage(ctx context.Context, client boxdpbconnect.ProcessServiceClient, pid uint32, data []byte, l zerolog.Logger) { + var msg wsControlMessage + if err := json.Unmarshal(data, &msg); err != nil { + l.Warn().Err(err).Msg("terminal: bad control JSON") + return + } + + switch msg.Type { + case "resize": + if msg.Cols == 0 || msg.Rows == 0 { + l.Warn().Msg("terminal: resize with zero dims") + return + } + _, err := client.Resize(ctx, connect.NewRequest(&pb.ResizeRequest{ + Pid: pid, + Size: &pb.TerminalSize{Cols: msg.Cols, Rows: msg.Rows}, + })) + if err != nil { + l.Warn().Err(err).Msg("terminal: boxd Resize failed") + } + + case "signal": + signum, ok := signalNameToNumber(msg.Name) + if !ok { + l.Warn().Str("name", msg.Name).Msg("terminal: unknown signal name") + return + } + _, err := client.Signal(ctx, connect.NewRequest(&pb.SignalRequest{ + Pid: pid, + Signal: signum, + })) + if err != nil { + l.Warn().Err(err).Msg("terminal: boxd Signal failed") + } + + default: + l.Debug().Str("type", msg.Type).Msg("terminal: unknown control type") + } +} + +// extractTerminalToken pulls the signed token out of the +// Sec-WebSocket-Protocol header. Clients include one entry of the form +// `token.` alongside the main terminalProtocol entry. This keeps +// the token out of URLs, logs, and referrers. +// +// Multiple entries per header line are comma-separated per RFC 6455; we +// also accept multiple header values. Entries are trimmed and compared +// case-sensitively (the prefix is ASCII, no folding needed). +// +// Returns "" if no token entry is found. The caller logs and rejects. +func extractTerminalToken(r *http.Request) string { + for _, hv := range r.Header.Values("Sec-WebSocket-Protocol") { + for _, part := range strings.Split(hv, ",") { + p := strings.TrimSpace(part) + if strings.HasPrefix(p, tokenProtocolPrefix) { + return strings.TrimPrefix(p, tokenProtocolPrefix) + } + } + } + return "" +} + +// signalNameToNumber maps POSIX signal names to their numeric values. +// Limited to signals a terminal user legitimately needs — we don't want +// browsers sending SIGKILL/SIGSTOP willy-nilly even though boxd would +// accept them. +func signalNameToNumber(name string) (int32, bool) { + switch name { + case "SIGINT": + return int32(syscall.SIGINT), true + case "SIGTERM": + return int32(syscall.SIGTERM), true + case "SIGHUP": + return int32(syscall.SIGHUP), true + case "SIGQUIT": + return int32(syscall.SIGQUIT), true + } + return 0, false +} + diff --git a/internal/proxy/terminal_test.go b/internal/proxy/terminal_test.go new file mode 100644 index 0000000..fb6c520 --- /dev/null +++ b/internal/proxy/terminal_test.go @@ -0,0 +1,475 @@ +package proxy + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "connectrpc.com/connect" + "github.com/coder/websocket" + "github.com/rs/zerolog" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" + + pb "github.com/superserve-ai/sandbox/proto/boxdpb" + "github.com/superserve-ai/sandbox/proto/boxdpb/boxdpbconnect" +) + +// fakeProcessService is a connect-rpc ProcessService implementation used by +// the bridge tests. It gives tests fine control over what the boxd-side of +// the bridge sees: callers can inject events to be emitted from Start, and +// every SendInput/Resize/Signal call is captured on channels for assertions. +// +// All operations are safe for concurrent use — the bridge's two goroutines +// call these methods from different goroutines. +type fakeProcessService struct { + boxdpbconnect.UnimplementedProcessServiceHandler + + // events is the queue of ProcessEvents the Start stream should emit + // to the client. Tests push events on this channel; Start drains it. + events chan *pb.ProcessEvent + + // inputs captures every SendInput call. Tests assert on length/content. + inputs chan *pb.SendInputRequest + + // resizes captures every Resize call. + resizes chan *pb.ResizeRequest + + // signals captures every Signal call. + signals chan *pb.SignalRequest + + // startErr, if set, is returned from Start immediately — lets tests + // drive the "boxd failed to start shell" path. + startErr error + + // sendInputErr, if set, is returned from SendInput — drives the + // "boxd errored mid-stream" path. + sendInputErr error +} + +func newFakeProcessService() *fakeProcessService { + return &fakeProcessService{ + events: make(chan *pb.ProcessEvent, 16), + inputs: make(chan *pb.SendInputRequest, 16), + resizes: make(chan *pb.ResizeRequest, 16), + signals: make(chan *pb.SignalRequest, 16), + } +} + +func (f *fakeProcessService) Start(ctx context.Context, req *connect.Request[pb.StartRequest], stream *connect.ServerStream[pb.ProcessEvent]) error { + if f.startErr != nil { + return f.startErr + } + for { + select { + case <-ctx.Done(): + return nil + case ev, ok := <-f.events: + if !ok { + return nil + } + if err := stream.Send(ev); err != nil { + return err + } + } + } +} + +func (f *fakeProcessService) SendInput(ctx context.Context, req *connect.Request[pb.SendInputRequest]) (*connect.Response[pb.SendInputResponse], error) { + if f.sendInputErr != nil { + return nil, f.sendInputErr + } + f.inputs <- req.Msg + return connect.NewResponse(&pb.SendInputResponse{}), nil +} + +func (f *fakeProcessService) Resize(ctx context.Context, req *connect.Request[pb.ResizeRequest]) (*connect.Response[pb.ResizeResponse], error) { + f.resizes <- req.Msg + return connect.NewResponse(&pb.ResizeResponse{}), nil +} + +func (f *fakeProcessService) Signal(ctx context.Context, req *connect.Request[pb.SignalRequest]) (*connect.Response[pb.SignalResponse], error) { + f.signals <- req.Msg + return connect.NewResponse(&pb.SignalResponse{}), nil +} + +// startEvent is a helper to push a StartEvent with the given PID onto the +// fake's event queue. This is the first event the bridge expects. +func (f *fakeProcessService) pushStart(pid uint32) { + f.events <- &pb.ProcessEvent{ + Event: &pb.ProcessEvent_Start{Start: &pb.StartEvent{Pid: pid}}, + } +} + +// pushPty enqueues a PtyData event — these become binary WS frames to the browser. +func (f *fakeProcessService) pushPty(data []byte) { + f.events <- &pb.ProcessEvent{ + Event: &pb.ProcessEvent_Data{ + Data: &pb.DataEvent{ + Output: &pb.DataEvent_PtyData{PtyData: data}, + }, + }, + } +} + +// pushEnd enqueues an EndEvent — signals the shell exited, should close the WS. +func (f *fakeProcessService) pushEnd(code int32) { + f.events <- &pb.ProcessEvent{ + Event: &pb.ProcessEvent_End{End: &pb.EndEvent{ExitCode: code}}, + } +} + +// --------------------------------------------------------------------------- +// Test harness +// --------------------------------------------------------------------------- + +// bridgeTestEnv wires up everything a bridge test needs: a fake boxd, a +// proxy handler that pipes a WS upgrade directly into bridgeTerminal, and a +// WS client dialled against the proxy. Tests just drive the fake and assert +// on what the client sees. +type bridgeTestEnv struct { + t *testing.T + fake *fakeProcessService + boxdSrv *httptest.Server + proxySrv *httptest.Server + clientWS *websocket.Conn + procClient boxdpbconnect.ProcessServiceClient +} + +func newBridgeTestEnv(t *testing.T) *bridgeTestEnv { + t.Helper() + fake := newFakeProcessService() + + // Fake boxd — mount the connect-rpc handler on httptest with h2c so + // connect streaming works over HTTP/2 without TLS setup. + path, handler := boxdpbconnect.NewProcessServiceHandler(fake) + boxdMux := http.NewServeMux() + boxdMux.Handle(path, handler) + boxdSrv := httptest.NewUnstartedServer(h2c.NewHandler(boxdMux, &http2.Server{})) + boxdSrv.EnableHTTP2 = true + boxdSrv.Start() + + // A connect-rpc client pointing at the fake. Uses an explicit + // http.Client with HTTP/2 transport so streaming flows correctly. + procClient := boxdpbconnect.NewProcessServiceClient( + boxdSrv.Client(), + boxdSrv.URL, + ) + + // Proxy handler that accepts a WS upgrade and hands it to the bridge. + h := &Handler{ + transports: newTransportCache(), + log: zerolog.Nop(), + } + + proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + InsecureSkipVerify: true, + CompressionMode: websocket.CompressionDisabled, + }) + if err != nil { + t.Errorf("ws accept: %v", err) + return + } + h.bridgeTerminal(r.Context(), ws, procClient, "sbx-test") + })) + + // Dial the proxy. + wsURL := "ws" + strings.TrimPrefix(proxySrv.URL, "http") + dialCtx, dialCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer dialCancel() + clientWS, _, err := websocket.Dial(dialCtx, wsURL, nil) + if err != nil { + t.Fatalf("ws dial: %v", err) + } + + env := &bridgeTestEnv{ + t: t, + fake: fake, + boxdSrv: boxdSrv, + proxySrv: proxySrv, + clientWS: clientWS, + procClient: procClient, + } + t.Cleanup(env.close) + return env +} + +func (e *bridgeTestEnv) close() { + _ = e.clientWS.Close(websocket.StatusNormalClosure, "test cleanup") + e.proxySrv.Close() + e.boxdSrv.Close() +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +// TestBridge_BinaryInputForwardedToSendInput verifies that a binary WS frame +// from the client arrives at boxd as a SendInput call carrying the same +// bytes and the captured PID. +func TestBridge_BinaryInputForwardedToSendInput(t *testing.T) { + env := newBridgeTestEnv(t) + env.fake.pushStart(42) + + // Give the bridge a moment to receive the StartEvent before we send input. + time.Sleep(100 * time.Millisecond) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := env.clientWS.Write(ctx, websocket.MessageBinary, []byte("ls -la\n")); err != nil { + t.Fatalf("client write: %v", err) + } + + select { + case got := <-env.fake.inputs: + if string(got.Data) != "ls -la\n" { + t.Errorf("SendInput.Data = %q, want %q", got.Data, "ls -la\n") + } + if got.Pid != 42 { + t.Errorf("SendInput.Pid = %d, want 42", got.Pid) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for SendInput") + } +} + +// TestBridge_PtyDataForwardedToClient verifies that a PtyData event from +// boxd becomes a binary WS frame on the client side. +func TestBridge_PtyDataForwardedToClient(t *testing.T) { + env := newBridgeTestEnv(t) + env.fake.pushStart(42) + env.fake.pushPty([]byte("hello terminal\n")) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + typ, data, err := env.clientWS.Read(ctx) + if err != nil { + t.Fatalf("client read: %v", err) + } + if typ != websocket.MessageBinary { + t.Errorf("type = %v, want Binary", typ) + } + if string(data) != "hello terminal\n" { + t.Errorf("data = %q, want %q", data, "hello terminal\n") + } +} + +// TestBridge_ResizeControlMessage verifies that a text frame carrying a +// resize message dispatches to the Resize RPC with the correct dimensions. +func TestBridge_ResizeControlMessage(t *testing.T) { + env := newBridgeTestEnv(t) + env.fake.pushStart(42) + time.Sleep(100 * time.Millisecond) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := env.clientWS.Write(ctx, websocket.MessageText, []byte(`{"type":"resize","cols":120,"rows":30}`)); err != nil { + t.Fatalf("client write: %v", err) + } + + select { + case got := <-env.fake.resizes: + if got.Pid != 42 { + t.Errorf("Resize.Pid = %d, want 42", got.Pid) + } + if got.Size.Cols != 120 || got.Size.Rows != 30 { + t.Errorf("Resize.Size = {%d,%d}, want {120,30}", got.Size.Cols, got.Size.Rows) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for Resize") + } +} + +// TestBridge_SignalAllowlist verifies that only the allowed signals get +// forwarded. A SIGKILL attempt from the browser should be dropped. +func TestBridge_SignalAllowlist(t *testing.T) { + env := newBridgeTestEnv(t) + env.fake.pushStart(42) + time.Sleep(100 * time.Millisecond) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + // Allowed — should arrive. + _ = env.clientWS.Write(ctx, websocket.MessageText, []byte(`{"type":"signal","name":"SIGINT"}`)) + select { + case got := <-env.fake.signals: + if got.Signal != 2 { // SIGINT + t.Errorf("Signal = %d, want 2 (SIGINT)", got.Signal) + } + case <-time.After(2 * time.Second): + t.Fatal("allowed SIGINT did not arrive") + } + + // Blocked — should be dropped. We assert by checking nothing arrives + // within a short window. + _ = env.clientWS.Write(ctx, websocket.MessageText, []byte(`{"type":"signal","name":"SIGKILL"}`)) + select { + case got := <-env.fake.signals: + t.Errorf("SIGKILL should have been blocked, got signal %d", got.Signal) + case <-time.After(250 * time.Millisecond): + // Expected — nothing arrived. + } +} + +// TestBridge_ShellExitClosesWS verifies that when boxd emits an EndEvent +// the bridge closes the WebSocket with a normal close code. +func TestBridge_ShellExitClosesWS(t *testing.T) { + env := newBridgeTestEnv(t) + env.fake.pushStart(42) + env.fake.pushEnd(0) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + // Read until we get the close — the client may see the PTY flush + // first, then the close. + for { + _, _, err := env.clientWS.Read(ctx) + if err == nil { + continue + } + status := websocket.CloseStatus(err) + if status != websocket.StatusNormalClosure { + t.Errorf("close status = %d, want NormalClosure", status) + } + return + } +} + +// TestBridge_ClientCloseStopsBoxdStream verifies that closing the WS from +// the client side causes the bridge to cancel its upstream connect-rpc +// stream, so boxd doesn't leak goroutines. +func TestBridge_ClientCloseStopsBoxdStream(t *testing.T) { + env := newBridgeTestEnv(t) + env.fake.pushStart(42) + // Let the bridge receive the start event and enter the pump loop. + time.Sleep(100 * time.Millisecond) + + _ = env.clientWS.Close(websocket.StatusNormalClosure, "bye") + + // After the close propagates, the fake's Start call context should + // be cancelled. We observe this by checking that pushing a new event + // eventually blocks (no consumer) — using a very short timeout plus + // a best-effort push. + done := make(chan struct{}) + go func() { + select { + case env.fake.events <- &pb.ProcessEvent{ + Event: &pb.ProcessEvent_Data{ + Data: &pb.DataEvent{Output: &pb.DataEvent_PtyData{PtyData: []byte("ignored")}}, + }, + }: + default: + } + close(done) + }() + <-done + + // If the bridge is still alive, it would have consumed the event. + // We can't directly introspect goroutine state, but subsequent WS + // operations should fail — which is asserted implicitly by the WS + // already being closed. This is a smoke check. +} + +// TestBridge_StartErrorClosesWSImmediately verifies the early-failure path: +// if boxd rejects Start, the WS should be closed before any bytes flow. +func TestBridge_StartErrorClosesWSImmediately(t *testing.T) { + fake := newFakeProcessService() + fake.startErr = errors.New("boxd: start failed") + + path, handler := boxdpbconnect.NewProcessServiceHandler(fake) + boxdMux := http.NewServeMux() + boxdMux.Handle(path, handler) + boxdSrv := httptest.NewUnstartedServer(h2c.NewHandler(boxdMux, &http2.Server{})) + boxdSrv.EnableHTTP2 = true + boxdSrv.Start() + defer boxdSrv.Close() + + procClient := boxdpbconnect.NewProcessServiceClient(boxdSrv.Client(), boxdSrv.URL) + + h := &Handler{transports: newTransportCache(), log: zerolog.Nop()} + proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{InsecureSkipVerify: true}) + if err != nil { + return + } + h.bridgeTerminal(r.Context(), ws, procClient, "sbx") + })) + defer proxySrv.Close() + + wsURL := "ws" + strings.TrimPrefix(proxySrv.URL, "http") + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + client, _, err := websocket.Dial(ctx, wsURL, nil) + if err != nil { + t.Fatalf("ws dial: %v", err) + } + + // Expect the WS to be closed with InternalError because Start failed. + _, _, err = client.Read(ctx) + if err == nil { + t.Fatal("expected read error after Start failure") + } + if status := websocket.CloseStatus(err); status != websocket.StatusInternalError { + t.Errorf("close status = %d, want InternalError", status) + } +} + +// TestBridge_ConcurrentInputOutput exercises the two-goroutine pump under +// simultaneous traffic in both directions — catches races in shared state +// or ordering assumptions. +func TestBridge_ConcurrentInputOutput(t *testing.T) { + env := newBridgeTestEnv(t) + env.fake.pushStart(42) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Emit 20 PtyData events from boxd. + go func() { + for i := 0; i < 20; i++ { + env.fake.pushPty([]byte("line\n")) + } + }() + + // Write 20 input frames from the client. + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 20; i++ { + if err := env.clientWS.Write(ctx, websocket.MessageBinary, []byte("k")); err != nil { + return + } + } + }() + + // Read 20 messages from the client side. + received := 0 + for received < 20 { + _, _, err := env.clientWS.Read(ctx) + if err != nil { + t.Fatalf("read %d: %v", received, err) + } + received++ + } + + // Verify 20 SendInputs arrived at boxd. + wg.Wait() + inputsSeen := 0 + for inputsSeen < 20 { + select { + case <-env.fake.inputs: + inputsSeen++ + case <-time.After(2 * time.Second): + t.Fatalf("only %d of 20 inputs arrived", inputsSeen) + } + } +} diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go new file mode 100644 index 0000000..dfd4a9f --- /dev/null +++ b/internal/scheduler/scheduler.go @@ -0,0 +1,110 @@ +package scheduler + +import ( + "context" + "fmt" + "math/rand/v2" + "sync" + "time" + + "github.com/superserve-ai/sandbox/internal/db" +) + +// Scheduler selects a host for a new sandbox. +type Scheduler interface { + SelectHost(ctx context.Context) (hostID string, err error) +} + +const defaultCacheTTL = 30 * time.Second + +// LeastLoaded picks the active host with the fewest running sandboxes +// using the "power of two random choices" algorithm. Instead of always +// picking the globally least-loaded host (which causes thundering herd +// when many creates arrive simultaneously), it samples two random hosts +// from the active set and picks the one with fewer sandboxes. +// +// With one host this degenerates to always picking that host. With two +// or more it spreads load naturally without coordination. The algorithm +// is proven to reduce max load from O(log n / log log n) to O(log log n). +// +// If no host rows exist in the table, SelectHost falls back to +// DefaultHostID so sandbox creation works without populating the host table. +type LeastLoaded struct { + DB *db.Queries + DefaultHostID string // fallback when no host rows exist + TTL time.Duration // 0 = use defaultCacheTTL + + mu sync.RWMutex + cached []db.ListActiveHostsByLoadRow + cachedAt time.Time +} + +func (s *LeastLoaded) ttl() time.Duration { + if s.TTL > 0 { + return s.TTL + } + return defaultCacheTTL +} + +func (s *LeastLoaded) SelectHost(ctx context.Context) (string, error) { + hosts, err := s.loadHosts(ctx) + if err != nil { + return "", err + } + if len(hosts) == 0 { + if s.DefaultHostID != "" { + return s.DefaultHostID, nil + } + return "", fmt.Errorf("no active hosts available") + } + if len(hosts) == 1 { + return hosts[0].ID, nil + } + + // Power of two random choices: pick two random hosts, return the + // one with fewer active sandboxes. This avoids the thundering-herd + // problem where every concurrent create picks the same least-loaded + // host from a globally-sorted list. + a := rand.IntN(len(hosts)) + b := rand.IntN(len(hosts) - 1) + if b >= a { + b++ // ensures b != a + } + if hosts[a].ActiveSandboxCount <= hosts[b].ActiveSandboxCount { + return hosts[a].ID, nil + } + return hosts[b].ID, nil +} + +func (s *LeastLoaded) loadHosts(ctx context.Context) ([]db.ListActiveHostsByLoadRow, error) { + s.mu.RLock() + if s.cached != nil && time.Since(s.cachedAt) < s.ttl() { + hosts := s.cached + s.mu.RUnlock() + return hosts, nil + } + s.mu.RUnlock() + + s.mu.Lock() + defer s.mu.Unlock() + if s.cached != nil && time.Since(s.cachedAt) < s.ttl() { + return s.cached, nil + } + + hosts, err := s.DB.ListActiveHostsByLoad(ctx) + if err != nil { + return nil, fmt.Errorf("list active hosts by load: %w", err) + } + s.cached = hosts + s.cachedAt = time.Now() + return hosts, nil +} + +// Invalidate drops the cached host list so the next SelectHost reflects +// changes immediately. +func (s *LeastLoaded) Invalidate() { + s.mu.Lock() + s.cached = nil + s.cachedAt = time.Time{} + s.mu.Unlock() +} diff --git a/internal/telemetry/client.go b/internal/telemetry/client.go new file mode 100644 index 0000000..3f91361 --- /dev/null +++ b/internal/telemetry/client.go @@ -0,0 +1,147 @@ +// Package telemetry wires OpenTelemetry traces, metrics, and logs through a +// single Client. All three signals export via OTLP/gRPC to the endpoint named +// by OTEL_EXPORTER_OTLP_ENDPOINT. When that env var is unset the Client is a +// no-op so local dev and tests have zero telemetry overhead. +// +// Usage from a binary's main: +// +// tel, err := telemetry.New(ctx, "controlplane", version, nodeID) +// if err != nil { return err } +// defer tel.Shutdown(context.Background()) +// +// The Client installs global tracer/meter/log providers, so downstream code +// uses otel.Tracer(...) / otel.Meter(...) without holding a reference. +package telemetry + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/google/uuid" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/log/global" + "go.opentelemetry.io/otel/propagation" + sdklog "go.opentelemetry.io/otel/sdk/log" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + sdkresource "go.opentelemetry.io/otel/sdk/resource" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + semconv "go.opentelemetry.io/otel/semconv/v1.26.0" +) + +// EndpointEnv is the env var consulted by Client.New. When empty the client +// is a no-op. +const EndpointEnv = "OTEL_EXPORTER_OTLP_ENDPOINT" + +// Client owns the SDK providers for one process. Hold the pointer for the +// lifetime of the binary and call Shutdown before exit. +type Client struct { + TracerProvider *sdktrace.TracerProvider + MeterProvider *sdkmetric.MeterProvider + LoggerProvider *sdklog.LoggerProvider + + enabled bool +} + +// New initialises the telemetry providers. Returns a no-op Client when +// OTEL_EXPORTER_OTLP_ENDPOINT is unset; the returned Client is always safe +// to use and to Shutdown. +func New(ctx context.Context, serviceName, serviceVersion, nodeID string) (*Client, error) { + endpoint := os.Getenv(EndpointEnv) + if endpoint == "" { + return &Client{}, nil + } + + res, err := buildResource(ctx, serviceName, serviceVersion, nodeID) + if err != nil { + return nil, fmt.Errorf("build resource: %w", err) + } + + tp, err := newTracerProvider(ctx, res) + if err != nil { + return nil, fmt.Errorf("tracer provider: %w", err) + } + mp, err := newMeterProvider(ctx, res) + if err != nil { + _ = tp.Shutdown(ctx) + return nil, fmt.Errorf("meter provider: %w", err) + } + lp, err := newLoggerProvider(ctx, res) + if err != nil { + _ = tp.Shutdown(ctx) + _ = mp.Shutdown(ctx) + return nil, fmt.Errorf("logger provider: %w", err) + } + + otel.SetTracerProvider(tp) + otel.SetMeterProvider(mp) + global.SetLoggerProvider(lp) + otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator( + propagation.TraceContext{}, + propagation.Baggage{}, + )) + + return &Client{ + TracerProvider: tp, + MeterProvider: mp, + LoggerProvider: lp, + enabled: true, + }, nil +} + +// Enabled reports whether the Client is exporting (i.e. EndpointEnv was set). +func (c *Client) Enabled() bool { return c != nil && c.enabled } + +// Shutdown flushes and closes all providers. Safe to call on a no-op Client. +// Uses a bounded internal timeout so a stuck collector cannot hang process +// exit. +func (c *Client) Shutdown(ctx context.Context) error { + if c == nil || !c.enabled { + return nil + } + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + var firstErr error + if err := c.TracerProvider.Shutdown(ctx); err != nil && firstErr == nil { + firstErr = fmt.Errorf("tracer shutdown: %w", err) + } + if err := c.MeterProvider.Shutdown(ctx); err != nil && firstErr == nil { + firstErr = fmt.Errorf("meter shutdown: %w", err) + } + if err := c.LoggerProvider.Shutdown(ctx); err != nil && firstErr == nil { + firstErr = fmt.Errorf("logger shutdown: %w", err) + } + return firstErr +} + +func buildResource(_ context.Context, serviceName, serviceVersion, nodeID string) (*sdkresource.Resource, error) { + if nodeID == "" { + nodeID = os.Getenv("NODE_ID") + } + if nodeID == "" { + if h, err := os.Hostname(); err == nil { + nodeID = h + } else { + nodeID = "unknown" + } + } + env := os.Getenv("ENVIRONMENT") + if env == "" { + env = "dev" + } + + attrs := []attribute.KeyValue{ + semconv.ServiceName(serviceName), + semconv.ServiceVersion(serviceVersion), + semconv.ServiceInstanceID(uuid.NewString()), + semconv.HostID(nodeID), + semconv.DeploymentEnvironment(env), + } + return sdkresource.Merge( + sdkresource.Default(), + sdkresource.NewWithAttributes(semconv.SchemaURL, attrs...), + ) +} diff --git a/internal/telemetry/client_test.go b/internal/telemetry/client_test.go new file mode 100644 index 0000000..5ce5067 --- /dev/null +++ b/internal/telemetry/client_test.go @@ -0,0 +1,35 @@ +package telemetry + +import ( + "context" + "testing" +) + +// TestNewNoOpWhenEndpointUnset is the contract the rest of the codebase +// relies on: when OTEL_EXPORTER_OTLP_ENDPOINT is not set the constructor +// must succeed and Shutdown must be a no-op. Local dev and CI depend on +// this so they incur zero telemetry overhead. +func TestNewNoOpWhenEndpointUnset(t *testing.T) { + t.Setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "") + + c, err := New(context.Background(), "test", "v0", "node-1") + if err != nil { + t.Fatalf("New returned error with endpoint unset: %v", err) + } + if c == nil { + t.Fatal("New returned nil client") + } + if c.Enabled() { + t.Error("Client.Enabled() should be false when endpoint unset") + } + if err := c.Shutdown(context.Background()); err != nil { + t.Errorf("Shutdown on no-op client returned error: %v", err) + } +} + +func TestShutdownNilClient(t *testing.T) { + var c *Client + if err := c.Shutdown(context.Background()); err != nil { + t.Errorf("Shutdown on nil client returned error: %v", err) + } +} diff --git a/internal/telemetry/helpers.go b/internal/telemetry/helpers.go new file mode 100644 index 0000000..e703e40 --- /dev/null +++ b/internal/telemetry/helpers.go @@ -0,0 +1,41 @@ +package telemetry + +import ( + "context" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" +) + +// ReportEvent attaches a named event with attributes to the active span in +// ctx. No-op when no span is active. Use this at lifecycle boundaries +// (snapshot restored, network attached, VM resumed) — it shows up in the +// trace UI as a clickable marker on the timeline. +func ReportEvent(ctx context.Context, name string, attrs ...attribute.KeyValue) { + span := trace.SpanFromContext(ctx) + if !span.IsRecording() { + return + } + span.AddEvent(name, trace.WithAttributes(attrs...)) +} + +// ReportError records err on the active span and marks the span as failed. +// Returns err unchanged so callers can write `return telemetry.ReportError(...)`. +// No-op when no span is active. +func ReportError(ctx context.Context, msg string, err error, attrs ...attribute.KeyValue) error { + span := trace.SpanFromContext(ctx) + if span.IsRecording() { + span.RecordError(err, trace.WithAttributes(attrs...)) + span.SetStatus(codes.Error, msg) + } + return err +} + +// SetAttrs sets attributes on the active span. No-op when no span is active. +func SetAttrs(ctx context.Context, attrs ...attribute.KeyValue) { + span := trace.SpanFromContext(ctx) + if span.IsRecording() { + span.SetAttributes(attrs...) + } +} diff --git a/internal/telemetry/logs.go b/internal/telemetry/logs.go new file mode 100644 index 0000000..4083bc1 --- /dev/null +++ b/internal/telemetry/logs.go @@ -0,0 +1,26 @@ +package telemetry + +import ( + "context" + + "go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc" + sdklog "go.opentelemetry.io/otel/sdk/log" + sdkresource "go.opentelemetry.io/otel/sdk/resource" +) + +// newLoggerProvider builds an OTLP/gRPC log provider with a batch processor. +// Wired but not yet bridged from zerolog — Phase 1.5 will add that bridge so +// logs flow alongside traces in the same backend. For now logs continue to +// go to stdout/journald and the trace_id/span_id hook (zerolog.go) lets +// operators correlate manually. +func newLoggerProvider(ctx context.Context, res *sdkresource.Resource) (*sdklog.LoggerProvider, error) { + exp, err := otlploggrpc.New(ctx) + if err != nil { + return nil, err + } + lp := sdklog.NewLoggerProvider( + sdklog.WithResource(res), + sdklog.WithProcessor(sdklog.NewBatchProcessor(exp)), + ) + return lp, nil +} diff --git a/internal/telemetry/meters.go b/internal/telemetry/meters.go new file mode 100644 index 0000000..84192c5 --- /dev/null +++ b/internal/telemetry/meters.go @@ -0,0 +1,234 @@ +package telemetry + +import ( + "context" + "sync" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" +) + +// Cardinality discipline — read this before adding a metric. +// +// FORBIDDEN labels on counters and histograms: +// - sandbox_id +// - team_id / user_id +// - api_key_id +// - request_id / trace_id +// - any other unbounded identifier +// +// These belong on traces and logs, not metrics. One rogue label explodes +// backend storage cost and makes dashboards unusable. Per-sandbox dimensions +// are handled via delta-temporality observable gauges (Phase 2 sandbox +// observer) where the series dies when the sandbox dies. +// +// ALLOWED label dimensions (small bounded enums): +// - service, host_id, environment +// - state, outcome (ok|error|timeout) +// - kind, op, method, code +// - status_class (2xx|4xx|5xx) — never raw HTTP status codes as labels +// +// Code sites import this package and call e.g. +// telemetry.SandboxLifecycleDuration().Record(ctx, secs, ...attrs). + +const meterName = "github.com/superserve-ai/sandbox" + +var ( + initOnce sync.Once + + sandboxLifecycleDuration metric.Float64Histogram + snapshotOpDuration metric.Float64Histogram + snapshotOrphanGCTotal metric.Int64Counter + reconcilerRunDuration metric.Float64Histogram + reconcilerDriftTotal metric.Int64Counter + reaperReapedTotal metric.Int64Counter + proxyHMACFailuresTotal metric.Int64Counter + dataplaneAuthFailures metric.Int64Counter +) + +// initMeters lazily builds the metric handles against the global meter +// provider. Safe to call repeatedly; runs once. +func initMeters() { + initOnce.Do(func() { + m := otel.Meter(meterName) + + sandboxLifecycleDuration, _ = m.Float64Histogram( + "sandbox.lifecycle.duration", + metric.WithUnit("s"), + metric.WithDescription("End-to-end duration of a sandbox lifecycle operation."), + ) + snapshotOpDuration, _ = m.Float64Histogram( + "sandbox.snapshot.duration", + metric.WithUnit("s"), + metric.WithDescription("Duration of a snapshot operation (create, restore, delete, gc)."), + ) + snapshotOrphanGCTotal, _ = m.Int64Counter( + "sandbox.snapshot.orphan_gc.total", + metric.WithDescription("Snapshots reclaimed by the orphan GC."), + ) + reconcilerRunDuration, _ = m.Float64Histogram( + "vmd.reconciler.run.duration", + metric.WithUnit("s"), + metric.WithDescription("Duration of a reconciler tick."), + ) + reconcilerDriftTotal, _ = m.Int64Counter( + "vmd.reconciler.drift.total", + metric.WithDescription("Drift events detected by the reconciler, by kind."), + ) + reaperReapedTotal, _ = m.Int64Counter( + "controlplane.reaper.reaped.total", + metric.WithDescription("Sandboxes destroyed by the timeout reaper, by reason."), + ) + proxyHMACFailuresTotal, _ = m.Int64Counter( + "proxy.hmac.failures.total", + metric.WithDescription("Edge-proxy data-plane requests rejected for invalid HMAC."), + ) + dataplaneAuthFailures, _ = m.Int64Counter( + "controlplane.auth.failures.total", + metric.WithDescription("API key validation failures, by outcome."), + ) + }) +} + +// SandboxLifecycleOp identifies the operation labelled on lifecycle metrics. +type SandboxLifecycleOp string + +const ( + OpCreate SandboxLifecycleOp = "create" + OpPause SandboxLifecycleOp = "pause" + OpResume SandboxLifecycleOp = "resume" + OpDestroy SandboxLifecycleOp = "destroy" +) + +// Outcome is the small bounded enum used on counters/histograms that need +// to distinguish success from failure. Never use raw error strings here. +type Outcome string + +const ( + OutcomeOK Outcome = "ok" + OutcomeError Outcome = "error" + OutcomeTimeout Outcome = "timeout" +) + +// RecordSandboxLifecycle records the duration of a sandbox lifecycle op. +// `from` is "cold" | "warm_pool" | "snapshot" for create; empty otherwise. +func RecordSandboxLifecycle(ctx context.Context, op SandboxLifecycleOp, outcome Outcome, from string, seconds float64) { + initMeters() + if sandboxLifecycleDuration == nil { + return + } + attrs := []attribute.KeyValue{ + attribute.String("op", string(op)), + attribute.String("outcome", string(outcome)), + } + if from != "" { + attrs = append(attrs, attribute.String("from", from)) + } + sandboxLifecycleDuration.Record(ctx, seconds, metric.WithAttributes(attrs...)) +} + +// RecordSnapshotOp records the duration of a snapshot operation. +// op ∈ {"create", "restore", "delete", "gc"}. +func RecordSnapshotOp(ctx context.Context, op string, outcome Outcome, seconds float64) { + initMeters() + if snapshotOpDuration == nil { + return + } + snapshotOpDuration.Record(ctx, seconds, metric.WithAttributes( + attribute.String("op", op), + attribute.String("outcome", string(outcome)), + )) +} + +// IncSnapshotOrphanGC increments the orphan-snapshot GC counter. +func IncSnapshotOrphanGC(ctx context.Context, n int64) { + initMeters() + if snapshotOrphanGCTotal == nil || n == 0 { + return + } + snapshotOrphanGCTotal.Add(ctx, n) +} + +// RecordReconcilerRun records one reconciler tick duration. +func RecordReconcilerRun(ctx context.Context, seconds float64) { + initMeters() + if reconcilerRunDuration == nil { + return + } + reconcilerRunDuration.Record(ctx, seconds) +} + +// IncReconcilerDrift increments the drift counter for `kind`. Bounded enum. +func IncReconcilerDrift(ctx context.Context, kind string) { + initMeters() + if reconcilerDriftTotal == nil { + return + } + reconcilerDriftTotal.Add(ctx, 1, metric.WithAttributes(attribute.String("kind", kind))) +} + +// IncReaperReaped increments the reaper counter labelled by reason. +// reason ∈ {"timeout", "paused_max_age", "destroyed"}. +func IncReaperReaped(ctx context.Context, reason string, n int64) { + initMeters() + if reaperReapedTotal == nil || n == 0 { + return + } + reaperReapedTotal.Add(ctx, n, metric.WithAttributes(attribute.String("reason", reason))) +} + +// IncProxyHMACFailure records a rejected data-plane request. +func IncProxyHMACFailure(ctx context.Context) { + initMeters() + if proxyHMACFailuresTotal == nil { + return + } + proxyHMACFailuresTotal.Add(ctx, 1) +} + +// IncAuthFailure records a control-plane API key auth failure. +// outcome ∈ {"expired", "invalid", "revoked"}. +func IncAuthFailure(ctx context.Context, outcome string) { + initMeters() + if dataplaneAuthFailures == nil { + return + } + dataplaneAuthFailures.Add(ctx, 1, metric.WithAttributes(attribute.String("outcome", outcome))) +} + +// RegisterPoolGauge registers an observable up-down counter that publishes +// the current available count of `kind`. The callback is invoked on each +// metric export tick (15s). kind ∈ {"tap", "netns", "ip", "snapshot_overlay"}. +func RegisterPoolGauge(kind string, getter func() int64) error { + initMeters() + m := otel.Meter(meterName) + _, err := m.Int64ObservableUpDownCounter( + "vmd.pool.available", + metric.WithDescription("Available slots in a pre-allocated pool."), + metric.WithInt64Callback(func(_ context.Context, o metric.Int64Observer) error { + o.Observe(getter(), metric.WithAttributes(attribute.String("kind", kind))) + return nil + }), + ) + return err +} + +// RegisterActiveSandboxesGauge registers an observable up-down counter that +// publishes the count of sandboxes currently in `state`. Use a single +// callback that emits one observation per state to keep cardinality bounded. +func RegisterActiveSandboxesGauge(getter func() map[string]int64) error { + initMeters() + m := otel.Meter(meterName) + _, err := m.Int64ObservableUpDownCounter( + "sandbox.active", + metric.WithDescription("Sandboxes currently in each lifecycle state."), + metric.WithInt64Callback(func(_ context.Context, o metric.Int64Observer) error { + for state, n := range getter() { + o.Observe(n, metric.WithAttributes(attribute.String("state", state))) + } + return nil + }), + ) + return err +} diff --git a/internal/telemetry/metrics.go b/internal/telemetry/metrics.go new file mode 100644 index 0000000..8a15793 --- /dev/null +++ b/internal/telemetry/metrics.go @@ -0,0 +1,64 @@ +package telemetry + +import ( + "context" + "time" + + "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + "go.opentelemetry.io/otel/sdk/metric/metricdata" + sdkresource "go.opentelemetry.io/otel/sdk/resource" +) + +// newMeterProvider builds an OTLP/gRPC meter provider. +// +// Two non-default choices that matter at scale: +// +// 1. Delta temporality. Per-sandbox gauges (Phase 2) only publish a series +// while the sandbox is alive; cumulative temporality would keep stale +// series forever and explode backend storage. +// +// 2. Base2 exponential histograms by default. Auto-bucketing means we don't +// have to pre-tune buckets per metric, and the backend gets a uniformly +// compact representation. MaxSize=160, MaxScale=20 mirrors what e2b uses +// in production. +func newMeterProvider(ctx context.Context, res *sdkresource.Resource) (*sdkmetric.MeterProvider, error) { + exp, err := otlpmetricgrpc.New(ctx, + otlpmetricgrpc.WithTemporalitySelector(deltaTemporalitySelector), + ) + if err != nil { + return nil, err + } + + reader := sdkmetric.NewPeriodicReader(exp, + sdkmetric.WithInterval(15*time.Second), + ) + + mp := sdkmetric.NewMeterProvider( + sdkmetric.WithResource(res), + sdkmetric.WithReader(reader), + sdkmetric.WithView(sdkmetric.NewView( + sdkmetric.Instrument{Kind: sdkmetric.InstrumentKindHistogram}, + sdkmetric.Stream{Aggregation: sdkmetric.AggregationBase2ExponentialHistogram{ + MaxSize: 160, + MaxScale: 20, + }}, + )), + ) + return mp, nil +} + +// deltaTemporalitySelector forces delta temporality for sums and histograms +// (so per-sandbox series die on sandbox exit) while leaving up-down counters +// as cumulative — they represent steady-state values like "active sandboxes" +// where deltas would lose meaning. +func deltaTemporalitySelector(kind sdkmetric.InstrumentKind) metricdata.Temporality { + switch kind { + case sdkmetric.InstrumentKindCounter, + sdkmetric.InstrumentKindHistogram, + sdkmetric.InstrumentKindObservableCounter: + return metricdata.DeltaTemporality + default: + return metricdata.CumulativeTemporality + } +} diff --git a/internal/telemetry/runtime.go b/internal/telemetry/runtime.go new file mode 100644 index 0000000..ab77e39 --- /dev/null +++ b/internal/telemetry/runtime.go @@ -0,0 +1,20 @@ +package telemetry + +import ( + "fmt" + + "go.opentelemetry.io/contrib/instrumentation/runtime" +) + +// StartRuntimeInstrumentation registers the standard Go runtime metrics +// (goroutine count, heap, GC pauses, etc.) on the global meter provider. +// Safe to call on a no-op Client; in that case it returns nil immediately. +func (c *Client) StartRuntimeInstrumentation() error { + if c == nil || !c.enabled { + return nil + } + if err := runtime.Start(runtime.WithMeterProvider(c.MeterProvider)); err != nil { + return fmt.Errorf("runtime instrumentation: %w", err) + } + return nil +} diff --git a/internal/telemetry/traces.go b/internal/telemetry/traces.go new file mode 100644 index 0000000..43dbdda --- /dev/null +++ b/internal/telemetry/traces.go @@ -0,0 +1,29 @@ +package telemetry + +import ( + "context" + + "go.opentelemetry.io/otel/exporters/otlp/otlptrace" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" + sdkresource "go.opentelemetry.io/otel/sdk/resource" + sdktrace "go.opentelemetry.io/otel/sdk/trace" +) + +// newTracerProvider builds an OTLP/gRPC tracer provider with a batch +// processor. Sampling is AlwaysSample at the SDK; if volume becomes a +// problem we add tail sampling at the collector instead. +// +// Endpoint, headers, and TLS come from the standard OTEL_EXPORTER_OTLP_* +// env vars, so operators get the full OTel config surface for free. +func newTracerProvider(ctx context.Context, res *sdkresource.Resource) (*sdktrace.TracerProvider, error) { + exp, err := otlptrace.New(ctx, otlptracegrpc.NewClient()) + if err != nil { + return nil, err + } + tp := sdktrace.NewTracerProvider( + sdktrace.WithBatcher(exp), + sdktrace.WithResource(res), + sdktrace.WithSampler(sdktrace.AlwaysSample()), + ) + return tp, nil +} diff --git a/internal/telemetry/zerolog.go b/internal/telemetry/zerolog.go new file mode 100644 index 0000000..b048501 --- /dev/null +++ b/internal/telemetry/zerolog.go @@ -0,0 +1,30 @@ +package telemetry + +import ( + "github.com/rs/zerolog" + "go.opentelemetry.io/otel/trace" +) + +// ZerologTraceHook injects trace_id and span_id from the active span (if +// any) into every log event. Install once at process start: +// +// log.Logger = log.Logger.Hook(telemetry.ZerologTraceHook{}) +// +// Zerolog hooks don't have access to context.Context directly, so callers +// must use log.Ctx(ctx) / log.With().Ctx(ctx) when emitting logs they want +// correlated. Lines without a context get no trace fields — same as today. +type ZerologTraceHook struct{} + +func (ZerologTraceHook) Run(e *zerolog.Event, _ zerolog.Level, _ string) { + ctx := e.GetCtx() + if ctx == nil { + return + } + sc := trace.SpanContextFromContext(ctx) + if sc.HasTraceID() { + e.Str("trace_id", sc.TraceID().String()) + } + if sc.HasSpanID() { + e.Str("span_id", sc.SpanID().String()) + } +} diff --git a/internal/vm/grpc_adapter.go b/internal/vm/grpc_adapter.go index 71fb6c7..9b30b3d 100644 --- a/internal/vm/grpc_adapter.go +++ b/internal/vm/grpc_adapter.go @@ -2,8 +2,6 @@ package vm import ( "context" - "fmt" - "io" "time" "google.golang.org/grpc" @@ -50,12 +48,22 @@ func (a *GRPCAdapter) CreateVM(ctx context.Context, req *vmdpb.CreateVMRequest) return nil, err } + if err := postBoxdInit(ctx, inst.IP, req.GetEnvVars()); err != nil { + a.mgr.log.Error().Err(err).Str("vm_id", inst.ID).Msg("failed to post env vars to boxd") + // Don't fail the create — the VM is running, env vars just didn't get set. + // The caller can still use per-request env vars as a fallback. + } + return &vmdpb.CreateVMResponse{ VmId: inst.ID, SocketPath: inst.SocketPath, IpAddress: inst.IP, TapDevice: inst.TAPDevice, Pid: uint32(inst.PID), + ResourceLimits: &vmdpb.ResourceLimits{ + VcpuCount: inst.Config.VCPU, + MemoryMib: inst.Config.MemoryMiB, + }, }, nil } @@ -87,11 +95,20 @@ func (a *GRPCAdapter) ResumeVM(ctx context.Context, req *vmdpb.ResumeVMRequest) if err != nil { return nil, err } + + if err := postBoxdInit(ctx, inst.IP, req.GetEnvVars()); err != nil { + a.mgr.log.Error().Err(err).Str("vm_id", inst.ID).Msg("failed to post env vars to boxd on resume") + } + return &vmdpb.ResumeVMResponse{ VmId: inst.ID, SocketPath: inst.SocketPath, IpAddress: inst.IP, Pid: uint32(inst.PID), + ResourceLimits: &vmdpb.ResourceLimits{ + VcpuCount: inst.Config.VCPU, + MemoryMib: inst.Config.MemoryMiB, + }, }, nil } @@ -140,6 +157,20 @@ func (a *GRPCAdapter) RestoreSnapshot(ctx context.Context, req *vmdpb.RestoreSna }, nil } +// DeleteSnapshot unlinks the vmstate + memory files for a previous snapshot. +// Idempotent; path traversal is blocked at the Manager layer. +func (a *GRPCAdapter) DeleteSnapshot(ctx context.Context, req *vmdpb.DeleteSnapshotRequest) (*vmdpb.DeleteSnapshotResponse, error) { + snapshotPath := req.GetSnapshotPath() + memPath := req.GetMemFilePath() + if snapshotPath == "" && memPath == "" { + return nil, status.Error(codes.InvalidArgument, "snapshot_path and/or mem_file_path must be set") + } + if err := a.mgr.DeleteSnapshotFiles(snapshotPath, memPath); err != nil { + return nil, err + } + return &vmdpb.DeleteSnapshotResponse{Deleted: true}, nil +} + func (a *GRPCAdapter) ExecCommand(req *vmdpb.ExecCommandRequest, stream grpc.ServerStreamingServer[vmdpb.ExecCommandResponse]) error { if req.GetCommand() == "" { return status.Error(codes.InvalidArgument, "command is required") @@ -221,89 +252,6 @@ func (a *GRPCAdapter) SetupNetwork(ctx context.Context, req *vmdpb.SetupNetworkR }, nil } -func (a *GRPCAdapter) UploadFile(stream grpc.ClientStreamingServer[vmdpb.UploadFileRequest, vmdpb.UploadFileResponse]) error { - // First message must contain vm_id and path. - first, err := stream.Recv() - if err != nil { - return status.Errorf(codes.InvalidArgument, "failed to receive first message: %v", err) - } - vmID := first.GetVmId() - path := first.GetPath() - if vmID == "" || path == "" { - return status.Error(codes.InvalidArgument, "vm_id and path are required in the first message") - } - - // Collect all data chunks into a buffer via io.Pipe so the Manager - // streams directly to boxd without buffering the entire file. - pr, pw := io.Pipe() - var uploadErr error - var bytesWritten int64 - - go func() { - defer pw.Close() - // Write first message's data if present. - if len(first.GetData()) > 0 { - if _, err := pw.Write(first.GetData()); err != nil { - return - } - } - for { - msg, err := stream.Recv() - if err == io.EOF { - return - } - if err != nil { - pw.CloseWithError(fmt.Errorf("recv: %w", err)) - return - } - if len(msg.GetData()) > 0 { - if _, err := pw.Write(msg.GetData()); err != nil { - return - } - } - } - }() - - bytesWritten, uploadErr = a.mgr.UploadFile(stream.Context(), vmID, path, pr) - if uploadErr != nil { - return status.Errorf(codes.Internal, "upload file: %v", uploadErr) - } - - return stream.SendAndClose(&vmdpb.UploadFileResponse{ - BytesWritten: bytesWritten, - }) -} - -func (a *GRPCAdapter) DownloadFile(req *vmdpb.DownloadFileRequest, stream grpc.ServerStreamingServer[vmdpb.DownloadFileChunk]) error { - if req.GetVmId() == "" || req.GetPath() == "" { - return status.Error(codes.InvalidArgument, "vm_id and path are required") - } - - reader, err := a.mgr.DownloadFile(stream.Context(), req.GetVmId(), req.GetPath()) - if err != nil { - return status.Errorf(codes.Internal, "download file: %v", err) - } - defer reader.Close() - - buf := make([]byte, 64*1024) // 64KB chunks - for { - n, readErr := reader.Read(buf) - if n > 0 { - chunk := make([]byte, n) - copy(chunk, buf[:n]) - if sendErr := stream.Send(&vmdpb.DownloadFileChunk{Data: chunk}); sendErr != nil { - return sendErr - } - } - if readErr == io.EOF { - return nil - } - if readErr != nil { - return status.Errorf(codes.Internal, "read file: %v", readErr) - } - } -} - func (a *GRPCAdapter) UpdateSandboxNetwork(ctx context.Context, req *vmdpb.UpdateSandboxNetworkRequest) (*vmdpb.UpdateSandboxNetworkResponse, error) { vmID := req.GetVmId() if vmID == "" { diff --git a/internal/vm/heartbeat.go b/internal/vm/heartbeat.go new file mode 100644 index 0000000..fe521f6 --- /dev/null +++ b/internal/vm/heartbeat.go @@ -0,0 +1,88 @@ +package vm + +import ( + "context" + "fmt" + "io" + "net/http" + "time" + + "github.com/rs/zerolog" +) + +// HeartbeatConfig controls the VMD → control plane heartbeat loop. +type HeartbeatConfig struct { + // ControlPlaneURL is the base URL of the control plane API (e.g. + // "http://localhost:8080"). The heartbeat POSTs to + // {ControlPlaneURL}/internal/hosts/{HostID}/heartbeat. + ControlPlaneURL string + + // HostID is this host's identifier in the host table. + HostID string + + // Token is the shared secret for authenticating internal API calls. + // Sent as `Authorization: Bearer `. + Token string + + // Interval is how often the heartbeat fires. Default: 30s. + Interval time.Duration +} + +// StartHeartbeat launches a background goroutine that periodically POSTs +// to the control plane's heartbeat endpoint. Blocks until ctx is cancelled. +func StartHeartbeat(ctx context.Context, cfg HeartbeatConfig, log zerolog.Logger) { + log = log.With().Str("component", "heartbeat").Logger() + + interval := cfg.Interval + if interval <= 0 { + interval = 30 * time.Second + } + + url := fmt.Sprintf("%s/internal/hosts/%s/heartbeat", cfg.ControlPlaneURL, cfg.HostID) + token := cfg.Token + client := &http.Client{Timeout: 10 * time.Second} + + log.Info(). + Str("url", url). + Dur("interval", interval). + Msg("heartbeat started") + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + // Fire once immediately so the host is marked alive on startup. + sendHeartbeat(ctx, client, url, token, log) + + for { + select { + case <-ctx.Done(): + log.Info().Msg("heartbeat exiting") + return + case <-ticker.C: + sendHeartbeat(ctx, client, url, token, log) + } + } +} + +func sendHeartbeat(ctx context.Context, client *http.Client, url, token string, log zerolog.Logger) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) + if err != nil { + log.Error().Err(err).Msg("failed to create heartbeat request") + return + } + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + + resp, err := client.Do(req) + if err != nil { + log.Warn().Err(err).Msg("heartbeat failed") + return + } + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + log.Warn().Int("status", resp.StatusCode).Msg("heartbeat got non-200 response") + } +} diff --git a/internal/vm/manager.go b/internal/vm/manager.go index 7d33f30..c2f8ae4 100644 --- a/internal/vm/manager.go +++ b/internal/vm/manager.go @@ -2,11 +2,13 @@ package vm import ( "context" + "crypto/sha256" "fmt" "io" "os" "os/exec" "path/filepath" + "strings" "sync" "syscall" "time" @@ -104,7 +106,7 @@ type ManagerConfig struct { BaseRootfsPath string SnapshotDir string RunDir string - MaxConcurrent int // Max concurrent CreateVM operations (0 = default 10). + MaxConcurrent int // Max concurrent CreateVM operations (0 = default 10). } // TemplateSnapshot holds paths for a template snapshot created at startup. @@ -113,6 +115,8 @@ type TemplateSnapshot struct { MemFilePath string // e.g., snapshots/template/mem.snap DiskPath string // e.g., rundir/template/rootfs.ext4 RunDir string // e.g., rundir/template/ + VCPUCount uint32 // actual vCPU count baked into the snapshot + MemSizeMiB uint32 // actual RAM in MiB baked into the snapshot } // --------------------------------------------------------------------------- @@ -125,6 +129,7 @@ type Manager struct { netMgr *network.Manager egressProxy *network.EgressProxy log zerolog.Logger + state *StateStore // persistent local state (BoltDB); nil = no persistence mu sync.RWMutex vms map[string]*VMInstance @@ -139,14 +144,20 @@ func NewManager(cfg ManagerConfig, netMgr *network.Manager, log zerolog.Logger) maxConcurrent = 10 } return &Manager{ - cfg: cfg, - netMgr: netMgr, - log: log.With().Str("component", "vm_manager").Logger(), + cfg: cfg, + netMgr: netMgr, + log: log.With().Str("component", "vm_manager").Logger(), vms: make(map[string]*VMInstance), createSem: make(chan struct{}, maxConcurrent), }, nil } +// SetStateStore attaches a BoltDB state store for durable persistence. +// Must be called before any VM operations. +func (m *Manager) SetStateStore(s *StateStore) { + m.state = s +} + // SetEgressProxy sets the TCP egress proxy for domain-based filtering. // Must be called before any VMs are created. func (m *Manager) SetEgressProxy(proxy *network.EgressProxy) { @@ -162,51 +173,85 @@ func (m *Manager) templateRunDir() string { // InitDefaultTemplate — boot once, snapshot, reuse forever // --------------------------------------------------------------------------- -// InitDefaultTemplate cold-boots a throwaway VM from the base image, waits -// for the guest agent, snapshots the running state, and kills the VM — keeping -// the rundir and snapshot files on disk. Every subsequent CreateVM restores -// from this template snapshot instead of cold-booting. +// InitDefaultTemplate ensures a template snapshot is available for fast +// sandbox creation. If a valid cached template exists and the base rootfs +// hasn't changed since it was built, the cached template is reused — +// skipping the ~2-3s cold boot entirely. This makes VMD restarts fast +// when only the VMD binary changed (the common deploy case). // -// The template VM uses a fixed directory name ("template") so the snapshot's -// hardcoded path_on_host is always rundir/template/rootfs.ext4. Each new VM -// gets mount namespace isolation to present its own rootfs at that path. +// The template is rebuilt only when: +// - No cached snapshot exists on disk (first boot) +// - The base rootfs hash changed (new boxd version baked in) +// - The cached snapshot files are missing or corrupt func (m *Manager) InitDefaultTemplate(ctx context.Context) error { - // Use a fixed ID so the rundir is always "template". templateID := templateDirName log := m.log.With().Str("template_id", templateID).Logger() - log.Info().Msg("initializing default template — cold-booting throwaway VM") - // Cold-boot a throwaway VM from the base image. + snapshotDir := filepath.Join(m.cfg.SnapshotDir, templateDirName) + snapPath := filepath.Join(snapshotDir, "vmstate.snap") + memPath := filepath.Join(snapshotDir, "mem.snap") + diskPath := filepath.Join(m.templateRunDir(), "rootfs.ext4") + hashPath := filepath.Join(snapshotDir, "rootfs.sha256") + + // Check if we can reuse the cached template. + currentHash, hashErr := fileHash(m.cfg.BaseRootfsPath) + if hashErr != nil { + log.Warn().Err(hashErr).Msg("could not hash base rootfs — will rebuild template") + } + + metaPath := filepath.Join(snapshotDir, "template.meta") + + if hashErr == nil && m.canReuseTemplate(snapPath, memPath, diskPath, hashPath, currentHash) { + vcpu, mem := readTemplateMeta(metaPath) + log.Info().Uint32("vcpu", vcpu).Uint32("mem_mib", mem).Msg("base rootfs unchanged — reusing cached template snapshot") + m.defaultTemplate = &TemplateSnapshot{ + SnapshotPath: snapPath, + MemFilePath: memPath, + DiskPath: diskPath, + RunDir: m.templateRunDir(), + VCPUCount: vcpu, + MemSizeMiB: mem, + } + return nil + } + + // Cache miss — cold boot a throwaway VM, snapshot it, kill it. + log.Info().Msg("building new template — cold-booting throwaway VM") + inst, err := m.coldBootVM(ctx, templateID) if err != nil { return fmt.Errorf("boot template VM: %w", err) } - // Wait for boxd to be reachable via HTTP. if err := m.waitForBoxd(ctx, inst.IP, 30*time.Second); err != nil { _ = m.DestroyVM(ctx, templateID, true) return fmt.Errorf("boxd not ready: %w", err) } log.Info().Msg("guest agent ready — creating template snapshot") - // Snapshot the live VM. - snapshotDir := filepath.Join(m.cfg.SnapshotDir, templateDirName) - snapPath, memPath, err := m.CreateVMSnapshot(ctx, templateID, snapshotDir) + snapPath, memPath, err = m.CreateVMSnapshot(ctx, templateID, snapshotDir) if err != nil { _ = m.DestroyVM(ctx, templateID, true) return fmt.Errorf("snapshot template VM: %w", err) } - // Kill the VM process but keep the rundir on disk — the snapshot's - // path_on_host references these files. - diskPath := inst.DiskPath + diskPath = inst.DiskPath m.killVMKeepRunDir(templateID) + // Persist the rootfs hash and resource values so the next startup + // can skip the cold boot and restore the correct template config. + if currentHash != "" { + _ = os.WriteFile(hashPath, []byte(currentHash), 0o644) + } + writeTemplateMeta(metaPath, inst.Config.VCPU, inst.Config.MemoryMiB) + m.defaultTemplate = &TemplateSnapshot{ SnapshotPath: snapPath, MemFilePath: memPath, DiskPath: diskPath, RunDir: m.templateRunDir(), + VCPUCount: inst.Config.VCPU, + MemSizeMiB: inst.Config.MemoryMiB, } log.Info(). @@ -216,6 +261,54 @@ func (m *Manager) InitDefaultTemplate(ctx context.Context) error { return nil } +// readTemplateMeta reads the vCPU and memory values persisted alongside +// the template snapshot. Returns safe defaults (1 vCPU, 1024 MiB) if the +// file is missing or unreadable. +func readTemplateMeta(path string) (vcpu, memMiB uint32) { + data, err := os.ReadFile(path) + if err != nil { + return 1, 1024 + } + var v, m uint32 + if _, err := fmt.Sscanf(strings.TrimSpace(string(data)), "%d %d", &v, &m); err != nil { + return 1, 1024 + } + return v, m +} + +func writeTemplateMeta(path string, vcpu, memMiB uint32) { + _ = os.WriteFile(path, []byte(fmt.Sprintf("%d %d", vcpu, memMiB)), 0o644) +} + +// canReuseTemplate returns true when all template files exist on disk and +// the stored rootfs hash matches the current base image. +func (m *Manager) canReuseTemplate(snapPath, memPath, diskPath, hashPath, currentHash string) bool { + for _, p := range []string{snapPath, memPath, diskPath} { + if _, err := os.Stat(p); err != nil { + return false + } + } + stored, err := os.ReadFile(hashPath) + if err != nil { + return false + } + return strings.TrimSpace(string(stored)) == currentHash +} + +// fileHash returns the SHA-256 hex digest of a file. +func fileHash(path string) (string, error) { + f, err := os.Open(path) + if err != nil { + return "", err + } + defer f.Close() + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return "", err + } + return fmt.Sprintf("%x", h.Sum(nil)), nil +} + // --------------------------------------------------------------------------- // CreateVM — single code path via template snapshot restore // --------------------------------------------------------------------------- @@ -259,70 +352,115 @@ func (m *Manager) CreateVM(ctx context.Context, vmID string, vcpu, memMiB, diskM return nil, status.Errorf(codes.AlreadyExists, "vm %s already exists", vmID) } - rundirID := uuid.New().String() inst := &VMInstance{ ID: vmID, Status: StatusCreating, CreatedAt: time.Now(), Metadata: metadata, - RunDirID: rundirID, + RunDirID: vmID, Config: VMConfig{ - VCPU: vcpu, - MemoryMiB: memMiB, - DiskSizeMiB: diskMiB, + VCPU: m.defaultTemplate.VCPUCount, + MemoryMiB: m.defaultTemplate.MemSizeMiB, }, } m.vms[vmID] = inst m.mu.Unlock() cleanup := func() { - m.cleanupRunDir(rundirID) + m.cleanupRunDir(vmID) m.setStatus(vmID, StatusError) m.removeVM(vmID) } - // 1. Copy the template rootfs for this VM. - stepStart := time.Now() - perVMRootfs, err := m.copyRootfs(ctx, rundirID, m.defaultTemplate.DiskPath) - if err != nil { + // Steps 1 and 2 — copying the rootfs and setting up the network + // namespace — are independent. Run them in parallel so the total + // wall-clock for the pair is max(rootfs, netns) instead of their sum. + // On typical hardware this shaves ~10-30ms off create latency. + parallelStart := time.Now() + + type rootfsResult struct { + path string + err error + } + type netResult struct { + info *network.VMNetInfo + err error + } + rootfsCh := make(chan rootfsResult, 1) + netCh := make(chan netResult, 1) + + go func() { + p, err := m.copyRootfs(ctx, vmID, m.defaultTemplate.DiskPath) + rootfsCh <- rootfsResult{path: p, err: err} + }() + go func() { + info, err := m.netMgr.SetupVM(ctx, vmID, netCfg) + netCh <- netResult{info: info, err: err} + }() + + rfs := <-rootfsCh + nr := <-netCh + + // Both goroutines always run to completion so we know exactly which + // side(s) succeeded and need unwinding. Tear them down in reverse + // order of resource ownership: network first (it's tied to kernel + // state), rundir last (it's just files, handled by cleanup()). + if rfs.err != nil || nr.err != nil { + // If the network came up but the rootfs did not, the kernel + // namespace/veth/firewall state must be explicitly freed; the + // shared cleanup() only removes the rundir. + if nr.err == nil { + m.netMgr.CleanupVM(vmID) + } cleanup() - return nil, fmt.Errorf("copy rootfs: %w", err) + switch { + case rfs.err != nil && nr.err != nil: + return nil, fmt.Errorf("copy rootfs: %w; setup network: %v", rfs.err, nr.err) + case rfs.err != nil: + return nil, fmt.Errorf("copy rootfs: %w", rfs.err) + default: + return nil, fmt.Errorf("setup network: %w", nr.err) + } } - inst.DiskPath = perVMRootfs - log.Debug().Dur("duration_ms", time.Since(stepStart)).Msg("step: copy rootfs") - // 2. Set up networking. - stepStart = time.Now() - netInfo, err := m.netMgr.SetupVM(ctx, vmID, netCfg) - if err != nil { - cleanup() - return nil, fmt.Errorf("setup network: %w", err) - } + perVMRootfs := rfs.path + netInfo := nr.info + // Take inst.mu to write — concurrent readers via ExecCommand / + // LookupInstance / persistState take RLock. + inst.mu.Lock() + inst.DiskPath = perVMRootfs inst.IP = netInfo.HostIP inst.TAPDevice = netInfo.TAPDevice inst.MACAddress = netInfo.MACAddress inst.Namespace = netInfo.Namespace - log.Debug().Dur("duration_ms", time.Since(stepStart)).Msg("step: setup network") + inst.mu.Unlock() + log.Debug().Dur("duration_ms", time.Since(parallelStart)).Msg("step: copy rootfs + setup network (parallel)") // 3. Start Firecracker in a mount + network namespace. - stepStart = time.Now() - vmDir := filepath.Join(m.cfg.RunDir, rundirID) + startStep := time.Now() + vmDir := filepath.Join(m.cfg.RunDir, vmID) socketPath := filepath.Join(vmDir, "firecracker.sock") - inst.SocketPath = socketPath - pid, err := m.startFirecrackerInNamespace(vmID, socketPath, perVMRootfs, netInfo.Namespace) - if err != nil { + var ( + pid int + startErr error + ) + pid, startErr = m.startFirecrackerViaSystemd(ctx, vmID, socketPath, perVMRootfs, netInfo.Namespace) + if startErr != nil { m.netMgr.CleanupVM(vmID) cleanup() - return nil, fmt.Errorf("start firecracker: %w", err) + return nil, fmt.Errorf("start firecracker: %w", startErr) } + inst.mu.Lock() + inst.SocketPath = socketPath inst.PID = pid - log.Debug().Dur("duration_ms", time.Since(stepStart)).Msg("step: start firecracker") + inst.mu.Unlock() + log.Debug().Dur("duration_ms", time.Since(startStep)).Msg("step: start firecracker") // 4. Restore from the original (unpatched) template snapshot. // No IP reconfig needed — the VM uses a fixed internal IP (169.254.0.21) // and the network namespace provides isolation. - stepStart = time.Now() + restoreStep := time.Now() if err := RestoreSnapshotWithOverrides( socketPath, m.defaultTemplate.SnapshotPath, m.defaultTemplate.MemFilePath, "eth0", netInfo.TAPDevice, @@ -331,9 +469,11 @@ func (m *Manager) CreateVM(ctx context.Context, vmID string, vcpu, memMiB, diskM cleanup() return nil, fmt.Errorf("restore template snapshot: %w", err) } - log.Debug().Dur("duration_ms", time.Since(stepStart)).Msg("step: restore snapshot") + log.Debug().Dur("duration_ms", time.Since(restoreStep)).Msg("step: restore snapshot") m.setStatus(vmID, StatusRunning) + // Persist again now that PID, IP, and socket are set. + m.persistState(inst) log.Info(). Int("pid", pid). Str("host_ip", inst.IP). @@ -366,7 +506,7 @@ func (m *Manager) coldBootVM(ctx context.Context, vmID string) (*VMInstance, err RunDirID: vmID, Config: VMConfig{ VCPU: 1, - MemoryMiB: 512, + MemoryMiB: 1024, KernelPath: m.cfg.KernelPath, RootfsPath: m.cfg.BaseRootfsPath, }, @@ -384,7 +524,6 @@ func (m *Manager) coldBootVM(ctx context.Context, vmID string) (*VMInstance, err m.setStatus(vmID, StatusError) return nil, fmt.Errorf("copy rootfs: %w", err) } - inst.DiskPath = diskPath // 2. Set up networking. netInfo, err := m.netMgr.SetupVM(ctx, vmID, nil) @@ -393,15 +532,22 @@ func (m *Manager) coldBootVM(ctx context.Context, vmID string) (*VMInstance, err m.setStatus(vmID, StatusError) return nil, fmt.Errorf("setup network: %w", err) } + + // inst is already visible via m.vms; take inst.mu for writes so + // concurrent readers (ExecCommand, LookupInstance, persistState) + // see a consistent view. + inst.mu.Lock() + inst.DiskPath = diskPath inst.IP = netInfo.HostIP inst.TAPDevice = netInfo.TAPDevice inst.MACAddress = netInfo.MACAddress inst.Namespace = netInfo.Namespace + mac := inst.MACAddress + inst.mu.Unlock() // 3. Build Firecracker machine configuration. vmDir := filepath.Join(m.cfg.RunDir, vmID) socketPath := filepath.Join(vmDir, "firecracker.sock") - inst.SocketPath = socketPath fcCfg := FirecrackerConfig{ SocketPath: socketPath, @@ -411,7 +557,7 @@ func (m *Manager) coldBootVM(ctx context.Context, vmID string) (*VMInstance, err VCPUCount: 1, MemSizeMiB: 1024, TAPDevice: network.TAPName, - MACAddress: inst.MACAddress, + MACAddress: mac, VMID: vmID, VMIP: network.VMInternalIP, GatewayIP: network.VMGatewayIP, @@ -425,10 +571,14 @@ func (m *Manager) coldBootVM(ctx context.Context, vmID string) (*VMInstance, err m.setStatus(vmID, StatusError) return nil, fmt.Errorf("start firecracker: %w", err) } + + inst.mu.Lock() + inst.SocketPath = socketPath inst.PID = pid + inst.mu.Unlock() m.setStatus(vmID, StatusRunning) - log.Info().Int("pid", pid).Str("host_ip", inst.IP).Msg("VM cold-booted") + log.Info().Int("pid", pid).Str("host_ip", netInfo.HostIP).Msg("VM cold-booted") return inst, nil } @@ -446,24 +596,11 @@ func (m *Manager) DestroyVM(ctx context.Context, vmID string, force bool) error log := m.log.With().Str("vm_id", vmID).Logger() log.Info().Bool("force", force).Msg("destroying VM") - if inst.PID > 0 { - proc, findErr := os.FindProcess(inst.PID) - if findErr == nil { - if force { - _ = proc.Signal(syscall.SIGKILL) - } else { - _ = proc.Signal(syscall.SIGTERM) - done := make(chan error, 1) - go func() { _, e := proc.Wait(); done <- e }() - select { - case <-done: - case <-time.After(5 * time.Second): - log.Warn().Msg("SIGTERM timed out, sending SIGKILL") - _ = proc.Signal(syscall.SIGKILL) - } - } - } + // Stop the systemd unit — this kills Firecracker and runs ExecStopPost cleanup. + if err := stopUnit(ctx, systemdUnitName(vmID)); err != nil { + log.Warn().Err(err).Msg("systemctl stop failed (unit may already be stopped)") } + removeUnitDropIn(vmID) if inst.SocketPath != "" { _ = os.Remove(inst.SocketPath) @@ -507,21 +644,12 @@ func (m *Manager) PauseVM(ctx context.Context, vmID, snapshotDir string) (snapsh log.Info().Str("snapshot_path", snapshotPath).Msg("pausing VM — creating snapshot") if err := CreateSnapshot(inst.SocketPath, snapshotPath, memPath); err != nil { - return "", "", fmt.Errorf("create snapshot: %w", err) + return "", "", m.handleVMError(vmID, fmt.Errorf("create snapshot: %w", err)) } - if inst.PID > 0 { - if proc, e := os.FindProcess(inst.PID); e == nil { - _ = proc.Signal(syscall.SIGTERM) - done := make(chan struct{}) - go func() { proc.Wait(); close(done) }() //nolint:errcheck - select { - case <-done: - case <-time.After(500 * time.Millisecond): - _ = proc.Signal(syscall.SIGKILL) - <-done - } - } + // Stop the Firecracker process — snapshot is already on disk. + if err := stopUnit(ctx, systemdUnitName(vmID)); err != nil { + log.Warn().Err(err).Msg("systemctl stop failed during pause") } inst.mu.Lock() @@ -530,6 +658,7 @@ func (m *Manager) PauseVM(ctx context.Context, vmID, snapshotDir string) (snapsh inst.MemFilePath = memPath inst.mu.Unlock() + m.persistState(inst) log.Info().Msg("VM paused") return snapshotPath, memPath, nil } @@ -573,7 +702,7 @@ func (m *Manager) ResumeVM(ctx context.Context, vmID, snapshotPath, memPath stri vmDir := filepath.Join(m.cfg.RunDir, rundirKey) socketPath := filepath.Join(vmDir, "firecracker.sock") - pid, err := m.startFirecrackerInNamespace(vmID, socketPath, rootfsPath, inst.Namespace) + pid, err := m.startFirecrackerViaSystemd(ctx, vmID, socketPath, rootfsPath, inst.Namespace) if err != nil { return nil, fmt.Errorf("start firecracker for restore: %w", err) } @@ -589,6 +718,7 @@ func (m *Manager) ResumeVM(ctx context.Context, vmID, snapshotPath, memPath stri inst.Status = StatusRunning inst.mu.Unlock() + m.persistState(inst) log.Info().Int("pid", pid).Msg("VM resumed from snapshot") return inst, nil } @@ -625,6 +755,75 @@ func (m *Manager) CreateVMSnapshot(ctx context.Context, vmID, snapshotDir string return snapshotPath, memPath, nil } +// DeleteSnapshotFiles removes a snapshot's on-disk artifacts (vmstate + memory +// file). Both paths must resolve to locations under the configured snapshot +// directory — arbitrary paths are rejected as InvalidArgument to prevent the +// control plane from accidentally (or maliciously) unlinking unrelated files. +// +// The operation is idempotent: missing files are not an error. The enclosing +// directory is removed on a best-effort basis once both files are gone and it +// is empty; a non-empty directory is left alone. +// +// Callers are responsible for ensuring the snapshot is no longer referenced +// by any running VM. This method does not inspect instance state. +func (m *Manager) DeleteSnapshotFiles(snapshotPath, memPath string) error { + if snapshotPath == "" && memPath == "" { + return status.Error(codes.InvalidArgument, "at least one of snapshot_path/mem_file_path is required") + } + for _, p := range []string{snapshotPath, memPath} { + if p == "" { + continue + } + if err := m.assertUnderSnapshotDir(p); err != nil { + return err + } + } + + for _, p := range []string{snapshotPath, memPath} { + if p == "" { + continue + } + if err := os.Remove(p); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("remove %s: %w", p, err) + } + } + + // Best-effort: if the parent directory is now empty, clean it up. Any + // error here is swallowed — a non-empty or missing directory is fine. + for _, p := range []string{snapshotPath, memPath} { + if p == "" { + continue + } + dir := filepath.Dir(p) + // Only attempt to remove directories under SnapshotDir — never the + // root itself. + if dir == "" || dir == m.cfg.SnapshotDir { + continue + } + _ = os.Remove(dir) // removes only if empty + } + return nil +} + +// assertUnderSnapshotDir returns nil iff `p` is an absolute path that, after +// cleaning, lies under m.cfg.SnapshotDir. This is the guard that keeps +// DeleteSnapshotFiles from being used to unlink arbitrary files on the host. +func (m *Manager) assertUnderSnapshotDir(p string) error { + if m.cfg.SnapshotDir == "" { + return status.Error(codes.FailedPrecondition, "snapshot_dir not configured") + } + if !filepath.IsAbs(p) { + return status.Errorf(codes.InvalidArgument, "path must be absolute: %s", p) + } + cleaned := filepath.Clean(p) + root := filepath.Clean(m.cfg.SnapshotDir) + rel, err := filepath.Rel(root, cleaned) + if err != nil || rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + return status.Errorf(codes.InvalidArgument, "path is outside snapshot directory: %s", p) + } + return nil +} + // RestoreVMSnapshot boots a VM from a previously captured snapshot. func (m *Manager) RestoreVMSnapshot(ctx context.Context, vmID, snapshotPath, memPath, diskPath string, resourceLimits VMConfig, netCfg *network.Config, @@ -636,23 +835,12 @@ func (m *Manager) RestoreVMSnapshot(ctx context.Context, vmID, snapshotPath, mem } m.mu.Lock() - existingInst, inPlace := m.vms[vmID] + _, inPlace := m.vms[vmID] if inPlace { - if existingInst.PID > 0 { - if proc, e := os.FindProcess(existingInst.PID); e == nil { - _ = proc.Signal(syscall.SIGTERM) - // Wait for process to exit instead of sleeping. - done := make(chan struct{}) - go func() { proc.Wait(); close(done) }() //nolint:errcheck - select { - case <-done: - case <-time.After(500 * time.Millisecond): - _ = proc.Signal(syscall.SIGKILL) - <-done - } - } - } delete(m.vms, vmID) + m.mu.Unlock() + _ = stopUnit(ctx, systemdUnitName(vmID)) + m.mu.Lock() } inst := &VMInstance{ @@ -676,7 +864,6 @@ func (m *Manager) RestoreVMSnapshot(ctx context.Context, vmID, snapshotPath, mem } log.Debug().Str("disk_path", diskPath).Msg("created rootfs copy for restored VM") } - inst.DiskPath = diskPath var tapDevice, macAddr, hostIP, nsName string @@ -702,25 +889,34 @@ func (m *Manager) RestoreVMSnapshot(ctx context.Context, vmID, snapshotPath, mem hostIP = netInfo.HostIP nsName = netInfo.Namespace } + + vmDir := filepath.Join(m.cfg.RunDir, vmID) + socketPath := filepath.Join(vmDir, "firecracker.sock") + + // Publish all the network/disk/socket fields before starting + // Firecracker so the in-memory view is consistent for concurrent + // readers. Lock once for the batch. + inst.mu.Lock() + inst.DiskPath = diskPath inst.IP = hostIP inst.TAPDevice = tapDevice inst.MACAddress = macAddr inst.Namespace = nsName - - vmDir := filepath.Join(m.cfg.RunDir, vmID) - socketPath := filepath.Join(vmDir, "firecracker.sock") inst.SocketPath = socketPath + inst.mu.Unlock() - pid, err := m.startFirecrackerInNamespace(vmID, socketPath, diskPath, nsName) - if err != nil { + pid, startErr := m.startFirecrackerViaSystemd(ctx, vmID, socketPath, diskPath, nsName) + if startErr != nil { if !inPlace { m.netMgr.CleanupVM(vmID) } m.cleanupRunDir(vmID) m.setStatus(vmID, StatusError) - return nil, fmt.Errorf("start firecracker: %w", err) + return nil, fmt.Errorf("start firecracker: %w", startErr) } + inst.mu.Lock() inst.PID = pid + inst.mu.Unlock() log.Info().Msg("restoring snapshot") @@ -740,6 +936,7 @@ func (m *Manager) RestoreVMSnapshot(ctx context.Context, vmID, snapshotPath, mem } m.setStatus(vmID, StatusRunning) + m.persistState(inst) log.Info().Int("pid", pid).Msg("VM restored from snapshot") return inst, nil } @@ -756,19 +953,109 @@ func (m *Manager) GetVMInfo(_ context.Context, vmID string) (*VMInstance, error) // ShutdownAll // --------------------------------------------------------------------------- +// ShutdownAll is a no-op — VMs are owned by systemd and outlive VMD. func (m *Manager) ShutdownAll() { - m.mu.RLock() - ids := make([]string, 0, len(m.vms)) - for id := range m.vms { - ids = append(ids, id) + m.log.Info().Msg("VMs are systemd-managed — they will continue running after VMD shutdown") +} + +// --------------------------------------------------------------------------- +// ReattachAll — startup recovery +// --------------------------------------------------------------------------- + +// ReattachAll reconstructs the in-memory VM map on startup from two sources: +// +// 1. BoltDB — VMD's own cache from the previous lifetime. +// 2. Systemd — ground truth for which Firecracker units are actually running. +// +// For each VM in BoltDB that systemd confirms is alive AND whose Firecracker +// API socket is reachable, VMD reattaches. Stale BoltDB entries (dead process) +// are cleaned up. Orphan systemd units (running but not in BoltDB) are logged +// so the Phase 3 reconciler can handle them. +func (m *Manager) ReattachAll(ctx context.Context) (reattached, stale int) { + if m.state == nil { + m.log.Warn().Msg("no state store configured — skipping reattach") + return 0, 0 + } + + records, err := m.state.All() + if err != nil { + m.log.Error().Err(err).Msg("failed to read BoltDB state — skipping reattach") + return 0, 0 + } + + // Build a set of BoltDB-known IDs for orphan detection. + knownIDs := make(map[string]bool, len(records)) + for _, rec := range records { + knownIDs[rec.ID] = true + } + + if len(records) == 0 { + m.log.Info().Msg("no VMs in BoltDB — checking systemd for orphans") + } else { + m.log.Info().Int("count", len(records)).Msg("reattaching VMs from BoltDB") + } + + // Phase A: reattach from BoltDB. + for _, rec := range records { + log := m.log.With().Str("vm_id", rec.ID).Logger() + + // Paused VMs legitimately have no running systemd unit — they + // were stopped during pause and are waiting for a resume via + // their snapshot. Reattach them with their paused status so the + // resume path can find them. + if rec.Status == StatusPaused { + inst := toInstance(rec) + m.mu.Lock() + m.vms[rec.ID] = inst + m.mu.Unlock() + log.Info().Msg("reattached paused VM") + reattached++ + continue + } + + // For running VMs, verify the systemd unit is still active. + if !isUnitActive(ctx, systemdUnitName(rec.ID)) { + log.Warn().Msg("VM in BoltDB but not running — cleaning up stale record") + m.state.Delete(rec.ID) + stale++ + continue + } + + // Verify the Firecracker API socket is actually reachable. + if rec.SocketPath != "" { + if _, statErr := os.Stat(rec.SocketPath); statErr != nil { + log.Warn().Str("socket", rec.SocketPath).Msg("VM unit active but socket missing — cleaning up") + m.state.Delete(rec.ID) + stale++ + continue + } + } + + // Reattach: add to in-memory map. + inst := toInstance(rec) + + m.mu.Lock() + m.vms[rec.ID] = inst + m.mu.Unlock() + + m.persistState(inst) + log.Info().Int("pid", inst.PID).Str("ip", inst.IP).Msg("reattached to running VM") + reattached++ } - m.mu.RUnlock() - for _, id := range ids { - if err := m.DestroyVM(context.Background(), id, true); err != nil { - m.log.Error().Err(err).Str("vm_id", id).Msg("failed to destroy VM during shutdown") + // Phase B: detect orphan systemd units not in BoltDB. + activeIDs, err := listActiveFirecrackerUnits(ctx) + if err != nil { + m.log.Warn().Err(err).Msg("failed to list active firecracker units — orphan detection skipped") + } else { + for _, id := range activeIDs { + if !knownIDs[id] { + m.log.Warn().Str("vm_id", id).Msg("orphan systemd unit detected (not in BoltDB) — will be handled by reconciler") + } } } + + return reattached, stale } // --------------------------------------------------------------------------- @@ -799,7 +1086,11 @@ func (m *Manager) ExecCommand(ctx context.Context, vmID, command string, timeout ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - return httpExec(ctx, vmIP, command, timeout, opts) + result, err := httpExec(ctx, vmIP, command, timeout, opts) + if err != nil { + return nil, m.handleVMError(vmID, err) + } + return result, nil } func (m *Manager) ExecCommandStream(ctx context.Context, vmID, command string, timeout time.Duration, opts *ExecOptions, @@ -828,31 +1119,17 @@ func (m *Manager) ExecCommandStream(ctx context.Context, vmID, command string, t ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - return httpExecStream(ctx, vmIP, command, timeout, opts, onChunk) + if err := httpExecStream(ctx, vmIP, command, timeout, opts, onChunk); err != nil { + return m.handleVMError(vmID, err) + } + return nil } // --------------------------------------------------------------------------- -// File operations (raw HTTP to boxd for content, Connect RPC for metadata) +// File operations (Connect RPC for metadata only; byte transfer lives +// on the edge proxy's /files endpoint.) // --------------------------------------------------------------------------- -// UploadFile writes content to a file inside a running VM. -func (m *Manager) UploadFile(ctx context.Context, vmID, filePath string, content io.Reader) (int64, error) { - vmIP, err := m.getRunningVMIP(vmID) - if err != nil { - return 0, err - } - return uploadFile(ctx, vmIP, filePath, content) -} - -// DownloadFile reads a file from inside a running VM. -func (m *Manager) DownloadFile(ctx context.Context, vmID, filePath string) (io.ReadCloser, error) { - vmIP, err := m.getRunningVMIP(vmID) - if err != nil { - return nil, err - } - return downloadFile(ctx, vmIP, filePath) -} - // DeleteFile removes a file or directory inside a running VM via Connect RPC. func (m *Manager) DeleteFile(ctx context.Context, vmID, filePath string) error { vmIP, err := m.getRunningVMIP(vmID) @@ -883,6 +1160,43 @@ func (m *Manager) getRunningVMIP(vmID string) (string, error) { return vmIP, nil } +// handleVMError checks whether a connection error to a VM means the VM is +// dead. If the systemd unit is no longer active, it marks the VM as failed +// in BoltDB, removes it from the in-memory map, and returns NotFound so +// the control plane returns 410 Gone. If the unit is still active (transient +// error), it returns the original error unchanged. +func (m *Manager) handleVMError(vmID string, origErr error) error { + if origErr == nil { + return nil + } + checkCtx, checkCancel := context.WithTimeout(context.Background(), 3*time.Second) + defer checkCancel() + if isUnitActive(checkCtx, systemdUnitName(vmID)) { + return origErr + } + + // Single lock acquisition for both status update and removal so + // concurrent callers can't race on the same VM. + m.mu.Lock() + inst, ok := m.vms[vmID] + if !ok { + m.mu.Unlock() + // Already cleaned up by another goroutine. + return status.Errorf(codes.NotFound, "vm %s is no longer running", vmID) + } + inst.mu.Lock() + inst.Status = StatusStopped + inst.mu.Unlock() + delete(m.vms, vmID) + m.mu.Unlock() + + m.log.Warn().Str("vm_id", vmID).Err(origErr). + Msg("VM process is dead — cleaning up and returning NotFound") + m.persistState(inst) + m.deleteState(vmID) + return status.Errorf(codes.NotFound, "vm %s is no longer running", vmID) +} + // InstanceInfo is a snapshot of a VM's address and status for proxy lookups. type InstanceInfo struct { VMIP string @@ -935,12 +1249,36 @@ func (m *Manager) setStatus(vmID string, s VMStatus) { inst.mu.Lock() inst.Status = s inst.mu.Unlock() + m.persistState(inst) +} + +// persistState writes the current VM state to BoltDB. No-op if no state +// store is configured. Errors are logged but not returned — BoltDB is a +// cache, not a source of truth. +func (m *Manager) persistState(inst *VMInstance) { + if m.state == nil { + return + } + if err := m.state.Put(toRecord(inst)); err != nil { + m.log.Error().Err(err).Str("vm_id", inst.ID).Msg("failed to persist VM state to BoltDB") + } +} + +// deleteState removes a VM record from BoltDB. +func (m *Manager) deleteState(vmID string) { + if m.state == nil { + return + } + if err := m.state.Delete(vmID); err != nil { + m.log.Error().Err(err).Str("vm_id", vmID).Msg("failed to delete VM state from BoltDB") + } } func (m *Manager) removeVM(vmID string) { m.mu.Lock() delete(m.vms, vmID) m.mu.Unlock() + m.deleteState(vmID) } // copyRootfs creates a per-VM rootfs by copying the source image. @@ -1003,11 +1341,10 @@ func (m *Manager) startFirecrackerColdBoot(ctx context.Context, vmID, socketPath return pid, nil } -// startFirecrackerInNamespace launches Firecracker in its own mount namespace -// AND inside the given network namespace. The mount namespace provides rootfs -// isolation (tmpfs + symlink to per-VM rootfs), and the network namespace -// provides network isolation (each VM uses the same internal IP). -func (m *Manager) startFirecrackerInNamespace(vmID, socketPath, perVMRootfs, netNS string) (int, error) { +// startFirecrackerViaSystemd writes the start script and launches Firecracker +// as a standalone systemd unit. The VM survives VMD restarts because systemd +// owns the process, not VMD. +func (m *Manager) startFirecrackerViaSystemd(ctx context.Context, vmID, socketPath, perVMRootfs, netNS string) (int, error) { if err := os.MkdirAll(filepath.Dir(socketPath), 0o755); err != nil { return 0, fmt.Errorf("mkdir socket dir: %w", err) } @@ -1016,31 +1353,60 @@ func (m *Manager) startFirecrackerInNamespace(vmID, socketPath, perVMRootfs, net templateDir := m.templateRunDir() rootfsLink := filepath.Join(templateDir, "rootfs.ext4") - // Write a temporary shell script to avoid shell injection from config values. - // The script sets up mount namespace isolation, then exec's Firecracker. + // Write the start script that the systemd unit's ExecStart calls. scriptPath := filepath.Join(filepath.Dir(socketPath), "start.sh") - scriptContent := fmt.Sprintf("#!/bin/sh\nmount --make-rprivate / && mount -t tmpfs tmpfs %q && ln -s %q %q && exec %q --api-sock %q --id %q\n", - templateDir, perVMRootfs, rootfsLink, m.cfg.FirecrackerBin, socketPath, vmID) + scriptContent := fmt.Sprintf("#!/bin/sh\nexec ip netns exec %s unshare -m -- sh -c 'mount --make-rprivate / && mount -t tmpfs tmpfs %q && ln -s %q %q && exec %q --api-sock %q --id %q'\n", + netNS, templateDir, perVMRootfs, rootfsLink, m.cfg.FirecrackerBin, socketPath, vmID) if err := os.WriteFile(scriptPath, []byte(scriptContent), 0o755); err != nil { return 0, fmt.Errorf("write start script: %w", err) } - // Run inside the network namespace with a private mount namespace. - cmd := exec.Command("ip", "netns", "exec", netNS, - "unshare", "-m", "--", "sh", scriptPath) - cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true} - - if err := cmd.Start(); err != nil { - return 0, fmt.Errorf("exec firecracker in namespace: %w", err) + // Start the systemd unit. + if err := startUnit(ctx, systemdUnitName(vmID)); err != nil { + return 0, fmt.Errorf("start systemd unit: %w", err) } + // Wait for the Firecracker API socket. if err := waitForSocket(socketPath, 5*time.Second); err != nil { - _ = cmd.Process.Kill() + _ = stopUnit(ctx, systemdUnitName(vmID)) return 0, fmt.Errorf("wait for socket: %w", err) } - go func() { _ = cmd.Wait() }() - return cmd.Process.Pid, nil + // Read the PID asynchronously so the create path isn't slowed down + // by the ~15ms dbus roundtrip. The PID is populated in the instance + // shortly after create returns and persisted to BoltDB. + go m.resolveAndSetPID(vmID) + + return 0, nil +} + +func (m *Manager) resolveAndSetPID(vmID string) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, "systemctl", "show", "--property=MainPID", "--value", systemdUnitName(vmID)) + out, err := cmd.Output() + if err != nil { + return + } + var pid int + if _, err := fmt.Sscanf(strings.TrimSpace(string(out)), "%d", &pid); err != nil || pid == 0 { + return + } + + m.mu.RLock() + inst, ok := m.vms[vmID] + m.mu.RUnlock() + if !ok { + return + } + + inst.mu.Lock() + inst.PID = pid + inst.mu.Unlock() + + m.persistState(inst) + m.log.Debug().Str("vm_id", vmID).Int("pid", pid).Msg("resolved systemd MainPID") } func waitForSocket(path string, timeout time.Duration) error { @@ -1055,19 +1421,25 @@ func waitForSocket(path string, timeout time.Duration) error { } // killVMKeepRunDir terminates the VM process and releases networking but -// leaves the rundir intact on disk. +// leaves the rundir intact on disk. Used only for the throwaway template +// VM (which is cold-booted as a direct child, not a systemd unit). func (m *Manager) killVMKeepRunDir(vmID string) { inst, err := m.getInstance(vmID) if err != nil { return } + // Template VMs run as direct child processes (cold boot path). + // Regular VMs run as systemd units — stop the unit if it exists. if inst.PID > 0 { if proc, e := os.FindProcess(inst.PID); e == nil { _ = proc.Signal(syscall.SIGKILL) go proc.Wait() //nolint:errcheck } + } else { + _ = stopUnit(context.Background(), systemdUnitName(vmID)) } + if inst.SocketPath != "" { _ = os.Remove(inst.SocketPath) } @@ -1122,11 +1494,3 @@ func (m *Manager) CleanupTemplate() { m.log.Info().Msg("template files cleaned up") } -func firstNonEmpty(vals ...string) string { - for _, v := range vals { - if v != "" { - return v - } - } - return "" -} diff --git a/internal/vm/reconciler.go b/internal/vm/reconciler.go new file mode 100644 index 0000000..de910c2 --- /dev/null +++ b/internal/vm/reconciler.go @@ -0,0 +1,460 @@ +package vm + +import ( + "context" + "os" + "sync" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" + + "github.com/superserve-ai/sandbox/internal/db" +) + +// ReconcilerConfig controls the periodic reconciler. +type ReconcilerConfig struct { + // Interval is how often the reconciler runs. + Interval time.Duration + // GracePeriod is the minimum time a drift must persist before the + // reconciler takes destructive action. Prevents races where VMD has + // just started a VM and systemd hasn't fully registered it yet. + GracePeriod time.Duration + // MaxAutoFailPerHour caps destructive actions per host to bound the + // blast radius of a reconciler bug. If exceeded, the reconciler + // logs a paging alert and stops taking destructive action until + // the counter resets. + MaxAutoFailPerHour int + // HostID is this host's identifier in the `host` table. The reconciler + // only operates on sandboxes with this host_id. + HostID string + // DB is optional. When set, the reconciler does three-way drift + // detection (BoltDB ↔ systemd ↔ DB) and writes audit log entries. + // When nil, it only compares BoltDB and systemd. + DB *db.Queries +} + +// DefaultReconcilerConfig returns sensible defaults from the design doc. +func DefaultReconcilerConfig() ReconcilerConfig { + return ReconcilerConfig{ + Interval: 30 * time.Second, + GracePeriod: 60 * time.Second, + MaxAutoFailPerHour: 5, + } +} + +// Reconciler detects and fixes drift between three sources of truth: +// +// - systemd: which firecracker@ units are actually running +// (authoritative for liveness) +// - Control plane DB: the sandbox rows scheduled on this host +// (authoritative for intent) +// - BoltDB: VMD's own fast-path cache (authoritative for nothing) +// +// The DB source is optional — if the reconciler is constructed without a +// DB, it falls back to a BoltDB ↔ systemd comparison only. Destructive +// actions are rate-limited via MaxAutoFailPerHour and require the drift +// to persist across at least two consecutive runs (GracePeriod). +type Reconciler struct { + mgr *Manager + cfg ReconcilerConfig + + // driftSeen tracks the first-seen timestamp for each drifted VM so + // we can enforce the grace period. Keyed by vmID. + mu sync.Mutex + driftSeen map[string]time.Time + autoFailLog []time.Time // timestamps of recent auto-fail actions +} + +// NewReconciler creates a reconciler bound to a Manager. +func NewReconciler(mgr *Manager, cfg ReconcilerConfig) *Reconciler { + return &Reconciler{ + mgr: mgr, + cfg: cfg, + driftSeen: make(map[string]time.Time), + } +} + +// runTimeout bounds each reconciliation pass so a slow DB or stuck +// systemctl call cannot wedge the loop. +const runTimeout = 25 * time.Second + +// Run launches the reconciler loop. Blocks until ctx is cancelled. +func (r *Reconciler) Run(ctx context.Context) { + log := r.mgr.log.With().Str("component", "reconciler").Logger() + log.Info(). + Dur("interval", r.cfg.Interval). + Dur("grace_period", r.cfg.GracePeriod). + Int("max_autofail_per_hour", r.cfg.MaxAutoFailPerHour). + Msg("reconciler started") + + ticker := time.NewTicker(r.cfg.Interval) + defer ticker.Stop() + + // Run once immediately so startup is observable. + r.runWithTimeout(ctx) + + for { + select { + case <-ctx.Done(): + log.Info().Msg("reconciler exiting") + return + case <-ticker.C: + r.runWithTimeout(ctx) + } + } +} + +// runWithTimeout bounds a single reconciliation pass. Runs deadlines +// below the tick interval so two consecutive runs cannot overlap. +func (r *Reconciler) runWithTimeout(parent context.Context) { + ctx, cancel := context.WithTimeout(parent, runTimeout) + defer cancel() + r.runOnce(ctx) +} + +// runOnce performs a single reconciliation pass. Each pass: +// 1. Queries BoltDB, systemd, and (optionally) the control plane DB. +// 2. Compares the three sets. +// 3. Records a "first seen" timestamp for every drift so we can enforce +// the grace period (rule C7). +// 4. Applies fixes that have persisted past the grace period, rate-limited +// by MaxAutoFailPerHour (rule C6). +func (r *Reconciler) runOnce(ctx context.Context) { + log := r.mgr.log.With().Str("component", "reconciler").Logger() + + if r.mgr.state == nil { + log.Debug().Msg("no state store — skipping run") + return + } + + // Source A: BoltDB records. + records, err := r.mgr.state.All() + if err != nil { + log.Error().Err(err).Msg("failed to read state store") + return + } + bolted := make(map[string]VMRecord, len(records)) + for _, rec := range records { + bolted[rec.ID] = rec + } + + // Source B: active systemd units. + ids, err := listActiveFirecrackerUnits(ctx) + if err != nil { + log.Error().Err(err).Msg("failed to list systemd units") + return + } + active := make(map[string]bool, len(ids)) + for _, id := range ids { + active[id] = true + } + + // Source C: DB sandbox rows for this host (optional). A short + // per-query deadline keeps a slow DB from stalling the whole run. + var dbSandboxes map[string]db.Sandbox + if r.cfg.DB != nil && r.cfg.HostID != "" { + qctx, cancel := context.WithTimeout(ctx, 10*time.Second) + rows, dbErr := r.cfg.DB.ListSandboxesByHost(qctx, r.cfg.HostID) + cancel() + if dbErr != nil { + log.Error().Err(dbErr).Msg("failed to list sandboxes from DB") + } else { + dbSandboxes = make(map[string]db.Sandbox, len(rows)) + for _, s := range rows { + dbSandboxes[s.ID.String()] = s + } + } + } + + now := time.Now() + + // Drift 1: DB says active, systemd/socket says dead. + // Action: mark sandbox failed in DB + clean up BoltDB + in-memory. + if dbSandboxes != nil { + for id, sb := range dbSandboxes { + if sb.Status != db.SandboxStatusActive { + continue + } + if r.isAlive(ctx, id, bolted) { + r.clearDrift(id) + continue + } + if !r.gracePeriodElapsed(id, now) { + continue + } + if !r.consumeAutoFailBudget(id) { + r.writeAudit(ctx, id, "budget_exhausted", "mark_failed suppressed by rate limit", "db_active_systemd_missing") + continue + } + log.Warn().Str("vm_id", id).Str("drift", "db_active_systemd_missing"). + Msg("DB says active but VM is dead — marking failed") + r.markFailedInDB(ctx, id) + r.markStale(id) + r.writeAudit(ctx, id, "mark_failed", "VM dead while DB said active", "db_active_systemd_missing") + } + } + + // Drift 2: BoltDB says running but VM is actually dead, and DB is + // unavailable (reconciler running in BoltDB-only mode). Fall back to + // the old behavior: just clean up the stale BoltDB entry. + if dbSandboxes == nil { + for id, rec := range bolted { + if rec.Status != StatusRunning { + continue + } + if r.isAlive(ctx, id, bolted) { + r.clearDrift(id) + continue + } + if !r.gracePeriodElapsed(id, now) { + continue + } + if !r.consumeAutoFailBudget(id) { + continue + } + log.Warn().Str("vm_id", id).Str("drift", "boltdb_running_unit_missing"). + Msg("dead Firecracker detected (no DB context)") + r.markStale(id) + r.writeAudit(ctx, id, "stale_cleanup", "VM dead, DB unavailable", "boltdb_running_unit_missing") + } + } + + // Drift 3: systemd has a unit, DB says the sandbox is deleted or has + // no row at all. This is an orphan — stop the unit + clean up. + if dbSandboxes != nil { + for id := range active { + sb, known := dbSandboxes[id] + deleted := known && sb.Status == db.SandboxStatusDeleted + if known && !deleted { + continue + } + if !r.gracePeriodElapsed("orphan:"+id, now) { + continue + } + if !r.consumeAutoFailBudget(id) { + r.writeAudit(ctx, id, "budget_exhausted", "orphan_stop suppressed by rate limit", "systemd_active_db_missing") + continue + } + reason := "systemd unit with no DB row" + kind := "systemd_active_db_missing" + if deleted { + reason = "systemd unit for soft-deleted sandbox" + kind = "systemd_active_db_deleted" + } + log.Warn().Str("vm_id", id).Str("drift", kind).Msg("orphan systemd unit — stopping") + if err := stopUnit(ctx, systemdUnitName(id)); err != nil { + log.Error().Err(err).Str("vm_id", id).Msg("failed to stop orphan unit") + continue + } + removeUnitDropIn(id) + r.markStale(id) + r.writeAudit(ctx, id, "orphan_stop", reason, kind) + r.clearDrift("orphan:" + id) + } + } + + // Drift 4: DB says paused, snapshot file missing on disk → mark failed. + if dbSandboxes != nil { + for id, sb := range dbSandboxes { + if sb.Status != db.SandboxStatusPaused || !sb.SnapshotID.Valid { + continue + } + snap, snapErr := r.getSnapshot(ctx, sb.SnapshotID.Bytes) + if snapErr != nil { + continue + } + if _, statErr := os.Stat(snap.Path); statErr == nil { + r.clearDrift("paused:" + id) + continue + } + if !r.gracePeriodElapsed("paused:"+id, now) { + continue + } + if !r.consumeAutoFailBudget(id) { + r.writeAudit(ctx, id, "budget_exhausted", "mark_failed suppressed by rate limit", "paused_snapshot_missing") + continue + } + log.Warn().Str("vm_id", id).Str("snapshot_path", snap.Path). + Str("drift", "paused_snapshot_missing"). + Msg("paused sandbox snapshot file missing — marking failed") + r.markFailedInDB(ctx, id) + r.writeAudit(ctx, id, "mark_failed", "snapshot file missing", "paused_snapshot_missing") + r.clearDrift("paused:" + id) + } + } + + // Drift 5: BoltDB record exists but DB has no corresponding sandbox + // row (either never written or soft-deleted). Clean up the BoltDB + // entry. If the VM is still live, we ALSO need to stop it — leaving + // a systemd unit running for a sandbox the control plane forgot about + // is a resource leak and a security risk. + if dbSandboxes != nil { + for id, rec := range bolted { + if _, ok := dbSandboxes[id]; ok { + continue + } + if !r.gracePeriodElapsed("bolt-orphan:"+id, now) { + continue + } + if !r.consumeAutoFailBudget(id) { + r.writeAudit(ctx, id, "budget_exhausted", "stale_cleanup suppressed by rate limit", "boltdb_present_db_missing") + continue + } + log.Warn().Str("vm_id", id).Str("drift", "boltdb_present_db_missing"). + Msg("BoltDB entry with no DB row — cleaning up") + // Stop the unit if it's still live so we don't leak a VM. + if rec.Status == StatusRunning { + if err := stopUnit(ctx, systemdUnitName(id)); err != nil { + log.Error().Err(err).Str("vm_id", id).Msg("failed to stop orphan unit from BoltDB") + } + removeUnitDropIn(id) + } + r.markStale(id) + r.writeAudit(ctx, id, "stale_cleanup", "BoltDB entry with no DB row", "boltdb_present_db_missing") + r.clearDrift("bolt-orphan:" + id) + } + } +} + +// isAlive returns true when the VM's systemd unit is currently active. +func (r *Reconciler) isAlive(ctx context.Context, vmID string, _ map[string]VMRecord) bool { + return isUnitActive(ctx, systemdUnitName(vmID)) +} + +// gracePeriodElapsed records the first-seen timestamp for a drifted ID and +// returns true once the configured grace period has passed. Used to absorb +// transient states (e.g. VMD just started a VM and systemd hasn't fully +// registered it yet). +func (r *Reconciler) gracePeriodElapsed(key string, now time.Time) bool { + r.mu.Lock() + defer r.mu.Unlock() + firstSeen, ok := r.driftSeen[key] + if !ok { + r.driftSeen[key] = now + return false + } + return now.Sub(firstSeen) >= r.cfg.GracePeriod +} + +// clearDrift removes a drift marker once the VM returns to a healthy state. +func (r *Reconciler) clearDrift(key string) { + r.mu.Lock() + delete(r.driftSeen, key) + r.mu.Unlock() +} + +// getSnapshot loads a snapshot row by ID via the internal (unscoped) +// query. The reconciler is host-scoped, not team-scoped, so it uses +// GetSnapshotByID which bypasses the tenant filter. Guards the call +// with dbQueryTimeout so a slow DB can't wedge the loop. +func (r *Reconciler) getSnapshot(ctx context.Context, id [16]byte) (db.Snapshot, error) { + qctx, cancel := context.WithTimeout(ctx, dbQueryTimeout) + defer cancel() + return r.cfg.DB.GetSnapshotByID(qctx, id) +} + +// dbQueryTimeout is the per-query deadline for short reconciler writes +// and single-row reads. Kept below runTimeout so a single slow query +// can't consume the whole run's budget. +const dbQueryTimeout = 5 * time.Second + +// markFailedInDB writes status=failed for the given sandbox ID. No-op if +// the DB is not configured. +func (r *Reconciler) markFailedInDB(ctx context.Context, vmID string) { + if r.cfg.DB == nil { + return + } + id, err := uuid.Parse(vmID) + if err != nil { + r.mgr.log.Error().Err(err).Str("vm_id", vmID).Msg("reconciler: invalid vm_id for DB mark-failed") + return + } + qctx, cancel := context.WithTimeout(ctx, dbQueryTimeout) + defer cancel() + if err := r.cfg.DB.MarkSandboxFailed(qctx, id); err != nil { + r.mgr.log.Error().Err(err).Str("vm_id", vmID).Msg("reconciler: failed to mark sandbox failed in DB") + } +} + +// writeAudit appends a row to the reconciler_log table. No-op if the DB +// is not configured. Rule C8: every reconciler action produces an audit +// record. +func (r *Reconciler) writeAudit(ctx context.Context, vmID, action, reason, driftKind string) { + if r.cfg.DB == nil { + return + } + var sandboxID pgtype.UUID + if id, err := uuid.Parse(vmID); err == nil { + sandboxID = pgtype.UUID{Bytes: id, Valid: true} + } + kind := driftKind + qctx, cancel := context.WithTimeout(ctx, dbQueryTimeout) + defer cancel() + if err := r.cfg.DB.InsertReconcilerLog(qctx, db.InsertReconcilerLogParams{ + HostID: r.cfg.HostID, + SandboxID: sandboxID, + Action: action, + Reason: reason, + DriftKind: &kind, + }); err != nil { + r.mgr.log.Error().Err(err).Str("vm_id", vmID).Msg("reconciler: failed to write audit log") + } +} + +// consumeAutoFailBudget enforces crash safety rule C6: bounded-blast-radius +// auto-failure. Returns false (and does not consume the budget) when the +// reconciler has already marked MaxAutoFailPerHour VMs stale in the last +// rolling hour. Mass drift is almost always a reconciler bug, not 50 +// simultaneous VM crashes. +func (r *Reconciler) consumeAutoFailBudget(vmID string) bool { + r.mu.Lock() + defer r.mu.Unlock() + + now := time.Now() + cutoff := now.Add(-time.Hour) + kept := r.autoFailLog[:0] + for _, t := range r.autoFailLog { + if t.After(cutoff) { + kept = append(kept, t) + } + } + r.autoFailLog = kept + + if len(r.autoFailLog) >= r.cfg.MaxAutoFailPerHour { + r.mgr.log.Error(). + Str("component", "reconciler"). + Str("vm_id", vmID). + Int("budget", r.cfg.MaxAutoFailPerHour). + Msg("auto-fail budget exhausted — halting destructive actions until budget resets") + return false + } + + r.autoFailLog = append(r.autoFailLog, now) + return true +} + +// markStale deletes the stale BoltDB entry and drops the VM from the +// in-memory map. The VM is already gone in reality; this just cleans up +// VMD's cache. +func (r *Reconciler) markStale(vmID string) { + // Delete from BoltDB first. If this fails, keep the in-memory entry + // so the state stays consistent — the reconciler will retry on the + // next run. Deleting from the map before BoltDB would cause + // ReattachAll to resurrect the stale record on next restart. + if err := r.mgr.state.Delete(vmID); err != nil { + r.mgr.log.Error().Err(err).Str("vm_id", vmID).Msg("reconciler: failed to delete stale state, will retry") + return + } + + r.mgr.mu.Lock() + delete(r.mgr.vms, vmID) + r.mgr.mu.Unlock() + + r.mu.Lock() + delete(r.driftSeen, vmID) + r.mu.Unlock() + + r.mgr.log.Warn().Str("component", "reconciler").Str("vm_id", vmID). + Str("action", "mark_stale").Msg("reconciler: cleaned up stale VM record") +} diff --git a/internal/vm/state.go b/internal/vm/state.go new file mode 100644 index 0000000..d774f4b --- /dev/null +++ b/internal/vm/state.go @@ -0,0 +1,170 @@ +package vm + +import ( + "encoding/json" + "fmt" + "time" + + bolt "go.etcd.io/bbolt" +) + +// State provides durable local persistence for VM instance metadata. +// It is a cache — systemd is the ground truth for liveness, the control +// plane DB is the ground truth for intent. State allows VMD to reattach +// to running Firecracker processes after a restart without querying the +// control plane. + +var bucketName = []byte("vms") + +// VMRecord is the serializable subset of VMInstance persisted to BoltDB. +// It contains everything VMD needs to reconstruct its in-memory map on +// startup and reattach to a live Firecracker process. +type VMRecord struct { + ID string `json:"id"` + PID int `json:"pid"` + SocketPath string `json:"socket_path"` + VsockPath string `json:"vsock_path,omitempty"` + IP string `json:"ip"` + TAPDevice string `json:"tap_device"` + MACAddress string `json:"mac_address"` + Status VMStatus `json:"status"` + RunDirID string `json:"rundir_id"` + Namespace string `json:"namespace"` + DiskPath string `json:"disk_path"` + SnapshotPath string `json:"snapshot_path,omitempty"` + MemFilePath string `json:"mem_file_path,omitempty"` + CreatedAt time.Time `json:"created_at"` + Metadata map[string]string `json:"metadata,omitempty"` + VCPU uint32 `json:"vcpu"` + MemoryMiB uint32 `json:"memory_mib"` +} + +// StateStore wraps a BoltDB database for VM state persistence. +type StateStore struct { + db *bolt.DB +} + +// OpenStateStore opens (or creates) the BoltDB file at path. +func OpenStateStore(path string) (*StateStore, error) { + db, err := bolt.Open(path, 0o600, &bolt.Options{Timeout: 1 * time.Second}) + if err != nil { + return nil, fmt.Errorf("open state store %s: %w", path, err) + } + if err := db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucketIfNotExists(bucketName) + return err + }); err != nil { + db.Close() + return nil, fmt.Errorf("create bucket: %w", err) + } + return &StateStore{db: db}, nil +} + +// Close flushes and closes the database. +func (s *StateStore) Close() error { + return s.db.Close() +} + +// Put persists a VM record. Called on every state transition. +func (s *StateStore) Put(rec VMRecord) error { + data, err := json.Marshal(rec) + if err != nil { + return fmt.Errorf("marshal vm record: %w", err) + } + return s.db.Update(func(tx *bolt.Tx) error { + return tx.Bucket(bucketName).Put([]byte(rec.ID), data) + }) +} + +// Get retrieves a single VM record by ID. Returns nil if not found. +func (s *StateStore) Get(vmID string) (*VMRecord, error) { + var rec VMRecord + err := s.db.View(func(tx *bolt.Tx) error { + v := tx.Bucket(bucketName).Get([]byte(vmID)) + if v == nil { + return nil + } + return json.Unmarshal(v, &rec) + }) + if err != nil { + return nil, err + } + if rec.ID == "" { + return nil, nil + } + return &rec, nil +} + +// Delete removes a VM record. +func (s *StateStore) Delete(vmID string) error { + return s.db.Update(func(tx *bolt.Tx) error { + return tx.Bucket(bucketName).Delete([]byte(vmID)) + }) +} + +// All returns every persisted VM record. +func (s *StateStore) All() ([]VMRecord, error) { + var records []VMRecord + err := s.db.View(func(tx *bolt.Tx) error { + b := tx.Bucket(bucketName) + return b.ForEach(func(_, v []byte) error { + var rec VMRecord + if err := json.Unmarshal(v, &rec); err != nil { + return fmt.Errorf("unmarshal vm record: %w", err) + } + records = append(records, rec) + return nil + }) + }) + return records, err +} + +// toRecord converts a VMInstance to a persistable VMRecord. +func toRecord(inst *VMInstance) VMRecord { + inst.mu.RLock() + defer inst.mu.RUnlock() + return VMRecord{ + ID: inst.ID, + PID: inst.PID, + SocketPath: inst.SocketPath, + VsockPath: inst.VsockPath, + IP: inst.IP, + TAPDevice: inst.TAPDevice, + MACAddress: inst.MACAddress, + Status: inst.Status, + RunDirID: inst.RunDirID, + Namespace: inst.Namespace, + DiskPath: inst.DiskPath, + SnapshotPath: inst.SnapshotPath, + MemFilePath: inst.MemFilePath, + CreatedAt: inst.CreatedAt, + Metadata: inst.Metadata, + VCPU: inst.Config.VCPU, + MemoryMiB: inst.Config.MemoryMiB, + } +} + +// toInstance converts a VMRecord back to a VMInstance. +func toInstance(rec VMRecord) *VMInstance { + return &VMInstance{ + ID: rec.ID, + PID: rec.PID, + SocketPath: rec.SocketPath, + VsockPath: rec.VsockPath, + IP: rec.IP, + TAPDevice: rec.TAPDevice, + MACAddress: rec.MACAddress, + Status: rec.Status, + RunDirID: rec.RunDirID, + Namespace: rec.Namespace, + DiskPath: rec.DiskPath, + SnapshotPath: rec.SnapshotPath, + MemFilePath: rec.MemFilePath, + CreatedAt: rec.CreatedAt, + Metadata: rec.Metadata, + Config: VMConfig{ + VCPU: rec.VCPU, + MemoryMiB: rec.MemoryMiB, + }, + } +} diff --git a/internal/vm/systemd.go b/internal/vm/systemd.go new file mode 100644 index 0000000..3fc0568 --- /dev/null +++ b/internal/vm/systemd.go @@ -0,0 +1,79 @@ +package vm + +import ( + "context" + "fmt" + "os" + "os/exec" + "strings" +) + +// systemdUnitName returns the systemd unit name for a sandbox. +func systemdUnitName(vmID string) string { + return "firecracker@" + vmID + ".service" +} + +// startUnit starts a systemd unit. Idempotent — starting an already-running +// unit is a no-op. +func startUnit(ctx context.Context, unit string) error { + cmd := exec.CommandContext(ctx, "systemctl", "start", unit) + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("systemctl start %s: %s: %w", unit, strings.TrimSpace(string(out)), err) + } + return nil +} + +// stopUnit stops a systemd unit. Idempotent — stopping an already-stopped +// unit is a no-op. +func stopUnit(ctx context.Context, unit string) error { + cmd := exec.CommandContext(ctx, "systemctl", "stop", unit) + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("systemctl stop %s: %s: %w", unit, strings.TrimSpace(string(out)), err) + } + return nil +} + +// isUnitActive checks if a systemd unit is currently active (running). +func isUnitActive(ctx context.Context, unit string) bool { + cmd := exec.CommandContext(ctx, "systemctl", "is-active", "--quiet", unit) + return cmd.Run() == nil +} + +// listActiveFirecrackerUnits returns the sandbox IDs of all running +// firecracker@ units. Used during startup reattach. +func listActiveFirecrackerUnits(ctx context.Context) ([]string, error) { + cmd := exec.CommandContext(ctx, "systemctl", "list-units", + "firecracker@*.service", "--state=active", "--no-legend", "--plain") + out, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("list firecracker units: %w", err) + } + + var ids []string + for _, line := range strings.Split(strings.TrimSpace(string(out)), "\n") { + if line == "" { + continue + } + // Each line: "firecracker@.service loaded active running ..." + fields := strings.Fields(line) + if len(fields) == 0 { + continue + } + unit := fields[0] + // Extract ID from "firecracker@.service" + unit = strings.TrimPrefix(unit, "firecracker@") + unit = strings.TrimSuffix(unit, ".service") + if unit != "" { + ids = append(ids, unit) + } + } + return ids, nil +} + +// removeUnitDropIn removes the drop-in directory for a firecracker@ unit. +func removeUnitDropIn(vmID string) { + dropInDir := fmt.Sprintf("/etc/systemd/system/firecracker@%s.service.d", vmID) + os.RemoveAll(dropInDir) +} diff --git a/internal/vm/vsock_exec.go b/internal/vm/vsock_exec.go index 921d29d..68305a5 100644 --- a/internal/vm/vsock_exec.go +++ b/internal/vm/vsock_exec.go @@ -1,20 +1,34 @@ package vm import ( + "bytes" "context" "encoding/json" "fmt" "io" "net/http" - "net/url" "time" "connectrpc.com/connect" + "connectrpc.com/otelconnect" pb "github.com/superserve-ai/sandbox/proto/boxdpb" "github.com/superserve-ai/sandbox/proto/boxdpb/boxdpbconnect" ) +// boxdInterceptors carries the otel client interceptor for outbound Connect +// calls to boxd. otelconnect.NewInterceptor only errors on impossible +// configuration, so panic at startup is the right escalation. The +// interceptor is no-op when the global TracerProvider is the SDK noop +// (i.e. telemetry disabled). +var boxdInterceptors = func() connect.Option { + i, err := otelconnect.NewInterceptor() + if err != nil { + panic("otelconnect.NewInterceptor: " + err.Error()) + } + return connect.WithInterceptors(i) +}() + // ExecResult holds the result of a command execution inside a VM. type ExecResult struct { Stdout []byte @@ -28,7 +42,7 @@ const boxdPort = 49983 // boxdProcessClient returns a Connect RPC client for the ProcessService. func boxdProcessClient(vmIP string) boxdpbconnect.ProcessServiceClient { baseURL := fmt.Sprintf("http://%s:%d", vmIP, boxdPort) - return boxdpbconnect.NewProcessServiceClient(http.DefaultClient, baseURL) + return boxdpbconnect.NewProcessServiceClient(http.DefaultClient, baseURL, boxdInterceptors) } // ExecOptions holds optional parameters for command execution. @@ -178,62 +192,47 @@ func waitForHTTPHealth(ctx context.Context, vmIP string, timeout time.Duration) return fmt.Errorf("boxd health check not ready after %s", timeout) } -// boxdFilesystemClient returns a Connect RPC client for the FilesystemService. -func boxdFilesystemClient(vmIP string) boxdpbconnect.FilesystemServiceClient { - baseURL := fmt.Sprintf("http://%s:%d", vmIP, boxdPort) - return boxdpbconnect.NewFilesystemServiceClient(http.DefaultClient, baseURL) -} - -// boxdFileURL returns the HTTP URL for file operations. -func boxdFileURL(vmIP, path string) string { - return fmt.Sprintf("http://%s:%d/files?path=%s", vmIP, boxdPort, url.QueryEscape(path)) -} - -// uploadFile uploads content to a file inside a VM via raw HTTP. -func uploadFile(ctx context.Context, vmIP, filePath string, content io.Reader) (int64, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodPost, boxdFileURL(vmIP, filePath), content) - if err != nil { - return 0, fmt.Errorf("create request: %w", err) - } - req.Header.Set("Content-Type", "application/octet-stream") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return 0, fmt.Errorf("upload file: %w", err) +// postBoxdInit sends sandbox-level environment variables to boxd's /init +// endpoint. These vars are injected into every process boxd spawns. +func postBoxdInit(ctx context.Context, vmIP string, envVars map[string]string) error { + if len(envVars) == 0 { + return nil } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return 0, fmt.Errorf("upload failed (status %d): %s", resp.StatusCode, body) - } + body := struct { + EnvVars map[string]string `json:"env_vars"` + }{EnvVars: envVars} - var result struct { - Size int64 `json:"size"` - } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return 0, fmt.Errorf("decode response: %w", err) + buf, err := json.Marshal(body) + if err != nil { + return fmt.Errorf("marshal init body: %w", err) } - return result.Size, nil -} -// downloadFile downloads a file from a VM via raw HTTP. -func downloadFile(ctx context.Context, vmIP, filePath string) (io.ReadCloser, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, boxdFileURL(vmIP, filePath), nil) + url := fmt.Sprintf("http://%s:%d/init", vmIP, boxdPort) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(buf)) if err != nil { - return nil, fmt.Errorf("create request: %w", err) + return fmt.Errorf("create init request: %w", err) } + req.Header.Set("Content-Type", "application/json") - resp, err := http.DefaultClient.Do(req) + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Do(req) if err != nil { - return nil, fmt.Errorf("download file: %w", err) + return fmt.Errorf("POST /init: %w", err) } + defer resp.Body.Close() + io.Copy(io.Discard, resp.Body) if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - resp.Body.Close() - return nil, fmt.Errorf("download failed (status %d): %s", resp.StatusCode, body) + return fmt.Errorf("POST /init: status %d", resp.StatusCode) } + return nil +} - return resp.Body, nil +// boxdFilesystemClient returns a Connect RPC client for boxd's +// FilesystemService, used for metadata ops (Remove, Move, etc.) inside +// a VM. File byte transfer goes through the edge proxy directly. +func boxdFilesystemClient(vmIP string) boxdpbconnect.FilesystemServiceClient { + baseURL := fmt.Sprintf("http://%s:%d", vmIP, boxdPort) + return boxdpbconnect.NewFilesystemServiceClient(http.DefaultClient, baseURL, boxdInterceptors) } diff --git a/internal/vmdclient/client.go b/internal/vmdclient/client.go new file mode 100644 index 0000000..4756cfb --- /dev/null +++ b/internal/vmdclient/client.go @@ -0,0 +1,28 @@ +// Package vmdclient defines the interface for talking to a VM daemon. +// It lives in its own leaf package so both internal/api and +// internal/hostreg can reference it without circular imports. +package vmdclient + +import "context" + +// Client defines the subset of the VM daemon gRPC interface used by the +// control plane. Implementations: grpcVMDClient in cmd/controlplane, +// stubVMD in tests. +type Client interface { + CreateInstance(ctx context.Context, instanceID string, vcpu, memMiB, diskMiB uint32, metadata map[string]string, envVars map[string]string) (ipAddress string, actualVcpu, actualMemMiB uint32, err error) + DestroyInstance(ctx context.Context, instanceID string, force bool) error + PauseInstance(ctx context.Context, instanceID, snapshotDir string) (snapshotPath, memPath string, err error) + ResumeInstance(ctx context.Context, instanceID, snapshotPath, memPath string, envVars map[string]string) (ipAddress string, actualVcpu, actualMemMiB uint32, err error) + // RestoreSnapshot is the stateless restore path used as a fallback when + // ResumeInstance fails with NotFound (e.g. after a VMD crash lost the + // in-memory map but the snapshot files are still on disk). + RestoreSnapshot(ctx context.Context, instanceID, snapshotPath, memPath string) (ipAddress string, actualVcpu, actualMemMiB uint32, err error) + // DeleteSnapshot removes the on-disk vmstate + memory files for a + // previous snapshot. Idempotent: missing files return nil. Used by the + // control plane to garbage-collect the previous snapshot after a new + // pause writes a fresh one. + DeleteSnapshot(ctx context.Context, instanceID, snapshotPath, memPath string) error + ExecCommand(ctx context.Context, instanceID, command string, args []string, env map[string]string, workingDir string, timeoutS uint32) (stdout, stderr string, exitCode int32, err error) + ExecCommandStream(ctx context.Context, instanceID, command string, args []string, env map[string]string, workingDir string, timeoutS uint32, onChunk func(stdout, stderr []byte, exitCode int32, finished bool)) error + UpdateSandboxNetwork(ctx context.Context, instanceID string, allowedCIDRs, deniedCIDRs, allowedDomains []string) error +} diff --git a/proto/vmd.proto b/proto/vmd.proto index 9303411..b6891c4 100644 --- a/proto/vmd.proto +++ b/proto/vmd.proto @@ -24,6 +24,12 @@ service VMDaemon { // RestoreSnapshot boots a VM from a previously captured snapshot. rpc RestoreSnapshot(RestoreSnapshotRequest) returns (RestoreSnapshotResponse); + // DeleteSnapshot removes a snapshot's on-disk artifacts (vmstate + memory + // file). Idempotent — missing files are treated as success. Used by the + // control plane to garbage-collect the previous snapshot when a sandbox is + // re-paused and only the latest snapshot is reachable. + rpc DeleteSnapshot(DeleteSnapshotRequest) returns (DeleteSnapshotResponse); + // ExecCommand runs a command inside a VM and streams stdout/stderr back. rpc ExecCommand(ExecCommandRequest) returns (stream ExecCommandResponse); @@ -33,12 +39,6 @@ service VMDaemon { // SetupNetwork configures networking for a VM (TAP device, NAT, IP allocation). rpc SetupNetwork(SetupNetworkRequest) returns (SetupNetworkResponse); - // UploadFile streams file content into a VM. - rpc UploadFile(stream UploadFileRequest) returns (UploadFileResponse); - - // DownloadFile streams file content out of a VM. - rpc DownloadFile(DownloadFileRequest) returns (stream DownloadFileChunk); - // UpdateSandboxNetwork atomically replaces the egress allow/deny rules for a running VM. rpc UpdateSandboxNetwork(UpdateSandboxNetworkRequest) returns (UpdateSandboxNetworkResponse); } @@ -90,6 +90,7 @@ message CreateVMRequest { NetworkConfig network_config = 6; map metadata = 7; // Arbitrary key-value pairs. SandboxNetworkConfig sandbox_network = 8; // Per-sandbox egress rules (optional). + map env_vars = 9; // Environment variables injected into every process in the VM. } message CreateVMResponse { @@ -98,6 +99,7 @@ message CreateVMResponse { string ip_address = 3; string tap_device = 4; uint32 pid = 5; + ResourceLimits resource_limits = 6; } // --------------------------------------------------------------------------- @@ -138,6 +140,7 @@ message ResumeVMRequest { string snapshot_path = 2; string mem_file_path = 3; SandboxNetworkConfig sandbox_network = 4; // Reapply egress rules after resume. + map env_vars = 5; // Re-inject env vars after resume. } message ResumeVMResponse { @@ -145,6 +148,7 @@ message ResumeVMResponse { string socket_path = 2; string ip_address = 3; uint32 pid = 4; + ResourceLimits resource_limits = 5; } // --------------------------------------------------------------------------- @@ -183,6 +187,27 @@ message RestoreSnapshotResponse { uint32 pid = 4; } +// --------------------------------------------------------------------------- +// DeleteSnapshot +// --------------------------------------------------------------------------- + +message DeleteSnapshotRequest { + // Optional. Identifies the sandbox this snapshot belongs to — used for + // logging and for a sanity check that the paths live inside the expected + // snapshot directory. Never required for the delete itself. + string vm_id = 1; + // Absolute path to the vmstate snapshot file. + string snapshot_path = 2; + // Absolute path to the memory dump file. + string mem_file_path = 3; +} + +message DeleteSnapshotResponse { + // True when both files are no longer on disk after this call — either + // because this call removed them, or because they were already gone. + bool deleted = 1; +} + // --------------------------------------------------------------------------- // ExecCommand // --------------------------------------------------------------------------- @@ -249,33 +274,6 @@ message SetupNetworkResponse { string mac_address = 5; } -// --------------------------------------------------------------------------- -// UploadFile -// --------------------------------------------------------------------------- - -message UploadFileRequest { - string vm_id = 1; // Set in the first message of the stream. - string path = 2; // Set in the first message of the stream. - bytes data = 3; // File content chunks (all messages). -} - -message UploadFileResponse { - int64 bytes_written = 1; -} - -// --------------------------------------------------------------------------- -// DownloadFile -// --------------------------------------------------------------------------- - -message DownloadFileRequest { - string vm_id = 1; - string path = 2; -} - -message DownloadFileChunk { - bytes data = 1; -} - // --------------------------------------------------------------------------- // UpdateSandboxNetwork // --------------------------------------------------------------------------- diff --git a/proto/vmdpb/vmd.pb.go b/proto/vmdpb/vmd.pb.go index cac110d..b56cd19 100644 --- a/proto/vmdpb/vmd.pb.go +++ b/proto/vmdpb/vmd.pb.go @@ -329,8 +329,9 @@ type CreateVMRequest struct { KernelArgs string `protobuf:"bytes,4,opt,name=kernel_args,json=kernelArgs,proto3" json:"kernel_args,omitempty"` // Boot arguments for the kernel. ResourceLimits *ResourceLimits `protobuf:"bytes,5,opt,name=resource_limits,json=resourceLimits,proto3" json:"resource_limits,omitempty"` NetworkConfig *NetworkConfig `protobuf:"bytes,6,opt,name=network_config,json=networkConfig,proto3" json:"network_config,omitempty"` - Metadata map[string]string `protobuf:"bytes,7,rep,name=metadata,proto3" json:"metadata,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` // Arbitrary key-value pairs. - SandboxNetwork *SandboxNetworkConfig `protobuf:"bytes,8,opt,name=sandbox_network,json=sandboxNetwork,proto3" json:"sandbox_network,omitempty"` // Per-sandbox egress rules (optional). + Metadata map[string]string `protobuf:"bytes,7,rep,name=metadata,proto3" json:"metadata,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` // Arbitrary key-value pairs. + SandboxNetwork *SandboxNetworkConfig `protobuf:"bytes,8,opt,name=sandbox_network,json=sandboxNetwork,proto3" json:"sandbox_network,omitempty"` // Per-sandbox egress rules (optional). + EnvVars map[string]string `protobuf:"bytes,9,rep,name=env_vars,json=envVars,proto3" json:"env_vars,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` // Environment variables injected into every process in the VM. unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -421,15 +422,23 @@ func (x *CreateVMRequest) GetSandboxNetwork() *SandboxNetworkConfig { return nil } +func (x *CreateVMRequest) GetEnvVars() map[string]string { + if x != nil { + return x.EnvVars + } + return nil +} + type CreateVMResponse struct { - state protoimpl.MessageState `protogen:"open.v1"` - VmId string `protobuf:"bytes,1,opt,name=vm_id,json=vmId,proto3" json:"vm_id,omitempty"` - SocketPath string `protobuf:"bytes,2,opt,name=socket_path,json=socketPath,proto3" json:"socket_path,omitempty"` - IpAddress string `protobuf:"bytes,3,opt,name=ip_address,json=ipAddress,proto3" json:"ip_address,omitempty"` - TapDevice string `protobuf:"bytes,4,opt,name=tap_device,json=tapDevice,proto3" json:"tap_device,omitempty"` - Pid uint32 `protobuf:"varint,5,opt,name=pid,proto3" json:"pid,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + VmId string `protobuf:"bytes,1,opt,name=vm_id,json=vmId,proto3" json:"vm_id,omitempty"` + SocketPath string `protobuf:"bytes,2,opt,name=socket_path,json=socketPath,proto3" json:"socket_path,omitempty"` + IpAddress string `protobuf:"bytes,3,opt,name=ip_address,json=ipAddress,proto3" json:"ip_address,omitempty"` + TapDevice string `protobuf:"bytes,4,opt,name=tap_device,json=tapDevice,proto3" json:"tap_device,omitempty"` + Pid uint32 `protobuf:"varint,5,opt,name=pid,proto3" json:"pid,omitempty"` + ResourceLimits *ResourceLimits `protobuf:"bytes,6,opt,name=resource_limits,json=resourceLimits,proto3" json:"resource_limits,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *CreateVMResponse) Reset() { @@ -497,6 +506,13 @@ func (x *CreateVMResponse) GetPid() uint32 { return 0 } +func (x *CreateVMResponse) GetResourceLimits() *ResourceLimits { + if x != nil { + return x.ResourceLimits + } + return nil +} + type DestroyVMRequest struct { state protoimpl.MessageState `protogen:"open.v1"` VmId string `protobuf:"bytes,1,opt,name=vm_id,json=vmId,proto3" json:"vm_id,omitempty"` @@ -718,7 +734,8 @@ type ResumeVMRequest struct { VmId string `protobuf:"bytes,1,opt,name=vm_id,json=vmId,proto3" json:"vm_id,omitempty"` SnapshotPath string `protobuf:"bytes,2,opt,name=snapshot_path,json=snapshotPath,proto3" json:"snapshot_path,omitempty"` MemFilePath string `protobuf:"bytes,3,opt,name=mem_file_path,json=memFilePath,proto3" json:"mem_file_path,omitempty"` - SandboxNetwork *SandboxNetworkConfig `protobuf:"bytes,4,opt,name=sandbox_network,json=sandboxNetwork,proto3" json:"sandbox_network,omitempty"` // Reapply egress rules after resume. + SandboxNetwork *SandboxNetworkConfig `protobuf:"bytes,4,opt,name=sandbox_network,json=sandboxNetwork,proto3" json:"sandbox_network,omitempty"` // Reapply egress rules after resume. + EnvVars map[string]string `protobuf:"bytes,5,rep,name=env_vars,json=envVars,proto3" json:"env_vars,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` // Re-inject env vars after resume. unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -781,14 +798,22 @@ func (x *ResumeVMRequest) GetSandboxNetwork() *SandboxNetworkConfig { return nil } +func (x *ResumeVMRequest) GetEnvVars() map[string]string { + if x != nil { + return x.EnvVars + } + return nil +} + type ResumeVMResponse struct { - state protoimpl.MessageState `protogen:"open.v1"` - VmId string `protobuf:"bytes,1,opt,name=vm_id,json=vmId,proto3" json:"vm_id,omitempty"` - SocketPath string `protobuf:"bytes,2,opt,name=socket_path,json=socketPath,proto3" json:"socket_path,omitempty"` - IpAddress string `protobuf:"bytes,3,opt,name=ip_address,json=ipAddress,proto3" json:"ip_address,omitempty"` - Pid uint32 `protobuf:"varint,4,opt,name=pid,proto3" json:"pid,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + VmId string `protobuf:"bytes,1,opt,name=vm_id,json=vmId,proto3" json:"vm_id,omitempty"` + SocketPath string `protobuf:"bytes,2,opt,name=socket_path,json=socketPath,proto3" json:"socket_path,omitempty"` + IpAddress string `protobuf:"bytes,3,opt,name=ip_address,json=ipAddress,proto3" json:"ip_address,omitempty"` + Pid uint32 `protobuf:"varint,4,opt,name=pid,proto3" json:"pid,omitempty"` + ResourceLimits *ResourceLimits `protobuf:"bytes,5,opt,name=resource_limits,json=resourceLimits,proto3" json:"resource_limits,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *ResumeVMResponse) Reset() { @@ -849,6 +874,13 @@ func (x *ResumeVMResponse) GetPid() uint32 { return 0 } +func (x *ResumeVMResponse) GetResourceLimits() *ResourceLimits { + if x != nil { + return x.ResourceLimits + } + return nil +} + type CreateSnapshotRequest struct { state protoimpl.MessageState `protogen:"open.v1"` VmId string `protobuf:"bytes,1,opt,name=vm_id,json=vmId,proto3" json:"vm_id,omitempty"` @@ -1121,6 +1153,117 @@ func (x *RestoreSnapshotResponse) GetPid() uint32 { return 0 } +type DeleteSnapshotRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Optional. Identifies the sandbox this snapshot belongs to — used for + // logging and for a sanity check that the paths live inside the expected + // snapshot directory. Never required for the delete itself. + VmId string `protobuf:"bytes,1,opt,name=vm_id,json=vmId,proto3" json:"vm_id,omitempty"` + // Absolute path to the vmstate snapshot file. + SnapshotPath string `protobuf:"bytes,2,opt,name=snapshot_path,json=snapshotPath,proto3" json:"snapshot_path,omitempty"` + // Absolute path to the memory dump file. + MemFilePath string `protobuf:"bytes,3,opt,name=mem_file_path,json=memFilePath,proto3" json:"mem_file_path,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DeleteSnapshotRequest) Reset() { + *x = DeleteSnapshotRequest{} + mi := &file_proto_vmd_proto_msgTypes[16] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DeleteSnapshotRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DeleteSnapshotRequest) ProtoMessage() {} + +func (x *DeleteSnapshotRequest) ProtoReflect() protoreflect.Message { + mi := &file_proto_vmd_proto_msgTypes[16] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DeleteSnapshotRequest.ProtoReflect.Descriptor instead. +func (*DeleteSnapshotRequest) Descriptor() ([]byte, []int) { + return file_proto_vmd_proto_rawDescGZIP(), []int{16} +} + +func (x *DeleteSnapshotRequest) GetVmId() string { + if x != nil { + return x.VmId + } + return "" +} + +func (x *DeleteSnapshotRequest) GetSnapshotPath() string { + if x != nil { + return x.SnapshotPath + } + return "" +} + +func (x *DeleteSnapshotRequest) GetMemFilePath() string { + if x != nil { + return x.MemFilePath + } + return "" +} + +type DeleteSnapshotResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // True when both files are no longer on disk after this call — either + // because this call removed them, or because they were already gone. + Deleted bool `protobuf:"varint,1,opt,name=deleted,proto3" json:"deleted,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DeleteSnapshotResponse) Reset() { + *x = DeleteSnapshotResponse{} + mi := &file_proto_vmd_proto_msgTypes[17] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DeleteSnapshotResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DeleteSnapshotResponse) ProtoMessage() {} + +func (x *DeleteSnapshotResponse) ProtoReflect() protoreflect.Message { + mi := &file_proto_vmd_proto_msgTypes[17] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DeleteSnapshotResponse.ProtoReflect.Descriptor instead. +func (*DeleteSnapshotResponse) Descriptor() ([]byte, []int) { + return file_proto_vmd_proto_rawDescGZIP(), []int{17} +} + +func (x *DeleteSnapshotResponse) GetDeleted() bool { + if x != nil { + return x.Deleted + } + return false +} + type ExecCommandRequest struct { state protoimpl.MessageState `protogen:"open.v1"` VmId string `protobuf:"bytes,1,opt,name=vm_id,json=vmId,proto3" json:"vm_id,omitempty"` @@ -1135,7 +1278,7 @@ type ExecCommandRequest struct { func (x *ExecCommandRequest) Reset() { *x = ExecCommandRequest{} - mi := &file_proto_vmd_proto_msgTypes[16] + mi := &file_proto_vmd_proto_msgTypes[18] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1147,7 +1290,7 @@ func (x *ExecCommandRequest) String() string { func (*ExecCommandRequest) ProtoMessage() {} func (x *ExecCommandRequest) ProtoReflect() protoreflect.Message { - mi := &file_proto_vmd_proto_msgTypes[16] + mi := &file_proto_vmd_proto_msgTypes[18] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1160,7 +1303,7 @@ func (x *ExecCommandRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ExecCommandRequest.ProtoReflect.Descriptor instead. func (*ExecCommandRequest) Descriptor() ([]byte, []int) { - return file_proto_vmd_proto_rawDescGZIP(), []int{16} + return file_proto_vmd_proto_rawDescGZIP(), []int{18} } func (x *ExecCommandRequest) GetVmId() string { @@ -1217,7 +1360,7 @@ type ExecCommandResponse struct { func (x *ExecCommandResponse) Reset() { *x = ExecCommandResponse{} - mi := &file_proto_vmd_proto_msgTypes[17] + mi := &file_proto_vmd_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1229,7 +1372,7 @@ func (x *ExecCommandResponse) String() string { func (*ExecCommandResponse) ProtoMessage() {} func (x *ExecCommandResponse) ProtoReflect() protoreflect.Message { - mi := &file_proto_vmd_proto_msgTypes[17] + mi := &file_proto_vmd_proto_msgTypes[19] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1242,7 +1385,7 @@ func (x *ExecCommandResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ExecCommandResponse.ProtoReflect.Descriptor instead. func (*ExecCommandResponse) Descriptor() ([]byte, []int) { - return file_proto_vmd_proto_rawDescGZIP(), []int{17} + return file_proto_vmd_proto_rawDescGZIP(), []int{19} } func (x *ExecCommandResponse) GetStdout() []byte { @@ -1282,7 +1425,7 @@ type GetVMInfoRequest struct { func (x *GetVMInfoRequest) Reset() { *x = GetVMInfoRequest{} - mi := &file_proto_vmd_proto_msgTypes[18] + mi := &file_proto_vmd_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1294,7 +1437,7 @@ func (x *GetVMInfoRequest) String() string { func (*GetVMInfoRequest) ProtoMessage() {} func (x *GetVMInfoRequest) ProtoReflect() protoreflect.Message { - mi := &file_proto_vmd_proto_msgTypes[18] + mi := &file_proto_vmd_proto_msgTypes[20] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1307,7 +1450,7 @@ func (x *GetVMInfoRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetVMInfoRequest.ProtoReflect.Descriptor instead. func (*GetVMInfoRequest) Descriptor() ([]byte, []int) { - return file_proto_vmd_proto_rawDescGZIP(), []int{18} + return file_proto_vmd_proto_rawDescGZIP(), []int{20} } func (x *GetVMInfoRequest) GetVmId() string { @@ -1334,7 +1477,7 @@ type GetVMInfoResponse struct { func (x *GetVMInfoResponse) Reset() { *x = GetVMInfoResponse{} - mi := &file_proto_vmd_proto_msgTypes[19] + mi := &file_proto_vmd_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1346,7 +1489,7 @@ func (x *GetVMInfoResponse) String() string { func (*GetVMInfoResponse) ProtoMessage() {} func (x *GetVMInfoResponse) ProtoReflect() protoreflect.Message { - mi := &file_proto_vmd_proto_msgTypes[19] + mi := &file_proto_vmd_proto_msgTypes[21] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1359,7 +1502,7 @@ func (x *GetVMInfoResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetVMInfoResponse.ProtoReflect.Descriptor instead. func (*GetVMInfoResponse) Descriptor() ([]byte, []int) { - return file_proto_vmd_proto_rawDescGZIP(), []int{19} + return file_proto_vmd_proto_rawDescGZIP(), []int{21} } func (x *GetVMInfoResponse) GetVmId() string { @@ -1435,7 +1578,7 @@ type SetupNetworkRequest struct { func (x *SetupNetworkRequest) Reset() { *x = SetupNetworkRequest{} - mi := &file_proto_vmd_proto_msgTypes[20] + mi := &file_proto_vmd_proto_msgTypes[22] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1447,7 +1590,7 @@ func (x *SetupNetworkRequest) String() string { func (*SetupNetworkRequest) ProtoMessage() {} func (x *SetupNetworkRequest) ProtoReflect() protoreflect.Message { - mi := &file_proto_vmd_proto_msgTypes[20] + mi := &file_proto_vmd_proto_msgTypes[22] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1460,7 +1603,7 @@ func (x *SetupNetworkRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SetupNetworkRequest.ProtoReflect.Descriptor instead. func (*SetupNetworkRequest) Descriptor() ([]byte, []int) { - return file_proto_vmd_proto_rawDescGZIP(), []int{20} + return file_proto_vmd_proto_rawDescGZIP(), []int{22} } func (x *SetupNetworkRequest) GetVmId() string { @@ -1490,7 +1633,7 @@ type SetupNetworkResponse struct { func (x *SetupNetworkResponse) Reset() { *x = SetupNetworkResponse{} - mi := &file_proto_vmd_proto_msgTypes[21] + mi := &file_proto_vmd_proto_msgTypes[23] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1502,7 +1645,7 @@ func (x *SetupNetworkResponse) String() string { func (*SetupNetworkResponse) ProtoMessage() {} func (x *SetupNetworkResponse) ProtoReflect() protoreflect.Message { - mi := &file_proto_vmd_proto_msgTypes[21] + mi := &file_proto_vmd_proto_msgTypes[23] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1515,7 +1658,7 @@ func (x *SetupNetworkResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SetupNetworkResponse.ProtoReflect.Descriptor instead. func (*SetupNetworkResponse) Descriptor() ([]byte, []int) { - return file_proto_vmd_proto_rawDescGZIP(), []int{21} + return file_proto_vmd_proto_rawDescGZIP(), []int{23} } func (x *SetupNetworkResponse) GetVmId() string { @@ -1553,206 +1696,6 @@ func (x *SetupNetworkResponse) GetMacAddress() string { return "" } -type UploadFileRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - VmId string `protobuf:"bytes,1,opt,name=vm_id,json=vmId,proto3" json:"vm_id,omitempty"` // Set in the first message of the stream. - Path string `protobuf:"bytes,2,opt,name=path,proto3" json:"path,omitempty"` // Set in the first message of the stream. - Data []byte `protobuf:"bytes,3,opt,name=data,proto3" json:"data,omitempty"` // File content chunks (all messages). - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *UploadFileRequest) Reset() { - *x = UploadFileRequest{} - mi := &file_proto_vmd_proto_msgTypes[22] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *UploadFileRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*UploadFileRequest) ProtoMessage() {} - -func (x *UploadFileRequest) ProtoReflect() protoreflect.Message { - mi := &file_proto_vmd_proto_msgTypes[22] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use UploadFileRequest.ProtoReflect.Descriptor instead. -func (*UploadFileRequest) Descriptor() ([]byte, []int) { - return file_proto_vmd_proto_rawDescGZIP(), []int{22} -} - -func (x *UploadFileRequest) GetVmId() string { - if x != nil { - return x.VmId - } - return "" -} - -func (x *UploadFileRequest) GetPath() string { - if x != nil { - return x.Path - } - return "" -} - -func (x *UploadFileRequest) GetData() []byte { - if x != nil { - return x.Data - } - return nil -} - -type UploadFileResponse struct { - state protoimpl.MessageState `protogen:"open.v1"` - BytesWritten int64 `protobuf:"varint,1,opt,name=bytes_written,json=bytesWritten,proto3" json:"bytes_written,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *UploadFileResponse) Reset() { - *x = UploadFileResponse{} - mi := &file_proto_vmd_proto_msgTypes[23] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *UploadFileResponse) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*UploadFileResponse) ProtoMessage() {} - -func (x *UploadFileResponse) ProtoReflect() protoreflect.Message { - mi := &file_proto_vmd_proto_msgTypes[23] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use UploadFileResponse.ProtoReflect.Descriptor instead. -func (*UploadFileResponse) Descriptor() ([]byte, []int) { - return file_proto_vmd_proto_rawDescGZIP(), []int{23} -} - -func (x *UploadFileResponse) GetBytesWritten() int64 { - if x != nil { - return x.BytesWritten - } - return 0 -} - -type DownloadFileRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - VmId string `protobuf:"bytes,1,opt,name=vm_id,json=vmId,proto3" json:"vm_id,omitempty"` - Path string `protobuf:"bytes,2,opt,name=path,proto3" json:"path,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *DownloadFileRequest) Reset() { - *x = DownloadFileRequest{} - mi := &file_proto_vmd_proto_msgTypes[24] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *DownloadFileRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*DownloadFileRequest) ProtoMessage() {} - -func (x *DownloadFileRequest) ProtoReflect() protoreflect.Message { - mi := &file_proto_vmd_proto_msgTypes[24] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use DownloadFileRequest.ProtoReflect.Descriptor instead. -func (*DownloadFileRequest) Descriptor() ([]byte, []int) { - return file_proto_vmd_proto_rawDescGZIP(), []int{24} -} - -func (x *DownloadFileRequest) GetVmId() string { - if x != nil { - return x.VmId - } - return "" -} - -func (x *DownloadFileRequest) GetPath() string { - if x != nil { - return x.Path - } - return "" -} - -type DownloadFileChunk struct { - state protoimpl.MessageState `protogen:"open.v1"` - Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *DownloadFileChunk) Reset() { - *x = DownloadFileChunk{} - mi := &file_proto_vmd_proto_msgTypes[25] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *DownloadFileChunk) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*DownloadFileChunk) ProtoMessage() {} - -func (x *DownloadFileChunk) ProtoReflect() protoreflect.Message { - mi := &file_proto_vmd_proto_msgTypes[25] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use DownloadFileChunk.ProtoReflect.Descriptor instead. -func (*DownloadFileChunk) Descriptor() ([]byte, []int) { - return file_proto_vmd_proto_rawDescGZIP(), []int{25} -} - -func (x *DownloadFileChunk) GetData() []byte { - if x != nil { - return x.Data - } - return nil -} - type UpdateSandboxNetworkRequest struct { state protoimpl.MessageState `protogen:"open.v1"` VmId string `protobuf:"bytes,1,opt,name=vm_id,json=vmId,proto3" json:"vm_id,omitempty"` @@ -1763,7 +1706,7 @@ type UpdateSandboxNetworkRequest struct { func (x *UpdateSandboxNetworkRequest) Reset() { *x = UpdateSandboxNetworkRequest{} - mi := &file_proto_vmd_proto_msgTypes[26] + mi := &file_proto_vmd_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1775,7 +1718,7 @@ func (x *UpdateSandboxNetworkRequest) String() string { func (*UpdateSandboxNetworkRequest) ProtoMessage() {} func (x *UpdateSandboxNetworkRequest) ProtoReflect() protoreflect.Message { - mi := &file_proto_vmd_proto_msgTypes[26] + mi := &file_proto_vmd_proto_msgTypes[24] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1788,7 +1731,7 @@ func (x *UpdateSandboxNetworkRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use UpdateSandboxNetworkRequest.ProtoReflect.Descriptor instead. func (*UpdateSandboxNetworkRequest) Descriptor() ([]byte, []int) { - return file_proto_vmd_proto_rawDescGZIP(), []int{26} + return file_proto_vmd_proto_rawDescGZIP(), []int{24} } func (x *UpdateSandboxNetworkRequest) GetVmId() string { @@ -1814,7 +1757,7 @@ type UpdateSandboxNetworkResponse struct { func (x *UpdateSandboxNetworkResponse) Reset() { *x = UpdateSandboxNetworkResponse{} - mi := &file_proto_vmd_proto_msgTypes[27] + mi := &file_proto_vmd_proto_msgTypes[25] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1826,7 +1769,7 @@ func (x *UpdateSandboxNetworkResponse) String() string { func (*UpdateSandboxNetworkResponse) ProtoMessage() {} func (x *UpdateSandboxNetworkResponse) ProtoReflect() protoreflect.Message { - mi := &file_proto_vmd_proto_msgTypes[27] + mi := &file_proto_vmd_proto_msgTypes[25] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1839,7 +1782,7 @@ func (x *UpdateSandboxNetworkResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use UpdateSandboxNetworkResponse.ProtoReflect.Descriptor instead. func (*UpdateSandboxNetworkResponse) Descriptor() ([]byte, []int) { - return file_proto_vmd_proto_rawDescGZIP(), []int{27} + return file_proto_vmd_proto_rawDescGZIP(), []int{25} } func (x *UpdateSandboxNetworkResponse) GetVmId() string { @@ -1874,7 +1817,7 @@ const file_proto_vmd_proto_rawDesc = "" + "\x1aSandboxNetworkEgressConfig\x12#\n" + "\rallowed_cidrs\x18\x01 \x03(\tR\fallowedCidrs\x12!\n" + "\fdenied_cidrs\x18\x02 \x03(\tR\vdeniedCidrs\x12'\n" + - "\x0fallowed_domains\x18\x03 \x03(\tR\x0eallowedDomains\"\x84\x04\n" + + "\x0fallowed_domains\x18\x03 \x03(\tR\x0eallowedDomains\"\x8c\x05\n" + "\x0fCreateVMRequest\x12\x13\n" + "\x05vm_id\x18\x01 \x01(\tR\x04vmId\x12(\n" + "\x10base_rootfs_path\x18\x02 \x01(\tR\x0ebaseRootfsPath\x12\x1f\n" + @@ -1885,10 +1828,14 @@ const file_proto_vmd_proto_rawDesc = "" + "\x0fresource_limits\x18\x05 \x01(\v2!.superserve.vmd.v1.ResourceLimitsR\x0eresourceLimits\x12G\n" + "\x0enetwork_config\x18\x06 \x01(\v2 .superserve.vmd.v1.NetworkConfigR\rnetworkConfig\x12L\n" + "\bmetadata\x18\a \x03(\v20.superserve.vmd.v1.CreateVMRequest.MetadataEntryR\bmetadata\x12P\n" + - "\x0fsandbox_network\x18\b \x01(\v2'.superserve.vmd.v1.SandboxNetworkConfigR\x0esandboxNetwork\x1a;\n" + + "\x0fsandbox_network\x18\b \x01(\v2'.superserve.vmd.v1.SandboxNetworkConfigR\x0esandboxNetwork\x12J\n" + + "\benv_vars\x18\t \x03(\v2/.superserve.vmd.v1.CreateVMRequest.EnvVarsEntryR\aenvVars\x1a;\n" + "\rMetadataEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + - "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01\"\x98\x01\n" + + "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01\x1a:\n" + + "\fEnvVarsEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + + "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01\"\xe4\x01\n" + "\x10CreateVMResponse\x12\x13\n" + "\x05vm_id\x18\x01 \x01(\tR\x04vmId\x12\x1f\n" + "\vsocket_path\x18\x02 \x01(\tR\n" + @@ -1897,7 +1844,8 @@ const file_proto_vmd_proto_rawDesc = "" + "ip_address\x18\x03 \x01(\tR\tipAddress\x12\x1d\n" + "\n" + "tap_device\x18\x04 \x01(\tR\ttapDevice\x12\x10\n" + - "\x03pid\x18\x05 \x01(\rR\x03pid\"=\n" + + "\x03pid\x18\x05 \x01(\rR\x03pid\x12J\n" + + "\x0fresource_limits\x18\x06 \x01(\v2!.superserve.vmd.v1.ResourceLimitsR\x0eresourceLimits\"=\n" + "\x10DestroyVMRequest\x12\x13\n" + "\x05vm_id\x18\x01 \x01(\tR\x04vmId\x12\x14\n" + "\x05force\x18\x02 \x01(\bR\x05force\"G\n" + @@ -1911,19 +1859,24 @@ const file_proto_vmd_proto_rawDesc = "" + "\x0fPauseVMResponse\x12\x13\n" + "\x05vm_id\x18\x01 \x01(\tR\x04vmId\x12#\n" + "\rsnapshot_path\x18\x02 \x01(\tR\fsnapshotPath\x12\"\n" + - "\rmem_file_path\x18\x03 \x01(\tR\vmemFilePath\"\xc1\x01\n" + + "\rmem_file_path\x18\x03 \x01(\tR\vmemFilePath\"\xc9\x02\n" + "\x0fResumeVMRequest\x12\x13\n" + "\x05vm_id\x18\x01 \x01(\tR\x04vmId\x12#\n" + "\rsnapshot_path\x18\x02 \x01(\tR\fsnapshotPath\x12\"\n" + "\rmem_file_path\x18\x03 \x01(\tR\vmemFilePath\x12P\n" + - "\x0fsandbox_network\x18\x04 \x01(\v2'.superserve.vmd.v1.SandboxNetworkConfigR\x0esandboxNetwork\"y\n" + + "\x0fsandbox_network\x18\x04 \x01(\v2'.superserve.vmd.v1.SandboxNetworkConfigR\x0esandboxNetwork\x12J\n" + + "\benv_vars\x18\x05 \x03(\v2/.superserve.vmd.v1.ResumeVMRequest.EnvVarsEntryR\aenvVars\x1a:\n" + + "\fEnvVarsEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + + "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01\"\xc5\x01\n" + "\x10ResumeVMResponse\x12\x13\n" + "\x05vm_id\x18\x01 \x01(\tR\x04vmId\x12\x1f\n" + "\vsocket_path\x18\x02 \x01(\tR\n" + "socketPath\x12\x1d\n" + "\n" + "ip_address\x18\x03 \x01(\tR\tipAddress\x12\x10\n" + - "\x03pid\x18\x04 \x01(\rR\x03pid\"O\n" + + "\x03pid\x18\x04 \x01(\rR\x03pid\x12J\n" + + "\x0fresource_limits\x18\x05 \x01(\v2!.superserve.vmd.v1.ResourceLimitsR\x0eresourceLimits\"O\n" + "\x15CreateSnapshotRequest\x12\x13\n" + "\x05vm_id\x18\x01 \x01(\tR\x04vmId\x12!\n" + "\fsnapshot_dir\x18\x02 \x01(\tR\vsnapshotDir\"\x9e\x01\n" + @@ -1945,7 +1898,13 @@ const file_proto_vmd_proto_rawDesc = "" + "socketPath\x12\x1d\n" + "\n" + "ip_address\x18\x03 \x01(\tR\tipAddress\x12\x10\n" + - "\x03pid\x18\x04 \x01(\rR\x03pid\"\x9b\x02\n" + + "\x03pid\x18\x04 \x01(\rR\x03pid\"u\n" + + "\x15DeleteSnapshotRequest\x12\x13\n" + + "\x05vm_id\x18\x01 \x01(\tR\x04vmId\x12#\n" + + "\rsnapshot_path\x18\x02 \x01(\tR\fsnapshotPath\x12\"\n" + + "\rmem_file_path\x18\x03 \x01(\tR\vmemFilePath\"2\n" + + "\x16DeleteSnapshotResponse\x12\x18\n" + + "\adeleted\x18\x01 \x01(\bR\adeleted\"\x9b\x02\n" + "\x12ExecCommandRequest\x12\x13\n" + "\x05vm_id\x18\x01 \x01(\tR\x04vmId\x12\x18\n" + "\acommand\x18\x02 \x01(\tR\acommand\x12\x12\n" + @@ -1990,18 +1949,7 @@ const file_proto_vmd_proto_rawDesc = "" + "\n" + "gateway_ip\x18\x04 \x01(\tR\tgatewayIp\x12\x1f\n" + "\vmac_address\x18\x05 \x01(\tR\n" + - "macAddress\"P\n" + - "\x11UploadFileRequest\x12\x13\n" + - "\x05vm_id\x18\x01 \x01(\tR\x04vmId\x12\x12\n" + - "\x04path\x18\x02 \x01(\tR\x04path\x12\x12\n" + - "\x04data\x18\x03 \x01(\fR\x04data\"9\n" + - "\x12UploadFileResponse\x12#\n" + - "\rbytes_written\x18\x01 \x01(\x03R\fbytesWritten\">\n" + - "\x13DownloadFileRequest\x12\x13\n" + - "\x05vm_id\x18\x01 \x01(\tR\x04vmId\x12\x12\n" + - "\x04path\x18\x02 \x01(\tR\x04path\"'\n" + - "\x11DownloadFileChunk\x12\x12\n" + - "\x04data\x18\x01 \x01(\fR\x04data\"y\n" + + "macAddress\"y\n" + "\x1bUpdateSandboxNetworkRequest\x12\x13\n" + "\x05vm_id\x18\x01 \x01(\tR\x04vmId\x12E\n" + "\x06egress\x18\x02 \x01(\v2-.superserve.vmd.v1.SandboxNetworkEgressConfigR\x06egress\"3\n" + @@ -2013,20 +1961,18 @@ const file_proto_vmd_proto_rawDesc = "" + "\x11VM_STATUS_RUNNING\x10\x02\x12\x14\n" + "\x10VM_STATUS_PAUSED\x10\x03\x12\x15\n" + "\x11VM_STATUS_STOPPED\x10\x04\x12\x13\n" + - "\x0fVM_STATUS_ERROR\x10\x052\xfe\b\n" + + "\x0fVM_STATUS_ERROR\x10\x052\xa8\b\n" + "\bVMDaemon\x12S\n" + "\bCreateVM\x12\".superserve.vmd.v1.CreateVMRequest\x1a#.superserve.vmd.v1.CreateVMResponse\x12V\n" + "\tDestroyVM\x12#.superserve.vmd.v1.DestroyVMRequest\x1a$.superserve.vmd.v1.DestroyVMResponse\x12P\n" + "\aPauseVM\x12!.superserve.vmd.v1.PauseVMRequest\x1a\".superserve.vmd.v1.PauseVMResponse\x12S\n" + "\bResumeVM\x12\".superserve.vmd.v1.ResumeVMRequest\x1a#.superserve.vmd.v1.ResumeVMResponse\x12e\n" + "\x0eCreateSnapshot\x12(.superserve.vmd.v1.CreateSnapshotRequest\x1a).superserve.vmd.v1.CreateSnapshotResponse\x12h\n" + - "\x0fRestoreSnapshot\x12).superserve.vmd.v1.RestoreSnapshotRequest\x1a*.superserve.vmd.v1.RestoreSnapshotResponse\x12^\n" + + "\x0fRestoreSnapshot\x12).superserve.vmd.v1.RestoreSnapshotRequest\x1a*.superserve.vmd.v1.RestoreSnapshotResponse\x12e\n" + + "\x0eDeleteSnapshot\x12(.superserve.vmd.v1.DeleteSnapshotRequest\x1a).superserve.vmd.v1.DeleteSnapshotResponse\x12^\n" + "\vExecCommand\x12%.superserve.vmd.v1.ExecCommandRequest\x1a&.superserve.vmd.v1.ExecCommandResponse0\x01\x12V\n" + "\tGetVMInfo\x12#.superserve.vmd.v1.GetVMInfoRequest\x1a$.superserve.vmd.v1.GetVMInfoResponse\x12_\n" + - "\fSetupNetwork\x12&.superserve.vmd.v1.SetupNetworkRequest\x1a'.superserve.vmd.v1.SetupNetworkResponse\x12[\n" + - "\n" + - "UploadFile\x12$.superserve.vmd.v1.UploadFileRequest\x1a%.superserve.vmd.v1.UploadFileResponse(\x01\x12^\n" + - "\fDownloadFile\x12&.superserve.vmd.v1.DownloadFileRequest\x1a$.superserve.vmd.v1.DownloadFileChunk0\x01\x12w\n" + + "\fSetupNetwork\x12&.superserve.vmd.v1.SetupNetworkRequest\x1a'.superserve.vmd.v1.SetupNetworkResponse\x12w\n" + "\x14UpdateSandboxNetwork\x12..superserve.vmd.v1.UpdateSandboxNetworkRequest\x1a/.superserve.vmd.v1.UpdateSandboxNetworkResponseB.Z,github.com/superserve-ai/sandbox/proto/vmdpbb\x06proto3" var ( @@ -2061,19 +2007,19 @@ var file_proto_vmd_proto_goTypes = []any{ (*CreateSnapshotResponse)(nil), // 14: superserve.vmd.v1.CreateSnapshotResponse (*RestoreSnapshotRequest)(nil), // 15: superserve.vmd.v1.RestoreSnapshotRequest (*RestoreSnapshotResponse)(nil), // 16: superserve.vmd.v1.RestoreSnapshotResponse - (*ExecCommandRequest)(nil), // 17: superserve.vmd.v1.ExecCommandRequest - (*ExecCommandResponse)(nil), // 18: superserve.vmd.v1.ExecCommandResponse - (*GetVMInfoRequest)(nil), // 19: superserve.vmd.v1.GetVMInfoRequest - (*GetVMInfoResponse)(nil), // 20: superserve.vmd.v1.GetVMInfoResponse - (*SetupNetworkRequest)(nil), // 21: superserve.vmd.v1.SetupNetworkRequest - (*SetupNetworkResponse)(nil), // 22: superserve.vmd.v1.SetupNetworkResponse - (*UploadFileRequest)(nil), // 23: superserve.vmd.v1.UploadFileRequest - (*UploadFileResponse)(nil), // 24: superserve.vmd.v1.UploadFileResponse - (*DownloadFileRequest)(nil), // 25: superserve.vmd.v1.DownloadFileRequest - (*DownloadFileChunk)(nil), // 26: superserve.vmd.v1.DownloadFileChunk - (*UpdateSandboxNetworkRequest)(nil), // 27: superserve.vmd.v1.UpdateSandboxNetworkRequest - (*UpdateSandboxNetworkResponse)(nil), // 28: superserve.vmd.v1.UpdateSandboxNetworkResponse - nil, // 29: superserve.vmd.v1.CreateVMRequest.MetadataEntry + (*DeleteSnapshotRequest)(nil), // 17: superserve.vmd.v1.DeleteSnapshotRequest + (*DeleteSnapshotResponse)(nil), // 18: superserve.vmd.v1.DeleteSnapshotResponse + (*ExecCommandRequest)(nil), // 19: superserve.vmd.v1.ExecCommandRequest + (*ExecCommandResponse)(nil), // 20: superserve.vmd.v1.ExecCommandResponse + (*GetVMInfoRequest)(nil), // 21: superserve.vmd.v1.GetVMInfoRequest + (*GetVMInfoResponse)(nil), // 22: superserve.vmd.v1.GetVMInfoResponse + (*SetupNetworkRequest)(nil), // 23: superserve.vmd.v1.SetupNetworkRequest + (*SetupNetworkResponse)(nil), // 24: superserve.vmd.v1.SetupNetworkResponse + (*UpdateSandboxNetworkRequest)(nil), // 25: superserve.vmd.v1.UpdateSandboxNetworkRequest + (*UpdateSandboxNetworkResponse)(nil), // 26: superserve.vmd.v1.UpdateSandboxNetworkResponse + nil, // 27: superserve.vmd.v1.CreateVMRequest.MetadataEntry + nil, // 28: superserve.vmd.v1.CreateVMRequest.EnvVarsEntry + nil, // 29: superserve.vmd.v1.ResumeVMRequest.EnvVarsEntry nil, // 30: superserve.vmd.v1.ExecCommandRequest.EnvEntry nil, // 31: superserve.vmd.v1.GetVMInfoResponse.MetadataEntry } @@ -2081,46 +2027,48 @@ var file_proto_vmd_proto_depIdxs = []int32{ 4, // 0: superserve.vmd.v1.SandboxNetworkConfig.egress:type_name -> superserve.vmd.v1.SandboxNetworkEgressConfig 1, // 1: superserve.vmd.v1.CreateVMRequest.resource_limits:type_name -> superserve.vmd.v1.ResourceLimits 2, // 2: superserve.vmd.v1.CreateVMRequest.network_config:type_name -> superserve.vmd.v1.NetworkConfig - 29, // 3: superserve.vmd.v1.CreateVMRequest.metadata:type_name -> superserve.vmd.v1.CreateVMRequest.MetadataEntry + 27, // 3: superserve.vmd.v1.CreateVMRequest.metadata:type_name -> superserve.vmd.v1.CreateVMRequest.MetadataEntry 3, // 4: superserve.vmd.v1.CreateVMRequest.sandbox_network:type_name -> superserve.vmd.v1.SandboxNetworkConfig - 3, // 5: superserve.vmd.v1.ResumeVMRequest.sandbox_network:type_name -> superserve.vmd.v1.SandboxNetworkConfig - 1, // 6: superserve.vmd.v1.RestoreSnapshotRequest.resource_limits:type_name -> superserve.vmd.v1.ResourceLimits - 2, // 7: superserve.vmd.v1.RestoreSnapshotRequest.network_config:type_name -> superserve.vmd.v1.NetworkConfig - 30, // 8: superserve.vmd.v1.ExecCommandRequest.env:type_name -> superserve.vmd.v1.ExecCommandRequest.EnvEntry - 0, // 9: superserve.vmd.v1.GetVMInfoResponse.status:type_name -> superserve.vmd.v1.VMStatus - 1, // 10: superserve.vmd.v1.GetVMInfoResponse.resource_limits:type_name -> superserve.vmd.v1.ResourceLimits - 31, // 11: superserve.vmd.v1.GetVMInfoResponse.metadata:type_name -> superserve.vmd.v1.GetVMInfoResponse.MetadataEntry - 2, // 12: superserve.vmd.v1.SetupNetworkRequest.network_config:type_name -> superserve.vmd.v1.NetworkConfig - 4, // 13: superserve.vmd.v1.UpdateSandboxNetworkRequest.egress:type_name -> superserve.vmd.v1.SandboxNetworkEgressConfig - 5, // 14: superserve.vmd.v1.VMDaemon.CreateVM:input_type -> superserve.vmd.v1.CreateVMRequest - 7, // 15: superserve.vmd.v1.VMDaemon.DestroyVM:input_type -> superserve.vmd.v1.DestroyVMRequest - 9, // 16: superserve.vmd.v1.VMDaemon.PauseVM:input_type -> superserve.vmd.v1.PauseVMRequest - 11, // 17: superserve.vmd.v1.VMDaemon.ResumeVM:input_type -> superserve.vmd.v1.ResumeVMRequest - 13, // 18: superserve.vmd.v1.VMDaemon.CreateSnapshot:input_type -> superserve.vmd.v1.CreateSnapshotRequest - 15, // 19: superserve.vmd.v1.VMDaemon.RestoreSnapshot:input_type -> superserve.vmd.v1.RestoreSnapshotRequest - 17, // 20: superserve.vmd.v1.VMDaemon.ExecCommand:input_type -> superserve.vmd.v1.ExecCommandRequest - 19, // 21: superserve.vmd.v1.VMDaemon.GetVMInfo:input_type -> superserve.vmd.v1.GetVMInfoRequest - 21, // 22: superserve.vmd.v1.VMDaemon.SetupNetwork:input_type -> superserve.vmd.v1.SetupNetworkRequest - 23, // 23: superserve.vmd.v1.VMDaemon.UploadFile:input_type -> superserve.vmd.v1.UploadFileRequest - 25, // 24: superserve.vmd.v1.VMDaemon.DownloadFile:input_type -> superserve.vmd.v1.DownloadFileRequest - 27, // 25: superserve.vmd.v1.VMDaemon.UpdateSandboxNetwork:input_type -> superserve.vmd.v1.UpdateSandboxNetworkRequest - 6, // 26: superserve.vmd.v1.VMDaemon.CreateVM:output_type -> superserve.vmd.v1.CreateVMResponse - 8, // 27: superserve.vmd.v1.VMDaemon.DestroyVM:output_type -> superserve.vmd.v1.DestroyVMResponse - 10, // 28: superserve.vmd.v1.VMDaemon.PauseVM:output_type -> superserve.vmd.v1.PauseVMResponse - 12, // 29: superserve.vmd.v1.VMDaemon.ResumeVM:output_type -> superserve.vmd.v1.ResumeVMResponse - 14, // 30: superserve.vmd.v1.VMDaemon.CreateSnapshot:output_type -> superserve.vmd.v1.CreateSnapshotResponse - 16, // 31: superserve.vmd.v1.VMDaemon.RestoreSnapshot:output_type -> superserve.vmd.v1.RestoreSnapshotResponse - 18, // 32: superserve.vmd.v1.VMDaemon.ExecCommand:output_type -> superserve.vmd.v1.ExecCommandResponse - 20, // 33: superserve.vmd.v1.VMDaemon.GetVMInfo:output_type -> superserve.vmd.v1.GetVMInfoResponse - 22, // 34: superserve.vmd.v1.VMDaemon.SetupNetwork:output_type -> superserve.vmd.v1.SetupNetworkResponse - 24, // 35: superserve.vmd.v1.VMDaemon.UploadFile:output_type -> superserve.vmd.v1.UploadFileResponse - 26, // 36: superserve.vmd.v1.VMDaemon.DownloadFile:output_type -> superserve.vmd.v1.DownloadFileChunk - 28, // 37: superserve.vmd.v1.VMDaemon.UpdateSandboxNetwork:output_type -> superserve.vmd.v1.UpdateSandboxNetworkResponse - 26, // [26:38] is the sub-list for method output_type - 14, // [14:26] is the sub-list for method input_type - 14, // [14:14] is the sub-list for extension type_name - 14, // [14:14] is the sub-list for extension extendee - 0, // [0:14] is the sub-list for field type_name + 28, // 5: superserve.vmd.v1.CreateVMRequest.env_vars:type_name -> superserve.vmd.v1.CreateVMRequest.EnvVarsEntry + 1, // 6: superserve.vmd.v1.CreateVMResponse.resource_limits:type_name -> superserve.vmd.v1.ResourceLimits + 3, // 7: superserve.vmd.v1.ResumeVMRequest.sandbox_network:type_name -> superserve.vmd.v1.SandboxNetworkConfig + 29, // 8: superserve.vmd.v1.ResumeVMRequest.env_vars:type_name -> superserve.vmd.v1.ResumeVMRequest.EnvVarsEntry + 1, // 9: superserve.vmd.v1.ResumeVMResponse.resource_limits:type_name -> superserve.vmd.v1.ResourceLimits + 1, // 10: superserve.vmd.v1.RestoreSnapshotRequest.resource_limits:type_name -> superserve.vmd.v1.ResourceLimits + 2, // 11: superserve.vmd.v1.RestoreSnapshotRequest.network_config:type_name -> superserve.vmd.v1.NetworkConfig + 30, // 12: superserve.vmd.v1.ExecCommandRequest.env:type_name -> superserve.vmd.v1.ExecCommandRequest.EnvEntry + 0, // 13: superserve.vmd.v1.GetVMInfoResponse.status:type_name -> superserve.vmd.v1.VMStatus + 1, // 14: superserve.vmd.v1.GetVMInfoResponse.resource_limits:type_name -> superserve.vmd.v1.ResourceLimits + 31, // 15: superserve.vmd.v1.GetVMInfoResponse.metadata:type_name -> superserve.vmd.v1.GetVMInfoResponse.MetadataEntry + 2, // 16: superserve.vmd.v1.SetupNetworkRequest.network_config:type_name -> superserve.vmd.v1.NetworkConfig + 4, // 17: superserve.vmd.v1.UpdateSandboxNetworkRequest.egress:type_name -> superserve.vmd.v1.SandboxNetworkEgressConfig + 5, // 18: superserve.vmd.v1.VMDaemon.CreateVM:input_type -> superserve.vmd.v1.CreateVMRequest + 7, // 19: superserve.vmd.v1.VMDaemon.DestroyVM:input_type -> superserve.vmd.v1.DestroyVMRequest + 9, // 20: superserve.vmd.v1.VMDaemon.PauseVM:input_type -> superserve.vmd.v1.PauseVMRequest + 11, // 21: superserve.vmd.v1.VMDaemon.ResumeVM:input_type -> superserve.vmd.v1.ResumeVMRequest + 13, // 22: superserve.vmd.v1.VMDaemon.CreateSnapshot:input_type -> superserve.vmd.v1.CreateSnapshotRequest + 15, // 23: superserve.vmd.v1.VMDaemon.RestoreSnapshot:input_type -> superserve.vmd.v1.RestoreSnapshotRequest + 17, // 24: superserve.vmd.v1.VMDaemon.DeleteSnapshot:input_type -> superserve.vmd.v1.DeleteSnapshotRequest + 19, // 25: superserve.vmd.v1.VMDaemon.ExecCommand:input_type -> superserve.vmd.v1.ExecCommandRequest + 21, // 26: superserve.vmd.v1.VMDaemon.GetVMInfo:input_type -> superserve.vmd.v1.GetVMInfoRequest + 23, // 27: superserve.vmd.v1.VMDaemon.SetupNetwork:input_type -> superserve.vmd.v1.SetupNetworkRequest + 25, // 28: superserve.vmd.v1.VMDaemon.UpdateSandboxNetwork:input_type -> superserve.vmd.v1.UpdateSandboxNetworkRequest + 6, // 29: superserve.vmd.v1.VMDaemon.CreateVM:output_type -> superserve.vmd.v1.CreateVMResponse + 8, // 30: superserve.vmd.v1.VMDaemon.DestroyVM:output_type -> superserve.vmd.v1.DestroyVMResponse + 10, // 31: superserve.vmd.v1.VMDaemon.PauseVM:output_type -> superserve.vmd.v1.PauseVMResponse + 12, // 32: superserve.vmd.v1.VMDaemon.ResumeVM:output_type -> superserve.vmd.v1.ResumeVMResponse + 14, // 33: superserve.vmd.v1.VMDaemon.CreateSnapshot:output_type -> superserve.vmd.v1.CreateSnapshotResponse + 16, // 34: superserve.vmd.v1.VMDaemon.RestoreSnapshot:output_type -> superserve.vmd.v1.RestoreSnapshotResponse + 18, // 35: superserve.vmd.v1.VMDaemon.DeleteSnapshot:output_type -> superserve.vmd.v1.DeleteSnapshotResponse + 20, // 36: superserve.vmd.v1.VMDaemon.ExecCommand:output_type -> superserve.vmd.v1.ExecCommandResponse + 22, // 37: superserve.vmd.v1.VMDaemon.GetVMInfo:output_type -> superserve.vmd.v1.GetVMInfoResponse + 24, // 38: superserve.vmd.v1.VMDaemon.SetupNetwork:output_type -> superserve.vmd.v1.SetupNetworkResponse + 26, // 39: superserve.vmd.v1.VMDaemon.UpdateSandboxNetwork:output_type -> superserve.vmd.v1.UpdateSandboxNetworkResponse + 29, // [29:40] is the sub-list for method output_type + 18, // [18:29] is the sub-list for method input_type + 18, // [18:18] is the sub-list for extension type_name + 18, // [18:18] is the sub-list for extension extendee + 0, // [0:18] is the sub-list for field type_name } func init() { file_proto_vmd_proto_init() } diff --git a/proto/vmdpb/vmd_grpc.pb.go b/proto/vmdpb/vmd_grpc.pb.go index b34535b..8dbf0d4 100644 --- a/proto/vmdpb/vmd_grpc.pb.go +++ b/proto/vmdpb/vmd_grpc.pb.go @@ -25,11 +25,10 @@ const ( VMDaemon_ResumeVM_FullMethodName = "/superserve.vmd.v1.VMDaemon/ResumeVM" VMDaemon_CreateSnapshot_FullMethodName = "/superserve.vmd.v1.VMDaemon/CreateSnapshot" VMDaemon_RestoreSnapshot_FullMethodName = "/superserve.vmd.v1.VMDaemon/RestoreSnapshot" + VMDaemon_DeleteSnapshot_FullMethodName = "/superserve.vmd.v1.VMDaemon/DeleteSnapshot" VMDaemon_ExecCommand_FullMethodName = "/superserve.vmd.v1.VMDaemon/ExecCommand" VMDaemon_GetVMInfo_FullMethodName = "/superserve.vmd.v1.VMDaemon/GetVMInfo" VMDaemon_SetupNetwork_FullMethodName = "/superserve.vmd.v1.VMDaemon/SetupNetwork" - VMDaemon_UploadFile_FullMethodName = "/superserve.vmd.v1.VMDaemon/UploadFile" - VMDaemon_DownloadFile_FullMethodName = "/superserve.vmd.v1.VMDaemon/DownloadFile" VMDaemon_UpdateSandboxNetwork_FullMethodName = "/superserve.vmd.v1.VMDaemon/UpdateSandboxNetwork" ) @@ -51,16 +50,17 @@ type VMDaemonClient interface { CreateSnapshot(ctx context.Context, in *CreateSnapshotRequest, opts ...grpc.CallOption) (*CreateSnapshotResponse, error) // RestoreSnapshot boots a VM from a previously captured snapshot. RestoreSnapshot(ctx context.Context, in *RestoreSnapshotRequest, opts ...grpc.CallOption) (*RestoreSnapshotResponse, error) + // DeleteSnapshot removes a snapshot's on-disk artifacts (vmstate + memory + // file). Idempotent — missing files are treated as success. Used by the + // control plane to garbage-collect the previous snapshot when a sandbox is + // re-paused and only the latest snapshot is reachable. + DeleteSnapshot(ctx context.Context, in *DeleteSnapshotRequest, opts ...grpc.CallOption) (*DeleteSnapshotResponse, error) // ExecCommand runs a command inside a VM and streams stdout/stderr back. ExecCommand(ctx context.Context, in *ExecCommandRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ExecCommandResponse], error) // GetVMInfo returns the current status and metadata of a VM. GetVMInfo(ctx context.Context, in *GetVMInfoRequest, opts ...grpc.CallOption) (*GetVMInfoResponse, error) // SetupNetwork configures networking for a VM (TAP device, NAT, IP allocation). SetupNetwork(ctx context.Context, in *SetupNetworkRequest, opts ...grpc.CallOption) (*SetupNetworkResponse, error) - // UploadFile streams file content into a VM. - UploadFile(ctx context.Context, opts ...grpc.CallOption) (grpc.ClientStreamingClient[UploadFileRequest, UploadFileResponse], error) - // DownloadFile streams file content out of a VM. - DownloadFile(ctx context.Context, in *DownloadFileRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[DownloadFileChunk], error) // UpdateSandboxNetwork atomically replaces the egress allow/deny rules for a running VM. UpdateSandboxNetwork(ctx context.Context, in *UpdateSandboxNetworkRequest, opts ...grpc.CallOption) (*UpdateSandboxNetworkResponse, error) } @@ -133,6 +133,16 @@ func (c *vMDaemonClient) RestoreSnapshot(ctx context.Context, in *RestoreSnapsho return out, nil } +func (c *vMDaemonClient) DeleteSnapshot(ctx context.Context, in *DeleteSnapshotRequest, opts ...grpc.CallOption) (*DeleteSnapshotResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(DeleteSnapshotResponse) + err := c.cc.Invoke(ctx, VMDaemon_DeleteSnapshot_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + func (c *vMDaemonClient) ExecCommand(ctx context.Context, in *ExecCommandRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ExecCommandResponse], error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) stream, err := c.cc.NewStream(ctx, &VMDaemon_ServiceDesc.Streams[0], VMDaemon_ExecCommand_FullMethodName, cOpts...) @@ -172,38 +182,6 @@ func (c *vMDaemonClient) SetupNetwork(ctx context.Context, in *SetupNetworkReque return out, nil } -func (c *vMDaemonClient) UploadFile(ctx context.Context, opts ...grpc.CallOption) (grpc.ClientStreamingClient[UploadFileRequest, UploadFileResponse], error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) - stream, err := c.cc.NewStream(ctx, &VMDaemon_ServiceDesc.Streams[1], VMDaemon_UploadFile_FullMethodName, cOpts...) - if err != nil { - return nil, err - } - x := &grpc.GenericClientStream[UploadFileRequest, UploadFileResponse]{ClientStream: stream} - return x, nil -} - -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type VMDaemon_UploadFileClient = grpc.ClientStreamingClient[UploadFileRequest, UploadFileResponse] - -func (c *vMDaemonClient) DownloadFile(ctx context.Context, in *DownloadFileRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[DownloadFileChunk], error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) - stream, err := c.cc.NewStream(ctx, &VMDaemon_ServiceDesc.Streams[2], VMDaemon_DownloadFile_FullMethodName, cOpts...) - if err != nil { - return nil, err - } - x := &grpc.GenericClientStream[DownloadFileRequest, DownloadFileChunk]{ClientStream: stream} - if err := x.ClientStream.SendMsg(in); err != nil { - return nil, err - } - if err := x.ClientStream.CloseSend(); err != nil { - return nil, err - } - return x, nil -} - -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type VMDaemon_DownloadFileClient = grpc.ServerStreamingClient[DownloadFileChunk] - func (c *vMDaemonClient) UpdateSandboxNetwork(ctx context.Context, in *UpdateSandboxNetworkRequest, opts ...grpc.CallOption) (*UpdateSandboxNetworkResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(UpdateSandboxNetworkResponse) @@ -232,16 +210,17 @@ type VMDaemonServer interface { CreateSnapshot(context.Context, *CreateSnapshotRequest) (*CreateSnapshotResponse, error) // RestoreSnapshot boots a VM from a previously captured snapshot. RestoreSnapshot(context.Context, *RestoreSnapshotRequest) (*RestoreSnapshotResponse, error) + // DeleteSnapshot removes a snapshot's on-disk artifacts (vmstate + memory + // file). Idempotent — missing files are treated as success. Used by the + // control plane to garbage-collect the previous snapshot when a sandbox is + // re-paused and only the latest snapshot is reachable. + DeleteSnapshot(context.Context, *DeleteSnapshotRequest) (*DeleteSnapshotResponse, error) // ExecCommand runs a command inside a VM and streams stdout/stderr back. ExecCommand(*ExecCommandRequest, grpc.ServerStreamingServer[ExecCommandResponse]) error // GetVMInfo returns the current status and metadata of a VM. GetVMInfo(context.Context, *GetVMInfoRequest) (*GetVMInfoResponse, error) // SetupNetwork configures networking for a VM (TAP device, NAT, IP allocation). SetupNetwork(context.Context, *SetupNetworkRequest) (*SetupNetworkResponse, error) - // UploadFile streams file content into a VM. - UploadFile(grpc.ClientStreamingServer[UploadFileRequest, UploadFileResponse]) error - // DownloadFile streams file content out of a VM. - DownloadFile(*DownloadFileRequest, grpc.ServerStreamingServer[DownloadFileChunk]) error // UpdateSandboxNetwork atomically replaces the egress allow/deny rules for a running VM. UpdateSandboxNetwork(context.Context, *UpdateSandboxNetworkRequest) (*UpdateSandboxNetworkResponse, error) mustEmbedUnimplementedVMDaemonServer() @@ -272,6 +251,9 @@ func (UnimplementedVMDaemonServer) CreateSnapshot(context.Context, *CreateSnapsh func (UnimplementedVMDaemonServer) RestoreSnapshot(context.Context, *RestoreSnapshotRequest) (*RestoreSnapshotResponse, error) { return nil, status.Error(codes.Unimplemented, "method RestoreSnapshot not implemented") } +func (UnimplementedVMDaemonServer) DeleteSnapshot(context.Context, *DeleteSnapshotRequest) (*DeleteSnapshotResponse, error) { + return nil, status.Error(codes.Unimplemented, "method DeleteSnapshot not implemented") +} func (UnimplementedVMDaemonServer) ExecCommand(*ExecCommandRequest, grpc.ServerStreamingServer[ExecCommandResponse]) error { return status.Error(codes.Unimplemented, "method ExecCommand not implemented") } @@ -281,12 +263,6 @@ func (UnimplementedVMDaemonServer) GetVMInfo(context.Context, *GetVMInfoRequest) func (UnimplementedVMDaemonServer) SetupNetwork(context.Context, *SetupNetworkRequest) (*SetupNetworkResponse, error) { return nil, status.Error(codes.Unimplemented, "method SetupNetwork not implemented") } -func (UnimplementedVMDaemonServer) UploadFile(grpc.ClientStreamingServer[UploadFileRequest, UploadFileResponse]) error { - return status.Error(codes.Unimplemented, "method UploadFile not implemented") -} -func (UnimplementedVMDaemonServer) DownloadFile(*DownloadFileRequest, grpc.ServerStreamingServer[DownloadFileChunk]) error { - return status.Error(codes.Unimplemented, "method DownloadFile not implemented") -} func (UnimplementedVMDaemonServer) UpdateSandboxNetwork(context.Context, *UpdateSandboxNetworkRequest) (*UpdateSandboxNetworkResponse, error) { return nil, status.Error(codes.Unimplemented, "method UpdateSandboxNetwork not implemented") } @@ -419,6 +395,24 @@ func _VMDaemon_RestoreSnapshot_Handler(srv interface{}, ctx context.Context, dec return interceptor(ctx, in, info, handler) } +func _VMDaemon_DeleteSnapshot_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(DeleteSnapshotRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(VMDaemonServer).DeleteSnapshot(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: VMDaemon_DeleteSnapshot_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(VMDaemonServer).DeleteSnapshot(ctx, req.(*DeleteSnapshotRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _VMDaemon_ExecCommand_Handler(srv interface{}, stream grpc.ServerStream) error { m := new(ExecCommandRequest) if err := stream.RecvMsg(m); err != nil { @@ -466,24 +460,6 @@ func _VMDaemon_SetupNetwork_Handler(srv interface{}, ctx context.Context, dec fu return interceptor(ctx, in, info, handler) } -func _VMDaemon_UploadFile_Handler(srv interface{}, stream grpc.ServerStream) error { - return srv.(VMDaemonServer).UploadFile(&grpc.GenericServerStream[UploadFileRequest, UploadFileResponse]{ServerStream: stream}) -} - -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type VMDaemon_UploadFileServer = grpc.ClientStreamingServer[UploadFileRequest, UploadFileResponse] - -func _VMDaemon_DownloadFile_Handler(srv interface{}, stream grpc.ServerStream) error { - m := new(DownloadFileRequest) - if err := stream.RecvMsg(m); err != nil { - return err - } - return srv.(VMDaemonServer).DownloadFile(m, &grpc.GenericServerStream[DownloadFileRequest, DownloadFileChunk]{ServerStream: stream}) -} - -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type VMDaemon_DownloadFileServer = grpc.ServerStreamingServer[DownloadFileChunk] - func _VMDaemon_UpdateSandboxNetwork_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(UpdateSandboxNetworkRequest) if err := dec(in); err != nil { @@ -533,6 +509,10 @@ var VMDaemon_ServiceDesc = grpc.ServiceDesc{ MethodName: "RestoreSnapshot", Handler: _VMDaemon_RestoreSnapshot_Handler, }, + { + MethodName: "DeleteSnapshot", + Handler: _VMDaemon_DeleteSnapshot_Handler, + }, { MethodName: "GetVMInfo", Handler: _VMDaemon_GetVMInfo_Handler, @@ -552,16 +532,6 @@ var VMDaemon_ServiceDesc = grpc.ServiceDesc{ Handler: _VMDaemon_ExecCommand_Handler, ServerStreams: true, }, - { - StreamName: "UploadFile", - Handler: _VMDaemon_UploadFile_Handler, - ClientStreams: true, - }, - { - StreamName: "DownloadFile", - Handler: _VMDaemon_DownloadFile_Handler, - ServerStreams: true, - }, }, Metadata: "proto/vmd.proto", } diff --git a/scripts/fc-cleanup b/scripts/fc-cleanup new file mode 100755 index 0000000..323d48f --- /dev/null +++ b/scripts/fc-cleanup @@ -0,0 +1,15 @@ +#!/bin/sh +# fc-cleanup — post-stop cleanup for a Firecracker VM unit. +# Called by ExecStopPost in firecracker@.service. +# +# Usage: fc-cleanup +# RUN_DIR is inherited from the unit's EnvironmentFile. + +set -eu + +SANDBOX_ID="$1" +SOCKET="${RUN_DIR}/${SANDBOX_ID}/firecracker.sock" + +rm -f "${SOCKET}" + +echo "fc-cleanup: cleaned up ${SANDBOX_ID}" diff --git a/scripts/setup-proxy-infra.sh b/scripts/setup-proxy-infra.sh index ce2fd87..ac04869 100755 --- a/scripts/setup-proxy-infra.sh +++ b/scripts/setup-proxy-infra.sh @@ -34,21 +34,22 @@ set -euo pipefail PROJECT="${GCP_PROJECT}" REGION="${ZONE%-*}" # strips zone suffix, e.g. us-central1-a → us-central1 WILDCARD_DOMAIN="*.${DOMAIN}" -PROXY_PORT=5007 +PROXY_PORT=5007 # Main listener: HTTPS-after-LB-termination, HTTP/1.1, WebSocket +PROXY_REDIRECT_PORT=5008 # Tiny listener: HTTP→HTTPS 301 redirect # Resource names IP_NAME="sandbox-proxy-ip" IG_NAME="sandbox-proxy-ig" # unmanaged instance group -HC_NAME="sandbox-proxy-hc" # health check -BACKEND_NAME="sandbox-proxy-backend" +HC_NAME="sandbox-proxy-hc" # HTTP health check on PROXY_PORT +HC_REDIRECT_NAME="sandbox-proxy-redirect-hc" # TCP health check on PROXY_REDIRECT_PORT +BACKEND_NAME="sandbox-proxy-backend" # main TCP backend → proxy:5007 +BACKEND_REDIRECT_NAME="sandbox-proxy-redirect-backend" # TCP backend → proxy:5008 CERT_MAP_NAME="sandbox-proxy-cert-map" CERT_MAP_ENTRY="sandbox-proxy-cert-entry" CERT_NAME="sandbox-proxy-cert" DNS_AUTH_NAME="sandbox-proxy-dns-auth" -URL_MAP_NAME="sandbox-proxy-url-map" -URL_MAP_HTTP_NAME="sandbox-proxy-url-map-http" # for HTTP→HTTPS redirect -HTTPS_PROXY_NAME="sandbox-proxy-https" -HTTP_PROXY_NAME="sandbox-proxy-http" +SSL_PROXY_NAME="sandbox-proxy-ssl" # target SSL proxy (terminates TLS at LB) +TCP_PROXY_NAME="sandbox-proxy-tcp" # target TCP proxy (port 80 → redirect listener) FWD_RULE_HTTPS="sandbox-proxy-https-fwd" FWD_RULE_HTTP="sandbox-proxy-http-fwd" @@ -86,9 +87,11 @@ else --project="${PROJECT}" fi -# Define the named port so the backend service knows which port to target. +# Define the named ports so backend services know which ports to target. +# - "proxy" → main HTTP listener (TLS terminates at the LB, plain HTTP here) +# - "proxy-redirect" → tiny HTTP listener that 301-redirects to https:// gcloud compute instance-groups unmanaged set-named-ports "${IG_NAME}" \ - --named-ports="proxy:${PROXY_PORT}" \ + --named-ports="proxy:${PROXY_PORT},proxy-redirect:${PROXY_REDIRECT_PORT}" \ --zone="${ZONE}" \ --project="${PROJECT}" @@ -104,30 +107,33 @@ FW_LB_NAME="allow-sandbox-proxy-lb" if ! gcloud compute firewall-rules describe "${FW_HC_NAME}" --project="${PROJECT}" &>/dev/null; then gcloud compute firewall-rules create "${FW_HC_NAME}" \ --network="${NETWORK}" \ - --allow="tcp:${PROXY_PORT}" \ + --allow="tcp:${PROXY_PORT},tcp:${PROXY_REDIRECT_PORT}" \ --source-ranges="35.191.0.0/16,130.211.0.0/22" \ --target-tags="${INSTANCE_TAG:-vmd}" \ - --description="Allow GCP LB health check probes to edge proxy" \ + --description="Allow GCP LB health check probes to edge proxy (main + redirect ports)" \ --project="${PROJECT}" fi if ! gcloud compute firewall-rules describe "${FW_LB_NAME}" --project="${PROJECT}" &>/dev/null; then gcloud compute firewall-rules create "${FW_LB_NAME}" \ --network="${NETWORK}" \ - --allow="tcp:${PROXY_PORT}" \ + --allow="tcp:${PROXY_PORT},tcp:${PROXY_REDIRECT_PORT}" \ --source-ranges="130.211.0.0/22,35.191.0.0/16" \ --target-tags="${INSTANCE_TAG:-vmd}" \ - --description="Allow GCP LB backend traffic to edge proxy" \ + --description="Allow GCP LB backend traffic to edge proxy (main + redirect ports)" \ --project="${PROJECT}" fi # --------------------------------------------------------------------------- -# 4. HTTP health check on /health +# 4. Health checks +# - Main backend uses an HTTP health check on /health +# - Redirect backend uses a TCP health check (the redirect listener has +# no /health endpoint, only a 301 handler) # --------------------------------------------------------------------------- echo "" -echo "==> [4/9] Creating health check..." +echo "==> [4/9] Creating health checks..." if gcloud compute health-checks describe "${HC_NAME}" --global --project="${PROJECT}" &>/dev/null; then - echo " already exists, skipping" + echo " ${HC_NAME} already exists, skipping" else gcloud compute health-checks create http "${HC_NAME}" \ --global \ @@ -140,23 +146,85 @@ else --project="${PROJECT}" fi +if gcloud compute health-checks describe "${HC_REDIRECT_NAME}" --global --project="${PROJECT}" &>/dev/null; then + echo " ${HC_REDIRECT_NAME} already exists, skipping" +else + gcloud compute health-checks create tcp "${HC_REDIRECT_NAME}" \ + --global \ + --port-name=proxy-redirect \ + --check-interval=10 \ + --timeout=5 \ + --healthy-threshold=2 \ + --unhealthy-threshold=3 \ + --project="${PROJECT}" +fi + # --------------------------------------------------------------------------- -# 5. Backend service -# Timeout 630s > proxy server idle (620s) > proxy transport idle (610s) > GCP LB upstream (600s) +# 5. Backend services +# +# We use TWO backend services and TWO load balancers in front of the proxy: +# +# (a) Main backend (TCP protocol, port-name=proxy) — fronted by an SSL +# Proxy Network LB on port 443. The LB terminates TLS using the +# Certificate Manager wildcard cert and forwards plain TCP to the +# proxy on port 5007. The proxy serves plain HTTP — TLS termination +# is at the LB. +# +# Critically, SSL Proxy LB does NOT advertise HTTP/2 in TLS ALPN, so +# browsers fall back to HTTP/1.1. This is the *whole reason* we use +# SSL Proxy LB instead of the Application LB: GCP's Application LB +# advertises h2 in ALPN and then strips the WebSocket Upgrade headers +# during HTTP/2→HTTP/1.1 translation, breaking every browser-based +# WebSocket upgrade. Confirmed empirically; do not switch back. +# +# (b) Redirect backend (TCP protocol, port-name=proxy-redirect) — fronted +# by a TCP Proxy LB on port 80. Plain TCP forwarding to the proxy's +# tiny HTTP-only redirect listener on port 5008, which serves a 301 +# to the same URL on https://. Lives on the same instance group, just +# a different named port. +# +# Long timeouts on the main backend so streaming WebSocket connections +# (terminal sessions, exec streams) survive idle periods. The 86400s +# (24h) value matches the maximum the proxy itself will allow before +# its idle timer kicks in. # --------------------------------------------------------------------------- echo "" -echo "==> [5/9] Creating backend service..." +echo "==> [5/9] Creating backend services..." if gcloud compute backend-services describe "${BACKEND_NAME}" --global --project="${PROJECT}" &>/dev/null; then - echo " already exists, skipping" + echo " ${BACKEND_NAME} already exists, skipping" else gcloud compute backend-services create "${BACKEND_NAME}" \ --global \ - --protocol=HTTP \ + --load-balancing-scheme=EXTERNAL_MANAGED \ + --protocol=TCP \ --port-name=proxy \ --health-checks="${HC_NAME}" \ - --timeout=630 \ + --timeout=86400 \ + --connection-draining-timeout=300 \ + --enable-logging \ + --logging-sample-rate=1.0 \ --project="${PROJECT}" gcloud compute backend-services add-backend "${BACKEND_NAME}" \ + --global \ + --instance-group="${IG_NAME}" \ + --instance-group-zone="${ZONE}" \ + --balancing-mode=UTILIZATION \ + --max-utilization=0.8 \ + --project="${PROJECT}" +fi + +if gcloud compute backend-services describe "${BACKEND_REDIRECT_NAME}" --global --project="${PROJECT}" &>/dev/null; then + echo " ${BACKEND_REDIRECT_NAME} already exists, skipping" +else + gcloud compute backend-services create "${BACKEND_REDIRECT_NAME}" \ + --global \ + --load-balancing-scheme=EXTERNAL_MANAGED \ + --protocol=TCP \ + --port-name=proxy-redirect \ + --health-checks="${HC_REDIRECT_NAME}" \ + --timeout=30 \ + --project="${PROJECT}" + gcloud compute backend-services add-backend "${BACKEND_REDIRECT_NAME}" \ --global \ --instance-group="${IG_NAME}" \ --instance-group-zone="${ZONE}" \ @@ -204,70 +272,58 @@ if ! gcloud certificate-manager maps entries describe "${CERT_MAP_ENTRY}" \ fi # --------------------------------------------------------------------------- -# 7. URL maps +# 7. Target proxies +# +# - target-ssl-proxy → SSL Proxy LB on port 443. Terminates TLS at the LB +# using the Certificate Manager wildcard cert. Forwards plain TCP to the +# main backend. Does NOT speak HTTP, does NOT advertise h2 in ALPN. +# - target-tcp-proxy → TCP Proxy LB on port 80. Plain TCP forwarding to +# the redirect backend (which serves a 301). +# +# We deliberately do NOT use target-https-proxy / Application LB here +# because GCP's Application LB strips the WebSocket Upgrade headers when +# translating HTTP/2 client connections to HTTP/1.1 backend connections, +# breaking every browser-based WebSocket upgrade. Both classic and +# modern (EXTERNAL_MANAGED) Application LBs have this bug. Confirmed +# empirically. Do not switch back without re-verifying. # --------------------------------------------------------------------------- echo "" -echo "==> [7/9] Creating URL maps..." +echo "==> [7/9] Creating target proxies..." -# HTTPS: route everything to the backend -if ! gcloud compute url-maps describe "${URL_MAP_NAME}" --global --project="${PROJECT}" &>/dev/null; then - gcloud compute url-maps create "${URL_MAP_NAME}" \ - --default-service="${BACKEND_NAME}" \ - --global \ +if ! gcloud compute target-ssl-proxies describe "${SSL_PROXY_NAME}" --project="${PROJECT}" &>/dev/null; then + gcloud compute target-ssl-proxies create "${SSL_PROXY_NAME}" \ + --backend-service="${BACKEND_NAME}" \ + --certificate-map="${CERT_MAP_NAME}" \ --project="${PROJECT}" fi -# HTTP: redirect all traffic to HTTPS -if ! gcloud compute url-maps describe "${URL_MAP_HTTP_NAME}" --global --project="${PROJECT}" &>/dev/null; then - gcloud compute url-maps import "${URL_MAP_HTTP_NAME}" \ - --global \ - --project="${PROJECT}" \ - --source=/dev/stdin <<'YAML' -name: sandbox-proxy-url-map-http -defaultUrlRedirect: - redirectResponseCode: MOVED_PERMANENTLY_DEFAULT - httpsRedirect: true -YAML +if ! gcloud compute target-tcp-proxies describe "${TCP_PROXY_NAME}" --project="${PROJECT}" &>/dev/null; then + gcloud compute target-tcp-proxies create "${TCP_PROXY_NAME}" \ + --backend-service="${BACKEND_REDIRECT_NAME}" \ + --project="${PROJECT}" fi # --------------------------------------------------------------------------- -# 8. Target proxies and forwarding rules +# 8. Forwarding rules (both on the same global static IP) # --------------------------------------------------------------------------- echo "" -echo "==> [8/9] Creating target proxies and forwarding rules..." - -# HTTPS target proxy — references the certificate map -if ! gcloud compute target-https-proxies describe "${HTTPS_PROXY_NAME}" --global --project="${PROJECT}" &>/dev/null; then - gcloud compute target-https-proxies create "${HTTPS_PROXY_NAME}" \ - --url-map="${URL_MAP_NAME}" \ - --certificate-map="${CERT_MAP_NAME}" \ - --global \ - --project="${PROJECT}" -fi - -# HTTP target proxy (for redirect) -if ! gcloud compute target-http-proxies describe "${HTTP_PROXY_NAME}" --global --project="${PROJECT}" &>/dev/null; then - gcloud compute target-http-proxies create "${HTTP_PROXY_NAME}" \ - --url-map="${URL_MAP_HTTP_NAME}" \ - --global \ - --project="${PROJECT}" -fi +echo "==> [8/9] Creating forwarding rules..." -# HTTPS forwarding rule if ! gcloud compute forwarding-rules describe "${FWD_RULE_HTTPS}" --global --project="${PROJECT}" &>/dev/null; then gcloud compute forwarding-rules create "${FWD_RULE_HTTPS}" \ --global \ - --target-https-proxy="${HTTPS_PROXY_NAME}" \ + --load-balancing-scheme=EXTERNAL_MANAGED \ + --target-ssl-proxy="${SSL_PROXY_NAME}" \ --address="${IP_NAME}" \ --ports=443 \ --project="${PROJECT}" fi -# HTTP forwarding rule (redirect to HTTPS) if ! gcloud compute forwarding-rules describe "${FWD_RULE_HTTP}" --global --project="${PROJECT}" &>/dev/null; then gcloud compute forwarding-rules create "${FWD_RULE_HTTP}" \ --global \ - --target-http-proxy="${HTTP_PROXY_NAME}" \ + --load-balancing-scheme=EXTERNAL_MANAGED \ + --target-tcp-proxy="${TCP_PROXY_NAME}" \ --address="${IP_NAME}" \ --ports=80 \ --project="${PROJECT}" diff --git a/supabase/migrations/20260410000001_host_table.sql b/supabase/migrations/20260410000001_host_table.sql new file mode 100644 index 0000000..8f286fc --- /dev/null +++ b/supabase/migrations/20260410000001_host_table.sql @@ -0,0 +1,27 @@ +-- Host table: models bare-metal machines running VMD. One row per host. +-- Multi-host-ready from day one — adding a second host is an ops task +-- (insert a row, deploy VMD), not an engineering project. + +CREATE TABLE host ( + id text PRIMARY KEY, + vmd_addr text NOT NULL, + proxy_addr text NOT NULL, + region text NOT NULL, + status text NOT NULL DEFAULT 'active', + capacity_memory_mib int NOT NULL, + capacity_vcpus int NOT NULL, + last_heartbeat_at timestamptz, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + + CONSTRAINT host_status_valid CHECK (status IN ('active', 'draining', 'unhealthy')), + CONSTRAINT host_capacity_memory_positive CHECK (capacity_memory_mib > 0), + CONSTRAINT host_capacity_vcpus_positive CHECK (capacity_vcpus > 0) +); + +ALTER TABLE public.host ENABLE ROW LEVEL SECURITY; + +-- Backfill any sandbox rows with NULL host_id so we can add NOT NULL. +UPDATE sandbox SET host_id = 'default' WHERE host_id IS NULL; + +ALTER TABLE sandbox ALTER COLUMN host_id SET NOT NULL; diff --git a/supabase/migrations/20260410000002_reconciler_log.sql b/supabase/migrations/20260410000002_reconciler_log.sql new file mode 100644 index 0000000..023bfa7 --- /dev/null +++ b/supabase/migrations/20260410000002_reconciler_log.sql @@ -0,0 +1,24 @@ +-- Audit log for every reconciler action. On-call reads this table first +-- when diagnosing "my sandbox died at 3am" — no silent actions. + +CREATE TABLE reconciler_log ( + id bigserial PRIMARY KEY, + host_id text NOT NULL, + sandbox_id uuid, + action text NOT NULL, + reason text NOT NULL, + drift_kind text, + created_at timestamptz NOT NULL DEFAULT now(), + + CONSTRAINT reconciler_log_action_valid CHECK (action IN ( + 'mark_failed', + 'orphan_stop', + 'stale_cleanup', + 'budget_exhausted' + )) +); + +CREATE INDEX idx_reconciler_log_host ON reconciler_log(host_id, created_at DESC); +CREATE INDEX idx_reconciler_log_sandbox ON reconciler_log(sandbox_id, created_at DESC) WHERE sandbox_id IS NOT NULL; + +ALTER TABLE public.reconciler_log ENABLE ROW LEVEL SECURITY; diff --git a/supabase/migrations/20260414000001_snapshot_mem_path.sql b/supabase/migrations/20260414000001_snapshot_mem_path.sql new file mode 100644 index 0000000..2180641 --- /dev/null +++ b/supabase/migrations/20260414000001_snapshot_mem_path.sql @@ -0,0 +1,3 @@ +-- Add mem_path column to snapshot table. Previously derived by convention +-- from filepath.Dir(path) + "mem.snap", which is fragile. +ALTER TABLE snapshot ADD COLUMN mem_path text; diff --git a/supabase/migrations/20260414000002_drop_last_activity.sql b/supabase/migrations/20260414000002_drop_last_activity.sql new file mode 100644 index 0000000..92d3ebd --- /dev/null +++ b/supabase/migrations/20260414000002_drop_last_activity.sql @@ -0,0 +1,6 @@ +-- Drop last_activity_at column and its index. The column was used by the +-- (now removed) auto-wake middleware and idle-sandbox listing — with those +-- gone, nothing reads the column. + +DROP INDEX IF EXISTS idx_sandbox_last_activity; +ALTER TABLE sandbox DROP COLUMN IF EXISTS last_activity_at; diff --git a/supabase/migrations/20260414000003_rename_idle_to_paused.sql b/supabase/migrations/20260414000003_rename_idle_to_paused.sql new file mode 100644 index 0000000..338cd6f --- /dev/null +++ b/supabase/migrations/20260414000003_rename_idle_to_paused.sql @@ -0,0 +1,6 @@ +-- Rename the 'idle' sandbox status to 'paused'. The old name was a holdover +-- from an earlier design where idle-detection was meant to drive auto-pause; +-- the state itself just means "VM stopped, memory+disk snapshotted". Align +-- the name with the actual semantics and with the /pause, /resume endpoints. + +ALTER TYPE sandbox_status RENAME VALUE 'idle' TO 'paused';