Skip to content

Commit

Permalink
[Feature] Addition of MongoDB Atlas datastore (#428)
Browse files Browse the repository at this point in the history
* docker compose file.

* search example.

* mongodb atlas datastore.

* refactor, docstring and notebook cleaning.

* docstring.

* fix attributes names.

* Functional tests.

* Example adjustement.

* setup.md

* remove some useless comments.

* wrong docker image.

* Minor documentation fixes.

* Update example.

* refactor.

* default as a default collection.

* TODO resolved.

* Refactor delete.

* fix readme and setup.md

* add warning when delete without criteria.

* rename private function.

* replace pymongo to motor and fix integration test.

* Refactor code and adjust tests

* wait for assert function.

* Update docs/providers/mongodb_atlas/setup.md

Co-authored-by: Jib <[email protected]>

* Update datastore/providers/mongodb_atlas_datastore.py

Co-authored-by: Jib <[email protected]>

* Increase oversampling factor to 10.

* Update tests/datastore/providers/mongodb_atlas/test_mongodb_datastore.py

Co-authored-by: Jib <[email protected]>

* Update tests/datastore/providers/mongodb_atlas/test_mongodb_datastore.py

Co-authored-by: Jib <[email protected]>

* Update datastore/providers/mongodb_atlas_datastore.py

Co-authored-by: Jib <[email protected]>

* Init docstring.

* Default parameters

* Update datastore/providers/mongodb_atlas_datastore.py

Co-authored-by: Jib <[email protected]>

* refactor sample_embeddings.

* Apply suggestions from code review

Co-authored-by: Jib <[email protected]>

* refactor delete.

* Version added.

* Update datastore/providers/mongodb_atlas_datastore.py

Co-authored-by: Jib <[email protected]>

* Removed _atlas from folder name to keep it simple and self-consistent

* Expanded setup.md

* Fixed a couple typos in docstrings

* Add optional EMBEDDING_DIMENSION to get_embedding

* Fixed typo in kwarg

* Extended setup.md

* Edits to environment variable table

* Added authentication token descriptions

* Removed hardcoded vector size

* Added semantic search example

* Added instructions to integration tests

* Cleanup

* Removed pathname from example.

* Override DataStore.upsert in  MongoDBAtlasDataStore to increase performance.

* upsert now returns ids of chunks, which is what each datastore document is

* Added full integration test

* test_integration now uses FastAPI TestClient

* Retries query until response contains number requested

---------

Co-authored-by: Emanuel Lupi <[email protected]>
Co-authored-by: Jib <[email protected]>
  • Loading branch information
3 people authored Apr 24, 2024
1 parent b808c10 commit b28ddce
Show file tree
Hide file tree
Showing 11 changed files with 1,477 additions and 6 deletions.
15 changes: 13 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ This README provides detailed information on how to set up, develop, and deploy
- [Choosing a Vector Database](#choosing-a-vector-database)
- [Pinecone](#pinecone)
- [Elasticsearch](#elasticsearch)
- [MongoDB Atlas](#mongodb-atlas)
- [Weaviate](#weaviate)
- [Zilliz](#zilliz)
- [Milvus](#milvus)
Expand Down Expand Up @@ -190,6 +191,12 @@ Follow these steps to quickly set up and run the ChatGPT Retrieval Plugin:
export ELASTICSEARCH_INDEX=<elasticsearch_index_name>
export ELASTICSEARCH_REPLICAS=<elasticsearch_replicas>
export ELASTICSEARCH_SHARDS=<elasticsearch_shards>
# MongoDB Atlas
export MONGODB_URI=<mongodb_uri>
export MONGODB_DATABASE=<mongodb_database>
export MONGODB_COLLECTION=<mongodb_collection>
export MONGODB_INDEX=<mongodb_index>
```

10. Run the API locally: `poetry run start`
Expand Down Expand Up @@ -352,8 +359,8 @@ poetry install
The API requires the following environment variables to work:

| Name | Required | Description |
| ---------------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `DATASTORE` | Yes | This specifies the vector database provider you want to use to store and query embeddings. You can choose from `elasticsearch`, `chroma`, `pinecone`, `weaviate`, `zilliz`, `milvus`, `qdrant`, `redis`, `azuresearch`, `supabase`, `postgres`, `analyticdb`. |
| ---------------- | -------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `DATASTORE` | Yes | This specifies the vector database provider you want to use to store and query embeddings. You can choose from `elasticsearch`, `chroma`, `pinecone`, `weaviate`, `zilliz`, `milvus`, `qdrant`, `redis`, `azuresearch`, `supabase`, `postgres`, `analyticdb`, `mongodb-atlas`. |
| `BEARER_TOKEN` | Yes | This is a secret token that you need to authenticate your requests to the API. You can generate one using any tool or method you prefer, such as [jwt.io](https://jwt.io/). |
| `OPENAI_API_KEY` | Yes | This is your OpenAI API key that you need to generate embeddings using the one of the OpenAI embeddings model. You can get an API key by creating an account on [OpenAI](https://openai.com/). |

Expand Down Expand Up @@ -434,6 +441,10 @@ For detailed setup instructions, refer to [`/docs/providers/llama/setup.md`](/do

[Elasticsearch](https://www.elastic.co/guide/en/elasticsearch/reference/current/index.html) currently supports storing vectors through the `dense_vector` field type and uses them to calculate document scores. Elasticsearch 8.0 builds on this functionality to support fast, approximate nearest neighbor search (ANN). This represents a much more scalable approach, allowing vector search to run efficiently on large datasets. For detailed setup instructions, refer to [`/docs/providers/elasticsearch/setup.md`](/docs/providers/elasticsearch/setup.md).

#### Mongodb-Atlas

[MongoDB Atlas](https://www.mongodb.com/docs/atlas/getting-started/) Currently, the procedure involves generating an Atlas Vector Search index for all collections featuring vector embeddings of 2048 dimensions or fewer in width. This applies to diverse data types coexisting with additional data on your Atlas cluster, and the process is executed through the Atlas UI and Atlas Administration AP, refer to [`/docs/providers/mongodb_atlas/setup.md`](/docs/providers/mongodb_atlas/setup.md).

### Running the API locally

To run the API locally, you first need to set the requisite environment variables with the `export` command:
Expand Down
4 changes: 2 additions & 2 deletions datastore/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def upsert(
@abstractmethod
async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]:
"""
Takes in a list of list of document chunks and inserts them into the database.
Takes in a list of document chunks and inserts them into the database.
Return a list of document ids.
"""

Expand All @@ -54,7 +54,7 @@ async def query(self, queries: List[Query]) -> List[QueryResult]:
"""
Takes in a list of queries and filters and returns a list of query results with matching document chunks and scores.
"""
# get a list of of just the queries from the Query list
# get a list of just the queries from the Query list
query_texts = [query.query for query in queries]
query_embeddings = get_embeddings(query_texts)
# hydrate the queries with embeddings
Expand Down
6 changes: 6 additions & 0 deletions datastore/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ async def get_datastore() -> DataStore:
)

return ElasticsearchDataStore()
case "mongodb":
from datastore.providers.mongodb_atlas_datastore import (
MongoDBAtlasDataStore,
)

return MongoDBAtlasDataStore()
case _:
raise ValueError(
f"Unsupported vector database: {datastore}. "
Expand Down
252 changes: 252 additions & 0 deletions datastore/providers/mongodb_atlas_datastore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
import os
from typing import Dict, List, Any, Optional
from loguru import logger
from importlib.metadata import version
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo.driver_info import DriverInfo
from pymongo import UpdateOne

from datastore.datastore import DataStore
from functools import cached_property
from models.models import (
Document,
DocumentChunk,
DocumentChunkWithScore,
DocumentMetadataFilter,
QueryResult,
QueryWithEmbedding,
)
from services.chunks import get_document_chunks
from services.date import to_unix_timestamp


MONGODB_CONNECTION_URI = os.environ.get("MONGODB_URI")
MONGODB_DATABASE = os.environ.get("MONGODB_DATABASE", "default")
MONGODB_COLLECTION = os.environ.get("MONGODB_COLLECTION", "default")
MONGODB_INDEX = os.environ.get("MONGODB_INDEX", "default")
OVERSAMPLING_FACTOR = 10
MAX_CANDIDATES = 10_000


class MongoDBAtlasDataStore(DataStore):

def __init__(
self,
atlas_connection_uri: str = MONGODB_CONNECTION_URI,
index_name: str = MONGODB_INDEX,
database_name: str = MONGODB_DATABASE,
collection_name: str = MONGODB_COLLECTION,
oversampling_factor: float = OVERSAMPLING_FACTOR,
):
"""
Initialize a MongoDBAtlasDataStore instance.
Parameters:
- index_name (str, optional): Vector search index. If not provided, default index name is used.
- database_name (str, optional): Database. If not provided, default database name is used.
- collection_name (str, optional): Collection. If not provided, default collection name is used.
- oversampling_factor (float, optional): Oversampling factor for data augmentation.
Default is OVERSAMPLING_FACTOR.
Raises:
- ValueError: If index_name is not a valid string.
Attributes:
- index_name (str): Name of the index.
- database_name (str): Name of the database.
- collection_name (str): Name of the collection.
- oversampling_factor (float): Oversampling factor for data augmentation.
"""

self.atlas_connection_uri = atlas_connection_uri
self.oversampling_factor = oversampling_factor
self.database_name = database_name
self.collection_name = collection_name

if not (index_name and isinstance(index_name, str)):
raise ValueError("Provide a valid index name")
self.index_name = index_name

# TODO: Create index via driver https://jira.mongodb.org/browse/PYTHON-4175
# self._create_search_index(num_dimensions=1536, path="embedding", similarity="dotProduct", type="vector")

@cached_property
def client(self):
return self._connect_to_mongodb_atlas(
atlas_connection_uri=MONGODB_CONNECTION_URI
)

async def upsert(
self, documents: List[Document], chunk_token_size: Optional[int] = None
) -> List[str]:
"""
Takes in a list of Documents, chunks them, and upserts the chunks into the database.
Return a list the ids of the document chunks.
"""
chunks = get_document_chunks(documents, chunk_token_size)
return await self._upsert(chunks)

async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]:
"""
Takes in a list of document chunks and inserts them into the database.
Return a list of document ids.
"""
documents_to_upsert = []
inserted_ids = []
for chunk_list in chunks.values():
for chunk in chunk_list:
inserted_ids.append(chunk.id)
documents_to_upsert.append(
UpdateOne({'_id': chunk.id}, {"$set": chunk.dict()}, upsert=True)
)
logger.info(f"Upsert documents into MongoDB collection: {self.database_name}: {self.collection_name}")
await self.client[self.database_name][self.collection_name].bulk_write(documents_to_upsert)
logger.info("Upsert successful")

return inserted_ids

async def _query(
self,
queries: List[QueryWithEmbedding],
) -> List[QueryResult]:
"""
Takes in a list of queries with embeddings and filters and returns
a list of query results with matching document chunks and scores.
"""
results = []
for query in queries:
query_result = await self._execute_embedding_query(query)
results.append(query_result)

return results

async def _execute_embedding_query(self, query: QueryWithEmbedding) -> QueryResult:
"""
Execute a MongoDB query using vector search on the specified collection and
return the result of the query, including matched documents and their scores.
"""
pipeline = [
{
'$vectorSearch': {
'index': self.index_name,
'path': 'embedding',
'queryVector': query.embedding,
'numCandidates': min(query.top_k * self.oversampling_factor, MAX_CANDIDATES),
'limit': query.top_k
}
}, {
'$project': {
'text': 1,
'metadata': 1,
'score': {
'$meta': 'vectorSearchScore'
}
}
}
]

async with self.client[self.database_name][self.collection_name].aggregate(pipeline) as cursor:
results = [
self._convert_mongodb_document_to_document_chunk_with_score(doc)
async for doc in cursor
]

return QueryResult(
query=query.query,
results=results,
)

async def delete(
self,
ids: Optional[List[str]] = None,
filter: Optional[DocumentMetadataFilter] = None,
delete_all: Optional[bool] = None,
) -> bool:
"""
Removes documents by ids, filter, or everything in the datastore.
Returns whether the operation was successful.
Note that ids refer to those in the datastore,
which are those of the **DocumentChunks**
"""
# Delete all documents from the collection if delete_all is True
if delete_all:
logger.info("Deleting all documents from collection")
mg_filter = {}

# Delete by ids
elif ids:
logger.info(f"Deleting documents with ids: {ids}")
mg_filter = {"_id": {"$in": ids}}

# Delete by filters
elif filter:
mg_filter = self._build_mongo_filter(filter)
logger.info(f"Deleting documents with filter: {mg_filter}")
# Do nothing
else:
logger.warning("No criteria set; nothing to delete args: ids: %s, filter: %s delete_all: %s", ids, filter, delete_all)
return True

try:
await self.client[self.database_name][self.collection_name].delete_many(mg_filter)
logger.info("Deleted documents successfully")
except Exception as e:
logger.error("Error deleting documents with filter: %s -- error: %s", mg_filter, e)
return False

return True

def _convert_mongodb_document_to_document_chunk_with_score(
self, document: Dict
) -> DocumentChunkWithScore:
# Convert MongoDB document to DocumentChunkWithScore
return DocumentChunkWithScore(
id=document.get("_id"),
text=document["text"],
metadata=document.get("metadata"),
score=document.get("score"),
)

def _build_mongo_filter(
self, filter: Optional[DocumentMetadataFilter] = None
) -> Dict[str, Any]:
"""
Generate MongoDB query filters based on the provided DocumentMetadataFilter.
"""
if filter is None:
return {}

mongo_filters = {
"$and": [],
}

# For each field in the MetadataFilter,
# check if it has a value and add the corresponding MongoDB filter expression
for field, value in filter.dict().items():
if value is not None:
if field == "start_date":
mongo_filters["$and"].append(
{"created_at": {"$gte": to_unix_timestamp(value)}}
)
elif field == "end_date":
mongo_filters["$and"].append(
{"created_at": {"$lte": to_unix_timestamp(value)}}
)
else:
mongo_filters["$and"].append(
{f"metadata.{field}": value}
)

return mongo_filters

@staticmethod
def _connect_to_mongodb_atlas(atlas_connection_uri: str):
"""
Establish a connection to MongoDB Atlas.
"""

client = AsyncIOMotorClient(
atlas_connection_uri,
driver=DriverInfo(name="Chatgpt Retrieval Plugin", version=version("chatgpt_retrieval_plugin")))
return client
Loading

0 comments on commit b28ddce

Please sign in to comment.