diff --git a/.gitignore b/.gitignore index c972d59de..cedb23efa 100644 --- a/.gitignore +++ b/.gitignore @@ -237,3 +237,8 @@ outputs evaluation/data/ test_add_pipeline.py test_file_pipeline.py + +# LanceDB local storage and scripts +data/ +inspect_lancedb.py +memos_server.log diff --git a/pyproject.toml b/pyproject.toml index e7fca38ff..0c38ec160 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ memos = "memos.cli:main" tree-mem = [ "neo4j (>=5.28.1,<6.0.0)", # Graph database "schedule (>=1.2.2,<2.0.0)", # Task scheduling + "lancedb (>=0.30.1,<1.0.0)", # LanceDB ] # MemScheduler @@ -102,6 +103,13 @@ skill-mem = [ "alibabacloud-oss-v2 (>=1.2.2,<1.2.3)", ] +# Lance Vector DB +lance-mem = [ + "lancedb (>=0.17.0,<1.0.0)", # Lance vector database + "pyarrow (>=18.0.0,<20.0.0)", # Arrow format support for Lance + "tantivy (>=0.22.0,<1.0.0)", # FTS engine for LanceDB +] + # Tavily Search tavily = [ "tavily-python (>=0.5.0,<1.0.0)", @@ -135,6 +143,9 @@ all = [ "rake-nltk (>=1.0.6,<1.1.0)", "alibabacloud-oss-v2 (>=1.2.2,<1.2.3)", "tavily-python (>=0.5.0,<1.0.0)", + "lancedb (>=0.17.0,<1.0.0)", + "pyarrow (>=18.0.0,<20.0.0)", + "tantivy (>=0.22.0,<1.0.0)", # Uncategorized dependencies ] @@ -199,8 +210,8 @@ langgraph = "^0.5.1" pymysql = "^1.1.2" [[tool.poetry.source]] -name = "mirrors" -url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" +name = "volces" +url = "https://mirrors.volces.com/pypi/simple/" priority = "supplemental" diff --git a/src/memos/api/config.py b/src/memos/api/config.py index c68deae5a..87c1c4f58 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -1049,6 +1049,16 @@ def get_start_default_config() -> dict[str, Any]: return config + @staticmethod + def get_lance_graph_config(user_id: str | None = None) -> dict[str, Any]: + """Get LanceDB graph configuration.""" + base_uri = os.getenv("LANCE_URI", "./data/lance_db") + return { + "uri": base_uri, + "user_name": user_id, + "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 2048)), + } + @staticmethod def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "GeneralMemCube"]: """Create configuration for a specific user.""" @@ -1126,6 +1136,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene neo4j_community_config = APIConfig.get_neo4j_community_config(user_id) neo4j_config = APIConfig.get_neo4j_config(user_id) polardb_config = APIConfig.get_polardb_config(user_id) + lance_config = APIConfig.get_lance_graph_config(user_id) internet_config = ( APIConfig.get_internet_config() if os.getenv("ENABLE_INTERNET", "false").lower() == "true" @@ -1137,6 +1148,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene "neo4j": neo4j_config, "polardb": polardb_config, "postgres": postgres_config, + "lance": lance_config, } # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars graph_db_backend = os.getenv( @@ -1210,11 +1222,13 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None": neo4j_config = APIConfig.get_neo4j_config(user_id="default") polardb_config = APIConfig.get_polardb_config(user_id="default") postgres_config = APIConfig.get_postgres_config(user_id="default") + lance_config = APIConfig.get_lance_graph_config(user_id="default") graph_db_backend_map = { "neo4j-community": neo4j_community_config, "neo4j": neo4j_config, "polardb": polardb_config, "postgres": postgres_config, + "lance": lance_config, } internet_config = ( APIConfig.get_internet_config() diff --git a/src/memos/api/handlers/config_builders.py b/src/memos/api/handlers/config_builders.py index d29429fc9..7088d54b1 100644 --- a/src/memos/api/handlers/config_builders.py +++ b/src/memos/api/handlers/config_builders.py @@ -41,6 +41,7 @@ def build_graph_db_config(user_id: str = "default") -> dict[str, Any]: "neo4j": APIConfig.get_neo4j_config(user_id=user_id), "polardb": APIConfig.get_polardb_config(user_id=user_id), "postgres": APIConfig.get_postgres_config(user_id=user_id), + "lance": APIConfig.get_lance_graph_config(user_id=user_id), } # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars @@ -62,10 +63,27 @@ def build_vec_db_config() -> dict[str, Any]: Returns: Validated vector database configuration dictionary """ + vec_db_backend = os.getenv("MOS_VEC_DB_BACKEND", "milvus").lower() + + config = {} + if vec_db_backend == "milvus": + config = APIConfig.get_milvus_config() + elif vec_db_backend == "lance": + base_uri = os.getenv("LANCE_URI", "./data/lance_db") + config = { + "uri": base_uri, + "collection_name": ["memories"], + "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 2048)), + } + elif vec_db_backend == "qdrant": + config = APIConfig.get_qdrant_config() + else: + raise ValueError(f"Unsupported vector DB backend: {vec_db_backend}") + return VectorDBConfigFactory.model_validate( { - "backend": "milvus", - "config": APIConfig.get_milvus_config(), + "backend": vec_db_backend, + "config": config, } ) diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index ba1c50b07..64578ffad 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -6,6 +6,7 @@ """ import copy +import logging import math from typing import Any @@ -71,7 +72,40 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse results = cube_view.search_memories(search_req_local) if not search_req_local.relativity: search_req_local.relativity = 0 + self.logger.info(f"[SearchHandler] Relativity filter: {search_req_local.relativity}") + + # Extract and log scores for visibility before filtering + if self.logger.isEnabledFor(logging.DEBUG): + score_details = [] + for key in ("text_mem", "pref_mem"): + buckets = results.get(key) + if not isinstance(buckets, list): + continue + for bucket in buckets: + memories = bucket.get("memories") + if not isinstance(memories, list): + continue + for mem in memories: + if not isinstance(mem, dict): + continue + mem_text = mem.get("memory", "").replace("\n", " ") + # Truncate to 100 chars to avoid log flooding + if len(mem_text) > 100: + mem_text = mem_text[:100] + "..." + meta = mem.get("metadata", {}) + score = meta.get("relativity", 1.0) if isinstance(meta, dict) else 1.0 + try: + score_val = float(score) if score is not None else 1.0 + except (TypeError, ValueError): + score_val = 1.0 + score_details.append(f"[{score_val:.4f}] {mem_text}") + + if score_details: + self.logger.debug( + f"[SearchHandler] Reranker scores before threshold ({search_req_local.relativity}): \n" + + "\n".join(score_details) + ) results = self._apply_relativity_threshold(results, search_req_local.relativity) if search_req_local.dedup == "sim": diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index 98de09812..bcd76c4e9 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -241,6 +241,28 @@ def validate_config(self): return self +class LanceGraphDBConfig(BaseConfig): + """ + LanceDB-specific configuration. + """ + + uri: str = Field(..., description="The URI/path to the LanceDB dataset") + user_name: str | None = Field( + default=None, + description="Logical user or tenant ID for data isolation", + ) + embedding_dimension: int = Field(default=768, description="Dimension of vector embedding") + compaction_version_threshold: int = Field( + default=500, description="Number of new versions to accumulate before triggering compaction" + ) + compaction_interval_mins: int = Field( + default=30, description="Fallback interval in minutes to check and run compaction" + ) + cleanup_older_than_days: int = Field( + default=7, description="Number of days to keep old versions before pruning" + ) + + class GraphDBConfigFactory(BaseModel): backend: str = Field(..., description="Backend for graph database") config: dict[str, Any] = Field(..., description="Configuration for the graph database backend") @@ -250,6 +272,7 @@ class GraphDBConfigFactory(BaseModel): "neo4j-community": Neo4jCommunityGraphDBConfig, "polardb": PolarDBGraphDBConfig, "postgres": PostgresGraphDBConfig, + "lance": LanceGraphDBConfig, } @field_validator("backend") diff --git a/src/memos/graph_dbs/base.py b/src/memos/graph_dbs/base.py index 0bc4a54f8..cd653edec 100644 --- a/src/memos/graph_dbs/base.py +++ b/src/memos/graph_dbs/base.py @@ -302,3 +302,7 @@ def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = N - metadata: dict[str, Any] - Node metadata user_name: Optional user name (will use config default if not provided) """ + + @abstractmethod + def node_not_exist(self, scope: str, user_name: str | None = None) -> bool: + pass diff --git a/src/memos/graph_dbs/factory.py b/src/memos/graph_dbs/factory.py index 93b5971ec..1b47a7c4a 100644 --- a/src/memos/graph_dbs/factory.py +++ b/src/memos/graph_dbs/factory.py @@ -2,6 +2,7 @@ from memos.configs.graph_db import GraphDBConfigFactory from memos.graph_dbs.base import BaseGraphDB +from memos.graph_dbs.lance import LanceGraphDB from memos.graph_dbs.neo4j import Neo4jGraphDB from memos.graph_dbs.neo4j_community import Neo4jCommunityGraphDB from memos.graph_dbs.polardb import PolarDBGraphDB @@ -16,6 +17,7 @@ class GraphStoreFactory(BaseGraphDB): "neo4j-community": Neo4jCommunityGraphDB, "polardb": PolarDBGraphDB, "postgres": PostgresGraphDB, + "lance": LanceGraphDB, } @classmethod diff --git a/src/memos/graph_dbs/lance.py b/src/memos/graph_dbs/lance.py new file mode 100644 index 000000000..cd980e5bc --- /dev/null +++ b/src/memos/graph_dbs/lance.py @@ -0,0 +1,957 @@ +from __future__ import annotations + +import json +import os +import threading +import time +import uuid + +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from memos.configs.graph_db import LanceGraphDBConfig + +from memos.dependency import require_python_package +from memos.graph_dbs.base import BaseGraphDB +from memos.log import get_logger + + +logger = get_logger(__name__) + + +class LanceGraphDB(BaseGraphDB): + """ + LanceDB implementation for MemOS GraphDB interface. + Features: + - Flattened 'memory_type' and 'status' for blazing fast scalar filtering. + - Native Full-Text Search (FTS) using Tantivy. + - BFS based multi-hop graph traversal. + """ + + @require_python_package(import_name="lancedb", install_command="pip install lancedb tantivy") + def __init__(self, config: LanceGraphDBConfig): + self.config = config + self.uri = config.uri + self.user_name = config.user_name or "default" + self.dim = config.embedding_dimension + + # Compaction settings + self.compaction_version_threshold = config.compaction_version_threshold + self.compaction_interval_mins = config.compaction_interval_mins + self.cleanup_older_than_days = config.cleanup_older_than_days + + self.memories_uri = os.path.join(self.uri, "memories") + self.edges_uri = os.path.join(self.uri, "edges") + + self._init_schema() + + # Start LanceDB background optimizer thread + self._last_compact_versions = { + "memories": self._get_memories_table().version, + "edges": self._get_edges_table().version, + } + self._optimizer_thread = threading.Thread( + target=self._db_optimizer_loop, + daemon=True, + name="lancedb-optimizer", + ) + self._optimizer_thread.start() + + def _db_optimizer_loop(self): + """Background loop to periodically trigger table optimization.""" + import schedule + + schedule.every(self.compaction_interval_mins).minutes.do(self._force_optimize) + + logger.info( + f"Started LanceDB optimizer thread. Compaction interval: {self.compaction_interval_mins}m, " + f"Version threshold: {self.compaction_version_threshold}" + ) + + while True: + try: + # 1. Check version threshold + self._check_and_trigger_compaction() + + # 2. Run scheduled fallback compaction + schedule.run_pending() + except Exception as e: + logger.error(f"Error in LanceDB optimizer loop: {e}", stack_info=True) + + time.sleep(5) # Avoid busy waiting + + def _check_and_trigger_compaction(self): + """Trigger compaction if any table's version diff exceeds the threshold.""" + try: + memories_ds = self._get_memories_table() + if ( + memories_ds.version - self._last_compact_versions["memories"] + > self.compaction_version_threshold + ): + self._optimize_table("memories", memories_ds) + + edges_ds = self._get_edges_table() + if ( + edges_ds.version - self._last_compact_versions["edges"] + > self.compaction_version_threshold + ): + self._optimize_table("edges", edges_ds) + except Exception as e: + logger.error(f"Failed to check compaction versions: {e}") + + def _optimize_table(self, table_name: str, ds): + """Helper method to optimize a specific LanceDB table.""" + try: + current_version = ds.version + last_version = self._last_compact_versions[table_name] + + if current_version > last_version: + logger.info( + f"Triggering LanceDB optimization for '{table_name}'. " + f"Current version: {current_version}, Last compacted: {last_version}" + ) + + stats = ds.optimize(cleanup_older_than=timedelta(days=self.cleanup_older_than_days)) + + stats_msg = "" + if stats: + compaction = getattr(stats, "compaction", None) + if compaction: + stats_msg += ( + f" | Compaction: " + f"-{getattr(compaction, 'fragments_removed', 0)}/" + f"+{getattr(compaction, 'fragments_added', 0)} fragments, " + f"-{getattr(compaction, 'files_removed', 0)}/" + f"+{getattr(compaction, 'files_added', 0)} files" + ) + + prune = getattr(stats, "prune", None) + if prune: + stats_msg += ( + f" | Prune: -{getattr(prune, 'bytes_removed', 0)} bytes, " + f"-{getattr(prune, 'old_versions_removed', 0)} versions" + ) + + # Reload the table to get the updated version after optimization + if table_name == "memories": + ds = self._get_memories_table() + elif table_name == "edges": + ds = self._get_edges_table() + + self._last_compact_versions[table_name] = ds.version + logger.info( + f"LanceDB '{table_name}' optimization completed successfully. " + f"New version: {self._last_compact_versions[table_name]}{stats_msg}" + ) + except Exception as e: + logger.error(f"LanceDB '{table_name}' optimization failed: {e}") + + def _force_optimize(self): + # Optimize Memories Table + self._optimize_table("memories", self._get_memories_table()) + # Optimize Edges Table + self._optimize_table("edges", self._get_edges_table()) + + def _init_schema(self): + import lancedb + import pyarrow as pa + + os.makedirs(self.uri, exist_ok=True) + self.db = lancedb.connect(self.uri) + + if hasattr(self.db, "table_names"): + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + table_names = self.db.table_names() + else: + table_names = [tbl.name for tbl in self.db.list_tables()] + + if "memories" not in table_names: + schema = pa.schema( + [ + pa.field("id", pa.string()), + pa.field("memory", pa.string()), + pa.field("properties", pa.string()), # Store arbitrary JSON as string + pa.field("embedding", pa.list_(pa.float32(), self.dim)), + pa.field("user_name", pa.string()), + pa.field("memory_type", pa.string()), # Flattened for performance + pa.field("status", pa.string()), # Flattened for performance + pa.field("created_at", pa.string()), + pa.field("updated_at", pa.string()), + ] + ) + empty_table = pa.Table.from_pylist([], schema=schema) + self.db.create_table("memories", data=empty_table) + logger.info("Created LanceDB table for memories.") + + try: + ds = self.db.open_table("memories") + + # Create vector index (aligned with memory-lancedb TS implementation) + import math + + row_count = ds.count_rows() + if row_count > 256: # LanceDB requires at least 256 rows to train vector index + num_partitions = max(1, math.floor(math.sqrt(row_count))) + ds.create_index( + metric="cosine", + vector_column_name="embedding", + num_partitions=num_partitions, + ) + logger.info( + f"Created IVF_FLAT index for memories.embedding with metric=cosine, partitions={num_partitions}" + ) + else: + logger.debug( + f"Skipping vector index creation, not enough rows ({row_count} <= 256)" + ) + + # Create full-text search index + ds.create_fts_index("memory", replace=True) + logger.info("Created FTS index for memories.memory") + except Exception as e: + logger.warning(f"Failed to create LanceDB indices: {e}") + + if "edges" not in table_names: + edge_schema = pa.schema( + [ + pa.field("id", pa.string()), + pa.field("source_id", pa.string()), + pa.field("target_id", pa.string()), + pa.field("edge_type", pa.string()), + pa.field("user_name", pa.string()), + pa.field("created_at", pa.string()), + ] + ) + empty_edge_table = pa.Table.from_pylist([], schema=edge_schema) + self.db.create_table("edges", data=empty_edge_table) + logger.info("Created LanceDB table for edges.") + + def _get_memories_table(self): + import lancedb + + if not hasattr(self, "db"): + self.db = lancedb.connect(self.uri) + return self.db.open_table("memories") + + def _get_edges_table(self): + import lancedb + + if not hasattr(self, "db"): + self.db = lancedb.connect(self.uri) + return self.db.open_table("edges") + + def add_node( + self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None + ) -> None: + self.add_nodes_batch( + [{"id": id, "memory": memory, "metadata": metadata}], user_name=user_name + ) + + def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = None) -> None: + target_user = user_name or self.user_name + data = [] + now = datetime.now().isoformat() + + ids = [n["id"] for n in nodes if "id" in n] + if ids: + self.delete_node_by_prams(ids, user_name=target_user) + + for node in nodes: + node_id = node.get("id", str(uuid.uuid4())) + mem = node.get("memory", "") + meta = node.get("metadata", {}) + embedding = node.get("embedding", meta.get("embedding")) + + if embedding is None: + embedding = [0.0] * self.dim + + mem_type = meta.get("memory_type", "") + status = meta.get("status", "") + + if "embedding" in meta: + meta = meta.copy() + del meta["embedding"] + + data.append( + { + "id": str(node_id), + "memory": str(mem), + "properties": json.dumps(meta), + "embedding": embedding, + "user_name": target_user, + "memory_type": mem_type, + "status": status, + "created_at": now, + "updated_at": now, + } + ) + + if data: + self._get_memories_table().add(data) + + # Rebuild FTS index automatically after batch insert + try: + ds = self._get_memories_table() + ds.create_fts_index("memory", replace=True) + except Exception as e: + logger.warning(f"Failed to create LanceDB FTS index (tantivy missing?): {e}") + + def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: + target_user = user_name or self.user_name + node = self.get_node(id, include_embedding=True, user_name=target_user) + if not node: + return + + meta = node.get("metadata", {}) + + # In Neo4j, fields are top-level properties. In LanceDB, most properties go into metadata. + # We should merge the provided fields into metadata, except for memory and embedding. + for k, v in fields.items(): + if k == "metadata" and isinstance(v, dict): + meta.update(v) + elif k not in ("memory", "embedding"): + meta[k] = v + + new_mem = fields.get("memory", node.get("memory")) + new_emb = fields.get("embedding", meta.get("embedding")) + + if new_emb is not None and "embedding" in meta: + del meta["embedding"] + + self.add_node(id, new_mem, meta, user_name=target_user) + + def delete_node(self, id: str, user_name: str | None = None) -> None: + self.delete_node_by_prams([id], user_name=user_name) + + def delete_node_by_prams(self, ids: list[str], user_name: str | None = None) -> None: + if not ids: + return + target_user = user_name or self.user_name + ds = self._get_memories_table() + id_list = ", ".join([f"'{i}'" for i in ids]) + try: + ds.delete(f"id IN ({id_list}) AND user_name = '{target_user}'") + edges_ds = self._get_edges_table() + edges_ds.delete( + f"(source_id IN ({id_list}) OR target_id IN ({id_list})) AND user_name = '{target_user}'" + ) + except Exception as e: + logger.error(f"Error deleting nodes in LanceDB: {e}") + + def get_node(self, id: str, include_embedding: bool = False, **kwargs) -> dict[str, Any] | None: + nodes = self.get_nodes([id], include_embedding=include_embedding, **kwargs) + return nodes[0] if nodes else None + + def get_nodes( + self, + ids: list[str], + include_embedding: bool = False, + user_name: str | None = None, + **kwargs, + ) -> list[dict[str, Any]]: + if not ids: + return [] + target_user = user_name or self.user_name + ds = self._get_memories_table() + id_list = ", ".join([f"'{i}'" for i in ids]) + try: + df = ds.search().where(f"id IN ({id_list}) AND user_name = '{target_user}'").to_pandas() + return [self._parse_row(row, include_embedding) for _, row in df.iterrows()] + except Exception as e: + logger.error(f"Error getting nodes in LanceDB: {e}") + return [] + + def _parse_row(self, row, include_embedding=False) -> dict[str, Any]: + properties = json.loads(row["properties"]) if row.get("properties") else {} + # Restore flattened fields into metadata + if row.get("memory_type"): + properties["memory_type"] = row["memory_type"] + if row.get("status"): + properties["status"] = row["status"] + + if include_embedding and "embedding" in row: + vec = row["embedding"] + properties["embedding"] = vec.tolist() if hasattr(vec, "tolist") else vec + + return {"id": row["id"], "memory": row.get("memory", ""), "metadata": properties} + + def add_edge( + self, source_id: str, target_id: str, type: str, user_name: str | None = None + ) -> None: + target_user = user_name or self.user_name + if self.edge_exists(source_id, target_id, type, user_name=target_user): + return + + now = datetime.now().isoformat() + data = [ + { + "id": str(uuid.uuid4()), + "source_id": source_id, + "target_id": target_id, + "edge_type": type, + "user_name": target_user, + "created_at": now, + } + ] + self._get_edges_table().add(data) + + def delete_edge( + self, source_id: str, target_id: str, type: str, user_name: str | None = None + ) -> None: + target_user = user_name or self.user_name + ds = self._get_edges_table() + try: + ds.delete( + f"source_id = '{source_id}' AND target_id = '{target_id}' AND edge_type = '{type}' AND user_name = '{target_user}'" + ) + except Exception as e: + logger.error(f"Error deleting edge in LanceDB: {e}") + + def edge_exists( + self, source_id: str, target_id: str, type: str, user_name: str | None = None + ) -> bool: + target_user = user_name or self.user_name + ds = self._get_edges_table() + try: + return ( + len( + ds.search() + .where( + f"source_id = '{source_id}' AND target_id = '{target_id}' AND edge_type = '{type}' AND user_name = '{target_user}'" + ) + .limit(1) + .to_list() + ) + > 0 + ) + except Exception: + return False + + def get_neighbors( + self, + node_id: str, + edge_type: str | None = None, + direction: str = "OUT", + user_name: str | None = None, + ) -> list[dict[str, Any]]: + target_user = user_name or self.user_name + ds = self._get_edges_table() + + conditions = [f"user_name = '{target_user}'"] + if edge_type: + conditions.append(f"edge_type = '{edge_type}'") + + if direction == "OUT": + conditions.append(f"source_id = '{node_id}'") + elif direction == "IN": + conditions.append(f"target_id = '{node_id}'") + else: + conditions.append(f"(source_id = '{node_id}' OR target_id = '{node_id}')") + + filter_str = " AND ".join(conditions) + + try: + df = ds.search().where(filter_str).to_pandas() + + neighbor_ids = [] + for _, row in df.iterrows(): + if direction == "OUT": + neighbor_ids.append(row["target_id"]) + elif direction == "IN": + neighbor_ids.append(row["source_id"]) + else: + nid = row["target_id"] if row["source_id"] == node_id else row["source_id"] + neighbor_ids.append(nid) + + if not neighbor_ids: + return [] + + return self.get_nodes(list(set(neighbor_ids)), user_name=target_user) + except Exception as e: + logger.error(f"Error getting neighbors in LanceDB: {e}") + return [] + + def get_path( + self, source_id: str, target_id: str, max_depth: int = 3, user_name: str | None = None + ) -> list[str]: + if source_id == target_id: + return [source_id] + + target_user = user_name or self.user_name + ds = self._get_edges_table() + + queue = [[source_id]] + visited = {source_id} + + while queue: + path = queue.pop(0) + current = path[-1] + + if len(path) > max_depth: + continue + + try: + df = ( + ds.search() + .where(f"source_id = '{current}' AND user_name = '{target_user}'") + .to_pandas() + ) + for _, row in df.iterrows(): + neighbor = row["target_id"] + if neighbor == target_id: + return [*path, neighbor] + if neighbor not in visited: + visited.add(neighbor) + queue.append([*path, neighbor]) + except Exception: + pass + + return [] + + def get_subgraph( + self, + center_id: str, + depth: int = 2, + center_status: str = "activated", + user_name: str | None = None, + ) -> list[str]: + target_user = user_name or self.user_name + + center_node = self.get_node(center_id, user_name=target_user) + if not center_node or center_node.get("metadata", {}).get("status") != center_status: + return [] + + visited = {center_id} + current_layer = {center_id} + ds = self._get_edges_table() + + for _ in range(depth): + next_layer = set() + for node in current_layer: + try: + df = ( + ds.search() + .where( + f"(source_id = '{node}' OR target_id = '{node}') AND user_name = '{target_user}'" + ) + .to_pandas() + ) + for _, row in df.iterrows(): + n1, n2 = row["source_id"], row["target_id"] + if n1 not in visited: + next_layer.add(n1) + visited.add(n1) + if n2 not in visited: + next_layer.add(n2) + visited.add(n2) + except Exception: + pass + current_layer = next_layer + + return list(visited) + + def get_context_chain( + self, id: str, type: str = "FOLLOWS", user_name: str | None = None + ) -> list[str]: + target_user = user_name or self.user_name + chain = [] + current = id + ds = self._get_edges_table() + + while current: + try: + df = ( + ds.search() + .where( + f"source_id = '{current}' AND edge_type = '{type}' AND user_name = '{target_user}'" + ) + .to_pandas() + ) + if df.empty: + break + current = df.iloc[0]["target_id"] + chain.append(current) + except Exception: + break + return chain + + def search_by_embedding( + self, + vector: list[float], + top_k: int = 5, + scope: str | None = None, + status: str | None = None, + threshold: float | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list | None = None, + return_fields: list[str] | None = None, + **kwargs, + ) -> list[dict]: + target_user = user_name or self.user_name + ds = self._get_memories_table() + + conditions = [] + if getattr(self.config, "use_multi_db", False) is False and target_user: + conditions.append(f"user_name = '{target_user}'") + + # Fast scalar filtering using flattened columns + if scope and scope != "All": + conditions.append(f"memory_type = '{scope}'") + if status: + conditions.append(f"status = '{status}'") + + # Fallback to string matching for dynamic JSON properties + if search_filter: + for k, v in search_filter.items(): + if isinstance(v, str): + conditions.append(f'properties LIKE \'%"{k}": "{v}"%\'') + else: + conditions.append(f"properties LIKE '%\"{k}\": {json.dumps(v)}%'") + + where_clause = " AND ".join(conditions) if conditions else None + + try: + query = ds.search(vector, vector_column_name="embedding") + if where_clause: + query = query.where(where_clause) + + df = query.limit(top_k).to_pandas() + results = [] + + for _, row in df.iterrows(): + score = 1.0 - row.get("_distance", 0.0) + if threshold is not None and score < threshold: + continue + + item = {"id": row["id"], "score": score} + + if return_fields: + props = json.loads(row["properties"]) if row.get("properties") else {} + for field in return_fields: + if field == "memory": + item["memory"] = row.get("memory", "") + elif field == "memory_type" or field == "status": + item[field] = row.get(field, "") + elif field in props: + item[field] = props[field] + + results.append(item) + + return results + except Exception as e: + logger.error(f"Error in LanceDB search_by_embedding: {e}") + return [] + + def get_by_metadata( + self, + filters: list[dict[str, Any]], + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + user_name_flag: bool = True, + status: str | None = None, + ) -> list[str]: + target_user = user_name or self.user_name + ds = self._get_memories_table() + + conditions = [] + if user_name_flag: + conditions.append(f"user_name = '{target_user}'") + + if status: + conditions.append(f"status = '{status}'") + + for f in filters: + field = f["field"] + op = f.get("op", "=") + value = f["value"] + + # Use flattened fast columns if possible + if field in ["memory_type", "status"]: + if op == "=": + conditions.append(f"{field} = '{value}'") + elif op == "in": + in_conds = [f"{field} = '{v}'" for v in value] + if in_conds: + conditions.append(f"({' OR '.join(in_conds)})") + continue + + # Use LIKE for JSON properties + if op == "=": + conditions.append(f'properties LIKE \'%"{field}": "{value}"%\'') + elif op == "in": + in_conds = [f'properties LIKE \'%"{field}": "{v}"%\'' for v in value] + if in_conds: + conditions.append(f"({' OR '.join(in_conds)})") + elif op == "contains": + conditions.append(f"properties LIKE '%\"{value}\"%'") + + where_clause = " AND ".join(conditions) if conditions else None + + try: + if where_clause: + df = ds.search().where(where_clause).select(["id"]).to_pandas() + else: + df = ds.search().select(["id"]).to_pandas() + return df["id"].tolist() + except Exception as e: + logger.error(f"Error in LanceDB get_by_metadata: {e}") + return [] + + def search_by_fulltext( + self, + query_words: list[str], + top_k: int = 10, + **kwargs, + ) -> list[dict]: + """ + Implements Native Full-Text Search (FTS) using LanceDB's Tantivy integration. + This enables MemOS to perform Multi-way Recall (Vector + BM25/FTS) seamlessly. + """ + target_user = kwargs.get("user_name") or self.user_name + ds = self._get_memories_table() + query_str = " ".join(query_words) + + try: + # Execute native FTS query + query = ds.search(query_str) + if getattr(self.config, "use_multi_db", False) is False and target_user: + query = query.where(f"user_name = '{target_user}'") + + res = query.limit(top_k).to_list() + results = [] + for row in res: + results.append( + { + "id": row["id"], + "memory": row.get("memory", ""), + "score": row.get("_score", 1.0), # Tantivy relevance score + } + ) + return results + except Exception as e: + logger.error( + f"LanceDB FTS search failed (ensure tantivy is installed and index exists): {e}" + ) + return [] + + def get_all_memory_items( + self, + scope: str, + include_embedding: bool = False, + status: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + **kwargs, + ) -> list[dict]: + target_user = kwargs.get("user_name") or self.user_name + ds = self._get_memories_table() + + conditions = [f"user_name = '{target_user}'"] + if scope: + conditions.append(f"memory_type = '{scope}'") + if status: + conditions.append(f"status = '{status}'") + + if filter: + for k, v in filter.items(): + if isinstance(v, str): + conditions.append(f'properties LIKE \'%"{k}": "{v}"%\'') + else: + import json + + conditions.append(f"properties LIKE '%\"{k}\": {json.dumps(v)}%'") + + where_clause = " AND ".join(conditions) + + try: + df = ds.search().where(where_clause).to_pandas() + results = [] + for _, row in df.iterrows(): + results.append(self._parse_row(row, include_embedding)) + return results + except Exception as e: + logger.error(f"Error getting all memory items in LanceDB: {e}") + return [] + + def get_structure_optimization_candidates( + self, scope: str, user_name: str | None = None, **kwargs + ): + target_user = user_name or ( + self.user_name if getattr(self, "user_name", "default") != "default" else None + ) + ds = self._get_memories_table() + edges_ds = self._get_edges_table() + + try: + # get all memories + query = ds.search().where(f"memory_type = '{scope}' AND status = 'activated'") + if getattr(self.config, "use_multi_db", False) is False and target_user: + query = ds.search().where( + f"memory_type = '{scope}' AND status = 'activated' AND user_name = '{target_user}'" + ) + df_memories = query.to_pandas() + if df_memories.empty: + return [] + + # get all edges to find isolated nodes + edge_query = edges_ds.search() + if getattr(self.config, "use_multi_db", False) is False and target_user: + edge_query = edge_query.where(f"user_name = '{target_user}'") + df_edges = edge_query.to_pandas() + connected_nodes = set() + if not df_edges.empty: + connected_nodes.update(df_edges["source_id"].tolist()) + connected_nodes.update(df_edges["target_id"].tolist()) + + results = [] + for _, row in df_memories.iterrows(): + if row["id"] not in connected_nodes: + results.append(self._parse_row(row, include_embedding=False)) + + return results + except Exception as e: + logger.error(f"Error getting structure optimization candidates in LanceDB: {e}") + return [] + + def deduplicate_nodes(self) -> None: + pass + + def detect_conflicts(self) -> list[tuple[str, str]]: + return [] + + def merge_nodes(self, id1: str, id2: str) -> str: + raise NotImplementedError + + def get_grouped_counts( + self, group_fields: list[str], user_name: str | None = None + ) -> list[dict]: + return [] + + def search_by_hybrid( + self, + query_text: str, + vector: list[float], + top_k: int = 10, + user_name: str | None = None, + reranker: Any | None = None, + **kwargs, + ) -> list[dict]: + target_user = user_name or self.user_name + ds = self._get_memories_table() + + try: + query = ds.search(query_type="hybrid").vector(vector).text(query_text) + if getattr(self.config, "use_multi_db", False) is False and target_user: + query = query.where(f"user_name = '{target_user}'") + + query = query.limit(top_k) + + if reranker: + query = query.rerank(reranker=reranker) + + res = query.to_pandas() + + results = [] + for _, row in res.iterrows(): + results.append( + { + "id": row["id"], + "memory": row.get("memory", ""), + "score": row.get("_score", 1.0), + } + ) + return results + except Exception as e: + logger.error(f"LanceDB Hybrid search failed: {e}") + return [] + + def remove_oldest_memory( + self, memory_type: str, keep_latest: int, user_name: str | None = None + ) -> None: + """ + Keep the latest `keep_latest` memories of a specific `memory_type`, and remove the older ones. + """ + target_user = user_name or self.user_name + ds = self._get_memories_table() + try: + # Query all matching memories sorted by created_at descending + df = ( + ds.search() + .where(f"memory_type = '{memory_type}' AND user_name = '{target_user}'") + .to_pandas() + ) + if len(df) <= keep_latest: + return + + df = df.sort_values(by="created_at", ascending=False) + old_ids = df.iloc[keep_latest:]["id"].tolist() + if old_ids: + self.delete_node_by_prams(old_ids, user_name=target_user) + except Exception as e: + logger.error(f"Error removing oldest memory in LanceDB: {e}") + + def clear(self, user_name: str | None = None) -> None: + target_user = user_name or self.user_name + try: + ds1 = self._get_memories_table() + ds1.delete(f"user_name = '{target_user}'") + ds2 = self._get_edges_table() + ds2.delete(f"user_name = '{target_user}'") + except Exception: + pass + + def export_graph(self, include_embedding: bool = False, **kwargs) -> dict[str, Any]: + return {"nodes": [], "edges": []} + + def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None: + pass + + def get_memory_count(self, scope: str | None = None, user_name: str | None = None) -> int: + target_user = user_name or ( + self.user_name if getattr(self, "user_name", "default") != "default" else None + ) + try: + ds = self._get_memories_table() + where_clauses = [] + if getattr(self.config, "use_multi_db", False) is False and target_user: + where_clauses.append(f"user_name = '{target_user}'") + if scope: + where_clauses.append(f"memory_type = '{scope}'") + + if where_clauses: + query_str = " AND ".join(where_clauses) + return len(ds.search().where(query_str).to_list()) + return ds.count_rows() + except Exception: + return 0 + + def node_not_exist(self, scope: str, user_name: str | None = None) -> bool: + """Check if there is NO node with the given memory_type (scope) for the user.""" + target_user = user_name or ( + self.user_name if getattr(self, "user_name", "default") != "default" else None + ) + try: + ds = self._get_memories_table() + where_clauses = [] + if getattr(self.config, "use_multi_db", False) is False and target_user: + where_clauses.append(f"user_name = '{target_user}'") + if scope: + where_clauses.append(f"memory_type = '{scope}'") + + query_str = " AND ".join(where_clauses) + if query_str: + return len(ds.search().where(query_str).to_list()) == 0 + return ds.count_rows() == 0 + except Exception: + return True + + def close(self) -> None: + pass diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index a57a40676..efd97434f 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -725,7 +725,7 @@ def filter_fault_update(self, operations: list[dict]): for judge in all_judge: valid_update = None if judge["judgement"] == "UPDATE_APPROVED": - valid_update = id2op.get(judge["id"], None) + valid_update = id2op.get(judge["id"]) if valid_update: valid_updates.append(valid_update) diff --git a/src/memos/memories/textual/tree_text_memory/organize/handler.py b/src/memos/memories/textual/tree_text_memory/organize/handler.py index 2d776912b..20312a017 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/handler.py +++ b/src/memos/memories/textual/tree_text_memory/organize/handler.py @@ -155,12 +155,38 @@ def _resolve_in_graph( merged: TextualMemoryItem, user_name: str | None = None, ): - edges_a = self.graph_store.get_edges( - conflict_a.id, type="ANY", direction="ANY", user_name=user_name - ) - edges_b = self.graph_store.get_edges( - conflict_b.id, type="ANY", direction="ANY", user_name=user_name - ) + edges_a = [] + + if hasattr(self.graph_store, "get_edges"): + edges_a = self.graph_store.get_edges( + conflict_a.id, type="ANY", direction="ANY", user_name=user_name + ) + elif hasattr(self.graph_store, "get_neighbors"): + neighbor_nodes_in = self.graph_store.get_neighbors( + conflict_a.id, edge_type="ANY", direction="IN", user_name=user_name + ) + neighbor_nodes_out = self.graph_store.get_neighbors( + conflict_a.id, edge_type="ANY", direction="OUT", user_name=user_name + ) + edges_a = [ + {"from": n["id"], "to": conflict_a.id, "type": "ANY"} for n in neighbor_nodes_in + ] + [{"from": conflict_a.id, "to": n["id"], "type": "ANY"} for n in neighbor_nodes_out] + + edges_b = [] + if hasattr(self.graph_store, "get_edges"): + edges_b = self.graph_store.get_edges( + conflict_b.id, type="ANY", direction="ANY", user_name=user_name + ) + elif hasattr(self.graph_store, "get_neighbors"): + neighbor_nodes_in = self.graph_store.get_neighbors( + conflict_b.id, edge_type="ANY", direction="IN", user_name=user_name + ) + neighbor_nodes_out = self.graph_store.get_neighbors( + conflict_b.id, edge_type="ANY", direction="OUT", user_name=user_name + ) + edges_b = [ + {"from": n["id"], "to": conflict_b.id, "type": "ANY"} for n in neighbor_nodes_in + ] + [{"from": conflict_b.id, "to": n["id"], "type": "ANY"} for n in neighbor_nodes_out] all_edges = edges_a + edges_b self.graph_store.add_node( @@ -175,7 +201,7 @@ def _resolve_in_graph( new_to = merged.id if edge["to"] in (conflict_a.id, conflict_b.id) else edge["to"] if new_from == new_to: continue - # Check if the edge already exists before adding + # Check if the edge already exists before adding it if not self.graph_store.edge_exists( new_from, new_to, edge["type"], direction="ANY", user_name=user_name ): diff --git a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py index b7fb6b1a0..ef604937b 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py +++ b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py @@ -270,10 +270,25 @@ def _check_deadline(where: str): ) return - # Step 2: Partition nodes - if _check_deadline("[GraphStructureReorganize] Before partition"): - return - partitioned_groups = self._partition(nodes) + # Group nodes by user_name to prevent cross-user clustering + from collections import defaultdict + + nodes_by_user = defaultdict(list) + for n in nodes: + # LanceDB may not have user_name in node object directly, but let's try properties or user_name + u_name = getattr(n, "user_name", "default") + nodes_by_user[u_name].append(n) + + partitioned_groups = [] + for u_name, u_nodes in nodes_by_user.items(): + if len(u_nodes) < min_group_size: + logger.info( + f"[GraphStructureReorganize] User {u_name} has only {len(u_nodes)} nodes, skipping." + ) + continue + u_groups = self._partition(u_nodes) + partitioned_groups.extend(u_groups) + logger.info( f"[GraphStructureReorganize] Partitioned into {len(partitioned_groups)} clusters." ) diff --git a/tests/graph_dbs/test_lance.py b/tests/graph_dbs/test_lance.py new file mode 100644 index 000000000..42aa1dffb --- /dev/null +++ b/tests/graph_dbs/test_lance.py @@ -0,0 +1,201 @@ +import os +import tempfile + +from memos.configs.graph_db import LanceGraphDBConfig +from memos.graph_dbs.lance import LanceGraphDB + + +def test_lance_graph_db(): + with tempfile.TemporaryDirectory() as tmpdir: + db_uri = os.path.join(tmpdir, "test_lancedb_data") + + config = LanceGraphDBConfig(uri=db_uri, user_name="test_user", embedding_dimension=3) + + print("\nInitializing LanceGraphDB in temporary directory...") + db = LanceGraphDB(config) + + print("\n--- 1. Testing Node Insertion (Batch & Upsert) ---") + nodes = [ + { + "id": "node_1", + "memory": "Alice went to Beijing", + "metadata": { + "memory_type": "LongTermMemory", + "status": "activated", + "tags": ["travel", "city"], + }, + "embedding": [0.1, 0.2, 0.3], + }, + { + "id": "node_2", + "memory": "Alice visited the Forbidden City", + "metadata": { + "memory_type": "LongTermMemory", + "status": "activated", + "tags": ["travel", "history"], + }, + "embedding": [0.15, 0.25, 0.35], + }, + { + "id": "node_3", + "memory": "Bob likes programming in Python", + "metadata": { + "memory_type": "ShortTermMemory", + "status": "activated", + "tags": ["tech"], + }, + "embedding": [0.8, 0.1, 0.1], + }, + ] + db.add_nodes_batch(nodes) + + n1 = db.get_node("node_1") + assert n1["id"] == "node_1" + assert n1["memory"] == "Alice went to Beijing" + print("Node insertion verified.") + + print("\n--- 2. Testing Edge Insertion ---") + db.add_edge("node_1", "node_2", "FOLLOWS") + db.add_edge("node_1", "node_3", "KNOWS") + print("Edge insertion verified.") + + print("\n--- 3. Testing Vector Search ---") + res_vec = db.search_by_embedding( + [0.12, 0.22, 0.32], top_k=2, return_fields=["memory", "memory_type"] + ) + assert len(res_vec) == 2 + assert res_vec[0]["id"] in ["node_1", "node_2"] + print("Vector search verified.") + + print("\n--- 4. Testing Metadata Filter (Scalar + JSON LIKE) ---") + res_meta = db.get_by_metadata( + filters=[{"field": "tags", "op": "contains", "value": "travel"}], status="activated" + ) + assert "node_1" in res_meta + assert "node_2" in res_meta + print("Metadata filtering verified.") + + print("\n--- 5. Testing Full-Text Search (FTS) ---") + try: + res_fts = db.search_by_fulltext(["Forbidden", "City"], top_k=2) + assert len(res_fts) > 0 + assert res_fts[0]["id"] == "node_2" + print("FTS verified.") + except Exception as e: + print(f"FTS failed: {e}") + + print("\n--- 6. Testing Hybrid Search (Multi-way Recall + Reranker) ---") + try: + from lancedb.rerankers import LinearCombinationReranker, RRFReranker + + res_hybrid_default = db.search_by_hybrid( + query_text="Forbidden", vector=[0.1, 0.2, 0.3], top_k=2 + ) + assert len(res_hybrid_default) > 0 + + ratio_reranker = LinearCombinationReranker(weight=0.8) + res_hybrid_ratio = db.search_by_hybrid( + query_text="Forbidden", vector=[0.1, 0.2, 0.3], top_k=2, reranker=ratio_reranker + ) + assert len(res_hybrid_ratio) > 0 + + rrf_reranker = RRFReranker() + res_hybrid_rrf = db.search_by_hybrid( + query_text="Forbidden", vector=[0.1, 0.2, 0.3], top_k=2, reranker=rrf_reranker + ) + assert len(res_hybrid_rrf) > 0 + print("Hybrid Search (Default/Ratio/RRF) verified.") + except Exception as e: + print(f"Hybrid search failed: {e}") + + print("\n--- 7. Testing Graph Traversal (Neighbors) ---") + neighbors_out = db.get_neighbors("node_1", direction="OUT") + assert len(neighbors_out) == 2 + print("Neighbors traversal verified.") + + print("\n--- 8. Testing Graph Traversal (Subgraph BFS) ---") + subgraph = db.get_subgraph("node_1", depth=1) + assert len(subgraph) == 3 + print("Subgraph BFS verified.") + + print("\n--- 9. Testing Node Update ---") + db.update_node("node_1", {"memory": "Alice went to Beijing and loved it!"}) + n1_updated = db.get_node("node_1") + assert n1_updated["memory"] == "Alice went to Beijing and loved it!" + print("Node update verified.") + + print("\nCleaning up...") + db.clear() + print("Test finished successfully in temporary directory!") + + +def test_lance_compaction_and_fts_effectiveness(): + """ + Test the effectiveness of the LanceDB _optimize_table mechanism, + including compaction of small files and FTS index functionality. + """ + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + db_uri = os.path.join(tmpdir, "test_lancedb_compaction") + # Use a low threshold to force triggering + config = LanceGraphDBConfig( + uri=db_uri, user_name="test_user", embedding_dimension=3, compaction_version_threshold=2 + ) + db = LanceGraphDB(config) + + # 1. Insert multiple single nodes to create small fragments + print("\nInserting 5 separate fragments...") + for i in range(5): + node = { + "id": f"node_c_{i}", + "memory": f"Alice went to the magical forest number {i}", + "metadata": {"memory_type": "LongTermMemory", "status": "activated"}, + "embedding": [0.1 * i, 0.2 * i, 0.3 * i], + } + db.add_nodes_batch([node]) + + import lance + + ds = lance.dataset(os.path.join(db_uri, "memories.lance")) + fragments_before = len(ds.get_fragments()) + print(f"Fragments BEFORE optimize: {fragments_before}") + + # 2. Test FTS before optimization + try: + res_fts_before = db.search_by_fulltext(["magical"], top_k=10) + print(f"FTS hits BEFORE optimize: {len(res_fts_before)}") + except Exception as e: + print(f"FTS failed before optimize: {e}") + + # 3. Force the internal optimizer + print("Forcing LanceDB optimizer...") + db._force_optimize() + + ds = lance.dataset(os.path.join(db_uri, "memories.lance")) + fragments_after = len(ds.get_fragments()) + print(f"Fragments AFTER optimize: {fragments_after}") + + # 5. Verify FTS index effectiveness after optimization + res_fts_after = db.search_by_fulltext(["magical"], top_k=10) + assert len(res_fts_after) == 5, ( + f"FTS should recall all 5 nodes, but got {len(res_fts_after)}" + ) + print(f"FTS hits AFTER optimize: {len(res_fts_after)}") + + # 6. Test prune/delete + db.delete_node("node_c_0") + db._force_optimize() + + res_fts_deleted = db.search_by_fulltext(["magical"], top_k=10) + assert len(res_fts_deleted) == 4, ( + f"FTS should recall 4 nodes after deletion, got {len(res_fts_deleted)}" + ) + print(f"FTS hits AFTER deletion and optimize: {len(res_fts_deleted)}") + + db.clear() + + +if __name__ == "__main__": + test_lance_graph_db() + test_lance_compaction_and_fts_effectiveness() diff --git a/tests/vec_dbs/test_qdrant.py b/tests/vec_dbs/test_qdrant.py index 67f76d463..47d81cd00 100644 --- a/tests/vec_dbs/test_qdrant.py +++ b/tests/vec_dbs/test_qdrant.py @@ -1,6 +1,6 @@ import uuid -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest @@ -89,7 +89,10 @@ def test_search(vec_db): }, )() vec_db.client.query_points.return_value = mock_response - results = vec_db.search([0.1, 0.2, 0.3], top_k=1) + results = vec_db.search( + query_vector=[0.1, 0.2, 0.3], + top_k=1, + ) assert len(results) == 1 assert isinstance(results[0], VecDBItem) assert results[0].score == 0.9 @@ -100,15 +103,20 @@ def test_update_vector(vec_db): data = {"id": id, "vector": [0.4, 0.5, 0.6], "payload": {"new": "data"}} vec_db.update(id, data) vec_db.client.upsert.assert_called_once() + vec_db.client.set_payload.assert_not_called() def test_update_payload_only(vec_db): - vec_db.update("1", {"payload": {"only": "payload"}}) + id = str(uuid.uuid4()) + data = {"id": id, "payload": {"new": "data"}} + vec_db.update(id, data) + vec_db.client.update_vectors.assert_not_called() vec_db.client.set_payload.assert_called_once() def test_delete(vec_db): - vec_db.delete(["1", "2"]) + id = str(uuid.uuid4()) + vec_db.delete([id]) vec_db.client.delete.assert_called_once() @@ -119,12 +127,18 @@ def test_count(vec_db): def test_get_all(vec_db): - vec_db.get_by_filter = MagicMock( - return_value=[VecDBItem(id=str(uuid.uuid4()), vector=[0.1, 0.2, 0.3])] + id1, id2 = str(uuid.uuid4()), str(uuid.uuid4()) + vec_db.client.scroll.return_value = ( + [ + type("obj", (object,), {"id": id1, "vector": [0.1], "payload": {}}), + type("obj", (object,), {"id": id2, "vector": [0.2], "payload": {}}), + ], + None, ) results = vec_db.get_all() - assert len(results) == 1 - assert isinstance(results[0], VecDBItem) + assert len(results) == 2 + assert results[0].id == id1 + assert results[1].id == id2 def test_qdrant_client_cloud_init():