Skip to content

Commit

Permalink
💡 Test filtering for maximum cosine distance
Browse files Browse the repository at this point in the history
  • Loading branch information
redadmiral committed Feb 7, 2024
1 parent 06d7aea commit b497a46
Showing 1 changed file with 63 additions and 16 deletions.
79 changes: 63 additions & 16 deletions test/test_databases.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,44 @@
from src.brdata_rag_tools.databases import PGVector, FAISS
from src.brdata_rag_tools.embeddings import EmbeddingConfig
from src.brdata_rag_tools.embeddings import EmbeddingConfig, Embedder, register

from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import String, text
import numpy as np

from pytest import fixture


@fixture
def remove_table():
database = PGVector()
database.drop_table("test")
yield
database.drop_table("test")


class Test(Embedder):
def __init__(self):
super().__init__(endpoint="example.com", auth_token=None)

def create_embedding_bulk(self, rows):
"""
Takes an list of SQLAlchemy Table classes as input and returns them with embeddings assigned.
"""
for row in rows:
row.embedding = self.create_embedding(row.embedding_source)

return rows

def create_embedding(self, text: str) -> np.array:
if text == "test":
return np.array([1, 2, 3])
else:
return np.array([4, 5, 6])


register(Test, name="test", dimension=3)


def test_faiss():
database = FAISS()
assert type(database) == FAISS
Expand All @@ -28,21 +54,32 @@ class Podcast(abstract_table):
database.create_tables()
assert "test" in list(database.metadata.tables.keys())

podcast1 = Podcast(title="TRUE CRIME - Unter Verdacht",
id="1",
url = "example.com",
embedding_source="Wer wird hier zu Recht, wer zu Unrecht verdächtigt? Was, wenn Menschen unschuldig verurteilt werden und ihnen niemand glaubt? Oder andersherum: Wenn der wahre Täter oder die wahre Täterin ohne Strafe davonkommen? Unter Verdacht - In der 7. Staffel des erfolgreichen BAYERN 3 True Crime Podcasts sprechen Strafverteidiger Dr. Alexander Stevens und BAYERN 3 Moderatorin Jacqueline Belle über neue spannende Kriminalfälle. Diesmal geht es um Menschen, die unter Verdacht geraten sind. Wer ist schuldig? Wer lügt, wer sagt die Wahrheit? Und werden am Ende immer die Richtigen verurteilt?")
podcasts = []

for i in range(3):
podcasts.append(Podcast(title="TRUE CRIME - Unter Verdacht",
id=i,
url="example.com",
embedding_source="test")
)

database.write_rows([podcast1], create_embeddings=True)
podcasts.append(Podcast(title="TRUE CRIME - Unter Verdacht",
id=4,
url="example.com",
embedding_source="Different Vector")
)

database.write_rows(podcasts, create_embeddings=True)

with database.session() as s:
response = s.execute(text("SELECT count(1) as count FROM test;")).first()

assert response.count == 1
assert response.count == 4

simcont = database.retrieve_similar_content(prompt = "Hallo Test.", embedding_type=EmbeddingConfig.TEST, table=Podcast)
simcont = database.retrieve_similar_content(prompt="different vector.", embedding_type=EmbeddingConfig.TEST,
table=Podcast, max_dist=.05)

assert len(simcont) == 1
assert len(simcont) == 1 # filters out 3 results
assert isinstance(simcont[0], dict)
assert simcont[0]["cosine_dist"] == 0

Expand All @@ -61,20 +98,30 @@ class Podcast(abstract_table):

database.create_tables()
assert "test" in list(database.metadata.tables.keys())
podcasts = []

for i in range(3):
podcasts.append(Podcast(title="TRUE CRIME - Unter Verdacht",
id=i,
url="example.com",
embedding_source="test")
)

podcast1 = Podcast(title="TRUE CRIME - Unter Verdacht",
id="1",
url = "example.com",
embedding_source="Wer wird hier zu Recht, wer zu Unrecht verdächtigt? Was, wenn Menschen unschuldig verurteilt werden und ihnen niemand glaubt? Oder andersherum: Wenn der wahre Täter oder die wahre Täterin ohne Strafe davonkommen? Unter Verdacht - In der 7. Staffel des erfolgreichen BAYERN 3 True Crime Podcasts sprechen Strafverteidiger Dr. Alexander Stevens und BAYERN 3 Moderatorin Jacqueline Belle über neue spannende Kriminalfälle. Diesmal geht es um Menschen, die unter Verdacht geraten sind. Wer ist schuldig? Wer lügt, wer sagt die Wahrheit? Und werden am Ende immer die Richtigen verurteilt?")
podcasts.append(Podcast(title="TRUE CRIME - Unter Verdacht",
id=4,
url="example.com",
embedding_source="Different Vector")
)

database.write_rows([podcast1], create_embeddings=True)
database.write_rows(podcasts, create_embeddings=True)

with database.session() as s:
response = s.execute(text("SELECT count(1) FROM test;")).first()

assert response.count == 1
assert response.count == 4

simcont = database.retrieve_similar_content(prompt = "Hallo Test.", embedding_type=EmbeddingConfig.TEST, table=Podcast)
simcont = database.retrieve_similar_content(prompt="Hallo Test.", embedding_type=EmbeddingConfig.TEST,
table=Podcast, max_dist=.02)

assert len(simcont) == 1
assert isinstance(simcont[0], dict)
Expand Down

0 comments on commit b497a46

Please sign in to comment.