diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 0000000..7e4b989 --- /dev/null +++ b/PLAN.md @@ -0,0 +1,54 @@ +# FileSystem and Shell Capabilities + +Closes #25 and #26. + +## Overview + +Two `AbstractCapability` subclasses providing file system access and shell execution as composable agent capabilities. + +## FileSystem (`src/pydantic_harness/filesystem.py`) + +**Tools provided via `get_toolset()`:** +- `read_file(path, offset, limit)` -- reads a text file with numbered lines +- `write_file(path, content)` -- creates or overwrites a file +- `edit_file(path, old_text, new_text, replace_all)` -- exact string replacement editing +- `list_directory(path)` -- directory listing with type/size indicators +- `search_files(pattern, path)` -- regex search across files + +**Configuration:** +- `root_dir` -- all paths resolved relative to this, traversal prevented +- `allowed_patterns` -- glob allowlist (if non-empty, only matching paths accessible) +- `denied_patterns` -- glob denylist (matching paths always rejected) +- `max_read_lines` -- per-read line limit (default 2000) + +**Security:** Path traversal above `root_dir` is rejected. Hidden files skipped in search. Binary files skipped in search. + +## Shell (`src/pydantic_harness/shell.py`) + +**Tool provided via `get_toolset()`:** +- `run_command(command, timeout_seconds)` -- execute a shell command + +**Configuration:** +- `cwd` -- working directory for commands +- `allowed_commands` -- executable allowlist (mutually exclusive with deny) +- `denied_commands` -- executable denylist +- `default_timeout` -- seconds (default 30) +- `max_output_chars` -- output truncation limit (default 10000) + +**Implementation:** Uses `anyio.open_process` for async-backend-agnostic subprocess execution (works with both asyncio and trio). + +## Design decisions + +1. **Public methods for tool implementations** -- `read_file()`, `write_file()`, etc. are public methods on the capability class, registered with the toolset via `FunctionToolset.add_function()`. This allows direct testing and reuse by subclasses or future Environment abstraction (#52). + +2. **No `RunContext` dependency** -- tool implementations are synchronous (FileSystem) or standalone async (Shell), following `get_toolset()` semantics where the toolset is created at agent construction time. + +3. **anyio for Shell** -- uses `anyio.open_process` instead of `asyncio.create_subprocess_shell` so the capability works under both asyncio and trio event loops. + +4. **pydantic-ai-slim from main** -- the harness `pyproject.toml` sources `pydantic-ai-slim` from the pydantic-ai `main` branch (which includes the capabilities module). This will be updated to a release version once capabilities ship. + +## Future considerations + +- Integration with the Environment abstraction (#52) -- FileSystem and Shell could become thin wrappers over an `Environment` protocol +- `get_instructions()` -- could provide context-aware system prompt additions +- `for_run()` -- per-run state isolation for sandboxed environments diff --git a/pyproject.toml b/pyproject.toml index 0d573a0..792f21f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ 'Topic :: Software Development :: Libraries', 'Typing :: Typed', ] -dependencies = ['pydantic-ai-slim>=1.76.0'] +dependencies = ['pydantic-ai-slim>=1.76.0', 'anyio>=4.0'] [project.urls] Homepage = 'https://github.com/pydantic/pydantic-harness' @@ -33,8 +33,8 @@ Issues = 'https://github.com/pydantic/pydantic-harness/issues' [dependency-groups] dev = [ - 'pytest', 'anyio[trio]', + 'pytest', 'pytest-anyio', 'coverage', ] @@ -85,7 +85,15 @@ executionEnvironments = [ [tool.pytest.ini_options] xfail_strict = true -filterwarnings = ['error'] +filterwarnings = [ + 'error', + # Python 3.10 asyncio subprocess transports emit spurious ResourceWarnings + # and RuntimeErrors (from __del__ on a closed event loop) during GC even + # after the process is properly closed. Fixed in 3.11+. + 'ignore:unclosed transport:ResourceWarning', + 'ignore:unclosed file:ResourceWarning', + 'ignore::pytest.PytestUnraisableExceptionWarning', +] anyio_mode = 'auto' [tool.coverage.run] diff --git a/src/pydantic_harness/__init__.py b/src/pydantic_harness/__init__.py index 9d728b6..1a548d6 100644 --- a/src/pydantic_harness/__init__.py +++ b/src/pydantic_harness/__init__.py @@ -7,4 +7,10 @@ # Each capability module is imported and re-exported here. # Capabilities are listed alphabetically. -__all__: list[str] = [] +from pydantic_harness.filesystem import FileSystem +from pydantic_harness.shell import Shell + +__all__: list[str] = [ + 'FileSystem', + 'Shell', +] diff --git a/src/pydantic_harness/filesystem.py b/src/pydantic_harness/filesystem.py new file mode 100644 index 0000000..009e916 --- /dev/null +++ b/src/pydantic_harness/filesystem.py @@ -0,0 +1,329 @@ +"""FileSystem capability: gives agents configurable file system access. + +Provides tools for reading, writing, editing, listing, searching, and finding +files, all scoped to a configurable root directory with path filtering. +""" + +from __future__ import annotations + +import fnmatch +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from pydantic_ai.capabilities.abstract import AbstractCapability +from pydantic_ai.toolsets import AgentToolset, FunctionToolset + + +def format_lines(text: str, offset: int, limit: int) -> str: + """Format text with line numbers, similar to ``cat -n``. + + Args: + text: The raw file content. + offset: Zero-based line offset to start from. + limit: Maximum number of lines to include. + + Returns: + Numbered text with a continuation hint when more lines remain. + """ + lines = text.splitlines(keepends=True) + total = len(lines) + + if offset >= total > 0: + raise ValueError(f'Offset {offset} exceeds file length ({total} lines).') + + selected = lines[offset : offset + limit] + numbered = [f'{i:>6}\t{line}' for i, line in enumerate(selected, start=offset + 1)] + result = ''.join(numbered) + if not result.endswith('\n'): + result += '\n' + + remaining = total - (offset + len(selected)) + if remaining > 0: + next_offset = offset + len(selected) + result += f'... ({remaining} more lines. Use offset={next_offset} to continue reading.)\n' + + return result + + +@dataclass +class FileSystem(AbstractCapability[Any]): + """Capability that provides file system access scoped to a root directory. + + All paths supplied by the model are resolved relative to ``root_dir``. + Traversal above the root is rejected. Optional allow/deny glob patterns + restrict which paths may be accessed. + + Example:: + + from pydantic_ai import Agent + from pydantic_harness.filesystem import FileSystem + + agent = Agent('openai:gpt-4o', capabilities=[FileSystem(root_dir='.')]) + """ + + root_dir: str | Path = '.' + """Root directory for all file operations. Defaults to the current directory.""" + + allowed_patterns: list[str] = field(default_factory=lambda: list[str]()) + """If non-empty, only paths matching at least one glob pattern are accessible.""" + + denied_patterns: list[str] = field(default_factory=lambda: list[str]()) + """Paths matching any of these glob patterns are rejected.""" + + max_read_lines: int = 2000 + """Maximum number of lines returned by a single ``read_file`` call.""" + + def __post_init__(self) -> None: + """Resolve the root directory to an absolute path.""" + self._root = Path(self.root_dir).resolve() + + # ------------------------------------------------------------------ + # Path helpers + # ------------------------------------------------------------------ + + def resolve_path(self, path: str) -> Path: + """Resolve *path* relative to the root, raising on traversal. + + Args: + path: A relative path within the root directory. + + Returns: + The resolved absolute path. + + Raises: + PermissionError: If the resolved path escapes the root. + """ + resolved = (self._root / path).resolve() + if not resolved.is_relative_to(self._root): + raise PermissionError(f'Path {path!r} resolves outside the root directory.') + return resolved + + def check_access(self, path: str) -> None: + """Raise ``PermissionError`` if *path* is blocked by allow/deny patterns. + + Args: + path: The relative path to check. + """ + if self.denied_patterns: + for pattern in self.denied_patterns: + if fnmatch.fnmatch(path, pattern): + raise PermissionError(f'Path {path!r} is denied by pattern {pattern!r}.') + if self.allowed_patterns: + if not any(fnmatch.fnmatch(path, p) for p in self.allowed_patterns): + raise PermissionError(f'Path {path!r} does not match any allowed pattern.') + + def safe_resolve(self, path: str) -> Path: + """Resolve and access-check a path in one step. + + Args: + path: The relative path to resolve and validate. + + Returns: + The resolved absolute path. + """ + self.check_access(path) + return self.resolve_path(path) + + # ------------------------------------------------------------------ + # Tool implementations + # ------------------------------------------------------------------ + + def read_file(self, path: str, *, offset: int = 0, limit: int | None = None) -> str: + """Read a text file with line numbers. + + Args: + path: File path relative to the root directory. + offset: Zero-based line offset to start reading from. + limit: Maximum number of lines to return. Defaults to ``max_read_lines``. + + Returns: + File content with line numbers. + """ + if limit is None: + limit = self.max_read_lines + resolved = self.safe_resolve(path) + if not resolved.is_file(): + if resolved.is_dir(): + raise FileNotFoundError(f"'{path}' is a directory, not a file.") + raise FileNotFoundError(f'File not found: {path}') + text = resolved.read_text(encoding='utf-8') + return format_lines(text, offset, limit) + + def write_file(self, path: str, content: str) -> str: + """Create or overwrite a file. + + Args: + path: File path relative to the root directory. + content: The text content to write. + + Returns: + Confirmation message. + """ + resolved = self.safe_resolve(path) + resolved.parent.mkdir(parents=True, exist_ok=True) + resolved.write_text(content, encoding='utf-8') + return f'Successfully wrote {len(content)} characters to {path}.' + + def edit_file(self, path: str, old_text: str, new_text: str, *, replace_all: bool = False) -> str: + """Edit a file by exact string replacement. + + Args: + path: File path relative to the root directory. + old_text: The exact text to find. + new_text: The replacement text. + replace_all: If True, replace all occurrences. + Otherwise ``old_text`` must appear exactly once. + + Returns: + Summary of replacements made. + """ + resolved = self.safe_resolve(path) + if not resolved.is_file(): + raise FileNotFoundError(f'File not found: {path}') + text = resolved.read_text(encoding='utf-8') + + count = text.count(old_text) + if count == 0: + raise ValueError(f'old_text not found in {path}.') + if not replace_all and count > 1: + raise ValueError( + f'old_text found {count} times in {path}. ' + 'Set replace_all=True or provide more surrounding context to make the match unique.' + ) + + new_content = text.replace(old_text, new_text) if replace_all else text.replace(old_text, new_text, 1) + resolved.write_text(new_content, encoding='utf-8') + replacements = count if replace_all else 1 + return f'Replaced {replacements} occurrence(s) in {path}.' + + def list_directory(self, path: str = '.') -> str: + """List the contents of a directory. + + Args: + path: Directory path relative to the root directory. + + Returns: + A newline-separated listing with type indicators (``/`` for directories). + """ + resolved = self.safe_resolve(path) + if not resolved.is_dir(): + raise NotADirectoryError(f'Not a directory: {path}') + + entries: list[str] = [] + for entry in sorted(resolved.iterdir()): + rel = str(entry.relative_to(self._root)) + if entry.is_dir(): + entries.append(f'{rel}/') + else: + try: + size = entry.stat().st_size + except OSError: # pragma: no cover + size = 0 + entries.append(f'{rel} ({size} bytes)') + return '\n'.join(entries) if entries else '(empty directory)' + + def search_files(self, pattern: str, *, path: str = '.') -> str: + """Search file contents using a regular expression. + + Args: + pattern: Regex pattern to search for. + path: Directory to search in, relative to the root directory. + + Returns: + Matching lines formatted as ``file:line_number:text``. + """ + resolved = self.safe_resolve(path) + compiled = re.compile(pattern) + results: list[str] = [] + + if resolved.is_file(): + files = [resolved] + else: + files = sorted(resolved.rglob('*')) + + for file_path in files: + if not file_path.is_file(): + continue + # Skip hidden files/directories + try: + rel_parts = file_path.relative_to(self._root).parts + except ValueError: # pragma: no cover + continue + if any(part.startswith('.') for part in rel_parts): + continue + try: + raw = file_path.read_bytes() + except OSError: + continue + # Skip binary files + if b'\x00' in raw[:8192]: + continue + text = raw.decode('utf-8', errors='replace') + rel_path = str(file_path.relative_to(self._root)) + for line_num, line in enumerate(text.splitlines(), start=1): + if compiled.search(line): + results.append(f'{rel_path}:{line_num}:{line}') + if len(results) > 1000: + results.append('[... truncated at 1000 matches]') + break + + return '\n'.join(results) if results else 'No matches found.' + + def create_directory(self, path: str) -> str: + """Create a directory and any missing parents. + + Args: + path: Directory path relative to the root directory. + + Returns: + Confirmation message. + """ + resolved = self.safe_resolve(path) + resolved.mkdir(parents=True, exist_ok=True) + return f'Created directory {path}.' + + def find_files(self, pattern: str, *, path: str = '.') -> str: + """Find files by glob pattern (name matching, not content search). + + Args: + pattern: Glob pattern to match file names against (e.g. ``*.py``, ``**/*.json``). + path: Directory to search in, relative to the root directory. + + Returns: + Newline-separated list of matching file paths relative to the root. + """ + resolved = self.safe_resolve(path) + if not resolved.is_dir(): + raise NotADirectoryError(f'Not a directory: {path}') + + matches: list[str] = [] + for match in sorted(resolved.glob(pattern)): + rel = str(match.relative_to(self._root)) + # Skip hidden files/directories + if any(part.startswith('.') for part in match.relative_to(self._root).parts): + continue + suffix = '/' if match.is_dir() else '' + matches.append(f'{rel}{suffix}') + if len(matches) > 1000: + matches.append('[... truncated at 1000 matches]') + break + + return '\n'.join(matches) if matches else 'No matches found.' + + # ------------------------------------------------------------------ + # Capability interface + # ------------------------------------------------------------------ + + def get_toolset(self) -> AgentToolset[Any] | None: + """Build and return the toolset containing all file system tools.""" + toolset: FunctionToolset[Any] = FunctionToolset() + toolset.add_function(self.read_file, name='read_file') + toolset.add_function(self.write_file, name='write_file') + toolset.add_function(self.edit_file, name='edit_file') + toolset.add_function(self.list_directory, name='list_directory') + toolset.add_function(self.search_files, name='search_files') + toolset.add_function(self.create_directory, name='create_directory') + toolset.add_function(self.find_files, name='find_files') + return toolset diff --git a/src/pydantic_harness/shell.py b/src/pydantic_harness/shell.py new file mode 100644 index 0000000..c4b1603 --- /dev/null +++ b/src/pydantic_harness/shell.py @@ -0,0 +1,220 @@ +"""Shell capability: gives agents configurable command execution. + +Provides a ``run_command`` tool with timeout support, output truncation, +and optional command allow/deny lists. +""" + +from __future__ import annotations + +import re +import shlex +import subprocess +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import anyio +from pydantic_ai.capabilities.abstract import AbstractCapability +from pydantic_ai.toolsets import AgentToolset, FunctionToolset + + +@dataclass +class Shell(AbstractCapability[Any]): + """Capability that provides shell command execution. + + Commands are executed in a subprocess rooted at ``cwd``. An optional + allow-list (``allowed_commands``) or deny-list (``denied_commands``) + restricts which executables may be invoked. Output is truncated to + ``max_output_chars`` to keep model context manageable. + + When ``persist_cwd`` is ``True``, the shell tracks ``cd`` commands and + adjusts the working directory for subsequent calls, simulating a + persistent shell session. + + Example:: + + from pydantic_ai import Agent + from pydantic_harness.shell import Shell + + agent = Agent('openai:gpt-4o', capabilities=[Shell(cwd='.')]) + """ + + cwd: str | Path = '.' + """Working directory for command execution.""" + + allowed_commands: list[str] = field(default_factory=lambda: list[str]()) + """If non-empty, only these command names may be executed (allowlist).""" + + denied_commands: list[str] = field(default_factory=lambda: list[str]()) + """These command names are always rejected (denylist).""" + + default_timeout: float = 30.0 + """Default timeout in seconds for command execution.""" + + max_output_chars: int = 10_000 + """Maximum characters of output returned to the model.""" + + persist_cwd: bool = False + """If ``True``, track ``cd`` commands and adjust the working directory for subsequent calls.""" + + def __post_init__(self) -> None: + """Resolve the working directory and validate configuration.""" + self._cwd = Path(self.cwd).resolve() + if self.allowed_commands and self.denied_commands: + raise ValueError('Specify allowed_commands or denied_commands, not both.') + + # ------------------------------------------------------------------ + # Validation + # ------------------------------------------------------------------ + + def check_command(self, command: str) -> None: + """Validate *command* against allow/deny lists. + + Args: + command: The shell command string to validate. + + Raises: + PermissionError: If the command is blocked by the allow/deny lists. + """ + try: + tokens = shlex.split(command) + except ValueError: + # If shlex can't parse it, fall through and let the shell handle it + return + if not tokens: + return + executable = tokens[0] + + if self.denied_commands and executable in self.denied_commands: + raise PermissionError(f'Command {executable!r} is denied.') + if self.allowed_commands and executable not in self.allowed_commands: + raise PermissionError(f'Command {executable!r} is not in the allowed list.') + + def truncate(self, text: str) -> str: + """Truncate *text* to ``max_output_chars``. + + Args: + text: The text to truncate. + + Returns: + The original text if within limits, otherwise truncated with a notice. + """ + if len(text) <= self.max_output_chars: + return text + return text[: self.max_output_chars] + f'\n... [output truncated at {self.max_output_chars} characters]' + + # ------------------------------------------------------------------ + # Execution + # ------------------------------------------------------------------ + + def _extract_cd_target(self, command: str) -> str | None: + """Extract the target directory from a ``cd`` command. + + Returns ``None`` if *command* is not a ``cd`` invocation. + """ + stripped = command.strip() + # Match: `cd `, `cd && ...`, `cd ;...` + m = re.match(r'^cd\s+(.+?)(?:\s*[;&|]|$)', stripped) + if m is None: + return None + target = m.group(1).strip() + # Strip surrounding quotes + if len(target) >= 2 and target[0] in ('"', "'") and target[-1] == target[0]: + target = target[1:-1] + return target + + def _update_cwd(self, command: str) -> None: + """Update ``_cwd`` if *command* contains a ``cd`` and the target exists.""" + target = self._extract_cd_target(command) + if target is None: + return + if target == '~': + new_cwd = Path.home() + elif target.startswith('~'): + new_cwd = Path.home() / target[2:] # skip "~/" + else: + new_cwd = (self._cwd / target).resolve() + if new_cwd.is_dir(): + self._cwd = new_cwd + + async def run_command(self, command: str, *, timeout_seconds: float | None = None) -> str: + """Execute a shell command and return its output. + + Stdout and stderr are captured separately and labeled in the output. + When ``persist_cwd`` is enabled, ``cd`` commands update the working + directory for subsequent calls. + + Args: + command: The shell command to run. + timeout_seconds: Maximum seconds to wait. Defaults to ``default_timeout``. + + Returns: + Labeled stdout/stderr output, with exit code appended on non-zero exit. + """ + self.check_command(command) + timeout = timeout_seconds if timeout_seconds is not None else self.default_timeout + + proc = await anyio.open_process( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=self._cwd, + ) + try: + assert proc.stdout is not None + assert proc.stderr is not None + stdout_chunks: list[bytes] = [] + stderr_chunks: list[bytes] = [] + + async def _read_stdout() -> None: + assert proc.stdout is not None + async for chunk in proc.stdout: + stdout_chunks.append(chunk) + + async def _read_stderr() -> None: + assert proc.stderr is not None + async for chunk in proc.stderr: + stderr_chunks.append(chunk) + + with anyio.fail_after(timeout): + async with anyio.create_task_group() as tg: + tg.start_soon(_read_stdout) + tg.start_soon(_read_stderr) + await proc.wait() + except TimeoutError: + proc.kill() + with anyio.CancelScope(shield=True): + await proc.wait() + return f'[Command timed out after {timeout} seconds]' + finally: + await proc.aclose() + + stdout = b''.join(stdout_chunks).decode('utf-8', errors='replace') + stderr = b''.join(stderr_chunks).decode('utf-8', errors='replace') + + parts: list[str] = [] + if stdout: + parts.append(f'[stdout]\n{stdout}') + if stderr: + parts.append(f'[stderr]\n{stderr}') + output = '\n'.join(parts) if parts else '' + + output = self.truncate(output) + exit_code = proc.returncode if proc.returncode is not None else 0 + + if self.persist_cwd and exit_code == 0: + self._update_cwd(command) + + if exit_code != 0: + return f'{output}\n[exit code: {exit_code}]' + return output + + # ------------------------------------------------------------------ + # Capability interface + # ------------------------------------------------------------------ + + def get_toolset(self) -> AgentToolset[Any] | None: + """Build and return the toolset containing the run_command tool.""" + toolset: FunctionToolset[Any] = FunctionToolset() + toolset.add_function(self.run_command, name='run_command') + return toolset diff --git a/tests/test_filesystem.py b/tests/test_filesystem.py new file mode 100644 index 0000000..da205a7 --- /dev/null +++ b/tests/test_filesystem.py @@ -0,0 +1,366 @@ +"""Tests for the FileSystem capability.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from pydantic_harness.filesystem import FileSystem, format_lines + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +@pytest.fixture +def tmp_root(tmp_path: Path) -> Path: + """Create a temporary root directory populated with test files.""" + (tmp_path / 'hello.txt').write_text('line one\nline two\nline three\n') + (tmp_path / 'sub').mkdir() + (tmp_path / 'sub' / 'nested.py').write_text('print("hello")\n') + (tmp_path / 'secret.env').write_text('API_KEY=abc123\n') + return tmp_path + + +@pytest.fixture +def fs(tmp_root: Path) -> FileSystem: + """A FileSystem capability rooted at the test directory.""" + return FileSystem(root_dir=tmp_root) + + +# --------------------------------------------------------------------------- +# format_lines +# --------------------------------------------------------------------------- + + +class TestFormatLines: + def test_basic(self) -> None: + result = format_lines('a\nb\nc\n', 0, 10) + assert ' 1\ta\n' in result + assert ' 2\tb\n' in result + assert ' 3\tc\n' in result + + def test_offset(self) -> None: + result = format_lines('a\nb\nc\nd\n', 2, 1) + assert ' 3\tc\n' in result + assert 'a' not in result.split('\t')[0] # line 1 not present + assert '1 more lines' in result + + def test_offset_past_end(self) -> None: + with pytest.raises(ValueError, match='Offset 10 exceeds file length'): + format_lines('a\n', 10, 1) + + def test_continuation_hint(self) -> None: + result = format_lines('a\nb\nc\n', 0, 2) + assert 'more lines' in result + assert 'offset=2' in result + + def test_no_trailing_newline(self) -> None: + result = format_lines('a', 0, 10) + assert result.endswith('\n') + assert ' 1\ta' in result + + +# --------------------------------------------------------------------------- +# read_file +# --------------------------------------------------------------------------- + + +class TestReadFile: + def test_read_existing(self, fs: FileSystem) -> None: + result = fs.read_file('hello.txt') + assert 'line one' in result + assert 'line two' in result + + def test_read_with_offset(self, fs: FileSystem) -> None: + result = fs.read_file('hello.txt', offset=1, limit=1) + assert 'line two' in result + assert 'line one' not in result + + def test_read_nested(self, fs: FileSystem) -> None: + result = fs.read_file('sub/nested.py') + assert 'print' in result + + def test_read_missing(self, fs: FileSystem) -> None: + with pytest.raises(FileNotFoundError): + fs.read_file('nonexistent.txt') + + def test_read_directory(self, fs: FileSystem) -> None: + with pytest.raises(FileNotFoundError, match='is a directory'): + fs.read_file('sub') + + def test_traversal_blocked(self, fs: FileSystem) -> None: + with pytest.raises(PermissionError, match='outside the root'): + fs.read_file('../../../etc/passwd') + + +# --------------------------------------------------------------------------- +# write_file +# --------------------------------------------------------------------------- + + +class TestWriteFile: + def test_write_new(self, fs: FileSystem, tmp_root: Path) -> None: + result = fs.write_file('new.txt', 'hello world') + assert 'Successfully wrote' in result + assert (tmp_root / 'new.txt').read_text() == 'hello world' + + def test_write_creates_parents(self, fs: FileSystem, tmp_root: Path) -> None: + fs.write_file('deep/nested/file.txt', 'content') + assert (tmp_root / 'deep' / 'nested' / 'file.txt').read_text() == 'content' + + def test_write_overwrite(self, fs: FileSystem, tmp_root: Path) -> None: + fs.write_file('hello.txt', 'overwritten') + assert (tmp_root / 'hello.txt').read_text() == 'overwritten' + + +# --------------------------------------------------------------------------- +# edit_file +# --------------------------------------------------------------------------- + + +class TestEditFile: + def test_edit_single(self, fs: FileSystem, tmp_root: Path) -> None: + result = fs.edit_file('hello.txt', 'line two', 'LINE TWO') + assert '1 occurrence' in result + assert 'LINE TWO' in (tmp_root / 'hello.txt').read_text() + + def test_edit_not_found(self, fs: FileSystem) -> None: + with pytest.raises(ValueError, match='not found'): + fs.edit_file('hello.txt', 'does not exist', 'replacement') + + def test_edit_ambiguous(self, fs: FileSystem, tmp_root: Path) -> None: + (tmp_root / 'dup.txt').write_text('aaa\naaa\n') + with pytest.raises(ValueError, match='2 times'): + fs.edit_file('dup.txt', 'aaa', 'bbb') + + def test_edit_replace_all(self, fs: FileSystem, tmp_root: Path) -> None: + (tmp_root / 'dup.txt').write_text('aaa\naaa\n') + result = fs.edit_file('dup.txt', 'aaa', 'bbb', replace_all=True) + assert '2 occurrence' in result + assert (tmp_root / 'dup.txt').read_text() == 'bbb\nbbb\n' + + def test_edit_missing_file(self, fs: FileSystem) -> None: + with pytest.raises(FileNotFoundError): + fs.edit_file('nope.txt', 'a', 'b') + + +# --------------------------------------------------------------------------- +# list_directory +# --------------------------------------------------------------------------- + + +class TestListDirectory: + def test_list_root(self, fs: FileSystem) -> None: + result = fs.list_directory() + assert 'hello.txt' in result + assert 'sub/' in result + + def test_list_subdir(self, fs: FileSystem) -> None: + result = fs.list_directory('sub') + assert 'nested.py' in result + + def test_list_nonexistent(self, fs: FileSystem) -> None: + with pytest.raises(NotADirectoryError): + fs.list_directory('nonexistent') + + def test_list_empty(self, fs: FileSystem, tmp_root: Path) -> None: + (tmp_root / 'empty').mkdir() + result = fs.list_directory('empty') + assert result == '(empty directory)' + + +# --------------------------------------------------------------------------- +# search_files +# --------------------------------------------------------------------------- + + +class TestSearchFiles: + def test_search_match(self, fs: FileSystem) -> None: + result = fs.search_files('line') + assert 'hello.txt:1:line one' in result + + def test_search_regex(self, fs: FileSystem) -> None: + result = fs.search_files(r'line\s+t') + assert 'hello.txt' in result + + def test_search_no_match(self, fs: FileSystem) -> None: + result = fs.search_files('zzzzz_nothing') + assert result == 'No matches found.' + + def test_search_nested(self, fs: FileSystem) -> None: + result = fs.search_files('print') + assert 'sub/nested.py' in result + + def test_search_specific_file(self, fs: FileSystem) -> None: + result = fs.search_files('line', path='hello.txt') + assert 'hello.txt:1:line one' in result + + def test_search_skips_hidden(self, fs: FileSystem, tmp_root: Path) -> None: + (tmp_root / '.hidden').mkdir() + (tmp_root / '.hidden' / 'secret.txt').write_text('findme\n') + result = fs.search_files('findme') + assert result == 'No matches found.' + + def test_search_skips_binary(self, fs: FileSystem, tmp_root: Path) -> None: + (tmp_root / 'binary.dat').write_bytes(b'findme\x00binary') + result = fs.search_files('findme') + assert 'binary.dat' not in result + + def test_search_skips_unreadable(self, fs: FileSystem, tmp_root: Path) -> None: + target = tmp_root / 'unreadable.txt' + target.write_text('findme\n') + target.chmod(0o000) + try: + result = fs.search_files('findme') + assert 'unreadable.txt' not in result + finally: + target.chmod(0o644) + + def test_search_truncation(self, tmp_root: Path) -> None: + # Create enough matches to trigger truncation at 1000 + big = '\n'.join(f'match line {i}' for i in range(1100)) + (tmp_root / 'big.txt').write_text(big) + fs = FileSystem(root_dir=tmp_root) + result = fs.search_files('match') + assert 'truncated' in result + + +# --------------------------------------------------------------------------- +# create_directory +# --------------------------------------------------------------------------- + + +class TestCreateDirectory: + def test_create_simple(self, fs: FileSystem, tmp_root: Path) -> None: + result = fs.create_directory('newdir') + assert 'Created directory' in result + assert (tmp_root / 'newdir').is_dir() + + def test_create_nested(self, fs: FileSystem, tmp_root: Path) -> None: + result = fs.create_directory('a/b/c') + assert 'Created directory' in result + assert (tmp_root / 'a' / 'b' / 'c').is_dir() + + def test_create_existing(self, fs: FileSystem) -> None: + # Should not raise for existing directories (exist_ok=True) + result = fs.create_directory('sub') + assert 'Created directory' in result + + def test_create_traversal_blocked(self, fs: FileSystem) -> None: + with pytest.raises(PermissionError, match='outside the root'): + fs.create_directory('../../escape') + + def test_create_denied(self, tmp_root: Path) -> None: + fs = FileSystem(root_dir=tmp_root, denied_patterns=['*.secret']) + with pytest.raises(PermissionError, match='denied'): + fs.create_directory('stuff.secret') + + +# --------------------------------------------------------------------------- +# find_files +# --------------------------------------------------------------------------- + + +class TestFindFiles: + def test_find_by_extension(self, fs: FileSystem) -> None: + result = fs.find_files('*.txt') + assert 'hello.txt' in result + + def test_find_recursive(self, fs: FileSystem) -> None: + result = fs.find_files('**/*.py') + assert 'sub/nested.py' in result + + def test_find_no_match(self, fs: FileSystem) -> None: + result = fs.find_files('*.nonexistent') + assert result == 'No matches found.' + + def test_find_in_subdir(self, fs: FileSystem) -> None: + result = fs.find_files('*.py', path='sub') + assert 'nested.py' in result + + def test_find_not_a_directory(self, fs: FileSystem) -> None: + with pytest.raises(NotADirectoryError): + fs.find_files('*.txt', path='hello.txt') + + def test_find_skips_hidden(self, fs: FileSystem, tmp_root: Path) -> None: + (tmp_root / '.hidden').mkdir() + (tmp_root / '.hidden' / 'secret.py').write_text('hidden\n') + result = fs.find_files('**/*.py') + assert '.hidden' not in result + + def test_find_includes_directories(self, fs: FileSystem) -> None: + result = fs.find_files('sub') + assert 'sub/' in result + + def test_find_truncation(self, tmp_root: Path) -> None: + for i in range(1100): + (tmp_root / f'file_{i:04d}.dat').write_text('') + fs = FileSystem(root_dir=tmp_root) + result = fs.find_files('*.dat') + assert 'truncated' in result + + +# --------------------------------------------------------------------------- +# Path filtering (allowed_patterns / denied_patterns) +# --------------------------------------------------------------------------- + + +class TestPathFiltering: + def test_denied_pattern(self, tmp_root: Path) -> None: + fs = FileSystem(root_dir=tmp_root, denied_patterns=['*.env']) + with pytest.raises(PermissionError, match='denied'): + fs.read_file('secret.env') + # Other files still accessible + result = fs.read_file('hello.txt') + assert 'line one' in result + + def test_allowed_pattern(self, tmp_root: Path) -> None: + fs = FileSystem(root_dir=tmp_root, allowed_patterns=['*.txt']) + result = fs.read_file('hello.txt') + assert 'line one' in result + with pytest.raises(PermissionError, match='does not match'): + fs.read_file('secret.env') + + def test_denied_write(self, tmp_root: Path) -> None: + fs = FileSystem(root_dir=tmp_root, denied_patterns=['*.env']) + with pytest.raises(PermissionError, match='denied'): + fs.write_file('new.env', 'bad') + + def test_denied_edit(self, tmp_root: Path) -> None: + fs = FileSystem(root_dir=tmp_root, denied_patterns=['*.env']) + with pytest.raises(PermissionError, match='denied'): + fs.edit_file('secret.env', 'API_KEY', 'REDACTED') + + +# --------------------------------------------------------------------------- +# Toolset integration +# --------------------------------------------------------------------------- + + +class TestToolset: + def test_get_toolset_returns_function_toolset(self, fs: FileSystem) -> None: + from pydantic_ai.toolsets import FunctionToolset + + toolset = fs.get_toolset() + assert isinstance(toolset, FunctionToolset) + + def test_toolset_has_expected_tools(self, fs: FileSystem) -> None: + from pydantic_ai.toolsets import FunctionToolset + + toolset = fs.get_toolset() + assert isinstance(toolset, FunctionToolset) + tool_names = set(toolset.tools.keys()) + assert tool_names == { + 'read_file', + 'write_file', + 'edit_file', + 'list_directory', + 'search_files', + 'create_directory', + 'find_files', + } + + def test_serialization_name(self) -> None: + assert FileSystem.get_serialization_name() == 'FileSystem' diff --git a/tests/test_shell.py b/tests/test_shell.py new file mode 100644 index 0000000..c8647ec --- /dev/null +++ b/tests/test_shell.py @@ -0,0 +1,293 @@ +"""Tests for the Shell capability.""" + +from __future__ import annotations + +import sys +from pathlib import Path + +import pytest + +from pydantic_harness.shell import Shell + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +@pytest.fixture +def tmp_cwd(tmp_path: Path) -> Path: + """Create a temporary working directory.""" + (tmp_path / 'greeting.txt').write_text('hello world\n') + return tmp_path + + +@pytest.fixture +def sh(tmp_cwd: Path) -> Shell: + """A Shell capability rooted at the test directory.""" + return Shell(cwd=tmp_cwd) + + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + + +class TestConfig: + def test_defaults(self) -> None: + sh = Shell() + assert sh.default_timeout == 30.0 + assert sh.max_output_chars == 10_000 + + def test_cannot_mix_allow_deny(self) -> None: + with pytest.raises(ValueError, match='not both'): + Shell(allowed_commands=['ls'], denied_commands=['rm']) + + +# --------------------------------------------------------------------------- +# Command validation +# --------------------------------------------------------------------------- + + +class TestCommandValidation: + def test_denied_command(self) -> None: + sh = Shell(denied_commands=['rm']) + with pytest.raises(PermissionError, match='denied'): + sh.check_command('rm -rf /') + + def test_allowed_command(self) -> None: + sh = Shell(allowed_commands=['echo', 'cat']) + sh.check_command('echo hello') # should not raise + with pytest.raises(PermissionError, match='not in the allowed'): + sh.check_command('rm -rf /') + + def test_no_restrictions(self) -> None: + sh = Shell() + sh.check_command('anything goes') # should not raise + + def test_malformed_command(self) -> None: + sh = Shell(denied_commands=['rm']) + # Unterminated quote: shlex.split raises ValueError. + # The capability falls through and lets the shell handle it. + sh.check_command("echo 'unterminated") # should not raise + + def test_empty_command(self) -> None: + sh = Shell(allowed_commands=['echo']) + sh.check_command('') # empty string should not raise + + +# --------------------------------------------------------------------------- +# Output truncation +# --------------------------------------------------------------------------- + + +class TestTruncation: + def test_short_output(self) -> None: + sh = Shell(max_output_chars=100) + assert sh.truncate('short') == 'short' + + def test_long_output(self) -> None: + sh = Shell(max_output_chars=10) + result = sh.truncate('x' * 50) + assert len(result.splitlines()[0]) == 10 + assert 'truncated' in result + + +# --------------------------------------------------------------------------- +# Command execution +# --------------------------------------------------------------------------- + + +class TestRunCommand: + @pytest.mark.anyio + async def test_echo(self, sh: Shell) -> None: + result = await sh.run_command('echo hello') + assert '[stdout]' in result + assert 'hello' in result + + @pytest.mark.anyio + async def test_stderr_label(self, sh: Shell) -> None: + result = await sh.run_command('echo oops >&2') + assert '[stderr]' in result + assert 'oops' in result + + @pytest.mark.anyio + async def test_stdout_and_stderr(self, sh: Shell) -> None: + result = await sh.run_command( + f"{sys.executable} -c \"import sys; print('out'); print('err', file=sys.stderr)\"" + ) + assert '[stdout]' in result + assert '[stderr]' in result + assert 'out' in result + assert 'err' in result + + @pytest.mark.anyio + async def test_exit_code(self, sh: Shell) -> None: + result = await sh.run_command('exit 1') + assert 'exit code: 1' in result + + @pytest.mark.anyio + async def test_timeout(self) -> None: + sh = Shell() + result = await sh.run_command('sleep 10', timeout_seconds=0.1) + assert 'timed out' in result.lower() + + @pytest.mark.anyio + async def test_cwd(self, sh: Shell) -> None: + result = await sh.run_command('cat greeting.txt') + assert 'hello world' in result + + @pytest.mark.anyio + async def test_truncated_output(self, tmp_cwd: Path) -> None: + sh = Shell(cwd=tmp_cwd, max_output_chars=20) + result = await sh.run_command(f'{sys.executable} -c "print(\'x\' * 100)"') + assert 'truncated' in result + + @pytest.mark.anyio + async def test_denied_command_async(self) -> None: + sh = Shell(denied_commands=['rm']) + with pytest.raises(PermissionError, match='denied'): + await sh.run_command('rm -rf /') + + @pytest.mark.anyio + async def test_allowed_command_async(self) -> None: + sh = Shell(allowed_commands=['echo']) + result = await sh.run_command('echo works') + assert 'works' in result + with pytest.raises(PermissionError, match='not in the allowed'): + await sh.run_command('cat /etc/passwd') + + @pytest.mark.anyio + async def test_empty_output(self, sh: Shell) -> None: + result = await sh.run_command('true') + assert result == '' + + +# --------------------------------------------------------------------------- +# Persistent working directory +# --------------------------------------------------------------------------- + + +class TestPersistCwd: + @pytest.mark.anyio + async def test_cd_updates_cwd(self, tmp_cwd: Path) -> None: + subdir = tmp_cwd / 'subdir' + subdir.mkdir() + (subdir / 'marker.txt').write_text('found\n') + sh = Shell(cwd=tmp_cwd, persist_cwd=True) + await sh.run_command('cd subdir') + result = await sh.run_command('cat marker.txt') + assert 'found' in result + + @pytest.mark.anyio + async def test_cd_disabled_by_default(self, tmp_cwd: Path) -> None: + subdir = tmp_cwd / 'subdir' + subdir.mkdir() + (subdir / 'marker.txt').write_text('found\n') + sh = Shell(cwd=tmp_cwd) + await sh.run_command('cd subdir') + # Without persist_cwd, cwd should not change + result = await sh.run_command('cat marker.txt') + assert 'exit code' in result # cat fails because we're in the wrong dir + + @pytest.mark.anyio + async def test_cd_chained_command(self, tmp_cwd: Path) -> None: + subdir = tmp_cwd / 'sub' + subdir.mkdir() + sh = Shell(cwd=tmp_cwd, persist_cwd=True) + await sh.run_command('cd sub && echo hi') + assert sh._cwd == subdir + + @pytest.mark.anyio + async def test_cd_quoted_path(self, tmp_cwd: Path) -> None: + subdir = tmp_cwd / 'my dir' + subdir.mkdir() + sh = Shell(cwd=tmp_cwd, persist_cwd=True) + await sh.run_command("cd 'my dir'") + assert sh._cwd == subdir + + @pytest.mark.anyio + async def test_cd_nonexistent_ignored(self, tmp_cwd: Path) -> None: + sh = Shell(cwd=tmp_cwd, persist_cwd=True) + original_cwd = sh._cwd + await sh.run_command('cd nonexistent_dir') + # cd failed (exit code != 0), so _cwd should not change + assert sh._cwd == original_cwd + + @pytest.mark.anyio + async def test_cd_not_updated_on_failure(self, tmp_cwd: Path) -> None: + subdir = tmp_cwd / 'subdir' + subdir.mkdir() + sh = Shell(cwd=tmp_cwd, persist_cwd=True) + original_cwd = sh._cwd + # The cd itself succeeds but the second command fails + await sh.run_command('cd subdir && false') + assert sh._cwd == original_cwd + + @pytest.mark.anyio + async def test_cd_absolute_path(self, tmp_cwd: Path) -> None: + subdir = tmp_cwd / 'target' + subdir.mkdir() + sh = Shell(cwd=tmp_cwd, persist_cwd=True) + await sh.run_command(f'cd {subdir}') + assert sh._cwd == subdir + + @pytest.mark.anyio + async def test_cd_home(self, tmp_cwd: Path) -> None: + sh = Shell(cwd=tmp_cwd, persist_cwd=True) + await sh.run_command('cd ~') + assert sh._cwd == Path.home() + + @pytest.mark.anyio + async def test_cd_home_subdir(self, tmp_cwd: Path) -> None: + sh = Shell(cwd=tmp_cwd, persist_cwd=True) + await sh.run_command('cd ~/.') + assert sh._cwd == Path.home() + + def test_update_cwd_nonexistent_dir(self, tmp_cwd: Path) -> None: + sh = Shell(cwd=tmp_cwd, persist_cwd=True) + original = sh._cwd + # Call _update_cwd directly with a cd to a non-existent path + sh._update_cwd('cd does_not_exist') + assert sh._cwd == original + + def test_extract_cd_target_none(self) -> None: + sh = Shell() + assert sh._extract_cd_target('echo hello') is None + assert sh._extract_cd_target('ls -la') is None + + def test_extract_cd_target_simple(self) -> None: + sh = Shell() + assert sh._extract_cd_target('cd foo') == 'foo' + assert sh._extract_cd_target('cd /tmp') == '/tmp' + + def test_extract_cd_target_with_chain(self) -> None: + sh = Shell() + assert sh._extract_cd_target('cd foo && ls') == 'foo' + assert sh._extract_cd_target('cd bar; pwd') == 'bar' + + def test_extract_cd_double_quoted(self) -> None: + sh = Shell() + assert sh._extract_cd_target('cd "my dir"') == 'my dir' + + +# --------------------------------------------------------------------------- +# Toolset integration +# --------------------------------------------------------------------------- + + +class TestToolset: + def test_get_toolset_returns_function_toolset(self, sh: Shell) -> None: + from pydantic_ai.toolsets import FunctionToolset + + toolset = sh.get_toolset() + assert isinstance(toolset, FunctionToolset) + + def test_toolset_has_run_command(self, sh: Shell) -> None: + from pydantic_ai.toolsets import FunctionToolset + + toolset = sh.get_toolset() + assert isinstance(toolset, FunctionToolset) + assert set(toolset.tools.keys()) == {'run_command'} + + def test_serialization_name(self) -> None: + assert Shell.get_serialization_name() == 'Shell' diff --git a/uv.lock b/uv.lock index 0730281..f84f872 100644 --- a/uv.lock +++ b/uv.lock @@ -540,6 +540,7 @@ wheels = [ name = "pydantic-harness" source = { editable = "." } dependencies = [ + { name = "anyio" }, { name = "pydantic-ai-slim" }, ] @@ -556,7 +557,10 @@ lint = [ ] [package.metadata] -requires-dist = [{ name = "pydantic-ai-slim", specifier = ">=1.76.0" }] +requires-dist = [ + { name = "anyio", specifier = ">=4.0" }, + { name = "pydantic-ai-slim", specifier = ">=1.76.0" }, +] [package.metadata.requires-dev] dev = [