Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate crawl4ai tool #697

Merged
merged 19 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion autogen/tools/dependency_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from abc import ABC
from collections.abc import Iterable
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, get_type_hints
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, get_type_hints

from fast_depends import Depends as FastDepends
from fast_depends import inject
Expand All @@ -26,6 +26,7 @@
"Field",
"get_context_params",
"inject_params",
"on",
]


Expand Down Expand Up @@ -75,6 +76,16 @@ def last_message(self) -> Optional[dict[str, Any]]:
return self._agent.last_message()


T = TypeVar("T")


def on(x: T) -> Callable[[], T]:
def inner(_x: T = x) -> T:
return _x

return inner


@export_module("autogen.tools")
def Depends(x: Any) -> Any: # noqa: N802
"""Creates a dependency for injection based on the provided context or type.
Expand Down
3 changes: 2 additions & 1 deletion autogen/tools/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
# SPDX-License-Identifier: Apache-2.0

from .browser_use import BrowserUseTool
from .crawl4ai import Crawl4AITool

__all__ = ["BrowserUseTool"]
__all__ = ["BrowserUseTool", "Crawl4AITool"]
20 changes: 5 additions & 15 deletions autogen/tools/experimental/browser_use/browser_use.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Annotated, Any, Callable, Optional, TypeVar
from typing import Annotated, Any, Optional

from pydantic import BaseModel

from ....doc_utils import export_module
from ....import_utils import optional_import_block, require_optional_import
from ... import Depends, Tool
from ...dependency_injection import on

with optional_import_block():
from browser_use import Agent
Expand All @@ -19,8 +21,7 @@
__all__ = ["BrowserUseResult", "BrowserUseTool"]


# todo: add export_module decorator
# @export_module("autogen.tools.experimental.browser_use")
@export_module("autogen.tools.experimental.browser_use")
class BrowserUseResult(BaseModel):
"""The result of using the browser to perform a task.

Expand All @@ -33,19 +34,8 @@ class BrowserUseResult(BaseModel):
final_result: Optional[str]


T = TypeVar("T")


def on(x: T) -> Callable[[], T]:
def inner(_x: T = x) -> T:
return _x

return inner


@require_optional_import(["langchain_openai", "browser_use"], "browser-use")
# todo: add export_module decorator
# @export_module("autogen.tools.experimental")
@export_module("autogen.tools.experimental")
class BrowserUseTool(Tool):
"""BrowserUseTool is a tool that uses the browser to perform a task."""

Expand Down
7 changes: 7 additions & 0 deletions autogen/tools/experimental/crawl4ai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0

from .crawl4ai import Crawl4AITool

__all__ = ["Crawl4AITool"]
173 changes: 173 additions & 0 deletions autogen/tools/experimental/crawl4ai/crawl4ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Annotated, Any, Optional, Type

from pydantic import BaseModel

from ....doc_utils import export_module
from ....import_utils import optional_import_block, require_optional_import
from ... import Tool
from ...dependency_injection import Depends, on

with optional_import_block():
from crawl4ai import AsyncWebCrawler, BrowserConfig, CacheMode, CrawlerRunConfig
from crawl4ai.extraction_strategy import LLMExtractionStrategy

__all__ = ["Crawl4AITool"]


@require_optional_import(["crawl4ai"], "crawl4ai")
@export_module("autogen.tools.experimental")
class Crawl4AITool(Tool):
rjambrecic marked this conversation as resolved.
Show resolved Hide resolved
"""
Crawl a website and extract information using the crawl4ai library.
"""

def __init__(
self,
llm_config: Optional[dict[str, Any]] = None,
extraction_model: Optional[Type[BaseModel]] = None,
llm_strategy_kwargs: Optional[dict[str, Any]] = None,
) -> None:
"""
Initialize the Crawl4AITool.

Args:
llm_config: The config dictionary for the LLM model. If None, the tool will run without LLM.
extraction_model: The Pydantic model to use for extraction. If None, the tool will use the default schema.
llm_strategy_kwargs: The keyword arguments to pass to the LLM extraction strategy.
"""
Crawl4AITool._validate_llm_strategy_kwargs(llm_strategy_kwargs, llm_config_provided=(llm_config is not None))

async def crawl4ai_helper( # type: ignore[no-any-unimported]
url: str,
browser_cfg: Optional["BrowserConfig"] = None,
crawl_config: Optional["CrawlerRunConfig"] = None,
) -> Any:
async with AsyncWebCrawler(config=browser_cfg) as crawler:
result = await crawler.arun(
url=url,
config=crawl_config,
)

if crawl_config is None:
response = result.markdown
else:
response = result.extracted_content if result.success else result.error_message

return response

async def crawl4ai_without_llm(
url: Annotated[str, "The url to crawl and extract information from."],
) -> Any:
return await crawl4ai_helper(url=url)

async def crawl4ai_with_llm(
url: Annotated[str, "The url to crawl and extract information from."],
instruction: Annotated[str, "The instruction to provide on how and what to extract."],
llm_config: Annotated[dict[str, Any], Depends(on(llm_config))],
llm_strategy_kwargs: Annotated[Optional[dict[str, Any]], Depends(on(llm_strategy_kwargs))],
extraction_model: Annotated[Optional[Type[BaseModel]], Depends(on(extraction_model))],
) -> Any:
browser_cfg = BrowserConfig(headless=True)
crawl_config = Crawl4AITool._get_crawl_config(
llm_config=llm_config,
instruction=instruction,
extraction_model=extraction_model,
llm_strategy_kwargs=llm_strategy_kwargs,
)

return await crawl4ai_helper(url=url, browser_cfg=browser_cfg, crawl_config=crawl_config)

super().__init__(
name="crawl4ai",
description="Crawl a website and extract information.",
func_or_tool=crawl4ai_without_llm if llm_config is None else crawl4ai_with_llm,
)

@staticmethod
def _validate_llm_strategy_kwargs(llm_strategy_kwargs: Optional[dict[str, Any]], llm_config_provided: bool) -> None:
if not llm_strategy_kwargs:
return

if not llm_config_provided:
raise ValueError("llm_strategy_kwargs can only be provided if llm_config is also provided.")

check_parameters_error_msg = "".join(
f"'{key}' should not be provided in llm_strategy_kwargs. It is automatically set based on llm_config.\n"
for key in ["provider", "api_token"]
if key in llm_strategy_kwargs
)

check_parameters_error_msg += "".join(
"'schema' should not be provided in llm_strategy_kwargs. It is automatically set based on extraction_model type.\n"
if "schema" in llm_strategy_kwargs
else ""
)

check_parameters_error_msg += "".join(
"'instruction' should not be provided in llm_strategy_kwargs. It is provided at the time of calling the tool.\n"
if "instruction" in llm_strategy_kwargs
else ""
)

if check_parameters_error_msg:
raise ValueError(check_parameters_error_msg)

@staticmethod
def _get_provider_and_api_key(llm_config: dict[str, Any]) -> tuple[str, str]:
if "config_list" not in llm_config:
if "model" in llm_config:
model = llm_config["model"]
api_type = "openai"
api_key = os.getenv("OPENAI_API_KEY")
raise ValueError("llm_config must be a valid config dictionary.")
else:
try:
model = llm_config["config_list"][0]["model"]
api_type = llm_config["config_list"][0].get("api_type", "openai")
api_key = llm_config["config_list"][0]["api_key"]

except (KeyError, TypeError):
raise ValueError("llm_config must be a valid config dictionary.")

provider = f"{api_type}/{model}"
return provider, api_key # type: ignore[return-value]

@staticmethod
def _get_crawl_config( # type: ignore[no-any-unimported]
llm_config: dict[str, Any],
instruction: str,
llm_strategy_kwargs: Optional[dict[str, Any]] = None,
extraction_model: Optional[Type[BaseModel]] = None,
) -> "CrawlerRunConfig":
provider, api_key = Crawl4AITool._get_provider_and_api_key(llm_config)

if llm_strategy_kwargs is None:
llm_strategy_kwargs = {}

schema = (
extraction_model.model_json_schema()
if (extraction_model and issubclass(extraction_model, BaseModel))
else None
)

extraction_type = llm_strategy_kwargs.pop("extraction_type", "schema" if schema else "block")

# 1. Define the LLM extraction strategy
llm_strategy = LLMExtractionStrategy(
provider=provider,
api_token=api_key,
schema=schema,
extraction_type=extraction_type,
instruction=instruction,
**llm_strategy_kwargs,
)

# 2. Build the crawler config
crawl_config = CrawlerRunConfig(extraction_strategy=llm_strategy, cache_mode=CacheMode.BYPASS)

return crawl_config
Loading
Loading