diff --git a/flagscale/agent/__init__.py b/flagscale/agent/__init__.py index e69de29bb..57a939c96 100644 --- a/flagscale/agent/__init__.py +++ b/flagscale/agent/__init__.py @@ -0,0 +1,26 @@ +"""FlagScale Agent Module + +Modules for collaboration, tool matching, and memory management. +""" + +# Export collaboration module +from .collaboration import Collaborator + +# Export memory module +from .memory import MemoryManager, MemoryModuleConfig, create_memory_toolkit, register_memory_tools + +# Export tool matching module +from .tool_match import ToolMatcher, ToolRegistry + +__all__ = [ + # Collaboration + "Collaborator", + # Tool matching + "ToolRegistry", + "ToolMatcher", + # Memory management + "MemoryManager", + "register_memory_tools", + "create_memory_toolkit", + "MemoryModuleConfig", +] diff --git a/flagscale/agent/memory/README.md b/flagscale/agent/memory/README.md new file mode 100644 index 000000000..2f85210ea --- /dev/null +++ b/flagscale/agent/memory/README.md @@ -0,0 +1,265 @@ +# FlagScale Agent Memory Module + +Memory management module for FlagScale Agent, providing short-term and long-term memory capabilities. + +## Overview + +- **Short-term Memory (InMemoryMemory)**: Fast access to conversation history, stored in memory +- **Long-term Memory (Mem0LongTermMemory)**: Vector retrieval and semantic search based on mem0 + +### Difference from Collaborator + +| Module | Layer | Responsibility | Storage | Scope | +|--------|-------|----------------|---------|-------| +| **Collaborator** | Coordination | Agent collaboration, state synchronization, message passing | Redis (distributed) | Cross-Agent | +| **Memory** | Cognitive | Conversation history, knowledge retrieval, context management | Local memory/Vector database | Single Agent internal | + +## Quick Start + +### Basic Usage (Short-term Memory Only) + +```python +from flagscale.agent.memory import MemoryManager, Msg +import asyncio + +async def main(): + memory = MemoryManager(enable_short_term=True, max_size=100) + + await memory.short_term_memory.add(Msg(role="user", content="Hello")) + await memory.short_term_memory.add(Msg(role="assistant", content="Hi there!")) + + results = await memory.short_term_memory.retrieve("Hello", limit=5) + print(f"Found {len(results)} related memories") + +asyncio.run(main()) +``` + +### Integration with ToolRegistry + +```python +from flagscale.agent.tool_match import ToolRegistry +from flagscale.agent.memory import MemoryManager, register_memory_tools + +memory = MemoryManager(enable_short_term=True, max_size=100) +registry = ToolRegistry() +registered = register_memory_tools(registry, memory) +``` + +### Using Configuration File + +```python +from flagscale.agent.memory import load_config_from_file, MemoryManager + +# Load configuration +config = load_config_from_file("memory_config.yaml") + +# Create memory manager (LLM and Embedding models required) +memory = MemoryManager.from_config( + config, + llm_model=your_llm_model, + embedding_model=your_embedding_model, +) +``` + +## Configuration + +### Configuration File Format + +```yaml +memory: + # Short-term memory configuration + short_term: + max_size: 1000 # Maximum number of messages to store + auto_cleanup: true # Automatically clean up old messages + + # Long-term memory configuration + long_term: + provider: "mem0" # Use mem0 as long-term memory provider + agent_name: "my_agent" # Agent name (required when enabling long-term memory) + user_name: null # User name (optional) + run_name: null # Run session name (optional) + + # Vector store configuration + vector_store_type: "qdrant" # Vector database type + vector_store_path: "./qdrant_storage" # Storage path + collection_name: "memory_collection" # Collection name + on_disk: true # Use disk persistence + + default_memory_type: null # Default memory type + + # LLM model configuration (optional, provide in code if not configured) + llm: + provider: "openai" # Provider: openai + model: "gpt-3.5-turbo" # Model name + api_base: null # API endpoint (null uses default endpoint) + api_key: "${OPENAI_API_KEY}" # API key (supports environment variable format) + temperature: 0.7 # Temperature parameter + max_tokens: 2000 # Maximum tokens + + # Embedding model configuration (optional, provide in code if not configured) + embedding: + provider: "openai" # Provider: openai + model: "text-embedding-3-small" # Model name + api_base: null # API endpoint (null uses default endpoint) + api_key: "${OPENAI_API_KEY}" # API key (supports environment variable format) + dimensions: 1536 # Vector dimensions + + # Global configuration + enable_short_term: true # Enable short-term memory + enable_long_term: false # Enable long-term memory + auto_record: false # Automatically record to long-term memory + auto_record_threshold: 10 # Auto-record message count threshold +``` + +### Loading Configuration File + +```python +from flagscale.agent.memory import load_config_from_file, load_config, MemoryManager + +# Method 1: Direct path specification +config = load_config_from_file("memory_config.yaml") + +# Method 2: Read path from environment variable +# export FLAGSCALE_MEMORY_CONFIG=/path/to/memory_config.yaml +config = load_config() + +# Create from configuration +memory = MemoryManager.from_config(config, llm_model=llm, embedding_model=emb) +``` + +## LLM and Embedding Model Setup + +Long-term memory requires LLM and Embedding models. Models must be callable and support synchronous or asynchronous interfaces. + +### Embedding Dimension Adaptation + +Different Embedding models have different output dimensions (e.g., OpenAI text-embedding-3-small is 1536-dimensional, text-embedding-3-large is 3072-dimensional). The system adapts through the following mechanisms: + +1. **Configuration File Specification**: Specify dimensions in the `embedding.dimensions` field of the configuration file + ```yaml + embedding: + provider: "openai" + model: "text-embedding-3-small" + dimensions: 1536 # Specify vector dimensions + ``` + +2. **Automatic Transfer to Vector Store**: Dimension information is automatically passed to mem0's vector store configuration (`embedding_model_dims`), ensuring the vector database uses the correct dimensions + +3. **OpenAI Models**: For OpenAI models that support dimension parameters, the `dimensions` parameter is passed to the API, allowing you to adjust output dimensions + +4. **Custom Models**: When using custom Embedding models, you need to manually specify `dimensions` in the configuration, or pass it through the `embedding_dimensions` parameter: + ```python + memory = MemoryManager( + config=config, + embedding_model=your_custom_embedding_model, + embedding_dimensions=1024 # Specify custom model dimensions + ) + ``` + +**Note**: If dimensions are not specified, some vector databases may not initialize correctly. It is recommended to always explicitly specify dimensions in the configuration. + +### Model Interfaces + +**LLM Model Interface:** +```python +def llm_model(messages: List[Dict[str, str]], tools: List[Dict] = None) -> str: + # messages: [{"role": "user", "content": "..."}, ...] + # Returns: str or object containing content/text/message.content + pass +``` + +**Embedding Model Interface:** +```python +from typing import List, Union + +def embedding_model(texts: List[str]) -> Union[List[float], List[List[float]]]: + # texts: List of texts + # Returns: List[float] for single text, List[List[float]] for multiple texts + pass +``` + +### Usage Example + +```python +from openai import AsyncOpenAI +from flagscale.agent.memory import load_config_from_file, MemoryManager + +config = load_config_from_file("memory_config.yaml") +client = AsyncOpenAI(api_key="your-api-key") + +async def openai_llm(messages, tools=None): + response = await client.chat.completions.create( + model="gpt-3.5-turbo", messages=messages, tools=tools + ) + return response.choices[0].message.content + +async def openai_embedding(texts): + response = await client.embeddings.create( + model="text-embedding-3-small", input=texts + ) + embeddings = [item.embedding for item in response.data] + return embeddings[0] if len(texts) == 1 else embeddings + +memory = MemoryManager.from_config( + config, + llm_model=openai_llm, + embedding_model=openai_embedding, +) +``` + +## API Reference + +### MemoryManager + +```python +memory = MemoryManager( + enable_short_term=True, # Whether to enable short-term memory + enable_long_term=False, # Whether to enable long-term memory + max_size=1000, # Maximum capacity of short-term memory + llm_model=None, # LLM model (required for long-term memory) + embedding_model=None, # Embedding model (required for long-term memory) + agent_name=None, # Agent name + vector_store_path="./qdrant_storage", # Vector store path +) +``` + +### register_memory_tools + +```python +registered_tools = register_memory_tools( + tool_registry, # ToolRegistry instance + memory_manager, # MemoryManager instance + category="memory" # Tool category +) +``` + +## FAQ + +**Q: What if the configuration file is not found?** +A: You must explicitly specify the configuration file path, or use `create_if_not_exists=True` to automatically create it. + +**Q: Must models be callable?** +A: Yes, models must implement the `__call__` method. + +**Q: Are both synchronous and asynchronous interfaces supported?** +A: Yes, the adapter automatically detects and handles both. + +**Q: What dependencies are required for long-term memory?** +A: You need to install `mem0ai >= 0.1.115` and `qdrant-client`. + +**Q: What's the difference between short-term and long-term memory?** +A: Short-term memory is memory-based for fast access to conversation history; long-term memory is vector-based retrieval requiring LLM and Embedding models. + +## Dependencies + +- `mem0ai >= 0.1.115`: Long-term memory backend (optional) +- `qdrant-client`: Vector database client (optional) +- `pyyaml`: Configuration file support +- `pydantic >= 2.0`: Data validation + +## Notes + +1. **Short-term Memory**: No additional dependencies required, can be used directly +2. **Long-term Memory**: Requires LLM and Embedding models, as well as the mem0 library +3. **Qdrant**: Uses embedded mode by default, no separate service deployment required +4. **Asynchronous Operations**: All memory operations are asynchronous and require `await` diff --git a/flagscale/agent/memory/__init__.py b/flagscale/agent/memory/__init__.py new file mode 100644 index 000000000..6061a48f1 --- /dev/null +++ b/flagscale/agent/memory/__init__.py @@ -0,0 +1,51 @@ +"""FlagScale Agent Memory Module + +Memory management module providing short-term and long-term memory capabilities. +""" + +from .base import LongTermMemoryBase, MemoryBase, Msg, StateModule, TextBlock, ToolResponse +from .long_term_memory import Mem0LongTermMemory +from .memory_config import ( + EmbeddingConfig, + LLMConfig, + LongTermMemoryConfig, + MemoryModuleConfig, + ModelFactory, + ShortTermMemoryConfig, + create_config_file, + create_default_config, + get_config_from_env, + load_config, + load_config_from_file, +) +from .memory_manager import MemoryManager +from .memory_tools import MemoryToolkit, create_memory_toolkit, register_memory_tools +from .short_term_memory import InMemoryMemory + +__all__ = [ + "MemoryManager", + "register_memory_tools", + "create_memory_toolkit", + "MemoryBase", + "LongTermMemoryBase", + "StateModule", + "Msg", + "TextBlock", + "ToolResponse", + "InMemoryMemory", + "Mem0LongTermMemory", + "MemoryToolkit", + "MemoryModuleConfig", + "ShortTermMemoryConfig", + "LongTermMemoryConfig", + "LLMConfig", + "EmbeddingConfig", + "create_default_config", + "load_config_from_file", + "load_config", + "create_config_file", + "get_config_from_env", + "ModelFactory", +] + +__version__ = "1.0.0" diff --git a/flagscale/agent/memory/base.py b/flagscale/agent/memory/base.py new file mode 100644 index 000000000..b35dc6d76 --- /dev/null +++ b/flagscale/agent/memory/base.py @@ -0,0 +1,346 @@ +"""Base class definitions and message data structures for memory system.""" + +import uuid + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, Dict, Iterable, List, Optional, Union + + +@dataclass +class Msg: + """Message class compatible with AgentScope's Msg structure. + + This is the core data structure in the memory system for storing and passing messages. + """ + + id: str = field(default_factory=lambda: str(uuid.uuid4())) + """Unique identifier for the message""" + + role: str = "assistant" + """Message role, can be 'user', 'assistant', or 'system'""" + + content: Union[str, List[Dict[str, Any]]] = "" + """Message content, can be a string or a list of structured content blocks""" + + name: Optional[str] = None + """Message sender name (optional)""" + + metadata: Optional[Dict[str, Any]] = None + """Message metadata (optional)""" + + timestamp: datetime = field(default_factory=datetime.now) + """Message timestamp""" + + def to_dict(self) -> Dict[str, Any]: + """Convert message to dictionary format. + + Returns: + Dict[str, Any]: Dictionary representation of the message + """ + data = { + "id": self.id, + "role": self.role, + "content": self.content, + "timestamp": self.timestamp.isoformat(), + } + + if self.name is not None: + data["name"] = self.name + + if self.metadata is not None: + data["metadata"] = self.metadata + + return data + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Msg": + """Create message instance from dictionary. + + Args: + data (Dict[str, Any]): Dictionary containing message data + + Returns: + Msg: Message instance + """ + # Handle timestamp + timestamp = data.get("timestamp") + if isinstance(timestamp, str): + timestamp = datetime.fromisoformat(timestamp) + elif timestamp is None: + timestamp = datetime.now() + + return cls( + id=data.get("id", str(uuid.uuid4())), + role=data.get("role", "assistant"), + content=data.get("content", ""), + name=data.get("name"), + metadata=data.get("metadata"), + timestamp=timestamp, + ) + + def __str__(self) -> str: + """Return string representation of the message.""" + content_str = self.content + if isinstance(self.content, list): + # If structured content, extract text parts + text_parts = [ + str(block.get("text", "")) if isinstance(block, dict) else str(block) + for block in self.content + ] + content_str = " ".join(text_parts) + + return f"Msg(role={self.role}, content={content_str[:50]}...)" + + +@dataclass +class TextBlock: + """Text content block.""" + + text: str = "" + + def __init__(self, text: str = ""): + """Initialize text block. + + Args: + text (str): Text content + """ + self.text = text + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + return {"type": "text", "text": self.text} + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "TextBlock": + """Create instance from dictionary.""" + return cls(text=data.get("text", "")) + + +@dataclass +class ToolResponse: + """Tool response class for encapsulating tool function return results.""" + + content: Union[str, List[TextBlock], List[Dict[str, Any]]] = field(default_factory=list) + + def __init__(self, content: Union[str, List[TextBlock], List[Dict[str, Any]]] = None): + """Initialize tool response. + + Args: + content: Response content, can be a string, list of TextBlocks, or list of dictionaries + """ + if content is None: + content = [] + + if isinstance(content, str): + content = [TextBlock(text=content)] + elif isinstance(content, list): + # Convert dicts to TextBlocks + converted = [] + for item in content: + if isinstance(item, TextBlock): + converted.append(item) + elif isinstance(item, dict): + if item.get("type") == "text": + converted.append(TextBlock(text=item.get("text", ""))) + else: + # Keep other types of blocks unchanged + converted.append(item) + else: + # Convert to text block + converted.append(TextBlock(text=str(item))) + content = converted + + self.content = content + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + if isinstance(self.content, list): + return { + "content": [ + item.to_dict() if isinstance(item, TextBlock) else item for item in self.content + ] + } + return {"content": self.content} + + def get_text(self) -> str: + """Get plain text content of the response. + + Returns: + str: Plain text content + """ + if isinstance(self.content, str): + return self.content + + if isinstance(self.content, list): + text_parts = [] + for item in self.content: + if isinstance(item, TextBlock): + text_parts.append(item.text) + elif isinstance(item, dict): + text_parts.append(str(item.get("text", ""))) + else: + text_parts.append(str(item)) + return "\n".join(text_parts) + + return str(self.content) + + +class StateModule(ABC): + """State module base class providing state persistence functionality.""" + + @abstractmethod + def state_dict(self) -> dict: + """Get module state dictionary. + + Returns: + dict: State dictionary + """ + pass + + @abstractmethod + def load_state_dict(self, state_dict: dict, strict: bool = True) -> None: + """Load module state from state dictionary. + + Args: + state_dict (dict): State dictionary + strict (bool): If True, raise error when keys are missing in state dictionary + """ + pass + + +class MemoryBase(StateModule): + """Base class for memory system, defining common interface for all memory types. + + This base class references AgentScope's MemoryBase design to ensure architectural consistency. + """ + + @abstractmethod + async def add(self, *args: Any, **kwargs: Any) -> None: + """Add items to memory. + + Args: + *args: Positional arguments + **kwargs: Keyword arguments + """ + pass + + @abstractmethod + async def delete(self, *args: Any, **kwargs: Any) -> None: + """Delete items from memory. + + Args: + *args: Positional arguments + **kwargs: Keyword arguments + """ + pass + + @abstractmethod + async def retrieve(self, *args: Any, **kwargs: Any) -> Any: + """Retrieve items from memory. + + Args: + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + Any: Retrieved results + """ + pass + + @abstractmethod + async def size(self) -> int: + """Get memory size. + + Returns: + int: Number of items in memory + """ + pass + + @abstractmethod + async def clear(self) -> None: + """Clear memory content.""" + pass + + @abstractmethod + async def get_memory(self, *args: Any, **kwargs: Any) -> List[Msg]: + """Get memory content. + + Args: + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + List[Msg]: List of messages + """ + pass + + +class LongTermMemoryBase(StateModule): + """Long-term memory base class, defining dedicated interface for long-term memory. + + Long-term memory is a temporal memory management system that supports: + - record/retrieve: Developer interface for actively managing memory in code + - record_to_memory/retrieve_from_memory: Tool interface for agent autonomous calls + """ + + @abstractmethod + async def record(self, msgs: Union[Msg, List[Msg], None], **kwargs: Any) -> None: + """Developer interface: Record messages to long-term memory. + + This method is called by developers in code, e.g., automatically record after each conversation. + + Args: + msgs: Messages or list of messages to record + **kwargs: Additional keyword arguments + """ + pass + + @abstractmethod + async def retrieve(self, msg: Union[Msg, List[Msg], None], **kwargs: Any) -> str: + """Developer interface: Retrieve information from long-term memory. + + This method is called by developers in code, e.g., retrieve related memories before each response. + + Args: + msg: Messages or list of messages used for retrieval + **kwargs: Additional keyword arguments + + Returns: + str: Retrieved memory content + """ + pass + + @abstractmethod + async def record_to_memory(self, thinking: str, content: List[str], **kwargs: Any) -> Any: + """Tool interface: Record important information to long-term memory. + + This method is wrapped as a tool function for agent autonomous calls. + Agent can actively decide when to record information. + + Args: + thinking (str): Agent's thinking process and reasoning + content (List[str]): List of content to record + **kwargs: Additional keyword arguments + + Returns: + Tool response object + """ + pass + + @abstractmethod + async def retrieve_from_memory(self, keywords: List[str], **kwargs: Any) -> Any: + """Tool interface: Retrieve information from long-term memory. + + This method is wrapped as a tool function for agent autonomous calls. + Agent can actively decide when to retrieve memories. + + Args: + keywords (List[str]): List of retrieval keywords + **kwargs: Additional keyword arguments + + Returns: + Tool response object + """ + pass diff --git a/flagscale/agent/memory/long_term_memory.py b/flagscale/agent/memory/long_term_memory.py new file mode 100644 index 000000000..c562dff2a --- /dev/null +++ b/flagscale/agent/memory/long_term_memory.py @@ -0,0 +1,723 @@ +"""Long-term memory implementation using mem0 library for vector retrieval and semantic search.""" + +import asyncio +import json +import logging + +from importlib import metadata +from typing import Any, Dict, List, Literal, Optional, Union + +from packaging import version + +from .base import LongTermMemoryBase, Msg, TextBlock, ToolResponse + +logger = logging.getLogger(__name__) + +# Try to import mem0 +try: + import mem0 + + from mem0.configs.base import MemoryConfig + from mem0.configs.embeddings.base import BaseEmbedderConfig + from mem0.configs.llms.base import BaseLlmConfig + from mem0.embeddings.base import EmbeddingBase + from mem0.llms.base import LLMBase + from mem0.vector_stores.configs import VectorStoreConfig + + MEM0_AVAILABLE = True +except ImportError as e: + MEM0_AVAILABLE = False + logger.warning("mem0 library not installed, long-term memory will be unavailable") + + +class CustomLLMAdapter(LLMBase): + """Custom LLM adapter for integrating custom LLM models into mem0. + + This adapter references AgentScope's AgentScopeLLM implementation, + can adapt any LLM model that implements the standard interface. + """ + + def __init__(self, config: BaseLlmConfig = None): + """Initialize LLM adapter. + + Args: + config: LLM configuration object, should contain 'model' parameter + + Raises: + ValueError: If required configuration parameters are missing + """ + super().__init__(config) + + if self.config.model is None: + raise ValueError("`model` parameter is required") + + self.llm_model = self.config.model + logger.info(f"CustomLLMAdapter initialized: {type(self.llm_model).__name__}") + + def generate_response( + self, + messages: List[Dict[str, str]], + response_format: Any = None, + tools: List[Dict] = None, + tool_choice: str = "auto", + ) -> str: + """Generate response using custom LLM model. + + Args: + messages: List of messages, each containing 'role' and 'content' + response_format: Response format (not used in this adapter) + tools: List of tools (not used in this adapter) + tool_choice: Tool choice method (not used in this adapter) + + Returns: + str: Generated response text + + Raises: + RuntimeError: If response generation fails + """ + try: + # Check if llm_model has __call__ method + if not callable(self.llm_model): + raise ValueError("LLM model must be callable") + + # Async models are not supported in this sync method to avoid event loop conflicts. + if asyncio.iscoroutinefunction(self.llm_model): + raise TypeError( + "Async LLM models are not supported by this adapter. " + "Please provide a synchronous callable." + ) + + # Sync call + response = self.llm_model(messages, tools=tools) + + # Extract text response + return self._extract_text_from_response(response) + + except Exception as e: + logger.error(f"LLM response generation failed: {e}") + raise RuntimeError(f"Error generating response with custom LLM model: {str(e)}") from e + + async def _async_generate( + self, messages: List[Dict[str, str]], tools: List[Dict] = None + ) -> Any: + """Helper method for async response generation.""" + return await self.llm_model(messages, tools=tools) + + def _extract_text_from_response(self, response: Any) -> str: + """Extract text content from response. + + Args: + response: LLM model response object + + Returns: + str: Extracted text content + """ + # If response is already a string, return directly + if isinstance(response, str): + return response + + # If response has content attribute + if hasattr(response, "content"): + content = response.content + + # content is a string + if isinstance(content, str): + return content + + # content is a list (may contain multiple blocks) + if isinstance(content, list): + text_parts = [] + thinking_parts = [] + tool_parts = [] + + for block in content: + # Handle dict-type blocks + if isinstance(block, dict): + block_type = block.get("type") + + if block_type == "text": + text_parts.append(block.get("text", "")) + elif block_type == "thinking": + thinking_parts.append(f"[Thinking: {block.get('thinking', '')}]") + elif block_type == "tool_use": + tool_name = block.get("name") + tool_input = block.get("input", {}) + tool_parts.append(f"[Tool: {tool_name} - {str(tool_input)}]") + # Handle object-type blocks + elif hasattr(block, "type"): + if block.type == "text" and hasattr(block, "text"): + text_parts.append(block.text) + elif block.type == "thinking" and hasattr(block, "thinking"): + thinking_parts.append(f"[Thinking: {block.thinking}]") + + # Combine all parts + all_parts = thinking_parts + text_parts + tool_parts + if all_parts: + return "\n".join(all_parts) + + # If response has text attribute + if hasattr(response, "text"): + return response.text + + # If response has message attribute + if hasattr(response, "message") and hasattr(response.message, "content"): + return response.message.content + + # Finally, try to convert to string + return str(response) + + +class CustomEmbeddingAdapter(EmbeddingBase): + """Custom Embedding adapter for integrating custom Embedding models into mem0. + + This adapter references AgentScope's AgentScopeEmbedding implementation, + can adapt any Embedding model that implements the standard interface. + """ + + def __init__(self, config: BaseEmbedderConfig = None): + """Initialize Embedding adapter. + + Args: + config: Embedding configuration object, should contain 'model' parameter + + Raises: + ValueError: If required configuration parameters are missing + """ + super().__init__(config) + + if self.config.model is None: + raise ValueError("`model` parameter is required") + + self.embedding_model = self.config.model + logger.info(f"CustomEmbeddingAdapter initialized: {type(self.embedding_model).__name__}") + + def embed( + self, text: Union[str, List[str]], memory_action: Literal["add", "search", "update"] = None + ) -> List[float]: + """Generate embeddings using custom Embedding model. + + Args: + text: Text or list of texts to embed + memory_action: Memory action type (not used in this adapter) + + Returns: + List[float]: Embedding vector + + Raises: + RuntimeError: If embedding generation fails + """ + try: + # Convert to list format + text_list = [text] if isinstance(text, str) else text + + # Check if embedding_model has __call__ method + if not callable(self.embedding_model): + raise ValueError("Embedding model must be callable") + + # Async models are not supported in this sync method to avoid event loop conflicts. + if asyncio.iscoroutinefunction(self.embedding_model): + raise TypeError( + "Async embedding models are not supported by this adapter. " + "Please provide a synchronous callable." + ) + + # Sync call + response = self.embedding_model(text_list) + + # Extract embedding vector + return self._extract_embedding_from_response(response) + + except Exception as e: + logger.error(f"Embedding generation failed: {e}") + raise RuntimeError( + f"Error generating embeddings with custom Embedding model: {str(e)}" + ) from e + + async def _async_embed(self, text_list: List[str]) -> Any: + """Helper method for async embedding generation.""" + return await self.embedding_model(text_list) + + def _extract_embedding_from_response(self, response: Any) -> List[float]: + """Extract embedding vector from response. + + Args: + response: Embedding model response object + + Returns: + List[float]: Embedding vector + + Raises: + ValueError: If embedding vector cannot be extracted + """ + # If response is already a list, return directly + if isinstance(response, list): + # Check if it's an embedding vector (list of numbers) + if response and isinstance(response[0], (int, float)): + return response + # If it's a list of objects, try to extract the first one + if hasattr(response[0], "embedding"): + return response[0].embedding + + # If response has embeddings attribute (list) + if hasattr(response, "embeddings"): + embeddings = response.embeddings + if isinstance(embeddings, list) and embeddings: + # Get first embedding + first_embedding = embeddings[0] + + # If it's a vector + if isinstance(first_embedding, list): + return first_embedding + + # If it's an object + if hasattr(first_embedding, "embedding"): + return first_embedding.embedding + + # If it's directly a vector + if isinstance(first_embedding, (int, float)): + return embeddings + + # If response has embedding attribute + if hasattr(response, "embedding"): + return response.embedding + + # If response has data attribute + if hasattr(response, "data"): + data = response.data + if isinstance(data, list) and data: + if hasattr(data[0], "embedding"): + return data[0].embedding + + raise ValueError( + f"Cannot extract embedding vector from response. Response type: {type(response)}" + ) + + +def register_custom_adapters_to_mem0(): + """Register custom adapters to mem0 factory. + + This function needs to be called before using mem0 to register custom LLM and Embedding adapters. + """ + if not MEM0_AVAILABLE: + logger.error("mem0 library not available, cannot register adapters") + return False + + try: + from mem0.utils.factory import EmbedderFactory, LlmFactory + + # Check mem0 version + current_version = metadata.version("mem0ai") + is_mem0_version_low = version.parse(current_version) <= version.parse("0.1.115") + + # Register Embedding adapter + EmbedderFactory.provider_to_class["custom"] = f"{__name__}.CustomEmbeddingAdapter" + + # Register LLM adapter (use different format based on version) + if is_mem0_version_low: + # mem0 version <= 0.1.115 + LlmFactory.provider_to_class["custom"] = f"{__name__}.CustomLLMAdapter" + else: + # mem0 version > 0.1.115 + LlmFactory.provider_to_class["custom"] = (f"{__name__}.CustomLLMAdapter", BaseLlmConfig) + + logger.info(f"Custom adapters registered to mem0 (version: {current_version})") + return True + + except Exception as e: + logger.error(f"Failed to register custom adapters: {e}") + return False + + +class Mem0LongTermMemory(LongTermMemoryBase): + """Long-term memory implementation using mem0 library. + + This class references AgentScope's Mem0LongTermMemory implementation, providing: + - Vectorized storage and semantic retrieval + - Persistent storage (Qdrant) + - Tool interface for agent calls + - Developer interface for code management + + Key features: + - record/retrieve: Used by developers in code + - record_to_memory/retrieve_from_memory: Wrapped as tools for agent calls + """ + + def __init__( + self, + agent_name: str = None, + user_name: str = None, + run_name: str = None, + llm_model: Any = None, + embedding_model: Any = None, + vector_store_config: VectorStoreConfig = None, + mem0_config: MemoryConfig = None, + default_memory_type: str = None, + embedding_dimensions: Optional[int] = None, + **kwargs: Any, + ) -> None: + """Initialize long-term memory. + + Args: + agent_name (str): Agent name (at least one of agent_name, user_name, run_name must be provided) + user_name (str): User name + run_name (str): Run/session name + llm_model: LLM model instance + embedding_model: Embedding model instance + vector_store_config: Vector store configuration + mem0_config: mem0 configuration (if provided, will override other configs) + default_memory_type (str): Default memory type + **kwargs: Additional configuration parameters + + Raises: + ImportError: If mem0 library is not installed + ValueError: If required parameters are missing + """ + super().__init__() + + if not MEM0_AVAILABLE: + raise ImportError( + "Please install mem0 library: pip install mem0ai\n" + "Long-term memory requires mem0 library support" + ) + + # Validate that at least one identifier is provided + if agent_name is None and user_name is None and run_name is None: + raise ValueError("At least one of agent_name, user_name, or run_name must be provided") + + # Store identifiers + self.agent_id = agent_name + self.user_id = user_name + self.run_id = run_name + + # Store embedding dimensions for config + self._embedding_dimensions = embedding_dimensions or kwargs.get("embedding_dimensions") + + # Initialize mem0 configuration + self._init_mem0_config( + llm_model=llm_model, + embedding_model=embedding_model, + vector_store_config=vector_store_config, + mem0_config=mem0_config, + embedding_dimensions=self._embedding_dimensions, + **kwargs, + ) + + # Store default memory type + self.default_memory_type = default_memory_type + + logger.info( + f"Mem0LongTermMemory initialized " + f"(agent: {agent_name}, user: {user_name}, run: {run_name})" + ) + + def _init_mem0_config( + self, + llm_model: Any, + embedding_model: Any, + vector_store_config: VectorStoreConfig, + mem0_config: MemoryConfig, + embedding_dimensions: Optional[int] = None, + **kwargs: Any, + ) -> None: + """Initialize mem0 configuration.""" + # Register custom adapters + register_custom_adapters_to_mem0() + + # Dynamically create config classes + _LlmConfig, _EmbedderConfig = self._create_config_classes() + + # Prepare embedder config (dimensions should be set in vector_store, not here) + embedder_config_dict = {"model": embedding_model} + + if mem0_config is not None: + # Use provided mem0 config, but allow overrides + if llm_model is not None: + mem0_config.llm = _LlmConfig(provider="custom", config={"model": llm_model}) + + if embedding_model is not None: + mem0_config.embedder = _EmbedderConfig( + provider="custom", config=embedder_config_dict + ) + + if vector_store_config is not None: + mem0_config.vector_store = vector_store_config + + else: + # Create new mem0 configuration + if llm_model is None or embedding_model is None: + raise ValueError( + "If mem0_config is not provided, llm_model and embedding_model must be provided" + ) + + mem0_config = mem0.configs.base.MemoryConfig( + llm=_LlmConfig(provider="custom", config={"model": llm_model}), + embedder=_EmbedderConfig(provider="custom", config=embedder_config_dict), + ) + + # Set vector store + if vector_store_config is not None: + mem0_config.vector_store = vector_store_config + else: + # Build VectorStoreConfig from kwargs + on_disk = kwargs.get("on_disk", True) + collection_name = kwargs.get("collection_name", "memory_collection") + # Accept both `path` and `vector_store_path` for compatibility + path = kwargs.get("path") or kwargs.get("vector_store_path") or "./qdrant_storage" + provider = kwargs.get("vector_store_type", "qdrant") + + # Include dimensions in vector store config if provided + vector_store_config_dict = { + "on_disk": on_disk, + "collection_name": collection_name, + "path": path, + } + if embedding_dimensions is not None: + vector_store_config_dict["embedding_model_dims"] = embedding_dimensions + + mem0_config.vector_store = mem0.vector_stores.configs.VectorStoreConfig( + provider=provider, config=vector_store_config_dict + ) + + # Initialize AsyncMemory + self.long_term_memory = mem0.AsyncMemory(mem0_config) + logger.info("mem0 AsyncMemory initialized") + + def _create_config_classes(self): + """Create custom configuration classes.""" + from mem0.embeddings.configs import EmbedderConfig + from mem0.llms.configs import LlmConfig + from pydantic import field_validator + + class _CustomLlmConfig(LlmConfig): + """Custom LLM configuration class.""" + + @field_validator("config") + @classmethod + def validate_config(cls, v: Any, values: Any) -> Any: + from mem0.utils.factory import LlmFactory + + provider = values.data.get("provider") + if provider in LlmFactory.provider_to_class: + return v + raise ValueError(f"Unsupported LLM provider: {provider}") + + class _CustomEmbedderConfig(EmbedderConfig): + """Custom Embedder configuration class.""" + + @field_validator("config") + @classmethod + def validate_config(cls, v: Any, values: Any) -> Any: + from mem0.utils.factory import EmbedderFactory + + provider = values.data.get("provider") + if provider in EmbedderFactory.provider_to_class: + return v + raise ValueError(f"Unsupported Embedder provider: {provider}") + + return _CustomLlmConfig, _CustomEmbedderConfig + + async def record( + self, + msgs: Union[Msg, List[Msg], None], + memory_type: str = None, + infer: bool = True, + **kwargs: Any, + ) -> None: + """Developer interface: Record messages to long-term memory. + + Args: + msgs: Messages or list of messages to record + memory_type (str): Memory type + infer (bool): Whether to infer memory content + **kwargs: Additional parameters + """ + if msgs is None: + return + + if isinstance(msgs, Msg): + msgs = [msgs] + + # Filter None + msg_list = [m for m in msgs if m is not None] + if not all(isinstance(m, Msg) for m in msg_list): + raise TypeError("Input must be Msg object or list of Msg objects") + + # Convert to mem0 format + messages = [ + { + "role": "assistant", + "content": "\n".join([str(m.content) for m in msg_list]), + "name": "assistant", + } + ] + + await self._mem0_record(messages, memory_type=memory_type, infer=infer, **kwargs) + + async def retrieve( + self, msg: Union[Msg, List[Msg], None], limit: int = 5, **kwargs: Any + ) -> str: + """Developer interface: Retrieve information from long-term memory. + + Args: + msg: Messages or list of messages used for retrieval + limit (int): Maximum number of results to return + **kwargs: Additional parameters + + Returns: + str: Retrieved memory content + """ + if isinstance(msg, Msg): + msg = [msg] + + if not isinstance(msg, list) or not all(isinstance(m, Msg) for m in msg): + raise TypeError("Input must be Msg object or list of Msg objects") + + # Convert to query string + msg_strs = [json.dumps(m.to_dict()["content"], ensure_ascii=False) for m in msg] + + results = [] + for query in msg_strs: + result = await self.long_term_memory.search( + query=query, + agent_id=self.agent_id, + user_id=self.user_id, + run_id=self.run_id, + limit=limit, + ) + if result and "results" in result: + results.extend([item["memory"] for item in result["results"]]) + + return "\n".join(results) + + async def record_to_memory( + self, thinking: str, content: List[str], **kwargs: Any + ) -> ToolResponse: + """Tool interface: Record important information to long-term memory. + + This method is wrapped as a tool function for agent calls. + + Args: + thinking (str): Agent's thinking process + content (List[str]): List of content to record + **kwargs: Additional parameters + + Returns: + ToolResponse: Tool response object + """ + try: + # Merge thinking process and content + if thinking: + full_content = [thinking] + content + else: + full_content = content + + # Call mem0 record + results = await self._mem0_record( + [{"role": "assistant", "content": "\n".join(full_content), "name": "assistant"}], + **kwargs, + ) + + return ToolResponse( + content=[ + TextBlock( + text=f"Successfully recorded content to long-term memory. Result: {results}" + ) + ] + ) + + except Exception as e: + logger.error(f"Failed to record to memory: {e}") + return ToolResponse(content=[TextBlock(text=f"Failed to record memory: {str(e)}")]) + + async def retrieve_from_memory( + self, keywords: List[str], limit: int = 5, **kwargs: Any + ) -> ToolResponse: + """Tool interface: Retrieve information from long-term memory. + + This method is wrapped as a tool function for agent calls. + + Args: + keywords (List[str]): List of retrieval keywords + limit (int): Maximum number of results to return + **kwargs: Additional parameters + + Returns: + ToolResponse: Tool response object + """ + try: + results = [] + for keyword in keywords: + result = await self.long_term_memory.search( + query=keyword, + agent_id=self.agent_id, + user_id=self.user_id, + run_id=self.run_id, + limit=limit, + ) + if result and "results" in result: + results.extend([item["memory"] for item in result["results"]]) + + if results: + return ToolResponse(content=[TextBlock(text="\n".join(results))]) + else: + return ToolResponse(content=[TextBlock(text="No related memories found")]) + + except Exception as e: + logger.error(f"Failed to retrieve from memory: {e}") + return ToolResponse(content=[TextBlock(text=f"Failed to retrieve memory: {str(e)}")]) + + async def _mem0_record( + self, + messages: Union[str, List[Dict]], + memory_type: str = None, + infer: bool = True, + **kwargs: Any, + ) -> Dict: + """Internal method: Record content using mem0. + + Args: + messages: Messages to record + memory_type (str): Memory type + infer (bool): Whether to infer memory + **kwargs: Additional parameters + + Returns: + Dict: mem0 return result + """ + results = await self.long_term_memory.add( + messages=messages, + agent_id=self.agent_id, + user_id=self.user_id, + run_id=self.run_id, + memory_type=(memory_type if memory_type is not None else self.default_memory_type), + infer=infer, + **kwargs, + ) + logger.debug(f"mem0 record result: {results}") + return results + + def state_dict(self) -> Dict: + """Get state dictionary. + + Returns: + Dict: State dictionary + """ + return { + "agent_id": self.agent_id, + "user_id": self.user_id, + "run_id": self.run_id, + "default_memory_type": self.default_memory_type, + } + + def load_state_dict(self, state_dict: Dict, strict: bool = True) -> None: + """Load from state dictionary. + + Args: + state_dict (Dict): State dictionary + strict (bool): Strict mode + """ + self.agent_id = state_dict.get("agent_id") + self.user_id = state_dict.get("user_id") + self.run_id = state_dict.get("run_id") + self.default_memory_type = state_dict.get("default_memory_type") + logger.info("Loaded long-term memory configuration from state dictionary") diff --git a/flagscale/agent/memory/memory_config.py b/flagscale/agent/memory/memory_config.py new file mode 100644 index 000000000..138f7704c --- /dev/null +++ b/flagscale/agent/memory/memory_config.py @@ -0,0 +1,921 @@ +"""Memory module configuration management.""" + +import logging +import os + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import yaml + +logger = logging.getLogger(__name__) + + +@dataclass +class ShortTermMemoryConfig: + """Short-term memory configuration.""" + + max_size: int = 1000 + """Maximum capacity of short-term memory""" + + auto_cleanup: bool = True + """Whether to automatically clean up old messages""" + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any]) -> "ShortTermMemoryConfig": + """Create configuration from dictionary.""" + return cls( + max_size=config_dict.get("max_size", 1000), + auto_cleanup=config_dict.get("auto_cleanup", True), + ) + + +@dataclass +class LLMConfig: + """LLM model configuration.""" + + provider: Optional[str] = None + """Provider type: openai, huggingface, custom, null""" + + model: Optional[str] = None + """Model name or path""" + + api_base: Optional[str] = None + """API base URL (for external APIs)""" + + api_key: Optional[str] = None + """API key (supports environment variables, format: ${ENV_VAR_NAME})""" + + temperature: float = 0.7 + """Temperature parameter""" + + max_tokens: int = 2000 + """Maximum number of tokens""" + + # Other custom parameters + extra_params: Dict[str, Any] = field(default_factory=dict) + """Additional custom parameters""" + + @classmethod + def from_dict(cls, config_dict: Optional[Dict[str, Any]]) -> Optional["LLMConfig"]: + """Create configuration from dictionary. + + Supports multiple configuration formats (in priority order): + 1. Direct configuration (recommended): {"provider": "openai", "model": "gpt-3.5-turbo", ...} + 2. Simplified configuration: {"provider": "openai", "model": "gpt-3.5-turbo"} # others use defaults + 3. Preset configuration: {"preset": "openai-gpt35"} # backward compatible, not recommended + 4. String shorthand: {"openai": "gpt-3.5-turbo"} # backward compatible + """ + if config_dict is None: + return None + + # Prioritize direct configuration (most readable) + if "provider" in config_dict or "model" in config_dict: + return cls( + provider=config_dict.get("provider"), + model=config_dict.get("model"), + api_base=config_dict.get("api_base"), + api_key=config_dict.get("api_key"), + temperature=config_dict.get("temperature", 0.7), + max_tokens=config_dict.get("max_tokens", 2000), + extra_params=config_dict.get("extra_params", {}), + ) + + # Handle preset configuration (backward compatible) + if "preset" in config_dict: + preset = config_dict["preset"] + return cls._from_preset(preset, config_dict) + + # Handle string shorthand format (backward compatible): {"openai": "gpt-3.5-turbo"} + if isinstance(config_dict, dict) and len(config_dict) == 1: + key, value = next(iter(config_dict.items())) + if isinstance(value, str) and key in ["openai", "huggingface", "custom"]: + return cls(provider=key, model=value, temperature=0.7, max_tokens=2000) + else: + logger.warning(f"Unrecognized LLM shorthand config: {config_dict}") + + logger.warning(f"Unrecognized LLM config format: {config_dict}") + return None + + @classmethod + def _from_preset(cls, preset: str, overrides: Dict[str, Any] = None) -> "LLMConfig": + """Create configuration from preset.""" + presets = { + "openai-gpt35": {"provider": "openai", "model": "gpt-3.5-turbo"}, + "openai-gpt4": {"provider": "openai", "model": "gpt-4"}, + "openai-gpt4o": {"provider": "openai", "model": "gpt-4o"}, + "huggingface-llama": { + "provider": "huggingface", + "model": "meta-llama/Llama-2-7b-chat-hf", + }, + } + + if preset not in presets: + logger.warning(f"Unknown preset configuration: {preset}, using default configuration") + preset_config = {} + else: + preset_config = presets[preset].copy() + + if overrides: + preset_config.update({k: v for k, v in overrides.items() if k != "preset"}) + + return cls.from_dict(preset_config) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + result = {} + if self.provider is not None: + result["provider"] = self.provider + if self.model is not None: + result["model"] = self.model + if self.api_base is not None: + result["api_base"] = self.api_base + if self.api_key is not None: + result["api_key"] = self.api_key + if self.temperature != 0.7: + result["temperature"] = self.temperature + if self.max_tokens != 2000: + result["max_tokens"] = self.max_tokens + if self.extra_params: + result["extra_params"] = self.extra_params + return result + + +@dataclass +class EmbeddingConfig: + """Embedding model configuration.""" + + provider: Optional[str] = None + """Provider type: openai, huggingface, custom, null""" + + model: Optional[str] = None + """Model name or path""" + + api_base: Optional[str] = None + """API base URL (for external APIs)""" + + api_key: Optional[str] = None + """API key (supports environment variables, format: ${ENV_VAR_NAME})""" + + dimensions: Optional[int] = None + """Vector dimensions""" + + # Other custom parameters + extra_params: Dict[str, Any] = field(default_factory=dict) + """Additional custom parameters""" + + @classmethod + def from_dict(cls, config_dict: Optional[Dict[str, Any]]) -> Optional["EmbeddingConfig"]: + """Create configuration from dictionary. + + Supports multiple configuration formats (in priority order): + 1. Direct configuration (recommended): {"provider": "openai", "model": "text-embedding-3-small", ...} + 2. Simplified configuration: {"provider": "openai", "model": "text-embedding-3-small"} # others use defaults + 3. Preset configuration: {"preset": "openai-small"} # backward compatible, not recommended + 4. String shorthand: {"openai": "text-embedding-3-small"} # backward compatible + """ + if config_dict is None: + return None + + # Prioritize direct configuration (most readable) + if "provider" in config_dict or "model" in config_dict: + return cls( + provider=config_dict.get("provider"), + model=config_dict.get("model"), + api_base=config_dict.get("api_base"), + api_key=config_dict.get("api_key"), + dimensions=config_dict.get("dimensions"), + extra_params=config_dict.get("extra_params", {}), + ) + + # Handle preset configuration (backward compatible) + if "preset" in config_dict: + preset = config_dict["preset"] + return cls._from_preset(preset, config_dict) + + # Handle string shorthand format (backward compatible): {"openai": "text-embedding-3-small"} + if isinstance(config_dict, dict) and len(config_dict) == 1: + key, value = next(iter(config_dict.items())) + if isinstance(value, str) and key in ["openai", "huggingface", "custom"]: + return cls(provider=key, model=value) + else: + logger.warning(f"Unrecognized Embedding shorthand config: {config_dict}") + + logger.warning(f"Unrecognized Embedding config format: {config_dict}") + return None + + @classmethod + def _from_preset(cls, preset: str, overrides: Dict[str, Any] = None) -> "EmbeddingConfig": + """Create configuration from preset.""" + presets = { + "openai-small": {"provider": "openai", "model": "text-embedding-3-small"}, + "openai-large": {"provider": "openai", "model": "text-embedding-3-large"}, + "huggingface-minilm": { + "provider": "huggingface", + "model": "sentence-transformers/all-MiniLM-L6-v2", + }, + } + + if preset not in presets: + logger.warning(f"Unknown preset configuration: {preset}, using default configuration") + preset_config = {} + else: + preset_config = presets[preset].copy() + + if overrides: + preset_config.update({k: v for k, v in overrides.items() if k != "preset"}) + + return cls.from_dict(preset_config) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + result = {} + if self.provider is not None: + result["provider"] = self.provider + if self.model is not None: + result["model"] = self.model + if self.api_base is not None: + result["api_base"] = self.api_base + if self.api_key is not None: + result["api_key"] = self.api_key + if self.dimensions is not None: + result["dimensions"] = self.dimensions + if self.extra_params: + result["extra_params"] = self.extra_params + return result + + +@dataclass +class LongTermMemoryConfig: + """Long-term memory configuration.""" + + provider: str = "mem0" + """Long-term memory provider""" + + agent_name: Optional[str] = None + """Agent name""" + + user_name: Optional[str] = None + """User name""" + + run_name: Optional[str] = None + """Run session name""" + + vector_store_type: str = "qdrant" + """Vector store type""" + + vector_store_path: str = "./qdrant_storage" + """Vector store path""" + + collection_name: str = "memory_collection" + """Collection name""" + + on_disk: bool = True + """Whether to use disk storage""" + + default_memory_type: Optional[str] = None + """Default memory type""" + + llm: Optional[LLMConfig] = None + """LLM model configuration""" + + embedding: Optional[EmbeddingConfig] = None + """Embedding model configuration""" + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any]) -> "LongTermMemoryConfig": + """Create configuration from dictionary.""" + llm_dict = config_dict.get("llm") + embedding_dict = config_dict.get("embedding") + + return cls( + provider=config_dict.get("provider", "mem0"), + agent_name=config_dict.get("agent_name"), + user_name=config_dict.get("user_name"), + run_name=config_dict.get("run_name"), + vector_store_type=config_dict.get("vector_store_type", "qdrant"), + vector_store_path=config_dict.get("vector_store_path", "./qdrant_storage"), + collection_name=config_dict.get("collection_name", "memory_collection"), + on_disk=config_dict.get("on_disk", True), + default_memory_type=config_dict.get("default_memory_type"), + llm=LLMConfig.from_dict(llm_dict), + embedding=EmbeddingConfig.from_dict(embedding_dict), + ) + + +@dataclass +class MemoryModuleConfig: + """Memory module overall configuration.""" + + short_term: ShortTermMemoryConfig = field(default_factory=ShortTermMemoryConfig) + """Short-term memory configuration""" + + long_term: LongTermMemoryConfig = field(default_factory=LongTermMemoryConfig) + """Long-term memory configuration""" + + enable_short_term: bool = True + """Whether to enable short-term memory""" + + enable_long_term: bool = True + """Whether to enable long-term memory""" + + auto_record: bool = False + """Whether to automatically record conversations to long-term memory""" + + auto_record_threshold: int = 10 + """Threshold for number of messages to auto-record""" + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any]) -> "MemoryModuleConfig": + """Create configuration from dictionary.""" + short_term_dict = config_dict.get("short_term", {}) + long_term_dict = config_dict.get("long_term", {}) + + return cls( + short_term=ShortTermMemoryConfig.from_dict(short_term_dict), + long_term=LongTermMemoryConfig.from_dict(long_term_dict), + enable_short_term=config_dict.get("enable_short_term", True), + enable_long_term=config_dict.get("enable_long_term", True), + auto_record=config_dict.get("auto_record", False), + auto_record_threshold=config_dict.get("auto_record_threshold", 10), + ) + + @classmethod + def from_yaml(cls, yaml_path: str) -> "MemoryModuleConfig": + """Load configuration from YAML file. + + Args: + yaml_path (str): YAML configuration file path + + Returns: + MemoryModuleConfig: Configuration object + + Raises: + FileNotFoundError: If configuration file does not exist + yaml.YAMLError: If configuration file format is invalid + """ + if not os.path.exists(yaml_path): + raise FileNotFoundError(f"Configuration file does not exist: {yaml_path}") + + with open(yaml_path, "r", encoding="utf-8") as f: + config_dict = yaml.safe_load(f) + + # Get memory section configuration + memory_config = config_dict.get("memory", config_dict) + + logger.info(f"Loaded memory configuration from {yaml_path}") + return cls.from_dict(memory_config) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format. + + Returns: + Dict[str, Any]: Configuration dictionary + """ + return { + "short_term": { + "max_size": self.short_term.max_size, + "auto_cleanup": self.short_term.auto_cleanup, + }, + "long_term": { + "provider": self.long_term.provider, + "agent_name": self.long_term.agent_name, + "user_name": self.long_term.user_name, + "run_name": self.long_term.run_name, + "vector_store_type": self.long_term.vector_store_type, + "vector_store_path": self.long_term.vector_store_path, + "collection_name": self.long_term.collection_name, + "on_disk": self.long_term.on_disk, + "default_memory_type": self.long_term.default_memory_type, + "llm": self.long_term.llm.to_dict() if self.long_term.llm else None, + "embedding": ( + self.long_term.embedding.to_dict() if self.long_term.embedding else None + ), + }, + "enable_short_term": self.enable_short_term, + "enable_long_term": self.enable_long_term, + "auto_record": self.auto_record, + "auto_record_threshold": self.auto_record_threshold, + } + + def to_yaml(self, yaml_path: str) -> None: + """Save configuration to YAML file. + + Args: + yaml_path (str): YAML configuration file path + """ + config_dict = {"memory": self.to_dict()} + + with open(yaml_path, "w", encoding="utf-8") as f: + yaml.dump(config_dict, f, allow_unicode=True, default_flow_style=False) + + logger.info(f"Configuration saved to {yaml_path}") + + def validate(self) -> bool: + """Validate configuration validity. + + Returns: + bool: Whether configuration is valid + """ + # Validate that long-term memory has at least one identifier + if self.enable_long_term: + if not any( + [self.long_term.agent_name, self.long_term.user_name, self.long_term.run_name] + ): + logger.error( + "Long-term memory configuration error: must provide at least one of agent_name, user_name, or run_name" + ) + return False + + # Validate short-term memory capacity + if self.short_term.max_size <= 0: + logger.error("Short-term memory configuration error: max_size must be greater than 0") + return False + + logger.info("Configuration validation passed") + return True + + +def create_default_config( + agent_name: str = "default_agent", save_path: str = None +) -> MemoryModuleConfig: + """Create default configuration. + + Args: + agent_name (str): Agent name + save_path (str): If provided, save configuration to this path + + Returns: + MemoryModuleConfig: Default configuration object + """ + config = MemoryModuleConfig( + short_term=ShortTermMemoryConfig(max_size=1000, auto_cleanup=True), + long_term=LongTermMemoryConfig( + provider="mem0", + agent_name=agent_name, + vector_store_type="qdrant", + vector_store_path="./qdrant_storage", + collection_name="memory_collection", + on_disk=True, + ), + enable_short_term=True, + enable_long_term=True, + auto_record=False, + auto_record_threshold=10, + ) + + if save_path: + config.to_yaml(save_path) + + logger.info(f"Created default configuration (agent: {agent_name})") + return config + + +# Configuration Loading Functions + + +def load_config_from_file( + config_path: str, create_if_not_exists: bool = False, default_agent_name: str = "default_agent" +) -> MemoryModuleConfig: + """Load memory module configuration from configuration file. + + Args: + config_path: Configuration file path (supports YAML format) + create_if_not_exists: Whether to create default configuration if file does not exist + default_agent_name: Agent name to use when creating default configuration + + Returns: + MemoryModuleConfig: Configuration object + + Raises: + FileNotFoundError: If configuration file does not exist and create_if_not_exists=False + ValueError: If configuration file format is invalid + """ + config_path = Path(config_path).expanduser().resolve() + + # If file does not exist + if not config_path.exists(): + if create_if_not_exists: + logger.info(f"Configuration file does not exist, creating default: {config_path}") + config = create_default_config( + agent_name=default_agent_name, save_path=str(config_path) + ) + return config + else: + raise FileNotFoundError( + f"Configuration file does not exist: {config_path}\n" + f"Hint: Use create_if_not_exists=True to automatically create default configuration" + ) + + # Load configuration + try: + config = MemoryModuleConfig.from_yaml(str(config_path)) + logger.info(f"Successfully loaded configuration file: {config_path}") + + # Validate configuration + if not config.validate(): + raise ValueError("Configuration validation failed, please check configuration items") + + return config + + except Exception as e: + logger.error(f"Failed to load configuration file: {e}") + raise ValueError(f"Configuration file format error: {e}") from e + + +def create_config_file( + config_path: str, agent_name: str = "default_agent", enable_long_term: bool = True +) -> MemoryModuleConfig: + """Create configuration file. + + Args: + config_path: Configuration file save path + agent_name: Agent name + enable_long_term: Whether to enable long-term memory + + Returns: + MemoryModuleConfig: Created configuration object + """ + config_path = Path(config_path).expanduser().resolve() + + # Ensure directory exists + config_path.parent.mkdir(parents=True, exist_ok=True) + + # Create configuration + config = MemoryModuleConfig( + short_term=ShortTermMemoryConfig(max_size=1000, auto_cleanup=True), + long_term=LongTermMemoryConfig( + agent_name=agent_name if enable_long_term else None, + vector_store_path="./qdrant_storage", + ), + enable_short_term=True, + enable_long_term=enable_long_term, + ) + + # Save configuration + config.to_yaml(str(config_path)) + logger.info(f"Configuration file created: {config_path}") + + return config + + +def get_config_from_env() -> Optional[MemoryModuleConfig]: + """Load configuration path from environment variable and load configuration. + + Environment variables: + FLAGSCALE_MEMORY_CONFIG: Configuration file path + + Returns: + MemoryModuleConfig: Configuration object, or None if environment variable is not set + """ + config_path = os.getenv("FLAGSCALE_MEMORY_CONFIG") + if config_path: + return load_config_from_file(config_path) + return None + + +def load_config( + config_path: Optional[str] = None, + agent_name: str = "default_agent", + create_if_not_exists: bool = False, +) -> MemoryModuleConfig: + """Load configuration. + + Configuration file path is determined by the following priority: + 1. If config_path is provided, use that path + 2. Read from environment variable FLAGSCALE_MEMORY_CONFIG + 3. If neither is provided, raise error + + Args: + config_path: Configuration file path (optional) + agent_name: Agent name (used when creating default configuration) + create_if_not_exists: Whether to create default configuration if file does not exist + + Returns: + MemoryModuleConfig: Configuration object + + Raises: + ValueError: If configuration file path is not provided and environment variable is not set + """ + # Use provided path if available + if config_path: + return load_config_from_file( + config_path, create_if_not_exists=create_if_not_exists, default_agent_name=agent_name + ) + + # Otherwise, try reading from environment variable + env_config = get_config_from_env() + if env_config: + return env_config + + # Neither provided, raise error + raise ValueError( + "Configuration file path not provided. Please use one of the following:\n" + "1. Specify configuration file path: load_config(config_path='memory_config.yaml')\n" + "2. Set environment variable: export FLAGSCALE_MEMORY_CONFIG=/path/to/config.yaml" + ) + + +# Model Factory + + +def _resolve_env_var(value: str) -> str: + """Resolve environment variable reference. + + Args: + value: String that may contain environment variable reference, format: ${ENV_VAR_NAME} + + Returns: + Resolved value + """ + if isinstance(value, str) and value.startswith("${") and value.endswith("}"): + env_var = value[2:-1] + resolved = os.getenv(env_var, "") + if not resolved: + logger.warning(f"Environment variable {env_var} is not set") + return resolved + return value + + +class ModelFactory: + """Model factory for creating model instances from configuration.""" + + @staticmethod + def create_llm_model(config: Optional[LLMConfig]) -> Optional[Any]: + """Create LLM model from configuration. + + Args: + config: LLM configuration object + + Returns: + LLM model instance, or None if config is None + + Raises: + ValueError: If configuration is invalid or provider is not supported + """ + if config is None: + return None + + if config.provider is None: + logger.warning("LLM provider not specified, skipping creation") + return None + + provider = config.provider.lower() + + if provider == "openai": + return ModelFactory._create_openai_llm(config) + elif provider == "huggingface": + return ModelFactory._create_huggingface_llm(config) + elif provider == "custom": + return ModelFactory._create_custom_llm(config) + else: + raise ValueError(f"Unsupported LLM provider: {provider}") + + @staticmethod + def create_embedding_model(config: Optional[EmbeddingConfig]) -> Optional[Any]: + """Create Embedding model from configuration. + + Args: + config: Embedding configuration object + + Returns: + Embedding model instance, or None if config is None + + Raises: + ValueError: If configuration is invalid or provider is not supported + """ + if config is None: + return None + + if config.provider is None: + logger.warning("Embedding provider not specified, skipping creation") + return None + + provider = config.provider.lower() + + if provider == "openai": + return ModelFactory._create_openai_embedding(config) + elif provider == "huggingface": + return ModelFactory._create_huggingface_embedding(config) + elif provider == "custom": + return ModelFactory._create_custom_embedding(config) + else: + raise ValueError(f"Unsupported Embedding provider: {provider}") + + @staticmethod + def _create_openai_llm(config: LLMConfig) -> Any: + """Create OpenAI-compatible LLM model.""" + try: + from openai import AsyncOpenAI + + # If api_key is not specified in config, try to read from environment variables + if config.api_key: + api_key = _resolve_env_var(config.api_key) + else: + # Try to read from common environment variables + api_key = ( + os.getenv("OPENAI_API_KEY") + or os.getenv("OPENAI_KEY") + or os.getenv("ZHIPU_API_KEY") + ) + + api_base = config.api_base + + client = AsyncOpenAI( + api_key=api_key if api_key else "EMPTY", base_url=api_base if api_base else None + ) + + model_name = config.model or "gpt-3.5-turbo" + temperature = config.temperature + max_tokens = config.max_tokens + + async def llm_model(messages: List[Dict[str, str]], tools: List[Dict] = None) -> str: + """OpenAI LLM model call interface.""" + params = { + "model": model_name, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + } + + if tools: + params["tools"] = tools + + # Merge extra parameters + params.update(config.extra_params) + + response = await client.chat.completions.create(**params) + return response.choices[0].message.content + + logger.info(f"Created OpenAI LLM: {model_name}") + return llm_model + + except ImportError: + logger.error("Failed to import OpenAI library, please install: pip install openai") + raise + + @staticmethod + def _create_openai_embedding(config: EmbeddingConfig) -> Any: + """Create OpenAI-compatible Embedding model.""" + try: + from openai import AsyncOpenAI + + # If api_key is not specified in config, try to read from environment variables + if config.api_key: + api_key = _resolve_env_var(config.api_key) + else: + # Try to read from common environment variables + api_key = ( + os.getenv("OPENAI_API_KEY") + or os.getenv("OPENAI_KEY") + or os.getenv("ZHIPU_API_KEY") + ) + + api_base = config.api_base + + client = AsyncOpenAI( + api_key=api_key if api_key else "EMPTY", base_url=api_base if api_base else None + ) + + model_name = config.model or "text-embedding-3-small" + dimensions = config.dimensions + + async def embedding_model(texts: List[str]) -> Union[List[float], List[List[float]]]: + """OpenAI Embedding model call interface.""" + is_single = len(texts) == 1 + + params = {"model": model_name, "input": texts} + + if dimensions: + params["dimensions"] = dimensions + + # Merge extra parameters + params.update(config.extra_params) + + response = await client.embeddings.create(**params) + embeddings = [item.embedding for item in response.data] + + return embeddings[0] if is_single else embeddings + + logger.info(f"Created OpenAI Embedding: {model_name}") + return embedding_model + + except ImportError: + logger.error("Failed to import OpenAI library, please install: pip install openai") + raise + + @staticmethod + def _create_huggingface_llm(config: LLMConfig) -> Any: + """Create HuggingFace LLM model.""" + try: + import torch + + from transformers import AutoModelForCausalLM, AutoTokenizer + + if not config.model: + raise ValueError("HuggingFace LLM requires model parameter to be specified") + + model_name = config.model + device = "cuda" if torch.cuda.is_available() else "cpu" + + logger.info(f"Loading HuggingFace LLM model: {model_name} (device: {device})") + + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.float16 if device == "cuda" else torch.float32, + device_map="auto" if device == "cuda" else None, + ) + + if device == "cpu": + model = model.to(device) + + max_tokens = config.max_tokens + temperature = config.temperature + + def llm_model(messages: List[Dict[str, str]], tools: List[Dict] = None) -> str: + """HuggingFace LLM model call interface (synchronous).""" + # Convert messages to text + text = "" + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + text += f"{role}: {content}\n" + text += "assistant:" + + inputs = tokenizer(text, return_tensors="pt").to(device) + + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=max_tokens, + temperature=temperature, + do_sample=temperature > 0, + pad_token_id=tokenizer.eos_token_id, + ) + + response = tokenizer.decode( + outputs[0][inputs.input_ids.shape[1] :], skip_special_tokens=True + ) + return response.strip() + + logger.info(f"Created HuggingFace LLM: {model_name}") + return llm_model + + except ImportError: + logger.error( + "Failed to import transformers library, please install: pip install transformers torch" + ) + raise + + @staticmethod + def _create_huggingface_embedding(config: EmbeddingConfig) -> Any: + """Create HuggingFace Embedding model.""" + try: + import torch + + from sentence_transformers import SentenceTransformer + + if not config.model: + raise ValueError("HuggingFace Embedding requires model parameter to be specified") + + model_name = config.model + device = "cuda" if torch.cuda.is_available() else "cpu" + + logger.info(f"Loading HuggingFace Embedding model: {model_name} (device: {device})") + + model = SentenceTransformer(model_name, device=device) + + def embedding_model(texts: List[str]) -> Union[List[float], List[List[float]]]: + """HuggingFace Embedding model call interface (synchronous).""" + is_single = len(texts) == 1 + embeddings = model.encode(texts, convert_to_numpy=True).tolist() + return embeddings[0] if is_single else embeddings + + logger.info(f"Created HuggingFace Embedding: {model_name}") + return embedding_model + + except ImportError: + logger.error( + "Failed to import sentence-transformers library, please install: pip install sentence-transformers" + ) + raise + + @staticmethod + def _create_custom_llm(config: LLMConfig) -> Any: + """Create custom LLM model (requires user to provide model instance in code).""" + logger.warning( + "custom provider requires user to provide model instance in code, " + "cannot be automatically created from config. Please use MemoryManager(..., llm_model=your_model)" + ) + return None + + @staticmethod + def _create_custom_embedding(config: EmbeddingConfig) -> Any: + """Create custom Embedding model (requires user to provide model instance in code).""" + logger.warning( + "custom provider requires user to provide model instance in code, " + "cannot be automatically created from config. Please use MemoryManager(..., embedding_model=your_model)" + ) + return None diff --git a/flagscale/agent/memory/memory_config.yaml b/flagscale/agent/memory/memory_config.yaml new file mode 100644 index 000000000..1be0af89b --- /dev/null +++ b/flagscale/agent/memory/memory_config.yaml @@ -0,0 +1,40 @@ +# Memory Module Configuration + +memory: + short_term: + max_size: 1000 + auto_cleanup: true + + long_term: + provider: "mem0" + agent_name: "my_agent" + user_name: null + run_name: null + + vector_store_type: "qdrant" + vector_store_path: "./qdrant_storage" + collection_name: "memory_collection" + on_disk: true + + default_memory_type: null + + llm: + provider: "openai" + model: "gpt-3.5-turbo" + api_base: null + api_key: "${OPENAI_API_KEY}" + temperature: 0.7 + max_tokens: 2000 + + embedding: + provider: "openai" + model: "text-embedding-3-small" + api_base: null + api_key: "${OPENAI_API_KEY}" + dimensions: 1536 + + enable_short_term: true + enable_long_term: false + auto_record: false + auto_record_threshold: 10 + diff --git a/flagscale/agent/memory/memory_manager.py b/flagscale/agent/memory/memory_manager.py new file mode 100644 index 000000000..261b9ad2e --- /dev/null +++ b/flagscale/agent/memory/memory_manager.py @@ -0,0 +1,233 @@ +"""Memory manager for unified management of short-term and long-term memory.""" + +import logging + +from typing import Any, Optional + +from .long_term_memory import Mem0LongTermMemory +from .memory_config import ( + LongTermMemoryConfig, + MemoryModuleConfig, + ModelFactory, + ShortTermMemoryConfig, +) +from .short_term_memory import InMemoryMemory + +logger = logging.getLogger(__name__) + + +class MemoryManager: + """Memory manager for unified management of short-term and long-term memory. + + This class provides a simplified interface to manage Agent's memory system, + adapted for FlagScale usage scenarios. + """ + + def __init__( + self, + config: Optional[MemoryModuleConfig] = None, + short_term_memory: Optional[InMemoryMemory] = None, + long_term_memory: Optional[Mem0LongTermMemory] = None, + **kwargs: Any, + ): + """Initialize memory manager. + + Args: + config: Memory module configuration (optional) + short_term_memory: Short-term memory instance (optional, used directly if provided) + long_term_memory: Long-term memory instance (optional, used directly if provided) + **kwargs: Additional configuration parameters for creating memory instances + """ + self.config = config + # Internal status for diagnostics + self._status = { + "short_term": {"enabled": False, "reason": None}, + "long_term": { + "enabled": False, + "reason": None, + "errors": [], + "llm_created": False, + "embedding_created": False, + }, + } + + # Initialize short-term memory + if short_term_memory is not None: + self.short_term_memory = short_term_memory + self._status["short_term"]["enabled"] = True + elif config and config.enable_short_term: + self.short_term_memory = InMemoryMemory(max_size=config.short_term.max_size) + self._status["short_term"]["enabled"] = True + elif kwargs.get("enable_short_term", True): + max_size = kwargs.get("max_size", 1000) + self.short_term_memory = InMemoryMemory(max_size=max_size) + self._status["short_term"]["enabled"] = True + else: + self.short_term_memory = None + self._status["short_term"]["reason"] = "disabled_by_config" + + # Initialize long-term memory + if long_term_memory is not None: + self.long_term_memory = long_term_memory + self._status["long_term"]["enabled"] = True + elif config and config.enable_long_term: + # Try to get models from config or kwargs + llm_model = kwargs.get("llm_model") + embedding_model = kwargs.get("embedding_model") + + # If models not provided, try to create from config + if llm_model is None and config.long_term.llm: + try: + llm_model = ModelFactory.create_llm_model(config.long_term.llm) + if llm_model: + logger.info("Successfully created LLM model from config") + self._status["long_term"]["llm_created"] = True + except Exception as e: + logger.warning(f"Failed to create LLM model from config: {e}") + self._status["long_term"]["errors"].append(str(e)) + self._status["long_term"]["reason"] = "llm_model_creation_failed" + + if embedding_model is None and config.long_term.embedding: + try: + embedding_model = ModelFactory.create_embedding_model( + config.long_term.embedding + ) + if embedding_model: + logger.info("Successfully created Embedding model from config") + self._status["long_term"]["embedding_created"] = True + except Exception as e: + logger.warning(f"Failed to create Embedding model from config: {e}") + self._status["long_term"]["errors"].append(str(e)) + # keep first reason if already set by llm failure + if not self._status["long_term"]["reason"]: + self._status["long_term"]["reason"] = "embedding_model_creation_failed" + + if llm_model is None or embedding_model is None: + logger.warning( + "Long-term memory requires llm_model and embedding_model. " + "Not provided and cannot be created from config, long-term memory will be unavailable" + ) + self.long_term_memory = None + missing = [] + if llm_model is None: + missing.append("llm_model") + if embedding_model is None: + missing.append("embedding_model") + if not self._status["long_term"]["reason"]: + self._status["long_term"]["reason"] = f"missing_models: {', '.join(missing)}" + else: + # Get embedding dimensions from config if available + embedding_dimensions = None + if config.long_term.embedding: + embedding_dimensions = config.long_term.embedding.dimensions + + self.long_term_memory = Mem0LongTermMemory( + agent_name=config.long_term.agent_name, + user_name=config.long_term.user_name, + run_name=config.long_term.run_name, + llm_model=llm_model, + embedding_model=embedding_model, + embedding_dimensions=embedding_dimensions, + path=config.long_term.vector_store_path, + vector_store_type=config.long_term.vector_store_type, + collection_name=config.long_term.collection_name, + on_disk=config.long_term.on_disk, + default_memory_type=config.long_term.default_memory_type, + ) + self._status["long_term"]["enabled"] = True + elif kwargs.get("enable_long_term", False): + llm_model = kwargs.get("llm_model") + embedding_model = kwargs.get("embedding_model") + + if llm_model is None or embedding_model is None: + logger.warning("Long-term memory requires llm_model and embedding_model") + self.long_term_memory = None + missing = [] + if llm_model is None: + missing.append("llm_model") + if embedding_model is None: + missing.append("embedding_model") + self._status["long_term"]["reason"] = f"missing_models: {', '.join(missing)}" + else: + self.long_term_memory = Mem0LongTermMemory( + agent_name=kwargs.get("agent_name"), + user_name=kwargs.get("user_name"), + run_name=kwargs.get("run_name"), + llm_model=llm_model, + embedding_model=embedding_model, + embedding_dimensions=kwargs.get("embedding_dimensions"), + path=kwargs.get("vector_store_path", "./qdrant_storage"), + vector_store_type=kwargs.get("vector_store_type", "qdrant"), + collection_name=kwargs.get("collection_name", "memory_collection"), + on_disk=kwargs.get("on_disk", True), + default_memory_type=kwargs.get("default_memory_type"), + ) + self._status["long_term"]["enabled"] = True + else: + self.long_term_memory = None + self._status["long_term"]["reason"] = "disabled_by_config" + + logger.info( + f"MemoryManager initialized - " + f"Short-term memory: {'enabled' if self.short_term_memory else 'disabled'}, " + f"Long-term memory: {'enabled' if self.long_term_memory else 'disabled'}" + ) + + @classmethod + def from_config( + cls, + config: MemoryModuleConfig, + llm_model: Optional[Any] = None, + embedding_model: Optional[Any] = None, + auto_create_models: bool = True, + ) -> "MemoryManager": + """Create memory manager from configuration. + + Args: + config: Memory module configuration + llm_model: LLM model (optional, will try to create from config if not provided and auto_create_models=True) + embedding_model: Embedding model (optional, will try to create from config if not provided and auto_create_models=True) + auto_create_models: Whether to automatically create models from config (default True) + + Returns: + MemoryManager: Memory manager instance + """ + # If auto-create enabled and models not provided, try to create from config + if auto_create_models: + if llm_model is None and config.long_term.llm: + try: + llm_model = ModelFactory.create_llm_model(config.long_term.llm) + except Exception as e: + logger.warning(f"Failed to create LLM model from config: {e}") + + if embedding_model is None and config.long_term.embedding: + try: + embedding_model = ModelFactory.create_embedding_model( + config.long_term.embedding + ) + except Exception as e: + logger.warning(f"Failed to create Embedding model from config: {e}") + + return cls(config=config, llm_model=llm_model, embedding_model=embedding_model) + + def get_short_term_memory(self) -> Optional[InMemoryMemory]: + """Get short-term memory instance.""" + return self.short_term_memory + + def get_long_term_memory(self) -> Optional[Mem0LongTermMemory]: + """Get long-term memory instance.""" + return self.long_term_memory + + def has_short_term(self) -> bool: + """Check if short-term memory is enabled.""" + return self.short_term_memory is not None + + def has_long_term(self) -> bool: + """Check if long-term memory is enabled.""" + return self.long_term_memory is not None + + def get_status(self) -> dict: + """Return diagnostics status for memory manager. + Includes enable flags and reasons/errors when long-term memory is unavailable. + """ + return self._status diff --git a/flagscale/agent/memory/memory_tools.py b/flagscale/agent/memory/memory_tools.py new file mode 100644 index 000000000..a9945e789 --- /dev/null +++ b/flagscale/agent/memory/memory_tools.py @@ -0,0 +1,478 @@ +"""Memory tool functions for agent calls.""" + +import logging + +from typing import Any, Dict, List, Optional + +from .base import TextBlock, ToolResponse +from .long_term_memory import Mem0LongTermMemory +from .short_term_memory import InMemoryMemory + +logger = logging.getLogger(__name__) + + +class MemoryToolkit: + """Memory toolkit that wraps memory operations as tool functions. + + This class provides a set of tool functions that can be registered to agent's toolkit, + allowing agents to autonomously decide when to use the memory system. + """ + + def __init__( + self, short_term_memory: InMemoryMemory = None, long_term_memory: Mem0LongTermMemory = None + ): + """Initialize memory toolkit. + + Args: + short_term_memory: Short-term memory instance + long_term_memory: Long-term memory instance + """ + self.short_term_memory = short_term_memory + self.long_term_memory = long_term_memory + logger.info("MemoryToolkit initialized") + + async def record_to_long_term_memory( + self, thinking: str, content: List[str], **kwargs: Any + ) -> Dict[str, Any]: + """Tool function: Record important information to long-term memory. + + Use this function to record important information that needs to be saved long-term, + which can be retrieved and used in the future. + Use cases: + - User's important preferences and habits + - Key facts and knowledge + - Tasks and goals that need to be remembered long-term + + Args: + thinking (str): Your thinking and reasoning about the content to record, + explaining why this information is important + content (List[str]): List of content to record, + each item should be specific and clear information + **kwargs: Additional metadata, such as importance, category, etc. + + Returns: + Dict[str, Any]: Dictionary containing operation results + - success (bool): Whether the operation was successful + - message (str): Result message + - content (str): Summary of recorded content + + Example: + >>> result = await record_to_long_term_memory( + ... thinking="User explicitly stated they like coffee in the morning", + ... content=["User drinks black coffee at 7 AM every morning", "No sugar or milk"], + ... importance="high", + ... category="preferences" + ... ) + """ + if self.long_term_memory is None: + return {"success": False, "message": "Long-term memory not initialized", "content": ""} + + try: + # Call long-term memory tool interface + tool_response = await self.long_term_memory.record_to_memory( + thinking=thinking, content=content, **kwargs + ) + + # Extract response text + result_text = tool_response.get_text() + + return { + "success": True, + "message": "Successfully recorded to long-term memory", + "content": result_text, + } + + except Exception as e: + logger.error(f"Failed to record to long-term memory: {e}") + return {"success": False, "message": f"Recording failed: {str(e)}", "content": ""} + + async def retrieve_from_long_term_memory( + self, keywords: List[str], limit: int = 5, **kwargs: Any + ) -> Dict[str, Any]: + """Tool function: Retrieve information from long-term memory. + + Use this function to retrieve related information from long-term memory. + The system will use semantic search to find memories most relevant to the keywords. + + Args: + keywords (List[str]): List of retrieval keywords, + should be specific and clear words, such as names, places, events, etc. + limit (int): Maximum number of results to return, default 5 + **kwargs: Additional retrieval parameters + + Returns: + Dict[str, Any]: Dictionary containing retrieval results + - success (bool): Whether the operation was successful + - message (str): Result message + - results (List[str]): List of retrieved memories + - count (int): Number of results + + Example: + >>> result = await retrieve_from_long_term_memory( + ... keywords=["coffee", "breakfast habits"], + ... limit=3 + ... ) + """ + if self.long_term_memory is None: + return { + "success": False, + "message": "Long-term memory not initialized", + "results": [], + "count": 0, + } + + try: + # Call long-term memory tool interface + tool_response = await self.long_term_memory.retrieve_from_memory( + keywords=keywords, limit=limit, **kwargs + ) + + # Extract response text + result_text = tool_response.get_text() + + # Split results + results = [r.strip() for r in result_text.split("\n") if r.strip()] + + return { + "success": True, + "message": f"Found {len(results)} related memories", + "results": results, + "count": len(results), + } + + except Exception as e: + logger.error(f"Failed to retrieve from long-term memory: {e}") + return { + "success": False, + "message": f"Retrieval failed: {str(e)}", + "results": [], + "count": 0, + } + + async def search_short_term_memory( + self, query: str, limit: int = 10, **kwargs: Any + ) -> Dict[str, Any]: + """Tool function: Search short-term memory. + + Search for information in the current conversation's short-term memory. + Suitable for finding recent conversation content. + + Args: + query (str): Search query + limit (int): Maximum number of results to return, default 10 + **kwargs: Additional search parameters + + Returns: + Dict[str, Any]: Dictionary containing search results + - success (bool): Whether the operation was successful + - message (str): Result message + - results (List[Dict]): List of found messages + - count (int): Number of results + """ + if self.short_term_memory is None: + return { + "success": False, + "message": "Short-term memory not initialized", + "results": [], + "count": 0, + } + + try: + # Search short-term memory + messages = await self.short_term_memory.search(query=query, limit=limit) + + # Convert to dictionary format + results = [ + { + "role": msg.role, + "content": str(msg.content), + "timestamp": msg.timestamp.isoformat(), + } + for msg in messages + ] + + return { + "success": True, + "message": f"Found {len(results)} matching messages", + "results": results, + "count": len(results), + } + + except Exception as e: + logger.error(f"Failed to search short-term memory: {e}") + return { + "success": False, + "message": f"Search failed: {str(e)}", + "results": [], + "count": 0, + } + + async def get_recent_conversation(self, limit: int = 10, **kwargs: Any) -> Dict[str, Any]: + """Tool function: Get recent conversation records. + + Get recent conversation history for reviewing previous exchanges. + + Args: + limit (int): Number of messages to return, default 10 + **kwargs: Additional parameters + + Returns: + Dict[str, Any]: Dictionary containing conversation records + - success (bool): Whether the operation was successful + - message (str): Result message + - conversation (List[Dict]): List of conversation messages + - count (int): Number of messages + """ + if self.short_term_memory is None: + return { + "success": False, + "message": "Short-term memory not initialized", + "conversation": [], + "count": 0, + } + + try: + # Get recent messages + messages = await self.short_term_memory.get_recent_messages(limit=limit) + + # Convert to conversation format + conversation = [ + { + "role": msg.role, + "content": str(msg.content), + "timestamp": msg.timestamp.isoformat(), + } + for msg in messages + ] + + return { + "success": True, + "message": f"Retrieved {len(conversation)} recent conversations", + "conversation": conversation, + "count": len(conversation), + } + + except Exception as e: + logger.error(f"Failed to get recent conversation: {e}") + return { + "success": False, + "message": f"Retrieval failed: {str(e)}", + "conversation": [], + "count": 0, + } + + +def create_memory_tool_functions(toolkit: MemoryToolkit) -> Dict[str, callable]: + """Create memory tool function dictionary. + + This function returns a dictionary containing all tool functions that can be registered to agent toolkit. + + Args: + toolkit (MemoryToolkit): Memory toolkit instance + + Returns: + Dict[str, callable]: Tool function dictionary, keys are function names, values are function objects + + Example: + >>> toolkit = MemoryToolkit(short_term, long_term) + >>> tools = create_memory_tool_functions(toolkit) + >>> # Register to agent + >>> for name, func in tools.items(): + >>> agent.toolkit.register_tool_function(func, name=name) + """ + return { + "record_to_long_term_memory": toolkit.record_to_long_term_memory, + "retrieve_from_long_term_memory": toolkit.retrieve_from_long_term_memory, + "search_short_term_memory": toolkit.search_short_term_memory, + "get_recent_conversation": toolkit.get_recent_conversation, + } + + +# Tool Registration Functions + + +def register_memory_tools( + tool_registry: Any, memory_manager: Any, category: str = "memory" +) -> List[str]: + """Register memory functionality to ToolRegistry. + + Args: + tool_registry: ToolRegistry instance + memory_manager: MemoryManager instance + category: Tool category, default "memory" + + Returns: + List[str]: List of registered tool names + + Example: + >>> from flagscale.agent.tool_match import ToolRegistry + >>> from flagscale.agent.memory import MemoryManager + >>> + >>> registry = ToolRegistry() + >>> memory = MemoryManager(...) + >>> + >>> registered_tools = register_memory_tools(registry, memory) + >>> print(f"Registered {len(registered_tools)} memory tools") + """ + # Import here to avoid circular import + from .memory_manager import MemoryManager + + registered_tools = [] + + # Create MemoryToolkit + toolkit = MemoryToolkit( + short_term_memory=memory_manager.get_short_term_memory(), + long_term_memory=memory_manager.get_long_term_memory(), + ) + + # Define tool list + tools = [] + + # Record to long-term memory + if memory_manager.has_long_term(): + tools.append( + { + "function": { + "name": "record_to_long_term_memory", + "description": ( + "Record important information to long-term memory. Use this tool to record " + "information that needs to be saved long-term, such as user preferences, " + "key facts, important tasks, etc." + ), + "parameters": { + "type": "object", + "properties": { + "thinking": { + "type": "string", + "description": "Thinking and reasoning about the content to record, explaining why this information is important", + }, + "content": { + "type": "array", + "items": {"type": "string"}, + "description": "List of content to record, each item should be specific and clear information", + }, + }, + "required": ["content"], + }, + }, + "func": toolkit.record_to_long_term_memory, + } + ) + + # Retrieve from long-term memory + if memory_manager.has_long_term(): + tools.append( + { + "function": { + "name": "retrieve_from_long_term_memory", + "description": ( + "Retrieve related information from long-term memory. Uses semantic search " + "to find memories most relevant to the keywords." + ), + "parameters": { + "type": "object", + "properties": { + "keywords": { + "type": "array", + "items": {"type": "string"}, + "description": "List of retrieval keywords, should be specific and clear words", + }, + "limit": { + "type": "integer", + "description": "Maximum number of results to return", + "default": 5, + }, + }, + "required": ["keywords"], + }, + }, + "func": toolkit.retrieve_from_long_term_memory, + } + ) + + # Search short-term memory + if memory_manager.has_short_term(): + tools.append( + { + "function": { + "name": "search_short_term_memory", + "description": ( + "Search for information in the current conversation's short-term memory. " + "Suitable for finding recent conversation content." + ), + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"}, + "limit": { + "type": "integer", + "description": "Maximum number of results to return", + "default": 10, + }, + }, + "required": ["query"], + }, + }, + "func": toolkit.search_short_term_memory, + } + ) + + # Get recent conversation + if memory_manager.has_short_term(): + tools.append( + { + "function": { + "name": "get_recent_conversation", + "description": ( + "Get recent conversation records. Used to review previous exchange content." + ), + "parameters": { + "type": "object", + "properties": { + "limit": { + "type": "integer", + "description": "Number of messages to return", + "default": 10, + } + }, + }, + }, + "func": toolkit.get_recent_conversation, + } + ) + + # Register tools to ToolRegistry + for tool in tools: + try: + # Ensure tool dictionary contains function reference + tool_dict = { + "function": tool["function"], + "func": tool["func"], # Store actual function object + } + tool_registry.register_tool(tool_dict, category=category) + registered_tools.append(tool["function"]["name"]) + logger.debug(f"Registered memory tool: {tool['function']['name']}") + except Exception as e: + logger.error(f"Failed to register tool {tool['function']['name']}: {e}") + + logger.info( + f"Registered {len(registered_tools)} memory tools to ToolRegistry " + f"(category: {category})" + ) + + return registered_tools + + +def create_memory_toolkit(memory_manager: Any) -> MemoryToolkit: + """Create MemoryToolkit instance. + + Args: + memory_manager: MemoryManager instance + + Returns: + MemoryToolkit: Memory toolkit instance + """ + return MemoryToolkit( + short_term_memory=memory_manager.get_short_term_memory(), + long_term_memory=memory_manager.get_long_term_memory(), + ) diff --git a/flagscale/agent/memory/short_term_memory.py b/flagscale/agent/memory/short_term_memory.py new file mode 100644 index 000000000..d6638ad60 --- /dev/null +++ b/flagscale/agent/memory/short_term_memory.py @@ -0,0 +1,240 @@ +"""Short-term memory implementation with fast in-memory access.""" + +import logging + +from typing import Any, Iterable, List, Union + +from .base import MemoryBase, Msg + +logger = logging.getLogger(__name__) + + +class InMemoryMemory(MemoryBase): + """In-memory short-term memory implementation. + + This class references AgentScope's InMemoryMemory implementation, providing: + - Fast message storage and access + - Index-based deletion operations + - Duplicate control + - State persistence support + """ + + def __init__(self, max_size: int = 1000) -> None: + """Initialize short-term memory. + + Args: + max_size (int): Maximum number of messages to store, default 1000 + """ + super().__init__() + self.content: List[Msg] = [] + self.max_size = max_size + logger.info(f"InMemoryMemory initialized, max capacity: {max_size}") + + def state_dict(self) -> dict: + """Convert current memory to dictionary format. + + Returns: + dict: Dictionary containing memory content and configuration + """ + return {"content": [msg.to_dict() for msg in self.content], "max_size": self.max_size} + + def load_state_dict(self, state_dict: dict, strict: bool = True) -> None: + """Load memory state from dictionary. + + Args: + state_dict (dict): State dictionary + strict (bool): If True, raise error when keys are missing + """ + self.content = [] + for data in state_dict.get("content", []): + # Remove possible type field (for compatibility) + data.pop("type", None) + self.content.append(Msg.from_dict(data)) + + self.max_size = state_dict.get("max_size", 1000) + logger.info(f"Loaded {len(self.content)} messages from state dictionary") + + async def size(self) -> int: + """Get memory size. + + Returns: + int: Current number of stored messages + """ + return len(self.content) + + async def retrieve(self, query: str = "", limit: int = 10) -> List[Msg]: + """Retrieve messages from memory. + + Args: + query (str): Retrieval query (simple text matching) + limit (int): Maximum number of results to return + + Returns: + List[Msg]: List of matching messages + """ + if not query: + # If no query, return recent messages + return self.content[-limit:] if limit > 0 else self.content + + # Simple text matching retrieval + results = [] + query_lower = query.lower() + + for msg in self.content: + content_str = msg.content + if isinstance(msg.content, list): + # Handle structured content + content_str = " ".join( + str(block.get("text", "")) if isinstance(block, dict) else str(block) + for block in msg.content + ) + + if query_lower in str(content_str).lower(): + results.append(msg) + + # Return recent matching results + return results[-limit:] if limit > 0 else results + + async def delete(self, index: Union[Iterable, int]) -> None: + """Delete messages at specified indices. + + Args: + index: Index or list of indices to delete + + Raises: + IndexError: If index does not exist + """ + if isinstance(index, int): + index = [index] + + # Check for invalid indices + invalid_index = [i for i in index if i < 0 or i >= len(self.content)] + + if invalid_index: + raise IndexError(f"Index {invalid_index} does not exist.") + + # Delete messages at specified indices + self.content = [msg for idx, msg in enumerate(self.content) if idx not in index] + logger.debug(f"Deleted {len(list(index))} messages") + + async def add( + self, memories: Union[List[Msg], Msg, None], allow_duplicates: bool = False + ) -> None: + """Add messages to memory. + + Args: + memories: Messages or list of messages to add + allow_duplicates (bool): Whether to allow adding duplicate messages (same id) + + Raises: + TypeError: If message type is incorrect + """ + if memories is None: + return + + if isinstance(memories, Msg): + memories = [memories] + + if not isinstance(memories, list): + raise TypeError( + f"memories should be a list of Msg or a single Msg, " f"but got {type(memories)}." + ) + + for msg in memories: + if not isinstance(msg, Msg): + raise TypeError( + f"memories should be a list of Msg or a single Msg, " f"but got {type(msg)}." + ) + + # Deduplication + if not allow_duplicates: + existing_ids = {msg.id for msg in self.content} + memories = [msg for msg in memories if msg.id not in existing_ids] + + self.content.extend(memories) + + # If exceeds max capacity, remove oldest messages + while len(self.content) > self.max_size: + removed = self.content.pop(0) + logger.debug(f"Memory full, removing oldest message: {removed.id}") + + logger.debug(f"Added {len(memories)} messages to short-term memory") + + async def get_memory(self, recent: int = 0) -> List[Msg]: + """Get memory content. + + Args: + recent (int): If greater than 0, only return the most recent N messages + + Returns: + List[Msg]: List of messages + """ + if recent > 0: + return self.content[-recent:] + return self.content.copy() + + async def clear(self) -> None: + """Clear memory content.""" + count = len(self.content) + self.content = [] + logger.info(f"Cleared short-term memory, deleted {count} messages") + + async def get_recent_messages(self, limit: int = 20) -> List[Msg]: + """Get recent messages. + + Args: + limit (int): Number of messages to return + + Returns: + List[Msg]: List of recent messages + """ + return self.content[-limit:] if limit > 0 else self.content + + async def search(self, query: str, limit: int = 10) -> List[Msg]: + """Search messages (similar to retrieve, for compatibility). + + Args: + query (str): Search query + limit (int): Maximum number of results to return + + Returns: + List[Msg]: List of matching messages + """ + return await self.retrieve(query, limit) + + async def update(self, message_id: str, content: str = None, metadata: dict = None) -> bool: + """Update message with specified ID. + + Args: + message_id (str): Message ID + content (str): New content (optional) + metadata (dict): New metadata (optional) + + Returns: + bool: Whether update was successful + """ + for msg in self.content: + if msg.id == message_id: + if content is not None: + msg.content = content + if metadata is not None: + msg.metadata = metadata + logger.debug(f"Updated message: {message_id}") + return True + + logger.warning(f"Message not found: {message_id}") + return False + + async def get_by_id(self, message_id: str) -> Union[Msg, None]: + """Get message by ID. + + Args: + message_id (str): Message ID + + Returns: + Union[Msg, None]: Message object or None + """ + for msg in self.content: + if msg.id == message_id: + return msg + return None