Skip to content

Commit

Permalink
Handle missing imports with a context manager and a wraper for classe…
Browse files Browse the repository at this point in the history
…s and objects (#540)

* Add check_for_missing_imports

* Add requires_optional_import decorator

* Comment out DummyModule

* refactoring

* refactoring

* Use import module to detect missing modules

* Use sys.modules instead of globals and cleanup print statements

* Use option_import block to import instead of try block

* Fix typo

* Revert requires block

* Fix typo in require_option_import

* Move require_option_import to correct position

* Use optional_import_block

* Install websurfer optional dependency in tests

* Update test_import_utils test

* Remove unused AutogenImportError

* Fix failing test in browser_utils

* Polish code

* Skip test in windows

---------

Co-authored-by: Davor Runje <[email protected]>
  • Loading branch information
kumaranvpl and davorrunje authored Jan 20, 2025
1 parent 4fa386e commit 2dd1493
Show file tree
Hide file tree
Showing 27 changed files with 617 additions and 157 deletions.
15 changes: 5 additions & 10 deletions autogen/agentchat/contrib/capabilities/text_compressors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from typing import Any, Optional, Protocol
from typing import Any, Protocol

IMPORT_ERROR: Optional[Exception] = None
try:
from ....import_utils import optional_import_block, require_optional_import

with optional_import_block() as result:
import llmlingua
except ImportError:
IMPORT_ERROR = ImportError("LLMLingua is not installed. Please install it with `pip install autogen[long-context]`")
PromptCompressor = object
else:
from llmlingua import PromptCompressor


Expand All @@ -27,6 +24,7 @@ def compress_text(self, text: str, **compression_params) -> dict[str, Any]:
...


@require_optional_import("llmlingua", "long-context")
class LLMLingua:
"""Compresses text messages using LLMLingua for improved efficiency in processing and response generation.
Expand Down Expand Up @@ -55,9 +53,6 @@ def __init__(
Raises:
ImportError: If the llmlingua library is not installed.
"""
if IMPORT_ERROR:
raise IMPORT_ERROR

self._prompt_compressor = PromptCompressor(**prompt_compressor_kwargs)

assert isinstance(self._prompt_compressor, llmlingua.PromptCompressor)
Expand Down
10 changes: 5 additions & 5 deletions autogen/agentchat/contrib/llamaindex_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,18 @@
from autogen.agentchat import Agent, ConversableAgent
from autogen.agentchat.contrib.vectordb.utils import get_logger

from ...import_utils import optional_import_block, require_optional_import

logger = get_logger(__name__)

try:
with optional_import_block() as result:
from llama_index.core.agent.runner.base import AgentRunner
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.chat_engine.types import AgentChatResponse
from pydantic import BaseModel
from pydantic import __version__ as pydantic_version

if result.is_successful:
# let's Avoid: AttributeError: type object 'Config' has no attribute 'copy'
# check for v1 like in autogen/_pydantic.py
is_pydantic_v1 = pydantic_version.startswith("1.")
Expand All @@ -35,11 +38,8 @@ class Config:
# Added to mitigate PydanticSchemaGenerationError
BaseModel.model_config = Config

except ImportError as e:
logger.fatal("Failed to import llama-index. Try running 'pip install llama-index'")
raise e


@require_optional_import("llama_index", "neo4j")
class LLamaIndexConversableAgent(ConversableAgent):
def __init__(
self,
Expand Down
11 changes: 6 additions & 5 deletions autogen/agentchat/contrib/math_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
from autogen.code_utils import UNKNOWN, execute_code, extract_code, infer_lang
from autogen.math_utils import get_answer

from ...import_utils import optional_import_block, require_optional_import

with optional_import_block() as result:
import wolframalpha

PROMPTS = {
# default
"default": """Let's use Python to solve a math problem.
Expand Down Expand Up @@ -402,16 +407,12 @@ class Config:

@root_validator(skip_on_failure=True)
@classmethod
@require_optional_import("wolframalpha", "mathchat")
def validate_environment(cls, values: dict) -> dict:
"""Validate that api key and python package exists in environment."""
wolfram_alpha_appid = get_from_dict_or_env(values, "wolfram_alpha_appid", "WOLFRAM_ALPHA_APPID")
values["wolfram_alpha_appid"] = wolfram_alpha_appid

try:
import wolframalpha

except ImportError as e:
raise ImportError("wolframalpha is not installed. Please install it with `pip install wolframalpha`") from e
client = wolframalpha.Client(wolfram_alpha_appid)
values["wolfram_client"] = client

Expand Down
10 changes: 6 additions & 4 deletions autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
)
from autogen.retrieve_utils import TEXT_FORMATS, get_files_from_dir, split_files_to_chunks

from ...import_utils import optional_import_block, require_optional_import

logger = get_logger(__name__)

try:
with optional_import_block():
import fastembed # noqa: F401
from qdrant_client import QdrantClient, models
from qdrant_client.fastembed_common import QueryResponse
except ImportError as e:
logger.fatal("Failed to import qdrant_client with fastembed. Try running 'pip install qdrant_client[fastembed]'")
raise e


@require_optional_import(["fastembed", "qdrant_client"], "retrievechat-qdrant")
class QdrantRetrieveUserProxyAgent(RetrieveUserProxyAgent):
def __init__(
self,
Expand Down Expand Up @@ -158,6 +158,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
self._results = results


@require_optional_import(["fastembed", "qdrant_client"], "retrievechat-qdrant")
def create_qdrant_from_dir(
dir_path: str,
max_tokens: int = 4000,
Expand Down Expand Up @@ -263,6 +264,7 @@ def create_qdrant_from_dir(
)


@require_optional_import("qdrant_client", "retrievechat-qdrant")
def query_qdrant(
query_texts: list[str],
n_results: int = 10,
Expand Down
6 changes: 4 additions & 2 deletions autogen/agentchat/contrib/reasoning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
from typing import Optional

from ...import_utils import optional_import_block
from ..agent import Agent
from ..assistant_agent import AssistantAgent

Expand Down Expand Up @@ -169,9 +170,10 @@ def visualize_tree(root: ThinkNode) -> None:
Args:
root (ThinkNode): The root node of the tree.
"""
try:
with optional_import_block() as result:
from graphviz import Digraph
except ImportError:

if not result.is_successful:
print("Please install graphviz: pip install graphviz")
return

Expand Down
9 changes: 5 additions & 4 deletions autogen/agentchat/contrib/retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@

from IPython import get_ipython

try:
import chromadb
except ImportError as e:
raise ImportError(f"{e}. You can try `pip install autogen[retrievechat]`, or install `chromadb` manually.")
from autogen.agentchat import UserProxyAgent
from autogen.agentchat.agent import Agent
from autogen.agentchat.contrib.vectordb.base import Document, QueryResults, VectorDB, VectorDBFactory
Expand All @@ -35,6 +31,10 @@
from autogen.token_count_utils import count_token

from ...formatting_utils import colored
from ...import_utils import optional_import_block, require_optional_import

with optional_import_block():
import chromadb

logger = get_logger(__name__)

Expand Down Expand Up @@ -91,6 +91,7 @@
UPDATE_CONTEXT_IN_PROMPT = "you should reply exactly `UPDATE CONTEXT`"


@require_optional_import("chromadb", "retrievechat")
class RetrieveUserProxyAgent(UserProxyAgent):
"""(In preview) The Retrieval-Augmented User Proxy retrieves document chunks based on the embedding
similarity, and sends them along with the question to the Retrieval-Augmented Assistant
Expand Down
14 changes: 8 additions & 6 deletions autogen/agentchat/contrib/vectordb/chromadb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,26 @@
import os
from typing import Callable

from ....import_utils import optional_import_block, require_optional_import
from .base import Document, ItemID, QueryResults, VectorDB
from .utils import chroma_results_to_query_results, filter_results_by_distance, get_logger

try:
with optional_import_block() as result:
import chromadb

if chromadb.__version__ < "0.4.15":
raise ImportError("Please upgrade chromadb to version 0.4.15 or later.")
import chromadb.errors
import chromadb.utils.embedding_functions as ef
from chromadb.api.models.Collection import Collection
except ImportError:
raise ImportError("Please install chromadb: `pip install chromadb`")

if result.is_successful:
if chromadb.__version__ < "0.4.15":
raise ImportError("Please upgrade chromadb to version 0.4.15 or later.")


CHROMADB_MAX_BATCH_SIZE = os.environ.get("CHROMADB_MAX_BATCH_SIZE", 40000)
logger = get_logger(__name__)


@require_optional_import("chromadb", "retrievechat")
class ChromaVectorDB(VectorDB):
"""A vector database that uses ChromaDB as the backend."""

Expand Down
13 changes: 5 additions & 8 deletions autogen/agentchat/contrib/vectordb/pgvectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,20 @@
import numpy as np
from sentence_transformers import SentenceTransformer

from ....import_utils import optional_import_block, require_optional_import
from .base import Document, ItemID, QueryResults, VectorDB
from .utils import get_logger

try:
with optional_import_block():
import pgvector # noqa: F401
from pgvector.psycopg import register_vector
except ImportError:
raise ImportError("Please install pgvector: `pip install pgvector`")

try:
import psycopg
except ImportError:
raise ImportError("Please install pgvector: `pip install psycopg`")
from pgvector.psycopg import register_vector

PGVECTOR_MAX_BATCH_SIZE = os.environ.get("PGVECTOR_MAX_BATCH_SIZE", 40000)
logger = get_logger(__name__)


@require_optional_import("psycopg", "retrievechat-pgvector")
class Collection:
"""A Collection object for PGVector.
Expand Down Expand Up @@ -543,6 +539,7 @@ def create_collection(
cursor.close()


@require_optional_import(["pgvector", "psycopg"], "retrievechat-pgvector")
class PGVectorDB(VectorDB):
"""A vector database that uses PGVector as the backend."""

Expand Down
15 changes: 6 additions & 9 deletions autogen/agentchat/contrib/vectordb/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from collections.abc import Sequence
from typing import Optional, Union

from ....import_utils import optional_import_block, require_optional_import
from .base import Document, ItemID, QueryResults, VectorDB
from .utils import get_logger

try:
with optional_import_block():
from fastembed import TextEmbedding
from qdrant_client import QdrantClient, models
except ImportError:
raise ImportError("Please install qdrant-client: `pip install qdrant-client`")


logger = get_logger(__name__)

Expand All @@ -28,6 +29,7 @@ def __call__(self, inputs: list[str]) -> list[Embeddings]:
raise NotImplementedError


@require_optional_import("fastembed", "retrievechat-qdrant")
class FastEmbedEmbeddingFunction(EmbeddingFunction):
"""Embedding function implementation using FastEmbed - https://qdrant.github.io/fastembed."""

Expand Down Expand Up @@ -57,12 +59,6 @@ def __init__(
Raises:
ValueError: If the model_name is not in the format `<org>/<model>` e.g. BAAI/bge-small-en-v1.5.
"""
try:
from fastembed import TextEmbedding
except ImportError as e:
raise ValueError(
"The 'fastembed' package is not installed. Please install it with `pip install fastembed`",
) from e
self._batch_size = batch_size
self._parallel = parallel
self._model = TextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads, **kwargs)
Expand All @@ -73,6 +69,7 @@ def __call__(self, inputs: list[str]) -> list[Embeddings]:
return [embedding.tolist() for embedding in embeddings]


@require_optional_import("qdrant_client", "retrievechat-qdrant")
class QdrantVectorDB(VectorDB):
"""A vector database implementation that uses Qdrant as the backend."""

Expand Down
21 changes: 10 additions & 11 deletions autogen/browser_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,26 @@
from typing import Any, Optional, Union
from urllib.parse import urljoin, urlparse

import markdownify
import requests
from bs4 import BeautifulSoup
from .import_utils import optional_import_block, require_optional_import

with optional_import_block():
import markdownify
import requests
from bs4 import BeautifulSoup

# Optional PDF support
IS_PDF_CAPABLE = False
try:
with optional_import_block() as result:
import pdfminer
import pdfminer.high_level

IS_PDF_CAPABLE = True
except ModuleNotFoundError:
pass
IS_PDF_CAPABLE = result.is_successful

# Other optional dependencies
try:
with optional_import_block():
import pathvalidate
except ModuleNotFoundError:
pass


@require_optional_import(["markdownify", "requests", "bs4", "pdfminer", "pathvalidate"], "websurfer")
class SimpleTextBrowser:
"""(In preview) An extremely simple text-based web browser comparable to Lynx. Suitable for Agentic use."""

Expand Down
12 changes: 7 additions & 5 deletions autogen/cache/cache_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
from typing import Any, Optional, Union

from ..import_utils import optional_import_block
from .abstract_cache_base import AbstractCache
from .disk_cache import DiskCache

Expand Down Expand Up @@ -63,22 +64,23 @@ def cache_factory(
"""
if redis_url:
try:
with optional_import_block() as result:
from .redis_cache import RedisCache

if result.is_successful:
return RedisCache(seed, redis_url)
except ImportError:
else:
logging.warning(
"RedisCache is not available. Checking other cache options. The last fallback is DiskCache."
)

if cosmosdb_config:
try:
with optional_import_block() as result:
from .cosmos_db_cache import CosmosDBCache

if result.is_successful:
return CosmosDBCache.create_cache(seed, cosmosdb_config)

except ImportError:
else:
logging.warning("CosmosDBCache is not available. Fallback to DiskCache.")

# Default to DiskCache if neither Redis nor Cosmos DB configurations are provided
Expand Down
1 change: 1 addition & 0 deletions autogen/exception_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
__all__ = [
"AgentNameConflict",
"InvalidCarryOverType",
"ModelToolNotSupportedError",
"NoEligibleSpeaker",
"SenderRequired",
"UndefinedNextAgent",
Expand Down
Loading

0 comments on commit 2dd1493

Please sign in to comment.