diff --git a/docs/examples/rag-surrealdb.md b/docs/examples/rag-surrealdb.md new file mode 100644 index 0000000000..6ec4461209 --- /dev/null +++ b/docs/examples/rag-surrealdb.md @@ -0,0 +1,57 @@ +# RAG with SurrealDB + +RAG search example using SurrealDB. This demo allows you to ask question of the [logfire](https://pydantic.dev/logfire) documentation. + +Demonstrates: + +- [tools](../tools.md) +- [Web Chat UI](../web.md) +- RAG search with SurrealDB + +This is done by creating a database containing each section of the markdown documentation, then registering +the search tool with the Pydantic AI agent. + +Logic for extracting sections from markdown files and a JSON file with that data is available in +[this gist](https://gist.github.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992). + +Set up your OpenAI API key: + +```bash +export OPENAI_API_KEY=your-api-key +``` + +Or store it in a `.env` file and add `--env-file .env` to your `uv run` commands. + +Build the search database (**warning**: this calls the OpenAI embedding API for every documentation section from the [Logfire docs JSON gist](https://gist.github.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992)): + +```bash +uv run -m pydantic_ai_examples.rag_surrealdb build +``` + +Ask the agent a question with: + +```bash +uv run -m pydantic_ai_examples.rag_surrealdb search "How do I configure logfire to work with FastAPI?" +``` + +Or use the web UI: + +```bash +uv run uvicorn pydantic_ai_examples.rag_surrealdb:app --host 127.0.0.1 --port 7932 +``` + +This example runs SurrealDB embedded. To run it in a separate process (useful if you want to explore the database with [Surrealist](https://surrealdb.com/surrealist)), follow the [installation instructions](https://surrealdb.com/docs/surrealdb/installation) or [run with docker](https://surrealdb.com/docs/surrealdb/installation/running/docker): + +```bash +surreal start -u root -p root rocksdb:database +``` + +With docker + +```bash +docker run --rm --pull always -p 8000:8000 surrealdb/surrealdb:latest start -u root -p root rocksdb:database +``` + +## Example Code + +```snippet {path="/examples/pydantic_ai_examples/rag_surrealdb.py"}``` diff --git a/examples/pydantic_ai_examples/chat_app_surreal.py b/examples/pydantic_ai_examples/chat_app_surreal.py new file mode 100644 index 0000000000..aeae7ac27c --- /dev/null +++ b/examples/pydantic_ai_examples/chat_app_surreal.py @@ -0,0 +1,243 @@ +"""Simple chat app example build with FastAPI using SurrealDB embedded. + +Set up your OpenAI API key: + + export OPENAI_API_KEY=your-api-key + +Or, store it in a .env file, and add `--env-file .env` to your `uv run` commands. + +Run with: + + uv run -m pydantic_ai_examples.chat_app_surreal +""" + +from __future__ import annotations as _annotations + +import json +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Annotated, Literal + +import fastapi +import logfire +from fastapi import Depends, Request +from fastapi.responses import FileResponse, Response, StreamingResponse +from surrealdb import AsyncEmbeddedSurrealConnection, AsyncSurreal +from typing_extensions import TypedDict + +from pydantic_ai import ( + Agent, + ModelMessage, + ModelMessagesTypeAdapter, + ModelRequest, + ModelResponse, + TextPart, + UnexpectedModelBehavior, + UserPromptPart, +) + +# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured +logfire.configure(send_to_logfire='if-token-present') +logfire.instrument_pydantic_ai() +# TODO: enable this once https://github.com/pydantic/logfire/pull/1573 is released +# logfire.instrument_surrealdb() + +agent = Agent('openai:gpt-5') +THIS_DIR = Path(__file__).parent + + +@asynccontextmanager +async def lifespan(_app: fastapi.FastAPI): + async with Database.connect() as db: + yield {'db': db} + + +app = fastapi.FastAPI(lifespan=lifespan) +logfire.instrument_fastapi(app) + + +@app.get('/') +async def index() -> FileResponse: + return FileResponse((THIS_DIR / 'chat_app.html'), media_type='text/html') + + +@app.get('/chat_app.ts') +async def main_ts() -> FileResponse: + """Get the raw typescript code, it's compiled in the browser, forgive me.""" + return FileResponse((THIS_DIR / 'chat_app.ts'), media_type='text/plain') + + +async def get_db(request: Request) -> Database: + return request.state.db + + +@app.get('/chat/') +async def get_chat(database: Database = Depends(get_db)) -> Response: + msgs = await database.get_messages() + return Response( + b'\n'.join(json.dumps(to_chat_message(m)).encode('utf-8') for m in msgs), + media_type='text/plain', + ) + + +class ChatMessage(TypedDict): + """Format of messages sent to the browser.""" + + role: Literal['user', 'model'] + timestamp: str + content: str + + +def to_chat_message(m: ModelMessage) -> ChatMessage: + first_part = m.parts[0] + if isinstance(m, ModelRequest): + if isinstance(first_part, UserPromptPart): + assert isinstance(first_part.content, str) + return { + 'role': 'user', + 'timestamp': first_part.timestamp.isoformat(), + 'content': first_part.content, + } + elif isinstance(m, ModelResponse): + if isinstance(first_part, TextPart): + return { + 'role': 'model', + 'timestamp': m.timestamp.isoformat(), + 'content': first_part.content, + } + raise UnexpectedModelBehavior(f'Unexpected message type for chat app: {m}') + + +@app.post('/chat/') +async def post_chat( + prompt: Annotated[str, fastapi.Form()], database: Database = Depends(get_db) +) -> StreamingResponse: + async def stream_messages(): + """Streams new line delimited JSON `Message`s to the client.""" + # stream the user prompt so that can be displayed straight away + yield ( + json.dumps( + { + 'role': 'user', + 'timestamp': datetime.now(tz=timezone.utc).isoformat(), + 'content': prompt, + } + ).encode('utf-8') + + b'\n' + ) + # get the chat history so far to pass as context to the agent + messages = await database.get_messages() + # run the agent with the user prompt and the chat history + async with agent.run_stream(prompt, message_history=messages) as result: + async for text in result.stream_output(debounce_by=0.01): + # text here is a `str` and the frontend wants + # JSON encoded ModelResponse, so we create one + m = ModelResponse(parts=[TextPart(text)], timestamp=result.timestamp()) + yield json.dumps(to_chat_message(m)).encode('utf-8') + b'\n' + + # add new messages (e.g. the user prompt and the agent response in this case) to the database + await database.add_messages(result.new_messages_json()) + + return StreamingResponse(stream_messages(), media_type='text/plain') + + +@dataclass +class Database: + """Database to store chat messages in SurrealDB embedded. + + Uses file-based persistence to store messages across sessions. + """ + + db: AsyncEmbeddedSurrealConnection + namespace: str = 'chat_app' + database: str = 'messages' + + @classmethod + @asynccontextmanager + async def connect( + cls, db_path: Path = THIS_DIR / '.chat_app_messages_surrealdb' + ) -> AsyncIterator[Database]: + """Connect to SurrealDB embedded database. + + Uses file-based persistence so messages are saved across sessions. + """ + with logfire.span('connect to DB'): + db_url = f'file://{db_path}' + # Use async context manager to properly manage the connection + # The connection stays open for the entire lifespan of the FastAPI app + async with AsyncSurreal(db_url) as db: + if not isinstance(db, AsyncEmbeddedSurrealConnection): + raise ValueError( + f'Expected AsyncEmbeddedSurrealConnection, got {type(db)}' + ) + slf = cls(db) + # Set namespace and database + await slf.db.use(slf.namespace, slf.database) + # Create table schema if it doesn't exist + await slf._initialize_schema() + # Yield the database instance - connection stays open until lifespan ends + yield slf + # Connection will be closed automatically when the async with block exits + + async def _initialize_schema(self) -> None: + """Initialize the messages table schema.""" + # Define table if it doesn't exist + # SurrealDB will create the table automatically on first insert, + # but we can define it explicitly for better control + await self.db.query('DEFINE TABLE message SCHEMALESS;') + + async def add_messages(self, messages: bytes) -> None: + """Add new messages to the database. + + Messages are stored as JSON in the message_list field. + """ + # Decode the bytes to get the JSON string + messages_json = messages.decode('utf-8') + # Validate it's valid JSON (will raise if invalid) + json.loads(messages_json) + + # Create a record with the message list + # Using a timestamp-based ID and created_at field for proper ordering + now = datetime.now(timezone.utc) + await self.db.create( + 'message', + { + 'message_list': messages_json, + 'created_at': now.isoformat(), + }, + ) + + async def get_messages(self) -> list[ModelMessage]: + """Retrieve all messages from the database, ordered by creation time.""" + # Query all messages ordered by created_at timestamp + result = await self.db.query( + 'SELECT message_list, created_at FROM message ORDER BY created_at ASC;' + ) + + messages: list[ModelMessage] = [] + if isinstance(result, list): + for record in result: + if isinstance(record, dict) and 'message_list' in record: + # Parse the JSON string and extend the messages list + messages.extend( + ModelMessagesTypeAdapter.validate_json( + str(record['message_list']) + ) + ) + else: + raise ValueError(f'Expected list, got {type(result)}') + + return messages + + +if __name__ == '__main__': + import uvicorn + + uvicorn.run( + 'pydantic_ai_examples.chat_app_surreal:app', + reload=True, + reload_dirs=[str(THIS_DIR)], + ) diff --git a/examples/pydantic_ai_examples/rag_surrealdb.py b/examples/pydantic_ai_examples/rag_surrealdb.py new file mode 100644 index 0000000000..52b1376a10 --- /dev/null +++ b/examples/pydantic_ai_examples/rag_surrealdb.py @@ -0,0 +1,298 @@ +"""RAG example with pydantic-ai — using vector search to augment a chat agent. + +Uses SurrealDB with HNSW vector indexes for persistent storage. + +Set up your OpenAI API key: + + export OPENAI_API_KEY=your-api-key + +Or, store it in a .env file, and add `--env-file .env` to your `uv run` commands. + +Build the search DB with: + + uv run -m pydantic_ai_examples.rag_surrealdb build + +Ask the agent a question with: + + uv run -m pydantic_ai_examples.rag_surrealdb search "How do I configure logfire to work with FastAPI?" + +Or use the web UI: + + uv run uvicorn pydantic_ai_examples.rag_surrealdb:app --host 127.0.0.1 --port 7932 + +This example runs SurrealDB embedded. If you want to run it in a separate process (useful to explore the db using Surrealist) you can start it with (or with docker): + + surreal start -u root -p root rocksdb:database +""" + +from __future__ import annotations as _annotations + +import asyncio +import re +import sys +import unicodedata +from contextlib import asynccontextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import Any, cast + +import httpx +import logfire +from anyio import create_task_group +from pydantic import TypeAdapter +from surrealdb import ( + AsyncHttpSurrealConnection, + AsyncSurreal, + AsyncWsSurrealConnection, + RecordID, + Value, +) +from typing_extensions import AsyncGenerator + +from pydantic_ai import Agent, Embedder + +# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured +logfire.configure(send_to_logfire='if-token-present') +logfire.instrument_pydantic_ai() +logfire.instrument_openai() +# TODO: enable this once https://github.com/pydantic/logfire/pull/1573 is released +# logfire.instrument_surrealdb() + +THIS_DIR = Path(__file__).parent + +embedder = Embedder('openai:text-embedding-3-small') +agent = Agent('openai:gpt-5') + + +@agent.tool_plain +async def retrieve(search_query: str) -> str: + """Retrieve documentation sections based on a search query. + + Args: + search_query: The search query. + """ + + @dataclass + class RetrievalQueryResult: + url: str + title: str + content: str + dist: float + + result_ta = TypeAdapter(list[RetrievalQueryResult]) + + with logfire.span( + 'create embedding for {search_query=}', search_query=search_query + ): + result = await embedder.embed_query(search_query) + embedding = result.embeddings + + assert len(embedding) == 1, ( + f'Expected 1 embedding, got {len(embedding)}, doc query: {search_query!r}' + ) + embedding_vector = list(embedding[0]) + + # SurrealDB vector search using HNSW index + async with database_connect(False) as db: + result = await db.query( + """ + SELECT url, title, content, vector::distance::knn() AS dist + FROM doc_sections + WHERE embedding <|8, 40|> $vector + ORDER BY dist ASC + """, + {'vector': cast(Value, embedding_vector)}, + ) + + # Process SurrealDB query result + try: + rows = result_ta.validate_python(result) + logfire.info('Retrieved {len} results', len=len(rows)) + except Exception as e: + logfire.error('Failed to validate JSON response: {error}', error=e) + raise + + return '\n\n'.join( + f'# {row.title}\nDocumentation URL:{row.url}\n\n{row.content}\n' for row in rows + ) + + +async def run_agent(question: str): + """Entry point to run the agent and perform RAG based question answering.""" + logfire.info('Asking "{question}"', question=question) + answer = await agent.run(question) + print(answer.output) + + +# Web chat UI +app = agent.to_web() + +####################################################### +# The rest of this file is dedicated to preparing the # +# search database, and some utilities. # +####################################################### + +# JSON document from +# https://gist.github.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992 +DOCS_JSON = ( + 'https://gist.githubusercontent.com/' + 'samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992/raw/' + '80c5925c42f1442c24963aaf5eb1a324d47afe95/logfire_docs.json' +) + + +async def build_search_db(): + """Build the search database.""" + async with httpx.AsyncClient() as client: + response = await client.get(DOCS_JSON) + response.raise_for_status() + sections = sections_ta.validate_json(response.content) + + async with database_connect(True) as db: + with logfire.span('create schema'): + await db.query(DB_SCHEMA) + + embedding_sem = asyncio.Semaphore(10) + db_sem = asyncio.Semaphore(1) + async with create_task_group() as tg: + for section in sections: + tg.start_soon(insert_doc_section, embedding_sem, db_sem, db, section) + + +async def insert_doc_section( + embedding_sem: asyncio.Semaphore, + db_sem: asyncio.Semaphore, + db: AsyncWsSurrealConnection | AsyncHttpSurrealConnection, + section: DocsSection, +) -> None: + async with embedding_sem: + url = section.url() + # Create a URL-safe record ID + url_slug = slugify(url, '_') + record_id = RecordID('doc_sections', url_slug) + + # Check if record exists + existing = await db.select(record_id) + if existing: + logfire.info('Skipping {url=}', url=url) + return + + with logfire.span('create embedding for {url=}', url=url): + result = await embedder.embed_documents([section.embedding_content()]) + embedding = result.embeddings + assert len(embedding) == 1, ( + f'Expected 1 embedding, got {len(embedding)}, doc section: {section}' + ) + embedding_vector = embedding[0] + + async with db_sem: + # Create record with embedding as array, using record ID directly + res = await db.create( + record_id, + { + 'url': url, + 'title': section.title, + 'content': section.content, + 'embedding': list(embedding_vector), + }, + ) + if not isinstance(res, dict): + raise ValueError(f'Unexpected response from database: {res}') + + +@dataclass +class DocsSection: + id: int + parent: int | None + path: str + level: int + title: str + content: str + + def url(self) -> str: + url_path = re.sub(r'\.md$', '', self.path) + return ( + f'https://logfire.pydantic.dev/docs/{url_path}/#{slugify(self.title, "-")}' + ) + + def embedding_content(self) -> str: + return '\n\n'.join((f'path: {self.path}', f'title: {self.title}', self.content)) + + +sections_ta = TypeAdapter(list[DocsSection]) + + +@asynccontextmanager +async def database_connect(create_db: bool = False) -> AsyncGenerator[Any, None]: + namespace = 'pydantic_ai_examples' + database = 'rag_surrealdb' + username = 'root' + password = 'root' + + # Running SurrealDB embedded + db_path = THIS_DIR / f'.{database}' + db_url = f'file://{db_path}' + requires_auth = False + + # Running SurrealDB in a separate process, connect with URL + # db_url = 'ws://localhost:8000/rpc' + # namespace = 'pydantic_ai_examples' + # database = 'rag_surrealdb' + # requires_auth = True + + async with AsyncSurreal(db_url) as db: + # Sign in to the database + if requires_auth: + await db.signin({'username': username, 'password': password}) + + # Set namespace and database + await db.use(namespace, database) + + # Initialize schema if creating database + if create_db: + with logfire.span('create schema'): + await db.query(DB_SCHEMA) + + yield db + + +DB_SCHEMA = """ +DEFINE TABLE doc_sections SCHEMALESS; + +DEFINE FIELD embedding ON doc_sections TYPE array; + +DEFINE INDEX hnsw_idx_doc_sections ON doc_sections + FIELDS embedding + HNSW DIMENSION 1536 + DIST COSINE + TYPE F32; +""" + + +def slugify(value: str, separator: str, unicode: bool = False) -> str: + """Slugify a string, to make it URL friendly.""" + # Taken unchanged from https://github.com/Python-Markdown/markdown/blob/3.7/markdown/extensions/toc.py#L38 + if not unicode: + # Replace Extended Latin characters with ASCII, i.e. `žlutý` => `zluty` + value = unicodedata.normalize('NFKD', value) + value = value.encode('ascii', 'ignore').decode('ascii') + value = re.sub(r'[^\w\s-]', '', value).strip().lower() + return re.sub(rf'[{separator}\s]+', separator, value) + + +if __name__ == '__main__': + action = sys.argv[1] if len(sys.argv) > 1 else None + if action == 'build': + asyncio.run(build_search_db()) + elif action == 'search': + if len(sys.argv) == 3: + q = sys.argv[2] + else: + q = 'How do I configure logfire to work with FastAPI?' + asyncio.run(run_agent(q)) + else: + print( + 'uv run --extra examples -m pydantic_ai_examples.rag_surrealdb build|search', + file=sys.stderr, + ) + sys.exit(1) diff --git a/examples/pyproject.toml b/examples/pyproject.toml index 01adcbc560..95bb872915 100644 --- a/examples/pyproject.toml +++ b/examples/pyproject.toml @@ -60,6 +60,7 @@ dependencies = [ "mcp[cli]>=1.4.1", "modal>=1.0.4", "duckdb>=1.3.2", + "surrealdb>=1.0.8", "datasets>=4.0.0", "pandas>=2.2.3", ] diff --git a/mkdocs.yml b/mkdocs.yml index cb191b2360..1edf68eb36 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -116,7 +116,9 @@ nav: - Data & Analytics: - examples/sql-gen.md - examples/data-analyst.md + - RAG: - examples/rag.md + - examples/rag-surrealdb.md - Streaming: - examples/stream-markdown.md - examples/stream-whales.md diff --git a/uv.lock b/uv.lock index bed14038cc..bdb261a8d8 100644 --- a/uv.lock +++ b/uv.lock @@ -6060,6 +6060,7 @@ dependencies = [ { name = "pydantic-evals" }, { name = "python-multipart" }, { name = "rich" }, + { name = "surrealdb" }, { name = "uvicorn" }, ] @@ -6079,6 +6080,7 @@ requires-dist = [ { name = "pydantic-evals", editable = "pydantic_evals" }, { name = "python-multipart", specifier = ">=0.0.17" }, { name = "rich", specifier = ">=13.9.2" }, + { name = "surrealdb", specifier = ">=1.0.8" }, { name = "uvicorn", specifier = ">=0.32.0" }, ] @@ -8185,6 +8187,26 @@ dependencies = [ { name = "pydantic" }, ] +[[package]] +name = "surrealdb" +version = "1.0.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "pydantic-core" }, + { name = "requests" }, + { name = "typing-extensions", marker = "python_full_version < '3.12'" }, + { name = "websockets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/36/ad/6f7e69bddb77a7b8deb0874652a8e9a4a15e15736d09f13911f1a9490294/surrealdb-1.0.8.tar.gz", hash = "sha256:14a9b2e24b8a2fbe15b6894617a2c2aababaf02e7fb95bd755ab9182b40c92c6", size = 291033, upload-time = "2026-01-07T18:18:40.335Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/79/a5/682e642b0b161b49a43aec930604bbc9367dff6ebe7e53dd7768ed25195d/surrealdb-1.0.8-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:afc95b38d915ac7cb9adafc32d6e9b5a9548470095dad67efe626dd3b7bdbfc7", size = 5130558, upload-time = "2026-01-07T18:18:32.986Z" }, + { url = "https://files.pythonhosted.org/packages/73/94/8a0ef6934190e2aef75a3862246dca50b747c60fe87da79ef07ecea085ea/surrealdb-1.0.8-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:4052ea81bbb999bc4e48a38bbd852f89d52840bcc52573dbfa009b1260045271", size = 4991412, upload-time = "2026-01-07T18:18:34.649Z" }, + { url = "https://files.pythonhosted.org/packages/3a/3b/82703abfc8b96a3f5000b2edca28d6f093d07185022dd60a2e463f0c59a7/surrealdb-1.0.8-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:ca0d6d4deee59f2100580da8ca2df449543d2c2945dc12299217b712278bf812", size = 5789423, upload-time = "2026-01-07T18:18:35.995Z" }, + { url = "https://files.pythonhosted.org/packages/34/cb/dd598d0519ed537bd033f79dc7d008adee88469cc8a0e60e33a57d51989e/surrealdb-1.0.8-cp39-abi3-manylinux_2_24_x86_64.whl", hash = "sha256:b85d2ae0f43306496690081a07b06231f30383952280002275fe6083eafc2a2a", size = 5686857, upload-time = "2026-01-07T18:18:37.702Z" }, + { url = "https://files.pythonhosted.org/packages/e1/db/5e24536cb158edcd1a40992811ed49ad4b911b330cedc84371bfa0c1d160/surrealdb-1.0.8-cp39-abi3-win_amd64.whl", hash = "sha256:977c5f4602d16476f70557c6e729c4035c6323be580a756a26f79a103e0df46d", size = 5047898, upload-time = "2026-01-07T18:18:39.085Z" }, +] + [[package]] name = "sympy" version = "1.14.0"