diff --git a/backend/app/api/api.py b/backend/app/api/api.py index d37ff8d1a..220f1e386 100644 --- a/backend/app/api/api.py +++ b/backend/app/api/api.py @@ -56,6 +56,9 @@ from app.api.endpoints.internal import ( callback_router, chat_storage_router, +) +from app.api.endpoints.internal import devices_router as internal_devices_router +from app.api.endpoints.internal import ( services_router, skills_router, subscriptions_router, @@ -175,6 +178,9 @@ api_router.include_router(skills_router, prefix="/internal", tags=["internal-skills"]) api_router.include_router(tables_router, prefix="/internal", tags=["internal-tables"]) +api_router.include_router( + internal_devices_router, prefix="/internal", tags=["internal-devices"] +) api_router.include_router( internal_bots_router, prefix="/internal", tags=["internal-bots"] ) diff --git a/backend/app/api/endpoints/devices.py b/backend/app/api/endpoints/devices.py index 61c5cde31..6f48c3d5f 100644 --- a/backend/app/api/endpoints/devices.py +++ b/backend/app/api/endpoints/devices.py @@ -61,6 +61,42 @@ class DeviceUpgradeResponse(BaseModel): message: str = Field(..., description="Human-readable status message") +class DeviceSandboxExecRequest(BaseModel): + """Request model for executing a command on a user device.""" + + command: str = Field(..., min_length=1, description="Command to execute") + working_dir: str = Field( + default="/home/user", + description="Working directory for command execution", + ) + timeout_seconds: int = Field( + default=300, + ge=1, + le=1800, + description="Command timeout in seconds", + ) + required_capability: Optional[str] = Field( + default=None, + description="Optional device capability required for routing", + ) + device_id: Optional[str] = Field( + default=None, + description="Optional explicit device ID override", + ) + + +class DeviceSandboxExecResponse(BaseModel): + """Response model for a device-backed command execution.""" + + success: bool = Field(..., description="Whether the command succeeded") + stdout: str = Field(default="", description="Standard output") + stderr: str = Field(default="", description="Standard error") + exit_code: int = Field(..., description="Process exit code") + execution_time: float = Field(..., description="Execution time in seconds") + device_id: str = Field(..., description="Device that executed the command") + backend: str = Field(default="device", description="Execution backend identifier") + + @router.get("", response_model=DeviceListResponse) async def get_all_devices( db: Session = Depends(get_db), @@ -162,6 +198,47 @@ async def delete_device( return {"message": f"Device '{device_id}' deleted"} +@router.post("/sandbox/exec", response_model=DeviceSandboxExecResponse) +async def execute_device_sandbox_command( + request: DeviceSandboxExecRequest, + db: Session = Depends(get_db), + current_user: User = Depends(security.get_current_user), +) -> DeviceSandboxExecResponse: + """ + Execute a command on an online user device through the existing device channel. + + The backend selects a compatible online device, forwards the command over + `/local-executor`, and returns the device's execution result. + """ + from app.services.device_sandbox_service import ( + DeviceSandboxError, + device_sandbox_service, + ) + + try: + result = await device_sandbox_service.execute_command( + db=db, + user_id=current_user.id, + command=request.command, + working_dir=request.working_dir, + timeout_seconds=request.timeout_seconds, + required_capability=request.required_capability, + device_id=request.device_id, + ) + except DeviceSandboxError as exc: + logger.warning( + "[Device Sandbox] Command rejected: user_id=%s, error=%s", + current_user.id, + exc, + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(exc), + ) from exc + + return DeviceSandboxExecResponse(**result) + + @router.post("/{device_id}/upgrade", response_model=DeviceUpgradeResponse) async def trigger_device_upgrade( device_id: str, diff --git a/backend/app/api/endpoints/internal/__init__.py b/backend/app/api/endpoints/internal/__init__.py index 616bf619d..110d0570b 100644 --- a/backend/app/api/endpoints/internal/__init__.py +++ b/backend/app/api/endpoints/internal/__init__.py @@ -9,6 +9,7 @@ from .bots import router as bots_router from .callback import router as callback_router from .chat_storage import router as chat_storage_router +from .devices import router as devices_router from .services import router as services_router from .skills import router as skills_router from .subscriptions import router as subscriptions_router @@ -23,6 +24,7 @@ "bots_router", "callback_router", "chat_storage_router", + "devices_router", "services_router", "skills_router", "subscriptions_router", diff --git a/backend/app/api/endpoints/internal/devices.py b/backend/app/api/endpoints/internal/devices.py new file mode 100644 index 000000000..25b303472 --- /dev/null +++ b/backend/app/api/endpoints/internal/devices.py @@ -0,0 +1,377 @@ +# SPDX-FileCopyrightText: 2025 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Internal device APIs for service-to-service communication.""" + +from typing import Any, Optional + +from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import BaseModel, Field +from sqlalchemy.orm import Session + +from app.api.dependencies import get_db +from app.api.endpoints.devices import DeviceSandboxExecResponse + +router = APIRouter(prefix="/devices", tags=["internal-devices"]) + + +class InternalDeviceSandboxExecRequest(BaseModel): + """Internal request model for device-backed command execution.""" + + task_id: int | None = Field(default=None, ge=1, description="Optional task ID") + user_id: int = Field(..., ge=1, description="Owner user ID") + command: str = Field(..., min_length=1, description="Command to execute") + working_dir: str = Field( + default="/home/user", + description="Working directory for command execution", + ) + timeout_seconds: int = Field( + default=300, + ge=1, + le=1800, + description="Command timeout in seconds", + ) + required_capability: str | None = Field( + default=None, + description="Optional device capability required for routing", + ) + device_id: str | None = Field( + default=None, + description="Optional explicit device ID override", + ) + + +class InternalTaskSandboxBindingResponse(BaseModel): + """Current sticky sandbox binding for a task.""" + + backend: str | None = None + device_id: str | None = None + + +class InternalDeviceSandboxReadFileRequest(BaseModel): + """Internal request model for device-backed file reads.""" + + task_id: int | None = Field(default=None, ge=1, description="Optional task ID") + user_id: int = Field(..., ge=1, description="Owner user ID") + file_path: str = Field(..., min_length=1, description="Path to read") + format: str = Field(default="text", description="Read format: text or bytes") + device_id: str | None = Field( + default=None, description="Optional explicit device ID override" + ) + + +class InternalDeviceSandboxListFilesRequest(BaseModel): + """Internal request model for device-backed file listing.""" + + task_id: int | None = Field(default=None, ge=1, description="Optional task ID") + user_id: int = Field(..., ge=1, description="Owner user ID") + path: str = Field(default="/home/user", description="Directory to list") + depth: int = Field(default=1, ge=1, le=10, description="Listing depth") + device_id: str | None = Field( + default=None, description="Optional explicit device ID override" + ) + + +class InternalDeviceSandboxWriteFileRequest(BaseModel): + """Internal request model for device-backed file writes.""" + + task_id: int | None = Field(default=None, ge=1, description="Optional task ID") + user_id: int = Field(..., ge=1, description="Owner user ID") + file_path: str = Field(..., min_length=1, description="Path to write") + content: str = Field(..., description="File content") + format: str = Field(default="text", description="Write format: text or bytes") + create_dirs: bool = Field( + default=True, description="Create parent directories automatically" + ) + device_id: str | None = Field( + default=None, description="Optional explicit device ID override" + ) + + +class InternalDeviceSandboxDownloadAttachmentRequest(BaseModel): + """Internal request model for device-backed attachment downloads.""" + + task_id: int | None = Field(default=None, ge=1, description="Optional task ID") + user_id: int = Field(..., ge=1, description="Owner user ID") + attachment_url: str = Field( + ..., min_length=1, description="Attachment download URL" + ) + save_path: str = Field(..., min_length=1, description="Destination path on device") + auth_token: str = Field(..., min_length=1, description="Task or user auth token") + api_base_url: str = Field(..., min_length=1, description="Backend base URL") + timeout_seconds: int = Field( + default=300, ge=1, le=1800, description="Download timeout in seconds" + ) + device_id: str | None = Field( + default=None, description="Optional explicit device ID override" + ) + + +class InternalDeviceSandboxUploadAttachmentRequest(BaseModel): + """Internal request model for device-backed attachment uploads.""" + + task_id: int | None = Field(default=None, ge=1, description="Optional task ID") + user_id: int = Field(..., ge=1, description="Owner user ID") + file_path: str = Field(..., min_length=1, description="Path to the local file") + auth_token: str = Field(..., min_length=1, description="Task or user auth token") + api_base_url: str = Field(..., min_length=1, description="Backend base URL") + overwrite_attachment_id: int | None = Field( + default=None, ge=1, description="Optional attachment to overwrite" + ) + timeout_seconds: int = Field( + default=300, ge=1, le=1800, description="Upload timeout in seconds" + ) + device_id: str | None = Field( + default=None, description="Optional explicit device ID override" + ) + + +class DeviceSandboxGenericResponse(BaseModel): + """Generic response model for device-backed sandbox file helpers.""" + + success: bool = Field(..., description="Whether the operation succeeded") + device_id: str = Field(..., description="Device that executed the operation") + backend: str = Field(default="device", description="Execution backend identifier") + execution_time: float = Field(..., description="Execution time in seconds") + data: dict[str, Any] = Field( + default_factory=dict, description="Operation-specific payload" + ) + + +def _build_generic_response(result: dict[str, Any]) -> DeviceSandboxGenericResponse: + """Convert a device sandbox service response into the generic endpoint model.""" + payload = dict(result) + return DeviceSandboxGenericResponse( + success=bool(payload.pop("success", False)), + device_id=str(payload.pop("device_id")), + backend=str(payload.pop("backend", "device")), + execution_time=float(payload.pop("execution_time", 0.0)), + data=payload, + ) + + +@router.post("/sandbox/exec", response_model=DeviceSandboxExecResponse) +async def execute_device_sandbox_command_internal( + request: InternalDeviceSandboxExecRequest, + db: Session = Depends(get_db), +) -> DeviceSandboxExecResponse: + """Execute a command on a user's device for internal trusted services.""" + from app.services.device_sandbox_service import ( + DeviceSandboxError, + device_sandbox_service, + ) + + try: + result = await device_sandbox_service.execute_command( + db=db, + task_id=request.task_id, + user_id=request.user_id, + command=request.command, + working_dir=request.working_dir, + timeout_seconds=request.timeout_seconds, + required_capability=request.required_capability, + device_id=request.device_id, + ) + except DeviceSandboxError as exc: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(exc), + ) from exc + + db.commit() + return DeviceSandboxExecResponse(**result) + + +@router.get( + "/sandbox/binding/{task_id}", response_model=InternalTaskSandboxBindingResponse +) +async def get_task_sandbox_binding( + task_id: int, + user_id: int, + db: Session = Depends(get_db), +) -> InternalTaskSandboxBindingResponse: + """Return the sticky sandbox binding currently stored on the task.""" + from app.models.task import TaskResource + from app.services.device_sandbox_service import ( + DEVICE_BACKEND_NAME, + SANDBOX_BACKEND_LABEL, + SANDBOX_DEVICE_ID_LABEL, + ) + + task = ( + db.query(TaskResource) + .filter( + TaskResource.id == task_id, + TaskResource.user_id == user_id, + TaskResource.kind == "Task", + TaskResource.is_active == TaskResource.STATE_ACTIVE, + ) + .first() + ) + if not task: + return InternalTaskSandboxBindingResponse() + + task_json = task.json if isinstance(task.json, dict) else {} + labels = task_json.get("metadata", {}).get("labels", {}) + backend = labels.get(SANDBOX_BACKEND_LABEL) + device_id = labels.get(SANDBOX_DEVICE_ID_LABEL) + if backend != DEVICE_BACKEND_NAME or not device_id: + return InternalTaskSandboxBindingResponse() + + return InternalTaskSandboxBindingResponse(backend=backend, device_id=device_id) + + +@router.post("/sandbox/read-file", response_model=DeviceSandboxGenericResponse) +async def read_device_sandbox_file_internal( + request: InternalDeviceSandboxReadFileRequest, + db: Session = Depends(get_db), +) -> DeviceSandboxGenericResponse: + """Read a file from the bound device-backed sandbox.""" + from app.services.device_sandbox_service import ( + DeviceSandboxError, + device_sandbox_service, + ) + + try: + result = await device_sandbox_service.read_file( + db=db, + user_id=request.user_id, + task_id=request.task_id, + file_path=request.file_path, + format=request.format, + device_id=request.device_id, + ) + except DeviceSandboxError as exc: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc) + ) from exc + + db.commit() + return _build_generic_response(result) + + +@router.post("/sandbox/list-files", response_model=DeviceSandboxGenericResponse) +async def list_device_sandbox_files_internal( + request: InternalDeviceSandboxListFilesRequest, + db: Session = Depends(get_db), +) -> DeviceSandboxGenericResponse: + """List files from the bound device-backed sandbox.""" + from app.services.device_sandbox_service import ( + DeviceSandboxError, + device_sandbox_service, + ) + + try: + result = await device_sandbox_service.list_files( + db=db, + user_id=request.user_id, + task_id=request.task_id, + path=request.path, + depth=request.depth, + device_id=request.device_id, + ) + except DeviceSandboxError as exc: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc) + ) from exc + + db.commit() + return _build_generic_response(result) + + +@router.post("/sandbox/write-file", response_model=DeviceSandboxGenericResponse) +async def write_device_sandbox_file_internal( + request: InternalDeviceSandboxWriteFileRequest, + db: Session = Depends(get_db), +) -> DeviceSandboxGenericResponse: + """Write a file into the bound device-backed sandbox.""" + from app.services.device_sandbox_service import ( + DeviceSandboxError, + device_sandbox_service, + ) + + try: + result = await device_sandbox_service.write_file( + db=db, + user_id=request.user_id, + task_id=request.task_id, + file_path=request.file_path, + content=request.content, + format=request.format, + create_dirs=request.create_dirs, + device_id=request.device_id, + ) + except DeviceSandboxError as exc: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc) + ) from exc + + db.commit() + return _build_generic_response(result) + + +@router.post( + "/sandbox/download-attachment", response_model=DeviceSandboxGenericResponse +) +async def download_device_sandbox_attachment_internal( + request: InternalDeviceSandboxDownloadAttachmentRequest, + db: Session = Depends(get_db), +) -> DeviceSandboxGenericResponse: + """Download a Wegent attachment into the bound device-backed sandbox.""" + from app.services.device_sandbox_service import ( + DeviceSandboxError, + device_sandbox_service, + ) + + try: + result = await device_sandbox_service.download_attachment( + db=db, + user_id=request.user_id, + task_id=request.task_id, + attachment_url=request.attachment_url, + save_path=request.save_path, + auth_token=request.auth_token, + api_base_url=request.api_base_url, + timeout_seconds=request.timeout_seconds, + device_id=request.device_id, + ) + except DeviceSandboxError as exc: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc) + ) from exc + + db.commit() + return _build_generic_response(result) + + +@router.post("/sandbox/upload-attachment", response_model=DeviceSandboxGenericResponse) +async def upload_device_sandbox_attachment_internal( + request: InternalDeviceSandboxUploadAttachmentRequest, + db: Session = Depends(get_db), +) -> DeviceSandboxGenericResponse: + """Upload a device-local file through the bound device-backed sandbox.""" + from app.services.device_sandbox_service import ( + DeviceSandboxError, + device_sandbox_service, + ) + + try: + result = await device_sandbox_service.upload_attachment( + db=db, + user_id=request.user_id, + task_id=request.task_id, + file_path=request.file_path, + auth_token=request.auth_token, + api_base_url=request.api_base_url, + overwrite_attachment_id=request.overwrite_attachment_id, + timeout_seconds=request.timeout_seconds, + device_id=request.device_id, + ) + except DeviceSandboxError as exc: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc) + ) from exc + + db.commit() + return _build_generic_response(result) diff --git a/backend/app/api/ws/device_namespace.py b/backend/app/api/ws/device_namespace.py index efbcc2ae2..f3d4e26c5 100644 --- a/backend/app/api/ws/device_namespace.py +++ b/backend/app/api/ws/device_namespace.py @@ -165,6 +165,7 @@ def _register_device( client_ip: Optional[str] = None, device_type: Optional[str] = None, bind_shell: Optional[str] = None, + capabilities: Optional[list[str]] = None, ) -> tuple[bool, Optional[str]]: """ Register or update device CRD in database. @@ -189,6 +190,7 @@ def _register_device( client_ip=client_ip, device_type=device_type, bind_shell=bind_shell, + capabilities=capabilities, ) return True, None except Exception as e: @@ -759,6 +761,7 @@ async def on_device_register(self, sid: str, data: dict) -> dict: payload.client_ip, payload.device_type.value, payload.bind_shell.value, + payload.capabilities, ) if not success: return {"error": f"Registration failed: {error}"} diff --git a/backend/app/services/device_sandbox_service.py b/backend/app/services/device_sandbox_service.py new file mode 100644 index 000000000..91db31c74 --- /dev/null +++ b/backend/app/services/device_sandbox_service.py @@ -0,0 +1,465 @@ +# SPDX-FileCopyrightText: 2025 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Device-backed sandbox execution helpers.""" + +import logging +import time +from typing import Any, Optional + +from sqlalchemy.orm import Session +from sqlalchemy.orm.attributes import flag_modified + +from app.core.socketio import get_sio +from app.models.task import TaskResource +from app.schemas.device import DeviceType +from app.services.device_service import device_service + +logger = logging.getLogger(__name__) + +SANDBOX_BACKEND_LABEL = "sandboxBackend" +SANDBOX_DEVICE_ID_LABEL = "sandboxDeviceId" +DEVICE_BACKEND_NAME = "device" + + +class DeviceSandboxError(RuntimeError): + """Raised when a device-backed sandbox command cannot be executed.""" + + +class DeviceSandboxService: + """Service for forwarding sandbox commands to an online user device.""" + + async def execute_command( + self, + db: Session, + user_id: int, + command: str, + working_dir: str = "/home/user", + timeout_seconds: int = 300, + required_capability: Optional[str] = None, + device_id: Optional[str] = None, + task_id: Optional[int] = None, + ) -> dict[str, Any]: + """Execute a command on a user's online device via Socket.IO.""" + return await self._execute_device_event( + db=db, + user_id=user_id, + task_id=task_id, + event_name="sandbox:exec", + payload={ + "command": command, + "working_dir": working_dir, + "timeout_seconds": timeout_seconds, + }, + timeout_seconds=timeout_seconds, + required_capability=required_capability, + device_id=device_id, + ) + + async def read_file( + self, + db: Session, + user_id: int, + file_path: str, + format: str = "text", + device_id: Optional[str] = None, + task_id: Optional[int] = None, + ) -> dict[str, Any]: + """Read a file from the bound device.""" + return await self._execute_device_event( + db=db, + user_id=user_id, + task_id=task_id, + event_name="sandbox:read_file", + payload={"file_path": file_path, "format": format}, + timeout_seconds=60, + device_id=device_id, + ) + + async def list_files( + self, + db: Session, + user_id: int, + path: str = "/home/user", + depth: int = 1, + device_id: Optional[str] = None, + task_id: Optional[int] = None, + ) -> dict[str, Any]: + """List files from the bound device.""" + return await self._execute_device_event( + db=db, + user_id=user_id, + task_id=task_id, + event_name="sandbox:list_files", + payload={"path": path, "depth": depth}, + timeout_seconds=60, + device_id=device_id, + ) + + async def write_file( + self, + db: Session, + user_id: int, + file_path: str, + content: str, + format: str = "text", + create_dirs: bool = True, + device_id: Optional[str] = None, + task_id: Optional[int] = None, + ) -> dict[str, Any]: + """Write a file to the bound device.""" + return await self._execute_device_event( + db=db, + user_id=user_id, + task_id=task_id, + event_name="sandbox:write_file", + payload={ + "file_path": file_path, + "content": content, + "format": format, + "create_dirs": create_dirs, + }, + timeout_seconds=60, + device_id=device_id, + ) + + async def download_attachment( + self, + db: Session, + user_id: int, + attachment_url: str, + save_path: str, + auth_token: str, + api_base_url: str, + timeout_seconds: int = 300, + device_id: Optional[str] = None, + task_id: Optional[int] = None, + ) -> dict[str, Any]: + """Download a Wegent attachment to the bound device.""" + return await self._execute_device_event( + db=db, + user_id=user_id, + task_id=task_id, + event_name="sandbox:download_attachment", + payload={ + "attachment_url": attachment_url, + "save_path": save_path, + "auth_token": auth_token, + "api_base_url": api_base_url, + "timeout_seconds": timeout_seconds, + }, + timeout_seconds=timeout_seconds, + device_id=device_id, + ) + + async def upload_attachment( + self, + db: Session, + user_id: int, + file_path: str, + auth_token: str, + api_base_url: str, + overwrite_attachment_id: Optional[int] = None, + timeout_seconds: int = 300, + device_id: Optional[str] = None, + task_id: Optional[int] = None, + ) -> dict[str, Any]: + """Upload a device-local file back to Wegent attachments.""" + return await self._execute_device_event( + db=db, + user_id=user_id, + task_id=task_id, + event_name="sandbox:upload_attachment", + payload={ + "file_path": file_path, + "auth_token": auth_token, + "api_base_url": api_base_url, + "overwrite_attachment_id": overwrite_attachment_id, + "timeout_seconds": timeout_seconds, + }, + timeout_seconds=timeout_seconds, + device_id=device_id, + ) + + async def _execute_device_event( + self, + db: Session, + user_id: int, + task_id: Optional[int], + event_name: str, + payload: dict[str, Any], + timeout_seconds: int, + required_capability: Optional[str] = None, + device_id: Optional[str] = None, + ) -> dict[str, Any]: + """Select a device, persist sticky binding, and dispatch an event.""" + target_device = await self._select_target_device( + db=db, + user_id=user_id, + required_capability=required_capability, + device_id=device_id, + task_id=task_id, + ) + if target_device is None: + raise DeviceSandboxError("No compatible online device is available") + + target_device_id = target_device["device_id"] + socket_id = await self._get_target_socket_id(user_id, target_device_id) + + if task_id is not None: + self._persist_task_binding( + db=db, + user_id=user_id, + task_id=task_id, + device_id=target_device_id, + ) + + sio = get_sio() + started_at = time.monotonic() + + logger.info( + "[DeviceSandboxService] Forwarding event to device: user_id=%s, " + "device_id=%s, event=%s, timeout=%ss, required_capability=%s", + user_id, + target_device_id, + event_name, + timeout_seconds, + required_capability, + ) + + try: + response = await sio.call( + event_name, + payload, + to=socket_id, + namespace="/local-executor", + timeout=max(timeout_seconds + 5, 30), + ) + except Exception as exc: + logger.error( + "[DeviceSandboxService] Device event dispatch failed: user_id=%s, " + "device_id=%s, event=%s, error=%s", + user_id, + target_device_id, + event_name, + exc, + ) + raise DeviceSandboxError(f"Device event dispatch failed: {exc}") from exc + + if not isinstance(response, dict): + raise DeviceSandboxError("Device returned an invalid sandbox response") + + execution_time = response.get("execution_time") + if not isinstance(execution_time, (int, float)): + execution_time = time.monotonic() - started_at + + normalized = dict(response) + normalized.setdefault("success", False) + normalized["execution_time"] = execution_time + normalized["device_id"] = target_device_id + normalized["backend"] = DEVICE_BACKEND_NAME + + logger.info( + "[DeviceSandboxService] Device event completed: user_id=%s, device_id=%s, " + "socket_id=%s, event=%s, success=%s, execution_time=%.2fs", + user_id, + target_device_id, + socket_id, + event_name, + bool(normalized.get("success")), + execution_time, + ) + + return normalized + + async def _get_target_socket_id(self, user_id: int, target_device_id: str) -> str: + """Resolve the active socket ID for a selected device.""" + online_info = await device_service.get_device_online_info( + user_id, target_device_id + ) + if not online_info: + raise DeviceSandboxError(f"Device '{target_device_id}' is offline") + + socket_id = online_info.get("socket_id") + if not socket_id: + raise DeviceSandboxError( + f"Device '{target_device_id}' does not have an active socket session" + ) + return socket_id + + async def _select_target_device( + self, + db: Session, + user_id: int, + required_capability: Optional[str], + device_id: Optional[str], + task_id: Optional[int], + ) -> Optional[dict[str, Any]]: + """Pick an online device for sandbox execution.""" + online_devices = await device_service.get_online_devices(db, user_id) + if not online_devices: + return None + + bound_device_id = None + if task_id is not None: + bound_device_id = self._get_bound_task_device_id( + db=db, + user_id=user_id, + task_id=task_id, + ) + + resolved_device_id = bound_device_id or device_id + + if bound_device_id and not any( + device.get("device_id") == bound_device_id for device in online_devices + ): + raise DeviceSandboxError( + f"Bound sandbox device '{bound_device_id}' is offline" + ) + + compatible_devices = [ + device + for device in online_devices + if self._matches_device( + device=device, + required_capability=required_capability, + device_id=resolved_device_id, + ) + ] + if not compatible_devices: + if bound_device_id: + raise DeviceSandboxError( + f"Bound sandbox device '{bound_device_id}' is unavailable" + ) + return None + + def priority(device: dict[str, Any]) -> int: + device_type = device.get("device_type") + is_default = bool(device.get("is_default")) + if is_default and device_type == DeviceType.CLOUD.value: + return 0 + if is_default: + return 1 + if device_type == DeviceType.CLOUD.value: + return 2 + return 3 + + compatible_devices.sort(key=priority) + selected_device = compatible_devices[0] + logger.info( + "[DeviceSandboxService] Selected device: user_id=%s, device_id=%s, " + "device_name=%s, device_type=%s, is_default=%s, required_capability=%s, " + "requested_device_id=%s, bound_device_id=%s, task_id=%s, " + "capabilities=%s, compatible_candidates=%s", + user_id, + selected_device.get("device_id"), + selected_device.get("device_name"), + selected_device.get("device_type"), + selected_device.get("is_default"), + required_capability, + device_id, + bound_device_id, + task_id, + selected_device.get("capabilities") or [], + [ + { + "device_id": device.get("device_id"), + "device_name": device.get("device_name"), + "device_type": device.get("device_type"), + "is_default": device.get("is_default"), + } + for device in compatible_devices + ], + ) + return selected_device + + def _matches_device( + self, + device: dict[str, Any], + required_capability: Optional[str], + device_id: Optional[str], + ) -> bool: + """Check whether a device satisfies routing constraints.""" + if device_id and device.get("device_id") != device_id: + return False + + if not required_capability: + return True + + capabilities = device.get("capabilities") or [] + return required_capability in capabilities + + def _get_bound_task_device_id( + self, + db: Session, + user_id: int, + task_id: int, + ) -> Optional[str]: + """Return the task-bound device ID when device backend was already selected.""" + task = ( + db.query(TaskResource) + .filter( + TaskResource.id == task_id, + TaskResource.user_id == user_id, + TaskResource.kind == "Task", + TaskResource.is_active == TaskResource.STATE_ACTIVE, + ) + .first() + ) + if not task: + return None + + task_json = task.json if isinstance(task.json, dict) else {} + labels = task_json.get("metadata", {}).get("labels", {}) + if labels.get(SANDBOX_BACKEND_LABEL) != DEVICE_BACKEND_NAME: + return None + return labels.get(SANDBOX_DEVICE_ID_LABEL) + + def _persist_task_binding( + self, + db: Session, + user_id: int, + task_id: int, + device_id: str, + ) -> None: + """Persist the selected device backend on the task for sticky routing.""" + task = ( + db.query(TaskResource) + .filter( + TaskResource.id == task_id, + TaskResource.user_id == user_id, + TaskResource.kind == "Task", + TaskResource.is_active == TaskResource.STATE_ACTIVE, + ) + .first() + ) + if not task: + return + + task_json = task.json if isinstance(task.json, dict) else {} + metadata = task_json.setdefault("metadata", {}) + labels = metadata.setdefault("labels", {}) + + if ( + labels.get(SANDBOX_BACKEND_LABEL) == DEVICE_BACKEND_NAME + and labels.get(SANDBOX_DEVICE_ID_LABEL) == device_id + ): + return + + labels[SANDBOX_BACKEND_LABEL] = DEVICE_BACKEND_NAME + labels[SANDBOX_DEVICE_ID_LABEL] = device_id + task.json = task_json + flag_modified(task, "json") + + logger.info( + "[DeviceSandboxService] Persisted task sandbox binding: user_id=%s, " + "task_id=%s, backend=%s, device_id=%s", + user_id, + task_id, + DEVICE_BACKEND_NAME, + device_id, + ) + + +device_sandbox_service = DeviceSandboxService() diff --git a/backend/app/services/device_service.py b/backend/app/services/device_service.py index 2a2bc88d0..19ad1a1b8 100644 --- a/backend/app/services/device_service.py +++ b/backend/app/services/device_service.py @@ -270,6 +270,7 @@ def upsert_device_crd( client_ip: Optional[str] = None, device_type: Optional[str] = None, bind_shell: Optional[str] = None, + capabilities: Optional[List[str]] = None, ) -> Kind: """Create or update a Device CRD record. @@ -287,6 +288,7 @@ def upsert_device_crd( bind_shell: Shell runtime binding ('claudecode' or 'openclaw'). If None, defaults to 'claudecode' for new devices or preserves existing value. + capabilities: Optional capability tags persisted in spec.capabilities. Returns: Kind model instance for the device @@ -320,6 +322,8 @@ def upsert_device_crd( # Update client IP if provided if client_ip is not None: device_json["spec"]["clientIp"] = client_ip + if capabilities is not None: + device_json["spec"]["capabilities"] = capabilities # Update bind_shell if provided, otherwise preserve existing value if bind_shell is not None: device_json["spec"]["bindShell"] = bind_shell @@ -374,7 +378,7 @@ def upsert_device_crd( "connectionMode": "websocket", "bindShell": resolved_bind_shell, "isDefault": is_first_device, - "capabilities": None, + "capabilities": capabilities, "clientIp": client_ip, }, "status": { diff --git a/backend/init_data/skills/sandbox/command_tool.py b/backend/init_data/skills/sandbox/command_tool.py index e2bbcd5d4..ea96dfece 100644 --- a/backend/init_data/skills/sandbox/command_tool.py +++ b/backend/init_data/skills/sandbox/command_tool.py @@ -163,6 +163,7 @@ async def _arun( """ start_time = time.time() effective_timeout = timeout_seconds or self.default_command_timeout + raw_command = command # Wrap command with bash -c if it contains shell operators # This ensures operators like &&, ||, |, ;, >, < are properly interpreted @@ -193,6 +194,37 @@ async def _arun( try: # Get sandbox manager from base class sandbox_manager = self._get_sandbox_manager() + await sandbox_manager.ensure_device_binding_loaded() + + if sandbox_manager.should_use_device_backend_for_command(raw_command): + if sandbox_manager.is_device_backend_bound(): + logger.info( + "[SandboxCommandTool] Routing command to bound device backend: %s", + raw_command[:100], + ) + else: + logger.info( + "[SandboxCommandTool] Routing command to device backend: %s", + raw_command[:100], + ) + response = await sandbox_manager.execute_command_via_device( + command=raw_command, + working_dir=working_dir, + timeout_seconds=effective_timeout, + ) + + if response.get("success"): + await self._emit_tool_status( + "completed", "Command executed successfully", response + ) + else: + await self._emit_tool_status( + "failed", + f"Command failed with exit code {response.get('exit_code', -1)}", + response, + ) + + return json.dumps(response, ensure_ascii=False, indent=2) # Get or create sandbox logger.info(f"[SandboxCommandTool] Getting or creating sandbox...") diff --git a/backend/init_data/skills/sandbox/download_attachment_tool.py b/backend/init_data/skills/sandbox/download_attachment_tool.py index 610d90ebc..3b649fcc5 100644 --- a/backend/init_data/skills/sandbox/download_attachment_tool.py +++ b/backend/init_data/skills/sandbox/download_attachment_tool.py @@ -157,6 +157,32 @@ async def _arun( # Get sandbox manager from base class sandbox_manager = self._get_sandbox_manager() + await sandbox_manager.ensure_device_binding_loaded() + if sandbox_manager.is_device_backend_bound(): + logger.info( + "[SandboxDownloadAttachmentTool] Downloading attachment via bound device backend: %s -> %s", + attachment_url, + save_path, + ) + response = await sandbox_manager.download_attachment_via_device( + attachment_url=attachment_url, + save_path=save_path, + timeout_seconds=effective_timeout, + ) + if response.get("success"): + await self._emit_tool_status( + "completed", + f"File downloaded successfully ({response.get('file_size', 0)} bytes)", + response, + ) + else: + await self._emit_tool_status( + "failed", + response.get("error", "Failed to download file"), + response, + ) + return json.dumps(response, ensure_ascii=False, indent=2) + # Get or create sandbox logger.info( f"[SandboxDownloadAttachmentTool] Getting or creating sandbox..." diff --git a/backend/init_data/skills/sandbox/list_files_tool.py b/backend/init_data/skills/sandbox/list_files_tool.py index 84fd6176f..dffe80056 100644 --- a/backend/init_data/skills/sandbox/list_files_tool.py +++ b/backend/init_data/skills/sandbox/list_files_tool.py @@ -135,6 +135,30 @@ async def _arun( # Get sandbox manager from base class sandbox_manager = self._get_sandbox_manager() + await sandbox_manager.ensure_device_binding_loaded() + if sandbox_manager.is_device_backend_bound(): + logger.info( + "[SandboxListFilesTool] Listing files from bound device backend: %s", + path, + ) + response = await sandbox_manager.list_files_via_device( + path=path or "/home/user", + depth=depth or 1, + ) + if response.get("success"): + await self._emit_tool_status( + "completed", + f"Listed {response.get('total', 0)} entries", + response, + ) + else: + await self._emit_tool_status( + "failed", + response.get("error", "Failed to list files"), + response, + ) + return json.dumps(response, ensure_ascii=False, indent=2) + # Get or create sandbox logger.info(f"[SandboxListFilesTool] Getting or creating sandbox...") sandbox, error = await sandbox_manager.get_or_create_sandbox( diff --git a/backend/init_data/skills/sandbox/read_file_tool.py b/backend/init_data/skills/sandbox/read_file_tool.py index c09b64440..91241b83a 100644 --- a/backend/init_data/skills/sandbox/read_file_tool.py +++ b/backend/init_data/skills/sandbox/read_file_tool.py @@ -137,6 +137,30 @@ async def _arun( # Get sandbox manager from base class sandbox_manager = self._get_sandbox_manager() + await sandbox_manager.ensure_device_binding_loaded() + if sandbox_manager.is_device_backend_bound(): + logger.info( + "[SandboxReadFileTool] Reading file from bound device backend: %s", + file_path, + ) + response = await sandbox_manager.read_file_via_device( + file_path=file_path, + format=format or "text", + ) + if response.get("success"): + await self._emit_tool_status( + "completed", + f"File read successfully ({response.get('size', 0)} bytes)", + response, + ) + else: + await self._emit_tool_status( + "failed", + response.get("error", "Failed to read file"), + response, + ) + return json.dumps(response, ensure_ascii=False, indent=2) + # Get or create sandbox logger.info(f"[SandboxReadFileTool] Getting or creating sandbox...") sandbox, error = await sandbox_manager.get_or_create_sandbox( diff --git a/backend/init_data/skills/sandbox/upload_attachment_tool.py b/backend/init_data/skills/sandbox/upload_attachment_tool.py index 50834e56e..62e39f71a 100644 --- a/backend/init_data/skills/sandbox/upload_attachment_tool.py +++ b/backend/init_data/skills/sandbox/upload_attachment_tool.py @@ -165,6 +165,31 @@ async def _arun( # Get sandbox manager from base class sandbox_manager = self._get_sandbox_manager() + await sandbox_manager.ensure_device_binding_loaded() + if sandbox_manager.is_device_backend_bound(): + logger.info( + "[SandboxUploadAttachmentTool] Uploading file via bound device backend: %s", + file_path, + ) + response = await sandbox_manager.upload_attachment_via_device( + file_path=file_path, + overwrite_attachment_id=overwrite_attachment_id, + timeout_seconds=effective_timeout, + ) + if response.get("success"): + await self._emit_tool_status( + "completed", + f"File uploaded successfully ({response.get('file_size', 0)} bytes)", + response, + ) + else: + await self._emit_tool_status( + "failed", + response.get("error", "Failed to upload file"), + response, + ) + return json.dumps(response, ensure_ascii=False, indent=2) + # Get or create sandbox logger.info(f"[SandboxUploadAttachmentTool] Getting or creating sandbox...") sandbox, error = await sandbox_manager.get_or_create_sandbox( diff --git a/backend/init_data/skills/sandbox/write_file_tool.py b/backend/init_data/skills/sandbox/write_file_tool.py index 2de1e6684..7b42c0103 100644 --- a/backend/init_data/skills/sandbox/write_file_tool.py +++ b/backend/init_data/skills/sandbox/write_file_tool.py @@ -177,6 +177,32 @@ async def _arun( # Get sandbox manager from base class sandbox_manager = self._get_sandbox_manager() + await sandbox_manager.ensure_device_binding_loaded() + if sandbox_manager.is_device_backend_bound(): + logger.info( + "[SandboxWriteFileTool] Writing file to bound device backend: %s", + file_path, + ) + response = await sandbox_manager.write_file_via_device( + file_path=file_path, + content=content, + format=format or "text", + create_dirs=bool(create_dirs), + ) + if response.get("success"): + await self._emit_tool_status( + "completed", + f"File written successfully ({response.get('size', 0)} bytes)", + response, + ) + else: + await self._emit_tool_status( + "failed", + response.get("error", "Failed to write file"), + response, + ) + return json.dumps(response, ensure_ascii=False, indent=2) + # Get or create sandbox logger.info(f"[SandboxWriteFileTool] Getting or creating sandbox...") sandbox, error = await sandbox_manager.get_or_create_sandbox( diff --git a/backend/tests/services/test_device_sandbox_service.py b/backend/tests/services/test_device_sandbox_service.py new file mode 100644 index 000000000..ef3c7e13a --- /dev/null +++ b/backend/tests/services/test_device_sandbox_service.py @@ -0,0 +1,205 @@ +# SPDX-FileCopyrightText: 2025 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for device-backed sandbox execution.""" + +from unittest.mock import ANY, AsyncMock, MagicMock, patch + +import pytest + +from app.services.device_sandbox_service import ( + DeviceSandboxError, + device_sandbox_service, +) + + +class TestDeviceSandboxService: + """Tests for DeviceSandboxService.""" + + @pytest.mark.asyncio + async def test_execute_command_prefers_default_cloud_device(self): + """Default cloud devices should be preferred over other online devices.""" + online_devices = [ + { + "device_id": "local-device", + "device_type": "local", + "is_default": True, + "capabilities": [], + }, + { + "device_id": "cloud-device", + "device_type": "cloud", + "is_default": True, + "capabilities": [], + }, + ] + mock_sio = MagicMock() + mock_sio.call = AsyncMock( + return_value={ + "success": True, + "stdout": "ok", + "stderr": "", + "exit_code": 0, + "execution_time": 0.12, + } + ) + + with ( + patch( + "app.services.device_sandbox_service.device_service.get_online_devices", + AsyncMock(return_value=online_devices), + ), + patch( + "app.services.device_sandbox_service.device_service.get_device_online_info", + AsyncMock(return_value={"socket_id": "socket-1"}), + ) as mock_online_info, + patch( + "app.services.device_sandbox_service.get_sio", + return_value=mock_sio, + ), + ): + result = await device_sandbox_service.execute_command( + db=MagicMock(), + user_id=1, + command="himalaya --help", + ) + + assert result["success"] is True + assert result["device_id"] == "cloud-device" + mock_online_info.assert_awaited_once_with(1, "cloud-device") + mock_sio.call.assert_awaited_once() + + @pytest.mark.asyncio + async def test_execute_command_raises_when_no_online_device(self): + """An explicit error should be raised when no online device is available.""" + with patch( + "app.services.device_sandbox_service.device_service.get_online_devices", + AsyncMock(return_value=[]), + ): + with pytest.raises(DeviceSandboxError, match="No compatible online device"): + await device_sandbox_service.execute_command( + db=MagicMock(), + user_id=1, + command="himalaya --help", + ) + + @pytest.mark.asyncio + async def test_execute_command_prefers_task_bound_device(self): + """Task-bound device should override normal priority ordering.""" + online_devices = [ + { + "device_id": "local-device", + "device_type": "local", + "is_default": False, + "capabilities": [], + }, + { + "device_id": "cloud-device", + "device_type": "cloud", + "is_default": True, + "capabilities": [], + }, + ] + mock_sio = MagicMock() + mock_sio.call = AsyncMock( + return_value={ + "success": True, + "stdout": "ok", + "stderr": "", + "exit_code": 0, + "execution_time": 0.12, + } + ) + + with ( + patch( + "app.services.device_sandbox_service.device_service.get_online_devices", + AsyncMock(return_value=online_devices), + ), + patch( + "app.services.device_sandbox_service.device_service.get_device_online_info", + AsyncMock(return_value={"socket_id": "socket-1"}), + ) as mock_online_info, + patch( + "app.services.device_sandbox_service.get_sio", + return_value=mock_sio, + ), + patch.object( + device_sandbox_service, + "_get_bound_task_device_id", + return_value="local-device", + ), + patch.object(device_sandbox_service, "_persist_task_binding"), + ): + result = await device_sandbox_service.execute_command( + db=MagicMock(), + user_id=1, + task_id=794, + command="pwd", + ) + + assert result["success"] is True + assert result["device_id"] == "local-device" + mock_online_info.assert_awaited_once_with(1, "local-device") + mock_sio.call.assert_awaited_once() + + @pytest.mark.asyncio + async def test_execute_command_persists_selected_device_on_task(self): + """Selected device should be written back as sticky task binding.""" + online_devices = [ + { + "device_id": "cloud-device", + "device_type": "cloud", + "is_default": True, + "capabilities": [], + } + ] + mock_sio = MagicMock() + mock_sio.call = AsyncMock( + return_value={ + "success": False, + "stdout": "", + "stderr": "config missing", + "exit_code": 1, + "execution_time": 0.08, + } + ) + + with ( + patch( + "app.services.device_sandbox_service.device_service.get_online_devices", + AsyncMock(return_value=online_devices), + ), + patch( + "app.services.device_sandbox_service.device_service.get_device_online_info", + AsyncMock(return_value={"socket_id": "socket-1"}), + ), + patch( + "app.services.device_sandbox_service.get_sio", + return_value=mock_sio, + ), + patch.object( + device_sandbox_service, + "_get_bound_task_device_id", + return_value=None, + ), + patch.object( + device_sandbox_service, + "_persist_task_binding", + ) as mock_persist_task_binding, + ): + result = await device_sandbox_service.execute_command( + db=MagicMock(), + user_id=1, + task_id=794, + command="himalaya --help", + ) + + assert result["success"] is False + mock_persist_task_binding.assert_called_once_with( + db=ANY, + user_id=1, + task_id=794, + device_id="cloud-device", + ) diff --git a/chat_shell/.env.example b/chat_shell/.env.example index 7e9ef3a31..47e4fbbc2 100644 --- a/chat_shell/.env.example +++ b/chat_shell/.env.example @@ -77,9 +77,11 @@ CHAT_SHELL_CHAT_API_TIMEOUT_SECONDS=300 # Tool calling flow limits # Maximum LLM requests in tool calling flow -CHAT_SHELL_CHAT_TOOL_MAX_REQUESTS=10 +CHAT_SHELL_CHAT_TOOL_MAX_REQUESTS=50 # Maximum time for tool calling flow (seconds) CHAT_SHELL_CHAT_TOOL_MAX_TIME_SECONDS=60.0 +# Comma-separated command names that should route to the device backend +CHAT_SHELL_DEVICE_ROUTED_COMMANDS=himalaya # Group chat history configuration # In group chat mode, AI-bot sees: first N messages + last M messages (no duplicates) diff --git a/chat_shell/chat_shell/agent.py b/chat_shell/chat_shell/agent.py index 91e84a7ae..6b5d77a78 100644 --- a/chat_shell/chat_shell/agent.py +++ b/chat_shell/chat_shell/agent.py @@ -53,7 +53,7 @@ class AgentConfig: model_config: dict[str, Any] system_prompt: str = "" - max_iterations: int = 10 # Default, can be overridden by settings + max_iterations: int = 50 # Default, can be overridden by settings extra_tools: list[BaseTool] | None = None streaming: bool = True # Prompt enhancement options (handled internally by ChatAgent) diff --git a/chat_shell/chat_shell/agents/graph_builder.py b/chat_shell/chat_shell/agents/graph_builder.py index eb2df2c32..5324ee48c 100644 --- a/chat_shell/chat_shell/agents/graph_builder.py +++ b/chat_shell/chat_shell/agents/graph_builder.py @@ -233,7 +233,7 @@ def __init__( self, llm: BaseChatModel, tool_registry: ToolRegistry | None = None, - max_iterations: int = 10, + max_iterations: int = 50, enable_checkpointing: bool = False, max_truncation_retries: int | None = None, ): @@ -269,6 +269,52 @@ def __init__( # Automatically detect PromptModifierTool instances from registered tools self._prompt_modifier_tools = self._find_prompt_modifier_tools() + async def _stream_final_response_without_tools( + self, + lc_messages: list[BaseMessage], + system_notice: str, + ) -> AsyncGenerator[str, None]: + """Ask the model for a final response without allowing more tools.""" + final_messages = list(lc_messages) + [HumanMessage(content=system_notice)] + + async for chunk in self.llm.astream(final_messages): + if hasattr(chunk, "content"): + content = chunk.content + if isinstance(content, str) and content: + yield content + elif isinstance(content, list): + for part in content: + if isinstance(part, str) and part: + yield part + elif isinstance(part, dict): + text = part.get("text", "") + if text: + yield text + + async def _build_final_state_without_tools( + self, + lc_messages: list[BaseMessage], + system_notice: str, + ) -> dict[str, Any]: + """Get a final LLM response without allowing more tools.""" + final_messages = list(lc_messages) + [HumanMessage(content=system_notice)] + response = await self.llm.ainvoke(final_messages) + + final_content = "" + if hasattr(response, "content"): + if isinstance(response.content, str): + final_content = response.content + elif isinstance(response.content, list): + text_parts = [] + for part in response.content: + if isinstance(part, str): + text_parts.append(part) + elif isinstance(part, dict) and part.get("type") == "text": + text_parts.append(part.get("text", "")) + final_content = "".join(text_parts) + + return {"messages": list(lc_messages) + [AIMessage(content=final_content)]} + def _find_prompt_modifier_tools(self) -> list[Any]: """Find all tools that implement the PromptModifierTool protocol. @@ -1103,33 +1149,16 @@ async def stream_tokens( "Asking model to provide final response.", self.max_iterations, ) - - # Build messages with the limit reached notice - # Add a human message to prompt the model to provide final response - limit_messages = list(lc_messages) + [ - HumanMessage(content=TOOL_LIMIT_REACHED_MESSAGE) - ] - - # Call the LLM directly (without tools) to get final response try: - async for chunk in self.llm.astream(limit_messages): - if hasattr(chunk, "content"): - content = chunk.content - if isinstance(content, str) and content: - yield content - elif isinstance(content, list): - for part in content: - if isinstance(part, str) and part: - yield part - elif isinstance(part, dict): - text = part.get("text", "") - if text: - yield text - + async for chunk in self._stream_final_response_without_tools( + lc_messages, + TOOL_LIMIT_REACHED_MESSAGE, + ): + yield chunk logger.info( "[stream_tokens] Final response generated after tool limit reached" ) - except Exception as recovery_error: + except Exception: logger.exception( "Error generating final response after tool limit reached" ) @@ -1289,32 +1318,11 @@ async def stream_events_with_state( "Asking model to provide final response.", self.max_iterations, ) - - # Build messages with the limit reached notice - limit_messages = list(lc_messages) + [ - HumanMessage(content=TOOL_LIMIT_REACHED_MESSAGE) - ] - - # Call the LLM directly (without tools) to get final response try: - response = await self.llm.ainvoke(limit_messages) - final_content = "" - if hasattr(response, "content"): - if isinstance(response.content, str): - final_content = response.content - elif isinstance(response.content, list): - text_parts = [] - for part in response.content: - if isinstance(part, str): - text_parts.append(part) - elif isinstance(part, dict) and part.get("type") == "text": - text_parts.append(part.get("text", "")) - final_content = "".join(text_parts) - - # Create a final state with the response - final_state = { - "messages": list(lc_messages) + [AIMessage(content=final_content)] - } + final_state = await self._build_final_state_without_tools( + lc_messages, + TOOL_LIMIT_REACHED_MESSAGE, + ) logger.debug( "[stream_events_with_state] Final response generated after tool limit reached" diff --git a/chat_shell/chat_shell/api/schemas.py b/chat_shell/chat_shell/api/schemas.py index 2fe690377..2945c133c 100644 --- a/chat_shell/chat_shell/api/schemas.py +++ b/chat_shell/chat_shell/api/schemas.py @@ -76,7 +76,7 @@ class ChatEvent: extra_tools: list[Any] | None = None enable_web_search: bool = False search_engine: str | None = None - max_iterations: int = 10 + max_iterations: int = 50 # Metadata message_id: int | None = None @@ -119,7 +119,7 @@ def from_dict(cls, data: dict[str, Any]) -> "ChatEvent": extra_tools=data.get("extra_tools"), enable_web_search=data.get("enable_web_search", False), search_engine=data.get("search_engine"), - max_iterations=data.get("max_iterations", 10), + max_iterations=data.get("max_iterations", 50), message_id=data.get("message_id"), shell_type=data.get("shell_type", "Chat"), ) diff --git a/chat_shell/chat_shell/core/config.py b/chat_shell/chat_shell/core/config.py index e1cd779df..abc857cfa 100644 --- a/chat_shell/chat_shell/core/config.py +++ b/chat_shell/chat_shell/core/config.py @@ -79,8 +79,9 @@ def CHAT_SHELL_MODE(self) -> str: CHAT_API_TIMEOUT_SECONDS: int = 300 # Tool calling flow limits - CHAT_TOOL_MAX_REQUESTS: int = 30 + CHAT_TOOL_MAX_REQUESTS: int = 50 CHAT_TOOL_MAX_TIME_SECONDS: float = 60.0 + DEVICE_ROUTED_COMMANDS: str = "himalaya" # Group chat history configuration GROUP_CHAT_HISTORY_FIRST_MESSAGES: int = 10 diff --git a/chat_shell/chat_shell/skills/registry.py b/chat_shell/chat_shell/skills/registry.py index e2cfcf2ce..5a796189c 100644 --- a/chat_shell/chat_shell/skills/registry.py +++ b/chat_shell/chat_shell/skills/registry.py @@ -9,6 +9,7 @@ loading of providers from skill packages. """ +import importlib.machinery import importlib.util import logging import threading @@ -22,6 +23,24 @@ logger = logging.getLogger(__name__) +def _module_execution_priority(module_name: str) -> tuple[int, str]: + """Return execution priority for dynamically loaded skill modules. + + Skill packages are loaded from ZIP bytes without a real filesystem-backed + importer. Modules that depend on shared package state, such as ``_base.py``, + must execute before tool modules that import them. ``__init__.py`` is + executed last because it commonly re-exports submodules. + """ + + if module_name == "_base": + return (0, module_name) + if module_name == "provider": + return (1, module_name) + if module_name == "__init__": + return (3, module_name) + return (2, module_name) + + class SkillToolRegistry: """Central registry for skill tool providers. @@ -288,21 +307,37 @@ def load_provider_from_zip( ) return None - # Create the package module if it doesn't exist + # Create the package module if it doesn't exist. + # The explicit package spec makes relative imports work reliably + # for dynamically executed modules from in-memory ZIP content. if package_name not in sys.modules: package_module = types.ModuleType(package_name) + package_spec = importlib.machinery.ModuleSpec( + package_name, + loader=None, + is_package=True, + ) + package_spec.submodule_search_locations = [] package_module.__path__ = [] package_module.__package__ = package_name + package_module.__spec__ = package_spec sys.modules[package_name] = package_module - # Load all Python modules in the skill package + module_sources: dict[str, str] = {} + + # Pre-register all module objects before execution so intra-package + # imports resolve consistently during module initialization. for py_mod_name, file_path in python_files.items(): full_module_name = f"{package_name}.{py_mod_name}" if full_module_name in sys.modules: + module_sources[py_mod_name] = zip_file.read(file_path).decode( + "utf-8" + ) continue module_code = zip_file.read(file_path).decode("utf-8") + module_sources[py_mod_name] = module_code spec = importlib.util.spec_from_loader( full_module_name, @@ -320,8 +355,24 @@ def load_provider_from_zip( module.__package__ = package_name sys.modules[full_module_name] = module + # Execute modules in dependency-aware order. This avoids cases like + # command_tool importing ._base before _base.py has been executed. + for py_mod_name in sorted( + python_files.keys(), key=_module_execution_priority + ): + full_module_name = f"{package_name}.{py_mod_name}" + module = sys.modules.get(full_module_name) + if module is None: + continue + + if getattr(module, "__skill_loaded__", False): + continue + + module_code = module_sources[py_mod_name] + try: exec(module_code, module.__dict__) + module.__skill_loaded__ = True except Exception as e: logger.error( f"[SkillToolRegistry] Failed to execute module " diff --git a/chat_shell/chat_shell/tools/sandbox/_base.py b/chat_shell/chat_shell/tools/sandbox/_base.py index 03f0a8841..a71d772df 100644 --- a/chat_shell/chat_shell/tools/sandbox/_base.py +++ b/chat_shell/chat_shell/tools/sandbox/_base.py @@ -16,15 +16,57 @@ """ import asyncio +import base64 import logging import os +import re from typing import Any, Optional +import httpx + logger = logging.getLogger(__name__) # Default configuration DEFAULT_EXECUTOR_MANAGER_URL = "http://localhost:8001" +DEFAULT_BACKEND_API_URL = "http://localhost:8000" DEFAULT_SANDBOX_TIMEOUT = 1800 # 30 minutes +DEVICE_EXEC_ENDPOINT = "/api/internal/devices/sandbox/exec" +DEVICE_READ_FILE_ENDPOINT = "/api/internal/devices/sandbox/read-file" +DEVICE_LIST_FILES_ENDPOINT = "/api/internal/devices/sandbox/list-files" +DEVICE_WRITE_FILE_ENDPOINT = "/api/internal/devices/sandbox/write-file" +DEVICE_DOWNLOAD_ATTACHMENT_ENDPOINT = ( + "/api/internal/devices/sandbox/download-attachment" +) +DEVICE_UPLOAD_ATTACHMENT_ENDPOINT = "/api/internal/devices/sandbox/upload-attachment" +DEVICE_BINDING_ENDPOINT_TEMPLATE = "/api/internal/devices/sandbox/binding/{task_id}" +DEFAULT_DEVICE_ROUTED_COMMANDS = "himalaya" + + +def _build_device_routed_command_pattern() -> re.Pattern[str]: + """Build a regex for commands that should be routed to the device backend.""" + raw_value = os.getenv( + "CHAT_SHELL_DEVICE_ROUTED_COMMANDS", + DEFAULT_DEVICE_ROUTED_COMMANDS, + ) + command_names = [name.strip() for name in raw_value.split(",") if name.strip()] + if not command_names: + return re.compile(r"$^") + + command_patterns: list[str] = [] + for command_name in command_names: + escaped_name = re.escape(command_name) + command_patterns.extend( + [ + escaped_name, + rf"command\s+-v\s+{escaped_name}", + rf"which\s+{escaped_name}", + ] + ) + + return re.compile(rf"(^|\s)({'|'.join(command_patterns)})(\s|$)") + + +DEVICE_ROUTED_COMMAND_PATTERN = _build_device_routed_command_pattern() # E2B SDK patching - must be done before any e2b imports # Setup environment variables first @@ -228,6 +270,9 @@ def __init__( self.timeout = timeout self.bot_config = bot_config or [] self.auth_token = auth_token + self._bound_backend: Optional[str] = None + self._bound_device_id: Optional[str] = None + self._binding_loaded: bool = False # Ensure E2B SDK is patched patch_e2b_sdk() @@ -370,6 +415,245 @@ def sandbox_id(self) -> Optional[str]: """ return None + def should_use_device_backend_for_command(self, command: str) -> bool: + """Return True when a command should prefer the device-backed executor.""" + return self.is_device_backend_bound() or bool( + DEVICE_ROUTED_COMMAND_PATTERN.search(command) + ) + + def is_device_backend_bound(self) -> bool: + """Return whether this task is already pinned to a device backend.""" + return self._bound_backend == "device" and bool(self._bound_device_id) + + def bind_device_backend(self, device_id: str) -> None: + """Pin the current task to a specific device-backed sandbox.""" + self._bound_backend = "device" + self._bound_device_id = device_id + logger.info( + "[SandboxManager] Bound task_id=%s to device backend: device_id=%s", + self.task_id, + device_id, + ) + + async def ensure_device_binding_loaded(self) -> None: + """Load sticky device binding from backend once per task if present.""" + if self._binding_loaded or self.is_device_backend_bound(): + return + + backend_url = _get_backend_api_url() + headers = {} + if self.auth_token: + headers["Authorization"] = f"Bearer {self.auth_token}" + + binding_url = f"{backend_url}{DEVICE_BINDING_ENDPOINT_TEMPLATE.format(task_id=self.task_id)}" + + try: + async with httpx.AsyncClient(timeout=30) as client: + response = await client.get( + binding_url, + params={"user_id": self.user_id}, + headers=headers, + ) + response.raise_for_status() + payload = response.json() + except Exception as exc: + logger.warning( + "[SandboxManager] Failed to load task sandbox binding: task_id=%s, error=%s", + self.task_id, + exc, + ) + self._binding_loaded = True + return + + backend = payload.get("backend") + device_id = payload.get("device_id") + if backend == "device" and isinstance(device_id, str) and device_id: + self.bind_device_backend(device_id) + + self._binding_loaded = True + + async def execute_command_via_device( + self, + command: str, + working_dir: Optional[str] = "/home/user", + timeout_seconds: int = 300, + required_capability: Optional[str] = None, + ) -> dict[str, Any]: + """Execute a command through the backend's device sandbox bridge.""" + return await self._post_device_backend( + endpoint=DEVICE_EXEC_ENDPOINT, + payload={ + "command": command, + "working_dir": working_dir or "/home/user", + "timeout_seconds": timeout_seconds, + "required_capability": required_capability, + }, + timeout_seconds=timeout_seconds, + ) + + async def read_file_via_device( + self, + file_path: str, + format: str = "text", + ) -> dict[str, Any]: + """Read a file through the backend's device sandbox bridge.""" + return await self._post_device_backend( + endpoint=DEVICE_READ_FILE_ENDPOINT, + payload={ + "file_path": file_path, + "format": format, + }, + timeout_seconds=60, + unwrap_data=True, + ) + + async def list_files_via_device( + self, + path: str = "/home/user", + depth: int = 1, + ) -> dict[str, Any]: + """List files through the backend's device sandbox bridge.""" + return await self._post_device_backend( + endpoint=DEVICE_LIST_FILES_ENDPOINT, + payload={ + "path": path, + "depth": depth, + }, + timeout_seconds=60, + unwrap_data=True, + ) + + async def write_file_via_device( + self, + file_path: str, + content: str, + format: str = "text", + create_dirs: bool = True, + ) -> dict[str, Any]: + """Write a file through the backend's device sandbox bridge.""" + return await self._post_device_backend( + endpoint=DEVICE_WRITE_FILE_ENDPOINT, + payload={ + "file_path": file_path, + "content": content, + "format": format, + "create_dirs": create_dirs, + }, + timeout_seconds=60, + unwrap_data=True, + ) + + async def download_attachment_via_device( + self, + attachment_url: str, + save_path: str, + timeout_seconds: int = 300, + ) -> dict[str, Any]: + """Download a Wegent attachment through the device sandbox bridge.""" + return await self._post_device_backend( + endpoint=DEVICE_DOWNLOAD_ATTACHMENT_ENDPOINT, + payload={ + "attachment_url": attachment_url, + "save_path": save_path, + "auth_token": self.auth_token, + "api_base_url": _get_backend_api_url(), + "timeout_seconds": timeout_seconds, + }, + timeout_seconds=timeout_seconds, + unwrap_data=True, + ) + + async def upload_attachment_via_device( + self, + file_path: str, + overwrite_attachment_id: Optional[int] = None, + timeout_seconds: int = 300, + ) -> dict[str, Any]: + """Upload a device-local file back to Wegent attachments.""" + return await self._post_device_backend( + endpoint=DEVICE_UPLOAD_ATTACHMENT_ENDPOINT, + payload={ + "file_path": file_path, + "auth_token": self.auth_token, + "api_base_url": _get_backend_api_url(), + "overwrite_attachment_id": overwrite_attachment_id, + "timeout_seconds": timeout_seconds, + }, + timeout_seconds=timeout_seconds, + unwrap_data=True, + ) + + async def _post_device_backend( + self, + endpoint: str, + payload: dict[str, Any], + timeout_seconds: int, + unwrap_data: bool = False, + ) -> dict[str, Any]: + """Call an internal backend endpoint that proxies to the sticky device.""" + await self.ensure_device_binding_loaded() + backend_url = _get_backend_api_url() + headers = {"Content-Type": "application/json"} + if self.auth_token: + headers["Authorization"] = f"Bearer {self.auth_token}" + + request_payload = { + "task_id": self.task_id, + "user_id": self.user_id, + **payload, + "device_id": self._bound_device_id, + } + + logger.info( + "[SandboxManager] Calling device backend: endpoint=%s, backend_url=%s, " + "timeout=%ss, bound_device_id=%s", + endpoint, + backend_url, + timeout_seconds, + self._bound_device_id, + ) + + async with httpx.AsyncClient(timeout=max(timeout_seconds + 10, 60)) as client: + response = await client.post( + f"{backend_url}{endpoint}", + json=request_payload, + headers=headers, + ) + response.raise_for_status() + result = response.json() + + device_id = result.get("device_id") + if isinstance(device_id, str) and device_id: + self.bind_device_backend(device_id) + + if unwrap_data and isinstance(result.get("data"), dict): + result = { + "success": result.get("success", False), + "execution_time": result.get("execution_time", 0.0), + "device_id": result.get("device_id"), + "backend": result.get("backend", "device"), + **result["data"], + } + + return result + + +def _get_backend_api_url() -> str: + """Resolve backend API base URL for chat_shell-side HTTP calls.""" + explicit_url = os.getenv("BACKEND_API_URL") + if explicit_url: + return explicit_url.rstrip("/") + + remote_storage_url = os.getenv("CHAT_SHELL_REMOTE_STORAGE_URL", "").rstrip("/") + if remote_storage_url: + if remote_storage_url.endswith("/api/internal"): + return remote_storage_url[: -len("/api/internal")] + if remote_storage_url.endswith("/api"): + return remote_storage_url[: -len("/api")] + return remote_storage_url + + return DEFAULT_BACKEND_API_URL + # Patch E2B SDK at module load time patch_e2b_sdk() diff --git a/chat_shell/tests/test_sandbox_manager.py b/chat_shell/tests/test_sandbox_manager.py new file mode 100644 index 000000000..1845c7538 --- /dev/null +++ b/chat_shell/tests/test_sandbox_manager.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: 2025 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for sticky device backend routing in sandbox manager.""" + +import pytest + +from chat_shell.tools.sandbox import _base +from chat_shell.tools.sandbox._base import SandboxManager + + +class _FakeResponse: + def __init__(self, payload): + self._payload = payload + + def raise_for_status(self) -> None: + return None + + def json(self): + return self._payload + + +@pytest.mark.asyncio +async def test_sandbox_manager_sticks_to_device_backend(monkeypatch): + """Once a task is routed to a device, later commands stay on that device.""" + captured_payloads: list[dict] = [] + + class _FakeAsyncClient: + def __init__(self, *args, **kwargs): + return None + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def post(self, url, json, headers): + captured_payloads.append(json) + return _FakeResponse( + { + "success": True, + "stdout": "ok", + "stderr": "", + "exit_code": 0, + "execution_time": 0.1, + "device_id": "device-1", + "backend": "device", + } + ) + + monkeypatch.setattr( + "chat_shell.tools.sandbox._base.httpx.AsyncClient", + _FakeAsyncClient, + ) + + manager = SandboxManager.get_instance( + task_id=794, + user_id=2, + user_name="sifang", + ) + + try: + assert manager.is_device_backend_bound() is False + assert manager.should_use_device_backend_for_command("ls /home") is False + + await manager.execute_command_via_device(command="himalaya --help") + + assert manager.is_device_backend_bound() is True + assert manager.should_use_device_backend_for_command("ls /home") is True + + await manager.execute_command_via_device(command="ls /home") + + assert captured_payloads[0]["task_id"] == 794 + assert captured_payloads[0]["device_id"] is None + assert captured_payloads[1]["device_id"] == "device-1" + finally: + SandboxManager.remove_instance(794) + + +def test_build_device_routed_command_pattern_uses_env(monkeypatch): + """Device-routed commands should be configurable via environment variables.""" + monkeypatch.setenv("CHAT_SHELL_DEVICE_ROUTED_COMMANDS", "mail-cli,custom-tool") + + pattern = _base._build_device_routed_command_pattern() + + assert pattern.search("mail-cli --help") + assert pattern.search("command -v custom-tool") + assert pattern.search("which custom-tool") + assert not pattern.search("himalaya --help") diff --git a/chat_shell/tests/test_skill_registry.py b/chat_shell/tests/test_skill_registry.py new file mode 100644 index 000000000..66f33495d --- /dev/null +++ b/chat_shell/tests/test_skill_registry.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: 2025 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for dynamic skill provider loading.""" + +import io +import sys +import zipfile + +from chat_shell.skills import SkillToolContext +from chat_shell.skills.registry import SkillToolRegistry + + +def _build_out_of_order_skill_zip() -> bytes: + """Create a skill ZIP whose tool module appears before _base.py.""" + + provider_code = """ +from chat_shell.skills import SkillToolProvider + + +class TestProvider(SkillToolProvider): + @property + def provider_name(self) -> str: + return "synthetic" + + @property + def supported_tools(self) -> list[str]: + return ["echo"] + + def create_tool(self, tool_name, context, tool_config=None): + if tool_name != "echo": + raise ValueError(f"Unknown tool: {tool_name}") + from .command_tool import EchoTool + return EchoTool() +""" + + command_tool_code = """ +from langchain_core.tools import BaseTool +from pydantic import BaseModel, Field + + +class EchoInput(BaseModel): + text: str = Field(...) + + +try: + from ._base import PREFIX +except ImportError: + import sys + + package_name = __name__.rsplit(".", 1)[0] + _base_module = sys.modules.get(f"{package_name}._base") + if _base_module: + PREFIX = _base_module.PREFIX + else: + raise ImportError(f"Cannot import _base from {package_name}") + + +class EchoTool(BaseTool): + name: str = "echo" + description: str = "Echo text with a prefix" + args_schema: type[BaseModel] = EchoInput + + def _run(self, text: str): + return f"{PREFIX}:{text}" +""" + + base_code = """ +PREFIX = "loaded" +""" + + init_code = """ +from . import _base +""" + + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, "w", zipfile.ZIP_DEFLATED) as zip_file: + # Write command_tool before _base to reproduce the original failure mode. + zip_file.writestr("zip-order-skill/command_tool.py", command_tool_code) + zip_file.writestr("zip-order-skill/provider.py", provider_code) + zip_file.writestr("zip-order-skill/_base.py", base_code) + zip_file.writestr("zip-order-skill/__init__.py", init_code) + return buffer.getvalue() + + +def test_load_provider_from_zip_handles_base_module_dependencies(): + """Tool modules importing ._base should load regardless of ZIP entry order.""" + + skill_name = "zip-order-skill" + package_name = "skill_pkg_zip_order_skill" + registry = SkillToolRegistry() + + provider = registry.load_provider_from_zip( + zip_content=_build_out_of_order_skill_zip(), + provider_config={"module": "provider", "class": "TestProvider"}, + skill_name=skill_name, + ) + + assert provider is not None + registry.register(provider) + + context = SkillToolContext( + task_id=1, + subtask_id=1, + user_id=1, + db_session=None, + ws_emitter=None, + ) + + tools = registry.create_tools_for_skill( + skill_config={"tools": [{"name": "echo", "provider": "synthetic"}]}, + context=context, + ) + + assert len(tools) == 1 + assert tools[0].invoke({"text": "hello"}) == "loaded:hello" + + for module_name in list(sys.modules): + if module_name == package_name or module_name.startswith(f"{package_name}."): + sys.modules.pop(module_name, None) diff --git a/executor/modes/local/__init__.py b/executor/modes/local/__init__.py index 6a117eda4..7a80e07b3 100644 --- a/executor/modes/local/__init__.py +++ b/executor/modes/local/__init__.py @@ -22,6 +22,7 @@ from executor.modes.local.events import ( ChatEvents, DeviceEvents, + SandboxEvents, TaskEvents, ) from executor.modes.local.runner import LocalRunner @@ -31,5 +32,6 @@ # Event classes "DeviceEvents", "TaskEvents", + "SandboxEvents", "ChatEvents", ] diff --git a/executor/modes/local/events.py b/executor/modes/local/events.py index ea4dcc4a3..83cddb220 100644 --- a/executor/modes/local/events.py +++ b/executor/modes/local/events.py @@ -34,6 +34,17 @@ class TaskEvents: CLOSE_SESSION = "task:close-session" +class SandboxEvents: + """Sandbox helper events for lightweight device-side execution.""" + + EXEC = "sandbox:exec" + READ_FILE = "sandbox:read_file" + LIST_FILES = "sandbox:list_files" + WRITE_FILE = "sandbox:write_file" + DOWNLOAD_ATTACHMENT = "sandbox:download_attachment" + UPLOAD_ATTACHMENT = "sandbox:upload_attachment" + + class ChatEvents: """Chat streaming events using OpenAI Responses API event types. diff --git a/executor/modes/local/handlers.py b/executor/modes/local/handlers.py index 264545e86..73e1e98bd 100644 --- a/executor/modes/local/handlers.py +++ b/executor/modes/local/handlers.py @@ -9,9 +9,19 @@ """ import asyncio +import base64 +import grp +import mimetypes +import os +import pwd +import subprocess import threading +import time +from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Optional +import requests + from shared.logger import setup_logger from shared.models.execution import ExecutionRequest @@ -369,9 +379,597 @@ async def delayed_restart(): pm = ProcessManager() success = pm.restart_executor() if success: - logger.info("[UpgradeHandler] New executor started, exiting current process") + logger.info( + "[UpgradeHandler] New executor started, exiting current process" + ) import sys + sys.exit(0) # Schedule the restart without awaiting it asyncio.create_task(delayed_restart()) + + +class SandboxHandler: + """Handler for lightweight sandbox-style device commands.""" + + MAX_READ_FILE_SIZE = 10 * 1024 * 1024 + MAX_WRITE_FILE_SIZE = 10 * 1024 * 1024 + MAX_UPLOAD_FILE_SIZE = 100 * 1024 * 1024 + + def __init__(self, runner: "LocalRunner"): + """Initialize the sandbox handler.""" + self.runner = runner + + async def handle_exec(self, data: Dict[str, Any]) -> Dict[str, Any]: + """Execute a shell command on the device and return an ack payload.""" + command = str(data.get("command", "")).strip() + if not command: + return { + "success": False, + "stdout": "", + "stderr": "Command is required", + "exit_code": -1, + "execution_time": 0.0, + } + + working_dir = str(data.get("working_dir") or os.path.expanduser("~")) + timeout_seconds = int(data.get("timeout_seconds") or 300) + + logger.info( + "[SandboxHandler] Executing device command: cwd=%s, timeout=%ss, command=%s", + working_dir, + timeout_seconds, + command[:200], + ) + + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, + self._execute_command_sync, + command, + working_dir, + timeout_seconds, + ) + + async def handle_read_file(self, data: Dict[str, Any]) -> Dict[str, Any]: + """Read a file from the device filesystem.""" + file_path = str(data.get("file_path", "")).strip() + file_format = str(data.get("format") or "text") + if not file_path: + return self._error_response("file_path is required") + + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, + self._read_file_sync, + file_path, + file_format, + ) + + async def handle_list_files(self, data: Dict[str, Any]) -> Dict[str, Any]: + """List files from the device filesystem.""" + path = str(data.get("path") or os.path.expanduser("~")) + depth = int(data.get("depth") or 1) + + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, + self._list_files_sync, + path, + depth, + ) + + async def handle_write_file(self, data: Dict[str, Any]) -> Dict[str, Any]: + """Write a file to the device filesystem.""" + file_path = str(data.get("file_path", "")).strip() + content = data.get("content") + file_format = str(data.get("format") or "text") + create_dirs = bool(data.get("create_dirs", True)) + + if not file_path: + return self._error_response("file_path is required") + if content is None or content == "": + return self._error_response("content is required") + + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, + self._write_file_sync, + file_path, + str(content), + file_format, + create_dirs, + ) + + async def handle_download_attachment(self, data: Dict[str, Any]) -> Dict[str, Any]: + """Download a Wegent attachment onto the device.""" + attachment_url = str(data.get("attachment_url", "")).strip() + save_path = str(data.get("save_path", "")).strip() + auth_token = str(data.get("auth_token", "")).strip() + api_base_url = str(data.get("api_base_url", "")).rstrip("/") + timeout_seconds = int(data.get("timeout_seconds") or 300) + + if not attachment_url or not save_path: + return self._error_response("attachment_url and save_path are required") + if not auth_token: + return self._error_response("auth_token is required") + if not api_base_url: + return self._error_response("api_base_url is required") + + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, + self._download_attachment_sync, + attachment_url, + save_path, + auth_token, + api_base_url, + timeout_seconds, + ) + + async def handle_upload_attachment(self, data: Dict[str, Any]) -> Dict[str, Any]: + """Upload a device-local file back to Wegent attachments.""" + file_path = str(data.get("file_path", "")).strip() + auth_token = str(data.get("auth_token", "")).strip() + api_base_url = str(data.get("api_base_url", "")).rstrip("/") + overwrite_attachment_id = data.get("overwrite_attachment_id") + timeout_seconds = int(data.get("timeout_seconds") or 300) + + if not file_path: + return self._error_response("file_path is required") + if not auth_token: + return self._error_response("auth_token is required") + if not api_base_url: + return self._error_response("api_base_url is required") + + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, + self._upload_attachment_sync, + file_path, + auth_token, + api_base_url, + overwrite_attachment_id, + timeout_seconds, + ) + + def _execute_command_sync( + self, + command: str, + working_dir: str, + timeout_seconds: int, + ) -> Dict[str, Any]: + """Execute a command synchronously in a worker thread.""" + started_at = time.monotonic() + resolved_working_dir = self._normalize_path(working_dir) + + try: + result = subprocess.run( + command, + shell=True, + capture_output=True, + text=True, + cwd=resolved_working_dir, + timeout=timeout_seconds, + encoding="utf-8", + errors="replace", + ) + except subprocess.TimeoutExpired as exc: + return { + "success": False, + "stdout": exc.stdout or "", + "stderr": (exc.stderr or "") + + f"\nCommand timed out after {timeout_seconds}s", + "exit_code": -1, + "execution_time": time.monotonic() - started_at, + } + except Exception as exc: + logger.error( + "[SandboxHandler] Device command failed: cwd=%s, resolved_cwd=%s, error=%s", + working_dir, + resolved_working_dir, + exc, + ) + return { + "success": False, + "stdout": "", + "stderr": str(exc), + "exit_code": -1, + "execution_time": time.monotonic() - started_at, + } + + return { + "success": result.returncode == 0, + "stdout": result.stdout or "", + "stderr": result.stderr or "", + "exit_code": result.returncode, + "execution_time": time.monotonic() - started_at, + } + + def _read_file_sync(self, file_path: str, file_format: str) -> Dict[str, Any]: + """Read a file synchronously.""" + started_at = time.monotonic() + try: + resolved_path = self._normalize_path(file_path) + if not os.path.exists(resolved_path): + return self._error_response( + f"File not found: {resolved_path}", + execution_time=time.monotonic() - started_at, + path=resolved_path, + size=0, + content="", + ) + if not os.path.isfile(resolved_path): + return self._error_response( + f"Path is not a file: {resolved_path}", + execution_time=time.monotonic() - started_at, + path=resolved_path, + size=0, + content="", + ) + + file_size = os.path.getsize(resolved_path) + if file_size > self.MAX_READ_FILE_SIZE: + return self._error_response( + f"File too large: {file_size} bytes", + execution_time=time.monotonic() - started_at, + path=resolved_path, + size=file_size, + content="", + ) + + if file_format == "bytes": + with open(resolved_path, "rb") as file_obj: + content = base64.b64encode(file_obj.read()).decode("ascii") + else: + with open( + resolved_path, + "r", + encoding="utf-8", + errors="replace", + ) as file_obj: + content = file_obj.read() + + return { + "success": True, + "content": content, + "size": file_size, + "path": resolved_path, + "format": file_format, + "modified_time": self._iso_mtime(resolved_path), + "execution_time": time.monotonic() - started_at, + } + except Exception as exc: + logger.error("[SandboxHandler] Device read_file failed: %s", exc) + return self._error_response( + str(exc), + execution_time=time.monotonic() - started_at, + path=file_path, + size=0, + content="", + ) + + def _list_files_sync(self, path: str, depth: int) -> Dict[str, Any]: + """List files synchronously.""" + started_at = time.monotonic() + try: + resolved_path = self._normalize_path(path) + if not os.path.exists(resolved_path): + return self._error_response( + f"Path not found: {resolved_path}", + execution_time=time.monotonic() - started_at, + path=resolved_path, + entries=[], + total=0, + ) + if not os.path.isdir(resolved_path): + return self._error_response( + f"Path is not a directory: {resolved_path}", + execution_time=time.monotonic() - started_at, + path=resolved_path, + entries=[], + total=0, + ) + + entries = self._collect_entries(resolved_path, max(depth, 1)) + return { + "success": True, + "entries": entries, + "total": len(entries), + "path": resolved_path, + "execution_time": time.monotonic() - started_at, + } + except Exception as exc: + logger.error("[SandboxHandler] Device list_files failed: %s", exc) + return self._error_response( + str(exc), + execution_time=time.monotonic() - started_at, + path=path, + entries=[], + total=0, + ) + + def _write_file_sync( + self, + file_path: str, + content: str, + file_format: str, + create_dirs: bool, + ) -> Dict[str, Any]: + """Write a file synchronously.""" + started_at = time.monotonic() + try: + resolved_path = self._normalize_path(file_path) + parent_dir = os.path.dirname(resolved_path) + if create_dirs and parent_dir: + Path(parent_dir).mkdir(parents=True, exist_ok=True) + + if file_format == "bytes": + content_bytes = base64.b64decode(content) + mode = "wb" + else: + content_bytes = content.encode("utf-8") + mode = "w" + + if len(content_bytes) > self.MAX_WRITE_FILE_SIZE: + return self._error_response( + f"Content too large: {len(content_bytes)} bytes", + execution_time=time.monotonic() - started_at, + path=resolved_path, + size=len(content_bytes), + ) + + with open( + resolved_path, + mode, + encoding=None if mode == "wb" else "utf-8", + ) as file_obj: + if mode == "wb": + file_obj.write(content_bytes) + else: + file_obj.write(content) + + file_size = os.path.getsize(resolved_path) + return { + "success": True, + "path": resolved_path, + "size": file_size, + "format": file_format, + "modified_time": self._iso_mtime(resolved_path), + "execution_time": time.monotonic() - started_at, + } + except Exception as exc: + logger.error("[SandboxHandler] Device write_file failed: %s", exc) + return self._error_response( + str(exc), + execution_time=time.monotonic() - started_at, + path=file_path, + size=0, + ) + + def _download_attachment_sync( + self, + attachment_url: str, + save_path: str, + auth_token: str, + api_base_url: str, + timeout_seconds: int, + ) -> Dict[str, Any]: + """Download an attachment synchronously.""" + started_at = time.monotonic() + resolved_path = self._normalize_path(save_path) + try: + Path(os.path.dirname(resolved_path)).mkdir(parents=True, exist_ok=True) + download_url = ( + attachment_url + if attachment_url.startswith(("http://", "https://")) + else f"{api_base_url}{attachment_url if attachment_url.startswith('/') else '/' + attachment_url}" + ) + response = requests.get( + download_url, + headers={"Authorization": f"Bearer {auth_token}"}, + timeout=timeout_seconds, + stream=True, + ) + response.raise_for_status() + with open(resolved_path, "wb") as file_obj: + for chunk in response.iter_content(chunk_size=1024 * 1024): + if chunk: + file_obj.write(chunk) + + file_size = os.path.getsize(resolved_path) + return { + "success": True, + "file_path": resolved_path, + "file_size": file_size, + "message": "File downloaded successfully", + "execution_time": time.monotonic() - started_at, + } + except Exception as exc: + logger.error("[SandboxHandler] Device download_attachment failed: %s", exc) + return self._error_response( + f"Failed to download file: {exc}", + execution_time=time.monotonic() - started_at, + file_path=resolved_path, + file_size=0, + ) + + def _upload_attachment_sync( + self, + file_path: str, + auth_token: str, + api_base_url: str, + overwrite_attachment_id: Optional[int], + timeout_seconds: int, + ) -> Dict[str, Any]: + """Upload an attachment synchronously.""" + started_at = time.monotonic() + resolved_path = self._normalize_path(file_path) + try: + if not os.path.exists(resolved_path): + return self._error_response( + f"File not found: {resolved_path}", + execution_time=time.monotonic() - started_at, + attachment_id=None, + filename=os.path.basename(resolved_path), + file_size=0, + download_url="", + ) + if not os.path.isfile(resolved_path): + return self._error_response( + f"Path is not a file: {resolved_path}", + execution_time=time.monotonic() - started_at, + attachment_id=None, + filename=os.path.basename(resolved_path), + file_size=0, + download_url="", + ) + + file_size = os.path.getsize(resolved_path) + if file_size > self.MAX_UPLOAD_FILE_SIZE: + return self._error_response( + f"File too large: {file_size} bytes", + execution_time=time.monotonic() - started_at, + attachment_id=None, + filename=os.path.basename(resolved_path), + file_size=file_size, + download_url="", + ) + + upload_url = f"{api_base_url}/api/attachments/upload" + if overwrite_attachment_id is not None: + upload_url = ( + f"{upload_url}?overwrite_attachment_id={overwrite_attachment_id}" + ) + + with open(resolved_path, "rb") as file_obj: + response = requests.post( + upload_url, + headers={"Authorization": f"Bearer {auth_token}"}, + files={"file": (os.path.basename(resolved_path), file_obj)}, + timeout=timeout_seconds, + ) + response.raise_for_status() + + payload = response.json() + if "detail" in payload: + detail = payload["detail"] + detail_message = ( + detail.get("message") if isinstance(detail, dict) else str(detail) + ) + return self._error_response( + f"Upload API error: {detail_message}", + execution_time=time.monotonic() - started_at, + attachment_id=None, + filename=os.path.basename(resolved_path), + file_size=file_size, + download_url="", + ) + + attachment_id = payload.get("id") + return { + "success": True, + "attachment_id": attachment_id, + "filename": payload.get("filename", os.path.basename(resolved_path)), + "file_size": payload.get("file_size", file_size), + "mime_type": payload.get( + "mime_type", + mimetypes.guess_type(resolved_path)[0] + or "application/octet-stream", + ), + "download_url": f"/api/attachments/{attachment_id}/download", + "message": "File uploaded successfully", + "execution_time": time.monotonic() - started_at, + } + except Exception as exc: + logger.error("[SandboxHandler] Device upload_attachment failed: %s", exc) + return self._error_response( + f"Failed to upload file: {exc}", + execution_time=time.monotonic() - started_at, + attachment_id=None, + filename=os.path.basename(resolved_path), + file_size=0, + download_url="", + ) + + def _normalize_path(self, path: str) -> str: + """Map sandbox-style paths onto the local device home directory.""" + home_dir = os.path.expanduser("~") + normalized = os.path.expanduser(path) + if not os.path.isabs(normalized): + return os.path.join(home_dir, normalized) + if normalized == "/home/user": + return home_dir + if normalized.startswith("/home/user/"): + suffix = normalized[len("/home/user/") :] + return os.path.join(home_dir, suffix) + return normalized + + def _collect_entries(self, root_path: str, depth: int) -> list[Dict[str, Any]]: + """Collect directory entries recursively up to the requested depth.""" + entries: list[Dict[str, Any]] = [] + + def walk(current_path: str, remaining_depth: int) -> None: + with os.scandir(current_path) as iterator: + for entry in iterator: + entry_path = entry.path + stat_result = entry.stat(follow_symlinks=False) + entries.append( + { + "name": entry.name, + "path": entry_path, + "type": self._entry_type(entry), + "size": stat_result.st_size, + "permissions": oct(stat_result.st_mode & 0o777), + "owner": self._resolve_owner(stat_result.st_uid), + "group": self._resolve_group(stat_result.st_gid), + "modified_time": self._iso_mtime(entry_path), + **( + {"symlink_target": os.readlink(entry_path)} + if entry.is_symlink() + else {} + ), + } + ) + if remaining_depth > 1 and entry.is_dir(follow_symlinks=False): + walk(entry_path, remaining_depth - 1) + + walk(root_path, depth) + return entries + + def _entry_type(self, entry: os.DirEntry[str]) -> str: + """Return the normalized entry type.""" + if entry.is_symlink(): + return "symlink" + if entry.is_dir(follow_symlinks=False): + return "directory" + return "file" + + def _resolve_owner(self, uid: int) -> str: + """Resolve an owner name from uid.""" + try: + return pwd.getpwuid(uid).pw_name + except KeyError: + return str(uid) + + def _resolve_group(self, gid: int) -> str: + """Resolve a group name from gid.""" + try: + return grp.getgrgid(gid).gr_name + except KeyError: + return str(gid) + + def _iso_mtime(self, path: str) -> str: + """Return an ISO timestamp for a file's modification time.""" + return time.strftime( + "%Y-%m-%dT%H:%M:%S", time.localtime(os.path.getmtime(path)) + ) + + def _error_response(self, message: str, **kwargs: Any) -> Dict[str, Any]: + """Build a consistent sandbox handler error response.""" + return { + "success": False, + "error": message, + **kwargs, + } diff --git a/executor/modes/local/runner.py b/executor/modes/local/runner.py index 57f5e1d64..d50931021 100644 --- a/executor/modes/local/runner.py +++ b/executor/modes/local/runner.py @@ -26,8 +26,8 @@ from executor.config import config from executor.config.device_config import DeviceConfig -from executor.modes.local.events import ChatEvents, TaskEvents -from executor.modes.local.handlers import TaskHandler, UpgradeHandler +from executor.modes.local.events import ChatEvents, SandboxEvents, TaskEvents +from executor.modes.local.handlers import SandboxHandler, TaskHandler, UpgradeHandler from executor.modes.local.heartbeat import LocalHeartbeatService from executor.modes.local.websocket_client import WebSocketClient from executor.services.updater.process_manager import ProcessManager @@ -78,6 +78,7 @@ def __init__(self, device_config: Optional[DeviceConfig] = None): # Event handlers self.task_handler = TaskHandler(self) self.upgrade_handler = UpgradeHandler(self) + self.sandbox_handler = SandboxHandler(self) # Task queue for execution self.task_queue: asyncio.Queue = asyncio.Queue() @@ -233,6 +234,24 @@ def _register_handlers(self) -> None: self.websocket_client.on( "device:upgrade", self.upgrade_handler.handle_upgrade_command ) + self.websocket_client.on(SandboxEvents.EXEC, self.sandbox_handler.handle_exec) + self.websocket_client.on( + SandboxEvents.READ_FILE, self.sandbox_handler.handle_read_file + ) + self.websocket_client.on( + SandboxEvents.LIST_FILES, self.sandbox_handler.handle_list_files + ) + self.websocket_client.on( + SandboxEvents.WRITE_FILE, self.sandbox_handler.handle_write_file + ) + self.websocket_client.on( + SandboxEvents.DOWNLOAD_ATTACHMENT, + self.sandbox_handler.handle_download_attachment, + ) + self.websocket_client.on( + SandboxEvents.UPLOAD_ATTACHMENT, + self.sandbox_handler.handle_upload_attachment, + ) logger.info("WebSocket event handlers registered") diff --git a/executor/modes/local/websocket_client.py b/executor/modes/local/websocket_client.py index 1a146d5bf..ecb7857b3 100644 --- a/executor/modes/local/websocket_client.py +++ b/executor/modes/local/websocket_client.py @@ -111,6 +111,7 @@ def __init__( self.device_name = device_config.device_name or self._get_device_name() self.device_type = device_config.device_type or "local" self.bind_shell = device_config.bind_shell or "claudecode" + self.capabilities = device_config.capabilities or [] else: self.backend_url = backend_url or config.WEGENT_BACKEND_URL self.auth_token = self._normalize_token( @@ -120,6 +121,7 @@ def __init__( self.device_name = self._get_device_name() self.device_type = "local" self.bind_shell = "claudecode" + self.capabilities = [] # Reconnection settings reconnection_delay = reconnection_delay or config.LOCAL_RECONNECT_DELAY @@ -478,6 +480,7 @@ async def register_device(self, timeout: float = 10.0) -> bool: "name": self.device_name, "device_type": self.device_type, "bind_shell": self.bind_shell, + "capabilities": self.capabilities, "executor_version": get_version(), "client_ip": self._get_client_ip(), } diff --git a/executor/tests/test_local_sandbox_handler.py b/executor/tests/test_local_sandbox_handler.py new file mode 100644 index 000000000..230195792 --- /dev/null +++ b/executor/tests/test_local_sandbox_handler.py @@ -0,0 +1,158 @@ +# SPDX-FileCopyrightText: 2025 Weibo, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for lightweight device sandbox command handling.""" + +import base64 +import os +import subprocess +from unittest.mock import MagicMock, patch + +from executor.modes.local.handlers import SandboxHandler + + +class TestSandboxHandler: + """Tests for SandboxHandler.""" + + def test_execute_command_sync_returns_process_output(self): + """Successful subprocess output should be returned unchanged.""" + handler = SandboxHandler(runner=MagicMock()) + + completed = subprocess.CompletedProcess( + args=["echo", "hello"], + returncode=0, + stdout="hello\n", + stderr="", + ) + + with patch( + "executor.modes.local.handlers.subprocess.run", return_value=completed + ): + result = handler._execute_command_sync( + command="echo hello", + working_dir="/tmp", + timeout_seconds=5, + ) + + assert result["success"] is True + assert result["stdout"] == "hello\n" + assert result["stderr"] == "" + assert result["exit_code"] == 0 + + def test_execute_command_sync_returns_timeout_error(self): + """Timeouts should surface as structured command failures.""" + handler = SandboxHandler(runner=MagicMock()) + + with patch( + "executor.modes.local.handlers.subprocess.run", + side_effect=subprocess.TimeoutExpired( + cmd="sleep 10", + timeout=1, + output="partial", + stderr="still running", + ), + ): + result = handler._execute_command_sync( + command="sleep 10", + working_dir="/tmp", + timeout_seconds=1, + ) + + assert result["success"] is False + assert result["stdout"] == "partial" + assert "timed out" in result["stderr"] + assert result["exit_code"] == -1 + + def test_read_file_sync_reads_text_content(self, tmp_path): + """Text files should be read back with metadata.""" + handler = SandboxHandler(runner=MagicMock()) + target = tmp_path / "notes.txt" + target.write_text("hello device", encoding="utf-8") + + result = handler._read_file_sync(str(target), "text") + + assert result["success"] is True + assert result["content"] == "hello device" + assert result["size"] == len("hello device") + assert result["path"] == str(target) + assert result["format"] == "text" + + def test_list_files_sync_returns_recursive_entries(self, tmp_path): + """Directory listings should include nested entries up to depth.""" + handler = SandboxHandler(runner=MagicMock()) + nested_dir = tmp_path / "reports" + nested_dir.mkdir() + file_path = nested_dir / "weekly.txt" + file_path.write_text("done", encoding="utf-8") + + result = handler._list_files_sync(str(tmp_path), 2) + + assert result["success"] is True + assert result["path"] == str(tmp_path) + assert result["total"] == 2 + paths = {entry["path"] for entry in result["entries"]} + assert str(nested_dir) in paths + assert str(file_path) in paths + + def test_write_file_sync_writes_text_content(self, tmp_path): + """Text writes should create parent directories and persist content.""" + handler = SandboxHandler(runner=MagicMock()) + target = tmp_path / "mail" / "summary.txt" + + result = handler._write_file_sync(str(target), "device output", "text", True) + + assert result["success"] is True + assert result["path"] == str(target) + assert target.read_text(encoding="utf-8") == "device output" + assert result["size"] == len("device output".encode("utf-8")) + + def test_write_file_sync_writes_binary_content(self, tmp_path): + """Binary writes should decode base64 payloads before persisting.""" + handler = SandboxHandler(runner=MagicMock()) + target = tmp_path / "attachments" / "report.bin" + payload = base64.b64encode(b"\x00\x01\x02").decode("ascii") + + result = handler._write_file_sync(str(target), payload, "bytes", True) + + assert result["success"] is True + assert target.read_bytes() == b"\x00\x01\x02" + assert result["size"] == 3 + + def test_normalize_path_maps_home_user_to_local_home(self, tmp_path, monkeypatch): + """Sandbox-style /home/user paths should resolve under the device home.""" + handler = SandboxHandler(runner=MagicMock()) + monkeypatch.setattr(os.path, "expanduser", lambda _: str(tmp_path)) + + assert handler._normalize_path("/home/user") == str(tmp_path) + assert handler._normalize_path("/home/user/mail/config.toml") == str( + tmp_path / "mail" / "config.toml" + ) + assert handler._normalize_path("relative.txt") == str(tmp_path / "relative.txt") + + def test_execute_command_sync_normalizes_sandbox_working_dir(self, tmp_path, monkeypatch): + """Device exec should map /home/user to the local home before subprocess starts.""" + handler = SandboxHandler(runner=MagicMock()) + monkeypatch.setattr(os.path, "expanduser", lambda _: str(tmp_path)) + + completed = subprocess.CompletedProcess( + args=["pwd"], + returncode=0, + stdout=f"{tmp_path}\n", + stderr="", + ) + + with patch( + "executor.modes.local.handlers.subprocess.run", + return_value=completed, + ) as mock_run: + result = handler._execute_command_sync( + command="pwd", + working_dir="/home/user", + timeout_seconds=5, + ) + + assert result["success"] is True + assert result["stdout"] == f"{tmp_path}\n" + mock_run.assert_called_once() + assert mock_run.call_args.kwargs["cwd"] == str(tmp_path)