-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Addition of MongoDB Atlas datastore (#428)
* docker compose file. * search example. * mongodb atlas datastore. * refactor, docstring and notebook cleaning. * docstring. * fix attributes names. * Functional tests. * Example adjustement. * setup.md * remove some useless comments. * wrong docker image. * Minor documentation fixes. * Update example. * refactor. * default as a default collection. * TODO resolved. * Refactor delete. * fix readme and setup.md * add warning when delete without criteria. * rename private function. * replace pymongo to motor and fix integration test. * Refactor code and adjust tests * wait for assert function. * Update docs/providers/mongodb_atlas/setup.md Co-authored-by: Jib <[email protected]> * Update datastore/providers/mongodb_atlas_datastore.py Co-authored-by: Jib <[email protected]> * Increase oversampling factor to 10. * Update tests/datastore/providers/mongodb_atlas/test_mongodb_datastore.py Co-authored-by: Jib <[email protected]> * Update tests/datastore/providers/mongodb_atlas/test_mongodb_datastore.py Co-authored-by: Jib <[email protected]> * Update datastore/providers/mongodb_atlas_datastore.py Co-authored-by: Jib <[email protected]> * Init docstring. * Default parameters * Update datastore/providers/mongodb_atlas_datastore.py Co-authored-by: Jib <[email protected]> * refactor sample_embeddings. * Apply suggestions from code review Co-authored-by: Jib <[email protected]> * refactor delete. * Version added. * Update datastore/providers/mongodb_atlas_datastore.py Co-authored-by: Jib <[email protected]> * Removed _atlas from folder name to keep it simple and self-consistent * Expanded setup.md * Fixed a couple typos in docstrings * Add optional EMBEDDING_DIMENSION to get_embedding * Fixed typo in kwarg * Extended setup.md * Edits to environment variable table * Added authentication token descriptions * Removed hardcoded vector size * Added semantic search example * Added instructions to integration tests * Cleanup * Removed pathname from example. * Override DataStore.upsert in MongoDBAtlasDataStore to increase performance. * upsert now returns ids of chunks, which is what each datastore document is * Added full integration test * test_integration now uses FastAPI TestClient * Retries query until response contains number requested --------- Co-authored-by: Emanuel Lupi <[email protected]> Co-authored-by: Jib <[email protected]>
- Loading branch information
1 parent
b808c10
commit b28ddce
Showing
11 changed files
with
1,477 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,252 @@ | ||
import os | ||
from typing import Dict, List, Any, Optional | ||
from loguru import logger | ||
from importlib.metadata import version | ||
from motor.motor_asyncio import AsyncIOMotorClient | ||
from pymongo.driver_info import DriverInfo | ||
from pymongo import UpdateOne | ||
|
||
from datastore.datastore import DataStore | ||
from functools import cached_property | ||
from models.models import ( | ||
Document, | ||
DocumentChunk, | ||
DocumentChunkWithScore, | ||
DocumentMetadataFilter, | ||
QueryResult, | ||
QueryWithEmbedding, | ||
) | ||
from services.chunks import get_document_chunks | ||
from services.date import to_unix_timestamp | ||
|
||
|
||
MONGODB_CONNECTION_URI = os.environ.get("MONGODB_URI") | ||
MONGODB_DATABASE = os.environ.get("MONGODB_DATABASE", "default") | ||
MONGODB_COLLECTION = os.environ.get("MONGODB_COLLECTION", "default") | ||
MONGODB_INDEX = os.environ.get("MONGODB_INDEX", "default") | ||
OVERSAMPLING_FACTOR = 10 | ||
MAX_CANDIDATES = 10_000 | ||
|
||
|
||
class MongoDBAtlasDataStore(DataStore): | ||
|
||
def __init__( | ||
self, | ||
atlas_connection_uri: str = MONGODB_CONNECTION_URI, | ||
index_name: str = MONGODB_INDEX, | ||
database_name: str = MONGODB_DATABASE, | ||
collection_name: str = MONGODB_COLLECTION, | ||
oversampling_factor: float = OVERSAMPLING_FACTOR, | ||
): | ||
""" | ||
Initialize a MongoDBAtlasDataStore instance. | ||
Parameters: | ||
- index_name (str, optional): Vector search index. If not provided, default index name is used. | ||
- database_name (str, optional): Database. If not provided, default database name is used. | ||
- collection_name (str, optional): Collection. If not provided, default collection name is used. | ||
- oversampling_factor (float, optional): Oversampling factor for data augmentation. | ||
Default is OVERSAMPLING_FACTOR. | ||
Raises: | ||
- ValueError: If index_name is not a valid string. | ||
Attributes: | ||
- index_name (str): Name of the index. | ||
- database_name (str): Name of the database. | ||
- collection_name (str): Name of the collection. | ||
- oversampling_factor (float): Oversampling factor for data augmentation. | ||
""" | ||
|
||
self.atlas_connection_uri = atlas_connection_uri | ||
self.oversampling_factor = oversampling_factor | ||
self.database_name = database_name | ||
self.collection_name = collection_name | ||
|
||
if not (index_name and isinstance(index_name, str)): | ||
raise ValueError("Provide a valid index name") | ||
self.index_name = index_name | ||
|
||
# TODO: Create index via driver https://jira.mongodb.org/browse/PYTHON-4175 | ||
# self._create_search_index(num_dimensions=1536, path="embedding", similarity="dotProduct", type="vector") | ||
|
||
@cached_property | ||
def client(self): | ||
return self._connect_to_mongodb_atlas( | ||
atlas_connection_uri=MONGODB_CONNECTION_URI | ||
) | ||
|
||
async def upsert( | ||
self, documents: List[Document], chunk_token_size: Optional[int] = None | ||
) -> List[str]: | ||
""" | ||
Takes in a list of Documents, chunks them, and upserts the chunks into the database. | ||
Return a list the ids of the document chunks. | ||
""" | ||
chunks = get_document_chunks(documents, chunk_token_size) | ||
return await self._upsert(chunks) | ||
|
||
async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]: | ||
""" | ||
Takes in a list of document chunks and inserts them into the database. | ||
Return a list of document ids. | ||
""" | ||
documents_to_upsert = [] | ||
inserted_ids = [] | ||
for chunk_list in chunks.values(): | ||
for chunk in chunk_list: | ||
inserted_ids.append(chunk.id) | ||
documents_to_upsert.append( | ||
UpdateOne({'_id': chunk.id}, {"$set": chunk.dict()}, upsert=True) | ||
) | ||
logger.info(f"Upsert documents into MongoDB collection: {self.database_name}: {self.collection_name}") | ||
await self.client[self.database_name][self.collection_name].bulk_write(documents_to_upsert) | ||
logger.info("Upsert successful") | ||
|
||
return inserted_ids | ||
|
||
async def _query( | ||
self, | ||
queries: List[QueryWithEmbedding], | ||
) -> List[QueryResult]: | ||
""" | ||
Takes in a list of queries with embeddings and filters and returns | ||
a list of query results with matching document chunks and scores. | ||
""" | ||
results = [] | ||
for query in queries: | ||
query_result = await self._execute_embedding_query(query) | ||
results.append(query_result) | ||
|
||
return results | ||
|
||
async def _execute_embedding_query(self, query: QueryWithEmbedding) -> QueryResult: | ||
""" | ||
Execute a MongoDB query using vector search on the specified collection and | ||
return the result of the query, including matched documents and their scores. | ||
""" | ||
pipeline = [ | ||
{ | ||
'$vectorSearch': { | ||
'index': self.index_name, | ||
'path': 'embedding', | ||
'queryVector': query.embedding, | ||
'numCandidates': min(query.top_k * self.oversampling_factor, MAX_CANDIDATES), | ||
'limit': query.top_k | ||
} | ||
}, { | ||
'$project': { | ||
'text': 1, | ||
'metadata': 1, | ||
'score': { | ||
'$meta': 'vectorSearchScore' | ||
} | ||
} | ||
} | ||
] | ||
|
||
async with self.client[self.database_name][self.collection_name].aggregate(pipeline) as cursor: | ||
results = [ | ||
self._convert_mongodb_document_to_document_chunk_with_score(doc) | ||
async for doc in cursor | ||
] | ||
|
||
return QueryResult( | ||
query=query.query, | ||
results=results, | ||
) | ||
|
||
async def delete( | ||
self, | ||
ids: Optional[List[str]] = None, | ||
filter: Optional[DocumentMetadataFilter] = None, | ||
delete_all: Optional[bool] = None, | ||
) -> bool: | ||
""" | ||
Removes documents by ids, filter, or everything in the datastore. | ||
Returns whether the operation was successful. | ||
Note that ids refer to those in the datastore, | ||
which are those of the **DocumentChunks** | ||
""" | ||
# Delete all documents from the collection if delete_all is True | ||
if delete_all: | ||
logger.info("Deleting all documents from collection") | ||
mg_filter = {} | ||
|
||
# Delete by ids | ||
elif ids: | ||
logger.info(f"Deleting documents with ids: {ids}") | ||
mg_filter = {"_id": {"$in": ids}} | ||
|
||
# Delete by filters | ||
elif filter: | ||
mg_filter = self._build_mongo_filter(filter) | ||
logger.info(f"Deleting documents with filter: {mg_filter}") | ||
# Do nothing | ||
else: | ||
logger.warning("No criteria set; nothing to delete args: ids: %s, filter: %s delete_all: %s", ids, filter, delete_all) | ||
return True | ||
|
||
try: | ||
await self.client[self.database_name][self.collection_name].delete_many(mg_filter) | ||
logger.info("Deleted documents successfully") | ||
except Exception as e: | ||
logger.error("Error deleting documents with filter: %s -- error: %s", mg_filter, e) | ||
return False | ||
|
||
return True | ||
|
||
def _convert_mongodb_document_to_document_chunk_with_score( | ||
self, document: Dict | ||
) -> DocumentChunkWithScore: | ||
# Convert MongoDB document to DocumentChunkWithScore | ||
return DocumentChunkWithScore( | ||
id=document.get("_id"), | ||
text=document["text"], | ||
metadata=document.get("metadata"), | ||
score=document.get("score"), | ||
) | ||
|
||
def _build_mongo_filter( | ||
self, filter: Optional[DocumentMetadataFilter] = None | ||
) -> Dict[str, Any]: | ||
""" | ||
Generate MongoDB query filters based on the provided DocumentMetadataFilter. | ||
""" | ||
if filter is None: | ||
return {} | ||
|
||
mongo_filters = { | ||
"$and": [], | ||
} | ||
|
||
# For each field in the MetadataFilter, | ||
# check if it has a value and add the corresponding MongoDB filter expression | ||
for field, value in filter.dict().items(): | ||
if value is not None: | ||
if field == "start_date": | ||
mongo_filters["$and"].append( | ||
{"created_at": {"$gte": to_unix_timestamp(value)}} | ||
) | ||
elif field == "end_date": | ||
mongo_filters["$and"].append( | ||
{"created_at": {"$lte": to_unix_timestamp(value)}} | ||
) | ||
else: | ||
mongo_filters["$and"].append( | ||
{f"metadata.{field}": value} | ||
) | ||
|
||
return mongo_filters | ||
|
||
@staticmethod | ||
def _connect_to_mongodb_atlas(atlas_connection_uri: str): | ||
""" | ||
Establish a connection to MongoDB Atlas. | ||
""" | ||
|
||
client = AsyncIOMotorClient( | ||
atlas_connection_uri, | ||
driver=DriverInfo(name="Chatgpt Retrieval Plugin", version=version("chatgpt_retrieval_plugin"))) | ||
return client |
Oops, something went wrong.