diff --git a/.gitignore b/.gitignore index a60302b..6be6f5f 100755 --- a/.gitignore +++ b/.gitignore @@ -163,4 +163,5 @@ docs/build .pypirc *.tar.gz *.whl -*.db \ No newline at end of file +*.db +.idea \ No newline at end of file diff --git a/docs/requirements.txt b/docs/requirements.txt index 0474921..e8bc130 100755 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -5,6 +5,7 @@ certifi==2023.11.17 cffi==1.16.0 charset-normalizer==3.3.2 cryptography==41.0.7 +chromadb==0.4.17 docutils==0.20.1 idna==3.6 imagesize==1.4.1 @@ -17,7 +18,7 @@ markdown-it-py==3.0.0 MarkupSafe==2.1.3 mdurl==0.1.2 more-itertools==10.2.0 -nh3==0.2.15 +opentelemetry-api~=1.12 packaging==23.2 pkginfo==1.9.6 pycparser==2.21 diff --git a/pyproject.toml b/pyproject.toml index e17aac2..470e34e 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "annotated-types==0.6.0", "anyio==4.2.0", "certifi==2023.11.17", + "chromadb==0.4.17", "charset-normalizer==3.3.2", "distro==1.9.0", "exceptiongroup==1.2.0", diff --git a/src/brdata_rag_tools/databases.py b/src/brdata_rag_tools/databases.py index 9cdf94a..e6e1b50 100644 --- a/src/brdata_rag_tools/databases.py +++ b/src/brdata_rag_tools/databases.py @@ -5,6 +5,10 @@ import faiss import numpy as np from pgvector.sqlalchemy import Vector + +import chromadb +from chromadb.api.types import Where + from sqlalchemy import String, text, BLOB from sqlalchemy import create_engine, select from sqlalchemy.orm import Session, Mapped, mapped_column @@ -140,6 +144,7 @@ def write_rows(self, rows: List[Type[BaseClass]], create_embeddings: bool = True if expunge: session.expunge_all() + @dataclass class IndexWrapper: embeddings: faiss.METRIC_INNER_PRODUCT @@ -334,3 +339,178 @@ def retrieve_embedding(self, row_id: str, table: BaseClass) -> np.array: table.id == row_id).first() return embedding + + +class Chroma(Database): + """ + This class represents a locally run ChromaDB. + + :param user: The username to connect to the database, if run as a different service. + :type user: str + :param database: The name of the database to connect to, if run as a different service. + :type database: str + :param password: The password to connect to the database, if run as a different service. If not provided, it will use the value of the "DATABASE_PASSWORD" environment variable. + :type password: str + :param host: The host address of the database, if run as a different service. Default is "localhost". + :type host: str + :param port: The port number of the database, if run as a different service. Default is 8000. + :type port: int + :param verbose: Whether to enable verbose output. Default is False. + :type verbose: bool + """ + def __init__(self, user: str = None, database: str = None, password: str = None, + host: str = None, port: int = None, verbose: bool = False): + super().__init__(user, database, password, host, port, verbose, vector_type=Vector) + + def _create_engine(self): + if self.database: + # run productively + return chromadb.PersistentClient(path=self.database) + else: + # run for test purposes without persistence + return chromadb.Client() + + def write_rows(self, rows: List[Type[BaseClass]], create_embeddings: bool = True, expunge=False, expire=True): + """ + Write rows to the database and optionally create embeddings for the rows. + + :param rows: A list of rows to be written to the database. Rows must be instances of BaseClass or its subclasses. + :param create_embeddings: A boolean value indicating whether embeddings should be created for the rows. + Default value is True. + :return: None + """ + table = type(rows[0]) + collection = self.engine.get_or_create_collection(name=table.__name__, metadata={"hnsw:space": "cosine"}) # https://docs.trychroma.com/usage-guide#changing-the-distance-function + + rows_with_embedding = self.get_existing_row_ids(table) + rows_wo_embedding = [x for x in rows if x.id not in rows_with_embedding] + + embedder = rows[0].embedding_type.model + custom_embedder = type(embedder).__name__ != 'ChromaEmbedder' + + if custom_embedder: + newly_embedded = embedder.create_embedding_bulk(rows_wo_embedding) + else: + newly_embedded = rows_wo_embedding + + for i, row in enumerate(newly_embedded): + # kill unneeded metadata from class + # not very elegant, i guess ;-) + metadata = row.__dict__.copy() + del metadata['id'] + del metadata['embedding_source'] + if custom_embedder: + del metadata['embedding'] + internal_keys = [k for k in list(metadata.keys()) if k.startswith("_")] + for ik in internal_keys: + del metadata[ik] + + # ChromaDB does not accept some kind of metadata, so change to str + dt = [k for k, v in metadata.items() if type(v) not in [str, float, int, bool]] + for d in dt: + metadata[d] = str(metadata[d]) + + if custom_embedder: + collection.add(documents=str(row.embedding_source), + embeddings=list(row.embedding), + metadatas=metadata, + ids=row.id + ) + else: + # use ChromaDB's own embedding + collection.add(documents=str(row.embedding_source), + metadatas=metadata, + ids=row.id + ) + + def update_rows(self, entries: List, update_metadatas: bool = False): + """ + Update Entries + + :return: None + """ + table = type(entries[0]) + + collection = self.engine.get_collection(name=table.__name__) + + ids = [] + metadatas = [] + for e in entries: + # kill unneeded metadata from class + # XXX not very elegant, i guess ;-) + ids.append(e.id) + metadata = e.__dict__.copy() + del metadata['id'] + internal_keys = [k for k in list(metadata.keys()) if k.startswith("_")] + for ik in internal_keys: + del metadata[ik] + + # ChromaDB does not accept some kind of metadata, so change to str + dt = [k for k, v in metadata.items() if type(v) not in [str, float, int, bool]] + for d in dt: + metadata[d] = str(metadata[d]) + metadatas.append(metadata) + + collection.update(ids=ids, + metadatas=metadatas, + ) + + def retrieve_similar_content(self, prompt, table: Type[BaseClass], + embedding_type: EmbeddingConfig = None, + limit: int = 50, max_dist: float = 100, where: Where = {}) -> List: + """ + Retrieve similar content based on a prompt. The function creates an embedding with the specified embedding type + and queries the associated database for the most similar matches. + + :param prompt: The prompt for which similar content needs to be found. + :param table: The table in which the content is stored. + :param embedding_type: The type of embedding to be used. (default: None, stored in table class) + :param limit: The maximum number of similar content to be retrieved (default: 50). + :param max_dist: The maximum cosine distance between embedding vectors (default: 100) + :param: where: query metadata parameters, see chromadb docs for details + :return: A list of results containing similar content. + """ + if embedding_type: + embedder = embedding_type.model + else: + embedder = table.embedding_type.model + + custom_embedder = type(embedder).__name__ != 'ChromaEmbedder' + collection = self.engine.get_or_create_collection(name=table.__name__) + + if custom_embedder: + prompt_embedding = list(embedder.create_embedding(prompt)) + query = collection.query(query_embeddings=[prompt_embedding], n_results=limit, where=where) + else: + # Use ChromaDB´s own embedder + query = collection.query(query_texts=[prompt], n_results=limit, where=where) + + results = [] + if len(query) > 0: + documents = query['documents'][0] + metadatas = query['metadatas'][0] + distances = query['distances'][0] + for i, id in enumerate(query['ids'][0]): + if distances[i] > max_dist: + break + entry = table() + entry.__dict__.update(metadatas[i]) + entry.__dict__.update(id=id, embedding_source=documents[i], cosine_dist = distances[i]) + results.append(entry) + + return results + + def retrieve_embedding(self, row_id: str, table: BaseClass) -> np.array: + collection = self.engine.get_collection(name=table.__name__) + all = collection.get(ids=[row_id], include=['embeddings']) + + return np.array(all['embeddings'][0]) + + def create_tables(self): + return NotImplementedError("create_tables is not implemented") + + def get_existing_row_ids(self, table: BaseClass): + collection = self.engine.get_collection(name=table.__name__) + all = collection.get(include=[]) + + return all['ids'] \ No newline at end of file diff --git a/src/brdata_rag_tools/embeddings.py b/src/brdata_rag_tools/embeddings.py index fd5c034..75e420d 100644 --- a/src/brdata_rag_tools/embeddings.py +++ b/src/brdata_rag_tools/embeddings.py @@ -12,6 +12,9 @@ "sentence_transformers": { "dimension": 1024 }, + "ChromaEmbedder": { + "dimension": 1024 # not sure... XXX + }, } user_models = {} @@ -49,6 +52,7 @@ class EmbeddingConfig(Enum): """ SENTENCE_TRANSFORMERS = "sentence_transformers" TF_IDF = "tfidf" + CHROMAEMBEDDER = "ChromaEmbedder" @property def dimension(self): @@ -69,6 +73,8 @@ def model(self): return SentenceTransformer() elif self == self.TF_IDF: raise NotImplementedError() + elif self == self.CHROMAEMBEDDER: + return ChromaEmbedder() else: try: return user_models[self.value]() @@ -213,4 +219,14 @@ def create_embedding_bulk(self, rows: List[Type[BaseClass]]) -> List[ return rows +class ChromaEmbedder(Embedder): + # Chroma's own embedder, no need for implementation + def __init__(self, endpoint: str = None, auth_token: Optional[str] = None): + self.endpoint = "" + self.auth_token = "" + + def create_embedding_bulk(self, rows: List[Type[BaseClass]]): + raise NotImplementedError() + def create_embedding(self, text: str) -> np.array: + raise NotImplementedError() diff --git a/test/test_databases.py b/test/test_databases.py index 860e9d7..e6489a2 100755 --- a/test/test_databases.py +++ b/test/test_databases.py @@ -1,4 +1,4 @@ -from src.brdata_rag_tools.databases import PGVector, FAISS +from src.brdata_rag_tools.databases import PGVector, FAISS, Chroma from src.brdata_rag_tools.embeddings import EmbeddingConfig, Embedder, register from sqlalchemy.orm import Mapped, mapped_column @@ -126,3 +126,38 @@ class Podcast(abstract_table): assert len(simcont) == 1 assert isinstance(simcont[0], dict) assert simcont[0]["cosine_dist"] == 0 + +def test_chroma(): + database = Chroma() + assert type(database) == Chroma + + abstract_table = database.create_abstract_embedding_table(EmbeddingConfig.CHROMAEMBEDDER) + assert len(set(abstract_table.__annotations__.keys()) & set(["id", "embedding_source", "embedding"])) == 3 + + class Podcast(abstract_table): + __tablename__ = "testchroma" + title: Mapped[str] = mapped_column(String) + url: Mapped[str] = mapped_column(String) + + podcasts = [] + + for i in range(3): + podcasts.append(Podcast(title="TRUE CRIME - Unter Verdacht", + id=str(i), # ChromaDb only accepts strings as ID + url="example.com", + embedding_source="test") + ) + + podcasts.append(Podcast(title="TRUE CRIME - Unter Verdacht", + id="4", + url="example.com", + embedding_source="Different Vector") + ) + + database.write_rows(podcasts, create_embeddings=True) + + simcont = database.retrieve_similar_content(prompt="Hallo Test.", table=Podcast, max_dist=0.5) + + assert len(simcont) == 3 + assert isinstance(simcont[0], Podcast) + assert simcont[0].cosine_dist < 0.5