From e96d4267e3ac4f4fb3eeb98c7d7fda23873c8f05 Mon Sep 17 00:00:00 2001 From: Marco Vinciguerra Date: Mon, 29 Sep 2025 12:21:20 +0200 Subject: [PATCH] feat: add scrapegraph --- crewai_tools/tools/__init__.py | 4 + .../tools/scrapegraph_scrape_tool/README.md | 302 ++++++++++++--- .../tools/scrapegraph_scrape_tool/__init__.py | 17 + .../scrapegraph_scrape_tool.py | 267 +++++++++++-- pyproject.toml | 2 +- tests/tools/test_scrapegraph_scrape_tool.py | 352 ++++++++++++++++++ 6 files changed, 873 insertions(+), 71 deletions(-) create mode 100644 crewai_tools/tools/scrapegraph_scrape_tool/__init__.py create mode 100644 tests/tools/test_scrapegraph_scrape_tool.py diff --git a/crewai_tools/tools/__init__.py b/crewai_tools/tools/__init__.py index 2b0bb968..c29b71c9 100644 --- a/crewai_tools/tools/__init__.py +++ b/crewai_tools/tools/__init__.py @@ -86,6 +86,10 @@ from .scrapegraph_scrape_tool.scrapegraph_scrape_tool import ( ScrapegraphScrapeTool, ScrapegraphScrapeToolSchema, + FixedScrapegraphScrapeToolSchema, + ScrapeMethod, + ScrapegraphError, + RateLimitError, ) from .scrapfly_scrape_website_tool.scrapfly_scrape_website_tool import ( ScrapflyScrapeWebsiteTool, diff --git a/crewai_tools/tools/scrapegraph_scrape_tool/README.md b/crewai_tools/tools/scrapegraph_scrape_tool/README.md index e006c0ff..f2ea2c55 100644 --- a/crewai_tools/tools/scrapegraph_scrape_tool/README.md +++ b/crewai_tools/tools/scrapegraph_scrape_tool/README.md @@ -1,84 +1,296 @@ -# ScrapegraphScrapeTool +# ScrapeGraph AI Multi-Method Scraper Tool ## Description -A tool that leverages Scrapegraph AI's SmartScraper API to intelligently extract content from websites. This tool provides advanced web scraping capabilities with AI-powered content extraction, making it ideal for targeted data collection and content analysis tasks. +A comprehensive CrewAI tool that integrates with ScrapeGraph AI to provide intelligent web scraping capabilities using multiple methods. This enhanced tool supports 6 different scraping approaches, from basic content extraction to complex interactive automation. + +## Features + +The tool supports 6 different scraping methods: + +### 1. SmartScraper (Default) +Intelligent content extraction using AI to understand and extract relevant information from web pages. + +```python +from crewai_tools import ScrapegraphScrapeTool, ScrapeMethod + +tool = ScrapegraphScrapeTool() +result = tool.run( + website_url="https://example.com", + method=ScrapeMethod.SMARTSCRAPER, + user_prompt="Extract company information" +) +``` + +### 2. SearchScraper +Search-based content gathering from multiple sources across the web. + +```python +result = tool.run( + method=ScrapeMethod.SEARCHSCRAPER, + user_prompt="Latest AI developments", + num_results=5 # 1-20 sources +) +``` + +### 3. AgenticScraper +Interactive scraping with automated actions like clicking buttons, filling forms, etc. + +```python +result = tool.run( + website_url="https://example.com", + method=ScrapeMethod.AGENTICSCRAPER, + steps=[ + "Type email@example.com in email field", + "Type password123 in password field", + "Click login button" + ], + use_session=True, + ai_extraction=True, + user_prompt="Extract dashboard information" +) +``` + +### 4. Crawl +Multi-page crawling with depth control and domain restrictions. + +```python +result = tool.run( + website_url="https://example.com", + method=ScrapeMethod.CRAWL, + user_prompt="Extract all product information", + depth=2, + max_pages=10, + same_domain_only=True, + cache_website=True +) +``` + +### 5. Scrape +Raw HTML extraction with JavaScript rendering support. + +```python +result = tool.run( + website_url="https://example.com", + method=ScrapeMethod.SCRAPE, + render_heavy_js=True, + headers={"User-Agent": "Custom Agent"} +) +``` + +### 6. Markdownify +Convert web content to markdown format. + +```python +result = tool.run( + website_url="https://example.com", + method=ScrapeMethod.MARKDOWNIFY +) +``` ## Installation Install the required packages: ```shell pip install 'crewai[tools]' +pip install scrapegraph-py +``` + +## Schema Support + +All methods support structured data extraction using JSON schemas: + +```python +schema = { + "type": "object", + "properties": { + "title": {"type": "string"}, + "content": {"type": "string"}, + "authors": { + "type": "array", + "items": {"type": "string"} + } + } +} + +result = tool.run( + website_url="https://example.com", + method=ScrapeMethod.SMARTSCRAPER, + data_schema=schema +) +``` + +## Configuration Options + +- `method`: Scraping method (ScrapeMethod enum) +- `render_heavy_js`: Enable JavaScript rendering for dynamic content +- `headers`: Custom HTTP headers +- `data_schema`: JSON schema for structured data extraction +- `steps`: Action steps for agentic scraping (required for AgenticScraper) +- `num_results`: Number of search results (1-20, for SearchScraper) +- `depth`: Crawling depth (1-5, for Crawl) +- `max_pages`: Maximum pages to crawl +- `same_domain_only`: Restrict crawling to same domain +- `cache_website`: Cache content for faster subsequent requests +- `use_session`: Maintain session state for agentic scraping +- `ai_extraction`: Enable AI extraction for agentic scraping +- `timeout`: Request timeout (10-600 seconds) + +## Setup + +1. Set your API key: +```bash +export SCRAPEGRAPH_API_KEY="your-api-key-here" ``` -## Example Usage +Or use a `.env` file: +``` +SCRAPEGRAPH_API_KEY=your-api-key-here +``` -### Basic Usage +2. Initialize the tool: ```python -from crewai_tools import ScrapegraphScrapeTool +from crewai_tools import ScrapegraphScrapeTool, ScrapeMethod + +# Basic initialization +tool = ScrapegraphScrapeTool() -# Basic usage with API key -tool = ScrapegraphScrapeTool(api_key="your_api_key") +# With fixed URL +tool = ScrapegraphScrapeTool(website_url="https://example.com") + +# With custom API key +tool = ScrapegraphScrapeTool(api_key="your-api-key") +``` + +## Advanced Examples + +### Interactive Form Automation +```python result = tool.run( - website_url="https://www.example.com", - user_prompt="Extract the main heading and summary" + website_url="https://dashboard.example.com", + method=ScrapeMethod.AGENTICSCRAPER, + steps=[ + "Type username@email.com in the email input field", + "Type mypassword in the password input field", + "Click the login button", + "Wait for the dashboard to load", + "Click on the profile section" + ], + use_session=True, + ai_extraction=True, + user_prompt="Extract user profile information and account details" ) ``` -### Fixed Website URL +### Multi-Source Research ```python -# Initialize with a fixed website URL -tool = ScrapegraphScrapeTool( - website_url="https://www.example.com", - api_key="your_api_key" +result = tool.run( + method=ScrapeMethod.SEARCHSCRAPER, + user_prompt="Latest developments in web scraping technology and tools", + num_results=10 ) -result = tool.run() +print(f"Research findings: {result['result']}") +print(f"Sources: {result['reference_urls']}") ``` -### Custom Prompt +### Comprehensive Website Crawling ```python -# With custom prompt -tool = ScrapegraphScrapeTool( - api_key="your_api_key", - user_prompt="Extract all product prices and descriptions" +result = tool.run( + website_url="https://company.com", + method=ScrapeMethod.CRAWL, + user_prompt="Extract all product information, pricing, and company details", + depth=3, + max_pages=20, + same_domain_only=True, + cache_website=True, + data_schema={ + "type": "object", + "properties": { + "products": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "price": {"type": "string"}, + "description": {"type": "string"} + } + } + }, + "company_info": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "description": {"type": "string"}, + "contact": {"type": "string"} + } + } + } + } ) -result = tool.run(website_url="https://www.example.com") ``` -### Error Handling +## Error Handling + +The tool handles various error conditions: + ```python +from crewai_tools import ScrapegraphScrapeTool, ScrapegraphError, RateLimitError + try: - tool = ScrapegraphScrapeTool(api_key="your_api_key") + tool = ScrapegraphScrapeTool() result = tool.run( - website_url="https://www.example.com", - user_prompt="Extract the main heading" + website_url="https://example.com", + method=ScrapeMethod.SMARTSCRAPER ) except ValueError as e: - print(f"Configuration error: {e}") # Handles invalid URLs or missing API keys + print(f"Configuration error: {e}") # Invalid parameters or missing API key +except RateLimitError as e: + print(f"Rate limit exceeded: {e}") # API rate limits exceeded +except ScrapegraphError as e: + print(f"ScrapeGraph API error: {e}") # General API errors except RuntimeError as e: - print(f"Scraping error: {e}") # Handles API or network errors + print(f"Scraping operation failed: {e}") # Other runtime errors ``` -## Arguments -- `website_url`: The URL of the website to scrape (required if not set during initialization) -- `user_prompt`: Custom instructions for content extraction (optional) -- `api_key`: Your Scrapegraph API key (required, can be set via SCRAPEGRAPH_API_KEY environment variable) - ## Environment Variables -- `SCRAPEGRAPH_API_KEY`: Your Scrapegraph API key, you can obtain one [here](https://scrapegraphai.com) +- `SCRAPEGRAPH_API_KEY`: Your ScrapeGraph API key, you can obtain one [here](https://scrapegraphai.com) ## Rate Limiting -The Scrapegraph API has rate limits that vary based on your subscription plan. Consider the following best practices: +The ScrapeGraph API has rate limits that vary based on your subscription plan. Consider the following best practices: - Implement appropriate delays between requests when processing multiple URLs - Handle rate limit errors gracefully in your application -- Check your API plan limits on the Scrapegraph dashboard - -## Error Handling -The tool may raise the following exceptions: -- `ValueError`: When API key is missing or URL format is invalid -- `RuntimeError`: When scraping operation fails (network issues, API errors) -- `RateLimitError`: When API rate limits are exceeded +- Check your API plan limits on the ScrapeGraph dashboard +- Use caching for frequently accessed content ## Best Practices -1. Always validate URLs before making requests -2. Implement proper error handling as shown in examples -3. Consider caching results for frequently accessed pages -4. Monitor your API usage through the Scrapegraph dashboard + +1. **Method Selection**: Choose the appropriate method for your use case: + - Use `SmartScraper` for general content extraction + - Use `SearchScraper` for research across multiple sources + - Use `AgenticScraper` for sites requiring interaction + - Use `Crawl` for comprehensive site mapping + - Use `Scrape` for raw HTML when you need full control + - Use `Markdownify` for content formatting + +2. **Schema Design**: When using `data_schema`, design clear, specific schemas for better extraction results + +3. **Session Management**: Use `use_session=True` for `AgenticScraper` when you need to maintain state across interactions + +4. **Performance**: Enable `cache_website=True` for crawling operations to improve performance + +5. **Error Handling**: Always implement proper error handling as shown in examples + +6. **Validation**: Validate URLs and parameters before making requests + +7. **Monitoring**: Monitor your API usage through the ScrapeGraph dashboard + +## Examples + +See `examples/scrapegraph_tool_examples.py` for complete working examples of all methods. + +## API Reference + +- **ScrapeMethod**: Enum of available scraping methods (SMARTSCRAPER, SEARCHSCRAPER, AGENTICSCRAPER, CRAWL, SCRAPE, MARKDOWNIFY) +- **ScrapegraphScrapeToolSchema**: Input validation schema for flexible URL usage +- **FixedScrapegraphScrapeToolSchema**: Schema for tools with fixed URLs +- **ScrapegraphError**: Base exception class for ScrapeGraph-related errors +- **RateLimitError**: Specialized exception for rate limiting scenarios \ No newline at end of file diff --git a/crewai_tools/tools/scrapegraph_scrape_tool/__init__.py b/crewai_tools/tools/scrapegraph_scrape_tool/__init__.py new file mode 100644 index 00000000..cf3d6080 --- /dev/null +++ b/crewai_tools/tools/scrapegraph_scrape_tool/__init__.py @@ -0,0 +1,17 @@ +from .scrapegraph_scrape_tool import ( + ScrapegraphScrapeTool, + ScrapegraphScrapeToolSchema, + FixedScrapegraphScrapeToolSchema, + ScrapeMethod, + ScrapegraphError, + RateLimitError, +) + +__all__ = [ + "ScrapegraphScrapeTool", + "ScrapegraphScrapeToolSchema", + "FixedScrapegraphScrapeToolSchema", + "ScrapeMethod", + "ScrapegraphError", + "RateLimitError", +] \ No newline at end of file diff --git a/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py b/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py index 34f42e52..04e8c7d0 100644 --- a/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py +++ b/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py @@ -1,6 +1,8 @@ import os -from typing import TYPE_CHECKING, Any, Optional, Type, List +import time +from typing import TYPE_CHECKING, Any, Optional, Type, List, Dict, Union from urllib.parse import urlparse +from enum import Enum from crewai.tools import BaseTool, EnvVar from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -18,18 +20,81 @@ class RateLimitError(ScrapegraphError): """Raised when API rate limits are exceeded""" +class ScrapeMethod(str, Enum): + """Available scraping methods""" + SMARTSCRAPER = "smartscraper" + SEARCHSCRAPER = "searchscraper" + AGENTICSCRAPER = "agenticscraper" + CRAWL = "crawl" + SCRAPE = "scrape" + MARKDOWNIFY = "markdownify" + + class FixedScrapegraphScrapeToolSchema(BaseModel): """Input for ScrapegraphScrapeTool when website_url is fixed.""" + method: ScrapeMethod = Field( + default=ScrapeMethod.SMARTSCRAPER, + description="Scraping method to use" + ) + user_prompt: str = Field( + default="Extract the main content of the webpage", + description="Prompt to guide the extraction of content", + ) + render_heavy_js: bool = Field( + default=False, + description="Enable JavaScript rendering for dynamic content" + ) + headers: Optional[Dict[str, str]] = Field( + default=None, + description="Custom headers for the request" + ) + data_schema: Optional[Dict[str, Any]] = Field( + default=None, + description="JSON schema for structured data extraction" + ) + steps: Optional[List[str]] = Field( + default=None, + description="List of steps for agentic scraping (e.g., ['click button', 'fill form'])" + ) + num_results: Optional[int] = Field( + default=3, + description="Number of search results for searchscraper (3-20)" + ) + depth: Optional[int] = Field( + default=1, + description="Crawling depth for crawl method" + ) + max_pages: Optional[int] = Field( + default=10, + description="Maximum pages to crawl" + ) + same_domain_only: bool = Field( + default=True, + description="Only crawl pages from the same domain" + ) + cache_website: bool = Field( + default=False, + description="Cache website content for faster subsequent requests" + ) + use_session: bool = Field( + default=False, + description="Use session for agentic scraping to maintain state" + ) + ai_extraction: bool = Field( + default=True, + description="Enable AI extraction for agentic scraping" + ) + timeout: Optional[int] = Field( + default=300, + description="Request timeout in seconds (max 600 for crawl operations)" + ) + class ScrapegraphScrapeToolSchema(FixedScrapegraphScrapeToolSchema): """Input for ScrapegraphScrapeTool.""" website_url: str = Field(..., description="Mandatory website url to scrape") - user_prompt: str = Field( - default="Extract the main content of the webpage", - description="Prompt to guide the extraction of content", - ) @field_validator("website_url") def validate_url(cls, v): @@ -44,10 +109,39 @@ def validate_url(cls, v): "Invalid URL format. URL must include scheme (http/https) and domain" ) + @field_validator("num_results") + def validate_num_results(cls, v): + """Validate number of results for searchscraper""" + if v is not None and (v < 1 or v > 20): + raise ValueError("num_results must be between 1 and 20") + return v + + @field_validator("depth") + def validate_depth(cls, v): + """Validate crawling depth""" + if v is not None and (v < 1 or v > 5): + raise ValueError("depth must be between 1 and 5") + return v + + @field_validator("timeout") + def validate_timeout(cls, v): + """Validate timeout""" + if v is not None and (v < 10 or v > 600): + raise ValueError("timeout must be between 10 and 600 seconds") + return v + class ScrapegraphScrapeTool(BaseTool): """ - A tool that uses Scrapegraph AI to intelligently scrape website content. + A comprehensive tool that uses ScrapeGraph AI to intelligently scrape website content. + + Supports multiple scraping methods: + - smartscraper: Basic intelligent content extraction + - searchscraper: Search-based content gathering from multiple sources + - agenticscraper: Interactive scraping with automated actions (clicking, typing, etc.) + - crawl: Multi-page crawling with depth control + - scrape: Raw HTML extraction with JS rendering support + - markdownify: Convert web content to markdown format Raises: ValueError: If API key is missing or URL format is invalid @@ -57,9 +151,11 @@ class ScrapegraphScrapeTool(BaseTool): model_config = ConfigDict(arbitrary_types_allowed=True) - name: str = "Scrapegraph website scraper" + name: str = "ScrapeGraph AI Multi-Method Scraper" description: str = ( - "A tool that uses Scrapegraph AI to intelligently scrape website content." + "A comprehensive scraping tool using ScrapeGraph AI. Supports smartscraper (intelligent extraction), " + "searchscraper (multi-source search), agenticscraper (interactive automation), crawl (multi-page), " + "scrape (raw HTML), and markdownify (markdown conversion) methods." ) args_schema: Type[BaseModel] = ScrapegraphScrapeToolSchema website_url: Optional[str] = None @@ -78,6 +174,7 @@ def __init__( user_prompt: Optional[str] = None, api_key: Optional[str] = None, enable_logging: bool = False, + method: ScrapeMethod = ScrapeMethod.SMARTSCRAPER, **kwargs, ): super().__init__(**kwargs) @@ -111,7 +208,7 @@ def __init__( if website_url is not None: self._validate_url(website_url) self.website_url = website_url - self.description = f"A tool that uses Scrapegraph AI to intelligently scrape {website_url}'s content." + self.description = f"A tool that uses ScrapeGraph AI to scrape {website_url}'s content using {method.value} method." self.args_schema = FixedScrapegraphScrapeToolSchema if user_prompt is not None: @@ -133,10 +230,10 @@ def _validate_url(url: str) -> None: "Invalid URL format. URL must include scheme (http/https) and domain" ) - def _handle_api_response(self, response: dict) -> str: + def _handle_api_response(self, response: dict, method: ScrapeMethod) -> Any: """Handle and validate API response""" if not response: - raise RuntimeError("Empty response from Scrapegraph API") + raise RuntimeError("Empty response from ScrapeGraph API") if "error" in response: error_msg = response.get("error", {}).get("message", "Unknown error") @@ -144,40 +241,160 @@ def _handle_api_response(self, response: dict) -> str: raise RateLimitError(f"Rate limit exceeded: {error_msg}") raise RuntimeError(f"API error: {error_msg}") + # Handle different response formats based on method + if method == ScrapeMethod.CRAWL: + # Crawl may return async operation ID + if "id" in response or "task_id" in response: + return self._handle_async_crawl(response) + elif "result" in response: + return response["result"] + elif method == ScrapeMethod.SEARCHSCRAPER: + # SearchScraper returns result and reference_urls + if "result" in response: + return { + "result": response["result"], + "reference_urls": response.get("reference_urls", []) + } + elif method == ScrapeMethod.SCRAPE: + # Scrape returns HTML content + if "html" in response: + return response["html"] + + # Default handling for other methods if "result" not in response: - raise RuntimeError("Invalid response format from Scrapegraph API") + raise RuntimeError("Invalid response format from ScrapeGraph API") return response["result"] + def _handle_async_crawl(self, initial_response: dict) -> Any: + """Handle asynchronous crawl operations""" + crawl_id = initial_response.get("id") or initial_response.get("task_id") + if not crawl_id: + return initial_response + + # Poll for result with timeout + max_iterations = 60 # 5 minutes with 5-second intervals + for i in range(max_iterations): + time.sleep(5) + try: + result = self._client.get_crawl(crawl_id) + status = result.get("status") + + if status == "success" and result.get("result"): + return result["result"].get("llm_result", result["result"]) + elif status == "failed": + raise RuntimeError(f"Crawl operation failed: {result.get('error', 'Unknown error')}") + elif status in ["completed", "finished"]: + return result.get("result", result) + + except Exception as e: + if i == max_iterations - 1: # Last iteration + raise RuntimeError(f"Failed to get crawl result: {str(e)}") + continue + + raise RuntimeError("Crawl operation timed out after 5 minutes") + def _run( self, **kwargs: Any, ) -> Any: website_url = kwargs.get("website_url", self.website_url) + method = kwargs.get("method", ScrapeMethod.SMARTSCRAPER) user_prompt = ( kwargs.get("user_prompt", self.user_prompt) or "Extract the main content of the webpage" ) - if not website_url: - raise ValueError("website_url is required") - - # Validate URL format - self._validate_url(website_url) + # Extract additional parameters + render_heavy_js = kwargs.get("render_heavy_js", False) + headers = kwargs.get("headers") + data_schema = kwargs.get("data_schema") + steps = kwargs.get("steps") + num_results = kwargs.get("num_results", 3) + depth = kwargs.get("depth", 1) + max_pages = kwargs.get("max_pages", 10) + same_domain_only = kwargs.get("same_domain_only", True) + cache_website = kwargs.get("cache_website", False) + use_session = kwargs.get("use_session", False) + ai_extraction = kwargs.get("ai_extraction", True) + timeout = kwargs.get("timeout", 300) + + # Validate required parameters based on method + if method != ScrapeMethod.SEARCHSCRAPER and not website_url: + raise ValueError("website_url is required for this method") + + if method == ScrapeMethod.AGENTICSCRAPER and not steps: + raise ValueError("steps parameter is required for agentic scraping") + + if website_url: + self._validate_url(website_url) try: - # Make the SmartScraper request - response = self._client.smartscraper( - website_url=website_url, - user_prompt=user_prompt, - ) + # Route to appropriate method + if method == ScrapeMethod.SMARTSCRAPER: + response = self._client.smartscraper( + website_url=website_url, + user_prompt=user_prompt, + data_schema=data_schema + ) + + elif method == ScrapeMethod.SEARCHSCRAPER: + response = self._client.searchscraper( + user_prompt=user_prompt, + num_results=num_results + ) + + elif method == ScrapeMethod.AGENTICSCRAPER: + agenticscraper_kwargs = { + "url": website_url, + "steps": steps, + "use_session": use_session, + "ai_extraction": ai_extraction + } + if ai_extraction: + agenticscraper_kwargs["user_prompt"] = user_prompt + if data_schema: + agenticscraper_kwargs["output_schema"] = data_schema + + response = self._client.agenticscraper(**agenticscraper_kwargs) + + elif method == ScrapeMethod.CRAWL: + crawl_kwargs = { + "url": website_url, + "prompt": user_prompt, + "depth": depth, + "max_pages": max_pages, + "same_domain_only": same_domain_only, + "cache_website": cache_website + } + if data_schema: + crawl_kwargs["data_schema"] = data_schema + + response = self._client.crawl(**crawl_kwargs) + + elif method == ScrapeMethod.SCRAPE: + scrape_kwargs = { + "website_url": website_url, + "render_heavy_js": render_heavy_js + } + if headers: + scrape_kwargs["headers"] = headers + + response = self._client.scrape(**scrape_kwargs) + + elif method == ScrapeMethod.MARKDOWNIFY: + response = self._client.markdownify(website_url=website_url) + + else: + raise ValueError(f"Unsupported scraping method: {method}") - return response + return self._handle_api_response(response, method) except RateLimitError: raise # Re-raise rate limit errors except Exception as e: - raise RuntimeError(f"Scraping failed: {str(e)}") + raise RuntimeError(f"Scraping failed with {method.value}: {str(e)}") finally: # Always close the client - self._client.close() + if hasattr(self, '_client') and self._client: + self._client.close() diff --git a/pyproject.toml b/pyproject.toml index 82c73f8e..08651975 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,7 @@ spider-client = [ "spider-client>=0.1.25", ] scrapegraph-py = [ - "scrapegraph-py>=1.9.0", + "scrapegraph-py>=1.31.0", ] linkup-sdk = [ "linkup-sdk>=0.2.2", diff --git a/tests/tools/test_scrapegraph_scrape_tool.py b/tests/tools/test_scrapegraph_scrape_tool.py new file mode 100644 index 00000000..35e5e9ce --- /dev/null +++ b/tests/tools/test_scrapegraph_scrape_tool.py @@ -0,0 +1,352 @@ +import os +import pytest +from unittest.mock import Mock, patch, MagicMock +from pydantic import ValidationError + +from crewai_tools.tools.scrapegraph_scrape_tool.scrapegraph_scrape_tool import ( + ScrapegraphScrapeTool, + ScrapegraphScrapeToolSchema, + FixedScrapegraphScrapeToolSchema, + ScrapeMethod, + ScrapegraphError, + RateLimitError, +) + + +class TestScrapegraphScrapeToolSchema: + """Test the schema validation""" + + def test_valid_url(self): + schema = ScrapegraphScrapeToolSchema( + website_url="https://example.com", + user_prompt="Test prompt" + ) + assert schema.website_url == "https://example.com" + assert schema.user_prompt == "Test prompt" + + def test_invalid_url_format(self): + with pytest.raises(ValidationError, match="Invalid URL format"): + ScrapegraphScrapeToolSchema(website_url="not-a-url") + + def test_invalid_num_results(self): + with pytest.raises(ValidationError, match="num_results must be between 1 and 20"): + ScrapegraphScrapeToolSchema( + website_url="https://example.com", + num_results=25 + ) + + def test_invalid_depth(self): + with pytest.raises(ValidationError, match="depth must be between 1 and 5"): + ScrapegraphScrapeToolSchema( + website_url="https://example.com", + depth=10 + ) + + def test_invalid_timeout(self): + with pytest.raises(ValidationError, match="timeout must be between 10 and 600"): + ScrapegraphScrapeToolSchema( + website_url="https://example.com", + timeout=1000 + ) + + def test_default_values(self): + schema = ScrapegraphScrapeToolSchema(website_url="https://example.com") + assert schema.method == ScrapeMethod.SMARTSCRAPER + assert schema.render_heavy_js is False + assert schema.num_results == 3 + assert schema.depth == 1 + assert schema.same_domain_only is True + + +class TestScrapegraphScrapeTool: + """Test the main tool functionality""" + + @pytest.fixture + def mock_client(self): + with patch('crewai_tools.tools.scrapegraph_scrape_tool.scrapegraph_scrape_tool.Client') as mock: + client_instance = Mock() + mock.return_value = client_instance + yield client_instance + + @pytest.fixture + def tool_with_api_key(self, mock_client): + with patch.dict(os.environ, {'SCRAPEGRAPH_API_KEY': 'test-api-key'}): + return ScrapegraphScrapeTool() + + def test_initialization_without_api_key(self, mock_client): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(ValueError, match="Scrapegraph API key is required"): + ScrapegraphScrapeTool() + + def test_initialization_with_api_key(self, mock_client): + with patch.dict(os.environ, {'SCRAPEGRAPH_API_KEY': 'test-key'}): + tool = ScrapegraphScrapeTool() + assert tool.api_key == 'test-key' + + def test_initialization_with_fixed_url(self, mock_client): + with patch.dict(os.environ, {'SCRAPEGRAPH_API_KEY': 'test-key'}): + tool = ScrapegraphScrapeTool(website_url="https://example.com") + assert tool.website_url == "https://example.com" + assert tool.args_schema == FixedScrapegraphScrapeToolSchema + + def test_invalid_url_initialization(self, mock_client): + with patch.dict(os.environ, {'SCRAPEGRAPH_API_KEY': 'test-key'}): + with pytest.raises(ValueError, match="Invalid URL format"): + ScrapegraphScrapeTool(website_url="invalid-url") + + def test_smartscraper_method(self, tool_with_api_key, mock_client): + mock_client.smartscraper.return_value = {"result": "Extracted content"} + + result = tool_with_api_key.run( + website_url="https://example.com", + method=ScrapeMethod.SMARTSCRAPER, + user_prompt="Extract content" + ) + + assert result == "Extracted content" + mock_client.smartscraper.assert_called_once_with( + website_url="https://example.com", + user_prompt="Extract content", + data_schema=None + ) + mock_client.close.assert_called_once() + + def test_searchscraper_method(self, tool_with_api_key, mock_client): + mock_client.searchscraper.return_value = { + "result": "Search results", + "reference_urls": ["https://source1.com", "https://source2.com"] + } + + result = tool_with_api_key.run( + method=ScrapeMethod.SEARCHSCRAPER, + user_prompt="Search for information", + num_results=5 + ) + + expected = { + "result": "Search results", + "reference_urls": ["https://source1.com", "https://source2.com"] + } + assert result == expected + mock_client.searchscraper.assert_called_once_with( + user_prompt="Search for information", + num_results=5 + ) + + def test_agenticscraper_method(self, tool_with_api_key, mock_client): + mock_client.agenticscraper.return_value = {"result": "Agentic result"} + + steps = ["Click login button", "Enter credentials"] + result = tool_with_api_key.run( + website_url="https://example.com", + method=ScrapeMethod.AGENTICSCRAPER, + steps=steps, + use_session=True, + ai_extraction=True, + user_prompt="Extract after login" + ) + + assert result == "Agentic result" + mock_client.agenticscraper.assert_called_once_with( + url="https://example.com", + steps=steps, + use_session=True, + ai_extraction=True, + user_prompt="Extract after login" + ) + + def test_agenticscraper_missing_steps(self, tool_with_api_key): + with pytest.raises(ValueError, match="steps parameter is required for agentic scraping"): + tool_with_api_key.run( + website_url="https://example.com", + method=ScrapeMethod.AGENTICSCRAPER + ) + + def test_crawl_method(self, tool_with_api_key, mock_client): + mock_client.crawl.return_value = {"result": "Crawl result"} + + result = tool_with_api_key.run( + website_url="https://example.com", + method=ScrapeMethod.CRAWL, + user_prompt="Crawl website", + depth=2, + max_pages=5, + same_domain_only=True + ) + + assert result == "Crawl result" + mock_client.crawl.assert_called_once_with( + url="https://example.com", + prompt="Crawl website", + depth=2, + max_pages=5, + same_domain_only=True, + cache_website=False + ) + + def test_scrape_method(self, tool_with_api_key, mock_client): + mock_client.scrape.return_value = {"html": "Content"} + + result = tool_with_api_key.run( + website_url="https://example.com", + method=ScrapeMethod.SCRAPE, + render_heavy_js=True, + headers={"User-Agent": "test"} + ) + + assert result == "Content" + mock_client.scrape.assert_called_once_with( + website_url="https://example.com", + render_heavy_js=True, + headers={"User-Agent": "test"} + ) + + def test_markdownify_method(self, tool_with_api_key, mock_client): + mock_client.markdownify.return_value = {"result": "# Markdown content"} + + result = tool_with_api_key.run( + website_url="https://example.com", + method=ScrapeMethod.MARKDOWNIFY + ) + + assert result == "# Markdown content" + mock_client.markdownify.assert_called_once_with(website_url="https://example.com") + + def test_missing_website_url(self, tool_with_api_key): + with pytest.raises(ValueError, match="website_url is required for this method"): + tool_with_api_key.run(method=ScrapeMethod.SMARTSCRAPER) + + def test_invalid_method(self, tool_with_api_key): + with pytest.raises(ValueError, match="Unsupported scraping method"): + tool_with_api_key.run( + website_url="https://example.com", + method="invalid_method" + ) + + def test_rate_limit_error(self, tool_with_api_key, mock_client): + mock_client.smartscraper.return_value = { + "error": {"message": "Rate limit exceeded"} + } + + with pytest.raises(RateLimitError, match="Rate limit exceeded"): + tool_with_api_key.run( + website_url="https://example.com", + method=ScrapeMethod.SMARTSCRAPER + ) + + def test_api_error(self, tool_with_api_key, mock_client): + mock_client.smartscraper.return_value = { + "error": {"message": "API error occurred"} + } + + with pytest.raises(RuntimeError, match="API error: API error occurred"): + tool_with_api_key.run( + website_url="https://example.com", + method=ScrapeMethod.SMARTSCRAPER + ) + + def test_empty_response(self, tool_with_api_key, mock_client): + mock_client.smartscraper.return_value = {} + + with pytest.raises(RuntimeError, match="Empty response from ScrapeGraph API"): + tool_with_api_key.run( + website_url="https://example.com", + method=ScrapeMethod.SMARTSCRAPER + ) + + def test_missing_result_in_response(self, tool_with_api_key, mock_client): + mock_client.smartscraper.return_value = {"status": "success"} + + with pytest.raises(RuntimeError, match="Invalid response format"): + tool_with_api_key.run( + website_url="https://example.com", + method=ScrapeMethod.SMARTSCRAPER + ) + + def test_async_crawl_success(self, tool_with_api_key, mock_client): + # Simulate async crawl with polling + mock_client.crawl.return_value = {"id": "task123"} + mock_client.get_crawl.return_value = { + "status": "success", + "result": {"llm_result": "Final crawl result"} + } + + with patch('time.sleep'): # Mock sleep to speed up test + result = tool_with_api_key.run( + website_url="https://example.com", + method=ScrapeMethod.CRAWL, + user_prompt="Crawl website" + ) + + assert result == "Final crawl result" + mock_client.get_crawl.assert_called_with("task123") + + def test_async_crawl_failure(self, tool_with_api_key, mock_client): + mock_client.crawl.return_value = {"id": "task123"} + mock_client.get_crawl.return_value = { + "status": "failed", + "error": "Crawl failed" + } + + with patch('time.sleep'): + with pytest.raises(RuntimeError, match="Crawl operation failed: Crawl failed"): + tool_with_api_key.run( + website_url="https://example.com", + method=ScrapeMethod.CRAWL + ) + + def test_client_close_called_on_exception(self, tool_with_api_key, mock_client): + mock_client.smartscraper.side_effect = Exception("Network error") + + with pytest.raises(RuntimeError, match="Scraping failed"): + tool_with_api_key.run( + website_url="https://example.com", + method=ScrapeMethod.SMARTSCRAPER + ) + + mock_client.close.assert_called_once() + + def test_schema_with_data_schema(self, tool_with_api_key, mock_client): + mock_client.smartscraper.return_value = {"result": "Structured data"} + + data_schema = { + "type": "object", + "properties": { + "title": {"type": "string"}, + "content": {"type": "string"} + } + } + + result = tool_with_api_key.run( + website_url="https://example.com", + method=ScrapeMethod.SMARTSCRAPER, + data_schema=data_schema + ) + + assert result == "Structured data" + mock_client.smartscraper.assert_called_once_with( + website_url="https://example.com", + user_prompt="Extract the main content of the webpage", + data_schema=data_schema + ) + + +class TestScrapeMethod: + """Test the ScrapeMethod enum""" + + def test_all_methods_defined(self): + expected_methods = [ + "smartscraper", + "searchscraper", + "agenticscraper", + "crawl", + "scrape", + "markdownify" + ] + + actual_methods = [method.value for method in ScrapeMethod] + assert set(expected_methods) == set(actual_methods) + + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file