diff --git a/erniebot-agent/erniebot_agent/memory/__init__.py b/erniebot-agent/erniebot_agent/memory/__init__.py index 90ba3de5c..6d8c64fd1 100644 --- a/erniebot-agent/erniebot_agent/memory/__init__.py +++ b/erniebot-agent/erniebot_agent/memory/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import Memory +from .base import Memory, MessageManager, PersistentMessageManager from .limit_token_memory import LimitTokensMemory from .sliding_window_memory import SlidingWindowMemory from .whole_memory import WholeMemory diff --git a/erniebot-agent/erniebot_agent/memory/base.py b/erniebot-agent/erniebot_agent/memory/base.py index 32f8eb200..cdc078e71 100644 --- a/erniebot-agent/erniebot_agent/memory/base.py +++ b/erniebot-agent/erniebot_agent/memory/base.py @@ -12,9 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import Dict, List, Optional, Union -from erniebot_agent.messages import AIMessage, Message +from erniebot_agent.messages import AIMessage, HumanMessage, Message + +# Test Cases + +user_AK_relation = {"AK-123": "user-123", "AK-124": "user-124"} +user_session_id_relation: Dict[str, List] = {"user-123": [], "user-124": ["session-124", "session-125"]} + +session_messages = { + "session-124": [HumanMessage(content="你好"), AIMessage(content="你好124", function_call=None)], + "session-125": [HumanMessage(content="你好"), AIMessage(content="你好125", function_call=None)], +} class MessageManager: @@ -38,15 +48,125 @@ def clear_messages(self) -> None: def update_last_message_token_count(self, token_count: int): self.messages[-1].token_count = token_count - def retrieve_messages(self) -> List[Message]: + def get_messages(self) -> List[Message]: return self.messages +class RemoteMemory: + """ + 远程memory的实现类, 用于管理一个user 在一个session中的messages。 + """ + + def __init__(self, session_id): + self.session_id: str = session_id + self.messages: list[Message] = session_messages[session_id] + + def add_message(self, message: Message): + "make changes to the session's memory" + session_messages[self.session_id].append(message) + + def pop_message(self): + """pop the message from the start""" + session_messages[self.session_id].pop(0) + + def clear_memory(self): + session_messages[self.session_id] = [] + + def get_messages(self): + if self.session_id not in session_messages.keys(): + raise KeyError(f"session_id {self.session_id} not found") + return session_messages[self.session_id] + + def search_memory(self, session_id, payload): # TODO: refer zep + pass + + # TODO: 关闭之后同步message的变化到数据库 + + +class MessageStorageServer: # 绑定user + """ + MessageStorageServer 用于管理一个user在多个session中的message切换。 + + Args: + request_url (str): 请求地址 + AK (str): 用户ID + session_id (str, optional): 用户选择的session对应的session id. Defaults to None. + """ + + def __init__(self, request_url: str, AK: str, session_id: Optional[str] = None): + self.request_url = request_url + self.AK = AK + self.user_id = user_AK_relation[AK] + self.sessions: List = user_session_id_relation[self.user_id] + if len(self.sessions) == 0: + self.create_session() + + self.session_id = session_id if session_id else self.sessions[-1] # TODO: session选择 + self.memory = RemoteMemory(self.session_id) + + def get_messages(self): + return self.memory.get_messages() + + def create_session( + self, + ): + """create a new session for user and return the session id""" + import uuid + + session_id = uuid.uuid4().hex # A new session identifier + self.sessions.append(session_id) + user_session_id_relation[self.user_id] = [session_id] + global session_messages + session_messages[session_id] = [] + # 同时在数据库中创建相应空间 + return session_id + + +class PersistentMessageManager: + """ + PersistentMessageManager 用于本地的持久化、隔离化message管理。 + """ + + def __init__(self, url: str, AK: str, session_id: Optional[str] = None): + self.client = MessageStorageServer( + request_url=url, AK=AK, session_id=session_id + ) # client 内确定了session_id + self.session_id = self.client.session_id # 统一内外的session_id + self.messages = self.get_messages() + + def add_message(self, message: Message): + self.client.memory.add_message(message=message) + + def clear_messages(self): + self.messages = [] + self.client.memory.clear_memory() + + def pop_message(self): # TODO: choose from pop_message and cherry_pick_message + delete_message = self.client.memory.pop_message() + return delete_message + + def get_messages( + self, + ) -> List[Message]: # system,AI,user,contains summary if necessary + memory = self.client.memory.get_messages() + return memory + + # def cherry_pick_message(self, query): # TODO: 不使用pop,而是利用存储后端的索引功能找到相关message,但不保证限制长度 + # from zep_python import MemorySearchPayload + + # payload: MemorySearchPayload = MemorySearchPayload(text=query) + + # return self.client.memory.search_memory(self.session_id, payload) + + def update_last_message_token_count(self, token_count: int): + self.client.memory.get_messages()[-1].token_count = token_count + + class Memory: """The base class of memory""" - def __init__(self): - self.msg_manager = MessageManager() + def __init__(self, message_manager: Union[PersistentMessageManager, MessageManager] = MessageManager()): + self.msg_manager = message_manager def add_messages(self, messages: List[Message]): for message in messages: @@ -55,17 +175,10 @@ def add_messages(self, messages: List[Message]): def add_message(self, message: Message): if isinstance(message, AIMessage): self.msg_manager.update_last_message_token_count(message.query_tokens_count) - self.msg_manager.add_message(message) + self.msg_manager.add_message(message=message) def get_messages(self) -> List[Message]: - return self.msg_manager.retrieve_messages() + return self.msg_manager.get_messages() def clear_chat_history(self): self.msg_manager.clear_messages() - - -class WholeMemory(Memory): - """The memory include all the messages""" - - def __init__(self): - super().__init__() diff --git a/erniebot-agent/erniebot_agent/memory/limit_token_memory.py b/erniebot-agent/erniebot_agent/memory/limit_token_memory.py index 9719c2225..46f41e542 100644 --- a/erniebot-agent/erniebot_agent/memory/limit_token_memory.py +++ b/erniebot-agent/erniebot_agent/memory/limit_token_memory.py @@ -13,7 +13,7 @@ # limitations under the License. -from erniebot_agent.memory import Memory +from erniebot_agent.memory import Memory, MessageManager from erniebot_agent.messages import AIMessage, Message @@ -22,8 +22,8 @@ class LimitTokensMemory(Memory): If tokens >= max_token_limit, pop message from memory. """ - def __init__(self, max_token_limit=None): - super().__init__() + def __init__(self, max_token_limit=None, message_manager=MessageManager()): + super().__init__(message_manager) self.max_token_limit = max_token_limit self.mem_token_count = 0 diff --git a/erniebot-agent/erniebot_agent/memory/sliding_window_memory.py b/erniebot-agent/erniebot_agent/memory/sliding_window_memory.py index 3bd65ccbe..eaa15f38d 100644 --- a/erniebot-agent/erniebot_agent/memory/sliding_window_memory.py +++ b/erniebot-agent/erniebot_agent/memory/sliding_window_memory.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from erniebot_agent.memory import Memory +from erniebot_agent.memory import Memory, MessageManager from erniebot_agent.messages import Message class SlidingWindowMemory(Memory): """This class controls max number of messages.""" - def __init__(self, max_num_message: int): - super().__init__() + def __init__(self, max_num_message: int, message_manager=MessageManager()): + super().__init__(message_manager) self.max_num_message = max_num_message assert (isinstance(max_num_message, int)) and ( diff --git a/erniebot-agent/tests/unit_tests/memory/test_limit_token_memory.py b/erniebot-agent/tests/unit_tests/memory/test_limit_token_memory.py index b2bae95bc..1a34dccb8 100644 --- a/erniebot-agent/tests/unit_tests/memory/test_limit_token_memory.py +++ b/erniebot-agent/tests/unit_tests/memory/test_limit_token_memory.py @@ -15,13 +15,15 @@ def setUp(self): async def test_limit_token_memory(self): messages = HumanMessage(content="What is the purpose of model regularization?") - memory = LimitTokensMemory(4000) + memory = LimitTokensMemory(4) memory.add_message(messages) message = await self.llm.async_chat([messages]) memory.add_message(message) memory.add_message(HumanMessage("OK, what else?")) message = await self.llm.async_chat(memory.get_messages()) + memory.add_message(message) self.assertTrue(message is not None) + self.assertTrue(memory.mem_token_count <= 4) @pytest.mark.asyncio async def test_limit_token_memory_truncate_tokens(self, k=3): # truncate through returned message diff --git a/erniebot-agent/tests/unit_tests/memory/test_persist_message_manager.py b/erniebot-agent/tests/unit_tests/memory/test_persist_message_manager.py new file mode 100644 index 000000000..2f5f7e93a --- /dev/null +++ b/erniebot-agent/tests/unit_tests/memory/test_persist_message_manager.py @@ -0,0 +1,69 @@ +import asyncio +import unittest + +import pytest +from erniebot_agent.memory import PersistentMessageManager, WholeMemory +from erniebot_agent.messages import HumanMessage + +from tests.unit_tests.testing_utils import MockErnieBot + + +class TestSlidingWindowMemory(unittest.TestCase): + def setUp(self): + self.llm = MockErnieBot(None, None, None) + + # @pytest.mark.asyncio + @pytest.mark.parametrize("k", [1, 2, 4, 5, 10]) + def test_sliding_window_memory(self, k=3): # asyn pytest + async def test_sliding_window_memory(k=3): # asyn pytest + # The memory + + memory = WholeMemory( + message_manager=PersistentMessageManager(AK="AK-123", url="not used", session_id=None) + ) + + for _ in range(k): + # 2 times of human message + memory.add_message(HumanMessage(content="What is the purpose of model regularization?")) + # AI message + message = await self.llm.async_chat(memory.get_messages()) + memory.add_message(message) + print( + "!!! test_sliding_window_memory_wo_sessionid, conversation output", + memory.msg_manager.client.memory.messages, + ) + + self.assertTrue(len(memory.get_messages()) == 2 * k) + + asyncio.run(test_sliding_window_memory(k)) + + @pytest.mark.parametrize("k", [1, 2, 4, 5, 10]) + def test_sliding_window_memory_with_sessionid(self, k=3): # asyn pytest + async def test_sliding_window_memory(k=3): # asyn pytest + # The memory + + memory = WholeMemory( + message_manager=PersistentMessageManager( + AK="AK-124", url="not used", session_id="session-124" + ), + ) + + for _ in range(k): + # 2 times of human message + memory.add_message(HumanMessage(content="What is the purpose of model regularization?")) + + # AI message + message = await self.llm.async_chat(memory.get_messages()) + memory.add_message(message) + print( + "!!! test_sliding_window_memory_with_sessionid, conversation output", + memory.msg_manager.client.memory.messages, + ) + + self.assertTrue(len(memory.get_messages()) == 2 * k + 2) + + asyncio.run(test_sliding_window_memory(k)) + + +if __name__ == "__main__": + unittest.main()