Skip to content
This repository has been archived by the owner on Nov 9, 2024. It is now read-only.

Commit

Permalink
fix: better filters for chat storage/retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
ToJen committed Oct 22, 2024
1 parent b29c0fd commit fd9111f
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions hive_agent/chat/chat_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from datetime import datetime, timezone
from typing import Any, List, Optional

from hive_agent.database.database import DatabaseManager
from llama_index.core.agent.runner.base import AgentRunner
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.schema import ImageDocument

from swarmzero.database.database import DatabaseManager


class ChatManager:

Expand All @@ -26,8 +27,10 @@ async def add_message(self, db_manager: DatabaseManager, role: str, content: Any
"role": role,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
if "HIVE_AGENT_ID" in os.environ:
data["agent_id"] = os.getenv("HIVE_AGENT_ID", "")
if "AGENT_ID" in os.environ:
data["agent_id"] = os.getenv("AGENT_ID", "")
if "SWARM_ID" in os.environ:
data["swarm_id"] = os.getenv("SWARM_ID", "")

await db_manager.insert_data(
table_name="chats",
Expand All @@ -36,12 +39,22 @@ async def add_message(self, db_manager: DatabaseManager, role: str, content: Any

async def get_messages(self, db_manager: DatabaseManager):
filters = {"user_id": [self.user_id], "session_id": [self.session_id]}
if "AGENT_ID" in os.environ:
filters["agent_id"] = os.getenv("AGENT_ID", "")
if "SWARM_ID" in os.environ:
filters["swarm_id"] = os.getenv("SWARM_ID", "")

db_chat_history = await db_manager.read_data("chats", filters)
chat_history = [ChatMessage(role=chat["role"], content=chat["message"]) for chat in db_chat_history]
return chat_history

async def get_all_chats_for_user(self, db_manager: DatabaseManager):
filters = {"user_id": [self.user_id]}
if "AGENT_ID" in os.environ:
filters["agent_id"] = os.getenv("AGENT_ID", "")
if "SWARM_ID" in os.environ:
filters["swarm_id"] = os.getenv("SWARM_ID", "")

db_chat_history = await db_manager.read_data("chats", filters)

chats_by_session: dict[str, list] = {}
Expand Down

0 comments on commit fd9111f

Please sign in to comment.