diff --git a/.agents/settings.local.json b/.agents/settings.local.json deleted file mode 100644 index 8b311a3..0000000 --- a/.agents/settings.local.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "permissions": { - "allow": [] - } -} diff --git a/.gitignore b/.gitignore index 00e73a6..36c255e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,8 @@ +.env* +.mcp.json +.DS_Store +.agents/settings.local.json + # IDE .idea/ diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 0000000..0f95390 --- /dev/null +++ b/PLAN.md @@ -0,0 +1,54 @@ +# Memory Capability + +## Summary + +Implements a `Memory` capability (`AbstractCapability` subclass) that provides persistent key-value memory across agent sessions, referencing issues #30 and #31. + +## Design + +### Architecture + +- **`Memory`** dataclass extends `AbstractCapability[AgentDepsT]` + - `get_instructions()` returns a dynamic callable that injects stored memories into the system prompt at run start + - `get_toolset()` returns a `FunctionToolset` with five tools: `save_memory`, `recall_memory`, `search_memories`, `list_memories`, `delete_memory` + - Tool functions use closures over `self.store` (no dependency on agent `deps`) + +### Storage + +- **`MemoryStore`** protocol: pluggable backend with `get`, `put`, `delete`, `list_all`, `search` +- **`InMemoryStore`**: dict-based, ephemeral, for testing (default) +- **`FileStore`**: JSON file on disk, reads on init, writes on every mutation + +### Memory Model + +- **`MemoryEntry`** dataclass: `key`, `content`, `tags` (list[str]), `scope`, `expires_at`, `created_at`, `updated_at` +- **`MemoryEntryDict`** TypedDict for serialization +- Word-boundary search with relevance scoring (case-insensitive) across key, content, and tags +- Scoping/namespaces via `scope` field with filtering on search/list +- TTL/expiration via `expires_at` with `is_expired()` auto-filtering +- Dedup warning on save when keys are similar (Levenshtein distance <= 2) + +### Spec Serialization + +- `Memory.get_serialization_name()` returns `"Memory"` +- `Memory.from_spec(backend="file", path="...")` creates a `FileStore`-backed instance + +## Configuration + +| Field | Default | Description | +|-------|---------|-------------| +| `store` | `InMemoryStore()` | Storage backend | +| `inject_memories_in_instructions` | `True` | Include memories in system prompt | +| `max_instructions_memories` | `20` | Cap on memories injected into prompt | + +## Files + +- `src/pydantic_harness/memory.py` - Capability, stores, entry model +- `src/pydantic_harness/__init__.py` - Re-exports +- `tests/test_memory.py` - 113 tests covering all code paths + +## Future Work + +- Semantic/vector search backend (e.g. embedding-based `MemoryStore`) +- Session-scoped memory isolation via `for_run()` +- SQLite / Redis backends for production persistence diff --git a/examples/memory/coding_assistant.py b/examples/memory/coding_assistant.py new file mode 100644 index 0000000..65c4123 --- /dev/null +++ b/examples/memory/coding_assistant.py @@ -0,0 +1,101 @@ +"""Self-Improving Coding Assistant — procedural memory via instructions injection. + +Demonstrates: instructions injection as self-modifying prompt, scoping, search, delete. +""" + +from __future__ import annotations + +import sys + +import logfire +from pydantic_ai import Agent + +from pydantic_harness.memory import InMemoryStore, Memory + +logfire.configure(send_to_logfire='if-token-present') +logfire.instrument_openai() + + +def main() -> None: + """Run the coding assistant example.""" + store = InMemoryStore() + memory = Memory(store=store, max_instructions_memories=10) + + agent = Agent( + 'openai:gpt-4o-mini', + capabilities=[memory], + system_prompt=( + 'You are a coding assistant that learns from user corrections. ' + 'When the user gives you a coding rule or correction, save it as a memory ' + 'with scope "rules" and tags like ["python", "style"] or ["typescript", "testing"]. ' + 'Use descriptive keys like "rule_python_fstrings" or "rule_ts_const". ' + 'When asked to write code, search your memories for relevant rules first.' + ), + ) + + # --- Teach rules --- + with logfire.span('teach-rules'): + result1 = agent.run_sync( + 'Remember these coding rules:\n' + '1. Always use f-strings in Python, never .format() or % formatting\n' + '2. In TypeScript, prefer const over let, never use var\n' + '3. Always add type hints to Python function signatures' + ) + print(f'Assistant: {result1.output}') + + rules = store.list_all() + print(f'\nRules stored: {len(rules)}') + for r in rules: + print(f' [{r.key}] {r.content} (scope={r.scope}, tags={r.tags})') + + assert len(rules) >= 3, f'Expected at least 3 rules saved, got {len(rules)}' + + # Check that search works across stored rules + python_rules = store.search('python') + print(f'Rules matching "python": {len(python_rules)}') + assert len(python_rules) >= 1, 'Expected at least 1 rule matching "python"' + + # --- Verify instructions injection --- + # Build instructions should now include the rules + from unittest.mock import MagicMock + + from pydantic_ai._run_context import RunContext + from pydantic_ai.usage import RunUsage + + ctx: RunContext[None] = RunContext(deps=None, model=MagicMock(), usage=RunUsage()) + instructions = memory.build_instructions(ctx) + print(f'\nInstructions preview (first 300 chars):\n{instructions[:300]}...') + + assert 'Currently stored memories' in instructions, 'Expected memories in instructions' + + # --- Ask for code, verify rules are considered --- + with logfire.span('apply-rules'): + result2 = agent.run_sync( + 'Write a Python function that greets a user by name. Follow all coding rules you know.' + ) + print(f'\nAssistant: {result2.output}') + + # The output should use f-strings and type hints (based on rules) + output_lower = result2.output.lower() + assert "f'" in result2.output or 'f"' in result2.output or 'f-string' in output_lower, ( + 'Expected f-string usage in code output' + ) + + # --- Delete an obsolete rule --- + with logfire.span('delete-rule'): + result3 = agent.run_sync('Actually, the TypeScript const rule is outdated for this project. Delete it.') + print(f'\nAssistant: {result3.output}') + + remaining = store.list_all() + print(f'\nRules after deletion: {len(remaining)}') + for r in remaining: + print(f' [{r.key}] {r.content}') + + # Should have fewer rules now + assert len(remaining) < len(rules), 'Expected at least one rule deleted' + + print('\n--- Coding Assistant example passed! ---') + + +if __name__ == '__main__': + sys.exit(main() or 0) diff --git a/examples/memory/personal_assistant.py b/examples/memory/personal_assistant.py new file mode 100644 index 0000000..cd477f1 --- /dev/null +++ b/examples/memory/personal_assistant.py @@ -0,0 +1,89 @@ +"""Personal Assistant — remembers user preferences across sessions. + +Demonstrates: FileStore persistence, save/recall, instructions injection, tags, scoping. +""" + +from __future__ import annotations + +import sys +import tempfile +from pathlib import Path + +import logfire +from pydantic_ai import Agent + +from pydantic_harness.memory import FileStore, Memory + +logfire.configure(send_to_logfire='if-token-present') +logfire.instrument_openai() + + +def main() -> None: + """Run the personal assistant example.""" + with tempfile.TemporaryDirectory() as tmpdir: + mem_path = Path(tmpdir) / 'preferences.json' + store = FileStore(mem_path) + memory = Memory(store=store) + + agent = Agent( + 'openai:gpt-4o-mini', + capabilities=[memory], + system_prompt=( + 'You are a helpful personal assistant. ' + 'When the user tells you about their preferences, save each one as a memory ' + 'with scope "user_prefs" and appropriate tags. ' + 'Use descriptive keys like "preferred_name" or "theme_preference".' + ), + ) + + # --- Session 1: user shares preferences --- + with logfire.span('session-1-save-preferences'): + result1 = agent.run_sync("Hi! My name is Alice, I prefer dark mode, and I'm vegetarian.") + print(f'Assistant: {result1.output}') + + entries = store.list_all() + print(f'\nMemories after session 1: {len(entries)}') + for e in entries: + print(f' [{e.key}] {e.content} (tags={e.tags}, scope={e.scope})') + + assert len(entries) >= 2, f'Expected at least 2 memories saved, got {len(entries)}' + all_content = ' '.join(e.content.lower() for e in entries) + assert 'alice' in all_content or any('alice' in e.key.lower() for e in entries), 'Expected a memory about Alice' + + # --- Session 2: new agent instance loads from same file (persistence) --- + store2 = FileStore(mem_path) + memory2 = Memory(store=store2) + agent2 = Agent( + 'openai:gpt-4o-mini', + capabilities=[memory2], + system_prompt='You are a helpful personal assistant.', + ) + + loaded_entries = store2.list_all() + print(f'\nMemories loaded in session 2: {len(loaded_entries)}') + assert len(loaded_entries) == len(entries), 'FileStore persistence failed' + + with logfire.span('session-2-recall-preferences'): + result2 = agent2.run_sync('What do you know about me?') + print(f'Assistant: {result2.output}') + + # The instructions injection should have included the memories + assert 'alice' in result2.output.lower() or 'dark' in result2.output.lower(), ( + 'Expected assistant to recall preferences from instructions injection' + ) + + # --- Session 3: update a preference --- + with logfire.span('session-3-update-preference'): + result3 = agent2.run_sync('Actually, I go by Ali now. Please update my name.') + print(f'\nAssistant: {result3.output}') + + updated_entries = store2.list_all() + print(f'\nMemories after update: {len(updated_entries)}') + for e in updated_entries: + print(f' [{e.key}] {e.content} (tags={e.tags})') + + print('\n--- Personal Assistant example passed! ---') + + +if __name__ == '__main__': + sys.exit(main() or 0) diff --git a/examples/memory/study_coach.py b/examples/memory/study_coach.py new file mode 100644 index 0000000..702b335 --- /dev/null +++ b/examples/memory/study_coach.py @@ -0,0 +1,76 @@ +"""Study Coach — spaced repetition with TTL. + +Demonstrates: TTL/expiration, save with ttl_minutes, list/search, tags. +""" + +from __future__ import annotations + +import sys + +import logfire +from pydantic_ai import Agent + +from pydantic_harness.memory import InMemoryStore, Memory + +logfire.configure(send_to_logfire='if-token-present') +logfire.instrument_openai() + + +def main() -> None: + """Run the study coach example.""" + store = InMemoryStore() + memory = Memory(store=store) + + agent = Agent( + 'openai:gpt-4o-mini', + capabilities=[memory], + system_prompt=( + 'You are a study coach that helps users learn facts. ' + 'When the user provides a fact to learn, save it as a memory with ' + 'tag "study" and a ttl_minutes value: use 1 for new/hard facts, ' + '60 for reviewed facts, and 1440 for mastered facts. ' + 'Use descriptive keys like "biology_mitochondria" or "history_magna_carta".' + ), + ) + + # --- Learn some facts --- + with logfire.span('learn-facts'): + result1 = agent.run_sync( + 'I need to learn these facts:\n' + '1. Mitochondria are the powerhouse of the cell\n' + '2. The Magna Carta was signed in 1215\n' + '3. Water boils at 100 degrees Celsius at sea level' + ) + print(f'Coach: {result1.output}') + + entries = store.list_all() + print(f'\nFacts stored: {len(entries)}') + for e in entries: + print(f' [{e.key}] {e.content} (tags={e.tags}, ttl={e.expires_at})') + + assert len(entries) >= 3, f'Expected at least 3 facts saved, got {len(entries)}' + + # Check that TTL was set on at least some entries + entries_with_ttl = [e for e in entries if e.expires_at is not None] + assert len(entries_with_ttl) >= 1, 'Expected at least 1 entry with TTL set' + print(f'Entries with TTL: {len(entries_with_ttl)}') + + # Check tags + entries_with_study_tag = [e for e in entries if 'study' in e.tags] + assert len(entries_with_study_tag) >= 1, 'Expected at least 1 entry with "study" tag' + + # --- Search for facts --- + with logfire.span('search-facts'): + result2 = agent.run_sync('Search my memories for anything about biology.') + print(f'\nCoach: {result2.output}') + + # --- List all facts --- + with logfire.span('list-facts'): + result3 = agent.run_sync('List all my study memories.') + print(f'\nCoach: {result3.output}') + + print('\n--- Study Coach example passed! ---') + + +if __name__ == '__main__': + sys.exit(main() or 0) diff --git a/src/pydantic_harness/__init__.py b/src/pydantic_harness/__init__.py index 9d728b6..d47b53e 100644 --- a/src/pydantic_harness/__init__.py +++ b/src/pydantic_harness/__init__.py @@ -7,4 +7,13 @@ # Each capability module is imported and re-exported here. # Capabilities are listed alphabetically. -__all__: list[str] = [] +from pydantic_harness.memory import FileStore, InMemoryStore, Memory, MemoryEntry, MemoryEntryDict, MemoryStore + +__all__: list[str] = [ + 'FileStore', + 'InMemoryStore', + 'Memory', + 'MemoryEntry', + 'MemoryEntryDict', + 'MemoryStore', +] diff --git a/src/pydantic_harness/memory.py b/src/pydantic_harness/memory.py new file mode 100644 index 0000000..6dce2e6 --- /dev/null +++ b/src/pydantic_harness/memory.py @@ -0,0 +1,480 @@ +"""Memory capability for persistent agent memory across sessions. + +Provides tools for saving, recalling, searching, listing, and deleting +key-value memories, with pluggable storage backends (`InMemoryStore` for +testing, `FileStore` for on-disk persistence). +""" + +from __future__ import annotations + +import json +import logging +import re +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Any, Protocol, TypedDict, runtime_checkable + +from pydantic_ai._instructions import AgentInstructions +from pydantic_ai.capabilities.abstract import AbstractCapability +from pydantic_ai.tools import AgentDepsT, RunContext, Tool +from pydantic_ai.toolsets import AgentToolset +from pydantic_ai.toolsets.function import FunctionToolset + +logger = logging.getLogger(__name__) + + +class _MemoryEntryDictRequired(TypedDict): + """Required fields for MemoryEntryDict.""" + + key: str + content: str + + +class MemoryEntryDict(_MemoryEntryDictRequired, total=False): + """Serialized form of a MemoryEntry for JSON storage. + + Only `key` and `content` are required; the remaining fields are + optional so that `from_dict` can accept legacy data missing some keys. + """ + + tags: list[str] + scope: str + expires_at: str | None + created_at: str + updated_at: str + + +@dataclass +class MemoryEntry: + """A single memory entry with content, tags, and timestamps.""" + + key: str + """Unique identifier for this memory.""" + + content: str + """The content of the memory.""" + + tags: list[str] = field(default_factory=lambda: list[str]()) + """Optional tags for categorization and search.""" + + scope: str = 'global' + """Namespace scope for this memory (default `'global'`).""" + + expires_at: str | None = None + """Optional ISO 8601 expiration timestamp. `None` means no expiry.""" + + created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + """ISO 8601 timestamp of when the memory was first created.""" + + updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + """ISO 8601 timestamp of the last update.""" + + def is_expired(self) -> bool: + """Return True if this entry has passed its expiration time.""" + if self.expires_at is None: + return False + return datetime.fromisoformat(self.expires_at) <= datetime.now(timezone.utc) + + def to_dict(self) -> MemoryEntryDict: + """Serialize to a plain dict for JSON storage.""" + return { + 'key': self.key, + 'content': self.content, + 'tags': self.tags, + 'scope': self.scope, + 'expires_at': self.expires_at, + 'created_at': self.created_at, + 'updated_at': self.updated_at, + } + + @classmethod + def from_dict(cls, data: MemoryEntryDict) -> MemoryEntry: + """Deserialize from a plain dict.""" + return cls( + key=data['key'], + content=data['content'], + tags=data.get('tags', []), + scope=data.get('scope', 'global'), + expires_at=data.get('expires_at'), + created_at=data.get('created_at', ''), + updated_at=data.get('updated_at', ''), + ) + + +def _score_entry(entry: MemoryEntry, words: list[str]) -> int: + r"""Score a memory entry by counting word-boundary matches across fields. + + Each query word that appears as a whole word (case-insensitive) in the + key, content, or any tag contributes one point per field it appears in. + Underscores and hyphens are treated as word separators in addition to + the standard `\\b` boundaries. + """ + score = 0 + for word in words: + # Use a boundary pattern that also treats _ and - as separators. + escaped = re.escape(word) + pattern = re.compile(rf'(? bool: + """Return True if two keys share the same first 10 characters and differ only slightly. + + Uses a simple character-level edit distance check: keys are considered + similar when they share the same 10-char prefix and differ by at most 2 + characters (Levenshtein-like). + """ + if len(a) < 10 or len(b) < 10: + return False + if a[:10] != b[:10]: + return False + if a == b: + return False + # Simple Levenshtein-like check: allow at most 2 edits + if abs(len(a) - len(b)) > 2: + return False + # Bounded character-level distance (sufficient for dedup warnings) + max_edits = 2 + m, n = len(a), len(b) + prev = list(range(n + 1)) + for i in range(1, m + 1): + curr = [i] + [0] * n + for j in range(1, n + 1): + cost = 0 if a[i - 1] == b[j - 1] else 1 + curr[j] = min(curr[j - 1] + 1, prev[j] + 1, prev[j - 1] + cost) + prev = curr + return prev[n] <= max_edits + + +@runtime_checkable +class MemoryStore(Protocol): + """Protocol for pluggable memory storage backends.""" + + def get(self, key: str) -> MemoryEntry | None: # pragma: no cover + """Retrieve a memory entry by key, or None if not found.""" + ... + + def put(self, entry: MemoryEntry) -> None: # pragma: no cover + """Store or update a memory entry.""" + ... + + def delete(self, key: str) -> bool: # pragma: no cover + """Delete a memory entry by key. Returns True if it existed.""" + ... + + def list_all(self, *, scope: str | None = None) -> list[MemoryEntry]: # pragma: no cover + """Return all non-expired entries, optionally filtered by scope.""" + ... + + def search(self, query: str, *, scope: str | None = None) -> list[MemoryEntry]: # pragma: no cover + """Search non-expired entries with word-boundary matching, sorted by relevance.""" + ... + + +class _BaseDictStore: + """Base class for dict-backed memory stores.""" + + _entries: dict[str, MemoryEntry] + + def get(self, key: str) -> MemoryEntry | None: + """Retrieve a memory entry by key.""" + return self._entries.get(key) + + def put(self, entry: MemoryEntry) -> None: + """Store or update a memory entry.""" + self._entries[entry.key] = entry + + def delete(self, key: str) -> bool: + """Delete a memory entry by key.""" + return self._entries.pop(key, None) is not None + + def list_all(self, *, scope: str | None = None) -> list[MemoryEntry]: + """Return all non-expired entries, optionally filtered by scope.""" + return [ + entry + for entry in self._entries.values() + if not entry.is_expired() and (scope is None or entry.scope == scope) + ] + + def search(self, query: str, *, scope: str | None = None) -> list[MemoryEntry]: + """Search non-expired entries with word-boundary matching, sorted by relevance.""" + words = query.lower().split() + if not words: + return [] + scored: list[tuple[int, MemoryEntry]] = [] + for entry in self._entries.values(): + if entry.is_expired(): + continue + if scope is not None and entry.scope != scope: + continue + score = _score_entry(entry, words) + if score > 0: + scored.append((score, entry)) + scored.sort(key=lambda pair: pair[0], reverse=True) + return [entry for _, entry in scored] + + +class InMemoryStore(_BaseDictStore): + """Dict-based in-memory store, suitable for testing. + + All data lives in a plain `dict` and is lost when the process exits. + """ + + def __init__(self) -> None: + """Initialize an empty in-memory store.""" + self._entries: dict[str, MemoryEntry] = {} + + +class FileStore(_BaseDictStore): + """JSON-file-based store for simple on-disk persistence. + + Reads the file on initialization and writes back on every mutation. + """ + + def __init__(self, path: str | Path) -> None: + """Initialize a file-backed store at the given path.""" + self._path = Path(path) + self._entries: dict[str, MemoryEntry] = {} + self._load() + + def _load(self) -> None: + if self._path.exists(): + try: + raw: dict[str, MemoryEntryDict] = json.loads(self._path.read_text(encoding='utf-8')) + if not isinstance(raw, dict): # pyright: ignore[reportUnnecessaryIsInstance] + logger.warning('Memory file %s contains non-dict JSON, starting empty', self._path) + return + self._entries = {key: MemoryEntry.from_dict(val) for key, val in raw.items()} + except (json.JSONDecodeError, KeyError, TypeError) as e: + logger.warning('Failed to load memory file %s: %s, starting empty', self._path, e) + self._entries = {} + + def _save(self) -> None: + self._path.parent.mkdir(parents=True, exist_ok=True) + data = {key: entry.to_dict() for key, entry in self._entries.items()} + self._path.write_text(json.dumps(data, indent=2), encoding='utf-8') + + def put(self, entry: MemoryEntry) -> None: + """Store or update a memory entry.""" + super().put(entry) + self._save() + + def delete(self, key: str) -> bool: + """Delete a memory entry by key.""" + existed = super().delete(key) + if existed: + self._save() + return existed + + +def format_entry(entry: MemoryEntry) -> str: + """Format a memory entry as a human-readable string.""" + line = f'[{entry.key}] {entry.content}' + extras: list[str] = [] + if entry.tags: + extras.append(f'tags: {", ".join(entry.tags)}') + if entry.scope != 'global': + extras.append(f'scope: {entry.scope}') + if entry.expires_at is not None: + extras.append(f'expires: {entry.expires_at}') + if extras: + line += f' ({"; ".join(extras)})' + return line + + +@dataclass +class Memory(AbstractCapability[AgentDepsT]): + """Capability for persistent memory across agent sessions. + + Provides tools for saving, recalling, searching, listing, and deleting + key-value memories. Uses a pluggable `MemoryStore` backend for storage. + + Example: + ```python {test="skip" lint="skip"} + from pydantic_ai import Agent + from pydantic_harness.memory import Memory, InMemoryStore + + agent = Agent('openai:gpt-4o', capabilities=[Memory(store=InMemoryStore())]) + ``` + """ + + store: MemoryStore = field(default_factory=InMemoryStore) + """The storage backend. Defaults to `InMemoryStore` (ephemeral, dict-based).""" + + inject_memories_in_instructions: bool = True + """Whether to inject existing memories into the system prompt at run start.""" + + max_instructions_memories: int = 20 + """Maximum number of memories to include in the system prompt.""" + + @classmethod + def get_serialization_name(cls) -> str | None: + """Return the name used for spec serialization.""" + return 'Memory' + + @classmethod + def from_spec( + cls, + *, + backend: str = 'memory', + path: str = '.memories.json', + inject_memories_in_instructions: bool = True, + max_instructions_memories: int = 20, + ) -> Memory[Any]: + """Create from spec arguments. + + Args: + backend: Storage backend, `"memory"` (default) or `"file"`. + path: File path for the `"file"` backend (default `".memories.json"`). + inject_memories_in_instructions: Whether to inject memories into the system prompt. + max_instructions_memories: Maximum memories to inject into the system prompt. + """ + store: MemoryStore + if backend == 'memory': + store = InMemoryStore() + elif backend == 'file': + store = FileStore(path) + else: + raise ValueError(f'Unknown memory backend: {backend!r}. Use "memory" or "file".') + return cls( + store=store, + inject_memories_in_instructions=inject_memories_in_instructions, + max_instructions_memories=max_instructions_memories, + ) + + def build_instructions(self, ctx: RunContext[AgentDepsT]) -> str: + """Build dynamic instructions that include currently stored memories.""" + parts: list[str] = [ + 'You have access to a persistent memory system. ' + 'Use it to save important information that should be remembered across conversations.', + ] + if self.inject_memories_in_instructions: + entries = self.store.list_all() + if entries: + parts.append('\nCurrently stored memories:') + for entry in entries[: self.max_instructions_memories]: + parts.append(f'- {format_entry(entry)}') + overflow = len(entries) - self.max_instructions_memories + if overflow > 0: + parts.append(f'... and {overflow} more (use list_memories or search_memories to see all).') + return '\n'.join(parts) + + def get_instructions(self) -> AgentInstructions[AgentDepsT] | None: + """Return dynamic instructions that include stored memories.""" + return self.build_instructions + + def get_toolset(self) -> AgentToolset[AgentDepsT] | None: + """Return a toolset with memory management tools. + + Tool functions close over `self` to access the store without + requiring anything from the agent's `deps`. + """ + store = self.store + + def save_memory( + key: str, + content: str, + tags: list[str] | None = None, + scope: str = 'global', + ttl_minutes: int | None = None, + ) -> str: + """Save or update a memory entry. + + Args: + key: Unique key for this memory. + content: The content to remember. + tags: Optional tags for categorization and search. + scope: Namespace scope (default `'global'`). + ttl_minutes: Optional time-to-live in minutes. The entry will expire after this duration. + """ + now = datetime.now(timezone.utc) + now_iso = now.isoformat() + existing = store.get(key) + + # Dedup warning: check for similar keys among existing entries + for existing_entry in store.list_all(): + if _simple_similarity(key, existing_entry.key): + logger.warning( + 'New memory key %r is very similar to existing key %r — possible duplicate', + key, + existing_entry.key, + ) + + expires_at: str | None = None + if ttl_minutes is not None: + expires_at = (now + timedelta(minutes=ttl_minutes)).isoformat() + + entry = MemoryEntry( + key=key, + content=content, + tags=tags or [], + scope=scope, + expires_at=expires_at, + created_at=existing.created_at if existing else now_iso, + updated_at=now_iso, + ) + store.put(entry) + return f'Memory saved: {key}' + + def recall_memory(key: str) -> str: + """Recall a specific memory by its key. + + Args: + key: The key of the memory to recall. + """ + entry = store.get(key) + if entry is None: + return f'No memory found for key: {key}' + if entry.is_expired(): + return f'No memory found for key: {key}' + return format_entry(entry) + + def search_memories(query: str, scope: str | None = None) -> str: + """Search memories by word-boundary matching on keys, content, or tags, sorted by relevance. + + Args: + query: The search query string (space-separated words). + scope: Optional scope to restrict the search to. + """ + results = store.search(query, scope=scope) + if not results: + return f'No memories found matching: {query}' + return '\n'.join(format_entry(entry) for entry in results) + + def list_memories(scope: str | None = None) -> str: + """List all stored memories, optionally filtered by scope. + + Args: + scope: Optional scope to filter by. + """ + entries = store.list_all(scope=scope) + if not entries: + return 'No memories stored.' + return '\n'.join(format_entry(entry) for entry in entries) + + def delete_memory(key: str) -> str: + """Delete a memory by its key. + + Args: + key: The key of the memory to delete. + """ + if store.delete(key): + return f'Memory deleted: {key}' + return f'No memory found for key: {key}' + + return FunctionToolset( + [ + Tool(save_memory, takes_ctx=False), + Tool(recall_memory, takes_ctx=False), + Tool(search_memories, takes_ctx=False), + Tool(list_memories, takes_ctx=False), + Tool(delete_memory, takes_ctx=False), + ], + ) diff --git a/tests/test_memory.py b/tests/test_memory.py new file mode 100644 index 0000000..82a45f7 --- /dev/null +++ b/tests/test_memory.py @@ -0,0 +1,812 @@ +"""Tests for the Memory capability.""" + +from __future__ import annotations + +import json +import logging +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Any + +from pydantic_ai._run_context import RunContext +from pydantic_ai.toolsets.function import FunctionToolset +from pydantic_ai.usage import RunUsage + +from pydantic_harness.memory import ( + FileStore, + InMemoryStore, + Memory, + MemoryEntry, + MemoryStore, + _score_entry, + _simple_similarity, + format_entry, +) + +# --- MemoryEntry --- + + +class TestMemoryEntry: + def test_round_trip(self) -> None: + entry = MemoryEntry( + key='k', + content='v', + tags=['a', 'b'], + scope='project', + expires_at='2099-01-01T00:00:00+00:00', + created_at='t1', + updated_at='t2', + ) + assert MemoryEntry.from_dict(entry.to_dict()) == entry + + def test_from_dict_defaults(self) -> None: + entry = MemoryEntry.from_dict({'key': 'k', 'content': 'v'}) + assert entry.tags == [] + assert entry.scope == 'global' + assert entry.expires_at is None + assert entry.created_at == '' + assert entry.updated_at == '' + + def test_default_timestamps(self) -> None: + entry = MemoryEntry(key='k', content='v') + assert entry.created_at # non-empty ISO string + assert entry.updated_at + + def test_default_scope(self) -> None: + entry = MemoryEntry(key='k', content='v') + assert entry.scope == 'global' + + def test_is_expired_no_expiry(self) -> None: + entry = MemoryEntry(key='k', content='v') + assert not entry.is_expired() + + def test_is_expired_future(self) -> None: + future = (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat() + entry = MemoryEntry(key='k', content='v', expires_at=future) + assert not entry.is_expired() + + def test_is_expired_past(self) -> None: + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + entry = MemoryEntry(key='k', content='v', expires_at=past) + assert entry.is_expired() + + +# --- _score_entry --- + + +class TestScoreEntry: + def test_no_match(self) -> None: + entry = MemoryEntry(key='greeting', content='hello world') + assert _score_entry(entry, ['zzz']) == 0 + + def test_key_match(self) -> None: + entry = MemoryEntry(key='greeting', content='some text') + assert _score_entry(entry, ['greeting']) == 1 + + def test_content_match(self) -> None: + entry = MemoryEntry(key='k', content='hello world') + assert _score_entry(entry, ['hello']) == 1 + + def test_tag_match(self) -> None: + entry = MemoryEntry(key='k', content='text', tags=['important']) + assert _score_entry(entry, ['important']) == 1 + + def test_multiple_field_match(self) -> None: + entry = MemoryEntry(key='hello', content='hello world', tags=['hello']) + # 'hello' appears in key (1) + content (1) + tags (1) = 3 + assert _score_entry(entry, ['hello']) == 3 + + def test_multiple_words(self) -> None: + entry = MemoryEntry(key='user', content='Alice likes blue') + # 'alice' in content (1), 'blue' in content (1) = 2 + assert _score_entry(entry, ['alice', 'blue']) == 2 + + def test_word_boundary_no_partial(self) -> None: + # 'fox' should NOT match 'foxes' with word-boundary matching + entry = MemoryEntry(key='k', content='foxes jump') + assert _score_entry(entry, ['fox']) == 0 + + def test_regex_metacharacters_in_query(self) -> None: + entry = MemoryEntry(key='lang', content='I use c++ daily') + assert _score_entry(entry, ['c++']) == 1 + + def test_empty_words_list(self) -> None: + entry = MemoryEntry(key='k', content='hello') + assert _score_entry(entry, []) == 0 + + def test_underscore_word_boundary(self) -> None: + entry = MemoryEntry(key='user_name', content='text') + assert _score_entry(entry, ['name']) == 1 + + def test_hyphen_word_boundary(self) -> None: + entry = MemoryEntry(key='my-project', content='text') + assert _score_entry(entry, ['project']) == 1 + + def test_partial_word_match(self) -> None: + entry = MemoryEntry(key='k', content='alice likes blue') + # 'alice' matches (1), 'zzz' does not (0) = score 1 + assert _score_entry(entry, ['alice', 'zzz']) == 1 + + +# --- _simple_similarity --- + + +class TestSimpleSimilarity: + def test_identical_keys_not_similar(self) -> None: + assert not _simple_similarity('abcdefghij', 'abcdefghij') + + def test_short_keys_not_similar(self) -> None: + assert not _simple_similarity('abc', 'abd') + + def test_similar_long_keys(self) -> None: + # Differ by 2 characters ('fo' vs 'ba') — within the edit-distance threshold + assert _simple_similarity('abcdefghij_fo', 'abcdefghij_ba') + + def test_different_prefix(self) -> None: + assert not _simple_similarity('xxxxxxxxxxfoo', 'yyyyyyyyyyfoo') + + def test_same_prefix_large_edit(self) -> None: + assert not _simple_similarity('abcdefghijklmnop', 'abcdefghijzzzzzz') + + def test_length_diff_too_large(self) -> None: + # Same 10-char prefix but length differs by more than 2 + assert not _simple_similarity('abcdefghij_x', 'abcdefghij_xyzw') + + def test_one_char_diff(self) -> None: + assert _simple_similarity('abcdefghij_x', 'abcdefghij_y') + + def test_edit_distance_exactly_three(self) -> None: + # Just over the threshold -- should NOT be similar + assert not _simple_similarity('abcdefghij_abc', 'abcdefghij_xyz') + + def test_nine_char_keys(self) -> None: + # Just below the 10-char minimum + assert not _simple_similarity('abcdefghi', 'abcdefghj') + + def test_exactly_ten_char_keys_not_similar(self) -> None: + # 10-char keys differing at position 10 do NOT share a 10-char prefix + assert not _simple_similarity('abcdefghij', 'abcdefghik') + + +# --- InMemoryStore --- + + +class TestInMemoryStore: + def test_put_and_get(self) -> None: + store = InMemoryStore() + entry = MemoryEntry(key='greeting', content='hello') + store.put(entry) + assert store.get('greeting') is entry + + def test_get_missing(self) -> None: + store = InMemoryStore() + assert store.get('nope') is None + + def test_put_overwrites(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='k', content='v1')) + store.put(MemoryEntry(key='k', content='v2')) + result = store.get('k') + assert result is not None + assert result.content == 'v2' + + def test_delete_existing(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='k', content='v')) + assert store.delete('k') is True + assert store.get('k') is None + + def test_delete_missing(self) -> None: + store = InMemoryStore() + assert store.delete('nope') is False + + def test_list_all_empty(self) -> None: + store = InMemoryStore() + assert store.list_all() == [] + + def test_list_all(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='a', content='alpha')) + store.put(MemoryEntry(key='b', content='beta')) + entries = store.list_all() + assert len(entries) == 2 + assert {e.key for e in entries} == {'a', 'b'} + + def test_list_all_filters_expired(self) -> None: + store = InMemoryStore() + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='alive', content='fresh')) + store.put(MemoryEntry(key='dead', content='stale', expires_at=past)) + entries = store.list_all() + assert len(entries) == 1 + assert entries[0].key == 'alive' + + def test_list_all_scope_filter(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='a', content='x', scope='project')) + store.put(MemoryEntry(key='b', content='y', scope='global')) + entries = store.list_all(scope='project') + assert len(entries) == 1 + assert entries[0].key == 'a' + + def test_list_all_scope_none_returns_all(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='a', content='x', scope='project')) + store.put(MemoryEntry(key='b', content='y', scope='global')) + assert len(store.list_all(scope=None)) == 2 + + def test_search_by_key(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='user_name', content='Alice')) + store.put(MemoryEntry(key='color', content='blue')) + results = store.search('user') + assert len(results) == 1 + assert results[0].key == 'user_name' + + def test_search_by_content(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='k1', content='the quick brown fox')) + store.put(MemoryEntry(key='k2', content='lazy dog')) + results = store.search('fox') + assert len(results) == 1 + assert results[0].key == 'k1' + + def test_search_by_tag(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='k1', content='x', tags=['important'])) + store.put(MemoryEntry(key='k2', content='y', tags=['trivial'])) + results = store.search('important') + assert len(results) == 1 + assert results[0].key == 'k1' + + def test_search_case_insensitive(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='K1', content='Hello World')) + results = store.search('hello') + assert len(results) == 1 + + def test_search_no_results(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='k', content='v')) + assert store.search('zzz') == [] + + def test_search_filters_expired(self) -> None: + store = InMemoryStore() + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='alive', content='hello world')) + store.put(MemoryEntry(key='dead', content='hello world', expires_at=past)) + results = store.search('hello') + assert len(results) == 1 + assert results[0].key == 'alive' + + def test_search_scope_filter(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='a', content='hello world', scope='project')) + store.put(MemoryEntry(key='b', content='hello world', scope='global')) + results = store.search('hello', scope='project') + assert len(results) == 1 + assert results[0].key == 'a' + + def test_search_relevance_ordering(self) -> None: + store = InMemoryStore() + # 'hello' appears in key + content = score 2 + store.put(MemoryEntry(key='hello', content='hello there')) + # 'hello' appears only in content = score 1 + store.put(MemoryEntry(key='other', content='hello world')) + results = store.search('hello') + assert len(results) == 2 + assert results[0].key == 'hello' # higher score first + assert results[1].key == 'other' + + def test_search_empty_query(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='k', content='v')) + assert store.search('') == [] + + +# --- FileStore --- + + +class TestFileStore: + def test_put_and_get(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='k', content='v')) + assert store.get('k') is not None + assert store.get('k').content == 'v' # type: ignore[union-attr] + + def test_persistence(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store1 = FileStore(path) + store1.put(MemoryEntry(key='k', content='persisted')) + + # New store instance should load from disk + store2 = FileStore(path) + result = store2.get('k') + assert result is not None + assert result.content == 'persisted' + + def test_delete_saves(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='k', content='v')) + store.delete('k') + + # Reload and verify deletion persisted + store2 = FileStore(path) + assert store2.get('k') is None + + def test_delete_missing(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + assert store.delete('nope') is False + + def test_list_all(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='a', content='alpha')) + store.put(MemoryEntry(key='b', content='beta')) + assert len(store.list_all()) == 2 + + def test_search(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='k1', content='hello', tags=['greeting'])) + store.put(MemoryEntry(key='k2', content='world')) + assert len(store.search('greeting')) == 1 + assert len(store.search('hello')) == 1 + + def test_empty_file(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + # File does not exist yet + store = FileStore(path) + assert store.list_all() == [] + + def test_creates_parent_dirs(self, tmp_path: Path) -> None: + path = tmp_path / 'sub' / 'dir' / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='k', content='v')) + assert path.exists() + + def test_file_format(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='k', content='v', tags=['t'], created_at='c', updated_at='u')) + raw = json.loads(path.read_text()) + assert raw == { + 'k': { + 'key': 'k', + 'content': 'v', + 'tags': ['t'], + 'scope': 'global', + 'expires_at': None, + 'created_at': 'c', + 'updated_at': 'u', + } + } + + def test_list_all_filters_expired(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='alive', content='x')) + store.put(MemoryEntry(key='dead', content='y', expires_at=past)) + assert len(store.list_all()) == 1 + + def test_search_filters_expired(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='alive', content='hello world')) + store.put(MemoryEntry(key='dead', content='hello world', expires_at=past)) + assert len(store.search('hello')) == 1 + + def test_list_all_scope(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='a', content='x', scope='project')) + store.put(MemoryEntry(key='b', content='y', scope='global')) + assert len(store.list_all(scope='project')) == 1 + + def test_search_scope(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='a', content='hello world', scope='project')) + store.put(MemoryEntry(key='b', content='hello world', scope='global')) + assert len(store.search('hello', scope='project')) == 1 + + def test_search_empty_query(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='k', content='v')) + assert store.search('') == [] + + def test_load_malformed_json(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + path.write_text('not json at all', encoding='utf-8') + store = FileStore(path) + assert store.list_all() == [] + + def test_load_wrong_structure(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + path.write_text('["a", "b"]', encoding='utf-8') + store = FileStore(path) + assert store.list_all() == [] + + def test_load_missing_entry_fields(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + path.write_text('{"k": {"not_a_key": "oops"}}', encoding='utf-8') + store = FileStore(path) + assert store.list_all() == [] + + def test_scope_persists(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store1 = FileStore(path) + store1.put(MemoryEntry(key='k', content='v', scope='session')) + store2 = FileStore(path) + entry = store2.get('k') + assert entry is not None + assert entry.scope == 'session' + + def test_expires_at_persists(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + future = (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat() + store1 = FileStore(path) + store1.put(MemoryEntry(key='k', content='v', expires_at=future)) + store2 = FileStore(path) + entry = store2.get('k') + assert entry is not None + assert entry.expires_at == future + + +# --- format_entry --- + + +class TestFormatEntry: + def test_no_tags(self) -> None: + entry = MemoryEntry(key='k', content='hello') + assert format_entry(entry) == '[k] hello' + + def test_with_tags(self) -> None: + entry = MemoryEntry(key='k', content='hello', tags=['a', 'b']) + assert format_entry(entry) == '[k] hello (tags: a, b)' + + def test_with_scope(self) -> None: + entry = MemoryEntry(key='k', content='hello', scope='project') + assert format_entry(entry) == '[k] hello (scope: project)' + + def test_global_scope_omitted(self) -> None: + entry = MemoryEntry(key='k', content='hello', scope='global') + assert format_entry(entry) == '[k] hello' + + def test_with_expires_at(self) -> None: + entry = MemoryEntry(key='k', content='hello', expires_at='2099-01-01T00:00:00+00:00') + assert format_entry(entry) == '[k] hello (expires: 2099-01-01T00:00:00+00:00)' + + def test_all_extras(self) -> None: + entry = MemoryEntry( + key='k', + content='hello', + tags=['t'], + scope='project', + expires_at='2099-01-01T00:00:00+00:00', + ) + assert format_entry(entry) == '[k] hello (tags: t; scope: project; expires: 2099-01-01T00:00:00+00:00)' + + def test_empty_content(self) -> None: + entry = MemoryEntry(key='k', content='') + assert format_entry(entry) == '[k] ' + + def test_empty_key(self) -> None: + entry = MemoryEntry(key='', content='hello') + assert format_entry(entry) == '[] hello' + + +# --- Memory capability --- + + +class TestMemoryCapability: + def test_serialization_name(self) -> None: + assert Memory.get_serialization_name() == 'Memory' + + def test_from_spec_default(self) -> None: + cap = Memory.from_spec() + assert isinstance(cap.store, InMemoryStore) + + def test_from_spec_file(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + cap = Memory.from_spec(backend='file', path=str(path)) + assert isinstance(cap.store, FileStore) + + def test_from_spec_unknown_backend(self) -> None: + import pytest + + with pytest.raises(ValueError, match='Unknown memory backend'): + Memory.from_spec(backend='redis') + + def test_from_spec_explicit_memory_backend(self) -> None: + cap = Memory.from_spec(backend='memory') + assert isinstance(cap.store, InMemoryStore) + + def test_from_spec_with_options(self, tmp_path: Path) -> None: + cap = Memory.from_spec( + backend='file', + path=str(tmp_path / 'mem.json'), + inject_memories_in_instructions=False, + max_instructions_memories=10, + ) + assert isinstance(cap.store, FileStore) + assert cap.inject_memories_in_instructions is False + assert cap.max_instructions_memories == 10 + + def test_default_store(self) -> None: + cap: Memory[None] = Memory() + assert isinstance(cap.store, InMemoryStore) + + def test_get_toolset_returns_function_toolset(self) -> None: + cap: Memory[None] = Memory() + toolset = cap.get_toolset() + assert isinstance(toolset, FunctionToolset) + + def test_toolset_has_expected_tools(self) -> None: + cap: Memory[None] = Memory() + toolset = cap.get_toolset() + assert isinstance(toolset, FunctionToolset) + tool_names = set(toolset.tools.keys()) + assert tool_names == {'save_memory', 'recall_memory', 'search_memories', 'list_memories', 'delete_memory'} + + +# --- Tool functions (via closure) --- + + +class TestMemoryTools: + """Test the tool functions exposed by the Memory capability.""" + + @staticmethod + def _get_tools(store: InMemoryStore | None = None) -> dict[str, Any]: + cap: Memory[None] = Memory(store=store or InMemoryStore()) + toolset = cap.get_toolset() + assert isinstance(toolset, FunctionToolset) + return {name: tool.function for name, tool in toolset.tools.items()} + + def test_save_and_recall(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + result = tools['save_memory']('greeting', 'hello world') + assert result == 'Memory saved: greeting' + + recalled = tools['recall_memory']('greeting') + assert '[greeting] hello world' in recalled + + def test_recall_missing(self) -> None: + tools = self._get_tools() + assert 'No memory found' in tools['recall_memory']('nope') + + def test_recall_expired(self) -> None: + store = InMemoryStore() + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='old', content='stale', expires_at=past)) + tools = self._get_tools(store) + assert 'No memory found' in tools['recall_memory']('old') + + def test_save_updates_existing(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v1') + original = store.get('k') + assert original is not None + original_created = original.created_at + + tools['save_memory']('k', 'v2') + updated = store.get('k') + assert updated is not None + assert updated.content == 'v2' + # created_at should be preserved + assert updated.created_at == original_created + + def test_save_with_tags(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v', ['tag1', 'tag2']) + entry = store.get('k') + assert entry is not None + assert entry.tags == ['tag1', 'tag2'] + + def test_save_with_scope(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v', None, 'project') + entry = store.get('k') + assert entry is not None + assert entry.scope == 'project' + + def test_save_with_ttl(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v', None, 'global', 60) + entry = store.get('k') + assert entry is not None + assert entry.expires_at is not None + expires = datetime.fromisoformat(entry.expires_at) + # Should expire roughly 60 minutes from now + assert expires > datetime.now(timezone.utc) + timedelta(minutes=59) + assert expires < datetime.now(timezone.utc) + timedelta(minutes=61) + + def test_search(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('user_name', 'Alice') + tools['save_memory']('color', 'blue') + + result = tools['search_memories']('Alice') + assert 'Alice' in result + assert 'blue' not in result + + def test_search_no_results(self) -> None: + tools = self._get_tools() + assert 'No memories found' in tools['search_memories']('zzz') + + def test_search_with_scope(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('a', 'hello world', None, 'project') + tools['save_memory']('b', 'hello world', None, 'global') + result = tools['search_memories']('hello', 'project') + assert '[a]' in result + assert '[b]' not in result + + def test_list_empty(self) -> None: + tools = self._get_tools() + assert tools['list_memories']() == 'No memories stored.' + + def test_list_with_entries(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('a', 'alpha') + tools['save_memory']('b', 'beta') + result = tools['list_memories']() + assert '[a] alpha' in result + assert '[b] beta' in result + + def test_list_with_scope(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('a', 'alpha', None, 'project') + tools['save_memory']('b', 'beta', None, 'global') + result = tools['list_memories']('project') + assert '[a] alpha' in result + assert '[b]' not in result + + def test_delete_existing(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v') + assert tools['delete_memory']('k') == 'Memory deleted: k' + assert store.get('k') is None + + def test_delete_missing(self) -> None: + tools = self._get_tools() + assert 'No memory found' in tools['delete_memory']('nope') + + def test_save_with_ttl_zero(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v', None, 'global', 0) + entry = store.get('k') + assert entry is not None + assert entry.expires_at is not None + # TTL=0 means it expires immediately + assert entry.is_expired() + + +# --- Dedup warning --- + + +class TestDedupWarning: + def test_similar_key_logs_warning(self, caplog: Any) -> None: + store = InMemoryStore() + tools = TestMemoryTools._get_tools(store) + tools['save_memory']('abcdefghij_x', 'first value') + with caplog.at_level(logging.WARNING, logger='pydantic_harness.memory'): + tools['save_memory']('abcdefghij_y', 'second value') + assert any('possible duplicate' in record.message.lower() for record in caplog.records) + + def test_different_keys_no_warning(self, caplog: Any) -> None: + store = InMemoryStore() + tools = TestMemoryTools._get_tools(store) + tools['save_memory']('first_key_long', 'first value') + with caplog.at_level(logging.WARNING, logger='pydantic_harness.memory'): + tools['save_memory']('other_key_long', 'second value') + assert not any('possible duplicate' in record.message.lower() for record in caplog.records) + + def test_short_keys_no_warning(self, caplog: Any) -> None: + store = InMemoryStore() + tools = TestMemoryTools._get_tools(store) + tools['save_memory']('abc', 'first value') + with caplog.at_level(logging.WARNING, logger='pydantic_harness.memory'): + tools['save_memory']('abd', 'second value') + assert not any('possible duplicate' in record.message.lower() for record in caplog.records) + + +# --- Instructions --- + + +class TestMemoryInstructions: + @staticmethod + def _make_ctx() -> RunContext[None]: + from unittest.mock import MagicMock + + return RunContext( + deps=None, + model=MagicMock(), + usage=RunUsage(), + ) + + def test_get_instructions_is_callable(self) -> None: + cap: Memory[None] = Memory() + assert callable(cap.get_instructions()) + + def test_instructions_with_no_memories(self) -> None: + cap: Memory[None] = Memory() + text = cap.build_instructions(self._make_ctx()) + assert 'persistent memory system' in text + assert 'Currently stored memories' not in text + + def test_instructions_with_memories(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='user', content='Alice')) + cap: Memory[None] = Memory(store=store) + text = cap.build_instructions(self._make_ctx()) + assert 'Currently stored memories' in text + assert '[user] Alice' in text + + def test_instructions_respects_max(self) -> None: + store = InMemoryStore() + for i in range(25): + store.put(MemoryEntry(key=f'k{i}', content=f'v{i}')) + cap: Memory[None] = Memory(store=store, max_instructions_memories=5) + text = cap.build_instructions(self._make_ctx()) + assert '... and 20 more' in text + + def test_instructions_disabled(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='k', content='v')) + cap: Memory[None] = Memory(store=store, inject_memories_in_instructions=False) + text = cap.build_instructions(self._make_ctx()) + assert 'Currently stored memories' not in text + + def test_instructions_exact_max_no_overflow(self) -> None: + store = InMemoryStore() + for i in range(5): + store.put(MemoryEntry(key=f'k{i}', content=f'v{i}')) + cap: Memory[None] = Memory(store=store, max_instructions_memories=5) + text = cap.build_instructions(self._make_ctx()) + assert '... and' not in text + assert '[k0]' in text + assert '[k4]' in text + + +# --- MemoryStore protocol --- + + +class TestMemoryStoreProtocol: + def test_in_memory_store_satisfies_protocol(self) -> None: + assert isinstance(InMemoryStore(), MemoryStore) + + def test_file_store_satisfies_protocol(self, tmp_path: Path) -> None: + assert isinstance(FileStore(tmp_path / 'mem.json'), MemoryStore) + + +# --- AbstractCapability conformance --- + + +class TestAbstractCapabilityConformance: + def test_is_abstract_capability_subclass(self) -> None: + from pydantic_ai.capabilities.abstract import AbstractCapability + + assert issubclass(Memory, AbstractCapability) + + def test_instance_is_abstract_capability(self) -> None: + from pydantic_ai.capabilities.abstract import AbstractCapability + + assert isinstance(Memory(), AbstractCapability) diff --git a/uv.lock b/uv.lock index 0730281..6178ac2 100644 --- a/uv.lock +++ b/uv.lock @@ -319,15 +319,15 @@ wheels = [ [[package]] name = "opentelemetry-api" -version = "1.40.0" +version = "1.39.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "importlib-metadata" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2c/1d/4049a9e8698361cc1a1aa03a6c59e4fa4c71e0c0f94a30f988a6876a2ae6/opentelemetry_api-1.40.0.tar.gz", hash = "sha256:159be641c0b04d11e9ecd576906462773eb97ae1b657730f0ecf64d32071569f", size = 70851, upload-time = "2026-03-04T14:17:21.555Z" } +sdist = { url = "https://files.pythonhosted.org/packages/97/b9/3161be15bb8e3ad01be8be5a968a9237c3027c5be504362ff800fca3e442/opentelemetry_api-1.39.1.tar.gz", hash = "sha256:fbde8c80e1b937a2c61f20347e91c0c18a1940cecf012d62e65a7caf08967c9c", size = 65767, upload-time = "2025-12-11T13:32:39.182Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5f/bf/93795954016c522008da367da292adceed71cca6ee1717e1d64c83089099/opentelemetry_api-1.40.0-py3-none-any.whl", hash = "sha256:82dd69331ae74b06f6a874704be0cfaa49a1650e1537d4a813b86ecef7d0ecf9", size = 68676, upload-time = "2026-03-04T14:17:01.24Z" }, + { url = "https://files.pythonhosted.org/packages/cf/df/d3f1ddf4bb4cb50ed9b1139cc7b1c54c34a1e7ce8fd1b9a37c0d1551a6bd/opentelemetry_api-1.39.1-py3-none-any.whl", hash = "sha256:2edd8463432a7f8443edce90972169b195e7d6a05500cd29e6d13898187c9950", size = 66356, upload-time = "2025-12-11T13:32:17.304Z" }, ] [[package]]