Skip to content

Commit

Permalink
Use skip_on_missing_imports to mark tests in test/agentchat/contrib/v…
Browse files Browse the repository at this point in the history
…ectordb files
  • Loading branch information
kumaranvpl committed Jan 23, 2025
1 parent 5011ec0 commit dac2b9d
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 33 deletions.
7 changes: 2 additions & 5 deletions test/agentchat/contrib/vectordb/test_chromadb.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest

from autogen.agentchat.contrib.vectordb.chromadb import ChromaVectorDB
from autogen.import_utils import optional_import_block
from autogen.import_utils import optional_import_block, skip_on_missing_imports

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))

Expand All @@ -20,10 +20,7 @@
import sentence_transformers # noqa: F401


skip = not result.is_successful


@pytest.mark.skipif(skip, reason="dependency is not installed")
@skip_on_missing_imports(["chromadb", "sentence_transformers"], "retrievechat")
def test_chromadb():
# test create collection
db = ChromaVectorDB(path=".db")
Expand Down
32 changes: 18 additions & 14 deletions test/agentchat/contrib/vectordb/test_mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,12 @@

from autogen.agentchat.contrib.vectordb.base import Document
from autogen.agentchat.contrib.vectordb.mongodb import MongoDBAtlasVectorDB
from autogen.import_utils import optional_import_block
from autogen.import_utils import optional_import_block, skip_on_missing_imports

with optional_import_block() as result:
import pymongo # noqa: F401
import sentence_transformers # noqa: F401


if not result.is_successful:
# To display warning in pyproject.toml [tool.pytest.ini_options] set log_cli = true
logger = logging.getLogger(__name__)
logger.warning(f"skipping {__name__}. It requires one to pip install pymongo or the extra [retrievechat-mongodb]")
pytest.skip(allow_module_level=True)

from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.errors import OperationFailure
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.errors import OperationFailure

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -143,6 +133,7 @@ def collection_name():
return f"{MONGODB_COLLECTION}_{collection_id}"


@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb")
def test_create_collection(db, collection_name):
"""Def create_collection(collection_name: str,
overwrite: bool = False) -> Collection
Expand Down Expand Up @@ -172,6 +163,7 @@ def test_create_collection(db, collection_name):
db.create_collection(collection_name=collection_name, overwrite=False, get_or_create=False)


@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb")
def test_get_collection(db, collection_name):
with pytest.raises(ValueError):
db.get_collection()
Expand All @@ -185,6 +177,7 @@ def test_get_collection(db, collection_name):
assert collection_got.name == db.active_collection.name


@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb")
def test_delete_collection(db, collection_name):
assert collection_name not in db.list_collections()
collection = db.create_collection(collection_name)
Expand All @@ -193,6 +186,7 @@ def test_delete_collection(db, collection_name):
assert collection_name not in db.list_collections()


@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb")
def test_insert_docs(db, collection_name, example_documents):
# Test that there's an active collection
with pytest.raises(ValueError) as exc:
Expand All @@ -218,6 +212,7 @@ def test_insert_docs(db, collection_name, example_documents):
assert len(found[0]["embedding"]) == 384


@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb")
def test_update_docs(db_with_indexed_clxn, example_documents):
db, collection = db_with_indexed_clxn
# Use update_docs to insert new documents
Expand Down Expand Up @@ -253,6 +248,7 @@ def test_update_docs(db_with_indexed_clxn, example_documents):
assert collection.find_one({"_id": new_id}) is None


@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb")
def test_delete_docs(db_with_indexed_clxn, example_documents):
db, clxn = db_with_indexed_clxn
# Insert example documents
Expand All @@ -263,6 +259,7 @@ def test_delete_docs(db_with_indexed_clxn, example_documents):
assert {2, "2"} == {doc["_id"] for doc in clxn.find({})}


@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb")
def test_get_docs_by_ids(db_with_indexed_clxn, example_documents):
db, clxn = db_with_indexed_clxn
# Insert example documents
Expand All @@ -288,11 +285,13 @@ def test_get_docs_by_ids(db_with_indexed_clxn, example_documents):
assert len(docs) == 4


@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb")
def test_retrieve_docs_empty(db_with_indexed_clxn):
db, clxn = db_with_indexed_clxn
assert db.retrieve_docs(queries=["Cats"], collection_name=clxn.name, n_results=2) == []


@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb")
def test_retrieve_docs_populated_db_empty_query(db_with_indexed_clxn, example_documents):
db, clxn = db_with_indexed_clxn
db.insert_docs(example_documents, collection_name=clxn.name)
Expand All @@ -301,6 +300,7 @@ def test_retrieve_docs_populated_db_empty_query(db_with_indexed_clxn, example_do
assert results == []


@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb")
def test_retrieve_docs(db_with_indexed_clxn, example_documents):
"""Begin testing Atlas Vector Search
NOTE: Indexing may take some time, so we must be patient on the first query.
Expand All @@ -324,6 +324,7 @@ def results_ready():
assert all(["embedding" not in doc[0] for doc in results[0]])


@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb")
def test_retrieve_docs_with_embedding(db_with_indexed_clxn, example_documents):
"""Begin testing Atlas Vector Search
NOTE: Indexing may take some time, so we must be patient on the first query.
Expand All @@ -347,6 +348,7 @@ def results_ready():
assert all(["embedding" in doc[0] for doc in results[0]])


@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb")
def test_retrieve_docs_multiple_queries(db_with_indexed_clxn, example_documents):
db, clxn = db_with_indexed_clxn
# Insert example documents
Expand All @@ -369,6 +371,7 @@ def results_ready():
assert {doc[0]["id"] for doc in results[1]} == {"1", "2"}


@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb")
def test_retrieve_docs_with_threshold(db_with_indexed_clxn, example_documents):
db, clxn = db_with_indexed_clxn
# Insert example documents
Expand All @@ -390,6 +393,7 @@ def results_ready():
assert all([doc[1] >= 0.7 for doc in results[0]])


@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb")
def test_wait_until_document_ready(collection_name, example_documents):
database = MongoClient(MONGODB_URI)[MONGODB_DATABASE]
_empty_collections_and_delete_indexes(database, [collection_name], wait=True)
Expand Down
9 changes: 3 additions & 6 deletions test/agentchat/contrib/vectordb/test_pgvectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,22 @@
import pytest

from autogen.agentchat.contrib.vectordb.pgvectordb import PGVectorDB
from autogen.import_utils import optional_import_block
from autogen.import_utils import optional_import_block, skip_on_missing_imports

from ....conftest import reason

with optional_import_block() as result:
import pgvector # noqa: F401
import psycopg
import sentence_transformers # noqa: F401


skip = not result.is_successful

reason = "do not run on MacOS or windows OR dependency is not installed OR " + reason


@pytest.mark.skipif(
sys.platform in ["darwin", "win32"] or skip,
sys.platform in ["darwin", "win32"],
reason=reason,
)
@skip_on_missing_imports(["pgvector", "psycopg", "sentence_transformers"], "retrievechat-pgvector")
def test_pgvector():
# test db config
db_config = {
Expand Down
10 changes: 2 additions & 8 deletions test/agentchat/contrib/vectordb/test_qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,16 @@
import sys
import uuid

import pytest

from autogen.agentchat.contrib.vectordb.qdrant import QdrantVectorDB
from autogen.import_utils import optional_import_block
from autogen.import_utils import optional_import_block, skip_on_missing_imports

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))

with optional_import_block() as result:
from fastembed import TextEmbedding # noqa: F401
from qdrant_client import QdrantClient


skip = not result.is_successful


@pytest.mark.skipif(skip, reason="dependency is not installed")
@skip_on_missing_imports(["fastembed", "qdrant_client"], "retrievechat-qdrant")
def test_qdrant():
# test create collection
client = QdrantClient(location=":memory:")
Expand Down

0 comments on commit dac2b9d

Please sign in to comment.