diff --git a/CLAUDE.md b/CLAUDE.md index c73ec14..1cd3cd6 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -16,6 +16,7 @@ uv run ruff check # Run linting uv run ruff format # Format code uv run pytest # Run tests uv run pytest tests/ # Run specific test directory +uv run pytest --run-api-tests # Run all tests, including API tests uv add # Add a dependency to pyproject.toml and update lock file uv remove # Remove a dependency from pyproject.toml and update lock file diff --git a/agent-memory-client/README.md b/agent-memory-client/README.md index 439989f..0f22dd0 100644 --- a/agent-memory-client/README.md +++ b/agent-memory-client/README.md @@ -206,8 +206,8 @@ await client.update_working_memory_data( # Append messages new_messages = [ - MemoryMessage(role="user", content="What's the weather?"), - MemoryMessage(role="assistant", content="It's sunny today!") + {"role": "user", "content": "What's the weather?"}, + {"role": "assistant", "content": "It's sunny today!"} ] await client.append_messages_to_working_memory( diff --git a/agent-memory-client/agent_memory_client/__init__.py b/agent-memory-client/agent_memory_client/__init__.py index 23f8612..37ec9d6 100644 --- a/agent-memory-client/agent_memory_client/__init__.py +++ b/agent-memory-client/agent_memory_client/__init__.py @@ -5,7 +5,7 @@ memory management capabilities for AI agents and applications. """ -__version__ = "0.9.0b3" +__version__ = "0.9.0b4" from .client import MemoryAPIClient, MemoryClientConfig, create_memory_client from .exceptions import ( diff --git a/agent-memory-client/agent_memory_client/client.py b/agent-memory-client/agent_memory_client/client.py index 168da53..2ed9e8d 100644 --- a/agent-memory-client/agent_memory_client/client.py +++ b/agent-memory-client/agent_memory_client/client.py @@ -5,7 +5,6 @@ """ import asyncio -import contextlib import re from collections.abc import AsyncIterator from typing import TYPE_CHECKING, Any, Literal, TypedDict @@ -166,7 +165,11 @@ async def health_check(self) -> HealthCheckResponse: raise async def list_sessions( - self, limit: int = 20, offset: int = 0, namespace: str | None = None + self, + limit: int = 20, + offset: int = 0, + namespace: str | None = None, + user_id: str | None = None, ) -> SessionListResponse: """ List available sessions with optional pagination and namespace filtering. @@ -175,6 +178,7 @@ async def list_sessions( limit: Maximum number of sessions to return (default: 20) offset: Offset for pagination (default: 0) namespace: Optional namespace filter + user_id: Optional user ID filter Returns: SessionListResponse containing session IDs and total count @@ -188,6 +192,9 @@ async def list_sessions( elif self.config.default_namespace is not None: params["namespace"] = self.config.default_namespace + if user_id is not None: + params["user_id"] = user_id + try: response = await self._client.get("/v1/working-memory/", params=params) response.raise_for_status() @@ -199,6 +206,7 @@ async def list_sessions( async def get_working_memory( self, session_id: str, + user_id: str | None = None, namespace: str | None = None, window_size: int | None = None, model_name: ModelNameLiteral | None = None, @@ -209,6 +217,7 @@ async def get_working_memory( Args: session_id: The session ID to retrieve working memory for + user_id: The user ID to retrieve working memory for namespace: Optional namespace for the session window_size: Optional number of messages to include model_name: Optional model name to determine context window size @@ -223,6 +232,9 @@ async def get_working_memory( """ params = {} + if user_id is not None: + params["user_id"] = user_id + if namespace is not None: params["namespace"] = namespace elif self.config.default_namespace is not None: @@ -248,7 +260,12 @@ async def get_working_memory( f"/v1/working-memory/{session_id}", params=params ) response.raise_for_status() - return WorkingMemoryResponse(**response.json()) + + # Get the raw JSON response + response_data = response.json() + + # Messages from JSON parsing are already in the correct dict format + return WorkingMemoryResponse(**response_data) except httpx.HTTPStatusError as e: self._handle_http_error(e.response) raise @@ -257,6 +274,7 @@ async def put_working_memory( self, session_id: str, memory: WorkingMemory, + user_id: str | None = None, model_name: str | None = None, context_window_max: int | None = None, ) -> WorkingMemoryResponse: @@ -266,6 +284,7 @@ async def put_working_memory( Args: session_id: The session ID to store memory for memory: WorkingMemory object with messages and optional context + user_id: Optional user ID for the session (overrides user_id in memory object) model_name: Optional model name for context window management context_window_max: Optional direct specification of context window max tokens @@ -279,6 +298,9 @@ async def put_working_memory( # Build query parameters for model-aware summarization params = {} + if user_id is not None: + params["user_id"] = user_id + # Use provided model_name or fall back to config default effective_model_name = model_name or self.config.default_model_name if effective_model_name is not None: @@ -304,7 +326,7 @@ async def put_working_memory( raise async def delete_working_memory( - self, session_id: str, namespace: str | None = None + self, session_id: str, namespace: str | None = None, user_id: str | None = None ) -> AckResponse: """ Delete working memory for a session. @@ -312,6 +334,7 @@ async def delete_working_memory( Args: session_id: The session ID to delete memory for namespace: Optional namespace for the session + user_id: Optional user ID for the session Returns: AckResponse indicating success @@ -322,6 +345,9 @@ async def delete_working_memory( elif self.config.default_namespace is not None: params["namespace"] = self.config.default_namespace + if user_id is not None: + params["user_id"] = user_id + try: response = await self._client.delete( f"/v1/working-memory/{session_id}", params=params @@ -369,11 +395,10 @@ async def set_working_memory_data( # Get existing memory if preserving existing_memory = None if preserve_existing: - with contextlib.suppress(Exception): - existing_memory = await self.get_working_memory( - session_id=session_id, - namespace=namespace, - ) + existing_memory = await self.get_working_memory( + session_id=session_id, + namespace=namespace, + ) # Create new working memory with the data working_memory = WorkingMemory( @@ -427,12 +452,10 @@ async def add_memories_to_working_memory( ``` """ # Get existing memory - existing_memory = None - with contextlib.suppress(Exception): - existing_memory = await self.get_working_memory( - session_id=session_id, - namespace=namespace, - ) + existing_memory = await self.get_working_memory( + session_id=session_id, + namespace=namespace, + ) # Determine final memories list if replace or not existing_memory: @@ -571,7 +594,7 @@ async def search_long_term_memory( print(f"Found {results.total} memories") for memory in results.memories: - print(f"- {memory.text[:100]}... (distance: {memory.distance})") + print(f"- {memory.text[:100]}... (distance: {memory.dist})") ``` """ # Convert dictionary filters to their proper filter objects if needed @@ -651,11 +674,12 @@ async def search_memory_tool( user_id: str | None = None, ) -> dict[str, Any]: """ - Simplified memory search designed for LLM tool use. + Simplified long-term memory search designed for LLM tool use. This method provides a streamlined interface for LLMs to search long-term memory with common parameters and user-friendly output. - Perfect for exposing as a tool to LLM frameworks. + Perfect for exposing as a tool to LLM frameworks. Note: This only + searches long-term memory, not working memory. Args: query: The search query text @@ -664,6 +688,7 @@ async def search_memory_tool( memory_type: Optional memory type ("episodic", "semantic", "message") max_results: Maximum results to return (default: 5) min_relevance: Optional minimum relevance score (0.0-1.0) + user_id: Optional user ID to filter memories by Returns: Dict with 'memories' list and 'summary' for LLM consumption @@ -729,8 +754,8 @@ async def search_memory_tool( "created_at": memory.created_at.isoformat() if memory.created_at else None, - "relevance_score": 1.0 - memory.distance - if hasattr(memory, "distance") and memory.distance is not None + "relevance_score": 1.0 - memory.dist + if hasattr(memory, "dist") and memory.dist is not None else None, } ) @@ -784,7 +809,7 @@ async def handle_tool_calls(client, tool_calls): "type": "function", "function": { "name": "search_memory", - "description": "Search long-term memory for relevant information based on a query. Use this when you need to recall past conversations, user preferences, or previously stored information.", + "description": "Search long-term memory for relevant information based on a query. Use this when you need to recall past conversations, user preferences, or previously stored information. Note: This searches only long-term memory, not current working memory.", "parameters": { "type": "object", "properties": { @@ -820,6 +845,10 @@ async def handle_tool_calls(client, tool_calls): "maximum": 1.0, "description": "Optional minimum relevance score (0.0-1.0, higher = more relevant)", }, + "user_id": { + "type": "string", + "description": "Optional user ID to filter memories by (e.g., 'user123')", + }, }, "required": ["query"], }, @@ -864,6 +893,7 @@ async def get_working_memory_tool( result = await self.get_working_memory( session_id=session_id, namespace=namespace or self.config.default_namespace, + user_id=user_id, ) # Format for LLM consumption @@ -1953,6 +1983,7 @@ async def update_working_memory_data( data_updates: dict[str, Any], namespace: str | None = None, merge_strategy: Literal["replace", "merge", "deep_merge"] = "merge", + user_id: str | None = None, ) -> WorkingMemoryResponse: """ Update specific data fields in working memory without replacing everything. @@ -1962,16 +1993,15 @@ async def update_working_memory_data( data_updates: Dictionary of updates to apply namespace: Optional namespace merge_strategy: How to handle existing data + user_id: Optional user ID for the session Returns: WorkingMemoryResponse with updated memory """ # Get existing memory - existing_memory = None - with contextlib.suppress(Exception): - existing_memory = await self.get_working_memory( - session_id=session_id, namespace=namespace - ) + existing_memory = await self.get_working_memory( + session_id=session_id, namespace=namespace, user_id=user_id + ) # Determine final data based on merge strategy if existing_memory and existing_memory.data: @@ -2002,11 +2032,11 @@ async def update_working_memory_data( async def append_messages_to_working_memory( self, session_id: str, - messages: list[Any], # Using Any since MemoryMessage isn't imported + messages: list[dict[str, Any]], # Expect proper message dicts namespace: str | None = None, - auto_summarize: bool = True, model_name: str | None = None, context_window_max: int | None = None, + user_id: str | None = None, ) -> WorkingMemoryResponse: """ Append new messages to existing working memory. @@ -2015,9 +2045,8 @@ async def append_messages_to_working_memory( Args: session_id: Target session - messages: List of messages to append + messages: List of message dictionaries with 'role' and 'content' keys namespace: Optional namespace - auto_summarize: Whether to allow automatic summarization model_name: Optional model name for token-based summarization context_window_max: Optional direct specification of context window max tokens @@ -2025,48 +2054,23 @@ async def append_messages_to_working_memory( WorkingMemoryResponse with updated memory (potentially summarized if token limit exceeded) """ # Get existing memory - existing_memory = None - with contextlib.suppress(Exception): - existing_memory = await self.get_working_memory( - session_id=session_id, namespace=namespace - ) + existing_memory = await self.get_working_memory( + session_id=session_id, namespace=namespace, user_id=user_id + ) - # Combine messages - convert MemoryMessage objects to dicts if needed - existing_messages = existing_memory.messages if existing_memory else [] - - # Convert existing messages to dict format if they're objects - converted_existing_messages = [] - for msg in existing_messages: - if hasattr(msg, "model_dump"): - converted_existing_messages.append(msg.model_dump()) - elif hasattr(msg, "role") and hasattr(msg, "content"): - converted_existing_messages.append( - {"role": msg.role, "content": msg.content} - ) - elif isinstance(msg, dict): - # Message is already a dictionary, use as-is - converted_existing_messages.append(msg) - else: - # Fallback for any other message type - convert to string content - converted_existing_messages.append( # type: ignore - {"role": "user", "content": str(msg)} + # Validate new messages have required structure + for msg in messages: + if not isinstance(msg, dict) or "role" not in msg or "content" not in msg: + raise ValueError( + "All messages must be dictionaries with 'role' and 'content' keys" ) - # Convert new messages to dict format if they're objects - new_messages = [] - for msg in messages: - if hasattr(msg, "model_dump"): - new_messages.append(msg.model_dump()) - elif hasattr(msg, "role") and hasattr(msg, "content"): - new_messages.append({"role": msg.role, "content": msg.content}) - elif isinstance(msg, dict): - # Message is already a dictionary, use as-is - new_messages.append(msg) - else: - # Fallback - assume it's already in the right format - new_messages.append(msg) + # Get existing messages (already in proper dict format from get_working_memory) + existing_messages = [] + if existing_memory and existing_memory.messages: + existing_messages = existing_memory.messages - final_messages = converted_existing_messages + new_messages + final_messages = existing_messages + messages # Create updated working memory working_memory = WorkingMemory( @@ -2095,6 +2099,7 @@ async def memory_prompt( model_name: str | None = None, context_window_max: int | None = None, long_term_search: dict[str, Any] | None = None, + user_id: str | None = None, ) -> dict[str, Any]: """ Hydrate a user query with memory context and return a prompt ready to send to an LLM. @@ -2107,6 +2112,7 @@ async def memory_prompt( model_name: Optional model name to determine context window size context_window_max: Optional direct specification of context window tokens long_term_search: Optional search parameters for long-term memory + user_id: Optional user ID for the session Returns: Dict with messages hydrated with relevant memory context @@ -2151,6 +2157,8 @@ async def memory_prompt( ) if effective_context_window_max is not None: session_params["context_window_max"] = str(effective_context_window_max) + if user_id is not None: + session_params["user_id"] = user_id payload["session"] = session_params # Add long-term search parameters if provided diff --git a/agent-memory-client/tests/test_client.py b/agent-memory-client/tests/test_client.py index d9bc309..7847f1b 100644 --- a/agent-memory-client/tests/test_client.py +++ b/agent-memory-client/tests/test_client.py @@ -527,8 +527,8 @@ async def test_append_messages_to_working_memory(self, enhanced_test_client): ) new_messages = [ - MemoryMessage(role="assistant", content="Second message"), - MemoryMessage(role="user", content="Third message"), + {"role": "assistant", "content": "Second message"}, + {"role": "user", "content": "Third message"}, ] with ( diff --git a/agent_memory_server/__init__.py b/agent_memory_server/__init__.py index 8e3a3ab..b685200 100644 --- a/agent_memory_server/__init__.py +++ b/agent_memory_server/__init__.py @@ -1,3 +1,3 @@ """Redis Agent Memory Server - A memory system for conversational AI.""" -__version__ = "0.9.0b3" +__version__ = "0.9.0b4" diff --git a/agent_memory_server/api.py b/agent_memory_server/api.py index bdaa2db..d7a21f5 100644 --- a/agent_memory_server/api.py +++ b/agent_memory_server/api.py @@ -188,7 +188,7 @@ async def list_sessions( Get a list of session IDs, with optional pagination. Args: - options: Query parameters (page, size, namespace) + options: Query parameters (limit, offset, namespace, user_id) Returns: List of session IDs @@ -200,6 +200,7 @@ async def list_sessions( limit=options.limit, offset=options.offset, namespace=options.namespace, + user_id=options.user_id, ) return SessionListResponse( @@ -211,8 +212,8 @@ async def list_sessions( @router.get("/v1/working-memory/{session_id}", response_model=WorkingMemoryResponse) async def get_working_memory( session_id: str, + user_id: str | None = None, namespace: str | None = None, - window_size: int = settings.window_size, # Deprecated: kept for backward compatibility model_name: ModelNameLiteral | None = None, context_window_max: int | None = None, current_user: UserInfo = Depends(get_current_user), @@ -225,8 +226,8 @@ async def get_working_memory( Args: session_id: The session ID + user_id: The user ID to retrieve working memory for namespace: The namespace to use for the session - window_size: DEPRECATED - The number of messages to include (kept for backward compatibility) model_name: The client's LLM model name (will determine context window size if provided) context_window_max: Direct specification of the context window max tokens (overrides model_name) @@ -240,6 +241,7 @@ async def get_working_memory( session_id=session_id, namespace=namespace, redis_client=redis, + user_id=user_id, ) if not working_mem: @@ -249,6 +251,7 @@ async def get_working_memory( memories=[], session_id=session_id, namespace=namespace, + user_id=user_id, ) # Apply token-based truncation if we have messages and model info @@ -266,10 +269,6 @@ async def get_working_memory( break working_mem.messages = truncated_messages - # Fallback to message-count truncation for backward compatibility - elif len(working_mem.messages) > window_size: - working_mem.messages = working_mem.messages[-window_size:] - return working_mem @@ -277,6 +276,7 @@ async def get_working_memory( async def put_working_memory( session_id: str, memory: WorkingMemory, + user_id: str | None = None, model_name: ModelNameLiteral | None = None, context_window_max: int | None = None, background_tasks=Depends(get_background_tasks), @@ -291,6 +291,7 @@ async def put_working_memory( Args: session_id: The session ID memory: Working memory to save + user_id: Optional user ID for the session (overrides user_id in memory object) model_name: The client's LLM model name for context window determination context_window_max: Direct specification of context window max tokens background_tasks: DocketBackgroundTasks instance (injected automatically) @@ -303,6 +304,10 @@ async def put_working_memory( # Ensure session_id matches memory.session_id = session_id + # Override user_id if provided as query parameter + if user_id is not None: + memory.user_id = user_id + # Validate that all structured memories have id (if any) for mem in memory.memories: if not mem.id: @@ -359,6 +364,7 @@ async def put_working_memory( @router.delete("/v1/working-memory/{session_id}", response_model=AckResponse) async def delete_working_memory( session_id: str, + user_id: str | None = None, namespace: str | None = None, current_user: UserInfo = Depends(get_current_user), ): @@ -369,6 +375,7 @@ async def delete_working_memory( Args: session_id: The session ID + user_id: Optional user ID for the session namespace: Optional namespace for the session Returns: @@ -379,6 +386,7 @@ async def delete_working_memory( # Delete unified working memory await working_memory.delete_working_memory( session_id=session_id, + user_id=user_id, namespace=namespace, redis_client=redis, ) @@ -558,6 +566,7 @@ async def memory_prompt( working_mem = await working_memory.get_working_memory( session_id=params.session.session_id, namespace=params.session.namespace, + user_id=params.session.user_id, redis_client=redis, ) diff --git a/agent_memory_server/config.py b/agent_memory_server/config.py index e5197b5..01a3ab5 100644 --- a/agent_memory_server/config.py +++ b/agent_memory_server/config.py @@ -20,7 +20,6 @@ def load_yaml_settings(): class Settings(BaseSettings): redis_url: str = "redis://localhost:6379" long_term_memory: bool = True - window_size: int = 20 openai_api_key: str | None = None anthropic_api_key: str | None = None generation_model: str = "gpt-4o-mini" @@ -66,6 +65,9 @@ class Settings(BaseSettings): auth0_client_id: str | None = None auth0_client_secret: str | None = None + # Working memory settings + window_size: int = 20 # Default number of recent messages to return + # Other Application settings log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" diff --git a/agent_memory_server/dependencies.py b/agent_memory_server/dependencies.py index f0d8ae4..44fead3 100644 --- a/agent_memory_server/dependencies.py +++ b/agent_memory_server/dependencies.py @@ -27,10 +27,19 @@ async def add_task( logger.info("Scheduling task through Docket") # Get the Redis connection that's already configured (will use testcontainer in tests) redis_conn = await get_redis_conn() - # Use the connection's URL instead of settings.redis_url directly - redis_url = redis_conn.connection_pool.connection_kwargs.get( - "url", settings.redis_url - ) + + # Extract Redis URL from the connection pool + connection_kwargs = redis_conn.connection_pool.connection_kwargs + if "host" in connection_kwargs and "port" in connection_kwargs: + redis_url = ( + f"redis://{connection_kwargs['host']}:{connection_kwargs['port']}" + ) + if "db" in connection_kwargs: + redis_url += f"/{connection_kwargs['db']}" + else: + # Fallback to settings if we can't extract from connection + redis_url = settings.redis_url + logger.info("redis_url: %s", redis_url) logger.info("docket_name: %s", settings.docket_name) async with Docket( diff --git a/agent_memory_server/long_term_memory.py b/agent_memory_server/long_term_memory.py index f3f6990..6e0c7bb 100644 --- a/agent_memory_server/long_term_memory.py +++ b/agent_memory_server/long_term_memory.py @@ -46,9 +46,6 @@ ) -DEFAULT_MEMORY_LIMIT = 1000 -MEMORY_INDEX = "memory_idx" - # Prompt for extracting memories from messages in working memory context WORKING_MEMORY_EXTRACTION_PROMPT = """ You are a memory extraction assistant. Your job is to analyze conversation @@ -354,15 +351,27 @@ async def compact_long_term_memories( # Find all memories with this hash # Use FT.SEARCH to find the actual memories with this hash # TODO: Use RedisVL index - search_query = ( - f"FT.SEARCH {index_name} " - f"(@memory_hash:{{{memory_hash}}}) {' '.join(filters)} " - "RETURN 6 id_ text last_accessed created_at user_id session_id " - "SORTBY last_accessed ASC" # Oldest first - ) + if filters: + # Combine hash query with filters using boolean AND + query_expr = f"(@memory_hash:{{{memory_hash}}}) ({' '.join(filters)})" + else: + query_expr = f"@memory_hash:{{{memory_hash}}}" search_results = await redis_client.execute_command( - search_query + "FT.SEARCH", + index_name, + f"'{query_expr}'", + "RETURN", + "6", + "id_", + "text", + "last_accessed", + "created_at", + "user_id", + "session_id", + "SORTBY", + "last_accessed", + "ASC", ) if search_results and search_results[0] > 1: @@ -1209,15 +1218,24 @@ async def deduplicate_by_hash( # Use FT.SEARCH to find memories with this hash # TODO: Use RedisVL - search_query = ( - f"FT.SEARCH {index_name} " - f"(@memory_hash:{{{memory_hash}}}) {filter_str} " - "RETURN 1 id_ " - "SORTBY last_accessed DESC" # Newest first + if filter_str: + # Combine hash query with filters using boolean AND + query_expr = f"(@memory_hash:{{{memory_hash}}}) ({filter_str})" + else: + query_expr = f"@memory_hash:{{{memory_hash}}}" + + search_results = await redis_client.execute_command( + "FT.SEARCH", + index_name, + f"'{query_expr}'", + "RETURN", + "1", + "id_", + "SORTBY", + "last_accessed", + "DESC", ) - search_results = await redis_client.execute_command(search_query) - if search_results and search_results[0] > 0: # Found existing memory with the same hash logger.info(f"Found existing memory with hash {memory_hash}") @@ -1285,15 +1303,25 @@ async def deduplicate_by_id( # Use FT.SEARCH to find memories with this id # TODO: Use RedisVL - search_query = ( - f"FT.SEARCH {index_name} " - f"(@id:{{{memory.id}}}) {filter_str} " - "RETURN 2 id_ persisted_at " - "SORTBY last_accessed DESC" # Newest first + if filter_str: + # Combine the id query with filters - Redis FT.SEARCH uses implicit AND between terms + query_expr = f"@id:{{{memory.id}}} {filter_str}" + else: + query_expr = f"@id:{{{memory.id}}}" + + search_results = await redis_client.execute_command( + "FT.SEARCH", + index_name, + f"'{query_expr}'", + "RETURN", + "2", + "id_", + "persisted_at", + "SORTBY", + "last_accessed", + "DESC", ) - search_results = await redis_client.execute_command(search_query) - if search_results and search_results[0] > 0: # Found existing memory with the same id logger.info(f"Found existing memory with id {memory.id}, will overwrite") diff --git a/agent_memory_server/mcp.py b/agent_memory_server/mcp.py index 7deeccb..cdde4e8 100644 --- a/agent_memory_server/mcp.py +++ b/agent_memory_server/mcp.py @@ -562,6 +562,7 @@ async def memory_prompt( session = WorkingMemoryRequest( session_id=_session_id, namespace=namespace.eq if namespace and namespace.eq else None, + user_id=user_id.eq if user_id and user_id.eq else None, window_size=window_size, model_name=model_name, context_window_max=context_window_max, diff --git a/agent_memory_server/models.py b/agent_memory_server/models.py index 10357e1..6411507 100644 --- a/agent_memory_server/models.py +++ b/agent_memory_server/models.py @@ -216,6 +216,7 @@ class WorkingMemoryRequest(BaseModel): session_id: str namespace: str | None = None + user_id: str | None = None window_size: int = settings.window_size model_name: ModelNameLiteral | None = None context_window_max: int | None = None @@ -257,6 +258,7 @@ class GetSessionsQuery(BaseModel): limit: int = Field(default=20, ge=1, le=100) offset: int = Field(default=0, ge=0) namespace: str | None = None + user_id: str | None = None class HealthCheckResponse(BaseModel): diff --git a/agent_memory_server/utils/keys.py b/agent_memory_server/utils/keys.py index 17fc35b..aec1b77 100644 --- a/agent_memory_server/utils/keys.py +++ b/agent_memory_server/utils/keys.py @@ -56,13 +56,22 @@ def metadata_key(session_id: str, namespace: str | None = None) -> str: ) @staticmethod - def working_memory_key(session_id: str, namespace: str | None = None) -> str: + def working_memory_key( + session_id: str, user_id: str | None = None, namespace: str | None = None + ) -> str: """Get the working memory key for a session.""" - return ( - f"working_memory:{namespace}:{session_id}" - if namespace - else f"working_memory:{session_id}" - ) + # Build key components, filtering out None values + key_parts = ["working_memory"] + + if namespace: + key_parts.append(namespace) + + if user_id: + key_parts.append(user_id) + + key_parts.append(session_id) + + return ":".join(key_parts) @staticmethod def search_index_name() -> str: diff --git a/agent_memory_server/working_memory.py b/agent_memory_server/working_memory.py index a754d11..182d1e1 100644 --- a/agent_memory_server/working_memory.py +++ b/agent_memory_server/working_memory.py @@ -27,8 +27,26 @@ async def list_sessions( limit: int = 10, offset: int = 0, namespace: str | None = None, + user_id: str | None = None, ) -> tuple[int, list[str]]: - """List sessions""" + """ + List sessions + + Args: + redis: Redis client + limit: Maximum number of sessions to return + offset: Offset for pagination + namespace: Optional namespace filter + user_id: Optional user ID filter (not yet implemented - sessions are stored in sorted sets) + + Returns: + Tuple of (total_count, session_ids) + + Note: + The user_id parameter is accepted for API compatibility but filtering by user_id + is not yet implemented. This would require changing how sessions are stored to + enable efficient user_id-based filtering. + """ # Calculate start and end indices (0-indexed start, inclusive end) start = offset end = offset + limit - 1 @@ -47,9 +65,9 @@ async def list_sessions( async def get_working_memory( session_id: str, + user_id: str | None = None, namespace: str | None = None, redis_client: Redis | None = None, - effective_window_size: int | None = None, ) -> WorkingMemory | None: """ Get working memory for a session. @@ -65,7 +83,11 @@ async def get_working_memory( if not redis_client: redis_client = await get_redis_conn() - key = Keys.working_memory_key(session_id, namespace) + key = Keys.working_memory_key( + session_id=session_id, + user_id=user_id, + namespace=namespace, + ) try: data = await redis_client.get(key) @@ -132,7 +154,11 @@ async def set_working_memory( if not memory.id: raise ValueError("All memory records in working memory must have an id") - key = Keys.working_memory_key(working_memory.session_id, working_memory.namespace) + key = Keys.working_memory_key( + session_id=working_memory.session_id, + user_id=working_memory.user_id, + namespace=working_memory.namespace, + ) # Update the updated_at timestamp working_memory.updated_at = datetime.now(UTC) @@ -179,6 +205,7 @@ async def set_working_memory( async def delete_working_memory( session_id: str, + user_id: str | None = None, namespace: str | None = None, redis_client: Redis | None = None, ) -> None: @@ -187,13 +214,16 @@ async def delete_working_memory( Args: session_id: The session ID + user_id: Optional user ID for the session namespace: Optional namespace for the session redis_client: Optional Redis client """ if not redis_client: redis_client = await get_redis_conn() - key = Keys.working_memory_key(session_id, namespace) + key = Keys.working_memory_key( + session_id=session_id, user_id=user_id, namespace=namespace + ) try: await redis_client.delete(key) diff --git a/docs/getting-started.md b/docs/getting-started.md index 9c2746d..09576cb 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -22,21 +22,27 @@ uv sync ## Running -The easiest way to start the REST and MCP servers is to use Docker Compose. See the Docker Compose section below for more details. +The easiest way to start the worker, REST API server, and MCP server is to use Docker Compose. See the Docker Compose section below for more details. -But you can also run these servers via the CLI commands. Here's how you +But you can also run these components via the CLI commands. Here's how you run the REST API server: ```bash uv run agent-memory api ``` -And the MCP server: +Or the MCP server: ```bash uv run agent-memory mcp --mode ``` +Both servers require a worker to be running, which you can start like this: + +```bash +uv run agent-memory task-worker +``` + **NOTE:** With uv, prefix the command with `uv`, e.g.: `uv run agent-memory --mode sse`. If you installed from source, you'll probably need to add `--directory` to tell uv where to find the code: `uv run --directory run agent-memory --mode stdio`. ## Docker Compose diff --git a/docs/memory-types.md b/docs/memory-types.md index b7856d1..c1cf549 100644 --- a/docs/memory-types.md +++ b/docs/memory-types.md @@ -356,9 +356,6 @@ LONG_TERM_MEMORY=true # Enable long-term memory features # Long-term memory settings ENABLE_DISCRETE_MEMORY_EXTRACTION=true # Extract memories from messages GENERATION_MODEL=gpt-4o-mini # Model for summarization/extraction - -# Search settings -DEFAULT_MEMORY_LIMIT=1000 # Default search result limit ``` For complete configuration options, see the [Configuration Guide](configuration.md). diff --git a/pyproject.toml b/pyproject.toml index 7cc8f02..be674aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,6 +122,14 @@ quote-style = "double" # Use spaces for indentation indent-style = "space" +[tool.uv.sources] +agent-memory-client = { path = "agent-memory-client" } + +[project.optional-dependencies] +dev = [ + "agent-memory-client" +] + [dependency-groups] dev = [ "pytest>=8.3.5", diff --git a/tests/conftest.py b/tests/conftest.py index 258e5fc..7eefafd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -275,11 +275,11 @@ def redis_url(redis_container): @pytest.fixture() -def async_redis_client(redis_url): +def async_redis_client(use_test_redis_connection): """ - An async Redis client that uses the dynamic `redis_url`. + An async Redis client that uses the same connection as other test fixtures. """ - return AsyncRedis.from_url(redis_url) + return use_test_redis_connection @pytest.fixture() diff --git a/tests/test_api.py b/tests/test_api.py index ed8729c..43b20dc 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -82,7 +82,7 @@ async def test_get_memory(self, client, session): session_id = session response = await client.get( - f"/v1/working-memory/{session_id}?namespace=test-namespace" + f"/v1/working-memory/{session_id}?namespace=test-namespace&user_id=test-user" ) assert response.status_code == 200 @@ -288,7 +288,7 @@ async def test_delete_memory(self, client, session): session_id = session response = await client.get( - f"/v1/working-memory/{session_id}?namespace=test-namespace" + f"/v1/working-memory/{session_id}?namespace=test-namespace&user_id=test-user" ) assert response.status_code == 200 @@ -297,7 +297,7 @@ async def test_delete_memory(self, client, session): assert len(data["messages"]) == 2 response = await client.delete( - f"/v1/working-memory/{session_id}?namespace=test-namespace" + f"/v1/working-memory/{session_id}?namespace=test-namespace&user_id=test-user" ) assert response.status_code == 200 @@ -307,7 +307,7 @@ async def test_delete_memory(self, client, session): assert data["status"] == "ok" response = await client.get( - f"/v1/working-memory/{session_id}?namespace=test-namespace" + f"/v1/working-memory/{session_id}?namespace=test-namespace&user_id=test-user" ) assert response.status_code == 200 diff --git a/tests/test_cli.py b/tests/test_cli.py index e3ae11e..2d06215 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -229,11 +229,11 @@ class TestTaskWorker: @patch("docket.Worker.run") @patch("agent_memory_server.cli.settings") - def test_task_worker_success(self, mock_settings, mock_worker_run): + def test_task_worker_success(self, mock_settings, mock_worker_run, redis_url): """Test successful task worker start.""" mock_settings.use_docket = True mock_settings.docket_name = "test-docket" - mock_settings.redis_url = "redis://localhost:6379/0" + mock_settings.redis_url = redis_url mock_worker_run.return_value = None @@ -258,11 +258,13 @@ def test_task_worker_docket_disabled(self, mock_settings): @patch("docket.Worker.run") @patch("agent_memory_server.cli.settings") - def test_task_worker_default_params(self, mock_settings, mock_worker_run): + def test_task_worker_default_params( + self, mock_settings, mock_worker_run, redis_url + ): """Test task worker with default parameters.""" mock_settings.use_docket = True mock_settings.docket_name = "test-docket" - mock_settings.redis_url = "redis://localhost:6379/0" + mock_settings.redis_url = redis_url mock_worker_run.return_value = None diff --git a/tests/test_client_enhancements.py b/tests/test_client_enhancements.py index d7fa483..93a0234 100644 --- a/tests/test_client_enhancements.py +++ b/tests/test_client_enhancements.py @@ -4,20 +4,19 @@ import pytest from agent_memory_client import MemoryAPIClient, MemoryClientConfig -from fastapi import FastAPI -from httpx import ASGITransport, AsyncClient - -from agent_memory_server.api import router as memory_router -from agent_memory_server.healthcheck import router as health_router -from agent_memory_server.models import ( +from agent_memory_client.models import ( AckResponse, ClientMemoryRecord, - MemoryMessage, MemoryRecordResult, MemoryRecordResults, MemoryTypeEnum, WorkingMemoryResponse, ) +from fastapi import FastAPI +from httpx import ASGITransport, AsyncClient + +from agent_memory_server.api import router as memory_router +from agent_memory_server.healthcheck import router as health_router @pytest.fixture @@ -520,7 +519,7 @@ async def test_append_messages_to_working_memory(self, enhanced_test_client): session_id = "test-session" existing_messages = [ - MemoryMessage(role="user", content="First message"), + {"role": "user", "content": "First message"}, ] existing_memory = WorkingMemoryResponse( @@ -533,8 +532,8 @@ async def test_append_messages_to_working_memory(self, enhanced_test_client): ) new_messages = [ - MemoryMessage(role="assistant", content="Second message"), - MemoryMessage(role="user", content="Third message"), + {"role": "assistant", "content": "Second message"}, + {"role": "user", "content": "Third message"}, ] with ( diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 7be3092..ba139d9 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -102,6 +102,7 @@ async def test_memory_prompt(self, session, mcp_test_setup): "query": "Test query", "session_id": {"eq": session}, "namespace": {"eq": "test-namespace"}, + "user_id": {"eq": "test-user"}, }, ) assert isinstance(prompt, CallToolResult) diff --git a/uv.lock b/uv.lock index 0ba3608..292ce56 100644 --- a/uv.lock +++ b/uv.lock @@ -19,6 +19,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/63/b1/8198e3cdd11a426b1df2912e3381018c4a4a55368f6d0857ba3ca418ef93/accelerate-1.6.0-py3-none-any.whl", hash = "sha256:1aee717d3d3735ad6d09710a7c26990ee4652b79b4e93df46551551b5227c2aa", size = 354748 }, ] +[[package]] +name = "agent-memory-client" +source = { directory = "agent-memory-client" } +dependencies = [ + { name = "httpx" }, + { name = "pydantic" }, + { name = "python-ulid" }, +] + +[package.metadata] +requires-dist = [ + { name = "httpx", specifier = ">=0.25.0" }, + { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.5.0" }, + { name = "pydantic", specifier = ">=2.0.0" }, + { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" }, + { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.21.0" }, + { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0.0" }, + { name = "pytest-httpx", marker = "extra == 'dev'", specifier = ">=0.21.0" }, + { name = "python-ulid", specifier = ">=3.0.0" }, + { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.0" }, +] + [[package]] name = "agent-memory-server" source = { editable = "." } @@ -50,6 +72,11 @@ dependencies = [ { name = "uvicorn" }, ] +[package.optional-dependencies] +dev = [ + { name = "agent-memory-client" }, +] + [package.dev-dependencies] dev = [ { name = "freezegun" }, @@ -66,6 +93,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "accelerate", specifier = ">=1.6.0" }, + { name = "agent-memory-client", marker = "extra == 'dev'", directory = "agent-memory-client" }, { name = "anthropic", specifier = ">=0.15.0" }, { name = "bertopic", specifier = ">=0.16.4,<0.17.0" }, { name = "click", specifier = ">=8.1.0" },