Skip to content

Commit

Permalink
Add FAISS tests
Browse files Browse the repository at this point in the history
  • Loading branch information
redadmiral committed Mar 27, 2024
1 parent 87e6b9f commit 6af815c
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/brdata_rag_tools/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def write_rows(self, rows: List[Type[BaseClass]], create_embeddings: bool = True
if expunge:
session.expunge_all()

def write_index(self):
raise NotImplementedError()

def _create_engine(self):
if self.database is None:
return create_engine("sqlite+pysqlite:///:memory:", echo=self.verbose)
Expand Down
Empty file modified src/brdata_rag_tools/datastructures.py
100755 → 100644
Empty file.
36 changes: 36 additions & 0 deletions test/test_databases.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

from src.brdata_rag_tools.databases import PGVector, FAISS
from src.brdata_rag_tools.embeddings import EmbeddingConfig, Embedder, register

Expand Down Expand Up @@ -83,6 +85,40 @@ class Podcast(abstract_table):
assert isinstance(simcont[0], dict)
assert simcont[0]["cosine_dist"] == 0

def test_faiss_persistence():
db_path = "test/data/faiss.db"
database = FAISS(db_path)
assert type(database) == FAISS

abstract_table = database.create_abstract_embedding_table(EmbeddingConfig.TEST)
assert len(set(abstract_table.__annotations__.keys()) & set(["id", "embedding_source", "embedding"])) == 3

class Podcast(abstract_table):
__tablename__ = "test"
title: Mapped[str] = mapped_column(String)
url: Mapped[str] = mapped_column(String)

database.create_tables()
assert "test" in list(database.metadata.tables.keys())

podcasts = []

for i in range(3):
podcasts.append(Podcast(title="TRUE CRIME - Unter Verdacht",
id=i,
url="example.com",
embedding_source="test")
)
print("bla")
database.write_index()

assert "Podcast.index" in os.listdir(db_path) and "Podcast.mapping" in os.listdir(db_path)

database2 = FAISS(db_path)

assert database2 == database



def test_pgvector(remove_table):
database = PGVector()
Expand Down

0 comments on commit 6af815c

Please sign in to comment.