diff --git a/test/test_databases.py b/test/test_databases.py index 6ca4512..860e9d7 100755 --- a/test/test_databases.py +++ b/test/test_databases.py @@ -1,11 +1,13 @@ from src.brdata_rag_tools.databases import PGVector, FAISS -from src.brdata_rag_tools.embeddings import EmbeddingConfig +from src.brdata_rag_tools.embeddings import EmbeddingConfig, Embedder, register from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy import String, text +import numpy as np from pytest import fixture + @fixture def remove_table(): database = PGVector() @@ -13,6 +15,30 @@ def remove_table(): yield database.drop_table("test") + +class Test(Embedder): + def __init__(self): + super().__init__(endpoint="example.com", auth_token=None) + + def create_embedding_bulk(self, rows): + """ + Takes an list of SQLAlchemy Table classes as input and returns them with embeddings assigned. + """ + for row in rows: + row.embedding = self.create_embedding(row.embedding_source) + + return rows + + def create_embedding(self, text: str) -> np.array: + if text == "test": + return np.array([1, 2, 3]) + else: + return np.array([4, 5, 6]) + + +register(Test, name="test", dimension=3) + + def test_faiss(): database = FAISS() assert type(database) == FAISS @@ -28,21 +54,32 @@ class Podcast(abstract_table): database.create_tables() assert "test" in list(database.metadata.tables.keys()) - podcast1 = Podcast(title="TRUE CRIME - Unter Verdacht", - id="1", - url = "example.com", - embedding_source="Wer wird hier zu Recht, wer zu Unrecht verdächtigt? Was, wenn Menschen unschuldig verurteilt werden und ihnen niemand glaubt? Oder andersherum: Wenn der wahre Täter oder die wahre Täterin ohne Strafe davonkommen? Unter Verdacht - In der 7. Staffel des erfolgreichen BAYERN 3 True Crime Podcasts sprechen Strafverteidiger Dr. Alexander Stevens und BAYERN 3 Moderatorin Jacqueline Belle über neue spannende Kriminalfälle. Diesmal geht es um Menschen, die unter Verdacht geraten sind. Wer ist schuldig? Wer lügt, wer sagt die Wahrheit? Und werden am Ende immer die Richtigen verurteilt?") + podcasts = [] + + for i in range(3): + podcasts.append(Podcast(title="TRUE CRIME - Unter Verdacht", + id=i, + url="example.com", + embedding_source="test") + ) - database.write_rows([podcast1], create_embeddings=True) + podcasts.append(Podcast(title="TRUE CRIME - Unter Verdacht", + id=4, + url="example.com", + embedding_source="Different Vector") + ) + + database.write_rows(podcasts, create_embeddings=True) with database.session() as s: response = s.execute(text("SELECT count(1) as count FROM test;")).first() - assert response.count == 1 + assert response.count == 4 - simcont = database.retrieve_similar_content(prompt = "Hallo Test.", embedding_type=EmbeddingConfig.TEST, table=Podcast) + simcont = database.retrieve_similar_content(prompt="different vector.", embedding_type=EmbeddingConfig.TEST, + table=Podcast, max_dist=.05) - assert len(simcont) == 1 + assert len(simcont) == 1 # filters out 3 results assert isinstance(simcont[0], dict) assert simcont[0]["cosine_dist"] == 0 @@ -61,20 +98,30 @@ class Podcast(abstract_table): database.create_tables() assert "test" in list(database.metadata.tables.keys()) + podcasts = [] + + for i in range(3): + podcasts.append(Podcast(title="TRUE CRIME - Unter Verdacht", + id=i, + url="example.com", + embedding_source="test") + ) - podcast1 = Podcast(title="TRUE CRIME - Unter Verdacht", - id="1", - url = "example.com", - embedding_source="Wer wird hier zu Recht, wer zu Unrecht verdächtigt? Was, wenn Menschen unschuldig verurteilt werden und ihnen niemand glaubt? Oder andersherum: Wenn der wahre Täter oder die wahre Täterin ohne Strafe davonkommen? Unter Verdacht - In der 7. Staffel des erfolgreichen BAYERN 3 True Crime Podcasts sprechen Strafverteidiger Dr. Alexander Stevens und BAYERN 3 Moderatorin Jacqueline Belle über neue spannende Kriminalfälle. Diesmal geht es um Menschen, die unter Verdacht geraten sind. Wer ist schuldig? Wer lügt, wer sagt die Wahrheit? Und werden am Ende immer die Richtigen verurteilt?") + podcasts.append(Podcast(title="TRUE CRIME - Unter Verdacht", + id=4, + url="example.com", + embedding_source="Different Vector") + ) - database.write_rows([podcast1], create_embeddings=True) + database.write_rows(podcasts, create_embeddings=True) with database.session() as s: response = s.execute(text("SELECT count(1) FROM test;")).first() - assert response.count == 1 + assert response.count == 4 - simcont = database.retrieve_similar_content(prompt = "Hallo Test.", embedding_type=EmbeddingConfig.TEST, table=Podcast) + simcont = database.retrieve_similar_content(prompt="Hallo Test.", embedding_type=EmbeddingConfig.TEST, + table=Podcast, max_dist=.02) assert len(simcont) == 1 assert isinstance(simcont[0], dict)