From 6af815c0024dbdb9cfe40560ed273b66c017335f Mon Sep 17 00:00:00 2001 From: Marco Lehner Date: Wed, 27 Mar 2024 16:06:29 +0100 Subject: [PATCH] Add FAISS tests --- src/brdata_rag_tools/databases.py | 3 +++ src/brdata_rag_tools/datastructures.py | 0 test/test_databases.py | 36 ++++++++++++++++++++++++++ 3 files changed, 39 insertions(+) mode change 100755 => 100644 src/brdata_rag_tools/datastructures.py diff --git a/src/brdata_rag_tools/databases.py b/src/brdata_rag_tools/databases.py index 9cdf94a..31775a4 100644 --- a/src/brdata_rag_tools/databases.py +++ b/src/brdata_rag_tools/databases.py @@ -185,6 +185,9 @@ def write_rows(self, rows: List[Type[BaseClass]], create_embeddings: bool = True if expunge: session.expunge_all() + def write_index(self): + raise NotImplementedError() + def _create_engine(self): if self.database is None: return create_engine("sqlite+pysqlite:///:memory:", echo=self.verbose) diff --git a/src/brdata_rag_tools/datastructures.py b/src/brdata_rag_tools/datastructures.py old mode 100755 new mode 100644 diff --git a/test/test_databases.py b/test/test_databases.py index 860e9d7..a6ac5bf 100755 --- a/test/test_databases.py +++ b/test/test_databases.py @@ -1,3 +1,5 @@ +import os + from src.brdata_rag_tools.databases import PGVector, FAISS from src.brdata_rag_tools.embeddings import EmbeddingConfig, Embedder, register @@ -83,6 +85,40 @@ class Podcast(abstract_table): assert isinstance(simcont[0], dict) assert simcont[0]["cosine_dist"] == 0 +def test_faiss_persistence(): + db_path = "test/data/faiss.db" + database = FAISS(db_path) + assert type(database) == FAISS + + abstract_table = database.create_abstract_embedding_table(EmbeddingConfig.TEST) + assert len(set(abstract_table.__annotations__.keys()) & set(["id", "embedding_source", "embedding"])) == 3 + + class Podcast(abstract_table): + __tablename__ = "test" + title: Mapped[str] = mapped_column(String) + url: Mapped[str] = mapped_column(String) + + database.create_tables() + assert "test" in list(database.metadata.tables.keys()) + + podcasts = [] + + for i in range(3): + podcasts.append(Podcast(title="TRUE CRIME - Unter Verdacht", + id=i, + url="example.com", + embedding_source="test") + ) + print("bla") + database.write_index() + + assert "Podcast.index" in os.listdir(db_path) and "Podcast.mapping" in os.listdir(db_path) + + database2 = FAISS(db_path) + + assert database2 == database + + def test_pgvector(remove_table): database = PGVector()