Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 109 additions & 42 deletions crewai_tools/tools/code_interpreter_tool/code_interpreter_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@

import importlib.util
import os
import shutil
import subprocess
from types import ModuleType
from typing import Any, Dict, List, Optional, Type

from crewai.tools import BaseTool
from docker import DockerClient
from docker import from_env as docker_from_env
from docker.errors import ImageNotFound, NotFound
from docker.models.containers import Container
from typing import Any, ClassVar

from crewai.tools import BaseTool # type: ignore[import-untyped]
from docker import DockerClient # type: ignore[import-untyped]
from docker import from_env as docker_from_env # type: ignore[import-untyped]
from docker.errors import ImageNotFound, NotFound # type: ignore[import-untyped]
from docker.models.containers import Container # type: ignore[import-untyped]
from packaging.requirements import InvalidRequirement, Requirement
from pydantic import BaseModel, Field

from crewai_tools.printer import Printer
Expand All @@ -32,7 +35,7 @@ class CodeInterpreterSchema(BaseModel):
description="Python3 code used to be interpreted in the Docker container. ALWAYS PRINT the final result and the output of the code",
)

libraries_used: List[str] = Field(
libraries_used: list[str] = Field(
...,
description="List of libraries used in the code with proper installing names separated by commas. Example: numpy,pandas,beautifulsoup4",
)
Expand All @@ -46,7 +49,7 @@ class SandboxPython:
environment where harmful operations are blocked.
"""

BLOCKED_MODULES = {
BLOCKED_MODULES: ClassVar[set[str]] = {
"os",
"sys",
"subprocess",
Expand All @@ -58,7 +61,7 @@ class SandboxPython:
"builtins",
}

UNSAFE_BUILTINS = {
UNSAFE_BUILTINS: ClassVar[set[str]] = {
"exec",
"eval",
"open",
Expand All @@ -74,9 +77,9 @@ class SandboxPython:
@staticmethod
def restricted_import(
name: str,
custom_globals: Optional[Dict[str, Any]] = None,
custom_locals: Optional[Dict[str, Any]] = None,
fromlist: Optional[List[str]] = None,
custom_globals: dict[str, Any] | None = None,
custom_locals: dict[str, Any] | None = None,
fromlist: list[str] | None = None,
level: int = 0,
) -> ModuleType:
"""A restricted import function that blocks importing of unsafe modules.
Expand All @@ -99,7 +102,7 @@ def restricted_import(
return __import__(name, custom_globals, custom_locals, fromlist or (), level)

@staticmethod
def safe_builtins() -> Dict[str, Any]:
def safe_builtins() -> dict[str, Any]:
"""Creates a dictionary of built-in functions with unsafe ones removed.

Returns:
Expand All @@ -116,14 +119,14 @@ def safe_builtins() -> Dict[str, Any]:
return safe_builtins

@staticmethod
def exec(code: str, locals: Dict[str, Any]) -> None:
def exec(code: str, locals: dict[str, Any]) -> None:
"""Executes Python code in a restricted environment.

Args:
code: The Python code to execute as a string.
locals: A dictionary that will be used for local variable storage.
"""
exec(code, {"__builtins__": SandboxPython.safe_builtins()}, locals)
exec(code, {"__builtins__": SandboxPython.safe_builtins()}, locals) # noqa: S102


class CodeInterpreterTool(BaseTool):
Expand All @@ -136,11 +139,11 @@ class CodeInterpreterTool(BaseTool):

name: str = "Code Interpreter"
description: str = "Interprets Python3 code strings with a final print statement."
args_schema: Type[BaseModel] = CodeInterpreterSchema
args_schema: type[BaseModel] = CodeInterpreterSchema
default_image_tag: str = "code-interpreter:latest"
code: Optional[str] = None
user_dockerfile_path: Optional[str] = None
user_docker_base_url: Optional[str] = None
code: str | None = None
user_dockerfile_path: str | None = None
user_docker_base_url: str | None = None
unsafe_mode: bool = False

@staticmethod
Expand All @@ -151,6 +154,8 @@ def _get_installed_package_path() -> str:
The directory path where the package is installed.
"""
spec = importlib.util.find_spec("crewai_tools")
if spec is None or spec.origin is None:
raise FileNotFoundError("Unable to locate crewai_tools package installation path.") from None
return os.path.dirname(spec.origin)

def _verify_docker_image(self) -> None:
Expand Down Expand Up @@ -183,7 +188,7 @@ def _verify_docker_image(self) -> None:
if not os.path.exists(dockerfile_path):
raise FileNotFoundError(
f"Dockerfile not found in {dockerfile_path}"
)
) from None

client.images.build(
path=dockerfile_path,
Expand All @@ -203,12 +208,20 @@ def _run(self, **kwargs) -> str:
code = kwargs.get("code", self.code)
libraries_used = kwargs.get("libraries_used", [])

if self.unsafe_mode:
return self.run_code_unsafe(code, libraries_used)
execution_code = code if isinstance(code, str) else (self.code or "")

if isinstance(libraries_used, list):
libraries = [str(library) for library in libraries_used]
elif libraries_used:
libraries = [str(libraries_used)]
else:
return self.run_code_safety(code, libraries_used)
libraries = []

def _install_libraries(self, container: Container, libraries: List[str]) -> None:
if self.unsafe_mode:
return self.run_code_unsafe(execution_code, libraries)
return self.run_code_safety(execution_code, libraries)

def _install_libraries(self, container: Container, libraries: list[str]) -> None:
"""Installs required Python libraries in the Docker container.

Args:
Expand Down Expand Up @@ -257,11 +270,14 @@ def _check_docker_available(self) -> bool:
Returns:
True if Docker is available and running, False otherwise.
"""
import subprocess
docker_executable = shutil.which("docker")
if not docker_executable:
Printer.print("Docker is not installed", color="bold_purple")
return False

try:
subprocess.run(
["docker", "info"],
subprocess.run( # noqa: S603
[docker_executable, "info"],
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
Expand All @@ -278,7 +294,7 @@ def _check_docker_available(self) -> bool:
Printer.print("Docker is not installed", color="bold_purple")
return False

def run_code_safety(self, code: str, libraries_used: List[str]) -> str:
def run_code_safety(self, code: str, libraries_used: list[str]) -> str:
"""Runs code in the safest available environment.

Attempts to run code in Docker if available, falls back to a restricted
Expand All @@ -293,10 +309,9 @@ def run_code_safety(self, code: str, libraries_used: List[str]) -> str:
"""
if self._check_docker_available():
return self.run_code_in_docker(code, libraries_used)
else:
return self.run_code_in_restricted_sandbox(code)
return self.run_code_in_restricted_sandbox(code)

def run_code_in_docker(self, code: str, libraries_used: List[str]) -> str:
def run_code_in_docker(self, code: str, libraries_used: list[str]) -> str:
"""Runs Python code in a Docker container for safe isolation.

Creates a Docker container, installs the required libraries, executes the code,
Expand Down Expand Up @@ -337,14 +352,14 @@ def run_code_in_restricted_sandbox(self, code: str) -> str:
or an error message if execution failed.
"""
Printer.print("Running code in restricted sandbox", color="yellow")
exec_locals = {}
exec_locals: dict[str, Any] = {}
try:
SandboxPython.exec(code=code, locals=exec_locals)
return exec_locals.get("result", "No result variable found.")
except Exception as e:
return f"An error occurred: {str(e)}"
return f"An error occurred: {e!s}"

def run_code_unsafe(self, code: str, libraries_used: List[str]) -> str:
def run_code_unsafe(self, code: str, libraries_used: list[str]) -> str:
"""Runs code directly on the host machine without any safety restrictions.

WARNING: This mode is unsafe and should only be used in trusted environments
Expand All @@ -360,14 +375,66 @@ def run_code_unsafe(self, code: str, libraries_used: List[str]) -> str:
"""

Printer.print("WARNING: Running code in unsafe mode", color="bold_magenta")
# Install libraries on the host machine
for library in libraries_used:
os.system(f"pip install {library}")

# Execute the code
try:
exec_locals = {}
exec(code, {}, exec_locals)
self._install_libraries_on_host(libraries_used)
except (RuntimeError, ValueError) as error:
return f"An error occurred while installing libraries: {error!s}"

try:
exec_locals: dict[str, Any] = {}
exec(code, {}, exec_locals) # noqa: S102
return exec_locals.get("result", "No result variable found.")
except Exception as e:
return f"An error occurred: {str(e)}"
return f"An error occurred while executing code in unsafe mode: {type(e).__name__}: {e!s}"

@staticmethod
def _sanitize_library_requirement(library: str) -> str:
"""Validates and normalizes a pip library specification."""

if not isinstance(library, str):
raise ValueError("Library specification must be a string.")

normalized = library.strip()
if not normalized:
raise ValueError("Library specification cannot be empty.")

if normalized.startswith("-"):
raise ValueError("Library specification cannot start with '-'.")

try:
requirement = Requirement(normalized)
except InvalidRequirement as error:
raise ValueError(f"Invalid library specification '{library}': {error}") from error

if getattr(requirement, "url", None):
raise ValueError(f"URL-based requirements are not allowed: '{library}'")

return normalized

def _install_libraries_on_host(self, libraries: list[str]) -> None:
"""Safely installs libraries on the host machine using pip."""

if not libraries:
return

pip_executable = shutil.which("pip")
if pip_executable is None:
raise RuntimeError("Unable to locate 'pip' executable on host system.")

for library in libraries:
sanitized = self._sanitize_library_requirement(library)

try:
subprocess.run( # noqa: S603
[pip_executable, "install", sanitized],
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
except subprocess.CalledProcessError as error:
error_output = error.stderr.strip() or error.stdout.strip()
raise RuntimeError(
f"Failed to install dependency '{sanitized}'. {error_output}"
) from error
55 changes: 54 additions & 1 deletion tests/tools/test_code_interpreter_tool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import subprocess
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -161,7 +162,7 @@ def test_unsafe_mode_running_with_no_result_variable(


def test_unsafe_mode_running_unsafe_code(printer_mock, docker_unavailable_mock):
"""Test behavior when no result variable is set."""
"""Test unsafe code execution returns expected result."""
tool = CodeInterpreterTool(unsafe_mode=True)
code = """
import os
Expand All @@ -173,3 +174,55 @@ def test_unsafe_mode_running_unsafe_code(printer_mock, docker_unavailable_mock):
"WARNING: Running code in unsafe mode", color="bold_magenta"
)
assert 5.0 == result


def test_sanitize_library_requirement_valid():
sanitized = CodeInterpreterTool._sanitize_library_requirement("requests>=2.0")
assert sanitized == "requests>=2.0"


@pytest.mark.parametrize(
"library",
["", " ", "--help", "-r requirements.txt", "git+https://example.com/repo.git"],
)
def test_sanitize_library_requirement_invalid(library):
with pytest.raises(ValueError):
CodeInterpreterTool._sanitize_library_requirement(library)


@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.subprocess.run")
def test_install_libraries_on_host_invokes_pip(mock_run):
mock_run.return_value = subprocess.CompletedProcess(
args=["pip", "install", "requests"], returncode=0, stdout="", stderr=""
)
with patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.shutil.which") as which_mock:
which_mock.return_value = "/usr/bin/pip"
tool = CodeInterpreterTool()
tool._install_libraries_on_host(["requests"])

which_mock.assert_called_once_with("pip")
mock_run.assert_called_once_with(
["/usr/bin/pip", "install", "requests"],
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)


@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.subprocess.run")
def test_install_libraries_on_host_invalid_requirement(mock_run):
mock_run.side_effect = subprocess.CalledProcessError(
1,
["pip", "install", "requests"],
stderr="error",
)
with patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.shutil.which") as which_mock:
which_mock.return_value = "/usr/bin/pip"
tool = CodeInterpreterTool()

with pytest.raises(RuntimeError) as exc:
tool._install_libraries_on_host(["requests"])

which_mock.assert_called_once_with("pip")
assert "Failed to install dependency" in str(exc.value)
Loading