-
Notifications
You must be signed in to change notification settings - Fork 12
Add SessionPersistence capability #176
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| # Session Persistence Capability | ||
|
|
||
| ## Summary | ||
|
|
||
| This PR implements the `SessionPersistence` capability for saving and loading agent conversation sessions across process restarts. | ||
|
|
||
| ## Design | ||
|
|
||
| ### Storage Protocol | ||
|
|
||
| `SessionStore` is a `Protocol` with four methods: | ||
| - `save(session_id, messages)` — persist a list of `ModelMessage` | ||
| - `load(session_id)` — retrieve messages or `None` | ||
| - `list_sessions()` — enumerate stored session IDs | ||
| - `delete(session_id)` — remove a session | ||
|
|
||
| ### Backends | ||
|
|
||
| - **`InMemorySessionStore`** — dict-based, for testing (data lost on process exit) | ||
| - **`FileSessionStore`** — one JSON file per session in a directory, using `ModelMessagesTypeAdapter` for serialization/deserialization | ||
|
|
||
| ### Capability | ||
|
|
||
| `SessionPersistence(AbstractCapability)`: | ||
| - **`before_run`**: loads saved messages and prepends them to `ctx.messages` | ||
| - **`after_run`**: saves `result.all_messages()` to the store (when `auto_save=True`) | ||
| - **`session_id`**: auto-generated UUID4 if not provided | ||
| - **`from_spec`**: supports `backend="memory"` (default) and `backend="file"` (with configurable `directory`) | ||
|
|
||
| ### Key decisions | ||
|
|
||
| - Uses `before_run`/`after_run` hooks (not `before_model_request`) since session restore/save is a per-run concern, not per-request | ||
| - Prepends history via `ctx.messages[:0] = existing` for clean integration with the agent's message handling | ||
| - `InMemorySessionStore` returns copies to prevent aliasing bugs | ||
| - `FileSessionStore` uses `ModelMessagesTypeAdapter.dump_json`/`validate_json` for full-fidelity message serialization | ||
|
|
||
| ## Files | ||
|
|
||
| - `src/pydantic_harness/session_persistence.py` — stores, capability | ||
| - `src/pydantic_harness/__init__.py` — re-exports | ||
| - `tests/test_session_persistence.py` — 33 tests, 100% coverage |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,260 @@ | ||||||
| """Session persistence capability for saving and loading agent conversation history. | ||||||
|
|
||||||
| Provides automatic save/restore of conversation messages across agent runs, | ||||||
| with pluggable storage backends (``InMemorySessionStore`` for testing, | ||||||
| ``FileSessionStore`` for on-disk persistence via JSON files). | ||||||
| """ | ||||||
|
|
||||||
| from __future__ import annotations | ||||||
|
|
||||||
| import json as _json | ||||||
| from dataclasses import dataclass, field | ||||||
| from pathlib import Path | ||||||
| from typing import Any, Protocol, runtime_checkable | ||||||
| from uuid import uuid4 | ||||||
|
|
||||||
| from pydantic_ai.capabilities.abstract import AbstractCapability | ||||||
| from pydantic_ai.messages import ModelMessage, ModelMessagesTypeAdapter | ||||||
| from pydantic_ai.run import AgentRunResult | ||||||
| from pydantic_ai.tools import AgentDepsT, RunContext | ||||||
|
|
||||||
|
|
||||||
| @runtime_checkable | ||||||
| class SessionStore(Protocol): | ||||||
| """Protocol for pluggable session storage backends.""" | ||||||
|
|
||||||
| def save( | ||||||
| self, | ||||||
| session_id: str, | ||||||
| messages: list[ModelMessage], | ||||||
| *, | ||||||
| metadata: dict[str, Any] | None = None, | ||||||
| ) -> None: # pragma: no cover | ||||||
| """Persist conversation messages (and optional metadata) for the given session.""" | ||||||
| ... | ||||||
|
|
||||||
| def load(self, session_id: str) -> list[ModelMessage] | None: # pragma: no cover | ||||||
| """Load conversation messages for the given session, or None if not found.""" | ||||||
| ... | ||||||
|
|
||||||
| def load_metadata(self, session_id: str) -> dict[str, Any] | None: # pragma: no cover | ||||||
| """Load metadata for the given session, or None if not found.""" | ||||||
| ... | ||||||
|
|
||||||
| def list_sessions(self) -> list[str]: # pragma: no cover | ||||||
| """Return all stored session IDs.""" | ||||||
| ... | ||||||
|
|
||||||
| def delete(self, session_id: str) -> bool: # pragma: no cover | ||||||
| """Delete a session by ID. Returns True if it existed.""" | ||||||
| ... | ||||||
|
|
||||||
|
|
||||||
| class InMemorySessionStore: | ||||||
| """Dict-based in-memory session store, suitable for testing. | ||||||
|
|
||||||
| All data lives in a plain ``dict`` and is lost when the process exits. | ||||||
| """ | ||||||
|
|
||||||
| def __init__(self) -> None: | ||||||
| """Initialize an empty in-memory session store.""" | ||||||
| self._sessions: dict[str, list[ModelMessage]] = {} | ||||||
| self._metadata: dict[str, dict[str, Any]] = {} | ||||||
|
|
||||||
| def save( | ||||||
| self, | ||||||
| session_id: str, | ||||||
| messages: list[ModelMessage], | ||||||
| *, | ||||||
| metadata: dict[str, Any] | None = None, | ||||||
| ) -> None: | ||||||
| """Persist conversation messages (and optional metadata) for the given session.""" | ||||||
| self._sessions[session_id] = list(messages) | ||||||
| if metadata is not None: | ||||||
| self._metadata[session_id] = dict(metadata) | ||||||
| else: | ||||||
| self._metadata.pop(session_id, None) | ||||||
|
|
||||||
| def load(self, session_id: str) -> list[ModelMessage] | None: | ||||||
| """Load conversation messages for the given session.""" | ||||||
| messages = self._sessions.get(session_id) | ||||||
| if messages is None: | ||||||
| return None | ||||||
| return list(messages) | ||||||
|
|
||||||
| def load_metadata(self, session_id: str) -> dict[str, Any] | None: | ||||||
| """Load metadata for the given session.""" | ||||||
| meta = self._metadata.get(session_id) | ||||||
| if meta is None: | ||||||
| return None | ||||||
| return dict(meta) | ||||||
|
|
||||||
| def list_sessions(self) -> list[str]: | ||||||
| """Return all stored session IDs.""" | ||||||
| return list(self._sessions) | ||||||
|
|
||||||
| def delete(self, session_id: str) -> bool: | ||||||
| """Delete a session by ID.""" | ||||||
| self._metadata.pop(session_id, None) | ||||||
| return self._sessions.pop(session_id, None) is not None | ||||||
|
|
||||||
|
|
||||||
| class FileSessionStore: | ||||||
| """JSON-file-based session store for on-disk persistence. | ||||||
|
|
||||||
| Each session is stored as a separate JSON file in the configured directory, | ||||||
| using ``ModelMessagesTypeAdapter`` for serialization. | ||||||
| """ | ||||||
|
|
||||||
| def __init__(self, directory: str | Path) -> None: | ||||||
| """Initialize a file-backed session store at the given directory. | ||||||
|
|
||||||
| Args: | ||||||
| directory: Path to the directory where session files are stored. | ||||||
| Created automatically if it does not exist. | ||||||
| """ | ||||||
| self._directory = Path(directory) | ||||||
|
|
||||||
| def _path_for(self, session_id: str) -> Path: | ||||||
| return self._directory / f'{session_id}.json' | ||||||
|
|
||||||
| def _meta_path_for(self, session_id: str) -> Path: | ||||||
| return self._directory / f'{session_id}.meta.json' | ||||||
|
|
||||||
| def save( | ||||||
| self, | ||||||
| session_id: str, | ||||||
| messages: list[ModelMessage], | ||||||
| *, | ||||||
| metadata: dict[str, Any] | None = None, | ||||||
| ) -> None: | ||||||
| """Persist conversation messages (and optional metadata) as JSON files.""" | ||||||
| self._directory.mkdir(parents=True, exist_ok=True) | ||||||
| data = ModelMessagesTypeAdapter.dump_json(messages) | ||||||
| self._path_for(session_id).write_bytes(data) | ||||||
|
|
||||||
| meta_path = self._meta_path_for(session_id) | ||||||
| if metadata is not None: | ||||||
| meta_path.write_text(_json.dumps(metadata), encoding='utf-8') | ||||||
| elif meta_path.exists(): | ||||||
| meta_path.unlink() | ||||||
|
|
||||||
| def load(self, session_id: str) -> list[ModelMessage] | None: | ||||||
| """Load conversation messages from a JSON file.""" | ||||||
| path = self._path_for(session_id) | ||||||
| if not path.exists(): | ||||||
| return None | ||||||
| data = path.read_bytes() | ||||||
| return ModelMessagesTypeAdapter.validate_json(data) | ||||||
|
|
||||||
| def load_metadata(self, session_id: str) -> dict[str, Any] | None: | ||||||
| """Load metadata from a JSON file.""" | ||||||
| meta_path = self._meta_path_for(session_id) | ||||||
| if not meta_path.exists(): | ||||||
| return None | ||||||
| raw = meta_path.read_text(encoding='utf-8') | ||||||
| result: dict[str, Any] = _json.loads(raw) | ||||||
| return result | ||||||
|
|
||||||
| def list_sessions(self) -> list[str]: | ||||||
| """Return all session IDs found in the directory.""" | ||||||
| if not self._directory.exists(): | ||||||
| return [] | ||||||
| return sorted(p.stem for p in self._directory.glob('*.json') if not p.name.endswith('.meta.json')) | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 FileSessionStore.list_sessions() silently drops session IDs ending with The Example demonstrating the bugstore = FileSessionStore('/tmp/sessions')
store.save('experiment.meta', [msg])
print(store.list_sessions()) # [] — session is missing!
print(store.load('experiment.meta')) # [msg] — but it loads fine
Suggested change
Was this helpful? React with 👍 or 👎 to provide feedback. |
||||||
|
|
||||||
| def delete(self, session_id: str) -> bool: | ||||||
| """Delete a session file and its metadata. Returns True if it existed.""" | ||||||
| path = self._path_for(session_id) | ||||||
| existed = path.exists() | ||||||
| if existed: | ||||||
| path.unlink() | ||||||
| meta_path = self._meta_path_for(session_id) | ||||||
| if meta_path.exists(): | ||||||
| meta_path.unlink() | ||||||
| return existed | ||||||
|
|
||||||
|
|
||||||
| @dataclass | ||||||
| class SessionPersistence(AbstractCapability[AgentDepsT]): | ||||||
| """Capability for saving and restoring conversation state across agent runs. | ||||||
|
|
||||||
| On run start, loads any previously saved messages for the session and | ||||||
| prepends them to the conversation. On run end, saves the full message | ||||||
| history back to the store. | ||||||
|
|
||||||
| Example: | ||||||
| ```python | ||||||
| from pydantic_ai import Agent | ||||||
| from pydantic_harness.session_persistence import ( | ||||||
| SessionPersistence, | ||||||
| InMemorySessionStore, | ||||||
| ) | ||||||
|
|
||||||
| store = InMemorySessionStore() | ||||||
| agent = Agent( | ||||||
| 'openai:gpt-4o', | ||||||
| capabilities=[SessionPersistence(store=store, session_id='my-session')], | ||||||
| ) | ||||||
| ``` | ||||||
| """ | ||||||
|
|
||||||
| store: SessionStore = field(default_factory=InMemorySessionStore) | ||||||
| """The storage backend. Defaults to ``InMemorySessionStore`` (ephemeral).""" | ||||||
|
|
||||||
| session_id: str = field(default_factory=lambda: str(uuid4())) | ||||||
| """Unique identifier for this session. Auto-generated (UUID4) if not provided.""" | ||||||
|
|
||||||
| auto_save: bool = True | ||||||
| """Whether to automatically save messages after each run.""" | ||||||
|
|
||||||
| metadata: dict[str, Any] | None = None | ||||||
| """Optional metadata to store alongside the session messages. | ||||||
|
|
||||||
| When set, this dict is persisted on each save and can be retrieved | ||||||
| via ``store.load_metadata(session_id)``. | ||||||
| """ | ||||||
|
|
||||||
| @classmethod | ||||||
| def get_serialization_name(cls) -> str | None: | ||||||
| """Return the name used for spec serialization.""" | ||||||
| return 'SessionPersistence' | ||||||
|
|
||||||
| @classmethod | ||||||
| def from_spec(cls, *args: Any, **kwargs: Any) -> SessionPersistence[Any]: | ||||||
| """Create from spec arguments. | ||||||
|
|
||||||
| Supports ``backend`` kwarg: ``"memory"`` (default) or ``"file"`` (requires ``directory``). | ||||||
| """ | ||||||
| backend = kwargs.pop('backend', 'memory') | ||||||
| if backend == 'file': | ||||||
| directory = kwargs.pop('directory', '.sessions') | ||||||
| return cls(store=FileSessionStore(directory), **kwargs) | ||||||
| return cls(store=InMemorySessionStore(), **kwargs) | ||||||
|
|
||||||
| async def before_run( | ||||||
| self, | ||||||
| ctx: RunContext[AgentDepsT], | ||||||
| ) -> None: | ||||||
| """Load saved messages and prepend them to the conversation.""" | ||||||
| existing = self.store.load(self.session_id) | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||||||
| if existing: | ||||||
| ctx.messages[:0] = existing | ||||||
|
|
||||||
| async def after_run( | ||||||
| self, | ||||||
| ctx: RunContext[AgentDepsT], | ||||||
| *, | ||||||
| result: AgentRunResult[Any], | ||||||
| ) -> AgentRunResult[Any]: | ||||||
| """Save the full message history after a successful run.""" | ||||||
| if self.auto_save: | ||||||
| self.store.save(self.session_id, result.all_messages(), metadata=self.metadata) | ||||||
| return result | ||||||
|
Comment on lines
+234
to
+252
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🚩 Synchronous store operations called from async hooks The Was this helpful? React with 👍 or 👎 to provide feedback. |
||||||
|
|
||||||
|
|
||||||
| __all__ = [ | ||||||
| 'FileSessionStore', | ||||||
| 'InMemorySessionStore', | ||||||
| 'SessionPersistence', | ||||||
| 'SessionStore', | ||||||
| ] | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔴 Path traversal vulnerability in FileSessionStore allows reading/writing files outside the session directory
FileSessionStore._path_forand_meta_path_fordirectly interpolatesession_idinto a file path (self._directory / f'{session_id}.json') without any sanitization. Asession_idcontaining path traversal sequences like../../resolves to a path outside the intended directory. For example,session_id="../../tmp/evil"would causesave()to write to/tmp/evil.jsonandload()to read from it. Sincesession_idcan be explicitly user-provided (e.g. viafrom_spec(session_id=user_input)) and is passed tosave,load,load_metadata, anddelete, this enables arbitrary file read/write/delete (with.jsonor.meta.jsonsuffix) anywhere the process has permissions.Prompt for agents
Was this helpful? React with 👍 or 👎 to provide feedback.