Skip to content

Commit

Permalink
✨ Add threshold param
Browse files Browse the repository at this point in the history
This parameter allows users to filter results by a similarity threshold, so only the most similar results are returned.
  • Loading branch information
redadmiral committed Feb 7, 2024
1 parent 0ac2eb4 commit 06d7aea
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions src/brdata_rag_tools/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def drop_table(self, name: str):
session.commit()

def retrieve_similar_content(self, prompt, table: Type[BaseClass],
embedding_type: EmbeddingConfig, limit: int = 50):
embedding_type: EmbeddingConfig, limit: int = 50, threshold: float = None):
raise NotImplementedError()

def get_existing_row_ids(self, table: BaseClass):
Expand Down Expand Up @@ -193,7 +193,7 @@ def _create_engine(self):

def retrieve_similar_content(self, prompt, table: Type[BaseClass],
embedding_type: EmbeddingConfig,
limit: int = 50) -> List:
limit: int = 50, max_dist: float = None) -> List:
"""
Retrieve similar content based on a prompt. The function creates an embedding with the specified embedding type
and queries the associated database for the most similar matches.
Expand Down Expand Up @@ -228,8 +228,9 @@ def retrieve_similar_content(self, prompt, table: Type[BaseClass],
for i, row in enumerate(results):
results[i][0].embedding = np.frombuffer(row[0].embedding)
d = results[i]._asdict()
d["cosine_dist"] = distances[i]
dict_result.append(d)
if max_dist is not None and distances[i] < max_dist:
d["cosine_dist"] = distances[i]
dict_result.append(d)

return dict_result

Expand Down Expand Up @@ -292,7 +293,7 @@ def _create_engine(self):
echo=self.verbose)

def retrieve_similar_content(self, prompt, table: Type[BaseClass],
embedding_type: EmbeddingConfig, limit: int = 50):
embedding_type: EmbeddingConfig, limit: int = 50, max_dist: float = None):
"""
Retrieve similar content based on a prompt. The function creates an embedding with the specified embedding type
and queries the associated database for the most similar matches.
Expand All @@ -310,6 +311,12 @@ def retrieve_similar_content(self, prompt, table: Type[BaseClass],
results = session.execute(select(table, table.embedding.cosine_distance(
prompt_embedding).label("cosine_dist")).order_by("cosine_dist").limit(
limit)).all()

for i in range(len(results)):
if max_dist <= results[i].cosine_dist:
results = results[:i]
break

return [x._asdict() for x in results]

def retrieve_embedding(self, row_id: str, table: BaseClass) -> np.array:
Expand Down

0 comments on commit 06d7aea

Please sign in to comment.