From 99d16524093b6c5d7667c5845b13be6f785868c6 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 2 Apr 2026 04:44:02 +0000 Subject: [PATCH 01/10] Add Memory capability with pluggable storage backends Implements a Memory capability (AbstractCapability subclass) for persistent key-value memory across agent sessions, addressing #30. - MemoryStore protocol with InMemoryStore (dict-based, for testing) and FileStore (JSON file on disk, for persistence) backends - Five tools via get_toolset(): save_memory, recall_memory, search_memories, list_memories, delete_memory - Dynamic instructions via get_instructions() that inject stored memories into the system prompt at run start - Substring-based search across keys, content, and tags - Spec serialization support (Memory.from_spec with backend="memory"|"file") - 48 tests covering all code paths, passing lint, format, and typecheck Co-Authored-By: Claude Opus 4.6 (1M context) --- PLAN.md | 51 ++++ src/pydantic_harness/__init__.py | 10 +- src/pydantic_harness/memory.py | 323 ++++++++++++++++++++++++++ tests/test_memory.py | 387 +++++++++++++++++++++++++++++++ 4 files changed, 770 insertions(+), 1 deletion(-) create mode 100644 PLAN.md create mode 100644 src/pydantic_harness/memory.py create mode 100644 tests/test_memory.py diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 0000000..ca533d8 --- /dev/null +++ b/PLAN.md @@ -0,0 +1,51 @@ +# Memory Capability + +## Summary + +Implements a `Memory` capability (`AbstractCapability` subclass) that provides persistent key-value memory across agent sessions, referencing issues #30 and #31. + +## Design + +### Architecture + +- **`Memory`** dataclass extends `AbstractCapability[AgentDepsT]` + - `get_instructions()` returns a dynamic callable that injects stored memories into the system prompt at run start + - `get_toolset()` returns a `FunctionToolset` with five tools: `save_memory`, `recall_memory`, `search_memories`, `list_memories`, `delete_memory` + - Tool functions use closures over `self.store` (no dependency on agent `deps`) + +### Storage + +- **`MemoryStore`** protocol: pluggable backend with `get`, `put`, `delete`, `list_all`, `search` +- **`InMemoryStore`**: dict-based, ephemeral, for testing (default) +- **`FileStore`**: JSON file on disk, reads on init, writes on every mutation + +### Memory Model + +- **`MemoryEntry`** dataclass: `key`, `content`, `tags` (list[str]), `created_at`, `updated_at` +- Search is substring-based (case-insensitive) across key, content, and tags + +### Spec Serialization + +- `Memory.get_serialization_name()` returns `"Memory"` +- `Memory.from_spec(backend="file", path="...")` creates a `FileStore`-backed instance + +## Configuration + +| Field | Default | Description | +|-------|---------|-------------| +| `store` | `InMemoryStore()` | Storage backend | +| `inject_memories_in_instructions` | `True` | Include memories in system prompt | +| `max_instructions_memories` | `20` | Cap on memories injected into prompt | + +## Files + +- `src/pydantic_harness/memory.py` - Capability, stores, entry model +- `src/pydantic_harness/__init__.py` - Re-exports +- `tests/test_memory.py` - 48 tests covering all code paths + +## Future Work + +- Semantic/vector search backend (e.g. embedding-based `MemoryStore`) +- TTL / expiration on entries +- Session-scoped memory isolation via `for_run()` +- SQLite / Redis backends for production persistence diff --git a/src/pydantic_harness/__init__.py b/src/pydantic_harness/__init__.py index 9d728b6..62831af 100644 --- a/src/pydantic_harness/__init__.py +++ b/src/pydantic_harness/__init__.py @@ -7,4 +7,12 @@ # Each capability module is imported and re-exported here. # Capabilities are listed alphabetically. -__all__: list[str] = [] +from pydantic_harness.memory import FileStore, InMemoryStore, Memory, MemoryEntry, MemoryStore + +__all__: list[str] = [ + 'FileStore', + 'InMemoryStore', + 'Memory', + 'MemoryEntry', + 'MemoryStore', +] diff --git a/src/pydantic_harness/memory.py b/src/pydantic_harness/memory.py new file mode 100644 index 0000000..8e8db26 --- /dev/null +++ b/src/pydantic_harness/memory.py @@ -0,0 +1,323 @@ +"""Memory capability for persistent agent memory across sessions. + +Provides tools for saving, recalling, searching, listing, and deleting +key-value memories, with pluggable storage backends (`InMemoryStore` for +testing, `FileStore` for on-disk persistence). +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Protocol, runtime_checkable + +from pydantic_ai._instructions import AgentInstructions +from pydantic_ai.capabilities.abstract import AbstractCapability +from pydantic_ai.tools import AgentDepsT, RunContext, Tool +from pydantic_ai.toolsets import AgentToolset +from pydantic_ai.toolsets.function import FunctionToolset + + +@dataclass +class MemoryEntry: + """A single memory entry with content, tags, and timestamps.""" + + key: str + """Unique identifier for this memory.""" + + content: str + """The content of the memory.""" + + tags: list[str] = field(default_factory=list[str]) + """Optional tags for categorization and search.""" + + created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + """ISO 8601 timestamp of when the memory was first created.""" + + updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + """ISO 8601 timestamp of the last update.""" + + def to_dict(self) -> dict[str, Any]: + """Serialize to a plain dict for JSON storage.""" + return { + 'key': self.key, + 'content': self.content, + 'tags': self.tags, + 'created_at': self.created_at, + 'updated_at': self.updated_at, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> MemoryEntry: + """Deserialize from a plain dict.""" + return cls( + key=data['key'], + content=data['content'], + tags=data.get('tags', []), + created_at=data.get('created_at', ''), + updated_at=data.get('updated_at', ''), + ) + + +@runtime_checkable +class MemoryStore(Protocol): + """Protocol for pluggable memory storage backends.""" + + def get(self, key: str) -> MemoryEntry | None: + """Retrieve a memory entry by key, or None if not found.""" + ... + + def put(self, entry: MemoryEntry) -> None: + """Store or update a memory entry.""" + ... + + def delete(self, key: str) -> bool: + """Delete a memory entry by key. Returns True if it existed.""" + ... + + def list_all(self) -> list[MemoryEntry]: + """Return all stored memory entries.""" + ... + + def search(self, query: str) -> list[MemoryEntry]: + """Search entries by substring match on key, content, or tags.""" + ... + + +class InMemoryStore: + """Dict-based in-memory store, suitable for testing. + + All data lives in a plain ``dict`` and is lost when the process exits. + """ + + def __init__(self) -> None: + """Initialize an empty in-memory store.""" + self._entries: dict[str, MemoryEntry] = {} + + def get(self, key: str) -> MemoryEntry | None: + """Retrieve a memory entry by key.""" + return self._entries.get(key) + + def put(self, entry: MemoryEntry) -> None: + """Store or update a memory entry.""" + self._entries[entry.key] = entry + + def delete(self, key: str) -> bool: + """Delete a memory entry by key.""" + return self._entries.pop(key, None) is not None + + def list_all(self) -> list[MemoryEntry]: + """Return all stored memory entries.""" + return list(self._entries.values()) + + def search(self, query: str) -> list[MemoryEntry]: + """Search entries by substring match on key, content, or tags.""" + q = query.lower() + return [ + entry + for entry in self._entries.values() + if q in entry.key.lower() or q in entry.content.lower() or any(q in tag.lower() for tag in entry.tags) + ] + + +class FileStore: + """JSON-file-based store for simple on-disk persistence. + + Reads the file on initialization and writes back on every mutation. + """ + + def __init__(self, path: str | Path) -> None: + """Initialize a file-backed store at the given path.""" + self._path = Path(path) + self._entries: dict[str, MemoryEntry] = {} + self._load() + + def _load(self) -> None: + if self._path.exists(): + raw: dict[str, Any] = json.loads(self._path.read_text(encoding='utf-8')) + self._entries = {key: MemoryEntry.from_dict(val) for key, val in raw.items()} + + def _save(self) -> None: + self._path.parent.mkdir(parents=True, exist_ok=True) + data = {key: entry.to_dict() for key, entry in self._entries.items()} + self._path.write_text(json.dumps(data, indent=2), encoding='utf-8') + + def get(self, key: str) -> MemoryEntry | None: + """Retrieve a memory entry by key.""" + return self._entries.get(key) + + def put(self, entry: MemoryEntry) -> None: + """Store or update a memory entry.""" + self._entries[entry.key] = entry + self._save() + + def delete(self, key: str) -> bool: + """Delete a memory entry by key.""" + existed = self._entries.pop(key, None) is not None + if existed: + self._save() + return existed + + def list_all(self) -> list[MemoryEntry]: + """Return all stored memory entries.""" + return list(self._entries.values()) + + def search(self, query: str) -> list[MemoryEntry]: + """Search entries by substring match on key, content, or tags.""" + q = query.lower() + return [ + entry + for entry in self._entries.values() + if q in entry.key.lower() or q in entry.content.lower() or any(q in tag.lower() for tag in entry.tags) + ] + + +def format_entry(entry: MemoryEntry) -> str: + """Format a memory entry as a human-readable string.""" + line = f'[{entry.key}] {entry.content}' + if entry.tags: + line += f' (tags: {", ".join(entry.tags)})' + return line + + +@dataclass +class Memory(AbstractCapability[AgentDepsT]): + """Capability for persistent memory across agent sessions. + + Provides tools for saving, recalling, searching, listing, and deleting + key-value memories. Uses a pluggable `MemoryStore` backend for storage. + + Example: + ```python {test="skip" lint="skip"} + from pydantic_ai import Agent + from pydantic_harness.memory import Memory, InMemoryStore + + agent = Agent('openai:gpt-4o', capabilities=[Memory(store=InMemoryStore())]) + ``` + """ + + store: MemoryStore = field(default_factory=InMemoryStore) + """The storage backend. Defaults to `InMemoryStore` (ephemeral, dict-based).""" + + inject_memories_in_instructions: bool = True + """Whether to inject existing memories into the system prompt at run start.""" + + max_instructions_memories: int = 20 + """Maximum number of memories to include in the system prompt.""" + + @classmethod + def get_serialization_name(cls) -> str | None: + """Return the name used for spec serialization.""" + return 'Memory' + + @classmethod + def from_spec(cls, *args: Any, **kwargs: Any) -> Memory[Any]: + """Create from spec arguments. + + Supports `backend` kwarg: ``"memory"`` (default) or ``"file"`` (requires `path`). + """ + backend = kwargs.pop('backend', 'memory') + if backend == 'file': + path = kwargs.pop('path', '.memories.json') + return cls(store=FileStore(path), **kwargs) + return cls(store=InMemoryStore(), **kwargs) + + def build_instructions(self, ctx: RunContext[AgentDepsT]) -> str: + """Build dynamic instructions that include currently stored memories.""" + parts: list[str] = [ + 'You have access to a persistent memory system. ' + 'Use it to save important information that should be remembered across conversations.', + ] + if self.inject_memories_in_instructions: + entries = self.store.list_all() + if entries: + parts.append('\nCurrently stored memories:') + for entry in entries[: self.max_instructions_memories]: + parts.append(f'- {format_entry(entry)}') + overflow = len(entries) - self.max_instructions_memories + if overflow > 0: + parts.append(f'... and {overflow} more (use list_memories or search_memories to see all).') + return '\n'.join(parts) + + def get_instructions(self) -> AgentInstructions[AgentDepsT] | None: + """Return dynamic instructions that include stored memories.""" + return self.build_instructions + + def get_toolset(self) -> AgentToolset[AgentDepsT] | None: + """Return a toolset with memory management tools. + + Tool functions close over ``self`` to access the store without + requiring anything from the agent's ``deps``. + """ + store = self.store + + def save_memory(key: str, content: str, tags: list[str] | None = None) -> str: + """Save or update a memory entry. + + Args: + key: Unique key for this memory. + content: The content to remember. + tags: Optional tags for categorization and search. + """ + now = datetime.now(timezone.utc).isoformat() + existing = store.get(key) + entry = MemoryEntry( + key=key, + content=content, + tags=tags or [], + created_at=existing.created_at if existing else now, + updated_at=now, + ) + store.put(entry) + return f'Memory saved: {key}' + + def recall_memory(key: str) -> str: + """Recall a specific memory by its key. + + Args: + key: The key of the memory to recall. + """ + entry = store.get(key) + if entry is None: + return f'No memory found for key: {key}' + return format_entry(entry) + + def search_memories(query: str) -> str: + """Search memories by substring match on keys, content, or tags. + + Args: + query: The search query string. + """ + results = store.search(query) + if not results: + return f'No memories found matching: {query}' + return '\n'.join(format_entry(entry) for entry in results) + + def list_memories() -> str: + """List all stored memories.""" + entries = store.list_all() + if not entries: + return 'No memories stored.' + return '\n'.join(format_entry(entry) for entry in entries) + + def delete_memory(key: str) -> str: + """Delete a memory by its key. + + Args: + key: The key of the memory to delete. + """ + if store.delete(key): + return f'Memory deleted: {key}' + return f'No memory found for key: {key}' + + return FunctionToolset( + [ + Tool(save_memory, takes_ctx=False), + Tool(recall_memory, takes_ctx=False), + Tool(search_memories, takes_ctx=False), + Tool(list_memories, takes_ctx=False), + Tool(delete_memory, takes_ctx=False), + ], + ) diff --git a/tests/test_memory.py b/tests/test_memory.py new file mode 100644 index 0000000..890bf00 --- /dev/null +++ b/tests/test_memory.py @@ -0,0 +1,387 @@ +"""Tests for the Memory capability.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from pydantic_ai._run_context import RunContext +from pydantic_ai.toolsets.function import FunctionToolset +from pydantic_ai.usage import RunUsage + +from pydantic_harness.memory import ( + FileStore, + InMemoryStore, + Memory, + MemoryEntry, + MemoryStore, + format_entry, +) + +# --- MemoryEntry --- + + +class TestMemoryEntry: + def test_round_trip(self) -> None: + entry = MemoryEntry(key='k', content='v', tags=['a', 'b'], created_at='t1', updated_at='t2') + assert MemoryEntry.from_dict(entry.to_dict()) == entry + + def test_from_dict_defaults(self) -> None: + entry = MemoryEntry.from_dict({'key': 'k', 'content': 'v'}) + assert entry.tags == [] + assert entry.created_at == '' + assert entry.updated_at == '' + + def test_default_timestamps(self) -> None: + entry = MemoryEntry(key='k', content='v') + assert entry.created_at # non-empty ISO string + assert entry.updated_at + + +# --- InMemoryStore --- + + +class TestInMemoryStore: + def test_put_and_get(self) -> None: + store = InMemoryStore() + entry = MemoryEntry(key='greeting', content='hello') + store.put(entry) + assert store.get('greeting') is entry + + def test_get_missing(self) -> None: + store = InMemoryStore() + assert store.get('nope') is None + + def test_put_overwrites(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='k', content='v1')) + store.put(MemoryEntry(key='k', content='v2')) + result = store.get('k') + assert result is not None + assert result.content == 'v2' + + def test_delete_existing(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='k', content='v')) + assert store.delete('k') is True + assert store.get('k') is None + + def test_delete_missing(self) -> None: + store = InMemoryStore() + assert store.delete('nope') is False + + def test_list_all_empty(self) -> None: + store = InMemoryStore() + assert store.list_all() == [] + + def test_list_all(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='a', content='alpha')) + store.put(MemoryEntry(key='b', content='beta')) + entries = store.list_all() + assert len(entries) == 2 + assert {e.key for e in entries} == {'a', 'b'} + + def test_search_by_key(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='user_name', content='Alice')) + store.put(MemoryEntry(key='color', content='blue')) + results = store.search('user') + assert len(results) == 1 + assert results[0].key == 'user_name' + + def test_search_by_content(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='k1', content='the quick brown fox')) + store.put(MemoryEntry(key='k2', content='lazy dog')) + results = store.search('fox') + assert len(results) == 1 + assert results[0].key == 'k1' + + def test_search_by_tag(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='k1', content='x', tags=['important'])) + store.put(MemoryEntry(key='k2', content='y', tags=['trivial'])) + results = store.search('important') + assert len(results) == 1 + assert results[0].key == 'k1' + + def test_search_case_insensitive(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='K1', content='Hello World')) + results = store.search('hello') + assert len(results) == 1 + + def test_search_no_results(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='k', content='v')) + assert store.search('zzz') == [] + + +# --- FileStore --- + + +class TestFileStore: + def test_put_and_get(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='k', content='v')) + assert store.get('k') is not None + assert store.get('k').content == 'v' # type: ignore[union-attr] + + def test_persistence(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store1 = FileStore(path) + store1.put(MemoryEntry(key='k', content='persisted')) + + # New store instance should load from disk + store2 = FileStore(path) + result = store2.get('k') + assert result is not None + assert result.content == 'persisted' + + def test_delete_saves(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='k', content='v')) + store.delete('k') + + # Reload and verify deletion persisted + store2 = FileStore(path) + assert store2.get('k') is None + + def test_list_all(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='a', content='alpha')) + store.put(MemoryEntry(key='b', content='beta')) + assert len(store.list_all()) == 2 + + def test_search(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='k1', content='hello', tags=['greeting'])) + store.put(MemoryEntry(key='k2', content='world')) + assert len(store.search('greeting')) == 1 + assert len(store.search('hello')) == 1 + + def test_empty_file(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + # File does not exist yet + store = FileStore(path) + assert store.list_all() == [] + + def test_creates_parent_dirs(self, tmp_path: Path) -> None: + path = tmp_path / 'sub' / 'dir' / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='k', content='v')) + assert path.exists() + + def test_file_format(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='k', content='v', tags=['t'], created_at='c', updated_at='u')) + raw = json.loads(path.read_text()) + assert raw == { + 'k': { + 'key': 'k', + 'content': 'v', + 'tags': ['t'], + 'created_at': 'c', + 'updated_at': 'u', + } + } + + +# --- format_entry --- + + +class TestFormatEntry: + def test_no_tags(self) -> None: + entry = MemoryEntry(key='k', content='hello') + assert format_entry(entry) == '[k] hello' + + def test_with_tags(self) -> None: + entry = MemoryEntry(key='k', content='hello', tags=['a', 'b']) + assert format_entry(entry) == '[k] hello (tags: a, b)' + + +# --- Memory capability --- + + +class TestMemoryCapability: + def test_serialization_name(self) -> None: + assert Memory.get_serialization_name() == 'Memory' + + def test_from_spec_default(self) -> None: + cap = Memory.from_spec() + assert isinstance(cap.store, InMemoryStore) + + def test_from_spec_file(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + cap = Memory.from_spec(backend='file', path=str(path)) + assert isinstance(cap.store, FileStore) + + def test_default_store(self) -> None: + cap: Memory[None] = Memory() + assert isinstance(cap.store, InMemoryStore) + + def test_get_toolset_returns_function_toolset(self) -> None: + cap: Memory[None] = Memory() + toolset = cap.get_toolset() + assert isinstance(toolset, FunctionToolset) + + def test_toolset_has_expected_tools(self) -> None: + cap: Memory[None] = Memory() + toolset = cap.get_toolset() + assert isinstance(toolset, FunctionToolset) + tool_names = set(toolset.tools.keys()) + assert tool_names == {'save_memory', 'recall_memory', 'search_memories', 'list_memories', 'delete_memory'} + + +# --- Tool functions (via closure) --- + + +class TestMemoryTools: + """Test the tool functions exposed by the Memory capability.""" + + @staticmethod + def _get_tools(store: InMemoryStore | None = None) -> dict[str, Any]: + cap: Memory[None] = Memory(store=store or InMemoryStore()) + toolset = cap.get_toolset() + assert isinstance(toolset, FunctionToolset) + return {name: tool.function for name, tool in toolset.tools.items()} + + def test_save_and_recall(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + result = tools['save_memory']('greeting', 'hello world') + assert result == 'Memory saved: greeting' + + recalled = tools['recall_memory']('greeting') + assert '[greeting] hello world' in recalled + + def test_recall_missing(self) -> None: + tools = self._get_tools() + assert 'No memory found' in tools['recall_memory']('nope') + + def test_save_updates_existing(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v1') + original = store.get('k') + assert original is not None + original_created = original.created_at + + tools['save_memory']('k', 'v2') + updated = store.get('k') + assert updated is not None + assert updated.content == 'v2' + # created_at should be preserved + assert updated.created_at == original_created + + def test_save_with_tags(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v', ['tag1', 'tag2']) + entry = store.get('k') + assert entry is not None + assert entry.tags == ['tag1', 'tag2'] + + def test_search(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('user_name', 'Alice') + tools['save_memory']('color', 'blue') + + result = tools['search_memories']('alice') + assert 'Alice' in result + assert 'blue' not in result + + def test_search_no_results(self) -> None: + tools = self._get_tools() + assert 'No memories found' in tools['search_memories']('zzz') + + def test_list_empty(self) -> None: + tools = self._get_tools() + assert tools['list_memories']() == 'No memories stored.' + + def test_list_with_entries(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('a', 'alpha') + tools['save_memory']('b', 'beta') + result = tools['list_memories']() + assert '[a] alpha' in result + assert '[b] beta' in result + + def test_delete_existing(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v') + assert tools['delete_memory']('k') == 'Memory deleted: k' + assert store.get('k') is None + + def test_delete_missing(self) -> None: + tools = self._get_tools() + assert 'No memory found' in tools['delete_memory']('nope') + + +# --- Instructions --- + + +class TestMemoryInstructions: + @staticmethod + def _make_ctx() -> RunContext[None]: + from unittest.mock import MagicMock + + return RunContext( + deps=None, + model=MagicMock(), + usage=RunUsage(), + ) + + def test_get_instructions_is_callable(self) -> None: + cap: Memory[None] = Memory() + assert callable(cap.get_instructions()) + + def test_instructions_with_no_memories(self) -> None: + cap: Memory[None] = Memory() + text = cap.build_instructions(self._make_ctx()) + assert 'persistent memory system' in text + assert 'Currently stored memories' not in text + + def test_instructions_with_memories(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='user', content='Alice')) + cap: Memory[None] = Memory(store=store) + text = cap.build_instructions(self._make_ctx()) + assert 'Currently stored memories' in text + assert '[user] Alice' in text + + def test_instructions_respects_max(self) -> None: + store = InMemoryStore() + for i in range(25): + store.put(MemoryEntry(key=f'k{i}', content=f'v{i}')) + cap: Memory[None] = Memory(store=store, max_instructions_memories=5) + text = cap.build_instructions(self._make_ctx()) + assert '... and 20 more' in text + + def test_instructions_disabled(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='k', content='v')) + cap: Memory[None] = Memory(store=store, inject_memories_in_instructions=False) + text = cap.build_instructions(self._make_ctx()) + assert 'Currently stored memories' not in text + + +# --- MemoryStore protocol --- + + +class TestMemoryStoreProtocol: + def test_in_memory_store_satisfies_protocol(self) -> None: + assert isinstance(InMemoryStore(), MemoryStore) + + def test_file_store_satisfies_protocol(self, tmp_path: Path) -> None: + assert isinstance(FileStore(tmp_path / 'mem.json'), MemoryStore) From 6feffca8435b70896d7d8877c3dd160787d38479 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 2 Apr 2026 05:50:03 +0000 Subject: [PATCH 02/10] Improve Memory capability: better search, scoping, TTL, dedup warning Address audit findings from PR review: - Better search: word-boundary matching with relevance scoring (count of matching words across key/content/tags, sorted by score descending). Underscores and hyphens treated as word separators. - Memory scoping: `scope: str = 'global'` field on MemoryEntry, with optional `scope` parameter on `search_memories` and `list_memories` tools and `list_all`/`search` store methods. - TTL/expiration: `expires_at: str | None = None` on MemoryEntry with `is_expired()` method. Stores filter out expired entries automatically. `save_memory` tool accepts optional `ttl_minutes` parameter. - Dedup warning: when saving a memory whose key is very similar to an existing key (same 10-char prefix, Levenshtein distance <= 2), log a warning via the `pydantic_harness.memory` logger. Tests: 48 -> 99, all passing with 100% coverage. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/pydantic_harness/memory.py | 206 +++++++++++++++++---- tests/test_memory.py | 314 ++++++++++++++++++++++++++++++++- 2 files changed, 482 insertions(+), 38 deletions(-) diff --git a/src/pydantic_harness/memory.py b/src/pydantic_harness/memory.py index 8e8db26..5649265 100644 --- a/src/pydantic_harness/memory.py +++ b/src/pydantic_harness/memory.py @@ -8,8 +8,10 @@ from __future__ import annotations import json +import logging +import re from dataclasses import dataclass, field -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Any, Protocol, runtime_checkable @@ -19,6 +21,8 @@ from pydantic_ai.toolsets import AgentToolset from pydantic_ai.toolsets.function import FunctionToolset +logger = logging.getLogger(__name__) + @dataclass class MemoryEntry: @@ -33,18 +37,32 @@ class MemoryEntry: tags: list[str] = field(default_factory=list[str]) """Optional tags for categorization and search.""" + scope: str = 'global' + """Namespace scope for this memory (default ``'global'``).""" + + expires_at: str | None = None + """Optional ISO 8601 expiration timestamp. ``None`` means no expiry.""" + created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) """ISO 8601 timestamp of when the memory was first created.""" updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) """ISO 8601 timestamp of the last update.""" + def is_expired(self) -> bool: + """Return True if this entry has passed its expiration time.""" + if self.expires_at is None: + return False + return datetime.fromisoformat(self.expires_at) <= datetime.now(timezone.utc) + def to_dict(self) -> dict[str, Any]: """Serialize to a plain dict for JSON storage.""" return { 'key': self.key, 'content': self.content, 'tags': self.tags, + 'scope': self.scope, + 'expires_at': self.expires_at, 'created_at': self.created_at, 'updated_at': self.updated_at, } @@ -56,33 +74,86 @@ def from_dict(cls, data: dict[str, Any]) -> MemoryEntry: key=data['key'], content=data['content'], tags=data.get('tags', []), + scope=data.get('scope', 'global'), + expires_at=data.get('expires_at'), created_at=data.get('created_at', ''), updated_at=data.get('updated_at', ''), ) +def _score_entry(entry: MemoryEntry, words: list[str]) -> int: + r"""Score a memory entry by counting word-boundary matches across fields. + + Each query word that appears as a whole word (case-insensitive) in the + key, content, or any tag contributes one point per field it appears in. + Underscores and hyphens are treated as word separators in addition to + the standard ``\b`` boundaries. + """ + score = 0 + for word in words: + # Use a boundary pattern that also treats _ and - as separators. + escaped = re.escape(word) + pattern = re.compile(rf'(? bool: + """Return True if two keys share the same first 10 characters and differ only slightly. + + Uses a simple character-level edit distance check: keys are considered + similar when they share the same 10-char prefix and differ by at most 2 + characters (Levenshtein-like). + """ + if len(a) < 10 or len(b) < 10: + return False + if a[:10] != b[:10]: + return False + if a == b: + return False + # Simple Levenshtein-like check: allow at most 2 edits + if abs(len(a) - len(b)) > 2: + return False + # Bounded character-level distance (sufficient for dedup warnings) + max_edits = 2 + m, n = len(a), len(b) + prev = list(range(n + 1)) + for i in range(1, m + 1): + curr = [i] + [0] * n + for j in range(1, n + 1): + cost = 0 if a[i - 1] == b[j - 1] else 1 + curr[j] = min(curr[j - 1] + 1, prev[j] + 1, prev[j - 1] + cost) + prev = curr + return prev[n] <= max_edits + + @runtime_checkable class MemoryStore(Protocol): """Protocol for pluggable memory storage backends.""" - def get(self, key: str) -> MemoryEntry | None: + def get(self, key: str) -> MemoryEntry | None: # pragma: no cover """Retrieve a memory entry by key, or None if not found.""" ... - def put(self, entry: MemoryEntry) -> None: + def put(self, entry: MemoryEntry) -> None: # pragma: no cover """Store or update a memory entry.""" ... - def delete(self, key: str) -> bool: + def delete(self, key: str) -> bool: # pragma: no cover """Delete a memory entry by key. Returns True if it existed.""" ... - def list_all(self) -> list[MemoryEntry]: - """Return all stored memory entries.""" + def list_all(self, *, scope: str | None = None) -> list[MemoryEntry]: # pragma: no cover + """Return all non-expired entries, optionally filtered by scope.""" ... - def search(self, query: str) -> list[MemoryEntry]: - """Search entries by substring match on key, content, or tags.""" + def search(self, query: str, *, scope: str | None = None) -> list[MemoryEntry]: # pragma: no cover + """Search non-expired entries with word-boundary matching, sorted by relevance.""" ... @@ -108,19 +179,31 @@ def delete(self, key: str) -> bool: """Delete a memory entry by key.""" return self._entries.pop(key, None) is not None - def list_all(self) -> list[MemoryEntry]: - """Return all stored memory entries.""" - return list(self._entries.values()) - - def search(self, query: str) -> list[MemoryEntry]: - """Search entries by substring match on key, content, or tags.""" - q = query.lower() + def list_all(self, *, scope: str | None = None) -> list[MemoryEntry]: + """Return all non-expired entries, optionally filtered by scope.""" return [ entry for entry in self._entries.values() - if q in entry.key.lower() or q in entry.content.lower() or any(q in tag.lower() for tag in entry.tags) + if not entry.is_expired() and (scope is None or entry.scope == scope) ] + def search(self, query: str, *, scope: str | None = None) -> list[MemoryEntry]: + """Search non-expired entries with word-boundary matching, sorted by relevance.""" + words = query.lower().split() + if not words: + return [] + scored: list[tuple[int, MemoryEntry]] = [] + for entry in self._entries.values(): + if entry.is_expired(): + continue + if scope is not None and entry.scope != scope: + continue + score = _score_entry(entry, words) + if score > 0: + scored.append((score, entry)) + scored.sort(key=lambda pair: pair[0], reverse=True) + return [entry for _, entry in scored] + class FileStore: """JSON-file-based store for simple on-disk persistence. @@ -160,25 +243,44 @@ def delete(self, key: str) -> bool: self._save() return existed - def list_all(self) -> list[MemoryEntry]: - """Return all stored memory entries.""" - return list(self._entries.values()) - - def search(self, query: str) -> list[MemoryEntry]: - """Search entries by substring match on key, content, or tags.""" - q = query.lower() + def list_all(self, *, scope: str | None = None) -> list[MemoryEntry]: + """Return all non-expired entries, optionally filtered by scope.""" return [ entry for entry in self._entries.values() - if q in entry.key.lower() or q in entry.content.lower() or any(q in tag.lower() for tag in entry.tags) + if not entry.is_expired() and (scope is None or entry.scope == scope) ] + def search(self, query: str, *, scope: str | None = None) -> list[MemoryEntry]: + """Search non-expired entries with word-boundary matching, sorted by relevance.""" + words = query.lower().split() + if not words: + return [] + scored: list[tuple[int, MemoryEntry]] = [] + for entry in self._entries.values(): + if entry.is_expired(): + continue + if scope is not None and entry.scope != scope: + continue + score = _score_entry(entry, words) + if score > 0: + scored.append((score, entry)) + scored.sort(key=lambda pair: pair[0], reverse=True) + return [entry for _, entry in scored] + def format_entry(entry: MemoryEntry) -> str: """Format a memory entry as a human-readable string.""" line = f'[{entry.key}] {entry.content}' + extras: list[str] = [] if entry.tags: - line += f' (tags: {", ".join(entry.tags)})' + extras.append(f'tags: {", ".join(entry.tags)}') + if entry.scope != 'global': + extras.append(f'scope: {entry.scope}') + if entry.expires_at is not None: + extras.append(f'expires: {entry.expires_at}') + if extras: + line += f' ({"; ".join(extras)})' return line @@ -253,22 +355,47 @@ def get_toolset(self) -> AgentToolset[AgentDepsT] | None: """ store = self.store - def save_memory(key: str, content: str, tags: list[str] | None = None) -> str: + def save_memory( + key: str, + content: str, + tags: list[str] | None = None, + scope: str = 'global', + ttl_minutes: int | None = None, + ) -> str: """Save or update a memory entry. Args: key: Unique key for this memory. content: The content to remember. tags: Optional tags for categorization and search. + scope: Namespace scope (default ``'global'``). + ttl_minutes: Optional time-to-live in minutes. The entry will expire after this duration. """ - now = datetime.now(timezone.utc).isoformat() + now = datetime.now(timezone.utc) + now_iso = now.isoformat() existing = store.get(key) + + # Dedup warning: check for similar keys among existing entries + for existing_entry in store.list_all(): + if _simple_similarity(key, existing_entry.key): + logger.warning( + 'New memory key %r is very similar to existing key %r — possible duplicate', + key, + existing_entry.key, + ) + + expires_at: str | None = None + if ttl_minutes is not None: + expires_at = (now + timedelta(minutes=ttl_minutes)).isoformat() + entry = MemoryEntry( key=key, content=content, tags=tags or [], - created_at=existing.created_at if existing else now, - updated_at=now, + scope=scope, + expires_at=expires_at, + created_at=existing.created_at if existing else now_iso, + updated_at=now_iso, ) store.put(entry) return f'Memory saved: {key}' @@ -282,22 +409,29 @@ def recall_memory(key: str) -> str: entry = store.get(key) if entry is None: return f'No memory found for key: {key}' + if entry.is_expired(): + return f'No memory found for key: {key}' return format_entry(entry) - def search_memories(query: str) -> str: - """Search memories by substring match on keys, content, or tags. + def search_memories(query: str, scope: str | None = None) -> str: + """Search memories by word-boundary matching on keys, content, or tags, sorted by relevance. Args: - query: The search query string. + query: The search query string (space-separated words). + scope: Optional scope to restrict the search to. """ - results = store.search(query) + results = store.search(query, scope=scope) if not results: return f'No memories found matching: {query}' return '\n'.join(format_entry(entry) for entry in results) - def list_memories() -> str: - """List all stored memories.""" - entries = store.list_all() + def list_memories(scope: str | None = None) -> str: + """List all stored memories, optionally filtered by scope. + + Args: + scope: Optional scope to filter by. + """ + entries = store.list_all(scope=scope) if not entries: return 'No memories stored.' return '\n'.join(format_entry(entry) for entry in entries) diff --git a/tests/test_memory.py b/tests/test_memory.py index 890bf00..6aa9f1e 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -3,6 +3,8 @@ from __future__ import annotations import json +import logging +from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Any @@ -16,6 +18,8 @@ Memory, MemoryEntry, MemoryStore, + _score_entry, + _simple_similarity, format_entry, ) @@ -24,12 +28,22 @@ class TestMemoryEntry: def test_round_trip(self) -> None: - entry = MemoryEntry(key='k', content='v', tags=['a', 'b'], created_at='t1', updated_at='t2') + entry = MemoryEntry( + key='k', + content='v', + tags=['a', 'b'], + scope='project', + expires_at='2099-01-01T00:00:00+00:00', + created_at='t1', + updated_at='t2', + ) assert MemoryEntry.from_dict(entry.to_dict()) == entry def test_from_dict_defaults(self) -> None: entry = MemoryEntry.from_dict({'key': 'k', 'content': 'v'}) assert entry.tags == [] + assert entry.scope == 'global' + assert entry.expires_at is None assert entry.created_at == '' assert entry.updated_at == '' @@ -38,6 +52,88 @@ def test_default_timestamps(self) -> None: assert entry.created_at # non-empty ISO string assert entry.updated_at + def test_default_scope(self) -> None: + entry = MemoryEntry(key='k', content='v') + assert entry.scope == 'global' + + def test_is_expired_no_expiry(self) -> None: + entry = MemoryEntry(key='k', content='v') + assert not entry.is_expired() + + def test_is_expired_future(self) -> None: + future = (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat() + entry = MemoryEntry(key='k', content='v', expires_at=future) + assert not entry.is_expired() + + def test_is_expired_past(self) -> None: + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + entry = MemoryEntry(key='k', content='v', expires_at=past) + assert entry.is_expired() + + +# --- _score_entry --- + + +class TestScoreEntry: + def test_no_match(self) -> None: + entry = MemoryEntry(key='greeting', content='hello world') + assert _score_entry(entry, ['zzz']) == 0 + + def test_key_match(self) -> None: + entry = MemoryEntry(key='greeting', content='some text') + assert _score_entry(entry, ['greeting']) == 1 + + def test_content_match(self) -> None: + entry = MemoryEntry(key='k', content='hello world') + assert _score_entry(entry, ['hello']) == 1 + + def test_tag_match(self) -> None: + entry = MemoryEntry(key='k', content='text', tags=['important']) + assert _score_entry(entry, ['important']) == 1 + + def test_multiple_field_match(self) -> None: + entry = MemoryEntry(key='hello', content='hello world', tags=['hello']) + # 'hello' appears in key (1) + content (1) + tags (1) = 3 + assert _score_entry(entry, ['hello']) == 3 + + def test_multiple_words(self) -> None: + entry = MemoryEntry(key='user', content='Alice likes blue') + # 'alice' in content (1), 'blue' in content (1) = 2 + assert _score_entry(entry, ['alice', 'blue']) == 2 + + def test_word_boundary_no_partial(self) -> None: + # 'fox' should NOT match 'foxes' with word-boundary matching + entry = MemoryEntry(key='k', content='foxes jump') + assert _score_entry(entry, ['fox']) == 0 + + +# --- _simple_similarity --- + + +class TestSimpleSimilarity: + def test_identical_keys_not_similar(self) -> None: + assert not _simple_similarity('abcdefghij', 'abcdefghij') + + def test_short_keys_not_similar(self) -> None: + assert not _simple_similarity('abc', 'abd') + + def test_similar_long_keys(self) -> None: + # Differ by 2 characters ('fo' vs 'ba') — within the edit-distance threshold + assert _simple_similarity('abcdefghij_fo', 'abcdefghij_ba') + + def test_different_prefix(self) -> None: + assert not _simple_similarity('xxxxxxxxxxfoo', 'yyyyyyyyyyfoo') + + def test_same_prefix_large_edit(self) -> None: + assert not _simple_similarity('abcdefghijklmnop', 'abcdefghijzzzzzz') + + def test_length_diff_too_large(self) -> None: + # Same 10-char prefix but length differs by more than 2 + assert not _simple_similarity('abcdefghij_x', 'abcdefghij_xyzw') + + def test_one_char_diff(self) -> None: + assert _simple_similarity('abcdefghij_x', 'abcdefghij_y') + # --- InMemoryStore --- @@ -83,6 +179,29 @@ def test_list_all(self) -> None: assert len(entries) == 2 assert {e.key for e in entries} == {'a', 'b'} + def test_list_all_filters_expired(self) -> None: + store = InMemoryStore() + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='alive', content='fresh')) + store.put(MemoryEntry(key='dead', content='stale', expires_at=past)) + entries = store.list_all() + assert len(entries) == 1 + assert entries[0].key == 'alive' + + def test_list_all_scope_filter(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='a', content='x', scope='project')) + store.put(MemoryEntry(key='b', content='y', scope='global')) + entries = store.list_all(scope='project') + assert len(entries) == 1 + assert entries[0].key == 'a' + + def test_list_all_scope_none_returns_all(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='a', content='x', scope='project')) + store.put(MemoryEntry(key='b', content='y', scope='global')) + assert len(store.list_all(scope=None)) == 2 + def test_search_by_key(self) -> None: store = InMemoryStore() store.put(MemoryEntry(key='user_name', content='Alice')) @@ -118,6 +237,39 @@ def test_search_no_results(self) -> None: store.put(MemoryEntry(key='k', content='v')) assert store.search('zzz') == [] + def test_search_filters_expired(self) -> None: + store = InMemoryStore() + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='alive', content='hello world')) + store.put(MemoryEntry(key='dead', content='hello world', expires_at=past)) + results = store.search('hello') + assert len(results) == 1 + assert results[0].key == 'alive' + + def test_search_scope_filter(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='a', content='hello world', scope='project')) + store.put(MemoryEntry(key='b', content='hello world', scope='global')) + results = store.search('hello', scope='project') + assert len(results) == 1 + assert results[0].key == 'a' + + def test_search_relevance_ordering(self) -> None: + store = InMemoryStore() + # 'hello' appears in key + content = score 2 + store.put(MemoryEntry(key='hello', content='hello there')) + # 'hello' appears only in content = score 1 + store.put(MemoryEntry(key='other', content='hello world')) + results = store.search('hello') + assert len(results) == 2 + assert results[0].key == 'hello' # higher score first + assert results[1].key == 'other' + + def test_search_empty_query(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='k', content='v')) + assert store.search('') == [] + # --- FileStore --- @@ -151,6 +303,11 @@ def test_delete_saves(self, tmp_path: Path) -> None: store2 = FileStore(path) assert store2.get('k') is None + def test_delete_missing(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + assert store.delete('nope') is False + def test_list_all(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' store = FileStore(path) @@ -188,11 +345,68 @@ def test_file_format(self, tmp_path: Path) -> None: 'key': 'k', 'content': 'v', 'tags': ['t'], + 'scope': 'global', + 'expires_at': None, 'created_at': 'c', 'updated_at': 'u', } } + def test_list_all_filters_expired(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='alive', content='x')) + store.put(MemoryEntry(key='dead', content='y', expires_at=past)) + assert len(store.list_all()) == 1 + + def test_search_filters_expired(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='alive', content='hello world')) + store.put(MemoryEntry(key='dead', content='hello world', expires_at=past)) + assert len(store.search('hello')) == 1 + + def test_list_all_scope(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='a', content='x', scope='project')) + store.put(MemoryEntry(key='b', content='y', scope='global')) + assert len(store.list_all(scope='project')) == 1 + + def test_search_scope(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='a', content='hello world', scope='project')) + store.put(MemoryEntry(key='b', content='hello world', scope='global')) + assert len(store.search('hello', scope='project')) == 1 + + def test_search_empty_query(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='k', content='v')) + assert store.search('') == [] + + def test_scope_persists(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store1 = FileStore(path) + store1.put(MemoryEntry(key='k', content='v', scope='session')) + store2 = FileStore(path) + entry = store2.get('k') + assert entry is not None + assert entry.scope == 'session' + + def test_expires_at_persists(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + future = (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat() + store1 = FileStore(path) + store1.put(MemoryEntry(key='k', content='v', expires_at=future)) + store2 = FileStore(path) + entry = store2.get('k') + assert entry is not None + assert entry.expires_at == future + # --- format_entry --- @@ -206,6 +420,28 @@ def test_with_tags(self) -> None: entry = MemoryEntry(key='k', content='hello', tags=['a', 'b']) assert format_entry(entry) == '[k] hello (tags: a, b)' + def test_with_scope(self) -> None: + entry = MemoryEntry(key='k', content='hello', scope='project') + assert format_entry(entry) == '[k] hello (scope: project)' + + def test_global_scope_omitted(self) -> None: + entry = MemoryEntry(key='k', content='hello', scope='global') + assert format_entry(entry) == '[k] hello' + + def test_with_expires_at(self) -> None: + entry = MemoryEntry(key='k', content='hello', expires_at='2099-01-01T00:00:00+00:00') + assert format_entry(entry) == '[k] hello (expires: 2099-01-01T00:00:00+00:00)' + + def test_all_extras(self) -> None: + entry = MemoryEntry( + key='k', + content='hello', + tags=['t'], + scope='project', + expires_at='2099-01-01T00:00:00+00:00', + ) + assert format_entry(entry) == '[k] hello (tags: t; scope: project; expires: 2099-01-01T00:00:00+00:00)' + # --- Memory capability --- @@ -266,6 +502,13 @@ def test_recall_missing(self) -> None: tools = self._get_tools() assert 'No memory found' in tools['recall_memory']('nope') + def test_recall_expired(self) -> None: + store = InMemoryStore() + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='old', content='stale', expires_at=past)) + tools = self._get_tools(store) + assert 'No memory found' in tools['recall_memory']('old') + def test_save_updates_existing(self) -> None: store = InMemoryStore() tools = self._get_tools(store) @@ -289,13 +532,33 @@ def test_save_with_tags(self) -> None: assert entry is not None assert entry.tags == ['tag1', 'tag2'] + def test_save_with_scope(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v', None, 'project') + entry = store.get('k') + assert entry is not None + assert entry.scope == 'project' + + def test_save_with_ttl(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v', None, 'global', 60) + entry = store.get('k') + assert entry is not None + assert entry.expires_at is not None + expires = datetime.fromisoformat(entry.expires_at) + # Should expire roughly 60 minutes from now + assert expires > datetime.now(timezone.utc) + timedelta(minutes=59) + assert expires < datetime.now(timezone.utc) + timedelta(minutes=61) + def test_search(self) -> None: store = InMemoryStore() tools = self._get_tools(store) tools['save_memory']('user_name', 'Alice') tools['save_memory']('color', 'blue') - result = tools['search_memories']('alice') + result = tools['search_memories']('Alice') assert 'Alice' in result assert 'blue' not in result @@ -303,6 +566,15 @@ def test_search_no_results(self) -> None: tools = self._get_tools() assert 'No memories found' in tools['search_memories']('zzz') + def test_search_with_scope(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('a', 'hello world', None, 'project') + tools['save_memory']('b', 'hello world', None, 'global') + result = tools['search_memories']('hello', 'project') + assert '[a]' in result + assert '[b]' not in result + def test_list_empty(self) -> None: tools = self._get_tools() assert tools['list_memories']() == 'No memories stored.' @@ -316,6 +588,15 @@ def test_list_with_entries(self) -> None: assert '[a] alpha' in result assert '[b] beta' in result + def test_list_with_scope(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('a', 'alpha', None, 'project') + tools['save_memory']('b', 'beta', None, 'global') + result = tools['list_memories']('project') + assert '[a] alpha' in result + assert '[b]' not in result + def test_delete_existing(self) -> None: store = InMemoryStore() tools = self._get_tools(store) @@ -328,6 +609,35 @@ def test_delete_missing(self) -> None: assert 'No memory found' in tools['delete_memory']('nope') +# --- Dedup warning --- + + +class TestDedupWarning: + def test_similar_key_logs_warning(self, caplog: Any) -> None: + store = InMemoryStore() + tools = TestMemoryTools._get_tools(store) + tools['save_memory']('abcdefghij_x', 'first value') + with caplog.at_level(logging.WARNING, logger='pydantic_harness.memory'): + tools['save_memory']('abcdefghij_y', 'second value') + assert any('possible duplicate' in record.message.lower() for record in caplog.records) + + def test_different_keys_no_warning(self, caplog: Any) -> None: + store = InMemoryStore() + tools = TestMemoryTools._get_tools(store) + tools['save_memory']('first_key_long', 'first value') + with caplog.at_level(logging.WARNING, logger='pydantic_harness.memory'): + tools['save_memory']('other_key_long', 'second value') + assert not any('possible duplicate' in record.message.lower() for record in caplog.records) + + def test_short_keys_no_warning(self, caplog: Any) -> None: + store = InMemoryStore() + tools = TestMemoryTools._get_tools(store) + tools['save_memory']('abc', 'first value') + with caplog.at_level(logging.WARNING, logger='pydantic_harness.memory'): + tools['save_memory']('abd', 'second value') + assert not any('possible duplicate' in record.message.lower() for record in caplog.records) + + # --- Instructions --- From d9ce68835d1114464d432d534c070444d6ca0b8e Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 2 Apr 2026 16:55:23 -0500 Subject: [PATCH 03/10] refactor(memory): add MemoryEntryDict TypedDict, eliminate avoidable Any types Co-Authored-By: Claude Opus 4.6 (1M context) --- src/pydantic_harness/memory.py | 47 ++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/src/pydantic_harness/memory.py b/src/pydantic_harness/memory.py index 5649265..ee651ba 100644 --- a/src/pydantic_harness/memory.py +++ b/src/pydantic_harness/memory.py @@ -13,7 +13,7 @@ from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import Any, Protocol, runtime_checkable +from typing import Any, Protocol, TypedDict, runtime_checkable from pydantic_ai._instructions import AgentInstructions from pydantic_ai.capabilities.abstract import AbstractCapability @@ -24,6 +24,27 @@ logger = logging.getLogger(__name__) +class _MemoryEntryDictRequired(TypedDict): + """Required fields for MemoryEntryDict.""" + + key: str + content: str + + +class MemoryEntryDict(_MemoryEntryDictRequired, total=False): + """Serialized form of a MemoryEntry for JSON storage. + + Only `key` and `content` are required; the remaining fields are + optional so that `from_dict` can accept legacy data missing some keys. + """ + + tags: list[str] + scope: str + expires_at: str | None + created_at: str + updated_at: str + + @dataclass class MemoryEntry: """A single memory entry with content, tags, and timestamps.""" @@ -34,14 +55,14 @@ class MemoryEntry: content: str """The content of the memory.""" - tags: list[str] = field(default_factory=list[str]) + tags: list[str] = field(default_factory=lambda: list[str]()) """Optional tags for categorization and search.""" scope: str = 'global' - """Namespace scope for this memory (default ``'global'``).""" + """Namespace scope for this memory (default `'global'`).""" expires_at: str | None = None - """Optional ISO 8601 expiration timestamp. ``None`` means no expiry.""" + """Optional ISO 8601 expiration timestamp. `None` means no expiry.""" created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) """ISO 8601 timestamp of when the memory was first created.""" @@ -55,7 +76,7 @@ def is_expired(self) -> bool: return False return datetime.fromisoformat(self.expires_at) <= datetime.now(timezone.utc) - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> MemoryEntryDict: """Serialize to a plain dict for JSON storage.""" return { 'key': self.key, @@ -68,7 +89,7 @@ def to_dict(self) -> dict[str, Any]: } @classmethod - def from_dict(cls, data: dict[str, Any]) -> MemoryEntry: + def from_dict(cls, data: MemoryEntryDict) -> MemoryEntry: """Deserialize from a plain dict.""" return cls( key=data['key'], @@ -87,7 +108,7 @@ def _score_entry(entry: MemoryEntry, words: list[str]) -> int: Each query word that appears as a whole word (case-insensitive) in the key, content, or any tag contributes one point per field it appears in. Underscores and hyphens are treated as word separators in addition to - the standard ``\b`` boundaries. + the standard `\\b` boundaries. """ score = 0 for word in words: @@ -160,7 +181,7 @@ def search(self, query: str, *, scope: str | None = None) -> list[MemoryEntry]: class InMemoryStore: """Dict-based in-memory store, suitable for testing. - All data lives in a plain ``dict`` and is lost when the process exits. + All data lives in a plain `dict` and is lost when the process exits. """ def __init__(self) -> None: @@ -219,7 +240,7 @@ def __init__(self, path: str | Path) -> None: def _load(self) -> None: if self._path.exists(): - raw: dict[str, Any] = json.loads(self._path.read_text(encoding='utf-8')) + raw = json.loads(self._path.read_text(encoding='utf-8')) self._entries = {key: MemoryEntry.from_dict(val) for key, val in raw.items()} def _save(self) -> None: @@ -318,7 +339,7 @@ def get_serialization_name(cls) -> str | None: def from_spec(cls, *args: Any, **kwargs: Any) -> Memory[Any]: """Create from spec arguments. - Supports `backend` kwarg: ``"memory"`` (default) or ``"file"`` (requires `path`). + Supports `backend` kwarg: `"memory"` (default) or `"file"` (requires `path`). """ backend = kwargs.pop('backend', 'memory') if backend == 'file': @@ -350,8 +371,8 @@ def get_instructions(self) -> AgentInstructions[AgentDepsT] | None: def get_toolset(self) -> AgentToolset[AgentDepsT] | None: """Return a toolset with memory management tools. - Tool functions close over ``self`` to access the store without - requiring anything from the agent's ``deps``. + Tool functions close over `self` to access the store without + requiring anything from the agent's `deps`. """ store = self.store @@ -368,7 +389,7 @@ def save_memory( key: Unique key for this memory. content: The content to remember. tags: Optional tags for categorization and search. - scope: Namespace scope (default ``'global'``). + scope: Namespace scope (default `'global'`). ttl_minutes: Optional time-to-live in minutes. The entry will expire after this duration. """ now = datetime.now(timezone.utc) From 63cd254c15ac424b174f3c1fab61724dd4b08958 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 2 Apr 2026 16:58:11 -0500 Subject: [PATCH 04/10] refactor(memory): extract _BaseDictStore to deduplicate InMemoryStore and FileStore Co-Authored-By: Claude Opus 4.6 (1M context) --- src/pydantic_harness/memory.py | 57 ++++++++++------------------------ 1 file changed, 17 insertions(+), 40 deletions(-) diff --git a/src/pydantic_harness/memory.py b/src/pydantic_harness/memory.py index ee651ba..44c5e6a 100644 --- a/src/pydantic_harness/memory.py +++ b/src/pydantic_harness/memory.py @@ -178,15 +178,10 @@ def search(self, query: str, *, scope: str | None = None) -> list[MemoryEntry]: ... -class InMemoryStore: - """Dict-based in-memory store, suitable for testing. - - All data lives in a plain `dict` and is lost when the process exits. - """ +class _BaseDictStore: + """Base class for dict-backed memory stores.""" - def __init__(self) -> None: - """Initialize an empty in-memory store.""" - self._entries: dict[str, MemoryEntry] = {} + _entries: dict[str, MemoryEntry] def get(self, key: str) -> MemoryEntry | None: """Retrieve a memory entry by key.""" @@ -226,7 +221,18 @@ def search(self, query: str, *, scope: str | None = None) -> list[MemoryEntry]: return [entry for _, entry in scored] -class FileStore: +class InMemoryStore(_BaseDictStore): + """Dict-based in-memory store, suitable for testing. + + All data lives in a plain `dict` and is lost when the process exits. + """ + + def __init__(self) -> None: + """Initialize an empty in-memory store.""" + self._entries: dict[str, MemoryEntry] = {} + + +class FileStore(_BaseDictStore): """JSON-file-based store for simple on-disk persistence. Reads the file on initialization and writes back on every mutation. @@ -248,47 +254,18 @@ def _save(self) -> None: data = {key: entry.to_dict() for key, entry in self._entries.items()} self._path.write_text(json.dumps(data, indent=2), encoding='utf-8') - def get(self, key: str) -> MemoryEntry | None: - """Retrieve a memory entry by key.""" - return self._entries.get(key) - def put(self, entry: MemoryEntry) -> None: """Store or update a memory entry.""" - self._entries[entry.key] = entry + super().put(entry) self._save() def delete(self, key: str) -> bool: """Delete a memory entry by key.""" - existed = self._entries.pop(key, None) is not None + existed = super().delete(key) if existed: self._save() return existed - def list_all(self, *, scope: str | None = None) -> list[MemoryEntry]: - """Return all non-expired entries, optionally filtered by scope.""" - return [ - entry - for entry in self._entries.values() - if not entry.is_expired() and (scope is None or entry.scope == scope) - ] - - def search(self, query: str, *, scope: str | None = None) -> list[MemoryEntry]: - """Search non-expired entries with word-boundary matching, sorted by relevance.""" - words = query.lower().split() - if not words: - return [] - scored: list[tuple[int, MemoryEntry]] = [] - for entry in self._entries.values(): - if entry.is_expired(): - continue - if scope is not None and entry.scope != scope: - continue - score = _score_entry(entry, words) - if score > 0: - scored.append((score, entry)) - scored.sort(key=lambda pair: pair[0], reverse=True) - return [entry for _, entry in scored] - def format_entry(entry: MemoryEntry) -> str: """Format a memory entry as a human-readable string.""" From f9b10667cd20bea020370416bd097ba0e54b2ef6 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 2 Apr 2026 17:00:03 -0500 Subject: [PATCH 05/10] fix(memory): handle malformed JSON gracefully in FileStore._load Co-Authored-By: Claude Opus 4.6 (1M context) --- src/pydantic_harness/memory.py | 11 +++++++++-- tests/test_memory.py | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/pydantic_harness/memory.py b/src/pydantic_harness/memory.py index 44c5e6a..cd9997d 100644 --- a/src/pydantic_harness/memory.py +++ b/src/pydantic_harness/memory.py @@ -246,8 +246,15 @@ def __init__(self, path: str | Path) -> None: def _load(self) -> None: if self._path.exists(): - raw = json.loads(self._path.read_text(encoding='utf-8')) - self._entries = {key: MemoryEntry.from_dict(val) for key, val in raw.items()} + try: + raw: dict[str, MemoryEntryDict] = json.loads(self._path.read_text(encoding='utf-8')) + if not isinstance(raw, dict): # pyright: ignore[reportUnnecessaryIsInstance] + logger.warning('Memory file %s contains non-dict JSON, starting empty', self._path) + return + self._entries = {key: MemoryEntry.from_dict(val) for key, val in raw.items()} + except (json.JSONDecodeError, KeyError, TypeError) as e: + logger.warning('Failed to load memory file %s: %s, starting empty', self._path, e) + self._entries = {} def _save(self) -> None: self._path.parent.mkdir(parents=True, exist_ok=True) diff --git a/tests/test_memory.py b/tests/test_memory.py index 6aa9f1e..8d944f1 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -388,6 +388,24 @@ def test_search_empty_query(self, tmp_path: Path) -> None: store.put(MemoryEntry(key='k', content='v')) assert store.search('') == [] + def test_load_malformed_json(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + path.write_text('not json at all', encoding='utf-8') + store = FileStore(path) + assert store.list_all() == [] + + def test_load_wrong_structure(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + path.write_text('["a", "b"]', encoding='utf-8') + store = FileStore(path) + assert store.list_all() == [] + + def test_load_missing_entry_fields(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + path.write_text('{"k": {"not_a_key": "oops"}}', encoding='utf-8') + store = FileStore(path) + assert store.list_all() == [] + def test_scope_persists(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' store1 = FileStore(path) From 7ddf098ffcc22b2fa79655f4296dd8ae24dcb98a Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 2 Apr 2026 17:01:36 -0500 Subject: [PATCH 06/10] refactor(memory): make from_spec signature explicit, raise on unknown backend Co-Authored-By: Claude Opus 4.6 (1M context) --- src/pydantic_harness/memory.py | 32 +++++++++++++++++++++++++------- tests/test_memory.py | 21 +++++++++++++++++++++ 2 files changed, 46 insertions(+), 7 deletions(-) diff --git a/src/pydantic_harness/memory.py b/src/pydantic_harness/memory.py index cd9997d..6dce2e6 100644 --- a/src/pydantic_harness/memory.py +++ b/src/pydantic_harness/memory.py @@ -320,16 +320,34 @@ def get_serialization_name(cls) -> str | None: return 'Memory' @classmethod - def from_spec(cls, *args: Any, **kwargs: Any) -> Memory[Any]: + def from_spec( + cls, + *, + backend: str = 'memory', + path: str = '.memories.json', + inject_memories_in_instructions: bool = True, + max_instructions_memories: int = 20, + ) -> Memory[Any]: """Create from spec arguments. - Supports `backend` kwarg: `"memory"` (default) or `"file"` (requires `path`). + Args: + backend: Storage backend, `"memory"` (default) or `"file"`. + path: File path for the `"file"` backend (default `".memories.json"`). + inject_memories_in_instructions: Whether to inject memories into the system prompt. + max_instructions_memories: Maximum memories to inject into the system prompt. """ - backend = kwargs.pop('backend', 'memory') - if backend == 'file': - path = kwargs.pop('path', '.memories.json') - return cls(store=FileStore(path), **kwargs) - return cls(store=InMemoryStore(), **kwargs) + store: MemoryStore + if backend == 'memory': + store = InMemoryStore() + elif backend == 'file': + store = FileStore(path) + else: + raise ValueError(f'Unknown memory backend: {backend!r}. Use "memory" or "file".') + return cls( + store=store, + inject_memories_in_instructions=inject_memories_in_instructions, + max_instructions_memories=max_instructions_memories, + ) def build_instructions(self, ctx: RunContext[AgentDepsT]) -> str: """Build dynamic instructions that include currently stored memories.""" diff --git a/tests/test_memory.py b/tests/test_memory.py index 8d944f1..0057e5f 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -477,6 +477,27 @@ def test_from_spec_file(self, tmp_path: Path) -> None: cap = Memory.from_spec(backend='file', path=str(path)) assert isinstance(cap.store, FileStore) + def test_from_spec_unknown_backend(self) -> None: + import pytest + + with pytest.raises(ValueError, match='Unknown memory backend'): + Memory.from_spec(backend='redis') + + def test_from_spec_explicit_memory_backend(self) -> None: + cap = Memory.from_spec(backend='memory') + assert isinstance(cap.store, InMemoryStore) + + def test_from_spec_with_options(self, tmp_path: Path) -> None: + cap = Memory.from_spec( + backend='file', + path=str(tmp_path / 'mem.json'), + inject_memories_in_instructions=False, + max_instructions_memories=10, + ) + assert isinstance(cap.store, FileStore) + assert cap.inject_memories_in_instructions is False + assert cap.max_instructions_memories == 10 + def test_default_store(self) -> None: cap: Memory[None] = Memory() assert isinstance(cap.store, InMemoryStore) From 11e944cece812ec718d059f592b8747a6ea6c09c Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 2 Apr 2026 17:04:14 -0500 Subject: [PATCH 07/10] test(memory): add edge case tests for scoring, similarity, format, TTL, and conformance Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/test_memory.py | 76 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/tests/test_memory.py b/tests/test_memory.py index 0057e5f..82a45f7 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -106,6 +106,27 @@ def test_word_boundary_no_partial(self) -> None: entry = MemoryEntry(key='k', content='foxes jump') assert _score_entry(entry, ['fox']) == 0 + def test_regex_metacharacters_in_query(self) -> None: + entry = MemoryEntry(key='lang', content='I use c++ daily') + assert _score_entry(entry, ['c++']) == 1 + + def test_empty_words_list(self) -> None: + entry = MemoryEntry(key='k', content='hello') + assert _score_entry(entry, []) == 0 + + def test_underscore_word_boundary(self) -> None: + entry = MemoryEntry(key='user_name', content='text') + assert _score_entry(entry, ['name']) == 1 + + def test_hyphen_word_boundary(self) -> None: + entry = MemoryEntry(key='my-project', content='text') + assert _score_entry(entry, ['project']) == 1 + + def test_partial_word_match(self) -> None: + entry = MemoryEntry(key='k', content='alice likes blue') + # 'alice' matches (1), 'zzz' does not (0) = score 1 + assert _score_entry(entry, ['alice', 'zzz']) == 1 + # --- _simple_similarity --- @@ -134,6 +155,18 @@ def test_length_diff_too_large(self) -> None: def test_one_char_diff(self) -> None: assert _simple_similarity('abcdefghij_x', 'abcdefghij_y') + def test_edit_distance_exactly_three(self) -> None: + # Just over the threshold -- should NOT be similar + assert not _simple_similarity('abcdefghij_abc', 'abcdefghij_xyz') + + def test_nine_char_keys(self) -> None: + # Just below the 10-char minimum + assert not _simple_similarity('abcdefghi', 'abcdefghj') + + def test_exactly_ten_char_keys_not_similar(self) -> None: + # 10-char keys differing at position 10 do NOT share a 10-char prefix + assert not _simple_similarity('abcdefghij', 'abcdefghik') + # --- InMemoryStore --- @@ -460,6 +493,14 @@ def test_all_extras(self) -> None: ) assert format_entry(entry) == '[k] hello (tags: t; scope: project; expires: 2099-01-01T00:00:00+00:00)' + def test_empty_content(self) -> None: + entry = MemoryEntry(key='k', content='') + assert format_entry(entry) == '[k] ' + + def test_empty_key(self) -> None: + entry = MemoryEntry(key='', content='hello') + assert format_entry(entry) == '[] hello' + # --- Memory capability --- @@ -647,6 +688,16 @@ def test_delete_missing(self) -> None: tools = self._get_tools() assert 'No memory found' in tools['delete_memory']('nope') + def test_save_with_ttl_zero(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v', None, 'global', 0) + entry = store.get('k') + assert entry is not None + assert entry.expires_at is not None + # TTL=0 means it expires immediately + assert entry.is_expired() + # --- Dedup warning --- @@ -724,6 +775,16 @@ def test_instructions_disabled(self) -> None: text = cap.build_instructions(self._make_ctx()) assert 'Currently stored memories' not in text + def test_instructions_exact_max_no_overflow(self) -> None: + store = InMemoryStore() + for i in range(5): + store.put(MemoryEntry(key=f'k{i}', content=f'v{i}')) + cap: Memory[None] = Memory(store=store, max_instructions_memories=5) + text = cap.build_instructions(self._make_ctx()) + assert '... and' not in text + assert '[k0]' in text + assert '[k4]' in text + # --- MemoryStore protocol --- @@ -734,3 +795,18 @@ def test_in_memory_store_satisfies_protocol(self) -> None: def test_file_store_satisfies_protocol(self, tmp_path: Path) -> None: assert isinstance(FileStore(tmp_path / 'mem.json'), MemoryStore) + + +# --- AbstractCapability conformance --- + + +class TestAbstractCapabilityConformance: + def test_is_abstract_capability_subclass(self) -> None: + from pydantic_ai.capabilities.abstract import AbstractCapability + + assert issubclass(Memory, AbstractCapability) + + def test_instance_is_abstract_capability(self) -> None: + from pydantic_ai.capabilities.abstract import AbstractCapability + + assert isinstance(Memory(), AbstractCapability) From c9dc52cfd3f2b428dd5ef78eb84ffb79c066eee2 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 2 Apr 2026 17:05:26 -0500 Subject: [PATCH 08/10] chore(memory): update exports and plan to reflect review changes Co-Authored-By: Claude Opus 4.6 (1M context) --- PLAN.md | 11 +++++++---- src/pydantic_harness/__init__.py | 3 ++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/PLAN.md b/PLAN.md index ca533d8..0f95390 100644 --- a/PLAN.md +++ b/PLAN.md @@ -21,8 +21,12 @@ Implements a `Memory` capability (`AbstractCapability` subclass) that provides p ### Memory Model -- **`MemoryEntry`** dataclass: `key`, `content`, `tags` (list[str]), `created_at`, `updated_at` -- Search is substring-based (case-insensitive) across key, content, and tags +- **`MemoryEntry`** dataclass: `key`, `content`, `tags` (list[str]), `scope`, `expires_at`, `created_at`, `updated_at` +- **`MemoryEntryDict`** TypedDict for serialization +- Word-boundary search with relevance scoring (case-insensitive) across key, content, and tags +- Scoping/namespaces via `scope` field with filtering on search/list +- TTL/expiration via `expires_at` with `is_expired()` auto-filtering +- Dedup warning on save when keys are similar (Levenshtein distance <= 2) ### Spec Serialization @@ -41,11 +45,10 @@ Implements a `Memory` capability (`AbstractCapability` subclass) that provides p - `src/pydantic_harness/memory.py` - Capability, stores, entry model - `src/pydantic_harness/__init__.py` - Re-exports -- `tests/test_memory.py` - 48 tests covering all code paths +- `tests/test_memory.py` - 113 tests covering all code paths ## Future Work - Semantic/vector search backend (e.g. embedding-based `MemoryStore`) -- TTL / expiration on entries - Session-scoped memory isolation via `for_run()` - SQLite / Redis backends for production persistence diff --git a/src/pydantic_harness/__init__.py b/src/pydantic_harness/__init__.py index 62831af..d47b53e 100644 --- a/src/pydantic_harness/__init__.py +++ b/src/pydantic_harness/__init__.py @@ -7,12 +7,13 @@ # Each capability module is imported and re-exported here. # Capabilities are listed alphabetically. -from pydantic_harness.memory import FileStore, InMemoryStore, Memory, MemoryEntry, MemoryStore +from pydantic_harness.memory import FileStore, InMemoryStore, Memory, MemoryEntry, MemoryEntryDict, MemoryStore __all__: list[str] = [ 'FileStore', 'InMemoryStore', 'Memory', 'MemoryEntry', + 'MemoryEntryDict', 'MemoryStore', ] From aa80c70a2c3a0cd0570a047eddde0440440248bf Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 2 Apr 2026 17:24:55 -0500 Subject: [PATCH 09/10] feat(memory): add 3 example scripts with logfire instrumentation - personal_assistant.py: FileStore persistence, preferences, instructions injection - study_coach.py: TTL/spaced repetition, tags, search - coding_assistant.py: procedural memory, rules, search, delete All examples assert on memory state and are instrumented with logfire spans. Co-Authored-By: Claude Opus 4.6 (1M context) --- examples/memory/coding_assistant.py | 101 ++++++++++++++++++++++++++ examples/memory/personal_assistant.py | 89 +++++++++++++++++++++++ examples/memory/study_coach.py | 76 +++++++++++++++++++ 3 files changed, 266 insertions(+) create mode 100644 examples/memory/coding_assistant.py create mode 100644 examples/memory/personal_assistant.py create mode 100644 examples/memory/study_coach.py diff --git a/examples/memory/coding_assistant.py b/examples/memory/coding_assistant.py new file mode 100644 index 0000000..65c4123 --- /dev/null +++ b/examples/memory/coding_assistant.py @@ -0,0 +1,101 @@ +"""Self-Improving Coding Assistant — procedural memory via instructions injection. + +Demonstrates: instructions injection as self-modifying prompt, scoping, search, delete. +""" + +from __future__ import annotations + +import sys + +import logfire +from pydantic_ai import Agent + +from pydantic_harness.memory import InMemoryStore, Memory + +logfire.configure(send_to_logfire='if-token-present') +logfire.instrument_openai() + + +def main() -> None: + """Run the coding assistant example.""" + store = InMemoryStore() + memory = Memory(store=store, max_instructions_memories=10) + + agent = Agent( + 'openai:gpt-4o-mini', + capabilities=[memory], + system_prompt=( + 'You are a coding assistant that learns from user corrections. ' + 'When the user gives you a coding rule or correction, save it as a memory ' + 'with scope "rules" and tags like ["python", "style"] or ["typescript", "testing"]. ' + 'Use descriptive keys like "rule_python_fstrings" or "rule_ts_const". ' + 'When asked to write code, search your memories for relevant rules first.' + ), + ) + + # --- Teach rules --- + with logfire.span('teach-rules'): + result1 = agent.run_sync( + 'Remember these coding rules:\n' + '1. Always use f-strings in Python, never .format() or % formatting\n' + '2. In TypeScript, prefer const over let, never use var\n' + '3. Always add type hints to Python function signatures' + ) + print(f'Assistant: {result1.output}') + + rules = store.list_all() + print(f'\nRules stored: {len(rules)}') + for r in rules: + print(f' [{r.key}] {r.content} (scope={r.scope}, tags={r.tags})') + + assert len(rules) >= 3, f'Expected at least 3 rules saved, got {len(rules)}' + + # Check that search works across stored rules + python_rules = store.search('python') + print(f'Rules matching "python": {len(python_rules)}') + assert len(python_rules) >= 1, 'Expected at least 1 rule matching "python"' + + # --- Verify instructions injection --- + # Build instructions should now include the rules + from unittest.mock import MagicMock + + from pydantic_ai._run_context import RunContext + from pydantic_ai.usage import RunUsage + + ctx: RunContext[None] = RunContext(deps=None, model=MagicMock(), usage=RunUsage()) + instructions = memory.build_instructions(ctx) + print(f'\nInstructions preview (first 300 chars):\n{instructions[:300]}...') + + assert 'Currently stored memories' in instructions, 'Expected memories in instructions' + + # --- Ask for code, verify rules are considered --- + with logfire.span('apply-rules'): + result2 = agent.run_sync( + 'Write a Python function that greets a user by name. Follow all coding rules you know.' + ) + print(f'\nAssistant: {result2.output}') + + # The output should use f-strings and type hints (based on rules) + output_lower = result2.output.lower() + assert "f'" in result2.output or 'f"' in result2.output or 'f-string' in output_lower, ( + 'Expected f-string usage in code output' + ) + + # --- Delete an obsolete rule --- + with logfire.span('delete-rule'): + result3 = agent.run_sync('Actually, the TypeScript const rule is outdated for this project. Delete it.') + print(f'\nAssistant: {result3.output}') + + remaining = store.list_all() + print(f'\nRules after deletion: {len(remaining)}') + for r in remaining: + print(f' [{r.key}] {r.content}') + + # Should have fewer rules now + assert len(remaining) < len(rules), 'Expected at least one rule deleted' + + print('\n--- Coding Assistant example passed! ---') + + +if __name__ == '__main__': + sys.exit(main() or 0) diff --git a/examples/memory/personal_assistant.py b/examples/memory/personal_assistant.py new file mode 100644 index 0000000..cd477f1 --- /dev/null +++ b/examples/memory/personal_assistant.py @@ -0,0 +1,89 @@ +"""Personal Assistant — remembers user preferences across sessions. + +Demonstrates: FileStore persistence, save/recall, instructions injection, tags, scoping. +""" + +from __future__ import annotations + +import sys +import tempfile +from pathlib import Path + +import logfire +from pydantic_ai import Agent + +from pydantic_harness.memory import FileStore, Memory + +logfire.configure(send_to_logfire='if-token-present') +logfire.instrument_openai() + + +def main() -> None: + """Run the personal assistant example.""" + with tempfile.TemporaryDirectory() as tmpdir: + mem_path = Path(tmpdir) / 'preferences.json' + store = FileStore(mem_path) + memory = Memory(store=store) + + agent = Agent( + 'openai:gpt-4o-mini', + capabilities=[memory], + system_prompt=( + 'You are a helpful personal assistant. ' + 'When the user tells you about their preferences, save each one as a memory ' + 'with scope "user_prefs" and appropriate tags. ' + 'Use descriptive keys like "preferred_name" or "theme_preference".' + ), + ) + + # --- Session 1: user shares preferences --- + with logfire.span('session-1-save-preferences'): + result1 = agent.run_sync("Hi! My name is Alice, I prefer dark mode, and I'm vegetarian.") + print(f'Assistant: {result1.output}') + + entries = store.list_all() + print(f'\nMemories after session 1: {len(entries)}') + for e in entries: + print(f' [{e.key}] {e.content} (tags={e.tags}, scope={e.scope})') + + assert len(entries) >= 2, f'Expected at least 2 memories saved, got {len(entries)}' + all_content = ' '.join(e.content.lower() for e in entries) + assert 'alice' in all_content or any('alice' in e.key.lower() for e in entries), 'Expected a memory about Alice' + + # --- Session 2: new agent instance loads from same file (persistence) --- + store2 = FileStore(mem_path) + memory2 = Memory(store=store2) + agent2 = Agent( + 'openai:gpt-4o-mini', + capabilities=[memory2], + system_prompt='You are a helpful personal assistant.', + ) + + loaded_entries = store2.list_all() + print(f'\nMemories loaded in session 2: {len(loaded_entries)}') + assert len(loaded_entries) == len(entries), 'FileStore persistence failed' + + with logfire.span('session-2-recall-preferences'): + result2 = agent2.run_sync('What do you know about me?') + print(f'Assistant: {result2.output}') + + # The instructions injection should have included the memories + assert 'alice' in result2.output.lower() or 'dark' in result2.output.lower(), ( + 'Expected assistant to recall preferences from instructions injection' + ) + + # --- Session 3: update a preference --- + with logfire.span('session-3-update-preference'): + result3 = agent2.run_sync('Actually, I go by Ali now. Please update my name.') + print(f'\nAssistant: {result3.output}') + + updated_entries = store2.list_all() + print(f'\nMemories after update: {len(updated_entries)}') + for e in updated_entries: + print(f' [{e.key}] {e.content} (tags={e.tags})') + + print('\n--- Personal Assistant example passed! ---') + + +if __name__ == '__main__': + sys.exit(main() or 0) diff --git a/examples/memory/study_coach.py b/examples/memory/study_coach.py new file mode 100644 index 0000000..702b335 --- /dev/null +++ b/examples/memory/study_coach.py @@ -0,0 +1,76 @@ +"""Study Coach — spaced repetition with TTL. + +Demonstrates: TTL/expiration, save with ttl_minutes, list/search, tags. +""" + +from __future__ import annotations + +import sys + +import logfire +from pydantic_ai import Agent + +from pydantic_harness.memory import InMemoryStore, Memory + +logfire.configure(send_to_logfire='if-token-present') +logfire.instrument_openai() + + +def main() -> None: + """Run the study coach example.""" + store = InMemoryStore() + memory = Memory(store=store) + + agent = Agent( + 'openai:gpt-4o-mini', + capabilities=[memory], + system_prompt=( + 'You are a study coach that helps users learn facts. ' + 'When the user provides a fact to learn, save it as a memory with ' + 'tag "study" and a ttl_minutes value: use 1 for new/hard facts, ' + '60 for reviewed facts, and 1440 for mastered facts. ' + 'Use descriptive keys like "biology_mitochondria" or "history_magna_carta".' + ), + ) + + # --- Learn some facts --- + with logfire.span('learn-facts'): + result1 = agent.run_sync( + 'I need to learn these facts:\n' + '1. Mitochondria are the powerhouse of the cell\n' + '2. The Magna Carta was signed in 1215\n' + '3. Water boils at 100 degrees Celsius at sea level' + ) + print(f'Coach: {result1.output}') + + entries = store.list_all() + print(f'\nFacts stored: {len(entries)}') + for e in entries: + print(f' [{e.key}] {e.content} (tags={e.tags}, ttl={e.expires_at})') + + assert len(entries) >= 3, f'Expected at least 3 facts saved, got {len(entries)}' + + # Check that TTL was set on at least some entries + entries_with_ttl = [e for e in entries if e.expires_at is not None] + assert len(entries_with_ttl) >= 1, 'Expected at least 1 entry with TTL set' + print(f'Entries with TTL: {len(entries_with_ttl)}') + + # Check tags + entries_with_study_tag = [e for e in entries if 'study' in e.tags] + assert len(entries_with_study_tag) >= 1, 'Expected at least 1 entry with "study" tag' + + # --- Search for facts --- + with logfire.span('search-facts'): + result2 = agent.run_sync('Search my memories for anything about biology.') + print(f'\nCoach: {result2.output}') + + # --- List all facts --- + with logfire.span('list-facts'): + result3 = agent.run_sync('List all my study memories.') + print(f'\nCoach: {result3.output}') + + print('\n--- Study Coach example passed! ---') + + +if __name__ == '__main__': + sys.exit(main() or 0) From 58c70a716a74434742c216b7420e6eb4176d9a2f Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Sat, 4 Apr 2026 11:27:47 -0500 Subject: [PATCH 10/10] chore: remove settings.local.json from tracking, restore original deps Co-Authored-By: Claude Opus 4.6 (1M context) --- .agents/settings.local.json | 5 ----- .gitignore | 5 +++++ uv.lock | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) delete mode 100644 .agents/settings.local.json diff --git a/.agents/settings.local.json b/.agents/settings.local.json deleted file mode 100644 index 8b311a3..0000000 --- a/.agents/settings.local.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "permissions": { - "allow": [] - } -} diff --git a/.gitignore b/.gitignore index 00e73a6..36c255e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,8 @@ +.env* +.mcp.json +.DS_Store +.agents/settings.local.json + # IDE .idea/ diff --git a/uv.lock b/uv.lock index 0730281..6178ac2 100644 --- a/uv.lock +++ b/uv.lock @@ -319,15 +319,15 @@ wheels = [ [[package]] name = "opentelemetry-api" -version = "1.40.0" +version = "1.39.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "importlib-metadata" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2c/1d/4049a9e8698361cc1a1aa03a6c59e4fa4c71e0c0f94a30f988a6876a2ae6/opentelemetry_api-1.40.0.tar.gz", hash = "sha256:159be641c0b04d11e9ecd576906462773eb97ae1b657730f0ecf64d32071569f", size = 70851, upload-time = "2026-03-04T14:17:21.555Z" } +sdist = { url = "https://files.pythonhosted.org/packages/97/b9/3161be15bb8e3ad01be8be5a968a9237c3027c5be504362ff800fca3e442/opentelemetry_api-1.39.1.tar.gz", hash = "sha256:fbde8c80e1b937a2c61f20347e91c0c18a1940cecf012d62e65a7caf08967c9c", size = 65767, upload-time = "2025-12-11T13:32:39.182Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5f/bf/93795954016c522008da367da292adceed71cca6ee1717e1d64c83089099/opentelemetry_api-1.40.0-py3-none-any.whl", hash = "sha256:82dd69331ae74b06f6a874704be0cfaa49a1650e1537d4a813b86ecef7d0ecf9", size = 68676, upload-time = "2026-03-04T14:17:01.24Z" }, + { url = "https://files.pythonhosted.org/packages/cf/df/d3f1ddf4bb4cb50ed9b1139cc7b1c54c34a1e7ce8fd1b9a37c0d1551a6bd/opentelemetry_api-1.39.1-py3-none-any.whl", hash = "sha256:2edd8463432a7f8443edce90972169b195e7d6a05500cd29e6d13898187c9950", size = 66356, upload-time = "2025-12-11T13:32:17.304Z" }, ] [[package]]