From b28ddce58474441da332d4e15c6dd60ddaa953ab Mon Sep 17 00:00:00 2001 From: Casey Clements Date: Wed, 24 Apr 2024 13:48:16 -0400 Subject: [PATCH] [Feature] Addition of MongoDB Atlas datastore (#428) * docker compose file. * search example. * mongodb atlas datastore. * refactor, docstring and notebook cleaning. * docstring. * fix attributes names. * Functional tests. * Example adjustement. * setup.md * remove some useless comments. * wrong docker image. * Minor documentation fixes. * Update example. * refactor. * default as a default collection. * TODO resolved. * Refactor delete. * fix readme and setup.md * add warning when delete without criteria. * rename private function. * replace pymongo to motor and fix integration test. * Refactor code and adjust tests * wait for assert function. * Update docs/providers/mongodb_atlas/setup.md Co-authored-by: Jib * Update datastore/providers/mongodb_atlas_datastore.py Co-authored-by: Jib * Increase oversampling factor to 10. * Update tests/datastore/providers/mongodb_atlas/test_mongodb_datastore.py Co-authored-by: Jib * Update tests/datastore/providers/mongodb_atlas/test_mongodb_datastore.py Co-authored-by: Jib * Update datastore/providers/mongodb_atlas_datastore.py Co-authored-by: Jib * Init docstring. * Default parameters * Update datastore/providers/mongodb_atlas_datastore.py Co-authored-by: Jib * refactor sample_embeddings. * Apply suggestions from code review Co-authored-by: Jib * refactor delete. * Version added. * Update datastore/providers/mongodb_atlas_datastore.py Co-authored-by: Jib * Removed _atlas from folder name to keep it simple and self-consistent * Expanded setup.md * Fixed a couple typos in docstrings * Add optional EMBEDDING_DIMENSION to get_embedding * Fixed typo in kwarg * Extended setup.md * Edits to environment variable table * Added authentication token descriptions * Removed hardcoded vector size * Added semantic search example * Added instructions to integration tests * Cleanup * Removed pathname from example. * Override DataStore.upsert in MongoDBAtlasDataStore to increase performance. * upsert now returns ids of chunks, which is what each datastore document is * Added full integration test * test_integration now uses FastAPI TestClient * Retries query until response contains number requested --------- Co-authored-by: Emanuel Lupi Co-authored-by: Jib --- README.md | 15 +- datastore/datastore.py | 4 +- datastore/factory.py | 6 + .../providers/mongodb_atlas_datastore.py | 252 +++++++ docs/providers/mongodb/setup.md | 139 ++++ .../providers/mongodb/semantic-search.ipynb | 654 ++++++++++++++++++ poetry.lock | 27 +- pyproject.toml | 1 + services/openai.py | 3 +- .../mongodb_atlas/test_integration.py | 128 ++++ .../mongodb_atlas/test_mongodb_datastore.py | 254 +++++++ 11 files changed, 1477 insertions(+), 6 deletions(-) create mode 100644 datastore/providers/mongodb_atlas_datastore.py create mode 100644 docs/providers/mongodb/setup.md create mode 100644 examples/providers/mongodb/semantic-search.ipynb create mode 100644 tests/datastore/providers/mongodb_atlas/test_integration.py create mode 100644 tests/datastore/providers/mongodb_atlas/test_mongodb_datastore.py diff --git a/README.md b/README.md index edad38158..3d1ec311e 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ This README provides detailed information on how to set up, develop, and deploy - [Choosing a Vector Database](#choosing-a-vector-database) - [Pinecone](#pinecone) - [Elasticsearch](#elasticsearch) + - [MongoDB Atlas](#mongodb-atlas) - [Weaviate](#weaviate) - [Zilliz](#zilliz) - [Milvus](#milvus) @@ -190,6 +191,12 @@ Follow these steps to quickly set up and run the ChatGPT Retrieval Plugin: export ELASTICSEARCH_INDEX= export ELASTICSEARCH_REPLICAS= export ELASTICSEARCH_SHARDS= + + # MongoDB Atlas + export MONGODB_URI= + export MONGODB_DATABASE= + export MONGODB_COLLECTION= + export MONGODB_INDEX= ``` 10. Run the API locally: `poetry run start` @@ -352,8 +359,8 @@ poetry install The API requires the following environment variables to work: | Name | Required | Description | -| ---------------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `DATASTORE` | Yes | This specifies the vector database provider you want to use to store and query embeddings. You can choose from `elasticsearch`, `chroma`, `pinecone`, `weaviate`, `zilliz`, `milvus`, `qdrant`, `redis`, `azuresearch`, `supabase`, `postgres`, `analyticdb`. | +| ---------------- | -------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `DATASTORE` | Yes | This specifies the vector database provider you want to use to store and query embeddings. You can choose from `elasticsearch`, `chroma`, `pinecone`, `weaviate`, `zilliz`, `milvus`, `qdrant`, `redis`, `azuresearch`, `supabase`, `postgres`, `analyticdb`, `mongodb-atlas`. | | `BEARER_TOKEN` | Yes | This is a secret token that you need to authenticate your requests to the API. You can generate one using any tool or method you prefer, such as [jwt.io](https://jwt.io/). | | `OPENAI_API_KEY` | Yes | This is your OpenAI API key that you need to generate embeddings using the one of the OpenAI embeddings model. You can get an API key by creating an account on [OpenAI](https://openai.com/). | @@ -434,6 +441,10 @@ For detailed setup instructions, refer to [`/docs/providers/llama/setup.md`](/do [Elasticsearch](https://www.elastic.co/guide/en/elasticsearch/reference/current/index.html) currently supports storing vectors through the `dense_vector` field type and uses them to calculate document scores. Elasticsearch 8.0 builds on this functionality to support fast, approximate nearest neighbor search (ANN). This represents a much more scalable approach, allowing vector search to run efficiently on large datasets. For detailed setup instructions, refer to [`/docs/providers/elasticsearch/setup.md`](/docs/providers/elasticsearch/setup.md). +#### Mongodb-Atlas + +[MongoDB Atlas](https://www.mongodb.com/docs/atlas/getting-started/) Currently, the procedure involves generating an Atlas Vector Search index for all collections featuring vector embeddings of 2048 dimensions or fewer in width. This applies to diverse data types coexisting with additional data on your Atlas cluster, and the process is executed through the Atlas UI and Atlas Administration AP, refer to [`/docs/providers/mongodb_atlas/setup.md`](/docs/providers/mongodb_atlas/setup.md). + ### Running the API locally To run the API locally, you first need to set the requisite environment variables with the `export` command: diff --git a/datastore/datastore.py b/datastore/datastore.py index ff0c79dd8..6941b4cf6 100644 --- a/datastore/datastore.py +++ b/datastore/datastore.py @@ -44,7 +44,7 @@ async def upsert( @abstractmethod async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]: """ - Takes in a list of list of document chunks and inserts them into the database. + Takes in a list of document chunks and inserts them into the database. Return a list of document ids. """ @@ -54,7 +54,7 @@ async def query(self, queries: List[Query]) -> List[QueryResult]: """ Takes in a list of queries and filters and returns a list of query results with matching document chunks and scores. """ - # get a list of of just the queries from the Query list + # get a list of just the queries from the Query list query_texts = [query.query for query in queries] query_embeddings = get_embeddings(query_texts) # hydrate the queries with embeddings diff --git a/datastore/factory.py b/datastore/factory.py index 67f17eba5..577a6f1b1 100644 --- a/datastore/factory.py +++ b/datastore/factory.py @@ -68,6 +68,12 @@ async def get_datastore() -> DataStore: ) return ElasticsearchDataStore() + case "mongodb": + from datastore.providers.mongodb_atlas_datastore import ( + MongoDBAtlasDataStore, + ) + + return MongoDBAtlasDataStore() case _: raise ValueError( f"Unsupported vector database: {datastore}. " diff --git a/datastore/providers/mongodb_atlas_datastore.py b/datastore/providers/mongodb_atlas_datastore.py new file mode 100644 index 000000000..631fd192c --- /dev/null +++ b/datastore/providers/mongodb_atlas_datastore.py @@ -0,0 +1,252 @@ +import os +from typing import Dict, List, Any, Optional +from loguru import logger +from importlib.metadata import version +from motor.motor_asyncio import AsyncIOMotorClient +from pymongo.driver_info import DriverInfo +from pymongo import UpdateOne + +from datastore.datastore import DataStore +from functools import cached_property +from models.models import ( + Document, + DocumentChunk, + DocumentChunkWithScore, + DocumentMetadataFilter, + QueryResult, + QueryWithEmbedding, +) +from services.chunks import get_document_chunks +from services.date import to_unix_timestamp + + +MONGODB_CONNECTION_URI = os.environ.get("MONGODB_URI") +MONGODB_DATABASE = os.environ.get("MONGODB_DATABASE", "default") +MONGODB_COLLECTION = os.environ.get("MONGODB_COLLECTION", "default") +MONGODB_INDEX = os.environ.get("MONGODB_INDEX", "default") +OVERSAMPLING_FACTOR = 10 +MAX_CANDIDATES = 10_000 + + +class MongoDBAtlasDataStore(DataStore): + + def __init__( + self, + atlas_connection_uri: str = MONGODB_CONNECTION_URI, + index_name: str = MONGODB_INDEX, + database_name: str = MONGODB_DATABASE, + collection_name: str = MONGODB_COLLECTION, + oversampling_factor: float = OVERSAMPLING_FACTOR, + ): + """ + Initialize a MongoDBAtlasDataStore instance. + + Parameters: + - index_name (str, optional): Vector search index. If not provided, default index name is used. + - database_name (str, optional): Database. If not provided, default database name is used. + - collection_name (str, optional): Collection. If not provided, default collection name is used. + - oversampling_factor (float, optional): Oversampling factor for data augmentation. + Default is OVERSAMPLING_FACTOR. + + Raises: + - ValueError: If index_name is not a valid string. + + Attributes: + - index_name (str): Name of the index. + - database_name (str): Name of the database. + - collection_name (str): Name of the collection. + - oversampling_factor (float): Oversampling factor for data augmentation. + """ + + self.atlas_connection_uri = atlas_connection_uri + self.oversampling_factor = oversampling_factor + self.database_name = database_name + self.collection_name = collection_name + + if not (index_name and isinstance(index_name, str)): + raise ValueError("Provide a valid index name") + self.index_name = index_name + + # TODO: Create index via driver https://jira.mongodb.org/browse/PYTHON-4175 + # self._create_search_index(num_dimensions=1536, path="embedding", similarity="dotProduct", type="vector") + + @cached_property + def client(self): + return self._connect_to_mongodb_atlas( + atlas_connection_uri=MONGODB_CONNECTION_URI + ) + + async def upsert( + self, documents: List[Document], chunk_token_size: Optional[int] = None + ) -> List[str]: + """ + Takes in a list of Documents, chunks them, and upserts the chunks into the database. + Return a list the ids of the document chunks. + """ + chunks = get_document_chunks(documents, chunk_token_size) + return await self._upsert(chunks) + + async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]: + """ + Takes in a list of document chunks and inserts them into the database. + Return a list of document ids. + """ + documents_to_upsert = [] + inserted_ids = [] + for chunk_list in chunks.values(): + for chunk in chunk_list: + inserted_ids.append(chunk.id) + documents_to_upsert.append( + UpdateOne({'_id': chunk.id}, {"$set": chunk.dict()}, upsert=True) + ) + logger.info(f"Upsert documents into MongoDB collection: {self.database_name}: {self.collection_name}") + await self.client[self.database_name][self.collection_name].bulk_write(documents_to_upsert) + logger.info("Upsert successful") + + return inserted_ids + + async def _query( + self, + queries: List[QueryWithEmbedding], + ) -> List[QueryResult]: + """ + Takes in a list of queries with embeddings and filters and returns + a list of query results with matching document chunks and scores. + """ + results = [] + for query in queries: + query_result = await self._execute_embedding_query(query) + results.append(query_result) + + return results + + async def _execute_embedding_query(self, query: QueryWithEmbedding) -> QueryResult: + """ + Execute a MongoDB query using vector search on the specified collection and + return the result of the query, including matched documents and their scores. + """ + pipeline = [ + { + '$vectorSearch': { + 'index': self.index_name, + 'path': 'embedding', + 'queryVector': query.embedding, + 'numCandidates': min(query.top_k * self.oversampling_factor, MAX_CANDIDATES), + 'limit': query.top_k + } + }, { + '$project': { + 'text': 1, + 'metadata': 1, + 'score': { + '$meta': 'vectorSearchScore' + } + } + } + ] + + async with self.client[self.database_name][self.collection_name].aggregate(pipeline) as cursor: + results = [ + self._convert_mongodb_document_to_document_chunk_with_score(doc) + async for doc in cursor + ] + + return QueryResult( + query=query.query, + results=results, + ) + + async def delete( + self, + ids: Optional[List[str]] = None, + filter: Optional[DocumentMetadataFilter] = None, + delete_all: Optional[bool] = None, + ) -> bool: + """ + Removes documents by ids, filter, or everything in the datastore. + Returns whether the operation was successful. + + Note that ids refer to those in the datastore, + which are those of the **DocumentChunks** + """ + # Delete all documents from the collection if delete_all is True + if delete_all: + logger.info("Deleting all documents from collection") + mg_filter = {} + + # Delete by ids + elif ids: + logger.info(f"Deleting documents with ids: {ids}") + mg_filter = {"_id": {"$in": ids}} + + # Delete by filters + elif filter: + mg_filter = self._build_mongo_filter(filter) + logger.info(f"Deleting documents with filter: {mg_filter}") + # Do nothing + else: + logger.warning("No criteria set; nothing to delete args: ids: %s, filter: %s delete_all: %s", ids, filter, delete_all) + return True + + try: + await self.client[self.database_name][self.collection_name].delete_many(mg_filter) + logger.info("Deleted documents successfully") + except Exception as e: + logger.error("Error deleting documents with filter: %s -- error: %s", mg_filter, e) + return False + + return True + + def _convert_mongodb_document_to_document_chunk_with_score( + self, document: Dict + ) -> DocumentChunkWithScore: + # Convert MongoDB document to DocumentChunkWithScore + return DocumentChunkWithScore( + id=document.get("_id"), + text=document["text"], + metadata=document.get("metadata"), + score=document.get("score"), + ) + + def _build_mongo_filter( + self, filter: Optional[DocumentMetadataFilter] = None + ) -> Dict[str, Any]: + """ + Generate MongoDB query filters based on the provided DocumentMetadataFilter. + """ + if filter is None: + return {} + + mongo_filters = { + "$and": [], + } + + # For each field in the MetadataFilter, + # check if it has a value and add the corresponding MongoDB filter expression + for field, value in filter.dict().items(): + if value is not None: + if field == "start_date": + mongo_filters["$and"].append( + {"created_at": {"$gte": to_unix_timestamp(value)}} + ) + elif field == "end_date": + mongo_filters["$and"].append( + {"created_at": {"$lte": to_unix_timestamp(value)}} + ) + else: + mongo_filters["$and"].append( + {f"metadata.{field}": value} + ) + + return mongo_filters + + @staticmethod + def _connect_to_mongodb_atlas(atlas_connection_uri: str): + """ + Establish a connection to MongoDB Atlas. + """ + + client = AsyncIOMotorClient( + atlas_connection_uri, + driver=DriverInfo(name="Chatgpt Retrieval Plugin", version=version("chatgpt_retrieval_plugin"))) + return client diff --git a/docs/providers/mongodb/setup.md b/docs/providers/mongodb/setup.md new file mode 100644 index 000000000..a1df10da3 --- /dev/null +++ b/docs/providers/mongodb/setup.md @@ -0,0 +1,139 @@ +# Setting up MongoDB Atlas as the Datastore Provider + +MongoDB Atlas is a multi-cloud database service made by the same people that build MongoDB. +Atlas simplifies deploying and managing your databases while offering the versatility you need +to build resilient and performant global applications on the cloud providers of your choice. + +You can perform semantic search on data in your Atlas cluster running MongoDB v6.0.11, v7.0.2, +or later using Atlas Vector Search. You can store vector embeddings for any kind of data along +with other data in your collection on the Atlas cluster. + +In the section, we set up a cluster, a database, test it, and finally create an Atlas Vector Search Index. + +### Deploy a Cluster + +Follow the [Getting-Started](https://www.mongodb.com/basics/mongodb-atlas-tutorial) documentation +to create an account, deploy an Atlas cluster, and connect to a database. + + +### Retrieve the URI used by Python to connect to the Cluster + +When you deploy the ChatGPT Retrieval App, this will be stored as the environment variable: `MONGODB_URI` +It will look something like the following. The username and password, if not provided, +can be configured in *Database Access* under Security in the left panel. + +``` +export MONGODB_URI="mongodb+srv://:@chatgpt-retrieval-plugin.zeatahb.mongodb.net/?retryWrites=true&w=majority" +``` + +There are a number of ways to navigate the Atlas UI. Keep your eye out for "Connect" and "driver". + +On the left panel, navigate and click 'Database' under DEPLOYMENT. +Click the Connect button that appears, then Drivers. Select Python. +(Have no concern for the version. This is the PyMongo, not Python, version.) +Once you have got the Connect Window open, you will see an instruction to `pip install pymongo`. +You will also see a **connection string**. +This is the `uri` that a `pymongo.MongoClient` uses to connect to the Database. + + +### Test the connection + +Atlas provides a simple check. Once you have your `uri` and `pymongo` installed, +try the following in a python console. + +```python +from pymongo.mongo_client import MongoClient +client = MongoClient(uri) # Create a new client and connect to the server +try: + client.admin.command('ping') # Send a ping to confirm a successful connection + print("Pinged your deployment. You successfully connected to MongoDB!") +except Exception as e: + print(e) +``` + +**Troubleshooting** +* You can edit a Database's users and passwords on the 'Database Access' page, under Security. +* Remember to add your IP address. (Try `curl -4 ifconfig.co`) + +### Create a Database and Collection + +As mentioned, Vector Databases provide two functions. In addition to being the data store, +they provide very efficient search based on natural language queries. +With Vector Search, one will index and query data with a powerful vector search algorithm +using "Hierarchical Navigable Small World (HNSW) graphs to find vector similarity. + +The indexing runs beside the data as a separate service asynchronously. +The Search index monitors changes to the Collection that it applies to. +Subsequently, one need not upload the data first. +We will create an empty collection now, which will simplify setup in the example notebook. + +Back in the UI, navigate to the Database Deployments page by clicking Database on the left panel. +Click the "Browse Collections" and then "+ Create Database" buttons. +This will open a window where you choose Database and Collection names. (No additional preferences.) +Remember these values as they will be as the environment variables, +`MONGODB_DATABASE` and `MONGODB_COLLECTION`. Though arbitrary, we suggest "SQUAD" and "Beyonce" +as these describe the data that we will use in our example Jupyter Notebook. + + +### Set Datastore Environment Variables + +To establish a connection to the MongoDB Cluster, Database, and Collection, plus create a Vector Search Index, +define the following environment variables. +You can confirm that the required ones have been set like this: `assert "MONGODB_URI" in os.environ` + +**IMPORTANT** It is crucial that the choices are consistent between setup in Atlas and Python environment(s). + +| Name | Description | Example | +|-----------------------|-----------------------------|----------------------------------------------------------------------------------| +| `MONGODB_URI` | Connection String | mongodb+srv://``:``@chatgpt-retrieval-plugin.zeatahb.mongodb.net | +| `MONGODB_DATABASE` | Database name | SQUAD | +| `MONGODB_COLLECTION` | Collection name | Beyonce | +| `MONGODB_INDEX` | Search index name | vector_index | +| `DATASTORE` | Datastore name | [must be] mongodb | +| `EMBEDDING_MODEL` | OpenAI Embedding Model | text-embedding-3-small | +| `EMBEDDING_DIMENSION` | Length of Embedding Vectors | 1536 | + +The following will also be required to authenticate with OpenAI and Plugin APIs. + +| Name | Description | +|------------------|-----------------------------------------------------------------| +| `OPENAI_API_KEY` | OpenAI token created at https://platform.openai.com/api-keys | +| `BEARER_TOKEN` | Secret string passed in HTTP request header that server expects | + +### Create an Atlas Vector Search Index + +The final step to configure MongoDB as the Datastore is to create a Vector Search Index. +The procedure is described [here](https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#procedure). + +Under Services on the left panel, choose Atlas Search > Create Search Index > +Atlas Vector Search JSON Editor. + +The Plugin expects an index definition like the following. +To begin, choose `numDimensions: 1536` along with the suggested EMBEDDING variables above. +You can experiment with these later. + +```json +{ + "fields": [ + { + "numDimensions": 1536, + "path": "embedding", + "similarity": "cosine", + "type": "vector" + } + ] +} +``` + + +### Running MongoDB Integration Tests + +In addition to the Jupyter Notebooks in `examples/`, +a suite of integration tests is available to verify the MongoDB integration. +The test suite needs the cluster up and running, and the environment variables defined above. + +Then, launch the test suite with this command: + +```bash +pytest ./tests/datastore/providers/mongodb_atlas/test_mongodb_datastore.py +``` diff --git a/examples/providers/mongodb/semantic-search.ipynb b/examples/providers/mongodb/semantic-search.ipynb new file mode 100644 index 000000000..548f60820 --- /dev/null +++ b/examples/providers/mongodb/semantic-search.ipynb @@ -0,0 +1,654 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "735ae737-86be-4497-a8e9-38525e422380", + "metadata": {}, + "source": [ + "# Semantic Search of one's own data with OpenAI Embedding Model and MongoDB Atlas Vector Search\n", + "\n", + "It is often a valuable exercise, when developing and documenting, to consider [User Stories](https://www.atlassian.com/agile/project-management/user-stories). We have a number of different personas interested in the ChatGPT Retrieval Plugin.\n", + "\n", + "1. The End User, who wishes to extract information from her organization's or personal data.\n", + "2. The Data Scientist, who curates the data.\n", + "3. The Application Engineer, who sets up and maintains the application.\n", + "\n", + "### Application Setup\n", + "\n", + "**The Application Engineer** has a number of tasks to complete in order to provide service to her two users.\n", + "\n", + "1. Set up the DataStore.\n", + "\n", + "\n", + " * Create a MongoDB Atlas cluster.\n", + " * Add a Vector Index Search to it.

\n", + "\n", + " Begin by following the detailed steps in **[setup.md](https://github.com/caseyclements/chatgpt-retrieval-plugin/blob/mongodb/docs/providers/mongodb/setup.md)**.\n", + " Once completed, you will have a running Cluster, with a Database, a Collection, and a Vector Search Index attached to it.\n", + "\n", + " You will also have a number of required environment variables. These need to be available to run this example.\n", + " We will check for them below, and suggest how to set them up with an `.env` file if that is your preference.\n", + "\n", + " \n", + "2. Create and Serve the ChatGPT Retrival Plugin.\n", + " * Provide an API for the Data Scientist to insert, update, and delete data.\n", + " * Provide an API for the End User to query the data using natural language.

\n", + " \n", + " Start the service in another terminal as described in the repo's **[QuickStart]( [here](https://github.com/openai/chatgpt-retrieval-plugin#quickstart)**. \n", + "\n", + " **IMPORTANT** Make sure the environment variables are set in the terminal before `poetry run start`.\n", + "\n", + "### Application Usage\n", + "\n", + "This notebook tells a story of a **Data Scientist** and an **End User** as they interact with the service.\n", + "\n", + "\n", + "We begin by collecting and fiiltering an example dataset, the Stanford Question Answering Dataset (SQuAD)[https://huggingface.co/datasets/squad].\n", + "We upsert the data into a MongoDB Collection via the `query` endpoint of the Plugin API. \n", + "Upon doing this, Atlas begins to automatically index the data in preparation for Semantic Search. \n", + "\n", + "We close by asking a question of the data, searching not for a particular text string, but using common language.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "42f2c141-2643-4bff-b431-532916dfedf9", + "metadata": {}, + "source": [ + "## 1) Application Engineering\n", + "\n", + "Of course, we cannot begin until we test that our environment is set up.\n", + "\n", + "### Check environment variables\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02c045c9-39c8-47e4-a726-7b2a4c1cef21", + "metadata": {}, + "outputs": [], + "source": [ + "!pwd" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "143c3e90-cf24-45dc-af65-646bcf89b071", + "metadata": {}, + "outputs": [], + "source": [ + "!which python" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "8cfe7dc4-820c-4117-bdb7-debf3f5ec5ff", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "required_vars = {'BEARER_TOKEN', 'OPENAI_API_KEY', 'DATASTORE', 'EMBEDDING_DIMENSION', 'EMBEDDING_MODEL',\n", + " 'MONGODB_COLLECTION', 'MONGODB_DATABASE', 'MONGODB_INDEX', 'MONGODB_URI'}\n", + "assert os.environ[\"DATASTORE\"] == 'mongodb'\n", + "missing = required_vars - set(os.environ)\n", + "if missing:\n", + " print(f\"It is strongly recommended to set these additional environment variables. {missing}=\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "88f07920-3616-4663-bb59-68179c97933e", + "metadata": {}, + "outputs": [], + "source": [ + "# If you keep the environment variables in a .env file, like that .env.example, do this:\n", + "if missing:\n", + " from dotenv import dotenv_values\n", + " from pathlib import Path\n", + " import os\n", + " config = dotenv_values(Path('../.env'))\n", + " os.environ.update(config)" + ] + }, + { + "cell_type": "markdown", + "id": "c0e152ba-4cb4-4703-ac30-035afbc84e67", + "metadata": {}, + "source": [ + "### Check MongoDB Atlas Datastore connection" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "1aa90750-3f23-4671-8419-53e867649a6b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Pinged your deployment. You successfully connected to MongoDB!\n" + ] + } + ], + "source": [ + "from pymongo import MongoClient\n", + "client = MongoClient(os.environ[\"MONGODB_URI\"])\n", + "# Send a ping to confirm a successful connection\n", + "try:\n", + " client.admin.command('ping')\n", + " print(\"Pinged your deployment. You successfully connected to MongoDB!\")\n", + "except Exception as e:\n", + " print(e)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "52480c9d-fcf0-4cd2-8031-94999a0f87cc", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Beyonce'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "db = client[os.environ[\"MONGODB_DATABASE\"]]\n", + "clxn = db[os.environ[\"MONGODB_COLLECTION\"]]\n", + "clxn.name" + ] + }, + { + "cell_type": "markdown", + "id": "5e334b80-babe-414d-adb1-6c4b8baff137", + "metadata": {}, + "source": [ + "### Check OpenAI Connection\n", + "\n", + "These tests require the environment variables: `OPENAI_API_KEY, EMBEDDING_MODEL`\n", + "\n", + "We set the api_key, then query the API for its available models. We then loop over this list to find which can provide text embeddings, and their natural, full, default dimensions." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8435ce2e-ed38-48e8-a9eb-4595d9c8eee3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"model_dimensions={'text-embedding-3-small': 1536, 'text-embedding-ada-002': 1536, 'text-embedding-3-large': 3072}\"" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import openai\n", + "openai.api_key = os.environ[\"OPENAI_API_KEY\"]\n", + "models = openai.Model.list()\n", + "model_names = [model[\"id\"] for model in models['data']]\n", + "model_dimensions = {}\n", + "for model_name in model_names:\n", + " try:\n", + " response = openai.Embedding.create(input=[\"Some input text\"], model=model_name)\n", + " model_dimensions[model_name] = len(response['data'][0]['embedding'])\n", + " except:\n", + " pass\n", + "f\"{model_dimensions=}\"" + ] + }, + { + "cell_type": "markdown", + "id": "bcddf850-3a7c-4164-8862-88553d7b3970", + "metadata": {}, + "source": [ + "## 2) Data Engineering\n", + "\n", + "### Prepare personal or organizational dataset\n", + "\n", + "The ChatGPT Retrieval Plug provides semantic search of your own data using OpenAI's Embedding Models and MongoDB's Vector Datastore and Semantic Search.\n", + "\n", + "In this example, we will use the **S**tanford **Qu**estion **A**nswering **D**ataset (SQuAD), which we download from Hugging Face Datasets." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6833a874-e6f8-4f7b-9889-9b5184f458aa", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "len(data)=19029\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtitlecontextquestionanswers
056be85543aeaaa14008c9063BeyoncéBeyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...When did Beyonce start becoming popular?{'text': ['in the late 1990s'], 'answer_start'...
1556be86cf3aeaaa14008c9076BeyoncéFollowing the disbandment of Destiny's Child i...After her second solo album, what other entert...{'text': ['acting'], 'answer_start': [207]}
2756be88473aeaaa14008c9080BeyoncéA self-described \"modern-day feminist\", Beyonc...In her music, what are some recurring elements...{'text': ['love, relationships, and monogamy']...
3956be892d3aeaaa14008c908bBeyoncéBeyoncé Giselle Knowles was born in Houston, T...Beyonce's younger sibling also sang with her i...{'text': ['Destiny's Child'], 'answer_start': ...
5256be8a583aeaaa14008c9094BeyoncéBeyoncé attended St. Mary's Elementary School ...What town did Beyonce go to school in?{'text': ['Fredericksburg'], 'answer_start': [...
\n", + "
" + ], + "text/plain": [ + " id title \\\n", + "0 56be85543aeaaa14008c9063 Beyoncé \n", + "15 56be86cf3aeaaa14008c9076 Beyoncé \n", + "27 56be88473aeaaa14008c9080 Beyoncé \n", + "39 56be892d3aeaaa14008c908b Beyoncé \n", + "52 56be8a583aeaaa14008c9094 Beyoncé \n", + "\n", + " context \\\n", + "0 Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b... \n", + "15 Following the disbandment of Destiny's Child i... \n", + "27 A self-described \"modern-day feminist\", Beyonc... \n", + "39 Beyoncé Giselle Knowles was born in Houston, T... \n", + "52 Beyoncé attended St. Mary's Elementary School ... \n", + "\n", + " question \\\n", + "0 When did Beyonce start becoming popular? \n", + "15 After her second solo album, what other entert... \n", + "27 In her music, what are some recurring elements... \n", + "39 Beyonce's younger sibling also sang with her i... \n", + "52 What town did Beyonce go to school in? \n", + "\n", + " answers \n", + "0 {'text': ['in the late 1990s'], 'answer_start'... \n", + "15 {'text': ['acting'], 'answer_start': [207]} \n", + "27 {'text': ['love, relationships, and monogamy']... \n", + "39 {'text': ['Destiny's Child'], 'answer_start': ... \n", + "52 {'text': ['Fredericksburg'], 'answer_start': [... " + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "from datasets import load_dataset\n", + "data = load_dataset(\"squad_v2\", split=\"train\")\n", + "data = data.to_pandas().drop_duplicates(subset=[\"context\"])\n", + "print(f'{len(data)=}')\n", + "data.head()" + ] + }, + { + "cell_type": "markdown", + "id": "29c69543-51d7-49f9-b936-b7b99804818c", + "metadata": {}, + "source": [ + "To speed up our example, let's focus specifically on questions about Beyoncé" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "b444b0f8-a341-4f65-ab26-0793300d275f", + "metadata": {}, + "outputs": [], + "source": [ + "data = data.loc[data['title']=='Beyoncé']" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "2136e910-9944-422d-aaa5-50f1d3d7a5ed", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'id': '56be85543aeaaa14008c9063',\n", + " 'text': 'Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she performed in various singing and dancing competitions as a child, and rose to fame in the late 1990s as lead singer of R&B girl-group Destiny\\'s Child. Managed by her father, Mathew Knowles, the group became one of the world\\'s best-selling girl groups of all time. Their hiatus saw the release of Beyoncé\\'s debut album, Dangerously in Love (2003), which established her as a solo artist worldwide, earned five Grammy Awards and featured the Billboard Hot 100 number-one singles \"Crazy in Love\" and \"Baby Boy\".',\n", + " 'metadata': {'title': 'Beyoncé'}}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "documents = [\n", + " {\n", + " 'id': r['id'],\n", + " 'text': r['context'],\n", + " 'metadata': {\n", + " 'title': r['title']\n", + " }\n", + " } for r in data.to_dict(orient='records')\n", + "]\n", + "documents[0]" + ] + }, + { + "cell_type": "markdown", + "id": "98388d9a-bb33-4eea-8d43-e890517f829a", + "metadata": {}, + "source": [ + "## Upsert and Index data via the Plugin API\n", + "\n", + "Posting an `upsert` request to the ChatGPT Retrieval Plugin API performs two tasks on the backend. First, it inserts into (or updates) your data in the MONGODB_COLLECTION in the MongoDB Cluster that you setup. Second, Atlas asynchronously begins populating a Vector Search Index on the embedding key. \n", + "\n", + "If you have already created the Collection and a Vector Search Index through the Atlas UI while Setting up MongoDB Atlas Cluster in [setup.md](https://github.com/caseyclements/chatgpt-retrieval-plugin/blob/main/docs/providers/mongodb/setup.md), then indexing will begin immediately.\n", + "\n", + "If you haven't set up the Atlas Vector Search yet, no problem. `upsert` will insert the data. To start indexing, simply go back to the Atlas UI and add a Search Index. This will trigger indexing. Once complete, we can begin semantic queries!\n", + "\n", + "\n", + "The front end of the Plugin is a FastAPI web server. It's API provides simple `http` requests.'We will need to provide authorization in the form of the BEARER_TOKEN we set earlier. We do this below:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "8e2891fb-4485-4f97-b73f-e2f4186238bd", + "metadata": {}, + "outputs": [], + "source": [ + "endpoint_url = 'http://0.0.0.0:8000'\n", + "headers = {\"Authorization\": f\"Bearer {os.environ['BEARER_TOKEN']}\"}" + ] + }, + { + "cell_type": "markdown", + "id": "5f5cc603-88f1-402b-a0a7-52cabb9f7d9d", + "metadata": {}, + "source": [ + "Although our sample data is not large, and the service and datastore are reponsive, we follow best-practice and execute bulk upserts in batches with retries." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "1c0e2eb7-58b5-4f36-af2d-d29aee18e08d", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "cb28c6a61b464858806fc8f3404e7a19", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1 [00:00=3.7" +files = [ + {file = "motor-3.3.2-py3-none-any.whl", hash = "sha256:6fe7e6f0c4f430b9e030b9d22549b732f7c2226af3ab71ecc309e4a1b7d19953"}, + {file = "motor-3.3.2.tar.gz", hash = "sha256:d2fc38de15f1c8058f389c1a44a4d4105c0405c48c061cd492a654496f7bc26a"}, +] + +[package.dependencies] +pymongo = ">=4.5,<5" + +[package.extras] +aws = ["pymongo[aws] (>=4.5,<5)"] +encryption = ["pymongo[encryption] (>=4.5,<5)"] +gssapi = ["pymongo[gssapi] (>=4.5,<5)"] +ocsp = ["pymongo[ocsp] (>=4.5,<5)"] +snappy = ["pymongo[snappy] (>=4.5,<5)"] +srv = ["pymongo[srv] (>=4.5,<5)"] +test = ["aiohttp (<3.8.6)", "mockupdb", "motor[encryption]", "pytest (>=7)", "tornado (>=5)"] +zstd = ["pymongo[zstd] (>=4.5,<5)"] + [[package]] name = "mpmath" version = "1.3.0" diff --git a/pyproject.toml b/pyproject.toml index 1ba7bf7cc..e41676ee4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ psycopg2cffi = {version = "^2.9.0", optional = true} loguru = "^0.7.0" elasticsearch = "8.8.2" pymongo = "^4.3.3" +motor = "^3.3.2" [tool.poetry.scripts] start = "server.main:start" diff --git a/services/openai.py b/services/openai.py index 965ece664..2f1fb1b9c 100644 --- a/services/openai.py +++ b/services/openai.py @@ -6,6 +6,7 @@ from tenacity import retry, wait_random_exponential, stop_after_attempt EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large") +EMBEDDING_DIMENSION = int(os.environ.get("EMBEDDING_DIMENSION", 256)) @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(3)) @@ -28,7 +29,7 @@ def get_embeddings(texts: List[str]) -> List[List[float]]: response = {} if deployment is None: - response = openai.Embedding.create(input=texts, model=EMBEDDING_MODEL) + response = openai.Embedding.create(input=texts, model=EMBEDDING_MODEL, dimensions=EMBEDDING_DIMENSION) else: response = openai.Embedding.create(input=texts, deployment_id=deployment) diff --git a/tests/datastore/providers/mongodb_atlas/test_integration.py b/tests/datastore/providers/mongodb_atlas/test_integration.py new file mode 100644 index 000000000..cea67678c --- /dev/null +++ b/tests/datastore/providers/mongodb_atlas/test_integration.py @@ -0,0 +1,128 @@ +"""Integration Tests of ChatGPT Retrieval Plugin +with MongoDB Atlas Vector Datastore and OPENAI Embedding model. + +As described in docs/providers/mongodb/setup.md, to run this, one must +have a running MongoDB Atlas Cluster, and +provide a valid OPENAI_API_KEY. +""" + +import os +from time import sleep + +import openai +import pytest +from fastapi.testclient import TestClient +from httpx import Response +from pymongo import MongoClient + +from server.main import app + + +@pytest.fixture(scope="session") +def documents(): + """ List of documents represents data to be embedded in the datastore. + Minimum requirements fpr Documents in the /upsert endpoint's UpsertRequest. + """ + return [ + {"text": "The quick brown fox jumped over the slimy green toad."}, + {"text": "The big brown bear jumped over the lazy dog."}, + {"text": "Toads are frogs."}, + {"text": "Green toads are basically red frogs."}, + ] + + +@pytest.fixture(scope="session", autouse=True) +def client(): + """TestClient makes requests to FastAPI service.""" + endpoint_url = "http://127.0.0.1:8000" + headers = {"Authorization": f"Bearer {os.environ['BEARER_TOKEN']}"} + with TestClient(app=app, base_url=endpoint_url, headers=headers) as client: + yield client + + +@pytest.fixture(scope="session") +def delete(client) -> bool: + """Drop existing documents from the collection""" + response = client.request("DELETE", "/delete", json={"delete_all": True}) + sleep(2) + return response + + +@pytest.fixture(scope="session") +def upsert(delete, documents, client) -> bool: + """Upload documents to the datastore via plugin's REST API.""" + response = client.post("/upsert", json={"documents": documents}) + sleep(2) # At this point, the Vector Search Index is being built + return response + + +def test_delete(delete) -> None: + """Simply confirm that delete fixture ran successfully""" + assert delete.status_code == 200 + assert delete.json()['success'] + + +def test_upsert(upsert) -> None: + """Simply confirm that upsert fixture has run successfully""" + assert upsert.status_code == 200 + assert len(upsert.json()['ids']) == 4 + + +def test_query(upsert, client) -> None: # upsert, + """Test queries produce reasonable results, + now that datastore contains embedded data which has been indexed + """ + question = "What did the fox jump over?" + n_requested = 2 # top N results per query + got_response = False + retries = 5 + query_result = {} + while retries and not got_response: + response = client.post("/query", json={'queries': [{"query": question, "top_k": n_requested}]}) + assert isinstance(response, Response) + assert response.status_code == 200 + assert len(response.json()) == 1 + query_result = response.json()['results'][0] + if len(query_result['results']) == n_requested: + got_response = True + else: + retries -= 1 + sleep(5) + + assert got_response # we got n_requested responses + assert query_result['query'] == question + answers = [] + scores = [] + for result in query_result['results']: + answers.append(result['text']) + scores.append(round(result['score'], 2)) + assert 0.8 < scores[0] < 0.9 + assert answers[0] == "The quick brown fox jumped over the slimy green toad." + + +def test_required_vars() -> None: + """Confirm that the environment has all it needs""" + required_vars = {'BEARER_TOKEN', 'OPENAI_API_KEY', 'DATASTORE', 'EMBEDDING_DIMENSION', 'EMBEDDING_MODEL', + 'MONGODB_COLLECTION', 'MONGODB_DATABASE', 'MONGODB_INDEX', 'MONGODB_URI'} + assert os.environ["DATASTORE"] == 'mongodb' + missing = required_vars - set(os.environ) + assert len(missing) == 0 + + +def test_mongodb_connection() -> None: + """Confirm that the connection to the datastore works.""" + client = MongoClient(os.environ["MONGODB_URI"]) + assert client.admin.command('ping')['ok'] + + +def test_openai_connection() -> None: + """Check that we can call OpenAI Embedding models.""" + openai.api_key = os.environ["OPENAI_API_KEY"] + models = openai.Model.list() + model_names = [model["id"] for model in models['data']] + for model_name in model_names: + try: + response = openai.Embedding.create(input=["Some input text"], model=model_name) + assert len(response['data'][0]['embedding']) >= int(os.environ['EMBEDDING_DIMENSION']) + except: + pass # Not all models are for text embedding. diff --git a/tests/datastore/providers/mongodb_atlas/test_mongodb_datastore.py b/tests/datastore/providers/mongodb_atlas/test_mongodb_datastore.py new file mode 100644 index 000000000..ef585c7e3 --- /dev/null +++ b/tests/datastore/providers/mongodb_atlas/test_mongodb_datastore.py @@ -0,0 +1,254 @@ +""" +Integration tests of MongoDB Atlas Datastore. + +These tests require one to have a running Cluster, Database, Collection and Atlas Search Index +as described in docs/providers/mongodb/setup.md. + +One will also have to set the same environment variables. Although one CAN +use we the same collection and index used in examples/providers/mongodb/semantic-search.ipynb, +these tests will make changes to the data, so you may wish to create another collection. +If you have run the example notebook, you can reuse with the following. + +MONGODB_DATABASE=SQUAD +MONGODB_COLLECTION=Beyonce +MONGODB_INDEX=vector_index +EMBEDDING_DIMENSION=1536 +MONGODB_URI=mongodb+srv://:@/?retryWrites=true&w=majority +""" + + +from inspect import iscoroutinefunction +import pytest +import time +from typing import Callable +import os + +from models.models import ( + DocumentChunkMetadata, + DocumentMetadataFilter, + DocumentChunk, + QueryWithEmbedding, + Source, +) +from services.date import to_unix_timestamp +from datetime import datetime +from datastore.providers.mongodb_atlas_datastore import ( + MongoDBAtlasDataStore, +) + + + +async def assert_when_ready(callable: Callable, tries: int = 5, interval: float = 1): + + for _ in range(tries): + if iscoroutinefunction(callable): + print("starting async call") + result = await callable() + print(f"finished async call with {result=}") + else: + result = callable() + if result: + return + time.sleep(interval) + + raise AssertionError("Condition not met after multiple attempts") + + +def collection_size_callback_factory(collection, num: int): + + async def predicate(): + num_documents = await collection.count_documents({}) + return num_documents == num + + return predicate + + +@pytest.fixture +def _mongodb_datastore(): + return MongoDBAtlasDataStore() + + +@pytest.fixture +async def mongodb_datastore(_mongodb_datastore): + await _mongodb_datastore.delete(delete_all=True) + collection = _mongodb_datastore.client[_mongodb_datastore.database_name][_mongodb_datastore.collection_name] + await assert_when_ready(collection_size_callback_factory(collection, 0)) + yield _mongodb_datastore + await _mongodb_datastore.delete(delete_all=True) + await assert_when_ready(collection_size_callback_factory(collection, 0)) + + +def sample_embedding(one_element_poz: int): + n_dims = int(os.environ["EMBEDDING_DIMENSION"]) + embedding = [0] * n_dims + embedding[one_element_poz % n_dims] = 1 + return embedding + + +def sample_embeddings(num: int, one_element_start: int = 0): + return [sample_embedding(x + one_element_start) for x in range(num)] + + +@pytest.fixture +def document_id(): + """ID of an unchunked document""" + return "a5991f75a315f755c3365ab2" + +@pytest.fixture +def chunk_ids(document_id): + """IDs of chunks""" + return [f"{document_id}_{i}" for i in range(3)] + + +@pytest.fixture +def one_documents_chunks(document_id, chunk_ids): + """Represents output of services.chunks.get_document_chunks + -> Dict[str, List[DocumentChunk]] + called on a list containing a single Document + """ + + n_chunks = len(chunk_ids) + + texts = [ + "Aenean euismod bibendum laoreet", + "Vivamus non enim vitae tortor", + "Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae", + ] + sources = [Source.email, Source.file, Source.chat] + created_ats = [ + "1929-10-28T09:30:00-05:00", + "2009-01-03T16:39:57-08:00", + "2021-01-21T10:00:00-02:00", + ] + authors = ["Fred Smith", "Bob Doe", "Appleton Doe"] + + embeddings = sample_embeddings(n_chunks) + doc_chunks = [] + for i in range(n_chunks): + chunk = DocumentChunk( + id=chunk_ids[i], + text=texts[i], + metadata=DocumentChunkMetadata( + document_id=document_id, + source=sources[i], + created_at=created_ats[i], + author=authors[i], + ), + embedding=embeddings[i], # type: ignore + ) + + doc_chunks.append(chunk) + + return {document_id: doc_chunks} + + +async def test_upsert(mongodb_datastore: MongoDBAtlasDataStore, one_documents_chunks, chunk_ids): + """This tests that data gets uploaded, but not that the search index is built.""" + res = await mongodb_datastore._upsert(one_documents_chunks) + assert res == chunk_ids + + collection = mongodb_datastore.client[mongodb_datastore.database_name][mongodb_datastore.collection_name] + await assert_when_ready(collection_size_callback_factory(collection, 3)) + + +async def test_upsert_query_all(mongodb_datastore, one_documents_chunks, chunk_ids): + """By running _query, this performs """ + res = await mongodb_datastore._upsert(one_documents_chunks) + await assert_when_ready(lambda: res == chunk_ids) + + query = QueryWithEmbedding( + query="Aenean", + top_k=10, + embedding=sample_embedding(0), # type: ignore + ) + + async def predicate(): + query_results = await mongodb_datastore._query(queries=[query]) + return 1 == len(query_results) and 3 == len(query_results[0].results) + + await assert_when_ready(predicate, tries=12, interval=5) + + +async def test_delete_with_document_id(mongodb_datastore, one_documents_chunks, chunk_ids): + res = await mongodb_datastore._upsert(one_documents_chunks) + assert res == chunk_ids + collection = mongodb_datastore.client[mongodb_datastore.database_name][mongodb_datastore.collection_name] + first_id = str((await collection.find_one())["_id"]) + await mongodb_datastore.delete(ids=[first_id]) + + await assert_when_ready(collection_size_callback_factory(collection, 2)) + + all_documents = [doc async for doc in collection.find()] + for document in all_documents: + assert document["metadata"]["author"] != "Fred Smith" + + +async def test_delete_with_source_filter(mongodb_datastore, one_documents_chunks, chunk_ids): + res = await mongodb_datastore._upsert(one_documents_chunks) + assert res == chunk_ids + + await mongodb_datastore.delete( + filter=DocumentMetadataFilter( + source=Source.email, + ) + ) + + query = QueryWithEmbedding( + query="Aenean", + top_k=9, + embedding=sample_embedding(0), # type: ignore + ) + + async def predicate(): + query_results = await mongodb_datastore._query(queries=[query]) + return 1 == len(query_results) and query_results[0].results + + await assert_when_ready(predicate, tries=12, interval=5) + query_results = await mongodb_datastore._query(queries=[query]) + for result in query_results[0].results: + assert result.text != "Aenean euismod bibendum laoreet" + + +@pytest.fixture +def build_mongo_filter(): + return MongoDBAtlasDataStore()._build_mongo_filter + + +async def test_build_mongo_filter_with_no_filter(build_mongo_filter): + result = build_mongo_filter() + assert result == {} + + +async def test_build_mongo_filter_with_start_date(build_mongo_filter): + date = datetime(2022, 1, 1).isoformat() + filter_data = {"start_date": date} + result = build_mongo_filter(DocumentMetadataFilter(**filter_data)) + + assert result == { + "$and": [ + {"created_at": {"$gte": to_unix_timestamp(date)}} + ] + } + + +async def test_build_mongo_filter_with_end_date(build_mongo_filter): + date = datetime(2022, 1, 1).isoformat() + filter_data = {"end_date": date} + result = build_mongo_filter(DocumentMetadataFilter(**filter_data)) + + assert result == { + "$and": [ + {"created_at": {"$lte": to_unix_timestamp(date)}} + ] + } + + +async def test_build_mongo_filter_with_metadata_field(build_mongo_filter): + filter_data = {"source": "email"} + result = build_mongo_filter(DocumentMetadataFilter(**filter_data)) + + assert result == { + "$and": [ + {"metadata.source": "email"} + ] + }